Skip to content

Commit

Permalink
feat: support phrase match query (#38869)
Browse files Browse the repository at this point in the history
The relevant issue: #38930

---------

Signed-off-by: SpadeA-Tang <[email protected]>
  • Loading branch information
SpadeA-Tang authored Jan 12, 2025
1 parent a8a6564 commit 032292a
Show file tree
Hide file tree
Showing 29 changed files with 2,001 additions and 1,380 deletions.
1 change: 1 addition & 0 deletions internal/core/src/common/EasyAssert.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ enum ErrorCode {
OutOfRange = 2039,
GcpNativeError = 2040,
TextIndexNotFound = 2041,
InvalidParameter = 2042,

KnowhereError = 2099
};
Expand Down
34 changes: 30 additions & 4 deletions internal/core/src/exec/expression/UnaryExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -796,11 +796,12 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson(OffsetVector* input) {
template <typename T>
VectorPtr
PhyUnaryRangeFilterExpr::ExecRangeVisitorImpl(OffsetVector* input) {
if (expr_->op_type_ == proto::plan::OpType::TextMatch) {
if (expr_->op_type_ == proto::plan::OpType::TextMatch ||
expr_->op_type_ == proto::plan::OpType::PhraseMatch) {
if (has_offset_input_) {
PanicInfo(
OpTypeInvalid,
fmt::format("text match does not support iterative filter"));
fmt::format("match query does not support iterative filter"));
}
return ExecTextMatch();
}
Expand Down Expand Up @@ -1089,8 +1090,33 @@ VectorPtr
PhyUnaryRangeFilterExpr::ExecTextMatch() {
using Index = index::TextMatchIndex;
auto query = GetValueFromProto<std::string>(expr_->val_);
auto func = [](Index* index, const std::string& query) -> TargetBitmap {
return index->MatchQuery(query);
int64_t slop = 0;
if (expr_->op_type_ == proto::plan::PhraseMatch) {
// It should be larger than 0 in normal cases. Check it incase of receiving old version proto.
if (expr_->extra_values_.size() > 0) {
slop = GetValueFromProto<int64_t>(expr_->extra_values_[0]);
}
if (slop < 0 || slop > std::numeric_limits<uint32_t>::max()) {
throw SegcoreError(
ErrorCode::InvalidParameter,
fmt::format(
"Slop {} is invalid in phrase match query. Should be "
"within [0, UINT32_MAX].",
slop));
}
}
auto op_type = expr_->op_type_;
auto func = [op_type, slop](Index* index,
const std::string& query) -> TargetBitmap {
if (op_type == proto::plan::OpType::TextMatch) {
return index->MatchQuery(query);
} else if (op_type == proto::plan::OpType::PhraseMatch) {
return index->PhraseMatchQuery(query, slop);
} else {
PanicInfo(OpTypeInvalid,
"unsupported operator type for match query: {}",
op_type);
}
};
auto res = ProcessTextMatchIndex(func, query);
return res;
Expand Down
26 changes: 21 additions & 5 deletions internal/core/src/expr/ITypeExpr.h
Original file line number Diff line number Diff line change
Expand Up @@ -349,18 +349,33 @@ class ValueExpr : public ITypeExpr {

class UnaryRangeFilterExpr : public ITypeFilterExpr {
public:
explicit UnaryRangeFilterExpr(const ColumnInfo& column,
proto::plan::OpType op_type,
const proto::plan::GenericValue& val)
: ITypeFilterExpr(), column_(column), op_type_(op_type), val_(val) {
explicit UnaryRangeFilterExpr(
const ColumnInfo& column,
proto::plan::OpType op_type,
const proto::plan::GenericValue& val,
const std::vector<proto::plan::GenericValue>& extra_values)
: ITypeFilterExpr(),
column_(column),
op_type_(op_type),
val_(val),
extra_values_(extra_values) {
}

std::string
ToString() const override {
std::stringstream ss;
ss << "UnaryRangeFilterExpr: {columnInfo:" << column_.ToString()
<< " op_type:" << milvus::proto::plan::OpType_Name(op_type_)
<< " val:" << val_.DebugString() << "}";
<< " val:" << val_.DebugString() << " extra_values: [";

for (size_t i = 0; i < extra_values_.size(); i++) {
ss << extra_values_[i].DebugString();
if (i != extra_values_.size() - 1) {
ss << ", ";
}
}

ss << "]}";
return ss.str();
}

Expand Down Expand Up @@ -393,6 +408,7 @@ class UnaryRangeFilterExpr : public ITypeFilterExpr {
const ColumnInfo column_;
const proto::plan::OpType op_type_;
const proto::plan::GenericValue val_;
const std::vector<proto::plan::GenericValue> extra_values_;
};

class AlwaysTrueExpr : public ITypeFilterExpr {
Expand Down
21 changes: 21 additions & 0 deletions internal/core/src/index/TextMatchIndex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,4 +291,25 @@ TextMatchIndex::MatchQuery(const std::string& query) {
apply_hits(bitset, hits, true);
return bitset;
}

TargetBitmap
TextMatchIndex::PhraseMatchQuery(const std::string& query, uint32_t slop) {
if (shouldTriggerCommit()) {
Commit();
Reload();
}

// The count opeartion of tantivy may be get older cnt if the index is committed with new tantivy segment.
// So we cannot use the count operation to get the total count for bitmap.
// Just use the maximum offset of hits to get the total count for bitmap here.
auto hits = wrapper_->phrase_match_query(query, slop);
auto cnt = should_allocate_bitset_size(hits);
TargetBitmap bitset(cnt);
if (bitset.empty()) {
return bitset;
}
apply_hits(bitset, hits, true);
return bitset;
}

} // namespace milvus::index
3 changes: 3 additions & 0 deletions internal/core/src/index/TextMatchIndex.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ class TextMatchIndex : public InvertedIndexTantivy<std::string> {
TargetBitmap
MatchQuery(const std::string& query);

TargetBitmap
PhraseMatchQuery(const std::string& query, uint32_t slop);

private:
bool
shouldTriggerCommit();
Expand Down
9 changes: 8 additions & 1 deletion internal/core/src/query/PlanProto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,15 @@ ProtoParser::ParseUnaryRangeExprs(const proto::plan::UnaryRangeExpr& expr_pb) {
auto field_id = FieldId(column_info.field_id());
auto data_type = schema[field_id].get_data_type();
Assert(data_type == static_cast<DataType>(column_info.data_type()));
std::vector<::milvus::proto::plan::GenericValue> extra_values;
for (auto val : expr_pb.extra_values()) {
extra_values.emplace_back(val);
}
return std::make_shared<milvus::expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(column_info), expr_pb.op(), expr_pb.value());
expr::ColumnInfo(column_info),
expr_pb.op(),
expr_pb.value(),
extra_values);
}

expr::TypedExprPtr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ RustResult tantivy_regex_query(void *ptr, const char *pattern);

RustResult tantivy_match_query(void *ptr, const char *query);

RustResult tantivy_phrase_match_query(void *ptr, const char *query, uint32_t slop);

RustResult tantivy_register_tokenizer(void *ptr,
const char *tokenizer_name,
const char *analyzer_params);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
use tantivy::{
query::BooleanQuery,
query::{BooleanQuery, PhraseQuery},
tokenizer::{TextAnalyzer, TokenStream},
Term,
};

use crate::error::Result;
use crate::error::{Result, TantivyBindingError};
use crate::{index_reader::IndexReaderWrapper, tokenizer::standard_analyzer};

impl IndexReaderWrapper {
// split the query string into multiple tokens using index's default tokenizer,
// and then execute the disconjunction of term query.
pub(crate) fn match_query(&self, q: &str) -> Result<Vec<u32>> {
// clone the tokenizer to make `match_query` thread-safe.
let mut tokenizer = self
.index
.tokenizer_for_field(self.field)
Expand All @@ -27,6 +26,31 @@ impl IndexReaderWrapper {
self.search(&query)
}

// split the query string into multiple tokens using index's default tokenizer,
// and then execute the disconjunction of term query.
pub(crate) fn phrase_match_query(&self, q: &str, slop: u32) -> Result<Vec<u32>> {
// clone the tokenizer to make `match_query` thread-safe.
let mut tokenizer = self
.index
.tokenizer_for_field(self.field)
.unwrap_or(standard_analyzer(vec![]))
.clone();
let mut token_stream = tokenizer.token_stream(q);
let mut terms: Vec<Term> = Vec::new();
while token_stream.advance() {
let token = token_stream.token();
terms.push(Term::from_field_text(self.field, &token.text));
}
if terms.len() <= 1 {
// tantivy will panic when terms.len() <= 1, so we forward to text match instead.
let query = BooleanQuery::new_multiterms_query(terms);
return self.search(&query);
}
let terms = terms.into_iter().enumerate().collect();
let phrase_query = PhraseQuery::new_with_offset_and_slop(terms, slop);
self.search(&phrase_query)
}

pub(crate) fn register_tokenizer(&self, tokenizer_name: String, tokenizer: TextAnalyzer) {
self.index.tokenizers().register(&tokenizer_name, tokenizer)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,19 @@ pub extern "C" fn tantivy_match_query(ptr: *mut c_void, query: *const c_char) ->
}
}

#[no_mangle]
pub extern "C" fn tantivy_phrase_match_query(
ptr: *mut c_void,
query: *const c_char,
slop: u32,
) -> RustResult {
let real = ptr as *mut IndexReaderWrapper;
unsafe {
let query = cstr_to_str!(query);
(*real).phrase_match_query(query, slop).into()
}
}

#[no_mangle]
pub extern "C" fn tantivy_register_tokenizer(
ptr: *mut c_void,
Expand Down
13 changes: 13 additions & 0 deletions internal/core/thirdparty/tantivy/tantivy-wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,19 @@ struct TantivyIndexWrapper {
return RustArrayWrapper(std::move(res.result_->value.rust_array._0));
}

RustArrayWrapper
phrase_match_query(const std::string& query, uint32_t slop) {
auto array = tantivy_phrase_match_query(reader_, query.c_str(), slop);
auto res = RustResultWrapper(array);
AssertInfo(res.result_->success,
"TantivyIndexWrapper.phrase_match_query: {}",
res.result_->error);
AssertInfo(
res.result_->value.tag == Value::Tag::RustArray,
"TantivyIndexWrapper.phrase_match_query: invalid result type");
return RustArrayWrapper(std::move(res.result_->value.rust_array._0));
}

public:
inline IndexWriter
get_writer() {
Expand Down
3 changes: 2 additions & 1 deletion internal/core/unittest/test_array_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2461,7 +2461,8 @@ TEST(Expr, TestArrayStringMatch) {
milvus::expr::ColumnInfo(
string_array_fid, DataType::ARRAY, testcase.nested_path),
testcase.op_type,
value);
value,
std::vector<proto::plan::GenericValue>{});
BitsetType final;
auto plan =
std::make_shared<plan::FilterBitsNode>(DEFAULT_PLANNODE_ID, expr);
Expand Down
Loading

0 comments on commit 032292a

Please sign in to comment.