Skip to content

Commit

Permalink
evm: Wormhole endpoint registration
Browse files Browse the repository at this point in the history
  • Loading branch information
djb15 committed Feb 23, 2024
1 parent 2650cb7 commit d0b1233
Show file tree
Hide file tree
Showing 9 changed files with 202 additions and 36 deletions.
2 changes: 1 addition & 1 deletion evm/src/Endpoint.sol
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ abstract contract Endpoint is

/// =============== ADMIN ===============================================

function _initialize() internal override {
function _initialize() internal virtual override {
__ReentrancyGuard_init();
// owner of the endpoint is set to the owner of the manager
__PausedOwnable_init(msg.sender, getManagerOwner());
Expand Down
8 changes: 6 additions & 2 deletions evm/src/Manager.sol
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ contract Manager is
emit ThresholdChanged(oldThreshold, threshold);
}

function getMode() public view returns (uint8) {
return uint8(mode);
}

/// @notice Returns the number of Endpoints that must attest to a msgId for
/// it to be considered valid and acted upon.
function getThreshold() public view returns (uint8) {
Expand Down Expand Up @@ -760,8 +764,8 @@ contract Manager is
return countSetBits(_getMessageAttestations(digest));
}

function _tokenDecimals() internal view override returns (uint8) {
return tokenDecimals;
function tokenDecimals() public view override(IManager, RateLimiter) returns (uint8) {
return tokenDecimals_;
}

// @dev Count the number of set bits in a uint64
Expand Down
8 changes: 4 additions & 4 deletions evm/src/NttNormalizer.sol
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ abstract contract NttNormalizer {
using NormalizedAmountLib for uint256;
using NormalizedAmountLib for NormalizedAmount;

uint8 immutable tokenDecimals;
uint8 internal immutable tokenDecimals_;

constructor(address _token) {
tokenDecimals = _tokenDecimals(_token);
tokenDecimals_ = _tokenDecimals(_token);
}

function _tokenDecimals(address token) internal view returns (uint8) {
Expand All @@ -19,11 +19,11 @@ abstract contract NttNormalizer {
}

function nttNormalize(uint256 amount) public view returns (NormalizedAmount memory) {
return amount.normalize(tokenDecimals);
return amount.normalize(tokenDecimals_);
}

function nttDenormalize(NormalizedAmount memory amount) public view returns (uint256) {
return amount.denormalize(tokenDecimals);
return amount.denormalize(tokenDecimals_);
}

/// @dev Shift decimals of `amount` to match the token decimals
Expand Down
46 changes: 45 additions & 1 deletion evm/src/WormholeEndpoint.sol
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ import "wormhole-solidity-sdk/libraries/BytesParsing.sol";
import "wormhole-solidity-sdk/interfaces/IWormhole.sol";

import "./libraries/EndpointHelpers.sol";
import "./libraries/EndpointStructs.sol";
import "./interfaces/IWormholeEndpoint.sol";
import "./interfaces/ISpecialRelayer.sol";
import "./interfaces/IManager.sol";
import "./Endpoint.sol";

contract WormholeEndpoint is Endpoint, IWormholeEndpoint, IWormholeReceiver {
Expand All @@ -22,6 +24,14 @@ contract WormholeEndpoint is Endpoint, IWormholeEndpoint, IWormholeReceiver {
/// Note that this is not a security critical field. It's meant to be used by messaging providers to identify which messages are Endpoint-related.
bytes4 constant WH_ENDPOINT_PAYLOAD_PREFIX = 0x9945FF10;

/// @dev Prefix for all Wormhole endpoint initialisation payloads
/// This is bytes4(keccak256("WormholeEndpointInit"))
bytes4 constant WH_ENDPOINT_INIT_PREFIX = 0xc83e3d2e;

/// @dev Prefix for all Wormhole sibling registration payloads
/// This is bytes4(keccak256("WormholeSiblingRegistration"))
bytes4 constant WH_SIBLING_REGISTRATION_PREFIX = 0xd0d292f1;

IWormhole public immutable wormhole;
IWormholeRelayer public immutable wormholeRelayer;
ISpecialRelayer public immutable specialRelayer;
Expand Down Expand Up @@ -126,6 +136,22 @@ contract WormholeEndpoint is Endpoint, IWormholeEndpoint, IWormholeReceiver {
consistencyLevel = _consistencyLevel;
}

function _initialize() internal override {
super._initialize();
_initializeEndpoint();
}

function _initializeEndpoint() internal {
EndpointStructs.EndpointInit memory init = EndpointStructs.EndpointInit({
endpointIdentifier: WH_ENDPOINT_INIT_PREFIX,
managerAddress: toWormholeFormat(manager),
managerMode: IManager(manager).getMode(),
tokenAddress: toWormholeFormat(managerToken),
tokenDecimals: IManager(manager).tokenDecimals()
});
wormhole.publishMessage(0, EndpointStructs.encodeEndpointInit(init), consistencyLevel);
}

function _checkInvalidRelayingConfig(uint16 chainId) internal view returns (bool) {
return isWormholeRelayingEnabled(chainId) && !isWormholeEvmChain(chainId);
}
Expand Down Expand Up @@ -328,9 +354,27 @@ contract WormholeEndpoint is Endpoint, IWormholeEndpoint, IWormholeReceiver {

bytes32 oldSiblingContract = _getWormholeSiblingsStorage()[chainId];

// We don't want to allow updating a sibling since this adds complexity in the accountant
// If the owner makes a mistake with sibling registration they should deploy a new Wormhole
// endpoint and register this new endpoint with the Manager
if (oldSiblingContract != bytes32(0)) {
revert SiblingAlreadySet(chainId, oldSiblingContract);
}

_getWormholeSiblingsStorage()[chainId] = siblingContract;

emit SetWormholeSibling(chainId, oldSiblingContract, siblingContract);
// Publish a message for this endpoint registration
EndpointStructs.EndpointRegistration memory registration = EndpointStructs
.EndpointRegistration({
endpointIdentifier: WH_SIBLING_REGISTRATION_PREFIX,
endpointChainId: chainId,
endpointAddress: siblingContract
});
wormhole.publishMessage(
0, EndpointStructs.encodeEndpointRegistration(registration), consistencyLevel
);

emit SetWormholeSibling(chainId, siblingContract);
}

function isWormholeRelayingEnabled(uint16 chainId) public view returns (bool) {
Expand Down
8 changes: 8 additions & 0 deletions evm/src/interfaces/IManager.sol
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,12 @@ interface IManager {
///
/// @param newImplementation The address of the new implementation.
function upgrade(address newImplementation) external;

/// @notice Returns the mode (locking or burning) of the Manager.
/// @return mode A uint8 corresponding to the mode
function getMode() external view returns (uint8);

/// @notice Returns the number of decimals of the token managed by the Manager.
/// @return decimals The number of decimals of the token.
function tokenDecimals() external view returns (uint8);
}
3 changes: 2 additions & 1 deletion evm/src/interfaces/IWormholeEndpoint.sol
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ interface IWormholeEndpoint {
);

event SendEndpointMessage(uint16 recipientChain, EndpointStructs.EndpointMessage message);
event SetWormholeSibling(uint16 chainId, bytes32 oldSiblingContract, bytes32 siblingContract);
event SetWormholeSibling(uint16 chainId, bytes32 siblingContract);
event SetIsWormholeRelayingEnabled(uint16 chainId, bool isRelayingEnabled);
event SetIsSpecialRelayingEnabled(uint16 chainId, bool isRelayingEnabled);
event SetIsWormholeEvmChain(uint16 chainId);
Expand All @@ -20,6 +20,7 @@ interface IWormholeEndpoint {
error UnexpectedAdditionalMessages();
error InvalidVaa(string reason);
error InvalidWormholeSibling(uint16 chainId, bytes32 siblingAddress);
error SiblingAlreadySet(uint16 chainId, bytes32 siblingAddress);
error TransferAlreadyCompleted(bytes32 vaaHash);
error InvalidWormholeSiblingZeroAddress();
error InvalidWormholeChainIdZero();
Expand Down
62 changes: 62 additions & 0 deletions evm/src/libraries/EndpointStructs.sol
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,68 @@ library EndpointStructs {
return instructions;
}

struct EndpointInit {
bytes4 endpointIdentifier;
bytes32 managerAddress;
uint8 managerMode;
bytes32 tokenAddress;
uint8 tokenDecimals;
}

function encodeEndpointInit(EndpointInit memory init) public pure returns (bytes memory) {
return abi.encodePacked(
init.endpointIdentifier,
init.managerAddress,
init.managerMode,
init.tokenAddress,
init.tokenDecimals
);
}

function decodeEndpointInit(bytes memory encoded)
public
pure
returns (EndpointInit memory init)
{
uint256 offset = 0;
(init.endpointIdentifier, offset) = encoded.asBytes4Unchecked(offset);
(init.managerAddress, offset) = encoded.asBytes32Unchecked(offset);
(init.managerMode, offset) = encoded.asUint8Unchecked(offset);
(init.tokenAddress, offset) = encoded.asBytes32Unchecked(offset);
(init.tokenDecimals, offset) = encoded.asUint8Unchecked(offset);
encoded.checkLength(offset);
}

struct EndpointRegistration {
bytes4 endpointIdentifier;
uint16 endpointChainId;
bytes32 endpointAddress;
}

function encodeEndpointRegistration(EndpointRegistration memory registration)
public
pure
returns (bytes memory)
{
return abi.encodePacked(
registration.endpointIdentifier,
registration.endpointChainId,
registration.endpointAddress
);
}

function decodeEndpointRegistration(bytes memory encoded)
public
pure
returns (EndpointRegistration memory registration)
{
uint256 offset = 0;
(registration.endpointIdentifier, offset) = encoded.asBytes4Unchecked(offset);
(registration.endpointChainId, offset) = encoded.asUint16Unchecked(offset);
(registration.endpointAddress, offset) = encoded.asBytes32Unchecked(offset);
encoded.checkLength(offset);
}

/*
* @dev This function takes a list of EndpointInstructions and expands them to a 256-length list,
* inserting each instruction into the expanded list based on `instruction.index`.
Expand Down
6 changes: 3 additions & 3 deletions evm/src/libraries/RateLimiter.sol
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ abstract contract RateLimiter is IRateLimiter, IRateLimiterEvents {

function getCurrentOutboundCapacity() public view returns (uint256) {
NormalizedAmount memory normalizedCapacity = _getCurrentCapacity(getOutboundLimitParams());
uint8 decimals = _tokenDecimals();
uint8 decimals = tokenDecimals();
return normalizedCapacity.denormalize(decimals);
}

Expand All @@ -119,7 +119,7 @@ abstract contract RateLimiter is IRateLimiter, IRateLimiterEvents {
function getCurrentInboundCapacity(uint16 chainId_) public view returns (uint256) {
NormalizedAmount memory normalizedCapacity =
_getCurrentCapacity(getInboundLimitParams(chainId_));
uint8 decimals = _tokenDecimals();
uint8 decimals = tokenDecimals();
return normalizedCapacity.denormalize(decimals);
}

Expand Down Expand Up @@ -295,5 +295,5 @@ abstract contract RateLimiter is IRateLimiter, IRateLimiterEvents {
emit InboundTransferQueued(digest);
}

function _tokenDecimals() internal view virtual returns (uint8);
function tokenDecimals() public view virtual returns (uint8);
}
95 changes: 71 additions & 24 deletions evm/test/IntegrationRelayer.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -483,16 +483,17 @@ contract TestRelayerEndToEndManual is
// Register sibling contracts for the manager and endpoint. Endpoints and manager each have the concept of siblings here.
managerChain1.setSibling(chainId2, bytes32(uint256(uint160(address(managerChain2)))));
managerChain2.setSibling(chainId1, bytes32(uint256(uint160(address(managerChain1)))));
}

function test_relayerEndpointAuth() public {
// Set up sensible WH endpoint siblings
wormholeEndpointChain1.setWormholeSibling(
chainId2, bytes32(uint256(uint160((address(wormholeEndpointChain2)))))
);
wormholeEndpointChain2.setWormholeSibling(
chainId1, bytes32(uint256(uint160(address(wormholeEndpointChain1))))
);
}

function test_relayerEndpointAuth() public {
vm.recordLogs();
vm.chainId(chainId1);

Expand Down Expand Up @@ -534,30 +535,8 @@ contract TestRelayerEndToEndManual is
vm.stopPrank();
vm.chainId(chainId2);

// Caller is not proper who to receive messages from
bytes[] memory a;
wormholeEndpointChain2.setWormholeSibling(chainId1, bytes32(uint256(uint160(address(0x1)))));
vm.startPrank(relayer);
vm.expectRevert(
abi.encodeWithSelector(
IWormholeEndpoint.InvalidWormholeSibling.selector,
chainId1,
address(wormholeEndpointChain1)
)
);
wormholeEndpointChain2.receiveWormholeMessages(
vaa.payload,
a,
bytes32(uint256(uint160(address(wormholeEndpointChain1)))),
vaa.emitterChainId,
vaa.hash
);
vm.stopPrank();

// Bad manager sibling calling
wormholeEndpointChain2.setWormholeSibling(
chainId1, bytes32(uint256(uint160(address(wormholeEndpointChain1))))
);
managerChain2.setSibling(chainId1, bytes32(uint256(uint160(address(0x1)))));
vm.startPrank(relayer);
vm.expectRevert(); // bad manager sibling
Expand Down Expand Up @@ -624,4 +603,72 @@ contract TestRelayerEndToEndManual is
vaa.hash // Hash of the VAA being used
);
}

function test_relayerWithInvalidWHEndpoint() public {
// Set up dodgy wormhole endpoint siblings
wormholeEndpointChain2.setWormholeSibling(chainId1, bytes32(uint256(uint160(address(0x1)))));
wormholeEndpointChain1.setWormholeSibling(
chainId2, bytes32(uint256(uint160(address(wormholeEndpointChain2))))
);

vm.recordLogs();
vm.chainId(chainId1);

// Setting up the transfer
DummyToken token1 = DummyToken(managerChain1.token());

uint8 decimals = token1.decimals();
uint256 sendingAmount = 5 * 10 ** decimals;
token1.mintDummy(address(userA), 5 * 10 ** decimals);
vm.startPrank(userA);
token1.approve(address(managerChain1), sendingAmount);

// Send token through the relayer
{
vm.deal(userA, 1 ether);
managerChain1.transfer{
value: wormholeEndpointChain1.quoteDeliveryPrice(
chainId2, buildEndpointInstruction(false)
)
}(
sendingAmount,
chainId2,
bytes32(uint256(uint160(userB))),
false,
encodeEndpointInstruction(false)
);
}

// Get the messages from the logs for the sender
vm.chainId(chainId2);
Vm.Log[] memory entries = guardian.fetchWormholeMessageFromLog(vm.getRecordedLogs());
bytes[] memory encodedVMs = new bytes[](entries.length);
for (uint256 i = 0; i < encodedVMs.length; i++) {
encodedVMs[i] = guardian.fetchSignedMessageFromLogs(entries[i], chainId1);
}

IWormhole.VM memory vaa = wormhole.parseVM(encodedVMs[0]);

vm.stopPrank();
vm.chainId(chainId2);

// Caller is not proper who to receive messages from
bytes[] memory a;
vm.startPrank(relayer);
vm.expectRevert(
abi.encodeWithSelector(
IWormholeEndpoint.InvalidWormholeSibling.selector,
chainId1,
address(wormholeEndpointChain1)
)
);
wormholeEndpointChain2.receiveWormholeMessages(
vaa.payload,
a,
bytes32(uint256(uint160(address(wormholeEndpointChain1)))),
vaa.emitterChainId,
vaa.hash
);
vm.stopPrank();
}
}

0 comments on commit d0b1233

Please sign in to comment.