Skip to content

Commit

Permalink
fix: couple swapper addresses to the chain ID (Finding 5) (#40)
Browse files Browse the repository at this point in the history
* fix: couple swapper addresses to the chain ID

`<T>Swapper` constructors now require the current chain ID as their second parameter, which they `assert()` as correct because the `<T>SwapperDeployer` always sends the correct ID. Despite the redundant behaviour, the swapper address is now coupled to the particular chain because the ID is used in predicting the CREATE2 address.

Note: this still requires testing of the unhappy path.

* test: changing chain IDs results in different swapper addresses, all else equal
  • Loading branch information
ARR4N authored Aug 11, 2024
1 parent 4ece5be commit cbbade1
Show file tree
Hide file tree
Showing 9 changed files with 84 additions and 24 deletions.
21 changes: 14 additions & 7 deletions src/SWAP2.sol
Original file line number Diff line number Diff line change
Expand Up @@ -100,19 +100,26 @@ abstract contract SWAP2ProposerBase is
MultiERC721ForERC20SwapperProposer
{}

/// @notice A standalone SWAP2 proposer for an immutable deployer address.
/// @notice A standalone SWAP2 proposer for an immutable deployer address and chain ID.
contract SWAP2Proposer is SWAP2ProposerBase {
/// @notice The SWAP2Deployer for which this contract proposes swaps.
address public immutable deployer;

/// @param deployer_ Address of the SWAP2Deployer for which this contract proposes swaps.
constructor(address deployer_) {
/// @notice The chain on which proposed swaps will be executed.
uint256 public immutable chainId;

/**
* @param deployer_ Address of the SWAP2Deployer for which this contract proposes swaps.
* @param chainId_ Chain on which proposed swaps will be executed.
*/
constructor(address deployer_, uint256 chainId_) {
deployer = deployer_;
chainId = chainId_;
}

/// @dev The immutable `deployer` is the swapper deployer for all types.
function _swapperDeployer() internal view override returns (address) {
return deployer;
function _swapperDeployer() internal view override returns (address, uint256) {
return (deployer, chainId);
}
}

Expand All @@ -123,7 +130,7 @@ contract SWAP2 is SWAP2Deployer, SWAP2ProposerBase {
{}

/// @dev The current contract is the swapper deployer for all types.
function _swapperDeployer() internal view override returns (address) {
return address(this);
function _swapperDeployer() internal view override returns (address, uint256) {
return (address(this), _currentChainId());
}
}
6 changes: 6 additions & 0 deletions src/SwapperDeployerBase.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
pragma solidity 0.8.25;

import {IEscrow} from "./Escrow.sol";
import {currentChainId} from "./TypesAndConstants.sol";

/// @dev Abstract base contract for all <T>SwapperDeployer implementations.
abstract contract SwapperDeployerBase {
Expand All @@ -17,4 +18,9 @@ abstract contract SwapperDeployerBase {
* buyer.
*/
function _escrow() internal view virtual returns (IEscrow);

/// @return The current chain ID.
function _currentChainId() internal view returns (uint256) {
return currentChainId();
}
}
4 changes: 2 additions & 2 deletions src/SwapperProposerBase.sol
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ pragma solidity 0.8.25;

/// @dev Abstract base contract for all <T>SwapperProposer implementations.
abstract contract SwapperProposerBase {
/// @dev Returns the address of the factory contract that deploys swappers, regardless of swap type.
function _swapperDeployer() internal view virtual returns (address);
/// @dev Returns the address and chain ID of the factory contract that deploys swappers, regardless of swap type.
function _swapperDeployer() internal view virtual returns (address, uint256 chainId);
}
2 changes: 1 addition & 1 deletion src/TMPL/ForERC20Swapper.tmpl.sol
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ import {TMPLSwapperBase} from "./TMPLSwapperBase.tmpl.sol";

/// @notice Executes the TMPLSwap received in the constructor.
contract TMPLSwapper is TMPLSwapperBase {
constructor(TMPLSwap memory swap) TMPLSwapperBase(swap) {}
constructor(TMPLSwap memory swap, uint256 currentChainId) TMPLSwapperBase(swap, currentChainId) {}
}
2 changes: 1 addition & 1 deletion src/TMPL/ForNativeSwapper.tmpl.sol
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ import {TMPLSwapperBase} from "./TMPLSwapperBase.tmpl.sol";

/// @notice Executes the TMPLSwap received in the constructor.
contract TMPLSwapper is TMPLSwapperBase {
constructor(TMPLSwap memory swap) payable TMPLSwapperBase(swap) {}
constructor(TMPLSwap memory swap, uint256 currentChainId) payable TMPLSwapperBase(swap, currentChainId) {}
}
10 changes: 8 additions & 2 deletions src/TMPL/TMPLSwapperBase.tmpl.sol
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ import {
CANCEL,
FILLED_ARTIFACT,
CANCELLED_ARTIFACT,
ExcessPlatformFee
ExcessPlatformFee,
currentChainId
} from "../TypesAndConstants.sol";

import {Math} from "@openzeppelin/contracts/utils/math/Math.sol";
Expand All @@ -29,7 +30,12 @@ contract TMPLSwapperBase is SwapperBase {
using ActionMessageLib for Message;
using ConsiderationLib for *;

constructor(TMPLSwap memory swap) {
constructor(TMPLSwap memory swap, uint256 currentChainId_) {
// The TMPLSwapperDeployer always uses its current chain ID, so this is an invariant, and hence the use of
// assert. Inclusion of the `currentChainId_` argument is purely to couple the CREATE2 address of this contract
// to the specific chain.
assert(currentChainId() == currentChainId_);

Message message = IETHome(msg.sender).etMessage();
Action action = message.action();

Expand Down
31 changes: 20 additions & 11 deletions src/TMPL/TMPLSwapperDeployer.tmpl.sol
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@ import {OnlyPartyCanCancel, ActionMessageLib, CANCEL_MSG, ISwapperEvents} from "

/// @dev Predictor of TMPLSwapper contract addresses.
contract TMPLSwapperPredictor {
function _swapper(TMPLSwap calldata swap, bytes32 salt, address deployer) internal pure returns (address) {
return ETPredictor.deploymentAddress(_bytecode(swap), salt, deployer);
function _swapper(TMPLSwap calldata swap, bytes32 salt, address deployer, uint256 chainId)
internal
pure
returns (address)
{
return ETPredictor.deploymentAddress(_bytecode(swap, chainId), salt, deployer);
}

function _bytecode(TMPLSwap calldata swap) internal pure returns (bytes memory) {
return abi.encodePacked(type(TMPLSwapper).creationCode, abi.encode(swap));
function _bytecode(TMPLSwap calldata swap, uint256 chainId) internal pure returns (bytes memory) {
return abi.encodePacked(type(TMPLSwapper).creationCode, abi.encode(swap, chainId));
}
}

Expand All @@ -29,8 +33,12 @@ abstract contract TMPLSwapperDeployer is TMPLSwapperPredictor, ETDeployer, Swapp
/// @dev Execute the `TMPLSwap`, transferring all assets between the parties.
function fillTMPL(TMPLSwap calldata swap, bytes32 salt) external payable returns (address) {
(address payable feeRecipient, uint16 basisPoints) = _platformFeeConfig();
address a =
_deploy(_bytecode(swap), msg.value, salt, ActionMessageLib.fillWithFeeConfig(feeRecipient, basisPoints));
address a = _deploy(
_bytecode(swap, _currentChainId()),
msg.value,
salt,
ActionMessageLib.fillWithFeeConfig(feeRecipient, basisPoints)
);
emit Filled(a);
return a;
}
Expand All @@ -40,7 +48,7 @@ abstract contract TMPLSwapperDeployer is TMPLSwapperPredictor, ETDeployer, Swapp
if (msg.sender != swap.parties.seller && msg.sender != swap.parties.buyer) {
revert OnlyPartyCanCancel();
}
address a = _deploy(_bytecode(swap), 0, salt, ActionMessageLib.cancelWithEscrow(_escrow()));
address a = _deploy(_bytecode(swap, _currentChainId()), 0, salt, ActionMessageLib.cancelWithEscrow(_escrow()));
emit Cancelled(a);
return a;
}
Expand All @@ -50,7 +58,7 @@ abstract contract TMPLSwapperDeployer is TMPLSwapperPredictor, ETDeployer, Swapp
* @dev Important: see `TMPLSwapperProposer.propose()` as an alternative.
*/
function swapperOfTMPL(TMPLSwap calldata swap, bytes32 salt) external view returns (address) {
return _swapper(swap, salt, address(this));
return _swapper(swap, salt, address(this), _currentChainId());
}
}

Expand All @@ -77,12 +85,13 @@ abstract contract TMPLSwapperProposer is TMPLSwapperPredictor, SwapperProposerBa
// malleability of block hashes is too low and their rate of production too slow for an attack based on
// discarding undesirable salts.
bytes32 salt = blockhash(block.number - 1);
address swapper_ = _swapper(swap, salt, _swapperDeployer(swap));
(address deployer, uint256 chainId) = _swapperDeployer();
address swapper_ = _swapper(swap, salt, deployer, chainId);
emit Proposal(swapper_, swap.parties.seller, swap.parties.buyer, swap, salt);
return (salt, swapper_);
}

function _swapperDeployer(TMPLSwap calldata) internal virtual returns (address) {
function _swapperDeployer(TMPLSwap calldata) internal virtual returns (address, uint256 chainId) {
return _swapperDeployer();
}
}
Expand All @@ -94,5 +103,5 @@ abstract contract TMPLSwapperProposer is TMPLSwapperPredictor, SwapperProposerBa
function _enforceTMPLSwapperCtorSig() {
assert(false);
TMPLSwap memory s;
new TMPLSwapper(s);
new TMPLSwapper(s, uint256(0));
}
9 changes: 9 additions & 0 deletions src/TypesAndConstants.sol
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,15 @@ function swapStatus(address swapper) view returns (SwapStatus) {
return SwapStatus.Invalid;
}

// The current chain ID is effectively a constant so it belongs in this file.

/// @return id The current chain ID.
function currentChainId() view returns (uint256 id) {
assembly ("memory-safe") {
id := chainid()
}
}

/**
* ======
*
Expand Down
23 changes: 23 additions & 0 deletions test/ERC721ForXTest.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,29 @@ abstract contract ERC721ForXTest is SwapperTestBase {
escrow.withdraw(test.buyer());
}

function testChainIdCoupling(ERC721TestCase memory t, uint64 chainId0, uint64 chainId1)
external
// While the specific assumptions are irrelevant, general assumptions about `t` must be made for it to be valid
// otherwise we'll get out-of-bounds errors.
assumeValidTest(t.base, Assumptions({sufficientPayment: true, validPlatformFee: true, approving: true}))
{
vm.chainId(chainId0);
address swapperOnChain0 = _swapper(t);

vm.chainId(chainId1);
address swapperOnChain1 = _swapper(t);

emit log_named_uint("chain ID 0", chainId0);
emit log_named_uint("chain ID 1", chainId1);
emit log_named_address("swapper on chain 0", swapperOnChain0);
emit log_named_address("swapper on chain 1", swapperOnChain1);
assertEq(
chainId0 == chainId1,
swapperOnChain0 == swapperOnChain1,
"different chain IDs <=> different swapper addresses"
);
}

function testGas() external {
Disbursement[5] memory thirdParty;
uint128 total = 1 ether;
Expand Down

0 comments on commit cbbade1

Please sign in to comment.