Skip to content

Commit

Permalink
feat(contracts): add tally results
Browse files Browse the repository at this point in the history
- [x] Allow to add tally results and keep them onchain
- [x] Update cli commands
  • Loading branch information
0xmad committed Oct 2, 2024
1 parent 6d08c1a commit 09d620f
Show file tree
Hide file tree
Showing 9 changed files with 248 additions and 11 deletions.
1 change: 1 addition & 0 deletions packages/cli/tests/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ export const mergeSignupsArgs: Omit<MergeSignupsArgs, "signer"> = {

export const proveOnChainArgs: Omit<ProveOnChainArgs, "signer"> = {
pollId: 0n,
tallyFile: testTallyFilePath,
proofDir: testProofsDirPath,
};

Expand Down
25 changes: 23 additions & 2 deletions packages/cli/ts/commands/proveOnChain.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/* eslint-disable no-await-in-loop */
import { type BigNumberish } from "ethers";
import { type IVerifyingKeyStruct, formatProofForVerifierContract } from "maci-contracts";
import { type IVerifyingKeyStruct, TallyData, formatProofForVerifierContract } from "maci-contracts";
import {
MACI__factory as MACIFactory,
AccQueue__factory as AccQueueFactory,
Expand All @@ -11,7 +11,7 @@ import {
Verifier__factory as VerifierFactory,
} from "maci-contracts/typechain-types";
import { MESSAGE_TREE_ARITY, STATE_TREE_ARITY } from "maci-core";
import { G1Point, G2Point } from "maci-crypto";
import { G1Point, G2Point, genTreeProof } from "maci-crypto";
import { VerifyingKey } from "maci-domainobjs";

import fs from "fs";
Expand Down Expand Up @@ -42,6 +42,7 @@ export const proveOnChain = async ({
proofDir,
maciAddress,
signer,
tallyFile,
quiet = true,
}: ProveOnChainArgs): Promise<void> => {
banner(quiet);
Expand Down Expand Up @@ -368,4 +369,24 @@ export const proveOnChain = async ({
if (tallyBatchNum === totalTallyBatches) {
logGreen(quiet, success("All vote tallying proofs have been submitted."));
}

if (tallyFile) {
const tallyData = await fs.promises.readFile(tallyFile).then((res) => JSON.parse(res.toString()) as TallyData);

const tallyResults = tallyData.results.tally.map((t) => BigInt(t));
const tallyResultProofs = tallyData.results.tally.map((_, index) =>
genTreeProof(index, tallyResults, Number(treeDepths.voteOptionTreeDepth)),
);

await tallyContract
.addTallyResults(
tallyData.results.tally.map((_, index) => index),
tallyResults,
tallyResultProofs,
tallyData.results.salt,
tallyData.totalSpentVoiceCredits.commitment,
tallyData.perVOSpentVoiceCredits?.commitment ?? 0n,
)
.then((tx) => tx.wait());
}
};
5 changes: 5 additions & 0 deletions packages/cli/ts/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,10 @@ program
.command("proveOnChain")
.description("prove the results of a poll on chain")
.requiredOption("-o, --poll-id <pollId>", "the poll id", BigInt)
.option(
"-t, --tally-file <tallyFile>",
"the tally file with results, per vote option spent credits, spent voice credits total",
)
.option("-q, --quiet <quiet>", "whether to print values to the console", (value) => value === "true", false)
.option("-r, --rpc-provider <provider>", "the rpc provider URL")
.option("-x, --maci-address <maciAddress>", "the MACI contract address")
Expand All @@ -645,6 +649,7 @@ program

await proveOnChain({
pollId: cmdObj.pollId,
tallyFile: cmdObj.tallyFile,
proofDir: cmdObj.proofDir,
maciAddress: cmdObj.maciAddress,
quiet: cmdObj.quiet,
Expand Down
5 changes: 5 additions & 0 deletions packages/cli/ts/utils/interfaces.ts
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,11 @@ export interface ProveOnChainArgs {
*/
signer: Signer;

/**
* The tally file with results, per vote option spent credits, spent voice credits total
*/
tallyFile?: string;

/**
* The address of the MACI contract
*/
Expand Down
76 changes: 76 additions & 0 deletions packages/contracts/contracts/Tally.sol
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ contract Tally is Ownable, SnarkCommon, CommonUtilities, Hasher, DomainObjs, ITa
IMessageProcessor public immutable messageProcessor;
Mode public immutable mode;

// The tally results
mapping(uint256 => uint256) public tallyResults;

// The total tally results number
uint256 public totalTallyResults;

/// @notice custom errors
error ProcessingNotComplete();
error InvalidTallyVotesProof();
Expand All @@ -60,6 +66,7 @@ contract Tally is Ownable, SnarkCommon, CommonUtilities, Hasher, DomainObjs, ITa
error BatchStartIndexTooLarge();
error TallyBatchSizeTooLarge();
error NotSupported();
error VotesNotTallied();

/// @notice Create a new Tally contract
/// @param _verifier The Verifier contract
Expand Down Expand Up @@ -344,4 +351,73 @@ contract Tally is Ownable, SnarkCommon, CommonUtilities, Hasher, DomainObjs, ITa
isValid = hash2(tally) == tallyCommitment;
}
}

/**
* @notice Add and verify tally results by batch.
* @param _voteOptionIndices Vote option index.
* @param _tallyResults The results of vote tally for the recipients.
* @param _tallyResultProofs Proofs of correctness of the vote tally results.
* @param _tallyResultSalt the respective salt in the results object in the tally.json
* @param _spentVoiceCreditsHashes hashLeftRight(number of spent voice credits, spent salt)
* @param _perVOSpentVoiceCreditsHashes hashLeftRight(merkle root of the no spent voice credits per vote option, perVOSpentVoiceCredits salt)
*/
function addTallyResults(
uint256[] calldata _voteOptionIndices,

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

uint256[] calldata _tallyResults,

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

uint256[][][] calldata _tallyResultProofs,

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

uint256 _tallyResultSalt,

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

uint256 _spentVoiceCreditsHashes,

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

uint256 _perVOSpentVoiceCreditsHashes

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

) public virtual onlyOwner {
if (!isTallied()) {
revert VotesNotTallied();
}

for (uint256 i = 0; i < _voteOptionIndices.length; i++) {
addTallyResult(
_voteOptionIndices[i],
_tallyResults[i],
_tallyResultProofs[i],
_tallyResultSalt,
_spentVoiceCreditsHashes,
_perVOSpentVoiceCreditsHashes
);
}
}

/**
* @dev Add and verify tally votes and calculate sum of tally squares for alpha calculation.
* @param _voteOptionIndex Vote option index.
* @param _tallyResult The results of vote tally for the recipients.
* @param _tallyResultProof Proofs of correctness of the vote tally results.
* @param _tallyResultSalt the respective salt in the results object in the tally.json
* @param _spentVoiceCreditsHash hashLeftRight(number of spent voice credits, spent salt)
* @param _perVOSpentVoiceCreditsHash hashLeftRight(merkle root of the no spent voice credits per vote option, perVOSpentVoiceCredits salt)
*/
function addTallyResult(
uint256 _voteOptionIndex,

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

uint256 _tallyResult,

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

uint256[][] calldata _tallyResultProof,

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

uint256 _tallyResultSalt,

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

uint256 _spentVoiceCreditsHash,

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

uint256 _perVOSpentVoiceCreditsHash

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

) internal virtual {
(, , , uint8 voteOptionTreeDepth) = poll.treeDepths();
bool isValid = verifyTallyResult(
_voteOptionIndex,
_tallyResult,
_tallyResultProof,
_tallyResultSalt,
voteOptionTreeDepth,
_spentVoiceCreditsHash,
_perVOSpentVoiceCreditsHash
);

if (!isValid) {
revert InvalidTallyVotesProof();
}

tallyResults[_voteOptionIndex] = _tallyResult;
totalTallyResults++;
}

Check warning

Code scanning / Slither

Unused return Medium

Check notice

Code scanning / Slither

Calls inside a loop Low

Check warning

Code scanning / Slither

Costly operations inside a loop Warning

}
22 changes: 19 additions & 3 deletions packages/contracts/tasks/helpers/Prover.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/* eslint-disable no-console, no-await-in-loop */
import { STATE_TREE_ARITY, MESSAGE_TREE_ARITY } from "maci-core";
import { G1Point, G2Point } from "maci-crypto";
import { G1Point, G2Point, genTreeProof } from "maci-crypto";
import { VerifyingKey } from "maci-domainobjs";

import type { IVerifyingKeyStruct, Proof } from "../../ts/types";
Expand All @@ -9,7 +9,7 @@ import type { BigNumberish } from "ethers";

import { asHex, formatProofForVerifierContract } from "../../ts/utils";

import { IProverParams } from "./types";
import { IProverParams, TallyData } from "./types";

/**
* Prover class is designed to prove message processing and tally proofs on-chain.
Expand Down Expand Up @@ -219,7 +219,7 @@ export class Prover {
*
* @param proofs tally proofs
*/
async proveTally(proofs: Proof[]): Promise<void> {
async proveTally(proofs: Proof[], tallyData: TallyData): Promise<void> {
const [treeDepths, numSignUpsAndMessages, tallyBatchNumber, mode, stateTreeDepth] = await Promise.all([
this.pollContract.treeDepths(),
this.pollContract.numSignUpsAndMessages(),
Expand Down Expand Up @@ -310,6 +310,22 @@ export class Prover {
if (tallyBatchNum === totalTallyBatches) {
console.log("All vote tallying proofs have been submitted.");
}

const tallyResults = tallyData.results.tally.map((t) => BigInt(t));
const tallyResultProofs = tallyData.results.tally.map((_, index) =>
genTreeProof(index, tallyResults, Number(treeDepths.voteOptionTreeDepth)),
);

await this.tallyContract
.addTallyResults(
tallyData.results.tally.map((_, index) => index),
tallyResults,
tallyResultProofs,
tallyData.results.salt,
tallyData.totalSpentVoiceCredits.commitment,
tallyData.perVOSpentVoiceCredits?.commitment ?? 0n,
)
.then((tx) => tx.wait());
}

/**
Expand Down
5 changes: 3 additions & 2 deletions packages/contracts/tasks/runner/prove.ts
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,9 @@ task("prove", "Command to generate proof and prove the result of a poll on-chain
data.processProofs = await proofGenerator.generateMpProofs();
await prover.proveMessageProcessing(data.processProofs);

data.tallyProofs = await proofGenerator.generateTallyProofs(network).then(({ proofs }) => proofs);
await prover.proveTally(data.tallyProofs);
const { proofs: tallyProofs, tallyData } = await proofGenerator.generateTallyProofs(network);
data.tallyProofs = tallyProofs;
await prover.proveTally(data.tallyProofs, tallyData);

const endBalance = await signer.provider.getBalance(signer);

Expand Down
Loading

0 comments on commit 09d620f

Please sign in to comment.