From ffcbf9c0e9292c7628297eab69f1027b3c37de20 Mon Sep 17 00:00:00 2001 From: PThorpe92 Date: Sun, 12 Jan 2025 17:32:20 -0500 Subject: [PATCH] Add tests for first extension --- Makefile | 9 ++- cli/app.rs | 5 +- core/lib.rs | 5 ++ core/types.rs | 25 +++---- extensions/uuid/src/lib.rs | 19 +++-- limbo_extension/src/lib.rs | 62 ++++++----------- testing/extensions.py | 139 +++++++++++++++++++++++++++++++++++++ 7 files changed, 194 insertions(+), 70 deletions(-) create mode 100755 testing/extensions.py diff --git a/Makefile b/Makefile index 30d84bcc..109a3f14 100644 --- a/Makefile +++ b/Makefile @@ -62,10 +62,15 @@ limbo-wasm: cargo build --package limbo-wasm --target wasm32-wasi .PHONY: limbo-wasm -test: limbo test-compat test-sqlite3 test-shell +test: limbo test-compat test-sqlite3 test-shell test-extensions .PHONY: test -test-shell: limbo +test-extensions: limbo + cargo build --package limbo_uuid + ./testing/extensions.py +.PHONY: test-extensions + +test-shell: limbo SQLITE_EXEC=$(SQLITE_EXEC) ./testing/shelltests.py .PHONY: test-shell diff --git a/cli/app.rs b/cli/app.rs index f3c5eb9e..5d14c0e3 100644 --- a/cli/app.rs +++ b/cli/app.rs @@ -270,6 +270,7 @@ impl Limbo { }; } + #[cfg(not(target_family = "wasm"))] fn handle_load_extension(&mut self, path: &str) -> Result<(), String> { self.conn.load_extension(path).map_err(|e| e.to_string()) } @@ -504,7 +505,9 @@ impl Limbo { let _ = self.writeln(e.to_string()); }; } - Command::LoadExtension => { + Command::LoadExtension => + { + #[cfg(not(target_family = "wasm"))] if let Err(e) = self.handle_load_extension(args[1]) { let _ = self.writeln(&e); } diff --git a/core/lib.rs b/core/lib.rs index edc5cf3a..c2dfb7ad 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -18,7 +18,9 @@ mod vdbe; static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; use fallible_iterator::FallibleIterator; +#[cfg(not(target_family = "wasm"))] use libloading::{Library, Symbol}; +#[cfg(not(target_family = "wasm"))] use limbo_extension::{ExtensionApi, ExtensionEntryPoint, RESULT_OK}; use log::trace; use schema::Schema; @@ -176,6 +178,7 @@ impl Database { .insert(name.as_ref().to_string(), func.into()); } + #[cfg(not(target_family = "wasm"))] pub fn load_extension(&self, path: &str) -> Result<()> { let api = Box::new(self.build_limbo_extension()); let lib = @@ -394,6 +397,7 @@ impl Connection { Ok(()) } + #[cfg(not(target_family = "wasm"))] pub fn load_extension(&self, path: &str) -> Result<()> { Database::load_extension(self.db.as_ref(), path) } @@ -496,6 +500,7 @@ impl Rows { pub(crate) struct SymbolTable { pub functions: HashMap>, + #[cfg(not(target_family = "wasm"))] extensions: Vec<(libloading::Library, *const ExtensionApi)>, } diff --git a/core/types.rs b/core/types.rs index 75f28a96..fdf6793f 100644 --- a/core/types.rs +++ b/core/types.rs @@ -107,35 +107,30 @@ impl OwnedValue { } pub fn from_ffi(v: &ExtValue) -> Self { - if v.value.is_null() { - return OwnedValue::Null; - } - match v.value_type { + match v.value_type() { ExtValueType::Null => OwnedValue::Null, ExtValueType::Integer => { - let int_ptr = v.value as *mut i64; - let integer = unsafe { *int_ptr }; - OwnedValue::Integer(integer) + let Some(int) = v.to_integer() else { + return OwnedValue::Null; + }; + OwnedValue::Integer(int) } ExtValueType::Float => { - let float_ptr = v.value as *mut f64; - let float = unsafe { *float_ptr }; + let Some(float) = v.to_float() else { + return OwnedValue::Null; + }; OwnedValue::Float(float) } ExtValueType::Text => { let Some(text) = v.to_text() else { return OwnedValue::Null; }; - OwnedValue::build_text(std::rc::Rc::new(unsafe { text.as_str().to_string() })) + OwnedValue::build_text(std::rc::Rc::new(text)) } ExtValueType::Blob => { - let Some(blob_ptr) = v.to_blob() else { + let Some(blob) = v.to_blob() else { return OwnedValue::Null; }; - let blob = unsafe { - let slice = std::slice::from_raw_parts(blob_ptr.data, blob_ptr.size as usize); - slice.to_vec() - }; OwnedValue::Blob(std::rc::Rc::new(blob)) } } diff --git a/extensions/uuid/src/lib.rs b/extensions/uuid/src/lib.rs index 92e9d5d4..e6a3a4f9 100644 --- a/extensions/uuid/src/lib.rs +++ b/extensions/uuid/src/lib.rs @@ -35,7 +35,7 @@ declare_scalar_functions! { uuid::Timestamp::now(ctx) } else { let arg = &args[0]; - match arg.value_type { + match arg.value_type() { ValueType::Integer => { let ctx = uuid::ContextV7::new(); let Some(int) = arg.to_integer() else { @@ -47,8 +47,7 @@ declare_scalar_functions! { let Some(text) = arg.to_text() else { return Value::null(); }; - let parsed = unsafe{text.as_str()}.parse::(); - match parsed { + match text.parse::() { Ok(unix) => { if unix <= 0 { return Value::null(); @@ -70,7 +69,7 @@ declare_scalar_functions! { let timestamp = if args.is_empty() { let ctx = uuid::ContextV7::new(); uuid::Timestamp::now(ctx) - } else if args[0].value_type == limbo_extension::ValueType::Integer { + } else if args[0].value_type() == limbo_extension::ValueType::Integer { let ctx = uuid::ContextV7::new(); let Some(int) = args[0].to_integer() else { return Value::null(); @@ -86,13 +85,12 @@ declare_scalar_functions! { #[args(1)] fn exec_ts_from_uuid7(args: &[Value]) -> Value { - match args[0].value_type { + match args[0].value_type() { ValueType::Blob => { let Some(blob) = &args[0].to_blob() else { return Value::null(); }; - let slice = unsafe{ std::slice::from_raw_parts(blob.data, blob.size as usize)}; - let uuid = uuid::Uuid::from_slice(slice).unwrap(); + let uuid = uuid::Uuid::from_slice(blob.as_slice()).unwrap(); let unix = uuid_to_unix(uuid.as_bytes()); Value::from_integer(unix as i64) } @@ -100,7 +98,7 @@ declare_scalar_functions! { let Some(text) = args[0].to_text() else { return Value::null(); }; - let Ok(uuid) = uuid::Uuid::parse_str(unsafe {text.as_str()}) else { + let Ok(uuid) = uuid::Uuid::parse_str(&text) else { return Value::null(); }; let unix = uuid_to_unix(uuid.as_bytes()); @@ -115,8 +113,7 @@ declare_scalar_functions! { let Some(blob) = args[0].to_blob() else { return Value::null(); }; - let slice = unsafe{ std::slice::from_raw_parts(blob.data, blob.size as usize)}; - let parsed = uuid::Uuid::from_slice(slice).ok().map(|u| u.to_string()); + let parsed = uuid::Uuid::from_slice(blob.as_slice()).ok().map(|u| u.to_string()); match parsed { Some(s) => Value::from_text(s), None => Value::null() @@ -128,7 +125,7 @@ declare_scalar_functions! { let Some(text) = args[0].to_text() else { return Value::null(); }; - match uuid::Uuid::parse_str(unsafe {text.as_str()}) { + match uuid::Uuid::parse_str(&text) { Ok(uuid) => { Value::from_blob(uuid.as_bytes().to_vec()) } diff --git a/limbo_extension/src/lib.rs b/limbo_extension/src/lib.rs index d07cc8ea..0666c588 100644 --- a/limbo_extension/src/lib.rs +++ b/limbo_extension/src/lib.rs @@ -50,19 +50,6 @@ macro_rules! register_scalar_functions { } } -/// Provide a cleaner interface to define scalar functions to extension authors -/// . e.g. -/// ``` -/// #[args(1)] -/// fn scalar_double(args: &[Value]) -> Value { -/// Value::from_integer(args[0].integer * 2) -/// } -/// -/// #[args(0..=2)] -/// fn scalar_sum(args: &[Value]) -> Value { -/// Value::from_integer(args.iter().map(|v| v.integer).sum()) -/// ``` -/// #[macro_export] macro_rules! declare_scalar_functions { ( @@ -100,7 +87,7 @@ macro_rules! declare_scalar_functions { } #[repr(C)] -#[derive(PartialEq, Eq)] +#[derive(PartialEq, Eq, Clone, Copy)] pub enum ValueType { Null, Integer, @@ -111,8 +98,8 @@ pub enum ValueType { #[repr(C)] pub struct Value { - pub value_type: ValueType, - pub value: *mut c_void, + value_type: ValueType, + value: *mut c_void, } impl std::fmt::Debug for Value { @@ -161,41 +148,27 @@ impl Default for TextValue { } impl TextValue { - pub fn new(text: *const u8, len: usize) -> Self { + pub(crate) fn new(text: *const u8, len: usize) -> Self { Self { text, len: len as u32, } } - /// # Safety - /// Safe to call if the pointer is null, returns None - /// if the value is not a text type or if the value is null - pub unsafe fn from_value(value: &Value) -> Option<&Self> { - if value.value_type != ValueType::Text { - return None; - } - if value.value.is_null() { - return None; - } - Some(&*(value.value as *const TextValue)) - } - - /// # Safety - /// If self.text is null we safely return an empty string but - /// the caller must ensure that the underlying value is valid utf8 - pub unsafe fn as_str(&self) -> &str { + fn as_str(&self) -> &str { if self.text.is_null() { return ""; } - std::str::from_utf8_unchecked(std::slice::from_raw_parts(self.text, self.len as usize)) + unsafe { + std::str::from_utf8_unchecked(std::slice::from_raw_parts(self.text, self.len as usize)) + } } } #[repr(C)] pub struct Blob { - pub data: *const u8, - pub size: u64, + data: *const u8, + size: u64, } impl std::fmt::Debug for Blob { @@ -218,6 +191,10 @@ impl Value { } } + pub fn value_type(&self) -> ValueType { + self.value_type + } + pub fn to_float(&self) -> Option { if self.value_type != ValueType::Float { return None; @@ -228,24 +205,27 @@ impl Value { Some(unsafe { *(self.value as *const f64) }) } - pub fn to_text(&self) -> Option<&TextValue> { + pub fn to_text(&self) -> Option { if self.value_type != ValueType::Text { return None; } if self.value.is_null() { return None; } - unsafe { Some(&*(self.value as *const TextValue)) } + let txt = unsafe { &*(self.value as *const TextValue) }; + Some(String::from(txt.as_str())) } - pub fn to_blob(&self) -> Option<&Blob> { + pub fn to_blob(&self) -> Option> { if self.value_type != ValueType::Blob { return None; } if self.value.is_null() { return None; } - unsafe { Some(&*(self.value as *const Blob)) } + let blob = unsafe { &*(self.value as *const Blob) }; + let slice = unsafe { std::slice::from_raw_parts(blob.data, blob.size as usize) }; + Some(slice.to_vec()) } pub fn to_integer(&self) -> Option { diff --git a/testing/extensions.py b/testing/extensions.py new file mode 100755 index 00000000..74383be9 --- /dev/null +++ b/testing/extensions.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python3 +import os +import subprocess +import select +import time +import uuid + +sqlite_exec = "./target/debug/limbo" +sqlite_flags = os.getenv("SQLITE_FLAGS", "-q").split(" ") + + +def init_limbo(): + pipe = subprocess.Popen( + [sqlite_exec, *sqlite_flags], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + bufsize=0, + ) + return pipe + + +def execute_sql(pipe, sql): + end_suffix = "END_OF_RESULT" + write_to_pipe(pipe, sql) + write_to_pipe(pipe, f"SELECT '{end_suffix}';\n") + stdout = pipe.stdout + stderr = pipe.stderr + output = "" + while True: + ready_to_read, _, error_in_pipe = select.select( + [stdout, stderr], [], [stdout, stderr] + ) + ready_to_read_or_err = set(ready_to_read + error_in_pipe) + if stderr in ready_to_read_or_err: + exit_on_error(stderr) + + if stdout in ready_to_read_or_err: + fragment = stdout.read(select.PIPE_BUF) + output += fragment.decode() + if output.rstrip().endswith(end_suffix): + output = output.rstrip().removesuffix(end_suffix) + break + output = strip_each_line(output) + return output + + +def strip_each_line(lines: str) -> str: + lines = lines.split("\n") + lines = [line.strip() for line in lines if line != ""] + return "\n".join(lines) + + +def write_to_pipe(pipe, command): + if pipe.stdin is None: + raise RuntimeError("Failed to write to shell") + pipe.stdin.write((command + "\n").encode()) + pipe.stdin.flush() + + +def exit_on_error(stderr): + while True: + ready_to_read, _, _ = select.select([stderr], [], []) + if not ready_to_read: + break + print(stderr.read().decode(), end="") + exit(1) + + +def run_test(pipe, sql, validator=None): + print(f"Running test: {sql}") + result = execute_sql(pipe, sql) + if validator is not None: + if not validator(result): + print(f"Test FAILED: {sql}") + print(f"Returned: {result}") + raise Exception("Validation failed") + print("Test PASSED") + + +def validate_blob(result): + # HACK: blobs are difficult to test because the shell + # tries to return them as utf8 strings, so we call hex + # and assert they are valid hex digits + return int(result, 16) is not None + + +def validate_string_uuid(result): + return len(result) == 36 and result.count("-") == 4 + + +def returns_null(result): + return result == "" or result == b"\n" or result == b"" + + +def assert_now_unixtime(result): + return result == str(int(time.time())) + + +def assert_specific_time(result): + return result == "1736720789" + + +def main(): + specific_time = "01945ca0-3189-76c0-9a8f-caf310fc8b8e" + extension_path = "./target/debug/liblimbo_uuid.so" + pipe = init_limbo() + try: + # before extension loads, assert no function + run_test(pipe, "SELECT uuid4();", returns_null) + run_test(pipe, "SELECT uuid4_str();", returns_null) + run_test(pipe, f".load {extension_path}", returns_null) + print("Extension loaded successfully.") + run_test(pipe, "SELECT hex(uuid4());", validate_blob) + run_test(pipe, "SELECT uuid4_str();", validate_string_uuid) + run_test(pipe, "SELECT hex(uuid7());", validate_blob) + run_test( + pipe, + "SELECT uuid7_timestamp_ms(uuid7()) / 1000;", + ) + run_test(pipe, "SELECT uuid7_str();", validate_string_uuid) + run_test(pipe, "SELECT uuid_str(uuid7());", validate_string_uuid) + run_test(pipe, "SELECT hex(uuid_blob(uuid7_str()));", validate_blob) + run_test(pipe, "SELECT uuid_str(uuid_blob(uuid7_str()));", validate_string_uuid) + run_test( + pipe, + f"SELECT uuid7_timestamp_ms('{specific_time}') / 1000;", + assert_specific_time, + ) + except Exception as e: + print(f"Test FAILED: {e}") + pipe.terminate() + exit(1) + pipe.terminate() + print("All tests passed successfully.") + + +if __name__ == "__main__": + main()