diff --git a/evm/src/NttManager/NttManager.sol b/evm/src/NttManager/NttManager.sol index 06d7e558b..eef8c6606 100644 --- a/evm/src/NttManager/NttManager.sol +++ b/evm/src/NttManager/NttManager.sol @@ -138,7 +138,7 @@ contract NttManager is INttManager, NttManagerState { revert InvalidTargetChain(nativeTokenTransfer.toChain, chainId); } TrimmedAmount memory nativeTransferAmount = - (nativeTokenTransfer.amount.untrim(tokenDecimals_)).trim(tokenDecimals_); + (nativeTokenTransfer.amount.untrim(tokenDecimals_)).trim(tokenDecimals_, tokenDecimals_); address transferRecipient = fromWormholeFormat(nativeTokenTransfer.to); @@ -227,6 +227,7 @@ contract NttManager is INttManager, NttManagerState { ) internal { uint256 numEnabledTransceivers = enabledTransceivers.length; mapping(address => TransceiverInfo) storage transceiverInfos = _getTransceiverInfosStorage(); + bytes32 peerAddress = _getPeersStorage()[recipientChain].peerAddress; // call into transceiver contracts to send the message for (uint256 i = 0; i < numEnabledTransceivers; i++) { address transceiverAddr = enabledTransceivers[i]; @@ -235,7 +236,7 @@ contract NttManager is INttManager, NttManagerState { recipientChain, transceiverInstructions[transceiverInfos[transceiverAddr].index], nttManagerMessage, - getPeer(recipientChain) + peerAddress ); } } @@ -305,14 +306,15 @@ contract NttManager is INttManager, NttManagerState { } // trim amount after burning to ensure transfer amount matches (amount - fee) - TrimmedAmount memory trimmedAmount = _trimTransferAmount(amount); + TrimmedAmount memory trimmedAmount = _trimTransferAmount(amount, recipientChain); + TrimmedAmount memory internalAmount = trimmedAmount.shift(tokenDecimals_); // get the sequence for this transfer uint64 sequence = _useMessageSequence(); { // now check rate limits - bool isAmountRateLimited = _isOutboundAmountRateLimited(trimmedAmount); + bool isAmountRateLimited = _isOutboundAmountRateLimited(internalAmount); if (!shouldQueue && isAmountRateLimited) { revert NotEnoughCapacity(getCurrentOutboundCapacity(), amount); } @@ -341,10 +343,10 @@ contract NttManager is INttManager, NttManagerState { } // otherwise, consume the outbound amount - _consumeOutboundAmount(trimmedAmount); + _consumeOutboundAmount(internalAmount); // When sending a transfer, we refill the inbound rate limit for // that chain by the same amount (we call this "backflow") - _backfillInboundAmount(trimmedAmount, recipientChain); + _backfillInboundAmount(internalAmount, recipientChain); return _transfer( sequence, trimmedAmount, recipientChain, recipient, msg.sender, transceiverInstructions @@ -448,10 +450,19 @@ contract NttManager is INttManager, NttManagerState { } } - function _trimTransferAmount(uint256 amount) internal view returns (TrimmedAmount memory) { + function _trimTransferAmount( + uint256 amount, + uint16 toChain + ) internal view returns (TrimmedAmount memory) { + uint8 toDecimals = _getPeersStorage()[toChain].tokenDecimals; + + if (toDecimals == 0) { + revert InvalidPeerDecimals(); + } + TrimmedAmount memory trimmedAmount; { - trimmedAmount = amount.trim(tokenDecimals_); + trimmedAmount = amount.trim(tokenDecimals_, toDecimals); // don't deposit dust that can not be bridged due to the decimal shift uint256 newAmount = trimmedAmount.untrim(tokenDecimals_); if (amount != newAmount) { diff --git a/evm/src/NttManager/NttManagerState.sol b/evm/src/NttManager/NttManagerState.sol index 470716e01..fa2f807da 100644 --- a/evm/src/NttManager/NttManagerState.sol +++ b/evm/src/NttManager/NttManagerState.sol @@ -121,7 +121,11 @@ abstract contract NttManagerState is } } - function _getPeersStorage() internal pure returns (mapping(uint16 => bytes32) storage $) { + function _getPeersStorage() + internal + pure + returns (mapping(uint16 => NttManagerPeer) storage $) + { uint256 slot = uint256(PEERS_SLOT); assembly ("memory-safe") { $.slot := slot @@ -157,7 +161,7 @@ abstract contract NttManagerState is } /// @inheritdoc INttManagerState - function getPeer(uint16 chainId_) public view returns (bytes32) { + function getPeer(uint16 chainId_) external view returns (NttManagerPeer memory) { return _getPeersStorage()[chainId_]; } @@ -251,29 +255,35 @@ abstract contract NttManagerState is } /// @inheritdoc INttManagerState - function setPeer(uint16 peerChainId, bytes32 peerContract) public onlyOwner { + function setPeer(uint16 peerChainId, bytes32 peerContract, uint8 decimals) public onlyOwner { if (peerChainId == 0) { revert InvalidPeerChainIdZero(); } if (peerContract == bytes32(0)) { revert InvalidPeerZeroAddress(); } + if (decimals == 0) { + revert InvalidPeerDecimals(); + } - bytes32 oldPeerContract = _getPeersStorage()[peerChainId]; + NttManagerPeer memory oldPeer = _getPeersStorage()[peerChainId]; - _getPeersStorage()[peerChainId] = peerContract; + _getPeersStorage()[peerChainId].peerAddress = peerContract; + _getPeersStorage()[peerChainId].tokenDecimals = decimals; - emit PeerUpdated(peerChainId, oldPeerContract, peerContract); + emit PeerUpdated( + peerChainId, oldPeer.peerAddress, oldPeer.tokenDecimals, peerContract, decimals + ); } /// @inheritdoc INttManagerState function setOutboundLimit(uint256 limit) external onlyOwner { - _setOutboundLimit(limit.trim(tokenDecimals_)); + _setOutboundLimit(limit.trim(tokenDecimals_, tokenDecimals_)); } /// @inheritdoc INttManagerState function setInboundLimit(uint256 limit, uint16 chainId_) external onlyOwner { - _setInboundLimit(limit.trim(tokenDecimals_), chainId_); + _setInboundLimit(limit.trim(tokenDecimals_, tokenDecimals_), chainId_); } // =============== Internal ============================================================== @@ -306,7 +316,7 @@ abstract contract NttManagerState is /// @dev Verify that the peer address saved for `sourceChainId` matches the `peerAddress`. function _verifyPeer(uint16 sourceChainId, bytes32 peerAddress) internal view { - if (getPeer(sourceChainId) != peerAddress) { + if (_getPeersStorage()[sourceChainId].peerAddress != peerAddress) { revert InvalidPeer(sourceChainId, peerAddress); } } diff --git a/evm/src/interfaces/INttManagerEvents.sol b/evm/src/interfaces/INttManagerEvents.sol index 4fceb34ab..b518eb81c 100644 --- a/evm/src/interfaces/INttManagerEvents.sol +++ b/evm/src/interfaces/INttManagerEvents.sol @@ -18,11 +18,19 @@ interface INttManagerEvents { /// @notice Emitted when the peer contract is updated. /// @dev Topic0 - /// 0x51b8437a7e22240c473f4cbdb4ed3a4f4bf5a9e7b3c511d7cfe0197325735700. + /// 0x1456404e7f41f35c3daac941bb50bad417a66275c3040061b4287d787719599d. /// @param chainId_ The chain ID of the peer contract. /// @param oldPeerContract The old peer contract address. + /// @param oldPeerDecimals The old peer contract decimals. /// @param peerContract The new peer contract address. - event PeerUpdated(uint16 indexed chainId_, bytes32 oldPeerContract, bytes32 peerContract); + /// @param peerDecimals The new peer contract decimals. + event PeerUpdated( + uint16 indexed chainId_, + bytes32 oldPeerContract, + uint8 oldPeerDecimals, + bytes32 peerContract, + uint8 peerDecimals + ); /// @notice Emitted when a message has been attested to. /// @dev Topic0 diff --git a/evm/src/interfaces/INttManagerState.sol b/evm/src/interfaces/INttManagerState.sol index a789ba965..a9a8ab6c6 100644 --- a/evm/src/interfaces/INttManagerState.sol +++ b/evm/src/interfaces/INttManagerState.sol @@ -21,6 +21,9 @@ interface INttManagerState { /// @notice Peer cannot be the zero address. error InvalidPeerZeroAddress(); + /// @notice Peer cannot have zero decimals. + error InvalidPeerDecimals(); + /// @notice The number of thresholds should not be zero. error ZeroThreshold(); @@ -30,6 +33,12 @@ interface INttManagerState { error ThresholdTooHigh(uint256 threshold, uint256 transceivers); error RetrievedIncorrectRegisteredTransceivers(uint256 retrieved, uint256 registered); + /// @dev The peer on another chain. + struct NttManagerPeer { + bytes32 peerAddress; + uint8 tokenDecimals; + } + /// @notice Sets the transceiver for the given chain. /// @param transceiver The address of the transceiver. /// @dev This method can only be executed by the `owner`. @@ -48,13 +57,14 @@ interface INttManagerState { /// @notice Returns registered peer contract for a given chain. /// @param chainId_ chain ID. - function getPeer(uint16 chainId_) external view returns (bytes32); + function getPeer(uint16 chainId_) external view returns (NttManagerPeer memory); /// @notice Sets the corresponding peer. /// @dev The nttManager that executes the message sets the source nttManager as the peer. /// @param peerChainId The chain ID of the peer. /// @param peerContract The address of the peer nttManager contract. - function setPeer(uint16 peerChainId, bytes32 peerContract) external; + /// @param decimals The number of decimals of the token on the peer chain. + function setPeer(uint16 peerChainId, bytes32 peerContract, uint8 decimals) external; /// @notice Checks if a message has been approved. The message should have at least /// the minimum threshold of attestations from distinct endpoints. diff --git a/evm/src/libraries/TrimmedAmount.sol b/evm/src/libraries/TrimmedAmount.sol index 9d8eecfec..d546a8731 100644 --- a/evm/src/libraries/TrimmedAmount.sol +++ b/evm/src/libraries/TrimmedAmount.sol @@ -65,6 +65,7 @@ library TrimmedAmountLib { return a.amount < b.amount; } + // TODO: is this needed? let's remove it function isZero(TrimmedAmount memory a) internal pure returns (bool) { return (a.amount == 0 && a.decimals == 0); } @@ -140,16 +141,38 @@ library TrimmedAmountLib { } } - function trim(uint256 amt, uint8 fromDecimals) internal pure returns (TrimmedAmount memory) { - uint8 toDecimals = minUint8(TRIMMED_DECIMALS, fromDecimals); - uint256 amountScaled = scale(amt, fromDecimals, toDecimals); + function shift( + TrimmedAmount memory amount, + uint8 toDecimals + ) internal pure returns (TrimmedAmount memory) { + uint8 actualToDecimals = minUint8(TRIMMED_DECIMALS, toDecimals); + return TrimmedAmount( + uint64(scale(amount.amount, amount.decimals, actualToDecimals)), actualToDecimals + ); + } + + /// @dev trim the amount to target decimals. + /// The actual resulting decimals is the minimum of TRIMMED_DECIMALS, + /// fromDecimals, and toDecimals. This ensures that no dust is + /// destroyed on either side of the transfer. + /// @param amt the amount to be trimmed + /// @param fromDecimals the original decimals of the amount + /// @param toDecimals the target decimals of the amount + /// + function trim( + uint256 amt, + uint8 fromDecimals, + uint8 toDecimals + ) internal pure returns (TrimmedAmount memory) { + uint8 actualToDecimals = minUint8(minUint8(TRIMMED_DECIMALS, fromDecimals), toDecimals); + uint256 amountScaled = scale(amt, fromDecimals, actualToDecimals); // NOTE: amt after trimming must fit into uint64 (that's the point of // trimming, as Solana only supports uint64 for token amts) if (amountScaled > type(uint64).max) { revert AmountTooLarge(amt); } - return TrimmedAmount(uint64(amountScaled), toDecimals); + return TrimmedAmount(uint64(amountScaled), actualToDecimals); } function untrim(TrimmedAmount memory amt, uint8 toDecimals) internal pure returns (uint256) { diff --git a/evm/test/IntegrationRelayer.t.sol b/evm/test/IntegrationRelayer.t.sol index 5c63aeebc..6cc9bc7c7 100755 --- a/evm/test/IntegrationRelayer.t.sol +++ b/evm/test/IntegrationRelayer.t.sol @@ -167,7 +167,7 @@ contract TestEndToEndRelayer is wormholeTransceiverChain2.setWormholePeer( chainId1, bytes32(uint256(uint160(address(wormholeTransceiverChain1)))) ); - nttManagerChain2.setPeer(chainId1, bytes32(uint256(uint160(address(nttManagerChain1))))); + nttManagerChain2.setPeer(chainId1, bytes32(uint256(uint160(address(nttManagerChain1)))), 9); DummyToken token2 = DummyTokenMintAndBurn(nttManagerChain2.token()); wormholeTransceiverChain2.setIsWormholeRelayingEnabled(chainId1, true); wormholeTransceiverChain2.setIsWormholeEvmChain(chainId1); @@ -178,7 +178,7 @@ contract TestEndToEndRelayer is wormholeTransceiverChain1.setWormholePeer( chainId2, bytes32(uint256(uint160((address(wormholeTransceiverChain2))))) ); - nttManagerChain1.setPeer(chainId2, bytes32(uint256(uint160(address(nttManagerChain2))))); + nttManagerChain1.setPeer(chainId2, bytes32(uint256(uint160(address(nttManagerChain2)))), 7); // Enable general relaying on the chain to transfer for the funds. wormholeTransceiverChain1.setIsWormholeRelayingEnabled(chainId2, true); @@ -265,14 +265,14 @@ contract TestEndToEndRelayer is wormholeTransceiverChain2.setWormholePeer( chainId1, bytes32(uint256(uint160(address(wormholeTransceiverChain1)))) ); - nttManagerChain2.setPeer(chainId1, bytes32(uint256(uint160(address(nttManagerChain1))))); + nttManagerChain2.setPeer(chainId1, bytes32(uint256(uint160(address(nttManagerChain1)))), 9); DummyToken token2 = DummyTokenMintAndBurn(nttManagerChain2.token()); wormholeTransceiverChain2.setIsWormholeRelayingEnabled(chainId1, true); wormholeTransceiverChain2.setIsWormholeEvmChain(chainId1); // Register peer contracts for the nttManager and transceiver. Transceivers and nttManager each have the concept of peers here. vm.selectFork(sourceFork); - nttManagerChain1.setPeer(chainId2, bytes32(uint256(uint160(address(nttManagerChain2))))); + nttManagerChain1.setPeer(chainId2, bytes32(uint256(uint160(address(nttManagerChain2)))), 7); wormholeTransceiverChain1.setWormholePeer( chainId2, bytes32(uint256(uint160((address(wormholeTransceiverChain2))))) ); @@ -495,8 +495,8 @@ contract TestRelayerEndToEndManual is nttManagerChain2.setInboundLimit(type(uint64).max, chainId1); // Register peer contracts for the nttManager and transceiver. Transceivers and nttManager each have the concept of peers here. - nttManagerChain1.setPeer(chainId2, bytes32(uint256(uint160(address(nttManagerChain2))))); - nttManagerChain2.setPeer(chainId1, bytes32(uint256(uint160(address(nttManagerChain1))))); + nttManagerChain1.setPeer(chainId2, bytes32(uint256(uint160(address(nttManagerChain2)))), 9); + nttManagerChain2.setPeer(chainId1, bytes32(uint256(uint160(address(nttManagerChain1)))), 7); } function test_relayerTransceiverAuth() public { @@ -551,7 +551,7 @@ contract TestRelayerEndToEndManual is bytes[] memory a; - nttManagerChain2.setPeer(chainId1, bytes32(uint256(uint160(address(0x1))))); + nttManagerChain2.setPeer(chainId1, bytes32(uint256(uint160(address(0x1)))), 9); vm.startPrank(relayer); vm.expectRevert(); // bad nttManager peer wormholeTransceiverChain2.receiveWormholeMessages( @@ -564,7 +564,7 @@ contract TestRelayerEndToEndManual is vm.stopPrank(); // Wrong caller - aka not relayer contract - nttManagerChain2.setPeer(chainId1, bytes32(uint256(uint160(address(nttManagerChain1))))); + nttManagerChain2.setPeer(chainId1, bytes32(uint256(uint160(address(nttManagerChain1)))), 9); vm.prank(userD); vm.expectRevert( abi.encodeWithSelector(IWormholeTransceiverState.CallerNotRelayer.selector, userD) diff --git a/evm/test/IntegrationStandalone.t.sol b/evm/test/IntegrationStandalone.t.sol index 5cbb09924..e371ec9b1 100755 --- a/evm/test/IntegrationStandalone.t.sol +++ b/evm/test/IntegrationStandalone.t.sol @@ -114,8 +114,8 @@ contract TestEndToEndBase is Test, INttManagerEvents, IRateLimiterEvents { nttManagerChain2.setInboundLimit(type(uint64).max, chainId1); // Register peer contracts for the nttManager and transceiver. Transceivers and nttManager each have the concept of peers here. - nttManagerChain1.setPeer(chainId2, bytes32(uint256(uint160(address(nttManagerChain2))))); - nttManagerChain2.setPeer(chainId1, bytes32(uint256(uint160(address(nttManagerChain1))))); + nttManagerChain1.setPeer(chainId2, bytes32(uint256(uint160(address(nttManagerChain2)))), 9); + nttManagerChain2.setPeer(chainId1, bytes32(uint256(uint160(address(nttManagerChain1)))), 7); // Set peers for the transceivers wormholeTransceiverChain1.setWormholePeer( diff --git a/evm/test/NttManager.t.sol b/evm/test/NttManager.t.sol index 7ee503b85..a6209dbb6 100644 --- a/evm/test/NttManager.t.sol +++ b/evm/test/NttManager.t.sol @@ -200,7 +200,7 @@ contract TestNttManager is Test, INttManagerEvents, IRateLimiterEvents { (DummyTransceiver e1,) = TransceiverHelpersLib.setup_transceivers(nttManagerOther); nttManagerOther.removeTransceiver(address(e1)); bytes32 peer = toWormholeFormat(address(nttManager)); - nttManagerOther.setPeer(TransceiverHelpersLib.SENDING_CHAIN_ID, peer); + nttManagerOther.setPeer(TransceiverHelpersLib.SENDING_CHAIN_ID, peer, 9); bytes memory transceiverMessage; (, transceiverMessage) = TransceiverHelpersLib.buildTransceiverMessageWithNttManagerPayload( @@ -238,7 +238,7 @@ contract TestNttManager is Test, INttManagerEvents, IRateLimiterEvents { // register nttManager peer bytes32 peer = toWormholeFormat(address(nttManager)); - nttManagerOther.setPeer(TransceiverHelpersLib.SENDING_CHAIN_ID, peer); + nttManagerOther.setPeer(TransceiverHelpersLib.SENDING_CHAIN_ID, peer, 9); TransceiverStructs.NttManagerMessage memory nttManagerMessage; bytes memory transceiverMessage; @@ -261,7 +261,7 @@ contract TestNttManager is Test, INttManagerEvents, IRateLimiterEvents { // register nttManager peer bytes32 peer = toWormholeFormat(address(nttManager)); - nttManagerOther.setPeer(TransceiverHelpersLib.SENDING_CHAIN_ID, peer); + nttManagerOther.setPeer(TransceiverHelpersLib.SENDING_CHAIN_ID, peer, 9); TransceiverStructs.NttManagerMessage memory nttManagerMessage; bytes memory transceiverMessage; @@ -289,7 +289,7 @@ contract TestNttManager is Test, INttManagerEvents, IRateLimiterEvents { nttManagerOther.setThreshold(2); bytes32 peer = toWormholeFormat(address(nttManager)); - nttManagerOther.setPeer(TransceiverHelpersLib.SENDING_CHAIN_ID, peer); + nttManagerOther.setPeer(TransceiverHelpersLib.SENDING_CHAIN_ID, peer, 9); ITransceiverReceiver[] memory transceivers = new ITransceiverReceiver[](1); transceivers[0] = e1; @@ -326,6 +326,7 @@ contract TestNttManager is Test, INttManagerEvents, IRateLimiterEvents { uint8 decimals = token.decimals(); + nttManager.setPeer(chainId, toWormholeFormat(address(0x1)), 9); nttManager.setOutboundLimit(TrimmedAmount(type(uint64).max, 8).untrim(decimals)); token.mintDummy(address(user_A), 5 * 10 ** decimals); @@ -444,6 +445,7 @@ contract TestNttManager is Test, INttManagerEvents, IRateLimiterEvents { uint256 maxAmount = 5 * 10 ** decimals; token.mintDummy(from, maxAmount); + nttManager.setPeer(chainId, toWormholeFormat(address(0x1)), 9); nttManager.setOutboundLimit(TrimmedAmount(type(uint64).max, 8).untrim(decimals)); nttManager.setInboundLimit( TrimmedAmount(type(uint64).max, 8).untrim(decimals), diff --git a/evm/test/RateLimit.t.sol b/evm/test/RateLimit.t.sol index 484dc01bf..913df44b6 100644 --- a/evm/test/RateLimit.t.sol +++ b/evm/test/RateLimit.t.sol @@ -40,6 +40,8 @@ contract TestRateLimit is Test, IRateLimiterEvents { nttManager = MockNttManagerContract(address(new ERC1967Proxy(address(implementation), ""))); nttManager.initialize(); + + nttManager.setPeer(chainId, toWormholeFormat(address(0x1)), 9); } function test_outboundRateLimit_setLimitSimple() public { @@ -52,8 +54,11 @@ contract TestRateLimit is Test, IRateLimiterEvents { IRateLimiter.RateLimitParams memory outboundLimitParams = nttManager.getOutboundLimitParams(); - assertEq(outboundLimitParams.limit.getAmount(), limit.trim(decimals).getAmount()); - assertEq(outboundLimitParams.currentCapacity.getAmount(), limit.trim(decimals).getAmount()); + assertEq(outboundLimitParams.limit.getAmount(), limit.trim(decimals, decimals).getAmount()); + assertEq( + outboundLimitParams.currentCapacity.getAmount(), + limit.trim(decimals, decimals).getAmount() + ); assertEq(outboundLimitParams.lastTxTimestamp, initialBlockTimestamp); } @@ -83,7 +88,7 @@ contract TestRateLimit is Test, IRateLimiterEvents { nttManager.getOutboundLimitParams(); assertEq( outboundLimitParams.currentCapacity.getAmount(), - (outboundLimit - transferAmount).trim(decimals).getAmount() + (outboundLimit - transferAmount).trim(decimals, decimals).getAmount() ); assertEq(outboundLimitParams.lastTxTimestamp, initialBlockTimestamp); @@ -127,11 +132,13 @@ contract TestRateLimit is Test, IRateLimiterEvents { IRateLimiter.RateLimitParams memory outboundLimitParams = nttManager.getOutboundLimitParams(); - assertEq(outboundLimitParams.limit.getAmount(), higherLimit.trim(decimals).getAmount()); + assertEq( + outboundLimitParams.limit.getAmount(), higherLimit.trim(decimals, decimals).getAmount() + ); assertEq(outboundLimitParams.lastTxTimestamp, initialBlockTimestamp); assertEq( outboundLimitParams.currentCapacity.getAmount(), - (2 * 10 ** decimals).trim(decimals).getAmount() + (2 * 10 ** decimals).trim(decimals, decimals).getAmount() ); } @@ -204,7 +211,9 @@ contract TestRateLimit is Test, IRateLimiterEvents { IRateLimiter.RateLimitParams memory outboundLimitParams = nttManager.getOutboundLimitParams(); - assertEq(outboundLimitParams.limit.getAmount(), higherLimit.trim(decimals).getAmount()); + assertEq( + outboundLimitParams.limit.getAmount(), higherLimit.trim(decimals, decimals).getAmount() + ); assertEq(outboundLimitParams.lastTxTimestamp, sixHoursLater); // capacity should be: // difference in limits + remaining capacity after t1 + the amount that's refreshed (based on the old rps) @@ -213,7 +222,7 @@ contract TestRateLimit is Test, IRateLimiterEvents { ( (1 * 10 ** decimals) + (1 * 10 ** decimals) + (outboundLimit * (6 hours)) / nttManager.rateLimitDuration() - ).trim(decimals).getAmount() + ).trim(decimals, decimals).getAmount() ); } @@ -251,7 +260,9 @@ contract TestRateLimit is Test, IRateLimiterEvents { IRateLimiter.RateLimitParams memory outboundLimitParams = nttManager.getOutboundLimitParams(); - assertEq(outboundLimitParams.limit.getAmount(), lowerLimit.trim(decimals).getAmount()); + assertEq( + outboundLimitParams.limit.getAmount(), lowerLimit.trim(decimals, decimals).getAmount() + ); assertEq(outboundLimitParams.lastTxTimestamp, sixHoursLater); // capacity should be: 0 assertEq(outboundLimitParams.currentCapacity.getAmount(), 0); @@ -292,7 +303,9 @@ contract TestRateLimit is Test, IRateLimiterEvents { IRateLimiter.RateLimitParams memory outboundLimitParams = nttManager.getOutboundLimitParams(); - assertEq(outboundLimitParams.limit.getAmount(), lowerLimit.trim(decimals).getAmount()); + assertEq( + outboundLimitParams.limit.getAmount(), lowerLimit.trim(decimals, decimals).getAmount() + ); assertEq(outboundLimitParams.lastTxTimestamp, sixHoursLater); // capacity should be: // remaining capacity after t1 - difference in limits + the amount that's refreshed (based on the old rps) @@ -301,7 +314,7 @@ contract TestRateLimit is Test, IRateLimiterEvents { ( (3 * 10 ** decimals) - (1 * 10 ** decimals) + (outboundLimit * (6 hours)) / nttManager.rateLimitDuration() - ).trim(decimals).getAmount() + ).trim(decimals, decimals).getAmount() ); } @@ -351,7 +364,7 @@ contract TestRateLimit is Test, IRateLimiterEvents { // assert currentCapacity is updated TrimmedAmount memory newCapacity = - outboundLimit.trim(decimals).sub(transferAmount.trim(decimals)); + outboundLimit.trim(decimals, decimals).sub(transferAmount.trim(decimals, decimals)); assertEq(nttManager.getCurrentOutboundCapacity(), newCapacity.untrim(decimals)); uint256 badTransferAmount = 2 * 10 ** decimals; @@ -396,7 +409,7 @@ contract TestRateLimit is Test, IRateLimiterEvents { // assert that the transfer got queued up assertEq(qSeq, 0); IRateLimiter.OutboundQueuedTransfer memory qt = nttManager.getOutboundQueuedTransfer(0); - assertEq(qt.amount.getAmount(), transferAmount.trim(decimals).getAmount()); + assertEq(qt.amount.getAmount(), transferAmount.trim(decimals, decimals).getAmount()); assertEq(qt.recipientChain, chainId); assertEq(qt.recipient, toWormholeFormat(user_B)); assertEq(qt.txTimestamp, initialBlockTimestamp); @@ -488,7 +501,7 @@ contract TestRateLimit is Test, IRateLimiterEvents { nttManager, nttManager, TrimmedAmount(50, 8), - uint256(5).trim(token.decimals()), + uint256(5).trim(token.decimals(), token.decimals()), transceivers ); encodedEm = TransceiverStructs.encodeTransceiverMessage( @@ -572,6 +585,7 @@ contract TestRateLimit is Test, IRateLimiterEvents { DummyToken token = DummyToken(nttManager.token()); uint8 decimals = token.decimals(); + assertEq(decimals, 18); TrimmedAmount memory mintAmount = TrimmedAmount(50, 8); token.mintDummy(address(user_A), mintAmount.untrim(decimals)); diff --git a/evm/test/TrimmedAmount.t.sol b/evm/test/TrimmedAmount.t.sol index 613fc1b54..9d7b1b3d5 100644 --- a/evm/test/TrimmedAmount.t.sol +++ b/evm/test/TrimmedAmount.t.sol @@ -12,13 +12,25 @@ contract TrimmingTest is Test { function testTrimmingRoundTrip() public { uint8 decimals = 18; uint256 amount = 50 * 10 ** decimals; - TrimmedAmount memory trimmed = amount.trim(decimals); + TrimmedAmount memory trimmed = amount.trim(decimals, 8); uint256 roundTrip = trimmed.untrim(decimals); uint256 expectedAmount = 50 * 10 ** decimals; assertEq(expectedAmount, roundTrip); } + function testTrimLessThan8() public { + uint8 decimals = 7; + uint8 targetDecimals = 3; + uint256 amount = 9123412342342; + TrimmedAmount memory trimmed = amount.trim(decimals, targetDecimals); + + uint64 expectedAmount = 912341234; + uint8 expectedDecimals = targetDecimals; + assertEq(trimmed.amount, expectedAmount); + assertEq(trimmed.decimals, expectedDecimals); + } + function testAddOperatorNonZero() public pure { uint8[2] memory decimals = [18, 3]; uint8[2] memory expectedDecimals = [8, 3]; @@ -26,8 +38,8 @@ contract TrimmingTest is Test { for (uint8 i = 0; i < decimals.length; i++) { uint256 amount = 5 * 10 ** decimals[i]; uint256 amountOther = 2 * 10 ** decimals[i]; - TrimmedAmount memory trimmedAmount = amount.trim(decimals[i]); - TrimmedAmount memory trimmedAmountOther = amountOther.trim(decimals[i]); + TrimmedAmount memory trimmedAmount = amount.trim(decimals[i], 8); + TrimmedAmount memory trimmedAmountOther = amountOther.trim(decimals[i], 8); TrimmedAmount memory trimmedSum = trimmedAmount.add(trimmedAmountOther); TrimmedAmount memory expectedTrimmedSum = @@ -43,8 +55,8 @@ contract TrimmingTest is Test { for (uint8 i = 0; i < decimals.length; i++) { uint256 amount = 5 * 10 ** decimals[i]; uint256 amountOther = 0; - TrimmedAmount memory trimmedAmount = amount.trim(decimals[i]); - TrimmedAmount memory trimmedAmountOther = amountOther.trim(decimals[i]); + TrimmedAmount memory trimmedAmount = amount.trim(decimals[i], 8); + TrimmedAmount memory trimmedAmountOther = amountOther.trim(decimals[i], 8); TrimmedAmount memory trimmedSum = trimmedAmount.add(trimmedAmountOther); TrimmedAmount memory expectedTrimmedSum = @@ -59,8 +71,8 @@ contract TrimmingTest is Test { uint256 amount = 5 * 10 ** decimals; uint256 amountOther = 2 * 10 ** decimalsOther; - TrimmedAmount memory trimmedAmount = amount.trim(decimals); - TrimmedAmount memory trimmedAmountOther = amountOther.trim(decimalsOther); + TrimmedAmount memory trimmedAmount = amount.trim(decimals, 8); + TrimmedAmount memory trimmedAmountOther = amountOther.trim(decimalsOther, 8); vm.expectRevert(); trimmedAmount.add(trimmedAmountOther); @@ -73,8 +85,8 @@ contract TrimmingTest is Test { for (uint8 i = 0; i < decimals.length; i++) { uint256 amount = 5 * 10 ** decimals[i]; uint256 amountOther = 2 * 10 ** 9; - TrimmedAmount memory trimmedAmount = amount.trim(decimals[i]); - TrimmedAmount memory trimmedAmountOther = amountOther.trim(9); + TrimmedAmount memory trimmedAmount = amount.trim(decimals[i], 8); + TrimmedAmount memory trimmedAmountOther = amountOther.trim(9, 8); TrimmedAmount memory trimmedSum = trimmedAmount.add(trimmedAmountOther); TrimmedAmount memory expectedTrimmedSum = @@ -90,8 +102,8 @@ contract TrimmingTest is Test { for (uint8 i = 0; i < decimals.length; i++) { uint256 amount = 5 * 10 ** decimals[i]; uint256 amountOther = 2 * 10 ** decimals[i]; - TrimmedAmount memory trimmedAmount = amount.trim(decimals[i]); - TrimmedAmount memory trimmedAmountOther = amountOther.trim(decimals[i]); + TrimmedAmount memory trimmedAmount = amount.trim(decimals[i], 8); + TrimmedAmount memory trimmedAmountOther = amountOther.trim(decimals[i], 8); TrimmedAmount memory trimmedSub = trimmedAmount.sub(trimmedAmountOther); TrimmedAmount memory expectedTrimmedSub = @@ -107,8 +119,8 @@ contract TrimmingTest is Test { for (uint8 i = 0; i < decimals.length; i++) { uint256 amount = 5 * 10 ** decimals[i]; uint256 amountOther = 0; - TrimmedAmount memory trimmedAmount = amount.trim(decimals[i]); - TrimmedAmount memory trimmedAmountOther = amountOther.trim(decimals[i]); + TrimmedAmount memory trimmedAmount = amount.trim(decimals[i], 8); + TrimmedAmount memory trimmedAmountOther = amountOther.trim(decimals[i], 8); TrimmedAmount memory trimmedSub = trimmedAmount.sub(trimmedAmountOther); TrimmedAmount memory expectedTrimmedSub = @@ -123,8 +135,8 @@ contract TrimmingTest is Test { for (uint8 i = 0; i < decimals.length; i++) { uint256 amount = 5 * 10 ** decimals[i]; uint256 amountOther = 6 * 10 ** decimals[i]; - TrimmedAmount memory trimmedAmount = amount.trim(decimals[i]); - TrimmedAmount memory trimmedAmountOther = amountOther.trim(decimals[i]); + TrimmedAmount memory trimmedAmount = amount.trim(decimals[i], 8); + TrimmedAmount memory trimmedAmountOther = amountOther.trim(decimals[i], 8); vm.expectRevert(); trimmedAmount.sub(trimmedAmountOther); @@ -136,7 +148,7 @@ contract TrimmingTest is Test { uint8 targetDecimals = 6; uint256 amount = 5 * 10 ** sourceDecimals; - TrimmedAmount memory trimmedAmount = amount.trim(sourceDecimals); + TrimmedAmount memory trimmedAmount = amount.trim(sourceDecimals, 8); // trimmed to 8 uint256 amountRoundTrip = trimmedAmount.untrim(targetDecimals); // untrim to 6 diff --git a/evm/test/Upgrades.t.sol b/evm/test/Upgrades.t.sol index 7dafa5c9c..a60ff2050 100644 --- a/evm/test/Upgrades.t.sol +++ b/evm/test/Upgrades.t.sol @@ -114,8 +114,16 @@ contract TestUpgrades is Test, INttManagerEvents, IRateLimiterEvents { nttManagerChain2.setInboundLimit(type(uint64).max, chainId1); // Register peer contracts for the nttManager and transceiver. Transceivers and nttManager each have the concept of peers here. - nttManagerChain1.setPeer(chainId2, bytes32(uint256(uint160(address(nttManagerChain2))))); - nttManagerChain2.setPeer(chainId1, bytes32(uint256(uint160(address(nttManagerChain1))))); + nttManagerChain1.setPeer( + chainId2, + bytes32(uint256(uint160(address(nttManagerChain2)))), + DummyToken(nttManagerChain2.token()).decimals() + ); + nttManagerChain2.setPeer( + chainId1, + bytes32(uint256(uint160(address(nttManagerChain1)))), + DummyToken(nttManagerChain1.token()).decimals() + ); wormholeTransceiverChain1.setWormholePeer( chainId2, bytes32(uint256(uint160((address(wormholeTransceiverChain2))))) diff --git a/evm/test/libraries/NttManagerHelpers.sol b/evm/test/libraries/NttManagerHelpers.sol index bc424c81f..7276f5b2d 100644 --- a/evm/test/libraries/NttManagerHelpers.sol +++ b/evm/test/libraries/NttManagerHelpers.sol @@ -16,7 +16,12 @@ library NttManagerHelpersLib { NttManager recipientNttManager, uint8 decimals ) internal { - recipientNttManager.setPeer(SENDING_CHAIN_ID, toWormholeFormat(address(nttManager))); + (, bytes memory queriedDecimals) = + address(nttManager.token()).staticcall(abi.encodeWithSignature("decimals()")); + uint8 tokenDecimals = abi.decode(queriedDecimals, (uint8)); + recipientNttManager.setPeer( + SENDING_CHAIN_ID, toWormholeFormat(address(nttManager)), tokenDecimals + ); recipientNttManager.setInboundLimit(inboundLimit.untrim(decimals), SENDING_CHAIN_ID); } }