diff --git a/crates/polkavm-assembler/src/amd64.rs b/crates/polkavm-assembler/src/amd64.rs index 7ea5668e..a6a37e1c 100644 --- a/crates/polkavm-assembler/src/amd64.rs +++ b/crates/polkavm-assembler/src/amd64.rs @@ -2157,7 +2157,7 @@ mod tests { } self.disassembly_1.pop(); - disassemble_into(code, &mut self.disassembly_2); + disassemble_into(&code, &mut self.disassembly_2); assert_eq!(self.disassembly_1, self.disassembly_2, "broken encoding for: {inst:?}"); } } @@ -2244,7 +2244,7 @@ mod tests { let mut asm = crate::Assembler::new(); let label = asm.forward_declare_label(); asm.push_with_label(label, jmp_label8(label)); - let disassembly = disassemble(asm.finalize()); + let disassembly = disassemble(&asm.finalize()); assert_eq!(disassembly, "00000000 ebfe jmp short 0x0"); } @@ -2254,7 +2254,7 @@ mod tests { let mut asm = crate::Assembler::new(); let label = asm.forward_declare_label(); asm.push_with_label(label, jmp_label32(label)); - let disassembly = disassemble(asm.finalize()); + let disassembly = disassemble(&asm.finalize()); assert_eq!(disassembly, "00000000 e9fbffffff jmp 0x0"); } @@ -2264,7 +2264,7 @@ mod tests { let mut asm = crate::Assembler::new(); let label = asm.forward_declare_label(); asm.push_with_label(label, call_label32(label)); - let disassembly = disassemble(asm.finalize()); + let disassembly = disassemble(&asm.finalize()); assert_eq!(disassembly, "00000000 e8fbffffff call 0x0"); } @@ -2275,7 +2275,7 @@ mod tests { let mut asm = crate::Assembler::new(); let label = asm.forward_declare_label(); asm.push_with_label(label, jcc_label8(cond, label)); - let disassembly = disassemble(asm.finalize()); + let disassembly = disassemble(&asm.finalize()); assert_eq!( disassembly, format!("00000000 {:02x}fe j{} short 0x0", 0x70 + cond as u8, cond.suffix()) @@ -2291,7 +2291,7 @@ mod tests { let label = asm.forward_declare_label(); asm.push(jcc_label8(cond, label)); asm.push_with_label(label, nop()); - let disassembly = disassemble(asm.finalize()); + let disassembly = disassemble(&asm.finalize()); assert_eq!( disassembly, format!( @@ -2311,7 +2311,7 @@ mod tests { let label = asm.forward_declare_label(); asm.push_with_label(label, nop()); asm.push(jcc_label8(cond, label)); - let disassembly = disassemble(asm.finalize()); + let disassembly = disassemble(&asm.finalize()); assert_eq!( disassembly, format!( @@ -2331,7 +2331,7 @@ mod tests { let label = asm.forward_declare_label(); asm.push(jcc_label32(cond, label)); asm.push_with_label(label, nop()); - let disassembly = disassemble(asm.finalize()); + let disassembly = disassemble(&asm.finalize()); assert_eq!( disassembly, format!( @@ -2349,7 +2349,7 @@ mod tests { let mut asm = crate::Assembler::new(); let label = asm.forward_declare_label(); asm.push_with_label(label, lea_rip_label(super::Reg::rax, label)); - let disassembly = disassemble(asm.finalize()); + let disassembly = disassemble(&asm.finalize()); assert_eq!(disassembly, "00000000 488d05f9ffffff lea rax, [rip-0x7]"); } @@ -2360,7 +2360,7 @@ mod tests { let label = asm.forward_declare_label(); asm.push(lea_rip_label(super::Reg::rax, label)); asm.push_with_label(label, nop()); - let disassembly = disassemble(asm.finalize()); + let disassembly = disassemble(&asm.finalize()); assert_eq!(disassembly, "00000000 488d0500000000 lea rax, [rip]\n00000007 90 nop"); } } diff --git a/crates/polkavm-assembler/src/assembler.rs b/crates/polkavm-assembler/src/assembler.rs index f4756695..a63d42ef 100644 --- a/crates/polkavm-assembler/src/assembler.rs +++ b/crates/polkavm-assembler/src/assembler.rs @@ -23,6 +23,30 @@ impl Default for Assembler { } } +#[repr(transparent)] +pub struct AssembledCode<'a>(&'a mut Assembler); + +impl<'a> core::ops::Deref for AssembledCode<'a> { + type Target = [u8]; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.0.code + } +} + +impl<'a> From> for Vec { + fn from(code: AssembledCode<'a>) -> Vec { + core::mem::take(&mut code.0.code) + } +} + +impl<'a> Drop for AssembledCode<'a> { + fn drop(&mut self) { + self.0.clear(); + } +} + impl Assembler { pub const fn new() -> Self { Assembler { @@ -136,7 +160,7 @@ impl Assembler { self } - pub fn finalize(&mut self) -> &[u8] { + pub fn finalize(&mut self) -> AssembledCode { for fixup in self.fixups.drain(..) { let origin = fixup.instruction_offset + fixup.instruction_length as usize; let target_absolute = self.labels[fixup.target_label.0 as usize]; @@ -157,7 +181,8 @@ impl Assembler { unreachable!() } } - &self.code + + AssembledCode(self) } pub fn is_empty(&self) -> bool { @@ -168,6 +193,10 @@ impl Assembler { self.code.len() } + pub fn code_mut(&mut self) -> &mut [u8] { + &mut self.code + } + pub fn spare_capacity(&self) -> usize { self.code.capacity() - self.code.len() } diff --git a/crates/polkavm-common/src/abi.rs b/crates/polkavm-common/src/abi.rs index 657b296a..755839cb 100644 --- a/crates/polkavm-common/src/abi.rs +++ b/crates/polkavm-common/src/abi.rs @@ -1,40 +1,46 @@ //! Everything in this module affects the ABI of the guest programs, either by affecting //! their observable behavior (no matter how obscure), or changing which programs are accepted by the VM. -use crate::utils::align_to_next_page_u64; +use crate::utils::{align_to_next_page_u32, align_to_next_page_u64}; use core::ops::Range; +const ADDRESS_SPACE_SIZE: u64 = 0x100000000_u64; + /// The page size of the VM. /// /// This is the minimum granularity with which the VM can allocate memory. pub const VM_PAGE_SIZE: u32 = 0x4000; +/// The maximum page size of the VM. +pub const VM_MAX_PAGE_SIZE: u32 = 0x10000; + +static_assert!(VM_PAGE_SIZE <= VM_MAX_PAGE_SIZE); +static_assert!(VM_MAX_PAGE_SIZE % VM_PAGE_SIZE == 0); + /// The address at which the program's memory starts inside of the VM. /// /// This is directly accessible by the program running inside of the VM. -pub const VM_ADDR_USER_MEMORY: u32 = 0x00010000; +pub const VM_ADDR_USER_MEMORY: u32 = VM_MAX_PAGE_SIZE; /// The address at which the program's stack starts inside of the VM. /// /// This is directly accessible by the program running inside of the VM. -pub const VM_ADDR_USER_STACK_HIGH: u32 = 0xffffc000; -static_assert!(0xffffffff - VM_PAGE_SIZE as u64 + 1 == VM_ADDR_USER_STACK_HIGH as u64); +pub const VM_ADDR_USER_STACK_HIGH: u32 = (ADDRESS_SPACE_SIZE - VM_MAX_PAGE_SIZE as u64) as u32; /// The address which, when jumped to, will return to the host. /// /// There isn't actually anything there; it's just a virtual address. -pub const VM_ADDR_RETURN_TO_HOST: u32 = 0xffffc000; +pub const VM_ADDR_RETURN_TO_HOST: u32 = 0xffff0000; static_assert!(VM_ADDR_RETURN_TO_HOST & 0b11 == 0); /// The total maximum amount of memory a program can use. /// /// This is the whole 32-bit address space, except: -/// * the guard pages at the start, +/// * the guard page at the start, /// * the guard page between read-only data and read-write data /// * the guard page between the heap and the stack, /// * and the guard page at the end. -pub const VM_MAXIMUM_MEMORY_SIZE: u32 = 0xfffe4000; -static_assert!(VM_MAXIMUM_MEMORY_SIZE as u64 == (1_u64 << 32) - VM_ADDR_USER_MEMORY as u64 - VM_PAGE_SIZE as u64 * 3); +pub const VM_MAXIMUM_MEMORY_SIZE: u32 = (ADDRESS_SPACE_SIZE - VM_MAX_PAGE_SIZE as u64 * 4) as u32; /// The maximum number of VM instructions a program can be composed of. pub const VM_MAXIMUM_INSTRUCTION_COUNT: u32 = 2 * 1024 * 1024; @@ -94,7 +100,7 @@ impl GuestMemoryConfig { // We already checked that these are less than the maximum memory size, so these cannot fail // because the maximum memory size is going to be vastly smaller than what an u64 can hold. const _: () = { - assert!(VM_MAXIMUM_MEMORY_SIZE as u64 + VM_PAGE_SIZE as u64 <= u32::MAX as u64); + assert!(VM_MAXIMUM_MEMORY_SIZE as u64 + VM_MAX_PAGE_SIZE as u64 <= u32::MAX as u64); }; let ro_data_size = match align_to_next_page_u64(VM_PAGE_SIZE as u64, ro_data_size) { @@ -204,7 +210,10 @@ impl GuestMemoryConfig { if self.ro_data_size == 0 { self.user_memory_region_address() } else { - self.ro_data_address() + self.ro_data_size + VM_PAGE_SIZE + match align_to_next_page_u32(VM_MAX_PAGE_SIZE, self.ro_data_address() + self.ro_data_size) { + Some(offset) => offset + VM_MAX_PAGE_SIZE, + None => unreachable!(), + } } } diff --git a/crates/polkavm-common/src/program.rs b/crates/polkavm-common/src/program.rs index c22431e9..770ba6c2 100644 --- a/crates/polkavm-common/src/program.rs +++ b/crates/polkavm-common/src/program.rs @@ -107,6 +107,75 @@ impl core::fmt::Display for Reg { } } +#[doc(hidden)] +pub struct VisitorHelper<'a, T> { + pub visitor: T, + reader: Reader<'a>, +} + +impl<'a, T> VisitorHelper<'a, T> { + #[inline] + pub fn run( + blob: &'a ProgramBlob<'a>, + visitor: T, + decode_table: &[fn(&mut Self) -> ::ReturnTy; 256], + ) -> (T, ::ReturnTy) + where + T: ParsingVisitor, + { + let mut state = VisitorHelper { + visitor, + reader: blob.get_section_reader(blob.code.clone()), + }; + + let mut result = Ok(()); + loop { + let Ok(opcode) = state.reader.read_byte() else { break }; + result = state.visitor.on_pre_visit(state.reader.position - 1, opcode); + if result.is_err() { + break; + } + + result = decode_table[opcode as usize](&mut state); + if result.is_err() { + break; + } + + result = state.visitor.on_post_visit(); + if result.is_err() { + break; + } + } + + (state.visitor, result) + } + + #[cold] + pub fn unknown_opcode(&mut self) -> ::ReturnTy + where + T: InstructionVisitor>, + E: From, + { + let error = ProgramParseError::unexpected_instruction(self.reader.origin + self.reader.position - 1); + Err(error.into()) + } + + #[inline(always)] + pub fn read_varint(&mut self) -> Result { + self.reader.read_varint() + } + + #[inline(always)] + pub fn read_reg(&mut self) -> Result { + self.reader.read_reg() + } + + #[inline(always)] + pub fn read_regs2(&mut self) -> Result<(Reg, Reg), ProgramParseError> { + self.reader.read_regs2() + } +} + macro_rules! define_opcodes { (@impl_shared $($name:ident = $value:expr,)+) => { #[allow(non_camel_case_types)] @@ -163,78 +232,227 @@ macro_rules! define_opcodes { }; ( + $d:tt [$($name_argless:ident = $value_argless:expr,)+] [$($name_with_imm:ident = $value_with_imm:expr,)+] [$($name_with_regs3:ident = $value_with_regs3:expr,)+] [$($name_with_regs2_imm:ident = $value_with_regs2_imm:expr,)+] ) => { + pub trait ParsingVisitor: InstructionVisitor> /*where E: From,*/ { + fn on_pre_visit(&mut self, _offset: usize, _opcode: u8) -> Self::ReturnTy { + Ok(()) + } + + fn on_post_visit(&mut self) -> Self::ReturnTy { + Ok(()) + } + } + pub trait InstructionVisitor { type ReturnTy; + $(fn $name_argless(&mut self) -> Self::ReturnTy;)+ $(fn $name_with_imm(&mut self, imm: u32) -> Self::ReturnTy;)+ $(fn $name_with_regs3(&mut self, reg1: Reg, reg2: Reg, reg3: Reg) -> Self::ReturnTy;)+ $(fn $name_with_regs2_imm(&mut self, reg1: Reg, reg2: Reg, imm: u32) -> Self::ReturnTy;)+ } - impl RawInstruction { + #[macro_export] + macro_rules! implement_instruction_visitor { + (impl<$d($visitor_ty_params:tt),*> $visitor_ty:ty, $method:ident) => { + impl<$d($visitor_ty_params),*> polkavm_common::program::InstructionVisitor for $visitor_ty { + type ReturnTy = (); + + $(fn $name_argless(&mut self) -> Self::ReturnTy { + self.$method(polkavm_common::program::Instruction::$name_argless); + })+ + $(fn $name_with_imm(&mut self, imm: u32) -> Self::ReturnTy { + self.$method(polkavm_common::program::Instruction::$name_with_imm(imm)); + })+ + $(fn $name_with_regs3(&mut self, reg1: Reg, reg2: Reg, reg3: Reg) -> Self::ReturnTy { + self.$method(polkavm_common::program::Instruction::$name_with_regs3(reg1, reg2, reg3)); + })+ + $(fn $name_with_regs2_imm(&mut self, reg1: Reg, reg2: Reg, imm: u32) -> Self::ReturnTy { + self.$method(polkavm_common::program::Instruction::$name_with_regs2_imm(reg1, reg2, imm)); + })+ + } + } + } + + pub use implement_instruction_visitor; + + #[derive(Copy, Clone, PartialEq, Eq, Debug)] + #[allow(non_camel_case_types)] + pub enum Instruction { + $($name_argless,)+ + $($name_with_imm(u32),)+ + $($name_with_regs3(Reg, Reg, Reg),)+ + $($name_with_regs2_imm(Reg, Reg, u32),)+ + } + + impl Instruction { pub fn visit(self, visitor: &mut T) -> T::ReturnTy where T: InstructionVisitor { - // SAFETY: If a given opcode is set then we have a guarantee the arguments are also - // properly set, in which case calling the `*_unchecked` methods is safe. - #[allow(unsafe_code)] - unsafe { - match self.op as usize { - $($value_argless => visitor.$name_argless(),)+ - $($value_with_imm => visitor.$name_with_imm(self.imm_or_reg),)+ - $($value_with_regs3 => visitor.$name_with_regs3(self.reg1_unchecked(), self.reg2_unchecked(), self.reg3_unchecked()),)+ - $($value_with_regs2_imm => visitor.$name_with_regs2_imm(self.reg1_unchecked(), self.reg2_unchecked(), self.imm_or_reg),)+ - _ => unreachable!() - } + match self { + $(Self::$name_argless => visitor.$name_argless(),)+ + $(Self::$name_with_imm(imm) => visitor.$name_with_imm(imm),)+ + $(Self::$name_with_regs3(reg1, reg2, reg3) => visitor.$name_with_regs3(reg1, reg2, reg3),)+ + $(Self::$name_with_regs2_imm(reg1, reg2, imm) => visitor.$name_with_regs2_imm(reg1, reg2, imm),)+ + } + } + + pub fn serialize_into(self, buffer: &mut [u8]) -> usize { + match self { + $(Self::$name_argless => Self::serialize_argless(buffer, Opcode::$name_argless),)+ + $(Self::$name_with_imm(imm) => Self::serialize_with_imm(buffer, Opcode::$name_with_imm, imm),)+ + $(Self::$name_with_regs3(reg1, reg2, reg3) => Self::serialize_with_regs3(buffer, Opcode::$name_with_regs3, reg1, reg2, reg3),)+ + $(Self::$name_with_regs2_imm(reg1, reg2, imm) => Self::serialize_with_regs2_imm(buffer, Opcode::$name_with_regs2_imm, reg1, reg2, imm),)+ + } + } + + pub fn opcode(self) -> Opcode { + match self { + $(Self::$name_argless => Opcode::$name_argless,)+ + $(Self::$name_with_imm(..) => Opcode::$name_with_imm,)+ + $(Self::$name_with_regs3(..) => Opcode::$name_with_regs3,)+ + $(Self::$name_with_regs2_imm(..) => Opcode::$name_with_regs2_imm,)+ } } } - impl core::fmt::Display for RawInstruction { + impl core::fmt::Display for Instruction { fn fmt(&self, fmt: &mut core::fmt::Formatter) -> core::fmt::Result { self.visit(fmt) } } + fn parse_instruction_impl(opcode: u8, reader: &mut Reader) -> Result { + Ok(match opcode { + $($value_argless => Instruction::$name_argless,)+ + $($value_with_imm => Instruction::$name_with_imm(reader.read_varint()?),)+ + $($value_with_regs3 => { + let (reg1, reg2) = reader.read_regs2()?; + let reg3 = reader.read_reg()?; + Instruction::$name_with_regs3(reg1, reg2, reg3) + },)+ + $($value_with_regs2_imm => { + let (reg1, reg2) = reader.read_regs2()?; + let imm = reader.read_varint()?; + Instruction::$name_with_regs2_imm(reg1, reg2, imm) + },)+ + _ => return Err(ProgramParseError::unexpected_instruction(reader.origin + reader.position - 1)) + }) + } + pub mod asm { - use super::{RawInstruction, Opcode, Reg}; + use super::{Instruction, Reg}; $( - pub fn $name_argless() -> RawInstruction { - RawInstruction::new_argless(Opcode::$name_argless) + pub fn $name_argless() -> Instruction { + Instruction::$name_argless } )+ $( - pub fn $name_with_imm(imm: u32) -> RawInstruction { - RawInstruction::new_with_imm(Opcode::$name_with_imm, imm) + pub fn $name_with_imm(imm: u32) -> Instruction { + Instruction::$name_with_imm(imm) } )+ $( - pub fn $name_with_regs3(reg1: Reg, reg2: Reg, reg3: Reg) -> RawInstruction { - RawInstruction::new_with_regs3(Opcode::$name_with_regs3, reg1, reg2, reg3) + pub fn $name_with_regs3(reg1: Reg, reg2: Reg, reg3: Reg) -> Instruction { + Instruction::$name_with_regs3(reg1, reg2, reg3) } )+ $( - pub fn $name_with_regs2_imm(reg1: Reg, reg2: Reg, imm: u32) -> RawInstruction { - RawInstruction::new_with_regs2_imm(Opcode::$name_with_regs2_imm, reg1, reg2, imm) + pub fn $name_with_regs2_imm(reg1: Reg, reg2: Reg, imm: u32) -> Instruction { + Instruction::$name_with_regs2_imm(reg1, reg2, imm) } )+ - pub fn ret() -> RawInstruction { + pub fn ret() -> Instruction { jump_and_link_register(Reg::Zero, Reg::RA, 0) } - pub fn load_imm(dst: Reg, value: u32) -> RawInstruction { + pub fn load_imm(dst: Reg, value: u32) -> Instruction { add_imm(dst, Reg::Zero, value) } } + #[macro_export] + macro_rules! prepare_visitor { + ($visitor_ty:ident<$d($visitor_ty_params:tt),*>) => {{ + use polkavm_common::program::{ + InstructionVisitor, + VisitorHelper, + }; + + type ReturnTy<$d($visitor_ty_params),*> = <$visitor_ty<$d($visitor_ty_params),*> as InstructionVisitor>::ReturnTy; + type VisitFn<'_code, $d($visitor_ty_params),*> = fn(state: &mut VisitorHelper<'_code, $visitor_ty<$d($visitor_ty_params),*>>) -> ReturnTy<$d($visitor_ty_params),*>; + + static DECODE_TABLE: [VisitFn; 256] = { + let mut table = [VisitorHelper::unknown_opcode as VisitFn; 256]; + $({ + fn $name_argless<'_code, $d($visitor_ty_params),*>(state: &mut VisitorHelper<'_code, $visitor_ty<$d($visitor_ty_params),*>>) -> ReturnTy<$d($visitor_ty_params),*>{ + state.visitor.$name_argless() + } + + table[$value_argless] = $name_argless; + })* + $({ + fn $name_with_imm<'_code, $d($visitor_ty_params),*>(state: &mut VisitorHelper<'_code, $visitor_ty<$d($visitor_ty_params),*>>) -> ReturnTy<$d($visitor_ty_params),*>{ + let imm = state.read_varint()?; + state.visitor.$name_with_imm(imm) + } + + table[$value_with_imm] = $name_with_imm; + })* + + $({ + fn $name_with_regs3<'_code, $d($visitor_ty_params),*>(state: &mut VisitorHelper<'_code, $visitor_ty<$d($visitor_ty_params),*>>) -> ReturnTy<$d($visitor_ty_params),*>{ + let (reg1, reg2) = state.read_regs2()?; + let reg3 = state.read_reg()?; + + state.visitor.$name_with_regs3(reg1, reg2, reg3) + } + + table[$value_with_regs3] = $name_with_regs3; + })* + + $({ + fn $name_with_regs2_imm<'_code, $d($visitor_ty_params),*>(state: &mut VisitorHelper<'_code, $visitor_ty<$d($visitor_ty_params),*>>) -> ReturnTy<$d($visitor_ty_params),*>{ + let (reg1, reg2) = state.read_regs2()?; + let imm = state.read_varint()?; + + state.visitor.$name_with_regs2_imm(reg1, reg2, imm) + } + + table[$value_with_regs2_imm] = $name_with_regs2_imm; + })* + + table + }; + + #[inline] + fn run<$d($visitor_ty_params),*>( + blob: &ProgramBlob, + visitor: $visitor_ty<$d($visitor_ty_params),*>, + ) + -> ($visitor_ty<$d($visitor_ty_params),*>, <$visitor_ty<$d($visitor_ty_params),*> as InstructionVisitor>::ReturnTy) + { + let decode_table: &[VisitFn; 256] = &DECODE_TABLE; + // SAFETY: Here we transmute the lifetimes which were unnecessarily extended to be 'static due to the table here being a `static`. + let decode_table: &[VisitFn; 256] = unsafe { core::mem::transmute(decode_table) }; + + VisitorHelper::run(blob, visitor, decode_table) + } + + run + }}; + } + + pub use prepare_visitor; + define_opcodes!( @impl_shared $($name_argless = $value_argless,)+ @@ -246,6 +464,7 @@ macro_rules! define_opcodes { } define_opcodes! { + $ // 1 byte instructions // Instructions with no args. [ @@ -317,21 +536,45 @@ define_opcodes! { ] } -pub const MAX_INSTRUCTION_LENGTH: usize = MAX_VARINT_LENGTH + 2; +impl Instruction { + pub fn deserialize(input: &[u8]) -> Option<(usize, Self)> { + let mut reader = Reader { + blob: input, + origin: 0, + position: 0, + }; -#[derive(Copy, Clone, PartialEq, Eq)] -pub struct RawInstruction { - op: Opcode, - regs: u8, - imm_or_reg: u32, -} + let opcode = reader.read_byte().ok()?; + let instruction = parse_instruction_impl(opcode, &mut reader).ok()?; + Some((reader.position, instruction)) + } -impl core::fmt::Debug for RawInstruction { - fn fmt(&self, fmt: &mut core::fmt::Formatter) -> core::fmt::Result { - write!(fmt, "({:02x} {:02x} {:08x}) {}", self.op as u8, self.regs, self.imm_or_reg, self) + fn serialize_argless(buffer: &mut [u8], opcode: Opcode) -> usize { + buffer[0] = opcode as u8; + 1 + } + + fn serialize_with_imm(buffer: &mut [u8], opcode: Opcode, imm: u32) -> usize { + buffer[0] = opcode as u8; + write_varint(imm, &mut buffer[1..]) + 1 + } + + fn serialize_with_regs3(buffer: &mut [u8], opcode: Opcode, reg1: Reg, reg2: Reg, reg3: Reg) -> usize { + buffer[0] = opcode as u8; + buffer[1] = reg1 as u8 | (reg2 as u8) << 4; + buffer[2] = reg3 as u8; + 3 + } + + fn serialize_with_regs2_imm(buffer: &mut [u8], opcode: Opcode, reg1: Reg, reg2: Reg, imm: u32) -> usize { + buffer[0] = opcode as u8; + buffer[1] = reg1 as u8 | (reg2 as u8) << 4; + write_varint(imm, &mut buffer[2..]) + 2 } } +pub const MAX_INSTRUCTION_LENGTH: usize = MAX_VARINT_LENGTH + 2; + impl<'a> InstructionVisitor for core::fmt::Formatter<'a> { type ReturnTy = core::fmt::Result; @@ -607,211 +850,6 @@ impl<'a> InstructionVisitor for core::fmt::Formatter<'a> { } } -impl RawInstruction { - #[inline] - pub fn new_argless(op: Opcode) -> Self { - assert_eq!(op as u8 & 0b11_000000, 0b00_000000); - RawInstruction { - op, - regs: 0, - imm_or_reg: 0, - } - } - - #[inline] - pub fn new_with_imm(op: Opcode, imm: u32) -> Self { - assert_eq!(op as u8 & 0b11_000000, 0b01_000000); - RawInstruction { - op, - regs: 0, - imm_or_reg: imm, - } - } - - #[inline] - pub fn new_with_regs3(op: Opcode, reg1: Reg, reg2: Reg, reg3: Reg) -> Self { - assert_eq!(op as u8 & 0b11_000000, 0b10_000000); - RawInstruction { - op, - regs: reg1 as u8 | (reg2 as u8) << 4, - imm_or_reg: reg3 as u32, - } - } - - #[inline] - pub fn new_with_regs2_imm(op: Opcode, reg1: Reg, reg2: Reg, imm: u32) -> Self { - assert_eq!(op as u8 & 0b11_000000, 0b11_000000); - RawInstruction { - op, - regs: reg1 as u8 | (reg2 as u8) << 4, - imm_or_reg: imm, - } - } - - #[inline] - pub fn op(self) -> Opcode { - self.op - } - - #[inline] - pub fn reg1(self) -> Reg { - Reg::from_u8(self.regs & 0b00001111).unwrap_or_else(|| unreachable!()) - } - - #[inline] - pub fn reg2(self) -> Reg { - Reg::from_u8(self.regs >> 4).unwrap_or_else(|| unreachable!()) - } - - #[inline] - pub fn reg3(self) -> Reg { - Reg::from_u8(self.imm_or_reg as u8).unwrap_or_else(|| unreachable!()) - } - - #[inline] - #[allow(unsafe_code)] - unsafe fn reg1_unchecked(self) -> Reg { - core::mem::transmute(self.raw_reg1()) - } - - #[inline] - #[allow(unsafe_code)] - unsafe fn reg2_unchecked(self) -> Reg { - core::mem::transmute(self.raw_reg2()) - } - - #[inline] - #[allow(unsafe_code)] - unsafe fn reg3_unchecked(self) -> Reg { - core::mem::transmute(self.raw_reg3()) - } - - #[inline] - pub fn raw_reg1(self) -> u8 { - self.regs & 0b00001111 - } - - #[inline] - pub fn raw_reg2(self) -> u8 { - self.regs >> 4 - } - - #[inline] - pub fn raw_reg3(self) -> u8 { - self.imm_or_reg as u8 - } - - #[inline] - pub fn raw_imm_or_reg3(self) -> u32 { - self.imm_or_reg - } - - pub fn deserialize(input: &[u8]) -> Option<(usize, Self)> { - let op = Opcode::from_u8(*input.get(0)?)?; - - let mut position = 1; - let mut output = RawInstruction { - op, - regs: 0, - imm_or_reg: 0, - }; - - // Should we load the registers mask? - if op as u8 & 0b10000000 != 0 { - output.regs = *input.get(position)?; - if matches!(output.regs & 0b1111, 14 | 15) || matches!(output.regs >> 4, 14 | 15) { - // Invalid register. - return None; - } - position += 1; - } - - // Is there at least another byte to load? - if op as u8 & 0b11000000 != 0 { - let first_byte = *input.get(position)?; - position += 1; - - if op as u8 & 0b11_000000 == 0b10_000000 { - // It's the third register. - if first_byte > 13 { - // Invalid register. - return None; - } - - output.imm_or_reg = first_byte as u32; - } else { - // It's an immediate. - let (length, imm_or_reg) = read_varint(&input[position..], first_byte)?; - position += length; - output.imm_or_reg = imm_or_reg; - } - } - - Some((position, output)) - } - - #[inline] - pub fn serialize_into(self, buffer: &mut [u8]) -> usize { - assert!(buffer.len() >= MAX_INSTRUCTION_LENGTH); - buffer[0] = self.op as u8; - - let mut length = 1; - if self.op as u8 & 0b10000000 != 0 { - buffer[1] = self.regs; - length += 1; - } - - if self.op as u8 & 0b11000000 != 0 { - length += write_varint(self.imm_or_reg, &mut buffer[length..]); - } - - length - } -} - -macro_rules! test_serde { - ($($serialized:expr => $deserialized:expr,)+) => { - #[test] - fn test_deserialize_raw_instruction() { - $( - assert_eq!( - RawInstruction::deserialize(&$serialized).unwrap(), - ($serialized.len(), $deserialized), - "failed to deserialize: {:?}", $serialized - ); - )+ - } - - #[test] - fn test_serialize_raw_instruction() { - $( - { - let mut buffer = [0; MAX_INSTRUCTION_LENGTH]; - let byte_count = $deserialized.serialize_into(&mut buffer); - assert_eq!(byte_count, $serialized.len()); - assert_eq!(&buffer[..byte_count], $serialized); - assert!(buffer[byte_count..].iter().all(|&byte| byte == 0)); - } - )+ - } - }; -} - -test_serde! { - [0b01_111111, 0b01111111] => RawInstruction { op: Opcode::ecalli, regs: 0, imm_or_reg: 0b01111111 }, - [0b01_111111, 0b10111111, 0b00000000] => RawInstruction { op: Opcode::ecalli, regs: 0, imm_or_reg: 0b00111111_00000000 }, - [0b01_111111, 0b10111111, 0b10101010] => RawInstruction { op: Opcode::ecalli, regs: 0, imm_or_reg: 0b00111111_10101010 }, - [0b01_111111, 0b10111111, 0b01010101] => RawInstruction { op: Opcode::ecalli, regs: 0, imm_or_reg: 0b00111111_01010101 }, - [0b01_111111, 0b10000001, 0b11111111] => RawInstruction { op: Opcode::ecalli, regs: 0, imm_or_reg: 0b00000001_11111111 }, - [0b01_111111, 0b11000001, 0b10101010, 0b01010101] => RawInstruction { op: Opcode::ecalli, regs: 0, imm_or_reg: 0b00000001_01010101_10101010 }, - - [0b00_000000] => RawInstruction { op: Opcode::trap, regs: 0, imm_or_reg: 0 }, - - [0b10_000000, 0b00100001, 0b00000100] => RawInstruction { op: Opcode::set_less_than_unsigned, regs: 0b00100001, imm_or_reg: 0b00000100 }, - - [0b11_000000, 0b00100001, 0b10111111, 0b00000000] => RawInstruction { op: Opcode::set_less_than_unsigned_imm, regs: 0b00100001, imm_or_reg: 0b00111111_00000000 }, -} - #[derive(Debug)] pub struct ProgramParseError(ProgramParseErrorKind); @@ -823,6 +861,9 @@ enum ProgramParseErrorKind { FailedToReadStringNonUtf { offset: usize, }, + FailedToReadInstructionArguments { + offset: usize, + }, UnexpectedSection { offset: usize, section: u8, @@ -841,6 +882,32 @@ enum ProgramParseErrorKind { Other(&'static str), } +impl ProgramParseError { + #[cold] + fn unexpected_instruction(offset: usize) -> ProgramParseError { + ProgramParseError(ProgramParseErrorKind::UnexpectedInstruction { offset }) + } + + #[cold] + fn failed_to_read_instruction_arguments(offset: usize) -> ProgramParseError { + ProgramParseError(ProgramParseErrorKind::FailedToReadInstructionArguments { offset }) + } + + #[cold] + fn failed_to_read_varint(offset: usize) -> ProgramParseError { + ProgramParseError(ProgramParseErrorKind::FailedToReadVarint { offset }) + } + + #[cold] + fn unexpected_end_of_file(offset: usize, expected_count: usize, actual_count: usize) -> ProgramParseError { + ProgramParseError(ProgramParseErrorKind::UnexpectedEnd { + offset, + expected_count, + actual_count, + }) + } +} + impl core::fmt::Display for ProgramParseError { fn fmt(&self, fmt: &mut core::fmt::Formatter) -> core::fmt::Result { match self.0 { @@ -858,6 +925,13 @@ impl core::fmt::Display for ProgramParseError { offset ) } + ProgramParseErrorKind::FailedToReadInstructionArguments { offset } => { + write!( + fmt, + "failed to parse program blob: failed to parse instruction arguments at offset 0x{:x}", + offset + ) + } ProgramParseErrorKind::UnexpectedSection { offset, section } => { write!( fmt, @@ -1022,11 +1096,15 @@ pub struct ProgramBlob<'a> { debug_strings: Range, debug_line_program_ranges: Range, debug_line_programs: Range, + + instruction_count: u32, + basic_block_count: u32, } #[derive(Clone)] struct Reader<'a> { blob: &'a [u8], + origin: usize, position: usize, } @@ -1035,19 +1113,53 @@ impl<'a> Reader<'a> { self.read_slice_as_range(count).map(|_| ()) } + fn finish(&mut self) { + self.position = self.blob.len(); + } + + #[inline(always)] fn read_byte(&mut self) -> Result { Ok(self.blob[self.read_slice_as_range(1)?][0]) } + #[inline(always)] fn read_varint(&mut self) -> Result { let offset = self.position; let first_byte = self.read_byte()?; let (length, value) = read_varint(&self.blob[self.position..], first_byte) - .ok_or(ProgramParseError(ProgramParseErrorKind::FailedToReadVarint { offset }))?; + .ok_or_else(|| ProgramParseError::failed_to_read_varint(self.origin + offset))?; self.position += length; Ok(value) } + #[inline(always)] + fn read_reg(&mut self) -> Result { + let reg = self.read_byte()?; + if let Some(reg) = Reg::from_u8(reg) { + return Ok(reg); + } + + Err(ProgramParseError::failed_to_read_instruction_arguments( + self.origin + self.position - 1, + )) + } + + #[inline(always)] + fn read_regs2(&mut self) -> Result<(Reg, Reg), ProgramParseError> { + let regs = self.read_byte()?; + let reg1 = regs & 0b1111; + let reg2 = regs >> 4; + if let Some(reg1) = Reg::from_u8(reg1) { + if let Some(reg2) = Reg::from_u8(reg2) { + return Ok((reg1, reg2)); + } + } + + Err(ProgramParseError::failed_to_read_instruction_arguments( + self.origin + self.position - 1, + )) + } + fn read_string_with_length(&mut self) -> Result<&'a str, ProgramParseError> { let offset = self.position; let length = self.read_varint()?; @@ -1055,17 +1167,19 @@ impl<'a> Reader<'a> { let slice = &self.blob[range]; core::str::from_utf8(slice) .ok() - .ok_or(ProgramParseError(ProgramParseErrorKind::FailedToReadStringNonUtf { offset })) + .ok_or(ProgramParseError(ProgramParseErrorKind::FailedToReadStringNonUtf { + offset: self.origin + offset, + })) } fn read_slice_as_range(&mut self, count: u32) -> Result, ProgramParseError> { let range = self.position..self.position + count as usize; if self.blob.get(range.clone()).is_none() { - return Err(ProgramParseError(ProgramParseErrorKind::UnexpectedEnd { - offset: self.position, - expected_count: count as usize, - actual_count: self.blob.len() - self.position, - })); + return Err(ProgramParseError::unexpected_end_of_file( + self.origin + self.position, + count as usize, + self.blob.len() - self.position, + )); }; self.position = range.end; @@ -1153,6 +1267,7 @@ impl<'a> ProgramBlob<'a> { let mut reader = Reader { blob: &program.blob, + origin: 0, position: BLOB_MAGIC.len(), }; @@ -1182,7 +1297,32 @@ impl<'a> ProgramBlob<'a> { reader.read_section_range_into(&mut section, &mut program.imports, SECTION_IMPORTS)?; reader.read_section_range_into(&mut section, &mut program.exports, SECTION_EXPORTS)?; reader.read_section_range_into(&mut section, &mut program.jump_table, SECTION_JUMP_TABLE)?; - reader.read_section_range_into(&mut section, &mut program.code, SECTION_CODE)?; + + if section == SECTION_CODE { + let section_length = reader.read_varint()?; + let initial_position = reader.position; + let instruction_count = reader.read_varint()?; + let basic_block_count = reader.read_varint()?; + let header_size = (reader.position - initial_position) as u32; + if section_length < header_size { + return Err(ProgramParseError(ProgramParseErrorKind::Other("the code section is too short"))); + } + + let body_length = section_length - header_size; + if instruction_count > body_length { + return Err(ProgramParseError(ProgramParseErrorKind::Other("invalid instruction count"))); + } + + if basic_block_count > body_length { + return Err(ProgramParseError(ProgramParseErrorKind::Other("invalid basic block count"))); + } + + program.instruction_count = instruction_count; + program.basic_block_count = basic_block_count; + program.code = reader.read_slice_as_range(body_length)?; + section = reader.read_byte()?; + } + reader.read_section_range_into(&mut section, &mut program.debug_strings, SECTION_OPT_DEBUG_STRINGS)?; reader.read_section_range_into(&mut section, &mut program.debug_line_programs, SECTION_OPT_DEBUG_LINE_PROGRAMS)?; reader.read_section_range_into( @@ -1235,10 +1375,29 @@ impl<'a> ProgramBlob<'a> { &self.blob[self.code.clone()] } + /// Returns the number of instructions the code section should contain. + /// + /// NOTE: It is safe to preallocate memory based on this value as we make sure + /// that it is no larger than the the physical size of the code section, however + /// we do not verify that it is actually true, so it should *not* be blindly trusted! + pub fn instruction_count(&self) -> u32 { + self.instruction_count + } + + /// Returns the number of basic blocks the code section should contain. + /// + /// NOTE: It is safe to preallocate memory based on this value as we make sure + /// that it is no larger than the the physical size of the code section, however + /// we do not verify that it is actually true, so it should *not* be blindly trusted! + pub fn basic_block_count(&self) -> u32 { + self.basic_block_count + } + fn get_section_reader(&self, range: Range) -> Reader { Reader { - blob: &self.blob[..range.end], - position: range.start, + blob: &self.blob[range.start..range.end], + origin: range.start, + position: 0, } } @@ -1361,38 +1520,34 @@ impl<'a> ProgramBlob<'a> { } /// Returns an iterator over program instructions. - pub fn instructions(&'_ self) -> impl Iterator> + Clone + '_ { + pub fn instructions(&'_ self) -> impl Iterator> + Clone + '_ { #[derive(Clone)] struct CodeIterator<'a> { - code_section_position: usize, - position: usize, - code: &'a [u8], + reader: Reader<'a>, } impl<'a> Iterator for CodeIterator<'a> { - type Item = Result; + type Item = Result; fn next(&mut self) -> Option { - let slice = &self.code[self.position..]; - if slice.is_empty() { + if self.reader.is_eof() { return None; } - if let Some((bytes_consumed, instruction)) = RawInstruction::deserialize(slice) { - self.position += bytes_consumed; - return Some(Ok(instruction)); - } + let result = (|| -> Result { + let opcode = self.reader.read_byte()?; + parse_instruction_impl(opcode, &mut self.reader) + })(); - let offset = self.code_section_position + self.position; - self.position = self.code.len(); + if result.is_err() { + self.reader.finish(); + } - Some(Err(ProgramParseError(ProgramParseErrorKind::UnexpectedInstruction { offset }))) + Some(result) } } CodeIterator { - code_section_position: self.code.start, - position: 0, - code: self.code(), + reader: self.get_section_reader(self.code.clone()), } } @@ -1514,6 +1669,9 @@ impl<'a> ProgramBlob<'a> { debug_strings: self.debug_strings, debug_line_program_ranges: self.debug_line_program_ranges, debug_line_programs: self.debug_line_programs, + + instruction_count: self.instruction_count, + basic_block_count: self.basic_block_count, } } } diff --git a/crates/polkavm-common/src/utils.rs b/crates/polkavm-common/src/utils.rs index 357634bb..5424a794 100644 --- a/crates/polkavm-common/src/utils.rs +++ b/crates/polkavm-common/src/utils.rs @@ -131,6 +131,7 @@ macro_rules! define_align_to_next_page { }; } +define_align_to_next_page!(align_to_next_page_u32, u32); define_align_to_next_page!(align_to_next_page_u64, u64); define_align_to_next_page!(align_to_next_page_usize, usize); diff --git a/crates/polkavm-common/src/writer.rs b/crates/polkavm-common/src/writer.rs index 3440519e..9d0da50c 100644 --- a/crates/polkavm-common/src/writer.rs +++ b/crates/polkavm-common/src/writer.rs @@ -1,5 +1,5 @@ use crate::elf::FnMetadata; -use crate::program::{self, RawInstruction}; +use crate::program::{self, Instruction}; use alloc::vec::Vec; use core::ops::Range; @@ -14,6 +14,8 @@ pub struct ProgramBlobBuilder { jump_table: Vec, code: Vec, custom: Vec<(u8, Vec)>, + instruction_count: u32, + basic_block_count: u32, } impl ProgramBlobBuilder { @@ -53,11 +55,30 @@ impl ProgramBlobBuilder { } } - pub fn set_code(&mut self, code: &[RawInstruction]) { + pub fn set_code(&mut self, code: &[Instruction]) { + self.instruction_count = 0; + self.basic_block_count = 0; for instruction in code { let mut buffer = [0; program::MAX_INSTRUCTION_LENGTH]; let length = instruction.serialize_into(&mut buffer); self.code.extend_from_slice(&buffer[..length]); + self.instruction_count += 1; + + use crate::program::Opcode as O; + match instruction.opcode() { + O::trap + | O::fallthrough + | O::jump_and_link_register + | O::branch_less_unsigned + | O::branch_less_signed + | O::branch_greater_or_equal_unsigned + | O::branch_greater_or_equal_signed + | O::branch_eq + | O::branch_not_eq => { + self.basic_block_count += 1; + } + _ => {} + } } } @@ -102,7 +123,11 @@ impl ProgramBlobBuilder { } writer.push_section(program::SECTION_JUMP_TABLE, &self.jump_table); - writer.push_section(program::SECTION_CODE, &self.code); + writer.push_section_inplace(program::SECTION_CODE, |writer| { + writer.push_varint(self.instruction_count); + writer.push_varint(self.basic_block_count); + writer.push_raw_bytes(&self.code); + }); for (section, contents) in self.custom { writer.push_section(section, &contents); diff --git a/crates/polkavm-linker/src/program_from_elf.rs b/crates/polkavm-linker/src/program_from_elf.rs index 9dd0ca74..ee531d82 100644 --- a/crates/polkavm-linker/src/program_from_elf.rs +++ b/crates/polkavm-linker/src/program_from_elf.rs @@ -1,7 +1,7 @@ -use polkavm_common::abi::{GuestMemoryConfig, VM_ADDR_USER_MEMORY, VM_CODE_ADDRESS_ALIGNMENT, VM_PAGE_SIZE}; +use polkavm_common::abi::{GuestMemoryConfig, VM_ADDR_USER_MEMORY, VM_CODE_ADDRESS_ALIGNMENT, VM_MAX_PAGE_SIZE, VM_PAGE_SIZE}; use polkavm_common::elf::{FnMetadata, ImportMetadata, INSTRUCTION_ECALLI}; use polkavm_common::program::Reg as PReg; -use polkavm_common::program::{self, FrameKind, LineProgramOp, Opcode, ProgramBlob, RawInstruction}; +use polkavm_common::program::{self, FrameKind, Instruction, LineProgramOp, ProgramBlob}; use polkavm_common::utils::align_to_next_page_u64; use polkavm_common::varint; use polkavm_common::writer::{ProgramBlobBuilder, Writer}; @@ -679,7 +679,7 @@ fn extract_memory_config( } } - assert_eq!(memory_end % VM_PAGE_SIZE as u64, 0); + assert_eq!(memory_end % VM_MAX_PAGE_SIZE as u64, 0); let ro_data_address = memory_end; for §ion_index in sections_ro_data { @@ -724,10 +724,11 @@ fn extract_memory_config( } assert_eq!(memory_end % VM_PAGE_SIZE as u64, 0); + memory_end = align_to_next_page_u64(VM_MAX_PAGE_SIZE as u64, memory_end).unwrap(); if ro_data_size > 0 { // Add a guard page between read-only data and read-write data. - memory_end += u64::from(VM_PAGE_SIZE); + memory_end += u64::from(VM_MAX_PAGE_SIZE); } let mut rw_data = Vec::new(); @@ -3675,7 +3676,7 @@ fn emit_code( used_blocks: &[BlockTarget], used_imports: &HashSet, jump_target_for_block: &[Option], -) -> Result, ProgramFromElfError> { +) -> Result, ProgramFromElfError> { let can_fallthrough_to_next_block = calculate_whether_can_fallthrough(all_blocks, used_blocks); let get_data_address = |target: SectionTarget| -> Result { if let Some(base_address) = base_address_for_section.get(&target.section_index) { @@ -3702,7 +3703,7 @@ fn emit_code( }; let mut basic_block_delimited = true; - let mut code: Vec<(SourceStack, RawInstruction)> = Vec::new(); + let mut code: Vec<(SourceStack, Instruction)> = Vec::new(); for block_target in used_blocks { let block = &all_blocks[block_target.index()]; @@ -3714,41 +3715,76 @@ fn emit_code( offset_range: (block.source.offset_range.start..block.source.offset_range.start + 4).into(), } .into(), - RawInstruction::new_argless(Opcode::fallthrough), + Instruction::fallthrough, )); } - fn conv_load_kind(kind: LoadKind) -> Opcode { - match kind { - LoadKind::I8 => Opcode::load_i8, - LoadKind::I16 => Opcode::load_i16, - LoadKind::U32 => Opcode::load_u32, - LoadKind::U8 => Opcode::load_u8, - LoadKind::U16 => Opcode::load_u16, - } - } + macro_rules! codegen { + ( + args = $args:tt, + kind = $kind:expr, - fn conv_store_kind(kind: StoreKind) -> Opcode { - match kind { - StoreKind::U32 => Opcode::store_u32, - StoreKind::U8 => Opcode::store_u8, - StoreKind::U16 => Opcode::store_u16, + { + $($p:pat => $inst:ident,)+ + } + ) => { + match $kind { + $( + $p => Instruction::$inst $args + ),+ + } } } for (source, op) in &block.ops { let op = match *op { BasicInst::LoadAbsolute { kind, dst, target } => { - RawInstruction::new_with_regs2_imm(conv_load_kind(kind), cast_reg(dst), cast_reg(Reg::Zero), get_data_address(target)?) + codegen! { + args = (cast_reg(dst), cast_reg(Reg::Zero), get_data_address(target)?), + kind = kind, + { + LoadKind::I8 => load_i8, + LoadKind::I16 => load_i16, + LoadKind::U32 => load_u32, + LoadKind::U8 => load_u8, + LoadKind::U16 => load_u16, + } + } } BasicInst::StoreAbsolute { kind, src, target } => { - RawInstruction::new_with_regs2_imm(conv_store_kind(kind), cast_reg(src), cast_reg(Reg::Zero), get_data_address(target)?) + codegen! { + args = (cast_reg(src), cast_reg(Reg::Zero), get_data_address(target)?), + kind = kind, + { + StoreKind::U32 => store_u32, + StoreKind::U16 => store_u16, + StoreKind::U8 => store_u8, + } + } } BasicInst::LoadIndirect { kind, dst, base, offset } => { - RawInstruction::new_with_regs2_imm(conv_load_kind(kind), cast_reg(dst), cast_reg(base), offset as u32) + codegen! { + args = (cast_reg(dst), cast_reg(base), offset as u32), + kind = kind, + { + LoadKind::I8 => load_i8, + LoadKind::I16 => load_i16, + LoadKind::U32 => load_u32, + LoadKind::U8 => load_u8, + LoadKind::U16 => load_u16, + } + } } BasicInst::StoreIndirect { kind, src, base, offset } => { - RawInstruction::new_with_regs2_imm(conv_store_kind(kind), cast_reg(src), cast_reg(base), offset as u32) + codegen! { + args = (cast_reg(src), cast_reg(base), offset as u32), + kind = kind, + { + StoreKind::U32 => store_u32, + StoreKind::U16 => store_u16, + StoreKind::U8 => store_u8, + } + } } BasicInst::LoadAddress { dst, target } => { let value = match target { @@ -3762,7 +3798,7 @@ fn emit_code( AnyTarget::Data(target) => get_data_address(target)?, }; - RawInstruction::new_with_regs2_imm(Opcode::add_imm, cast_reg(dst), cast_reg(Reg::Zero), value) + Instruction::add_imm(cast_reg(dst), cast_reg(Reg::Zero), value) } BasicInst::LoadAddressIndirect { dst, target } => { let Some(&offset) = target_to_got_offset.get(&target) else { @@ -3777,61 +3813,62 @@ fn emit_code( }; let value = get_data_address(target)?; - RawInstruction::new_with_regs2_imm(conv_load_kind(LoadKind::U32), cast_reg(dst), cast_reg(Reg::Zero), value) + Instruction::load_u32(cast_reg(dst), cast_reg(Reg::Zero), value) } BasicInst::RegImm { kind, dst, src, imm } => { - let kind = match kind { - RegImmKind::Add => Opcode::add_imm, - RegImmKind::SetLessThanSigned => Opcode::set_less_than_signed_imm, - RegImmKind::SetLessThanUnsigned => Opcode::set_less_than_unsigned_imm, - RegImmKind::Xor => Opcode::xor_imm, - RegImmKind::Or => Opcode::or_imm, - RegImmKind::And => Opcode::and_imm, - RegImmKind::ShiftLogicalLeft => Opcode::shift_logical_left_imm, - RegImmKind::ShiftLogicalRight => Opcode::shift_logical_right_imm, - RegImmKind::ShiftArithmeticRight => Opcode::shift_arithmetic_right_imm, - }; - RawInstruction::new_with_regs2_imm(kind, cast_reg(dst), cast_reg(src), imm as u32) + codegen! { + args = (cast_reg(dst), cast_reg(src), imm as u32), + kind = kind, + { + RegImmKind::Add => add_imm, + RegImmKind::SetLessThanSigned => set_less_than_signed_imm, + RegImmKind::SetLessThanUnsigned => set_less_than_unsigned_imm, + RegImmKind::Xor => xor_imm, + RegImmKind::Or => or_imm, + RegImmKind::And => and_imm, + RegImmKind::ShiftLogicalLeft => shift_logical_left_imm, + RegImmKind::ShiftLogicalRight => shift_logical_right_imm, + RegImmKind::ShiftArithmeticRight => shift_arithmetic_right_imm, + } + } } BasicInst::RegReg { kind, dst, src1, src2 } => { - let kind = match kind { - RegRegKind::Add => Opcode::add, - RegRegKind::Sub => Opcode::sub, - RegRegKind::ShiftLogicalLeft => Opcode::shift_logical_left, - RegRegKind::SetLessThanSigned => Opcode::set_less_than_signed, - RegRegKind::SetLessThanUnsigned => Opcode::set_less_than_unsigned, - RegRegKind::Xor => Opcode::xor, - RegRegKind::ShiftLogicalRight => Opcode::shift_logical_right, - RegRegKind::ShiftArithmeticRight => Opcode::shift_arithmetic_right, - RegRegKind::Or => Opcode::or, - RegRegKind::And => Opcode::and, - RegRegKind::Mul => Opcode::mul, - RegRegKind::MulUpperSignedSigned => Opcode::mul_upper_signed_signed, - RegRegKind::MulUpperUnsignedUnsigned => Opcode::mul_upper_unsigned_unsigned, - RegRegKind::MulUpperSignedUnsigned => Opcode::mul_upper_signed_unsigned, - RegRegKind::Div => Opcode::div_signed, - RegRegKind::DivUnsigned => Opcode::div_unsigned, - RegRegKind::Rem => Opcode::rem_signed, - RegRegKind::RemUnsigned => Opcode::rem_unsigned, - }; - RawInstruction::new_with_regs3(kind, cast_reg(dst), cast_reg(src1), cast_reg(src2)) + codegen! { + args = (cast_reg(dst), cast_reg(src1), cast_reg(src2)), + kind = kind, + { + RegRegKind::Add => add, + RegRegKind::Sub => sub, + RegRegKind::ShiftLogicalLeft => shift_logical_left, + RegRegKind::SetLessThanSigned => set_less_than_signed, + RegRegKind::SetLessThanUnsigned => set_less_than_unsigned, + RegRegKind::Xor => xor, + RegRegKind::ShiftLogicalRight => shift_logical_right, + RegRegKind::ShiftArithmeticRight => shift_arithmetic_right, + RegRegKind::Or => or, + RegRegKind::And => and, + RegRegKind::Mul => mul, + RegRegKind::MulUpperSignedSigned => mul_upper_signed_signed, + RegRegKind::MulUpperUnsignedUnsigned => mul_upper_unsigned_unsigned, + RegRegKind::MulUpperSignedUnsigned => mul_upper_signed_unsigned, + RegRegKind::Div => div_signed, + RegRegKind::DivUnsigned => div_unsigned, + RegRegKind::Rem => rem_signed, + RegRegKind::RemUnsigned => rem_unsigned, + } + } } BasicInst::Ecalli { syscall } => { assert!(used_imports.contains(&syscall)); - RawInstruction::new_with_imm(Opcode::ecalli, syscall) + Instruction::ecalli(syscall) } }; code.push((source.clone(), op)); } - fn unconditional_jump(target: JumpTarget) -> RawInstruction { - RawInstruction::new_with_regs2_imm( - Opcode::jump_and_link_register, - cast_reg(Reg::Zero), - cast_reg(Reg::Zero), - target.static_target, - ) + fn unconditional_jump(target: JumpTarget) -> Instruction { + Instruction::jump_and_link_register(cast_reg(Reg::Zero), cast_reg(Reg::Zero), target.static_target) } match block.next.instruction { @@ -3852,18 +3889,13 @@ fn emit_code( code.push(( block.next.source.clone(), - RawInstruction::new_with_regs2_imm( - Opcode::jump_and_link_register, - cast_reg(ra), - cast_reg(Reg::Zero), - target.static_target, - ), + Instruction::jump_and_link_register(cast_reg(ra), cast_reg(Reg::Zero), target.static_target), )); } ControlInst::JumpIndirect { base, offset } => { code.push(( block.next.source.clone(), - RawInstruction::new_with_regs2_imm(Opcode::jump_and_link_register, cast_reg(Reg::Zero), cast_reg(base), offset as u32), + Instruction::jump_and_link_register(cast_reg(Reg::Zero), cast_reg(base), offset as u32), )); } ControlInst::CallIndirect { @@ -3876,7 +3908,7 @@ fn emit_code( get_jump_target(target_return)?; code.push(( block.next.source.clone(), - RawInstruction::new_with_regs2_imm(Opcode::jump_and_link_register, cast_reg(ra), cast_reg(base), offset as u32), + Instruction::jump_and_link_register(cast_reg(ra), cast_reg(base), offset as u32), )); } ControlInst::Branch { @@ -3891,22 +3923,23 @@ fn emit_code( let target_true = get_jump_target(target_true)?; get_jump_target(target_false)?; - let kind = match kind { - BranchKind::Eq => Opcode::branch_eq, - BranchKind::NotEq => Opcode::branch_not_eq, - BranchKind::LessSigned => Opcode::branch_less_signed, - BranchKind::GreaterOrEqualSigned => Opcode::branch_greater_or_equal_signed, - BranchKind::LessUnsigned => Opcode::branch_less_unsigned, - BranchKind::GreaterOrEqualUnsigned => Opcode::branch_greater_or_equal_unsigned, + let instruction = codegen! { + args = (cast_reg(src1), cast_reg(src2), target_true.static_target), + kind = kind, + { + BranchKind::Eq => branch_eq, + BranchKind::NotEq => branch_not_eq, + BranchKind::LessSigned => branch_less_signed, + BranchKind::GreaterOrEqualSigned => branch_greater_or_equal_signed, + BranchKind::LessUnsigned => branch_less_unsigned, + BranchKind::GreaterOrEqualUnsigned => branch_greater_or_equal_unsigned, + } }; - code.push(( - block.next.source.clone(), - RawInstruction::new_with_regs2_imm(kind, cast_reg(src1), cast_reg(src2), target_true.static_target), - )); + code.push((block.next.source.clone(), instruction)); } ControlInst::Unimplemented => { - code.push((block.next.source.clone(), RawInstruction::new_argless(Opcode::trap))); + code.push((block.next.source.clone(), Instruction::trap)); } } } diff --git a/crates/polkavm/src/api.rs b/crates/polkavm/src/api.rs index 290b03f5..1612ff41 100644 --- a/crates/polkavm/src/api.rs +++ b/crates/polkavm/src/api.rs @@ -12,7 +12,7 @@ use polkavm_common::abi::{VM_ADDR_RETURN_TO_HOST, VM_ADDR_USER_STACK_HIGH}; use polkavm_common::error::Trap; use polkavm_common::init::GuestProgramInit; use polkavm_common::program::{ExternFnPrototype, ExternTy, ProgramBlob, ProgramExport, ProgramImport}; -use polkavm_common::program::{FrameKind, Opcode, RawInstruction, Reg}; +use polkavm_common::program::{FrameKind, Instruction, InstructionVisitor, Reg}; use polkavm_common::utils::{Access, AsUninitSliceMut, Gas}; use crate::caller::{Caller, CallerRaw}; @@ -333,13 +333,12 @@ struct ModulePrivate { exports: Vec>, imports: BTreeMap>, export_index_by_name: HashMap, - instructions: Vec, instruction_by_basic_block: Vec, jump_table_index_by_basic_block: HashMap, basic_block_by_jump_table_index: Vec, - blob: Option>, + blob: ProgramBlob<'static>, compiled_module: CompiledModuleKind, interpreted_module: Option, memory_config: GuestMemoryConfig, @@ -350,13 +349,440 @@ struct ModulePrivate { #[derive(Clone)] pub struct Module(Arc); +pub(crate) trait BackendModule: Sized { + type BackendVisitor<'a>; + type Aux; + + #[allow(clippy::too_many_arguments)] + fn create_visitor<'a>( + config: &'a ModuleConfig, + exports: &'a [ProgramExport], + basic_block_by_jump_table_index: &'a [u32], + jump_table_index_by_basic_block: &'a HashMap, + init: GuestProgramInit<'a>, + instruction_count: usize, + basic_block_count: usize, + debug_trace_execution: bool, + ) -> Result<(Self::BackendVisitor<'a>, Self::Aux), Error>; + + fn finish_compilation<'a>(wrapper: VisitorWrapper<'a, Self::BackendVisitor<'a>>, aux: Self::Aux) -> Result<(Common<'a>, Self), Error>; +} + +pub(crate) trait BackendVisitor: InstructionVisitor { + fn before_instruction(&mut self); + fn after_instruction(&mut self); +} + +polkavm_common::program::implement_instruction_visitor!(impl<'a> VisitorWrapper<'a, Vec>, push); + +impl<'a> BackendVisitor for VisitorWrapper<'a, Vec> { + fn before_instruction(&mut self) {} + fn after_instruction(&mut self) {} +} + +pub(crate) struct Common<'a> { + pub(crate) code: &'a [u8], + pub(crate) config: &'a ModuleConfig, + pub(crate) imports: &'a BTreeMap>, + pub(crate) jump_table_index_by_basic_block: &'a HashMap, + pub(crate) instruction_by_basic_block: Vec, + pub(crate) gas_cost_for_basic_block: Vec, + pub(crate) maximum_seen_jump_target: u32, + pub(crate) nth_instruction: usize, + pub(crate) instruction_count: usize, + pub(crate) basic_block_count: usize, + pub(crate) block_in_progress: bool, + pub(crate) current_instruction_offset: usize, +} + +impl<'a> Common<'a> { + pub(crate) fn is_last_instruction(&self) -> bool { + self.nth_instruction + 1 == self.instruction_count + } +} + +pub(crate) struct VisitorWrapper<'a, T> { + pub(crate) common: Common<'a>, + pub(crate) visitor: T, +} + +impl<'a, T> core::ops::Deref for VisitorWrapper<'a, T> { + type Target = T; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.visitor + } +} + +impl<'a, T> core::ops::DerefMut for VisitorWrapper<'a, T> { + #[inline] + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.visitor + } +} + +#[repr(transparent)] +pub(crate) struct CommonVisitor<'a, T>(VisitorWrapper<'a, T>); + +impl<'a, T> core::ops::Deref for CommonVisitor<'a, T> { + type Target = Common<'a>; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.0.common + } +} + +impl<'a, T> core::ops::DerefMut for CommonVisitor<'a, T> { + #[inline] + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0.common + } +} + +impl<'a, T> CommonVisitor<'a, T> +where + VisitorWrapper<'a, T>: BackendVisitor, +{ + fn start_new_basic_block(&mut self) -> Result<(), Error> { + if !self.is_last_instruction() { + let nth = (self.nth_instruction + 1) as u32; + self.instruction_by_basic_block.push(nth); + } + + if self.instruction_by_basic_block.len() > self.basic_block_count { + bail!("program contains an invalid basic block count"); + } + + self.block_in_progress = false; + Ok(()) + } + + fn branch(&mut self, jump_target: u32, cb: impl FnOnce(&mut VisitorWrapper<'a, T>)) -> Result<(), Error> { + self.maximum_seen_jump_target = core::cmp::max(self.maximum_seen_jump_target, jump_target); + + self.start_new_basic_block()?; + self.0.before_instruction(); + cb(&mut self.0); + Ok(()) + } +} + +impl<'a, T> polkavm_common::program::ParsingVisitor for CommonVisitor<'a, T> +where + VisitorWrapper<'a, T>: BackendVisitor, +{ + #[inline] + fn on_pre_visit(&mut self, offset: usize, _opcode: u8) -> Self::ReturnTy { + if self.config.gas_metering.is_some() { + // TODO: Come up with a better cost model. + *self.gas_cost_for_basic_block.last_mut().unwrap() += 1; + } + + self.current_instruction_offset = offset; + self.block_in_progress = true; + Ok(()) + } + + #[inline] + fn on_post_visit(&mut self) -> Self::ReturnTy { + self.0.after_instruction(); + self.nth_instruction += 1; + Ok(()) + } +} + +impl<'a, T> polkavm_common::program::InstructionVisitor for CommonVisitor<'a, T> +where + VisitorWrapper<'a, T>: BackendVisitor, +{ + type ReturnTy = Result<(), Error>; + + fn trap(&mut self) -> Self::ReturnTy { + self.start_new_basic_block()?; + self.0.before_instruction(); + self.0.trap(); + Ok(()) + } + + fn fallthrough(&mut self) -> Self::ReturnTy { + self.start_new_basic_block()?; + self.0.before_instruction(); + self.0.fallthrough(); + Ok(()) + } + + fn ecalli(&mut self, imm: u32) -> Self::ReturnTy { + if self.imports.get(&imm).is_none() { + bail!("found an unrecognized ecall number: {imm:}"); + } + + self.0.before_instruction(); + self.0.ecalli(imm); + Ok(()) + } + + fn set_less_than_unsigned(&mut self, d: Reg, s1: Reg, s2: Reg) -> Self::ReturnTy { + self.0.before_instruction(); + self.0.set_less_than_unsigned(d, s1, s2); + Ok(()) + } + + fn set_less_than_signed(&mut self, d: Reg, s1: Reg, s2: Reg) -> Self::ReturnTy { + self.0.before_instruction(); + self.0.set_less_than_signed(d, s1, s2); + Ok(()) + } + + fn shift_logical_right(&mut self, d: Reg, s1: Reg, s2: Reg) -> Self::ReturnTy { + self.0.before_instruction(); + self.0.shift_logical_right(d, s1, s2); + Ok(()) + } + + fn shift_arithmetic_right(&mut self, d: Reg, s1: Reg, s2: Reg) -> Self::ReturnTy { + self.0.before_instruction(); + self.0.shift_arithmetic_right(d, s1, s2); + Ok(()) + } + + fn shift_logical_left(&mut self, d: Reg, s1: Reg, s2: Reg) -> Self::ReturnTy { + self.0.before_instruction(); + self.0.shift_logical_left(d, s1, s2); + Ok(()) + } + + fn xor(&mut self, d: Reg, s1: Reg, s2: Reg) -> Self::ReturnTy { + self.0.before_instruction(); + self.0.xor(d, s1, s2); + Ok(()) + } + + fn and(&mut self, d: Reg, s1: Reg, s2: Reg) -> Self::ReturnTy { + self.0.before_instruction(); + self.0.and(d, s1, s2); + Ok(()) + } + + fn or(&mut self, d: Reg, s1: Reg, s2: Reg) -> Self::ReturnTy { + self.0.before_instruction(); + self.0.or(d, s1, s2); + Ok(()) + } + + fn add(&mut self, d: Reg, s1: Reg, s2: Reg) -> Self::ReturnTy { + self.0.before_instruction(); + self.0.add(d, s1, s2); + Ok(()) + } + + fn sub(&mut self, d: Reg, s1: Reg, s2: Reg) -> Self::ReturnTy { + self.0.before_instruction(); + self.0.sub(d, s1, s2); + Ok(()) + } + + fn mul(&mut self, d: Reg, s1: Reg, s2: Reg) -> Self::ReturnTy { + self.0.before_instruction(); + self.0.mul(d, s1, s2); + Ok(()) + } + + fn mul_upper_signed_signed(&mut self, d: Reg, s1: Reg, s2: Reg) -> Self::ReturnTy { + self.0.before_instruction(); + self.0.mul_upper_signed_signed(d, s1, s2); + Ok(()) + } + + fn mul_upper_unsigned_unsigned(&mut self, d: Reg, s1: Reg, s2: Reg) -> Self::ReturnTy { + self.0.before_instruction(); + self.0.mul_upper_unsigned_unsigned(d, s1, s2); + Ok(()) + } + + fn mul_upper_signed_unsigned(&mut self, d: Reg, s1: Reg, s2: Reg) -> Self::ReturnTy { + self.0.before_instruction(); + self.0.mul_upper_signed_unsigned(d, s1, s2); + Ok(()) + } + + fn div_unsigned(&mut self, d: Reg, s1: Reg, s2: Reg) -> Self::ReturnTy { + self.0.before_instruction(); + self.0.div_unsigned(d, s1, s2); + Ok(()) + } + + fn div_signed(&mut self, d: Reg, s1: Reg, s2: Reg) -> Self::ReturnTy { + self.0.before_instruction(); + self.0.div_signed(d, s1, s2); + Ok(()) + } + + fn rem_unsigned(&mut self, d: Reg, s1: Reg, s2: Reg) -> Self::ReturnTy { + self.0.before_instruction(); + self.0.rem_unsigned(d, s1, s2); + Ok(()) + } + + fn rem_signed(&mut self, d: Reg, s1: Reg, s2: Reg) -> Self::ReturnTy { + self.0.before_instruction(); + self.0.rem_signed(d, s1, s2); + Ok(()) + } + + fn set_less_than_unsigned_imm(&mut self, d: Reg, s: Reg, imm: u32) -> Self::ReturnTy { + self.0.before_instruction(); + self.0.set_less_than_unsigned_imm(d, s, imm); + Ok(()) + } + + fn set_less_than_signed_imm(&mut self, d: Reg, s: Reg, imm: u32) -> Self::ReturnTy { + self.0.before_instruction(); + self.0.set_less_than_signed_imm(d, s, imm); + Ok(()) + } + + fn shift_logical_right_imm(&mut self, d: Reg, s: Reg, imm: u32) -> Self::ReturnTy { + self.0.before_instruction(); + self.0.shift_logical_right_imm(d, s, imm); + Ok(()) + } + + fn shift_arithmetic_right_imm(&mut self, d: Reg, s: Reg, imm: u32) -> Self::ReturnTy { + self.0.before_instruction(); + self.0.shift_arithmetic_right_imm(d, s, imm); + Ok(()) + } + + fn shift_logical_left_imm(&mut self, d: Reg, s: Reg, imm: u32) -> Self::ReturnTy { + self.0.before_instruction(); + self.0.shift_logical_left_imm(d, s, imm); + Ok(()) + } + + fn or_imm(&mut self, d: Reg, s: Reg, imm: u32) -> Self::ReturnTy { + self.0.before_instruction(); + self.0.or_imm(d, s, imm); + Ok(()) + } + + fn and_imm(&mut self, d: Reg, s: Reg, imm: u32) -> Self::ReturnTy { + self.0.before_instruction(); + self.0.and_imm(d, s, imm); + Ok(()) + } + + fn xor_imm(&mut self, d: Reg, s: Reg, imm: u32) -> Self::ReturnTy { + self.0.before_instruction(); + self.0.xor_imm(d, s, imm); + Ok(()) + } + + fn add_imm(&mut self, d: Reg, s: Reg, imm: u32) -> Self::ReturnTy { + self.0.before_instruction(); + self.0.add_imm(d, s, imm); + Ok(()) + } + + fn store_u8(&mut self, src: Reg, base: Reg, offset: u32) -> Self::ReturnTy { + self.0.before_instruction(); + self.0.store_u8(src, base, offset); + Ok(()) + } + + fn store_u16(&mut self, src: Reg, base: Reg, offset: u32) -> Self::ReturnTy { + self.0.before_instruction(); + self.0.store_u16(src, base, offset); + Ok(()) + } + + fn store_u32(&mut self, src: Reg, base: Reg, offset: u32) -> Self::ReturnTy { + self.0.before_instruction(); + self.0.store_u32(src, base, offset); + Ok(()) + } + + fn load_u8(&mut self, dst: Reg, base: Reg, offset: u32) -> Self::ReturnTy { + self.0.before_instruction(); + self.0.load_u8(dst, base, offset); + Ok(()) + } + + fn load_i8(&mut self, dst: Reg, base: Reg, offset: u32) -> Self::ReturnTy { + self.0.before_instruction(); + self.0.load_i8(dst, base, offset); + Ok(()) + } + + fn load_u16(&mut self, dst: Reg, base: Reg, offset: u32) -> Self::ReturnTy { + self.0.before_instruction(); + self.0.load_u16(dst, base, offset); + Ok(()) + } + + fn load_i16(&mut self, dst: Reg, base: Reg, offset: u32) -> Self::ReturnTy { + self.0.before_instruction(); + self.0.load_i16(dst, base, offset); + Ok(()) + } + + fn load_u32(&mut self, dst: Reg, base: Reg, offset: u32) -> Self::ReturnTy { + self.0.before_instruction(); + self.0.load_u32(dst, base, offset); + Ok(()) + } + + fn branch_less_unsigned(&mut self, s1: Reg, s2: Reg, imm: u32) -> Self::ReturnTy { + self.branch(imm, move |backend| backend.branch_less_unsigned(s1, s2, imm)) + } + + fn branch_less_signed(&mut self, s1: Reg, s2: Reg, imm: u32) -> Self::ReturnTy { + self.branch(imm, move |backend| backend.branch_less_signed(s1, s2, imm)) + } + + fn branch_greater_or_equal_unsigned(&mut self, s1: Reg, s2: Reg, imm: u32) -> Self::ReturnTy { + self.branch(imm, move |backend| backend.branch_greater_or_equal_unsigned(s1, s2, imm)) + } + + fn branch_greater_or_equal_signed(&mut self, s1: Reg, s2: Reg, imm: u32) -> Self::ReturnTy { + self.branch(imm, move |backend| backend.branch_greater_or_equal_signed(s1, s2, imm)) + } + + fn branch_eq(&mut self, s1: Reg, s2: Reg, imm: u32) -> Self::ReturnTy { + self.branch(imm, move |backend| backend.branch_eq(s1, s2, imm)) + } + + fn branch_not_eq(&mut self, s1: Reg, s2: Reg, imm: u32) -> Self::ReturnTy { + self.branch(imm, move |backend| backend.branch_not_eq(s1, s2, imm)) + } + + fn jump_and_link_register(&mut self, ra: Reg, base: Reg, offset: u32) -> Self::ReturnTy { + if ra != Reg::Zero { + let return_basic_block = self.instruction_by_basic_block.len() as u32; + if !self.jump_table_index_by_basic_block.contains_key(&return_basic_block) { + bail!("found a call instruction where the next basic block is not part of the jump table"); + } + } + + if base == Reg::Zero { + self.maximum_seen_jump_target = core::cmp::max(self.maximum_seen_jump_target, offset); + } + + self.start_new_basic_block()?; + self.0.before_instruction(); + self.0.jump_and_link_register(ra, base, offset); + Ok(()) + } +} + impl Module { pub(crate) fn is_debug_trace_execution_enabled(&self) -> bool { self.0.debug_trace_execution } - pub(crate) fn instructions(&self) -> &[RawInstruction] { - &self.0.instructions + pub(crate) fn instructions(&self) -> &[Instruction] { + &self.interpreted_module().unwrap().instructions } pub(crate) fn compiled_module(&self) -> &CompiledModuleKind { @@ -367,8 +793,8 @@ impl Module { self.0.interpreted_module.as_ref() } - pub(crate) fn blob(&self) -> Option<&ProgramBlob<'static>> { - self.0.blob.as_ref() + pub(crate) fn blob(&self) -> &ProgramBlob<'static> { + &self.0.blob } pub(crate) fn get_export(&self, export_index: usize) -> Option<&ProgramExport> { @@ -409,187 +835,77 @@ impl Module { /// Creates a new module from a deserialized program `blob`. pub fn from_blob(engine: &Engine, config: &ModuleConfig, blob: &ProgramBlob) -> Result { - log::trace!("Parsing imports..."); - let mut imports = BTreeMap::new(); - for import in blob.imports() { - let import = import.map_err(Error::from_display)?; - if import.index() & (1 << 31) != 0 { - bail!("out of range import index"); - } - - if imports.insert(import.index(), import).is_some() { - bail!("duplicate import index"); - } - - if imports.len() > VM_MAXIMUM_IMPORT_COUNT as usize { - bail!( - "too many imports; the program contains more than {} imports", - VM_MAXIMUM_IMPORT_COUNT - ); - } - } - - log::trace!("Parsing jump table..."); - let mut basic_block_by_jump_table_index = Vec::with_capacity(blob.jump_table_upper_bound() + 1); - - // The very first entry is always invalid. - basic_block_by_jump_table_index.push(u32::MAX); - - let mut maximum_seen_jump_target = 0; - for nth_basic_block in blob.jump_table() { - let nth_basic_block = nth_basic_block.map_err(Error::from_display)?; - maximum_seen_jump_target = core::cmp::max(maximum_seen_jump_target, nth_basic_block); - basic_block_by_jump_table_index.push(nth_basic_block); - } - - basic_block_by_jump_table_index.shrink_to_fit(); - - let jump_table_index_by_basic_block: HashMap<_, _> = basic_block_by_jump_table_index - .iter() - .copied() - .enumerate() - .map(|(jump_table_index, nth_basic_block)| (nth_basic_block, jump_table_index as u32)) - .collect(); - - let mut gas_cost_for_basic_block: Vec = Vec::new(); - let mut last_instruction = None; - let mut jump_count = 0; - - log::trace!("Parsing code..."); - let (instructions, instruction_by_basic_block) = { - let mut instruction_by_basic_block = Vec::with_capacity(blob.code().len() / 3); - instruction_by_basic_block.push(0); - if config.gas_metering.is_some() { - gas_cost_for_basic_block.push(0); - } - - let mut instructions = Vec::with_capacity(blob.code().len() / 3); - for (nth_instruction, instruction) in blob.instructions().enumerate() { - let nth_instruction = nth_instruction as u32; - let instruction = instruction.map_err(Error::from_display)?; - last_instruction = Some(instruction); + // Do an early check for memory config validity. + GuestMemoryConfig::new( + blob.ro_data().len() as u64, + blob.rw_data().len() as u64, + blob.bss_size() as u64, + blob.stack_size() as u64, + ) + .map_err(Error::from_static_str)?; - if config.gas_metering.is_some() { - // TODO: Come up with a better cost model. - *gas_cost_for_basic_block.last_mut().unwrap() += 1; + let imports = { + log::trace!("Parsing imports..."); + let mut imports = BTreeMap::new(); + for import in blob.imports() { + let import = import.map_err(Error::from_display)?; + if import.index() & (1 << 31) != 0 { + bail!("out of range import index"); } - match instruction.op() { - Opcode::fallthrough => { - instruction_by_basic_block.push(nth_instruction + 1); - if config.gas_metering.is_some() { - gas_cost_for_basic_block.push(0); - } - } - Opcode::jump_and_link_register => { - let ra = instruction.reg1(); - if ra != Reg::Zero { - let return_basic_block = instruction_by_basic_block.len() as u32; - if !jump_table_index_by_basic_block.contains_key(&return_basic_block) { - bail!("found a call instruction where the next basic block is not part of the jump table"); - } - } - - let base = instruction.reg2(); - if base == Reg::Zero { - maximum_seen_jump_target = core::cmp::max(maximum_seen_jump_target, instruction.raw_imm_or_reg3()); - } - - instruction_by_basic_block.push(nth_instruction + 1); - if config.gas_metering.is_some() { - gas_cost_for_basic_block.push(0); - } - - jump_count += 1; - } - Opcode::trap => { - instruction_by_basic_block.push(nth_instruction + 1); - if config.gas_metering.is_some() { - gas_cost_for_basic_block.push(0); - } - } - Opcode::branch_less_unsigned - | Opcode::branch_less_signed - | Opcode::branch_greater_or_equal_unsigned - | Opcode::branch_greater_or_equal_signed - | Opcode::branch_eq - | Opcode::branch_not_eq => { - instruction_by_basic_block.push(nth_instruction + 1); - maximum_seen_jump_target = core::cmp::max(maximum_seen_jump_target, instruction.raw_imm_or_reg3()); - - if config.gas_metering.is_some() { - gas_cost_for_basic_block.push(0); - } + if imports.insert(import.index(), import).is_some() { + bail!("duplicate import index"); + } - jump_count += 1; - } - Opcode::ecalli => { - let nr = instruction.raw_imm_or_reg3(); - if imports.get(&nr).is_none() { - bail!("found an unrecognized ecall number: {nr:}"); - } - } - _ => {} + if imports.len() > VM_MAXIMUM_IMPORT_COUNT as usize { + bail!( + "too many imports; the program contains more than {} imports", + VM_MAXIMUM_IMPORT_COUNT + ); } - instructions.push(instruction); } - - instruction_by_basic_block.shrink_to_fit(); - gas_cost_for_basic_block.shrink_to_fit(); - (instructions, instruction_by_basic_block) + imports }; - let basic_block_count = instruction_by_basic_block.len(); + let (initial_maximum_seen_jump_target, basic_block_by_jump_table_index, jump_table_index_by_basic_block) = { + log::trace!("Parsing jump table..."); + let mut basic_block_by_jump_table_index = Vec::with_capacity(blob.jump_table_upper_bound() + 1); - { - let Some(last_instruction) = last_instruction else { - bail!("the module contains no code"); - }; + // The very first entry is always invalid. + basic_block_by_jump_table_index.push(u32::MAX); - match last_instruction.op() { - Opcode::fallthrough - | Opcode::jump_and_link_register - | Opcode::trap - | Opcode::branch_less_unsigned - | Opcode::branch_less_signed - | Opcode::branch_greater_or_equal_unsigned - | Opcode::branch_greater_or_equal_signed - | Opcode::branch_eq - | Opcode::branch_not_eq => {} - _ => { - bail!("code doesn't end with a control flow instruction") - } + let mut maximum_seen_jump_target = 0; + for nth_basic_block in blob.jump_table() { + let nth_basic_block = nth_basic_block.map_err(Error::from_display)?; + maximum_seen_jump_target = core::cmp::max(maximum_seen_jump_target, nth_basic_block); + basic_block_by_jump_table_index.push(nth_basic_block); } - } - if instructions.len() > VM_MAXIMUM_INSTRUCTION_COUNT as usize { - bail!( - "too many instructions; the program contains more than {} instructions", - VM_MAXIMUM_INSTRUCTION_COUNT - ); - } + basic_block_by_jump_table_index.shrink_to_fit(); - debug_assert!(!instruction_by_basic_block.is_empty()); - let maximum_valid_jump_target = (instruction_by_basic_block.len() - 1) as u32; - if maximum_seen_jump_target > maximum_valid_jump_target { - bail!("out of range jump found; found a jump to @{maximum_seen_jump_target:x}, while the very last valid jump target is @{maximum_valid_jump_target:x}"); - } + let jump_table_index_by_basic_block: HashMap<_, _> = basic_block_by_jump_table_index + .iter() + .copied() + .enumerate() + .map(|(jump_table_index, nth_basic_block)| (nth_basic_block, jump_table_index as u32)) + .collect(); + + ( + maximum_seen_jump_target, + basic_block_by_jump_table_index, + jump_table_index_by_basic_block, + ) + }; - log::trace!("Parsing exports..."); - let exports = { + let (maximum_export_jump_target, exports) = { + log::trace!("Parsing exports..."); + let mut maximum_export_jump_target = 0; let mut exports = Vec::with_capacity(1); for export in blob.exports() { let export = export.map_err(Error::from_display)?; - if export.address() > maximum_valid_jump_target { - bail!( - "out of range export found; export '{}' points to @{:x}, while the very last valid jump target is @{maximum_valid_jump_target:x}", - export.prototype().name(), - export.address() - ); - } + maximum_export_jump_target = core::cmp::max(maximum_export_jump_target, export.address()); exports.push(export); - if exports.len() > VM_MAXIMUM_EXPORT_COUNT as usize { bail!( "too many exports; the program contains more than {} exports", @@ -597,85 +913,196 @@ impl Module { ); } } - exports + (maximum_export_jump_target, exports) }; - log::trace!("Parsing finished!"); - - // Do an early check for memory config validity. - GuestMemoryConfig::new( - blob.ro_data().len() as u64, - blob.rw_data().len() as u64, - blob.bss_size() as u64, - blob.stack_size() as u64, - ) - .map_err(Error::from_static_str)?; - let init = GuestProgramInit::new() .with_ro_data(blob.ro_data()) .with_rw_data(blob.rw_data()) .with_bss(blob.bss_size()) .with_stack(blob.stack_size()); - let compiled_module = if_compiler_is_supported! { + macro_rules! new_common { + () => {{ + let mut common = Common { + code: blob.code(), + config, + imports: &imports, + jump_table_index_by_basic_block: &jump_table_index_by_basic_block, + instruction_by_basic_block: Vec::new(), + gas_cost_for_basic_block: Vec::new(), + maximum_seen_jump_target: initial_maximum_seen_jump_target, + nth_instruction: 0, + instruction_count: blob.instruction_count() as usize, + basic_block_count: blob.basic_block_count() as usize, + block_in_progress: false, + current_instruction_offset: 0, + }; + + common.instruction_by_basic_block.reserve(common.basic_block_count + 1); + common.instruction_by_basic_block.push(0); + if config.gas_metering.is_some() { + common.gas_cost_for_basic_block.resize(common.basic_block_count, 0); + } + + common + }}; + } + + #[allow(unused_macros)] + macro_rules! compile_module { + ($sandbox_kind:ident, $module_kind:ident) => {{ + let (visitor, aux) = CompiledModule::<$sandbox_kind>::create_visitor( + config, + &exports, + &basic_block_by_jump_table_index, + &jump_table_index_by_basic_block, + init, + blob.instruction_count() as usize, + blob.basic_block_count() as usize, + engine.debug_trace_execution, + )?; + + let common = new_common!(); + type VisitorTy<'a> = CommonVisitor<'a, crate::compiler::Compiler<'a>>; + let visitor: VisitorTy = CommonVisitor(VisitorWrapper { common, visitor }); + let run = polkavm_common::program::prepare_visitor!(VisitorTy<'a>); + let (visitor, result) = run(blob, visitor); + result?; + + let (common, module) = CompiledModule::<$sandbox_kind>::finish_compilation(visitor.0, aux)?; + Some((common, CompiledModuleKind::$module_kind(module))) + }}; + } + + let compiled: Option<(Common, CompiledModuleKind)> = if_compiler_is_supported! { { match engine.selected_sandbox { + _ if engine.selected_backend != BackendKind::Compiler => None, Some(SandboxKind::Linux) => { #[cfg(target_os = "linux")] { - let module = CompiledModule::new( - config, - &instructions, - &exports, - &basic_block_by_jump_table_index, - &jump_table_index_by_basic_block, - &gas_cost_for_basic_block, - init, - engine.debug_trace_execution, - jump_count, - basic_block_count, - )?; - CompiledModuleKind::Linux(module) + compile_module!(SandboxLinux, Linux) } #[cfg(not(target_os = "linux"))] { log::debug!("Selected sandbox unavailable!"); - CompiledModuleKind::Unavailable + None } }, Some(SandboxKind::Generic) => { - let module = CompiledModule::new( - config, - &instructions, - &exports, - &basic_block_by_jump_table_index, - &jump_table_index_by_basic_block, - &gas_cost_for_basic_block, - init, - engine.debug_trace_execution, - jump_count, - basic_block_count, - )?; - CompiledModuleKind::Generic(module) - } - None => CompiledModuleKind::Unavailable + compile_module!(SandboxGeneric, Generic) + }, + None => None } } else {{ - let _ = jump_count; - let _ = basic_block_count; - CompiledModuleKind::Unavailable + None }} }; - let interpreted_module = if engine.interpreter_enabled { - Some(InterpretedModule::new(init, gas_cost_for_basic_block)?) + let interpreted: Option<(Common, InterpretedModule)> = if engine.interpreter_enabled { + let common = new_common!(); + type VisitorTy<'a> = CommonVisitor<'a, Vec>; + let instructions = Vec::with_capacity(blob.instruction_count() as usize); + let visitor: VisitorTy = CommonVisitor(VisitorWrapper { + common, + visitor: instructions, + }); + let run = polkavm_common::program::prepare_visitor!(VisitorTy<'a>); + let (visitor, result) = run(blob, visitor); + result?; + + let CommonVisitor(VisitorWrapper { + mut common, + visitor: instructions, + }) = visitor; + + let module = InterpretedModule::new(init, core::mem::take(&mut common.gas_cost_for_basic_block), instructions)?; + Some((common, module)) } else { None }; - assert!(compiled_module.is_some() || interpreted_module.is_some()); + let mut common = None; + let compiled_module = if let Some((compiled_common, compiled_module)) = compiled { + common = Some(compiled_common); + compiled_module + } else { + CompiledModuleKind::Unavailable + }; + let interpreted_module = if let Some((interpreted_common, interpreted_module)) = interpreted { + if common.is_none() { + common = Some(interpreted_common); + } + Some(interpreted_module) + } else { + None + }; + + let common = common.unwrap(); + if common.nth_instruction == 0 { + bail!("the module contains no code"); + } + + if common.block_in_progress { + bail!("code doesn't end with a control flow instruction"); + } + + if common.nth_instruction > VM_MAXIMUM_INSTRUCTION_COUNT as usize { + bail!( + "too many instructions; the program contains more than {} instructions", + VM_MAXIMUM_INSTRUCTION_COUNT + ); + } + + if common.nth_instruction != common.instruction_count { + bail!( + "program contains an invalid instruction count (expected {}, found {})", + common.instruction_count, + common.nth_instruction + ); + } + + if common.instruction_by_basic_block.len() != common.basic_block_count { + bail!( + "program contains an invalid basic block count (expected {}, found {})", + common.basic_block_count, + common.instruction_by_basic_block.len() + ); + } + + debug_assert!(!common.instruction_by_basic_block.is_empty()); + let maximum_valid_jump_target = (common.instruction_by_basic_block.len() - 1) as u32; + if common.maximum_seen_jump_target > maximum_valid_jump_target { + bail!( + "out of range jump found; found a jump to @{:x}, while the very last valid jump target is @{maximum_valid_jump_target:x}", + common.maximum_seen_jump_target + ); + } + + if maximum_export_jump_target > maximum_valid_jump_target { + let export = exports + .iter() + .find(|export| export.address() == maximum_export_jump_target) + .unwrap(); + bail!( + "out of range export found; export '{}' points to @{:x}, while the very last valid jump target is @{maximum_valid_jump_target:x}", + export.prototype().name(), + export.address(), + ); + } + + let instruction_by_basic_block = { + let mut vec = common.instruction_by_basic_block; + vec.shrink_to_fit(); + vec + }; + + log::trace!("Processing finished!"); + + assert!(compiled_module.is_some() || interpreted_module.is_some()); if compiled_module.is_some() { log::debug!("Backend used: 'compiled'"); } else { @@ -710,7 +1137,6 @@ impl Module { Ok(Module(Arc::new(ModulePrivate { debug_trace_execution: engine.debug_trace_execution, - instructions, exports, imports, export_index_by_name, @@ -719,11 +1145,8 @@ impl Module { jump_table_index_by_basic_block, basic_block_by_jump_table_index, - blob: if engine.debug_trace_execution || engine.selected_backend == BackendKind::Interpreter { - Some(blob.clone().into_owned()) - } else { - None - }, + // TODO: Remove the clone. + blob: blob.clone().into_owned(), compiled_module, interpreted_module, memory_config, @@ -808,11 +1231,7 @@ impl Module { pub(crate) fn debug_print_location(&self, log_level: log::Level, pc: u32) { log::log!(log_level, " At #{pc}:"); - let Some(blob) = self.blob() else { - log::log!(log_level, " (no location available)"); - return; - }; - + let blob = self.blob(); let Ok(Some(mut line_program)) = blob.get_debug_line_program_at(pc) else { log::log!(log_level, " (no location available)"); return; diff --git a/crates/polkavm/src/compiler.rs b/crates/polkavm/src/compiler.rs index e9f6ab45..d6750a05 100644 --- a/crates/polkavm/src/compiler.rs +++ b/crates/polkavm/src/compiler.rs @@ -5,13 +5,13 @@ use std::sync::Arc; use polkavm_assembler::{Assembler, Label}; use polkavm_common::error::{ExecutionError, Trap}; use polkavm_common::init::GuestProgramInit; -use polkavm_common::program::{InstructionVisitor, ProgramExport, RawInstruction}; +use polkavm_common::program::{ProgramExport, Instruction}; use polkavm_common::zygote::{ AddressTable, VM_COMPILER_MAXIMUM_EPILOGUE_LENGTH, VM_COMPILER_MAXIMUM_INSTRUCTION_LENGTH, }; use polkavm_common::abi::VM_CODE_ADDRESS_ALIGNMENT; -use crate::api::{BackendAccess, EngineState, ExecutionConfig, Module, OnHostcall, SandboxExt}; +use crate::api::{BackendAccess, EngineState, ExecutionConfig, Module, OnHostcall, SandboxExt, VisitorWrapper}; use crate::error::{bail, Error}; use crate::sandbox::{Sandbox, SandboxProgram, SandboxProgramInit, ExecuteArgs}; @@ -20,15 +20,14 @@ use crate::config::{GasMeteringKind, ModuleConfig, SandboxKind}; #[cfg(target_arch = "x86_64")] mod amd64; -struct Compiler<'a> { +pub(crate) struct Compiler<'a> { asm: Assembler, exports: &'a [ProgramExport<'a>], - instructions: &'a [RawInstruction], basic_block_by_jump_table_index: &'a [u32], jump_table_index_by_basic_block: &'a HashMap, - gas_cost_for_basic_block: &'a [u32], nth_basic_block_to_label: Vec