Skip to content

Commit

Permalink
Fix poseidon api (#851)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonas-lj authored Jan 9, 2025
1 parent 6ce2022 commit 4e4dfe7
Showing 1 changed file with 25 additions and 28 deletions.
53 changes: 25 additions & 28 deletions fastcrypto-zkp/src/bn254/poseidon/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@ use crate::FrRepr;
use ark_bn254::Fr;
use ark_ff::{BigInteger, PrimeField};
use byte_slice_cast::AsByteSlice;
use fastcrypto::error::FastCryptoError::{InputTooLong, InvalidInput};
use fastcrypto::error::FastCryptoError::InvalidInput;
use fastcrypto::error::{FastCryptoError, FastCryptoResult};
use ff::PrimeField as OtherPrimeField;
use neptune::poseidon::HashMode::OptimizedStatic;
use neptune::Poseidon;
use std::cmp::Ordering;

/// The output of the Poseidon hash function is a field element in BN254 which is 254 bits long, so
/// we need 32 bytes to represent it as an integer.
Expand Down Expand Up @@ -90,12 +89,12 @@ pub fn poseidon_merkle_tree(inputs: &[FieldElement]) -> FastCryptoResult<FieldEl
}

/// Calculate the poseidon hash of an array of inputs. Each input is interpreted as a BN254 field
/// element assuming a little-endian encoding. The field elements are then hashed using the poseidon
/// hash function ([poseidon_merkle_tree]) and the result is serialized as a little-endian integer (32
/// bytes).
/// element assuming a little-endian encoding and must be 32 bytes long.
/// The field elements are then hashed using the poseidon hash function ([poseidon_merkle_tree])
/// and the result is serialized as a little-endian integer (32 bytes).
///
/// If one of the inputs is in non-canonical form, e.g. it represents an integer greater than the
/// field size or is longer than 32 bytes, an error is returned.
/// field size or is not exactly 32 bytes, an [InvalidInput] error is returned.
///
/// This function is used as an interface to the poseidon hash function in the sui-framework.
pub fn poseidon_bytes(inputs: &[Vec<u8>]) -> FastCryptoResult<[u8; FIELD_ELEMENT_SIZE_IN_BYTES]> {
Expand All @@ -111,25 +110,21 @@ pub fn poseidon_bytes(inputs: &[Vec<u8>]) -> FastCryptoResult<[u8; FIELD_ELEMENT

/// Given a binary representation of a BN254 field element as an integer in little-endian encoding,
/// this function returns the corresponding field element. If the field element is not canonical (is
/// larger than the field size as an integer), an `FastCryptoError::InvalidInput` is returned.
/// larger than the field size as an integer), an [InvalidInput] is returned.
///
/// If more than 32 bytes is given, an `FastCryptoError::InputTooLong` is returned.
/// If the input is not exactly 32 bytes long, an [InvalidInput] is returned.
fn canonical_le_bytes_to_field_element(bytes: &[u8]) -> FastCryptoResult<FieldElement> {
match bytes.len().cmp(&FIELD_ELEMENT_SIZE_IN_BYTES) {
Ordering::Less => Ok(Fr::from_le_bytes_mod_order(bytes)),
Ordering::Equal => {
let field_element = Fr::from_le_bytes_mod_order(bytes);
// Unfortunately, there doesn't seem to be a nice way to check if a modular reduction
// happened without doing the extra work of serializing the field element again.
let reduced_bytes = field_element.into_bigint().to_bytes_le();
if reduced_bytes != bytes {
return Err(InvalidInput);
}
Ok(field_element)
}
Ordering::Greater => Err(InputTooLong(FIELD_ELEMENT_SIZE_IN_BYTES)),
if bytes.len() != FIELD_ELEMENT_SIZE_IN_BYTES {
return Err(InvalidInput);
}
let field_element = Fr::from_le_bytes_mod_order(bytes);
// Unfortunately, there doesn't seem to be a nice way to check if a modular reduction
// happened without doing the extra work of serializing the field element again.
let reduced_bytes = field_element.into_bigint().to_bytes_le();
if reduced_bytes != bytes {
return Err(InvalidInput);
}
.map(FieldElement)
Ok(FieldElement(field_element))
}

/// Convert a BN254 field element to a byte array as the little-endian representation of the
Expand Down Expand Up @@ -247,16 +242,22 @@ mod test {

#[test]
fn test_hash_to_bytes() {
let inputs: Vec<Vec<u8>> = vec![vec![1u8]];
let mut one = vec![0; 32];
one[0] = 1;

let inputs: Vec<Vec<u8>> = vec![one.clone()];
let hash = poseidon_bytes(&inputs).unwrap();
// 18586133768512220936620570745912940619677854269274689475585506675881198879027 in decimal
let expected =
hex::decode("33018202c57d898b84338b16d1a4960e133c6a4d656cfec1bd62a9ea00611729")
.unwrap();
assert_eq!(hash.as_slice(), &expected);

let mut two = vec![0; 32];
two[0] = 2;

// 7853200120776062878684798364095072458815029376092732009249414926327459813530 in decimal
let inputs: Vec<Vec<u8>> = vec![vec![1u8], vec![2u8]];
let inputs: Vec<Vec<u8>> = vec![one.clone(), two.clone()];
let hash = poseidon_bytes(&inputs).unwrap();
let expected =
hex::decode("9a1817447a60199e51453274f217362acfe962966b4cf63d4190d6e7f5c05c11")
Expand All @@ -266,10 +267,6 @@ mod test {
// Input larger than the modulus
let inputs = vec![vec![255; 32]];
assert!(poseidon_bytes(&inputs).is_err());

// Input smaller than the modulus
let inputs = vec![vec![255; 31]];
assert!(poseidon_bytes(&inputs).is_ok());
}

#[cfg(test)]
Expand Down

0 comments on commit 4e4dfe7

Please sign in to comment.