From b3fcb47cdcd49c55a219b7e7f15e9fdf57102568 Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Fri, 23 Feb 2024 18:06:11 +0100 Subject: [PATCH 1/9] Add traits State and Transition --- .../src/cpu/kernel/interpreter.rs | 195 +++++- evm_arithmetization/src/generation/mod.rs | 196 +----- .../src/generation/prover_input.rs | 1 + evm_arithmetization/src/generation/state.rs | 397 +++++++++++- evm_arithmetization/src/witness/operation.rs | 221 ++----- evm_arithmetization/src/witness/transition.rs | 570 +++++++++--------- 6 files changed, 883 insertions(+), 697 deletions(-) diff --git a/evm_arithmetization/src/cpu/kernel/interpreter.rs b/evm_arithmetization/src/cpu/kernel/interpreter.rs index 4ff283c7a..cc65aeb4c 100644 --- a/evm_arithmetization/src/cpu/kernel/interpreter.rs +++ b/evm_arithmetization/src/cpu/kernel/interpreter.rs @@ -13,13 +13,13 @@ use plonky2::field::types::Field; use super::assembler::BYTES_PER_OFFSET; use super::utils::u256_from_bool; -use crate::cpu::halt; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::assembler::Kernel; use crate::cpu::kernel::constants::context_metadata::ContextMetadata; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::cpu::kernel::constants::txn_fields::NormalizedTxnField; use crate::cpu::stack::MAX_USER_STACK_SIZE; +use crate::cpu::{decode, halt}; use crate::extension_tower::BN_BASE; use crate::generation::mpt::load_all_mpts; use crate::generation::prover_input::ProverInputFn; @@ -27,7 +27,7 @@ use crate::generation::rlp::all_rlp_prover_inputs_reversed; use crate::generation::state::{ self, all_withdrawals_prover_inputs_reversed, GenerationState, GenerationStateCheckpoint, }; -use crate::generation::{run_cpu, GenerationInputs, State}; +use crate::generation::{state::State, GenerationInputs}; use crate::memory::segments::{Segment, SEGMENT_SCALING_FACTOR}; use crate::util::{h2u, u256_to_u8, u256_to_usize}; use crate::witness::errors::{ProgramError, ProverInputError}; @@ -37,6 +37,10 @@ use crate::witness::memory::{ }; use crate::witness::operation::{Operation, CONTEXT_SCALING_FACTOR}; use crate::witness::state::RegistersState; +use crate::witness::transition::{ + decode, fill_op_flag, get_op_special_length, log_kernel_instruction, perform_state_op, + Transition, +}; use crate::witness::util::{push_no_write, stack_peek}; use crate::{arithmetic, logic}; @@ -375,19 +379,8 @@ impl Interpreter { Ok(()) } - /// Returns a `GenerationStateCheckpoint` to save the current registers and - /// reset memory operations to the empty vector. - pub(crate) fn checkpoint(&mut self) -> GenerationStateCheckpoint { - self.generation_state.traces.memory_ops = vec![]; - GenerationStateCheckpoint { - registers: self.generation_state.registers, - traces: self.generation_state.traces.checkpoint(), - } - } - pub(crate) fn run(&mut self) -> Result<(), anyhow::Error> { - let mut state = State::Interpreter(self); - run_cpu(&mut state, false)?; + self.run_cpu(false)?; #[cfg(debug_assertions)] { @@ -551,18 +544,6 @@ impl Interpreter { self.set_memory_segment_bytes(Segment::RlpRaw, rlp) } - /// Clears all traces of the interpreter's `GenerationState` except for - /// memory_ops, which are necessary to apply operations. - pub(crate) fn clear_traces(&mut self) { - self.generation_state.traces.arithmetic_ops = vec![]; - self.generation_state.traces.arithmetic_ops = vec![]; - self.generation_state.traces.byte_packing_ops = vec![]; - self.generation_state.traces.cpu = vec![]; - self.generation_state.traces.logic_ops = vec![]; - self.generation_state.traces.keccak_inputs = vec![]; - self.generation_state.traces.keccak_sponge_ops = vec![]; - } - pub(crate) fn set_code(&mut self, context: usize, code: Vec) { assert_ne!(context, 0, "Can't modify kernel code."); while self.generation_state.memory.contexts.len() <= context { @@ -627,6 +608,7 @@ impl Interpreter { } } } + fn stack_segment_mut(&mut self) -> &mut Vec> { let context = self.context(); &mut self.generation_state.memory.contexts[context].segments[Segment::Stack.unscale()] @@ -772,6 +754,167 @@ impl Interpreter { } } +impl State for Interpreter { + //// Returns a `GenerationStateCheckpoint` to save the current registers and + /// reset memory operations to the empty vector. + fn checkpoint(&mut self) -> GenerationStateCheckpoint { + self.generation_state.traces.memory_ops = vec![]; + GenerationStateCheckpoint { + registers: self.generation_state.registers, + traces: self.generation_state.traces.checkpoint(), + } + } + + fn incr_gas(&mut self, n: u64) { + self.generation_state.registers.gas_used += n; + } + + fn incr_pc(&mut self, n: usize) { + self.generation_state.registers.program_counter += n; + } + + fn get_registers(&self) -> RegistersState { + self.generation_state.registers + } + + fn get_mut_registers(&mut self) -> &mut RegistersState { + &mut self.generation_state.registers + } + + fn get_from_memory(&mut self, address: MemoryAddress) -> U256 { + self.generation_state + .memory + .get(address, true, &self.preinitialized_segments) + } + + fn get_mut_generation_state(&mut self) -> &mut GenerationState { + &mut self.generation_state + } + + fn is_generation_state(&mut self) -> bool { + false + } + + fn incr_interpreter_clock(&mut self) { + self.clock += 1 + } + + fn get_clock(&mut self) -> usize { + self.clock + } + + fn rollback(&mut self, checkpoint: GenerationStateCheckpoint) { + self.generation_state.rollback(checkpoint) + } + + fn get_context(&mut self) -> usize { + self.context() + } + + fn get_halt_context(&mut self) -> Option { + self.halt_context + } + + fn mem_get_kernel_content(&self) -> Vec> { + self.generation_state.memory.contexts[0].segments[Segment::KernelGeneral.unscale()] + .content + .clone() + } + + fn apply_ops(&mut self, checkpoint: GenerationStateCheckpoint) { + self.apply_memops(); + } + + fn get_stack(&mut self) -> Vec { + self.stack().clone() + } + + fn get_halt_offsets(&self) -> Vec { + self.halt_offsets.clone() + } + + fn try_perform_instruction(&mut self) -> Result { + let registers = self.generation_state.registers; + let (mut row, opcode) = self.base_row(); + + let op = decode(registers, opcode)?; + + fill_op_flag(op, &mut row); + + self.fill_stack_fields(&mut row)?; + + let generation_state = self.get_mut_generation_state(); + if registers.is_kernel { + log_kernel_instruction(generation_state, op); + } else { + log::debug!("User instruction: {:?}", op); + } + + // Might write in general CPU columns when it shouldn't, but the correct values + // will overwrite these ones during the op generation. + if let Some(special_len) = get_op_special_length(op) { + let special_len_f = F::from_canonical_usize(special_len); + let diff = row.stack_len - special_len_f; + if (generation_state.stack().len() != special_len) { + // If the `State` is an interpreter, we cannot rely on the row to carry out the + // check. + generation_state.registers.is_stack_top_read = true; + } + } else if let Some(inv) = row.stack_len.try_inverse() { + row.general.stack_mut().stack_inv = inv; + row.general.stack_mut().stack_inv_aux = F::ONE; + } + + perform_state_op(self, opcode, op, row) + } +} + +impl Transition for Interpreter { + fn generate_jumpdest_analysis(&mut self, dst: usize) -> bool { + if self.is_jumpdest_analysis && !self.generation_state.registers.is_kernel { + self.add_jumpdest_offset(dst as usize); + true + } else { + false + } + } + + fn skip_if_necessary(&mut self, op: Operation) -> Result { + if self.is_kernel() + && self.is_jumpdest_analysis + && self.generation_state.registers.program_counter + == KERNEL.global_labels["jumpdest_analysis"] + { + self.generation_state.registers.program_counter = + KERNEL.global_labels["jumpdest_analysis_end"]; + self.generation_state + .set_jumpdest_bits(&self.generation_state.get_current_code()?); + let opcode = self + .code() + .get(self.generation_state.registers.program_counter) + .byte(0); + + decode(self.generation_state.registers, opcode) + } else { + Ok(op) + } + } + + fn get_preinitialized_segments(&self) -> HashMap { + self.preinitialized_segments.clone() + } + + fn fill_stack_fields( + &mut self, + row: &mut crate::cpu::columns::CpuColumnsView, + ) -> Result<(), ProgramError> { + self.generation_state.registers.is_stack_top_read = false; + self.generation_state.registers.check_overflow = false; + + Ok(()) + } +} + fn get_mnemonic(opcode: u8) -> &'static str { match opcode { 0x00 => "STOP", diff --git a/evm_arithmetization/src/generation/mod.rs b/evm_arithmetization/src/generation/mod.rs index 765c37cd9..1ced63da0 100644 --- a/evm_arithmetization/src/generation/mod.rs +++ b/evm_arithmetization/src/generation/mod.rs @@ -28,7 +28,6 @@ use crate::proof::{BlockHashes, BlockMetadata, ExtraBlockData, PublicValues, Tri use crate::util::{h2u, u256_to_u8, u256_to_usize}; use crate::witness::memory::{MemoryAddress, MemoryChannel, MemoryOp, MemoryOpKind}; use crate::witness::state::RegistersState; -use crate::witness::transition::transition; pub mod mpt; pub(crate) mod prover_input; @@ -36,7 +35,7 @@ pub(crate) mod rlp; pub(crate) mod state; mod trie_extractor; -use self::state::GenerationStateCheckpoint; +use self::state::{GenerationStateCheckpoint, State}; use crate::witness::util::{mem_write_log, stack_peek}; /// Inputs needed for trace generation. @@ -312,199 +311,8 @@ pub fn generate_traces, const D: usize>( Ok((tables, public_values)) } -/// A State is either an `Interpreter` (used for tests and jumpdest analysis) or -/// a `GenerationState`. -pub(crate) enum State<'a, F: Field> { - Generation(&'a mut GenerationState), - Interpreter(&'a mut Interpreter), -} - -impl<'a, F: Field> State<'a, F> { - /// Returns a `State`'s `Checkpoint`. - pub(crate) fn checkpoint(&mut self) -> GenerationStateCheckpoint { - match self { - Self::Generation(state) => state.checkpoint(), - Self::Interpreter(interpreter) => interpreter.checkpoint(), - } - } - - /// Increments the `gas_used` register by a value `n`. - pub(crate) fn incr_gas(&mut self, n: u64) { - match self { - Self::Generation(state) => state.registers.gas_used += n, - Self::Interpreter(interpreter) => interpreter.generation_state.registers.gas_used += n, - } - } - - /// Increments the `program_counter` register by a value `n`. - pub(crate) fn incr_pc(&mut self, n: usize) { - match self { - Self::Generation(state) => state.registers.program_counter += n, - Self::Interpreter(interpreter) => { - interpreter.generation_state.registers.program_counter += n - } - } - } - - /// Returns a `State`'s registers. - pub(crate) fn get_registers(&self) -> RegistersState { - match self { - Self::Generation(state) => state.registers, - Self::Interpreter(interpreter) => interpreter.generation_state.registers, - } - } - - /// Returns a `State`'s mutable registers. - pub(crate) fn get_mut_registers(&mut self) -> &mut RegistersState { - match self { - Self::Generation(state) => &mut state.registers, - Self::Interpreter(interpreter) => &mut interpreter.generation_state.registers, - } - } - - /// Returns the value stored at address `address` in a `State`. - pub(crate) fn get_from_memory(&mut self, address: MemoryAddress) -> U256 { - match self { - Self::Generation(state) => state.memory.get(address, false, &HashMap::default()), - Self::Interpreter(interpreter) => interpreter.generation_state.memory.get( - address, - true, - &interpreter.preinitialized_segments, - ), - } - } - - /// Returns a mutable `GenerationState` from a `State`. - pub(crate) fn get_mut_generation_state(&mut self) -> &mut GenerationState { - match self { - Self::Generation(state) => state, - Self::Interpreter(interpreter) => &mut interpreter.generation_state, - } - } - - /// Returns true if a `State` is a `GenerationState` and false otherwise. - pub(crate) fn is_generation_state(&mut self) -> bool { - match self { - Self::Generation(state) => true, - Self::Interpreter(interpreter) => false, - } - } - - /// Increments the clock of an `Interpreter`'s clock. - pub(crate) fn incr_interpreter_clock(&mut self) { - match self { - Self::Generation(state) => {} - Self::Interpreter(interpreter) => interpreter.clock += 1, - } - } - - /// Returns the value of a `State`'s clock. - pub(crate) fn get_clock(&mut self) -> usize { - match self { - Self::Generation(state) => state.traces.clock(), - Self::Interpreter(interpreter) => interpreter.clock, - } - } - - /// Rolls back a `State`. - pub(crate) fn rollback(&mut self, checkpoint: GenerationStateCheckpoint) { - match self { - Self::Generation(state) => state.rollback(checkpoint), - Self::Interpreter(interpreter) => interpreter.generation_state.rollback(checkpoint), - } - } - - /// Returns a `State`'s stack. - pub(crate) fn get_stack(&mut self) -> Vec { - match self { - Self::Generation(state) => state.stack(), - Self::Interpreter(interpreter) => interpreter.stack(), - } - } - - fn get_context(&mut self) -> usize { - match self { - Self::Generation(state) => state.registers.context, - Self::Interpreter(interpreter) => interpreter.context(), - } - } - - fn get_halt_context(&mut self) -> Option { - match self { - Self::Generation(state) => None, - Self::Interpreter(interpreter) => interpreter.halt_context, - } - } - - /// Returns the content of a the `KernelGeneral` segment of a `State`. - pub(crate) fn mem_get_kernel_content(&self) -> Vec> { - match self { - Self::Generation(state) => state.memory.contexts[0].segments - [Segment::KernelGeneral.unscale()] - .content - .clone(), - Self::Interpreter(interpreter) => interpreter.generation_state.memory.contexts[0] - .segments[Segment::KernelGeneral.unscale()] - .content - .clone(), - } - } - - /// Applies a `State`'s operations since a checkpoint. - pub(crate) fn apply_ops(&mut self, checkpoint: GenerationStateCheckpoint) { - match self { - Self::Generation(state) => state - .memory - .apply_ops(state.traces.mem_ops_since(checkpoint.traces)), - Self::Interpreter(interpreter) => { - // An interpreter `checkpoint()` clears all operations before the checkpoint. - interpreter.apply_memops(); - } - } - } -} - -/// Simulates a CPU. It only generates the traces if the `State` is a -/// `GenerationState`. Otherwise, it simply simulates all ooperations. -pub(crate) fn run_cpu( - any_state: &mut State, - is_generation: bool, -) -> anyhow::Result<()> { - let halt_offsets = match any_state { - State::Generation(state) => vec![KERNEL.global_labels["halt"]], - State::Interpreter(interpreter) => interpreter.halt_offsets.clone(), - }; - - loop { - // If we've reached the kernel's halt routine. - let registers = any_state.get_registers(); - let pc = registers.program_counter; - - let halt = registers.is_kernel && halt_offsets.contains(&pc); - - if halt { - if let Some(halt_context) = any_state.get_halt_context() { - if registers.context == halt_context { - // Only happens during jumpdest analysis. - return Ok(()); - } - } else { - if is_generation { - log::info!("CPU halted after {} cycles", any_state.get_clock()); - } - return Ok(()); - } - } - - transition(any_state)?; - any_state.incr_interpreter_clock(); - } - - Ok(()) -} - fn simulate_cpu(state: &mut GenerationState) -> anyhow::Result<()> { - run_cpu(&mut State::Generation(state), true)?; + state.run_cpu(true)?; let pc = state.registers.program_counter; // Padding diff --git a/evm_arithmetization/src/generation/prover_input.rs b/evm_arithmetization/src/generation/prover_input.rs index 361de2fb5..2a2126589 100644 --- a/evm_arithmetization/src/generation/prover_input.rs +++ b/evm_arithmetization/src/generation/prover_input.rs @@ -9,6 +9,7 @@ use num_bigint::BigUint; use plonky2::field::types::Field; use serde::{Deserialize, Serialize}; +use super::state::State; use crate::cpu::kernel::constants::context_metadata::ContextMetadata; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::cpu::kernel::interpreter::simulate_cpu_and_get_user_jumps; diff --git a/evm_arithmetization/src/generation/state.rs b/evm_arithmetization/src/generation/state.rs index 8fb5baacd..ca7121eef 100644 --- a/evm_arithmetization/src/generation/state.rs +++ b/evm_arithmetization/src/generation/state.rs @@ -1,26 +1,217 @@ use std::collections::HashMap; +use anyhow::bail; use ethereum_types::{Address, BigEndianHash, H160, H256, U256}; use keccak_hash::keccak; +use log::log_enabled; use plonky2::field::types::Field; use super::mpt::{load_all_mpts, TrieRootPtrs}; use super::TrieInputs; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::constants::context_metadata::ContextMetadata; +use crate::cpu::membus::NUM_GP_CHANNELS; +use crate::cpu::stack::MAX_USER_STACK_SIZE; use crate::generation::rlp::all_rlp_prover_inputs_reversed; +use crate::generation::CpuColumnsView; use crate::generation::GenerationInputs; use crate::memory::segments::Segment; use crate::util::u256_to_usize; use crate::witness::errors::ProgramError; -use crate::witness::memory::{MemoryAddress, MemoryState}; +use crate::witness::memory::MemoryChannel::GeneralPurpose; +use crate::witness::memory::MemoryOpKind; +use crate::witness::memory::{MemoryAddress, MemoryOp, MemoryState}; +use crate::witness::operation::{generate_exception, Operation}; use crate::witness::state::RegistersState; use crate::witness::traces::{TraceCheckpoint, Traces}; -use crate::witness::util::stack_peek; +use crate::witness::transition::{ + decode, fill_op_flag, get_op_special_length, log_kernel_instruction, might_overflow_op, + perform_state_op, read_code_memory, Transition, +}; +use crate::witness::util::{ + fill_channel_with_value, mem_read_gp_with_log_and_fill, stack_peek, stack_pop_with_log_and_fill, +}; -pub(crate) struct GenerationStateCheckpoint { - pub(crate) registers: RegistersState, - pub(crate) traces: TraceCheckpoint, +/// A State is either an `Interpreter` (used for tests and jumpdest analysis) or +/// a `GenerationState`. +pub(crate) trait State { + /// Returns a `State`'s `Checkpoint`. + fn checkpoint(&mut self) -> GenerationStateCheckpoint; + + /// Increments the `gas_used` register by a value `n`. + fn incr_gas(&mut self, n: u64); + + /// Increments the `program_counter` register by a value `n`. + fn incr_pc(&mut self, n: usize); + + /// Returns a `State`'s registers. + fn get_registers(&self) -> RegistersState; + + /// Returns a `State`'s mutable registers. + fn get_mut_registers(&mut self) -> &mut RegistersState; + + /// Returns the value stored at address `address` in a `State`. + fn get_from_memory(&mut self, address: MemoryAddress) -> U256; + + /// Returns a mutable `GenerationState` from a `State`. + fn get_mut_generation_state(&mut self) -> &mut GenerationState; + + /// Returns true if a `State` is a `GenerationState` and false otherwise. + fn is_generation_state(&mut self) -> bool; + + /// Increments the clock of an `Interpreter`'s clock. + fn incr_interpreter_clock(&mut self); + + /// Returns the value of a `State`'s clock. + fn get_clock(&mut self) -> usize; + + /// Rolls back a `State`. + fn rollback(&mut self, checkpoint: GenerationStateCheckpoint); + + /// Returns a `State`'s stack. + fn get_stack(&mut self) -> Vec; + + /// Returns the current context. + fn get_context(&mut self) -> usize; + + fn get_halt_context(&mut self) -> Option; + + /// Returns the content of a the `KernelGeneral` segment of a `State`. + fn mem_get_kernel_content(&self) -> Vec>; + + /// Applies a `State`'s operations since a checkpoint. + fn apply_ops(&mut self, checkpoint: GenerationStateCheckpoint); + + /// Return the offsets at which execution must halt + fn get_halt_offsets(&self) -> Vec; + + /// Simulates a CPU. It only generates the traces if the `State` is a + /// `GenerationState`. Otherwise, it simply simulates all ooperations. + fn run_cpu(&mut self, is_generation: bool) -> anyhow::Result<()> { + let halt_offsets = self.get_halt_offsets(); + + loop { + // If we've reached the kernel's halt routine. + let registers = self.get_registers(); + let pc = registers.program_counter; + + let halt = registers.is_kernel && halt_offsets.contains(&pc); + + if halt { + if let Some(halt_context) = self.get_halt_context() { + if registers.context == halt_context { + // Only happens during jumpdest analysis. + return Ok(()); + } + } else { + if is_generation { + log::info!("CPU halted after {} cycles", self.get_clock()); + } + return Ok(()); + } + } + + self.transition()?; + self.incr_interpreter_clock(); + } + + Ok(()) + } + + fn handle_error(&mut self, err: ProgramError) -> anyhow::Result<()> { + let exc_code: u8 = match err { + ProgramError::OutOfGas => 0, + ProgramError::InvalidOpcode => 1, + ProgramError::StackUnderflow => 2, + ProgramError::InvalidJumpDestination => 3, + ProgramError::InvalidJumpiDestination => 4, + ProgramError::StackOverflow => 5, + _ => bail!("TODO: figure out what to do with this..."), + }; + + let (checkpoint, is_generation) = (self.checkpoint(), self.is_generation_state()); + + let (row, _) = self.base_row(); + generate_exception( + exc_code, + self.get_mut_generation_state(), + row, + is_generation, + ); + + // We only clear traces for the interpreter + if !self.is_generation_state() { + self.clear_traces() + } + + self.apply_ops(checkpoint); + + Ok(()) + } + + fn transition(&mut self) -> anyhow::Result<()> { + let checkpoint = self.checkpoint(); + let result = self.try_perform_instruction(); + + match result { + Ok(op) => { + self.apply_ops(checkpoint); + + if might_overflow_op(op) { + self.get_mut_registers().check_overflow = true; + } + Ok(()) + } + Err(e) => { + if self.get_registers().is_kernel { + let offset_name = KERNEL.offset_name(self.get_registers().program_counter); + bail!( + "{:?} in kernel at pc={}, stack={:?}, memory={:?}", + e, + offset_name, + self.get_stack(), + self.mem_get_kernel_content(), + ); + } + self.rollback(checkpoint); + self.handle_error(e) + } + } + } + + /// Clears all traces from `GenerationState` except for + /// memory_ops, which are necessary to apply operations. + fn clear_traces(&mut self) { + let generation_state = self.get_mut_generation_state(); + generation_state.traces.arithmetic_ops = vec![]; + generation_state.traces.arithmetic_ops = vec![]; + generation_state.traces.byte_packing_ops = vec![]; + generation_state.traces.cpu = vec![]; + generation_state.traces.logic_ops = vec![]; + generation_state.traces.keccak_inputs = vec![]; + generation_state.traces.keccak_sponge_ops = vec![]; + } + + fn try_perform_instruction(&mut self) -> Result; + + /// Row that has the correct values for system registers and the code + /// channel, but is otherwise blank. It fulfills the constraints that + /// are common to successful operations and the exception operation. It + /// also returns the opcode + fn base_row(&mut self) -> (CpuColumnsView, u8) { + let generation_state = self.get_mut_generation_state(); + let mut row: CpuColumnsView = CpuColumnsView::default(); + row.clock = F::from_canonical_usize(generation_state.traces.clock()); + row.context = F::from_canonical_usize(generation_state.registers.context); + row.program_counter = F::from_canonical_usize(generation_state.registers.program_counter); + row.is_kernel_mode = F::from_bool(generation_state.registers.is_kernel); + row.gas = F::from_canonical_u64(generation_state.registers.gas_used); + row.stack_len = F::from_canonical_usize(generation_state.registers.stack_len); + fill_channel_with_value(&mut row, 0, generation_state.registers.stack_top); + + let opcode = read_code_memory(generation_state, &mut row); + (row, opcode) + } } #[derive(Debug)] @@ -161,13 +352,6 @@ impl GenerationState { Ok(()) } - pub(crate) fn checkpoint(&self) -> GenerationStateCheckpoint { - GenerationStateCheckpoint { - registers: self.registers, - traces: self.traces.checkpoint(), - } - } - pub(crate) fn rollback(&mut self, checkpoint: GenerationStateCheckpoint) { self.registers = checkpoint.registers; self.traces.rollback(checkpoint.traces); @@ -201,6 +385,195 @@ impl GenerationState { } } +impl State for GenerationState { + fn checkpoint(&mut self) -> GenerationStateCheckpoint { + GenerationStateCheckpoint { + registers: self.registers, + traces: self.traces.checkpoint(), + } + } + + /// Increments the `gas_used` register by a value `n`. + fn incr_gas(&mut self, n: u64) { + self.registers.gas_used += n; + } + + /// Increments the `program_counter` register by a value `n`. + fn incr_pc(&mut self, n: usize) { + self.registers.program_counter += n; + } + + /// Returns a `State`'s registers. + fn get_registers(&self) -> RegistersState { + self.registers + } + + /// Returns a `State`'s mutable registers. + fn get_mut_registers(&mut self) -> &mut RegistersState { + &mut self.registers + } + + /// Returns the value stored at address `address` in a `State`. + fn get_from_memory(&mut self, address: MemoryAddress) -> U256 { + self.memory.get(address, false, &HashMap::default()) + } + + /// Returns a mutable `GenerationState` from a `State`. + fn get_mut_generation_state(&mut self) -> &mut GenerationState { + self + } + + /// Returns true if a `State` is a `GenerationState` and false otherwise. + fn is_generation_state(&mut self) -> bool { + true + } + + /// Increments the clock of an `Interpreter`'s clock. + fn incr_interpreter_clock(&mut self) {} + + /// Returns the value of a `State`'s clock. + fn get_clock(&mut self) -> usize { + self.traces.clock() + } + + /// Rolls back a `State`. + fn rollback(&mut self, checkpoint: GenerationStateCheckpoint) { + self.rollback(checkpoint) + } + + /// Returns a `State`'s stack. + fn get_stack(&mut self) -> Vec { + self.stack() + } + + fn get_context(&mut self) -> usize { + self.registers.context + } + + fn get_halt_context(&mut self) -> Option { + None + } + + /// Returns the content of a the `KernelGeneral` segment of a `State`. + fn mem_get_kernel_content(&self) -> Vec> { + self.memory.contexts[0].segments[Segment::KernelGeneral.unscale()] + .content + .clone() + } + + /// Applies a `State`'s operations since a checkpoint. + fn apply_ops(&mut self, checkpoint: GenerationStateCheckpoint) { + self.memory + .apply_ops(self.traces.mem_ops_since(checkpoint.traces)) + } + + fn get_halt_offsets(&self) -> Vec { + vec![KERNEL.global_labels["halt"]] + } + + fn try_perform_instruction(&mut self) -> Result { + let registers = self.registers; + let (mut row, opcode) = self.base_row(); + + let op = decode(registers, opcode)?; + + if registers.is_kernel { + log_kernel_instruction(self, op); + } else { + log::debug!("User instruction: {:?}", op); + } + fill_op_flag(op, &mut row); + + self.fill_stack_fields(&mut row)?; + + // Might write in general CPU columns when it shouldn't, but the correct values + // will overwrite these ones during the op generation. + if let Some(special_len) = get_op_special_length(op) { + let special_len_f = F::from_canonical_usize(special_len); + let diff = row.stack_len - special_len_f; + if let Some(inv) = diff.try_inverse() { + row.general.stack_mut().stack_inv = inv; + row.general.stack_mut().stack_inv_aux = F::ONE; + self.registers.is_stack_top_read = true; + } else if (self.stack().len() != special_len) { + // If the `State` is an interpreter, we cannot rely on the row to carry out the + // check. + self.registers.is_stack_top_read = true; + } + } else if let Some(inv) = row.stack_len.try_inverse() { + row.general.stack_mut().stack_inv = inv; + row.general.stack_mut().stack_inv_aux = F::ONE; + } + + perform_state_op(self, opcode, op, row) + } +} + +impl Transition for GenerationState { + fn skip_if_necessary(&mut self, op: Operation) -> Result { + Ok(op) + } + + fn get_preinitialized_segments( + &self, + ) -> HashMap { + HashMap::default() + } + + fn generate_jumpdest_analysis(&mut self, dst: usize) -> bool { + false + } + + fn fill_stack_fields(&mut self, row: &mut CpuColumnsView) -> Result<(), ProgramError> { + if self.registers.is_stack_top_read { + let channel = &mut row.mem_channels[0]; + channel.used = F::ONE; + channel.is_read = F::ONE; + channel.addr_context = F::from_canonical_usize(self.registers.context); + channel.addr_segment = F::from_canonical_usize(Segment::Stack.unscale()); + channel.addr_virtual = F::from_canonical_usize(self.registers.stack_len - 1); + + let address = MemoryAddress::new( + self.registers.context, + Segment::Stack, + self.registers.stack_len - 1, + ); + + let mem_op = MemoryOp::new( + GeneralPurpose(0), + self.traces.clock(), + address, + MemoryOpKind::Read, + self.registers.stack_top, + ); + self.traces.push_memory(mem_op); + } + self.registers.is_stack_top_read = false; + + if self.registers.check_overflow { + if self.registers.is_kernel { + row.general.stack_mut().stack_len_bounds_aux = F::ZERO; + } else { + let clock = self.traces.clock(); + let last_row = &mut self.traces.cpu[clock - 1]; + let disallowed_len = F::from_canonical_usize(MAX_USER_STACK_SIZE + 1); + let diff = row.stack_len - disallowed_len; + if let Some(inv) = diff.try_inverse() { + last_row.general.stack_mut().stack_len_bounds_aux = inv; + } + } + } + self.registers.check_overflow = false; + + Ok(()) + } +} + +pub(crate) struct GenerationStateCheckpoint { + pub(crate) registers: RegistersState, + pub(crate) traces: TraceCheckpoint, +} + /// Withdrawals prover input array is of the form `[addr0, amount0, ..., addrN, /// amountN, U256::MAX, U256::MAX]`. Returns the reversed array. pub(crate) fn all_withdrawals_prover_inputs_reversed(withdrawals: &[(Address, U256)]) -> Vec { diff --git a/evm_arithmetization/src/witness/operation.rs b/evm_arithmetization/src/witness/operation.rs index b820887ab..19a0fd0ae 100644 --- a/evm_arithmetization/src/witness/operation.rs +++ b/evm_arithmetization/src/witness/operation.rs @@ -7,6 +7,7 @@ use keccak_hash::keccak; use plonky2::field::types::Field; use super::memory::MemorySegmentState; +use super::transition::Transition; use super::util::{ byte_packing_log, byte_unpacking_log, mem_read_with_log, mem_write_log, mem_write_partial_log_and_fill, push_no_write, push_with_write, @@ -21,14 +22,13 @@ use crate::cpu::simple_logic::eq_iszero::generate_pinv_diff; use crate::cpu::stack::MAX_USER_STACK_SIZE; use crate::extension_tower::BN_BASE; use crate::generation::state::GenerationState; -use crate::generation::State; +use crate::generation::state::State; use crate::memory::segments::Segment; use crate::util::u256_to_usize; use crate::witness::errors::MemoryError::VirtTooLarge; use crate::witness::errors::ProgramError; use crate::witness::memory::{MemoryAddress, MemoryChannel, MemoryOp, MemoryOpKind}; use crate::witness::operation::MemoryChannel::GeneralPurpose; -use crate::witness::transition::fill_stack_fields; use crate::witness::util::{ keccak_sponge_log, mem_read_gp_with_log_and_fill, mem_write_gp_log_and_fill, stack_pop_with_log_and_fill, @@ -217,160 +217,6 @@ pub(crate) fn generate_pop( Ok(()) } -pub(crate) fn generate_jump( - state: &mut State, - mut row: CpuColumnsView, - is_jumpdest_analysis: bool, -) -> Result<(), ProgramError> { - let [(dst, _)] = - stack_pop_with_log_and_fill::<1, _>(state.get_mut_generation_state(), &mut row)?; - - let dst: u32 = dst - .try_into() - .map_err(|_| ProgramError::InvalidJumpDestination)?; - - if is_jumpdest_analysis { - match state { - State::Generation(state) => { - panic!("Cannot carry out jumpdest analysis with a `GenerationState.") - } - State::Interpreter(interpreter) => { - if !interpreter.generation_state.registers.is_kernel { - interpreter.add_jumpdest_offset(dst as usize); - } - } - } - } else { - let gen_state = state.get_mut_generation_state(); - // Even though we might be in the interpreter, `JumpdestBits` is not part of the - // preinitialized segments, so we don't need to carry out the additional checks - // when get the value from memory. - let (jumpdest_bit, jumpdest_bit_log) = mem_read_gp_with_log_and_fill( - NUM_GP_CHANNELS - 1, - MemoryAddress::new( - gen_state.registers.context, - Segment::JumpdestBits, - dst as usize, - ), - gen_state, - &mut row, - false, - &HashMap::default(), - ); - - row.mem_channels[1].value[0] = F::ONE; - - if gen_state.registers.is_kernel { - // Don't actually do the read, just set the address, etc. - let channel = &mut row.mem_channels[NUM_GP_CHANNELS - 1]; - channel.used = F::ZERO; - channel.value[0] = F::ONE; - } else { - if jumpdest_bit != U256::one() { - return Err(ProgramError::InvalidJumpDestination); - } - gen_state.traces.push_memory(jumpdest_bit_log); - } - - // Extra fields required by the constraints. - row.general.jumps_mut().should_jump = F::ONE; - row.general.jumps_mut().cond_sum_pinv = F::ONE; - - let diff = row.stack_len - F::ONE; - if let Some(inv) = diff.try_inverse() { - row.general.stack_mut().stack_inv = inv; - row.general.stack_mut().stack_inv_aux = F::ONE; - } else { - row.general.stack_mut().stack_inv = F::ZERO; - row.general.stack_mut().stack_inv_aux = F::ZERO; - } - - gen_state.traces.push_cpu(row); - } - state.get_mut_generation_state().jump_to(dst as usize)?; - Ok(()) -} - -pub(crate) fn generate_jumpi( - state: &mut State, - mut row: CpuColumnsView, - is_jumpdest_analysis: bool, -) -> Result<(), ProgramError> { - let [(dst, _), (cond, log_cond)] = - stack_pop_with_log_and_fill::<2, _>(state.get_mut_generation_state(), &mut row)?; - - let should_jump = !cond.is_zero(); - if should_jump { - let dst: u32 = dst - .try_into() - .map_err(|_| ProgramError::InvalidJumpiDestination)?; - let is_kernel = state.get_registers().is_kernel; - if is_jumpdest_analysis && !is_kernel { - match state { - State::Generation(state) => { - panic!("Cannot carry out jumpdest analysis with a `GenerationState`.") - } - State::Interpreter(interpreter) => interpreter.add_jumpdest_offset(dst as usize), - } - } else { - row.general.jumps_mut().should_jump = F::ONE; - let cond_sum_u64 = cond - .0 - .into_iter() - .map(|limb| ((limb as u32) as u64) + (limb >> 32)) - .sum(); - let cond_sum = F::from_canonical_u64(cond_sum_u64); - row.general.jumps_mut().cond_sum_pinv = cond_sum.inverse(); - } - state.get_mut_generation_state().jump_to(dst as usize)?; - } else { - row.general.jumps_mut().should_jump = F::ZERO; - row.general.jumps_mut().cond_sum_pinv = F::ZERO; - state.incr_pc(1); - } - - let gen_state = state.get_mut_generation_state(); - // Even though we might be in the interpreter, `JumpdestBits` is not part of the - // preinitialized segments, so we don't need to carry out the additional checks - // when get the value from memory. - let (jumpdest_bit, jumpdest_bit_log) = mem_read_gp_with_log_and_fill( - NUM_GP_CHANNELS - 1, - MemoryAddress::new( - gen_state.registers.context, - Segment::JumpdestBits, - dst.low_u32() as usize, - ), - gen_state, - &mut row, - false, - &HashMap::default(), - ); - if !should_jump || gen_state.registers.is_kernel { - // Don't actually do the read, just set the address, etc. - let channel = &mut row.mem_channels[NUM_GP_CHANNELS - 1]; - channel.used = F::ZERO; - channel.value[0] = F::ONE; - } else { - if jumpdest_bit != U256::one() { - return Err(ProgramError::InvalidJumpiDestination); - } - gen_state.traces.push_memory(jumpdest_bit_log); - } - - let diff = row.stack_len - F::TWO; - if let Some(inv) = diff.try_inverse() { - row.general.stack_mut().stack_inv = inv; - row.general.stack_mut().stack_inv_aux = F::ONE; - } else { - row.general.stack_mut().stack_inv = F::ZERO; - row.general.stack_mut().stack_inv_aux = F::ZERO; - } - - gen_state.traces.push_memory(log_cond); - gen_state.traces.push_cpu(row); - Ok(()) -} - pub(crate) fn generate_pc( state: &mut GenerationState, mut row: CpuColumnsView, @@ -1009,32 +855,37 @@ pub(crate) fn generate_mstore_general( Ok(()) } -pub(crate) fn generate_mstore_32bytes( +pub(crate) fn generate_mstore_32bytes>( n: u8, - state: &mut GenerationState, + state: &mut S, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let [(addr, _), (val, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; + let generation_state = state.get_mut_generation_state(); + let [(addr, _), (val, log_in1)] = + stack_pop_with_log_and_fill::<2, _>(generation_state, &mut row)?; let base_address = MemoryAddress::new_bundle(addr)?; - byte_unpacking_log(state, base_address, val, n as usize); + byte_unpacking_log(generation_state, base_address, val, n as usize); let new_addr = addr + n; - push_no_write(state, new_addr); + push_no_write(generation_state, new_addr); - state.traces.push_memory(log_in1); - state.traces.push_cpu(row); + generation_state.traces.push_memory(log_in1); + generation_state.traces.push_cpu(row); Ok(()) } -pub(crate) fn generate_exception( +pub(crate) fn generate_exception>( exc_code: u8, - state: &mut GenerationState, + state: &mut T, mut row: CpuColumnsView, is_generation: bool, ) -> Result<(), ProgramError> { - if TryInto::::try_into(state.registers.gas_used).is_err() { + // TDO: se fue arriba + state.fill_stack_fields(&mut row)?; + let generation_state = state.get_mut_generation_state(); + if TryInto::::try_into(generation_state.registers.gas_used).is_err() { return Err(ProgramError::GasLimitError); } @@ -1045,8 +896,6 @@ pub(crate) fn generate_exception( row.general.stack_mut().stack_inv_aux = F::ONE; } - fill_stack_fields(state, &mut row, is_generation)?; - row.general.exception_mut().exc_code_bits = [ F::from_bool(exc_code & 1 != 0), F::from_bool(exc_code & 2 != 0), @@ -1067,7 +916,9 @@ pub(crate) fn generate_exception( // Even though we might be in the interpreter, `Code` is not part of the // preinitialized segments, so we don't need to carry out the additional checks // when get the value from memory. - let val = state.memory.get(address, false, &HashMap::default()); + let val = generation_state + .memory + .get(address, false, &HashMap::default()); val.low_u32() as u8 }) .collect_vec(); @@ -1081,20 +932,26 @@ pub(crate) fn generate_exception( jumptable_channel.addr_virtual = F::from_canonical_usize(handler_addr_addr); jumptable_channel.value[0] = F::from_canonical_usize(u256_to_usize(packed_int)?); - byte_packing_log(state, base_address, bytes); + byte_packing_log(generation_state, base_address, bytes); let new_program_counter = u256_to_usize(packed_int)?; - let gas = U256::from(state.registers.gas_used); + let gas = U256::from(generation_state.registers.gas_used); - let exc_info = U256::from(state.registers.program_counter) + (gas << 192); + let exc_info = U256::from(generation_state.registers.program_counter) + (gas << 192); // Get the opcode so we can provide it to the range_check operation. - let code_context = state.registers.code_context(); - let address = MemoryAddress::new(code_context, Segment::Code, state.registers.program_counter); + let code_context = generation_state.registers.code_context(); + let address = MemoryAddress::new( + code_context, + Segment::Code, + generation_state.registers.program_counter, + ); // Even though we might be in the interpreter, `Code` is not part of the // preinitialized segments, so we don't need to carry out the additional checks // when get the value from memory. - let opcode = state.memory.get(address, false, &HashMap::default()); + let opcode = generation_state + .memory + .get(address, false, &HashMap::default()); // `ArithmeticStark` range checks `mem_channels[0]`, which contains // the top of the stack, `mem_channels[1]`, which contains the new PC, @@ -1103,7 +960,7 @@ pub(crate) fn generate_exception( // Our goal here is to range-check the gas, contained in syscall_info, // stored in the next stack top. let range_check_op = arithmetic::Operation::range_check( - state.registers.stack_top, + generation_state.registers.stack_top, packed_int, U256::from(0), opcode, @@ -1113,15 +970,15 @@ pub(crate) fn generate_exception( // kernel mode so we can't incorrectly trigger a stack overflow. However, // note that we have to do it _after_ we make `exc_info`, which should // contain the old values. - state.registers.program_counter = new_program_counter; - state.registers.is_kernel = true; - state.registers.gas_used = 0; + generation_state.registers.program_counter = new_program_counter; + generation_state.registers.is_kernel = true; + generation_state.registers.gas_used = 0; - push_with_write(state, &mut row, exc_info)?; + push_with_write(generation_state, &mut row, exc_info)?; log::debug!("Exception to {}", KERNEL.offset_name(new_program_counter)); - state.traces.push_arithmetic(range_check_op); - state.traces.push_cpu(row); + generation_state.traces.push_arithmetic(range_check_op); + generation_state.traces.push_cpu(row); Ok(()) } diff --git a/evm_arithmetization/src/witness/transition.rs b/evm_arithmetization/src/witness/transition.rs index 5fcc0dca5..f8fa23787 100644 --- a/evm_arithmetization/src/witness/transition.rs +++ b/evm_arithmetization/src/witness/transition.rs @@ -1,21 +1,27 @@ use std::collections::HashMap; +use std::hash::RandomState; use anyhow::bail; +use ethereum_types::U256; use log::log_enabled; use plonky2::field::types::Field; -use super::memory::{MemoryOp, MemoryOpKind}; -use super::util::fill_channel_with_value; +use super::memory::{MemoryOp, MemoryOpKind, MemorySegmentState}; +use super::util::{ + fill_channel_with_value, mem_read_gp_with_log_and_fill, push_no_write, + stack_pop_with_log_and_fill, +}; use crate::cpu::columns::CpuColumnsView; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::constants::context_metadata::ContextMetadata; use crate::cpu::kernel::interpreter::InterpreterMemOpKind; +use crate::cpu::membus::NUM_GP_CHANNELS; use crate::cpu::stack::{ EQ_STACK_BEHAVIOR, IS_ZERO_STACK_BEHAVIOR, JUMPI_OP, JUMP_OP, MAX_USER_STACK_SIZE, MIGHT_OVERFLOW, STACK_BEHAVIORS, }; -use crate::generation::state::{GenerationState, GenerationStateCheckpoint}; -use crate::generation::State; +use crate::extension_tower::BN_BASE; +use crate::generation::state::{GenerationState, GenerationStateCheckpoint, State}; use crate::memory::segments::Segment; use crate::witness::errors::ProgramError; use crate::witness::gas::gas_to_charge; @@ -26,7 +32,10 @@ use crate::witness::state::RegistersState; use crate::witness::util::mem_read_code_with_log_and_fill; use crate::{arithmetic, logic}; -fn read_code_memory(state: &mut GenerationState, row: &mut CpuColumnsView) -> u8 { +pub(crate) fn read_code_memory( + state: &mut GenerationState, + row: &mut CpuColumnsView, +) -> u8 { let code_context = state.registers.code_context(); row.code_context = F::from_canonical_usize(code_context); @@ -163,7 +172,7 @@ pub(crate) fn decode(registers: RegistersState, opcode: u8) -> Result(op: Operation, row: &mut CpuColumnsView) { +pub(crate) fn fill_op_flag(op: Operation, row: &mut CpuColumnsView) { let flags = &mut row.op; *match op { Operation::Dup(_) | Operation::Swap(_) => &mut flags.dup_swap, @@ -191,7 +200,7 @@ fn fill_op_flag(op: Operation, row: &mut CpuColumnsView) { // Equal to the number of pops if an operation pops without pushing, and `None` // otherwise. -const fn get_op_special_length(op: Operation) -> Option { +pub(crate) const fn get_op_special_length(op: Operation) -> Option { let behavior_opt = match op { Operation::Push(0) | Operation::Pc => STACK_BEHAVIORS.pc_push0, Operation::Push(1..) | Operation::ProverInput => STACK_BEHAVIORS.push_prover_input, @@ -232,7 +241,7 @@ const fn get_op_special_length(op: Operation) -> Option { // These operations might trigger a stack overflow, typically those pushing // without popping. Kernel-only pushing instructions aren't considered; they // can't overflow. -const fn might_overflow_op(op: Operation) -> bool { +pub(crate) const fn might_overflow_op(op: Operation) -> bool { match op { Operation::Push(1..) | Operation::ProverInput => MIGHT_OVERFLOW.push_prover_input, Operation::Dup(_) | Operation::Swap(_) => MIGHT_OVERFLOW.dup_swap, @@ -259,116 +268,13 @@ const fn might_overflow_op(op: Operation) -> bool { } } -fn perform_op( - any_state: &mut State, - op: Operation, - opcode: u8, - row: CpuColumnsView, -) -> Result<(), ProgramError> { - let (op, is_interpreter, preinitialized_segments, is_jumpdest_analysis) = match any_state { - State::Generation(state) => (op, false, HashMap::default(), false), - State::Interpreter(interpreter) => { - // Jumpdest analysis is performed natively by the interpreter and not - // using the non-deterministic Kernel assembly code. - let op = if interpreter.is_kernel() - && interpreter.is_jumpdest_analysis - && interpreter.generation_state.registers.program_counter - == KERNEL.global_labels["jumpdest_analysis"] - { - interpreter.generation_state.registers.program_counter = - KERNEL.global_labels["jumpdest_analysis_end"]; - interpreter - .generation_state - .set_jumpdest_bits(&interpreter.generation_state.get_current_code()?); - let opcode = interpreter - .code() - .get(interpreter.generation_state.registers.program_counter) - .byte(0); - - decode(interpreter.generation_state.registers, opcode)? - } else { - op - }; - ( - op, - true, - interpreter.preinitialized_segments.clone(), - interpreter.is_jumpdest_analysis, - ) - } - }; - - #[cfg(debug_assertions)] - if !any_state.get_registers().is_kernel { - log::debug!( - "User instruction {:?}, stack = {:?}, ctx = {}", - op, - { - let mut stack = any_state.get_stack(); - stack.reverse(); - stack - }, - any_state.get_registers().context - ); - } - - let state = any_state.get_mut_generation_state(); - - match op { - Operation::Push(n) => generate_push(n, state, row)?, - Operation::Dup(n) => generate_dup(n, state, row)?, - Operation::Swap(n) => generate_swap(n, state, row)?, - Operation::Iszero => generate_iszero(state, row)?, - Operation::Not => generate_not(state, row)?, - Operation::BinaryArithmetic(arithmetic::BinaryOperator::Shl) => generate_shl(state, row)?, - Operation::BinaryArithmetic(arithmetic::BinaryOperator::Shr) => generate_shr(state, row)?, - Operation::Syscall(opcode, stack_values_read, stack_len_increased) => { - generate_syscall(opcode, stack_values_read, stack_len_increased, state, row)? - } - Operation::Eq => generate_eq(state, row)?, - Operation::BinaryLogic(binary_logic_op) => { - generate_binary_logic_op(binary_logic_op, state, row)? - } - Operation::BinaryArithmetic(op) => generate_binary_arithmetic_op(op, state, row)?, - Operation::TernaryArithmetic(op) => generate_ternary_arithmetic_op(op, state, row)?, - Operation::KeccakGeneral => { - generate_keccak_general(state, row, is_interpreter, &preinitialized_segments)? - } - Operation::ProverInput => generate_prover_input(state, row)?, - Operation::Pop => generate_pop(state, row)?, - Operation::Jump => generate_jump(any_state, row, is_jumpdest_analysis)?, - Operation::Jumpi => generate_jumpi(any_state, row, is_jumpdest_analysis)?, - Operation::Pc => generate_pc(state, row)?, - Operation::Jumpdest => generate_jumpdest(state, row)?, - Operation::GetContext => generate_get_context(state, row)?, - Operation::SetContext => generate_set_context(state, row)?, - Operation::Mload32Bytes => { - generate_mload_32bytes(state, row, is_interpreter, &preinitialized_segments)? - } - Operation::Mstore32Bytes(n) => generate_mstore_32bytes(n, state, row)?, - Operation::ExitKernel => generate_exit_kernel(state, row)?, - Operation::MloadGeneral => { - generate_mload_general(state, row, is_interpreter, &preinitialized_segments)? - } - Operation::MstoreGeneral => generate_mstore_general(state, row)?, - }; - match any_state { - State::Generation(state) => {} - State::Interpreter(interpreter) => { - interpreter.clear_traces(); - interpreter.opcode_count[opcode as usize] += 1; - } - } - Ok(()) -} - -fn perform_state_op( - any_state: &mut State, +pub(crate) fn perform_state_op>( + any_state: &mut T, opcode: u8, op: Operation, row: CpuColumnsView, ) -> Result { - perform_op(any_state, op, opcode, row)?; + any_state.perform_op(op, opcode, row)?; any_state.incr_pc(match op { Operation::Syscall(_, _, _) | Operation::ExitKernel => 0, Operation::Push(n) => n as usize + 1, @@ -399,121 +305,7 @@ fn perform_state_op( Ok(op) } -/// Row that has the correct values for system registers and the code channel, -/// but is otherwise blank. It fulfills the constraints that are common to -/// successful operations and the exception operation. It also returns the -/// opcode. -fn base_row(state: &mut GenerationState) -> (CpuColumnsView, u8) { - let mut row: CpuColumnsView = CpuColumnsView::default(); - row.clock = F::from_canonical_usize(state.traces.clock()); - row.context = F::from_canonical_usize(state.registers.context); - row.program_counter = F::from_canonical_usize(state.registers.program_counter); - row.is_kernel_mode = F::from_bool(state.registers.is_kernel); - row.gas = F::from_canonical_u64(state.registers.gas_used); - row.stack_len = F::from_canonical_usize(state.registers.stack_len); - fill_channel_with_value(&mut row, 0, state.registers.stack_top); - - let opcode = read_code_memory(state, &mut row); - (row, opcode) -} - -pub(crate) fn fill_stack_fields( - state: &mut GenerationState, - row: &mut CpuColumnsView, - is_generation: bool, -) -> Result<(), ProgramError> { - if state.registers.is_stack_top_read && is_generation { - let channel = &mut row.mem_channels[0]; - channel.used = F::ONE; - channel.is_read = F::ONE; - channel.addr_context = F::from_canonical_usize(state.registers.context); - channel.addr_segment = F::from_canonical_usize(Segment::Stack.unscale()); - channel.addr_virtual = F::from_canonical_usize(state.registers.stack_len - 1); - - let address = MemoryAddress::new( - state.registers.context, - Segment::Stack, - state.registers.stack_len - 1, - ); - - let mem_op = MemoryOp::new( - GeneralPurpose(0), - state.traces.clock(), - address, - MemoryOpKind::Read, - state.registers.stack_top, - ); - state.traces.push_memory(mem_op); - } - state.registers.is_stack_top_read = false; - - if state.registers.check_overflow && is_generation { - if state.registers.is_kernel { - row.general.stack_mut().stack_len_bounds_aux = F::ZERO; - } else { - let clock = state.traces.clock(); - let last_row = &mut state.traces.cpu[clock - 1]; - let disallowed_len = F::from_canonical_usize(MAX_USER_STACK_SIZE + 1); - let diff = row.stack_len - disallowed_len; - if let Some(inv) = diff.try_inverse() { - last_row.general.stack_mut().stack_len_bounds_aux = inv; - } else { - // This is a stack overflow that should have been caught earlier. - return Err(ProgramError::InterpreterError); - } - } - } - state.registers.check_overflow = false; - - Ok(()) -} - -fn try_perform_instruction(any_state: &mut State) -> Result { - let is_generation = any_state.is_generation_state(); - let registers = any_state.get_registers(); - - let state = any_state.get_mut_generation_state(); - let (mut row, opcode) = base_row(state); - - let op = decode(registers, opcode)?; - - if registers.is_kernel { - log_kernel_instruction(state, op); - } else { - log::debug!("User instruction: {:?}", op); - } - - if is_generation { - fill_op_flag(op, &mut row); - } - - fill_stack_fields(state, &mut row, is_generation)?; - - // Might write in general CPU columns when it shouldn't, but the correct values - // will overwrite these ones during the op generation. - if let Some(special_len) = get_op_special_length(op) { - let special_len_f = F::from_canonical_usize(special_len); - let diff = row.stack_len - special_len_f; - if let Some(inv) = diff.try_inverse() - && is_generation - { - row.general.stack_mut().stack_inv = inv; - row.general.stack_mut().stack_inv_aux = F::ONE; - state.registers.is_stack_top_read = true; - } else if !is_generation && (state.stack().len() != special_len) { - // If the `State` is an interpreter, we cannot rely on the row to carry out the - // check. - state.registers.is_stack_top_read = true; - } - } else if let Some(inv) = row.stack_len.try_inverse() { - row.general.stack_mut().stack_inv = inv; - row.general.stack_mut().stack_inv_aux = F::ONE; - } - - perform_state_op(any_state, opcode, op, row) -} - -fn log_kernel_instruction(state: &GenerationState, op: Operation) { +pub(crate) fn log_kernel_instruction(state: &mut GenerationState, op: Operation) { // The logic below is a bit costly, so skip it if debug logs aren't enabled. if !log_enabled!(log::Level::Debug) { return; @@ -542,68 +334,280 @@ fn log_kernel_instruction(state: &GenerationState, op: Operation) { assert!(pc < KERNEL.code.len(), "Kernel PC is out of range: {}", pc); } -fn handle_error(any_state: &mut State, err: ProgramError) -> anyhow::Result<()> { - let exc_code: u8 = match err { - ProgramError::OutOfGas => 0, - ProgramError::InvalidOpcode => 1, - ProgramError::StackUnderflow => 2, - ProgramError::InvalidJumpDestination => 3, - ProgramError::InvalidJumpiDestination => 4, - ProgramError::StackOverflow => 5, - _ => bail!("TODO: figure out what to do with this..."), - }; +pub(crate) trait Transition: State { + fn generate_jumpdest_analysis(&mut self, dst: usize) -> bool; + + fn generate_jump(&mut self, mut row: CpuColumnsView) -> Result<(), ProgramError> { + let [(dst, _)] = + stack_pop_with_log_and_fill::<1, _>(self.get_mut_generation_state(), &mut row)?; + + let dst: u32 = dst + .try_into() + .map_err(|_| ProgramError::InvalidJumpDestination)?; + + if !self.generate_jumpdest_analysis(dst as usize) { + let gen_state = self.get_mut_generation_state(); + // Even though we might be in the interpreter, `JumpdestBits` is not part + // preinitialized segments, so we don't need to carry out the additional checks + // when get the value from memory. + let (jumpdest_bit, jumpdest_bit_log) = mem_read_gp_with_log_and_fill( + NUM_GP_CHANNELS - 1, + MemoryAddress::new( + gen_state.registers.context, + Segment::JumpdestBits, + dst as usize, + ), + gen_state, + &mut row, + false, + &HashMap::default(), + ); - let (checkpoint, is_generation, state): ( - GenerationStateCheckpoint, - bool, - &mut GenerationState<_>, - ) = match any_state { - State::Generation(state) => (state.checkpoint(), true, state), - State::Interpreter(interpreter) => ( - interpreter.checkpoint(), - false, - &mut interpreter.generation_state, - ), - }; - let (row, _) = base_row(state); - generate_exception(exc_code, state, row, is_generation); - match any_state { - State::Generation(state) => {} - // If we are in the interpreter, we do not need all the traces. We only need the memory - // traces of the current operation. - State::Interpreter(interpreter) => interpreter.clear_traces(), - } - any_state.apply_ops(checkpoint); + row.mem_channels[1].value[0] = F::ONE; - Ok(()) -} + if gen_state.registers.is_kernel { + // Don't actually do the read, just set the address, etc. + let channel = &mut row.mem_channels[NUM_GP_CHANNELS - 1]; + channel.used = F::ZERO; + channel.value[0] = F::ONE; + } else { + if jumpdest_bit != U256::one() { + return Err(ProgramError::InvalidJumpDestination); + } + gen_state.traces.push_memory(jumpdest_bit_log); + } -pub(crate) fn transition(any_state: &mut State) -> anyhow::Result<()> { - let checkpoint = any_state.checkpoint(); - let result = try_perform_instruction(any_state); + // Extra fields required by the constraints. + row.general.jumps_mut().should_jump = F::ONE; + row.general.jumps_mut().cond_sum_pinv = F::ONE; + + let diff = row.stack_len - F::ONE; + if let Some(inv) = diff.try_inverse() { + row.general.stack_mut().stack_inv = inv; + row.general.stack_mut().stack_inv_aux = F::ONE; + } else { + row.general.stack_mut().stack_inv = F::ZERO; + row.general.stack_mut().stack_inv_aux = F::ZERO; + } + + gen_state.traces.push_cpu(row); + } + self.get_mut_generation_state().jump_to(dst as usize)?; + Ok(()) + } - match result { - Ok(op) => { - any_state.apply_ops(checkpoint); + fn generate_jumpi(&mut self, mut row: CpuColumnsView) -> Result<(), ProgramError> { + let [(dst, _), (cond, log_cond)] = + stack_pop_with_log_and_fill::<2, _>(self.get_mut_generation_state(), &mut row)?; + + let should_jump = !cond.is_zero(); + if should_jump { + let dst: u32 = dst + .try_into() + .map_err(|_| ProgramError::InvalidJumpiDestination)?; + let is_kernel = self.get_registers().is_kernel; + if !self.generate_jumpdest_analysis(dst as usize) { + row.general.jumps_mut().should_jump = F::ONE; + let cond_sum_u64 = cond + .0 + .into_iter() + .map(|limb| ((limb as u32) as u64) + (limb >> 32)) + .sum(); + let cond_sum = F::from_canonical_u64(cond_sum_u64); + row.general.jumps_mut().cond_sum_pinv = cond_sum.inverse(); + } + self.get_mut_generation_state().jump_to(dst as usize)?; + } else { + row.general.jumps_mut().should_jump = F::ZERO; + row.general.jumps_mut().cond_sum_pinv = F::ZERO; + self.incr_pc(1); + } - if might_overflow_op(op) { - any_state.get_mut_registers().check_overflow = true; + let gen_state = self.get_mut_generation_state(); + // Even though we might be in the interpreter, `JumpdestBits` is not part of the + // preinitialized segments, so we don't need to carry out the additional checks + // when get the value from memory. + let (jumpdest_bit, jumpdest_bit_log) = mem_read_gp_with_log_and_fill( + NUM_GP_CHANNELS - 1, + MemoryAddress::new( + gen_state.registers.context, + Segment::JumpdestBits, + dst.low_u32() as usize, + ), + gen_state, + &mut row, + false, + &HashMap::default(), + ); + if !should_jump || gen_state.registers.is_kernel { + // Don't actually do the read, just set the address, etc. + let channel = &mut row.mem_channels[NUM_GP_CHANNELS - 1]; + channel.used = F::ZERO; + channel.value[0] = F::ONE; + } else { + if jumpdest_bit != U256::one() { + return Err(ProgramError::InvalidJumpiDestination); } - Ok(()) + gen_state.traces.push_memory(jumpdest_bit_log); } - Err(e) => { - if any_state.get_registers().is_kernel { - let offset_name = KERNEL.offset_name(any_state.get_registers().program_counter); - bail!( - "{:?} in kernel at pc={}, stack={:?}, memory={:?}", - e, - offset_name, - any_state.get_stack(), - any_state.mem_get_kernel_content(), - ); + + let diff = row.stack_len - F::TWO; + if let Some(inv) = diff.try_inverse() { + row.general.stack_mut().stack_inv = inv; + row.general.stack_mut().stack_inv_aux = F::ONE; + } else { + row.general.stack_mut().stack_inv = F::ZERO; + row.general.stack_mut().stack_inv_aux = F::ZERO; + } + + gen_state.traces.push_memory(log_cond); + gen_state.traces.push_cpu(row); + Ok(()) + } + + /// Skips the following instructions for some specific labels + fn skip_if_necessary(&mut self, op: Operation) -> Result; + + fn get_preinitialized_segments(&self) -> HashMap; + + fn perform_op( + &mut self, + op: Operation, + opcode: u8, + row: CpuColumnsView, + ) -> Result<(), ProgramError> { + /// It may + let op = self.skip_if_necessary(op)?; + let is_interpreter = self.is_generation_state(); + let preinitialized_segments = self.get_preinitialized_segments(); + + #[cfg(debug_assertions)] + if !self.get_registers().is_kernel { + log::debug!( + "User instruction {:?}, stack = {:?}, ctx = {}", + op, + { + let mut stack = self.get_stack(); + stack.reverse(); + stack + }, + self.get_registers().context + ); + } + + let generation_state = self.get_mut_generation_state(); + + match op { + Operation::Push(n) => generate_push(n, generation_state, row)?, + Operation::Dup(n) => generate_dup(n, generation_state, row)?, + Operation::Swap(n) => generate_swap(n, generation_state, row)?, + Operation::Iszero => generate_iszero(generation_state, row)?, + Operation::Not => generate_not(generation_state, row)?, + Operation::BinaryArithmetic(arithmetic::BinaryOperator::Shl) => { + generate_shl(generation_state, row)? + } + Operation::BinaryArithmetic(arithmetic::BinaryOperator::Shr) => { + generate_shr(generation_state, row)? + } + Operation::Syscall(opcode, stack_values_read, stack_len_increased) => generate_syscall( + opcode, + stack_values_read, + stack_len_increased, + generation_state, + row, + )?, + Operation::Eq => generate_eq(generation_state, row)?, + Operation::BinaryLogic(binary_logic_op) => { + generate_binary_logic_op(binary_logic_op, generation_state, row)? } - any_state.rollback(checkpoint); - handle_error(any_state, e) + Operation::BinaryArithmetic(op) => { + generate_binary_arithmetic_op(op, generation_state, row)? + } + Operation::TernaryArithmetic(op) => { + generate_ternary_arithmetic_op(op, generation_state, row)? + } + Operation::KeccakGeneral => generate_keccak_general( + generation_state, + row, + is_interpreter, + &preinitialized_segments, + )?, + Operation::ProverInput => generate_prover_input(generation_state, row)?, + Operation::Pop => generate_pop(generation_state, row)?, + Operation::Jump => self.generate_jump(row)?, + Operation::Jumpi => self.generate_jumpi(row)?, + Operation::Pc => generate_pc(generation_state, row)?, + Operation::Jumpdest => generate_jumpdest(generation_state, row)?, + Operation::GetContext => generate_get_context(generation_state, row)?, + Operation::SetContext => generate_set_context(generation_state, row)?, + Operation::Mload32Bytes => generate_mload_32bytes( + generation_state, + row, + is_interpreter, + &preinitialized_segments, + )?, + Operation::Mstore32Bytes(n) => generate_mstore_32bytes(n, generation_state, row)?, + Operation::ExitKernel => generate_exit_kernel(generation_state, row)?, + Operation::MloadGeneral => generate_mload_general( + generation_state, + row, + is_interpreter, + &preinitialized_segments, + )?, + Operation::MstoreGeneral => generate_mstore_general(generation_state, row)?, + }; + + if !self.is_generation_state() { + self.clear_traces(); } + Ok(()) } + + fn fill_stack_fields(&mut self, row: &mut CpuColumnsView) -> Result<(), ProgramError>; + // { + // if state.registers.is_stack_top_read && is_generation { + // let channel = &mut row.mem_channels[0]; + // channel.used = F::ONE; + // channel.is_read = F::ONE; + // channel.addr_context = + // F::from_canonical_usize(state.registers.context); channel. + // addr_segment = F::from_canonical_usize(Segment::Stack.unscale()); + // channel.addr_virtual = + // F::from_canonical_usize(state.registers.stack_len - 1); + + // let address = MemoryAddress::new( + // state.registers.context, + // Segment::Stack, + // state.registers.stack_len - 1, + // ); + + // let mem_op = MemoryOp::new( + // GeneralPurpose(0), + // state.traces.clock(), + // address, + // MemoryOpKind::Read, + // state.registers.stack_top, + // ); + // state.traces.push_memory(mem_op); + // } + // state.registers.is_stack_top_read = false; + + // if state.registers.check_overflow && is_generation { + // if state.registers.is_kernel { + // row.general.stack_mut().stack_len_bounds_aux = F::ZERO; + // } else { + // let clock = state.traces.clock(); + // let last_row = &mut state.traces.cpu[clock - 1]; + // let disallowed_len = F::from_canonical_usize(MAX_USER_STACK_SIZE + // + 1); let diff = row.stack_len - disallowed_len; if let Some(inv) + // = diff.try_inverse() { last_row.general.stack_mut().stack_len_bounds_aux = + // inv; } else { // This is a stack overflow that should have been caught + // earlier. return Err(ProgramError::InterpreterError); + // } + // } + // } + // state.registers.check_overflow = false; + + // Ok(()) + // } } From 0a824fbf5f0af948cca8eaa0b65c9f6d53c2faa8 Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Mon, 26 Feb 2024 11:48:14 +0100 Subject: [PATCH 2/9] Move perform_state_op to Transition --- .../src/cpu/kernel/interpreter.rs | 5 +- evm_arithmetization/src/generation/state.rs | 4 +- evm_arithmetization/src/witness/transition.rs | 120 ++++++------------ 3 files changed, 41 insertions(+), 88 deletions(-) diff --git a/evm_arithmetization/src/cpu/kernel/interpreter.rs b/evm_arithmetization/src/cpu/kernel/interpreter.rs index cc65aeb4c..71643b5f6 100644 --- a/evm_arithmetization/src/cpu/kernel/interpreter.rs +++ b/evm_arithmetization/src/cpu/kernel/interpreter.rs @@ -38,8 +38,7 @@ use crate::witness::memory::{ use crate::witness::operation::{Operation, CONTEXT_SCALING_FACTOR}; use crate::witness::state::RegistersState; use crate::witness::transition::{ - decode, fill_op_flag, get_op_special_length, log_kernel_instruction, perform_state_op, - Transition, + decode, fill_op_flag, get_op_special_length, log_kernel_instruction, Transition, }; use crate::witness::util::{push_no_write, stack_peek}; use crate::{arithmetic, logic}; @@ -865,7 +864,7 @@ impl State for Interpreter { row.general.stack_mut().stack_inv_aux = F::ONE; } - perform_state_op(self, opcode, op, row) + self.perform_state_op(opcode, op, row) } } diff --git a/evm_arithmetization/src/generation/state.rs b/evm_arithmetization/src/generation/state.rs index ca7121eef..aa431d71d 100644 --- a/evm_arithmetization/src/generation/state.rs +++ b/evm_arithmetization/src/generation/state.rs @@ -26,7 +26,7 @@ use crate::witness::state::RegistersState; use crate::witness::traces::{TraceCheckpoint, Traces}; use crate::witness::transition::{ decode, fill_op_flag, get_op_special_length, log_kernel_instruction, might_overflow_op, - perform_state_op, read_code_memory, Transition, + read_code_memory, Transition, }; use crate::witness::util::{ fill_channel_with_value, mem_read_gp_with_log_and_fill, stack_peek, stack_pop_with_log_and_fill, @@ -505,7 +505,7 @@ impl State for GenerationState { row.general.stack_mut().stack_inv_aux = F::ONE; } - perform_state_op(self, opcode, op, row) + self.perform_state_op(opcode, op, row) } } diff --git a/evm_arithmetization/src/witness/transition.rs b/evm_arithmetization/src/witness/transition.rs index f8fa23787..ab19163f8 100644 --- a/evm_arithmetization/src/witness/transition.rs +++ b/evm_arithmetization/src/witness/transition.rs @@ -268,43 +268,6 @@ pub(crate) const fn might_overflow_op(op: Operation) -> bool { } } -pub(crate) fn perform_state_op>( - any_state: &mut T, - opcode: u8, - op: Operation, - row: CpuColumnsView, -) -> Result { - any_state.perform_op(op, opcode, row)?; - any_state.incr_pc(match op { - Operation::Syscall(_, _, _) | Operation::ExitKernel => 0, - Operation::Push(n) => n as usize + 1, - Operation::Jump | Operation::Jumpi => 0, - _ => 1, - }); - - any_state.incr_gas(gas_to_charge(op)); - let registers = any_state.get_registers(); - let gas_limit_address = MemoryAddress::new( - registers.context, - Segment::ContextMetadata, - ContextMetadata::GasLimit.unscale(), // context offsets are already scaled - ); - - if !registers.is_kernel { - let gas_limit = TryInto::::try_into(any_state.get_from_memory(gas_limit_address)); - match gas_limit { - Ok(limit) => { - if registers.gas_used > limit { - return Err(ProgramError::OutOfGas); - } - } - Err(_) => return Err(ProgramError::IntegerTooLarge), - } - } - - Ok(op) -} - pub(crate) fn log_kernel_instruction(state: &mut GenerationState, op: Operation) { // The logic below is a bit costly, so skip it if debug logs aren't enabled. if !log_enabled!(log::Level::Debug) { @@ -337,6 +300,43 @@ pub(crate) fn log_kernel_instruction(state: &mut GenerationState, o pub(crate) trait Transition: State { fn generate_jumpdest_analysis(&mut self, dst: usize) -> bool; + fn perform_state_op( + &mut self, + opcode: u8, + op: Operation, + row: CpuColumnsView, + ) -> Result { + self.perform_op(op, opcode, row)?; + self.incr_pc(match op { + Operation::Syscall(_, _, _) | Operation::ExitKernel => 0, + Operation::Push(n) => n as usize + 1, + Operation::Jump | Operation::Jumpi => 0, + _ => 1, + }); + + self.incr_gas(gas_to_charge(op)); + let registers = self.get_registers(); + let gas_limit_address = MemoryAddress::new( + registers.context, + Segment::ContextMetadata, + ContextMetadata::GasLimit.unscale(), // context offsets are already scaled + ); + + if !registers.is_kernel { + let gas_limit = TryInto::::try_into(self.get_from_memory(gas_limit_address)); + match gas_limit { + Ok(limit) => { + if registers.gas_used > limit { + return Err(ProgramError::OutOfGas); + } + } + Err(_) => return Err(ProgramError::IntegerTooLarge), + } + } + + Ok(op) + } + fn generate_jump(&mut self, mut row: CpuColumnsView) -> Result<(), ProgramError> { let [(dst, _)] = stack_pop_with_log_and_fill::<1, _>(self.get_mut_generation_state(), &mut row)?; @@ -564,50 +564,4 @@ pub(crate) trait Transition: State { } fn fill_stack_fields(&mut self, row: &mut CpuColumnsView) -> Result<(), ProgramError>; - // { - // if state.registers.is_stack_top_read && is_generation { - // let channel = &mut row.mem_channels[0]; - // channel.used = F::ONE; - // channel.is_read = F::ONE; - // channel.addr_context = - // F::from_canonical_usize(state.registers.context); channel. - // addr_segment = F::from_canonical_usize(Segment::Stack.unscale()); - // channel.addr_virtual = - // F::from_canonical_usize(state.registers.stack_len - 1); - - // let address = MemoryAddress::new( - // state.registers.context, - // Segment::Stack, - // state.registers.stack_len - 1, - // ); - - // let mem_op = MemoryOp::new( - // GeneralPurpose(0), - // state.traces.clock(), - // address, - // MemoryOpKind::Read, - // state.registers.stack_top, - // ); - // state.traces.push_memory(mem_op); - // } - // state.registers.is_stack_top_read = false; - - // if state.registers.check_overflow && is_generation { - // if state.registers.is_kernel { - // row.general.stack_mut().stack_len_bounds_aux = F::ZERO; - // } else { - // let clock = state.traces.clock(); - // let last_row = &mut state.traces.cpu[clock - 1]; - // let disallowed_len = F::from_canonical_usize(MAX_USER_STACK_SIZE - // + 1); let diff = row.stack_len - disallowed_len; if let Some(inv) - // = diff.try_inverse() { last_row.general.stack_mut().stack_len_bounds_aux = - // inv; } else { // This is a stack overflow that should have been caught - // earlier. return Err(ProgramError::InterpreterError); - // } - // } - // } - // state.registers.check_overflow = false; - - // Ok(()) - // } } From 66762000c92db47dec28688fd635753f2ef98af2 Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Mon, 26 Feb 2024 12:26:39 +0100 Subject: [PATCH 3/9] Fix bug --- evm_arithmetization/src/witness/transition.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/evm_arithmetization/src/witness/transition.rs b/evm_arithmetization/src/witness/transition.rs index ab19163f8..1e414e0fa 100644 --- a/evm_arithmetization/src/witness/transition.rs +++ b/evm_arithmetization/src/witness/transition.rs @@ -478,7 +478,7 @@ pub(crate) trait Transition: State { ) -> Result<(), ProgramError> { /// It may let op = self.skip_if_necessary(op)?; - let is_interpreter = self.is_generation_state(); + let is_interpreter = !self.is_generation_state(); let preinitialized_segments = self.get_preinitialized_segments(); #[cfg(debug_assertions)] From dac96fa4aedfd35a894f4070b142b029fff65560 Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Mon, 26 Feb 2024 14:45:19 +0100 Subject: [PATCH 4/9] Get rid of is_generation_state + add reviews --- .../src/cpu/kernel/interpreter.rs | 29 ++++++----- .../src/cpu/kernel/tests/core/access_lists.rs | 11 ---- .../src/cpu/kernel/tests/ecc/curve_ops.rs | 1 - .../src/generation/prover_input.rs | 4 -- evm_arithmetization/src/generation/state.rs | 46 +++++------------ evm_arithmetization/src/witness/memory.rs | 2 - evm_arithmetization/src/witness/operation.rs | 51 ++++--------------- evm_arithmetization/src/witness/transition.rs | 34 ++++--------- evm_arithmetization/src/witness/util.rs | 19 ++----- 9 files changed, 53 insertions(+), 144 deletions(-) diff --git a/evm_arithmetization/src/cpu/kernel/interpreter.rs b/evm_arithmetization/src/cpu/kernel/interpreter.rs index 71643b5f6..48b80398a 100644 --- a/evm_arithmetization/src/cpu/kernel/interpreter.rs +++ b/evm_arithmetization/src/cpu/kernel/interpreter.rs @@ -619,7 +619,6 @@ impl Interpreter { for i in range { let term = self.generation_state.memory.get( MemoryAddress::new(0, segment, i), - true, &self.preinitialized_segments, ); output.push(term); @@ -683,7 +682,6 @@ impl Interpreter { segment: Segment::JumpdestBits.unscale(), virt: offset, }, - false, &HashMap::default(), ) } else { @@ -765,35 +763,31 @@ impl State for Interpreter { } fn incr_gas(&mut self, n: u64) { - self.generation_state.registers.gas_used += n; + self.generation_state.incr_gas(n); } fn incr_pc(&mut self, n: usize) { - self.generation_state.registers.program_counter += n; + self.generation_state.incr_pc(n); } fn get_registers(&self) -> RegistersState { - self.generation_state.registers + self.generation_state.get_registers() } fn get_mut_registers(&mut self) -> &mut RegistersState { - &mut self.generation_state.registers + self.generation_state.get_mut_registers() } fn get_from_memory(&mut self, address: MemoryAddress) -> U256 { self.generation_state .memory - .get(address, true, &self.preinitialized_segments) + .get(address, &self.preinitialized_segments) } fn get_mut_generation_state(&mut self) -> &mut GenerationState { &mut self.generation_state } - fn is_generation_state(&mut self) -> bool { - false - } - fn incr_interpreter_clock(&mut self) { self.clock += 1 } @@ -866,12 +860,23 @@ impl State for Interpreter { self.perform_state_op(opcode, op, row) } + + fn clear_traces(&mut self) { + let generation_state = self.get_mut_generation_state(); + generation_state.traces.arithmetic_ops = vec![]; + generation_state.traces.arithmetic_ops = vec![]; + generation_state.traces.byte_packing_ops = vec![]; + generation_state.traces.cpu = vec![]; + generation_state.traces.logic_ops = vec![]; + generation_state.traces.keccak_inputs = vec![]; + generation_state.traces.keccak_sponge_ops = vec![]; + } } impl Transition for Interpreter { fn generate_jumpdest_analysis(&mut self, dst: usize) -> bool { if self.is_jumpdest_analysis && !self.generation_state.registers.is_kernel { - self.add_jumpdest_offset(dst as usize); + self.add_jumpdest_offset(dst); true } else { false diff --git a/evm_arithmetization/src/cpu/kernel/tests/core/access_lists.rs b/evm_arithmetization/src/cpu/kernel/tests/core/access_lists.rs index 3a1b9903f..f5e3168a8 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/core/access_lists.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/core/access_lists.rs @@ -30,7 +30,6 @@ fn test_init_access_lists() -> Result<()> { .map(|i| { interpreter.generation_state.memory.get( MemoryAddress::new(0, Segment::AccessedAddresses, i), - false, &HashMap::default(), ) }) @@ -44,7 +43,6 @@ fn test_init_access_lists() -> Result<()> { .map(|i| { interpreter.generation_state.memory.get( MemoryAddress::new(0, Segment::AccessedStorageKeys, i), - false, &HashMap::default(), ) }) @@ -115,7 +113,6 @@ fn test_insert_address() -> Result<()> { assert_eq!( interpreter.generation_state.memory.get( MemoryAddress::new_bundle(U256::from(AccessedAddressesLen as usize)).unwrap(), - false, &HashMap::default(), ), U256::from(Segment::AccessedAddresses as usize + 4) @@ -170,7 +167,6 @@ fn test_insert_accessed_addresses() -> Result<()> { assert_eq!( interpreter.generation_state.memory.get( MemoryAddress::new_bundle(U256::from(AccessedAddressesLen as usize)).unwrap(), - false, &HashMap::default(), ), U256::from(offset + 2 * (n + 1)) @@ -187,7 +183,6 @@ fn test_insert_accessed_addresses() -> Result<()> { assert_eq!( interpreter.generation_state.memory.get( MemoryAddress::new_bundle(U256::from(AccessedAddressesLen as usize)).unwrap(), - false, &HashMap::default(), ), U256::from(offset + 2 * (n + 2)) @@ -195,7 +190,6 @@ fn test_insert_accessed_addresses() -> Result<()> { assert_eq!( interpreter.generation_state.memory.get( MemoryAddress::new(0, AccessedAddresses, 2 * (n + 1)), - false, &HashMap::default(), ), U256::from(addr_not_in_list.0.as_slice()) @@ -259,7 +253,6 @@ fn test_insert_accessed_storage_keys() -> Result<()> { assert_eq!( interpreter.generation_state.memory.get( MemoryAddress::new_bundle(U256::from(AccessedStorageKeysLen as usize)).unwrap(), - false, &HashMap::default(), ), U256::from(offset + 4 * (n + 1)) @@ -281,7 +274,6 @@ fn test_insert_accessed_storage_keys() -> Result<()> { assert_eq!( interpreter.generation_state.memory.get( MemoryAddress::new_bundle(U256::from(AccessedStorageKeysLen as usize)).unwrap(), - false, &HashMap::default() ), U256::from(offset + 4 * (n + 2)) @@ -289,7 +281,6 @@ fn test_insert_accessed_storage_keys() -> Result<()> { assert_eq!( interpreter.generation_state.memory.get( MemoryAddress::new(0, AccessedStorageKeys, 4 * (n + 1)), - false, &HashMap::default(), ), U256::from(storage_key_not_in_list.0 .0.as_slice()) @@ -297,7 +288,6 @@ fn test_insert_accessed_storage_keys() -> Result<()> { assert_eq!( interpreter.generation_state.memory.get( MemoryAddress::new(0, AccessedStorageKeys, 4 * (n + 1) + 1), - false, &HashMap::default() ), storage_key_not_in_list.1 @@ -305,7 +295,6 @@ fn test_insert_accessed_storage_keys() -> Result<()> { assert_eq!( interpreter.generation_state.memory.get( MemoryAddress::new(0, AccessedStorageKeys, 4 * (n + 1) + 2), - false, &HashMap::default() ), storage_key_not_in_list.2 diff --git a/evm_arithmetization/src/cpu/kernel/tests/ecc/curve_ops.rs b/evm_arithmetization/src/cpu/kernel/tests/ecc/curve_ops.rs index 90ccca15b..4540be61e 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/ecc/curve_ops.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/ecc/curve_ops.rs @@ -216,7 +216,6 @@ mod bn { segment: Segment::BnTableQ.unscale(), virt: i, }, - false, &HashMap::default(), )); } diff --git a/evm_arithmetization/src/generation/prover_input.rs b/evm_arithmetization/src/generation/prover_input.rs index 2a2126589..005fa4b12 100644 --- a/evm_arithmetization/src/generation/prover_input.rs +++ b/evm_arithmetization/src/generation/prover_input.rs @@ -423,7 +423,6 @@ impl GenerationState { .map(|i| { u256_to_u8(self.memory.get( MemoryAddress::new(context, Segment::Code, i), - false, &HashMap::default(), )) }) @@ -438,7 +437,6 @@ impl GenerationState { Segment::ContextMetadata, ContextMetadata::CodeSize.unscale(), ), - false, &HashMap::default(), ))?; Ok(code_len) @@ -477,7 +475,6 @@ impl GenerationState { fn get_global_metadata(&self, data: GlobalMetadata) -> U256 { self.memory.get( MemoryAddress::new(0, Segment::GlobalMetadata, data.unscale()), - false, &HashMap::default(), ) } @@ -493,7 +490,6 @@ impl GenerationState { Segment::GlobalMetadata, GlobalMetadata::AccessedStorageKeysLen.unscale(), ), - false, &HashMap::default(), ) - Segment::AccessedStorageKeys as usize, )?; diff --git a/evm_arithmetization/src/generation/state.rs b/evm_arithmetization/src/generation/state.rs index aa431d71d..0de7ac527 100644 --- a/evm_arithmetization/src/generation/state.rs +++ b/evm_arithmetization/src/generation/state.rs @@ -56,8 +56,8 @@ pub(crate) trait State { /// Returns a mutable `GenerationState` from a `State`. fn get_mut_generation_state(&mut self) -> &mut GenerationState; - /// Returns true if a `State` is a `GenerationState` and false otherwise. - fn is_generation_state(&mut self) -> bool; + // /// Returns true if a `State` is a `GenerationState` and false otherwise. + // fn is_generation_state(&mut self) -> bool; /// Increments the clock of an `Interpreter`'s clock. fn incr_interpreter_clock(&mut self); @@ -129,20 +129,12 @@ pub(crate) trait State { _ => bail!("TODO: figure out what to do with this..."), }; - let (checkpoint, is_generation) = (self.checkpoint(), self.is_generation_state()); + let checkpoint = self.checkpoint(); let (row, _) = self.base_row(); - generate_exception( - exc_code, - self.get_mut_generation_state(), - row, - is_generation, - ); + generate_exception(exc_code, self.get_mut_generation_state(), row); - // We only clear traces for the interpreter - if !self.is_generation_state() { - self.clear_traces() - } + self.clear_traces(); self.apply_ops(checkpoint); @@ -181,16 +173,7 @@ pub(crate) trait State { /// Clears all traces from `GenerationState` except for /// memory_ops, which are necessary to apply operations. - fn clear_traces(&mut self) { - let generation_state = self.get_mut_generation_state(); - generation_state.traces.arithmetic_ops = vec![]; - generation_state.traces.arithmetic_ops = vec![]; - generation_state.traces.byte_packing_ops = vec![]; - generation_state.traces.cpu = vec![]; - generation_state.traces.logic_ops = vec![]; - generation_state.traces.keccak_inputs = vec![]; - generation_state.traces.keccak_sponge_ops = vec![]; - } + fn clear_traces(&mut self) {} fn try_perform_instruction(&mut self) -> Result; @@ -335,11 +318,8 @@ impl GenerationState { let returndata_offset = ContextMetadata::ReturndataSize.unscale(); let returndata_size_addr = MemoryAddress::new(ctx, Segment::ContextMetadata, returndata_offset); - let returndata_size = u256_to_usize(self.memory.get( - returndata_size_addr, - false, - &HashMap::default(), - ))?; + let returndata_size = + u256_to_usize(self.memory.get(returndata_size_addr, &HashMap::default()))?; let code = self.memory.contexts[ctx].segments[Segment::Returndata.unscale()].content [..returndata_size] .iter() @@ -415,7 +395,7 @@ impl State for GenerationState { /// Returns the value stored at address `address` in a `State`. fn get_from_memory(&mut self, address: MemoryAddress) -> U256 { - self.memory.get(address, false, &HashMap::default()) + self.memory.get(address, &HashMap::default()) } /// Returns a mutable `GenerationState` from a `State`. @@ -423,10 +403,10 @@ impl State for GenerationState { self } - /// Returns true if a `State` is a `GenerationState` and false otherwise. - fn is_generation_state(&mut self) -> bool { - true - } + // /// Returns true if a `State` is a `GenerationState` and false otherwise. + // fn is_generation_state(&mut self) -> bool { + // true + // } /// Increments the clock of an `Interpreter`'s clock. fn incr_interpreter_clock(&mut self) {} diff --git a/evm_arithmetization/src/witness/memory.rs b/evm_arithmetization/src/witness/memory.rs index 2e509fe53..c83caf197 100644 --- a/evm_arithmetization/src/witness/memory.rs +++ b/evm_arithmetization/src/witness/memory.rs @@ -223,7 +223,6 @@ impl MemoryState { pub(crate) fn get( &self, address: MemoryAddress, - is_interpreter: bool, preinitialized_segments: &HashMap, ) -> U256 { match self.get_option(address) { @@ -271,7 +270,6 @@ impl MemoryState { pub(crate) fn read_global_metadata(&self, field: GlobalMetadata) -> U256 { self.get( MemoryAddress::new_bundle(U256::from(field as usize)).unwrap(), - false, &HashMap::default(), ) } diff --git a/evm_arithmetization/src/witness/operation.rs b/evm_arithmetization/src/witness/operation.rs index 19a0fd0ae..becfa81ec 100644 --- a/evm_arithmetization/src/witness/operation.rs +++ b/evm_arithmetization/src/witness/operation.rs @@ -137,7 +137,6 @@ pub(crate) fn generate_ternary_arithmetic_op( pub(crate) fn generate_keccak_general( state: &mut GenerationState, mut row: CpuColumnsView, - is_interpreter: bool, preinitialized_segments: &HashMap, ) -> Result<(), ProgramError> { let [(addr, _), (len, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; @@ -150,9 +149,7 @@ pub(crate) fn generate_keccak_general( virt: base_address.virt.saturating_add(i), ..base_address }; - let val = state - .memory - .get(address, is_interpreter, preinitialized_segments); + let val = state.memory.get(address, preinitialized_segments); val.low_u32() as u8 }) .collect_vec(); @@ -297,13 +294,7 @@ pub(crate) fn generate_set_context( // Even though we might be in the interpreter, `Stack` is not part of the // preinitialized segments, so we don't need to carry out the additional checks // when get the value from memory. - mem_read_with_log( - GeneralPurpose(2), - new_sp_addr, - state, - false, - &HashMap::default(), - ) + mem_read_with_log(GeneralPurpose(2), new_sp_addr, state, &HashMap::default()) }; // If the new stack isn't empty, read stack_top from memory. @@ -325,14 +316,8 @@ pub(crate) fn generate_set_context( // Even though we might be in the interpreter, `Stack` is not part of the // preinitialized segments, so we don't need to carry out the additional checks // when get the value from memory. - let (new_top, log_read_new_top) = mem_read_gp_with_log_and_fill( - 2, - new_top_addr, - state, - &mut row, - false, - &HashMap::default(), - ); + let (new_top, log_read_new_top) = + mem_read_gp_with_log_and_fill(2, new_top_addr, state, &mut row, &HashMap::default()); state.registers.stack_top = new_top; state.traces.push_memory(log_read_new_top); } else { @@ -374,11 +359,6 @@ pub(crate) fn generate_push( virt: base_address.virt + i, ..base_address }, - // Even though we might be in the interpreter, `Code` is not part of - // the preinitialized segments, so we don't need to carry - // out the additional checks when get the value from - // memory. - false, &HashMap::default(), ) .low_u32() as u8 @@ -458,7 +438,7 @@ pub(crate) fn generate_dup( // Even though we might be in the interpreter, `Stack` is not part of the // preinitialized segments, so we don't need to carry out the additional checks // when get the value from memory. - mem_read_gp_with_log_and_fill(2, other_addr, state, &mut row, false, &HashMap::default()) + mem_read_gp_with_log_and_fill(2, other_addr, state, &mut row, &HashMap::default()) }; push_no_write(state, val); @@ -484,7 +464,7 @@ pub(crate) fn generate_swap( // preinitialized segments, so we don't need to carry out the additional checks // when get the value from memory. let (in1, log_in1) = - mem_read_gp_with_log_and_fill(1, other_addr, state, &mut row, false, &HashMap::default()); + mem_read_gp_with_log_and_fill(1, other_addr, state, &mut row, &HashMap::default()); let log_out0 = mem_write_gp_log_and_fill(2, other_addr, state, &mut row, in0); push_no_write(state, in1); @@ -555,7 +535,6 @@ fn append_shift( lookup_addr, state, &mut row, - false, &HashMap::default(), ); state.traces.push_memory(read); @@ -649,7 +628,7 @@ pub(crate) fn generate_syscall( // Even though we might be in the interpreter, `Code` is not part of the // preinitialized segments, so we don't need to carry out the additional checks // when get the value from memory. - let val = state.memory.get(address, false, &HashMap::default()); + let val = state.memory.get(address, &HashMap::default()); val.low_u32() as u8 }) .collect_vec(); @@ -752,7 +731,6 @@ pub(crate) fn generate_exit_kernel( pub(crate) fn generate_mload_general( state: &mut GenerationState, mut row: CpuColumnsView, - is_interpreter: bool, preinitialized_segments: &HashMap, ) -> Result<(), ProgramError> { let [(addr, _)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; @@ -762,7 +740,6 @@ pub(crate) fn generate_mload_general( MemoryAddress::new_bundle(addr)?, state, &mut row, - is_interpreter, preinitialized_segments, ); push_no_write(state, val); @@ -788,7 +765,6 @@ pub(crate) fn generate_mload_general( pub(crate) fn generate_mload_32bytes( state: &mut GenerationState, mut row: CpuColumnsView, - is_interpreter: bool, preinitialized_segments: &HashMap, ) -> Result<(), ProgramError> { let [(addr, _), (len, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; @@ -810,9 +786,7 @@ pub(crate) fn generate_mload_32bytes( virt: base_address.virt + i, ..base_address }; - let val = state - .memory - .get(address, is_interpreter, preinitialized_segments); + let val = state.memory.get(address, preinitialized_segments); val.low_u32() as u8 }) .collect_vec(); @@ -880,7 +854,6 @@ pub(crate) fn generate_exception>( exc_code: u8, state: &mut T, mut row: CpuColumnsView, - is_generation: bool, ) -> Result<(), ProgramError> { // TDO: se fue arriba state.fill_stack_fields(&mut row)?; @@ -916,9 +889,7 @@ pub(crate) fn generate_exception>( // Even though we might be in the interpreter, `Code` is not part of the // preinitialized segments, so we don't need to carry out the additional checks // when get the value from memory. - let val = generation_state - .memory - .get(address, false, &HashMap::default()); + let val = generation_state.memory.get(address, &HashMap::default()); val.low_u32() as u8 }) .collect_vec(); @@ -949,9 +920,7 @@ pub(crate) fn generate_exception>( // Even though we might be in the interpreter, `Code` is not part of the // preinitialized segments, so we don't need to carry out the additional checks // when get the value from memory. - let opcode = generation_state - .memory - .get(address, false, &HashMap::default()); + let opcode = generation_state.memory.get(address, &HashMap::default()); // `ArithmeticStark` range checks `mem_channels[0]`, which contains // the top of the stack, `mem_channels[1]`, which contains the new PC, diff --git a/evm_arithmetization/src/witness/transition.rs b/evm_arithmetization/src/witness/transition.rs index 1e414e0fa..e73a96cbf 100644 --- a/evm_arithmetization/src/witness/transition.rs +++ b/evm_arithmetization/src/witness/transition.rs @@ -359,7 +359,6 @@ pub(crate) trait Transition: State { ), gen_state, &mut row, - false, &HashMap::default(), ); @@ -436,7 +435,6 @@ pub(crate) trait Transition: State { ), gen_state, &mut row, - false, &HashMap::default(), ); if !should_jump || gen_state.registers.is_kernel { @@ -478,7 +476,6 @@ pub(crate) trait Transition: State { ) -> Result<(), ProgramError> { /// It may let op = self.skip_if_necessary(op)?; - let is_interpreter = !self.is_generation_state(); let preinitialized_segments = self.get_preinitialized_segments(); #[cfg(debug_assertions)] @@ -526,12 +523,9 @@ pub(crate) trait Transition: State { Operation::TernaryArithmetic(op) => { generate_ternary_arithmetic_op(op, generation_state, row)? } - Operation::KeccakGeneral => generate_keccak_general( - generation_state, - row, - is_interpreter, - &preinitialized_segments, - )?, + Operation::KeccakGeneral => { + generate_keccak_general(generation_state, row, &preinitialized_segments)? + } Operation::ProverInput => generate_prover_input(generation_state, row)?, Operation::Pop => generate_pop(generation_state, row)?, Operation::Jump => self.generate_jump(row)?, @@ -540,26 +534,18 @@ pub(crate) trait Transition: State { Operation::Jumpdest => generate_jumpdest(generation_state, row)?, Operation::GetContext => generate_get_context(generation_state, row)?, Operation::SetContext => generate_set_context(generation_state, row)?, - Operation::Mload32Bytes => generate_mload_32bytes( - generation_state, - row, - is_interpreter, - &preinitialized_segments, - )?, + Operation::Mload32Bytes => { + generate_mload_32bytes(generation_state, row, &preinitialized_segments)? + } Operation::Mstore32Bytes(n) => generate_mstore_32bytes(n, generation_state, row)?, Operation::ExitKernel => generate_exit_kernel(generation_state, row)?, - Operation::MloadGeneral => generate_mload_general( - generation_state, - row, - is_interpreter, - &preinitialized_segments, - )?, + Operation::MloadGeneral => { + generate_mload_general(generation_state, row, &preinitialized_segments)? + } Operation::MstoreGeneral => generate_mstore_general(generation_state, row)?, }; - if !self.is_generation_state() { - self.clear_traces(); - } + self.clear_traces(); Ok(()) } diff --git a/evm_arithmetization/src/witness/util.rs b/evm_arithmetization/src/witness/util.rs index 36e1b3825..4bd589feb 100644 --- a/evm_arithmetization/src/witness/util.rs +++ b/evm_arithmetization/src/witness/util.rs @@ -50,7 +50,6 @@ pub(crate) fn stack_peek( Segment::Stack, state.registers.stack_len - 1 - i, ), - false, &HashMap::default(), )) } @@ -66,7 +65,6 @@ pub(crate) fn current_context_peek( let context = state.registers.context; state.memory.get( MemoryAddress::new(context, segment, virt), - is_interpreter, preinitialized_segments, ) } @@ -121,12 +119,9 @@ pub(crate) fn mem_read_with_log( channel: MemoryChannel, address: MemoryAddress, state: &GenerationState, - is_interpreter: bool, preinitialized_segments: &HashMap, ) -> (U256, MemoryOp) { - let val = state - .memory - .get(address, is_interpreter, preinitialized_segments); + let val = state.memory.get(address, preinitialized_segments); let op = MemoryOp::new( channel, state.traces.clock(), @@ -159,13 +154,7 @@ pub(crate) fn mem_read_code_with_log_and_fill( is_interpreter: bool, preinitialized_segments: &HashMap, ) -> (u8, MemoryOp) { - let (val, op) = mem_read_with_log( - MemoryChannel::Code, - address, - state, - is_interpreter, - preinitialized_segments, - ); + let (val, op) = mem_read_with_log(MemoryChannel::Code, address, state, preinitialized_segments); let val_u8 = to_byte_checked(val); row.opcode_bits = to_bits_le(val_u8); @@ -178,14 +167,12 @@ pub(crate) fn mem_read_gp_with_log_and_fill( address: MemoryAddress, state: &GenerationState, row: &mut CpuColumnsView, - is_interpreter: bool, preinitialized_segments: &HashMap, ) -> (U256, MemoryOp) { let (val, op) = mem_read_with_log( MemoryChannel::GeneralPurpose(n), address, state, - is_interpreter, preinitialized_segments, ); let val_limbs: [u64; 4] = val.0; @@ -276,7 +263,7 @@ pub(crate) fn stack_pop_with_log_and_fill( state.registers.stack_len - 1 - i, ); - mem_read_gp_with_log_and_fill(i, address, state, row, false, &HashMap::default()) + mem_read_gp_with_log_and_fill(i, address, state, row, &HashMap::default()) } }); From 25910cf97784f98844e3c827eb79eb7da4f6f11f Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Mon, 26 Feb 2024 14:49:08 +0100 Subject: [PATCH 5/9] Remove TDO --- evm_arithmetization/src/witness/operation.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/evm_arithmetization/src/witness/operation.rs b/evm_arithmetization/src/witness/operation.rs index becfa81ec..0c34cafd6 100644 --- a/evm_arithmetization/src/witness/operation.rs +++ b/evm_arithmetization/src/witness/operation.rs @@ -855,7 +855,6 @@ pub(crate) fn generate_exception>( state: &mut T, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - // TDO: se fue arriba state.fill_stack_fields(&mut row)?; let generation_state = state.get_mut_generation_state(); if TryInto::::try_into(generation_state.registers.gas_used).is_err() { From 3e311154936c2eaaf673bc2ccee04c16650c9727 Mon Sep 17 00:00:00 2001 From: Linda Guiga Date: Tue, 27 Feb 2024 18:19:45 +0100 Subject: [PATCH 6/9] Fix handle_error --- .../src/cpu/kernel/interpreter.rs | 2 +- evm_arithmetization/src/generation/mod.rs | 2 +- evm_arithmetization/src/generation/state.rs | 24 +++++++++++++------ evm_arithmetization/src/witness/transition.rs | 1 - 4 files changed, 19 insertions(+), 10 deletions(-) diff --git a/evm_arithmetization/src/cpu/kernel/interpreter.rs b/evm_arithmetization/src/cpu/kernel/interpreter.rs index 48b80398a..4245cde04 100644 --- a/evm_arithmetization/src/cpu/kernel/interpreter.rs +++ b/evm_arithmetization/src/cpu/kernel/interpreter.rs @@ -379,7 +379,7 @@ impl Interpreter { } pub(crate) fn run(&mut self) -> Result<(), anyhow::Error> { - self.run_cpu(false)?; + self.run_cpu()?; #[cfg(debug_assertions)] { diff --git a/evm_arithmetization/src/generation/mod.rs b/evm_arithmetization/src/generation/mod.rs index 1ced63da0..2ae8d25b1 100644 --- a/evm_arithmetization/src/generation/mod.rs +++ b/evm_arithmetization/src/generation/mod.rs @@ -312,7 +312,7 @@ pub fn generate_traces, const D: usize>( } fn simulate_cpu(state: &mut GenerationState) -> anyhow::Result<()> { - state.run_cpu(true)?; + state.run_cpu()?; let pc = state.registers.program_counter; // Padding diff --git a/evm_arithmetization/src/generation/state.rs b/evm_arithmetization/src/generation/state.rs index 0de7ac527..c1e8ac404 100644 --- a/evm_arithmetization/src/generation/state.rs +++ b/evm_arithmetization/src/generation/state.rs @@ -87,7 +87,11 @@ pub(crate) trait State { /// Simulates a CPU. It only generates the traces if the `State` is a /// `GenerationState`. Otherwise, it simply simulates all ooperations. - fn run_cpu(&mut self, is_generation: bool) -> anyhow::Result<()> { + fn run_cpu(&mut self) -> anyhow::Result<()> + where + Self: Transition, + Self: Sized, + { let halt_offsets = self.get_halt_offsets(); loop { @@ -104,9 +108,7 @@ pub(crate) trait State { return Ok(()); } } else { - if is_generation { - log::info!("CPU halted after {} cycles", self.get_clock()); - } + log::info!("CPU halted after {} cycles", self.get_clock()); return Ok(()); } } @@ -118,7 +120,11 @@ pub(crate) trait State { Ok(()) } - fn handle_error(&mut self, err: ProgramError) -> anyhow::Result<()> { + fn handle_error(&mut self, err: ProgramError) -> anyhow::Result<()> + where + Self: Transition, + Self: Sized, + { let exc_code: u8 = match err { ProgramError::OutOfGas => 0, ProgramError::InvalidOpcode => 1, @@ -132,7 +138,7 @@ pub(crate) trait State { let checkpoint = self.checkpoint(); let (row, _) = self.base_row(); - generate_exception(exc_code, self.get_mut_generation_state(), row); + generate_exception(exc_code, self, row); self.clear_traces(); @@ -141,7 +147,11 @@ pub(crate) trait State { Ok(()) } - fn transition(&mut self) -> anyhow::Result<()> { + fn transition(&mut self) -> anyhow::Result<()> + where + Self: Transition, + Self: Sized, + { let checkpoint = self.checkpoint(); let result = self.try_perform_instruction(); diff --git a/evm_arithmetization/src/witness/transition.rs b/evm_arithmetization/src/witness/transition.rs index e73a96cbf..1ac3fe6db 100644 --- a/evm_arithmetization/src/witness/transition.rs +++ b/evm_arithmetization/src/witness/transition.rs @@ -474,7 +474,6 @@ pub(crate) trait Transition: State { opcode: u8, row: CpuColumnsView, ) -> Result<(), ProgramError> { - /// It may let op = self.skip_if_necessary(op)?; let preinitialized_segments = self.get_preinitialized_segments(); From bc1dc1d590a0c6c6bf5a94ac91d130317077102c Mon Sep 17 00:00:00 2001 From: Linda Guiga Date: Tue, 27 Feb 2024 19:09:20 +0100 Subject: [PATCH 7/9] Move preinitialized_segments from the interpreter to MemoryState --- .../src/cpu/kernel/interpreter.rs | 63 ++++++++++--------- .../src/cpu/kernel/tests/core/access_lists.rs | 57 +++++++++-------- .../src/cpu/kernel/tests/ecc/curve_ops.rs | 13 ++-- .../src/generation/prover_input.rs | 49 +++++++-------- evm_arithmetization/src/generation/state.rs | 34 +++++----- evm_arithmetization/src/witness/memory.rs | 52 +++++++++++---- evm_arithmetization/src/witness/operation.rs | 50 +++++---------- evm_arithmetization/src/witness/transition.rs | 26 ++------ evm_arithmetization/src/witness/util.rs | 35 +++-------- 9 files changed, 177 insertions(+), 202 deletions(-) diff --git a/evm_arithmetization/src/cpu/kernel/interpreter.rs b/evm_arithmetization/src/cpu/kernel/interpreter.rs index 4245cde04..ae62dd062 100644 --- a/evm_arithmetization/src/cpu/kernel/interpreter.rs +++ b/evm_arithmetization/src/cpu/kernel/interpreter.rs @@ -57,7 +57,6 @@ pub(crate) struct Interpreter { pub(crate) opcode_count: [usize; 0x100], memops: Vec, jumpdest_table: HashMap>, - pub(crate) preinitialized_segments: HashMap, pub(crate) is_jumpdest_analysis: bool, pub(crate) clock: usize, } @@ -174,7 +173,6 @@ impl Interpreter { opcode_count: [0; 256], memops: vec![], jumpdest_table: HashMap::new(), - preinitialized_segments: HashMap::default(), is_jumpdest_analysis: false, clock: 0, }; @@ -206,7 +204,6 @@ impl Interpreter { opcode_count: [0; 256], memops: vec![], jumpdest_table: HashMap::new(), - preinitialized_segments: HashMap::new(), is_jumpdest_analysis: true, clock: 0, } @@ -240,8 +237,7 @@ impl Interpreter { let preinit_trie_data_segment = MemorySegmentState { content: trie_data.iter().map(|&elt| Some(elt)).collect::>(), }; - self.preinitialized_segments - .insert(Segment::TrieData, preinit_trie_data_segment); + self.insert_preinitialized_segment(Segment::TrieData, preinit_trie_data_segment); // Update the RLP and withdrawal prover inputs. let rlp_prover_inputs = @@ -356,10 +352,7 @@ impl Interpreter { match kind { MemoryOpKind::Read => { if self.generation_state.memory.get_option(address).is_none() { - if !self - .preinitialized_segments - .contains_key(&Segment::all()[address.segment]) - { + if !self.is_preinitialized_segment(address.segment) { assert_eq!( value, 0.into(), @@ -477,7 +470,7 @@ impl Interpreter { } pub(crate) fn get_memory_segment(&self, segment: Segment) -> Vec { - if self.preinitialized_segments.contains_key(&segment) { + if self.is_preinitialized_segment(segment.unscale()) { let total_len = self.generation_state.memory.contexts[0].segments[segment.unscale()] .content .len(); @@ -490,7 +483,14 @@ impl Interpreter { }) .collect::>() }; - let mut res = get_vals(&self.preinitialized_segments.get(&segment).unwrap().content); + let mut res = get_vals( + &self + .generation_state + .memory + .get_preinitialized_segment(segment) + .unwrap() + .content, + ); let init_len = res.len(); res.extend(&get_vals( &self.generation_state.memory.contexts[0].segments[segment.unscale()].content @@ -617,10 +617,10 @@ impl Interpreter { pub(crate) fn extract_kernel_memory(self, segment: Segment, range: Range) -> Vec { let mut output: Vec = Vec::with_capacity(range.end); for i in range { - let term = self.generation_state.memory.get( - MemoryAddress::new(0, segment, i), - &self.preinitialized_segments, - ); + let term = self + .generation_state + .memory + .get(MemoryAddress::new(0, segment, i)); output.push(term); } output @@ -676,14 +676,11 @@ impl Interpreter { // Even though we are in the interpreter, `JumpdestBits` is not part of the // preinitialized segments, so we don't need to carry out the additional checks // when get the value from memory. - self.generation_state.memory.get( - MemoryAddress { - context: self.context(), - segment: Segment::JumpdestBits.unscale(), - virt: offset, - }, - &HashMap::default(), - ) + self.generation_state.memory.get(MemoryAddress { + context: self.context(), + segment: Segment::JumpdestBits.unscale(), + virt: offset, + }) } else { 0.into() } @@ -762,6 +759,18 @@ impl State for Interpreter { } } + fn insert_preinitialized_segment(&mut self, segment: Segment, values: MemorySegmentState) { + self.generation_state + .memory + .insert_preinitialized_segment(segment, values); + } + + fn is_preinitialized_segment(&self, segment: usize) -> bool { + self.generation_state + .memory + .is_preinitialized_segment(segment) + } + fn incr_gas(&mut self, n: u64) { self.generation_state.incr_gas(n); } @@ -779,9 +788,7 @@ impl State for Interpreter { } fn get_from_memory(&mut self, address: MemoryAddress) -> U256 { - self.generation_state - .memory - .get(address, &self.preinitialized_segments) + self.generation_state.memory.get(address) } fn get_mut_generation_state(&mut self) -> &mut GenerationState { @@ -904,10 +911,6 @@ impl Transition for Interpreter { } } - fn get_preinitialized_segments(&self) -> HashMap { - self.preinitialized_segments.clone() - } - fn fill_stack_fields( &mut self, row: &mut crate::cpu::columns::CpuColumnsView, diff --git a/evm_arithmetization/src/cpu/kernel/tests/core/access_lists.rs b/evm_arithmetization/src/cpu/kernel/tests/core/access_lists.rs index f5e3168a8..329dcd1ed 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/core/access_lists.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/core/access_lists.rs @@ -28,10 +28,11 @@ fn test_init_access_lists() -> Result<()> { let acc_addr_list: Vec = (0..2) .map(|i| { - interpreter.generation_state.memory.get( - MemoryAddress::new(0, Segment::AccessedAddresses, i), - &HashMap::default(), - ) + interpreter.generation_state.memory.get(MemoryAddress::new( + 0, + Segment::AccessedAddresses, + i, + )) }) .collect(); assert_eq!( @@ -41,10 +42,11 @@ fn test_init_access_lists() -> Result<()> { let acc_storage_keys: Vec = (0..4) .map(|i| { - interpreter.generation_state.memory.get( - MemoryAddress::new(0, Segment::AccessedStorageKeys, i), - &HashMap::default(), - ) + interpreter.generation_state.memory.get(MemoryAddress::new( + 0, + Segment::AccessedStorageKeys, + i, + )) }) .collect(); @@ -111,10 +113,10 @@ fn test_insert_address() -> Result<()> { interpreter.run()?; assert_eq!(interpreter.stack(), &[U256::one()]); assert_eq!( - interpreter.generation_state.memory.get( - MemoryAddress::new_bundle(U256::from(AccessedAddressesLen as usize)).unwrap(), - &HashMap::default(), - ), + interpreter + .generation_state + .memory + .get(MemoryAddress::new_bundle(U256::from(AccessedAddressesLen as usize)).unwrap(),), U256::from(Segment::AccessedAddresses as usize + 4) ); @@ -167,7 +169,7 @@ fn test_insert_accessed_addresses() -> Result<()> { assert_eq!( interpreter.generation_state.memory.get( MemoryAddress::new_bundle(U256::from(AccessedAddressesLen as usize)).unwrap(), - &HashMap::default(), + ), U256::from(offset + 2 * (n + 1)) ); @@ -181,17 +183,18 @@ fn test_insert_accessed_addresses() -> Result<()> { interpreter.run()?; assert_eq!(interpreter.stack(), &[U256::one()]); assert_eq!( - interpreter.generation_state.memory.get( - MemoryAddress::new_bundle(U256::from(AccessedAddressesLen as usize)).unwrap(), - &HashMap::default(), - ), + interpreter + .generation_state + .memory + .get(MemoryAddress::new_bundle(U256::from(AccessedAddressesLen as usize)).unwrap(),), U256::from(offset + 2 * (n + 2)) ); assert_eq!( - interpreter.generation_state.memory.get( - MemoryAddress::new(0, AccessedAddresses, 2 * (n + 1)), - &HashMap::default(), - ), + interpreter.generation_state.memory.get(MemoryAddress::new( + 0, + AccessedAddresses, + 2 * (n + 1) + ),), U256::from(addr_not_in_list.0.as_slice()) ); @@ -252,8 +255,7 @@ fn test_insert_accessed_storage_keys() -> Result<()> { assert_eq!(interpreter.pop().unwrap(), value); assert_eq!( interpreter.generation_state.memory.get( - MemoryAddress::new_bundle(U256::from(AccessedStorageKeysLen as usize)).unwrap(), - &HashMap::default(), + MemoryAddress::new_bundle(U256::from(AccessedStorageKeysLen as usize)).unwrap(), ), U256::from(offset + 4 * (n + 1)) ); @@ -273,29 +275,26 @@ fn test_insert_accessed_storage_keys() -> Result<()> { ); assert_eq!( interpreter.generation_state.memory.get( - MemoryAddress::new_bundle(U256::from(AccessedStorageKeysLen as usize)).unwrap(), - &HashMap::default() + MemoryAddress::new_bundle(U256::from(AccessedStorageKeysLen as usize)).unwrap(), ), U256::from(offset + 4 * (n + 2)) ); assert_eq!( interpreter.generation_state.memory.get( - MemoryAddress::new(0, AccessedStorageKeys, 4 * (n + 1)), - &HashMap::default(), + MemoryAddress::new(0, AccessedStorageKeys, 4 * (n + 1)), ), U256::from(storage_key_not_in_list.0 .0.as_slice()) ); assert_eq!( interpreter.generation_state.memory.get( MemoryAddress::new(0, AccessedStorageKeys, 4 * (n + 1) + 1), - &HashMap::default() + ), storage_key_not_in_list.1 ); assert_eq!( interpreter.generation_state.memory.get( MemoryAddress::new(0, AccessedStorageKeys, 4 * (n + 1) + 2), - &HashMap::default() ), storage_key_not_in_list.2 ); diff --git a/evm_arithmetization/src/cpu/kernel/tests/ecc/curve_ops.rs b/evm_arithmetization/src/cpu/kernel/tests/ecc/curve_ops.rs index 4540be61e..5bf79f2f2 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/ecc/curve_ops.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/ecc/curve_ops.rs @@ -210,14 +210,11 @@ mod bn { let mut computed_table = Vec::new(); for i in 0..32 { - computed_table.push(int.generation_state.memory.get( - MemoryAddress { - context: 0, - segment: Segment::BnTableQ.unscale(), - virt: i, - }, - &HashMap::default(), - )); + computed_table.push(int.generation_state.memory.get(MemoryAddress { + context: 0, + segment: Segment::BnTableQ.unscale(), + virt: i, + })); } let table = u256ify([ diff --git a/evm_arithmetization/src/generation/prover_input.rs b/evm_arithmetization/src/generation/prover_input.rs index 005fa4b12..41c0454ee 100644 --- a/evm_arithmetization/src/generation/prover_input.rs +++ b/evm_arithmetization/src/generation/prover_input.rs @@ -124,9 +124,9 @@ impl GenerationState { let ptr = stack_peek(self, 11 - n).map(u256_to_usize)??; let f: [U256; 12] = match field { - Bn254Base => std::array::from_fn(|i| { - current_context_peek(self, BnPairing, ptr + i, false, &HashMap::default()) - }), + Bn254Base => { + std::array::from_fn(|i| current_context_peek(self, BnPairing, ptr + i, false)) + } _ => todo!(), }; Ok(field.field_extension_inverse(n, f)) @@ -421,24 +421,21 @@ impl GenerationState { let code_len = self.get_code_len(context)?; let code = (0..code_len) .map(|i| { - u256_to_u8(self.memory.get( - MemoryAddress::new(context, Segment::Code, i), - &HashMap::default(), - )) + u256_to_u8( + self.memory + .get(MemoryAddress::new(context, Segment::Code, i)), + ) }) .collect::, _>>()?; Ok(code) } fn get_code_len(&self, context: usize) -> Result { - let code_len = u256_to_usize(self.memory.get( - MemoryAddress::new( - context, - Segment::ContextMetadata, - ContextMetadata::CodeSize.unscale(), - ), - &HashMap::default(), - ))?; + let code_len = u256_to_usize(self.memory.get(MemoryAddress::new( + context, + Segment::ContextMetadata, + ContextMetadata::CodeSize.unscale(), + )))?; Ok(code_len) } @@ -473,10 +470,11 @@ impl GenerationState { } fn get_global_metadata(&self, data: GlobalMetadata) -> U256 { - self.memory.get( - MemoryAddress::new(0, Segment::GlobalMetadata, data.unscale()), - &HashMap::default(), - ) + self.memory.get(MemoryAddress::new( + 0, + Segment::GlobalMetadata, + data.unscale(), + )) } pub(crate) fn get_storage_keys_access_list(&self) -> Result { @@ -484,14 +482,11 @@ impl GenerationState { // virtual address in the segment. In order to get the length we need // to substract Segment::AccessedStorageKeys as usize let acc_storage_len = u256_to_usize( - self.memory.get( - MemoryAddress::new( - 0, - Segment::GlobalMetadata, - GlobalMetadata::AccessedStorageKeysLen.unscale(), - ), - &HashMap::default(), - ) - Segment::AccessedStorageKeys as usize, + self.memory.get(MemoryAddress::new( + 0, + Segment::GlobalMetadata, + GlobalMetadata::AccessedStorageKeysLen.unscale(), + )) - Segment::AccessedStorageKeys as usize, )?; AccList::from_mem_and_segment( &self.memory.contexts[0].segments[Segment::AccessedStorageKeys.unscale()].content diff --git a/evm_arithmetization/src/generation/state.rs b/evm_arithmetization/src/generation/state.rs index c1e8ac404..8e832fe1b 100644 --- a/evm_arithmetization/src/generation/state.rs +++ b/evm_arithmetization/src/generation/state.rs @@ -19,8 +19,8 @@ use crate::memory::segments::Segment; use crate::util::u256_to_usize; use crate::witness::errors::ProgramError; use crate::witness::memory::MemoryChannel::GeneralPurpose; -use crate::witness::memory::MemoryOpKind; use crate::witness::memory::{MemoryAddress, MemoryOp, MemoryState}; +use crate::witness::memory::{MemoryOpKind, MemorySegmentState}; use crate::witness::operation::{generate_exception, Operation}; use crate::witness::state::RegistersState; use crate::witness::traces::{TraceCheckpoint, Traces}; @@ -85,6 +85,12 @@ pub(crate) trait State { /// Return the offsets at which execution must halt fn get_halt_offsets(&self) -> Vec; + /// Inserts a preinitialized segment, given as a [Segment], + /// into the `preinitialized_segments` memory field. + fn insert_preinitialized_segment(&mut self, segment: Segment, values: MemorySegmentState); + + fn is_preinitialized_segment(&self, segment: usize) -> bool; + /// Simulates a CPU. It only generates the traces if the `State` is a /// `GenerationState`. Otherwise, it simply simulates all ooperations. fn run_cpu(&mut self) -> anyhow::Result<()> @@ -328,8 +334,7 @@ impl GenerationState { let returndata_offset = ContextMetadata::ReturndataSize.unscale(); let returndata_size_addr = MemoryAddress::new(ctx, Segment::ContextMetadata, returndata_offset); - let returndata_size = - u256_to_usize(self.memory.get(returndata_size_addr, &HashMap::default()))?; + let returndata_size = u256_to_usize(self.memory.get(returndata_size_addr))?; let code = self.memory.contexts[ctx].segments[Segment::Returndata.unscale()].content [..returndata_size] .iter() @@ -383,6 +388,16 @@ impl State for GenerationState { } } + fn insert_preinitialized_segment(&mut self, segment: Segment, values: MemorySegmentState) { + panic!( + "A `GenerationState` cannot have a nonempty `preinitialized_segment` field in memory." + ) + } + + fn is_preinitialized_segment(&self, segment: usize) -> bool { + false + } + /// Increments the `gas_used` register by a value `n`. fn incr_gas(&mut self, n: u64) { self.registers.gas_used += n; @@ -405,7 +420,7 @@ impl State for GenerationState { /// Returns the value stored at address `address` in a `State`. fn get_from_memory(&mut self, address: MemoryAddress) -> U256 { - self.memory.get(address, &HashMap::default()) + self.memory.get(address) } /// Returns a mutable `GenerationState` from a `State`. @@ -413,11 +428,6 @@ impl State for GenerationState { self } - // /// Returns true if a `State` is a `GenerationState` and false otherwise. - // fn is_generation_state(&mut self) -> bool { - // true - // } - /// Increments the clock of an `Interpreter`'s clock. fn incr_interpreter_clock(&mut self) {} @@ -504,12 +514,6 @@ impl Transition for GenerationState { Ok(op) } - fn get_preinitialized_segments( - &self, - ) -> HashMap { - HashMap::default() - } - fn generate_jumpdest_analysis(&mut self, dst: usize) -> bool { false } diff --git a/evm_arithmetization/src/witness/memory.rs b/evm_arithmetization/src/witness/memory.rs index c83caf197..4f5dc8851 100644 --- a/evm_arithmetization/src/witness/memory.rs +++ b/evm_arithmetization/src/witness/memory.rs @@ -165,6 +165,7 @@ impl MemoryOp { #[derive(Clone, Debug)] pub(crate) struct MemoryState { pub(crate) contexts: Vec, + preinitialized_segments: HashMap, } impl MemoryState { @@ -220,20 +221,22 @@ impl MemoryState { Some(val) } - pub(crate) fn get( - &self, - address: MemoryAddress, - preinitialized_segments: &HashMap, - ) -> U256 { + pub(crate) fn get(&self, address: MemoryAddress) -> U256 { match self.get_option(address) { Some(val) => val, None => { let segment = Segment::all()[address.segment]; let offset = address.virt; - if preinitialized_segments.contains_key(&segment) - && offset < preinitialized_segments.get(&segment).unwrap().content.len() + if self.preinitialized_segments.contains_key(&segment) + && offset + < self + .preinitialized_segments + .get(&segment) + .unwrap() + .content + .len() { - preinitialized_segments.get(&segment).unwrap().content[offset].unwrap() + self.preinitialized_segments.get(&segment).unwrap().content[offset].unwrap() } else { 0.into() } @@ -268,10 +271,34 @@ impl MemoryState { // These fields are already scaled by their respective segment. pub(crate) fn read_global_metadata(&self, field: GlobalMetadata) -> U256 { - self.get( - MemoryAddress::new_bundle(U256::from(field as usize)).unwrap(), - &HashMap::default(), - ) + self.get(MemoryAddress::new_bundle(U256::from(field as usize)).unwrap()) + } + + /// Inserts a segment and its preinitialized values in + /// `preinitialized_segments`. + pub(crate) fn insert_preinitialized_segment( + &mut self, + segment: Segment, + values: MemorySegmentState, + ) { + self.preinitialized_segments.insert(segment, values); + } + + /// Returns a boolean which indicates whether a segment (given as a usize) + /// is part of the `preinitialize_segments`. + pub(crate) fn is_preinitialized_segment(&self, segment: usize) -> bool { + if let Some(seg) = Segment::all().get(segment) { + self.preinitialized_segments.contains_key(&seg) + } else { + false + } + } + + pub(crate) fn get_preinitialized_segment( + &self, + segment: Segment, + ) -> Option<&MemorySegmentState> { + self.preinitialized_segments.get(&segment) } } @@ -280,6 +307,7 @@ impl Default for MemoryState { Self { // We start with an initial context for the kernel. contexts: vec![MemoryContextState::default()], + preinitialized_segments: HashMap::default(), } } } diff --git a/evm_arithmetization/src/witness/operation.rs b/evm_arithmetization/src/witness/operation.rs index 0c34cafd6..cdb7b3113 100644 --- a/evm_arithmetization/src/witness/operation.rs +++ b/evm_arithmetization/src/witness/operation.rs @@ -137,7 +137,6 @@ pub(crate) fn generate_ternary_arithmetic_op( pub(crate) fn generate_keccak_general( state: &mut GenerationState, mut row: CpuColumnsView, - preinitialized_segments: &HashMap, ) -> Result<(), ProgramError> { let [(addr, _), (len, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; let len = u256_to_usize(len)?; @@ -149,7 +148,7 @@ pub(crate) fn generate_keccak_general( virt: base_address.virt.saturating_add(i), ..base_address }; - let val = state.memory.get(address, preinitialized_segments); + let val = state.memory.get(address); val.low_u32() as u8 }) .collect_vec(); @@ -294,7 +293,7 @@ pub(crate) fn generate_set_context( // Even though we might be in the interpreter, `Stack` is not part of the // preinitialized segments, so we don't need to carry out the additional checks // when get the value from memory. - mem_read_with_log(GeneralPurpose(2), new_sp_addr, state, &HashMap::default()) + mem_read_with_log(GeneralPurpose(2), new_sp_addr, state) }; // If the new stack isn't empty, read stack_top from memory. @@ -317,7 +316,7 @@ pub(crate) fn generate_set_context( // preinitialized segments, so we don't need to carry out the additional checks // when get the value from memory. let (new_top, log_read_new_top) = - mem_read_gp_with_log_and_fill(2, new_top_addr, state, &mut row, &HashMap::default()); + mem_read_gp_with_log_and_fill(2, new_top_addr, state, &mut row); state.registers.stack_top = new_top; state.traces.push_memory(log_read_new_top); } else { @@ -354,13 +353,10 @@ pub(crate) fn generate_push( .map(|i| { state .memory - .get( - MemoryAddress { - virt: base_address.virt + i, - ..base_address - }, - &HashMap::default(), - ) + .get(MemoryAddress { + virt: base_address.virt + i, + ..base_address + }) .low_u32() as u8 }) .collect_vec(); @@ -438,7 +434,7 @@ pub(crate) fn generate_dup( // Even though we might be in the interpreter, `Stack` is not part of the // preinitialized segments, so we don't need to carry out the additional checks // when get the value from memory. - mem_read_gp_with_log_and_fill(2, other_addr, state, &mut row, &HashMap::default()) + mem_read_gp_with_log_and_fill(2, other_addr, state, &mut row) }; push_no_write(state, val); @@ -463,8 +459,7 @@ pub(crate) fn generate_swap( // Even though we might be in the interpreter, `Stack` is not part of the // preinitialized segments, so we don't need to carry out the additional checks // when get the value from memory. - let (in1, log_in1) = - mem_read_gp_with_log_and_fill(1, other_addr, state, &mut row, &HashMap::default()); + let (in1, log_in1) = mem_read_gp_with_log_and_fill(1, other_addr, state, &mut row); let log_out0 = mem_write_gp_log_and_fill(2, other_addr, state, &mut row, in0); push_no_write(state, in1); @@ -530,13 +525,7 @@ fn append_shift( // Even though we might be in the interpreter, `ShiftTable` is not part of the // preinitialized segments, so we don't need to carry out the additional checks // when get the value from memory. - let (_, read) = mem_read_gp_with_log_and_fill( - LOOKUP_CHANNEL, - lookup_addr, - state, - &mut row, - &HashMap::default(), - ); + let (_, read) = mem_read_gp_with_log_and_fill(LOOKUP_CHANNEL, lookup_addr, state, &mut row); state.traces.push_memory(read); } else { // The shift constraints still expect the address to be set, even though no read @@ -628,7 +617,7 @@ pub(crate) fn generate_syscall( // Even though we might be in the interpreter, `Code` is not part of the // preinitialized segments, so we don't need to carry out the additional checks // when get the value from memory. - let val = state.memory.get(address, &HashMap::default()); + let val = state.memory.get(address); val.low_u32() as u8 }) .collect_vec(); @@ -731,17 +720,11 @@ pub(crate) fn generate_exit_kernel( pub(crate) fn generate_mload_general( state: &mut GenerationState, mut row: CpuColumnsView, - preinitialized_segments: &HashMap, ) -> Result<(), ProgramError> { let [(addr, _)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; - let (val, log_read) = mem_read_gp_with_log_and_fill( - 1, - MemoryAddress::new_bundle(addr)?, - state, - &mut row, - preinitialized_segments, - ); + let (val, log_read) = + mem_read_gp_with_log_and_fill(1, MemoryAddress::new_bundle(addr)?, state, &mut row); push_no_write(state, val); // Because MLOAD_GENERAL performs 1 pop and 1 push, it does not make use of the @@ -765,7 +748,6 @@ pub(crate) fn generate_mload_general( pub(crate) fn generate_mload_32bytes( state: &mut GenerationState, mut row: CpuColumnsView, - preinitialized_segments: &HashMap, ) -> Result<(), ProgramError> { let [(addr, _), (len, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; let len = u256_to_usize(len)?; @@ -786,7 +768,7 @@ pub(crate) fn generate_mload_32bytes( virt: base_address.virt + i, ..base_address }; - let val = state.memory.get(address, preinitialized_segments); + let val = state.memory.get(address); val.low_u32() as u8 }) .collect_vec(); @@ -888,7 +870,7 @@ pub(crate) fn generate_exception>( // Even though we might be in the interpreter, `Code` is not part of the // preinitialized segments, so we don't need to carry out the additional checks // when get the value from memory. - let val = generation_state.memory.get(address, &HashMap::default()); + let val = generation_state.memory.get(address); val.low_u32() as u8 }) .collect_vec(); @@ -919,7 +901,7 @@ pub(crate) fn generate_exception>( // Even though we might be in the interpreter, `Code` is not part of the // preinitialized segments, so we don't need to carry out the additional checks // when get the value from memory. - let opcode = generation_state.memory.get(address, &HashMap::default()); + let opcode = generation_state.memory.get(address); // `ArithmeticStark` range checks `mem_channels[0]`, which contains // the top of the stack, `mem_channels[1]`, which contains the new PC, diff --git a/evm_arithmetization/src/witness/transition.rs b/evm_arithmetization/src/witness/transition.rs index 1ac3fe6db..cb375fadd 100644 --- a/evm_arithmetization/src/witness/transition.rs +++ b/evm_arithmetization/src/witness/transition.rs @@ -40,8 +40,7 @@ pub(crate) fn read_code_memory( row.code_context = F::from_canonical_usize(code_context); let address = MemoryAddress::new(code_context, Segment::Code, state.registers.program_counter); - let (opcode, mem_log) = - mem_read_code_with_log_and_fill(address, state, row, false, &HashMap::default()); + let (opcode, mem_log) = mem_read_code_with_log_and_fill(address, state, row, false); state.traces.push_memory(mem_log); @@ -347,9 +346,6 @@ pub(crate) trait Transition: State { if !self.generate_jumpdest_analysis(dst as usize) { let gen_state = self.get_mut_generation_state(); - // Even though we might be in the interpreter, `JumpdestBits` is not part - // preinitialized segments, so we don't need to carry out the additional checks - // when get the value from memory. let (jumpdest_bit, jumpdest_bit_log) = mem_read_gp_with_log_and_fill( NUM_GP_CHANNELS - 1, MemoryAddress::new( @@ -359,7 +355,6 @@ pub(crate) trait Transition: State { ), gen_state, &mut row, - &HashMap::default(), ); row.mem_channels[1].value[0] = F::ONE; @@ -423,9 +418,6 @@ pub(crate) trait Transition: State { } let gen_state = self.get_mut_generation_state(); - // Even though we might be in the interpreter, `JumpdestBits` is not part of the - // preinitialized segments, so we don't need to carry out the additional checks - // when get the value from memory. let (jumpdest_bit, jumpdest_bit_log) = mem_read_gp_with_log_and_fill( NUM_GP_CHANNELS - 1, MemoryAddress::new( @@ -435,7 +427,6 @@ pub(crate) trait Transition: State { ), gen_state, &mut row, - &HashMap::default(), ); if !should_jump || gen_state.registers.is_kernel { // Don't actually do the read, just set the address, etc. @@ -466,8 +457,6 @@ pub(crate) trait Transition: State { /// Skips the following instructions for some specific labels fn skip_if_necessary(&mut self, op: Operation) -> Result; - fn get_preinitialized_segments(&self) -> HashMap; - fn perform_op( &mut self, op: Operation, @@ -475,7 +464,6 @@ pub(crate) trait Transition: State { row: CpuColumnsView, ) -> Result<(), ProgramError> { let op = self.skip_if_necessary(op)?; - let preinitialized_segments = self.get_preinitialized_segments(); #[cfg(debug_assertions)] if !self.get_registers().is_kernel { @@ -522,9 +510,7 @@ pub(crate) trait Transition: State { Operation::TernaryArithmetic(op) => { generate_ternary_arithmetic_op(op, generation_state, row)? } - Operation::KeccakGeneral => { - generate_keccak_general(generation_state, row, &preinitialized_segments)? - } + Operation::KeccakGeneral => generate_keccak_general(generation_state, row)?, Operation::ProverInput => generate_prover_input(generation_state, row)?, Operation::Pop => generate_pop(generation_state, row)?, Operation::Jump => self.generate_jump(row)?, @@ -533,14 +519,10 @@ pub(crate) trait Transition: State { Operation::Jumpdest => generate_jumpdest(generation_state, row)?, Operation::GetContext => generate_get_context(generation_state, row)?, Operation::SetContext => generate_set_context(generation_state, row)?, - Operation::Mload32Bytes => { - generate_mload_32bytes(generation_state, row, &preinitialized_segments)? - } + Operation::Mload32Bytes => generate_mload_32bytes(generation_state, row)?, Operation::Mstore32Bytes(n) => generate_mstore_32bytes(n, generation_state, row)?, Operation::ExitKernel => generate_exit_kernel(generation_state, row)?, - Operation::MloadGeneral => { - generate_mload_general(generation_state, row, &preinitialized_segments)? - } + Operation::MloadGeneral => generate_mload_general(generation_state, row)?, Operation::MstoreGeneral => generate_mstore_general(generation_state, row)?, }; diff --git a/evm_arithmetization/src/witness/util.rs b/evm_arithmetization/src/witness/util.rs index 4bd589feb..c29cb280c 100644 --- a/evm_arithmetization/src/witness/util.rs +++ b/evm_arithmetization/src/witness/util.rs @@ -44,14 +44,11 @@ pub(crate) fn stack_peek( return Ok(state.registers.stack_top); } - Ok(state.memory.get( - MemoryAddress::new( - state.registers.context, - Segment::Stack, - state.registers.stack_len - 1 - i, - ), - &HashMap::default(), - )) + Ok(state.memory.get(MemoryAddress::new( + state.registers.context, + Segment::Stack, + state.registers.stack_len - 1 - i, + ))) } /// Peek at kernel at specified segment and address @@ -60,13 +57,9 @@ pub(crate) fn current_context_peek( segment: Segment, virt: usize, is_interpreter: bool, - preinitialized_segments: &HashMap, ) -> U256 { let context = state.registers.context; - state.memory.get( - MemoryAddress::new(context, segment, virt), - preinitialized_segments, - ) + state.memory.get(MemoryAddress::new(context, segment, virt)) } pub(crate) fn fill_channel_with_value(row: &mut CpuColumnsView, n: usize, val: U256) { @@ -119,9 +112,8 @@ pub(crate) fn mem_read_with_log( channel: MemoryChannel, address: MemoryAddress, state: &GenerationState, - preinitialized_segments: &HashMap, ) -> (U256, MemoryOp) { - let val = state.memory.get(address, preinitialized_segments); + let val = state.memory.get(address); let op = MemoryOp::new( channel, state.traces.clock(), @@ -152,9 +144,8 @@ pub(crate) fn mem_read_code_with_log_and_fill( state: &GenerationState, row: &mut CpuColumnsView, is_interpreter: bool, - preinitialized_segments: &HashMap, ) -> (u8, MemoryOp) { - let (val, op) = mem_read_with_log(MemoryChannel::Code, address, state, preinitialized_segments); + let (val, op) = mem_read_with_log(MemoryChannel::Code, address, state); let val_u8 = to_byte_checked(val); row.opcode_bits = to_bits_le(val_u8); @@ -167,14 +158,8 @@ pub(crate) fn mem_read_gp_with_log_and_fill( address: MemoryAddress, state: &GenerationState, row: &mut CpuColumnsView, - preinitialized_segments: &HashMap, ) -> (U256, MemoryOp) { - let (val, op) = mem_read_with_log( - MemoryChannel::GeneralPurpose(n), - address, - state, - preinitialized_segments, - ); + let (val, op) = mem_read_with_log(MemoryChannel::GeneralPurpose(n), address, state); let val_limbs: [u64; 4] = val.0; let channel = &mut row.mem_channels[n]; @@ -263,7 +248,7 @@ pub(crate) fn stack_pop_with_log_and_fill( state.registers.stack_len - 1 - i, ); - mem_read_gp_with_log_and_fill(i, address, state, row, &HashMap::default()) + mem_read_gp_with_log_and_fill(i, address, state, row) } }); From a1d1fe082a108a7c617f86ab73077c83967631ae Mon Sep 17 00:00:00 2001 From: Linda Guiga Date: Tue, 27 Feb 2024 19:11:24 +0100 Subject: [PATCH 8/9] Clippy --- evm_arithmetization/src/witness/memory.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/evm_arithmetization/src/witness/memory.rs b/evm_arithmetization/src/witness/memory.rs index 4f5dc8851..04123addd 100644 --- a/evm_arithmetization/src/witness/memory.rs +++ b/evm_arithmetization/src/witness/memory.rs @@ -288,7 +288,7 @@ impl MemoryState { /// is part of the `preinitialize_segments`. pub(crate) fn is_preinitialized_segment(&self, segment: usize) -> bool { if let Some(seg) = Segment::all().get(segment) { - self.preinitialized_segments.contains_key(&seg) + self.preinitialized_segments.contains_key(seg) } else { false } From 7bfae89dbafe785d9c4c1858db4bd9336eccd548 Mon Sep 17 00:00:00 2001 From: Linda Guiga Date: Tue, 27 Feb 2024 19:18:16 +0100 Subject: [PATCH 9/9] Apply get_halt_context comment --- evm_arithmetization/src/generation/state.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/evm_arithmetization/src/generation/state.rs b/evm_arithmetization/src/generation/state.rs index 8e832fe1b..cc4c3adbc 100644 --- a/evm_arithmetization/src/generation/state.rs +++ b/evm_arithmetization/src/generation/state.rs @@ -74,7 +74,9 @@ pub(crate) trait State { /// Returns the current context. fn get_context(&mut self) -> usize; - fn get_halt_context(&mut self) -> Option; + fn get_halt_context(&mut self) -> Option { + None + } /// Returns the content of a the `KernelGeneral` segment of a `State`. fn mem_get_kernel_content(&self) -> Vec>; @@ -450,10 +452,6 @@ impl State for GenerationState { self.registers.context } - fn get_halt_context(&mut self) -> Option { - None - } - /// Returns the content of a the `KernelGeneral` segment of a `State`. fn mem_get_kernel_content(&self) -> Vec> { self.memory.contexts[0].segments[Segment::KernelGeneral.unscale()]