From fbed550f3d12abd9870d625dbbda83c6edb5210b Mon Sep 17 00:00:00 2001 From: Oleg Komendant <44612825+Hrom131@users.noreply.github.com> Date: Thu, 8 Aug 2024 18:50:38 +0300 Subject: [PATCH] Feature/proof testing (#3) * Add base CircuitZKit support * Update Chai types * Add strict property * update readme * Add addition checks --------- Co-authored-by: Artem Chystiakov --- README.md | 7 +- package-lock.json | 4 +- package.json | 2 +- src/types.ts | 27 ++++++-- src/utils.ts | 43 ++++++++++-- src/witness.ts | 97 +++++++++++++------------- test/chai-zkit.test.ts | 152 +++++++++++++++++++++++++---------------- 7 files changed, 210 insertions(+), 122 deletions(-) diff --git a/README.md b/README.md index 68ab88d..a5d4402 100644 --- a/README.md +++ b/README.md @@ -39,12 +39,12 @@ After installing the package, you may use the following assertions: ```ts const matrix = await zkit.getCircuit("Matrix"); -// strict assertion, all the outputs must be present -await expect(matrix).with.witnessInputs({ a, b, c }).to.have.witnessOutputsStrict({ d, e, f }); - // partial output assertion await expect(matrix).with.witnessInputs({ a, b, c }).to.have.witnessOutputs({ d }); +// strict assertion, all the outputs must be present +await expect(matrix).with.witnessInputs({ a, b, c }).to.have.strict.witnessOutputs({ d, e, f }); + // provided output `e` doesn't match the obtained one await expect(expect(matrix).with.witnessInputs({ a, b, c }).to.have.witnessOutputs({ e })).to.be.rejectedWith( `Expected output "e" to be "[[2,5,0],[17,26,0],[0,0,0]]", but got "[[1,4,0],[16,25,0],[0,0,0]]"`, @@ -58,4 +58,5 @@ await expect( ## Known limitations +- Do not use `not` chai negation prior `witnessInputs` call, this will break the typization. - Temporarily, only the witness `input <> output` signals testing is supported. diff --git a/package-lock.json b/package-lock.json index e37b5c9..647b792 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@solarity/chai-zkit", - "version": "0.0.1", + "version": "0.1.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@solarity/chai-zkit", - "version": "0.0.1", + "version": "0.1.0", "license": "MIT", "dependencies": { "@solarity/zkit": "0.2.4", diff --git a/package.json b/package.json index f73e1fd..a2e0885 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@solarity/chai-zkit", - "version": "0.0.1", + "version": "0.1.0", "license": "MIT", "author": "Distributed Lab", "readme": "README.md", diff --git a/src/types.ts b/src/types.ts index 66e2d63..cfdc6dd 100644 --- a/src/types.ts +++ b/src/types.ts @@ -17,10 +17,31 @@ declare global { } interface ExpectStatic { - (val: T, message?: string): Assertion; (val: T, message?: string): Assertion; } + interface AsyncAssertion extends Promise { + not: AsyncAssertion; + strict: AsyncAssertion; + to: AsyncAssertion; + be: AsyncAssertion; + been: AsyncAssertion; + is: AsyncAssertion; + that: AsyncAssertion; + which: AsyncAssertion; + and: AsyncAssertion; + has: AsyncAssertion; + have: AsyncAssertion; + with: AsyncAssertion; + at: AsyncAssertion; + of: AsyncAssertion; + same: AsyncAssertion; + but: AsyncAssertion; + does: AsyncAssertion; + witnessInputs(inputs: T extends Circuit ? ExtractInputs : never): AsyncAssertion; + witnessOutputs(outputs: T extends Circuit ? Partial> : never): AsyncAssertion; + } + interface Assertion { to: Assertion; be: Assertion; @@ -37,12 +58,8 @@ declare global { same: Assertion; but: Assertion; does: Assertion; - witnessInputs(inputs: T extends Circuit ? ExtractInputs : never): AsyncAssertion; - witnessOutputsStrict(outputs: T extends Circuit ? ExtractOutputs : never): AsyncAssertion; witnessOutputs(outputs: T extends Circuit ? Partial> : never): AsyncAssertion; } - - interface AsyncAssertion extends Assertion, Promise {} } } diff --git a/src/utils.ts b/src/utils.ts index 5b6d69f..13deb13 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -1,6 +1,6 @@ import * as fs from "fs"; -import { CircuitZKit, Signal, Signals } from "@solarity/zkit"; +import { CircuitZKit, NumberLike, Signal, Signals } from "@solarity/zkit"; export function loadOutputs(zkit: CircuitZKit, witness: bigint[], inputs: Signals): Signals { const signalToIndex = loadSym(zkit); @@ -24,6 +24,37 @@ export function loadOutputs(zkit: CircuitZKit, witness: bigint[], inputs: Signal return parseOutputSignals(witness, signals.slice(0, minInputIndex), signalToIndex); } +export function flattenSignals(signals: Signals): NumberLike[] { + let flattenSignalsArr: NumberLike[] = []; + + for (const output of Object.keys(signals)) { + flattenSignalsArr.push(...flattenSignal(signals[output])); + } + + return flattenSignalsArr; +} + +export function compareSignals(actualSignal: Signal, expectedSignal: Signal): boolean { + const actualSignalValues: NumberLike[] = flattenSignal(actualSignal); + const expectedSignalValues: NumberLike[] = flattenSignal(expectedSignal); + + if (actualSignalValues.length !== expectedSignalValues.length) { + return false; + } + + for (let i = 0; i < actualSignalValues.length; i++) { + if (BigInt(actualSignalValues[i]) !== BigInt(expectedSignalValues[i])) { + return false; + } + } + + return true; +} + +export function stringifySignal(signal: Signal): string { + return JSON.stringify(signal, (_, v) => (typeof v === "bigint" ? v.toString() : v)).replaceAll(`"`, ""); +} + function loadSym(zkit: CircuitZKit): Map { const symFile = zkit.mustGetArtifactsFilePath("sym"); const signals = new Map(); @@ -42,10 +73,6 @@ function loadSym(zkit: CircuitZKit): Map { return signals; } -export function stringifySignal(signal: Signal): string { - return JSON.stringify(signal, (_, v) => (typeof v === "bigint" ? v.toString() : v)).replaceAll(`"`, ""); -} - function parseOutputSignals(witness: bigint[], signals: string[], signalToIndex: Map): Signals { let outputSignals: Signals = {}; @@ -110,3 +137,9 @@ function countSignalDimensions(signal: Signal): number { return countSignalDimensions(signal[0]) + 1; } + +function flattenSignal(signal: Signal): NumberLike[] { + const flatValue = Array.isArray(signal) ? signal.flatMap((signal) => flattenSignal(signal)) : signal; + + return Array.isArray(flatValue) ? flatValue : [flatValue]; +} diff --git a/src/witness.ts b/src/witness.ts index 56c1092..30f7f58 100644 --- a/src/witness.ts +++ b/src/witness.ts @@ -1,8 +1,14 @@ -import { CircuitZKit, Signals } from "@solarity/zkit"; +import { CircuitZKit, NumberLike, Signals } from "@solarity/zkit"; -import { loadOutputs, stringifySignal } from "./utils"; +import { compareSignals, flattenSignals, loadOutputs, stringifySignal } from "./utils"; export function witness(chai: Chai.ChaiStatic, utils: Chai.ChaiUtils): void { + chai.Assertion.addProperty("strict", function (this: any) { + utils.flag(this, "strict", true); + + return this; + }); + chai.Assertion.addMethod("witnessInputs", function (this: any, inputs: Signals) { const obj = utils.flag(this, "object"); @@ -23,11 +29,14 @@ export function witness(chai: Chai.ChaiStatic, utils: Chai.ChaiUtils): void { return this; }); - chai.Assertion.addMethod("witnessOutputsStrict", function (this: any, outputs: Signals) { + chai.Assertion.addChainableMethod; + + chai.Assertion.addMethod("witnessOutputs", function (this: any, outputs: Signals | NumberLike[]) { const obj = utils.flag(this, "object"); + const isStrict = utils.flag(this, "strict"); if (!(obj instanceof CircuitZKit)) { - throw new Error("`witnessOutputsStrict` is expected to be called on `CircuitZKit`"); + throw new Error("`witnessOutputs` is expected to be called on `CircuitZKit`"); } const promise = (this.then === undefined ? Promise.resolve() : this).then(async () => { @@ -35,7 +44,7 @@ export function witness(chai: Chai.ChaiStatic, utils: Chai.ChaiUtils): void { const inputs = utils.flag(this, "inputs"); if (!witness) { - throw new Error("`witnessOutputsStrict` is expected to be called after `witnessInputs`"); + throw new Error("`witnessOutputs` is expected to be called after `witnessInputs`"); } if (Object.keys(inputs).length === 0) { @@ -44,17 +53,7 @@ export function witness(chai: Chai.ChaiStatic, utils: Chai.ChaiUtils): void { const actual = loadOutputs(obj as CircuitZKit, witness, inputs); - if (Object.keys(actual).length !== Object.keys(outputs).length) { - throw new Error(`Expected ${Object.keys(outputs).length} outputs, but got ${Object.keys(actual).length}`); - } - - for (const output of Object.keys(outputs)) { - this.assert( - stringifySignal(actual[output]) === stringifySignal(outputs[output]), - `Expected output "${output}" to be "${stringifySignal(outputs[output])}", but got "${stringifySignal(actual[output])}"`, - `Expected output "${output}" NOT to be "${stringifySignal(outputs[output])}", but it is"`, - ); - } + witnessOutputsCompare(this, actual, outputs, isStrict); }); this.then = promise.then.bind(promise); @@ -62,40 +61,44 @@ export function witness(chai: Chai.ChaiStatic, utils: Chai.ChaiUtils): void { return this; }); +} - chai.Assertion.addMethod("witnessOutputs", function (this: any, outputs: Signals) { - const obj = utils.flag(this, "object"); - - if (!(obj instanceof CircuitZKit)) { - throw new Error("`witnessOutputs` is expected to be called on `CircuitZKit`"); +function witnessOutputsCompare( + instance: any, + actualOutputs: Signals, + expectedOutputs: Signals | NumberLike[], + isStrict?: boolean, +) { + if (Array.isArray(expectedOutputs)) { + const actualOutputsArr: NumberLike[] = flattenSignals(actualOutputs); + + if ( + (isStrict && actualOutputsArr.length !== expectedOutputs.length) || + actualOutputsArr.length < expectedOutputs.length + ) { + throw new Error(`Expected ${actualOutputsArr.length} outputs, but got ${expectedOutputs.length}`); } - const promise = (this.then === undefined ? Promise.resolve() : this).then(async () => { - const witness = utils.flag(this, "witness"); - const inputs = utils.flag(this, "inputs"); - - if (!witness) { - throw new Error("`witnessOutputs` is expected to be called after `witnessInputs`"); - } - - if (Object.keys(inputs).length === 0) { - throw new Error("Circuit must have at least one input to extract outputs"); - } - - const actual = loadOutputs(obj as CircuitZKit, witness, inputs); - - for (const output of Object.keys(outputs)) { - this.assert( - stringifySignal(actual[output]) === stringifySignal(outputs[output]), - `Expected output "${output}" to be "${stringifySignal(outputs[output])}", but got "${stringifySignal(actual[output])}"`, - `Expected output "${output}" NOT to be "${stringifySignal(outputs[output])}", but it is"`, - ); - } + expectedOutputs.forEach((output: NumberLike, index: number) => { + instance.assert( + BigInt(output) === BigInt(actualOutputsArr[index]), + `Expected output with index "${index}" to be "${output}", but got "${actualOutputsArr[index]}"`, + `Expected output "${output}" NOT to be "${output}", but it is"`, + ); }); + } else { + if (isStrict && Object.keys(actualOutputs).length !== Object.keys(expectedOutputs).length) { + throw new Error( + `Expected ${Object.keys(expectedOutputs).length} outputs, but got ${Object.keys(actualOutputs).length}`, + ); + } - this.then = promise.then.bind(promise); - this.catch = promise.catch.bind(promise); - - return this; - }); + for (const output of Object.keys(expectedOutputs)) { + instance.assert( + compareSignals(actualOutputs[output], expectedOutputs[output]), + `Expected output "${output}" to be "${stringifySignal(expectedOutputs[output])}", but got "${stringifySignal(actualOutputs[output])}"`, + `Expected output "${output}" NOT to be "${stringifySignal(expectedOutputs[output])}", but it is"`, + ); + } + } } diff --git a/test/chai-zkit.test.ts b/test/chai-zkit.test.ts index 8c3e3fc..5fbc01c 100644 --- a/test/chai-zkit.test.ts +++ b/test/chai-zkit.test.ts @@ -3,7 +3,7 @@ import { expect } from "chai"; import * as fs from "fs"; import path from "path"; -import { NumberLike } from "@solarity/zkit"; +import { NumberLike, CircuitZKit } from "@solarity/zkit"; import { useFixtureProject } from "./helpers"; @@ -20,6 +20,7 @@ describe("chai-zkit", () => { return path.join(process.cwd(), "contracts", "verifiers"); } + let baseMatrix: CircuitZKit; let matrix: Matrix; useFixtureProject("complex-circuits"); @@ -29,6 +30,7 @@ describe("chai-zkit", () => { const circuitArtifactsPath = getArtifactsFullPath(`${circuitName}.circom`); const verifierDirPath = getVerifiersDirFullPath(); + baseMatrix = new CircuitZKit({ circuitName, circuitArtifactsPath, verifierDirPath }); matrix = new Matrix({ circuitName, circuitArtifactsPath, verifierDirPath }); }); @@ -77,64 +79,15 @@ describe("chai-zkit", () => { await expect(matrix) .with.witnessInputs({ a, b, c: "1337" }) .with.witnessInputs({ a, b, c }) - .to.have.witnessOutputsStrict({ d, e, f }); - }); - }); - - describe("witnessOutputsStrict", () => { - it("should not pass if not called on zkit", async () => { - /// @ts-ignore - expect(() => expect(1).to.have.witnessOutputsStrict({ d, e })).to.throw( - "`witnessOutputsStrict` is expected to be called on `CircuitZKit`", - ); - }); - - it("should not pass if called not before witnessInputs", async () => { - await expect(expect(matrix).to.have.witnessOutputsStrict({ d, e, f })).to.be.rejectedWith( - "`witnessOutputsStrict` is expected to be called after `witnessInputs`", - ); - }); - - it("should not pass if no inputs", async () => { - const circuitName = "NoInputs"; - const circuitArtifactsPath = getArtifactsFullPath(`${circuitName}.circom`); - const verifierDirPath = getVerifiersDirFullPath(); - - const noInputs = new NoInputs({ circuitName, circuitArtifactsPath, verifierDirPath }); - - await expect( - expect(noInputs) - .with.witnessInputs({} as any) - .to.have.witnessOutputsStrict({ c: "1337" }), - ).to.be.rejectedWith("Circuit must have at least one input to extract outputs"); - }); - - it("should not pass if outputs are incorrect for given inputs", async () => { - e = d; - - await expect( - expect(matrix).with.witnessInputs({ a, b, c }).to.have.witnessOutputsStrict({ d, e, f }), - ).to.be.rejectedWith( - `Expected output "e" to be "[[2,5,0],[17,26,0],[0,0,0]]", but got "[[1,4,0],[16,25,0],[0,0,0]]"`, - ); - }); - - it("should not pass if not the same amount of outputs", async () => { - await expect( - expect(matrix) - .with.witnessInputs({ a, b, c }) - .to.have.witnessOutputsStrict({ d } as unknown as any), - ).to.be.rejectedWith("Expected 1 outputs, but got 3"); - }); - - it("should not pass if negated but outputs are correct", async () => { - await expect( - expect(matrix).with.witnessInputs({ a, b, c }).to.not.have.witnessOutputsStrict({ d, e, f }), - ).to.be.rejectedWith(`Expected output "d" NOT to be "[[2,5,0],[17,26,0],[0,0,0]]", but it is"`); - }); - - it("should pass if outputs are correct for given inputs", async () => { - await expect(matrix).with.witnessInputs({ a, b, c }).to.have.witnessOutputsStrict({ d, e, f }); + .to.have.witnessOutputs({ + d, + e: [ + ["1", "4", "0"], + ["16", "25", "0"], + ["0", "0", "0"], + ], + f: "0x1", + }); }); }); @@ -192,6 +145,61 @@ describe("chai-zkit", () => { ).to.be.rejectedWith("Sym file is missing input signals"); }); + it("should not pass if not the same amount of outputs and strict", async () => { + await expect( + expect(matrix).with.witnessInputs({ a, b, c }).to.have.strict.witnessOutputs({ d }), + ).to.be.rejectedWith("Expected 1 outputs, but got 3"); + }); + + it("should not pass if pass output arr with invalid length", async () => { + await expect( + expect(matrix) + .with.witnessInputs({ a, b, c }) + .to.have.witnessOutputs({ d: [["123"]] }), + ).to.be.rejectedWith(`Expected output "d" to be "[[123]]", but got "[[2,5,0],[17,26,0],[0,0,0]]"`); + }); + + it("should not pass if not the same amount of outputs and strict and base CircuitZKit object", async () => { + const wrongOutputs: string[] = ["2", "0x5", "0", "17", "26", "0"]; + + await expect( + expect(baseMatrix).with.witnessInputs({ a, b, c }).to.have.strict.witnessOutputs(wrongOutputs), + ).to.be.rejectedWith(`Expected 19 outputs, but got ${wrongOutputs.length}`); + }); + + it("should not pass for base CircuitZKit object and expected outputs length bigger than actual", async () => { + const wrongOutputs: string[] = [ + "2", + "0x5", + "0", + "17", + "26", + "0", + "0", + "0", + "0", + "1", + "4", + "0", + "16", + "25", + "0", + "0", + "0", + "0", + "1", + "1", + ]; + + await expect( + expect(baseMatrix).with.witnessInputs({ a, b, c }).to.have.witnessOutputs(wrongOutputs), + ).to.be.rejectedWith(`Expected 19 outputs, but got ${wrongOutputs.length}`); + }); + + it("should pass if not the same amount of outputs and not strict", async () => { + await expect(matrix).with.witnessInputs({ a, b, c }).to.have.witnessOutputs({ d }); + }); + it("should pass if outputs are correct for given inputs", async () => { await expect(matrix).with.witnessInputs({ a, b, c }).to.have.witnessOutputs({ d, e }); }); @@ -199,6 +207,32 @@ describe("chai-zkit", () => { it("should pass if outputs are correct for given inputs and not all outputs passed", async () => { await expect(matrix).with.witnessInputs({ a, b, c }).to.have.witnessOutputs({ d }); }); + + it("should pass for base CircuitZKit object", async () => { + await expect(baseMatrix) + .with.witnessInputs({ a, b, c }) + .to.have.witnessOutputs([ + "2", + "0x5", + "0", + "17", + "26", + "0", + "0", + "0", + "0", + "1", + "4", + "0", + "16", + "25", + "0", + "0", + "0", + "0", + "1", + ]); + }); }); }); });