diff --git a/bls/signature.go b/bls/signature.go index 673f24307..83fdb7d2d 100644 --- a/bls/signature.go +++ b/bls/signature.go @@ -227,6 +227,14 @@ func (publicKey *PublicKey) FromString(publicKeyString string) (*PublicKey, erro return publicKey, err } +func (publicKey *PublicKey) ToAbbreviatedString() string { + str := publicKey.ToString() + if len(str) <= 8 { + return str + } + return str[:8] + "..." + str[len(str)-8:] +} + func (publicKey *PublicKey) MarshalJSON() ([]byte, error) { // This is called automatically by the JSON library when converting a // bls.PublicKey to JSON. This is useful when passing a bls.PublicKey @@ -324,6 +332,14 @@ func (signature *Signature) FromString(signatureString string) (*Signature, erro return signature, nil } +func (signature *Signature) ToAbbreviatedString() string { + str := signature.ToString() + if len(str) <= 8 { + return str + } + return str[:8] + "..." + str[len(str)-8:] +} + func (signature *Signature) MarshalJSON() ([]byte, error) { // This is called automatically by the JSON library when converting a // bls.Signature to JSON. This is useful when passing a bls.Signature diff --git a/bls/signature_no_relic.go b/bls/signature_no_relic.go index b3d472739..0f3988253 100644 --- a/bls/signature_no_relic.go +++ b/bls/signature_no_relic.go @@ -84,6 +84,10 @@ func (publicKey *PublicKey) FromString(publicKeyString string) (*PublicKey, erro panic(BLSNoRelicError) } +func (publicKey *PublicKey) ToAbbreviatedString() string { + panic(BLSNoRelicError) +} + func (publicKey *PublicKey) MarshalJSON() ([]byte, error) { panic(BLSNoRelicError) } @@ -136,6 +140,10 @@ func (signature *Signature) FromString(signatureString string) (*Signature, erro panic(BLSNoRelicError) } +func (signature *Signature) ToAbbreviatedString() string { + panic(BLSNoRelicError) +} + func (signature *Signature) MarshalJSON() ([]byte, error) { panic(BLSNoRelicError) } diff --git a/cmd/config.go b/cmd/config.go index 3e3a47f97..302ba2eb7 100644 --- a/cmd/config.go +++ b/cmd/config.go @@ -49,11 +49,13 @@ type Config struct { PosTimeoutBaseDurationMilliseconds uint64 // Mempool - MempoolBackupIntervalMillis uint64 - MaxMempoolPosSizeBytes uint64 - MempoolFeeEstimatorNumMempoolBlocks uint64 - MempoolFeeEstimatorNumPastBlocks uint64 - AugmentedBlockViewRefreshIntervalMillis uint64 + MempoolBackupIntervalMillis uint64 + MaxMempoolPosSizeBytes uint64 + MempoolFeeEstimatorNumMempoolBlocks uint64 + MempoolFeeEstimatorNumPastBlocks uint64 + MempoolMaxValidationViewConnects uint64 + TransactionValidationRefreshIntervalMillis uint64 + AugmentedBlockViewRefreshIntervalMillis uint64 // Mining MinerPublicKeys []string @@ -80,7 +82,8 @@ type Config struct { TimeEvents bool // State Syncer - StateChangeDir string + StateChangeDir string + StateSyncerMempoolTxnSyncLimit uint64 } func LoadConfig() *Config { @@ -130,6 +133,8 @@ func LoadConfig() *Config { config.MaxMempoolPosSizeBytes = viper.GetUint64("max-mempool-pos-size-bytes") config.MempoolFeeEstimatorNumMempoolBlocks = viper.GetUint64("mempool-fee-estimator-num-mempool-blocks") config.MempoolFeeEstimatorNumPastBlocks = viper.GetUint64("mempool-fee-estimator-num-past-blocks") + config.MempoolMaxValidationViewConnects = viper.GetUint64("mempool-max-validation-view-connects") + config.TransactionValidationRefreshIntervalMillis = viper.GetUint64("transaction-validation-refresh-interval-millis") config.AugmentedBlockViewRefreshIntervalMillis = viper.GetUint64("augmented-block-view-refresh-interval-millis") // Peers @@ -176,6 +181,7 @@ func LoadConfig() *Config { // State Syncer config.StateChangeDir = viper.GetString("state-change-dir") + config.StateSyncerMempoolTxnSyncLimit = viper.GetUint64("state-syncer-mempool-txn-sync-limit") return &config } diff --git a/cmd/node.go b/cmd/node.go index 4e64870f4..f14c09d87 100644 --- a/cmd/node.go +++ b/cmd/node.go @@ -27,12 +27,13 @@ import ( ) type Node struct { - Server *lib.Server - ChainDB *badger.DB - TXIndex *lib.TXIndex - Params *lib.DeSoParams - Config *Config - Postgres *lib.Postgres + Server *lib.Server + ChainDB *badger.DB + TXIndex *lib.TXIndex + Params *lib.DeSoParams + Config *Config + Postgres *lib.Postgres + Listeners []net.Listener // IsRunning is false when a NewNode is created, set to true on Start(), set to false // after Stop() is called. Mainly used in testing. @@ -117,8 +118,7 @@ func (node *Node) Start(exitChannels ...*chan struct{}) { // This just gets localhost listening addresses on the protocol port. // Such as [{127.0.0.1 18000 } {::1 18000 }], and associated listener structs. - listeningAddrs, listeners := GetAddrsToListenOn(node.Config.ProtocolPort) - _ = listeningAddrs + _, node.Listeners = GetAddrsToListenOn(node.Config.ProtocolPort) // If --connect-ips is not passed, we will connect the addresses from // --add-ips, DNSSeeds, and DNSSeedGenerators. @@ -238,7 +238,7 @@ func (node *Node) Start(exitChannels ...*chan struct{}) { shouldRestart := false node.Server, err, shouldRestart = lib.NewServer( node.Params, - listeners, + node.Listeners, desoAddrMgr, node.Config.ConnectIPs, node.ChainDB, @@ -279,9 +279,12 @@ func (node *Node) Start(exitChannels ...*chan struct{}) { node.Config.MempoolBackupIntervalMillis, node.Config.MempoolFeeEstimatorNumMempoolBlocks, node.Config.MempoolFeeEstimatorNumPastBlocks, + node.Config.MempoolMaxValidationViewConnects, + node.Config.TransactionValidationRefreshIntervalMillis, node.Config.AugmentedBlockViewRefreshIntervalMillis, node.Config.PosBlockProductionIntervalMilliseconds, node.Config.PosTimeoutBaseDurationMilliseconds, + node.Config.StateSyncerMempoolTxnSyncLimit, ) if err != nil { // shouldRestart can be true if, on the previous run, we did not finish flushing all ancestral diff --git a/cmd/run.go b/cmd/run.go index c5ebe87a8..d5394d7cd 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -102,9 +102,14 @@ func SetupRunFlags(cmd *cobra.Command) { "The number of future blocks to break the PoS mempool into when estimating txn fee for the next block.") cmd.PersistentFlags().Uint64("mempool-fee-estimator-num-past-blocks", 50, "The number of past blocks to use when estimating txn fee for the next block from the PoS mempool.") + cmd.PersistentFlags().Uint64("mempool-max-validation-view-connects", 10000, + "The maximum number of connects that the mempool transaction validation routine will perform.") + cmd.PersistentFlags().Uint64("transaction-validation-refresh-interval-millis", 10, + "The frequency in milliseconds with which the transaction validation routine is run in mempool. "+ + "The default value is 10 milliseconds.") cmd.PersistentFlags().Uint64("augmented-block-view-refresh-interval-millis", 10, "The frequency in milliseconds with which the augmented block view will be refreshed. "+ - "The default value is 100 milliseconds.") + "The default value is 10 milliseconds.") // Peers cmd.PersistentFlags().StringSlice("connect-ips", []string{}, @@ -222,6 +227,8 @@ func SetupRunFlags(cmd *cobra.Command) { cmd.PersistentFlags().Bool("time-events", false, "Enable simple event timer, helpful in hands-on performance testing") cmd.PersistentFlags().String("state-change-dir", "", "The directory for state change logs. WARNING: Changing this "+ "from an empty string to a non-empty string (or from a non-empty string to the empty string) requires a resync.") + cmd.PersistentFlags().Uint("state-syncer-mempool-txn-sync-limit", 10000, "The maximum number of transactions to "+ + "process in the mempool tx state syncer at a time.") cmd.PersistentFlags().VisitAll(func(flag *pflag.Flag) { viper.BindPFlag(flag.Name, flag) }) diff --git a/collections/concurrent_map.go b/collections/concurrent_map.go new file mode 100644 index 000000000..e16d64dc7 --- /dev/null +++ b/collections/concurrent_map.go @@ -0,0 +1,80 @@ +package collections + +import "sync" + +type ConcurrentMap[Key comparable, Value any] struct { + mtx sync.RWMutex + m map[Key]Value +} + +func NewConcurrentMap[Key comparable, Value any]() *ConcurrentMap[Key, Value] { + return &ConcurrentMap[Key, Value]{ + m: make(map[Key]Value), + } +} + +func (cm *ConcurrentMap[Key, Value]) Set(key Key, val Value) { + cm.mtx.Lock() + defer cm.mtx.Unlock() + + cm.m[key] = val +} + +func (cm *ConcurrentMap[Key, Value]) Remove(key Key) { + cm.mtx.Lock() + defer cm.mtx.Unlock() + + _, ok := cm.m[key] + if !ok { + return + } + delete(cm.m, key) +} + +func (cm *ConcurrentMap[Key, Value]) Get(key Key) (Value, bool) { + cm.mtx.RLock() + defer cm.mtx.RUnlock() + + val, ok := cm.m[key] + return val, ok +} + +func (cm *ConcurrentMap[Key, Value]) Clone() *ConcurrentMap[Key, Value] { + cm.mtx.RLock() + defer cm.mtx.RUnlock() + + clone := NewConcurrentMap[Key, Value]() + for key, val := range cm.m { + clone.Set(key, val) + } + return clone +} + +func (cm *ConcurrentMap[Key, Value]) ToMap() map[Key]Value { + cm.mtx.RLock() + defer cm.mtx.RUnlock() + + index := make(map[Key]Value) + for key, node := range cm.m { + index[key] = node + } + return index +} + +func (cm *ConcurrentMap[Key, Value]) GetAll() []Value { + cm.mtx.RLock() + defer cm.mtx.RUnlock() + + var vals []Value + for _, val := range cm.m { + vals = append(vals, val) + } + return vals +} + +func (cm *ConcurrentMap[Key, Value]) Count() int { + cm.mtx.RLock() + defer cm.mtx.RUnlock() + + return len(cm.m) +} diff --git a/collections/concurrent_map_test.go b/collections/concurrent_map_test.go new file mode 100644 index 000000000..aac89b2fb --- /dev/null +++ b/collections/concurrent_map_test.go @@ -0,0 +1,61 @@ +package collections + +import ( + "fmt" + "testing" +) + +func TestConcurrentMap(t *testing.T) { + m := NewConcurrentMap[string, int]() + control := make(map[string]int) + + // test add + for ii := 0; ii < 100; ii++ { + key := fmt.Sprintf("%v", ii) + m.Set(key, ii) + control[key] = ii + } + + for key, val := range control { + if mVal, ok := m.Get(key); !ok || mVal != val { + t.Errorf("Expected %d, got %d", val, m.m[key]) + } + } + + // test remove + for ii := 0; ii < 50; ii++ { + key := fmt.Sprintf("%v", ii) + m.Remove(key) + delete(control, key) + } + + for key, val := range control { + if mVal, ok := m.Get(key); !ok || mVal != val { + t.Errorf("Expected %d, got %d", val, m.m[key]) + } + } + + // test copy + copy := m.ToMap() + for key, val := range control { + if mVal, ok := copy[key]; !ok || mVal != val { + t.Errorf("Expected %d, got %d", val, m.m[key]) + } + } + if len(copy) != len(control) { + t.Errorf("Expected %d, got %d", len(control), len(copy)) + } + + // test get all + vals := m.GetAll() + for _, val := range vals { + if _, ok := control[fmt.Sprintf("%v", val)]; !ok { + t.Errorf("Expected %d, got %d", val, m.m[fmt.Sprintf("%v", val)]) + } + } + + // test size + if m.Count() != len(control) { + t.Errorf("Expected %d, got %d", len(control), m.Count()) + } +} diff --git a/consensus/integration_test_types.go b/consensus/integration_test_types.go index dfc2397db..962342500 100644 --- a/consensus/integration_test_types.go +++ b/consensus/integration_test_types.go @@ -104,6 +104,10 @@ func (node *validatorNode) GetStakeAmount() *uint256.Int { return node.stake } +func (node *validatorNode) GetDomains() [][]byte { + return [][]byte{} +} + func (node *validatorNode) ProcessBlock(incomingBlock *block) { node.lock.Lock() defer node.lock.Unlock() diff --git a/consensus/types.go b/consensus/types.go index b724b062e..5ba6c3713 100644 --- a/consensus/types.go +++ b/consensus/types.go @@ -92,6 +92,7 @@ type BlockHash interface { type Validator interface { GetPublicKey() *bls.PublicKey GetStakeAmount() *uint256.Int + GetDomains() [][]byte } type AggregateQuorumCertificate interface { diff --git a/consensus/types_internal.go b/consensus/types_internal.go index 7b98ce1a2..1eb40eeb7 100644 --- a/consensus/types_internal.go +++ b/consensus/types_internal.go @@ -35,6 +35,10 @@ func (v *validator) GetStakeAmount() *uint256.Int { return v.stakeAmount } +func (v *validator) GetDomains() [][]byte { + return [][]byte{} +} + //////////////////////////////////////////////////////////////////////// // AggregateQuorumCertificate interface implementation for internal use. // We use this type for unit tests, and to construct timeout QCs for diff --git a/integration_testing/blocksync_test.go b/integration_testing/blocksync_test.go index 8be96d735..cf077f2c1 100644 --- a/integration_testing/blocksync_test.go +++ b/integration_testing/blocksync_test.go @@ -1,11 +1,6 @@ package integration_testing import ( - "fmt" - "github.com/deso-protocol/core/cmd" - "github.com/deso-protocol/core/lib" - "github.com/stretchr/testify/require" - "os" "testing" ) @@ -16,41 +11,22 @@ import ( // 4. node2 syncs MaxSyncBlockHeight blocks from node1. // 5. compare node1 db matches node2 db. func TestSimpleBlockSync(t *testing.T) { - require := require.New(t) - _ = require - - dbDir1 := getDirectory(t) - dbDir2 := getDirectory(t) - defer os.RemoveAll(dbDir1) - defer os.RemoveAll(dbDir2) - - config1 := generateConfig(t, 18000, dbDir1, 10) - config1.SyncType = lib.NodeSyncTypeBlockSync - config2 := generateConfig(t, 18001, dbDir2, 10) - config2.SyncType = lib.NodeSyncTypeBlockSync - - config1.ConnectIPs = []string{"deso-seed-2.io:17000"} - - node1 := cmd.NewNode(config1) - node2 := cmd.NewNode(config2) - + node1 := spawnNodeProtocol1(t, 18000, "node1") + node1.Config.ConnectIPs = []string{"deso-seed-2.io:17000"} node1 = startNode(t, node1) - node2 = startNode(t, node2) // wait for node1 to sync blocks waitForNodeToFullySync(node1) - // bridge the nodes together. - bridge := NewConnectionBridge(node1, node2) - require.NoError(bridge.Start()) + node2 := spawnNodeProtocol1(t, 18001, "node2") + node2.Config.ConnectIPs = []string{"127.0.0.1:18000"} + node2 = startNode(t, node2) // wait for node2 to sync blocks. waitForNodeToFullySync(node2) compareNodesByDB(t, node1, node2, 0) - fmt.Println("Databases match!") - node1.Stop() - node2.Stop() + t.Logf("Databases match!") } // TestSimpleSyncRestart tests if a node can successfully restart while syncing blocks. @@ -62,45 +38,26 @@ func TestSimpleBlockSync(t *testing.T) { // 6. node2 reconnects with node1 and syncs remaining blocks. // 7. compare node1 db matches node2 db. func TestSimpleSyncRestart(t *testing.T) { - require := require.New(t) - _ = require - - dbDir1 := getDirectory(t) - dbDir2 := getDirectory(t) - defer os.RemoveAll(dbDir1) - defer os.RemoveAll(dbDir2) - - config1 := generateConfig(t, 18000, dbDir1, 10) - config1.SyncType = lib.NodeSyncTypeBlockSync - config2 := generateConfig(t, 18001, dbDir2, 10) - config2.SyncType = lib.NodeSyncTypeBlockSync - - config1.ConnectIPs = []string{"deso-seed-2.io:17000"} - - node1 := cmd.NewNode(config1) - node2 := cmd.NewNode(config2) - + node1 := spawnNodeProtocol1(t, 18000, "node1") + node1.Config.ConnectIPs = []string{"deso-seed-2.io:17000"} node1 = startNode(t, node1) - node2 = startNode(t, node2) // wait for node1 to sync blocks waitForNodeToFullySync(node1) - // bridge the nodes together. - bridge := NewConnectionBridge(node1, node2) - require.NoError(bridge.Start()) + node2 := spawnNodeProtocol1(t, 18001, "node2") + node2.Config.ConnectIPs = []string{"127.0.0.1:18000"} + node2 = startNode(t, node2) - randomHeight := randomUint32Between(t, 10, config2.MaxSyncBlockHeight) - fmt.Println("Random height for a restart (re-use if test failed):", randomHeight) + randomHeight := randomUint32Between(t, 10, node2.Config.MaxSyncBlockHeight) + t.Logf("Random height for a restart (re-use if test failed): %v", randomHeight) // Reboot node2 at a specific height and reconnect it with node1 - node2, bridge = restartAtHeightAndReconnectNode(t, node2, node1, bridge, randomHeight) + node2 = restartAtHeight(t, node2, randomHeight) waitForNodeToFullySync(node2) compareNodesByDB(t, node1, node2, 0) - fmt.Println("Random restart successful! Random height was", randomHeight) - fmt.Println("Databases match!") - node1.Stop() - node2.Stop() + t.Logf("Random restart successful! Random height was: %v", randomHeight) + t.Logf("Databases match!") } // TestSimpleSyncDisconnectWithSwitchingToNewPeer tests if a node can successfully restart while syncing blocks, and @@ -114,60 +71,35 @@ func TestSimpleSyncRestart(t *testing.T) { // 7. compare node1 state matches node2 state. // 8. compare node3 state matches node2 state. func TestSimpleSyncDisconnectWithSwitchingToNewPeer(t *testing.T) { - require := require.New(t) - _ = require - - dbDir1 := getDirectory(t) - dbDir2 := getDirectory(t) - dbDir3 := getDirectory(t) - defer os.RemoveAll(dbDir1) - defer os.RemoveAll(dbDir2) - defer os.RemoveAll(dbDir3) - - config1 := generateConfig(t, 18000, dbDir1, 10) - config1.SyncType = lib.NodeSyncTypeBlockSync - config2 := generateConfig(t, 18001, dbDir2, 10) - config2.SyncType = lib.NodeSyncTypeBlockSync - config3 := generateConfig(t, 18002, dbDir3, 10) - config3.SyncType = lib.NodeSyncTypeBlockSync - - config1.ConnectIPs = []string{"deso-seed-2.io:17000"} - config3.ConnectIPs = []string{"deso-seed-2.io:17000"} - - node1 := cmd.NewNode(config1) - node2 := cmd.NewNode(config2) - node3 := cmd.NewNode(config3) - + node1 := spawnNodeProtocol1(t, 18000, "node1") + node1.Config.ConnectIPs = []string{"deso-seed-2.io:17000"} node1 = startNode(t, node1) - node2 = startNode(t, node2) - node3 = startNode(t, node3) // wait for node1 to sync blocks waitForNodeToFullySync(node1) + + node3 := spawnNodeProtocol1(t, 18002, "node3") + node3.Config.ConnectIPs = []string{"deso-seed-2.io:17000"} + node3 = startNode(t, node3) + // wait for node3 to sync blocks waitForNodeToFullySync(node3) - // bridge the nodes together. - bridge12 := NewConnectionBridge(node1, node2) - require.NoError(bridge12.Start()) - - randomHeight := randomUint32Between(t, 10, config2.MaxSyncBlockHeight) - fmt.Println("Random height for a restart (re-use if test failed):", randomHeight) - disconnectAtBlockHeight(t, node2, bridge12, randomHeight) + node2 := spawnNodeProtocol1(t, 18001, "node2") + node2.Config.ConnectIPs = []string{"127.0.0.1:18000"} + node2 = startNode(t, node2) - // bridge the nodes together. - bridge23 := NewConnectionBridge(node2, node3) - require.NoError(bridge23.Start()) + randomHeight := randomUint32Between(t, 10, node2.Config.MaxSyncBlockHeight) + t.Logf("Random height for a restart (re-use if test failed): %v", randomHeight) - // Reboot node2 at a specific height and reconnect it with node1 - //node2, bridge12 = restartAtHeightAndReconnectNode(t, node2, node1, bridge12, randomHeight) + // Reboot node2 at a specific height and reconnect it with node3 + node2 = shutdownAtHeight(t, node2, randomHeight) + node2.Config.ConnectIPs = []string{"127.0.0.1:18002"} + node2 = startNode(t, node2) waitForNodeToFullySync(node2) compareNodesByDB(t, node1, node2, 0) compareNodesByDB(t, node3, node2, 0) - fmt.Println("Random restart successful! Random height was", randomHeight) - fmt.Println("Databases match!") - node1.Stop() - node2.Stop() - node3.Stop() + t.Logf("Random restart successful! Random height was %v", randomHeight) + t.Logf("Databases match!") } diff --git a/integration_testing/connection_bridge.go b/integration_testing/connection_bridge.go index 1d0228467..b93fabac5 100644 --- a/integration_testing/connection_bridge.go +++ b/integration_testing/connection_bridge.go @@ -13,6 +13,7 @@ import ( "time" ) +// TODO: DEPRECATE // ConnectionBridge is a bidirectional communication channel between two nodes. A bridge creates a pair of inbound and // outbound peers for each of the nodes to handle communication. In total, it creates four peers. // @@ -111,13 +112,13 @@ func (bridge *ConnectionBridge) createInboundConnection(node *cmd.Node) *lib.Pee } // This channel is redundant in our setting. - messagesFromPeer := make(chan *lib.ServerMessage) + messagesFromPeer := make(chan *lib.ServerMessage, 100) + donePeerChan := make(chan *lib.Peer, 100) // Because it is an inbound Peer of the node, it is simultaneously a "fake" outbound Peer of the bridge. // Hence, we will mark the _isOutbound parameter as "true" in NewPeer. - peer := lib.NewPeer(conn, true, netAddress, true, - 10000, 0, &lib.DeSoMainnetParams, - messagesFromPeer, nil, nil, lib.NodeSyncTypeAny) - peer.ID = uint64(lib.RandInt64(math.MaxInt64)) + peer := lib.NewPeer(uint64(lib.RandInt64(math.MaxInt64)), conn, true, + netAddress, true, 10000, 0, &lib.DeSoMainnetParams, + messagesFromPeer, nil, nil, lib.NodeSyncTypeAny, donePeerChan) return peer } @@ -139,27 +140,27 @@ func (bridge *ConnectionBridge) createOutboundConnection(node *cmd.Node, otherNo fmt.Println("createOutboundConnection: Got a connection from remote:", conn.RemoteAddr().String(), "on listener:", ll.Addr().String()) - na, err := lib.IPToNetAddr(conn.RemoteAddr().String(), otherNode.Server.GetConnectionManager().AddrMgr, - otherNode.Params) - messagesFromPeer := make(chan *lib.ServerMessage) - peer := lib.NewPeer(conn, false, na, false, - 10000, 0, bridge.nodeB.Params, - messagesFromPeer, nil, nil, lib.NodeSyncTypeAny) - peer.ID = uint64(lib.RandInt64(math.MaxInt64)) + addrMgr := addrmgr.New("", net.LookupIP) + na, err := lib.IPToNetAddr(conn.RemoteAddr().String(), addrMgr, otherNode.Params) + messagesFromPeer := make(chan *lib.ServerMessage, 100) + donePeerChan := make(chan *lib.Peer, 100) + peer := lib.NewPeer(uint64(lib.RandInt64(math.MaxInt64)), conn, + false, na, false, 10000, 0, bridge.nodeB.Params, + messagesFromPeer, nil, nil, lib.NodeSyncTypeAny, donePeerChan) bridge.newPeerChan <- peer //} }(ll) // Make the provided node to make an outbound connection to our listener. - netAddress, _ := lib.IPToNetAddr(ll.Addr().String(), addrmgr.New("", net.LookupIP), &lib.DeSoMainnetParams) - fmt.Println("createOutboundConnection: IP:", netAddress.IP, "Port:", netAddress.Port) - go node.Server.GetConnectionManager().ConnectPeer(nil, netAddress) + addrMgr := addrmgr.New("", net.LookupIP) + addr, _ := lib.IPToNetAddr(ll.Addr().String(), addrMgr, node.Params) + go node.Server.GetConnectionManager().DialOutboundConnection(addr, uint64(lib.RandInt64(math.MaxInt64))) } // getVersionMessage simulates a version message that the provided node would have sent. func (bridge *ConnectionBridge) getVersionMessage(node *cmd.Node) *lib.MsgDeSoVersion { ver := lib.NewMessage(lib.MsgTypeVersion).(*lib.MsgDeSoVersion) - ver.Version = node.Params.ProtocolVersion + ver.Version = node.Params.ProtocolVersion.ToUint64() ver.TstampSecs = time.Now().Unix() ver.Nonce = uint64(lib.RandInt64(math.MaxInt64)) ver.UserAgent = node.Params.UserAgent @@ -172,27 +173,43 @@ func (bridge *ConnectionBridge) getVersionMessage(node *cmd.Node) *lib.MsgDeSoVe } if node.Server != nil { - ver.StartBlockHeight = uint32(node.Server.GetBlockchain().BlockTip().Header.Height) + ver.LatestBlockHeight = node.Server.GetBlockchain().BlockTip().Header.Height } ver.MinFeeRateNanosPerKB = node.Config.MinFeerate return ver } +func ReadWithTimeout(readFunc func() error, readTimeout time.Duration) error { + errChan := make(chan error) + go func() { + errChan <- readFunc() + }() + select { + case err := <-errChan: + { + return err + } + case <-time.After(readTimeout): + { + return fmt.Errorf("ReadWithTimeout: Timed out reading message") + } + } +} + // startConnection starts the connection by performing version and verack exchange with // the provided connection, pretending to be the otherNode. func (bridge *ConnectionBridge) startConnection(connection *lib.Peer, otherNode *cmd.Node) error { // Prepare the version message. versionMessage := bridge.getVersionMessage(otherNode) - connection.VersionNonceSent = versionMessage.Nonce // Send the version message. - fmt.Println("Sending version message:", versionMessage, versionMessage.StartBlockHeight) + fmt.Println("Sending version message:", versionMessage, versionMessage.LatestBlockHeight) if err := connection.WriteDeSoMessage(versionMessage); err != nil { return err } // Wait for a response to the version message. - if err := connection.ReadWithTimeout( + if err := ReadWithTimeout( func() error { msg, err := connection.ReadDeSoMessage() if err != nil { @@ -204,7 +221,6 @@ func (bridge *ConnectionBridge) startConnection(connection *lib.Peer, otherNode return err } - connection.VersionNonceReceived = verMsg.Nonce connection.TimeConnected = time.Unix(verMsg.TstampSecs, 0) connection.TimeOffsetSecs = verMsg.TstampSecs - time.Now().Unix() return nil @@ -215,7 +231,6 @@ func (bridge *ConnectionBridge) startConnection(connection *lib.Peer, otherNode // Now prepare the verack message. verackMsg := lib.NewMessage(lib.MsgTypeVerack) - verackMsg.(*lib.MsgDeSoVerack).Nonce = connection.VersionNonceReceived // And send it to the connection. if err := connection.WriteDeSoMessage(verackMsg); err != nil { @@ -223,7 +238,7 @@ func (bridge *ConnectionBridge) startConnection(connection *lib.Peer, otherNode } // And finally wait for connection's response to the verack message. - if err := connection.ReadWithTimeout( + if err := ReadWithTimeout( func() error { msg, err := connection.ReadDeSoMessage() if err != nil { @@ -233,17 +248,11 @@ func (bridge *ConnectionBridge) startConnection(connection *lib.Peer, otherNode if msg.GetMsgType() != lib.MsgTypeVerack { return fmt.Errorf("message is not verack! Type: %v", msg.GetMsgType()) } - verackMsg := msg.(*lib.MsgDeSoVerack) - if verackMsg.Nonce != connection.VersionNonceSent { - return fmt.Errorf("verack message nonce doesn't match (received: %v, sent: %v)", - verackMsg.Nonce, connection.VersionNonceSent) - } return nil }, lib.DeSoMainnetParams.VersionNegotiationTimeout); err != nil { return err } - connection.VersionNegotiated = true return nil } diff --git a/integration_testing/hypersync_test.go b/integration_testing/hypersync_test.go index aad90ee0e..b76b1db48 100644 --- a/integration_testing/hypersync_test.go +++ b/integration_testing/hypersync_test.go @@ -1,11 +1,7 @@ package integration_testing import ( - "fmt" - "github.com/deso-protocol/core/cmd" "github.com/deso-protocol/core/lib" - "github.com/stretchr/testify/require" - "os" "testing" ) @@ -16,35 +12,19 @@ import ( // 4. node2 hypersyncs from node1 // 5. once done, compare node1 state, db, and checksum matches node2. func TestSimpleHyperSync(t *testing.T) { - require := require.New(t) - _ = require - - dbDir1 := getDirectory(t) - dbDir2 := getDirectory(t) - defer os.RemoveAll(dbDir1) - defer os.RemoveAll(dbDir2) - - config1 := generateConfig(t, 18000, dbDir1, 10) - config1.SyncType = lib.NodeSyncTypeBlockSync - config2 := generateConfig(t, 18001, dbDir2, 10) - config2.SyncType = lib.NodeSyncTypeHyperSync - - config1.HyperSync = true - config2.HyperSync = true - config1.ConnectIPs = []string{"deso-seed-2.io:17000"} - - node1 := cmd.NewNode(config1) - node2 := cmd.NewNode(config2) - + node1 := spawnNodeProtocol1(t, 18000, "node1") + node1.Config.HyperSync = true + node1.Config.ConnectIPs = []string{"deso-seed-2.io:17000"} node1 = startNode(t, node1) - node2 = startNode(t, node2) // wait for node1 to sync blocks waitForNodeToFullySync(node1) - // bridge the nodes together. - bridge := NewConnectionBridge(node1, node2) - require.NoError(bridge.Start()) + node2 := spawnNodeProtocol1(t, 18001, "node2") + node2.Config.SyncType = lib.NodeSyncTypeHyperSync + node2.Config.HyperSync = true + node2.Config.ConnectIPs = []string{"127.0.0.1:18000"} + node2 = startNode(t, node2) // wait for node2 to sync blocks. waitForNodeToFullySync(node2) @@ -52,9 +32,7 @@ func TestSimpleHyperSync(t *testing.T) { compareNodesByState(t, node1, node2, 0) //compareNodesByDB(t, node1, node2, 0) compareNodesByChecksum(t, node1, node2) - fmt.Println("Databases match!") - node1.Stop() - node2.Stop() + t.Logf("Databases match!") } // TestHyperSyncFromHyperSyncedNode test if a node can successfully hypersync from another hypersynced node: @@ -65,49 +43,28 @@ func TestSimpleHyperSync(t *testing.T) { // 5. once done, bridge node3 and node2 so that node3 hypersyncs from node2. // 6. compare node1 state, db, and checksum matches node2, and node3. func TestHyperSyncFromHyperSyncedNode(t *testing.T) { - require := require.New(t) - _ = require - - dbDir1 := getDirectory(t) - dbDir2 := getDirectory(t) - dbDir3 := getDirectory(t) - defer os.RemoveAll(dbDir1) - defer os.RemoveAll(dbDir2) - defer os.RemoveAll(dbDir3) - - config1 := generateConfig(t, 18000, dbDir1, 10) - config1.SyncType = lib.NodeSyncTypeBlockSync - config2 := generateConfig(t, 18001, dbDir2, 10) - config2.SyncType = lib.NodeSyncTypeHyperSyncArchival - config3 := generateConfig(t, 18002, dbDir3, 10) - config3.SyncType = lib.NodeSyncTypeHyperSyncArchival - - config1.HyperSync = true - config2.HyperSync = true - config3.HyperSync = true - config1.ConnectIPs = []string{"deso-seed-2.io:17000"} - - node1 := cmd.NewNode(config1) - node2 := cmd.NewNode(config2) - node3 := cmd.NewNode(config3) - + node1 := spawnNodeProtocol1(t, 18000, "node1") + node1.Config.HyperSync = true + node1.Config.ConnectIPs = []string{"deso-seed-2.io:17000"} node1 = startNode(t, node1) - node2 = startNode(t, node2) - node3 = startNode(t, node3) // wait for node1 to sync blocks waitForNodeToFullySync(node1) - // bridge the nodes together. - bridge12 := NewConnectionBridge(node1, node2) - require.NoError(bridge12.Start()) + node2 := spawnNodeProtocol1(t, 18001, "node2") + node2.Config.SyncType = lib.NodeSyncTypeHyperSyncArchival + node2.Config.HyperSync = true + node2.Config.ConnectIPs = []string{"127.0.0.1:18000"} + node2 = startNode(t, node2) // wait for node2 to sync blocks. waitForNodeToFullySync(node2) - // bridge node3 to node2 to kick off hyper sync from a hyper synced node - bridge23 := NewConnectionBridge(node2, node3) - require.NoError(bridge23.Start()) + node3 := spawnNodeProtocol1(t, 18002, "node3") + node3.Config.SyncType = lib.NodeSyncTypeHyperSyncArchival + node3.Config.HyperSync = true + node3.Config.ConnectIPs = []string{"127.0.0.1:18001"} + node3 = startNode(t, node3) // wait for node2 to sync blocks. waitForNodeToFullySync(node3) @@ -121,10 +78,7 @@ func TestHyperSyncFromHyperSyncedNode(t *testing.T) { //compareNodesByDB(t, node2, node3, 0) compareNodesByChecksum(t, node2, node3) - fmt.Println("Databases match!") - node1.Stop() - node2.Stop() - node3.Stop() + t.Logf("Databases match!") } // TestSimpleHyperSyncRestart test if a node can successfully hyper sync from another node: @@ -135,51 +89,34 @@ func TestHyperSyncFromHyperSyncedNode(t *testing.T) { // 5. node2 reconnects to node1 and hypersyncs again. // 6. Once node2 finishes sync, compare node1 state, db, and checksum matches node2. func TestSimpleHyperSyncRestart(t *testing.T) { - require := require.New(t) - _ = require - - dbDir1 := getDirectory(t) - dbDir2 := getDirectory(t) - defer os.RemoveAll(dbDir1) - defer os.RemoveAll(dbDir2) - - config1 := generateConfig(t, 18000, dbDir1, 10) - config2 := generateConfig(t, 18001, dbDir2, 10) - - config1.HyperSync = true - config1.SyncType = lib.NodeSyncTypeBlockSync - config2.HyperSync = true - config2.SyncType = lib.NodeSyncTypeHyperSyncArchival - config1.ConnectIPs = []string{"deso-seed-2.io:17000"} - - node1 := cmd.NewNode(config1) - node2 := cmd.NewNode(config2) - + node1 := spawnNodeProtocol1(t, 18000, "node1") + node1.Config.HyperSync = true + node1.Config.ConnectIPs = []string{"deso-seed-2.io:17000"} node1 = startNode(t, node1) - node2 = startNode(t, node2) // wait for node1 to sync blocks waitForNodeToFullySync(node1) - // bridge the nodes together. - bridge := NewConnectionBridge(node1, node2) - require.NoError(bridge.Start()) + node2 := spawnNodeProtocol1(t, 18001, "node2") + node2.Config.SyncType = lib.NodeSyncTypeHyperSyncArchival + node2.Config.HyperSync = true + node2.Config.ConnectIPs = []string{"127.0.0.1:18000"} + node2 = startNode(t, node2) syncIndex := randomUint32Between(t, 0, uint32(len(lib.StatePrefixes.StatePrefixesList))) syncPrefix := lib.StatePrefixes.StatePrefixesList[syncIndex] - fmt.Println("Random sync prefix for a restart (re-use if test failed):", syncPrefix) + t.Logf("Random sync prefix for a restart (re-use if test failed): %v", syncPrefix) + // Reboot node2 at a specific sync prefix and reconnect it with node1 - node2, bridge = restartAtSyncPrefixAndReconnectNode(t, node2, node1, bridge, syncPrefix) + node2 = restartAtSyncPrefix(t, node2, syncPrefix) // wait for node2 to sync blocks. waitForNodeToFullySync(node2) compareNodesByState(t, node1, node2, 0) //compareNodesByDB(t, node1, node2, 0) compareNodesByChecksum(t, node1, node2) - fmt.Println("Random restart successful! Random sync prefix was", syncPrefix) - fmt.Println("Databases match!") - node1.Stop() - node2.Stop() + t.Logf("Random restart successful! Random sync prefix was: %v", syncPrefix) + t.Logf("Databases match!") } // TestSimpleHyperSyncDisconnectWithSwitchingToNewPeer tests if a node can successfully restart while hypersyncing. @@ -190,57 +127,34 @@ func TestSimpleHyperSyncRestart(t *testing.T) { // 5. after restart, bridge node2 with node3 and resume hypersync. // 6. once node2 finishes, compare node1, node2, node3 state, db, and checksums are identical. func TestSimpleHyperSyncDisconnectWithSwitchingToNewPeer(t *testing.T) { - require := require.New(t) - _ = require - - dbDir1 := getDirectory(t) - dbDir2 := getDirectory(t) - dbDir3 := getDirectory(t) - defer os.RemoveAll(dbDir1) - defer os.RemoveAll(dbDir2) - defer os.RemoveAll(dbDir3) - - config1 := generateConfig(t, 18000, dbDir1, 10) - config1.SyncType = lib.NodeSyncTypeBlockSync - config2 := generateConfig(t, 18001, dbDir2, 10) - config2.SyncType = lib.NodeSyncTypeHyperSyncArchival - config3 := generateConfig(t, 18002, dbDir3, 10) - config3.SyncType = lib.NodeSyncTypeBlockSync - - config1.HyperSync = true - config2.HyperSync = true - config3.HyperSync = true - config1.ConnectIPs = []string{"deso-seed-2.io:17000"} - config3.ConnectIPs = []string{"deso-seed-2.io:17000"} - - node1 := cmd.NewNode(config1) - node2 := cmd.NewNode(config2) - node3 := cmd.NewNode(config3) - + node1 := spawnNodeProtocol1(t, 18000, "node1") + node1.Config.HyperSync = true + node1.Config.ConnectIPs = []string{"deso-seed-2.io:17000"} node1 = startNode(t, node1) - node2 = startNode(t, node2) - node3 = startNode(t, node3) - // wait for node1 to sync blocks waitForNodeToFullySync(node1) + + node3 := spawnNodeProtocol1(t, 18002, "node3") + node3.Config.HyperSync = true + node3.Config.ConnectIPs = []string{"127.0.0.1:18000"} + node3 = startNode(t, node3) // wait for node3 to sync blocks waitForNodeToFullySync(node3) - // bridge the nodes together. - bridge12 := NewConnectionBridge(node1, node2) - require.NoError(bridge12.Start()) + node2 := spawnNodeProtocol1(t, 18001, "node2") + node2.Config.SyncType = lib.NodeSyncTypeHyperSyncArchival + node2.Config.HyperSync = true + node2.Config.ConnectIPs = []string{"127.0.0.1:18000"} + node2 = startNode(t, node2) + // Reboot node2 at a specific height and reconnect it with node1 syncIndex := randomUint32Between(t, 0, uint32(len(lib.StatePrefixes.StatePrefixesList))) syncPrefix := lib.StatePrefixes.StatePrefixesList[syncIndex] - fmt.Println("Random prefix for a restart (re-use if test failed):", syncPrefix) - disconnectAtSyncPrefix(t, node2, bridge12, syncPrefix) - - // bridge the nodes together. - bridge23 := NewConnectionBridge(node2, node3) - require.NoError(bridge23.Start()) + t.Logf("Random prefix for a restart (re-use if test failed): %v", syncPrefix) + node2 = shutdownAtSyncPrefix(t, node2, syncPrefix) + node2.Config.ConnectIPs = []string{"127.0.0.1:18002"} + node2 = startNode(t, node2) - // Reboot node2 at a specific height and reconnect it with node1 - //node2, bridge12 = restartAtHeightAndReconnectNode(t, node2, node1, bridge12, randomHeight) // wait for node2 to sync blocks. waitForNodeToFullySync(node2) @@ -253,11 +167,8 @@ func TestSimpleHyperSyncDisconnectWithSwitchingToNewPeer(t *testing.T) { compareNodesByState(t, node1, node2, 0) //compareNodesByDB(t, node1, node2, 0) compareNodesByChecksum(t, node1, node2) - fmt.Println("Random restart successful! Random sync prefix was", syncPrefix) - fmt.Println("Databases match!") - node1.Stop() - node2.Stop() - node3.Stop() + t.Logf("Random restart successful! Random sync prefix was: %v", syncPrefix) + t.Logf("Databases match!") } // TODO: disconnecting the provider peer during hypersync doesn't work. @@ -311,92 +222,49 @@ func TestSimpleHyperSyncDisconnectWithSwitchingToNewPeer(t *testing.T) { //} func TestArchivalMode(t *testing.T) { - require := require.New(t) - _ = require - - dbDir1 := getDirectory(t) - dbDir2 := getDirectory(t) - defer os.RemoveAll(dbDir1) - defer os.RemoveAll(dbDir2) - - config1 := generateConfig(t, 18000, dbDir1, 10) - config2 := generateConfig(t, 18001, dbDir2, 10) - - config1.HyperSync = true - config2.HyperSync = true - config1.ConnectIPs = []string{"deso-seed-2.io:17000"} - config1.SyncType = lib.NodeSyncTypeBlockSync - config2.SyncType = lib.NodeSyncTypeHyperSyncArchival - - node1 := cmd.NewNode(config1) - node2 := cmd.NewNode(config2) - + node1 := spawnNodeProtocol1(t, 18000, "node1") + node1.Config.HyperSync = true + node1.Config.ConnectIPs = []string{"deso-seed-2.io:17000"} node1 = startNode(t, node1) - node2 = startNode(t, node2) // wait for node1 to sync blocks waitForNodeToFullySync(node1) - // bridge the nodes together. - bridge := NewConnectionBridge(node1, node2) - require.NoError(bridge.Start()) + node2 := spawnNodeProtocol1(t, 18001, "node2") + node2.Config.SyncType = lib.NodeSyncTypeHyperSyncArchival + node2.Config.HyperSync = true + node2.Config.ConnectIPs = []string{"127.0.0.1:18000"} + node2 = startNode(t, node2) // wait for node2 to sync blocks. waitForNodeToFullySync(node2) compareNodesByDB(t, node1, node2, 0) - - //compareNodesByDB(t, node1, node2, 0) compareNodesByChecksum(t, node1, node2) - fmt.Println("Databases match!") - node1.Stop() - node2.Stop() + t.Logf("Databases match!") } func TestBlockSyncFromArchivalModeHyperSync(t *testing.T) { - require := require.New(t) - _ = require - - dbDir1 := getDirectory(t) - dbDir2 := getDirectory(t) - dbDir3 := getDirectory(t) - defer os.RemoveAll(dbDir1) - defer os.RemoveAll(dbDir2) - defer os.RemoveAll(dbDir3) - - config1 := generateConfig(t, 18000, dbDir1, 10) - config2 := generateConfig(t, 18001, dbDir2, 10) - config3 := generateConfig(t, 18002, dbDir3, 10) - - config1.HyperSync = true - config1.SyncType = lib.NodeSyncTypeBlockSync - config2.HyperSync = true - config2.SyncType = lib.NodeSyncTypeHyperSyncArchival - config3.HyperSync = false - config3.SyncType = lib.NodeSyncTypeBlockSync - config1.ConnectIPs = []string{"deso-seed-2.io:17000"} - - node1 := cmd.NewNode(config1) - node2 := cmd.NewNode(config2) - node3 := cmd.NewNode(config3) - + node1 := spawnNodeProtocol1(t, 18000, "node1") + node1.Config.HyperSync = true + node1.Config.ConnectIPs = []string{"deso-seed-2.io:17000"} node1 = startNode(t, node1) - node2 = startNode(t, node2) - node3 = startNode(t, node3) - // wait for node1 to sync blocks waitForNodeToFullySync(node1) - // bridge the nodes together. - bridge12 := NewConnectionBridge(node1, node2) - require.NoError(bridge12.Start()) - + node2 := spawnNodeProtocol1(t, 18001, "node2") + node2.Config.SyncType = lib.NodeSyncTypeHyperSyncArchival + node2.Config.HyperSync = true + node2.Config.ConnectIPs = []string{"127.0.0.1:18000"} + node2 = startNode(t, node2) // wait for node2 to sync blocks. waitForNodeToFullySync(node2) - bridge23 := NewConnectionBridge(node2, node3) - require.NoError(bridge23.Start()) - + node3 := spawnNodeProtocol1(t, 18002, "node3") + node3.Config.SyncType = lib.NodeSyncTypeBlockSync + node3.Config.HyperSync = true + node3.Config.ConnectIPs = []string{"127.0.0.1:18001"} + node3 = startNode(t, node3) // wait for node3 to sync blocks. waitForNodeToFullySync(node3) @@ -405,7 +273,5 @@ func TestBlockSyncFromArchivalModeHyperSync(t *testing.T) { //compareNodesByDB(t, node1, node2, 0) compareNodesByChecksum(t, node1, node2) - fmt.Println("Databases match!") - node1.Stop() - node2.Stop() + t.Logf("Databases match!") } diff --git a/integration_testing/migrations_test.go b/integration_testing/migrations_test.go index b0a692b52..067a2f3b6 100644 --- a/integration_testing/migrations_test.go +++ b/integration_testing/migrations_test.go @@ -1,64 +1,39 @@ package integration_testing import ( - "fmt" - "github.com/deso-protocol/core/cmd" - "github.com/deso-protocol/core/lib" "github.com/stretchr/testify/require" - "os" "testing" ) // TODO: Add an encoder migration height in constants.go then modify some // random struct like UtxoEntry. Until we have a migration, we can't fully test this. func TestEncoderMigrations(t *testing.T) { - require := require.New(t) - _ = require - - dbDir1 := getDirectory(t) - dbDir2 := getDirectory(t) - defer os.RemoveAll(dbDir1) - defer os.RemoveAll(dbDir2) - - config1 := generateConfig(t, 18000, dbDir1, 10) - config1.SyncType = lib.NodeSyncTypeBlockSync - config2 := generateConfig(t, 18001, dbDir2, 10) - config2.SyncType = lib.NodeSyncTypeHyperSync - - config1.ConnectIPs = []string{"deso-seed-2.io:17000"} - config1.HyperSync = true - config2.HyperSync = true - - node1 := cmd.NewNode(config1) - node2 := cmd.NewNode(config2) - + node1 := spawnNodeProtocol1(t, 18000, "node1") + node1.Config.HyperSync = true + node1.Config.ConnectIPs = []string{"deso-seed-2.io:17000"} node1 = startNode(t, node1) - node2 = startNode(t, node2) - // wait for node1 to sync blocks waitForNodeToFullySync(node1) - // bridge the nodes together. - bridge := NewConnectionBridge(node1, node2) - require.NoError(bridge.Start()) - + node2 := spawnNodeProtocol1(t, 18001, "node2") + node2.Config.HyperSync = true + node2.Config.ConnectIPs = []string{"127.0.0.1:18000"} + node2 = startNode(t, node2) // wait for node2 to sync blocks. waitForNodeToFullySync(node2) - fmt.Println("Chain state and operation channel", node2.Server.GetBlockchain().ChainState(), + t.Logf("Chain state and operation channel (state: %v), (len: %v)", node2.Server.GetBlockchain().ChainState(), len(node2.Server.GetBlockchain().Snapshot().OperationChannel.OperationChannel)) compareNodesByState(t, node1, node2, 0) - fmt.Println("node1 checksum:", computeNodeStateChecksum(t, node1, 1500)) - fmt.Println("node2 checksum:", computeNodeStateChecksum(t, node2, 1500)) + t.Logf("node1 checksum: %v", computeNodeStateChecksum(t, node1, 1500)) + t.Logf("node2 checksum: %v", computeNodeStateChecksum(t, node2, 1500)) checksum1, err := node1.Server.GetBlockchain().Snapshot().Checksum.ToBytes() - require.NoError(err) + require.NoError(t, err) checksum2, err := node2.Server.GetBlockchain().Snapshot().Checksum.ToBytes() - require.NoError(err) - fmt.Println("node1 server checksum:", checksum1) - fmt.Println("node2 server checksum:", checksum2) + require.NoError(t, err) + t.Logf("node1 server checksum: %v", checksum1) + t.Logf("node2 server checksum: %v", checksum2) compareNodesByChecksum(t, node1, node2) - fmt.Println("Databases match!") - node1.Stop() - node2.Stop() + t.Logf("Databases match!") } diff --git a/integration_testing/mining_test.go b/integration_testing/mining_test.go index 49a23333c..facbce226 100644 --- a/integration_testing/mining_test.go +++ b/integration_testing/mining_test.go @@ -1,37 +1,22 @@ package integration_testing import ( - "github.com/deso-protocol/core/cmd" "github.com/deso-protocol/core/lib" - "github.com/stretchr/testify/require" - "os" "testing" ) // TestSimpleBlockSync test if a node can mine blocks on regtest func TestRegtestMiner(t *testing.T) { - require := require.New(t) - _ = require - - dbDir1 := getDirectory(t) - defer os.RemoveAll(dbDir1) - - config1 := generateConfig(t, 18000, dbDir1, 10) - config1.SyncType = lib.NodeSyncTypeBlockSync - config1.Params = &lib.DeSoTestnetParams - config1.MaxSyncBlockHeight = 0 - config1.MinerPublicKeys = []string{"tBCKVERmG9nZpHTk2AVPqknWc1Mw9HHAnqrTpW1RnXpXMQ4PsQgnmV"} - - config1.Regtest = true - - node1 := cmd.NewNode(config1) + node1 := spawnNodeProtocol1(t, 18000, "node1") + params := lib.DeSoTestnetParams + node1.Config.Params = ¶ms + node1.Params = ¶ms + node1.Config.MaxSyncBlockHeight = 0 + node1.Config.MinerPublicKeys = []string{"tBCKVERmG9nZpHTk2AVPqknWc1Mw9HHAnqrTpW1RnXpXMQ4PsQgnmV"} + node1.Config.Regtest = true node1 = startNode(t, node1) // wait for node1 to sync blocks mineHeight := uint32(40) - listener := make(chan bool) - listenForBlockHeight(t, node1, mineHeight, listener) - <-listener - - node1.Stop() + <-listenForBlockHeight(node1, mineHeight) } diff --git a/integration_testing/network_manager_routines_test.go b/integration_testing/network_manager_routines_test.go new file mode 100644 index 000000000..f89bb2465 --- /dev/null +++ b/integration_testing/network_manager_routines_test.go @@ -0,0 +1,584 @@ +package integration_testing + +import ( + "fmt" + "github.com/deso-protocol/core/bls" + "github.com/deso-protocol/core/cmd" + "github.com/deso-protocol/core/collections" + "github.com/deso-protocol/core/consensus" + "github.com/deso-protocol/core/lib" + "github.com/stretchr/testify/require" + "github.com/tyler-smith/go-bip39" + "testing" + "time" +) + +func TestConnectionControllerInitiatePersistentConnections(t *testing.T) { + // NonValidator Node1 will set its --connect-ips to two non-validators node2 and node3, + // and two validators node4 and node5. + node1 := spawnNonValidatorNodeProtocol2(t, 18000, "node1") + node2 := spawnNonValidatorNodeProtocol2(t, 18001, "node2") + node3 := spawnNonValidatorNodeProtocol2(t, 18002, "node3") + blsSeedPhrase4, err := bip39.NewMnemonic(lib.RandomBytes(32)) + require.NoError(t, err) + node4 := spawnValidatorNodeProtocol2(t, 18003, "node4", blsSeedPhrase4) + blsSeedPhrase5, err := bip39.NewMnemonic(lib.RandomBytes(32)) + require.NoError(t, err) + node5 := spawnValidatorNodeProtocol2(t, 18004, "node5", blsSeedPhrase5) + + node2 = startNode(t, node2) + node3 = startNode(t, node3) + node4 = startNode(t, node4) + node5 = startNode(t, node5) + + node1.Config.ConnectIPs = []string{ + node2.Listeners[0].Addr().String(), + node3.Listeners[0].Addr().String(), + node4.Listeners[0].Addr().String(), + node5.Listeners[0].Addr().String(), + } + node1 = startNode(t, node1) + activeValidatorsMap := getActiveValidatorsMapWithValidatorNodes(t, node4, node5) + setActiveValidators(activeValidatorsMap, node1, node2, node3, node4, node5) + waitForNonValidatorOutboundConnection(t, node1, node2) + waitForNonValidatorOutboundConnection(t, node1, node3) + waitForValidatorConnection(t, node1, node4) + waitForValidatorConnection(t, node1, node5) + waitForValidatorConnection(t, node4, node5) + waitForCountRemoteNodeIndexerHandshakeCompleted(t, node1, 4, 2, 2, 0) + waitForCountRemoteNodeIndexerHandshakeCompleted(t, node2, 1, 0, 0, 1) + waitForCountRemoteNodeIndexerHandshakeCompleted(t, node3, 1, 0, 0, 1) + waitForCountRemoteNodeIndexerHandshakeCompleted(t, node4, 2, 1, 0, 1) + waitForCountRemoteNodeIndexerHandshakeCompleted(t, node5, 2, 1, 0, 1) + node1.Stop() + t.Logf("Test #1 passed | Successfully run non-validator node1 with --connect-ips set to node2, node3, node4, node5") + + // Now try again with a validator node6, with connect-ips set to node2, node3, node4, node5. + blsSeedPhrase6, err := bip39.NewMnemonic(lib.RandomBytes(32)) + require.NoError(t, err) + node6 := spawnValidatorNodeProtocol2(t, 18005, "node6", blsSeedPhrase6) + node6.Config.ConnectIPs = []string{ + node2.Listeners[0].Addr().String(), + node3.Listeners[0].Addr().String(), + node4.Listeners[0].Addr().String(), + node5.Listeners[0].Addr().String(), + } + node6 = startNode(t, node6) + activeValidatorsMap = getActiveValidatorsMapWithValidatorNodes(t, node4, node5, node6) + setActiveValidators(activeValidatorsMap, node1, node2, node3, node4, node5, node6) + waitForNonValidatorOutboundConnection(t, node6, node2) + waitForNonValidatorOutboundConnection(t, node6, node3) + waitForValidatorConnection(t, node6, node4) + waitForValidatorConnection(t, node6, node5) + waitForValidatorConnection(t, node4, node5) + waitForCountRemoteNodeIndexerHandshakeCompleted(t, node6, 4, 2, 2, 0) + waitForCountRemoteNodeIndexerHandshakeCompleted(t, node2, 1, 1, 0, 0) + waitForCountRemoteNodeIndexerHandshakeCompleted(t, node3, 1, 1, 0, 0) + waitForCountRemoteNodeIndexerHandshakeCompleted(t, node4, 2, 2, 0, 0) + waitForCountRemoteNodeIndexerHandshakeCompleted(t, node5, 2, 2, 0, 0) + t.Logf("Test #2 passed | Successfully run validator node6 with --connect-ips set to node2, node3, node4, node5") +} + +func TestConnectionControllerNonValidatorCircularConnectIps(t *testing.T) { + node1 := spawnNonValidatorNodeProtocol2(t, 18000, "node1") + node2 := spawnNonValidatorNodeProtocol2(t, 18001, "node2") + + node1.Config.ConnectIPs = []string{"127.0.0.1:18001"} + node2.Config.ConnectIPs = []string{"127.0.0.1:18000"} + + node1 = startNode(t, node1) + node2 = startNode(t, node2) + + waitForCountRemoteNodeIndexerHandshakeCompleted(t, node1, 2, 0, 1, 1) + waitForCountRemoteNodeIndexerHandshakeCompleted(t, node2, 2, 0, 1, 1) +} + +func TestNetworkManagerPersistentConnectorReconnect(t *testing.T) { + // Ensure that a node that is disconnected from a persistent connection will be reconnected to. + // Spawn three nodes: a non-validator node1, and node2, and a validator node3. Then set node1 connectIps + // to node2, node3, as well as a non-existing ip. Then we will stop node2, and wait for node1 to drop the + // connection. Then we will restart node2, and wait for node1 to reconnect to node2. We will repeat this + // process for node3. + + node1 := spawnNonValidatorNodeProtocol2(t, 18000, "node1") + // Set TargetOutboundPeers to 0 to ensure the non-validator connector doesn't interfere. + node1.Config.TargetOutboundPeers = 0 + + node2 := spawnNonValidatorNodeProtocol2(t, 18001, "node2") + blsSeedPhrase3, err := bip39.NewMnemonic(lib.RandomBytes(32)) + require.NoError(t, err) + node3 := spawnValidatorNodeProtocol2(t, 18002, "node3", blsSeedPhrase3) + + node2 = startNode(t, node2) + node3 = startNode(t, node3) + + node1.Config.ConnectIPs = []string{ + node2.Listeners[0].Addr().String(), + node3.Listeners[0].Addr().String(), + "127.0.0.1:18003", + } + node1 = startNode(t, node1) + activeValidatorsMap := getActiveValidatorsMapWithValidatorNodes(t, node3) + setActiveValidators(activeValidatorsMap, node1, node2, node3) + + waitForNonValidatorOutboundConnection(t, node1, node2) + waitForValidatorConnection(t, node1, node3) + waitForCountRemoteNodeIndexer(t, node1, 3, 1, 2, 0) + + node2.Stop() + waitForCountRemoteNodeIndexer(t, node1, 2, 1, 1, 0) + // node1 should reopen the connection to node2, and it should be re-indexed as a non-validator (attempted). + waitForCountRemoteNodeIndexer(t, node1, 3, 1, 2, 0) + node2 = startNode(t, node2) + setActiveValidators(activeValidatorsMap, node2) + waitForCountRemoteNodeIndexer(t, node1, 3, 1, 2, 0) + t.Logf("Test #1 passed | Successfully run reconnect test with non-validator node1 with --connect-ips for node2") + + // Now we will do the same for node3. + node3.Stop() + waitForCountRemoteNodeIndexer(t, node1, 2, 0, 2, 0) + // node1 should reopen the connection to node3, and it should be re-indexed as a non-validator (attempted). + waitForCountRemoteNodeIndexer(t, node1, 3, 0, 3, 0) + node3 = startNode(t, node3) + setActiveValidators(activeValidatorsMap, node3) + waitForValidatorConnection(t, node1, node3) + waitForCountRemoteNodeIndexer(t, node1, 3, 1, 2, 0) + t.Logf("Test #2 passed | Successfully run reconnect test with non-validator node1 with --connect-ips for node3") +} + +func TestConnectionControllerValidatorConnector(t *testing.T) { + // Spawn 5 validators node1, node2, node3, node4, node5 and two non-validators node6 and node7. + // All the validators are initially in the validator set. And later, node1 and node2 will be removed from the + // validator set. Then, make node3 inactive, and node2 active again. Then, make all the validators inactive. + // Make node6, and node7 connect-ips to all the validators. + + blsSeedPhrase1, err := bip39.NewMnemonic(lib.RandomBytes(32)) + require.NoError(t, err) + node1 := spawnValidatorNodeProtocol2(t, 18000, "node1", blsSeedPhrase1) + blsSeedPhrase2, err := bip39.NewMnemonic(lib.RandomBytes(32)) + require.NoError(t, err) + node2 := spawnValidatorNodeProtocol2(t, 18001, "node2", blsSeedPhrase2) + blsSeedPhrase3, err := bip39.NewMnemonic(lib.RandomBytes(32)) + require.NoError(t, err) + node3 := spawnValidatorNodeProtocol2(t, 18002, "node3", blsSeedPhrase3) + blsSeedPhrase4, err := bip39.NewMnemonic(lib.RandomBytes(32)) + require.NoError(t, err) + node4 := spawnValidatorNodeProtocol2(t, 18003, "node4", blsSeedPhrase4) + blsSeedPhrase5, err := bip39.NewMnemonic(lib.RandomBytes(32)) + require.NoError(t, err) + node5 := spawnValidatorNodeProtocol2(t, 18004, "node5", blsSeedPhrase5) + + node6 := spawnNonValidatorNodeProtocol2(t, 18005, "node6") + node7 := spawnNonValidatorNodeProtocol2(t, 18006, "node7") + + node1 = startNode(t, node1) + node2 = startNode(t, node2) + node3 = startNode(t, node3) + node4 = startNode(t, node4) + node5 = startNode(t, node5) + + node6.Config.ConnectIPs = []string{ + node1.Listeners[0].Addr().String(), + node2.Listeners[0].Addr().String(), + node3.Listeners[0].Addr().String(), + node4.Listeners[0].Addr().String(), + node5.Listeners[0].Addr().String(), + } + node7.Config.ConnectIPs = node6.Config.ConnectIPs + node6 = startNode(t, node6) + node7 = startNode(t, node7) + activeValidatorsMap := getActiveValidatorsMapWithValidatorNodes(t, node1, node2, node3, node4, node5) + setActiveValidators(activeValidatorsMap, node1, node2, node3, node4, node5, node6, node7) + + // Verify full graph between active validators. + waitForValidatorFullGraph(t, node1, node2, node3, node4, node5) + // Verify connections of non-validators. + for _, nonValidator := range []*cmd.Node{node6, node7} { + waitForValidatorConnectionOneWay(t, nonValidator, node1, node2, node3, node4, node5) + } + // Verify connections of initial validators. + for _, validator := range []*cmd.Node{node1, node2, node3, node4, node5} { + waitForNonValidatorInboundConnection(t, validator, node6) + waitForNonValidatorInboundConnection(t, validator, node7) + } + // Verify connection counts of active validators. + for _, validator := range []*cmd.Node{node1, node2, node3, node4, node5} { + waitForMinNonValidatorCountRemoteNodeIndexer(t, validator, 6, 4, 0, 2) + } + // NOOP Verify connection counts of inactive validators. + // Verify connection counts of non-validators. + waitForCountRemoteNodeIndexer(t, node6, 5, 5, 0, 0) + waitForCountRemoteNodeIndexer(t, node7, 5, 5, 0, 0) + t.Logf("Test #1 passed | Successfully run validators node1, node2, node3, node4, node5; non-validators node6, node7") + + // Remove node1 and node2 from the validator set. + activeValidatorsMap = getActiveValidatorsMapWithValidatorNodes(t, node3, node4, node5) + setActiveValidators(activeValidatorsMap, node1, node2, node3, node4, node5, node6, node7) + // Verify full graph between active validators. + waitForValidatorFullGraph(t, node3, node4, node5) + // Verify connections of non-validators. + for _, nonValidator := range []*cmd.Node{node1, node2, node6, node7} { + waitForValidatorConnectionOneWay(t, nonValidator, node3, node4, node5) + } + // Verify connections of initial validators. + for _, validator := range []*cmd.Node{node1, node2, node3, node4, node5} { + waitForNonValidatorInboundConnection(t, validator, node6) + waitForNonValidatorInboundConnection(t, validator, node7) + } + // Verify connections of active validators. + for _, validator := range []*cmd.Node{node3, node4, node5} { + waitForNonValidatorInboundXOROutboundConnection(t, validator, node1) + waitForNonValidatorInboundXOROutboundConnection(t, validator, node2) + waitForMinNonValidatorCountRemoteNodeIndexer(t, validator, 6, 2, 0, 2) + } + // Verify connection counts of inactive validators. + for _, validator := range []*cmd.Node{node1, node2} { + waitForMinNonValidatorCountRemoteNodeIndexer(t, validator, 6, 3, 0, 2) + } + // Verify connection counts of non-validators. + waitForCountRemoteNodeIndexerHandshakeCompleted(t, node6, 5, 3, 2, 0) + waitForCountRemoteNodeIndexerHandshakeCompleted(t, node7, 5, 3, 2, 0) + t.Logf("Test #2 passed | Successfully run validators node3, node4, node5; inactive-validators node1, node2; " + + "non-validators node6, node7") + + // Remove node3 from the validator set. Make node1 active again. + activeValidatorsMap = getActiveValidatorsMapWithValidatorNodes(t, node1, node4, node5) + setActiveValidators(activeValidatorsMap, node1, node2, node3, node4, node5, node6, node7) + // Verify full graph between active validators. + waitForValidatorFullGraph(t, node1, node4, node5) + // Verify connections of non-validators. + for _, nonValidator := range []*cmd.Node{node2, node3, node6, node7} { + waitForValidatorConnectionOneWay(t, nonValidator, node1, node4, node5) + } + // Verify connections of initial validators. + for _, validator := range []*cmd.Node{node1, node2, node3, node4, node5} { + waitForNonValidatorInboundConnection(t, validator, node6) + waitForNonValidatorInboundConnection(t, validator, node7) + } + // Verify connections of active validators. + for _, validator := range []*cmd.Node{node1, node4, node5} { + waitForNonValidatorInboundXOROutboundConnection(t, validator, node2) + waitForNonValidatorInboundXOROutboundConnection(t, validator, node3) + waitForMinNonValidatorCountRemoteNodeIndexer(t, validator, 6, 2, 0, 2) + } + // Verify connection counts of inactive validators. + for _, validator := range []*cmd.Node{node2, node3} { + waitForMinNonValidatorCountRemoteNodeIndexer(t, validator, 6, 3, 0, 2) + } + // Verify connection counts of non-validators. + waitForCountRemoteNodeIndexerHandshakeCompleted(t, node6, 5, 3, 2, 0) + waitForCountRemoteNodeIndexerHandshakeCompleted(t, node7, 5, 3, 2, 0) + t.Logf("Test #3 passed | Successfully run validators node1, node4, node5; inactive validators node2, node3; " + + "non-validators node6, node7") + + // Make all validators inactive. + activeValidatorsMap = getActiveValidatorsMapWithValidatorNodes(t) + setActiveValidators(activeValidatorsMap, node1, node2, node3, node4, node5, node6, node7) + // NOOP Verify full graph between active validators. + // NOOP Verify connections of non-validators. + // Verify connections of initial validators. + for _, validator := range []*cmd.Node{node1, node2, node3, node4, node5} { + waitForNonValidatorInboundConnection(t, validator, node6) + waitForNonValidatorInboundConnection(t, validator, node7) + } + // NOOP Verify connections of active validators. + // Verify connections and counts of inactive validators. + inactiveValidators := []*cmd.Node{node1, node2, node3, node4, node5} + for ii := 0; ii < len(inactiveValidators); ii++ { + for jj := ii + 1; jj < len(inactiveValidators); jj++ { + waitForNonValidatorInboundXOROutboundConnection(t, inactiveValidators[ii], inactiveValidators[jj]) + } + } + inactiveValidatorsRev := []*cmd.Node{node5, node4, node3, node2, node1} + for ii := 0; ii < len(inactiveValidatorsRev); ii++ { + for jj := ii + 1; jj < len(inactiveValidatorsRev); jj++ { + waitForNonValidatorInboundXOROutboundConnection(t, inactiveValidatorsRev[ii], inactiveValidatorsRev[jj]) + } + } + for _, validator := range inactiveValidators { + waitForMinNonValidatorCountRemoteNodeIndexer(t, validator, 6, 0, 0, 2) + } + // Verify connection counts of non-validators. + waitForCountRemoteNodeIndexerHandshakeCompleted(t, node6, 5, 0, 5, 0) + waitForCountRemoteNodeIndexerHandshakeCompleted(t, node7, 5, 0, 5, 0) + t.Logf("Test #4 passed | Successfully run inactive validators node1, node2, node3, node4, node5; " + + "non-validators node6, node7") +} + +func TestConnectionControllerValidatorInboundDeduplication(t *testing.T) { + // Spawn a non-validator node1, and two validators node2, node3. The validator nodes will have the same public key. + // Node2 and node3 will not initially be in the validator set. First, node2 will start an outbound connection to + // node1. We wait until the node2 is re-indexed as non-validator by node1, and then we make node3 open an outbound + // connection to node1. We wait until node3 is re-indexed as non-validator by node1. Then, we make node2 and node3 + // join the validator set (i.e. add one entry with the duplicated public key). Now, node1 should disconnect from + // either node2 or node3 because of duplicate public key. + + node1 := spawnNonValidatorNodeProtocol2(t, 18000, "node1") + blsSeedPhrase2, err := bip39.NewMnemonic(lib.RandomBytes(32)) + require.NoError(t, err) + node2 := spawnValidatorNodeProtocol2(t, 18001, "node2", blsSeedPhrase2) + node3 := spawnValidatorNodeProtocol2(t, 18002, "node3", blsSeedPhrase2) + + node1 = startNode(t, node1) + node2 = startNode(t, node2) + node3 = startNode(t, node3) + + nm2 := node2.Server.GetNetworkManager() + require.NoError(t, nm2.CreateNonValidatorOutboundConnection(node1.Listeners[0].Addr().String())) + // First wait for node2 to be indexed as a validator by node1. + waitForValidatorConnection(t, node1, node2) + // Now wait for node2 to be re-indexed as a non-validator. + waitForNonValidatorInboundConnectionDynamic(t, node1, node2, true) + waitForNonValidatorOutboundConnection(t, node2, node1) + + // Now connect node3 to node1. + nm3 := node3.Server.GetNetworkManager() + require.NoError(t, nm3.CreateNonValidatorOutboundConnection(node1.Listeners[0].Addr().String())) + // First wait for node3 to be indexed as a validator by node1. + waitForValidatorConnection(t, node1, node3) + // Now wait for node3 to be re-indexed as a non-validator. + waitForNonValidatorInboundConnectionDynamic(t, node1, node3, true) + waitForNonValidatorOutboundConnection(t, node3, node1) + + // Now add node2 and node3 to the validator set. + activeValidatorsMap := getActiveValidatorsMapWithValidatorNodes(t, node2) + setActiveValidators(activeValidatorsMap, node1, node2, node3) + // Now wait for node1 to disconnect from either node2 or node3. + waitForCountRemoteNodeIndexerHandshakeCompleted(t, node1, 1, 1, 0, 0) + t.Logf("Test #1 passed | Successfully run non-validator node1; validators node2, node3 with duplicate public key") +} + +func TestConnectionControllerNonValidatorConnectorOutbound(t *testing.T) { + // Spawn 6 non-validators node1, node2, node3, node4, node5, node6. Set node1's targetOutboundPeers to 3. Then make + // node1 create persistent outbound connections to node2, node3, and node4, as well as non-validator connections to + // node5 and node6. + node1 := spawnNonValidatorNodeProtocol2(t, 18000, "node1") + node1.Config.TargetOutboundPeers = 0 + node2 := spawnNonValidatorNodeProtocol2(t, 18001, "node2") + node3 := spawnNonValidatorNodeProtocol2(t, 18002, "node3") + node4 := spawnNonValidatorNodeProtocol2(t, 18003, "node4") + node5 := spawnNonValidatorNodeProtocol2(t, 18004, "node5") + node6 := spawnNonValidatorNodeProtocol2(t, 18005, "node6") + + node2 = startNode(t, node2) + node3 = startNode(t, node3) + node4 = startNode(t, node4) + node5 = startNode(t, node5) + node6 = startNode(t, node6) + + node1.Config.ConnectIPs = []string{ + node2.Listeners[0].Addr().String(), + node3.Listeners[0].Addr().String(), + node4.Listeners[0].Addr().String(), + } + node1 = startNode(t, node1) + + nm := node1.Server.GetNetworkManager() + require.NoError(t, nm.CreateNonValidatorOutboundConnection(node5.Listeners[0].Addr().String())) + require.NoError(t, nm.CreateNonValidatorOutboundConnection(node6.Listeners[0].Addr().String())) + + waitForCountRemoteNodeIndexerHandshakeCompleted(t, node1, 3, 0, 3, 0) + waitForNonValidatorOutboundConnection(t, node1, node2) + waitForNonValidatorOutboundConnection(t, node1, node3) + waitForNonValidatorOutboundConnection(t, node1, node4) +} + +func TestConnectionControllerNonValidatorConnectorInbound(t *testing.T) { + // Spawn validators node1, node2, node3, node4, node5, node6. Also spawn non-validators node7, node8, node9, node10. + // Set node1's targetOutboundPeers to 0 and targetInboundPeers to 1. Then make node1 create outbound connections to + // node2, node3, and make node4, node5, node6 create inbound connections to node1. Then make node1 create outbound + // connections to node7, node8, and make node9, node10 create inbound connections to node1. + blsSeedPhrase1, err := bip39.NewMnemonic(lib.RandomBytes(32)) + require.NoError(t, err) + node1 := spawnValidatorNodeProtocol2(t, 18000, "node1", blsSeedPhrase1) + node1.Config.TargetOutboundPeers = 0 + node1.Config.MaxInboundPeers = 1 + node1.Params.DialTimeout = 1 * time.Second + node1.Params.VerackNegotiationTimeout = 1 * time.Second + node1.Params.VersionNegotiationTimeout = 1 * time.Second + + blsSeedPhrase2, err := bip39.NewMnemonic(lib.RandomBytes(32)) + require.NoError(t, err) + node2 := spawnValidatorNodeProtocol2(t, 18001, "node2", blsSeedPhrase2) + node2.Config.GlogV = 0 + blsSeedPhrase3, err := bip39.NewMnemonic(lib.RandomBytes(32)) + require.NoError(t, err) + node3 := spawnValidatorNodeProtocol2(t, 18002, "node3", blsSeedPhrase3) + node3.Config.GlogV = 0 + blsSeedPhrase4, err := bip39.NewMnemonic(lib.RandomBytes(32)) + require.NoError(t, err) + node4 := spawnValidatorNodeProtocol2(t, 18003, "node4", blsSeedPhrase4) + node4.Config.GlogV = 0 + blsSeedPhrase5, err := bip39.NewMnemonic(lib.RandomBytes(32)) + require.NoError(t, err) + node5 := spawnValidatorNodeProtocol2(t, 18004, "node5", blsSeedPhrase5) + node5.Config.GlogV = 0 + blsSeedPhrase6, err := bip39.NewMnemonic(lib.RandomBytes(32)) + require.NoError(t, err) + node6 := spawnValidatorNodeProtocol2(t, 18005, "node6", blsSeedPhrase6) + node6.Config.GlogV = 0 + + node7 := spawnNonValidatorNodeProtocol2(t, 18006, "node7") + node8 := spawnNonValidatorNodeProtocol2(t, 18007, "node8") + node9 := spawnNonValidatorNodeProtocol2(t, 18008, "node9") + node10 := spawnNonValidatorNodeProtocol2(t, 18009, "node10") + + node1 = startNode(t, node1) + node2 = startNode(t, node2) + node3 = startNode(t, node3) + node4 = startNode(t, node4) + node5 = startNode(t, node5) + node6 = startNode(t, node6) + node7 = startNode(t, node7) + node8 = startNode(t, node8) + node9 = startNode(t, node9) + node10 = startNode(t, node10) + + // Connect node1 to node2, node3, node7, and node8. + nm1 := node1.Server.GetNetworkManager() + require.NoError(t, nm1.CreateNonValidatorOutboundConnection(node2.Listeners[0].Addr().String())) + require.NoError(t, nm1.CreateNonValidatorOutboundConnection(node3.Listeners[0].Addr().String())) + require.NoError(t, nm1.CreateNonValidatorOutboundConnection(node7.Listeners[0].Addr().String())) + require.NoError(t, nm1.CreateNonValidatorOutboundConnection(node8.Listeners[0].Addr().String())) + // Connect node4, node5, node6 to node1. + nm4 := node4.Server.GetNetworkManager() + require.NoError(t, nm4.CreateNonValidatorOutboundConnection(node1.Listeners[0].Addr().String())) + nm5 := node5.Server.GetNetworkManager() + require.NoError(t, nm5.CreateNonValidatorOutboundConnection(node1.Listeners[0].Addr().String())) + nm6 := node6.Server.GetNetworkManager() + require.NoError(t, nm6.CreateNonValidatorOutboundConnection(node1.Listeners[0].Addr().String())) + + // Connect node9, node10 to node1. + nm9 := node9.Server.GetNetworkManager() + require.NoError(t, nm9.CreateNonValidatorOutboundConnection(node1.Listeners[0].Addr().String())) + nm10 := node10.Server.GetNetworkManager() + require.NoError(t, nm10.CreateNonValidatorOutboundConnection(node1.Listeners[0].Addr().String())) + + activeValidatorsMap := getActiveValidatorsMapWithValidatorNodes(t, node1, node2, node3, node4, node5, node6) + setActiveValidators(activeValidatorsMap, node1, node2, node3, node4, node5, node6, node7, node8, node9, node10) + + waitForValidatorConnection(t, node1, node2) + waitForValidatorConnection(t, node1, node3) + waitForValidatorConnection(t, node1, node4) + waitForValidatorConnection(t, node1, node5) + waitForValidatorConnection(t, node1, node6) + waitForCountRemoteNodeIndexerHandshakeCompleted(t, node1, 6, 5, 0, 1) +} + +func TestConnectionControllerNonValidatorConnectorAddressMgr(t *testing.T) { + // Spawn a non-validator node1. Set node1's targetOutboundPeers to 1 and targetInboundPeers to 0. Then + // add one ip address to AddrMgr. Make sure that node1 creates outbound connections to this node. + node1 := spawnNodeProtocol1(t, 18000, "node1") + node1.Config.TargetOutboundPeers = 1 + node1.Config.MaxInboundPeers = 0 + node1.Config.MaxSyncBlockHeight = 1 + + node1 = startNode(t, node1) + nm := node1.Server.GetNetworkManager() + na1, err := nm.ConvertIPStringToNetAddress("deso-seed-2.io:17000") + require.NoError(t, err) + nm.AddrMgr.AddAddress(na1, na1) + waitForCountRemoteNodeIndexerHandshakeCompleted(t, node1, 1, 0, 1, 0) +} + +func TestConnectionControllerNonValidatorConnectorAddIps(t *testing.T) { + // Spawn a non-validator node1. Set node1's targetOutboundPeers to 2 and targetInboundPeers to 0. Then + // add two ip addresses to AddIps. Make sure that node1 creates outbound connections to these nodes. + node1 := spawnNodeProtocol1(t, 18000, "node1") + node1.Config.TargetOutboundPeers = 2 + node1.Config.MaxInboundPeers = 0 + node1.Config.MaxSyncBlockHeight = 1 + node1.Config.AddIPs = []string{"deso-seed-2.io", "deso-seed-3.io"} + + node1 = startNode(t, node1) + waitForCountRemoteNodeIndexer(t, node1, 2, 0, 2, 0) +} + +func getActiveValidatorsMapWithValidatorNodes(t *testing.T, validators ...*cmd.Node) *collections.ConcurrentMap[bls.SerializedPublicKey, consensus.Validator] { + mapping := collections.NewConcurrentMap[bls.SerializedPublicKey, consensus.Validator]() + for _, validator := range validators { + seed := validator.Config.PosValidatorSeed + if seed == "" { + t.Fatalf("Validator node %s does not have a PosValidatorSeed set", validator.Params.UserAgent) + } + keystore, err := lib.NewBLSKeystore(seed) + require.NoError(t, err) + mapping.Set(keystore.GetSigner().GetPublicKey().Serialize(), createSimpleValidatorEntry(validator)) + } + return mapping +} + +func setActiveValidators(validatorMap *collections.ConcurrentMap[bls.SerializedPublicKey, consensus.Validator], nodes ...*cmd.Node) { + for _, node := range nodes { + node.Server.GetNetworkManager().SetActiveValidatorsMap(validatorMap) + } +} + +func createSimpleValidatorEntry(node *cmd.Node) *lib.ValidatorEntry { + return &lib.ValidatorEntry{ + Domains: [][]byte{[]byte(node.Listeners[0].Addr().String())}, + } +} + +func waitForValidatorFullGraph(t *testing.T, validators ...*cmd.Node) { + for ii := 0; ii < len(validators); ii++ { + waitForValidatorConnectionOneWay(t, validators[ii], validators[ii+1:]...) + } +} + +func waitForValidatorConnectionOneWay(t *testing.T, n *cmd.Node, validators ...*cmd.Node) { + if len(validators) == 0 { + return + } + for _, validator := range validators { + waitForValidatorConnection(t, n, validator) + } +} + +func waitForNonValidatorInboundXOROutboundConnection(t *testing.T, node1 *cmd.Node, node2 *cmd.Node) { + userAgentN1 := node1.Params.UserAgent + userAgentN2 := node2.Params.UserAgent + conditionInbound := conditionNonValidatorInboundConnectionDynamic(t, node1, node2, true) + conditionOutbound := conditionNonValidatorOutboundConnectionDynamic(t, node1, node2, true) + xorCondition := func() bool { + return conditionInbound() != conditionOutbound() + } + waitForCondition(t, fmt.Sprintf("Waiting for Node (%s) to connect to inbound XOR outbound non-validator Node (%s)", + userAgentN1, userAgentN2), xorCondition) +} + +func waitForMinNonValidatorCountRemoteNodeIndexer(t *testing.T, node *cmd.Node, allCount int, validatorCount int, + minNonValidatorOutboundCount int, minNonValidatorInboundCount int) { + + userAgent := node.Params.UserAgent + nm := node.Server.GetNetworkManager() + condition := func() bool { + return checkRemoteNodeIndexerMinNonValidatorCount(nm, allCount, validatorCount, + minNonValidatorOutboundCount, minNonValidatorInboundCount) + } + waitForCondition(t, fmt.Sprintf("Waiting for Node (%s) to have at least %d non-validator outbound nodes and %d non-validator inbound nodes", + userAgent, minNonValidatorOutboundCount, minNonValidatorInboundCount), condition) +} + +func checkRemoteNodeIndexerMinNonValidatorCount(manager *lib.NetworkManager, allCount int, validatorCount int, + minNonValidatorOutboundCount int, minNonValidatorInboundCount int) bool { + + if allCount != manager.GetAllRemoteNodes().Count() { + return false + } + if validatorCount != manager.GetValidatorIndex().Count() { + return false + } + if minNonValidatorOutboundCount > manager.GetNonValidatorOutboundIndex().Count() { + return false + } + if minNonValidatorInboundCount > manager.GetNonValidatorInboundIndex().Count() { + return false + } + if allCount != manager.GetValidatorIndex().Count()+ + manager.GetNonValidatorOutboundIndex().Count()+ + manager.GetNonValidatorInboundIndex().Count() { + return false + } + return true +} diff --git a/integration_testing/network_manager_test.go b/integration_testing/network_manager_test.go new file mode 100644 index 000000000..8c883b973 --- /dev/null +++ b/integration_testing/network_manager_test.go @@ -0,0 +1,466 @@ +package integration_testing + +import ( + "github.com/deso-protocol/core/bls" + "github.com/deso-protocol/core/lib" + "github.com/stretchr/testify/require" + "github.com/tyler-smith/go-bip39" + "testing" +) + +func TestConnectionControllerNonValidator(t *testing.T) { + node1 := spawnNonValidatorNodeProtocol2(t, 18000, "node1") + node1.Params.DisableNetworkManagerRoutines = true + node1 = startNode(t, node1) + + // Make sure NonValidator Node1 can create an outbound connection to NonValidator Node2 + node2 := spawnNonValidatorNodeProtocol2(t, 18001, "node2") + node2.Params.DisableNetworkManagerRoutines = true + node2 = startNode(t, node2) + + nm := node1.Server.GetNetworkManager() + require.NoError(t, nm.CreateNonValidatorOutboundConnection(node2.Listeners[0].Addr().String())) + waitForNonValidatorOutboundConnection(t, node1, node2) + waitForNonValidatorInboundConnection(t, node2, node1) + + node2.Stop() + waitForEmptyRemoteNodeIndexer(t, node1) + t.Logf("Test #1 passed | Successfully created outbound connection from NonValidator Node1 to NonValidator Node2") + + // Make sure NonValidator Node1 can create an outbound connection to validator Node3 + blsSeedPhrase3, err := bip39.NewMnemonic(lib.RandomBytes(32)) + require.NoError(t, err) + node3 := spawnValidatorNodeProtocol2(t, 18002, "node3", blsSeedPhrase3) + node3.Params.DisableNetworkManagerRoutines = true + node3 = startNode(t, node3) + + nm = node1.Server.GetNetworkManager() + require.NoError(t, nm.CreateNonValidatorOutboundConnection(node3.Listeners[0].Addr().String())) + waitForValidatorConnection(t, node1, node3) + waitForNonValidatorInboundConnection(t, node3, node1) + + node3.Stop() + waitForEmptyRemoteNodeIndexer(t, node1) + t.Logf("Test #2 passed | Successfully created outbound connection from NonValidator Node1 to Validator Node3") + + // Make sure NonValidator Node1 can create a non-validator connection to validator Node4 + blsSeedPhrase4, err := bip39.NewMnemonic(lib.RandomBytes(32)) + require.NoError(t, err) + node4 := spawnValidatorNodeProtocol2(t, 18003, "node4", blsSeedPhrase4) + node4.Params.DisableNetworkManagerRoutines = true + node4 = startNode(t, node4) + + nm = node1.Server.GetNetworkManager() + require.NoError(t, nm.CreateNonValidatorOutboundConnection(node4.Listeners[0].Addr().String())) + waitForValidatorConnection(t, node1, node4) + waitForNonValidatorInboundConnection(t, node4, node1) + t.Logf("Test #3 passed | Successfully created outbound connection from NonValidator Node1 to Validator Node4") +} + +func TestConnectionControllerValidator(t *testing.T) { + blsSeedPhrase1, err := bip39.NewMnemonic(lib.RandomBytes(32)) + require.NoError(t, err) + node1 := spawnValidatorNodeProtocol2(t, 18000, "node1", blsSeedPhrase1) + node1.Params.DisableNetworkManagerRoutines = true + node1 = startNode(t, node1) + + // Make sure Validator Node1 can create an outbound connection to Validator Node2 + blsSeedPhrase2, err := bip39.NewMnemonic(lib.RandomBytes(32)) + require.NoError(t, err) + blsKeyStore2, err := lib.NewBLSKeystore(blsSeedPhrase2) + require.NoError(t, err) + blsPub2 := blsKeyStore2.GetSigner().GetPublicKey() + node2 := spawnValidatorNodeProtocol2(t, 18001, "node2", blsSeedPhrase2) + node2.Params.DisableNetworkManagerRoutines = true + node2 = startNode(t, node2) + + nm := node1.Server.GetNetworkManager() + require.NoError(t, nm.CreateValidatorConnection(node2.Listeners[0].Addr().String(), blsPub2)) + waitForValidatorConnection(t, node1, node2) + waitForValidatorConnection(t, node2, node1) + + node2.Stop() + waitForEmptyRemoteNodeIndexer(t, node1) + t.Logf("Test #1 passed | Successfully created outbound connection from Validator Node1 to Validator Node2") + + // Make sure Validator Node1 can create an outbound connection to NonValidator Node3 + node3 := spawnNonValidatorNodeProtocol2(t, 18002, "node3") + node3.Params.DisableNetworkManagerRoutines = true + node3 = startNode(t, node3) + + nm = node1.Server.GetNetworkManager() + require.NoError(t, nm.CreateNonValidatorOutboundConnection(node3.Listeners[0].Addr().String())) + waitForNonValidatorOutboundConnection(t, node1, node3) + waitForValidatorConnection(t, node3, node1) + + node3.Stop() + waitForEmptyRemoteNodeIndexer(t, node1) + t.Logf("Test #2 passed | Successfully created outbound connection from Validator Node1 to NonValidator Node3") + + // Make sure Validator Node1 can create an outbound non-validator connection to Validator Node4 + blsSeedPhrase4, err := bip39.NewMnemonic(lib.RandomBytes(32)) + require.NoError(t, err) + node4 := spawnValidatorNodeProtocol2(t, 18003, "node4", blsSeedPhrase4) + node4.Params.DisableNetworkManagerRoutines = true + node4 = startNode(t, node4) + + nm = node1.Server.GetNetworkManager() + require.NoError(t, nm.CreateNonValidatorOutboundConnection(node4.Listeners[0].Addr().String())) + waitForValidatorConnection(t, node1, node4) + waitForValidatorConnection(t, node4, node1) + t.Logf("Test #3 passed | Successfully created non-validator outbound connection from Validator Node1 to Validator Node4") +} + +func TestConnectionControllerHandshakeDataErrors(t *testing.T) { + blsSeedPhrase1, err := bip39.NewMnemonic(lib.RandomBytes(32)) + require.NoError(t, err) + node1 := spawnValidatorNodeProtocol2(t, 18000, "node1", blsSeedPhrase1) + node1.Params.DisableNetworkManagerRoutines = true + + // This node should have ProtocolVersion2, but it has ProtocolVersion1 as we want it to disconnect. + blsSeedPhrase2, err := bip39.NewMnemonic(lib.RandomBytes(32)) + require.NoError(t, err) + node2 := spawnValidatorNodeProtocol2(t, 18001, "node2", blsSeedPhrase2) + node2.Params.DisableNetworkManagerRoutines = true + node2.Params.ProtocolVersion = lib.ProtocolVersion1 + + node1 = startNode(t, node1) + node2 = startNode(t, node2) + + nm := node2.Server.GetNetworkManager() + require.NoError(t, nm.CreateNonValidatorOutboundConnection(node1.Listeners[0].Addr().String())) + waitForEmptyRemoteNodeIndexer(t, node1) + waitForEmptyRemoteNodeIndexer(t, node2) + t.Logf("Test #1 passed | Successfuly disconnected node with SFValidator flag and ProtocolVersion1 mismatch") + + // This node shouldn't have ProtocolVersion3, which is beyond latest ProtocolVersion2, meaning nodes should disconnect. + blsSeedPhrase3, err := bip39.NewMnemonic(lib.RandomBytes(32)) + require.NoError(t, err) + node3 := spawnValidatorNodeProtocol2(t, 18002, "node3", blsSeedPhrase3) + node3.Params.DisableNetworkManagerRoutines = true + node3.Params.ProtocolVersion = lib.ProtocolVersionType(3) + node3 = startNode(t, node3) + + nm = node1.Server.GetNetworkManager() + require.NoError(t, nm.CreateNonValidatorOutboundConnection(node3.Listeners[0].Addr().String())) + waitForEmptyRemoteNodeIndexer(t, node1) + waitForEmptyRemoteNodeIndexer(t, node3) + t.Logf("Test #2 passed | Successfuly disconnected node with ProtocolVersion3") + + // This node shouldn't have ProtocolVersion0, which is outdated. + node4 := spawnNonValidatorNodeProtocol2(t, 18003, "node4") + node4.Params.DisableNetworkManagerRoutines = true + node4.Params.ProtocolVersion = lib.ProtocolVersion0 + node4 = startNode(t, node4) + + nm = node1.Server.GetNetworkManager() + require.NoError(t, nm.CreateNonValidatorOutboundConnection(node4.Listeners[0].Addr().String())) + waitForEmptyRemoteNodeIndexer(t, node1) + waitForEmptyRemoteNodeIndexer(t, node4) + t.Logf("Test #3 passed | Successfuly disconnected node with ProtocolVersion0") + + // This node will have a different public key than the one it's supposed to have. + blsSeedPhrase5, err := bip39.NewMnemonic(lib.RandomBytes(32)) + require.NoError(t, err) + blsSeedPhrase5Wrong, err := bip39.NewMnemonic(lib.RandomBytes(32)) + require.NoError(t, err) + blsKeyStore5Wrong, err := lib.NewBLSKeystore(blsSeedPhrase5Wrong) + require.NoError(t, err) + node5 := spawnValidatorNodeProtocol2(t, 18004, "node5", blsSeedPhrase5) + node5.Params.DisableNetworkManagerRoutines = true + node5 = startNode(t, node5) + + nm = node1.Server.GetNetworkManager() + require.NoError(t, nm.CreateValidatorConnection(node5.Listeners[0].Addr().String(), blsKeyStore5Wrong.GetSigner().GetPublicKey())) + waitForEmptyRemoteNodeIndexer(t, node1) + waitForEmptyRemoteNodeIndexer(t, node5) + t.Logf("Test #4 passed | Successfuly disconnected node with public key mismatch") + + // This node will be missing SFPosValidator flag while being connected as a validator. + blsPriv6, err := bls.NewPrivateKey() + require.NoError(t, err) + node6 := spawnNonValidatorNodeProtocol2(t, 18005, "node6") + node6.Params.DisableNetworkManagerRoutines = true + node6 = startNode(t, node6) + + nm = node1.Server.GetNetworkManager() + require.NoError(t, nm.CreateValidatorConnection(node6.Listeners[0].Addr().String(), blsPriv6.PublicKey())) + waitForEmptyRemoteNodeIndexer(t, node1) + waitForEmptyRemoteNodeIndexer(t, node6) + t.Logf("Test #5 passed | Successfuly disconnected supposed validator node with missing SFPosValidator flag") + + // This node will have ProtocolVersion1 and be connected as an outbound non-validator node. + node7 := spawnNonValidatorNodeProtocol2(t, 18006, "node7") + node7.Params.DisableNetworkManagerRoutines = true + node7.Params.ProtocolVersion = lib.ProtocolVersion1 + node7 = startNode(t, node7) + + nm = node1.Server.GetNetworkManager() + require.NoError(t, nm.CreateNonValidatorOutboundConnection(node7.Listeners[0].Addr().String())) + waitForEmptyRemoteNodeIndexer(t, node1) + waitForEmptyRemoteNodeIndexer(t, node7) + t.Logf("Test #6 passed | Successfuly disconnected outbound non-validator node with ProtocolVersion1") +} + +func TestConnectionControllerHandshakeTimeouts(t *testing.T) { + // Set version negotiation timeout to 0 to make sure that the node will be disconnected + node1 := spawnNonValidatorNodeProtocol2(t, 18000, "node1") + node1.Params.DisableNetworkManagerRoutines = true + node1.Params.VersionNegotiationTimeout = 0 + node1 = startNode(t, node1) + + node2 := spawnNonValidatorNodeProtocol2(t, 18001, "node2") + node2.Params.DisableNetworkManagerRoutines = true + node2 = startNode(t, node2) + + nm := node1.Server.GetNetworkManager() + require.NoError(t, nm.CreateNonValidatorOutboundConnection(node2.Listeners[0].Addr().String())) + waitForEmptyRemoteNodeIndexer(t, node1) + waitForEmptyRemoteNodeIndexer(t, node2) + t.Logf("Test #1 passed | Successfuly disconnected node after version negotiation timeout") + + // Now let's try timing out verack exchange + node1.Params.VersionNegotiationTimeout = lib.DeSoTestnetParams.VersionNegotiationTimeout + node3 := spawnNonValidatorNodeProtocol2(t, 18002, "node3") + node3.Params.DisableNetworkManagerRoutines = true + node3.Params.VerackNegotiationTimeout = 0 + node3 = startNode(t, node3) + + nm = node3.Server.GetNetworkManager() + require.NoError(t, nm.CreateNonValidatorOutboundConnection(node1.Listeners[0].Addr().String())) + waitForEmptyRemoteNodeIndexer(t, node1) + waitForEmptyRemoteNodeIndexer(t, node3) + t.Logf("Test #2 passed | Successfuly disconnected node after verack exchange timeout") + + // Now let's try timing out handshake between two validators node4 and node5 + blsSeedPhrase4, err := bip39.NewMnemonic(lib.RandomBytes(32)) + require.NoError(t, err) + node4 := spawnValidatorNodeProtocol2(t, 18003, "node4", blsSeedPhrase4) + node4.Params.DisableNetworkManagerRoutines = true + node4.Params.HandshakeTimeoutMicroSeconds = 0 + node4 = startNode(t, node4) + + blsSeedPhrase5, err := bip39.NewMnemonic(lib.RandomBytes(32)) + require.NoError(t, err) + blsKeyStore5, err := lib.NewBLSKeystore(blsSeedPhrase5) + require.NoError(t, err) + node5 := spawnValidatorNodeProtocol2(t, 18004, "node5", blsSeedPhrase5) + node5.Params.DisableNetworkManagerRoutines = true + node5 = startNode(t, node5) + + nm = node4.Server.GetNetworkManager() + require.NoError(t, nm.CreateValidatorConnection(node5.Listeners[0].Addr().String(), blsKeyStore5.GetSigner().GetPublicKey())) + waitForEmptyRemoteNodeIndexer(t, node4) + waitForEmptyRemoteNodeIndexer(t, node5) + t.Logf("Test #3 passed | Successfuly disconnected validator node after handshake timeout") +} + +func TestConnectionControllerValidatorDuplication(t *testing.T) { + node1 := spawnNonValidatorNodeProtocol2(t, 18000, "node1") + node1.Params.DisableNetworkManagerRoutines = true + node1 = startNode(t, node1) + + // Create a validator Node2 + blsSeedPhrase2, err := bip39.NewMnemonic(lib.RandomBytes(32)) + require.NoError(t, err) + blsKeyStore2, err := lib.NewBLSKeystore(blsSeedPhrase2) + require.NoError(t, err) + node2 := spawnValidatorNodeProtocol2(t, 18001, "node2", blsSeedPhrase2) + node2.Params.DisableNetworkManagerRoutines = true + node2 = startNode(t, node2) + + // Create a duplicate validator Node3 + node3 := spawnValidatorNodeProtocol2(t, 18002, "node3", blsSeedPhrase2) + node3.Params.DisableNetworkManagerRoutines = true + node3 = startNode(t, node3) + + // Create validator connection from Node1 to Node2 and from Node1 to Node3 + nm := node1.Server.GetNetworkManager() + require.NoError(t, nm.CreateValidatorConnection(node2.Listeners[0].Addr().String(), blsKeyStore2.GetSigner().GetPublicKey())) + // This should fail out right because Node3 has a duplicate public key. + require.Error(t, nm.CreateValidatorConnection(node3.Listeners[0].Addr().String(), blsKeyStore2.GetSigner().GetPublicKey())) + waitForValidatorConnection(t, node1, node2) + waitForNonValidatorInboundConnection(t, node2, node1) + + // Now create an outbound connection from Node3 to Node1, which should pass handshake, but then fail because + // Node1 already has a validator connection to Node2 with the same public key. + nm3 := node3.Server.GetNetworkManager() + require.NoError(t, nm3.CreateNonValidatorOutboundConnection(node1.Listeners[0].Addr().String())) + waitForEmptyRemoteNodeIndexer(t, node3) + waitForCountRemoteNodeIndexer(t, node1, 1, 1, 0, 0) + t.Logf("Test #1 passed | Successfuly rejected duplicate validator connection with inbound/outbound validators") + + node3.Stop() + node2.Stop() + waitForEmptyRemoteNodeIndexer(t, node1) + + // Create two more validators Node4, Node5 with duplicate public keys + blsSeedPhrase4, err := bip39.NewMnemonic(lib.RandomBytes(32)) + require.NoError(t, err) + node4 := spawnValidatorNodeProtocol2(t, 18003, "node4", blsSeedPhrase4) + node4.Params.DisableNetworkManagerRoutines = true + node4 = startNode(t, node4) + + node5 := spawnValidatorNodeProtocol2(t, 18004, "node5", blsSeedPhrase4) + node5.Params.DisableNetworkManagerRoutines = true + node5 = startNode(t, node5) + + // Create validator connections from Node4 to Node1 and from Node5 to Node1 + nm4 := node4.Server.GetNetworkManager() + require.NoError(t, nm4.CreateNonValidatorOutboundConnection(node1.Listeners[0].Addr().String())) + waitForValidatorConnection(t, node1, node4) + waitForNonValidatorOutboundConnection(t, node4, node1) + nm5 := node5.Server.GetNetworkManager() + require.NoError(t, nm5.CreateNonValidatorOutboundConnection(node1.Listeners[0].Addr().String())) + waitForEmptyRemoteNodeIndexer(t, node5) + waitForCountRemoteNodeIndexer(t, node1, 1, 1, 0, 0) + t.Logf("Test #2 passed | Successfuly rejected duplicate validator connection with multiple outbound validators") +} + +func TestConnectionControllerProtocolDifference(t *testing.T) { + // Create a ProtocolVersion1 Node1 + node1 := spawnNonValidatorNodeProtocol2(t, 18000, "node1") + node1.Params.DisableNetworkManagerRoutines = true + node1.Params.ProtocolVersion = lib.ProtocolVersion1 + node1 = startNode(t, node1) + + // Create a ProtocolVersion2 NonValidator Node2 + node2 := spawnNonValidatorNodeProtocol2(t, 18001, "node2") + node2.Params.DisableNetworkManagerRoutines = true + node2 = startNode(t, node2) + + // Create non-validator connection from Node1 to Node2 + nm := node1.Server.GetNetworkManager() + require.NoError(t, nm.CreateNonValidatorOutboundConnection(node2.Listeners[0].Addr().String())) + waitForNonValidatorOutboundConnection(t, node1, node2) + waitForNonValidatorInboundConnection(t, node2, node1) + t.Logf("Test #1 passed | Successfuly connected to a ProtocolVersion1 node with a ProtocolVersion2 non-validator") + + // Create a ProtocolVersion2 Validator Node3 + blsSeedPhrase3, err := bip39.NewMnemonic(lib.RandomBytes(32)) + require.NoError(t, err) + blsKeyStore3, err := lib.NewBLSKeystore(blsSeedPhrase3) + require.NoError(t, err) + node3 := spawnValidatorNodeProtocol2(t, 18002, "node3", blsSeedPhrase3) + node3.Params.DisableNetworkManagerRoutines = true + node3 = startNode(t, node3) + + // Create validator connection from Node1 to Node3 + require.NoError(t, nm.CreateValidatorConnection(node3.Listeners[0].Addr().String(), blsKeyStore3.GetSigner().GetPublicKey())) + waitForValidatorConnection(t, node1, node3) + waitForNonValidatorInboundConnection(t, node3, node1) + t.Logf("Test #2 passed | Successfuly connected to a ProtocolVersion1 node with a ProtocolVersion2 validator") + + node2.Stop() + node3.Stop() + waitForEmptyRemoteNodeIndexer(t, node1) + + // Create a ProtocolVersion2 validator Node4 + blsSeedPhrase4, err := bip39.NewMnemonic(lib.RandomBytes(32)) + require.NoError(t, err) + blsKeyStore4, err := lib.NewBLSKeystore(blsSeedPhrase4) + require.NoError(t, err) + node4 := spawnValidatorNodeProtocol2(t, 18003, "node4", blsSeedPhrase4) + node4.Params.DisableNetworkManagerRoutines = true + node4 = startNode(t, node4) + + // Attempt to create non-validator connection from Node4 to Node1 + nm = node4.Server.GetNetworkManager() + require.NoError(t, nm.CreateNonValidatorOutboundConnection(node1.Listeners[0].Addr().String())) + waitForEmptyRemoteNodeIndexer(t, node4) + waitForEmptyRemoteNodeIndexer(t, node1) + t.Logf("Test #3 passed | Successfuly rejected outbound connection from ProtocolVersion2 node to ProtcolVersion1 node") + + // Attempt to create validator connection from Node4 to Node1 + require.NoError(t, nm.CreateValidatorConnection(node1.Listeners[0].Addr().String(), blsKeyStore4.GetSigner().GetPublicKey())) + waitForEmptyRemoteNodeIndexer(t, node4) + waitForEmptyRemoteNodeIndexer(t, node1) + t.Logf("Test #4 passed | Successfuly rejected validator connection from ProtocolVersion2 node to ProtcolVersion1 node") + + // Create a ProtocolVersion2 non-validator Node5 + node5 := spawnNonValidatorNodeProtocol2(t, 18004, "node5") + node5.Params.DisableNetworkManagerRoutines = true + node5 = startNode(t, node5) + + // Attempt to create non-validator connection from Node5 to Node1 + nm = node5.Server.GetNetworkManager() + require.NoError(t, nm.CreateNonValidatorOutboundConnection(node1.Listeners[0].Addr().String())) + waitForEmptyRemoteNodeIndexer(t, node5) + waitForEmptyRemoteNodeIndexer(t, node1) + t.Logf("Test #5 passed | Successfuly rejected outbound connection from ProtocolVersion2 node to ProtcolVersion1 node") +} + +func TestConnectionControllerPersistentConnection(t *testing.T) { + // Create a NonValidator Node1 + node1 := spawnNonValidatorNodeProtocol2(t, 18000, "node1") + node1.Params.DisableNetworkManagerRoutines = true + node1 = startNode(t, node1) + + // Create a Validator Node2 + blsSeedPhrase2, err := bip39.NewMnemonic(lib.RandomBytes(32)) + require.NoError(t, err) + node2 := spawnValidatorNodeProtocol2(t, 18001, "node2", blsSeedPhrase2) + node2.Params.DisableNetworkManagerRoutines = true + node2 = startNode(t, node2) + + // Create a persistent connection from Node1 to Node2 + nm := node1.Server.GetNetworkManager() + _, err = nm.CreateNonValidatorPersistentOutboundConnection(node2.Listeners[0].Addr().String()) + require.NoError(t, err) + waitForValidatorConnection(t, node1, node2) + waitForNonValidatorInboundConnection(t, node2, node1) + node2.Stop() + waitForEmptyRemoteNodeIndexer(t, node1) + t.Logf("Test #1 passed | Successfuly created persistent connection from non-validator Node1 to validator Node2") + + // Create a Non-validator Node3 + node3 := spawnNonValidatorNodeProtocol2(t, 18002, "node3") + node3.Params.DisableNetworkManagerRoutines = true + node3 = startNode(t, node3) + + // Create a persistent connection from Node1 to Node3 + _, err = nm.CreateNonValidatorPersistentOutboundConnection(node3.Listeners[0].Addr().String()) + require.NoError(t, err) + waitForNonValidatorOutboundConnection(t, node1, node3) + waitForNonValidatorInboundConnection(t, node3, node1) + node3.Stop() + waitForEmptyRemoteNodeIndexer(t, node1) + node1.Stop() + t.Logf("Test #2 passed | Successfuly created persistent connection from non-validator Node1 to non-validator Node3") + + // Create a Validator Node4 + blsSeedPhrase4, err := bip39.NewMnemonic(lib.RandomBytes(32)) + require.NoError(t, err) + node4 := spawnValidatorNodeProtocol2(t, 18003, "node4", blsSeedPhrase4) + node4.Params.DisableNetworkManagerRoutines = true + node4 = startNode(t, node4) + + // Create a non-validator Node5 + node5 := spawnNonValidatorNodeProtocol2(t, 18004, "node5") + node5.Params.DisableNetworkManagerRoutines = true + node5 = startNode(t, node5) + + // Create a persistent connection from Node4 to Node5 + nm = node4.Server.GetNetworkManager() + _, err = nm.CreateNonValidatorPersistentOutboundConnection(node5.Listeners[0].Addr().String()) + require.NoError(t, err) + waitForNonValidatorOutboundConnection(t, node4, node5) + waitForValidatorConnection(t, node5, node4) + node5.Stop() + waitForEmptyRemoteNodeIndexer(t, node4) + t.Logf("Test #3 passed | Successfuly created persistent connection from validator Node4 to non-validator Node5") + + // Create a Validator Node6 + blsSeedPhrase6, err := bip39.NewMnemonic(lib.RandomBytes(32)) + require.NoError(t, err) + node6 := spawnValidatorNodeProtocol2(t, 18005, "node6", blsSeedPhrase6) + node6.Params.DisableNetworkManagerRoutines = true + node6 = startNode(t, node6) + + // Create a persistent connection from Node4 to Node6 + _, err = nm.CreateNonValidatorPersistentOutboundConnection(node6.Listeners[0].Addr().String()) + require.NoError(t, err) + waitForValidatorConnection(t, node4, node6) + waitForValidatorConnection(t, node6, node4) + t.Logf("Test #4 passed | Successfuly created persistent connection from validator Node4 to validator Node6") +} diff --git a/integration_testing/network_manager_utils_test.go b/integration_testing/network_manager_utils_test.go new file mode 100644 index 000000000..6c1e95010 --- /dev/null +++ b/integration_testing/network_manager_utils_test.go @@ -0,0 +1,294 @@ +package integration_testing + +import ( + "fmt" + "github.com/deso-protocol/core/cmd" + "github.com/deso-protocol/core/lib" + "os" + "testing" +) + +func waitForValidatorConnection(t *testing.T, node1 *cmd.Node, node2 *cmd.Node) { + userAgentN1 := node1.Params.UserAgent + userAgentN2 := node2.Params.UserAgent + nmN1 := node1.Server.GetNetworkManager() + n1ValidatedN2 := func() bool { + if true != checkRemoteNodeIndexerUserAgent(nmN1, userAgentN2, true, false, false) { + return false + } + rnFromN2 := getRemoteNodeWithUserAgent(node1, userAgentN2) + if rnFromN2 == nil { + return false + } + if !rnFromN2.IsHandshakeCompleted() { + return false + } + return true + } + waitForCondition(t, fmt.Sprintf("Waiting for Node (%s) to connect to validator Node (%s)", userAgentN1, userAgentN2), n1ValidatedN2) +} + +func waitForNonValidatorOutboundConnection(t *testing.T, node1 *cmd.Node, node2 *cmd.Node) { + userAgentN1 := node1.Params.UserAgent + userAgentN2 := node2.Params.UserAgent + condition := conditionNonValidatorOutboundConnection(t, node1, node2) + waitForCondition(t, fmt.Sprintf("Waiting for Node (%s) to connect to outbound non-validator Node (%s)", userAgentN1, userAgentN2), condition) +} + +func conditionNonValidatorOutboundConnection(t *testing.T, node1 *cmd.Node, node2 *cmd.Node) func() bool { + return conditionNonValidatorOutboundConnectionDynamic(t, node1, node2, false) +} + +func conditionNonValidatorOutboundConnectionDynamic(t *testing.T, node1 *cmd.Node, node2 *cmd.Node, inactiveValidator bool) func() bool { + userAgentN2 := node2.Params.UserAgent + nmN1 := node1.Server.GetNetworkManager() + return func() bool { + if true != checkRemoteNodeIndexerUserAgent(nmN1, userAgentN2, false, true, false) { + return false + } + rnFromN2 := getRemoteNodeWithUserAgent(node1, userAgentN2) + if rnFromN2 == nil { + return false + } + if !rnFromN2.IsHandshakeCompleted() { + return false + } + // inactiveValidator should have the public key. + if inactiveValidator { + return rnFromN2.GetValidatorPublicKey() != nil + } + return rnFromN2.GetValidatorPublicKey() == nil + } +} + +func waitForNonValidatorInboundConnection(t *testing.T, node1 *cmd.Node, node2 *cmd.Node) { + userAgentN1 := node1.Params.UserAgent + userAgentN2 := node2.Params.UserAgent + condition := conditionNonValidatorInboundConnection(t, node1, node2) + waitForCondition(t, fmt.Sprintf("Waiting for Node (%s) to connect to inbound non-validator Node (%s)", userAgentN1, userAgentN2), condition) +} + +func waitForNonValidatorInboundConnectionDynamic(t *testing.T, node1 *cmd.Node, node2 *cmd.Node, inactiveValidator bool) { + userAgentN1 := node1.Params.UserAgent + userAgentN2 := node2.Params.UserAgent + condition := conditionNonValidatorInboundConnectionDynamic(t, node1, node2, inactiveValidator) + waitForCondition(t, fmt.Sprintf("Waiting for Node (%s) to connect to inbound non-validator Node (%s), "+ + "inactiveValidator (%v)", userAgentN1, userAgentN2, inactiveValidator), condition) +} + +func conditionNonValidatorInboundConnection(t *testing.T, node1 *cmd.Node, node2 *cmd.Node) func() bool { + return conditionNonValidatorInboundConnectionDynamic(t, node1, node2, false) +} + +func conditionNonValidatorInboundConnectionDynamic(t *testing.T, node1 *cmd.Node, node2 *cmd.Node, inactiveValidator bool) func() bool { + userAgentN2 := node2.Params.UserAgent + nmN1 := node1.Server.GetNetworkManager() + return func() bool { + if true != checkRemoteNodeIndexerUserAgent(nmN1, userAgentN2, false, false, true) { + return false + } + rnFromN2 := getRemoteNodeWithUserAgent(node1, userAgentN2) + if rnFromN2 == nil { + return false + } + if !rnFromN2.IsHandshakeCompleted() { + return false + } + // inactiveValidator should have the public key. + if inactiveValidator { + return rnFromN2.GetValidatorPublicKey() != nil + } + return rnFromN2.GetValidatorPublicKey() == nil + } +} + +func waitForEmptyRemoteNodeIndexer(t *testing.T, node1 *cmd.Node) { + userAgentN1 := node1.Params.UserAgent + nmN1 := node1.Server.GetNetworkManager() + n1ValidatedN2 := func() bool { + if true != checkRemoteNodeIndexerEmpty(nmN1) { + return false + } + return true + } + waitForCondition(t, fmt.Sprintf("Waiting for Node (%s) to disconnect from all RemoteNodes", userAgentN1), n1ValidatedN2) +} + +func waitForCountRemoteNodeIndexer(t *testing.T, node1 *cmd.Node, allCount int, validatorCount int, + nonValidatorOutboundCount int, nonValidatorInboundCount int) { + + userAgent := node1.Params.UserAgent + nm := node1.Server.GetNetworkManager() + condition := func() bool { + if true != checkRemoteNodeIndexerCount(nm, allCount, validatorCount, nonValidatorOutboundCount, nonValidatorInboundCount) { + return false + } + return true + } + waitForCondition(t, fmt.Sprintf("Waiting for Node (%s) to have appropriate RemoteNodes counts", userAgent), condition) +} + +func waitForCountRemoteNodeIndexerHandshakeCompleted(t *testing.T, node1 *cmd.Node, allCount, validatorCount int, + nonValidatorOutboundCount int, nonValidatorInboundCount int) { + + userAgent := node1.Params.UserAgent + nm := node1.Server.GetNetworkManager() + condition := func() bool { + return checkRemoteNodeIndexerCountHandshakeCompleted(nm, allCount, validatorCount, + nonValidatorOutboundCount, nonValidatorInboundCount) + } + waitForCondition(t, fmt.Sprintf("Waiting for Node (%s) to have appropriate RemoteNodes counts", userAgent), condition) +} + +func checkRemoteNodeIndexerUserAgent(manager *lib.NetworkManager, userAgent string, validator bool, + nonValidatorOutbound bool, nonValidatorInbound bool) bool { + + if true != checkUserAgentInRemoteNodeList(userAgent, manager.GetAllRemoteNodes().GetAll()) { + return false + } + if validator != checkUserAgentInRemoteNodeList(userAgent, manager.GetValidatorIndex().GetAll()) { + return false + } + if nonValidatorOutbound != checkUserAgentInRemoteNodeList(userAgent, manager.GetNonValidatorOutboundIndex().GetAll()) { + return false + } + if nonValidatorInbound != checkUserAgentInRemoteNodeList(userAgent, manager.GetNonValidatorInboundIndex().GetAll()) { + return false + } + + return true +} + +func checkRemoteNodeIndexerCount(manager *lib.NetworkManager, allCount int, validatorCount int, + nonValidatorOutboundCount int, nonValidatorInboundCount int) bool { + + if allCount != manager.GetAllRemoteNodes().Count() { + return false + } + if validatorCount != manager.GetValidatorIndex().Count() { + return false + } + if nonValidatorOutboundCount != manager.GetNonValidatorOutboundIndex().Count() { + return false + } + if nonValidatorInboundCount != manager.GetNonValidatorInboundIndex().Count() { + return false + } + + return true +} + +func checkRemoteNodeIndexerCountHandshakeCompleted(manager *lib.NetworkManager, allCount int, validatorCount int, + nonValidatorOutboundCount int, nonValidatorInboundCount int) bool { + + if allCount != manager.GetAllRemoteNodes().Count() { + return false + } + if validatorCount != manager.GetValidatorIndex().Count() { + return false + } + for _, rn := range manager.GetValidatorIndex().GetAll() { + if !rn.IsHandshakeCompleted() { + return false + } + } + + if nonValidatorOutboundCount != manager.GetNonValidatorOutboundIndex().Count() { + return false + } + for _, rn := range manager.GetNonValidatorOutboundIndex().GetAll() { + if !rn.IsHandshakeCompleted() { + return false + } + } + + if nonValidatorInboundCount != manager.GetNonValidatorInboundIndex().Count() { + return false + } + for _, rn := range manager.GetNonValidatorInboundIndex().GetAll() { + if !rn.IsHandshakeCompleted() { + return false + } + } + + return true +} + +func checkRemoteNodeIndexerEmpty(manager *lib.NetworkManager) bool { + if manager.GetAllRemoteNodes().Count() != 0 { + return false + } + if manager.GetValidatorIndex().Count() != 0 { + return false + } + if manager.GetNonValidatorOutboundIndex().Count() != 0 { + return false + } + if manager.GetNonValidatorInboundIndex().Count() != 0 { + return false + } + return true +} + +func checkUserAgentInRemoteNodeList(userAgent string, rnList []*lib.RemoteNode) bool { + for _, rn := range rnList { + if rn == nil { + continue + } + if rn.GetUserAgent() == userAgent { + return true + } + } + return false +} + +func getRemoteNodeWithUserAgent(node *cmd.Node, userAgent string) *lib.RemoteNode { + nm := node.Server.GetNetworkManager() + rnList := nm.GetAllRemoteNodes().GetAll() + for _, rn := range rnList { + if rn.GetUserAgent() == userAgent { + return rn + } + } + return nil +} + +func spawnNodeProtocol1(t *testing.T, port uint32, id string) *cmd.Node { + dbDir := getDirectory(t) + t.Cleanup(func() { + os.RemoveAll(dbDir) + }) + config := generateConfig(t, port, dbDir, 10) + config.SyncType = lib.NodeSyncTypeBlockSync + node := cmd.NewNode(config) + node.Params.UserAgent = id + node.Params.ProtocolVersion = lib.ProtocolVersion1 + return node +} + +func spawnNonValidatorNodeProtocol2(t *testing.T, port uint32, id string) *cmd.Node { + dbDir := getDirectory(t) + t.Cleanup(func() { + os.RemoveAll(dbDir) + }) + config := generateConfig(t, port, dbDir, 10) + config.SyncType = lib.NodeSyncTypeBlockSync + node := cmd.NewNode(config) + node.Params.UserAgent = id + node.Params.ProtocolVersion = lib.ProtocolVersion2 + return node +} + +func spawnValidatorNodeProtocol2(t *testing.T, port uint32, id string, blsSeedPhrase string) *cmd.Node { + dbDir := getDirectory(t) + t.Cleanup(func() { + os.RemoveAll(dbDir) + }) + config := generateConfig(t, port, dbDir, 10) + config.SyncType = lib.NodeSyncTypeBlockSync + config.PosValidatorSeed = blsSeedPhrase + node := cmd.NewNode(config) + node.Params.UserAgent = id + node.Params.ProtocolVersion = lib.ProtocolVersion2 + return node +} diff --git a/integration_testing/rollback_test.go b/integration_testing/rollback_test.go index 8028866ac..c7b440b2b 100644 --- a/integration_testing/rollback_test.go +++ b/integration_testing/rollback_test.go @@ -10,7 +10,10 @@ import ( ) // Start blocks to height 5000 and then disconnect +// TODO: This test won't work now. func TestStateRollback(t *testing.T) { + t.Skipf("DisconnectBlocksToHeight doesn't work in PoS") + require := require.New(t) _ = require diff --git a/integration_testing/tools.go b/integration_testing/tools.go index c73b82873..43733cbb9 100644 --- a/integration_testing/tools.go +++ b/integration_testing/tools.go @@ -82,6 +82,16 @@ func generateConfig(t *testing.T, port uint32, dataDir string, maxPeers uint32) config.SnapshotBlockHeightPeriod = HyperSyncSnapshotPeriod config.MaxSyncBlockHeight = MaxSyncBlockHeight config.SyncType = lib.NodeSyncTypeBlockSync + config.MempoolBackupIntervalMillis = 30000 + config.MaxMempoolPosSizeBytes = 3000000000 + config.MempoolFeeEstimatorNumMempoolBlocks = 1 + config.MempoolFeeEstimatorNumPastBlocks = 50 + config.MempoolMaxValidationViewConnects = 10000 + config.TransactionValidationRefreshIntervalMillis = 10 + config.AugmentedBlockViewRefreshIntervalMillis = 10 + config.PosBlockProductionIntervalMilliseconds = 1500 + config.PosTimeoutBaseDurationMilliseconds = 30000 + //config.ArchivalMode = true return config @@ -150,7 +160,8 @@ func compareNodesByChecksum(t *testing.T, nodeA *cmd.Node, nodeB *cmd.Node) { // compareNodesByState will look through all state records in nodeA and nodeB databases and will compare them. // The nodes pass this comparison iff they have identical states. func compareNodesByState(t *testing.T, nodeA *cmd.Node, nodeB *cmd.Node, verbose int) { - compareNodesByStateWithPrefixList(t, nodeA.ChainDB, nodeB.ChainDB, lib.StatePrefixes.StatePrefixesList, verbose) + compareNodesByStateWithPrefixList(t, nodeA.Server.GetBlockchain().DB(), nodeB.Server.GetBlockchain().DB(), + lib.StatePrefixes.StatePrefixesList, verbose) } // compareNodesByDB will look through all records in nodeA and nodeB databases and will compare them. @@ -164,7 +175,8 @@ func compareNodesByDB(t *testing.T, nodeA *cmd.Node, nodeB *cmd.Node, verbose in } prefixList = append(prefixList, []byte{prefix}) } - compareNodesByStateWithPrefixList(t, nodeA.ChainDB, nodeB.ChainDB, prefixList, verbose) + compareNodesByStateWithPrefixList(t, nodeA.Server.GetBlockchain().DB(), nodeB.Server.GetBlockchain().DB(), + prefixList, verbose) } // compareNodesByDB will look through all records in nodeA and nodeB txindex databases and will compare them. @@ -386,25 +398,25 @@ func restartNode(t *testing.T, node *cmd.Node) *cmd.Node { } // listenForBlockHeight busy-waits until the node's block tip reaches provided height. -func listenForBlockHeight(t *testing.T, node *cmd.Node, height uint32, signal chan<- bool) { +func listenForBlockHeight(node *cmd.Node, height uint32) (_listener chan bool) { + listener := make(chan bool) ticker := time.NewTicker(1 * time.Millisecond) go func() { for { <-ticker.C if node.Server.GetBlockchain().BlockTip().Height >= height { - signal <- true + listener <- true break } } }() + return listener } // disconnectAtBlockHeight busy-waits until the node's block tip reaches provided height, and then disconnects // from the provided bridge. -func disconnectAtBlockHeight(t *testing.T, syncingNode *cmd.Node, bridge *ConnectionBridge, height uint32) { - listener := make(chan bool) - listenForBlockHeight(t, syncingNode, height, listener) - <-listener +func disconnectAtBlockHeight(syncingNode *cmd.Node, bridge *ConnectionBridge, height uint32) { + <-listenForBlockHeight(syncingNode, height) bridge.Disconnect() } @@ -414,7 +426,7 @@ func restartAtHeightAndReconnectNode(t *testing.T, node *cmd.Node, source *cmd.N height uint32) (_node *cmd.Node, _bridge *ConnectionBridge) { require := require.New(t) - disconnectAtBlockHeight(t, node, currentBridge, height) + disconnectAtBlockHeight(node, currentBridge, height) newNode := restartNode(t, node) // Wait after the restart. time.Sleep(1 * time.Second) @@ -425,6 +437,16 @@ func restartAtHeightAndReconnectNode(t *testing.T, node *cmd.Node, source *cmd.N return newNode, bridge } +func restartAtHeight(t *testing.T, node *cmd.Node, height uint32) *cmd.Node { + <-listenForBlockHeight(node, height) + return restartNode(t, node) +} + +func shutdownAtHeight(t *testing.T, node *cmd.Node, height uint32) *cmd.Node { + <-listenForBlockHeight(node, height) + return shutdownNode(t, node) +} + // listenForSyncPrefix will wait until the node starts downloading the provided syncPrefix in hypersync, and then sends // a message to the provided signal channel. func listenForSyncPrefix(t *testing.T, node *cmd.Node, syncPrefix []byte, signal chan<- bool) { @@ -468,6 +490,20 @@ func restartAtSyncPrefixAndReconnectNode(t *testing.T, node *cmd.Node, source *c return newNode, bridge } +func restartAtSyncPrefix(t *testing.T, node *cmd.Node, syncPrefix []byte) *cmd.Node { + listener := make(chan bool) + listenForSyncPrefix(t, node, syncPrefix, listener) + <-listener + return restartNode(t, node) +} + +func shutdownAtSyncPrefix(t *testing.T, node *cmd.Node, syncPrefix []byte) *cmd.Node { + listener := make(chan bool) + listenForSyncPrefix(t, node, syncPrefix, listener) + <-listener + return shutdownNode(t, node) +} + func randomUint32Between(t *testing.T, min, max uint32) uint32 { require := require.New(t) randomNumber, err := wire.RandomUint64() @@ -475,3 +511,23 @@ func randomUint32Between(t *testing.T, min, max uint32) uint32 { randomHeight := uint32(randomNumber) % (max - min) return randomHeight + min } + +func waitForCondition(t *testing.T, id string, condition func() bool) { + signalChan := make(chan struct{}) + go func() { + for { + if condition() { + signalChan <- struct{}{} + return + } + time.Sleep(100 * time.Millisecond) + } + }() + + select { + case <-signalChan: + return + case <-time.After(5 * time.Second): + t.Fatalf("Condition timed out | %s", id) + } +} diff --git a/integration_testing/txindex_test.go b/integration_testing/txindex_test.go index aa13fd265..702e63c10 100644 --- a/integration_testing/txindex_test.go +++ b/integration_testing/txindex_test.go @@ -1,11 +1,7 @@ package integration_testing import ( - "fmt" - "github.com/deso-protocol/core/cmd" "github.com/deso-protocol/core/lib" - "github.com/stretchr/testify/require" - "os" "testing" ) @@ -16,39 +12,21 @@ import ( // 4. node2 syncs MaxSyncBlockHeight blocks from node1, and builds txindex afterwards. // 5. compare node1 db and txindex matches node2. func TestSimpleTxIndex(t *testing.T) { - require := require.New(t) - _ = require - - dbDir1 := getDirectory(t) - dbDir2 := getDirectory(t) - defer os.RemoveAll(dbDir1) - defer os.RemoveAll(dbDir2) - - config1 := generateConfig(t, 18000, dbDir1, 10) - config1.HyperSync = true - config1.SyncType = lib.NodeSyncTypeBlockSync - config2 := generateConfig(t, 18001, dbDir2, 10) - config2.HyperSync = true - config2.SyncType = lib.NodeSyncTypeHyperSyncArchival - - config1.TXIndex = true - config2.TXIndex = true - config1.ConnectIPs = []string{"deso-seed-2.io:17000"} - - node1 := cmd.NewNode(config1) - node2 := cmd.NewNode(config2) - + node1 := spawnNodeProtocol1(t, 18000, "node1") + node1.Config.ConnectIPs = []string{"deso-seed-2.io:17000"} + node1.Config.HyperSync = true + node1.Config.TXIndex = true node1 = startNode(t, node1) - node2 = startNode(t, node2) - // wait for node1 to sync blocks waitForNodeToFullySync(node1) - // bridge the nodes together. - bridge := NewConnectionBridge(node1, node2) - require.NoError(bridge.Start()) - - // wait for node2 to sync blocks. + node2 := spawnNodeProtocol1(t, 18001, "node2") + node2.Config.SyncType = lib.NodeSyncTypeHyperSyncArchival + node2.Config.ConnectIPs = []string{"127.0.0.1:18000"} + node2.Config.HyperSync = true + node2.Config.TXIndex = true + node2 = startNode(t, node2) + // wait for node1 to sync blocks waitForNodeToFullySync(node2) waitForNodeToFullySyncTxIndex(node1) @@ -56,7 +34,5 @@ func TestSimpleTxIndex(t *testing.T) { compareNodesByDB(t, node1, node2, 0) compareNodesByTxIndex(t, node1, node2, 0) - fmt.Println("Databases match!") - node1.Stop() - node2.Stop() + t.Logf("Databases match!") } diff --git a/lib/block_producer.go b/lib/block_producer.go index a4763abf6..b5ad66cb0 100644 --- a/lib/block_producer.go +++ b/lib/block_producer.go @@ -4,6 +4,7 @@ import ( "encoding/hex" "fmt" "math" + "strings" "sync" "sync/atomic" "time" @@ -80,15 +81,26 @@ func NewDeSoBlockProducer( var privKey *btcec.PrivateKey if blockProducerSeed != "" { - seedBytes, err := bip39.NewSeedWithErrorChecking(blockProducerSeed, "") - if err != nil { - return nil, fmt.Errorf("NewDeSoBlockProducer: Error converting mnemonic: %+v", err) - } + // If a blockProducerSeed is provided then we use it to generate a private key. + // If the block producer seed beings with 0x, we treat it as a hex seed. Otherwise, + // we treat it as a seed phrase. + if strings.HasPrefix(blockProducerSeed, "0x") { + privKeyBytes, err := hex.DecodeString(blockProducerSeed[2:]) + if err != nil { + return nil, fmt.Errorf("NewDeSoBlockProducer: Error decoding hex seed: %+v", err) + } + privKey, _ = btcec.PrivKeyFromBytes(btcec.S256(), privKeyBytes) + } else { + seedBytes, err := bip39.NewSeedWithErrorChecking(blockProducerSeed, "") + if err != nil { + return nil, fmt.Errorf("NewDeSoBlockProducer: Error converting mnemonic: %+v", err) + } - _, privKey, _, err = ComputeKeysFromSeed(seedBytes, 0, params) - if err != nil { - return nil, fmt.Errorf( - "NewDeSoBlockProducer: Error computing keys from seed: %+v", err) + _, privKey, _, err = ComputeKeysFromSeed(seedBytes, 0, params) + if err != nil { + return nil, fmt.Errorf( + "NewDeSoBlockProducer: Error computing keys from seed: %+v", err) + } } } diff --git a/lib/block_view.go b/lib/block_view.go index 168f2a997..652013370 100644 --- a/lib/block_view.go +++ b/lib/block_view.go @@ -3920,6 +3920,89 @@ func (bav *UtxoView) _connectTransaction( return utxoOpsForTxn, totalInput, totalOutput, fees, nil } +func (bav *UtxoView) ConnectTransactions( + txns []*MsgDeSoTxn, txHashes []*BlockHash, blockHeight uint32, blockTimestampNanoSecs int64, + verifySignatures bool, ignoreUtxos bool, ignoreFailing bool) ( + _combinedUtxoOps [][]*UtxoOperation, _totalInputs []uint64, _totalOutputs []uint64, + _fees []uint64, _successFlags []bool, _err error) { + + return bav._connectTransactions(txns, txHashes, blockHeight, blockTimestampNanoSecs, verifySignatures, + ignoreUtxos, ignoreFailing, 0) +} + +func (bav *UtxoView) ConnectTransactionsWithLimit( + txns []*MsgDeSoTxn, txHashes []*BlockHash, blockHeight uint32, blockTimestampNanoSecs int64, + verifySignatures bool, ignoreUtxos bool, ignoreFailing bool, transactionConnectLimit uint64) ( + _combinedUtxoOps [][]*UtxoOperation, _totalInputs []uint64, _totalOutputs []uint64, + _fees []uint64, _successFlags []bool, _err error) { + + return bav._connectTransactions(txns, txHashes, blockHeight, blockTimestampNanoSecs, verifySignatures, + ignoreUtxos, ignoreFailing, transactionConnectLimit) +} + +func (bav *UtxoView) _connectTransactions( + txns []*MsgDeSoTxn, txHashes []*BlockHash, blockHeight uint32, blockTimestampNanoSecs int64, + verifySignatures bool, ignoreUtxos bool, ignoreFailing bool, transactionConnectLimit uint64) ( + _combinedUtxoOps [][]*UtxoOperation, _totalInputs []uint64, _totalOutputs []uint64, + _fees []uint64, _successFlags []bool, _err error) { + + var combinedUtxoOps [][]*UtxoOperation + var totalInputs []uint64 + var totalOutputs []uint64 + var fees []uint64 + var successFlags []bool + var totalConnectedTxns uint64 + + updateValues := func(utxoOps []*UtxoOperation, totalInput uint64, totalOutput uint64, fee uint64, success bool) { + combinedUtxoOps = append(combinedUtxoOps, utxoOps) + totalInputs = append(totalInputs, totalInput) + totalOutputs = append(totalOutputs, totalOutput) + fees = append(fees, fee) + successFlags = append(successFlags, success) + } + + // Connect the transactions in the order they are given. + for ii, txn := range txns { + // Create a copy of the view to connect the transactions to in the event we have a failing txn. + copiedView, err := bav.CopyUtxoView() + if err != nil { + return nil, nil, nil, nil, nil, + errors.Wrapf(err, "ConnectTransactions: Problem copying UtxoView") + } + + // Connect the transaction. + utxoOpsForTxn, totalInput, totalOutput, fee, err := copiedView.ConnectTransaction( + txn, txHashes[ii], blockHeight, blockTimestampNanoSecs, verifySignatures, ignoreUtxos) + if err != nil && ignoreFailing { + glog.V(2).Infof("ConnectTransactions: Ignoring failing txn %d: %v", ii, err) + updateValues(nil, 0, 0, 0, false) + continue + } else if err != nil { + return nil, nil, nil, nil, nil, + errors.Wrapf(err, "ConnectTransactions: Problem connecting txn %d on copy view", ii) + } + + utxoOpsForTxn, totalInput, totalOutput, fee, err = bav.ConnectTransaction( + txn, txHashes[ii], blockHeight, blockTimestampNanoSecs, verifySignatures, ignoreUtxos) + if err != nil { + return nil, nil, nil, nil, nil, + errors.Wrapf(err, "ConnectTransactions: Problem connecting txn %d", ii) + } + updateValues(utxoOpsForTxn, totalInput, totalOutput, fee, true) + + if totalConnectedTxns == 0 { + continue + } + + totalConnectedTxns++ + if totalConnectedTxns >= transactionConnectLimit { + break + } + } + + return combinedUtxoOps, totalInputs, totalOutputs, fees, successFlags, nil +} + func (bav *UtxoView) ValidateTransactionNonce(txn *MsgDeSoTxn, blockHeight uint64) error { if txn == nil || txn.TxnNonce == nil { return fmt.Errorf("ValidateTransactionNonce: Nonce or txn is nil for public key %v", @@ -3952,103 +4035,6 @@ func (bav *UtxoView) ValidateTransactionNonce(txn *MsgDeSoTxn, blockHeight uint6 return nil } -// _connectFailingTransaction is used to process the fee and burn associated with the user submitting a failing transaction. -// A failing transaction is a txn that passes formatting validation, yet fails connecting to the UtxoView. This can happen for a -// number of reasons, such as insufficient DESO balance, wrong public key, etc. With Revolution's Fee-Time block ordering, these -// failing transactions are included in the blocks and their fees are burned. In addition, a major part of the effective -// fees of this transaction is burned with BMF. This makes spam attacks economically disadvantageous. Attacker's funds -// are burned, to the benefit of everyone else on the network. BMF algorithm also computes a utility fee, which is -// distributed to the block producer. -func (bav *UtxoView) _connectFailingTransaction(txn *MsgDeSoTxn, blockHeight uint32, verifySignatures bool) ( - _utxoOps []*UtxoOperation, _burnFee uint64, _utilityFee uint64, _err error) { - - // Failing transactions are only allowed after ProofOfStake2ConsensusCutoverBlockHeight. - if blockHeight < bav.Params.ForkHeights.ProofOfStake2ConsensusCutoverBlockHeight { - return nil, 0, 0, fmt.Errorf("_connectFailingTransaction: Failing transactions " + - "not allowed before ProofOfStake2ConsensusCutoverBlockHeight") - } - - // Sanity check the transaction to make sure it is properly formatted. - if err := CheckTransactionSanity(txn, blockHeight, bav.Params); err != nil { - return nil, 0, 0, errors.Wrapf(err, "_connectFailingTransaction: "+ - "Problem checking txn sanity") - } - - if err := ValidateDeSoTxnSanityBalanceModel(txn, uint64(blockHeight), bav.Params, bav.GlobalParamsEntry); err != nil { - return nil, 0, 0, errors.Wrapf(err, "_connectFailingTransaction: "+ - "Problem checking txn sanity under balance model") - } - - if err := bav.ValidateTransactionNonce(txn, uint64(blockHeight)); err != nil { - return nil, 0, 0, errors.Wrapf(err, "_connectFailingTransaction: "+ - "Problem validating transaction nonce") - } - - // Get the FailingTransactionBMFMultiplierBasisPoints from the global params entry. We then compute the effective fee - // as: effectiveFee = txn.TxnFeeNanos * FailingTransactionBMFMultiplierBasisPoints / 10000 - gp := bav.GetCurrentGlobalParamsEntry() - - failingTransactionRate := uint256.NewInt().SetUint64(gp.FailingTransactionBMFMultiplierBasisPoints) - failingTransactionFee := uint256.NewInt().SetUint64(txn.TxnFeeNanos) - basisPointsAsUint256 := uint256.NewInt().SetUint64(MaxBasisPoints) - - effectiveFeeU256 := uint256.NewInt() - if effectiveFeeU256.MulOverflow(failingTransactionRate, failingTransactionFee) { - return nil, 0, 0, fmt.Errorf("_connectFailingTransaction: Problem computing effective fee") - } - effectiveFeeU256.Div(effectiveFeeU256, basisPointsAsUint256) - - // We should never overflow on the effective fee, since FailingTransactionBMFMultiplierBasisPoints is <= 10000. - // But if for some magical reason we do, we set the effective fee to the max uint64. We don't error, and - // instead let _spendBalance handle the overflow. - if !effectiveFeeU256.IsUint64() { - effectiveFeeU256.SetUint64(math.MaxUint64) - } - effectiveFee := effectiveFeeU256.Uint64() - - // Serialize the transaction to bytes so we can compute its size. - txnBytes, err := txn.ToBytes(false) - if err != nil { - return nil, 0, 0, errors.Wrapf(err, "_connectFailingTransaction: Problem serializing transaction: ") - } - txnSizeBytes := uint64(len(txnBytes)) - - // If the effective fee rate per KB is less than the minimum network fee rate per KB, we set it to the minimum - // network fee rate per KB. We multiply by 1000 and divide by the txn bytes to convert the txn's total effective - // fee to a fee rate per KB. - // - // The effectiveFee * 1000 computation is guaranteed to not overflow because an overflow check is already - // performed in ValidateDeSoTxnSanityBalanceModel above. - effectiveFeeRateNanosPerKB := (effectiveFee * 1000) / txnSizeBytes - if effectiveFeeRateNanosPerKB < gp.MinimumNetworkFeeNanosPerKB { - // The minimum effective fee for the txn is the txn size * the minimum network fee rate per KB. - effectiveFee = (gp.MinimumNetworkFeeNanosPerKB * txnSizeBytes) / 1000 - } - - burnFee, utilityFee := computeBMF(effectiveFee) - - var utxoOps []*UtxoOperation - // When spending balances, we need to check for immature block rewards. Since we don't have - // the block rewards yet for the current block, we subtract one from the current block height - // when spending balances. - feeUtxoOp, err := bav._spendBalance(effectiveFee, txn.PublicKey, blockHeight-1) - if err != nil { - return nil, 0, 0, errors.Wrapf(err, "_connectFailingTransaction: Problem "+ - "spending balance") - } - utxoOps = append(utxoOps, feeUtxoOp) - - // If verifySignatures is passed, we check transaction signature. - if verifySignatures { - if err := bav._verifyTxnSignature(txn, blockHeight); err != nil { - return nil, 0, 0, errors.Wrapf(err, "_connectFailingTransaction: Problem "+ - "verifying signature") - } - } - - return utxoOps, burnFee, utilityFee, nil -} - // computeBMF computes the burn fee and the utility fee for a given fee. The acronym stands for Burn Maximizing Fee, which // entails that the burn function is designed to maximize the amount of DESO burned, while providing the minimal viable // utility fee to the block producer. This is so that block producers have no advantage over other network participants @@ -4156,66 +4142,24 @@ func (bav *UtxoView) ConnectBlock( for txIndex, txn := range desoBlock.Txns { txHash := txHashes[txIndex] - // PoS introduced a concept of a failing transaction, or transactions that fail UtxoView's ConnectTransaction. - // In PoS, these failing transactions are included in the block and their fees are burned. - - // To determine if we're dealing with a connecting or failing transaction, we first check if we're on a PoS block - // height. Otherwise, the transaction is expected to connect. - hasPoWBlockHeight := bav.Params.IsPoWBlockHeight(blockHeight) - // Also, the first transaction in the block, the block reward transaction, should always be a connecting transaction. - isBlockRewardTxn := (txIndex == 0) && (txn.TxnMeta.GetTxnType() == TxnTypeBlockReward) - // Finally, if the transaction is not the first in the block, we check the TxnConnectStatusByIndex to see if - // it's marked by the block producer as a connecting transaction. PoS blocks should reflect this in TxnConnectStatusByIndex. - hasConnectingPoSTxnStatus := false - if bav.Params.IsPoSBlockHeight(blockHeight) && (txIndex > 0) && (desoBlock.TxnConnectStatusByIndex != nil) { - // Note that TxnConnectStatusByIndex doesn't include the first block reward transaction. - hasConnectingPoSTxnStatus = desoBlock.TxnConnectStatusByIndex.Get(txIndex - 1) - } - // Now, we can determine if the transaction is expected to connect. - txnConnects := hasPoWBlockHeight || isBlockRewardTxn || hasConnectingPoSTxnStatus - var utilityFee uint64 var utxoOpsForTxn []*UtxoOperation var err error var currentFees uint64 - if txnConnects { - // ConnectTransaction validates all of the transactions in the block and - // is responsible for verifying signatures. - // - // TODO: We currently don't check that the min transaction fee is satisfied when - // connecting blocks. We skip this check because computing the transaction's size - // would slow down block processing significantly. We should figure out a way to - // enforce this check in the future, but for now the only attack vector is one in - // which a miner is trying to spam the network, which should generally never happen. - utxoOpsForTxn, _, _, currentFees, err = bav.ConnectTransaction( - txn, txHash, uint32(blockHeader.Height), blockHeader.TstampNanoSecs, verifySignatures, false) - if err != nil { - return nil, errors.Wrapf(err, "ConnectBlock: error connecting txn #%d", txIndex) - } - _, utilityFee = computeBMF(currentFees) - } else { - // If the transaction is not supposed to connect, we need to verify that it won't connect. - // We need to construct a copy of the view to verify that the transaction won't connect - // without side effects. - var utxoViewCopy *UtxoView - utxoViewCopy, err = bav.CopyUtxoView() - if err != nil { - return nil, errors.Wrapf(err, "ConnectBlock: error copying UtxoView") - } - _, _, _, _, err = utxoViewCopy.ConnectTransaction( - txn, txHash, uint32(blockHeader.Height), blockHeader.TstampNanoSecs, verifySignatures, false) - if err == nil { - return nil, errors.Errorf("ConnectBlock: txn #%d should not connect but err is nil", txIndex) - } - var burnFee uint64 - // Connect the failing transaction to get the fees and utility fee. - utxoOpsForTxn, burnFee, utilityFee, err = bav._connectFailingTransaction( - txn, uint32(blockHeader.Height), verifySignatures) - if err != nil { - return nil, errors.Wrapf(err, "ConnectBlock: error connecting failing txn #%d", txIndex) - } - currentFees = burnFee + utilityFee + // ConnectTransaction validates all of the transactions in the block and + // is responsible for verifying signatures. + // + // TODO: We currently don't check that the min transaction fee is satisfied when + // connecting blocks. We skip this check because computing the transaction's size + // would slow down block processing significantly. We should figure out a way to + // enforce this check in the future, but for now the only attack vector is one in + // which a miner is trying to spam the network, which should generally never happen. + utxoOpsForTxn, _, _, currentFees, err = bav.ConnectTransaction( + txn, txHash, uint32(blockHeader.Height), blockHeader.TstampNanoSecs, verifySignatures, false) + if err != nil { + return nil, errors.Wrapf(err, "ConnectBlock: error connecting txn #%d", txIndex) } + _, utilityFee = computeBMF(currentFees) // After the block reward patch block height, we only include fees from transactions // where the transactor is not the block reward output public key. This prevents diff --git a/lib/block_view_test.go b/lib/block_view_test.go index 0056fbbed..ee7f173bf 100644 --- a/lib/block_view_test.go +++ b/lib/block_view_test.go @@ -12,8 +12,6 @@ import ( "github.com/deso-protocol/core/bls" - "math/rand" - "github.com/btcsuite/btcd/btcec" "github.com/decred/dcrd/lru" "github.com/dgraph-io/badger/v3" @@ -2217,166 +2215,3 @@ func TestBlockRewardPatch(t *testing.T) { require.NoError(t, err) } } - -func TestConnectFailingTransaction(t *testing.T) { - setBalanceModelBlockHeights(t) - setPoSBlockHeights(t, 3, 3) - require := require.New(t) - seed := int64(1011) - rand := rand.New(rand.NewSource(seed)) - - globalParams := _testGetDefaultGlobalParams() - feeMin := globalParams.MinimumNetworkFeeNanosPerKB - feeMax := uint64(10000) - - chain, params, db := NewLowDifficultyBlockchain(t) - mempool, miner := NewTestMiner(t, chain, params, true) - // Mine a few blocks to give the senderPkString some money. - _, err := miner.MineAndProcessSingleBlock(0 /*threadIndex*/, mempool) - require.NoError(err) - _, err = miner.MineAndProcessSingleBlock(0 /*threadIndex*/, mempool) - require.NoError(err) - - m0PubBytes, _, _ := Base58CheckDecode(m0Pub) - m0PublicKeyBase58Check := Base58CheckEncode(m0PubBytes, false, params) - - _, _, _ = _doBasicTransferWithViewFlush( - t, chain, db, params, senderPkString, m0PublicKeyBase58Check, - senderPrivString, 200000, 11) - - blockHeight := chain.BlockTip().Height + 1 - - // Set up the test meta. - testMeta := &TestMeta{ - t: t, - chain: chain, - params: params, - db: db, - mempool: mempool, - miner: miner, - savedHeight: blockHeight, - feeRateNanosPerKb: uint64(201), - } - // Allow m0 to update global params. - params.ExtraRegtestParamUpdaterKeys[MakePkMapKey(m0PubBytes)] = true - - // Test failing txn with default global params - { - blockView, err := NewUtxoView(db, params, nil, nil, chain.eventManager) - require.NoError(err) - txn := _generateTestTxn(t, rand, feeMin, feeMax, m0PubBytes, m0Priv, 100, 0) - utxoOps, burnFee, utilityFee, err := blockView._connectFailingTransaction(txn, blockHeight, true) - require.NoError(err) - require.Equal(1, len(utxoOps)) - expectedBurnFee, expectedUtilityFee := _getBMFForTxn(txn, globalParams) - require.Equal(expectedBurnFee, burnFee) - require.Equal(expectedUtilityFee, utilityFee) - - err = blockView.FlushToDb(uint64(blockHeight)) - require.NoError(err) - } - - // Test case where the failing txn fee rate is applied as expected. - { - - { - // Set FailingTransactionBMFMultiplierBasisPoints=7000 or 70%. - _updateGlobalParamsEntryWithExtraData( - testMeta, - testMeta.feeRateNanosPerKb, - m0Pub, - m0Priv, - map[string][]byte{FailingTransactionBMFMultiplierBasisPointsKey: UintToBuf(7000)}, - ) - } - blockView, err := NewUtxoView(db, params, nil, nil, chain.eventManager) - require.NoError(err) - - newParams := blockView.GetCurrentGlobalParamsEntry() - require.Equal(uint64(7000), newParams.FailingTransactionBMFMultiplierBasisPoints) - - startingBalance, err := blockView.GetDeSoBalanceNanosForPublicKey(m0PubBytes) - require.NoError(err) - - // Try connecting another failing transaction, and make sure the burn and utility fees are computed accurately. - txn := _generateTestTxn(t, rand, feeMin, feeMax, m0PubBytes, m0Priv, 100, 0) - - utxoOps, burnFee, utilityFee, err := blockView._connectFailingTransaction(txn, blockHeight, true) - require.NoError(err) - require.Equal(1, len(utxoOps)) - - // The final balance is m0's starting balance minus the failing txn fee paid. - finalBalance, err := blockView.GetDeSoBalanceNanosForPublicKey(m0PubBytes) - require.NoError(err) - - // Recompute the failing txn fee, which is expected to use the minimum network fee rate because - // the failing txn fee rate is too low on its own. - expectedFailingTxnFee := txn.TxnFeeNanos * newParams.FailingTransactionBMFMultiplierBasisPoints / MaxBasisPoints - require.Equal(startingBalance, finalBalance+expectedFailingTxnFee) - - expectedBurnFee, expectedUtilityFee := _getBMFForTxn(txn, newParams) - require.Equal(expectedBurnFee, burnFee) - require.Equal(expectedUtilityFee, utilityFee) - - err = blockView.FlushToDb(uint64(blockHeight)) - require.NoError(err) - } - - // Test case where the failing txn fee rate is too low and replaced by the minimum network fee. - { - { - // Set FailingTransactionBMFMultiplierBasisPoints=1 or 0.01%. - _updateGlobalParamsEntryWithExtraData( - testMeta, - testMeta.feeRateNanosPerKb, - m0Pub, - m0Priv, - map[string][]byte{FailingTransactionBMFMultiplierBasisPointsKey: UintToBuf(1)}, - ) - } - - // Set the txn fee to ~1000 nanos, which guarantees that the effective failing txn fee rate is too low. - feeMin := uint64(1000) - feeMax := uint64(1001) - - blockView, err := NewUtxoView(db, params, nil, nil, chain.eventManager) - require.NoError(err) - - newParams := blockView.GetCurrentGlobalParamsEntry() - require.Equal(uint64(1), newParams.FailingTransactionBMFMultiplierBasisPoints) - - startingBalance, err := blockView.GetDeSoBalanceNanosForPublicKey(m0PubBytes) - require.NoError(err) - - txn := _generateTestTxn(t, rand, feeMin, feeMax, m0PubBytes, m0Priv, 100, 0) - utxoOps, burnFee, utilityFee, err := blockView._connectFailingTransaction(txn, blockHeight, true) - require.NoError(err) - require.Equal(1, len(utxoOps)) - - // The final balance is m0's starting balance minus the failing txn fee paid. - finalBalance, err := blockView.GetDeSoBalanceNanosForPublicKey(m0PubBytes) - require.NoError(err) - - txnBytes, err := txn.ToBytes(false) - require.NoError(err) - - // Recompute the failing txn fee, which is expected to use the minimum network fee rate because - // the failing txn fee rate is too low on its own. - expectedFailingTxnFee := uint64(len(txnBytes)) * newParams.MinimumNetworkFeeNanosPerKB / 1000 - require.Equal(startingBalance, finalBalance+expectedFailingTxnFee) - - expectedBurnFee, expectedUtilityFee := computeBMF(expectedFailingTxnFee) - require.Equal(expectedBurnFee, burnFee) - require.Equal(expectedUtilityFee, utilityFee) - - err = blockView.FlushToDb(uint64(blockHeight)) - require.NoError(err) - } -} - -func _getBMFForTxn(txn *MsgDeSoTxn, gp *GlobalParamsEntry) (_burnFee uint64, _utilityFee uint64) { - failingTransactionRate := NewFloat().SetUint64(gp.FailingTransactionBMFMultiplierBasisPoints) - failingTransactionRate.Quo(failingTransactionRate, NewFloat().SetUint64(10000)) - failingTransactionFee, _ := NewFloat().Mul(failingTransactionRate, NewFloat().SetUint64(txn.TxnFeeNanos)).Uint64() - return computeBMF(failingTransactionFee) -} diff --git a/lib/block_view_types.go b/lib/block_view_types.go index 04d52e1ae..b487c91ec 100644 --- a/lib/block_view_types.go +++ b/lib/block_view_types.go @@ -682,7 +682,8 @@ const ( OperationTypeStakeDistributionRestake OperationType = 49 OperationTypeStakeDistributionPayToBalance OperationType = 50 OperationTypeSetValidatorLastActiveAtEpoch OperationType = 51 - // NEXT_TAG = 52 + OperationTypeFailingTxn OperationType = 52 + // NEXT_TAG = 53 ) func (op OperationType) String() string { diff --git a/lib/blockchain.go b/lib/blockchain.go index 3a8b9484c..764a53681 100644 --- a/lib/blockchain.go +++ b/lib/blockchain.go @@ -1146,6 +1146,14 @@ func (bc *Blockchain) HasBlock(blockHash *BlockHash) bool { return true } +func (bc *Blockchain) HasBlockInBlockIndex(blockHash *BlockHash) bool { + bc.ChainLock.RLock() + defer bc.ChainLock.RUnlock() + + _, exists := bc.blockIndexByHash[*blockHash] + return exists +} + // This needs to hold a lock on the blockchain because it read from an in-memory map that is // not thread-safe. func (bc *Blockchain) GetBlockHeaderFromIndex(blockHash *BlockHash) *MsgDeSoHeader { diff --git a/lib/connection_manager.go b/lib/connection_manager.go index a14742c8b..ffc56aa25 100644 --- a/lib/connection_manager.go +++ b/lib/connection_manager.go @@ -4,7 +4,7 @@ import ( "fmt" "math" "net" - "strconv" + "sync" "sync/atomic" "time" @@ -12,9 +12,7 @@ import ( chainlib "github.com/btcsuite/btcd/blockchain" "github.com/btcsuite/btcd/wire" "github.com/decred/dcrd/lru" - "github.com/deso-protocol/go-deadlock" "github.com/golang/glog" - "github.com/pkg/errors" ) // connection_manager.go contains most of the logic for creating and managing @@ -36,24 +34,10 @@ type ConnectionManager struct { // doesn't need a reference to the Server object. But for now we keep things lazy. srv *Server - // When --connectips is set, we don't connect to anything from the addrmgr. - connectIps []string - - // The address manager keeps track of peer addresses we're aware of. When - // we need to connect to a new outbound peer, it chooses one of the addresses - // it's aware of at random and provides it to us. - AddrMgr *addrmgr.AddrManager // The interfaces we listen on for new incoming connections. listeners []net.Listener // The parameters we are initialized with. params *DeSoParams - // The target number of outbound peers we want to have. - targetOutboundPeers uint32 - // The maximum number of inbound peers we allow. - maxInboundPeers uint32 - // When true, only one connection per IP is allowed. Prevents eclipse attacks - // among other things. - limitOneInboundConnectionPerIP bool // When --hypersync is set to true we will attempt fast block synchronization HyperSync bool @@ -80,17 +64,26 @@ type ConnectionManager struct { // concurrently by many goroutines to figure out if outbound connections // should be made to particular addresses. - mtxOutboundConnIPGroups deadlock.Mutex + mtxOutboundConnIPGroups sync.Mutex outboundConnIPGroups map[string]int // The peer maps map peer ID to peers for various types of peer connections. // // A persistent peer is typically one we got through a commandline argument. // The reason it's called persistent is because we maintain a connection to // it, and retry the connection if it fails. - mtxPeerMaps deadlock.RWMutex + mtxPeerMaps sync.RWMutex persistentPeers map[uint64]*Peer outboundPeers map[uint64]*Peer inboundPeers map[uint64]*Peer + connectedPeers map[uint64]*Peer + + mtxConnectionAttempts sync.Mutex + // outboundConnectionAttempts keeps track of the outbound connections, mapping attemptId [uint64] -> connection attempt. + outboundConnectionAttempts map[uint64]*OutboundConnectionAttempt + // outboundConnectionChan is used to signal successful outbound connections to the connection manager. + outboundConnectionChan chan *outboundConnection + // inboundConnectionChan is used to signal successful inbound connections to the connection manager. + inboundConnectionChan chan *inboundConnection // Track the number of outbound peers we have so that this value can // be accessed concurrently when deciding whether or not to add more // outbound peers. @@ -102,11 +95,9 @@ type ConnectionManager struct { // avoid choosing them in the address manager. We need a mutex on this // guy because many goroutines will be querying the address manager // at once. - mtxConnectedOutboundAddrs deadlock.RWMutex - connectedOutboundAddrs map[string]bool - - // Used to set peer ids. Must be incremented atomically. - peerIndex uint64 + mtxAddrsMaps sync.RWMutex + connectedOutboundAddrs map[string]bool + attemptedOutboundAddrs map[string]bool serverMessageQueue chan *ServerMessage @@ -114,9 +105,8 @@ type ConnectionManager struct { // peers' time. timeSource chainlib.MedianTimeSource - // Events that can happen to a peer. - newPeerChan chan *Peer - donePeerChan chan *Peer + // peerDisconnectedChan is notified whenever a peer exits. + peerDisconnectedChan chan *Peer // stallTimeoutSeconds is how long we wait to receive responses from Peers // for certain types of messages. @@ -129,10 +119,8 @@ type ConnectionManager struct { } func NewConnectionManager( - _params *DeSoParams, _addrMgr *addrmgr.AddrManager, _listeners []net.Listener, + _params *DeSoParams, _listeners []net.Listener, _connectIps []string, _timeSource chainlib.MedianTimeSource, - _targetOutboundPeers uint32, _maxInboundPeers uint32, - _limitOneInboundConnectionPerIP bool, _hyperSync bool, _syncType NodeSyncType, _stallTimeoutSeconds uint64, @@ -143,55 +131,52 @@ func NewConnectionManager( ValidateHyperSyncFlags(_hyperSync, _syncType) return &ConnectionManager{ - srv: _srv, - params: _params, - AddrMgr: _addrMgr, - listeners: _listeners, - connectIps: _connectIps, + srv: _srv, + params: _params, + listeners: _listeners, // We keep track of the last N nonces we've sent in order to detect // self connections. sentNonces: lru.NewCache(1000), timeSource: _timeSource, - //newestBlock: _newestBlock, // Initialize the peer data structures. - outboundConnIPGroups: make(map[string]int), - persistentPeers: make(map[uint64]*Peer), - outboundPeers: make(map[uint64]*Peer), - inboundPeers: make(map[uint64]*Peer), - connectedOutboundAddrs: make(map[string]bool), + outboundConnIPGroups: make(map[string]int), + persistentPeers: make(map[uint64]*Peer), + outboundPeers: make(map[uint64]*Peer), + inboundPeers: make(map[uint64]*Peer), + connectedPeers: make(map[uint64]*Peer), + outboundConnectionAttempts: make(map[uint64]*OutboundConnectionAttempt), + connectedOutboundAddrs: make(map[string]bool), + attemptedOutboundAddrs: make(map[string]bool), // Initialize the channels. - newPeerChan: make(chan *Peer), - donePeerChan: make(chan *Peer), - - targetOutboundPeers: _targetOutboundPeers, - maxInboundPeers: _maxInboundPeers, - limitOneInboundConnectionPerIP: _limitOneInboundConnectionPerIP, - HyperSync: _hyperSync, - SyncType: _syncType, - serverMessageQueue: _serverMessageQueue, - stallTimeoutSeconds: _stallTimeoutSeconds, - minFeeRateNanosPerKB: _minFeeRateNanosPerKB, - } -} + peerDisconnectedChan: make(chan *Peer, 100), + outboundConnectionChan: make(chan *outboundConnection, 100), + inboundConnectionChan: make(chan *inboundConnection, 100), -func (cmgr *ConnectionManager) GetAddrManager() *addrmgr.AddrManager { - return cmgr.AddrMgr + HyperSync: _hyperSync, + SyncType: _syncType, + serverMessageQueue: _serverMessageQueue, + stallTimeoutSeconds: _stallTimeoutSeconds, + minFeeRateNanosPerKB: _minFeeRateNanosPerKB, + } } -// Check if the address passed shares a group with any addresses already in our -// data structures. -func (cmgr *ConnectionManager) isRedundantGroupKey(na *wire.NetAddress) bool { +// Check if the address passed shares a group with any addresses already in our data structures. +func (cmgr *ConnectionManager) IsFromRedundantOutboundIPAddress(na *wire.NetAddress) bool { groupKey := addrmgr.GroupKey(na) + // For the sake of running multiple nodes on the same machine, we allow localhost connections. + if groupKey == "local" { + return false + } cmgr.mtxOutboundConnIPGroups.Lock() numGroupsForKey := cmgr.outboundConnIPGroups[groupKey] cmgr.mtxOutboundConnIPGroups.Unlock() if numGroupsForKey != 0 && numGroupsForKey != 1 { - glog.V(2).Infof("isRedundantGroupKey: Found numGroupsForKey != (0 or 1). Is (%d) "+ + glog.V(2).Infof("IsFromRedundantOutboundIPAddress: Found numGroupsForKey != (0 or 1). Is (%d) "+ "instead for addr (%s) and group key (%s). This "+ "should never happen.", numGroupsForKey, na.IP.String(), groupKey) } @@ -202,7 +187,7 @@ func (cmgr *ConnectionManager) isRedundantGroupKey(na *wire.NetAddress) bool { return true } -func (cmgr *ConnectionManager) addToGroupKey(na *wire.NetAddress) { +func (cmgr *ConnectionManager) AddToGroupKey(na *wire.NetAddress) { groupKey := addrmgr.GroupKey(na) cmgr.mtxOutboundConnIPGroups.Lock() @@ -218,48 +203,13 @@ func (cmgr *ConnectionManager) subFromGroupKey(na *wire.NetAddress) { cmgr.mtxOutboundConnIPGroups.Unlock() } -func (cmgr *ConnectionManager) getRandomAddr() *wire.NetAddress { - for tries := 0; tries < 100; tries++ { - // Lock the address map since multiple threads will be trying to read - // and modify it at the same time. - cmgr.mtxConnectedOutboundAddrs.RLock() - addr := cmgr.AddrMgr.GetAddress() - cmgr.mtxConnectedOutboundAddrs.RUnlock() - - if addr == nil { - glog.V(2).Infof("ConnectionManager.getRandomAddr: addr from GetAddressWithExclusions was nil") - break - } - - if cmgr.connectedOutboundAddrs[addrmgr.NetAddressKey(addr.NetAddress())] { - glog.V(2).Infof("ConnectionManager.getRandomAddr: Not choosing already connected address %v:%v", addr.NetAddress().IP, addr.NetAddress().Port) - continue - } - - // We can only have one outbound address per /16. This is similar to - // Bitcoin and we do it to prevent Sybil attacks. - if cmgr.isRedundantGroupKey(addr.NetAddress()) { - glog.V(2).Infof("ConnectionManager.getRandomAddr: Not choosing address due to redundant group key %v:%v", addr.NetAddress().IP, addr.NetAddress().Port) - continue - } - - glog.V(2).Infof("ConnectionManager.getRandomAddr: Returning %v:%v at %d iterations", - addr.NetAddress().IP, addr.NetAddress().Port, tries) - return addr.NetAddress() - } - - glog.V(2).Infof("ConnectionManager.getRandomAddr: Returning nil") - return nil -} - -func _delayRetry(retryCount int, persistentAddrForLogging *wire.NetAddress) { +func _delayRetry(retryCount uint64, persistentAddrForLogging *wire.NetAddress, unit time.Duration) (_retryDuration time.Duration) { // No delay if we haven't tried yet or if the number of retries isn't positive. if retryCount <= 0 { - time.Sleep(time.Second) - return + return 0 } numSecs := int(math.Pow(2.0, float64(retryCount))) - retryDelay := time.Duration(numSecs) * time.Second + retryDelay := time.Duration(numSecs) * unit if persistentAddrForLogging != nil { glog.V(1).Infof("Retrying connection to outbound persistent peer: "+ @@ -268,122 +218,70 @@ func _delayRetry(retryCount int, persistentAddrForLogging *wire.NetAddress) { } else { glog.V(2).Infof("Retrying connection to outbound non-persistent peer in (%d) seconds.", numSecs) } - time.Sleep(retryDelay) + return retryDelay } -func (cmgr *ConnectionManager) enoughOutboundPeers() bool { - val := atomic.LoadUint32(&cmgr.numOutboundPeers) - if val > cmgr.targetOutboundPeers { - glog.Errorf("enoughOutboundPeers: Connected to too many outbound "+ - "peers: (%d). Should be "+ - "no more than (%d).", val, cmgr.targetOutboundPeers) - return true - } - - if val == cmgr.targetOutboundPeers { - return true - } - return false +func (cmgr *ConnectionManager) IsConnectedOutboundIpAddress(netAddr *wire.NetAddress) bool { + cmgr.mtxAddrsMaps.RLock() + defer cmgr.mtxAddrsMaps.RUnlock() + return cmgr.connectedOutboundAddrs[addrmgr.NetAddressKey(netAddr)] } -// Chooses a random address and tries to connect to it. Repeats this process until -// it finds a peer that can pass version negotiation. -func (cmgr *ConnectionManager) _getOutboundConn(persistentAddr *wire.NetAddress) net.Conn { - // If a persistentAddr was provided then the connection is a persistent - // one. - isPersistent := (persistentAddr != nil) - retryCount := 0 - for { - if atomic.LoadInt32(&cmgr.shutdown) != 0 { - glog.Info("_getOutboundConn: Ignoring connection due to shutdown") - return nil - } - // We want to start backing off exponentially once we've gone through enough - // unsuccessful retries. However, we want to give more slack to non-persistent - // peers before we start backing off, which is why it's not as cut and dry as - // just delaying based on the raw number of retries. - adjustedRetryCount := retryCount - if !isPersistent { - // If the address is not persistent, only start backing off once there - // has been a large number of failed attempts in a row as this likely indicates - // that there's a connection issue we need to wait out. - adjustedRetryCount = retryCount - 5 - } - _delayRetry(adjustedRetryCount, persistentAddr) - retryCount++ - - // If the connection manager is saturated with non-persistent - // outbound peers, no need to keep trying non-persistent outbound - // connections. - if !isPersistent && cmgr.enoughOutboundPeers() { - glog.V(1).Infof("Dropping connection request to non-persistent outbound " + - "peer because we have enough of them.") - return nil - } - - // If we don't have a persistentAddr, pick one from our addrmgr. - ipNetAddr := persistentAddr - if ipNetAddr == nil { - ipNetAddr = cmgr.getRandomAddr() - } - if ipNetAddr == nil { - // This should never happen but if it does, sleep a bit and try again. - glog.V(1).Infof("_getOutboundConn: No valid addresses to connect to.") - time.Sleep(time.Second) - continue - } +func (cmgr *ConnectionManager) IsAttemptedOutboundIpAddress(netAddr *wire.NetAddress) bool { + cmgr.mtxAddrsMaps.RLock() + defer cmgr.mtxAddrsMaps.RUnlock() + return cmgr.attemptedOutboundAddrs[addrmgr.NetAddressKey(netAddr)] +} - netAddr := net.TCPAddr{ - IP: ipNetAddr.IP, - Port: int(ipNetAddr.Port), - } +func (cmgr *ConnectionManager) AddAttemptedOutboundAddrs(netAddr *wire.NetAddress) { + cmgr.mtxAddrsMaps.Lock() + defer cmgr.mtxAddrsMaps.Unlock() + cmgr.attemptedOutboundAddrs[addrmgr.NetAddressKey(netAddr)] = true +} - // If the peer is not persistent, update the addrmgr. - glog.V(1).Infof("Attempting to connect to addr: %v", netAddr) - if !isPersistent { - cmgr.AddrMgr.Attempt(ipNetAddr) - } - var err error - conn, err := net.DialTimeout(netAddr.Network(), netAddr.String(), cmgr.params.DialTimeout) - if err != nil { - // If we failed to connect to this peer, get a new address and try again. - glog.V(1).Infof("Connection to addr (%v) failed: %v", netAddr, err) - continue - } +func (cmgr *ConnectionManager) RemoveAttemptedOutboundAddrs(netAddr *wire.NetAddress) { + cmgr.mtxAddrsMaps.Lock() + defer cmgr.mtxAddrsMaps.Unlock() + delete(cmgr.attemptedOutboundAddrs, addrmgr.NetAddressKey(netAddr)) +} - // We were able to dial successfully so we'll break out now. - glog.V(1).Infof("Connected to addr: %v", netAddr) +// DialPersistentOutboundConnection attempts to connect to a persistent peer. +func (cmgr *ConnectionManager) DialPersistentOutboundConnection(persistentAddr *wire.NetAddress, attemptId uint64) (_attemptId uint64) { + glog.V(2).Infof("ConnectionManager.DialPersistentOutboundConnection: Connecting to peer (IP=%v, Port=%v)", + persistentAddr.IP.String(), persistentAddr.Port) + return cmgr._dialOutboundConnection(persistentAddr, attemptId, true) +} - // If this was a non-persistent outbound connection, mark the address as - // connected in the addrmgr. - if !isPersistent { - cmgr.AddrMgr.Connected(ipNetAddr) - } +// DialOutboundConnection attempts to connect to a non-persistent peer. +func (cmgr *ConnectionManager) DialOutboundConnection(addr *wire.NetAddress, attemptId uint64) { + glog.V(2).Infof("ConnectionManager.ConnectOutboundConnection: Connecting to peer (IP=%v, Port=%v)", + addr.IP.String(), addr.Port) + cmgr._dialOutboundConnection(addr, attemptId, false) +} - // We made a successful outbound connection so return. - return conn +// CloseAttemptedConnection closes an ongoing connection attempt. +func (cmgr *ConnectionManager) CloseAttemptedConnection(attemptId uint64) { + glog.V(2).Infof("ConnectionManager.CloseAttemptedConnection: Closing connection attempt %d", attemptId) + cmgr.mtxConnectionAttempts.Lock() + defer cmgr.mtxConnectionAttempts.Unlock() + if attempt, exists := cmgr.outboundConnectionAttempts[attemptId]; exists { + attempt.Stop() + delete(cmgr.outboundConnectionAttempts, attemptId) } } -func IPToNetAddr(ipStr string, addrMgr *addrmgr.AddrManager, params *DeSoParams) (*wire.NetAddress, error) { - port := params.DefaultSocketPort - host, portstr, err := net.SplitHostPort(ipStr) - if err != nil { - // No port specified so leave port=default and set - // host to the ipStr. - host = ipStr - } else { - pp, err := strconv.ParseUint(portstr, 10, 16) - if err != nil { - return nil, errors.Wrapf(err, "IPToNetAddr: Can not parse port from %s for ip", ipStr) - } - port = uint16(pp) - } - netAddr, err := addrMgr.HostToNetAddress(host, port, 0) - if err != nil { - return nil, errors.Wrapf(err, "IPToNetAddr: Can not parse port from %s for ip", ipStr) - } - return netAddr, nil +// _dialOutboundConnection is the internal method that spawns and initiates an OutboundConnectionAttempt, which handles the +// connection attempt logic. It returns the attemptId of the attempt that was created. +func (cmgr *ConnectionManager) _dialOutboundConnection(addr *wire.NetAddress, attemptId uint64, isPersistent bool) (_attemptId uint64) { + connectionAttempt := NewOutboundConnectionAttempt(attemptId, addr, isPersistent, + cmgr.params.DialTimeout, cmgr.outboundConnectionChan) + cmgr.mtxConnectionAttempts.Lock() + cmgr.outboundConnectionAttempts[connectionAttempt.attemptId] = connectionAttempt + cmgr.mtxConnectionAttempts.Unlock() + cmgr.AddAttemptedOutboundAddrs(addr) + + connectionAttempt.Start() + return attemptId } // ConnectPeer connects either an INBOUND or OUTBOUND peer. If Conn == nil, @@ -392,158 +290,40 @@ func IPToNetAddr(ipStr string, addrMgr *addrmgr.AddrManager, params *DeSoParams) // is set, then we will connect only to that addr. Otherwise, we will use // the addrmgr to randomly select addrs and create OUTBOUND connections // with them until we find a worthy peer. -func (cmgr *ConnectionManager) ConnectPeer(conn net.Conn, persistentAddr *wire.NetAddress) { - // If we don't have a connection object then we will try and make an - // outbound connection to a peer to get one. - isOutbound := false - if conn == nil { - isOutbound = true - } - isPersistent := (persistentAddr != nil) - retryCount := 0 - for { - // If the peer is persistent use exponential back off delay before retrying. - if isPersistent { - _delayRetry(retryCount, persistentAddr) - } - retryCount++ - - // If this is an outbound peer, create an outbound connection. - if isOutbound { - conn = cmgr._getOutboundConn(persistentAddr) - } - - if conn == nil { - // Conn should only be nil if this is a non-persistent outbound peer. - if isPersistent { - glog.Errorf("ConnectPeer: Got a nil connection for a persistent peer. This should never happen: (%s)", persistentAddr.IP.String()) - } - - // If we end up without a connection object, it implies we had enough - // outbound peers so just return. - return - } - - // At this point Conn is set so create a peer object to do - // a version negotiation. - na, err := IPToNetAddr(conn.RemoteAddr().String(), cmgr.AddrMgr, cmgr.params) - if err != nil { - glog.Errorf("ConnectPeer: Problem calling ipToNetAddr for addr: (%s) err: (%v)", conn.RemoteAddr().String(), err) +func (cmgr *ConnectionManager) ConnectPeer(id uint64, conn net.Conn, na *wire.NetAddress, isOutbound bool, + isPersistent bool) *Peer { - // If we get an error in the conversion and this is an - // outbound connection, keep trying it. Otherwise, just return. - if isOutbound { - continue - } - return - } - peer := NewPeer(conn, isOutbound, na, isPersistent, - cmgr.stallTimeoutSeconds, - cmgr.minFeeRateNanosPerKB, - cmgr.params, - cmgr.srv.incomingMessages, cmgr, cmgr.srv, cmgr.SyncType) - - if err := peer.NegotiateVersion(cmgr.params.VersionNegotiationTimeout); err != nil { - glog.Errorf("ConnectPeer: Problem negotiating version with peer with addr: (%s) err: (%v)", conn.RemoteAddr().String(), err) - - // If we have an error in the version negotiation we disconnect - // from this peer. - peer.Conn.Close() - - // If the connection is outbound, then - // we try a new connection until we get one that works. Otherwise - // we break. - if isOutbound { - continue - } - return - } - peer._logVersionSuccess() + // At this point Conn is set so create a peer object to do a version negotiation. + peer := NewPeer(id, conn, isOutbound, na, isPersistent, + cmgr.stallTimeoutSeconds, + cmgr.minFeeRateNanosPerKB, + cmgr.params, + cmgr.srv.incomingMessages, cmgr, cmgr.srv, cmgr.SyncType, + cmgr.peerDisconnectedChan) - // If the version negotiation worked and we have an outbound non-persistent - // connection, mark the address as good in the addrmgr. - if isOutbound && !isPersistent { - cmgr.AddrMgr.Good(na) - } - - // We connected to the peer and it passed its version negotiation. - // Handle the next steps in the main loop. - cmgr.newPeerChan <- peer - - // Once we've successfully connected to a valid peer we're done. The connection - // manager will handle starting the peer and, if this is an outbound peer and - // the peer later disconnects, - // it will potentially try and reconnect the peer or replace the peer with - // a new one so that we always maintain a fixed number of outbound peers. - return - } -} + // Now we can add the peer to our data structures. + peer._logAddPeer() + cmgr.addPeer(peer) -func (cmgr *ConnectionManager) _initiateOutboundConnections() { - // This is a hack to make outbound connections go away. - if cmgr.targetOutboundPeers == 0 { - return - } - if len(cmgr.connectIps) > 0 { - // Connect to addresses passed via the --connect-ips flag. These addresses - // are persistent in the sense that if we disconnect from one, we will - // try to reconnect to the same one. - for _, connectIp := range cmgr.connectIps { - ipNetAddr, err := IPToNetAddr(connectIp, cmgr.AddrMgr, cmgr.params) - if err != nil { - glog.Error(errors.Errorf("Couldn't connect to IP %v: %v", connectIp, err)) - continue - } + // Start the peer's message loop. + peer.Start() - go func(na *wire.NetAddress) { - cmgr.ConnectPeer(nil, na) - }(ipNetAddr) - } - return - } - // Only connect to addresses from the addrmgr if we don't specify --connect-ips. - // These addresses are *not* persistent, meaning if we disconnect from one we'll - // try a different one. - // - // TODO: We should try more addresses than we need initially to increase the - // speed at which we saturate our outbound connections. The ConnectionManager - // will handle the disconnection from peers once we have enough outbound - // connections. I had this as the logic before but removed it because it caused - // contention of the AddrMgr's lock. - for ii := 0; ii < int(cmgr.targetOutboundPeers); ii++ { - go cmgr.ConnectPeer(nil, nil) - } + return peer } -func (cmgr *ConnectionManager) _isFromRedundantInboundIPAddress(addrToCheck net.Addr) bool { +func (cmgr *ConnectionManager) IsDuplicateInboundIPAddress(netAddr *wire.NetAddress) bool { cmgr.mtxPeerMaps.RLock() defer cmgr.mtxPeerMaps.RUnlock() // Loop through all the peers to see if any have the same IP // address. This map is normally pretty small so doing this // every time a Peer connects should be fine. - netAddr, err := IPToNetAddr(addrToCheck.String(), cmgr.AddrMgr, cmgr.params) - if err != nil { - // Return true in case we have an error. We do this because it - // will result in the peer connection not being accepted, which - // is desired in this case. - glog.Warningf(errors.Wrapf(err, - "ConnectionManager._isFromRedundantInboundIPAddress: Problem parsing "+ - "net.Addr to wire.NetAddress so marking as redundant and not "+ - "making connection").Error()) - return true - } - if netAddr == nil { - glog.Warningf("ConnectionManager._isFromRedundantInboundIPAddress: " + - "address was nil after parsing so marking as redundant and not " + - "making connection") - return true - } + // If the IP is a localhost IP let it slide. This is useful for testing fake // nodes on a local machine. // TODO: Should this be a flag? if net.IP([]byte{127, 0, 0, 1}).Equal(netAddr.IP) { - glog.V(1).Infof("ConnectionManager._isFromRedundantInboundIPAddress: Allowing " + + glog.V(1).Infof("ConnectionManager.IsDuplicateInboundIPAddress: Allowing " + "localhost IP address to connect") return false } @@ -578,38 +358,9 @@ func (cmgr *ConnectionManager) _handleInboundConnections() { continue } - // As a quick check, reject the peer if we have too many already. Note that - // this check isn't perfect but we have a later check at the end after doing - // a version negotiation that will properly reject the peer if this check - // messes up e.g. due to a concurrency issue. - // - // TODO: We should instead have eviction logic here to prevent - // someone from monopolizing a node's inbound connections. - numInboundPeers := atomic.LoadUint32(&cmgr.numInboundPeers) - if numInboundPeers > cmgr.maxInboundPeers { - - glog.Infof("Rejecting INBOUND peer (%s) due to max inbound peers (%d) hit.", - conn.RemoteAddr().String(), cmgr.maxInboundPeers) - conn.Close() - - continue - } - - // If we want to limit inbound connections to one per IP address, check to - // make sure this address isn't already connected. - if cmgr.limitOneInboundConnectionPerIP && - cmgr._isFromRedundantInboundIPAddress(conn.RemoteAddr()) { - - glog.Infof("Rejecting INBOUND peer (%s) due to already having an "+ - "inbound connection from the same IP with "+ - "limit_one_inbound_connection_per_ip set.", - conn.RemoteAddr().String()) - conn.Close() - - continue + cmgr.inboundConnectionChan <- &inboundConnection{ + connection: conn, } - - go cmgr.ConnectPeer(conn, nil) } }(outerListener) } @@ -622,13 +373,7 @@ func (cmgr *ConnectionManager) GetAllPeers() []*Peer { defer cmgr.mtxPeerMaps.RUnlock() allPeers := []*Peer{} - for _, pp := range cmgr.persistentPeers { - allPeers = append(allPeers, pp) - } - for _, pp := range cmgr.outboundPeers { - allPeers = append(allPeers, pp) - } - for _, pp := range cmgr.inboundPeers { + for _, pp := range cmgr.connectedPeers { allPeers = append(allPeers, pp) } @@ -686,12 +431,11 @@ func (cmgr *ConnectionManager) addPeer(pp *Peer) { // number of outbound peers. Also add the peer's address to // our map. if _, ok := peerList[pp.ID]; !ok { - cmgr.addToGroupKey(pp.netAddr) atomic.AddUint32(&cmgr.numOutboundPeers, 1) - cmgr.mtxConnectedOutboundAddrs.Lock() + cmgr.mtxAddrsMaps.Lock() cmgr.connectedOutboundAddrs[addrmgr.NetAddressKey(pp.netAddr)] = true - cmgr.mtxConnectedOutboundAddrs.Unlock() + cmgr.mtxAddrsMaps.Unlock() } } else { // This is an inbound peer. @@ -700,10 +444,45 @@ func (cmgr *ConnectionManager) addPeer(pp *Peer) { } peerList[pp.ID] = pp + cmgr.connectedPeers[pp.ID] = pp +} + +func (cmgr *ConnectionManager) getPeer(id uint64) *Peer { + cmgr.mtxPeerMaps.RLock() + defer cmgr.mtxPeerMaps.RUnlock() + + if peer, ok := cmgr.connectedPeers[id]; ok { + return peer + } + return nil +} + +func (cmgr *ConnectionManager) SendMessage(msg DeSoMessage, peerId uint64) error { + peer := cmgr.getPeer(peerId) + if peer == nil { + return fmt.Errorf("SendMessage: Peer with ID %d not found", peerId) + } + glog.V(1).Infof("SendMessage: Sending message %v to peer %d", msg.GetMsgType().String(), peerId) + peer.AddDeSoMessage(msg, false) + return nil +} + +func (cmgr *ConnectionManager) CloseConnection(peerId uint64) { + glog.V(2).Infof("ConnectionManager.CloseConnection: Closing connection to peer (id= %v)", peerId) + + var peer *Peer + var ok bool + cmgr.mtxPeerMaps.Lock() + peer, ok = cmgr.connectedPeers[peerId] + cmgr.mtxPeerMaps.Unlock() + if !ok { + return + } + peer.Disconnect() } // Update our data structures to remove this peer. -func (cmgr *ConnectionManager) RemovePeer(pp *Peer) { +func (cmgr *ConnectionManager) removePeer(pp *Peer) { // Acquire the mtxPeerMaps lock for writing. cmgr.mtxPeerMaps.Lock() defer cmgr.mtxPeerMaps.Unlock() @@ -724,9 +503,9 @@ func (cmgr *ConnectionManager) RemovePeer(pp *Peer) { cmgr.subFromGroupKey(pp.netAddr) atomic.AddUint32(&cmgr.numOutboundPeers, Uint32Dec) - cmgr.mtxConnectedOutboundAddrs.Lock() + cmgr.mtxAddrsMaps.Lock() delete(cmgr.connectedOutboundAddrs, addrmgr.NetAddressKey(pp.netAddr)) - cmgr.mtxConnectedOutboundAddrs.Unlock() + cmgr.mtxAddrsMaps.Unlock() } } else { // This is an inbound peer. @@ -737,25 +516,12 @@ func (cmgr *ConnectionManager) RemovePeer(pp *Peer) { // Update the last seen time before we finish removing the peer. // TODO: Really, we call 'Connected()' on removing a peer? // I can't find a Disconnected() but seems odd. - cmgr.AddrMgr.Connected(pp.netAddr) + // FIXME: Move this to Done Peer + //cmgr.AddrMgr.Connected(pp.netAddr) // Remove the peer from our data structure. delete(peerList, pp.ID) -} - -func (cmgr *ConnectionManager) _maybeReplacePeer(pp *Peer) { - // If the peer was outbound, replace her with a - // new peer to maintain a fixed number of outbound connections. - if pp.isOutbound { - // If the peer is not persistent then we don't want to pass an - // address to connectPeer. The lack of an address will cause it - // to choose random addresses from the addrmgr until one works. - na := pp.netAddr - if !pp.isPersistent { - na = nil - } - go cmgr.ConnectPeer(nil, na) - } + delete(cmgr.connectedPeers, pp.ID) } func (cmgr *ConnectionManager) _logOutboundPeerData() { @@ -763,24 +529,32 @@ func (cmgr *ConnectionManager) _logOutboundPeerData() { numInboundPeers := int(atomic.LoadUint32(&cmgr.numInboundPeers)) numPersistentPeers := int(atomic.LoadUint32(&cmgr.numPersistentPeers)) glog.V(1).Infof("Num peers: OUTBOUND(%d) INBOUND(%d) PERSISTENT(%d)", numOutboundPeers, numInboundPeers, numPersistentPeers) +} - cmgr.mtxOutboundConnIPGroups.Lock() - for _, vv := range cmgr.outboundConnIPGroups { - if vv != 0 && vv != 1 { - glog.V(1).Infof("_logOutboundPeerData: Peer group count != (0 or 1). "+ - "Is (%d) instead. This "+ - "should never happen.", vv) - } - } - cmgr.mtxOutboundConnIPGroups.Unlock() +func (cmgr *ConnectionManager) AddTimeSample(addrStr string, timeSample time.Time) { + cmgr.timeSource.AddTimeSample(addrStr, timeSample) +} + +func (cmgr *ConnectionManager) GetNumInboundPeers() uint32 { + return atomic.LoadUint32(&cmgr.numInboundPeers) +} + +func (cmgr *ConnectionManager) GetNumOutboundPeers() uint32 { + return atomic.LoadUint32(&cmgr.numOutboundPeers) } func (cmgr *ConnectionManager) Stop() { + cmgr.mtxPeerMaps.Lock() + defer cmgr.mtxPeerMaps.Unlock() + if atomic.AddInt32(&cmgr.shutdown, 1) != 1 { glog.Warningf("ConnectionManager.Stop is already in the process of " + "shutting down") return } + for id := range cmgr.outboundConnectionAttempts { + cmgr.CloseAttemptedConnection(id) + } glog.Infof("ConnectionManager: Stopping, number of inbound peers (%v), number of outbound "+ "peers (%v), number of persistent peers (%v).", len(cmgr.inboundPeers), len(cmgr.outboundPeers), len(cmgr.persistentPeers)) @@ -823,10 +597,6 @@ func (cmgr *ConnectionManager) Start() { // - Have the peer enter a switch statement listening for all kinds of messages. // - Send addr and getaddr messages as appropriate. - // Initiate outbound connections with peers either using the --connect-ips passed - // in or using the addrmgr. - cmgr._initiateOutboundConnections() - // Accept inbound connections from peers on our listeners. cmgr._handleInboundConnections() @@ -837,90 +607,51 @@ func (cmgr *ConnectionManager) Start() { cmgr._logOutboundPeerData() select { - case pp := <-cmgr.newPeerChan: - { - // We have successfully connected to a peer and it passed its version - // negotiation. - - // if this is a non-persistent outbound peer and we already have enough - // outbound peers, then don't bother adding this one. - if !pp.isPersistent && pp.isOutbound && cmgr.enoughOutboundPeers() { - // TODO: Make this less verbose - glog.V(1).Infof("Dropping peer because we already have enough outbound peer connections.") - pp.Conn.Close() - continue - } - - // If this is a non-persistent outbound peer and the group key - // overlaps with another peer we're already connected to then - // abort mission. We only connect to one peer per IP group in - // order to prevent Sybil attacks. - if pp.isOutbound && - !pp.isPersistent && - cmgr.isRedundantGroupKey(pp.netAddr) { - - // TODO: Make this less verbose - glog.Infof("Rejecting OUTBOUND NON-PERSISTENT peer (%v) with "+ - "redundant group key (%s).", - pp, addrmgr.GroupKey(pp.netAddr)) - - pp.Conn.Close() - cmgr._maybeReplacePeer(pp) - continue - } - - // Check that we have not exceeded the maximum number of inbound - // peers allowed. - // - // TODO: We should instead have eviction logic to prevent - // someone from monopolizing a node's inbound connections. - numInboundPeers := atomic.LoadUint32(&cmgr.numInboundPeers) - if !pp.isOutbound && numInboundPeers > cmgr.maxInboundPeers { - - // TODO: Make this less verbose - glog.Infof("Rejecting INBOUND peer (%v) due to max inbound peers (%d) hit.", - pp, cmgr.maxInboundPeers) - - pp.Conn.Close() - continue - } - - // Now we can add the peer to our data structures. - pp._logAddPeer() - cmgr.addPeer(pp) - - // Start the peer's message loop. - pp.Start() - - // Signal the server about the new Peer in case it wants to do something with it. - cmgr.serverMessageQueue <- &ServerMessage{ - Peer: pp, - Msg: &MsgDeSoNewPeer{}, - } - + case oc := <-cmgr.outboundConnectionChan: + if oc.failed { + glog.V(2).Infof("ConnectionManager.Start: Failed to establish an outbound connection with "+ + "(id= %v)", oc.attemptId) + } else { + glog.V(2).Infof("ConnectionManager.Start: Successfully established an outbound connection with "+ + "(addr= %v) (id= %v)", oc.connection.RemoteAddr(), oc.attemptId) + } + cmgr.mtxConnectionAttempts.Lock() + delete(cmgr.outboundConnectionAttempts, oc.attemptId) + cmgr.mtxConnectionAttempts.Unlock() + cmgr.serverMessageQueue <- &ServerMessage{ + Peer: nil, + Msg: &MsgDeSoNewConnection{ + Connection: oc, + }, } - case pp := <-cmgr.donePeerChan: + case ic := <-cmgr.inboundConnectionChan: + glog.V(2).Infof("ConnectionManager.Start: Successfully received an inbound connection from "+ + "(addr= %v)", ic.connection.RemoteAddr()) + cmgr.serverMessageQueue <- &ServerMessage{ + Peer: nil, + Msg: &MsgDeSoNewConnection{ + Connection: ic, + }, + } + case pp := <-cmgr.peerDisconnectedChan: { // By the time we get here, it can be assumed that the Peer's Disconnect function // has already been called, since that is what's responsible for adding the peer // to this queue in the first place. - glog.V(1).Infof("Done with peer (%v).", pp) + glog.V(1).Infof("Done with peer (id=%v).", pp.ID) - if !pp.PeerManuallyRemovedFromConnectionManager { - // Remove the peer from our data structures. - cmgr.RemovePeer(pp) + // Remove the peer from our data structures. + cmgr.removePeer(pp) - // Potentially replace the peer. For example, if the Peer was an outbound Peer - // then we want to find a new peer in order to maintain our TargetOutboundPeers. - cmgr._maybeReplacePeer(pp) - } + // Potentially replace the peer. For example, if the Peer was an outbound Peer + // then we want to find a new peer in order to maintain our TargetOutboundPeers. // Signal the server about the Peer being done in case it wants to do something // with it. cmgr.serverMessageQueue <- &ServerMessage{ Peer: pp, - Msg: &MsgDeSoDonePeer{}, + Msg: &MsgDeSoDisconnectedPeer{}, } } } diff --git a/lib/constants.go b/lib/constants.go index 9c92a47a3..3bb345abb 100644 --- a/lib/constants.go +++ b/lib/constants.go @@ -496,6 +496,30 @@ func GetEncoderMigrationHeightsList(forkHeights *ForkHeights) ( return migrationHeightsList } +type ProtocolVersionType uint64 + +const ( + ProtocolVersion0 ProtocolVersionType = 0 + ProtocolVersion1 ProtocolVersionType = 1 + ProtocolVersion2 ProtocolVersionType = 2 +) + +func NewProtocolVersionType(version uint64) ProtocolVersionType { + return ProtocolVersionType(version) +} + +func (pvt ProtocolVersionType) ToUint64() uint64 { + return uint64(pvt) +} + +func (pvt ProtocolVersionType) Before(version ProtocolVersionType) bool { + return pvt.ToUint64() < version.ToUint64() +} + +func (pvt ProtocolVersionType) After(version ProtocolVersionType) bool { + return pvt.ToUint64() > version.ToUint64() +} + // DeSoParams defines the full list of possible parameters for the // DeSo network. type DeSoParams struct { @@ -504,7 +528,7 @@ type DeSoParams struct { // Set to true when we're running in regtest mode. This is useful for testing. ExtraRegtestParamUpdaterKeys map[PkMapKey]bool // The current protocol version we're running. - ProtocolVersion uint64 + ProtocolVersion ProtocolVersionType // The minimum protocol version we'll allow a peer we connect to // to have. MinProtocolVersion uint64 @@ -562,6 +586,11 @@ type DeSoParams struct { DialTimeout time.Duration // The amount of time we wait to receive a version message from a peer. VersionNegotiationTimeout time.Duration + // The amount of time we wait to receive a verack message from a peer. + VerackNegotiationTimeout time.Duration + + // The maximum number of addresses to broadcast to peers. + MaxAddressesToBroadcast uint32 // The genesis block to use as the base of our chain. GenesisBlock *MsgDeSoBlock @@ -755,6 +784,12 @@ type DeSoParams struct { // for a description of its usage. DefaultMempoolFeeEstimatorNumPastBlocks uint64 + // HandshakeTimeoutMicroSeconds is the timeout for the peer handshake certificate. The default value is 15 minutes. + HandshakeTimeoutMicroSeconds uint64 + + // DisableNetworkManagerRoutines is a testing flag that disables the network manager routines. + DisableNetworkManagerRoutines bool + ForkHeights ForkHeights EncoderMigrationHeights *EncoderMigrationHeights @@ -819,6 +854,9 @@ func (params *DeSoParams) EnableRegtest() { // Clear the seeds params.DNSSeeds = []string{} + // Set the protocol version + params.ProtocolVersion = ProtocolVersion2 + // Mine blocks incredibly quickly params.TimeBetweenBlocks = 2 * time.Second params.TimeBetweenDifficultyRetargets = 6 * time.Second @@ -992,7 +1030,7 @@ var MainnetForkHeights = ForkHeights{ // DeSoMainnetParams defines the DeSo parameters for the mainnet. var DeSoMainnetParams = DeSoParams{ NetworkType: NetworkType_MAINNET, - ProtocolVersion: 1, + ProtocolVersion: ProtocolVersion1, MinProtocolVersion: 1, UserAgent: "Architect", DNSSeeds: []string{ @@ -1075,6 +1113,9 @@ var DeSoMainnetParams = DeSoParams{ DialTimeout: 30 * time.Second, VersionNegotiationTimeout: 30 * time.Second, + VerackNegotiationTimeout: 30 * time.Second, + + MaxAddressesToBroadcast: 10, BlockRewardMaturity: time.Hour * 3, @@ -1214,6 +1255,12 @@ var DeSoMainnetParams = DeSoParams{ // The number of past blocks to consider when estimating the mempool fee. DefaultMempoolFeeEstimatorNumPastBlocks: 50, + // The peer handshake certificate timeout. + HandshakeTimeoutMicroSeconds: uint64(900000000), + + // DisableNetworkManagerRoutines is a testing flag that disables the network manager routines. + DisableNetworkManagerRoutines: false, + ForkHeights: MainnetForkHeights, EncoderMigrationHeights: GetEncoderMigrationHeights(&MainnetForkHeights), EncoderMigrationHeightsList: GetEncoderMigrationHeightsList(&MainnetForkHeights), @@ -1313,7 +1360,7 @@ var TestnetForkHeights = ForkHeights{ // DeSoTestnetParams defines the DeSo parameters for the testnet. var DeSoTestnetParams = DeSoParams{ NetworkType: NetworkType_TESTNET, - ProtocolVersion: 0, + ProtocolVersion: ProtocolVersion0, MinProtocolVersion: 0, UserAgent: "Architect", DNSSeeds: []string{ @@ -1356,6 +1403,9 @@ var DeSoTestnetParams = DeSoParams{ DialTimeout: 30 * time.Second, VersionNegotiationTimeout: 30 * time.Second, + VerackNegotiationTimeout: 30 * time.Second, + + MaxAddressesToBroadcast: 10, GenesisBlock: &GenesisBlock, GenesisBlockHashHex: GenesisBlockHashHex, @@ -1497,6 +1547,12 @@ var DeSoTestnetParams = DeSoParams{ // The number of past blocks to consider when estimating the mempool fee. DefaultMempoolFeeEstimatorNumPastBlocks: 50, + // The peer handshake certificate timeout. + HandshakeTimeoutMicroSeconds: uint64(900000000), + + // DisableNetworkManagerRoutines is a testing flag that disables the network manager routines. + DisableNetworkManagerRoutines: false, + ForkHeights: TestnetForkHeights, EncoderMigrationHeights: GetEncoderMigrationHeights(&TestnetForkHeights), EncoderMigrationHeightsList: GetEncoderMigrationHeightsList(&TestnetForkHeights), diff --git a/lib/db_utils.go b/lib/db_utils.go index e0aa9f9de..5ee4e649a 100644 --- a/lib/db_utils.go +++ b/lib/db_utils.go @@ -595,7 +595,7 @@ type DBPrefixes struct { // PrefixSnapshotValidatorBLSPublicKeyPKIDPairEntry: Retrieve a snapshotted BLSPublicKeyPKIDPairEntry // by BLS Public Key and SnapshotAtEpochNumber. // Prefix, , -> *BLSPublicKeyPKIDPairEntry - PrefixSnapshotValidatorBLSPublicKeyPKIDPairEntry []byte `prefix_id:"[96]" is_state:"true"` + PrefixSnapshotValidatorBLSPublicKeyPKIDPairEntry []byte `prefix_id:"[96]" is_state:"true" core_state:"true"` // NEXT_TAG: 97 } diff --git a/lib/legacy_mempool.go b/lib/legacy_mempool.go index c7e3bb770..10eebc11b 100644 --- a/lib/legacy_mempool.go +++ b/lib/legacy_mempool.go @@ -241,7 +241,7 @@ func (mp *DeSoMempool) IsRunning() bool { return !mp.stopped } -func (mp *DeSoMempool) AddTransaction(txn *MempoolTransaction, verifySignature bool) error { +func (mp *DeSoMempool) AddTransaction(txn *MempoolTransaction) error { return errors.New("Not implemented") } @@ -254,13 +254,13 @@ func (mp *DeSoMempool) GetTransaction(txnHash *BlockHash) *MempoolTransaction { if !exists { return nil } - return NewMempoolTransaction(mempoolTx.Tx, mempoolTx.Added) + return NewMempoolTransaction(mempoolTx.Tx, mempoolTx.Added, true) } func (mp *DeSoMempool) GetTransactions() []*MempoolTransaction { return collections.Transform( mp.GetOrderedTransactions(), func(mempoolTx *MempoolTx) *MempoolTransaction { - return NewMempoolTransaction(mempoolTx.Tx, mempoolTx.Added) + return NewMempoolTransaction(mempoolTx.Tx, mempoolTx.Added, true) }, ) } @@ -270,11 +270,6 @@ func (mp *DeSoMempool) GetIterator() MempoolIterator { panic("implement me") } -func (mp *DeSoMempool) Refresh() error { - //TODO implement me - panic("implement me") -} - func (mp *DeSoMempool) UpdateLatestBlock(blockView *UtxoView, blockHeight uint64) { //TODO implement me panic("implement me") @@ -286,6 +281,8 @@ func (mp *DeSoMempool) UpdateGlobalParams(globalParams *GlobalParamsEntry) { } func (mp *DeSoMempool) GetOrderedTransactions() []*MempoolTx { + mp.mtx.RLock() + defer mp.mtx.RUnlock() orderedTxns, _, _ := mp.GetTransactionsOrderedByTimeAdded() return orderedTxns } @@ -2129,24 +2126,6 @@ func _computeBitcoinExchangeFields(params *DeSoParams, }, PkToString(publicKey.SerializeCompressed(), params), nil } -func ConnectTxnAndComputeTransactionMetadata( - txn *MsgDeSoTxn, utxoView *UtxoView, blockHash *BlockHash, - blockHeight uint32, blockTimestampNanoSecs int64, txnIndexInBlock uint64) (*TransactionMetadata, error) { - - totalNanosPurchasedBefore := utxoView.NanosPurchased - usdCentsPerBitcoinBefore := utxoView.GetCurrentUSDCentsPerBitcoin() - utxoOps, totalInput, totalOutput, fees, err := utxoView._connectTransaction( - txn, txn.Hash(), blockHeight, blockTimestampNanoSecs, false, false, - ) - if err != nil { - return nil, fmt.Errorf( - "UpdateTxindex: Error connecting txn to UtxoView: %v", err) - } - - return ComputeTransactionMetadata(txn, utxoView, blockHash, totalNanosPurchasedBefore, - usdCentsPerBitcoinBefore, totalInput, totalOutput, fees, txnIndexInBlock, utxoOps, uint64(blockHeight)), nil -} - // This is the main function used for adding a new txn to the pool. It will // run all needed validation on the txn before adding it, and it will only // accept the txn if these validations pass. @@ -2446,7 +2425,21 @@ func EstimateMaxTxnFeeV1(txn *MsgDeSoTxn, minFeeRateNanosPerKB uint64) uint64 { func (mp *DeSoMempool) EstimateFee(txn *MsgDeSoTxn, minFeeRateNanosPerKB uint64, _ uint64, _ uint64, _ uint64, _ uint64, _ uint64) (uint64, error) { - return EstimateMaxTxnFeeV1(txn, minFeeRateNanosPerKB), nil + feeRate, _ := mp.EstimateFeeRate(minFeeRateNanosPerKB, 0, 0, 0, 0, 0) + return EstimateMaxTxnFeeV1(txn, feeRate), nil +} + +func (mp *DeSoMempool) EstimateFeeRate( + minFeeRateNanosPerKB uint64, + _ uint64, + _ uint64, + _ uint64, + _ uint64, + _ uint64) (uint64, error) { + if minFeeRateNanosPerKB < mp.readOnlyUtxoView.GlobalParamsEntry.MinimumNetworkFeeNanosPerKB { + return mp.readOnlyUtxoView.GlobalParamsEntry.MinimumNetworkFeeNanosPerKB, nil + } + return minFeeRateNanosPerKB, nil } func convertMempoolTxsToSummaryStats(mempoolTxs []*MempoolTx) map[string]*SummaryStats { diff --git a/lib/miner.go b/lib/miner.go index 06272da42..5bc1e45d1 100644 --- a/lib/miner.go +++ b/lib/miner.go @@ -198,10 +198,6 @@ func (desoMiner *DeSoMiner) MineAndProcessSingleBlock(threadIndex uint32, mempoo return nil, fmt.Errorf("DeSoMiner._startThread: _mineSingleBlock returned nil; should only happen if we're stopping") } - if desoMiner.params.IsPoSBlockHeight(blockToMine.Header.Height) { - return nil, fmt.Errorf("DeSoMiner._startThread: _mineSingleBlock returned a block that is past the Proof of Stake Cutover") - } - // Log information on the block we just mined. bestHash, _ := blockToMine.Hash() glog.Infof("================== YOU MINED A NEW BLOCK! ================== Height: %d, Hash: %s", blockToMine.Header.Height, hex.EncodeToString(bestHash[:])) @@ -293,6 +289,12 @@ func (desoMiner *DeSoMiner) _startThread(threadIndex uint32) { continue } + // Exit if blockchain has connected a block at the final PoW block height. + currentTip := desoMiner.BlockProducer.chain.blockTip() + if currentTip.Header.Height >= desoMiner.params.GetFinalPoWBlockHeight() { + return + } + newBlock, err := desoMiner.MineAndProcessSingleBlock(threadIndex, nil /*mempoolToUpdate*/) if err != nil { glog.Errorf(err.Error()) @@ -317,8 +319,12 @@ func (desoMiner *DeSoMiner) Start() { "start the miner") return } - glog.Infof("DeSoMiner.Start: Starting miner with difficulty target %s", desoMiner.params.MinDifficultyTargetHex) blockTip := desoMiner.BlockProducer.chain.blockTip() + if desoMiner.params.IsPoSBlockHeight(blockTip.Header.Height) { + glog.Infof("DeSoMiner.Start: NOT starting miner because we are at a PoS block height %d", blockTip.Header.Height) + return + } + glog.Infof("DeSoMiner.Start: Starting miner with difficulty target %s", desoMiner.params.MinDifficultyTargetHex) glog.Infof("DeSoMiner.Start: Block tip height %d, cum work %v, and difficulty %v", blockTip.Header.Height, BigintToHash(blockTip.CumWork), blockTip.DifficultyTarget) // Start a bunch of threads to mine for blocks. @@ -400,6 +406,10 @@ func HashToBigint(hash *BlockHash) *big.Int { } func BigintToHash(bigint *big.Int) *BlockHash { + if bigint == nil { + glog.Errorf("BigintToHash: Bigint is nil") + return nil + } hexStr := bigint.Text(16) if len(hexStr)%2 != 0 { // If we have an odd number of bytes add one to the beginning (remember @@ -410,6 +420,7 @@ func BigintToHash(bigint *big.Int) *BlockHash { if err != nil { glog.Errorf("Failed in converting bigint (%#v) with hex "+ "string (%s) to hash.", bigint, hexStr) + return nil } if len(hexBytes) > HashSizeBytes { glog.Errorf("BigintToHash: Bigint %v overflows the hash size %d", bigint, HashSizeBytes) diff --git a/lib/network.go b/lib/network.go index efcc1772b..8442cbf0a 100644 --- a/lib/network.go +++ b/lib/network.go @@ -17,7 +17,6 @@ import ( "strings" "time" - "github.com/deso-protocol/core/collections/bitset" "github.com/golang/glog" "github.com/decred/dcrd/dcrec/secp256k1/v4" @@ -109,11 +108,12 @@ const ( // TODO: Should probably split these out into a separate channel in the server to // make things more parallelized. - MsgTypeQuit MsgType = ControlMessagesStart - MsgTypeNewPeer MsgType = ControlMessagesStart + 1 - MsgTypeDonePeer MsgType = ControlMessagesStart + 2 - MsgTypeBlockAccepted MsgType = ControlMessagesStart + 3 - MsgTypeBitcoinManagerUpdate MsgType = ControlMessagesStart + 4 // Deprecated + MsgTypeQuit MsgType = ControlMessagesStart + MsgTypeDisconnectedPeer MsgType = ControlMessagesStart + 1 + MsgTypeBlockAccepted MsgType = ControlMessagesStart + 2 + MsgTypeBitcoinManagerUpdate MsgType = ControlMessagesStart + 3 // Deprecated + MsgTypePeerHandshakeComplete MsgType = ControlMessagesStart + 4 + MsgTypeNewConnection MsgType = ControlMessagesStart + 5 // NEXT_TAG = 7 ) @@ -171,14 +171,16 @@ func (msgType MsgType) String() string { return "GET_ADDR" case MsgTypeQuit: return "QUIT" - case MsgTypeNewPeer: - return "NEW_PEER" - case MsgTypeDonePeer: + case MsgTypeDisconnectedPeer: return "DONE_PEER" case MsgTypeBlockAccepted: return "BLOCK_ACCEPTED" case MsgTypeBitcoinManagerUpdate: return "BITCOIN_MANAGER_UPDATE" + case MsgTypePeerHandshakeComplete: + return "PEER_HANDSHAKE_COMPLETE" + case MsgTypeNewConnection: + return "NEW_CONNECTION" case MsgTypeGetSnapshot: return "GET_SNAPSHOT" case MsgTypeSnapshotData: @@ -835,34 +837,47 @@ func (msg *MsgDeSoQuit) FromBytes(data []byte) error { return fmt.Errorf("MsgDeSoQuit.FromBytes not implemented") } -type MsgDeSoNewPeer struct { +type MsgDeSoDisconnectedPeer struct { } -func (msg *MsgDeSoNewPeer) GetMsgType() MsgType { - return MsgTypeNewPeer +func (msg *MsgDeSoDisconnectedPeer) GetMsgType() MsgType { + return MsgTypeDisconnectedPeer } -func (msg *MsgDeSoNewPeer) ToBytes(preSignature bool) ([]byte, error) { - return nil, fmt.Errorf("MsgDeSoNewPeer.ToBytes: Not implemented") +func (msg *MsgDeSoDisconnectedPeer) ToBytes(preSignature bool) ([]byte, error) { + return nil, fmt.Errorf("MsgDeSoDisconnectedPeer.ToBytes: Not implemented") } -func (msg *MsgDeSoNewPeer) FromBytes(data []byte) error { - return fmt.Errorf("MsgDeSoNewPeer.FromBytes not implemented") +func (msg *MsgDeSoDisconnectedPeer) FromBytes(data []byte) error { + return fmt.Errorf("MsgDeSoDisconnectedPeer.FromBytes not implemented") } -type MsgDeSoDonePeer struct { +type ConnectionType uint8 + +const ( + ConnectionTypeOutbound ConnectionType = iota + ConnectionTypeInbound +) + +type Connection interface { + GetConnectionType() ConnectionType + Close() +} + +type MsgDeSoNewConnection struct { + Connection Connection } -func (msg *MsgDeSoDonePeer) GetMsgType() MsgType { - return MsgTypeDonePeer +func (msg *MsgDeSoNewConnection) GetMsgType() MsgType { + return MsgTypeNewConnection } -func (msg *MsgDeSoDonePeer) ToBytes(preSignature bool) ([]byte, error) { - return nil, fmt.Errorf("MsgDeSoDonePeer.ToBytes: Not implemented") +func (msg *MsgDeSoNewConnection) ToBytes(preSignature bool) ([]byte, error) { + return nil, fmt.Errorf("MsgDeSoNewConnection.ToBytes: Not implemented") } -func (msg *MsgDeSoDonePeer) FromBytes(data []byte) error { - return fmt.Errorf("MsgDeSoDonePeer.FromBytes not implemented") +func (msg *MsgDeSoNewConnection) FromBytes(data []byte) error { + return fmt.Errorf("MsgDeSoNewConnection.FromBytes not implemented") } // ================================================================== @@ -1509,16 +1524,21 @@ func (msg *MsgDeSoPong) FromBytes(data []byte) error { type ServiceFlag uint64 const ( - // SFFullNodeDeprecated is deprecated, and set on all nodes by default - // now. We basically split it into SFHyperSync and SFArchivalMode. - SFFullNodeDeprecated ServiceFlag = 1 << iota + // SFFullNodeDeprecated is deprecated, and set on all nodes by default now. + SFFullNodeDeprecated ServiceFlag = 1 << 0 // SFHyperSync is a flag used to indicate that the peer supports hyper sync. - SFHyperSync + SFHyperSync ServiceFlag = 1 << 1 // SFArchivalNode is a flag complementary to SFHyperSync. If node is a hypersync node then // it might not be able to support block sync anymore, unless it has archival mode turned on. - SFArchivalNode + SFArchivalNode ServiceFlag = 1 << 2 + // SFPosValidator is a flag used to indicate that the peer is running a PoS validator. + SFPosValidator ServiceFlag = 1 << 3 ) +func (sf ServiceFlag) HasService(serviceFlag ServiceFlag) bool { + return sf&serviceFlag == serviceFlag +} + type MsgDeSoVersion struct { // What is the current version we're on? Version uint64 @@ -1542,8 +1562,7 @@ type MsgDeSoVersion struct { // The height of the last block on the main chain for // this node. // - // TODO: We need to update this to uint64 - StartBlockHeight uint32 + LatestBlockHeight uint64 // MinFeeRateNanosPerKB is the minimum feerate that a peer will // accept from other peers when validating transactions. @@ -1575,11 +1594,11 @@ func (msg *MsgDeSoVersion) ToBytes(preSignature bool) ([]byte, error) { retBytes = append(retBytes, UintToBuf(uint64(len(msg.UserAgent)))...) retBytes = append(retBytes, msg.UserAgent...) - // StartBlockHeight - retBytes = append(retBytes, UintToBuf(uint64(msg.StartBlockHeight))...) + // LatestBlockHeight + retBytes = append(retBytes, UintToBuf(msg.LatestBlockHeight)...) // MinFeeRateNanosPerKB - retBytes = append(retBytes, UintToBuf(uint64(msg.MinFeeRateNanosPerKB))...) + retBytes = append(retBytes, UintToBuf(msg.MinFeeRateNanosPerKB)...) // JSONAPIPort - deprecated retBytes = append(retBytes, UintToBuf(uint64(0))...) @@ -1653,13 +1672,13 @@ func (msg *MsgDeSoVersion) FromBytes(data []byte) error { retVer.UserAgent = string(userAgent) } - // StartBlockHeight + // LatestBlockHeight { - lastBlockHeight, err := ReadUvarint(rr) - if err != nil || lastBlockHeight > math.MaxUint32 { + latestBlockHeight, err := ReadUvarint(rr) + if err != nil || latestBlockHeight > math.MaxUint32 { return errors.Wrapf(err, "MsgDeSoVersion.FromBytes: Problem converting msg.LatestBlockHeight") } - retVer.StartBlockHeight = uint32(lastBlockHeight) + retVer.LatestBlockHeight = latestBlockHeight } // MinFeeRateNanosPerKB @@ -1862,34 +1881,144 @@ func (msg *MsgDeSoGetAddr) GetMsgType() MsgType { // VERACK Message // ================================================================== -// VERACK messages have no payload. +type VerackVersion uint64 + +func NewVerackVersion(version uint64) VerackVersion { + return VerackVersion(version) +} + +const ( + VerackVersion0 VerackVersion = 0 + VerackVersion1 VerackVersion = 1 +) + +func (vv VerackVersion) ToUint64() uint64 { + return uint64(vv) +} + type MsgDeSoVerack struct { - // A verack message must contain the nonce the peer received in the - // initial version message. This ensures the peer that is communicating - // with us actually controls the address she says she does similar to - // "SYN Cookie" DDOS protection. - Nonce uint64 + // The VerackVersion0 message contains only the NonceReceived field, which is the nonce the sender received in the + // initial version message from the peer. This ensures the sender controls the network address, similarly to the + // "SYN Cookie" DDOS protection. The Version field in the VerackVersion0 message is implied, based on the msg length. + // + // The VerackVersion1 message contains the tuple of which correspond to the + // received and sent nonces in the version message from the sender's perspective, as well as a recent timestamp. + // The VerackVersion1 message is used in context of Proof of Stake, where validators register their BLS public keys + // as part of their validator entry. The sender of this message must be a registered validator, and he must attach + // their public key to the message, along with a BLS signature of the tuple. + Version VerackVersion + + NonceReceived uint64 + NonceSent uint64 + TstampMicro uint64 + + PublicKey *bls.PublicKey + Signature *bls.Signature } func (msg *MsgDeSoVerack) ToBytes(preSignature bool) ([]byte, error) { + switch msg.Version { + case VerackVersion0: + return msg.EncodeVerackV0() + case VerackVersion1: + return msg.EncodeVerackV1() + default: + return nil, fmt.Errorf("MsgDeSoVerack.ToBytes: Unrecognized version: %v", msg.Version) + } +} + +func (msg *MsgDeSoVerack) EncodeVerackV0() ([]byte, error) { retBytes := []byte{} // Nonce - retBytes = append(retBytes, UintToBuf(msg.Nonce)...) + retBytes = append(retBytes, UintToBuf(msg.NonceReceived)...) + return retBytes, nil +} + +func (msg *MsgDeSoVerack) EncodeVerackV1() ([]byte, error) { + retBytes := []byte{} + + // Version + retBytes = append(retBytes, UintToBuf(msg.Version.ToUint64())...) + // Nonce Received + retBytes = append(retBytes, UintToBuf(msg.NonceReceived)...) + // Nonce Sent + retBytes = append(retBytes, UintToBuf(msg.NonceSent)...) + // Tstamp Micro + retBytes = append(retBytes, UintToBuf(msg.TstampMicro)...) + // PublicKey + retBytes = append(retBytes, EncodeBLSPublicKey(msg.PublicKey)...) + // Signature + retBytes = append(retBytes, EncodeBLSSignature(msg.Signature)...) + return retBytes, nil } func (msg *MsgDeSoVerack) FromBytes(data []byte) error { rr := bytes.NewReader(data) - retMsg := NewMessage(MsgTypeVerack).(*MsgDeSoVerack) - { - nonce, err := ReadUvarint(rr) - if err != nil { - return errors.Wrapf(err, "MsgDeSoVerack.FromBytes: Problem reading Nonce") - } - retMsg.Nonce = nonce + // The V0 verack message is determined from the message length. The V0 message will only contain the NonceReceived field. + if len(data) <= MaxVarintLen64 { + return msg.FromBytesV0(data) + } + + version, err := ReadUvarint(rr) + if err != nil { + return errors.Wrapf(err, "MsgDeSoVerack.FromBytes: Problem reading Version") + } + msg.Version = NewVerackVersion(version) + switch msg.Version { + case VerackVersion0: + return fmt.Errorf("MsgDeSoVerack.FromBytes: Outdated Version=0 used for new encoding") + case VerackVersion1: + return msg.FromBytesV1(data) + default: + return fmt.Errorf("MsgDeSoVerack.FromBytes: Unrecognized version: %v", msg.Version) + } +} + +func (msg *MsgDeSoVerack) FromBytesV0(data []byte) error { + var err error + rr := bytes.NewReader(data) + msg.NonceReceived, err = ReadUvarint(rr) + if err != nil { + return errors.Wrapf(err, "MsgDeSoVerack.FromBytes: Problem reading Nonce") + } + return nil +} + +func (msg *MsgDeSoVerack) FromBytesV1(data []byte) error { + var err error + rr := bytes.NewReader(data) + version, err := ReadUvarint(rr) + if err != nil { + return errors.Wrapf(err, "MsgDeSoVerack.FromBytes: Problem reading Version") + } + msg.Version = NewVerackVersion(version) + + msg.NonceReceived, err = ReadUvarint(rr) + if err != nil { + return errors.Wrapf(err, "MsgDeSoVerack.FromBytes: Problem reading Nonce Received") + } + + msg.NonceSent, err = ReadUvarint(rr) + if err != nil { + return errors.Wrapf(err, "MsgDeSoVerack.FromBytes: Problem reading Nonce Sent") + } + + msg.TstampMicro, err = ReadUvarint(rr) + if err != nil { + return errors.Wrapf(err, "MsgDeSoVerack.FromBytes: Problem reading Tstamp Micro") + } + + msg.PublicKey, err = DecodeBLSPublicKey(rr) + if err != nil { + return errors.Wrapf(err, "MsgDeSoVerack.FromBytes: Problem reading PublicKey") + } + + msg.Signature, err = DecodeBLSSignature(rr) + if err != nil { + return errors.Wrapf(err, "MsgDeSoVerack.FromBytes: Problem reading Signature") } - *msg = *retMsg return nil } @@ -1951,14 +2080,6 @@ type MsgDeSoHeader struct { // event that ASICs become powerful enough to have birthday problems in the future. ExtraNonce uint64 - // TransactionsConnectStatus is only used for Proof of Stake blocks, starting with - // MsgDeSoHeader version 2. For all earlier versions, this field will default to nil. - // - // The hash of the TxnConnectStatusByIndex field in MsgDeSoBlock. It is stored to ensure - // that the TxnConnectStatusByIndex is part of the header hash, which is signed by the - // proposer. The full index is stored in the block to offload space complexity. - TxnConnectStatusByIndexHash *BlockHash - // ProposerVotingPublicKey is only used for Proof of Stake blocks, starting with // MsgDeSoHeader version 2. For all earlier versions, this field will default to nil. // @@ -2175,12 +2296,6 @@ func (msg *MsgDeSoHeader) EncodeHeaderVersion2(preSignature bool) ([]byte, error // The Nonce and ExtraNonce fields are unused in version 2. We skip them // during both encoding and decoding. - // TxnConnectStatusByIndexHash - if msg.TxnConnectStatusByIndexHash == nil { - return nil, fmt.Errorf("EncodeHeaderVersion2: TxnConnectStatusByIndexHash must be non-nil") - } - retBytes = append(retBytes, msg.TxnConnectStatusByIndexHash[:]...) - // ProposerVotingPublicKey if msg.ProposerVotingPublicKey == nil { return nil, fmt.Errorf("EncodeHeaderVersion2: ProposerVotingPublicKey must be non-nil") @@ -2390,13 +2505,6 @@ func DecodeHeaderVersion2(rr io.Reader) (*MsgDeSoHeader, error) { retHeader.Nonce = 0 retHeader.ExtraNonce = 0 - // TxnConnectStatusByIndexHash - retHeader.TxnConnectStatusByIndexHash = &BlockHash{} - _, err = io.ReadFull(rr, retHeader.TxnConnectStatusByIndexHash[:]) - if err != nil { - return nil, errors.Wrapf(err, "MsgDeSoHeader.FromBytes: Problem decoding TxnConnectStatusByIndexHash") - } - // ProposerVotingPublicKey retHeader.ProposerVotingPublicKey, err = DecodeBLSPublicKey(rr) if err != nil { @@ -2610,11 +2718,6 @@ type MsgDeSoBlock struct { // entity, which can be useful for nodes that want to restrict who they accept blocks // from. BlockProducerInfo *BlockProducerInfo - - // This bitset field stores information whether each transaction in the block passes - // or fails to connect. The bit at i-th position is set to 1 if the i-th transaction - // in the block passes connect, and 0 otherwise. - TxnConnectStatusByIndex *bitset.Bitset } func (msg *MsgDeSoBlock) EncodeBlockCommmon(preSignature bool) ([]byte, error) { @@ -2667,29 +2770,12 @@ func (msg *MsgDeSoBlock) EncodeBlockVersion1(preSignature bool) ([]byte, error) return data, nil } -func (msg *MsgDeSoBlock) EncodeBlockVersion2(preSignature bool) ([]byte, error) { - data, err := msg.EncodeBlockCommmon(preSignature) - if err != nil { - return nil, err - } - - // TxnConnectStatusByIndex - if msg.TxnConnectStatusByIndex == nil { - return nil, fmt.Errorf("MsgDeSoBlock.EncodeBlockVersion2: TxnConnectStatusByIndex should not be nil") - } - data = append(data, EncodeBitset(msg.TxnConnectStatusByIndex)...) - - return data, nil -} - func (msg *MsgDeSoBlock) ToBytes(preSignature bool) ([]byte, error) { switch msg.Header.Version { case HeaderVersion0: return msg.EncodeBlockVersion0(preSignature) case HeaderVersion1: return msg.EncodeBlockVersion1(preSignature) - case HeaderVersion2: - return msg.EncodeBlockVersion2(preSignature) default: return nil, fmt.Errorf("MsgDeSoBlock.ToBytes: Error encoding version: %v", msg.Header.Version) } @@ -2782,14 +2868,6 @@ func (msg *MsgDeSoBlock) FromBytes(data []byte) error { } } - // Version 2 blocks have a TxnStatusConnectedIndex attached to them. - if ret.Header.Version == HeaderVersion2 { - ret.TxnConnectStatusByIndex, err = DecodeBitset(rr) - if err != nil { - return errors.Wrapf(err, "MsgDeSoBlock.FromBytes: Error decoding TxnConnectStatusByIndex") - } - } - *msg = *ret return nil } diff --git a/lib/network_connection.go b/lib/network_connection.go new file mode 100644 index 000000000..4d50d22a8 --- /dev/null +++ b/lib/network_connection.go @@ -0,0 +1,220 @@ +package lib + +import ( + "github.com/btcsuite/btcd/wire" + "github.com/golang/glog" + "net" + "sync" + "time" +) + +// outboundConnection is used to store an established connection with a peer. It can also be used to signal that the +// connection was unsuccessful, in which case the failed flag is set to true. outboundConnection is created after an +// OutboundConnectionAttempt concludes. outboundConnection implements the Connection interface. +type outboundConnection struct { + mtx sync.Mutex + terminated bool + + attemptId uint64 + address *wire.NetAddress + connection net.Conn + isPersistent bool + failed bool +} + +func (oc *outboundConnection) GetConnectionType() ConnectionType { + return ConnectionTypeOutbound +} + +func (oc *outboundConnection) Close() { + oc.mtx.Lock() + defer oc.mtx.Unlock() + + if oc.terminated { + return + } + if oc.connection != nil { + oc.connection.Close() + } + oc.terminated = true +} + +// inboundConnection is used to store an established connection with a peer. inboundConnection is created after +// an external peer connects to the node. inboundConnection implements the Connection interface. +type inboundConnection struct { + mtx sync.Mutex + terminated bool + + connection net.Conn +} + +func (ic *inboundConnection) GetConnectionType() ConnectionType { + return ConnectionTypeInbound +} + +func (ic *inboundConnection) Close() { + ic.mtx.Lock() + defer ic.mtx.Unlock() + + if ic.terminated { + return + } + + if ic.connection != nil { + ic.connection.Close() + } + ic.terminated = true +} + +// OutboundConnectionAttempt is used to store the state of an outbound connection attempt. It is used to initiate +// an outbound connection to a peer, and manage the lifecycle of the connection attempt. +type OutboundConnectionAttempt struct { + mtx sync.Mutex + + // attemptId is used to identify the connection attempt. It will later be the id of the peer, + // if the connection is successful. + attemptId uint64 + + // netAddr is the address of the peer we are attempting to connect to. + netAddr *wire.NetAddress + // isPersistent is used to indicate whether we should retry connecting to the peer if the connection attempt fails. + // If isPersistent is true, we will retry connecting to the peer until we are successful. Each time such connection + // fails, we will sleep according to exponential backoff. Otherwise, we will only attempt to connect to the peer once. + isPersistent bool + // dialTimeout is the amount of time we will wait before timing out an individual connection attempt. + dialTimeout time.Duration + // timeoutUnit is the unit of time we will use to calculate the exponential backoff delay. The initial timeout is + // calculated as timeoutUnit * 2^0, the second timeout is calculated as timeoutUnit * 2^1, and so on. + timeoutUnit time.Duration + // retryCount is the number of times we have attempted to connect to the peer. + retryCount uint64 + // connectionChan is used to send the result of the connection attempt to the caller thread. + connectionChan chan *outboundConnection + + startGroup sync.WaitGroup + exitChan chan bool + status outboundConnectionAttemptStatus +} + +type outboundConnectionAttemptStatus int + +const ( + outboundConnectionAttemptInitialized outboundConnectionAttemptStatus = 0 + outboundConnectionAttemptRunning outboundConnectionAttemptStatus = 1 + outboundConnectionAttemptTerminated outboundConnectionAttemptStatus = 2 +) + +func NewOutboundConnectionAttempt(attemptId uint64, netAddr *wire.NetAddress, isPersistent bool, + dialTimeout time.Duration, connectionChan chan *outboundConnection) *OutboundConnectionAttempt { + + return &OutboundConnectionAttempt{ + attemptId: attemptId, + netAddr: netAddr, + isPersistent: isPersistent, + dialTimeout: dialTimeout, + timeoutUnit: time.Second, + exitChan: make(chan bool), + connectionChan: connectionChan, + status: outboundConnectionAttemptInitialized, + } +} + +func (oca *OutboundConnectionAttempt) Start() { + oca.mtx.Lock() + defer oca.mtx.Unlock() + + if oca.status != outboundConnectionAttemptInitialized { + return + } + + oca.startGroup.Add(1) + go oca.start() + oca.startGroup.Wait() + oca.status = outboundConnectionAttemptRunning +} + +func (oca *OutboundConnectionAttempt) start() { + oca.startGroup.Done() + oca.retryCount = 0 + +out: + for { + sleepDuration := 0 * time.Second + // for persistent peers, calculate the exponential backoff delay. + if oca.isPersistent { + sleepDuration = _delayRetry(oca.retryCount, oca.netAddr, oca.timeoutUnit) + } + + select { + case <-oca.exitChan: + break out + case <-time.After(sleepDuration): + // If the peer is persistent use exponential back off delay before retrying. + // We want to start backing off exponentially once we've gone through enough + // unsuccessful retries. + if oca.isPersistent { + oca.retryCount++ + } + + conn := oca.attemptOutboundConnection() + if conn == nil && oca.isPersistent { + break + } + if conn == nil { + break out + } + + oca.connectionChan <- &outboundConnection{ + attemptId: oca.attemptId, + address: oca.netAddr, + connection: conn, + isPersistent: oca.isPersistent, + failed: false, + } + return + } + } + oca.connectionChan <- &outboundConnection{ + attemptId: oca.attemptId, + address: oca.netAddr, + connection: nil, + isPersistent: oca.isPersistent, + failed: true, + } +} + +func (oca *OutboundConnectionAttempt) Stop() { + oca.mtx.Lock() + defer oca.mtx.Unlock() + + if oca.status == outboundConnectionAttemptTerminated { + return + } + close(oca.exitChan) + oca.status = outboundConnectionAttemptTerminated +} + +func (oca *OutboundConnectionAttempt) SetTimeoutUnit(timeoutUnit time.Duration) { + oca.timeoutUnit = timeoutUnit +} + +// attemptOutboundConnection dials the peer. If the connection attempt is successful, it will return the connection. +// Otherwise, it will return nil. +func (oca *OutboundConnectionAttempt) attemptOutboundConnection() net.Conn { + // If the peer is not persistent, update the addrmgr. + glog.V(1).Infof("Attempting to connect to addr: %v:%v", oca.netAddr.IP.String(), oca.netAddr.Port) + + var err error + tcpAddr := net.TCPAddr{ + IP: oca.netAddr.IP, + Port: int(oca.netAddr.Port), + } + conn, err := net.DialTimeout(tcpAddr.Network(), tcpAddr.String(), oca.dialTimeout) + if err != nil { + // If we failed to connect to this peer, get a new address and try again. + glog.V(2).Infof("Connection to addr (%v) failed: %v", tcpAddr, err) + return nil + } + + return conn +} diff --git a/lib/network_connection_test.go b/lib/network_connection_test.go new file mode 100644 index 000000000..5d3008f72 --- /dev/null +++ b/lib/network_connection_test.go @@ -0,0 +1,167 @@ +package lib + +import ( + "fmt" + "github.com/btcsuite/btcd/addrmgr" + "github.com/btcsuite/btcd/wire" + "github.com/stretchr/testify/require" + "net" + "sync" + "testing" + "time" +) + +type simpleListener struct { + t *testing.T + ll net.Listener + addr *wire.NetAddress + closed bool + + connectionChan chan Connection + + exitChan chan struct{} + startGroup sync.WaitGroup + stopGroup sync.WaitGroup +} + +func newSimpleListener(t *testing.T) *simpleListener { + require := require.New(t) + ll, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(err) + params := &DeSoTestnetParams + addr := ll.Addr() + addrMgr := addrmgr.New("", net.LookupIP) + na, err := IPToNetAddr(addr.String(), addrMgr, params) + + return &simpleListener{ + t: t, + ll: ll, + addr: na, + closed: false, + connectionChan: make(chan Connection, 100), + exitChan: make(chan struct{}), + } +} + +func (sl *simpleListener) start() { + require := require.New(sl.t) + if sl.closed { + ll, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%v", sl.addr.Port)) + require.NoError(err) + sl.ll = ll + sl.connectionChan = make(chan Connection, 100) + sl.exitChan = make(chan struct{}) + sl.closed = false + } + sl.startGroup.Add(1) + sl.stopGroup.Add(1) + + go func() { + sl.startGroup.Done() + defer sl.stopGroup.Done() + for { + select { + case <-sl.exitChan: + return + default: + conn, err := sl.ll.Accept() + if err != nil { + fmt.Println("simpleListener.start: ll.Accept:", err) + return + } + sl.connectionChan <- &inboundConnection{ + connection: conn, + } + } + } + }() + sl.startGroup.Wait() +} + +func (sl *simpleListener) stop() { + sl.ll.Close() + sl.closed = true + close(sl.exitChan) + close(sl.connectionChan) + sl.stopGroup.Wait() + fmt.Println("simpleListener.stop: stopped") +} + +func (sl *simpleListener) getTCPAddr() *net.TCPAddr { + return sl.ll.Addr().(*net.TCPAddr) +} + +func verifyOutboundConnection(t *testing.T, conn *outboundConnection, sl *simpleListener, attemptId uint64, isPersistent bool, failed bool) { + require := require.New(t) + require.Equal(attemptId, conn.attemptId) + require.Equal(isPersistent, conn.isPersistent) + require.Equal(failed, conn.failed) + if failed { + require.Nil(conn.connection) + return + } + + require.Equal(conn.address.IP.String(), sl.getTCPAddr().IP.String()) + require.Equal(conn.address.Port, uint16(sl.getTCPAddr().Port)) + require.Equal(conn.address.IP.String(), sl.getTCPAddr().IP.String()) + require.Equal(conn.address.Port, uint16(sl.getTCPAddr().Port)) +} + +func verifyOutboundConnectionSelect(t *testing.T, connectionChan chan *outboundConnection, timeoutDuration time.Duration, + sl *simpleListener, attemptId uint64, isPersistent bool, failed bool) { + + select { + case conn := <-connectionChan: + verifyOutboundConnection(t, conn, sl, attemptId, isPersistent, failed) + case <-time.After(2 * timeoutDuration): + panic("Timed out waiting for outbound connection.") + } +} + +func TestOutboundConnectionAttempt(t *testing.T) { + require := require.New(t) + _ = require + timeoutDuration := 100 * time.Millisecond + + sl := newSimpleListener(t) + sl.start() + + connectionChan := make(chan *outboundConnection, 100) + attempt := NewOutboundConnectionAttempt(0, sl.addr, false, timeoutDuration, connectionChan) + attempt.Start() + verifyOutboundConnectionSelect(t, connectionChan, 2*timeoutDuration, sl, 0, false, false) + t.Log("TestOutboundConnectionAttempt #1 | Happy path, non-persistent | PASS") + + sl.stop() + attemptFailed := NewOutboundConnectionAttempt(1, sl.addr, false, timeoutDuration, connectionChan) + attemptFailed.Start() + verifyOutboundConnectionSelect(t, connectionChan, 2*timeoutDuration, sl, 1, false, true) + t.Log("TestOutboundConnectionAttempt #2 | Failed connection, non-persistent | PASS") + + sl2 := newSimpleListener(t) + sl2.start() + + attemptPersistent := NewOutboundConnectionAttempt(2, sl2.addr, true, timeoutDuration, connectionChan) + attemptPersistent.Start() + verifyOutboundConnectionSelect(t, connectionChan, 2*timeoutDuration, sl2, 2, true, false) + t.Log("TestOutboundConnectionAttempt #3 | Happy path, persistent | PASS") + + sl2.stop() + attemptPersistentDelay := NewOutboundConnectionAttempt(3, sl2.addr, true, timeoutDuration, connectionChan) + attemptPersistentDelay.SetTimeoutUnit(timeoutDuration) + attemptPersistentDelay.Start() + time.Sleep(timeoutDuration) + sl2.start() + verifyOutboundConnectionSelect(t, connectionChan, 2*timeoutDuration, sl2, 3, true, false) + require.Greater(attemptPersistentDelay.retryCount, uint64(0)) + t.Log("TestOutboundConnectionAttempt #4 | Failed connection, persistent, delayed | PASS") + + sl2.stop() + attemptPersistentCancel := NewOutboundConnectionAttempt(4, sl2.addr, true, timeoutDuration, connectionChan) + attemptPersistentCancel.Start() + time.Sleep(timeoutDuration) + attemptPersistentCancel.Stop() + verifyOutboundConnectionSelect(t, connectionChan, 2*timeoutDuration, sl2, 4, true, true) + require.Greater(attemptPersistentCancel.retryCount, uint64(0)) + t.Log("TestOutboundConnectionAttempt #5 | Failed connection, persistent, delayed, canceled | PASS") +} diff --git a/lib/network_manager.go b/lib/network_manager.go new file mode 100644 index 000000000..d48a281e0 --- /dev/null +++ b/lib/network_manager.go @@ -0,0 +1,1230 @@ +package lib + +import ( + "fmt" + "github.com/btcsuite/btcd/addrmgr" + "github.com/btcsuite/btcd/wire" + "github.com/decred/dcrd/lru" + "github.com/deso-protocol/core/bls" + "github.com/deso-protocol/core/collections" + "github.com/deso-protocol/core/consensus" + "github.com/golang/glog" + "github.com/pkg/errors" + "math" + "net" + "strconv" + "sync" + "sync/atomic" + "time" +) + +// NetworkManager is a structure that oversees all connections to RemoteNodes. NetworkManager has the following +// responsibilities in regard to the lifecycle of RemoteNodes: +// - Maintain a list of all RemoteNodes that the node is connected to through the RemoteNodeManager. +// - Initialize RemoteNodes from established outbound and inbound peer connections. +// - Initiate and handle the communication of the handshake process with RemoteNodes. +// +// The NetworkManager is also responsible for opening and closing connections. It does this by running a set of +// goroutines that periodically check the state of different categories of RemoteNodes, and disconnects or connects +// RemoteNodes as needed. These categories of RemoteNodes include: +// - Persistent RemoteNodes: These are RemoteNodes that we want to maintain a persistent (constant) connection to. +// These are specified by the --connect-ips flag. +// - Validators: These are RemoteNodes that are in the active validators set. We want to maintain a connection to +// all active validators. We also want to disconnect from any validators that are no longer active. +// - Non-Validators: These are RemoteNodes that are not in the active validators set. We want to maintain a connection +// to at most a target number of outbound and inbound non-validators. If we have more than the target number of +// outbound or inbound non-validators, we will disconnect the excess RemoteNodes. +// +// The NetworkManager also runs an auxiliary goroutine that periodically cleans up RemoteNodes that may have timed out +// the handshake process, or became invalid for some other reason. +type NetworkManager struct { + mtx sync.Mutex + mtxHandshakeComplete sync.Mutex + + // The parameters we are initialized with. + params *DeSoParams + + srv *Server + bc *Blockchain + cmgr *ConnectionManager + keystore *BLSKeystore + + // configs + minTxFeeRateNanosPerKB uint64 + nodeServices ServiceFlag + + // Used to set remote node ids. Must be incremented atomically. + remoteNodeIndex uint64 + // AllRemoteNodes is a map storing all remote nodes by their IDs. + AllRemoteNodes *collections.ConcurrentMap[RemoteNodeId, *RemoteNode] + + // Indices for various types of remote nodes. + ValidatorIndex *collections.ConcurrentMap[bls.SerializedPublicKey, *RemoteNode] + NonValidatorOutboundIndex *collections.ConcurrentMap[RemoteNodeId, *RemoteNode] + NonValidatorInboundIndex *collections.ConcurrentMap[RemoteNodeId, *RemoteNode] + + // Cache of nonces used during handshake. + usedNonces lru.Cache + + // The address manager keeps track of peer addresses we're aware of. When + // we need to connect to a new outbound peer, it chooses one of the addresses + // it's aware of at random and provides it to us. + AddrMgr *addrmgr.AddrManager + + // When --connect-ips is set, we don't connect to anything from the addrmgr. + connectIps []string + // persistentIpToRemoteNodeIdsMap maps persistent IP addresses, like the --connect-ips, to the RemoteNodeIds of the + // corresponding RemoteNodes. This is used to ensure that we don't connect to the same persistent IP address twice. + // And that we can reconnect to the same persistent IP address if we disconnect from it. + persistentIpToRemoteNodeIdsMap *collections.ConcurrentMap[string, RemoteNodeId] + + activeValidatorsMapLock sync.RWMutex + // activeValidatorsMap is a map of all currently active validators registered in consensus. It will be updated + // periodically by the owner of the NetworkManager. + activeValidatorsMap *collections.ConcurrentMap[bls.SerializedPublicKey, consensus.Validator] + + // The target number of non-validator outbound remote nodes we want to have. We will disconnect remote nodes once + // we've exceeded this number of outbound connections. + targetNonValidatorOutboundRemoteNodes uint32 + // The target number of non-validator inbound remote nodes we want to have. We will disconnect remote nodes once + // we've exceeded this number of inbound connections. + targetNonValidatorInboundRemoteNodes uint32 + // When true, only one connection per IP is allowed. Prevents eclipse attacks + // among other things. + limitOneInboundRemoteNodePerIP bool + + startGroup sync.WaitGroup + exitChan chan struct{} + exitGroup sync.WaitGroup +} + +func NewNetworkManager(params *DeSoParams, srv *Server, bc *Blockchain, cmgr *ConnectionManager, + blsKeystore *BLSKeystore, addrMgr *addrmgr.AddrManager, connectIps []string, + targetNonValidatorOutboundRemoteNodes uint32, targetNonValidatorInboundRemoteNodes uint32, + limitOneInboundConnectionPerIP bool, minTxFeeRateNanosPerKB uint64, nodeServices ServiceFlag) *NetworkManager { + + return &NetworkManager{ + params: params, + srv: srv, + bc: bc, + cmgr: cmgr, + keystore: blsKeystore, + AddrMgr: addrMgr, + minTxFeeRateNanosPerKB: minTxFeeRateNanosPerKB, + nodeServices: nodeServices, + AllRemoteNodes: collections.NewConcurrentMap[RemoteNodeId, *RemoteNode](), + ValidatorIndex: collections.NewConcurrentMap[bls.SerializedPublicKey, *RemoteNode](), + NonValidatorOutboundIndex: collections.NewConcurrentMap[RemoteNodeId, *RemoteNode](), + NonValidatorInboundIndex: collections.NewConcurrentMap[RemoteNodeId, *RemoteNode](), + usedNonces: lru.NewCache(1000), + connectIps: connectIps, + persistentIpToRemoteNodeIdsMap: collections.NewConcurrentMap[string, RemoteNodeId](), + activeValidatorsMap: collections.NewConcurrentMap[bls.SerializedPublicKey, consensus.Validator](), + targetNonValidatorOutboundRemoteNodes: targetNonValidatorOutboundRemoteNodes, + targetNonValidatorInboundRemoteNodes: targetNonValidatorInboundRemoteNodes, + limitOneInboundRemoteNodePerIP: limitOneInboundConnectionPerIP, + exitChan: make(chan struct{}), + } +} + +func (nm *NetworkManager) Start() { + // If the NetworkManager routines are disabled, we do nothing. + if nm.params.DisableNetworkManagerRoutines { + return + } + + // Start the NetworkManager goroutines. The startGroup is used to ensure that all goroutines have started before + // exiting the context of this function. + nm.startGroup.Add(4) + go nm.startPersistentConnector() + go nm.startValidatorConnector() + go nm.startNonValidatorConnector() + go nm.startRemoteNodeCleanup() + + nm.startGroup.Wait() +} + +func (nm *NetworkManager) Stop() { + if !nm.params.DisableNetworkManagerRoutines { + nm.exitGroup.Add(4) + close(nm.exitChan) + nm.exitGroup.Wait() + } + nm.DisconnectAll() +} + +func (nm *NetworkManager) SetTargetOutboundPeers(numPeers uint32) { + nm.targetNonValidatorOutboundRemoteNodes = numPeers +} + +// ########################### +// ## NetworkManager Routines +// ########################### + +// startPersistentConnector is responsible for ensuring that the node is connected to all persistent IP addresses. It +// does this by periodically checking the persistentIpToRemoteNodeIdsMap, and connecting to any persistent IP addresses +// that are not already connected. +func (nm *NetworkManager) startPersistentConnector() { + nm.startGroup.Done() + for { + select { + case <-nm.exitChan: + nm.exitGroup.Done() + return + case <-time.After(1 * time.Second): + nm.refreshConnectIps() + } + } +} + +// startValidatorConnector is responsible for ensuring that the node is connected to all active validators. It does +// this in two steps. First, it looks through the already established connections and checks if any of these connections +// are validators. If they are, it adds them to the validator index. It also checks if any of the existing validators +// are no longer active and removes them from the validator index. Second, it checks if any of the active validators +// are missing from the validator index. If they are, it attempts to connect to them. +func (nm *NetworkManager) startValidatorConnector() { + nm.startGroup.Done() + for { + select { + case <-nm.exitChan: + nm.exitGroup.Done() + return + case <-time.After(1 * time.Second): + activeValidatorsMap := nm.getActiveValidatorsMap() + nm.refreshValidatorIndex(activeValidatorsMap) + nm.connectValidators(activeValidatorsMap) + } + } +} + +// startNonValidatorConnector is responsible for ensuring that the node is connected to the target number of outbound +// and inbound remote nodes. To do this, it periodically checks the number of outbound and inbound remote nodes, and +// if the number is above the target number, it disconnects the excess remote nodes. If the number is below the target +// number, it attempts to connect to new remote nodes. +func (nm *NetworkManager) startNonValidatorConnector() { + nm.startGroup.Done() + + for { + select { + case <-nm.exitChan: + nm.exitGroup.Done() + return + case <-time.After(1 * time.Second): + nm.refreshNonValidatorOutboundIndex() + nm.refreshNonValidatorInboundIndex() + nm.connectNonValidators() + } + } +} + +// startRemoteNodeCleanup is responsible for cleaning up RemoteNodes that may have timed out the handshake process, +// or became invalid for some other reason. +func (nm *NetworkManager) startRemoteNodeCleanup() { + nm.startGroup.Done() + + for { + select { + case <-nm.exitChan: + nm.exitGroup.Done() + return + case <-time.After(1 * time.Second): + nm.Cleanup() + } + } + +} + +// ########################### +// ## Handlers (Peer, DeSoMessage) +// ########################### + +// _handleVersionMessage is called when a new version message is received. +func (nm *NetworkManager) _handleVersionMessage(origin *Peer, desoMsg DeSoMessage) { + if desoMsg.GetMsgType() != MsgTypeVersion { + return + } + + rn := nm.GetRemoteNodeFromPeer(origin) + if rn == nil { + // This should never happen. + return + } + + var verMsg *MsgDeSoVersion + var ok bool + if verMsg, ok = desoMsg.(*MsgDeSoVersion); !ok { + glog.Errorf("NetworkManager.handleVersionMessage: Disconnecting RemoteNode with id: (%v) "+ + "error casting version message", origin.ID) + nm.Disconnect(rn) + return + } + + // If we've seen this nonce before then return an error since this is a connection from ourselves. + msgNonce := verMsg.Nonce + if nm.usedNonces.Contains(msgNonce) { + nm.usedNonces.Delete(msgNonce) + glog.Errorf("NetworkManager.handleVersionMessage: Disconnecting RemoteNode with id: (%v) "+ + "nonce collision, nonce (%v)", origin.ID, msgNonce) + nm.Disconnect(rn) + return + } + + // Call HandleVersionMessage on the RemoteNode. + responseNonce := uint64(RandInt64(math.MaxInt64)) + if err := rn.HandleVersionMessage(verMsg, responseNonce); err != nil { + glog.Errorf("NetworkManager.handleVersionMessage: Requesting PeerDisconnect for id: (%v) "+ + "error handling version message: %v", origin.ID, err) + nm.Disconnect(rn) + return + + } + nm.usedNonces.Add(responseNonce) +} + +// _handleVerackMessage is called when a new verack message is received. +func (nm *NetworkManager) _handleVerackMessage(origin *Peer, desoMsg DeSoMessage) { + if desoMsg.GetMsgType() != MsgTypeVerack { + return + } + + rn := nm.GetRemoteNodeFromPeer(origin) + if rn == nil { + // This should never happen. + return + } + + var vrkMsg *MsgDeSoVerack + var ok bool + if vrkMsg, ok = desoMsg.(*MsgDeSoVerack); !ok { + glog.Errorf("NetworkManager.handleVerackMessage: Disconnecting RemoteNode with id: (%v) "+ + "error casting verack message", origin.ID) + nm.Disconnect(rn) + return + } + + // Call HandleVerackMessage on the RemoteNode. + if err := rn.HandleVerackMessage(vrkMsg); err != nil { + glog.Errorf("NetworkManager.handleVerackMessage: Requesting PeerDisconnect for id: (%v) "+ + "error handling verack message: %v", origin.ID, err) + nm.Disconnect(rn) + return + } + + nm.handleHandshakeComplete(rn) +} + +// _handleDisconnectedPeerMessage is called when a peer is disconnected. It is responsible for cleaning up the +// RemoteNode associated with the peer. +func (nm *NetworkManager) _handleDisconnectedPeerMessage(origin *Peer, desoMsg DeSoMessage) { + if desoMsg.GetMsgType() != MsgTypeDisconnectedPeer { + return + } + + glog.V(2).Infof("NetworkManager._handleDisconnectedPeerMessage: Handling disconnected peer message for "+ + "id=%v", origin.ID) + nm.DisconnectById(NewRemoteNodeId(origin.ID)) + // Update the persistentIpToRemoteNodeIdsMap, in case the disconnected peer was a persistent peer. + ipRemoteNodeIdMap := nm.persistentIpToRemoteNodeIdsMap.ToMap() + for ip, id := range ipRemoteNodeIdMap { + if id.ToUint64() == origin.ID { + nm.persistentIpToRemoteNodeIdsMap.Remove(ip) + } + } +} + +// _handleNewConnectionMessage is called when a new outbound or inbound connection is established. It is responsible +// for creating a RemoteNode from the connection and initiating the handshake. The incoming DeSoMessage is a control message. +func (nm *NetworkManager) _handleNewConnectionMessage(origin *Peer, desoMsg DeSoMessage) { + if desoMsg.GetMsgType() != MsgTypeNewConnection { + return + } + + msg, ok := desoMsg.(*MsgDeSoNewConnection) + if !ok { + return + } + + var remoteNode *RemoteNode + var err error + // We create the RemoteNode differently depending on whether the connection is inbound or outbound. + switch msg.Connection.GetConnectionType() { + case ConnectionTypeInbound: + remoteNode, err = nm.processInboundConnection(msg.Connection) + if err != nil { + glog.Errorf("NetworkManager.handleNewConnectionMessage: Problem handling inbound connection: %v", err) + nm.cleanupFailedInboundConnection(remoteNode, msg.Connection) + return + } + case ConnectionTypeOutbound: + remoteNode, err = nm.processOutboundConnection(msg.Connection) + if err != nil { + glog.Errorf("NetworkManager.handleNewConnectionMessage: Problem handling outbound connection: %v", err) + nm.cleanupFailedOutboundConnection(msg.Connection) + return + } + } + + // If we made it here, we have a valid remote node. We will now initiate the handshake. + nm.InitiateHandshake(remoteNode) +} + +// processInboundConnection is called when a new inbound connection is established. At this point, the connection is not validated, +// nor is it assigned to a RemoteNode. This function is responsible for validating the connection and creating a RemoteNode from it. +// Once the RemoteNode is created, we will initiate handshake. +func (nm *NetworkManager) processInboundConnection(conn Connection) (*RemoteNode, error) { + var ic *inboundConnection + var ok bool + if ic, ok = conn.(*inboundConnection); !ok { + return nil, fmt.Errorf("NetworkManager.handleInboundConnection: Connection is not an inboundConnection") + } + + // If we want to limit inbound connections to one per IP address, check to make sure this address isn't already connected. + if nm.limitOneInboundRemoteNodePerIP && + nm.isDuplicateInboundIPAddress(ic.connection.RemoteAddr()) { + + return nil, fmt.Errorf("NetworkManager.handleInboundConnection: Rejecting INBOUND peer (%s) due to "+ + "already having an inbound connection from the same IP with limit_one_inbound_connection_per_ip set", + ic.connection.RemoteAddr().String()) + } + + na, err := nm.ConvertIPStringToNetAddress(ic.connection.RemoteAddr().String()) + if err != nil { + return nil, errors.Wrapf(err, "NetworkManager.handleInboundConnection: Problem calling "+ + "ConvertIPStringToNetAddress for addr: (%s)", ic.connection.RemoteAddr().String()) + } + + remoteNode, err := nm.AttachInboundConnection(ic.connection, na) + if remoteNode == nil || err != nil { + return nil, errors.Wrapf(err, "NetworkManager.handleInboundConnection: Problem calling "+ + "AttachInboundConnection for addr: (%s)", ic.connection.RemoteAddr().String()) + } + + return remoteNode, nil +} + +// processOutboundConnection is called when a new outbound connection is established. At this point, the connection is not validated, +// nor is it assigned to a RemoteNode. This function is responsible for validating the connection and creating a RemoteNode from it. +// Once the RemoteNode is created, we will initiate handshake. +func (nm *NetworkManager) processOutboundConnection(conn Connection) (*RemoteNode, error) { + var oc *outboundConnection + var ok bool + if oc, ok = conn.(*outboundConnection); !ok { + return nil, fmt.Errorf("NetworkManager.handleOutboundConnection: Connection is not an outboundConnection") + } + + if oc.failed { + return nil, fmt.Errorf("NetworkManager.handleOutboundConnection: Failed to connect to peer (%s:%v)", + oc.address.IP.String(), oc.address.Port) + } + + if !oc.isPersistent { + nm.AddrMgr.Connected(oc.address) + nm.AddrMgr.Good(oc.address) + } + + // If this is a non-persistent outbound peer and the group key overlaps with another peer we're already connected to then + // abort mission. We only connect to one peer per IP group in order to prevent Sybil attacks. + if !oc.isPersistent && nm.cmgr.IsFromRedundantOutboundIPAddress(oc.address) { + return nil, fmt.Errorf("NetworkManager.handleOutboundConnection: Rejecting OUTBOUND NON-PERSISTENT "+ + "connection with redundant group key (%s).", addrmgr.GroupKey(oc.address)) + } + + na, err := nm.ConvertIPStringToNetAddress(oc.connection.RemoteAddr().String()) + if err != nil { + return nil, errors.Wrapf(err, "NetworkManager.handleOutboundConnection: Problem calling ipToNetAddr "+ + "for addr: (%s)", oc.connection.RemoteAddr().String()) + } + + // Attach the connection before additional validation steps because it is already established. + remoteNode, err := nm.AttachOutboundConnection(oc.connection, na, oc.attemptId, oc.isPersistent) + if remoteNode == nil || err != nil { + return nil, errors.Wrapf(err, "NetworkManager.handleOutboundConnection: Problem calling AttachOutboundConnection "+ + "for addr: (%s)", oc.connection.RemoteAddr().String()) + } + + // If this is a persistent remote node or a validator, we don't need to do any extra connection validation. + if remoteNode.IsPersistent() || remoteNode.IsExpectedValidator() { + return remoteNode, nil + } + + // If we get here, it means we're dealing with a non-persistent or non-validator remote node. We perform additional + // connection validation. + + // If the group key overlaps with another peer we're already connected to then abort mission. We only connect to + // one peer per IP group in order to prevent Sybil attacks. + if nm.cmgr.IsFromRedundantOutboundIPAddress(oc.address) { + return nil, fmt.Errorf("NetworkManager.handleOutboundConnection: Rejecting OUTBOUND NON-PERSISTENT "+ + "connection with redundant group key (%s).", addrmgr.GroupKey(oc.address)) + } + nm.cmgr.AddToGroupKey(na) + + return remoteNode, nil +} + +// cleanupFailedInboundConnection is called when an inbound connection fails to be processed. It is responsible for +// cleaning up the RemoteNode and the connection. Most of the time, the RemoteNode will be nil, but if the RemoteNode +// was successfully created, we will disconnect it. +func (nm *NetworkManager) cleanupFailedInboundConnection(remoteNode *RemoteNode, connection Connection) { + glog.V(2).Infof("NetworkManager.cleanupFailedInboundConnection: Cleaning up failed inbound connection") + if remoteNode != nil { + nm.Disconnect(remoteNode) + } + connection.Close() +} + +// cleanupFailedOutboundConnection is called when an outbound connection fails to be processed. It is responsible for +// cleaning up the RemoteNode and the connection. +func (nm *NetworkManager) cleanupFailedOutboundConnection(connection Connection) { + oc, ok := connection.(*outboundConnection) + if !ok { + return + } + glog.V(2).Infof("NetworkManager.cleanupFailedOutboundConnection: Cleaning up failed outbound connection") + + // Find the RemoteNode associated with the connection. It should almost always exist, since we create the RemoteNode + // as we're attempting to connect to the address. + id := NewRemoteNodeId(oc.attemptId) + rn := nm.GetRemoteNodeById(id) + if rn != nil { + nm.Disconnect(rn) + } + oc.Close() + nm.cmgr.RemoveAttemptedOutboundAddrs(oc.address) +} + +// ########################### +// ## Persistent Connections +// ########################### + +// refreshConnectIps is called periodically by the persistent connector. It is responsible for connecting to all +// persistent IP addresses that we are not already connected to. +func (nm *NetworkManager) refreshConnectIps() { + // Connect to addresses passed via the --connect-ips flag. These addresses are persistent in the sense that if we + // disconnect from one, we will try to reconnect to the same one. + for _, connectIp := range nm.connectIps { + if _, ok := nm.persistentIpToRemoteNodeIdsMap.Get(connectIp); ok { + continue + } + + glog.Infof("NetworkManager.initiatePersistentConnections: Connecting to connectIp: %v", connectIp) + id, err := nm.CreateNonValidatorPersistentOutboundConnection(connectIp) + if err != nil { + glog.Errorf("NetworkManager.initiatePersistentConnections: Problem connecting "+ + "to connectIp %v: %v", connectIp, err) + continue + } + + nm.persistentIpToRemoteNodeIdsMap.Set(connectIp, id) + } +} + +// ########################### +// ## Validator Connections +// ########################### + +// SetActiveValidatorsMap is called by the owner of the NetworkManager to update the activeValidatorsMap. This should +// generally be done whenever the active validators set changes. +func (nm *NetworkManager) SetActiveValidatorsMap(activeValidatorsMap *collections.ConcurrentMap[bls.SerializedPublicKey, consensus.Validator]) { + nm.activeValidatorsMapLock.Lock() + defer nm.activeValidatorsMapLock.Unlock() + nm.activeValidatorsMap = activeValidatorsMap.Clone() + +} + +func (nm *NetworkManager) getActiveValidatorsMap() *collections.ConcurrentMap[bls.SerializedPublicKey, consensus.Validator] { + nm.activeValidatorsMapLock.RLock() + defer nm.activeValidatorsMapLock.RUnlock() + return nm.activeValidatorsMap.Clone() +} + +// refreshValidatorIndex re-indexes validators based on the activeValidatorsMap. It is called periodically by the +// validator connector. +func (nm *NetworkManager) refreshValidatorIndex(activeValidatorsMap *collections.ConcurrentMap[bls.SerializedPublicKey, consensus.Validator]) { + // De-index inactive validators. We skip any checks regarding RemoteNodes connection status, nor do we verify whether + // de-indexing the validator would result in an excess number of outbound/inbound connections. Any excess connections + // will be cleaned up by the NonValidator connector. + validatorRemoteNodeMap := nm.GetValidatorIndex().ToMap() + for pk, rn := range validatorRemoteNodeMap { + // If the validator is no longer active, de-index it. + if _, ok := activeValidatorsMap.Get(pk); !ok { + nm.SetNonValidator(rn) + nm.UnsetValidator(rn) + } + } + + // Look for validators in our existing outbound / inbound connections. + allNonValidators := nm.GetAllNonValidators() + for _, rn := range allNonValidators { + // It is possible for a RemoteNode to be in the non-validator indices, and still have a public key. This can happen + // if the RemoteNode advertised support for the SFValidator service flag during handshake, and provided us + // with a public key, and a corresponding proof of possession signature. + pk := rn.GetValidatorPublicKey() + if pk == nil { + continue + } + // It is possible that through unlikely concurrence, and malevolence, two non-validators happen to have the same + // public key, which goes undetected during handshake. To prevent this from affecting the indexing of the validator + // set, we check that the non-validator's public key is not already present in the validator index. + if _, ok := nm.GetValidatorIndex().Get(pk.Serialize()); ok { + glog.V(2).Infof("NetworkManager.refreshValidatorIndex: Disconnecting Validator RemoteNode "+ + "(%v) has validator public key (%v) that is already present in validator index", rn, pk) + nm.Disconnect(rn) + continue + } + + // If the RemoteNode turns out to be in the validator set, index it. + if _, ok := activeValidatorsMap.Get(pk.Serialize()); ok { + nm.SetValidator(rn) + nm.UnsetNonValidator(rn) + } + } +} + +// connectValidators attempts to connect to all active validators that are not already connected. It is called +// periodically by the validator connector. +func (nm *NetworkManager) connectValidators(activeValidatorsMap *collections.ConcurrentMap[bls.SerializedPublicKey, consensus.Validator]) { + // Look through the active validators and connect to any that we're not already connected to. + if nm.keystore == nil { + return + } + + validators := activeValidatorsMap.ToMap() + for pk, validator := range validators { + _, exists := nm.GetValidatorIndex().Get(pk) + // If we're already connected to the validator, continue. + if exists { + continue + } + // If the validator is our node, continue. + if nm.keystore.GetSigner().GetPublicKey().Serialize() == pk { + continue + } + + publicKey, err := pk.Deserialize() + if err != nil { + continue + } + + // For now, we only dial the first domain in the validator's domain list. + if len(validator.GetDomains()) == 0 { + continue + } + address := string(validator.GetDomains()[0]) + if err := nm.CreateValidatorConnection(address, publicKey); err != nil { + glog.V(2).Infof("NetworkManager.connectValidators: Problem connecting to validator %v: %v", address, err) + continue + } + } +} + +// ########################### +// ## NonValidator Connections +// ########################### + +// refreshNonValidatorOutboundIndex is called periodically by the NonValidator connector. It is responsible for +// disconnecting excess outbound remote nodes. +func (nm *NetworkManager) refreshNonValidatorOutboundIndex() { + // There are three categories of outbound remote nodes: attempted, connected, and persistent. All of these + // remote nodes are stored in the same non-validator outbound index. We want to disconnect excess remote nodes that + // are not persistent, starting with the attempted nodes first. + + // First let's run a quick check to see if the number of our non-validator remote nodes exceeds our target. Note that + // this number will include the persistent nodes. + numOutboundRemoteNodes := uint32(nm.GetNonValidatorOutboundIndex().Count()) + if numOutboundRemoteNodes <= nm.targetNonValidatorOutboundRemoteNodes { + return + } + + // If we get here, it means that we should potentially disconnect some remote nodes. Let's first separate the + // attempted and connected remote nodes, ignoring the persistent ones. + allOutboundRemoteNodes := nm.GetNonValidatorOutboundIndex().GetAll() + var attemptedOutboundRemoteNodes, connectedOutboundRemoteNodes []*RemoteNode + for _, rn := range allOutboundRemoteNodes { + if rn.IsPersistent() || rn.IsExpectedValidator() { + // We do nothing for persistent remote nodes or expected validators. + continue + } else if rn.IsHandshakeCompleted() { + connectedOutboundRemoteNodes = append(connectedOutboundRemoteNodes, rn) + } else { + attemptedOutboundRemoteNodes = append(attemptedOutboundRemoteNodes, rn) + } + } + + // Having separated the attempted and connected remote nodes, we can now find the actual number of attempted and + // connected remote nodes. We can then find out how many remote nodes we need to disconnect. + numOutboundRemoteNodes = uint32(len(attemptedOutboundRemoteNodes) + len(connectedOutboundRemoteNodes)) + excessiveOutboundRemoteNodes := uint32(0) + if numOutboundRemoteNodes > nm.targetNonValidatorOutboundRemoteNodes { + excessiveOutboundRemoteNodes = numOutboundRemoteNodes - nm.targetNonValidatorOutboundRemoteNodes + } + + // First disconnect the attempted remote nodes. + for _, rn := range attemptedOutboundRemoteNodes { + if excessiveOutboundRemoteNodes == 0 { + break + } + glog.V(2).Infof("NetworkManager.refreshNonValidatorOutboundIndex: Disconnecting attempted remote "+ + "node (id=%v) due to excess outbound RemoteNodes", rn.GetId()) + nm.Disconnect(rn) + excessiveOutboundRemoteNodes-- + } + // Now disconnect the connected remote nodes, if we still have too many remote nodes. + for _, rn := range connectedOutboundRemoteNodes { + if excessiveOutboundRemoteNodes == 0 { + break + } + glog.V(2).Infof("NetworkManager.refreshNonValidatorOutboundIndex: Disconnecting connected remote "+ + "node (id=%v) due to excess outbound RemoteNodes", rn.GetId()) + nm.Disconnect(rn) + excessiveOutboundRemoteNodes-- + } +} + +// refreshNonValidatorInboundIndex is called periodically by the non-validator connector. It is responsible for +// disconnecting excess inbound remote nodes. +func (nm *NetworkManager) refreshNonValidatorInboundIndex() { + // First let's check if we have an excess number of inbound remote nodes. If we do, we'll disconnect some of them. + numConnectedInboundRemoteNodes := uint32(nm.GetNonValidatorInboundIndex().Count()) + if numConnectedInboundRemoteNodes <= nm.targetNonValidatorInboundRemoteNodes { + return + } + + // Disconnect random inbound non-validators if we have too many of them. + inboundRemoteNodes := nm.GetNonValidatorInboundIndex().GetAll() + var connectedInboundRemoteNodes []*RemoteNode + for _, rn := range inboundRemoteNodes { + // We only want to disconnect remote nodes that have completed handshake. RemoteNodes that don't have the + // handshake completed status could be validators, in which case we don't want to disconnect them. It is also + // possible that the RemoteNodes without completed handshake will end up never finishing it, in which case + // they will be removed by the cleanup goroutine, once the handshake timeout is reached. + if rn.IsHandshakeCompleted() { + connectedInboundRemoteNodes = append(connectedInboundRemoteNodes, rn) + } + } + + // Having separated the connected remote nodes, we can now find the actual number of connected inbound remote nodes + // that have completed handshake. We can then find out how many remote nodes we need to disconnect. + numConnectedInboundRemoteNodes = uint32(len(connectedInboundRemoteNodes)) + excessiveInboundRemoteNodes := uint32(0) + if numConnectedInboundRemoteNodes > nm.targetNonValidatorInboundRemoteNodes { + excessiveInboundRemoteNodes = numConnectedInboundRemoteNodes - nm.targetNonValidatorInboundRemoteNodes + } + for _, rn := range connectedInboundRemoteNodes { + if excessiveInboundRemoteNodes == 0 { + break + } + glog.V(2).Infof("NetworkManager.refreshNonValidatorInboundIndex: Disconnecting inbound remote "+ + "node (id=%v) due to excess inbound RemoteNodes", rn.GetId()) + nm.Disconnect(rn) + excessiveInboundRemoteNodes-- + } +} + +// connectNonValidators attempts to connect to new outbound nonValidator remote nodes. It is called periodically by the +// nonValidator connector. +func (nm *NetworkManager) connectNonValidators() { + // First, find all nonValidator outbound remote nodes that are not persistent. + allOutboundRemoteNodes := nm.GetNonValidatorOutboundIndex().GetAll() + var nonValidatorOutboundRemoteNodes []*RemoteNode + for _, rn := range allOutboundRemoteNodes { + if rn.IsPersistent() || rn.IsExpectedValidator() { + // We do nothing for persistent remote nodes or expected validators. + continue + } else { + nonValidatorOutboundRemoteNodes = append(nonValidatorOutboundRemoteNodes, rn) + } + } + // Now find the number of nonValidator, non-persistent outbound remote nodes. + numOutboundRemoteNodes := uint32(len(nonValidatorOutboundRemoteNodes)) + remainingOutboundRemoteNodes := uint32(0) + // Check if we need to connect to more nonValidator outbound remote nodes. + if numOutboundRemoteNodes < nm.targetNonValidatorOutboundRemoteNodes { + remainingOutboundRemoteNodes = nm.targetNonValidatorOutboundRemoteNodes - numOutboundRemoteNodes + } + for ii := uint32(0); ii < remainingOutboundRemoteNodes; ii++ { + // Get a random unconnected address from the address manager. If we can't find one, we break out of the loop. + addr := nm.getRandomUnconnectedAddress() + if addr == nil { + break + } + // Attempt to connect to the address. + nm.AddrMgr.Attempt(addr) + if err := nm.createNonValidatorOutboundConnection(addr); err != nil { + glog.V(2).Infof("NetworkManager.connectNonValidators: Problem creating non-validator outbound "+ + "connection to addr: %v; err: %v", addr, err) + } + } +} + +// getRandomUnconnectedAddress returns a random address from the address manager that we are not already connected to. +func (nm *NetworkManager) getRandomUnconnectedAddress() *wire.NetAddress { + for tries := 0; tries < 100; tries++ { + addr := nm.AddrMgr.GetAddress() + if addr == nil { + break + } + + if nm.cmgr.IsConnectedOutboundIpAddress(addr.NetAddress()) { + continue + } + + if nm.cmgr.IsAttemptedOutboundIpAddress(addr.NetAddress()) { + continue + } + + // We can only have one outbound address per /16. This is similar to + // Bitcoin and we do it to prevent Sybil attacks. + if nm.cmgr.IsFromRedundantOutboundIPAddress(addr.NetAddress()) { + continue + } + + return addr.NetAddress() + } + + return nil +} + +// ########################### +// ## Create RemoteNode Functions +// ########################### + +func (nm *NetworkManager) CreateValidatorConnection(ipStr string, publicKey *bls.PublicKey) error { + netAddr, err := nm.ConvertIPStringToNetAddress(ipStr) + if err != nil { + return err + } + if netAddr == nil || publicKey == nil { + return fmt.Errorf("NetworkManager.CreateValidatorConnection: netAddr or public key is nil") + } + + if _, ok := nm.GetValidatorIndex().Get(publicKey.Serialize()); ok { + return fmt.Errorf("NetworkManager.CreateValidatorConnection: RemoteNode already exists for public key: %v", publicKey) + } + + remoteNode := nm.newRemoteNode(publicKey, false) + if err := remoteNode.DialOutboundConnection(netAddr); err != nil { + return errors.Wrapf(err, "NetworkManager.CreateValidatorConnection: Problem calling DialPersistentOutboundConnection "+ + "for addr: (%s:%v)", netAddr.IP.String(), netAddr.Port) + } + nm.setRemoteNode(remoteNode) + nm.GetValidatorIndex().Set(publicKey.Serialize(), remoteNode) + return nil +} + +func (nm *NetworkManager) CreateNonValidatorPersistentOutboundConnection(ipStr string) (RemoteNodeId, error) { + netAddr, err := nm.ConvertIPStringToNetAddress(ipStr) + if err != nil { + return 0, err + } + if netAddr == nil { + return 0, fmt.Errorf("NetworkManager.CreateNonValidatorPersistentOutboundConnection: netAddr is nil") + } + + remoteNode := nm.newRemoteNode(nil, true) + if err := remoteNode.DialPersistentOutboundConnection(netAddr); err != nil { + return 0, errors.Wrapf(err, "NetworkManager.CreateNonValidatorPersistentOutboundConnection: Problem calling DialPersistentOutboundConnection "+ + "for addr: (%s:%v)", netAddr.IP.String(), netAddr.Port) + } + nm.setRemoteNode(remoteNode) + nm.GetNonValidatorOutboundIndex().Set(remoteNode.GetId(), remoteNode) + return remoteNode.GetId(), nil +} + +func (nm *NetworkManager) CreateNonValidatorOutboundConnection(ipStr string) error { + netAddr, err := nm.ConvertIPStringToNetAddress(ipStr) + if err != nil { + return err + } + return nm.createNonValidatorOutboundConnection(netAddr) +} + +func (nm *NetworkManager) createNonValidatorOutboundConnection(netAddr *wire.NetAddress) error { + if netAddr == nil { + return fmt.Errorf("NetworkManager.CreateNonValidatorOutboundConnection: netAddr is nil") + } + + remoteNode := nm.newRemoteNode(nil, false) + if err := remoteNode.DialOutboundConnection(netAddr); err != nil { + return errors.Wrapf(err, "NetworkManager.CreateNonValidatorOutboundConnection: Problem calling DialOutboundConnection "+ + "for addr: (%s:%v)", netAddr.IP.String(), netAddr.Port) + } + nm.setRemoteNode(remoteNode) + nm.GetNonValidatorOutboundIndex().Set(remoteNode.GetId(), remoteNode) + return nil +} + +func (nm *NetworkManager) AttachInboundConnection(conn net.Conn, + na *wire.NetAddress) (*RemoteNode, error) { + + remoteNode := nm.newRemoteNode(nil, false) + if err := remoteNode.AttachInboundConnection(conn, na); err != nil { + return remoteNode, errors.Wrapf(err, "NetworkManager.AttachInboundConnection: Problem calling AttachInboundConnection "+ + "for addr: (%s)", conn.RemoteAddr().String()) + } + + nm.setRemoteNode(remoteNode) + return remoteNode, nil +} + +func (nm *NetworkManager) AttachOutboundConnection(conn net.Conn, na *wire.NetAddress, + remoteNodeId uint64, isPersistent bool) (*RemoteNode, error) { + + id := NewRemoteNodeId(remoteNodeId) + remoteNode := nm.GetRemoteNodeById(id) + if remoteNode == nil { + return nil, fmt.Errorf("NetworkManager.AttachOutboundConnection: Problem getting remote node by id (%d)", + id.ToUint64()) + } + + if err := remoteNode.AttachOutboundConnection(conn, na, isPersistent); err != nil { + nm.Disconnect(remoteNode) + return nil, errors.Wrapf(err, "NetworkManager.AttachOutboundConnection: Problem calling AttachOutboundConnection "+ + "for addr: (%s). Disconnecting remote node (id=%v)", conn.RemoteAddr().String(), remoteNode.GetId()) + } + + return remoteNode, nil +} + +// ########################### +// ## RemoteNode Management +// ########################### + +func (nm *NetworkManager) DisconnectAll() { + allRemoteNodes := nm.GetAllRemoteNodes().GetAll() + for _, rn := range allRemoteNodes { + glog.V(2).Infof("NetworkManager.DisconnectAll: Disconnecting from remote node (id=%v)", rn.GetId()) + nm.Disconnect(rn) + } +} + +func (nm *NetworkManager) newRemoteNode(validatorPublicKey *bls.PublicKey, isPersistent bool) *RemoteNode { + id := atomic.AddUint64(&nm.remoteNodeIndex, 1) + remoteNodeId := NewRemoteNodeId(id) + latestBlockHeight := uint64(nm.bc.BlockTip().Height) + return NewRemoteNode(remoteNodeId, validatorPublicKey, isPersistent, nm.srv, nm.cmgr, nm.keystore, + nm.params, nm.minTxFeeRateNanosPerKB, latestBlockHeight, nm.nodeServices) +} + +func (nm *NetworkManager) ProcessCompletedHandshake(remoteNode *RemoteNode) { + if remoteNode == nil { + return + } + + if remoteNode.IsValidator() { + nm.SetValidator(remoteNode) + nm.UnsetNonValidator(remoteNode) + } else { + nm.UnsetValidator(remoteNode) + nm.SetNonValidator(remoteNode) + } + nm.srv.HandleAcceptedPeer(remoteNode) + nm.srv.maybeRequestAddresses(remoteNode) +} + +func (nm *NetworkManager) Disconnect(rn *RemoteNode) { + if rn == nil { + return + } + glog.V(2).Infof("NetworkManager.Disconnect: Disconnecting from remote node id=%v", rn.GetId()) + rn.Disconnect() + nm.removeRemoteNodeFromIndexer(rn) +} + +func (nm *NetworkManager) DisconnectById(id RemoteNodeId) { + rn := nm.GetRemoteNodeById(id) + if rn == nil { + return + } + + nm.Disconnect(rn) +} + +func (nm *NetworkManager) SendMessage(rn *RemoteNode, desoMessage DeSoMessage) error { + if rn == nil { + return fmt.Errorf("NetworkManager.SendMessage: RemoteNode is nil") + } + + return rn.SendMessage(desoMessage) +} + +func (nm *NetworkManager) removeRemoteNodeFromIndexer(rn *RemoteNode) { + nm.mtx.Lock() + defer nm.mtx.Unlock() + + if rn == nil { + return + } + + nm.GetAllRemoteNodes().Remove(rn.GetId()) + nm.GetNonValidatorOutboundIndex().Remove(rn.GetId()) + nm.GetNonValidatorInboundIndex().Remove(rn.GetId()) + + // Try to evict the remote node from the validator index. If the remote node is not a validator, then there is nothing to do. + if rn.GetValidatorPublicKey() == nil { + return + } + // Only remove from the validator index if the fetched remote node is the same as the one we are trying to remove. + // Otherwise, we could have a fun edge-case where a duplicated validator connection ends up removing an + // existing validator connection from the index. + fetchedRn, ok := nm.GetValidatorIndex().Get(rn.GetValidatorPublicKey().Serialize()) + if ok && fetchedRn.GetId() == rn.GetId() { + nm.GetValidatorIndex().Remove(rn.GetValidatorPublicKey().Serialize()) + } +} + +func (nm *NetworkManager) Cleanup() { + allRemoteNodes := nm.GetAllRemoteNodes().GetAll() + for _, rn := range allRemoteNodes { + if rn.IsTimedOut() { + glog.V(2).Infof("NetworkManager.Cleanup: Disconnecting from remote node (id=%v)", rn.GetId()) + nm.Disconnect(rn) + } + } +} + +// ########################### +// ## RemoteNode Setters +// ########################### + +func (nm *NetworkManager) setRemoteNode(rn *RemoteNode) { + nm.mtx.Lock() + defer nm.mtx.Unlock() + + if rn == nil || rn.IsTerminated() { + return + } + + nm.GetAllRemoteNodes().Set(rn.GetId(), rn) +} + +func (nm *NetworkManager) SetNonValidator(rn *RemoteNode) { + nm.mtx.Lock() + defer nm.mtx.Unlock() + + if rn == nil || rn.IsTerminated() { + return + } + + if rn.IsOutbound() { + nm.GetNonValidatorOutboundIndex().Set(rn.GetId(), rn) + } else { + nm.GetNonValidatorInboundIndex().Set(rn.GetId(), rn) + } +} + +func (nm *NetworkManager) SetValidator(remoteNode *RemoteNode) { + nm.mtx.Lock() + defer nm.mtx.Unlock() + + if remoteNode == nil || remoteNode.IsTerminated() { + return + } + + pk := remoteNode.GetValidatorPublicKey() + if pk == nil { + return + } + nm.GetValidatorIndex().Set(pk.Serialize(), remoteNode) +} + +func (nm *NetworkManager) UnsetValidator(remoteNode *RemoteNode) { + nm.mtx.Lock() + defer nm.mtx.Unlock() + + if remoteNode == nil || remoteNode.IsTerminated() { + return + } + + pk := remoteNode.GetValidatorPublicKey() + if pk == nil { + return + } + nm.GetValidatorIndex().Remove(pk.Serialize()) +} + +func (nm *NetworkManager) UnsetNonValidator(rn *RemoteNode) { + nm.mtx.Lock() + defer nm.mtx.Unlock() + + if rn == nil || rn.IsTerminated() { + return + } + + if rn.IsOutbound() { + nm.GetNonValidatorOutboundIndex().Remove(rn.GetId()) + } else { + nm.GetNonValidatorInboundIndex().Remove(rn.GetId()) + } +} + +// ########################### +// ## RemoteNode Getters +// ########################### + +func (nm *NetworkManager) GetAllRemoteNodes() *collections.ConcurrentMap[RemoteNodeId, *RemoteNode] { + return nm.AllRemoteNodes +} + +func (nm *NetworkManager) GetValidatorIndex() *collections.ConcurrentMap[bls.SerializedPublicKey, *RemoteNode] { + return nm.ValidatorIndex +} + +func (nm *NetworkManager) GetNonValidatorOutboundIndex() *collections.ConcurrentMap[RemoteNodeId, *RemoteNode] { + return nm.NonValidatorOutboundIndex +} + +func (nm *NetworkManager) GetNonValidatorInboundIndex() *collections.ConcurrentMap[RemoteNodeId, *RemoteNode] { + return nm.NonValidatorInboundIndex +} + +func (nm *NetworkManager) GetRemoteNodeFromPeer(peer *Peer) *RemoteNode { + if peer == nil { + return nil + } + id := NewRemoteNodeId(peer.GetId()) + rn, _ := nm.GetAllRemoteNodes().Get(id) + return rn +} + +func (nm *NetworkManager) GetRemoteNodeById(id RemoteNodeId) *RemoteNode { + rn, ok := nm.GetAllRemoteNodes().Get(id) + if !ok { + return nil + } + return rn +} + +func (nm *NetworkManager) GetAllNonValidators() []*RemoteNode { + outboundRemoteNodes := nm.GetNonValidatorOutboundIndex().GetAll() + inboundRemoteNodes := nm.GetNonValidatorInboundIndex().GetAll() + return append(outboundRemoteNodes, inboundRemoteNodes...) +} + +// ########################### +// ## RemoteNode Handshake +// ########################### + +// InitiateHandshake kicks off handshake with a remote node. +func (nm *NetworkManager) InitiateHandshake(rn *RemoteNode) { + nonce := uint64(RandInt64(math.MaxInt64)) + if err := rn.InitiateHandshake(nonce); err != nil { + glog.Errorf("NetworkManager.InitiateHandshake: Error initiating handshake: %v", err) + nm.Disconnect(rn) + } + nm.usedNonces.Add(nonce) +} + +// handleHandshakeComplete is called on a completed handshake with a RemoteNodes. +func (nm *NetworkManager) handleHandshakeComplete(remoteNode *RemoteNode) { + // Prevent race conditions while handling handshake complete messages. + nm.mtxHandshakeComplete.Lock() + defer nm.mtxHandshakeComplete.Unlock() + + // Get the handshake information of this peer. + if remoteNode == nil { + return + } + + if remoteNode.GetNegotiatedProtocolVersion().Before(ProtocolVersion2) { + nm.ProcessCompletedHandshake(remoteNode) + return + } + + if err := nm.handleHandshakeCompletePoSMessage(remoteNode); err != nil { + glog.Errorf("NetworkManager.handleHandshakeComplete: Error handling PoS handshake peer message: %v, "+ + "remoteNodePk (%s)", err, remoteNode.GetValidatorPublicKey().Serialize()) + nm.Disconnect(remoteNode) + return + } + nm.ProcessCompletedHandshake(remoteNode) +} + +func (nm *NetworkManager) handleHandshakeCompletePoSMessage(remoteNode *RemoteNode) error { + + validatorPk := remoteNode.GetValidatorPublicKey() + // If the remote node is not a potential validator, we don't need to do anything. + if validatorPk == nil { + return nil + } + + // Lookup the validator in the ValidatorIndex with the same public key. + existingValidator, ok := nm.GetValidatorIndex().Get(validatorPk.Serialize()) + // For inbound RemoteNodes, we should ensure that there isn't an existing validator connected with the same public key. + // Inbound nodes are not initiated by us, so we shouldn't have added the RemoteNode to the ValidatorIndex yet. + if remoteNode.IsInbound() && ok { + return fmt.Errorf("NetworkManager.handleHandshakeCompletePoSMessage: Inbound RemoteNode with duplicate validator public key") + } + // For outbound RemoteNodes, we have two possible scenarios. Either the RemoteNode has been initiated as a validator, + // in which case it should already be in the ValidatorIndex. Or the RemoteNode has been initiated as a regular node, + // in which case it should not be in the ValidatorIndex, but in the NonValidatorOutboundIndex. So to ensure there is + // no duplicate connection with the same public key, we only check whether there is a validator in the ValidatorIndex + // with the RemoteNode's public key. If there is one, we want to ensure that these two RemoteNodes have identical ids. + if remoteNode.IsOutbound() && ok { + if remoteNode.GetId() != existingValidator.GetId() { + return fmt.Errorf("NetworkManager.handleHandshakeCompletePoSMessage: Outbound RemoteNode with duplicate validator public key. "+ + "Existing validator id: %v, new validator id: %v", existingValidator.GetId().ToUint64(), remoteNode.GetId().ToUint64()) + } + } + return nil +} + +// ########################### +// ## Helper Functions +// ########################### + +func (nm *NetworkManager) ConvertIPStringToNetAddress(ipStr string) (*wire.NetAddress, error) { + netAddr, err := IPToNetAddr(ipStr, nm.AddrMgr, nm.params) + if err != nil { + return nil, errors.Wrapf(err, + "NetworkManager.ConvertIPStringToNetAddress: Problem parsing "+ + "ipString to wire.NetAddress") + } + if netAddr == nil { + return nil, fmt.Errorf("NetworkManager.ConvertIPStringToNetAddress: " + + "address was nil after parsing") + } + return netAddr, nil +} + +func IPToNetAddr(ipStr string, addrMgr *addrmgr.AddrManager, params *DeSoParams) (*wire.NetAddress, error) { + port := params.DefaultSocketPort + host, portstr, err := net.SplitHostPort(ipStr) + if err != nil { + // No port specified so leave port=default and set + // host to the ipStr. + host = ipStr + } else { + pp, err := strconv.ParseUint(portstr, 10, 16) + if err != nil { + return nil, errors.Wrapf(err, "IPToNetAddr: Can not parse port from %s for ip", ipStr) + } + port = uint16(pp) + } + netAddr, err := addrMgr.HostToNetAddress(host, port, 0) + if err != nil { + return nil, errors.Wrapf(err, "IPToNetAddr: Can not parse port from %s for ip", ipStr) + } + return netAddr, nil +} + +func (nm *NetworkManager) isDuplicateInboundIPAddress(addr net.Addr) bool { + netAddr, err := IPToNetAddr(addr.String(), nm.AddrMgr, nm.params) + if err != nil { + // Return true in case we have an error. We do this because it + // will result in the peer connection not being accepted, which + // is desired in this case. + glog.Warningf(errors.Wrapf(err, + "NetworkManager.isDuplicateInboundIPAddress: Problem parsing "+ + "net.Addr to wire.NetAddress so marking as redundant and not "+ + "making connection").Error()) + return true + } + if netAddr == nil { + glog.Warningf("NetworkManager.isDuplicateInboundIPAddress: " + + "address was nil after parsing so marking as redundant and not " + + "making connection") + return true + } + + return nm.cmgr.IsDuplicateInboundIPAddress(netAddr) +} diff --git a/lib/network_test.go b/lib/network_test.go index 85dc3d85e..fef075d20 100644 --- a/lib/network_test.go +++ b/lib/network_test.go @@ -5,6 +5,8 @@ package lib import ( "bytes" "encoding/hex" + "github.com/deso-protocol/core/bls" + "golang.org/x/crypto/sha3" "math/big" "math/rand" "reflect" @@ -41,7 +43,7 @@ var expectedVer = &MsgDeSoVersion{ TstampSecs: 2, Nonce: uint64(0xffffffffffffffff), UserAgent: "abcdef", - StartBlockHeight: 4, + LatestBlockHeight: 4, MinFeeRateNanosPerKB: 10, } @@ -68,7 +70,7 @@ func TestVersionConversion(t *testing.T) { "works, add the new field to the test case, and fix this error.") } -func TestVerack(t *testing.T) { +func TestVerackV0(t *testing.T) { assert := assert.New(t) require := require.New(t) _ = assert @@ -78,13 +80,51 @@ func TestVerack(t *testing.T) { var buf bytes.Buffer nonce := uint64(12345678910) - _, err := WriteMessage(&buf, &MsgDeSoVerack{Nonce: nonce}, networkType) + _, err := WriteMessage(&buf, &MsgDeSoVerack{Version: VerackVersion0, NonceReceived: nonce}, networkType) require.NoError(err) verBytes := buf.Bytes() testMsg, _, err := ReadMessage(bytes.NewReader(verBytes), networkType) require.NoError(err) - require.Equal(&MsgDeSoVerack{Nonce: nonce}, testMsg) + require.Equal(&MsgDeSoVerack{Version: VerackVersion0, NonceReceived: nonce}, testMsg) +} + +func TestVerackV1(t *testing.T) { + require := require.New(t) + + networkType := NetworkType_MAINNET + var buf1, buf2 bytes.Buffer + + nonceReceived := uint64(12345678910) + nonceSent := nonceReceived + 1 + tstamp := uint64(2345678910) + // First, test that nil public key and signature are allowed. + msg := &MsgDeSoVerack{ + Version: VerackVersion1, + NonceReceived: nonceReceived, + NonceSent: nonceSent, + TstampMicro: tstamp, + PublicKey: nil, + Signature: nil, + } + _, err := WriteMessage(&buf1, msg, networkType) + require.NoError(err) + payload := append(UintToBuf(nonceReceived), UintToBuf(nonceSent)...) + payload = append(payload, UintToBuf(tstamp)...) + hash := sha3.Sum256(payload) + + priv, err := bls.NewPrivateKey() + require.NoError(err) + msg.PublicKey = priv.PublicKey() + msg.Signature, err = priv.Sign(hash[:]) + require.NoError(err) + _, err = WriteMessage(&buf2, msg, networkType) + require.NoError(err) + + verBytes := buf2.Bytes() + testMsg, _, err := ReadMessage(bytes.NewReader(verBytes), networkType) + require.NoError(err) + require.Equal(msg, testMsg) } // Creates fully formatted a PoS block header with random signatures @@ -102,12 +142,6 @@ func createTestBlockHeaderVersion2(t *testing.T, includeTimeoutQC bool) *MsgDeSo 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, 0x60, 0x61, 0x62, 0x63, 0x64, 0x65, } - testTxnConnectStatusByIndex := BlockHash{ - 0x00, 0x03, 0x04, 0x21, 0x06, 0x07, 0x08, 0x09, 0x10, 0x19, - 0x12, 0x13, 0x14, 0x15, 0x44, 0x17, 0x18, 0x19, 0x20, 0x21, - 0x02, 0x23, 0x24, 0x25, 0x26, 0x27, 0x33, 0x29, 0x30, 0x31, - 0x32, 0x33, - } testBitset := bitset.NewBitset().Set(0, true).Set(3, true) testBLSPublicKey, testBLSSignature := _generateValidatorVotingPublicKeyAndSignature(t) @@ -147,7 +181,6 @@ func createTestBlockHeaderVersion2(t *testing.T, includeTimeoutQC bool) *MsgDeSo // Nonce and ExtraNonce are unused and set to 0 starting in version 2. Nonce: uint64(0), ExtraNonce: uint64(0), - TxnConnectStatusByIndexHash: &testTxnConnectStatusByIndex, ProposerVotingPublicKey: testBLSPublicKey, ProposerRandomSeedSignature: testBLSSignature, ProposedInView: uint64(1432101234), @@ -413,10 +446,6 @@ func createTestBlockVersion2(t *testing.T, includeTimeoutQC bool) *MsgDeSoBlock // Set V2 header. block.Header = createTestBlockHeaderVersion2(t, includeTimeoutQC) - // Set the block's TxnConnectStatusByIndex and update its hash in the header. - block.TxnConnectStatusByIndex = bitset.NewBitset().Set(0, true).Set(3, true) - block.Header.TxnConnectStatusByIndexHash = HashBitset(block.TxnConnectStatusByIndex) - return &block } diff --git a/lib/peer.go b/lib/peer.go index 780f72f62..1f759c3c4 100644 --- a/lib/peer.go +++ b/lib/peer.go @@ -2,14 +2,13 @@ package lib import ( "fmt" - "math" + "github.com/decred/dcrd/lru" "net" "sort" + "sync" "sync/atomic" "time" - "github.com/decred/dcrd/lru" - "github.com/btcsuite/btcd/wire" "github.com/deso-protocol/go-deadlock" "github.com/golang/glog" @@ -49,7 +48,6 @@ type Peer struct { StatsMtx deadlock.RWMutex TimeOffsetSecs int64 TimeConnected time.Time - startingHeight uint32 ID uint64 // Ping-related fields. LastPingNonce uint64 @@ -64,36 +62,17 @@ type Peer struct { stallTimeoutSeconds uint64 Params *DeSoParams MessageChan chan *ServerMessage - // A hack to make it so that we can allow an API endpoint to manually - // delete a peer. - PeerManuallyRemovedFromConnectionManager bool - - // In order to complete a version negotiation successfully, the peer must - // reply to the initial version message we send them with a verack message - // containing the nonce from that initial version message. This ensures that - // the peer's IP isn't being spoofed since the only way to actually produce - // a verack with the appropriate response is to actually own the IP that - // the peer claims it has. As such, we maintain the version nonce we sent - // the peer and the version nonce they sent us here. - // - // TODO: The way we synchronize the version nonce is currently a bit - // messy; ideally we could do it without keeping global state. - VersionNonceSent uint64 - VersionNonceReceived uint64 // A pointer to the Server srv *Server // Basic state. - PeerInfoMtx deadlock.Mutex - serviceFlags ServiceFlag - addrStr string - netAddr *wire.NetAddress - userAgent string - advertisedProtocolVersion uint64 - negotiatedProtocolVersion uint64 - VersionNegotiated bool - minTxFeeRateNanosPerKB uint64 + PeerInfoMtx deadlock.Mutex + serviceFlags ServiceFlag + latestHeight uint64 + addrStr string + netAddr *wire.NetAddress + minTxFeeRateNanosPerKB uint64 // Messages for which we are expecting a reply within a fixed // amount of time. This list is always sorted by ExpectedTime, // with the item having the earliest time at the front. @@ -104,7 +83,8 @@ type Peer struct { knownAddressesMap map[string]bool // Output queue for messages that need to be sent to the peer. - outputQueueChan chan DeSoMessage + outputQueueChan chan DeSoMessage + peerDisconnectedChan chan *Peer // Set to zero until Disconnect has been called on the Peer. Used to make it // so that the logic in Disconnect will only be executed once. @@ -143,6 +123,13 @@ type Peer struct { // SyncType indicates whether blocksync should not be requested for this peer. If set to true // then we'll only hypersync from this peer. syncType NodeSyncType + + // startGroup ensures that all the Peer's go routines are started when we call Start(). + startGroup sync.WaitGroup +} + +func (pp *Peer) GetId() uint64 { + return pp.ID } func (pp *Peer) AddDeSoMessage(desoMessage DeSoMessage, inbound bool) { @@ -193,7 +180,7 @@ func (pp *Peer) HandleGetTransactionsMsg(getTxnMsg *MsgDeSoGetTransactions) { // If the transaction isn't in the pool, just continue without adding // it. It is generally OK to respond with only a subset of the transactions // that were requested. - if mempoolTx == nil { + if mempoolTx == nil || !mempoolTx.IsValidated() { continue } @@ -325,7 +312,7 @@ func (pp *Peer) HelpHandleInv(msg *MsgDeSoInv) { // For transactions, check that the transaction isn't in the // mempool and that it isn't currently being requested. _, requestIsInFlight := pp.srv.requestedTransactionsMap[currentHash] - if requestIsInFlight || pp.srv.mempool.IsTransactionInPool(¤tHash) { + if requestIsInFlight || pp.srv.GetMempool().GetTransaction(¤tHash) != nil { continue } @@ -551,6 +538,7 @@ func (pp *Peer) cleanupMessageProcessor() { } func (pp *Peer) StartDeSoMessageProcessor() { + pp.startGroup.Done() glog.Infof("StartDeSoMessageProcessor: Starting for peer %v", pp) for { if pp.disconnected != 0 { @@ -614,15 +602,17 @@ func (pp *Peer) StartDeSoMessageProcessor() { } // NewPeer creates a new Peer object. -func NewPeer(_conn net.Conn, _isOutbound bool, _netAddr *wire.NetAddress, +func NewPeer(_id uint64, _conn net.Conn, _isOutbound bool, _netAddr *wire.NetAddress, _isPersistent bool, _stallTimeoutSeconds uint64, _minFeeRateNanosPerKB uint64, params *DeSoParams, messageChan chan *ServerMessage, _cmgr *ConnectionManager, _srv *Server, - _syncType NodeSyncType) *Peer { + _syncType NodeSyncType, + peerDisconnectedChan chan *Peer) *Peer { pp := Peer{ + ID: _id, cmgr: _cmgr, srv: _srv, Conn: _conn, @@ -631,6 +621,7 @@ func NewPeer(_conn net.Conn, _isOutbound bool, _netAddr *wire.NetAddress, isOutbound: _isOutbound, isPersistent: _isPersistent, outputQueueChan: make(chan DeSoMessage), + peerDisconnectedChan: peerDisconnectedChan, quit: make(chan interface{}), knownInventory: lru.NewCache(maxKnownInventory), blocksToSend: make(map[BlockHash]bool), @@ -642,9 +633,6 @@ func NewPeer(_conn net.Conn, _isOutbound bool, _netAddr *wire.NetAddress, requestedBlocks: make(map[BlockHash]bool), syncType: _syncType, } - if _cmgr != nil { - pp.ID = atomic.AddUint64(&_cmgr.peerIndex, 1) - } // TODO: Before, we would give each Peer its own Logger object. Now we // have a much better way of debugging which is that we include a nonce @@ -679,10 +667,10 @@ func (pp *Peer) MinFeeRateNanosPerKB() uint64 { } // StartingBlockHeight is the height of the peer's blockchain tip. -func (pp *Peer) StartingBlockHeight() uint32 { +func (pp *Peer) StartingBlockHeight() uint64 { pp.StatsMtx.RLock() defer pp.StatsMtx.RUnlock() - return pp.startingHeight + return pp.latestHeight } // NumBlocksToSend is the number of blocks the Peer has requested from @@ -738,6 +726,7 @@ func (pp *Peer) HandlePongMsg(msg *MsgDeSoPong) { } func (pp *Peer) PingHandler() { + pp.startGroup.Done() glog.V(1).Infof("Peer.PingHandler: Starting ping handler for Peer %v", pp) pingTicker := time.NewTicker(pingInterval) defer pingTicker.Stop() @@ -787,6 +776,10 @@ func (pp *Peer) Address() string { return pp.addrStr } +func (pp *Peer) NetAddress() *wire.NetAddress { + return pp.netAddr +} + func (pp *Peer) IP() string { return pp.netAddr.IP.String() } @@ -799,6 +792,10 @@ func (pp *Peer) IsOutbound() bool { return pp.isOutbound } +func (pp *Peer) IsPersistent() bool { + return pp.isPersistent +} + func (pp *Peer) QueueMessage(desoMessage DeSoMessage) { // If the peer is disconnected, don't queue anything. if !pp.Connected() { @@ -898,7 +895,22 @@ func (pp *Peer) _setKnownAddressesMap(key string, val bool) { pp.knownAddressesMap[key] = val } +func (pp *Peer) SetLatestBlockHeight(height uint64) { + pp.StatsMtx.Lock() + defer pp.StatsMtx.Unlock() + + pp.latestHeight = height +} + +func (pp *Peer) SetServiceFlag(sf ServiceFlag) { + pp.PeerInfoMtx.Lock() + defer pp.PeerInfoMtx.Unlock() + + pp.serviceFlags = sf +} + func (pp *Peer) outHandler() { + pp.startGroup.Done() glog.V(1).Infof("Peer.outHandler: Starting outHandler for Peer %v", pp) stallTicker := time.NewTicker(time.Second) out: @@ -1078,6 +1090,7 @@ func (pp *Peer) _handleInExpectedResponse(rmsg DeSoMessage) error { // inHandler handles all incoming messages for the peer. It must be run as a // goroutine. func (pp *Peer) inHandler() { + pp.startGroup.Done() glog.V(1).Infof("Peer.inHandler: Starting inHandler for Peer %v", pp) // The timer is stopped when a new message is received and reset after it @@ -1134,20 +1147,6 @@ out: // This switch actually processes the message. For most messages, we just // pass them onto the Server. switch msg := rmsg.(type) { - case *MsgDeSoVersion: - // We always receive the VERSION from the Peer before starting this select - // statement, so getting one here is an error. - - glog.Errorf("Peer.inHandler: Already received 'version' from peer %v -- disconnecting", pp) - break out - - case *MsgDeSoVerack: - // We always receive the VERACK from the Peer before starting this select - // statement, so getting one here is an error. - - glog.Errorf("Peer.inHandler: Already received 'verack' from peer %v -- disconnecting", pp) - break out - case *MsgDeSoPing: // Respond to a ping with a pong. pp.HandlePingMsg(msg) @@ -1156,7 +1155,7 @@ out: // Measure the ping time when we receive a pong. pp.HandlePongMsg(msg) - case *MsgDeSoNewPeer, *MsgDeSoDonePeer, *MsgDeSoQuit: + case *MsgDeSoDisconnectedPeer, *MsgDeSoQuit: // We should never receive control messages from a Peer. Disconnect if we do. glog.Errorf("Peer.inHandler: Received control message of type %v from "+ @@ -1189,20 +1188,12 @@ func (pp *Peer) Start() { glog.Infof("Peer.Start: Starting peer %v", pp) // The protocol has been negotiated successfully so start processing input // and output messages. + pp.startGroup.Add(4) go pp.PingHandler() go pp.outHandler() go pp.inHandler() go pp.StartDeSoMessageProcessor() - - // If the address manager needs more addresses, then send a GetAddr message - // to the peer. This is best-effort. - if pp.cmgr != nil { - if pp.cmgr.AddrMgr.NeedMoreAddresses() { - go func() { - pp.QueueMessage(&MsgDeSoGetAddr{}) - }() - } - } + pp.startGroup.Wait() // Send our verack message now that the IO processing machinery has started. } @@ -1284,226 +1275,17 @@ func (pp *Peer) ReadDeSoMessage() (DeSoMessage, error) { return msg, nil } -func (pp *Peer) NewVersionMessage(params *DeSoParams) *MsgDeSoVersion { - ver := NewMessage(MsgTypeVersion).(*MsgDeSoVersion) - - ver.Version = params.ProtocolVersion - ver.TstampSecs = time.Now().Unix() - // We use an int64 instead of a uint64 for convenience but - // this should be fine since we're just looking to generate a - // unique value. - ver.Nonce = uint64(RandInt64(math.MaxInt64)) - ver.UserAgent = params.UserAgent - // TODO: Right now all peers are full nodes. Later on we'll want to change this, - // at which point we'll need to do a little refactoring. - ver.Services = SFFullNodeDeprecated - if pp.cmgr != nil && pp.cmgr.HyperSync { - ver.Services |= SFHyperSync - } - if pp.srv.blockchain.archivalMode { - ver.Services |= SFArchivalNode - } - - // When a node asks you for what height you have, you should reply with - // the height of the latest actual block you have. This makes it so that - // peers who have up-to-date headers but missing blocks won't be considered - // for initial block download. - // - // TODO: This is ugly. It would be nice if the Peer required zero knowledge of the - // Server and the Blockchain. - if pp.srv != nil { - ver.StartBlockHeight = uint32(pp.srv.blockchain.blockTip().Header.Height) - } else { - ver.StartBlockHeight = uint32(0) - } - - // Set the minimum fee rate the peer will accept. - ver.MinFeeRateNanosPerKB = pp.minTxFeeRateNanosPerKB - - return ver -} - -func (pp *Peer) sendVerack() error { - verackMsg := NewMessage(MsgTypeVerack) - // Include the nonce we received in the peer's version message so - // we can validate that we actually control our IP address. - verackMsg.(*MsgDeSoVerack).Nonce = pp.VersionNonceReceived - if err := pp.WriteDeSoMessage(verackMsg); err != nil { - return errors.Wrap(err, "sendVerack: ") - } - - return nil -} - -func (pp *Peer) readVerack() error { - msg, err := pp.ReadDeSoMessage() - if err != nil { - return errors.Wrap(err, "readVerack: ") - } - if msg.GetMsgType() != MsgTypeVerack { - return fmt.Errorf( - "readVerack: Received message with type %s but expected type VERACK. ", - msg.GetMsgType().String()) - } - verackMsg := msg.(*MsgDeSoVerack) - if verackMsg.Nonce != pp.VersionNonceSent { - return fmt.Errorf( - "readVerack: Received VERACK message with nonce %d but expected nonce %d", - verackMsg.Nonce, pp.VersionNonceSent) - } - - return nil -} - -func (pp *Peer) sendVersion() error { - // For an outbound peer, we send a version message and then wait to - // hear back for one. - verMsg := pp.NewVersionMessage(pp.Params) - - // Record the nonce of this version message before we send it so we can - // detect self connections and so we can validate that the peer actually - // controls the IP she's supposedly communicating to us from. - pp.VersionNonceSent = verMsg.Nonce - if pp.cmgr != nil { - pp.cmgr.sentNonces.Add(pp.VersionNonceSent) - } - - if err := pp.WriteDeSoMessage(verMsg); err != nil { - return errors.Wrap(err, "sendVersion: ") - } - - return nil -} - -func (pp *Peer) readVersion() error { - msg, err := pp.ReadDeSoMessage() - if err != nil { - return errors.Wrap(err, "readVersion: ") - } - - verMsg, ok := msg.(*MsgDeSoVersion) - if !ok { - return fmt.Errorf( - "readVersion: Received message with type %s but expected type VERSION. "+ - "The VERSION message must preceed all others", msg.GetMsgType().String()) - } - if verMsg.Version < pp.Params.MinProtocolVersion { - return fmt.Errorf("readVersion: Peer's protocol version too low: %d (min: %v)", - verMsg.Version, pp.Params.MinProtocolVersion) - } - - // If we've sent this nonce before then return an error since this is - // a connection from ourselves. - msgNonce := verMsg.Nonce - if pp.cmgr != nil { - if pp.cmgr.sentNonces.Contains(msgNonce) { - pp.cmgr.sentNonces.Delete(msgNonce) - return fmt.Errorf("readVersion: Rejecting connection to self") - } - } - // Save the version nonce so we can include it in our verack message. - pp.VersionNonceReceived = msgNonce - - // Set the peer info-related fields. - pp.PeerInfoMtx.Lock() - pp.userAgent = verMsg.UserAgent - pp.serviceFlags = verMsg.Services - pp.advertisedProtocolVersion = verMsg.Version - negotiatedVersion := pp.Params.ProtocolVersion - if pp.advertisedProtocolVersion < pp.Params.ProtocolVersion { - negotiatedVersion = pp.advertisedProtocolVersion - } - pp.negotiatedProtocolVersion = negotiatedVersion - pp.PeerInfoMtx.Unlock() - - // Set the stats-related fields. - pp.StatsMtx.Lock() - pp.startingHeight = verMsg.StartBlockHeight - pp.minTxFeeRateNanosPerKB = verMsg.MinFeeRateNanosPerKB - pp.TimeConnected = time.Unix(verMsg.TstampSecs, 0) - pp.TimeOffsetSecs = verMsg.TstampSecs - time.Now().Unix() - pp.StatsMtx.Unlock() - - // Update the timeSource now that we've gotten a version message from the - // peer. - if pp.cmgr != nil { - pp.cmgr.timeSource.AddTimeSample(pp.addrStr, pp.TimeConnected) - } - - return nil -} - -func (pp *Peer) ReadWithTimeout(readFunc func() error, readTimeout time.Duration) error { - errChan := make(chan error) - go func() { - errChan <- readFunc() - }() - select { - case err := <-errChan: - { - return err - } - case <-time.After(readTimeout): - { - return fmt.Errorf("ReadWithTimeout: Timed out reading message from peer: (%v)", pp) - } - } -} - -func (pp *Peer) NegotiateVersion(versionNegotiationTimeout time.Duration) error { - if pp.isOutbound { - // Write a version message. - if err := pp.sendVersion(); err != nil { - return errors.Wrapf(err, "negotiateVersion: Problem sending version to Peer %v", pp) - } - // Read the peer's version. - if err := pp.ReadWithTimeout( - pp.readVersion, - versionNegotiationTimeout); err != nil { - - return errors.Wrapf(err, "negotiateVersion: Problem reading OUTBOUND peer version for Peer %v", pp) - } - } else { - // Read the version first since this is an inbound peer. - if err := pp.ReadWithTimeout( - pp.readVersion, - versionNegotiationTimeout); err != nil { - - return errors.Wrapf(err, "negotiateVersion: Problem reading INBOUND peer version for Peer %v", pp) - } - if err := pp.sendVersion(); err != nil { - return errors.Wrapf(err, "negotiateVersion: Problem sending version to Peer %v", pp) - } - } - - // After sending and receiving a compatible version, complete the - // negotiation by sending and receiving a verack message. - if err := pp.sendVerack(); err != nil { - return errors.Wrapf(err, "negotiateVersion: Problem sending verack to Peer %v", pp) - } - if err := pp.ReadWithTimeout( - pp.readVerack, - versionNegotiationTimeout); err != nil { - - return errors.Wrapf(err, "negotiateVersion: Problem reading VERACK message from Peer %v", pp) - } - pp.VersionNegotiated = true - - // At this point we have sent a version and validated our peer's - // version. So the negotiation should be complete. - return nil -} - // Disconnect closes a peer's network connection. func (pp *Peer) Disconnect() { // Only run the logic the first time Disconnect is called. glog.V(1).Infof(CLog(Yellow, "Peer.Disconnect: Starting")) - if atomic.AddInt32(&pp.disconnected, 1) != 1 { + if atomic.LoadInt32(&pp.disconnected) != 0 { glog.V(1).Infof("Peer.Disconnect: Disconnect call ignored since it was already called before for Peer %v", pp) return } + atomic.AddInt32(&pp.disconnected, 1) - glog.V(1).Infof("Peer.Disconnect: Running Disconnect for the first time for Peer %v", pp) + glog.V(2).Infof("Peer.Disconnect: Running Disconnect for the first time for Peer %v", pp) // Close the connection object. pp.Conn.Close() @@ -1513,9 +1295,7 @@ func (pp *Peer) Disconnect() { // Add the Peer to donePeers so that the ConnectionManager and Server can do any // cleanup they need to do. - if pp.cmgr != nil && atomic.LoadInt32(&pp.cmgr.shutdown) == 0 && pp.cmgr.donePeerChan != nil { - pp.cmgr.donePeerChan <- pp - } + pp.peerDisconnectedChan <- pp } func (pp *Peer) _logVersionSuccess() { diff --git a/lib/pos_block_producer.go b/lib/pos_block_producer.go index d216a1153..372dc94a1 100644 --- a/lib/pos_block_producer.go +++ b/lib/pos_block_producer.go @@ -7,7 +7,6 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/deso-protocol/core/bls" - "github.com/deso-protocol/core/collections/bitset" "github.com/pkg/errors" ) @@ -109,8 +108,6 @@ func (pbp *PosBlockProducer) createBlockTemplate(latestBlockView *UtxoView, newB block.Header.ProposerVotingPublicKey = pbp.proposerVotingPublicKey block.Header.ProposerRandomSeedSignature = proposerRandomSeedSignature - // Hash the TxnConnectStatusByIndex - block.Header.TxnConnectStatusByIndexHash = HashBitset(block.TxnConnectStatusByIndex) return block, nil } @@ -139,7 +136,7 @@ func (pbp *PosBlockProducer) createBlockWithoutHeader( } // Get block transactions from the mempool. - feeTimeTxns, txnConnectStatusByIndex, maxUtilityFee, err := pbp.getBlockTransactions( + feeTimeTxns, maxUtilityFee, err := pbp.getBlockTransactions( pbp.proposerPublicKey, latestBlockView, newBlockHeight, @@ -155,7 +152,6 @@ func (pbp *PosBlockProducer) createBlockWithoutHeader( block.Txns = append([]*MsgDeSoTxn{blockRewardTxn}, feeTimeTxns...) // Set the RevolutionMetadata - block.TxnConnectStatusByIndex = txnConnectStatusByIndex return block, nil } @@ -168,7 +164,6 @@ func (pbp *PosBlockProducer) getBlockTransactions( maxBlockSizeBytes uint64, ) ( _txns []*MsgDeSoTxn, - _txnConnectStatusByIndex *bitset.Bitset, _maxUtilityFee uint64, _err error, ) { @@ -177,17 +172,16 @@ func (pbp *PosBlockProducer) getBlockTransactions( // Try to connect transactions one by one. blocksTxns := []*MsgDeSoTxn{} - txnConnectStatusByIndex := bitset.NewBitset() maxUtilityFee := uint64(0) currentBlockSize := uint64(0) blockUtxoView, err := latestBlockView.CopyUtxoView() if err != nil { - return nil, nil, 0, errors.Wrapf(err, "Error copying UtxoView: ") + return nil, 0, errors.Wrapf(err, "Error copying UtxoView: ") } for _, txn := range feeTimeTxns { txnBytes, err := txn.ToBytes(false) if err != nil { - return nil, nil, 0, errors.Wrapf(err, "Error getting transaction size: ") + return nil, 0, errors.Wrapf(err, "Error getting transaction size: ") } // Skip over transactions that are too big. @@ -197,68 +191,34 @@ func (pbp *PosBlockProducer) getBlockTransactions( blockUtxoViewCopy, err := blockUtxoView.CopyUtxoView() if err != nil { - return nil, nil, 0, errors.Wrapf(err, "Error copying UtxoView: ") + return nil, 0, errors.Wrapf(err, "Error copying UtxoView: ") } _, _, _, fees, err := blockUtxoViewCopy._connectTransaction( txn.GetTxn(), txn.Hash(), uint32(newBlockHeight), newBlockTimestampNanoSecs, true, false) // Check if the transaction connected. - if err == nil { - blockUtxoView = blockUtxoViewCopy - txnConnectStatusByIndex.Set(len(blocksTxns), true) - blocksTxns = append(blocksTxns, txn.GetTxn()) - currentBlockSize += uint64(len(txnBytes)) - - // If the transactor is the block producer, then they won't receive the utility - // fee. - if blockProducerPublicKey.Equal(*NewPublicKey(txn.PublicKey)) { - continue - } - - // Compute BMF for the transaction. - _, utilityFee := computeBMF(fees) - maxUtilityFee, err = SafeUint64().Add(maxUtilityFee, utilityFee) - if err != nil { - return nil, nil, 0, errors.Wrapf(err, "Error computing max utility fee: ") - } - continue - } - - // If the transaction didn't connect, we will try to add it as a failing transaction. - blockUtxoViewCopy, err = blockUtxoView.CopyUtxoView() if err != nil { - return nil, nil, 0, errors.Wrapf(err, "Error copying UtxoView: ") - } - - _, _, utilityFee, err := blockUtxoViewCopy._connectFailingTransaction(txn.GetTxn(), uint32(newBlockHeight), true) - if err != nil { - // If the transaction still doesn't connect, this means we encountered an invalid transaction. We will skip - // it and let some other process figure out what to do with it. Removing invalid transactions is a fast - // process, so we don't need to worry about it here. continue } - - // If we get to this point, it means the transaction didn't connect but it was a valid transaction. We will - // add it to the block as a failing transaction. blockUtxoView = blockUtxoViewCopy - txnConnectStatusByIndex.Set(len(blocksTxns), false) blocksTxns = append(blocksTxns, txn.GetTxn()) currentBlockSize += uint64(len(txnBytes)) - // If the transactor is the block producer, then they won't receive the utility - // fee. + // If the transactor is the block producer, then they won't receive the utility fee. if blockProducerPublicKey.Equal(*NewPublicKey(txn.PublicKey)) { continue } + // Compute BMF for the transaction. + _, utilityFee := computeBMF(fees) maxUtilityFee, err = SafeUint64().Add(maxUtilityFee, utilityFee) if err != nil { - return nil, nil, 0, errors.Wrapf(err, "Error computing max utility fee: ") + return nil, 0, errors.Wrapf(err, "Error computing max utility fee: ") } } - return blocksTxns, txnConnectStatusByIndex, maxUtilityFee, nil + return blocksTxns, maxUtilityFee, nil } func _maxInt64(a, b int64) int64 { diff --git a/lib/pos_block_producer_test.go b/lib/pos_block_producer_test.go index a5f90c519..2fb62beea 100644 --- a/lib/pos_block_producer_test.go +++ b/lib/pos_block_producer_test.go @@ -9,7 +9,6 @@ import ( "time" "github.com/deso-protocol/core/bls" - "github.com/deso-protocol/core/collections/bitset" "github.com/stretchr/testify/require" ) @@ -36,7 +35,7 @@ func TestCreateBlockTemplate(t *testing.T) { mempool := NewPosMempool() require.NoError(mempool.Init( params, globalParams, latestBlockView, 2, dir, false, maxMempoolPosSizeBytes, mempoolBackupIntervalMillis, 1, - nil, 1, 100, + nil, 1, 10000, 100, 100, )) require.NoError(mempool.Start()) defer mempool.Stop() @@ -76,7 +75,6 @@ func TestCreateBlockTemplate(t *testing.T) { require.Equal(blockTemplate.Header.ProposedInView, uint64(10)) require.Equal(blockTemplate.Header.ProposerVotingPublicKey, pub) require.True(blockTemplate.Header.ProposerRandomSeedSignature.Eq(seedSignature)) - require.Equal(blockTemplate.Header.TxnConnectStatusByIndexHash, HashBitset(blockTemplate.TxnConnectStatusByIndex)) } func TestCreateBlockWithoutHeader(t *testing.T) { @@ -102,7 +100,7 @@ func TestCreateBlockWithoutHeader(t *testing.T) { mempool := NewPosMempool() require.NoError(mempool.Init( params, globalParams, latestBlockView, 2, dir, false, maxMempoolPosSizeBytes, mempoolBackupIntervalMillis, 1, - nil, 1, 100, + nil, 1, 10000, 100, 100, )) require.NoError(mempool.Start()) defer mempool.Stop() @@ -122,14 +120,13 @@ func TestCreateBlockWithoutHeader(t *testing.T) { // Test cases where the block producer is the transactor for the mempool txns { pbp := NewPosBlockProducer(mempool, params, NewPublicKey(m0PubBytes), blsPubKey, time.Now().UnixNano()) - txns, txnConnectStatus, _, err := pbp.getBlockTransactions( + txns, _, err := pbp.getBlockTransactions( NewPublicKey(m0PubBytes), latestBlockView, 3, 0, 50000) require.NoError(err) blockTemplate, err := pbp.createBlockWithoutHeader(latestBlockView, 3, 0) require.NoError(err) require.Equal(txns, blockTemplate.Txns[1:]) - require.Equal(txnConnectStatus, blockTemplate.TxnConnectStatusByIndex) require.Equal(uint64(0), blockTemplate.Txns[0].TxOutputs[0].AmountNanos) require.Equal(NewMessage(MsgTypeHeader).(*MsgDeSoHeader), blockTemplate.Header) require.Nil(blockTemplate.BlockProducerInfo) @@ -138,14 +135,13 @@ func TestCreateBlockWithoutHeader(t *testing.T) { // Test cases where the block producer is not the transactor for the mempool txns { pbp := NewPosBlockProducer(mempool, params, NewPublicKey(m1PubBytes), blsPubKey, time.Now().UnixNano()) - txns, txnConnectStatus, maxUtilityFee, err := pbp.getBlockTransactions( + txns, maxUtilityFee, err := pbp.getBlockTransactions( NewPublicKey(m1PubBytes), latestBlockView, 3, 0, 50000) require.NoError(err) blockTemplate, err := pbp.createBlockWithoutHeader(latestBlockView, 3, 0) require.NoError(err) require.Equal(txns, blockTemplate.Txns[1:]) - require.Equal(txnConnectStatus, blockTemplate.TxnConnectStatusByIndex) require.Equal(maxUtilityFee, blockTemplate.Txns[0].TxOutputs[0].AmountNanos) require.Equal(NewMessage(MsgTypeHeader).(*MsgDeSoHeader), blockTemplate.Header) require.Nil(blockTemplate.BlockProducerInfo) @@ -179,7 +175,7 @@ func TestGetBlockTransactions(t *testing.T) { mempool := NewPosMempool() require.NoError(mempool.Init( params, globalParams, latestBlockView, 2, dir, false, maxMempoolPosSizeBytes, mempoolBackupIntervalMillis, 1, - nil, 1, 100, + nil, 1, 10000, 100, 100, )) require.NoError(mempool.Start()) defer mempool.Stop() @@ -246,11 +242,10 @@ func TestGetBlockTransactions(t *testing.T) { latestBlockViewCopy, err := latestBlockView.CopyUtxoView() require.NoError(err) - txns, txnConnectStatus, maxUtilityFee, err := pbp.getBlockTransactions(NewPublicKey(m1PubBytes), latestBlockView, 3, 0, 1000) + txns, maxUtilityFee, err := pbp.getBlockTransactions(NewPublicKey(m1PubBytes), latestBlockView, 3, 0, 1000) require.NoError(err) require.Equal(latestBlockViewCopy, latestBlockView) require.Equal(true, len(passingTxns) > len(txns)) - require.Equal(true, len(passingTxns) > txnConnectStatus.Size()) totalUtilityFee = 0 for _, txn := range txns { _, utilityFee := computeBMF(txn.TxnFeeNanos) @@ -264,15 +259,15 @@ func TestGetBlockTransactions(t *testing.T) { testMempool := NewPosMempool() testMempool.Init( params, globalParams, latestBlockView, 2, "", true, maxMempoolPosSizeBytes, mempoolBackupIntervalMillis, 1, - nil, 1, 100, + nil, 1, 10000, 100, 100, ) require.NoError(testMempool.Start()) defer testMempool.Stop() currentTime := time.Now() for ii, txn := range txns { // Use the Simulated Transaction Timestamp. - mtxn := NewMempoolTransaction(txn, currentTime.Add(time.Duration(ii)*time.Microsecond)) - require.NoError(testMempool.AddTransaction(mtxn, false)) + mtxn := NewMempoolTransaction(txn, currentTime.Add(time.Duration(ii)*time.Microsecond), false) + require.NoError(testMempool.AddTransaction(mtxn)) } newTxns := testMempool.GetTransactions() require.Equal(len(txns), len(newTxns)) @@ -282,7 +277,7 @@ func TestGetBlockTransactions(t *testing.T) { } func _testProduceBlockNoSizeLimit(t *testing.T, mp *PosMempool, pbp *PosBlockProducer, latestBlockView *UtxoView, blockHeight uint64, - numPassing int, numFailing int, numInvalid int) (_txns []*MsgDeSoTxn, _txnConnectStatusByIndex *bitset.Bitset, _maxUtilityFee uint64) { + numPassing int, numFailing int, numInvalid int) (_txns []*MsgDeSoTxn, _maxUtilityFee uint64) { require := require.New(t) totalAcceptedTxns := numPassing + numFailing @@ -291,17 +286,10 @@ func _testProduceBlockNoSizeLimit(t *testing.T, mp *PosMempool, pbp *PosBlockPro latestBlockViewCopy, err := latestBlockView.CopyUtxoView() require.NoError(err) - txns, txnConnectStatus, maxUtilityFee, err := pbp.getBlockTransactions(pbp.proposerPublicKey, latestBlockView, blockHeight, 0, math.MaxUint64) + txns, maxUtilityFee, err := pbp.getBlockTransactions(pbp.proposerPublicKey, latestBlockView, blockHeight, 0, math.MaxUint64) require.NoError(err) require.Equal(latestBlockViewCopy, latestBlockView) require.Equal(totalAcceptedTxns, len(txns)) - require.True(totalAcceptedTxns >= txnConnectStatus.Size()) - numConnected := 0 - for ii := range txns { - if txnConnectStatus.Get(ii) { - numConnected++ - } - } - require.Equal(numPassing, numConnected) - return txns, txnConnectStatus, maxUtilityFee + + return txns, maxUtilityFee } diff --git a/lib/pos_blockchain.go b/lib/pos_blockchain.go index 65ade5b80..72044588b 100644 --- a/lib/pos_blockchain.go +++ b/lib/pos_blockchain.go @@ -53,6 +53,19 @@ func (bc *Blockchain) processHeaderPoS(header *MsgDeSoHeader) ( return false, false, errors.Wrapf(err, "processHeaderPoS: Problem hashing header") } + // If the incoming header is already part of the best header chain, then we can exit early. + // The header is not part of a fork, and is already an ancestor of the current header chain tip. + if _, isInBestHeaderChain := bc.bestHeaderChainMap[*headerHash]; isInBestHeaderChain { + return true, false, nil + } + + // If the incoming header is part of a reorg that uncommits the committed tip from the best chain, + // then we exit early. Such headers are invalid and should not be synced. + committedBlockchainTip, _ := bc.GetCommittedTip() + if committedBlockchainTip != nil && committedBlockchainTip.Header.Height >= header.Height { + return false, false, errors.New("processHeaderPoS: Header conflicts with committed tip") + } + // Validate the header and index it in the block index. blockNode, isOrphan, err := bc.validateAndIndexHeaderPoS(header, headerHash) if err != nil { @@ -257,7 +270,7 @@ func (bc *Blockchain) processBlockPoS(block *MsgDeSoBlock, currentView uint64, v } if !passedSpamPreventionCheck { // If the block fails the spam prevention check, we throw it away. - return false, false, nil, errors.New("processBlockPoS: Block failed spam prevention check") + return false, false, nil, errors.Wrapf(RuleErrorFailedSpamPreventionsCheck, "processBlockPoS: Block failed spam prevention check: ") } // Validate the block and store it in the block index. The block is guaranteed to not be an orphan. @@ -899,16 +912,6 @@ func (bc *Blockchain) isProperlyFormedBlockPoS(block *MsgDeSoBlock) error { return RuleErrorBlockWithNoTxns } - // Make sure TxnConnectStatusByIndex is non-nil - if block.TxnConnectStatusByIndex == nil { - return RuleErrorNilTxnConnectStatusByIndex - } - - // Make sure the TxnConnectStatusByIndex matches the TxnConnectStatusByIndexHash - if !(HashBitset(block.TxnConnectStatusByIndex).IsEqual(block.Header.TxnConnectStatusByIndexHash)) { - return RuleErrorTxnConnectStatusByIndexHashMismatch - } - // Make sure that the first txn in each block is a block reward txn. if block.Txns[0].TxnMeta.GetTxnType() != TxnTypeBlockReward { return RuleErrorBlockDoesNotStartWithRewardTxn @@ -949,11 +952,6 @@ func (bc *Blockchain) isProperlyFormedBlockHeaderPoS(header *MsgDeSoHeader) erro return RuleErrorInvalidPoSBlockHeaderVersion } - // Must have TxnConnectStatusByIndexHash - if header.TxnConnectStatusByIndexHash == nil { - return RuleErrorNilTxnConnectStatusByIndexHash - } - // Require header to have either vote or timeout QC isTimeoutQCEmpty := header.ValidatorsTimeoutAggregateQC.isEmpty() isVoteQCEmpty := header.ValidatorsVoteQC.isEmpty() @@ -1227,7 +1225,7 @@ func (bc *Blockchain) isValidPoSQuorumCertificate(block *MsgDeSoBlock, validator // including the committed tip. The first block in the returned slice is the first uncommitted // ancestor. func (bc *Blockchain) getLineageFromCommittedTip(header *MsgDeSoHeader) ([]*BlockNode, error) { - highestCommittedBlock, idx := bc.getCommittedTip() + highestCommittedBlock, idx := bc.GetCommittedTip() if idx == -1 || highestCommittedBlock == nil { return nil, errors.New("getLineageFromCommittedTip: No committed blocks found") } @@ -1549,8 +1547,6 @@ func (bc *Blockchain) shouldReorg(blockNode *BlockNode, currentView uint64) bool func (bc *Blockchain) addTipBlockToBestChain(blockNode *BlockNode) { bc.bestChain = append(bc.bestChain, blockNode) bc.bestChainMap[*blockNode.Hash] = blockNode - bc.bestHeaderChain = append(bc.bestHeaderChain, blockNode) - bc.bestHeaderChainMap[*blockNode.Hash] = blockNode } // removeTipBlockFromBestChain removes the current tip from the best chain. It @@ -1562,8 +1558,6 @@ func (bc *Blockchain) removeTipBlockFromBestChain() *BlockNode { lastBlock := bc.bestChain[len(bc.bestChain)-1] delete(bc.bestChainMap, *lastBlock.Hash) bc.bestChain = bc.bestChain[:len(bc.bestChain)-1] - bc.bestHeaderChain = bc.bestHeaderChain[:len(bc.bestChain)] - delete(bc.bestHeaderChainMap, *lastBlock.Hash) return lastBlock } @@ -1579,7 +1573,7 @@ func (bc *Blockchain) runCommitRuleOnBestChain() error { return nil } // Find all uncommitted ancestors of block to commit - _, idx := bc.getCommittedTip() + _, idx := bc.GetCommittedTip() if idx == -1 { // This is an edge case we'll never hit in practice since all the PoW blocks // are committed. @@ -1722,6 +1716,50 @@ func (bc *Blockchain) commitBlockPoS(blockHash *BlockHash) error { return nil } +// GetUncommittedFullBlocks is a helper that the state syncer uses to fetch all uncommitted +// blocks, so it can flush them just like we would with mempool transactions. It returns +// all uncommitted blocks from the specified tip to the last uncommitted block. +// Note: it would be more efficient if we cached these results. +func (bc *Blockchain) GetUncommittedFullBlocks(tipHash *BlockHash) ([]*MsgDeSoBlock, error) { + if tipHash == nil { + tipHash = bc.BlockTip().Hash + } + bc.ChainLock.RLock() + defer bc.ChainLock.RUnlock() + tipBlock, exists := bc.bestChainMap[*tipHash] + if !exists { + return nil, errors.Errorf("GetUncommittedFullBlocks: Block %v not found in best chain map", tipHash.String()) + } + // If the tip block is committed, we can't get uncommitted blocks from it so we return an empty slice. + if tipBlock.IsCommitted() { + return []*MsgDeSoBlock{}, nil + } + var uncommittedBlocks []*MsgDeSoBlock + currentBlock := tipBlock + for !currentBlock.IsCommitted() { + fullBlock, err := GetBlock(currentBlock.Hash, bc.db, bc.snapshot) + if err != nil { + return nil, errors.Wrapf(err, "GetUncommittedFullBlocks: Problem fetching block %v", + currentBlock.Hash.String()) + } + uncommittedBlocks = append(uncommittedBlocks, fullBlock) + currentParentHash := currentBlock.Header.PrevBlockHash + if currentParentHash == nil { + return nil, errors.Errorf("GetUncommittedFullBlocks: Block %v has nil PrevBlockHash", currentBlock.Hash) + } + currentBlock = bc.blockIndexByHash[*currentParentHash] + if currentBlock == nil { + return nil, errors.Errorf("GetUncommittedFullBlocks: Block %v not found in block index", currentBlock.Hash) + } + } + return collections.Reverse(uncommittedBlocks), nil +} + +// GetCommittedTipView builds a UtxoView to the committed tip. +func (bc *Blockchain) GetCommittedTipView() (*UtxoView, error) { + return NewUtxoView(bc.db, bc.params, bc.postgres, bc.snapshot, nil) +} + // GetUncommittedTipView builds a UtxoView to the uncommitted tip. func (bc *Blockchain) GetUncommittedTipView() (*UtxoView, error) { // Connect the uncommitted blocks to the tip so that we can validate subsequent blocks @@ -1739,7 +1777,7 @@ func (bc *Blockchain) getUtxoViewAtBlockHash(blockHash BlockHash) (*UtxoView, er // If the provided block is committed, we need to make sure it's the committed tip. // Otherwise, we return an error. if currentBlock.IsCommitted() { - highestCommittedBlock, _ := bc.getCommittedTip() + highestCommittedBlock, _ := bc.GetCommittedTip() if highestCommittedBlock == nil { return nil, errors.Errorf("getUtxoViewAtBlockHash: No committed blocks found") } @@ -1786,8 +1824,8 @@ func (bc *Blockchain) getUtxoViewAtBlockHash(blockHash BlockHash) (*UtxoView, er return utxoView, nil } -// getCommittedTip returns the highest committed block and its index in the best chain. -func (bc *Blockchain) getCommittedTip() (*BlockNode, int) { +// GetCommittedTip returns the highest committed block and its index in the best chain. +func (bc *Blockchain) GetCommittedTip() (*BlockNode, int) { for ii := len(bc.bestChain) - 1; ii >= 0; ii-- { if bc.bestChain[ii].IsCommitted() { return bc.bestChain[ii], ii @@ -1818,7 +1856,7 @@ func (bc *Blockchain) GetSafeBlocks() ([]*MsgDeSoHeader, error) { func (bc *Blockchain) getSafeBlockNodes() ([]*BlockNode, error) { // First get committed tip. - committedTip, idx := bc.getCommittedTip() + committedTip, idx := bc.GetCommittedTip() if idx == -1 || committedTip == nil { return nil, errors.New("getSafeBlockNodes: No committed blocks found") } @@ -1909,9 +1947,6 @@ const ( RuleErrorPoSBlockTstampNanoSecsTooOld RuleError = "RuleErrorPoSBlockTstampNanoSecsTooOld" RuleErrorPoSBlockTstampNanoSecsInFuture RuleError = "RuleErrorPoSBlockTstampNanoSecsInFuture" RuleErrorInvalidPoSBlockHeaderVersion RuleError = "RuleErrorInvalidPoSBlockHeaderVersion" - RuleErrorNilTxnConnectStatusByIndex RuleError = "RuleErrorNilTxnConnectStatusByIndex" - RuleErrorNilTxnConnectStatusByIndexHash RuleError = "RuleErrorNilTxnConnectStatusByIndexHash" - RuleErrorTxnConnectStatusByIndexHashMismatch RuleError = "RuleErrorTxnConnectStatusByIndexHashMismatch" RuleErrorNoTimeoutOrVoteQC RuleError = "RuleErrorNoTimeoutOrVoteQC" RuleErrorBothTimeoutAndVoteQC RuleError = "RuleErrorBothTimeoutAndVoteQC" RuleErrorBlockWithNoTxns RuleError = "RuleErrorBlockWithNoTxns" @@ -1922,6 +1957,7 @@ const ( RuleErrorAncestorBlockValidationFailed RuleError = "RuleErrorAncestorBlockValidationFailed" RuleErrorParentBlockHasViewGreaterOrEqualToChildBlock RuleError = "RuleErrorParentBlockHasViewGreaterOrEqualToChildBlock" RuleErrorParentBlockHeightNotSequentialWithChildBlockHeight RuleError = "RuleErrorParentBlockHeightNotSequentialWithChildBlockHeight" + RuleErrorFailedSpamPreventionsCheck RuleError = "RuleErrorFailedSpamPreventionsCheck" RuleErrorNilMerkleRoot RuleError = "RuleErrorNilMerkleRoot" RuleErrorInvalidMerkleRoot RuleError = "RuleErrorInvalidMerkleRoot" diff --git a/lib/pos_blockchain_test.go b/lib/pos_blockchain_test.go index 8f40597e8..b7fbc25dd 100644 --- a/lib/pos_blockchain_test.go +++ b/lib/pos_blockchain_test.go @@ -77,10 +77,8 @@ func TestIsProperlyFormedBlockPoSAndIsBlockTimestampValidRelativeToParentPoS(t * ProposerRandomSeedSignature: signature, ProposerVotingPublicKey: randomBLSPrivateKey.PublicKey(), TransactionMerkleRoot: merkleRoot, - TxnConnectStatusByIndexHash: HashBitset(bitset.NewBitset().Set(0, true)), }, - Txns: txns, - TxnConnectStatusByIndex: bitset.NewBitset().Set(0, true), + Txns: txns, } // Validate the block with a valid timeout QC and header. @@ -183,23 +181,6 @@ func TestIsProperlyFormedBlockPoSAndIsBlockTimestampValidRelativeToParentPoS(t * }, } - // TxnConnectStatusByIndex tests - // TxnConnectStatusByIndex must be non-nil - block.TxnConnectStatusByIndex = nil - err = bc.isProperlyFormedBlockPoS(block) - require.Equal(t, err, RuleErrorNilTxnConnectStatusByIndex) - // TxnConnectStatusByIndexHash must be non-nil - block.TxnConnectStatusByIndex = bitset.NewBitset().Set(0, true) - block.Header.TxnConnectStatusByIndexHash = nil - err = bc.isProperlyFormedBlockPoS(block) - require.Equal(t, err, RuleErrorNilTxnConnectStatusByIndexHash) - // The hashed version of TxnConnectStatusByIndex must match the actual TxnConnectStatusByIndexHash - block.Header.TxnConnectStatusByIndexHash = HashBitset(bitset.NewBitset().Set(0, false)) - err = bc.isProperlyFormedBlockPoS(block) - require.Equal(t, err, RuleErrorTxnConnectStatusByIndexHashMismatch) - // Reset TxnConnectStatusByIndexHash - block.Header.TxnConnectStatusByIndexHash = HashBitset(block.TxnConnectStatusByIndex) - // Block must have valid proposer voting public key block.Header.ProposerVotingPublicKey = nil err = bc.isProperlyFormedBlockPoS(block) @@ -386,14 +367,12 @@ func TestUpsertBlockAndBlockNodeToDB(t *testing.T) { SignersList: bitset.NewBitset(), }, }, - TxnConnectStatusByIndexHash: NewBlockHash(bitset.NewBitset().ToBytes()), }, Txns: []*MsgDeSoTxn{ { TxnMeta: &BlockRewardMetadataa{}, }, }, - TxnConnectStatusByIndex: bitset.NewBitset(), } blockNode, err := bc.storeBlockInBlockIndex(block) require.NoError(t, err) @@ -1891,8 +1870,6 @@ func testProcessBlockPoS(t *testing.T, testMeta *TestMeta) { var malformedOrphanBlock *MsgDeSoBlock malformedOrphanBlock = _generateRealBlock(testMeta, 18, 18, 9273, testMeta.chain.BlockTip().Hash, false) malformedOrphanBlock.Header.PrevBlockHash = randomHash - // Modify anything to make the block malformed, but make sure a hash can still be generated. - malformedOrphanBlock.Header.TxnConnectStatusByIndexHash = randomHash // Resign the block. updateProposerVotePartialSignatureForBlock(testMeta, malformedOrphanBlock) malformedOrphanBlockHash, err := malformedOrphanBlock.Hash() @@ -1909,7 +1886,6 @@ func testProcessBlockPoS(t *testing.T, testMeta *TestMeta) { require.True(t, malformedOrphanBlockInIndex.IsStored()) // If a block can't be hashed, we expect to get an error. - malformedOrphanBlock.Header.TxnConnectStatusByIndexHash = nil success, isOrphan, missingBlockHashes, err = testMeta.chain.ProcessBlockPoS(malformedOrphanBlock, 18, true) require.False(t, success) require.False(t, isOrphan) @@ -1920,7 +1896,6 @@ func testProcessBlockPoS(t *testing.T, testMeta *TestMeta) { { var blockWithFailingTxn *MsgDeSoBlock blockWithFailingTxn = _generateRealBlockWithFailingTxn(testMeta, 18, 18, 123722, orphanBlockHash, false, 1, 0) - require.Equal(t, blockWithFailingTxn.TxnConnectStatusByIndex.Get(len(blockWithFailingTxn.Txns)-1), false) success, _, _, err := testMeta.chain.ProcessBlockPoS(blockWithFailingTxn, 18, true) require.True(t, success) blockWithFailingTxnHash, err = blockWithFailingTxn.Hash() @@ -2595,9 +2570,8 @@ func _getVoteQC(testMeta *TestMeta, blockHeight uint64, qcBlockHash *BlockHash, return voteQC } -// _getFullRealBlockTemplate is a helper function that generates a block template with a valid TxnConnectStatusByIndexHash -// and a valid TxnConnectStatusByIndex, a valid vote or timeout QC, does all the required signing by validators, -// and generates the proper ProposerVotePartialSignature. +// _getFullRealBlockTemplate is a helper function that generates a block template with a valid vote or timeout QC, +// does all the required signing by validators, and generates the proper ProposerVotePartialSignature. func _getFullRealBlockTemplate( testMeta *TestMeta, blockHeight uint64, @@ -2610,7 +2584,6 @@ func _getFullRealBlockTemplate( testMeta.posMempool.readOnlyLatestBlockView, blockHeight, view, seedSignature) require.NoError(testMeta.t, err) require.NotNil(testMeta.t, blockTemplate) - blockTemplate.Header.TxnConnectStatusByIndexHash = HashBitset(blockTemplate.TxnConnectStatusByIndex) // Figure out who the leader is supposed to be. leaderPublicKey, leaderPublicKeyBytes := getLeaderForBlockHeightAndView(testMeta, blockHeight, view) @@ -2686,13 +2659,11 @@ func _getFullRealBlockTemplate( return blockTemplate } -// _getFullDummyBlockTemplate is a helper function that generates a block template with a dummy TxnConnectStatusByIndexHash -// and a dummy ValidatorsVoteQC. +// _getFullDummyBlockTemplate is a helper function that generates a block template with a dummy ValidatorsVoteQC. func _getFullDummyBlockTemplate(testMeta *TestMeta, latestBlockView *UtxoView, blockHeight uint64, view uint64, seedSignature *bls.Signature) BlockTemplate { blockTemplate, err := testMeta.posBlockProducer.createBlockTemplate(latestBlockView, blockHeight, view, seedSignature) require.NoError(testMeta.t, err) require.NotNil(testMeta.t, blockTemplate) - blockTemplate.Header.TxnConnectStatusByIndexHash = HashBitset(blockTemplate.TxnConnectStatusByIndex) // Add a dummy vote QC proposerVotingPublicKey := _generateRandomBLSPrivateKey(testMeta.t) dummySig, err := proposerVotingPublicKey.Sign(RandomBytes(32)) @@ -2790,7 +2761,7 @@ func NewTestPoSBlockchainWithValidators(t *testing.T) *TestMeta { mempool := NewPosMempool() require.NoError(t, mempool.Init( params, _testGetDefaultGlobalParams(), latestBlockView, 11, _dbDirSetup(t), false, maxMempoolPosSizeBytes, - mempoolBackupIntervalMillis, 1, nil, 1, 100, + mempoolBackupIntervalMillis, 1, nil, 1, 10000, 100, 100, )) require.NoError(t, mempool.Start()) require.True(t, mempool.IsRunning()) diff --git a/lib/pos_consensus.go b/lib/pos_consensus.go index 99c582d71..18ca73e26 100644 --- a/lib/pos_consensus.go +++ b/lib/pos_consensus.go @@ -14,6 +14,7 @@ import ( type FastHotStuffConsensus struct { lock sync.RWMutex + networkManager *NetworkManager blockchain *Blockchain fastHotStuffEventLoop consensus.FastHotStuffEventLoop mempool Mempool @@ -25,6 +26,7 @@ type FastHotStuffConsensus struct { func NewFastHotStuffConsensus( params *DeSoParams, + networkManager *NetworkManager, blockchain *Blockchain, mempool Mempool, signer *BLSSigner, @@ -32,6 +34,7 @@ func NewFastHotStuffConsensus( timeoutBaseDurationMilliseconds uint64, ) *FastHotStuffConsensus { return &FastHotStuffConsensus{ + networkManager: networkManager, blockchain: blockchain, fastHotStuffEventLoop: consensus.NewFastHotStuffEventLoop(), mempool: mempool, @@ -99,10 +102,22 @@ func (fc *FastHotStuffConsensus) Start() error { blockProductionInterval := time.Millisecond * time.Duration(fc.blockProductionIntervalMilliseconds) timeoutBaseDuration := time.Millisecond * time.Duration(fc.timeoutBaseDurationMilliseconds) - // Initialize and start the event loop - fc.fastHotStuffEventLoop.Init(blockProductionInterval, timeoutBaseDuration, genesisQC, tipBlockWithValidators[0], safeBlocksWithValidators) + // Initialize the event loop. This should never fail. If it does, we return the error to the caller. + // The caller handle the error and decide when to retry. + err = fc.fastHotStuffEventLoop.Init(blockProductionInterval, timeoutBaseDuration, genesisQC, tipBlockWithValidators[0], safeBlocksWithValidators) + if err != nil { + return errors.Errorf("FastHotStuffConsensus.Start: Error initializing FastHotStuffEventLoop: %v", err) + } + + // Start the event loop fc.fastHotStuffEventLoop.Start() + // Update the validator connections in the NetworkManager. This is a best effort operation. If it fails, + // we log the error and continue. + if err = fc.updateActiveValidatorConnections(); err != nil { + glog.Errorf("FastHotStuffConsensus.tryProcessBlockAsNewTip: Error updating validator connections: %v", err) + } + return nil } @@ -281,7 +296,18 @@ func (fc *FastHotStuffConsensus) handleBlockProposalEvent( ) } - // TODO: Broadcast the block proposal to the network + // Broadcast the block to the validator network + validators := fc.networkManager.GetValidatorIndex().GetAll() + for _, validator := range validators { + sendMessageToRemoteNodeAsync(validator, blockProposal) + } + + // Broadcast the block to all inbound non-validator peers. This allows them to sync + // blocks from us. + nonValidators := fc.networkManager.GetNonValidatorInboundIndex().GetAll() + for _, nonValidator := range nonValidators { + sendMessageToRemoteNodeAsync(nonValidator, blockProposal) + } fc.logBlockProposal(blockProposal, blockHash) return nil @@ -346,8 +372,11 @@ func (fc *FastHotStuffConsensus) HandleLocalVoteEvent(event *consensus.FastHotSt return errors.Errorf("FastHotStuffConsensus.HandleLocalVoteEvent: Error processing vote locally: %v", err) } - // Broadcast the vote message to the network - // TODO: Broadcast the vote message to the network or alternatively to just the block proposer + // Broadcast the block to the validator network + validators := fc.networkManager.GetValidatorIndex().GetAll() + for _, validator := range validators { + sendMessageToRemoteNodeAsync(validator, voteMsg) + } return nil } @@ -355,6 +384,8 @@ func (fc *FastHotStuffConsensus) HandleLocalVoteEvent(event *consensus.FastHotSt // HandleValidatorVote is called when we receive a validator vote message from a peer. This function processes // the vote locally in the FastHotStuffEventLoop. func (fc *FastHotStuffConsensus) HandleValidatorVote(pp *Peer, msg *MsgDeSoValidatorVote) error { + glog.V(2).Infof("FastHotStuffConsensus.HandleValidatorVote: Received vote msg %s", msg.ToString()) + // No need to hold a lock on the consensus because this function is a pass-through // for the FastHotStuffEventLoop which guarantees thread-safety for its callers @@ -362,7 +393,8 @@ func (fc *FastHotStuffConsensus) HandleValidatorVote(pp *Peer, msg *MsgDeSoValid if err := fc.fastHotStuffEventLoop.ProcessValidatorVote(msg); err != nil { // If we can't process the vote locally, then it must somehow be malformed, stale, // or a duplicate vote/timeout for the same view. - return errors.Wrapf(err, "FastHotStuffConsensus.HandleValidatorVote: Error processing vote: ") + glog.Errorf("FastHotStuffConsensus.HandleValidatorVote: Error processing vote msg: %v", err) + return errors.Wrapf(err, "FastHotStuffConsensus.HandleValidatorVote: Error processing vote msg: ") } // Happy path @@ -461,8 +493,11 @@ func (fc *FastHotStuffConsensus) HandleLocalTimeoutEvent(event *consensus.FastHo return errors.Errorf("FastHotStuffConsensus.HandleLocalTimeoutEvent: Error processing timeout locally: %v", err) } - // Broadcast the timeout message to the network - // TODO: Broadcast the timeout message to the network or alternatively to just the block proposer + // Broadcast the block to the validator network + validators := fc.networkManager.GetValidatorIndex().GetAll() + for _, validator := range validators { + sendMessageToRemoteNodeAsync(validator, timeoutMsg) + } return nil } @@ -470,14 +505,32 @@ func (fc *FastHotStuffConsensus) HandleLocalTimeoutEvent(event *consensus.FastHo // HandleValidatorTimeout is called when we receive a validator timeout message from a peer. This function // processes the timeout locally in the FastHotStuffEventLoop. func (fc *FastHotStuffConsensus) HandleValidatorTimeout(pp *Peer, msg *MsgDeSoValidatorTimeout) error { - // No need to hold a lock on the consensus because this function is a pass-through - // for the FastHotStuffEventLoop which guarantees thread-safety for its callers. + glog.V(2).Infof("FastHotStuffConsensus.HandleLocalTimeoutEvent: Received timeout msg: %s", msg.ToString()) + + // Hold a write lock on the consensus, since we need to update the timeout message in the + // FastHotStuffEventLoop. + fc.lock.Lock() + defer fc.lock.Unlock() + + if !fc.fastHotStuffEventLoop.IsRunning() { + return errors.Errorf("FastHotStuffConsensus.HandleValidatorTimeout: FastHotStuffEventLoop is not running") + } + + // If we don't have the highQC's block on hand, then we need to request it from the peer. We do + // that first before storing the timeout message locally in the FastHotStuffEventLoop. This + // prevents spamming of timeout messages by peers. + if !fc.blockchain.HasBlockInBlockIndex(msg.HighQC.BlockHash) { + fc.trySendMessageToPeer(pp, &MsgDeSoGetBlocks{HashList: []*BlockHash{msg.HighQC.BlockHash}}) + glog.Errorf("FastHotStuffConsensus.HandleValidatorTimeout: Requesting missing highQC's block: %v", msg.HighQC.BlockHash) + return errors.Errorf("FastHotStuffConsensus.HandleValidatorTimeout: Missing highQC's block: %v", msg.HighQC.BlockHash) + } // Process the timeout message locally in the FastHotStuffEventLoop if err := fc.fastHotStuffEventLoop.ProcessValidatorTimeout(msg); err != nil { // If we can't process the timeout locally, then it must somehow be malformed, stale, // or a duplicate vote/timeout for the same view. - return errors.Wrapf(err, "FastHotStuffConsensus.HandleValidatorTimeout: Error processing timeout: ") + glog.Errorf("FastHotStuffConsensus.HandleValidatorTimeout: Error processing timeout msg: %v", err) + return errors.Wrapf(err, "FastHotStuffConsensus.HandleValidatorTimeout: Error processing timeout msg: ") } // Happy path @@ -519,9 +572,11 @@ func (fc *FastHotStuffConsensus) HandleBlock(pp *Peer, msg *MsgDeSoBlock) error // // See https://github.com/deso-protocol/core/pull/875#discussion_r1460183510 for more details. if len(missingBlockHashes) > 0 { - pp.QueueMessage(&MsgDeSoGetBlocks{ - HashList: missingBlockHashes, - }) + remoteNode := fc.networkManager.GetRemoteNodeFromPeer(pp) + if remoteNode == nil { + return errors.Errorf("FastHotStuffConsensus.HandleBlock: RemoteNode not found for peer: %v", pp) + } + sendMessageToRemoteNodeAsync(remoteNode, &MsgDeSoGetBlocks{HashList: missingBlockHashes}) } return nil @@ -599,6 +654,12 @@ func (fc *FastHotStuffConsensus) tryProcessBlockAsNewTip(block *MsgDeSoBlock) ([ return nil, errors.Errorf("Error processing tip block locally: %v", err) } + // Update the validator connections in the NetworkManager. This is a best effort operation. If it fails, + // we log the error and continue. + if err = fc.updateActiveValidatorConnections(); err != nil { + glog.Errorf("FastHotStuffConsensus.tryProcessBlockAsNewTip: Error updating validator connections: %v", err) + } + // Happy path. The block was processed successfully and applied as the new tip. Nothing left to do. return nil, nil } @@ -775,6 +836,70 @@ func (fc *FastHotStuffConsensus) createBlockProducer(bav *UtxoView, previousBloc return blockProducer, nil } +func (fc *FastHotStuffConsensus) updateActiveValidatorConnections() error { + // Fetch the committed tip view. This ends up being as good as using the uncommitted tip view + // but without the overhead of connecting at least two blocks' worth of txns to the view. + utxoView, err := fc.blockchain.GetCommittedTipView() + if err != nil { + return errors.Errorf("FastHotStuffConsensus.Start: Error fetching uncommitted tip view: %v", err) + } + + // Get the current snapshot epoch number from the committed tip. This will be behind the uncommitted tip + // by up to two blocks, but this is fine since we fetch both the current epoch's and next epoch's validator + // sets. + snapshotEpochNumber, err := utxoView.GetCurrentSnapshotEpochNumber() + if err != nil { + return errors.Errorf("FastHotStuffConsensus.Start: Error fetching snapshot epoch number: %v", err) + } + + // Fetch the current snapshot epoch's validator set. + currentValidatorList, err := utxoView.GetAllSnapshotValidatorSetEntriesByStakeAtEpochNumber(snapshotEpochNumber) + if err != nil { + return errors.Errorf("FastHotStuffConsensus.Start: Error fetching validator list: %v", err) + } + + // Fetch the next snapshot epoch's validator set. This is useful when we're close to epoch transitions and + // allows us to pre-connect to the next epoch's validator set. In the event that there is a timeout at + // the epoch transition, reverting us to the previous epoch, this allows us to maintain connections to the + // next epoch's validators. + // + // TODO: There is an optimization we can add here to only fetch the next epoch's validator list once we're + // within 300 blocks of the next epoch. This way, we don't prematurely attempt connections to the next + // epoch's validators. In production, this will reduce the lead time with which we connect to the next epoch's + // validator set from 1 hour to 5 minutes. + nextValidatorList, err := utxoView.GetAllSnapshotValidatorSetEntriesByStakeAtEpochNumber(snapshotEpochNumber + 1) + if err != nil { + return errors.Errorf("FastHotStuffConsensus.Start: Error fetching validator list: %v", err) + } + + // Merge the current and next validator lists. Place the current epoch's validators last so that they override + // the next epoch's validators in the event of a conflict. + mergedValidatorList := append(nextValidatorList, currentValidatorList...) + validatorsMap := collections.NewConcurrentMap[bls.SerializedPublicKey, consensus.Validator]() + for _, validator := range mergedValidatorList { + if validator.VotingPublicKey.Eq(fc.signer.GetPublicKey()) { + continue + } + validatorsMap.Set(validator.VotingPublicKey.Serialize(), validator) + } + + // Update the active validators map in the network manager + fc.networkManager.SetActiveValidatorsMap(validatorsMap) + + return nil +} + +func (fc *FastHotStuffConsensus) trySendMessageToPeer(pp *Peer, msg DeSoMessage) { + remoteNode := fc.networkManager.GetRemoteNodeFromPeer(pp) + if remoteNode == nil { + glog.Errorf("FastHotStuffConsensus.trySendMessageToPeer: RemoteNode not found for peer: %v", pp) + return + } + + // Send the message to the peer + remoteNode.SendMessage(msg) +} + // Finds the epoch entry for the block and returns the epoch number. func getEpochEntryForBlockHeight(blockHeight uint64, epochEntries []*EpochEntry) (*EpochEntry, error) { for _, epochEntry := range epochEntries { @@ -818,6 +943,10 @@ func isProperlyFormedBlockProposalEvent(event *consensus.FastHotStuffEvent) bool return false } +func sendMessageToRemoteNodeAsync(remoteNode *RemoteNode, msg DeSoMessage) { + go func(rn *RemoteNode, m DeSoMessage) { rn.SendMessage(m) }(remoteNode, msg) +} + ////////////////////////////////////////// Logging Helper Functions /////////////////////////////////////////////// func (fc *FastHotStuffConsensus) logBlockProposal(block *MsgDeSoBlock, blockHash *BlockHash) { @@ -836,13 +965,15 @@ func (fc *FastHotStuffConsensus) logBlockProposal(block *MsgDeSoBlock, blockHash "\n Timestamp: %d, View: %d, Height: %d, BlockHash: %v"+ "\n Proposer Voting PKey: %s"+ "\n Proposer Signature: %s"+ + "\n Proposer Random Seed Signature: %s"+ "\n High QC View: %d, High QC Num Validators: %d, High QC BlockHash: %s"+ "\n Timeout Agg QC View: %d, Timeout Agg QC Num Validators: %d, Timeout High QC Views: %s"+ "\n Num Block Transactions: %d, Num Transactions Remaining In Mempool: %d"+ - "\n=================================================================================================================", + "\n=================================================================================================================\n", block.Header.GetTstampSecs(), block.Header.GetView(), block.Header.Height, blockHash.String(), block.Header.ProposerVotingPublicKey.ToString(), block.Header.ProposerVotePartialSignature.ToString(), + block.Header.ProposerRandomSeedSignature.ToString(), block.Header.GetQC().GetView(), block.Header.GetQC().GetAggregatedSignature().GetSignersList().Size(), block.Header.PrevBlockHash.String(), aggQCView, aggQCNumValidators, aggQCHighQCViews, len(block.Txns), len(fc.mempool.GetTransactions()), diff --git a/lib/pos_consensus_test.go b/lib/pos_consensus_test.go index 9bdb136a0..b5d1f3bb1 100644 --- a/lib/pos_consensus_test.go +++ b/lib/pos_consensus_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/deso-protocol/core/bls" + "github.com/deso-protocol/core/collections" "github.com/deso-protocol/core/consensus" "github.com/deso-protocol/go-deadlock" "github.com/pkg/errors" @@ -26,7 +27,8 @@ func TestFastHotStuffConsensusHandleLocalVoteEvent(t *testing.T) { // Create a mock consensus fastHotStuffConsensus := FastHotStuffConsensus{ - lock: sync.RWMutex{}, + lock: sync.RWMutex{}, + networkManager: _createMockNetworkManagerForConsensus(), blockchain: &Blockchain{ params: &DeSoTestnetParams, }, @@ -104,7 +106,8 @@ func TestFastHotStuffConsensusHandleLocalTimeoutEvent(t *testing.T) { // Create a mock consensus fastHotStuffConsensus := FastHotStuffConsensus{ - lock: sync.RWMutex{}, + lock: sync.RWMutex{}, + networkManager: _createMockNetworkManagerForConsensus(), signer: &BLSSigner{ privateKey: blsPrivateKey, }, @@ -200,3 +203,12 @@ func TestFastHotStuffConsensusHandleLocalTimeoutEvent(t *testing.T) { func alwaysReturnTrue() bool { return true } + +func _createMockNetworkManagerForConsensus() *NetworkManager { + return &NetworkManager{ + AllRemoteNodes: collections.NewConcurrentMap[RemoteNodeId, *RemoteNode](), + ValidatorIndex: collections.NewConcurrentMap[bls.SerializedPublicKey, *RemoteNode](), + NonValidatorOutboundIndex: collections.NewConcurrentMap[RemoteNodeId, *RemoteNode](), + NonValidatorInboundIndex: collections.NewConcurrentMap[RemoteNodeId, *RemoteNode](), + } +} diff --git a/lib/pos_fee_estimator.go b/lib/pos_fee_estimator.go index aae60d662..ad43ca405 100644 --- a/lib/pos_fee_estimator.go +++ b/lib/pos_fee_estimator.go @@ -230,16 +230,18 @@ func (posFeeEstimator *PoSFeeEstimator) sortCachedBlocks() { // and past blocks using the congestionFactorBasisPoints, priorityPercentileBasisPoints, and // maxBlockSize params. func (posFeeEstimator *PoSFeeEstimator) EstimateFeeRateNanosPerKB( - congestionFactorBasisPoints uint64, - priorityPercentileBasisPoints uint64, + mempoolCongestionFactorBasisPoints uint64, + mempoolPriorityPercentileBasisPoints uint64, + pastBlocksCongestionFactorBasisPoints uint64, + pastBlocksPriorityPercentileBasisPoints uint64, maxBlockSize uint64, ) (uint64, error) { posFeeEstimator.rwLock.RLock() defer posFeeEstimator.rwLock.RUnlock() pastBlockFeeRate, err := posFeeEstimator.estimateFeeRateNanosPerKBGivenTransactionRegister( posFeeEstimator.pastBlocksTransactionRegister, - congestionFactorBasisPoints, - priorityPercentileBasisPoints, + pastBlocksCongestionFactorBasisPoints, + pastBlocksPriorityPercentileBasisPoints, posFeeEstimator.numPastBlocks, maxBlockSize, ) @@ -248,8 +250,8 @@ func (posFeeEstimator *PoSFeeEstimator) EstimateFeeRateNanosPerKB( } mempoolFeeRate, err := posFeeEstimator.estimateFeeRateNanosPerKBGivenTransactionRegister( posFeeEstimator.mempoolTransactionRegister, - congestionFactorBasisPoints, - priorityPercentileBasisPoints, + mempoolCongestionFactorBasisPoints, + mempoolPriorityPercentileBasisPoints, posFeeEstimator.numMempoolBlocks, maxBlockSize, ) diff --git a/lib/pos_fee_estimator_test.go b/lib/pos_fee_estimator_test.go index 08e1a18ee..e672e2460 100644 --- a/lib/pos_fee_estimator_test.go +++ b/lib/pos_fee_estimator_test.go @@ -26,7 +26,7 @@ func TestFeeEstimator(t *testing.T) { mempool := NewPosMempool() err = mempool.Init( params, globalParams, latestBlockView, 2, dir, false, maxMempoolPosSizeBytes, mempoolBackupIntervalMillis, 1, - nil, 1, 100, + nil, 1, 10000, 100, 100, ) require.NoError(t, err) require.NoError(t, mempool.Start()) diff --git a/lib/pos_mempool.go b/lib/pos_mempool.go index 864b78ed2..440b30d1d 100644 --- a/lib/pos_mempool.go +++ b/lib/pos_mempool.go @@ -25,12 +25,11 @@ type Mempool interface { Start() error Stop() IsRunning() bool - AddTransaction(txn *MempoolTransaction, verifySignature bool) error + AddTransaction(txn *MempoolTransaction) error RemoveTransaction(txnHash *BlockHash) error GetTransaction(txnHash *BlockHash) *MempoolTransaction GetTransactions() []*MempoolTransaction GetIterator() MempoolIterator - Refresh() error UpdateLatestBlock(blockView *UtxoView, blockHeight uint64) UpdateGlobalParams(globalParams *GlobalParamsEntry) @@ -51,6 +50,14 @@ type Mempool interface { pastBlocksPriorityPercentileBasisPoints uint64, maxBlockSize uint64, ) (uint64, error) + EstimateFeeRate( + minFeeRateNanosPerKB uint64, + mempoolCongestionFactorBasisPoints uint64, + mempoolPriorityPercentileBasisPoints uint64, + pastBlocksCongestionFactorBasisPoints uint64, + pastBlocksPriorityPercentileBasisPoints uint64, + maxBlockSize uint64, + ) (uint64, error) } type MempoolIterator interface { @@ -63,12 +70,14 @@ type MempoolIterator interface { type MempoolTransaction struct { *MsgDeSoTxn TimestampUnixMicro time.Time + Validated bool } -func NewMempoolTransaction(txn *MsgDeSoTxn, timestamp time.Time) *MempoolTransaction { +func NewMempoolTransaction(txn *MsgDeSoTxn, timestamp time.Time, validated bool) *MempoolTransaction { return &MempoolTransaction{ MsgDeSoTxn: txn, TimestampUnixMicro: timestamp, + Validated: validated, } } @@ -80,11 +89,18 @@ func (mtxn *MempoolTransaction) GetTimestamp() time.Time { return mtxn.TimestampUnixMicro } +func (mtxn *MempoolTransaction) IsValidated() bool { + return mtxn.Validated +} + // PosMempool is used by the node to keep track of uncommitted transactions. The main responsibilities of the PosMempool // include addition/removal of transactions, back up of transaction to database, and retrieval of transactions ordered // by Fee-Time algorithm. More on the Fee-Time algorithm can be found in the documentation of TransactionRegister. type PosMempool struct { sync.RWMutex + startGroup sync.WaitGroup + exitGroup sync.WaitGroup + status PosMempoolStatus // params of the blockchain params *DeSoParams @@ -107,9 +123,6 @@ type PosMempool struct { // The persister runs on its dedicated thread and events are used to notify the persister thread whenever // transactions are added/removed from the mempool. The persister thread then updates the database accordingly. persister *MempoolPersister - // ledger is a simple data structure that keeps track of cumulative transaction fees in the mempool. - // The ledger keeps track of how much each user would have spent in fees across all their transactions in the mempool. - ledger *BalanceLedger // nonceTracker is responsible for keeping track of a (public key, nonce) -> Txn index. The index is useful in // facilitating a "replace by higher fee" feature. This feature gives users the ability to replace their existing // mempool transaction with a new transaction having the same nonce but higher fee. @@ -141,6 +154,11 @@ type PosMempool struct { // based off the current state of the mempool and the most n recent blocks. feeEstimator *PoSFeeEstimator + maxValidationViewConnects uint64 + + // transactionValidationRoutineRefreshIntervalMillis is the frequency with which the transactionValidationRoutine is run. + transactionValidationRefreshIntervalMillis uint64 + // augmentedBlockViewRefreshIntervalMillis is the frequency with which the augmentedLatestBlockView is updated. augmentedBlockViewRefreshIntervalMillis uint64 @@ -164,7 +182,7 @@ func (it *PosMempoolIterator) Value() (*MempoolTransaction, bool) { if txn == nil || txn.Tx == nil { return nil, ok } - return NewMempoolTransaction(txn.Tx, txn.Added), ok + return NewMempoolTransaction(txn.Tx, txn.Added, txn.IsValidated()), ok } func (it *PosMempoolIterator) Initialized() bool { @@ -180,7 +198,6 @@ func NewPosMempool() *PosMempool { status: PosMempoolStatusNotInitialized, txnRegister: NewTransactionRegister(), feeEstimator: NewPoSFeeEstimator(), - ledger: NewBalanceLedger(), nonceTracker: NewNonceTracker(), quit: make(chan interface{}), } @@ -198,6 +215,8 @@ func (mp *PosMempool) Init( feeEstimatorNumMempoolBlocks uint64, feeEstimatorPastBlocks []*MsgDeSoBlock, feeEstimatorNumPastBlocks uint64, + maxValidationViewConnects uint64, + transactionValidationRefreshIntervalMillis uint64, augmentedBlockViewRefreshIntervalMillis uint64, ) error { if mp.status != PosMempoolStatusNotInitialized { @@ -220,6 +239,8 @@ func (mp *PosMempool) Init( mp.inMemoryOnly = inMemoryOnly mp.maxMempoolPosSizeBytes = maxMempoolPosSizeBytes mp.mempoolBackupIntervalMillis = mempoolBackupIntervalMillis + mp.maxValidationViewConnects = maxValidationViewConnects + mp.transactionValidationRefreshIntervalMillis = transactionValidationRefreshIntervalMillis mp.augmentedBlockViewRefreshIntervalMillis = augmentedBlockViewRefreshIntervalMillis // TODO: parameterize num blocks. Also, how to pass in blocks. @@ -248,7 +269,6 @@ func (mp *PosMempool) Start() error { // Create the transaction register, the ledger, and the nonce tracker, mp.txnRegister = NewTransactionRegister() mp.txnRegister.Init(mp.globalParams) - mp.ledger = NewBalanceLedger() mp.nonceTracker = NewNonceTracker() // Setup the database and create the persister @@ -269,14 +289,35 @@ func (mp *PosMempool) Start() error { return errors.Wrapf(err, "PosMempool.Start: Problem loading persisted transactions") } } + mp.startGroup.Add(2) + mp.exitGroup.Add(2) + mp.startTransactionValidationRoutine() mp.startAugmentedViewRefreshRoutine() - + mp.startGroup.Wait() mp.status = PosMempoolStatusRunning return nil } +func (mp *PosMempool) startTransactionValidationRoutine() { + go func() { + mp.startGroup.Done() + for { + select { + case <-time.After(time.Duration(mp.transactionValidationRefreshIntervalMillis) * time.Millisecond): + if err := mp.validateTransactions(); err != nil { + glog.Errorf("PosMempool.startTransactionValidationRoutine: Problem validating transactions: %v", err) + } + case <-mp.quit: + mp.exitGroup.Done() + return + } + } + }() +} + func (mp *PosMempool) startAugmentedViewRefreshRoutine() { go func() { + mp.startGroup.Done() for { select { case <-time.After(time.Duration(mp.augmentedBlockViewRefreshIntervalMillis) * time.Millisecond): @@ -307,26 +348,6 @@ func (mp *PosMempool) startAugmentedViewRefreshRoutine() { // and proceed to the next transaction. if err == nil { newView = copiedView - continue - } - // If the transaction failed to connect, we connect the transaction as a failed txn - // directly on newView. - if mp.params.IsPoSBlockHeight(mp.latestBlockHeight + 1) { - // Copy the view again in case we hit an error. - copiedView, err = newView.CopyUtxoView() - if err != nil { - glog.Errorf("PosMempool.startAugmentedViewRefreshRoutine: Problem copying utxo view inner: %v", err) - continue - } - // Try to connect as failing txn directly to newView - _, _, _, err = copiedView._connectFailingTransaction( - txn.GetTxn(), uint32(mp.latestBlockHeight+1), false) - if err != nil { - glog.Errorf( - "PosMempool.startAugmentedViewRefreshRoutine: Problem connecting transaction: %v", err) - continue - } - newView = copiedView } } // Grab the augmentedLatestBlockViewMutex write lock and update the augmentedLatestBlockView. @@ -336,6 +357,7 @@ func (mp *PosMempool) startAugmentedViewRefreshRoutine() { // Increment the augmentedLatestBlockViewSequenceNumber. atomic.AddInt64(&mp.augmentedLatestBlockViewSequenceNumber, 1) case <-mp.quit: + mp.exitGroup.Done() return } } @@ -362,10 +384,10 @@ func (mp *PosMempool) Stop() { // Reset the transaction register, the ledger, and the nonce tracker. mp.txnRegister.Reset() - mp.ledger.Reset() mp.nonceTracker.Reset() mp.feeEstimator = NewPoSFeeEstimator() close(mp.quit) + mp.exitGroup.Wait() mp.status = PosMempoolStatusNotInitialized } @@ -464,7 +486,7 @@ func (mp *PosMempool) OnBlockDisconnected(block *MsgDeSoBlock) { // AddTransaction validates a MsgDeSoTxn transaction and adds it to the mempool if it is valid. // If the mempool overflows as a result of adding the transaction, the mempool is pruned. The // transaction signature verification can be skipped if verifySignature is passed as true. -func (mp *PosMempool) AddTransaction(mtxn *MempoolTransaction, verifySignature bool) error { +func (mp *PosMempool) AddTransaction(mtxn *MempoolTransaction) error { if mtxn == nil || mtxn.GetTxn() == nil { return fmt.Errorf("PosMempool.AddTransaction: Cannot add a nil transaction") } @@ -472,7 +494,7 @@ func (mp *PosMempool) AddTransaction(mtxn *MempoolTransaction, verifySignature b // First, validate that the transaction is properly formatted according to BalanceModel. We acquire a read lock on // the mempool. This allows multiple goroutines to safely perform transaction validation concurrently. In particular, // transaction signature verification can be parallelized. - if err := mp.validateTransaction(mtxn.GetTxn(), verifySignature); err != nil { + if err := mp.checkTransactionSanity(mtxn.GetTxn()); err != nil { return errors.Wrapf(err, "PosMempool.AddTransaction: Problem verifying transaction") } @@ -503,7 +525,7 @@ func (mp *PosMempool) AddTransaction(mtxn *MempoolTransaction, verifySignature b return nil } -func (mp *PosMempool) validateTransaction(txn *MsgDeSoTxn, verifySignature bool) error { +func (mp *PosMempool) checkTransactionSanity(txn *MsgDeSoTxn) error { mp.RLock() defer mp.RUnlock() @@ -519,32 +541,27 @@ func (mp *PosMempool) validateTransaction(txn *MsgDeSoTxn, verifySignature bool) return errors.Wrapf(err, "PosMempool.AddTransaction: Problem validating transaction nonce") } - if !verifySignature { - return nil + return nil +} + +func (mp *PosMempool) updateTransactionValidatedStatus(txnHash *BlockHash, validated bool) { + mp.Lock() + defer mp.Unlock() + + if !mp.IsRunning() || txnHash == nil { + return } - // Check transaction signature. - if _, err := mp.readOnlyLatestBlockView.VerifySignature(txn, uint32(mp.latestBlockHeight)); err != nil { - return errors.Wrapf(err, "PosMempool.AddTransaction: Signature validation failed") + txn := mp.txnRegister.GetTransaction(txnHash) + if txn == nil { + return } - return nil + txn.SetValidated(validated) } func (mp *PosMempool) addTransactionNoLock(txn *MempoolTx, persistToDb bool) error { userPk := NewPublicKey(txn.Tx.PublicKey) - txnFee := txn.Tx.TxnFeeNanos - - // Validate that the user has enough balance to cover the transaction fees. - spendableBalanceNanos, err := mp.readOnlyLatestBlockView.GetSpendableDeSoBalanceNanosForPublicKey(userPk.ToBytes(), - uint32(mp.latestBlockHeight)) - if err != nil { - return errors.Wrapf(err, "PosMempool.addTransactionNoLock: Problem getting spendable balance") - } - if err := mp.ledger.CanIncreaseEntryWithLimit(*userPk, txnFee, spendableBalanceNanos); err != nil { - return errors.Wrapf(err, "PosMempool.addTransactionNoLock: Problem checking balance increase for transaction with"+ - "hash %v, fee %v", txn.Tx.Hash(), txnFee) - } // Check the nonceTracker to see if this transaction is meant to replace an existing one. existingTxn := mp.nonceTracker.GetTxnByPublicKeyNonce(*userPk, *txn.Tx.TxnNonce) @@ -553,9 +570,7 @@ func (mp *PosMempool) addTransactionNoLock(txn *MempoolTx, persistToDb bool) err "by higher fee failed. New transaction has lower fee.") } - // If we get here, it means that the transaction's sender has enough balance to cover transaction fees. Moreover, if - // this transaction is meant to replace an existing one, at this point we know the new txn has a sufficient fee to - // do so. We can now add the transaction to mempool. + // We can now add the transaction to the mempool. if err := mp.txnRegister.AddTransaction(txn); err != nil { return errors.Wrapf(err, "PosMempool.addTransactionNoLock: Problem adding txn to register") } @@ -569,8 +584,7 @@ func (mp *PosMempool) addTransactionNoLock(txn *MempoolTx, persistToDb bool) err } } - // At this point the transaction is in the mempool. We can now update the ledger and nonce tracker. - mp.ledger.IncreaseEntry(*userPk, txnFee) + // At this point the transaction is in the mempool. mp.nonceTracker.AddTxnByPublicKeyNonce(txn, *userPk, *txn.Tx.TxnNonce) // Emit an event for the newly added transaction. @@ -633,8 +647,7 @@ func (mp *PosMempool) removeTransactionNoLock(txn *MempoolTx, persistToDb bool) return errors.Wrapf(err, "PosMempool.removeTransactionNoLock: Problem removing txn from register") } - // Remove the txn from the balance ledger and the nonce tracker. - mp.ledger.DecreaseEntry(*userPk, txn.Fee) + // Remove the txn from the nonce tracker. mp.nonceTracker.RemoveTxnByPublicKeyNonce(*userPk, *txn.Tx.TxnNonce) // Emit an event for the removed transaction. @@ -663,7 +676,7 @@ func (mp *PosMempool) GetTransaction(txnHash *BlockHash) *MempoolTransaction { return nil } - return NewMempoolTransaction(txn.Tx, txn.Added) + return NewMempoolTransaction(txn.Tx, txn.Added, txn.IsValidated()) } // GetTransactions returns all transactions in the mempool ordered by the Fee-Time algorithm. This function is thread-safe. @@ -682,7 +695,7 @@ func (mp *PosMempool) GetTransactions() []*MempoolTransaction { continue } - mtxn := NewMempoolTransaction(txn.Tx, txn.Added) + mtxn := NewMempoolTransaction(txn.Tx, txn.Added, txn.IsValidated()) mempoolTxns = append(mempoolTxns, mtxn) } return mempoolTxns @@ -713,61 +726,45 @@ func (mp *PosMempool) GetIterator() MempoolIterator { return NewPosMempoolIterator(mp.txnRegister.GetFeeTimeIterator()) } -// Refresh can be used to evict stale transactions from the mempool. However, it is a bit expensive and should be used -// sparingly. Upon being called, Refresh will create an in-memory temp PosMempool and populate it with transactions from -// the main mempool. The temp mempool will have the most up-to-date readOnlyLatestBlockView, Height, and globalParams. Any -// transaction that fails to add to the temp mempool will be removed from the main mempool. -func (mp *PosMempool) Refresh() error { - mp.Lock() - defer mp.Unlock() - +// Any transaction that fails to add to the temp mempool will be removed from the main mempool. +func (mp *PosMempool) validateTransactions() error { + mp.RLock() if !mp.IsRunning() { return nil } - if err := mp.refreshNoLock(); err != nil { - return errors.Wrapf(err, "PosMempool.Refresh: Problem refreshing mempool") - } - return nil -} + validationView := mp.readOnlyLatestBlockView + mempoolTxns := mp.getTransactionsNoLock() + mp.RUnlock() -func (mp *PosMempool) refreshNoLock() error { - // Create the temporary in-memory mempool with the most up-to-date readOnlyLatestBlockView, Height, and globalParams. - tempPool := NewPosMempool() - err := tempPool.Init( - mp.params, - mp.globalParams, - mp.readOnlyLatestBlockView, - mp.latestBlockHeight, - "", - true, - mp.maxMempoolPosSizeBytes, - mp.mempoolBackupIntervalMillis, - mp.feeEstimator.numMempoolBlocks, - mp.feeEstimator.cachedBlocks, - mp.feeEstimator.numPastBlocks, - mp.augmentedBlockViewRefreshIntervalMillis, - ) + var txns []*MsgDeSoTxn + var txHashes []*BlockHash + for _, txn := range mempoolTxns { + txns = append(txns, txn.Tx) + txHashes = append(txHashes, txn.Hash) + } + copyValidationView, err := validationView.CopyUtxoView() if err != nil { - return errors.Wrapf(err, "PosMempool.refreshNoLock: Problem initializing temp pool") + return errors.Wrapf(err, "PosMempool.refreshNoLock: Problem copying utxo view") } - if err := tempPool.Start(); err != nil { - return errors.Wrapf(err, "PosMempool.refreshNoLock: Problem starting temp pool") + _, _, _, _, successFlags, err := copyValidationView.ConnectTransactionsWithLimit(txns, txHashes, uint32(mp.latestBlockHeight)+1, time.Now().UnixNano(), + false, false, true, mp.maxValidationViewConnects) + if err != nil { + return errors.Wrapf(err, "PosMempool.refreshNoLock: Problem connecting transactions") } - defer tempPool.Stop() - // Add all transactions from the main mempool to the temp mempool. Skip signature verification. var txnsToRemove []*MempoolTx - txns := mp.getTransactionsNoLock() - for _, txn := range txns { - mtxn := NewMempoolTransaction(txn.Tx, txn.Added) - err := tempPool.AddTransaction(mtxn, false) - if err == nil { - continue + for ii, successFlag := range successFlags { + if ii >= len(mempoolTxns) { + break + } + if successFlag { + mp.Lock() + mp.updateTransactionValidatedStatus(mempoolTxns[ii].Hash, true) + mp.Unlock() + } else { + txnsToRemove = append(txnsToRemove, mempoolTxns[ii]) } - - // If we've encountered an error while adding the transaction to the temp mempool, we add it to our txnsToRemove list. - txnsToRemove = append(txnsToRemove, txn) } // Now remove all transactions from the txnsToRemove list from the main mempool. @@ -836,9 +833,6 @@ func (mp *PosMempool) UpdateGlobalParams(globalParams *GlobalParamsEntry) { } mp.globalParams = globalParams - if err := mp.refreshNoLock(); err != nil { - glog.Errorf("PosMempool.UpdateGlobalParams: Problem refreshing mempool: %v", err) - } } // Implementation of the Mempool interface @@ -911,8 +905,19 @@ func (mp *PosMempool) EstimateFee(txn *MsgDeSoTxn, pastBlocksCongestionFactorBasisPoints uint64, pastBlocksPriorityPercentileBasisPoints uint64, maxBlockSize uint64) (uint64, error) { - // TODO: replace MaxBasisPoints with variables configured by flags. return mp.feeEstimator.EstimateFee( txn, mempoolCongestionFactorBasisPoints, mempoolPriorityPercentileBasisPoints, pastBlocksCongestionFactorBasisPoints, pastBlocksPriorityPercentileBasisPoints, maxBlockSize) } + +func (mp *PosMempool) EstimateFeeRate( + _ uint64, + mempoolCongestionFactorBasisPoints uint64, + mempoolPriorityPercentileBasisPoints uint64, + pastBlocksCongestionFactorBasisPoints uint64, + pastBlocksPriorityPercentileBasisPoints uint64, + maxBlockSize uint64) (uint64, error) { + return mp.feeEstimator.EstimateFeeRateNanosPerKB( + mempoolCongestionFactorBasisPoints, mempoolPriorityPercentileBasisPoints, + pastBlocksCongestionFactorBasisPoints, pastBlocksPriorityPercentileBasisPoints, maxBlockSize) +} diff --git a/lib/pos_mempool_ledger.go b/lib/pos_mempool_ledger.go deleted file mode 100644 index e44096cd4..000000000 --- a/lib/pos_mempool_ledger.go +++ /dev/null @@ -1,93 +0,0 @@ -package lib - -import ( - "github.com/pkg/errors" - "math" - "sync" -) - -// BalanceLedger is a simple in-memory ledger of balances for user public keys. The balances in the ledger can be -// increased or decreased, as long as user's new balance doesn't exceed the user's total max balance. -type BalanceLedger struct { - sync.RWMutex - - // Map of public keys to balances. - balances map[PublicKey]uint64 -} - -func NewBalanceLedger() *BalanceLedger { - return &BalanceLedger{ - balances: make(map[PublicKey]uint64), - } -} - -// CanIncreaseEntryWithLimit checks if the user's ledger entry can be increased by delta. If the user's -// balance + delta is less or equal than the balanceLimit, the increase is allowed. Otherwise, an error is returned. -func (bl *BalanceLedger) CanIncreaseEntryWithLimit(publicKey PublicKey, delta uint64, balanceLimit uint64) error { - bl.RLock() - defer bl.RUnlock() - - balance, exists := bl.balances[publicKey] - - // Check for balance overflow. - if exists && delta > math.MaxUint64-balance { - return errors.Errorf("CanIncreaseEntryWithLimit: balance overflow") - } - - newBalance := balance + delta - if newBalance > balanceLimit { - return errors.Errorf("CanIncreaseEntryWithLimit: Balance + delta exceeds balance limit "+ - "(balance: %d, delta %v, balanceLimit: %d)", balance, delta, balanceLimit) - } - return nil -} - -// IncreaseEntry increases the user's ledger entry by delta. CanIncreaseEntryWithLimit should be called before -// calling this function to ensure the increase is allowed. -func (bl *BalanceLedger) IncreaseEntry(publicKey PublicKey, delta uint64) { - bl.Lock() - defer bl.Unlock() - - balance, _ := bl.balances[publicKey] - // Check for balance overflow. - if delta > math.MaxUint64-balance { - bl.balances[publicKey] = math.MaxUint64 - return - } - - bl.balances[publicKey] = balance + delta -} - -// DecreaseEntry decreases the user's ledger entry by delta. -func (bl *BalanceLedger) DecreaseEntry(publicKey PublicKey, delta uint64) { - bl.Lock() - defer bl.Unlock() - - balance, exists := bl.balances[publicKey] - if !exists { - return - } - // Check for balance underflow. - if delta > balance { - delete(bl.balances, publicKey) - return - } - - bl.balances[publicKey] = balance - delta -} - -// GetEntry returns the user's ledger entry. -func (bl *BalanceLedger) GetEntry(publicKey PublicKey) uint64 { - bl.RLock() - defer bl.RUnlock() - - balance, _ := bl.balances[publicKey] - return balance -} - -func (bl *BalanceLedger) Reset() { - bl.Lock() - defer bl.Unlock() - - bl.balances = make(map[PublicKey]uint64) -} diff --git a/lib/pos_mempool_ledger_test.go b/lib/pos_mempool_ledger_test.go deleted file mode 100644 index b909ce44b..000000000 --- a/lib/pos_mempool_ledger_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package lib - -import ( - "github.com/stretchr/testify/require" - "math" - "testing" -) - -func TestBalanceLedger(t *testing.T) { - require := require.New(t) - - pk0 := *NewPublicKey(m0PkBytes) - pk1 := *NewPublicKey(m1PkBytes) - pk2 := *NewPublicKey(m2PkBytes) - - // Sanity-check some balance increase and decreases for pk0 - balanceLedger := NewBalanceLedger() - require.NoError(balanceLedger.CanIncreaseEntryWithLimit(pk0, 100, 100)) - require.NoError(balanceLedger.CanIncreaseEntryWithLimit(pk0, 0, 100)) - balanceLedger.IncreaseEntry(pk0, 100) - require.Equal(uint64(100), balanceLedger.GetEntry(pk0)) - require.NoError(balanceLedger.CanIncreaseEntryWithLimit(pk0, 0, 100)) - require.Error(balanceLedger.CanIncreaseEntryWithLimit(pk0, 1, 100)) - require.Error(balanceLedger.CanIncreaseEntryWithLimit(pk0, 0, 99)) - require.Error(balanceLedger.CanIncreaseEntryWithLimit(pk0, math.MaxUint64, math.MaxUint64)) - balanceLedger.DecreaseEntry(pk0, 100) - require.Equal(uint64(0), balanceLedger.GetEntry(pk0)) - balanceLedger.IncreaseEntry(pk0, 10) - require.Equal(uint64(10), balanceLedger.GetEntry(pk0)) - balanceLedger.DecreaseEntry(pk0, 100) - require.Equal(uint64(0), balanceLedger.GetEntry(pk0)) - balanceLedger.IncreaseEntry(pk0, 100) - - // Increase balance for pk1 and pk2 a couple of times - balanceLedger.IncreaseEntry(pk1, 100) - balanceLedger.IncreaseEntry(pk2, 100) - balanceLedger.DecreaseEntry(pk1, 40) - balanceLedger.IncreaseEntry(pk2, 40) - require.Equal(uint64(100), balanceLedger.GetEntry(pk0)) - require.Equal(uint64(60), balanceLedger.GetEntry(pk1)) - require.Equal(uint64(140), balanceLedger.GetEntry(pk2)) - - // Test clearing balance ledger - balanceLedger.Reset() - require.Equal(uint64(0), balanceLedger.GetEntry(pk0)) - require.Equal(uint64(0), balanceLedger.GetEntry(pk1)) - require.Equal(uint64(0), balanceLedger.GetEntry(pk2)) -} diff --git a/lib/pos_mempool_test.go b/lib/pos_mempool_test.go index 9afd41694..c9fa1abbc 100644 --- a/lib/pos_mempool_test.go +++ b/lib/pos_mempool_test.go @@ -23,11 +23,12 @@ func TestPosMempoolStart(t *testing.T) { mempool := NewPosMempool() require.NoError(mempool.Init( - ¶ms, globalParams, nil, 0, dir, false, maxMempoolPosSizeBytes, mempoolBackupIntervalMillis, 1, nil, 1, 100, + ¶ms, globalParams, nil, 0, dir, false, maxMempoolPosSizeBytes, + mempoolBackupIntervalMillis, 1, nil, 1, 1000, 100, 100, )) require.NoError(mempool.Start()) require.True(mempool.IsRunning()) - require.NoError(mempool.Refresh()) + require.NoError(mempool.validateTransactions()) mempool.Stop() require.False(mempool.IsRunning()) } @@ -53,7 +54,7 @@ func TestPosMempoolRestartWithTransactions(t *testing.T) { mempool := NewPosMempool() require.NoError(mempool.Init( params, globalParams, latestBlockView, 2, dir, false, maxMempoolPosSizeBytes, mempoolBackupIntervalMillis, 1, - nil, 1, 100, + nil, 1, 1000, 100, 100, )) require.NoError(mempool.Start()) require.True(mempool.IsRunning()) @@ -65,20 +66,20 @@ func TestPosMempoolRestartWithTransactions(t *testing.T) { poolTxns := mempool.GetTransactions() require.Equal(2, len(poolTxns)) - require.NoError(mempool.Refresh()) + require.NoError(mempool.validateTransactions()) require.Equal(2, len(mempool.GetTransactions())) mempool.Stop() require.False(mempool.IsRunning()) newPool := NewPosMempool() require.NoError(newPool.Init(params, globalParams, latestBlockView, 2, dir, false, maxMempoolPosSizeBytes, - mempoolBackupIntervalMillis, 1, nil, 1, 100)) + mempoolBackupIntervalMillis, 1, nil, 1, 1000, 100, 100)) require.NoError(newPool.Start()) require.True(newPool.IsRunning()) newPoolTxns := newPool.GetTransactions() require.Equal(2, len(newPoolTxns)) require.Equal(len(newPool.GetTransactions()), len(newPool.nonceTracker.nonceMap)) - require.NoError(newPool.Refresh()) + require.NoError(newPool.validateTransactions()) require.Equal(2, len(newPool.GetTransactions())) _wrappedPosMempoolRemoveTransaction(t, newPool, txn1.Hash()) _wrappedPosMempoolRemoveTransaction(t, newPool, txn2.Hash()) @@ -108,7 +109,7 @@ func TestPosMempoolPrune(t *testing.T) { mempool := NewPosMempool() require.NoError(mempool.Init( params, globalParams, latestBlockView, 2, dir, false, maxMempoolPosSizeBytes, mempoolBackupIntervalMillis, 1, - nil, 1, 100, + nil, 1, 1000, 100, 100, )) require.NoError(mempool.Start()) require.True(mempool.IsRunning()) @@ -137,7 +138,7 @@ func TestPosMempoolPrune(t *testing.T) { // Remove one transaction. _wrappedPosMempoolRemoveTransaction(t, mempool, fetchedTxns[0].Hash()) - require.NoError(mempool.Refresh()) + require.NoError(mempool.validateTransactions()) require.Equal(2, len(mempool.GetTransactions())) mempool.Stop() require.False(mempool.IsRunning()) @@ -145,7 +146,7 @@ func TestPosMempoolPrune(t *testing.T) { newPool := NewPosMempool() require.NoError(newPool.Init( params, globalParams, latestBlockView, 2, dir, false, maxMempoolPosSizeBytes, mempoolBackupIntervalMillis, 1, - nil, 1, 100, + nil, 1, 1000, 100, 100, )) require.NoError(newPool.Start()) require.True(newPool.IsRunning()) @@ -174,7 +175,7 @@ func TestPosMempoolPrune(t *testing.T) { index++ } require.Equal(len(newPool.GetTransactions()), len(newPool.nonceTracker.nonceMap)) - require.NoError(newPool.Refresh()) + require.NoError(newPool.validateTransactions()) newTxns := newPool.GetTransactions() require.Equal(3, len(newTxns)) for _, txn := range newTxns { @@ -206,7 +207,7 @@ func TestPosMempoolUpdateGlobalParams(t *testing.T) { mempool := NewPosMempool() require.NoError(mempool.Init( params, globalParams, latestBlockView, 2, dir, false, maxMempoolPosSizeBytes, mempoolBackupIntervalMillis, 1, - nil, 1, 100, + nil, 1, 1000, 100, 100, )) require.NoError(mempool.Start()) require.True(mempool.IsRunning()) @@ -235,7 +236,7 @@ func TestPosMempoolUpdateGlobalParams(t *testing.T) { newPool := NewPosMempool() require.NoError(newPool.Init( params, newGlobalParams, latestBlockView, 2, dir, false, maxMempoolPosSizeBytes, mempoolBackupIntervalMillis, 1, - nil, 1, 100, + nil, 1, 1000, 100, 100, )) require.NoError(newPool.Start()) require.True(newPool.IsRunning()) @@ -268,7 +269,7 @@ func TestPosMempoolReplaceWithHigherFee(t *testing.T) { mempool := NewPosMempool() require.NoError(mempool.Init( params, globalParams, latestBlockView, 2, dir, false, maxMempoolPosSizeBytes, mempoolBackupIntervalMillis, 1, - nil, 1, 100, + nil, 1, 1000, 100, 100, )) require.NoError(mempool.Start()) require.True(mempool.IsRunning()) @@ -300,8 +301,8 @@ func TestPosMempoolReplaceWithHigherFee(t *testing.T) { *txn2Low.TxnNonce = *txn2.TxnNonce _signTxn(t, txn2Low, m1Priv) added2Low := time.Now() - mtxn2Low := NewMempoolTransaction(txn2Low, added2Low) - err = mempool.AddTransaction(mtxn2Low, true) + mtxn2Low := NewMempoolTransaction(txn2Low, added2Low, false) + err = mempool.AddTransaction(mtxn2Low) require.Contains(err.Error(), MempoolFailedReplaceByHigherFee) // Now generate a proper new transaction for m1, with same nonce, and higher fee. @@ -321,7 +322,7 @@ func TestPosMempoolReplaceWithHigherFee(t *testing.T) { require.Equal(txn1New, mempool.GetTransactions()[1].GetTxn()) require.Equal(len(mempool.GetTransactions()), len(mempool.nonceTracker.nonceMap)) - require.NoError(mempool.Refresh()) + require.NoError(mempool.validateTransactions()) require.Equal(2, len(mempool.GetTransactions())) mempool.Stop() require.False(mempool.IsRunning()) @@ -392,8 +393,8 @@ func _generateTestTxn(t *testing.T, rand *rand.Rand, feeMin uint64, feeMax uint6 func _wrappedPosMempoolAddTransaction(t *testing.T, mp *PosMempool, txn *MsgDeSoTxn) { added := time.Now() - mtxn := NewMempoolTransaction(txn, added) - require.NoError(t, mp.AddTransaction(mtxn, true)) + mtxn := NewMempoolTransaction(txn, added, false) + require.NoError(t, mp.AddTransaction(mtxn)) require.Equal(t, true, _checkPosMempoolIntegrity(t, mp)) } @@ -426,26 +427,5 @@ func _checkPosMempoolIntegrity(t *testing.T, mp *PosMempool) bool { } balances[*pk] += txn.TxnFeeNanos } - - if len(balances) > len(mp.ledger.balances) { - t.Errorf("PosMempool ledger is out of sync length balances (%v) > ledger (%v)", len(balances), len(mp.ledger.balances)) - return false - } - activeBalances := 0 - for pk, ledgerBalance := range mp.ledger.balances { - if ledgerBalance > 0 { - activeBalances++ - } else { - continue - } - if balance, exists := balances[pk]; !exists || ledgerBalance != balance { - t.Errorf("PosMempool ledger is out of sync pk %v", PkToStringTestnet(pk.ToBytes())) - return false - } - } - if len(balances) != activeBalances { - t.Errorf("PosMempool ledger is out of sync length") - return false - } return true } diff --git a/lib/pos_mempool_transaction.go b/lib/pos_mempool_transaction.go index ceef43aaa..e88a87593 100644 --- a/lib/pos_mempool_transaction.go +++ b/lib/pos_mempool_transaction.go @@ -26,6 +26,9 @@ type MempoolTx struct { // The time when the txn was added to the pool Added time.Time + // Whether this transaction has been validated by the mempool + Validated bool + // The block height when the txn was added to the pool. It's generally set // to tip+1. Height uint32 @@ -125,3 +128,11 @@ func (mempoolTx *MempoolTx) FromBytes(rr *bytes.Reader) error { *mempoolTx = *newTxn return nil } + +func (mempoolTx *MempoolTx) SetValidated(validated bool) { + mempoolTx.Validated = validated +} + +func (mempoolTx *MempoolTx) IsValidated() bool { + return mempoolTx.Validated +} diff --git a/lib/pos_network.go b/lib/pos_network.go index 891315937..08aae1bbc 100644 --- a/lib/pos_network.go +++ b/lib/pos_network.go @@ -123,6 +123,17 @@ func (msg *MsgDeSoValidatorVote) FromBytes(data []byte) error { return nil } +func (msg *MsgDeSoValidatorVote) ToString() string { + return fmt.Sprintf( + "{MsgVersion: %d, VotingPublicKey: %s, BlockHash: %v, ProposedInView: %d, VotePartialSignature: %v}", + msg.MsgVersion, + msg.VotingPublicKey.ToAbbreviatedString(), + msg.BlockHash, + msg.ProposedInView, + msg.VotePartialSignature.ToAbbreviatedString(), + ) +} + // ================================================================== // Proof of Stake Timeout Message // ================================================================== @@ -240,6 +251,18 @@ func (msg *MsgDeSoValidatorTimeout) FromBytes(data []byte) error { return nil } +func (msg *MsgDeSoValidatorTimeout) ToString() string { + return fmt.Sprintf( + "{MsgVersion: %d, VotingPublicKey: %s, TimedOutView: %d, HighQCView: %v, HighQCBlockHash: %v, TimeoutPartialSignature: %s}", + msg.MsgVersion, + msg.VotingPublicKey.ToAbbreviatedString(), + msg.TimedOutView, + msg.HighQC.ProposedInView, + msg.HighQC.BlockHash, + msg.TimeoutPartialSignature.ToAbbreviatedString(), + ) +} + // A QuorumCertificate contains an aggregated signature from 2/3rds of the validators // on the network, weighted by stake. The signatures are associated with a block hash // and a view, both of which are identified in the certificate. diff --git a/lib/pos_network_message_interface.go b/lib/pos_network_message_interface.go index 96d19a9cb..0593f52e2 100644 --- a/lib/pos_network_message_interface.go +++ b/lib/pos_network_message_interface.go @@ -155,6 +155,10 @@ func (validator *ValidatorEntry) GetStakeAmount() *uint256.Int { return validator.TotalStakeAmountNanos } +func (validator *ValidatorEntry) GetDomains() [][]byte { + return validator.Domains +} + func ValidatorEntriesToConsensusInterface(validatorEntries []*ValidatorEntry) []consensus.Validator { validatorInterfaces := make([]consensus.Validator, len(validatorEntries)) for idx, validatorEntry := range validatorEntries { diff --git a/lib/pos_server_regtest.go b/lib/pos_server_regtest.go index e4a339416..bd74780d6 100644 --- a/lib/pos_server_regtest.go +++ b/lib/pos_server_regtest.go @@ -25,7 +25,7 @@ func (srv *Server) submitRegtestValidatorRegistrationTxns(block *MsgDeSoBlock) { } txnMeta := RegisterAsValidatorMetadata{ - Domains: [][]byte{[]byte("https://deso.com")}, + Domains: [][]byte{[]byte("http://localhost:18000")}, DisableDelegatedStake: false, DelegatedStakeCommissionBasisPoints: 100, VotingPublicKey: blsSigner.GetPublicKey(), diff --git a/lib/remote_node.go b/lib/remote_node.go new file mode 100644 index 000000000..07610f8ec --- /dev/null +++ b/lib/remote_node.go @@ -0,0 +1,744 @@ +package lib + +import ( + "encoding/binary" + "fmt" + "github.com/btcsuite/btcd/wire" + "github.com/deso-protocol/core/bls" + "github.com/golang/glog" + "github.com/pkg/errors" + "golang.org/x/crypto/sha3" + "net" + "sync" + "time" +) + +type RemoteNodeStatus int + +const ( + RemoteNodeStatus_NotConnected RemoteNodeStatus = 0 + RemoteNodeStatus_Connected RemoteNodeStatus = 1 + RemoteNodeStatus_VersionSent RemoteNodeStatus = 2 + RemoteNodeStatus_VerackSent RemoteNodeStatus = 3 + RemoteNodeStatus_HandshakeCompleted RemoteNodeStatus = 4 + RemoteNodeStatus_Attempted RemoteNodeStatus = 5 + RemoteNodeStatus_Terminated RemoteNodeStatus = 6 +) + +type RemoteNodeId uint64 + +func NewRemoteNodeId(id uint64) RemoteNodeId { + return RemoteNodeId(id) +} + +func (id RemoteNodeId) ToUint64() uint64 { + return uint64(id) +} + +// RemoteNode is a chain-aware wrapper around the network Peer object. It is used to manage the lifecycle of a peer +// and to store blockchain-related metadata about the peer. The RemoteNode can wrap around either an inbound or outbound +// peer connection. For outbound peers, the RemoteNode is created prior to the connection being established. In this case, +// the RemoteNode will be first used to initiate an OutboundConnectionAttempt, and then store the resulting connected peer. +// For inbound peers, the RemoteNode is created after the connection is established in ConnectionManager. +// +// Once the RemoteNode's peer is set, the RemoteNode is used to manage the handshake with the peer. The handshake involves +// rounds of Version and Verack messages being sent between our node and the peer. The handshake is complete when both +// nodes have sent and received a Version and Verack message. Once the handshake is successful, the RemoteNode will +// emit a MsgDeSoPeerHandshakeComplete control message via the Server. +// +// In steady state, i.e. after the handshake is complete, the RemoteNode can be used to send a message to the peer, +// retrieve the peer's handshake metadata, and close the connection with the peer. The RemoteNode has a single-use +// lifecycle. Once the RemoteNode is terminated, it will be disposed of, and a new RemoteNode must be created if we +// wish to reconnect to the peer in the future. +type RemoteNode struct { + mtx sync.RWMutex + + peer *Peer + // The id is the unique identifier of this RemoteNode. For outbound connections, the id will be the same as the + // attemptId of the OutboundConnectionAttempt, and the subsequent id of the outbound peer. For inbound connections, + // the id will be the same as the inbound peer's id. + id RemoteNodeId + // validatorPublicKey is the BLS public key of the validator node. This is only set for validator nodes. For + // non-validator nodes, this will be nil. For outbound validators nodes, the validatorPublicKey will be set when + // the RemoteNode is instantiated. And for inbound validator nodes, the validatorPublicKey will be set when the + // handshake is completed. + validatorPublicKey *bls.PublicKey + // isPersistent identifies whether the RemoteNode is persistent or not. Persistent RemoteNodes is a sub-category of + // outbound RemoteNodes. They are different from non-persistent RemoteNodes from the very moment they are created. + // Initially, an outbound RemoteNode is in an "attempted" state, meaning we dial the connection to the peer. The + // non-persistent RemoteNode is terminated after the first failed dial, while a persistent RemoteNode will keep + // trying to dial the peer indefinitely until the connection is established, or the node stops. + isPersistent bool + + connectionStatus RemoteNodeStatus + + params *DeSoParams + srv *Server + cmgr *ConnectionManager + + // minTxFeeRateNanosPerKB is the minimum transaction fee rate in nanos per KB that our node will accept. + minTxFeeRateNanosPerKB uint64 + // latestBlockHeight is the block height of our node's block tip. + latestBlockHeight uint64 + // nodeServices is a bitfield that indicates the services supported by our node. + nodeServices ServiceFlag + + // handshakeMetadata is used to store the information received from the peer during the handshake. + handshakeMetadata *HandshakeMetadata + // keystore is a reference to the node's BLS private key storage. In the context of a RemoteNode, the keystore is + // used in the Verack message for validator nodes to prove ownership of the validator BLS public key. + keystore *BLSKeystore + + // versionTimeExpected is the latest time by which we expect to receive a Version message from the peer. + // If the Version message is not received by this time, the connection will be terminated. + versionTimeExpected *time.Time + // verackTimeExpected is the latest time by which we expect to receive a Verack message from the peer. + // If the Verack message is not received by this time, the connection will be terminated. + verackTimeExpected *time.Time +} + +// HandshakeMetadata stores the information received from the peer during the Version and Verack exchange. +type HandshakeMetadata struct { + // ### The following fields are populated during the MsgDeSoVersion exchange. + // versionNonceSent is the nonce sent in the Version message to the peer. + versionNonceSent uint64 + // versionNonceReceived is the nonce received in the Version message from the peer. + versionNonceReceived uint64 + // userAgent is a meta level label that can be used to analyze the network. + userAgent string + // serviceFlag is a bitfield that indicates the services supported by the peer. + serviceFlag ServiceFlag + // latestBlockHeight is the block height of the peer's block tip during the Version exchange. + latestBlockHeight uint64 + // minTxFeeRateNanosPerKB is the minimum transaction fee rate in nanos per KB that the peer will accept. + minTxFeeRateNanosPerKB uint64 + // advertisedProtocolVersion is the protocol version advertised by the peer. + advertisedProtocolVersion ProtocolVersionType + // negotiatedProtocolVersion is the protocol version negotiated between the peer and our node. This is the minimum + // of the advertised protocol version and our node's protocol version. + negotiatedProtocolVersion ProtocolVersionType + // timeConnected is the unix timestamp of the peer, measured when the peer sent their Version message. + timeConnected *time.Time + // versionNegotiated is true if the peer passed the version negotiation step. + versionNegotiated bool + // timeOffsetSecs is the time offset between our node and the peer, measured by taking the difference between the + // peer's unix timestamp and our node's unix timestamp. + timeOffsetSecs uint64 + + // ### The following fields are populated during the MsgDeSoVerack exchange. + // validatorPublicKey is the BLS public key of the peer, if the peer is a validator node. + validatorPublicKey *bls.PublicKey +} + +func NewHandshakeMetadata() *HandshakeMetadata { + return &HandshakeMetadata{} +} + +func NewRemoteNode(id RemoteNodeId, validatorPublicKey *bls.PublicKey, isPersistent bool, srv *Server, + cmgr *ConnectionManager, keystore *BLSKeystore, params *DeSoParams, minTxFeeRateNanosPerKB uint64, + latestBlockHeight uint64, nodeServices ServiceFlag) *RemoteNode { + return &RemoteNode{ + id: id, + validatorPublicKey: validatorPublicKey, + isPersistent: isPersistent, + connectionStatus: RemoteNodeStatus_NotConnected, + handshakeMetadata: NewHandshakeMetadata(), + srv: srv, + cmgr: cmgr, + keystore: keystore, + params: params, + minTxFeeRateNanosPerKB: minTxFeeRateNanosPerKB, + latestBlockHeight: latestBlockHeight, + nodeServices: nodeServices, + } +} + +// setStatusHandshakeCompleted sets the connection status of the remote node to HandshakeCompleted. +func (rn *RemoteNode) setStatusHandshakeCompleted() { + rn.connectionStatus = RemoteNodeStatus_HandshakeCompleted +} + +// setStatusConnected sets the connection status of the remote node to connected. +func (rn *RemoteNode) setStatusConnected() { + rn.connectionStatus = RemoteNodeStatus_Connected +} + +// setStatusVersionSent sets the connection status of the remote node to version sent. +func (rn *RemoteNode) setStatusVersionSent() { + rn.connectionStatus = RemoteNodeStatus_VersionSent +} + +// setStatusVerackSent sets the connection status of the remote node to verack sent. +func (rn *RemoteNode) setStatusVerackSent() { + rn.connectionStatus = RemoteNodeStatus_VerackSent +} + +// setStatusTerminated sets the connection status of the remote node to terminated. +func (rn *RemoteNode) setStatusTerminated() { + rn.connectionStatus = RemoteNodeStatus_Terminated +} + +// setStatusAttempted sets the connection status of the remote node to attempted. +func (rn *RemoteNode) setStatusAttempted() { + rn.connectionStatus = RemoteNodeStatus_Attempted +} + +func (rn *RemoteNode) GetId() RemoteNodeId { + return rn.id +} + +func (rn *RemoteNode) GetPeer() *Peer { + return rn.peer +} + +func (rn *RemoteNode) GetNegotiatedProtocolVersion() ProtocolVersionType { + return rn.handshakeMetadata.negotiatedProtocolVersion +} + +func (rn *RemoteNode) GetValidatorPublicKey() *bls.PublicKey { + return rn.validatorPublicKey +} + +func (rn *RemoteNode) GetServiceFlag() ServiceFlag { + return rn.handshakeMetadata.serviceFlag +} + +func (rn *RemoteNode) GetLatestBlockHeight() uint64 { + return rn.handshakeMetadata.latestBlockHeight +} + +func (rn *RemoteNode) GetUserAgent() string { + return rn.handshakeMetadata.userAgent +} + +func (rn *RemoteNode) GetNetAddress() *wire.NetAddress { + if !rn.IsHandshakeCompleted() || rn.GetPeer() == nil { + return nil + } + return rn.GetPeer().NetAddress() +} + +func (rn *RemoteNode) IsInbound() bool { + return rn.peer != nil && !rn.peer.IsOutbound() +} + +func (rn *RemoteNode) IsOutbound() bool { + return rn.peer != nil && rn.peer.IsOutbound() +} + +func (rn *RemoteNode) IsPersistent() bool { + return rn.isPersistent +} + +func (rn *RemoteNode) IsNotConnected() bool { + return rn.connectionStatus == RemoteNodeStatus_NotConnected +} + +func (rn *RemoteNode) IsConnected() bool { + return rn.connectionStatus == RemoteNodeStatus_Connected +} + +func (rn *RemoteNode) IsVersionSent() bool { + return rn.connectionStatus == RemoteNodeStatus_VersionSent +} + +func (rn *RemoteNode) IsVerackSent() bool { + return rn.connectionStatus == RemoteNodeStatus_VerackSent +} + +func (rn *RemoteNode) IsHandshakeCompleted() bool { + return rn.connectionStatus == RemoteNodeStatus_HandshakeCompleted +} + +func (rn *RemoteNode) IsTerminated() bool { + return rn.connectionStatus == RemoteNodeStatus_Terminated +} + +func (rn *RemoteNode) IsValidator() bool { + if !rn.IsHandshakeCompleted() { + return false + } + return rn.hasValidatorServiceFlag() +} + +func (rn *RemoteNode) IsExpectedValidator() bool { + return rn.GetValidatorPublicKey() != nil +} + +func (rn *RemoteNode) hasValidatorServiceFlag() bool { + return rn.GetServiceFlag().HasService(SFPosValidator) +} + +// DialOutboundConnection dials an outbound connection to the provided netAddr. +func (rn *RemoteNode) DialOutboundConnection(netAddr *wire.NetAddress) error { + rn.mtx.Lock() + defer rn.mtx.Unlock() + + if !rn.IsNotConnected() { + return fmt.Errorf("RemoteNode.DialOutboundConnection: RemoteNode is not in the NotConnected state") + } + + rn.cmgr.DialOutboundConnection(netAddr, rn.GetId().ToUint64()) + rn.setStatusAttempted() + return nil +} + +// DialPersistentOutboundConnection dials a persistent outbound connection to the provided netAddr. +func (rn *RemoteNode) DialPersistentOutboundConnection(netAddr *wire.NetAddress) error { + rn.mtx.Lock() + defer rn.mtx.Unlock() + + if !rn.IsNotConnected() { + return fmt.Errorf("RemoteNode.DialPersistentOutboundConnection: RemoteNode is not in the NotConnected state") + } + + rn.cmgr.DialPersistentOutboundConnection(netAddr, rn.GetId().ToUint64()) + rn.setStatusAttempted() + return nil +} + +// AttachInboundConnection creates an inbound peer once a successful inbound connection has been established. +func (rn *RemoteNode) AttachInboundConnection(conn net.Conn, na *wire.NetAddress) error { + rn.mtx.Lock() + defer rn.mtx.Unlock() + + if !rn.IsNotConnected() { + return fmt.Errorf("RemoteNode.AttachInboundConnection: RemoteNode is not in the NotConnected state") + } + + id := rn.GetId().ToUint64() + rn.peer = rn.cmgr.ConnectPeer(id, conn, na, false, false) + versionTimeExpected := time.Now().Add(rn.params.VersionNegotiationTimeout) + rn.versionTimeExpected = &versionTimeExpected + rn.setStatusConnected() + return nil +} + +// AttachOutboundConnection creates an outbound peer once a successful outbound connection has been established. +func (rn *RemoteNode) AttachOutboundConnection(conn net.Conn, na *wire.NetAddress, isPersistent bool) error { + rn.mtx.Lock() + defer rn.mtx.Unlock() + + if rn.connectionStatus != RemoteNodeStatus_Attempted { + return fmt.Errorf("RemoteNode.AttachOutboundConnection: RemoteNode is not in the Attempted state") + } + + id := rn.GetId().ToUint64() + rn.peer = rn.cmgr.ConnectPeer(id, conn, na, true, isPersistent) + versionTimeExpected := time.Now().Add(rn.params.VersionNegotiationTimeout) + rn.versionTimeExpected = &versionTimeExpected + rn.setStatusConnected() + return nil +} + +// Disconnect disconnects the remote node, closing the attempted connection or the established connection. +func (rn *RemoteNode) Disconnect() { + rn.mtx.Lock() + defer rn.mtx.Unlock() + + if rn.connectionStatus == RemoteNodeStatus_Terminated { + return + } + glog.V(2).Infof("RemoteNode.Disconnect: Disconnecting from peer (id= %d, status= %v)", + rn.id, rn.connectionStatus) + + id := rn.GetId().ToUint64() + switch rn.connectionStatus { + case RemoteNodeStatus_Attempted: + rn.cmgr.CloseAttemptedConnection(id) + case RemoteNodeStatus_Connected, RemoteNodeStatus_VersionSent, RemoteNodeStatus_VerackSent, + RemoteNodeStatus_HandshakeCompleted: + rn.cmgr.CloseConnection(id) + } + rn.setStatusTerminated() +} + +func (rn *RemoteNode) SendMessage(desoMsg DeSoMessage) error { + rn.mtx.RLock() + rn.mtx.RUnlock() + + if rn.connectionStatus != RemoteNodeStatus_HandshakeCompleted { + return fmt.Errorf("SendMessage: Remote node is not connected") + } + + return rn.sendMessage(desoMsg) +} + +func (rn *RemoteNode) sendMessage(desoMsg DeSoMessage) error { + if err := rn.cmgr.SendMessage(desoMsg, rn.GetId().ToUint64()); err != nil { + return fmt.Errorf("SendMessage: Problem sending message to peer (id= %d): %v", rn.id, err) + } + return nil +} + +// InitiateHandshake is a starting point for a peer handshake. If the peer is outbound, a version message is sent +// to the peer. If the peer is inbound, the peer is expected to send a version message to us first. +func (rn *RemoteNode) InitiateHandshake(nonce uint64) error { + rn.mtx.Lock() + defer rn.mtx.Unlock() + + if rn.connectionStatus != RemoteNodeStatus_Connected { + return fmt.Errorf("InitiateHandshake: Remote node is not connected") + } + + if rn.GetPeer().IsOutbound() { + if err := rn.sendVersionMessage(nonce); err != nil { + return fmt.Errorf("InitiateHandshake: Problem sending version message to peer (id= %d): %v", rn.id, err) + } + rn.setStatusVersionSent() + } + return nil +} + +// sendVersionMessage generates and sends a version message to a RemoteNode peer. The message will contain the nonce +// that is passed in as an argument. +func (rn *RemoteNode) sendVersionMessage(nonce uint64) error { + verMsg := rn.newVersionMessage(nonce) + + // Record the nonce of this version message before we send it so we can + // detect self connections and so we can validate that the peer actually + // controls the IP she's supposedly communicating to us from. + rn.handshakeMetadata.versionNonceSent = nonce + + if err := rn.sendMessage(verMsg); err != nil { + return fmt.Errorf("sendVersionMessage: Problem sending version message to peer (id= %d): %v", rn.id, err) + } + return nil +} + +// newVersionMessage returns a new version message that can be sent to a RemoteNode. The message will contain the +// nonce that is passed in as an argument. +func (rn *RemoteNode) newVersionMessage(nonce uint64) *MsgDeSoVersion { + ver := NewMessage(MsgTypeVersion).(*MsgDeSoVersion) + + ver.Version = rn.params.ProtocolVersion.ToUint64() + // Set the services bitfield to indicate what services this node supports. + ver.Services = rn.nodeServices + + // We use an int64 instead of a uint64 for convenience. + ver.TstampSecs = time.Now().Unix() + + ver.Nonce = nonce + ver.UserAgent = rn.params.UserAgent + + // When a node asks you for what height you have, you should reply with the height of the latest actual block you + // have. This makes it so that peers who have up-to-date headers but missing blocks won't be considered for initial + // block download. + ver.LatestBlockHeight = rn.latestBlockHeight + + // Set the minimum fee rate the peer will accept. + ver.MinFeeRateNanosPerKB = rn.minTxFeeRateNanosPerKB + + return ver +} + +func (rn *RemoteNode) IsTimedOut() bool { + if rn.IsTerminated() { + return true + } + if rn.IsConnected() || rn.IsVersionSent() { + return rn.versionTimeExpected.Before(time.Now()) + } + if rn.IsVerackSent() { + return rn.verackTimeExpected.Before(time.Now()) + } + return false +} + +// HandleVersionMessage is called upon receiving a version message from the RemoteNode's peer. The peer may be the one +// initiating the handshake, in which case, we should respond with our own version message. To do this, we pass the +// responseNonce to this function, which we will use in our response version message. +func (rn *RemoteNode) HandleVersionMessage(verMsg *MsgDeSoVersion, responseNonce uint64) error { + rn.mtx.Lock() + defer rn.mtx.Unlock() + + if !rn.IsConnected() && !rn.IsVersionSent() { + return fmt.Errorf("HandleVersionMessage: RemoteNode is not connected or version exchange has already "+ + "been completed, connectionStatus: %v", rn.connectionStatus) + } + + // Verify that the peer's version matches our minimal supported version. + if verMsg.Version < rn.params.MinProtocolVersion { + return fmt.Errorf("RemoteNode.HandleVersionMessage: Requesting disconnect for id: (%v) "+ + "protocol version too low. Peer version: %v, min version: %v", rn.id, verMsg.Version, rn.params.MinProtocolVersion) + } + + // Verify that the peer's version message is sent within the version negotiation timeout. + if rn.versionTimeExpected.Before(time.Now()) { + return fmt.Errorf("RemoteNode.HandleVersionMessage: Requesting disconnect for id: (%v) "+ + "version timeout. Time expected: %v, now: %v", rn.id, rn.versionTimeExpected.UnixMicro(), time.Now().UnixMicro()) + } + + vMeta := rn.handshakeMetadata + // Record the version the peer is using. + vMeta.advertisedProtocolVersion = NewProtocolVersionType(verMsg.Version) + // Make sure the latest supported protocol version is ProtocolVersion2. + if vMeta.advertisedProtocolVersion.After(ProtocolVersion2) { + return fmt.Errorf("RemoteNode.HandleVersionMessage: Requesting disconnect for id: (%v) "+ + "protocol version too high. Peer version: %v, max version: %v", rn.id, verMsg.Version, ProtocolVersion2) + } + + // Decide on the protocol version to use for this connection. + negotiatedVersion := rn.params.ProtocolVersion + if verMsg.Version < rn.params.ProtocolVersion.ToUint64() { + // In order to smoothly transition to the PoS fork, we prevent establishing new outbound connections with + // outdated nodes that run on ProtocolVersion1. This is because ProtocolVersion1 nodes will not be able to + // validate the PoS blocks and will be stuck on the PoW chain, unless they upgrade to ProtocolVersion2. + if rn.params.ProtocolVersion == ProtocolVersion2 && rn.IsOutbound() { + return fmt.Errorf("RemoteNode.HandleVersionMessage: Requesting disconnect for id: (%v). Version too low. "+ + "Outbound RemoteNodes must use at least ProtocolVersion2, instead received version: %v", rn.id, verMsg.Version) + } + + negotiatedVersion = NewProtocolVersionType(verMsg.Version) + } + + vMeta.negotiatedProtocolVersion = negotiatedVersion + + // Record the services the peer is advertising. + vMeta.serviceFlag = verMsg.Services + // If the RemoteNode was connected with an expectation of being a validator, make sure that its advertised ServiceFlag + // indicates that it is a validator. + if !rn.hasValidatorServiceFlag() && rn.validatorPublicKey != nil { + return fmt.Errorf("RemoteNode.HandleVersionMessage: Requesting disconnect for id: (%v). "+ + "Expected validator, but received invalid ServiceFlag: %v", rn.id, verMsg.Services) + } + // If the RemoteNode is on ProtocolVersion1, then it must not have the validator service flag set. + if rn.hasValidatorServiceFlag() && vMeta.advertisedProtocolVersion.Before(ProtocolVersion2) { + return fmt.Errorf("RemoteNode.HandleVersionMessage: Requesting disconnect for id: (%v). "+ + "RemoteNode has SFValidator service flag, but doesn't have ProtocolVersion2 or later", rn.id) + } + + // Record the tstamp sent by the peer and calculate the time offset. + timeConnected := time.Unix(verMsg.TstampSecs, 0) + vMeta.timeConnected = &timeConnected + currentTime := time.Now().Unix() + if currentTime > verMsg.TstampSecs { + vMeta.timeOffsetSecs = uint64(currentTime - verMsg.TstampSecs) + } else { + vMeta.timeOffsetSecs = uint64(verMsg.TstampSecs - currentTime) + } + + // Save the received version nonce so we can include it in our verack message. + vMeta.versionNonceReceived = verMsg.Nonce + + // Set the peer info-related fields. + vMeta.userAgent = verMsg.UserAgent + vMeta.latestBlockHeight = verMsg.LatestBlockHeight + vMeta.minTxFeeRateNanosPerKB = verMsg.MinFeeRateNanosPerKB + + // Respond to the version message if this is an inbound peer. + if rn.IsInbound() { + if err := rn.sendVersionMessage(responseNonce); err != nil { + return errors.Wrapf(err, "RemoteNode.HandleVersionMessage: Problem sending version message to peer (id= %d)", rn.id) + } + } + + // After sending and receiving a compatible version, send the verack message. Notice that we don't wait for the + // peer's verack message even if it is an inbound peer. Instead, we just send the verack message right away. + + // Set the latest time by which we should receive a verack message from the peer. + verackTimeExpected := time.Now().Add(rn.params.VerackNegotiationTimeout) + rn.verackTimeExpected = &verackTimeExpected + if err := rn.sendVerack(); err != nil { + return errors.Wrapf(err, "RemoteNode.HandleVersionMessage: Problem sending verack message to peer (id= %d)", rn.id) + } + + // Update the timeSource now that we've gotten a version message from the peer. + rn.cmgr.AddTimeSample(rn.peer.Address(), timeConnected) + rn.setStatusVerackSent() + return nil +} + +// sendVerack constructs and sends a verack message to the peer. +func (rn *RemoteNode) sendVerack() error { + verackMsg, err := rn.newVerackMessage() + if err != nil { + return err + } + + if err := rn.sendMessage(verackMsg); err != nil { + return errors.Wrapf(err, "RemoteNode.SendVerack: Problem sending verack message to peer (id= %d): %v", rn.id, err) + } + return nil +} + +// newVerackMessage constructs a verack message to be sent to the peer. +func (rn *RemoteNode) newVerackMessage() (*MsgDeSoVerack, error) { + verack := NewMessage(MsgTypeVerack).(*MsgDeSoVerack) + vMeta := rn.handshakeMetadata + + switch vMeta.negotiatedProtocolVersion { + case ProtocolVersion0, ProtocolVersion1: + // For protocol versions 0 and 1, we just send back the nonce we received from the peer in the version message. + verack.Version = VerackVersion0 + verack.NonceReceived = vMeta.versionNonceReceived + case ProtocolVersion2: + // For protocol version 2, we need to send the nonce we received from the peer in their version message. + // We also need to send our own nonce, which we generate for our version message. In addition, we need to + // send a current timestamp (in microseconds). We then sign the tuple of (nonceReceived, nonceSent, tstampMicro) + // using our validator BLS key, and send the signature along with our public key. + var err error + verack.Version = VerackVersion1 + verack.NonceReceived = vMeta.versionNonceReceived + verack.NonceSent = vMeta.versionNonceSent + tstampMicro := uint64(time.Now().UnixMicro()) + verack.TstampMicro = tstampMicro + // If the RemoteNode is not a validator, then we don't need to sign the verack message. + if !rn.nodeServices.HasService(SFPosValidator) { + break + } + verack.PublicKey = rn.keystore.GetSigner().GetPublicKey() + verack.Signature, err = rn.keystore.GetSigner().SignPoSValidatorHandshake(verack.NonceSent, verack.NonceReceived, tstampMicro) + if err != nil { + return nil, fmt.Errorf("RemoteNode.newVerackMessage: Problem signing verack message: %v", err) + } + } + return verack, nil +} + +// HandleVerackMessage handles a verack message received from the peer. +func (rn *RemoteNode) HandleVerackMessage(vrkMsg *MsgDeSoVerack) error { + rn.mtx.Lock() + defer rn.mtx.Unlock() + + if rn.connectionStatus != RemoteNodeStatus_VerackSent { + return fmt.Errorf("RemoteNode.HandleVerackMessage: Requesting disconnect for id: (%v) "+ + "verack received while in state: %v", rn.id, rn.connectionStatus) + } + + if rn.verackTimeExpected != nil && rn.verackTimeExpected.Before(time.Now()) { + return fmt.Errorf("RemoteNode.HandleVerackMessage: Requesting disconnect for id: (%v) "+ + "verack timeout. Time expected: %v, now: %v", rn.id, rn.verackTimeExpected.UnixMicro(), time.Now().UnixMicro()) + } + + var err error + vMeta := rn.handshakeMetadata + switch vMeta.negotiatedProtocolVersion { + case ProtocolVersion0, ProtocolVersion1: + err = rn.validateVerackPoW(vrkMsg) + case ProtocolVersion2: + err = rn.validateVerackPoS(vrkMsg) + } + + if err != nil { + return errors.Wrapf(err, "RemoteNode.HandleVerackMessage: Problem validating verack message from peer (id= %d)", rn.id) + } + + // If we get here then the peer has successfully completed the handshake. + vMeta.versionNegotiated = true + rn._logVersionSuccess() + rn.setStatusHandshakeCompleted() + + return nil +} + +func (rn *RemoteNode) validateVerackPoW(vrkMsg *MsgDeSoVerack) error { + vMeta := rn.handshakeMetadata + + // Verify that the verack message is formatted correctly according to the PoW standard. + if vrkMsg.Version != VerackVersion0 { + return fmt.Errorf("RemoteNode.validateVerackPoW: Requesting disconnect for id: (%v) "+ + "verack version mismatch; message: %v; expected: %v", rn.id, vrkMsg.Version, VerackVersion0) + } + + // If the verack message has a nonce that wasn't previously sent to us in the version message, return an error. + if vrkMsg.NonceReceived != vMeta.versionNonceSent { + return fmt.Errorf("RemoteNode.validateVerackPoW: Requesting disconnect for id: (%v) nonce mismatch; "+ + "message: %v; nonceSent: %v", rn.id, vrkMsg.NonceReceived, vMeta.versionNonceSent) + } + + return nil +} + +func (rn *RemoteNode) validateVerackPoS(vrkMsg *MsgDeSoVerack) error { + vMeta := rn.handshakeMetadata + + // Verify that the verack message is formatted correctly according to the PoS standard. + if vrkMsg.Version != VerackVersion1 { + return fmt.Errorf("RemoteNode.validateVerackPoS: Requesting disconnect for id: (%v) "+ + "verack version mismatch; message: %v; expected: %v", rn.id, vrkMsg.Version, VerackVersion1) + } + + // Verify that the counterparty's verack message's NonceReceived matches the NonceSent we sent. + if vrkMsg.NonceReceived != vMeta.versionNonceSent { + return fmt.Errorf("RemoteNode.validateVerackPoS: Requesting disconnect for id: (%v) nonce mismatch; "+ + "message: %v; nonceSent: %v", rn.id, vrkMsg.NonceReceived, vMeta.versionNonceSent) + } + + // Verify that the counterparty's verack message's NonceSent matches the NonceReceived we sent. + if vrkMsg.NonceSent != vMeta.versionNonceReceived { + return fmt.Errorf("RemoteNode.validateVerackPoS: Requesting disconnect for id: (%v) "+ + "verack nonce mismatch; message: %v; expected: %v", rn.id, vrkMsg.NonceSent, vMeta.versionNonceReceived) + } + + // Get the current time in microseconds and make sure the verack message's timestamp is within 15 minutes of it. + timeNowMicro := uint64(time.Now().UnixMicro()) + if vrkMsg.TstampMicro < timeNowMicro-rn.params.HandshakeTimeoutMicroSeconds { + return fmt.Errorf("RemoteNode.validateVerackPoS: Requesting disconnect for id: (%v) "+ + "verack timestamp too far in the past. Time now: %v, verack timestamp: %v", rn.id, timeNowMicro, vrkMsg.TstampMicro) + } + + // If the RemoteNode is not a validator, then we don't need to verify the verack message's signature. + if !rn.hasValidatorServiceFlag() { + return nil + } + + // Make sure the verack message's public key and signature are not nil. + if vrkMsg.PublicKey == nil || vrkMsg.Signature == nil { + return fmt.Errorf("RemoteNode.validateVerackPoS: Requesting disconnect for id: (%v) "+ + "verack public key or signature is nil", rn.id) + } + + // Verify the verack message's signature. + ok, err := BLSVerifyPoSValidatorHandshake(vrkMsg.NonceSent, vrkMsg.NonceReceived, vrkMsg.TstampMicro, + vrkMsg.Signature, vrkMsg.PublicKey) + if err != nil { + return errors.Wrapf(err, "RemoteNode.validateVerackPoS: Requesting disconnect for id: (%v) "+ + "verack signature verification failed with error", rn.id) + } + if !ok { + return fmt.Errorf("RemoteNode.validateVerackPoS: Requesting disconnect for id: (%v) "+ + "verack signature verification failed", rn.id) + } + + if rn.validatorPublicKey != nil && rn.validatorPublicKey.Serialize() != vrkMsg.PublicKey.Serialize() { + return fmt.Errorf("RemoteNode.validateVerackPoS: Requesting disconnect for id: (%v) "+ + "verack public key mismatch; message: %v; expected: %v", rn.id, vrkMsg.PublicKey, rn.validatorPublicKey) + } + + // If we get here then the verack message is valid. Set the validator public key on the peer. + vMeta.validatorPublicKey = vrkMsg.PublicKey + rn.validatorPublicKey = vrkMsg.PublicKey + return nil +} + +func (rn *RemoteNode) _logVersionSuccess() { + inboundStr := "INBOUND" + if rn.IsOutbound() { + inboundStr = "OUTBOUND" + } + persistentStr := "PERSISTENT" + if !rn.IsPersistent() { + persistentStr = "NON-PERSISTENT" + } + logStr := fmt.Sprintf("SUCCESS version negotiation for (%s) (%s) id=(%v).", inboundStr, persistentStr, rn.id.ToUint64()) + glog.V(1).Info(logStr) +} + +func GetVerackHandshakePayload(nonceReceived uint64, nonceSent uint64, tstampMicro uint64) [32]byte { + // The payload for the verack message is the two nonces concatenated together. + // We do this so that we can sign the nonces and verify the signature on the other side. + nonceReceivedBytes := make([]byte, 8) + binary.BigEndian.PutUint64(nonceReceivedBytes, nonceReceived) + + nonceSentBytes := make([]byte, 8) + binary.BigEndian.PutUint64(nonceSentBytes, nonceSent) + + tstampBytes := make([]byte, 8) + binary.BigEndian.PutUint64(tstampBytes, tstampMicro) + + payload := append(nonceReceivedBytes, nonceSentBytes...) + payload = append(payload, tstampBytes...) + + return sha3.Sum256(payload) +} diff --git a/lib/server.go b/lib/server.go index d0940cc03..7de67303c 100644 --- a/lib/server.go +++ b/lib/server.go @@ -12,15 +12,16 @@ import ( "sync/atomic" "time" + "github.com/btcsuite/btcd/wire" + "github.com/deso-protocol/core/consensus" + "github.com/decred/dcrd/lru" "github.com/DataDog/datadog-go/statsd" "github.com/btcsuite/btcd/addrmgr" chainlib "github.com/btcsuite/btcd/blockchain" - "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" - "github.com/deso-protocol/core/consensus" "github.com/deso-protocol/go-deadlock" "github.com/dgraph-io/badger/v3" "github.com/golang/glog" @@ -64,6 +65,9 @@ type Server struct { TxIndex *TXIndex params *DeSoParams + // fastHotStuffEventLoop consensus.FastHotStuffEventLoop + networkManager *NetworkManager + // posMempool *PosMemPool TODO: Add the mempool later fastHotStuffConsensus *FastHotStuffConsensus // All messages received from peers get sent from the ConnectionManager to the @@ -128,7 +132,9 @@ type Server struct { // It is organized in this way so that we can limit the number of addresses we // are distributing for a single peer to avoid a DOS attack. addrsToBroadcastLock deadlock.RWMutex - addrsToBroadcastt map[string][]*SingleAddr + addrsToBroadcast map[string][]*SingleAddr + + AddrMgr *addrmgr.AddrManager // When set to true, we disable the ConnectionManager DisableNetworking bool @@ -176,6 +182,10 @@ func (srv *Server) ResetRequestQueues() { srv.requestedTransactionsMap = make(map[BlockHash]*GetDataRequestInfo) } +func (srv *Server) GetNetworkManager() *NetworkManager { + return srv.networkManager +} + // dataLock must be acquired for writing before calling this function. func (srv *Server) _removeRequest(hash *BlockHash) { // Just be lazy and remove the hash from everything indiscriminately to @@ -235,7 +245,6 @@ func (srv *Server) GetBlockProducer() *DeSoBlockProducer { return srv.blockProducer } -// TODO: The hallmark of a messy non-law-of-demeter-following interface... func (srv *Server) GetConnectionManager() *ConnectionManager { return srv.cmgr } @@ -397,9 +406,12 @@ func NewServer( _mempoolBackupIntervalMillis uint64, _mempoolFeeEstimatorNumMempoolBlocks uint64, _mempoolFeeEstimatorNumPastBlocks uint64, + _mempoolMaxValidationViewConnects uint64, + _transactionValidationRefreshIntervalMillis uint64, _augmentedBlockViewRefreshIntervalMillis uint64, _posBlockProductionIntervalMilliseconds uint64, _posTimeoutBaseDurationMilliseconds uint64, + _stateSyncerMempoolTxnSyncLimit uint64, ) ( _srv *Server, _err error, @@ -413,7 +425,7 @@ func NewServer( if _stateChangeDir != "" { // Create the state change syncer to handle syncing state changes to disk, and assign some of its methods // to the event manager. - stateChangeSyncer = NewStateChangeSyncer(_stateChangeDir, _syncType) + stateChangeSyncer = NewStateChangeSyncer(_stateChangeDir, _syncType, _stateSyncerMempoolTxnSyncLimit) eventManager.OnStateSyncerOperation(stateChangeSyncer._handleStateSyncerOperation) eventManager.OnStateSyncerFlushed(stateChangeSyncer._handleStateSyncerFlush) } @@ -444,6 +456,7 @@ func NewServer( snapshot: _snapshot, nodeMessageChannel: _nodeMessageChan, forceChecksum: _forceChecksum, + AddrMgr: _desoAddrMgr, params: _params, } @@ -457,10 +470,9 @@ func NewServer( timesource := chainlib.NewMedianTime() // Create a new connection manager but note that it won't be initialized until Start(). - _incomingMessages := make(chan *ServerMessage, (_targetOutboundPeers+_maxInboundPeers)*3) + _incomingMessages := make(chan *ServerMessage, 100+(_targetOutboundPeers+_maxInboundPeers)*3) _cmgr := NewConnectionManager( - _params, _desoAddrMgr, _listeners, _connectIps, timesource, - _targetOutboundPeers, _maxInboundPeers, _limitOneInboundConnectionPerIP, + _params, _listeners, _connectIps, timesource, _hyperSync, _syncType, _stallTimeoutSeconds, _minFeeRateNanosPerKB, _incomingMessages, srv) @@ -487,13 +499,37 @@ func NewServer( return nil, errors.Wrapf(err, "NewServer: Problem initializing blockchain"), true } + headerCumWorkStr := "" + headerCumWork := BigintToHash(_chain.headerTip().CumWork) + if headerCumWork != nil { + headerCumWorkStr = hex.EncodeToString(headerCumWork[:]) + } + blockCumWorkStr := "" + blockCumWork := BigintToHash(_chain.blockTip().CumWork) + if blockCumWork != nil { + blockCumWorkStr = hex.EncodeToString(blockCumWork[:]) + } glog.V(1).Infof("Initialized chain: Best Header Height: %d, Header Hash: %s, Header CumWork: %s, Best Block Height: %d, Block Hash: %s, Block CumWork: %s", _chain.headerTip().Height, hex.EncodeToString(_chain.headerTip().Hash[:]), - hex.EncodeToString(BigintToHash(_chain.headerTip().CumWork)[:]), + headerCumWorkStr, _chain.blockTip().Height, hex.EncodeToString(_chain.blockTip().Hash[:]), - hex.EncodeToString(BigintToHash(_chain.blockTip().CumWork)[:])) + blockCumWorkStr) + + nodeServices := SFFullNodeDeprecated + if _hyperSync { + nodeServices |= SFHyperSync + } + if archivalMode { + nodeServices |= SFArchivalNode + } + if _blsKeystore != nil { + nodeServices |= SFPosValidator + } + srv.networkManager = NewNetworkManager(_params, srv, _chain, _cmgr, _blsKeystore, _desoAddrMgr, + _connectIps, _targetOutboundPeers, _maxInboundPeers, _limitOneInboundConnectionPerIP, + _minFeeRateNanosPerKB, nodeServices) if srv.stateChangeSyncer != nil { srv.stateChangeSyncer.BlockHeight = uint64(_chain.headerTip().Height) @@ -530,6 +566,8 @@ func NewServer( _mempoolFeeEstimatorNumMempoolBlocks, []*MsgDeSoBlock{latestBlock}, _mempoolFeeEstimatorNumPastBlocks, + _mempoolMaxValidationViewConnects, + _transactionValidationRefreshIntervalMillis, _augmentedBlockViewRefreshIntervalMillis, ) if err != nil { @@ -594,6 +632,7 @@ func NewServer( if _blsKeystore != nil { srv.fastHotStuffConsensus = NewFastHotStuffConsensus( _params, + srv.networkManager, _chain, _posMempool, _blsKeystore.GetSigner(), @@ -633,7 +672,7 @@ func NewServer( } // Initialize the addrs to broadcast map. - srv.addrsToBroadcastt = make(map[string][]*SingleAddr) + srv.addrsToBroadcast = make(map[string][]*SingleAddr) // This will initialize the request queues. srv.ResetRequestQueues() @@ -643,6 +682,10 @@ func NewServer( timer.Initialize() srv.timer = timer + if srv.stateChangeSyncer != nil { + srv.stateChangeSyncer.StartMempoolSyncRoutine(srv) + } + // If shouldRestart is true, it means that the state checksum is likely corrupted, and we need to enter a recovery mode. // This can happen if the node was terminated mid-operation last time it was running. The recovery process rolls back // blocks to the beginning of the current snapshot epoch and resets to the state checksum to the epoch checksum. @@ -770,11 +813,13 @@ func (srv *Server) GetSnapshot(pp *Peer) { } // If operationQueueSemaphore is full, we are already storing too many chunks in memory. Block the thread while // we wait for the queue to clear up. - srv.snapshot.operationQueueSemaphore <- struct{}{} - // Now send a message to the peer to fetch the snapshot chunk. - pp.AddDeSoMessage(&MsgDeSoGetSnapshot{ - SnapshotStartKey: lastReceivedKey, - }, false) + go func() { + srv.snapshot.operationQueueSemaphore <- struct{}{} + // Now send a message to the peer to fetch the snapshot chunk. + pp.AddDeSoMessage(&MsgDeSoGetSnapshot{ + SnapshotStartKey: lastReceivedKey, + }, false) + }() glog.V(2).Infof("Server.GetSnapshot: Sending a GetSnapshot message to peer (%v) "+ "with Prefix (%v) and SnapshotStartEntry (%v)", pp, prefix, lastReceivedKey) @@ -876,8 +921,8 @@ func (srv *Server) GetBlocks(pp *Peer, maxHeight int) { func (srv *Server) _handleHeaderBundle(pp *Peer, msg *MsgDeSoHeaderBundle) { printHeight := pp.StartingBlockHeight() - if srv.blockchain.headerTip().Height > printHeight { - printHeight = srv.blockchain.headerTip().Height + if uint64(srv.blockchain.headerTip().Height) > printHeight { + printHeight = uint64(srv.blockchain.headerTip().Height) } glog.Infof(CLog(Yellow, fmt.Sprintf("Received header bundle with %v headers "+ "in state %s from peer %v. Downloaded ( %v / %v ) total headers", @@ -1232,6 +1277,8 @@ func (srv *Server) _handleSnapshot(pp *Peer, msg *MsgDeSoSnapshotData) { "<%v>, Last entry: <%v>), (number of entries: %v), metadata (%v), and isEmpty (%v), from Peer %v", msg.SnapshotChunk[0].Key, msg.SnapshotChunk[len(msg.SnapshotChunk)-1].Key, len(msg.SnapshotChunk), msg.SnapshotMetadata, msg.SnapshotChunk[0].IsEmpty(), pp))) + // Free up a slot in the operationQueueSemaphore, now that a chunk has been processed. + srv.snapshot.FreeOperationQueueSemaphore() // There is a possibility that during hypersync the network entered a new snapshot epoch. We handle this case by // restarting the node and starting hypersync from scratch. @@ -1615,6 +1662,7 @@ func (srv *Server) _startSync() { // Find a peer with StartingHeight bigger than our best header tip. var bestPeer *Peer for _, peer := range srv.cmgr.GetAllPeers() { + if !peer.IsSyncCandidate() { glog.Infof("Peer is not sync candidate: %v (isOutbound: %v)", peer, peer.isOutbound) continue @@ -1622,7 +1670,7 @@ func (srv *Server) _startSync() { // Choose the peer with the best height out of everyone who's a // valid sync candidate. - if peer.StartingBlockHeight() < bestHeight { + if peer.StartingBlockHeight() < uint64(bestHeight) { continue } @@ -1672,11 +1720,19 @@ func (srv *Server) _startSync() { } -func (srv *Server) _handleNewPeer(pp *Peer) { +func (srv *Server) HandleAcceptedPeer(rn *RemoteNode) { + if rn == nil || rn.GetPeer() == nil { + return + } + pp := rn.GetPeer() + pp.SetServiceFlag(rn.GetServiceFlag()) + pp.SetLatestBlockHeight(rn.GetLatestBlockHeight()) + isSyncCandidate := pp.IsSyncCandidate() isSyncing := srv.blockchain.isSyncing() chainState := srv.blockchain.chainState() - glog.V(1).Infof("Server._handleNewPeer: Processing NewPeer: (%v); IsSyncCandidate(%v), syncPeerIsNil=(%v), IsSyncing=(%v), ChainState=(%v)", + glog.V(1).Infof("Server.HandleAcceptedPeer: Processing NewPeer: (%v); IsSyncCandidate(%v), "+ + "syncPeerIsNil=(%v), IsSyncing=(%v), ChainState=(%v)", pp, isSyncCandidate, (srv.SyncPeer == nil), isSyncing, chainState) // Request a sync if we're ready @@ -1691,6 +1747,22 @@ func (srv *Server) _handleNewPeer(pp *Peer) { } } +func (srv *Server) maybeRequestAddresses(remoteNode *RemoteNode) { + if remoteNode == nil { + return + } + // If the address manager needs more addresses, then send a GetAddr message + // to the peer. This is best-effort. + if !srv.AddrMgr.NeedMoreAddresses() { + return + } + + if err := remoteNode.SendMessage(&MsgDeSoGetAddr{}); err != nil { + glog.Errorf("Server.maybeRequestAddresses: Problem sending GetAddr message to "+ + "remoteNode (id= %v); err: %v", remoteNode, err) + } +} + func (srv *Server) _cleanupDonePeerState(pp *Peer) { // Grab the dataLock since we'll be modifying requestedBlocks srv.dataLock.Lock() @@ -1755,8 +1827,8 @@ func (srv *Server) _cleanupDonePeerState(pp *Peer) { }, false) } -func (srv *Server) _handleDonePeer(pp *Peer) { - glog.V(1).Infof("Server._handleDonePeer: Processing DonePeer: %v", pp) +func (srv *Server) _handleDisconnectedPeerMessage(pp *Peer) { + glog.V(1).Infof("Server._handleDisconnectedPeerMessage: Processing DonePeer: %v", pp) srv._cleanupDonePeerState(pp) @@ -1799,6 +1871,10 @@ func (srv *Server) _relayTransactions() { // for which the minimum fee is below what the Peer will allow. invMsg := &MsgDeSoInv{} for _, newTxn := range txnList { + if !newTxn.IsValidated() { + continue + } + invVect := &InvVect{ Type: InvTypeTx, Hash: *newTxn.Hash(), @@ -1852,7 +1928,7 @@ func (srv *Server) _addNewTxn( // Only attempt to add the transaction to the PoW mempool if we're on the // PoW protocol. If we're on the PoW protocol, then we use the PoW mempool's, // txn validity checks to signal whether the txn has been added or not. The PoW - // mempool has stricter txn validity checks than the PoW mempool, so this works + // mempool has stricter txn validity checks than the PoS mempool, so this works // out conveniently, as it allows us to always add a txn to the PoS mempool. if srv.params.IsPoWBlockHeight(tipHeight) { _, err := srv.mempool.ProcessTransaction( @@ -1866,8 +1942,8 @@ func (srv *Server) _addNewTxn( // Always add the txn to the PoS mempool. This should always succeed if the txn // addition into the PoW mempool succeeded above. - mempoolTxn := NewMempoolTransaction(txn, time.Now()) - if err := srv.posMempool.AddTransaction(mempoolTxn, true /*verifySignatures*/); err != nil { + mempoolTxn := NewMempoolTransaction(txn, time.Now(), false) + if err := srv.posMempool.AddTransaction(mempoolTxn); err != nil { return nil, errors.Wrapf(err, "Server._addNewTxn: problem adding txn to pos mempool") } @@ -2049,15 +2125,14 @@ func (srv *Server) _handleBlock(pp *Peer, blk *MsgDeSoBlock) { return } - if pp != nil { - if _, exists := pp.requestedBlocks[*blockHash]; !exists { - glog.Errorf("_handleBlock: Getting a block that we haven't requested before, "+ - "block hash (%v)", *blockHash) - } - delete(pp.requestedBlocks, *blockHash) - } else { - glog.Errorf("_handleBlock: Called with nil peer, this should never happen.") + // Log a warning if we receive a block we haven't requested yet. It is still possible to receive + // a block in this case if we're connected directly to the block producer and they send us a block + // directly. + if _, exists := pp.requestedBlocks[*blockHash]; !exists { + glog.Warningf("_handleBlock: Getting a block that we haven't requested before, "+ + "block hash (%v)", *blockHash) } + delete(pp.requestedBlocks, *blockHash) // Check that the mempool has not received a transaction that would forbid this block's signature pubkey. // This is a minimal check, a more thorough check is made in the ProcessBlock function. This check is @@ -2104,20 +2179,28 @@ func (srv *Server) _handleBlock(pp *Peer, blk *MsgDeSoBlock) { // headers comment above but in the future we should probably try and figure // out a way to be more strict about things. glog.Warningf("Got duplicate block %v from peer %v", blk, pp) + } else if strings.Contains(err.Error(), RuleErrorFailedSpamPreventionsCheck.Error()) { + // If the block fails the spam prevention check, then it must be signed by the + // bad block proposer signature or it has a bad QC. In either case, we should + // disconnect the peer. + srv._logAndDisconnectPeer(pp, blk, errors.Wrapf(err, "Error while processing block: ").Error()) + return } else { - srv._logAndDisconnectPeer( - pp, blk, - errors.Wrapf(err, "Error while processing block: ").Error()) + // For any other error, we log the error and continue. + glog.Errorf("Server._handleBlock: Error while processing block: %v", err) return } } + if isOrphan { - // We should generally never receive orphan blocks. It indicates something - // went wrong in our headers syncing. - glog.Errorf("ERROR: Received orphan block with hash %v height %v. "+ + // It's possible to receive an orphan block if we're connected directly to the + // block producer, and they are broadcasting blocks in the steady state. We log + // a warning in this case and move on. + glog.Warningf("ERROR: Received orphan block with hash %v height %v. "+ "This should never happen", blockHash, blk.Header.Height) return } + srv.timer.End("Server._handleBlock: Process Block") srv.timer.Print("Server._handleBlock: General") @@ -2125,9 +2208,7 @@ func (srv *Server) _handleBlock(pp *Peer, blk *MsgDeSoBlock) { // We shouldn't be receiving blocks while syncing headers. if srv.blockchain.chainState() == SyncStateSyncingHeaders { - srv._logAndDisconnectPeer( - pp, blk, - "We should never get blocks when we're syncing headers") + glog.Warningf("Server._handleBlock: Received block while syncing headers: %v", blk) return } @@ -2257,7 +2338,7 @@ func (srv *Server) ProcessSingleTxnWithChainLock(pp *Peer, txn *MsgDeSoTxn) ([]* // Regardless of the consensus protocol we're running (PoW or PoS), we use the PoS mempool's to house all // mempool txns. If a txn can't make it into the PoS mempool, which uses a looser unspent balance check for // the the transactor, then it must be invalid. - if err := srv.posMempool.AddTransaction(NewMempoolTransaction(txn, time.Now()), true); err != nil { + if err := srv.posMempool.AddTransaction(NewMempoolTransaction(txn, time.Now(), false)); err != nil { return nil, errors.Wrapf(err, "Server.ProcessSingleTxnWithChainLock: Problem adding transaction to PoS mempool: ") } @@ -2360,20 +2441,33 @@ func (srv *Server) StartStatsdReporter() { }() } -func (srv *Server) _handleAddrMessage(pp *Peer, msg *MsgDeSoAddr) { +func (srv *Server) _handleAddrMessage(pp *Peer, desoMsg DeSoMessage) { + if desoMsg.GetMsgType() != MsgTypeAddr { + return + } + + id := NewRemoteNodeId(pp.ID) + var msg *MsgDeSoAddr + var ok bool + if msg, ok = desoMsg.(*MsgDeSoAddr); !ok { + glog.Errorf("Server._handleAddrMessage: Problem decoding MsgDeSoAddr: %v", spew.Sdump(desoMsg)) + srv.networkManager.DisconnectById(id) + return + } + srv.addrsToBroadcastLock.Lock() defer srv.addrsToBroadcastLock.Unlock() - glog.V(1).Infof("Server._handleAddrMessage: Received Addr from peer %v with addrs %v", pp, spew.Sdump(msg.AddrList)) + glog.V(1).Infof("Server._handleAddrMessage: Received Addr from peer id=%v with addrs %v", pp.ID, spew.Sdump(msg.AddrList)) // If this addr message contains more than the maximum allowed number of addresses // then disconnect this peer. if len(msg.AddrList) > MaxAddrsPerAddrMsg { glog.Errorf(fmt.Sprintf("Server._handleAddrMessage: Disconnecting "+ - "Peer %v for sending us an addr message with %d transactions, which exceeds "+ + "Peer id=%v for sending us an addr message with %d transactions, which exceeds "+ "the max allowed %d", - pp, len(msg.AddrList), MaxAddrsPerAddrMsg)) - pp.Disconnect() + pp.ID, len(msg.AddrList), MaxAddrsPerAddrMsg)) + srv.networkManager.DisconnectById(id) return } @@ -2382,17 +2476,16 @@ func (srv *Server) _handleAddrMessage(pp *Peer, msg *MsgDeSoAddr) { for _, addr := range msg.AddrList { addrAsNetAddr := wire.NewNetAddressIPPort(addr.IP, addr.Port, (wire.ServiceFlag)(addr.Services)) if !addrmgr.IsRoutable(addrAsNetAddr) { - glog.V(1).Infof("Dropping address %v from peer %v because it is not routable", addr, pp) + glog.V(1).Infof("Server._handleAddrMessage: Dropping address %v from peer %v because it is not routable", addr, pp) continue } netAddrsReceived = append( netAddrsReceived, addrAsNetAddr) } - srv.cmgr.AddrMgr.AddAddresses(netAddrsReceived, pp.netAddr) + srv.AddrMgr.AddAddresses(netAddrsReceived, pp.netAddr) - // If the message had <= 10 addrs in it, then queue all the addresses for relaying - // on the next cycle. + // If the message had <= 10 addrs in it, then queue all the addresses for relaying on the next cycle. if len(msg.AddrList) <= 10 { glog.V(1).Infof("Server._handleAddrMessage: Queueing %d addrs for forwarding from "+ "peer %v", len(msg.AddrList), pp) @@ -2402,7 +2495,7 @@ func (srv *Server) _handleAddrMessage(pp *Peer, msg *MsgDeSoAddr) { Port: pp.netAddr.Port, Services: pp.serviceFlags, } - listToAddTo, hasSeenSource := srv.addrsToBroadcastt[sourceAddr.StringWithPort(false /*includePort*/)] + listToAddTo, hasSeenSource := srv.addrsToBroadcast[sourceAddr.StringWithPort(false /*includePort*/)] if !hasSeenSource { listToAddTo = []*SingleAddr{} } @@ -2412,15 +2505,30 @@ func (srv *Server) _handleAddrMessage(pp *Peer, msg *MsgDeSoAddr) { listToAddTo = listToAddTo[:MaxAddrsPerAddrMsg/2] } listToAddTo = append(listToAddTo, msg.AddrList...) - srv.addrsToBroadcastt[sourceAddr.StringWithPort(false /*includePort*/)] = listToAddTo + srv.addrsToBroadcast[sourceAddr.StringWithPort(false /*includePort*/)] = listToAddTo } } -func (srv *Server) _handleGetAddrMessage(pp *Peer, msg *MsgDeSoGetAddr) { +func (srv *Server) _handleGetAddrMessage(pp *Peer, desoMsg DeSoMessage) { + if desoMsg.GetMsgType() != MsgTypeGetAddr { + return + } + + id := NewRemoteNodeId(pp.ID) + if _, ok := desoMsg.(*MsgDeSoGetAddr); !ok { + glog.Errorf("Server._handleAddrMessage: Problem decoding "+ + "MsgDeSoAddr: %v", spew.Sdump(desoMsg)) + srv.networkManager.DisconnectById(id) + return + } + glog.V(1).Infof("Server._handleGetAddrMessage: Received GetAddr from peer %v", pp) // When we get a GetAddr message, choose MaxAddrsPerMsg from the AddrMgr // and send them back to the peer. - netAddrsFound := srv.cmgr.AddrMgr.AddressCache() + netAddrsFound := srv.AddrMgr.AddressCache() + if len(netAddrsFound) == 0 { + return + } if len(netAddrsFound) > MaxAddrsPerAddrMsg { netAddrsFound = netAddrsFound[:MaxAddrsPerAddrMsg] } @@ -2436,16 +2544,22 @@ func (srv *Server) _handleGetAddrMessage(pp *Peer, msg *MsgDeSoGetAddr) { } res.AddrList = append(res.AddrList, singleAddr) } - pp.AddDeSoMessage(res, false) + rn := srv.networkManager.GetRemoteNodeById(id) + if err := srv.networkManager.SendMessage(rn, res); err != nil { + glog.Errorf("Server._handleGetAddrMessage: Problem sending addr message to peer %v: %v", pp, err) + srv.networkManager.DisconnectById(id) + return + } } func (srv *Server) _handleControlMessages(serverMessage *ServerMessage) (_shouldQuit bool) { switch serverMessage.Msg.(type) { // Control messages used internally to signal to the server. - case *MsgDeSoNewPeer: - srv._handleNewPeer(serverMessage.Peer) - case *MsgDeSoDonePeer: - srv._handleDonePeer(serverMessage.Peer) + case *MsgDeSoDisconnectedPeer: + srv._handleDisconnectedPeerMessage(serverMessage.Peer) + srv.networkManager._handleDisconnectedPeerMessage(serverMessage.Peer, serverMessage.Msg) + case *MsgDeSoNewConnection: + srv.networkManager._handleNewConnectionMessage(serverMessage.Peer, serverMessage.Msg) case *MsgDeSoQuit: return true } @@ -2457,6 +2571,10 @@ func (srv *Server) _handlePeerMessages(serverMessage *ServerMessage) { // Handle all non-control message types from our Peers. switch msg := serverMessage.Msg.(type) { // Messages sent among peers. + case *MsgDeSoAddr: + srv._handleAddrMessage(serverMessage.Peer, serverMessage.Msg) + case *MsgDeSoGetAddr: + srv._handleGetAddrMessage(serverMessage.Peer, serverMessage.Msg) case *MsgDeSoGetHeaders: srv._handleGetHeaders(serverMessage.Peer, msg) case *MsgDeSoHeaderBundle: @@ -2479,6 +2597,10 @@ func (srv *Server) _handlePeerMessages(serverMessage *ServerMessage) { srv._handleMempool(serverMessage.Peer, msg) case *MsgDeSoInv: srv._handleInv(serverMessage.Peer, msg) + case *MsgDeSoVersion: + srv.networkManager._handleVersionMessage(serverMessage.Peer, serverMessage.Msg) + case *MsgDeSoVerack: + srv.networkManager._handleVerackMessage(serverMessage.Peer, serverMessage.Msg) case *MsgDeSoValidatorVote: srv._handleValidatorVote(serverMessage.Peer, msg) case *MsgDeSoValidatorTimeout: @@ -2559,7 +2681,7 @@ func (srv *Server) _startConsensus() { select { case consensusEvent := <-srv._getFastHotStuffConsensusEventChannel(): { - glog.Infof("Server._startConsensus: Received consensus event: %s", consensusEvent.ToString()) + glog.V(2).Infof("Server._startConsensus: Received consensus event: %s", consensusEvent.ToString()) srv._handleFastHostStuffConsensusEvent(consensusEvent) } @@ -2569,20 +2691,6 @@ func (srv *Server) _startConsensus() { glog.V(2).Infof("Server._startConsensus: Handling message of type %v from Peer %v", serverMessage.Msg.GetMsgType(), serverMessage.Peer) - - // If the message is an addr message we handle it independent of whether or - // not the BitcoinManager is synced. - if serverMessage.Msg.GetMsgType() == MsgTypeAddr { - srv._handleAddrMessage(serverMessage.Peer, serverMessage.Msg.(*MsgDeSoAddr)) - continue - } - // If the message is a GetAddr message we handle it independent of whether or - // not the BitcoinManager is synced. - if serverMessage.Msg.GetMsgType() == MsgTypeGetAddr { - srv._handleGetAddrMessage(serverMessage.Peer, serverMessage.Msg.(*MsgDeSoGetAddr)) - continue - } - srv._handlePeerMessages(serverMessage) // Always check for and handle control messages regardless of whether the @@ -2603,35 +2711,36 @@ func (srv *Server) _startConsensus() { glog.V(2).Info("Server.Start: Server done") } -func (srv *Server) _getAddrsToBroadcast() []*SingleAddr { +func (srv *Server) getAddrsToBroadcast() []*SingleAddr { srv.addrsToBroadcastLock.Lock() defer srv.addrsToBroadcastLock.Unlock() // If there's nothing in the map, return. - if len(srv.addrsToBroadcastt) == 0 { + if len(srv.addrsToBroadcast) == 0 { return []*SingleAddr{} } // If we get here then we have some addresses to broadcast. addrsToBroadcast := []*SingleAddr{} - for len(addrsToBroadcast) < 10 && len(srv.addrsToBroadcastt) > 0 { + for uint32(len(addrsToBroadcast)) < srv.params.MaxAddressesToBroadcast && + len(srv.addrsToBroadcast) > 0 { // Choose a key at random. This works because map iteration is random in golang. bucket := "" - for kk := range srv.addrsToBroadcastt { + for kk := range srv.addrsToBroadcast { bucket = kk break } // Remove the last element from the slice for the given bucket. - currentAddrList := srv.addrsToBroadcastt[bucket] + currentAddrList := srv.addrsToBroadcast[bucket] if len(currentAddrList) > 0 { lastIndex := len(currentAddrList) - 1 currentAddr := currentAddrList[lastIndex] currentAddrList = currentAddrList[:lastIndex] if len(currentAddrList) == 0 { - delete(srv.addrsToBroadcastt, bucket) + delete(srv.addrsToBroadcast, bucket) } else { - srv.addrsToBroadcastt[bucket] = currentAddrList + srv.addrsToBroadcast[bucket] = currentAddrList } addrsToBroadcast = append(addrsToBroadcast, currentAddr) @@ -2648,16 +2757,24 @@ func (srv *Server) _startAddressRelayer() { if atomic.LoadInt32(&srv.shutdown) >= 1 { break } - // For the first ten minutes after the server starts, relay our address to all + // For the first ten minutes after the connection controller starts, relay our address to all // peers. After the first ten minutes, do it once every 24 hours. - glog.V(1).Infof("Server.Start._startAddressRelayer: Relaying our own addr to peers") + glog.V(1).Infof("Server.startAddressRelayer: Relaying our own addr to peers") + remoteNodes := srv.networkManager.GetAllRemoteNodes().GetAll() if numMinutesPassed < 10 || numMinutesPassed%(RebroadcastNodeAddrIntervalMinutes) == 0 { - for _, pp := range srv.cmgr.GetAllPeers() { - bestAddress := srv.cmgr.AddrMgr.GetBestLocalAddress(pp.netAddr) + for _, rn := range remoteNodes { + if !rn.IsHandshakeCompleted() { + continue + } + netAddr := rn.GetNetAddress() + if netAddr == nil { + continue + } + bestAddress := srv.AddrMgr.GetBestLocalAddress(netAddr) if bestAddress != nil { - glog.V(2).Infof("Server.Start._startAddressRelayer: Relaying address %v to "+ - "peer %v", bestAddress.IP.String(), pp) - pp.AddDeSoMessage(&MsgDeSoAddr{ + glog.V(2).Infof("Server.startAddressRelayer: Relaying address %v to "+ + "RemoteNode (id= %v)", bestAddress.IP.String(), rn.GetId()) + addrMsg := &MsgDeSoAddr{ AddrList: []*SingleAddr{ { Timestamp: time.Now(), @@ -2666,27 +2783,38 @@ func (srv *Server) _startAddressRelayer() { Services: (ServiceFlag)(bestAddress.Services), }, }, - }, false) + } + if err := rn.SendMessage(addrMsg); err != nil { + glog.Errorf("Server.startAddressRelayer: Problem sending "+ + "MsgDeSoAddr to RemoteNode (id= %v): %v", rn.GetId(), err) + } } } } - glog.V(2).Infof("Server.Start._startAddressRelayer: Seeing if there are addrs to relay...") + glog.V(2).Infof("Server.startAddressRelayer: Seeing if there are addrs to relay...") // Broadcast the addrs we have to all of our peers. - addrsToBroadcast := srv._getAddrsToBroadcast() + addrsToBroadcast := srv.getAddrsToBroadcast() if len(addrsToBroadcast) == 0 { - glog.V(2).Infof("Server.Start._startAddressRelayer: No addrs to relay.") + glog.V(2).Infof("Server.startAddressRelayer: No addrs to relay.") time.Sleep(AddrRelayIntervalSeconds * time.Second) continue } - glog.V(2).Infof("Server.Start._startAddressRelayer: Found %d addrs to "+ + glog.V(2).Infof("Server.startAddressRelayer: Found %d addrs to "+ "relay: %v", len(addrsToBroadcast), spew.Sdump(addrsToBroadcast)) // Iterate over all our peers and broadcast the addrs to all of them. - for _, pp := range srv.cmgr.GetAllPeers() { - pp.AddDeSoMessage(&MsgDeSoAddr{ + for _, rn := range remoteNodes { + if !rn.IsHandshakeCompleted() { + continue + } + addrMsg := &MsgDeSoAddr{ AddrList: addrsToBroadcast, - }, false) + } + if err := rn.SendMessage(addrMsg); err != nil { + glog.Errorf("Server.startAddressRelayer: Problem sending "+ + "MsgDeSoAddr to RemoteNode (id= %v): %v", rn.GetId(), err) + } } time.Sleep(AddrRelayIntervalSeconds * time.Second) continue @@ -2723,6 +2851,9 @@ func (srv *Server) Stop() { srv.cmgr.Stop() glog.Infof(CLog(Yellow, "Server.Stop: Closed the ConnectionManger")) + srv.networkManager.Stop() + glog.Infof(CLog(Yellow, "Server.Stop: Closed the NetworkManager")) + // Stop the miner if we have one running. if srv.miner != nil { srv.miner.Stop() @@ -2816,6 +2947,8 @@ func (srv *Server) Start() { go srv.miner.Start() } + srv.networkManager.Start() + // On testnet, if the node is configured to be a PoW block producer, and it is configured // to be also a PoS validator, then we attach block mined listeners to the miner to kick // off the PoS consensus once the miner is done. diff --git a/lib/snapshot.go b/lib/snapshot.go index f0db0a511..6d0c8df9e 100644 --- a/lib/snapshot.go +++ b/lib/snapshot.go @@ -468,8 +468,6 @@ func (snap *Snapshot) Run() { operation.blockHeight); err != nil { glog.Errorf("Snapshot.Run: Problem adding snapshot chunk to the db") } - // Free up a slot in the operationQueueSemaphore, now that a chunk has been processed. - <-snap.operationQueueSemaphore case SnapshotOperationChecksumAdd: if err := snap.Checksum.AddOrRemoveBytesWithMigrations(operation.checksumKey, operation.checksumValue, @@ -1345,6 +1343,12 @@ func (snap *Snapshot) SetSnapshotChunk(mainDb *badger.DB, mainDbMutex *deadlock. return nil } +func (snap *Snapshot) FreeOperationQueueSemaphore() { + if len(snap.operationQueueSemaphore) > 0 { + <-snap.operationQueueSemaphore + } +} + // ------------------------------------------------------------------------------------- // StateChecksum // ------------------------------------------------------------------------------------- diff --git a/lib/state_change_syncer.go b/lib/state_change_syncer.go index a5a9f1777..d9d275716 100644 --- a/lib/state_change_syncer.go +++ b/lib/state_change_syncer.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "fmt" + "github.com/deso-protocol/core/collections" "github.com/deso-protocol/go-deadlock" "github.com/golang/glog" "github.com/google/uuid" @@ -258,6 +259,8 @@ type StateChangeSyncer struct { // of each entry, the consumer only has to sync the most recent version of each entry. // BlocksyncCompleteEntriesFlushed is used to track whether this one time flush has been completed. BlocksyncCompleteEntriesFlushed bool + + MempoolTxnSyncLimit uint64 } // Open a file, create if it doesn't exist. @@ -275,7 +278,8 @@ func openOrCreateLogFile(filePath string) (*os.File, error) { } // NewStateChangeSyncer initializes necessary log files and returns a StateChangeSyncer. -func NewStateChangeSyncer(stateChangeDir string, nodeSyncType NodeSyncType) *StateChangeSyncer { +func NewStateChangeSyncer(stateChangeDir string, nodeSyncType NodeSyncType, mempoolTxnSyncLimit uint64, +) *StateChangeSyncer { stateChangeFilePath := filepath.Join(stateChangeDir, StateChangeFileName) stateChangeIndexFilePath := filepath.Join(stateChangeDir, StateChangeIndexFileName) stateChangeMempoolFilePath := filepath.Join(stateChangeDir, StateChangeMempoolFileName) @@ -322,6 +326,7 @@ func NewStateChangeSyncer(stateChangeDir string, nodeSyncType NodeSyncType) *Sta StateSyncerMutex: &sync.Mutex{}, SyncType: nodeSyncType, BlocksyncCompleteEntriesFlushed: blocksyncCompleteEntriesFlushed, + MempoolTxnSyncLimit: mempoolTxnSyncLimit, } } @@ -624,10 +629,9 @@ func createMempoolTxKey(operationType StateSyncerOperationType, keyBytes []byte) // in the mempool state change file. It also loops through all unconnected transactions and their associated // utxo ops and adds them to the mempool state change file. func (stateChangeSyncer *StateChangeSyncer) SyncMempoolToStateSyncer(server *Server) (bool, error) { - originalCommittedFlushId := stateChangeSyncer.BlockSyncFlushId - if server.mempool.stopped { + if !server.GetMempool().IsRunning() { return true, nil } @@ -655,7 +659,9 @@ func (stateChangeSyncer *StateChangeSyncer) SyncMempoolToStateSyncer(server *Ser // Kill the snapshot so that it doesn't affect the original snapshot. mempoolUtxoView.Snapshot = nil + server.blockchain.ChainLock.RLock() mempoolUtxoView.TipHash = server.blockchain.bestChain[len(server.blockchain.bestChain)-1].Hash + server.blockchain.ChainLock.RUnlock() // A new transaction is created so that we can simulate writes to the db without actually writing to the db. // Using the transaction here rather than a stubbed badger db allows the process to query the db for any entries @@ -672,10 +678,10 @@ func (stateChangeSyncer *StateChangeSyncer) SyncMempoolToStateSyncer(server *Ser } // Loop through all the transactions in the mempool and connect them and their utxo ops to the mempool view. - server.mempool.mtx.RLock() - mempoolTxns, _, err := server.mempool._getTransactionsOrderedByTimeAdded() - server.mempool.mtx.RUnlock() + mempoolTxns := server.GetMempool().GetOrderedTransactions() + // Get the uncommitted blocks from the chain. + uncommittedBlocks, err := server.blockchain.GetUncommittedFullBlocks(mempoolUtxoView.TipHash) if err != nil { mempoolUtxoView.EventManager.stateSyncerFlushed(&StateSyncerFlushedEvent{ FlushId: uuid.Nil, @@ -685,13 +691,74 @@ func (stateChangeSyncer *StateChangeSyncer) SyncMempoolToStateSyncer(server *Ser return false, errors.Wrapf(err, "StateChangeSyncer.SyncMempoolToStateSyncer: ") } - currentTimestamp := time.Now().UnixNano() - for _, mempoolTx := range mempoolTxns { - utxoOpsForTxn, _, _, _, err := mempoolTxUtxoView.ConnectTransaction( - mempoolTx.Tx, mempoolTx.Hash, uint32(blockHeight+1), - currentTimestamp, false, false /*ignoreUtxos*/) + // First connect the uncommitted blocks to the mempool view. + for _, uncommittedBlock := range uncommittedBlocks { + var utxoOpsForBlock [][]*UtxoOperation + txHashes := collections.Transform(uncommittedBlock.Txns, func(txn *MsgDeSoTxn) *BlockHash { + return txn.Hash() + }) + // TODO: there is a slight performance enhancement we could make here + // by rewriting the ConnectBlock logic to avoid unnecessary UtxoView copying + // for failing transactions. However, we'd also need to rewrite the end-of-epoch + // logic here which would make this function a bit long. + // Connect this block to the mempoolTxUtxoView so we can get the utxo ops. + utxoOpsForBlock, err = mempoolTxUtxoView.ConnectBlock( + uncommittedBlock, txHashes, false, nil, uncommittedBlock.Header.Height) if err != nil { - return false, errors.Wrapf(err, "StateChangeSyncer.SyncMempoolToStateSyncer ConnectTransaction: ") + mempoolUtxoView.EventManager.stateSyncerFlushed(&StateSyncerFlushedEvent{ + FlushId: uuid.Nil, + Succeeded: false, + IsMempoolFlush: true, + }) + return false, errors.Wrapf(err, "StateChangeSyncer.SyncMempoolToStateSyncer ConnectBlock uncommitted block: ") + } + blockHash, _ := uncommittedBlock.Hash() + // Emit the UtxoOps event. + mempoolUtxoView.EventManager.stateSyncerOperation(&StateSyncerOperationEvent{ + StateChangeEntry: &StateChangeEntry{ + OperationType: DbOperationTypeUpsert, + KeyBytes: _DbKeyForUtxoOps(blockHash), + EncoderBytes: EncodeToBytes(blockHeight, &UtxoOperationBundle{ + UtxoOpBundle: utxoOpsForBlock, + }, false), + Block: uncommittedBlock, + }, + FlushId: uuid.Nil, + IsMempoolTxn: true, + }) + } + + currentTimestamp := time.Now().UnixNano() + for ii, mempoolTx := range mempoolTxns { + if server.params.IsPoSBlockHeight(blockHeight) && uint64(ii) > stateChangeSyncer.MempoolTxnSyncLimit { + break + } + var utxoOpsForTxn []*UtxoOperation + if server.params.IsPoSBlockHeight(blockHeight + 1) { + // We need to create a copy of the view in the event that the transaction fails to + // connect. If it fails to connect, we need to reset the view to its original state. + // and try to connect it as a failing transaction. If that fails as well, we just continue + // and the mempoolTxUtxoView is unmodified. + var copiedView *UtxoView + copiedView, err = mempoolTxUtxoView.CopyUtxoView() + if err != nil { + return false, errors.Wrapf(err, "StateChangeSyncer.SyncMempoolToStateSyncer CopyUtxoView: ") + } + utxoOpsForTxn, _, _, _, err = copiedView.ConnectTransaction( + mempoolTx.Tx, mempoolTx.Hash, uint32(blockHeight+1), + currentTimestamp, false, false /*ignoreUtxos*/) + // If the transaction successfully connected, we update mempoolTxUtxoView to the copied view. + if err == nil { + mempoolTxUtxoView = copiedView + } + } else { + // For PoW block heights, we can just connect the transaction to the mempool view. + utxoOpsForTxn, _, _, _, err = mempoolTxUtxoView.ConnectTransaction( + mempoolTx.Tx, mempoolTx.Hash, uint32(blockHeight+1), + currentTimestamp, false, false /*ignoreUtxos*/) + if err != nil { + return false, errors.Wrapf(err, "StateChangeSyncer.SyncMempoolToStateSyncer ConnectTransaction: ") + } } // Emit transaction state change. @@ -747,7 +814,7 @@ func (stateChangeSyncer *StateChangeSyncer) SyncMempoolToStateSyncer(server *Ser func (stateChangeSyncer *StateChangeSyncer) StartMempoolSyncRoutine(server *Server) { go func() { // Wait for mempool to be initialized. - for server.mempool == nil || server.blockchain.chainState() != SyncStateFullyCurrent { + for server.GetMempool() == nil || server.blockchain.chainState() != SyncStateFullyCurrent { time.Sleep(15000 * time.Millisecond) } if !stateChangeSyncer.BlocksyncCompleteEntriesFlushed && stateChangeSyncer.SyncType == NodeSyncTypeBlockSync { @@ -757,7 +824,7 @@ func (stateChangeSyncer *StateChangeSyncer) StartMempoolSyncRoutine(server *Serv fmt.Printf("StateChangeSyncer.StartMempoolSyncRoutine: Error flushing all entries to file: %v", err) } } - mempoolClosed := server.mempool.stopped + mempoolClosed := !server.GetMempool().IsRunning() for !mempoolClosed { // Sleep for a short while to avoid a tight loop. time.Sleep(100 * time.Millisecond) diff --git a/lib/txindex.go b/lib/txindex.go index 61d0a7966..f1f805235 100644 --- a/lib/txindex.go +++ b/lib/txindex.go @@ -149,7 +149,11 @@ func NewTXIndex(coreChain *Blockchain, params *DeSoParams, dataDirectory string) } func (txi *TXIndex) FinishedSyncing() bool { - return txi.TXIndexChain.BlockTip().Height == txi.CoreChain.BlockTip().Height + committedTip, idx := txi.CoreChain.GetCommittedTip() + if idx == -1 { + return false + } + return txi.TXIndexChain.BlockTip().Height == committedTip.Height } func (txi *TXIndex) Start() { @@ -224,7 +228,7 @@ func (txi *TXIndex) GetTxindexUpdateBlockNodes() ( txindexTipNode := blockIndexByHashCopy[*txindexTipHash.Hash] // Get the committed tip. - committedTip, _ := txi.CoreChain.getCommittedTip() + committedTip, _ := txi.CoreChain.GetCommittedTip() if txindexTipNode == nil { glog.Info("GetTxindexUpdateBlockNodes: Txindex tip was not found; building txindex starting at genesis block") @@ -408,10 +412,16 @@ func (txi *TXIndex) Update() error { return fmt.Errorf( "Update: Error initializing UtxoView: %v", err) } + if blockToAttach.Header.PrevBlockHash != nil { + utxoView, err = txi.TXIndexChain.getUtxoViewAtBlockHash(*blockToAttach.Header.PrevBlockHash) + if err != nil { + return fmt.Errorf("Update: Problem getting UtxoView at block hash %v: %v", + blockToAttach.Header.PrevBlockHash, err) + } + } // Do each block update in a single transaction so we're safe in case the node // restarts. - blockHeight := uint64(txi.CoreChain.BlockTip().Height) err = txi.TXIndexChain.DB().Update(func(dbTxn *badger.Txn) error { // Iterate through each transaction in the block and do the following: @@ -421,13 +431,13 @@ func (txi *TXIndex) Update() error { for txnIndexInBlock, txn := range blockMsg.Txns { txnMeta, err := ConnectTxnAndComputeTransactionMetadata( txn, utxoView, blockToAttach.Hash, blockToAttach.Height, - int64(blockToAttach.Header.TstampNanoSecs), uint64(txnIndexInBlock)) + blockToAttach.Header.TstampNanoSecs, uint64(txnIndexInBlock)) if err != nil { return fmt.Errorf("Update: Problem connecting txn %v to txindex: %v", txn, err) } - err = DbPutTxindexTransactionMappingsWithTxn(dbTxn, nil, blockHeight, + err = DbPutTxindexTransactionMappingsWithTxn(dbTxn, nil, blockMsg.Header.Height, txn, txi.Params, txnMeta, txi.CoreChain.eventManager) if err != nil { return fmt.Errorf("Update: Problem adding txn %v to txindex: %v", @@ -454,3 +464,26 @@ func (txi *TXIndex) Update() error { return nil } + +func ConnectTxnAndComputeTransactionMetadata( + txn *MsgDeSoTxn, utxoView *UtxoView, blockHash *BlockHash, + blockHeight uint32, blockTimestampNanoSecs int64, txnIndexInBlock uint64) (*TransactionMetadata, error) { + + totalNanosPurchasedBefore := utxoView.NanosPurchased + usdCentsPerBitcoinBefore := utxoView.GetCurrentUSDCentsPerBitcoin() + + var utxoOps []*UtxoOperation + var totalInput, totalOutput, fees uint64 + var err error + utxoOps, totalInput, totalOutput, fees, err = utxoView._connectTransaction( + txn, txn.Hash(), blockHeight, blockTimestampNanoSecs, false, false, + ) + + if err != nil { + return nil, fmt.Errorf( + "UpdateTxindex: Error connecting txn to UtxoView: %v", err) + } + + return ComputeTransactionMetadata(txn, utxoView, blockHash, totalNanosPurchasedBefore, + usdCentsPerBitcoinBefore, totalInput, totalOutput, fees, txnIndexInBlock, utxoOps, uint64(blockHeight)), nil +}