From 57422cb2ed7b763d83ca77a13e4e9f031a694d96 Mon Sep 17 00:00:00 2001 From: zhuwenxing Date: Mon, 2 Sep 2024 17:49:02 +0800 Subject: [PATCH 1/3] test: add array inverted index function test (#35874) /kind improvement --------- Signed-off-by: zhuwenxing --- tests/python_client/common/common_func.py | 142 +++++++++++++++++++ tests/python_client/testcases/test_query.py | 70 +++++++-- tests/python_client/testcases/test_search.py | 56 +++++++- 3 files changed, 254 insertions(+), 14 deletions(-) diff --git a/tests/python_client/common/common_func.py b/tests/python_client/common/common_func.py index d147635f4472c..fbe89d171511c 100644 --- a/tests/python_client/common/common_func.py +++ b/tests/python_client/common/common_func.py @@ -63,6 +63,148 @@ def prepare_param_info(self, host, port, handler, replica_num, user, password, s param_info = ParamInfo() +def generate_array_dataset(size, array_length, hit_probabilities, target_values): + dataset = [] + target_array_length = target_values.get('array_length_field', None) + target_array_access = target_values.get('array_access', None) + all_target_values = set( + val for sublist in target_values.values() for val in (sublist if isinstance(sublist, list) else [sublist])) + for i in range(size): + entry = {"id": i} + + # Generate random arrays for each condition + for condition in hit_probabilities.keys(): + available_values = [val for val in range(1, 100) if val not in all_target_values] + array = random.sample(available_values, array_length) + + # Ensure the array meets the condition based on its probability + if random.random() < hit_probabilities[condition]: + if condition == 'contains': + if target_values[condition] not in array: + array[random.randint(0, array_length - 1)] = target_values[condition] + elif condition == 'contains_any': + if not any(val in array for val in target_values[condition]): + array[random.randint(0, array_length - 1)] = random.choice(target_values[condition]) + elif condition == 'contains_all': + indices = random.sample(range(array_length), len(target_values[condition])) + for idx, val in zip(indices, target_values[condition]): + array[idx] = val + elif condition == 'equals': + array = target_values[condition][:] + elif condition == 'array_length_field': + array = [random.randint(0, 10) for _ in range(target_array_length)] + elif condition == 'array_access': + array = [random.randint(0, 10) for _ in range(random.randint(10, 20))] + array[target_array_access[0]] = target_array_access[1] + else: + raise ValueError(f"Unknown condition: {condition}") + + entry[condition] = array + + dataset.append(entry) + + return dataset + +def prepare_array_test_data(data_size, hit_rate=0.005, dim=128): + size = data_size # Number of arrays in the dataset + array_length = 10 # Length of each array + + # Probabilities that an array hits the target condition + hit_probabilities = { + 'contains': hit_rate, + 'contains_any': hit_rate, + 'contains_all': hit_rate, + 'equals': hit_rate, + 'array_length_field': hit_rate, + 'array_access': hit_rate + } + + # Target values for each condition + target_values = { + 'contains': 42, + 'contains_any': [21, 37, 42], + 'contains_all': [15, 30], + 'equals': [1,2,3,4,5], + 'array_length_field': 5, # array length == 5 + 'array_access': [0, 5] # index=0, and value == 5 + } + + # Generate dataset + dataset = generate_array_dataset(size, array_length, hit_probabilities, target_values) + data = { + "id": pd.Series([x["id"] for x in dataset]), + "contains": pd.Series([x["contains"] for x in dataset]), + "contains_any": pd.Series([x["contains_any"] for x in dataset]), + "contains_all": pd.Series([x["contains_all"] for x in dataset]), + "equals": pd.Series([x["equals"] for x in dataset]), + "array_length_field": pd.Series([x["array_length_field"] for x in dataset]), + "array_access": pd.Series([x["array_access"] for x in dataset]), + "emb": pd.Series([np.array([random.random() for j in range(dim)], dtype=np.dtype("float32")) for _ in + range(size)]) + } + # Define testing conditions + contains_value = target_values['contains'] + contains_any_values = target_values['contains_any'] + contains_all_values = target_values['contains_all'] + equals_array = target_values['equals'] + + # Perform tests + contains_result = [d for d in dataset if contains_value in d["contains"]] + contains_any_result = [d for d in dataset if any(val in d["contains_any"] for val in contains_any_values)] + contains_all_result = [d for d in dataset if all(val in d["contains_all"] for val in contains_all_values)] + equals_result = [d for d in dataset if d["equals"] == equals_array] + array_length_result = [d for d in dataset if len(d["array_length_field"]) == target_values['array_length_field']] + array_access_result = [d for d in dataset if d["array_access"][0] == target_values['array_access'][1]] + # Calculate and log.info proportions + contains_ratio = len(contains_result) / size + contains_any_ratio = len(contains_any_result) / size + contains_all_ratio = len(contains_all_result) / size + equals_ratio = len(equals_result) / size + array_length_ratio = len(array_length_result) / size + array_access_ratio = len(array_access_result) / size + + log.info(f"\nProportion of arrays that contain the value: {contains_ratio}") + log.info(f"Proportion of arrays that contain any of the values: {contains_any_ratio}") + log.info(f"Proportion of arrays that contain all of the values: {contains_all_ratio}") + log.info(f"Proportion of arrays that equal the target array: {equals_ratio}") + log.info(f"Proportion of arrays that have the target array length: {array_length_ratio}") + log.info(f"Proportion of arrays that have the target array access: {array_access_ratio}") + + + + train_df = pd.DataFrame(data) + + target_id = { + "contains": [r["id"] for r in contains_result], + "contains_any": [r["id"] for r in contains_any_result], + "contains_all": [r["id"] for r in contains_all_result], + "equals": [r["id"] for r in equals_result], + "array_length": [r["id"] for r in array_length_result], + "array_access": [r["id"] for r in array_access_result] + } + target_id_list = [target_id[key] for key in ["contains", "contains_any", "contains_all", "equals", "array_length", "array_access"]] + + + filters = [ + "array_contains(contains, 42)", + "array_contains_any(contains_any, [21, 37, 42])", + "array_contains_all(contains_all, [15, 30])", + "equals == [1,2,3,4,5]", + "array_length(array_length_field) == 5", + "array_access[0] == 5" + + ] + query_expr = [] + for i in range(len(filters)): + item = { + "expr": filters[i], + "ground_truth": target_id_list[i], + } + query_expr.append(item) + return train_df, query_expr + + + def gen_unique_str(str_value=None): prefix = "".join(random.choice(string.ascii_letters + string.digits) for _ in range(8)) return "test_" + prefix if str_value is None else str_value + "_" + prefix diff --git a/tests/python_client/testcases/test_query.py b/tests/python_client/testcases/test_query.py index 435483e0e0268..41f15b08b402b 100644 --- a/tests/python_client/testcases/test_query.py +++ b/tests/python_client/testcases/test_query.py @@ -7,6 +7,10 @@ from common.code_mapping import ConnectionErrorMessage as cem from base.client_base import TestcaseBase from pymilvus.orm.types import CONSISTENCY_STRONG, CONSISTENCY_BOUNDED, CONSISTENCY_EVENTUALLY +from pymilvus import ( + FieldSchema, CollectionSchema, DataType, + Collection +) import threading from pymilvus import DefaultConfig from datetime import datetime @@ -1520,7 +1524,7 @@ def test_query_invalid_output_fields(self): def test_query_output_fields_simple_wildcard(self): """ target: test query output_fields with simple wildcard (* and %) - method: specify output_fields as "*" + method: specify output_fields as "*" expected: output all scale field; output all fields """ # init collection with fields: int64, float, float_vec, float_vector1 @@ -2566,7 +2570,7 @@ def test_query_multi_logical_exprs(self): """ target: test the scenario which query with many logical expressions method: 1. create collection - 3. query the expr that like: int64 == 0 || int64 == 1 ........ + 3. query the expr that like: int64 == 0 || int64 == 1 ........ expected: run successfully """ c_name = cf.gen_unique_str(prefix) @@ -2577,14 +2581,14 @@ def test_query_multi_logical_exprs(self): collection_w.load() multi_exprs = " || ".join(f'{default_int_field_name} == {i}' for i in range(60)) _, check_res = collection_w.query(multi_exprs, output_fields=[f'{default_int_field_name}']) - assert(check_res == True) + assert(check_res == True) @pytest.mark.tags(CaseLabel.L0) def test_search_multi_logical_exprs(self): """ target: test the scenario which search with many logical expressions method: 1. create collection - 3. search with the expr that like: int64 == 0 || int64 == 1 ........ + 3. search with the expr that like: int64 == 0 || int64 == 1 ........ expected: run successfully """ c_name = cf.gen_unique_str(prefix) @@ -2593,15 +2597,15 @@ def test_search_multi_logical_exprs(self): collection_w.insert(df) collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() - + multi_exprs = " || ".join(f'{default_int_field_name} == {i}' for i in range(60)) - + collection_w.load() vectors_s = [[random.random() for _ in range(ct.default_dim)] for _ in range(ct.default_nq)] limit = 1000 _, check_res = collection_w.search(vectors_s[:ct.default_nq], ct.default_float_vec_field_name, ct.default_search_params, limit, multi_exprs) - assert(check_res == True) + assert(check_res == True) class TestQueryString(TestcaseBase): """ @@ -2947,8 +2951,8 @@ def test_query_string_field_not_primary_is_empty(self): @pytest.mark.tags(CaseLabel.L2) def test_query_with_create_diskann_index(self): """ - target: test query after create diskann index - method: create a collection and build diskann index + target: test query after create diskann index + method: create a collection and build diskann index expected: verify query result """ collection_w, vectors = self.init_collection_general(prefix, insert_data=True, is_index=False)[0:2] @@ -2968,8 +2972,8 @@ def test_query_with_create_diskann_index(self): @pytest.mark.tags(CaseLabel.L2) def test_query_with_create_diskann_with_string_pk(self): """ - target: test query after create diskann index - method: create a collection with string pk and build diskann index + target: test query after create diskann index + method: create a collection with string pk and build diskann index expected: verify query result """ collection_w, vectors = self.init_collection_general(prefix, insert_data=True, @@ -2986,7 +2990,7 @@ def test_query_with_create_diskann_with_string_pk(self): @pytest.mark.tags(CaseLabel.L1) def test_query_with_scalar_field(self): """ - target: test query with Scalar field + target: test query with Scalar field method: create collection , string field is primary collection load and insert empty data with string field collection query uses string expr in string field @@ -3015,6 +3019,48 @@ def test_query_with_scalar_field(self): res, _ = collection_w.query(expr, output_fields=output_fields) assert len(res) == 4 +class TestQueryArray(TestcaseBase): + + @pytest.mark.tags(CaseLabel.L1) + @pytest.mark.parametrize("array_element_data_type", [DataType.INT64]) + def test_query_array_with_inverted_index(self, array_element_data_type): + # create collection + additional_params = {"max_length": 1000} if array_element_data_type == DataType.VARCHAR else {} + fields = [ + FieldSchema(name="id", dtype=DataType.INT64, is_primary=True), + FieldSchema(name="contains", dtype=DataType.ARRAY, element_type=array_element_data_type, max_capacity=2000, **additional_params), + FieldSchema(name="contains_any", dtype=DataType.ARRAY, element_type=array_element_data_type, + max_capacity=2000, **additional_params), + FieldSchema(name="contains_all", dtype=DataType.ARRAY, element_type=array_element_data_type, + max_capacity=2000, **additional_params), + FieldSchema(name="equals", dtype=DataType.ARRAY, element_type=array_element_data_type, max_capacity=2000, **additional_params), + FieldSchema(name="array_length_field", dtype=DataType.ARRAY, element_type=array_element_data_type, + max_capacity=2000, **additional_params), + FieldSchema(name="array_access", dtype=DataType.ARRAY, element_type=array_element_data_type, + max_capacity=2000, **additional_params), + FieldSchema(name="emb", dtype=DataType.FLOAT_VECTOR, dim=128) + ] + schema = CollectionSchema(fields=fields, description="test collection", enable_dynamic_field=True) + collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix), schema=schema) + # insert data + train_data, query_expr = cf.prepare_array_test_data(3000, hit_rate=0.05) + collection_w.insert(train_data) + index_params = {"metric_type": "L2", "index_type": "HNSW", "params": {"M": 48, "efConstruction": 500}} + collection_w.create_index("emb", index_params=index_params) + for f in ["contains", "contains_any", "contains_all", "equals", "array_length_field", "array_access"]: + collection_w.create_index(f, {"index_type": "INVERTED"}) + collection_w.load() + + for item in query_expr: + expr = item["expr"] + ground_truth = item["ground_truth"] + res, _ = collection_w.query( + expr=expr, + output_fields=["*"], + ) + assert len(res) == len(ground_truth) + for i in range(len(res)): + assert res[i]["id"] == ground_truth[i] class TestQueryCount(TestcaseBase): diff --git a/tests/python_client/testcases/test_search.py b/tests/python_client/testcases/test_search.py index 16f03a1a58729..9a8e9bb183ff3 100644 --- a/tests/python_client/testcases/test_search.py +++ b/tests/python_client/testcases/test_search.py @@ -1,6 +1,10 @@ import numpy as np from pymilvus.orm.types import CONSISTENCY_STRONG, CONSISTENCY_BOUNDED, CONSISTENCY_SESSION, CONSISTENCY_EVENTUALLY from pymilvus import AnnSearchRequest, RRFRanker, WeightedRanker +from pymilvus import ( + FieldSchema, CollectionSchema, DataType, + Collection +) from common.constants import * from utils.util_pymilvus import * from common.common_type import CaseLabel, CheckTasks @@ -5237,7 +5241,7 @@ def test_enable_mmap_search_for_binary_indexes(self, index): class TestSearchDSL(TestcaseBase): @pytest.mark.tags(CaseLabel.L0) - def test_query_vector_only(self): + def test_search_vector_only(self): """ target: test search normal scenario method: search vector only @@ -5254,6 +5258,54 @@ def test_query_vector_only(self): check_items={"nq": nq, "ids": insert_ids, "limit": ct.default_top_k}) +class TestSearchArray(TestcaseBase): + + @pytest.mark.tags(CaseLabel.L1) + @pytest.mark.parametrize("array_element_data_type", [DataType.INT64]) + def test_search_array_with_inverted_index(self, array_element_data_type): + # create collection + additional_params = {"max_length": 1000} if array_element_data_type == DataType.VARCHAR else {} + fields = [ + FieldSchema(name="id", dtype=DataType.INT64, is_primary=True), + FieldSchema(name="contains", dtype=DataType.ARRAY, element_type=array_element_data_type, max_capacity=2000, **additional_params), + FieldSchema(name="contains_any", dtype=DataType.ARRAY, element_type=array_element_data_type, + max_capacity=2000, **additional_params), + FieldSchema(name="contains_all", dtype=DataType.ARRAY, element_type=array_element_data_type, + max_capacity=2000, **additional_params), + FieldSchema(name="equals", dtype=DataType.ARRAY, element_type=array_element_data_type, max_capacity=2000, **additional_params), + FieldSchema(name="array_length_field", dtype=DataType.ARRAY, element_type=array_element_data_type, + max_capacity=2000, **additional_params), + FieldSchema(name="array_access", dtype=DataType.ARRAY, element_type=array_element_data_type, + max_capacity=2000, **additional_params), + FieldSchema(name="emb", dtype=DataType.FLOAT_VECTOR, dim=128) + ] + schema = CollectionSchema(fields=fields, description="test collection", enable_dynamic_field=True) + collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix), schema=schema) + # insert data + train_data, query_expr = cf.prepare_array_test_data(3000, hit_rate=0.05) + collection_w.insert(train_data) + index_params = {"metric_type": "L2", "index_type": "HNSW", "params": {"M": 48, "efConstruction": 500}} + collection_w.create_index("emb", index_params=index_params) + for f in ["contains", "contains_any", "contains_all", "equals", "array_length_field", "array_access"]: + collection_w.create_index(f, {"index_type": "INVERTED"}) + collection_w.load() + + for item in query_expr: + expr = item["expr"] + ground_truth_candidate = item["ground_truth"] + res, _ = collection_w.search( + data = [np.array([random.random() for j in range(128)], dtype=np.dtype("float32"))], + anns_field="emb", + param={"metric_type": "L2", "params": {"M": 32, "efConstruction": 360}}, + limit=10, + expr=expr, + output_fields=["*"], + ) + assert len(res) == 1 + for i in range(len(res)): + assert len(res[i]) == 10 + for hit in res[i]: + assert hit.id in ground_truth_candidate class TestSearchString(TestcaseBase): @@ -12869,4 +12921,4 @@ def test_sparse_vector_search_iterator(self, index): collection_w.search_iterator(data[-1][-1:], ct.default_sparse_vec_field_name, ct.default_sparse_search_params, batch_size, check_task=CheckTasks.check_search_iterator, - check_items={"batch_size": batch_size}) \ No newline at end of file + check_items={"batch_size": batch_size}) From 4641fd9195d71d0a67c3a61f2b1d3bc6404c4cbf Mon Sep 17 00:00:00 2001 From: Chun Han <116052805+MrPresent-Han@users.noreply.github.com> Date: Mon, 2 Sep 2024 18:25:03 +0800 Subject: [PATCH 2/3] enhance: make search groupby stop when reaching topk groups (#35814) related: #33544 Signed-off-by: MrPresent-Han Co-authored-by: MrPresent-Han --- internal/core/src/common/QueryInfo.h | 1 + internal/core/src/query/PlanProto.cpp | 1 + .../query/groupby/SearchGroupByOperator.cpp | 15 ++++++++-- .../src/query/groupby/SearchGroupByOperator.h | 28 ++++++++++++++----- internal/core/unittest/test_group_by.cpp | 2 ++ internal/proto/plan.proto | 1 + internal/proxy/search_util.go | 24 ++++++++++++---- internal/proxy/task.go | 1 + 8 files changed, 57 insertions(+), 16 deletions(-) diff --git a/internal/core/src/common/QueryInfo.h b/internal/core/src/common/QueryInfo.h index 31785ea365183..440194d33c9f7 100644 --- a/internal/core/src/common/QueryInfo.h +++ b/internal/core/src/common/QueryInfo.h @@ -27,6 +27,7 @@ namespace milvus { struct SearchInfo { int64_t topk_{0}; int64_t group_size_{1}; + bool group_strict_size_{false}; int64_t round_decimal_{0}; FieldId field_id_; MetricType metric_type_; diff --git a/internal/core/src/query/PlanProto.cpp b/internal/core/src/query/PlanProto.cpp index 170b0d120c85b..7964c8df9c214 100644 --- a/internal/core/src/query/PlanProto.cpp +++ b/internal/core/src/query/PlanProto.cpp @@ -212,6 +212,7 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) { search_info.group_size_ = query_info_proto.group_size() > 0 ? query_info_proto.group_size() : 1; + search_info.group_strict_size_ = query_info_proto.group_strict_size(); } auto plan_node = [&]() -> std::unique_ptr { diff --git a/internal/core/src/query/groupby/SearchGroupByOperator.cpp b/internal/core/src/query/groupby/SearchGroupByOperator.cpp index 7b04f9cd2faff..1650e55a8e173 100644 --- a/internal/core/src/query/groupby/SearchGroupByOperator.cpp +++ b/internal/core/src/query/groupby/SearchGroupByOperator.cpp @@ -44,6 +44,7 @@ SearchGroupBy(const std::vector>& iterators, GroupIteratorsByType(iterators, search_info.topk_, search_info.group_size_, + search_info.group_strict_size_, *dataGetter, group_by_values, seg_offsets, @@ -58,6 +59,7 @@ SearchGroupBy(const std::vector>& iterators, GroupIteratorsByType(iterators, search_info.topk_, search_info.group_size_, + search_info.group_strict_size_, *dataGetter, group_by_values, seg_offsets, @@ -72,6 +74,7 @@ SearchGroupBy(const std::vector>& iterators, GroupIteratorsByType(iterators, search_info.topk_, search_info.group_size_, + search_info.group_strict_size_, *dataGetter, group_by_values, seg_offsets, @@ -86,6 +89,7 @@ SearchGroupBy(const std::vector>& iterators, GroupIteratorsByType(iterators, search_info.topk_, search_info.group_size_, + search_info.group_strict_size_, *dataGetter, group_by_values, seg_offsets, @@ -99,6 +103,7 @@ SearchGroupBy(const std::vector>& iterators, GroupIteratorsByType(iterators, search_info.topk_, search_info.group_size_, + search_info.group_strict_size_, *dataGetter, group_by_values, seg_offsets, @@ -113,6 +118,7 @@ SearchGroupBy(const std::vector>& iterators, GroupIteratorsByType(iterators, search_info.topk_, search_info.group_size_, + search_info.group_strict_size_, *dataGetter, group_by_values, seg_offsets, @@ -136,6 +142,7 @@ GroupIteratorsByType( const std::vector>& iterators, int64_t topK, int64_t group_size, + bool group_strict_size, const DataGetter& data_getter, std::vector& group_by_values, std::vector& seg_offsets, @@ -147,6 +154,7 @@ GroupIteratorsByType( GroupIteratorResult(iterator, topK, group_size, + group_strict_size, data_getter, group_by_values, seg_offsets, @@ -161,13 +169,14 @@ void GroupIteratorResult(const std::shared_ptr& iterator, int64_t topK, int64_t group_size, + bool group_strict_size, const DataGetter& data_getter, std::vector& group_by_values, std::vector& offsets, std::vector& distances, const knowhere::MetricType& metrics_type) { //1. - GroupByMap groupMap(topK, group_size); + GroupByMap groupMap(topK, group_size, group_strict_size); //2. do iteration until fill the whole map or run out of all data //note it may enumerate all data inside a segment and can block following @@ -195,8 +204,8 @@ GroupIteratorResult(const std::shared_ptr& iterator, //4. save groupBy results for (auto iter = res.cbegin(); iter != res.cend(); iter++) { - offsets.push_back(std::get<0>(*iter)); - distances.push_back(std::get<1>(*iter)); + offsets.emplace_back(std::get<0>(*iter)); + distances.emplace_back(std::get<1>(*iter)); group_by_values.emplace_back(std::move(std::get<2>(*iter))); } } diff --git a/internal/core/src/query/groupby/SearchGroupByOperator.h b/internal/core/src/query/groupby/SearchGroupByOperator.h index dfc51d318ebc6..f3513ab882bd4 100644 --- a/internal/core/src/query/groupby/SearchGroupByOperator.h +++ b/internal/core/src/query/groupby/SearchGroupByOperator.h @@ -182,6 +182,7 @@ GroupIteratorsByType( const std::vector>& iterators, int64_t topK, int64_t group_size, + bool group_strict_size, const DataGetter& data_getter, std::vector& group_by_values, std::vector& seg_offsets, @@ -195,19 +196,31 @@ struct GroupByMap { std::unordered_map group_map_{}; int group_capacity_{0}; int group_size_{0}; - int enough_group_count{0}; + int enough_group_count_{0}; + bool strict_group_size_{false}; public: - GroupByMap(int group_capacity, int group_size) - : group_capacity_(group_capacity), group_size_(group_size){}; + GroupByMap(int group_capacity, + int group_size, + bool strict_group_size = false) + : group_capacity_(group_capacity), + group_size_(group_size), + strict_group_size_(strict_group_size){}; bool IsGroupResEnough() { - return group_map_.size() == group_capacity_ && - enough_group_count == group_capacity_; + bool enough = false; + if (strict_group_size_) { + enough = group_map_.size() == group_capacity_ && + enough_group_count_ == group_capacity_; + } else { + enough = group_map_.size() == group_capacity_; + } + return enough; } bool Push(const T& t) { - if (group_map_.size() >= group_capacity_ && group_map_[t] == 0) { + if (group_map_.size() >= group_capacity_ && + group_map_.find(t) == group_map_.end()) { return false; } if (group_map_[t] >= group_size_) { @@ -218,7 +231,7 @@ struct GroupByMap { } group_map_[t] += 1; if (group_map_[t] >= group_size_) { - enough_group_count += 1; + enough_group_count_ += 1; } return true; } @@ -229,6 +242,7 @@ void GroupIteratorResult(const std::shared_ptr& iterator, int64_t topK, int64_t group_size, + bool group_strict_size, const DataGetter& data_getter, std::vector& group_by_values, std::vector& offsets, diff --git a/internal/core/unittest/test_group_by.cpp b/internal/core/unittest/test_group_by.cpp index c06ff6e558c69..0d334cbe51382 100644 --- a/internal/core/unittest/test_group_by.cpp +++ b/internal/core/unittest/test_group_by.cpp @@ -474,6 +474,7 @@ TEST(GroupBY, SealedData) { search_params: "{\"ef\": 10}" group_by_field_id: 101, group_size: 5, + group_strict_size: true, > placeholder_tag: "$0" @@ -796,6 +797,7 @@ TEST(GroupBY, GrowingIndex) { search_params: "{\"ef\": 10}" group_by_field_id: 101 group_size: 3 + group_strict_size: true > placeholder_tag: "$0" diff --git a/internal/proto/plan.proto b/internal/proto/plan.proto index e551a242b50b0..e9d19f0193eb0 100644 --- a/internal/proto/plan.proto +++ b/internal/proto/plan.proto @@ -62,6 +62,7 @@ message QueryInfo { int64 group_by_field_id = 6; bool materialized_view_involved = 7; int64 group_size = 8; + bool group_strict_size = 9; } message ColumnInfo { diff --git a/internal/proxy/search_util.go b/internal/proxy/search_util.go index 06f2ff4a0ae95..ad6a90dd08857 100644 --- a/internal/proxy/search_util.go +++ b/internal/proxy/search_util.go @@ -129,6 +129,17 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb } } + var groupStrictSize bool + groupStrictSizeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupStrictSize, searchParamsPair) + if err != nil { + groupStrictSize = false + } else { + groupStrictSize, err = strconv.ParseBool(groupStrictSizeStr) + if err != nil { + groupStrictSize = false + } + } + // 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search if isIterator == "True" && groupByFieldId > 0 { return nil, 0, merr.WrapErrParameterInvalid("", "", @@ -140,12 +151,13 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb } return &planpb.QueryInfo{ - Topk: queryTopK, - MetricType: metricType, - SearchParams: searchParamStr, - RoundDecimal: roundDecimal, - GroupByFieldId: groupByFieldId, - GroupSize: groupSize, + Topk: queryTopK, + MetricType: metricType, + SearchParams: searchParamStr, + RoundDecimal: roundDecimal, + GroupByFieldId: groupByFieldId, + GroupSize: groupSize, + GroupStrictSize: groupStrictSize, }, offset, nil } diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 71bb162789a95..2b4fad685a153 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -48,6 +48,7 @@ const ( IteratorField = "iterator" GroupByFieldKey = "group_by_field" GroupSizeKey = "group_size" + GroupStrictSize = "group_strict_size" AnnsFieldKey = "anns_field" TopKKey = "topk" NQKey = "nq" From 3698c53a72473e69a7496b31b6602a9c0a5b3d92 Mon Sep 17 00:00:00 2001 From: congqixia Date: Mon, 2 Sep 2024 18:39:03 +0800 Subject: [PATCH 3/3] enhance: Check load fields for previous loaded collection (#35905) Related to #35415 This PR make querycoord report error when load request tries to update load fields list, which is currently not supported. Signed-off-by: Congqi Xia --- internal/querycoordv2/job/job_load.go | 17 +++++++++ internal/querycoordv2/job/job_test.go | 54 +++++++++++++++++++++++++++ 2 files changed, 71 insertions(+) diff --git a/internal/querycoordv2/job/job_load.go b/internal/querycoordv2/job/job_load.go index 4ade22ee48c73..f17d64ea8be28 100644 --- a/internal/querycoordv2/job/job_load.go +++ b/internal/querycoordv2/job/job_load.go @@ -19,6 +19,7 @@ package job import ( "context" "fmt" + "reflect" "time" "github.com/cockroachdb/errors" @@ -104,6 +105,14 @@ func (job *LoadCollectionJob) PreExecute() error { return merr.WrapErrParameterInvalid(collection.GetReplicaNumber(), req.GetReplicaNumber(), "can't change the replica number for loaded collection") } + if !reflect.DeepEqual(collection.GetLoadFields(), req.GetLoadFields()) { + log.Warn("collection with different load field list exists, release this collection first before chaning its replica number", + zap.Int64s("loadedFieldIDs", collection.GetLoadFields()), + zap.Int64s("reqFieldIDs", req.GetLoadFields()), + ) + return merr.WrapErrParameterInvalid(collection.GetLoadFields(), req.GetLoadFields(), "can't change the load field list for loaded collection") + } + return nil } @@ -289,6 +298,14 @@ func (job *LoadPartitionJob) PreExecute() error { return merr.WrapErrParameterInvalid(collection.GetReplicaNumber(), req.GetReplicaNumber(), "can't change the replica number for loaded partitions") } + if !reflect.DeepEqual(collection.GetLoadFields(), req.GetLoadFields()) { + log.Warn("collection with different load field list exists, release this collection first before chaning its replica number", + zap.Int64s("loadedFieldIDs", collection.GetLoadFields()), + zap.Int64s("reqFieldIDs", req.GetLoadFields()), + ) + return merr.WrapErrParameterInvalid(collection.GetLoadFields(), req.GetLoadFields(), "can't change the load field list for loaded collection") + } + return nil } diff --git a/internal/querycoordv2/job/job_test.go b/internal/querycoordv2/job/job_test.go index 276234990de1b..e919d28b7a240 100644 --- a/internal/querycoordv2/job/job_test.go +++ b/internal/querycoordv2/job/job_test.go @@ -307,6 +307,32 @@ func (suite *JobSuite) TestLoadCollection() { suite.ErrorIs(err, merr.ErrParameterInvalid) } + // Test load existed collection with different load fields + for _, collection := range suite.collections { + if suite.loadTypes[collection] != querypb.LoadType_LoadCollection { + continue + } + req := &querypb.LoadCollectionRequest{ + CollectionID: collection, + LoadFields: []int64{100, 101}, + } + job := NewLoadCollectionJob( + ctx, + req, + suite.dist, + suite.meta, + suite.broker, + suite.cluster, + suite.targetMgr, + suite.targetObserver, + suite.collectionObserver, + suite.nodeMgr, + ) + suite.scheduler.Add(job) + err := job.Wait() + suite.ErrorIs(err, merr.ErrParameterInvalid) + } + // Test load partition while collection exists for _, collection := range suite.collections { if suite.loadTypes[collection] != querypb.LoadType_LoadCollection { @@ -514,6 +540,34 @@ func (suite *JobSuite) TestLoadPartition() { suite.ErrorIs(err, merr.ErrParameterInvalid) } + // Test load partition with different load fields + for _, collection := range suite.collections { + if suite.loadTypes[collection] != querypb.LoadType_LoadPartition { + continue + } + + req := &querypb.LoadPartitionsRequest{ + CollectionID: collection, + PartitionIDs: suite.partitions[collection], + LoadFields: []int64{100, 101}, + } + job := NewLoadPartitionJob( + ctx, + req, + suite.dist, + suite.meta, + suite.broker, + suite.cluster, + suite.targetMgr, + suite.targetObserver, + suite.collectionObserver, + suite.nodeMgr, + ) + suite.scheduler.Add(job) + err := job.Wait() + suite.ErrorIs(err, merr.ErrParameterInvalid) + } + // Test load partition with more partition for _, collection := range suite.collections { if suite.loadTypes[collection] != querypb.LoadType_LoadPartition {