Skip to content

Commit

Permalink
feat(eval): support vector ops (#867)
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Chi Z <[email protected]>
Signed-off-by: Alex Chi <[email protected]>
  • Loading branch information
skyzh authored Jan 11, 2025
1 parent fe51708 commit a590c6a
Show file tree
Hide file tree
Showing 12 changed files with 186 additions and 16 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[build-system]
requires = ["maturin>=0.13,<0.14"]
requires = ["maturin>=1,<2"]
build-backend = "maturin"

[project]
Expand Down
66 changes: 66 additions & 0 deletions src/array/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,72 @@ impl ArrayImpl {
};
Ok(A::new_string(unary_op(a.as_ref(), |s| s.replace(from, to))))
}

pub fn vector_l2_distance(&self, other: &ArrayImpl) -> Result {
let ArrayImpl::Vector(a) = self else {
return Err(ConvertError::NoBinaryOp(
"vector_l2_distance".into(),
self.type_string(),
other.type_string(),
));
};
let ArrayImpl::Vector(b) = other else {
return Err(ConvertError::NoBinaryOp(
"vector_l2_distance".into(),
other.type_string(),
self.type_string(),
));
};
Ok(ArrayImpl::new_float64(binary_op(
a.as_ref(),
b.as_ref(),
|a, b| a.l2_distance(b),
)))
}

pub fn vector_cosine_distance(&self, other: &ArrayImpl) -> Result {
let ArrayImpl::Vector(a) = self else {
return Err(ConvertError::NoBinaryOp(
"vector_cosine_distance".into(),
self.type_string(),
other.type_string(),
));
};
let ArrayImpl::Vector(b) = other else {
return Err(ConvertError::NoBinaryOp(
"vector_cosine_distance".into(),
other.type_string(),
self.type_string(),
));
};
Ok(ArrayImpl::new_float64(binary_op(
a.as_ref(),
b.as_ref(),
|a, b| a.cosine_distance(b),
)))
}

pub fn vector_neg_inner_product(&self, other: &ArrayImpl) -> Result {
let ArrayImpl::Vector(a) = self else {
return Err(ConvertError::NoBinaryOp(
"vector_neg_inner_product".into(),
self.type_string(),
other.type_string(),
));
};
let ArrayImpl::Vector(b) = other else {
return Err(ConvertError::NoBinaryOp(
"vector_neg_inner_product".into(),
other.type_string(),
self.type_string(),
));
};
Ok(ArrayImpl::new_float64(binary_op(
a.as_ref(),
b.as_ref(),
|a, b| -a.dot_product(b),
)))
}
}

/// Implement aggregation functions.
Expand Down
6 changes: 6 additions & 0 deletions src/binder/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,12 @@ impl Binder {
And => Node::And([l, r]),
Or => Node::Or([l, r]),
Xor => Node::Xor([l, r]),
Spaceship => Node::VectorCosineDistance([l, r]),
Custom(name) => match name.as_str() {
"<->" => Node::VectorL2Distance([l, r]),
"<#>" => Node::VectorNegtiveInnerProduct([l, r]),
op => todo!("bind custom binary op: {:?}", op),
},
_ => todo!("bind binary op: {:?}", op),
};
Ok(self.egraph.add(node))
Expand Down
15 changes: 15 additions & 0 deletions src/executor/evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,21 @@ impl<'a> Evaluator<'a> {
};
a.replace(from, to)
}
VectorL2Distance([a, b]) => {
let a = self.next(*a).eval(chunk)?;
let b = self.next(*b).eval(chunk)?;
a.vector_l2_distance(&b)
}
VectorCosineDistance([a, b]) => {
let a = self.next(*a).eval(chunk)?;
let b = self.next(*b).eval(chunk)?;
a.vector_cosine_distance(&b)
}
VectorNegtiveInnerProduct([a, b]) => {
let a = self.next(*a).eval(chunk)?;
let b = self.next(*b).eval(chunk)?;
a.vector_neg_inner_product(&b)
}
e => {
if let Some((op, a, b)) = e.binary_op() {
let left = self.next(a).eval(chunk)?;
Expand Down
14 changes: 0 additions & 14 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ pub mod types;
/// Utilities.
pub mod utils;

#[cfg(feature = "python")]
use python::open;
#[cfg(feature = "jemalloc")]
use tikv_jemallocator::Jemalloc;

Expand All @@ -63,15 +61,3 @@ pub use self::db::{Database, Error};
#[cfg(feature = "jemalloc")]
#[global_allocator]
static GLOBAL: Jemalloc = Jemalloc;

/// Python Extension
#[cfg(feature = "python")]
use pyo3::{prelude::*, wrap_pyfunction};

/// The entry point of python module must be in the lib.rs
#[cfg(feature = "python")]
#[pymodule]
fn risinglight(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(open, m)?)?;
Ok(())
}
23 changes: 23 additions & 0 deletions src/planner/explain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,29 @@ impl<'a> Explain<'a> {
],
),

// vector functions
VectorL2Distance([a, b]) => Pretty::childless_record(
"VectorL2Distance",
vec![
("lhs", self.expr(a).pretty()),
("rhs", self.expr(b).pretty()),
],
),
VectorCosineDistance([a, b]) => Pretty::childless_record(
"VectorCosineDistance",
vec![
("lhs", self.expr(a).pretty()),
("rhs", self.expr(b).pretty()),
],
),
VectorNegtiveInnerProduct([a, b]) => Pretty::childless_record(
"VectorDotProduct",
vec![
("lhs", self.expr(a).pretty()),
("rhs", self.expr(b).pretty()),
],
),

// aggregations
RowCount | RowNumber => enode.to_string().into(),
Max(a) | Min(a) | Sum(a) | Avg(a) | Count(a) | First(a) | Last(a)
Expand Down
5 changes: 5 additions & 0 deletions src/planner/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ define_language! {
"replace" = Replace([Id; 3]), // (replace expr pattern replacement)
"substring" = Substring([Id; 3]), // (substring expr start length)

// vector functions
"<->" = VectorL2Distance([Id; 2]),
"<#>" = VectorNegtiveInnerProduct([Id; 2]),
"<=>" = VectorCosineDistance([Id; 2]),

// aggregations
"max" = Max(Id),
"min" = Min(Id),
Expand Down
8 changes: 8 additions & 0 deletions src/planner/rules/type_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,14 @@ pub fn analyze_type(
(a == DataType::String && b == DataType::String).then_some(DataType::Bool)
}),

// vector ops
VectorL2Distance([a, b])
| VectorCosineDistance([a, b])
| VectorNegtiveInnerProduct([a, b]) => merge(enode, [x(a)?, x(b)?], |[a, b]| {
(matches!(a, DataType::Vector(_)) && matches!(b, DataType::Vector(_)))
.then_some(DataType::Float64)
}),

// bool ops
Not(a) => check(enode, x(a)?, |a| a == &DataType::Bool),
Gt([a, b]) | Lt([a, b]) | GtEq([a, b]) | LtEq([a, b]) | Eq([a, b]) | NotEq([a, b]) => {
Expand Down
8 changes: 8 additions & 0 deletions src/python/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use tokio::runtime::Runtime;

use crate::storage::SecondaryStorageOptions;
use crate::Database;

#[pyclass]
pub struct PythonDatabase {
runtime: Runtime,
Expand Down Expand Up @@ -61,6 +62,13 @@ pub fn open_in_memory() -> PyResult<PythonDatabase> {
Ok(PythonDatabase { runtime, database })
}

#[pymodule]
fn risinglight(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(open, m)?)?;
m.add_function(wrap_pyfunction!(open_in_memory, m)?)?;
Ok(())
}

use crate::types::DataValue;
/// Convert datachunk into Python List
pub fn datachunk_to_python_list(py: Python, chunk: &Chunk) -> Vec<Vec<PyObject>> {
Expand Down
2 changes: 1 addition & 1 deletion src/types/blob.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ impl fmt::Display for BlobRef {
/// A slice of a vector.
#[repr(transparent)]
#[derive(PartialEq, Eq, PartialOrd, Ord, RefCast, Hash)]
pub struct VectorRef([F64]);
pub struct VectorRef(pub(crate) [F64]);

impl VectorRef {
pub fn new(values: &[F64]) -> &Self {
Expand Down
38 changes: 38 additions & 0 deletions src/types/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,44 @@ impl Deref for Vector {
}
}

impl VectorRef {
pub fn norm_squared(&self) -> F64 {
let sum: f64 = self.0.iter().map(|a| a.powi(2)).sum();
F64::from(sum)
}

pub fn norm(&self) -> F64 {
F64::from(self.norm_squared().sqrt())
}

pub fn l2_distance(&self, other: &VectorRef) -> F64 {
let sum = self
.0
.iter()
.zip(other.0.iter())
.map(|(a, b)| (a.0 - b.0).powi(2))
.sum::<f64>();
F64::from(sum.sqrt())
}

pub fn cosine_distance(&self, other: &VectorRef) -> F64 {
let dot_product = self.dot_product(other);
let norm_self_squared = self.norm_squared();
let norm_other_squared = other.norm_squared();
F64::from(1.0) - dot_product / (norm_self_squared * norm_other_squared).sqrt()
}

pub fn dot_product(&self, other: &VectorRef) -> F64 {
let sum = self
.0
.iter()
.zip(other.0.iter())
.map(|(a, b)| a.0 * b.0)
.sum::<f64>();
F64::from(sum)
}
}

/// An error which can be returned when parsing a blob.
#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
pub enum ParseVectorError {
Expand Down
15 changes: 15 additions & 0 deletions tests/sql/vector.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# vector
statement ok
create table t (a vector(3) not null);

statement ok
insert into t values ('[-1, -2.0, -3]'), ('[1, 2.0, 3]');

query RRR
select a <-> '[0, 0, 0]'::VECTOR(3), a <=> '[1, 1, 1]'::VECTOR(3), a <#> '[1, 1, 1]'::VECTOR(3) from t;
----
3.7416573867739413 1.9258200997725514 6
3.7416573867739413 0.07417990022744858 -6

statement ok
drop table t

0 comments on commit a590c6a

Please sign in to comment.