From 46079854cefe1ca7e3267efb40e016fa2fae3189 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Sun, 12 Jan 2025 11:35:02 -0500 Subject: [PATCH] Adjust types in extension API --- Cargo.lock | 4 ++ core/ext/mod.rs | 4 +- core/types.rs | 50 ++++++++++---- extensions/uuid/Cargo.toml | 3 +- extensions/uuid/src/lib.rs | 35 +++++----- limbo_extension/Cargo.toml | 1 + limbo_extension/src/lib.rs | 133 ++++++++++++++++++++----------------- 7 files changed, 139 insertions(+), 91 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index dadbd7b9..b17c82e4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1251,6 +1251,9 @@ dependencies = [ [[package]] name = "limbo_extension" version = "0.0.11" +dependencies = [ + "log", +] [[package]] name = "limbo_macros" @@ -1285,6 +1288,7 @@ name = "limbo_uuid" version = "0.0.11" dependencies = [ "limbo_extension", + "log", "uuid", ] diff --git a/core/ext/mod.rs b/core/ext/mod.rs index 79936079..179b3e7c 100644 --- a/core/ext/mod.rs +++ b/core/ext/mod.rs @@ -1,6 +1,8 @@ use crate::{function::ExternalFunc, Database}; +pub use limbo_extension::{ + Blob as ExtBlob, TextValue as ExtTextValue, Value as ExtValue, ValueType as ExtValueType, +}; use limbo_extension::{ExtensionApi, ResultCode, ScalarFunction, RESULT_ERROR, RESULT_OK}; -pub use limbo_extension::{Value as ExtValue, ValueType as ExtValueType}; use std::{ ffi::{c_char, c_void, CStr}, rc::Rc, diff --git a/core/types.rs b/core/types.rs index 346297d4..c903833f 100644 --- a/core/types.rs +++ b/core/types.rs @@ -1,5 +1,5 @@ use crate::error::LimboError; -use crate::ext::{ExtValue, ExtValueType}; +use crate::ext::{ExtBlob, ExtTextValue, ExtValue, ExtValueType}; use crate::storage::sqlite3_ondisk::write_varint; use crate::Result; use std::fmt::Display; @@ -92,36 +92,62 @@ impl Display for OwnedValue { } } } + impl OwnedValue { pub fn to_ffi(&self) -> ExtValue { match self { Self::Null => ExtValue::null(), Self::Integer(i) => ExtValue::from_integer(*i), Self::Float(fl) => ExtValue::from_float(*fl), - Self::Text(s) => ExtValue::from_text(s.value.to_string()), - Self::Blob(b) => ExtValue::from_blob(b), - Self::Agg(_) => todo!(), - Self::Record(_) => todo!(), + Self::Text(text) => ExtValue::from_text(text.value.to_string()), + Self::Blob(blob) => ExtValue::from_blob(blob.to_vec()), + Self::Agg(_) => todo!("Aggregate values not yet supported"), + Self::Record(_) => todo!("Record values not yet supported"), } } + pub fn from_ffi(v: &ExtValue) -> Self { match v.value_type { ExtValueType::Null => OwnedValue::Null, - ExtValueType::Integer => OwnedValue::Integer(v.integer), - ExtValueType::Float => OwnedValue::Float(v.float), + ExtValueType::Integer => { + if v.value.is_null() { + OwnedValue::Null + } else { + let int_ptr = v.value as *mut i64; + let integer = unsafe { *int_ptr }; + OwnedValue::Integer(integer) + } + } + ExtValueType::Float => { + if v.value.is_null() { + OwnedValue::Null + } else { + let float_ptr = v.value as *mut f64; + let float = unsafe { *float_ptr }; + OwnedValue::Float(float) + } + } ExtValueType::Text => { - if v.text.is_null() { + if v.value.is_null() { OwnedValue::Null } else { - OwnedValue::build_text(std::rc::Rc::new(v.text.to_string())) + let Some(text) = ExtTextValue::from_value(v) else { + return OwnedValue::Null; + }; + OwnedValue::build_text(std::rc::Rc::new(unsafe { text.as_str().to_string() })) } } ExtValueType::Blob => { - if v.blob.data.is_null() { + if v.value.is_null() { OwnedValue::Null } else { - let bytes = unsafe { std::slice::from_raw_parts(v.blob.data, v.blob.size) }; - OwnedValue::Blob(std::rc::Rc::new(bytes.to_vec())) + let blob_ptr = v.value as *mut ExtBlob; + let blob = unsafe { + let slice = + std::slice::from_raw_parts((*blob_ptr).data, (*blob_ptr).size as usize); + slice.to_vec() + }; + OwnedValue::Blob(std::rc::Rc::new(blob)) } } } diff --git a/extensions/uuid/Cargo.toml b/extensions/uuid/Cargo.toml index 1b9eb10a..ed2c43e8 100644 --- a/extensions/uuid/Cargo.toml +++ b/extensions/uuid/Cargo.toml @@ -7,9 +7,10 @@ license.workspace = true repository.workspace = true [lib] -crate-type = ["cdylib"] +crate-type = ["cdylib", "lib"] [dependencies] limbo_extension = { path = "../../limbo_extension"} uuid = { version = "1.11.0", features = ["v4", "v7"] } +log = "0.4.20" diff --git a/extensions/uuid/src/lib.rs b/extensions/uuid/src/lib.rs index c9950be8..94065815 100644 --- a/extensions/uuid/src/lib.rs +++ b/extensions/uuid/src/lib.rs @@ -1,5 +1,5 @@ use limbo_extension::{ - declare_scalar_functions, register_extension, register_scalar_functions, Value, + declare_scalar_functions, register_extension, register_scalar_functions, Blob, TextValue, Value, }; register_extension! { @@ -17,46 +17,47 @@ declare_scalar_functions! { let uuid = uuid::Uuid::new_v4().to_string(); Value::from_text(uuid) } + #[args(min = 0, max = 0)] fn uuid4_blob(_args: &[Value]) -> Value { let uuid = uuid::Uuid::new_v4(); let bytes = uuid.as_bytes(); - Value::from_blob(bytes) + Value::from_blob(bytes.to_vec()) } #[args(min = 1, max = 1)] fn uuid_str(args: &[Value]) -> Value { - if args.len() != 1 { - return Value::null(); - } if args[0].value_type != limbo_extension::ValueType::Blob { + log::debug!("uuid_str was passed a non-blob arg"); return Value::null(); } - let data_ptr = args[0].blob.data; - let size = args[0].blob.size; - if data_ptr.is_null() || size != 16 { - return Value::null(); - } - let slice = unsafe{ std::slice::from_raw_parts(data_ptr, size)}; + if let Some(blob) = Blob::from_value(&args[0]) { + let slice = unsafe{ std::slice::from_raw_parts(blob.data, blob.size as usize)}; let parsed = uuid::Uuid::from_slice(slice).ok().map(|u| u.to_string()); match parsed { Some(s) => Value::from_text(s), None => Value::null() } + } else { + Value::null() + } } #[args(min = 1, max = 1)] fn uuid_blob(args: &[Value]) -> Value { - if args.len() != 1 { - return Value::null(); - } if args[0].value_type != limbo_extension::ValueType::Text { + log::debug!("uuid_blob was passed a non-text arg"); return Value::null(); } - let text = args[0].text.to_string(); - match uuid::Uuid::parse_str(&text) { - Ok(uuid) => Value::from_blob(uuid.as_bytes()), + if let Some(text) = TextValue::from_value(&args[0]) { + match uuid::Uuid::parse_str(unsafe {text.as_str()}) { + Ok(uuid) => { + Value::from_blob(uuid.as_bytes().to_vec()) + } Err(_) => Value::null() } + } else { + Value::null() + } } } diff --git a/limbo_extension/Cargo.toml b/limbo_extension/Cargo.toml index d3ac246d..2928ed85 100644 --- a/limbo_extension/Cargo.toml +++ b/limbo_extension/Cargo.toml @@ -7,3 +7,4 @@ license.workspace = true repository.workspace = true [dependencies] +log = "0.4.20" diff --git a/limbo_extension/src/lib.rs b/limbo_extension/src/lib.rs index 6704f63e..3daf9975 100644 --- a/limbo_extension/src/lib.rs +++ b/limbo_extension/src/lib.rs @@ -1,4 +1,3 @@ -use std::ffi::CString; use std::os::raw::{c_char, c_void}; pub type ResultCode = i32; @@ -76,8 +75,7 @@ macro_rules! declare_scalar_functions { argv: *const *const std::os::raw::c_void ) -> $crate::Value { if !($min_args..=$max_args).contains(&argc) { - println!("{}: Invalid argument count", stringify!($func_name)); - return $crate::Value::null();// TODO: error code + return $crate::Value::null(); } if argc == 0 || argv.is_null() { let $args: &[$crate::Value] = &[]; @@ -103,8 +101,8 @@ macro_rules! declare_scalar_functions { }; } -#[derive(PartialEq, Eq)] #[repr(C)] +#[derive(PartialEq, Eq)] pub enum ValueType { Null, Integer, @@ -113,48 +111,50 @@ pub enum ValueType { Blob, } -// TODO: perf, these can be better expressed #[repr(C)] pub struct Value { pub value_type: ValueType, - pub integer: i64, - pub float: f64, - pub text: TextValue, - pub blob: Blob, + pub value: *mut c_void, } #[repr(C)] pub struct TextValue { - text: *const c_char, - len: usize, + pub text: *const u8, + pub len: u32, } -impl std::fmt::Display for TextValue { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if self.text.is_null() { - return write!(f, ""); - } - let slice = unsafe { std::slice::from_raw_parts(self.text as *const u8, self.len) }; - match std::str::from_utf8(slice) { - Ok(s) => write!(f, "{}", s), - Err(e) => write!(f, "", e), +impl Default for TextValue { + fn default() -> Self { + Self { + text: std::ptr::null(), + len: 0, } } } impl TextValue { - pub fn is_null(&self) -> bool { - self.text.is_null() + pub fn new(text: *const u8, len: usize) -> Self { + Self { + text, + len: len as u32, + } } - pub fn new(text: *const c_char, len: usize) -> Self { - Self { text, len } + pub fn from_value(value: &Value) -> Option<&Self> { + if value.value_type != ValueType::Text { + return None; + } + unsafe { Some(&*(value.value as *const TextValue)) } } - pub fn null() -> Self { - Self { - text: std::ptr::null(), - len: 0, + /// # Safety + /// The caller must ensure that the text is a valid UTF-8 string + pub unsafe fn as_str(&self) -> &str { + if self.text.is_null() { + return ""; + } + unsafe { + std::str::from_utf8_unchecked(std::slice::from_raw_parts(self.text, self.len as usize)) } } } @@ -162,18 +162,19 @@ impl TextValue { #[repr(C)] pub struct Blob { pub data: *const u8, - pub size: usize, + pub size: u64, } impl Blob { - pub fn new(data: *const u8, size: usize) -> Self { + pub fn new(data: *const u8, size: u64) -> Self { Self { data, size } } - pub fn null() -> Self { - Self { - data: std::ptr::null(), - size: 0, + + pub fn from_value(value: &Value) -> Option<&Self> { + if value.value_type != ValueType::Blob { + return None; } + unsafe { Some(&*(value.value as *const Blob)) } } } @@ -181,53 +182,65 @@ impl Value { pub fn null() -> Self { Self { value_type: ValueType::Null, - integer: 0, - float: 0.0, - text: TextValue::null(), - blob: Blob::null(), + value: std::ptr::null_mut(), } } pub fn from_integer(value: i64) -> Self { + let boxed = Box::new(value); Self { value_type: ValueType::Integer, - integer: value, - float: 0.0, - text: TextValue::null(), - blob: Blob::null(), + value: Box::into_raw(boxed) as *mut c_void, } } + pub fn from_float(value: f64) -> Self { + let boxed = Box::new(value); Self { value_type: ValueType::Float, - integer: 0, - float: value, - text: TextValue::null(), - blob: Blob::null(), + value: Box::into_raw(boxed) as *mut c_void, } } - pub fn from_text(value: String) -> Self { - let cstr = CString::new(&*value).unwrap(); - let ptr = cstr.as_ptr(); - let len = value.len(); - std::mem::forget(cstr); + pub fn from_text(s: String) -> Self { + let text_value = TextValue::new(s.as_ptr(), s.len()); + let boxed_text = Box::new(text_value); + std::mem::forget(s); Self { value_type: ValueType::Text, - integer: 0, - float: 0.0, - text: TextValue::new(ptr, len), - blob: Blob::null(), + value: Box::into_raw(boxed_text) as *mut c_void, } } - pub fn from_blob(value: &[u8]) -> Self { + pub fn from_blob(value: Vec) -> Self { + let boxed = Box::new(Blob::new(value.as_ptr(), value.len() as u64)); + std::mem::forget(value); Self { value_type: ValueType::Blob, - integer: 0, - float: 0.0, - text: TextValue::null(), - blob: Blob::new(value.as_ptr(), value.len()), + value: Box::into_raw(boxed) as *mut c_void, } } + + pub unsafe fn free(&mut self) { + if self.value.is_null() { + return; + } + match self.value_type { + ValueType::Integer => { + let _ = Box::from_raw(self.value as *mut i64); + } + ValueType::Float => { + let _ = Box::from_raw(self.value as *mut f64); + } + ValueType::Text => { + let _ = Box::from_raw(self.value as *mut TextValue); + } + ValueType::Blob => { + let _ = Box::from_raw(self.value as *mut Blob); + } + ValueType::Null => {} + } + + self.value = std::ptr::null_mut(); + } }