From b8dd76cc3bf8995d33007d2abfcb9aef3a489509 Mon Sep 17 00:00:00 2001 From: Gali Michlevich Date: Tue, 21 Jan 2025 11:42:59 +0200 Subject: [PATCH] Bitwise Xor Struct --- .../prover/src/cairo_air/preprocessed.rs | 89 ++++++++++++++++++- 1 file changed, 88 insertions(+), 1 deletion(-) diff --git a/stwo_cairo_prover/crates/prover/src/cairo_air/preprocessed.rs b/stwo_cairo_prover/crates/prover/src/cairo_air/preprocessed.rs index 049c8dea..d21d0a00 100644 --- a/stwo_cairo_prover/crates/prover/src/cairo_air/preprocessed.rs +++ b/stwo_cairo_prover/crates/prover/src/cairo_air/preprocessed.rs @@ -1,3 +1,5 @@ +use std::simd::{u32x16, Simd}; + use itertools::{chain, Itertools}; use prover_types::simd::LOG_N_LANES; use stwo_prover::constraint_framework::preprocessed_columns::{IsFirst, PreProcessedColumnId}; @@ -90,14 +92,61 @@ impl Seq { pub fn id(&self) -> PreProcessedColumnId { PreProcessedColumnId { - id: format!("preprocessed_seq_{}", self.log_size).to_string(), + id: format!("seq_{}", self.log_size).to_string(), } } } +/// A table of a,b,c, where a,b,c are integers and a ^ b = c. +/// +/// # Attributes +/// +/// - `n_bits`: The number of bits in each integer. +/// - `col_index`: The column index in the preprocessed table. +#[derive(Debug)] +pub struct BitwiseXor { + n_bits: u32, + col_index: usize, +} +impl BitwiseXor { + pub const fn new(n_bits: u32, col_index: usize) -> Self { + assert!(col_index < 3, "col_index must be in range 0..=2"); + Self { n_bits, col_index } + } + + pub fn id(&self) -> PreProcessedColumnId { + PreProcessedColumnId { + id: format!("bitwise_xor_{}_{}", self.n_bits, self.col_index), + } + } + + pub const fn log_size(&self) -> u32 { + 2 * self.n_bits + } + + pub fn packed_at(&self, vec_row: usize) -> PackedM31 { + let lhs = || -> u32x16 { + (SIMD_ENUMERATION_0 + Simd::splat((vec_row * N_LANES) as u32)) >> self.n_bits + }; + let rhs = || -> u32x16 { + (SIMD_ENUMERATION_0 + Simd::splat((vec_row * N_LANES) as u32)) + & Simd::splat((1 << self.n_bits) - 1) + }; + let simd = match self.col_index { + 0 => lhs(), + 1 => rhs(), + 2 => lhs() ^ rhs(), + _ => unreachable!(), + }; + unsafe { PackedM31::from_simd_unchecked(simd) } + } +} + #[cfg(test)] mod tests { use super::*; + const LOG_SIZE: u32 = 8; + use stwo_prover::core::backend::Column; #[test] fn test_columns_are_in_decending_order() { @@ -107,4 +156,42 @@ mod tests { .windows(2) .all(|w| w[0].log_size() >= w[1].log_size())); } + + #[test] + fn test_gen_seq() { + let seq = Seq::new(LOG_SIZE).gen_column_simd(); + for i in 0..(1 << LOG_SIZE) { + assert_eq!(seq.at(i), BaseField::from_u32_unchecked(i as u32)); + } + } + + #[test] + fn test_packed_at_seq() { + let seq = Seq::new(LOG_SIZE); + let expected_seq: [_; 1 << LOG_SIZE] = std::array::from_fn(|i| M31::from(i as u32)); + let packed_seq = std::array::from_fn::<_, { (1 << LOG_SIZE) / N_LANES }, _>(|i| { + seq.packed_at(i).to_array() + }) + .concat(); + assert_eq!(packed_seq, expected_seq); + } + + #[test] + fn test_packed_at_bitwise_xor() { + let bitwise_a = BitwiseXor::new(LOG_SIZE, 0); + let bitwise_b = BitwiseXor::new(LOG_SIZE, 1); + let bitwise_xor = BitwiseXor::new(LOG_SIZE, 2); + let index: usize = 1000; + let a = index / (1 << LOG_SIZE); + let b = index % (1 << LOG_SIZE); + let expected_xor = a ^ b; + + let res_a = bitwise_a.packed_at(index / N_LANES).to_array()[index % N_LANES]; + let res_b = bitwise_b.packed_at(index / N_LANES).to_array()[index % N_LANES]; + let res_xor = bitwise_xor.packed_at(index / N_LANES).to_array()[index % N_LANES]; + + assert_eq!(res_a.0, a as u32); + assert_eq!(res_b.0, b as u32); + assert_eq!(res_xor.0, expected_xor as u32); + } }