diff --git a/pyproject.toml b/pyproject.toml index ea9ecc51..6212f366 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["maturin>=0.13,<0.14"] +requires = ["maturin>=1,<2"] build-backend = "maturin" [project] diff --git a/src/array/ops.rs b/src/array/ops.rs index d10673d9..6e2869e9 100644 --- a/src/array/ops.rs +++ b/src/array/ops.rs @@ -634,6 +634,72 @@ impl ArrayImpl { }; Ok(A::new_string(unary_op(a.as_ref(), |s| s.replace(from, to)))) } + + pub fn vector_l2_distance(&self, other: &ArrayImpl) -> Result { + let ArrayImpl::Vector(a) = self else { + return Err(ConvertError::NoBinaryOp( + "vector_l2_distance".into(), + self.type_string(), + other.type_string(), + )); + }; + let ArrayImpl::Vector(b) = other else { + return Err(ConvertError::NoBinaryOp( + "vector_l2_distance".into(), + other.type_string(), + self.type_string(), + )); + }; + Ok(ArrayImpl::new_float64(binary_op( + a.as_ref(), + b.as_ref(), + |a, b| a.l2_distance(b), + ))) + } + + pub fn vector_cosine_distance(&self, other: &ArrayImpl) -> Result { + let ArrayImpl::Vector(a) = self else { + return Err(ConvertError::NoBinaryOp( + "vector_cosine_distance".into(), + self.type_string(), + other.type_string(), + )); + }; + let ArrayImpl::Vector(b) = other else { + return Err(ConvertError::NoBinaryOp( + "vector_cosine_distance".into(), + other.type_string(), + self.type_string(), + )); + }; + Ok(ArrayImpl::new_float64(binary_op( + a.as_ref(), + b.as_ref(), + |a, b| a.cosine_distance(b), + ))) + } + + pub fn vector_neg_inner_product(&self, other: &ArrayImpl) -> Result { + let ArrayImpl::Vector(a) = self else { + return Err(ConvertError::NoBinaryOp( + "vector_neg_inner_product".into(), + self.type_string(), + other.type_string(), + )); + }; + let ArrayImpl::Vector(b) = other else { + return Err(ConvertError::NoBinaryOp( + "vector_neg_inner_product".into(), + other.type_string(), + self.type_string(), + )); + }; + Ok(ArrayImpl::new_float64(binary_op( + a.as_ref(), + b.as_ref(), + |a, b| -a.dot_product(b), + ))) + } } /// Implement aggregation functions. diff --git a/src/binder/expr.rs b/src/binder/expr.rs index fa1627f6..df2c6bd6 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -142,6 +142,12 @@ impl Binder { And => Node::And([l, r]), Or => Node::Or([l, r]), Xor => Node::Xor([l, r]), + Spaceship => Node::VectorCosineDistance([l, r]), + Custom(name) => match name.as_str() { + "<->" => Node::VectorL2Distance([l, r]), + "<#>" => Node::VectorNegtiveInnerProduct([l, r]), + op => todo!("bind custom binary op: {:?}", op), + }, _ => todo!("bind binary op: {:?}", op), }; Ok(self.egraph.add(node)) diff --git a/src/executor/evaluator.rs b/src/executor/evaluator.rs index f39690d0..cc1259fd 100644 --- a/src/executor/evaluator.rs +++ b/src/executor/evaluator.rs @@ -132,6 +132,21 @@ impl<'a> Evaluator<'a> { }; a.replace(from, to) } + VectorL2Distance([a, b]) => { + let a = self.next(*a).eval(chunk)?; + let b = self.next(*b).eval(chunk)?; + a.vector_l2_distance(&b) + } + VectorCosineDistance([a, b]) => { + let a = self.next(*a).eval(chunk)?; + let b = self.next(*b).eval(chunk)?; + a.vector_cosine_distance(&b) + } + VectorNegtiveInnerProduct([a, b]) => { + let a = self.next(*a).eval(chunk)?; + let b = self.next(*b).eval(chunk)?; + a.vector_neg_inner_product(&b) + } e => { if let Some((op, a, b)) = e.binary_op() { let left = self.next(a).eval(chunk)?; diff --git a/src/lib.rs b/src/lib.rs index a9ab380f..665fd784 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -52,8 +52,6 @@ pub mod types; /// Utilities. pub mod utils; -#[cfg(feature = "python")] -use python::open; #[cfg(feature = "jemalloc")] use tikv_jemallocator::Jemalloc; @@ -63,15 +61,3 @@ pub use self::db::{Database, Error}; #[cfg(feature = "jemalloc")] #[global_allocator] static GLOBAL: Jemalloc = Jemalloc; - -/// Python Extension -#[cfg(feature = "python")] -use pyo3::{prelude::*, wrap_pyfunction}; - -/// The entry point of python module must be in the lib.rs -#[cfg(feature = "python")] -#[pymodule] -fn risinglight(m: &Bound<'_, PyModule>) -> PyResult<()> { - m.add_function(wrap_pyfunction!(open, m)?)?; - Ok(()) -} diff --git a/src/planner/explain.rs b/src/planner/explain.rs index efabd09f..3b28d944 100644 --- a/src/planner/explain.rs +++ b/src/planner/explain.rs @@ -185,6 +185,29 @@ impl<'a> Explain<'a> { ], ), + // vector functions + VectorL2Distance([a, b]) => Pretty::childless_record( + "VectorL2Distance", + vec![ + ("lhs", self.expr(a).pretty()), + ("rhs", self.expr(b).pretty()), + ], + ), + VectorCosineDistance([a, b]) => Pretty::childless_record( + "VectorCosineDistance", + vec![ + ("lhs", self.expr(a).pretty()), + ("rhs", self.expr(b).pretty()), + ], + ), + VectorNegtiveInnerProduct([a, b]) => Pretty::childless_record( + "VectorDotProduct", + vec![ + ("lhs", self.expr(a).pretty()), + ("rhs", self.expr(b).pretty()), + ], + ), + // aggregations RowCount | RowNumber => enode.to_string().into(), Max(a) | Min(a) | Sum(a) | Avg(a) | Count(a) | First(a) | Last(a) diff --git a/src/planner/mod.rs b/src/planner/mod.rs index e00cd261..2866b61f 100644 --- a/src/planner/mod.rs +++ b/src/planner/mod.rs @@ -69,6 +69,11 @@ define_language! { "replace" = Replace([Id; 3]), // (replace expr pattern replacement) "substring" = Substring([Id; 3]), // (substring expr start length) + // vector functions + "<->" = VectorL2Distance([Id; 2]), + "<#>" = VectorNegtiveInnerProduct([Id; 2]), + "<=>" = VectorCosineDistance([Id; 2]), + // aggregations "max" = Max(Id), "min" = Min(Id), diff --git a/src/planner/rules/type_.rs b/src/planner/rules/type_.rs index 21cd4c2f..54fdb876 100644 --- a/src/planner/rules/type_.rs +++ b/src/planner/rules/type_.rs @@ -77,6 +77,14 @@ pub fn analyze_type( (a == DataType::String && b == DataType::String).then_some(DataType::Bool) }), + // vector ops + VectorL2Distance([a, b]) + | VectorCosineDistance([a, b]) + | VectorNegtiveInnerProduct([a, b]) => merge(enode, [x(a)?, x(b)?], |[a, b]| { + (matches!(a, DataType::Vector(_)) && matches!(b, DataType::Vector(_))) + .then_some(DataType::Float64) + }), + // bool ops Not(a) => check(enode, x(a)?, |a| a == &DataType::Bool), Gt([a, b]) | Lt([a, b]) | GtEq([a, b]) | LtEq([a, b]) | Eq([a, b]) | NotEq([a, b]) => { diff --git a/src/python/mod.rs b/src/python/mod.rs index e8059f09..e4310af3 100644 --- a/src/python/mod.rs +++ b/src/python/mod.rs @@ -7,6 +7,7 @@ use tokio::runtime::Runtime; use crate::storage::SecondaryStorageOptions; use crate::Database; + #[pyclass] pub struct PythonDatabase { runtime: Runtime, @@ -61,6 +62,13 @@ pub fn open_in_memory() -> PyResult { Ok(PythonDatabase { runtime, database }) } +#[pymodule] +fn risinglight(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_function(wrap_pyfunction!(open, m)?)?; + m.add_function(wrap_pyfunction!(open_in_memory, m)?)?; + Ok(()) +} + use crate::types::DataValue; /// Convert datachunk into Python List pub fn datachunk_to_python_list(py: Python, chunk: &Chunk) -> Vec> { diff --git a/src/types/blob.rs b/src/types/blob.rs index 4ffdf214..3df48aa1 100644 --- a/src/types/blob.rs +++ b/src/types/blob.rs @@ -159,7 +159,7 @@ impl fmt::Display for BlobRef { /// A slice of a vector. #[repr(transparent)] #[derive(PartialEq, Eq, PartialOrd, Ord, RefCast, Hash)] -pub struct VectorRef([F64]); +pub struct VectorRef(pub(crate) [F64]); impl VectorRef { pub fn new(values: &[F64]) -> &Self { diff --git a/src/types/vector.rs b/src/types/vector.rs index 1f733810..14922918 100644 --- a/src/types/vector.rs +++ b/src/types/vector.rs @@ -62,6 +62,44 @@ impl Deref for Vector { } } +impl VectorRef { + pub fn norm_squared(&self) -> F64 { + let sum: f64 = self.0.iter().map(|a| a.powi(2)).sum(); + F64::from(sum) + } + + pub fn norm(&self) -> F64 { + F64::from(self.norm_squared().sqrt()) + } + + pub fn l2_distance(&self, other: &VectorRef) -> F64 { + let sum = self + .0 + .iter() + .zip(other.0.iter()) + .map(|(a, b)| (a.0 - b.0).powi(2)) + .sum::(); + F64::from(sum.sqrt()) + } + + pub fn cosine_distance(&self, other: &VectorRef) -> F64 { + let dot_product = self.dot_product(other); + let norm_self_squared = self.norm_squared(); + let norm_other_squared = other.norm_squared(); + F64::from(1.0) - dot_product / (norm_self_squared * norm_other_squared).sqrt() + } + + pub fn dot_product(&self, other: &VectorRef) -> F64 { + let sum = self + .0 + .iter() + .zip(other.0.iter()) + .map(|(a, b)| a.0 * b.0) + .sum::(); + F64::from(sum) + } +} + /// An error which can be returned when parsing a blob. #[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)] pub enum ParseVectorError { diff --git a/tests/sql/vector.slt b/tests/sql/vector.slt new file mode 100644 index 00000000..a016fc7f --- /dev/null +++ b/tests/sql/vector.slt @@ -0,0 +1,15 @@ +# vector +statement ok +create table t (a vector(3) not null); + +statement ok +insert into t values ('[-1, -2.0, -3]'), ('[1, 2.0, 3]'); + +query RRR +select a <-> '[0, 0, 0]'::VECTOR(3), a <=> '[1, 1, 1]'::VECTOR(3), a <#> '[1, 1, 1]'::VECTOR(3) from t; +---- +3.7416573867739413 1.9258200997725514 6 +3.7416573867739413 0.07417990022744858 -6 + +statement ok +drop table t