Skip to content

Commit

Permalink
Refactor AIR (#937)
Browse files Browse the repository at this point in the history
* refactor

* fix clippy

* rm old tests

* rm comments

* rm comments

---------

Co-authored-by: Diego K <[email protected]>
  • Loading branch information
ColoCarletti and diegokingston authored Nov 13, 2024
1 parent e650e0f commit d016a73
Show file tree
Hide file tree
Showing 20 changed files with 408 additions and 1,133 deletions.
2 changes: 1 addition & 1 deletion provers/stark/src/constraints/evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::trace::LDETraceTable;
use crate::traits::AIR;
use crate::{frame::Frame, prover::evaluate_polynomial_on_lde_domain};
use itertools::Itertools;
#[cfg(all(debug_assertions, not(feature = "parallel")))]
#[cfg(not(feature = "parallel"))]
use lambdaworks_math::polynomial::Polynomial;
use lambdaworks_math::{fft::errors::FFTError, field::element::FieldElement, traits::AsBytes};
#[cfg(feature = "parallel")]
Expand Down
3 changes: 2 additions & 1 deletion provers/stark/src/constraints/transition.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::ops::Div;

use crate::domain::Domain;
use crate::frame::Frame;
use crate::prover::evaluate_polynomial_on_lde_domain;
Expand All @@ -6,7 +8,6 @@ use lambdaworks_math::field::element::FieldElement;
use lambdaworks_math::field::traits::{IsFFTField, IsField, IsSubFieldOf};
use lambdaworks_math::polynomial::Polynomial;
use num_integer::Integer;
use std::ops::Div;
/// TransitionConstraint represents the behaviour that a transition constraint
/// over the computation that wants to be proven must comply with.
pub trait TransitionConstraint<F, E>: Send + Sync
Expand Down
13 changes: 0 additions & 13 deletions provers/stark/src/context.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::collections::HashSet;

use super::proof::options::ProofOptions;

#[derive(Clone, Debug)]
Expand All @@ -13,22 +11,11 @@ pub struct AirContext {
/// offsets that are needed to compute EVERY transition constraint, even if some
/// constraints don't use all of the indexes in said offsets.
pub transition_offsets: Vec<usize>,
pub transition_exemptions: Vec<usize>,
pub num_transition_constraints: usize,
}

impl AirContext {
pub fn num_transition_constraints(&self) -> usize {
self.num_transition_constraints
}

/// Returns the number of non-trivial different
/// transition exemptions.
pub fn num_transition_exemptions(&self) -> usize {
self.transition_exemptions
.iter()
.filter(|&x| *x != 0)
.collect::<HashSet<_>>()
.len()
}
}
10 changes: 4 additions & 6 deletions provers/stark/src/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,10 @@ pub fn validate_trace<A: AIR>(

// --------- VALIDATE TRANSITION CONSTRAINTS -----------
let n_transition_constraints = air.context().num_transition_constraints();
let transition_exemptions = &air.context().transition_exemptions;

let exemption_steps: Vec<usize> = vec![lde_trace.num_rows(); n_transition_constraints]
.iter()
.zip(transition_exemptions)
.map(|(trace_steps, exemptions)| trace_steps - exemptions)
let exemption_steps: Vec<usize> = std::iter::repeat(lde_trace.num_steps())
.take(n_transition_constraints)
.zip(air.transition_constraints())
.map(|(trace_steps, constraint)| trace_steps - constraint.end_exemptions())
.collect();

// Iterate over trace and compute transitions
Expand Down
20 changes: 8 additions & 12 deletions provers/stark/src/examples/bit_flags.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,6 @@ impl TransitionConstraint<StarkField, StarkField> for ZeroFlagConstraint {
16
}

fn offset(&self) -> usize {
15
}

fn evaluate(
&self,
frame: &Frame<StarkField, StarkField>,
Expand Down Expand Up @@ -130,17 +126,12 @@ impl AIR for BitFlagsAIR {
let flag_constraint = Box::new(ZeroFlagConstraint::new());
let constraints: Vec<Box<dyn TransitionConstraint<Self::Field, Self::FieldExtension>>> =
vec![bit_constraint, flag_constraint];
// vec![flag_constraint];
// vec![bit_constraint];

let num_transition_constraints = constraints.len();
let transition_exemptions: Vec<_> =
constraints.iter().map(|c| c.end_exemptions()).collect();

let context = AirContext {
proof_options: proof_options.clone(),
trace_columns: 1,
transition_exemptions,
trace_columns: 2,
transition_offsets: vec![0],
num_transition_constraints,
};
Expand Down Expand Up @@ -195,7 +186,7 @@ impl AIR for BitFlagsAIR {
}
}

pub fn bit_prefix_flag_trace(num_steps: usize) -> TraceTable<StarkField> {
pub fn bit_prefix_flag_trace(num_steps: usize) -> TraceTable<StarkField, StarkField> {
debug_assert!(num_steps.is_power_of_two());
let step: Vec<Felt252> = [
1031u64, 515, 257, 128, 64, 32, 16, 8, 4, 2, 1, 0, 0, 0, 0, 0,
Expand All @@ -207,5 +198,10 @@ pub fn bit_prefix_flag_trace(num_steps: usize) -> TraceTable<StarkField> {
let mut data: Vec<Felt252> = iter::repeat(step).take(num_steps).flatten().collect();
data[0] = Felt252::from(1030);

TraceTable::new(data, 1, 0, 16)
let mut dummy_column = (0..16).map(Felt252::from).collect();
dummy_column = iter::repeat(dummy_column)
.take(num_steps)
.flatten()
.collect();
TraceTable::from_columns_main(vec![data, dummy_column], 16)
}
9 changes: 2 additions & 7 deletions provers/stark/src/examples/dummy_air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ impl AIR for DummyAIR {
let context = AirContext {
proof_options: proof_options.clone(),
trace_columns: 2,
transition_exemptions: vec![0, 2],
transition_offsets: vec![0, 1, 2],
num_transition_constraints: 2,
};
Expand Down Expand Up @@ -198,7 +197,7 @@ impl AIR for DummyAIR {
}
}

pub fn dummy_trace<F: IsFFTField>(trace_length: usize) -> TraceTable<F> {
pub fn dummy_trace<F: IsFFTField>(trace_length: usize) -> TraceTable<F, F> {
let mut ret: Vec<FieldElement<F>> = vec![];

let a0 = FieldElement::one();
Expand All @@ -211,9 +210,5 @@ pub fn dummy_trace<F: IsFFTField>(trace_length: usize) -> TraceTable<F> {
ret.push(ret[i - 1].clone() + ret[i - 2].clone());
}

TraceTable::from_columns(
vec![vec![FieldElement::<F>::one(); trace_length], ret],
2,
1,
)
TraceTable::from_columns_main(vec![vec![FieldElement::<F>::one(); trace_length], ret], 1)
}
46 changes: 4 additions & 42 deletions provers/stark/src/examples/fibonacci_2_cols_shifted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ where

let context = AirContext {
proof_options: proof_options.clone(),
transition_exemptions: vec![1, 1],
transition_offsets: vec![0, 1],
num_transition_constraints: 2,
trace_columns: 2,
Expand Down Expand Up @@ -238,7 +237,7 @@ where
pub fn compute_trace<F: IsFFTField>(
initial_value: FieldElement<F>,
trace_length: usize,
) -> TraceTable<F> {
) -> TraceTable<F, F> {
let mut x = FieldElement::one();
let mut y = initial_value;
let mut col0 = vec![x.clone()];
Expand All @@ -250,7 +249,7 @@ pub fn compute_trace<F: IsFFTField>(
col1.push(y.clone());
}

TraceTable::from_columns(vec![col0, col1], 2, 1)
TraceTable::from_columns_main(vec![col0, col1], 1)
}

#[cfg(test)]
Expand All @@ -264,46 +263,9 @@ mod tests {
#[test]
fn trace_has_expected_rows() {
let trace = compute_trace(FieldElement::<Stark252PrimeField>::one(), 8);
assert_eq!(trace.n_rows(), 8);
assert_eq!(trace.num_rows(), 8);

let trace = compute_trace(FieldElement::<Stark252PrimeField>::one(), 64);
assert_eq!(trace.n_rows(), 64);
}

#[test]
fn trace_of_8_rows_is_correctly_calculated() {
let trace = compute_trace(FieldElement::<Stark252PrimeField>::one(), 8);
assert_eq!(
trace.get_row(0),
vec![FieldElement::one(), FieldElement::one()]
);
assert_eq!(
trace.get_row(1),
vec![FieldElement::one(), FieldElement::from(2)]
);
assert_eq!(
trace.get_row(2),
vec![FieldElement::from(2), FieldElement::from(3)]
);
assert_eq!(
trace.get_row(3),
vec![FieldElement::from(3), FieldElement::from(5)]
);
assert_eq!(
trace.get_row(4),
vec![FieldElement::from(5), FieldElement::from(8)]
);
assert_eq!(
trace.get_row(5),
vec![FieldElement::from(8), FieldElement::from(13)]
);
assert_eq!(
trace.get_row(6),
vec![FieldElement::from(13), FieldElement::from(21)]
);
assert_eq!(
trace.get_row(7),
vec![FieldElement::from(21), FieldElement::from(34)]
);
assert_eq!(trace.num_rows(), 64);
}
}
5 changes: 2 additions & 3 deletions provers/stark/src/examples/fibonacci_2_columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ where

let context = AirContext {
proof_options: proof_options.clone(),
transition_exemptions: vec![1, 1],
transition_offsets: vec![0, 1],
num_transition_constraints: constraints.len(),
trace_columns: 2,
Expand Down Expand Up @@ -209,7 +208,7 @@ where
pub fn compute_trace<F: IsFFTField>(
initial_values: [FieldElement<F>; 2],
trace_length: usize,
) -> TraceTable<F> {
) -> TraceTable<F, F> {
let mut ret1: Vec<FieldElement<F>> = vec![];
let mut ret2: Vec<FieldElement<F>> = vec![];

Expand All @@ -222,5 +221,5 @@ pub fn compute_trace<F: IsFFTField>(
ret2.push(new_val + ret2[i - 1].clone());
}

TraceTable::from_columns(vec![ret1, ret2], 2, 1)
TraceTable::from_columns_main(vec![ret1, ret2], 1)
}
32 changes: 18 additions & 14 deletions provers/stark/src/examples/fibonacci_rap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,10 @@ where
Box::new(PermutationConstraint::new()),
];

let exemptions = 3 + trace_length - pub_inputs.steps - 1;

let context = AirContext {
proof_options: proof_options.clone(),
trace_columns: 3,
transition_offsets: vec![0, 1, 2],
transition_exemptions: vec![exemptions, 1],
num_transition_constraints: transition_constraints.len(),
};

Expand All @@ -186,15 +183,15 @@ where

fn build_auxiliary_trace(
&self,
main_trace: &TraceTable<Self::Field>,
trace: &mut TraceTable<Self::Field, Self::FieldExtension>,
challenges: &[FieldElement<F>],
) -> TraceTable<Self::Field> {
let main_segment_cols = main_trace.columns();
) {
let main_segment_cols = trace.columns_main();
let not_perm = &main_segment_cols[0];
let perm = &main_segment_cols[1];
let gamma = &challenges[0];

let trace_len = main_trace.n_rows();
let trace_len = trace.num_rows();

let mut aux_col = Vec::new();
for i in 0..trace_len {
Expand All @@ -208,7 +205,10 @@ where
aux_col.push(z_i * n_p_term.div(p_term));
}
}
TraceTable::from_columns(vec![aux_col], 0, 1)

for (i, aux_elem) in aux_col.iter().enumerate().take(trace.num_rows()) {
trace.set_aux(i, 0, aux_elem.clone())
}
}

fn build_rap_challenges(
Expand Down Expand Up @@ -236,7 +236,6 @@ where
let a0_aux = BoundaryConstraint::new_aux(0, 0, FieldElement::<Self::FieldExtension>::one());

BoundaryConstraints::from_constraints(vec![a0, a1, a0_aux])
// BoundaryConstraints::from_constraints(vec![a0, a1])
}

fn transition_constraints(
Expand Down Expand Up @@ -274,9 +273,8 @@ where
pub fn fibonacci_rap_trace<F: IsFFTField>(
initial_values: [FieldElement<F>; 2],
trace_length: usize,
) -> TraceTable<F> {
) -> TraceTable<F, F> {
let mut fib_seq: Vec<FieldElement<F>> = vec![];

fib_seq.push(initial_values[0].clone());
fib_seq.push(initial_values[1].clone());

Expand All @@ -294,7 +292,13 @@ pub fn fibonacci_rap_trace<F: IsFFTField>(
let mut trace_cols = vec![fib_seq, fib_permuted];
resize_to_next_power_of_two(&mut trace_cols);

TraceTable::from_columns(trace_cols, 2, 1)
let mut trace = TraceTable::allocate_with_zeros(trace_cols[0].len(), 2, 1, 1);
for i in 0..trace.num_rows() {
trace.set_main(i, 0, trace_cols[0][i].clone());
trace.set_main(i, 1, trace_cols[1][i].clone());
}

trace
}

#[cfg(test)]
Expand Down Expand Up @@ -337,13 +341,13 @@ mod test {
];
resize_to_next_power_of_two(&mut expected_trace);

assert_eq!(trace.columns(), expected_trace);
assert_eq!(trace.columns_main(), expected_trace);
}

#[test]
fn aux_col() {
let trace = fibonacci_rap_trace([FE17::from(1), FE17::from(1)], 64);
let trace_cols = trace.columns();
let trace_cols = trace.columns_main();

let not_perm = trace_cols[0].clone();
let perm = trace_cols[1].clone();
Expand Down
5 changes: 2 additions & 3 deletions provers/stark/src/examples/quadratic_air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ where
let context = AirContext {
proof_options: proof_options.clone(),
trace_columns: 1,
transition_exemptions: vec![1],
transition_offsets: vec![0, 1],
num_transition_constraints: constraints.len(),
};
Expand Down Expand Up @@ -161,7 +160,7 @@ where
pub fn quadratic_trace<F: IsFFTField>(
initial_value: FieldElement<F>,
trace_length: usize,
) -> TraceTable<F> {
) -> TraceTable<F, F> {
let mut ret: Vec<FieldElement<F>> = vec![];

ret.push(initial_value);
Expand All @@ -170,5 +169,5 @@ pub fn quadratic_trace<F: IsFFTField>(
ret.push(ret[i - 1].clone() * ret[i - 1].clone());
}

TraceTable::from_columns(vec![ret], 1, 1)
TraceTable::from_columns_main(vec![ret], 1)
}
5 changes: 2 additions & 3 deletions provers/stark/src/examples/simple_fibonacci.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ where
let context = AirContext {
proof_options: proof_options.clone(),
trace_columns: 1,
transition_exemptions: vec![2],
transition_offsets: vec![0, 1, 2],
num_transition_constraints: constraints.len(),
};
Expand Down Expand Up @@ -162,7 +161,7 @@ where
pub fn fibonacci_trace<F: IsFFTField>(
initial_values: [FieldElement<F>; 2],
trace_length: usize,
) -> TraceTable<F> {
) -> TraceTable<F, F> {
let mut ret: Vec<FieldElement<F>> = vec![];

ret.push(initial_values[0].clone());
Expand All @@ -172,5 +171,5 @@ pub fn fibonacci_trace<F: IsFFTField>(
ret.push(ret[i - 1].clone() + ret[i - 2].clone());
}

TraceTable::from_columns(vec![ret], 1, 1)
TraceTable::from_columns_main(vec![ret], 1)
}
Loading

0 comments on commit d016a73

Please sign in to comment.