diff --git a/circuits/aggregator/dummy.go b/circuits/aggregator/dummy.go deleted file mode 100644 index c908bf6..0000000 --- a/circuits/aggregator/dummy.go +++ /dev/null @@ -1,44 +0,0 @@ -package aggregator - -import ( - "errors" - - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/frontend" -) - -type dummyCircuit struct { - nbConstraints int - SecretInput frontend.Variable `gnark:",secret"` - PublicInputs frontend.Variable `gnark:",public"` -} - -func (c *dummyCircuit) Define(api frontend.API) error { - cmtr, ok := api.(frontend.Committer) - if !ok { - return errors.New("api is not a commiter") - } - secret, err := cmtr.Commit(c.SecretInput) - if err != nil { - return err - } - api.AssertIsDifferent(secret, 0) - - res := api.Mul(c.SecretInput, c.SecretInput) - for i := 2; i < c.nbConstraints; i++ { - res = api.Mul(res, c.SecretInput) - } - api.AssertIsEqual(c.PublicInputs, res) - return nil -} - -// DummyPlaceholder function returns the placeholder of a dummy circtuit for -// the constraint.ConstraintSystem provided. -func DummyPlaceholder(mainCircuit constraint.ConstraintSystem) *dummyCircuit { - return &dummyCircuit{nbConstraints: mainCircuit.GetNbConstraints()} -} - -// DummyPlaceholder function returns the assigment of a dummy circtuit. -func DummyAssigment() *dummyCircuit { - return &dummyCircuit{PublicInputs: 1, SecretInput: 1} -} diff --git a/circuits/aggregator/dummy_helpers.go b/circuits/aggregator/helpers.go similarity index 68% rename from circuits/aggregator/dummy_helpers.go rename to circuits/aggregator/helpers.go index 801cd86..7ba5695 100644 --- a/circuits/aggregator/dummy_helpers.go +++ b/circuits/aggregator/helpers.go @@ -5,13 +5,11 @@ import ( "math/big" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend/groth16" - "github.com/consensys/gnark/backend/witness" "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/std/algebra/native/sw_bls12377" stdgroth16 "github.com/consensys/gnark/std/recursion/groth16" + "github.com/vocdoni/vocdoni-z-sandbox/circuits/dummy" ) // EncodeProofsSelector function returns a number that its base2 representation @@ -39,8 +37,8 @@ func EncodeProofsSelector(nValidProofs int) *big.Int { // something fails. func FillWithDummyFixed(placeholder, assigments *AggregatorCircuit, main constraint.ConstraintSystem, fromIdx int) error { // compile the dummy circuit for the main - dummyCCS, pubWitness, proof, vk, err := compileAndVerifyCircuit( - DummyPlaceholder(main), DummyAssigment(), + dummyCCS, pubWitness, proof, vk, err := dummy.Prove( + dummy.Placeholder(main), dummy.Assignment(1), ecc.BW6_761.ScalarField(), ecc.BLS12_377.ScalarField()) if err != nil { return err @@ -82,30 +80,3 @@ func FillWithDummyFixed(placeholder, assigments *AggregatorCircuit, main constra } return nil } - -func compileAndVerifyCircuit(placeholder, assigment frontend.Circuit, outer *big.Int, field *big.Int) (constraint.ConstraintSystem, witness.Witness, groth16.Proof, groth16.VerifyingKey, error) { - ccs, err := frontend.Compile(field, r1cs.NewBuilder, placeholder) - if err != nil { - return nil, nil, nil, nil, fmt.Errorf("compile error: %w", err) - } - pk, vk, err := groth16.Setup(ccs) - if err != nil { - return nil, nil, nil, nil, fmt.Errorf("setup error: %w", err) - } - fullWitness, err := frontend.NewWitness(assigment, field) - if err != nil { - return nil, nil, nil, nil, fmt.Errorf("full witness error: %w", err) - } - proof, err := groth16.Prove(ccs, pk, fullWitness, stdgroth16.GetNativeProverOptions(outer, field)) - if err != nil { - return nil, nil, nil, nil, fmt.Errorf("proof error: %w", err) - } - publicWitness, err := fullWitness.Public() - if err != nil { - return nil, nil, nil, nil, fmt.Errorf("pub witness error: %w", err) - } - if err = groth16.Verify(proof, vk, publicWitness, stdgroth16.GetNativeVerifierOptions(outer, field)); err != nil { - return nil, nil, nil, nil, fmt.Errorf("verify error: %w", err) - } - return ccs, publicWitness, proof, vk, nil -} diff --git a/circuits/dummy/dummy.go b/circuits/dummy/dummy.go new file mode 100644 index 0000000..ca41900 --- /dev/null +++ b/circuits/dummy/dummy.go @@ -0,0 +1,50 @@ +package dummy + +import ( + "errors" + + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/frontend" +) + +type Circuit struct { + nbConstraints int + SecretInput frontend.Variable `gnark:",secret"` + PublicInputs frontend.Variable `gnark:",public"` +} + +func (c *Circuit) Define(api frontend.API) error { + cmtr, ok := api.(frontend.Committer) + if !ok { + return errors.New("api is not a commiter") + } + secret, err := cmtr.Commit(c.SecretInput) + if err != nil { + return err + } + api.AssertIsDifferent(secret, 0) + + res := api.Mul(c.SecretInput, c.SecretInput) + for i := 2; i < c.nbConstraints; i++ { + res = api.Mul(res, c.SecretInput) + } + api.AssertIsEqual(c.PublicInputs, c.PublicInputs) + return nil +} + +// Placeholder function returns the placeholder of a dummy circuit for +// the constraint.ConstraintSystem provided. +func Placeholder(mainCircuit constraint.ConstraintSystem) *Circuit { + return &Circuit{nbConstraints: mainCircuit.GetNbConstraints()} +} + +// PlaceholderWithConstraints returns the placeholder of a dummy circuit +// with the desired number of constraints. +func PlaceholderWithConstraints(nbConstraints int) *Circuit { + return &Circuit{nbConstraints: nbConstraints} +} + +// Assignment returns the assignment of a dummy circuit. +func Assignment(publicInput frontend.Variable) *Circuit { + return &Circuit{PublicInputs: publicInput, SecretInput: 1} +} diff --git a/circuits/dummy/dummy_helpers.go b/circuits/dummy/dummy_helpers.go new file mode 100644 index 0000000..a13025f --- /dev/null +++ b/circuits/dummy/dummy_helpers.go @@ -0,0 +1,40 @@ +package dummy + +import ( + "fmt" + "math/big" + + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + stdgroth16 "github.com/consensys/gnark/std/recursion/groth16" +) + +func Prove(placeholder, assigment frontend.Circuit, outer *big.Int, field *big.Int) (constraint.ConstraintSystem, witness.Witness, groth16.Proof, groth16.VerifyingKey, error) { + ccs, err := frontend.Compile(field, r1cs.NewBuilder, placeholder) + if err != nil { + return nil, nil, nil, nil, fmt.Errorf("compile error: %w", err) + } + pk, vk, err := groth16.Setup(ccs) + if err != nil { + return nil, nil, nil, nil, fmt.Errorf("setup error: %w", err) + } + fullWitness, err := frontend.NewWitness(assigment, field) + if err != nil { + return nil, nil, nil, nil, fmt.Errorf("full witness error: %w", err) + } + proof, err := groth16.Prove(ccs, pk, fullWitness, stdgroth16.GetNativeProverOptions(outer, field)) + if err != nil { + return nil, nil, nil, nil, fmt.Errorf("proof error: %w", err) + } + publicWitness, err := fullWitness.Public() + if err != nil { + return nil, nil, nil, nil, fmt.Errorf("pub witness error: %w", err) + } + if err = groth16.Verify(proof, vk, publicWitness, stdgroth16.GetNativeVerifierOptions(outer, field)); err != nil { + return nil, nil, nil, nil, fmt.Errorf("verify error: %w", err) + } + return ccs, publicWitness, proof, vk, nil +} diff --git a/circuits/aggregator/dummy_test.go b/circuits/dummy/dummy_test.go similarity index 97% rename from circuits/aggregator/dummy_test.go rename to circuits/dummy/dummy_test.go index 8de296c..477b4fb 100644 --- a/circuits/aggregator/dummy_test.go +++ b/circuits/dummy/dummy_test.go @@ -1,4 +1,4 @@ -package aggregator +package dummy import ( "testing" @@ -28,7 +28,7 @@ func TestSameCircuitsInfo(t *testing.T) { c.Assert(err, qt.IsNil) mainVk := stdgroth16.PlaceholderVerifyingKey[sw_bls12377.G1Affine, sw_bls12377.G2Affine, sw_bls12377.GT](mainCCS) - dummyCCS, err := frontend.Compile(ecc.BLS12_377.ScalarField(), r1cs.NewBuilder, DummyPlaceholder(mainCCS)) + dummyCCS, err := frontend.Compile(ecc.BLS12_377.ScalarField(), r1cs.NewBuilder, Placeholder(mainCCS)) c.Assert(err, qt.IsNil) dummyVk := stdgroth16.PlaceholderVerifyingKey[sw_bls12377.G1Affine, sw_bls12377.G2Affine, sw_bls12377.GT](dummyCCS) diff --git a/circuits/statetransition/circuit.go b/circuits/statetransition/circuit.go index 9cd4414..a4ea661 100644 --- a/circuits/statetransition/circuit.go +++ b/circuits/statetransition/circuit.go @@ -1,9 +1,15 @@ package statetransition import ( + "fmt" + + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/emulated/sw_bw6761" + "github.com/consensys/gnark/std/recursion/groth16" "github.com/vocdoni/gnark-crypto-primitives/elgamal" "github.com/vocdoni/gnark-crypto-primitives/utils" + "github.com/vocdoni/vocdoni-z-sandbox/circuits/dummy" "github.com/vocdoni/vocdoni-z-sandbox/state" "github.com/vocdoni/vocdoni-z-sandbox/util" ) @@ -19,7 +25,6 @@ type Circuit struct { // --------------------------------------------------------------------------------------------- // PUBLIC INPUTS - // list of root hashes RootHashBefore frontend.Variable `gnark:",public"` RootHashAfter frontend.Variable `gnark:",public"` NumNewVotes frontend.Variable `gnark:",public"` @@ -28,7 +33,9 @@ type Circuit struct { // --------------------------------------------------------------------------------------------- // SECRET INPUTS - AggregatedProof frontend.Variable // mock, this should be a zkProof + AggregatedProof groth16.Proof[sw_bw6761.G1Affine, sw_bw6761.G2Affine] + AggregatedProofWitness groth16.Witness[sw_bw6761.ScalarField] + AggregatedProofVK groth16.VerifyingKey[sw_bw6761.G1Affine, sw_bw6761.G2Affine, sw_bw6761.GTEl] `gnark:"-"` ProcessID state.MerkleProof CensusRoot state.MerkleProof @@ -42,14 +49,19 @@ type Circuit struct { // Define declares the circuit's constraints func (circuit Circuit) Define(api frontend.API) error { - circuit.VerifyAggregatedZKProof(api) + if err := circuit.VerifyAggregatedWitness(api, HashFn); err != nil { + return err + } + if err := circuit.VerifyAggregatedZKProof(api); err != nil { + return err + } circuit.VerifyMerkleProofs(api, HashFn) circuit.VerifyMerkleTransitions(api, HashFn) circuit.VerifyBallots(api) return nil } -func (circuit Circuit) VerifyAggregatedZKProof(api frontend.API) { +func (circuit Circuit) VerifyAggregatedWitness(api frontend.API, hFn utils.Hasher) error { // all of the following values compose the preimage that is hashed // to produce the public input needed to verify AggregatedProof. // they are extracted from the MerkleProofs: @@ -62,29 +74,52 @@ func (circuit Circuit) VerifyAggregatedZKProof(api frontend.API) { // Addressess := circuit.Commitment[i].NewKey // Commitments := circuit.Commitment[i].NewValue - api.Println("verify AggregatedZKProof mock:", circuit.AggregatedProof) // mock - - packedInputs := func() frontend.Variable { - for i, p := range []state.MerkleProof{ - circuit.ProcessID, - circuit.CensusRoot, - circuit.BallotMode, - circuit.EncryptionKey, - } { - api.Println("packInputs mock", i, p.Value) // mock - } - for i := range circuit.Ballot { - api.Println("packInputs mock nullifier", i, circuit.Ballot[i].NewKey) // mock - api.Println("packInputs mock ballot", i, circuit.Ballot[i].NewValue) // mock - } - for i := range circuit.Commitment { - api.Println("packInputs mock address", i, circuit.Commitment[i].NewKey) // mock - api.Println("packInputs mock commitment", i, circuit.Commitment[i].NewValue) // mock - } - return 1 // mock, should return hash of packed inputs + inputs := []frontend.Variable{ + circuit.ProcessID.Value, + circuit.CensusRoot.Value, + circuit.BallotMode.Value, + circuit.EncryptionKey.Value, } + // for _, mt := range circuit.Ballot { + // inputs = append(inputs, mt.NewKey) // Nullifier + // } + // for _, mt := range circuit.Ballot { + // inputs = append(inputs, mt.NewValue) // Ballot + // } + // for _, mt := range circuit.Commitment { + // inputs = append(inputs, mt.NewKey) // Address + // } + // for _, mt := range circuit.Commitment { + // inputs = append(inputs, mt.NewValue) // Commitment + // } + // hash the inputs + hash, err := hFn(api, inputs...) + if err != nil { + return fmt.Errorf("failed to hash: %w", err) + } + api.Println("hash:", inputs) + api.Println("hashed", len(inputs), "inputs, hash =", util.PrettyHex(hash)) - api.AssertIsEqual(packedInputs(), 1) // TODO: mock, should actually verify AggregatedZKProof + api.AssertIsEqual(len(circuit.AggregatedProofWitness.Public), 1) + publicInput, err := utils.PackScalarToVar(api, &circuit.AggregatedProofWitness.Public[0]) + if err != nil { + return fmt.Errorf("failed to pack scalar to var: %w", err) + } + api.AssertIsEqual(hash, publicInput) + return nil +} + +func (circuit Circuit) VerifyAggregatedZKProof(api frontend.API) error { + // initialize the verifier + verifier, err := groth16.NewVerifier[sw_bw6761.ScalarField, sw_bw6761.G1Affine, sw_bw6761.G2Affine, sw_bw6761.GTEl](api) + if err != nil { + return fmt.Errorf("failed to create bw6761 verifier: %w", err) + } + // verify the proof with the hash as input and the fixed verification key + if err := verifier.AssertProof(circuit.AggregatedProofVK, circuit.AggregatedProof, circuit.AggregatedProofWitness); err != nil { + return fmt.Errorf("failed to verify aggregated proof: %w", err) + } + return nil } func (circuit Circuit) VerifyMerkleProofs(api frontend.API, hFn utils.Hasher) { @@ -136,3 +171,44 @@ func (circuit Circuit) VerifyBallots(api frontend.API) { api.AssertIsEqual(circuit.NumNewVotes, ballotCount) api.AssertIsEqual(circuit.NumOverwrites, overwrittenCount) } + +func CircuitPlaceholder() *Circuit { + _, ph, err := WitnessAndCircuitPlaceholder(0) + if err != nil { + panic(err) + } + return ph +} + +func WitnessAndCircuitPlaceholder(inputsHash frontend.Variable) (*Circuit, *Circuit, error) { + _, witness, proof, vk, err := dummy.Prove( + dummy.PlaceholderWithConstraints(10), dummy.Assignment(inputsHash), + ecc.BN254.ScalarField(), ecc.BW6_761.ScalarField()) + if err != nil { + return nil, nil, err + } + // parse dummy proof and witness + dummyProof, err := groth16.ValueOfProof[sw_bw6761.G1Affine, sw_bw6761.G2Affine](proof) + if err != nil { + return nil, nil, fmt.Errorf("dummy proof value error: %w", err) + } + dummyWitness, err := groth16.ValueOfWitness[sw_bw6761.ScalarField](witness) + if err != nil { + return nil, nil, fmt.Errorf("dummy witness value error: %w", err) + } + // set fixed dummy vk in the placeholders + dummyVK, err := groth16.ValueOfVerifyingKeyFixed[sw_bw6761.G1Affine, sw_bw6761.G2Affine, sw_bw6761.GTEl](vk) + if err != nil { + return nil, nil, fmt.Errorf("fix dummy vk error: %w", err) + } + + return &Circuit{ + AggregatedProof: dummyProof, + AggregatedProofWitness: dummyWitness, + AggregatedProofVK: dummyVK, + }, &Circuit{ + AggregatedProof: dummyProof, + AggregatedProofWitness: dummyWitness, + AggregatedProofVK: dummyVK, + }, nil +} diff --git a/circuits/statetransition/circuit_test.go b/circuits/statetransition/circuit_test.go index 80144cb..2a31840 100644 --- a/circuits/statetransition/circuit_test.go +++ b/circuits/statetransition/circuit_test.go @@ -29,7 +29,7 @@ func TestCircuitCompile(t *testing.T) { // enable log to see nbConstraints logger.Set(zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: "15:04:05"}).With().Timestamp().Logger()) - _, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &statetransition.Circuit{}) + _, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, statetransition.CircuitPlaceholder()) if err != nil { panic(err) } @@ -60,45 +60,59 @@ func TestCircuitProve(t *testing.T) { } assert := test.NewAssert(t) - assert.ProverSucceeded( - &statetransition.Circuit{}, - witness, - test.WithCurves(ecc.BN254), - test.WithBackends(backend.GROTH16)) - - debugLog(t, witness) - - // second batch - if err := s.StartBatch(); err != nil { - t.Fatal(err) - } - if err := s.AddVote(newMockVote(1, 100)); err != nil { // overwrite vote 1 - t.Fatal(err) - } - if err := s.AddVote(newMockVote(3, 30)); err != nil { // add vote 3 - t.Fatal(err) - } - if err := s.AddVote(newMockVote(4, 30)); err != nil { // add vote 4 - t.Fatal(err) - } - witness, err = GenerateWitnesses(s) + inputsHash, err := s.AggregatorProofInput() if err != nil { t.Fatal(err) } - if err := s.EndBatch(); err != nil { + + witnessProof, placeholder, err := statetransition.WitnessAndCircuitPlaceholder(arbo.BytesToBigInt(inputsHash)) + if err != nil { t.Fatal(err) } - // expected results: - // ResultsAdd: 16+17+10+100 = 143 - // ResultsSub: 16 = 16 - // Final: 16+17-16+10+100 = 127 + witness.AggregatedProof = witnessProof.AggregatedProof + witness.AggregatedProofWitness = witnessProof.AggregatedProofWitness assert.ProverSucceeded( - &statetransition.Circuit{}, + placeholder, witness, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16)) debugLog(t, witness) + + // // second batch + // if err := s.StartBatch(); err != nil { + // t.Fatal(err) + // } + // if err := s.AddVote(newMockVote(1, 100)); err != nil { // overwrite vote 1 + // t.Fatal(err) + // } + // if err := s.AddVote(newMockVote(3, 30)); err != nil { // add vote 3 + // t.Fatal(err) + // } + // if err := s.AddVote(newMockVote(4, 30)); err != nil { // add vote 4 + // t.Fatal(err) + // } + // witness, err = GenerateWitnesses(s) + // if err != nil { + // t.Fatal(err) + // } + // if err := s.EndBatch(); err != nil { + // t.Fatal(err) + // } + // // expected results: + // // ResultsAdd: 16+17+10+100 = 143 + // // ResultsSub: 16 = 16 + // // Final: 16+17-16+10+100 = 127 + // witness.AggregatedProof = witness1.AggregatedProof + // witness.AggregatedProofWitness = witness1.AggregatedProofWitness + // witness.AggregatedProofVK = witness1.AggregatedProofVK + // assert.ProverSucceeded( + // placeholder, + // witness, + // test.WithCurves(ecc.BN254), + // test.WithBackends(backend.GROTH16)) + + // debugLog(t, witness) } func debugLog(t *testing.T, witness *statetransition.Circuit) { @@ -141,6 +155,76 @@ func debugLog(t *testing.T, witness *statetransition.Circuit) { } } +type CircuitAggregatedWitness struct { + statetransition.Circuit +} + +func (circuit CircuitAggregatedWitness) Define(api frontend.API) error { + if err := circuit.VerifyAggregatedWitness(api, statetransition.HashFn); err != nil { + return err + } + return nil +} + +func TestCircuitAggregatedWitnessCompile(t *testing.T) { + if os.Getenv("RUN_CIRCUIT_TESTS") == "" || os.Getenv("RUN_CIRCUIT_TESTS") == "false" { + t.Skip("skipping circuit tests...") + } + // enable log to see nbConstraints + logger.Set(zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: "15:04:05"}).With().Timestamp().Logger()) + + _, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &CircuitAggregatedWitness{}) + if err != nil { + panic(err) + } +} + +func TestCircuitAggregatedWitnessProve(t *testing.T) { + if os.Getenv("RUN_CIRCUIT_TESTS") == "" || os.Getenv("RUN_CIRCUIT_TESTS") == "false" { + t.Skip("skipping circuit tests...") + } + s := newMockState(t) + + if err := s.StartBatch(); err != nil { + t.Fatal(err) + } + + if err := s.AddVote(newMockVote(1, 10)); err != nil { // new vote 1 + t.Fatal(err) + } + + witness, err := GenerateWitnesses(s) + if err != nil { + t.Fatal(err) + } + + if err := s.EndBatch(); err != nil { // expected result: 16+17=33 + t.Fatal(err) + } + assert := test.NewAssert(t) + inputsHash, err := s.AggregatorProofInput() + if err != nil { + t.Fatal(err) + } + + witnessProof, ph, err := statetransition.WitnessAndCircuitPlaceholder(arbo.BytesToBigInt(inputsHash)) + if err != nil { + t.Fatal(err) + } + witness.AggregatedProof = witnessProof.AggregatedProof + witness.AggregatedProofWitness = witnessProof.AggregatedProofWitness + + placeholder := &CircuitAggregatedWitness{} + placeholder.AggregatedProof = ph.AggregatedProof + placeholder.AggregatedProofWitness = ph.AggregatedProofWitness + placeholder.AggregatedProofVK = ph.AggregatedProofVK + assert.ProverSucceeded( + placeholder, + witness, + test.WithCurves(ecc.BN254), + test.WithBackends(backend.GROTH16)) +} + type CircuitBallots struct { statetransition.Circuit } diff --git a/circuits/statetransition/witness_test.go b/circuits/statetransition/witness_test.go index f37bd0c..ebcffc7 100644 --- a/circuits/statetransition/witness_test.go +++ b/circuits/statetransition/witness_test.go @@ -12,9 +12,6 @@ func GenerateWitnesses(o *state.State) (*statetransition.Circuit, error) { var err error witness := &statetransition.Circuit{} - // TODO: mock, replace by actual AggregatedProof - witness.AggregatedProof = 0 - // RootHashBefore witness.RootHashBefore, err = o.RootAsBigInt() if err != nil { diff --git a/circuits/test/statetransition/statetransition_inputs.go b/circuits/test/statetransition/statetransition_inputs.go new file mode 100644 index 0000000..bf8a292 --- /dev/null +++ b/circuits/test/statetransition/statetransition_inputs.go @@ -0,0 +1,192 @@ +package statetransitiontest + +import ( + "bytes" + "fmt" + "math/big" + "os" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/std/algebra/emulated/sw_bw6761" + stdgroth16 "github.com/consensys/gnark/std/recursion/groth16" + "github.com/vocdoni/arbo" + "github.com/vocdoni/vocdoni-z-sandbox/circuits" + "github.com/vocdoni/vocdoni-z-sandbox/circuits/aggregator" + "github.com/vocdoni/vocdoni-z-sandbox/circuits/statetransition" + aggregatortest "github.com/vocdoni/vocdoni-z-sandbox/circuits/test/aggregator" + ballottest "github.com/vocdoni/vocdoni-z-sandbox/circuits/test/ballotproof" + "github.com/vocdoni/vocdoni-z-sandbox/crypto/elgamal" + "github.com/vocdoni/vocdoni-z-sandbox/state" + "go.vocdoni.io/dvote/db/metadb" +) + +// StateTransitionTestResults struct includes relevant data after StateTransitionCircuit +// inputs generation +type StateTransitionTestResults struct { + ProcessId []byte + CensusRoot *big.Int + EncryptionPubKey [2]*big.Int + Nullifiers []*big.Int + Commitments []*big.Int + Addresses []*big.Int + EncryptedBallots [][ballottest.NFields][2][2]*big.Int + PlainEncryptedBallots []*big.Int +} + +// StateTransitionInputsForTest returns the StateTransitionTestResults, the placeholder +// and the assigments of a StateTransitionCircuit for the processId provided +// generating nValidVoters. If something fails it returns an error. +func StateTransitionInputsForTest(processId []byte, nValidVoters int) ( + *StateTransitionTestResults, *statetransition.Circuit, *statetransition.Circuit, error, +) { + // generate aggregator circuit and inputs + agInputs, agPlaceholder, agWitness, err := aggregatortest.AggregarorInputsForTest(processId, nValidVoters) + if err != nil { + return nil, nil, nil, err + } + // compile aggregoar circuit + agCCS, err := frontend.Compile(ecc.BLS12_377.ScalarField(), r1cs.NewBuilder, agPlaceholder) + if err != nil { + return nil, nil, nil, err + } + agPk, agVk, err := groth16.Setup(agCCS) + if err != nil { + return nil, nil, nil, err + } + // parse the witness to the circuit + fullWitness, err := frontend.NewWitness(agWitness, ecc.BLS12_377.ScalarField()) + if err != nil { + return nil, nil, nil, err + } + // generate the proof + proof, err := groth16.Prove(agCCS, agPk, fullWitness, stdgroth16.GetNativeProverOptions(ecc.BW6_761.ScalarField(), ecc.BLS12_377.ScalarField())) + if err != nil { + return nil, nil, nil, fmt.Errorf("err proving proof: %w", err) + } + // convert the proof to the circuit proof type + proofInBLS12377, err := stdgroth16.ValueOfProof[sw_bw6761.G1Affine, sw_bw6761.G2Affine](proof) + if err != nil { + return nil, nil, nil, err + } + // convert the public inputs to the circuit public inputs type + publicWitness, err := fullWitness.Public() + if err != nil { + return nil, nil, nil, err + } + err = groth16.Verify(proof, agVk, publicWitness, stdgroth16.GetNativeVerifierOptions(ecc.BW6_761.ScalarField(), ecc.BLS12_377.ScalarField())) + if err != nil { + return nil, nil, nil, err + } + agPublicInputs, err := stdgroth16.ValueOfWitness[sw_bw6761.ScalarField](publicWitness) + if err != nil { + return nil, nil, nil, err + } + + // pad voters inputs (nullifiers, commitments, addresses, plain EncryptedBallots) + nullifiers := circuits.BigIntArrayToN(agInputs.Nullifiers, aggregator.MaxVotes) + commitments := circuits.BigIntArrayToN(agInputs.Commitments, aggregator.MaxVotes) + addresses := circuits.BigIntArrayToN(agInputs.Addresses, aggregator.MaxVotes) + plainEncryptedBallots := circuits.BigIntArrayToN(agInputs.PlainEncryptedBallots, aggregator.MaxVotes*ballottest.NFields*4) + + // init final assigments stuff + s := newState( + processId, + agInputs.CensusRoot.Bytes(), + ballotMode().Bytes(), + pubkeyToBytes(agInputs.EncryptionPubKey)) + + if err := s.StartBatch(); err != nil { + return nil, nil, nil, err + } + for i := range agInputs.EncryptedBallots { + if err := s.AddVote(&state.Vote{ + Nullifier: arbo.BigIntToBytes(32, agInputs.Nullifiers[i]), + Ballot: toBallot(agInputs.EncryptedBallots[i]), + Address: arbo.BigIntToBytes(32, agInputs.Addresses[i]), + Commitment: agInputs.Commitments[i], + }); err != nil { + return nil, nil, nil, err + } + } + witness, err := GenerateWitnesses(s) + if err != nil { + return nil, nil, nil, err + } + if err := s.EndBatch(); err != nil { + return nil, nil, nil, err + } + + witness.AggregatedProof = proofInBLS12377 + witness.AggregatedProofWitness = agPublicInputs + + // create final placeholder + circuitPlaceholder := &statetransition.Circuit{ + // AggregatedProofWitness: stdgroth16.Witness[sw_bw6761.ScalarField]{}, + // AggregatedProof: stdgroth16.Proof[sw_bw6761.G1Affine, sw_bw6761.G2Affine]{}, + AggregatedProofVK: stdgroth16.VerifyingKey[sw_bw6761.G1Affine, sw_bw6761.G2Affine, sw_bw6761.GTEl]{}, + } + // fix the vote verifier verification key + fixedVk, err := stdgroth16.ValueOfVerifyingKeyFixed[sw_bw6761.G1Affine, sw_bw6761.G2Affine, sw_bw6761.GTEl](agVk) + if err != nil { + return nil, nil, nil, err + } + circuitPlaceholder.AggregatedProofVK = fixedVk + // // fill placeholder and witness with dummy circuits + // if err := aggregator.FillWithDummyFixed(finalPlaceholder, finalAssigments, agCCS, nValidVoters); err != nil { + // return nil, nil, nil, err + // } + return &StateTransitionTestResults{ + ProcessId: agInputs.ProcessId, + CensusRoot: agInputs.CensusRoot, + EncryptionPubKey: agInputs.EncryptionPubKey, + Nullifiers: nullifiers, + Commitments: commitments, + Addresses: addresses, + EncryptedBallots: agInputs.EncryptedBallots, + PlainEncryptedBallots: plainEncryptedBallots, + }, circuitPlaceholder, witness, nil +} + +func newState(processId, censusRoot, ballotMode, encryptionKey []byte) *state.State { + dir, err := os.MkdirTemp(os.TempDir(), "statetransition") + if err != nil { + panic(err) + } + db, err := metadb.New("pebble", dir) + if err != nil { + panic(err) + } + s, err := state.New(db, processId) + if err != nil { + panic(err) + } + + if err := s.Initialize( + censusRoot, + ballotMode, + encryptionKey, + ); err != nil { + panic(err) + } + + return s +} + +func toBallot(x [8][2][2]*big.Int) *elgamal.Ciphertexts { + z := elgamal.NewCiphertexts(state.Curve) + for i := range x { + z[i].C1.SetPoint(x[i][0][0], x[i][0][1]) + z[i].C2.SetPoint(x[i][1][0], x[i][1][1]) + } + return z +} + +func pubkeyToBytes(pubkey [2]*big.Int) []byte { + buf := bytes.Buffer{} + buf.Write(arbo.BigIntToBytes(32, pubkey[0])) + buf.Write(arbo.BigIntToBytes(32, pubkey[1])) + return buf.Bytes() +} diff --git a/circuits/test/statetransition/statetransition_test.go b/circuits/test/statetransition/statetransition_test.go new file mode 100644 index 0000000..b5090d1 --- /dev/null +++ b/circuits/test/statetransition/statetransition_test.go @@ -0,0 +1,34 @@ +package statetransitiontest + +import ( + "os" + "testing" + "time" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend" + stdgroth16 "github.com/consensys/gnark/std/recursion/groth16" + "github.com/consensys/gnark/test" + qt "github.com/frankban/quicktest" + "github.com/vocdoni/vocdoni-z-sandbox/util" +) + +func TestStateTransitionCircuit(t *testing.T) { + if os.Getenv("RUN_CIRCUIT_TESTS") == "" || os.Getenv("RUN_CIRCUIT_TESTS") == "false" { + t.Skip("skipping circuit tests...") + } + c := qt.New(t) + // inputs generation + now := time.Now() + processId := util.RandomBytes(20) + _, placeholder, witness, err := StateTransitionInputsForTest(processId, 3) + c.Assert(err, qt.IsNil) + c.Logf("inputs generation took %s", time.Since(now).String()) + // proving + now = time.Now() + assert := test.NewAssert(t) + assert.SolvingSucceeded(placeholder, witness, + test.WithCurves(ecc.BW6_761), test.WithBackends(backend.GROTH16), + test.WithProverOpts(stdgroth16.GetNativeProverOptions(ecc.BN254.ScalarField(), ecc.BW6_761.ScalarField()))) + c.Logf("proving took %s", time.Since(now).String()) +} diff --git a/circuits/test/statetransition/statetransition_witness.go b/circuits/test/statetransition/statetransition_witness.go new file mode 100644 index 0000000..3c54845 --- /dev/null +++ b/circuits/test/statetransition/statetransition_witness.go @@ -0,0 +1,111 @@ +package statetransitiontest + +import ( + "fmt" + "math" + + "github.com/consensys/gnark/frontend" + "github.com/vocdoni/vocdoni-z-sandbox/circuits" + ballottest "github.com/vocdoni/vocdoni-z-sandbox/circuits/test/ballotproof" + + "github.com/consensys/gnark/std/algebra/emulated/sw_bw6761" + "github.com/consensys/gnark/std/recursion/groth16" + "github.com/vocdoni/arbo" + "github.com/vocdoni/vocdoni-z-sandbox/circuits/statetransition" + "github.com/vocdoni/vocdoni-z-sandbox/state" +) + +func ballotMode() circuits.BallotMode[frontend.Variable] { + return circuits.BallotMode[frontend.Variable]{ + MaxCount: ballottest.MaxCount, + ForceUniqueness: ballottest.ForceUniqueness, + MaxValue: ballottest.MaxValue, + MinValue: ballottest.MinValue, + MaxTotalCost: int(math.Pow(float64(ballottest.MaxValue), float64(ballottest.CostExp))) * ballottest.MaxCount, + MinTotalCost: ballottest.MaxCount, + CostExp: ballottest.CostExp, + CostFromWeight: ballottest.CostFromWeight, + } +} + +func GenerateWitnesses(o *state.State) (*statetransition.Circuit, error) { + var err error + witness := &statetransition.Circuit{} + + // TODO: mock, replace by actual AggregatedProof + witness.AggregatedProof = groth16.Proof[sw_bw6761.G1Affine, sw_bw6761.G2Affine]{} + + // RootHashBefore + witness.RootHashBefore, err = o.RootAsBigInt() + if err != nil { + return nil, err + } + + // first get MerkleProofs, since they need to belong to RootHashBefore, i.e. before MerkleTransitions + if witness.ProcessID, err = o.GenMerkleProof(state.KeyProcessID); err != nil { + return nil, err + } + if witness.CensusRoot, err = o.GenMerkleProof(state.KeyCensusRoot); err != nil { + return nil, err + } + if witness.BallotMode, err = o.GenMerkleProof(state.KeyBallotMode); err != nil { + return nil, err + } + if witness.EncryptionKey, err = o.GenMerkleProof(state.KeyEncryptionKey); err != nil { + return nil, err + } + + // now build ordered chain of MerkleTransitions + + // add Ballots + for i := range witness.Ballot { + if i < len(o.Votes()) { + witness.Ballot[i], err = o.MerkleTransitionFromAddOrUpdate( + o.Votes()[i].Nullifier, o.Votes()[i].Ballot.Serialize()) + } else { + witness.Ballot[i], err = o.MerkleTransitionFromNoop() + } + if err != nil { + return nil, err + } + } + + // add Commitments + for i := range witness.Commitment { + if i < len(o.Votes()) { + witness.Commitment[i], err = o.MerkleTransitionFromAddOrUpdate( + o.Votes()[i].Address, arbo.BigIntToBytes(32, o.Votes()[i].Commitment)) + } else { + witness.Commitment[i], err = o.MerkleTransitionFromNoop() + } + if err != nil { + return nil, err + } + } + + // update ResultsAdd + witness.ResultsAdd, err = o.MerkleTransitionFromAddOrUpdate( + state.KeyResultsAdd, o.ResultsAdd.Add(o.ResultsAdd, o.BallotSum).Serialize()) + if err != nil { + return nil, fmt.Errorf("ResultsAdd: %w", err) + } + + // update ResultsSub + witness.ResultsSub, err = o.MerkleTransitionFromAddOrUpdate( + state.KeyResultsSub, o.ResultsSub.Add(o.ResultsSub, o.OverwriteSum).Serialize()) + if err != nil { + return nil, fmt.Errorf("ResultsSub: %w", err) + } + + // update stats + witness.NumNewVotes = o.BallotCount() + witness.NumOverwrites = o.OverwriteCount() + + // RootHashAfter + witness.RootHashAfter, err = o.RootAsBigInt() + if err != nil { + return nil, err + } + + return witness, nil +} diff --git a/circuits/types.go b/circuits/types.go index f2103f2..5157cd8 100644 --- a/circuits/types.go +++ b/circuits/types.go @@ -14,3 +14,7 @@ type BallotMode[T any] struct { CostFromWeight T EncryptionPubKey [2]T } + +func (bm BallotMode[T]) Bytes() []byte { + return []byte{0x00} +} diff --git a/state/state.go b/state/state.go index eacbd01..3d85f3d 100644 --- a/state/state.go +++ b/state/state.go @@ -7,6 +7,7 @@ import ( "github.com/vocdoni/arbo" "github.com/vocdoni/vocdoni-z-sandbox/crypto/ecc/curves" "github.com/vocdoni/vocdoni-z-sandbox/crypto/elgamal" + "github.com/vocdoni/vocdoni-z-sandbox/util" "go.vocdoni.io/dvote/db" "go.vocdoni.io/dvote/db/prefixeddb" ) @@ -163,3 +164,73 @@ func (o *State) OverwriteCount() int { func (o *State) Votes() []*Vote { return o.votes } + +func (o *State) ProcessID() []byte { + _, v, err := o.tree.Get(KeyProcessID) + if err != nil { + panic(err) + } + return v +} + +func (o *State) CensusRoot() []byte { + _, v, err := o.tree.Get(KeyCensusRoot) + if err != nil { + panic(err) + } + return v +} + +func (o *State) BallotMode() []byte { + _, v, err := o.tree.Get(KeyBallotMode) + if err != nil { + panic(err) + } + return v +} + +func (o *State) EncryptionKey() []byte { + _, v, err := o.tree.Get(KeyEncryptionKey) + if err != nil { + panic(err) + } + return v +} + +func (o *State) AggregatorProofInput() ([]byte, error) { + // ProcessID := circuit.ProcessID.Value + // CensusRoot := circuit.CensusRoot.Value + // BallotMode := circuit.BallotMode.Value + // EncryptionKey := circuit.EncryptionKey.Value + // Nullifiers := circuit.Ballot[i].NewKey + // Ballots := circuit.Ballot[i].NewValue + // Addressess := circuit.Commitment[i].NewKey + // Commitments := circuit.Commitment[i].NewValue + + inputs := [][]byte{ + o.ProcessID(), + o.CensusRoot(), + o.BallotMode(), + o.EncryptionKey(), + } + // for _, v := range o.votes { + // inputs = append(inputs, v.Nullifier) + // } + // for _, v := range o.votes { + // inputs = append(inputs, v.Ballot.Serialize()) + // } + // for _, v := range o.votes { + // inputs = append(inputs, v.Address) + // } + // for _, v := range o.votes { + // inputs = append(inputs, v.Commitment.Bytes()) + // } + // hash the inputs + hash, err := HashFunc.Hash(inputs...) + if err != nil { + return nil, err + } + fmt.Println("hash:", inputs) + fmt.Println("hashed", len(inputs), "inputs, hash =", util.PrettyHex(hash)) + return hash, nil +}