diff --git a/internal/core/src/index/ScalarIndex.h b/internal/core/src/index/ScalarIndex.h index f3b602c8821df..bdea066fdc8d3 100644 --- a/internal/core/src/index/ScalarIndex.h +++ b/internal/core/src/index/ScalarIndex.h @@ -135,8 +135,8 @@ class ScalarIndex : public IndexBase { virtual bool IsMmapSupported() const { - return index_type_ == milvus::index::BITMAP_INDEX_TYPE || - index_type_ == milvus::index::HYBRID_INDEX_TYPE; + return index_type_ == milvus::index::BITMAP_INDEX_TYPE || + index_type_ == milvus::index::HYBRID_INDEX_TYPE; } virtual int64_t diff --git a/internal/core/src/index/VectorIndex.h b/internal/core/src/index/VectorIndex.h index 95655db9e544a..a615956f01275 100644 --- a/internal/core/src/index/VectorIndex.h +++ b/internal/core/src/index/VectorIndex.h @@ -117,7 +117,8 @@ class VectorIndex : public IndexBase { virtual bool IsMmapSupported() const { - return knowhere::IndexFactory::Instance().FeatureCheck(index_type_, knowhere::feature::MMAP); + return knowhere::IndexFactory::Instance().FeatureCheck( + index_type_, knowhere::feature::MMAP); } knowhere::Json diff --git a/internal/core/src/segcore/vector_index_c.cpp b/internal/core/src/segcore/vector_index_c.cpp index e8b1af0eb4cbc..df072d7bf79d7 100644 --- a/internal/core/src/segcore/vector_index_c.cpp +++ b/internal/core/src/segcore/vector_index_c.cpp @@ -21,11 +21,15 @@ #include "pb/index_cgo_msg.pb.h" CStatus -ValidateIndexParams(const char* index_type, enum CDataType data_type, const uint8_t* serialized_index_params, const uint64_t length) { +ValidateIndexParams(const char* index_type, + enum CDataType data_type, + const uint8_t* serialized_index_params, + const uint64_t length) { try { auto index_params = - std::make_unique(); - auto res = index_params->ParseFromArray(serialized_index_params, length); + std::make_unique(); + auto res = + index_params->ParseFromArray(serialized_index_params, length); AssertInfo(res, "Unmarshall index params failed"); knowhere::Json json; @@ -40,15 +44,40 @@ ValidateIndexParams(const char* index_type, enum CDataType data_type, const uint knowhere::Status status; std::string error_msg; if (dataType == milvus::DataType::VECTOR_BINARY) { - status = knowhere::CheckConfig(index_type, knowhere::Version::GetCurrentVersion().VersionNumber(), json, knowhere::PARAM_TYPE::TRAIN, error_msg); + status = knowhere::CheckConfig( + index_type, + knowhere::Version::GetCurrentVersion().VersionNumber(), + json, + knowhere::PARAM_TYPE::TRAIN, + error_msg); } else if (dataType == milvus::DataType::VECTOR_FLOAT) { - status = knowhere::CheckConfig(index_type, knowhere::Version::GetCurrentVersion().VersionNumber(), json, knowhere::PARAM_TYPE::TRAIN, error_msg); + status = knowhere::CheckConfig( + index_type, + knowhere::Version::GetCurrentVersion().VersionNumber(), + json, + knowhere::PARAM_TYPE::TRAIN, + error_msg); } else if (dataType == milvus::DataType::VECTOR_BFLOAT16) { - status = knowhere::CheckConfig(index_type, knowhere::Version::GetCurrentVersion().VersionNumber(), json, knowhere::PARAM_TYPE::TRAIN, error_msg); + status = knowhere::CheckConfig( + index_type, + knowhere::Version::GetCurrentVersion().VersionNumber(), + json, + knowhere::PARAM_TYPE::TRAIN, + error_msg); } else if (dataType == milvus::DataType::VECTOR_FLOAT16) { - status = knowhere::CheckConfig(index_type, knowhere::Version::GetCurrentVersion().VersionNumber(), json, knowhere::PARAM_TYPE::TRAIN, error_msg); + status = knowhere::CheckConfig( + index_type, + knowhere::Version::GetCurrentVersion().VersionNumber(), + json, + knowhere::PARAM_TYPE::TRAIN, + error_msg); } else if (dataType == milvus::DataType::VECTOR_SPARSE_FLOAT) { - status = knowhere::CheckConfig(index_type, knowhere::Version::GetCurrentVersion().VersionNumber(), json, knowhere::PARAM_TYPE::TRAIN, error_msg); + status = knowhere::CheckConfig( + index_type, + knowhere::Version::GetCurrentVersion().VersionNumber(), + json, + knowhere::PARAM_TYPE::TRAIN, + error_msg); } else { status = knowhere::Status::invalid_args; } @@ -87,4 +116,3 @@ GetIndexFeatures(void* index_key_list, uint64_t* index_feature_list) { idx++; } } - diff --git a/internal/core/src/segcore/vector_index_c.h b/internal/core/src/segcore/vector_index_c.h index 06160e0bd6f9e..7e9b8f52391b9 100644 --- a/internal/core/src/segcore/vector_index_c.h +++ b/internal/core/src/segcore/vector_index_c.h @@ -18,7 +18,10 @@ extern "C" { #include "common/type_c.h" CStatus -ValidateIndexParams(const char* index_type, enum CDataType data_type, const uint8_t* index_params, const uint64_t length); +ValidateIndexParams(const char* index_type, + enum CDataType data_type, + const uint8_t* index_params, + const uint64_t length); int GetIndexListSize(); @@ -26,7 +29,6 @@ GetIndexListSize(); void GetIndexFeatures(void* index_key_list, uint64_t* index_feature_list); - #ifdef __cplusplus } #endif diff --git a/internal/datacoord/index_meta.go b/internal/datacoord/index_meta.go index c65946e7ac40f..be3a7f7000817 100644 --- a/internal/datacoord/index_meta.go +++ b/internal/datacoord/index_meta.go @@ -33,8 +33,8 @@ import ( "github.com/milvus-io/milvus/internal/metastore" "github.com/milvus-io/milvus/internal/metastore/model" "github.com/milvus-io/milvus/internal/proto/indexpb" - "github.com/milvus-io/milvus/internal/util/indexparamcheck" "github.com/milvus-io/milvus/internal/proto/workerpb" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" diff --git a/internal/util/hookutil/mock_hook.go b/internal/util/hookutil/mock_hook.go index f9dd39b830326..0808080c669f8 100644 --- a/internal/util/hookutil/mock_hook.go +++ b/internal/util/hookutil/mock_hook.go @@ -1,3 +1,6 @@ +//go:build test +// +build test + /* * Licensed to the LF AI & Data foundation under one * or more contributor license agreements. See the NOTICE file diff --git a/internal/util/indexparamcheck/vector_index_checker.go b/internal/util/indexparamcheck/vector_index_checker.go index 31c146ca204b2..8b5085671976a 100644 --- a/internal/util/indexparamcheck/vector_index_checker.go +++ b/internal/util/indexparamcheck/vector_index_checker.go @@ -79,7 +79,7 @@ func (c vecIndexChecker) CheckTrain(dataType schemapb.DataType, params map[strin func (c vecIndexChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { if !typeutil.IsVectorType(field.GetDataType()) { - return fmt.Errorf("index %s only support vector data type", indexType) + return fmt.Errorf("index %s only supports vector data type", indexType) } if !GetVecIndexMgrInstance().IsDataTypeSupport(indexType, field.GetDataType()) { return fmt.Errorf("index %s do not support data type: %s", indexType, schemapb.DataType_name[int32(field.GetDataType())]) diff --git a/pkg/config/config.go b/pkg/config/config.go index bd8bf83627ece..fc93c086f74d4 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -20,10 +20,7 @@ import ( "strings" "github.com/cockroachdb/errors" - "github.com/spf13/cast" - "go.uber.org/zap" - "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -33,10 +30,6 @@ var ( ErrKeyNotFound = errors.New("key not found") ) -const ( - NotFormatPrefix = "knowhere." -) - func Init(opts ...Option) (*Manager, error) { o := &Options{} for _, opt := range opts { @@ -63,9 +56,6 @@ func Init(opts ...Option) (*Manager, error) { var formattedKeys = typeutil.NewConcurrentMap[string, string]() func formatKey(key string) string { - if strings.HasPrefix(key, NotFormatPrefix) { - return key - } cached, ok := formattedKeys.Get(key) if ok { return cached @@ -74,41 +64,3 @@ func formatKey(key string) string { formattedKeys.Insert(key, result) return result } - -func parseConfig(prefix string, m map[string]interface{}, result map[string]string) { - for k, v := range m { - fullKey := k - if prefix != "" { - fullKey = prefix + "." + k - } - - switch val := v.(type) { - case map[string]interface{}: - parseConfig(fullKey, val, result) - case []interface{}: - str := "" - for i, item := range val { - itemStr, err := cast.ToStringE(item) - if err != nil { - log.Warn("cast to string failed", zap.Any("item", item)) - continue - } - if i == 0 { - str = itemStr - } else { - str = str + "," + itemStr - } - } - result[strings.ToLower(fullKey)] = str - result[formatKey(fullKey)] = str - default: - str, err := cast.ToStringE(val) - if err != nil { - log.Warn("cast to string failed", zap.Any("val", val)) - continue - } - result[strings.ToLower(fullKey)] = str - result[formatKey(fullKey)] = str - } - } -} diff --git a/pkg/config/file_source.go b/pkg/config/file_source.go index 43a629e3dad3b..e8402efe6b6ad 100644 --- a/pkg/config/file_source.go +++ b/pkg/config/file_source.go @@ -17,14 +17,14 @@ package config import ( - "github.com/spf13/cast" "os" "sync" "github.com/cockroachdb/errors" "github.com/samber/lo" + "github.com/spf13/cast" + "github.com/spf13/viper" "go.uber.org/zap" - "gopkg.in/yaml.v3" "github.com/milvus-io/milvus/pkg/log" ) @@ -114,60 +114,8 @@ func (fs *FileSource) UpdateOptions(opts Options) { fs.files = opts.FileInfo.Files } -func (fs *FileSource) extractConfigFromNode(node *yaml.Node, config map[string]string, prefix string) error { - for i := 0; i < len(node.Content); i += 2 { - keyNode := node.Content[i] - valueNode := node.Content[i+1] - - // Assuming keys are always strings - key := keyNode.Value - fullKey := key - if prefix != "" { - fullKey = prefix + "." + key - } - - switch valueNode.Kind { - case yaml.ScalarNode: - // If it's a scalar, just cast it to string - str, err := cast.ToStringE(valueNode.Value) - if err != nil { - return err - } - config[fullKey] = str - config[formatKey(fullKey)] = str - - case yaml.SequenceNode: - // Handle a list of values - var combinedStr string - for _, item := range valueNode.Content { - str, err := cast.ToStringE(item.Value) - if err != nil { - zap.L().Warn("cast to string failed", zap.Any("value", item.Value)) - continue - } - if combinedStr == "" { - combinedStr = str - } else { - combinedStr = combinedStr + "," + str - } - } - config[fullKey] = combinedStr - config[formatKey(fullKey)] = combinedStr - - case yaml.MappingNode: - // Recursively process nested mappings - if err := fs.extractConfigFromNode(valueNode, config, fullKey); err != nil { - return err - } - - default: - zap.L().Warn("Unhandled YAML node type", zap.Any("node", valueNode)) - } - } - return nil -} - func (fs *FileSource) loadFromFile() error { + yamlReader := viper.New() newConfig := make(map[string]string) var configFiles []string @@ -180,28 +128,37 @@ func (fs *FileSource) loadFromFile() error { continue } - data, err := os.ReadFile(configFile) - if err != nil { - return errors.Wrap(err, "Failed to read config file: "+configFile) - } - - var rootNode yaml.Node - if err := yaml.Unmarshal(data, &rootNode); err != nil { - return errors.Wrap(err, "Failed to unmarshal YAML: "+configFile) - } - - if rootNode.Kind != yaml.DocumentNode || len(rootNode.Content) == 0 { - return errors.New("Invalid YAML structure in file: " + configFile) - } - - // Assuming the top-level node is a map - if rootNode.Content[0].Kind != yaml.MappingNode { - return errors.New("YAML content is not a map in file: " + configFile) + yamlReader.SetConfigFile(configFile) + if err := yamlReader.ReadInConfig(); err != nil { + return errors.Wrap(err, "Read config failed: "+configFile) } - err = fs.extractConfigFromNode(rootNode.Content[0], newConfig, "") - if err != nil { - return errors.Wrap(err, "Failed to extract config: "+configFile) + for _, key := range yamlReader.AllKeys() { + val := yamlReader.Get(key) + str, err := cast.ToStringE(val) + if err != nil { + switch val := val.(type) { + case []any: + str = str[:0] + for _, v := range val { + ss, err := cast.ToStringE(v) + if err != nil { + log.Warn("cast to string failed", zap.Any("value", v)) + } + if str == "" { + str = ss + } else { + str = str + "," + ss + } + } + + default: + log.Warn("val is not a slice", zap.Any("value", val)) + continue + } + } + newConfig[key] = str + newConfig[formatKey(key)] = str } } diff --git a/tests/go_client/testcases/index_test.go b/tests/go_client/testcases/index_test.go index bfa4fcf502e2c..ce080b29e8e37 100644 --- a/tests/go_client/testcases/index_test.go +++ b/tests/go_client/testcases/index_test.go @@ -653,7 +653,7 @@ func TestCreateUnsupportedIndexArrayField(t *testing.T) { if field.DataType == entity.FieldTypeArray { // create vector index _, err1 := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, field.Name, vectorIdx).WithIndexName("vector_index")) - common.CheckErr(t, err1, false, "data type should be FloatVector, Float16Vector or BFloat16Vector") + common.CheckErr(t, err1, false, "index SCANN only supports vector data type") // create scalar index _, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, field.Name, idxErr.idx)) @@ -840,11 +840,11 @@ func TestCreateSparseIndexInvalidParams(t *testing.T) { for _, drb := range []float64{-0.3, 1.3} { idxInverted := index.NewSparseInvertedIndex(entity.IP, drb) _, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultSparseVecFieldName, idxInverted)) - common.CheckErr(t, err, false, "must be in range [0, 1)") + common.CheckErr(t, err, false, "param drop_ratio_build out of range") idxWand := index.NewSparseWANDIndex(entity.IP, drb) _, err1 := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultSparseVecFieldName, idxWand)) - common.CheckErr(t, err1, false, "must be in range [0, 1)") + common.CheckErr(t, err1, false, "param drop_ratio_build out of range") } } @@ -972,7 +972,7 @@ func TestCreateIndexInvalidParams(t *testing.T) { _, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, cp, hp.TNewFieldsOption(), hp.TNewSchemaOption().TWithEnableDynamicField(true)) // invalid IvfFlat nlist [1, 65536] - errMsg := "nlist out of range: [1, 65536]" + errMsg := "param nlist out of range" for _, invalidNlist := range []int{0, -1, 65536 + 1} { // IvfFlat idxIvfFlat := index.NewIvfFlatIndex(entity.L2, invalidNlist) @@ -997,7 +997,7 @@ func TestCreateIndexInvalidParams(t *testing.T) { // IvfFlat idxIvfPq := index.NewIvfPQIndex(entity.L2, 128, 8, invalidNBits) _, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultFloatVecFieldName, idxIvfPq)) - common.CheckErr(t, err, false, "parameter `nbits` out of range, expect range [1,64]") + common.CheckErr(t, err, false, "param nlist out of range") } idxIvfPq := index.NewIvfPQIndex(entity.L2, 128, 7, 8)