Skip to content

Commit

Permalink
Bitwise Xor Struct
Browse files Browse the repository at this point in the history
  • Loading branch information
Gali-StarkWare committed Jan 30, 2025
1 parent c74c0e9 commit b8dd76c
Showing 1 changed file with 88 additions and 1 deletion.
89 changes: 88 additions & 1 deletion stwo_cairo_prover/crates/prover/src/cairo_air/preprocessed.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::simd::{u32x16, Simd};

use itertools::{chain, Itertools};
use prover_types::simd::LOG_N_LANES;
use stwo_prover::constraint_framework::preprocessed_columns::{IsFirst, PreProcessedColumnId};
Expand Down Expand Up @@ -90,14 +92,61 @@ impl Seq {

pub fn id(&self) -> PreProcessedColumnId {
PreProcessedColumnId {
id: format!("preprocessed_seq_{}", self.log_size).to_string(),
id: format!("seq_{}", self.log_size).to_string(),
}
}
}

/// A table of a,b,c, where a,b,c are integers and a ^ b = c.
///
/// # Attributes
///
/// - `n_bits`: The number of bits in each integer.
/// - `col_index`: The column index in the preprocessed table.
#[derive(Debug)]
pub struct BitwiseXor {
n_bits: u32,
col_index: usize,
}
impl BitwiseXor {
pub const fn new(n_bits: u32, col_index: usize) -> Self {
assert!(col_index < 3, "col_index must be in range 0..=2");
Self { n_bits, col_index }
}

pub fn id(&self) -> PreProcessedColumnId {
PreProcessedColumnId {
id: format!("bitwise_xor_{}_{}", self.n_bits, self.col_index),
}
}

pub const fn log_size(&self) -> u32 {
2 * self.n_bits
}

pub fn packed_at(&self, vec_row: usize) -> PackedM31 {
let lhs = || -> u32x16 {
(SIMD_ENUMERATION_0 + Simd::splat((vec_row * N_LANES) as u32)) >> self.n_bits
};
let rhs = || -> u32x16 {
(SIMD_ENUMERATION_0 + Simd::splat((vec_row * N_LANES) as u32))
& Simd::splat((1 << self.n_bits) - 1)
};
let simd = match self.col_index {
0 => lhs(),
1 => rhs(),
2 => lhs() ^ rhs(),
_ => unreachable!(),
};
unsafe { PackedM31::from_simd_unchecked(simd) }
}
}

#[cfg(test)]
mod tests {
use super::*;
const LOG_SIZE: u32 = 8;
use stwo_prover::core::backend::Column;

#[test]
fn test_columns_are_in_decending_order() {
Expand All @@ -107,4 +156,42 @@ mod tests {
.windows(2)
.all(|w| w[0].log_size() >= w[1].log_size()));
}

#[test]
fn test_gen_seq() {
let seq = Seq::new(LOG_SIZE).gen_column_simd();
for i in 0..(1 << LOG_SIZE) {
assert_eq!(seq.at(i), BaseField::from_u32_unchecked(i as u32));
}
}

#[test]
fn test_packed_at_seq() {
let seq = Seq::new(LOG_SIZE);
let expected_seq: [_; 1 << LOG_SIZE] = std::array::from_fn(|i| M31::from(i as u32));
let packed_seq = std::array::from_fn::<_, { (1 << LOG_SIZE) / N_LANES }, _>(|i| {
seq.packed_at(i).to_array()
})
.concat();
assert_eq!(packed_seq, expected_seq);
}

#[test]
fn test_packed_at_bitwise_xor() {
let bitwise_a = BitwiseXor::new(LOG_SIZE, 0);
let bitwise_b = BitwiseXor::new(LOG_SIZE, 1);
let bitwise_xor = BitwiseXor::new(LOG_SIZE, 2);
let index: usize = 1000;
let a = index / (1 << LOG_SIZE);
let b = index % (1 << LOG_SIZE);
let expected_xor = a ^ b;

let res_a = bitwise_a.packed_at(index / N_LANES).to_array()[index % N_LANES];
let res_b = bitwise_b.packed_at(index / N_LANES).to_array()[index % N_LANES];
let res_xor = bitwise_xor.packed_at(index / N_LANES).to_array()[index % N_LANES];

assert_eq!(res_a.0, a as u32);
assert_eq!(res_b.0, b as u32);
assert_eq!(res_xor.0, expected_xor as u32);
}
}

0 comments on commit b8dd76c

Please sign in to comment.