Skip to content

Commit

Permalink
revert config
Browse files Browse the repository at this point in the history
Signed-off-by: xianliang.li <[email protected]>
  • Loading branch information
foxspy committed Oct 8, 2024
1 parent 75d2f0e commit 91ea18d
Show file tree
Hide file tree
Showing 10 changed files with 87 additions and 144 deletions.
4 changes: 2 additions & 2 deletions internal/core/src/index/ScalarIndex.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ class ScalarIndex : public IndexBase {

virtual bool
IsMmapSupported() const {
return index_type_ == milvus::index::BITMAP_INDEX_TYPE ||
index_type_ == milvus::index::HYBRID_INDEX_TYPE;
return index_type_ == milvus::index::BITMAP_INDEX_TYPE ||
index_type_ == milvus::index::HYBRID_INDEX_TYPE;
}

virtual int64_t
Expand Down
3 changes: 2 additions & 1 deletion internal/core/src/index/VectorIndex.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ class VectorIndex : public IndexBase {

virtual bool
IsMmapSupported() const {
return knowhere::IndexFactory::Instance().FeatureCheck(index_type_, knowhere::feature::MMAP);
return knowhere::IndexFactory::Instance().FeatureCheck(
index_type_, knowhere::feature::MMAP);
}

knowhere::Json
Expand Down
46 changes: 37 additions & 9 deletions internal/core/src/segcore/vector_index_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@
#include "pb/index_cgo_msg.pb.h"

CStatus
ValidateIndexParams(const char* index_type, enum CDataType data_type, const uint8_t* serialized_index_params, const uint64_t length) {
ValidateIndexParams(const char* index_type,
enum CDataType data_type,
const uint8_t* serialized_index_params,
const uint64_t length) {
try {
auto index_params =
std::make_unique<milvus::proto::indexcgo::IndexParams>();
auto res = index_params->ParseFromArray(serialized_index_params, length);
std::make_unique<milvus::proto::indexcgo::IndexParams>();
auto res =
index_params->ParseFromArray(serialized_index_params, length);
AssertInfo(res, "Unmarshall index params failed");

knowhere::Json json;
Expand All @@ -40,15 +44,40 @@ ValidateIndexParams(const char* index_type, enum CDataType data_type, const uint
knowhere::Status status;
std::string error_msg;
if (dataType == milvus::DataType::VECTOR_BINARY) {
status = knowhere::CheckConfig<knowhere::bin1>(index_type, knowhere::Version::GetCurrentVersion().VersionNumber(), json, knowhere::PARAM_TYPE::TRAIN, error_msg);
status = knowhere::CheckConfig<knowhere::bin1>(
index_type,
knowhere::Version::GetCurrentVersion().VersionNumber(),
json,
knowhere::PARAM_TYPE::TRAIN,
error_msg);
} else if (dataType == milvus::DataType::VECTOR_FLOAT) {
status = knowhere::CheckConfig<knowhere::fp32>(index_type, knowhere::Version::GetCurrentVersion().VersionNumber(), json, knowhere::PARAM_TYPE::TRAIN, error_msg);
status = knowhere::CheckConfig<knowhere::fp32>(
index_type,
knowhere::Version::GetCurrentVersion().VersionNumber(),
json,
knowhere::PARAM_TYPE::TRAIN,
error_msg);
} else if (dataType == milvus::DataType::VECTOR_BFLOAT16) {
status = knowhere::CheckConfig<knowhere::bf16>(index_type, knowhere::Version::GetCurrentVersion().VersionNumber(), json, knowhere::PARAM_TYPE::TRAIN, error_msg);
status = knowhere::CheckConfig<knowhere::bf16>(
index_type,
knowhere::Version::GetCurrentVersion().VersionNumber(),
json,
knowhere::PARAM_TYPE::TRAIN,
error_msg);
} else if (dataType == milvus::DataType::VECTOR_FLOAT16) {
status = knowhere::CheckConfig<knowhere::fp16>(index_type, knowhere::Version::GetCurrentVersion().VersionNumber(), json, knowhere::PARAM_TYPE::TRAIN, error_msg);
status = knowhere::CheckConfig<knowhere::fp16>(
index_type,
knowhere::Version::GetCurrentVersion().VersionNumber(),
json,
knowhere::PARAM_TYPE::TRAIN,
error_msg);
} else if (dataType == milvus::DataType::VECTOR_SPARSE_FLOAT) {
status = knowhere::CheckConfig<knowhere::fp32>(index_type, knowhere::Version::GetCurrentVersion().VersionNumber(), json, knowhere::PARAM_TYPE::TRAIN, error_msg);
status = knowhere::CheckConfig<knowhere::fp32>(
index_type,
knowhere::Version::GetCurrentVersion().VersionNumber(),
json,
knowhere::PARAM_TYPE::TRAIN,
error_msg);
} else {
status = knowhere::Status::invalid_args;
}
Expand Down Expand Up @@ -87,4 +116,3 @@ GetIndexFeatures(void* index_key_list, uint64_t* index_feature_list) {
idx++;
}
}

6 changes: 4 additions & 2 deletions internal/core/src/segcore/vector_index_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,17 @@ extern "C" {
#include "common/type_c.h"

CStatus
ValidateIndexParams(const char* index_type, enum CDataType data_type, const uint8_t* index_params, const uint64_t length);
ValidateIndexParams(const char* index_type,
enum CDataType data_type,
const uint8_t* index_params,
const uint64_t length);

int
GetIndexListSize();

void
GetIndexFeatures(void* index_key_list, uint64_t* index_feature_list);


#ifdef __cplusplus
}
#endif
2 changes: 1 addition & 1 deletion internal/datacoord/index_meta.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ import (
"github.com/milvus-io/milvus/internal/metastore"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
"github.com/milvus-io/milvus/internal/proto/workerpb"
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
Expand Down
3 changes: 3 additions & 0 deletions internal/util/hookutil/mock_hook.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
//go:build test
// +build test

/*
* Licensed to the LF AI & Data foundation under one
* or more contributor license agreements. See the NOTICE file
Expand Down
2 changes: 1 addition & 1 deletion internal/util/indexparamcheck/vector_index_checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func (c vecIndexChecker) CheckTrain(dataType schemapb.DataType, params map[strin

func (c vecIndexChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error {
if !typeutil.IsVectorType(field.GetDataType()) {
return fmt.Errorf("index %s only support vector data type", indexType)
return fmt.Errorf("index %s only supports vector data type", indexType)
}
if !GetVecIndexMgrInstance().IsDataTypeSupport(indexType, field.GetDataType()) {
return fmt.Errorf("index %s do not support data type: %s", indexType, schemapb.DataType_name[int32(field.GetDataType())])
Expand Down
48 changes: 0 additions & 48 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@ import (
"strings"

"github.com/cockroachdb/errors"
"github.com/spf13/cast"
"go.uber.org/zap"

"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)

Expand All @@ -33,10 +30,6 @@ var (
ErrKeyNotFound = errors.New("key not found")
)

const (
NotFormatPrefix = "knowhere."
)

func Init(opts ...Option) (*Manager, error) {
o := &Options{}
for _, opt := range opts {
Expand All @@ -63,9 +56,6 @@ func Init(opts ...Option) (*Manager, error) {
var formattedKeys = typeutil.NewConcurrentMap[string, string]()

func formatKey(key string) string {
if strings.HasPrefix(key, NotFormatPrefix) {
return key
}
cached, ok := formattedKeys.Get(key)
if ok {
return cached
Expand All @@ -74,41 +64,3 @@ func formatKey(key string) string {
formattedKeys.Insert(key, result)
return result
}

func parseConfig(prefix string, m map[string]interface{}, result map[string]string) {
for k, v := range m {
fullKey := k
if prefix != "" {
fullKey = prefix + "." + k
}

switch val := v.(type) {
case map[string]interface{}:
parseConfig(fullKey, val, result)
case []interface{}:
str := ""
for i, item := range val {
itemStr, err := cast.ToStringE(item)
if err != nil {
log.Warn("cast to string failed", zap.Any("item", item))
continue
}
if i == 0 {
str = itemStr
} else {
str = str + "," + itemStr
}
}
result[strings.ToLower(fullKey)] = str
result[formatKey(fullKey)] = str
default:
str, err := cast.ToStringE(val)
if err != nil {
log.Warn("cast to string failed", zap.Any("val", val))
continue
}
result[strings.ToLower(fullKey)] = str
result[formatKey(fullKey)] = str
}
}
}
107 changes: 32 additions & 75 deletions pkg/config/file_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
package config

import (
"github.com/spf13/cast"
"os"
"sync"

"github.com/cockroachdb/errors"
"github.com/samber/lo"
"github.com/spf13/cast"
"github.com/spf13/viper"
"go.uber.org/zap"
"gopkg.in/yaml.v3"

"github.com/milvus-io/milvus/pkg/log"
)
Expand Down Expand Up @@ -114,60 +114,8 @@ func (fs *FileSource) UpdateOptions(opts Options) {
fs.files = opts.FileInfo.Files
}

func (fs *FileSource) extractConfigFromNode(node *yaml.Node, config map[string]string, prefix string) error {
for i := 0; i < len(node.Content); i += 2 {
keyNode := node.Content[i]
valueNode := node.Content[i+1]

// Assuming keys are always strings
key := keyNode.Value
fullKey := key
if prefix != "" {
fullKey = prefix + "." + key
}

switch valueNode.Kind {
case yaml.ScalarNode:
// If it's a scalar, just cast it to string
str, err := cast.ToStringE(valueNode.Value)
if err != nil {
return err
}
config[fullKey] = str
config[formatKey(fullKey)] = str

case yaml.SequenceNode:
// Handle a list of values
var combinedStr string
for _, item := range valueNode.Content {
str, err := cast.ToStringE(item.Value)
if err != nil {
zap.L().Warn("cast to string failed", zap.Any("value", item.Value))
continue
}
if combinedStr == "" {
combinedStr = str
} else {
combinedStr = combinedStr + "," + str
}
}
config[fullKey] = combinedStr
config[formatKey(fullKey)] = combinedStr

case yaml.MappingNode:
// Recursively process nested mappings
if err := fs.extractConfigFromNode(valueNode, config, fullKey); err != nil {
return err
}

default:
zap.L().Warn("Unhandled YAML node type", zap.Any("node", valueNode))
}
}
return nil
}

func (fs *FileSource) loadFromFile() error {
yamlReader := viper.New()
newConfig := make(map[string]string)
var configFiles []string

Expand All @@ -180,28 +128,37 @@ func (fs *FileSource) loadFromFile() error {
continue
}

data, err := os.ReadFile(configFile)
if err != nil {
return errors.Wrap(err, "Failed to read config file: "+configFile)
}

var rootNode yaml.Node
if err := yaml.Unmarshal(data, &rootNode); err != nil {
return errors.Wrap(err, "Failed to unmarshal YAML: "+configFile)
}

if rootNode.Kind != yaml.DocumentNode || len(rootNode.Content) == 0 {
return errors.New("Invalid YAML structure in file: " + configFile)
}

// Assuming the top-level node is a map
if rootNode.Content[0].Kind != yaml.MappingNode {
return errors.New("YAML content is not a map in file: " + configFile)
yamlReader.SetConfigFile(configFile)
if err := yamlReader.ReadInConfig(); err != nil {
return errors.Wrap(err, "Read config failed: "+configFile)
}

err = fs.extractConfigFromNode(rootNode.Content[0], newConfig, "")
if err != nil {
return errors.Wrap(err, "Failed to extract config: "+configFile)
for _, key := range yamlReader.AllKeys() {
val := yamlReader.Get(key)
str, err := cast.ToStringE(val)
if err != nil {
switch val := val.(type) {
case []any:
str = str[:0]
for _, v := range val {
ss, err := cast.ToStringE(v)
if err != nil {
log.Warn("cast to string failed", zap.Any("value", v))
}
if str == "" {
str = ss
} else {
str = str + "," + ss
}
}

default:
log.Warn("val is not a slice", zap.Any("value", val))
continue
}
}
newConfig[key] = str
newConfig[formatKey(key)] = str
}
}

Expand Down
10 changes: 5 additions & 5 deletions tests/go_client/testcases/index_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ func TestCreateUnsupportedIndexArrayField(t *testing.T) {
if field.DataType == entity.FieldTypeArray {
// create vector index
_, err1 := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, field.Name, vectorIdx).WithIndexName("vector_index"))
common.CheckErr(t, err1, false, "data type should be FloatVector, Float16Vector or BFloat16Vector")
common.CheckErr(t, err1, false, "index SCANN only supports vector data type")

// create scalar index
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, field.Name, idxErr.idx))
Expand Down Expand Up @@ -840,11 +840,11 @@ func TestCreateSparseIndexInvalidParams(t *testing.T) {
for _, drb := range []float64{-0.3, 1.3} {
idxInverted := index.NewSparseInvertedIndex(entity.IP, drb)
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultSparseVecFieldName, idxInverted))
common.CheckErr(t, err, false, "must be in range [0, 1)")
common.CheckErr(t, err, false, "param drop_ratio_build out of range")

idxWand := index.NewSparseWANDIndex(entity.IP, drb)
_, err1 := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultSparseVecFieldName, idxWand))
common.CheckErr(t, err1, false, "must be in range [0, 1)")
common.CheckErr(t, err1, false, "param drop_ratio_build out of range")
}
}

Expand Down Expand Up @@ -972,7 +972,7 @@ func TestCreateIndexInvalidParams(t *testing.T) {
_, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, cp, hp.TNewFieldsOption(), hp.TNewSchemaOption().TWithEnableDynamicField(true))

// invalid IvfFlat nlist [1, 65536]
errMsg := "nlist out of range: [1, 65536]"
errMsg := "param nlist out of range"
for _, invalidNlist := range []int{0, -1, 65536 + 1} {
// IvfFlat
idxIvfFlat := index.NewIvfFlatIndex(entity.L2, invalidNlist)
Expand All @@ -997,7 +997,7 @@ func TestCreateIndexInvalidParams(t *testing.T) {
// IvfFlat
idxIvfPq := index.NewIvfPQIndex(entity.L2, 128, 8, invalidNBits)
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultFloatVecFieldName, idxIvfPq))
common.CheckErr(t, err, false, "parameter `nbits` out of range, expect range [1,64]")
common.CheckErr(t, err, false, "param nlist out of range")
}

idxIvfPq := index.NewIvfPQIndex(entity.L2, 128, 7, 8)
Expand Down

0 comments on commit 91ea18d

Please sign in to comment.