Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/milvus-io/milvus into 240…
Browse files Browse the repository at this point in the history
…8-skip-bf
  • Loading branch information
bigsheeper committed Sep 2, 2024
2 parents 29e2578 + 3698c53 commit 1abff33
Show file tree
Hide file tree
Showing 13 changed files with 382 additions and 30 deletions.
1 change: 1 addition & 0 deletions internal/core/src/common/QueryInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ namespace milvus {
struct SearchInfo {
int64_t topk_{0};
int64_t group_size_{1};
bool group_strict_size_{false};
int64_t round_decimal_{0};
FieldId field_id_;
MetricType metric_type_;
Expand Down
1 change: 1 addition & 0 deletions internal/core/src/query/PlanProto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
search_info.group_size_ = query_info_proto.group_size() > 0
? query_info_proto.group_size()
: 1;
search_info.group_strict_size_ = query_info_proto.group_strict_size();
}

auto plan_node = [&]() -> std::unique_ptr<VectorPlanNode> {
Expand Down
15 changes: 12 additions & 3 deletions internal/core/src/query/groupby/SearchGroupByOperator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
GroupIteratorsByType<int8_t>(iterators,
search_info.topk_,
search_info.group_size_,
search_info.group_strict_size_,
*dataGetter,
group_by_values,
seg_offsets,
Expand All @@ -58,6 +59,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
GroupIteratorsByType<int16_t>(iterators,
search_info.topk_,
search_info.group_size_,
search_info.group_strict_size_,
*dataGetter,
group_by_values,
seg_offsets,
Expand All @@ -72,6 +74,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
GroupIteratorsByType<int32_t>(iterators,
search_info.topk_,
search_info.group_size_,
search_info.group_strict_size_,
*dataGetter,
group_by_values,
seg_offsets,
Expand All @@ -86,6 +89,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
GroupIteratorsByType<int64_t>(iterators,
search_info.topk_,
search_info.group_size_,
search_info.group_strict_size_,
*dataGetter,
group_by_values,
seg_offsets,
Expand All @@ -99,6 +103,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
GroupIteratorsByType<bool>(iterators,
search_info.topk_,
search_info.group_size_,
search_info.group_strict_size_,
*dataGetter,
group_by_values,
seg_offsets,
Expand All @@ -113,6 +118,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
GroupIteratorsByType<std::string>(iterators,
search_info.topk_,
search_info.group_size_,
search_info.group_strict_size_,
*dataGetter,
group_by_values,
seg_offsets,
Expand All @@ -136,6 +142,7 @@ GroupIteratorsByType(
const std::vector<std::shared_ptr<VectorIterator>>& iterators,
int64_t topK,
int64_t group_size,
bool group_strict_size,
const DataGetter<T>& data_getter,
std::vector<GroupByValueType>& group_by_values,
std::vector<int64_t>& seg_offsets,
Expand All @@ -147,6 +154,7 @@ GroupIteratorsByType(
GroupIteratorResult<T>(iterator,
topK,
group_size,
group_strict_size,
data_getter,
group_by_values,
seg_offsets,
Expand All @@ -161,13 +169,14 @@ void
GroupIteratorResult(const std::shared_ptr<VectorIterator>& iterator,
int64_t topK,
int64_t group_size,
bool group_strict_size,
const DataGetter<T>& data_getter,
std::vector<GroupByValueType>& group_by_values,
std::vector<int64_t>& offsets,
std::vector<float>& distances,
const knowhere::MetricType& metrics_type) {
//1.
GroupByMap<T> groupMap(topK, group_size);
GroupByMap<T> groupMap(topK, group_size, group_strict_size);

//2. do iteration until fill the whole map or run out of all data
//note it may enumerate all data inside a segment and can block following
Expand Down Expand Up @@ -195,8 +204,8 @@ GroupIteratorResult(const std::shared_ptr<VectorIterator>& iterator,

//4. save groupBy results
for (auto iter = res.cbegin(); iter != res.cend(); iter++) {
offsets.push_back(std::get<0>(*iter));
distances.push_back(std::get<1>(*iter));
offsets.emplace_back(std::get<0>(*iter));
distances.emplace_back(std::get<1>(*iter));
group_by_values.emplace_back(std::move(std::get<2>(*iter)));
}
}
Expand Down
28 changes: 21 additions & 7 deletions internal/core/src/query/groupby/SearchGroupByOperator.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ GroupIteratorsByType(
const std::vector<std::shared_ptr<VectorIterator>>& iterators,
int64_t topK,
int64_t group_size,
bool group_strict_size,
const DataGetter<T>& data_getter,
std::vector<GroupByValueType>& group_by_values,
std::vector<int64_t>& seg_offsets,
Expand All @@ -195,19 +196,31 @@ struct GroupByMap {
std::unordered_map<T, int> group_map_{};
int group_capacity_{0};
int group_size_{0};
int enough_group_count{0};
int enough_group_count_{0};
bool strict_group_size_{false};

public:
GroupByMap(int group_capacity, int group_size)
: group_capacity_(group_capacity), group_size_(group_size){};
GroupByMap(int group_capacity,
int group_size,
bool strict_group_size = false)
: group_capacity_(group_capacity),
group_size_(group_size),
strict_group_size_(strict_group_size){};
bool
IsGroupResEnough() {
return group_map_.size() == group_capacity_ &&
enough_group_count == group_capacity_;
bool enough = false;
if (strict_group_size_) {
enough = group_map_.size() == group_capacity_ &&
enough_group_count_ == group_capacity_;
} else {
enough = group_map_.size() == group_capacity_;
}
return enough;
}
bool
Push(const T& t) {
if (group_map_.size() >= group_capacity_ && group_map_[t] == 0) {
if (group_map_.size() >= group_capacity_ &&
group_map_.find(t) == group_map_.end()) {
return false;
}
if (group_map_[t] >= group_size_) {
Expand All @@ -218,7 +231,7 @@ struct GroupByMap {
}
group_map_[t] += 1;
if (group_map_[t] >= group_size_) {
enough_group_count += 1;
enough_group_count_ += 1;
}
return true;
}
Expand All @@ -229,6 +242,7 @@ void
GroupIteratorResult(const std::shared_ptr<VectorIterator>& iterator,
int64_t topK,
int64_t group_size,
bool group_strict_size,
const DataGetter<T>& data_getter,
std::vector<GroupByValueType>& group_by_values,
std::vector<int64_t>& offsets,
Expand Down
2 changes: 2 additions & 0 deletions internal/core/unittest/test_group_by.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,7 @@ TEST(GroupBY, SealedData) {
search_params: "{\"ef\": 10}"
group_by_field_id: 101,
group_size: 5,
group_strict_size: true,
>
placeholder_tag: "$0"
Expand Down Expand Up @@ -796,6 +797,7 @@ TEST(GroupBY, GrowingIndex) {
search_params: "{\"ef\": 10}"
group_by_field_id: 101
group_size: 3
group_strict_size: true
>
placeholder_tag: "$0"
Expand Down
1 change: 1 addition & 0 deletions internal/proto/plan.proto
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ message QueryInfo {
int64 group_by_field_id = 6;
bool materialized_view_involved = 7;
int64 group_size = 8;
bool group_strict_size = 9;
}

message ColumnInfo {
Expand Down
24 changes: 18 additions & 6 deletions internal/proxy/search_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,17 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
}
}

var groupStrictSize bool
groupStrictSizeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupStrictSize, searchParamsPair)
if err != nil {
groupStrictSize = false
} else {
groupStrictSize, err = strconv.ParseBool(groupStrictSizeStr)
if err != nil {
groupStrictSize = false
}
}

// 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search
if isIterator == "True" && groupByFieldId > 0 {
return nil, 0, merr.WrapErrParameterInvalid("", "",
Expand All @@ -140,12 +151,13 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
}

return &planpb.QueryInfo{
Topk: queryTopK,
MetricType: metricType,
SearchParams: searchParamStr,
RoundDecimal: roundDecimal,
GroupByFieldId: groupByFieldId,
GroupSize: groupSize,
Topk: queryTopK,
MetricType: metricType,
SearchParams: searchParamStr,
RoundDecimal: roundDecimal,
GroupByFieldId: groupByFieldId,
GroupSize: groupSize,
GroupStrictSize: groupStrictSize,
}, offset, nil
}

Expand Down
1 change: 1 addition & 0 deletions internal/proxy/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ const (
IteratorField = "iterator"
GroupByFieldKey = "group_by_field"
GroupSizeKey = "group_size"
GroupStrictSize = "group_strict_size"
AnnsFieldKey = "anns_field"
TopKKey = "topk"
NQKey = "nq"
Expand Down
17 changes: 17 additions & 0 deletions internal/querycoordv2/job/job_load.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package job
import (
"context"
"fmt"
"reflect"
"time"

"github.com/cockroachdb/errors"
Expand Down Expand Up @@ -104,6 +105,14 @@ func (job *LoadCollectionJob) PreExecute() error {
return merr.WrapErrParameterInvalid(collection.GetReplicaNumber(), req.GetReplicaNumber(), "can't change the replica number for loaded collection")
}

if !reflect.DeepEqual(collection.GetLoadFields(), req.GetLoadFields()) {
log.Warn("collection with different load field list exists, release this collection first before chaning its replica number",
zap.Int64s("loadedFieldIDs", collection.GetLoadFields()),
zap.Int64s("reqFieldIDs", req.GetLoadFields()),
)
return merr.WrapErrParameterInvalid(collection.GetLoadFields(), req.GetLoadFields(), "can't change the load field list for loaded collection")
}

return nil
}

Expand Down Expand Up @@ -289,6 +298,14 @@ func (job *LoadPartitionJob) PreExecute() error {
return merr.WrapErrParameterInvalid(collection.GetReplicaNumber(), req.GetReplicaNumber(), "can't change the replica number for loaded partitions")
}

if !reflect.DeepEqual(collection.GetLoadFields(), req.GetLoadFields()) {
log.Warn("collection with different load field list exists, release this collection first before chaning its replica number",
zap.Int64s("loadedFieldIDs", collection.GetLoadFields()),
zap.Int64s("reqFieldIDs", req.GetLoadFields()),
)
return merr.WrapErrParameterInvalid(collection.GetLoadFields(), req.GetLoadFields(), "can't change the load field list for loaded collection")
}

return nil
}

Expand Down
54 changes: 54 additions & 0 deletions internal/querycoordv2/job/job_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,32 @@ func (suite *JobSuite) TestLoadCollection() {
suite.ErrorIs(err, merr.ErrParameterInvalid)
}

// Test load existed collection with different load fields
for _, collection := range suite.collections {
if suite.loadTypes[collection] != querypb.LoadType_LoadCollection {
continue
}
req := &querypb.LoadCollectionRequest{
CollectionID: collection,
LoadFields: []int64{100, 101},
}
job := NewLoadCollectionJob(
ctx,
req,
suite.dist,
suite.meta,
suite.broker,
suite.cluster,
suite.targetMgr,
suite.targetObserver,
suite.collectionObserver,
suite.nodeMgr,
)
suite.scheduler.Add(job)
err := job.Wait()
suite.ErrorIs(err, merr.ErrParameterInvalid)
}

// Test load partition while collection exists
for _, collection := range suite.collections {
if suite.loadTypes[collection] != querypb.LoadType_LoadCollection {
Expand Down Expand Up @@ -514,6 +540,34 @@ func (suite *JobSuite) TestLoadPartition() {
suite.ErrorIs(err, merr.ErrParameterInvalid)
}

// Test load partition with different load fields
for _, collection := range suite.collections {
if suite.loadTypes[collection] != querypb.LoadType_LoadPartition {
continue
}

req := &querypb.LoadPartitionsRequest{
CollectionID: collection,
PartitionIDs: suite.partitions[collection],
LoadFields: []int64{100, 101},
}
job := NewLoadPartitionJob(
ctx,
req,
suite.dist,
suite.meta,
suite.broker,
suite.cluster,
suite.targetMgr,
suite.targetObserver,
suite.collectionObserver,
suite.nodeMgr,
)
suite.scheduler.Add(job)
err := job.Wait()
suite.ErrorIs(err, merr.ErrParameterInvalid)
}

// Test load partition with more partition
for _, collection := range suite.collections {
if suite.loadTypes[collection] != querypb.LoadType_LoadPartition {
Expand Down
Loading

0 comments on commit 1abff33

Please sign in to comment.