Skip to content

Commit

Permalink
Refactor Storage Access to be 4337 Compatible (#132)
Browse files Browse the repository at this point in the history
* Refactor Storage Access to be 4337 Compatible

* Make all storage ERC-4337 compatible

* Add Comment Explaining Mapping Setup
  • Loading branch information
nlordell authored Nov 9, 2023
1 parent 0f3ee32 commit 71dff1d
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 31 deletions.
4 changes: 2 additions & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ If you are submitting a bug report, please:
Please send a [GitHub Pull Request to safe-core-protocol repository](https://github.com/safe-global/safe-core-protocol) with a clear description of the proposed changes. Each pull request should be associated with an issue and should be made against the `main` branch.

Branch naming convention:

- For a new feature, use `feature-<issue-number>-short-description`
- For a bug fix, use `fix-<issue-number>-short-description`

Expand All @@ -40,4 +40,4 @@ Steps to be taken before submitting a pull request to be considered for review:
- Make sure there are no linting errors

Thanks,
Safe team
Safe team
29 changes: 15 additions & 14 deletions contracts/SafeProtocolManager.sol
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ contract SafeProtocolManager is ISafeProtocolManager, RegistryManager, HooksMana

/**
* @notice Mapping of a mapping what stores information about plugins that are enabled per account.
* address (Account address) => address (module address) => EnabledPluginInfo
* address (module address) => address (account address) => EnabledPluginInfo
* @dev The key of the inner-most mapping is the account address, which is required for 4337-compatibility.
*/
mapping(address => mapping(address => PluginAccessInfo)) public enabledPlugins;
struct PluginAccessInfo {
Expand Down Expand Up @@ -180,8 +181,8 @@ contract SafeProtocolManager is ISafeProtocolManager, RegistryManager, HooksMana
if (!ISafeProtocolPlugin(plugin).supportsInterface(type(ISafeProtocolPlugin).interfaceId))
revert ContractDoesNotImplementValidInterfaceId(plugin);

PluginAccessInfo storage senderSentinelPlugin = enabledPlugins[msg.sender][SENTINEL_MODULES];
PluginAccessInfo storage senderPlugin = enabledPlugins[msg.sender][plugin];
PluginAccessInfo storage senderSentinelPlugin = enabledPlugins[SENTINEL_MODULES][msg.sender];
PluginAccessInfo storage senderPlugin = enabledPlugins[plugin][msg.sender];

if (senderPlugin.nextPluginPointer != address(0)) {
revert PluginAlreadyEnabled(msg.sender, plugin);
Expand All @@ -208,8 +209,8 @@ contract SafeProtocolManager is ISafeProtocolManager, RegistryManager, HooksMana
* @param plugin Plugin to be disabled
*/
function disablePlugin(address prevPlugin, address plugin) external noZeroOrSentinelPlugin(plugin) onlyAccount {
PluginAccessInfo storage prevPluginInfo = enabledPlugins[msg.sender][prevPlugin];
PluginAccessInfo storage pluginInfo = enabledPlugins[msg.sender][plugin];
PluginAccessInfo storage prevPluginInfo = enabledPlugins[prevPlugin][msg.sender];
PluginAccessInfo storage pluginInfo = enabledPlugins[plugin][msg.sender];

if (prevPluginInfo.nextPluginPointer != plugin) {
revert InvalidPrevPluginAddress(prevPlugin);
Expand All @@ -229,7 +230,7 @@ contract SafeProtocolManager is ISafeProtocolManager, RegistryManager, HooksMana
* @param plugin Address of a plugin
*/
function getPluginInfo(address account, address plugin) external view returns (PluginAccessInfo memory enabled) {
return enabledPlugins[account][plugin];
return enabledPlugins[plugin][account];
}

/**
Expand All @@ -239,7 +240,7 @@ contract SafeProtocolManager is ISafeProtocolManager, RegistryManager, HooksMana
* @return True if the plugin is enabled
*/
function isPluginEnabled(address account, address plugin) public view returns (bool) {
return SENTINEL_MODULES != plugin && enabledPlugins[account][plugin].nextPluginPointer != address(0);
return SENTINEL_MODULES != plugin && enabledPlugins[plugin][account].nextPluginPointer != address(0);
}

/**
Expand Down Expand Up @@ -268,10 +269,10 @@ contract SafeProtocolManager is ISafeProtocolManager, RegistryManager, HooksMana

// Populate return array
uint256 pluginCount = 0;
next = enabledPlugins[account][start].nextPluginPointer;
next = enabledPlugins[start][account].nextPluginPointer;
while (next != address(0) && next != SENTINEL_MODULES && pluginCount < pageSize) {
array[pluginCount] = next;
next = enabledPlugins[account][next].nextPluginPointer;
next = enabledPlugins[next][account].nextPluginPointer;
pluginCount++;
}

Expand All @@ -282,10 +283,10 @@ contract SafeProtocolManager is ISafeProtocolManager, RegistryManager, HooksMana

/**
Because of the argument validation, we can assume that the loop will always iterate over the valid plugin list values
and the `next` variable will either be an enabled plugin or a sentinel address (signalling the end).
and the `next` variable will either be an enabled plugin or a sentinel address (signalling the end).
If we haven't reached the end inside the loop, we need to set the next pointer to the last element of the plugins array
because the `next` variable (which is a plugin by itself) acting as a pointer to the start of the next page is neither
because the `next` variable (which is a plugin by itself) acting as a pointer to the start of the next page is neither
included to the current page, nor will it be included in the next one if you pass it as a start.
*/
if (next != SENTINEL_MODULES && pluginCount != 0) {
Expand Down Expand Up @@ -436,7 +437,7 @@ contract SafeProtocolManager is ISafeProtocolManager, RegistryManager, HooksMana
}

function checkOnlyEnabledPlugin(address account) private view {
if (enabledPlugins[account][msg.sender].nextPluginPointer == address(0)) {
if (enabledPlugins[msg.sender][account].nextPluginPointer == address(0)) {
revert PluginNotEnabled(msg.sender);
}
}
Expand All @@ -457,7 +458,7 @@ contract SafeProtocolManager is ISafeProtocolManager, RegistryManager, HooksMana
*/
function checkPermission(address account, uint8 permission) private view {
// For each action, Manager will read storage and call plugin's requiresPermissions().
uint8 givenPermissions = enabledPlugins[account][msg.sender].permissions;
uint8 givenPermissions = enabledPlugins[msg.sender][account].permissions;
uint8 requiresPermissions = ISafeProtocolPlugin(msg.sender).requiresPermissions();

if ((requiresPermissions & givenPermissions & permission) != permission) {
Expand Down
9 changes: 5 additions & 4 deletions contracts/SignatureValidatorManager.sol
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ contract SignatureValidatorManager is RegistryManager, ISafeProtocolFunctionHand

// Storage
/**
* @notice Mapping to account address => domain separator => signature validator contract
* @notice Mapping to domain separator => account address => signature validator contract
* @dev The key of the inner-most mapping is the account address, which is required for 4337-compatibility.
*/
mapping(address => mapping(bytes32 => address)) public signatureValidators;
mapping(bytes32 => mapping(address => address)) public signatureValidators;

/**
* @notice Mapping to account address => signature validator hooks contract
Expand Down Expand Up @@ -72,7 +73,7 @@ contract SignatureValidatorManager is RegistryManager, ISafeProtocolFunctionHand
if (!ISafeProtocolSignatureValidator(signatureValidator).supportsInterface(type(ISafeProtocolSignatureValidator).interfaceId))
revert ContractDoesNotImplementValidInterfaceId(signatureValidator);
}
signatureValidators[msg.sender][domainSeparator] = signatureValidator;
signatureValidators[domainSeparator][msg.sender] = signatureValidator;

emit SignatureValidatorChanged(msg.sender, domainSeparator, signatureValidator);
}
Expand Down Expand Up @@ -193,7 +194,7 @@ contract SignatureValidatorManager is RegistryManager, ISafeProtocolFunctionHand
revert InvalidMessageHash(messageHash);
}

address signatureValidator = signatureValidators[account][domainSeparator];
address signatureValidator = signatureValidators[domainSeparator][account];
if (signatureValidator == address(0)) {
revert SignatureValidatorNotSet(account);
}
Expand Down
12 changes: 7 additions & 5 deletions contracts/base/FunctionHandlerManager.sol
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ import {MODULE_TYPE_FUNCTION_HANDLER} from "../common/Constants.sol";
*/
abstract contract FunctionHandlerManager is RegistryManager {
// Storage
/** @dev Mapping that stores information about an account, function selector, and address of the account.
/**
* @notice Mapping that stores information about an account, function selector, and address of the account.
* @dev The key of the inner-most mapping is the account address, which is required for 4337-compatibility.
*/
mapping(address => mapping(bytes4 => address)) public functionHandlers;
mapping(bytes4 => mapping(address => address)) public functionHandlers;

// Events
event FunctionHandlerChanged(address indexed account, bytes4 indexed selector, address indexed functionHandler);
Expand All @@ -31,7 +33,7 @@ abstract contract FunctionHandlerManager is RegistryManager {
* @return functionHandler Address of the contract to be set as a function handler
*/
function getFunctionHandler(address account, bytes4 selector) external view returns (address functionHandler) {
functionHandler = functionHandlers[account][selector];
functionHandler = functionHandlers[selector][account];
}

/**
Expand All @@ -48,7 +50,7 @@ abstract contract FunctionHandlerManager is RegistryManager {
}

// No need to check if functionHandler implements expected interfaceId as check will be done when adding to registry.
functionHandlers[msg.sender][selector] = functionHandler;
functionHandlers[selector][msg.sender] = functionHandler;
emit FunctionHandlerChanged(msg.sender, selector, functionHandler);
}

Expand All @@ -63,7 +65,7 @@ abstract contract FunctionHandlerManager is RegistryManager {
address account = msg.sender;
bytes4 functionSelector = bytes4(msg.data);

address functionHandler = functionHandlers[account][functionSelector];
address functionHandler = functionHandlers[functionSelector][account];

// Revert if functionHandler is not set
if (functionHandler == address(0)) {
Expand Down
4 changes: 2 additions & 2 deletions docs/execution_flows.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ end
subgraph SafeProtocolManager
ExamplePlugin1 -->|Execute tx for an Account through Plugin| Execute_Transaction(Execute transaction from a Plugin) --> Validate_ExecuteFromPluginFlow{Is Plugin Enabled?<br>Call SafeProtocolRegistry<br>and validate if Plugin trusted}
Validate_ExecuteFromPluginFlow -- No ----> E(Revert transaction)
Validate_ExecuteFromPluginFlow -- No ----> E(Revert transaction)
end
```

Expand Down Expand Up @@ -118,4 +118,4 @@ SafeProtocolManager --> isValidSignature{isValidSignature}
isValidSignature --> |Yes| ExecuteTx(Continue transaction execution)
User("`Users(s)`") --> |Generate an Account signature| Sign_Transaction
```
```
2 changes: 1 addition & 1 deletion src/tasks/generate_deployments_markdown.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ task("generate:deployments", "Generate markdown file with deployed contract addr
console.error("No deployments file found. Please run the deployment script first.");
return;
}

const {default: deployments} = await import("../../deployments");
const markdownFile = "./docs/deployments.md";

Expand Down
2 changes: 1 addition & 1 deletion src/utils/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ export const MODULE_TYPE_SIGNATURE_VALIDATOR: number = 16;

// solidity: bytes4(keccak256("Account712Signature(bytes32,bytes32,bytes)"));
// javascript: hre.ethers.keccak256(toUtf8Bytes("Account712Signature(bytes32,bytes32,bytes)")).slice(0, 10);
export const SIGNATURE_VALIDATOR_SELECTOR = "0xb5c726cb";
export const SIGNATURE_VALIDATOR_SELECTOR = "0xb5c726cb";
4 changes: 2 additions & 2 deletions test/SignatureValidatorManager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ describe("SignatureValidatorManager", () => {

await account.executeCallViaMock(safeProtocolSignatureValidatorManager.target, 0, dataSetValidator, MaxUint256);

expect(await safeProtocolSignatureValidatorManager.signatureValidators(account.target, domainSeparator)).to.be.equal(
expect(await safeProtocolSignatureValidatorManager.signatureValidators(domainSeparator, account.target)).to.be.equal(
mockContract.target,
);

Expand All @@ -89,7 +89,7 @@ describe("SignatureValidatorManager", () => {
]);

await account.executeCallViaMock(safeProtocolSignatureValidatorManager.target, 0, dataResetValidator, MaxUint256);
expect(await safeProtocolSignatureValidatorManager.signatureValidators(account.target, domainSeparator)).to.be.equal(ZeroAddress);
expect(await safeProtocolSignatureValidatorManager.signatureValidators(domainSeparator, account.target)).to.be.equal(ZeroAddress);
});

it("should revert when enabling a signature validator hooks not implementing ISafeProtocolSignatureValidatorHooks interface", async () => {
Expand Down

0 comments on commit 71dff1d

Please sign in to comment.