diff --git a/rskj-core/src/main/java/co/rsk/RskContext.java b/rskj-core/src/main/java/co/rsk/RskContext.java index 1ac4727b8d..7d6ecd7aec 100644 --- a/rskj-core/src/main/java/co/rsk/RskContext.java +++ b/rskj-core/src/main/java/co/rsk/RskContext.java @@ -2089,6 +2089,7 @@ private SnapshotProcessor getSnapshotProcessor() { new ValidGasUsedRule() ), getRskSystemProperties().getSnapshotChunkSize(), + getRskSystemProperties().checkHistoricalHeaders(), getRskSystemProperties().isSnapshotParallelEnabled() ); } diff --git a/rskj-core/src/main/java/co/rsk/config/RskSystemProperties.java b/rskj-core/src/main/java/co/rsk/config/RskSystemProperties.java index 9e2da036ee..a868152b51 100644 --- a/rskj-core/src/main/java/co/rsk/config/RskSystemProperties.java +++ b/rskj-core/src/main/java/co/rsk/config/RskSystemProperties.java @@ -73,6 +73,7 @@ public class RskSystemProperties extends SystemProperties { public static final String USE_PEERS_FROM_LAST_SESSION = "peer.discovery.usePeersFromLastSession"; public static final String PROPERTY_SNAP_CLIENT_ENABLED = "sync.snapshot.client.enabled"; + public static final String PROPERTY_SNAP_CLIENT_CHECK_HISTORICAL_HEADERS = "sync.snapshot.client.checkHistoricalHeaders"; public static final String PROPERTY_SNAP_NODES = "sync.snapshot.client.snapBootNodes"; //TODO: REMOVE THIS WHEN THE LocalBLockTests starts working with REMASC @@ -429,6 +430,8 @@ public int getLongSyncLimit() { public boolean isServerSnapshotSyncEnabled() { return configFromFiles.getBoolean("sync.snapshot.server.enabled");} public boolean isClientSnapshotSyncEnabled() { return configFromFiles.getBoolean(PROPERTY_SNAP_CLIENT_ENABLED);} + public boolean checkHistoricalHeaders() { return configFromFiles.getBoolean(PROPERTY_SNAP_CLIENT_CHECK_HISTORICAL_HEADERS);} + public boolean isSnapshotParallelEnabled() { return configFromFiles.getBoolean("sync.snapshot.client.parallel");} public int getSnapshotChunkSize() { return configFromFiles.getInt("sync.snapshot.client.chunkSize");} @@ -512,10 +515,14 @@ public boolean fastBlockPropagation() { return configFromFiles.getBoolean("peer.fastBlockPropagation"); } - public Integer getMessageQueueMaxSize() { + public int getMessageQueueMaxSize() { return configFromFiles.getInt("peer.messageQueue.maxSizePerPeer"); } + public int getMessageQueuePerMinuteThreshold() { + return configFromFiles.getInt("peer.messageQueue.thresholdPerMinutePerPeer"); + } + public boolean rpcZeroSignatureIfRemasc() { return configFromFiles.getBoolean("rpc.zeroSignatureIfRemasc"); } diff --git a/rskj-core/src/main/java/co/rsk/net/NodeMessageHandler.java b/rskj-core/src/main/java/co/rsk/net/NodeMessageHandler.java index 31825bd67d..66d0c15fd2 100644 --- a/rskj-core/src/main/java/co/rsk/net/NodeMessageHandler.java +++ b/rskj-core/src/main/java/co/rsk/net/NodeMessageHandler.java @@ -211,7 +211,7 @@ private void tryAddMessage(Peer sender, Message message, NodeMsgTraceInfo nodeMs */ private boolean controlMessageIngress(Peer sender, Message message, double score) { return - allowByScore(score) && + allowByScore(sender, message, score) && allowByMessageCount(sender) && allowByMinerNotBanned(sender, message) && allowByMessageUniqueness(sender, message); // prevent repeated is the most expensive and MUST be the last @@ -221,8 +221,13 @@ private boolean controlMessageIngress(Peer sender, Message message, double score /** * assert score is acceptable */ - private boolean allowByScore(double score) { - return score >= 0; + private boolean allowByScore(Peer sender, Message message, double score) { + boolean allow = score >= 0; + if (!allow) { + logger.debug("Message: [{}] from: [{}] with score: [{}] was not allowed", message.getMessageType(), sender, score); + } + + return allow; } /** diff --git a/rskj-core/src/main/java/co/rsk/net/SnapshotProcessor.java b/rskj-core/src/main/java/co/rsk/net/SnapshotProcessor.java index dddb79f1fe..6471779a73 100644 --- a/rskj-core/src/main/java/co/rsk/net/SnapshotProcessor.java +++ b/rskj-core/src/main/java/co/rsk/net/SnapshotProcessor.java @@ -53,9 +53,10 @@ import java.util.*; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.atomic.AtomicLong; import java.util.stream.Collectors; +import static co.rsk.net.sync.SnapSyncRequestManager.PeerSelector; + /** * Snapshot Synchronization consist in 3 steps: * 1. Status: exchange message with the server, to know which block we are going to sync and what the size of the Unitrie of that block is. @@ -86,8 +87,7 @@ public class SnapshotProcessor implements InternalService { private final BlockHeaderParentDependantValidationRule blockHeaderParentValidator; private final BlockHeaderValidationRule blockHeaderValidator; - private final AtomicLong messageId = new AtomicLong(0); - + private final boolean checkHistoricalHeaders; // flag for parallel requests private final boolean parallel; @@ -106,10 +106,11 @@ public SnapshotProcessor(Blockchain blockchain, BlockHeaderParentDependantValidationRule blockHeaderParentValidator, BlockHeaderValidationRule blockHeaderValidator, int chunkSize, + boolean checkHistoricalHeaders, boolean isParallelEnabled) { this(blockchain, trieStore, peersInformation, blockStore, transactionPool, blockParentValidator, blockValidator, blockHeaderParentValidator, blockHeaderValidator, - chunkSize, isParallelEnabled, null); + chunkSize, checkHistoricalHeaders, isParallelEnabled, null); } @VisibleForTesting @@ -123,6 +124,7 @@ public SnapshotProcessor(Blockchain blockchain, BlockHeaderParentDependantValidationRule blockHeaderParentValidator, BlockHeaderValidationRule blockHeaderValidator, int chunkSize, + boolean checkHistoricalHeaders, boolean isParallelEnabled, @Nullable SyncMessageHandler.Listener listener) { this.blockchain = blockchain; @@ -138,6 +140,7 @@ public SnapshotProcessor(Blockchain blockchain, this.blockHeaderParentValidator = blockHeaderParentValidator; this.blockHeaderValidator = blockHeaderValidator; + this.checkHistoricalHeaders = checkHistoricalHeaders; this.parallel = isParallelEnabled; this.thread = new Thread(new SyncMessageHandler("SNAP/server", requestQueue, listener) { @@ -157,7 +160,7 @@ public void startSyncing(SnapSyncState state) { } logger.info("Starting Snap sync"); - requestSnapStatus(bestPeerOpt.get()); + requestSnapStatus(state, bestPeerOpt.get()); } private void completeSyncing(SnapSyncState state) { @@ -177,9 +180,22 @@ private void failSyncing(SnapSyncState state, Peer peer, EventType eventType, St /** * STATUS */ - private void requestSnapStatus(Peer peer) { - SnapStatusRequestMessage message = new SnapStatusRequestMessage(); - peer.sendMessage(message); + private void requestSnapStatus(SnapSyncState state, Peer peer) { + state.submitRequest(snapPeerSelector(peer), SnapStatusRequestMessage::new); + } + + private PeerSelector peerSelector(@Nullable Peer peer) { + return PeerSelector.builder() + .withDefaultPeer(() -> peer) + .withAltPeer(peersInformation::getBestPeer) + .build(); + } + + private PeerSelector snapPeerSelector(@Nullable Peer snapPeer) { + return PeerSelector.builder() + .withDefaultPeer(() -> snapPeer) + .withAltPeer(peersInformation::getBestSnapPeer) + .build(); } public void processSnapStatusRequest(Peer sender, SnapStatusRequestMessage requestMessage) { @@ -201,7 +217,7 @@ public void run() { } } - void processSnapStatusRequestInternal(Peer sender, SnapStatusRequestMessage ignoredRequestMessage) { + void processSnapStatusRequestInternal(Peer sender, SnapStatusRequestMessage requestMessage) { long bestBlockNumber = blockchain.getBestBlock().getNumber(); long checkpointBlockNumber = bestBlockNumber - (bestBlockNumber % BLOCK_NUMBER_CHECKPOINT); logger.debug("Processing snapshot status request, checkpointBlockNumber: {}, bestBlockNumber: {}", checkpointBlockNumber, bestBlockNumber); @@ -226,7 +242,7 @@ void processSnapStatusRequestInternal(Peer sender, SnapStatusRequestMessage igno long trieSize = opt.get().getTotalSize(); logger.debug("Processing snapshot status request - rootHash: {} trieSize: {}", rootHash, trieSize); - SnapStatusResponseMessage responseMessage = new SnapStatusResponseMessage(blocks, difficulties, trieSize); + SnapStatusResponseMessage responseMessage = new SnapStatusResponseMessage(requestMessage.getId(), blocks, difficulties, trieSize); sender.sendMessage(responseMessage); } @@ -261,7 +277,7 @@ public void processSnapStatusResponse(SnapSyncState state, Peer sender, SnapStat generateChunkRequestTasks(state); startRequestingChunks(state); } else { - requestBlocksChunk(sender, blocksFromResponse.get(0).getNumber()); + requestBlocksChunk(state, blocksFromResponse.get(0).getNumber()); } } @@ -318,9 +334,11 @@ private boolean areBlockPairsValid(Pair blockPair, @Null /** * BLOCK CHUNK */ - private void requestBlocksChunk(Peer sender, long blockNumber) { - logger.debug("Requesting block chunk to node {} - block {}", sender.getPeerNodeID(), blockNumber); - sender.sendMessage(new SnapBlocksRequestMessage(blockNumber)); + private void requestBlocksChunk(SnapSyncState state, long blockNumber) { + state.submitRequest( + peerSelector(null), + messageId -> new SnapBlocksRequestMessage(messageId, blockNumber) + ); } public void processBlockHeaderChunk(SnapSyncState state, Peer sender, List chunk) { @@ -405,7 +423,10 @@ private void requestNextBlockHeadersChunk(SnapSyncState state, Peer sender) { logger.debug("Requesting block header chunk to node {} - block [{}/{}]", peer.getPeerNodeID(), lastVerifiedBlockHeader.getNumber() - 1, parentHash); - state.getSyncEventsHandler().sendBlockHeadersRequest(peer, new ChunkDescriptor(parentHash.getBytes(), (int) count)); + state.submitRequest( + peerSelector(sender), + messageId -> new BlockHeadersRequestMessage(messageId, parentHash.getBytes(), (int) count) + ); } public void processSnapBlocksRequest(Peer sender, SnapBlocksRequestMessage requestMessage) { @@ -447,7 +468,7 @@ void processSnapBlocksRequestInternal(Peer sender, SnapBlocksRequestMessage requ difficulties.add(blockStore.getTotalDifficultyForHash(block.getHash().getBytes())); } logger.debug("Sending snap blocks response. From block {} to block {} - chunksize {}", blocks.get(0).getNumber(), blocks.get(blocks.size() - 1).getNumber(), BLOCK_CHUNK_SIZE); - SnapBlocksResponseMessage responseMessage = new SnapBlocksResponseMessage(blocks, difficulties); + SnapBlocksResponseMessage responseMessage = new SnapBlocksResponseMessage(requestMessage.getId(), blocks, difficulties); sender.sendMessage(responseMessage); } @@ -474,7 +495,12 @@ public void processSnapBlocksResponse(SnapSyncState state, Peer sender, SnapBloc generateChunkRequestTasks(state); startRequestingChunks(state); } else if (nextChunk > lastRequiredBlock) { - requestBlocksChunk(sender, nextChunk); + requestBlocksChunk(state, nextChunk); + } else if (!this.checkHistoricalHeaders) { + logger.info("Finished Snap blocks request sending. Start requesting state chunks without historical headers check"); + + generateChunkRequestTasks(state); + startRequestingChunks(state); } else { logger.info("Finished Snap blocks request sending. Start requesting state chunks and block headers"); @@ -488,10 +514,11 @@ public void processSnapBlocksResponse(SnapSyncState state, Peer sender, SnapBloc /** * STATE CHUNK */ - private void requestStateChunk(Peer peer, long from, long blockNumber, int chunkSize) { - logger.debug("Requesting state chunk to node {} - block {} - chunkNumber {}", peer.getPeerNodeID(), blockNumber, from / chunkSize); - SnapStateChunkRequestMessage message = new SnapStateChunkRequestMessage(messageId.getAndIncrement(), blockNumber, from, chunkSize); - peer.sendMessage(message); + private void requestStateChunk(SnapSyncState state, Peer peer, long from, long blockNumber, int chunkSize) { + state.submitRequest( + snapPeerSelector(peer), + messageId -> new SnapStateChunkRequestMessage(messageId, blockNumber, from, chunkSize) + ); } public void processStateChunkRequest(Peer sender, SnapStateChunkRequestMessage requestMessage) { @@ -573,7 +600,7 @@ public void processStateChunkResponse(SnapSyncState state, Peer peer, SnapStateC state.setNextExpectedFrom(nextExpectedFrom + chunkSize * CHUNK_ITEM_SIZE); } catch (Exception e) { logger.error("Error while processing chunk response. {}", e.getMessage(), e); - onStateChunkResponseError(peer, nextMessage); + onStateChunkResponseError(state, peer, nextMessage); } } else { break; @@ -587,19 +614,18 @@ public void processStateChunkResponse(SnapSyncState state, Peer peer, SnapStateC } @VisibleForTesting - void onStateChunkResponseError(Peer peer, SnapStateChunkResponseMessage responseMessage) { + void onStateChunkResponseError(SnapSyncState state, Peer peer, SnapStateChunkResponseMessage responseMessage) { logger.error("Error while processing chunk response from {} of peer {}. Asking for chunk again.", responseMessage.getFrom(), peer.getPeerNodeID()); Peer alternativePeer = peersInformation.getBestSnapPeerCandidates().stream() .filter(listedPeer -> !listedPeer.getPeerNodeID().equals(peer.getPeerNodeID())) .findFirst() .orElse(peer); logger.debug("Requesting state chunk \"from\" {} to peer {}", responseMessage.getFrom(), peer.getPeerNodeID()); - requestStateChunk(alternativePeer, responseMessage.getFrom(), responseMessage.getBlockNumber(), chunkSize); + requestStateChunk(state, alternativePeer, responseMessage.getFrom(), responseMessage.getBlockNumber(), chunkSize); } private void processOrderedStateChunkResponse(SnapSyncState state, Peer peer, SnapStateChunkResponseMessage message) throws Exception { logger.debug("Processing State chunk received from {} to {}", message.getFrom(), message.getTo()); - peersInformation.getOrRegisterPeer(peer); RLPList nodeLists = RLP.decodeList(message.getChunkOfTrieKeyValue()); final RLPList preRootElements = RLP.decodeList(nodeLists.get(0).getRLPData()); @@ -648,10 +674,8 @@ private void processOrderedStateChunkResponse(SnapSyncState state, Peer peer, Sn state.getAllNodes().addAll(nodes); state.setStateSize(state.getStateSize().add(BigInteger.valueOf(trieElements.size()))); state.setStateChunkSize(state.getStateChunkSize().add(BigInteger.valueOf(message.getChunkOfTrieKeyValue().length))); - if (!message.isComplete()) { - executeNextChunkRequestTask(state, peer); - } else { - if (blocksVerified(state)) { + if (message.isComplete()) { + if (!this.checkHistoricalHeaders || blocksVerified(state)) { completeSyncing(state); } else { state.setStateFetched(); @@ -716,7 +740,7 @@ private void executeNextChunkRequestTask(SnapSyncState state, Peer peer) { if (!taskQueue.isEmpty()) { ChunkTask task = taskQueue.poll(); - requestStateChunk(peer, task.getFrom(), task.getBlockNumber(), chunkSize); + requestStateChunk(state, peer, task.getFrom(), task.getBlockNumber(), chunkSize); } else { logger.warn("No more chunk request tasks."); } diff --git a/rskj-core/src/main/java/co/rsk/net/SyncProcessor.java b/rskj-core/src/main/java/co/rsk/net/SyncProcessor.java index 0c5c2477ea..8fa931ec2e 100644 --- a/rskj-core/src/main/java/co/rsk/net/SyncProcessor.java +++ b/rskj-core/src/main/java/co/rsk/net/SyncProcessor.java @@ -41,6 +41,7 @@ import java.time.Duration; import java.util.*; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; /** * This class' methods are executed one at a time because NodeMessageHandler is synchronized. @@ -69,6 +70,7 @@ public class SyncProcessor implements SyncEventsHandler { private final Map pendingMessages; private final AtomicBoolean isSyncing = new AtomicBoolean(); private final SnapshotProcessor snapshotProcessor; + private final AtomicLong lastRequestId = new AtomicLong(); private volatile long initialBlockNumber; private volatile long highestBlockNumber; @@ -76,7 +78,6 @@ public class SyncProcessor implements SyncEventsHandler { private volatile long lastDelayWarn = System.currentTimeMillis(); private SyncState syncState; - private long lastRequestId; @VisibleForTesting public SyncProcessor(Blockchain blockchain, @@ -119,7 +120,7 @@ public SyncProcessor(Blockchain blockchain, this.difficultyRule = new DifficultyRule(difficultyCalculator); this.genesis = genesis; this.ethereumListener = ethereumListener; - this.pendingMessages = new LinkedHashMap() { + this.pendingMessages = Collections.synchronizedMap(new LinkedHashMap<>() { @Override protected boolean removeEldestEntry(Map.Entry eldest) { boolean shouldDiscard = size() > MAX_PENDING_MESSAGES; @@ -128,7 +129,7 @@ protected boolean removeEldestEntry(Map.Entry eldest) { } return shouldDiscard; } - }; + }); this.peersInformation = peersInformation; this.snapshotProcessor = snapshotProcessor; @@ -178,7 +179,7 @@ public void processBlockHeadersResponse(Peer peer, BlockHeadersResponseMessage m MessageType messageType = message.getMessageType(); if (isPending(messageId, messageType)) { removePendingMessage(messageId, messageType); - syncState.newBlockHeaders(peer, message.getBlockHeaders()); + syncState.newBlockHeaders(peer, message); } else { notifyUnexpectedMessageToPeerScoring(peer, "block headers"); } @@ -205,7 +206,7 @@ public void processNewBlockHash(Peer peer, NewBlockHashMessage message) { if (syncState instanceof PeerAndModeDecidingSyncState && blockSyncService.getBlockFromStoreOrBlockchain(hash) == null) { peersInformation.getOrRegisterPeer(peer); - sendMessage(peer, new BlockRequestMessage(++lastRequestId, hash)); + sendMessage(peer, new BlockRequestMessage(nextMessageId(), hash)); } } @@ -224,29 +225,56 @@ public void processBlockResponse(Peer peer, BlockResponseMessage message) { } } - public void processSnapStatusResponse(Peer sender, SnapStatusResponseMessage responseMessage) { - syncState.onSnapStatus(sender, responseMessage); + public void processSnapStatusResponse(Peer peer, SnapStatusResponseMessage responseMessage) { + peersInformation.getOrRegisterPeer(peer); + + long messageId = responseMessage.getId(); + MessageType messageType = responseMessage.getMessageType(); + if (isPending(messageId, messageType)) { + removePendingMessage(messageId, messageType); + syncState.onSnapStatus(peer, responseMessage); + } else { + notifyUnexpectedMessageToPeerScoring(peer, "snap status"); + } } - public void processSnapBlocksResponse(Peer sender, SnapBlocksResponseMessage responseMessage) { - syncState.onSnapBlocks(sender, responseMessage); + public void processSnapBlocksResponse(Peer peer, SnapBlocksResponseMessage responseMessage) { + peersInformation.getOrRegisterPeer(peer); + + long messageId = responseMessage.getId(); + MessageType messageType = responseMessage.getMessageType(); + if (isPending(messageId, messageType)) { + removePendingMessage(messageId, messageType); + syncState.onSnapBlocks(peer, responseMessage); + } else { + notifyUnexpectedMessageToPeerScoring(peer, "snap blocks"); + } } public void processStateChunkResponse(Peer peer, SnapStateChunkResponseMessage responseMessage) { - syncState.onSnapStateChunk(peer, responseMessage); + peersInformation.getOrRegisterPeer(peer); + + long messageId = responseMessage.getId(); + MessageType messageType = responseMessage.getMessageType(); + if (isPending(messageId, messageType)) { + removePendingMessage(messageId, messageType); + syncState.onSnapStateChunk(peer, responseMessage); + } else { + notifyUnexpectedMessageToPeerScoring(peer, "snap state chunk"); + } } @Override public void sendSkeletonRequest(Peer peer, long height) { logger.debug("Send skeleton request to node {} height {}", peer.getPeerNodeID(), height); - MessageWithId message = new SkeletonRequestMessage(++lastRequestId, height); + MessageWithId message = new SkeletonRequestMessage(nextMessageId(), height); sendMessage(peer, message); } @Override public void sendBlockHashRequest(Peer peer, long height) { logger.debug("Send hash request to node {} height {}", peer.getPeerNodeID(), height); - BlockHashRequestMessage message = new BlockHashRequestMessage(++lastRequestId, height); + BlockHashRequestMessage message = new BlockHashRequestMessage(nextMessageId(), height); sendMessage(peer, message); } @@ -255,7 +283,7 @@ public void sendBlockHeadersRequest(Peer peer, ChunkDescriptor chunk) { logger.debug("Send headers request to node {}", peer.getPeerNodeID()); BlockHeadersRequestMessage message = - new BlockHeadersRequestMessage(++lastRequestId, chunk.getHash(), chunk.getCount()); + new BlockHeadersRequestMessage(nextMessageId(), chunk.getHash(), chunk.getCount()); sendMessage(peer, message); } @@ -264,7 +292,7 @@ public long sendBodyRequest(Peer peer, @Nonnull BlockHeader header) { logger.debug("Send body request block {} hash {} to peer {}", header.getNumber(), HashUtil.toPrintableHash(header.getHash().getBytes()), peer.getPeerNodeID()); - BodyRequestMessage message = new BodyRequestMessage(++lastRequestId, header.getHash().getBytes()); + BodyRequestMessage message = new BodyRequestMessage(nextMessageId(), header.getHash().getBytes()); sendMessage(peer, message); return message.getId(); } @@ -464,6 +492,15 @@ int getNoAdvancedPeers() { return this.peersInformation.countIf(s -> chainStatus.hasLowerTotalDifficultyThan(s.getStatus())); } + public long nextMessageId() { + return lastRequestId.incrementAndGet(); + } + + @Override + public void registerPendingMessage(@Nonnull MessageWithId message) { + pendingMessages.put(message.getId(), new MessageInfo(message.getResponseMessageType())); + } + @VisibleForTesting public void registerExpectedMessage(MessageWithId message) { pendingMessages.put(message.getId(), new MessageInfo(message.getMessageType())); diff --git a/rskj-core/src/main/java/co/rsk/net/messages/MessageType.java b/rskj-core/src/main/java/co/rsk/net/messages/MessageType.java index 2b8da7a78d..bf7b902b95 100644 --- a/rskj-core/src/main/java/co/rsk/net/messages/MessageType.java +++ b/rskj-core/src/main/java/co/rsk/net/messages/MessageType.java @@ -21,11 +21,11 @@ import co.rsk.core.BlockDifficulty; import co.rsk.net.Status; import co.rsk.remasc.RemascTransaction; +import org.bouncycastle.util.BigIntegers; import org.ethereum.core.*; import org.ethereum.util.RLP; import org.ethereum.util.RLPElement; import org.ethereum.util.RLPList; -import org.bouncycastle.util.BigIntegers; import java.util.ArrayList; import java.util.List; @@ -260,19 +260,19 @@ public Message createMessage(BlockFactory blockFactory, RLPList list) { SNAP_STATE_CHUNK_REQUEST_MESSAGE(20) { @Override public Message createMessage(BlockFactory blockFactory, RLPList list) { - return SnapStateChunkRequestMessage.create(blockFactory, list); + return SnapStateChunkRequestMessage.decodeMessage(blockFactory, list); } }, SNAP_STATE_CHUNK_RESPONSE_MESSAGE(21) { @Override public Message createMessage(BlockFactory blockFactory, RLPList list) { - return SnapStateChunkResponseMessage.create(blockFactory, list); + return SnapStateChunkResponseMessage.decodeMessage(blockFactory, list); } }, SNAP_STATUS_REQUEST_MESSAGE(22) { @Override public Message createMessage(BlockFactory blockFactory, RLPList list) { - return new SnapStatusRequestMessage(); + return SnapStatusRequestMessage.decodeMessage(blockFactory, list); } }, SNAP_STATUS_RESPONSE_MESSAGE(23) { diff --git a/rskj-core/src/main/java/co/rsk/net/messages/SnapBlocksRequestMessage.java b/rskj-core/src/main/java/co/rsk/net/messages/SnapBlocksRequestMessage.java index b543747738..8cd5869888 100644 --- a/rskj-core/src/main/java/co/rsk/net/messages/SnapBlocksRequestMessage.java +++ b/rskj-core/src/main/java/co/rsk/net/messages/SnapBlocksRequestMessage.java @@ -25,10 +25,13 @@ import java.math.BigInteger; -public class SnapBlocksRequestMessage extends Message { +public class SnapBlocksRequestMessage extends MessageWithId { + private final long id; + private final long blockNumber; - public SnapBlocksRequestMessage(long blockNumber) { + public SnapBlocksRequestMessage(long id, long blockNumber) { + this.id = id; this.blockNumber = blockNumber; } @@ -38,17 +41,30 @@ public MessageType getMessageType() { } @Override - public byte[] getEncodedMessage() { + public MessageType getResponseMessageType() { + return MessageType.SNAP_BLOCKS_RESPONSE_MESSAGE; + } + + @Override + public long getId() { + return this.id; + } + + @Override + protected byte[] getEncodedMessageWithoutId() { byte[] encodedBlockNumber = RLP.encodeBigInteger(BigInteger.valueOf(blockNumber)); return RLP.encodeList(encodedBlockNumber); } public static Message decodeMessage(BlockFactory blockFactory, RLPList list) { - byte[] rlpBlockNumber = list.get(0).getRLPData(); + byte[] rlpId = list.get(0).getRLPData(); + long id = rlpId == null ? 0 : BigIntegers.fromUnsignedByteArray(rlpId).longValue(); + RLPList message = (RLPList)RLP.decode2(list.get(1).getRLPData()).get(0); + byte[] rlpBlockNumber = message.get(0).getRLPData(); long blockNumber = rlpBlockNumber == null ? 0 : BigIntegers.fromUnsignedByteArray(rlpBlockNumber).longValue(); - return new SnapBlocksRequestMessage(blockNumber); + return new SnapBlocksRequestMessage(id, blockNumber); } public long getBlockNumber() { diff --git a/rskj-core/src/main/java/co/rsk/net/messages/SnapBlocksResponseMessage.java b/rskj-core/src/main/java/co/rsk/net/messages/SnapBlocksResponseMessage.java index df0185c458..b3b4c73634 100644 --- a/rskj-core/src/main/java/co/rsk/net/messages/SnapBlocksResponseMessage.java +++ b/rskj-core/src/main/java/co/rsk/net/messages/SnapBlocksResponseMessage.java @@ -20,6 +20,7 @@ import co.rsk.core.BlockDifficulty; import com.google.common.collect.Lists; +import org.bouncycastle.util.BigIntegers; import org.ethereum.core.Block; import org.ethereum.core.BlockFactory; import org.ethereum.util.RLP; @@ -29,11 +30,14 @@ import java.util.List; import java.util.stream.Collectors; -public class SnapBlocksResponseMessage extends Message { +public class SnapBlocksResponseMessage extends MessageWithId { + private final long id; + private final List blocks; private final List difficulties; - public SnapBlocksResponseMessage(List blocks, List difficulties) { + public SnapBlocksResponseMessage(long id, List blocks, List difficulties) { + this.id = id; this.blocks = blocks; this.difficulties = difficulties; } @@ -52,7 +56,12 @@ public List getBlocks() { } @Override - public byte[] getEncodedMessage() { + public long getId() { + return this.id; + } + + @Override + protected byte[] getEncodedMessageWithoutId() { List rlpBlocks = this.blocks.stream().map(Block::getEncoded).map(RLP::encode).collect(Collectors.toList()); List rlpDifficulties = this.difficulties.stream().map(BlockDifficulty::getBytes).map(RLP::encode).collect(Collectors.toList()); return RLP.encodeList(RLP.encodeList(rlpBlocks.toArray(new byte[][]{})), @@ -60,17 +69,21 @@ public byte[] getEncodedMessage() { } public static Message decodeMessage(BlockFactory blockFactory, RLPList list) { + byte[] rlpId = list.get(0).getRLPData(); + long id = rlpId == null ? 0 : BigIntegers.fromUnsignedByteArray(rlpId).longValue(); + + RLPList message = (RLPList)RLP.decode2(list.get(1).getRLPData()).get(0); List blocks = Lists.newArrayList(); List blockDifficulties = Lists.newArrayList(); - RLPList blocksRLP = RLP.decodeList(list.get(0).getRLPData()); + RLPList blocksRLP = RLP.decodeList(message.get(0).getRLPData()); for (int i = 0; i < blocksRLP.size(); i++) { blocks.add(blockFactory.decodeBlock(blocksRLP.get(i).getRLPData())); } - RLPList difficultiesRLP = RLP.decodeList(list.get(1).getRLPData()); + RLPList difficultiesRLP = RLP.decodeList(message.get(1).getRLPData()); for (int i = 0; i < difficultiesRLP.size(); i++) { blockDifficulties.add(new BlockDifficulty(new BigInteger(difficultiesRLP.get(i).getRLPData()))); } - return new SnapBlocksResponseMessage(blocks, blockDifficulties); + return new SnapBlocksResponseMessage(id, blocks, blockDifficulties); } @Override diff --git a/rskj-core/src/main/java/co/rsk/net/messages/SnapStateChunkRequestMessage.java b/rskj-core/src/main/java/co/rsk/net/messages/SnapStateChunkRequestMessage.java index 3195ff22aa..fefc467dad 100644 --- a/rskj-core/src/main/java/co/rsk/net/messages/SnapStateChunkRequestMessage.java +++ b/rskj-core/src/main/java/co/rsk/net/messages/SnapStateChunkRequestMessage.java @@ -43,6 +43,11 @@ public MessageType getMessageType() { return MessageType.SNAP_STATE_CHUNK_REQUEST_MESSAGE; } + @Override + public MessageType getResponseMessageType() { + return MessageType.SNAP_STATE_CHUNK_RESPONSE_MESSAGE; + } + @Override public void accept(MessageVisitor v) { v.apply(this); @@ -61,21 +66,18 @@ protected byte[] getEncodedMessageWithoutId() { return RLP.encodeList(rlpBlockNumber, rlpFrom, rlpChunkSize); } - public static Message create(BlockFactory blockFactory, RLPList list) { - try { - byte[] rlpId = list.get(0).getRLPData(); - RLPList message = (RLPList) RLP.decode2(list.get(1).getRLPData()).get(0); - byte[] rlpBlockNumber = message.get(0).getRLPData(); - byte[] rlpFrom = message.get(1).getRLPData(); - byte[] rlpChunkSize = message.get(2).getRLPData(); - long id = rlpId == null ? 0 : BigIntegers.fromUnsignedByteArray(rlpId).longValue(); - long blockNumber = rlpBlockNumber == null ? 0 : BigIntegers.fromUnsignedByteArray(rlpBlockNumber).longValue(); - long from = rlpFrom == null ? 0 : BigIntegers.fromUnsignedByteArray(rlpFrom).longValue(); - long chunkSize = rlpChunkSize == null ? 0 : BigIntegers.fromUnsignedByteArray(rlpChunkSize).longValue(); - return new SnapStateChunkRequestMessage(id, blockNumber, from, chunkSize); - } catch (Exception e) { - throw e; - } + public static Message decodeMessage(BlockFactory blockFactory, RLPList list) { + byte[] rlpId = list.get(0).getRLPData(); + long id = rlpId == null ? 0 : BigIntegers.fromUnsignedByteArray(rlpId).longValue(); + + RLPList message = (RLPList) RLP.decode2(list.get(1).getRLPData()).get(0); + byte[] rlpBlockNumber = message.get(0).getRLPData(); + byte[] rlpFrom = message.get(1).getRLPData(); + byte[] rlpChunkSize = message.get(2).getRLPData(); + long blockNumber = rlpBlockNumber == null ? 0 : BigIntegers.fromUnsignedByteArray(rlpBlockNumber).longValue(); + long from = rlpFrom == null ? 0 : BigIntegers.fromUnsignedByteArray(rlpFrom).longValue(); + long chunkSize = rlpChunkSize == null ? 0 : BigIntegers.fromUnsignedByteArray(rlpChunkSize).longValue(); + return new SnapStateChunkRequestMessage(id, blockNumber, from, chunkSize); } public long getFrom() { diff --git a/rskj-core/src/main/java/co/rsk/net/messages/SnapStateChunkResponseMessage.java b/rskj-core/src/main/java/co/rsk/net/messages/SnapStateChunkResponseMessage.java index f0af95538c..23262f2463 100644 --- a/rskj-core/src/main/java/co/rsk/net/messages/SnapStateChunkResponseMessage.java +++ b/rskj-core/src/main/java/co/rsk/net/messages/SnapStateChunkResponseMessage.java @@ -26,8 +26,9 @@ import java.math.BigInteger; public class SnapStateChunkResponseMessage extends MessageWithId { - private final long to; private final long id; + + private final long to; private final byte[] chunkOfTrieKeyValue; private final long from; @@ -59,38 +60,34 @@ public long getId() { return this.id; } - @Override protected byte[] getEncodedMessageWithoutId() { - try { - byte[] rlpBlockNumber = RLP.encodeBigInteger(BigInteger.valueOf(this.blockNumber)); - byte[] rlpFrom = RLP.encodeBigInteger(BigInteger.valueOf(this.from)); - byte[] rlpTo = RLP.encodeBigInteger(BigInteger.valueOf(this.to)); - byte[] rlpComplete = new byte[]{this.complete ? (byte) 1 : (byte) 0}; - return RLP.encodeList(chunkOfTrieKeyValue, rlpBlockNumber, rlpFrom, rlpTo, rlpComplete); - } catch (Exception e) { - throw e; - } + byte[] rlpChunkOfTrieKeyValue = RLP.encodeElement(chunkOfTrieKeyValue); + byte[] rlpBlockNumber = RLP.encodeBigInteger(BigInteger.valueOf(this.blockNumber)); + byte[] rlpFrom = RLP.encodeBigInteger(BigInteger.valueOf(this.from)); + byte[] rlpTo = RLP.encodeBigInteger(BigInteger.valueOf(this.to)); + byte[] rlpComplete = RLP.encodeInt(this.complete ? 1 : 0); + + return RLP.encodeList(rlpChunkOfTrieKeyValue, rlpBlockNumber, rlpFrom, rlpTo, rlpComplete); } - public static Message create(BlockFactory blockFactory, RLPList list) { - try { - byte[] rlpId = list.get(0).getRLPData(); - RLPList message = (RLPList) RLP.decode2(list.get(1).getRLPData()).get(0); - byte[] chunkOfTrieKeys = message.get(0).getRLPData(); - byte[] rlpBlockNumber = message.get(1).getRLPData(); - byte[] rlpFrom = message.get(2).getRLPData(); - byte[] rlpTo = message.get(3).getRLPData(); - byte[] rlpComplete = message.get(4).getRLPData(); - long id = rlpId == null ? 0 : BigIntegers.fromUnsignedByteArray(rlpId).longValue(); - long blockNumber = rlpBlockNumber == null ? 0 : BigIntegers.fromUnsignedByteArray(rlpBlockNumber).longValue(); - long from = rlpFrom == null ? 0 : BigIntegers.fromUnsignedByteArray(rlpFrom).longValue(); - long to = rlpTo == null ? 0 : BigIntegers.fromUnsignedByteArray(rlpTo).longValue(); - boolean complete = rlpComplete == null ? Boolean.FALSE : rlpComplete[0] != 0; - return new SnapStateChunkResponseMessage(id, chunkOfTrieKeys, blockNumber, from, to, complete); - } catch (Exception e) { - throw e; - } + public static Message decodeMessage(BlockFactory blockFactory, RLPList list) { + byte[] rlpId = list.get(0).getRLPData(); + long id = rlpId == null ? 0 : BigIntegers.fromUnsignedByteArray(rlpId).longValue(); + + RLPList message = (RLPList) RLP.decode2(list.get(1).getRLPData()).get(0); + byte[] rlpChunkOfTrieKeys = message.get(0).getRLPData(); + byte[] rlpBlockNumber = message.get(1).getRLPData(); + byte[] rlpFrom = message.get(2).getRLPData(); + byte[] rlpTo = message.get(3).getRLPData(); + byte[] rlpComplete = message.get(4).getRLPData(); + + byte[] chunkOfTrieKeys = rlpChunkOfTrieKeys == null ? new byte[0] : rlpChunkOfTrieKeys; + long blockNumber = rlpBlockNumber == null ? 0 : BigIntegers.fromUnsignedByteArray(rlpBlockNumber).longValue(); + long from = rlpFrom == null ? 0 : BigIntegers.fromUnsignedByteArray(rlpFrom).longValue(); + long to = rlpTo == null ? 0 : BigIntegers.fromUnsignedByteArray(rlpTo).longValue(); + boolean complete = rlpComplete != null && rlpComplete.length != 0 && rlpComplete[0] != 0; + return new SnapStateChunkResponseMessage(id, chunkOfTrieKeys, blockNumber, from, to, complete); } public byte[] getChunkOfTrieKeyValue() { diff --git a/rskj-core/src/main/java/co/rsk/net/messages/SnapStatusRequestMessage.java b/rskj-core/src/main/java/co/rsk/net/messages/SnapStatusRequestMessage.java index 39e822ba78..25bd84dac8 100644 --- a/rskj-core/src/main/java/co/rsk/net/messages/SnapStatusRequestMessage.java +++ b/rskj-core/src/main/java/co/rsk/net/messages/SnapStatusRequestMessage.java @@ -18,11 +18,17 @@ package co.rsk.net.messages; +import org.bouncycastle.util.BigIntegers; +import org.ethereum.core.BlockFactory; import org.ethereum.util.RLP; +import org.ethereum.util.RLPList; -public class SnapStatusRequestMessage extends Message { +public class SnapStatusRequestMessage extends MessageWithId { - public SnapStatusRequestMessage() { + private final long id; + + public SnapStatusRequestMessage(long id) { + this.id = id; } @Override @@ -31,7 +37,17 @@ public MessageType getMessageType() { } @Override - public byte[] getEncodedMessage() { + public MessageType getResponseMessageType() { + return MessageType.SNAP_STATUS_RESPONSE_MESSAGE; + } + + @Override + public long getId() { + return this.id; + } + + @Override + protected byte[] getEncodedMessageWithoutId() { return RLP.encodedEmptyList(); } @@ -39,4 +55,11 @@ public byte[] getEncodedMessage() { public void accept(MessageVisitor v) { v.apply(this); } + + public static Message decodeMessage(BlockFactory blockFactory, RLPList list) { + byte[] rlpId = list.get(0).getRLPData(); + long id = rlpId == null ? 0 : BigIntegers.fromUnsignedByteArray(rlpId).longValue(); + + return new SnapStatusRequestMessage(id); + } } diff --git a/rskj-core/src/main/java/co/rsk/net/messages/SnapStatusResponseMessage.java b/rskj-core/src/main/java/co/rsk/net/messages/SnapStatusResponseMessage.java index e70851edc6..520e2e4fdf 100644 --- a/rskj-core/src/main/java/co/rsk/net/messages/SnapStatusResponseMessage.java +++ b/rskj-core/src/main/java/co/rsk/net/messages/SnapStatusResponseMessage.java @@ -30,7 +30,8 @@ import java.util.List; import java.util.stream.Collectors; -public class SnapStatusResponseMessage extends Message { +public class SnapStatusResponseMessage extends MessageWithId { + private final long id; private final List blocks; private final List difficulties; private final long trieSize; @@ -43,7 +44,8 @@ public long getTrieSize() { return this.trieSize; } - public SnapStatusResponseMessage(List blocks, List difficulties, long trieSize) { + public SnapStatusResponseMessage(long id, List blocks, List difficulties, long trieSize) { + this.id = id; this.blocks = blocks; this.difficulties = difficulties; this.trieSize = trieSize; @@ -59,7 +61,12 @@ public List getDifficulties() { } @Override - public byte[] getEncodedMessage() { + public long getId() { + return this.id; + } + + @Override + protected byte[] getEncodedMessageWithoutId() { List rlpBlocks = this.blocks.stream().map(Block::getEncoded).map(RLP::encode).collect(Collectors.toList()); List rlpDifficulties = this.difficulties.stream().map(BlockDifficulty::getBytes).map(RLP::encode).collect(Collectors.toList()); byte[] rlpTrieSize = RLP.encodeBigInteger(BigInteger.valueOf(this.trieSize)); @@ -68,8 +75,12 @@ public byte[] getEncodedMessage() { } public static Message decodeMessage(BlockFactory blockFactory, RLPList list) { - RLPList rlpBlocks = RLP.decodeList(list.get(0).getRLPData()); - RLPList rlpDifficulties = RLP.decodeList(list.get(1).getRLPData()); + byte[] rlpId = list.get(0).getRLPData(); + long id = rlpId == null ? 0 : BigIntegers.fromUnsignedByteArray(rlpId).longValue(); + + RLPList message = (RLPList)RLP.decode2(list.get(1).getRLPData()).get(0); + RLPList rlpBlocks = RLP.decodeList(message.get(0).getRLPData()); + RLPList rlpDifficulties = RLP.decodeList(message.get(1).getRLPData()); List blocks = Lists.newArrayList(); List difficulties = Lists.newArrayList(); for (int i = 0; i < rlpBlocks.size(); i++) { @@ -79,10 +90,10 @@ public static Message decodeMessage(BlockFactory blockFactory, RLPList list) { difficulties.add(new BlockDifficulty(new BigInteger(rlpDifficulties.get(i).getRLPData()))); } - byte[] rlpTrieSize = list.get(2).getRLPData(); + byte[] rlpTrieSize = message.get(2).getRLPData(); long trieSize = rlpTrieSize == null ? 0 : BigIntegers.fromUnsignedByteArray(rlpTrieSize).longValue(); - return new SnapStatusResponseMessage(blocks, difficulties, trieSize); + return new SnapStatusResponseMessage(id, blocks, difficulties, trieSize); } @Override diff --git a/rskj-core/src/main/java/co/rsk/net/sync/PeersInformation.java b/rskj-core/src/main/java/co/rsk/net/sync/PeersInformation.java index 99ebc82149..e1dc4f1aee 100644 --- a/rskj-core/src/main/java/co/rsk/net/sync/PeersInformation.java +++ b/rskj-core/src/main/java/co/rsk/net/sync/PeersInformation.java @@ -209,6 +209,15 @@ public Optional getBestPeer(Set exclude) { return getBestPeer(getBestCandidatesStream().filter(p -> !exclude.contains(p.getKey().getPeerNodeID()))); } + @Override + public Optional getBestSnapPeer(Set exclude) { + return getBestPeer( + getBestCandidatesStream() + .filter(this::isSnapPeerCandidateOrCapable) + .filter(p -> !exclude.contains(p.getKey().getPeerNodeID())) + ); + } + public Set knownNodeIds() { return peerStatuses.keySet().stream() .map(Peer::getPeerNodeID) diff --git a/rskj-core/src/main/java/co/rsk/net/sync/SnapSyncRequestManager.java b/rskj-core/src/main/java/co/rsk/net/sync/SnapSyncRequestManager.java new file mode 100644 index 0000000000..bce97735db --- /dev/null +++ b/rskj-core/src/main/java/co/rsk/net/sync/SnapSyncRequestManager.java @@ -0,0 +1,197 @@ +/* + * This file is part of RskJ + * Copyright (C) 2024 RSK Labs Ltd. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with this program. If not, see . + */ + +package co.rsk.net.sync; + +import co.rsk.net.NodeID; +import co.rsk.net.Peer; +import co.rsk.net.messages.MessageWithId; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.*; +import java.util.function.Function; +import java.util.function.Supplier; + +public class SnapSyncRequestManager { + + private static final Logger logger = LoggerFactory.getLogger("snapshotprocessor"); + + private static final long MAX_RETRY_NUM = 2; + + private final SyncConfiguration syncConfiguration; + private final SyncEventsHandler syncEventsHandler; + + private final Map pendingRequests = new HashMap<>(); + + public SnapSyncRequestManager(@Nonnull SyncConfiguration syncConfiguration, @Nonnull SyncEventsHandler syncEventsHandler) { + this.syncConfiguration = Objects.requireNonNull(syncConfiguration); + this.syncEventsHandler = Objects.requireNonNull(syncEventsHandler); + } + + synchronized void submitRequest(@Nonnull PeerSelector peerSelector, @Nonnull RequestFactory requestFactory) throws SendRequestException { + long messageId = syncEventsHandler.nextMessageId(); + PendingRequest pendingRequest = new PendingRequest(peerSelector, requestFactory); + pendingRequests.put(messageId, pendingRequest); + + MessageWithId messageWithId = pendingRequest.send(messageId, System.currentTimeMillis()); + syncEventsHandler.registerPendingMessage(messageWithId); + } + + synchronized boolean processResponse(@Nonnull MessageWithId responseMessage) { + return pendingRequests.remove(responseMessage.getId()) != null; + } + + synchronized void resendExpiredRequests() throws SendRequestException { + long requestTimeout = syncConfiguration.getTimeoutWaitingRequest().toMillis(); + long now = System.currentTimeMillis(); + long exp = now - requestTimeout; + Map resentRequests = null; + + for (Iterator> iter = pendingRequests.entrySet().iterator(); iter.hasNext(); ) { + Map.Entry msgEntry = iter.next(); + PendingRequest pendingRequest = msgEntry.getValue(); + if (pendingRequest.isExpired(exp)) { + iter.remove(); + + long messageId = syncEventsHandler.nextMessageId(); + MessageWithId messageWithId = pendingRequest.reSend(msgEntry.getKey(), messageId, now); + syncEventsHandler.registerPendingMessage(messageWithId); + if (resentRequests == null) { + resentRequests = new HashMap<>(); + } + resentRequests.put(messageId, pendingRequest); + } + } + + if (resentRequests != null) { + pendingRequests.putAll(resentRequests); + } + } + + @FunctionalInterface + public interface RequestFactory { + MessageWithId createRequest(long messageId); + } + + @FunctionalInterface + public interface PeerSelector { + Optional selectPeer(@Nullable NodeID failedPeerIds); + + static Builder builder() { + return new Builder(); + } + + class Builder { + private Supplier> defaultPeerSupplier = Optional::empty; + private Function, Optional> altPeerSupplier = failedPeerIds -> Optional.empty(); + + public Builder withDefaultPeerOption(Supplier> defaultPeerOptionSupplier) { + this.defaultPeerSupplier = Objects.requireNonNull(defaultPeerOptionSupplier); + return this; + } + + public Builder withDefaultPeer(Supplier defaultPeerSupplier) { + Objects.requireNonNull(defaultPeerSupplier); + this.defaultPeerSupplier = () -> Optional.ofNullable(defaultPeerSupplier.get()); + return this; + } + + public Builder withAltPeer(Function, Optional> altPeerSupplier) { + this.altPeerSupplier = Objects.requireNonNull(altPeerSupplier); + return this; + } + + public PeerSelector build() { + return failedPeerId -> Optional.ofNullable(failedPeerId) + .flatMap(peerId -> altPeerSupplier.apply(Collections.singleton(peerId))) + .or(defaultPeerSupplier) + .or(() -> altPeerSupplier.apply(Collections.emptySet())); + } + } + } + + private static class PendingRequest { + private final PeerSelector peerSelector; + private final RequestFactory requestFactory; + + private Peer selectedPeer; + private long started; + private int retries; + + PendingRequest(@Nonnull PeerSelector peerSelector, @Nonnull RequestFactory requestFactory) { + this.peerSelector = Objects.requireNonNull(peerSelector); + this.requestFactory = Objects.requireNonNull(requestFactory); + } + + MessageWithId send(long messageId, long now) throws SendRequestException { + this.started = now; + + Optional selectedPeerOpt = this.peerSelector.selectPeer(null); + if (selectedPeerOpt.isEmpty()) { + throw new SendRequestException("Failed to send request - no peer available"); + } + + this.selectedPeer = selectedPeerOpt.get(); + + MessageWithId msg = requestFactory.createRequest(messageId); + + logger.debug("Sending request: [{}] with id: [{}] to: [{}]", msg.getMessageType(), msg.getId(), this.selectedPeer.getPeerNodeID()); + selectedPeer.sendMessage(msg); + + return msg; + } + + MessageWithId reSend(long previousMessageId, long newMessageId, long now) throws SendRequestException { + this.started = now; + + if (this.retries >= MAX_RETRY_NUM) { + throw new SendRequestException("Failed to re-send expired request with previous messageId: [" + previousMessageId + "] - max retries reached"); + } + + Optional selectedPeerOpt = this.peerSelector.selectPeer(this.selectedPeer.getPeerNodeID()); + if (selectedPeerOpt.isEmpty()) { + throw new SendRequestException("Failed to re-send expired request with previous messageId: [" + previousMessageId + "] - no peer available"); + } + + this.selectedPeer = selectedPeerOpt.get(); + + MessageWithId msg = this.requestFactory.createRequest(newMessageId); + + logger.debug("Re-sending expired request: [{}] with old id: [{}] to: [{}] with new id: [{}]", msg.getMessageType(), previousMessageId, this.selectedPeer.getPeerNodeID(), msg.getId()); + this.selectedPeer.sendMessage(msg); + + this.retries++; + + return msg; + } + + boolean isExpired(long exp) { + return started <= exp; + } + } + + public static class SendRequestException extends Exception { + + public SendRequestException(String message) { + super(message); + } + } +} diff --git a/rskj-core/src/main/java/co/rsk/net/sync/SnapSyncState.java b/rskj-core/src/main/java/co/rsk/net/sync/SnapSyncState.java index 694ec16652..2234187c56 100644 --- a/rskj-core/src/main/java/co/rsk/net/sync/SnapSyncState.java +++ b/rskj-core/src/main/java/co/rsk/net/sync/SnapSyncState.java @@ -21,10 +21,7 @@ import co.rsk.core.BlockDifficulty; import co.rsk.net.Peer; import co.rsk.net.SnapshotProcessor; -import co.rsk.net.messages.MessageType; -import co.rsk.net.messages.SnapBlocksResponseMessage; -import co.rsk.net.messages.SnapStateChunkResponseMessage; -import co.rsk.net.messages.SnapStatusResponseMessage; +import co.rsk.net.messages.*; import co.rsk.scoring.EventType; import co.rsk.trie.TrieDTO; import com.google.common.annotations.VisibleForTesting; @@ -35,17 +32,24 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import javax.annotation.Nonnull; import javax.annotation.Nullable; import java.math.BigInteger; +import java.time.Duration; import java.util.*; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; +import static co.rsk.net.sync.SnapSyncRequestManager.PeerSelector; +import static co.rsk.net.sync.SnapSyncRequestManager.RequestFactory; +import static co.rsk.net.sync.SnapSyncRequestManager.SendRequestException; + public class SnapSyncState extends BaseSyncState { private static final Logger logger = LoggerFactory.getLogger("SnapSyncState"); private final SnapshotProcessor snapshotProcessor; + private final SnapSyncRequestManager snapRequestManager; // queue for processing of SNAP responses private final BlockingQueue responseQueue = new LinkedBlockingQueue<>(); @@ -78,14 +82,16 @@ public class SnapSyncState extends BaseSyncState { private final Thread thread; public SnapSyncState(SyncEventsHandler syncEventsHandler, SnapshotProcessor snapshotProcessor, SyncConfiguration syncConfiguration) { - this(syncEventsHandler, snapshotProcessor, syncConfiguration, null); + this(syncEventsHandler, snapshotProcessor, new SnapSyncRequestManager(syncConfiguration, syncEventsHandler), syncConfiguration, null); } @VisibleForTesting SnapSyncState(SyncEventsHandler syncEventsHandler, SnapshotProcessor snapshotProcessor, - SyncConfiguration syncConfiguration, @Nullable SyncMessageHandler.Listener listener) { + SnapSyncRequestManager snapRequestManager, SyncConfiguration syncConfiguration, + @Nullable SyncMessageHandler.Listener listener) { super(syncEventsHandler, syncConfiguration); - this.snapshotProcessor = snapshotProcessor; // TODO(snap-poc) code in SnapshotProcessor should be moved here probably + this.snapshotProcessor = snapshotProcessor; + this.snapRequestManager = snapRequestManager; this.allNodes = Lists.newArrayList(); this.blocks = Lists.newArrayList(); this.thread = new Thread(new SyncMessageHandler("SNAP/client", responseQueue, listener) { @@ -110,8 +116,12 @@ public void onEnter() { @Override public void onSnapStatus(Peer sender, SnapStatusResponseMessage responseMessage) { + if (!snapRequestManager.processResponse(responseMessage)) { + logger.warn("Unexpected response: [{}] received with id: [{}]. Ignoring", responseMessage.getMessageType(), responseMessage.getId()); + return; + } + try { - resetTimeElapsed(); responseQueue.put(new SyncMessageHandler.Job(sender, responseMessage) { @Override public void run() { @@ -126,8 +136,12 @@ public void run() { @Override public void onSnapBlocks(Peer sender, SnapBlocksResponseMessage responseMessage) { + if (!snapRequestManager.processResponse(responseMessage)) { + logger.warn("Unexpected response: [{}] received with id: [{}]. Ignoring", responseMessage.getMessageType(), responseMessage.getId()); + return; + } + try { - resetTimeElapsed(); responseQueue.put(new SyncMessageHandler.Job(sender, responseMessage) { @Override public void run() { @@ -142,8 +156,12 @@ public void run() { @Override public void onSnapStateChunk(Peer sender, SnapStateChunkResponseMessage responseMessage) { + if (!snapRequestManager.processResponse(responseMessage)) { + logger.warn("Unexpected response: [{}] received with id: [{}]. Ignoring", responseMessage.getMessageType(), responseMessage.getId()); + return; + } + try { - resetTimeElapsed(); responseQueue.put(new SyncMessageHandler.Job(sender, responseMessage) { @Override public void run() { @@ -157,13 +175,17 @@ public void run() { } @Override - public void newBlockHeaders(Peer peer, List chunk) { + public void newBlockHeaders(Peer sender, BlockHeadersResponseMessage responseMessage) { + if (!snapRequestManager.processResponse(responseMessage)) { + logger.warn("Unexpected response: [{}] received with id: [{}]. Ignoring", responseMessage.getMessageType(), responseMessage.getId()); + return; + } + try { - resetTimeElapsed(); - responseQueue.put(new SyncMessageHandler.Job(peer, MessageType.BLOCK_HEADERS_RESPONSE_MESSAGE) { + responseQueue.put(new SyncMessageHandler.Job(sender, responseMessage) { @Override public void run() { - snapshotProcessor.processBlockHeaderChunk(SnapSyncState.this, peer, chunk); + snapshotProcessor.processBlockHeaderChunk(SnapSyncState.this, sender, responseMessage.getBlockHeaders()); } }); } catch (InterruptedException e) { @@ -176,9 +198,23 @@ public SyncEventsHandler getSyncEventsHandler() { return this.syncEventsHandler; } + public synchronized void submitRequest(@Nonnull PeerSelector peerSelector, @Nonnull RequestFactory requestFactory) { + try { + snapRequestManager.submitRequest(peerSelector, requestFactory); + } catch (SendRequestException e) { + logger.warn("Failed to submit expired requests. Stopping snap syncing", e); + finish(); + } + } + @Override - protected void onMessageTimeOut() { - fail(getLastBlockSender(), EventType.TIMEOUT_MESSAGE, "Snap sync timed out"); + public void tick(Duration duration) { + try { + this.snapRequestManager.resendExpiredRequests(); + } catch (SendRequestException e) { + logger.warn("Failed to re-submit expired requests. Stopping snap syncing", e); + finish(); + } } public Block getLastBlock() { diff --git a/rskj-core/src/main/java/co/rsk/net/sync/SnapshotPeersInformation.java b/rskj-core/src/main/java/co/rsk/net/sync/SnapshotPeersInformation.java index 1a9fad26a8..992842d149 100644 --- a/rskj-core/src/main/java/co/rsk/net/sync/SnapshotPeersInformation.java +++ b/rskj-core/src/main/java/co/rsk/net/sync/SnapshotPeersInformation.java @@ -30,8 +30,10 @@ * things such as the underlying communication channel. */ public interface SnapshotPeersInformation { + Optional getBestPeer(); Optional getBestSnapPeer(); List getBestSnapPeerCandidates(); Optional getBestPeer(Set exclude); + Optional getBestSnapPeer(Set exclude); SyncPeerStatus getOrRegisterPeer(Peer peer); } diff --git a/rskj-core/src/main/java/co/rsk/net/sync/SyncConfiguration.java b/rskj-core/src/main/java/co/rsk/net/sync/SyncConfiguration.java index 722f1329a3..ecf104b4f6 100644 --- a/rskj-core/src/main/java/co/rsk/net/sync/SyncConfiguration.java +++ b/rskj-core/src/main/java/co/rsk/net/sync/SyncConfiguration.java @@ -30,10 +30,10 @@ @Immutable public final class SyncConfiguration { @VisibleForTesting - public static final SyncConfiguration DEFAULT = new SyncConfiguration(5, 60, 30, 5, 20, 192, 20, 10, 0, false, false, 60, 0); + public static final SyncConfiguration DEFAULT = new SyncConfiguration(5, 60, 30, 5, 20, 192, 20, 10, 0, false, false, 0); @VisibleForTesting - public static final SyncConfiguration IMMEDIATE_FOR_TESTING = new SyncConfiguration(1, 1, 3, 1, 5, 192, 20, 10, 0, false, false, 60, 0); + public static final SyncConfiguration IMMEDIATE_FOR_TESTING = new SyncConfiguration(1, 1, 3, 1, 5, 192, 20, 10, 0, false, false, 0); private final int expectedPeers; private final Duration timeoutWaitingPeers; @@ -62,7 +62,6 @@ public final class SyncConfiguration { * @param topBest % of top best nodes that will be considered for random selection. * @param isServerSnapSyncEnabled Flag that indicates if server-side snap sync is enabled * @param isClientSnapSyncEnabled Flag that indicates if client-side snap sync is enabled - * @param timeoutWaitingSnapChunk Specific request timeout for snap sync * @param snapshotSyncLimit Distance to the tip of the peer's blockchain to enable snap synchronization. */ public SyncConfiguration( @@ -77,7 +76,6 @@ public SyncConfiguration( double topBest, boolean isServerSnapSyncEnabled, boolean isClientSnapSyncEnabled, - int timeoutWaitingSnapChunk, int snapshotSyncLimit) { this(expectedPeers, timeoutWaitingPeers, @@ -127,27 +125,27 @@ public SyncConfiguration( .collect(Collectors.toMap(peer -> peer.getId().toString(), peer -> peer))); } - public final int getExpectedPeers() { + public int getExpectedPeers() { return expectedPeers; } - public final int getMaxSkeletonChunks() { + public int getMaxSkeletonChunks() { return maxSkeletonChunks; } - public final Duration getTimeoutWaitingPeers() { + public Duration getTimeoutWaitingPeers() { return timeoutWaitingPeers; } - public final Duration getTimeoutWaitingRequest() { + public Duration getTimeoutWaitingRequest() { return timeoutWaitingRequest; } - public final Duration getExpirationTimePeerStatus() { + public Duration getExpirationTimePeerStatus() { return expirationTimePeerStatus; } - public final int getChunkSize() { + public int getChunkSize() { return chunkSize; } diff --git a/rskj-core/src/main/java/co/rsk/net/sync/SyncEventsHandler.java b/rskj-core/src/main/java/co/rsk/net/sync/SyncEventsHandler.java index e0bd7be4e9..a79905b762 100644 --- a/rskj-core/src/main/java/co/rsk/net/sync/SyncEventsHandler.java +++ b/rskj-core/src/main/java/co/rsk/net/sync/SyncEventsHandler.java @@ -18,11 +18,13 @@ package co.rsk.net.sync; import co.rsk.net.Peer; +import co.rsk.net.messages.MessageWithId; import co.rsk.scoring.EventType; import org.ethereum.core.Block; import org.ethereum.core.BlockHeader; import org.ethereum.core.BlockIdentifier; +import javax.annotation.Nonnull; import javax.annotation.Nullable; import java.util.Deque; import java.util.List; @@ -60,4 +62,8 @@ public interface SyncEventsHandler { void backwardSyncing(Peer peer); void startSnapSync(Peer peer); + + void registerPendingMessage(@Nonnull MessageWithId message); + + long nextMessageId(); } diff --git a/rskj-core/src/main/java/co/rsk/net/sync/SyncState.java b/rskj-core/src/main/java/co/rsk/net/sync/SyncState.java index 4d0436d45b..f0dcbe67ca 100644 --- a/rskj-core/src/main/java/co/rsk/net/sync/SyncState.java +++ b/rskj-core/src/main/java/co/rsk/net/sync/SyncState.java @@ -18,10 +18,7 @@ package co.rsk.net.sync; import co.rsk.net.Peer; -import co.rsk.net.messages.BodyResponseMessage; -import co.rsk.net.messages.SnapBlocksResponseMessage; -import co.rsk.net.messages.SnapStateChunkResponseMessage; -import co.rsk.net.messages.SnapStatusResponseMessage; +import co.rsk.net.messages.*; import org.ethereum.core.BlockHeader; import org.ethereum.core.BlockIdentifier; @@ -31,6 +28,10 @@ public interface SyncState { void newBlockHeaders(Peer peer, List chunk); + default void newBlockHeaders(Peer peer, BlockHeadersResponseMessage message) { + newBlockHeaders(peer, message.getBlockHeaders()); + } + // TODO(mc) don't receive a full message void newBody(BodyResponseMessage message, Peer peer); diff --git a/rskj-core/src/main/java/org/ethereum/net/server/Channel.java b/rskj-core/src/main/java/org/ethereum/net/server/Channel.java index c390be3318..fbe8261656 100644 --- a/rskj-core/src/main/java/org/ethereum/net/server/Channel.java +++ b/rskj-core/src/main/java/org/ethereum/net/server/Channel.java @@ -25,6 +25,7 @@ import co.rsk.net.eth.RskWireProtocol; import co.rsk.net.messages.Message; import co.rsk.net.messages.MessageType; +import com.google.common.annotations.VisibleForTesting; import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; import org.ethereum.net.MessageQueue; @@ -74,9 +75,10 @@ public class Channel implements Peer { private final PeerStatistics peerStats = new PeerStatistics(); - private Stats stats; + private final Stats stats; private boolean isSnapCapable; + @VisibleForTesting public Channel(MessageQueue msgQueue, MessageCodec messageCodec, NodeManager nodeManager, @@ -84,6 +86,17 @@ public Channel(MessageQueue msgQueue, Eth62MessageFactory eth62MessageFactory, StaticMessages staticMessages, String remoteId) { + this(msgQueue, messageCodec, nodeManager, rskWireProtocolFactory, eth62MessageFactory, staticMessages, remoteId, new Stats()); + } + + public Channel(MessageQueue msgQueue, + MessageCodec messageCodec, + NodeManager nodeManager, + RskWireProtocol.Factory rskWireProtocolFactory, + Eth62MessageFactory eth62MessageFactory, + StaticMessages staticMessages, + String remoteId, + Stats stats) { this.msgQueue = msgQueue; this.messageCodec = messageCodec; this.nodeManager = nodeManager; @@ -91,7 +104,7 @@ public Channel(MessageQueue msgQueue, this.eth62MessageFactory = eth62MessageFactory; this.staticMessages = staticMessages; this.isActive = remoteId != null && !remoteId.isEmpty(); - this.stats = new Stats(); + this.stats = stats; } public void sendHelloMessage(ChannelHandlerContext ctx, FrameCodec frameCodec, String nodeId, diff --git a/rskj-core/src/main/java/org/ethereum/net/server/EthereumChannelInitializer.java b/rskj-core/src/main/java/org/ethereum/net/server/EthereumChannelInitializer.java index 7d0f478f7f..ae25fcc1f6 100644 --- a/rskj-core/src/main/java/org/ethereum/net/server/EthereumChannelInitializer.java +++ b/rskj-core/src/main/java/org/ethereum/net/server/EthereumChannelInitializer.java @@ -111,7 +111,8 @@ public void initChannel(NioSocketChannel ch) { P2pHandler p2pHandler = new P2pHandler(ethereumListener, messageQueue, config.getPeerP2PPingInterval()); MessageCodec messageCodec = new MessageCodec(ethereumListener, config); HandshakeHandler handshakeHandler = new HandshakeHandler(config, peerScoringManager, p2pHandler, messageCodec, configCapabilities); - Channel channel = new Channel(messageQueue, messageCodec, nodeManager, rskWireProtocolFactory, eth62MessageFactory, staticMessages, remoteId); + Stats stats = new Stats(config.getMessageQueuePerMinuteThreshold()); + Channel channel = new Channel(messageQueue, messageCodec, nodeManager, rskWireProtocolFactory, eth62MessageFactory, staticMessages, remoteId, stats); ch.pipeline().addLast("readTimeoutHandler", new ReadTimeoutHandler(config.peerChannelReadTimeout(), TimeUnit.SECONDS)); ch.pipeline().addLast("handshakeHandler", handshakeHandler); diff --git a/rskj-core/src/main/java/org/ethereum/net/server/Stats.java b/rskj-core/src/main/java/org/ethereum/net/server/Stats.java index 1cdcd9929f..9dc5c2b967 100644 --- a/rskj-core/src/main/java/org/ethereum/net/server/Stats.java +++ b/rskj-core/src/main/java/org/ethereum/net/server/Stats.java @@ -29,9 +29,10 @@ public class Stats { // Current minute messages counter private long minute; - // Reject messages over this treshold + + // Reject messages over this threshold // Is calculated using Exponential moving average - private long perMinuteThreshold; + private final long perMinuteThreshold; // events counters // 100% heuristics @@ -44,26 +45,29 @@ public class Stats { private double avg; //in ms // how fast avg and mpm update - private double alpha_m; - private double alpha_a; + private final double alpha_m; + private final double alpha_a; // scores for blocks and others - private double maxBlock; - private double maxOther; + private final double maxBlock; + private final double maxOther; public Stats() { + this(1000); + } + + public Stats(long perMinuteThreshold) { avg = 500; alpha_m = 0.3; alpha_a = 0.03; - perMinuteThreshold = 1000; + this.perMinuteThreshold = perMinuteThreshold; maxBlock = 200; maxOther = 100; mpm = 1; } - public synchronized double update(long timestamp, MessageType type) { long min = timestamp / 60000; long delta = timestamp - lastMessage; @@ -166,6 +170,7 @@ private double priority(MessageType type) { return 0.0; } } + public synchronized void imported(boolean best) { if (best) { importedBest++; @@ -174,7 +179,6 @@ public synchronized void imported(boolean best) { } } - @Override public String toString() { return "Stats{" + @@ -197,7 +201,6 @@ public double getMpm() { return mpm; } - @VisibleForTesting public long getMinute() { return minute; @@ -217,5 +220,4 @@ public void setImportedBest(int importedBest) { public void setImportedNotBest(int importedNotBest) { this.importedNotBest = importedNotBest; } - } diff --git a/rskj-core/src/main/resources/expected.conf b/rskj-core/src/main/resources/expected.conf index a62a7c7a5e..79d221be2f 100644 --- a/rskj-core/src/main/resources/expected.conf +++ b/rskj-core/src/main/resources/expected.conf @@ -162,6 +162,7 @@ peer = { bannedPeerIDs = [] bannedMiners = [] messageQueue.maxSizePerPeer = + messageQueue.thresholdPerMinutePerPeer = } genesis = genesis_constants.federationPublicKeys = [] @@ -283,6 +284,7 @@ sync = { } client = { enabled = + checkHistoricalHeaders = parallel = chunkSize = limit = diff --git a/rskj-core/src/main/resources/reference.conf b/rskj-core/src/main/resources/reference.conf index a81984b8b7..6bd24d8fca 100644 --- a/rskj-core/src/main/resources/reference.conf +++ b/rskj-core/src/main/resources/reference.conf @@ -178,6 +178,9 @@ peer { # Max number of pending messages that will be allowed per peer messageQueue.maxSizePerPeer = 2000 + # Reject peer's messages over this threshold + # It's calculated using exponential moving average + messageQueue.thresholdPerMinutePerPeer = 1000 } miner { @@ -382,6 +385,8 @@ sync { client = { # Client / snapshot sync enabled enabled = false + # Flat that determines if the client should check the historical headers + checkHistoricalHeaders = true # Server / chunk size chunkSize = 50 # Distance to the tip of the blockchain to start snapshot sync diff --git a/rskj-core/src/test/java/co/rsk/net/NodeMessageHandlerTest.java b/rskj-core/src/test/java/co/rsk/net/NodeMessageHandlerTest.java index 09cbf6f5cc..54aa38f0f1 100644 --- a/rskj-core/src/test/java/co/rsk/net/NodeMessageHandlerTest.java +++ b/rskj-core/src/test/java/co/rsk/net/NodeMessageHandlerTest.java @@ -993,8 +993,7 @@ void fillMessageQueue_thenBlockNewMessages() { } // assert that the surplus was not added - Assertions.assertEquals(config.getMessageQueueMaxSize(), (Integer) handler.getMessageQueueSize(sender)); - + Assertions.assertEquals(config.getMessageQueueMaxSize(), handler.getMessageQueueSize(sender)); } @Test diff --git a/rskj-core/src/test/java/co/rsk/net/SnapshotProcessorTest.java b/rskj-core/src/test/java/co/rsk/net/SnapshotProcessorTest.java index 58068feafd..6457d070f3 100644 --- a/rskj-core/src/test/java/co/rsk/net/SnapshotProcessorTest.java +++ b/rskj-core/src/test/java/co/rsk/net/SnapshotProcessorTest.java @@ -45,6 +45,8 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import static co.rsk.net.sync.SnapSyncRequestManager.PeerSelector; +import static co.rsk.net.sync.SnapSyncRequestManager.RequestFactory; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.*; @@ -95,12 +97,13 @@ void givenStartSyncingIsCalled_thenSnapStatusStartToBeRequestedFromPeer() { blockHeaderParentValidator, blockHeaderValidator, TEST_CHUNK_SIZE, + true, false); doReturn(Optional.of(peer)).when(peersInformation).getBestSnapPeer(); //when underTest.startSyncing(snapSyncState); //then - verify(peer).sendMessage(any(SnapStatusRequestMessage.class)); + verify(snapSyncState).submitRequest(any(PeerSelector.class), any(RequestFactory.class)); } @Test @@ -120,6 +123,7 @@ void givenSnapStatusResponseCalled_thenSnapChunkRequestsAreMade() { blockHeaderParentValidator, blockHeaderValidator, TEST_CHUNK_SIZE, + true, false); for (long blockNumber = 0; blockNumber < blockchain.getSize(); blockNumber++) { @@ -128,7 +132,7 @@ void givenSnapStatusResponseCalled_thenSnapChunkRequestsAreMade() { difficulties.add(blockStore.getTotalDifficultyForHash(currentBlock.getHash().getBytes())); } - SnapStatusResponseMessage snapStatusResponseMessage = new SnapStatusResponseMessage(blocks, difficulties, 100000L); + SnapStatusResponseMessage snapStatusResponseMessage = new SnapStatusResponseMessage(1, blocks, difficulties, 100000L); doReturn(blocks.get(blocks.size() - 1)).when(snapSyncState).getLastBlock(); doReturn(snapStatusResponseMessage.getTrieSize()).when(snapSyncState).getRemoteTrieSize(); @@ -144,7 +148,7 @@ void givenSnapStatusResponseCalled_thenSnapChunkRequestsAreMade() { underTest.processSnapStatusResponse(snapSyncState, peer, snapStatusResponseMessage); //then - verify(peer, times(2)).sendMessage(any()); // 1 for SnapStatusRequestMessage, 1 for SnapBlocksRequestMessage and 1 for SnapStateChunkRequestMessage + verify(snapSyncState, times(2)).submitRequest(any(PeerSelector.class), any(RequestFactory.class)); // 1 for SnapStatusRequestMessage, 1 for SnapBlocksRequestMessage and 1 for SnapStateChunkRequestMessage verify(peersInformation, times(1)).getBestSnapPeer(); } @@ -163,6 +167,7 @@ void givenSnapStatusRequestReceived_thenSnapStatusResponseIsSent() { blockHeaderParentValidator, blockHeaderValidator, TEST_CHUNK_SIZE, + true, false); //when underTest.processSnapStatusRequestInternal(peer, mock(SnapStatusRequestMessage.class)); @@ -186,9 +191,10 @@ void givenSnapBlockRequestReceived_thenSnapBlocksResponseMessageIsSent() { blockHeaderParentValidator, blockHeaderValidator, TEST_CHUNK_SIZE, + true, false); - SnapBlocksRequestMessage snapBlocksRequestMessage = new SnapBlocksRequestMessage(460); + SnapBlocksRequestMessage snapBlocksRequestMessage = new SnapBlocksRequestMessage(1, 460); //when underTest.processSnapBlocksRequestInternal(peer, snapBlocksRequestMessage); @@ -213,6 +219,7 @@ void givenSnapBlocksResponseReceived_thenSnapBlocksRequestMessageIsSent() { blockHeaderParentValidator, blockHeaderValidator, 200, + true, false); for (long blockNumber = 0; blockNumber < blockchain.getSize(); blockNumber++) { @@ -221,7 +228,7 @@ void givenSnapBlocksResponseReceived_thenSnapBlocksRequestMessageIsSent() { difficulties.add(blockStore.getTotalDifficultyForHash(currentBlock.getHash().getBytes())); } - SnapStatusResponseMessage snapStatusResponseMessage = new SnapStatusResponseMessage(blocks, difficulties, 100000L); + SnapStatusResponseMessage snapStatusResponseMessage = new SnapStatusResponseMessage(1, blocks, difficulties, 100000L); doReturn(true).when(snapSyncState).isRunning(); doReturn(true).when(blockValidator).isValid(any()); doReturn(true).when(blockParentValidator).isValid(any(), any()); @@ -230,14 +237,14 @@ void givenSnapBlocksResponseReceived_thenSnapBlocksRequestMessageIsSent() { underTest.startSyncing(snapSyncState); underTest.processSnapStatusResponse(snapSyncState, peer, snapStatusResponseMessage); - SnapBlocksResponseMessage snapBlocksResponseMessage = new SnapBlocksResponseMessage(blocks, difficulties); + SnapBlocksResponseMessage snapBlocksResponseMessage = new SnapBlocksResponseMessage(1, blocks, difficulties); when(snapSyncState.getLastBlock()).thenReturn(blocks.get(blocks.size() - 1)); //when underTest.processSnapBlocksResponse(snapSyncState, peer, snapBlocksResponseMessage); //then - verify(peer, atLeast(2)).sendMessage(any(SnapBlocksRequestMessage.class)); + verify(snapSyncState, atLeast(2)).submitRequest(any(PeerSelector.class), any(RequestFactory.class)); } @Test @@ -255,6 +262,7 @@ void givenSnapStateChunkRequest_thenSnapStateChunkResponseMessageIsSent() { blockHeaderParentValidator, blockHeaderValidator, TEST_CHUNK_SIZE, + true, false); SnapStateChunkRequestMessage snapStateChunkRequestMessage = new SnapStateChunkRequestMessage(1L, 1L, 1, TEST_CHUNK_SIZE); @@ -284,6 +292,7 @@ void givenProcessSnapStatusRequestIsCalled_thenInternalOneIsCalledLater() throws blockHeaderParentValidator, blockHeaderValidator, TEST_CHUNK_SIZE, + true, false, listener) { @Override @@ -324,6 +333,7 @@ void givenProcessSnapBlocksRequestIsCalled_thenInternalOneIsCalledLater() throws blockHeaderParentValidator, blockHeaderValidator, TEST_CHUNK_SIZE, + true, false, listener) { @Override @@ -364,6 +374,7 @@ void givenProcessStateChunkRequestIsCalled_thenInternalOneIsCalledLater() throws blockHeaderParentValidator, blockHeaderValidator, TEST_CHUNK_SIZE, + true, false, listener) { @Override @@ -399,6 +410,7 @@ void givenErrorRLPData_thenOnStateChunkErrorIsCalled() { blockHeaderParentValidator, blockHeaderValidator, TEST_CHUNK_SIZE, + true, false); PriorityQueue queue = new PriorityQueue<>( @@ -416,8 +428,8 @@ void givenErrorRLPData_thenOnStateChunkErrorIsCalled() { underTest.processStateChunkResponse(snapSyncState, peer, responseMessage); - verify(underTest, times(1)).onStateChunkResponseError(peer, responseMessage); - verify(peer, times(1)).sendMessage(any(SnapStateChunkRequestMessage.class)); + verify(underTest, times(1)).onStateChunkResponseError(snapSyncState, peer, responseMessage); + verify(snapSyncState, times(1)).submitRequest(any(PeerSelector.class), any(RequestFactory.class)); } private void initializeBlockchainWithAmountOfBlocks(int numberOfBlocks) { diff --git a/rskj-core/src/test/java/co/rsk/net/ThreeAsyncNodeUsingSyncProcessorTest.java b/rskj-core/src/test/java/co/rsk/net/ThreeAsyncNodeUsingSyncProcessorTest.java index 1472821f78..beaca29f2c 100644 --- a/rskj-core/src/test/java/co/rsk/net/ThreeAsyncNodeUsingSyncProcessorTest.java +++ b/rskj-core/src/test/java/co/rsk/net/ThreeAsyncNodeUsingSyncProcessorTest.java @@ -190,7 +190,7 @@ public void synchronizeNewNodeWithTwoPeersDefault() { SimpleAsyncNode node1 = SimpleAsyncNode.createDefaultNode(b1); SimpleAsyncNode node2 = SimpleAsyncNode.createDefaultNode(b1); - SyncConfiguration syncConfiguration = new SyncConfiguration(2,1,1,1,20,192, 20, 10, 0, false, false, 60, 0); + SyncConfiguration syncConfiguration = new SyncConfiguration(2,1,1,1,20,192, 20, 10, 0, false, false, 0); SimpleAsyncNode node3 = SimpleAsyncNode.createNode(b2, syncConfiguration); Assertions.assertEquals(50, node1.getBestBlock().getNumber()); @@ -231,7 +231,7 @@ public void synchronizeNewNodeWithTwoPeers200Default() { SimpleAsyncNode node1 = SimpleAsyncNode.createDefaultNode(b1); SimpleAsyncNode node2 = SimpleAsyncNode.createDefaultNode(b1); - SyncConfiguration syncConfiguration = new SyncConfiguration(2,1,1,1,20,192, 20, 10, 0, false, false, 60, 0); + SyncConfiguration syncConfiguration = new SyncConfiguration(2,1,1,1,20,192, 20, 10, 0, false, false, 0); SimpleAsyncNode node3 = SimpleAsyncNode.createNode(b2, syncConfiguration); Assertions.assertEquals(200, node1.getBestBlock().getNumber()); @@ -272,7 +272,7 @@ public void synchronizeWithTwoPeers200AndOneFails() { SimpleAsyncNode node1 = SimpleAsyncNode.createDefaultNode(b1); SimpleAsyncNode node2 = SimpleAsyncNode.createDefaultNode(b1); - SyncConfiguration syncConfiguration = new SyncConfiguration(2,1,0,1,20,192, 20, 10, 0, false, false, 60, 0); + SyncConfiguration syncConfiguration = new SyncConfiguration(2,1,0,1,20,192, 20, 10, 0, false, false, 0); SimpleAsyncNode node3 = SimpleAsyncNode.createNode(b2, syncConfiguration); Assertions.assertEquals(200, node1.getBestBlock().getNumber()); @@ -319,7 +319,7 @@ public void synchronizeNewNodeWithTwoPeers200Different() { SimpleAsyncNode node1 = SimpleAsyncNode.createDefaultNode(b1); SimpleAsyncNode node2 = SimpleAsyncNode.createDefaultNode(b2); - SyncConfiguration syncConfiguration = new SyncConfiguration(2,1,1,1,20,192, 20, 10, 0, false, false, 60, 0); + SyncConfiguration syncConfiguration = new SyncConfiguration(2,1,1,1,20,192, 20, 10, 0, false, false, 0); SimpleAsyncNode node3 = SimpleAsyncNode.createNode(b3, syncConfiguration); Assertions.assertEquals(193, node1.getBestBlock().getNumber()); @@ -363,7 +363,7 @@ public void synchronizeNewNodeWithThreePeers400Different() { SimpleAsyncNode node1 = SimpleAsyncNode.createDefaultNode(b2); SimpleAsyncNode node2 = SimpleAsyncNode.createDefaultNode(b2); SimpleAsyncNode node3 = SimpleAsyncNode.createDefaultNode(b3); - SyncConfiguration syncConfiguration = new SyncConfiguration(3,1,10,100,20,192, 20, 10, 0, false, false, 60, 0); + SyncConfiguration syncConfiguration = new SyncConfiguration(3,1,10,100,20,192, 20, 10, 0, false, false, 0); SimpleAsyncNode node4 = SimpleAsyncNode.createNode(b1, syncConfiguration); Assertions.assertEquals(200, node1.getBestBlock().getNumber()); diff --git a/rskj-core/src/test/java/co/rsk/net/messages/SnapBlocksRequestMessageTest.java b/rskj-core/src/test/java/co/rsk/net/messages/SnapBlocksRequestMessageTest.java index e5443746b8..21f3f86d52 100644 --- a/rskj-core/src/test/java/co/rsk/net/messages/SnapBlocksRequestMessageTest.java +++ b/rskj-core/src/test/java/co/rsk/net/messages/SnapBlocksRequestMessageTest.java @@ -19,8 +19,11 @@ package co.rsk.net.messages; import co.rsk.blockchain.utils.BlockGenerator; +import co.rsk.config.TestSystemProperties; import org.ethereum.core.Block; +import org.ethereum.core.BlockFactory; import org.ethereum.util.RLP; +import org.ethereum.util.RLPList; import org.junit.jupiter.api.Test; import java.math.BigInteger; @@ -28,12 +31,15 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.mockito.Mockito.*; class SnapBlocksRequestMessageTest { + private final TestSystemProperties config = new TestSystemProperties(); + private final BlockFactory blockFactory = new BlockFactory(config.getActivationConfig()); private final Block block4Test = new BlockGenerator().getBlock(1); - private final SnapBlocksRequestMessage underTest = new SnapBlocksRequestMessage(block4Test.getNumber()); + private final SnapBlocksRequestMessage underTest = new SnapBlocksRequestMessage(1, block4Test.getNumber()); @Test @@ -52,8 +58,26 @@ void getEncodedMessage_returnExpectedByteArray() { //when byte[] encodedMessage = underTest.getEncodedMessage(); + byte[] expectedEncodedMessage = RLP.encodeList( + RLP.encodeBigInteger(BigInteger.valueOf(underTest.getId())), + RLP.encodeList(RLP.encodeBigInteger(BigInteger.ONE))); + + //then + assertThat(encodedMessage, equalTo(expectedEncodedMessage)); + } + + @Test + void decodeMessage_returnExpectedMessage() { + //given default block 4 test + RLPList encodedRLPList = (RLPList) RLP.decode2(underTest.getEncodedMessage()).get(0); + + //when + Message decodedMessage = SnapBlocksRequestMessage.decodeMessage(blockFactory, encodedRLPList); + //then - assertThat(encodedMessage, equalTo(RLP.encodeList(RLP.encodeBigInteger(BigInteger.ONE)))); + assertInstanceOf(SnapBlocksRequestMessage.class, decodedMessage); + assertThat(underTest.getId(), equalTo(((SnapBlocksRequestMessage) decodedMessage).getId())); + assertEquals(1, ((SnapBlocksRequestMessage) decodedMessage).getBlockNumber()); } @Test @@ -68,10 +92,10 @@ void getBlockNumber_returnTheExpectedValue() { } @Test - void givenAcceptIsCalled_messageVisitorIsAppliedFormessage() { + void givenAcceptIsCalled_messageVisitorIsAppliedForMessage() { //given Block block = new BlockGenerator().getBlock(1); - SnapBlocksRequestMessage message = new SnapBlocksRequestMessage(block.getNumber()); + SnapBlocksRequestMessage message = new SnapBlocksRequestMessage(1, block.getNumber()); MessageVisitor visitor = mock(MessageVisitor.class); //when @@ -80,4 +104,4 @@ void givenAcceptIsCalled_messageVisitorIsAppliedFormessage() { //then verify(visitor, times(1)).apply(message); } -} \ No newline at end of file +} diff --git a/rskj-core/src/test/java/co/rsk/net/messages/SnapBlocksResponseMessageTest.java b/rskj-core/src/test/java/co/rsk/net/messages/SnapBlocksResponseMessageTest.java index ca2268576d..0e2d2aa0ef 100644 --- a/rskj-core/src/test/java/co/rsk/net/messages/SnapBlocksResponseMessageTest.java +++ b/rskj-core/src/test/java/co/rsk/net/messages/SnapBlocksResponseMessageTest.java @@ -28,14 +28,18 @@ import org.ethereum.db.BlockStore; import org.ethereum.db.IndexedBlockStore; import org.ethereum.util.RLP; +import org.ethereum.util.RLPList; import org.junit.jupiter.api.Test; +import java.math.BigInteger; import java.util.Collections; import java.util.List; +import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.mockito.Mockito.*; class SnapBlocksResponseMessageTest { @@ -46,7 +50,7 @@ class SnapBlocksResponseMessageTest { private final Block block4Test = new BlockGenerator().getBlock(1); private final List blockList = Collections.singletonList(new BlockGenerator().getBlock(1)); private final List blockDifficulties = Collections.singletonList(indexedBlockStore.getTotalDifficultyForHash(block4Test.getHash().getBytes())); - private final SnapBlocksResponseMessage underTest = new SnapBlocksResponseMessage(blockList, blockDifficulties); + private final SnapBlocksResponseMessage underTest = new SnapBlocksResponseMessage(1, blockList, blockDifficulties); @Test @@ -62,8 +66,10 @@ void getMessageType_returnCorrectMessageType() { void getEncodedMessage_returnExpectedByteArray() { //given default block 4 test byte[] expectedEncodedMessage = RLP.encodeList( - RLP.encodeList(RLP.encode(block4Test.getEncoded())), - RLP.encodeList(RLP.encode(blockDifficulties.get(0).getBytes()))); + RLP.encodeBigInteger(BigInteger.valueOf(underTest.getId())), + RLP.encodeList( + RLP.encodeList(RLP.encode(block4Test.getEncoded())), + RLP.encodeList(RLP.encode(blockDifficulties.get(0).getBytes())))); //when byte[] encodedMessage = underTest.getEncodedMessage(); @@ -71,6 +77,23 @@ void getEncodedMessage_returnExpectedByteArray() { assertThat(encodedMessage, equalTo(expectedEncodedMessage)); } + @Test + void decodeMessage_returnExpectedMessage() { + //given default block 4 test + RLPList encodedRLPList = (RLPList) RLP.decode2(underTest.getEncodedMessage()).get(0); + + //when + Message decodedMessage = SnapBlocksResponseMessage.decodeMessage(blockFactory, encodedRLPList); + + //then + assertInstanceOf(SnapBlocksResponseMessage.class, decodedMessage); + assertThat(underTest.getId(), equalTo(((SnapBlocksResponseMessage) decodedMessage).getId())); + assertThat(1, is(((SnapBlocksResponseMessage) decodedMessage).getBlocks().size())); + assertThat(block4Test.getHash(), is(((SnapBlocksResponseMessage) decodedMessage).getBlocks().get(0).getHash())); + assertThat(1, is(((SnapBlocksResponseMessage) decodedMessage).getDifficulties().size())); + assertThat(blockDifficulties.get(0), is(((SnapBlocksResponseMessage) decodedMessage).getDifficulties().get(0))); + } + @Test void getDifficulties_returnTheExpectedValue() { //given default block 4 test @@ -94,7 +117,7 @@ void getBlocks_returnTheExpectedValue() { @Test void givenAcceptIsCalled_messageVisitorIsAppliedForMessage() { //given - SnapBlocksResponseMessage message = new SnapBlocksResponseMessage(blockList, blockDifficulties); + SnapBlocksResponseMessage message = new SnapBlocksResponseMessage(1, blockList, blockDifficulties); MessageVisitor visitor = mock(MessageVisitor.class); //when diff --git a/rskj-core/src/test/java/co/rsk/net/messages/SnapStateChunkRequestMessageTest.java b/rskj-core/src/test/java/co/rsk/net/messages/SnapStateChunkRequestMessageTest.java index e8f5280560..7629c82dd1 100644 --- a/rskj-core/src/test/java/co/rsk/net/messages/SnapStateChunkRequestMessageTest.java +++ b/rskj-core/src/test/java/co/rsk/net/messages/SnapStateChunkRequestMessageTest.java @@ -19,8 +19,11 @@ package co.rsk.net.messages; import co.rsk.blockchain.utils.BlockGenerator; +import co.rsk.config.TestSystemProperties; import org.ethereum.core.Block; +import org.ethereum.core.BlockFactory; import org.ethereum.util.RLP; +import org.ethereum.util.RLPList; import org.junit.jupiter.api.Test; import java.math.BigInteger; @@ -28,10 +31,14 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.mockito.Mockito.*; public class SnapStateChunkRequestMessageTest { + private final TestSystemProperties config = new TestSystemProperties(); + private final BlockFactory blockFactory = new BlockFactory(config.getActivationConfig()); + @Test void getMessageType_returnCorrectMessageType() { //given @@ -45,6 +52,7 @@ void getMessageType_returnCorrectMessageType() { //then assertThat(messageType, equalTo(MessageType.SNAP_STATE_CHUNK_REQUEST_MESSAGE)); } + @Test void givenParameters4Test_assureExpectedValues() { //given @@ -104,6 +112,28 @@ void getEncodedMessageWithId_returnExpectedByteArray() { assertThat(encodedMessage, equalTo(expectedEncodedMessage)); } + @Test + void decodeMessage_returnExpectedMessage() { + //given default block 4 test + long blockNumber = 1L; + long id4Test = 42L; + long from = 1L; + long chunkSize = 20L; + + SnapStateChunkRequestMessage message = new SnapStateChunkRequestMessage(id4Test, blockNumber, from, chunkSize); + RLPList encodedRLPList = (RLPList) RLP.decode2(message.getEncodedMessage()).get(0); + + //when + Message decodedMessage = SnapStateChunkRequestMessage.decodeMessage(blockFactory, encodedRLPList); + + //then + assertInstanceOf(SnapStateChunkRequestMessage.class, decodedMessage); + assertEquals(id4Test,((SnapStateChunkRequestMessage) decodedMessage).getId()); + assertEquals(from,((SnapStateChunkRequestMessage) decodedMessage).getFrom()); + assertEquals(blockNumber,((SnapStateChunkRequestMessage) decodedMessage).getBlockNumber()); + assertEquals(chunkSize,((SnapStateChunkRequestMessage) decodedMessage).getChunkSize()); + } + @Test void givenAcceptIsCalled_messageVisitorIsAppliedForMessage() { //given diff --git a/rskj-core/src/test/java/co/rsk/net/messages/SnapStateChunkResponseMessageTest.java b/rskj-core/src/test/java/co/rsk/net/messages/SnapStateChunkResponseMessageTest.java index 2ce50f7115..fe8f651ab2 100644 --- a/rskj-core/src/test/java/co/rsk/net/messages/SnapStateChunkResponseMessageTest.java +++ b/rskj-core/src/test/java/co/rsk/net/messages/SnapStateChunkResponseMessageTest.java @@ -19,19 +19,26 @@ package co.rsk.net.messages; import co.rsk.blockchain.utils.BlockGenerator; +import co.rsk.config.TestSystemProperties; import org.ethereum.core.Block; +import org.ethereum.core.BlockFactory; import org.ethereum.util.RLP; +import org.ethereum.util.RLPList; import org.junit.jupiter.api.Test; import java.math.BigInteger; +import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; -import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.*; import static org.mockito.Mockito.*; public class SnapStateChunkResponseMessageTest { + private final TestSystemProperties config = new TestSystemProperties(); + private final BlockFactory blockFactory = new BlockFactory(config.getActivationConfig()); + @Test void getMessageType_returnCorrectMessageType() { //given @@ -81,11 +88,11 @@ void getEncodedMessageWithoutId_returnExpectedByteArray() { boolean complete = true; byte[] expectedEncodedMessage = RLP.encodeList( - trieValueBytes, + RLP.encodeElement(trieValueBytes), RLP.encodeBigInteger(BigInteger.valueOf(blockNumber)), RLP.encodeBigInteger(BigInteger.valueOf(from)), RLP.encodeBigInteger(BigInteger.valueOf(to)), - new byte[]{(byte) 1}); + RLP.encodeInt(complete ? 1 : 0)); SnapStateChunkResponseMessage message = new SnapStateChunkResponseMessage(id4Test, trieValueBytes, blockNumber, from, to, complete); @@ -114,7 +121,33 @@ void getEncodedMessageWithId_returnExpectedByteArray() { byte[] encodedMessage = message.getEncodedMessage(); //then - assertThat(encodedMessage, equalTo(expectedEncodedMessage)); + assertArrayEquals(encodedMessage, expectedEncodedMessage); + } + + @Test + void decodeMessage_returnExpectedMessage() { + //given default block 4 test + long blockNumber = 111L; + long id4Test = 42L; + byte[] trieValueBytes = "any random data".getBytes(); + long from = 5L; + long to = 20L; + boolean complete = false; + + SnapStateChunkResponseMessage message = new SnapStateChunkResponseMessage(id4Test, trieValueBytes, blockNumber, from, to, complete); + RLPList encodedRLPList = (RLPList) RLP.decode2(message.getEncodedMessage()).get(0); + + //when + Message decodedMessage = SnapStateChunkResponseMessage.decodeMessage(blockFactory, encodedRLPList); + + //then + assertInstanceOf(SnapStateChunkResponseMessage.class, decodedMessage); + assertEquals(id4Test,((SnapStateChunkResponseMessage) decodedMessage).getId()); + assertEquals(from,((SnapStateChunkResponseMessage) decodedMessage).getFrom()); + assertEquals(to,((SnapStateChunkResponseMessage) decodedMessage).getTo()); + assertEquals(blockNumber,((SnapStateChunkResponseMessage) decodedMessage).getBlockNumber()); + assertEquals(complete, ((SnapStateChunkResponseMessage) decodedMessage).isComplete()); + assertThat(trieValueBytes, is(((SnapStateChunkResponseMessage) decodedMessage).getChunkOfTrieKeyValue())); } @Test diff --git a/rskj-core/src/test/java/co/rsk/net/messages/SnapStatusRequestMessageTest.java b/rskj-core/src/test/java/co/rsk/net/messages/SnapStatusRequestMessageTest.java index ff34d0a463..967406dd48 100644 --- a/rskj-core/src/test/java/co/rsk/net/messages/SnapStatusRequestMessageTest.java +++ b/rskj-core/src/test/java/co/rsk/net/messages/SnapStatusRequestMessageTest.java @@ -18,18 +18,29 @@ */ package co.rsk.net.messages; +import co.rsk.config.TestSystemProperties; +import org.ethereum.core.BlockFactory; import org.ethereum.util.RLP; +import org.ethereum.util.RLPList; import org.junit.jupiter.api.Test; +import java.math.BigInteger; + import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.mockito.Mockito.*; class SnapStatusRequestMessageTest { + + private final TestSystemProperties config = new TestSystemProperties(); + private final BlockFactory blockFactory = new BlockFactory(config.getActivationConfig()); + @Test void getMessageType_returnCorrectMessageType() { //given - SnapStatusRequestMessage message = new SnapStatusRequestMessage(); + SnapStatusRequestMessage message = new SnapStatusRequestMessage(1); //when MessageType messageType = message.getMessageType(); @@ -41,8 +52,8 @@ void getMessageType_returnCorrectMessageType() { @Test void getEncodedMessage_returnExpectedByteArray() { //given - SnapStatusRequestMessage message = new SnapStatusRequestMessage(); - byte[] expectedEncodedMessage = RLP.encodedEmptyList(); + SnapStatusRequestMessage message = new SnapStatusRequestMessage(1); + byte[] expectedEncodedMessage = RLP.encodeList(RLP.encodeBigInteger(BigInteger.valueOf(1)), RLP.encodedEmptyList()); //when byte[] encodedMessage = message.getEncodedMessage(); @@ -50,10 +61,24 @@ void getEncodedMessage_returnExpectedByteArray() { assertThat(encodedMessage, equalTo(expectedEncodedMessage)); } + @Test + void decodeMessage_returnExpectedMessage() { + //given default block 4 test + SnapStatusRequestMessage message = new SnapStatusRequestMessage(111); + RLPList encodedRLPList = (RLPList) RLP.decode2(message.getEncodedMessage()).get(0); + + //when + Message decodedMessage = SnapStatusRequestMessage.decodeMessage(blockFactory, encodedRLPList); + + //then + assertInstanceOf(SnapStatusRequestMessage.class, decodedMessage); + assertEquals(111, ((SnapStatusRequestMessage) decodedMessage).getId()); + } + @Test void givenAcceptIsCalled_messageVisitorIsAppliedForMessage() { //given - SnapStatusRequestMessage message = new SnapStatusRequestMessage(); + SnapStatusRequestMessage message = new SnapStatusRequestMessage(1); MessageVisitor visitor = mock(MessageVisitor.class); //when diff --git a/rskj-core/src/test/java/co/rsk/net/messages/SnapStatusResponseMessageTest.java b/rskj-core/src/test/java/co/rsk/net/messages/SnapStatusResponseMessageTest.java index 623d0c4232..25e308edb0 100644 --- a/rskj-core/src/test/java/co/rsk/net/messages/SnapStatusResponseMessageTest.java +++ b/rskj-core/src/test/java/co/rsk/net/messages/SnapStatusResponseMessageTest.java @@ -28,6 +28,7 @@ import org.ethereum.db.BlockStore; import org.ethereum.db.IndexedBlockStore; import org.ethereum.util.RLP; +import org.ethereum.util.RLPList; import org.junit.jupiter.api.Test; import java.math.BigInteger; @@ -36,7 +37,8 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; -import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.mockito.Mockito.*; class SnapStatusResponseMessageTest { @@ -48,7 +50,7 @@ class SnapStatusResponseMessageTest { private final List blockList = Collections.singletonList(new BlockGenerator().getBlock(1)); private final List blockDifficulties = Collections.singletonList(indexedBlockStore.getTotalDifficultyForHash(block4Test.getHash().getBytes())); private final long trieSize = 1L; - private final SnapStatusResponseMessage underTest = new SnapStatusResponseMessage(blockList, blockDifficulties, trieSize); + private final SnapStatusResponseMessage underTest = new SnapStatusResponseMessage(1, blockList, blockDifficulties, trieSize); @Test @@ -64,9 +66,11 @@ void getMessageType_returnCorrectMessageType() { void getEncodedMessage_returnExpectedByteArray() { //given default block 4 test byte[] expectedEncodedMessage = RLP.encodeList( - RLP.encodeList(RLP.encode(block4Test.getEncoded())), - RLP.encodeList(RLP.encode(blockDifficulties.get(0).getBytes())), - RLP.encodeBigInteger(BigInteger.valueOf(this.trieSize))); + RLP.encodeBigInteger(BigInteger.valueOf(underTest.getId())), + RLP.encodeList( + RLP.encodeList(RLP.encode(block4Test.getEncoded())), + RLP.encodeList(RLP.encode(blockDifficulties.get(0).getBytes())), + RLP.encodeBigInteger(BigInteger.valueOf(this.trieSize)))); //when byte[] encodedMessage = underTest.getEncodedMessage(); @@ -74,6 +78,23 @@ void getEncodedMessage_returnExpectedByteArray() { assertThat(encodedMessage, equalTo(expectedEncodedMessage)); } + @Test + void decodeMessage_returnExpectedMessage() { + //given default block 4 test + RLPList encodedRLPList = (RLPList) RLP.decode2(underTest.getEncodedMessage()).get(0); + + //when + Message decodedMessage = SnapStatusResponseMessage.decodeMessage(blockFactory, encodedRLPList); + + //then + assertInstanceOf(SnapStatusResponseMessage.class, decodedMessage); + assertEquals(underTest.getId(), ((SnapStatusResponseMessage) decodedMessage).getId()); + assertEquals(1, ((SnapStatusResponseMessage) decodedMessage).getBlocks().size()); + assertEquals(underTest.getBlocks().get(0).getHash(), ((SnapStatusResponseMessage) decodedMessage).getBlocks().get(0).getHash()); + assertEquals(1, ((SnapStatusResponseMessage) decodedMessage).getDifficulties().size()); + assertEquals(underTest.getDifficulties().get(0), ((SnapStatusResponseMessage) decodedMessage).getDifficulties().get(0)); + } + @Test void getDifficulties_returnTheExpectedValue() { //given default block 4 test @@ -107,7 +128,7 @@ void getTrieSize_returnTheExpectedValue() { @Test void givenAcceptIsCalled_messageVisitorIsAppliedForMessage() { //given - SnapStatusResponseMessage message = new SnapStatusResponseMessage(blockList, blockDifficulties, trieSize); + SnapStatusResponseMessage message = new SnapStatusResponseMessage(1, blockList, blockDifficulties, trieSize); MessageVisitor visitor = mock(MessageVisitor.class); //when diff --git a/rskj-core/src/test/java/co/rsk/net/sync/SimpleSyncEventsHandler.java b/rskj-core/src/test/java/co/rsk/net/sync/SimpleSyncEventsHandler.java index d6694ecebe..b8a735fe02 100644 --- a/rskj-core/src/test/java/co/rsk/net/sync/SimpleSyncEventsHandler.java +++ b/rskj-core/src/test/java/co/rsk/net/sync/SimpleSyncEventsHandler.java @@ -19,11 +19,13 @@ package co.rsk.net.sync; import co.rsk.net.Peer; +import co.rsk.net.messages.MessageWithId; import co.rsk.scoring.EventType; import org.ethereum.core.Block; import org.ethereum.core.BlockHeader; import org.ethereum.core.BlockIdentifier; +import javax.annotation.Nonnull; import javax.annotation.Nullable; import java.util.Deque; import java.util.List; @@ -106,4 +108,10 @@ public boolean stopSyncingWasCalled() { @Override public void startSnapSync(Peer peer) { } + + @Override + public long nextMessageId() { return 0; } + + @Override + public void registerPendingMessage(@Nonnull MessageWithId message) { } } diff --git a/rskj-core/src/test/java/co/rsk/net/sync/SnapSyncStateTest.java b/rskj-core/src/test/java/co/rsk/net/sync/SnapSyncStateTest.java index f12e508048..a028c02304 100644 --- a/rskj-core/src/test/java/co/rsk/net/sync/SnapSyncStateTest.java +++ b/rskj-core/src/test/java/co/rsk/net/sync/SnapSyncStateTest.java @@ -21,29 +21,22 @@ import co.rsk.core.BlockDifficulty; import co.rsk.net.Peer; import co.rsk.net.SnapshotProcessor; -import co.rsk.net.messages.MessageType; -import co.rsk.net.messages.SnapBlocksResponseMessage; -import co.rsk.net.messages.SnapStateChunkResponseMessage; -import co.rsk.net.messages.SnapStatusResponseMessage; -import co.rsk.scoring.EventType; +import co.rsk.net.messages.*; import org.apache.commons.lang3.tuple.Pair; import org.ethereum.core.Block; -import org.ethereum.core.BlockHeader; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; import java.math.BigInteger; -import java.time.Duration; +import java.util.Collections; import java.util.List; import java.util.PriorityQueue; import java.util.Queue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.equalTo; import static org.junit.jupiter.api.Assertions.*; import static org.mockito.Mockito.*; @@ -55,9 +48,10 @@ class SnapSyncStateTest { private final SyncEventsHandler syncEventsHandler = mock(SyncEventsHandler.class); private final SnapshotPeersInformation peersInformation = mock(SnapshotPeersInformation.class); private final SnapshotProcessor snapshotProcessor = mock(SnapshotProcessor.class); + private final SnapSyncRequestManager snapRequestManager = mock(SnapSyncRequestManager.class); private final SyncMessageHandler.Listener listener = mock(SyncMessageHandler.Listener.class); - private final SnapSyncState underTest = new SnapSyncState(syncEventsHandler, snapshotProcessor, syncConfiguration, listener); + private final SnapSyncState underTest = new SnapSyncState(syncEventsHandler, snapshotProcessor, snapRequestManager, syncConfiguration, listener); @BeforeEach void setUp() { @@ -96,19 +90,6 @@ void givenOnEnterWasCalledTwice_thenSyncingStartsOnlyOnce() { verify(snapshotProcessor, times(1)).startSyncing(underTest); } - @Test - void givenTickIsCalledBeforeTimeout_thenTimerIsUpdated_andNoTimeoutHappens() { - //given - Duration elapsedTime = Duration.ofMillis(10); - underTest.timeElapsed = Duration.ZERO; - // when - underTest.tick(elapsedTime); - //then - assertThat(underTest.timeElapsed, equalTo(elapsedTime)); - verify(syncEventsHandler, never()).stopSyncing(); - verify(syncEventsHandler, never()).onErrorSyncing(any(), any(), any(), any()); - } - @Test void givenFinishIsCalled_thenSyncEventHandlerStopsSync() { //given-when @@ -122,9 +103,10 @@ void givenFinishIsCalled_thenSyncEventHandlerStopsSync() { void givenOnSnapStatusIsCalled_thenJobIsAddedAndRun() throws InterruptedException { //given Peer peer = mock(Peer.class); - SnapStatusResponseMessage msg = mock(SnapStatusResponseMessage.class); + SnapStatusResponseMessage msg = new SnapStatusResponseMessage(1, Collections.emptyList(), Collections.emptyList(), 1); CountDownLatch latch = new CountDownLatch(1); doCountDownOnQueueEmpty(listener, latch); + doReturn(true).when(snapRequestManager).processResponse(any()); underTest.onEnter(); //when @@ -144,9 +126,10 @@ void givenOnSnapStatusIsCalled_thenJobIsAddedAndRun() throws InterruptedExceptio void givenOnSnapBlocksIsCalled_thenJobIsAddedAndRun() throws InterruptedException { //given Peer peer = mock(Peer.class); - SnapBlocksResponseMessage msg = mock(SnapBlocksResponseMessage.class); + SnapBlocksResponseMessage msg = new SnapBlocksResponseMessage(1, Collections.emptyList(), Collections.emptyList()); CountDownLatch latch = new CountDownLatch(1); doCountDownOnQueueEmpty(listener, latch); + doReturn(true).when(snapRequestManager).processResponse(any()); underTest.onEnter(); //when @@ -166,10 +149,10 @@ void givenOnSnapBlocksIsCalled_thenJobIsAddedAndRun() throws InterruptedExceptio void givenNewBlockHeadersIsCalled_thenJobIsAddedAndRun() throws InterruptedException { //given Peer peer = mock(Peer.class); - //noinspection unchecked - List msg = mock(List.class); + BlockHeadersResponseMessage msg = new BlockHeadersResponseMessage(1, Collections.emptyList()); CountDownLatch latch = new CountDownLatch(1); doCountDownOnQueueEmpty(listener, latch); + doReturn(true).when(snapRequestManager).processResponse(any()); underTest.onEnter(); //when @@ -189,9 +172,10 @@ void givenNewBlockHeadersIsCalled_thenJobIsAddedAndRun() throws InterruptedExcep void givenOnSnapStateChunkIsCalled_thenJobIsAddedAndRun() throws InterruptedException { //given Peer peer = mock(Peer.class); - SnapStateChunkResponseMessage msg = mock(SnapStateChunkResponseMessage.class); + SnapStateChunkResponseMessage msg = new SnapStateChunkResponseMessage(1, new byte[0], 1, 1, 1, true); CountDownLatch latch = new CountDownLatch(1); doCountDownOnQueueEmpty(listener, latch); + doReturn(true).when(snapRequestManager).processResponse(any()); underTest.onEnter(); //when @@ -207,20 +191,6 @@ void givenOnSnapStateChunkIsCalled_thenJobIsAddedAndRun() throws InterruptedExce assertEquals(msg.getMessageType(), jobArg.getValue().getMsgType()); } - @Test - void givenOnMessageTimeOut_thenShouldFail() throws InterruptedException { - //given - Peer peer = mock(Peer.class); - underTest.setLastBlock(mock(Block.class), mock(BlockDifficulty.class), peer); - underTest.setRunning(); - - //when - underTest.onMessageTimeOut(); - - //then - verify(syncEventsHandler, times(1)).onErrorSyncing(eq(peer), eq(EventType.TIMEOUT_MESSAGE), any()); - } - @Test void testSetAndGetLastBlock() { Block mockBlock = mock(Block.class);