diff --git a/core/ext/mod.rs b/core/ext/mod.rs index 179b3e7c6..f1758324b 100644 --- a/core/ext/mod.rs +++ b/core/ext/mod.rs @@ -1,8 +1,6 @@ use crate::{function::ExternalFunc, Database}; -pub use limbo_extension::{ - Blob as ExtBlob, TextValue as ExtTextValue, Value as ExtValue, ValueType as ExtValueType, -}; use limbo_extension::{ExtensionApi, ResultCode, ScalarFunction, RESULT_ERROR, RESULT_OK}; +pub use limbo_extension::{Value as ExtValue, ValueType as ExtValueType}; use std::{ ffi::{c_char, c_void, CStr}, rc::Rc, @@ -18,6 +16,9 @@ extern "C" fn register_scalar_function( Ok(s) => s.to_string(), Err(_) => return RESULT_ERROR, }; + if ctx.is_null() { + return RESULT_ERROR; + } let db = unsafe { &*(ctx as *const Database) }; db.register_scalar_function_impl(name_str, func) } diff --git a/core/lib.rs b/core/lib.rs index 96cba7170..edc5cf3a4 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -123,7 +123,7 @@ impl Database { let header = db_header; let schema = Rc::new(RefCell::new(Schema::new())); let syms = Rc::new(RefCell::new(SymbolTable::new())); - let mut db = Database { + let db = Database { pager: pager.clone(), schema: schema.clone(), header: header.clone(), @@ -190,7 +190,9 @@ impl Database { self.syms.borrow_mut().extensions.push((lib, api_ptr)); Ok(()) } else { - let _ = unsafe { Box::from_raw(api_ptr.cast_mut()) }; // own this again so we dont leak + if !api_ptr.is_null() { + let _ = unsafe { Box::from_raw(api_ptr.cast_mut()) }; + } Err(LimboError::ExtensionError( "Extension registration failed".to_string(), )) diff --git a/core/types.rs b/core/types.rs index f95762b8c..75f28a96e 100644 --- a/core/types.rs +++ b/core/types.rs @@ -1,5 +1,5 @@ use crate::error::LimboError; -use crate::ext::{ExtBlob, ExtTextValue, ExtValue, ExtValueType}; +use crate::ext::{ExtValue, ExtValueType}; use crate::storage::sqlite3_ondisk::write_varint; use crate::Result; use std::fmt::Display; @@ -123,16 +123,17 @@ impl OwnedValue { OwnedValue::Float(float) } ExtValueType::Text => { - let Some(text) = (unsafe { ExtTextValue::from_value(v) }) else { + let Some(text) = v.to_text() else { return OwnedValue::Null; }; OwnedValue::build_text(std::rc::Rc::new(unsafe { text.as_str().to_string() })) } ExtValueType::Blob => { - let blob_ptr = v.value as *mut ExtBlob; + let Some(blob_ptr) = v.to_blob() else { + return OwnedValue::Null; + }; let blob = unsafe { - let slice = - std::slice::from_raw_parts((*blob_ptr).data, (*blob_ptr).size as usize); + let slice = std::slice::from_raw_parts(blob_ptr.data, blob_ptr.size as usize); slice.to_vec() }; OwnedValue::Blob(std::rc::Rc::new(blob)) diff --git a/extensions/uuid/src/lib.rs b/extensions/uuid/src/lib.rs index df8581955..d88f2f887 100644 --- a/extensions/uuid/src/lib.rs +++ b/extensions/uuid/src/lib.rs @@ -1,6 +1,5 @@ use limbo_extension::{ - declare_scalar_functions, register_extension, register_scalar_functions, Blob, TextValue, - Value, ValueType, + declare_scalar_functions, register_extension, register_scalar_functions, Value, ValueType, }; register_extension! { @@ -11,7 +10,7 @@ register_extension! { "uuid7" => uuid7_blob, "uuid_str" => uuid_str, "uuid_blob" => uuid_blob, - "exec_ts_from_uuid7" => exec_ts_from_uuid7, + "uuid7_timestamp_ms" => exec_ts_from_uuid7, }, } @@ -32,14 +31,35 @@ declare_scalar_functions! { #[args(0..=1)] fn uuid7_str(args: &[Value]) -> Value { let timestamp = if args.is_empty() { - let ctx = uuid::ContextV7::new(); + let ctx = uuid::ContextV7::new(); uuid::Timestamp::now(ctx) - } else if args[0].value_type == limbo_extension::ValueType::Integer { + } else { + let arg = &args[0]; + match arg.value_type { + ValueType::Integer => { let ctx = uuid::ContextV7::new(); - let int = args[0].value as i64; + let Some(int) = arg.to_integer() else { + return Value::null(); + }; uuid::Timestamp::from_unix(ctx, int as u64, 0) - } else { - return Value::null(); + } + ValueType::Text => { + let Some(text) = arg.to_text() else { + return Value::null(); + }; + let parsed = unsafe{text.as_str()}.parse::(); + match parsed { + Ok(unix) => { + if unix <= 0 { + return Value::null(); + } + uuid::Timestamp::from_unix(uuid::ContextV7::new(), unix as u64, 0) + } + Err(_) => return Value::null(), + } + } + _ => return Value::null(), + } }; let uuid = uuid::Uuid::new_v7(timestamp); Value::from_text(uuid.to_string()) @@ -52,7 +72,9 @@ declare_scalar_functions! { uuid::Timestamp::now(ctx) } else if args[0].value_type == limbo_extension::ValueType::Integer { let ctx = uuid::ContextV7::new(); - let int = args[0].value as i64; + let Some(int) = args[0].to_integer() else { + return Value::null(); + }; uuid::Timestamp::from_unix(ctx, int as u64, 0) } else { return Value::null(); @@ -66,14 +88,16 @@ declare_scalar_functions! { fn exec_ts_from_uuid7(args: &[Value]) -> Value { match args[0].value_type { ValueType::Blob => { - let blob = Blob::from_value(&args[0]).unwrap(); + let Some(blob) = &args[0].to_blob() else { + return Value::null(); + }; let slice = unsafe{ std::slice::from_raw_parts(blob.data, blob.size as usize)}; let uuid = uuid::Uuid::from_slice(slice).unwrap(); let unix = uuid_to_unix(uuid.as_bytes()); Value::from_integer(unix as i64) } ValueType::Text => { - let Some(text) = (unsafe {TextValue::from_value(&args[0])}) else { + let Some(text) = args[0].to_text() else { return Value::null(); }; let uuid = uuid::Uuid::parse_str(unsafe {text.as_str()}).unwrap(); @@ -86,29 +110,20 @@ declare_scalar_functions! { #[args(1)] fn uuid_str(args: &[Value]) -> Value { - if args[0].value_type != limbo_extension::ValueType::Blob { - log::debug!("uuid_str was passed a non-blob arg"); + let Some(blob) = args[0].to_blob() else { return Value::null(); - } - if let Some(blob) = Blob::from_value(&args[0]) { + }; let slice = unsafe{ std::slice::from_raw_parts(blob.data, blob.size as usize)}; let parsed = uuid::Uuid::from_slice(slice).ok().map(|u| u.to_string()); match parsed { Some(s) => Value::from_text(s), None => Value::null() } - } else { - Value::null() - } } #[args(1)] fn uuid_blob(args: &[Value]) -> Value { - if args[0].value_type != limbo_extension::ValueType::Text { - log::debug!("uuid_blob was passed a non-text arg"); - return Value::null(); - } - let Some(text) = (unsafe { TextValue::from_value(&args[0])}) else { + let Some(text) = args[0].to_text() else { return Value::null(); }; match uuid::Uuid::parse_str(unsafe {text.as_str()}) { diff --git a/limbo_extension/src/lib.rs b/limbo_extension/src/lib.rs index f8eb19dcf..d07cc8ea7 100644 --- a/limbo_extension/src/lib.rs +++ b/limbo_extension/src/lib.rs @@ -208,13 +208,6 @@ impl Blob { pub fn new(data: *const u8, size: u64) -> Self { Self { data, size } } - - pub fn from_value(value: &Value) -> Option<&Self> { - if value.value_type != ValueType::Blob { - return None; - } - unsafe { Some(&*(value.value as *const Blob)) } - } } impl Value { @@ -225,6 +218,46 @@ impl Value { } } + pub fn to_float(&self) -> Option { + if self.value_type != ValueType::Float { + return None; + } + if self.value.is_null() { + return None; + } + Some(unsafe { *(self.value as *const f64) }) + } + + pub fn to_text(&self) -> Option<&TextValue> { + if self.value_type != ValueType::Text { + return None; + } + if self.value.is_null() { + return None; + } + unsafe { Some(&*(self.value as *const TextValue)) } + } + + pub fn to_blob(&self) -> Option<&Blob> { + if self.value_type != ValueType::Blob { + return None; + } + if self.value.is_null() { + return None; + } + unsafe { Some(&*(self.value as *const Blob)) } + } + + pub fn to_integer(&self) -> Option { + if self.value_type != ValueType::Integer { + return None; + } + if self.value.is_null() { + return None; + } + Some(unsafe { *(self.value as *const i64) }) + } + pub fn from_integer(value: i64) -> Self { let boxed = Box::new(value); Self {