Skip to content

Commit

Permalink
evm: cap decimal trimming by destination decimals
Browse files Browse the repository at this point in the history
  • Loading branch information
kcsongor committed Feb 26, 2024
1 parent 4a2bb5d commit 1d7954a
Show file tree
Hide file tree
Showing 12 changed files with 174 additions and 71 deletions.
27 changes: 19 additions & 8 deletions evm/src/NttManager/NttManager.sol
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ contract NttManager is INttManager, NttManagerState {
revert InvalidTargetChain(nativeTokenTransfer.toChain, chainId);
}
TrimmedAmount memory nativeTransferAmount =
(nativeTokenTransfer.amount.untrim(tokenDecimals_)).trim(tokenDecimals_);
(nativeTokenTransfer.amount.untrim(tokenDecimals_)).trim(tokenDecimals_, tokenDecimals_);

address transferRecipient = fromWormholeFormat(nativeTokenTransfer.to);

Expand Down Expand Up @@ -227,6 +227,7 @@ contract NttManager is INttManager, NttManagerState {
) internal {
uint256 numEnabledTransceivers = enabledTransceivers.length;
mapping(address => TransceiverInfo) storage transceiverInfos = _getTransceiverInfosStorage();
bytes32 peerAddress = _getPeersStorage()[recipientChain].peerAddress;
// call into transceiver contracts to send the message
for (uint256 i = 0; i < numEnabledTransceivers; i++) {
address transceiverAddr = enabledTransceivers[i];
Expand All @@ -235,7 +236,7 @@ contract NttManager is INttManager, NttManagerState {
recipientChain,
transceiverInstructions[transceiverInfos[transceiverAddr].index],
nttManagerMessage,
getPeer(recipientChain)
peerAddress
);
}
}
Expand Down Expand Up @@ -305,14 +306,15 @@ contract NttManager is INttManager, NttManagerState {
}

// trim amount after burning to ensure transfer amount matches (amount - fee)
TrimmedAmount memory trimmedAmount = _trimTransferAmount(amount);
TrimmedAmount memory trimmedAmount = _trimTransferAmount(amount, recipientChain);
TrimmedAmount memory internalAmount = trimmedAmount.shift(tokenDecimals_);

// get the sequence for this transfer
uint64 sequence = _useMessageSequence();

{
// now check rate limits
bool isAmountRateLimited = _isOutboundAmountRateLimited(trimmedAmount);
bool isAmountRateLimited = _isOutboundAmountRateLimited(internalAmount);
if (!shouldQueue && isAmountRateLimited) {
revert NotEnoughCapacity(getCurrentOutboundCapacity(), amount);
}
Expand Down Expand Up @@ -341,10 +343,10 @@ contract NttManager is INttManager, NttManagerState {
}

// otherwise, consume the outbound amount
_consumeOutboundAmount(trimmedAmount);
_consumeOutboundAmount(internalAmount);
// When sending a transfer, we refill the inbound rate limit for
// that chain by the same amount (we call this "backflow")
_backfillInboundAmount(trimmedAmount, recipientChain);
_backfillInboundAmount(internalAmount, recipientChain);

return _transfer(
sequence, trimmedAmount, recipientChain, recipient, msg.sender, transceiverInstructions
Expand Down Expand Up @@ -448,10 +450,19 @@ contract NttManager is INttManager, NttManagerState {
}
}

function _trimTransferAmount(uint256 amount) internal view returns (TrimmedAmount memory) {
function _trimTransferAmount(
uint256 amount,
uint16 toChain
) internal view returns (TrimmedAmount memory) {
uint8 toDecimals = _getPeersStorage()[toChain].tokenDecimals;

if (toDecimals == 0) {
revert InvalidPeerDecimals();
}

TrimmedAmount memory trimmedAmount;
{
trimmedAmount = amount.trim(tokenDecimals_);
trimmedAmount = amount.trim(tokenDecimals_, toDecimals);
// don't deposit dust that can not be bridged due to the decimal shift
uint256 newAmount = trimmedAmount.untrim(tokenDecimals_);
if (amount != newAmount) {
Expand Down
28 changes: 19 additions & 9 deletions evm/src/NttManager/NttManagerState.sol
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,11 @@ abstract contract NttManagerState is
}
}

function _getPeersStorage() internal pure returns (mapping(uint16 => bytes32) storage $) {
function _getPeersStorage()
internal
pure
returns (mapping(uint16 => NttManagerPeer) storage $)
{
uint256 slot = uint256(PEERS_SLOT);
assembly ("memory-safe") {
$.slot := slot
Expand Down Expand Up @@ -157,7 +161,7 @@ abstract contract NttManagerState is
}

/// @inheritdoc INttManagerState
function getPeer(uint16 chainId_) public view returns (bytes32) {
function getPeer(uint16 chainId_) external view returns (NttManagerPeer memory) {
return _getPeersStorage()[chainId_];
}

Expand Down Expand Up @@ -251,29 +255,35 @@ abstract contract NttManagerState is
}

/// @inheritdoc INttManagerState
function setPeer(uint16 peerChainId, bytes32 peerContract) public onlyOwner {
function setPeer(uint16 peerChainId, bytes32 peerContract, uint8 decimals) public onlyOwner {
if (peerChainId == 0) {
revert InvalidPeerChainIdZero();
}
if (peerContract == bytes32(0)) {
revert InvalidPeerZeroAddress();
}
if (decimals == 0) {
revert InvalidPeerDecimals();
}

bytes32 oldPeerContract = _getPeersStorage()[peerChainId];
NttManagerPeer memory oldPeer = _getPeersStorage()[peerChainId];

_getPeersStorage()[peerChainId] = peerContract;
_getPeersStorage()[peerChainId].peerAddress = peerContract;
_getPeersStorage()[peerChainId].tokenDecimals = decimals;

emit PeerUpdated(peerChainId, oldPeerContract, peerContract);
emit PeerUpdated(
peerChainId, oldPeer.peerAddress, oldPeer.tokenDecimals, peerContract, decimals
);
}

/// @inheritdoc INttManagerState
function setOutboundLimit(uint256 limit) external onlyOwner {
_setOutboundLimit(limit.trim(tokenDecimals_));
_setOutboundLimit(limit.trim(tokenDecimals_, tokenDecimals_));
}

/// @inheritdoc INttManagerState
function setInboundLimit(uint256 limit, uint16 chainId_) external onlyOwner {
_setInboundLimit(limit.trim(tokenDecimals_), chainId_);
_setInboundLimit(limit.trim(tokenDecimals_, tokenDecimals_), chainId_);
}

// =============== Internal ==============================================================
Expand Down Expand Up @@ -306,7 +316,7 @@ abstract contract NttManagerState is

/// @dev Verify that the peer address saved for `sourceChainId` matches the `peerAddress`.
function _verifyPeer(uint16 sourceChainId, bytes32 peerAddress) internal view {
if (getPeer(sourceChainId) != peerAddress) {
if (_getPeersStorage()[sourceChainId].peerAddress != peerAddress) {
revert InvalidPeer(sourceChainId, peerAddress);
}
}
Expand Down
12 changes: 10 additions & 2 deletions evm/src/interfaces/INttManagerEvents.sol
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,19 @@ interface INttManagerEvents {

/// @notice Emitted when the peer contract is updated.
/// @dev Topic0
/// 0x51b8437a7e22240c473f4cbdb4ed3a4f4bf5a9e7b3c511d7cfe0197325735700.
/// 0x1456404e7f41f35c3daac941bb50bad417a66275c3040061b4287d787719599d.
/// @param chainId_ The chain ID of the peer contract.
/// @param oldPeerContract The old peer contract address.
/// @param oldPeerDecimals The old peer contract decimals.
/// @param peerContract The new peer contract address.
event PeerUpdated(uint16 indexed chainId_, bytes32 oldPeerContract, bytes32 peerContract);
/// @param peerDecimals The new peer contract decimals.
event PeerUpdated(
uint16 indexed chainId_,
bytes32 oldPeerContract,
uint8 oldPeerDecimals,
bytes32 peerContract,
uint8 peerDecimals
);

/// @notice Emitted when a message has been attested to.
/// @dev Topic0
Expand Down
14 changes: 12 additions & 2 deletions evm/src/interfaces/INttManagerState.sol
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ interface INttManagerState {
/// @notice Peer cannot be the zero address.
error InvalidPeerZeroAddress();

/// @notice Peer cannot have zero decimals.
error InvalidPeerDecimals();

/// @notice The number of thresholds should not be zero.
error ZeroThreshold();

Expand All @@ -30,6 +33,12 @@ interface INttManagerState {
error ThresholdTooHigh(uint256 threshold, uint256 transceivers);
error RetrievedIncorrectRegisteredTransceivers(uint256 retrieved, uint256 registered);

/// @dev The peer on another chain.
struct NttManagerPeer {
bytes32 peerAddress;
uint8 tokenDecimals;
}

/// @notice Sets the transceiver for the given chain.
/// @param transceiver The address of the transceiver.
/// @dev This method can only be executed by the `owner`.
Expand All @@ -48,13 +57,14 @@ interface INttManagerState {

/// @notice Returns registered peer contract for a given chain.
/// @param chainId_ chain ID.
function getPeer(uint16 chainId_) external view returns (bytes32);
function getPeer(uint16 chainId_) external view returns (NttManagerPeer memory);

/// @notice Sets the corresponding peer.
/// @dev The nttManager that executes the message sets the source nttManager as the peer.
/// @param peerChainId The chain ID of the peer.
/// @param peerContract The address of the peer nttManager contract.
function setPeer(uint16 peerChainId, bytes32 peerContract) external;
/// @param decimals The number of decimals of the token on the peer chain.
function setPeer(uint16 peerChainId, bytes32 peerContract, uint8 decimals) external;

/// @notice Checks if a message has been approved. The message should have at least
/// the minimum threshold of attestations from distinct endpoints.
Expand Down
31 changes: 27 additions & 4 deletions evm/src/libraries/TrimmedAmount.sol
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ library TrimmedAmountLib {
return a.amount < b.amount;
}

// TODO: is this needed? let's remove it
function isZero(TrimmedAmount memory a) internal pure returns (bool) {
return (a.amount == 0 && a.decimals == 0);
}
Expand Down Expand Up @@ -140,16 +141,38 @@ library TrimmedAmountLib {
}
}

function trim(uint256 amt, uint8 fromDecimals) internal pure returns (TrimmedAmount memory) {
uint8 toDecimals = minUint8(TRIMMED_DECIMALS, fromDecimals);
uint256 amountScaled = scale(amt, fromDecimals, toDecimals);
function shift(
TrimmedAmount memory amount,
uint8 toDecimals
) internal pure returns (TrimmedAmount memory) {
uint8 actualToDecimals = minUint8(TRIMMED_DECIMALS, toDecimals);
return TrimmedAmount(
uint64(scale(amount.amount, amount.decimals, actualToDecimals)), actualToDecimals
);
}

/// @dev trim the amount to target decimals.
/// The actual resulting decimals is the minimum of TRIMMED_DECIMALS,
/// fromDecimals, and toDecimals. This ensures that no dust is
/// destroyed on either side of the transfer.
/// @param amt the amount to be trimmed
/// @param fromDecimals the original decimals of the amount
/// @param toDecimals the target decimals of the amount
///
function trim(
uint256 amt,
uint8 fromDecimals,
uint8 toDecimals
) internal pure returns (TrimmedAmount memory) {
uint8 actualToDecimals = minUint8(minUint8(TRIMMED_DECIMALS, fromDecimals), toDecimals);
uint256 amountScaled = scale(amt, fromDecimals, actualToDecimals);

// NOTE: amt after trimming must fit into uint64 (that's the point of
// trimming, as Solana only supports uint64 for token amts)
if (amountScaled > type(uint64).max) {
revert AmountTooLarge(amt);
}
return TrimmedAmount(uint64(amountScaled), toDecimals);
return TrimmedAmount(uint64(amountScaled), actualToDecimals);
}

function untrim(TrimmedAmount memory amt, uint8 toDecimals) internal pure returns (uint256) {
Expand Down
16 changes: 8 additions & 8 deletions evm/test/IntegrationRelayer.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ contract TestEndToEndRelayer is
wormholeTransceiverChain2.setWormholePeer(
chainId1, bytes32(uint256(uint160(address(wormholeTransceiverChain1))))
);
nttManagerChain2.setPeer(chainId1, bytes32(uint256(uint160(address(nttManagerChain1)))));
nttManagerChain2.setPeer(chainId1, bytes32(uint256(uint160(address(nttManagerChain1)))), 9);
DummyToken token2 = DummyTokenMintAndBurn(nttManagerChain2.token());
wormholeTransceiverChain2.setIsWormholeRelayingEnabled(chainId1, true);
wormholeTransceiverChain2.setIsWormholeEvmChain(chainId1);
Expand All @@ -178,7 +178,7 @@ contract TestEndToEndRelayer is
wormholeTransceiverChain1.setWormholePeer(
chainId2, bytes32(uint256(uint160((address(wormholeTransceiverChain2)))))
);
nttManagerChain1.setPeer(chainId2, bytes32(uint256(uint160(address(nttManagerChain2)))));
nttManagerChain1.setPeer(chainId2, bytes32(uint256(uint160(address(nttManagerChain2)))), 7);

// Enable general relaying on the chain to transfer for the funds.
wormholeTransceiverChain1.setIsWormholeRelayingEnabled(chainId2, true);
Expand Down Expand Up @@ -265,14 +265,14 @@ contract TestEndToEndRelayer is
wormholeTransceiverChain2.setWormholePeer(
chainId1, bytes32(uint256(uint160(address(wormholeTransceiverChain1))))
);
nttManagerChain2.setPeer(chainId1, bytes32(uint256(uint160(address(nttManagerChain1)))));
nttManagerChain2.setPeer(chainId1, bytes32(uint256(uint160(address(nttManagerChain1)))), 9);
DummyToken token2 = DummyTokenMintAndBurn(nttManagerChain2.token());
wormholeTransceiverChain2.setIsWormholeRelayingEnabled(chainId1, true);
wormholeTransceiverChain2.setIsWormholeEvmChain(chainId1);

// Register peer contracts for the nttManager and transceiver. Transceivers and nttManager each have the concept of peers here.
vm.selectFork(sourceFork);
nttManagerChain1.setPeer(chainId2, bytes32(uint256(uint160(address(nttManagerChain2)))));
nttManagerChain1.setPeer(chainId2, bytes32(uint256(uint160(address(nttManagerChain2)))), 7);
wormholeTransceiverChain1.setWormholePeer(
chainId2, bytes32(uint256(uint160((address(wormholeTransceiverChain2)))))
);
Expand Down Expand Up @@ -495,8 +495,8 @@ contract TestRelayerEndToEndManual is
nttManagerChain2.setInboundLimit(type(uint64).max, chainId1);

// Register peer contracts for the nttManager and transceiver. Transceivers and nttManager each have the concept of peers here.
nttManagerChain1.setPeer(chainId2, bytes32(uint256(uint160(address(nttManagerChain2)))));
nttManagerChain2.setPeer(chainId1, bytes32(uint256(uint160(address(nttManagerChain1)))));
nttManagerChain1.setPeer(chainId2, bytes32(uint256(uint160(address(nttManagerChain2)))), 9);
nttManagerChain2.setPeer(chainId1, bytes32(uint256(uint160(address(nttManagerChain1)))), 7);
}

function test_relayerTransceiverAuth() public {
Expand Down Expand Up @@ -551,7 +551,7 @@ contract TestRelayerEndToEndManual is

bytes[] memory a;

nttManagerChain2.setPeer(chainId1, bytes32(uint256(uint160(address(0x1)))));
nttManagerChain2.setPeer(chainId1, bytes32(uint256(uint160(address(0x1)))), 9);
vm.startPrank(relayer);
vm.expectRevert(); // bad nttManager peer
wormholeTransceiverChain2.receiveWormholeMessages(
Expand All @@ -564,7 +564,7 @@ contract TestRelayerEndToEndManual is
vm.stopPrank();

// Wrong caller - aka not relayer contract
nttManagerChain2.setPeer(chainId1, bytes32(uint256(uint160(address(nttManagerChain1)))));
nttManagerChain2.setPeer(chainId1, bytes32(uint256(uint160(address(nttManagerChain1)))), 9);
vm.prank(userD);
vm.expectRevert(
abi.encodeWithSelector(IWormholeTransceiverState.CallerNotRelayer.selector, userD)
Expand Down
4 changes: 2 additions & 2 deletions evm/test/IntegrationStandalone.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ contract TestEndToEndBase is Test, INttManagerEvents, IRateLimiterEvents {
nttManagerChain2.setInboundLimit(type(uint64).max, chainId1);

// Register peer contracts for the nttManager and transceiver. Transceivers and nttManager each have the concept of peers here.
nttManagerChain1.setPeer(chainId2, bytes32(uint256(uint160(address(nttManagerChain2)))));
nttManagerChain2.setPeer(chainId1, bytes32(uint256(uint160(address(nttManagerChain1)))));
nttManagerChain1.setPeer(chainId2, bytes32(uint256(uint160(address(nttManagerChain2)))), 9);
nttManagerChain2.setPeer(chainId1, bytes32(uint256(uint160(address(nttManagerChain1)))), 7);

// Set peers for the transceivers
wormholeTransceiverChain1.setWormholePeer(
Expand Down
10 changes: 6 additions & 4 deletions evm/test/NttManager.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ contract TestNttManager is Test, INttManagerEvents, IRateLimiterEvents {
(DummyTransceiver e1,) = TransceiverHelpersLib.setup_transceivers(nttManagerOther);
nttManagerOther.removeTransceiver(address(e1));
bytes32 peer = toWormholeFormat(address(nttManager));
nttManagerOther.setPeer(TransceiverHelpersLib.SENDING_CHAIN_ID, peer);
nttManagerOther.setPeer(TransceiverHelpersLib.SENDING_CHAIN_ID, peer, 9);

bytes memory transceiverMessage;
(, transceiverMessage) = TransceiverHelpersLib.buildTransceiverMessageWithNttManagerPayload(
Expand Down Expand Up @@ -238,7 +238,7 @@ contract TestNttManager is Test, INttManagerEvents, IRateLimiterEvents {

// register nttManager peer
bytes32 peer = toWormholeFormat(address(nttManager));
nttManagerOther.setPeer(TransceiverHelpersLib.SENDING_CHAIN_ID, peer);
nttManagerOther.setPeer(TransceiverHelpersLib.SENDING_CHAIN_ID, peer, 9);

TransceiverStructs.NttManagerMessage memory nttManagerMessage;
bytes memory transceiverMessage;
Expand All @@ -261,7 +261,7 @@ contract TestNttManager is Test, INttManagerEvents, IRateLimiterEvents {

// register nttManager peer
bytes32 peer = toWormholeFormat(address(nttManager));
nttManagerOther.setPeer(TransceiverHelpersLib.SENDING_CHAIN_ID, peer);
nttManagerOther.setPeer(TransceiverHelpersLib.SENDING_CHAIN_ID, peer, 9);

TransceiverStructs.NttManagerMessage memory nttManagerMessage;
bytes memory transceiverMessage;
Expand Down Expand Up @@ -289,7 +289,7 @@ contract TestNttManager is Test, INttManagerEvents, IRateLimiterEvents {
nttManagerOther.setThreshold(2);

bytes32 peer = toWormholeFormat(address(nttManager));
nttManagerOther.setPeer(TransceiverHelpersLib.SENDING_CHAIN_ID, peer);
nttManagerOther.setPeer(TransceiverHelpersLib.SENDING_CHAIN_ID, peer, 9);

ITransceiverReceiver[] memory transceivers = new ITransceiverReceiver[](1);
transceivers[0] = e1;
Expand Down Expand Up @@ -326,6 +326,7 @@ contract TestNttManager is Test, INttManagerEvents, IRateLimiterEvents {

uint8 decimals = token.decimals();

nttManager.setPeer(chainId, toWormholeFormat(address(0x1)), 9);
nttManager.setOutboundLimit(TrimmedAmount(type(uint64).max, 8).untrim(decimals));

token.mintDummy(address(user_A), 5 * 10 ** decimals);
Expand Down Expand Up @@ -444,6 +445,7 @@ contract TestNttManager is Test, INttManagerEvents, IRateLimiterEvents {

uint256 maxAmount = 5 * 10 ** decimals;
token.mintDummy(from, maxAmount);
nttManager.setPeer(chainId, toWormholeFormat(address(0x1)), 9);
nttManager.setOutboundLimit(TrimmedAmount(type(uint64).max, 8).untrim(decimals));
nttManager.setInboundLimit(
TrimmedAmount(type(uint64).max, 8).untrim(decimals),
Expand Down
Loading

0 comments on commit 1d7954a

Please sign in to comment.