Skip to content

Commit

Permalink
Feature/proof testing (#3)
Browse files Browse the repository at this point in the history
* Add base CircuitZKit support

* Update Chai types

* Add strict property

* update readme

* Add addition checks

---------

Co-authored-by: Artem Chystiakov <[email protected]>
  • Loading branch information
Hrom131 and Arvolear authored Aug 8, 2024
1 parent 1647a09 commit fbed550
Show file tree
Hide file tree
Showing 7 changed files with 210 additions and 122 deletions.
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ After installing the package, you may use the following assertions:
```ts
const matrix = await zkit.getCircuit("Matrix");

// strict assertion, all the outputs must be present
await expect(matrix).with.witnessInputs({ a, b, c }).to.have.witnessOutputsStrict({ d, e, f });

// partial output assertion
await expect(matrix).with.witnessInputs({ a, b, c }).to.have.witnessOutputs({ d });

// strict assertion, all the outputs must be present
await expect(matrix).with.witnessInputs({ a, b, c }).to.have.strict.witnessOutputs({ d, e, f });

// provided output `e` doesn't match the obtained one
await expect(expect(matrix).with.witnessInputs({ a, b, c }).to.have.witnessOutputs({ e })).to.be.rejectedWith(
`Expected output "e" to be "[[2,5,0],[17,26,0],[0,0,0]]", but got "[[1,4,0],[16,25,0],[0,0,0]]"`,
Expand All @@ -58,4 +58,5 @@ await expect(

## Known limitations

- Do not use `not` chai negation prior `witnessInputs` call, this will break the typization.
- Temporarily, only the witness `input <> output` signals testing is supported.
4 changes: 2 additions & 2 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@solarity/chai-zkit",
"version": "0.0.1",
"version": "0.1.0",
"license": "MIT",
"author": "Distributed Lab",
"readme": "README.md",
Expand Down
27 changes: 22 additions & 5 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,31 @@ declare global {
}

interface ExpectStatic {
<T extends Circuit>(val: T, message?: string): Assertion<T>;
<T>(val: T, message?: string): Assertion<T>;
}

interface AsyncAssertion<T = any> extends Promise<void> {
not: AsyncAssertion<T>;
strict: AsyncAssertion<T>;
to: AsyncAssertion<T>;
be: AsyncAssertion<T>;
been: AsyncAssertion<T>;
is: AsyncAssertion<T>;
that: AsyncAssertion<T>;
which: AsyncAssertion<T>;
and: AsyncAssertion<T>;
has: AsyncAssertion<T>;
have: AsyncAssertion<T>;
with: AsyncAssertion<T>;
at: AsyncAssertion<T>;
of: AsyncAssertion<T>;
same: AsyncAssertion<T>;
but: AsyncAssertion<T>;
does: AsyncAssertion<T>;
witnessInputs(inputs: T extends Circuit ? ExtractInputs<T> : never): AsyncAssertion<T>;
witnessOutputs(outputs: T extends Circuit ? Partial<ExtractOutputs<T>> : never): AsyncAssertion<T>;
}

interface Assertion<T = any> {
to: Assertion<T>;
be: Assertion<T>;
Expand All @@ -37,12 +58,8 @@ declare global {
same: Assertion<T>;
but: Assertion<T>;
does: Assertion<T>;

witnessInputs(inputs: T extends Circuit ? ExtractInputs<T> : never): AsyncAssertion<T>;
witnessOutputsStrict(outputs: T extends Circuit ? ExtractOutputs<T> : never): AsyncAssertion<T>;
witnessOutputs(outputs: T extends Circuit ? Partial<ExtractOutputs<T>> : never): AsyncAssertion<T>;
}

interface AsyncAssertion<T = any> extends Assertion<T>, Promise<void> {}
}
}
43 changes: 38 additions & 5 deletions src/utils.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import * as fs from "fs";

import { CircuitZKit, Signal, Signals } from "@solarity/zkit";
import { CircuitZKit, NumberLike, Signal, Signals } from "@solarity/zkit";

export function loadOutputs(zkit: CircuitZKit, witness: bigint[], inputs: Signals): Signals {
const signalToIndex = loadSym(zkit);
Expand All @@ -24,6 +24,37 @@ export function loadOutputs(zkit: CircuitZKit, witness: bigint[], inputs: Signal
return parseOutputSignals(witness, signals.slice(0, minInputIndex), signalToIndex);
}

export function flattenSignals(signals: Signals): NumberLike[] {
let flattenSignalsArr: NumberLike[] = [];

for (const output of Object.keys(signals)) {
flattenSignalsArr.push(...flattenSignal(signals[output]));
}

return flattenSignalsArr;
}

export function compareSignals(actualSignal: Signal, expectedSignal: Signal): boolean {
const actualSignalValues: NumberLike[] = flattenSignal(actualSignal);
const expectedSignalValues: NumberLike[] = flattenSignal(expectedSignal);

if (actualSignalValues.length !== expectedSignalValues.length) {
return false;
}

for (let i = 0; i < actualSignalValues.length; i++) {
if (BigInt(actualSignalValues[i]) !== BigInt(expectedSignalValues[i])) {
return false;
}
}

return true;
}

export function stringifySignal(signal: Signal): string {
return JSON.stringify(signal, (_, v) => (typeof v === "bigint" ? v.toString() : v)).replaceAll(`"`, "");
}

function loadSym(zkit: CircuitZKit): Map<string, number> {
const symFile = zkit.mustGetArtifactsFilePath("sym");
const signals = new Map<string, number>();
Expand All @@ -42,10 +73,6 @@ function loadSym(zkit: CircuitZKit): Map<string, number> {
return signals;
}

export function stringifySignal(signal: Signal): string {
return JSON.stringify(signal, (_, v) => (typeof v === "bigint" ? v.toString() : v)).replaceAll(`"`, "");
}

function parseOutputSignals(witness: bigint[], signals: string[], signalToIndex: Map<string, number>): Signals {
let outputSignals: Signals = {};

Expand Down Expand Up @@ -110,3 +137,9 @@ function countSignalDimensions(signal: Signal): number {

return countSignalDimensions(signal[0]) + 1;
}

function flattenSignal(signal: Signal): NumberLike[] {
const flatValue = Array.isArray(signal) ? signal.flatMap((signal) => flattenSignal(signal)) : signal;

return Array.isArray(flatValue) ? flatValue : [flatValue];
}
97 changes: 50 additions & 47 deletions src/witness.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import { CircuitZKit, Signals } from "@solarity/zkit";
import { CircuitZKit, NumberLike, Signals } from "@solarity/zkit";

import { loadOutputs, stringifySignal } from "./utils";
import { compareSignals, flattenSignals, loadOutputs, stringifySignal } from "./utils";

export function witness(chai: Chai.ChaiStatic, utils: Chai.ChaiUtils): void {
chai.Assertion.addProperty("strict", function (this: any) {
utils.flag(this, "strict", true);

return this;
});

chai.Assertion.addMethod("witnessInputs", function (this: any, inputs: Signals) {
const obj = utils.flag(this, "object");

Expand All @@ -23,19 +29,22 @@ export function witness(chai: Chai.ChaiStatic, utils: Chai.ChaiUtils): void {
return this;
});

chai.Assertion.addMethod("witnessOutputsStrict", function (this: any, outputs: Signals) {
chai.Assertion.addChainableMethod;

chai.Assertion.addMethod("witnessOutputs", function (this: any, outputs: Signals | NumberLike[]) {
const obj = utils.flag(this, "object");
const isStrict = utils.flag(this, "strict");

if (!(obj instanceof CircuitZKit)) {
throw new Error("`witnessOutputsStrict` is expected to be called on `CircuitZKit`");
throw new Error("`witnessOutputs` is expected to be called on `CircuitZKit`");
}

const promise = (this.then === undefined ? Promise.resolve() : this).then(async () => {
const witness = utils.flag(this, "witness");
const inputs = utils.flag(this, "inputs");

if (!witness) {
throw new Error("`witnessOutputsStrict` is expected to be called after `witnessInputs`");
throw new Error("`witnessOutputs` is expected to be called after `witnessInputs`");
}

if (Object.keys(inputs).length === 0) {
Expand All @@ -44,58 +53,52 @@ export function witness(chai: Chai.ChaiStatic, utils: Chai.ChaiUtils): void {

const actual = loadOutputs(obj as CircuitZKit, witness, inputs);

if (Object.keys(actual).length !== Object.keys(outputs).length) {
throw new Error(`Expected ${Object.keys(outputs).length} outputs, but got ${Object.keys(actual).length}`);
}

for (const output of Object.keys(outputs)) {
this.assert(
stringifySignal(actual[output]) === stringifySignal(outputs[output]),
`Expected output "${output}" to be "${stringifySignal(outputs[output])}", but got "${stringifySignal(actual[output])}"`,
`Expected output "${output}" NOT to be "${stringifySignal(outputs[output])}", but it is"`,
);
}
witnessOutputsCompare(this, actual, outputs, isStrict);
});

this.then = promise.then.bind(promise);
this.catch = promise.catch.bind(promise);

return this;
});
}

chai.Assertion.addMethod("witnessOutputs", function (this: any, outputs: Signals) {
const obj = utils.flag(this, "object");

if (!(obj instanceof CircuitZKit)) {
throw new Error("`witnessOutputs` is expected to be called on `CircuitZKit`");
function witnessOutputsCompare(
instance: any,
actualOutputs: Signals,
expectedOutputs: Signals | NumberLike[],
isStrict?: boolean,
) {
if (Array.isArray(expectedOutputs)) {
const actualOutputsArr: NumberLike[] = flattenSignals(actualOutputs);

if (
(isStrict && actualOutputsArr.length !== expectedOutputs.length) ||
actualOutputsArr.length < expectedOutputs.length
) {
throw new Error(`Expected ${actualOutputsArr.length} outputs, but got ${expectedOutputs.length}`);
}

const promise = (this.then === undefined ? Promise.resolve() : this).then(async () => {
const witness = utils.flag(this, "witness");
const inputs = utils.flag(this, "inputs");

if (!witness) {
throw new Error("`witnessOutputs` is expected to be called after `witnessInputs`");
}

if (Object.keys(inputs).length === 0) {
throw new Error("Circuit must have at least one input to extract outputs");
}

const actual = loadOutputs(obj as CircuitZKit, witness, inputs);

for (const output of Object.keys(outputs)) {
this.assert(
stringifySignal(actual[output]) === stringifySignal(outputs[output]),
`Expected output "${output}" to be "${stringifySignal(outputs[output])}", but got "${stringifySignal(actual[output])}"`,
`Expected output "${output}" NOT to be "${stringifySignal(outputs[output])}", but it is"`,
);
}
expectedOutputs.forEach((output: NumberLike, index: number) => {
instance.assert(
BigInt(output) === BigInt(actualOutputsArr[index]),
`Expected output with index "${index}" to be "${output}", but got "${actualOutputsArr[index]}"`,
`Expected output "${output}" NOT to be "${output}", but it is"`,
);
});
} else {
if (isStrict && Object.keys(actualOutputs).length !== Object.keys(expectedOutputs).length) {
throw new Error(
`Expected ${Object.keys(expectedOutputs).length} outputs, but got ${Object.keys(actualOutputs).length}`,
);
}

this.then = promise.then.bind(promise);
this.catch = promise.catch.bind(promise);

return this;
});
for (const output of Object.keys(expectedOutputs)) {
instance.assert(
compareSignals(actualOutputs[output], expectedOutputs[output]),
`Expected output "${output}" to be "${stringifySignal(expectedOutputs[output])}", but got "${stringifySignal(actualOutputs[output])}"`,
`Expected output "${output}" NOT to be "${stringifySignal(expectedOutputs[output])}", but it is"`,
);
}
}
}
Loading

0 comments on commit fbed550

Please sign in to comment.