Skip to content
This repository has been archived by the owner on Oct 31, 2024. It is now read-only.

Create State and Transition traits #1

Merged
merged 10 commits into from
Feb 27, 2024
233 changes: 189 additions & 44 deletions evm_arithmetization/src/cpu/kernel/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,18 @@ use crate::generation::rlp::all_rlp_prover_inputs_reversed;
use crate::generation::state::{
all_withdrawals_prover_inputs_reversed, GenerationState, GenerationStateCheckpoint,
};
use crate::generation::{run_cpu, GenerationInputs, State};
use crate::generation::{state::State, GenerationInputs};
use crate::memory::segments::Segment;
use crate::util::h2u;
use crate::witness::errors::ProgramError;
use crate::witness::memory::{
MemoryAddress, MemoryContextState, MemoryOp, MemoryOpKind, MemorySegmentState,
};
use crate::witness::operation::Operation;
use crate::witness::state::RegistersState;
use crate::witness::transition::{
decode, fill_op_flag, get_op_special_length, log_kernel_instruction, Transition,
};

type F = GoldilocksField;

Expand All @@ -45,11 +49,6 @@ pub(crate) struct Interpreter<F: Field> {
/// Counts the number of appearances of each opcode. For debugging purposes.
pub(crate) opcode_count: [usize; 0x100],
jumpdest_table: HashMap<usize, BTreeSet<usize>>,
/// Segments that can be preinitialized: they are not stored in the
/// interpreter memory unless they are read/written during the execution.
/// When the values are first read, they are read from this `HashMap` (and
/// the value is then written in memory).
pub(crate) preinitialized_segments: HashMap<Segment, MemorySegmentState>,
/// `true` if the we are currently carrying out a jumpdest analysis.
pub(crate) is_jumpdest_analysis: bool,
/// Holds the value of the clock: the clock counts the number of operations
Expand Down Expand Up @@ -160,7 +159,6 @@ impl<F: Field> Interpreter<F> {
halt_context: None,
opcode_count: [0; 256],
jumpdest_table: HashMap::new(),
preinitialized_segments: HashMap::default(),
is_jumpdest_analysis: false,
clock: 0,
};
Expand Down Expand Up @@ -190,7 +188,6 @@ impl<F: Field> Interpreter<F> {
halt_context: Some(halt_context),
opcode_count: [0; 256],
jumpdest_table: HashMap::new(),
preinitialized_segments: HashMap::new(),
is_jumpdest_analysis: true,
clock: 0,
}
Expand All @@ -215,8 +212,7 @@ impl<F: Field> Interpreter<F> {
let preinit_trie_data_segment = MemorySegmentState {
content: trie_data.iter().map(|&elt| Some(elt)).collect::<Vec<_>>(),
};
self.preinitialized_segments
.insert(Segment::TrieData, preinit_trie_data_segment);
self.insert_preinitialized_segment(Segment::TrieData, preinit_trie_data_segment);

// Update the RLP and withdrawal prover inputs.
let rlp_prover_inputs =
Expand Down Expand Up @@ -332,11 +328,7 @@ impl<F: Field> Interpreter<F> {
match kind {
MemoryOpKind::Read => {
if self.generation_state.memory.get(address).is_none() {
if !self
.preinitialized_segments
.contains_key(&Segment::all()[address.segment])
&& !value.is_zero()
{
if !self.is_preinitialized_segment(address.segment) && !value.is_zero() {
return Err(anyhow!("The initial value {:?} at address {:?} should be zero, because it is not preinitialized.", value, address));
}
self.generation_state.memory.set(address, value);
Expand All @@ -349,19 +341,8 @@ impl<F: Field> Interpreter<F> {
Ok(())
}

/// Returns a `GenerationStateCheckpoint` to save the current registers and
/// reset memory operations to the empty vector.
pub(crate) fn checkpoint(&mut self) -> GenerationStateCheckpoint {
self.generation_state.traces.memory_ops = vec![];
GenerationStateCheckpoint {
registers: self.generation_state.registers,
traces: self.generation_state.traces.checkpoint(),
}
}

pub(crate) fn run(&mut self) -> Result<(), anyhow::Error> {
let mut state = State::Interpreter(self);
run_cpu(&mut state)?;
self.run_cpu()?;

#[cfg(debug_assertions)]
{
Expand Down Expand Up @@ -458,7 +439,7 @@ impl<F: Field> Interpreter<F> {
}

pub(crate) fn get_memory_segment(&self, segment: Segment) -> Vec<U256> {
if self.preinitialized_segments.contains_key(&segment) {
if self.is_preinitialized_segment(segment.unscale()) {
let total_len = self.generation_state.memory.contexts[0].segments[segment.unscale()]
.content
.len();
Expand All @@ -473,8 +454,9 @@ impl<F: Field> Interpreter<F> {
};
let mut res = get_vals(
&self
.preinitialized_segments
.get(&segment)
.generation_state
.memory
.get_preinitialized_segment(segment)
.expect("The segment should be in the preinitialized segments.")
.content,
);
Expand Down Expand Up @@ -607,6 +589,7 @@ impl<F: Field> Interpreter<F> {
}
}
}

fn stack_segment_mut(&mut self) -> &mut Vec<Option<U256>> {
let context = self.context();
&mut self.generation_state.memory.contexts[context].segments[Segment::Stack.unscale()]
Expand All @@ -616,11 +599,10 @@ impl<F: Field> Interpreter<F> {
pub(crate) fn extract_kernel_memory(self, segment: Segment, range: Range<usize>) -> Vec<U256> {
let mut output: Vec<U256> = Vec::with_capacity(range.end);
for i in range {
let term = self.generation_state.memory.get_with_init(
MemoryAddress::new(0, segment, i),
true,
&self.preinitialized_segments,
);
let term = self
.generation_state
.memory
.get_with_init(MemoryAddress::new(0, segment, i));
output.push(term);
}
output
Expand Down Expand Up @@ -682,15 +664,11 @@ impl<F: Field> Interpreter<F> {
// Even though we are in the interpreter, `JumpdestBits` is not part of the
// preinitialized segments, so we don't need to carry out the additional checks
// when get the value from memory.
self.generation_state.memory.get_with_init(
MemoryAddress {
context: self.context(),
segment: Segment::JumpdestBits.unscale(),
virt: offset,
},
false,
&HashMap::default(),
)
self.generation_state.memory.get_with_init(MemoryAddress {
context: self.context(),
segment: Segment::JumpdestBits.unscale(),
virt: offset,
})
} else {
0.into()
}
Expand Down Expand Up @@ -757,6 +735,173 @@ impl<F: Field> Interpreter<F> {
}
}

impl<F: Field> State<F> for Interpreter<F> {
//// Returns a `GenerationStateCheckpoint` to save the current registers and
/// reset memory operations to the empty vector.
fn checkpoint(&mut self) -> GenerationStateCheckpoint {
self.generation_state.traces.memory_ops = vec![];
GenerationStateCheckpoint {
registers: self.generation_state.registers,
traces: self.generation_state.traces.checkpoint(),
}
}

fn is_generation(&mut self) -> bool {
false
}

fn insert_preinitialized_segment(&mut self, segment: Segment, values: MemorySegmentState) {
self.generation_state
.memory
.insert_preinitialized_segment(segment, values);
}

fn is_preinitialized_segment(&self, segment: usize) -> bool {
self.generation_state
.memory
.is_preinitialized_segment(segment)
}

fn incr_gas(&mut self, n: u64) {
self.generation_state.incr_gas(n);
}

fn incr_pc(&mut self, n: usize) {
self.generation_state.incr_pc(n);
}

fn get_registers(&self) -> RegistersState {
self.generation_state.get_registers()
}

fn get_mut_registers(&mut self) -> &mut RegistersState {
self.generation_state.get_mut_registers()
}

fn get_from_memory(&mut self, address: MemoryAddress) -> U256 {
self.generation_state.memory.get_with_init(address)
}

fn get_mut_generation_state(&mut self) -> &mut GenerationState<F> {
&mut self.generation_state
}

fn incr_interpreter_clock(&mut self) {
self.clock += 1
}

fn get_clock(&mut self) -> usize {
self.clock
}

fn rollback(&mut self, checkpoint: GenerationStateCheckpoint) {
self.generation_state.rollback(checkpoint)
}

fn get_context(&mut self) -> usize {
self.context()
}

fn get_halt_context(&mut self) -> Option<usize> {
self.halt_context
}

fn mem_get_kernel_content(&self) -> Vec<Option<U256>> {
self.generation_state.memory.contexts[0].segments[Segment::KernelGeneral.unscale()]
.content
.clone()
}

fn apply_ops(&mut self, checkpoint: GenerationStateCheckpoint) {
self.apply_memops();
}

fn get_stack(&mut self) -> Vec<U256> {
self.stack().clone()
}

fn get_halt_offsets(&self) -> Vec<usize> {
self.halt_offsets.clone()
}

fn try_perform_instruction(&mut self) -> Result<Operation, ProgramError> {
let registers = self.generation_state.registers;
let (mut row, opcode) = self.base_row();

let op = decode(registers, opcode)?;

fill_op_flag(op, &mut row);

self.fill_stack_fields(&mut row)?;

let generation_state = self.get_mut_generation_state();
if registers.is_kernel {
log_kernel_instruction(generation_state, op);
} else {
log::debug!("User instruction: {:?}", op);
}

// Might write in general CPU columns when it shouldn't, but the correct values
// will overwrite these ones during the op generation.
if let Some(special_len) = get_op_special_length(op) {
let special_len_f = F::from_canonical_usize(special_len);
let diff = row.stack_len - special_len_f;
if (generation_state.stack().len() != special_len) {
// If the `State` is an interpreter, we cannot rely on the row to carry out the
// check.
generation_state.registers.is_stack_top_read = true;
}
} else if let Some(inv) = row.stack_len.try_inverse() {
row.general.stack_mut().stack_inv = inv;
row.general.stack_mut().stack_inv_aux = F::ONE;
}

self.perform_state_op(opcode, op, row)
}
}

impl<F: Field> Transition<F> for Interpreter<F> {
fn generate_jumpdest_analysis(&mut self, dst: usize) -> bool {
if self.is_jumpdest_analysis && !self.generation_state.registers.is_kernel {
self.add_jumpdest_offset(dst);
true
} else {
false
}
}

fn skip_if_necessary(&mut self, op: Operation) -> Result<Operation, ProgramError> {
if self.is_kernel()
&& self.is_jumpdest_analysis
&& self.generation_state.registers.program_counter
== KERNEL.global_labels["jumpdest_analysis"]
{
self.generation_state.registers.program_counter =
KERNEL.global_labels["jumpdest_analysis_end"];
self.generation_state
.set_jumpdest_bits(&self.generation_state.get_current_code()?);
let opcode = self
.code()
.get(self.generation_state.registers.program_counter)
.byte(0);

decode(self.generation_state.registers, opcode)
} else {
Ok(op)
}
}

fn fill_stack_fields(
&mut self,
row: &mut crate::cpu::columns::CpuColumnsView<F>,
) -> Result<(), ProgramError> {
self.generation_state.registers.is_stack_top_read = false;
self.generation_state.registers.check_overflow = false;

Ok(())
}
}

fn get_mnemonic(opcode: u8) -> &'static str {
match opcode {
0x00 => "STOP",
Expand Down
Loading