Skip to content

Commit

Permalink
Rough design for extension api/draft extension
Browse files Browse the repository at this point in the history
  • Loading branch information
PThorpe92 committed Jan 12, 2025
1 parent 1e86370 commit 3942582
Show file tree
Hide file tree
Showing 18 changed files with 489 additions and 348 deletions.
32 changes: 16 additions & 16 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ members = [
"sqlite3",
"core",
"simulator",
"test", "macros", "extension_api", "extensions/uuid",
"test", "macros", "limbo_extension", "extensions/uuid",
]
exclude = ["perf/latency/limbo"]

Expand Down
14 changes: 4 additions & 10 deletions cli/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,14 +270,8 @@ impl Limbo {
};
}

fn handle_load_extension(&mut self) -> Result<(), String> {
let mut args = self.input_buff.split_whitespace();
let _ = args.next();
let lib = args
.next()
.ok_or("No library specified")
.map_err(|e| e.to_string())?;
self.conn.load_extension(lib).map_err(|e| e.to_string())
fn handle_load_extension(&mut self, path: &str) -> Result<(), String> {
self.conn.load_extension(path).map_err(|e| e.to_string())
}

fn display_in_memory(&mut self) -> std::io::Result<()> {
Expand Down Expand Up @@ -511,8 +505,8 @@ impl Limbo {
};
}
Command::LoadExtension => {
if let Err(e) = self.handle_load_extension() {
let _ = self.writeln(e.to_string());
if let Err(e) = self.handle_load_extension(args[1]) {
let _ = self.writeln(&e);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ rustix = "0.38.34"
mimalloc = { version = "*", default-features = false }

[dependencies]
extension_api = { path = "../extension_api" }
limbo_extension = { path = "../limbo_extension" }
cfg_block = "0.1.1"
fallible-iterator = "0.3.0"
hex = "0.4.3"
Expand Down
93 changes: 33 additions & 60 deletions core/ext/mod.rs
Original file line number Diff line number Diff line change
@@ -1,68 +1,41 @@
#[cfg(feature = "uuid")]
mod uuid;
use crate::{function::ExternalFunc, Database};
use std::sync::Arc;

use extension_api::{AggregateFunction, ExtensionApi, Result, ScalarFunction, VirtualTable};
#[cfg(feature = "uuid")]
pub use uuid::{exec_ts_from_uuid7, exec_uuid, exec_uuidblob, exec_uuidstr, UuidFunc};

impl ExtensionApi for Database {
fn register_scalar_function(
&self,
name: &str,
func: Arc<dyn ScalarFunction>,
) -> extension_api::Result<()> {
let ext_func = ExternalFunc::new(name, func.clone());
self.syms
.borrow_mut()
.functions
.insert(name.to_string(), Arc::new(ext_func));
Ok(())
}

fn register_aggregate_function(
&self,
_name: &str,
_func: Arc<dyn AggregateFunction>,
) -> Result<()> {
todo!("implement aggregate function registration");
}

fn register_virtual_table(&self, _name: &str, _table: Arc<dyn VirtualTable>) -> Result<()> {
todo!("implement virtual table registration");
}
}

#[derive(Debug, Clone, PartialEq)]
pub enum ExtFunc {
#[cfg(feature = "uuid")]
Uuid(UuidFunc),
use limbo_extension::{ExtensionApi, ResultCode, ScalarFunction, RESULT_ERROR, RESULT_OK};
pub use limbo_extension::{Value as ExtValue, ValueType as ExtValueType};
use std::{
ffi::{c_char, c_void, CStr},
rc::Rc,
};

extern "C" fn register_scalar_function(
ctx: *mut c_void,
name: *const c_char,
func: ScalarFunction,
) -> ResultCode {
let c_str = unsafe { CStr::from_ptr(name) };
let name_str = match c_str.to_str() {
Ok(s) => s.to_string(),
Err(_) => return RESULT_ERROR,
};
let db = unsafe { &*(ctx as *const Database) };
db.register_scalar_function_impl(name_str, func)
}

#[allow(unreachable_patterns)] // TODO: remove when more extension funcs added
impl std::fmt::Display for ExtFunc {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
#[cfg(feature = "uuid")]
Self::Uuid(uuidfn) => write!(f, "{}", uuidfn),
_ => write!(f, "unknown"),
}
impl Database {
fn register_scalar_function_impl(&self, name: String, func: ScalarFunction) -> ResultCode {
self.syms.borrow_mut().functions.insert(
name.to_string(),
Rc::new(ExternalFunc {
name: name.to_string(),
func,
}),
);
RESULT_OK
}
}

#[allow(unreachable_patterns)]
impl ExtFunc {
pub fn resolve_function(name: &str, num_args: usize) -> Option<ExtFunc> {
match name {
#[cfg(feature = "uuid")]
name => UuidFunc::resolve_function(name, num_args),
_ => None,
pub fn build_limbo_extension(&self) -> ExtensionApi {
ExtensionApi {
ctx: self as *const _ as *mut c_void,
register_scalar_function,
}
}
}

//pub fn init(db: &mut crate::Database) {
// #[cfg(feature = "uuid")]
// uuid::init(db);
//}
18 changes: 7 additions & 11 deletions core/function.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
use crate::ext::ExtFunc;
use std::fmt;
use std::fmt::{Debug, Display};
use std::sync::Arc;
use std::rc::Rc;

use limbo_extension::ScalarFunction;

pub struct ExternalFunc {
pub name: String,
pub func: Arc<dyn extension_api::ScalarFunction>,
pub func: ScalarFunction,
}

impl ExternalFunc {
pub fn new(name: &str, func: Arc<dyn extension_api::ScalarFunction>) -> Self {
pub fn new(name: &str, func: ScalarFunction) -> Self {
Self {
name: name.to_string(),
func,
Expand Down Expand Up @@ -306,8 +307,7 @@ pub enum Func {
Math(MathFunc),
#[cfg(feature = "json")]
Json(JsonFunc),
Extension(ExtFunc),
External(Arc<ExternalFunc>),
External(Rc<ExternalFunc>),
}

impl Display for Func {
Expand All @@ -318,7 +318,6 @@ impl Display for Func {
Self::Math(math_func) => write!(f, "{}", math_func),
#[cfg(feature = "json")]
Self::Json(json_func) => write!(f, "{}", json_func),
Self::Extension(ext_func) => write!(f, "{}", ext_func),
Self::External(generic_func) => write!(f, "{}", generic_func),
}
}
Expand Down Expand Up @@ -423,10 +422,7 @@ impl Func {
"tan" => Ok(Self::Math(MathFunc::Tan)),
"tanh" => Ok(Self::Math(MathFunc::Tanh)),
"trunc" => Ok(Self::Math(MathFunc::Trunc)),
_ => match ExtFunc::resolve_function(name, arg_count) {
Some(ext_func) => Ok(Self::Extension(ext_func)),
None => Err(()),
},
_ => Err(()),
}
}
}
45 changes: 27 additions & 18 deletions core/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ mod vdbe;
#[global_allocator]
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;

use extension_api::{Extension, ExtensionApi};
use fallible_iterator::FallibleIterator;
use libloading::{Library, Symbol};
use limbo_extension::{ExtensionApi, ExtensionEntryPoint, RESULT_OK};
use log::trace;
use schema::Schema;
use sqlite3_parser::ast;
Expand Down Expand Up @@ -131,7 +131,6 @@ impl Database {
_shared_wal: shared_wal.clone(),
syms,
};
// ext::init(&mut db);
let db = Arc::new(db);
let conn = Rc::new(Connection {
db: db.clone(),
Expand Down Expand Up @@ -165,31 +164,37 @@ impl Database {
pub fn define_scalar_function<S: AsRef<str>>(
&self,
name: S,
func: Arc<dyn extension_api::ScalarFunction>,
func: limbo_extension::ScalarFunction,
) {
let func = function::ExternalFunc {
name: name.as_ref().to_string(),
func: func.clone(),
func,
};
self.syms
.borrow_mut()
.functions
.insert(name.as_ref().to_string(), Arc::new(func));
.insert(name.as_ref().to_string(), func.into());
}

pub fn load_extension(&self, path: &str) -> Result<()> {
let api = Box::new(self.build_limbo_extension());
let lib =
unsafe { Library::new(path).map_err(|e| LimboError::ExtensionError(e.to_string()))? };
unsafe {
let register: Symbol<unsafe extern "C" fn(&dyn ExtensionApi) -> Box<dyn Extension>> =
lib.get(b"register_extension")
.map_err(|e| LimboError::ExtensionError(e.to_string()))?;
let extension = register(self);
extension
.load()
.map_err(|e| LimboError::ExtensionError(e.to_string()))?;
let entry: Symbol<ExtensionEntryPoint> = unsafe {
lib.get(b"register_extension")
.map_err(|e| LimboError::ExtensionError(e.to_string()))?
};
let api_ptr: *const ExtensionApi = Box::into_raw(api);
let result_code = entry(api_ptr);
if result_code == RESULT_OK {
self.syms.borrow_mut().extensions.push((lib, api_ptr));
Ok(())
} else {
let _ = unsafe { Box::from_raw(api_ptr.cast_mut()) }; // own this again so we dont leak
Err(LimboError::ExtensionError(
"Extension registration failed".to_string(),
))
}
Ok(())
}
}

Expand Down Expand Up @@ -318,7 +323,11 @@ impl Connection {
Cmd::ExplainQueryPlan(stmt) => {
match stmt {
ast::Stmt::Select(select) => {
let mut plan = prepare_select_plan(&self.schema.borrow(), *select)?;
let mut plan = prepare_select_plan(
&self.schema.borrow(),
*select,
&self.db.syms.borrow(),
)?;
optimize_plan(&mut plan)?;
println!("{}", plan);
}
Expand Down Expand Up @@ -484,8 +493,8 @@ impl Rows {
}

pub(crate) struct SymbolTable {
pub functions: HashMap<String, Arc<crate::function::ExternalFunc>>,
extensions: Vec<Rc<dyn Extension>>,
pub functions: HashMap<String, Rc<crate::function::ExternalFunc>>,
extensions: Vec<(libloading::Library, *const ExtensionApi)>,
}

impl std::fmt::Debug for SymbolTable {
Expand All @@ -508,7 +517,7 @@ impl SymbolTable {
&self,
name: &str,
_arg_count: usize,
) -> Option<Arc<crate::function::ExternalFunc>> {
) -> Option<Rc<crate::function::ExternalFunc>> {
self.functions.get(name).cloned()
}
}
Loading

0 comments on commit 3942582

Please sign in to comment.