diff --git a/evm_arithmetization/src/cpu/kernel/interpreter.rs b/evm_arithmetization/src/cpu/kernel/interpreter.rs index 0c16898b1..76958c7b6 100644 --- a/evm_arithmetization/src/cpu/kernel/interpreter.rs +++ b/evm_arithmetization/src/cpu/kernel/interpreter.rs @@ -19,14 +19,18 @@ use crate::generation::rlp::all_rlp_prover_inputs_reversed; use crate::generation::state::{ all_withdrawals_prover_inputs_reversed, GenerationState, GenerationStateCheckpoint, }; -use crate::generation::{run_cpu, GenerationInputs, State}; +use crate::generation::{state::State, GenerationInputs}; use crate::memory::segments::Segment; use crate::util::h2u; use crate::witness::errors::ProgramError; use crate::witness::memory::{ MemoryAddress, MemoryContextState, MemoryOp, MemoryOpKind, MemorySegmentState, }; +use crate::witness::operation::Operation; use crate::witness::state::RegistersState; +use crate::witness::transition::{ + decode, fill_op_flag, get_op_special_length, log_kernel_instruction, Transition, +}; type F = GoldilocksField; @@ -45,11 +49,6 @@ pub(crate) struct Interpreter { /// Counts the number of appearances of each opcode. For debugging purposes. pub(crate) opcode_count: [usize; 0x100], jumpdest_table: HashMap>, - /// Segments that can be preinitialized: they are not stored in the - /// interpreter memory unless they are read/written during the execution. - /// When the values are first read, they are read from this `HashMap` (and - /// the value is then written in memory). - pub(crate) preinitialized_segments: HashMap, /// `true` if the we are currently carrying out a jumpdest analysis. pub(crate) is_jumpdest_analysis: bool, /// Holds the value of the clock: the clock counts the number of operations @@ -160,7 +159,6 @@ impl Interpreter { halt_context: None, opcode_count: [0; 256], jumpdest_table: HashMap::new(), - preinitialized_segments: HashMap::default(), is_jumpdest_analysis: false, clock: 0, }; @@ -190,7 +188,6 @@ impl Interpreter { halt_context: Some(halt_context), opcode_count: [0; 256], jumpdest_table: HashMap::new(), - preinitialized_segments: HashMap::new(), is_jumpdest_analysis: true, clock: 0, } @@ -215,8 +212,7 @@ impl Interpreter { let preinit_trie_data_segment = MemorySegmentState { content: trie_data.iter().map(|&elt| Some(elt)).collect::>(), }; - self.preinitialized_segments - .insert(Segment::TrieData, preinit_trie_data_segment); + self.insert_preinitialized_segment(Segment::TrieData, preinit_trie_data_segment); // Update the RLP and withdrawal prover inputs. let rlp_prover_inputs = @@ -332,11 +328,7 @@ impl Interpreter { match kind { MemoryOpKind::Read => { if self.generation_state.memory.get(address).is_none() { - if !self - .preinitialized_segments - .contains_key(&Segment::all()[address.segment]) - && !value.is_zero() - { + if !self.is_preinitialized_segment(address.segment) && !value.is_zero() { return Err(anyhow!("The initial value {:?} at address {:?} should be zero, because it is not preinitialized.", value, address)); } self.generation_state.memory.set(address, value); @@ -349,19 +341,8 @@ impl Interpreter { Ok(()) } - /// Returns a `GenerationStateCheckpoint` to save the current registers and - /// reset memory operations to the empty vector. - pub(crate) fn checkpoint(&mut self) -> GenerationStateCheckpoint { - self.generation_state.traces.memory_ops = vec![]; - GenerationStateCheckpoint { - registers: self.generation_state.registers, - traces: self.generation_state.traces.checkpoint(), - } - } - pub(crate) fn run(&mut self) -> Result<(), anyhow::Error> { - let mut state = State::Interpreter(self); - run_cpu(&mut state)?; + self.run_cpu()?; #[cfg(debug_assertions)] { @@ -458,7 +439,7 @@ impl Interpreter { } pub(crate) fn get_memory_segment(&self, segment: Segment) -> Vec { - if self.preinitialized_segments.contains_key(&segment) { + if self.is_preinitialized_segment(segment.unscale()) { let total_len = self.generation_state.memory.contexts[0].segments[segment.unscale()] .content .len(); @@ -473,8 +454,9 @@ impl Interpreter { }; let mut res = get_vals( &self - .preinitialized_segments - .get(&segment) + .generation_state + .memory + .get_preinitialized_segment(segment) .expect("The segment should be in the preinitialized segments.") .content, ); @@ -607,6 +589,7 @@ impl Interpreter { } } } + fn stack_segment_mut(&mut self) -> &mut Vec> { let context = self.context(); &mut self.generation_state.memory.contexts[context].segments[Segment::Stack.unscale()] @@ -616,11 +599,10 @@ impl Interpreter { pub(crate) fn extract_kernel_memory(self, segment: Segment, range: Range) -> Vec { let mut output: Vec = Vec::with_capacity(range.end); for i in range { - let term = self.generation_state.memory.get_with_init( - MemoryAddress::new(0, segment, i), - true, - &self.preinitialized_segments, - ); + let term = self + .generation_state + .memory + .get_with_init(MemoryAddress::new(0, segment, i)); output.push(term); } output @@ -682,15 +664,11 @@ impl Interpreter { // Even though we are in the interpreter, `JumpdestBits` is not part of the // preinitialized segments, so we don't need to carry out the additional checks // when get the value from memory. - self.generation_state.memory.get_with_init( - MemoryAddress { - context: self.context(), - segment: Segment::JumpdestBits.unscale(), - virt: offset, - }, - false, - &HashMap::default(), - ) + self.generation_state.memory.get_with_init(MemoryAddress { + context: self.context(), + segment: Segment::JumpdestBits.unscale(), + virt: offset, + }) } else { 0.into() } @@ -757,6 +735,173 @@ impl Interpreter { } } +impl State for Interpreter { + //// Returns a `GenerationStateCheckpoint` to save the current registers and + /// reset memory operations to the empty vector. + fn checkpoint(&mut self) -> GenerationStateCheckpoint { + self.generation_state.traces.memory_ops = vec![]; + GenerationStateCheckpoint { + registers: self.generation_state.registers, + traces: self.generation_state.traces.checkpoint(), + } + } + + fn is_generation(&mut self) -> bool { + false + } + + fn insert_preinitialized_segment(&mut self, segment: Segment, values: MemorySegmentState) { + self.generation_state + .memory + .insert_preinitialized_segment(segment, values); + } + + fn is_preinitialized_segment(&self, segment: usize) -> bool { + self.generation_state + .memory + .is_preinitialized_segment(segment) + } + + fn incr_gas(&mut self, n: u64) { + self.generation_state.incr_gas(n); + } + + fn incr_pc(&mut self, n: usize) { + self.generation_state.incr_pc(n); + } + + fn get_registers(&self) -> RegistersState { + self.generation_state.get_registers() + } + + fn get_mut_registers(&mut self) -> &mut RegistersState { + self.generation_state.get_mut_registers() + } + + fn get_from_memory(&mut self, address: MemoryAddress) -> U256 { + self.generation_state.memory.get_with_init(address) + } + + fn get_mut_generation_state(&mut self) -> &mut GenerationState { + &mut self.generation_state + } + + fn incr_interpreter_clock(&mut self) { + self.clock += 1 + } + + fn get_clock(&mut self) -> usize { + self.clock + } + + fn rollback(&mut self, checkpoint: GenerationStateCheckpoint) { + self.generation_state.rollback(checkpoint) + } + + fn get_context(&mut self) -> usize { + self.context() + } + + fn get_halt_context(&mut self) -> Option { + self.halt_context + } + + fn mem_get_kernel_content(&self) -> Vec> { + self.generation_state.memory.contexts[0].segments[Segment::KernelGeneral.unscale()] + .content + .clone() + } + + fn apply_ops(&mut self, checkpoint: GenerationStateCheckpoint) { + self.apply_memops(); + } + + fn get_stack(&mut self) -> Vec { + self.stack().clone() + } + + fn get_halt_offsets(&self) -> Vec { + self.halt_offsets.clone() + } + + fn try_perform_instruction(&mut self) -> Result { + let registers = self.generation_state.registers; + let (mut row, opcode) = self.base_row(); + + let op = decode(registers, opcode)?; + + fill_op_flag(op, &mut row); + + self.fill_stack_fields(&mut row)?; + + let generation_state = self.get_mut_generation_state(); + if registers.is_kernel { + log_kernel_instruction(generation_state, op); + } else { + log::debug!("User instruction: {:?}", op); + } + + // Might write in general CPU columns when it shouldn't, but the correct values + // will overwrite these ones during the op generation. + if let Some(special_len) = get_op_special_length(op) { + let special_len_f = F::from_canonical_usize(special_len); + let diff = row.stack_len - special_len_f; + if (generation_state.stack().len() != special_len) { + // If the `State` is an interpreter, we cannot rely on the row to carry out the + // check. + generation_state.registers.is_stack_top_read = true; + } + } else if let Some(inv) = row.stack_len.try_inverse() { + row.general.stack_mut().stack_inv = inv; + row.general.stack_mut().stack_inv_aux = F::ONE; + } + + self.perform_state_op(opcode, op, row) + } +} + +impl Transition for Interpreter { + fn generate_jumpdest_analysis(&mut self, dst: usize) -> bool { + if self.is_jumpdest_analysis && !self.generation_state.registers.is_kernel { + self.add_jumpdest_offset(dst); + true + } else { + false + } + } + + fn skip_if_necessary(&mut self, op: Operation) -> Result { + if self.is_kernel() + && self.is_jumpdest_analysis + && self.generation_state.registers.program_counter + == KERNEL.global_labels["jumpdest_analysis"] + { + self.generation_state.registers.program_counter = + KERNEL.global_labels["jumpdest_analysis_end"]; + self.generation_state + .set_jumpdest_bits(&self.generation_state.get_current_code()?); + let opcode = self + .code() + .get(self.generation_state.registers.program_counter) + .byte(0); + + decode(self.generation_state.registers, opcode) + } else { + Ok(op) + } + } + + fn fill_stack_fields( + &mut self, + row: &mut crate::cpu::columns::CpuColumnsView, + ) -> Result<(), ProgramError> { + self.generation_state.registers.is_stack_top_read = false; + self.generation_state.registers.check_overflow = false; + + Ok(()) + } +} + fn get_mnemonic(opcode: u8) -> &'static str { match opcode { 0x00 => "STOP", diff --git a/evm_arithmetization/src/cpu/kernel/tests/core/access_lists.rs b/evm_arithmetization/src/cpu/kernel/tests/core/access_lists.rs index bb75da567..fab799886 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/core/access_lists.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/core/access_lists.rs @@ -28,11 +28,10 @@ fn test_init_access_lists() -> Result<()> { let acc_addr_list: Vec = (0..2) .map(|i| { - interpreter.generation_state.memory.get_with_init( - MemoryAddress::new(0, Segment::AccessedAddresses, i), - false, - &HashMap::default(), - ) + interpreter + .generation_state + .memory + .get_with_init(MemoryAddress::new(0, Segment::AccessedAddresses, i)) }) .collect(); assert_eq!( @@ -42,11 +41,10 @@ fn test_init_access_lists() -> Result<()> { let acc_storage_keys: Vec = (0..4) .map(|i| { - interpreter.generation_state.memory.get_with_init( - MemoryAddress::new(0, Segment::AccessedStorageKeys, i), - false, - &HashMap::default(), - ) + interpreter + .generation_state + .memory + .get_with_init(MemoryAddress::new(0, Segment::AccessedStorageKeys, i)) }) .collect(); @@ -115,8 +113,6 @@ fn test_insert_address() -> Result<()> { assert_eq!( interpreter.generation_state.memory.get_with_init( MemoryAddress::new_bundle(U256::from(AccessedAddressesLen as usize)).unwrap(), - false, - &HashMap::default(), ), U256::from(Segment::AccessedAddresses as usize + 4) ); @@ -170,8 +166,6 @@ fn test_insert_accessed_addresses() -> Result<()> { assert_eq!( interpreter.generation_state.memory.get_with_init( MemoryAddress::new_bundle(U256::from(AccessedAddressesLen as usize)).unwrap(), - false, - &HashMap::default(), ), U256::from(offset + 2 * (n + 1)) ); @@ -187,17 +181,14 @@ fn test_insert_accessed_addresses() -> Result<()> { assert_eq!( interpreter.generation_state.memory.get_with_init( MemoryAddress::new_bundle(U256::from(AccessedAddressesLen as usize)).unwrap(), - false, - &HashMap::default(), ), U256::from(offset + 2 * (n + 2)) ); assert_eq!( - interpreter.generation_state.memory.get_with_init( - MemoryAddress::new(0, AccessedAddresses, 2 * (n + 1)), - false, - &HashMap::default(), - ), + interpreter + .generation_state + .memory + .get_with_init(MemoryAddress::new(0, AccessedAddresses, 2 * (n + 1)),), U256::from(addr_not_in_list.0.as_slice()) ); @@ -259,8 +250,6 @@ fn test_insert_accessed_storage_keys() -> Result<()> { assert_eq!( interpreter.generation_state.memory.get_with_init( MemoryAddress::new_bundle(U256::from(AccessedStorageKeysLen as usize)).unwrap(), - false, - &HashMap::default(), ), U256::from(offset + 4 * (n + 1)) ); @@ -281,33 +270,28 @@ fn test_insert_accessed_storage_keys() -> Result<()> { assert_eq!( interpreter.generation_state.memory.get_with_init( MemoryAddress::new_bundle(U256::from(AccessedStorageKeysLen as usize)).unwrap(), - false, - &HashMap::default() ), U256::from(offset + 4 * (n + 2)) ); assert_eq!( - interpreter.generation_state.memory.get_with_init( - MemoryAddress::new(0, AccessedStorageKeys, 4 * (n + 1)), - false, - &HashMap::default(), - ), + interpreter + .generation_state + .memory + .get_with_init(MemoryAddress::new(0, AccessedStorageKeys, 4 * (n + 1)),), U256::from(storage_key_not_in_list.0 .0.as_slice()) ); assert_eq!( - interpreter.generation_state.memory.get_with_init( - MemoryAddress::new(0, AccessedStorageKeys, 4 * (n + 1) + 1), - false, - &HashMap::default() - ), + interpreter + .generation_state + .memory + .get_with_init(MemoryAddress::new(0, AccessedStorageKeys, 4 * (n + 1) + 1),), storage_key_not_in_list.1 ); assert_eq!( - interpreter.generation_state.memory.get_with_init( - MemoryAddress::new(0, AccessedStorageKeys, 4 * (n + 1) + 2), - false, - &HashMap::default() - ), + interpreter + .generation_state + .memory + .get_with_init(MemoryAddress::new(0, AccessedStorageKeys, 4 * (n + 1) + 2),), storage_key_not_in_list.2 ); diff --git a/evm_arithmetization/src/cpu/kernel/tests/ecc/curve_ops.rs b/evm_arithmetization/src/cpu/kernel/tests/ecc/curve_ops.rs index cdda1c5b7..77ab3fc6e 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/ecc/curve_ops.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/ecc/curve_ops.rs @@ -210,15 +210,11 @@ mod bn { let mut computed_table = Vec::new(); for i in 0..32 { - computed_table.push(int.generation_state.memory.get_with_init( - MemoryAddress { - context: 0, - segment: Segment::BnTableQ.unscale(), - virt: i, - }, - false, - &HashMap::default(), - )); + computed_table.push(int.generation_state.memory.get_with_init(MemoryAddress { + context: 0, + segment: Segment::BnTableQ.unscale(), + virt: i, + })); } let table = u256ify([ diff --git a/evm_arithmetization/src/generation/mod.rs b/evm_arithmetization/src/generation/mod.rs index 83c7e2236..18ada6311 100644 --- a/evm_arithmetization/src/generation/mod.rs +++ b/evm_arithmetization/src/generation/mod.rs @@ -28,7 +28,6 @@ use crate::proof::{BlockHashes, BlockMetadata, ExtraBlockData, PublicValues, Tri use crate::util::{h2u, u256_to_u8, u256_to_usize}; use crate::witness::memory::{MemoryAddress, MemoryChannel, MemoryOp, MemoryOpKind}; use crate::witness::state::RegistersState; -use crate::witness::transition::transition; pub mod mpt; pub(crate) mod prover_input; @@ -36,7 +35,7 @@ pub(crate) mod rlp; pub(crate) mod state; mod trie_extractor; -use self::state::GenerationStateCheckpoint; +use self::state::{GenerationStateCheckpoint, State}; use crate::witness::util::{mem_write_log, stack_peek}; /// Inputs needed for trace generation. @@ -312,197 +311,8 @@ pub fn generate_traces, const D: usize>( Ok((tables, public_values)) } -/// A State is either an `Interpreter` (used for tests and jumpdest analysis) or -/// a `GenerationState`. -pub(crate) enum State<'a, F: Field> { - Generation(&'a mut GenerationState), - Interpreter(&'a mut Interpreter), -} - -impl<'a, F: Field> State<'a, F> { - /// Returns a `State`'s `Checkpoint`. - pub(crate) fn checkpoint(&mut self) -> GenerationStateCheckpoint { - match self { - Self::Generation(state) => state.checkpoint(), - Self::Interpreter(interpreter) => interpreter.checkpoint(), - } - } - - /// Increments the `gas_used` register by a value `n`. - pub(crate) fn incr_gas(&mut self, n: u64) { - match self { - Self::Generation(state) => state.registers.gas_used += n, - Self::Interpreter(interpreter) => interpreter.generation_state.registers.gas_used += n, - } - } - - /// Increments the `program_counter` register by a value `n`. - pub(crate) fn incr_pc(&mut self, n: usize) { - match self { - Self::Generation(state) => state.registers.program_counter += n, - Self::Interpreter(interpreter) => { - interpreter.generation_state.registers.program_counter += n - } - } - } - - /// Returns a `State`'s registers. - pub(crate) fn get_registers(&self) -> RegistersState { - match self { - Self::Generation(state) => state.registers, - Self::Interpreter(interpreter) => interpreter.generation_state.registers, - } - } - - /// Returns a `State`'s mutable registers. - pub(crate) fn get_mut_registers(&mut self) -> &mut RegistersState { - match self { - Self::Generation(state) => &mut state.registers, - Self::Interpreter(interpreter) => &mut interpreter.generation_state.registers, - } - } - - /// Returns the value stored at address `address` in a `State`. - pub(crate) fn get_from_memory(&mut self, address: MemoryAddress) -> U256 { - match self { - Self::Generation(state) => { - state - .memory - .get_with_init(address, false, &HashMap::default()) - } - Self::Interpreter(interpreter) => interpreter.generation_state.memory.get_with_init( - address, - true, - &interpreter.preinitialized_segments, - ), - } - } - - /// Returns a mutable `GenerationState` from a `State`. - pub(crate) fn get_mut_generation_state(&mut self) -> &mut GenerationState { - match self { - Self::Generation(state) => state, - Self::Interpreter(interpreter) => &mut interpreter.generation_state, - } - } - - /// Returns true if a `State` is a `GenerationState` and false otherwise. - pub(crate) fn is_generation_state(&mut self) -> bool { - match self { - Self::Generation(state) => true, - Self::Interpreter(interpreter) => false, - } - } - - /// Increments the clock of an `Interpreter`'s clock. - pub(crate) fn incr_interpreter_clock(&mut self) { - match self { - Self::Generation(state) => {} - Self::Interpreter(interpreter) => interpreter.clock += 1, - } - } - - /// Returns the value of a `State`'s clock. - pub(crate) fn get_clock(&mut self) -> usize { - match self { - Self::Generation(state) => state.traces.clock(), - Self::Interpreter(interpreter) => interpreter.clock, - } - } - - /// Rolls back a `State`. - pub(crate) fn rollback(&mut self, checkpoint: GenerationStateCheckpoint) { - match self { - Self::Generation(state) => state.rollback(checkpoint), - Self::Interpreter(interpreter) => interpreter.generation_state.rollback(checkpoint), - } - } - - /// Returns a `State`'s stack. - pub(crate) fn get_stack(&mut self) -> Vec { - match self { - Self::Generation(state) => state.stack(), - Self::Interpreter(interpreter) => interpreter.stack(), - } - } - - fn get_context(&mut self) -> usize { - match self { - Self::Generation(state) => state.registers.context, - Self::Interpreter(interpreter) => interpreter.context(), - } - } - - fn get_halt_context(&mut self) -> Option { - match self { - Self::Generation(state) => None, - Self::Interpreter(interpreter) => interpreter.halt_context, - } - } - - /// Returns the content of a the `KernelGeneral` segment of a `State`. - pub(crate) fn mem_get_kernel_content(&self) -> Vec> { - match self { - Self::Generation(state) => state.memory.contexts[0].segments - [Segment::KernelGeneral.unscale()] - .content - .clone(), - Self::Interpreter(interpreter) => interpreter.generation_state.memory.contexts[0] - .segments[Segment::KernelGeneral.unscale()] - .content - .clone(), - } - } - - /// Applies a `State`'s operations since a checkpoint. - pub(crate) fn apply_ops(&mut self, checkpoint: GenerationStateCheckpoint) { - match self { - Self::Generation(state) => state - .memory - .apply_ops(state.traces.mem_ops_since(checkpoint.traces)), - Self::Interpreter(interpreter) => { - // An interpreter `checkpoint()` clears all operations before the checkpoint. - interpreter.apply_memops(); - } - } - } -} - -/// Simulates a CPU. It only generates the traces if the `State` is a -/// `GenerationState`. Otherwise, it simply simulates all operations. -pub(crate) fn run_cpu(any_state: &mut State) -> anyhow::Result<()> { - let halt_offsets = match any_state { - State::Generation(_state) => vec![KERNEL.global_labels["halt"]], - State::Interpreter(interpreter) => interpreter.halt_offsets.clone(), - }; - - loop { - // If we've reached the kernel's halt routine. - let registers = any_state.get_registers(); - let pc = registers.program_counter; - - let halt = registers.is_kernel && halt_offsets.contains(&pc); - - if halt { - if let Some(halt_context) = any_state.get_halt_context() { - if registers.context == halt_context { - // Only happens during jumpdest analysis. - return Ok(()); - } - } else { - log::info!("CPU halted after {} cycles", any_state.get_clock()); - - return Ok(()); - } - } - - transition(any_state)?; - any_state.incr_interpreter_clock(); - } -} - fn simulate_cpu(state: &mut GenerationState) -> anyhow::Result<()> { - run_cpu(&mut State::Generation(state))?; + state.run_cpu()?; let pc = state.registers.program_counter; // Padding diff --git a/evm_arithmetization/src/generation/prover_input.rs b/evm_arithmetization/src/generation/prover_input.rs index f5d635e6b..460eb4aeb 100644 --- a/evm_arithmetization/src/generation/prover_input.rs +++ b/evm_arithmetization/src/generation/prover_input.rs @@ -9,6 +9,7 @@ use num_bigint::BigUint; use plonky2::field::types::Field; use serde::{Deserialize, Serialize}; +use super::state::State; use crate::cpu::kernel::constants::context_metadata::ContextMetadata; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::cpu::kernel::interpreter::simulate_cpu_and_get_user_jumps; @@ -122,9 +123,7 @@ impl GenerationState { let ptr = stack_peek(self, 11 - n).map(u256_to_usize)??; let f: [U256; 12] = match field { - Bn254Base => std::array::from_fn(|i| { - current_context_peek(self, BnPairing, ptr + i, false, &HashMap::default()) - }), + Bn254Base => std::array::from_fn(|i| current_context_peek(self, BnPairing, ptr + i)), _ => todo!(), }; Ok(field.field_extension_inverse(n, f)) @@ -434,26 +433,21 @@ impl GenerationState { let code_len = self.get_code_len(context)?; let code = (0..code_len) .map(|i| { - u256_to_u8(self.memory.get_with_init( - MemoryAddress::new(context, Segment::Code, i), - false, - &HashMap::default(), - )) + u256_to_u8( + self.memory + .get_with_init(MemoryAddress::new(context, Segment::Code, i)), + ) }) .collect::, _>>()?; Ok(code) } fn get_code_len(&self, context: usize) -> Result { - let code_len = u256_to_usize(self.memory.get_with_init( - MemoryAddress::new( - context, - Segment::ContextMetadata, - ContextMetadata::CodeSize.unscale(), - ), - false, - &HashMap::default(), - ))?; + let code_len = u256_to_usize(self.memory.get_with_init(MemoryAddress::new( + context, + Segment::ContextMetadata, + ContextMetadata::CodeSize.unscale(), + )))?; Ok(code_len) } @@ -488,11 +482,11 @@ impl GenerationState { } fn get_global_metadata(&self, data: GlobalMetadata) -> U256 { - self.memory.get_with_init( - MemoryAddress::new(0, Segment::GlobalMetadata, data.unscale()), - false, - &HashMap::default(), - ) + self.memory.get_with_init(MemoryAddress::new( + 0, + Segment::GlobalMetadata, + data.unscale(), + )) } pub(crate) fn get_storage_keys_access_list(&self) -> Result { @@ -500,15 +494,11 @@ impl GenerationState { // virtual address in the segment. In order to get the length we need // to substract Segment::AccessedStorageKeys as usize let acc_storage_len = u256_to_usize( - self.memory.get_with_init( - MemoryAddress::new( - 0, - Segment::GlobalMetadata, - GlobalMetadata::AccessedStorageKeysLen.unscale(), - ), - false, - &HashMap::default(), - ) - Segment::AccessedStorageKeys as usize, + self.memory.get_with_init(MemoryAddress::new( + 0, + Segment::GlobalMetadata, + GlobalMetadata::AccessedStorageKeysLen.unscale(), + )) - Segment::AccessedStorageKeys as usize, )?; AccList::from_mem_and_segment( &self.memory.contexts[0].segments[Segment::AccessedStorageKeys.unscale()].content diff --git a/evm_arithmetization/src/generation/state.rs b/evm_arithmetization/src/generation/state.rs index 820ce6a4e..e5a733832 100644 --- a/evm_arithmetization/src/generation/state.rs +++ b/evm_arithmetization/src/generation/state.rs @@ -1,26 +1,216 @@ use std::collections::HashMap; +use anyhow::bail; use ethereum_types::{Address, BigEndianHash, H160, H256, U256}; use keccak_hash::keccak; +use log::log_enabled; use plonky2::field::types::Field; use super::mpt::{load_all_mpts, TrieRootPtrs}; use super::TrieInputs; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::constants::context_metadata::ContextMetadata; +use crate::cpu::membus::NUM_GP_CHANNELS; +use crate::cpu::stack::MAX_USER_STACK_SIZE; use crate::generation::rlp::all_rlp_prover_inputs_reversed; +use crate::generation::CpuColumnsView; use crate::generation::GenerationInputs; use crate::memory::segments::Segment; use crate::util::u256_to_usize; use crate::witness::errors::ProgramError; -use crate::witness::memory::{MemoryAddress, MemoryState}; +use crate::witness::memory::MemoryChannel::GeneralPurpose; +use crate::witness::memory::{MemoryAddress, MemoryOp, MemoryState}; +use crate::witness::memory::{MemoryOpKind, MemorySegmentState}; +use crate::witness::operation::{generate_exception, Operation}; use crate::witness::state::RegistersState; use crate::witness::traces::{TraceCheckpoint, Traces}; -use crate::witness::util::stack_peek; +use crate::witness::transition::{ + decode, fill_op_flag, get_op_special_length, log_kernel_instruction, might_overflow_op, + read_code_memory, Transition, +}; +use crate::witness::util::{ + fill_channel_with_value, mem_read_gp_with_log_and_fill, stack_peek, stack_pop_with_log_and_fill, +}; -pub(crate) struct GenerationStateCheckpoint { - pub(crate) registers: RegistersState, - pub(crate) traces: TraceCheckpoint, +/// A State is either an `Interpreter` (used for tests and jumpdest analysis) or +/// a `GenerationState`. +pub(crate) trait State { + /// Returns a `State`'s `Checkpoint`. + fn checkpoint(&mut self) -> GenerationStateCheckpoint; + + fn is_generation(&mut self) -> bool { + true + } + /// Increments the `gas_used` register by a value `n`. + fn incr_gas(&mut self, n: u64); + + /// Increments the `program_counter` register by a value `n`. + fn incr_pc(&mut self, n: usize); + + /// Returns a `State`'s registers. + fn get_registers(&self) -> RegistersState; + + /// Returns a `State`'s mutable registers. + fn get_mut_registers(&mut self) -> &mut RegistersState; + + /// Returns the value stored at address `address` in a `State`. + fn get_from_memory(&mut self, address: MemoryAddress) -> U256; + + /// Returns a mutable `GenerationState` from a `State`. + fn get_mut_generation_state(&mut self) -> &mut GenerationState; + + // /// Returns true if a `State` is a `GenerationState` and false otherwise. + // fn is_generation_state(&mut self) -> bool; + + /// Increments the clock of an `Interpreter`'s clock. + fn incr_interpreter_clock(&mut self); + + /// Returns the value of a `State`'s clock. + fn get_clock(&mut self) -> usize; + + /// Rolls back a `State`. + fn rollback(&mut self, checkpoint: GenerationStateCheckpoint); + + /// Returns a `State`'s stack. + fn get_stack(&mut self) -> Vec; + + /// Returns the current context. + fn get_context(&mut self) -> usize; + + fn get_halt_context(&mut self) -> Option { + None + } + + /// Returns the content of a the `KernelGeneral` segment of a `State`. + fn mem_get_kernel_content(&self) -> Vec>; + + /// Applies a `State`'s operations since a checkpoint. + fn apply_ops(&mut self, checkpoint: GenerationStateCheckpoint); + + /// Return the offsets at which execution must halt + fn get_halt_offsets(&self) -> Vec; + + /// Inserts a preinitialized segment, given as a [Segment], + /// into the `preinitialized_segments` memory field. + fn insert_preinitialized_segment(&mut self, segment: Segment, values: MemorySegmentState); + + fn is_preinitialized_segment(&self, segment: usize) -> bool; + + /// Simulates a CPU. It only generates the traces if the `State` is a + /// `GenerationState`. Otherwise, it simply simulates all ooperations. + fn run_cpu(&mut self) -> anyhow::Result<()> + where + Self: Transition, + Self: Sized, + { + let halt_offsets = self.get_halt_offsets(); + + loop { + // If we've reached the kernel's halt routine. + let registers = self.get_registers(); + let pc = registers.program_counter; + + let halt = registers.is_kernel && halt_offsets.contains(&pc); + + if halt { + if let Some(halt_context) = self.get_halt_context() { + if registers.context == halt_context { + // Only happens during jumpdest analysis. + return Ok(()); + } + } else { + log::info!("CPU halted after {} cycles", self.get_clock()); + return Ok(()); + } + } + + self.transition()?; + self.incr_interpreter_clock(); + } + + Ok(()) + } + + fn handle_error(&mut self, err: ProgramError) -> anyhow::Result<()> + where + Self: Transition, + Self: Sized, + { + let exc_code: u8 = match err { + ProgramError::OutOfGas => 0, + ProgramError::InvalidOpcode => 1, + ProgramError::StackUnderflow => 2, + ProgramError::InvalidJumpDestination => 3, + ProgramError::InvalidJumpiDestination => 4, + ProgramError::StackOverflow => 5, + _ => bail!("TODO: figure out what to do with this..."), + }; + + let checkpoint = self.checkpoint(); + + let (row, _) = self.base_row(); + let is_generation = self.is_generation(); + generate_exception(exc_code, self, row, is_generation); + + self.apply_ops(checkpoint); + + Ok(()) + } + + fn transition(&mut self) -> anyhow::Result<()> + where + Self: Transition, + Self: Sized, + { + let checkpoint = self.checkpoint(); + let result = self.try_perform_instruction(); + + match result { + Ok(op) => { + self.apply_ops(checkpoint); + + if might_overflow_op(op) { + self.get_mut_registers().check_overflow = true; + } + Ok(()) + } + Err(e) => { + if self.get_registers().is_kernel { + let offset_name = KERNEL.offset_name(self.get_registers().program_counter); + bail!( + "{:?} in kernel at pc={}, stack={:?}, memory={:?}", + e, + offset_name, + self.get_stack(), + self.mem_get_kernel_content(), + ); + } + self.rollback(checkpoint); + self.handle_error(e) + } + } + } + + fn try_perform_instruction(&mut self) -> Result; + + /// Row that has the correct values for system registers and the code + /// channel, but is otherwise blank. It fulfills the constraints that + /// are common to successful operations and the exception operation. It + /// also returns the opcode + fn base_row(&mut self) -> (CpuColumnsView, u8) { + let generation_state = self.get_mut_generation_state(); + let mut row: CpuColumnsView = CpuColumnsView::default(); + row.clock = F::from_canonical_usize(generation_state.traces.clock()); + row.context = F::from_canonical_usize(generation_state.registers.context); + row.program_counter = F::from_canonical_usize(generation_state.registers.program_counter); + row.is_kernel_mode = F::from_bool(generation_state.registers.is_kernel); + row.gas = F::from_canonical_u64(generation_state.registers.gas_used); + row.stack_len = F::from_canonical_usize(generation_state.registers.stack_len); + fill_channel_with_value(&mut row, 0, generation_state.registers.stack_top); + + let opcode = read_code_memory(generation_state, &mut row); + (row, opcode) + } } #[derive(Debug)] @@ -144,11 +334,7 @@ impl GenerationState { let returndata_offset = ContextMetadata::ReturndataSize.unscale(); let returndata_size_addr = MemoryAddress::new(ctx, Segment::ContextMetadata, returndata_offset); - let returndata_size = u256_to_usize(self.memory.get_with_init( - returndata_size_addr, - false, - &HashMap::default(), - ))?; + let returndata_size = u256_to_usize(self.memory.get_with_init(returndata_size_addr))?; let code = self.memory.contexts[ctx].segments[Segment::Returndata.unscale()].content [..returndata_size] .iter() @@ -161,13 +347,6 @@ impl GenerationState { Ok(()) } - pub(crate) fn checkpoint(&self) -> GenerationStateCheckpoint { - GenerationStateCheckpoint { - registers: self.registers, - traces: self.traces.checkpoint(), - } - } - pub(crate) fn rollback(&mut self, checkpoint: GenerationStateCheckpoint) { self.registers = checkpoint.registers; self.traces.rollback(checkpoint.traces); @@ -201,6 +380,190 @@ impl GenerationState { } } +impl State for GenerationState { + fn checkpoint(&mut self) -> GenerationStateCheckpoint { + GenerationStateCheckpoint { + registers: self.registers, + traces: self.traces.checkpoint(), + } + } + + fn insert_preinitialized_segment(&mut self, segment: Segment, values: MemorySegmentState) { + panic!( + "A `GenerationState` cannot have a nonempty `preinitialized_segment` field in memory." + ) + } + + fn is_preinitialized_segment(&self, segment: usize) -> bool { + false + } + + /// Increments the `gas_used` register by a value `n`. + fn incr_gas(&mut self, n: u64) { + self.registers.gas_used += n; + } + + /// Increments the `program_counter` register by a value `n`. + fn incr_pc(&mut self, n: usize) { + self.registers.program_counter += n; + } + + /// Returns a `State`'s registers. + fn get_registers(&self) -> RegistersState { + self.registers + } + + /// Returns a `State`'s mutable registers. + fn get_mut_registers(&mut self) -> &mut RegistersState { + &mut self.registers + } + + /// Returns the value stored at address `address` in a `State`. + fn get_from_memory(&mut self, address: MemoryAddress) -> U256 { + self.memory.get_with_init(address) + } + + /// Returns a mutable `GenerationState` from a `State`. + fn get_mut_generation_state(&mut self) -> &mut GenerationState { + self + } + + /// Increments the clock of an `Interpreter`'s clock. + fn incr_interpreter_clock(&mut self) {} + + /// Returns the value of a `State`'s clock. + fn get_clock(&mut self) -> usize { + self.traces.clock() + } + + /// Rolls back a `State`. + fn rollback(&mut self, checkpoint: GenerationStateCheckpoint) { + self.rollback(checkpoint) + } + + /// Returns a `State`'s stack. + fn get_stack(&mut self) -> Vec { + self.stack() + } + + fn get_context(&mut self) -> usize { + self.registers.context + } + + /// Returns the content of a the `KernelGeneral` segment of a `State`. + fn mem_get_kernel_content(&self) -> Vec> { + self.memory.contexts[0].segments[Segment::KernelGeneral.unscale()] + .content + .clone() + } + + /// Applies a `State`'s operations since a checkpoint. + fn apply_ops(&mut self, checkpoint: GenerationStateCheckpoint) { + self.memory + .apply_ops(self.traces.mem_ops_since(checkpoint.traces)) + } + + fn get_halt_offsets(&self) -> Vec { + vec![KERNEL.global_labels["halt"]] + } + + fn try_perform_instruction(&mut self) -> Result { + let registers = self.registers; + let (mut row, opcode) = self.base_row(); + + let op = decode(registers, opcode)?; + + if registers.is_kernel { + log_kernel_instruction(self, op); + } else { + log::debug!("User instruction: {:?}", op); + } + fill_op_flag(op, &mut row); + + self.fill_stack_fields(&mut row)?; + + // Might write in general CPU columns when it shouldn't, but the correct values + // will overwrite these ones during the op generation. + if let Some(special_len) = get_op_special_length(op) { + let special_len_f = F::from_canonical_usize(special_len); + let diff = row.stack_len - special_len_f; + if let Some(inv) = diff.try_inverse() { + row.general.stack_mut().stack_inv = inv; + row.general.stack_mut().stack_inv_aux = F::ONE; + self.registers.is_stack_top_read = true; + } else if (self.stack().len() != special_len) { + // If the `State` is an interpreter, we cannot rely on the row to carry out the + // check. + self.registers.is_stack_top_read = true; + } + } else if let Some(inv) = row.stack_len.try_inverse() { + row.general.stack_mut().stack_inv = inv; + row.general.stack_mut().stack_inv_aux = F::ONE; + } + + self.perform_state_op(opcode, op, row) + } +} + +impl Transition for GenerationState { + fn skip_if_necessary(&mut self, op: Operation) -> Result { + Ok(op) + } + + fn generate_jumpdest_analysis(&mut self, dst: usize) -> bool { + false + } + + fn fill_stack_fields(&mut self, row: &mut CpuColumnsView) -> Result<(), ProgramError> { + if self.registers.is_stack_top_read { + let channel = &mut row.mem_channels[0]; + channel.used = F::ONE; + channel.is_read = F::ONE; + channel.addr_context = F::from_canonical_usize(self.registers.context); + channel.addr_segment = F::from_canonical_usize(Segment::Stack.unscale()); + channel.addr_virtual = F::from_canonical_usize(self.registers.stack_len - 1); + + let address = MemoryAddress::new( + self.registers.context, + Segment::Stack, + self.registers.stack_len - 1, + ); + + let mem_op = MemoryOp::new( + GeneralPurpose(0), + self.traces.clock(), + address, + MemoryOpKind::Read, + self.registers.stack_top, + ); + self.traces.push_memory(mem_op); + } + self.registers.is_stack_top_read = false; + + if self.registers.check_overflow { + if self.registers.is_kernel { + row.general.stack_mut().stack_len_bounds_aux = F::ZERO; + } else { + let clock = self.traces.clock(); + let last_row = &mut self.traces.cpu[clock - 1]; + let disallowed_len = F::from_canonical_usize(MAX_USER_STACK_SIZE + 1); + let diff = row.stack_len - disallowed_len; + if let Some(inv) = diff.try_inverse() { + last_row.general.stack_mut().stack_len_bounds_aux = inv; + } + } + } + self.registers.check_overflow = false; + + Ok(()) + } +} + +pub(crate) struct GenerationStateCheckpoint { + pub(crate) registers: RegistersState, + pub(crate) traces: TraceCheckpoint, +} + /// Withdrawals prover input array is of the form `[addr0, amount0, ..., addrN, /// amountN, U256::MAX, U256::MAX]`. Returns the reversed array. pub(crate) fn all_withdrawals_prover_inputs_reversed(withdrawals: &[(Address, U256)]) -> Vec { diff --git a/evm_arithmetization/src/witness/memory.rs b/evm_arithmetization/src/witness/memory.rs index a1ac91604..42f49d080 100644 --- a/evm_arithmetization/src/witness/memory.rs +++ b/evm_arithmetization/src/witness/memory.rs @@ -163,6 +163,7 @@ impl MemoryOp { #[derive(Clone, Debug)] pub(crate) struct MemoryState { pub(crate) contexts: Vec, + preinitialized_segments: HashMap, } impl MemoryState { @@ -210,22 +211,22 @@ impl MemoryState { Some(val) } - pub(crate) fn get_with_init( - &self, - address: MemoryAddress, - is_interpreter: bool, - preinitialized_segments: &HashMap, - ) -> U256 { + pub(crate) fn get_with_init(&self, address: MemoryAddress) -> U256 { match self.get(address) { Some(val) => val, None => { let segment = Segment::all()[address.segment]; let offset = address.virt; - if is_interpreter - && preinitialized_segments.contains_key(&segment) - && offset < preinitialized_segments.get(&segment).unwrap().content.len() + if self.preinitialized_segments.contains_key(&segment) + && offset + < self + .preinitialized_segments + .get(&segment) + .unwrap() + .content + .len() { - preinitialized_segments.get(&segment).unwrap().content[offset] + self.preinitialized_segments.get(&segment).unwrap().content[offset] .expect("We checked that the offset is not out of bounds.") } else { 0.into() @@ -253,11 +254,34 @@ impl MemoryState { // These fields are already scaled by their respective segment. pub(crate) fn read_global_metadata(&self, field: GlobalMetadata) -> U256 { - self.get_with_init( - MemoryAddress::new_bundle(U256::from(field as usize)).unwrap(), - false, - &HashMap::default(), - ) + self.get_with_init(MemoryAddress::new_bundle(U256::from(field as usize)).unwrap()) + } + + /// Inserts a segment and its preinitialized values in + /// `preinitialized_segments`. + pub(crate) fn insert_preinitialized_segment( + &mut self, + segment: Segment, + values: MemorySegmentState, + ) { + self.preinitialized_segments.insert(segment, values); + } + + /// Returns a boolean which indicates whether a segment (given as a usize) + /// is part of the `preinitialize_segments`. + pub(crate) fn is_preinitialized_segment(&self, segment: usize) -> bool { + if let Some(seg) = Segment::all().get(segment) { + self.preinitialized_segments.contains_key(seg) + } else { + false + } + } + + pub(crate) fn get_preinitialized_segment( + &self, + segment: Segment, + ) -> Option<&MemorySegmentState> { + self.preinitialized_segments.get(&segment) } } @@ -266,6 +290,7 @@ impl Default for MemoryState { Self { // We start with an initial context for the kernel. contexts: vec![MemoryContextState::default()], + preinitialized_segments: HashMap::default(), } } } diff --git a/evm_arithmetization/src/witness/operation.rs b/evm_arithmetization/src/witness/operation.rs index 0fad5072a..469651d03 100644 --- a/evm_arithmetization/src/witness/operation.rs +++ b/evm_arithmetization/src/witness/operation.rs @@ -6,6 +6,7 @@ use keccak_hash::keccak; use plonky2::field::types::Field; use super::memory::MemorySegmentState; +use super::transition::Transition; use super::util::{ byte_packing_log, byte_unpacking_log, mem_read_with_log, mem_write_log, mem_write_partial_log_and_fill, push_no_write, push_with_write, @@ -20,14 +21,13 @@ use crate::cpu::simple_logic::eq_iszero::generate_pinv_diff; use crate::cpu::stack::MAX_USER_STACK_SIZE; use crate::extension_tower::BN_BASE; use crate::generation::state::GenerationState; -use crate::generation::State; +use crate::generation::state::State; use crate::memory::segments::Segment; use crate::util::u256_to_usize; use crate::witness::errors::MemoryError::VirtTooLarge; use crate::witness::errors::ProgramError; use crate::witness::memory::{MemoryAddress, MemoryChannel, MemoryOp, MemoryOpKind}; use crate::witness::operation::MemoryChannel::GeneralPurpose; -use crate::witness::transition::fill_stack_fields; use crate::witness::util::{ keccak_sponge_log, mem_read_gp_with_log_and_fill, mem_write_gp_log_and_fill, stack_pop_with_log_and_fill, @@ -140,7 +140,6 @@ pub(crate) fn generate_keccak_general( state: &mut GenerationState, mut row: CpuColumnsView, is_generation: bool, - preinitialized_segments: &HashMap, ) -> Result<(), ProgramError> { let [(addr, _), (len, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; let len = u256_to_usize(len)?; @@ -152,9 +151,7 @@ pub(crate) fn generate_keccak_general( virt: base_address.virt.saturating_add(i), ..base_address }; - let val = state - .memory - .get_with_init(address, !is_generation, preinitialized_segments); + let val = state.memory.get_with_init(address); val.low_u32() as u8 }) .collect_vec(); @@ -221,167 +218,6 @@ pub(crate) fn generate_pop( Ok(()) } -pub(crate) fn generate_jump( - state: &mut State, - mut row: CpuColumnsView, - is_generation: bool, - is_jumpdest_analysis: bool, -) -> Result<(), ProgramError> { - let [(dst, _)] = - stack_pop_with_log_and_fill::<1, _>(state.get_mut_generation_state(), &mut row)?; - - let dst: u32 = dst - .try_into() - .map_err(|_| ProgramError::InvalidJumpDestination)?; - - if is_jumpdest_analysis { - match state { - State::Generation(_state) => { - panic!("Cannot carry out jumpdest analysis with a `GenerationState.") - } - State::Interpreter(interpreter) => { - if !interpreter.generation_state.registers.is_kernel { - interpreter.add_jumpdest_offset(dst as usize); - } - } - } - } else { - let gen_state = state.get_mut_generation_state(); - // Even though we might be in the interpreter, `JumpdestBits` is not part of the - // preinitialized segments, so we don't need to carry out the additional checks - // when get the value from memory. - let (jumpdest_bit, jumpdest_bit_log) = mem_read_gp_with_log_and_fill( - NUM_GP_CHANNELS - 1, - MemoryAddress::new( - gen_state.registers.context, - Segment::JumpdestBits, - dst as usize, - ), - gen_state, - &mut row, - false, - &HashMap::default(), - ); - - row.mem_channels[1].value[0] = F::ONE; - - if gen_state.registers.is_kernel { - // Don't actually do the read, just set the address, etc. - let channel = &mut row.mem_channels[NUM_GP_CHANNELS - 1]; - channel.used = F::ZERO; - channel.value[0] = F::ONE; - } else { - if jumpdest_bit != U256::one() { - return Err(ProgramError::InvalidJumpDestination); - } - gen_state.traces.push_memory(jumpdest_bit_log); - } - - // Extra fields required by the constraints. - row.general.jumps_mut().should_jump = F::ONE; - row.general.jumps_mut().cond_sum_pinv = F::ONE; - - let diff = row.stack_len - F::ONE; - if let Some(inv) = diff.try_inverse() { - row.general.stack_mut().stack_inv = inv; - row.general.stack_mut().stack_inv_aux = F::ONE; - } else { - row.general.stack_mut().stack_inv = F::ZERO; - row.general.stack_mut().stack_inv_aux = F::ZERO; - } - - gen_state.traces.push_cpu(is_generation, row); - } - state.get_mut_generation_state().jump_to(dst as usize)?; - Ok(()) -} - -pub(crate) fn generate_jumpi( - state: &mut State, - mut row: CpuColumnsView, - is_generation: bool, - is_jumpdest_analysis: bool, -) -> Result<(), ProgramError> { - let [(dst, _), (cond, log_cond)] = - stack_pop_with_log_and_fill::<2, _>(state.get_mut_generation_state(), &mut row)?; - - let should_jump = !cond.is_zero(); - if should_jump { - let dst: u32 = dst - .try_into() - .map_err(|_| ProgramError::InvalidJumpiDestination)?; - - if is_jumpdest_analysis { - match state { - State::Generation(_state) => { - panic!("Cannot carry out jumpdest analysis with a `GenerationState`.") - } - State::Interpreter(interpreter) => { - let is_kernel = interpreter.generation_state.registers.is_kernel; - if !is_kernel { - interpreter.add_jumpdest_offset(dst as usize) - } - } - } - } else { - row.general.jumps_mut().should_jump = F::ONE; - let cond_sum_u64 = cond - .0 - .into_iter() - .map(|limb| ((limb as u32) as u64) + (limb >> 32)) - .sum(); - let cond_sum = F::from_canonical_u64(cond_sum_u64); - row.general.jumps_mut().cond_sum_pinv = cond_sum.inverse(); - } - state.get_mut_generation_state().jump_to(dst as usize)?; - } else { - row.general.jumps_mut().should_jump = F::ZERO; - row.general.jumps_mut().cond_sum_pinv = F::ZERO; - state.incr_pc(1); - } - - let gen_state = state.get_mut_generation_state(); - // Even though we might be in the interpreter, `JumpdestBits` is not part of the - // preinitialized segments, so we don't need to carry out the additional checks - // when get the value from memory. - let (jumpdest_bit, jumpdest_bit_log) = mem_read_gp_with_log_and_fill( - NUM_GP_CHANNELS - 1, - MemoryAddress::new( - gen_state.registers.context, - Segment::JumpdestBits, - dst.low_u32() as usize, - ), - gen_state, - &mut row, - false, - &HashMap::default(), - ); - if !should_jump || gen_state.registers.is_kernel { - // Don't actually do the read, just set the address, etc. - let channel = &mut row.mem_channels[NUM_GP_CHANNELS - 1]; - channel.used = F::ZERO; - channel.value[0] = F::ONE; - } else { - if jumpdest_bit != U256::one() { - return Err(ProgramError::InvalidJumpiDestination); - } - gen_state.traces.push_memory(jumpdest_bit_log); - } - - let diff = row.stack_len - F::TWO; - if let Some(inv) = diff.try_inverse() { - row.general.stack_mut().stack_inv = inv; - row.general.stack_mut().stack_inv_aux = F::ONE; - } else { - row.general.stack_mut().stack_inv = F::ZERO; - row.general.stack_mut().stack_inv_aux = F::ZERO; - } - - gen_state.traces.push_memory(log_cond); - gen_state.traces.push_cpu(is_generation, row); - Ok(()) -} - pub(crate) fn generate_pc( state: &mut GenerationState, mut row: CpuColumnsView, @@ -466,13 +302,7 @@ pub(crate) fn generate_set_context( // Even though we might be in the interpreter, `Stack` is not part of the // preinitialized segments, so we don't need to carry out the additional checks // when get the value from memory. - mem_read_with_log( - GeneralPurpose(2), - new_sp_addr, - state, - false, - &HashMap::default(), - ) + mem_read_with_log(GeneralPurpose(2), new_sp_addr, state) }; // If the new stack isn't empty, read stack_top from memory. @@ -494,14 +324,8 @@ pub(crate) fn generate_set_context( // Even though we might be in the interpreter, `Stack` is not part of the // preinitialized segments, so we don't need to carry out the additional checks // when get the value from memory. - let (new_top, log_read_new_top) = mem_read_gp_with_log_and_fill( - 2, - new_top_addr, - state, - &mut row, - false, - &HashMap::default(), - ); + let (new_top, log_read_new_top) = + mem_read_gp_with_log_and_fill(2, new_top_addr, state, &mut row); state.registers.stack_top = new_top; state.traces.push_memory(log_read_new_top); } else { @@ -539,18 +363,10 @@ pub(crate) fn generate_push( .map(|i| { state .memory - .get_with_init( - MemoryAddress { - virt: base_address.virt + i, - ..base_address - }, - // Even though we might be in the interpreter, `Code` is not part of - // the preinitialized segments, so we don't need to carry - // out the additional checks when get the value from - // memory. - false, - &HashMap::default(), - ) + .get_with_init(MemoryAddress { + virt: base_address.virt + i, + ..base_address + }) .low_u32() as u8 }) .collect_vec(); @@ -629,7 +445,7 @@ pub(crate) fn generate_dup( // Even though we might be in the interpreter, `Stack` is not part of the // preinitialized segments, so we don't need to carry out the additional checks // when get the value from memory. - mem_read_gp_with_log_and_fill(2, other_addr, state, &mut row, false, &HashMap::default()) + mem_read_gp_with_log_and_fill(2, other_addr, state, &mut row) }; push_no_write(state, val); @@ -655,8 +471,7 @@ pub(crate) fn generate_swap( // Even though we might be in the interpreter, `Stack` is not part of the // preinitialized segments, so we don't need to carry out the additional checks // when get the value from memory. - let (in1, log_in1) = - mem_read_gp_with_log_and_fill(1, other_addr, state, &mut row, false, &HashMap::default()); + let (in1, log_in1) = mem_read_gp_with_log_and_fill(1, other_addr, state, &mut row); let log_out0 = mem_write_gp_log_and_fill(2, other_addr, state, &mut row, in0); push_no_write(state, in1); @@ -725,14 +540,7 @@ fn append_shift( // Even though we might be in the interpreter, `ShiftTable` is not part of the // preinitialized segments, so we don't need to carry out the additional checks // when get the value from memory. - let (_, read) = mem_read_gp_with_log_and_fill( - LOOKUP_CHANNEL, - lookup_addr, - state, - &mut row, - false, - &HashMap::default(), - ); + let (_, read) = mem_read_gp_with_log_and_fill(LOOKUP_CHANNEL, lookup_addr, state, &mut row); state.traces.push_memory(read); } else { // The shift constraints still expect the address to be set, even though no read @@ -845,9 +653,7 @@ pub(crate) fn generate_syscall( // Even though we might be in the interpreter, `Code` is not part of the // preinitialized segments, so we don't need to carry out the additional checks // when get the value from memory. - let val = state - .memory - .get_with_init(address, false, &HashMap::default()); + let val = state.memory.get_with_init(address); val.low_u32() as u8 }) .collect_vec(); @@ -953,18 +759,11 @@ pub(crate) fn generate_mload_general( state: &mut GenerationState, mut row: CpuColumnsView, is_generation: bool, - preinitialized_segments: &HashMap, ) -> Result<(), ProgramError> { let [(addr, _)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; - let (val, log_read) = mem_read_gp_with_log_and_fill( - 1, - MemoryAddress::new_bundle(addr)?, - state, - &mut row, - !is_generation, - preinitialized_segments, - ); + let (val, log_read) = + mem_read_gp_with_log_and_fill(1, MemoryAddress::new_bundle(addr)?, state, &mut row); push_no_write(state, val); // Because MLOAD_GENERAL performs 1 pop and 1 push, it does not make use of the @@ -989,7 +788,6 @@ pub(crate) fn generate_mload_32bytes( state: &mut GenerationState, mut row: CpuColumnsView, is_generation: bool, - preinitialized_segments: &HashMap, ) -> Result<(), ProgramError> { let [(addr, _), (len, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; let len = u256_to_usize(len)?; @@ -1010,9 +808,7 @@ pub(crate) fn generate_mload_32bytes( virt: base_address.virt + i, ..base_address }; - let val = state - .memory - .get_with_init(address, !is_generation, preinitialized_segments); + let val = state.memory.get_with_init(address); val.low_u32() as u8 }) .collect_vec(); @@ -1056,33 +852,43 @@ pub(crate) fn generate_mstore_general( Ok(()) } -pub(crate) fn generate_mstore_32bytes( +pub(crate) fn generate_mstore_32bytes>( n: u8, - state: &mut GenerationState, + state: &mut S, mut row: CpuColumnsView, is_generation: bool, ) -> Result<(), ProgramError> { - let [(addr, _), (val, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; + let generation_state = state.get_mut_generation_state(); + let [(addr, _), (val, log_in1)] = + stack_pop_with_log_and_fill::<2, _>(generation_state, &mut row)?; let base_address = MemoryAddress::new_bundle(addr)?; - byte_unpacking_log(state, is_generation, base_address, val, n as usize); + byte_unpacking_log( + generation_state, + is_generation, + base_address, + val, + n as usize, + ); let new_addr = addr + n; - push_no_write(state, new_addr); + push_no_write(generation_state, new_addr); - state.traces.push_memory(log_in1); - state.traces.push_cpu(is_generation, row); + generation_state.traces.push_memory(log_in1); + generation_state.traces.push_cpu(is_generation, row); Ok(()) } -pub(crate) fn generate_exception( +pub(crate) fn generate_exception>( exc_code: u8, - state: &mut GenerationState, + state: &mut T, mut row: CpuColumnsView, is_generation: bool, ) -> Result<(), ProgramError> { - if TryInto::::try_into(state.registers.gas_used).is_err() { + state.fill_stack_fields(&mut row)?; + let generation_state = state.get_mut_generation_state(); + if TryInto::::try_into(generation_state.registers.gas_used).is_err() { return Err(ProgramError::GasLimitError); } @@ -1093,8 +899,6 @@ pub(crate) fn generate_exception( row.general.stack_mut().stack_inv_aux = F::ONE; } - fill_stack_fields(state, &mut row, is_generation)?; - row.general.exception_mut().exc_code_bits = [ F::from_bool(exc_code & 1 != 0), F::from_bool(exc_code & 2 != 0), @@ -1112,12 +916,7 @@ pub(crate) fn generate_exception( virt: base_address.virt + i, ..base_address }; - // Even though we might be in the interpreter, `Code` is not part of the - // preinitialized segments, so we don't need to carry out the additional checks - // when get the value from memory. - let val = state - .memory - .get_with_init(address, false, &HashMap::default()); + let val = generation_state.memory.get_with_init(address); val.low_u32() as u8 }) .collect_vec(); @@ -1131,22 +930,24 @@ pub(crate) fn generate_exception( jumptable_channel.addr_virtual = F::from_canonical_usize(handler_addr_addr); jumptable_channel.value[0] = F::from_canonical_usize(u256_to_usize(packed_int)?); - byte_packing_log(state, is_generation, base_address, bytes); + byte_packing_log(generation_state, is_generation, base_address, bytes); let new_program_counter = u256_to_usize(packed_int)?; - let gas = U256::from(state.registers.gas_used); + let gas = U256::from(generation_state.registers.gas_used); - let exc_info = U256::from(state.registers.program_counter) + (gas << 192); + let exc_info = U256::from(generation_state.registers.program_counter) + (gas << 192); // Get the opcode so we can provide it to the range_check operation. - let code_context = state.registers.code_context(); - let address = MemoryAddress::new(code_context, Segment::Code, state.registers.program_counter); + let code_context = generation_state.registers.code_context(); + let address = MemoryAddress::new( + code_context, + Segment::Code, + generation_state.registers.program_counter, + ); // Even though we might be in the interpreter, `Code` is not part of the // preinitialized segments, so we don't need to carry out the additional checks // when get the value from memory. - let opcode = state - .memory - .get_with_init(address, false, &HashMap::default()); + let opcode = generation_state.memory.get_with_init(address); // `ArithmeticStark` range checks `mem_channels[0]`, which contains // the top of the stack, `mem_channels[1]`, which contains the new PC, @@ -1155,7 +956,7 @@ pub(crate) fn generate_exception( // Our goal here is to range-check the gas, contained in syscall_info, // stored in the next stack top. let range_check_op = arithmetic::Operation::range_check( - state.registers.stack_top, + generation_state.registers.stack_top, packed_int, U256::from(0), opcode, @@ -1165,15 +966,17 @@ pub(crate) fn generate_exception( // kernel mode so we can't incorrectly trigger a stack overflow. However, // note that we have to do it _after_ we make `exc_info`, which should // contain the old values. - state.registers.program_counter = new_program_counter; - state.registers.is_kernel = true; - state.registers.gas_used = 0; + generation_state.registers.program_counter = new_program_counter; + generation_state.registers.is_kernel = true; + generation_state.registers.gas_used = 0; - push_with_write(state, &mut row, exc_info)?; + push_with_write(generation_state, &mut row, exc_info)?; log::debug!("Exception to {}", KERNEL.offset_name(new_program_counter)); - state.traces.push_arithmetic(is_generation, range_check_op); - state.traces.push_cpu(is_generation, row); + generation_state + .traces + .push_arithmetic(is_generation, range_check_op); + generation_state.traces.push_cpu(is_generation, row); Ok(()) } diff --git a/evm_arithmetization/src/witness/transition.rs b/evm_arithmetization/src/witness/transition.rs index 8c15671d8..b6c9c1826 100644 --- a/evm_arithmetization/src/witness/transition.rs +++ b/evm_arithmetization/src/witness/transition.rs @@ -1,20 +1,26 @@ use std::collections::HashMap; +use std::hash::RandomState; use anyhow::bail; +use ethereum_types::U256; use log::log_enabled; use plonky2::field::types::Field; -use super::memory::{MemoryOp, MemoryOpKind}; -use super::util::fill_channel_with_value; +use super::memory::{MemoryOp, MemoryOpKind, MemorySegmentState}; +use super::util::{ + fill_channel_with_value, mem_read_gp_with_log_and_fill, push_no_write, + stack_pop_with_log_and_fill, +}; use crate::cpu::columns::CpuColumnsView; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::constants::context_metadata::ContextMetadata; +use crate::cpu::membus::NUM_GP_CHANNELS; use crate::cpu::stack::{ EQ_STACK_BEHAVIOR, IS_ZERO_STACK_BEHAVIOR, JUMPI_OP, JUMP_OP, MAX_USER_STACK_SIZE, MIGHT_OVERFLOW, STACK_BEHAVIORS, }; -use crate::generation::state::{GenerationState, GenerationStateCheckpoint}; -use crate::generation::State; +use crate::extension_tower::BN_BASE; +use crate::generation::state::{GenerationState, GenerationStateCheckpoint, State}; use crate::memory::segments::Segment; use crate::witness::errors::ProgramError; use crate::witness::gas::gas_to_charge; @@ -25,13 +31,15 @@ use crate::witness::state::RegistersState; use crate::witness::util::mem_read_code_with_log_and_fill; use crate::{arithmetic, logic}; -fn read_code_memory(state: &mut GenerationState, row: &mut CpuColumnsView) -> u8 { +pub(crate) fn read_code_memory( + state: &mut GenerationState, + row: &mut CpuColumnsView, +) -> u8 { let code_context = state.registers.code_context(); row.code_context = F::from_canonical_usize(code_context); let address = MemoryAddress::new(code_context, Segment::Code, state.registers.program_counter); - let (opcode, mem_log) = - mem_read_code_with_log_and_fill(address, state, row, false, &HashMap::default()); + let (opcode, mem_log) = mem_read_code_with_log_and_fill(address, state, row); state.traces.push_memory(mem_log); @@ -162,7 +170,7 @@ pub(crate) fn decode(registers: RegistersState, opcode: u8) -> Result(op: Operation, row: &mut CpuColumnsView) { +pub(crate) fn fill_op_flag(op: Operation, row: &mut CpuColumnsView) { let flags = &mut row.op; *match op { Operation::Dup(_) | Operation::Swap(_) => &mut flags.dup_swap, @@ -190,7 +198,7 @@ fn fill_op_flag(op: Operation, row: &mut CpuColumnsView) { // Equal to the number of pops if an operation pops without pushing, and `None` // otherwise. -const fn get_op_special_length(op: Operation) -> Option { +pub(crate) const fn get_op_special_length(op: Operation) -> Option { let behavior_opt = match op { Operation::Push(0) | Operation::Pc => STACK_BEHAVIORS.pc_push0, Operation::Push(1..) | Operation::ProverInput => STACK_BEHAVIORS.push_prover_input, @@ -231,7 +239,7 @@ const fn get_op_special_length(op: Operation) -> Option { // These operations might trigger a stack overflow, typically those pushing // without popping. Kernel-only pushing instructions aren't considered; they // can't overflow. -const fn might_overflow_op(op: Operation) -> bool { +pub(crate) const fn might_overflow_op(op: Operation) -> bool { match op { Operation::Push(1..) | Operation::ProverInput => MIGHT_OVERFLOW.push_prover_input, Operation::Dup(_) | Operation::Swap(_) => MIGHT_OVERFLOW.dup_swap, @@ -258,273 +266,7 @@ const fn might_overflow_op(op: Operation) -> bool { } } -fn perform_op( - any_state: &mut State, - op: Operation, - opcode: u8, - row: CpuColumnsView, -) -> Result<(), ProgramError> { - let (op, is_generation, preinitialized_segments, is_jumpdest_analysis) = match any_state { - State::Generation(_state) => (op, true, HashMap::default(), false), - State::Interpreter(interpreter) => { - // Jumpdest analysis is performed natively by the interpreter and not - // using the non-deterministic Kernel assembly code. - let op = if interpreter.is_kernel() - && interpreter.is_jumpdest_analysis - && interpreter.generation_state.registers.program_counter - == KERNEL.global_labels["jumpdest_analysis"] - { - interpreter.generation_state.registers.program_counter = - KERNEL.global_labels["jumpdest_analysis_end"]; - interpreter - .generation_state - .set_jumpdest_bits(&interpreter.generation_state.get_current_code()?); - let opcode = interpreter - .code() - .get(interpreter.generation_state.registers.program_counter) - .byte(0); - - decode(interpreter.generation_state.registers, opcode)? - } else { - op - }; - ( - op, - false, - interpreter.preinitialized_segments.clone(), - interpreter.is_jumpdest_analysis, - ) - } - }; - - #[cfg(debug_assertions)] - if !any_state.get_registers().is_kernel { - log::debug!( - "User instruction {:?}, stack = {:?}, ctx = {}", - op, - { - let mut stack = any_state.get_stack(); - stack.reverse(); - stack - }, - any_state.get_registers().context - ); - } - - let state = any_state.get_mut_generation_state(); - - match op { - Operation::Push(n) => generate_push(n, state, row, is_generation)?, - Operation::Dup(n) => generate_dup(n, state, row, is_generation)?, - Operation::Swap(n) => generate_swap(n, state, row, is_generation)?, - Operation::Iszero => generate_iszero(state, row, is_generation)?, - Operation::Not => generate_not(state, row, is_generation)?, - Operation::BinaryArithmetic(arithmetic::BinaryOperator::Shl) => { - generate_shl(state, row, is_generation)? - } - Operation::BinaryArithmetic(arithmetic::BinaryOperator::Shr) => { - generate_shr(state, row, is_generation)? - } - Operation::Syscall(opcode, stack_values_read, stack_len_increased) => generate_syscall( - opcode, - stack_values_read, - stack_len_increased, - state, - row, - is_generation, - )?, - Operation::Eq => generate_eq(state, row, is_generation)?, - Operation::BinaryLogic(binary_logic_op) => { - generate_binary_logic_op(binary_logic_op, state, row, is_generation)? - } - Operation::BinaryArithmetic(op) => { - generate_binary_arithmetic_op(op, state, row, is_generation)? - } - Operation::TernaryArithmetic(op) => { - generate_ternary_arithmetic_op(op, state, row, is_generation)? - } - Operation::KeccakGeneral => { - generate_keccak_general(state, row, is_generation, &preinitialized_segments)? - } - Operation::ProverInput => generate_prover_input(state, row, is_generation)?, - Operation::Pop => generate_pop(state, row, is_generation)?, - Operation::Jump => generate_jump(any_state, row, is_generation, is_jumpdest_analysis)?, - Operation::Jumpi => generate_jumpi(any_state, row, is_generation, is_jumpdest_analysis)?, - Operation::Pc => generate_pc(state, row, is_generation)?, - Operation::Jumpdest => generate_jumpdest(state, row, is_generation)?, - Operation::GetContext => generate_get_context(state, row, is_generation)?, - Operation::SetContext => generate_set_context(state, row, is_generation)?, - Operation::Mload32Bytes => { - generate_mload_32bytes(state, row, is_generation, &preinitialized_segments)? - } - Operation::Mstore32Bytes(n) => generate_mstore_32bytes(n, state, row, is_generation)?, - Operation::ExitKernel => generate_exit_kernel(state, row, is_generation)?, - Operation::MloadGeneral => { - generate_mload_general(state, row, is_generation, &preinitialized_segments)? - } - Operation::MstoreGeneral => generate_mstore_general(state, row, is_generation)?, - }; - match any_state { - State::Generation(_state) => {} - State::Interpreter(interpreter) => { - interpreter.opcode_count[opcode as usize] += 1; - } - } - Ok(()) -} - -fn perform_state_op( - any_state: &mut State, - opcode: u8, - op: Operation, - row: CpuColumnsView, -) -> Result { - perform_op(any_state, op, opcode, row)?; - any_state.incr_pc(match op { - Operation::Syscall(_, _, _) | Operation::ExitKernel => 0, - Operation::Push(n) => n as usize + 1, - Operation::Jump | Operation::Jumpi => 0, - _ => 1, - }); - - any_state.incr_gas(gas_to_charge(op)); - let registers = any_state.get_registers(); - let gas_limit_address = MemoryAddress::new( - registers.context, - Segment::ContextMetadata, - ContextMetadata::GasLimit.unscale(), // context offsets are already scaled - ); - - if !registers.is_kernel { - let gas_limit = TryInto::::try_into(any_state.get_from_memory(gas_limit_address)); - match gas_limit { - Ok(limit) => { - if registers.gas_used > limit { - return Err(ProgramError::OutOfGas); - } - } - Err(_) => return Err(ProgramError::IntegerTooLarge), - } - } - - Ok(op) -} - -/// Row that has the correct values for system registers and the code channel, -/// but is otherwise blank. It fulfills the constraints that are common to -/// successful operations and the exception operation. It also returns the -/// opcode. -fn base_row(state: &mut GenerationState) -> (CpuColumnsView, u8) { - let mut row: CpuColumnsView = CpuColumnsView::default(); - row.clock = F::from_canonical_usize(state.traces.clock()); - row.context = F::from_canonical_usize(state.registers.context); - row.program_counter = F::from_canonical_usize(state.registers.program_counter); - row.is_kernel_mode = F::from_bool(state.registers.is_kernel); - row.gas = F::from_canonical_u64(state.registers.gas_used); - row.stack_len = F::from_canonical_usize(state.registers.stack_len); - fill_channel_with_value(&mut row, 0, state.registers.stack_top); - - let opcode = read_code_memory(state, &mut row); - (row, opcode) -} - -pub(crate) fn fill_stack_fields( - state: &mut GenerationState, - row: &mut CpuColumnsView, - is_generation: bool, -) -> Result<(), ProgramError> { - if state.registers.is_stack_top_read && is_generation { - let channel = &mut row.mem_channels[0]; - channel.used = F::ONE; - channel.is_read = F::ONE; - channel.addr_context = F::from_canonical_usize(state.registers.context); - channel.addr_segment = F::from_canonical_usize(Segment::Stack.unscale()); - channel.addr_virtual = F::from_canonical_usize(state.registers.stack_len - 1); - - let address = MemoryAddress::new( - state.registers.context, - Segment::Stack, - state.registers.stack_len - 1, - ); - - let mem_op = MemoryOp::new( - GeneralPurpose(0), - state.traces.clock(), - address, - MemoryOpKind::Read, - state.registers.stack_top, - ); - state.traces.push_memory(mem_op); - } - state.registers.is_stack_top_read = false; - - if state.registers.check_overflow && is_generation { - if state.registers.is_kernel { - row.general.stack_mut().stack_len_bounds_aux = F::ZERO; - } else { - let clock = state.traces.clock(); - let last_row = &mut state.traces.cpu[clock - 1]; - let disallowed_len = F::from_canonical_usize(MAX_USER_STACK_SIZE + 1); - let diff = row.stack_len - disallowed_len; - if let Some(inv) = diff.try_inverse() { - last_row.general.stack_mut().stack_len_bounds_aux = inv; - } else { - // This is a stack overflow that should have been caught earlier. - return Err(ProgramError::InterpreterError); - } - } - } - state.registers.check_overflow = false; - - Ok(()) -} - -fn try_perform_instruction(any_state: &mut State) -> Result { - let is_generation = any_state.is_generation_state(); - let registers = any_state.get_registers(); - - let state = any_state.get_mut_generation_state(); - let (mut row, opcode) = base_row(state); - - let op = decode(registers, opcode)?; - - if registers.is_kernel { - log_kernel_instruction(state, op); - } else { - log::debug!("User instruction: {:?}", op); - } - - if is_generation { - fill_op_flag(op, &mut row); - } - - fill_stack_fields(state, &mut row, is_generation)?; - - // Might write in general CPU columns when it shouldn't, but the correct values - // will overwrite these ones during the op generation. - if let Some(special_len) = get_op_special_length(op) { - let special_len_f = F::from_canonical_usize(special_len); - let diff = row.stack_len - special_len_f; - if let Some(inv) = diff.try_inverse() - && is_generation - { - row.general.stack_mut().stack_inv = inv; - row.general.stack_mut().stack_inv_aux = F::ONE; - state.registers.is_stack_top_read = true; - } else if !is_generation && (state.stack().len() != special_len) { - // If the `State` is an interpreter, we cannot rely on the row to carry out the - // check. - state.registers.is_stack_top_read = true; - } - } else if let Some(inv) = row.stack_len.try_inverse() { - row.general.stack_mut().stack_inv = inv; - row.general.stack_mut().stack_inv_aux = F::ONE; - } - - perform_state_op(any_state, opcode, op, row) -} - -fn log_kernel_instruction(state: &GenerationState, op: Operation) { +pub(crate) fn log_kernel_instruction(state: &mut GenerationState, op: Operation) { // The logic below is a bit costly, so skip it if debug logs aren't enabled. if !log_enabled!(log::Level::Debug) { return; @@ -553,63 +295,258 @@ fn log_kernel_instruction(state: &GenerationState, op: Operation) { assert!(pc < KERNEL.code.len(), "Kernel PC is out of range: {}", pc); } -fn handle_error(any_state: &mut State, err: ProgramError) -> anyhow::Result<()> { - let exc_code: u8 = match err { - ProgramError::OutOfGas => 0, - ProgramError::InvalidOpcode => 1, - ProgramError::StackUnderflow => 2, - ProgramError::InvalidJumpDestination => 3, - ProgramError::InvalidJumpiDestination => 4, - ProgramError::StackOverflow => 5, - _ => bail!("TODO: figure out what to do with this..."), - }; +pub(crate) trait Transition: State { + fn generate_jumpdest_analysis(&mut self, dst: usize) -> bool; + + fn perform_state_op( + &mut self, + opcode: u8, + op: Operation, + row: CpuColumnsView, + ) -> Result { + self.perform_op(op, opcode, row)?; + self.incr_pc(match op { + Operation::Syscall(_, _, _) | Operation::ExitKernel => 0, + Operation::Push(n) => n as usize + 1, + Operation::Jump | Operation::Jumpi => 0, + _ => 1, + }); + + self.incr_gas(gas_to_charge(op)); + let registers = self.get_registers(); + let gas_limit_address = MemoryAddress::new( + registers.context, + Segment::ContextMetadata, + ContextMetadata::GasLimit.unscale(), // context offsets are already scaled + ); - let (checkpoint, is_generation, state): ( - GenerationStateCheckpoint, - bool, - &mut GenerationState<_>, - ) = match any_state { - State::Generation(state) => (state.checkpoint(), true, state), - State::Interpreter(interpreter) => ( - interpreter.checkpoint(), - false, - &mut interpreter.generation_state, - ), - }; - let (row, _) = base_row(state); - generate_exception(exc_code, state, row, is_generation) - .map_err(|_| anyhow::Error::msg("error handling errored..."))?; - any_state.apply_ops(checkpoint); + if !registers.is_kernel { + let gas_limit = TryInto::::try_into(self.get_from_memory(gas_limit_address)); + match gas_limit { + Ok(limit) => { + if registers.gas_used > limit { + return Err(ProgramError::OutOfGas); + } + } + Err(_) => return Err(ProgramError::IntegerTooLarge), + } + } - Ok(()) -} + Ok(op) + } -pub(crate) fn transition(any_state: &mut State) -> anyhow::Result<()> { - let checkpoint = any_state.checkpoint(); - let result = try_perform_instruction(any_state); + fn generate_jump( + &mut self, + mut row: CpuColumnsView, + is_generation: bool, + ) -> Result<(), ProgramError> { + let [(dst, _)] = + stack_pop_with_log_and_fill::<1, _>(self.get_mut_generation_state(), &mut row)?; + + let dst: u32 = dst + .try_into() + .map_err(|_| ProgramError::InvalidJumpDestination)?; + + if !self.generate_jumpdest_analysis(dst as usize) { + let gen_state = self.get_mut_generation_state(); + let (jumpdest_bit, jumpdest_bit_log) = mem_read_gp_with_log_and_fill( + NUM_GP_CHANNELS - 1, + MemoryAddress::new( + gen_state.registers.context, + Segment::JumpdestBits, + dst as usize, + ), + gen_state, + &mut row, + ); - match result { - Ok(op) => { - any_state.apply_ops(checkpoint); + row.mem_channels[1].value[0] = F::ONE; - if might_overflow_op(op) { - any_state.get_mut_registers().check_overflow = true; + if gen_state.registers.is_kernel { + // Don't actually do the read, just set the address, etc. + let channel = &mut row.mem_channels[NUM_GP_CHANNELS - 1]; + channel.used = F::ZERO; + channel.value[0] = F::ONE; + } else { + if jumpdest_bit != U256::one() { + return Err(ProgramError::InvalidJumpDestination); + } + gen_state.traces.push_memory(jumpdest_bit_log); } - Ok(()) + + // Extra fields required by the constraints. + row.general.jumps_mut().should_jump = F::ONE; + row.general.jumps_mut().cond_sum_pinv = F::ONE; + + let diff = row.stack_len - F::ONE; + if let Some(inv) = diff.try_inverse() { + row.general.stack_mut().stack_inv = inv; + row.general.stack_mut().stack_inv_aux = F::ONE; + } else { + row.general.stack_mut().stack_inv = F::ZERO; + row.general.stack_mut().stack_inv_aux = F::ZERO; + } + + gen_state.traces.push_cpu(is_generation, row); } - Err(e) => { - if any_state.get_registers().is_kernel { - let offset_name = KERNEL.offset_name(any_state.get_registers().program_counter); - bail!( - "{:?} in kernel at pc={}, stack={:?}, memory={:?}", - e, - offset_name, - any_state.get_stack(), - any_state.mem_get_kernel_content(), - ); + self.get_mut_generation_state().jump_to(dst as usize)?; + Ok(()) + } + + fn generate_jumpi( + &mut self, + mut row: CpuColumnsView, + is_generation: bool, + ) -> Result<(), ProgramError> { + let [(dst, _), (cond, log_cond)] = + stack_pop_with_log_and_fill::<2, _>(self.get_mut_generation_state(), &mut row)?; + + let should_jump = !cond.is_zero(); + if should_jump { + let dst: u32 = dst + .try_into() + .map_err(|_| ProgramError::InvalidJumpiDestination)?; + let is_kernel = self.get_registers().is_kernel; + if !self.generate_jumpdest_analysis(dst as usize) { + row.general.jumps_mut().should_jump = F::ONE; + let cond_sum_u64 = cond + .0 + .into_iter() + .map(|limb| ((limb as u32) as u64) + (limb >> 32)) + .sum(); + let cond_sum = F::from_canonical_u64(cond_sum_u64); + row.general.jumps_mut().cond_sum_pinv = cond_sum.inverse(); } - any_state.rollback(checkpoint); - handle_error(any_state, e) + self.get_mut_generation_state().jump_to(dst as usize)?; + } else { + row.general.jumps_mut().should_jump = F::ZERO; + row.general.jumps_mut().cond_sum_pinv = F::ZERO; + self.incr_pc(1); } + + let gen_state = self.get_mut_generation_state(); + let (jumpdest_bit, jumpdest_bit_log) = mem_read_gp_with_log_and_fill( + NUM_GP_CHANNELS - 1, + MemoryAddress::new( + gen_state.registers.context, + Segment::JumpdestBits, + dst.low_u32() as usize, + ), + gen_state, + &mut row, + ); + if !should_jump || gen_state.registers.is_kernel { + // Don't actually do the read, just set the address, etc. + let channel = &mut row.mem_channels[NUM_GP_CHANNELS - 1]; + channel.used = F::ZERO; + channel.value[0] = F::ONE; + } else { + if jumpdest_bit != U256::one() { + return Err(ProgramError::InvalidJumpiDestination); + } + gen_state.traces.push_memory(jumpdest_bit_log); + } + + let diff = row.stack_len - F::TWO; + if let Some(inv) = diff.try_inverse() { + row.general.stack_mut().stack_inv = inv; + row.general.stack_mut().stack_inv_aux = F::ONE; + } else { + row.general.stack_mut().stack_inv = F::ZERO; + row.general.stack_mut().stack_inv_aux = F::ZERO; + } + + gen_state.traces.push_memory(log_cond); + gen_state.traces.push_cpu(is_generation, row); + Ok(()) } + + /// Skips the following instructions for some specific labels + fn skip_if_necessary(&mut self, op: Operation) -> Result; + + fn perform_op( + &mut self, + op: Operation, + opcode: u8, + row: CpuColumnsView, + ) -> Result<(), ProgramError> { + let op = self.skip_if_necessary(op)?; + + #[cfg(debug_assertions)] + if !self.get_registers().is_kernel { + log::debug!( + "User instruction {:?}, stack = {:?}, ctx = {}", + op, + { + let mut stack = self.get_stack(); + stack.reverse(); + stack + }, + self.get_registers().context + ); + } + + let is_generation = self.is_generation(); + let generation_state = self.get_mut_generation_state(); + + match op { + Operation::Push(n) => generate_push(n, generation_state, row, is_generation)?, + Operation::Dup(n) => generate_dup(n, generation_state, row, is_generation)?, + Operation::Swap(n) => generate_swap(n, generation_state, row, is_generation)?, + Operation::Iszero => generate_iszero(generation_state, row, is_generation)?, + Operation::Not => generate_not(generation_state, row, is_generation)?, + Operation::BinaryArithmetic(arithmetic::BinaryOperator::Shl) => { + generate_shl(generation_state, row, is_generation)? + } + Operation::BinaryArithmetic(arithmetic::BinaryOperator::Shr) => { + generate_shr(generation_state, row, is_generation)? + } + Operation::Syscall(opcode, stack_values_read, stack_len_increased) => generate_syscall( + opcode, + stack_values_read, + stack_len_increased, + generation_state, + row, + is_generation, + )?, + Operation::Eq => generate_eq(generation_state, row, is_generation)?, + Operation::BinaryLogic(binary_logic_op) => { + generate_binary_logic_op(binary_logic_op, generation_state, row, is_generation)? + } + Operation::BinaryArithmetic(op) => { + generate_binary_arithmetic_op(op, generation_state, row, is_generation)? + } + Operation::TernaryArithmetic(op) => { + generate_ternary_arithmetic_op(op, generation_state, row, is_generation)? + } + Operation::KeccakGeneral => { + generate_keccak_general(generation_state, row, is_generation)? + } + Operation::ProverInput => generate_prover_input(generation_state, row, is_generation)?, + Operation::Pop => generate_pop(generation_state, row, is_generation)?, + Operation::Jump => self.generate_jump(row, is_generation)?, + Operation::Jumpi => self.generate_jumpi(row, is_generation)?, + Operation::Pc => generate_pc(generation_state, row, is_generation)?, + Operation::Jumpdest => generate_jumpdest(generation_state, row, is_generation)?, + Operation::GetContext => generate_get_context(generation_state, row, is_generation)?, + Operation::SetContext => generate_set_context(generation_state, row, is_generation)?, + Operation::Mload32Bytes => { + generate_mload_32bytes(generation_state, row, is_generation)? + } + Operation::Mstore32Bytes(n) => { + generate_mstore_32bytes(n, generation_state, row, is_generation)? + } + Operation::ExitKernel => generate_exit_kernel(generation_state, row, is_generation)?, + Operation::MloadGeneral => { + generate_mload_general(generation_state, row, is_generation)? + } + Operation::MstoreGeneral => { + generate_mstore_general(generation_state, row, is_generation)? + } + }; + + Ok(()) + } + + fn fill_stack_fields(&mut self, row: &mut CpuColumnsView) -> Result<(), ProgramError>; } diff --git a/evm_arithmetization/src/witness/util.rs b/evm_arithmetization/src/witness/util.rs index d6790164c..0f3d44edb 100644 --- a/evm_arithmetization/src/witness/util.rs +++ b/evm_arithmetization/src/witness/util.rs @@ -43,15 +43,11 @@ pub(crate) fn stack_peek( return Ok(state.registers.stack_top); } - Ok(state.memory.get_with_init( - MemoryAddress::new( - state.registers.context, - Segment::Stack, - state.registers.stack_len - 1 - i, - ), - false, - &HashMap::default(), - )) + Ok(state.memory.get_with_init(MemoryAddress::new( + state.registers.context, + Segment::Stack, + state.registers.stack_len - 1 - i, + ))) } /// Peek at kernel at specified segment and address @@ -59,15 +55,11 @@ pub(crate) fn current_context_peek( state: &GenerationState, segment: Segment, virt: usize, - is_interpreter: bool, - preinitialized_segments: &HashMap, ) -> U256 { let context = state.registers.context; - state.memory.get_with_init( - MemoryAddress::new(context, segment, virt), - is_interpreter, - preinitialized_segments, - ) + state + .memory + .get_with_init(MemoryAddress::new(context, segment, virt)) } pub(crate) fn fill_channel_with_value(row: &mut CpuColumnsView, n: usize, val: U256) { @@ -120,12 +112,8 @@ pub(crate) fn mem_read_with_log( channel: MemoryChannel, address: MemoryAddress, state: &GenerationState, - is_interpreter: bool, - preinitialized_segments: &HashMap, ) -> (U256, MemoryOp) { - let val = state - .memory - .get_with_init(address, is_interpreter, preinitialized_segments); + let val = state.memory.get_with_init(address); let op = MemoryOp::new( channel, state.traces.clock(), @@ -155,16 +143,8 @@ pub(crate) fn mem_read_code_with_log_and_fill( address: MemoryAddress, state: &GenerationState, row: &mut CpuColumnsView, - is_interpreter: bool, - preinitialized_segments: &HashMap, ) -> (u8, MemoryOp) { - let (val, op) = mem_read_with_log( - MemoryChannel::Code, - address, - state, - is_interpreter, - preinitialized_segments, - ); + let (val, op) = mem_read_with_log(MemoryChannel::Code, address, state); let val_u8 = to_byte_checked(val); row.opcode_bits = to_bits_le(val_u8); @@ -177,16 +157,8 @@ pub(crate) fn mem_read_gp_with_log_and_fill( address: MemoryAddress, state: &GenerationState, row: &mut CpuColumnsView, - is_interpreter: bool, - preinitialized_segments: &HashMap, ) -> (U256, MemoryOp) { - let (val, op) = mem_read_with_log( - MemoryChannel::GeneralPurpose(n), - address, - state, - is_interpreter, - preinitialized_segments, - ); + let (val, op) = mem_read_with_log(MemoryChannel::GeneralPurpose(n), address, state); let val_limbs: [u64; 4] = val.0; let channel = &mut row.mem_channels[n]; @@ -275,7 +247,7 @@ pub(crate) fn stack_pop_with_log_and_fill( state.registers.stack_len - 1 - i, ); - mem_read_gp_with_log_and_fill(i, address, state, row, false, &HashMap::default()) + mem_read_gp_with_log_and_fill(i, address, state, row) } });