diff --git a/Cargo.lock b/Cargo.lock index ae544266..dadbd7b9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -564,7 +564,7 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef552e6f588e446098f6ba40d89ac146c8c7b64aade83c051ee00bb5d2bc18d" dependencies = [ - "uuid 1.11.0", + "uuid", ] [[package]] @@ -694,10 +694,6 @@ dependencies = [ "str-buf", ] -[[package]] -name = "extension_api" -version = "0.0.11" - [[package]] name = "fallible-iterator" version = "0.2.0" @@ -1218,7 +1214,6 @@ dependencies = [ "cfg_block", "chrono", "criterion", - "extension_api", "fallible-iterator 0.3.0", "getrandom", "hex", @@ -1228,6 +1223,7 @@ dependencies = [ "julian_day_converter", "libc", "libloading", + "limbo_extension", "limbo_macros", "log", "miette", @@ -1249,9 +1245,13 @@ dependencies = [ "sqlite3-parser", "tempfile", "thiserror 1.0.69", - "uuid 1.11.0", + "uuid", ] +[[package]] +name = "limbo_extension" +version = "0.0.11" + [[package]] name = "limbo_macros" version = "0.0.11" @@ -1280,6 +1280,14 @@ dependencies = [ "log", ] +[[package]] +name = "limbo_uuid" +version = "0.0.11" +dependencies = [ + "limbo_extension", + "uuid", +] + [[package]] name = "linux-raw-sys" version = "0.4.14" @@ -2238,7 +2246,7 @@ dependencies = [ "debugid", "memmap2", "stable_deref_trait", - "uuid 1.11.0", + "uuid", ] [[package]] @@ -2451,14 +2459,6 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" -[[package]] -name = "uuid" -version = "0.0.11" -dependencies = [ - "extension_api", - "uuid 1.11.0", -] - [[package]] name = "uuid" version = "1.11.0" diff --git a/Cargo.toml b/Cargo.toml index a6e1af7d..d9e67193 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ members = [ "sqlite3", "core", "simulator", - "test", "macros", "extension_api", "extensions/uuid", + "test", "macros", "limbo_extension", "extensions/uuid", ] exclude = ["perf/latency/limbo"] diff --git a/cli/app.rs b/cli/app.rs index de2c642a..f3c5eb9e 100644 --- a/cli/app.rs +++ b/cli/app.rs @@ -270,14 +270,8 @@ impl Limbo { }; } - fn handle_load_extension(&mut self) -> Result<(), String> { - let mut args = self.input_buff.split_whitespace(); - let _ = args.next(); - let lib = args - .next() - .ok_or("No library specified") - .map_err(|e| e.to_string())?; - self.conn.load_extension(lib).map_err(|e| e.to_string()) + fn handle_load_extension(&mut self, path: &str) -> Result<(), String> { + self.conn.load_extension(path).map_err(|e| e.to_string()) } fn display_in_memory(&mut self) -> std::io::Result<()> { @@ -511,8 +505,8 @@ impl Limbo { }; } Command::LoadExtension => { - if let Err(e) = self.handle_load_extension() { - let _ = self.writeln(e.to_string()); + if let Err(e) = self.handle_load_extension(args[1]) { + let _ = self.writeln(&e); } } } diff --git a/core/Cargo.toml b/core/Cargo.toml index a905cd13..b4a58435 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -35,7 +35,7 @@ rustix = "0.38.34" mimalloc = { version = "*", default-features = false } [dependencies] -extension_api = { path = "../extension_api" } +limbo_extension = { path = "../limbo_extension" } cfg_block = "0.1.1" fallible-iterator = "0.3.0" hex = "0.4.3" diff --git a/core/ext/mod.rs b/core/ext/mod.rs index c1718dcc..79936079 100644 --- a/core/ext/mod.rs +++ b/core/ext/mod.rs @@ -1,68 +1,41 @@ -#[cfg(feature = "uuid")] -mod uuid; use crate::{function::ExternalFunc, Database}; -use std::sync::Arc; - -use extension_api::{AggregateFunction, ExtensionApi, Result, ScalarFunction, VirtualTable}; -#[cfg(feature = "uuid")] -pub use uuid::{exec_ts_from_uuid7, exec_uuid, exec_uuidblob, exec_uuidstr, UuidFunc}; - -impl ExtensionApi for Database { - fn register_scalar_function( - &self, - name: &str, - func: Arc, - ) -> extension_api::Result<()> { - let ext_func = ExternalFunc::new(name, func.clone()); - self.syms - .borrow_mut() - .functions - .insert(name.to_string(), Arc::new(ext_func)); - Ok(()) - } - - fn register_aggregate_function( - &self, - _name: &str, - _func: Arc, - ) -> Result<()> { - todo!("implement aggregate function registration"); - } - - fn register_virtual_table(&self, _name: &str, _table: Arc) -> Result<()> { - todo!("implement virtual table registration"); - } -} - -#[derive(Debug, Clone, PartialEq)] -pub enum ExtFunc { - #[cfg(feature = "uuid")] - Uuid(UuidFunc), +use limbo_extension::{ExtensionApi, ResultCode, ScalarFunction, RESULT_ERROR, RESULT_OK}; +pub use limbo_extension::{Value as ExtValue, ValueType as ExtValueType}; +use std::{ + ffi::{c_char, c_void, CStr}, + rc::Rc, +}; + +extern "C" fn register_scalar_function( + ctx: *mut c_void, + name: *const c_char, + func: ScalarFunction, +) -> ResultCode { + let c_str = unsafe { CStr::from_ptr(name) }; + let name_str = match c_str.to_str() { + Ok(s) => s.to_string(), + Err(_) => return RESULT_ERROR, + }; + let db = unsafe { &*(ctx as *const Database) }; + db.register_scalar_function_impl(name_str, func) } -#[allow(unreachable_patterns)] // TODO: remove when more extension funcs added -impl std::fmt::Display for ExtFunc { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - #[cfg(feature = "uuid")] - Self::Uuid(uuidfn) => write!(f, "{}", uuidfn), - _ => write!(f, "unknown"), - } +impl Database { + fn register_scalar_function_impl(&self, name: String, func: ScalarFunction) -> ResultCode { + self.syms.borrow_mut().functions.insert( + name.to_string(), + Rc::new(ExternalFunc { + name: name.to_string(), + func, + }), + ); + RESULT_OK } -} -#[allow(unreachable_patterns)] -impl ExtFunc { - pub fn resolve_function(name: &str, num_args: usize) -> Option { - match name { - #[cfg(feature = "uuid")] - name => UuidFunc::resolve_function(name, num_args), - _ => None, + pub fn build_limbo_extension(&self) -> ExtensionApi { + ExtensionApi { + ctx: self as *const _ as *mut c_void, + register_scalar_function, } } } - -//pub fn init(db: &mut crate::Database) { -// #[cfg(feature = "uuid")] -// uuid::init(db); -//} diff --git a/core/function.rs b/core/function.rs index 70bc4292..ac4d2c27 100644 --- a/core/function.rs +++ b/core/function.rs @@ -1,15 +1,16 @@ -use crate::ext::ExtFunc; use std::fmt; use std::fmt::{Debug, Display}; -use std::sync::Arc; +use std::rc::Rc; + +use limbo_extension::ScalarFunction; pub struct ExternalFunc { pub name: String, - pub func: Arc, + pub func: ScalarFunction, } impl ExternalFunc { - pub fn new(name: &str, func: Arc) -> Self { + pub fn new(name: &str, func: ScalarFunction) -> Self { Self { name: name.to_string(), func, @@ -306,8 +307,7 @@ pub enum Func { Math(MathFunc), #[cfg(feature = "json")] Json(JsonFunc), - Extension(ExtFunc), - External(Arc), + External(Rc), } impl Display for Func { @@ -318,7 +318,6 @@ impl Display for Func { Self::Math(math_func) => write!(f, "{}", math_func), #[cfg(feature = "json")] Self::Json(json_func) => write!(f, "{}", json_func), - Self::Extension(ext_func) => write!(f, "{}", ext_func), Self::External(generic_func) => write!(f, "{}", generic_func), } } @@ -423,10 +422,7 @@ impl Func { "tan" => Ok(Self::Math(MathFunc::Tan)), "tanh" => Ok(Self::Math(MathFunc::Tanh)), "trunc" => Ok(Self::Math(MathFunc::Trunc)), - _ => match ExtFunc::resolve_function(name, arg_count) { - Some(ext_func) => Ok(Self::Extension(ext_func)), - None => Err(()), - }, + _ => Err(()), } } } diff --git a/core/lib.rs b/core/lib.rs index aa6d19c5..96cba717 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -17,9 +17,9 @@ mod vdbe; #[global_allocator] static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; -use extension_api::{Extension, ExtensionApi}; use fallible_iterator::FallibleIterator; use libloading::{Library, Symbol}; +use limbo_extension::{ExtensionApi, ExtensionEntryPoint, RESULT_OK}; use log::trace; use schema::Schema; use sqlite3_parser::ast; @@ -131,7 +131,6 @@ impl Database { _shared_wal: shared_wal.clone(), syms, }; - // ext::init(&mut db); let db = Arc::new(db); let conn = Rc::new(Connection { db: db.clone(), @@ -165,31 +164,37 @@ impl Database { pub fn define_scalar_function>( &self, name: S, - func: Arc, + func: limbo_extension::ScalarFunction, ) { let func = function::ExternalFunc { name: name.as_ref().to_string(), - func: func.clone(), + func, }; self.syms .borrow_mut() .functions - .insert(name.as_ref().to_string(), Arc::new(func)); + .insert(name.as_ref().to_string(), func.into()); } pub fn load_extension(&self, path: &str) -> Result<()> { + let api = Box::new(self.build_limbo_extension()); let lib = unsafe { Library::new(path).map_err(|e| LimboError::ExtensionError(e.to_string()))? }; - unsafe { - let register: Symbol Box> = - lib.get(b"register_extension") - .map_err(|e| LimboError::ExtensionError(e.to_string()))?; - let extension = register(self); - extension - .load() - .map_err(|e| LimboError::ExtensionError(e.to_string()))?; + let entry: Symbol = unsafe { + lib.get(b"register_extension") + .map_err(|e| LimboError::ExtensionError(e.to_string()))? + }; + let api_ptr: *const ExtensionApi = Box::into_raw(api); + let result_code = entry(api_ptr); + if result_code == RESULT_OK { + self.syms.borrow_mut().extensions.push((lib, api_ptr)); + Ok(()) + } else { + let _ = unsafe { Box::from_raw(api_ptr.cast_mut()) }; // own this again so we dont leak + Err(LimboError::ExtensionError( + "Extension registration failed".to_string(), + )) } - Ok(()) } } @@ -318,7 +323,11 @@ impl Connection { Cmd::ExplainQueryPlan(stmt) => { match stmt { ast::Stmt::Select(select) => { - let mut plan = prepare_select_plan(&self.schema.borrow(), *select)?; + let mut plan = prepare_select_plan( + &self.schema.borrow(), + *select, + &self.db.syms.borrow(), + )?; optimize_plan(&mut plan)?; println!("{}", plan); } @@ -484,8 +493,8 @@ impl Rows { } pub(crate) struct SymbolTable { - pub functions: HashMap>, - extensions: Vec>, + pub functions: HashMap>, + extensions: Vec<(libloading::Library, *const ExtensionApi)>, } impl std::fmt::Debug for SymbolTable { @@ -508,7 +517,7 @@ impl SymbolTable { &self, name: &str, _arg_count: usize, - ) -> Option> { + ) -> Option> { self.functions.get(name).cloned() } } diff --git a/core/translate/emitter.rs b/core/translate/emitter.rs index 8fdff8cd..8819b220 100644 --- a/core/translate/emitter.rs +++ b/core/translate/emitter.rs @@ -38,14 +38,13 @@ impl<'a> Resolver<'a> { } pub fn resolve_function(&self, func_name: &str, arg_count: usize) -> Option { - let func_type = match Func::resolve_function(&func_name, arg_count).ok() { + match Func::resolve_function(func_name, arg_count).ok() { Some(func) => Some(func), None => self .symbol_table - .resolve_function(&func_name, arg_count) - .map(|func| Func::External(func)), - }; - func_type + .resolve_function(func_name, arg_count) + .map(|arg| Func::External(arg.clone())), + } } pub fn resolve_cached_expr_reg(&self, expr: &ast::Expr) -> Option { diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 2f88c9a0..c1e9ecc1 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -1,7 +1,5 @@ use sqlite3_parser::ast::{self, UnaryOperator}; -#[cfg(feature = "uuid")] -use crate::ext::{ExtFunc, UuidFunc}; #[cfg(feature = "json")] use crate::function::JsonFunc; use crate::function::{Func, FuncCtx, MathFuncArity, ScalarFunc}; @@ -1403,60 +1401,6 @@ pub fn translate_expr( } } } - Func::Extension(ext_func) => match ext_func { - #[cfg(feature = "uuid")] - ExtFunc::Uuid(ref uuid_fn) => match uuid_fn { - UuidFunc::UuidStr | UuidFunc::UuidBlob | UuidFunc::Uuid7TS => { - let args = expect_arguments_exact!(args, 1, ext_func); - let regs = program.alloc_register(); - translate_expr(program, referenced_tables, &args[0], regs, resolver)?; - program.emit_insn(Insn::Function { - constant_mask: 0, - start_reg: regs, - dest: target_register, - func: func_ctx, - }); - Ok(target_register) - } - UuidFunc::Uuid4Str => { - if args.is_some() { - crate::bail_parse_error!( - "{} function with arguments", - ext_func.to_string() - ); - } - let regs = program.alloc_register(); - program.emit_insn(Insn::Function { - constant_mask: 0, - start_reg: regs, - dest: target_register, - func: func_ctx, - }); - Ok(target_register) - } - UuidFunc::Uuid7 => { - let args = expect_arguments_max!(args, 1, ext_func); - let mut start_reg = None; - if let Some(arg) = args.first() { - start_reg = Some(translate_and_mark( - program, - referenced_tables, - arg, - resolver, - )?); - } - program.emit_insn(Insn::Function { - constant_mask: 0, - start_reg: start_reg.unwrap_or(target_register), - dest: target_register, - func: func_ctx, - }); - Ok(target_register) - } - }, - #[allow(unreachable_patterns)] - _ => unreachable!("{ext_func} not implemented yet"), - }, Func::Math(math_func) => match math_func.arity() { MathFuncArity::Nullary => { if args.is_some() { diff --git a/core/translate/planner.rs b/core/translate/planner.rs index 64a0ffb0..f5c835f8 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -1,6 +1,7 @@ use super::{ plan::{Aggregate, Plan, SelectQueryType, SourceOperator, TableReference, TableReferenceType}, select::prepare_select_plan, + SymbolTable, }; use crate::{ function::Func, @@ -259,6 +260,7 @@ fn parse_from_clause_table( table: ast::SelectTable, operator_id_counter: &mut OperatorIdCounter, cur_table_index: usize, + syms: &SymbolTable, ) -> Result<(TableReference, SourceOperator)> { match table { ast::SelectTable::Table(qualified_name, maybe_alias, _) => { @@ -289,7 +291,7 @@ fn parse_from_clause_table( )) } ast::SelectTable::Select(subselect, maybe_alias) => { - let Plan::Select(mut subplan) = prepare_select_plan(schema, *subselect)? else { + let Plan::Select(mut subplan) = prepare_select_plan(schema, *subselect, syms)? else { unreachable!(); }; subplan.query_type = SelectQueryType::Subquery { @@ -322,6 +324,7 @@ pub fn parse_from( schema: &Schema, mut from: Option, operator_id_counter: &mut OperatorIdCounter, + syms: &SymbolTable, ) -> Result<(SourceOperator, Vec)> { if from.as_ref().and_then(|f| f.select.as_ref()).is_none() { return Ok(( @@ -339,7 +342,7 @@ pub fn parse_from( let select_owned = *std::mem::take(&mut from_owned.select).unwrap(); let joins_owned = std::mem::take(&mut from_owned.joins).unwrap_or_default(); let (table_reference, mut operator) = - parse_from_clause_table(schema, select_owned, operator_id_counter, table_index)?; + parse_from_clause_table(schema, select_owned, operator_id_counter, table_index, syms)?; tables.push(table_reference); table_index += 1; @@ -350,7 +353,14 @@ pub fn parse_from( is_outer_join: outer, using, predicates, - } = parse_join(schema, join, operator_id_counter, &mut tables, table_index)?; + } = parse_join( + schema, + join, + operator_id_counter, + &mut tables, + table_index, + syms, + )?; operator = SourceOperator::Join { left: Box::new(operator), right: Box::new(right), @@ -394,6 +404,7 @@ fn parse_join( operator_id_counter: &mut OperatorIdCounter, tables: &mut Vec, table_index: usize, + syms: &SymbolTable, ) -> Result { let ast::JoinedSelectTable { operator: join_operator, @@ -402,7 +413,7 @@ fn parse_join( } = join; let (table_reference, source_operator) = - parse_from_clause_table(schema, table, operator_id_counter, table_index)?; + parse_from_clause_table(schema, table, operator_id_counter, table_index, syms)?; tables.push(table_reference); diff --git a/core/translate/select.rs b/core/translate/select.rs index 44dcb528..a19987d7 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -25,12 +25,16 @@ pub fn translate_select( connection: Weak, syms: &SymbolTable, ) -> Result { - let mut select_plan = prepare_select_plan(schema, select)?; + let mut select_plan = prepare_select_plan(schema, select, syms)?; optimize_plan(&mut select_plan)?; emit_program(database_header, select_plan, connection, syms) } -pub fn prepare_select_plan(schema: &Schema, select: ast::Select) -> Result { +pub fn prepare_select_plan( + schema: &Schema, + select: ast::Select, + syms: &SymbolTable, +) -> Result { match *select.body.select { ast::OneSelect::Select { mut columns, @@ -47,7 +51,8 @@ pub fn prepare_select_plan(schema: &Schema, select: ast::Select) -> Result let mut operator_id_counter = OperatorIdCounter::new(); // Parse the FROM clause - let (source, referenced_tables) = parse_from(schema, from, &mut operator_id_counter)?; + let (source, referenced_tables) = + parse_from(schema, from, &mut operator_id_counter, syms)?; let mut plan = SelectPlan { source, @@ -147,7 +152,25 @@ pub fn prepare_select_plan(schema: &Schema, select: ast::Select) -> Result contains_aggregates, }); } - _ => {} + Err(_) => { + if syms.functions.contains_key(&name.0) { + // TODO: future extensions can be aggregate functions + log::debug!( + "Resolving {} function from symbol table", + name.0 + ); + plan.result_columns.push(ResultSetColumn { + name: get_name( + maybe_alias.as_ref(), + expr, + &plan.referenced_tables, + || format!("expr_{}", result_column_idx), + ), + expr: expr.clone(), + contains_aggregates: false, + }); + } + } } } ast::Expr::FunctionCallStar { diff --git a/core/types.rs b/core/types.rs index 0b2f691c..346297d4 100644 --- a/core/types.rs +++ b/core/types.rs @@ -1,11 +1,10 @@ use crate::error::LimboError; +use crate::ext::{ExtValue, ExtValueType}; +use crate::storage::sqlite3_ondisk::write_varint; use crate::Result; -use extension_api::Value as ExtValue; use std::fmt::Display; use std::{cell::Ref, rc::Rc}; -use crate::storage::sqlite3_ondisk::write_varint; - #[derive(Debug, Clone, PartialEq)] pub enum Value<'a> { Null, @@ -15,45 +14,6 @@ pub enum Value<'a> { Blob(&'a Vec), } -impl From<&OwnedValue> for extension_api::Value { - fn from(value: &OwnedValue) -> Self { - match value { - OwnedValue::Null => extension_api::Value::Null, - OwnedValue::Integer(i) => extension_api::Value::Integer(*i), - OwnedValue::Float(f) => extension_api::Value::Float(*f), - OwnedValue::Text(text) => extension_api::Value::Text(text.value.to_string()), - OwnedValue::Blob(blob) => extension_api::Value::Blob(blob.to_vec()), - OwnedValue::Agg(_) => { - panic!("Cannot convert Aggregate context to extension_api::Value") - } // Handle appropriately - OwnedValue::Record(_) => panic!("Cannot convert Record to extension_api::Value"), // Handle appropriately - } - } -} -impl From for OwnedValue { - fn from(value: ExtValue) -> Self { - match value { - ExtValue::Null => OwnedValue::Null, - ExtValue::Integer(i) => OwnedValue::Integer(i), - ExtValue::Float(f) => OwnedValue::Float(f), - ExtValue::Text(text) => OwnedValue::Text(LimboText::new(Rc::new(text.to_string()))), - ExtValue::Blob(blob) => OwnedValue::Blob(Rc::new(blob.to_vec())), - } - } -} - -impl<'a> From<&'a crate::Value<'a>> for ExtValue { - fn from(value: &'a crate::Value<'a>) -> Self { - match value { - crate::Value::Null => extension_api::Value::Null, - crate::Value::Integer(i) => extension_api::Value::Integer(*i), - crate::Value::Float(f) => extension_api::Value::Float(*f), - crate::Value::Text(t) => extension_api::Value::Text(t.to_string()), - crate::Value::Blob(b) => extension_api::Value::Blob(b.to_vec()), - } - } -} - impl Display for Value<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -132,6 +92,41 @@ impl Display for OwnedValue { } } } +impl OwnedValue { + pub fn to_ffi(&self) -> ExtValue { + match self { + Self::Null => ExtValue::null(), + Self::Integer(i) => ExtValue::from_integer(*i), + Self::Float(fl) => ExtValue::from_float(*fl), + Self::Text(s) => ExtValue::from_text(s.value.to_string()), + Self::Blob(b) => ExtValue::from_blob(b), + Self::Agg(_) => todo!(), + Self::Record(_) => todo!(), + } + } + pub fn from_ffi(v: &ExtValue) -> Self { + match v.value_type { + ExtValueType::Null => OwnedValue::Null, + ExtValueType::Integer => OwnedValue::Integer(v.integer), + ExtValueType::Float => OwnedValue::Float(v.float), + ExtValueType::Text => { + if v.text.is_null() { + OwnedValue::Null + } else { + OwnedValue::build_text(std::rc::Rc::new(v.text.to_string())) + } + } + ExtValueType::Blob => { + if v.blob.data.is_null() { + OwnedValue::Null + } else { + let bytes = unsafe { std::slice::from_raw_parts(v.blob.data, v.blob.size) }; + OwnedValue::Blob(std::rc::Rc::new(bytes.to_vec())) + } + } + } + } +} #[derive(Debug, Clone, PartialEq)] pub enum AggContext { diff --git a/core/vdbe/mod.rs b/core/vdbe/mod.rs index 357d4ed5..e20991ac 100644 --- a/core/vdbe/mod.rs +++ b/core/vdbe/mod.rs @@ -25,8 +25,7 @@ pub mod likeop; pub mod sorter; use crate::error::{LimboError, SQLITE_CONSTRAINT_PRIMARYKEY}; -#[cfg(feature = "uuid")] -use crate::ext::{exec_ts_from_uuid7, exec_uuid, exec_uuidblob, exec_uuidstr, ExtFunc, UuidFunc}; +use crate::ext::ExtValue; use crate::function::{AggFunc, FuncCtx, MathFunc, MathFuncArity, ScalarFunc}; use crate::pseudo::PseudoCursor; use crate::result::LimboResult; @@ -53,9 +52,10 @@ use likeop::{construct_like_escape_arg, exec_glob, exec_like_with_escape}; use rand::distributions::{Distribution, Uniform}; use rand::{thread_rng, Rng}; use regex::{Regex, RegexBuilder}; -use std::borrow::{Borrow, BorrowMut}; +use std::borrow::BorrowMut; use std::cell::RefCell; use std::collections::{BTreeMap, HashMap}; +use std::os::raw::c_void; use std::rc::{Rc, Weak}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -1720,47 +1720,20 @@ impl Program { state.registers[*dest] = exec_replace(source, pattern, replacement); } }, - #[allow(unreachable_patterns)] - crate::function::Func::Extension(extfn) => match extfn { - #[cfg(feature = "uuid")] - ExtFunc::Uuid(uuidfn) => match uuidfn { - UuidFunc::Uuid4Str => { - state.registers[*dest] = exec_uuid(uuidfn, None)? - } - UuidFunc::Uuid7 => match arg_count { - 0 => { - state.registers[*dest] = - exec_uuid(uuidfn, None).unwrap_or(OwnedValue::Null); - } - 1 => { - let reg_value = state.registers[*start_reg].borrow(); - state.registers[*dest] = exec_uuid(uuidfn, Some(reg_value)) - .unwrap_or(OwnedValue::Null); - } - _ => unreachable!(), - }, - _ => { - // remaining accept 1 arg - let reg_value = state.registers[*start_reg].borrow(); - state.registers[*dest] = match uuidfn { - UuidFunc::Uuid7TS => Some(exec_ts_from_uuid7(reg_value)), - UuidFunc::UuidStr => exec_uuidstr(reg_value).ok(), - UuidFunc::UuidBlob => exec_uuidblob(reg_value).ok(), - _ => unreachable!(), - } - .unwrap_or(OwnedValue::Null); - } - }, - _ => unreachable!(), // when more extension types are added - }, crate::function::Func::External(f) => { let values = &state.registers[*start_reg..*start_reg + arg_count]; - let args: Vec<_> = values.into_iter().map(|v| v.into()).collect(); - let result = f - .func - .execute(args.as_slice()) - .map_err(|e| LimboError::ExtensionError(e.to_string()))?; - state.registers[*dest] = result.into(); + let c_values: Vec<*const c_void> = values + .iter() + .map(|ov| &ov.to_ffi() as *const _ as *const c_void) + .collect(); + let argv_ptr = if c_values.is_empty() { + std::ptr::null() + } else { + c_values.as_ptr() + }; + let result_c_value: ExtValue = (f.func)(arg_count as i32, argv_ptr); + let result_ov = OwnedValue::from_ffi(&result_c_value); + state.registers[*dest] = result_ov; } crate::function::Func::Math(math_func) => match math_func.arity() { MathFuncArity::Nullary => match math_func { diff --git a/extension_api/src/lib.rs b/extension_api/src/lib.rs deleted file mode 100644 index 7ee26e32..00000000 --- a/extension_api/src/lib.rs +++ /dev/null @@ -1,75 +0,0 @@ -use std::any::Any; -use std::rc::Rc; -use std::sync::Arc; - -pub type Result = std::result::Result; - -pub trait Extension { - fn load(&self) -> Result<()>; -} - -#[derive(Debug)] -pub enum LimboApiError { - ConnectionError(String), - RegisterFunctionError(String), - ValueError(String), - VTableError(String), -} - -impl From for LimboApiError { - fn from(e: std::io::Error) -> Self { - Self::ConnectionError(e.to_string()) - } -} - -impl std::fmt::Display for LimboApiError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - Self::ConnectionError(e) => write!(f, "Connection error: {e}"), - Self::RegisterFunctionError(e) => write!(f, "Register function error: {e}"), - Self::ValueError(e) => write!(f, "Value error: {e}"), - Self::VTableError(e) => write!(f, "VTable error: {e}"), - } - } -} - -pub trait ExtensionApi { - fn register_scalar_function(&self, name: &str, func: Arc) -> Result<()>; - fn register_aggregate_function( - &self, - name: &str, - func: Arc, - ) -> Result<()>; - fn register_virtual_table(&self, name: &str, table: Arc) -> Result<()>; -} - -pub trait ScalarFunction { - fn execute(&self, args: &[Value]) -> Result; -} - -pub trait AggregateFunction { - fn init(&self) -> Box; - fn step(&self, state: &mut dyn Any, args: &[Value]) -> Result<()>; - fn finalize(&self, state: Box) -> Result; -} - -pub trait VirtualTable { - fn schema(&self) -> &'static str; - fn create_cursor(&self) -> Box; -} - -pub trait Cursor { - fn next(&mut self) -> Result>; -} - -pub struct Row { - pub values: Vec, -} - -pub enum Value { - Text(String), - Blob(Vec), - Integer(i64), - Float(f64), - Null, -} diff --git a/extensions/uuid/Cargo.toml b/extensions/uuid/Cargo.toml index c6ae90bd..1b9eb10a 100644 --- a/extensions/uuid/Cargo.toml +++ b/extensions/uuid/Cargo.toml @@ -1,11 +1,15 @@ [package] -name = "uuid" +name = "limbo_uuid" version.workspace = true authors.workspace = true edition.workspace = true license.workspace = true repository.workspace = true +[lib] +crate-type = ["cdylib"] + + [dependencies] -extension_api = { path = "../../extension_api"} +limbo_extension = { path = "../../limbo_extension"} uuid = { version = "1.11.0", features = ["v4", "v7"] } diff --git a/extensions/uuid/src/lib.rs b/extensions/uuid/src/lib.rs index e69de29b..c9950be8 100644 --- a/extensions/uuid/src/lib.rs +++ b/extensions/uuid/src/lib.rs @@ -0,0 +1,62 @@ +use limbo_extension::{ + declare_scalar_functions, register_extension, register_scalar_functions, Value, +}; + +register_extension! { + scalars: { + "uuid4_str" => uuid4_str, + "uuid4" => uuid4_blob, + "uuid_str" => uuid_str, + "uuid_blob" => uuid_blob, + }, +} + +declare_scalar_functions! { + #[args(min = 0, max = 0)] + fn uuid4_str(_args: &[Value]) -> Value { + let uuid = uuid::Uuid::new_v4().to_string(); + Value::from_text(uuid) + } + #[args(min = 0, max = 0)] + fn uuid4_blob(_args: &[Value]) -> Value { + let uuid = uuid::Uuid::new_v4(); + let bytes = uuid.as_bytes(); + Value::from_blob(bytes) + } + + #[args(min = 1, max = 1)] + fn uuid_str(args: &[Value]) -> Value { + if args.len() != 1 { + return Value::null(); + } + if args[0].value_type != limbo_extension::ValueType::Blob { + return Value::null(); + } + let data_ptr = args[0].blob.data; + let size = args[0].blob.size; + if data_ptr.is_null() || size != 16 { + return Value::null(); + } + let slice = unsafe{ std::slice::from_raw_parts(data_ptr, size)}; + let parsed = uuid::Uuid::from_slice(slice).ok().map(|u| u.to_string()); + match parsed { + Some(s) => Value::from_text(s), + None => Value::null() + } + } + + #[args(min = 1, max = 1)] + fn uuid_blob(args: &[Value]) -> Value { + if args.len() != 1 { + return Value::null(); + } + if args[0].value_type != limbo_extension::ValueType::Text { + return Value::null(); + } + let text = args[0].text.to_string(); + match uuid::Uuid::parse_str(&text) { + Ok(uuid) => Value::from_blob(uuid.as_bytes()), + Err(_) => Value::null() + } + } +} diff --git a/extension_api/Cargo.toml b/limbo_extension/Cargo.toml similarity index 86% rename from extension_api/Cargo.toml rename to limbo_extension/Cargo.toml index 73056af3..d3ac246d 100644 --- a/extension_api/Cargo.toml +++ b/limbo_extension/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "extension_api" +name = "limbo_extension" version.workspace = true authors.workspace = true edition.workspace = true diff --git a/limbo_extension/src/lib.rs b/limbo_extension/src/lib.rs new file mode 100644 index 00000000..6704f63e --- /dev/null +++ b/limbo_extension/src/lib.rs @@ -0,0 +1,233 @@ +use std::ffi::CString; +use std::os::raw::{c_char, c_void}; + +pub type ResultCode = i32; + +pub const RESULT_OK: ResultCode = 0; +pub const RESULT_ERROR: ResultCode = 1; +// TODO: more error types + +pub type ExtensionEntryPoint = extern "C" fn(api: *const ExtensionApi) -> ResultCode; +pub type ScalarFunction = extern "C" fn(argc: i32, *const *const c_void) -> Value; + +#[repr(C)] +pub struct ExtensionApi { + pub ctx: *mut c_void, + pub register_scalar_function: + extern "C" fn(ctx: *mut c_void, name: *const c_char, func: ScalarFunction) -> ResultCode, +} + +#[macro_export] +macro_rules! register_extension { + ( + scalars: { $( $scalar_name:expr => $scalar_func:ident ),* $(,)? }, + //aggregates: { $( $agg_name:expr => ($step_func:ident, $finalize_func:ident) ),* $(,)? }, + //virtual_tables: { $( $vt_name:expr => $vt_impl:expr ),* $(,)? } + ) => { + #[no_mangle] + pub unsafe extern "C" fn register_extension(api: *const $crate::ExtensionApi) -> $crate::ResultCode { + if api.is_null() { + return $crate::RESULT_ERROR; + } + + register_scalar_functions! { api, $( $scalar_name => $scalar_func ),* } + // TODO: + //register_aggregate_functions! { $( $agg_name => ($step_func, $finalize_func) ),* } + //register_virtual_tables! { $( $vt_name => $vt_impl ),* } + $crate::RESULT_OK + } + } +} + +#[macro_export] +macro_rules! register_scalar_functions { + ( $api:expr, $( $fname:expr => $fptr:ident ),* ) => { + unsafe { + $( + let cname = std::ffi::CString::new($fname).unwrap(); + ((*$api).register_scalar_function)((*$api).ctx, cname.as_ptr(), $fptr); + )* + } + } +} + +/// Provide a cleaner interface to define scalar functions to extension authors +/// . e.g. +/// ``` +/// fn scalar_func(args: &[Value]) -> Value { +/// if args.len() != 1 { +/// return Value::null(); +/// } +/// Value::from_integer(args[0].integer * 2) +/// } +/// ``` +/// +#[macro_export] +macro_rules! declare_scalar_functions { + ( + $( + #[args(min = $min_args:literal, max = $max_args:literal)] + fn $func_name:ident ($args:ident : &[Value]) -> Value $body:block + )* + ) => { + $( + extern "C" fn $func_name( + argc: i32, + argv: *const *const std::os::raw::c_void + ) -> $crate::Value { + if !($min_args..=$max_args).contains(&argc) { + println!("{}: Invalid argument count", stringify!($func_name)); + return $crate::Value::null();// TODO: error code + } + if argc == 0 || argv.is_null() { + let $args: &[$crate::Value] = &[]; + $body + } else { + unsafe { + let ptr_slice = std::slice::from_raw_parts(argv, argc as usize); + let mut values = Vec::with_capacity(argc as usize); + for &ptr in ptr_slice { + let val_ptr = ptr as *const $crate::Value; + if val_ptr.is_null() { + values.push($crate::Value::null()); + } else { + values.push(std::ptr::read(val_ptr)); + } + } + let $args: &[$crate::Value] = &values[..]; + $body + } + } + } + )* + }; +} + +#[derive(PartialEq, Eq)] +#[repr(C)] +pub enum ValueType { + Null, + Integer, + Float, + Text, + Blob, +} + +// TODO: perf, these can be better expressed +#[repr(C)] +pub struct Value { + pub value_type: ValueType, + pub integer: i64, + pub float: f64, + pub text: TextValue, + pub blob: Blob, +} + +#[repr(C)] +pub struct TextValue { + text: *const c_char, + len: usize, +} + +impl std::fmt::Display for TextValue { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if self.text.is_null() { + return write!(f, ""); + } + let slice = unsafe { std::slice::from_raw_parts(self.text as *const u8, self.len) }; + match std::str::from_utf8(slice) { + Ok(s) => write!(f, "{}", s), + Err(e) => write!(f, "", e), + } + } +} + +impl TextValue { + pub fn is_null(&self) -> bool { + self.text.is_null() + } + + pub fn new(text: *const c_char, len: usize) -> Self { + Self { text, len } + } + + pub fn null() -> Self { + Self { + text: std::ptr::null(), + len: 0, + } + } +} + +#[repr(C)] +pub struct Blob { + pub data: *const u8, + pub size: usize, +} + +impl Blob { + pub fn new(data: *const u8, size: usize) -> Self { + Self { data, size } + } + pub fn null() -> Self { + Self { + data: std::ptr::null(), + size: 0, + } + } +} + +impl Value { + pub fn null() -> Self { + Self { + value_type: ValueType::Null, + integer: 0, + float: 0.0, + text: TextValue::null(), + blob: Blob::null(), + } + } + + pub fn from_integer(value: i64) -> Self { + Self { + value_type: ValueType::Integer, + integer: value, + float: 0.0, + text: TextValue::null(), + blob: Blob::null(), + } + } + pub fn from_float(value: f64) -> Self { + Self { + value_type: ValueType::Float, + integer: 0, + float: value, + text: TextValue::null(), + blob: Blob::null(), + } + } + + pub fn from_text(value: String) -> Self { + let cstr = CString::new(&*value).unwrap(); + let ptr = cstr.as_ptr(); + let len = value.len(); + std::mem::forget(cstr); + Self { + value_type: ValueType::Text, + integer: 0, + float: 0.0, + text: TextValue::new(ptr, len), + blob: Blob::null(), + } + } + + pub fn from_blob(value: &[u8]) -> Self { + Self { + value_type: ValueType::Blob, + integer: 0, + float: 0.0, + text: TextValue::null(), + blob: Blob::new(value.as_ptr(), value.len()), + } + } +}