From 52ae8efbc03bbaf3395d4b3f2750b0a9eaf7dcc6 Mon Sep 17 00:00:00 2001 From: cairo Date: Sat, 4 Jan 2025 20:12:21 +0100 Subject: [PATCH] Improve custom curve tests --- src/base/BaseCustomCurve.sol | 22 +- test/base/BaseCustomCurve.t.sol | 351 ++++++++++++++++++++++++++++- test/mocks/BaseCustomCurveMock.sol | 8 +- 3 files changed, 366 insertions(+), 15 deletions(-) diff --git a/src/base/BaseCustomCurve.sol b/src/base/BaseCustomCurve.sol index 1c96f91..7ba83e2 100644 --- a/src/base/BaseCustomCurve.sol +++ b/src/base/BaseCustomCurve.sol @@ -12,6 +12,7 @@ import {SafeCast} from "v4-core/src/libraries/SafeCast.sol"; import {CurrencySettler} from "v4-core/test/utils/CurrencySettler.sol"; import {BeforeSwapDelta, toBeforeSwapDelta} from "v4-core/src/types/BeforeSwapDelta.sol"; import {BalanceDelta, toBalanceDelta} from "v4-core/src/types/BalanceDelta.sol"; +import {console2 as console} from "forge-std/console2.sol"; /** * @dev Base implementation for custom curves. @@ -75,7 +76,7 @@ abstract contract BaseCustomCurve is BaseCustomAccounting { returns (bytes memory, uint256) { (uint256 amount0, uint256 amount1, uint256 liquidity) = _calculateIn(params); - return (abi.encode(amount0, amount1), liquidity); + return (abi.encode(amount0.toInt128(), amount1.toInt128()), liquidity); } function _getRemoveLiquidity(RemoveLiquidityParams memory params) @@ -85,7 +86,7 @@ abstract contract BaseCustomCurve is BaseCustomAccounting { returns (bytes memory, uint256) { (uint256 amount0, uint256 amount1, uint256 liquidity) = _calculateOut(params); - return (abi.encode(amount0, amount1), liquidity); + return (abi.encode(-amount0.toInt128(), -amount1.toInt128()), liquidity); } function _beforeSwap(address, PoolKey calldata key, IPoolManager.SwapParams calldata params, bytes calldata) @@ -141,6 +142,9 @@ abstract contract BaseCustomCurve is BaseCustomAccounting { function _unlockCallback(bytes calldata rawData) internal virtual override returns (bytes memory) { CallbackDataCustom memory data = abi.decode(rawData, (CallbackDataCustom)); + console.log(data.amount0); + console.log(data.amount1); + int128 amount0; int128 amount1; @@ -149,27 +153,27 @@ abstract contract BaseCustomCurve is BaseCustomAccounting { // When adding liquidity, mint ERC-6909 claim tokens and transfer tokens from receiver to pool. if (data.amount0 < 0) { - poolKey.currency0.settle(poolManager, address(this), uint256(int256(data.amount0)), true); - poolKey.currency0.take(poolManager, data.sender, uint256(int256(data.amount0)), false); + poolKey.currency0.settle(poolManager, address(this), uint256(int256(-data.amount0)), true); + poolKey.currency0.take(poolManager, data.sender, uint256(int256(-data.amount0)), false); amount0 = data.amount0; } if (data.amount1 < 0) { - poolKey.currency1.settle(poolManager, address(this), uint256(int256(data.amount1)), true); - poolKey.currency1.take(poolManager, data.sender, uint256(int256(data.amount1)), false); + poolKey.currency1.settle(poolManager, address(this), uint256(int256(-data.amount1)), true); + poolKey.currency1.take(poolManager, data.sender, uint256(int256(-data.amount1)), false); amount1 = data.amount1; } if (data.amount0 > 0) { - poolKey.currency0.settle(poolManager, data.sender, uint256(int256(-data.amount0)), false); + poolKey.currency0.settle(poolManager, data.sender, uint256(int256(data.amount0)), false); poolKey.currency0.take(poolManager, address(this), uint256(int256(data.amount0)), true); amount0 = -data.amount0; } if (data.amount1 > 0) { - poolKey.currency1.settle(poolManager, data.sender, uint256(int256(-data.amount1)), false); + poolKey.currency1.settle(poolManager, data.sender, uint256(int256(data.amount1)), false); poolKey.currency1.take(poolManager, address(this), uint256(int256(data.amount1)), true); - amount1 = data.amount1; + amount1 = -data.amount1; } return abi.encode(toBalanceDelta(amount0, amount1)); diff --git a/test/base/BaseCustomCurve.t.sol b/test/base/BaseCustomCurve.t.sol index 04c3a91..8b3581e 100644 --- a/test/base/BaseCustomCurve.t.sol +++ b/test/base/BaseCustomCurve.t.sol @@ -89,11 +89,358 @@ contract BaseCustomCurveTest is Test, Deployers { uint256 liquidityTokenBal = hook.balanceOf(address(this)); - assertEq(manager.getLiquidity(id), liquidityTokenBal); - assertEq(key.currency0.balanceOf(address(this)), prevBalance0 - 10 ether); assertEq(key.currency1.balanceOf(address(this)), prevBalance1 - 10 ether); assertEq(liquidityTokenBal, 10 ether); } + + function test_addLiquidity_fuzz_succeeds(uint112 amount) public { + vm.assume(amount > 0); + + hook.addLiquidity( + BaseCustomAccounting.AddLiquidityParams( + amount, amount, 0, 0, address(this), MAX_DEADLINE, MIN_TICK, MAX_TICK + ) + ); + + uint256 liquidityTokenBal = hook.balanceOf(address(this)); + assertEq(liquidityTokenBal, amount); + } + + function test_addLiquidity_swapThenAdd_succeeds() public { + uint256 prevBalance0 = key.currency0.balanceOf(address(this)); + uint256 prevBalance1 = key.currency1.balanceOf(address(this)); + + hook.addLiquidity( + BaseCustomAccounting.AddLiquidityParams( + 10 ether, 10 ether, 9 ether, 9 ether, address(this), MAX_DEADLINE, MIN_TICK, MAX_TICK + ) + ); + + uint256 liquidityTokenBal = hook.balanceOf(address(this)); + + assertEq(liquidityTokenBal, 10 ether); + assertEq(key.currency0.balanceOf(address(this)), prevBalance0 - 10 ether); + assertEq(key.currency1.balanceOf(address(this)), prevBalance1 - 10 ether); + + vm.expectEmit(true, true, true, true, address(manager)); + emit Swap(id, address(swapRouter), 0, 0, 79228162514264337593543950336, 0, 0, 0); + + IPoolManager.SwapParams memory params = + IPoolManager.SwapParams({zeroForOne: true, amountSpecified: -1 ether, sqrtPriceLimitX96: SQRT_PRICE_1_2}); + PoolSwapTest.TestSettings memory settings = + PoolSwapTest.TestSettings({takeClaims: false, settleUsingBurn: false}); + + swapRouter.swap(key, params, settings, ZERO_BYTES); + + assertEq(key.currency0.balanceOf(address(this)), prevBalance0 - 10 ether - 1 ether); + assertEq(key.currency1.balanceOf(address(this)), prevBalance1 - 9 ether); + + hook.addLiquidity( + BaseCustomAccounting.AddLiquidityParams( + 5 ether, 5 ether, 4 ether, 4 ether, address(this), MAX_DEADLINE, MIN_TICK, MAX_TICK + ) + ); + + liquidityTokenBal = hook.balanceOf(address(this)); + + assertEq(liquidityTokenBal, 15 ether); + assertEq(liquidityTokenBal, 15 ether); + } + + function test_addLiquidity_expired_revert() public { + vm.expectRevert(BaseCustomAccounting.ExpiredPastDeadline.selector); + hook.addLiquidity( + BaseCustomAccounting.AddLiquidityParams(0, 0, 0, 0, address(this), block.timestamp - 1, MIN_TICK, MAX_TICK) + ); + } + + function test_addLiquidity_tooMuchSlippage_reverts() public { + vm.expectRevert(BaseCustomAccounting.TooMuchSlippage.selector); + hook.addLiquidity( + BaseCustomAccounting.AddLiquidityParams( + 10 ether, 10 ether, 100000 ether, 100000 ether, address(this), MAX_DEADLINE, MIN_TICK, MAX_TICK + ) + ); + } + + function test_swap_twoSwaps_succeeds() public { + hook.addLiquidity( + BaseCustomAccounting.AddLiquidityParams( + 2 ether, 2 ether, 0, 0, address(this), MAX_DEADLINE, MIN_TICK, MAX_TICK + ) + ); + + IPoolManager.SwapParams memory params = + IPoolManager.SwapParams({zeroForOne: true, amountSpecified: 1 ether, sqrtPriceLimitX96: MIN_PRICE_LIMIT}); + PoolSwapTest.TestSettings memory settings = + PoolSwapTest.TestSettings({takeClaims: false, settleUsingBurn: false}); + + swapRouter.swap(key, params, settings, ZERO_BYTES); + swapRouter.swap(key, params, settings, ZERO_BYTES); + } + + function test_removeLiquidity_initialRemove_succeeds() public { + hook.addLiquidity( + BaseCustomAccounting.AddLiquidityParams( + 100 ether, 100 ether, 99 ether, 99 ether, address(this), MAX_DEADLINE, MIN_TICK, MAX_TICK + ) + ); + + uint256 prevBalance0 = key.currency0.balanceOf(address(this)); + uint256 prevBalance1 = key.currency1.balanceOf(address(this)); + + hook.approve(address(hook), type(uint256).max); + + BaseCustomAccounting.RemoveLiquidityParams memory removeLiquidityParams = + BaseCustomAccounting.RemoveLiquidityParams(1 ether, MAX_DEADLINE, MIN_TICK, MAX_TICK); + + hook.removeLiquidity(removeLiquidityParams); + + uint256 liquidityTokenBal = hook.balanceOf(address(this)); + assertEq(liquidityTokenBal, 99 ether); + assertEq(key.currency0.balanceOf(address(this)), prevBalance0 + 0.5 ether); + assertEq(key.currency1.balanceOf(address(this)), prevBalance1 + 0.5 ether); + } + + function test_removeLiquidity_fuzz_succeeds(uint256 amount) public { + vm.assume(amount > 0); + + if (amount > hook.balanceOf(address(this))) { + vm.expectRevert(); + hook.removeLiquidity(BaseCustomAccounting.RemoveLiquidityParams(amount, MAX_DEADLINE, MIN_TICK, MAX_TICK)); + } else { + uint256 prevLiquidityTokenBal = hook.balanceOf(address(this)); + hook.removeLiquidity(BaseCustomAccounting.RemoveLiquidityParams(amount, MAX_DEADLINE, MIN_TICK, MAX_TICK)); + + uint256 liquidityTokenBal = hook.balanceOf(address(this)); + + assertEq(prevLiquidityTokenBal - liquidityTokenBal, amount); + assertEq(manager.getLiquidity(id), liquidityTokenBal); + } + } + + function test_removeLiquidity_noLiquidity_reverts() public { + vm.expectRevert(); + hook.removeLiquidity( + BaseCustomAccounting.RemoveLiquidityParams(1000000 ether, MAX_DEADLINE, MIN_TICK, MAX_TICK) + ); + } + + function test_removeLiquidity_partial_succeeds() public { + uint256 prevBalance0 = key.currency0.balanceOf(address(this)); + uint256 prevBalance1 = key.currency1.balanceOf(address(this)); + + hook.addLiquidity( + BaseCustomAccounting.AddLiquidityParams( + 10 ether, 10 ether, 9 ether, 9 ether, address(this), MAX_DEADLINE, MIN_TICK, MAX_TICK + ) + ); + + assertEq(hook.balanceOf(address(this)), 10 ether); + assertEq(key.currency0.balanceOfSelf(), prevBalance0 - 10 ether); + assertEq(key.currency1.balanceOfSelf(), prevBalance1 - 10 ether); + + hook.removeLiquidity(BaseCustomAccounting.RemoveLiquidityParams(5 ether, MAX_DEADLINE, MIN_TICK, MAX_TICK)); + + uint256 liquidityTokenBal = hook.balanceOf(address(this)); + assertEq(liquidityTokenBal, 5 ether); + assertEq(key.currency0.balanceOf(address(this)), prevBalance0 - 7.5 ether); + assertEq(key.currency1.balanceOf(address(this)), prevBalance1 - 7.5 ether); + } + + function test_removeLiquidity_diffRatios_succeeds() public { + uint256 prevBalance0 = key.currency0.balanceOf(address(this)); + uint256 prevBalance1 = key.currency1.balanceOf(address(this)); + + hook.addLiquidity( + BaseCustomAccounting.AddLiquidityParams( + 10 ether, 10 ether, 9 ether, 9 ether, address(this), MAX_DEADLINE, MIN_TICK, MAX_TICK + ) + ); + + assertEq(key.currency0.balanceOf(address(this)), prevBalance0 - 10 ether); + assertEq(key.currency1.balanceOf(address(this)), prevBalance1 - 10 ether); + assertEq(hook.balanceOf(address(this)), 10 ether); + + hook.addLiquidity( + BaseCustomAccounting.AddLiquidityParams( + 5 ether, 2.5 ether, 2 ether, 2 ether, address(this), MAX_DEADLINE, MIN_TICK, MAX_TICK + ) + ); + + assertEq(key.currency0.balanceOf(address(this)), prevBalance0 - 15 ether); + assertEq(key.currency1.balanceOf(address(this)), prevBalance1 - 12.5 ether); + assertEq(hook.balanceOf(address(this)), 13.75 ether); + + hook.removeLiquidity(BaseCustomAccounting.RemoveLiquidityParams(5 ether, MAX_DEADLINE, MIN_TICK, MAX_TICK)); + + uint256 liquidityTokenBal = hook.balanceOf(address(this)); + assertEq(liquidityTokenBal, 8.75 ether); + assertEq(key.currency0.balanceOf(address(this)), prevBalance0 - 12.5 ether); + assertEq(key.currency1.balanceOf(address(this)), prevBalance1 - 10 ether); + } + + function test_removeLiquidity_allFuzz_succeeds(uint112 amount) public { + vm.assume(amount > 0); + + hook.addLiquidity( + BaseCustomAccounting.AddLiquidityParams( + amount, amount, 0, 0, address(this), MAX_DEADLINE, MIN_TICK, MAX_TICK + ) + ); + + uint256 liquidityTokenBal = hook.balanceOf(address(this)); + + hook.removeLiquidity( + BaseCustomAccounting.RemoveLiquidityParams(liquidityTokenBal, MAX_DEADLINE, MIN_TICK, MAX_TICK) + ); + + assertEq(manager.getLiquidity(id), 0); + } + + function test_removeLiquidity_multiple_succeeds() public { + // Mint tokens for dummy addresses + deal(Currency.unwrap(currency0), address(1), 2 ** 128); + deal(Currency.unwrap(currency1), address(1), 2 ** 128); + deal(Currency.unwrap(currency0), address(2), 2 ** 128); + deal(Currency.unwrap(currency1), address(2), 2 ** 128); + + // Approve the hook + vm.prank(address(1)); + ERC20(Currency.unwrap(currency0)).approve(address(hook), type(uint256).max); + vm.prank(address(1)); + ERC20(Currency.unwrap(currency1)).approve(address(hook), type(uint256).max); + + vm.prank(address(2)); + ERC20(Currency.unwrap(currency0)).approve(address(hook), type(uint256).max); + vm.prank(address(2)); + ERC20(Currency.unwrap(currency1)).approve(address(hook), type(uint256).max); + + // address(1) adds liquidity + vm.prank(address(1)); + hook.addLiquidity( + BaseCustomAccounting.AddLiquidityParams( + 100 ether, 100 ether, 99 ether, 99 ether, address(this), MAX_DEADLINE, MIN_TICK, MAX_TICK + ) + ); + + // address(2) adds liquidity + vm.prank(address(2)); + hook.addLiquidity( + BaseCustomAccounting.AddLiquidityParams( + 100 ether, 100 ether, 99 ether, 99 ether, address(this), MAX_DEADLINE, MIN_TICK, MAX_TICK + ) + ); + + IPoolManager.SwapParams memory params = + IPoolManager.SwapParams({zeroForOne: true, amountSpecified: 100 ether, sqrtPriceLimitX96: SQRT_PRICE_1_4}); + + PoolSwapTest.TestSettings memory testSettings = + PoolSwapTest.TestSettings({takeClaims: false, settleUsingBurn: false}); + + swapRouter.swap(key, params, testSettings, ZERO_BYTES); + + // Test contract removes liquidity, succeeds + hook.removeLiquidity( + BaseCustomAccounting.RemoveLiquidityParams(hook.balanceOf(address(this)), MAX_DEADLINE, MIN_TICK, MAX_TICK) + ); + + // PoolManager does not have any liquidity left over + assertEq(manager.getLiquidity(id), 0); + } + + function test_removeLiquidity_swapRemoveAllFuzz_succeeds(uint112 amount) public { + vm.assume(amount > 4); + + hook.addLiquidity( + BaseCustomAccounting.AddLiquidityParams( + amount, amount, 0, 0, address(this), MAX_DEADLINE, MIN_TICK, MAX_TICK + ) + ); + + IPoolManager.SwapParams memory params = IPoolManager.SwapParams({ + zeroForOne: true, + amountSpecified: (FullMath.mulDiv(amount, 1, 4)).toInt256(), + sqrtPriceLimitX96: SQRT_PRICE_1_4 + }); + + PoolSwapTest.TestSettings memory testSettings = + PoolSwapTest.TestSettings({takeClaims: false, settleUsingBurn: false}); + + swapRouter.swap(key, params, testSettings, ZERO_BYTES); + + uint256 liquidityTokenBal = hook.balanceOf(address(this)); + + hook.removeLiquidity( + BaseCustomAccounting.RemoveLiquidityParams(liquidityTokenBal, MAX_DEADLINE, MIN_TICK, MAX_TICK) + ); + + assertEq(manager.getLiquidity(id), 0); + } + + function test_removeLiquidity_notInitialized_reverts() public { + BaseCustomCurveMock uninitializedHook = BaseCustomCurveMock(0x1000000000000000000000000000000000002088); + deployCodeTo( + "test/mocks/BaseCustomCurveMock.sol:BaseCustomCurveMock", abi.encode(manager), address(uninitializedHook) + ); + + vm.expectRevert(BaseCustomAccounting.PoolNotInitialized.selector); + uninitializedHook.removeLiquidity( + BaseCustomAccounting.RemoveLiquidityParams(1 ether, MAX_DEADLINE, MIN_TICK, MAX_TICK) + ); + } + + function test_addLiquidity_notInitialized_reverts() public { + BaseCustomCurveMock uninitializedHook = BaseCustomCurveMock(0x1000000000000000000000000000000000002088); + deployCodeTo( + "test/mocks/BaseCustomCurveMock.sol:BaseCustomCurveMock", abi.encode(manager), address(uninitializedHook) + ); + + vm.expectRevert(BaseCustomAccounting.PoolNotInitialized.selector); + uninitializedHook.addLiquidity( + BaseCustomAccounting.AddLiquidityParams( + 1 ether, 1 ether, 0, 0, address(this), MAX_DEADLINE, MIN_TICK, MAX_TICK + ) + ); + } + + function test_swap_addThenRemove_succeeds() public { + uint256 prevBalance0 = key.currency0.balanceOf(address(this)); + uint256 prevBalance1 = key.currency1.balanceOf(address(this)); + + hook.addLiquidity( + BaseCustomAccounting.AddLiquidityParams(0, 1 ether, 0, 0, address(this), MAX_DEADLINE, MIN_TICK, MAX_TICK) + ); + + uint256 liquidityTokenBal = hook.balanceOf(address(this)); + + assertEq(liquidityTokenBal, 0.5 ether); + assertEq(key.currency0.balanceOf(address(this)), prevBalance0); + assertEq(key.currency1.balanceOf(address(this)), prevBalance1 - 1 ether); + + vm.expectEmit(true, true, true, true, address(manager)); + emit Swap(id, address(swapRouter), 0, 0, 79228162514264337593543950336, 0, 0, 0); + + IPoolManager.SwapParams memory params = + IPoolManager.SwapParams({zeroForOne: true, amountSpecified: -0.5 ether, sqrtPriceLimitX96: SQRT_PRICE_1_2}); + PoolSwapTest.TestSettings memory settings = + PoolSwapTest.TestSettings({takeClaims: false, settleUsingBurn: false}); + + swapRouter.swap(key, params, settings, ZERO_BYTES); + + assertEq(key.currency0.balanceOf(address(this)), prevBalance0 - 0.5 ether); + assertEq(key.currency1.balanceOf(address(this)), prevBalance1 - 0.5 ether); + + hook.removeLiquidity( + BaseCustomAccounting.RemoveLiquidityParams(hook.balanceOf(address(this)), MAX_DEADLINE, MIN_TICK, MAX_TICK) + ); + + liquidityTokenBal = hook.balanceOf(address(this)); + + assertEq(liquidityTokenBal, 0); + assertEq(key.currency0.balanceOf(address(this)), prevBalance0 - 0.25 ether); + assertEq(key.currency1.balanceOf(address(this)), prevBalance1 - 0.25 ether); + } } diff --git a/test/mocks/BaseCustomCurveMock.sol b/test/mocks/BaseCustomCurveMock.sol index 4900b06..aaf0f1a 100644 --- a/test/mocks/BaseCustomCurveMock.sol +++ b/test/mocks/BaseCustomCurveMock.sol @@ -39,7 +39,7 @@ contract BaseCustomCurveMock is BaseCustomCurve { { amount0 = params.amount0Desired; amount1 = params.amount1Desired; - liquidity = amount0 + amount1; + liquidity = (amount0 + amount1) / 2; } function _calculateOut(RemoveLiquidityParams memory params) @@ -47,9 +47,9 @@ contract BaseCustomCurveMock is BaseCustomCurve { override returns (uint256 amount0, uint256 amount1, uint256 liquidity) { - amount0 = liquidity / 2; - amount1 = liquidity / 2; - liquidity = amount0 + amount1; + amount0 = params.liquidity / 2; + amount1 = params.liquidity / 2; + liquidity = params.liquidity; } // Exclude from coverage report