diff --git a/circuits/statetransition/circuit_test.go b/circuits/statetransition/circuit_test.go index 8b8e70e..c89c726 100644 --- a/circuits/statetransition/circuit_test.go +++ b/circuits/statetransition/circuit_test.go @@ -17,7 +17,6 @@ import ( "github.com/consensys/gnark/test" "github.com/rs/zerolog" "github.com/vocdoni/vocdoni-z-sandbox/circuits/statetransition" - "github.com/vocdoni/vocdoni-z-sandbox/crypto/ecc/curves" "github.com/vocdoni/vocdoni-z-sandbox/crypto/elgamal" "github.com/vocdoni/vocdoni-z-sandbox/state" @@ -254,12 +253,12 @@ func newMockVote(nullifier, amount uint64) state.Vote { big.NewInt(int64(nullifier)+int64(state.KeyNullifiersOffset))) // mock // generate a public mocked key - publicKey, _, err := elgamal.GenerateKey(curves.New(state.CurveType)) + publicKey, _, err := elgamal.GenerateKey(state.Curve) if err != nil { panic(fmt.Errorf("error generating public key: %v", err)) } - c, err := elgamal.NewCiphertext(state.CurveType).Encrypt(big.NewInt(int64(amount)), publicKey, nil) + c, err := elgamal.NewCiphertext(publicKey).Encrypt(big.NewInt(int64(amount)), publicKey, nil) if err != nil { panic(fmt.Errorf("error encrypting: %v", err)) } diff --git a/crypto/elgamal/ciphertext.go b/crypto/elgamal/ciphertext.go index 921685c..b1e4a86 100644 --- a/crypto/elgamal/ciphertext.go +++ b/crypto/elgamal/ciphertext.go @@ -10,26 +10,21 @@ import ( "github.com/vocdoni/arbo" gelgamal "github.com/vocdoni/gnark-crypto-primitives/elgamal" "github.com/vocdoni/vocdoni-z-sandbox/crypto/ecc" - "github.com/vocdoni/vocdoni-z-sandbox/crypto/ecc/curves" "github.com/vocdoni/vocdoni-z-sandbox/crypto/ecc/format" ) // Ciphertext represents an ElGamal encrypted message with homomorphic properties. // It is a wrapper for convenience of the elGamal ciphersystem that encapsulates the two points of a ciphertext. type Ciphertext struct { - CurveType string `json:"curveType"` - C1 ecc.Point `json:"c1"` - C2 ecc.Point `json:"c2"` + C1 ecc.Point `json:"c1"` + C2 ecc.Point `json:"c2"` } -// NewCiphertext creates a new Ciphertext with the given curve type. -// The curve type must be one of the supported curves by crypto/ecc/curves package. -func NewCiphertext(curveType string) *Ciphertext { - return &Ciphertext{ - C1: curves.New(curveType).New(), - C2: curves.New(curveType).New(), - CurveType: curveType, - } +// NewCiphertext creates a new Ciphertext on the same curve as the given Point. +// The Point must be one on of the supported curves by crypto/ecc/curves package, +// can be easily created with curves.New(type) +func NewCiphertext(curve ecc.Point) *Ciphertext { + return &Ciphertext{C1: curve.New(), C2: curve.New()} } // Encrypt encrypts a message using the public key provided as elliptic curve point. diff --git a/crypto/elgamal/ciphertext_test.go b/crypto/elgamal/ciphertext_test.go index 572c2f7..cbffdd6 100644 --- a/crypto/elgamal/ciphertext_test.go +++ b/crypto/elgamal/ciphertext_test.go @@ -11,7 +11,7 @@ import ( func TestNewCiphertext(t *testing.T) { c := qt.New(t) - cipher := NewCiphertext(curves.CurveTypeBN254) + cipher := NewCiphertext(curves.New(curves.CurveTypeBN254)) c.Assert(cipher, qt.Not(qt.IsNil)) c.Assert(cipher.C1, qt.Not(qt.IsNil)) c.Assert(cipher.C2, qt.Not(qt.IsNil)) @@ -29,7 +29,7 @@ func TestCiphertext_Encrypt(t *testing.T) { msg := big.NewInt(42) // Test with nil k (random k generation) - cipher := NewCiphertext(curves.CurveTypeBN254) + cipher := NewCiphertext(publicKey) encrypted, err := cipher.Encrypt(msg, publicKey, nil) c.Assert(err, qt.IsNil) c.Assert(encrypted, qt.Not(qt.IsNil)) @@ -57,16 +57,16 @@ func TestCiphertext_Add(t *testing.T) { k1 := big.NewInt(789) k2 := big.NewInt(987) - cipher1 := NewCiphertext(curves.CurveTypeBN254) + cipher1 := NewCiphertext(publicKey) encrypted1, err := cipher1.Encrypt(msg1, publicKey, k1) c.Assert(err, qt.IsNil) - cipher2 := NewCiphertext(curves.CurveTypeBN254) + cipher2 := NewCiphertext(publicKey) encrypted2, err := cipher2.Encrypt(msg2, publicKey, k2) c.Assert(err, qt.IsNil) // Test addition - result := NewCiphertext(curves.CurveTypeBN254) + result := NewCiphertext(publicKey) // Initialize result points with the first ciphertext's values result.C1 = encrypted1.C1 result.C2 = encrypted1.C2 @@ -88,7 +88,7 @@ func TestCiphertext_SerializeDeserialize(t *testing.T) { msg := big.NewInt(42) k := big.NewInt(789) - cipher := NewCiphertext(curves.CurveTypeBN254) + cipher := NewCiphertext(publicKey) encrypted, err := cipher.Encrypt(msg, publicKey, k) c.Assert(err, qt.IsNil) @@ -98,7 +98,7 @@ func TestCiphertext_SerializeDeserialize(t *testing.T) { c.Assert(len(serialized), qt.Equals, 128) // 4 * 32 bytes // Test deserialization - deserialized := NewCiphertext(curves.CurveTypeBN254) + deserialized := NewCiphertext(publicKey) deserialized.Deserialize(serialized) // Compare points @@ -124,7 +124,7 @@ func TestCiphertext_MarshalUnmarshal(t *testing.T) { msg := big.NewInt(42) k := big.NewInt(789) - cipher := NewCiphertext(curves.CurveTypeBN254) + cipher := NewCiphertext(publicKey) encrypted, err := cipher.Encrypt(msg, publicKey, k) c.Assert(err, qt.IsNil) @@ -134,7 +134,7 @@ func TestCiphertext_MarshalUnmarshal(t *testing.T) { c.Assert(marshaled, qt.Not(qt.IsNil)) // Test unmarshaling - unmarshaled := NewCiphertext(curves.CurveTypeBN254) + unmarshaled := NewCiphertext(publicKey) err = unmarshaled.Unmarshal(marshaled) c.Assert(err, qt.IsNil) @@ -161,7 +161,7 @@ func TestCiphertext_String(t *testing.T) { msg := big.NewInt(42) k := big.NewInt(789) - cipher := NewCiphertext(curves.CurveTypeBN254) + cipher := NewCiphertext(publicKey) encrypted, err := cipher.Encrypt(msg, publicKey, k) c.Assert(err, qt.IsNil) @@ -174,7 +174,7 @@ func TestCiphertext_String(t *testing.T) { func TestCiphertext_DeserializePanic(t *testing.T) { c := qt.New(t) - cipher := NewCiphertext(curves.CurveTypeBN254) + cipher := NewCiphertext(curves.New(curves.CurveTypeBN254)) // Test with invalid length, should panic c.Assert(func() { diff --git a/state/merkleproof.go b/state/merkleproof.go index d567fb8..481c71e 100644 --- a/state/merkleproof.go +++ b/state/merkleproof.go @@ -203,7 +203,7 @@ func (o *State) MerkleTransitionFromAddOrUpdate(k []byte, v []byte) (MerkleTrans } mp := MerkleTransitionFromArboProofPair(mpBefore, mpAfter) - oldCiphertext, newCiphertext := elgamal.NewCiphertext(CurveType), elgamal.NewCiphertext(CurveType) + oldCiphertext, newCiphertext := elgamal.NewCiphertext(Curve), elgamal.NewCiphertext(Curve) if len(mpBefore.Value) > 32 { oldCiphertext.Deserialize(mpBefore.Value) mp.IsOldElGamal = 1 diff --git a/state/state.go b/state/state.go index fd058a5..06dd66d 100644 --- a/state/state.go +++ b/state/state.go @@ -18,12 +18,14 @@ const ( MaxKeyLen = (MaxLevels + 7) / 8 // votes that were processed in AggregatedProof VoteBatchSize = 10 - // CurveType is the curve type used for the encryption - CurveType = curves.CurveTypeBabyJubJubGnark ) -// hashFunc is the hash function used in the state tree. -var hashFunc = arbo.HashMiMC_BN254{} +var ( + // HashFunc is the hash function used in the state tree. + HashFunc = arbo.HashFunctionMiMC_BN254 + // Curve is the curve used for the encryption + Curve = curves.New(curves.CurveTypeBabyJubJubGnark) +) var ( KeyProcessID = []byte{0x00} @@ -60,7 +62,7 @@ func New(db db.Database, processId []byte) (*State, error) { pdb := prefixeddb.NewPrefixedDatabase(db, processId) tree, err := arbo.NewTree(arbo.Config{ Database: pdb, MaxLevels: MaxLevels, - HashFunction: hashFunc, + HashFunction: HashFunc, }) if err != nil { return nil, err @@ -89,10 +91,10 @@ func (o *State) Initialize(censusRoot, ballotMode, encryptionKey []byte) error { if err := o.tree.Add(KeyEncryptionKey, encryptionKey); err != nil { return err } - if err := o.tree.Add(KeyResultsAdd, encrypt.NewCiphertext(CurveType).Serialize()); err != nil { + if err := o.tree.Add(KeyResultsAdd, encrypt.NewCiphertext(Curve).Serialize()); err != nil { return err } - if err := o.tree.Add(KeyResultsSub, encrypt.NewCiphertext(CurveType).Serialize()); err != nil { + if err := o.tree.Add(KeyResultsSub, encrypt.NewCiphertext(Curve).Serialize()); err != nil { return err } return nil @@ -108,10 +110,10 @@ func (o *State) Close() error { func (o *State) StartBatch() error { o.dbTx = o.db.WriteTx() if o.ResultsAdd == nil { - o.ResultsAdd = elgamal.NewCiphertext(CurveType) + o.ResultsAdd = elgamal.NewCiphertext(Curve) } if o.ResultsSub == nil { - o.ResultsSub = elgamal.NewCiphertext(CurveType) + o.ResultsSub = elgamal.NewCiphertext(Curve) } { @@ -129,8 +131,8 @@ func (o *State) StartBatch() error { o.ResultsSub.Deserialize(v) } - o.BallotSum = elgamal.NewCiphertext(CurveType) - o.OverwriteSum = elgamal.NewCiphertext(CurveType) + o.BallotSum = elgamal.NewCiphertext(Curve) + o.OverwriteSum = elgamal.NewCiphertext(Curve) o.ballotCount = 0 o.overwriteCount = 0 o.votes = []Vote{} diff --git a/state/vote.go b/state/vote.go index da459e6..74bc130 100644 --- a/state/vote.go +++ b/state/vote.go @@ -27,7 +27,7 @@ func (o *State) AddVote(v Vote) error { // if nullifier exists, it's a vote overwrite, need to count the overwritten vote // so it's later added to circuit.ResultsSub if _, value, err := o.tree.Get(v.Nullifier); err == nil { - oldVote := elgamal.NewCiphertext(CurveType) + oldVote := elgamal.NewCiphertext(Curve) oldVote.Deserialize(value) o.OverwriteSum.Add(o.OverwriteSum, oldVote) o.overwriteCount++