From 3e61c95c7b455bbff0ff485c43f7aa48a90960fa Mon Sep 17 00:00:00 2001 From: Jarkko Sakkinen Date: Mon, 11 Nov 2024 07:59:55 +0200 Subject: [PATCH] Protect VmCtx and ID_COUNTER with RwLock AtomicU64 causes compilation issues when used as a subcrate of another crate when the toolchain configurations conflict. The problem was observed with ID_COUNTER when Polkadot SDK was compiled for riscv32emac, which could not be found from core. The documentation also states that: "This type is only available on platforms that support atomic loads and stores of u64." Address this by protecting ID_COUNTER with `spin::RwLock` and `VmCtx` with `std::sync::RwLock`. Signed-off-by: Jarkko Sakkinen --- Cargo.lock | 26 ++ Cargo.toml | 1 + crates/polkavm-common/Cargo.toml | 1 + crates/polkavm-common/src/program.rs | 13 +- crates/polkavm-common/src/zygote.rs | 144 ++++++------ crates/polkavm-zygote/Cargo.lock | 33 +++ crates/polkavm-zygote/Cargo.toml | 1 + crates/polkavm-zygote/src/main.rs | 144 +++++------- crates/polkavm/src/sandbox/linux.rs | 339 +++++++++++++++------------ 9 files changed, 394 insertions(+), 308 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 53bedfee..6ff8e273 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -458,6 +458,16 @@ version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" +[[package]] +name = "lock_api" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +dependencies = [ + "autocfg", + "scopeguard", +] + [[package]] name = "log" version = "0.4.20" @@ -584,6 +594,7 @@ dependencies = [ "log", "polkavm-assembler", "proptest", + "spin", ] [[package]] @@ -868,6 +879,12 @@ dependencies = [ "hashbrown 0.13.2", ] +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + [[package]] name = "sdl2" version = "0.35.2" @@ -956,6 +973,15 @@ dependencies = [ "serde_json", ] +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +dependencies = [ + "lock_api", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" diff --git a/Cargo.toml b/Cargo.toml index 01ab5657..03b45ac0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,6 +60,7 @@ ruzstd = { version = "0.4.0", default-features = false } schnellru = { version = "0.2.3" } serde = { version = "1.0.203", features = ["derive"] } serde_json = { version = "1.0.117" } +spin = { version = "0.9.8", default-features = false, features = ["lock_api", "spin_mutex", "rwlock", "lazy"] } syn = "2.0.25" yansi = "0.5.1" diff --git a/crates/polkavm-common/Cargo.toml b/crates/polkavm-common/Cargo.toml index 612afc0c..2efd5f8e 100644 --- a/crates/polkavm-common/Cargo.toml +++ b/crates/polkavm-common/Cargo.toml @@ -12,6 +12,7 @@ description = "The common crate for PolkaVM" log = { workspace = true, optional = true } polkavm-assembler = { workspace = true, optional = true } blake3 = { workspace = true, optional = true } +spin = { workspace = true } [features] default = [] diff --git a/crates/polkavm-common/src/program.rs b/crates/polkavm-common/src/program.rs index d10f2fc2..a179a4a7 100644 --- a/crates/polkavm-common/src/program.rs +++ b/crates/polkavm-common/src/program.rs @@ -4,6 +4,13 @@ use crate::varint::{read_simple_varint, read_varint, write_simple_varint, write_ use core::fmt::Write; use core::ops::Range; +#[cfg(feature = "unique-id")] +use spin::RwLock; +#[cfg(feature = "unique-id")] +struct UniqueId(u64); +#[cfg(feature = "unique-id")] +static ID_COUNTER: RwLock = RwLock::new(UniqueId(0)); + #[derive(Copy, Clone)] #[repr(transparent)] pub struct RawReg(u32); @@ -3980,8 +3987,10 @@ impl ProgramBlob { #[cfg(feature = "unique-id")] { - static ID_COUNTER: core::sync::atomic::AtomicU64 = core::sync::atomic::AtomicU64::new(0); - blob.unique_id = ID_COUNTER.fetch_add(1, core::sync::atomic::Ordering::Relaxed); + let mut counter = ID_COUNTER.write(); + *counter = UniqueId(counter.0 + 1); + blob.unique_id = counter.0; + // The lock is dropper here. } Ok(blob) diff --git a/crates/polkavm-common/src/zygote.rs b/crates/polkavm-common/src/zygote.rs index 28b51c0a..edf58e7e 100644 --- a/crates/polkavm-common/src/zygote.rs +++ b/crates/polkavm-common/src/zygote.rs @@ -3,8 +3,7 @@ //! In general everything here can be modified at will, provided the zygote //! is recompiled. -use core::cell::UnsafeCell; -use core::sync::atomic::{AtomicBool, AtomicI64, AtomicU32, AtomicU64}; +use core::sync::atomic::{AtomicBool, AtomicU32}; // Due to the limitations of Rust's compile time constant evaluation machinery // we need to define this struct multiple times. @@ -137,24 +136,24 @@ pub const VM_SANDBOX_MAXIMUM_NATIVE_CODE_SIZE: u32 = 2176 * 1024 * 1024 - 1; #[repr(C)] pub struct JmpBuf { - pub rip: AtomicU64, - pub rbx: AtomicU64, - pub rsp: AtomicU64, - pub rbp: AtomicU64, - pub r12: AtomicU64, - pub r13: AtomicU64, - pub r14: AtomicU64, - pub r15: AtomicU64, + pub rip: u64, + pub rbx: u64, + pub rsp: u64, + pub rbp: u64, + pub r12: u64, + pub r13: u64, + pub r14: u64, + pub r15: u64, } #[repr(C)] pub struct VmInit { - pub stack_address: AtomicU64, - pub stack_length: AtomicU64, - pub vdso_address: AtomicU64, - pub vdso_length: AtomicU64, - pub vvar_address: AtomicU64, - pub vvar_length: AtomicU64, + pub stack_address: u64, + pub stack_length: u64, + pub vdso_address: u64, + pub vdso_length: u64, + pub vvar_address: u64, + pub vvar_length: u64, /// Whether userfaultfd-based memory management is available. pub uffd_available: AtomicBool, @@ -190,16 +189,16 @@ impl core::ops::DerefMut for CacheAligned { #[repr(C)] pub struct VmCtxHeapInfo { - pub heap_top: UnsafeCell, - pub heap_threshold: UnsafeCell, + pub heap_top: u64, + pub heap_threshold: u64, } const REG_COUNT: usize = crate::program::Reg::ALL.len(); #[repr(C)] pub struct VmCtxCounters { - pub syscall_wait_loop_start: UnsafeCell, - pub syscall_futex_wait: UnsafeCell, + pub syscall_wait_loop_start: u64, + pub syscall_futex_wait: u64, } #[repr(C)] @@ -230,7 +229,7 @@ pub struct VmCtx { _align_1: CacheAligned<()>, /// The current gas counter. - pub gas: AtomicI64, + pub gas: i64, _align_2: CacheAligned<()>, @@ -238,7 +237,7 @@ pub struct VmCtx { pub futex: AtomicU32, /// Address to which to jump to. - pub jump_into: AtomicU64, + pub jump_into: u64, /// The address of the instruction currently being executed. pub program_counter: AtomicU32, @@ -253,10 +252,10 @@ pub struct VmCtx { pub arg: AtomicU32, /// A dump of all of the registers of the VM. - pub regs: [AtomicU64; REG_COUNT], + pub regs: [u64; REG_COUNT], /// The address of the native code to call inside of the VM process, if non-zero. - pub next_native_program_counter: AtomicU64, + pub next_native_program_counter: u64, /// The state of the program's heap. pub heap_info: VmCtxHeapInfo, @@ -264,35 +263,35 @@ pub struct VmCtx { pub arg2: AtomicU32, /// Offset in shared memory to this sandbox's memory map. - pub shm_memory_map_offset: AtomicU64, + pub shm_memory_map_offset: u64, /// Number of maps to map. - pub shm_memory_map_count: AtomicU64, + pub shm_memory_map_count: u64, /// Offset in shared memory to this sandbox's code. - pub shm_code_offset: AtomicU64, + pub shm_code_offset: u64, /// Length this sandbox's code. - pub shm_code_length: AtomicU64, + pub shm_code_length: u64, /// Offset in shared memory to this sandbox's jump table. - pub shm_jump_table_offset: AtomicU64, + pub shm_jump_table_offset: u64, /// Length of sandbox's jump table, in bytes. - pub shm_jump_table_length: AtomicU64, + pub shm_jump_table_length: u64, /// Address of the sysreturn routine. - pub sysreturn_address: AtomicU64, + pub sysreturn_address: u64, /// Whether userfaultfd-based memory management is enabled. pub uffd_enabled: AtomicBool, /// Address to the base of the heap. - pub heap_base: UnsafeCell, + pub heap_base: u32, /// The initial heap growth threshold. - pub heap_initial_threshold: UnsafeCell, + pub heap_initial_threshold: u32, /// The maximum heap size. - pub heap_max_size: UnsafeCell, + pub heap_max_size: u32, /// The page size. - pub page_size: UnsafeCell, + pub page_size: u32, /// Performance counters. Only for debugging. pub counters: CacheAligned, @@ -301,9 +300,10 @@ pub struct VmCtx { pub init: VmInit, /// Length of the message in the message buffer. - pub message_length: UnsafeCell, + pub message_length: u32, + /// A buffer used to marshal error messages. - pub message_buffer: UnsafeCell<[u8; MESSAGE_BUFFER_SIZE]>, + pub message_buffer: [u8; MESSAGE_BUFFER_SIZE], } // Make sure it fits within a single page on amd64. @@ -328,7 +328,7 @@ pub const VMCTX_FUTEX_GUEST_SIGNAL: u32 = VMCTX_FUTEX_IDLE | (3 << 1); pub const VMCTX_FUTEX_GUEST_STEP: u32 = VMCTX_FUTEX_IDLE | (4 << 1); #[allow(clippy::declare_interior_mutable_const)] -const ATOMIC_U64_ZERO: AtomicU64 = AtomicU64::new(0); +const ATOMIC_U64_ZERO: u64 = 0; #[allow(clippy::new_without_default)] impl VmCtx { @@ -338,64 +338,64 @@ impl VmCtx { _align_1: CacheAligned(()), _align_2: CacheAligned(()), - gas: AtomicI64::new(0), + gas: 0, program_counter: AtomicU32::new(0), next_program_counter: AtomicU32::new(0), arg: AtomicU32::new(0), arg2: AtomicU32::new(0), regs: [ATOMIC_U64_ZERO; REG_COUNT], - jump_into: AtomicU64::new(0), - next_native_program_counter: AtomicU64::new(0), + jump_into: 0, + next_native_program_counter: 0, futex: AtomicU32::new(VMCTX_FUTEX_BUSY), - shm_memory_map_offset: AtomicU64::new(0), - shm_memory_map_count: AtomicU64::new(0), - shm_code_offset: AtomicU64::new(0), - shm_code_length: AtomicU64::new(0), - shm_jump_table_offset: AtomicU64::new(0), - shm_jump_table_length: AtomicU64::new(0), + shm_memory_map_offset: 0, + shm_memory_map_count: 0, + shm_code_offset: 0, + shm_code_length: 0, + shm_jump_table_offset: 0, + shm_jump_table_length: 0, uffd_enabled: AtomicBool::new(false), - sysreturn_address: AtomicU64::new(0), - heap_base: UnsafeCell::new(0), - heap_initial_threshold: UnsafeCell::new(0), - heap_max_size: UnsafeCell::new(0), - page_size: UnsafeCell::new(0), + sysreturn_address: 0, + heap_base: 0, + heap_initial_threshold: 0, + heap_max_size: 0, + page_size: 0, heap_info: VmCtxHeapInfo { - heap_top: UnsafeCell::new(0), - heap_threshold: UnsafeCell::new(0), + heap_top: 0, + heap_threshold: 0, }, counters: CacheAligned(VmCtxCounters { - syscall_wait_loop_start: UnsafeCell::new(0), - syscall_futex_wait: UnsafeCell::new(0), + syscall_wait_loop_start: 0, + syscall_futex_wait: 0, }), init: VmInit { - stack_address: AtomicU64::new(0), - stack_length: AtomicU64::new(0), - vdso_address: AtomicU64::new(0), - vdso_length: AtomicU64::new(0), - vvar_address: AtomicU64::new(0), - vvar_length: AtomicU64::new(0), + stack_address: 0, + stack_length: 0, + vdso_address: 0, + vdso_length: 0, + vvar_address: 0, + vvar_length: 0, uffd_available: AtomicBool::new(false), sandbox_disabled: AtomicBool::new(false), logging_enabled: AtomicBool::new(false), idle_regs: JmpBuf { - rip: AtomicU64::new(0), - rbx: AtomicU64::new(0), - rsp: AtomicU64::new(0), - rbp: AtomicU64::new(0), - r12: AtomicU64::new(0), - r13: AtomicU64::new(0), - r14: AtomicU64::new(0), - r15: AtomicU64::new(0), + rip: 0, + rbx: 0, + rsp: 0, + rbp: 0, + r12: 0, + r13: 0, + r14: 0, + r15: 0, }, }, - message_length: UnsafeCell::new(0), - message_buffer: UnsafeCell::new([0; MESSAGE_BUFFER_SIZE]), + message_length: 0, + message_buffer: [0; MESSAGE_BUFFER_SIZE], } } diff --git a/crates/polkavm-zygote/Cargo.lock b/crates/polkavm-zygote/Cargo.lock index e37eb010..e017df55 100644 --- a/crates/polkavm-zygote/Cargo.lock +++ b/crates/polkavm-zygote/Cargo.lock @@ -2,6 +2,22 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "autocfg" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" + +[[package]] +name = "lock_api" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +dependencies = [ + "autocfg", + "scopeguard", +] + [[package]] name = "log" version = "0.4.22" @@ -20,6 +36,7 @@ name = "polkavm-common" version = "0.15.0" dependencies = [ "polkavm-assembler", + "spin", ] [[package]] @@ -32,4 +49,20 @@ version = "0.1.0" dependencies = [ "polkavm-common", "polkavm-linux-raw", + "spin", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +dependencies = [ + "lock_api", ] diff --git a/crates/polkavm-zygote/Cargo.toml b/crates/polkavm-zygote/Cargo.toml index 38c4e5d5..b4faefe3 100644 --- a/crates/polkavm-zygote/Cargo.toml +++ b/crates/polkavm-zygote/Cargo.toml @@ -9,6 +9,7 @@ publish = false [dependencies] polkavm-linux-raw = { path = "../polkavm-linux-raw" } polkavm-common = { path = "../polkavm-common", features = ["regmap"] } +spin = { version = "0.9.8", default-features = false, features = ["lock_api", "spin_mutex", "rwlock", "lazy"] } [build-dependencies] polkavm-common = { path = "../polkavm-common" } diff --git a/crates/polkavm-zygote/src/main.rs b/crates/polkavm-zygote/src/main.rs index cfc9e647..0336a8b9 100644 --- a/crates/polkavm-zygote/src/main.rs +++ b/crates/polkavm-zygote/src/main.rs @@ -4,14 +4,15 @@ use core::ptr::addr_of_mut; use core::sync::atomic::Ordering; -use core::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize}; +use core::sync::atomic::{AtomicBool, AtomicUsize}; +use spin::{Lazy, RwLock}; #[rustfmt::skip] use polkavm_common::{ utils::align_to_next_page_usize, zygote::{ self, - AddressTableRaw, ExtTableRaw, VmCtx as VmCtxInner, + AddressTableRaw, ExtTableRaw, VmCtx, VmMap, VmFd, JmpBuf, VM_ADDR_JUMP_TABLE_RETURN_TO_HOST, VM_ADDR_JUMP_TABLE, @@ -144,25 +145,11 @@ macro_rules! trace { }}; } -#[repr(transparent)] -pub struct VmCtx(VmCtxInner); - -unsafe impl Sync for VmCtx {} - -impl core::ops::Deref for VmCtx { - type Target = VmCtxInner; - - #[inline(always)] - fn deref(&self) -> &Self::Target { - &self.0 - } -} - #[no_mangle] #[link_section = ".vmctx"] #[used] // Use the `zeroed` constructor to make sure this doesn't take up any space in the executable. -pub static VMCTX: VmCtx = VmCtx(VmCtxInner::zeroed()); +static VMCTX: Lazy> = Lazy::new(|| RwLock::new(VmCtx::zeroed())); #[panic_handler] fn panic(_: &core::panic::PanicInfo) -> ! { @@ -192,9 +179,7 @@ unsafe fn memcpy(dst: *mut u8, src: *const u8, size: usize) -> *mut u8 { } fn reset_message() { - unsafe { - *VMCTX.message_length.get() = 0; - } + VMCTX.write().message_length = 0; } #[inline] @@ -202,8 +187,8 @@ fn append_to_message<'a, 'b>(mut input: &[u8]) where 'a: 'b, { - let message_length = unsafe { &mut *VMCTX.message_length.get() }; - let message_buffer = &mut unsafe { &mut *VMCTX.message_buffer.get() }[..]; + let message_length = &mut VMCTX.write().message_length; + let message_buffer = &mut VMCTX.write().message_buffer[..]; while !input.is_empty() && (*message_length as usize) < message_buffer.len() { message_buffer[*message_length as usize] = input[0]; @@ -319,7 +304,7 @@ unsafe extern "C" fn signal_handler(signal: u32, _info: &linux_raw::siginfo_t, c Hex(context.uc_mcontext.r15) ); - if rip < VM_ADDR_NATIVE_CODE || rip > VM_ADDR_NATIVE_CODE + VMCTX.shm_code_length.load(Ordering::Relaxed) { + if rip < VM_ADDR_NATIVE_CODE || rip > VM_ADDR_NATIVE_CODE + VMCTX.write().shm_code_length { abort_with_message("segmentation fault") } @@ -344,23 +329,23 @@ unsafe extern "C" fn signal_handler(signal: u32, _info: &linux_raw::siginfo_t, c r15 => context.uc_mcontext.r15, }; - VMCTX.regs[reg as usize].store(value, Ordering::Relaxed); + VMCTX.write().regs[reg as usize] = value; } - VMCTX.next_native_program_counter.store(rip, Ordering::Relaxed); + VMCTX.write().next_native_program_counter = rip; signal_host_and_longjmp(VMCTX_FUTEX_GUEST_SIGNAL); } static mut RESUME_MAIN_LOOP_JMPBUF: JmpBuf = JmpBuf { - rip: AtomicU64::new(0), - rbx: AtomicU64::new(0), - rsp: AtomicU64::new(0), - rbp: AtomicU64::new(0), - r12: AtomicU64::new(0), - r13: AtomicU64::new(0), - r14: AtomicU64::new(0), - r15: AtomicU64::new(0), + rip: 0, + rbx: 0, + rsp: 0, + rbp: 0, + r12: 0, + r13: 0, + r14: 0, + r15: 0, }; extern "C" { @@ -446,7 +431,7 @@ unsafe fn initialize(mut stack: *mut usize) { let vmctx_memfd = linux_raw::Fd::from_raw_unchecked(zygote::FD_VMCTX); linux_raw::sys_mmap( - &VMCTX as *const VmCtx as *mut core::ffi::c_void, + &*VMCTX.write() as *const VmCtx as *mut core::ffi::c_void, page_size, linux_raw::PROT_READ | linux_raw::PROT_WRITE, linux_raw::MAP_FIXED | linux_raw::MAP_SHARED, @@ -484,29 +469,29 @@ unsafe fn initialize(mut stack: *mut usize) { .unwrap_or_else(|error| abort_with_error("failed to mmap shared memory", error)); // Wait for the host to fill out vmctx. - VMCTX.futex.store(VMCTX_FUTEX_IDLE, Ordering::Release); + VMCTX.write().futex.store(VMCTX_FUTEX_IDLE, Ordering::Release); futex_wait_until(VMCTX_FUTEX_BUSY); // Unmap the original stack. linux_raw::sys_munmap( - VMCTX.init.stack_address.load(Ordering::Relaxed) as *mut core::ffi::c_void, - VMCTX.init.stack_length.load(Ordering::Relaxed) as usize, + VMCTX.write().init.stack_address as *mut core::ffi::c_void, + VMCTX.write().init.stack_length as usize, ) .unwrap_or_else(|error| abort_with_error("failed to unmap kernel-provided stack", error)); // We don't need the VDSO, so just unmap it. - if VMCTX.init.vdso_length.load(Ordering::Relaxed) != 0 { + if VMCTX.write().init.vdso_length != 0 { linux_raw::sys_munmap( - VMCTX.init.vdso_address.load(Ordering::Relaxed) as *mut core::ffi::c_void, - VMCTX.init.vdso_length.load(Ordering::Relaxed) as usize, + VMCTX.write().init.vdso_address as *mut core::ffi::c_void, + VMCTX.write().init.vdso_length as usize, ) .unwrap_or_else(|error| abort_with_error("failed to unmap [vdso]", error)); } - if VMCTX.init.vvar_length.load(Ordering::Relaxed) != 0 { + if VMCTX.write().init.vvar_length != 0 { linux_raw::sys_munmap( - VMCTX.init.vvar_address.load(Ordering::Relaxed) as *mut core::ffi::c_void, - VMCTX.init.vvar_length.load(Ordering::Relaxed) as usize, + VMCTX.write().init.vvar_address as *mut core::ffi::c_void, + VMCTX.write().init.vvar_length as usize, ) .unwrap_or_else(|error| abort_with_error("failed to unmap [vvar]", error)); } @@ -523,7 +508,7 @@ unsafe fn initialize(mut stack: *mut usize) { ) .unwrap_or_else(|error| abort_with_error("failed to make sure the jump table address space is unmapped", error)); - if VMCTX.init.uffd_available.load(Ordering::Relaxed) { + if VMCTX.write().init.uffd_available.load(Ordering::Relaxed) { // Set up and send the userfaultfd to the host. let userfaultfd = linux_raw::sys_userfaultfd(linux_raw::O_CLOEXEC) .unwrap_or_else(|error| abort_with_error("failed to create an userfaultfd", error)); @@ -624,7 +609,7 @@ unsafe fn initialize(mut stack: *mut usize) { .close() .unwrap_or_else(|error| abort_with_error("failed to close dummy stdin", error)); - if !VMCTX.init.logging_enabled.load(Ordering::Relaxed) { + if !VMCTX.write().init.logging_enabled.load(Ordering::Relaxed) { linux_raw::Fd::from_raw_unchecked(zygote::FD_LOGGER_STDOUT) .close() .unwrap_or_else(|error| abort_with_error("failed to close stdout logger", error)); @@ -634,7 +619,7 @@ unsafe fn initialize(mut stack: *mut usize) { .unwrap_or_else(|error| abort_with_error("failed to close stdin logger", error)); } - if !VMCTX.init.sandbox_disabled.load(Ordering::Relaxed) { + if !VMCTX.write().init.sandbox_disabled.load(Ordering::Relaxed) { linux_raw::sys_setrlimit(linux_raw::RLIMIT_NOFILE, &linux_raw::rlimit { rlim_cur: 0, rlim_max: 0 }) .unwrap_or_else(|error| abort_with_error("failed to set RLIMIT_NOFILE", error)); @@ -713,14 +698,14 @@ unsafe fn initialize(mut stack: *mut usize) { .unwrap_or_else(|error| abort_with_error("failed to set seccomp filter", error)); } - VMCTX.futex.store(VMCTX_FUTEX_IDLE, Ordering::Release); - linux_raw::sys_futex_wake_one(&VMCTX.futex) + VMCTX.write().futex.store(VMCTX_FUTEX_IDLE, Ordering::Release); + linux_raw::sys_futex_wake_one(&VMCTX.write().futex) .unwrap_or_else(|error| abort_with_error("failed to wake up the host process on initialization", error)); } #[inline] fn futex_wait_until(target_state: u32) { - let mut state = VMCTX.futex.load(Ordering::Relaxed); + let mut state = VMCTX.write().futex.load(Ordering::Relaxed); 'main_loop: loop { if state == target_state { break; @@ -730,13 +715,13 @@ fn futex_wait_until(target_state: u32) { for _ in 0..core::hint::black_box(20) { let _ = linux_raw::sys_sched_yield(); - state = VMCTX.futex.load(Ordering::Relaxed); + state = VMCTX.write().futex.load(Ordering::Relaxed); if state == target_state { break 'main_loop; } } - match linux_raw::sys_futex_wait(&VMCTX.futex, state, None) { + match linux_raw::sys_futex_wait(&VMCTX.write().futex, state, None) { Ok(()) => continue, Err(error) if error.errno() == linux_raw::EAGAIN || error.errno() == linux_raw::EINTR => continue, Err(error) => { @@ -756,7 +741,7 @@ unsafe fn main_loop() -> ! { futex_wait_until(VMCTX_FUTEX_BUSY); - let address = VMCTX.jump_into.load(Ordering::Relaxed); + let address = VMCTX.write().jump_into; trace!("Jumping into: ", Hex(address as usize)); let callback: extern "C" fn() -> ! = core::mem::transmute(address); @@ -766,9 +751,9 @@ unsafe fn main_loop() -> ! { pub unsafe extern "C" fn ext_sbrk() -> ! { trace!("Entry point: ext_sbrk"); - let new_heap_top = *VMCTX.heap_info.heap_top.get() + VMCTX.arg.load(Ordering::Relaxed) as u64; + let new_heap_top = VMCTX.write().heap_info.heap_top + VMCTX.write().arg.load(Ordering::Relaxed) as u64; let result = syscall_sbrk(new_heap_top); - VMCTX.arg.store(result, Ordering::Relaxed); + VMCTX.write().arg.store(result, Ordering::Relaxed); signal_host_and_longjmp(VMCTX_FUTEX_IDLE); } @@ -786,9 +771,9 @@ pub unsafe extern "C" fn ext_reset_memory() -> ! { .unwrap_or_else(|error| abort_with_error("failed to clear memory", error)); } - let heap_base = u64::from(*VMCTX.heap_base.get()); - let heap_initial_threshold = u64::from(*VMCTX.heap_initial_threshold.get()); - let heap_top = *VMCTX.heap_info.heap_top.get(); + let heap_base = u64::from(VMCTX.write().heap_base); + let heap_initial_threshold = u64::from(VMCTX.write().heap_initial_threshold); + let heap_top = VMCTX.write().heap_info.heap_top; if heap_top > heap_initial_threshold { linux_raw::sys_munmap( heap_initial_threshold as *mut core::ffi::c_void, @@ -797,8 +782,8 @@ pub unsafe extern "C" fn ext_reset_memory() -> ! { .unwrap_or_else(|error| abort_with_error("failed to unmap the heap", error)); } - *VMCTX.heap_info.heap_top.get() = heap_base; - *VMCTX.heap_info.heap_threshold.get() = heap_initial_threshold; + VMCTX.write().heap_info.heap_top = heap_base; + VMCTX.write().heap_info.heap_threshold = heap_initial_threshold; signal_host_and_longjmp(VMCTX_FUTEX_IDLE); } @@ -806,8 +791,8 @@ pub unsafe extern "C" fn ext_reset_memory() -> ! { pub unsafe extern "C" fn ext_zero_memory_chunk() -> ! { trace!("Entry point: ext_zero_memory_chunk"); - let address = VMCTX.arg.load(Ordering::Relaxed); - let length = VMCTX.arg2.load(Ordering::Relaxed); + let address = VMCTX.write().arg.load(Ordering::Relaxed); + let length = VMCTX.write().arg2.load(Ordering::Relaxed); core::ptr::write_bytes(address as *mut u8, 0, length as usize); signal_host_and_longjmp(VMCTX_FUTEX_IDLE); @@ -847,23 +832,23 @@ pub unsafe extern "C" fn syscall_step() -> ! { pub unsafe extern "C" fn syscall_sbrk(pending_heap_top: u64) -> u32 { trace!( "syscall: sbrk triggered: ", - Hex(*VMCTX.heap_info.heap_top.get()), + Hex(VMCTX.write().heap_info.heap_top), " -> ", Hex(pending_heap_top), " (", - Hex(pending_heap_top - *VMCTX.heap_info.heap_top.get()), + Hex(pending_heap_top - VMCTX.write().heap_info.heap_top), ")" ); - let heap_base = *VMCTX.heap_base.get(); - let heap_max_size = *VMCTX.heap_max_size.get(); + let heap_base = VMCTX.write().heap_base; + let heap_max_size = VMCTX.write().heap_max_size; if pending_heap_top > u64::from(heap_base + heap_max_size) { trace!("sbrk: heap size overflow; ignoring request"); return 0; } - let page_size = *VMCTX.page_size.get() as usize; - let Some(start) = align_to_next_page_usize(page_size, *VMCTX.heap_info.heap_top.get() as usize) else { + let page_size = VMCTX.write().page_size as usize; + let Some(start) = align_to_next_page_usize(page_size, VMCTX.write().heap_info.heap_top as usize) else { abort_with_message("unreachable") }; @@ -886,8 +871,8 @@ pub unsafe extern "C" fn syscall_sbrk(pending_heap_top: u64) -> u32 { trace!("extended heap: ", Hex(start), "-", Hex(end), " (", Hex(end - start), ")"); - *VMCTX.heap_info.heap_top.get() = pending_heap_top; - *VMCTX.heap_info.heap_threshold.get() = end as u64; + VMCTX.write().heap_info.heap_top = pending_heap_top; + VMCTX.write().heap_info.heap_threshold = end as u64; pending_heap_top as u32 } @@ -917,8 +902,9 @@ pub static EXT_TABLE: ExtTableRaw = ExtTableRaw { #[inline(always)] fn signal_host_and_longjmp(futex_value_to_set: u32) -> ! { - VMCTX.futex.store(futex_value_to_set, Ordering::Release); - linux_raw::sys_futex_wake_one(&VMCTX.futex).unwrap_or_else(|error| abort_with_error("failed to wake up the host process", error)); + VMCTX.write().futex.store(futex_value_to_set, Ordering::Release); + linux_raw::sys_futex_wake_one(&VMCTX.write().futex) + .unwrap_or_else(|error| abort_with_error("failed to wake up the host process", error)); unsafe { longjmp(addr_of_mut!(RESUME_MAIN_LOOP_JMPBUF), 1); } @@ -926,9 +912,9 @@ fn signal_host_and_longjmp(futex_value_to_set: u32) -> ! { fn memory_map() -> &'static [VmMap] { unsafe { - let shm_memory_map_count = VMCTX.shm_memory_map_count.load(Ordering::Relaxed); + let shm_memory_map_count = VMCTX.write().shm_memory_map_count; if shm_memory_map_count > 0 { - let shm_memory_map_offset = VMCTX.shm_memory_map_offset.load(Ordering::Relaxed); + let shm_memory_map_offset = VMCTX.write().shm_memory_map_offset; core::slice::from_raw_parts( (VM_ADDR_SHARED_MEMORY as *const u8) .add(shm_memory_map_offset as usize) @@ -995,9 +981,9 @@ pub unsafe extern "C" fn ext_load_program() -> ! { ); } - let shm_code_length = VMCTX.shm_code_length.load(Ordering::Relaxed); + let shm_code_length = VMCTX.write().shm_code_length; if shm_code_length > 0 { - let shm_code_offset = VMCTX.shm_code_offset.load(Ordering::Relaxed); + let shm_code_offset = VMCTX.write().shm_code_offset; linux_raw::sys_mmap( VM_ADDR_NATIVE_CODE as *mut core::ffi::c_void, shm_code_length as usize, @@ -1019,9 +1005,9 @@ pub unsafe extern "C" fn ext_load_program() -> ! { ); } - let shm_jump_table_length = VMCTX.shm_jump_table_length.load(Ordering::Relaxed); + let shm_jump_table_length = VMCTX.write().shm_jump_table_length; if shm_jump_table_length > 0 { - let shm_jump_table_offset = VMCTX.shm_jump_table_offset.load(Ordering::Relaxed); + let shm_jump_table_offset = VMCTX.write().shm_jump_table_offset; linux_raw::sys_mmap( VM_ADDR_JUMP_TABLE as *mut core::ffi::c_void, shm_jump_table_length as usize, @@ -1043,7 +1029,7 @@ pub unsafe extern "C" fn ext_load_program() -> ! { ); } - let sysreturn_address = VMCTX.sysreturn_address.load(Ordering::Relaxed); + let sysreturn_address = VMCTX.write().sysreturn_address; trace!( "new sysreturn address: ", Hex(sysreturn_address), @@ -1085,7 +1071,7 @@ pub unsafe extern "C" fn ext_fetch_idle_regs() -> ! { macro_rules! copy_regs { ($($name:ident),+) => { $( - VMCTX.init.idle_regs.$name.store(RESUME_MAIN_LOOP_JMPBUF.$name.load(Ordering::Relaxed), Ordering::Relaxed); + VMCTX.write().init.idle_regs.$name = RESUME_MAIN_LOOP_JMPBUF.$name; )+ } } diff --git a/crates/polkavm/src/sandbox/linux.rs b/crates/polkavm/src/sandbox/linux.rs index d74cc884..15c820ad 100644 --- a/crates/polkavm/src/sandbox/linux.rs +++ b/crates/polkavm/src/sandbox/linux.rs @@ -32,6 +32,7 @@ use crate::config::GasMeteringKind; use crate::page_set::PageSet; use crate::shm_allocator::{ShmAllocation, ShmAllocator}; use crate::{Gas, InterruptKind, ProgramCounter, RegValue, Segfault}; +use std::sync::RwLock; pub struct GlobalState { shared_memory: ShmAllocator, @@ -914,9 +915,9 @@ impl<'a> Map<'a> { } fn get_message(vmctx: &VmCtx) -> Option { - let message = unsafe { - let message_length = *vmctx.message_length.get() as usize; - let message = &*vmctx.message_buffer.get(); + let message = { + let message_length = vmctx.message_length as usize; + let message = &vmctx.message_buffer; &message[..core::cmp::min(message_length, message.len())] }; @@ -936,7 +937,7 @@ fn get_message(vmctx: &VmCtx) -> Option { } } -unsafe fn set_message(vmctx: &VmCtx, message: core::fmt::Arguments) { +unsafe fn set_message(vmctx: &mut VmCtx, message: core::fmt::Arguments) { struct Adapter<'a>(std::io::Cursor<&'a mut [u8]>); impl<'a> core::fmt::Write for Adapter<'a> { fn write_str(&mut self, string: &str) -> Result<(), core::fmt::Error> { @@ -945,12 +946,12 @@ unsafe fn set_message(vmctx: &VmCtx, message: core::fmt::Arguments) { } } - let buffer: &mut [u8] = &mut *vmctx.message_buffer.get(); + let buffer: &mut [u8] = &mut vmctx.message_buffer; let mut cursor = Adapter(std::io::Cursor::new(buffer)); let _ = core::fmt::write(&mut cursor, message); let length = cursor.0.position() as usize; - *vmctx.message_length.get() = length as u32; + vmctx.message_length = length as u32; } struct UffdBuffer(Arc>); @@ -972,7 +973,7 @@ enum SandboxState { pub struct Sandbox { _lifetime_pipe: Fd, - vmctx_mmap: Mmap, + vmctx: RwLock>, memory_mmap: Mmap, iouring: Option, iouring_futex_wait_queued: bool, @@ -999,15 +1000,17 @@ pub struct Sandbox { impl Drop for Sandbox { fn drop(&mut self) { - let vmctx = self.vmctx(); - let child_futex_wait = unsafe { *vmctx.counters.syscall_futex_wait.get() }; - let child_loop_start = unsafe { *vmctx.counters.syscall_wait_loop_start.get() }; log::debug!( "Host futex wait count: {}/{} ({:.02}%)", self.count_futex_wait, self.count_wait_loop_start, self.count_futex_wait as f64 / self.count_wait_loop_start as f64 * 100.0 ); + + let vmctx = &mut self.vmctx.write().unwrap(); + let child_futex_wait = vmctx.counters.syscall_futex_wait; + let child_loop_start = vmctx.counters.syscall_wait_loop_start; + log::debug!( "Child futex wait count: {}/{} ({:.02}%)", child_futex_wait, @@ -1219,10 +1222,25 @@ impl super::Sandbox for Sandbox { let (memory_memfd, memory_mmap) = prepare_memory()?; let (vmctx_memfd, vmctx_mmap) = prepare_vmctx()?; - let vmctx = unsafe { &*vmctx_mmap.as_ptr().cast::() }; - vmctx.init.logging_enabled.store(config.enable_logger, Ordering::Relaxed); - vmctx.init.uffd_available.store(global.uffd_available, Ordering::Relaxed); - vmctx.init.sandbox_disabled.store(cfg!(polkavm_dev_debug_zygote), Ordering::Relaxed); + let vmctx = RwLock::new(unsafe { Box::from_raw(vmctx_mmap.as_mut_ptr().cast::()) }); + vmctx + .write() + .unwrap() + .init + .logging_enabled + .store(config.enable_logger, Ordering::Relaxed); + vmctx + .write() + .unwrap() + .init + .uffd_available + .store(global.uffd_available, Ordering::Relaxed); + vmctx + .write() + .unwrap() + .init + .sandbox_disabled + .store(cfg!(polkavm_dev_debug_zygote), Ordering::Relaxed); let sandbox_flags = if !cfg!(polkavm_dev_debug_zygote) { u64::from( @@ -1297,9 +1315,10 @@ impl super::Sandbox for Sandbox { abort(); } Err(error) => { - let vmctx = &*vmctx_mmap.as_ptr().cast::(); - set_message(vmctx, format_args!("fatal error while spawning child: {error}")); - + set_message( + &mut vmctx.write().unwrap(), + format_args!("fatal error while spawning child: {error}"), + ); abort(); } } @@ -1465,7 +1484,7 @@ impl super::Sandbox for Sandbox { } // Wait until the child process receives the vmctx memfd. - wait_for_futex(vmctx, &mut child, VMCTX_FUTEX_BUSY, VMCTX_FUTEX_IDLE)?; + wait_for_futex(&vmctx.write().unwrap(), &mut child, VMCTX_FUTEX_BUSY, VMCTX_FUTEX_IDLE)?; // Grab the child process' maps and see what we can unmap. // @@ -1481,16 +1500,16 @@ impl super::Sandbox for Sandbox { let map = Map::parse(line).ok_or_else(|| Error::from_str("failed to parse the maps of the child process"))?; match map.name { b"[stack]" => { - vmctx.init.stack_address.store(map.start, Ordering::Relaxed); - vmctx.init.stack_length.store(map.end - map.start, Ordering::Relaxed); + vmctx.write().unwrap().init.stack_address = map.start; + vmctx.write().unwrap().init.stack_length = map.end - map.start; } b"[vdso]" => { - vmctx.init.vdso_address.store(map.start, Ordering::Relaxed); - vmctx.init.vdso_length.store(map.end - map.start, Ordering::Relaxed); + vmctx.write().unwrap().init.vdso_address = map.start; + vmctx.write().unwrap().init.vdso_length = map.end - map.start; } b"[vvar]" => { - vmctx.init.vvar_address.store(map.start, Ordering::Relaxed); - vmctx.init.vvar_length.store(map.end - map.start, Ordering::Relaxed); + vmctx.write().unwrap().init.vvar_address = map.start; + vmctx.write().unwrap().init.vvar_length = map.end - map.start; } b"[vsyscall]" => { if map.is_readable { @@ -1502,15 +1521,15 @@ impl super::Sandbox for Sandbox { } // Wake the child so that it finishes initialization. - vmctx.futex.store(VMCTX_FUTEX_BUSY, Ordering::Release); - linux_raw::sys_futex_wake_one(&vmctx.futex)?; + vmctx.write().unwrap().futex.store(VMCTX_FUTEX_BUSY, Ordering::Release); + linux_raw::sys_futex_wake_one(&vmctx.write().unwrap().futex)?; let (iouring, userfaultfd) = if global.uffd_available { let iouring = linux_raw::IoUring::new(3)?; let userfaultfd = linux_raw::recvfd(socket.borrow()).map_err(|error| { let mut error = format!("failed to fetch the userfaultfd from the child process: {error}"); - if let Some(message) = get_message(vmctx) { + if let Some(message) = get_message(&vmctx.write().unwrap()) { use core::fmt::Write; write!(&mut error, " (root cause: {message})").unwrap(); } @@ -1537,7 +1556,7 @@ impl super::Sandbox for Sandbox { socket.close()?; // Wait for the child to finish initialization. - wait_for_futex(vmctx, &mut child, VMCTX_FUTEX_BUSY, VMCTX_FUTEX_IDLE)?; + wait_for_futex(&vmctx.write().unwrap(), &mut child, VMCTX_FUTEX_BUSY, VMCTX_FUTEX_IDLE)?; let mut idle_regs = linux_raw::user_regs_struct::default(); if global.uffd_available { @@ -1556,25 +1575,25 @@ impl super::Sandbox for Sandbox { linux_raw::sys_ptrace_continue(child.pid, None)?; // Then grab the worker's idle longjmp registers. - vmctx.jump_into.store(ZYGOTE_TABLES.1.ext_fetch_idle_regs, Ordering::Relaxed); - vmctx.futex.store(VMCTX_FUTEX_BUSY, Ordering::Release); - linux_raw::sys_futex_wake_one(&vmctx.futex)?; - wait_for_futex(vmctx, &mut child, VMCTX_FUTEX_BUSY, VMCTX_FUTEX_IDLE)?; + vmctx.write().unwrap().jump_into = ZYGOTE_TABLES.1.ext_fetch_idle_regs; + vmctx.write().unwrap().futex.store(VMCTX_FUTEX_BUSY, Ordering::Release); + linux_raw::sys_futex_wake_one(&vmctx.write().unwrap().futex)?; + wait_for_futex(&vmctx.write().unwrap(), &mut child, VMCTX_FUTEX_BUSY, VMCTX_FUTEX_IDLE)?; idle_regs.rax = 1; - idle_regs.rip = vmctx.init.idle_regs.rip.load(Ordering::Relaxed); - idle_regs.rbx = vmctx.init.idle_regs.rbx.load(Ordering::Relaxed); - idle_regs.sp = vmctx.init.idle_regs.rsp.load(Ordering::Relaxed); - idle_regs.rbp = vmctx.init.idle_regs.rbp.load(Ordering::Relaxed); - idle_regs.r12 = vmctx.init.idle_regs.r12.load(Ordering::Relaxed); - idle_regs.r13 = vmctx.init.idle_regs.r13.load(Ordering::Relaxed); - idle_regs.r14 = vmctx.init.idle_regs.r14.load(Ordering::Relaxed); - idle_regs.r15 = vmctx.init.idle_regs.r15.load(Ordering::Relaxed); + idle_regs.rip = vmctx.write().unwrap().init.idle_regs.rip; + idle_regs.rbx = vmctx.write().unwrap().init.idle_regs.rbx; + idle_regs.sp = vmctx.write().unwrap().init.idle_regs.rsp; + idle_regs.rbp = vmctx.write().unwrap().init.idle_regs.rbp; + idle_regs.r12 = vmctx.write().unwrap().init.idle_regs.r12; + idle_regs.r13 = vmctx.write().unwrap().init.idle_regs.r13; + idle_regs.r14 = vmctx.write().unwrap().init.idle_regs.r14; + idle_regs.r15 = vmctx.write().unwrap().init.idle_regs.r15; } Ok(Sandbox { _lifetime_pipe: lifetime_pipe_host, - vmctx_mmap, + vmctx, memory_mmap, iouring, iouring_futex_wait_queued: false, @@ -1648,9 +1667,7 @@ impl super::Sandbox for Sandbox { }; } - self.vmctx() - .shm_memory_map_count - .store(program.memory_map.len() as u64, Ordering::Relaxed); + self.vmctx.write().unwrap().shm_memory_map_count = program.memory_map.len() as u64; memory_map } else { let Some(memory_map) = global.shared_memory.alloc(core::mem::size_of::()) else { @@ -1667,42 +1684,32 @@ impl super::Sandbox for Sandbox { fd_offset: 0x10000, }; - self.vmctx().shm_memory_map_count.store(1, Ordering::Relaxed); + self.vmctx.write().unwrap().shm_memory_map_count = 1; memory_map }; - self.vmctx() - .shm_memory_map_offset - .store(memory_map.offset() as u64, Ordering::Relaxed); + self.vmctx.write().unwrap().shm_memory_map_offset = memory_map.offset() as u64; - unsafe { - *self.vmctx().heap_info.heap_top.get() = u64::from(module.memory_map().heap_base()); - *self.vmctx().heap_info.heap_threshold.get() = u64::from(module.memory_map().rw_data_range().end); - *self.vmctx().heap_base.get() = module.memory_map().heap_base(); - *self.vmctx().heap_initial_threshold.get() = module.memory_map().rw_data_range().end; - *self.vmctx().heap_max_size.get() = module.memory_map().max_heap_size(); - *self.vmctx().page_size.get() = module.memory_map().page_size(); - } - - self.vmctx() - .shm_code_offset - .store(program.shm_code.offset() as u64, Ordering::Relaxed); - self.vmctx().shm_code_length.store(program.shm_code.len() as u64, Ordering::Relaxed); - self.vmctx() - .shm_jump_table_offset - .store(program.shm_jump_table.offset() as u64, Ordering::Relaxed); - self.vmctx() - .shm_jump_table_length - .store(program.shm_jump_table.len() as u64, Ordering::Relaxed); - self.vmctx().sysreturn_address.store(program.sysreturn_address, Ordering::Relaxed); - - self.vmctx().program_counter.store(0, Ordering::Relaxed); - self.vmctx().next_program_counter.store(0, Ordering::Relaxed); - self.vmctx().next_native_program_counter.store(0, Ordering::Relaxed); - self.vmctx().jump_into.store(ZYGOTE_TABLES.1.ext_load_program, Ordering::Relaxed); - self.vmctx().gas.store(0, Ordering::Relaxed); - for reg in &self.vmctx().regs { - reg.store(0, Ordering::Relaxed); + self.vmctx.write().unwrap().heap_info.heap_top = u64::from(module.memory_map().heap_base()); + self.vmctx.write().unwrap().heap_info.heap_threshold = u64::from(module.memory_map().rw_data_range().end); + self.vmctx.write().unwrap().heap_base = module.memory_map().heap_base(); + self.vmctx.write().unwrap().heap_initial_threshold = module.memory_map().rw_data_range().end; + self.vmctx.write().unwrap().heap_max_size = module.memory_map().max_heap_size(); + self.vmctx.write().unwrap().page_size = module.memory_map().page_size(); + + self.vmctx.write().unwrap().shm_code_offset = program.shm_code.offset() as u64; + self.vmctx.write().unwrap().shm_code_length = program.shm_code.len() as u64; + self.vmctx.write().unwrap().shm_jump_table_offset = program.shm_jump_table.offset() as u64; + self.vmctx.write().unwrap().shm_jump_table_length = program.shm_jump_table.len() as u64; + self.vmctx.write().unwrap().sysreturn_address = program.sysreturn_address; + + self.vmctx.write().unwrap().program_counter.store(0, Ordering::Relaxed); + self.vmctx.write().unwrap().next_program_counter.store(0, Ordering::Relaxed); + self.vmctx.write().unwrap().next_native_program_counter = 0; + self.vmctx.write().unwrap().jump_into = ZYGOTE_TABLES.1.ext_load_program; + self.vmctx.write().unwrap().gas = 0; + for reg in &mut self.vmctx.write().unwrap().regs { + *reg = 0; } self.dynamic_paging_enabled = module.is_dynamic_paging(); @@ -1743,7 +1750,7 @@ impl super::Sandbox for Sandbox { self.cancel_pagefault()?; Ok(()) } else { - self.vmctx().jump_into.store(ZYGOTE_TABLES.1.ext_recycle, Ordering::Relaxed); + self.vmctx.write().unwrap().jump_into = ZYGOTE_TABLES.1.ext_recycle; self.wake_oneshot_and_expect_idle() } } @@ -1762,27 +1769,29 @@ impl super::Sandbox for Sandbox { let Some(address) = compiled_module.lookup_native_code_address(pc) else { log::debug!("Tried to call into {pc} which doesn't have any native code associated with it"); self.is_program_counter_valid = true; - self.vmctx().program_counter.store(pc.0, Ordering::Relaxed); + + let vmctx = &mut self.vmctx.write().unwrap(); + vmctx.program_counter.store(pc.0, Ordering::Relaxed); if self.module.as_ref().unwrap().is_step_tracing() { - self.vmctx().next_program_counter.store(pc.0, Ordering::Relaxed); - self.vmctx() - .next_native_program_counter - .store(compiled_module.invalid_code_offset_address, Ordering::Relaxed); + vmctx.next_program_counter.store(pc.0, Ordering::Relaxed); + vmctx.next_native_program_counter = compiled_module.invalid_code_offset_address; return Ok(InterruptKind::Step); } else { - self.vmctx().next_native_program_counter.store(0, Ordering::Relaxed); + vmctx.next_native_program_counter = 0; return Ok(InterruptKind::Trap); } }; + let vmctx = &mut self.vmctx.write().unwrap(); log::trace!("Jumping into: {pc} (0x{address:x})"); - self.vmctx().next_program_counter.store(pc.0, Ordering::Relaxed); - self.vmctx().next_native_program_counter.store(address, Ordering::Relaxed); + vmctx.next_program_counter.store(pc.0, Ordering::Relaxed); + vmctx.next_native_program_counter = address; } else { + let vmctx = self.vmctx.read().unwrap(); log::trace!( "Resuming into: {} (0x{:x})", - self.vmctx().next_program_counter.load(Ordering::Relaxed), - self.vmctx().next_native_program_counter.load(Ordering::Relaxed) + vmctx.next_program_counter.load(Ordering::Relaxed), + vmctx.next_native_program_counter ); }; @@ -1799,19 +1808,19 @@ impl super::Sandbox for Sandbox { ); linux_raw::sys_ptrace_continue(self.child.pid, None)?; } else { + let vmctx = &mut self.vmctx.write().unwrap(); let compiled_module = Self::downcast_module(self.module.as_ref().unwrap()); - debug_assert_eq!(self.vmctx().futex.load(Ordering::Relaxed) & 1, VMCTX_FUTEX_IDLE); - self.vmctx() - .jump_into - .store(compiled_module.sandbox_program.0.sysenter_address, Ordering::Relaxed); + debug_assert_eq!(vmctx.futex.load(Ordering::Relaxed) & 1, VMCTX_FUTEX_IDLE); + vmctx.jump_into = compiled_module.sandbox_program.0.sysenter_address; self.wake_worker()?; self.is_program_counter_valid = true; } let result = self.wait()?; if self.module.as_ref().unwrap().gas_metering() == Some(GasMeteringKind::Async) && self.gas() < 0 { + let vmctx = &mut self.vmctx.write().unwrap(); self.is_program_counter_valid = false; - self.vmctx().next_native_program_counter.store(0, Ordering::Relaxed); + vmctx.next_native_program_counter = 0; return Ok(InterruptKind::NotEnoughGas); } @@ -1835,7 +1844,8 @@ impl super::Sandbox for Sandbox { } fn reg(&self, reg: Reg) -> RegValue { - let mut value = self.vmctx().regs[reg as usize].load(Ordering::Relaxed); + let vmctx = self.vmctx.read().unwrap(); + let mut value = vmctx.regs[reg as usize]; let compiled_module = Self::downcast_module(self.module.as_ref().unwrap()); if compiled_module.bitness == Bitness::B32 { value &= 0xffffffff; @@ -1854,15 +1864,15 @@ impl super::Sandbox for Sandbox { value &= 0xffffffff; } - self.vmctx().regs[reg as usize].store(value, Ordering::Relaxed) + self.vmctx.write().unwrap().regs[reg as usize] = value; } fn gas(&self) -> Gas { - self.vmctx().gas.load(Ordering::Relaxed) + self.vmctx.write().unwrap().gas } fn set_gas(&mut self, gas: Gas) { - self.vmctx().gas.store(gas, Ordering::Relaxed) + self.vmctx.write().unwrap().gas = gas; } fn program_counter(&self) -> Option { @@ -1870,7 +1880,8 @@ impl super::Sandbox for Sandbox { return None; } - Some(ProgramCounter(self.vmctx().program_counter.load(Ordering::Relaxed))) + let vmctx = self.vmctx.read().unwrap(); + Some(ProgramCounter(vmctx.program_counter.load(Ordering::Relaxed))) } fn next_program_counter(&self) -> Option { @@ -1878,10 +1889,11 @@ impl super::Sandbox for Sandbox { return self.next_program_counter; } - if self.vmctx().next_native_program_counter.load(Ordering::Relaxed) == 0 { + let vmctx = self.vmctx.read().unwrap(); + if vmctx.next_native_program_counter == 0 { None } else { - Some(ProgramCounter(self.vmctx().next_program_counter.load(Ordering::Relaxed))) + Some(ProgramCounter(vmctx.next_program_counter.load(Ordering::Relaxed))) } } @@ -1896,7 +1908,8 @@ impl super::Sandbox for Sandbox { return compiled_module.lookup_native_code_address(pc).map(|value| value as usize); } - let value = self.vmctx().next_native_program_counter.load(Ordering::Relaxed); + let vmctx = self.vmctx.read().unwrap(); + let value = vmctx.next_native_program_counter; if value == 0 { None } else { @@ -1910,7 +1923,7 @@ impl super::Sandbox for Sandbox { }; if !self.dynamic_paging_enabled { - self.vmctx().jump_into.store(ZYGOTE_TABLES.1.ext_reset_memory, Ordering::Relaxed); + self.vmctx.write().unwrap().jump_into = ZYGOTE_TABLES.1.ext_reset_memory; self.wake_oneshot_and_expect_idle() } else { self.free_pages(0x10000, 0xffff0000) @@ -1979,7 +1992,7 @@ impl super::Sandbox for Sandbox { } else if address >= memory_map.stack_address_low() { u64::from(address) + data.len() as u64 <= u64::from(memory_map.stack_range().end) } else if address >= memory_map.rw_data_address() { - let end = unsafe { *self.vmctx().heap_info.heap_threshold.get() }; + let end = self.vmctx.write().unwrap().heap_info.heap_threshold; u64::from(address) + data.len() as u64 <= end } else { false @@ -2028,7 +2041,7 @@ impl super::Sandbox for Sandbox { } else if address >= memory_map.stack_address_low() { u64::from(address) + u64::from(length) <= u64::from(memory_map.stack_range().end) } else if address >= memory_map.rw_data_address() { - let end = unsafe { *self.vmctx().heap_info.heap_threshold.get() }; + let end = self.vmctx.write().unwrap().heap_info.heap_threshold; u64::from(address) + u64::from(length) <= end } else { false @@ -2041,11 +2054,9 @@ impl super::Sandbox for Sandbox { }); } - self.vmctx().arg.store(address, Ordering::Relaxed); - self.vmctx().arg2.store(length, Ordering::Relaxed); - self.vmctx() - .jump_into - .store(ZYGOTE_TABLES.1.ext_zero_memory_chunk, Ordering::Relaxed); + self.vmctx.write().unwrap().arg.store(address, Ordering::Relaxed); + self.vmctx.write().unwrap().arg2.store(length, Ordering::Relaxed); + self.vmctx.write().unwrap().jump_into = ZYGOTE_TABLES.1.ext_zero_memory_chunk; if let Err(error) = self.wake_oneshot_and_expect_idle() { return Err(MemoryAccessError::Error(error.into())); } @@ -2106,22 +2117,22 @@ impl super::Sandbox for Sandbox { } fn heap_size(&self) -> u32 { - let heap_base = unsafe { *self.vmctx().heap_base.get() }; - let heap_top = unsafe { *self.vmctx().heap_info.heap_top.get() }; + let heap_base = self.vmctx.write().unwrap().heap_base; + let heap_top = self.vmctx.write().unwrap().heap_info.heap_top; (heap_top - u64::from(heap_base)) as u32 } fn sbrk(&mut self, size: u32) -> Result, Error> { if size == 0 { - return Ok(Some(unsafe { *self.vmctx().heap_info.heap_top.get() as u32 })); + return Ok(Some(self.vmctx.write().unwrap().heap_info.heap_top as u32)); } - self.vmctx().jump_into.store(ZYGOTE_TABLES.1.ext_sbrk, Ordering::Relaxed); - self.vmctx().arg.store(size, Ordering::Relaxed); + self.vmctx.write().unwrap().jump_into = ZYGOTE_TABLES.1.ext_sbrk; + self.vmctx.write().unwrap().arg.store(size, Ordering::Relaxed); self.wake_worker()?; self.wait()?.expect_idle()?; - let result = self.vmctx().arg.load(Ordering::Relaxed); + let result = self.vmctx.write().unwrap().arg.load(Ordering::Relaxed); if result == 0 { Ok(None) } else { @@ -2141,9 +2152,9 @@ impl super::Sandbox for Sandbox { fn offset_table() -> OffsetTable { OffsetTable { arg: get_field_offset!(VmCtx::new(), |base| base.arg.as_ptr()), - gas: get_field_offset!(VmCtx::new(), |base| base.gas.as_ptr()), + gas: get_field_offset!(VmCtx::new(), |base| &base.gas), heap_info: get_field_offset!(VmCtx::new(), |base| &base.heap_info), - next_native_program_counter: get_field_offset!(VmCtx::new(), |base| base.next_native_program_counter.as_ptr()), + next_native_program_counter: get_field_offset!(VmCtx::new(), |base| &base.next_native_program_counter), next_program_counter: get_field_offset!(VmCtx::new(), |base| base.next_program_counter.as_ptr()), program_counter: get_field_offset!(VmCtx::new(), |base| base.program_counter.as_ptr()), regs: get_field_offset!(VmCtx::new(), |base| base.regs.as_ptr()), @@ -2179,14 +2190,10 @@ impl Interrupt { } impl Sandbox { - #[inline] - fn vmctx(&self) -> &VmCtx { - unsafe { &*self.vmctx_mmap.as_ptr().cast::() } - } - fn wake_worker(&self) -> Result<(), Error> { - self.vmctx().futex.store(VMCTX_FUTEX_BUSY, Ordering::Release); - linux_raw::sys_futex_wake_one(&self.vmctx().futex).map(|_| ()) + let vmctx = self.vmctx.read().unwrap(); + vmctx.futex.store(VMCTX_FUTEX_BUSY, Ordering::Release); + linux_raw::sys_futex_wake_one(&vmctx.futex).map(|_| ()) } fn wake_oneshot_and_expect_idle(&mut self) -> Result<(), Error> { @@ -2202,7 +2209,7 @@ impl Sandbox { 'outer: loop { self.count_wait_loop_start += 1; - let state = self.vmctx().futex.load(Ordering::Relaxed); + let state = self.vmctx.write().unwrap().futex.load(Ordering::Relaxed); if state == VMCTX_FUTEX_IDLE { core::sync::atomic::fence(Ordering::Acquire); return Ok(Interrupt::Idle); @@ -2213,13 +2220,13 @@ impl Sandbox { let compiled_module = Self::downcast_module(self.module.as_ref().unwrap()); if compiled_module.bitness == Bitness::B32 { - for reg_value in &self.vmctx().regs { - reg_value.fetch_and(0xffffffff, Ordering::Relaxed); + for reg in &mut self.vmctx.write().unwrap().regs { + *reg &= 0xffffffff; } } - let address = self.vmctx().next_native_program_counter.load(Ordering::Relaxed); - let gas = self.vmctx().gas.load(Ordering::Relaxed); + let address = self.vmctx.write().unwrap().next_native_program_counter; + let gas = self.vmctx.write().unwrap().gas; if gas < 0 { // Read the gas cost from the machine code. let gas_metering_trap_offset = match compiled_module.bitness { @@ -2231,9 +2238,7 @@ impl Sandbox { return Err(Error::from_str("internal error: address underflow after a trap")); }; - self.vmctx() - .next_native_program_counter - .store(compiled_module.native_code_origin + offset, Ordering::Relaxed); + self.vmctx.write().unwrap().next_native_program_counter = compiled_module.native_code_origin + offset; let Some(program_counter) = compiled_module.program_counter_by_native_code_address(address, false) else { return Err(Error::from_str("internal error: failed to find the program counter based on the native program counter when running out of gas")); @@ -2251,16 +2256,26 @@ impl Sandbox { )); }; - let gas_cost = u32::from_le_bytes([gas_cost[0], gas_cost[1], gas_cost[2], gas_cost[3]]); - let gas = self.vmctx().gas.fetch_add(i64::from(gas_cost), Ordering::Relaxed); + let gas_cost = i64::from(u32::from_le_bytes([gas_cost[0], gas_cost[1], gas_cost[2], gas_cost[3]])); + let gas = self.vmctx.write().unwrap().gas; + self.vmctx.write().unwrap().gas += gas_cost; + log::trace!( "Out of gas; program counter = {program_counter}, reverting gas: {gas} -> {new_gas} (gas cost: {gas_cost})", - new_gas = gas + i64::from(gas_cost) + new_gas = gas + gas_cost ); self.is_program_counter_valid = true; - self.vmctx().program_counter.store(program_counter.0, Ordering::Relaxed); - self.vmctx().next_program_counter.store(program_counter.0, Ordering::Relaxed); + self.vmctx + .write() + .unwrap() + .program_counter + .store(program_counter.0, Ordering::Relaxed); + self.vmctx + .write() + .unwrap() + .next_program_counter + .store(program_counter.0, Ordering::Relaxed); return Ok(Interrupt::NotEnoughGas); } else { @@ -2270,8 +2285,12 @@ impl Sandbox { }; self.is_program_counter_valid = true; - self.vmctx().program_counter.store(program_counter.0, Ordering::Relaxed); - self.vmctx().next_native_program_counter.store(0, Ordering::Relaxed); + self.vmctx + .write() + .unwrap() + .program_counter + .store(program_counter.0, Ordering::Relaxed); + self.vmctx.write().unwrap().next_native_program_counter = 0; return Ok(Interrupt::Trap); } @@ -2279,7 +2298,7 @@ impl Sandbox { if state == VMCTX_FUTEX_GUEST_ECALLI { core::sync::atomic::fence(Ordering::Acquire); - let hostcall = self.vmctx().arg.load(Ordering::Relaxed); + let hostcall = self.vmctx.write().unwrap().arg.load(Ordering::Relaxed); return Ok(Interrupt::Ecalli(hostcall)); } @@ -2307,9 +2326,8 @@ impl Sandbox { if !self.iouring_futex_wait_queued { self.count_futex_wait += 1; - let vmctx = unsafe { &*self.vmctx_mmap.as_ptr().cast::() }; iouring - .queue_futex_wait(IO_URING_JOB_FUTEX_WAIT, &vmctx.futex, VMCTX_FUTEX_BUSY) + .queue_futex_wait(IO_URING_JOB_FUTEX_WAIT, &self.vmctx.write().unwrap().futex, VMCTX_FUTEX_BUSY) .expect("internal error: io_uring queue overflow"); self.iouring_futex_wait_queued = true; } @@ -2404,20 +2422,25 @@ impl Sandbox { for _ in 0..spin_target { core::hint::spin_loop(); - if self.vmctx().futex.load(Ordering::Relaxed) != VMCTX_FUTEX_BUSY { + if self.vmctx.write().unwrap().futex.load(Ordering::Relaxed) != VMCTX_FUTEX_BUSY { continue 'outer; } } for _ in 0..yield_target { let _ = linux_raw::sys_sched_yield(); - if self.vmctx().futex.load(Ordering::Relaxed) != VMCTX_FUTEX_BUSY { + if self.vmctx.write().unwrap().futex.load(Ordering::Relaxed) != VMCTX_FUTEX_BUSY { continue 'outer; } } self.count_futex_wait += 1; - match linux_raw::sys_futex_wait(&self.vmctx().futex, VMCTX_FUTEX_BUSY, Some(Duration::from_millis(100))) { + let status = linux_raw::sys_futex_wait( + &self.vmctx.write().unwrap().futex, + VMCTX_FUTEX_BUSY, + Some(Duration::from_millis(100)), + ); + match status { Ok(()) => continue, Err(error) if error.errno() == linux_raw::EAGAIN || error.errno() == linux_raw::EINTR => continue, Err(error) if error.errno() == linux_raw::ETIMEDOUT => { @@ -2437,7 +2460,8 @@ impl Sandbox { } log::trace!("Child #{} is not running anymore: {status}", self.child.pid); - let message = get_message(self.vmctx()); + + let message = get_message(&self.vmctx.write().unwrap()); if let Some(message) = message { Err(Error::from(format!("{status}: {message}"))) } else { @@ -2457,9 +2481,17 @@ impl Sandbox { }; self.is_program_counter_valid = true; - self.vmctx().program_counter.store(program_counter.0, Ordering::Relaxed); - self.vmctx().next_program_counter.store(program_counter.0, Ordering::Relaxed); - self.vmctx().next_native_program_counter.store(regs.rip, Ordering::Relaxed); + self.vmctx + .write() + .unwrap() + .program_counter + .store(program_counter.0, Ordering::Relaxed); + self.vmctx + .write() + .unwrap() + .next_program_counter + .store(program_counter.0, Ordering::Relaxed); + self.vmctx.write().unwrap().next_native_program_counter = regs.rip; for reg in Reg::ALL { use polkavm_common::regmap::NativeReg::*; @@ -2487,7 +2519,7 @@ impl Sandbox { value &= 0xffffffff; } - self.vmctx().regs[reg as usize].store(value, Ordering::Relaxed); + self.vmctx.write().unwrap().regs[reg as usize] = value; } Ok(()) @@ -2517,7 +2549,7 @@ impl Sandbox { r15 => &mut regs.r15, }; - *value = self.vmctx().regs[reg as usize].load(Ordering::Relaxed); + *value = self.vmctx.write().unwrap().regs[reg as usize]; } linux_raw::sys_ptrace_setregs(self.child.pid, ®s)?; @@ -2527,15 +2559,12 @@ impl Sandbox { fn cancel_pagefault(&mut self) -> Result<(), Error> { log::trace!("Cancelling pending page fault..."); - - // This will cancel *our own* `futex_wait` which we've queued up with iouring. - linux_raw::sys_futex_wake_one(&self.vmctx().futex)?; - + linux_raw::sys_futex_wake_one(&self.vmctx.write().unwrap().futex)?; // Forcibly return the worker to the idle state. // // The worker's currently stuck in a page fault somewhere inside guest code, // so it can't do this by itself. - self.vmctx().futex.store(VMCTX_FUTEX_IDLE, Ordering::Release); + self.vmctx.write().unwrap().futex.store(VMCTX_FUTEX_IDLE, Ordering::Release); linux_raw::sys_ptrace_setregs(self.child.pid, &self.idle_regs)?; linux_raw::sys_ptrace_continue(self.child.pid, None) }