From b524413e6c027b2e6cbbdae13475cb7e5bb5a79c Mon Sep 17 00:00:00 2001 From: "xianliang.li" Date: Wed, 21 Aug 2024 18:04:12 +0800 Subject: [PATCH] add index config check Signed-off-by: xianliang.li --- cmd/tools/migration/mmap/mmap_230_240.go | 2 +- configs/milvus.yaml | 20 +++ internal/core/src/segcore/vector_index_c.cpp | 52 ++++++++ internal/core/src/segcore/vector_index_c.h | 4 + .../core/thirdparty/knowhere/CMakeLists.txt | 2 + internal/datacoord/compaction_trigger_v2.go | 2 +- internal/datacoord/index_meta.go | 2 +- internal/datacoord/index_service.go | 12 +- internal/datacoord/index_service_test.go | 7 +- internal/datacoord/task_index.go | 12 ++ internal/datacoord/util.go | 2 +- internal/indexnode/task_index.go | 9 +- internal/proxy/task_index.go | 34 ++--- internal/proxy/task_index_test.go | 2 +- internal/proxy/util.go | 2 +- internal/querynodev2/segments/collection.go | 2 +- .../querynodev2/segments/index_attr_cache.go | 6 +- .../segments/index_attr_cache_test.go | 3 +- internal/querynodev2/segments/segment.go | 2 +- .../segments/segment_loader_test.go | 2 +- internal/querynodev2/segments/utils.go | 2 +- internal/querynodev2/segments/utils_test.go | 3 +- .../indexparamcheck/auto_index_checker.go | 4 +- .../util/indexparamcheck/base_checker.go | 27 ++-- .../util/indexparamcheck/base_checker_test.go | 6 +- .../util/indexparamcheck/bin_flat_checker.go | 19 +++ .../indexparamcheck/bin_flat_checker_test.go | 4 +- .../indexparamcheck/bin_ivf_flat_checker.go | 9 +- .../bin_ivf_flat_checker_test.go | 4 +- .../binary_vector_base_checker.go | 8 +- .../binary_vector_base_checker_test.go | 2 +- .../indexparamcheck/bitmap_checker_test.go | 32 +++++ .../indexparamcheck/bitmap_index_checker.go | 6 +- .../util/indexparamcheck/cagra_checker.go | 7 +- .../indexparamcheck/cagra_checker_test.go | 2 +- .../util/indexparamcheck/conf_adapter_mgr.go | 23 +--- .../indexparamcheck/conf_adapter_mgr_test.go | 34 ++--- .../util/indexparamcheck/constraints.go | 0 .../util/indexparamcheck/diskann_checker.go | 4 +- .../indexparamcheck/diskann_checker_test.go | 4 +- .../util/indexparamcheck/flat_checker.go | 4 +- .../util/indexparamcheck/flat_checker_test.go | 5 +- .../float_vector_base_checker.go | 8 +- .../float_vector_base_checker_test.go | 2 +- .../util/indexparamcheck/hnsw_checker.go | 12 +- .../util/indexparamcheck/hnsw_checker_test.go | 6 +- .../indexparamcheck/hybrid_checker_test.go | 37 ++++++ .../indexparamcheck/hybrid_index_checker.go | 6 +- .../util/indexparamcheck/index_checker.go | 8 +- .../indexparamcheck/index_checker_test.go | 0 .../util/indexparamcheck/index_type.go | 18 +-- .../util/indexparamcheck/index_type_test.go | 0 .../util/indexparamcheck/inverted_checker.go | 6 +- .../indexparamcheck/inverted_checker_test.go | 25 ++++ .../util/indexparamcheck/ivf_base_checker.go | 28 +++++ .../indexparamcheck/ivf_base_checker_test.go | 4 +- .../util/indexparamcheck/ivf_pq_checker.go | 5 +- .../indexparamcheck/ivf_pq_checker_test.go | 4 +- .../util/indexparamcheck/ivf_sq_checker.go | 5 +- .../indexparamcheck/ivf_sq_checker_test.go | 4 +- .../raft_brute_force_checker.go | 9 +- .../raft_brute_force_checker_test.go | 2 +- .../indexparamcheck/raft_ivf_flat_checker.go | 9 +- .../raft_ivf_flat_checker_test.go | 4 +- .../indexparamcheck/raft_ivf_pq_checker.go | 5 +- .../raft_ivf_pq_checker_test.go | 4 +- .../indexparamcheck/scalar_index_checker.go | 11 ++ .../scalar_index_checker_test.go | 2 +- .../util/indexparamcheck/scann_checker.go | 5 +- .../indexparamcheck/scann_checker_test.go | 4 +- .../sparse_float_vector_base_checker.go | 8 +- .../sparse_inverted_index_checker.go | 0 .../util/indexparamcheck/stl_sort_checker.go | 6 +- .../indexparamcheck/stl_sort_checker_test.go | 22 ++++ .../util/indexparamcheck/trie_checker.go | 6 +- .../util/indexparamcheck/trie_checker_test.go | 23 ++++ .../util/indexparamcheck/utils.go | 0 .../util/indexparamcheck/utils_test.go | 0 .../indexparamcheck/vector_index_checker.go | 95 ++++++++++++++ .../util/indexparamcheck/vector_index_mgr.go | 72 ++++++++++- .../indexparamcheck/vector_index_mgr_test.go | 0 pkg/config/config.go | 48 +++++++ pkg/config/file_source.go | 41 ++---- pkg/util/indexparamcheck/bin_flat_checker.go | 17 --- .../indexparamcheck/hybrid_checker_test.go | 37 ------ .../indexparamcheck/inverted_checker_test.go | 25 ---- pkg/util/indexparamcheck/ivf_base_checker.go | 26 ---- .../indexparamcheck/scalar_index_checker.go | 9 -- .../indexparamcheck/stl_sort_checker_test.go | 22 ---- pkg/util/indexparamcheck/trie_checker_test.go | 23 ---- pkg/util/paramtable/autoindex_param.go | 6 +- pkg/util/paramtable/autoindex_param_test.go | 2 +- pkg/util/paramtable/component_param.go | 2 + pkg/util/paramtable/indexengine_param.go | 117 ++++++++++++++++++ pkg/util/paramtable/indexengine_param_test.go | 20 +++ tests/integration/import/import_test.go | 3 +- 96 files changed, 849 insertions(+), 402 deletions(-) rename {pkg => internal}/util/indexparamcheck/auto_index_checker.go (59%) rename {pkg => internal}/util/indexparamcheck/base_checker.go (67%) rename {pkg => internal}/util/indexparamcheck/base_checker_test.go (93%) create mode 100644 internal/util/indexparamcheck/bin_flat_checker.go rename {pkg => internal}/util/indexparamcheck/bin_flat_checker_test.go (96%) rename {pkg => internal}/util/indexparamcheck/bin_ivf_flat_checker.go (55%) rename {pkg => internal}/util/indexparamcheck/bin_ivf_flat_checker_test.go (97%) rename {pkg => internal}/util/indexparamcheck/binary_vector_base_checker.go (72%) rename {pkg => internal}/util/indexparamcheck/binary_vector_base_checker_test.go (96%) create mode 100644 internal/util/indexparamcheck/bitmap_checker_test.go rename {pkg => internal}/util/indexparamcheck/bitmap_index_checker.go (79%) rename {pkg => internal}/util/indexparamcheck/cagra_checker.go (84%) rename {pkg => internal}/util/indexparamcheck/cagra_checker_test.go (98%) rename {pkg => internal}/util/indexparamcheck/conf_adapter_mgr.go (71%) rename {pkg => internal}/util/indexparamcheck/conf_adapter_mgr_test.go (78%) rename {pkg => internal}/util/indexparamcheck/constraints.go (100%) rename {pkg => internal}/util/indexparamcheck/diskann_checker.go (59%) rename {pkg => internal}/util/indexparamcheck/diskann_checker_test.go (95%) rename {pkg => internal}/util/indexparamcheck/flat_checker.go (52%) rename {pkg => internal}/util/indexparamcheck/flat_checker_test.go (91%) rename {pkg => internal}/util/indexparamcheck/float_vector_base_checker.go (69%) rename {pkg => internal}/util/indexparamcheck/float_vector_base_checker_test.go (94%) rename {pkg => internal}/util/indexparamcheck/hnsw_checker.go (72%) rename {pkg => internal}/util/indexparamcheck/hnsw_checker_test.go (96%) create mode 100644 internal/util/indexparamcheck/hybrid_checker_test.go rename {pkg => internal}/util/indexparamcheck/hybrid_index_checker.go (83%) rename {pkg => internal}/util/indexparamcheck/index_checker.go (77%) rename {pkg => internal}/util/indexparamcheck/index_checker_test.go (100%) rename {pkg => internal}/util/indexparamcheck/index_type.go (78%) rename {pkg => internal}/util/indexparamcheck/index_type_test.go (100%) rename {pkg => internal}/util/indexparamcheck/inverted_checker.go (70%) create mode 100644 internal/util/indexparamcheck/inverted_checker_test.go create mode 100644 internal/util/indexparamcheck/ivf_base_checker.go rename {pkg => internal}/util/indexparamcheck/ivf_base_checker_test.go (95%) rename {pkg => internal}/util/indexparamcheck/ivf_pq_checker.go (87%) rename {pkg => internal}/util/indexparamcheck/ivf_pq_checker_test.go (97%) rename {pkg => internal}/util/indexparamcheck/ivf_sq_checker.go (78%) rename {pkg => internal}/util/indexparamcheck/ivf_sq_checker_test.go (96%) rename {pkg => internal}/util/indexparamcheck/raft_brute_force_checker.go (62%) rename {pkg => internal}/util/indexparamcheck/raft_brute_force_checker_test.go (96%) rename {pkg => internal}/util/indexparamcheck/raft_ivf_flat_checker.go (75%) rename {pkg => internal}/util/indexparamcheck/raft_ivf_flat_checker_test.go (96%) rename {pkg => internal}/util/indexparamcheck/raft_ivf_pq_checker.go (88%) rename {pkg => internal}/util/indexparamcheck/raft_ivf_pq_checker_test.go (97%) create mode 100644 internal/util/indexparamcheck/scalar_index_checker.go rename {pkg => internal}/util/indexparamcheck/scalar_index_checker_test.go (70%) rename {pkg => internal}/util/indexparamcheck/scann_checker.go (78%) rename {pkg => internal}/util/indexparamcheck/scann_checker_test.go (96%) rename {pkg => internal}/util/indexparamcheck/sparse_float_vector_base_checker.go (75%) rename {pkg => internal}/util/indexparamcheck/sparse_inverted_index_checker.go (100%) rename {pkg => internal}/util/indexparamcheck/stl_sort_checker.go (65%) create mode 100644 internal/util/indexparamcheck/stl_sort_checker_test.go rename {pkg => internal}/util/indexparamcheck/trie_checker.go (64%) create mode 100644 internal/util/indexparamcheck/trie_checker_test.go rename {pkg => internal}/util/indexparamcheck/utils.go (100%) rename {pkg => internal}/util/indexparamcheck/utils_test.go (100%) create mode 100644 internal/util/indexparamcheck/vector_index_checker.go rename {pkg => internal}/util/indexparamcheck/vector_index_mgr.go (69%) rename {pkg => internal}/util/indexparamcheck/vector_index_mgr_test.go (100%) delete mode 100644 pkg/util/indexparamcheck/bin_flat_checker.go delete mode 100644 pkg/util/indexparamcheck/hybrid_checker_test.go delete mode 100644 pkg/util/indexparamcheck/inverted_checker_test.go delete mode 100644 pkg/util/indexparamcheck/ivf_base_checker.go delete mode 100644 pkg/util/indexparamcheck/scalar_index_checker.go delete mode 100644 pkg/util/indexparamcheck/stl_sort_checker_test.go delete mode 100644 pkg/util/indexparamcheck/trie_checker_test.go create mode 100644 pkg/util/paramtable/indexengine_param.go create mode 100644 pkg/util/paramtable/indexengine_param_test.go diff --git a/cmd/tools/migration/mmap/mmap_230_240.go b/cmd/tools/migration/mmap/mmap_230_240.go index 8994551d02d7a..d25136b2b1943 100644 --- a/cmd/tools/migration/mmap/mmap_230_240.go +++ b/cmd/tools/migration/mmap/mmap_230_240.go @@ -3,6 +3,7 @@ package mmap import ( "context" "fmt" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/metastore" @@ -10,7 +11,6 @@ import ( "github.com/milvus-io/milvus/internal/rootcoord" "github.com/milvus-io/milvus/internal/tso" "github.com/milvus-io/milvus/pkg/common" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" ) // In Milvus 2.3.x, querynode.MmapDirPath is used to enable mmap and save mmap files. diff --git a/configs/milvus.yaml b/configs/milvus.yaml index f931e3cbde1dc..82997519b7ae4 100644 --- a/configs/milvus.yaml +++ b/configs/milvus.yaml @@ -1044,3 +1044,23 @@ streamingNode: serverMaxRecvSize: 268435456 # The maximum size of each RPC request that the streamingNode can receive, unit: byte clientMaxSendSize: 268435456 # The maximum size of each RPC request that the clients on streamingNode can send, unit: byte clientMaxRecvSize: 268435456 # The maximum size of each RPC request that the clients on streamingNode can receive, unit: byte + +knowhere: + enable: true + HNSW: + build: + efConstruction : 360 + M: 30 + search: + ef: 30 + DISKANN: + build: + max_degree: 56 + search_list_size: 100 + pq_code_budget_gb_ratio: 0.125 + search_cache_budget_gb_ratio: 0.1 + beam_width_ratio: 4 + load: + cacheRatio: 0.1 + search: + beamRatio: 4.0 \ No newline at end of file diff --git a/internal/core/src/segcore/vector_index_c.cpp b/internal/core/src/segcore/vector_index_c.cpp index b45a684d57903..558282bd3eb4a 100644 --- a/internal/core/src/segcore/vector_index_c.cpp +++ b/internal/core/src/segcore/vector_index_c.cpp @@ -11,9 +11,61 @@ #include "segcore/vector_index_c.h" +#include "common/Types.h" +#include "common/EasyAssert.h" #include "knowhere/utils.h" +#include "knowhere/config.h" +#include "knowhere/version.h" #include "index/Meta.h" #include "index/IndexFactory.h" +#include "pb/index_cgo_msg.pb.h" + +CStatus +ValidateIndexParams(const char* index_type, enum CDataType data_type, const uint8_t* serialized_index_params, const uint64_t length) { + try { + auto index_params = + std::make_unique(); + auto res = index_params->ParseFromArray(serialized_index_params, length); + AssertInfo(res, "Unmarshall index params failed"); + + knowhere::Json json; + + for (size_t i = 0; i < index_params->params_size(); i++) { + auto& param = index_params->params(i); + json[param.key()] = param.value(); + } + + milvus::DataType dataType(static_cast(data_type)); + + knowhere::Status status; + std::string error_msg; + if (dataType == milvus::DataType::VECTOR_BINARY) { + status = knowhere::CheckConfig(index_type, knowhere::Version::GetCurrentVersion().VersionNumber(), json, knowhere::PARAM_TYPE::TRAIN, error_msg); + } else if (dataType == milvus::DataType::VECTOR_FLOAT) { + status = knowhere::CheckConfig(index_type, knowhere::Version::GetCurrentVersion().VersionNumber(), json, knowhere::PARAM_TYPE::TRAIN, error_msg); + } else if (dataType == milvus::DataType::VECTOR_BFLOAT16) { + status = knowhere::CheckConfig(index_type, knowhere::Version::GetCurrentVersion().VersionNumber(), json, knowhere::PARAM_TYPE::TRAIN, error_msg); + } else if (dataType == milvus::DataType::VECTOR_FLOAT16) { + status = knowhere::CheckConfig(index_type, knowhere::Version::GetCurrentVersion().VersionNumber(), json, knowhere::PARAM_TYPE::TRAIN, error_msg); + } else if (dataType == milvus::DataType::VECTOR_SPARSE_FLOAT) { + status = knowhere::CheckConfig(index_type, knowhere::Version::GetCurrentVersion().VersionNumber(), json, knowhere::PARAM_TYPE::TRAIN, error_msg); + } else { + status = knowhere::Status::invalid_args; + } + CStatus cStatus; + if (status == knowhere::Status::success) { + cStatus.error_code = milvus::Success; + cStatus.error_msg = ""; + } else { + cStatus.error_code = milvus::ConfigInvalid; + cStatus.error_msg = error_msg.c_str(); + } + } catch (std::exception& e) { + auto status = CStatus(); + status.error_code = milvus::UnexpectedError; + status.error_msg = strdup(e.what()); + } +} int GetIndexListSize() { diff --git a/internal/core/src/segcore/vector_index_c.h b/internal/core/src/segcore/vector_index_c.h index d38a06b7a1e15..06160e0bd6f9e 100644 --- a/internal/core/src/segcore/vector_index_c.h +++ b/internal/core/src/segcore/vector_index_c.h @@ -15,6 +15,10 @@ extern "C" { #endif #include +#include "common/type_c.h" + +CStatus +ValidateIndexParams(const char* index_type, enum CDataType data_type, const uint8_t* index_params, const uint64_t length); int GetIndexListSize(); diff --git a/internal/core/thirdparty/knowhere/CMakeLists.txt b/internal/core/thirdparty/knowhere/CMakeLists.txt index 3585ee2f7b3cd..c8304e840f2a7 100644 --- a/internal/core/thirdparty/knowhere/CMakeLists.txt +++ b/internal/core/thirdparty/knowhere/CMakeLists.txt @@ -60,3 +60,5 @@ endif() # get prometheus COMPILE_OPTIONS get_property( var DIRECTORY "${knowhere_SOURCE_DIR}" PROPERTY COMPILE_OPTIONS ) message( STATUS "knowhere src compile options: ${var}" ) + +set( KNOWHERE_INCLUDE_DIR ${knowhere_SOURCE_DIR}/include CACHE INTERNAL "Path to knowhere include directory" ) diff --git a/internal/datacoord/compaction_trigger_v2.go b/internal/datacoord/compaction_trigger_v2.go index add3ccdb3162a..3aec22c21626d 100644 --- a/internal/datacoord/compaction_trigger_v2.go +++ b/internal/datacoord/compaction_trigger_v2.go @@ -18,6 +18,7 @@ package datacoord import ( "context" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "sync" "time" @@ -29,7 +30,6 @@ import ( "github.com/milvus-io/milvus/internal/metastore/model" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/pkg/log" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/lock" "github.com/milvus-io/milvus/pkg/util/logutil" "github.com/milvus-io/milvus/pkg/util/typeutil" diff --git a/internal/datacoord/index_meta.go b/internal/datacoord/index_meta.go index 41685db08e805..c6d93851f4574 100644 --- a/internal/datacoord/index_meta.go +++ b/internal/datacoord/index_meta.go @@ -20,6 +20,7 @@ package datacoord import ( "context" "fmt" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "strconv" "sync" @@ -36,7 +37,6 @@ import ( "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/timerecord" "github.com/milvus-io/milvus/pkg/util/typeutil" ) diff --git a/internal/datacoord/index_service.go b/internal/datacoord/index_service.go index dea94dd9a49a2..567588b8fc287 100644 --- a/internal/datacoord/index_service.go +++ b/internal/datacoord/index_service.go @@ -19,6 +19,7 @@ package datacoord import ( "context" "fmt" + indexparamcheck2 "github.com/milvus-io/milvus/internal/util/indexparamcheck" "time" "github.com/samber/lo" @@ -30,7 +31,6 @@ import ( "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metautil" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -222,7 +222,7 @@ func (s *Server) CreateIndex(ctx context.Context, req *indexpb.CreateIndexReques metrics.IndexRequestCounter.WithLabelValues(metrics.FailLabel).Inc() return merr.Status(err), nil } - if indexparamcheck.GetVecIndexMgrInstance().IsDiskANN(GetIndexType(req.IndexParams)) && !s.indexNodeManager.ClientSupportDisk() { + if indexparamcheck2.GetVecIndexMgrInstance().IsDiskANN(GetIndexType(req.IndexParams)) && !s.indexNodeManager.ClientSupportDisk() { errMsg := "all IndexNodes do not support disk indexes, please verify" log.Warn(errMsg) err = merr.WrapErrIndexNotSupported(GetIndexType(req.IndexParams)) @@ -273,16 +273,16 @@ func ValidateIndexParams(index *model.Index) error { indexType := GetIndexType(index.IndexParams) indexParams := funcutil.KeyValuePair2Map(index.IndexParams) userIndexParams := funcutil.KeyValuePair2Map(index.UserIndexParams) - if err := indexparamcheck.ValidateMmapIndexParams(indexType, indexParams); err != nil { + if err := indexparamcheck2.ValidateMmapIndexParams(indexType, indexParams); err != nil { return merr.WrapErrParameterInvalidMsg("invalid mmap index params", err.Error()) } - if err := indexparamcheck.ValidateMmapIndexParams(indexType, userIndexParams); err != nil { + if err := indexparamcheck2.ValidateMmapIndexParams(indexType, userIndexParams); err != nil { return merr.WrapErrParameterInvalidMsg("invalid mmap user index params", err.Error()) } - if err := indexparamcheck.ValidateOffsetCacheIndexParams(indexType, indexParams); err != nil { + if err := indexparamcheck2.ValidateOffsetCacheIndexParams(indexType, indexParams); err != nil { return merr.WrapErrParameterInvalidMsg("invalid offset cache index params", err.Error()) } - if err := indexparamcheck.ValidateOffsetCacheIndexParams(indexType, userIndexParams); err != nil { + if err := indexparamcheck2.ValidateOffsetCacheIndexParams(indexType, userIndexParams); err != nil { return merr.WrapErrParameterInvalidMsg("invalid offset cache index params", err.Error()) } return nil diff --git a/internal/datacoord/index_service_test.go b/internal/datacoord/index_service_test.go index 75257243d5700..859e85e783b5e 100644 --- a/internal/datacoord/index_service_test.go +++ b/internal/datacoord/index_service_test.go @@ -19,6 +19,7 @@ package datacoord import ( "context" "fmt" + indexparamcheck2 "github.com/milvus-io/milvus/internal/util/indexparamcheck" "testing" "time" @@ -2427,7 +2428,7 @@ func TestValidateIndexParams(t *testing.T) { IndexParams: []*commonpb.KeyValuePair{ { Key: common.IndexTypeKey, - Value: indexparamcheck.AutoIndex, + Value: indexparamcheck2.AutoIndex, }, { Key: common.MmapEnabledKey, @@ -2444,7 +2445,7 @@ func TestValidateIndexParams(t *testing.T) { IndexParams: []*commonpb.KeyValuePair{ { Key: common.IndexTypeKey, - Value: indexparamcheck.AutoIndex, + Value: indexparamcheck2.AutoIndex, }, { Key: common.MmapEnabledKey, @@ -2461,7 +2462,7 @@ func TestValidateIndexParams(t *testing.T) { IndexParams: []*commonpb.KeyValuePair{ { Key: common.IndexTypeKey, - Value: indexparamcheck.AutoIndex, + Value: indexparamcheck2.AutoIndex, }, }, UserIndexParams: []*commonpb.KeyValuePair{ diff --git a/internal/datacoord/task_index.go b/internal/datacoord/task_index.go index 7cf782322ef4d..1c4858955a429 100644 --- a/internal/datacoord/task_index.go +++ b/internal/datacoord/task_index.go @@ -18,6 +18,7 @@ package datacoord import ( "context" + "github.com/milvus-io/milvus/pkg/util/paramtable" "path" "time" @@ -168,6 +169,17 @@ func (it *indexBuildTask) PreCheck(ctx context.Context, dependency *taskSchedule fieldID := dependency.meta.indexMeta.GetFieldIDByIndexID(segIndex.CollectionID, segIndex.IndexID) binlogIDs := getBinLogIDs(segment, fieldID) + if Params.IndexEngineConfig.Enable.GetAsBool() { + var ret error + indexParams, ret = Params.IndexEngineConfig.MergeRequestParam(GetIndexType(indexParams), paramtable.BuildStage, indexParams) + + if ret != nil { + log.Ctx(ctx).Warn("failed to construct index build params", zap.Int64("taskID", it.taskID), zap.Error(ret)) + it.SetState(indexpb.JobState_JobStateInit, ret.Error()) + return true + } + } + if isDiskANNIndex(GetIndexType(indexParams)) { var err error indexParams, err = indexparams.UpdateDiskIndexBuildParams(Params, indexParams) diff --git a/internal/datacoord/util.go b/internal/datacoord/util.go index 164727ddc746a..ab0384d2d5b3d 100644 --- a/internal/datacoord/util.go +++ b/internal/datacoord/util.go @@ -19,6 +19,7 @@ package datacoord import ( "context" "fmt" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "strconv" "strings" "time" @@ -33,7 +34,6 @@ import ( "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/tsoutil" diff --git a/internal/indexnode/task_index.go b/internal/indexnode/task_index.go index e926eb3c6a73e..2284c8e41e41d 100644 --- a/internal/indexnode/task_index.go +++ b/internal/indexnode/task_index.go @@ -19,6 +19,7 @@ package indexnode import ( "context" "fmt" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "strconv" "strings" "time" @@ -35,7 +36,6 @@ import ( "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/indexparams" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metautil" @@ -208,6 +208,7 @@ func (it *indexBuildTask) Execute(ctx context.Context) error { zap.Int32("currentIndexVersion", it.req.GetCurrentIndexVersion())) indexType := it.newIndexParams[common.IndexTypeKey] + var fieldDataSize uint64 if indexparamcheck.GetVecIndexMgrInstance().IsDiskANN(indexType) { // check index node support disk index if !Params.IndexNodeCfg.EnableDisk.GetAsBool() { @@ -223,7 +224,7 @@ func (it *indexBuildTask) Execute(ctx context.Context) error { log.Warn("IndexNode get local used size failed") return err } - fieldDataSize, err := estimateFieldDataSize(it.req.GetDim(), it.req.GetNumRows(), it.req.GetField().GetDataType()) + fieldDataSize, err = estimateFieldDataSize(it.req.GetDim(), it.req.GetNumRows(), it.req.GetField().GetDataType()) if err != nil { log.Warn("IndexNode get local used size failed") return err @@ -245,6 +246,10 @@ func (it *indexBuildTask) Execute(ctx context.Context) error { } } + if Params.IndexEngineConfig.Enable.GetAsBool() { + it.newIndexParams, _ = Params.IndexEngineConfig.MergeWithResource(fieldDataSize, it.newIndexParams) + } + storageConfig := &indexcgopb.StorageConfig{ Address: it.req.GetStorageConfig().GetAddress(), AccessKeyID: it.req.GetStorageConfig().GetAccessKeyID(), diff --git a/internal/proxy/task_index.go b/internal/proxy/task_index.go index 994912d9e292f..258c794d65c74 100644 --- a/internal/proxy/task_index.go +++ b/internal/proxy/task_index.go @@ -19,6 +19,7 @@ package proxy import ( "context" "fmt" + indexparamcheck2 "github.com/milvus-io/milvus/internal/util/indexparamcheck" "github.com/cockroachdb/errors" "go.uber.org/zap" @@ -33,7 +34,6 @@ import ( "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/indexparams" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metric" @@ -155,13 +155,13 @@ func (cit *createIndexTask) parseIndexParams() error { specifyIndexType, exist := indexParamsMap[common.IndexTypeKey] if exist && specifyIndexType != "" { - if err := indexparamcheck.ValidateMmapIndexParams(specifyIndexType, indexParamsMap); err != nil { + if err := indexparamcheck2.ValidateMmapIndexParams(specifyIndexType, indexParamsMap); err != nil { log.Ctx(cit.ctx).Warn("Invalid mmap type params", zap.String(common.IndexTypeKey, specifyIndexType), zap.Error(err)) return merr.WrapErrParameterInvalidMsg("invalid mmap type params", err.Error()) } - checker, err := indexparamcheck.GetIndexCheckerMgrInstance().GetChecker(specifyIndexType) + checker, err := indexparamcheck2.GetIndexCheckerMgrInstance().GetChecker(specifyIndexType) // not enable hybrid index for user, used in milvus internally - if err != nil || indexparamcheck.IsHYBRIDChecker(checker) { + if err != nil || indexparamcheck2.IsHYBRIDChecker(checker) { log.Ctx(cit.ctx).Warn("Failed to get index checker", zap.String(common.IndexTypeKey, specifyIndexType)) return merr.WrapErrParameterInvalid("valid index", fmt.Sprintf("invalid index type: %s", specifyIndexType)) } @@ -297,7 +297,14 @@ func (cit *createIndexTask) parseIndexParams() error { if !exist { return fmt.Errorf("IndexType not specified") } - if indexparamcheck.GetVecIndexMgrInstance().IsDiskANN(indexType) { + if Params.IndexEngineConfig.Enable.GetAsBool() { + var err error + indexParamsMap, err = Params.IndexEngineConfig.MergeRequestMapParam(indexType, paramtable.BuildStage, indexParamsMap) + if err != nil { + return err + } + } + if indexparamcheck2.GetVecIndexMgrInstance().IsDiskANN(indexType) { err := indexparams.FillDiskIndexParams(Params, indexParamsMap) if err != nil { return err @@ -308,7 +315,7 @@ func (cit *createIndexTask) parseIndexParams() error { return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "metric type not set for vector index") } if typeutil.IsDenseFloatVectorType(cit.fieldSchema.DataType) { - if !funcutil.SliceContain(indexparamcheck.FloatVectorMetrics, metricType) { + if !funcutil.SliceContain(indexparamcheck2.FloatVectorMetrics, metricType) { return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "float vector index does not support metric type: "+metricType) } } else if typeutil.IsSparseFloatVectorType(cit.fieldSchema.DataType) { @@ -316,7 +323,7 @@ func (cit *createIndexTask) parseIndexParams() error { return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "only IP is the supported metric type for sparse index") } } else if typeutil.IsBinaryVectorType(cit.fieldSchema.DataType) { - if !funcutil.SliceContain(indexparamcheck.BinaryVectorMetrics, metricType) { + if !funcutil.SliceContain(indexparamcheck2.BinaryVectorMetrics, metricType) { return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "binary vector index does not support metric type: "+metricType) } } @@ -391,19 +398,19 @@ func fillDimension(field *schemapb.FieldSchema, indexParams map[string]string) e func checkTrain(field *schemapb.FieldSchema, indexParams map[string]string) error { indexType := indexParams[common.IndexTypeKey] - if indexType == indexparamcheck.IndexHybrid { + if indexType == indexparamcheck2.IndexHybrid { _, exist := indexParams[common.BitmapCardinalityLimitKey] if !exist { indexParams[common.BitmapCardinalityLimitKey] = paramtable.Get().CommonCfg.BitmapIndexCardinalityBound.GetValue() } } - checker, err := indexparamcheck.GetIndexCheckerMgrInstance().GetChecker(indexType) + checker, err := indexparamcheck2.GetIndexCheckerMgrInstance().GetChecker(indexType) if err != nil { log.Warn("Failed to get index checker", zap.String(common.IndexTypeKey, indexType)) return fmt.Errorf("invalid index type: %s", indexType) } - if typeutil.IsVectorType(field.DataType) && indexType != indexparamcheck.AutoIndex { + if typeutil.IsVectorType(field.DataType) && indexType != indexparamcheck2.AutoIndex { exist := CheckVecIndexWithDataTypeExist(indexType, field.DataType) if !exist { return fmt.Errorf("data type %d can't build with this index %s", field.DataType, indexType) @@ -416,17 +423,14 @@ func checkTrain(field *schemapb.FieldSchema, indexParams map[string]string) erro if err := fillDimension(field, indexParams); err != nil { return err } - } else { - // used only for checker, should be deleted after checking - indexParams[IsSparseKey] = "true" } - if err := checker.CheckValidDataType(field); err != nil { + if err := checker.CheckValidDataType(indexType, field); err != nil { log.Info("create index with invalid data type", zap.Error(err), zap.String("data_type", field.GetDataType().String())) return err } - if err := checker.CheckTrain(indexParams); err != nil { + if err := checker.CheckTrain(field.DataType, indexParams); err != nil { log.Info("create index with invalid parameters", zap.Error(err)) return err } diff --git a/internal/proxy/task_index_test.go b/internal/proxy/task_index_test.go index 8c89d403decf7..4525ef8840709 100644 --- a/internal/proxy/task_index_test.go +++ b/internal/proxy/task_index_test.go @@ -19,6 +19,7 @@ package proxy import ( "context" "encoding/json" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "os" "sort" "strconv" @@ -37,7 +38,6 @@ import ( "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" diff --git a/internal/proxy/util.go b/internal/proxy/util.go index d369247331edf..0288a6dc322cd 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -20,6 +20,7 @@ import ( "context" "encoding/json" "fmt" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "strconv" "strings" "time" @@ -47,7 +48,6 @@ import ( "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/contextutil" "github.com/milvus-io/milvus/pkg/util/crypto" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metric" "github.com/milvus-io/milvus/pkg/util/paramtable" diff --git a/internal/querynodev2/segments/collection.go b/internal/querynodev2/segments/collection.go index 5a5679c0c9d4d..01515355186d4 100644 --- a/internal/querynodev2/segments/collection.go +++ b/internal/querynodev2/segments/collection.go @@ -25,6 +25,7 @@ package segments import "C" import ( + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "sync" "unsafe" @@ -40,7 +41,6 @@ import ( "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" ) diff --git a/internal/querynodev2/segments/index_attr_cache.go b/internal/querynodev2/segments/index_attr_cache.go index 393fdc6ce6f20..e16d5e39d0119 100644 --- a/internal/querynodev2/segments/index_attr_cache.go +++ b/internal/querynodev2/segments/index_attr_cache.go @@ -25,6 +25,7 @@ import "C" import ( "fmt" + indexparamcheck2 "github.com/milvus-io/milvus/internal/util/indexparamcheck" "unsafe" "github.com/milvus-io/milvus/internal/proto/datapb" @@ -32,7 +33,6 @@ import ( "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/conc" "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -60,12 +60,12 @@ func (c *IndexAttrCache) GetIndexResourceUsage(indexInfo *querypb.FieldIndexInfo if err != nil { return 0, 0, fmt.Errorf("index type not exist in index params") } - if indexparamcheck.GetVecIndexMgrInstance().IsDiskANN(indexType) { + if indexparamcheck2.GetVecIndexMgrInstance().IsDiskANN(indexType) { neededMemSize := indexInfo.IndexSize / UsedDiskMemoryRatio neededDiskSize := indexInfo.IndexSize - neededMemSize return uint64(neededMemSize), uint64(neededDiskSize), nil } - if indexType == indexparamcheck.IndexINVERTED { + if indexType == indexparamcheck2.IndexINVERTED { neededMemSize := 0 // we will mmap the binlog if the index type is inverted index. neededDiskSize := indexInfo.IndexSize + getBinlogDataDiskSize(fieldBinlog) diff --git a/internal/querynodev2/segments/index_attr_cache_test.go b/internal/querynodev2/segments/index_attr_cache_test.go index 55d3f705bfb93..5320878a5b63d 100644 --- a/internal/querynodev2/segments/index_attr_cache_test.go +++ b/internal/querynodev2/segments/index_attr_cache_test.go @@ -17,6 +17,7 @@ package segments import ( + indexparamcheck2 "github.com/milvus-io/milvus/internal/util/indexparamcheck" "testing" "github.com/stretchr/testify/suite" @@ -81,7 +82,7 @@ func (s *IndexAttrCacheSuite) TestDiskANN() { func (s *IndexAttrCacheSuite) TestInvertedIndex() { info := &querypb.FieldIndexInfo{ IndexParams: []*commonpb.KeyValuePair{ - {Key: common.IndexTypeKey, Value: indexparamcheck.IndexINVERTED}, + {Key: common.IndexTypeKey, Value: indexparamcheck2.IndexINVERTED}, }, CurrentIndexVersion: 0, IndexSize: 50, diff --git a/internal/querynodev2/segments/segment.go b/internal/querynodev2/segments/segment.go index d342c63697bf5..256619984ef2b 100644 --- a/internal/querynodev2/segments/segment.go +++ b/internal/querynodev2/segments/segment.go @@ -29,6 +29,7 @@ import "C" import ( "context" "fmt" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "runtime" "strings" "unsafe" @@ -55,7 +56,6 @@ import ( "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/indexparams" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metautil" diff --git a/internal/querynodev2/segments/segment_loader_test.go b/internal/querynodev2/segments/segment_loader_test.go index ebb282d50e795..59f72fd1351f7 100644 --- a/internal/querynodev2/segments/segment_loader_test.go +++ b/internal/querynodev2/segments/segment_loader_test.go @@ -19,6 +19,7 @@ package segments import ( "context" "fmt" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "math/rand" "testing" "time" @@ -37,7 +38,6 @@ import ( "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/contextutil" "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metric" "github.com/milvus-io/milvus/pkg/util/paramtable" diff --git a/internal/querynodev2/segments/utils.go b/internal/querynodev2/segments/utils.go index 14ddedd2d844e..3765778f853b8 100644 --- a/internal/querynodev2/segments/utils.go +++ b/internal/querynodev2/segments/utils.go @@ -16,6 +16,7 @@ import ( "context" "encoding/binary" "fmt" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "io" "strconv" "time" @@ -33,7 +34,6 @@ import ( "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/contextutil" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" diff --git a/internal/querynodev2/segments/utils_test.go b/internal/querynodev2/segments/utils_test.go index 51d8733ca6114..ca33e4ee89b48 100644 --- a/internal/querynodev2/segments/utils_test.go +++ b/internal/querynodev2/segments/utils_test.go @@ -1,6 +1,7 @@ package segments import ( + indexparamcheck2 "github.com/milvus-io/milvus/internal/util/indexparamcheck" "testing" "github.com/stretchr/testify/assert" @@ -147,7 +148,7 @@ func TestIsIndexMmapEnable(t *testing.T) { IndexParams: []*commonpb.KeyValuePair{ { Key: common.IndexTypeKey, - Value: indexparamcheck.IndexINVERTED, + Value: indexparamcheck2.IndexINVERTED, }, }, }) diff --git a/pkg/util/indexparamcheck/auto_index_checker.go b/internal/util/indexparamcheck/auto_index_checker.go similarity index 59% rename from pkg/util/indexparamcheck/auto_index_checker.go rename to internal/util/indexparamcheck/auto_index_checker.go index cc83f196d2e0c..f56a2887b17ae 100644 --- a/pkg/util/indexparamcheck/auto_index_checker.go +++ b/internal/util/indexparamcheck/auto_index_checker.go @@ -9,11 +9,11 @@ type AUTOINDEXChecker struct { baseChecker } -func (c *AUTOINDEXChecker) CheckTrain(params map[string]string) error { +func (c *AUTOINDEXChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { return nil } -func (c *AUTOINDEXChecker) CheckValidDataType(field *schemapb.FieldSchema) error { +func (c *AUTOINDEXChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { return nil } diff --git a/pkg/util/indexparamcheck/base_checker.go b/internal/util/indexparamcheck/base_checker.go similarity index 67% rename from pkg/util/indexparamcheck/base_checker.go rename to internal/util/indexparamcheck/base_checker.go index 6ea600ba4003d..5f44ea513a20e 100644 --- a/pkg/util/indexparamcheck/base_checker.go +++ b/internal/util/indexparamcheck/base_checker.go @@ -18,30 +18,17 @@ package indexparamcheck import ( "fmt" - "math" - "strings" - "github.com/cockroachdb/errors" + "github.com/milvus-io/milvus/pkg/util/typeutil" + "math" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/pkg/common" ) type baseChecker struct{} -func (c baseChecker) CheckTrain(params map[string]string) error { - // vector dimension should be checked on collection creation. this is just some basic check - isSparse := false - if val, exist := params[common.IsSparseKey]; exist { - val = strings.ToLower(val) - if val != "true" && val != "false" { - return fmt.Errorf("invalid is_sparse value: %s, must be true or false", val) - } - if val == "true" { - isSparse = true - } - } - if isSparse { +func (c baseChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + if typeutil.IsSparseFloatVectorType(dataType) { if !CheckStrByValues(params, Metric, SparseMetrics) { return fmt.Errorf("metric type not found or not supported for sparse float vectors, supported: %v", SparseMetrics) } @@ -55,13 +42,13 @@ func (c baseChecker) CheckTrain(params map[string]string) error { } // CheckValidDataType check whether the field data type is supported for the index type -func (c baseChecker) CheckValidDataType(field *schemapb.FieldSchema) error { +func (c baseChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { return nil } -func (c baseChecker) SetDefaultMetricTypeIfNotExist(m map[string]string, dType schemapb.DataType) {} +func (c baseChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, m map[string]string) {} -func (c baseChecker) StaticCheck(params map[string]string) error { +func (c baseChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error { return errors.New("unsupported index type") } diff --git a/pkg/util/indexparamcheck/base_checker_test.go b/internal/util/indexparamcheck/base_checker_test.go similarity index 93% rename from pkg/util/indexparamcheck/base_checker_test.go rename to internal/util/indexparamcheck/base_checker_test.go index 59a0969d18d4d..c9ceea90af18d 100644 --- a/pkg/util/indexparamcheck/base_checker_test.go +++ b/internal/util/indexparamcheck/base_checker_test.go @@ -44,7 +44,7 @@ func Test_baseChecker_CheckTrain(t *testing.T) { c := newBaseChecker() for _, test := range cases { - err := c.CheckTrain(test.params) + err := c.CheckTrain(0, test.params) if test.errIsNil { assert.NoError(t, err) } else { @@ -115,7 +115,7 @@ func Test_baseChecker_CheckValidDataType(t *testing.T) { c := newBaseChecker() for _, test := range cases { fieldSchema := &schemapb.FieldSchema{DataType: test.dType} - err := c.CheckValidDataType(fieldSchema) + err := c.CheckValidDataType("FLAT", fieldSchema) if test.errIsNil { assert.NoError(t, err) } else { @@ -126,5 +126,5 @@ func Test_baseChecker_CheckValidDataType(t *testing.T) { func Test_baseChecker_StaticCheck(t *testing.T) { // TODO - assert.Error(t, newBaseChecker().StaticCheck(nil)) + assert.Error(t, newBaseChecker().StaticCheck(schemapb.DataType_FloatVector, nil)) } diff --git a/internal/util/indexparamcheck/bin_flat_checker.go b/internal/util/indexparamcheck/bin_flat_checker.go new file mode 100644 index 0000000000000..647e3827ff012 --- /dev/null +++ b/internal/util/indexparamcheck/bin_flat_checker.go @@ -0,0 +1,19 @@ +package indexparamcheck + +import "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + +type binFlatChecker struct { + binaryVectorBaseChecker +} + +func (c binFlatChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + return c.binaryVectorBaseChecker.CheckTrain(0, params) +} + +func (c binFlatChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error { + return c.staticCheck(params) +} + +func newBinFlatChecker() IndexChecker { + return &binFlatChecker{} +} diff --git a/pkg/util/indexparamcheck/bin_flat_checker_test.go b/internal/util/indexparamcheck/bin_flat_checker_test.go similarity index 96% rename from pkg/util/indexparamcheck/bin_flat_checker_test.go rename to internal/util/indexparamcheck/bin_flat_checker_test.go index 9cf4f39394515..c482a05bc5800 100644 --- a/pkg/util/indexparamcheck/bin_flat_checker_test.go +++ b/internal/util/indexparamcheck/bin_flat_checker_test.go @@ -66,7 +66,7 @@ func Test_binFlatChecker_CheckTrain(t *testing.T) { c := newBinFlatChecker() for _, test := range cases { - err := c.CheckTrain(test.params) + err := c.CheckTrain(0, test.params) if test.errIsNil { assert.NoError(t, err) } else { @@ -137,7 +137,7 @@ func Test_binFlatChecker_CheckValidDataType(t *testing.T) { c := newBinFlatChecker() for _, test := range cases { fieldSchema := &schemapb.FieldSchema{DataType: test.dType} - err := c.CheckValidDataType(fieldSchema) + err := c.CheckValidDataType("BINFLAT", fieldSchema) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/bin_ivf_flat_checker.go b/internal/util/indexparamcheck/bin_ivf_flat_checker.go similarity index 55% rename from pkg/util/indexparamcheck/bin_ivf_flat_checker.go rename to internal/util/indexparamcheck/bin_ivf_flat_checker.go index c36bc41c1c32e..b53e6ea1c5b0f 100644 --- a/pkg/util/indexparamcheck/bin_ivf_flat_checker.go +++ b/internal/util/indexparamcheck/bin_ivf_flat_checker.go @@ -2,13 +2,14 @@ package indexparamcheck import ( "fmt" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) type binIVFFlatChecker struct { binaryVectorBaseChecker } -func (c binIVFFlatChecker) StaticCheck(params map[string]string) error { +func (c binIVFFlatChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error { if !CheckStrByValues(params, Metric, BinIvfMetrics) { return fmt.Errorf("metric type %s not found or not supported, supported: %v", params[Metric], BinIvfMetrics) } @@ -20,12 +21,12 @@ func (c binIVFFlatChecker) StaticCheck(params map[string]string) error { return nil } -func (c binIVFFlatChecker) CheckTrain(params map[string]string) error { - if err := c.binaryVectorBaseChecker.CheckTrain(params); err != nil { +func (c binIVFFlatChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + if err := c.binaryVectorBaseChecker.CheckTrain(0, params); err != nil { return err } - return c.StaticCheck(params) + return c.StaticCheck(schemapb.DataType_BinaryVector, params) } func newBinIVFFlatChecker() IndexChecker { diff --git a/pkg/util/indexparamcheck/bin_ivf_flat_checker_test.go b/internal/util/indexparamcheck/bin_ivf_flat_checker_test.go similarity index 97% rename from pkg/util/indexparamcheck/bin_ivf_flat_checker_test.go rename to internal/util/indexparamcheck/bin_ivf_flat_checker_test.go index 77bda3bb016b1..ab41fed98b00c 100644 --- a/pkg/util/indexparamcheck/bin_ivf_flat_checker_test.go +++ b/internal/util/indexparamcheck/bin_ivf_flat_checker_test.go @@ -117,7 +117,7 @@ func Test_binIVFFlatChecker_CheckTrain(t *testing.T) { c := newBinIVFFlatChecker() for _, test := range cases { - err := c.CheckTrain(test.params) + err := c.CheckTrain(0, test.params) if test.errIsNil { assert.NoError(t, err) } else { @@ -188,7 +188,7 @@ func Test_binIVFFlatChecker_CheckValidDataType(t *testing.T) { c := newBinIVFFlatChecker() for _, test := range cases { fieldSchema := &schemapb.FieldSchema{DataType: test.dType} - err := c.CheckValidDataType(fieldSchema) + err := c.CheckValidDataType("", fieldSchema) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/binary_vector_base_checker.go b/internal/util/indexparamcheck/binary_vector_base_checker.go similarity index 72% rename from pkg/util/indexparamcheck/binary_vector_base_checker.go rename to internal/util/indexparamcheck/binary_vector_base_checker.go index e73bd8b62e40a..d7b1ca891ae30 100644 --- a/pkg/util/indexparamcheck/binary_vector_base_checker.go +++ b/internal/util/indexparamcheck/binary_vector_base_checker.go @@ -19,22 +19,22 @@ func (c binaryVectorBaseChecker) staticCheck(params map[string]string) error { return nil } -func (c binaryVectorBaseChecker) CheckTrain(params map[string]string) error { - if err := c.baseChecker.CheckTrain(params); err != nil { +func (c binaryVectorBaseChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + if err := c.baseChecker.CheckTrain(0, params); err != nil { return err } return c.staticCheck(params) } -func (c binaryVectorBaseChecker) CheckValidDataType(field *schemapb.FieldSchema) error { +func (c binaryVectorBaseChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { if field.GetDataType() != schemapb.DataType_BinaryVector { return fmt.Errorf("binary vector is only supported") } return nil } -func (c binaryVectorBaseChecker) SetDefaultMetricTypeIfNotExist(params map[string]string, dType schemapb.DataType) { +func (c binaryVectorBaseChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, params map[string]string) { setDefaultIfNotExist(params, common.MetricTypeKey, BinaryVectorDefaultMetricType) } diff --git a/pkg/util/indexparamcheck/binary_vector_base_checker_test.go b/internal/util/indexparamcheck/binary_vector_base_checker_test.go similarity index 96% rename from pkg/util/indexparamcheck/binary_vector_base_checker_test.go rename to internal/util/indexparamcheck/binary_vector_base_checker_test.go index b52648f79355e..fb47b7f122014 100644 --- a/pkg/util/indexparamcheck/binary_vector_base_checker_test.go +++ b/internal/util/indexparamcheck/binary_vector_base_checker_test.go @@ -70,7 +70,7 @@ func Test_binaryVectorBaseChecker_CheckValidDataType(t *testing.T) { c := newBinaryVectorBaseChecker() for _, test := range cases { fieldSchema := &schemapb.FieldSchema{DataType: test.dType} - err := c.CheckValidDataType(fieldSchema) + err := c.CheckValidDataType("", fieldSchema) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/internal/util/indexparamcheck/bitmap_checker_test.go b/internal/util/indexparamcheck/bitmap_checker_test.go new file mode 100644 index 0000000000000..3c97c14a02a8c --- /dev/null +++ b/internal/util/indexparamcheck/bitmap_checker_test.go @@ -0,0 +1,32 @@ +package indexparamcheck + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +func Test_BitmapIndexChecker(t *testing.T) { + c := newBITMAPChecker() + + assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Bool})) + assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Int8})) + assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Int16})) + assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Int32})) + assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Int64})) + assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_String})) + assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Bool})) + assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int8})) + assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int16})) + assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int32})) + assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64})) + assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_String})) + + assert.Error(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_JSON})) + assert.Error(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Float})) + assert.Error(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Double})) + assert.Error(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Float})) + assert.Error(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Double})) +} diff --git a/pkg/util/indexparamcheck/bitmap_index_checker.go b/internal/util/indexparamcheck/bitmap_index_checker.go similarity index 79% rename from pkg/util/indexparamcheck/bitmap_index_checker.go rename to internal/util/indexparamcheck/bitmap_index_checker.go index f19943a50ea93..55375d93771e0 100644 --- a/pkg/util/indexparamcheck/bitmap_index_checker.go +++ b/internal/util/indexparamcheck/bitmap_index_checker.go @@ -11,11 +11,11 @@ type BITMAPChecker struct { scalarIndexChecker } -func (c *BITMAPChecker) CheckTrain(params map[string]string) error { - return c.scalarIndexChecker.CheckTrain(params) +func (c *BITMAPChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + return c.scalarIndexChecker.CheckTrain(0, params) } -func (c *BITMAPChecker) CheckValidDataType(field *schemapb.FieldSchema) error { +func (c *BITMAPChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { if field.IsPrimaryKey { return fmt.Errorf("create bitmap index on primary key not supported") } diff --git a/pkg/util/indexparamcheck/cagra_checker.go b/internal/util/indexparamcheck/cagra_checker.go similarity index 84% rename from pkg/util/indexparamcheck/cagra_checker.go rename to internal/util/indexparamcheck/cagra_checker.go index 8f52a1605d775..45cd24b57c9e7 100644 --- a/pkg/util/indexparamcheck/cagra_checker.go +++ b/internal/util/indexparamcheck/cagra_checker.go @@ -2,6 +2,7 @@ package indexparamcheck import ( "fmt" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "strconv" ) @@ -10,8 +11,8 @@ type cagraChecker struct { floatVectorBaseChecker } -func (c *cagraChecker) CheckTrain(params map[string]string) error { - err := c.baseChecker.CheckTrain(params) +func (c *cagraChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + err := c.baseChecker.CheckTrain(0, params) if err != nil { return err } @@ -54,7 +55,7 @@ func (c *cagraChecker) CheckTrain(params map[string]string) error { return nil } -func (c cagraChecker) StaticCheck(params map[string]string) error { +func (c cagraChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error { return c.staticCheck(params) } diff --git a/pkg/util/indexparamcheck/cagra_checker_test.go b/internal/util/indexparamcheck/cagra_checker_test.go similarity index 98% rename from pkg/util/indexparamcheck/cagra_checker_test.go rename to internal/util/indexparamcheck/cagra_checker_test.go index 23a931a12ef01..e3b058187d866 100644 --- a/pkg/util/indexparamcheck/cagra_checker_test.go +++ b/internal/util/indexparamcheck/cagra_checker_test.go @@ -103,7 +103,7 @@ func Test_cagraChecker_CheckTrain(t *testing.T) { c := newCagraChecker() for _, test := range cases { - err := c.CheckTrain(test.params) + err := c.CheckTrain(0, test.params) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/conf_adapter_mgr.go b/internal/util/indexparamcheck/conf_adapter_mgr.go similarity index 71% rename from pkg/util/indexparamcheck/conf_adapter_mgr.go rename to internal/util/indexparamcheck/conf_adapter_mgr.go index 2ff7320c9b3a2..45d0549e05cbd 100644 --- a/pkg/util/indexparamcheck/conf_adapter_mgr.go +++ b/internal/util/indexparamcheck/conf_adapter_mgr.go @@ -34,7 +34,10 @@ type indexCheckerMgrImpl struct { func (mgr *indexCheckerMgrImpl) GetChecker(indexType string) (IndexChecker, error) { mgr.once.Do(mgr.registerIndexChecker) - + // Unify the vector index checker + if GetVecIndexMgrInstance().IsVecIndex(indexType) { + return mgr.checkers[IndexVector], nil + } adapter, ok := mgr.checkers[indexType] if ok { return adapter, nil @@ -43,23 +46,7 @@ func (mgr *indexCheckerMgrImpl) GetChecker(indexType string) (IndexChecker, erro } func (mgr *indexCheckerMgrImpl) registerIndexChecker() { - mgr.checkers[IndexRaftIvfFlat] = newRaftIVFFlatChecker() - mgr.checkers[IndexRaftIvfPQ] = newRaftIVFPQChecker() - mgr.checkers[IndexRaftCagra] = newCagraChecker() - mgr.checkers[IndexRaftBruteForce] = newRaftBruteForceChecker() - mgr.checkers[IndexFaissIDMap] = newFlatChecker() - mgr.checkers[IndexFaissIvfFlat] = newIVFBaseChecker() - mgr.checkers[IndexFaissIvfPQ] = newIVFPQChecker() - mgr.checkers[IndexScaNN] = newScaNNChecker() - mgr.checkers[IndexFaissIvfSQ8] = newIVFSQChecker() - mgr.checkers[IndexFaissBinIDMap] = newBinFlatChecker() - mgr.checkers[IndexFaissBinIvfFlat] = newBinIVFFlatChecker() - mgr.checkers[IndexHNSW] = newHnswChecker() - mgr.checkers[IndexDISKANN] = newDiskannChecker() - mgr.checkers[IndexSparseInverted] = newSparseInvertedIndexChecker() - // WAND doesn't have more index params than sparse inverted index, thus - // using the same checker. - mgr.checkers[IndexSparseWand] = newSparseInvertedIndexChecker() + mgr.checkers[IndexVector] = newVecIndexChecker() mgr.checkers[IndexINVERTED] = newINVERTEDChecker() mgr.checkers[IndexSTLSORT] = newSTLSORTChecker() mgr.checkers["Asceneding"] = newSTLSORTChecker() diff --git a/pkg/util/indexparamcheck/conf_adapter_mgr_test.go b/internal/util/indexparamcheck/conf_adapter_mgr_test.go similarity index 78% rename from pkg/util/indexparamcheck/conf_adapter_mgr_test.go rename to internal/util/indexparamcheck/conf_adapter_mgr_test.go index 6ab9469ee501d..746f6cd09454d 100644 --- a/pkg/util/indexparamcheck/conf_adapter_mgr_test.go +++ b/internal/util/indexparamcheck/conf_adapter_mgr_test.go @@ -29,49 +29,49 @@ func Test_GetConfAdapterMgrInstance(t *testing.T) { assert.NotEqual(t, nil, err) assert.Equal(t, nil, adapter) - adapter, err = adapterMgr.GetChecker(IndexFaissIDMap) + adapter, err = adapterMgr.GetChecker("FLAT") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) _, ok = adapter.(*flatChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexFaissIvfFlat) + adapter, err = adapterMgr.GetChecker("IVF_FLAT") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) _, ok = adapter.(*ivfBaseChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexScaNN) + adapter, err = adapterMgr.GetChecker("SCANN") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) _, ok = adapter.(*scaNNChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexFaissIvfPQ) + adapter, err = adapterMgr.GetChecker("IVF_PQ") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) _, ok = adapter.(*ivfPQChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexFaissIvfSQ8) + adapter, err = adapterMgr.GetChecker("IVF_SQ8") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) _, ok = adapter.(*ivfSQChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexFaissBinIDMap) + adapter, err = adapterMgr.GetChecker("BIN_FLAT") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) _, ok = adapter.(*binFlatChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexFaissBinIvfFlat) + adapter, err = adapterMgr.GetChecker("BIN_IVF_FLAT") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) _, ok = adapter.(*binIVFFlatChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexHNSW) + adapter, err = adapterMgr.GetChecker("HNSW") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) _, ok = adapter.(*hnswChecker) @@ -89,49 +89,49 @@ func TestConfAdapterMgrImpl_GetAdapter(t *testing.T) { assert.NotEqual(t, nil, err) assert.Equal(t, nil, adapter) - adapter, err = adapterMgr.GetChecker(IndexFaissIDMap) + adapter, err = adapterMgr.GetChecker("FLAT") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) _, ok = adapter.(*flatChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexFaissIvfFlat) + adapter, err = adapterMgr.GetChecker("IVF_FLAT") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) _, ok = adapter.(*ivfBaseChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexScaNN) + adapter, err = adapterMgr.GetChecker("SCANN") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) _, ok = adapter.(*scaNNChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexFaissIvfPQ) + adapter, err = adapterMgr.GetChecker("IVF_PQ") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) _, ok = adapter.(*ivfPQChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexFaissIvfSQ8) + adapter, err = adapterMgr.GetChecker("IVF_SQ8") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) _, ok = adapter.(*ivfSQChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexFaissBinIDMap) + adapter, err = adapterMgr.GetChecker("BIN_FLAT") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) _, ok = adapter.(*binFlatChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexFaissBinIvfFlat) + adapter, err = adapterMgr.GetChecker("BIN_IVF_FLAT") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) _, ok = adapter.(*binIVFFlatChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexHNSW) + adapter, err = adapterMgr.GetChecker("HNSW") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) _, ok = adapter.(*hnswChecker) @@ -146,7 +146,7 @@ func TestConfAdapterMgrImpl_GetAdapter_multiple_threads(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - adapter, err := mgr.GetChecker(IndexHNSW) + adapter, err := mgr.GetChecker("HNSW") assert.NoError(t, err) assert.NotNil(t, adapter) }() diff --git a/pkg/util/indexparamcheck/constraints.go b/internal/util/indexparamcheck/constraints.go similarity index 100% rename from pkg/util/indexparamcheck/constraints.go rename to internal/util/indexparamcheck/constraints.go diff --git a/pkg/util/indexparamcheck/diskann_checker.go b/internal/util/indexparamcheck/diskann_checker.go similarity index 59% rename from pkg/util/indexparamcheck/diskann_checker.go rename to internal/util/indexparamcheck/diskann_checker.go index 3f2401851e961..323859b6f05a5 100644 --- a/pkg/util/indexparamcheck/diskann_checker.go +++ b/internal/util/indexparamcheck/diskann_checker.go @@ -1,11 +1,13 @@ package indexparamcheck +import "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + // diskannChecker checks if an diskann index can be built. type diskannChecker struct { floatVectorBaseChecker } -func (c diskannChecker) StaticCheck(params map[string]string) error { +func (c diskannChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error { return c.staticCheck(params) } diff --git a/pkg/util/indexparamcheck/diskann_checker_test.go b/internal/util/indexparamcheck/diskann_checker_test.go similarity index 95% rename from pkg/util/indexparamcheck/diskann_checker_test.go rename to internal/util/indexparamcheck/diskann_checker_test.go index 4fcfdbf019aa7..e46d2dc09bdcf 100644 --- a/pkg/util/indexparamcheck/diskann_checker_test.go +++ b/internal/util/indexparamcheck/diskann_checker_test.go @@ -74,7 +74,7 @@ func Test_diskannChecker_CheckTrain(t *testing.T) { c := newDiskannChecker() for _, test := range cases { - err := c.CheckTrain(test.params) + err := c.CheckTrain(0, test.params) if test.errIsNil { assert.NoError(t, err) } else { @@ -144,7 +144,7 @@ func Test_diskannChecker_CheckValidDataType(t *testing.T) { c := newDiskannChecker() for _, test := range cases { - err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType}) + err := c.CheckValidDataType("DISKANN", &schemapb.FieldSchema{DataType: test.dType}) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/flat_checker.go b/internal/util/indexparamcheck/flat_checker.go similarity index 52% rename from pkg/util/indexparamcheck/flat_checker.go rename to internal/util/indexparamcheck/flat_checker.go index d98db449206b4..8fe6d59f25ba0 100644 --- a/pkg/util/indexparamcheck/flat_checker.go +++ b/internal/util/indexparamcheck/flat_checker.go @@ -1,10 +1,12 @@ package indexparamcheck +import "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + type flatChecker struct { floatVectorBaseChecker } -func (c flatChecker) StaticCheck(m map[string]string) error { +func (c flatChecker) StaticCheck(dataType schemapb.DataType, m map[string]string) error { return c.staticCheck(m) } diff --git a/pkg/util/indexparamcheck/flat_checker_test.go b/internal/util/indexparamcheck/flat_checker_test.go similarity index 91% rename from pkg/util/indexparamcheck/flat_checker_test.go rename to internal/util/indexparamcheck/flat_checker_test.go index c22432bc6f17c..74cc786b2e99d 100644 --- a/pkg/util/indexparamcheck/flat_checker_test.go +++ b/internal/util/indexparamcheck/flat_checker_test.go @@ -1,6 +1,7 @@ package indexparamcheck import ( + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "strconv" "testing" @@ -54,7 +55,7 @@ func Test_flatChecker_CheckTrain(t *testing.T) { c := newFlatChecker() for _, test := range cases { - err := c.CheckTrain(test.params) + err := c.CheckTrain(0, test.params) if test.errIsNil { assert.NoError(t, err) } else { @@ -91,7 +92,7 @@ func Test_flatChecker_StaticCheck(t *testing.T) { c := newFlatChecker() for _, test := range cases { - err := c.StaticCheck(test.params) + err := c.StaticCheck(schemapb.DataType_FloatVector, test.params) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/float_vector_base_checker.go b/internal/util/indexparamcheck/float_vector_base_checker.go similarity index 69% rename from pkg/util/indexparamcheck/float_vector_base_checker.go rename to internal/util/indexparamcheck/float_vector_base_checker.go index 710dfb3a18a38..c8ffe4254a128 100644 --- a/pkg/util/indexparamcheck/float_vector_base_checker.go +++ b/internal/util/indexparamcheck/float_vector_base_checker.go @@ -20,22 +20,22 @@ func (c floatVectorBaseChecker) staticCheck(params map[string]string) error { return nil } -func (c floatVectorBaseChecker) CheckTrain(params map[string]string) error { - if err := c.baseChecker.CheckTrain(params); err != nil { +func (c floatVectorBaseChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + if err := c.baseChecker.CheckTrain(0, params); err != nil { return err } return c.staticCheck(params) } -func (c floatVectorBaseChecker) CheckValidDataType(field *schemapb.FieldSchema) error { +func (c floatVectorBaseChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { if !typeutil.IsDenseFloatVectorType(field.GetDataType()) { return fmt.Errorf("data type should be FloatVector, Float16Vector or BFloat16Vector") } return nil } -func (c floatVectorBaseChecker) SetDefaultMetricTypeIfNotExist(params map[string]string, dType schemapb.DataType) { +func (c floatVectorBaseChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, params map[string]string) { setDefaultIfNotExist(params, common.MetricTypeKey, FloatVectorDefaultMetricType) } diff --git a/pkg/util/indexparamcheck/float_vector_base_checker_test.go b/internal/util/indexparamcheck/float_vector_base_checker_test.go similarity index 94% rename from pkg/util/indexparamcheck/float_vector_base_checker_test.go rename to internal/util/indexparamcheck/float_vector_base_checker_test.go index 7eb0a97d36c6c..bef2d4efbb32a 100644 --- a/pkg/util/indexparamcheck/float_vector_base_checker_test.go +++ b/internal/util/indexparamcheck/float_vector_base_checker_test.go @@ -69,7 +69,7 @@ func Test_floatVectorBaseChecker_CheckValidDataType(t *testing.T) { c := newFloatVectorBaseChecker() for _, test := range cases { - err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType}) + err := c.CheckValidDataType("HNSW", &schemapb.FieldSchema{DataType: test.dType}) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/hnsw_checker.go b/internal/util/indexparamcheck/hnsw_checker.go similarity index 72% rename from pkg/util/indexparamcheck/hnsw_checker.go rename to internal/util/indexparamcheck/hnsw_checker.go index b5f9e1f2b77e1..78a733a0082ca 100644 --- a/pkg/util/indexparamcheck/hnsw_checker.go +++ b/internal/util/indexparamcheck/hnsw_checker.go @@ -12,7 +12,7 @@ type hnswChecker struct { baseChecker } -func (c hnswChecker) StaticCheck(params map[string]string) error { +func (c hnswChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error { if !CheckIntByRange(params, EFConstruction, HNSWMinEfConstruction, HNSWMaxEfConstruction) { return errOutOfRange(EFConstruction, HNSWMinEfConstruction, HNSWMaxEfConstruction) } @@ -25,21 +25,21 @@ func (c hnswChecker) StaticCheck(params map[string]string) error { return nil } -func (c hnswChecker) CheckTrain(params map[string]string) error { - if err := c.StaticCheck(params); err != nil { +func (c hnswChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + if err := c.StaticCheck(dataType, params); err != nil { return err } - return c.baseChecker.CheckTrain(params) + return c.baseChecker.CheckTrain(0, params) } -func (c hnswChecker) CheckValidDataType(field *schemapb.FieldSchema) error { +func (c hnswChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { if !typeutil.IsVectorType(field.GetDataType()) { return fmt.Errorf("can't build hnsw in not vector type") } return nil } -func (c hnswChecker) SetDefaultMetricTypeIfNotExist(params map[string]string, dType schemapb.DataType) { +func (c hnswChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, params map[string]string) { if typeutil.IsDenseFloatVectorType(dType) { setDefaultIfNotExist(params, common.MetricTypeKey, FloatVectorDefaultMetricType) } else if typeutil.IsSparseFloatVectorType(dType) { diff --git a/pkg/util/indexparamcheck/hnsw_checker_test.go b/internal/util/indexparamcheck/hnsw_checker_test.go similarity index 96% rename from pkg/util/indexparamcheck/hnsw_checker_test.go rename to internal/util/indexparamcheck/hnsw_checker_test.go index b9118125407e9..21ded66bb8b74 100644 --- a/pkg/util/indexparamcheck/hnsw_checker_test.go +++ b/internal/util/indexparamcheck/hnsw_checker_test.go @@ -94,7 +94,7 @@ func Test_hnswChecker_CheckTrain(t *testing.T) { c := newHnswChecker() for _, test := range cases { - err := c.CheckTrain(test.params) + err := c.CheckTrain(0, test.params) if test.errIsNil { assert.NoError(t, err) } else { @@ -164,7 +164,7 @@ func Test_hnswChecker_CheckValidDataType(t *testing.T) { c := newHnswChecker() for _, test := range cases { - err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType}) + err := c.CheckValidDataType("HNSW", &schemapb.FieldSchema{DataType: test.dType}) if test.errIsNil { assert.NoError(t, err) } else { @@ -207,7 +207,7 @@ func Test_hnswChecker_SetDefaultMetricType(t *testing.T) { HNSWM: strconv.Itoa(16), EFConstruction: strconv.Itoa(200), } - c.SetDefaultMetricTypeIfNotExist(p, test.dType) + c.SetDefaultMetricTypeIfNotExist(test.dType, p) assert.Equal(t, p[Metric], test.metricType) } } diff --git a/internal/util/indexparamcheck/hybrid_checker_test.go b/internal/util/indexparamcheck/hybrid_checker_test.go new file mode 100644 index 0000000000000..fdfa507ab9e02 --- /dev/null +++ b/internal/util/indexparamcheck/hybrid_checker_test.go @@ -0,0 +1,37 @@ +package indexparamcheck + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +func Test_HybridIndexChecker(t *testing.T) { + c := newHYBRIDChecker() + + assert.NoError(t, c.CheckTrain(0, map[string]string{"bitmap_cardinality_limit": "100"})) + + assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Bool})) + assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Int8})) + assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Int16})) + assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Int32})) + assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Int64})) + assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_String})) + assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Bool})) + assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int8})) + assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int16})) + assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int32})) + assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64})) + assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_String})) + + assert.Error(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_JSON})) + assert.Error(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Float})) + assert.Error(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Double})) + assert.Error(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Float})) + assert.Error(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Double})) + assert.Error(t, c.CheckTrain(0, map[string]string{})) + assert.Error(t, c.CheckTrain(0, map[string]string{"bitmap_cardinality_limit": "0"})) + assert.Error(t, c.CheckTrain(0, map[string]string{"bitmap_cardinality_limit": "2000"})) +} diff --git a/pkg/util/indexparamcheck/hybrid_index_checker.go b/internal/util/indexparamcheck/hybrid_index_checker.go similarity index 83% rename from pkg/util/indexparamcheck/hybrid_index_checker.go rename to internal/util/indexparamcheck/hybrid_index_checker.go index 9493bccd91f6d..891c14192793b 100644 --- a/pkg/util/indexparamcheck/hybrid_index_checker.go +++ b/internal/util/indexparamcheck/hybrid_index_checker.go @@ -12,15 +12,15 @@ type HYBRIDChecker struct { scalarIndexChecker } -func (c *HYBRIDChecker) CheckTrain(params map[string]string) error { +func (c *HYBRIDChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { if !CheckIntByRange(params, common.BitmapCardinalityLimitKey, 1, MaxBitmapCardinalityLimit) { return fmt.Errorf("failed to check bitmap cardinality limit, should be larger than 0 and smaller than %d", MaxBitmapCardinalityLimit) } - return c.scalarIndexChecker.CheckTrain(params) + return c.scalarIndexChecker.CheckTrain(0, params) } -func (c *HYBRIDChecker) CheckValidDataType(field *schemapb.FieldSchema) error { +func (c *HYBRIDChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { mainType := field.GetDataType() elemType := field.GetElementType() if !typeutil.IsBoolType(mainType) && !typeutil.IsIntegerType(mainType) && diff --git a/pkg/util/indexparamcheck/index_checker.go b/internal/util/indexparamcheck/index_checker.go similarity index 77% rename from pkg/util/indexparamcheck/index_checker.go rename to internal/util/indexparamcheck/index_checker.go index 1c11280898394..610ddffc2cd9b 100644 --- a/pkg/util/indexparamcheck/index_checker.go +++ b/internal/util/indexparamcheck/index_checker.go @@ -21,8 +21,8 @@ import ( ) type IndexChecker interface { - CheckTrain(map[string]string) error - CheckValidDataType(field *schemapb.FieldSchema) error - SetDefaultMetricTypeIfNotExist(map[string]string, schemapb.DataType) - StaticCheck(map[string]string) error + CheckTrain(schemapb.DataType, map[string]string) error + CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error + SetDefaultMetricTypeIfNotExist(schemapb.DataType, map[string]string) + StaticCheck(schemapb.DataType, map[string]string) error } diff --git a/pkg/util/indexparamcheck/index_checker_test.go b/internal/util/indexparamcheck/index_checker_test.go similarity index 100% rename from pkg/util/indexparamcheck/index_checker_test.go rename to internal/util/indexparamcheck/index_checker_test.go diff --git a/pkg/util/indexparamcheck/index_type.go b/internal/util/indexparamcheck/index_type.go similarity index 78% rename from pkg/util/indexparamcheck/index_type.go rename to internal/util/indexparamcheck/index_type.go index d0d57c35d9980..e487970c7236f 100644 --- a/pkg/util/indexparamcheck/index_type.go +++ b/internal/util/indexparamcheck/index_type.go @@ -23,23 +23,7 @@ type IndexType = string // IndexType definitions const ( - // vector index - IndexGpuBF IndexType = "GPU_BRUTE_FORCE" - IndexRaftIvfFlat IndexType = "GPU_IVF_FLAT" - IndexRaftIvfPQ IndexType = "GPU_IVF_PQ" - IndexRaftCagra IndexType = "GPU_CAGRA" - IndexRaftBruteForce IndexType = "GPU_BRUTE_FORCE" - IndexFaissIDMap IndexType = "FLAT" // no index is built. - IndexFaissIvfFlat IndexType = "IVF_FLAT" - IndexFaissIvfPQ IndexType = "IVF_PQ" - IndexScaNN IndexType = "SCANN" - IndexFaissIvfSQ8 IndexType = "IVF_SQ8" - IndexFaissBinIDMap IndexType = "BIN_FLAT" - IndexFaissBinIvfFlat IndexType = "BIN_IVF_FLAT" - IndexHNSW IndexType = "HNSW" - IndexDISKANN IndexType = "DISKANN" - IndexSparseInverted IndexType = "SPARSE_INVERTED_INDEX" - IndexSparseWand IndexType = "SPARSE_WAND" + IndexVector IndexType = "VECINDEX" // scalar index IndexSTLSORT IndexType = "STL_SORT" diff --git a/pkg/util/indexparamcheck/index_type_test.go b/internal/util/indexparamcheck/index_type_test.go similarity index 100% rename from pkg/util/indexparamcheck/index_type_test.go rename to internal/util/indexparamcheck/index_type_test.go diff --git a/pkg/util/indexparamcheck/inverted_checker.go b/internal/util/indexparamcheck/inverted_checker.go similarity index 70% rename from pkg/util/indexparamcheck/inverted_checker.go rename to internal/util/indexparamcheck/inverted_checker.go index 8d6893c10085a..85d904825f3e8 100644 --- a/pkg/util/indexparamcheck/inverted_checker.go +++ b/internal/util/indexparamcheck/inverted_checker.go @@ -12,11 +12,11 @@ type INVERTEDChecker struct { scalarIndexChecker } -func (c *INVERTEDChecker) CheckTrain(params map[string]string) error { - return c.scalarIndexChecker.CheckTrain(params) +func (c *INVERTEDChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + return c.scalarIndexChecker.CheckTrain(0, params) } -func (c *INVERTEDChecker) CheckValidDataType(field *schemapb.FieldSchema) error { +func (c *INVERTEDChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { dType := field.GetDataType() if !typeutil.IsBoolType(dType) && !typeutil.IsArithmetic(dType) && !typeutil.IsStringType(dType) && !typeutil.IsArrayType(dType) { diff --git a/internal/util/indexparamcheck/inverted_checker_test.go b/internal/util/indexparamcheck/inverted_checker_test.go new file mode 100644 index 0000000000000..52659f129457f --- /dev/null +++ b/internal/util/indexparamcheck/inverted_checker_test.go @@ -0,0 +1,25 @@ +package indexparamcheck + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +func Test_INVERTEDIndexChecker(t *testing.T) { + c := newINVERTEDChecker() + + assert.NoError(t, c.CheckTrain(0, map[string]string{})) + + assert.NoError(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_VarChar})) + assert.NoError(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_String})) + assert.NoError(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_Bool})) + assert.NoError(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_Int64})) + assert.NoError(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_Float})) + assert.NoError(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_Array})) + + assert.Error(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_JSON})) + assert.Error(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_FloatVector})) +} diff --git a/internal/util/indexparamcheck/ivf_base_checker.go b/internal/util/indexparamcheck/ivf_base_checker.go new file mode 100644 index 0000000000000..93b0255c35dd5 --- /dev/null +++ b/internal/util/indexparamcheck/ivf_base_checker.go @@ -0,0 +1,28 @@ +package indexparamcheck + +import "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + +type ivfBaseChecker struct { + floatVectorBaseChecker +} + +func (c ivfBaseChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error { + if !CheckIntByRange(params, NLIST, MinNList, MaxNList) { + return errOutOfRange(NLIST, MinNList, MaxNList) + } + + // skip check number of rows + + return c.floatVectorBaseChecker.staticCheck(params) +} + +func (c ivfBaseChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + if err := c.StaticCheck(dataType, params); err != nil { + return err + } + return c.floatVectorBaseChecker.CheckTrain(0, params) +} + +func newIVFBaseChecker() IndexChecker { + return &ivfBaseChecker{} +} diff --git a/pkg/util/indexparamcheck/ivf_base_checker_test.go b/internal/util/indexparamcheck/ivf_base_checker_test.go similarity index 95% rename from pkg/util/indexparamcheck/ivf_base_checker_test.go rename to internal/util/indexparamcheck/ivf_base_checker_test.go index 4a379038dde33..62e54a010f0b2 100644 --- a/pkg/util/indexparamcheck/ivf_base_checker_test.go +++ b/internal/util/indexparamcheck/ivf_base_checker_test.go @@ -72,7 +72,7 @@ func Test_ivfBaseChecker_CheckTrain(t *testing.T) { c := newIVFBaseChecker() for _, test := range cases { - err := c.CheckTrain(test.params) + err := c.CheckTrain(0, test.params) if test.errIsNil { assert.NoError(t, err) } else { @@ -142,7 +142,7 @@ func Test_ivfBaseChecker_CheckValidDataType(t *testing.T) { c := newIVFBaseChecker() for _, test := range cases { - err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType}) + err := c.CheckValidDataType("IVF_FLAT", &schemapb.FieldSchema{DataType: test.dType}) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/ivf_pq_checker.go b/internal/util/indexparamcheck/ivf_pq_checker.go similarity index 87% rename from pkg/util/indexparamcheck/ivf_pq_checker.go rename to internal/util/indexparamcheck/ivf_pq_checker.go index 4c35f193c4689..b0ad8192911a7 100644 --- a/pkg/util/indexparamcheck/ivf_pq_checker.go +++ b/internal/util/indexparamcheck/ivf_pq_checker.go @@ -2,6 +2,7 @@ package indexparamcheck import ( "fmt" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "strconv" ) @@ -11,8 +12,8 @@ type ivfPQChecker struct { } // CheckTrain checks if ivf-pq index can be built with the specific index parameters. -func (c *ivfPQChecker) CheckTrain(params map[string]string) error { - if err := c.ivfBaseChecker.CheckTrain(params); err != nil { +func (c *ivfPQChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + if err := c.ivfBaseChecker.CheckTrain(0, params); err != nil { return err } diff --git a/pkg/util/indexparamcheck/ivf_pq_checker_test.go b/internal/util/indexparamcheck/ivf_pq_checker_test.go similarity index 97% rename from pkg/util/indexparamcheck/ivf_pq_checker_test.go rename to internal/util/indexparamcheck/ivf_pq_checker_test.go index 4a22d45542b20..ff95c3fb5cddf 100644 --- a/pkg/util/indexparamcheck/ivf_pq_checker_test.go +++ b/internal/util/indexparamcheck/ivf_pq_checker_test.go @@ -143,7 +143,7 @@ func Test_ivfPQChecker_CheckTrain(t *testing.T) { c := newIVFPQChecker() for _, test := range cases { - err := c.CheckTrain(test.params) + err := c.CheckTrain(0, test.params) if test.errIsNil { assert.NoError(t, err) } else { @@ -213,7 +213,7 @@ func Test_ivfPQChecker_CheckValidDataType(t *testing.T) { c := newIVFPQChecker() for _, test := range cases { - err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType}) + err := c.CheckValidDataType("IVF_PQ", &schemapb.FieldSchema{DataType: test.dType}) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/ivf_sq_checker.go b/internal/util/indexparamcheck/ivf_sq_checker.go similarity index 78% rename from pkg/util/indexparamcheck/ivf_sq_checker.go rename to internal/util/indexparamcheck/ivf_sq_checker.go index fc1a2204f5562..9dce89b1c06e7 100644 --- a/pkg/util/indexparamcheck/ivf_sq_checker.go +++ b/internal/util/indexparamcheck/ivf_sq_checker.go @@ -2,6 +2,7 @@ package indexparamcheck import ( "fmt" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) // ivfSQChecker checks if a IVF_SQ index can be built. @@ -22,11 +23,11 @@ func (c *ivfSQChecker) checkNBits(params map[string]string) error { } // CheckTrain returns true if the index can be built with the specific index parameters. -func (c *ivfSQChecker) CheckTrain(params map[string]string) error { +func (c *ivfSQChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { if err := c.checkNBits(params); err != nil { return err } - return c.ivfBaseChecker.CheckTrain(params) + return c.ivfBaseChecker.CheckTrain(0, params) } func newIVFSQChecker() IndexChecker { diff --git a/pkg/util/indexparamcheck/ivf_sq_checker_test.go b/internal/util/indexparamcheck/ivf_sq_checker_test.go similarity index 96% rename from pkg/util/indexparamcheck/ivf_sq_checker_test.go rename to internal/util/indexparamcheck/ivf_sq_checker_test.go index 9478623fe89e3..e9dfc5d14c8a0 100644 --- a/pkg/util/indexparamcheck/ivf_sq_checker_test.go +++ b/internal/util/indexparamcheck/ivf_sq_checker_test.go @@ -92,7 +92,7 @@ func Test_ivfSQChecker_CheckTrain(t *testing.T) { c := newIVFSQChecker() for _, test := range cases { - err := c.CheckTrain(test.params) + err := c.CheckTrain(0, test.params) if test.errIsNil { assert.NoError(t, err) } else { @@ -162,7 +162,7 @@ func Test_ivfSQChecker_CheckValidDataType(t *testing.T) { c := newIVFSQChecker() for _, test := range cases { - err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType}) + err := c.CheckValidDataType("IVF_SQ8", &schemapb.FieldSchema{DataType: test.dType}) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/raft_brute_force_checker.go b/internal/util/indexparamcheck/raft_brute_force_checker.go similarity index 62% rename from pkg/util/indexparamcheck/raft_brute_force_checker.go rename to internal/util/indexparamcheck/raft_brute_force_checker.go index 38872da7ec773..50444e243a4d4 100644 --- a/pkg/util/indexparamcheck/raft_brute_force_checker.go +++ b/internal/util/indexparamcheck/raft_brute_force_checker.go @@ -1,14 +1,17 @@ package indexparamcheck -import "fmt" +import ( + "fmt" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) type raftBruteForceChecker struct { floatVectorBaseChecker } // raftBrustForceChecker checks if a Brute_Force index can be built. -func (c raftBruteForceChecker) CheckTrain(params map[string]string) error { - if err := c.floatVectorBaseChecker.CheckTrain(params); err != nil { +func (c raftBruteForceChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + if err := c.floatVectorBaseChecker.CheckTrain(0, params); err != nil { return err } if !CheckStrByValues(params, Metric, RaftMetrics) { diff --git a/pkg/util/indexparamcheck/raft_brute_force_checker_test.go b/internal/util/indexparamcheck/raft_brute_force_checker_test.go similarity index 96% rename from pkg/util/indexparamcheck/raft_brute_force_checker_test.go rename to internal/util/indexparamcheck/raft_brute_force_checker_test.go index ce037bc4dcb9c..64b4d8275bd60 100644 --- a/pkg/util/indexparamcheck/raft_brute_force_checker_test.go +++ b/internal/util/indexparamcheck/raft_brute_force_checker_test.go @@ -54,7 +54,7 @@ func Test_raftbfChecker_CheckTrain(t *testing.T) { c := newRaftBruteForceChecker() for _, test := range cases { - err := c.CheckTrain(test.params) + err := c.CheckTrain(0, test.params) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/raft_ivf_flat_checker.go b/internal/util/indexparamcheck/raft_ivf_flat_checker.go similarity index 75% rename from pkg/util/indexparamcheck/raft_ivf_flat_checker.go rename to internal/util/indexparamcheck/raft_ivf_flat_checker.go index 9f11803e9b17d..36e2a9bfc7b2b 100644 --- a/pkg/util/indexparamcheck/raft_ivf_flat_checker.go +++ b/internal/util/indexparamcheck/raft_ivf_flat_checker.go @@ -1,6 +1,9 @@ package indexparamcheck -import "fmt" +import ( + "fmt" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) // raftIVFChecker checks if a RAFT_IVF_Flat index can be built. type raftIVFFlatChecker struct { @@ -8,8 +11,8 @@ type raftIVFFlatChecker struct { } // CheckTrain checks if ivf-flat index can be built with the specific index parameters. -func (c *raftIVFFlatChecker) CheckTrain(params map[string]string) error { - if err := c.ivfBaseChecker.CheckTrain(params); err != nil { +func (c *raftIVFFlatChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + if err := c.ivfBaseChecker.CheckTrain(0, params); err != nil { return err } if !CheckStrByValues(params, Metric, RaftMetrics) { diff --git a/pkg/util/indexparamcheck/raft_ivf_flat_checker_test.go b/internal/util/indexparamcheck/raft_ivf_flat_checker_test.go similarity index 96% rename from pkg/util/indexparamcheck/raft_ivf_flat_checker_test.go rename to internal/util/indexparamcheck/raft_ivf_flat_checker_test.go index 3d64f830392f4..bafec248a5b9c 100644 --- a/pkg/util/indexparamcheck/raft_ivf_flat_checker_test.go +++ b/internal/util/indexparamcheck/raft_ivf_flat_checker_test.go @@ -86,7 +86,7 @@ func Test_raftIvfFlatChecker_CheckTrain(t *testing.T) { c := newRaftIVFFlatChecker() for _, test := range cases { - err := c.CheckTrain(test.params) + err := c.CheckTrain(0, test.params) if test.errIsNil { assert.NoError(t, err) } else { @@ -156,7 +156,7 @@ func Test_raftIvfFlatChecker_CheckValidDataType(t *testing.T) { c := newRaftIVFFlatChecker() for _, test := range cases { - err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType}) + err := c.CheckValidDataType("GPU_IVF_FLAT", &schemapb.FieldSchema{DataType: test.dType}) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/raft_ivf_pq_checker.go b/internal/util/indexparamcheck/raft_ivf_pq_checker.go similarity index 88% rename from pkg/util/indexparamcheck/raft_ivf_pq_checker.go rename to internal/util/indexparamcheck/raft_ivf_pq_checker.go index 2457619118070..0b510093253a2 100644 --- a/pkg/util/indexparamcheck/raft_ivf_pq_checker.go +++ b/internal/util/indexparamcheck/raft_ivf_pq_checker.go @@ -2,6 +2,7 @@ package indexparamcheck import ( "fmt" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "strconv" ) @@ -11,8 +12,8 @@ type raftIVFPQChecker struct { } // CheckTrain checks if ivf-pq index can be built with the specific index parameters. -func (c *raftIVFPQChecker) CheckTrain(params map[string]string) error { - if err := c.ivfBaseChecker.CheckTrain(params); err != nil { +func (c *raftIVFPQChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + if err := c.ivfBaseChecker.CheckTrain(0, params); err != nil { return err } if !CheckStrByValues(params, Metric, RaftMetrics) { diff --git a/pkg/util/indexparamcheck/raft_ivf_pq_checker_test.go b/internal/util/indexparamcheck/raft_ivf_pq_checker_test.go similarity index 97% rename from pkg/util/indexparamcheck/raft_ivf_pq_checker_test.go rename to internal/util/indexparamcheck/raft_ivf_pq_checker_test.go index 8c882900e9ef1..04c63ae2cd988 100644 --- a/pkg/util/indexparamcheck/raft_ivf_pq_checker_test.go +++ b/internal/util/indexparamcheck/raft_ivf_pq_checker_test.go @@ -146,7 +146,7 @@ func Test_raftIVFPQChecker_CheckTrain(t *testing.T) { c := newRaftIVFPQChecker() for _, test := range cases { - err := c.CheckTrain(test.params) + err := c.CheckTrain(0, test.params) if test.errIsNil { assert.NoError(t, err) } else { @@ -216,7 +216,7 @@ func Test_raftIVFPQChecker_CheckValidDataType(t *testing.T) { c := newRaftIVFPQChecker() for _, test := range cases { - err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType}) + err := c.CheckValidDataType("GPU_IVF_PQ", &schemapb.FieldSchema{DataType: test.dType}) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/internal/util/indexparamcheck/scalar_index_checker.go b/internal/util/indexparamcheck/scalar_index_checker.go new file mode 100644 index 0000000000000..a1272ae3880bc --- /dev/null +++ b/internal/util/indexparamcheck/scalar_index_checker.go @@ -0,0 +1,11 @@ +package indexparamcheck + +import "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + +type scalarIndexChecker struct { + baseChecker +} + +func (c scalarIndexChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + return nil +} diff --git a/pkg/util/indexparamcheck/scalar_index_checker_test.go b/internal/util/indexparamcheck/scalar_index_checker_test.go similarity index 70% rename from pkg/util/indexparamcheck/scalar_index_checker_test.go rename to internal/util/indexparamcheck/scalar_index_checker_test.go index eb3ae669e2891..423ddb1f66984 100644 --- a/pkg/util/indexparamcheck/scalar_index_checker_test.go +++ b/internal/util/indexparamcheck/scalar_index_checker_test.go @@ -8,5 +8,5 @@ import ( func TestCheckIndexValid(t *testing.T) { scalarIndexChecker := &scalarIndexChecker{} - assert.NoError(t, scalarIndexChecker.CheckTrain(map[string]string{})) + assert.NoError(t, scalarIndexChecker.CheckTrain(0, map[string]string{})) } diff --git a/pkg/util/indexparamcheck/scann_checker.go b/internal/util/indexparamcheck/scann_checker.go similarity index 78% rename from pkg/util/indexparamcheck/scann_checker.go rename to internal/util/indexparamcheck/scann_checker.go index eecf2ded64bbf..418a647c09d48 100644 --- a/pkg/util/indexparamcheck/scann_checker.go +++ b/internal/util/indexparamcheck/scann_checker.go @@ -2,6 +2,7 @@ package indexparamcheck import ( "fmt" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "strconv" ) @@ -11,8 +12,8 @@ type scaNNChecker struct { } // CheckTrain checks if SCANN index can be built with the specific index parameters. -func (c *scaNNChecker) CheckTrain(params map[string]string) error { - if err := c.ivfBaseChecker.CheckTrain(params); err != nil { +func (c *scaNNChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + if err := c.ivfBaseChecker.CheckTrain(0, params); err != nil { return err } diff --git a/pkg/util/indexparamcheck/scann_checker_test.go b/internal/util/indexparamcheck/scann_checker_test.go similarity index 96% rename from pkg/util/indexparamcheck/scann_checker_test.go rename to internal/util/indexparamcheck/scann_checker_test.go index 4f7014c6fde53..13a54597b0716 100644 --- a/pkg/util/indexparamcheck/scann_checker_test.go +++ b/internal/util/indexparamcheck/scann_checker_test.go @@ -89,7 +89,7 @@ func Test_scaNNChecker_CheckTrain(t *testing.T) { c := newScaNNChecker() for _, test := range cases { - err := c.CheckTrain(test.params) + err := c.CheckTrain(0, test.params) if test.errIsNil { assert.NoError(t, err) } else { @@ -159,7 +159,7 @@ func Test_scaNNChecker_CheckValidDataType(t *testing.T) { c := newScaNNChecker() for _, test := range cases { - err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType}) + err := c.CheckValidDataType("SCANN", &schemapb.FieldSchema{DataType: test.dType}) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/sparse_float_vector_base_checker.go b/internal/util/indexparamcheck/sparse_float_vector_base_checker.go similarity index 75% rename from pkg/util/indexparamcheck/sparse_float_vector_base_checker.go rename to internal/util/indexparamcheck/sparse_float_vector_base_checker.go index 218d2d3e03a3e..edb91157fa1e2 100644 --- a/pkg/util/indexparamcheck/sparse_float_vector_base_checker.go +++ b/internal/util/indexparamcheck/sparse_float_vector_base_checker.go @@ -12,7 +12,7 @@ import ( // sparse vector don't check for dim, but baseChecker does, thus not including baseChecker type sparseFloatVectorBaseChecker struct{} -func (c sparseFloatVectorBaseChecker) StaticCheck(params map[string]string) error { +func (c sparseFloatVectorBaseChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error { if !CheckStrByValues(params, Metric, SparseMetrics) { return fmt.Errorf("metric type not found or not supported, supported: %v", SparseMetrics) } @@ -20,7 +20,7 @@ func (c sparseFloatVectorBaseChecker) StaticCheck(params map[string]string) erro return nil } -func (c sparseFloatVectorBaseChecker) CheckTrain(params map[string]string) error { +func (c sparseFloatVectorBaseChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { dropRatioBuildStr, exist := params[SparseDropRatioBuild] if exist { dropRatioBuild, err := strconv.ParseFloat(dropRatioBuildStr, 64) @@ -32,14 +32,14 @@ func (c sparseFloatVectorBaseChecker) CheckTrain(params map[string]string) error return nil } -func (c sparseFloatVectorBaseChecker) CheckValidDataType(field *schemapb.FieldSchema) error { +func (c sparseFloatVectorBaseChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { if !typeutil.IsSparseFloatVectorType(field.GetDataType()) { return fmt.Errorf("only sparse float vector is supported for the specified index tpye") } return nil } -func (c sparseFloatVectorBaseChecker) SetDefaultMetricTypeIfNotExist(params map[string]string, dType schemapb.DataType) { +func (c sparseFloatVectorBaseChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, params map[string]string) { setDefaultIfNotExist(params, common.MetricTypeKey, SparseFloatVectorDefaultMetricType) } diff --git a/pkg/util/indexparamcheck/sparse_inverted_index_checker.go b/internal/util/indexparamcheck/sparse_inverted_index_checker.go similarity index 100% rename from pkg/util/indexparamcheck/sparse_inverted_index_checker.go rename to internal/util/indexparamcheck/sparse_inverted_index_checker.go diff --git a/pkg/util/indexparamcheck/stl_sort_checker.go b/internal/util/indexparamcheck/stl_sort_checker.go similarity index 65% rename from pkg/util/indexparamcheck/stl_sort_checker.go rename to internal/util/indexparamcheck/stl_sort_checker.go index 4b3441ad6dfcf..6e3339c0ef5bd 100644 --- a/pkg/util/indexparamcheck/stl_sort_checker.go +++ b/internal/util/indexparamcheck/stl_sort_checker.go @@ -12,11 +12,11 @@ type STLSORTChecker struct { scalarIndexChecker } -func (c *STLSORTChecker) CheckTrain(params map[string]string) error { - return c.scalarIndexChecker.CheckTrain(params) +func (c *STLSORTChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + return c.scalarIndexChecker.CheckTrain(0, params) } -func (c *STLSORTChecker) CheckValidDataType(field *schemapb.FieldSchema) error { +func (c *STLSORTChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { if !typeutil.IsArithmetic(field.GetDataType()) { return fmt.Errorf("STL_SORT are only supported on numeric field") } diff --git a/internal/util/indexparamcheck/stl_sort_checker_test.go b/internal/util/indexparamcheck/stl_sort_checker_test.go new file mode 100644 index 0000000000000..3e42c2810db65 --- /dev/null +++ b/internal/util/indexparamcheck/stl_sort_checker_test.go @@ -0,0 +1,22 @@ +package indexparamcheck + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +func Test_STLSORTIndexChecker(t *testing.T) { + c := newSTLSORTChecker() + + assert.NoError(t, c.CheckTrain(0, map[string]string{})) + + assert.NoError(t, c.CheckValidDataType(IndexSTLSORT, &schemapb.FieldSchema{DataType: schemapb.DataType_Int64})) + assert.NoError(t, c.CheckValidDataType(IndexSTLSORT, &schemapb.FieldSchema{DataType: schemapb.DataType_Float})) + + assert.Error(t, c.CheckValidDataType(IndexSTLSORT, &schemapb.FieldSchema{DataType: schemapb.DataType_VarChar})) + assert.Error(t, c.CheckValidDataType(IndexSTLSORT, &schemapb.FieldSchema{DataType: schemapb.DataType_Bool})) + assert.Error(t, c.CheckValidDataType(IndexSTLSORT, &schemapb.FieldSchema{DataType: schemapb.DataType_JSON})) +} diff --git a/pkg/util/indexparamcheck/trie_checker.go b/internal/util/indexparamcheck/trie_checker.go similarity index 64% rename from pkg/util/indexparamcheck/trie_checker.go rename to internal/util/indexparamcheck/trie_checker.go index 002014e42022c..2af351860f59f 100644 --- a/pkg/util/indexparamcheck/trie_checker.go +++ b/internal/util/indexparamcheck/trie_checker.go @@ -12,11 +12,11 @@ type TRIEChecker struct { scalarIndexChecker } -func (c *TRIEChecker) CheckTrain(params map[string]string) error { - return c.scalarIndexChecker.CheckTrain(params) +func (c *TRIEChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + return c.scalarIndexChecker.CheckTrain(0, params) } -func (c *TRIEChecker) CheckValidDataType(field *schemapb.FieldSchema) error { +func (c *TRIEChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { if !typeutil.IsStringType(field.GetDataType()) { return fmt.Errorf("TRIE are only supported on varchar field") } diff --git a/internal/util/indexparamcheck/trie_checker_test.go b/internal/util/indexparamcheck/trie_checker_test.go new file mode 100644 index 0000000000000..81a98b664832b --- /dev/null +++ b/internal/util/indexparamcheck/trie_checker_test.go @@ -0,0 +1,23 @@ +package indexparamcheck + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +func Test_TrieIndexChecker(t *testing.T) { + c := newTRIEChecker() + + assert.NoError(t, c.CheckTrain(0, map[string]string{})) + + assert.NoError(t, c.CheckValidDataType(IndexTRIE, &schemapb.FieldSchema{DataType: schemapb.DataType_VarChar})) + assert.NoError(t, c.CheckValidDataType(IndexTRIE, &schemapb.FieldSchema{DataType: schemapb.DataType_String})) + + assert.Error(t, c.CheckValidDataType(IndexTRIE, &schemapb.FieldSchema{DataType: schemapb.DataType_Bool})) + assert.Error(t, c.CheckValidDataType(IndexTRIE, &schemapb.FieldSchema{DataType: schemapb.DataType_Int64})) + assert.Error(t, c.CheckValidDataType(IndexTRIE, &schemapb.FieldSchema{DataType: schemapb.DataType_Float})) + assert.Error(t, c.CheckValidDataType(IndexTRIE, &schemapb.FieldSchema{DataType: schemapb.DataType_JSON})) +} diff --git a/pkg/util/indexparamcheck/utils.go b/internal/util/indexparamcheck/utils.go similarity index 100% rename from pkg/util/indexparamcheck/utils.go rename to internal/util/indexparamcheck/utils.go diff --git a/pkg/util/indexparamcheck/utils_test.go b/internal/util/indexparamcheck/utils_test.go similarity index 100% rename from pkg/util/indexparamcheck/utils_test.go rename to internal/util/indexparamcheck/utils_test.go diff --git a/internal/util/indexparamcheck/vector_index_checker.go b/internal/util/indexparamcheck/vector_index_checker.go new file mode 100644 index 0000000000000..1d65dc1eb1a02 --- /dev/null +++ b/internal/util/indexparamcheck/vector_index_checker.go @@ -0,0 +1,95 @@ +package indexparamcheck + +/* +#cgo pkg-config: milvus_core + +#include // free +#include "segcore/vector_index_c.h" +*/ +import "C" +import ( + "fmt" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/proto/indexcgopb" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/typeutil" + "google.golang.org/protobuf/proto" + "unsafe" +) + +type vecIndexChecker struct { + baseChecker +} + +// HandleCStatus deals with the error returned from CGO +func HandleCStatus(status *C.CStatus) error { + if status.error_code == 0 { + return nil + } + errorCode := status.error_code + errorMsg := C.GoString(status.error_msg) + defer C.free(unsafe.Pointer(status.error_msg)) + + return fmt.Errorf("code %s, msg %s", errorCode, errorMsg) +} + +func (c vecIndexChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error { + indexType, exist := params[common.IndexTypeKey] + + if !exist { + return fmt.Errorf("no indexType is specified") + } + + if GetVecIndexMgrInstance().IsVecIndex(indexType) { + return fmt.Errorf("indexType %s is not supported", indexType) + } + + protoIndexParams := &indexcgopb.IndexParams{ + Params: make([]*commonpb.KeyValuePair, 0), + } + indexParamsBlob, err := proto.Marshal(protoIndexParams) + if err != nil { + return fmt.Errorf("failed to marshal index params: %s", err) + } + + var status C.CStatus + + cIndexType := C.CString(indexType) + cDataType := uint32(dataType) + status = C.ValidateIndexParams(cIndexType, cDataType, (*C.uint8_t)(unsafe.Pointer(&indexParamsBlob[0])), (C.uint64_t)(len(indexParamsBlob))) + C.free(unsafe.Pointer(cIndexType)) + + return HandleCStatus(&status) +} + +func (c vecIndexChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + if err := c.StaticCheck(dataType, params); err != nil { + return err + } + return c.baseChecker.CheckTrain(0, params) +} + +func (c vecIndexChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { + if !typeutil.IsVectorType(field.GetDataType()) { + return fmt.Errorf("can not create vector index for an invalid datatype") + } + if GetVecIndexMgrInstance().IsDataTypeSupport(indexType, field.GetDataType()) { + return fmt.Errorf("vector index %s do not support data type: %s", indexType, schemapb.DataType_name[int32(field.GetDataType())]) + } + return nil +} + +func (c vecIndexChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, params map[string]string) { + if typeutil.IsDenseFloatVectorType(dType) { + setDefaultIfNotExist(params, common.MetricTypeKey, FloatVectorDefaultMetricType) + } else if typeutil.IsSparseFloatVectorType(dType) { + setDefaultIfNotExist(params, common.MetricTypeKey, SparseFloatVectorDefaultMetricType) + } else if typeutil.IsBinaryVectorType(dType) { + setDefaultIfNotExist(params, common.MetricTypeKey, BinaryVectorDefaultMetricType) + } +} + +func newVecIndexChecker() IndexChecker { + return &vecIndexChecker{} +} diff --git a/pkg/util/indexparamcheck/vector_index_mgr.go b/internal/util/indexparamcheck/vector_index_mgr.go similarity index 69% rename from pkg/util/indexparamcheck/vector_index_mgr.go rename to internal/util/indexparamcheck/vector_index_mgr.go index c6d9e1bcc22d9..bc62c2dd586c0 100644 --- a/pkg/util/indexparamcheck/vector_index_mgr.go +++ b/internal/util/indexparamcheck/vector_index_mgr.go @@ -17,7 +17,7 @@ package indexparamcheck /* -#cgo pkg-config: milvus_segcore +#cgo pkg-config: milvus_core #include // free #include "segcore/vector_index_c.h" @@ -27,12 +27,19 @@ import "C" import ( "bytes" "fmt" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/log" "sync" "unsafe" ) const ( + BinaryFlag uint64 = 1 << 0 + Float32Flag uint64 = 1 << 1 + Float16Flag uint64 = 1 << 2 + BFloat16Flag uint64 = 1 << 3 + SparseFloat32Flag uint64 = 1 << 4 + // BFFlag This flag indicates that there is no need to create any index structure BFFlag uint64 = 1 << 16 // KNNFlag This flag indicates that the index defaults to KNN search, meaning the recall rate is 100% @@ -49,6 +56,14 @@ const ( type VecIndexMgr interface { init() error + + IsBinarySupport(indexType IndexType) bool + IsFlat32Support(indexType IndexType) bool + IsFlat16Support(indexType IndexType) bool + IsBFlat16Support(indexType IndexType) bool + IsSparseFloat32Support(indexType IndexType) bool + IsDataTypeSupport(indexType IndexType, dataType schemapb.DataType) bool + IsFlatVecIndex(indexType IndexType) bool IsBruteForce(indexType IndexType) bool IsVecIndex(indexType IndexType) bool @@ -98,6 +113,61 @@ func (mgr *VecIndexMgrImpl) init() error { return nil } +func (mgr *VecIndexMgrImpl) IsBinarySupport(indexType IndexType) bool { + feature, ok := mgr.features[indexType] + if !ok { + return false + } + return (feature & BinaryFlag) == BinaryFlag +} + +func (mgr *VecIndexMgrImpl) IsFlat32Support(indexType IndexType) bool { + feature, ok := mgr.features[indexType] + if !ok { + return false + } + return (feature & Float32Flag) == Float32Flag +} + +func (mgr *VecIndexMgrImpl) IsFlat16Support(indexType IndexType) bool { + feature, ok := mgr.features[indexType] + if !ok { + return false + } + return (feature & Float16Flag) == Float16Flag +} +func (mgr *VecIndexMgrImpl) IsBFlat16Support(indexType IndexType) bool { + feature, ok := mgr.features[indexType] + if !ok { + return false + } + return (feature & BFloat16Flag) == BFloat16Flag +} + +func (mgr *VecIndexMgrImpl) IsSparseFloat32Support(indexType IndexType) bool { + feature, ok := mgr.features[indexType] + if !ok { + return false + } + return (feature & SparseFloat32Flag) == SparseFloat32Flag +} + +func (mgr *VecIndexMgrImpl) IsDataTypeSupport(indexType IndexType, dataType schemapb.DataType) bool { + if dataType == schemapb.DataType_BinaryVector { + return mgr.IsBinarySupport(indexType) + } else if dataType == schemapb.DataType_FloatVector { + return mgr.IsFlat32Support(indexType) + } else if dataType == schemapb.DataType_BFloat16Vector { + return mgr.IsBFlat16Support(indexType) + } else if dataType == schemapb.DataType_Float16Vector { + return mgr.IsFlat16Support(indexType) + } else if dataType == schemapb.DataType_SparseFloatVector { + return mgr.IsSparseFloat32Support(indexType) + } else { + return false + } +} + func (mgr *VecIndexMgrImpl) IsFlatVecIndex(indexType IndexType) bool { feature, ok := mgr.features[indexType] if !ok { diff --git a/pkg/util/indexparamcheck/vector_index_mgr_test.go b/internal/util/indexparamcheck/vector_index_mgr_test.go similarity index 100% rename from pkg/util/indexparamcheck/vector_index_mgr_test.go rename to internal/util/indexparamcheck/vector_index_mgr_test.go diff --git a/pkg/config/config.go b/pkg/config/config.go index fc93c086f74d4..6c7ee1d8e9aa4 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -17,6 +17,9 @@ package config import ( + "github.com/milvus-io/milvus/pkg/log" + "github.com/spf13/cast" + "go.uber.org/zap" "strings" "github.com/cockroachdb/errors" @@ -30,6 +33,10 @@ var ( ErrKeyNotFound = errors.New("key not found") ) +const ( + NotFormatPrefix = "knowhere." +) + func Init(opts ...Option) (*Manager, error) { o := &Options{} for _, opt := range opts { @@ -56,6 +63,9 @@ func Init(opts ...Option) (*Manager, error) { var formattedKeys = typeutil.NewConcurrentMap[string, string]() func formatKey(key string) string { + if strings.HasPrefix(key, NotFormatPrefix) { + return key + } cached, ok := formattedKeys.Get(key) if ok { return cached @@ -64,3 +74,41 @@ func formatKey(key string) string { formattedKeys.Insert(key, result) return result } + +func parseConfig(prefix string, m map[string]interface{}, result map[string]string) { + for k, v := range m { + fullKey := k + if prefix != "" { + fullKey = prefix + "." + k + } + + switch val := v.(type) { + case map[string]interface{}: + parseConfig(fullKey, val, result) + case []interface{}: + str := "" + for i, item := range val { + itemStr, err := cast.ToStringE(item) + if err != nil { + log.Warn("cast to string failed", zap.Any("item", item)) + continue + } + if i == 0 { + str = itemStr + } else { + str = str + "," + itemStr + } + } + result[fullKey] = str + result[formatKey(fullKey)] = str + default: + str, err := cast.ToStringE(val) + if err != nil { + log.Warn("cast to string failed", zap.Any("val", val)) + continue + } + result[fullKey] = str + result[formatKey(fullKey)] = str + } + } +} diff --git a/pkg/config/file_source.go b/pkg/config/file_source.go index e8402efe6b6ad..953de511ff436 100644 --- a/pkg/config/file_source.go +++ b/pkg/config/file_source.go @@ -17,13 +17,12 @@ package config import ( + "gopkg.in/yaml.v3" "os" "sync" "github.com/cockroachdb/errors" "github.com/samber/lo" - "github.com/spf13/cast" - "github.com/spf13/viper" "go.uber.org/zap" "github.com/milvus-io/milvus/pkg/log" @@ -115,7 +114,6 @@ func (fs *FileSource) UpdateOptions(opts Options) { } func (fs *FileSource) loadFromFile() error { - yamlReader := viper.New() newConfig := make(map[string]string) var configFiles []string @@ -128,38 +126,19 @@ func (fs *FileSource) loadFromFile() error { continue } - yamlReader.SetConfigFile(configFile) - if err := yamlReader.ReadInConfig(); err != nil { + data, err := os.ReadFile(configFile) + if err != nil { return errors.Wrap(err, "Read config failed: "+configFile) } - for _, key := range yamlReader.AllKeys() { - val := yamlReader.Get(key) - str, err := cast.ToStringE(val) - if err != nil { - switch val := val.(type) { - case []any: - str = str[:0] - for _, v := range val { - ss, err := cast.ToStringE(v) - if err != nil { - log.Warn("cast to string failed", zap.Any("value", v)) - } - if str == "" { - str = ss - } else { - str = str + "," + ss - } - } - - default: - log.Warn("val is not a slice", zap.Any("value", val)) - continue - } - } - newConfig[key] = str - newConfig[formatKey(key)] = str + var config map[string]interface{} + + err = yaml.Unmarshal(data, &config) + if err != nil { + return errors.Wrap(err, "Unmarshal config failed: "+configFile) } + + parseConfig("", config, newConfig) } return fs.update(newConfig) diff --git a/pkg/util/indexparamcheck/bin_flat_checker.go b/pkg/util/indexparamcheck/bin_flat_checker.go deleted file mode 100644 index 2e0b813c38402..0000000000000 --- a/pkg/util/indexparamcheck/bin_flat_checker.go +++ /dev/null @@ -1,17 +0,0 @@ -package indexparamcheck - -type binFlatChecker struct { - binaryVectorBaseChecker -} - -func (c binFlatChecker) CheckTrain(params map[string]string) error { - return c.binaryVectorBaseChecker.CheckTrain(params) -} - -func (c binFlatChecker) StaticCheck(params map[string]string) error { - return c.staticCheck(params) -} - -func newBinFlatChecker() IndexChecker { - return &binFlatChecker{} -} diff --git a/pkg/util/indexparamcheck/hybrid_checker_test.go b/pkg/util/indexparamcheck/hybrid_checker_test.go deleted file mode 100644 index 733adc2922804..0000000000000 --- a/pkg/util/indexparamcheck/hybrid_checker_test.go +++ /dev/null @@ -1,37 +0,0 @@ -package indexparamcheck - -import ( - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" -) - -func Test_HybridIndexChecker(t *testing.T) { - c := newHYBRIDChecker() - - assert.NoError(t, c.CheckTrain(map[string]string{"bitmap_cardinality_limit": "100"})) - - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Bool})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int8})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int16})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int32})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int64})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_String})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Bool})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int8})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int16})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int32})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_String})) - - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_JSON})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Float})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Double})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Float})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Double})) - assert.Error(t, c.CheckTrain(map[string]string{})) - assert.Error(t, c.CheckTrain(map[string]string{"bitmap_cardinality_limit": "0"})) - assert.Error(t, c.CheckTrain(map[string]string{"bitmap_cardinality_limit": "2000"})) -} diff --git a/pkg/util/indexparamcheck/inverted_checker_test.go b/pkg/util/indexparamcheck/inverted_checker_test.go deleted file mode 100644 index baecd97dd1766..0000000000000 --- a/pkg/util/indexparamcheck/inverted_checker_test.go +++ /dev/null @@ -1,25 +0,0 @@ -package indexparamcheck - -import ( - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" -) - -func Test_INVERTEDIndexChecker(t *testing.T) { - c := newINVERTEDChecker() - - assert.NoError(t, c.CheckTrain(map[string]string{})) - - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_VarChar})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_String})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Bool})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int64})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Float})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array})) - - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_JSON})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_FloatVector})) -} diff --git a/pkg/util/indexparamcheck/ivf_base_checker.go b/pkg/util/indexparamcheck/ivf_base_checker.go deleted file mode 100644 index 9b8a3e2e045a0..0000000000000 --- a/pkg/util/indexparamcheck/ivf_base_checker.go +++ /dev/null @@ -1,26 +0,0 @@ -package indexparamcheck - -type ivfBaseChecker struct { - floatVectorBaseChecker -} - -func (c ivfBaseChecker) StaticCheck(params map[string]string) error { - if !CheckIntByRange(params, NLIST, MinNList, MaxNList) { - return errOutOfRange(NLIST, MinNList, MaxNList) - } - - // skip check number of rows - - return c.floatVectorBaseChecker.staticCheck(params) -} - -func (c ivfBaseChecker) CheckTrain(params map[string]string) error { - if err := c.StaticCheck(params); err != nil { - return err - } - return c.floatVectorBaseChecker.CheckTrain(params) -} - -func newIVFBaseChecker() IndexChecker { - return &ivfBaseChecker{} -} diff --git a/pkg/util/indexparamcheck/scalar_index_checker.go b/pkg/util/indexparamcheck/scalar_index_checker.go deleted file mode 100644 index 9c372f4034c10..0000000000000 --- a/pkg/util/indexparamcheck/scalar_index_checker.go +++ /dev/null @@ -1,9 +0,0 @@ -package indexparamcheck - -type scalarIndexChecker struct { - baseChecker -} - -func (c scalarIndexChecker) CheckTrain(params map[string]string) error { - return nil -} diff --git a/pkg/util/indexparamcheck/stl_sort_checker_test.go b/pkg/util/indexparamcheck/stl_sort_checker_test.go deleted file mode 100644 index 771a51cd32f68..0000000000000 --- a/pkg/util/indexparamcheck/stl_sort_checker_test.go +++ /dev/null @@ -1,22 +0,0 @@ -package indexparamcheck - -import ( - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" -) - -func Test_STLSORTIndexChecker(t *testing.T) { - c := newSTLSORTChecker() - - assert.NoError(t, c.CheckTrain(map[string]string{})) - - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int64})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Float})) - - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_VarChar})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Bool})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_JSON})) -} diff --git a/pkg/util/indexparamcheck/trie_checker_test.go b/pkg/util/indexparamcheck/trie_checker_test.go deleted file mode 100644 index 3e1eaea1c5890..0000000000000 --- a/pkg/util/indexparamcheck/trie_checker_test.go +++ /dev/null @@ -1,23 +0,0 @@ -package indexparamcheck - -import ( - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" -) - -func Test_TrieIndexChecker(t *testing.T) { - c := newTRIEChecker() - - assert.NoError(t, c.CheckTrain(map[string]string{})) - - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_VarChar})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_String})) - - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Bool})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int64})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Float})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_JSON})) -} diff --git a/pkg/util/paramtable/autoindex_param.go b/pkg/util/paramtable/autoindex_param.go index 31df71a4a358d..514a0fb4e93b8 100644 --- a/pkg/util/paramtable/autoindex_param.go +++ b/pkg/util/paramtable/autoindex_param.go @@ -18,12 +18,12 @@ package paramtable import ( "fmt" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/config" "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" ) // ///////////////////////////////////////////////////////////////////////////// @@ -241,9 +241,9 @@ func (p *autoIndexConfig) panicIfNotValidAndSetDefaultMetricTypeHelper(key strin panic(fmt.Sprintf("%s invalid, unsupported index type: %s", key, indexType)) } - checker.SetDefaultMetricTypeIfNotExist(m, dtype) + checker.SetDefaultMetricTypeIfNotExist(dtype, m) - if err := checker.StaticCheck(m); err != nil { + if err := checker.StaticCheck(dtype, m); err != nil { panic(fmt.Sprintf("%s invalid, parameters invalid, error: %s", key, err.Error())) } diff --git a/pkg/util/paramtable/autoindex_param_test.go b/pkg/util/paramtable/autoindex_param_test.go index 231c8377e7a9d..71d174ffc73fd 100644 --- a/pkg/util/paramtable/autoindex_param_test.go +++ b/pkg/util/paramtable/autoindex_param_test.go @@ -18,6 +18,7 @@ package paramtable import ( "encoding/json" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "strconv" "testing" @@ -26,7 +27,6 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/config" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" ) const ( diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index b51b1e2bfbc6c..eb1e114ba4edd 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -77,6 +77,7 @@ type ComponentParam struct { DataCoordCfg dataCoordConfig DataNodeCfg dataNodeConfig IndexNodeCfg indexNodeConfig + IndexEngineConfig indexEngineConfig HTTPCfg httpConfig LogCfg logConfig RoleCfg roleConfig @@ -139,6 +140,7 @@ func (p *ComponentParam) init(bt *BaseTable) { p.GpuConfig.init(bt) p.StreamingCoordCfg.init(bt) p.StreamingNodeCfg.init(bt) + p.IndexEngineConfig.init(bt) p.RootCoordGrpcServerCfg.Init("rootCoord", bt) p.ProxyGrpcServerCfg.Init("proxy", bt) diff --git a/pkg/util/paramtable/indexengine_param.go b/pkg/util/paramtable/indexengine_param.go new file mode 100644 index 0000000000000..b1e74f062ac93 --- /dev/null +++ b/pkg/util/paramtable/indexengine_param.go @@ -0,0 +1,117 @@ +package paramtable + +import ( + "fmt" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus/pkg/util/hardware" + "strconv" + "strings" +) + +type indexEngineConfig struct { + Enable ParamItem `refreshable:"true"` + IndexParam ParamGroup `refreshable:"true"` +} + +const ( + BuildStage = "build" + LoadStage = "load" + SearchStage = "search" +) + +const ( + BuildDramBudgetKey = "build_dram_budget_gb" + NumBuildThreadKey = "num_build_thread" + VecFieldSizeKey = "vec_field_size_gb" +) + +func (p *indexEngineConfig) init(base *BaseTable) { + p.IndexParam = ParamGroup{ + KeyPrefix: "knowhere.", + Version: "2.5.0", + } + p.IndexParam.Init(base.mgr) + + p.Enable = ParamItem{ + Key: "knowhere.enable", + Version: "2.5.0", + DefaultValue: "false", + } + p.Enable.Init(base.mgr) +} + +func (p *indexEngineConfig) getIndexParam(indexType string, stage string) map[string]string { + matchedParam := make(map[string]string) + + params := p.IndexParam.GetValue() + prefix := indexType + "." + stage + "." + + for k, v := range params { + if strings.HasPrefix(k, prefix) { + matchedParam[strings.TrimPrefix(k, prefix)] = v + } + } + + return matchedParam +} + +func GetKeyFromSlice(indexParams []*commonpb.KeyValuePair, key string) string { + for _, param := range indexParams { + if param.Key == key { + return param.Value + } + } + return "" +} + +func (p *indexEngineConfig) GetRuntimeParam(stage string) (map[string]string, error) { + params := make(map[string]string) + + if stage == BuildStage { + params[BuildDramBudgetKey] = fmt.Sprintf("%f", float32(hardware.GetFreeMemoryCount())/(1<<30)) + params[NumBuildThreadKey] = strconv.Itoa(int(float32(hardware.GetCPUNum()))) + } + + return params, nil +} + +func (p *indexEngineConfig) MergeRequestParam(indexType string, stage string, indexParams []*commonpb.KeyValuePair) ([]*commonpb.KeyValuePair, error) { + defaultParams := p.getIndexParam(indexType, stage) + + for key, val := range defaultParams { + if GetKeyFromSlice(indexParams, key) == "" { + indexParams = append(indexParams, + &commonpb.KeyValuePair{ + Key: key, + Value: val, + }) + } + } + + return indexParams, nil +} + +func (p *indexEngineConfig) MergeWithResource(vecFieldSize uint64, indexParam map[string]string) (map[string]string, error) { + param, _ := p.GetRuntimeParam(BuildStage) + + for key, val := range param { + indexParam[key] = val + } + + indexParam[VecFieldSizeKey] = fmt.Sprintf("%f", float32(vecFieldSize)/(1<<30)) + + return indexParam, nil +} + +func (p *indexEngineConfig) MergeRequestMapParam(indexType string, stage string, indexParam map[string]string) (map[string]string, error) { + defaultParams := p.getIndexParam(indexType, stage) + + for key, val := range defaultParams { + _, existed := indexParam[key] + if !existed { + indexParam[key] = val + } + } + + return indexParam, nil +} diff --git a/pkg/util/paramtable/indexengine_param_test.go b/pkg/util/paramtable/indexengine_param_test.go new file mode 100644 index 0000000000000..80ae21241125b --- /dev/null +++ b/pkg/util/paramtable/indexengine_param_test.go @@ -0,0 +1,20 @@ +package paramtable + +import "testing" + +func TestIndexEngineConfig_Init(t *testing.T) { + params := ComponentParam{} + params.Init(NewBaseTable(SkipRemote(true))) + + cfg := ¶ms.IndexEngineConfig + print(cfg) +} + +func TestIndexEngineConfig_Get(t *testing.T) { + params := ComponentParam{} + params.Init(NewBaseTable(SkipRemote(true))) + + cfg := ¶ms.IndexEngineConfig + diskANNbuild := cfg.getIndexParam("DISKANN", BuildStage) + print(diskANNbuild) +} diff --git a/tests/integration/import/import_test.go b/tests/integration/import/import_test.go index 0eb43b55f8e3f..663f0bae22df1 100644 --- a/tests/integration/import/import_test.go +++ b/tests/integration/import/import_test.go @@ -19,6 +19,7 @@ package importv2 import ( "context" "fmt" + indexparamcheck2 "github.com/milvus-io/milvus/internal/util/indexparamcheck" "math/rand" "os" "testing" @@ -53,7 +54,7 @@ type BulkInsertSuite struct { fileType importutilv2.FileType vecType schemapb.DataType - indexType indexparamcheck.IndexType + indexType indexparamcheck2.IndexType metricType metric.MetricType }