From 2f441dc8e24a32e66c257cfe2161f6ad25f49539 Mon Sep 17 00:00:00 2001 From: "sammy.huang" Date: Fri, 5 Jan 2024 17:20:47 +0800 Subject: [PATCH 01/20] enhance: [skip e2e]increase timeout for image build (#29083) Signed-off-by: Sammy Huang --- ci/jenkins/PublishImages.groovy | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/jenkins/PublishImages.groovy b/ci/jenkins/PublishImages.groovy index 0d49ac0afed0b..f3a1ca2ba0cae 100644 --- a/ci/jenkins/PublishImages.groovy +++ b/ci/jenkins/PublishImages.groovy @@ -14,7 +14,7 @@ pipeline { options { timestamps() - timeout(time: 100, unit: 'MINUTES') + timeout(time: 200, unit: 'MINUTES') // parallelsAlwaysFailFast() disableConcurrentBuilds() } From 23183ffb0fd4698be0b81c6cd6c3c15d86e20240 Mon Sep 17 00:00:00 2001 From: "yihao.dai" Date: Fri, 5 Jan 2024 18:12:48 +0800 Subject: [PATCH 02/20] feat: Add import reader for json (#29252) This PR implements a new json reader for import. issue: https://github.com/milvus-io/milvus/issues/28521 --------- Signed-off-by: bigsheeper --- internal/storage/insert_data.go | 11 +- internal/util/importutilv2/json/reader.go | 143 ++++++ .../util/importutilv2/json/reader_test.go | 278 +++++++++++ internal/util/importutilv2/json/row_parser.go | 456 ++++++++++++++++++ pkg/util/typeutil/schema.go | 10 + 5 files changed, 892 insertions(+), 6 deletions(-) create mode 100644 internal/util/importutilv2/json/reader.go create mode 100644 internal/util/importutilv2/json/reader_test.go create mode 100644 internal/util/importutilv2/json/row_parser.go diff --git a/internal/storage/insert_data.go b/internal/storage/insert_data.go index 8beaa9be1452c..4fafbf160cbe8 100644 --- a/internal/storage/insert_data.go +++ b/internal/storage/insert_data.go @@ -78,13 +78,12 @@ func (i *InsertData) GetRowNum() int { if i.Data == nil || len(i.Data) == 0 { return 0 } - - data, ok := i.Data[common.RowIDField] - if !ok { - return 0 + var rowNum int + for _, data := range i.Data { + rowNum = data.RowNum() + break } - - return data.RowNum() + return rowNum } func (i *InsertData) GetMemorySize() int { diff --git a/internal/util/importutilv2/json/reader.go b/internal/util/importutilv2/json/reader.go new file mode 100644 index 0000000000000..8998335d22892 --- /dev/null +++ b/internal/util/importutilv2/json/reader.go @@ -0,0 +1,143 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package json + +import ( + "encoding/json" + "fmt" + "io" + "strings" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/util/merr" +) + +const ( + RowRootNode = "rows" +) + +type Row = map[storage.FieldID]any + +type reader struct { + dec *json.Decoder + schema *schemapb.CollectionSchema + + bufferSize int + isOldFormat bool + + parser RowParser +} + +func NewReader(r io.Reader, schema *schemapb.CollectionSchema, bufferSize int) (*reader, error) { + reader := &reader{ + dec: json.NewDecoder(r), + schema: schema, + bufferSize: bufferSize, + } + var err error + reader.parser, err = NewRowParser(schema) + if err != nil { + return nil, err + } + err = reader.Init() + if err != nil { + return nil, err + } + return reader, nil +} + +func (j *reader) Init() error { + // Treat number value as a string instead of a float64. + // By default, json lib treat all number values as float64, + // but if an int64 value has more than 15 digits, + // the value would be incorrect after converting from float64. + j.dec.UseNumber() + t, err := j.dec.Token() + if err != nil { + return merr.WrapErrImportFailed(fmt.Sprintf("failed to decode JSON, error: %v", err)) + } + if t != json.Delim('{') && t != json.Delim('[') { + return merr.WrapErrImportFailed("invalid JSON format, the content should be started with '{' or '['") + } + j.isOldFormat = t == json.Delim('{') + return nil +} + +func (j *reader) Read() (*storage.InsertData, error) { + insertData, err := storage.NewInsertData(j.schema) + if err != nil { + return nil, err + } + if !j.dec.More() { + return nil, nil + } + if j.isOldFormat { + // read the key + t, err := j.dec.Token() + if err != nil { + return nil, merr.WrapErrImportFailed(fmt.Sprintf("failed to decode the JSON file, error: %v", err)) + } + key := t.(string) + keyLower := strings.ToLower(key) + // the root key should be RowRootNode + if keyLower != RowRootNode { + return nil, merr.WrapErrImportFailed(fmt.Sprintf("invalid JSON format, the root key should be '%s', but get '%s'", RowRootNode, key)) + } + + // started by '[' + t, err = j.dec.Token() + if err != nil { + return nil, merr.WrapErrImportFailed(fmt.Sprintf("failed to decode the JSON file, error: %v", err)) + } + + if t != json.Delim('[') { + return nil, merr.WrapErrImportFailed("invalid JSON format, rows list should begin with '['") + } + } + for j.dec.More() { + var value any + if err = j.dec.Decode(&value); err != nil { + return nil, merr.WrapErrImportFailed(fmt.Sprintf("failed to parse row, error: %v", err)) + } + row, err := j.parser.Parse(value) + if err != nil { + return nil, err + } + err = insertData.Append(row) + if err != nil { + return nil, err + } + if insertData.GetMemorySize() >= j.bufferSize { + break + } + } + + if !j.dec.More() { + t, err := j.dec.Token() + if err != nil { + return nil, merr.WrapErrImportFailed(fmt.Sprintf("failed to decode JSON, error: %v", err)) + } + if t != json.Delim(']') { + return nil, merr.WrapErrImportFailed("invalid JSON format, rows list should end with ']'") + } + } + + return insertData, nil +} + +func (j *reader) Close() {} diff --git a/internal/util/importutilv2/json/reader_test.go b/internal/util/importutilv2/json/reader_test.go new file mode 100644 index 0000000000000..915d17c36d7e0 --- /dev/null +++ b/internal/util/importutilv2/json/reader_test.go @@ -0,0 +1,278 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package json + +import ( + rand2 "crypto/rand" + "encoding/json" + "fmt" + "math" + "math/rand" + "strconv" + "strings" + "testing" + + "github.com/samber/lo" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" + "golang.org/x/exp/slices" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type ReaderSuite struct { + suite.Suite + + numRows int + pkDataType schemapb.DataType + vecDataType schemapb.DataType +} + +func (suite *ReaderSuite) SetupSuite() { + paramtable.Get().Init(paramtable.NewBaseTable()) +} + +func (suite *ReaderSuite) SetupTest() { + // default suite params + suite.numRows = 100 + suite.pkDataType = schemapb.DataType_Int64 + suite.vecDataType = schemapb.DataType_FloatVector +} + +func createInsertData(t *testing.T, schema *schemapb.CollectionSchema, rowCount int) *storage.InsertData { + insertData, err := storage.NewInsertData(schema) + assert.NoError(t, err) + for _, field := range schema.GetFields() { + switch field.GetDataType() { + case schemapb.DataType_Bool: + boolData := make([]bool, 0) + for i := 0; i < rowCount; i++ { + boolData = append(boolData, i%3 != 0) + } + insertData.Data[field.GetFieldID()] = &storage.BoolFieldData{Data: boolData} + case schemapb.DataType_Float: + floatData := make([]float32, 0) + for i := 0; i < rowCount; i++ { + floatData = append(floatData, float32(i/2)) + } + insertData.Data[field.GetFieldID()] = &storage.FloatFieldData{Data: floatData} + case schemapb.DataType_Double: + doubleData := make([]float64, 0) + for i := 0; i < rowCount; i++ { + doubleData = append(doubleData, float64(i/5)) + } + insertData.Data[field.GetFieldID()] = &storage.DoubleFieldData{Data: doubleData} + case schemapb.DataType_Int8: + int8Data := make([]int8, 0) + for i := 0; i < rowCount; i++ { + int8Data = append(int8Data, int8(i%256)) + } + insertData.Data[field.GetFieldID()] = &storage.Int8FieldData{Data: int8Data} + case schemapb.DataType_Int16: + int16Data := make([]int16, 0) + for i := 0; i < rowCount; i++ { + int16Data = append(int16Data, int16(i%65536)) + } + insertData.Data[field.GetFieldID()] = &storage.Int16FieldData{Data: int16Data} + case schemapb.DataType_Int32: + int32Data := make([]int32, 0) + for i := 0; i < rowCount; i++ { + int32Data = append(int32Data, int32(i%1000)) + } + insertData.Data[field.GetFieldID()] = &storage.Int32FieldData{Data: int32Data} + case schemapb.DataType_Int64: + int64Data := make([]int64, 0) + for i := 0; i < rowCount; i++ { + int64Data = append(int64Data, int64(i)) + } + insertData.Data[field.GetFieldID()] = &storage.Int64FieldData{Data: int64Data} + case schemapb.DataType_BinaryVector: + dim, err := typeutil.GetDim(field) + assert.NoError(t, err) + binVecData := make([]byte, 0) + total := rowCount * int(dim) / 8 + for i := 0; i < total; i++ { + binVecData = append(binVecData, byte(i%256)) + } + insertData.Data[field.GetFieldID()] = &storage.BinaryVectorFieldData{Data: binVecData, Dim: int(dim)} + case schemapb.DataType_FloatVector: + dim, err := typeutil.GetDim(field) + assert.NoError(t, err) + floatVecData := make([]float32, 0) + total := rowCount * int(dim) + for i := 0; i < total; i++ { + floatVecData = append(floatVecData, rand.Float32()) + } + insertData.Data[field.GetFieldID()] = &storage.FloatVectorFieldData{Data: floatVecData, Dim: int(dim)} + case schemapb.DataType_Float16Vector: + dim, err := typeutil.GetDim(field) + assert.NoError(t, err) + total := int64(rowCount) * dim * 2 + float16VecData := make([]byte, total) + _, err = rand2.Read(float16VecData) + assert.NoError(t, err) + insertData.Data[field.GetFieldID()] = &storage.Float16VectorFieldData{Data: float16VecData, Dim: int(dim)} + case schemapb.DataType_String, schemapb.DataType_VarChar: + varcharData := make([]string, 0) + for i := 0; i < rowCount; i++ { + varcharData = append(varcharData, strconv.Itoa(i)) + } + insertData.Data[field.GetFieldID()] = &storage.StringFieldData{Data: varcharData} + case schemapb.DataType_JSON: + jsonData := make([][]byte, 0) + for i := 0; i < rowCount; i++ { + jsonData = append(jsonData, []byte(fmt.Sprintf("{\"y\": %d}", i))) + } + insertData.Data[field.GetFieldID()] = &storage.JSONFieldData{Data: jsonData} + case schemapb.DataType_Array: + arrayData := make([]*schemapb.ScalarField, 0) + for i := 0; i < rowCount; i++ { + arrayData = append(arrayData, &schemapb.ScalarField{ + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: []int32{int32(i), int32(i + 1), int32(i + 2)}, + }, + }, + }) + } + insertData.Data[field.GetFieldID()] = &storage.ArrayFieldData{Data: arrayData} + default: + panic(fmt.Sprintf("unexpected data type: %s", field.GetDataType().String())) + } + } + return insertData +} + +func (suite *ReaderSuite) run(dt schemapb.DataType) { + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + FieldID: 100, + Name: "pk", + IsPrimaryKey: true, + DataType: suite.pkDataType, + }, + { + FieldID: 101, + Name: "vec", + DataType: suite.vecDataType, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: "8", + }, + }, + }, + { + FieldID: 102, + Name: dt.String(), + DataType: dt, + ElementType: schemapb.DataType_Int32, + }, + }, + } + insertData := createInsertData(suite.T(), schema, suite.numRows) + rows := make([]map[string]any, 0, suite.numRows) + fieldIDToField := lo.KeyBy(schema.GetFields(), func(field *schemapb.FieldSchema) int64 { + return field.GetFieldID() + }) + for i := 0; i < insertData.GetRowNum(); i++ { + data := make(map[int64]interface{}) + for fieldID, v := range insertData.Data { + dataType := fieldIDToField[fieldID].GetDataType() + if dataType == schemapb.DataType_Array { + data[fieldID] = v.GetRow(i).(*schemapb.ScalarField).GetIntData().GetData() + } else if dataType == schemapb.DataType_JSON { + data[fieldID] = string(v.GetRow(i).([]byte)) + } else if dataType == schemapb.DataType_BinaryVector || dataType == schemapb.DataType_Float16Vector { + bytes := v.GetRow(i).([]byte) + ints := make([]int, 0, len(bytes)) + for _, b := range bytes { + ints = append(ints, int(b)) + } + data[fieldID] = ints + } else { + data[fieldID] = v.GetRow(i) + } + } + row := lo.MapKeys(data, func(_ any, fieldID int64) string { + return fieldIDToField[fieldID].GetName() + }) + rows = append(rows, row) + } + + jsonBytes, err := json.Marshal(rows) + suite.NoError(err) + r := strings.NewReader(string(jsonBytes)) + reader, err := NewReader(r, schema, math.MaxInt) + suite.NoError(err) + + checkFn := func(actualInsertData *storage.InsertData, offsetBegin, expectRows int) { + expectInsertData := insertData + for fieldID, data := range actualInsertData.Data { + suite.Equal(expectRows, data.RowNum()) + fieldDataType := typeutil.GetField(schema, fieldID).GetDataType() + for i := 0; i < expectRows; i++ { + expect := expectInsertData.Data[fieldID].GetRow(i + offsetBegin) + actual := data.GetRow(i) + if fieldDataType == schemapb.DataType_Array { + suite.True(slices.Equal(expect.(*schemapb.ScalarField).GetIntData().GetData(), actual.(*schemapb.ScalarField).GetIntData().GetData())) + } else { + suite.Equal(expect, actual) + } + } + } + } + + res, err := reader.Read() + suite.NoError(err) + checkFn(res, 0, suite.numRows) +} + +func (suite *ReaderSuite) TestReadScalarFields() { + suite.run(schemapb.DataType_Bool) + suite.run(schemapb.DataType_Int8) + suite.run(schemapb.DataType_Int16) + suite.run(schemapb.DataType_Int32) + suite.run(schemapb.DataType_Int64) + suite.run(schemapb.DataType_Float) + suite.run(schemapb.DataType_Double) + suite.run(schemapb.DataType_VarChar) + suite.run(schemapb.DataType_Array) + suite.run(schemapb.DataType_JSON) +} + +func (suite *ReaderSuite) TestStringPK() { + suite.pkDataType = schemapb.DataType_VarChar + suite.run(schemapb.DataType_Int32) +} + +func (suite *ReaderSuite) TestBinaryAndFloat16Vector() { + suite.vecDataType = schemapb.DataType_BinaryVector + suite.run(schemapb.DataType_Int32) + suite.vecDataType = schemapb.DataType_Float16Vector + suite.run(schemapb.DataType_Int32) +} + +func TestUtil(t *testing.T) { + suite.Run(t, new(ReaderSuite)) +} diff --git a/internal/util/importutilv2/json/row_parser.go b/internal/util/importutilv2/json/row_parser.go new file mode 100644 index 0000000000000..def8f1d439482 --- /dev/null +++ b/internal/util/importutilv2/json/row_parser.go @@ -0,0 +1,456 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package json + +import ( + "encoding/json" + "fmt" + "strconv" + + "github.com/cockroachdb/errors" + "github.com/samber/lo" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type RowParser interface { + Parse(raw any) (Row, error) +} + +type rowParser struct { + dim int + id2Field map[int64]*schemapb.FieldSchema + name2FieldID map[string]int64 + pkField *schemapb.FieldSchema + dynamicField *schemapb.FieldSchema +} + +func NewRowParser(schema *schemapb.CollectionSchema) (RowParser, error) { + id2Field := lo.KeyBy(schema.GetFields(), func(field *schemapb.FieldSchema) int64 { + return field.GetFieldID() + }) + vecField, err := typeutil.GetVectorFieldSchema(schema) + if err != nil { + return nil, err + } + dim, err := typeutil.GetDim(vecField) + if err != nil { + return nil, err + } + pkField, err := typeutil.GetPrimaryFieldSchema(schema) + if err != nil { + return nil, err + } + + name2FieldID := lo.SliceToMap(schema.GetFields(), + func(field *schemapb.FieldSchema) (string, int64) { + return field.GetName(), field.GetFieldID() + }) + + if pkField.GetAutoID() { + delete(name2FieldID, pkField.GetName()) + } + + dynamicField := typeutil.GetDynamicField(schema) + if dynamicField != nil { + delete(name2FieldID, dynamicField.GetName()) + } + return &rowParser{ + dim: int(dim), + id2Field: id2Field, + name2FieldID: name2FieldID, + pkField: pkField, + dynamicField: dynamicField, + }, nil +} + +func (r *rowParser) wrapTypeError(v any, fieldID int64) error { + field := r.id2Field[fieldID] + return merr.WrapErrImportFailed(fmt.Sprintf("expected type '%s' for field '%s', got type '%T' with value '%v'", + field.GetDataType().String(), field.GetName(), v, v)) +} + +func (r *rowParser) wrapDimError(actualDim int, fieldID int64) error { + field := r.id2Field[fieldID] + return merr.WrapErrImportFailed(fmt.Sprintf("expected dim '%d' for field '%s' with type '%s', got dim '%d'", + r.dim, field.GetName(), field.GetDataType().String(), actualDim)) +} + +func (r *rowParser) wrapArrayValueTypeError(v any, eleType schemapb.DataType) error { + return merr.WrapErrImportFailed(fmt.Sprintf("expected element type '%s' in array field, got type '%T' with value '%v'", + eleType.String(), v, v)) +} + +func (r *rowParser) Parse(raw any) (Row, error) { + stringMap, ok := raw.(map[string]any) + if !ok { + return nil, merr.WrapErrImportFailed("invalid JSON format, each row should be a key-value map") + } + if _, ok = stringMap[r.pkField.GetName()]; ok && r.pkField.GetAutoID() { + return nil, merr.WrapErrImportFailed( + fmt.Sprintf("the primary key '%s' is auto-generated, no need to provide", r.pkField.GetName())) + } + dynamicValues := make(map[string]any) + row := make(Row) + for key, value := range stringMap { + if fieldID, ok := r.name2FieldID[key]; ok { + data, err := r.parseEntity(fieldID, value) + if err != nil { + return nil, err + } + row[fieldID] = data + } else if r.dynamicField != nil { + if key == r.dynamicField.GetName() { + return nil, merr.WrapErrImportFailed( + fmt.Sprintf("dynamic field is enabled, explicit specification of '%s' is not allowed", key)) + } + // has dynamic field, put redundant pair to dynamicValues + dynamicValues[key] = value + } else { + return nil, merr.WrapErrImportFailed(fmt.Sprintf("the field '%s' is not defined in schema", key)) + } + } + for fieldName, fieldID := range r.name2FieldID { + if _, ok = row[fieldID]; !ok { + return nil, merr.WrapErrImportFailed(fmt.Sprintf("value of field '%s' is missed", fieldName)) + } + } + if r.dynamicField == nil { + return row, nil + } + // combine the redundant pairs into dynamic field(if it has) + err := r.combineDynamicRow(dynamicValues, row) + if err != nil { + return nil, err + } + return row, err +} + +func (r *rowParser) combineDynamicRow(dynamicValues map[string]any, row Row) error { + // Combine the dynamic field value + // invalid inputs: + // case 1: {"id": 1, "vector": [], "$meta": {"x": 8}} ==>> "$meta" is not allowed + // valid inputs: + // case 2: {"id": 1, "vector": [], "x": 8} ==>> {"id": 1, "vector": [], "$meta": "{\"x\": 8}"} + // case 3: {"id": 1, "vector": []} + dynamicFieldID := r.dynamicField.GetFieldID() + if len(dynamicValues) > 0 { + // case 2 + data, err := r.parseEntity(dynamicFieldID, dynamicValues) + if err != nil { + return err + } + row[dynamicFieldID] = data + } else { + // case 3 + row[dynamicFieldID] = "{}" + } + return nil +} + +func (r *rowParser) parseEntity(fieldID int64, obj any) (any, error) { + switch r.id2Field[fieldID].GetDataType() { + case schemapb.DataType_Bool: + b, ok := obj.(bool) + if !ok { + return nil, r.wrapTypeError(obj, fieldID) + } + return b, nil + case schemapb.DataType_Int8: + value, ok := obj.(json.Number) + if !ok { + return nil, r.wrapTypeError(obj, fieldID) + } + num, err := strconv.ParseInt(value.String(), 0, 8) + if err != nil { + return nil, err + } + return int8(num), nil + case schemapb.DataType_Int16: + value, ok := obj.(json.Number) + if !ok { + return nil, r.wrapTypeError(obj, fieldID) + } + num, err := strconv.ParseInt(value.String(), 0, 16) + if err != nil { + return nil, err + } + return int16(num), nil + case schemapb.DataType_Int32: + value, ok := obj.(json.Number) + if !ok { + return nil, r.wrapTypeError(obj, fieldID) + } + num, err := strconv.ParseInt(value.String(), 0, 32) + if err != nil { + return nil, err + } + return int32(num), nil + case schemapb.DataType_Int64: + value, ok := obj.(json.Number) + if !ok { + return nil, r.wrapTypeError(obj, fieldID) + } + num, err := strconv.ParseInt(value.String(), 0, 64) + if err != nil { + return nil, err + } + return num, nil + case schemapb.DataType_Float: + value, ok := obj.(json.Number) + if !ok { + return nil, r.wrapTypeError(obj, fieldID) + } + num, err := strconv.ParseFloat(value.String(), 32) + if err != nil { + return nil, err + } + return float32(num), nil + case schemapb.DataType_Double: + value, ok := obj.(json.Number) + if !ok { + return nil, r.wrapTypeError(obj, fieldID) + } + num, err := strconv.ParseFloat(value.String(), 64) + if err != nil { + return nil, err + } + return num, nil + case schemapb.DataType_BinaryVector: + arr, ok := obj.([]interface{}) + if !ok { + return nil, r.wrapTypeError(obj, fieldID) + } + if len(arr)*8 != r.dim { + return nil, r.wrapDimError(len(arr)*8, fieldID) + } + vec := make([]byte, len(arr)) + for i := 0; i < len(arr); i++ { + value, ok := arr[i].(json.Number) + if !ok { + return nil, r.wrapTypeError(arr[i], fieldID) + } + num, err := strconv.ParseUint(value.String(), 0, 8) + if err != nil { + return nil, err + } + vec[i] = byte(num) + } + return vec, nil + case schemapb.DataType_FloatVector: + arr, ok := obj.([]interface{}) + if !ok { + return nil, r.wrapTypeError(obj, fieldID) + } + if len(arr) != r.dim { + return nil, r.wrapDimError(len(arr), fieldID) + } + vec := make([]float32, len(arr)) + for i := 0; i < len(arr); i++ { + value, ok := arr[i].(json.Number) + if !ok { + return nil, r.wrapTypeError(arr[i], fieldID) + } + num, err := strconv.ParseFloat(value.String(), 32) + if err != nil { + return nil, err + } + vec[i] = float32(num) + } + return vec, nil + case schemapb.DataType_Float16Vector: + arr, ok := obj.([]interface{}) + if !ok { + return nil, r.wrapTypeError(obj, fieldID) + } + if len(arr)/2 != r.dim { + return nil, r.wrapDimError(len(arr)/2, fieldID) + } + vec := make([]byte, len(arr)) + for i := 0; i < len(arr); i++ { + value, ok := arr[i].(json.Number) + if !ok { + return nil, r.wrapTypeError(arr[i], fieldID) + } + num, err := strconv.ParseUint(value.String(), 0, 8) + if err != nil { + return nil, err + } + vec[i] = byte(num) + } + return vec, nil + case schemapb.DataType_String, schemapb.DataType_VarChar: + value, ok := obj.(string) + if !ok { + return nil, r.wrapTypeError(obj, fieldID) + } + return value, nil + case schemapb.DataType_JSON: + // for JSON data, we accept two kinds input: string and map[string]interface + // user can write JSON content as {"FieldJSON": "{\"x\": 8}"} or {"FieldJSON": {"x": 8}} + if value, ok := obj.(string); ok { + var dummy interface{} + err := json.Unmarshal([]byte(value), &dummy) + if err != nil { + return nil, err + } + return []byte(value), nil + } else if mp, ok := obj.(map[string]interface{}); ok { + bs, err := json.Marshal(mp) + if err != nil { + return nil, err + } + return bs, nil + } else { + return nil, r.wrapTypeError(obj, fieldID) + } + case schemapb.DataType_Array: + arr, ok := obj.([]interface{}) + if !ok { + return nil, r.wrapTypeError(obj, fieldID) + } + scalarFieldData, err := r.arrayToFieldData(arr, r.id2Field[fieldID].GetElementType()) + if err != nil { + return nil, err + } + return scalarFieldData, nil + default: + return nil, merr.WrapErrImportFailed(fmt.Sprintf("parse json failed, unsupport data type: %s", + r.id2Field[fieldID].GetDataType().String())) + } +} + +func (r *rowParser) arrayToFieldData(arr []interface{}, eleType schemapb.DataType) (*schemapb.ScalarField, error) { + switch eleType { + case schemapb.DataType_Bool: + values := make([]bool, 0) + for i := 0; i < len(arr); i++ { + value, ok := arr[i].(bool) + if !ok { + return nil, r.wrapArrayValueTypeError(arr, eleType) + } + values = append(values, value) + } + return &schemapb.ScalarField{ + Data: &schemapb.ScalarField_BoolData{ + BoolData: &schemapb.BoolArray{ + Data: values, + }, + }, + }, nil + case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32: + values := make([]int32, 0) + for i := 0; i < len(arr); i++ { + value, ok := arr[i].(json.Number) + if !ok { + return nil, r.wrapArrayValueTypeError(arr, eleType) + } + num, err := strconv.ParseInt(value.String(), 0, 32) + if err != nil { + return nil, err + } + values = append(values, int32(num)) + } + return &schemapb.ScalarField{ + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: values, + }, + }, + }, nil + case schemapb.DataType_Int64: + values := make([]int64, 0) + for i := 0; i < len(arr); i++ { + value, ok := arr[i].(json.Number) + if !ok { + return nil, r.wrapArrayValueTypeError(arr, eleType) + } + num, err := strconv.ParseInt(value.String(), 0, 64) + if err != nil { + return nil, err + } + values = append(values, num) + } + return &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: values, + }, + }, + }, nil + case schemapb.DataType_Float: + values := make([]float32, 0) + for i := 0; i < len(arr); i++ { + value, ok := arr[i].(json.Number) + if !ok { + return nil, r.wrapArrayValueTypeError(arr, eleType) + } + num, err := strconv.ParseFloat(value.String(), 32) + if err != nil { + return nil, err + } + values = append(values, float32(num)) + } + return &schemapb.ScalarField{ + Data: &schemapb.ScalarField_FloatData{ + FloatData: &schemapb.FloatArray{ + Data: values, + }, + }, + }, nil + case schemapb.DataType_Double: + values := make([]float64, 0) + for i := 0; i < len(arr); i++ { + value, ok := arr[i].(json.Number) + if !ok { + return nil, r.wrapArrayValueTypeError(arr, eleType) + } + num, err := strconv.ParseFloat(value.String(), 64) + if err != nil { + return nil, err + } + values = append(values, num) + } + return &schemapb.ScalarField{ + Data: &schemapb.ScalarField_DoubleData{ + DoubleData: &schemapb.DoubleArray{ + Data: values, + }, + }, + }, nil + case schemapb.DataType_VarChar, schemapb.DataType_String: + values := make([]string, 0) + for i := 0; i < len(arr); i++ { + value, ok := arr[i].(string) + if !ok { + return nil, r.wrapArrayValueTypeError(arr, eleType) + } + values = append(values, value) + } + return &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: values, + }, + }, + }, nil + default: + return nil, errors.New(fmt.Sprintf("unsupported array data type '%s'", eleType.String())) + } +} diff --git a/pkg/util/typeutil/schema.go b/pkg/util/typeutil/schema.go index df0d5453a4c34..2e816d2b4a9fd 100644 --- a/pkg/util/typeutil/schema.go +++ b/pkg/util/typeutil/schema.go @@ -816,6 +816,16 @@ func GetPartitionKeyFieldSchema(schema *schemapb.CollectionSchema) (*schemapb.Fi return nil, errors.New("partition key field is not found") } +// GetDynamicField returns the dynamic field if it exists. +func GetDynamicField(schema *schemapb.CollectionSchema) *schemapb.FieldSchema { + for _, fieldSchema := range schema.GetFields() { + if fieldSchema.GetIsDynamic() { + return fieldSchema + } + } + return nil +} + // HasPartitionKey check if a collection schema has PartitionKey field func HasPartitionKey(schema *schemapb.CollectionSchema) bool { for _, fieldSchema := range schema.Fields { From a0cec4047a293d90af64d1364c076828bacfd2e8 Mon Sep 17 00:00:00 2001 From: yah01 Date: Fri, 5 Jan 2024 18:24:47 +0800 Subject: [PATCH 03/20] fix: make the entity num metric accurate (#29643) fix #29642 Signed-off-by: yah01 --- .../delegator/delegator_data_test.go | 2 +- internal/querynodev2/segments/manager.go | 35 +++++++++---------- 2 files changed, 17 insertions(+), 20 deletions(-) diff --git a/internal/querynodev2/delegator/delegator_data_test.go b/internal/querynodev2/delegator/delegator_data_test.go index 91265890a85c3..cfb64dc9e167b 100644 --- a/internal/querynodev2/delegator/delegator_data_test.go +++ b/internal/querynodev2/delegator/delegator_data_test.go @@ -911,7 +911,7 @@ func (s *DelegatorDataSuite) TestSyncTargetVersion() { ms.EXPECT().Type().Return(segments.SegmentTypeGrowing) ms.EXPECT().Collection().Return(1) ms.EXPECT().Partition().Return(1) - ms.EXPECT().RowNum().Return(0) + ms.EXPECT().InsertCount().Return(0) ms.EXPECT().Indexes().Return(nil) ms.EXPECT().Shard().Return(s.vchannelName) ms.EXPECT().Level().Return(datapb.SegmentLevel_L1) diff --git a/internal/querynodev2/segments/manager.go b/internal/querynodev2/segments/manager.go index a6d4f23bbb102..03c77174ff769 100644 --- a/internal/querynodev2/segments/manager.go +++ b/internal/querynodev2/segments/manager.go @@ -202,15 +202,14 @@ func (mgr *segmentManager) Put(segmentType SegmentType, segments ...Segment) { fmt.Sprint(len(segment.Indexes())), segment.Level().String(), ).Inc() - if segment.RowNum() > 0 { - metrics.QueryNodeNumEntities.WithLabelValues( - fmt.Sprint(paramtable.GetNodeID()), - fmt.Sprint(segment.Collection()), - fmt.Sprint(segment.Partition()), - segment.Type().String(), - fmt.Sprint(len(segment.Indexes())), - ).Add(float64(segment.RowNum())) - } + + metrics.QueryNodeNumEntities.WithLabelValues( + fmt.Sprint(paramtable.GetNodeID()), + fmt.Sprint(segment.Collection()), + fmt.Sprint(segment.Partition()), + segment.Type().String(), + fmt.Sprint(len(segment.Indexes())), + ).Add(float64(segment.InsertCount())) } mgr.updateMetric() @@ -556,7 +555,6 @@ func (mgr *segmentManager) updateMetric() { } func remove(segment Segment) bool { - rowNum := segment.RowNum() segment.Release() metrics.QueryNodeNumSegments.WithLabelValues( @@ -567,14 +565,13 @@ func remove(segment Segment) bool { fmt.Sprint(len(segment.Indexes())), segment.Level().String(), ).Dec() - if rowNum > 0 { - metrics.QueryNodeNumEntities.WithLabelValues( - fmt.Sprint(paramtable.GetNodeID()), - fmt.Sprint(segment.Collection()), - fmt.Sprint(segment.Partition()), - segment.Type().String(), - fmt.Sprint(len(segment.Indexes())), - ).Sub(float64(rowNum)) - } + + metrics.QueryNodeNumEntities.WithLabelValues( + fmt.Sprint(paramtable.GetNodeID()), + fmt.Sprint(segment.Collection()), + fmt.Sprint(segment.Partition()), + segment.Type().String(), + fmt.Sprint(len(segment.Indexes())), + ).Sub(float64(segment.InsertCount())) return true } From 5be909982d9d1655413ac5660de2adf26940445d Mon Sep 17 00:00:00 2001 From: congqixia Date: Fri, 5 Jan 2024 21:42:47 +0800 Subject: [PATCH 04/20] enhance: add MockSerializer generation command into Makefile (#29713) See also #27675 Signed-off-by: Congqi Xia --- Makefile | 1 + 1 file changed, 1 insertion(+) diff --git a/Makefile b/Makefile index aae2477f5aae0..4d6b4c53a18d1 100644 --- a/Makefile +++ b/Makefile @@ -447,6 +447,7 @@ generate-mockery-datanode: getdeps $(INSTALL_PATH)/mockery --name=MetaCache --dir=$(PWD)/internal/datanode/metacache --output=$(PWD)/internal/datanode/metacache --filename=mock_meta_cache.go --with-expecter --structname=MockMetaCache --outpkg=metacache --inpackage $(INSTALL_PATH)/mockery --name=SyncManager --dir=$(PWD)/internal/datanode/syncmgr --output=$(PWD)/internal/datanode/syncmgr --filename=mock_sync_manager.go --with-expecter --structname=MockSyncManager --outpkg=syncmgr --inpackage $(INSTALL_PATH)/mockery --name=MetaWriter --dir=$(PWD)/internal/datanode/syncmgr --output=$(PWD)/internal/datanode/syncmgr --filename=mock_meta_writer.go --with-expecter --structname=MockMetaWriter --outpkg=syncmgr --inpackage + $(INSTALL_PATH)/mockery --name=Serializer --dir=$(PWD)/internal/datanode/syncmgr --output=$(PWD)/internal/datanode/syncmgr --filename=mock_serializer.go --with-expecter --structname=MockSerializer --outpkg=syncmgr --inpackage $(INSTALL_PATH)/mockery --name=WriteBuffer --dir=$(PWD)/internal/datanode/writebuffer --output=$(PWD)/internal/datanode/writebuffer --filename=mock_write_buffer.go --with-expecter --structname=MockWriteBuffer --outpkg=writebuffer --inpackage $(INSTALL_PATH)/mockery --name=BufferManager --dir=$(PWD)/internal/datanode/writebuffer --output=$(PWD)/internal/datanode/writebuffer --filename=mock_mananger.go --with-expecter --structname=MockBufferManager --outpkg=writebuffer --inpackage $(INSTALL_PATH)/mockery --name=BinlogIO --dir=$(PWD)/internal/datanode/io --output=$(PWD)/internal/datanode/io --filename=mock_binlogio.go --with-expecter --structname=MockBinlogIO --outpkg=io --inpackage From b5f039a2215881c29801adf3ee8365e1cf61a846 Mon Sep 17 00:00:00 2001 From: congqixia Date: Sun, 7 Jan 2024 15:54:47 +0800 Subject: [PATCH 05/20] fix: Assertion all async invocations in test case (#29737) Resolves: #29736 Signed-off-by: Congqi Xia --- .../querycoordv2/checkers/controller_test.go | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/internal/querycoordv2/checkers/controller_test.go b/internal/querycoordv2/checkers/controller_test.go index 6df196c9b8d09..9126a30f47de9 100644 --- a/internal/querycoordv2/checkers/controller_test.go +++ b/internal/querycoordv2/checkers/controller_test.go @@ -124,15 +124,23 @@ func (suite *CheckerControllerSuite) TestBasic() { suite.scheduler.EXPECT().GetSegmentTaskNum().Return(0).Maybe() suite.scheduler.EXPECT().GetChannelTaskNum().Return(0).Maybe() - suite.balancer.EXPECT().AssignSegment(mock.Anything, mock.Anything, mock.Anything).Return(nil) - suite.balancer.EXPECT().AssignChannel(mock.Anything, mock.Anything).Return(nil) + assignSegCounter := atomic.NewInt32(0) + assingChanCounter := atomic.NewInt32(0) + suite.balancer.EXPECT().AssignSegment(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(i1 int64, s []*meta.Segment, i2 []int64) []balance.SegmentAssignPlan { + assignSegCounter.Inc() + return nil + }) + suite.balancer.EXPECT().AssignChannel(mock.Anything, mock.Anything).RunAndReturn(func(dc []*meta.DmChannel, i []int64) []balance.ChannelAssignPlan { + assingChanCounter.Inc() + return nil + }) suite.controller.Start() defer suite.controller.Stop() suite.Eventually(func() bool { suite.controller.Check() - return counter.Load() > 0 - }, 5*time.Second, 1*time.Second) + return counter.Load() > 0 && assignSegCounter.Load() > 0 && assingChanCounter.Load() > 0 + }, 5*time.Second, 1*time.Millisecond) } func TestCheckControllerSuite(t *testing.T) { From 5dc300c4a9bc36e2a2e894b3ab4009e24a08484d Mon Sep 17 00:00:00 2001 From: "cai.zhang" Date: Sun, 7 Jan 2024 19:36:48 +0800 Subject: [PATCH 06/20] fix: Fix bug for pk index doesn't have raw data (#29711) issue: #29697 Signed-off-by: Cai Zhang --- .../core/src/segcore/SegmentSealedImpl.cpp | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/internal/core/src/segcore/SegmentSealedImpl.cpp b/internal/core/src/segcore/SegmentSealedImpl.cpp index a8f2d45f5e258..72f4fa09989db 100644 --- a/internal/core/src/segcore/SegmentSealedImpl.cpp +++ b/internal/core/src/segcore/SegmentSealedImpl.cpp @@ -150,21 +150,26 @@ SegmentSealedImpl::LoadScalarIndex(const LoadIndexInfo& info) { case DataType::INT64: { auto int64_index = dynamic_cast*>( scalar_indexings_[field_id].get()); - for (int i = 0; i < row_count; ++i) { - insert_record_.insert_pk(int64_index->Reverse_Lookup(i), i); + if (int64_index->HasRawData()) { + for (int i = 0; i < row_count; ++i) { + insert_record_.insert_pk(int64_index->Reverse_Lookup(i), + i); + } + insert_record_.seal_pks(); } - insert_record_.seal_pks(); break; } case DataType::VARCHAR: { auto string_index = dynamic_cast*>( scalar_indexings_[field_id].get()); - for (int i = 0; i < row_count; ++i) { - insert_record_.insert_pk(string_index->Reverse_Lookup(i), - i); + if (string_index->HasRawData()) { + for (int i = 0; i < row_count; ++i) { + insert_record_.insert_pk( + string_index->Reverse_Lookup(i), i); + } + insert_record_.seal_pks(); } - insert_record_.seal_pks(); break; } default: { From 156a0dd4501733e167c64015289ebabd3096e6ec Mon Sep 17 00:00:00 2001 From: "yihao.dai" Date: Sun, 7 Jan 2024 19:38:49 +0800 Subject: [PATCH 07/20] feat: Add import reader for Parquet (#29618) This PR implements a Parquet reader for import. issue: https://github.com/milvus-io/milvus/issues/28521 --------- Signed-off-by: bigsheeper --- .../util/importutilv2/parquet/field_reader.go | 568 ++++++++++++++++++ internal/util/importutilv2/parquet/reader.go | 113 ++++ .../util/importutilv2/parquet/reader_test.go | 447 ++++++++++++++ internal/util/importutilv2/parquet/util.go | 175 ++++++ 4 files changed, 1303 insertions(+) create mode 100644 internal/util/importutilv2/parquet/field_reader.go create mode 100644 internal/util/importutilv2/parquet/reader.go create mode 100644 internal/util/importutilv2/parquet/reader_test.go create mode 100644 internal/util/importutilv2/parquet/util.go diff --git a/internal/util/importutilv2/parquet/field_reader.go b/internal/util/importutilv2/parquet/field_reader.go new file mode 100644 index 0000000000000..162ea59e92ad3 --- /dev/null +++ b/internal/util/importutilv2/parquet/field_reader.go @@ -0,0 +1,568 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package parquet + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/apache/arrow/go/v12/arrow" + "github.com/apache/arrow/go/v12/arrow/array" + "github.com/apache/arrow/go/v12/parquet/pqarrow" + "github.com/samber/lo" + "golang.org/x/exp/constraints" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type FieldReader struct { + columnIndex int + columnReader *pqarrow.ColumnReader + + dim int + field *schemapb.FieldSchema +} + +func NewFieldReader(reader *pqarrow.FileReader, columnIndex int, field *schemapb.FieldSchema) (*FieldReader, error) { + columnReader, err := reader.GetColumn(context.Background(), columnIndex) // TODO: dyh, resolve context + if err != nil { + return nil, err + } + + var dim int64 = 1 + if typeutil.IsVectorType(field.GetDataType()) { + dim, err = typeutil.GetDim(field) + if err != nil { + return nil, err + } + } + + cr := &FieldReader{ + columnIndex: columnIndex, + columnReader: columnReader, + dim: int(dim), + field: field, + } + return cr, nil +} + +func (c *FieldReader) Next(count int64) (any, error) { + switch c.field.GetDataType() { + case schemapb.DataType_Bool: + return ReadBoolData(c, count) + case schemapb.DataType_Int8: + return ReadIntegerOrFloatData[int8](c, count) + case schemapb.DataType_Int16: + return ReadIntegerOrFloatData[int16](c, count) + case schemapb.DataType_Int32: + return ReadIntegerOrFloatData[int32](c, count) + case schemapb.DataType_Int64: + return ReadIntegerOrFloatData[int64](c, count) + case schemapb.DataType_Float: + data, err := ReadIntegerOrFloatData[float32](c, count) + if err != nil { + return nil, err + } + if data == nil { + return nil, nil + } + return data, typeutil.VerifyFloats32(data.([]float32)) + case schemapb.DataType_Double: + data, err := ReadIntegerOrFloatData[float64](c, count) + if err != nil { + return nil, err + } + if data == nil { + return nil, nil + } + return data, typeutil.VerifyFloats64(data.([]float64)) + case schemapb.DataType_VarChar, schemapb.DataType_String: + return ReadStringData(c, count) + case schemapb.DataType_JSON: + // JSON field read data from string array Parquet + data, err := ReadStringData(c, count) + if err != nil { + return nil, err + } + if data == nil { + return nil, nil + } + byteArr := make([][]byte, 0) + for _, str := range data.([]string) { + var dummy interface{} + err = json.Unmarshal([]byte(str), &dummy) + if err != nil { + return nil, err + } + byteArr = append(byteArr, []byte(str)) + } + return byteArr, nil + case schemapb.DataType_BinaryVector: + return ReadBinaryData(c, count) + case schemapb.DataType_FloatVector: + arrayData, err := ReadIntegerOrFloatArrayData[float32](c, count) + if err != nil { + return nil, err + } + if arrayData == nil { + return nil, nil + } + vectors := lo.Flatten(arrayData.([][]float32)) + return vectors, nil + case schemapb.DataType_Array: + data := make([]*schemapb.ScalarField, 0, count) + elementType := c.field.GetElementType() + switch elementType { + case schemapb.DataType_Bool: + boolArray, err := ReadBoolArrayData(c, count) + if err != nil { + return nil, err + } + if boolArray == nil { + return nil, nil + } + for _, elementArray := range boolArray.([][]bool) { + data = append(data, &schemapb.ScalarField{ + Data: &schemapb.ScalarField_BoolData{ + BoolData: &schemapb.BoolArray{ + Data: elementArray, + }, + }, + }) + } + + case schemapb.DataType_Int8: + int8Array, err := ReadIntegerOrFloatArrayData[int32](c, count) + if err != nil { + return nil, err + } + if int8Array == nil { + return nil, nil + } + for _, elementArray := range int8Array.([][]int32) { + data = append(data, &schemapb.ScalarField{ + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: elementArray, + }, + }, + }) + } + + case schemapb.DataType_Int16: + int16Array, err := ReadIntegerOrFloatArrayData[int32](c, count) + if err != nil { + return nil, err + } + if int16Array == nil { + return nil, nil + } + for _, elementArray := range int16Array.([][]int32) { + data = append(data, &schemapb.ScalarField{ + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: elementArray, + }, + }, + }) + } + + case schemapb.DataType_Int32: + int32Array, err := ReadIntegerOrFloatArrayData[int32](c, count) + if err != nil { + return nil, err + } + if int32Array == nil { + return nil, nil + } + for _, elementArray := range int32Array.([][]int32) { + data = append(data, &schemapb.ScalarField{ + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: elementArray, + }, + }, + }) + } + + case schemapb.DataType_Int64: + int64Array, err := ReadIntegerOrFloatArrayData[int64](c, count) + if err != nil { + return nil, err + } + if int64Array == nil { + return nil, nil + } + for _, elementArray := range int64Array.([][]int64) { + data = append(data, &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: elementArray, + }, + }, + }) + } + + case schemapb.DataType_Float: + float32Array, err := ReadIntegerOrFloatArrayData[float32](c, count) + if err != nil { + return nil, err + } + if float32Array == nil { + return nil, nil + } + for _, elementArray := range float32Array.([][]float32) { + data = append(data, &schemapb.ScalarField{ + Data: &schemapb.ScalarField_FloatData{ + FloatData: &schemapb.FloatArray{ + Data: elementArray, + }, + }, + }) + } + + case schemapb.DataType_Double: + float64Array, err := ReadIntegerOrFloatArrayData[float64](c, count) + if err != nil { + return nil, err + } + if float64Array == nil { + return nil, nil + } + for _, elementArray := range float64Array.([][]float64) { + data = append(data, &schemapb.ScalarField{ + Data: &schemapb.ScalarField_DoubleData{ + DoubleData: &schemapb.DoubleArray{ + Data: elementArray, + }, + }, + }) + } + + case schemapb.DataType_VarChar, schemapb.DataType_String: + stringArray, err := ReadStringArrayData(c, count) + if err != nil { + return nil, err + } + if stringArray == nil { + return nil, nil + } + for _, elementArray := range stringArray.([][]string) { + data = append(data, &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: elementArray, + }, + }, + }) + } + default: + return nil, merr.WrapErrImportFailed(fmt.Sprintf("unsupported data type '%s' for array field '%s'", + elementType.String(), c.field.GetName())) + } + return data, nil + default: + return nil, merr.WrapErrImportFailed(fmt.Sprintf("unsupported data type '%s' for field '%s'", + c.field.GetDataType().String(), c.field.GetName())) + } +} + +func (c *FieldReader) Close() {} + +func ReadBoolData(pcr *FieldReader, count int64) (any, error) { + chunked, err := pcr.columnReader.NextBatch(count) + if err != nil { + return nil, err + } + data := make([]bool, 0, count) + for _, chunk := range chunked.Chunks() { + dataNums := chunk.Data().Len() + chunkData := make([]bool, dataNums) + boolReader, ok := chunk.(*array.Boolean) + if !ok { + return nil, WrapTypeErr("bool", chunk.DataType().Name(), pcr.field) + } + for i := 0; i < dataNums; i++ { + chunkData[i] = boolReader.Value(i) + } + data = append(data, chunkData...) + } + if len(data) == 0 { + return nil, nil + } + return data, nil +} + +func ReadIntegerOrFloatData[T constraints.Integer | constraints.Float](pcr *FieldReader, count int64) (any, error) { + chunked, err := pcr.columnReader.NextBatch(count) + if err != nil { + return nil, err + } + data := make([]T, 0, count) + for _, chunk := range chunked.Chunks() { + dataNums := chunk.Data().Len() + chunkData := make([]T, dataNums) + switch chunk.DataType().ID() { + case arrow.INT8: + int8Reader := chunk.(*array.Int8) + for i := 0; i < dataNums; i++ { + chunkData[i] = T(int8Reader.Value(i)) + } + case arrow.INT16: + int16Reader := chunk.(*array.Int16) + for i := 0; i < dataNums; i++ { + chunkData[i] = T(int16Reader.Value(i)) + } + case arrow.INT32: + int32Reader := chunk.(*array.Int32) + for i := 0; i < dataNums; i++ { + chunkData[i] = T(int32Reader.Value(i)) + } + case arrow.INT64: + int64Reader := chunk.(*array.Int64) + for i := 0; i < dataNums; i++ { + chunkData[i] = T(int64Reader.Value(i)) + } + case arrow.FLOAT32: + float32Reader := chunk.(*array.Float32) + for i := 0; i < dataNums; i++ { + chunkData[i] = T(float32Reader.Value(i)) + } + case arrow.FLOAT64: + float64Reader := chunk.(*array.Float64) + for i := 0; i < dataNums; i++ { + chunkData[i] = T(float64Reader.Value(i)) + } + default: + return nil, WrapTypeErr("integer|float", chunk.DataType().Name(), pcr.field) + } + data = append(data, chunkData...) + } + if len(data) == 0 { + return nil, nil + } + return data, nil +} + +func ReadStringData(pcr *FieldReader, count int64) (any, error) { + chunked, err := pcr.columnReader.NextBatch(count) + if err != nil { + return nil, err + } + data := make([]string, 0, count) + for _, chunk := range chunked.Chunks() { + dataNums := chunk.Data().Len() + chunkData := make([]string, dataNums) + stringReader, ok := chunk.(*array.String) + if !ok { + return nil, WrapTypeErr("string", chunk.DataType().Name(), pcr.field) + } + for i := 0; i < dataNums; i++ { + chunkData[i] = stringReader.Value(i) + } + data = append(data, chunkData...) + } + if len(data) == 0 { + return nil, nil + } + return data, nil +} + +func ReadBinaryData(pcr *FieldReader, count int64) (any, error) { + chunked, err := pcr.columnReader.NextBatch(count) + if err != nil { + return nil, err + } + data := make([]byte, 0, count) + for _, chunk := range chunked.Chunks() { + dataNums := chunk.Data().Len() + switch chunk.DataType().ID() { + case arrow.BINARY: + binaryReader := chunk.(*array.Binary) + for i := 0; i < dataNums; i++ { + data = append(data, binaryReader.Value(i)...) + } + case arrow.LIST: + listReader := chunk.(*array.List) + if !isRegularVector(listReader.Offsets(), pcr.dim, true) { + return nil, merr.WrapErrImportFailed("binary vector is irregular") + } + uint8Reader, ok := listReader.ListValues().(*array.Uint8) + if !ok { + return nil, WrapTypeErr("binary", listReader.ListValues().DataType().Name(), pcr.field) + } + for i := 0; i < uint8Reader.Len(); i++ { + data = append(data, uint8Reader.Value(i)) + } + default: + return nil, WrapTypeErr("binary", chunk.DataType().Name(), pcr.field) + } + } + if len(data) == 0 { + return nil, nil + } + return data, nil +} + +func isRegularVector(offsets []int32, dim int, isBinary bool) bool { + if len(offsets) < 1 { + return false + } + if isBinary { + dim = dim / 8 + } + for i := 1; i < len(offsets); i++ { + if offsets[i]-offsets[i-1] != int32(dim) { + return false + } + } + return true +} + +func ReadBoolArrayData(pcr *FieldReader, count int64) (any, error) { + chunked, err := pcr.columnReader.NextBatch(count) + if err != nil { + return nil, err + } + data := make([][]bool, 0, count) + for _, chunk := range chunked.Chunks() { + listReader, ok := chunk.(*array.List) + if !ok { + return nil, WrapTypeErr("list", chunk.DataType().Name(), pcr.field) + } + boolReader, ok := listReader.ListValues().(*array.Boolean) + if !ok { + return nil, WrapTypeErr("boolArray", chunk.DataType().Name(), pcr.field) + } + offsets := listReader.Offsets() + for i := 1; i < len(offsets); i++ { + start, end := offsets[i-1], offsets[i] + elementData := make([]bool, 0, end-start) + for j := start; j < end; j++ { + elementData = append(elementData, boolReader.Value(int(j))) + } + data = append(data, elementData) + } + } + if len(data) == 0 { + return nil, nil + } + return data, nil +} + +func ReadIntegerOrFloatArrayData[T constraints.Integer | constraints.Float](pcr *FieldReader, count int64) (any, error) { + chunked, err := pcr.columnReader.NextBatch(count) + if err != nil { + return nil, err + } + data := make([][]T, 0, count) + + getDataFunc := func(offsets []int32, getValue func(int) T) { + for i := 1; i < len(offsets); i++ { + start, end := offsets[i-1], offsets[i] + elementData := make([]T, 0, end-start) + for j := start; j < end; j++ { + elementData = append(elementData, getValue(int(j))) + } + data = append(data, elementData) + } + } + for _, chunk := range chunked.Chunks() { + listReader, ok := chunk.(*array.List) + if !ok { + return nil, WrapTypeErr("list", chunk.DataType().Name(), pcr.field) + } + offsets := listReader.Offsets() + if typeutil.IsVectorType(pcr.field.GetDataType()) && + !isRegularVector(offsets, pcr.dim, pcr.field.GetDataType() == schemapb.DataType_BinaryVector) { + return nil, merr.WrapErrImportFailed("float vector is irregular") + } + valueReader := listReader.ListValues() + switch valueReader.DataType().ID() { + case arrow.INT8: + int8Reader := valueReader.(*array.Int8) + getDataFunc(offsets, func(i int) T { + return T(int8Reader.Value(i)) + }) + case arrow.INT16: + int16Reader := valueReader.(*array.Int16) + getDataFunc(offsets, func(i int) T { + return T(int16Reader.Value(i)) + }) + case arrow.INT32: + int32Reader := valueReader.(*array.Int32) + getDataFunc(offsets, func(i int) T { + return T(int32Reader.Value(i)) + }) + case arrow.INT64: + int64Reader := valueReader.(*array.Int64) + getDataFunc(offsets, func(i int) T { + return T(int64Reader.Value(i)) + }) + case arrow.FLOAT32: + float32Reader := valueReader.(*array.Float32) + getDataFunc(offsets, func(i int) T { + return T(float32Reader.Value(i)) + }) + case arrow.FLOAT64: + float64Reader := valueReader.(*array.Float64) + getDataFunc(offsets, func(i int) T { + return T(float64Reader.Value(i)) + }) + default: + return nil, WrapTypeErr("integerArray|floatArray", chunk.DataType().Name(), pcr.field) + } + } + if len(data) == 0 { + return nil, nil + } + return data, nil +} + +func ReadStringArrayData(pcr *FieldReader, count int64) (any, error) { + chunked, err := pcr.columnReader.NextBatch(count) + if err != nil { + return nil, err + } + data := make([][]string, 0, count) + for _, chunk := range chunked.Chunks() { + listReader, ok := chunk.(*array.List) + if !ok { + return nil, WrapTypeErr("list", chunk.DataType().Name(), pcr.field) + } + stringReader, ok := listReader.ListValues().(*array.String) + if !ok { + return nil, WrapTypeErr("stringArray", chunk.DataType().Name(), pcr.field) + } + offsets := listReader.Offsets() + for i := 1; i < len(offsets); i++ { + start, end := offsets[i-1], offsets[i] + elementData := make([]string, 0, end-start) + for j := start; j < end; j++ { + elementData = append(elementData, stringReader.Value(int(j))) + } + data = append(data, elementData) + } + } + if len(data) == 0 { + return nil, nil + } + return data, nil +} diff --git a/internal/util/importutilv2/parquet/reader.go b/internal/util/importutilv2/parquet/reader.go new file mode 100644 index 0000000000000..ba2c4f9d21c28 --- /dev/null +++ b/internal/util/importutilv2/parquet/reader.go @@ -0,0 +1,113 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package parquet + +import ( + "fmt" + + "github.com/apache/arrow/go/v12/arrow/memory" + "github.com/apache/arrow/go/v12/parquet" + "github.com/apache/arrow/go/v12/parquet/file" + "github.com/apache/arrow/go/v12/parquet/pqarrow" + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/merr" +) + +type Reader struct { + reader *file.Reader + + bufferSize int + + schema *schemapb.CollectionSchema + frs map[int64]*FieldReader // fieldID -> FieldReader +} + +func NewReader(schema *schemapb.CollectionSchema, cmReader storage.FileReader, bufferSize int) (*Reader, error) { + const pqBufSize = 32 * 1024 * 1024 // TODO: dyh, make if configurable + size := calcBufferSize(pqBufSize, schema) + reader, err := file.NewParquetReader(cmReader, file.WithReadProps(&parquet.ReaderProperties{ + BufferSize: int64(size), + BufferedStreamEnabled: true, + })) + if err != nil { + return nil, merr.WrapErrImportFailed(fmt.Sprintf("new parquet reader failed, err=%v", err)) + } + log.Info("create parquet reader done", zap.Int("row group num", reader.NumRowGroups()), + zap.Int64("num rows", reader.NumRows())) + + fileReader, err := pqarrow.NewFileReader(reader, pqarrow.ArrowReadProperties{}, memory.DefaultAllocator) + if err != nil { + return nil, merr.WrapErrImportFailed(fmt.Sprintf("new parquet file reader failed, err=%v", err)) + } + + crs, err := CreateFieldReaders(fileReader, schema) + if err != nil { + return nil, err + } + return &Reader{ + reader: reader, + bufferSize: bufferSize, + schema: schema, + frs: crs, + }, nil +} + +func (r *Reader) Read() (*storage.InsertData, error) { + insertData, err := storage.NewInsertData(r.schema) + if err != nil { + return nil, err + } +OUTER: + for { + for fieldID, cr := range r.frs { + data, err := cr.Next(1) + if err != nil { + return nil, err + } + if data == nil { + break OUTER + } + err = insertData.Data[fieldID].AppendRows(data) + if err != nil { + return nil, err + } + } + if insertData.GetMemorySize() >= r.bufferSize { + break + } + } + for fieldID := range r.frs { + if insertData.Data[fieldID].RowNum() == 0 { + return nil, nil + } + } + return insertData, nil +} + +func (r *Reader) Close() { + for _, cr := range r.frs { + cr.Close() + } + err := r.reader.Close() + if err != nil { + log.Warn("close parquet reader failed", zap.Error(err)) + } +} diff --git a/internal/util/importutilv2/parquet/reader_test.go b/internal/util/importutilv2/parquet/reader_test.go new file mode 100644 index 0000000000000..1b0e1a0639493 --- /dev/null +++ b/internal/util/importutilv2/parquet/reader_test.go @@ -0,0 +1,447 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package parquet + +import ( + "context" + "fmt" + "io" + "math" + "math/rand" + "os" + "testing" + + "github.com/apache/arrow/go/v12/arrow" + "github.com/apache/arrow/go/v12/arrow/array" + "github.com/apache/arrow/go/v12/arrow/memory" + "github.com/apache/arrow/go/v12/parquet" + "github.com/apache/arrow/go/v12/parquet/pqarrow" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type ReaderSuite struct { + suite.Suite + + numRows int + pkDataType schemapb.DataType + vecDataType schemapb.DataType +} + +func (s *ReaderSuite) SetupSuite() { + paramtable.Get().Init(paramtable.NewBaseTable()) +} + +func (s *ReaderSuite) SetupTest() { + // default suite params + s.numRows = 100 + s.pkDataType = schemapb.DataType_Int64 + s.vecDataType = schemapb.DataType_FloatVector +} + +func milvusDataTypeToArrowType(dataType schemapb.DataType, isBinary bool) arrow.DataType { + switch dataType { + case schemapb.DataType_Bool: + return &arrow.BooleanType{} + case schemapb.DataType_Int8: + return &arrow.Int8Type{} + case schemapb.DataType_Int16: + return &arrow.Int16Type{} + case schemapb.DataType_Int32: + return &arrow.Int32Type{} + case schemapb.DataType_Int64: + return &arrow.Int64Type{} + case schemapb.DataType_Float: + return &arrow.Float32Type{} + case schemapb.DataType_Double: + return &arrow.Float64Type{} + case schemapb.DataType_VarChar, schemapb.DataType_String: + return &arrow.StringType{} + case schemapb.DataType_Array: + return &arrow.ListType{} + case schemapb.DataType_JSON: + return &arrow.StringType{} + case schemapb.DataType_FloatVector: + return arrow.ListOfField(arrow.Field{ + Name: "item", + Type: &arrow.Float32Type{}, + Nullable: true, + Metadata: arrow.Metadata{}, + }) + case schemapb.DataType_BinaryVector: + if isBinary { + return &arrow.BinaryType{} + } + return arrow.ListOfField(arrow.Field{ + Name: "item", + Type: &arrow.Uint8Type{}, + Nullable: true, + Metadata: arrow.Metadata{}, + }) + case schemapb.DataType_Float16Vector: + return arrow.ListOfField(arrow.Field{ + Name: "item", + Type: &arrow.Float16Type{}, + Nullable: true, + Metadata: arrow.Metadata{}, + }) + default: + panic("unsupported data type") + } +} + +func convertMilvusSchemaToArrowSchema(schema *schemapb.CollectionSchema) *arrow.Schema { + fields := make([]arrow.Field, 0) + for _, field := range schema.GetFields() { + if field.GetDataType() == schemapb.DataType_Array { + fields = append(fields, arrow.Field{ + Name: field.GetName(), + Type: arrow.ListOfField(arrow.Field{ + Name: "item", + Type: milvusDataTypeToArrowType(field.GetElementType(), false), + Nullable: true, + Metadata: arrow.Metadata{}, + }), + Nullable: true, + Metadata: arrow.Metadata{}, + }) + continue + } + fields = append(fields, arrow.Field{ + Name: field.GetName(), + Type: milvusDataTypeToArrowType(field.GetDataType(), field.Name == "FieldBinaryVector2"), + Nullable: true, + Metadata: arrow.Metadata{}, + }) + } + return arrow.NewSchema(fields, nil) +} + +func randomString(length int) string { + letterRunes := []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + b := make([]rune, length) + for i := range b { + b[i] = letterRunes[rand.Intn(len(letterRunes))] + } + return string(b) +} + +func buildArrayData(dataType, elementType schemapb.DataType, dim, rows int, isBinary bool) arrow.Array { + mem := memory.NewGoAllocator() + switch dataType { + case schemapb.DataType_Bool: + builder := array.NewBooleanBuilder(mem) + for i := 0; i < rows; i++ { + builder.Append(i%2 == 0) + } + return builder.NewBooleanArray() + case schemapb.DataType_Int8: + builder := array.NewInt8Builder(mem) + for i := 0; i < rows; i++ { + builder.Append(int8(i)) + } + return builder.NewInt8Array() + case schemapb.DataType_Int16: + builder := array.NewInt16Builder(mem) + for i := 0; i < rows; i++ { + builder.Append(int16(i)) + } + return builder.NewInt16Array() + case schemapb.DataType_Int32: + builder := array.NewInt32Builder(mem) + for i := 0; i < rows; i++ { + builder.Append(int32(i)) + } + return builder.NewInt32Array() + case schemapb.DataType_Int64: + builder := array.NewInt64Builder(mem) + for i := 0; i < rows; i++ { + builder.Append(int64(i)) + } + return builder.NewInt64Array() + case schemapb.DataType_Float: + builder := array.NewFloat32Builder(mem) + for i := 0; i < rows; i++ { + builder.Append(float32(i) * 0.1) + } + return builder.NewFloat32Array() + case schemapb.DataType_Double: + builder := array.NewFloat64Builder(mem) + for i := 0; i < rows; i++ { + builder.Append(float64(i) * 0.02) + } + return builder.NewFloat64Array() + case schemapb.DataType_VarChar, schemapb.DataType_String: + builder := array.NewStringBuilder(mem) + for i := 0; i < rows; i++ { + builder.Append(randomString(10)) + } + return builder.NewStringArray() + case schemapb.DataType_FloatVector: + builder := array.NewListBuilder(mem, &arrow.Float32Type{}) + offsets := make([]int32, 0, rows) + valid := make([]bool, 0, rows) + for i := 0; i < dim*rows; i++ { + builder.ValueBuilder().(*array.Float32Builder).Append(float32(i)) + } + for i := 0; i < rows; i++ { + offsets = append(offsets, int32(i*dim)) + valid = append(valid, true) + } + builder.AppendValues(offsets, valid) + return builder.NewListArray() + case schemapb.DataType_BinaryVector: + if isBinary { + builder := array.NewBinaryBuilder(mem, &arrow.BinaryType{}) + for i := 0; i < rows; i++ { + element := make([]byte, dim/8) + for j := 0; j < dim/8; j++ { + element[j] = randomString(1)[0] + } + builder.Append(element) + } + return builder.NewBinaryArray() + } + builder := array.NewListBuilder(mem, &arrow.Uint8Type{}) + offsets := make([]int32, 0, rows) + valid := make([]bool, 0) + for i := 0; i < dim*rows/8; i++ { + builder.ValueBuilder().(*array.Uint8Builder).Append(uint8(i)) + } + for i := 0; i < rows; i++ { + offsets = append(offsets, int32(dim*i/8)) + valid = append(valid, true) + } + builder.AppendValues(offsets, valid) + return builder.NewListArray() + case schemapb.DataType_JSON: + builder := array.NewStringBuilder(mem) + for i := 0; i < rows; i++ { + builder.Append(fmt.Sprintf("{\"a\": \"%s\", \"b\": %d}", randomString(3), i)) + } + return builder.NewStringArray() + case schemapb.DataType_Array: + offsets := make([]int32, 0, rows) + valid := make([]bool, 0, rows) + index := 0 + for i := 0; i < rows; i++ { + index += i % 10 + offsets = append(offsets, int32(index)) + valid = append(valid, true) + } + switch elementType { + case schemapb.DataType_Bool: + builder := array.NewListBuilder(mem, &arrow.BooleanType{}) + valueBuilder := builder.ValueBuilder().(*array.BooleanBuilder) + for i := 0; i < index; i++ { + valueBuilder.Append(i%2 == 0) + } + builder.AppendValues(offsets, valid) + return builder.NewListArray() + case schemapb.DataType_Int8: + builder := array.NewListBuilder(mem, &arrow.Int8Type{}) + valueBuilder := builder.ValueBuilder().(*array.Int8Builder) + for i := 0; i < index; i++ { + valueBuilder.Append(int8(i)) + } + builder.AppendValues(offsets, valid) + return builder.NewListArray() + case schemapb.DataType_Int16: + builder := array.NewListBuilder(mem, &arrow.Int16Type{}) + valueBuilder := builder.ValueBuilder().(*array.Int16Builder) + for i := 0; i < index; i++ { + valueBuilder.Append(int16(i)) + } + builder.AppendValues(offsets, valid) + return builder.NewListArray() + case schemapb.DataType_Int32: + builder := array.NewListBuilder(mem, &arrow.Int32Type{}) + valueBuilder := builder.ValueBuilder().(*array.Int32Builder) + for i := 0; i < index; i++ { + valueBuilder.Append(int32(i)) + } + builder.AppendValues(offsets, valid) + return builder.NewListArray() + case schemapb.DataType_Int64: + builder := array.NewListBuilder(mem, &arrow.Int64Type{}) + valueBuilder := builder.ValueBuilder().(*array.Int64Builder) + for i := 0; i < index; i++ { + valueBuilder.Append(int64(i)) + } + builder.AppendValues(offsets, valid) + return builder.NewListArray() + case schemapb.DataType_Float: + builder := array.NewListBuilder(mem, &arrow.Float32Type{}) + valueBuilder := builder.ValueBuilder().(*array.Float32Builder) + for i := 0; i < index; i++ { + valueBuilder.Append(float32(i) * 0.1) + } + builder.AppendValues(offsets, valid) + return builder.NewListArray() + case schemapb.DataType_Double: + builder := array.NewListBuilder(mem, &arrow.Float64Type{}) + valueBuilder := builder.ValueBuilder().(*array.Float64Builder) + for i := 0; i < index; i++ { + valueBuilder.Append(float64(i) * 0.02) + } + builder.AppendValues(offsets, valid) + return builder.NewListArray() + case schemapb.DataType_VarChar, schemapb.DataType_String: + builder := array.NewListBuilder(mem, &arrow.StringType{}) + valueBuilder := builder.ValueBuilder().(*array.StringBuilder) + for i := 0; i < index; i++ { + valueBuilder.Append(randomString(5) + "-" + fmt.Sprintf("%d", i)) + } + builder.AppendValues(offsets, valid) + return builder.NewListArray() + } + } + return nil +} + +func writeParquet(w io.Writer, schema *schemapb.CollectionSchema, numRows int) error { + pqSchema := convertMilvusSchemaToArrowSchema(schema) + fw, err := pqarrow.NewFileWriter(pqSchema, w, parquet.NewWriterProperties(parquet.WithMaxRowGroupLength(int64(numRows))), pqarrow.DefaultWriterProps()) + if err != nil { + return err + } + defer fw.Close() + + columns := make([]arrow.Array, 0, len(schema.Fields)) + for _, field := range schema.Fields { + var dim int64 = 1 + if typeutil.IsVectorType(field.GetDataType()) { + dim, err = typeutil.GetDim(field) + if err != nil { + return err + } + } + columnData := buildArrayData(field.DataType, field.ElementType, int(dim), numRows, field.Name == "FieldBinaryVector2") + columns = append(columns, columnData) + } + recordBatch := array.NewRecord(pqSchema, columns, int64(numRows)) + err = fw.Write(recordBatch) + if err != nil { + return err + } + + return nil +} + +func (s *ReaderSuite) run(dt schemapb.DataType) { + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + FieldID: 100, + Name: "pk", + IsPrimaryKey: true, + DataType: s.pkDataType, + }, + { + FieldID: 101, + Name: "vec", + DataType: s.vecDataType, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: "8", + }, + }, + }, + { + FieldID: 102, + Name: dt.String(), + DataType: dt, + ElementType: schemapb.DataType_Int32, + }, + }, + } + + filePath := fmt.Sprintf("/tmp/test_%d_reader.parquet", rand.Int()) + defer os.Remove(filePath) + wf, err := os.OpenFile(filePath, os.O_RDWR|os.O_CREATE, 0o666) + assert.NoError(s.T(), err) + err = writeParquet(wf, schema, s.numRows) + assert.NoError(s.T(), err) + + ctx := context.Background() + f := storage.NewChunkManagerFactory("local", storage.RootPath("/tmp/milvus_test/test_parquet_reader/")) + cm, err := f.NewPersistentStorageChunkManager(ctx) + assert.NoError(s.T(), err) + cmReader, err := cm.Reader(ctx, filePath) + assert.NoError(s.T(), err) + reader, err := NewReader(schema, cmReader, math.MaxInt) + s.NoError(err) + + checkFn := func(actualInsertData *storage.InsertData, offsetBegin, expectRows int) { + // expectInsertData := insertData + for _, data := range actualInsertData.Data { + s.Equal(expectRows, data.RowNum()) + // TODO: dyh, check rows + // fieldDataType := typeutil.GetField(schema, fieldID).GetDataType() + // for i := 0; i < expectRows; i++ { + // expect := expectInsertData.Data[fieldID].GetRow(i + offsetBegin) + // actual := data.GetRow(i) + // if fieldDataType == schemapb.DataType_Array { + // s.True(slices.Equal(expect.(*schemapb.ScalarField).GetIntData().GetData(), actual.(*schemapb.ScalarField).GetIntData().GetData())) + // } else { + // s.Equal(expect, actual) + // } + // } + } + } + + res, err := reader.Read() + s.NoError(err) + checkFn(res, 0, s.numRows) +} + +func (s *ReaderSuite) TestReadScalarFields() { + s.run(schemapb.DataType_Bool) + s.run(schemapb.DataType_Int8) + s.run(schemapb.DataType_Int16) + s.run(schemapb.DataType_Int32) + s.run(schemapb.DataType_Int64) + s.run(schemapb.DataType_Float) + s.run(schemapb.DataType_Double) + s.run(schemapb.DataType_VarChar) + s.run(schemapb.DataType_Array) + s.run(schemapb.DataType_JSON) +} + +func (s *ReaderSuite) TestStringPK() { + s.pkDataType = schemapb.DataType_VarChar + s.run(schemapb.DataType_Int32) +} + +func (s *ReaderSuite) TestBinaryAndFloat16Vector() { + s.vecDataType = schemapb.DataType_BinaryVector + s.run(schemapb.DataType_Int32) + // s.vecDataType = schemapb.DataType_Float16Vector + // s.run(schemapb.DataType_Int32) // TODO: dyh, support float16 vector +} + +func TestUtil(t *testing.T) { + suite.Run(t, new(ReaderSuite)) +} diff --git a/internal/util/importutilv2/parquet/util.go b/internal/util/importutilv2/parquet/util.go new file mode 100644 index 0000000000000..3c2e4baac5673 --- /dev/null +++ b/internal/util/importutilv2/parquet/util.go @@ -0,0 +1,175 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package parquet + +import ( + "fmt" + + "github.com/apache/arrow/go/v12/arrow" + "github.com/apache/arrow/go/v12/parquet/pqarrow" + "github.com/samber/lo" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +func WrapTypeErr(expect string, actual string, field *schemapb.FieldSchema) error { + return merr.WrapErrImportFailed( + fmt.Sprintf("expect '%s' type for field '%s', but got '%s' type", + expect, field.GetName(), actual)) +} + +func calcBufferSize(blockSize int, schema *schemapb.CollectionSchema) int { + if len(schema.GetFields()) <= 0 { + return blockSize + } + return blockSize / len(schema.GetFields()) +} + +func CreateFieldReaders(fileReader *pqarrow.FileReader, schema *schemapb.CollectionSchema) (map[int64]*FieldReader, error) { + nameToField := lo.KeyBy(schema.GetFields(), func(field *schemapb.FieldSchema) string { + return field.GetName() + }) + + pqSchema, err := fileReader.Schema() + if err != nil { + return nil, merr.WrapErrImportFailed(fmt.Sprintf("get parquet schema failed, err=%v", err)) + } + + crs := make(map[int64]*FieldReader) + for i, pqField := range pqSchema.Fields() { + field, ok := nameToField[pqField.Name] + if !ok { + // TODO @cai.zhang: handle dynamic field + return nil, merr.WrapErrImportFailed(fmt.Sprintf("the field: %s is not in schema, "+ + "if it's a dynamic field, please reformat data by bulk_writer", pqField.Name)) + } + if field.GetIsPrimaryKey() && field.GetAutoID() { + return nil, merr.WrapErrImportFailed( + fmt.Sprintf("the primary key '%s' is auto-generated, no need to provide", field.GetName())) + } + + arrowType, isList := convertArrowSchemaToDataType(pqField, false) + dataType := field.GetDataType() + if isList { + if !typeutil.IsVectorType(dataType) && dataType != schemapb.DataType_Array { + return nil, WrapTypeErr(dataType.String(), pqField.Type.Name(), field) + } + if dataType == schemapb.DataType_Array { + dataType = field.GetElementType() + } + } + if !isConvertible(arrowType, dataType, isList) { + return nil, WrapTypeErr(dataType.String(), pqField.Type.Name(), field) + } + + cr, err := NewFieldReader(fileReader, i, field) + if err != nil { + return nil, err + } + if _, ok = crs[field.GetFieldID()]; ok { + return nil, merr.WrapErrImportFailed( + fmt.Sprintf("there is multi field with name: %s", field.GetName())) + } + crs[field.GetFieldID()] = cr + } + + for _, field := range nameToField { + if (field.GetIsPrimaryKey() && field.GetAutoID()) || field.GetIsDynamic() { + continue + } + if _, ok := crs[field.GetFieldID()]; !ok { + return nil, merr.WrapErrImportFailed( + fmt.Sprintf("no parquet field for milvus file '%s'", field.GetName())) + } + } + return crs, nil +} + +func convertArrowSchemaToDataType(field arrow.Field, isList bool) (schemapb.DataType, bool) { + switch field.Type.ID() { + case arrow.BOOL: + return schemapb.DataType_Bool, false + case arrow.UINT8: + if isList { + return schemapb.DataType_BinaryVector, false + } + return schemapb.DataType_None, false + case arrow.INT8: + return schemapb.DataType_Int8, false + case arrow.INT16: + return schemapb.DataType_Int16, false + case arrow.INT32: + return schemapb.DataType_Int32, false + case arrow.INT64: + return schemapb.DataType_Int64, false + case arrow.FLOAT16: + if isList { + return schemapb.DataType_Float16Vector, false + } + return schemapb.DataType_None, false + case arrow.FLOAT32: + return schemapb.DataType_Float, false + case arrow.FLOAT64: + return schemapb.DataType_Double, false + case arrow.STRING: + return schemapb.DataType_VarChar, false + case arrow.BINARY: + return schemapb.DataType_BinaryVector, false + case arrow.LIST: + elementType, _ := convertArrowSchemaToDataType(field.Type.(*arrow.ListType).ElemField(), true) + return elementType, true + default: + return schemapb.DataType_None, false + } +} + +func isConvertible(src, dst schemapb.DataType, isList bool) bool { + switch src { + case schemapb.DataType_Bool: + return typeutil.IsBoolType(dst) + case schemapb.DataType_Int8: + return typeutil.IsArithmetic(dst) + case schemapb.DataType_Int16: + return typeutil.IsArithmetic(dst) && dst != schemapb.DataType_Int8 + case schemapb.DataType_Int32: + return typeutil.IsArithmetic(dst) && dst != schemapb.DataType_Int8 && dst != schemapb.DataType_Int16 + case schemapb.DataType_Int64: + return typeutil.IsFloatingType(dst) || dst == schemapb.DataType_Int64 + case schemapb.DataType_Float: + if isList && dst == schemapb.DataType_FloatVector { + return true + } + return typeutil.IsFloatingType(dst) + case schemapb.DataType_Double: + if isList && dst == schemapb.DataType_FloatVector { + return true + } + return dst == schemapb.DataType_Double + case schemapb.DataType_String, schemapb.DataType_VarChar: + return typeutil.IsStringType(dst) || typeutil.IsJSONType(dst) + case schemapb.DataType_JSON: + return typeutil.IsJSONType(dst) + case schemapb.DataType_BinaryVector: + return dst == schemapb.DataType_BinaryVector + case schemapb.DataType_Float16Vector: + return dst == schemapb.DataType_Float16Vector + default: + return false + } +} From 635a7f777c0352dbe97ceccd01678c25344e64c4 Mon Sep 17 00:00:00 2001 From: wayblink Date: Sun, 7 Jan 2024 19:56:48 +0800 Subject: [PATCH 08/20] feat: add clustering key in create/describe collection (#29506) #28410 /kind feature Signed-off-by: wayblink --- go.mod | 2 +- go.sum | 4 +- internal/metastore/model/field.go | 105 +++++++++-------- internal/proxy/task.go | 60 ++++++++-- internal/proxy/task_test.go | 183 ++++++++++++++++++++++++++++++ pkg/util/merr/errors.go | 1 + pkg/util/merr/utils.go | 9 ++ 7 files changed, 299 insertions(+), 65 deletions(-) diff --git a/go.mod b/go.mod index d26e77871ce7e..218725be30525 100644 --- a/go.mod +++ b/go.mod @@ -23,7 +23,7 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 github.com/klauspost/compress v1.16.7 github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d - github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20231228051838-b5442d755fa4 + github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20231229025438-39bce6abb18f github.com/minio/minio-go/v7 v7.0.61 github.com/prometheus/client_golang v1.14.0 github.com/prometheus/client_model v0.3.0 diff --git a/go.sum b/go.sum index 95c831d51ad9f..2ca0615cd3183 100644 --- a/go.sum +++ b/go.sum @@ -583,8 +583,8 @@ github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/le github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b h1:TfeY0NxYxZzUfIfYe5qYDBzt4ZYRqzUjTR6CvUzjat8= github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b/go.mod h1:iwW+9cWfIzzDseEBCCeDSN5SD16Tidvy8cwQ7ZY8Qj4= -github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20231228051838-b5442d755fa4 h1:nxIohfJOCMbixFAC3q4Lclmv0xg/8q6D8T7D8l258To= -github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20231228051838-b5442d755fa4/go.mod h1:1OIl0v5PQeNxIJhCvY+K55CBUOYDZevw9g9380u1Wek= +github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20231229025438-39bce6abb18f h1:8lNcRqhQgUROtmtiIEdpQHGW82KMI5oASVKxkaZ/tBg= +github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20231229025438-39bce6abb18f/go.mod h1:1OIl0v5PQeNxIJhCvY+K55CBUOYDZevw9g9380u1Wek= github.com/milvus-io/milvus-storage/go v0.0.0-20231109072809-1cd7b0866092 h1:UYJ7JB+QlMOoFHNdd8mUa3/lV63t9dnBX7ILXmEEWPY= github.com/milvus-io/milvus-storage/go v0.0.0-20231109072809-1cd7b0866092/go.mod h1:GPETMcTZq1gLY1WA6Na5kiNAKnq8SEMMiVKUZrM3sho= github.com/milvus-io/pulsar-client-go v0.6.10 h1:eqpJjU+/QX0iIhEo3nhOqMNXL+TyInAs1IAHZCrCM/A= diff --git a/internal/metastore/model/field.go b/internal/metastore/model/field.go index 10d44604d2406..a4d906a24a546 100644 --- a/internal/metastore/model/field.go +++ b/internal/metastore/model/field.go @@ -7,19 +7,20 @@ import ( ) type Field struct { - FieldID int64 - Name string - IsPrimaryKey bool - Description string - DataType schemapb.DataType - TypeParams []*commonpb.KeyValuePair - IndexParams []*commonpb.KeyValuePair - AutoID bool - State schemapb.FieldState - IsDynamic bool - IsPartitionKey bool // partition key mode, multi logic partitions share a physical partition - DefaultValue *schemapb.ValueField - ElementType schemapb.DataType + FieldID int64 + Name string + IsPrimaryKey bool + Description string + DataType schemapb.DataType + TypeParams []*commonpb.KeyValuePair + IndexParams []*commonpb.KeyValuePair + AutoID bool + State schemapb.FieldState + IsDynamic bool + IsPartitionKey bool // partition key mode, multi logic partitions share a physical partition + IsClusteringKey bool + DefaultValue *schemapb.ValueField + ElementType schemapb.DataType } func (f *Field) Available() bool { @@ -28,19 +29,20 @@ func (f *Field) Available() bool { func (f *Field) Clone() *Field { return &Field{ - FieldID: f.FieldID, - Name: f.Name, - IsPrimaryKey: f.IsPrimaryKey, - Description: f.Description, - DataType: f.DataType, - TypeParams: common.CloneKeyValuePairs(f.TypeParams), - IndexParams: common.CloneKeyValuePairs(f.IndexParams), - AutoID: f.AutoID, - State: f.State, - IsDynamic: f.IsDynamic, - IsPartitionKey: f.IsPartitionKey, - DefaultValue: f.DefaultValue, - ElementType: f.ElementType, + FieldID: f.FieldID, + Name: f.Name, + IsPrimaryKey: f.IsPrimaryKey, + Description: f.Description, + DataType: f.DataType, + TypeParams: common.CloneKeyValuePairs(f.TypeParams), + IndexParams: common.CloneKeyValuePairs(f.IndexParams), + AutoID: f.AutoID, + State: f.State, + IsDynamic: f.IsDynamic, + IsPartitionKey: f.IsPartitionKey, + IsClusteringKey: f.IsClusteringKey, + DefaultValue: f.DefaultValue, + ElementType: f.ElementType, } } @@ -68,6 +70,7 @@ func (f *Field) Equal(other Field) bool { f.AutoID == other.AutoID && f.IsPartitionKey == other.IsPartitionKey && f.IsDynamic == other.IsDynamic && + f.IsClusteringKey == other.IsClusteringKey && f.DefaultValue == other.DefaultValue && f.ElementType == other.ElementType } @@ -91,18 +94,19 @@ func MarshalFieldModel(field *Field) *schemapb.FieldSchema { } return &schemapb.FieldSchema{ - FieldID: field.FieldID, - Name: field.Name, - IsPrimaryKey: field.IsPrimaryKey, - Description: field.Description, - DataType: field.DataType, - TypeParams: field.TypeParams, - IndexParams: field.IndexParams, - AutoID: field.AutoID, - IsDynamic: field.IsDynamic, - IsPartitionKey: field.IsPartitionKey, - DefaultValue: field.DefaultValue, - ElementType: field.ElementType, + FieldID: field.FieldID, + Name: field.Name, + IsPrimaryKey: field.IsPrimaryKey, + Description: field.Description, + DataType: field.DataType, + TypeParams: field.TypeParams, + IndexParams: field.IndexParams, + AutoID: field.AutoID, + IsDynamic: field.IsDynamic, + IsPartitionKey: field.IsPartitionKey, + IsClusteringKey: field.IsClusteringKey, + DefaultValue: field.DefaultValue, + ElementType: field.ElementType, } } @@ -124,18 +128,19 @@ func UnmarshalFieldModel(fieldSchema *schemapb.FieldSchema) *Field { } return &Field{ - FieldID: fieldSchema.FieldID, - Name: fieldSchema.Name, - IsPrimaryKey: fieldSchema.IsPrimaryKey, - Description: fieldSchema.Description, - DataType: fieldSchema.DataType, - TypeParams: fieldSchema.TypeParams, - IndexParams: fieldSchema.IndexParams, - AutoID: fieldSchema.AutoID, - IsDynamic: fieldSchema.IsDynamic, - IsPartitionKey: fieldSchema.IsPartitionKey, - DefaultValue: fieldSchema.DefaultValue, - ElementType: fieldSchema.ElementType, + FieldID: fieldSchema.FieldID, + Name: fieldSchema.Name, + IsPrimaryKey: fieldSchema.IsPrimaryKey, + Description: fieldSchema.Description, + DataType: fieldSchema.DataType, + TypeParams: fieldSchema.TypeParams, + IndexParams: fieldSchema.IndexParams, + AutoID: fieldSchema.AutoID, + IsDynamic: fieldSchema.IsDynamic, + IsPartitionKey: fieldSchema.IsPartitionKey, + IsClusteringKey: fieldSchema.IsClusteringKey, + DefaultValue: fieldSchema.DefaultValue, + ElementType: fieldSchema.ElementType, } } diff --git a/internal/proxy/task.go b/internal/proxy/task.go index a3f9bf16e39d3..bb57d377e4098 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -208,6 +208,36 @@ func (t *createCollectionTask) validatePartitionKey() error { return nil } +func (t *createCollectionTask) validateClusteringKey() error { + idx := -1 + for i, field := range t.schema.Fields { + if field.GetIsClusteringKey() { + if idx != -1 { + return merr.WrapErrCollectionIllegalSchema(t.CollectionName, + fmt.Sprintf("there are more than one clustering key, field name = %s, %s", t.schema.Fields[idx].Name, field.Name)) + } + + if field.GetIsPrimaryKey() { + return merr.WrapErrCollectionIllegalSchema(t.CollectionName, + fmt.Sprintf("the clustering key field must not be primary key field, field name = %s", field.Name)) + } + + if field.GetIsPartitionKey() { + return merr.WrapErrCollectionIllegalSchema(t.CollectionName, + fmt.Sprintf("the clustering key field must not be partition key field, field name = %s", field.Name)) + } + idx = i + } + } + + if idx != -1 { + log.Info("create collection with clustering key", + zap.String("collectionName", t.CollectionName), + zap.String("clusteringKeyField", t.schema.Fields[idx].Name)) + } + return nil +} + func (t *createCollectionTask) PreExecute(ctx context.Context) error { t.Base.MsgType = commonpb.MsgType_CreateCollection t.Base.SourceID = paramtable.GetNodeID() @@ -266,6 +296,11 @@ func (t *createCollectionTask) PreExecute(ctx context.Context) error { return err } + // validate clustering key + if err := t.validateClusteringKey(); err != nil { + return err + } + for _, field := range t.schema.Fields { // validate field name if err := validateFieldName(field.Name); err != nil { @@ -572,18 +607,19 @@ func (t *describeCollectionTask) Execute(ctx context.Context) error { } if field.FieldID >= common.StartOfUserFieldID { t.result.Schema.Fields = append(t.result.Schema.Fields, &schemapb.FieldSchema{ - FieldID: field.FieldID, - Name: field.Name, - IsPrimaryKey: field.IsPrimaryKey, - AutoID: field.AutoID, - Description: field.Description, - DataType: field.DataType, - TypeParams: field.TypeParams, - IndexParams: field.IndexParams, - IsDynamic: field.IsDynamic, - IsPartitionKey: field.IsPartitionKey, - DefaultValue: field.DefaultValue, - ElementType: field.ElementType, + FieldID: field.FieldID, + Name: field.Name, + IsPrimaryKey: field.IsPrimaryKey, + AutoID: field.AutoID, + Description: field.Description, + DataType: field.DataType, + TypeParams: field.TypeParams, + IndexParams: field.IndexParams, + IsDynamic: field.IsDynamic, + IsPartitionKey: field.IsPartitionKey, + IsClusteringKey: field.IsClusteringKey, + DefaultValue: field.DefaultValue, + ElementType: field.ElementType, }) } } diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index 4bb083ee5fd89..eb3acf48d7889 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -3493,3 +3493,186 @@ func TestPartitionKey(t *testing.T) { assert.Error(t, err) }) } + +func TestClusteringKey(t *testing.T) { + rc := NewRootCoordMock() + + defer rc.Close() + qc := getQueryCoordClient() + + ctx := context.Background() + + mgr := newShardClientMgr() + err := InitMetaCache(ctx, rc, qc, mgr) + assert.NoError(t, err) + + shardsNum := common.DefaultShardsNum + prefix := "TestClusteringKey" + collectionName := prefix + funcutil.GenRandomStr() + + t.Run("create collection normal", func(t *testing.T) { + fieldName2Type := make(map[string]schemapb.DataType) + fieldName2Type["int64_field"] = schemapb.DataType_Int64 + fieldName2Type["varChar_field"] = schemapb.DataType_VarChar + schema := constructCollectionSchemaByDataType(collectionName, fieldName2Type, "int64_field", false) + fieldName2Type["cluster_key_field"] = schemapb.DataType_Int64 + clusterKeyField := &schemapb.FieldSchema{ + Name: "cluster_key_field", + DataType: schemapb.DataType_Int64, + IsClusteringKey: true, + } + schema.Fields = append(schema.Fields, clusterKeyField) + vecField := &schemapb.FieldSchema{ + Name: "fvec_field", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: strconv.Itoa(testVecDim), + }, + }, + } + schema.Fields = append(schema.Fields, vecField) + marshaledSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + + createCollectionTask := &createCollectionTask{ + Condition: NewTaskCondition(ctx), + CreateCollectionRequest: &milvuspb.CreateCollectionRequest{ + Base: &commonpb.MsgBase{ + MsgID: UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()), + Timestamp: Timestamp(time.Now().UnixNano()), + }, + DbName: "", + CollectionName: collectionName, + Schema: marshaledSchema, + ShardsNum: shardsNum, + }, + ctx: ctx, + rootCoord: rc, + result: nil, + schema: nil, + } + err = createCollectionTask.PreExecute(ctx) + assert.NoError(t, err) + err = createCollectionTask.Execute(ctx) + assert.NoError(t, err) + }) + + t.Run("create collection clustering key can not be partition key", func(t *testing.T) { + fieldName2Type := make(map[string]schemapb.DataType) + fieldName2Type["int64_field"] = schemapb.DataType_Int64 + fieldName2Type["varChar_field"] = schemapb.DataType_VarChar + fieldName2Type["fvec_field"] = schemapb.DataType_FloatVector + schema := constructCollectionSchemaByDataType(collectionName, fieldName2Type, "int64_field", false) + fieldName2Type["cluster_key_field"] = schemapb.DataType_Int64 + clusterKeyField := &schemapb.FieldSchema{ + Name: "cluster_key_field", + DataType: schemapb.DataType_Int64, + IsClusteringKey: true, + IsPartitionKey: true, + } + schema.Fields = append(schema.Fields, clusterKeyField) + marshaledSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + + createCollectionTask := &createCollectionTask{ + Condition: NewTaskCondition(ctx), + CreateCollectionRequest: &milvuspb.CreateCollectionRequest{ + Base: &commonpb.MsgBase{ + MsgID: UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()), + Timestamp: Timestamp(time.Now().UnixNano()), + }, + DbName: "", + CollectionName: collectionName, + Schema: marshaledSchema, + ShardsNum: shardsNum, + }, + ctx: ctx, + rootCoord: rc, + result: nil, + schema: nil, + } + err = createCollectionTask.PreExecute(ctx) + assert.Error(t, err) + }) + + t.Run("create collection clustering key can not be primary key", func(t *testing.T) { + fieldName2Type := make(map[string]schemapb.DataType) + fieldName2Type["varChar_field"] = schemapb.DataType_VarChar + fieldName2Type["fvec_field"] = schemapb.DataType_FloatVector + schema := constructCollectionSchemaByDataType(collectionName, fieldName2Type, "int64_field", false) + fieldName2Type["cluster_key_field"] = schemapb.DataType_Int64 + clusterKeyField := &schemapb.FieldSchema{ + Name: "cluster_key_field", + DataType: schemapb.DataType_Int64, + IsClusteringKey: true, + IsPrimaryKey: true, + } + schema.Fields = append(schema.Fields, clusterKeyField) + marshaledSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + + createCollectionTask := &createCollectionTask{ + Condition: NewTaskCondition(ctx), + CreateCollectionRequest: &milvuspb.CreateCollectionRequest{ + Base: &commonpb.MsgBase{ + MsgID: UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()), + Timestamp: Timestamp(time.Now().UnixNano()), + }, + DbName: "", + CollectionName: collectionName, + Schema: marshaledSchema, + ShardsNum: shardsNum, + }, + ctx: ctx, + rootCoord: rc, + result: nil, + schema: nil, + } + err = createCollectionTask.PreExecute(ctx) + assert.Error(t, err) + }) + + t.Run("create collection not support more than one clustering key", func(t *testing.T) { + fieldName2Type := make(map[string]schemapb.DataType) + fieldName2Type["int64_field"] = schemapb.DataType_Int64 + fieldName2Type["varChar_field"] = schemapb.DataType_VarChar + schema := constructCollectionSchemaByDataType(collectionName, fieldName2Type, "int64_field", false) + fieldName2Type["cluster_key_field"] = schemapb.DataType_Int64 + clusterKeyField := &schemapb.FieldSchema{ + Name: "cluster_key_field", + DataType: schemapb.DataType_Int64, + IsClusteringKey: true, + } + schema.Fields = append(schema.Fields, clusterKeyField) + clusterKeyField2 := &schemapb.FieldSchema{ + Name: "cluster_key_field2", + DataType: schemapb.DataType_Int64, + IsClusteringKey: true, + } + schema.Fields = append(schema.Fields, clusterKeyField2) + marshaledSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + + createCollectionTask := &createCollectionTask{ + Condition: NewTaskCondition(ctx), + CreateCollectionRequest: &milvuspb.CreateCollectionRequest{ + Base: &commonpb.MsgBase{ + MsgID: UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()), + Timestamp: Timestamp(time.Now().UnixNano()), + }, + DbName: "", + CollectionName: collectionName, + Schema: marshaledSchema, + ShardsNum: shardsNum, + }, + ctx: ctx, + rootCoord: rc, + result: nil, + schema: nil, + } + err = createCollectionTask.PreExecute(ctx) + assert.Error(t, err) + }) +} diff --git a/pkg/util/merr/errors.go b/pkg/util/merr/errors.go index 3f3a3095f94f8..411e108749064 100644 --- a/pkg/util/merr/errors.go +++ b/pkg/util/merr/errors.go @@ -50,6 +50,7 @@ var ( ErrCollectionNumLimitExceeded = newMilvusError("exceeded the limit number of collections", 102, false) ErrCollectionNotFullyLoaded = newMilvusError("collection not fully loaded", 103, true) ErrCollectionLoaded = newMilvusError("collection already loaded", 104, false) + ErrCollectionIllegalSchema = newMilvusError("illegal collection schema", 105, false) // Partition related ErrPartitionNotFound = newMilvusError("partition not found", 200, false) diff --git a/pkg/util/merr/utils.go b/pkg/util/merr/utils.go index abe3178d4429f..9310edb0d5b45 100644 --- a/pkg/util/merr/utils.go +++ b/pkg/util/merr/utils.go @@ -465,6 +465,15 @@ func WrapErrCollectionLoaded(collection string, msgAndArgs ...any) error { return err } +func WrapErrCollectionIllegalSchema(collection string, msgAndArgs ...any) error { + err := wrapFields(ErrCollectionIllegalSchema, value("collection", collection)) + if len(msgAndArgs) > 0 { + msg := msgAndArgs[0].(string) + err = errors.Wrapf(err, msg, msgAndArgs[1:]...) + } + return err +} + func WrapErrAliasNotFound(db any, alias any, msg ...string) error { err := wrapFields(ErrAliasNotFound, value("database", db), From 271edc6669026ade76726d72b2a419aa9b30d683 Mon Sep 17 00:00:00 2001 From: foxspy Date: Sun, 7 Jan 2024 20:03:13 +0800 Subject: [PATCH 09/20] fix: throw exception when upload file failed for DiskIndex (#29627) related to : #29417 cardinal indexes upload index files in `Serialize` interface, and throw exception when the `Serialize` failed. Signed-off-by: xianliang --- internal/core/src/index/VectorDiskIndex.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/internal/core/src/index/VectorDiskIndex.cpp b/internal/core/src/index/VectorDiskIndex.cpp index e6a525453760e..6fd6267a7973f 100644 --- a/internal/core/src/index/VectorDiskIndex.cpp +++ b/internal/core/src/index/VectorDiskIndex.cpp @@ -134,7 +134,11 @@ template BinarySet VectorDiskAnnIndex::Upload(const Config& config) { BinarySet ret; - index_.Serialize(ret); + auto stat = index_.Serialize(ret); + if (stat != knowhere::Status::success) { + PanicInfo(ErrorCode::UnexpectedError, + "failed to serialize index, " + KnowhereStatusString(stat)); + } auto remote_paths_to_size = file_manager_->GetRemotePathsToFileSize(); for (auto& file : remote_paths_to_size) { ret.Append(file.first, nullptr, file.second); From 4b3de6473387d0bfddc1af510a25ba7f3a7666c1 Mon Sep 17 00:00:00 2001 From: Jiquan Long Date: Sun, 7 Jan 2024 20:04:12 +0800 Subject: [PATCH 10/20] enhance: add rust to install_dep.sh (#29586) fix: #29585 Signed-off-by: longjiquan --- scripts/install_deps.sh | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/scripts/install_deps.sh b/scripts/install_deps.sh index de6adb4b12160..608262f425c06 100755 --- a/scripts/install_deps.sh +++ b/scripts/install_deps.sh @@ -52,6 +52,15 @@ function install_linux_deps() { else echo "cmake version is $cmake_version" fi + # install rust + if command -v cargo >/dev/null 2>&1; then + echo "cargo exists" + rustup install 1.73 + rustup default 1.73 + else + bash -c "curl https://sh.rustup.rs -sSf | sh -s -- --default-toolchain=1.73 -y" || { echo 'rustup install failed'; exit 1;} + source $HOME/.cargo/env + fi } function install_mac_deps() { @@ -68,6 +77,15 @@ function install_mac_deps() { fi sudo ln -s "$(brew --prefix llvm@15)" "/usr/local/opt/llvm" + # install rust + if command -v cargo >/dev/null 2>&1; then + echo "cargo exists" + rustup install 1.73 + rustup default 1.73 + else + bash -c "curl https://sh.rustup.rs -sSf | sh -s -- --default-toolchain=1.73 -y" || { echo 'rustup install failed'; exit 1;} + source $HOME/.cargo/env + fi } if ! command -v go &> /dev/null From a3bae80b59c9b32a194b7c4c2979b395cfe951c7 Mon Sep 17 00:00:00 2001 From: Jiquan Long Date: Sun, 7 Jan 2024 20:09:49 +0800 Subject: [PATCH 11/20] enhance: print total memory when milvus starts (#29351) fix: #29349 --------- Signed-off-by: longjiquan --- cmd/milvus/run.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/cmd/milvus/run.go b/cmd/milvus/run.go index bbb19eb88ab7a..17f8e55ce2b35 100644 --- a/cmd/milvus/run.go +++ b/cmd/milvus/run.go @@ -11,6 +11,7 @@ import ( "go.uber.org/zap" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/hardware" "github.com/milvus-io/milvus/pkg/util/metricsinfo" ) @@ -36,6 +37,7 @@ func (c *run) execute(args []string, flags *flag.FlagSet) { c.printBanner(flags.Output()) c.injectVariablesToEnv() + c.printHardwareInfo(flags.Output()) lock, err := createPidFile(flags.Output(), filename, runtimeDir) if err != nil { panic(err) @@ -59,6 +61,14 @@ func (c *run) printBanner(w io.Writer) { fmt.Fprintln(w) } +func (c *run) printHardwareInfo(w io.Writer) { + totalMem := hardware.GetMemoryCount() + usedMem := hardware.GetUsedMemoryCount() + fmt.Fprintf(w, "TotalMem: %d\n", totalMem) + fmt.Fprintf(w, "UsedMem: %d\n", usedMem) + fmt.Fprintln(w) +} + func (c *run) injectVariablesToEnv() { // inject in need From d07197ab1a065ac57b2f659da9e913629b040fd8 Mon Sep 17 00:00:00 2001 From: zhagnlu <1542303831@qq.com> Date: Sun, 7 Jan 2024 20:20:57 +0800 Subject: [PATCH 12/20] enhance: add compare simd function (#29432) #26137 Signed-off-by: luzhang Co-authored-by: luzhang --- .../core/src/exec/expression/CompareExpr.h | 21 +- .../src/query/visitors/ExecExprVisitor.cpp | 21 - internal/core/src/simd/CMakeLists.txt | 5 +- internal/core/src/simd/avx2.cpp | 41 +- internal/core/src/simd/avx512.cpp | 725 +++++++++++++++++- internal/core/src/simd/avx512.h | 48 ++ internal/core/src/simd/common.h | 9 + internal/core/src/simd/hook.cpp | 489 ++++++++++-- internal/core/src/simd/hook.h | 205 +++-- internal/core/src/simd/interface.h | 264 +++++++ internal/core/src/simd/ref.h | 94 +++ internal/core/src/simd/sse2.cpp | 44 +- internal/core/src/simd/sse4.cpp | 4 +- internal/core/unittest/test_simd.cpp | 315 +++++++- 14 files changed, 2044 insertions(+), 241 deletions(-) create mode 100644 internal/core/src/simd/interface.h diff --git a/internal/core/src/exec/expression/CompareExpr.h b/internal/core/src/exec/expression/CompareExpr.h index 5b0497e0b8988..c05974eb5429c 100644 --- a/internal/core/src/exec/expression/CompareExpr.h +++ b/internal/core/src/exec/expression/CompareExpr.h @@ -24,6 +24,7 @@ #include "common/Vector.h" #include "exec/expression/Expr.h" #include "segcore/SegmentInterface.h" +#include "simd/interface.h" namespace milvus { namespace exec { @@ -41,7 +42,7 @@ using ChunkDataAccessor = std::function; template struct CompareElementFunc { void - operator()(const T* left, const U* right, size_t size, bool* res) { + operator_base(const T* left, const U* right, size_t size, bool* res) { for (int i = 0; i < size; ++i) { if constexpr (op == proto::plan::OpType::Equal) { res[i] = left[i] == right[i]; @@ -63,6 +64,24 @@ struct CompareElementFunc { } } } + + void + operator()(const T* left, const U* right, size_t size, bool* res) { +#if defined(USE_DYNAMIC_SIMD) + if constexpr (std::is_same_v) { + milvus::simd::compare_col_func( + static_cast(op), + left, + right, + size, + res); + } else { + operator_base(left, right, size, res); + } +#else + operator_base(left, right, size, res); +#endif + } }; class PhyCompareFilterExpr : public Expr { diff --git a/internal/core/src/query/visitors/ExecExprVisitor.cpp b/internal/core/src/query/visitors/ExecExprVisitor.cpp index 808f1758abe9f..e6a8ef901c663 100644 --- a/internal/core/src/query/visitors/ExecExprVisitor.cpp +++ b/internal/core/src/query/visitors/ExecExprVisitor.cpp @@ -2632,30 +2632,9 @@ ExecExprVisitor::ExecTermVisitorImplTemplate(TermExpr& expr_raw) -> BitsetType { return index->In(n, terms.data()); }; -#if defined(USE_DYNAMIC_SIMD) - std::function x)> elem_func; - if (n <= milvus::simd::TERM_EXPR_IN_SIZE_THREAD) { - elem_func = [&terms, &term_set, n](MayConstRef x) { - if constexpr (std::is_integral::value || - std::is_floating_point::value) { - return milvus::simd::find_term_func(terms.data(), n, x); - } else { - // For string type, simd performance not better than set mode - static_assert(std::is_same::value || - std::is_same::value); - return term_set.find(x) != term_set.end(); - } - }; - } else { - elem_func = [&term_set, n](MayConstRef x) { - return term_set.find(x) != term_set.end(); - }; - } -#else auto elem_func = [&term_set](MayConstRef x) { return term_set.find(x) != term_set.end(); }; -#endif auto default_skip_index_func = [&](const SkipIndex& skipIndex, FieldId fieldId, diff --git a/internal/core/src/simd/CMakeLists.txt b/internal/core/src/simd/CMakeLists.txt index ced8277197f20..632373da08ca6 100644 --- a/internal/core/src/simd/CMakeLists.txt +++ b/internal/core/src/simd/CMakeLists.txt @@ -25,7 +25,8 @@ if (${CMAKE_SYSTEM_PROCESSOR} STREQUAL "x86_64") ) set_source_files_properties(sse4.cpp PROPERTIES COMPILE_FLAGS "-msse4.2") set_source_files_properties(avx2.cpp PROPERTIES COMPILE_FLAGS "-mavx2") - set_source_files_properties(avx512.cpp PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512dq -mavx512bw") + set_source_files_properties(avx512.cpp PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512vl -mavx512dq -mavx512bw") + elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm*") # TODO: add arm cpu simd message ("simd using arm mode") @@ -37,4 +38,4 @@ endif() add_library(milvus_simd ${MILVUS_SIMD_SRCS}) # Link the milvus_simd library with other libraries as needed -target_link_libraries(milvus_simd milvus_log) \ No newline at end of file +target_link_libraries(milvus_simd milvus_log) diff --git a/internal/core/src/simd/avx2.cpp b/internal/core/src/simd/avx2.cpp index 08c6a2636d9f0..1ea51faccabb6 100644 --- a/internal/core/src/simd/avx2.cpp +++ b/internal/core/src/simd/avx2.cpp @@ -29,11 +29,12 @@ GetBitsetBlockAVX2(const bool* src) { // BitsetBlockType has 64 bits __m256i highbit = _mm256_set1_epi8(0x7F); uint32_t tmp[8]; - for (size_t i = 0; i < 2; i += 1) { - __m256i boolvec = _mm256_loadu_si256((__m256i*)&src[i * 32]); - __m256i highbits = _mm256_add_epi8(boolvec, highbit); - tmp[i] = _mm256_movemask_epi8(highbits); - } + __m256i boolvec = _mm256_loadu_si256((__m256i*)(src)); + __m256i highbits = _mm256_add_epi8(boolvec, highbit); + tmp[0] = _mm256_movemask_epi8(highbits); + boolvec = _mm256_loadu_si256((__m256i*)(src + 32)); + highbits = _mm256_add_epi8(boolvec, highbit); + tmp[1] = _mm256_movemask_epi8(highbits); __m256i tmpvec = _mm256_loadu_si256((__m256i*)tmp); BitsetBlockType res[4]; @@ -65,9 +66,9 @@ FindTermAVX2(const bool* src, size_t vec_size, bool val) { __m256i ymm_data; size_t num_chunks = vec_size / 32; - for (size_t i = 0; i < num_chunks; i++) { + for (size_t i = 0; i < 32 * num_chunks; i += 32) { ymm_data = - _mm256_loadu_si256(reinterpret_cast(src + 32 * i)); + _mm256_loadu_si256(reinterpret_cast(src + i)); __m256i ymm_match = _mm256_cmpeq_epi8(ymm_data, ymm_target); int mask = _mm256_movemask_epi8(ymm_match); if (mask != 0) { @@ -90,9 +91,9 @@ FindTermAVX2(const int8_t* src, size_t vec_size, int8_t val) { __m256i ymm_data; size_t num_chunks = vec_size / 32; - for (size_t i = 0; i < num_chunks; i++) { + for (size_t i = 0; i < 32 * num_chunks; i += 32) { ymm_data = - _mm256_loadu_si256(reinterpret_cast(src + 32 * i)); + _mm256_loadu_si256(reinterpret_cast(src + i)); __m256i ymm_match = _mm256_cmpeq_epi8(ymm_data, ymm_target); int mask = _mm256_movemask_epi8(ymm_match); if (mask != 0) { @@ -114,10 +115,9 @@ FindTermAVX2(const int16_t* src, size_t vec_size, int16_t val) { __m256i ymm_target = _mm256_set1_epi16(val); __m256i ymm_data; size_t num_chunks = vec_size / 16; - size_t remaining_size = vec_size % 16; - for (size_t i = 0; i < num_chunks; i++) { + for (size_t i = 0; i < 16 * num_chunks; i += 16) { ymm_data = - _mm256_loadu_si256(reinterpret_cast(src + 16 * i)); + _mm256_loadu_si256(reinterpret_cast(src + i)); __m256i ymm_match = _mm256_cmpeq_epi16(ymm_data, ymm_target); int mask = _mm256_movemask_epi8(ymm_match); if (mask != 0) { @@ -141,9 +141,9 @@ FindTermAVX2(const int32_t* src, size_t vec_size, int32_t val) { size_t num_chunks = vec_size / 8; size_t remaining_size = vec_size % 8; - for (size_t i = 0; i < num_chunks; i++) { + for (size_t i = 0; i < 8 * num_chunks; i += 8) { ymm_data = - _mm256_loadu_si256(reinterpret_cast(src + 8 * i)); + _mm256_loadu_si256(reinterpret_cast(src + i)); __m256i ymm_match = _mm256_cmpeq_epi32(ymm_data, ymm_target); int mask = _mm256_movemask_epi8(ymm_match); if (mask != 0) { @@ -163,11 +163,10 @@ FindTermAVX2(const int64_t* src, size_t vec_size, int64_t val) { __m256i ymm_target = _mm256_set1_epi64x(val); __m256i ymm_data; size_t num_chunks = vec_size / 4; - size_t remaining_size = vec_size % 4; - for (size_t i = 0; i < num_chunks; i++) { + for (size_t i = 0; i < 4 * num_chunks; i += 4) { ymm_data = - _mm256_loadu_si256(reinterpret_cast(src + 4 * i)); + _mm256_loadu_si256(reinterpret_cast(src + i)); __m256i ymm_match = _mm256_cmpeq_epi64(ymm_data, ymm_target); int mask = _mm256_movemask_epi8(ymm_match); if (mask != 0) { @@ -190,8 +189,8 @@ FindTermAVX2(const float* src, size_t vec_size, float val) { __m256 ymm_data; size_t num_chunks = vec_size / 8; - for (size_t i = 0; i < num_chunks; i++) { - ymm_data = _mm256_loadu_ps(src + 8 * i); + for (size_t i = 0; i < 8 * num_chunks; i += 8) { + ymm_data = _mm256_loadu_ps(src + i); __m256 ymm_match = _mm256_cmp_ps(ymm_data, ymm_target, _CMP_EQ_OQ); int mask = _mm256_movemask_ps(ymm_match); if (mask != 0) { @@ -214,8 +213,8 @@ FindTermAVX2(const double* src, size_t vec_size, double val) { __m256d ymm_data; size_t num_chunks = vec_size / 4; - for (size_t i = 0; i < num_chunks; i++) { - ymm_data = _mm256_loadu_pd(src + 8 * i); + for (size_t i = 0; i < 4 * num_chunks; i += 4) { + ymm_data = _mm256_loadu_pd(src + i); __m256d ymm_match = _mm256_cmp_pd(ymm_data, ymm_target, _CMP_EQ_OQ); int mask = _mm256_movemask_pd(ymm_match); if (mask != 0) { diff --git a/internal/core/src/simd/avx512.cpp b/internal/core/src/simd/avx512.cpp index 3df38319fdade..e1bc4da3ffe15 100644 --- a/internal/core/src/simd/avx512.cpp +++ b/internal/core/src/simd/avx512.cpp @@ -25,9 +25,9 @@ FindTermAVX512(const bool* src, size_t vec_size, bool val) { __m512i zmm_data; size_t num_chunks = vec_size / 64; - for (size_t i = 0; i < num_chunks; i++) { + for (size_t i = 0; i < 64 * num_chunks; i += 64) { zmm_data = - _mm512_loadu_si512(reinterpret_cast(src + 64 * i)); + _mm512_loadu_si512(reinterpret_cast(src + i)); __mmask64 mask = _mm512_cmpeq_epi8_mask(zmm_data, zmm_target); if (mask != 0) { return true; @@ -49,9 +49,9 @@ FindTermAVX512(const int8_t* src, size_t vec_size, int8_t val) { __m512i zmm_data; size_t num_chunks = vec_size / 64; - for (size_t i = 0; i < num_chunks; i++) { + for (size_t i = 0; i < 64 * num_chunks; i += 64) { zmm_data = - _mm512_loadu_si512(reinterpret_cast(src + 64 * i)); + _mm512_loadu_si512(reinterpret_cast(src + i)); __mmask64 mask = _mm512_cmpeq_epi8_mask(zmm_data, zmm_target); if (mask != 0) { return true; @@ -73,9 +73,9 @@ FindTermAVX512(const int16_t* src, size_t vec_size, int16_t val) { __m512i zmm_data; size_t num_chunks = vec_size / 32; - for (size_t i = 0; i < num_chunks; i++) { + for (size_t i = 0; i < 32 * num_chunks; i += 32) { zmm_data = - _mm512_loadu_si512(reinterpret_cast(src + 32 * i)); + _mm512_loadu_si512(reinterpret_cast(src + i)); __mmask32 mask = _mm512_cmpeq_epi16_mask(zmm_data, zmm_target); if (mask != 0) { return true; @@ -97,9 +97,9 @@ FindTermAVX512(const int32_t* src, size_t vec_size, int32_t val) { __m512i zmm_data; size_t num_chunks = vec_size / 16; - for (size_t i = 0; i < num_chunks; i++) { + for (size_t i = 0; i < 16 * num_chunks; i += 16) { zmm_data = - _mm512_loadu_si512(reinterpret_cast(src + 16 * i)); + _mm512_loadu_si512(reinterpret_cast(src + i)); __mmask16 mask = _mm512_cmpeq_epi32_mask(zmm_data, zmm_target); if (mask != 0) { return true; @@ -121,9 +121,9 @@ FindTermAVX512(const int64_t* src, size_t vec_size, int64_t val) { __m512i zmm_data; size_t num_chunks = vec_size / 8; - for (size_t i = 0; i < num_chunks; i++) { + for (size_t i = 0; i < 8 * num_chunks; i += 8) { zmm_data = - _mm512_loadu_si512(reinterpret_cast(src + 8 * i)); + _mm512_loadu_si512(reinterpret_cast(src + i)); __mmask8 mask = _mm512_cmpeq_epi64_mask(zmm_data, zmm_target); if (mask != 0) { return true; @@ -145,8 +145,8 @@ FindTermAVX512(const float* src, size_t vec_size, float val) { __m512 zmm_data; size_t num_chunks = vec_size / 16; - for (size_t i = 0; i < num_chunks; i++) { - zmm_data = _mm512_loadu_ps(src + 16 * i); + for (size_t i = 0; i < 16 * num_chunks; i += 16) { + zmm_data = _mm512_loadu_ps(src + i); __mmask16 mask = _mm512_cmp_ps_mask(zmm_data, zmm_target, _CMP_EQ_OQ); if (mask != 0) { return true; @@ -168,8 +168,8 @@ FindTermAVX512(const double* src, size_t vec_size, double val) { __m512d zmm_data; size_t num_chunks = vec_size / 8; - for (size_t i = 0; i < num_chunks; i++) { - zmm_data = _mm512_loadu_pd(src + 8 * i); + for (size_t i = 0; i < 8 * num_chunks; i += 8) { + zmm_data = _mm512_loadu_pd(src + i); __mmask8 mask = _mm512_cmp_pd_mask(zmm_data, zmm_target, _CMP_EQ_OQ); if (mask != 0) { return true; @@ -216,6 +216,703 @@ OrBoolAVX512(bool* left, bool* right, int64_t size) { } } +template +struct CompareOperator; + +template +struct CompareOperator { + static constexpr int ComparePredicate = + std::is_floating_point_v ? _CMP_EQ_OQ : _MM_CMPINT_EQ; + static constexpr bool + Op(T a, T b) { + return a == b; + } +}; + +template +struct CompareOperator { + static constexpr int ComparePredicate = + std::is_floating_point_v ? _CMP_NEQ_OQ : _MM_CMPINT_NE; + static constexpr bool + Op(T a, T b) { + return a != b; + } +}; + +template +struct CompareOperator { + static constexpr int ComparePredicate = + std::is_floating_point_v ? _CMP_LT_OQ : _MM_CMPINT_LT; + static constexpr bool + Op(T a, T b) { + return a < b; + } +}; + +template +struct CompareOperator { + static constexpr int ComparePredicate = + std::is_floating_point_v ? _CMP_LE_OQ : _MM_CMPINT_LE; + static constexpr bool + Op(T a, T b) { + return a <= b; + } +}; + +template +struct CompareOperator { + static constexpr int ComparePredicate = + std::is_floating_point_v ? _CMP_GT_OQ : _MM_CMPINT_NLE; + static constexpr bool + Op(T a, T b) { + return a > b; + } +}; + +template +struct CompareOperator { + static constexpr int ComparePredicate = + std::is_floating_point_v ? _CMP_GE_OQ : _MM_CMPINT_NLT; + static constexpr bool + Op(T a, T b) { + return a >= b; + } +}; + +template +struct CompareValAVX512Impl { + static void + Compare(const T* src, size_t size, T val, bool* res) { + static_assert(std::is_integral_v || std::is_floating_point_v, + "T must be integral or float/double type"); + } +}; + +template +struct CompareValAVX512Impl { + static void + Compare(const int8_t* src, size_t size, int8_t val, bool* res) { + __m512i target = _mm512_set1_epi8(val); + + int middle = size / 64 * 64; + + for (size_t i = 0; i < middle; i += 64) { + __m512i data = + _mm512_loadu_si512(reinterpret_cast(src + i)); + + __mmask64 cmp_res_mask = _mm512_cmp_epi8_mask( + data, + target, + (CompareOperator::ComparePredicate)); + __m512i cmp_res = _mm512_maskz_set1_epi8(cmp_res_mask, 0x01); + _mm512_storeu_si512(res + i, cmp_res); + } + + for (size_t i = middle; i < size; ++i) { + res[i] = CompareOperator::Op(src[i], val); + } + } +}; + +template +struct CompareValAVX512Impl { + static void + Compare(const int16_t* src, size_t size, int16_t val, bool* res) { + __m512i target = _mm512_set1_epi16(val); + + int middle = size / 32 * 32; + + for (size_t i = 0; i < middle; i += 32) { + __m512i data = + _mm512_loadu_si512(reinterpret_cast(src + i)); + + __mmask32 cmp_res_mask = _mm512_cmp_epi16_mask( + data, + target, + (CompareOperator::ComparePredicate)); + __m256i cmp_res = _mm256_maskz_set1_epi8(cmp_res_mask, 0x01); + _mm256_storeu_si256((__m256i*)(res + i), cmp_res); + } + + for (size_t i = middle; i < size; ++i) { + res[i] = CompareOperator::Op(src[i], val); + } + } +}; + +template +struct CompareValAVX512Impl { + static void + Compare(const int32_t* src, size_t size, int32_t val, bool* res) { + __m512i target = _mm512_set1_epi32(val); + + int middle = size / 16 * 16; + + for (size_t i = 0; i < middle; i += 16) { + __m512i data = + _mm512_loadu_si512(reinterpret_cast(src + i)); + + __mmask16 cmp_res_mask = _mm512_cmp_epi32_mask( + data, + target, + (CompareOperator::ComparePredicate)); + __m128i cmp_res = _mm_maskz_set1_epi8(cmp_res_mask, 0x01); + _mm_storeu_si128((__m128i*)(res + i), cmp_res); + } + + for (size_t i = middle; i < size; ++i) { + res[i] = CompareOperator::Op(src[i], val); + } + } +}; + +template +struct CompareValAVX512Impl { + static void + Compare(const int64_t* src, size_t size, int64_t val, bool* res) { + __m512i target = _mm512_set1_epi64(val); + int middle = size / 8 * 8; + int index = 0; + for (size_t i = 0; i < middle; i += 8) { + __m512i data = + _mm512_loadu_si512(reinterpret_cast(src + i)); + __mmask8 mask = _mm512_cmp_epi64_mask( + data, + target, + (CompareOperator::ComparePredicate)); + __m128i cmp_res = _mm_maskz_set1_epi8(mask, 0x01); + _mm_storeu_si64((__m128i*)(res + i), cmp_res); + } + + for (size_t i = middle; i < size; ++i) { + res[i] = CompareOperator::Op(src[i], val); + } + } +}; + +template +struct CompareValAVX512Impl { + static void + Compare(const float* src, size_t size, float val, bool* res) { + __m512 target = _mm512_set1_ps(val); + + int middle = size / 16 * 16; + + for (size_t i = 0; i < middle; i += 16) { + __m512 data = _mm512_loadu_ps(src + i); + + __mmask16 cmp_res_mask = _mm512_cmp_ps_mask( + data, target, (CompareOperator::ComparePredicate)); + __m128i cmp_res = _mm_maskz_set1_epi8(cmp_res_mask, 0x01); + _mm_storeu_si128((__m128i*)(res + i), cmp_res); + } + + for (size_t i = middle; i < size; ++i) { + res[i] = CompareOperator::Op(src[i], val); + } + } +}; + +template +struct CompareValAVX512Impl { + static void + Compare(const double* src, size_t size, double val, bool* res) { + __m512d target = _mm512_set1_pd(val); + + int middle = size / 8 * 8; + + for (size_t i = 0; i < middle; i += 8) { + __m512d data = _mm512_loadu_pd(src + i); + + __mmask8 cmp_res_mask = _mm512_cmp_pd_mask( + data, + target, + (CompareOperator::ComparePredicate)); + __m128i cmp_res = _mm_maskz_set1_epi8(cmp_res_mask, 0x01); + _mm_storeu_si64((res + i), cmp_res); + } + + for (size_t i = middle; i < size; ++i) { + res[i] = CompareOperator::Op(src[i], val); + } + } +}; + +template +void +EqualValAVX512(const T* src, size_t size, T val, bool* res) { + static_assert(std::is_integral_v || std::is_floating_point_v, + "T must be integral or float/double type"); + CompareValAVX512Impl::Compare(src, size, val, res); +}; +template void +EqualValAVX512(const int8_t* src, size_t size, int8_t val, bool* res); +template void +EqualValAVX512(const int16_t* src, size_t size, int16_t val, bool* res); +template void +EqualValAVX512(const int32_t* src, size_t size, int32_t val, bool* res); +template void +EqualValAVX512(const int64_t* src, size_t size, int64_t val, bool* res); +template void +EqualValAVX512(const float* src, size_t size, float val, bool* res); +template void +EqualValAVX512(const double* src, size_t size, double val, bool* res); + +template +void +LessValAVX512(const T* src, size_t size, T val, bool* res) { + static_assert(std::is_integral_v || std::is_floating_point_v, + "T must be integral or float/double type"); + CompareValAVX512Impl::Compare(src, size, val, res); +}; +template void +LessValAVX512(const int8_t* src, size_t size, int8_t val, bool* res); +template void +LessValAVX512(const int16_t* src, size_t size, int16_t val, bool* res); +template void +LessValAVX512(const int32_t* src, size_t size, int32_t val, bool* res); +template void +LessValAVX512(const int64_t* src, size_t size, int64_t val, bool* res); +template void +LessValAVX512(const float* src, size_t size, float val, bool* res); +template void +LessValAVX512(const double* src, size_t size, double val, bool* res); + +template +void +GreaterValAVX512(const T* src, size_t size, T val, bool* res) { + static_assert(std::is_integral_v || std::is_floating_point_v, + "T must be integral or float/double type"); + CompareValAVX512Impl::Compare(src, size, val, res); +}; +template void +GreaterValAVX512(const int8_t* src, size_t size, int8_t val, bool* res); +template void +GreaterValAVX512(const int16_t* src, size_t size, int16_t val, bool* res); +template void +GreaterValAVX512(const int32_t* src, size_t size, int32_t val, bool* res); +template void +GreaterValAVX512(const int64_t* src, size_t size, int64_t val, bool* res); +template void +GreaterValAVX512(const float* src, size_t size, float val, bool* res); +template void +GreaterValAVX512(const double* src, size_t size, double val, bool* res); + +template +void +NotEqualValAVX512(const T* src, size_t size, T val, bool* res) { + static_assert(std::is_integral_v || std::is_floating_point_v, + "T must be integral or float/double type"); + CompareValAVX512Impl::Compare(src, size, val, res); +}; +template void +NotEqualValAVX512(const int8_t* src, size_t size, int8_t val, bool* res); +template void +NotEqualValAVX512(const int16_t* src, size_t size, int16_t val, bool* res); +template void +NotEqualValAVX512(const int32_t* src, size_t size, int32_t val, bool* res); +template void +NotEqualValAVX512(const int64_t* src, size_t size, int64_t val, bool* res); +template void +NotEqualValAVX512(const float* src, size_t size, float val, bool* res); +template void +NotEqualValAVX512(const double* src, size_t size, double val, bool* res); + +template +void +LessEqualValAVX512(const T* src, size_t size, T val, bool* res) { + static_assert(std::is_integral_v || std::is_floating_point_v, + "T must be integral or float/double type"); + CompareValAVX512Impl::Compare(src, size, val, res); +}; +template void +LessEqualValAVX512(const int8_t* src, size_t size, int8_t val, bool* res); +template void +LessEqualValAVX512(const int16_t* src, size_t size, int16_t val, bool* res); +template void +LessEqualValAVX512(const int32_t* src, size_t size, int32_t val, bool* res); +template void +LessEqualValAVX512(const int64_t* src, size_t size, int64_t val, bool* res); +template void +LessEqualValAVX512(const float* src, size_t size, float val, bool* res); +template void +LessEqualValAVX512(const double* src, size_t size, double val, bool* res); + +template +void +GreaterEqualValAVX512(const T* src, size_t size, T val, bool* res) { + static_assert(std::is_integral_v || std::is_floating_point_v, + "T must be integral or float/double type"); + CompareValAVX512Impl::Compare(src, size, val, res); +}; +template void +GreaterEqualValAVX512(const int8_t* src, size_t size, int8_t val, bool* res); +template void +GreaterEqualValAVX512(const int16_t* src, size_t size, int16_t val, bool* res); +template void +GreaterEqualValAVX512(const int32_t* src, size_t size, int32_t val, bool* res); +template void +GreaterEqualValAVX512(const int64_t* src, size_t size, int64_t val, bool* res); +template void +GreaterEqualValAVX512(const float* src, size_t size, float val, bool* res); +template void +GreaterEqualValAVX512(const double* src, size_t size, double val, bool* res); + +template +void +CompareColumnAVX512(const T* left, const T* right, size_t size, bool* res) { + static_assert(std::is_integral_v || std::is_floating_point_v, + "T must be integral or float/double type"); +} + +template +struct CompareColumnAVX512Impl { + static void + Compare(const T* left, const T* right, size_t size, bool* res) { + static_assert(std::is_integral_v, "T must be integral type"); + + int batch_size = 512 / (sizeof(T) * 8); + int middle = size / batch_size * batch_size; + + for (size_t i = 0; i < middle; i += batch_size) { + __m512i left_reg = + _mm512_loadu_si512(reinterpret_cast(left + i)); + __m512i right_reg = + _mm512_loadu_si512(reinterpret_cast(right + i)); + + if constexpr (std::is_same_v) { + __mmask64 cmp_res_mask = _mm512_cmp_epi8_mask( + left_reg, + right_reg, + (CompareOperator::ComparePredicate)); + + __m512i cmp_res = _mm512_maskz_set1_epi8(cmp_res_mask, 0x01); + _mm512_storeu_si512(res + i, cmp_res); + } else if constexpr (std::is_same_v) { + __mmask32 cmp_res_mask = _mm512_cmp_epi16_mask( + left_reg, + right_reg, + (CompareOperator::ComparePredicate)); + + __m256i cmp_res = _mm256_maskz_set1_epi8(cmp_res_mask, 0x01); + _mm256_storeu_si256((__m256i*)(res + i), cmp_res); + } else if constexpr (std::is_same_v) { + __mmask16 cmp_res_mask = _mm512_cmp_epi32_mask( + left_reg, + right_reg, + (CompareOperator::ComparePredicate)); + + __m128i cmp_res = _mm_maskz_set1_epi8(cmp_res_mask, 0x01); + _mm_storeu_si128((__m128i*)(res + i), cmp_res); + } else if constexpr (std::is_same_v) { + __mmask8 mask = _mm512_cmp_epi64_mask( + left_reg, + right_reg, + (CompareOperator::ComparePredicate)); + + __m128i cmp_res = _mm_maskz_set1_epi8(mask, 0x01); + _mm_storeu_si64((__m128i*)(res + i), cmp_res); + } + } + + for (size_t i = middle; i < size; ++i) { + res[i] = CompareOperator::Op(left[i], right[i]); + } + } +}; + +template +struct CompareColumnAVX512Impl { + static void + Compare(const float* left, const float* right, size_t size, bool* res) { + int batch_size = 512 / (sizeof(float) * 8); + int middle = size / batch_size * batch_size; + + for (size_t i = 0; i < middle; i += batch_size) { + __m512 left_reg = + _mm512_loadu_ps(reinterpret_cast(left + i)); + __m512 right_reg = + _mm512_loadu_ps(reinterpret_cast(right + i)); + + __mmask16 cmp_res_mask = _mm512_cmp_ps_mask( + left_reg, + right_reg, + (CompareOperator::ComparePredicate)); + + __m128i cmp_res = _mm_maskz_set1_epi8(cmp_res_mask, 0x01); + _mm_storeu_si128((__m128i*)(res + i), cmp_res); + } + + for (size_t i = middle; i < size; ++i) { + res[i] = CompareOperator::Op(left[i], right[i]); + } + } +}; + +template +struct CompareColumnAVX512Impl { + static void + Compare(const double* left, const double* right, size_t size, bool* res) { + int batch_size = 512 / (sizeof(double) * 8); + int middle = size / batch_size * batch_size; + + for (size_t i = 0; i < middle; i += batch_size) { + __m512d left_reg = + _mm512_loadu_pd(reinterpret_cast(left + i)); + __m512d right_reg = + _mm512_loadu_pd(reinterpret_cast(right + i)); + + __mmask8 cmp_res_mask = _mm512_cmp_pd_mask( + left_reg, + right_reg, + (CompareOperator::ComparePredicate)); + + __m128i cmp_res = _mm_maskz_set1_epi8(cmp_res_mask, 0x01); + _mm_storeu_si64((res + i), cmp_res); + } + + for (size_t i = middle; i < size; ++i) { + res[i] = CompareOperator::Op(left[i], right[i]); + } + } +}; + +template +void +EqualColumnAVX512(const T* left, const T* right, size_t size, bool* res) { + static_assert(std::is_integral_v || std::is_floating_point_v, + "T must be integral or float/double type"); + CompareColumnAVX512Impl::Compare( + left, right, size, res); +}; + +template void +EqualColumnAVX512(const int8_t* left, + const int8_t* right, + size_t size, + bool* res); +template void +EqualColumnAVX512(const int16_t* left, + const int16_t* right, + size_t size, + bool* res); +template void +EqualColumnAVX512(const int32_t* left, + const int32_t* right, + size_t size, + bool* res); +template void +EqualColumnAVX512(const int64_t* left, + const int64_t* right, + size_t size, + bool* res); +template void +EqualColumnAVX512(const float* left, + const float* right, + size_t size, + bool* res); +template void +EqualColumnAVX512(const double* left, + const double* right, + size_t size, + bool* res); + +template +void +LessColumnAVX512(const T* left, const T* right, size_t size, bool* res) { + static_assert(std::is_integral_v || std::is_floating_point_v, + "T must be integral or float/double type"); + CompareColumnAVX512Impl::Compare( + left, right, size, res); +}; +template void +LessColumnAVX512(const int8_t* left, + const int8_t* right, + size_t size, + bool* res); +template void +LessColumnAVX512(const int16_t* left, + const int16_t* right, + size_t size, + bool* res); +template void +LessColumnAVX512(const int32_t* left, + const int32_t* right, + size_t size, + bool* res); +template void +LessColumnAVX512(const int64_t* left, + const int64_t* right, + size_t size, + bool* res); +template void +LessColumnAVX512(const float* left, const float* right, size_t size, bool* res); +template void +LessColumnAVX512(const double* left, + const double* right, + size_t size, + bool* res); + +template +void +GreaterColumnAVX512(const T* left, const T* right, size_t size, bool* res) { + static_assert(std::is_integral_v || std::is_floating_point_v, + "T must be integral or float/double type"); + CompareColumnAVX512Impl::Compare( + left, right, size, res); +}; +template void +GreaterColumnAVX512(const int8_t* left, + const int8_t* right, + size_t size, + bool* res); +template void +GreaterColumnAVX512(const int16_t* left, + const int16_t* right, + size_t size, + bool* res); +template void +GreaterColumnAVX512(const int32_t* left, + const int32_t* right, + size_t size, + bool* res); +template void +GreaterColumnAVX512(const int64_t* left, + const int64_t* right, + size_t size, + bool* res); +template void +GreaterColumnAVX512(const float* left, + const float* right, + size_t size, + bool* res); +template void +GreaterColumnAVX512(const double* left, + const double* right, + size_t size, + bool* res); + +template +void +LessEqualColumnAVX512(const T* left, const T* right, size_t size, bool* res) { + static_assert(std::is_integral_v || std::is_floating_point_v, + "T must be integral or float/double type"); + CompareColumnAVX512Impl::Compare( + left, right, size, res); +}; +template void +LessEqualColumnAVX512(const int8_t* left, + const int8_t* right, + size_t size, + bool* res); +template void +LessEqualColumnAVX512(const int16_t* left, + const int16_t* right, + size_t size, + bool* res); +template void +LessEqualColumnAVX512(const int32_t* left, + const int32_t* right, + size_t size, + bool* res); +template void +LessEqualColumnAVX512(const int64_t* left, + const int64_t* right, + size_t size, + bool* res); +template void +LessEqualColumnAVX512(const float* left, + const float* right, + size_t size, + bool* res); +template void +LessEqualColumnAVX512(const double* left, + const double* right, + size_t size, + bool* res); + +template +void +GreaterEqualColumnAVX512(const T* left, + const T* right, + size_t size, + bool* res) { + static_assert(std::is_integral_v || std::is_floating_point_v, + "T must be integral or float/double type"); + CompareColumnAVX512Impl::Compare( + left, right, size, res); +}; +template void +GreaterEqualColumnAVX512(const int8_t* left, + const int8_t* right, + size_t size, + bool* res); +template void +GreaterEqualColumnAVX512(const int16_t* left, + const int16_t* right, + size_t size, + bool* res); +template void +GreaterEqualColumnAVX512(const int32_t* left, + const int32_t* right, + size_t size, + bool* res); +template void +GreaterEqualColumnAVX512(const int64_t* left, + const int64_t* right, + size_t size, + bool* res); +template void +GreaterEqualColumnAVX512(const float* left, + const float* right, + size_t size, + bool* res); +template void +GreaterEqualColumnAVX512(const double* left, + const double* right, + size_t size, + bool* res); + +template +void +NotEqualColumnAVX512(const T* left, const T* right, size_t size, bool* res) { + static_assert(std::is_integral_v || std::is_floating_point_v, + "T must be integral or float/double type"); + CompareColumnAVX512Impl::Compare( + left, right, size, res); +}; + +template void +NotEqualColumnAVX512(const int8_t* left, + const int8_t* right, + size_t size, + bool* res); +template void +NotEqualColumnAVX512(const int16_t* left, + const int16_t* right, + size_t size, + bool* res); +template void +NotEqualColumnAVX512(const int32_t* left, + const int32_t* right, + size_t size, + bool* res); +template void +NotEqualColumnAVX512(const int64_t* left, + const int64_t* right, + size_t size, + bool* res); +template void +NotEqualColumnAVX512(const float* left, + const float* right, + size_t size, + bool* res); +template void +NotEqualColumnAVX512(const double* left, + const double* right, + size_t size, + bool* res); + } // namespace simd } // namespace milvus #endif diff --git a/internal/core/src/simd/avx512.h b/internal/core/src/simd/avx512.h index fe24b00bb69bf..9b5c549d3d429 100644 --- a/internal/core/src/simd/avx512.h +++ b/internal/core/src/simd/avx512.h @@ -61,5 +61,53 @@ AndBoolAVX512(bool* left, bool* right, int64_t size); void OrBoolAVX512(bool* left, bool* right, int64_t size); +template +void +EqualValAVX512(const T* src, size_t size, T val, bool* res); + +template +void +LessValAVX512(const T* src, size_t size, T val, bool* res); + +template +void +GreaterValAVX512(const T* src, size_t size, T val, bool* res); + +template +void +NotEqualValAVX512(const T* src, size_t size, T val, bool* res); + +template +void +LessEqualValAVX512(const T* src, size_t size, T val, bool* res); + +template +void +GreaterEqualValAVX512(const T* src, size_t size, T val, bool* res); + +template +void +EqualColumnAVX512(const T* left, const T* right, size_t size, bool* res); + +template +void +LessColumnAVX512(const T* left, const T* right, size_t size, bool* res); + +template +void +LessEqualColumnAVX512(const T* left, const T* right, size_t size, bool* res); + +template +void +GreaterColumnAVX512(const T* left, const T* right, size_t size, bool* res); + +template +void +GreaterEqualColumnAVX512(const T* left, const T* right, size_t size, bool* res); + +template +void +NotEqualColumnAVX512(const T* left, const T* right, size_t size, bool* res); + } // namespace simd } // namespace milvus diff --git a/internal/core/src/simd/common.h b/internal/core/src/simd/common.h index 3cbe9c6e3e76b..f6e0c9e3c630e 100644 --- a/internal/core/src/simd/common.h +++ b/internal/core/src/simd/common.h @@ -40,5 +40,14 @@ const int TERM_EXPR_IN_SIZE_THREAD = 50; std::is_same::value || std::is_same::value, \ Message); +enum class CompareType { + GT = 1, + GE = 2, + LT = 3, + LE = 4, + EQ = 5, + NEQ = 6, +}; + } // namespace simd } // namespace milvus diff --git a/internal/core/src/simd/hook.cpp b/internal/core/src/simd/hook.cpp index 2fe688b85724b..89b5b300671b6 100644 --- a/internal/core/src/simd/hook.cpp +++ b/internal/core/src/simd/hook.cpp @@ -32,19 +32,6 @@ namespace milvus { namespace simd { -#if defined(__x86_64__) -bool use_avx512 = true; -bool use_avx2 = true; -bool use_sse4_2 = true; -bool use_sse2 = true; - -bool use_bitset_sse2; -bool use_find_term_sse2; -bool use_find_term_sse4_2; -bool use_find_term_avx2; -bool use_find_term_avx512; -#endif - decltype(get_bitset_block) get_bitset_block = GetBitsetBlockRef; decltype(all_false) all_false = AllFalseRef; decltype(all_true) all_true = AllTrueRef; @@ -52,20 +39,124 @@ decltype(invert_bool) invert_bool = InvertBoolRef; decltype(and_bool) and_bool = AndBoolRef; decltype(or_bool) or_bool = OrBoolRef; -FindTermPtr find_term_bool = FindTermRef; -FindTermPtr find_term_int8 = FindTermRef; -FindTermPtr find_term_int16 = FindTermRef; -FindTermPtr find_term_int32 = FindTermRef; -FindTermPtr find_term_int64 = FindTermRef; -FindTermPtr find_term_float = FindTermRef; -FindTermPtr find_term_double = FindTermRef; +#define DECLARE_FIND_TERM_PTR(type) \ + FindTermPtr find_term_##type = FindTermRef; +DECLARE_FIND_TERM_PTR(bool) +DECLARE_FIND_TERM_PTR(int8_t) +DECLARE_FIND_TERM_PTR(int16_t) +DECLARE_FIND_TERM_PTR(int32_t) +DECLARE_FIND_TERM_PTR(int64_t) +DECLARE_FIND_TERM_PTR(float) +DECLARE_FIND_TERM_PTR(double) + +#define DECLARE_COMPARE_VAL_PTR(prefix, RefFunc, type) \ + CompareValPtr prefix##_##type = RefFunc; + +DECLARE_COMPARE_VAL_PTR(equal_val, EqualValRef, bool) +DECLARE_COMPARE_VAL_PTR(equal_val, EqualValRef, int8_t) +DECLARE_COMPARE_VAL_PTR(equal_val, EqualValRef, int16_t) +DECLARE_COMPARE_VAL_PTR(equal_val, EqualValRef, int32_t) +DECLARE_COMPARE_VAL_PTR(equal_val, EqualValRef, int64_t) +DECLARE_COMPARE_VAL_PTR(equal_val, EqualValRef, float) +DECLARE_COMPARE_VAL_PTR(equal_val, EqualValRef, double) + +DECLARE_COMPARE_VAL_PTR(less_val, LessValRef, bool) +DECLARE_COMPARE_VAL_PTR(less_val, LessValRef, int8_t) +DECLARE_COMPARE_VAL_PTR(less_val, LessValRef, int16_t) +DECLARE_COMPARE_VAL_PTR(less_val, LessValRef, int32_t) +DECLARE_COMPARE_VAL_PTR(less_val, LessValRef, int64_t) +DECLARE_COMPARE_VAL_PTR(less_val, LessValRef, float) +DECLARE_COMPARE_VAL_PTR(less_val, LessValRef, double) + +DECLARE_COMPARE_VAL_PTR(greater_val, GreaterValRef, bool) +DECLARE_COMPARE_VAL_PTR(greater_val, GreaterValRef, int8_t) +DECLARE_COMPARE_VAL_PTR(greater_val, GreaterValRef, int16_t) +DECLARE_COMPARE_VAL_PTR(greater_val, GreaterValRef, int32_t) +DECLARE_COMPARE_VAL_PTR(greater_val, GreaterValRef, int64_t) +DECLARE_COMPARE_VAL_PTR(greater_val, GreaterValRef, float) +DECLARE_COMPARE_VAL_PTR(greater_val, GreaterValRef, double) + +DECLARE_COMPARE_VAL_PTR(less_equal_val, LessEqualValRef, bool) +DECLARE_COMPARE_VAL_PTR(less_equal_val, LessEqualValRef, int8_t) +DECLARE_COMPARE_VAL_PTR(less_equal_val, LessEqualValRef, int16_t) +DECLARE_COMPARE_VAL_PTR(less_equal_val, LessEqualValRef, int32_t) +DECLARE_COMPARE_VAL_PTR(less_equal_val, LessEqualValRef, int64_t) +DECLARE_COMPARE_VAL_PTR(less_equal_val, LessEqualValRef, float) +DECLARE_COMPARE_VAL_PTR(less_equal_val, LessEqualValRef, double) + +DECLARE_COMPARE_VAL_PTR(greater_equal_val, GreaterEqualValRef, bool) +DECLARE_COMPARE_VAL_PTR(greater_equal_val, GreaterEqualValRef, int8_t) +DECLARE_COMPARE_VAL_PTR(greater_equal_val, GreaterEqualValRef, int16_t) +DECLARE_COMPARE_VAL_PTR(greater_equal_val, GreaterEqualValRef, int32_t) +DECLARE_COMPARE_VAL_PTR(greater_equal_val, GreaterEqualValRef, int64_t) +DECLARE_COMPARE_VAL_PTR(greater_equal_val, GreaterEqualValRef, float) +DECLARE_COMPARE_VAL_PTR(greater_equal_val, GreaterEqualValRef, double) + +DECLARE_COMPARE_VAL_PTR(not_equal_val, NotEqualValRef, bool) +DECLARE_COMPARE_VAL_PTR(not_equal_val, NotEqualValRef, int8_t) +DECLARE_COMPARE_VAL_PTR(not_equal_val, NotEqualValRef, int16_t) +DECLARE_COMPARE_VAL_PTR(not_equal_val, NotEqualValRef, int32_t) +DECLARE_COMPARE_VAL_PTR(not_equal_val, NotEqualValRef, int64_t) +DECLARE_COMPARE_VAL_PTR(not_equal_val, NotEqualValRef, float) +DECLARE_COMPARE_VAL_PTR(not_equal_val, NotEqualValRef, double) + +#define DECLARE_COMPARE_COL_PTR(prefix, RefFunc, type) \ + CompareColPtr prefix##_##type = RefFunc; + +DECLARE_COMPARE_COL_PTR(equal_col, EqualColumnRef, bool) +DECLARE_COMPARE_COL_PTR(equal_col, EqualColumnRef, int8_t) +DECLARE_COMPARE_COL_PTR(equal_col, EqualColumnRef, int16_t) +DECLARE_COMPARE_COL_PTR(equal_col, EqualColumnRef, int32_t) +DECLARE_COMPARE_COL_PTR(equal_col, EqualColumnRef, int64_t) +DECLARE_COMPARE_COL_PTR(equal_col, EqualColumnRef, float) +DECLARE_COMPARE_COL_PTR(equal_col, EqualColumnRef, double) + +DECLARE_COMPARE_COL_PTR(less_col, LessColumnRef, bool) +DECLARE_COMPARE_COL_PTR(less_col, LessColumnRef, int8_t) +DECLARE_COMPARE_COL_PTR(less_col, LessColumnRef, int16_t) +DECLARE_COMPARE_COL_PTR(less_col, LessColumnRef, int32_t) +DECLARE_COMPARE_COL_PTR(less_col, LessColumnRef, int64_t) +DECLARE_COMPARE_COL_PTR(less_col, LessColumnRef, float) +DECLARE_COMPARE_COL_PTR(less_col, LessColumnRef, double) + +DECLARE_COMPARE_COL_PTR(greater_col, GreaterColumnRef, bool) +DECLARE_COMPARE_COL_PTR(greater_col, GreaterColumnRef, int8_t) +DECLARE_COMPARE_COL_PTR(greater_col, GreaterColumnRef, int16_t) +DECLARE_COMPARE_COL_PTR(greater_col, GreaterColumnRef, int32_t) +DECLARE_COMPARE_COL_PTR(greater_col, GreaterColumnRef, int64_t) +DECLARE_COMPARE_COL_PTR(greater_col, GreaterColumnRef, float) +DECLARE_COMPARE_COL_PTR(greater_col, GreaterColumnRef, double) + +DECLARE_COMPARE_COL_PTR(less_equal_col, LessEqualColumnRef, bool) +DECLARE_COMPARE_COL_PTR(less_equal_col, LessEqualColumnRef, int8_t) +DECLARE_COMPARE_COL_PTR(less_equal_col, LessEqualColumnRef, int16_t) +DECLARE_COMPARE_COL_PTR(less_equal_col, LessEqualColumnRef, int32_t) +DECLARE_COMPARE_COL_PTR(less_equal_col, LessEqualColumnRef, int64_t) +DECLARE_COMPARE_COL_PTR(less_equal_col, LessEqualColumnRef, float) +DECLARE_COMPARE_COL_PTR(less_equal_col, LessEqualColumnRef, double) + +DECLARE_COMPARE_COL_PTR(greater_equal_col, GreaterEqualColumnRef, bool) +DECLARE_COMPARE_COL_PTR(greater_equal_col, GreaterEqualColumnRef, int8_t) +DECLARE_COMPARE_COL_PTR(greater_equal_col, GreaterEqualColumnRef, int16_t) +DECLARE_COMPARE_COL_PTR(greater_equal_col, GreaterEqualColumnRef, int32_t) +DECLARE_COMPARE_COL_PTR(greater_equal_col, GreaterEqualColumnRef, int64_t) +DECLARE_COMPARE_COL_PTR(greater_equal_col, GreaterEqualColumnRef, float) +DECLARE_COMPARE_COL_PTR(greater_equal_col, GreaterEqualColumnRef, double) + +DECLARE_COMPARE_COL_PTR(not_equal_col, NotEqualColumnRef, bool) +DECLARE_COMPARE_COL_PTR(not_equal_col, NotEqualColumnRef, int8_t) +DECLARE_COMPARE_COL_PTR(not_equal_col, NotEqualColumnRef, int16_t) +DECLARE_COMPARE_COL_PTR(not_equal_col, NotEqualColumnRef, int32_t) +DECLARE_COMPARE_COL_PTR(not_equal_col, NotEqualColumnRef, int64_t) +DECLARE_COMPARE_COL_PTR(not_equal_col, NotEqualColumnRef, float) +DECLARE_COMPARE_COL_PTR(not_equal_col, NotEqualColumnRef, double) #if defined(__x86_64__) bool cpu_support_avx512() { InstructionSet& instruction_set_inst = InstructionSet::GetInstance(); return (instruction_set_inst.AVX512F() && instruction_set_inst.AVX512DQ() && - instruction_set_inst.AVX512BW()); + instruction_set_inst.AVX512BW() && instruction_set_inst.AVX512VL()); } bool @@ -87,95 +178,77 @@ cpu_support_sse2() { } #endif -void +static void bitset_hook() { static std::mutex hook_mutex; std::lock_guard lock(hook_mutex); std::string simd_type = "REF"; #if defined(__x86_64__) - if (use_avx512 && cpu_support_avx512()) { - simd_type = "AVX512"; - // For now, sse2 has best performance - get_bitset_block = GetBitsetBlockSSE2; - use_bitset_sse2 = true; - } else if (use_avx2 && cpu_support_avx2()) { - simd_type = "AVX2"; - // For now, sse2 has best performance - get_bitset_block = GetBitsetBlockSSE2; - use_bitset_sse2 = true; - } else if (use_sse4_2 && cpu_support_sse4_2()) { - simd_type = "SSE4"; - get_bitset_block = GetBitsetBlockSSE2; - use_bitset_sse2 = true; - } else if (use_sse2 && cpu_support_sse2()) { + // SSE2 have best performance in test. + if (cpu_support_sse2()) { simd_type = "SSE2"; get_bitset_block = GetBitsetBlockSSE2; - use_bitset_sse2 = true; } #endif // TODO: support arm cpu LOG_INFO("bitset hook simd type: {}", simd_type); } -void +static void find_term_hook() { static std::mutex hook_mutex; std::lock_guard lock(hook_mutex); std::string simd_type = "REF"; #if defined(__x86_64__) - if (use_avx512 && cpu_support_avx512()) { + if (cpu_support_avx512()) { simd_type = "AVX512"; find_term_bool = FindTermAVX512; - find_term_int8 = FindTermAVX512; - find_term_int16 = FindTermAVX512; - find_term_int32 = FindTermAVX512; - find_term_int64 = FindTermAVX512; + find_term_int8_t = FindTermAVX512; + find_term_int16_t = FindTermAVX512; + find_term_int32_t = FindTermAVX512; + find_term_int64_t = FindTermAVX512; find_term_float = FindTermAVX512; find_term_double = FindTermAVX512; - use_find_term_avx512 = true; - } else if (use_avx2 && cpu_support_avx2()) { + } else if (cpu_support_avx2()) { simd_type = "AVX2"; find_term_bool = FindTermAVX2; - find_term_int8 = FindTermAVX2; - find_term_int16 = FindTermAVX2; - find_term_int32 = FindTermAVX2; - find_term_int64 = FindTermAVX2; + find_term_int8_t = FindTermAVX2; + find_term_int16_t = FindTermAVX2; + find_term_int32_t = FindTermAVX2; + find_term_int64_t = FindTermAVX2; find_term_float = FindTermAVX2; find_term_double = FindTermAVX2; - use_find_term_avx2 = true; - } else if (use_sse4_2 && cpu_support_sse4_2()) { + } else if (cpu_support_sse4_2()) { simd_type = "SSE4"; find_term_bool = FindTermSSE4; - find_term_int8 = FindTermSSE4; - find_term_int16 = FindTermSSE4; - find_term_int32 = FindTermSSE4; - find_term_int64 = FindTermSSE4; + find_term_int8_t = FindTermSSE4; + find_term_int16_t = FindTermSSE4; + find_term_int32_t = FindTermSSE4; + find_term_int64_t = FindTermSSE4; find_term_float = FindTermSSE4; find_term_double = FindTermSSE4; - use_find_term_sse4_2 = true; - } else if (use_sse2 && cpu_support_sse2()) { + } else if (cpu_support_sse2()) { simd_type = "SSE2"; find_term_bool = FindTermSSE2; - find_term_int8 = FindTermSSE2; - find_term_int16 = FindTermSSE2; - find_term_int32 = FindTermSSE2; - find_term_int64 = FindTermSSE2; + find_term_int8_t = FindTermSSE2; + find_term_int16_t = FindTermSSE2; + find_term_int32_t = FindTermSSE2; + find_term_int64_t = FindTermSSE2; find_term_float = FindTermSSE2; find_term_double = FindTermSSE2; - use_find_term_sse2 = true; } #endif // TODO: support arm cpu LOG_INFO("find term hook simd type: {}", simd_type); } -void +static void all_boolean_hook() { static std::mutex hook_mutex; std::lock_guard lock(hook_mutex); std::string simd_type = "REF"; #if defined(__x86_64__) - if (use_sse2 && cpu_support_sse2()) { + if (cpu_support_sse2()) { simd_type = "SSE2"; all_false = AllFalseSSE2; all_true = AllTrueSSE2; @@ -189,13 +262,13 @@ all_boolean_hook() { LOG_INFO("AllFalse/AllTrue hook simd type: {}", simd_type); } -void +static void invert_boolean_hook() { static std::mutex hook_mutex; std::lock_guard lock(hook_mutex); std::string simd_type = "REF"; #if defined(__x86_64__) - if (use_sse2 && cpu_support_sse2()) { + if (cpu_support_sse2()) { simd_type = "SSE2"; invert_bool = InvertBoolSSE2; } @@ -207,21 +280,21 @@ invert_boolean_hook() { LOG_INFO("InvertBoolean hook simd type: {}", simd_type); } -void +static void logical_boolean_hook() { static std::mutex hook_mutex; std::lock_guard lock(hook_mutex); std::string simd_type = "REF"; #if defined(__x86_64__) - if (use_avx512 && cpu_support_avx512()) { + if (cpu_support_avx512()) { simd_type = "AVX512"; and_bool = AndBoolAVX512; or_bool = OrBoolAVX512; - } else if (use_avx2 && cpu_support_avx2()) { + } else if (cpu_support_avx2()) { simd_type = "AVX2"; and_bool = AndBoolAVX2; or_bool = OrBoolAVX2; - } else if (use_sse2 && cpu_support_sse2()) { + } else if (cpu_support_sse2()) { simd_type = "SSE2"; and_bool = AndBoolSSE2; or_bool = OrBoolSSE2; @@ -234,17 +307,287 @@ logical_boolean_hook() { // TODO: support arm cpu LOG_INFO("InvertBoolean hook simd type: {}", simd_type); } -void + +static void boolean_hook() { all_boolean_hook(); invert_boolean_hook(); logical_boolean_hook(); } +static void +equal_val_hook() { + static std::mutex hook_mutex; + std::lock_guard lock(hook_mutex); + std::string simd_type = "REF"; +#if defined(__x86_64__) + // Only support avx512 for now + if (cpu_support_avx512()) { + simd_type = "AVX512"; + equal_val_int8_t = EqualValAVX512; + equal_val_int16_t = EqualValAVX512; + equal_val_int32_t = EqualValAVX512; + equal_val_int64_t = EqualValAVX512; + equal_val_float = EqualValAVX512; + equal_val_double = EqualValAVX512; + } +#endif + // TODO: support arm cpu + LOG_INFO("equal val hook simd type: {} ", simd_type); +} + +static void +less_val_hook() { + static std::mutex hook_mutex; + std::lock_guard lock(hook_mutex); + std::string simd_type = "REF"; +#if defined(__x86_64__) + // Only support avx512 for now + if (cpu_support_avx512()) { + simd_type = "AVX512"; + less_val_int8_t = LessValAVX512; + less_val_int16_t = LessValAVX512; + less_val_int32_t = LessValAVX512; + less_val_int64_t = LessValAVX512; + less_val_float = LessValAVX512; + less_val_double = LessValAVX512; + } +#endif + // TODO: support arm cpu + LOG_INFO("less than val hook simd type:{} ", simd_type); +} + +static void +greater_val_hook() { + static std::mutex hook_mutex; + std::lock_guard lock(hook_mutex); + std::string simd_type = "REF"; +#if defined(__x86_64__) + // Only support avx512 for now + if (cpu_support_avx512()) { + simd_type = "AVX512"; + greater_val_int8_t = GreaterValAVX512; + greater_val_int16_t = GreaterValAVX512; + greater_val_int32_t = GreaterValAVX512; + greater_val_int64_t = GreaterValAVX512; + greater_val_float = GreaterValAVX512; + greater_val_double = GreaterValAVX512; + } +#endif + // TODO: support arm cpu + LOG_INFO("greater than val hook simd type: {} ", simd_type); +} + +static void +less_equal_val_hook() { + static std::mutex hook_mutex; + std::lock_guard lock(hook_mutex); + std::string simd_type = "REF"; +#if defined(__x86_64__) + // Only support avx512 for now + if (cpu_support_avx512()) { + simd_type = "AVX512"; + less_equal_val_int8_t = LessEqualValAVX512; + less_equal_val_int16_t = LessEqualValAVX512; + less_equal_val_int32_t = LessEqualValAVX512; + less_equal_val_int64_t = LessEqualValAVX512; + less_equal_val_float = LessEqualValAVX512; + less_equal_val_double = LessEqualValAVX512; + } +#endif + // TODO: support arm cpu + LOG_INFO("less equal than val hook simd type: {} ", simd_type); +} + +static void +greater_equal_val_hook() { + static std::mutex hook_mutex; + std::lock_guard lock(hook_mutex); + std::string simd_type = "REF"; +#if defined(__x86_64__) + // Only support avx512 for now + if (cpu_support_avx512()) { + simd_type = "AVX512"; + greater_equal_val_int8_t = GreaterEqualValAVX512; + greater_equal_val_int16_t = GreaterEqualValAVX512; + greater_equal_val_int32_t = GreaterEqualValAVX512; + greater_equal_val_int64_t = GreaterEqualValAVX512; + greater_equal_val_float = GreaterEqualValAVX512; + greater_equal_val_double = GreaterEqualValAVX512; + } +#endif + // TODO: support arm cpu + LOG_INFO("greater equal than val hook simd type: {} ", simd_type); +} + +static void +not_equal_val_hook() { + static std::mutex hook_mutex; + std::lock_guard lock(hook_mutex); + std::string simd_type = "REF"; +#if defined(__x86_64__) + // Only support avx512 for now + if (cpu_support_avx512()) { + simd_type = "AVX512"; + not_equal_val_int8_t = NotEqualValAVX512; + not_equal_val_int16_t = NotEqualValAVX512; + not_equal_val_int32_t = NotEqualValAVX512; + not_equal_val_int64_t = NotEqualValAVX512; + not_equal_val_float = NotEqualValAVX512; + not_equal_val_double = NotEqualValAVX512; + } +#endif + // TODO: support arm cpu + LOG_INFO("not equal val hook simd type: {}", simd_type); +} + +static void +equal_col_hook() { + static std::mutex hook_mutex; + std::lock_guard lock(hook_mutex); + std::string simd_type = "REF"; +#if defined(__x86_64__) + // Only support avx512 for now + if (cpu_support_avx512()) { + simd_type = "AVX512"; + equal_col_int8_t = EqualColumnAVX512; + equal_col_int16_t = EqualColumnAVX512; + equal_col_int32_t = EqualColumnAVX512; + equal_col_int64_t = EqualColumnAVX512; + equal_col_float = EqualColumnAVX512; + equal_col_double = EqualColumnAVX512; + } +#endif + // TODO: support arm cpu + LOG_INFO("equal column hook simd type:{} ", simd_type); +} + +static void +less_col_hook() { + static std::mutex hook_mutex; + std::lock_guard lock(hook_mutex); + std::string simd_type = "REF"; +#if defined(__x86_64__) + // Only support avx512 for now + if (cpu_support_avx512()) { + simd_type = "AVX512"; + less_col_int8_t = LessColumnAVX512; + less_col_int16_t = LessColumnAVX512; + less_col_int32_t = LessColumnAVX512; + less_col_int64_t = LessColumnAVX512; + less_col_float = LessColumnAVX512; + less_col_double = LessColumnAVX512; + } +#endif + // TODO: support arm cpu + LOG_INFO("less than column hook simd type:{} ", simd_type); +} + +static void +greater_col_hook() { + static std::mutex hook_mutex; + std::lock_guard lock(hook_mutex); + std::string simd_type = "REF"; +#if defined(__x86_64__) + // Only support avx512 for now + if (cpu_support_avx512()) { + simd_type = "AVX512"; + greater_col_int8_t = GreaterColumnAVX512; + greater_col_int16_t = GreaterColumnAVX512; + greater_col_int32_t = GreaterColumnAVX512; + greater_col_int64_t = GreaterColumnAVX512; + greater_col_float = GreaterColumnAVX512; + greater_col_double = GreaterColumnAVX512; + } +#endif + // TODO: support arm cpu + LOG_INFO("greater than column hook simd type:{} ", simd_type); +} + +static void +less_equal_col_hook() { + static std::mutex hook_mutex; + std::lock_guard lock(hook_mutex); + std::string simd_type = "REF"; +#if defined(__x86_64__) + // Only support avx512 for now + if (cpu_support_avx512()) { + simd_type = "AVX512"; + less_equal_col_int8_t = LessEqualColumnAVX512; + less_equal_col_int16_t = LessEqualColumnAVX512; + less_equal_col_int32_t = LessEqualColumnAVX512; + less_equal_col_int64_t = LessEqualColumnAVX512; + less_equal_col_float = LessEqualColumnAVX512; + less_equal_col_double = LessEqualColumnAVX512; + } +#endif + // TODO: support arm cpu + LOG_INFO("less equal than column hook simd type: {}", simd_type); +} + +static void +greater_equal_col_hook() { + static std::mutex hook_mutex; + std::lock_guard lock(hook_mutex); + std::string simd_type = "REF"; +#if defined(__x86_64__) + // Only support avx512 for now + if (cpu_support_avx512()) { + simd_type = "AVX512"; + greater_equal_col_int8_t = GreaterEqualColumnAVX512; + greater_equal_col_int16_t = GreaterEqualColumnAVX512; + greater_equal_col_int32_t = GreaterEqualColumnAVX512; + greater_equal_col_int64_t = GreaterEqualColumnAVX512; + greater_equal_col_float = GreaterEqualColumnAVX512; + greater_equal_col_double = GreaterEqualColumnAVX512; + } +#endif + // TODO: support arm cpu + LOG_INFO("greater equal than column hook simd type:{} ", simd_type); +} + +static void +not_equal_col_hook() { + static std::mutex hook_mutex; + std::lock_guard lock(hook_mutex); + std::string simd_type = "REF"; +#if defined(__x86_64__) + // Only support avx512 for now + if (cpu_support_avx512()) { + simd_type = "AVX512"; + not_equal_col_int8_t = NotEqualColumnAVX512; + not_equal_col_int16_t = NotEqualColumnAVX512; + not_equal_col_int32_t = NotEqualColumnAVX512; + not_equal_col_int64_t = NotEqualColumnAVX512; + not_equal_col_float = NotEqualColumnAVX512; + not_equal_col_double = NotEqualColumnAVX512; + } +#endif + // TODO: support arm cpu + LOG_INFO("not equal column hook simd type: {}", simd_type); +} + +static void +compare_hook() { + equal_val_hook(); + less_val_hook(); + greater_val_hook(); + less_equal_val_hook(); + greater_equal_val_hook(); + not_equal_val_hook(); + equal_col_hook(); + less_col_hook(); + greater_col_hook(); + less_equal_col_hook(); + greater_equal_col_hook(); + not_equal_col_hook(); +} + static int init_hook_ = []() { bitset_hook(); - find_term_hook(); boolean_hook(); + find_term_hook(); + compare_hook(); return 0; }(); diff --git a/internal/core/src/simd/hook.h b/internal/core/src/simd/hook.h index 98e82853ae21d..2ffbbd81442d7 100644 --- a/internal/core/src/simd/hook.h +++ b/internal/core/src/simd/hook.h @@ -18,41 +18,6 @@ namespace milvus { namespace simd { -extern BitsetBlockType (*get_bitset_block)(const bool* src); -extern bool (*all_false)(const bool* src, int64_t size); -extern bool (*all_true)(const bool* src, int64_t size); -extern void (*invert_bool)(bool* src, int64_t size); -extern void (*and_bool)(bool* left, bool* right, int64_t size); -extern void (*or_bool)(bool* left, bool* right, int64_t size); - -template -using FindTermPtr = bool (*)(const T* src, size_t size, T val); - -extern FindTermPtr find_term_bool; -extern FindTermPtr find_term_int8; -extern FindTermPtr find_term_int16; -extern FindTermPtr find_term_int32; -extern FindTermPtr find_term_int64; -extern FindTermPtr find_term_float; -extern FindTermPtr find_term_double; - -#if defined(__x86_64__) -// Flags that indicate whether runtime can choose -// these simd type or not when hook starts. -extern bool use_avx512; -extern bool use_avx2; -extern bool use_sse4_2; -extern bool use_sse2; - -// Flags that indicate which kind of simd for -// different function when hook ends. -extern bool use_bitset_sse2; -extern bool use_find_term_sse2; -extern bool use_find_term_sse4_2; -extern bool use_find_term_avx2; -extern bool use_find_term_avx512; -#endif - #if defined(__x86_64__) bool cpu_support_avx512(); @@ -62,53 +27,135 @@ bool cpu_support_sse4_2(); #endif -void -bitset_hook(); - -void -find_term_hook(); - -void -boolean_hook(); - -void -all_boolean_hook(); - -void -invert_boolean_hook(); +extern BitsetBlockType (*get_bitset_block)(const bool* src); +extern bool (*all_false)(const bool* src, int64_t size); +extern bool (*all_true)(const bool* src, int64_t size); +extern void (*invert_bool)(bool* src, int64_t size); +extern void (*and_bool)(bool* left, bool* right, int64_t size); +extern void (*or_bool)(bool* left, bool* right, int64_t size); -void -logical_boolean_hook(); +template +using FindTermPtr = bool (*)(const T* src, size_t size, T val); +#define EXTERN_FIND_TERM_PTR(type) extern FindTermPtr find_term_##type; + +EXTERN_FIND_TERM_PTR(bool) +EXTERN_FIND_TERM_PTR(int8_t) +EXTERN_FIND_TERM_PTR(int16_t) +EXTERN_FIND_TERM_PTR(int32_t) +EXTERN_FIND_TERM_PTR(int64_t) +EXTERN_FIND_TERM_PTR(float) +EXTERN_FIND_TERM_PTR(double) + +// Compare val function register +// Such as A == 10, A < 10... +template +using CompareValPtr = void (*)(const T* src, size_t size, T val, bool* res); +#define EXTERN_COMPARE_VAL_PTR(prefix, type) \ + extern CompareValPtr prefix##_##type; +// Compare column function register +// Such as A == B, A < B... template -bool -find_term_func(const T* data, size_t size, T val) { - static_assert( - std::is_integral::value || std::is_floating_point::value, - "T must be integral or float/double type"); - - if constexpr (std::is_same_v) { - return milvus::simd::find_term_bool(data, size, val); - } - if constexpr (std::is_same_v) { - return milvus::simd::find_term_int8(data, size, val); - } - if constexpr (std::is_same_v) { - return milvus::simd::find_term_int16(data, size, val); - } - if constexpr (std::is_same_v) { - return milvus::simd::find_term_int32(data, size, val); - } - if constexpr (std::is_same_v) { - return milvus::simd::find_term_int64(data, size, val); - } - if constexpr (std::is_same_v) { - return milvus::simd::find_term_float(data, size, val); - } - if constexpr (std::is_same_v) { - return milvus::simd::find_term_double(data, size, val); - } -} +using CompareColPtr = + void (*)(const T* left, const T* right, size_t size, bool* res); +#define EXTERN_COMPARE_COL_PTR(prefix, type) \ + extern CompareColPtr prefix##_##type; + +EXTERN_COMPARE_VAL_PTR(equal_val, bool) +EXTERN_COMPARE_VAL_PTR(equal_val, int8_t) +EXTERN_COMPARE_VAL_PTR(equal_val, int16_t) +EXTERN_COMPARE_VAL_PTR(equal_val, int32_t) +EXTERN_COMPARE_VAL_PTR(equal_val, int64_t) +EXTERN_COMPARE_VAL_PTR(equal_val, float) +EXTERN_COMPARE_VAL_PTR(equal_val, double) + +EXTERN_COMPARE_VAL_PTR(less_val, bool) +EXTERN_COMPARE_VAL_PTR(less_val, int8_t) +EXTERN_COMPARE_VAL_PTR(less_val, int16_t) +EXTERN_COMPARE_VAL_PTR(less_val, int32_t) +EXTERN_COMPARE_VAL_PTR(less_val, int64_t) +EXTERN_COMPARE_VAL_PTR(less_val, float) +EXTERN_COMPARE_VAL_PTR(less_val, double) + +EXTERN_COMPARE_VAL_PTR(greater_val, bool) +EXTERN_COMPARE_VAL_PTR(greater_val, int8_t) +EXTERN_COMPARE_VAL_PTR(greater_val, int16_t) +EXTERN_COMPARE_VAL_PTR(greater_val, int32_t) +EXTERN_COMPARE_VAL_PTR(greater_val, int64_t) +EXTERN_COMPARE_VAL_PTR(greater_val, float) +EXTERN_COMPARE_VAL_PTR(greater_val, double) + +EXTERN_COMPARE_VAL_PTR(less_equal_val, bool) +EXTERN_COMPARE_VAL_PTR(less_equal_val, int8_t) +EXTERN_COMPARE_VAL_PTR(less_equal_val, int16_t) +EXTERN_COMPARE_VAL_PTR(less_equal_val, int32_t) +EXTERN_COMPARE_VAL_PTR(less_equal_val, int64_t) +EXTERN_COMPARE_VAL_PTR(less_equal_val, float) +EXTERN_COMPARE_VAL_PTR(less_equal_val, double) + +EXTERN_COMPARE_VAL_PTR(greater_equal_val, bool) +EXTERN_COMPARE_VAL_PTR(greater_equal_val, int8_t) +EXTERN_COMPARE_VAL_PTR(greater_equal_val, int16_t) +EXTERN_COMPARE_VAL_PTR(greater_equal_val, int32_t) +EXTERN_COMPARE_VAL_PTR(greater_equal_val, int64_t) +EXTERN_COMPARE_VAL_PTR(greater_equal_val, float) +EXTERN_COMPARE_VAL_PTR(greater_equal_val, double) + +EXTERN_COMPARE_VAL_PTR(not_equal_val, bool) +EXTERN_COMPARE_VAL_PTR(not_equal_val, int8_t) +EXTERN_COMPARE_VAL_PTR(not_equal_val, int16_t) +EXTERN_COMPARE_VAL_PTR(not_equal_val, int32_t) +EXTERN_COMPARE_VAL_PTR(not_equal_val, int64_t) +EXTERN_COMPARE_VAL_PTR(not_equal_val, float) +EXTERN_COMPARE_VAL_PTR(not_equal_val, double) + +EXTERN_COMPARE_COL_PTR(equal_col, bool) +EXTERN_COMPARE_COL_PTR(equal_col, int8_t) +EXTERN_COMPARE_COL_PTR(equal_col, int16_t) +EXTERN_COMPARE_COL_PTR(equal_col, int32_t) +EXTERN_COMPARE_COL_PTR(equal_col, int64_t) +EXTERN_COMPARE_COL_PTR(equal_col, float) +EXTERN_COMPARE_COL_PTR(equal_col, double) + +EXTERN_COMPARE_COL_PTR(less_col, bool) +EXTERN_COMPARE_COL_PTR(less_col, int8_t) +EXTERN_COMPARE_COL_PTR(less_col, int16_t) +EXTERN_COMPARE_COL_PTR(less_col, int32_t) +EXTERN_COMPARE_COL_PTR(less_col, int64_t) +EXTERN_COMPARE_COL_PTR(less_col, float) +EXTERN_COMPARE_COL_PTR(less_col, double) + +EXTERN_COMPARE_COL_PTR(greater_col, bool) +EXTERN_COMPARE_COL_PTR(greater_col, int8_t) +EXTERN_COMPARE_COL_PTR(greater_col, int16_t) +EXTERN_COMPARE_COL_PTR(greater_col, int32_t) +EXTERN_COMPARE_COL_PTR(greater_col, int64_t) +EXTERN_COMPARE_COL_PTR(greater_col, float) +EXTERN_COMPARE_COL_PTR(greater_col, double) + +EXTERN_COMPARE_COL_PTR(less_equal_col, bool) +EXTERN_COMPARE_COL_PTR(less_equal_col, int8_t) +EXTERN_COMPARE_COL_PTR(less_equal_col, int16_t) +EXTERN_COMPARE_COL_PTR(less_equal_col, int32_t) +EXTERN_COMPARE_COL_PTR(less_equal_col, int64_t) +EXTERN_COMPARE_COL_PTR(less_equal_col, float) +EXTERN_COMPARE_COL_PTR(less_equal_col, double) + +EXTERN_COMPARE_COL_PTR(greater_equal_col, bool) +EXTERN_COMPARE_COL_PTR(greater_equal_col, int8_t) +EXTERN_COMPARE_COL_PTR(greater_equal_col, int16_t) +EXTERN_COMPARE_COL_PTR(greater_equal_col, int32_t) +EXTERN_COMPARE_COL_PTR(greater_equal_col, int64_t) +EXTERN_COMPARE_COL_PTR(greater_equal_col, float) +EXTERN_COMPARE_COL_PTR(greater_equal_col, double) + +EXTERN_COMPARE_COL_PTR(not_equal_col, bool) +EXTERN_COMPARE_COL_PTR(not_equal_col, int8_t) +EXTERN_COMPARE_COL_PTR(not_equal_col, int16_t) +EXTERN_COMPARE_COL_PTR(not_equal_col, int32_t) +EXTERN_COMPARE_COL_PTR(not_equal_col, int64_t) +EXTERN_COMPARE_COL_PTR(not_equal_col, float) +EXTERN_COMPARE_COL_PTR(not_equal_col, double) } // namespace simd } // namespace milvus diff --git a/internal/core/src/simd/interface.h b/internal/core/src/simd/interface.h new file mode 100644 index 0000000000000..e93a5c31dc94c --- /dev/null +++ b/internal/core/src/simd/interface.h @@ -0,0 +1,264 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include "hook.h" +namespace milvus { +namespace simd { + +#define DISPATCH_FIND_TERM_SIMD_FUNC(type) \ + if constexpr (std::is_same_v) { \ + return milvus::simd::find_term_##type(data, size, val); \ + } + +#define DISPATCH_COMPARE_VAL_SIMD_FUNC(prefix, type) \ + if constexpr (std::is_same_v) { \ + return milvus::simd::prefix##_##type(data, size, val, res); \ + } + +#define DISPATCH_COMPARE_COL_SIMD_FUNC(prefix, type) \ + if constexpr (std::is_same_v) { \ + return milvus::simd::prefix##_##type(left, right, size, res); \ + } + +template +bool +find_term_func(const T* data, size_t size, T val) { + static_assert( + std::is_integral::value || std::is_floating_point::value, + "T must be integral or float/double type"); + + DISPATCH_FIND_TERM_SIMD_FUNC(bool) + DISPATCH_FIND_TERM_SIMD_FUNC(int8_t) + DISPATCH_FIND_TERM_SIMD_FUNC(int16_t) + DISPATCH_FIND_TERM_SIMD_FUNC(int32_t) + DISPATCH_FIND_TERM_SIMD_FUNC(int64_t) + DISPATCH_FIND_TERM_SIMD_FUNC(float) + DISPATCH_FIND_TERM_SIMD_FUNC(double) +} + +template +void +equal_val_func(const T* data, int64_t size, T val, bool* res) { + static_assert( + std::is_integral::value || std::is_floating_point::value, + "T must be integral or float/double type"); + + DISPATCH_COMPARE_VAL_SIMD_FUNC(equal_val, bool) + DISPATCH_COMPARE_VAL_SIMD_FUNC(equal_val, int8_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(equal_val, int16_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(equal_val, int32_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(equal_val, int64_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(equal_val, float) + DISPATCH_COMPARE_VAL_SIMD_FUNC(equal_val, double) +} + +template +void +less_val_func(const T* data, int64_t size, T val, bool* res) { + static_assert( + std::is_integral::value || std::is_floating_point::value, + "T must be integral or float/double type"); + + DISPATCH_COMPARE_VAL_SIMD_FUNC(less_val, bool) + DISPATCH_COMPARE_VAL_SIMD_FUNC(less_val, int8_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(less_val, int16_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(less_val, int32_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(less_val, int64_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(less_val, float) + DISPATCH_COMPARE_VAL_SIMD_FUNC(less_val, double) +} + +template +void +greater_val_func(const T* data, int64_t size, T val, bool* res) { + static_assert( + std::is_integral::value || std::is_floating_point::value, + "T must be integral or float/double type"); + + DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_val, bool) + DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_val, int8_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_val, int16_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_val, int32_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_val, int64_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_val, float) + DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_val, double) +} + +template +void +less_equal_val_func(const T* data, int64_t size, T val, bool* res) { + static_assert( + std::is_integral::value || std::is_floating_point::value, + "T must be integral or float/double type"); + + DISPATCH_COMPARE_VAL_SIMD_FUNC(less_equal_val, bool) + DISPATCH_COMPARE_VAL_SIMD_FUNC(less_equal_val, int8_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(less_equal_val, int16_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(less_equal_val, int32_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(less_equal_val, int64_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(less_equal_val, float) + DISPATCH_COMPARE_VAL_SIMD_FUNC(less_equal_val, double) +} + +template +void +greater_equal_val_func(const T* data, int64_t size, T val, bool* res) { + static_assert( + std::is_integral::value || std::is_floating_point::value, + "T must be integral or float/double type"); + + DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_equal_val, bool) + DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_equal_val, int8_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_equal_val, int16_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_equal_val, int32_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_equal_val, int64_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_equal_val, float) + DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_equal_val, double) +} + +template +void +not_equal_val_func(const T* data, int64_t size, T val, bool* res) { + static_assert( + std::is_integral::value || std::is_floating_point::value, + "T must be integral or float/double type"); + + DISPATCH_COMPARE_VAL_SIMD_FUNC(not_equal_val, bool) + DISPATCH_COMPARE_VAL_SIMD_FUNC(not_equal_val, int8_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(not_equal_val, int16_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(not_equal_val, int32_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(not_equal_val, int64_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(not_equal_val, float) + DISPATCH_COMPARE_VAL_SIMD_FUNC(not_equal_val, double) +} + +template +void +equal_col_func(const T* left, const T* right, int64_t size, bool* res) { + static_assert( + std::is_integral::value || std::is_floating_point::value, + "T must be integral or float/double type"); + + DISPATCH_COMPARE_COL_SIMD_FUNC(equal_col, bool) + DISPATCH_COMPARE_COL_SIMD_FUNC(equal_col, int8_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(equal_col, int16_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(equal_col, int32_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(equal_col, int64_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(equal_col, float) + DISPATCH_COMPARE_COL_SIMD_FUNC(equal_col, double) +} + +template +void +less_col_func(const T* left, const T* right, int64_t size, bool* res) { + static_assert( + std::is_integral::value || std::is_floating_point::value, + "T must be integral or float/double type"); + + DISPATCH_COMPARE_COL_SIMD_FUNC(less_col, bool) + DISPATCH_COMPARE_COL_SIMD_FUNC(less_col, int8_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(less_col, int16_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(less_col, int32_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(less_col, int64_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(less_col, float) + DISPATCH_COMPARE_COL_SIMD_FUNC(less_col, double) +} + +template +void +greater_col_func(const T* left, const T* right, int64_t size, bool* res) { + static_assert( + std::is_integral::value || std::is_floating_point::value, + "T must be integral or float/double type"); + + DISPATCH_COMPARE_COL_SIMD_FUNC(greater_col, bool) + DISPATCH_COMPARE_COL_SIMD_FUNC(greater_col, int8_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(greater_col, int16_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(greater_col, int32_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(greater_col, int64_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(greater_col, float) + DISPATCH_COMPARE_COL_SIMD_FUNC(greater_col, double) +} + +template +void +less_equal_col_func(const T* left, const T* right, int64_t size, bool* res) { + static_assert( + std::is_integral::value || std::is_floating_point::value, + "T must be integral or float/double type"); + + DISPATCH_COMPARE_COL_SIMD_FUNC(less_equal_col, bool) + DISPATCH_COMPARE_COL_SIMD_FUNC(less_equal_col, int8_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(less_equal_col, int16_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(less_equal_col, int32_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(less_equal_col, int64_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(less_equal_col, float) + DISPATCH_COMPARE_COL_SIMD_FUNC(less_equal_col, double) +} + +template +void +greater_equal_col_func(const T* left, const T* right, int64_t size, bool* res) { + static_assert( + std::is_integral::value || std::is_floating_point::value, + "T must be integral or float/double type"); + + DISPATCH_COMPARE_COL_SIMD_FUNC(greater_equal_col, bool) + DISPATCH_COMPARE_COL_SIMD_FUNC(greater_equal_col, int8_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(greater_equal_col, int16_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(greater_equal_col, int32_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(greater_equal_col, int64_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(greater_equal_col, float) + DISPATCH_COMPARE_COL_SIMD_FUNC(greater_equal_col, double) +} + +template +void +not_equal_col_func(const T* left, const T* right, int64_t size, bool* res) { + static_assert( + std::is_integral::value || std::is_floating_point::value, + "T must be integral or float/double type"); + + DISPATCH_COMPARE_COL_SIMD_FUNC(not_equal_col, bool) + DISPATCH_COMPARE_COL_SIMD_FUNC(not_equal_col, int8_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(not_equal_col, int16_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(not_equal_col, int32_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(not_equal_col, int64_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(not_equal_col, float) + DISPATCH_COMPARE_COL_SIMD_FUNC(not_equal_col, double) +} + +template +void +compare_col_func(CompareType cmp_type, + const T* left, + const T* right, + int64_t size, + bool* res) { + if (cmp_type == CompareType::EQ) { + equal_col_func(left, right, size, res); + } else if (cmp_type == CompareType::NEQ) { + not_equal_col_func(left, right, size, res); + } else if (cmp_type == CompareType::GE) { + greater_equal_col_func(left, right, size, res); + } else if (cmp_type == CompareType::GT) { + greater_col_func(left, right, size, res); + } else if (cmp_type == CompareType::LE) { + less_equal_col_func(left, right, size, res); + } else if (cmp_type == CompareType::LT) { + less_col_func(left, right, size, res); + } +} + +} // namespace simd +} // namespace milvus diff --git a/internal/core/src/simd/ref.h b/internal/core/src/simd/ref.h index 6e90c7215a9a6..f3b7af1a0c621 100644 --- a/internal/core/src/simd/ref.h +++ b/internal/core/src/simd/ref.h @@ -45,5 +45,99 @@ FindTermRef(const T* src, size_t size, T val) { return false; } +template +void +EqualValRef(const T* src, size_t size, T val, bool* res) { + for (size_t i = 0; i < size; ++i) { + res[i] = src[i] == val; + } +} + +template +void +LessValRef(const T* src, size_t size, T val, bool* res) { + for (size_t i = 0; i < size; ++i) { + res[i] = src[i] < val; + } +} + +template +void +GreaterValRef(const T* src, size_t size, T val, bool* res) { + for (size_t i = 0; i < size; ++i) { + res[i] = src[i] > val; + } +} + +template +void +LessEqualValRef(const T* src, size_t size, T val, bool* res) { + for (size_t i = 0; i < size; ++i) { + res[i] = src[i] <= val; + } +} +template +void +GreaterEqualValRef(const T* src, size_t size, T val, bool* res) { + for (size_t i = 0; i < size; ++i) { + res[i] = src[i] >= val; + } +} +template +void +NotEqualValRef(const T* src, size_t size, T val, bool* res) { + for (size_t i = 0; i < size; ++i) { + res[i] = src[i] != val; + } +} + +template +void +EqualColumnRef(const T* left, const T* right, size_t size, bool* res) { + for (size_t i = 0; i < size; ++i) { + res[i] = left[i] == right[i]; + } +} + +template +void +LessColumnRef(const T* left, const T* right, size_t size, bool* res) { + for (size_t i = 0; i < size; ++i) { + res[i] = left[i] < right[i]; + } +} + +template +void +LessEqualColumnRef(const T* left, const T* right, size_t size, bool* res) { + for (size_t i = 0; i < size; ++i) { + res[i] = left[i] <= right[i]; + } +} + +template +void +GreaterColumnRef(const T* left, const T* right, size_t size, bool* res) { + for (size_t i = 0; i < size; ++i) { + res[i] = left[i] > right[i]; + } +} + +template +void +GreaterEqualColumnRef(const T* left, const T* right, size_t size, bool* res) { + for (size_t i = 0; i < size; ++i) { + res[i] = left[i] >= right[i]; + } +} + +template +void +NotEqualColumnRef(const T* left, const T* right, size_t size, bool* res) { + for (size_t i = 0; i < size; ++i) { + res[i] = left[i] != right[i]; + } +} + } // namespace simd } // namespace milvus diff --git a/internal/core/src/simd/sse2.cpp b/internal/core/src/simd/sse2.cpp index 40542bf22baca..9726aec946053 100644 --- a/internal/core/src/simd/sse2.cpp +++ b/internal/core/src/simd/sse2.cpp @@ -61,9 +61,8 @@ FindTermSSE2(const bool* src, size_t vec_size, bool val) { __m128i xmm_target = _mm_set1_epi8(val); __m128i xmm_data; size_t num_chunks = vec_size / 16; - for (size_t i = 0; i < num_chunks; i++) { - xmm_data = - _mm_loadu_si128(reinterpret_cast(src + 16 * i)); + for (size_t i = 0; i < num_chunks * 16; i += 16) { + xmm_data = _mm_loadu_si128(reinterpret_cast(src + i)); __m128i xmm_match = _mm_cmpeq_epi8(xmm_data, xmm_target); int mask = _mm_movemask_epi8(xmm_match); if (mask != 0) { @@ -71,7 +70,7 @@ FindTermSSE2(const bool* src, size_t vec_size, bool val) { } } - for (size_t i = 16 * num_chunks; i < vec_size; ++i) { + for (size_t i = num_chunks * 16; i < vec_size; ++i) { if (src[i] == val) { return true; } @@ -86,9 +85,8 @@ FindTermSSE2(const int8_t* src, size_t vec_size, int8_t val) { __m128i xmm_target = _mm_set1_epi8(val); __m128i xmm_data; size_t num_chunks = vec_size / 16; - for (size_t i = 0; i < num_chunks; i++) { - xmm_data = - _mm_loadu_si128(reinterpret_cast(src + 16 * i)); + for (size_t i = 0; i < num_chunks * 16; i += 16) { + xmm_data = _mm_loadu_si128(reinterpret_cast(src + i)); __m128i xmm_match = _mm_cmpeq_epi8(xmm_data, xmm_target); int mask = _mm_movemask_epi8(xmm_match); if (mask != 0) { @@ -96,7 +94,7 @@ FindTermSSE2(const int8_t* src, size_t vec_size, int8_t val) { } } - for (size_t i = 16 * num_chunks; i < vec_size; ++i) { + for (size_t i = num_chunks * 16; i < vec_size; ++i) { if (src[i] == val) { return true; } @@ -111,9 +109,8 @@ FindTermSSE2(const int16_t* src, size_t vec_size, int16_t val) { __m128i xmm_target = _mm_set1_epi16(val); __m128i xmm_data; size_t num_chunks = vec_size / 8; - for (size_t i = 0; i < num_chunks; i++) { - xmm_data = - _mm_loadu_si128(reinterpret_cast(src + i * 8)); + for (size_t i = 0; i < num_chunks * 8; i += 8) { + xmm_data = _mm_loadu_si128(reinterpret_cast(src + i)); __m128i xmm_match = _mm_cmpeq_epi16(xmm_data, xmm_target); int mask = _mm_movemask_epi8(xmm_match); if (mask != 0) { @@ -121,7 +118,7 @@ FindTermSSE2(const int16_t* src, size_t vec_size, int16_t val) { } } - for (size_t i = 8 * num_chunks; i < vec_size; ++i) { + for (size_t i = num_chunks * 8; i < vec_size; ++i) { if (src[i] == val) { return true; } @@ -136,9 +133,9 @@ FindTermSSE2(const int32_t* src, size_t vec_size, int32_t val) { size_t remaining_size = vec_size % 4; __m128i xmm_target = _mm_set1_epi32(val); - for (size_t i = 0; i < num_chunk; ++i) { + for (size_t i = 0; i < num_chunk * 4; i += 4) { __m128i xmm_data = - _mm_loadu_si128(reinterpret_cast(src + i * 4)); + _mm_loadu_si128(reinterpret_cast(src + i)); __m128i xmm_match = _mm_cmpeq_epi32(xmm_data, xmm_target); int mask = _mm_movemask_epi8(xmm_match); if (mask != 0) { @@ -180,9 +177,9 @@ FindTermSSE2(const int64_t* src, size_t vec_size, int64_t val) { size_t num_chunk = vec_size / 2; size_t remaining_size = vec_size % 2; - for (int64_t i = 0; i < num_chunk; i++) { + for (int64_t i = 0; i < num_chunk * 2; i += 2) { __m128i xmm_vec = - _mm_load_si128(reinterpret_cast(src + i * 2)); + _mm_load_si128(reinterpret_cast(src + i)); __m128i xmm_low = _mm_set1_epi32(low); __m128i xmm_high = _mm_set1_epi32(high); @@ -203,13 +200,6 @@ FindTermSSE2(const int64_t* src, size_t vec_size, int64_t val) { } } return false; - - // for (size_t i = 0; i < vec_size; ++i) { - // if (src[i] == val) { - // return true; - // } - // } - // return false; } template <> @@ -217,8 +207,8 @@ bool FindTermSSE2(const float* src, size_t vec_size, float val) { size_t num_chunks = vec_size / 4; __m128 xmm_target = _mm_set1_ps(val); - for (int i = 0; i < num_chunks; ++i) { - __m128 xmm_data = _mm_loadu_ps(src + 4 * i); + for (int i = 0; i < 4 * num_chunks; i += 4) { + __m128 xmm_data = _mm_loadu_ps(src + i); __m128 xmm_match = _mm_cmpeq_ps(xmm_data, xmm_target); int mask = _mm_movemask_ps(xmm_match); if (mask != 0) { @@ -239,8 +229,8 @@ bool FindTermSSE2(const double* src, size_t vec_size, double val) { size_t num_chunks = vec_size / 2; __m128d xmm_target = _mm_set1_pd(val); - for (int i = 0; i < num_chunks; ++i) { - __m128d xmm_data = _mm_loadu_pd(src + 2 * i); + for (int i = 0; i < 2 * num_chunks; i += 2) { + __m128d xmm_data = _mm_loadu_pd(src + i); __m128d xmm_match = _mm_cmpeq_pd(xmm_data, xmm_target); int mask = _mm_movemask_pd(xmm_match); if (mask != 0) { diff --git a/internal/core/src/simd/sse4.cpp b/internal/core/src/simd/sse4.cpp index 8585f9c648af9..bf3d08c76bc71 100644 --- a/internal/core/src/simd/sse4.cpp +++ b/internal/core/src/simd/sse4.cpp @@ -32,9 +32,9 @@ FindTermSSE4(const int64_t* src, size_t vec_size, int64_t val) { size_t remaining_size = vec_size % 2; __m128i xmm_target = _mm_set1_epi64x(val); - for (size_t i = 0; i < num_chunk; ++i) { + for (size_t i = 0; i < num_chunk * 2; i += 2) { __m128i xmm_data = - _mm_loadu_si128(reinterpret_cast(src + i * 2)); + _mm_loadu_si128(reinterpret_cast(src + i)); __m128i xmm_match = _mm_cmpeq_epi64(xmm_data, xmm_target); int mask = _mm_movemask_epi8(xmm_match); if (mask != 0) { diff --git a/internal/core/unittest/test_simd.cpp b/internal/core/unittest/test_simd.cpp index edfc410c23c02..cb157436c50e7 100644 --- a/internal/core/unittest/test_simd.cpp +++ b/internal/core/unittest/test_simd.cpp @@ -38,6 +38,7 @@ using FixedVector = boost::container::vector; #include "simd/sse4.h" #include "simd/avx2.h" #include "simd/avx512.h" +#include "simd/ref.h" using namespace milvus::simd; TEST(GetBitSetBlock, base_test_sse) { @@ -107,6 +108,30 @@ TEST(GetBitSetBlock, base_test_sse) { ASSERT_EQ(res, 0x1084210842108421); } +TEST(GetBitsetBlockPerf, bitset) { + FixedVector srcs; + for (size_t i = 0; i < 100000000; ++i) { + srcs.push_back(i % 2 == 0); + } + std::cout << "start test" << std::endl; + auto start = std::chrono::steady_clock::now(); + for (int i = 0; i < 10000000; ++i) + auto result = GetBitsetBlockSSE2(srcs.data() + i); + std::cout << "cost: " + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << "us" << std::endl; + start = std::chrono::steady_clock::now(); + for (int i = 0; i < 10000000; ++i) + auto result = GetBitsetBlockAVX2(srcs.data() + i); + std::cout << "cost: " + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << "us" << std::endl; +} + TEST(GetBitSetBlock, base_test_avx2) { FixedVector src; for (int i = 0; i < 64; ++i) { @@ -1214,10 +1239,298 @@ TEST(AllBooleanNeon, performance) { } } +TEST(EqualVal, perf_int8) { + if (!cpu_support_avx512()) { + PRINT_SKPI_TEST + return; + } + std::vector srcs(1000000); + for (int i = 0; i < 1000000; ++i) { + srcs[i] = i % 128; + } + FixedVector res(1000000); + auto start = std::chrono::steady_clock::now(); + EqualValRef(srcs.data(), 1000000, (int8_t)10, res.data()); + std::cout << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + start = std::chrono::steady_clock::now(); + EqualValAVX512(srcs.data(), 1000000, (int8_t)10, res.data()); + std::cout << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; +} + +template +void +TestCompareValAVX512Perf() { + if (!cpu_support_avx512()) { + PRINT_SKPI_TEST + return; + } + std::vector srcs(1000000); + for (int i = 0; i < 1000000; ++i) { + srcs[i] = i; + } + FixedVector res(1000000); + T target = 10; + auto start = std::chrono::steady_clock::now(); + EqualValRef(srcs.data(), 1000000, target, res.data()); + std::cout << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + start = std::chrono::steady_clock::now(); + EqualValAVX512(srcs.data(), 1000000, target, res.data()); + std::cout << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; +} + +TEST(EqualVal, perf_int16) { + TestCompareValAVX512Perf(); +} + +TEST(EqualVal, pref_int32) { + TestCompareValAVX512Perf(); +} + +TEST(EqualVal, perf_int64) { + TestCompareValAVX512Perf(); +} + +TEST(EqualVal, perf_float) { + TestCompareValAVX512Perf(); +} + +TEST(EqualVal, perf_double) { + TestCompareValAVX512Perf(); +} + +template +void +TestCompareValAVX512(int size, T target) { + if (!cpu_support_avx512()) { + PRINT_SKPI_TEST + return; + } + std::vector vecs; + for (int i = 0; i < size; ++i) { + if constexpr (std::is_same_v) { + vecs.push_back(i % 127); + } else if constexpr (std::is_floating_point_v) { + vecs.push_back(i + 0.01); + } else { + vecs.push_back(i); + } + } + FixedVector res(size); + + EqualValAVX512(vecs.data(), size, target, res.data()); + for (int i = 0; i < size; i++) { + ASSERT_EQ(res[i], vecs[i] == target) << i; + } + LessValAVX512(vecs.data(), size, target, res.data()); + for (int i = 0; i < size; i++) { + ASSERT_EQ(res[i], vecs[i] < target) << i; + } + LessEqualValAVX512(vecs.data(), size, target, res.data()); + for (int i = 0; i < size; i++) { + ASSERT_EQ(res[i], vecs[i] <= target) << i; + } + GreaterEqualValAVX512(vecs.data(), size, target, res.data()); + for (int i = 0; i < size; i++) { + ASSERT_EQ(res[i], vecs[i] >= target) << i; + } + GreaterValAVX512(vecs.data(), size, target, res.data()); + for (int i = 0; i < size; i++) { + ASSERT_EQ(res[i], vecs[i] > target) << i; + } + NotEqualValAVX512(vecs.data(), size, target, res.data()); + for (int i = 0; i < size; i++) { + ASSERT_EQ(res[i], vecs[i] != target) << i; + } +} + +TEST(CompareVal, avx512_int8) { + TestCompareValAVX512(1000, 9); + TestCompareValAVX512(1000, 99); + TestCompareValAVX512(1001, 127); +} + +TEST(CompareVal, avx512_int16) { + TestCompareValAVX512(1000, 99); + TestCompareValAVX512(1000, 999); + TestCompareValAVX512(1001, 1000); +} + +TEST(CompareVal, avx512_int32) { + TestCompareValAVX512(1000, 99); + TestCompareValAVX512(1000, 999); + TestCompareValAVX512(1001, 1000); +} + +TEST(CompareVal, avx512_int64) { + TestCompareValAVX512(1000, 99); + TestCompareValAVX512(1000, 999); + TestCompareValAVX512(1001, 1000); +} + +TEST(CompareVal, avx512_float) { + TestCompareValAVX512(1000, 99.01); + TestCompareValAVX512(1000, 999.01); + TestCompareValAVX512(1001, 1000.01); +} + +TEST(CompareVal, avx512_double) { + TestCompareValAVX512(1000, 99.01); + TestCompareValAVX512(1000, 999.01); + TestCompareValAVX512(1001, 1000.01); +} + +template +void +TestCompareColumnAVX512Perf() { + if (!cpu_support_avx512()) { + PRINT_SKPI_TEST + return; + } + std::vector lefts(1000000); + for (int i = 0; i < 1000000; ++i) { + lefts[i] = i; + } + std::vector rights(1000000); + for (int i = 0; i < 1000000; ++i) { + rights[i] = i; + } + FixedVector res(1000000); + auto start = std::chrono::steady_clock::now(); + LessColumnRef(lefts.data(), rights.data(), 1000000, res.data()); + std::cout << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + start = std::chrono::steady_clock::now(); + LessColumnAVX512(lefts.data(), rights.data(), 1000000, res.data()); + std::cout << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; +} + +TEST(LessColumn, pref_int32) { + TestCompareColumnAVX512Perf(); +} + +TEST(LessColumn, perf_int64) { + TestCompareColumnAVX512Perf(); +} + +TEST(LessColumn, perf_float) { + TestCompareColumnAVX512Perf(); +} + +TEST(LessColumn, perf_double) { + TestCompareColumnAVX512Perf(); +} + +template +void +TestCompareColumnAVX512(int size, T min_val, T max_val) { + if (!cpu_support_avx512()) { + PRINT_SKPI_TEST + return; + } + std::random_device rd; + std::mt19937 gen(rd()); + + std::vector left; + std::vector right; + if constexpr (std::is_same_v) { + std::uniform_real_distribution dis(min_val, max_val); + for (int i = 0; i < size; ++i) { + left.push_back(dis(gen)); + right.push_back(dis(gen)); + } + } else if constexpr (std::is_same_v) { + std::uniform_real_distribution dis(min_val, max_val); + for (int i = 0; i < size; ++i) { + left.push_back(dis(gen)); + right.push_back(dis(gen)); + } + } else { + std::uniform_int_distribution<> dis(min_val, max_val); + for (int i = 0; i < size; ++i) { + left.push_back(dis(gen)); + right.push_back(dis(gen)); + } + } + + FixedVector res(size); + + EqualColumnAVX512(left.data(), right.data(), size, res.data()); + for (int i = 0; i < size; i++) { + ASSERT_EQ(res[i], left[i] == right[i]) << i; + } + LessColumnAVX512(left.data(), right.data(), size, res.data()); + for (int i = 0; i < size; i++) { + ASSERT_EQ(res[i], left[i] < right[i]) << i; + } + GreaterColumnAVX512(left.data(), right.data(), size, res.data()); + for (int i = 0; i < size; i++) { + ASSERT_EQ(res[i], left[i] > right[i]) << i; + } + LessEqualColumnAVX512(left.data(), right.data(), size, res.data()); + for (int i = 0; i < size; i++) { + ASSERT_EQ(res[i], left[i] <= right[i]) << i; + } + GreaterEqualColumnAVX512(left.data(), right.data(), size, res.data()); + for (int i = 0; i < size; i++) { + ASSERT_EQ(res[i], left[i] >= right[i]) << i; + } + NotEqualColumnAVX512(left.data(), right.data(), size, res.data()); + for (int i = 0; i < size; i++) { + ASSERT_EQ(res[i], left[i] != right[i]) << i; + } +} + +TEST(CompareColumn, avx512_int8) { + TestCompareColumnAVX512(1000, -128, 127); + TestCompareColumnAVX512(1001, -128, 127); +} + +TEST(CompareColumn, avx512_int16) { + TestCompareColumnAVX512(1000, -1000, 1000); + TestCompareColumnAVX512(1001, -1000, 1000); +} + +TEST(CompareColumn, avx512_int32) { + TestCompareColumnAVX512(1000, -1000, 1000); + TestCompareColumnAVX512(1001, -1000, 1000); +} + +TEST(CompareColumn, avx512_int64) { + TestCompareColumnAVX512(1000, -1000, 1000); + TestCompareColumnAVX512(1001, -1000, 1000); +} + +TEST(CompareColumn, avx512_float) { + TestCompareColumnAVX512(1000, -1.0, 1.0); + TestCompareColumnAVX512(1001, -1.0, 1.0); +} + +TEST(CompareColumn, avx512_double) { + TestCompareColumnAVX512(1000, -1.0, 1.0); + TestCompareColumnAVX512(1001, -1.0, 1.0); +} + #endif int main(int argc, char* argv[]) { ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); -} \ No newline at end of file +} From 20fb8475214cb6759b6956ce471306088b4f70fc Mon Sep 17 00:00:00 2001 From: Jiquan Long Date: Sun, 7 Jan 2024 20:22:48 +0800 Subject: [PATCH 13/20] enhance: load delta logs concurrently (#29623) This pr will make milvus load delta logs concurrently, which should decrease the latency of loading a segment. /kind improvement --------- Signed-off-by: longjiquan --- .../querynodev2/segments/segment_loader.go | 31 +++++++++++++------ 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/internal/querynodev2/segments/segment_loader.go b/internal/querynodev2/segments/segment_loader.go index 16e443167a81d..ea8f2a214f8f1 100644 --- a/internal/querynodev2/segments/segment_loader.go +++ b/internal/querynodev2/segments/segment_loader.go @@ -40,6 +40,7 @@ import ( "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" + "github.com/milvus-io/milvus/pkg/util/conc" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/hardware" "github.com/milvus-io/milvus/pkg/util/indexparamcheck" @@ -797,23 +798,35 @@ func (loader *segmentLoader) LoadDeltaLogs(ctx context.Context, segment Segment, ) dCodec := storage.DeleteCodec{} var blobs []*storage.Blob + var futures []*conc.Future[any] for _, deltaLog := range deltaLogs { for _, bLog := range deltaLog.GetBinlogs() { + bLog := bLog // the segment has applied the delta logs, skip it if bLog.GetTimestampTo() > 0 && // this field may be missed in legacy versions bLog.GetTimestampTo() < segment.LastDeltaTimestamp() { continue } - value, err := loader.cm.Read(ctx, bLog.GetLogPath()) - if err != nil { - return err - } - blob := &storage.Blob{ - Key: bLog.GetLogPath(), - Value: value, - } - blobs = append(blobs, blob) + future := GetLoadPool().Submit(func() (any, error) { + value, err := loader.cm.Read(ctx, bLog.GetLogPath()) + if err != nil { + return nil, err + } + blob := &storage.Blob{ + Key: bLog.GetLogPath(), + Value: value, + } + return blob, nil + }) + futures = append(futures, future) + } + } + for _, future := range futures { + blob, err := future.Await() + if err != nil { + return err } + blobs = append(blobs, blob.(*storage.Blob)) } if len(blobs) == 0 { log.Info("there are no delta logs saved with segment, skip loading delete record") From e9f3df3626497b9b7cc60f1fa8fdec85c9a134da Mon Sep 17 00:00:00 2001 From: Jiquan Long Date: Sun, 7 Jan 2024 20:26:49 +0800 Subject: [PATCH 14/20] fix: inverted index file not found (#29695) issue: https://github.com/milvus-io/milvus/issues/29654 --------- Signed-off-by: longjiquan --- internal/core/run_clang_format.sh | 2 +- internal/core/src/index/InvertedIndexTantivy.cpp | 16 ++++++++++++++-- .../tantivy-binding/include/tantivy-binding.h | 2 -- .../tantivy/tantivy-binding/src/index_writer.rs | 7 ++----- .../tantivy-binding/src/index_writer_c.rs | 15 +++------------ .../core/thirdparty/tantivy/tantivy-wrapper.h | 13 ++++++++++++- internal/core/unittest/test_c_api.cpp | 12 ++++++++---- .../test_utils/indexbuilder_test_utils.h | 3 ++- .../unittest/test_utils/storage_test_utils.h | 2 +- 9 files changed, 43 insertions(+), 29 deletions(-) diff --git a/internal/core/run_clang_format.sh b/internal/core/run_clang_format.sh index d4e7d7b58524a..2aa22c514473a 100755 --- a/internal/core/run_clang_format.sh +++ b/internal/core/run_clang_format.sh @@ -7,7 +7,7 @@ fi CorePath=$1 formatThis() { - find "$1" | grep -E "(*\.cpp|*\.h|*\.cc)$" | grep -v "gen_tools/templates" | grep -v "\.pb\." | xargs clang-format-10 -i + find "$1" | grep -E "(*\.cpp|*\.h|*\.cc)$" | grep -v "gen_tools/templates" | grep -v "\.pb\." | grep -v "tantivy-binding.h" | xargs clang-format-10 -i } formatThis "${CorePath}/src" diff --git a/internal/core/src/index/InvertedIndexTantivy.cpp b/internal/core/src/index/InvertedIndexTantivy.cpp index f03a7f91efd11..3b3d706893f50 100644 --- a/internal/core/src/index/InvertedIndexTantivy.cpp +++ b/internal/core/src/index/InvertedIndexTantivy.cpp @@ -71,8 +71,20 @@ BinarySet InvertedIndexTantivy::Upload(const Config& config) { finish(); - for (const auto& entry : std::filesystem::directory_iterator(path_)) { - disk_file_manager_->AddFile(entry.path()); + boost::filesystem::path p(path_); + boost::filesystem::directory_iterator end_iter; + + for (boost::filesystem::directory_iterator iter(p); iter != end_iter; + iter++) { + if (boost::filesystem::is_directory(*iter)) { + LOG_WARN("{} is a directory", iter->path().string()); + } else { + LOG_INFO("trying to add index file: {}", iter->path().string()); + AssertInfo(disk_file_manager_->AddFile(iter->path().string()), + "failed to add index file: {}", + iter->path().string()); + LOG_INFO("index file: {} added", iter->path().string()); + } } BinarySet ret; diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/include/tantivy-binding.h b/internal/core/thirdparty/tantivy/tantivy-binding/include/tantivy-binding.h index d402390552be5..0cabaf657182c 100644 --- a/internal/core/thirdparty/tantivy/tantivy-binding/include/tantivy-binding.h +++ b/internal/core/thirdparty/tantivy/tantivy-binding/include/tantivy-binding.h @@ -27,8 +27,6 @@ void tantivy_free_index_writer(void *ptr); void tantivy_finish_index(void *ptr); -void *tantivy_create_reader_for_index(void *ptr); - void tantivy_index_add_int8s(void *ptr, const int8_t *array, uintptr_t len); void tantivy_index_add_int16s(void *ptr, const int16_t *array, uintptr_t len); diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/src/index_writer.rs b/internal/core/thirdparty/tantivy/tantivy-binding/src/index_writer.rs index fe3d078d28619..4f8ed8e9df27d 100644 --- a/internal/core/thirdparty/tantivy/tantivy-binding/src/index_writer.rs +++ b/internal/core/thirdparty/tantivy/tantivy-binding/src/index_writer.rs @@ -16,10 +16,6 @@ pub struct IndexWriterWrapper { } impl IndexWriterWrapper { - pub fn create_reader(&self) -> IndexReaderWrapper { - IndexReaderWrapper::new(&self.index, &self.field_name, self.field) - } - pub fn new(field_name: String, data_type: TantivyDataType, path: String) -> IndexWriterWrapper { let field: Field; let mut schema_builder = Schema::builder(); @@ -101,8 +97,9 @@ impl IndexWriterWrapper { .unwrap(); } - pub fn finish(&mut self) { + pub fn finish(mut self) { self.index_writer.commit().unwrap(); block_on(self.index_writer.garbage_collect_files()).unwrap(); + self.index_writer.wait_merging_threads().unwrap(); } } diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/src/index_writer_c.rs b/internal/core/thirdparty/tantivy/tantivy-binding/src/index_writer_c.rs index c19ece5a83402..482011d305a57 100644 --- a/internal/core/thirdparty/tantivy/tantivy-binding/src/index_writer_c.rs +++ b/internal/core/thirdparty/tantivy/tantivy-binding/src/index_writer_c.rs @@ -26,25 +26,16 @@ pub extern "C" fn tantivy_free_index_writer(ptr: *mut c_void) { free_binding::(ptr); } +// tantivy_finish_index will finish the index writer, and the index writer can't be used any more. +// After this was called, you should reset the pointer to null. #[no_mangle] pub extern "C" fn tantivy_finish_index(ptr: *mut c_void) { let real = ptr as *mut IndexWriterWrapper; unsafe { - (*real).finish(); + Box::from_raw(real).finish() } } -// should be only used for test -#[no_mangle] -pub extern "C" fn tantivy_create_reader_for_index(ptr: *mut c_void) -> *mut c_void{ - let real = ptr as *mut IndexWriterWrapper; - unsafe { - let reader = (*real).create_reader(); - create_binding(reader) - } -} - - // -------------------------build-------------------- #[no_mangle] pub extern "C" fn tantivy_index_add_int8s(ptr: *mut c_void, array: *const i8, len: usize) { diff --git a/internal/core/thirdparty/tantivy/tantivy-wrapper.h b/internal/core/thirdparty/tantivy/tantivy-wrapper.h index 2e2d057ffe085..9577429fde1c2 100644 --- a/internal/core/thirdparty/tantivy/tantivy-wrapper.h +++ b/internal/core/thirdparty/tantivy/tantivy-wrapper.h @@ -93,9 +93,11 @@ struct TantivyIndexWrapper { writer_ = other.writer_; reader_ = other.reader_; finished_ = other.finished_; + path_ = other.path_; other.writer_ = nullptr; other.reader_ = nullptr; other.finished_ = false; + other.path_ = ""; } TantivyIndexWrapper& @@ -104,10 +106,12 @@ struct TantivyIndexWrapper { free(); writer_ = other.writer_; reader_ = other.reader_; + path_ = other.path_; finished_ = other.finished_; other.writer_ = nullptr; other.reader_ = nullptr; other.finished_ = false; + other.path_ = ""; } return *this; } @@ -116,11 +120,13 @@ struct TantivyIndexWrapper { TantivyDataType data_type, const char* path) { writer_ = tantivy_create_index(field_name, data_type, path); + path_ = std::string(path); } explicit TantivyIndexWrapper(const char* path) { assert(tantivy_index_exist(path)); reader_ = tantivy_load_index(path); + path_ = std::string(path); } ~TantivyIndexWrapper() { @@ -130,6 +136,8 @@ struct TantivyIndexWrapper { template void add_data(const T* array, uintptr_t len) { + assert(!finished_); + if constexpr (std::is_same_v) { tantivy_index_add_bools(writer_, array, len); return; @@ -182,7 +190,9 @@ struct TantivyIndexWrapper { finish() { if (!finished_) { tantivy_finish_index(writer_); - reader_ = tantivy_create_reader_for_index(writer_); + writer_ = nullptr; + reader_ = tantivy_load_index(path_.c_str()); + finished_ = true; } } @@ -358,5 +368,6 @@ struct TantivyIndexWrapper { bool finished_ = false; IndexWriter writer_ = nullptr; IndexReader reader_ = nullptr; + std::string path_; }; } // namespace milvus::tantivy diff --git a/internal/core/unittest/test_c_api.cpp b/internal/core/unittest/test_c_api.cpp index 9bbac8da78e4c..83f9717555fc8 100644 --- a/internal/core/unittest/test_c_api.cpp +++ b/internal/core/unittest/test_c_api.cpp @@ -316,7 +316,7 @@ TEST(CApiTest, SegmentTest) { ASSERT_NE(status.error_code, Success); DeleteCollection(collection); DeleteSegment(segment); - free((char *)status.error_msg); + free((char*)status.error_msg); } TEST(CApiTest, CPlan) { @@ -1579,7 +1579,10 @@ TEST(CApiTest, ReduceRemoveDuplicates) { } void -testReduceSearchWithExpr(int N, int topK, int num_queries, bool filter_all = false) { +testReduceSearchWithExpr(int N, + int topK, + int num_queries, + bool filter_all = false) { std::cerr << "testReduceSearchWithExpr(" << N << ", " << topK << ", " << num_queries << ")" << std::endl; @@ -1637,7 +1640,8 @@ testReduceSearchWithExpr(int N, int topK, int num_queries, bool filter_all = fal search_params: "{\"nprobe\": 10}" > placeholder_tag: "$0"> - output_field_ids: 100)") %topK %N; + output_field_ids: 100)") % + topK % N; } auto serialized_expr_plan = fmt.str(); auto blob = generate_query_data(num_queries); @@ -2305,7 +2309,7 @@ TEST(CApiTest, Indexing_Expr_With_float_Predicate_Range) { generate_collection_schema(knowhere::metric::L2, DIM, false); auto collection = NewCollection(schema_string.c_str()); auto schema = ((segcore::Collection*)collection)->get_schema(); - CSegmentInterface segment; + CSegmentInterface segment; auto status = NewSegment(collection, Growing, -1, &segment); ASSERT_EQ(status.error_code, Success); diff --git a/internal/core/unittest/test_utils/indexbuilder_test_utils.h b/internal/core/unittest/test_utils/indexbuilder_test_utils.h index 93a39b7bcdbf1..9acd12cd22eea 100644 --- a/internal/core/unittest/test_utils/indexbuilder_test_utils.h +++ b/internal/core/unittest/test_utils/indexbuilder_test_utils.h @@ -459,7 +459,8 @@ GetIndexTypesV2() { template <> inline std::vector GetIndexTypesV2() { - return std::vector{milvus::index::INVERTED_INDEX_TYPE, "marisa"}; + return std::vector{milvus::index::INVERTED_INDEX_TYPE, + "marisa"}; } } // namespace diff --git a/internal/core/unittest/test_utils/storage_test_utils.h b/internal/core/unittest/test_utils/storage_test_utils.h index 31e3b06d6b258..7eca359f3043d 100644 --- a/internal/core/unittest/test_utils/storage_test_utils.h +++ b/internal/core/unittest/test_utils/storage_test_utils.h @@ -25,11 +25,11 @@ #include "storage/ThreadPools.h" using milvus::DataType; -using milvus::storage::FieldDataMeta; using milvus::FieldDataPtr; using milvus::FieldId; using milvus::segcore::GeneratedData; using milvus::storage::ChunkManagerPtr; +using milvus::storage::FieldDataMeta; using milvus::storage::InsertData; using milvus::storage::StorageConfig; From cd34de7de52b1b76b74e4530df6ffa58c37ee17f Mon Sep 17 00:00:00 2001 From: "sammy.huang" Date: Mon, 8 Jan 2024 10:10:47 +0800 Subject: [PATCH 15/20] enhance:[skip e2e] use docker pugin to do same thing instead (#29667) fix: #29663 Signed-off-by: Sammy Huang --- .github/workflows/publish-builder.yaml | 42 +++++++++++-------- .github/workflows/publish-gpu-builder.yaml | 49 ++++++++++++---------- 2 files changed, 52 insertions(+), 39 deletions(-) diff --git a/.github/workflows/publish-builder.yaml b/.github/workflows/publish-builder.yaml index 0a89a5f9e0a5f..d991b4c387a0c 100644 --- a/.github/workflows/publish-builder.yaml +++ b/.github/workflows/publish-builder.yaml @@ -50,7 +50,16 @@ jobs: id: extracter run: | echo "::set-output name=version::$(date +%Y%m%d)" - echo "::set-output name=sha_short::$(git rev-parse --short HEAD)" + echo "::set-output name=sha_short::$(git rev-parse --short=7 HEAD)" + - name: Docker meta + id: meta + uses: docker/metadata-action@v5 + with: + images: | + milvusdb/milvus-env + tags: | + type=raw,enable=true,value=${{ matrix.os }}-{{date 'YYYYMMDD'}}-{{sha}} + type=raw,enable=true,value=${{ matrix.os }}-latest # - name: Setup upterm session # uses: lhotari/action-upterm@v1 - name: Set up QEMU @@ -59,25 +68,24 @@ jobs: platforms: arm64 - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - - name: Docker Build - if: success() && github.event_name == 'pull_request' && github.repository == 'milvus-io/milvus' - shell: bash - run: | - docker buildx ls - docker buildx build --platform linux/amd64,linux/arm64 -t milvusdb/milvus-env:${OS_NAME}-${{ steps.extracter.outputs.version }}-${{ steps.extracter.outputs.sha_short }} -f build/docker/builder/cpu/${OS_NAME}/Dockerfile . - - name: Docker Build&Push - if: success() && github.event_name == 'push' && github.repository == 'milvus-io/milvus' - shell: bash - run: | - docker buildx ls - docker login -u ${{ secrets.DOCKERHUB_USER }} \ - -p ${{ secrets.DOCKERHUB_TOKEN }} - docker buildx build --platform linux/amd64,linux/arm64 --push -t milvusdb/milvus-env:${OS_NAME}-${{ steps.extracter.outputs.version }}-${{ steps.extracter.outputs.sha_short }} -f build/docker/builder/cpu/${OS_NAME}/Dockerfile . - docker buildx build --platform linux/amd64,linux/arm64 --push -t milvusdb/milvus-env:${OS_NAME}-latest -f build/docker/builder/cpu/${OS_NAME}/Dockerfile . + - name: Login to Docker Hub + if: github.event_name != 'pull_request' + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USER }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + - name: Build and push + uses: docker/build-push-action@v5 + with: + context: . + platforms: linux/amd64,linux/arm64 + push: ${{ github.event_name != 'pull_request' }} + tags: ${{ steps.meta.outputs.tags }} + file: build/docker/builder/cpu/${{ matrix.os }}/Dockerfile - name: Bump Builder Version uses: ./.github/actions/bump-builder-version if: success() && github.event_name == 'push' && github.repository == 'milvus-io/milvus' && matrix.os == 'ubuntu20.04' with: tag: "${{ steps.extracter.outputs.version }}-${{ steps.extracter.outputs.sha_short }}" type: cpu - token: ${{ secrets.ALL_CONTRIBUTORS_TOKEN }} \ No newline at end of file + token: ${{ secrets.ALL_CONTRIBUTORS_TOKEN }} diff --git a/.github/workflows/publish-gpu-builder.yaml b/.github/workflows/publish-gpu-builder.yaml index 2587457b8fde7..a8dad78a3c966 100644 --- a/.github/workflows/publish-gpu-builder.yaml +++ b/.github/workflows/publish-gpu-builder.yaml @@ -50,33 +50,38 @@ jobs: id: extracter run: | echo "::set-output name=version::$(date +%Y%m%d)" - echo "::set-output name=sha_short::$(git rev-parse --short HEAD)" + echo "::set-output name=sha_short::$(git rev-parse --short=7 HEAD)" + - name: Docker meta + id: meta + uses: docker/metadata-action@v5 + with: + images: | + milvusdb/milvus-env + tags: | + type=raw,enable=true,value=${{ matrix.os }}-{{date 'YYYYMMDD'}}-{{sha}} + type=raw,enable=true,value=${{ matrix.os }}-latest # - name: Setup upterm session # uses: lhotari/action-upterm@v1 - - name: Docker Build - if: success() && github.event_name == 'pull_request' && github.repository == 'milvus-io/milvus' - shell: bash - run: | - docker info - docker build -t milvusdb/milvus-env:gpu-${OS_NAME}-${{ steps.extracter.outputs.version }}-${{ steps.extracter.outputs.sha_short }} -f build/docker/builder/gpu/${OS_NAME}/Dockerfile . - - name: Docker Build&Push - if: success() && github.event_name == 'push' && github.repository == 'milvus-io/milvus' - shell: bash - run: | - docker info - docker login -u ${{ secrets.DOCKERHUB_USER }} \ - -p ${{ secrets.DOCKERHUB_TOKEN }} - # Building the first image - docker build -t milvusdb/milvus-env:gpu-${OS_NAME}-${{ steps.extracter.outputs.version }}-${{ steps.extracter.outputs.sha_short }} -f build/docker/builder/gpu/${OS_NAME}/Dockerfile . - docker push milvusdb/milvus-env:gpu-${OS_NAME}-${{ steps.extracter.outputs.version }}-${{ steps.extracter.outputs.sha_short }} - - # Building the second image - docker build -t milvusdb/milvus-env:gpu-${OS_NAME}-latest -f build/docker/builder/gpu/${OS_NAME}/Dockerfile . - docker push milvusdb/milvus-env:gpu-${OS_NAME}-latest + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + - name: Login to Docker Hub + if: github.event_name != 'pull_request' + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USER }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + - name: Build and push + uses: docker/build-push-action@v5 + with: + context: . + platforms: linux/amd64 + push: ${{ github.event_name != 'pull_request' }} + tags: ${{ steps.meta.outputs.tags }} + file: build/docker/builder/gpu/${{ matrix.os }}/Dockerfile - name: Bump Builder Version if: success() && github.event_name == 'push' && github.repository == 'milvus-io/milvus' && matrix.os == 'ubuntu20.04' uses: ./.github/actions/bump-builder-version with: tag: "${{ steps.extracter.outputs.version }}-${{ steps.extracter.outputs.sha_short }}" type: gpu - token: ${{ secrets.ALL_CONTRIBUTORS_TOKEN }} \ No newline at end of file + token: ${{ secrets.ALL_CONTRIBUTORS_TOKEN }} From fe47deebf3e46158ac7852904b9921fe6993a095 Mon Sep 17 00:00:00 2001 From: congqixia Date: Mon, 8 Jan 2024 14:16:48 +0800 Subject: [PATCH 16/20] fix: Set & Return correct SegmentLevel in querynode segment manager (#29740) See also #27349 The segment level label in querynode used `Legacy` before segment level was correctly passed in Load request. Now this attribute is still using legacy so the metrics does not look right. This PR add paramter for `NewSegment` and passes corrent values for each invocation. Signed-off-by: Congqi Xia --- internal/querynodev2/delegator/delegator_data.go | 2 +- internal/querynodev2/segments/segment.go | 12 ++++++++---- internal/querynodev2/segments/segment_l0.go | 2 +- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/internal/querynodev2/delegator/delegator_data.go b/internal/querynodev2/delegator/delegator_data.go index 69191dfbc4509..a6d1238230547 100644 --- a/internal/querynodev2/delegator/delegator_data.go +++ b/internal/querynodev2/delegator/delegator_data.go @@ -98,7 +98,7 @@ func (sd *shardDelegator) ProcessInsert(insertRecords map[int64]*InsertData) { 0, insertData.StartPosition, insertData.StartPosition, - datapb.SegmentLevel_Legacy, + datapb.SegmentLevel_L1, ) if err != nil { log.Error("failed to create new segment", diff --git a/internal/querynodev2/segments/segment.go b/internal/querynodev2/segments/segment.go index a7e3baed07069..efec5ddd2b97e 100644 --- a/internal/querynodev2/segments/segment.go +++ b/internal/querynodev2/segments/segment.go @@ -74,18 +74,20 @@ type baseSegment struct { shard string collectionID int64 typ SegmentType + level datapb.SegmentLevel version *atomic.Int64 startPosition *msgpb.MsgPosition // for growing segment release bloomFilterSet *pkoracle.BloomFilterSet } -func newBaseSegment(id, partitionID, collectionID int64, shard string, typ SegmentType, version int64, startPosition *msgpb.MsgPosition) baseSegment { +func newBaseSegment(id, partitionID, collectionID int64, shard string, typ SegmentType, level datapb.SegmentLevel, version int64, startPosition *msgpb.MsgPosition) baseSegment { return baseSegment{ segmentID: id, partitionID: partitionID, collectionID: collectionID, shard: shard, typ: typ, + level: level, version: atomic.NewInt64(version), startPosition: startPosition, bloomFilterSet: pkoracle.NewBloomFilterSet(id, partitionID, typ), @@ -114,7 +116,7 @@ func (s *baseSegment) Type() SegmentType { } func (s *baseSegment) Level() datapb.SegmentLevel { - return datapb.SegmentLevel_Legacy + return s.level } func (s *baseSegment) StartPosition() *msgpb.MsgPosition { @@ -205,10 +207,12 @@ func NewSegment(ctx context.Context, zap.Int64("collectionID", collectionID), zap.Int64("partitionID", partitionID), zap.Int64("segmentID", segmentID), - zap.String("segmentType", segmentType.String())) + zap.String("segmentType", segmentType.String()), + zap.String("level", level.String()), + ) segment := &LocalSegment{ - baseSegment: newBaseSegment(segmentID, partitionID, collectionID, shard, segmentType, version, startPosition), + baseSegment: newBaseSegment(segmentID, partitionID, collectionID, shard, segmentType, level, version, startPosition), ptr: newPtr, lastDeltaTimestamp: atomic.NewUint64(0), fieldIndexes: typeutil.NewConcurrentMap[int64, *IndexedFieldInfo](), diff --git a/internal/querynodev2/segments/segment_l0.go b/internal/querynodev2/segments/segment_l0.go index e0a6b7b90f698..125642e55752a 100644 --- a/internal/querynodev2/segments/segment_l0.go +++ b/internal/querynodev2/segments/segment_l0.go @@ -62,7 +62,7 @@ func NewL0Segment(collection *Collection, zap.String("segmentType", segmentType.String())) segment := &L0Segment{ - baseSegment: newBaseSegment(segmentID, partitionID, collectionID, shard, segmentType, version, startPosition), + baseSegment: newBaseSegment(segmentID, partitionID, collectionID, shard, segmentType, datapb.SegmentLevel_L0, version, startPosition), } return segment, nil From 7e6f73a12df0d5b41c44cce4f625e0ab8401bbba Mon Sep 17 00:00:00 2001 From: "zhenshan.cao" Date: Mon, 8 Jan 2024 15:10:49 +0800 Subject: [PATCH 17/20] feat: Authorize users to query grant info of their roles (#29747) Once a role is granted to a user, the user should automatically possess the privilege information associated with that role. issue: #29710 Signed-off-by: zhenshan.cao --- internal/proxy/privilege_interceptor.go | 16 ++++++++++++++ tests/python_client/testcases/test_utility.py | 22 ++++++++++++++++++- 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/internal/proxy/privilege_interceptor.go b/internal/proxy/privilege_interceptor.go index 9eb3e8d77f4f9..10bec733c5ae7 100644 --- a/internal/proxy/privilege_interceptor.go +++ b/internal/proxy/privilege_interceptor.go @@ -15,6 +15,7 @@ import ( "google.golang.org/grpc/status" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util" "github.com/milvus-io/milvus/pkg/util/funcutil" @@ -113,6 +114,11 @@ func PrivilegeInterceptor(ctx context.Context, req interface{}) (context.Context if isCurUserObject(objectType, username, objectName) { return ctx, nil } + + if isSelectMyRoleGrants(req, roleNames) { + return ctx, nil + } + objectNameIndexs := privilegeExt.ObjectNameIndexs objectNames := funcutil.GetObjectNames(req, objectNameIndexs) objectPrivilege := privilegeExt.ObjectPrivilege.String() @@ -181,6 +187,16 @@ func isCurUserObject(objectType string, curUser string, object string) bool { return curUser == object } +func isSelectMyRoleGrants(req interface{}, roleNames []string) bool { + selectGrantReq, ok := req.(*milvuspb.SelectGrantRequest) + if !ok { + return false + } + filterGrantEntity := selectGrantReq.GetEntity() + roleName := filterGrantEntity.GetRole().GetName() + return funcutil.SliceContain(roleNames, roleName) +} + func DBMatchFunc(args ...interface{}) (interface{}, error) { name1 := args[0].(string) name2 := args[1].(string) diff --git a/tests/python_client/testcases/test_utility.py b/tests/python_client/testcases/test_utility.py index 71e2f30f4bd11..addf62dc6298f 100644 --- a/tests/python_client/testcases/test_utility.py +++ b/tests/python_client/testcases/test_utility.py @@ -3004,6 +3004,9 @@ def test_role_list_grants(self, host, port, with_db): r_name = cf.gen_unique_str(prefix) c_name = cf.gen_unique_str(prefix) u, _ = self.utility_wrap.create_user(user=user, password=password) + user2 = cf.gen_unique_str(prefix) + u2, _ = self.utility_wrap.create_user(user=user2, password=password) + self.utility_wrap.init_role(r_name) self.utility_wrap.create_role() @@ -3019,10 +3022,27 @@ def test_role_list_grants(self, host, port, with_db): self.utility_wrap.role_grant(grant_item["object"], grant_item["object_name"], grant_item["privilege"], **db_kwargs) - # list grants + # list grants with default user + g_list, _ = self.utility_wrap.role_list_grants(**db_kwargs) + assert len(g_list.groups) == len(grant_list) + + self.connection_wrap.disconnect(alias=DefaultConfig.DEFAULT_USING) + self.connection_wrap.connect(host=host, port=port, user=user, + password=password, check_task=ct.CheckTasks.ccr, **db_kwargs) + + # list grants with user g_list, _ = self.utility_wrap.role_list_grants(**db_kwargs) assert len(g_list.groups) == len(grant_list) + + self.connection_wrap.disconnect(alias=DefaultConfig.DEFAULT_USING) + self.connection_wrap.connect(host=host, port=port, user=user2, + password=password, check_task=ct.CheckTasks.ccr, **db_kwargs) + + # user2 can not list grants of role + self.utility_wrap.role_list_grants(**db_kwargs, + check_task=CheckTasks.check_permission_deny) + @pytest.mark.tags(CaseLabel.RBAC) def test_drop_role_which_bind_user(self, host, port): """ From 9702cef2b597daa6c6291e50cdaa78a8388f6545 Mon Sep 17 00:00:00 2001 From: xige-16 Date: Mon, 8 Jan 2024 15:34:48 +0800 Subject: [PATCH 18/20] feat: Support multiple vector search (#29433) issue #25639 Signed-off-by: xige-16 Signed-off-by: xige-16 --- .../core/src/segcore/SegmentGrowingImpl.h | 10 + .../core/src/segcore/SegmentSealedImpl.cpp | 11 + internal/proto/query_coord.proto | 2 +- internal/proxy/impl.go | 132 ++++- internal/proxy/proxy_test.go | 46 +- internal/proxy/reScorer.go | 157 ++++++ internal/proxy/reScorer_test.go | 55 +++ internal/proxy/task.go | 5 + internal/proxy/task_hybrid_search.go | 461 ++++++++++++++++++ internal/proxy/task_hybrid_search_test.go | 330 +++++++++++++ internal/proxy/task_search.go | 136 +++--- internal/proxy/task_test.go | 14 + internal/proxy/util.go | 4 + internal/querycoordv2/task/executor.go | 7 - internal/querycoordv2/task/utils.go | 25 +- internal/querycoordv2/task/utils_test.go | 52 -- internal/querynodev2/mock_data.go | 61 +-- internal/querynodev2/segments/collection.go | 11 +- internal/querynodev2/segments/mock_data.go | 154 ++++-- internal/querynodev2/segments/plan.go | 8 +- internal/querynodev2/services.go | 24 +- internal/querynodev2/services_test.go | 187 ++++--- pkg/metrics/metrics.go | 2 + pkg/util/typeutil/schema_test.go | 10 +- .../hybridsearch/hybridsearch_test.go | 225 +++++++++ tests/integration/util_index.go | 11 +- 26 files changed, 1773 insertions(+), 367 deletions(-) create mode 100644 internal/proxy/reScorer.go create mode 100644 internal/proxy/reScorer_test.go create mode 100644 internal/proxy/task_hybrid_search.go create mode 100644 internal/proxy/task_hybrid_search_test.go create mode 100644 tests/integration/hybridsearch/hybridsearch_test.go diff --git a/internal/core/src/segcore/SegmentGrowingImpl.h b/internal/core/src/segcore/SegmentGrowingImpl.h index 544ca154a5903..e23921801eb12 100644 --- a/internal/core/src/segcore/SegmentGrowingImpl.h +++ b/internal/core/src/segcore/SegmentGrowingImpl.h @@ -281,6 +281,16 @@ class SegmentGrowingImpl : public SegmentGrowing { void check_search(const query::Plan* plan) const override { Assert(plan); + auto& metric_str = plan->plan_node_->search_info_.metric_type_; + auto searched_field_id = plan->plan_node_->search_info_.field_id_; + auto index_meta = + index_meta_->GetFieldIndexMeta(FieldId(searched_field_id)); + if (metric_str.empty()) { + metric_str = index_meta.GeMetricType(); + } else { + AssertInfo(metric_str == index_meta.GeMetricType(), + "metric type not match"); + } } const ConcurrentVector& diff --git a/internal/core/src/segcore/SegmentSealedImpl.cpp b/internal/core/src/segcore/SegmentSealedImpl.cpp index 72f4fa09989db..a6c8773bab027 100644 --- a/internal/core/src/segcore/SegmentSealedImpl.cpp +++ b/internal/core/src/segcore/SegmentSealedImpl.cpp @@ -917,6 +917,17 @@ SegmentSealedImpl::check_search(const query::Plan* plan) const { AssertInfo(plan->extra_info_opt_.has_value(), "Extra info of search plan doesn't have value"); + auto& metric_str = plan->plan_node_->search_info_.metric_type_; + auto searched_field_id = plan->plan_node_->search_info_.field_id_; + auto index_meta = + col_index_meta_->GetFieldIndexMeta(FieldId(searched_field_id)); + if (metric_str.empty()) { + metric_str = index_meta.GeMetricType(); + } else { + AssertInfo(metric_str == index_meta.GeMetricType(), + "metric type not match"); + } + if (!is_system_field_ready()) { PanicInfo( FieldNotLoaded, diff --git a/internal/proto/query_coord.proto b/internal/proto/query_coord.proto index fdbd9dac78194..da69d840ea5b7 100644 --- a/internal/proto/query_coord.proto +++ b/internal/proto/query_coord.proto @@ -219,7 +219,7 @@ message LoadMetaInfo { LoadType load_type = 1; int64 collectionID = 2; repeated int64 partitionIDs = 3; - string metric_type = 4; + string metric_type = 4 [deprecated=true]; } message WatchDmChannelsRequest { diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index e3edb13bdc9a2..f4957527b1951 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -2726,9 +2726,135 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest) } func (node *Proxy) HybridSearch(ctx context.Context, request *milvuspb.HybridSearchRequest) (*milvuspb.SearchResults, error) { - return &milvuspb.SearchResults{ - Status: merr.Status(merr.WrapErrServiceInternal("unimplemented")), - }, nil + receiveSize := proto.Size(request) + metrics.ProxyReceiveBytes.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + metrics.HybridSearchLabel, + request.GetCollectionName(), + ).Add(float64(receiveSize)) + + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return &milvuspb.SearchResults{ + Status: merr.Status(err), + }, nil + } + + method := "HybridSearch" + tr := timerecord.NewTimeRecorder(method) + metrics.ProxyFunctionCall.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + method, + metrics.TotalLabel, + ).Inc() + + ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-HybridSearch") + defer sp.End() + + qt := &hybridSearchTask{ + ctx: ctx, + Condition: NewTaskCondition(ctx), + request: request, + tr: timerecord.NewTimeRecorder(method), + qc: node.queryCoord, + node: node, + lb: node.lbPolicy, + } + + guaranteeTs := request.GuaranteeTimestamp + + log := log.Ctx(ctx).With( + zap.String("role", typeutil.ProxyRole), + zap.String("db", request.DbName), + zap.String("collection", request.CollectionName), + zap.Any("partitions", request.PartitionNames), + zap.Any("OutputFields", request.OutputFields), + zap.Uint64("guarantee_timestamp", guaranteeTs), + ) + + defer func() { + span := tr.ElapseSpan() + if span >= SlowReadSpan { + log.Info(rpcSlow(method), zap.Duration("duration", span)) + } + }() + + log.Debug(rpcReceived(method)) + + if err := node.sched.dqQueue.Enqueue(qt); err != nil { + log.Warn( + rpcFailedToEnqueue(method), + zap.Error(err), + ) + + metrics.ProxyFunctionCall.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + method, + metrics.AbandonLabel, + ).Inc() + + return &milvuspb.SearchResults{ + Status: merr.Status(err), + }, nil + } + tr.CtxRecord(ctx, "hybrid search request enqueue") + + log.Debug( + rpcEnqueued(method), + zap.Uint64("timestamp", qt.request.Base.Timestamp), + ) + + if err := qt.WaitToFinish(); err != nil { + log.Warn( + rpcFailedToWaitToFinish(method), + zap.Error(err), + ) + + metrics.ProxyFunctionCall.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + method, + metrics.FailLabel, + ).Inc() + + return &milvuspb.SearchResults{ + Status: merr.Status(err), + }, nil + } + + span := tr.CtxRecord(ctx, "wait hybrid search result") + metrics.ProxyWaitForSearchResultLatency.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + metrics.HybridSearchLabel, + ).Observe(float64(span.Milliseconds())) + + tr.CtxRecord(ctx, "wait hybrid search result") + log.Debug(rpcDone(method)) + + metrics.ProxyFunctionCall.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + method, + metrics.SuccessLabel, + ).Inc() + + metrics.ProxySearchVectors.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Add(float64(len(qt.request.GetRequests()))) + + searchDur := tr.ElapseSpan().Milliseconds() + metrics.ProxySQLatency.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + metrics.HybridSearchLabel, + ).Observe(float64(searchDur)) + + metrics.ProxyCollectionSQLatency.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + metrics.HybridSearchLabel, + request.CollectionName, + ).Observe(float64(searchDur)) + + if qt.result != nil { + sentSize := proto.Size(qt.result) + metrics.ProxyReadReqSendBytes.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Add(float64(sentSize)) + rateCol.Add(metricsinfo.ReadResultThroughput, float64(sentSize)) + } + return qt.result, nil } func (node *Proxy) getVectorPlaceholderGroupForSearchByPks(ctx context.Context, request *milvuspb.SearchRequest) ([]byte, error) { diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 909cadefe9ef1..57fa36621d25d 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -1477,7 +1477,7 @@ func TestProxy(t *testing.T) { topk := 10 roundDecimal := 6 expr := fmt.Sprintf("%s > 0", int64Field) - constructVectorsPlaceholderGroup := func() *commonpb.PlaceholderGroup { + constructVectorsPlaceholderGroup := func(nq int) *commonpb.PlaceholderGroup { values := make([][]byte, 0, nq) for i := 0; i < nq; i++ { bs := make([]byte, 0, dim*4) @@ -1502,8 +1502,8 @@ func TestProxy(t *testing.T) { } } - constructSearchRequest := func() *milvuspb.SearchRequest { - plg := constructVectorsPlaceholderGroup() + constructSearchRequest := func(nq int) *milvuspb.SearchRequest { + plg := constructVectorsPlaceholderGroup(nq) plgBs, err := proto.Marshal(plg) assert.NoError(t, err) @@ -1538,13 +1538,51 @@ func TestProxy(t *testing.T) { wg.Add(1) t.Run("search", func(t *testing.T) { defer wg.Done() - req := constructSearchRequest() + req := constructSearchRequest(nq) resp, err := proxy.Search(ctx, req) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) }) + constructHybridSearchRequest := func(reqs []*milvuspb.SearchRequest) *milvuspb.HybridSearchRequest { + params := make(map[string]float64) + params[RRFParamsKey] = 60 + b, err := json.Marshal(params) + assert.NoError(t, err) + rankParams := []*commonpb.KeyValuePair{ + {Key: RankTypeKey, Value: "rrf"}, + {Key: RankParamsKey, Value: string(b)}, + {Key: LimitKey, Value: strconv.Itoa(topk)}, + {Key: RoundDecimalKey, Value: strconv.Itoa(roundDecimal)}, + } + + return &milvuspb.HybridSearchRequest{ + Base: nil, + DbName: dbName, + CollectionName: collectionName, + Requests: reqs, + PartitionNames: nil, + OutputFields: nil, + RankParams: rankParams, + TravelTimestamp: 0, + GuaranteeTimestamp: 0, + } + } + + wg.Add(1) + nq = 1 + t.Run("hybrid search", func(t *testing.T) { + defer wg.Done() + req1 := constructSearchRequest(nq) + req2 := constructSearchRequest(nq) + + resp, err := proxy.HybridSearch(ctx, constructHybridSearchRequest([]*milvuspb.SearchRequest{req1, req2})) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + }) + nq = 10 + constructPrimaryKeysPlaceholderGroup := func() *commonpb.PlaceholderGroup { expr := fmt.Sprintf("%v in [%v]", int64Field, insertedIds[0]) exprBytes := []byte(expr) diff --git a/internal/proxy/reScorer.go b/internal/proxy/reScorer.go new file mode 100644 index 0000000000000..264057ac3bbd1 --- /dev/null +++ b/internal/proxy/reScorer.go @@ -0,0 +1,157 @@ +package proxy + +import ( + "encoding/json" + "fmt" + "reflect" + + "github.com/cockroachdb/errors" + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" +) + +type rankType int + +const ( + invalidRankType rankType = iota // invalidRankType = 0 + rrfRankType // rrfRankType = 1 + weightedRankType // weightedRankType = 2 + udfExprRankType // udfExprRankType = 3 +) + +var rankTypeMap = map[string]rankType{ + "invalid": invalidRankType, + "rrf": rrfRankType, + "weighted": weightedRankType, + "expr": udfExprRankType, +} + +type reScorer interface { + name() string + scorerType() rankType + reScore(input *milvuspb.SearchResults) +} + +type baseScorer struct { + scorerName string +} + +func (bs *baseScorer) name() string { + return bs.scorerName +} + +type rrfScorer struct { + baseScorer + k float32 +} + +func (rs *rrfScorer) reScore(input *milvuspb.SearchResults) { + for i := range input.Results.GetScores() { + input.Results.Scores[i] = 1 / (rs.k + float32(i+1)) + } +} + +func (rs *rrfScorer) scorerType() rankType { + return rrfRankType +} + +type weightedScorer struct { + baseScorer + weight float32 +} + +func (ws *weightedScorer) reScore(input *milvuspb.SearchResults) { + for i, score := range input.Results.GetScores() { + input.Results.Scores[i] = ws.weight * score + } +} + +func (ws *weightedScorer) scorerType() rankType { + return weightedRankType +} + +func NewReScorer(reqs []*milvuspb.SearchRequest, rankParams []*commonpb.KeyValuePair) ([]reScorer, error) { + res := make([]reScorer, len(reqs)) + rankTypeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RankTypeKey, rankParams) + if err != nil { + log.Info("rank strategy not specified, use rrf instead") + // if not set rank strategy, use rrf rank as default + for i := range reqs { + res[i] = &rrfScorer{ + baseScorer: baseScorer{ + scorerName: "rrf", + }, + k: float32(defaultRRFParamsValue), + } + } + return res, nil + } + + if _, ok := rankTypeMap[rankTypeStr]; !ok { + return nil, errors.Errorf("unsupported rank type %s", rankTypeStr) + } + + paramStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RankParamsKey, rankParams) + if err != nil { + return nil, errors.New(RankParamsKey + " not found in rank_params") + } + + var params map[string]interface{} + err = json.Unmarshal([]byte(paramStr), ¶ms) + if err != nil { + return nil, err + } + + switch rankTypeMap[rankTypeStr] { + case rrfRankType: + k, ok := params[RRFParamsKey].(float64) + if !ok { + return nil, errors.New(RRFParamsKey + " not found in rank_params") + } + log.Debug("rrf params", zap.Float64("k", k)) + for i := range reqs { + res[i] = &rrfScorer{ + baseScorer: baseScorer{ + scorerName: "rrf", + }, + k: float32(k), + } + } + case weightedRankType: + if _, ok := params[WeightsParamsKey]; !ok { + return nil, errors.New(WeightsParamsKey + " not found in rank_params") + } + weights := make([]float32, 0) + switch reflect.TypeOf(params[WeightsParamsKey]).Kind() { + case reflect.Slice: + rs := reflect.ValueOf(params[WeightsParamsKey]) + for i := 0; i < rs.Len(); i++ { + weights = append(weights, float32(rs.Index(i).Interface().(float64))) + } + default: + return nil, errors.New("The weights param should be an array") + } + + log.Debug("weights params", zap.Any("weights", weights)) + if len(reqs) != len(weights) { + return nil, merr.WrapErrParameterInvalid(fmt.Sprint(len(reqs)), fmt.Sprint(len(weights)), "the length of weights param mismatch with ann search requests") + } + for i := range reqs { + res[i] = &weightedScorer{ + baseScorer: baseScorer{ + scorerName: "weighted", + }, + weight: weights[i], + } + } + default: + return nil, errors.Errorf("unsupported rank type %s", rankTypeStr) + } + + return res, nil +} diff --git a/internal/proxy/reScorer_test.go b/internal/proxy/reScorer_test.go new file mode 100644 index 0000000000000..7d12d5fe0ae81 --- /dev/null +++ b/internal/proxy/reScorer_test.go @@ -0,0 +1,55 @@ +package proxy + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" +) + +func TestRescorer(t *testing.T) { + t.Run("default scorer", func(t *testing.T) { + rescorers, err := NewReScorer([]*milvuspb.SearchRequest{{}, {}}, nil) + assert.NoError(t, err) + assert.Equal(t, 2, len(rescorers)) + assert.Equal(t, rrfRankType, rescorers[0].scorerType()) + }) + + t.Run("rrf", func(t *testing.T) { + params := make(map[string]float64) + params[RRFParamsKey] = 61 + b, err := json.Marshal(params) + assert.NoError(t, err) + rankParams := []*commonpb.KeyValuePair{ + {Key: RankTypeKey, Value: "rrf"}, + {Key: RankParamsKey, Value: string(b)}, + } + + rescorers, err := NewReScorer([]*milvuspb.SearchRequest{{}, {}}, rankParams) + assert.NoError(t, err) + assert.Equal(t, 2, len(rescorers)) + assert.Equal(t, rrfRankType, rescorers[0].scorerType()) + assert.Equal(t, float32(61), rescorers[0].(*rrfScorer).k) + }) + + t.Run("weights", func(t *testing.T) { + weights := []float64{0.5, 0.2} + params := make(map[string][]float64) + params[WeightsParamsKey] = weights + b, err := json.Marshal(params) + assert.NoError(t, err) + rankParams := []*commonpb.KeyValuePair{ + {Key: RankTypeKey, Value: "weighted"}, + {Key: RankParamsKey, Value: string(b)}, + } + + rescorers, err := NewReScorer([]*milvuspb.SearchRequest{{}, {}}, rankParams) + assert.NoError(t, err) + assert.Equal(t, 2, len(rescorers)) + assert.Equal(t, weightedRankType, rescorers[0].scorerType()) + assert.Equal(t, float32(weights[0]), rescorers[0].(*weightedScorer).weight) + }) +} diff --git a/internal/proxy/task.go b/internal/proxy/task.go index bb57d377e4098..4d37b89b14c45 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -88,6 +88,11 @@ const ( // minFloat32 minimum float. minFloat32 = -1 * float32(math.MaxFloat32) + + RankTypeKey = "strategy" + RankParamsKey = "params" + RRFParamsKey = "k" + WeightsParamsKey = "weights" ) type task interface { diff --git a/internal/proxy/task_hybrid_search.go b/internal/proxy/task_hybrid_search.go new file mode 100644 index 0000000000000..61da77861c8c6 --- /dev/null +++ b/internal/proxy/task_hybrid_search.go @@ -0,0 +1,461 @@ +package proxy + +import ( + "context" + "fmt" + "math" + "sort" + "strconv" + + "github.com/cockroachdb/errors" + "go.opentelemetry.io/otel" + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/commonpbutil" + "github.com/milvus-io/milvus/pkg/util/conc" + "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/timerecord" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +const ( + HybridSearchTaskName = "HybridSearchTask" +) + +type hybridSearchTask struct { + Condition + ctx context.Context + + result *milvuspb.SearchResults + request *milvuspb.HybridSearchRequest + + tr *timerecord.TimeRecorder + schema *schemaInfo + requery bool + + userOutputFields []string + + qc types.QueryCoordClient + node types.ProxyComponent + lb LBPolicy + + collectionID UniqueID + + multipleRecallResults *typeutil.ConcurrentSet[*milvuspb.SearchResults] + reScorers []reScorer +} + +func (t *hybridSearchTask) PreExecute(ctx context.Context) error { + ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-HybridSearch-PreExecute") + defer sp.End() + + if len(t.request.Requests) <= 0 { + return errors.New("minimum of ann search requests is 1") + } + + if len(t.request.Requests) > defaultMaxSearchRequest { + return errors.New("maximum of ann search requests is 1024") + } + for _, req := range t.request.GetRequests() { + nq, err := getNq(req) + if err != nil { + log.Debug("failed to get nq", zap.Error(err)) + return err + } + if nq != 1 { + err = merr.WrapErrParameterInvalid("1", fmt.Sprint(nq), "nq should be equal to 1") + log.Debug(err.Error()) + return err + } + } + + collectionName := t.request.CollectionName + collID, err := globalMetaCache.GetCollectionID(ctx, t.request.GetDbName(), collectionName) + if err != nil { + return err + } + t.collectionID = collID + + log := log.Ctx(ctx).With(zap.Int64("collID", collID), zap.String("collName", collectionName)) + t.schema, err = globalMetaCache.GetCollectionSchema(ctx, t.request.GetDbName(), collectionName) + if err != nil { + log.Warn("get collection schema failed", zap.Error(err)) + return err + } + + partitionKeyMode, err := isPartitionKeyMode(ctx, t.request.GetDbName(), collectionName) + if err != nil { + log.Warn("is partition key mode failed", zap.Error(err)) + return err + } + if partitionKeyMode && len(t.request.GetPartitionNames()) != 0 { + return errors.New("not support manually specifying the partition names if partition key mode is used") + } + + t.request.OutputFields, t.userOutputFields, err = translateOutputFields(t.request.OutputFields, t.schema, false) + if err != nil { + log.Warn("translate output fields failed", zap.Error(err)) + return err + } + log.Debug("translate output fields", + zap.Strings("output fields", t.request.GetOutputFields())) + + if len(t.request.OutputFields) > 0 { + t.requery = true + } + + log.Debug("hybrid search preExecute done.", + zap.Uint64("guarantee_ts", t.request.GetGuaranteeTimestamp()), + zap.Bool("use_default_consistency", t.request.GetUseDefaultConsistency()), + zap.Any("consistency level", t.request.GetConsistencyLevel())) + + return nil +} + +func (t *hybridSearchTask) Execute(ctx context.Context) error { + ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-HybridSearch-Execute") + defer sp.End() + + log := log.Ctx(ctx).With(zap.Int64("collID", t.collectionID), zap.String("collName", t.request.GetCollectionName())) + tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute hybrid search %d", t.ID())) + defer tr.CtxElapse(ctx, "done") + + futures := make([]*conc.Future[*milvuspb.SearchResults], len(t.request.Requests)) + for index := range t.request.Requests { + searchReq := t.request.Requests[index] + future := conc.Go(func() (*milvuspb.SearchResults, error) { + searchReq.TravelTimestamp = t.request.GetTravelTimestamp() + searchReq.GuaranteeTimestamp = t.request.GetGuaranteeTimestamp() + searchReq.NotReturnAllMeta = t.request.GetNotReturnAllMeta() + searchReq.ConsistencyLevel = t.request.GetConsistencyLevel() + searchReq.UseDefaultConsistency = t.request.GetUseDefaultConsistency() + searchReq.OutputFields = nil + + return t.node.Search(ctx, searchReq) + }) + futures[index] = future + } + + err := conc.AwaitAll(futures...) + if err != nil { + return err + } + + t.reScorers, err = NewReScorer(t.request.GetRequests(), t.request.GetRankParams()) + if err != nil { + log.Info("generate reScorer failed", zap.Any("rank params", t.request.GetRankParams()), zap.Error(err)) + return err + } + t.multipleRecallResults = typeutil.NewConcurrentSet[*milvuspb.SearchResults]() + for i, future := range futures { + err = future.Err() + if err != nil { + log.Debug("QueryNode search result error", zap.Error(err)) + return err + } + result := futures[i].Value() + if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + log.Debug("QueryNode search result error", + zap.String("reason", result.GetStatus().GetReason())) + return merr.Error(result.GetStatus()) + } + + t.reScorers[i].reScore(result) + t.multipleRecallResults.Insert(result) + } + + log.Debug("hybrid search execute done.") + return nil +} + +type rankParams struct { + limit int64 + offset int64 + roundDecimal int64 +} + +// parseRankParams get limit and offset from rankParams, both are optional. +func parseRankParams(rankParamsPair []*commonpb.KeyValuePair) (*rankParams, error) { + var ( + limit int64 + offset int64 + roundDecimal int64 + err error + ) + + limitStr, err := funcutil.GetAttrByKeyFromRepeatedKV(LimitKey, rankParamsPair) + if err != nil { + return nil, errors.New(LimitKey + " not found in search_params") + } + limit, err = strconv.ParseInt(limitStr, 0, 64) + if err != nil { + return nil, fmt.Errorf("%s [%s] is invalid", LimitKey, limitStr) + } + + offsetStr, err := funcutil.GetAttrByKeyFromRepeatedKV(OffsetKey, rankParamsPair) + if err == nil { + offset, err = strconv.ParseInt(offsetStr, 0, 64) + if err != nil { + return nil, fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr) + } + } + + // validate max result window. + if err = validateMaxQueryResultWindow(offset, limit); err != nil { + return nil, fmt.Errorf("invalid max query result window, %w", err) + } + + roundDecimalStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RoundDecimalKey, rankParamsPair) + if err != nil { + roundDecimalStr = "-1" + } + + roundDecimal, err = strconv.ParseInt(roundDecimalStr, 0, 64) + if err != nil { + return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr) + } + + if roundDecimal != -1 && (roundDecimal > 6 || roundDecimal < 0) { + return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr) + } + + return &rankParams{ + limit: limit, + offset: offset, + roundDecimal: roundDecimal, + }, nil +} + +func (t *hybridSearchTask) PostExecute(ctx context.Context) error { + ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-HybridSearch-PostExecute") + defer sp.End() + + log := log.Ctx(ctx).With(zap.Int64("collID", t.collectionID), zap.String("collName", t.request.GetCollectionName())) + tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy postExecute hybrid search %d", t.ID())) + defer func() { + tr.CtxElapse(ctx, "done") + }() + + primaryFieldSchema, err := t.schema.GetPkField() + if err != nil { + log.Warn("failed to get primary field schema", zap.Error(err)) + return err + } + + rankParams, err := parseRankParams(t.request.GetRankParams()) + if err != nil { + return err + } + + t.result, err = rankSearchResultData(ctx, 1, + rankParams, + primaryFieldSchema.GetDataType(), + t.multipleRecallResults.Collect()) + if err != nil { + log.Warn("rank search result failed", zap.Error(err)) + return err + } + + t.result.CollectionName = t.request.GetCollectionName() + t.fillInFieldInfo() + + if t.requery { + err := t.Requery() + if err != nil { + log.Warn("failed to requery", zap.Error(err)) + return err + } + } + t.result.Results.OutputFields = t.userOutputFields + + log.Debug("hybrid search post execute done") + return nil +} + +func (t *hybridSearchTask) Requery() error { + queryReq := &milvuspb.QueryRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Retrieve, + }, + DbName: t.request.GetDbName(), + CollectionName: t.request.GetCollectionName(), + Expr: "", + OutputFields: t.request.GetOutputFields(), + PartitionNames: t.request.GetPartitionNames(), + GuaranteeTimestamp: t.request.GetGuaranteeTimestamp(), + TravelTimestamp: t.request.GetTravelTimestamp(), + NotReturnAllMeta: t.request.GetNotReturnAllMeta(), + ConsistencyLevel: t.request.GetConsistencyLevel(), + UseDefaultConsistency: t.request.GetUseDefaultConsistency(), + } + + return doRequery(t.ctx, t.collectionID, t.node, t.schema.CollectionSchema, queryReq, t.result) +} + +func rankSearchResultData(ctx context.Context, + nq int64, + params *rankParams, + pkType schemapb.DataType, + searchResults []*milvuspb.SearchResults, +) (*milvuspb.SearchResults, error) { + tr := timerecord.NewTimeRecorder("rankSearchResultData") + defer func() { + tr.CtxElapse(ctx, "done") + }() + + offset := params.offset + limit := params.limit + topk := limit + offset + roundDecimal := params.roundDecimal + log.Ctx(ctx).Debug("rankSearchResultData", + zap.Int("len(searchResults)", len(searchResults)), + zap.Int64("nq", nq), + zap.Int64("offset", offset), + zap.Int64("limit", limit)) + + ret := &milvuspb.SearchResults{ + Status: merr.Success(), + Results: &schemapb.SearchResultData{ + NumQueries: nq, + TopK: limit, + FieldsData: make([]*schemapb.FieldData, 0), + Scores: []float32{}, + Ids: &schemapb.IDs{}, + Topks: []int64{}, + }, + } + + switch pkType { + case schemapb.DataType_Int64: + ret.GetResults().Ids.IdField = &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: make([]int64, 0), + }, + } + case schemapb.DataType_VarChar: + ret.GetResults().Ids.IdField = &schemapb.IDs_StrId{ + StrId: &schemapb.StringArray{ + Data: make([]string, 0), + }, + } + default: + return nil, errors.New("unsupported pk type") + } + + // []map[id]score + accumulatedScores := make([]map[interface{}]float32, nq) + for i := int64(0); i < nq; i++ { + accumulatedScores[i] = make(map[interface{}]float32) + } + + for _, result := range searchResults { + scores := result.GetResults().GetScores() + start := int64(0) + for i := int64(0); i < nq; i++ { + realTopk := result.GetResults().Topks[i] + for j := start; j < start+realTopk; j++ { + id := typeutil.GetPK(result.GetResults().GetIds(), j) + accumulatedScores[i][id] += scores[j] + } + start += realTopk + } + } + + for i := int64(0); i < nq; i++ { + idSet := accumulatedScores[i] + keys := make([]interface{}, 0) + for key := range idSet { + keys = append(keys, key) + } + + if int64(len(keys)) <= offset { + ret.Results.Topks = append(ret.Results.Topks, 0) + continue + } + + // sort id by score + sort.Slice(keys, func(i, j int) bool { + return idSet[keys[i]] >= idSet[keys[j]] + }) + + if int64(len(keys)) > topk { + keys = keys[:topk] + } + + // set real topk + ret.Results.Topks = append(ret.Results.Topks, int64(len(keys))-offset) + // append id and score + for index := offset; index < int64(len(keys)); index++ { + typeutil.AppendPKs(ret.Results.Ids, keys[index]) + score := idSet[keys[index]] + if roundDecimal != -1 { + multiplier := math.Pow(10.0, float64(roundDecimal)) + score = float32(math.Floor(float64(score)*multiplier+0.5) / multiplier) + } + ret.Results.Scores = append(ret.Results.Scores, score) + } + } + + return ret, nil +} + +func (t *hybridSearchTask) fillInFieldInfo() { + if len(t.request.OutputFields) != 0 && len(t.result.Results.FieldsData) != 0 { + for i, name := range t.request.OutputFields { + for _, field := range t.schema.Fields { + if t.result.Results.FieldsData[i] != nil && field.Name == name { + t.result.Results.FieldsData[i].FieldName = field.Name + t.result.Results.FieldsData[i].FieldId = field.FieldID + t.result.Results.FieldsData[i].Type = field.DataType + t.result.Results.FieldsData[i].IsDynamic = field.IsDynamic + } + } + } + } +} + +func (t *hybridSearchTask) TraceCtx() context.Context { + return t.ctx +} + +func (t *hybridSearchTask) ID() UniqueID { + return t.request.Base.MsgID +} + +func (t *hybridSearchTask) SetID(uid UniqueID) { + t.request.Base.MsgID = uid +} + +func (t *hybridSearchTask) Name() string { + return HybridSearchTaskName +} + +func (t *hybridSearchTask) Type() commonpb.MsgType { + return t.request.Base.MsgType +} + +func (t *hybridSearchTask) BeginTs() Timestamp { + return t.request.Base.Timestamp +} + +func (t *hybridSearchTask) EndTs() Timestamp { + return t.request.Base.Timestamp +} + +func (t *hybridSearchTask) SetTs(ts Timestamp) { + t.request.Base.Timestamp = ts +} + +func (t *hybridSearchTask) OnEnqueue() error { + t.request.Base = commonpbutil.NewMsgBase() + t.request.Base.MsgType = commonpb.MsgType_Search + t.request.Base.SourceID = paramtable.GetNodeID() + return nil +} diff --git a/internal/proxy/task_hybrid_search_test.go b/internal/proxy/task_hybrid_search_test.go new file mode 100644 index 0000000000000..0cee1f89db90a --- /dev/null +++ b/internal/proxy/task_hybrid_search_test.go @@ -0,0 +1,330 @@ +package proxy + +import ( + "context" + "strconv" + "testing" + "time" + + "github.com/cockroachdb/errors" + "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/mocks" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/dependency" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/timerecord" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +func createCollWithMultiVecField(t *testing.T, name string, rc types.RootCoordClient) { + schema := genCollectionSchema(name) + marshaledSchema, err := proto.Marshal(schema) + require.NoError(t, err) + ctx := context.TODO() + + createColT := &createCollectionTask{ + Condition: NewTaskCondition(context.TODO()), + CreateCollectionRequest: &milvuspb.CreateCollectionRequest{ + CollectionName: name, + Schema: marshaledSchema, + ShardsNum: common.DefaultShardsNum, + }, + ctx: ctx, + rootCoord: rc, + } + + require.NoError(t, createColT.OnEnqueue()) + require.NoError(t, createColT.PreExecute(ctx)) + require.NoError(t, createColT.Execute(ctx)) + require.NoError(t, createColT.PostExecute(ctx)) +} + +func TestHybridSearchTask_PreExecute(t *testing.T) { + var err error + + var ( + rc = NewRootCoordMock() + qc = mocks.NewMockQueryCoordClient(t) + ctx = context.TODO() + ) + + defer rc.Close() + require.NoError(t, err) + mgr := newShardClientMgr() + err = InitMetaCache(ctx, rc, qc, mgr) + require.NoError(t, err) + + genHybridSearchTaskWithNq := func(t *testing.T, collName string, reqs []*milvuspb.SearchRequest) *hybridSearchTask { + task := &hybridSearchTask{ + ctx: ctx, + Condition: NewTaskCondition(ctx), + request: &milvuspb.HybridSearchRequest{ + CollectionName: collName, + Requests: reqs, + }, + qc: qc, + tr: timerecord.NewTimeRecorder("test-hybrid-search"), + } + require.NoError(t, task.OnEnqueue()) + return task + } + + t.Run("bad nq 0", func(t *testing.T) { + collName := "test_bad_nq0_error" + funcutil.GenRandomStr() + createCollWithMultiVecField(t, collName, rc) + // Nq must be 1. + task := genHybridSearchTaskWithNq(t, collName, []*milvuspb.SearchRequest{{Nq: 0}}) + err = task.PreExecute(ctx) + assert.Error(t, err) + }) + + t.Run("bad req num 0", func(t *testing.T) { + collName := "test_bad_req_num0_error" + funcutil.GenRandomStr() + createCollWithMultiVecField(t, collName, rc) + // num of reqs must be [1, 1024]. + task := genHybridSearchTaskWithNq(t, collName, nil) + err = task.PreExecute(ctx) + assert.Error(t, err) + }) + + t.Run("bad req num 1025", func(t *testing.T) { + collName := "test_bad_req_num1025_error" + funcutil.GenRandomStr() + createCollWithMultiVecField(t, collName, rc) + // num of reqs must be [1, 1024]. + reqs := make([]*milvuspb.SearchRequest, 0) + for i := 0; i <= defaultMaxSearchRequest; i++ { + reqs = append(reqs, &milvuspb.SearchRequest{ + CollectionName: collName, + Nq: 1, + }) + } + task := genHybridSearchTaskWithNq(t, collName, reqs) + err = task.PreExecute(ctx) + assert.Error(t, err) + }) + + t.Run("collection not exist", func(t *testing.T) { + collName := "test_collection_not_exist" + funcutil.GenRandomStr() + task := genHybridSearchTaskWithNq(t, collName, []*milvuspb.SearchRequest{{Nq: 1}}) + err = task.PreExecute(ctx) + assert.Error(t, err) + }) + + t.Run("hybrid search with timeout", func(t *testing.T) { + collName := "hybrid_search_with_timeout" + funcutil.GenRandomStr() + createCollWithMultiVecField(t, collName, rc) + + task := genHybridSearchTaskWithNq(t, collName, []*milvuspb.SearchRequest{{Nq: 1}}) + + ctxTimeout, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + + task.ctx = ctxTimeout + task.request.OutputFields = []string{testFloatVecField} + assert.NoError(t, task.PreExecute(ctx)) + }) +} + +func TestHybridSearchTask_ErrExecute(t *testing.T) { + var ( + err error + ctx = context.TODO() + + rc = NewRootCoordMock() + qc = getQueryCoordClient() + qn = getQueryNodeClient() + + collectionName = t.Name() + funcutil.GenRandomStr() + ) + + qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() + + mgr := NewMockShardClientManager(t) + mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe() + mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe() + lb := NewLBPolicyImpl(mgr) + + factory := dependency.NewDefaultFactory(true) + node, err := NewProxy(ctx, factory) + assert.NoError(t, err) + node.UpdateStateCode(commonpb.StateCode_Healthy) + node.tsoAllocator = ×tampAllocator{ + tso: newMockTimestampAllocatorInterface(), + } + scheduler, err := newTaskScheduler(ctx, node.tsoAllocator, factory) + assert.NoError(t, err) + node.sched = scheduler + err = node.sched.Start() + assert.NoError(t, err) + err = node.initRateCollector() + assert.NoError(t, err) + node.rootCoord = rc + node.queryCoord = qc + + defer qc.Close() + + err = InitMetaCache(ctx, rc, qc, mgr) + assert.NoError(t, err) + + createCollWithMultiVecField(t, collectionName, rc) + + collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName) + assert.NoError(t, err) + + schema, err := globalMetaCache.GetCollectionSchema(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName) + assert.NoError(t, err) + + successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success} + qc.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(successStatus, nil) + qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{ + Status: successStatus, + Shards: []*querypb.ShardLeadersList{ + { + ChannelName: "channel-1", + NodeIds: []int64{1}, + NodeAddrs: []string{"localhost:9000"}, + }, + }, + }, nil) + qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{ + Status: successStatus, + CollectionIDs: []int64{collectionID}, + InMemoryPercentages: []int64{100}, + }, nil) + status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_LoadCollection, + SourceID: paramtable.GetNodeID(), + }, + CollectionID: collectionID, + }) + require.NoError(t, err) + require.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode) + + vectorFields := typeutil.GetVectorFieldSchemas(schema.CollectionSchema) + vectorFieldNames := make([]string, len(vectorFields)) + for i, field := range vectorFields { + vectorFieldNames[i] = field.GetName() + } + + // test begins + task := &hybridSearchTask{ + Condition: NewTaskCondition(ctx), + ctx: ctx, + result: &milvuspb.SearchResults{ + Status: merr.Success(), + }, + request: &milvuspb.HybridSearchRequest{ + CollectionName: collectionName, + Requests: []*milvuspb.SearchRequest{ + { + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Search, + SourceID: paramtable.GetNodeID(), + }, + CollectionName: collectionName, + Nq: 1, + DslType: commonpb.DslType_BoolExprV1, + SearchParams: []*commonpb.KeyValuePair{ + {Key: AnnsFieldKey, Value: testFloatVecField}, + {Key: TopKKey, Value: "10"}, + }, + }, + { + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Search, + SourceID: paramtable.GetNodeID(), + }, + CollectionName: collectionName, + Nq: 1, + DslType: commonpb.DslType_BoolExprV1, + SearchParams: []*commonpb.KeyValuePair{ + {Key: AnnsFieldKey, Value: testBinaryVecField}, + {Key: TopKKey, Value: "10"}, + }, + }, + }, + OutputFields: vectorFieldNames, + }, + qc: qc, + lb: lb, + node: node, + } + + assert.NoError(t, task.OnEnqueue()) + task.ctx = ctx + assert.NoError(t, task.PreExecute(ctx)) + + qn.EXPECT().Search(mock.Anything, mock.Anything).Return(nil, errors.New("mock error")) + assert.Error(t, task.Execute(ctx)) + + qn.ExpectedCalls = nil + qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() + qn.EXPECT().Search(mock.Anything, mock.Anything).Return(&internalpb.SearchResults{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + }, + }, nil) + assert.Error(t, task.Execute(ctx)) +} + +func TestHybridSearchTask_PostExecute(t *testing.T) { + var ( + rc = NewRootCoordMock() + qc = getQueryCoordClient() + qn = getQueryNodeClient() + collectionName = t.Name() + funcutil.GenRandomStr() + ) + + defer rc.Close() + mgr := NewMockShardClientManager(t) + mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe() + mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe() + + t.Run("Test empty result", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + err := InitMetaCache(ctx, rc, qc, mgr) + assert.NoError(t, err) + createCollWithMultiVecField(t, collectionName, rc) + + schema, err := globalMetaCache.GetCollectionSchema(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName) + assert.NoError(t, err) + + rankParams := []*commonpb.KeyValuePair{ + {Key: LimitKey, Value: strconv.Itoa(3)}, + {Key: OffsetKey, Value: strconv.Itoa(2)}, + } + qt := &hybridSearchTask{ + ctx: ctx, + Condition: NewTaskCondition(context.TODO()), + qc: nil, + tr: timerecord.NewTimeRecorder("search"), + schema: schema, + request: &milvuspb.HybridSearchRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Search, + }, + CollectionName: collectionName, + RankParams: rankParams, + }, + multipleRecallResults: typeutil.NewConcurrentSet[*milvuspb.SearchResults](), + } + + err = qt.PostExecute(context.TODO()) + assert.NoError(t, err) + assert.Equal(t, qt.result.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success) + }) +} diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index bf34500516996..b1ffd6a40b1cb 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -606,13 +606,6 @@ func (t *searchTask) estimateResultSize(nq int64, topK int64) (int64, error) { } func (t *searchTask) Requery() error { - pkField, err := t.schema.GetPkField() - if err != nil { - return err - } - ids := t.result.GetResults().GetIds() - plan := planparserv2.CreateRequeryPlan(pkField, ids) - queryReq := &milvuspb.QueryRequest{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_Retrieve, @@ -625,9 +618,72 @@ func (t *searchTask) Requery() error { GuaranteeTimestamp: t.request.GetGuaranteeTimestamp(), QueryParams: t.request.GetSearchParams(), } + + return doRequery(t.ctx, t.GetCollectionID(), t.node, t.schema.CollectionSchema, queryReq, t.result) +} + +func (t *searchTask) fillInEmptyResult(numQueries int64) { + t.result = &milvuspb.SearchResults{ + Status: merr.Success("search result is empty"), + CollectionName: t.collectionName, + Results: &schemapb.SearchResultData{ + NumQueries: numQueries, + Topks: make([]int64, numQueries), + }, + } +} + +func (t *searchTask) fillInFieldInfo() { + if len(t.request.OutputFields) != 0 && len(t.result.Results.FieldsData) != 0 { + for i, name := range t.request.OutputFields { + for _, field := range t.schema.Fields { + if t.result.Results.FieldsData[i] != nil && field.Name == name { + t.result.Results.FieldsData[i].FieldName = field.Name + t.result.Results.FieldsData[i].FieldId = field.FieldID + t.result.Results.FieldsData[i].Type = field.DataType + t.result.Results.FieldsData[i].IsDynamic = field.IsDynamic + } + } + } + } +} + +func (t *searchTask) collectSearchResults(ctx context.Context) ([]*internalpb.SearchResults, error) { + select { + case <-t.TraceCtx().Done(): + log.Ctx(ctx).Warn("search task wait to finish timeout!") + return nil, fmt.Errorf("search task wait to finish timeout, msgID=%d", t.ID()) + default: + toReduceResults := make([]*internalpb.SearchResults, 0) + log.Ctx(ctx).Debug("all searches are finished or canceled") + t.resultBuf.Range(func(res *internalpb.SearchResults) bool { + toReduceResults = append(toReduceResults, res) + log.Ctx(ctx).Debug("proxy receives one search result", + zap.Int64("sourceID", res.GetBase().GetSourceID())) + return true + }) + return toReduceResults, nil + } +} + +func doRequery(ctx context.Context, + collectionID int64, + node types.ProxyComponent, + schema *schemapb.CollectionSchema, + request *milvuspb.QueryRequest, + result *milvuspb.SearchResults, +) error { + outputFields := request.GetOutputFields() + pkField, err := typeutil.GetPrimaryFieldSchema(schema) + if err != nil { + return err + } + ids := result.GetResults().GetIds() + plan := planparserv2.CreateRequeryPlan(pkField, ids) + qt := &queryTask{ - ctx: t.ctx, - Condition: NewTaskCondition(t.ctx), + ctx: ctx, + Condition: NewTaskCondition(ctx), RetrieveRequest: &internalpb.RetrieveRequest{ Base: commonpbutil.NewMsgBase( commonpbutil.WithMsgType(commonpb.MsgType_Retrieve), @@ -635,12 +691,12 @@ func (t *searchTask) Requery() error { ), ReqID: paramtable.GetNodeID(), }, - request: queryReq, + request: request, plan: plan, - qc: t.node.(*Proxy).queryCoord, - lb: t.node.(*Proxy).lbPolicy, + qc: node.(*Proxy).queryCoord, + lb: node.(*Proxy).lbPolicy, } - queryResult, err := t.node.(*Proxy).query(t.ctx, qt) + queryResult, err := node.(*Proxy).query(ctx, qt) if err != nil { return err } @@ -672,68 +728,24 @@ func (t *searchTask) Requery() error { offsets[pk] = i } - t.result.Results.FieldsData = make([]*schemapb.FieldData, len(queryResult.GetFieldsData())) + result.Results.FieldsData = make([]*schemapb.FieldData, len(queryResult.GetFieldsData())) for i := 0; i < typeutil.GetSizeOfIDs(ids); i++ { id := typeutil.GetPK(ids, int64(i)) if _, ok := offsets[id]; !ok { return fmt.Errorf("incomplete query result, missing id %s, len(searchIDs) = %d, len(queryIDs) = %d, collection=%d", - id, typeutil.GetSizeOfIDs(ids), len(offsets), t.GetCollectionID()) + id, typeutil.GetSizeOfIDs(ids), len(offsets), collectionID) } - typeutil.AppendFieldData(t.result.Results.FieldsData, queryResult.GetFieldsData(), int64(offsets[id])) + typeutil.AppendFieldData(result.Results.FieldsData, queryResult.GetFieldsData(), int64(offsets[id])) } // filter id field out if it is not specified as output - t.result.Results.FieldsData = lo.Filter(t.result.Results.FieldsData, func(fieldData *schemapb.FieldData, i int) bool { - return lo.Contains(t.request.GetOutputFields(), fieldData.GetFieldName()) + result.Results.FieldsData = lo.Filter(result.Results.FieldsData, func(fieldData *schemapb.FieldData, i int) bool { + return lo.Contains(outputFields, fieldData.GetFieldName()) }) return nil } -func (t *searchTask) fillInEmptyResult(numQueries int64) { - t.result = &milvuspb.SearchResults{ - Status: merr.Success("search result is empty"), - CollectionName: t.collectionName, - Results: &schemapb.SearchResultData{ - NumQueries: numQueries, - Topks: make([]int64, numQueries), - }, - } -} - -func (t *searchTask) fillInFieldInfo() { - if len(t.request.OutputFields) != 0 && len(t.result.Results.FieldsData) != 0 { - for i, name := range t.request.OutputFields { - for _, field := range t.schema.Fields { - if t.result.Results.FieldsData[i] != nil && field.Name == name { - t.result.Results.FieldsData[i].FieldName = field.Name - t.result.Results.FieldsData[i].FieldId = field.FieldID - t.result.Results.FieldsData[i].Type = field.DataType - t.result.Results.FieldsData[i].IsDynamic = field.IsDynamic - } - } - } - } -} - -func (t *searchTask) collectSearchResults(ctx context.Context) ([]*internalpb.SearchResults, error) { - select { - case <-t.TraceCtx().Done(): - log.Ctx(ctx).Warn("search task wait to finish timeout!") - return nil, fmt.Errorf("search task wait to finish timeout, msgID=%d", t.ID()) - default: - toReduceResults := make([]*internalpb.SearchResults, 0) - log.Ctx(ctx).Debug("all searches are finished or canceled") - t.resultBuf.Range(func(res *internalpb.SearchResults) bool { - toReduceResults = append(toReduceResults, res) - log.Ctx(ctx).Debug("proxy receives one search result", - zap.Int64("sourceID", res.GetBase().GetSourceID())) - return true - }) - return toReduceResults, nil - } -} - func decodeSearchResults(ctx context.Context, searchResults []*internalpb.SearchResults) ([]*schemapb.SearchResultData, error) { tr := timerecord.NewTimeRecorder("decodeSearchResults") results := make([]*schemapb.SearchResultData, 0) diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index eb3acf48d7889..cfd9792178887 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -71,6 +71,20 @@ const ( testMaxVarCharLength = 100 ) +func genCollectionSchema(collectionName string) *schemapb.CollectionSchema { + return constructCollectionSchemaWithAllType( + testBoolField, + testInt32Field, + testInt64Field, + testFloatField, + testDoubleField, + testFloatVecField, + testBinaryVecField, + testFloat16VecField, + testVecDim, + collectionName) +} + func constructCollectionSchema( int64Field, floatVecField string, dim int, diff --git a/internal/proxy/util.go b/internal/proxy/util.go index df726c2d21bc1..c94d604450cf5 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -62,6 +62,8 @@ const ( defaultMaxArrayCapacity = 4096 + defaultMaxSearchRequest = 1024 + // DefaultArithmeticIndexType name of default index type for scalar field DefaultArithmeticIndexType = "STL_SORT" @@ -69,6 +71,8 @@ const ( DefaultStringIndexType = "Trie" InvertedIndexType = "INVERTED" + + defaultRRFParamsValue = 60 ) var logger = log.L().WithOptions(zap.Fields(zap.String("role", typeutil.ProxyRole))) diff --git a/internal/querycoordv2/task/executor.go b/internal/querycoordv2/task/executor.go index 76e5bfed5b157..cb4fed75ce5f5 100644 --- a/internal/querycoordv2/task/executor.go +++ b/internal/querycoordv2/task/executor.go @@ -175,7 +175,6 @@ func (ex *Executor) loadSegment(task *SegmentTask, step int) error { loadMeta := packLoadMeta( ex.meta.GetLoadType(task.CollectionID()), - "", task.CollectionID(), partitions..., ) @@ -370,14 +369,8 @@ func (ex *Executor) subscribeChannel(task *ChannelTask, step int) error { log.Warn("fail to get index meta of collection") return err } - metricType, err := getMetricType(indexInfo, collectionInfo.GetSchema()) - if err != nil { - log.Warn("failed to get metric type", zap.Error(err)) - return err - } loadMeta := packLoadMeta( ex.meta.GetLoadType(task.CollectionID()), - metricType, task.CollectionID(), partitions..., ) diff --git a/internal/querycoordv2/task/utils.go b/internal/querycoordv2/task/utils.go index 84a8f4f3bcc1d..c0ecd97f068d7 100644 --- a/internal/querycoordv2/task/utils.go +++ b/internal/querycoordv2/task/utils.go @@ -21,8 +21,6 @@ import ( "fmt" "time" - "github.com/samber/lo" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" @@ -33,7 +31,6 @@ import ( "github.com/milvus-io/milvus/internal/querycoordv2/utils" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/commonpbutil" - "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -162,12 +159,11 @@ func packReleaseSegmentRequest(task *SegmentTask, action *SegmentAction) *queryp } } -func packLoadMeta(loadType querypb.LoadType, metricType string, collectionID int64, partitions ...int64) *querypb.LoadMetaInfo { +func packLoadMeta(loadType querypb.LoadType, collectionID int64, partitions ...int64) *querypb.LoadMetaInfo { return &querypb.LoadMetaInfo{ LoadType: loadType, CollectionID: collectionID, PartitionIDs: partitions, - MetricType: metricType, } } @@ -241,22 +237,3 @@ func getShardLeader(replicaMgr *meta.ReplicaManager, distMgr *meta.DistributionM } return distMgr.GetShardLeader(replica, channel) } - -func getMetricType(indexInfos []*indexpb.IndexInfo, schema *schemapb.CollectionSchema) (string, error) { - vecField, err := typeutil.GetVectorFieldSchema(schema) - if err != nil { - return "", err - } - indexInfo, ok := lo.Find(indexInfos, func(info *indexpb.IndexInfo) bool { - return info.GetFieldID() == vecField.GetFieldID() - }) - if !ok || indexInfo == nil { - err = fmt.Errorf("cannot find index info for %s field", vecField.GetName()) - return "", err - } - metricType, err := funcutil.GetAttrByKeyFromRepeatedKV(common.MetricTypeKey, indexInfo.GetIndexParams()) - if err != nil { - return "", err - } - return metricType, nil -} diff --git a/internal/querycoordv2/task/utils_test.go b/internal/querycoordv2/task/utils_test.go index bd685344a0546..2504bf9ca4d06 100644 --- a/internal/querycoordv2/task/utils_test.go +++ b/internal/querycoordv2/task/utils_test.go @@ -26,7 +26,6 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/pkg/common" ) @@ -35,57 +34,6 @@ type UtilsSuite struct { suite.Suite } -func (s *UtilsSuite) TestGetMetricType() { - collection := int64(1) - schema := &schemapb.CollectionSchema{ - Name: "TestGetMetricType", - Fields: []*schemapb.FieldSchema{ - {FieldID: 100, Name: "vec", DataType: schemapb.DataType_FloatVector}, - }, - } - indexInfo := &indexpb.IndexInfo{ - CollectionID: collection, - FieldID: 100, - IndexParams: []*commonpb.KeyValuePair{ - { - Key: common.MetricTypeKey, - Value: "L2", - }, - }, - } - - indexInfo2 := &indexpb.IndexInfo{ - CollectionID: collection, - FieldID: 100, - } - - s.Run("test normal", func() { - metricType, err := getMetricType([]*indexpb.IndexInfo{indexInfo}, schema) - s.NoError(err) - s.Equal("L2", metricType) - }) - - s.Run("test get vec field failed", func() { - _, err := getMetricType([]*indexpb.IndexInfo{indexInfo}, &schemapb.CollectionSchema{ - Name: "TestGetMetricType", - }) - s.Error(err) - }) - s.Run("test field id mismatch", func() { - _, err := getMetricType([]*indexpb.IndexInfo{indexInfo}, &schemapb.CollectionSchema{ - Name: "TestGetMetricType", - Fields: []*schemapb.FieldSchema{ - {FieldID: -1, Name: "vec", DataType: schemapb.DataType_FloatVector}, - }, - }) - s.Error(err) - }) - s.Run("test no metric type", func() { - _, err := getMetricType([]*indexpb.IndexInfo{indexInfo2}, schema) - s.Error(err) - }) -} - func (s *UtilsSuite) TestPackLoadSegmentRequest() { ctx := context.Background() diff --git a/internal/querynodev2/mock_data.go b/internal/querynodev2/mock_data.go index ef884a3234270..fafc6bdf543fa 100644 --- a/internal/querynodev2/mock_data.go +++ b/internal/querynodev2/mock_data.go @@ -17,12 +17,10 @@ package querynodev2 import ( - "fmt" "math" "math/rand" "strconv" - "github.com/cockroachdb/errors" "github.com/golang/protobuf/proto" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" @@ -60,45 +58,32 @@ const ( // ---------- unittest util functions ---------- // functions of messages and requests -func genBruteForceDSL(schema *schemapb.CollectionSchema, topK int64, roundDecimal int64) (string, error) { - var vecFieldName string - var metricType string - topKStr := strconv.FormatInt(topK, 10) - nProbStr := strconv.Itoa(defaultNProb) - roundDecimalStr := strconv.FormatInt(roundDecimal, 10) - var fieldID int64 - for _, f := range schema.Fields { - if f.DataType == schemapb.DataType_FloatVector { - vecFieldName = f.Name - fieldID = f.FieldID - for _, p := range f.IndexParams { - if p.Key == metricTypeKey { - metricType = p.Value - } - } - } - } - if vecFieldName == "" || metricType == "" { - err := errors.New("invalid vector field name or metric type") - return "", err +func genSearchPlan(dataType schemapb.DataType, fieldID int64, metricType string) *planpb.PlanNode { + var vectorType planpb.VectorType + switch dataType { + case schemapb.DataType_FloatVector: + vectorType = planpb.VectorType_FloatVector + case schemapb.DataType_Float16Vector: + vectorType = planpb.VectorType_Float16Vector + case schemapb.DataType_BinaryVector: + vectorType = planpb.VectorType_BinaryVector } - return `vector_anns: < - field_id: ` + fmt.Sprintf("%d", fieldID) + ` - query_info: < - topk: ` + topKStr + ` - round_decimal: ` + roundDecimalStr + ` - metric_type: "` + metricType + `" - search_params: "{\"nprobe\": ` + nProbStr + `}" - > - placeholder_tag: "$0" - >`, nil -} -func genDSLByIndexType(schema *schemapb.CollectionSchema, indexType string) (string, error) { - if indexType == IndexFaissIDMap { // float vector - return genBruteForceDSL(schema, defaultTopK, defaultRoundDecimal) + return &planpb.PlanNode{ + Node: &planpb.PlanNode_VectorAnns{ + VectorAnns: &planpb.VectorANNS{ + VectorType: vectorType, + FieldId: fieldID, + QueryInfo: &planpb.QueryInfo{ + Topk: defaultTopK, + MetricType: metricType, + SearchParams: "{\"nprobe\":" + strconv.Itoa(defaultNProb) + "}", + RoundDecimal: defaultRoundDecimal, + }, + PlaceholderTag: "$0", + }, + }, } - return "", fmt.Errorf("Invalid indexType") } func genPlaceHolderGroup(nq int64) ([]byte, error) { diff --git a/internal/querynodev2/segments/collection.go b/internal/querynodev2/segments/collection.go index cd7ad2ce25ed5..89f12b463b083 100644 --- a/internal/querynodev2/segments/collection.go +++ b/internal/querynodev2/segments/collection.go @@ -83,7 +83,6 @@ func (m *collectionManager) PutOrRef(collectionID int64, schema *schemapb.Collec } collection := NewCollection(collectionID, schema, meta, loadMeta.GetLoadType()) - collection.metricType.Store(loadMeta.GetMetricType()) collection.AddPartition(loadMeta.GetPartitionIDs()...) collection.Ref(1) m.collections[collectionID] = collection @@ -125,7 +124,7 @@ type Collection struct { id int64 partitions *typeutil.ConcurrentSet[int64] loadType querypb.LoadType - metricType atomic.String + metricType atomic.String // deprecated schema atomic.Pointer[schemapb.CollectionSchema] isGpuIndex bool @@ -175,14 +174,6 @@ func (c *Collection) GetLoadType() querypb.LoadType { return c.loadType } -func (c *Collection) SetMetricType(metricType string) { - c.metricType.Store(metricType) -} - -func (c *Collection) GetMetricType() string { - return c.metricType.Load() -} - func (c *Collection) Ref(count uint32) uint32 { refCount := c.refCount.Add(count) log.Debug("collection ref increment", diff --git a/internal/querynodev2/segments/mock_data.go b/internal/querynodev2/segments/mock_data.go index 19a199ef75741..212af05914c5e 100644 --- a/internal/querynodev2/segments/mock_data.go +++ b/internal/querynodev2/segments/mock_data.go @@ -291,7 +291,56 @@ func GenTestCollectionSchema(collectionName string, pkType schemapb.DataType) *s return &schema } +func GenTestIndexInfoList(collectionID int64, schema *schemapb.CollectionSchema) []*indexpb.IndexInfo { + res := make([]*indexpb.IndexInfo, 0) + vectorFieldSchemas := typeutil.GetVectorFieldSchemas(schema) + for _, field := range vectorFieldSchemas { + index := &indexpb.IndexInfo{ + CollectionID: collectionID, + FieldID: field.GetFieldID(), + // For now, a field can only have one index + // using fieldID and fieldName as indexID and indexName, just make sure not repeated. + IndexID: field.GetFieldID(), + IndexName: field.GetName(), + TypeParams: field.GetTypeParams(), + } + switch field.GetDataType() { + case schemapb.DataType_FloatVector, schemapb.DataType_Float16Vector: + { + index.IndexParams = []*commonpb.KeyValuePair{ + {Key: common.MetricTypeKey, Value: metric.L2}, + {Key: common.IndexTypeKey, Value: IndexFaissIVFFlat}, + {Key: "nlist", Value: "128"}, + } + } + case schemapb.DataType_BinaryVector: + { + index.IndexParams = []*commonpb.KeyValuePair{ + {Key: common.MetricTypeKey, Value: metric.JACCARD}, + {Key: common.IndexTypeKey, Value: IndexFaissBinIVFFlat}, + {Key: "nlist", Value: "128"}, + } + } + } + res = append(res, index) + } + return res +} + func GenTestIndexMeta(collectionID int64, schema *schemapb.CollectionSchema) *segcorepb.CollectionIndexMeta { + indexInfos := GenTestIndexInfoList(collectionID, schema) + fieldIndexMetas := make([]*segcorepb.FieldIndexMeta, 0) + for _, info := range indexInfos { + fieldIndexMetas = append(fieldIndexMetas, &segcorepb.FieldIndexMeta{ + CollectionID: info.GetCollectionID(), + FieldID: info.GetFieldID(), + IndexName: info.GetIndexName(), + TypeParams: info.GetTypeParams(), + IndexParams: info.GetIndexParams(), + IsAutoIndex: info.GetIsAutoIndex(), + UserIndexParams: info.GetUserIndexParams(), + }) + } sizePerRecord, err := typeutil.EstimateSizePerRecord(schema) maxIndexRecordPerSegment := int64(0) if err != nil || sizePerRecord == 0 { @@ -302,37 +351,6 @@ func GenTestIndexMeta(collectionID int64, schema *schemapb.CollectionSchema) *se maxIndexRecordPerSegment = int64(threshold * proportion / float64(sizePerRecord)) } - fieldIndexMetas := make([]*segcorepb.FieldIndexMeta, 0) - fieldIndexMetas = append(fieldIndexMetas, &segcorepb.FieldIndexMeta{ - CollectionID: collectionID, - FieldID: simpleFloatVecField.id, - IndexName: "querynode-test", - TypeParams: []*commonpb.KeyValuePair{ - { - Key: dimKey, - Value: strconv.Itoa(simpleFloatVecField.dim), - }, - }, - IndexParams: []*commonpb.KeyValuePair{ - { - Key: metricTypeKey, - Value: simpleFloatVecField.metricType, - }, - { - Key: common.IndexTypeKey, - Value: IndexFaissIVFFlat, - }, - { - Key: "nlist", - Value: "128", - }, - }, - IsAutoIndex: false, - UserIndexParams: []*commonpb.KeyValuePair{ - {}, - }, - }) - indexMeta := segcorepb.CollectionIndexMeta{ MaxIndexRowCount: maxIndexRecordPerSegment, IndexMetas: fieldIndexMetas, @@ -889,6 +907,80 @@ func SaveDeltaLog(collectionID int64, return fieldBinlog, cm.MultiWrite(context.Background(), kvs) } +func GenAndSaveIndexV2(collectionID, partitionID, segmentID, buildID int64, + fieldSchema *schemapb.FieldSchema, + indexInfo *indexpb.IndexInfo, + cm storage.ChunkManager, + msgLength int, +) (*querypb.FieldIndexInfo, error) { + typeParams := funcutil.KeyValuePair2Map(indexInfo.GetTypeParams()) + indexParams := funcutil.KeyValuePair2Map(indexInfo.GetIndexParams()) + + index, err := indexcgowrapper.NewCgoIndex(fieldSchema.GetDataType(), typeParams, indexParams) + if err != nil { + return nil, err + } + defer index.Delete() + + var dataset *indexcgowrapper.Dataset + switch fieldSchema.DataType { + case schemapb.DataType_BinaryVector: + dataset = indexcgowrapper.GenBinaryVecDataset(generateBinaryVectors(msgLength, defaultDim)) + case schemapb.DataType_FloatVector: + dataset = indexcgowrapper.GenFloatVecDataset(generateFloatVectors(msgLength, defaultDim)) + } + + err = index.Build(dataset) + if err != nil { + return nil, err + } + + // save index to minio + binarySet, err := index.Serialize() + if err != nil { + return nil, err + } + + // serialize index params + indexCodec := storage.NewIndexFileBinlogCodec() + serializedIndexBlobs, err := indexCodec.Serialize( + buildID, + 0, + collectionID, + partitionID, + segmentID, + fieldSchema.GetFieldID(), + indexParams, + indexInfo.GetIndexName(), + indexInfo.GetIndexID(), + binarySet, + ) + if err != nil { + return nil, err + } + + indexPaths := make([]string, 0) + for _, index := range serializedIndexBlobs { + indexPath := filepath.Join(cm.RootPath(), "index_files", + strconv.Itoa(int(segmentID)), index.Key) + indexPaths = append(indexPaths, indexPath) + err := cm.Write(context.Background(), indexPath, index.Value) + if err != nil { + return nil, err + } + } + _, cCurrentIndexVersion := getIndexEngineVersion() + + return &querypb.FieldIndexInfo{ + FieldID: fieldSchema.GetFieldID(), + EnableIndex: true, + IndexName: indexInfo.GetIndexName(), + IndexParams: indexInfo.GetIndexParams(), + IndexFilePaths: indexPaths, + CurrentIndexVersion: cCurrentIndexVersion, + }, nil +} + func GenAndSaveIndex(collectionID, partitionID, segmentID, fieldID int64, msgLength int, indexType, metricType string, cm storage.ChunkManager) (*querypb.FieldIndexInfo, error) { typeParams, indexParams := genIndexParams(indexType, metricType) diff --git a/internal/querynodev2/segments/plan.go b/internal/querynodev2/segments/plan.go index edc01d27e083c..3b85862d82d52 100644 --- a/internal/querynodev2/segments/plan.go +++ b/internal/querynodev2/segments/plan.go @@ -54,13 +54,7 @@ func createSearchPlanByExpr(ctx context.Context, col *Collection, expr []byte, m return nil, err1 } - newPlan := &SearchPlan{cSearchPlan: cPlan} - if len(metricType) != 0 { - newPlan.setMetricType(metricType) - } else { - newPlan.setMetricType(col.GetMetricType()) - } - return newPlan, nil + return &SearchPlan{cSearchPlan: cPlan}, nil } func (plan *SearchPlan) getTopK() int64 { diff --git a/internal/querynodev2/services.go b/internal/querynodev2/services.go index 3202ff66505a3..a1e7091da867c 100644 --- a/internal/querynodev2/services.go +++ b/internal/querynodev2/services.go @@ -205,7 +205,6 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm log.Info("received watch channel request", zap.Int64("version", req.GetVersion()), - zap.String("metricType", req.GetLoadMeta().GetMetricType()), ) // check node healthy @@ -219,12 +218,6 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm return merr.Status(err), nil } - // check metric type - if req.GetLoadMeta().GetMetricType() == "" { - err := fmt.Errorf("empty metric type, collection = %d", req.GetCollectionID()) - return merr.Status(err), nil - } - // check index if len(req.GetIndexInfoList()) == 0 { err := merr.WrapErrIndexNotFoundForCollection(req.GetSchema().GetName()) @@ -253,8 +246,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm node.manager.Collection.PutOrRef(req.GetCollectionID(), req.GetSchema(), node.composeIndexMeta(req.GetIndexInfoList(), req.Schema), req.GetLoadMeta()) - collection := node.manager.Collection.Get(req.GetCollectionID()) - collection.SetMetricType(req.GetLoadMeta().GetMetricType()) + delegator, err := delegator.NewShardDelegator( ctx, req.GetCollectionID(), @@ -769,20 +761,6 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) ( return resp, nil } - // Check if the metric type specified in search params matches the metric type in the index info. - if !req.GetFromShardLeader() && req.GetReq().GetMetricType() != "" { - if req.GetReq().GetMetricType() != collection.GetMetricType() { - resp.Status = merr.Status(merr.WrapErrParameterInvalid(collection.GetMetricType(), req.GetReq().GetMetricType(), - fmt.Sprintf("collection:%d, metric type not match", collection.ID()))) - return resp, nil - } - } - - // Define the metric type when it has not been explicitly assigned by the user. - if !req.GetFromShardLeader() && req.GetReq().GetMetricType() == "" { - req.Req.MetricType = collection.GetMetricType() - } - toReduceResults := make([]*internalpb.SearchResults, len(req.GetDmlChannels())) runningGp, runningCtx := errgroup.WithContext(ctx) for i, ch := range req.GetDmlChannels() { diff --git a/internal/querynodev2/services_test.go b/internal/querynodev2/services_test.go index 2084f8374ff70..e74f723ae3819 100644 --- a/internal/querynodev2/services_test.go +++ b/internal/querynodev2/services_test.go @@ -39,7 +39,6 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/internalpb" - "github.com/milvus-io/milvus/internal/proto/planpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querynodev2/delegator" "github.com/milvus-io/milvus/internal/querynodev2/segments" @@ -52,7 +51,6 @@ import ( "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/merr" - "github.com/milvus-io/milvus/pkg/util/metric" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" @@ -257,6 +255,7 @@ func (suite *ServiceSuite) TestWatchDmChannelsInt64() { ctx := context.Background() // data + schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64) deltaLogs, err := segments.SaveDeltaLog(suite.collectionID, suite.partitionIDs[0], suite.flushedSegmentIDs[0], @@ -292,16 +291,14 @@ func (suite *ServiceSuite) TestWatchDmChannelsInt64() { Level: datapb.SegmentLevel_L0, }, }, - Schema: segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64), + Schema: schema, LoadMeta: &querypb.LoadMetaInfo{ LoadType: querypb.LoadType_LoadCollection, CollectionID: suite.collectionID, PartitionIDs: suite.partitionIDs, MetricType: defaultMetricType, }, - IndexInfoList: []*indexpb.IndexInfo{ - {}, - }, + IndexInfoList: segments.GenTestIndexInfoList(suite.collectionID, schema), } // mocks @@ -326,6 +323,7 @@ func (suite *ServiceSuite) TestWatchDmChannelsVarchar() { ctx := context.Background() // data + schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_VarChar) req := &querypb.WatchDmChannelsRequest{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_WatchDmChannels, @@ -344,16 +342,14 @@ func (suite *ServiceSuite) TestWatchDmChannelsVarchar() { DroppedSegmentIds: suite.droppedSegmentIDs, }, }, - Schema: segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_VarChar), + Schema: schema, LoadMeta: &querypb.LoadMetaInfo{ LoadType: querypb.LoadType_LoadCollection, CollectionID: suite.collectionID, PartitionIDs: suite.partitionIDs, MetricType: defaultMetricType, }, - IndexInfoList: []*indexpb.IndexInfo{ - {}, - }, + IndexInfoList: segments.GenTestIndexInfoList(suite.collectionID, schema), } // mocks @@ -378,6 +374,7 @@ func (suite *ServiceSuite) TestWatchDmChannels_Failed() { ctx := context.Background() // data + schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64) req := &querypb.WatchDmChannelsRequest{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_WatchDmChannels, @@ -396,13 +393,11 @@ func (suite *ServiceSuite) TestWatchDmChannels_Failed() { DroppedSegmentIds: suite.droppedSegmentIDs, }, }, - Schema: segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64), + Schema: schema, LoadMeta: &querypb.LoadMetaInfo{ MetricType: defaultMetricType, }, - IndexInfoList: []*indexpb.IndexInfo{ - {}, - }, + IndexInfoList: segments.GenTestIndexInfoList(suite.collectionID, schema), } // test channel is unsubscribing @@ -439,14 +434,6 @@ func (suite *ServiceSuite) TestWatchDmChannels_Failed() { status, err = suite.node.WatchDmChannels(ctx, req) suite.NoError(err) suite.Equal(commonpb.ErrorCode_NotReadyServe, status.GetErrorCode()) - - // empty metric type - req.LoadMeta.MetricType = "" - req.Base.TargetID = paramtable.GetNodeID() - suite.node.UpdateStateCode(commonpb.StateCode_Healthy) - status, err = suite.node.WatchDmChannels(ctx, req) - suite.NoError(err) - suite.Equal(commonpb.ErrorCode_UnexpectedError, status.ErrorCode) } func (suite *ServiceSuite) TestUnsubDmChannels_Normal() { @@ -502,22 +489,9 @@ func (suite *ServiceSuite) TestUnsubDmChannels_Failed() { suite.Equal(commonpb.ErrorCode_NotReadyServe, status.GetErrorCode()) } -func (suite *ServiceSuite) genSegmentIndexInfos(loadInfo []*querypb.SegmentLoadInfo) []*indexpb.IndexInfo { - indexInfoList := make([]*indexpb.IndexInfo, 0) - seg0LoadInfo := loadInfo[0] - fieldIndexInfos := seg0LoadInfo.IndexInfos - for _, info := range fieldIndexInfos { - indexInfoList = append(indexInfoList, &indexpb.IndexInfo{ - CollectionID: suite.collectionID, - FieldID: info.GetFieldID(), - IndexName: info.GetIndexName(), - IndexParams: info.GetIndexParams(), - }) - } - return indexInfoList -} - -func (suite *ServiceSuite) genSegmentLoadInfos(schema *schemapb.CollectionSchema) []*querypb.SegmentLoadInfo { +func (suite *ServiceSuite) genSegmentLoadInfos(schema *schemapb.CollectionSchema, + indexInfos []*indexpb.IndexInfo, +) []*querypb.SegmentLoadInfo { ctx := context.Background() segNum := len(suite.validSegmentIDs) @@ -534,18 +508,25 @@ func (suite *ServiceSuite) genSegmentLoadInfos(schema *schemapb.CollectionSchema ) suite.Require().NoError(err) - vecFieldIDs := funcutil.GetVecFieldIDs(schema) - indexes, err := segments.GenAndSaveIndex( - suite.collectionID, - suite.partitionIDs[i%partNum], - suite.validSegmentIDs[i], - vecFieldIDs[0], - 1000, - segments.IndexFaissIVFFlat, - metric.L2, - suite.node.chunkManager, - ) - suite.Require().NoError(err) + vectorFieldSchemas := typeutil.GetVectorFieldSchemas(schema) + indexes := make([]*querypb.FieldIndexInfo, 0) + for offset, field := range vectorFieldSchemas { + indexInfo := lo.FindOrElse(indexInfos, nil, func(info *indexpb.IndexInfo) bool { return info.FieldID == field.GetFieldID() }) + if indexInfo != nil { + index, err := segments.GenAndSaveIndexV2( + suite.collectionID, + suite.partitionIDs[i%partNum], + suite.validSegmentIDs[i], + int64(offset), + field, + indexInfo, + suite.node.chunkManager, + 1000, + ) + suite.Require().NoError(err) + indexes = append(indexes, index) + } + } info := &querypb.SegmentLoadInfo{ SegmentID: suite.validSegmentIDs[i], @@ -555,7 +536,7 @@ func (suite *ServiceSuite) genSegmentLoadInfos(schema *schemapb.CollectionSchema NumOfRows: 1000, BinlogPaths: binlogs, Statslogs: statslogs, - IndexInfos: []*querypb.FieldIndexInfo{indexes}, + IndexInfos: indexes, StartPosition: &msgpb.MsgPosition{Timestamp: 20000}, DeltaPosition: &msgpb.MsgPosition{Timestamp: 20000}, } @@ -569,7 +550,8 @@ func (suite *ServiceSuite) TestLoadSegments_Int64() { suite.TestWatchDmChannelsInt64() // data schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64) - infos := suite.genSegmentLoadInfos(schema) + indexInfos := segments.GenTestIndexInfoList(suite.collectionID, schema) + infos := suite.genSegmentLoadInfos(schema, indexInfos) for _, info := range infos { req := &querypb.LoadSegmentsRequest{ Base: &commonpb.MsgBase{ @@ -582,9 +564,7 @@ func (suite *ServiceSuite) TestLoadSegments_Int64() { Schema: schema, DeltaPositions: []*msgpb.MsgPosition{{Timestamp: 20000}}, NeedTransfer: true, - IndexInfoList: []*indexpb.IndexInfo{ - {}, - }, + IndexInfoList: indexInfos, } // LoadSegment @@ -607,7 +587,7 @@ func (suite *ServiceSuite) TestLoadSegments_VarChar() { suite.node.manager.Collection = segments.NewCollectionManager() suite.node.manager.Collection.PutOrRef(suite.collectionID, schema, nil, loadMeta) - infos := suite.genSegmentLoadInfos(schema) + infos := suite.genSegmentLoadInfos(schema, nil) for _, info := range infos { req := &querypb.LoadSegmentsRequest{ Base: &commonpb.MsgBase{ @@ -643,7 +623,7 @@ func (suite *ServiceSuite) TestLoadDeltaInt64() { }, CollectionID: suite.collectionID, DstNodeID: suite.node.session.ServerID, - Infos: suite.genSegmentLoadInfos(schema), + Infos: suite.genSegmentLoadInfos(schema, nil), Schema: schema, NeedTransfer: true, LoadScope: querypb.LoadScope_Delta, @@ -668,7 +648,7 @@ func (suite *ServiceSuite) TestLoadDeltaVarchar() { }, CollectionID: suite.collectionID, DstNodeID: suite.node.session.ServerID, - Infos: suite.genSegmentLoadInfos(schema), + Infos: suite.genSegmentLoadInfos(schema, nil), Schema: schema, NeedTransfer: true, LoadScope: querypb.LoadScope_Delta, @@ -687,7 +667,8 @@ func (suite *ServiceSuite) TestLoadIndex_Success() { schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64) - infos := suite.genSegmentLoadInfos(schema) + indexInfos := segments.GenTestIndexInfoList(suite.collectionID, schema) + infos := suite.genSegmentLoadInfos(schema, indexInfos) infos = lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) *querypb.SegmentLoadInfo { info.SegmentID = info.SegmentID + 1000 return info @@ -697,8 +678,7 @@ func (suite *ServiceSuite) TestLoadIndex_Success() { info.IndexInfos = nil return info }) - // generate indexinfos for setting index meta. - indexInfoList := suite.genSegmentIndexInfos(infos) + req := &querypb.LoadSegmentsRequest{ Base: &commonpb.MsgBase{ MsgID: rand.Int63(), @@ -710,7 +690,7 @@ func (suite *ServiceSuite) TestLoadIndex_Success() { Schema: schema, NeedTransfer: false, LoadScope: querypb.LoadScope_Full, - IndexInfoList: indexInfoList, + IndexInfoList: indexInfos, } // Load segment @@ -759,7 +739,8 @@ func (suite *ServiceSuite) TestLoadIndex_Failed() { schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64) suite.Run("load_non_exist_segment", func() { - infos := suite.genSegmentLoadInfos(schema) + indexInfos := segments.GenTestIndexInfoList(suite.collectionID, schema) + infos := suite.genSegmentLoadInfos(schema, indexInfos) infos = lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) *querypb.SegmentLoadInfo { info.SegmentID = info.SegmentID + 1000 return info @@ -780,7 +761,7 @@ func (suite *ServiceSuite) TestLoadIndex_Failed() { Schema: schema, NeedTransfer: false, LoadScope: querypb.LoadScope_Index, - IndexInfoList: []*indexpb.IndexInfo{{}}, + IndexInfoList: indexInfos, } // Load segment @@ -801,7 +782,8 @@ func (suite *ServiceSuite) TestLoadIndex_Failed() { mockLoader.EXPECT().LoadIndex(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(errors.New("mocked error")) - infos := suite.genSegmentLoadInfos(schema) + indexInfos := segments.GenTestIndexInfoList(suite.collectionID, schema) + infos := suite.genSegmentLoadInfos(schema, indexInfos) req := &querypb.LoadSegmentsRequest{ Base: &commonpb.MsgBase{ MsgID: rand.Int63(), @@ -813,7 +795,7 @@ func (suite *ServiceSuite) TestLoadIndex_Failed() { Schema: schema, NeedTransfer: false, LoadScope: querypb.LoadScope_Index, - IndexInfoList: []*indexpb.IndexInfo{{}}, + IndexInfoList: indexInfos, } // Load segment @@ -834,7 +816,7 @@ func (suite *ServiceSuite) TestLoadSegments_Failed() { }, CollectionID: suite.collectionID, DstNodeID: suite.node.session.ServerID, - Infos: suite.genSegmentLoadInfos(schema), + Infos: suite.genSegmentLoadInfos(schema, nil), Schema: schema, NeedTransfer: true, IndexInfoList: []*indexpb.IndexInfo{ @@ -886,7 +868,7 @@ func (suite *ServiceSuite) TestLoadSegments_Transfer() { }, CollectionID: suite.collectionID, DstNodeID: suite.node.session.ServerID, - Infos: suite.genSegmentLoadInfos(schema), + Infos: suite.genSegmentLoadInfos(schema, nil), Schema: schema, NeedTransfer: true, IndexInfoList: []*indexpb.IndexInfo{{}}, @@ -908,7 +890,7 @@ func (suite *ServiceSuite) TestLoadSegments_Transfer() { }, CollectionID: suite.collectionID, DstNodeID: suite.node.session.ServerID, - Infos: suite.genSegmentLoadInfos(schema), + Infos: suite.genSegmentLoadInfos(schema, nil), Schema: schema, NeedTransfer: true, IndexInfoList: []*indexpb.IndexInfo{{}}, @@ -935,7 +917,7 @@ func (suite *ServiceSuite) TestLoadSegments_Transfer() { }, CollectionID: suite.collectionID, DstNodeID: suite.node.session.ServerID, - Infos: suite.genSegmentLoadInfos(schema), + Infos: suite.genSegmentLoadInfos(schema, nil), Schema: schema, NeedTransfer: true, IndexInfoList: []*indexpb.IndexInfo{{}}, @@ -1139,18 +1121,14 @@ func (suite *ServiceSuite) TestGetSegmentInfo_Failed() { } // Test Search -func (suite *ServiceSuite) genCSearchRequest(nq int64, indexType string, schema *schemapb.CollectionSchema) (*internalpb.SearchRequest, error) { +func (suite *ServiceSuite) genCSearchRequest(nq int64, dataType schemapb.DataType, fieldID int64, metricType string) (*internalpb.SearchRequest, error) { placeHolder, err := genPlaceHolderGroup(nq) if err != nil { return nil, err } - planStr, err := genDSLByIndexType(schema, indexType) - if err != nil { - return nil, err - } - var planpb planpb.PlanNode - proto.UnmarshalText(planStr, &planpb) - serializedPlan, err2 := proto.Marshal(&planpb) + + plan := genSearchPlan(dataType, fieldID, metricType) + serializedPlan, err2 := proto.Marshal(plan) if err2 != nil { return nil, err2 } @@ -1175,9 +1153,7 @@ func (suite *ServiceSuite) TestSearch_Normal() { suite.TestWatchDmChannelsInt64() suite.TestLoadSegments_Int64() - // data - schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64) - creq, err := suite.genCSearchRequest(10, IndexFaissIDMap, schema) + creq, err := suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, defaultMetricType) req := &querypb.SearchRequest{ Req: creq, FromShardLeader: false, @@ -1197,14 +1173,11 @@ func (suite *ServiceSuite) TestSearch_Concurrent() { suite.TestWatchDmChannelsInt64() suite.TestLoadSegments_Int64() - // data - schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64) - concurrency := 16 futures := make([]*conc.Future[*internalpb.SearchResults], 0, concurrency) for i := 0; i < concurrency; i++ { future := conc.Go(func() (*internalpb.SearchResults, error) { - creq, err := suite.genCSearchRequest(30, IndexFaissIDMap, schema) + creq, err := suite.genCSearchRequest(30, schemapb.DataType_FloatVector, 107, defaultMetricType) req := &querypb.SearchRequest{ Req: creq, FromShardLeader: false, @@ -1230,7 +1203,7 @@ func (suite *ServiceSuite) TestSearch_Failed() { // data schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64) - creq, err := suite.genCSearchRequest(10, IndexFaissIDMap, schema) + creq, err := suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, "invalidMetricType") req := &querypb.SearchRequest{ Req: creq, FromShardLeader: false, @@ -1250,15 +1223,9 @@ func (suite *ServiceSuite) TestSearch_Failed() { LoadType: querypb.LoadType_LoadCollection, CollectionID: suite.collectionID, PartitionIDs: suite.partitionIDs, - MetricType: "L2", } - suite.node.manager.Collection.PutOrRef(suite.collectionID, schema, nil, LoadMeta) - req.GetReq().MetricType = "IP" - resp, err = suite.node.Search(ctx, req) - suite.NoError(err) - suite.ErrorIs(merr.Error(resp.GetStatus()), merr.ErrParameterInvalid) - suite.Contains(resp.GetStatus().GetReason(), merr.ErrParameterInvalid.Error()) - req.GetReq().MetricType = "L2" + indexMeta := suite.node.composeIndexMeta(segments.GenTestIndexInfoList(suite.collectionID, schema), schema) + suite.node.manager.Collection.PutOrRef(suite.collectionID, schema, indexMeta, LoadMeta) // Delegator not found resp, err = suite.node.Search(ctx, req) @@ -1268,6 +1235,34 @@ func (suite *ServiceSuite) TestSearch_Failed() { suite.TestWatchDmChannelsInt64() suite.TestLoadSegments_Int64() + // sync segment data + syncReq := &querypb.SyncDistributionRequest{ + Base: &commonpb.MsgBase{ + MsgID: rand.Int63(), + TargetID: suite.node.session.ServerID, + }, + CollectionID: suite.collectionID, + Channel: suite.vchannel, + } + + syncVersionAction := &querypb.SyncAction{ + Type: querypb.SyncType_UpdateVersion, + SealedInTarget: []int64{1, 2, 3, 4}, + TargetVersion: time.Now().UnixMilli(), + } + + syncReq.Actions = []*querypb.SyncAction{syncVersionAction} + status, err := suite.node.SyncDistribution(ctx, syncReq) + suite.NoError(err) + suite.Equal(commonpb.ErrorCode_Success, status.ErrorCode) + + // metric type not match + req.GetReq().MetricType = "IP" + resp, err = suite.node.Search(ctx, req) + suite.NoError(err) + suite.Contains(resp.GetStatus().GetReason(), "metric type not match") + req.GetReq().MetricType = "L2" + // target not match req.Req.Base.TargetID = -1 resp, err = suite.node.Search(ctx, req) @@ -1333,9 +1328,7 @@ func (suite *ServiceSuite) TestSearchSegments_Normal() { suite.TestWatchDmChannelsInt64() suite.TestLoadSegments_Int64() - // data - schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64) - creq, err := suite.genCSearchRequest(10, IndexFaissIDMap, schema) + creq, err := suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, defaultMetricType) req := &querypb.SearchRequest{ Req: creq, FromShardLeader: true, diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index 2f6bc9f1cca5e..8a460f89851a6 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -31,6 +31,8 @@ const ( FailLabel = "fail" TotalLabel = "total" + HybridSearchLabel = "hybrid_search" + InsertLabel = "insert" DeleteLabel = "delete" UpsertLabel = "upsert" diff --git a/pkg/util/typeutil/schema_test.go b/pkg/util/typeutil/schema_test.go index ba0f9eaf92936..5a209c00d11a3 100644 --- a/pkg/util/typeutil/schema_test.go +++ b/pkg/util/typeutil/schema_test.go @@ -239,9 +239,9 @@ func TestSchema_GetVectorFieldSchema(t *testing.T) { } t.Run("GetVectorFieldSchema", func(t *testing.T) { - fieldSchema, err := GetVectorFieldSchema(schemaNormal) - assert.Equal(t, "field_float_vector", fieldSchema.Name) - assert.NoError(t, err) + fieldSchema := GetVectorFieldSchemas(schemaNormal) + assert.Equal(t, 1, len(fieldSchema)) + assert.Equal(t, "field_float_vector", fieldSchema[0].Name) }) schemaInvalid := &schemapb.CollectionSchema{ @@ -260,8 +260,8 @@ func TestSchema_GetVectorFieldSchema(t *testing.T) { } t.Run("GetVectorFieldSchemaInvalid", func(t *testing.T) { - _, err := GetVectorFieldSchema(schemaInvalid) - assert.Error(t, err) + res := GetVectorFieldSchemas(schemaInvalid) + assert.Equal(t, 0, len(res)) }) } diff --git a/tests/integration/hybridsearch/hybridsearch_test.go b/tests/integration/hybridsearch/hybridsearch_test.go new file mode 100644 index 0000000000000..7abdf1410e2e7 --- /dev/null +++ b/tests/integration/hybridsearch/hybridsearch_test.go @@ -0,0 +1,225 @@ +package hybridsearch + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + "testing" + + "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/suite" + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/proxy" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/metric" + "github.com/milvus-io/milvus/tests/integration" +) + +type HybridSearchSuite struct { + integration.MiniClusterSuite +} + +func (s *HybridSearchSuite) TestHybridSearch() { + c := s.Cluster + ctx, cancel := context.WithCancel(c.GetContext()) + defer cancel() + + prefix := "TestHybridSearch" + dbName := "" + collectionName := prefix + funcutil.GenRandomStr() + dim := 128 + rowNum := 3000 + + schema := integration.ConstructSchema(collectionName, dim, true, + &schemapb.FieldSchema{Name: integration.Int64Field, DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true}, + &schemapb.FieldSchema{Name: integration.FloatVecField, DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{{Key: common.DimKey, Value: "128"}}}, + &schemapb.FieldSchema{Name: integration.BinVecField, DataType: schemapb.DataType_BinaryVector, TypeParams: []*commonpb.KeyValuePair{{Key: common.DimKey, Value: "128"}}}, + ) + marshaledSchema, err := proto.Marshal(schema) + s.NoError(err) + + createCollectionStatus, err := c.Proxy.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{ + DbName: dbName, + CollectionName: collectionName, + Schema: marshaledSchema, + ShardsNum: common.DefaultShardsNum, + }) + s.NoError(err) + + err = merr.Error(createCollectionStatus) + if err != nil { + log.Warn("createCollectionStatus fail reason", zap.Error(err)) + } + + log.Info("CreateCollection result", zap.Any("createCollectionStatus", createCollectionStatus)) + showCollectionsResp, err := c.Proxy.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{}) + s.NoError(err) + s.True(merr.Ok(showCollectionsResp.GetStatus())) + log.Info("ShowCollections result", zap.Any("showCollectionsResp", showCollectionsResp)) + + fVecColumn := integration.NewFloatVectorFieldData(integration.FloatVecField, rowNum, dim) + bVecColumn := integration.NewBinaryVectorFieldData(integration.BinVecField, rowNum, dim) + hashKeys := integration.GenerateHashKeys(rowNum) + insertResult, err := c.Proxy.Insert(ctx, &milvuspb.InsertRequest{ + DbName: dbName, + CollectionName: collectionName, + FieldsData: []*schemapb.FieldData{fVecColumn, bVecColumn}, + HashKeys: hashKeys, + NumRows: uint32(rowNum), + }) + s.NoError(err) + s.True(merr.Ok(insertResult.GetStatus())) + + // flush + flushResp, err := c.Proxy.Flush(ctx, &milvuspb.FlushRequest{ + DbName: dbName, + CollectionNames: []string{collectionName}, + }) + s.NoError(err) + segmentIDs, has := flushResp.GetCollSegIDs()[collectionName] + ids := segmentIDs.GetData() + s.Require().NotEmpty(segmentIDs) + s.Require().True(has) + flushTs, has := flushResp.GetCollFlushTs()[collectionName] + s.True(has) + + segments, err := c.MetaWatcher.ShowSegments() + s.NoError(err) + s.NotEmpty(segments) + for _, segment := range segments { + log.Info("ShowSegments result", zap.String("segment", segment.String())) + } + s.WaitForFlush(ctx, ids, flushTs, dbName, collectionName) + + // load without index on vector fields + loadStatus, err := c.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{ + DbName: dbName, + CollectionName: collectionName, + }) + s.NoError(err) + s.Error(merr.Error(loadStatus)) + + // create index for float vector + createIndexStatus, err := c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{ + CollectionName: collectionName, + FieldName: integration.FloatVecField, + IndexName: "_default_float", + ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, metric.L2), + }) + s.NoError(err) + err = merr.Error(createIndexStatus) + if err != nil { + log.Warn("createIndexStatus fail reason", zap.Error(err)) + } + s.WaitForIndexBuilt(ctx, collectionName, integration.FloatVecField) + + // load with index on partial vector fields + loadStatus, err = c.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{ + DbName: dbName, + CollectionName: collectionName, + }) + s.NoError(err) + s.Error(merr.Error(loadStatus)) + + // create index for binary vector + createIndexStatus, err = c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{ + CollectionName: collectionName, + FieldName: integration.BinVecField, + IndexName: "_default_binary", + ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissBinIvfFlat, metric.JACCARD), + }) + s.NoError(err) + err = merr.Error(createIndexStatus) + if err != nil { + log.Warn("createIndexStatus fail reason", zap.Error(err)) + } + s.WaitForIndexBuiltWithIndexName(ctx, collectionName, integration.BinVecField, "_default_binary") + + // load with index on all vector fields + loadStatus, err = c.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{ + DbName: dbName, + CollectionName: collectionName, + }) + s.NoError(err) + err = merr.Error(loadStatus) + if err != nil { + log.Warn("LoadCollection fail reason", zap.Error(err)) + } + s.WaitForLoad(ctx, collectionName) + + // search + expr := fmt.Sprintf("%s > 0", integration.Int64Field) + nq := 1 + topk := 10 + roundDecimal := -1 + + fParams := integration.GetSearchParams(integration.IndexFaissIvfFlat, metric.L2) + bParams := integration.GetSearchParams(integration.IndexFaissBinIvfFlat, metric.L2) + fSearchReq := integration.ConstructSearchRequest("", collectionName, expr, + integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.L2, fParams, nq, dim, topk, roundDecimal) + + bSearchReq := integration.ConstructSearchRequest("", collectionName, expr, + integration.BinVecField, schemapb.DataType_BinaryVector, nil, metric.JACCARD, bParams, nq, dim, topk, roundDecimal) + + hSearchReq := &milvuspb.HybridSearchRequest{ + Base: nil, + DbName: dbName, + CollectionName: collectionName, + PartitionNames: nil, + Requests: []*milvuspb.SearchRequest{fSearchReq, bSearchReq}, + OutputFields: []string{integration.FloatVecField, integration.BinVecField}, + } + + // rrf rank hybrid search + rrfParams := make(map[string]float64) + rrfParams[proxy.RRFParamsKey] = 60 + b, err := json.Marshal(rrfParams) + s.NoError(err) + hSearchReq.RankParams = []*commonpb.KeyValuePair{ + {Key: proxy.RankTypeKey, Value: "rrf"}, + {Key: proxy.RankParamsKey, Value: string(b)}, + {Key: proxy.LimitKey, Value: strconv.Itoa(topk)}, + {Key: proxy.RoundDecimalKey, Value: strconv.Itoa(roundDecimal)}, + } + + searchResult, err := c.Proxy.HybridSearch(ctx, hSearchReq) + + if searchResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + log.Warn("searchResult fail reason", zap.String("reason", searchResult.GetStatus().GetReason())) + } + s.NoError(err) + s.Equal(commonpb.ErrorCode_Success, searchResult.GetStatus().GetErrorCode()) + + // weighted rank hybrid search + weightsParams := make(map[string][]float64) + weightsParams[proxy.WeightsParamsKey] = []float64{0.5, 0.2} + b, err = json.Marshal(weightsParams) + s.NoError(err) + hSearchReq.RankParams = []*commonpb.KeyValuePair{ + {Key: proxy.RankTypeKey, Value: "weighted"}, + {Key: proxy.RankParamsKey, Value: string(b)}, + {Key: proxy.LimitKey, Value: strconv.Itoa(topk)}, + } + + searchResult, err = c.Proxy.HybridSearch(ctx, hSearchReq) + + if searchResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + log.Warn("searchResult fail reason", zap.String("reason", searchResult.GetStatus().GetReason())) + } + s.NoError(err) + s.Equal(commonpb.ErrorCode_Success, searchResult.GetStatus().GetErrorCode()) + + log.Info("TestHybridSearch succeed") +} + +func TestHybridSearch(t *testing.T) { + suite.Run(t, new(HybridSearchSuite)) +} diff --git a/tests/integration/util_index.go b/tests/integration/util_index.go index 602152d09a857..3bc821ccd73b2 100644 --- a/tests/integration/util_index.go +++ b/tests/integration/util_index.go @@ -44,19 +44,24 @@ const ( ) func (s *MiniClusterSuite) WaitForIndexBuiltWithDB(ctx context.Context, dbName, collection, field string) { - s.waitForIndexBuiltInternal(ctx, dbName, collection, field) + s.waitForIndexBuiltInternal(ctx, dbName, collection, field, "") } func (s *MiniClusterSuite) WaitForIndexBuilt(ctx context.Context, collection, field string) { - s.waitForIndexBuiltInternal(ctx, "", collection, field) + s.waitForIndexBuiltInternal(ctx, "", collection, field, "") } -func (s *MiniClusterSuite) waitForIndexBuiltInternal(ctx context.Context, dbName, collection, field string) { +func (s *MiniClusterSuite) WaitForIndexBuiltWithIndexName(ctx context.Context, collection, field, indexName string) { + s.waitForIndexBuiltInternal(ctx, "", collection, field, indexName) +} + +func (s *MiniClusterSuite) waitForIndexBuiltInternal(ctx context.Context, dbName, collection, field, indexName string) { getIndexBuilt := func() bool { resp, err := s.Cluster.Proxy.DescribeIndex(ctx, &milvuspb.DescribeIndexRequest{ DbName: dbName, CollectionName: collection, FieldName: field, + IndexName: indexName, }) if err != nil { s.FailNow("failed to describe index") From 97e4ec5a6904c5abe7d068b9b87a26765ef333cd Mon Sep 17 00:00:00 2001 From: yah01 Date: Mon, 8 Jan 2024 15:58:48 +0800 Subject: [PATCH 19/20] enhance: use random root path for minio unit tests (#29753) this avoids the conflicts while running multiple unit tests Signed-off-by: yah01 --- internal/storage/minio_chunk_manager_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/internal/storage/minio_chunk_manager_test.go b/internal/storage/minio_chunk_manager_test.go index 3d0dd2f771288..aae848f4320c7 100644 --- a/internal/storage/minio_chunk_manager_test.go +++ b/internal/storage/minio_chunk_manager_test.go @@ -18,6 +18,7 @@ package storage import ( "context" + "fmt" "io" "math/rand" "path" @@ -85,7 +86,7 @@ func TestMinIOCM(t *testing.T) { configRoot := Params.MinioCfg.RootPath.GetValue() - testMinIOKVRoot := path.Join(configRoot, "milvus-minio-ut-root") + testMinIOKVRoot := path.Join(configRoot, fmt.Sprintf("minio-ut-%d", rand.Int())) t.Run("test load", func(t *testing.T) { testLoadRoot := path.Join(testMinIOKVRoot, "test_load") From b9d76f77d19810f8f7c8f760745bdb5c6495c61f Mon Sep 17 00:00:00 2001 From: "zhenshan.cao" Date: Wed, 27 Dec 2023 15:12:13 +0800 Subject: [PATCH 20/20] Restore the MVCC functionality. When the TimeTravel functionality was previously removed, it inadvertently affected the MVCC functionality within the system. This PR aims to reintroduce the internal MVCC functionality as follows: 1. Add MvccTimestamp to the requests of Search/Query and the results of Search internally. 2. When the delegator receives a Query/Search request and there is no MVCC timestamp set in the request, set the delegator's current tsafe as the MVCC timestamp of the request. If the request already has an MVCC timestamp, do not modify it. 3. When the Proxy handles Search and triggers the second phase ReQuery, divide the ReQuery into different shards and pass the MVCC timestamp to the corresponding Query requests. Signed-off-by: zhenshan.cao --- .../core/src/segcore/SegmentInterface.cpp | 5 +- internal/core/src/segcore/SegmentInterface.h | 6 +- internal/core/src/segcore/segment_c.cpp | 3 +- internal/core/src/segcore/segment_c.h | 1 + internal/core/unittest/bench/bench_search.cpp | 12 +- internal/core/unittest/test_binlog_index.cpp | 7 +- internal/core/unittest/test_c_api.cpp | 169 +- internal/core/unittest/test_float16.cpp | 4 +- internal/core/unittest/test_group_by.cpp | 250 +- internal/core/unittest/test_growing_index.cpp | 7 +- internal/core/unittest/test_query.cpp | 35 +- internal/core/unittest/test_sealed.cpp | 75 +- internal/core/unittest/test_string_expr.cpp | 2 +- internal/core/unittest/test_utils/DataGen.h | 5 +- .../unittest/test_utils/c_api_test_utils.h | 136 +- internal/proto/internal.proto | 4 +- ...emove time travel ralted testcase (#26119) | 2132 +++++++++++++++++ internal/proxy/lb_policy.go | 2 +- internal/proxy/lb_policy_test.go | 18 +- internal/proxy/task_delete.go | 6 +- internal/proxy/task_delete_test.go | 16 +- internal/proxy/task_hybrid_search.go | 10 +- internal/proxy/task_query.go | 22 +- internal/proxy/task_search.go | 40 +- internal/proxy/task_search_test.go | 6 +- internal/proxy/task_statistic.go | 8 +- internal/querynodev2/delegator/delegator.go | 34 +- internal/querynodev2/handlers.go | 1 - internal/querynodev2/segments/plan.go | 2 + internal/querynodev2/segments/reduce_test.go | 2 + internal/querynodev2/segments/result.go | 8 +- internal/querynodev2/segments/segment.go | 1 + internal/querynodev2/services.go | 13 +- internal/querynodev2/services_test.go | 1 + internal/querynodev2/tasks/task.go | 8 + 35 files changed, 2722 insertions(+), 329 deletions(-) create mode 100644 internal/proto/internalpb/internal.pb.go~parent of ca1349708... Remove time travel ralted testcase (#26119) diff --git a/internal/core/src/segcore/SegmentInterface.cpp b/internal/core/src/segcore/SegmentInterface.cpp index 502bcd083ef0e..6779b56654647 100644 --- a/internal/core/src/segcore/SegmentInterface.cpp +++ b/internal/core/src/segcore/SegmentInterface.cpp @@ -67,11 +67,12 @@ SegmentInternalInterface::FillTargetEntry(const query::Plan* plan, std::unique_ptr SegmentInternalInterface::Search( const query::Plan* plan, - const query::PlaceholderGroup* placeholder_group) const { + const query::PlaceholderGroup* placeholder_group, + Timestamp timestamp) const { std::shared_lock lck(mutex_); milvus::tracer::AddEvent("obtained_segment_lock_mutex"); check_search(plan); - query::ExecPlanNodeVisitor visitor(*this, 1L << 63, placeholder_group); + query::ExecPlanNodeVisitor visitor(*this, timestamp, placeholder_group); auto results = std::make_unique(); *results = visitor.get_moved_result(*plan->plan_node_); results->segment_ = (void*)this; diff --git a/internal/core/src/segcore/SegmentInterface.h b/internal/core/src/segcore/SegmentInterface.h index b8d19de0d1640..a5fbc8014d602 100644 --- a/internal/core/src/segcore/SegmentInterface.h +++ b/internal/core/src/segcore/SegmentInterface.h @@ -54,7 +54,8 @@ class SegmentInterface { virtual std::unique_ptr Search(const query::Plan* Plan, - const query::PlaceholderGroup* placeholder_group) const = 0; + const query::PlaceholderGroup* placeholder_group, + Timestamp timestamp) const = 0; virtual std::unique_ptr Retrieve(const query::RetrievePlan* Plan, @@ -136,7 +137,8 @@ class SegmentInternalInterface : public SegmentInterface { std::unique_ptr Search(const query::Plan* Plan, - const query::PlaceholderGroup* placeholder_group) const override; + const query::PlaceholderGroup* placeholder_group, + Timestamp timestamp) const override; void FillPrimaryKeys(const query::Plan* plan, diff --git a/internal/core/src/segcore/segment_c.cpp b/internal/core/src/segcore/segment_c.cpp index e997eb5b88adc..814b504bc0eb4 100644 --- a/internal/core/src/segcore/segment_c.cpp +++ b/internal/core/src/segcore/segment_c.cpp @@ -80,6 +80,7 @@ Search(CSegmentInterface c_segment, CSearchPlan c_plan, CPlaceholderGroup c_placeholder_group, CTraceContext c_trace, + uint64_t timestamp, CSearchResult* result) { try { auto segment = (milvus::segcore::SegmentInterface*)c_segment; @@ -90,7 +91,7 @@ Search(CSegmentInterface c_segment, c_trace.traceID, c_trace.spanID, c_trace.flag}; auto span = milvus::tracer::StartSpan("SegCoreSearch", &ctx); milvus::tracer::SetRootSpan(span); - auto search_result = segment->Search(plan, phg_ptr); + auto search_result = segment->Search(plan, phg_ptr, timestamp); if (!milvus::PositivelyRelated( plan->plan_node_->search_info_.metric_type_)) { for (auto& dis : search_result->distances_) { diff --git a/internal/core/src/segcore/segment_c.h b/internal/core/src/segcore/segment_c.h index 118b69ff9c070..b638232ab2a46 100644 --- a/internal/core/src/segcore/segment_c.h +++ b/internal/core/src/segcore/segment_c.h @@ -45,6 +45,7 @@ Search(CSegmentInterface c_segment, CSearchPlan c_plan, CPlaceholderGroup c_placeholder_group, CTraceContext c_trace, + uint64_t timestamp, CSearchResult* result); void diff --git a/internal/core/unittest/bench/bench_search.cpp b/internal/core/unittest/bench/bench_search.cpp index f1334c7a3ed7e..fabfa38fc7c74 100644 --- a/internal/core/unittest/bench/bench_search.cpp +++ b/internal/core/unittest/bench/bench_search.cpp @@ -90,8 +90,10 @@ Search_GrowingIndex(benchmark::State& state) { dataset_.timestamps_.data(), dataset_.raw_); + Timestamp ts = 10000000; + for (auto _ : state) { - auto qr = segment->Search(search_plan.get(), ph_group.get()); + auto qr = segment->Search(search_plan.get(), ph_group.get(), ts); } } @@ -114,7 +116,8 @@ Search_Sealed(benchmark::State& state) { } else if (choice == 1) { // hnsw auto vec = dataset_.get_col(milvus::FieldId(100)); - auto indexing = GenVecIndexing(N, dim, vec.data(), knowhere::IndexEnum::INDEX_HNSW); + auto indexing = + GenVecIndexing(N, dim, vec.data(), knowhere::IndexEnum::INDEX_HNSW); segcore::LoadIndexInfo info; info.index = std::move(indexing); info.field_id = (*schema)[FieldName("fakevec")].get_id().get(); @@ -123,8 +126,11 @@ Search_Sealed(benchmark::State& state) { segment->DropFieldData(milvus::FieldId(100)); segment->LoadIndex(info); } + + Timestamp ts = 10000000; + for (auto _ : state) { - auto qr = segment->Search(search_plan.get(), ph_group.get()); + auto qr = segment->Search(search_plan.get(), ph_group.get(), ts); } } diff --git a/internal/core/unittest/test_binlog_index.cpp b/internal/core/unittest/test_binlog_index.cpp index 4d4c3faf23f43..7cd14c0b6deb4 100644 --- a/internal/core/unittest/test_binlog_index.cpp +++ b/internal/core/unittest/test_binlog_index.cpp @@ -191,7 +191,8 @@ TEST_P(BinlogIndexTest, Accuracy) { std::vector ph_group_arr = { ph_group.get()}; auto nlist = segcore_config.get_nlist(); - auto binlog_index_sr = segment->Search(plan.get(), ph_group.get()); + auto binlog_index_sr = + segment->Search(plan.get(), ph_group.get(), 1L << 63); ASSERT_EQ(binlog_index_sr->total_nq_, num_queries); EXPECT_EQ(binlog_index_sr->unity_topK_, topk); EXPECT_EQ(binlog_index_sr->distances_.size(), num_queries * topk); @@ -226,7 +227,7 @@ TEST_P(BinlogIndexTest, Accuracy) { EXPECT_TRUE(segment->HasIndex(vec_field_id)); EXPECT_EQ(segment->get_row_count(), data_n); EXPECT_FALSE(segment->HasFieldData(vec_field_id)); - auto ivf_sr = segment->Search(plan.get(), ph_group.get()); + auto ivf_sr = segment->Search(plan.get(), ph_group.get(), 1L << 63); auto similary = GetKnnSearchRecall(num_queries, binlog_index_sr->seg_offsets_.data(), topk, @@ -312,4 +313,4 @@ TEST_P(BinlogIndexTest, LoadBinlogWithoutIndexMeta) { EXPECT_FALSE(segment->HasIndex(vec_field_id)); EXPECT_EQ(segment->get_row_count(), data_n); EXPECT_TRUE(segment->HasFieldData(vec_field_id)); -} \ No newline at end of file +} diff --git a/internal/core/unittest/test_c_api.cpp b/internal/core/unittest/test_c_api.cpp index 83f9717555fc8..452b19d60d26a 100644 --- a/internal/core/unittest/test_c_api.cpp +++ b/internal/core/unittest/test_c_api.cpp @@ -1075,11 +1075,13 @@ TEST(CApiTest, SearchTest) { placeholderGroups.push_back(placeholderGroup); CSearchResult search_result; - auto res = Search(segment, plan, placeholderGroup, {}, &search_result); + auto res = + Search(segment, plan, placeholderGroup, {}, ts_offset, &search_result); ASSERT_EQ(res.error_code, Success); CSearchResult search_result2; - auto res2 = Search(segment, plan, placeholderGroup, {}, &search_result2); + auto res2 = + Search(segment, plan, placeholderGroup, {}, ts_offset, &search_result2); ASSERT_EQ(res2.error_code, Success); DeleteSearchPlan(plan); @@ -1143,7 +1145,12 @@ TEST(CApiTest, SearchTestWithExpr) { dataset.timestamps_.push_back(1); CSearchResult search_result; - auto res = Search(segment, plan, placeholderGroup, {}, &search_result); + auto res = Search(segment, + plan, + placeholderGroup, + {}, + dataset.timestamps_[0], + &search_result); ASSERT_EQ(res.error_code, Success); DeleteSearchPlan(plan); @@ -1427,7 +1434,7 @@ TEST(CApiTest, ReduceNullResult) { auto slice_topKs = std::vector{1}; std::vector results; CSearchResult res; - status = Search(segment, plan, placeholderGroup, {}, &res); + status = Search(segment, plan, placeholderGroup, {}, 1L << 63, &res); ASSERT_EQ(status.error_code, Success); results.push_back(res); CSearchResultDataBlobs cSearchResultData; @@ -1514,9 +1521,11 @@ TEST(CApiTest, ReduceRemoveDuplicates) { auto slice_topKs = std::vector{topK / 2, topK}; std::vector results; CSearchResult res1, res2; - status = Search(segment, plan, placeholderGroup, {}, &res1); + status = Search( + segment, plan, placeholderGroup, {}, dataset.timestamps_[0], &res1); ASSERT_EQ(status.error_code, Success); - status = Search(segment, plan, placeholderGroup, {}, &res2); + status = Search( + segment, plan, placeholderGroup, {}, dataset.timestamps_[0], &res2); ASSERT_EQ(status.error_code, Success); results.push_back(res1); results.push_back(res2); @@ -1545,11 +1554,14 @@ TEST(CApiTest, ReduceRemoveDuplicates) { auto slice_topKs = std::vector{topK / 2, topK, topK}; std::vector results; CSearchResult res1, res2, res3; - status = Search(segment, plan, placeholderGroup, {}, &res1); + status = Search( + segment, plan, placeholderGroup, {}, dataset.timestamps_[0], &res1); ASSERT_EQ(status.error_code, Success); - status = Search(segment, plan, placeholderGroup, {}, &res2); + status = Search( + segment, plan, placeholderGroup, {}, dataset.timestamps_[0], &res2); ASSERT_EQ(status.error_code, Success); - status = Search(segment, plan, placeholderGroup, {}, &res3); + status = Search( + segment, plan, placeholderGroup, {}, dataset.timestamps_[0], &res3); ASSERT_EQ(status.error_code, Success); results.push_back(res1); results.push_back(res2); @@ -1666,9 +1678,11 @@ testReduceSearchWithExpr(int N, std::vector results; CSearchResult res1; CSearchResult res2; - auto res = Search(segment, plan, placeholderGroup, {}, &res1); + auto res = Search( + segment, plan, placeholderGroup, {}, dataset.timestamps_[N - 1], &res1); ASSERT_EQ(res.error_code, Success); - res = Search(segment, plan, placeholderGroup, {}, &res2); + res = Search( + segment, plan, placeholderGroup, {}, dataset.timestamps_[N - 1], &res2); ASSERT_EQ(res.error_code, Success); results.push_back(res1); results.push_back(res2); @@ -1900,9 +1914,15 @@ TEST(CApiTest, Indexing_Without_Predicate) { std::vector placeholderGroups; placeholderGroups.push_back(placeholderGroup); + Timestamp timestmap = 10000000; + CSearchResult c_search_result_on_smallIndex; - auto res_before_load_index = Search( - segment, plan, placeholderGroup, {}, &c_search_result_on_smallIndex); + auto res_before_load_index = Search(segment, + plan, + placeholderGroup, + {}, + timestmap, + &c_search_result_on_smallIndex); ASSERT_EQ(res_before_load_index.error_code, Success); // load index to segment @@ -1962,6 +1982,7 @@ TEST(CApiTest, Indexing_Without_Predicate) { plan, placeholderGroup, {}, + timestmap, &c_search_result_on_bigIndex); ASSERT_EQ(res_after_load_index.error_code, Success); @@ -2044,9 +2065,15 @@ TEST(CApiTest, Indexing_Expr_Without_Predicate) { std::vector placeholderGroups; placeholderGroups.push_back(placeholderGroup); + Timestamp timestamp = 10000000; + CSearchResult c_search_result_on_smallIndex; - auto res_before_load_index = Search( - segment, plan, placeholderGroup, {}, &c_search_result_on_smallIndex); + auto res_before_load_index = Search(segment, + plan, + placeholderGroup, + {}, + timestamp, + &c_search_result_on_smallIndex); ASSERT_EQ(res_before_load_index.error_code, Success); // load index to segment @@ -2107,6 +2134,7 @@ TEST(CApiTest, Indexing_Expr_Without_Predicate) { plan, placeholderGroup, {}, + timestamp, &c_search_result_on_bigIndex); ASSERT_EQ(res_after_load_index.error_code, Success); @@ -2217,10 +2245,15 @@ TEST(CApiTest, Indexing_With_float_Predicate_Range) { std::vector placeholderGroups; placeholderGroups.push_back(placeholderGroup); + Timestamp timestamp = 10000000; CSearchResult c_search_result_on_smallIndex; - auto res_before_load_index = Search( - segment, plan, placeholderGroup, {}, &c_search_result_on_smallIndex); + auto res_before_load_index = Search(segment, + plan, + placeholderGroup, + {}, + timestamp, + &c_search_result_on_smallIndex); ASSERT_EQ(res_before_load_index.error_code, Success); // load index to segment @@ -2281,6 +2314,7 @@ TEST(CApiTest, Indexing_With_float_Predicate_Range) { plan, placeholderGroup, {}, + timestamp, &c_search_result_on_bigIndex); ASSERT_EQ(res_after_load_index.error_code, Success); @@ -2393,10 +2427,15 @@ TEST(CApiTest, Indexing_Expr_With_float_Predicate_Range) { std::vector placeholderGroups; placeholderGroups.push_back(placeholderGroup); + Timestamp timestamp = 10000000; CSearchResult c_search_result_on_smallIndex; - auto res_before_load_index = Search( - segment, plan, placeholderGroup, {}, &c_search_result_on_smallIndex); + auto res_before_load_index = Search(segment, + plan, + placeholderGroup, + {}, + timestamp, + &c_search_result_on_smallIndex); ASSERT_EQ(res_before_load_index.error_code, Success); // load index to segment @@ -2457,6 +2496,7 @@ TEST(CApiTest, Indexing_Expr_With_float_Predicate_Range) { plan, placeholderGroup, {}, + timestamp, &c_search_result_on_bigIndex); ASSERT_EQ(res_after_load_index.error_code, Success); @@ -2561,10 +2601,15 @@ TEST(CApiTest, Indexing_With_float_Predicate_Term) { std::vector placeholderGroups; placeholderGroups.push_back(placeholderGroup); + Timestamp timestamp = 10000000; CSearchResult c_search_result_on_smallIndex; - auto res_before_load_index = Search( - segment, plan, placeholderGroup, {}, &c_search_result_on_smallIndex); + auto res_before_load_index = Search(segment, + plan, + placeholderGroup, + {}, + timestamp, + &c_search_result_on_smallIndex); ASSERT_EQ(res_before_load_index.error_code, Success); // load index to segment @@ -2625,6 +2670,7 @@ TEST(CApiTest, Indexing_With_float_Predicate_Term) { plan, placeholderGroup, {}, + timestamp, &c_search_result_on_bigIndex); ASSERT_EQ(res_after_load_index.error_code, Success); @@ -2730,10 +2776,15 @@ TEST(CApiTest, Indexing_Expr_With_float_Predicate_Term) { std::vector placeholderGroups; placeholderGroups.push_back(placeholderGroup); + Timestamp timestamp = 10000000; CSearchResult c_search_result_on_smallIndex; - auto res_before_load_index = Search( - segment, plan, placeholderGroup, {}, &c_search_result_on_smallIndex); + auto res_before_load_index = Search(segment, + plan, + placeholderGroup, + {}, + timestamp, + &c_search_result_on_smallIndex); ASSERT_EQ(res_before_load_index.error_code, Success); // load index to segment @@ -2794,6 +2845,7 @@ TEST(CApiTest, Indexing_Expr_With_float_Predicate_Term) { plan, placeholderGroup, {}, + timestamp, &c_search_result_on_bigIndex); ASSERT_EQ(res_after_load_index.error_code, Success); @@ -2904,10 +2956,15 @@ TEST(CApiTest, Indexing_With_binary_Predicate_Range) { std::vector placeholderGroups; placeholderGroups.push_back(placeholderGroup); + Timestamp timestamp = 10000000; CSearchResult c_search_result_on_smallIndex; - auto res_before_load_index = Search( - segment, plan, placeholderGroup, {}, &c_search_result_on_smallIndex); + auto res_before_load_index = Search(segment, + plan, + placeholderGroup, + {}, + timestamp, + &c_search_result_on_smallIndex); ASSERT_EQ(res_before_load_index.error_code, Success); // load index to segment @@ -2969,6 +3026,7 @@ TEST(CApiTest, Indexing_With_binary_Predicate_Range) { plan, placeholderGroup, {}, + timestamp, &c_search_result_on_bigIndex); ASSERT_EQ(res_after_load_index.error_code, Success); @@ -3079,10 +3137,15 @@ TEST(CApiTest, Indexing_Expr_With_binary_Predicate_Range) { std::vector placeholderGroups; placeholderGroups.push_back(placeholderGroup); + Timestamp timestamp = 10000000; CSearchResult c_search_result_on_smallIndex; - auto res_before_load_index = Search( - segment, plan, placeholderGroup, {}, &c_search_result_on_smallIndex); + auto res_before_load_index = Search(segment, + plan, + placeholderGroup, + {}, + timestamp, + &c_search_result_on_smallIndex); ASSERT_TRUE(res_before_load_index.error_code == Success) << res_before_load_index.error_msg; @@ -3144,6 +3207,7 @@ TEST(CApiTest, Indexing_Expr_With_binary_Predicate_Range) { plan, placeholderGroup, {}, + timestamp, &c_search_result_on_bigIndex); ASSERT_EQ(res_after_load_index.error_code, Success); @@ -3249,10 +3313,15 @@ TEST(CApiTest, Indexing_With_binary_Predicate_Term) { std::vector placeholderGroups; placeholderGroups.push_back(placeholderGroup); + Timestamp timestamp = 10000000; CSearchResult c_search_result_on_smallIndex; - auto res_before_load_index = Search( - segment, plan, placeholderGroup, {}, &c_search_result_on_smallIndex); + auto res_before_load_index = Search(segment, + plan, + placeholderGroup, + {}, + timestamp, + &c_search_result_on_smallIndex); ASSERT_EQ(res_before_load_index.error_code, Success); // load index to segment @@ -3313,6 +3382,7 @@ TEST(CApiTest, Indexing_With_binary_Predicate_Term) { plan, placeholderGroup, {}, + timestamp, &c_search_result_on_bigIndex); ASSERT_EQ(res_after_load_index.error_code, Success); @@ -3440,11 +3510,15 @@ TEST(CApiTest, Indexing_Expr_With_binary_Predicate_Term) { std::vector placeholderGroups; placeholderGroups.push_back(placeholderGroup); - Timestamp time = 10000000; + Timestamp timestamp = 10000000; CSearchResult c_search_result_on_smallIndex; - auto res_before_load_index = Search( - segment, plan, placeholderGroup, {}, &c_search_result_on_smallIndex); + auto res_before_load_index = Search(segment, + plan, + placeholderGroup, + {}, + timestamp, + &c_search_result_on_smallIndex); ASSERT_EQ(res_before_load_index.error_code, Success); // load index to segment @@ -3505,6 +3579,7 @@ TEST(CApiTest, Indexing_Expr_With_binary_Predicate_Term) { plan, placeholderGroup, {}, + timestamp, &c_search_result_on_bigIndex); ASSERT_EQ(res_after_load_index.error_code, Success); @@ -3643,7 +3718,7 @@ TEST(CApiTest, SealedSegment_search_float_Predicate_Range) { std::vector placeholderGroups; placeholderGroups.push_back(placeholderGroup); - Timestamp time = 10000000; + Timestamp timestamp = 10000000; // load index to segment auto indexing = generate_index(vec_col.data(), @@ -3702,6 +3777,7 @@ TEST(CApiTest, SealedSegment_search_float_Predicate_Range) { plan, placeholderGroup, {}, + timestamp, &c_search_result_on_bigIndex); ASSERT_EQ(res_after_load_index.error_code, Success); @@ -3780,12 +3856,14 @@ TEST(CApiTest, SealedSegment_search_without_predicates) { std::vector placeholderGroups; placeholderGroups.push_back(placeholderGroup); CSearchResult search_result; - auto res = Search(segment, plan, placeholderGroup, {}, &search_result); + auto res = Search( + segment, plan, placeholderGroup, {}, N + ts_offset, &search_result); std::cout << res.error_msg << std::endl; ASSERT_EQ(res.error_code, Success); CSearchResult search_result2; - auto res2 = Search(segment, plan, placeholderGroup, {}, &search_result2); + auto res2 = Search( + segment, plan, placeholderGroup, {}, N + ts_offset, &search_result2); ASSERT_EQ(res2.error_code, Success); DeleteSearchPlan(plan); @@ -3874,6 +3952,7 @@ TEST(CApiTest, SealedSegment_search_float_With_Expr_Predicate_Range) { std::vector placeholderGroups; placeholderGroups.push_back(placeholderGroup); + Timestamp timestamp = 10000000; // load index to segment auto indexing = generate_index(vec_col.data(), @@ -3933,8 +4012,12 @@ TEST(CApiTest, SealedSegment_search_float_With_Expr_Predicate_Range) { } CSearchResult c_search_result_on_bigIndex; - auto res_after_load_index = Search( - segment, plan, placeholderGroup, {}, &c_search_result_on_bigIndex); + auto res_after_load_index = Search(segment, + plan, + placeholderGroup, + {}, + timestamp, + &c_search_result_on_bigIndex); ASSERT_EQ(res_after_load_index.error_code, Success); auto search_result_on_bigIndex = (SearchResult*)c_search_result_on_bigIndex; @@ -4230,7 +4313,8 @@ TEST(CApiTest, RANGE_SEARCH_WITH_RADIUS_WHEN_IP) { placeholderGroups.push_back(placeholderGroup); CSearchResult search_result; - auto res = Search(segment, plan, placeholderGroup, {}, &search_result); + auto res = + Search(segment, plan, placeholderGroup, {}, ts_offset, &search_result); ASSERT_EQ(res.error_code, Success); DeleteSearchPlan(plan); @@ -4293,7 +4377,8 @@ TEST(CApiTest, RANGE_SEARCH_WITH_RADIUS_AND_RANGE_FILTER_WHEN_IP) { placeholderGroups.push_back(placeholderGroup); CSearchResult search_result; - auto res = Search(segment, plan, placeholderGroup, {}, &search_result); + auto res = + Search(segment, plan, placeholderGroup, {}, ts_offset, &search_result); ASSERT_EQ(res.error_code, Success); DeleteSearchPlan(plan); @@ -4356,7 +4441,8 @@ TEST(CApiTest, RANGE_SEARCH_WITH_RADIUS_WHEN_L2) { placeholderGroups.push_back(placeholderGroup); CSearchResult search_result; - auto res = Search(segment, plan, placeholderGroup, {}, &search_result); + auto res = + Search(segment, plan, placeholderGroup, {}, ts_offset, &search_result); ASSERT_EQ(res.error_code, Success); DeleteSearchPlan(plan); @@ -4419,7 +4505,8 @@ TEST(CApiTest, RANGE_SEARCH_WITH_RADIUS_AND_RANGE_FILTER_WHEN_L2) { placeholderGroups.push_back(placeholderGroup); CSearchResult search_result; - auto res = Search(segment, plan, placeholderGroup, {}, &search_result); + auto res = + Search(segment, plan, placeholderGroup, {}, ts_offset, &search_result); ASSERT_EQ(res.error_code, Success); DeleteSearchPlan(plan); diff --git a/internal/core/unittest/test_float16.cpp b/internal/core/unittest/test_float16.cpp index f41f6f9cebbbb..81674884ef045 100644 --- a/internal/core/unittest/test_float16.cpp +++ b/internal/core/unittest/test_float16.cpp @@ -154,7 +154,7 @@ TEST(Float16, ExecWithoutPredicateFlat) { auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); - auto sr = segment->Search(plan.get(), ph_group.get()); + auto sr = segment->Search(plan.get(), ph_group.get(), 1L << 63); int topk = 5; query::Json json = SearchResultToJson(*sr); @@ -392,7 +392,7 @@ TEST(Float16, ExecWithPredicate) { auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); - auto sr = segment->Search(plan.get(), ph_group.get()); + auto sr = segment->Search(plan.get(), ph_group.get(), 1L << 63); int topk = 5; query::Json json = SearchResultToJson(*sr); diff --git a/internal/core/unittest/test_group_by.cpp b/internal/core/unittest/test_group_by.cpp index 3e8ef859edff8..50dc8c1ecf194 100644 --- a/internal/core/unittest/test_group_by.cpp +++ b/internal/core/unittest/test_group_by.cpp @@ -1,7 +1,17 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. // -// Created by zilliz on 2023/12/1. +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 // +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License +// +// Created by zilliz on 2023/12/1. +// #include #include "common/Schema.h" @@ -20,29 +30,29 @@ using namespace milvus::storage; const char* METRICS_TYPE = "metric_type"; - void prepareSegmentSystemFieldData(const std::unique_ptr& segment, size_t row_count, - GeneratedData& data_set){ + GeneratedData& data_set) { auto field_data = - std::make_shared>(DataType::INT64); + std::make_shared>(DataType::INT64); field_data->FillFieldData(data_set.row_ids_.data(), row_count); - auto field_data_info = FieldDataInfo{ - RowFieldID.get(), row_count, std::vector{field_data}}; + auto field_data_info = + FieldDataInfo{RowFieldID.get(), + row_count, + std::vector{field_data}}; segment->LoadFieldData(RowFieldID, field_data_info); - field_data = - std::make_shared>(DataType::INT64); + field_data = std::make_shared>(DataType::INT64); field_data->FillFieldData(data_set.timestamps_.data(), row_count); field_data_info = - FieldDataInfo{TimestampFieldID.get(), - row_count, - std::vector{field_data}}; + FieldDataInfo{TimestampFieldID.get(), + row_count, + std::vector{field_data}}; segment->LoadFieldData(TimestampFieldID, field_data_info); } -TEST(GroupBY, Normal2){ +TEST(GroupBY, Normal2) { using namespace milvus; using namespace milvus::query; using namespace milvus::segcore; @@ -51,7 +61,7 @@ TEST(GroupBY, Normal2){ int dim = 64; auto schema = std::make_shared(); auto vec_fid = schema->AddDebugField( - "fakevec", DataType::VECTOR_FLOAT, dim, knowhere::metric::L2); + "fakevec", DataType::VECTOR_FLOAT, dim, knowhere::metric::L2); auto int8_fid = schema->AddDebugField("int8", DataType::INT8); auto int16_fid = schema->AddDebugField("int16", DataType::INT16); auto int32_fid = schema->AddDebugField("int32", DataType::INT32); @@ -71,7 +81,7 @@ TEST(GroupBY, Normal2){ auto info = FieldDataInfo(field_data.field_id(), N); auto field_meta = fields.at(FieldId(field_id)); info.channel->push( - CreateFieldDataFromDataArray(N, &field_data, field_meta)); + CreateFieldDataFromDataArray(N, &field_data, field_meta)); info.channel->close(); segment->LoadFieldData(FieldId(field_id), info); @@ -80,7 +90,8 @@ TEST(GroupBY, Normal2){ //3. load index auto vector_data = raw_data.get_col(vec_fid); - auto indexing = GenVecIndexing(N, dim, vector_data.data(), knowhere::IndexEnum::INDEX_HNSW); + auto indexing = GenVecIndexing( + N, dim, vector_data.data(), knowhere::IndexEnum::INDEX_HNSW); LoadIndexInfo load_index_info; load_index_info.field_id = vec_fid.get(); load_index_info.index = std::move(indexing); @@ -102,26 +113,34 @@ TEST(GroupBY, Normal2){ >)"; auto plan_str = translate_text_plan_to_binary_plan(raw_plan); - auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + auto plan = + CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); auto num_queries = 1; auto seed = 1024; auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed); - auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); - auto search_result = segment->Search(plan.get(), ph_group.get()); + auto ph_group = + ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); + auto search_result = + segment->Search(plan.get(), ph_group.get(), 1L << 63); auto& group_by_values = search_result->group_by_values_; - ASSERT_EQ(search_result->group_by_values_.size(), search_result->seg_offsets_.size()); - ASSERT_EQ(search_result->distances_.size(), search_result->seg_offsets_.size()); + ASSERT_EQ(search_result->group_by_values_.size(), + search_result->seg_offsets_.size()); + ASSERT_EQ(search_result->distances_.size(), + search_result->seg_offsets_.size()); int size = group_by_values.size(); std::unordered_set i8_set; float lastDistance = 0.0; - for(size_t i = 0; i < size; i++){ - if(std::holds_alternative(group_by_values[i])){ + for (size_t i = 0; i < size; i++) { + if (std::holds_alternative(group_by_values[i])) { int8_t g_val = std::get(group_by_values[i]); - ASSERT_FALSE(i8_set.count(g_val)>0);//no repetition on groupBy field + ASSERT_FALSE(i8_set.count(g_val) > + 0); //no repetition on groupBy field i8_set.insert(g_val); auto distance = search_result->distances_.at(i); - ASSERT_TRUE(lastDistance<=distance);//distance should be decreased as metrics_type is L2 + ASSERT_TRUE( + lastDistance <= + distance); //distance should be decreased as metrics_type is L2 lastDistance = distance; } else { //check padding @@ -146,26 +165,34 @@ TEST(GroupBY, Normal2){ >)"; auto plan_str = translate_text_plan_to_binary_plan(raw_plan); - auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + auto plan = + CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); auto num_queries = 1; auto seed = 1024; auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed); - auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); - auto search_result = segment->Search(plan.get(), ph_group.get()); + auto ph_group = + ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); + auto search_result = + segment->Search(plan.get(), ph_group.get(), 1L << 63); auto& group_by_values = search_result->group_by_values_; - ASSERT_EQ(search_result->group_by_values_.size(), search_result->seg_offsets_.size()); - ASSERT_EQ(search_result->distances_.size(), search_result->seg_offsets_.size()); + ASSERT_EQ(search_result->group_by_values_.size(), + search_result->seg_offsets_.size()); + ASSERT_EQ(search_result->distances_.size(), + search_result->seg_offsets_.size()); int size = group_by_values.size(); std::unordered_set i16_set; float lastDistance = 0.0; - for(size_t i = 0; i < size; i++){ - if(std::holds_alternative(group_by_values[i])){ + for (size_t i = 0; i < size; i++) { + if (std::holds_alternative(group_by_values[i])) { int16_t g_val = std::get(group_by_values[i]); - ASSERT_FALSE(i16_set.count(g_val)>0);//no repetition on groupBy field + ASSERT_FALSE(i16_set.count(g_val) > + 0); //no repetition on groupBy field i16_set.insert(g_val); auto distance = search_result->distances_.at(i); - ASSERT_TRUE(lastDistance<=distance);//distance should be decreased as metrics_type is L2 + ASSERT_TRUE( + lastDistance <= + distance); //distance should be decreased as metrics_type is L2 lastDistance = distance; } else { //check padding @@ -190,26 +217,34 @@ TEST(GroupBY, Normal2){ >)"; auto plan_str = translate_text_plan_to_binary_plan(raw_plan); - auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + auto plan = + CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); auto num_queries = 1; auto seed = 1024; auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed); - auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); - auto search_result = segment->Search(plan.get(), ph_group.get()); + auto ph_group = + ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); + auto search_result = + segment->Search(plan.get(), ph_group.get(), 1L << 63); auto& group_by_values = search_result->group_by_values_; - ASSERT_EQ(search_result->group_by_values_.size(), search_result->seg_offsets_.size()); - ASSERT_EQ(search_result->distances_.size(), search_result->seg_offsets_.size()); + ASSERT_EQ(search_result->group_by_values_.size(), + search_result->seg_offsets_.size()); + ASSERT_EQ(search_result->distances_.size(), + search_result->seg_offsets_.size()); int size = group_by_values.size(); std::unordered_set i32_set; float lastDistance = 0.0; - for(size_t i = 0; i < size; i++){ - if(std::holds_alternative(group_by_values[i])){ + for (size_t i = 0; i < size; i++) { + if (std::holds_alternative(group_by_values[i])) { int16_t g_val = std::get(group_by_values[i]); - ASSERT_FALSE(i32_set.count(g_val)>0);//no repetition on groupBy field + ASSERT_FALSE(i32_set.count(g_val) > + 0); //no repetition on groupBy field i32_set.insert(g_val); auto distance = search_result->distances_.at(i); - ASSERT_TRUE(lastDistance<=distance);//distance should be decreased as metrics_type is L2 + ASSERT_TRUE( + lastDistance <= + distance); //distance should be decreased as metrics_type is L2 lastDistance = distance; } else { //check padding @@ -234,26 +269,34 @@ TEST(GroupBY, Normal2){ >)"; auto plan_str = translate_text_plan_to_binary_plan(raw_plan); - auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + auto plan = + CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); auto num_queries = 1; auto seed = 1024; auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed); - auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); - auto search_result = segment->Search(plan.get(), ph_group.get()); + auto ph_group = + ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); + auto search_result = + segment->Search(plan.get(), ph_group.get(), 1L << 63); auto& group_by_values = search_result->group_by_values_; - ASSERT_EQ(search_result->group_by_values_.size(), search_result->seg_offsets_.size()); - ASSERT_EQ(search_result->distances_.size(), search_result->seg_offsets_.size()); + ASSERT_EQ(search_result->group_by_values_.size(), + search_result->seg_offsets_.size()); + ASSERT_EQ(search_result->distances_.size(), + search_result->seg_offsets_.size()); int size = group_by_values.size(); std::unordered_set i64_set; float lastDistance = 0.0; - for(size_t i = 0; i < size; i++){ - if(std::holds_alternative(group_by_values[i])){ + for (size_t i = 0; i < size; i++) { + if (std::holds_alternative(group_by_values[i])) { int16_t g_val = std::get(group_by_values[i]); - ASSERT_FALSE(i64_set.count(g_val)>0);//no repetition on groupBy field + ASSERT_FALSE(i64_set.count(g_val) > + 0); //no repetition on groupBy field i64_set.insert(g_val); auto distance = search_result->distances_.at(i); - ASSERT_TRUE(lastDistance<=distance);//distance should be decreased as metrics_type is L2 + ASSERT_TRUE( + lastDistance <= + distance); //distance should be decreased as metrics_type is L2 lastDistance = distance; } else { //check padding @@ -278,26 +321,35 @@ TEST(GroupBY, Normal2){ >)"; auto plan_str = translate_text_plan_to_binary_plan(raw_plan); - auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + auto plan = + CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); auto num_queries = 1; auto seed = 1024; auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed); - auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); - auto search_result = segment->Search(plan.get(), ph_group.get()); + auto ph_group = + ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); + auto search_result = + segment->Search(plan.get(), ph_group.get(), 1L << 63); auto& group_by_values = search_result->group_by_values_; - ASSERT_EQ(search_result->group_by_values_.size(), search_result->seg_offsets_.size()); - ASSERT_EQ(search_result->distances_.size(), search_result->seg_offsets_.size()); + ASSERT_EQ(search_result->group_by_values_.size(), + search_result->seg_offsets_.size()); + ASSERT_EQ(search_result->distances_.size(), + search_result->seg_offsets_.size()); int size = group_by_values.size(); std::unordered_set strs_set; float lastDistance = 0.0; - for(size_t i = 0; i < size; i++){ - if(std::holds_alternative(group_by_values[i])){ - std::string_view g_val = std::get(group_by_values[i]); - ASSERT_FALSE(strs_set.count(g_val)>0);//no repetition on groupBy field + for (size_t i = 0; i < size; i++) { + if (std::holds_alternative(group_by_values[i])) { + std::string_view g_val = + std::get(group_by_values[i]); + ASSERT_FALSE(strs_set.count(g_val) > + 0); //no repetition on groupBy field strs_set.insert(g_val); auto distance = search_result->distances_.at(i); - ASSERT_TRUE(lastDistance<=distance);//distance should be decreased as metrics_type is L2 + ASSERT_TRUE( + lastDistance <= + distance); //distance should be decreased as metrics_type is L2 lastDistance = distance; } else { //check padding @@ -322,40 +374,48 @@ TEST(GroupBY, Normal2){ >)"; auto plan_str = translate_text_plan_to_binary_plan(raw_plan); - auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + auto plan = + CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); auto num_queries = 1; auto seed = 1024; auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed); - auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); - auto search_result = segment->Search(plan.get(), ph_group.get()); + auto ph_group = + ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); + auto search_result = + segment->Search(plan.get(), ph_group.get(), 1L << 63); auto& group_by_values = search_result->group_by_values_; - ASSERT_EQ(search_result->group_by_values_.size(), search_result->seg_offsets_.size()); - ASSERT_EQ(search_result->distances_.size(), search_result->seg_offsets_.size()); + ASSERT_EQ(search_result->group_by_values_.size(), + search_result->seg_offsets_.size()); + ASSERT_EQ(search_result->distances_.size(), + search_result->seg_offsets_.size()); int size = group_by_values.size(); std::unordered_set bools_set; int boolValCount = 0; float lastDistance = 0.0; - for(size_t i = 0; i < size; i++){ - if(std::holds_alternative(group_by_values[i])){ + for (size_t i = 0; i < size; i++) { + if (std::holds_alternative(group_by_values[i])) { bool g_val = std::get(group_by_values[i]); - ASSERT_FALSE(bools_set.count(g_val)>0);//no repetition on groupBy field + ASSERT_FALSE(bools_set.count(g_val) > + 0); //no repetition on groupBy field bools_set.insert(g_val); - boolValCount+=1; + boolValCount += 1; auto distance = search_result->distances_.at(i); - ASSERT_TRUE(lastDistance<=distance);//distance should be decreased as metrics_type is L2 + ASSERT_TRUE( + lastDistance <= + distance); //distance should be decreased as metrics_type is L2 lastDistance = distance; } else { //check padding ASSERT_EQ(search_result->seg_offsets_[i], INVALID_SEG_OFFSET); ASSERT_EQ(search_result->distances_[i], 0.0); } - ASSERT_TRUE(boolValCount<=2);//bool values cannot exceed two + ASSERT_TRUE(boolValCount <= 2); //bool values cannot exceed two } } } -TEST(GroupBY, Reduce){ +TEST(GroupBY, Reduce) { using namespace milvus; using namespace milvus::query; using namespace milvus::segcore; @@ -364,7 +424,7 @@ TEST(GroupBY, Reduce){ int dim = 64; auto schema = std::make_shared(); auto vec_fid = schema->AddDebugField( - "fakevec", DataType::VECTOR_FLOAT, dim, knowhere::metric::L2); + "fakevec", DataType::VECTOR_FLOAT, dim, knowhere::metric::L2); auto int64_fid = schema->AddDebugField("int64", DataType::INT64); schema->set_primary_field_id(int64_fid); auto segment1 = CreateSealedSegment(schema); @@ -386,7 +446,7 @@ TEST(GroupBY, Reduce){ auto info = FieldDataInfo(field_data.field_id(), N); auto field_meta = fields.at(FieldId(field_id)); info.channel->push( - CreateFieldDataFromDataArray(N, &field_data, field_meta)); + CreateFieldDataFromDataArray(N, &field_data, field_meta)); info.channel->close(); segment1->LoadFieldData(FieldId(field_id), info); } @@ -398,7 +458,7 @@ TEST(GroupBY, Reduce){ auto info = FieldDataInfo(field_data.field_id(), N); auto field_meta = fields.at(FieldId(field_id)); info.channel->push( - CreateFieldDataFromDataArray(N, &field_data, field_meta)); + CreateFieldDataFromDataArray(N, &field_data, field_meta)); info.channel->close(); segment2->LoadFieldData(FieldId(field_id), info); } @@ -406,7 +466,8 @@ TEST(GroupBY, Reduce){ //3. load index auto vector_data_1 = raw_data1.get_col(vec_fid); - auto indexing_1 = GenVecIndexing(N, dim, vector_data_1.data(), knowhere::IndexEnum::INDEX_HNSW); + auto indexing_1 = GenVecIndexing( + N, dim, vector_data_1.data(), knowhere::IndexEnum::INDEX_HNSW); LoadIndexInfo load_index_info_1; load_index_info_1.field_id = vec_fid.get(); load_index_info_1.index = std::move(indexing_1); @@ -414,14 +475,14 @@ TEST(GroupBY, Reduce){ segment1->LoadIndex(load_index_info_1); auto vector_data_2 = raw_data2.get_col(vec_fid); - auto indexing_2 = GenVecIndexing(N, dim, vector_data_2.data(), knowhere::IndexEnum::INDEX_HNSW); + auto indexing_2 = GenVecIndexing( + N, dim, vector_data_2.data(), knowhere::IndexEnum::INDEX_HNSW); LoadIndexInfo load_index_info_2; load_index_info_2.field_id = vec_fid.get(); load_index_info_2.index = std::move(indexing_2); load_index_info_2.index_params[METRICS_TYPE] = knowhere::metric::L2; segment2->LoadIndex(load_index_info_2); - //4. search group by respectively const char* raw_plan = R"(vector_anns: < field_id: 100 @@ -435,11 +496,13 @@ TEST(GroupBY, Reduce){ >)"; auto plan_str = translate_text_plan_to_binary_plan(raw_plan); - auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + auto plan = + CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); auto num_queries = 10; auto topK = 100; auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed); - auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); + auto ph_group = + ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); CPlaceholderGroup c_ph_group = ph_group.release(); CSearchPlan c_plan = plan.release(); @@ -447,9 +510,11 @@ TEST(GroupBY, Reduce){ CSegmentInterface c_segment_2 = segment2.release(); CSearchResult c_search_res_1; CSearchResult c_search_res_2; - auto status = Search(c_segment_1, c_plan, c_ph_group, {}, &c_search_res_1); + auto status = + Search(c_segment_1, c_plan, c_ph_group, {}, 1L << 63, &c_search_res_1); ASSERT_EQ(status.error_code, Success); - status = Search(c_segment_2, c_plan, c_ph_group, {}, &c_search_res_2); + status = + Search(c_segment_2, c_plan, c_ph_group, {}, 1L << 63, &c_search_res_2); ASSERT_EQ(status.error_code, Success); std::vector results; results.push_back(c_search_res_1); @@ -458,23 +523,20 @@ TEST(GroupBY, Reduce){ auto slice_nqs = std::vector{num_queries / 2, num_queries / 2}; auto slice_topKs = std::vector{topK / 2, topK}; CSearchResultDataBlobs cSearchResultData; - status = ReduceSearchResultsAndFillData( - &cSearchResultData, - c_plan, - results.data(), - results.size(), - slice_nqs.data(), - slice_topKs.data(), - slice_nqs.size() - ); + status = ReduceSearchResultsAndFillData(&cSearchResultData, + c_plan, + results.data(), + results.size(), + slice_nqs.data(), + slice_topKs.data(), + slice_nqs.size()); CheckSearchResultDuplicate(results); DeleteSearchResult(c_search_res_1); DeleteSearchResult(c_search_res_2); DeleteSearchResultDataBlobs(cSearchResultData); - DeleteSearchPlan(c_plan); DeletePlaceholderGroup(c_ph_group); DeleteSegment(c_segment_1); DeleteSegment(c_segment_2); -} \ No newline at end of file +} diff --git a/internal/core/unittest/test_growing_index.cpp b/internal/core/unittest/test_growing_index.cpp index 3666dc7cf0ddd..54f09c4474044 100644 --- a/internal/core/unittest/test_growing_index.cpp +++ b/internal/core/unittest/test_growing_index.cpp @@ -101,7 +101,9 @@ TEST(GrowingIndex, Correctness) { *schema, plan_str.data(), plan_str.size()); auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); - auto sr = segment->Search(plan.get(), ph_group.get()); + + Timestamp timestamp = 1000000; + auto sr = segment->Search(plan.get(), ph_group.get(), timestamp); EXPECT_EQ(sr->total_nq_, num_queries); EXPECT_EQ(sr->unity_topK_, top_k); EXPECT_EQ(sr->distances_.size(), num_queries * top_k); @@ -111,7 +113,8 @@ TEST(GrowingIndex, Correctness) { *schema, range_plan_str.data(), range_plan_str.size()); auto range_ph_group = ParsePlaceholderGroup( range_plan.get(), ph_group_raw.SerializeAsString()); - auto range_sr = segment->Search(range_plan.get(), range_ph_group.get()); + auto range_sr = + segment->Search(range_plan.get(), range_ph_group.get(), timestamp); ASSERT_EQ(range_sr->total_nq_, num_queries); EXPECT_EQ(sr->unity_topK_, top_k); EXPECT_EQ(sr->distances_.size(), num_queries * top_k); diff --git a/internal/core/unittest/test_query.cpp b/internal/core/unittest/test_query.cpp index 97fba1187ba31..2e0223126935b 100644 --- a/internal/core/unittest/test_query.cpp +++ b/internal/core/unittest/test_query.cpp @@ -128,8 +128,9 @@ TEST(Query, ExecWithPredicateLoader) { auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024); auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); + Timestamp timestamp = 1000000; - auto sr = segment->Search(plan.get(), ph_group.get()); + auto sr = segment->Search(plan.get(), ph_group.get(), timestamp); query::Json json = SearchResultToJson(*sr); #ifdef __linux__ @@ -212,7 +213,9 @@ TEST(Query, ExecWithPredicateSmallN) { auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); - auto sr = segment->Search(plan.get(), ph_group.get()); + Timestamp timestamp = 1000000; + + auto sr = segment->Search(plan.get(), ph_group.get(), timestamp); query::Json json = SearchResultToJson(*sr); std::cout << json.dump(2); @@ -270,8 +273,9 @@ TEST(Query, ExecWithPredicate) { auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024); auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); + Timestamp timestamp = 1000000; - auto sr = segment->Search(plan.get(), ph_group.get()); + auto sr = segment->Search(plan.get(), ph_group.get(), timestamp); query::Json json = SearchResultToJson(*sr); #ifdef __linux__ @@ -351,8 +355,9 @@ TEST(Query, ExecTerm) { auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024); auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); + Timestamp timestamp = 1000000; - auto sr = segment->Search(plan.get(), ph_group.get()); + auto sr = segment->Search(plan.get(), ph_group.get(), timestamp); int topk = 5; auto json = SearchResultToJson(*sr); ASSERT_EQ(sr->total_nq_, num_queries); @@ -386,7 +391,8 @@ TEST(Query, ExecEmpty) { auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); - auto sr = segment->Search(plan.get(), ph_group.get()); + Timestamp timestamp = 1000000; + auto sr = segment->Search(plan.get(), ph_group.get(), timestamp); std::cout << SearchResultToJson(*sr); ASSERT_EQ(sr->unity_topK_, 0); @@ -434,8 +440,8 @@ TEST(Query, ExecWithoutPredicateFlat) { auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024); auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); - - auto sr = segment->Search(plan.get(), ph_group.get()); + Timestamp timestamp = 1000000; + auto sr = segment->Search(plan.get(), ph_group.get(), timestamp); std::vector> results; auto json = SearchResultToJson(*sr); std::cout << json.dump(2); @@ -477,8 +483,9 @@ TEST(Query, ExecWithoutPredicate) { auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024); auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); + Timestamp timestamp = 1000000; - auto sr = segment->Search(plan.get(), ph_group.get()); + auto sr = segment->Search(plan.get(), ph_group.get(), timestamp); assert_order(*sr, "l2"); std::vector> results; auto json = SearchResultToJson(*sr); @@ -546,7 +553,9 @@ TEST(Query, InnerProduct) { CreatePlaceholderGroupFromBlob(num_queries, 16, col.data()); auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); - auto sr = segment->Search(plan.get(), ph_group.get()); + + Timestamp ts = N * 2; + auto sr = segment->Search(plan.get(), ph_group.get(), ts); assert_order(*sr, "ip"); } @@ -633,6 +642,8 @@ TEST(Query, FillSegment) { CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); auto ph_proto = CreatePlaceholderGroup(10, 16, 443); auto ph = ParsePlaceholderGroup(plan.get(), ph_proto.SerializeAsString()); + Timestamp ts = N * 2UL; + auto topk = 5; auto num_queries = 10; @@ -642,7 +653,7 @@ TEST(Query, FillSegment) { schema->get_field_id(FieldName("fakevec"))); plan->target_entries_.push_back( schema->get_field_id(FieldName("the_value"))); - auto result = segment->Search(plan.get(), ph.get()); + auto result = segment->Search(plan.get(), ph.get(), ts); result->result_offsets_.resize(topk * num_queries); segment->FillTargetEntry(plan.get(), *result); segment->FillPrimaryKeys(plan.get(), *result); @@ -746,7 +757,9 @@ TEST(Query, ExecWithPredicateBinary) { num_queries, 512, vec_ptr.data() + 1024 * 512 / 8); auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); - auto sr = segment->Search(plan.get(), ph_group.get()); + + Timestamp timestamp = 1000000; + auto sr = segment->Search(plan.get(), ph_group.get(), timestamp); query::Json json = SearchResultToJson(*sr); std::cout << json.dump(2); diff --git a/internal/core/unittest/test_sealed.cpp b/internal/core/unittest/test_sealed.cpp index e768dfe08a355..9361408385768 100644 --- a/internal/core/unittest/test_sealed.cpp +++ b/internal/core/unittest/test_sealed.cpp @@ -80,10 +80,11 @@ TEST(Sealed, without_predicate) { CreatePlaceholderGroupFromBlob(num_queries, 16, query_ptr); auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); + Timestamp timestamp = 1000000; std::vector ph_group_arr = {ph_group.get()}; - auto sr = segment->Search(plan.get(), ph_group.get()); + auto sr = segment->Search(plan.get(), ph_group.get(), timestamp); auto pre_result = SearchResultToJson(*sr); milvus::index::CreateIndexInfo create_index_info; create_index_info.field_type = DataType::VECTOR_FLOAT; @@ -127,7 +128,7 @@ TEST(Sealed, without_predicate) { sealed_segment->DropFieldData(fake_id); sealed_segment->LoadIndex(load_info); - sr = sealed_segment->Search(plan.get(), ph_group.get()); + sr = sealed_segment->Search(plan.get(), ph_group.get(), timestamp); auto post_result = SearchResultToJson(*sr); std::cout << "ref_result" << std::endl; @@ -135,6 +136,9 @@ TEST(Sealed, without_predicate) { std::cout << "post_result" << std::endl; std::cout << post_result.dump(1); // ASSERT_EQ(ref_result.dump(1), post_result.dump(1)); + + sr = sealed_segment->Search(plan.get(), ph_group.get(), 0); + EXPECT_EQ(sr->get_total_result_count(), 0); } TEST(Sealed, with_predicate) { @@ -196,10 +200,11 @@ TEST(Sealed, with_predicate) { CreatePlaceholderGroupFromBlob(num_queries, 16, query_ptr); auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); + Timestamp timestamp = 1000000; std::vector ph_group_arr = {ph_group.get()}; - auto sr = segment->Search(plan.get(), ph_group.get()); + auto sr = segment->Search(plan.get(), ph_group.get(), timestamp); milvus::index::CreateIndexInfo create_index_info; create_index_info.field_type = DataType::VECTOR_FLOAT; create_index_info.metric_type = knowhere::metric::L2; @@ -242,7 +247,7 @@ TEST(Sealed, with_predicate) { sealed_segment->DropFieldData(fake_id); sealed_segment->LoadIndex(load_info); - sr = sealed_segment->Search(plan.get(), ph_group.get()); + sr = sealed_segment->Search(plan.get(), ph_group.get(), timestamp); for (int i = 0; i < num_queries; ++i) { auto offset = i * topK; @@ -303,6 +308,7 @@ TEST(Sealed, with_predicate_filter_all) { CreatePlaceholderGroupFromBlob(num_queries, 16, query_ptr); auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); + Timestamp timestamp = 1000000; std::vector ph_group_arr = {ph_group.get()}; @@ -337,7 +343,7 @@ TEST(Sealed, with_predicate_filter_all) { ivf_sealed_segment->DropFieldData(fake_id); ivf_sealed_segment->LoadIndex(load_info); - auto sr = ivf_sealed_segment->Search(plan.get(), ph_group.get()); + auto sr = ivf_sealed_segment->Search(plan.get(), ph_group.get(), timestamp); EXPECT_EQ(sr->unity_topK_, 0); EXPECT_EQ(sr->get_total_result_count(), 0); @@ -372,7 +378,8 @@ TEST(Sealed, with_predicate_filter_all) { hnsw_sealed_segment->DropFieldData(fake_id); hnsw_sealed_segment->LoadIndex(hnsw_load_info); - auto sr2 = hnsw_sealed_segment->Search(plan.get(), ph_group.get()); + auto sr2 = + hnsw_sealed_segment->Search(plan.get(), ph_group.get(), timestamp); EXPECT_EQ(sr2->unity_topK_, 0); EXPECT_EQ(sr2->get_total_result_count(), 0); } @@ -400,7 +407,8 @@ TEST(Sealed, LoadFieldData) { auto fakevec = dataset.get_col(fakevec_id); - auto indexing = GenVecIndexing(N, dim, fakevec.data(), knowhere::IndexEnum::INDEX_FAISS_IVFFLAT); + auto indexing = GenVecIndexing( + N, dim, fakevec.data(), knowhere::IndexEnum::INDEX_FAISS_IVFFLAT); auto segment = CreateSealedSegment(schema); // std::string dsl = R"({ @@ -456,7 +464,7 @@ TEST(Sealed, LoadFieldData) { > placeholder_tag: "$0" >)"; - + Timestamp timestamp = 1000000; auto plan_str = translate_text_plan_to_binary_plan(raw_plan); auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); @@ -465,13 +473,13 @@ TEST(Sealed, LoadFieldData) { auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); - ASSERT_ANY_THROW(segment->Search(plan.get(), ph_group.get())); + ASSERT_ANY_THROW(segment->Search(plan.get(), ph_group.get(), timestamp)); SealedLoadFieldData(dataset, *segment); - segment->Search(plan.get(), ph_group.get()); + segment->Search(plan.get(), ph_group.get(), timestamp); segment->DropFieldData(fakevec_id); - ASSERT_ANY_THROW(segment->Search(plan.get(), ph_group.get())); + ASSERT_ANY_THROW(segment->Search(plan.get(), ph_group.get(), timestamp)); LoadIndexInfo vec_info; vec_info.field_id = fakevec_id.get(); @@ -494,12 +502,12 @@ TEST(Sealed, LoadFieldData) { ASSERT_EQ(chunk_span3[i], ref3[i]); } - auto sr = segment->Search(plan.get(), ph_group.get()); + auto sr = segment->Search(plan.get(), ph_group.get(), timestamp); auto json = SearchResultToJson(*sr); std::cout << json.dump(1); segment->DropIndex(fakevec_id); - ASSERT_ANY_THROW(segment->Search(plan.get(), ph_group.get())); + ASSERT_ANY_THROW(segment->Search(plan.get(), ph_group.get(), timestamp)); } TEST(Sealed, LoadFieldDataMmap) { @@ -525,7 +533,8 @@ TEST(Sealed, LoadFieldDataMmap) { auto fakevec = dataset.get_col(fakevec_id); - auto indexing = GenVecIndexing(N, dim, fakevec.data(), knowhere::IndexEnum::INDEX_FAISS_IVFFLAT); + auto indexing = GenVecIndexing( + N, dim, fakevec.data(), knowhere::IndexEnum::INDEX_FAISS_IVFFLAT); auto segment = CreateSealedSegment(schema); const char* raw_plan = R"(vector_anns: < @@ -554,7 +563,7 @@ TEST(Sealed, LoadFieldDataMmap) { > placeholder_tag: "$0" >)"; - + Timestamp timestamp = 1000000; auto plan_str = translate_text_plan_to_binary_plan(raw_plan); auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); @@ -563,13 +572,13 @@ TEST(Sealed, LoadFieldDataMmap) { auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); - ASSERT_ANY_THROW(segment->Search(plan.get(), ph_group.get())); + ASSERT_ANY_THROW(segment->Search(plan.get(), ph_group.get(), timestamp)); SealedLoadFieldData(dataset, *segment, {}, true); - segment->Search(plan.get(), ph_group.get()); + segment->Search(plan.get(), ph_group.get(), timestamp); segment->DropFieldData(fakevec_id); - ASSERT_ANY_THROW(segment->Search(plan.get(), ph_group.get())); + ASSERT_ANY_THROW(segment->Search(plan.get(), ph_group.get(), timestamp)); LoadIndexInfo vec_info; vec_info.field_id = fakevec_id.get(); @@ -592,12 +601,12 @@ TEST(Sealed, LoadFieldDataMmap) { ASSERT_EQ(chunk_span3[i], ref3[i]); } - auto sr = segment->Search(plan.get(), ph_group.get()); + auto sr = segment->Search(plan.get(), ph_group.get(), timestamp); auto json = SearchResultToJson(*sr); std::cout << json.dump(1); segment->DropIndex(fakevec_id); - ASSERT_ANY_THROW(segment->Search(plan.get(), ph_group.get())); + ASSERT_ANY_THROW(segment->Search(plan.get(), ph_group.get(), timestamp)); } TEST(Sealed, LoadScalarIndex) { @@ -616,7 +625,8 @@ TEST(Sealed, LoadScalarIndex) { auto fakevec = dataset.get_col(fakevec_id); - auto indexing = GenVecIndexing(N, dim, fakevec.data(), knowhere::IndexEnum::INDEX_FAISS_IVFFLAT); + auto indexing = GenVecIndexing( + N, dim, fakevec.data(), knowhere::IndexEnum::INDEX_FAISS_IVFFLAT); auto segment = CreateSealedSegment(schema); // std::string dsl = R"({ @@ -672,7 +682,7 @@ TEST(Sealed, LoadScalarIndex) { > placeholder_tag: "$0" >)"; - + Timestamp timestamp = 1000000; auto plan_str = translate_text_plan_to_binary_plan(raw_plan); auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); @@ -731,7 +741,7 @@ TEST(Sealed, LoadScalarIndex) { nothing_index.index = GenScalarIndexing(N, nothing_data.data()); segment->LoadIndex(nothing_index); - auto sr = segment->Search(plan.get(), ph_group.get()); + auto sr = segment->Search(plan.get(), ph_group.get(), timestamp); auto json = SearchResultToJson(*sr); std::cout << json.dump(1); } @@ -780,7 +790,7 @@ TEST(Sealed, Delete) { > placeholder_tag: "$0" >)"; - + Timestamp timestamp = 1000000; auto plan_str = translate_text_plan_to_binary_plan(raw_plan); auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); @@ -789,7 +799,7 @@ TEST(Sealed, Delete) { auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); - ASSERT_ANY_THROW(segment->Search(plan.get(), ph_group.get())); + ASSERT_ANY_THROW(segment->Search(plan.get(), ph_group.get(), timestamp)); SealedLoadFieldData(dataset, *segment); @@ -864,7 +874,7 @@ TEST(Sealed, OverlapDelete) { > placeholder_tag: "$0" >)"; - + Timestamp timestamp = 1000000; auto plan_str = translate_text_plan_to_binary_plan(raw_plan); auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); @@ -873,7 +883,7 @@ TEST(Sealed, OverlapDelete) { auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); - ASSERT_ANY_THROW(segment->Search(plan.get(), ph_group.get())); + ASSERT_ANY_THROW(segment->Search(plan.get(), ph_group.get(), timestamp)); SealedLoadFieldData(dataset, *segment); @@ -991,7 +1001,7 @@ TEST(Sealed, BF) { auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); - auto result = segment->Search(plan.get(), ph_group.get()); + auto result = segment->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP); auto ves = SearchResultToVector(*result); // first: offset, second: distance EXPECT_GE(ves[0].first, 0); @@ -1045,7 +1055,7 @@ TEST(Sealed, BF_Overflow) { auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); - auto result = segment->Search(plan.get(), ph_group.get()); + auto result = segment->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP); auto ves = SearchResultToVector(*result); for (int i = 0; i < num_queries; ++i) { EXPECT_EQ(ves[0].first, -1); @@ -1135,7 +1145,8 @@ TEST(Sealed, GetVector) { auto fakevec = dataset.get_col(fakevec_id); - auto indexing = GenVecIndexing(N, dim, fakevec.data(), knowhere::IndexEnum::INDEX_FAISS_IVFFLAT); + auto indexing = GenVecIndexing( + N, dim, fakevec.data(), knowhere::IndexEnum::INDEX_FAISS_IVFFLAT); auto segment_sealed = CreateSealedSegment(schema); @@ -1322,7 +1333,7 @@ TEST(Sealed, LoadArrayFieldData) { ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); SealedLoadFieldData(dataset, *segment); - segment->Search(plan.get(), ph_group.get()); + segment->Search(plan.get(), ph_group.get(), 1L << 63); auto ids_ds = GenRandomIds(N); auto s = dynamic_cast(segment.get()); @@ -1379,7 +1390,7 @@ TEST(Sealed, LoadArrayFieldDataWithMMap) { ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); SealedLoadFieldData(dataset, *segment, {}, true); - segment->Search(plan.get(), ph_group.get()); + segment->Search(plan.get(), ph_group.get(), 1L << 63); } TEST(Sealed, SkipIndexSkipUnaryRange) { diff --git a/internal/core/unittest/test_string_expr.cpp b/internal/core/unittest/test_string_expr.cpp index c20c2665e7083..127c63ff38598 100644 --- a/internal/core/unittest/test_string_expr.cpp +++ b/internal/core/unittest/test_string_expr.cpp @@ -726,7 +726,7 @@ TEST(AlwaysTrueStringPlan, SearchWithOutputFields) { auto sub_result = BruteForceSearch( search_dataset, vec_col.data(), N, knowhere::Json(), nullptr); - auto sr = segment->Search(plan.get(), ph_group.get()); + auto sr = segment->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP); segment->FillPrimaryKeys(plan.get(), *sr); segment->FillTargetEntry(plan.get(), *sr); ASSERT_EQ(sr->pk_type_, DataType::VARCHAR); diff --git a/internal/core/unittest/test_utils/DataGen.h b/internal/core/unittest/test_utils/DataGen.h index 6b24b0aa2c978..0f5d13a3b62b3 100644 --- a/internal/core/unittest/test_utils/DataGen.h +++ b/internal/core/unittest/test_utils/DataGen.h @@ -904,7 +904,10 @@ SealedCreator(SchemaPtr schema, const GeneratedData& dataset) { } inline std::unique_ptr -GenVecIndexing(int64_t N, int64_t dim, const float* vec, const char* index_type) { +GenVecIndexing(int64_t N, + int64_t dim, + const float* vec, + const char* index_type) { auto conf = knowhere::Json{{knowhere::meta::METRIC_TYPE, knowhere::metric::L2}, {knowhere::meta::DIM, std::to_string(dim)}, diff --git a/internal/core/unittest/test_utils/c_api_test_utils.h b/internal/core/unittest/test_utils/c_api_test_utils.h index 6c46fea2e92a9..e57cb2615eb6d 100644 --- a/internal/core/unittest/test_utils/c_api_test_utils.h +++ b/internal/core/unittest/test_utils/c_api_test_utils.h @@ -37,9 +37,9 @@ using namespace milvus; using namespace milvus::segcore; namespace { - const char* - get_default_schema_config() { - static std::string conf = R"(name: "default-collection" +const char* +get_default_schema_config() { + static std::string conf = R"(name: "default-collection" fields: < fieldID: 100 name: "fakevec" @@ -59,81 +59,81 @@ namespace { data_type: Int64 is_primary_key: true >)"; - static std::string fake_conf = ""; - return conf.c_str(); - } + static std::string fake_conf = ""; + return conf.c_str(); +} - std::string - generate_max_float_query_data(int all_nq, int max_float_nq) { - assert(max_float_nq <= all_nq); - namespace ser = milvus::proto::common; - int dim = DIM; - ser::PlaceholderGroup raw_group; - auto value = raw_group.add_placeholders(); - value->set_tag("$0"); - value->set_type(ser::PlaceholderType::FloatVector); - for (int i = 0; i < all_nq; ++i) { - std::vector vec; - if (i < max_float_nq) { - for (int d = 0; d < dim; ++d) { - vec.push_back(std::numeric_limits::max()); - } - } else { - for (int d = 0; d < dim; ++d) { - vec.push_back(1); - } +std::string +generate_max_float_query_data(int all_nq, int max_float_nq) { + assert(max_float_nq <= all_nq); + namespace ser = milvus::proto::common; + int dim = DIM; + ser::PlaceholderGroup raw_group; + auto value = raw_group.add_placeholders(); + value->set_tag("$0"); + value->set_type(ser::PlaceholderType::FloatVector); + for (int i = 0; i < all_nq; ++i) { + std::vector vec; + if (i < max_float_nq) { + for (int d = 0; d < dim; ++d) { + vec.push_back(std::numeric_limits::max()); } - value->add_values(vec.data(), vec.size() * sizeof(float)); - } - auto blob = raw_group.SerializeAsString(); - return blob; - } - std::string - generate_query_data(int nq) { - namespace ser = milvus::proto::common; - std::default_random_engine e(67); - int dim = DIM; - std::normal_distribution dis(0.0, 1.0); - ser::PlaceholderGroup raw_group; - auto value = raw_group.add_placeholders(); - value->set_tag("$0"); - value->set_type(ser::PlaceholderType::FloatVector); - for (int i = 0; i < nq; ++i) { - std::vector vec; + } else { for (int d = 0; d < dim; ++d) { - vec.push_back(dis(e)); + vec.push_back(1); } - value->add_values(vec.data(), vec.size() * sizeof(float)); } - auto blob = raw_group.SerializeAsString(); - return blob; + value->add_values(vec.data(), vec.size() * sizeof(float)); } - void - CheckSearchResultDuplicate(const std::vector& results) { - auto nq = ((SearchResult*)results[0])->total_nq_; + auto blob = raw_group.SerializeAsString(); + return blob; +} +std::string +generate_query_data(int nq) { + namespace ser = milvus::proto::common; + std::default_random_engine e(67); + int dim = DIM; + std::normal_distribution dis(0.0, 1.0); + ser::PlaceholderGroup raw_group; + auto value = raw_group.add_placeholders(); + value->set_tag("$0"); + value->set_type(ser::PlaceholderType::FloatVector); + for (int i = 0; i < nq; ++i) { + std::vector vec; + for (int d = 0; d < dim; ++d) { + vec.push_back(dis(e)); + } + value->add_values(vec.data(), vec.size() * sizeof(float)); + } + auto blob = raw_group.SerializeAsString(); + return blob; +} +void +CheckSearchResultDuplicate(const std::vector& results) { + auto nq = ((SearchResult*)results[0])->total_nq_; - std::unordered_set pk_set; - std::unordered_set group_by_val_set; - for (int qi = 0; qi < nq; qi++) { - pk_set.clear(); - group_by_val_set.clear(); - for (size_t i = 0; i < results.size(); i++) { - auto search_result = (SearchResult*)results[i]; - ASSERT_EQ(nq, search_result->total_nq_); - auto topk_beg = search_result->topk_per_nq_prefix_sum_[qi]; - auto topk_end = search_result->topk_per_nq_prefix_sum_[qi + 1]; - for (size_t ki = topk_beg; ki < topk_end; ki++) { - ASSERT_NE(search_result->seg_offsets_[ki], INVALID_SEG_OFFSET); - auto ret = pk_set.insert(search_result->primary_keys_[ki]); - ASSERT_TRUE(ret.second); + std::unordered_set pk_set; + std::unordered_set group_by_val_set; + for (int qi = 0; qi < nq; qi++) { + pk_set.clear(); + group_by_val_set.clear(); + for (size_t i = 0; i < results.size(); i++) { + auto search_result = (SearchResult*)results[i]; + ASSERT_EQ(nq, search_result->total_nq_); + auto topk_beg = search_result->topk_per_nq_prefix_sum_[qi]; + auto topk_end = search_result->topk_per_nq_prefix_sum_[qi + 1]; + for (size_t ki = topk_beg; ki < topk_end; ki++) { + ASSERT_NE(search_result->seg_offsets_[ki], INVALID_SEG_OFFSET); + auto ret = pk_set.insert(search_result->primary_keys_[ki]); + ASSERT_TRUE(ret.second); - if(search_result->group_by_values_.size()>ki){ - auto group_by_val = search_result->group_by_values_[ki]; - ASSERT_TRUE(group_by_val_set.count(group_by_val)==0); - group_by_val_set.insert(group_by_val); - } + if (search_result->group_by_values_.size() > ki) { + auto group_by_val = search_result->group_by_values_[ki]; + ASSERT_TRUE(group_by_val_set.count(group_by_val) == 0); + group_by_val_set.insert(group_by_val); } } } } } +} // namespace diff --git a/internal/proto/internal.proto b/internal/proto/internal.proto index 9f768a7796594..a3cee9652f840 100644 --- a/internal/proto/internal.proto +++ b/internal/proto/internal.proto @@ -94,6 +94,7 @@ message SearchRequest { common.DslType dsl_type = 8; bytes serialized_expr_plan = 9; repeated int64 output_fields_id = 10; + uint64 mvcc_timestamp = 11; uint64 guarantee_timestamp = 12; uint64 timeout_timestamp = 13; int64 nq = 14; @@ -120,6 +121,7 @@ message SearchResults { // search request cost CostAggregation costAggregation = 13; + map channels_mvcc = 14; } message CostAggregation { @@ -160,7 +162,7 @@ message RetrieveResults { repeated int64 global_sealed_segmentIDs = 8; // query request cost - CostAggregation costAggregation = 13; + CostAggregation costAggregation = 13; } message LoadIndex { diff --git a/internal/proto/internalpb/internal.pb.go~parent of ca1349708... Remove time travel ralted testcase (#26119) b/internal/proto/internalpb/internal.pb.go~parent of ca1349708... Remove time travel ralted testcase (#26119) new file mode 100644 index 0000000000000..65268d2d6e002 --- /dev/null +++ b/internal/proto/internalpb/internal.pb.go~parent of ca1349708... Remove time travel ralted testcase (#26119) @@ -0,0 +1,2132 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: internal.proto + +package internalpb + +import ( + fmt "fmt" + proto "github.com/golang/protobuf/proto" + commonpb "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + schemapb "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + math "math" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package + +type RateType int32 + +const ( + RateType_DDLCollection RateType = 0 + RateType_DDLPartition RateType = 1 + RateType_DDLIndex RateType = 2 + RateType_DDLFlush RateType = 3 + RateType_DDLCompaction RateType = 4 + RateType_DMLInsert RateType = 5 + RateType_DMLDelete RateType = 6 + RateType_DMLBulkLoad RateType = 7 + RateType_DQLSearch RateType = 8 + RateType_DQLQuery RateType = 9 + RateType_DMLUpsert RateType = 10 +) + +var RateType_name = map[int32]string{ + 0: "DDLCollection", + 1: "DDLPartition", + 2: "DDLIndex", + 3: "DDLFlush", + 4: "DDLCompaction", + 5: "DMLInsert", + 6: "DMLDelete", + 7: "DMLBulkLoad", + 8: "DQLSearch", + 9: "DQLQuery", + 10: "DMLUpsert", +} + +var RateType_value = map[string]int32{ + "DDLCollection": 0, + "DDLPartition": 1, + "DDLIndex": 2, + "DDLFlush": 3, + "DDLCompaction": 4, + "DMLInsert": 5, + "DMLDelete": 6, + "DMLBulkLoad": 7, + "DQLSearch": 8, + "DQLQuery": 9, + "DMLUpsert": 10, +} + +func (x RateType) String() string { + return proto.EnumName(RateType_name, int32(x)) +} + +func (RateType) EnumDescriptor() ([]byte, []int) { + return fileDescriptor_41f4a519b878ee3b, []int{0} +} + +type GetTimeTickChannelRequest struct { + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *GetTimeTickChannelRequest) Reset() { *m = GetTimeTickChannelRequest{} } +func (m *GetTimeTickChannelRequest) String() string { return proto.CompactTextString(m) } +func (*GetTimeTickChannelRequest) ProtoMessage() {} +func (*GetTimeTickChannelRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_41f4a519b878ee3b, []int{0} +} + +func (m *GetTimeTickChannelRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_GetTimeTickChannelRequest.Unmarshal(m, b) +} +func (m *GetTimeTickChannelRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_GetTimeTickChannelRequest.Marshal(b, m, deterministic) +} +func (m *GetTimeTickChannelRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_GetTimeTickChannelRequest.Merge(m, src) +} +func (m *GetTimeTickChannelRequest) XXX_Size() int { + return xxx_messageInfo_GetTimeTickChannelRequest.Size(m) +} +func (m *GetTimeTickChannelRequest) XXX_DiscardUnknown() { + xxx_messageInfo_GetTimeTickChannelRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_GetTimeTickChannelRequest proto.InternalMessageInfo + +type GetStatisticsChannelRequest struct { + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *GetStatisticsChannelRequest) Reset() { *m = GetStatisticsChannelRequest{} } +func (m *GetStatisticsChannelRequest) String() string { return proto.CompactTextString(m) } +func (*GetStatisticsChannelRequest) ProtoMessage() {} +func (*GetStatisticsChannelRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_41f4a519b878ee3b, []int{1} +} + +func (m *GetStatisticsChannelRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_GetStatisticsChannelRequest.Unmarshal(m, b) +} +func (m *GetStatisticsChannelRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_GetStatisticsChannelRequest.Marshal(b, m, deterministic) +} +func (m *GetStatisticsChannelRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_GetStatisticsChannelRequest.Merge(m, src) +} +func (m *GetStatisticsChannelRequest) XXX_Size() int { + return xxx_messageInfo_GetStatisticsChannelRequest.Size(m) +} +func (m *GetStatisticsChannelRequest) XXX_DiscardUnknown() { + xxx_messageInfo_GetStatisticsChannelRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_GetStatisticsChannelRequest proto.InternalMessageInfo + +type GetDdChannelRequest struct { + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *GetDdChannelRequest) Reset() { *m = GetDdChannelRequest{} } +func (m *GetDdChannelRequest) String() string { return proto.CompactTextString(m) } +func (*GetDdChannelRequest) ProtoMessage() {} +func (*GetDdChannelRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_41f4a519b878ee3b, []int{2} +} + +func (m *GetDdChannelRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_GetDdChannelRequest.Unmarshal(m, b) +} +func (m *GetDdChannelRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_GetDdChannelRequest.Marshal(b, m, deterministic) +} +func (m *GetDdChannelRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_GetDdChannelRequest.Merge(m, src) +} +func (m *GetDdChannelRequest) XXX_Size() int { + return xxx_messageInfo_GetDdChannelRequest.Size(m) +} +func (m *GetDdChannelRequest) XXX_DiscardUnknown() { + xxx_messageInfo_GetDdChannelRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_GetDdChannelRequest proto.InternalMessageInfo + +type NodeInfo struct { + Address *commonpb.Address `protobuf:"bytes,1,opt,name=address,proto3" json:"address,omitempty"` + Role string `protobuf:"bytes,2,opt,name=role,proto3" json:"role,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *NodeInfo) Reset() { *m = NodeInfo{} } +func (m *NodeInfo) String() string { return proto.CompactTextString(m) } +func (*NodeInfo) ProtoMessage() {} +func (*NodeInfo) Descriptor() ([]byte, []int) { + return fileDescriptor_41f4a519b878ee3b, []int{3} +} + +func (m *NodeInfo) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_NodeInfo.Unmarshal(m, b) +} +func (m *NodeInfo) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_NodeInfo.Marshal(b, m, deterministic) +} +func (m *NodeInfo) XXX_Merge(src proto.Message) { + xxx_messageInfo_NodeInfo.Merge(m, src) +} +func (m *NodeInfo) XXX_Size() int { + return xxx_messageInfo_NodeInfo.Size(m) +} +func (m *NodeInfo) XXX_DiscardUnknown() { + xxx_messageInfo_NodeInfo.DiscardUnknown(m) +} + +var xxx_messageInfo_NodeInfo proto.InternalMessageInfo + +func (m *NodeInfo) GetAddress() *commonpb.Address { + if m != nil { + return m.Address + } + return nil +} + +func (m *NodeInfo) GetRole() string { + if m != nil { + return m.Role + } + return "" +} + +type InitParams struct { + NodeID int64 `protobuf:"varint,1,opt,name=nodeID,proto3" json:"nodeID,omitempty"` + StartParams []*commonpb.KeyValuePair `protobuf:"bytes,2,rep,name=start_params,json=startParams,proto3" json:"start_params,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *InitParams) Reset() { *m = InitParams{} } +func (m *InitParams) String() string { return proto.CompactTextString(m) } +func (*InitParams) ProtoMessage() {} +func (*InitParams) Descriptor() ([]byte, []int) { + return fileDescriptor_41f4a519b878ee3b, []int{4} +} + +func (m *InitParams) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_InitParams.Unmarshal(m, b) +} +func (m *InitParams) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_InitParams.Marshal(b, m, deterministic) +} +func (m *InitParams) XXX_Merge(src proto.Message) { + xxx_messageInfo_InitParams.Merge(m, src) +} +func (m *InitParams) XXX_Size() int { + return xxx_messageInfo_InitParams.Size(m) +} +func (m *InitParams) XXX_DiscardUnknown() { + xxx_messageInfo_InitParams.DiscardUnknown(m) +} + +var xxx_messageInfo_InitParams proto.InternalMessageInfo + +func (m *InitParams) GetNodeID() int64 { + if m != nil { + return m.NodeID + } + return 0 +} + +func (m *InitParams) GetStartParams() []*commonpb.KeyValuePair { + if m != nil { + return m.StartParams + } + return nil +} + +type StringList struct { + Values []string `protobuf:"bytes,1,rep,name=values,proto3" json:"values,omitempty"` + Status *commonpb.Status `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *StringList) Reset() { *m = StringList{} } +func (m *StringList) String() string { return proto.CompactTextString(m) } +func (*StringList) ProtoMessage() {} +func (*StringList) Descriptor() ([]byte, []int) { + return fileDescriptor_41f4a519b878ee3b, []int{5} +} + +func (m *StringList) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_StringList.Unmarshal(m, b) +} +func (m *StringList) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_StringList.Marshal(b, m, deterministic) +} +func (m *StringList) XXX_Merge(src proto.Message) { + xxx_messageInfo_StringList.Merge(m, src) +} +func (m *StringList) XXX_Size() int { + return xxx_messageInfo_StringList.Size(m) +} +func (m *StringList) XXX_DiscardUnknown() { + xxx_messageInfo_StringList.DiscardUnknown(m) +} + +var xxx_messageInfo_StringList proto.InternalMessageInfo + +func (m *StringList) GetValues() []string { + if m != nil { + return m.Values + } + return nil +} + +func (m *StringList) GetStatus() *commonpb.Status { + if m != nil { + return m.Status + } + return nil +} + +type GetStatisticsRequest struct { + Base *commonpb.MsgBase `protobuf:"bytes,1,opt,name=base,proto3" json:"base,omitempty"` + // Not useful for now + DbID int64 `protobuf:"varint,2,opt,name=dbID,proto3" json:"dbID,omitempty"` + // The collection you want get statistics + CollectionID int64 `protobuf:"varint,3,opt,name=collectionID,proto3" json:"collectionID,omitempty"` + // The partitions you want get statistics + PartitionIDs []int64 `protobuf:"varint,4,rep,packed,name=partitionIDs,proto3" json:"partitionIDs,omitempty"` + // timestamp of the statistics + TravelTimestamp uint64 `protobuf:"varint,5,opt,name=travel_timestamp,json=travelTimestamp,proto3" json:"travel_timestamp,omitempty"` + GuaranteeTimestamp uint64 `protobuf:"varint,6,opt,name=guarantee_timestamp,json=guaranteeTimestamp,proto3" json:"guarantee_timestamp,omitempty"` + TimeoutTimestamp uint64 `protobuf:"varint,7,opt,name=timeout_timestamp,json=timeoutTimestamp,proto3" json:"timeout_timestamp,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *GetStatisticsRequest) Reset() { *m = GetStatisticsRequest{} } +func (m *GetStatisticsRequest) String() string { return proto.CompactTextString(m) } +func (*GetStatisticsRequest) ProtoMessage() {} +func (*GetStatisticsRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_41f4a519b878ee3b, []int{6} +} + +func (m *GetStatisticsRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_GetStatisticsRequest.Unmarshal(m, b) +} +func (m *GetStatisticsRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_GetStatisticsRequest.Marshal(b, m, deterministic) +} +func (m *GetStatisticsRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_GetStatisticsRequest.Merge(m, src) +} +func (m *GetStatisticsRequest) XXX_Size() int { + return xxx_messageInfo_GetStatisticsRequest.Size(m) +} +func (m *GetStatisticsRequest) XXX_DiscardUnknown() { + xxx_messageInfo_GetStatisticsRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_GetStatisticsRequest proto.InternalMessageInfo + +func (m *GetStatisticsRequest) GetBase() *commonpb.MsgBase { + if m != nil { + return m.Base + } + return nil +} + +func (m *GetStatisticsRequest) GetDbID() int64 { + if m != nil { + return m.DbID + } + return 0 +} + +func (m *GetStatisticsRequest) GetCollectionID() int64 { + if m != nil { + return m.CollectionID + } + return 0 +} + +func (m *GetStatisticsRequest) GetPartitionIDs() []int64 { + if m != nil { + return m.PartitionIDs + } + return nil +} + +func (m *GetStatisticsRequest) GetTravelTimestamp() uint64 { + if m != nil { + return m.TravelTimestamp + } + return 0 +} + +func (m *GetStatisticsRequest) GetGuaranteeTimestamp() uint64 { + if m != nil { + return m.GuaranteeTimestamp + } + return 0 +} + +func (m *GetStatisticsRequest) GetTimeoutTimestamp() uint64 { + if m != nil { + return m.TimeoutTimestamp + } + return 0 +} + +type GetStatisticsResponse struct { + Base *commonpb.MsgBase `protobuf:"bytes,1,opt,name=base,proto3" json:"base,omitempty"` + // Contain error_code and reason + Status *commonpb.Status `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"` + // Collection statistics data. Contain pairs like {"row_count": "1"} + Stats []*commonpb.KeyValuePair `protobuf:"bytes,3,rep,name=stats,proto3" json:"stats,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *GetStatisticsResponse) Reset() { *m = GetStatisticsResponse{} } +func (m *GetStatisticsResponse) String() string { return proto.CompactTextString(m) } +func (*GetStatisticsResponse) ProtoMessage() {} +func (*GetStatisticsResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_41f4a519b878ee3b, []int{7} +} + +func (m *GetStatisticsResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_GetStatisticsResponse.Unmarshal(m, b) +} +func (m *GetStatisticsResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_GetStatisticsResponse.Marshal(b, m, deterministic) +} +func (m *GetStatisticsResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_GetStatisticsResponse.Merge(m, src) +} +func (m *GetStatisticsResponse) XXX_Size() int { + return xxx_messageInfo_GetStatisticsResponse.Size(m) +} +func (m *GetStatisticsResponse) XXX_DiscardUnknown() { + xxx_messageInfo_GetStatisticsResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_GetStatisticsResponse proto.InternalMessageInfo + +func (m *GetStatisticsResponse) GetBase() *commonpb.MsgBase { + if m != nil { + return m.Base + } + return nil +} + +func (m *GetStatisticsResponse) GetStatus() *commonpb.Status { + if m != nil { + return m.Status + } + return nil +} + +func (m *GetStatisticsResponse) GetStats() []*commonpb.KeyValuePair { + if m != nil { + return m.Stats + } + return nil +} + +type CreateAliasRequest struct { + Base *commonpb.MsgBase `protobuf:"bytes,1,opt,name=base,proto3" json:"base,omitempty"` + DbName string `protobuf:"bytes,2,opt,name=db_name,json=dbName,proto3" json:"db_name,omitempty"` + CollectionName string `protobuf:"bytes,3,opt,name=collection_name,json=collectionName,proto3" json:"collection_name,omitempty"` + Alias string `protobuf:"bytes,4,opt,name=alias,proto3" json:"alias,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *CreateAliasRequest) Reset() { *m = CreateAliasRequest{} } +func (m *CreateAliasRequest) String() string { return proto.CompactTextString(m) } +func (*CreateAliasRequest) ProtoMessage() {} +func (*CreateAliasRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_41f4a519b878ee3b, []int{8} +} + +func (m *CreateAliasRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_CreateAliasRequest.Unmarshal(m, b) +} +func (m *CreateAliasRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_CreateAliasRequest.Marshal(b, m, deterministic) +} +func (m *CreateAliasRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_CreateAliasRequest.Merge(m, src) +} +func (m *CreateAliasRequest) XXX_Size() int { + return xxx_messageInfo_CreateAliasRequest.Size(m) +} +func (m *CreateAliasRequest) XXX_DiscardUnknown() { + xxx_messageInfo_CreateAliasRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_CreateAliasRequest proto.InternalMessageInfo + +func (m *CreateAliasRequest) GetBase() *commonpb.MsgBase { + if m != nil { + return m.Base + } + return nil +} + +func (m *CreateAliasRequest) GetDbName() string { + if m != nil { + return m.DbName + } + return "" +} + +func (m *CreateAliasRequest) GetCollectionName() string { + if m != nil { + return m.CollectionName + } + return "" +} + +func (m *CreateAliasRequest) GetAlias() string { + if m != nil { + return m.Alias + } + return "" +} + +type DropAliasRequest struct { + Base *commonpb.MsgBase `protobuf:"bytes,1,opt,name=base,proto3" json:"base,omitempty"` + DbName string `protobuf:"bytes,2,opt,name=db_name,json=dbName,proto3" json:"db_name,omitempty"` + Alias string `protobuf:"bytes,3,opt,name=alias,proto3" json:"alias,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *DropAliasRequest) Reset() { *m = DropAliasRequest{} } +func (m *DropAliasRequest) String() string { return proto.CompactTextString(m) } +func (*DropAliasRequest) ProtoMessage() {} +func (*DropAliasRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_41f4a519b878ee3b, []int{9} +} + +func (m *DropAliasRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_DropAliasRequest.Unmarshal(m, b) +} +func (m *DropAliasRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_DropAliasRequest.Marshal(b, m, deterministic) +} +func (m *DropAliasRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_DropAliasRequest.Merge(m, src) +} +func (m *DropAliasRequest) XXX_Size() int { + return xxx_messageInfo_DropAliasRequest.Size(m) +} +func (m *DropAliasRequest) XXX_DiscardUnknown() { + xxx_messageInfo_DropAliasRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_DropAliasRequest proto.InternalMessageInfo + +func (m *DropAliasRequest) GetBase() *commonpb.MsgBase { + if m != nil { + return m.Base + } + return nil +} + +func (m *DropAliasRequest) GetDbName() string { + if m != nil { + return m.DbName + } + return "" +} + +func (m *DropAliasRequest) GetAlias() string { + if m != nil { + return m.Alias + } + return "" +} + +type AlterAliasRequest struct { + Base *commonpb.MsgBase `protobuf:"bytes,1,opt,name=base,proto3" json:"base,omitempty"` + DbName string `protobuf:"bytes,2,opt,name=db_name,json=dbName,proto3" json:"db_name,omitempty"` + CollectionName string `protobuf:"bytes,3,opt,name=collection_name,json=collectionName,proto3" json:"collection_name,omitempty"` + Alias string `protobuf:"bytes,4,opt,name=alias,proto3" json:"alias,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *AlterAliasRequest) Reset() { *m = AlterAliasRequest{} } +func (m *AlterAliasRequest) String() string { return proto.CompactTextString(m) } +func (*AlterAliasRequest) ProtoMessage() {} +func (*AlterAliasRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_41f4a519b878ee3b, []int{10} +} + +func (m *AlterAliasRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_AlterAliasRequest.Unmarshal(m, b) +} +func (m *AlterAliasRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_AlterAliasRequest.Marshal(b, m, deterministic) +} +func (m *AlterAliasRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_AlterAliasRequest.Merge(m, src) +} +func (m *AlterAliasRequest) XXX_Size() int { + return xxx_messageInfo_AlterAliasRequest.Size(m) +} +func (m *AlterAliasRequest) XXX_DiscardUnknown() { + xxx_messageInfo_AlterAliasRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_AlterAliasRequest proto.InternalMessageInfo + +func (m *AlterAliasRequest) GetBase() *commonpb.MsgBase { + if m != nil { + return m.Base + } + return nil +} + +func (m *AlterAliasRequest) GetDbName() string { + if m != nil { + return m.DbName + } + return "" +} + +func (m *AlterAliasRequest) GetCollectionName() string { + if m != nil { + return m.CollectionName + } + return "" +} + +func (m *AlterAliasRequest) GetAlias() string { + if m != nil { + return m.Alias + } + return "" +} + +type CreateIndexRequest struct { + Base *commonpb.MsgBase `protobuf:"bytes,1,opt,name=base,proto3" json:"base,omitempty"` + DbName string `protobuf:"bytes,2,opt,name=db_name,json=dbName,proto3" json:"db_name,omitempty"` + CollectionName string `protobuf:"bytes,3,opt,name=collection_name,json=collectionName,proto3" json:"collection_name,omitempty"` + FieldName string `protobuf:"bytes,4,opt,name=field_name,json=fieldName,proto3" json:"field_name,omitempty"` + DbID int64 `protobuf:"varint,5,opt,name=dbID,proto3" json:"dbID,omitempty"` + CollectionID int64 `protobuf:"varint,6,opt,name=collectionID,proto3" json:"collectionID,omitempty"` + FieldID int64 `protobuf:"varint,7,opt,name=fieldID,proto3" json:"fieldID,omitempty"` + ExtraParams []*commonpb.KeyValuePair `protobuf:"bytes,8,rep,name=extra_params,json=extraParams,proto3" json:"extra_params,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *CreateIndexRequest) Reset() { *m = CreateIndexRequest{} } +func (m *CreateIndexRequest) String() string { return proto.CompactTextString(m) } +func (*CreateIndexRequest) ProtoMessage() {} +func (*CreateIndexRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_41f4a519b878ee3b, []int{11} +} + +func (m *CreateIndexRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_CreateIndexRequest.Unmarshal(m, b) +} +func (m *CreateIndexRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_CreateIndexRequest.Marshal(b, m, deterministic) +} +func (m *CreateIndexRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_CreateIndexRequest.Merge(m, src) +} +func (m *CreateIndexRequest) XXX_Size() int { + return xxx_messageInfo_CreateIndexRequest.Size(m) +} +func (m *CreateIndexRequest) XXX_DiscardUnknown() { + xxx_messageInfo_CreateIndexRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_CreateIndexRequest proto.InternalMessageInfo + +func (m *CreateIndexRequest) GetBase() *commonpb.MsgBase { + if m != nil { + return m.Base + } + return nil +} + +func (m *CreateIndexRequest) GetDbName() string { + if m != nil { + return m.DbName + } + return "" +} + +func (m *CreateIndexRequest) GetCollectionName() string { + if m != nil { + return m.CollectionName + } + return "" +} + +func (m *CreateIndexRequest) GetFieldName() string { + if m != nil { + return m.FieldName + } + return "" +} + +func (m *CreateIndexRequest) GetDbID() int64 { + if m != nil { + return m.DbID + } + return 0 +} + +func (m *CreateIndexRequest) GetCollectionID() int64 { + if m != nil { + return m.CollectionID + } + return 0 +} + +func (m *CreateIndexRequest) GetFieldID() int64 { + if m != nil { + return m.FieldID + } + return 0 +} + +func (m *CreateIndexRequest) GetExtraParams() []*commonpb.KeyValuePair { + if m != nil { + return m.ExtraParams + } + return nil +} + +type SearchRequest struct { + Base *commonpb.MsgBase `protobuf:"bytes,1,opt,name=base,proto3" json:"base,omitempty"` + ReqID int64 `protobuf:"varint,2,opt,name=reqID,proto3" json:"reqID,omitempty"` + DbID int64 `protobuf:"varint,3,opt,name=dbID,proto3" json:"dbID,omitempty"` + CollectionID int64 `protobuf:"varint,4,opt,name=collectionID,proto3" json:"collectionID,omitempty"` + PartitionIDs []int64 `protobuf:"varint,5,rep,packed,name=partitionIDs,proto3" json:"partitionIDs,omitempty"` + Dsl string `protobuf:"bytes,6,opt,name=dsl,proto3" json:"dsl,omitempty"` + // serialized `PlaceholderGroup` + PlaceholderGroup []byte `protobuf:"bytes,7,opt,name=placeholder_group,json=placeholderGroup,proto3" json:"placeholder_group,omitempty"` + DslType commonpb.DslType `protobuf:"varint,8,opt,name=dsl_type,json=dslType,proto3,enum=milvus.proto.common.DslType" json:"dsl_type,omitempty"` + SerializedExprPlan []byte `protobuf:"bytes,9,opt,name=serialized_expr_plan,json=serializedExprPlan,proto3" json:"serialized_expr_plan,omitempty"` + OutputFieldsId []int64 `protobuf:"varint,10,rep,packed,name=output_fields_id,json=outputFieldsId,proto3" json:"output_fields_id,omitempty"` + TravelTimestamp uint64 `protobuf:"varint,11,opt,name=travel_timestamp,json=travelTimestamp,proto3" json:"travel_timestamp,omitempty"` + GuaranteeTimestamp uint64 `protobuf:"varint,12,opt,name=guarantee_timestamp,json=guaranteeTimestamp,proto3" json:"guarantee_timestamp,omitempty"` + TimeoutTimestamp uint64 `protobuf:"varint,13,opt,name=timeout_timestamp,json=timeoutTimestamp,proto3" json:"timeout_timestamp,omitempty"` + Nq int64 `protobuf:"varint,14,opt,name=nq,proto3" json:"nq,omitempty"` + Topk int64 `protobuf:"varint,15,opt,name=topk,proto3" json:"topk,omitempty"` + MetricType string `protobuf:"bytes,16,opt,name=metricType,proto3" json:"metricType,omitempty"` + IgnoreGrowing bool `protobuf:"varint,17,opt,name=ignoreGrowing,proto3" json:"ignoreGrowing,omitempty"` + Username string `protobuf:"bytes,18,opt,name=username,proto3" json:"username,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *SearchRequest) Reset() { *m = SearchRequest{} } +func (m *SearchRequest) String() string { return proto.CompactTextString(m) } +func (*SearchRequest) ProtoMessage() {} +func (*SearchRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_41f4a519b878ee3b, []int{12} +} + +func (m *SearchRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_SearchRequest.Unmarshal(m, b) +} +func (m *SearchRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_SearchRequest.Marshal(b, m, deterministic) +} +func (m *SearchRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_SearchRequest.Merge(m, src) +} +func (m *SearchRequest) XXX_Size() int { + return xxx_messageInfo_SearchRequest.Size(m) +} +func (m *SearchRequest) XXX_DiscardUnknown() { + xxx_messageInfo_SearchRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_SearchRequest proto.InternalMessageInfo + +func (m *SearchRequest) GetBase() *commonpb.MsgBase { + if m != nil { + return m.Base + } + return nil +} + +func (m *SearchRequest) GetReqID() int64 { + if m != nil { + return m.ReqID + } + return 0 +} + +func (m *SearchRequest) GetDbID() int64 { + if m != nil { + return m.DbID + } + return 0 +} + +func (m *SearchRequest) GetCollectionID() int64 { + if m != nil { + return m.CollectionID + } + return 0 +} + +func (m *SearchRequest) GetPartitionIDs() []int64 { + if m != nil { + return m.PartitionIDs + } + return nil +} + +func (m *SearchRequest) GetDsl() string { + if m != nil { + return m.Dsl + } + return "" +} + +func (m *SearchRequest) GetPlaceholderGroup() []byte { + if m != nil { + return m.PlaceholderGroup + } + return nil +} + +func (m *SearchRequest) GetDslType() commonpb.DslType { + if m != nil { + return m.DslType + } + return commonpb.DslType_Dsl +} + +func (m *SearchRequest) GetSerializedExprPlan() []byte { + if m != nil { + return m.SerializedExprPlan + } + return nil +} + +func (m *SearchRequest) GetOutputFieldsId() []int64 { + if m != nil { + return m.OutputFieldsId + } + return nil +} + +func (m *SearchRequest) GetTravelTimestamp() uint64 { + if m != nil { + return m.TravelTimestamp + } + return 0 +} + +func (m *SearchRequest) GetGuaranteeTimestamp() uint64 { + if m != nil { + return m.GuaranteeTimestamp + } + return 0 +} + +func (m *SearchRequest) GetTimeoutTimestamp() uint64 { + if m != nil { + return m.TimeoutTimestamp + } + return 0 +} + +func (m *SearchRequest) GetNq() int64 { + if m != nil { + return m.Nq + } + return 0 +} + +func (m *SearchRequest) GetTopk() int64 { + if m != nil { + return m.Topk + } + return 0 +} + +func (m *SearchRequest) GetMetricType() string { + if m != nil { + return m.MetricType + } + return "" +} + +func (m *SearchRequest) GetIgnoreGrowing() bool { + if m != nil { + return m.IgnoreGrowing + } + return false +} + +func (m *SearchRequest) GetUsername() string { + if m != nil { + return m.Username + } + return "" +} + +type SearchResults struct { + Base *commonpb.MsgBase `protobuf:"bytes,1,opt,name=base,proto3" json:"base,omitempty"` + Status *commonpb.Status `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"` + ReqID int64 `protobuf:"varint,3,opt,name=reqID,proto3" json:"reqID,omitempty"` + MetricType string `protobuf:"bytes,4,opt,name=metric_type,json=metricType,proto3" json:"metric_type,omitempty"` + NumQueries int64 `protobuf:"varint,5,opt,name=num_queries,json=numQueries,proto3" json:"num_queries,omitempty"` + TopK int64 `protobuf:"varint,6,opt,name=top_k,json=topK,proto3" json:"top_k,omitempty"` + SealedSegmentIDsSearched []int64 `protobuf:"varint,7,rep,packed,name=sealed_segmentIDs_searched,json=sealedSegmentIDsSearched,proto3" json:"sealed_segmentIDs_searched,omitempty"` + ChannelIDsSearched []string `protobuf:"bytes,8,rep,name=channelIDs_searched,json=channelIDsSearched,proto3" json:"channelIDs_searched,omitempty"` + GlobalSealedSegmentIDs []int64 `protobuf:"varint,9,rep,packed,name=global_sealed_segmentIDs,json=globalSealedSegmentIDs,proto3" json:"global_sealed_segmentIDs,omitempty"` + // schema.SearchResultsData inside + SlicedBlob []byte `protobuf:"bytes,10,opt,name=sliced_blob,json=slicedBlob,proto3" json:"sliced_blob,omitempty"` + SlicedNumCount int64 `protobuf:"varint,11,opt,name=sliced_num_count,json=slicedNumCount,proto3" json:"sliced_num_count,omitempty"` + SlicedOffset int64 `protobuf:"varint,12,opt,name=sliced_offset,json=slicedOffset,proto3" json:"sliced_offset,omitempty"` + // search request cost + CostAggregation *CostAggregation `protobuf:"bytes,13,opt,name=costAggregation,proto3" json:"costAggregation,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *SearchResults) Reset() { *m = SearchResults{} } +func (m *SearchResults) String() string { return proto.CompactTextString(m) } +func (*SearchResults) ProtoMessage() {} +func (*SearchResults) Descriptor() ([]byte, []int) { + return fileDescriptor_41f4a519b878ee3b, []int{13} +} + +func (m *SearchResults) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_SearchResults.Unmarshal(m, b) +} +func (m *SearchResults) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_SearchResults.Marshal(b, m, deterministic) +} +func (m *SearchResults) XXX_Merge(src proto.Message) { + xxx_messageInfo_SearchResults.Merge(m, src) +} +func (m *SearchResults) XXX_Size() int { + return xxx_messageInfo_SearchResults.Size(m) +} +func (m *SearchResults) XXX_DiscardUnknown() { + xxx_messageInfo_SearchResults.DiscardUnknown(m) +} + +var xxx_messageInfo_SearchResults proto.InternalMessageInfo + +func (m *SearchResults) GetBase() *commonpb.MsgBase { + if m != nil { + return m.Base + } + return nil +} + +func (m *SearchResults) GetStatus() *commonpb.Status { + if m != nil { + return m.Status + } + return nil +} + +func (m *SearchResults) GetReqID() int64 { + if m != nil { + return m.ReqID + } + return 0 +} + +func (m *SearchResults) GetMetricType() string { + if m != nil { + return m.MetricType + } + return "" +} + +func (m *SearchResults) GetNumQueries() int64 { + if m != nil { + return m.NumQueries + } + return 0 +} + +func (m *SearchResults) GetTopK() int64 { + if m != nil { + return m.TopK + } + return 0 +} + +func (m *SearchResults) GetSealedSegmentIDsSearched() []int64 { + if m != nil { + return m.SealedSegmentIDsSearched + } + return nil +} + +func (m *SearchResults) GetChannelIDsSearched() []string { + if m != nil { + return m.ChannelIDsSearched + } + return nil +} + +func (m *SearchResults) GetGlobalSealedSegmentIDs() []int64 { + if m != nil { + return m.GlobalSealedSegmentIDs + } + return nil +} + +func (m *SearchResults) GetSlicedBlob() []byte { + if m != nil { + return m.SlicedBlob + } + return nil +} + +func (m *SearchResults) GetSlicedNumCount() int64 { + if m != nil { + return m.SlicedNumCount + } + return 0 +} + +func (m *SearchResults) GetSlicedOffset() int64 { + if m != nil { + return m.SlicedOffset + } + return 0 +} + +func (m *SearchResults) GetCostAggregation() *CostAggregation { + if m != nil { + return m.CostAggregation + } + return nil +} + +type CostAggregation struct { + ResponseTime int64 `protobuf:"varint,1,opt,name=responseTime,proto3" json:"responseTime,omitempty"` + ServiceTime int64 `protobuf:"varint,2,opt,name=serviceTime,proto3" json:"serviceTime,omitempty"` + TotalNQ int64 `protobuf:"varint,3,opt,name=totalNQ,proto3" json:"totalNQ,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *CostAggregation) Reset() { *m = CostAggregation{} } +func (m *CostAggregation) String() string { return proto.CompactTextString(m) } +func (*CostAggregation) ProtoMessage() {} +func (*CostAggregation) Descriptor() ([]byte, []int) { + return fileDescriptor_41f4a519b878ee3b, []int{14} +} + +func (m *CostAggregation) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_CostAggregation.Unmarshal(m, b) +} +func (m *CostAggregation) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_CostAggregation.Marshal(b, m, deterministic) +} +func (m *CostAggregation) XXX_Merge(src proto.Message) { + xxx_messageInfo_CostAggregation.Merge(m, src) +} +func (m *CostAggregation) XXX_Size() int { + return xxx_messageInfo_CostAggregation.Size(m) +} +func (m *CostAggregation) XXX_DiscardUnknown() { + xxx_messageInfo_CostAggregation.DiscardUnknown(m) +} + +var xxx_messageInfo_CostAggregation proto.InternalMessageInfo + +func (m *CostAggregation) GetResponseTime() int64 { + if m != nil { + return m.ResponseTime + } + return 0 +} + +func (m *CostAggregation) GetServiceTime() int64 { + if m != nil { + return m.ServiceTime + } + return 0 +} + +func (m *CostAggregation) GetTotalNQ() int64 { + if m != nil { + return m.TotalNQ + } + return 0 +} + +type RetrieveRequest struct { + Base *commonpb.MsgBase `protobuf:"bytes,1,opt,name=base,proto3" json:"base,omitempty"` + ReqID int64 `protobuf:"varint,2,opt,name=reqID,proto3" json:"reqID,omitempty"` + DbID int64 `protobuf:"varint,3,opt,name=dbID,proto3" json:"dbID,omitempty"` + CollectionID int64 `protobuf:"varint,4,opt,name=collectionID,proto3" json:"collectionID,omitempty"` + PartitionIDs []int64 `protobuf:"varint,5,rep,packed,name=partitionIDs,proto3" json:"partitionIDs,omitempty"` + SerializedExprPlan []byte `protobuf:"bytes,6,opt,name=serialized_expr_plan,json=serializedExprPlan,proto3" json:"serialized_expr_plan,omitempty"` + OutputFieldsId []int64 `protobuf:"varint,7,rep,packed,name=output_fields_id,json=outputFieldsId,proto3" json:"output_fields_id,omitempty"` + TravelTimestamp uint64 `protobuf:"varint,8,opt,name=travel_timestamp,json=travelTimestamp,proto3" json:"travel_timestamp,omitempty"` + GuaranteeTimestamp uint64 `protobuf:"varint,9,opt,name=guarantee_timestamp,json=guaranteeTimestamp,proto3" json:"guarantee_timestamp,omitempty"` + TimeoutTimestamp uint64 `protobuf:"varint,10,opt,name=timeout_timestamp,json=timeoutTimestamp,proto3" json:"timeout_timestamp,omitempty"` + Limit int64 `protobuf:"varint,11,opt,name=limit,proto3" json:"limit,omitempty"` + IgnoreGrowing bool `protobuf:"varint,12,opt,name=ignoreGrowing,proto3" json:"ignoreGrowing,omitempty"` + IsCount bool `protobuf:"varint,13,opt,name=is_count,json=isCount,proto3" json:"is_count,omitempty"` + IterationExtensionReduceRate int64 `protobuf:"varint,14,opt,name=iteration_extension_reduce_rate,json=iterationExtensionReduceRate,proto3" json:"iteration_extension_reduce_rate,omitempty"` + Username string `protobuf:"bytes,15,opt,name=username,proto3" json:"username,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *RetrieveRequest) Reset() { *m = RetrieveRequest{} } +func (m *RetrieveRequest) String() string { return proto.CompactTextString(m) } +func (*RetrieveRequest) ProtoMessage() {} +func (*RetrieveRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_41f4a519b878ee3b, []int{15} +} + +func (m *RetrieveRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_RetrieveRequest.Unmarshal(m, b) +} +func (m *RetrieveRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_RetrieveRequest.Marshal(b, m, deterministic) +} +func (m *RetrieveRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_RetrieveRequest.Merge(m, src) +} +func (m *RetrieveRequest) XXX_Size() int { + return xxx_messageInfo_RetrieveRequest.Size(m) +} +func (m *RetrieveRequest) XXX_DiscardUnknown() { + xxx_messageInfo_RetrieveRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_RetrieveRequest proto.InternalMessageInfo + +func (m *RetrieveRequest) GetBase() *commonpb.MsgBase { + if m != nil { + return m.Base + } + return nil +} + +func (m *RetrieveRequest) GetReqID() int64 { + if m != nil { + return m.ReqID + } + return 0 +} + +func (m *RetrieveRequest) GetDbID() int64 { + if m != nil { + return m.DbID + } + return 0 +} + +func (m *RetrieveRequest) GetCollectionID() int64 { + if m != nil { + return m.CollectionID + } + return 0 +} + +func (m *RetrieveRequest) GetPartitionIDs() []int64 { + if m != nil { + return m.PartitionIDs + } + return nil +} + +func (m *RetrieveRequest) GetSerializedExprPlan() []byte { + if m != nil { + return m.SerializedExprPlan + } + return nil +} + +func (m *RetrieveRequest) GetOutputFieldsId() []int64 { + if m != nil { + return m.OutputFieldsId + } + return nil +} + +func (m *RetrieveRequest) GetTravelTimestamp() uint64 { + if m != nil { + return m.TravelTimestamp + } + return 0 +} + +func (m *RetrieveRequest) GetGuaranteeTimestamp() uint64 { + if m != nil { + return m.GuaranteeTimestamp + } + return 0 +} + +func (m *RetrieveRequest) GetTimeoutTimestamp() uint64 { + if m != nil { + return m.TimeoutTimestamp + } + return 0 +} + +func (m *RetrieveRequest) GetLimit() int64 { + if m != nil { + return m.Limit + } + return 0 +} + +func (m *RetrieveRequest) GetIgnoreGrowing() bool { + if m != nil { + return m.IgnoreGrowing + } + return false +} + +func (m *RetrieveRequest) GetIsCount() bool { + if m != nil { + return m.IsCount + } + return false +} + +func (m *RetrieveRequest) GetIterationExtensionReduceRate() int64 { + if m != nil { + return m.IterationExtensionReduceRate + } + return 0 +} + +func (m *RetrieveRequest) GetUsername() string { + if m != nil { + return m.Username + } + return "" +} + +type RetrieveResults struct { + Base *commonpb.MsgBase `protobuf:"bytes,1,opt,name=base,proto3" json:"base,omitempty"` + Status *commonpb.Status `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"` + ReqID int64 `protobuf:"varint,3,opt,name=reqID,proto3" json:"reqID,omitempty"` + Ids *schemapb.IDs `protobuf:"bytes,4,opt,name=ids,proto3" json:"ids,omitempty"` + FieldsData []*schemapb.FieldData `protobuf:"bytes,5,rep,name=fields_data,json=fieldsData,proto3" json:"fields_data,omitempty"` + SealedSegmentIDsRetrieved []int64 `protobuf:"varint,6,rep,packed,name=sealed_segmentIDs_retrieved,json=sealedSegmentIDsRetrieved,proto3" json:"sealed_segmentIDs_retrieved,omitempty"` + ChannelIDsRetrieved []string `protobuf:"bytes,7,rep,name=channelIDs_retrieved,json=channelIDsRetrieved,proto3" json:"channelIDs_retrieved,omitempty"` + GlobalSealedSegmentIDs []int64 `protobuf:"varint,8,rep,packed,name=global_sealed_segmentIDs,json=globalSealedSegmentIDs,proto3" json:"global_sealed_segmentIDs,omitempty"` + // query request cost + CostAggregation *CostAggregation `protobuf:"bytes,13,opt,name=costAggregation,proto3" json:"costAggregation,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *RetrieveResults) Reset() { *m = RetrieveResults{} } +func (m *RetrieveResults) String() string { return proto.CompactTextString(m) } +func (*RetrieveResults) ProtoMessage() {} +func (*RetrieveResults) Descriptor() ([]byte, []int) { + return fileDescriptor_41f4a519b878ee3b, []int{16} +} + +func (m *RetrieveResults) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_RetrieveResults.Unmarshal(m, b) +} +func (m *RetrieveResults) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_RetrieveResults.Marshal(b, m, deterministic) +} +func (m *RetrieveResults) XXX_Merge(src proto.Message) { + xxx_messageInfo_RetrieveResults.Merge(m, src) +} +func (m *RetrieveResults) XXX_Size() int { + return xxx_messageInfo_RetrieveResults.Size(m) +} +func (m *RetrieveResults) XXX_DiscardUnknown() { + xxx_messageInfo_RetrieveResults.DiscardUnknown(m) +} + +var xxx_messageInfo_RetrieveResults proto.InternalMessageInfo + +func (m *RetrieveResults) GetBase() *commonpb.MsgBase { + if m != nil { + return m.Base + } + return nil +} + +func (m *RetrieveResults) GetStatus() *commonpb.Status { + if m != nil { + return m.Status + } + return nil +} + +func (m *RetrieveResults) GetReqID() int64 { + if m != nil { + return m.ReqID + } + return 0 +} + +func (m *RetrieveResults) GetIds() *schemapb.IDs { + if m != nil { + return m.Ids + } + return nil +} + +func (m *RetrieveResults) GetFieldsData() []*schemapb.FieldData { + if m != nil { + return m.FieldsData + } + return nil +} + +func (m *RetrieveResults) GetSealedSegmentIDsRetrieved() []int64 { + if m != nil { + return m.SealedSegmentIDsRetrieved + } + return nil +} + +func (m *RetrieveResults) GetChannelIDsRetrieved() []string { + if m != nil { + return m.ChannelIDsRetrieved + } + return nil +} + +func (m *RetrieveResults) GetGlobalSealedSegmentIDs() []int64 { + if m != nil { + return m.GlobalSealedSegmentIDs + } + return nil +} + +func (m *RetrieveResults) GetCostAggregation() *CostAggregation { + if m != nil { + return m.CostAggregation + } + return nil +} + +type LoadIndex struct { + Base *commonpb.MsgBase `protobuf:"bytes,1,opt,name=base,proto3" json:"base,omitempty"` + SegmentID int64 `protobuf:"varint,2,opt,name=segmentID,proto3" json:"segmentID,omitempty"` + FieldName string `protobuf:"bytes,3,opt,name=fieldName,proto3" json:"fieldName,omitempty"` + FieldID int64 `protobuf:"varint,4,opt,name=fieldID,proto3" json:"fieldID,omitempty"` + IndexPaths []string `protobuf:"bytes,5,rep,name=index_paths,json=indexPaths,proto3" json:"index_paths,omitempty"` + IndexParams []*commonpb.KeyValuePair `protobuf:"bytes,6,rep,name=index_params,json=indexParams,proto3" json:"index_params,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *LoadIndex) Reset() { *m = LoadIndex{} } +func (m *LoadIndex) String() string { return proto.CompactTextString(m) } +func (*LoadIndex) ProtoMessage() {} +func (*LoadIndex) Descriptor() ([]byte, []int) { + return fileDescriptor_41f4a519b878ee3b, []int{17} +} + +func (m *LoadIndex) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_LoadIndex.Unmarshal(m, b) +} +func (m *LoadIndex) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_LoadIndex.Marshal(b, m, deterministic) +} +func (m *LoadIndex) XXX_Merge(src proto.Message) { + xxx_messageInfo_LoadIndex.Merge(m, src) +} +func (m *LoadIndex) XXX_Size() int { + return xxx_messageInfo_LoadIndex.Size(m) +} +func (m *LoadIndex) XXX_DiscardUnknown() { + xxx_messageInfo_LoadIndex.DiscardUnknown(m) +} + +var xxx_messageInfo_LoadIndex proto.InternalMessageInfo + +func (m *LoadIndex) GetBase() *commonpb.MsgBase { + if m != nil { + return m.Base + } + return nil +} + +func (m *LoadIndex) GetSegmentID() int64 { + if m != nil { + return m.SegmentID + } + return 0 +} + +func (m *LoadIndex) GetFieldName() string { + if m != nil { + return m.FieldName + } + return "" +} + +func (m *LoadIndex) GetFieldID() int64 { + if m != nil { + return m.FieldID + } + return 0 +} + +func (m *LoadIndex) GetIndexPaths() []string { + if m != nil { + return m.IndexPaths + } + return nil +} + +func (m *LoadIndex) GetIndexParams() []*commonpb.KeyValuePair { + if m != nil { + return m.IndexParams + } + return nil +} + +type IndexStats struct { + IndexParams []*commonpb.KeyValuePair `protobuf:"bytes,1,rep,name=index_params,json=indexParams,proto3" json:"index_params,omitempty"` + NumRelatedSegments int64 `protobuf:"varint,2,opt,name=num_related_segments,json=numRelatedSegments,proto3" json:"num_related_segments,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *IndexStats) Reset() { *m = IndexStats{} } +func (m *IndexStats) String() string { return proto.CompactTextString(m) } +func (*IndexStats) ProtoMessage() {} +func (*IndexStats) Descriptor() ([]byte, []int) { + return fileDescriptor_41f4a519b878ee3b, []int{18} +} + +func (m *IndexStats) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_IndexStats.Unmarshal(m, b) +} +func (m *IndexStats) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_IndexStats.Marshal(b, m, deterministic) +} +func (m *IndexStats) XXX_Merge(src proto.Message) { + xxx_messageInfo_IndexStats.Merge(m, src) +} +func (m *IndexStats) XXX_Size() int { + return xxx_messageInfo_IndexStats.Size(m) +} +func (m *IndexStats) XXX_DiscardUnknown() { + xxx_messageInfo_IndexStats.DiscardUnknown(m) +} + +var xxx_messageInfo_IndexStats proto.InternalMessageInfo + +func (m *IndexStats) GetIndexParams() []*commonpb.KeyValuePair { + if m != nil { + return m.IndexParams + } + return nil +} + +func (m *IndexStats) GetNumRelatedSegments() int64 { + if m != nil { + return m.NumRelatedSegments + } + return 0 +} + +type FieldStats struct { + CollectionID int64 `protobuf:"varint,1,opt,name=collectionID,proto3" json:"collectionID,omitempty"` + FieldID int64 `protobuf:"varint,2,opt,name=fieldID,proto3" json:"fieldID,omitempty"` + IndexStats []*IndexStats `protobuf:"bytes,3,rep,name=index_stats,json=indexStats,proto3" json:"index_stats,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *FieldStats) Reset() { *m = FieldStats{} } +func (m *FieldStats) String() string { return proto.CompactTextString(m) } +func (*FieldStats) ProtoMessage() {} +func (*FieldStats) Descriptor() ([]byte, []int) { + return fileDescriptor_41f4a519b878ee3b, []int{19} +} + +func (m *FieldStats) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_FieldStats.Unmarshal(m, b) +} +func (m *FieldStats) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_FieldStats.Marshal(b, m, deterministic) +} +func (m *FieldStats) XXX_Merge(src proto.Message) { + xxx_messageInfo_FieldStats.Merge(m, src) +} +func (m *FieldStats) XXX_Size() int { + return xxx_messageInfo_FieldStats.Size(m) +} +func (m *FieldStats) XXX_DiscardUnknown() { + xxx_messageInfo_FieldStats.DiscardUnknown(m) +} + +var xxx_messageInfo_FieldStats proto.InternalMessageInfo + +func (m *FieldStats) GetCollectionID() int64 { + if m != nil { + return m.CollectionID + } + return 0 +} + +func (m *FieldStats) GetFieldID() int64 { + if m != nil { + return m.FieldID + } + return 0 +} + +func (m *FieldStats) GetIndexStats() []*IndexStats { + if m != nil { + return m.IndexStats + } + return nil +} + +type SegmentStats struct { + SegmentID int64 `protobuf:"varint,1,opt,name=segmentID,proto3" json:"segmentID,omitempty"` + MemorySize int64 `protobuf:"varint,2,opt,name=memory_size,json=memorySize,proto3" json:"memory_size,omitempty"` + NumRows int64 `protobuf:"varint,3,opt,name=num_rows,json=numRows,proto3" json:"num_rows,omitempty"` + RecentlyModified bool `protobuf:"varint,4,opt,name=recently_modified,json=recentlyModified,proto3" json:"recently_modified,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *SegmentStats) Reset() { *m = SegmentStats{} } +func (m *SegmentStats) String() string { return proto.CompactTextString(m) } +func (*SegmentStats) ProtoMessage() {} +func (*SegmentStats) Descriptor() ([]byte, []int) { + return fileDescriptor_41f4a519b878ee3b, []int{20} +} + +func (m *SegmentStats) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_SegmentStats.Unmarshal(m, b) +} +func (m *SegmentStats) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_SegmentStats.Marshal(b, m, deterministic) +} +func (m *SegmentStats) XXX_Merge(src proto.Message) { + xxx_messageInfo_SegmentStats.Merge(m, src) +} +func (m *SegmentStats) XXX_Size() int { + return xxx_messageInfo_SegmentStats.Size(m) +} +func (m *SegmentStats) XXX_DiscardUnknown() { + xxx_messageInfo_SegmentStats.DiscardUnknown(m) +} + +var xxx_messageInfo_SegmentStats proto.InternalMessageInfo + +func (m *SegmentStats) GetSegmentID() int64 { + if m != nil { + return m.SegmentID + } + return 0 +} + +func (m *SegmentStats) GetMemorySize() int64 { + if m != nil { + return m.MemorySize + } + return 0 +} + +func (m *SegmentStats) GetNumRows() int64 { + if m != nil { + return m.NumRows + } + return 0 +} + +func (m *SegmentStats) GetRecentlyModified() bool { + if m != nil { + return m.RecentlyModified + } + return false +} + +type ChannelTimeTickMsg struct { + Base *commonpb.MsgBase `protobuf:"bytes,1,opt,name=base,proto3" json:"base,omitempty"` + ChannelNames []string `protobuf:"bytes,2,rep,name=channelNames,proto3" json:"channelNames,omitempty"` + Timestamps []uint64 `protobuf:"varint,3,rep,packed,name=timestamps,proto3" json:"timestamps,omitempty"` + DefaultTimestamp uint64 `protobuf:"varint,4,opt,name=default_timestamp,json=defaultTimestamp,proto3" json:"default_timestamp,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *ChannelTimeTickMsg) Reset() { *m = ChannelTimeTickMsg{} } +func (m *ChannelTimeTickMsg) String() string { return proto.CompactTextString(m) } +func (*ChannelTimeTickMsg) ProtoMessage() {} +func (*ChannelTimeTickMsg) Descriptor() ([]byte, []int) { + return fileDescriptor_41f4a519b878ee3b, []int{21} +} + +func (m *ChannelTimeTickMsg) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_ChannelTimeTickMsg.Unmarshal(m, b) +} +func (m *ChannelTimeTickMsg) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_ChannelTimeTickMsg.Marshal(b, m, deterministic) +} +func (m *ChannelTimeTickMsg) XXX_Merge(src proto.Message) { + xxx_messageInfo_ChannelTimeTickMsg.Merge(m, src) +} +func (m *ChannelTimeTickMsg) XXX_Size() int { + return xxx_messageInfo_ChannelTimeTickMsg.Size(m) +} +func (m *ChannelTimeTickMsg) XXX_DiscardUnknown() { + xxx_messageInfo_ChannelTimeTickMsg.DiscardUnknown(m) +} + +var xxx_messageInfo_ChannelTimeTickMsg proto.InternalMessageInfo + +func (m *ChannelTimeTickMsg) GetBase() *commonpb.MsgBase { + if m != nil { + return m.Base + } + return nil +} + +func (m *ChannelTimeTickMsg) GetChannelNames() []string { + if m != nil { + return m.ChannelNames + } + return nil +} + +func (m *ChannelTimeTickMsg) GetTimestamps() []uint64 { + if m != nil { + return m.Timestamps + } + return nil +} + +func (m *ChannelTimeTickMsg) GetDefaultTimestamp() uint64 { + if m != nil { + return m.DefaultTimestamp + } + return 0 +} + +type CredentialInfo struct { + Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"` + // encrypted by bcrypt (for higher security level) + EncryptedPassword string `protobuf:"bytes,2,opt,name=encrypted_password,json=encryptedPassword,proto3" json:"encrypted_password,omitempty"` + Tenant string `protobuf:"bytes,3,opt,name=tenant,proto3" json:"tenant,omitempty"` + IsSuper bool `protobuf:"varint,4,opt,name=is_super,json=isSuper,proto3" json:"is_super,omitempty"` + // encrypted by sha256 (for good performance in cache mapping) + Sha256Password string `protobuf:"bytes,5,opt,name=sha256_password,json=sha256Password,proto3" json:"sha256_password,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *CredentialInfo) Reset() { *m = CredentialInfo{} } +func (m *CredentialInfo) String() string { return proto.CompactTextString(m) } +func (*CredentialInfo) ProtoMessage() {} +func (*CredentialInfo) Descriptor() ([]byte, []int) { + return fileDescriptor_41f4a519b878ee3b, []int{22} +} + +func (m *CredentialInfo) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_CredentialInfo.Unmarshal(m, b) +} +func (m *CredentialInfo) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_CredentialInfo.Marshal(b, m, deterministic) +} +func (m *CredentialInfo) XXX_Merge(src proto.Message) { + xxx_messageInfo_CredentialInfo.Merge(m, src) +} +func (m *CredentialInfo) XXX_Size() int { + return xxx_messageInfo_CredentialInfo.Size(m) +} +func (m *CredentialInfo) XXX_DiscardUnknown() { + xxx_messageInfo_CredentialInfo.DiscardUnknown(m) +} + +var xxx_messageInfo_CredentialInfo proto.InternalMessageInfo + +func (m *CredentialInfo) GetUsername() string { + if m != nil { + return m.Username + } + return "" +} + +func (m *CredentialInfo) GetEncryptedPassword() string { + if m != nil { + return m.EncryptedPassword + } + return "" +} + +func (m *CredentialInfo) GetTenant() string { + if m != nil { + return m.Tenant + } + return "" +} + +func (m *CredentialInfo) GetIsSuper() bool { + if m != nil { + return m.IsSuper + } + return false +} + +func (m *CredentialInfo) GetSha256Password() string { + if m != nil { + return m.Sha256Password + } + return "" +} + +type ListPolicyRequest struct { + // Not useful for now + Base *commonpb.MsgBase `protobuf:"bytes,1,opt,name=base,proto3" json:"base,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *ListPolicyRequest) Reset() { *m = ListPolicyRequest{} } +func (m *ListPolicyRequest) String() string { return proto.CompactTextString(m) } +func (*ListPolicyRequest) ProtoMessage() {} +func (*ListPolicyRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_41f4a519b878ee3b, []int{23} +} + +func (m *ListPolicyRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_ListPolicyRequest.Unmarshal(m, b) +} +func (m *ListPolicyRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_ListPolicyRequest.Marshal(b, m, deterministic) +} +func (m *ListPolicyRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_ListPolicyRequest.Merge(m, src) +} +func (m *ListPolicyRequest) XXX_Size() int { + return xxx_messageInfo_ListPolicyRequest.Size(m) +} +func (m *ListPolicyRequest) XXX_DiscardUnknown() { + xxx_messageInfo_ListPolicyRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_ListPolicyRequest proto.InternalMessageInfo + +func (m *ListPolicyRequest) GetBase() *commonpb.MsgBase { + if m != nil { + return m.Base + } + return nil +} + +type ListPolicyResponse struct { + // Contain error_code and reason + Status *commonpb.Status `protobuf:"bytes,1,opt,name=status,proto3" json:"status,omitempty"` + PolicyInfos []string `protobuf:"bytes,2,rep,name=policy_infos,json=policyInfos,proto3" json:"policy_infos,omitempty"` + UserRoles []string `protobuf:"bytes,3,rep,name=user_roles,json=userRoles,proto3" json:"user_roles,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *ListPolicyResponse) Reset() { *m = ListPolicyResponse{} } +func (m *ListPolicyResponse) String() string { return proto.CompactTextString(m) } +func (*ListPolicyResponse) ProtoMessage() {} +func (*ListPolicyResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_41f4a519b878ee3b, []int{24} +} + +func (m *ListPolicyResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_ListPolicyResponse.Unmarshal(m, b) +} +func (m *ListPolicyResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_ListPolicyResponse.Marshal(b, m, deterministic) +} +func (m *ListPolicyResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_ListPolicyResponse.Merge(m, src) +} +func (m *ListPolicyResponse) XXX_Size() int { + return xxx_messageInfo_ListPolicyResponse.Size(m) +} +func (m *ListPolicyResponse) XXX_DiscardUnknown() { + xxx_messageInfo_ListPolicyResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_ListPolicyResponse proto.InternalMessageInfo + +func (m *ListPolicyResponse) GetStatus() *commonpb.Status { + if m != nil { + return m.Status + } + return nil +} + +func (m *ListPolicyResponse) GetPolicyInfos() []string { + if m != nil { + return m.PolicyInfos + } + return nil +} + +func (m *ListPolicyResponse) GetUserRoles() []string { + if m != nil { + return m.UserRoles + } + return nil +} + +type ShowConfigurationsRequest struct { + Base *commonpb.MsgBase `protobuf:"bytes,1,opt,name=base,proto3" json:"base,omitempty"` + Pattern string `protobuf:"bytes,2,opt,name=pattern,proto3" json:"pattern,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *ShowConfigurationsRequest) Reset() { *m = ShowConfigurationsRequest{} } +func (m *ShowConfigurationsRequest) String() string { return proto.CompactTextString(m) } +func (*ShowConfigurationsRequest) ProtoMessage() {} +func (*ShowConfigurationsRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_41f4a519b878ee3b, []int{25} +} + +func (m *ShowConfigurationsRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_ShowConfigurationsRequest.Unmarshal(m, b) +} +func (m *ShowConfigurationsRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_ShowConfigurationsRequest.Marshal(b, m, deterministic) +} +func (m *ShowConfigurationsRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_ShowConfigurationsRequest.Merge(m, src) +} +func (m *ShowConfigurationsRequest) XXX_Size() int { + return xxx_messageInfo_ShowConfigurationsRequest.Size(m) +} +func (m *ShowConfigurationsRequest) XXX_DiscardUnknown() { + xxx_messageInfo_ShowConfigurationsRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_ShowConfigurationsRequest proto.InternalMessageInfo + +func (m *ShowConfigurationsRequest) GetBase() *commonpb.MsgBase { + if m != nil { + return m.Base + } + return nil +} + +func (m *ShowConfigurationsRequest) GetPattern() string { + if m != nil { + return m.Pattern + } + return "" +} + +type ShowConfigurationsResponse struct { + Status *commonpb.Status `protobuf:"bytes,1,opt,name=status,proto3" json:"status,omitempty"` + Configuations []*commonpb.KeyValuePair `protobuf:"bytes,2,rep,name=configuations,proto3" json:"configuations,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *ShowConfigurationsResponse) Reset() { *m = ShowConfigurationsResponse{} } +func (m *ShowConfigurationsResponse) String() string { return proto.CompactTextString(m) } +func (*ShowConfigurationsResponse) ProtoMessage() {} +func (*ShowConfigurationsResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_41f4a519b878ee3b, []int{26} +} + +func (m *ShowConfigurationsResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_ShowConfigurationsResponse.Unmarshal(m, b) +} +func (m *ShowConfigurationsResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_ShowConfigurationsResponse.Marshal(b, m, deterministic) +} +func (m *ShowConfigurationsResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_ShowConfigurationsResponse.Merge(m, src) +} +func (m *ShowConfigurationsResponse) XXX_Size() int { + return xxx_messageInfo_ShowConfigurationsResponse.Size(m) +} +func (m *ShowConfigurationsResponse) XXX_DiscardUnknown() { + xxx_messageInfo_ShowConfigurationsResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_ShowConfigurationsResponse proto.InternalMessageInfo + +func (m *ShowConfigurationsResponse) GetStatus() *commonpb.Status { + if m != nil { + return m.Status + } + return nil +} + +func (m *ShowConfigurationsResponse) GetConfiguations() []*commonpb.KeyValuePair { + if m != nil { + return m.Configuations + } + return nil +} + +type Rate struct { + Rt RateType `protobuf:"varint,1,opt,name=rt,proto3,enum=milvus.proto.internal.RateType" json:"rt,omitempty"` + R float64 `protobuf:"fixed64,2,opt,name=r,proto3" json:"r,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Rate) Reset() { *m = Rate{} } +func (m *Rate) String() string { return proto.CompactTextString(m) } +func (*Rate) ProtoMessage() {} +func (*Rate) Descriptor() ([]byte, []int) { + return fileDescriptor_41f4a519b878ee3b, []int{27} +} + +func (m *Rate) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Rate.Unmarshal(m, b) +} +func (m *Rate) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Rate.Marshal(b, m, deterministic) +} +func (m *Rate) XXX_Merge(src proto.Message) { + xxx_messageInfo_Rate.Merge(m, src) +} +func (m *Rate) XXX_Size() int { + return xxx_messageInfo_Rate.Size(m) +} +func (m *Rate) XXX_DiscardUnknown() { + xxx_messageInfo_Rate.DiscardUnknown(m) +} + +var xxx_messageInfo_Rate proto.InternalMessageInfo + +func (m *Rate) GetRt() RateType { + if m != nil { + return m.Rt + } + return RateType_DDLCollection +} + +func (m *Rate) GetR() float64 { + if m != nil { + return m.R + } + return 0 +} + +func init() { + proto.RegisterEnum("milvus.proto.internal.RateType", RateType_name, RateType_value) + proto.RegisterType((*GetTimeTickChannelRequest)(nil), "milvus.proto.internal.GetTimeTickChannelRequest") + proto.RegisterType((*GetStatisticsChannelRequest)(nil), "milvus.proto.internal.GetStatisticsChannelRequest") + proto.RegisterType((*GetDdChannelRequest)(nil), "milvus.proto.internal.GetDdChannelRequest") + proto.RegisterType((*NodeInfo)(nil), "milvus.proto.internal.NodeInfo") + proto.RegisterType((*InitParams)(nil), "milvus.proto.internal.InitParams") + proto.RegisterType((*StringList)(nil), "milvus.proto.internal.StringList") + proto.RegisterType((*GetStatisticsRequest)(nil), "milvus.proto.internal.GetStatisticsRequest") + proto.RegisterType((*GetStatisticsResponse)(nil), "milvus.proto.internal.GetStatisticsResponse") + proto.RegisterType((*CreateAliasRequest)(nil), "milvus.proto.internal.CreateAliasRequest") + proto.RegisterType((*DropAliasRequest)(nil), "milvus.proto.internal.DropAliasRequest") + proto.RegisterType((*AlterAliasRequest)(nil), "milvus.proto.internal.AlterAliasRequest") + proto.RegisterType((*CreateIndexRequest)(nil), "milvus.proto.internal.CreateIndexRequest") + proto.RegisterType((*SearchRequest)(nil), "milvus.proto.internal.SearchRequest") + proto.RegisterType((*SearchResults)(nil), "milvus.proto.internal.SearchResults") + proto.RegisterType((*CostAggregation)(nil), "milvus.proto.internal.CostAggregation") + proto.RegisterType((*RetrieveRequest)(nil), "milvus.proto.internal.RetrieveRequest") + proto.RegisterType((*RetrieveResults)(nil), "milvus.proto.internal.RetrieveResults") + proto.RegisterType((*LoadIndex)(nil), "milvus.proto.internal.LoadIndex") + proto.RegisterType((*IndexStats)(nil), "milvus.proto.internal.IndexStats") + proto.RegisterType((*FieldStats)(nil), "milvus.proto.internal.FieldStats") + proto.RegisterType((*SegmentStats)(nil), "milvus.proto.internal.SegmentStats") + proto.RegisterType((*ChannelTimeTickMsg)(nil), "milvus.proto.internal.ChannelTimeTickMsg") + proto.RegisterType((*CredentialInfo)(nil), "milvus.proto.internal.CredentialInfo") + proto.RegisterType((*ListPolicyRequest)(nil), "milvus.proto.internal.ListPolicyRequest") + proto.RegisterType((*ListPolicyResponse)(nil), "milvus.proto.internal.ListPolicyResponse") + proto.RegisterType((*ShowConfigurationsRequest)(nil), "milvus.proto.internal.ShowConfigurationsRequest") + proto.RegisterType((*ShowConfigurationsResponse)(nil), "milvus.proto.internal.ShowConfigurationsResponse") + proto.RegisterType((*Rate)(nil), "milvus.proto.internal.Rate") +} + +func init() { proto.RegisterFile("internal.proto", fileDescriptor_41f4a519b878ee3b) } + +var fileDescriptor_41f4a519b878ee3b = []byte{ + // 1927 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xd4, 0x58, 0x4b, 0x73, 0x1c, 0x49, + 0x11, 0xa6, 0xe7, 0x3d, 0x39, 0x23, 0x69, 0x54, 0x96, 0x4d, 0xfb, 0xb1, 0x6b, 0x6d, 0x43, 0x80, + 0x58, 0x62, 0xed, 0x45, 0x1b, 0xbb, 0xe6, 0x40, 0x40, 0xd8, 0x6a, 0xaf, 0x62, 0x62, 0xc7, 0x46, + 0xee, 0x31, 0x1b, 0x01, 0x97, 0x8e, 0x9a, 0xe9, 0xd4, 0xa8, 0x70, 0xbf, 0x54, 0x55, 0x6d, 0x49, + 0x3e, 0x73, 0x23, 0x82, 0x1b, 0x1c, 0x88, 0x80, 0x7f, 0xc0, 0x79, 0x83, 0x13, 0xff, 0x80, 0x13, + 0xbf, 0x66, 0x4f, 0x44, 0x3d, 0x7a, 0x5e, 0x1a, 0x2b, 0x24, 0x99, 0xc7, 0xee, 0xad, 0x33, 0xf3, + 0xab, 0xac, 0xaa, 0xcc, 0xac, 0xaf, 0xb2, 0x1a, 0xd6, 0x59, 0x2a, 0x91, 0xa7, 0x34, 0x7e, 0x90, + 0xf3, 0x4c, 0x66, 0xe4, 0x66, 0xc2, 0xe2, 0xd7, 0x85, 0x30, 0xd2, 0x83, 0xd2, 0x78, 0xa7, 0x3b, + 0xce, 0x92, 0x24, 0x4b, 0x8d, 0xfa, 0x4e, 0x57, 0x8c, 0x8f, 0x30, 0xa1, 0x46, 0xf2, 0xee, 0xc2, + 0xed, 0x7d, 0x94, 0x2f, 0x59, 0x82, 0x2f, 0xd9, 0xf8, 0xd5, 0xde, 0x11, 0x4d, 0x53, 0x8c, 0x03, + 0x3c, 0x2e, 0x50, 0x48, 0xef, 0x3d, 0xb8, 0xbb, 0x8f, 0x72, 0x28, 0xa9, 0x64, 0x42, 0xb2, 0xb1, + 0x58, 0x32, 0xdf, 0x84, 0x1b, 0xfb, 0x28, 0xfd, 0x68, 0x49, 0xfd, 0x25, 0xb4, 0x9e, 0x67, 0x11, + 0xf6, 0xd3, 0xc3, 0x8c, 0x7c, 0x06, 0x4d, 0x1a, 0x45, 0x1c, 0x85, 0x70, 0x9d, 0x6d, 0x67, 0xa7, + 0xb3, 0x7b, 0xef, 0xc1, 0xc2, 0x1a, 0xed, 0xca, 0x1e, 0x1b, 0x4c, 0x50, 0x82, 0x09, 0x81, 0x1a, + 0xcf, 0x62, 0x74, 0x2b, 0xdb, 0xce, 0x4e, 0x3b, 0xd0, 0xdf, 0xde, 0x6f, 0x01, 0xfa, 0x29, 0x93, + 0x07, 0x94, 0xd3, 0x44, 0x90, 0x5b, 0xd0, 0x48, 0xd5, 0x2c, 0xbe, 0x76, 0x5c, 0x0d, 0xac, 0x44, + 0x7c, 0xe8, 0x0a, 0x49, 0xb9, 0x0c, 0x73, 0x8d, 0x73, 0x2b, 0xdb, 0xd5, 0x9d, 0xce, 0xee, 0x07, + 0x2b, 0xa7, 0xfd, 0x02, 0xcf, 0xbe, 0xa4, 0x71, 0x81, 0x07, 0x94, 0xf1, 0xa0, 0xa3, 0x87, 0x19, + 0xef, 0xde, 0xaf, 0x01, 0x86, 0x92, 0xb3, 0x74, 0x32, 0x60, 0x42, 0xaa, 0xb9, 0x5e, 0x2b, 0x9c, + 0xda, 0x44, 0x75, 0xa7, 0x1d, 0x58, 0x89, 0x7c, 0x02, 0x0d, 0x21, 0xa9, 0x2c, 0x84, 0x5e, 0x67, + 0x67, 0xf7, 0xee, 0xca, 0x59, 0x86, 0x1a, 0x12, 0x58, 0xa8, 0xf7, 0xb7, 0x0a, 0x6c, 0x2d, 0x44, + 0xd5, 0xc6, 0x8d, 0x7c, 0x0c, 0xb5, 0x11, 0x15, 0x78, 0x61, 0xa0, 0x9e, 0x89, 0xc9, 0x13, 0x2a, + 0x30, 0xd0, 0x48, 0x15, 0xa5, 0x68, 0xd4, 0xf7, 0xf5, 0xec, 0xd5, 0x40, 0x7f, 0x13, 0x0f, 0xba, + 0xe3, 0x2c, 0x8e, 0x71, 0x2c, 0x59, 0x96, 0xf6, 0x7d, 0xb7, 0xaa, 0x6d, 0x0b, 0x3a, 0x85, 0xc9, + 0x29, 0x97, 0xcc, 0x88, 0xc2, 0xad, 0x6d, 0x57, 0x15, 0x66, 0x5e, 0x47, 0x7e, 0x04, 0x3d, 0xc9, + 0xe9, 0x6b, 0x8c, 0x43, 0xc9, 0x12, 0x14, 0x92, 0x26, 0xb9, 0x5b, 0xdf, 0x76, 0x76, 0x6a, 0xc1, + 0x86, 0xd1, 0xbf, 0x2c, 0xd5, 0xe4, 0x21, 0xdc, 0x98, 0x14, 0x94, 0xd3, 0x54, 0x22, 0xce, 0xa1, + 0x1b, 0x1a, 0x4d, 0xa6, 0xa6, 0xd9, 0x80, 0x1f, 0xc3, 0xa6, 0x82, 0x65, 0x85, 0x9c, 0x83, 0x37, + 0x35, 0xbc, 0x67, 0x0d, 0x53, 0xb0, 0xf7, 0x95, 0x03, 0x37, 0x97, 0xe2, 0x25, 0xf2, 0x2c, 0x15, + 0x78, 0x8d, 0x80, 0x5d, 0x27, 0x61, 0xe4, 0x11, 0xd4, 0xd5, 0x97, 0x70, 0xab, 0x97, 0x2d, 0x25, + 0x83, 0xf7, 0xfe, 0xea, 0x00, 0xd9, 0xe3, 0x48, 0x25, 0x3e, 0x8e, 0x19, 0x7d, 0x87, 0x3c, 0x7f, + 0x17, 0x9a, 0xd1, 0x28, 0x4c, 0x69, 0x52, 0x1e, 0x88, 0x46, 0x34, 0x7a, 0x4e, 0x13, 0x24, 0x3f, + 0x84, 0x8d, 0x59, 0x62, 0x0d, 0xa0, 0xaa, 0x01, 0xeb, 0x33, 0xb5, 0x06, 0x6e, 0x41, 0x9d, 0xaa, + 0x35, 0xb8, 0x35, 0x6d, 0x36, 0x82, 0x27, 0xa0, 0xe7, 0xf3, 0x2c, 0xff, 0x6f, 0xad, 0x6e, 0x3a, + 0x69, 0x75, 0x7e, 0xd2, 0xbf, 0x38, 0xb0, 0xf9, 0x38, 0x96, 0xc8, 0xbf, 0xa1, 0x41, 0xf9, 0x47, + 0xa5, 0xcc, 0x5a, 0x3f, 0x8d, 0xf0, 0xf4, 0xff, 0xb9, 0xc0, 0xf7, 0x00, 0x0e, 0x19, 0xc6, 0x91, + 0xc1, 0x98, 0x55, 0xb6, 0xb5, 0x46, 0x9b, 0xcb, 0xe3, 0x5f, 0xbf, 0xe0, 0xf8, 0x37, 0x56, 0x1c, + 0x7f, 0x17, 0x9a, 0xda, 0x49, 0xdf, 0xd7, 0x87, 0xae, 0x1a, 0x94, 0xa2, 0x22, 0x4f, 0x3c, 0x95, + 0x9c, 0x96, 0xe4, 0xd9, 0xba, 0x34, 0x79, 0xea, 0x61, 0x96, 0x3c, 0xff, 0x54, 0x87, 0xb5, 0x21, + 0x52, 0x3e, 0x3e, 0xba, 0x7e, 0xf0, 0xb6, 0xa0, 0xce, 0xf1, 0x78, 0xca, 0x6d, 0x46, 0x98, 0xee, + 0xb8, 0x7a, 0xc1, 0x8e, 0x6b, 0x97, 0x20, 0xbc, 0xfa, 0x0a, 0xc2, 0xeb, 0x41, 0x35, 0x12, 0xb1, + 0x0e, 0x58, 0x3b, 0x50, 0x9f, 0x8a, 0xa6, 0xf2, 0x98, 0x8e, 0xf1, 0x28, 0x8b, 0x23, 0xe4, 0xe1, + 0x84, 0x67, 0x85, 0xa1, 0xa9, 0x6e, 0xd0, 0x9b, 0x33, 0xec, 0x2b, 0x3d, 0x79, 0x04, 0xad, 0x48, + 0xc4, 0xa1, 0x3c, 0xcb, 0xd1, 0x6d, 0x6d, 0x3b, 0x3b, 0xeb, 0x6f, 0xd9, 0xa6, 0x2f, 0xe2, 0x97, + 0x67, 0x39, 0x06, 0xcd, 0xc8, 0x7c, 0x90, 0x8f, 0x61, 0x4b, 0x20, 0x67, 0x34, 0x66, 0x6f, 0x30, + 0x0a, 0xf1, 0x34, 0xe7, 0x61, 0x1e, 0xd3, 0xd4, 0x6d, 0xeb, 0x89, 0xc8, 0xcc, 0xf6, 0xf4, 0x34, + 0xe7, 0x07, 0x31, 0x4d, 0xc9, 0x0e, 0xf4, 0xb2, 0x42, 0xe6, 0x85, 0x0c, 0x75, 0xde, 0x44, 0xc8, + 0x22, 0x17, 0xf4, 0x8e, 0xd6, 0x8d, 0xfe, 0x73, 0xad, 0xee, 0x47, 0x2b, 0x49, 0xbc, 0x73, 0x25, + 0x12, 0xef, 0x5e, 0x8d, 0xc4, 0xd7, 0x56, 0x93, 0x38, 0x59, 0x87, 0x4a, 0x7a, 0xec, 0xae, 0xeb, + 0xd4, 0x54, 0xd2, 0x63, 0x95, 0x48, 0x99, 0xe5, 0xaf, 0xdc, 0x0d, 0x93, 0x48, 0xf5, 0x4d, 0xde, + 0x07, 0x48, 0x50, 0x72, 0x36, 0x56, 0x61, 0x71, 0x7b, 0x3a, 0x0f, 0x73, 0x1a, 0xf2, 0x7d, 0x58, + 0x63, 0x93, 0x34, 0xe3, 0xb8, 0xcf, 0xb3, 0x13, 0x96, 0x4e, 0xdc, 0xcd, 0x6d, 0x67, 0xa7, 0x15, + 0x2c, 0x2a, 0xc9, 0x1d, 0x68, 0x15, 0x42, 0xf5, 0x3d, 0x09, 0xba, 0x44, 0xfb, 0x98, 0xca, 0xde, + 0x3f, 0x6b, 0xb3, 0xc2, 0x14, 0x45, 0x2c, 0xc5, 0xff, 0xea, 0x0a, 0x99, 0x56, 0x73, 0x75, 0xbe, + 0x9a, 0xef, 0x43, 0xc7, 0x6c, 0xcf, 0x54, 0x4d, 0xed, 0xdc, 0x8e, 0xef, 0x43, 0x27, 0x2d, 0x92, + 0xf0, 0xb8, 0x40, 0xce, 0x50, 0xd8, 0x73, 0x0e, 0x69, 0x91, 0xbc, 0x30, 0x1a, 0x72, 0x03, 0xea, + 0x32, 0xcb, 0xc3, 0x57, 0xf6, 0x98, 0xab, 0x38, 0x7e, 0x41, 0x7e, 0x06, 0x77, 0x04, 0xd2, 0x18, + 0xa3, 0x50, 0xe0, 0x24, 0xc1, 0x54, 0xf6, 0x7d, 0x11, 0x0a, 0xbd, 0x6d, 0x8c, 0xdc, 0xa6, 0x2e, + 0x14, 0xd7, 0x20, 0x86, 0x53, 0xc0, 0xd0, 0xda, 0x55, 0x1d, 0x8c, 0x4d, 0x3f, 0xb7, 0x30, 0xac, + 0xa5, 0x1b, 0x1f, 0x32, 0x33, 0x4d, 0x07, 0xfc, 0x14, 0xdc, 0x49, 0x9c, 0x8d, 0x68, 0x1c, 0x9e, + 0x9b, 0xd5, 0x6d, 0xeb, 0xc9, 0x6e, 0x19, 0xfb, 0x70, 0x69, 0x4a, 0xb5, 0x3d, 0x11, 0xb3, 0x31, + 0x46, 0xe1, 0x28, 0xce, 0x46, 0x2e, 0xe8, 0x82, 0x07, 0xa3, 0x7a, 0x12, 0x67, 0x23, 0x55, 0xe8, + 0x16, 0xa0, 0xc2, 0x30, 0xce, 0x8a, 0x54, 0xea, 0xf2, 0xad, 0x06, 0xeb, 0x46, 0xff, 0xbc, 0x48, + 0xf6, 0x94, 0x96, 0x7c, 0x0f, 0xd6, 0x2c, 0x32, 0x3b, 0x3c, 0x14, 0x28, 0x75, 0xdd, 0x56, 0x83, + 0xae, 0x51, 0xfe, 0x52, 0xeb, 0xc8, 0x81, 0xe2, 0x5d, 0x21, 0x1f, 0x4f, 0x26, 0x1c, 0x27, 0x54, + 0x9d, 0x7b, 0x5d, 0xaf, 0x9d, 0xdd, 0x1f, 0x3c, 0x58, 0xd9, 0x38, 0x3f, 0xd8, 0x5b, 0x44, 0x07, + 0xcb, 0xc3, 0xbd, 0x63, 0xd8, 0x58, 0xc2, 0x28, 0xaa, 0xe1, 0xb6, 0x41, 0x51, 0xe5, 0x6f, 0xbb, + 0xd3, 0x05, 0x1d, 0xd9, 0x86, 0x8e, 0x40, 0xfe, 0x9a, 0x8d, 0x0d, 0xc4, 0x50, 0xdc, 0xbc, 0x4a, + 0x51, 0xb4, 0xcc, 0x24, 0x8d, 0x9f, 0xbf, 0xb0, 0x25, 0x53, 0x8a, 0xde, 0xbf, 0x6a, 0xb0, 0x11, + 0xa8, 0x12, 0xc1, 0xd7, 0xf8, 0x6d, 0xa2, 0xd7, 0xb7, 0xd1, 0x5c, 0xe3, 0x4a, 0x34, 0xd7, 0xbc, + 0x34, 0xcd, 0xb5, 0xae, 0x44, 0x73, 0xed, 0xab, 0xd1, 0x1c, 0xbc, 0x85, 0xe6, 0xb6, 0xa0, 0x1e, + 0xb3, 0x84, 0x95, 0x55, 0x6a, 0x84, 0xf3, 0xc4, 0xd5, 0x5d, 0x45, 0x5c, 0xb7, 0xa1, 0xc5, 0x84, + 0x2d, 0xf2, 0x35, 0x0d, 0x68, 0x32, 0x61, 0xaa, 0xfb, 0x29, 0xdc, 0x67, 0x12, 0xb9, 0x2e, 0xb0, + 0x10, 0x4f, 0x25, 0xa6, 0x42, 0x7d, 0x71, 0x8c, 0x8a, 0x31, 0x86, 0x9c, 0x4a, 0xb4, 0xd4, 0x7a, + 0x6f, 0x0a, 0x7b, 0x5a, 0xa2, 0x02, 0x0d, 0x0a, 0xa8, 0xc4, 0x05, 0x6a, 0xdc, 0x58, 0xa2, 0xc6, + 0xaf, 0xab, 0xf3, 0x65, 0xf5, 0x0d, 0x20, 0xc7, 0x0f, 0xa1, 0xca, 0x22, 0xd3, 0x9a, 0x75, 0x76, + 0xdd, 0x45, 0x3f, 0xf6, 0x05, 0xdb, 0xf7, 0x45, 0xa0, 0x40, 0xe4, 0x17, 0xd0, 0xb1, 0x25, 0x12, + 0x51, 0x49, 0x75, 0xf9, 0x75, 0x76, 0xdf, 0x5f, 0x39, 0x46, 0xd7, 0x8c, 0x4f, 0x25, 0x0d, 0x4c, + 0x6b, 0x25, 0xd4, 0x37, 0xf9, 0x39, 0xdc, 0x3d, 0x4f, 0x99, 0xdc, 0x86, 0x23, 0x72, 0x1b, 0xba, + 0xea, 0x6e, 0x2f, 0x73, 0x66, 0x19, 0xaf, 0x88, 0xfc, 0x04, 0xb6, 0xe6, 0x48, 0x73, 0x36, 0xb0, + 0xa9, 0x59, 0x73, 0x8e, 0x50, 0x67, 0x43, 0x2e, 0xa2, 0xcd, 0xd6, 0x85, 0xb4, 0xf9, 0x9f, 0xa7, + 0xb1, 0xaf, 0x1d, 0x68, 0x0f, 0x32, 0x1a, 0xe9, 0x86, 0xf7, 0x1a, 0x69, 0xbf, 0x07, 0xed, 0xe9, + 0xea, 0x2d, 0xa3, 0xcc, 0x14, 0xca, 0x3a, 0xed, 0x59, 0x6d, 0xa3, 0x3b, 0xd7, 0xc4, 0xce, 0x35, + 0xa3, 0xb5, 0xc5, 0x66, 0xf4, 0x3e, 0x74, 0x98, 0x5a, 0x50, 0x98, 0x53, 0x79, 0x64, 0x48, 0xa5, + 0x1d, 0x80, 0x56, 0x1d, 0x28, 0x8d, 0xea, 0x56, 0x4b, 0x80, 0xee, 0x56, 0x1b, 0x97, 0xee, 0x56, + 0xad, 0x13, 0xdd, 0xad, 0xfe, 0xce, 0x01, 0xd0, 0x1b, 0x57, 0x65, 0x79, 0xde, 0xa9, 0x73, 0x1d, + 0xa7, 0x8a, 0xed, 0xd4, 0x95, 0xc5, 0x31, 0xa6, 0x72, 0x96, 0x5b, 0x61, 0x83, 0x43, 0xd2, 0x22, + 0x09, 0x8c, 0xc9, 0xe6, 0x55, 0x78, 0x7f, 0x70, 0x00, 0x74, 0x71, 0x9a, 0x65, 0x2c, 0xd3, 0xae, + 0x73, 0x71, 0x1f, 0x5f, 0x59, 0x0c, 0xdd, 0x93, 0x32, 0x74, 0x17, 0x3c, 0x5c, 0xa7, 0xe5, 0x31, + 0xdb, 0xbc, 0x8d, 0xae, 0xfe, 0xf6, 0xfe, 0xe8, 0x40, 0xd7, 0xae, 0xce, 0x2c, 0x69, 0x21, 0xcb, + 0xce, 0x72, 0x96, 0x75, 0x33, 0x93, 0x64, 0xfc, 0x2c, 0x14, 0xec, 0x4d, 0x79, 0xa7, 0x81, 0x51, + 0x0d, 0xd9, 0x1b, 0x54, 0xfc, 0xa6, 0x43, 0x92, 0x9d, 0x88, 0xf2, 0x4e, 0x53, 0x61, 0xc8, 0x4e, + 0x84, 0xe2, 0x58, 0x8e, 0x63, 0x4c, 0x65, 0x7c, 0x16, 0x26, 0x59, 0xc4, 0x0e, 0x19, 0x46, 0xba, + 0x1a, 0x5a, 0x41, 0xaf, 0x34, 0x3c, 0xb3, 0x7a, 0xef, 0x2b, 0xf5, 0xaa, 0x36, 0x07, 0xaa, 0xfc, + 0x6d, 0xf5, 0x4c, 0x4c, 0xae, 0x51, 0xb5, 0x2a, 0xc4, 0xc6, 0x8f, 0x2a, 0x44, 0xf3, 0xa7, 0xa8, + 0x1d, 0x2c, 0xe8, 0x54, 0x4f, 0x3a, 0x65, 0x7d, 0x13, 0xc7, 0x5a, 0x30, 0xa7, 0x51, 0x2b, 0x8f, + 0xf0, 0x90, 0x16, 0xf1, 0xfc, 0xed, 0x50, 0x33, 0xb7, 0x83, 0x35, 0x2c, 0xfc, 0xc9, 0x58, 0xdf, + 0xe3, 0x18, 0x61, 0x2a, 0x19, 0x8d, 0xf5, 0xff, 0xb1, 0x79, 0x4a, 0x76, 0x16, 0x29, 0x99, 0x7c, + 0x04, 0x04, 0xd3, 0x31, 0x3f, 0xcb, 0x55, 0x05, 0xe5, 0x54, 0x88, 0x93, 0x8c, 0x47, 0xf6, 0x29, + 0xb9, 0x39, 0xb5, 0x1c, 0x58, 0x03, 0xb9, 0x05, 0x0d, 0x89, 0x29, 0x4d, 0xa5, 0x3d, 0x63, 0x56, + 0xb2, 0xf7, 0x8a, 0x28, 0x72, 0xe4, 0x36, 0xa6, 0x4d, 0x26, 0x86, 0x4a, 0x54, 0x0f, 0x51, 0x71, + 0x44, 0x77, 0x3f, 0xfd, 0x6c, 0xe6, 0xbe, 0x6e, 0x1e, 0xa2, 0x46, 0x5d, 0xfa, 0xf6, 0x9e, 0xc2, + 0xe6, 0x80, 0x09, 0x79, 0x90, 0xc5, 0x6c, 0x7c, 0x76, 0xed, 0xae, 0xc3, 0xfb, 0xbd, 0x03, 0x64, + 0xde, 0x8f, 0xfd, 0x8f, 0x33, 0xbb, 0x35, 0x9c, 0xcb, 0xdf, 0x1a, 0x1f, 0x40, 0x37, 0xd7, 0x6e, + 0x42, 0x96, 0x1e, 0x66, 0x65, 0xf6, 0x3a, 0x46, 0xa7, 0x62, 0x2b, 0xd4, 0xf3, 0x59, 0x05, 0x33, + 0xe4, 0x59, 0x8c, 0x26, 0x79, 0xed, 0xa0, 0xad, 0x34, 0x81, 0x52, 0x78, 0x13, 0xb8, 0x3d, 0x3c, + 0xca, 0x4e, 0xf6, 0xb2, 0xf4, 0x90, 0x4d, 0x0a, 0x73, 0x6d, 0xbe, 0xc3, 0xff, 0x08, 0x17, 0x9a, + 0x39, 0x95, 0xea, 0x4c, 0xd9, 0x1c, 0x95, 0xa2, 0xf7, 0x67, 0x07, 0xee, 0xac, 0x9a, 0xe9, 0x5d, + 0xb6, 0xbf, 0x0f, 0x6b, 0x63, 0xe3, 0xce, 0x78, 0xbb, 0xfc, 0x7f, 0xce, 0xc5, 0x71, 0xde, 0x53, + 0xa8, 0xe9, 0xe6, 0xe0, 0x21, 0x54, 0xb8, 0xd4, 0x2b, 0x58, 0xdf, 0xbd, 0xff, 0x16, 0xa6, 0x50, + 0x40, 0xfd, 0x78, 0xad, 0x70, 0x49, 0xba, 0xe0, 0x70, 0xbd, 0x53, 0x27, 0x70, 0xf8, 0x87, 0x7f, + 0x77, 0xa0, 0x55, 0x9a, 0xc9, 0x26, 0xac, 0xf9, 0xfe, 0x60, 0x6f, 0xca, 0x55, 0xbd, 0xef, 0x90, + 0x1e, 0x74, 0x7d, 0x7f, 0x70, 0x50, 0x76, 0x84, 0x3d, 0x87, 0x74, 0xa1, 0xe5, 0xfb, 0x03, 0x4d, + 0x3e, 0xbd, 0x8a, 0x95, 0x3e, 0x8f, 0x0b, 0x71, 0xd4, 0xab, 0x4e, 0x1d, 0x24, 0x39, 0x35, 0x0e, + 0x6a, 0x64, 0x0d, 0xda, 0xfe, 0xb3, 0x41, 0x3f, 0x15, 0xc8, 0x65, 0xaf, 0x6e, 0x45, 0x1f, 0x63, + 0x94, 0xd8, 0x6b, 0x90, 0x0d, 0xe8, 0xf8, 0xcf, 0x06, 0x4f, 0x8a, 0xf8, 0x95, 0xba, 0xc7, 0x7a, + 0x4d, 0x6d, 0x7f, 0x31, 0x30, 0x8f, 0x94, 0x5e, 0x4b, 0xbb, 0x7f, 0x31, 0x50, 0xcf, 0xa6, 0xb3, + 0x5e, 0xdb, 0x0e, 0xfe, 0x55, 0xae, 0x7d, 0xc1, 0x93, 0x47, 0xbf, 0xf9, 0x74, 0xc2, 0xe4, 0x51, + 0x31, 0x52, 0xf1, 0x7a, 0x68, 0xb6, 0xfe, 0x11, 0xcb, 0xec, 0xd7, 0xc3, 0x72, 0xfb, 0x0f, 0x75, + 0x34, 0xa6, 0x62, 0x3e, 0x1a, 0x35, 0xb4, 0xe6, 0x93, 0x7f, 0x07, 0x00, 0x00, 0xff, 0xff, 0xc1, + 0x18, 0x48, 0x84, 0x88, 0x17, 0x00, 0x00, +} diff --git a/internal/proxy/lb_policy.go b/internal/proxy/lb_policy.go index 68617372f8795..7a5394863f761 100644 --- a/internal/proxy/lb_policy.go +++ b/internal/proxy/lb_policy.go @@ -32,7 +32,7 @@ import ( "github.com/milvus-io/milvus/pkg/util/typeutil" ) -type executeFunc func(context.Context, UniqueID, types.QueryNodeClient, ...string) error +type executeFunc func(context.Context, UniqueID, types.QueryNodeClient, string) error type ChannelWorkload struct { db string diff --git a/internal/proxy/lb_policy_test.go b/internal/proxy/lb_policy_test.go index b3a89ef5f5d9e..bf3c32c896de3 100644 --- a/internal/proxy/lb_policy_test.go +++ b/internal/proxy/lb_policy_test.go @@ -248,7 +248,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { channel: s.channels[0], shardLeaders: s.nodes, nq: 1, - exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, s ...string) error { + exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error { return nil }, retryTimes: 1, @@ -265,7 +265,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { channel: s.channels[0], shardLeaders: s.nodes, nq: 1, - exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, s ...string) error { + exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error { return nil }, retryTimes: 1, @@ -285,7 +285,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { channel: s.channels[0], shardLeaders: s.nodes, nq: 1, - exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, s ...string) error { + exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error { return nil }, retryTimes: 1, @@ -303,7 +303,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { channel: s.channels[0], shardLeaders: s.nodes, nq: 1, - exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, s ...string) error { + exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error { return nil }, retryTimes: 2, @@ -324,7 +324,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { channel: s.channels[0], shardLeaders: s.nodes, nq: 1, - exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, s ...string) error { + exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error { counter++ if counter == 1 { return errors.New("fake error") @@ -349,7 +349,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { channel: s.channels[0], shardLeaders: s.nodes, nq: 1, - exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, s ...string) error { + exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error { _, err := qn.Search(ctx, nil) return err }, @@ -370,7 +370,7 @@ func (s *LBPolicySuite) TestExecute() { collectionName: s.collectionName, collectionID: s.collectionID, nq: 1, - exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, s ...string) error { + exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error { return nil }, }) @@ -383,7 +383,7 @@ func (s *LBPolicySuite) TestExecute() { collectionName: s.collectionName, collectionID: s.collectionID, nq: 1, - exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, s ...string) error { + exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error { // succeed in first execute if counter.Add(1) == 1 { return nil @@ -404,7 +404,7 @@ func (s *LBPolicySuite) TestExecute() { collectionName: s.collectionName, collectionID: s.collectionID, nq: 1, - exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, s ...string) error { + exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error { return nil }, }) diff --git a/internal/proxy/task_delete.go b/internal/proxy/task_delete.go index 2a0df673d8a38..fda28b75137de 100644 --- a/internal/proxy/task_delete.go +++ b/internal/proxy/task_delete.go @@ -350,7 +350,7 @@ func (dr *deleteRunner) produce(ctx context.Context, primaryKeys *schemapb.IDs) // getStreamingQueryAndDelteFunc return query function used by LBPolicy // make sure it concurrent safe func (dr *deleteRunner) getStreamingQueryAndDelteFunc(plan *planpb.PlanNode) executeFunc { - return func(ctx context.Context, nodeID int64, qn types.QueryNodeClient, channelIDs ...string) error { + return func(ctx context.Context, nodeID int64, qn types.QueryNodeClient, channel string) error { var partitionIDs []int64 // optimize query when partitionKey on @@ -375,7 +375,7 @@ func (dr *deleteRunner) getStreamingQueryAndDelteFunc(plan *planpb.PlanNode) exe log := log.Ctx(ctx).With( zap.Int64("collectionID", dr.collectionID), zap.Int64s("partitionIDs", partitionIDs), - zap.Strings("channels", channelIDs), + zap.String("channel", channel), zap.Int64("nodeID", nodeID)) // set plan @@ -405,7 +405,7 @@ func (dr *deleteRunner) getStreamingQueryAndDelteFunc(plan *planpb.PlanNode) exe OutputFieldsId: outputFieldIDs, GuaranteeTimestamp: parseGuaranteeTsFromConsistency(dr.ts, dr.ts, dr.req.GetConsistencyLevel()), }, - DmlChannels: channelIDs, + DmlChannels: []string{channel}, Scope: querypb.DataScope_All, } diff --git a/internal/proxy/task_delete_test.go b/internal/proxy/task_delete_test.go index 816b6aaa903c5..ca40dcea581e4 100644 --- a/internal/proxy/task_delete_test.go +++ b/internal/proxy/task_delete_test.go @@ -546,7 +546,7 @@ func TestDeleteRunner_Run(t *testing.T) { }, } lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error { - return workload.exec(ctx, 1, qn) + return workload.exec(ctx, 1, qn, "") }) qn.EXPECT().QueryStream(mock.Anything, mock.Anything).Return(nil, errors.New("mock error")) @@ -591,7 +591,7 @@ func TestDeleteRunner_Run(t *testing.T) { stream.EXPECT().Produce(mock.Anything).Return(nil) lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error { - return workload.exec(ctx, 1, qn) + return workload.exec(ctx, 1, qn, "") }) qn.EXPECT().QueryStream(mock.Anything, mock.Anything).Call.Return( @@ -654,7 +654,7 @@ func TestDeleteRunner_Run(t *testing.T) { mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(stream, nil) mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil) lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error { - return workload.exec(ctx, 1, qn) + return workload.exec(ctx, 1, qn, "") }) qn.EXPECT().QueryStream(mock.Anything, mock.Anything).Call.Return( @@ -716,7 +716,7 @@ func TestDeleteRunner_Run(t *testing.T) { mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(stream, nil) mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil) lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error { - return workload.exec(ctx, 1, qn) + return workload.exec(ctx, 1, qn, "") }) qn.EXPECT().QueryStream(mock.Anything, mock.Anything).Call.Return( @@ -797,7 +797,7 @@ func TestDeleteRunner_Run(t *testing.T) { mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(stream, nil) mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil) lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error { - return workload.exec(ctx, 1, qn) + return workload.exec(ctx, 1, qn, "") }) qn.EXPECT().QueryStream(mock.Anything, mock.Anything).Call.Return( @@ -899,7 +899,7 @@ func TestDeleteRunner_StreamingQueryAndDelteFunc(t *testing.T) { qn := mocks.NewMockQueryNodeClient(t) // witho out plan queryFunc := dr.getStreamingQueryAndDelteFunc(nil) - assert.Error(t, queryFunc(ctx, 1, qn)) + assert.Error(t, queryFunc(ctx, 1, qn, "")) }) t.Run("partitionKey mode get meta failed", func(t *testing.T) { @@ -938,7 +938,7 @@ func TestDeleteRunner_StreamingQueryAndDelteFunc(t *testing.T) { plan, err := planparserv2.CreateRetrievePlan(dr.schema.CollectionSchema, dr.req.Expr) assert.NoError(t, err) queryFunc := dr.getStreamingQueryAndDelteFunc(plan) - assert.Error(t, queryFunc(ctx, 1, qn)) + assert.Error(t, queryFunc(ctx, 1, qn, "")) }) t.Run("partitionKey mode get partition ID failed", func(t *testing.T) { @@ -981,6 +981,6 @@ func TestDeleteRunner_StreamingQueryAndDelteFunc(t *testing.T) { plan, err := planparserv2.CreateRetrievePlan(dr.schema.CollectionSchema, dr.req.Expr) assert.NoError(t, err) queryFunc := dr.getStreamingQueryAndDelteFunc(plan) - assert.Error(t, queryFunc(ctx, 1, qn)) + assert.Error(t, queryFunc(ctx, 1, qn, "")) }) } diff --git a/internal/proxy/task_hybrid_search.go b/internal/proxy/task_hybrid_search.go index 61da77861c8c6..7090bbdfb8c07 100644 --- a/internal/proxy/task_hybrid_search.go +++ b/internal/proxy/task_hybrid_search.go @@ -42,9 +42,10 @@ type hybridSearchTask struct { userOutputFields []string - qc types.QueryCoordClient - node types.ProxyComponent - lb LBPolicy + qc types.QueryCoordClient + node types.ProxyComponent + lb LBPolicy + queryChannelsTs map[string]Timestamp collectionID UniqueID @@ -296,7 +297,8 @@ func (t *hybridSearchTask) Requery() error { UseDefaultConsistency: t.request.GetUseDefaultConsistency(), } - return doRequery(t.ctx, t.collectionID, t.node, t.schema.CollectionSchema, queryReq, t.result) + // TODO:Xige-16 refine the mvcc functionality of hybrid search + return doRequery(t.ctx, t.collectionID, t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs) } func rankSearchResultData(ctx context.Context, diff --git a/internal/proxy/task_query.go b/internal/proxy/task_query.go index 47d59d6174125..30a471168aa16 100644 --- a/internal/proxy/task_query.go +++ b/internal/proxy/task_query.go @@ -61,6 +61,8 @@ type queryTask struct { plan *planpb.PlanNode partitionKeyMode bool lb LBPolicy + channelsMvcc map[string]Timestamp + fastSkip bool } type queryParams struct { @@ -467,19 +469,33 @@ func (t *queryTask) PostExecute(ctx context.Context) error { return nil } -func (t *queryTask) queryShard(ctx context.Context, nodeID int64, qn types.QueryNodeClient, channelIDs ...string) error { +func (t *queryTask) queryShard(ctx context.Context, nodeID int64, qn types.QueryNodeClient, channel string) error { + needOverrideMvcc := false + mvccTs := t.MvccTimestamp + if len(t.channelsMvcc) > 0 { + mvccTs, needOverrideMvcc = t.channelsMvcc[channel] + // In fast mode, if there is no corresponding channel in channelsMvcc, quickly skip this query. + if !needOverrideMvcc && t.fastSkip { + return nil + } + } + retrieveReq := typeutil.Clone(t.RetrieveRequest) retrieveReq.GetBase().TargetID = nodeID + if needOverrideMvcc && mvccTs > 0 { + retrieveReq.MvccTimestamp = mvccTs + } + req := &querypb.QueryRequest{ Req: retrieveReq, - DmlChannels: channelIDs, + DmlChannels: []string{channel}, Scope: querypb.DataScope_All, } log := log.Ctx(ctx).With(zap.Int64("collection", t.GetCollectionID()), zap.Int64s("partitionIDs", t.GetPartitionIDs()), zap.Int64("nodeID", nodeID), - zap.Strings("channels", channelIDs)) + zap.String("channel", channel)) result, err := qn.Query(ctx, req) if err != nil { diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index b1ffd6a40b1cb..e1ea686c1b7bc 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -63,9 +63,10 @@ type searchTask struct { offset int64 resultBuf *typeutil.ConcurrentSet[*internalpb.SearchResults] - qc types.QueryCoordClient - node types.ProxyComponent - lb LBPolicy + qc types.QueryCoordClient + node types.ProxyComponent + lb LBPolicy + queryChannelsTs map[string]Timestamp } func getPartitionIDs(ctx context.Context, dbName string, collectionName string, partitionNames []string) (partitionIDs []UniqueID, err error) { @@ -488,6 +489,13 @@ func (t *searchTask) PostExecute(ctx context.Context) error { return err } + t.queryChannelsTs = make(map[string]uint64) + for _, r := range toReduceResults { + for ch, ts := range r.GetChannelsMvcc() { + t.queryChannelsTs[ch] = ts + } + } + if len(toReduceResults) >= 1 { MetricType = toReduceResults[0].GetMetricType() } @@ -545,20 +553,20 @@ func (t *searchTask) PostExecute(ctx context.Context) error { return nil } -func (t *searchTask) searchShard(ctx context.Context, nodeID int64, qn types.QueryNodeClient, channelIDs ...string) error { +func (t *searchTask) searchShard(ctx context.Context, nodeID int64, qn types.QueryNodeClient, channel string) error { searchReq := typeutil.Clone(t.SearchRequest) searchReq.GetBase().TargetID = nodeID req := &querypb.SearchRequest{ Req: searchReq, - DmlChannels: channelIDs, + DmlChannels: []string{channel}, Scope: querypb.DataScope_All, - TotalChannelNum: int32(len(channelIDs)), + TotalChannelNum: int32(1), } log := log.Ctx(ctx).With(zap.Int64("collection", t.GetCollectionID()), zap.Int64s("partitionIDs", t.GetPartitionIDs()), zap.Int64("nodeID", nodeID), - zap.Strings("channels", channelIDs)) + zap.String("channel", channel)) var result *internalpb.SearchResults var err error @@ -619,7 +627,7 @@ func (t *searchTask) Requery() error { QueryParams: t.request.GetSearchParams(), } - return doRequery(t.ctx, t.GetCollectionID(), t.node, t.schema.CollectionSchema, queryReq, t.result) + return doRequery(t.ctx, t.GetCollectionID(), t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs) } func (t *searchTask) fillInEmptyResult(numQueries int64) { @@ -672,6 +680,7 @@ func doRequery(ctx context.Context, schema *schemapb.CollectionSchema, request *milvuspb.QueryRequest, result *milvuspb.SearchResults, + queryChannelsTs map[string]Timestamp, ) error { outputFields := request.GetOutputFields() pkField, err := typeutil.GetPrimaryFieldSchema(schema) @@ -680,7 +689,10 @@ func doRequery(ctx context.Context, } ids := result.GetResults().GetIds() plan := planparserv2.CreateRequeryPlan(pkField, ids) - + channelsMvcc := make(map[string]Timestamp) + for k, v := range queryChannelsTs { + channelsMvcc[k] = v + } qt := &queryTask{ ctx: ctx, Condition: NewTaskCondition(ctx), @@ -691,10 +703,12 @@ func doRequery(ctx context.Context, ), ReqID: paramtable.GetNodeID(), }, - request: request, - plan: plan, - qc: node.(*Proxy).queryCoord, - lb: node.(*Proxy).lbPolicy, + request: request, + plan: plan, + qc: node.(*Proxy).queryCoord, + lb: node.(*Proxy).lbPolicy, + channelsMvcc: channelsMvcc, + fastSkip: true, } queryResult, err := node.(*Proxy).query(ctx, qt) if err != nil { diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index 4da23d81763a3..f514dcd0aa89c 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -2061,7 +2061,7 @@ func TestSearchTask_Requery(t *testing.T) { lb := NewMockLBPolicy(t) lb.EXPECT().Execute(mock.Anything, mock.Anything).Run(func(ctx context.Context, workload CollectionWorkLoad) { - err = workload.exec(ctx, 0, qn) + err = workload.exec(ctx, 0, qn, "") assert.NoError(t, err) }).Return(nil) lb.EXPECT().UpdateCostMetrics(mock.Anything, mock.Anything).Return() @@ -2141,7 +2141,7 @@ func TestSearchTask_Requery(t *testing.T) { lb := NewMockLBPolicy(t) lb.EXPECT().Execute(mock.Anything, mock.Anything).Run(func(ctx context.Context, workload CollectionWorkLoad) { - _ = workload.exec(ctx, 0, qn) + _ = workload.exec(ctx, 0, qn, "") }).Return(fmt.Errorf("mock err 1")) node.lbPolicy = lb @@ -2175,7 +2175,7 @@ func TestSearchTask_Requery(t *testing.T) { lb := NewMockLBPolicy(t) lb.EXPECT().Execute(mock.Anything, mock.Anything).Run(func(ctx context.Context, workload CollectionWorkLoad) { - _ = workload.exec(ctx, 0, qn) + _ = workload.exec(ctx, 0, qn, "") }).Return(fmt.Errorf("mock err 1")) node.lbPolicy = lb diff --git a/internal/proxy/task_statistic.go b/internal/proxy/task_statistic.go index e423068829831..ec50b0a4daae4 100644 --- a/internal/proxy/task_statistic.go +++ b/internal/proxy/task_statistic.go @@ -273,19 +273,19 @@ func (g *getStatisticsTask) getStatisticsFromQueryNode(ctx context.Context) erro return nil } -func (g *getStatisticsTask) getStatisticsShard(ctx context.Context, nodeID int64, qn types.QueryNodeClient, channelIDs ...string) error { +func (g *getStatisticsTask) getStatisticsShard(ctx context.Context, nodeID int64, qn types.QueryNodeClient, channel string) error { nodeReq := proto.Clone(g.GetStatisticsRequest).(*internalpb.GetStatisticsRequest) nodeReq.Base.TargetID = nodeID req := &querypb.GetStatisticsRequest{ Req: nodeReq, - DmlChannels: channelIDs, + DmlChannels: []string{channel}, Scope: querypb.DataScope_All, } result, err := qn.GetStatistics(ctx, req) if err != nil { log.Warn("QueryNode statistic return error", zap.Int64("nodeID", nodeID), - zap.Strings("channels", channelIDs), + zap.String("channel", channel), zap.Error(err)) globalMetaCache.DeprecateShardCache(g.request.GetDbName(), g.collectionName) return err @@ -293,7 +293,7 @@ func (g *getStatisticsTask) getStatisticsShard(ctx context.Context, nodeID int64 if result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader { log.Warn("QueryNode is not shardLeader", zap.Int64("nodeID", nodeID), - zap.Strings("channels", channelIDs)) + zap.String("channel", channel)) globalMetaCache.DeprecateShardCache(g.request.GetDbName(), g.collectionName) return errInvalidShardLeaders } diff --git a/internal/querynodev2/delegator/delegator.go b/internal/querynodev2/delegator/delegator.go index c13f16e762dc1..c4ef2ef6b84fe 100644 --- a/internal/querynodev2/delegator/delegator.go +++ b/internal/querynodev2/delegator/delegator.go @@ -204,11 +204,14 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest // wait tsafe waitTr := timerecord.NewTimeRecorder("wait tSafe") - err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp) + tSafe, err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp) if err != nil { log.Warn("delegator search failed to wait tsafe", zap.Error(err)) return nil, err } + if req.GetReq().GetMvccTimestamp() == 0 { + req.Req.MvccTimestamp = tSafe + } metrics.QueryNodeSQLatencyWaitTSafe.WithLabelValues( fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel). Observe(float64(waitTr.ElapseSpan().Milliseconds())) @@ -279,11 +282,14 @@ func (sd *shardDelegator) QueryStream(ctx context.Context, req *querypb.QueryReq // wait tsafe waitTr := timerecord.NewTimeRecorder("wait tSafe") - err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp) + tSafe, err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp) if err != nil { log.Warn("delegator query failed to wait tsafe", zap.Error(err)) return err } + if req.GetReq().GetMvccTimestamp() == 0 { + req.Req.MvccTimestamp = tSafe + } metrics.QueryNodeSQLatencyWaitTSafe.WithLabelValues( fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel). Observe(float64(waitTr.ElapseSpan().Milliseconds())) @@ -347,11 +353,14 @@ func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest) // wait tsafe waitTr := timerecord.NewTimeRecorder("wait tSafe") - err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp) + tSafe, err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp) if err != nil { log.Warn("delegator query failed to wait tsafe", zap.Error(err)) return nil, err } + if req.GetReq().GetMvccTimestamp() == 0 { + req.Req.MvccTimestamp = tSafe + } metrics.QueryNodeSQLatencyWaitTSafe.WithLabelValues( fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel). Observe(float64(waitTr.ElapseSpan().Milliseconds())) @@ -410,7 +419,7 @@ func (sd *shardDelegator) GetStatistics(ctx context.Context, req *querypb.GetSta } // wait tsafe - err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp) + _, err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp) if err != nil { log.Warn("delegator GetStatistics failed to wait tsafe", zap.Error(err)) return nil, err @@ -552,14 +561,15 @@ func executeSubTasks[T any, R interface { } // waitTSafe returns when tsafe listener notifies a timestamp which meet the guarantee ts. -func (sd *shardDelegator) waitTSafe(ctx context.Context, ts uint64) error { +func (sd *shardDelegator) waitTSafe(ctx context.Context, ts uint64) (uint64, error) { log := sd.getLogger(ctx) // already safe to search - if sd.latestTsafe.Load() >= ts { - return nil + latestTSafe := sd.latestTsafe.Load() + if latestTSafe >= ts { + return latestTSafe, nil } // check lag duration too large - st, _ := tsoutil.ParseTS(sd.latestTsafe.Load()) + st, _ := tsoutil.ParseTS(latestTSafe) gt, _ := tsoutil.ParseTS(ts) lag := gt.Sub(st) maxLag := paramtable.Get().QueryNodeCfg.MaxTimestampLag.GetAsDuration(time.Second) @@ -570,7 +580,7 @@ func (sd *shardDelegator) waitTSafe(ctx context.Context, ts uint64) error { zap.Duration("lag", lag), zap.Duration("maxTsLag", maxLag), ) - return WrapErrTsLagTooLarge(lag, maxLag) + return 0, WrapErrTsLagTooLarge(lag, maxLag) } ch := make(chan struct{}) @@ -592,12 +602,12 @@ func (sd *shardDelegator) waitTSafe(ctx context.Context, ts uint64) error { case <-ctx.Done(): // notify wait goroutine to quit sd.tsCond.Broadcast() - return ctx.Err() + return 0, ctx.Err() case <-ch: if !sd.Serviceable() { - return merr.WrapErrChannelNotAvailable(sd.vchannelName, "delegator closed during wait tsafe") + return 0, merr.WrapErrChannelNotAvailable(sd.vchannelName, "delegator closed during wait tsafe") } - return nil + return sd.latestTsafe.Load(), nil } } } diff --git a/internal/querynodev2/handlers.go b/internal/querynodev2/handlers.go index 194e9af151d84..c27dd7a64ea3b 100644 --- a/internal/querynodev2/handlers.go +++ b/internal/querynodev2/handlers.go @@ -398,7 +398,6 @@ func (node *QueryNode) searchChannel(ctx context.Context, req *querypb.SearchReq metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.SuccessLabel, metrics.Leader).Inc() metrics.QueryNodeSearchNQ.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(req.Req.GetNq())) metrics.QueryNodeSearchTopK.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(req.Req.GetTopk())) - return resp, nil } diff --git a/internal/querynodev2/segments/plan.go b/internal/querynodev2/segments/plan.go index 3b85862d82d52..08825618c62fb 100644 --- a/internal/querynodev2/segments/plan.go +++ b/internal/querynodev2/segments/plan.go @@ -84,6 +84,7 @@ type SearchRequest struct { cPlaceholderGroup C.CPlaceholderGroup msgID UniqueID searchFieldID UniqueID + mvccTimestamp Timestamp } func NewSearchRequest(ctx context.Context, collection *Collection, req *querypb.SearchRequest, placeholderGrp []byte) (*SearchRequest, error) { @@ -123,6 +124,7 @@ func NewSearchRequest(ctx context.Context, collection *Collection, req *querypb. cPlaceholderGroup: cPlaceholderGroup, msgID: req.GetReq().GetBase().GetMsgID(), searchFieldID: int64(fieldID), + mvccTimestamp: req.GetReq().GetMvccTimestamp(), } return ret, nil diff --git a/internal/querynodev2/segments/reduce_test.go b/internal/querynodev2/segments/reduce_test.go index 2381cdd54e803..f6a587821b5fc 100644 --- a/internal/querynodev2/segments/reduce_test.go +++ b/internal/querynodev2/segments/reduce_test.go @@ -35,6 +35,7 @@ import ( "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) type ReduceSuite struct { @@ -168,6 +169,7 @@ func (suite *ReduceSuite) TestReduceAllFunc() { plan, err := createSearchPlanByExpr(context.Background(), suite.collection, serializedPlan, "") suite.NoError(err) searchReq, err := parseSearchRequest(context.Background(), plan, placeGroupByte) + searchReq.mvccTimestamp = typeutil.MaxTimestamp suite.NoError(err) defer searchReq.Delete() diff --git a/internal/querynodev2/segments/result.go b/internal/querynodev2/segments/result.go index 26a1197625c4a..1ed53aa8e0a1b 100644 --- a/internal/querynodev2/segments/result.go +++ b/internal/querynodev2/segments/result.go @@ -49,6 +49,12 @@ func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResult return results[0], nil } + channelsMvcc := make(map[string]uint64) + for _, r := range results { + for ch, ts := range r.GetChannelsMvcc() { + channelsMvcc[ch] = ts + } + } log := log.Ctx(ctx) searchResultData, err := DecodeSearchResults(results) @@ -88,7 +94,7 @@ func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResult return nil, false }) searchResults.CostAggregation = mergeRequestCost(requestCosts) - + searchResults.ChannelsMvcc = channelsMvcc return searchResults, nil } diff --git a/internal/querynodev2/segments/segment.go b/internal/querynodev2/segments/segment.go index efec5ddd2b97e..e66ee8419534e 100644 --- a/internal/querynodev2/segments/segment.go +++ b/internal/querynodev2/segments/segment.go @@ -388,6 +388,7 @@ func (s *LocalSegment) Search(ctx context.Context, searchReq *SearchRequest) (*S searchReq.plan.cSearchPlan, searchReq.cPlaceholderGroup, traceCtx, + C.uint64_t(searchReq.mvccTimestamp), &searchResult.cSearchResult, ) metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel).Observe(float64(tr.ElapseSpan().Milliseconds())) diff --git a/internal/querynodev2/services.go b/internal/querynodev2/services.go index a1e7091da867c..a963622604e9d 100644 --- a/internal/querynodev2/services.go +++ b/internal/querynodev2/services.go @@ -656,8 +656,13 @@ func (node *QueryNode) SearchSegments(ctx context.Context, req *querypb.SearchRe zap.String("channel", channel), zap.String("scope", req.GetScope().String()), ) - - resp := &internalpb.SearchResults{} + channelsMvcc := make(map[string]uint64) + for _, ch := range req.GetDmlChannels() { + channelsMvcc[ch] = req.GetReq().GetMvccTimestamp() + } + resp := &internalpb.SearchResults{ + ChannelsMvcc: channelsMvcc, + } if err := node.lifetime.Add(merr.IsHealthy); err != nil { resp.Status = merr.Status(err) return resp, nil @@ -733,7 +738,8 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) ( log.Debug("Received SearchRequest", zap.Int64s("segmentIDs", req.GetSegmentIDs()), - zap.Uint64("guaranteeTimestamp", req.GetReq().GetGuaranteeTimestamp())) + zap.Uint64("guaranteeTimestamp", req.GetReq().GetGuaranteeTimestamp()), + zap.Uint64("mvccTimestamp", req.GetReq().GetMvccTimestamp())) tr := timerecord.NewTimeRecorderWithTrace(ctx, "SearchRequest") @@ -763,6 +769,7 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) ( toReduceResults := make([]*internalpb.SearchResults, len(req.GetDmlChannels())) runningGp, runningCtx := errgroup.WithContext(ctx) + for i, ch := range req.GetDmlChannels() { ch := ch req := &querypb.SearchRequest{ diff --git a/internal/querynodev2/services_test.go b/internal/querynodev2/services_test.go index e74f723ae3819..c8048ebca5aa7 100644 --- a/internal/querynodev2/services_test.go +++ b/internal/querynodev2/services_test.go @@ -1144,6 +1144,7 @@ func (suite *ServiceSuite) genCSearchRequest(nq int64, dataType schemapb.DataTyp PlaceholderGroup: placeHolder, DslType: commonpb.DslType_BoolExprV1, Nq: nq, + MvccTimestamp: typeutil.MaxTimestamp, }, nil } diff --git a/internal/querynodev2/tasks/task.go b/internal/querynodev2/tasks/task.go index 9fbf12545a464..83498157a2764 100644 --- a/internal/querynodev2/tasks/task.go +++ b/internal/querynodev2/tasks/task.go @@ -257,6 +257,7 @@ func (t *SearchTask) Merge(other *SearchTask) bool { // Check mergeable if t.req.GetReq().GetDbID() != other.req.GetReq().GetDbID() || t.req.GetReq().GetCollectionID() != other.req.GetReq().GetCollectionID() || + t.req.GetReq().GetMvccTimestamp() != other.req.GetReq().GetMvccTimestamp() || t.req.GetReq().GetDslType() != other.req.GetReq().GetDslType() || t.req.GetDmlChannels()[0] != other.req.GetDmlChannels()[0] || nq+otherNq > paramtable.Get().QueryNodeCfg.MaxGroupNQ.GetAsInt64() || @@ -300,6 +301,13 @@ func (t *SearchTask) Wait() error { } func (t *SearchTask) Result() *internalpb.SearchResults { + if t.result != nil { + channelsMvcc := make(map[string]uint64) + for _, ch := range t.req.GetDmlChannels() { + channelsMvcc[ch] = t.req.GetReq().GetMvccTimestamp() + } + t.result.ChannelsMvcc = channelsMvcc + } return t.result }