diff --git a/crates/polkavm-common/src/utils.rs b/crates/polkavm-common/src/utils.rs index e79eed4c..357634bb 100644 --- a/crates/polkavm-common/src/utils.rs +++ b/crates/polkavm-common/src/utils.rs @@ -268,6 +268,8 @@ pub trait Access<'a> { /// if the program ended up consuming *exactly* the amount of gas that it was provided with! fn gas_remaining(&self) -> Option; + fn consume_gas(&mut self, gas: u64); + #[cfg(feature = "alloc")] fn read_memory_into_new_vec(&self, address: u32, length: u32) -> Result, Self::Error> { let mut buffer = Vec::new(); diff --git a/crates/polkavm-common/src/zygote.rs b/crates/polkavm-common/src/zygote.rs index a56b275e..33202e9c 100644 --- a/crates/polkavm-common/src/zygote.rs +++ b/crates/polkavm-common/src/zygote.rs @@ -254,7 +254,7 @@ const REG_COUNT: usize = crate::program::Reg::ALL_NON_ZERO.len(); pub struct VmCtxSyscall { // NOTE: The order of fields here can matter for performance! /// The current gas counter. - pub gas: UnsafeCell, + pub gas: UnsafeCell, /// The hostcall number that was triggered. pub hostcall: UnsafeCell, /// A dump of all of the registers of the VM. @@ -379,7 +379,7 @@ impl VmCtx { // when we shuffle things around in the structure. #[inline(always)] - pub const fn gas(&self) -> &UnsafeCell { + pub const fn gas(&self) -> &UnsafeCell { &self.syscall_ffi.0.gas } diff --git a/crates/polkavm/src/api.rs b/crates/polkavm/src/api.rs index 5d5a8d84..290b03f5 100644 --- a/crates/polkavm/src/api.rs +++ b/crates/polkavm/src/api.rs @@ -1767,6 +1767,10 @@ impl<'a> Access<'a> for BackendAccess<'a> { fn gas_remaining(&self) -> Option { access_backend!(self, |access| access.gas_remaining()) } + + fn consume_gas(&mut self, gas: u64) { + access_backend!(self, |access| access.consume_gas(gas)) + } } struct InstancePrivateMut { diff --git a/crates/polkavm/src/caller.rs b/crates/polkavm/src/caller.rs index 2e50c1e1..f15f4612 100644 --- a/crates/polkavm/src/caller.rs +++ b/crates/polkavm/src/caller.rs @@ -129,6 +129,11 @@ impl CallerRaw { // SAFETY: The caller will make sure that the invariants hold. unsafe { self.access() }.gas_remaining() } + + unsafe fn consume_gas(&mut self, gas: u64) { + // SAFETY: The caller will make sure that the invariants hold. + unsafe { self.access_mut() }.consume_gas(gas) + } } /// A handle used to access the execution context. @@ -246,6 +251,11 @@ impl<'a, T> Caller<'a, T> { // SAFETY: This can only be called from inside of `Caller::wrap` so this is always valid. unsafe { self.raw.gas_remaining() } } + + pub fn consume_gas(&mut self, gas: u64) { + // SAFETY: This can only be called from inside of `Caller::wrap` so this is always valid. + unsafe { self.raw.consume_gas(gas) } + } } /// A handle used to access the execution context, with erased lifetimes for convenience. @@ -328,6 +338,13 @@ impl CallerRef { // SAFETY: We've made sure the lifetime is valid. unsafe { (*self.raw).gas_remaining() } } + + pub fn consume_gas(&mut self, gas: u64) { + self.check_lifetime_or_panic(); + + // SAFETY: We've made sure the lifetime is valid. + unsafe { (*self.raw).consume_gas(gas) } + } } // Source: https://users.rust-lang.org/t/a-macro-to-assert-that-a-type-does-not-implement-trait-bounds/31179 diff --git a/crates/polkavm/src/interpreter.rs b/crates/polkavm/src/interpreter.rs index bd641b8f..b6422d25 100644 --- a/crates/polkavm/src/interpreter.rs +++ b/crates/polkavm/src/interpreter.rs @@ -248,6 +248,16 @@ impl InterpretedInstance { Ok(()) } + + fn check_gas(&mut self) -> Result<(), ExecutionError> { + if let Some(ref mut gas_remaining) = self.gas_remaining { + if *gas_remaining < 0 { + return Err(ExecutionError::OutOfGas); + } + } + + Ok(()) + } } pub struct InterpretedAccess<'a> { @@ -314,6 +324,12 @@ impl<'a> Access<'a> for InterpretedAccess<'a> { let gas = self.instance.gas_remaining?; Some(Gas::new(gas as u64).unwrap_or(Gas::MIN)) } + + fn consume_gas(&mut self, gas: u64) { + if let Some(ref mut gas_remaining) = self.instance.gas_remaining { + *gas_remaining = gas_remaining.checked_sub_unsigned(gas).unwrap_or(-1); + } + } } struct Visitor<'a, 'b> { @@ -542,6 +558,7 @@ impl<'a, 'b> InstructionVisitor for Visitor<'a, 'b> { let access = BackendAccess::Interpreted(self.inner.access()); (on_hostcall)(imm, access).map_err(ExecutionError::Trap)?; self.inner.nth_instruction += 1; + self.inner.check_gas()?; Ok(()) } else { log::debug!("Hostcall called without any hostcall handler set!"); diff --git a/crates/polkavm/src/sandbox.rs b/crates/polkavm/src/sandbox.rs index 608737b0..0e41accc 100644 --- a/crates/polkavm/src/sandbox.rs +++ b/crates/polkavm/src/sandbox.rs @@ -65,22 +65,6 @@ pub(crate) fn assert_native_page_size() { #[derive(Copy, Clone, PartialEq, Eq, Debug)] pub(crate) struct OutOfGas; -fn get_gas_remaining(raw_gas: u64) -> Result { - Gas::new(raw_gas).ok_or(OutOfGas) -} - -#[test] -fn test_get_gas_remaining() { - assert_eq!(get_gas_remaining(0), Ok(Gas::new(0).unwrap())); - assert_eq!(get_gas_remaining(1), Ok(Gas::new(1).unwrap())); - assert_eq!(get_gas_remaining((-1_i64) as u64), Err(OutOfGas)); - assert_eq!(get_gas_remaining(Gas::MIN.get()), Ok(Gas::MIN)); - assert_eq!(get_gas_remaining(Gas::MAX.get()), Ok(Gas::MAX)); - - // We should never have such gas values, but test it anyway. - assert_eq!(get_gas_remaining(Gas::MAX.get() + 1), Err(OutOfGas)); -} - pub trait SandboxConfig: Default { fn enable_logger(&mut self, value: bool); } @@ -246,7 +230,7 @@ impl<'a, T> ExecuteArgs<'a, T> where T: Sandbox { self.is_async = value; } - fn get_gas(&self, gas_metering: Option) -> Option { + fn get_gas(&self, gas_metering: Option) -> Option { if self.program.is_none() && self.gas.is_none() && gas_metering.is_some() { // Keep whatever value was set there previously. return None; @@ -254,7 +238,7 @@ impl<'a, T> ExecuteArgs<'a, T> where T: Sandbox { let gas = self.gas.unwrap_or(Gas::MIN); if gas_metering.is_some() { - Some(gas.get()) + Some(gas.get() as i64) } else { Some(0) } diff --git a/crates/polkavm/src/sandbox/generic.rs b/crates/polkavm/src/sandbox/generic.rs index d6370163..7114cec8 100644 --- a/crates/polkavm/src/sandbox/generic.rs +++ b/crates/polkavm/src/sandbox/generic.rs @@ -425,7 +425,7 @@ struct VmCtx { return_address: usize, return_stack_pointer: usize, - gas: u64, + gas: i64, program_range: Range, trap_triggered: bool, @@ -1053,8 +1053,8 @@ impl super::Sandbox for Sandbox { fn gas_remaining_impl(&self) -> Result, super::OutOfGas> { let Some(program) = self.program.as_ref() else { return Ok(None) }; if program.0.gas_metering.is_none() { return Ok(None) }; - let value = self.vmctx().gas; - super::get_gas_remaining(value).map(Some) + let raw_gas = self.vmctx().gas; + Gas::from_i64(raw_gas).ok_or(super::OutOfGas).map(Some) } fn sync(&mut self) -> Result<(), Self::Error> { @@ -1164,4 +1164,13 @@ impl<'a> Access<'a> for SandboxAccess<'a> { use super::Sandbox; self.sandbox.gas_remaining_impl().ok().unwrap_or(Some(Gas::MIN)) } + + fn consume_gas(&mut self, gas: u64) { + if self.sandbox.program.as_ref().and_then(|program| program.0.gas_metering).is_none() { + return; + } + + let gas_remaining = &mut self.sandbox.vmctx_mut().gas; + *gas_remaining = gas_remaining.checked_sub_unsigned(gas).unwrap_or(-1); + } } diff --git a/crates/polkavm/src/sandbox/linux.rs b/crates/polkavm/src/sandbox/linux.rs index d6521c8e..6a1a0897 100644 --- a/crates/polkavm/src/sandbox/linux.rs +++ b/crates/polkavm/src/sandbox/linux.rs @@ -1187,8 +1187,8 @@ impl super::Sandbox for Sandbox { fn gas_remaining_impl(&self) -> Result, super::OutOfGas> { if self.gas_metering.is_none() { return Ok(None) }; - let value = unsafe { *self.vmctx().gas().get() }; - super::get_gas_remaining(value).map(Some) + let raw_gas = unsafe { *self.vmctx().gas().get() }; + Gas::from_i64(raw_gas).ok_or(super::OutOfGas).map(Some) } fn sync(&mut self) -> Result<(), Self::Error> { @@ -1465,4 +1465,11 @@ impl<'a> Access<'a> for SandboxAccess<'a> { use super::Sandbox; self.sandbox.gas_remaining_impl().ok().unwrap_or(Some(Gas::MIN)) } + + fn consume_gas(&mut self, gas: u64) { + if self.sandbox.gas_metering.is_none() { return } + let gas_remaining = unsafe { &mut *self.sandbox.vmctx().gas().get() }; + *gas_remaining = gas_remaining.checked_sub_unsigned(gas).unwrap_or(-1); + + } } diff --git a/crates/polkavm/src/tests.rs b/crates/polkavm/src/tests.rs index e5fa3c7f..3297a98f 100644 --- a/crates/polkavm/src/tests.rs +++ b/crates/polkavm/src/tests.rs @@ -534,6 +534,68 @@ fn basic_gas_metering_async(config: Config) { basic_gas_metering(config, GasMeteringKind::Async); } +fn consume_gas_in_host_function(config: Config, gas_metering_kind: GasMeteringKind) { + let _ = env_logger::try_init(); + + let mut builder = ProgramBlobBuilder::new(); + builder.add_export(0, &FnMetadata::new("main", &[], Some(I32))); + builder.add_import(0, &FnMetadata::new("hostfn", &[], Some(I32))); + builder.set_code(&[asm::ecalli(0), asm::ret()]); + + let blob = ProgramBlob::parse(builder.into_vec()).unwrap(); + let engine = Engine::new(&config).unwrap(); + let mut module_config = ModuleConfig::default(); + module_config.set_gas_metering(Some(gas_metering_kind)); + + let module = Module::from_blob(&engine, &module_config, &blob).unwrap(); + let mut linker = Linker::new(&engine); + linker + .func_wrap("hostfn", |mut caller: Caller| -> u32 { + assert_eq!(caller.gas_remaining().unwrap().get(), 1); + caller.consume_gas(*caller.data()); + 666 + }) + .unwrap(); + + let instance_pre = linker.instantiate_pre(&module).unwrap(); + let instance = instance_pre.instantiate().unwrap(); + + { + let mut config = ExecutionConfig::default(); + config.set_gas(Gas::new(3).unwrap()); + + let result = instance.get_typed_func::<(), i32>("main").unwrap().call_ex(&mut 0, (), config); + assert!(matches!(result, Ok(666)), "unexpected result: {result:?}"); + assert_eq!(instance.gas_remaining().unwrap(), Gas::new(1).unwrap()); + } + + { + let mut config = ExecutionConfig::default(); + config.set_gas(Gas::new(3).unwrap()); + + let result = instance.get_typed_func::<(), i32>("main").unwrap().call_ex(&mut 1, (), config); + assert!(matches!(result, Ok(666)), "unexpected result: {result:?}"); + assert_eq!(instance.gas_remaining().unwrap(), Gas::new(0).unwrap()); + } + + { + let mut config = ExecutionConfig::default(); + config.set_gas(Gas::new(3).unwrap()); + + let result = instance.get_typed_func::<(), i32>("main").unwrap().call_ex(&mut 2, (), config); + assert_eq!(instance.gas_remaining().unwrap(), Gas::new(0).unwrap()); + assert!(matches!(result, Err(ExecutionError::OutOfGas)), "unexpected result: {result:?}"); + } +} + +fn consume_gas_in_host_function_sync(config: Config) { + consume_gas_in_host_function(config, GasMeteringKind::Sync); +} + +fn consume_gas_in_host_function_async(config: Config) { + consume_gas_in_host_function(config, GasMeteringKind::Async); +} + run_tests! { caller_and_caller_ref_work caller_split_works @@ -546,6 +608,8 @@ run_tests! { basic_gas_metering_sync basic_gas_metering_async + consume_gas_in_host_function_sync + consume_gas_in_host_function_async } // Source: https://users.rust-lang.org/t/a-macro-to-assert-that-a-type-does-not-implement-trait-bounds/31179