diff --git a/cmd/tools/migration/mmap/mmap_230_240.go b/cmd/tools/migration/mmap/mmap_230_240.go index 8994551d02d7a..3e0a1f7cd6075 100644 --- a/cmd/tools/migration/mmap/mmap_230_240.go +++ b/cmd/tools/migration/mmap/mmap_230_240.go @@ -9,8 +9,8 @@ import ( "github.com/milvus-io/milvus/internal/metastore/model" "github.com/milvus-io/milvus/internal/rootcoord" "github.com/milvus-io/milvus/internal/tso" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/common" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" ) // In Milvus 2.3.x, querynode.MmapDirPath is used to enable mmap and save mmap files. diff --git a/configs/milvus.yaml b/configs/milvus.yaml index f931e3cbde1dc..82997519b7ae4 100644 --- a/configs/milvus.yaml +++ b/configs/milvus.yaml @@ -1044,3 +1044,23 @@ streamingNode: serverMaxRecvSize: 268435456 # The maximum size of each RPC request that the streamingNode can receive, unit: byte clientMaxSendSize: 268435456 # The maximum size of each RPC request that the clients on streamingNode can send, unit: byte clientMaxRecvSize: 268435456 # The maximum size of each RPC request that the clients on streamingNode can receive, unit: byte + +knowhere: + enable: true + HNSW: + build: + efConstruction : 360 + M: 30 + search: + ef: 30 + DISKANN: + build: + max_degree: 56 + search_list_size: 100 + pq_code_budget_gb_ratio: 0.125 + search_cache_budget_gb_ratio: 0.1 + beam_width_ratio: 4 + load: + cacheRatio: 0.1 + search: + beamRatio: 4.0 \ No newline at end of file diff --git a/internal/core/src/index/Index.h b/internal/core/src/index/Index.h index 87cb5ae683cb6..bb136a10757b1 100644 --- a/internal/core/src/index/Index.h +++ b/internal/core/src/index/Index.h @@ -63,13 +63,8 @@ class IndexBase { virtual const bool HasRawData() const = 0; - bool - IsMmapSupported() const { - return knowhere::IndexFactory::Instance().FeatureCheck(index_type_, knowhere::feature::MMAP) || - // support mmap for bitmap/hybrid index - index_type_ == milvus::index::BITMAP_INDEX_TYPE || - index_type_ == milvus::index::HYBRID_INDEX_TYPE; - } + virtual bool + IsMmapSupported() const = 0; const IndexType& Type() const { diff --git a/internal/core/src/index/ScalarIndex.h b/internal/core/src/index/ScalarIndex.h index 6105ce4afb980..f3b602c8821df 100644 --- a/internal/core/src/index/ScalarIndex.h +++ b/internal/core/src/index/ScalarIndex.h @@ -133,6 +133,12 @@ class ScalarIndex : public IndexBase { PanicInfo(Unsupported, "pattern match is not supported"); } + virtual bool + IsMmapSupported() const { + return index_type_ == milvus::index::BITMAP_INDEX_TYPE || + index_type_ == milvus::index::HYBRID_INDEX_TYPE; + } + virtual int64_t Size() = 0; diff --git a/internal/core/src/index/VectorIndex.h b/internal/core/src/index/VectorIndex.h index 540b93d4a7e78..95655db9e544a 100644 --- a/internal/core/src/index/VectorIndex.h +++ b/internal/core/src/index/VectorIndex.h @@ -115,6 +115,11 @@ class VectorIndex : public IndexBase { err_msg); } + virtual bool + IsMmapSupported() const { + return knowhere::IndexFactory::Instance().FeatureCheck(index_type_, knowhere::feature::MMAP); + } + knowhere::Json PrepareSearchParams(const SearchInfo& search_info) const { knowhere::Json search_cfg = search_info.search_params_; diff --git a/internal/core/src/segcore/vector_index_c.cpp b/internal/core/src/segcore/vector_index_c.cpp index b45a684d57903..84f203e0fdee1 100644 --- a/internal/core/src/segcore/vector_index_c.cpp +++ b/internal/core/src/segcore/vector_index_c.cpp @@ -11,9 +11,63 @@ #include "segcore/vector_index_c.h" +#include "common/Types.h" +#include "common/EasyAssert.h" #include "knowhere/utils.h" +#include "knowhere/config.h" +#include "knowhere/version.h" #include "index/Meta.h" #include "index/IndexFactory.h" +#include "pb/index_cgo_msg.pb.h" + +CStatus +ValidateIndexParams(const char* index_type, enum CDataType data_type, const uint8_t* serialized_index_params, const uint64_t length) { + try { + auto index_params = + std::make_unique(); + auto res = index_params->ParseFromArray(serialized_index_params, length); + AssertInfo(res, "Unmarshall index params failed"); + + knowhere::Json json; + + for (size_t i = 0; i < index_params->params_size(); i++) { + auto& param = index_params->params(i); + json[param.key()] = param.value(); + } + + milvus::DataType dataType(static_cast(data_type)); + + knowhere::Status status; + std::string error_msg; + if (dataType == milvus::DataType::VECTOR_BINARY) { + status = knowhere::CheckConfig(index_type, knowhere::Version::GetCurrentVersion().VersionNumber(), json, knowhere::PARAM_TYPE::TRAIN, error_msg); + } else if (dataType == milvus::DataType::VECTOR_FLOAT) { + status = knowhere::CheckConfig(index_type, knowhere::Version::GetCurrentVersion().VersionNumber(), json, knowhere::PARAM_TYPE::TRAIN, error_msg); + } else if (dataType == milvus::DataType::VECTOR_BFLOAT16) { + status = knowhere::CheckConfig(index_type, knowhere::Version::GetCurrentVersion().VersionNumber(), json, knowhere::PARAM_TYPE::TRAIN, error_msg); + } else if (dataType == milvus::DataType::VECTOR_FLOAT16) { + status = knowhere::CheckConfig(index_type, knowhere::Version::GetCurrentVersion().VersionNumber(), json, knowhere::PARAM_TYPE::TRAIN, error_msg); + } else if (dataType == milvus::DataType::VECTOR_SPARSE_FLOAT) { + status = knowhere::CheckConfig(index_type, knowhere::Version::GetCurrentVersion().VersionNumber(), json, knowhere::PARAM_TYPE::TRAIN, error_msg); + } else { + status = knowhere::Status::invalid_args; + } + CStatus cStatus; + if (status == knowhere::Status::success) { + cStatus.error_code = milvus::Success; + cStatus.error_msg = ""; + } else { + cStatus.error_code = milvus::ConfigInvalid; + cStatus.error_msg = error_msg.c_str(); + } + return cStatus; + } catch (std::exception& e) { + auto cStatus = CStatus(); + cStatus.error_code = milvus::UnexpectedError; + cStatus.error_msg = strdup(e.what()); + return cStatus; + } +} int GetIndexListSize() { diff --git a/internal/core/src/segcore/vector_index_c.h b/internal/core/src/segcore/vector_index_c.h index d38a06b7a1e15..06160e0bd6f9e 100644 --- a/internal/core/src/segcore/vector_index_c.h +++ b/internal/core/src/segcore/vector_index_c.h @@ -15,6 +15,10 @@ extern "C" { #endif #include +#include "common/type_c.h" + +CStatus +ValidateIndexParams(const char* index_type, enum CDataType data_type, const uint8_t* index_params, const uint64_t length); int GetIndexListSize(); diff --git a/internal/core/thirdparty/knowhere/CMakeLists.txt b/internal/core/thirdparty/knowhere/CMakeLists.txt index 3585ee2f7b3cd..c8304e840f2a7 100644 --- a/internal/core/thirdparty/knowhere/CMakeLists.txt +++ b/internal/core/thirdparty/knowhere/CMakeLists.txt @@ -60,3 +60,5 @@ endif() # get prometheus COMPILE_OPTIONS get_property( var DIRECTORY "${knowhere_SOURCE_DIR}" PROPERTY COMPILE_OPTIONS ) message( STATUS "knowhere src compile options: ${var}" ) + +set( KNOWHERE_INCLUDE_DIR ${knowhere_SOURCE_DIR}/include CACHE INTERNAL "Path to knowhere include directory" ) diff --git a/internal/datacoord/compaction_trigger_v2.go b/internal/datacoord/compaction_trigger_v2.go index 150e572df6ae3..7fa39a37687c5 100644 --- a/internal/datacoord/compaction_trigger_v2.go +++ b/internal/datacoord/compaction_trigger_v2.go @@ -26,6 +26,7 @@ import ( "github.com/milvus-io/milvus/internal/datacoord/allocator" "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/lock" "github.com/milvus-io/milvus/pkg/util/logutil" diff --git a/internal/datacoord/index_meta.go b/internal/datacoord/index_meta.go index 3ff4fe1fae5cf..ac8aecf9b79e7 100644 --- a/internal/datacoord/index_meta.go +++ b/internal/datacoord/index_meta.go @@ -33,11 +33,11 @@ import ( "github.com/milvus-io/milvus/internal/metastore" "github.com/milvus-io/milvus/internal/metastore/model" "github.com/milvus-io/milvus/internal/proto/indexpb" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "github.com/milvus-io/milvus/internal/proto/workerpb" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/timerecord" "github.com/milvus-io/milvus/pkg/util/typeutil" ) diff --git a/internal/datacoord/index_service.go b/internal/datacoord/index_service.go index b0d5458e51d34..9102d2f438a76 100644 --- a/internal/datacoord/index_service.go +++ b/internal/datacoord/index_service.go @@ -28,10 +28,10 @@ import ( "github.com/milvus-io/milvus/internal/metastore/model" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/indexpb" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metautil" "github.com/milvus-io/milvus/pkg/util/paramtable" diff --git a/internal/datacoord/index_service_test.go b/internal/datacoord/index_service_test.go index e0bedad961fde..2cc4dab44e86f 100644 --- a/internal/datacoord/index_service_test.go +++ b/internal/datacoord/index_service_test.go @@ -42,9 +42,9 @@ import ( "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/workerpb" "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/common" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/merr" ) @@ -620,13 +620,13 @@ func TestServer_AlterIndex(t *testing.T) { s.stateCode.Store(commonpb.StateCode_Healthy) t.Run("mmap_unsupported", func(t *testing.T) { - indexParams[0].Value = indexparamcheck.IndexRaftCagra + indexParams[0].Value = "GPU_CAGRA" resp, err := s.AlterIndex(ctx, req) assert.NoError(t, err) assert.ErrorIs(t, merr.CheckRPCCall(resp, err), merr.ErrParameterInvalid) - indexParams[0].Value = indexparamcheck.IndexFaissIvfFlat + indexParams[0].Value = "IVF_FLAT" }) t.Run("param_value_invalied", func(t *testing.T) { diff --git a/internal/datacoord/task_index.go b/internal/datacoord/task_index.go index 243775cbda3bc..f316a7065cd69 100644 --- a/internal/datacoord/task_index.go +++ b/internal/datacoord/task_index.go @@ -33,6 +33,7 @@ import ( "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/indexparams" "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -156,6 +157,17 @@ func (it *indexBuildTask) PreCheck(ctx context.Context, dependency *taskSchedule fieldID := dependency.meta.indexMeta.GetFieldIDByIndexID(segIndex.CollectionID, segIndex.IndexID) binlogIDs := getBinLogIDs(segment, fieldID) + if Params.IndexEngineConfig.Enable.GetAsBool() { + var ret error + indexParams, ret = Params.IndexEngineConfig.MergeRequestParam(GetIndexType(indexParams), paramtable.BuildStage, indexParams) + + if ret != nil { + log.Ctx(ctx).Warn("failed to construct index build params", zap.Int64("taskID", it.taskID), zap.Error(ret)) + it.SetState(indexpb.JobState_JobStateInit, ret.Error()) + return true + } + } + if isDiskANNIndex(GetIndexType(indexParams)) { var err error indexParams, err = indexparams.UpdateDiskIndexBuildParams(Params, indexParams) diff --git a/internal/datacoord/task_scheduler_test.go b/internal/datacoord/task_scheduler_test.go index ba47290a738be..f5c261c6c9199 100644 --- a/internal/datacoord/task_scheduler_test.go +++ b/internal/datacoord/task_scheduler_test.go @@ -39,7 +39,6 @@ import ( "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/workerpb" "github.com/milvus-io/milvus/pkg/common" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -1432,7 +1431,7 @@ func (s *taskSchedulerSuite) Test_indexTaskWithMvOptionalScalarField() { }, { Key: common.IndexTypeKey, - Value: indexparamcheck.IndexHNSW, + Value: "HNSW", }, }, }, @@ -1485,7 +1484,7 @@ func (s *taskSchedulerSuite) Test_indexTaskWithMvOptionalScalarField() { }, { Key: common.IndexTypeKey, - Value: indexparamcheck.IndexHNSW, + Value: "HNSW", }, }, }, @@ -1582,7 +1581,7 @@ func (s *taskSchedulerSuite) Test_indexTaskWithMvOptionalScalarField() { resetMetaFunc := func() { mt.indexMeta.buildID2SegmentIndex[buildID].IndexState = commonpb.IndexState_Unissued mt.indexMeta.segmentIndexes[segID][indexID].IndexState = commonpb.IndexState_Unissued - mt.indexMeta.indexes[collID][indexID].IndexParams[1].Value = indexparamcheck.IndexHNSW + mt.indexMeta.indexes[collID][indexID].IndexParams[1].Value = "HNSW" mt.collections[collID].Schema.Fields[0].DataType = schemapb.DataType_FloatVector mt.collections[collID].Schema.Fields[1].IsPartitionKey = true mt.collections[collID].Schema.Fields[1].DataType = schemapb.DataType_VarChar diff --git a/internal/datacoord/util.go b/internal/datacoord/util.go index e06e720b4a07f..c910776303eb3 100644 --- a/internal/datacoord/util.go +++ b/internal/datacoord/util.go @@ -30,11 +30,11 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/indexpb" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/tsoutil" diff --git a/internal/indexnode/task_index.go b/internal/indexnode/task_index.go index 4cb8bcee56e36..e4a84f9f7ad7e 100644 --- a/internal/indexnode/task_index.go +++ b/internal/indexnode/task_index.go @@ -33,10 +33,10 @@ import ( "github.com/milvus-io/milvus/internal/proto/workerpb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/indexcgowrapper" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/indexparams" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metautil" @@ -210,6 +210,7 @@ func (it *indexBuildTask) Execute(ctx context.Context) error { zap.Int32("currentIndexVersion", it.req.GetCurrentIndexVersion())) indexType := it.newIndexParams[common.IndexTypeKey] + var fieldDataSize uint64 if indexparamcheck.GetVecIndexMgrInstance().IsDiskANN(indexType) { // check index node support disk index if !Params.IndexNodeCfg.EnableDisk.GetAsBool() { @@ -225,7 +226,7 @@ func (it *indexBuildTask) Execute(ctx context.Context) error { log.Warn("IndexNode get local used size failed") return err } - fieldDataSize, err := estimateFieldDataSize(it.req.GetDim(), it.req.GetNumRows(), it.req.GetField().GetDataType()) + fieldDataSize, err = estimateFieldDataSize(it.req.GetDim(), it.req.GetNumRows(), it.req.GetField().GetDataType()) if err != nil { log.Warn("IndexNode get local used size failed") return err @@ -247,6 +248,10 @@ func (it *indexBuildTask) Execute(ctx context.Context) error { } } + if Params.IndexEngineConfig.Enable.GetAsBool() { + it.newIndexParams, _ = Params.IndexEngineConfig.MergeWithResource(fieldDataSize, it.newIndexParams) + } + storageConfig := &indexcgopb.StorageConfig{ Address: it.req.GetStorageConfig().GetAddress(), AccessKeyID: it.req.GetStorageConfig().GetAccessKeyID(), diff --git a/internal/proxy/cgo_util_test.go b/internal/proxy/cgo_util_test.go index 363ee644f9027..da588c6b6bb4d 100644 --- a/internal/proxy/cgo_util_test.go +++ b/internal/proxy/cgo_util_test.go @@ -20,7 +20,6 @@ import ( "testing" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" ) func Test_CheckVecIndexWithDataTypeExist(t *testing.T) { @@ -29,25 +28,25 @@ func Test_CheckVecIndexWithDataTypeExist(t *testing.T) { dataType schemapb.DataType want bool }{ - {indexparamcheck.IndexHNSW, schemapb.DataType_FloatVector, true}, - {indexparamcheck.IndexHNSW, schemapb.DataType_BinaryVector, false}, - {indexparamcheck.IndexHNSW, schemapb.DataType_Float16Vector, true}, + {"HNSW", schemapb.DataType_FloatVector, true}, + {"HNSW", schemapb.DataType_BinaryVector, false}, + {"HNSW", schemapb.DataType_Float16Vector, true}, - {indexparamcheck.IndexSparseWand, schemapb.DataType_SparseFloatVector, true}, - {indexparamcheck.IndexSparseWand, schemapb.DataType_FloatVector, false}, - {indexparamcheck.IndexSparseWand, schemapb.DataType_Float16Vector, false}, + {"SPARSE_WAND", schemapb.DataType_SparseFloatVector, true}, + {"SPARSE_WAND", schemapb.DataType_FloatVector, false}, + {"SPARSE_WAND", schemapb.DataType_Float16Vector, false}, - {indexparamcheck.IndexGpuBF, schemapb.DataType_FloatVector, true}, - {indexparamcheck.IndexGpuBF, schemapb.DataType_Float16Vector, false}, - {indexparamcheck.IndexGpuBF, schemapb.DataType_BinaryVector, false}, + {"GPU_BRUTE_FORCE", schemapb.DataType_FloatVector, true}, + {"GPU_BRUTE_FORCE", schemapb.DataType_Float16Vector, false}, + {"GPU_BRUTE_FORCE", schemapb.DataType_BinaryVector, false}, - {indexparamcheck.IndexFaissBinIvfFlat, schemapb.DataType_BinaryVector, true}, - {indexparamcheck.IndexFaissBinIvfFlat, schemapb.DataType_FloatVector, false}, + {"BIN_IVF_FLAT", schemapb.DataType_BinaryVector, true}, + {"BIN_IVF_FLAT", schemapb.DataType_FloatVector, false}, - {indexparamcheck.IndexDISKANN, schemapb.DataType_FloatVector, true}, - {indexparamcheck.IndexDISKANN, schemapb.DataType_Float16Vector, true}, - {indexparamcheck.IndexDISKANN, schemapb.DataType_BFloat16Vector, true}, - {indexparamcheck.IndexDISKANN, schemapb.DataType_BinaryVector, false}, + {"DISKANN", schemapb.DataType_FloatVector, true}, + {"DISKANN", schemapb.DataType_Float16Vector, true}, + {"DISKANN", schemapb.DataType_BFloat16Vector, true}, + {"DISKANN", schemapb.DataType_BinaryVector, false}, } for _, test := range cases { diff --git a/internal/proxy/task_index.go b/internal/proxy/task_index.go index 87b077134b8c6..f095324b4e806 100644 --- a/internal/proxy/task_index.go +++ b/internal/proxy/task_index.go @@ -28,12 +28,12 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/indexparams" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metric" @@ -298,6 +298,13 @@ func (cit *createIndexTask) parseIndexParams() error { if !exist { return fmt.Errorf("IndexType not specified") } + if Params.IndexEngineConfig.Enable.GetAsBool() { + var err error + indexParamsMap, err = Params.IndexEngineConfig.MergeRequestMapParam(indexType, paramtable.BuildStage, indexParamsMap) + if err != nil { + return err + } + } if indexparamcheck.GetVecIndexMgrInstance().IsDiskANN(indexType) { err := indexparams.FillDiskIndexParams(Params, indexParamsMap) if err != nil { @@ -417,17 +424,14 @@ func checkTrain(field *schemapb.FieldSchema, indexParams map[string]string) erro if err := fillDimension(field, indexParams); err != nil { return err } - } else { - // used only for checker, should be deleted after checking - indexParams[IsSparseKey] = "true" } - if err := checker.CheckValidDataType(field); err != nil { + if err := checker.CheckValidDataType(indexType, field); err != nil { log.Info("create index with invalid data type", zap.Error(err), zap.String("data_type", field.GetDataType().String())) return err } - if err := checker.CheckTrain(indexParams); err != nil { + if err := checker.CheckTrain(field.DataType, indexParams); err != nil { log.Info("create index with invalid parameters", zap.Error(err)) return err } diff --git a/internal/proxy/task_index_test.go b/internal/proxy/task_index_test.go index 8c89d403decf7..13275d9359556 100644 --- a/internal/proxy/task_index_test.go +++ b/internal/proxy/task_index_test.go @@ -35,9 +35,9 @@ import ( "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" diff --git a/internal/proxy/util.go b/internal/proxy/util.go index fe90239b25eac..e2bfbfd19744a 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -39,6 +39,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/hookutil" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" @@ -47,7 +48,6 @@ import ( "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/contextutil" "github.com/milvus-io/milvus/pkg/util/crypto" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metric" "github.com/milvus-io/milvus/pkg/util/paramtable" diff --git a/internal/querynodev2/segments/collection.go b/internal/querynodev2/segments/collection.go index 5a5679c0c9d4d..4a241595f0f6a 100644 --- a/internal/querynodev2/segments/collection.go +++ b/internal/querynodev2/segments/collection.go @@ -37,10 +37,10 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/segcorepb" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" ) diff --git a/internal/querynodev2/segments/index_attr_cache.go b/internal/querynodev2/segments/index_attr_cache.go index 393fdc6ce6f20..ccb2b0b842c9d 100644 --- a/internal/querynodev2/segments/index_attr_cache.go +++ b/internal/querynodev2/segments/index_attr_cache.go @@ -29,10 +29,10 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/conc" "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/typeutil" ) diff --git a/internal/querynodev2/segments/index_attr_cache_test.go b/internal/querynodev2/segments/index_attr_cache_test.go index 55d3f705bfb93..dc4a11a70ce49 100644 --- a/internal/querynodev2/segments/index_attr_cache_test.go +++ b/internal/querynodev2/segments/index_attr_cache_test.go @@ -24,8 +24,8 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/common" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -62,7 +62,7 @@ func (s *IndexAttrCacheSuite) TestCacheMissing() { func (s *IndexAttrCacheSuite) TestDiskANN() { info := &querypb.FieldIndexInfo{ IndexParams: []*commonpb.KeyValuePair{ - {Key: common.IndexTypeKey, Value: indexparamcheck.IndexDISKANN}, + {Key: common.IndexTypeKey, Value: "DISKANN"}, }, CurrentIndexVersion: 0, IndexSize: 100, @@ -71,7 +71,7 @@ func (s *IndexAttrCacheSuite) TestDiskANN() { memory, disk, err := s.c.GetIndexResourceUsage(info, paramtable.Get().QueryNodeCfg.MemoryIndexLoadPredictMemoryUsageFactor.GetAsFloat(), nil) s.Require().NoError(err) - _, has := s.c.loadWithDisk.Get(typeutil.NewPair[string, int32](indexparamcheck.IndexDISKANN, 0)) + _, has := s.c.loadWithDisk.Get(typeutil.NewPair[string, int32]("DISKANN", 0)) s.False(has, "DiskANN shall never be checked load with disk") s.EqualValues(25, memory) diff --git a/internal/querynodev2/segments/segment.go b/internal/querynodev2/segments/segment.go index 4fe4f5b75dfc4..d92d142fdfb2d 100644 --- a/internal/querynodev2/segments/segment.go +++ b/internal/querynodev2/segments/segment.go @@ -51,11 +51,11 @@ import ( "github.com/milvus-io/milvus/internal/querynodev2/segments/state" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/cgo" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/indexparams" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metautil" diff --git a/internal/querynodev2/segments/segment_loader_test.go b/internal/querynodev2/segments/segment_loader_test.go index ebb282d50e795..e038c0f2262a8 100644 --- a/internal/querynodev2/segments/segment_loader_test.go +++ b/internal/querynodev2/segments/segment_loader_test.go @@ -33,11 +33,11 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "github.com/milvus-io/milvus/internal/util/initcore" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/contextutil" "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metric" "github.com/milvus-io/milvus/pkg/util/paramtable" diff --git a/internal/querynodev2/segments/utils.go b/internal/querynodev2/segments/utils.go index 14ddedd2d844e..e7eae8c6b67d3 100644 --- a/internal/querynodev2/segments/utils.go +++ b/internal/querynodev2/segments/utils.go @@ -29,11 +29,11 @@ import ( "github.com/milvus-io/milvus/internal/querycoordv2/params" "github.com/milvus-io/milvus/internal/querynodev2/segments/metricsutil" "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/contextutil" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" diff --git a/internal/querynodev2/segments/utils_test.go b/internal/querynodev2/segments/utils_test.go index 51d8733ca6114..881068eb3eb19 100644 --- a/internal/querynodev2/segments/utils_test.go +++ b/internal/querynodev2/segments/utils_test.go @@ -10,7 +10,6 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/pkg/common" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -131,7 +130,7 @@ func TestIsIndexMmapEnable(t *testing.T) { IndexParams: []*commonpb.KeyValuePair{ { Key: common.IndexTypeKey, - Value: indexparamcheck.IndexFaissIvfFlat, + Value: "IVF_FLAT", }, }, }) @@ -147,7 +146,7 @@ func TestIsIndexMmapEnable(t *testing.T) { IndexParams: []*commonpb.KeyValuePair{ { Key: common.IndexTypeKey, - Value: indexparamcheck.IndexINVERTED, + Value: "INVERTED", }, }, }) diff --git a/pkg/util/indexparamcheck/auto_index_checker.go b/internal/util/indexparamcheck/auto_index_checker.go similarity index 59% rename from pkg/util/indexparamcheck/auto_index_checker.go rename to internal/util/indexparamcheck/auto_index_checker.go index cc83f196d2e0c..f56a2887b17ae 100644 --- a/pkg/util/indexparamcheck/auto_index_checker.go +++ b/internal/util/indexparamcheck/auto_index_checker.go @@ -9,11 +9,11 @@ type AUTOINDEXChecker struct { baseChecker } -func (c *AUTOINDEXChecker) CheckTrain(params map[string]string) error { +func (c *AUTOINDEXChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { return nil } -func (c *AUTOINDEXChecker) CheckValidDataType(field *schemapb.FieldSchema) error { +func (c *AUTOINDEXChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { return nil } diff --git a/pkg/util/indexparamcheck/base_checker.go b/internal/util/indexparamcheck/base_checker.go similarity index 68% rename from pkg/util/indexparamcheck/base_checker.go rename to internal/util/indexparamcheck/base_checker.go index 6ea600ba4003d..ed52d320ddad7 100644 --- a/pkg/util/indexparamcheck/base_checker.go +++ b/internal/util/indexparamcheck/base_checker.go @@ -19,29 +19,17 @@ package indexparamcheck import ( "fmt" "math" - "strings" "github.com/cockroachdb/errors" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) type baseChecker struct{} -func (c baseChecker) CheckTrain(params map[string]string) error { - // vector dimension should be checked on collection creation. this is just some basic check - isSparse := false - if val, exist := params[common.IsSparseKey]; exist { - val = strings.ToLower(val) - if val != "true" && val != "false" { - return fmt.Errorf("invalid is_sparse value: %s, must be true or false", val) - } - if val == "true" { - isSparse = true - } - } - if isSparse { +func (c baseChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + if typeutil.IsSparseFloatVectorType(dataType) { if !CheckStrByValues(params, Metric, SparseMetrics) { return fmt.Errorf("metric type not found or not supported for sparse float vectors, supported: %v", SparseMetrics) } @@ -55,13 +43,13 @@ func (c baseChecker) CheckTrain(params map[string]string) error { } // CheckValidDataType check whether the field data type is supported for the index type -func (c baseChecker) CheckValidDataType(field *schemapb.FieldSchema) error { +func (c baseChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { return nil } -func (c baseChecker) SetDefaultMetricTypeIfNotExist(m map[string]string, dType schemapb.DataType) {} +func (c baseChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, m map[string]string) {} -func (c baseChecker) StaticCheck(params map[string]string) error { +func (c baseChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error { return errors.New("unsupported index type") } diff --git a/pkg/util/indexparamcheck/base_checker_test.go b/internal/util/indexparamcheck/base_checker_test.go similarity index 93% rename from pkg/util/indexparamcheck/base_checker_test.go rename to internal/util/indexparamcheck/base_checker_test.go index 59a0969d18d4d..c9ceea90af18d 100644 --- a/pkg/util/indexparamcheck/base_checker_test.go +++ b/internal/util/indexparamcheck/base_checker_test.go @@ -44,7 +44,7 @@ func Test_baseChecker_CheckTrain(t *testing.T) { c := newBaseChecker() for _, test := range cases { - err := c.CheckTrain(test.params) + err := c.CheckTrain(0, test.params) if test.errIsNil { assert.NoError(t, err) } else { @@ -115,7 +115,7 @@ func Test_baseChecker_CheckValidDataType(t *testing.T) { c := newBaseChecker() for _, test := range cases { fieldSchema := &schemapb.FieldSchema{DataType: test.dType} - err := c.CheckValidDataType(fieldSchema) + err := c.CheckValidDataType("FLAT", fieldSchema) if test.errIsNil { assert.NoError(t, err) } else { @@ -126,5 +126,5 @@ func Test_baseChecker_CheckValidDataType(t *testing.T) { func Test_baseChecker_StaticCheck(t *testing.T) { // TODO - assert.Error(t, newBaseChecker().StaticCheck(nil)) + assert.Error(t, newBaseChecker().StaticCheck(schemapb.DataType_FloatVector, nil)) } diff --git a/internal/util/indexparamcheck/bin_flat_checker.go b/internal/util/indexparamcheck/bin_flat_checker.go new file mode 100644 index 0000000000000..647e3827ff012 --- /dev/null +++ b/internal/util/indexparamcheck/bin_flat_checker.go @@ -0,0 +1,19 @@ +package indexparamcheck + +import "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + +type binFlatChecker struct { + binaryVectorBaseChecker +} + +func (c binFlatChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + return c.binaryVectorBaseChecker.CheckTrain(0, params) +} + +func (c binFlatChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error { + return c.staticCheck(params) +} + +func newBinFlatChecker() IndexChecker { + return &binFlatChecker{} +} diff --git a/pkg/util/indexparamcheck/bin_flat_checker_test.go b/internal/util/indexparamcheck/bin_flat_checker_test.go similarity index 89% rename from pkg/util/indexparamcheck/bin_flat_checker_test.go rename to internal/util/indexparamcheck/bin_flat_checker_test.go index 9cf4f39394515..b402294ad2311 100644 --- a/pkg/util/indexparamcheck/bin_flat_checker_test.go +++ b/internal/util/indexparamcheck/bin_flat_checker_test.go @@ -1,6 +1,7 @@ package indexparamcheck import ( + "github.com/milvus-io/milvus/pkg/common" "strconv" "testing" @@ -64,9 +65,10 @@ func Test_binFlatChecker_CheckTrain(t *testing.T) { {p7, true}, } - c := newBinFlatChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("BINFLAT") for _, test := range cases { - err := c.CheckTrain(test.params) + test.params[common.IndexTypeKey] = "BINFLAT" + err := c.CheckTrain(0, test.params) if test.errIsNil { assert.NoError(t, err) } else { @@ -134,10 +136,10 @@ func Test_binFlatChecker_CheckValidDataType(t *testing.T) { }, } - c := newBinFlatChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("BINFLAT") for _, test := range cases { fieldSchema := &schemapb.FieldSchema{DataType: test.dType} - err := c.CheckValidDataType(fieldSchema) + err := c.CheckValidDataType("BINFLAT", fieldSchema) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/bin_ivf_flat_checker.go b/internal/util/indexparamcheck/bin_ivf_flat_checker.go similarity index 55% rename from pkg/util/indexparamcheck/bin_ivf_flat_checker.go rename to internal/util/indexparamcheck/bin_ivf_flat_checker.go index c36bc41c1c32e..018bc521baa7d 100644 --- a/pkg/util/indexparamcheck/bin_ivf_flat_checker.go +++ b/internal/util/indexparamcheck/bin_ivf_flat_checker.go @@ -2,13 +2,15 @@ package indexparamcheck import ( "fmt" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) type binIVFFlatChecker struct { binaryVectorBaseChecker } -func (c binIVFFlatChecker) StaticCheck(params map[string]string) error { +func (c binIVFFlatChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error { if !CheckStrByValues(params, Metric, BinIvfMetrics) { return fmt.Errorf("metric type %s not found or not supported, supported: %v", params[Metric], BinIvfMetrics) } @@ -20,12 +22,12 @@ func (c binIVFFlatChecker) StaticCheck(params map[string]string) error { return nil } -func (c binIVFFlatChecker) CheckTrain(params map[string]string) error { - if err := c.binaryVectorBaseChecker.CheckTrain(params); err != nil { +func (c binIVFFlatChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + if err := c.binaryVectorBaseChecker.CheckTrain(0, params); err != nil { return err } - return c.StaticCheck(params) + return c.StaticCheck(schemapb.DataType_BinaryVector, params) } func newBinIVFFlatChecker() IndexChecker { diff --git a/pkg/util/indexparamcheck/bin_ivf_flat_checker_test.go b/internal/util/indexparamcheck/bin_ivf_flat_checker_test.go similarity index 94% rename from pkg/util/indexparamcheck/bin_ivf_flat_checker_test.go rename to internal/util/indexparamcheck/bin_ivf_flat_checker_test.go index 77bda3bb016b1..7c0e773ffe8ea 100644 --- a/pkg/util/indexparamcheck/bin_ivf_flat_checker_test.go +++ b/internal/util/indexparamcheck/bin_ivf_flat_checker_test.go @@ -115,9 +115,9 @@ func Test_binIVFFlatChecker_CheckTrain(t *testing.T) { {p7, false}, } - c := newBinIVFFlatChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("BINIVFFLAT") for _, test := range cases { - err := c.CheckTrain(test.params) + err := c.CheckTrain(0, test.params) if test.errIsNil { assert.NoError(t, err) } else { @@ -185,10 +185,10 @@ func Test_binIVFFlatChecker_CheckValidDataType(t *testing.T) { }, } - c := newBinIVFFlatChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("BINIVFFLAT") for _, test := range cases { fieldSchema := &schemapb.FieldSchema{DataType: test.dType} - err := c.CheckValidDataType(fieldSchema) + err := c.CheckValidDataType("BINIVFFLAT", fieldSchema) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/binary_vector_base_checker.go b/internal/util/indexparamcheck/binary_vector_base_checker.go similarity index 72% rename from pkg/util/indexparamcheck/binary_vector_base_checker.go rename to internal/util/indexparamcheck/binary_vector_base_checker.go index e73bd8b62e40a..d7b1ca891ae30 100644 --- a/pkg/util/indexparamcheck/binary_vector_base_checker.go +++ b/internal/util/indexparamcheck/binary_vector_base_checker.go @@ -19,22 +19,22 @@ func (c binaryVectorBaseChecker) staticCheck(params map[string]string) error { return nil } -func (c binaryVectorBaseChecker) CheckTrain(params map[string]string) error { - if err := c.baseChecker.CheckTrain(params); err != nil { +func (c binaryVectorBaseChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + if err := c.baseChecker.CheckTrain(0, params); err != nil { return err } return c.staticCheck(params) } -func (c binaryVectorBaseChecker) CheckValidDataType(field *schemapb.FieldSchema) error { +func (c binaryVectorBaseChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { if field.GetDataType() != schemapb.DataType_BinaryVector { return fmt.Errorf("binary vector is only supported") } return nil } -func (c binaryVectorBaseChecker) SetDefaultMetricTypeIfNotExist(params map[string]string, dType schemapb.DataType) { +func (c binaryVectorBaseChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, params map[string]string) { setDefaultIfNotExist(params, common.MetricTypeKey, BinaryVectorDefaultMetricType) } diff --git a/pkg/util/indexparamcheck/binary_vector_base_checker_test.go b/internal/util/indexparamcheck/binary_vector_base_checker_test.go similarity index 92% rename from pkg/util/indexparamcheck/binary_vector_base_checker_test.go rename to internal/util/indexparamcheck/binary_vector_base_checker_test.go index b52648f79355e..85942a3fc1dc7 100644 --- a/pkg/util/indexparamcheck/binary_vector_base_checker_test.go +++ b/internal/util/indexparamcheck/binary_vector_base_checker_test.go @@ -67,10 +67,10 @@ func Test_binaryVectorBaseChecker_CheckValidDataType(t *testing.T) { }, } - c := newBinaryVectorBaseChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("BINFLAT") for _, test := range cases { fieldSchema := &schemapb.FieldSchema{DataType: test.dType} - err := c.CheckValidDataType(fieldSchema) + err := c.CheckValidDataType("BINFLAT", fieldSchema) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/internal/util/indexparamcheck/bitmap_checker_test.go b/internal/util/indexparamcheck/bitmap_checker_test.go new file mode 100644 index 0000000000000..09180fdbb5e1f --- /dev/null +++ b/internal/util/indexparamcheck/bitmap_checker_test.go @@ -0,0 +1,33 @@ +package indexparamcheck + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +func Test_BitmapIndexChecker(t *testing.T) { + c := newBITMAPChecker() + + assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Bool})) + assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Int8})) + assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Int16})) + assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Int32})) + assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Int64})) + assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_String})) + assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Bool})) + assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int8})) + assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int16})) + assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int32})) + assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64})) + assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_String})) + + assert.Error(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_JSON})) + assert.Error(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Float})) + assert.Error(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Double})) + assert.Error(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Float})) + assert.Error(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Double})) + assert.Error(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Double, IsPrimaryKey: true})) +} diff --git a/pkg/util/indexparamcheck/bitmap_index_checker.go b/internal/util/indexparamcheck/bitmap_index_checker.go similarity index 79% rename from pkg/util/indexparamcheck/bitmap_index_checker.go rename to internal/util/indexparamcheck/bitmap_index_checker.go index f19943a50ea93..55375d93771e0 100644 --- a/pkg/util/indexparamcheck/bitmap_index_checker.go +++ b/internal/util/indexparamcheck/bitmap_index_checker.go @@ -11,11 +11,11 @@ type BITMAPChecker struct { scalarIndexChecker } -func (c *BITMAPChecker) CheckTrain(params map[string]string) error { - return c.scalarIndexChecker.CheckTrain(params) +func (c *BITMAPChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + return c.scalarIndexChecker.CheckTrain(0, params) } -func (c *BITMAPChecker) CheckValidDataType(field *schemapb.FieldSchema) error { +func (c *BITMAPChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { if field.IsPrimaryKey { return fmt.Errorf("create bitmap index on primary key not supported") } diff --git a/pkg/util/indexparamcheck/cagra_checker.go b/internal/util/indexparamcheck/cagra_checker.go similarity index 84% rename from pkg/util/indexparamcheck/cagra_checker.go rename to internal/util/indexparamcheck/cagra_checker.go index 8f52a1605d775..36151d6220f3d 100644 --- a/pkg/util/indexparamcheck/cagra_checker.go +++ b/internal/util/indexparamcheck/cagra_checker.go @@ -3,6 +3,8 @@ package indexparamcheck import ( "fmt" "strconv" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) // diskannChecker checks if an diskann index can be built. @@ -10,8 +12,8 @@ type cagraChecker struct { floatVectorBaseChecker } -func (c *cagraChecker) CheckTrain(params map[string]string) error { - err := c.baseChecker.CheckTrain(params) +func (c *cagraChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + err := c.baseChecker.CheckTrain(0, params) if err != nil { return err } @@ -54,7 +56,7 @@ func (c *cagraChecker) CheckTrain(params map[string]string) error { return nil } -func (c cagraChecker) StaticCheck(params map[string]string) error { +func (c cagraChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error { return c.staticCheck(params) } diff --git a/pkg/util/indexparamcheck/cagra_checker_test.go b/internal/util/indexparamcheck/cagra_checker_test.go similarity index 96% rename from pkg/util/indexparamcheck/cagra_checker_test.go rename to internal/util/indexparamcheck/cagra_checker_test.go index 23a931a12ef01..212dde4662dad 100644 --- a/pkg/util/indexparamcheck/cagra_checker_test.go +++ b/internal/util/indexparamcheck/cagra_checker_test.go @@ -101,9 +101,9 @@ func Test_cagraChecker_CheckTrain(t *testing.T) { {p14, false}, } - c := newCagraChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("GPU_CAGRA") for _, test := range cases { - err := c.CheckTrain(test.params) + err := c.CheckTrain(0, test.params) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/conf_adapter_mgr.go b/internal/util/indexparamcheck/conf_adapter_mgr.go similarity index 71% rename from pkg/util/indexparamcheck/conf_adapter_mgr.go rename to internal/util/indexparamcheck/conf_adapter_mgr.go index 2ff7320c9b3a2..45d0549e05cbd 100644 --- a/pkg/util/indexparamcheck/conf_adapter_mgr.go +++ b/internal/util/indexparamcheck/conf_adapter_mgr.go @@ -34,7 +34,10 @@ type indexCheckerMgrImpl struct { func (mgr *indexCheckerMgrImpl) GetChecker(indexType string) (IndexChecker, error) { mgr.once.Do(mgr.registerIndexChecker) - + // Unify the vector index checker + if GetVecIndexMgrInstance().IsVecIndex(indexType) { + return mgr.checkers[IndexVector], nil + } adapter, ok := mgr.checkers[indexType] if ok { return adapter, nil @@ -43,23 +46,7 @@ func (mgr *indexCheckerMgrImpl) GetChecker(indexType string) (IndexChecker, erro } func (mgr *indexCheckerMgrImpl) registerIndexChecker() { - mgr.checkers[IndexRaftIvfFlat] = newRaftIVFFlatChecker() - mgr.checkers[IndexRaftIvfPQ] = newRaftIVFPQChecker() - mgr.checkers[IndexRaftCagra] = newCagraChecker() - mgr.checkers[IndexRaftBruteForce] = newRaftBruteForceChecker() - mgr.checkers[IndexFaissIDMap] = newFlatChecker() - mgr.checkers[IndexFaissIvfFlat] = newIVFBaseChecker() - mgr.checkers[IndexFaissIvfPQ] = newIVFPQChecker() - mgr.checkers[IndexScaNN] = newScaNNChecker() - mgr.checkers[IndexFaissIvfSQ8] = newIVFSQChecker() - mgr.checkers[IndexFaissBinIDMap] = newBinFlatChecker() - mgr.checkers[IndexFaissBinIvfFlat] = newBinIVFFlatChecker() - mgr.checkers[IndexHNSW] = newHnswChecker() - mgr.checkers[IndexDISKANN] = newDiskannChecker() - mgr.checkers[IndexSparseInverted] = newSparseInvertedIndexChecker() - // WAND doesn't have more index params than sparse inverted index, thus - // using the same checker. - mgr.checkers[IndexSparseWand] = newSparseInvertedIndexChecker() + mgr.checkers[IndexVector] = newVecIndexChecker() mgr.checkers[IndexINVERTED] = newINVERTEDChecker() mgr.checkers[IndexSTLSORT] = newSTLSORTChecker() mgr.checkers["Asceneding"] = newSTLSORTChecker() diff --git a/pkg/util/indexparamcheck/conf_adapter_mgr_test.go b/internal/util/indexparamcheck/conf_adapter_mgr_test.go similarity index 78% rename from pkg/util/indexparamcheck/conf_adapter_mgr_test.go rename to internal/util/indexparamcheck/conf_adapter_mgr_test.go index 6ab9469ee501d..746f6cd09454d 100644 --- a/pkg/util/indexparamcheck/conf_adapter_mgr_test.go +++ b/internal/util/indexparamcheck/conf_adapter_mgr_test.go @@ -29,49 +29,49 @@ func Test_GetConfAdapterMgrInstance(t *testing.T) { assert.NotEqual(t, nil, err) assert.Equal(t, nil, adapter) - adapter, err = adapterMgr.GetChecker(IndexFaissIDMap) + adapter, err = adapterMgr.GetChecker("FLAT") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) _, ok = adapter.(*flatChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexFaissIvfFlat) + adapter, err = adapterMgr.GetChecker("IVF_FLAT") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) _, ok = adapter.(*ivfBaseChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexScaNN) + adapter, err = adapterMgr.GetChecker("SCANN") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) _, ok = adapter.(*scaNNChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexFaissIvfPQ) + adapter, err = adapterMgr.GetChecker("IVF_PQ") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) _, ok = adapter.(*ivfPQChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexFaissIvfSQ8) + adapter, err = adapterMgr.GetChecker("IVF_SQ8") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) _, ok = adapter.(*ivfSQChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexFaissBinIDMap) + adapter, err = adapterMgr.GetChecker("BIN_FLAT") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) _, ok = adapter.(*binFlatChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexFaissBinIvfFlat) + adapter, err = adapterMgr.GetChecker("BIN_IVF_FLAT") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) _, ok = adapter.(*binIVFFlatChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexHNSW) + adapter, err = adapterMgr.GetChecker("HNSW") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) _, ok = adapter.(*hnswChecker) @@ -89,49 +89,49 @@ func TestConfAdapterMgrImpl_GetAdapter(t *testing.T) { assert.NotEqual(t, nil, err) assert.Equal(t, nil, adapter) - adapter, err = adapterMgr.GetChecker(IndexFaissIDMap) + adapter, err = adapterMgr.GetChecker("FLAT") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) _, ok = adapter.(*flatChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexFaissIvfFlat) + adapter, err = adapterMgr.GetChecker("IVF_FLAT") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) _, ok = adapter.(*ivfBaseChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexScaNN) + adapter, err = adapterMgr.GetChecker("SCANN") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) _, ok = adapter.(*scaNNChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexFaissIvfPQ) + adapter, err = adapterMgr.GetChecker("IVF_PQ") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) _, ok = adapter.(*ivfPQChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexFaissIvfSQ8) + adapter, err = adapterMgr.GetChecker("IVF_SQ8") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) _, ok = adapter.(*ivfSQChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexFaissBinIDMap) + adapter, err = adapterMgr.GetChecker("BIN_FLAT") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) _, ok = adapter.(*binFlatChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexFaissBinIvfFlat) + adapter, err = adapterMgr.GetChecker("BIN_IVF_FLAT") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) _, ok = adapter.(*binIVFFlatChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexHNSW) + adapter, err = adapterMgr.GetChecker("HNSW") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) _, ok = adapter.(*hnswChecker) @@ -146,7 +146,7 @@ func TestConfAdapterMgrImpl_GetAdapter_multiple_threads(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - adapter, err := mgr.GetChecker(IndexHNSW) + adapter, err := mgr.GetChecker("HNSW") assert.NoError(t, err) assert.NotNil(t, adapter) }() diff --git a/pkg/util/indexparamcheck/constraints.go b/internal/util/indexparamcheck/constraints.go similarity index 100% rename from pkg/util/indexparamcheck/constraints.go rename to internal/util/indexparamcheck/constraints.go diff --git a/pkg/util/indexparamcheck/diskann_checker.go b/internal/util/indexparamcheck/diskann_checker.go similarity index 59% rename from pkg/util/indexparamcheck/diskann_checker.go rename to internal/util/indexparamcheck/diskann_checker.go index 3f2401851e961..323859b6f05a5 100644 --- a/pkg/util/indexparamcheck/diskann_checker.go +++ b/internal/util/indexparamcheck/diskann_checker.go @@ -1,11 +1,13 @@ package indexparamcheck +import "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + // diskannChecker checks if an diskann index can be built. type diskannChecker struct { floatVectorBaseChecker } -func (c diskannChecker) StaticCheck(params map[string]string) error { +func (c diskannChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error { return c.staticCheck(params) } diff --git a/pkg/util/indexparamcheck/diskann_checker_test.go b/internal/util/indexparamcheck/diskann_checker_test.go similarity index 91% rename from pkg/util/indexparamcheck/diskann_checker_test.go rename to internal/util/indexparamcheck/diskann_checker_test.go index 4fcfdbf019aa7..8a644e7d4d6c7 100644 --- a/pkg/util/indexparamcheck/diskann_checker_test.go +++ b/internal/util/indexparamcheck/diskann_checker_test.go @@ -72,9 +72,9 @@ func Test_diskannChecker_CheckTrain(t *testing.T) { {p7, false}, } - c := newDiskannChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("DISKANN") for _, test := range cases { - err := c.CheckTrain(test.params) + err := c.CheckTrain(0, test.params) if test.errIsNil { assert.NoError(t, err) } else { @@ -142,9 +142,9 @@ func Test_diskannChecker_CheckValidDataType(t *testing.T) { }, } - c := newDiskannChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("DISKANN") for _, test := range cases { - err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType}) + err := c.CheckValidDataType("DISKANN", &schemapb.FieldSchema{DataType: test.dType}) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/flat_checker.go b/internal/util/indexparamcheck/flat_checker.go similarity index 52% rename from pkg/util/indexparamcheck/flat_checker.go rename to internal/util/indexparamcheck/flat_checker.go index d98db449206b4..8fe6d59f25ba0 100644 --- a/pkg/util/indexparamcheck/flat_checker.go +++ b/internal/util/indexparamcheck/flat_checker.go @@ -1,10 +1,12 @@ package indexparamcheck +import "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + type flatChecker struct { floatVectorBaseChecker } -func (c flatChecker) StaticCheck(m map[string]string) error { +func (c flatChecker) StaticCheck(dataType schemapb.DataType, m map[string]string) error { return c.staticCheck(m) } diff --git a/pkg/util/indexparamcheck/flat_checker_test.go b/internal/util/indexparamcheck/flat_checker_test.go similarity index 85% rename from pkg/util/indexparamcheck/flat_checker_test.go rename to internal/util/indexparamcheck/flat_checker_test.go index c22432bc6f17c..bd7604f7889e7 100644 --- a/pkg/util/indexparamcheck/flat_checker_test.go +++ b/internal/util/indexparamcheck/flat_checker_test.go @@ -6,6 +6,7 @@ import ( "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/util/metric" ) @@ -52,9 +53,9 @@ func Test_flatChecker_CheckTrain(t *testing.T) { {p7, false}, } - c := newFlatChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("FLAT") for _, test := range cases { - err := c.CheckTrain(test.params) + err := c.CheckTrain(0, test.params) if test.errIsNil { assert.NoError(t, err) } else { @@ -89,9 +90,9 @@ func Test_flatChecker_StaticCheck(t *testing.T) { }, } - c := newFlatChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("FLAT") for _, test := range cases { - err := c.StaticCheck(test.params) + err := c.StaticCheck(schemapb.DataType_FloatVector, test.params) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/float_vector_base_checker.go b/internal/util/indexparamcheck/float_vector_base_checker.go similarity index 69% rename from pkg/util/indexparamcheck/float_vector_base_checker.go rename to internal/util/indexparamcheck/float_vector_base_checker.go index 710dfb3a18a38..c8ffe4254a128 100644 --- a/pkg/util/indexparamcheck/float_vector_base_checker.go +++ b/internal/util/indexparamcheck/float_vector_base_checker.go @@ -20,22 +20,22 @@ func (c floatVectorBaseChecker) staticCheck(params map[string]string) error { return nil } -func (c floatVectorBaseChecker) CheckTrain(params map[string]string) error { - if err := c.baseChecker.CheckTrain(params); err != nil { +func (c floatVectorBaseChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + if err := c.baseChecker.CheckTrain(0, params); err != nil { return err } return c.staticCheck(params) } -func (c floatVectorBaseChecker) CheckValidDataType(field *schemapb.FieldSchema) error { +func (c floatVectorBaseChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { if !typeutil.IsDenseFloatVectorType(field.GetDataType()) { return fmt.Errorf("data type should be FloatVector, Float16Vector or BFloat16Vector") } return nil } -func (c floatVectorBaseChecker) SetDefaultMetricTypeIfNotExist(params map[string]string, dType schemapb.DataType) { +func (c floatVectorBaseChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, params map[string]string) { setDefaultIfNotExist(params, common.MetricTypeKey, FloatVectorDefaultMetricType) } diff --git a/pkg/util/indexparamcheck/float_vector_base_checker_test.go b/internal/util/indexparamcheck/float_vector_base_checker_test.go similarity index 90% rename from pkg/util/indexparamcheck/float_vector_base_checker_test.go rename to internal/util/indexparamcheck/float_vector_base_checker_test.go index 7eb0a97d36c6c..7306db92b0df2 100644 --- a/pkg/util/indexparamcheck/float_vector_base_checker_test.go +++ b/internal/util/indexparamcheck/float_vector_base_checker_test.go @@ -67,9 +67,9 @@ func Test_floatVectorBaseChecker_CheckValidDataType(t *testing.T) { }, } - c := newFloatVectorBaseChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("HNSW") for _, test := range cases { - err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType}) + err := c.CheckValidDataType("HNSW", &schemapb.FieldSchema{DataType: test.dType}) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/hnsw_checker.go b/internal/util/indexparamcheck/hnsw_checker.go similarity index 72% rename from pkg/util/indexparamcheck/hnsw_checker.go rename to internal/util/indexparamcheck/hnsw_checker.go index b5f9e1f2b77e1..78a733a0082ca 100644 --- a/pkg/util/indexparamcheck/hnsw_checker.go +++ b/internal/util/indexparamcheck/hnsw_checker.go @@ -12,7 +12,7 @@ type hnswChecker struct { baseChecker } -func (c hnswChecker) StaticCheck(params map[string]string) error { +func (c hnswChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error { if !CheckIntByRange(params, EFConstruction, HNSWMinEfConstruction, HNSWMaxEfConstruction) { return errOutOfRange(EFConstruction, HNSWMinEfConstruction, HNSWMaxEfConstruction) } @@ -25,21 +25,21 @@ func (c hnswChecker) StaticCheck(params map[string]string) error { return nil } -func (c hnswChecker) CheckTrain(params map[string]string) error { - if err := c.StaticCheck(params); err != nil { +func (c hnswChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + if err := c.StaticCheck(dataType, params); err != nil { return err } - return c.baseChecker.CheckTrain(params) + return c.baseChecker.CheckTrain(0, params) } -func (c hnswChecker) CheckValidDataType(field *schemapb.FieldSchema) error { +func (c hnswChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { if !typeutil.IsVectorType(field.GetDataType()) { return fmt.Errorf("can't build hnsw in not vector type") } return nil } -func (c hnswChecker) SetDefaultMetricTypeIfNotExist(params map[string]string, dType schemapb.DataType) { +func (c hnswChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, params map[string]string) { if typeutil.IsDenseFloatVectorType(dType) { setDefaultIfNotExist(params, common.MetricTypeKey, FloatVectorDefaultMetricType) } else if typeutil.IsSparseFloatVectorType(dType) { diff --git a/pkg/util/indexparamcheck/hnsw_checker_test.go b/internal/util/indexparamcheck/hnsw_checker_test.go similarity index 92% rename from pkg/util/indexparamcheck/hnsw_checker_test.go rename to internal/util/indexparamcheck/hnsw_checker_test.go index b9118125407e9..d80c5207adf72 100644 --- a/pkg/util/indexparamcheck/hnsw_checker_test.go +++ b/internal/util/indexparamcheck/hnsw_checker_test.go @@ -92,9 +92,9 @@ func Test_hnswChecker_CheckTrain(t *testing.T) { {p7, false}, } - c := newHnswChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("HNSW") for _, test := range cases { - err := c.CheckTrain(test.params) + err := c.CheckTrain(0, test.params) if test.errIsNil { assert.NoError(t, err) } else { @@ -162,9 +162,9 @@ func Test_hnswChecker_CheckValidDataType(t *testing.T) { }, } - c := newHnswChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("HNSW") for _, test := range cases { - err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType}) + err := c.CheckValidDataType("HNSW", &schemapb.FieldSchema{DataType: test.dType}) if test.errIsNil { assert.NoError(t, err) } else { @@ -200,14 +200,14 @@ func Test_hnswChecker_SetDefaultMetricType(t *testing.T) { }, } - c := newHnswChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("HNSW") for _, test := range cases { p := map[string]string{ DIM: strconv.Itoa(128), HNSWM: strconv.Itoa(16), EFConstruction: strconv.Itoa(200), } - c.SetDefaultMetricTypeIfNotExist(p, test.dType) + c.SetDefaultMetricTypeIfNotExist(test.dType, p) assert.Equal(t, p[Metric], test.metricType) } } diff --git a/internal/util/indexparamcheck/hybrid_checker_test.go b/internal/util/indexparamcheck/hybrid_checker_test.go new file mode 100644 index 0000000000000..fdfa507ab9e02 --- /dev/null +++ b/internal/util/indexparamcheck/hybrid_checker_test.go @@ -0,0 +1,37 @@ +package indexparamcheck + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +func Test_HybridIndexChecker(t *testing.T) { + c := newHYBRIDChecker() + + assert.NoError(t, c.CheckTrain(0, map[string]string{"bitmap_cardinality_limit": "100"})) + + assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Bool})) + assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Int8})) + assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Int16})) + assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Int32})) + assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Int64})) + assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_String})) + assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Bool})) + assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int8})) + assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int16})) + assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int32})) + assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64})) + assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_String})) + + assert.Error(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_JSON})) + assert.Error(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Float})) + assert.Error(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Double})) + assert.Error(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Float})) + assert.Error(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Double})) + assert.Error(t, c.CheckTrain(0, map[string]string{})) + assert.Error(t, c.CheckTrain(0, map[string]string{"bitmap_cardinality_limit": "0"})) + assert.Error(t, c.CheckTrain(0, map[string]string{"bitmap_cardinality_limit": "2000"})) +} diff --git a/pkg/util/indexparamcheck/hybrid_index_checker.go b/internal/util/indexparamcheck/hybrid_index_checker.go similarity index 83% rename from pkg/util/indexparamcheck/hybrid_index_checker.go rename to internal/util/indexparamcheck/hybrid_index_checker.go index 9493bccd91f6d..891c14192793b 100644 --- a/pkg/util/indexparamcheck/hybrid_index_checker.go +++ b/internal/util/indexparamcheck/hybrid_index_checker.go @@ -12,15 +12,15 @@ type HYBRIDChecker struct { scalarIndexChecker } -func (c *HYBRIDChecker) CheckTrain(params map[string]string) error { +func (c *HYBRIDChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { if !CheckIntByRange(params, common.BitmapCardinalityLimitKey, 1, MaxBitmapCardinalityLimit) { return fmt.Errorf("failed to check bitmap cardinality limit, should be larger than 0 and smaller than %d", MaxBitmapCardinalityLimit) } - return c.scalarIndexChecker.CheckTrain(params) + return c.scalarIndexChecker.CheckTrain(0, params) } -func (c *HYBRIDChecker) CheckValidDataType(field *schemapb.FieldSchema) error { +func (c *HYBRIDChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { mainType := field.GetDataType() elemType := field.GetElementType() if !typeutil.IsBoolType(mainType) && !typeutil.IsIntegerType(mainType) && diff --git a/pkg/util/indexparamcheck/index_checker.go b/internal/util/indexparamcheck/index_checker.go similarity index 77% rename from pkg/util/indexparamcheck/index_checker.go rename to internal/util/indexparamcheck/index_checker.go index 1c11280898394..610ddffc2cd9b 100644 --- a/pkg/util/indexparamcheck/index_checker.go +++ b/internal/util/indexparamcheck/index_checker.go @@ -21,8 +21,8 @@ import ( ) type IndexChecker interface { - CheckTrain(map[string]string) error - CheckValidDataType(field *schemapb.FieldSchema) error - SetDefaultMetricTypeIfNotExist(map[string]string, schemapb.DataType) - StaticCheck(map[string]string) error + CheckTrain(schemapb.DataType, map[string]string) error + CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error + SetDefaultMetricTypeIfNotExist(schemapb.DataType, map[string]string) + StaticCheck(schemapb.DataType, map[string]string) error } diff --git a/pkg/util/indexparamcheck/index_checker_test.go b/internal/util/indexparamcheck/index_checker_test.go similarity index 100% rename from pkg/util/indexparamcheck/index_checker_test.go rename to internal/util/indexparamcheck/index_checker_test.go diff --git a/pkg/util/indexparamcheck/index_type.go b/internal/util/indexparamcheck/index_type.go similarity index 79% rename from pkg/util/indexparamcheck/index_type.go rename to internal/util/indexparamcheck/index_type.go index 977d7d7d5d0fd..c2db0e5aefd6c 100644 --- a/pkg/util/indexparamcheck/index_type.go +++ b/internal/util/indexparamcheck/index_type.go @@ -23,23 +23,7 @@ type IndexType = string // IndexType definitions const ( - // vector index - IndexGpuBF IndexType = "GPU_BRUTE_FORCE" - IndexRaftIvfFlat IndexType = "GPU_IVF_FLAT" - IndexRaftIvfPQ IndexType = "GPU_IVF_PQ" - IndexRaftCagra IndexType = "GPU_CAGRA" - IndexRaftBruteForce IndexType = "GPU_BRUTE_FORCE" - IndexFaissIDMap IndexType = "FLAT" // no index is built. - IndexFaissIvfFlat IndexType = "IVF_FLAT" - IndexFaissIvfPQ IndexType = "IVF_PQ" - IndexScaNN IndexType = "SCANN" - IndexFaissIvfSQ8 IndexType = "IVF_SQ8" - IndexFaissBinIDMap IndexType = "BIN_FLAT" - IndexFaissBinIvfFlat IndexType = "BIN_IVF_FLAT" - IndexHNSW IndexType = "HNSW" - IndexDISKANN IndexType = "DISKANN" - IndexSparseInverted IndexType = "SPARSE_INVERTED_INDEX" - IndexSparseWand IndexType = "SPARSE_WAND" + IndexVector IndexType = "VECINDEX" // scalar index IndexSTLSORT IndexType = "STL_SORT" diff --git a/pkg/util/indexparamcheck/index_type_test.go b/internal/util/indexparamcheck/index_type_test.go similarity index 94% rename from pkg/util/indexparamcheck/index_type_test.go rename to internal/util/indexparamcheck/index_type_test.go index 29d77eace5488..d350ee9121eca 100644 --- a/pkg/util/indexparamcheck/index_type_test.go +++ b/internal/util/indexparamcheck/index_type_test.go @@ -34,7 +34,7 @@ func TestIsScalarMmapIndex(t *testing.T) { func TestIsVectorMmapIndex(t *testing.T) { t.Run("vector index", func(t *testing.T) { - assert.True(t, IsVectorMmapIndex(IndexFaissIDMap)) + assert.True(t, IsVectorMmapIndex("FLAT")) assert.False(t, IsVectorMmapIndex(IndexINVERTED)) }) } @@ -60,7 +60,7 @@ func TestValidateMmapTypeParams(t *testing.T) { }) t.Run("invalid mmap enable type", func(t *testing.T) { - err := ValidateMmapIndexParams(IndexGpuBF, map[string]string{ + err := ValidateMmapIndexParams("GPU_BRUTE_FORCE", map[string]string{ common.MmapEnabledKey: "true", }) assert.Error(t, err) diff --git a/pkg/util/indexparamcheck/inverted_checker.go b/internal/util/indexparamcheck/inverted_checker.go similarity index 70% rename from pkg/util/indexparamcheck/inverted_checker.go rename to internal/util/indexparamcheck/inverted_checker.go index 8d6893c10085a..85d904825f3e8 100644 --- a/pkg/util/indexparamcheck/inverted_checker.go +++ b/internal/util/indexparamcheck/inverted_checker.go @@ -12,11 +12,11 @@ type INVERTEDChecker struct { scalarIndexChecker } -func (c *INVERTEDChecker) CheckTrain(params map[string]string) error { - return c.scalarIndexChecker.CheckTrain(params) +func (c *INVERTEDChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + return c.scalarIndexChecker.CheckTrain(0, params) } -func (c *INVERTEDChecker) CheckValidDataType(field *schemapb.FieldSchema) error { +func (c *INVERTEDChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { dType := field.GetDataType() if !typeutil.IsBoolType(dType) && !typeutil.IsArithmetic(dType) && !typeutil.IsStringType(dType) && !typeutil.IsArrayType(dType) { diff --git a/internal/util/indexparamcheck/inverted_checker_test.go b/internal/util/indexparamcheck/inverted_checker_test.go new file mode 100644 index 0000000000000..52659f129457f --- /dev/null +++ b/internal/util/indexparamcheck/inverted_checker_test.go @@ -0,0 +1,25 @@ +package indexparamcheck + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +func Test_INVERTEDIndexChecker(t *testing.T) { + c := newINVERTEDChecker() + + assert.NoError(t, c.CheckTrain(0, map[string]string{})) + + assert.NoError(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_VarChar})) + assert.NoError(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_String})) + assert.NoError(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_Bool})) + assert.NoError(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_Int64})) + assert.NoError(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_Float})) + assert.NoError(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_Array})) + + assert.Error(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_JSON})) + assert.Error(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_FloatVector})) +} diff --git a/internal/util/indexparamcheck/ivf_base_checker.go b/internal/util/indexparamcheck/ivf_base_checker.go new file mode 100644 index 0000000000000..93b0255c35dd5 --- /dev/null +++ b/internal/util/indexparamcheck/ivf_base_checker.go @@ -0,0 +1,28 @@ +package indexparamcheck + +import "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + +type ivfBaseChecker struct { + floatVectorBaseChecker +} + +func (c ivfBaseChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error { + if !CheckIntByRange(params, NLIST, MinNList, MaxNList) { + return errOutOfRange(NLIST, MinNList, MaxNList) + } + + // skip check number of rows + + return c.floatVectorBaseChecker.staticCheck(params) +} + +func (c ivfBaseChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + if err := c.StaticCheck(dataType, params); err != nil { + return err + } + return c.floatVectorBaseChecker.CheckTrain(0, params) +} + +func newIVFBaseChecker() IndexChecker { + return &ivfBaseChecker{} +} diff --git a/pkg/util/indexparamcheck/ivf_base_checker_test.go b/internal/util/indexparamcheck/ivf_base_checker_test.go similarity index 91% rename from pkg/util/indexparamcheck/ivf_base_checker_test.go rename to internal/util/indexparamcheck/ivf_base_checker_test.go index 4a379038dde33..7486ce2ea245e 100644 --- a/pkg/util/indexparamcheck/ivf_base_checker_test.go +++ b/internal/util/indexparamcheck/ivf_base_checker_test.go @@ -70,9 +70,9 @@ func Test_ivfBaseChecker_CheckTrain(t *testing.T) { {p7, false}, } - c := newIVFBaseChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("IVF_FLAT") for _, test := range cases { - err := c.CheckTrain(test.params) + err := c.CheckTrain(0, test.params) if test.errIsNil { assert.NoError(t, err) } else { @@ -140,9 +140,9 @@ func Test_ivfBaseChecker_CheckValidDataType(t *testing.T) { }, } - c := newIVFBaseChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("IVF_FLAT") for _, test := range cases { - err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType}) + err := c.CheckValidDataType("IVF_FLAT", &schemapb.FieldSchema{DataType: test.dType}) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/ivf_pq_checker.go b/internal/util/indexparamcheck/ivf_pq_checker.go similarity index 87% rename from pkg/util/indexparamcheck/ivf_pq_checker.go rename to internal/util/indexparamcheck/ivf_pq_checker.go index 4c35f193c4689..9a11d4559356c 100644 --- a/pkg/util/indexparamcheck/ivf_pq_checker.go +++ b/internal/util/indexparamcheck/ivf_pq_checker.go @@ -3,6 +3,8 @@ package indexparamcheck import ( "fmt" "strconv" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) // ivfPQChecker checks if a IVF_PQ index can be built. @@ -11,8 +13,8 @@ type ivfPQChecker struct { } // CheckTrain checks if ivf-pq index can be built with the specific index parameters. -func (c *ivfPQChecker) CheckTrain(params map[string]string) error { - if err := c.ivfBaseChecker.CheckTrain(params); err != nil { +func (c *ivfPQChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + if err := c.ivfBaseChecker.CheckTrain(0, params); err != nil { return err } diff --git a/pkg/util/indexparamcheck/ivf_pq_checker_test.go b/internal/util/indexparamcheck/ivf_pq_checker_test.go similarity index 95% rename from pkg/util/indexparamcheck/ivf_pq_checker_test.go rename to internal/util/indexparamcheck/ivf_pq_checker_test.go index 4a22d45542b20..c4e142342d255 100644 --- a/pkg/util/indexparamcheck/ivf_pq_checker_test.go +++ b/internal/util/indexparamcheck/ivf_pq_checker_test.go @@ -141,9 +141,9 @@ func Test_ivfPQChecker_CheckTrain(t *testing.T) { {p7, false}, } - c := newIVFPQChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("IVF_PQ") for _, test := range cases { - err := c.CheckTrain(test.params) + err := c.CheckTrain(0, test.params) if test.errIsNil { assert.NoError(t, err) } else { @@ -211,9 +211,9 @@ func Test_ivfPQChecker_CheckValidDataType(t *testing.T) { }, } - c := newIVFPQChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("IVF_PQ") for _, test := range cases { - err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType}) + err := c.CheckValidDataType("IVF_PQ", &schemapb.FieldSchema{DataType: test.dType}) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/ivf_sq_checker.go b/internal/util/indexparamcheck/ivf_sq_checker.go similarity index 78% rename from pkg/util/indexparamcheck/ivf_sq_checker.go rename to internal/util/indexparamcheck/ivf_sq_checker.go index fc1a2204f5562..79a58ca77040e 100644 --- a/pkg/util/indexparamcheck/ivf_sq_checker.go +++ b/internal/util/indexparamcheck/ivf_sq_checker.go @@ -2,6 +2,8 @@ package indexparamcheck import ( "fmt" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) // ivfSQChecker checks if a IVF_SQ index can be built. @@ -22,11 +24,11 @@ func (c *ivfSQChecker) checkNBits(params map[string]string) error { } // CheckTrain returns true if the index can be built with the specific index parameters. -func (c *ivfSQChecker) CheckTrain(params map[string]string) error { +func (c *ivfSQChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { if err := c.checkNBits(params); err != nil { return err } - return c.ivfBaseChecker.CheckTrain(params) + return c.ivfBaseChecker.CheckTrain(0, params) } func newIVFSQChecker() IndexChecker { diff --git a/pkg/util/indexparamcheck/ivf_sq_checker_test.go b/internal/util/indexparamcheck/ivf_sq_checker_test.go similarity index 93% rename from pkg/util/indexparamcheck/ivf_sq_checker_test.go rename to internal/util/indexparamcheck/ivf_sq_checker_test.go index 9478623fe89e3..714f0abe288db 100644 --- a/pkg/util/indexparamcheck/ivf_sq_checker_test.go +++ b/internal/util/indexparamcheck/ivf_sq_checker_test.go @@ -90,9 +90,9 @@ func Test_ivfSQChecker_CheckTrain(t *testing.T) { {p7, false}, } - c := newIVFSQChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("IVF_SQ") for _, test := range cases { - err := c.CheckTrain(test.params) + err := c.CheckTrain(0, test.params) if test.errIsNil { assert.NoError(t, err) } else { @@ -160,9 +160,9 @@ func Test_ivfSQChecker_CheckValidDataType(t *testing.T) { }, } - c := newIVFSQChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("IVF_SQ") for _, test := range cases { - err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType}) + err := c.CheckValidDataType("IVF_SQ8", &schemapb.FieldSchema{DataType: test.dType}) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/raft_brute_force_checker.go b/internal/util/indexparamcheck/raft_brute_force_checker.go similarity index 62% rename from pkg/util/indexparamcheck/raft_brute_force_checker.go rename to internal/util/indexparamcheck/raft_brute_force_checker.go index 38872da7ec773..13d82a8b67e6c 100644 --- a/pkg/util/indexparamcheck/raft_brute_force_checker.go +++ b/internal/util/indexparamcheck/raft_brute_force_checker.go @@ -1,14 +1,18 @@ package indexparamcheck -import "fmt" +import ( + "fmt" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) type raftBruteForceChecker struct { floatVectorBaseChecker } // raftBrustForceChecker checks if a Brute_Force index can be built. -func (c raftBruteForceChecker) CheckTrain(params map[string]string) error { - if err := c.floatVectorBaseChecker.CheckTrain(params); err != nil { +func (c raftBruteForceChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + if err := c.floatVectorBaseChecker.CheckTrain(0, params); err != nil { return err } if !CheckStrByValues(params, Metric, RaftMetrics) { diff --git a/pkg/util/indexparamcheck/raft_brute_force_checker_test.go b/internal/util/indexparamcheck/raft_brute_force_checker_test.go similarity index 91% rename from pkg/util/indexparamcheck/raft_brute_force_checker_test.go rename to internal/util/indexparamcheck/raft_brute_force_checker_test.go index ce037bc4dcb9c..b7d2dc830bc5b 100644 --- a/pkg/util/indexparamcheck/raft_brute_force_checker_test.go +++ b/internal/util/indexparamcheck/raft_brute_force_checker_test.go @@ -52,9 +52,9 @@ func Test_raftbfChecker_CheckTrain(t *testing.T) { {p7, false}, } - c := newRaftBruteForceChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("GPU_BRUTE_FORCE") for _, test := range cases { - err := c.CheckTrain(test.params) + err := c.CheckTrain(0, test.params) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/raft_ivf_flat_checker.go b/internal/util/indexparamcheck/raft_ivf_flat_checker.go similarity index 75% rename from pkg/util/indexparamcheck/raft_ivf_flat_checker.go rename to internal/util/indexparamcheck/raft_ivf_flat_checker.go index 9f11803e9b17d..e429ec7405659 100644 --- a/pkg/util/indexparamcheck/raft_ivf_flat_checker.go +++ b/internal/util/indexparamcheck/raft_ivf_flat_checker.go @@ -1,6 +1,10 @@ package indexparamcheck -import "fmt" +import ( + "fmt" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) // raftIVFChecker checks if a RAFT_IVF_Flat index can be built. type raftIVFFlatChecker struct { @@ -8,8 +12,8 @@ type raftIVFFlatChecker struct { } // CheckTrain checks if ivf-flat index can be built with the specific index parameters. -func (c *raftIVFFlatChecker) CheckTrain(params map[string]string) error { - if err := c.ivfBaseChecker.CheckTrain(params); err != nil { +func (c *raftIVFFlatChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + if err := c.ivfBaseChecker.CheckTrain(0, params); err != nil { return err } if !CheckStrByValues(params, Metric, RaftMetrics) { diff --git a/pkg/util/indexparamcheck/raft_ivf_flat_checker_test.go b/internal/util/indexparamcheck/raft_ivf_flat_checker_test.go similarity index 92% rename from pkg/util/indexparamcheck/raft_ivf_flat_checker_test.go rename to internal/util/indexparamcheck/raft_ivf_flat_checker_test.go index 3d64f830392f4..b9cfcf0f8b72f 100644 --- a/pkg/util/indexparamcheck/raft_ivf_flat_checker_test.go +++ b/internal/util/indexparamcheck/raft_ivf_flat_checker_test.go @@ -84,9 +84,9 @@ func Test_raftIvfFlatChecker_CheckTrain(t *testing.T) { {p9, false}, } - c := newRaftIVFFlatChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("GPU_IVF_FLAT") for _, test := range cases { - err := c.CheckTrain(test.params) + err := c.CheckTrain(0, test.params) if test.errIsNil { assert.NoError(t, err) } else { @@ -154,9 +154,9 @@ func Test_raftIvfFlatChecker_CheckValidDataType(t *testing.T) { }, } - c := newRaftIVFFlatChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("GPU_IVF_FLAT") for _, test := range cases { - err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType}) + err := c.CheckValidDataType("GPU_IVF_FLAT", &schemapb.FieldSchema{DataType: test.dType}) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/raft_ivf_pq_checker.go b/internal/util/indexparamcheck/raft_ivf_pq_checker.go similarity index 88% rename from pkg/util/indexparamcheck/raft_ivf_pq_checker.go rename to internal/util/indexparamcheck/raft_ivf_pq_checker.go index 2457619118070..fe1e4a463efd0 100644 --- a/pkg/util/indexparamcheck/raft_ivf_pq_checker.go +++ b/internal/util/indexparamcheck/raft_ivf_pq_checker.go @@ -3,6 +3,8 @@ package indexparamcheck import ( "fmt" "strconv" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) // raftIVFPQChecker checks if a RAFT_IVF_PQ index can be built. @@ -11,8 +13,8 @@ type raftIVFPQChecker struct { } // CheckTrain checks if ivf-pq index can be built with the specific index parameters. -func (c *raftIVFPQChecker) CheckTrain(params map[string]string) error { - if err := c.ivfBaseChecker.CheckTrain(params); err != nil { +func (c *raftIVFPQChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + if err := c.ivfBaseChecker.CheckTrain(0, params); err != nil { return err } if !CheckStrByValues(params, Metric, RaftMetrics) { diff --git a/pkg/util/indexparamcheck/raft_ivf_pq_checker_test.go b/internal/util/indexparamcheck/raft_ivf_pq_checker_test.go similarity index 94% rename from pkg/util/indexparamcheck/raft_ivf_pq_checker_test.go rename to internal/util/indexparamcheck/raft_ivf_pq_checker_test.go index 8c882900e9ef1..d9cd4d9952cd2 100644 --- a/pkg/util/indexparamcheck/raft_ivf_pq_checker_test.go +++ b/internal/util/indexparamcheck/raft_ivf_pq_checker_test.go @@ -144,9 +144,9 @@ func Test_raftIVFPQChecker_CheckTrain(t *testing.T) { {p9, false}, } - c := newRaftIVFPQChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("GPU_IVF_PQ") for _, test := range cases { - err := c.CheckTrain(test.params) + err := c.CheckTrain(0, test.params) if test.errIsNil { assert.NoError(t, err) } else { @@ -214,9 +214,9 @@ func Test_raftIVFPQChecker_CheckValidDataType(t *testing.T) { }, } - c := newRaftIVFPQChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("GPU_IVF_PQ") for _, test := range cases { - err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType}) + err := c.CheckValidDataType("GPU_IVF_PQ", &schemapb.FieldSchema{DataType: test.dType}) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/internal/util/indexparamcheck/scalar_index_checker.go b/internal/util/indexparamcheck/scalar_index_checker.go new file mode 100644 index 0000000000000..a1272ae3880bc --- /dev/null +++ b/internal/util/indexparamcheck/scalar_index_checker.go @@ -0,0 +1,11 @@ +package indexparamcheck + +import "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + +type scalarIndexChecker struct { + baseChecker +} + +func (c scalarIndexChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + return nil +} diff --git a/pkg/util/indexparamcheck/scalar_index_checker_test.go b/internal/util/indexparamcheck/scalar_index_checker_test.go similarity index 70% rename from pkg/util/indexparamcheck/scalar_index_checker_test.go rename to internal/util/indexparamcheck/scalar_index_checker_test.go index eb3ae669e2891..423ddb1f66984 100644 --- a/pkg/util/indexparamcheck/scalar_index_checker_test.go +++ b/internal/util/indexparamcheck/scalar_index_checker_test.go @@ -8,5 +8,5 @@ import ( func TestCheckIndexValid(t *testing.T) { scalarIndexChecker := &scalarIndexChecker{} - assert.NoError(t, scalarIndexChecker.CheckTrain(map[string]string{})) + assert.NoError(t, scalarIndexChecker.CheckTrain(0, map[string]string{})) } diff --git a/pkg/util/indexparamcheck/scann_checker.go b/internal/util/indexparamcheck/scann_checker.go similarity index 78% rename from pkg/util/indexparamcheck/scann_checker.go rename to internal/util/indexparamcheck/scann_checker.go index eecf2ded64bbf..ffe4c25cc59f8 100644 --- a/pkg/util/indexparamcheck/scann_checker.go +++ b/internal/util/indexparamcheck/scann_checker.go @@ -3,6 +3,8 @@ package indexparamcheck import ( "fmt" "strconv" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) // scaNNChecker checks if a SCANN index can be built. @@ -11,8 +13,8 @@ type scaNNChecker struct { } // CheckTrain checks if SCANN index can be built with the specific index parameters. -func (c *scaNNChecker) CheckTrain(params map[string]string) error { - if err := c.ivfBaseChecker.CheckTrain(params); err != nil { +func (c *scaNNChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + if err := c.ivfBaseChecker.CheckTrain(0, params); err != nil { return err } diff --git a/pkg/util/indexparamcheck/scann_checker_test.go b/internal/util/indexparamcheck/scann_checker_test.go similarity index 96% rename from pkg/util/indexparamcheck/scann_checker_test.go rename to internal/util/indexparamcheck/scann_checker_test.go index 4f7014c6fde53..13a54597b0716 100644 --- a/pkg/util/indexparamcheck/scann_checker_test.go +++ b/internal/util/indexparamcheck/scann_checker_test.go @@ -89,7 +89,7 @@ func Test_scaNNChecker_CheckTrain(t *testing.T) { c := newScaNNChecker() for _, test := range cases { - err := c.CheckTrain(test.params) + err := c.CheckTrain(0, test.params) if test.errIsNil { assert.NoError(t, err) } else { @@ -159,7 +159,7 @@ func Test_scaNNChecker_CheckValidDataType(t *testing.T) { c := newScaNNChecker() for _, test := range cases { - err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType}) + err := c.CheckValidDataType("SCANN", &schemapb.FieldSchema{DataType: test.dType}) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/sparse_float_vector_base_checker.go b/internal/util/indexparamcheck/sparse_float_vector_base_checker.go similarity index 75% rename from pkg/util/indexparamcheck/sparse_float_vector_base_checker.go rename to internal/util/indexparamcheck/sparse_float_vector_base_checker.go index 218d2d3e03a3e..edb91157fa1e2 100644 --- a/pkg/util/indexparamcheck/sparse_float_vector_base_checker.go +++ b/internal/util/indexparamcheck/sparse_float_vector_base_checker.go @@ -12,7 +12,7 @@ import ( // sparse vector don't check for dim, but baseChecker does, thus not including baseChecker type sparseFloatVectorBaseChecker struct{} -func (c sparseFloatVectorBaseChecker) StaticCheck(params map[string]string) error { +func (c sparseFloatVectorBaseChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error { if !CheckStrByValues(params, Metric, SparseMetrics) { return fmt.Errorf("metric type not found or not supported, supported: %v", SparseMetrics) } @@ -20,7 +20,7 @@ func (c sparseFloatVectorBaseChecker) StaticCheck(params map[string]string) erro return nil } -func (c sparseFloatVectorBaseChecker) CheckTrain(params map[string]string) error { +func (c sparseFloatVectorBaseChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { dropRatioBuildStr, exist := params[SparseDropRatioBuild] if exist { dropRatioBuild, err := strconv.ParseFloat(dropRatioBuildStr, 64) @@ -32,14 +32,14 @@ func (c sparseFloatVectorBaseChecker) CheckTrain(params map[string]string) error return nil } -func (c sparseFloatVectorBaseChecker) CheckValidDataType(field *schemapb.FieldSchema) error { +func (c sparseFloatVectorBaseChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { if !typeutil.IsSparseFloatVectorType(field.GetDataType()) { return fmt.Errorf("only sparse float vector is supported for the specified index tpye") } return nil } -func (c sparseFloatVectorBaseChecker) SetDefaultMetricTypeIfNotExist(params map[string]string, dType schemapb.DataType) { +func (c sparseFloatVectorBaseChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, params map[string]string) { setDefaultIfNotExist(params, common.MetricTypeKey, SparseFloatVectorDefaultMetricType) } diff --git a/pkg/util/indexparamcheck/sparse_inverted_index_checker.go b/internal/util/indexparamcheck/sparse_inverted_index_checker.go similarity index 100% rename from pkg/util/indexparamcheck/sparse_inverted_index_checker.go rename to internal/util/indexparamcheck/sparse_inverted_index_checker.go diff --git a/pkg/util/indexparamcheck/stl_sort_checker.go b/internal/util/indexparamcheck/stl_sort_checker.go similarity index 65% rename from pkg/util/indexparamcheck/stl_sort_checker.go rename to internal/util/indexparamcheck/stl_sort_checker.go index 4b3441ad6dfcf..6e3339c0ef5bd 100644 --- a/pkg/util/indexparamcheck/stl_sort_checker.go +++ b/internal/util/indexparamcheck/stl_sort_checker.go @@ -12,11 +12,11 @@ type STLSORTChecker struct { scalarIndexChecker } -func (c *STLSORTChecker) CheckTrain(params map[string]string) error { - return c.scalarIndexChecker.CheckTrain(params) +func (c *STLSORTChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + return c.scalarIndexChecker.CheckTrain(0, params) } -func (c *STLSORTChecker) CheckValidDataType(field *schemapb.FieldSchema) error { +func (c *STLSORTChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { if !typeutil.IsArithmetic(field.GetDataType()) { return fmt.Errorf("STL_SORT are only supported on numeric field") } diff --git a/internal/util/indexparamcheck/stl_sort_checker_test.go b/internal/util/indexparamcheck/stl_sort_checker_test.go new file mode 100644 index 0000000000000..3e42c2810db65 --- /dev/null +++ b/internal/util/indexparamcheck/stl_sort_checker_test.go @@ -0,0 +1,22 @@ +package indexparamcheck + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +func Test_STLSORTIndexChecker(t *testing.T) { + c := newSTLSORTChecker() + + assert.NoError(t, c.CheckTrain(0, map[string]string{})) + + assert.NoError(t, c.CheckValidDataType(IndexSTLSORT, &schemapb.FieldSchema{DataType: schemapb.DataType_Int64})) + assert.NoError(t, c.CheckValidDataType(IndexSTLSORT, &schemapb.FieldSchema{DataType: schemapb.DataType_Float})) + + assert.Error(t, c.CheckValidDataType(IndexSTLSORT, &schemapb.FieldSchema{DataType: schemapb.DataType_VarChar})) + assert.Error(t, c.CheckValidDataType(IndexSTLSORT, &schemapb.FieldSchema{DataType: schemapb.DataType_Bool})) + assert.Error(t, c.CheckValidDataType(IndexSTLSORT, &schemapb.FieldSchema{DataType: schemapb.DataType_JSON})) +} diff --git a/pkg/util/indexparamcheck/trie_checker.go b/internal/util/indexparamcheck/trie_checker.go similarity index 64% rename from pkg/util/indexparamcheck/trie_checker.go rename to internal/util/indexparamcheck/trie_checker.go index 002014e42022c..2af351860f59f 100644 --- a/pkg/util/indexparamcheck/trie_checker.go +++ b/internal/util/indexparamcheck/trie_checker.go @@ -12,11 +12,11 @@ type TRIEChecker struct { scalarIndexChecker } -func (c *TRIEChecker) CheckTrain(params map[string]string) error { - return c.scalarIndexChecker.CheckTrain(params) +func (c *TRIEChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + return c.scalarIndexChecker.CheckTrain(0, params) } -func (c *TRIEChecker) CheckValidDataType(field *schemapb.FieldSchema) error { +func (c *TRIEChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { if !typeutil.IsStringType(field.GetDataType()) { return fmt.Errorf("TRIE are only supported on varchar field") } diff --git a/internal/util/indexparamcheck/trie_checker_test.go b/internal/util/indexparamcheck/trie_checker_test.go new file mode 100644 index 0000000000000..81a98b664832b --- /dev/null +++ b/internal/util/indexparamcheck/trie_checker_test.go @@ -0,0 +1,23 @@ +package indexparamcheck + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +func Test_TrieIndexChecker(t *testing.T) { + c := newTRIEChecker() + + assert.NoError(t, c.CheckTrain(0, map[string]string{})) + + assert.NoError(t, c.CheckValidDataType(IndexTRIE, &schemapb.FieldSchema{DataType: schemapb.DataType_VarChar})) + assert.NoError(t, c.CheckValidDataType(IndexTRIE, &schemapb.FieldSchema{DataType: schemapb.DataType_String})) + + assert.Error(t, c.CheckValidDataType(IndexTRIE, &schemapb.FieldSchema{DataType: schemapb.DataType_Bool})) + assert.Error(t, c.CheckValidDataType(IndexTRIE, &schemapb.FieldSchema{DataType: schemapb.DataType_Int64})) + assert.Error(t, c.CheckValidDataType(IndexTRIE, &schemapb.FieldSchema{DataType: schemapb.DataType_Float})) + assert.Error(t, c.CheckValidDataType(IndexTRIE, &schemapb.FieldSchema{DataType: schemapb.DataType_JSON})) +} diff --git a/pkg/util/indexparamcheck/utils.go b/internal/util/indexparamcheck/utils.go similarity index 100% rename from pkg/util/indexparamcheck/utils.go rename to internal/util/indexparamcheck/utils.go diff --git a/pkg/util/indexparamcheck/utils_test.go b/internal/util/indexparamcheck/utils_test.go similarity index 100% rename from pkg/util/indexparamcheck/utils_test.go rename to internal/util/indexparamcheck/utils_test.go diff --git a/internal/util/indexparamcheck/vector_index_checker.go b/internal/util/indexparamcheck/vector_index_checker.go new file mode 100644 index 0000000000000..006067c6827b4 --- /dev/null +++ b/internal/util/indexparamcheck/vector_index_checker.go @@ -0,0 +1,98 @@ +package indexparamcheck + +/* +#cgo pkg-config: milvus_core + +#include // free +#include "segcore/vector_index_c.h" +*/ +import "C" + +import ( + "fmt" + "unsafe" + + "google.golang.org/protobuf/proto" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/proto/indexcgopb" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type vecIndexChecker struct { + baseChecker +} + +// HandleCStatus deals with the error returned from CGO +func HandleCStatus(status *C.CStatus) error { + if status.error_code == 0 { + return nil + } + errorCode := status.error_code + errorMsg := C.GoString(status.error_msg) + defer C.free(unsafe.Pointer(status.error_msg)) + + return fmt.Errorf("code %d, msg %s", errorCode, errorMsg) +} + +func (c vecIndexChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error { + indexType, exist := params[common.IndexTypeKey] + + if !exist { + return fmt.Errorf("no indexType is specified") + } + + if !GetVecIndexMgrInstance().IsVecIndex(indexType) { + return fmt.Errorf("indexType %s is not supported", indexType) + } + + protoIndexParams := &indexcgopb.IndexParams{ + Params: make([]*commonpb.KeyValuePair, 0), + } + indexParamsBlob, err := proto.Marshal(protoIndexParams) + if err != nil { + return fmt.Errorf("failed to marshal index params: %s", err) + } + + var status C.CStatus + + cIndexType := C.CString(indexType) + cDataType := uint32(dataType) + status = C.ValidateIndexParams(cIndexType, cDataType, (*C.uint8_t)(unsafe.Pointer(&indexParamsBlob[0])), (C.uint64_t)(len(indexParamsBlob))) + C.free(unsafe.Pointer(cIndexType)) + + return HandleCStatus(&status) +} + +func (c vecIndexChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + if err := c.StaticCheck(dataType, params); err != nil { + return err + } + return c.baseChecker.CheckTrain(0, params) +} + +func (c vecIndexChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { + if !typeutil.IsVectorType(field.GetDataType()) { + return fmt.Errorf("can not create vector index for an invalid datatype") + } + if GetVecIndexMgrInstance().IsDataTypeSupport(indexType, field.GetDataType()) { + return fmt.Errorf("vector index %s do not support data type: %s", indexType, schemapb.DataType_name[int32(field.GetDataType())]) + } + return nil +} + +func (c vecIndexChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, params map[string]string) { + if typeutil.IsDenseFloatVectorType(dType) { + setDefaultIfNotExist(params, common.MetricTypeKey, FloatVectorDefaultMetricType) + } else if typeutil.IsSparseFloatVectorType(dType) { + setDefaultIfNotExist(params, common.MetricTypeKey, SparseFloatVectorDefaultMetricType) + } else if typeutil.IsBinaryVectorType(dType) { + setDefaultIfNotExist(params, common.MetricTypeKey, BinaryVectorDefaultMetricType) + } +} + +func newVecIndexChecker() IndexChecker { + return &vecIndexChecker{} +} diff --git a/pkg/util/indexparamcheck/vector_index_mgr.go b/internal/util/indexparamcheck/vector_index_mgr.go similarity index 69% rename from pkg/util/indexparamcheck/vector_index_mgr.go rename to internal/util/indexparamcheck/vector_index_mgr.go index c6d9e1bcc22d9..b268113f62f73 100644 --- a/pkg/util/indexparamcheck/vector_index_mgr.go +++ b/internal/util/indexparamcheck/vector_index_mgr.go @@ -17,7 +17,7 @@ package indexparamcheck /* -#cgo pkg-config: milvus_segcore +#cgo pkg-config: milvus_core #include // free #include "segcore/vector_index_c.h" @@ -27,12 +27,20 @@ import "C" import ( "bytes" "fmt" - "github.com/milvus-io/milvus/pkg/log" "sync" "unsafe" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/log" ) const ( + BinaryFlag uint64 = 1 << 0 + Float32Flag uint64 = 1 << 1 + Float16Flag uint64 = 1 << 2 + BFloat16Flag uint64 = 1 << 3 + SparseFloat32Flag uint64 = 1 << 4 + // BFFlag This flag indicates that there is no need to create any index structure BFFlag uint64 = 1 << 16 // KNNFlag This flag indicates that the index defaults to KNN search, meaning the recall rate is 100% @@ -49,6 +57,14 @@ const ( type VecIndexMgr interface { init() error + + IsBinarySupport(indexType IndexType) bool + IsFlat32Support(indexType IndexType) bool + IsFlat16Support(indexType IndexType) bool + IsBFlat16Support(indexType IndexType) bool + IsSparseFloat32Support(indexType IndexType) bool + IsDataTypeSupport(indexType IndexType, dataType schemapb.DataType) bool + IsFlatVecIndex(indexType IndexType) bool IsBruteForce(indexType IndexType) bool IsVecIndex(indexType IndexType) bool @@ -98,6 +114,62 @@ func (mgr *VecIndexMgrImpl) init() error { return nil } +func (mgr *VecIndexMgrImpl) IsBinarySupport(indexType IndexType) bool { + feature, ok := mgr.features[indexType] + if !ok { + return false + } + return (feature & BinaryFlag) == BinaryFlag +} + +func (mgr *VecIndexMgrImpl) IsFlat32Support(indexType IndexType) bool { + feature, ok := mgr.features[indexType] + if !ok { + return false + } + return (feature & Float32Flag) == Float32Flag +} + +func (mgr *VecIndexMgrImpl) IsFlat16Support(indexType IndexType) bool { + feature, ok := mgr.features[indexType] + if !ok { + return false + } + return (feature & Float16Flag) == Float16Flag +} + +func (mgr *VecIndexMgrImpl) IsBFlat16Support(indexType IndexType) bool { + feature, ok := mgr.features[indexType] + if !ok { + return false + } + return (feature & BFloat16Flag) == BFloat16Flag +} + +func (mgr *VecIndexMgrImpl) IsSparseFloat32Support(indexType IndexType) bool { + feature, ok := mgr.features[indexType] + if !ok { + return false + } + return (feature & SparseFloat32Flag) == SparseFloat32Flag +} + +func (mgr *VecIndexMgrImpl) IsDataTypeSupport(indexType IndexType, dataType schemapb.DataType) bool { + if dataType == schemapb.DataType_BinaryVector { + return mgr.IsBinarySupport(indexType) + } else if dataType == schemapb.DataType_FloatVector { + return mgr.IsFlat32Support(indexType) + } else if dataType == schemapb.DataType_BFloat16Vector { + return mgr.IsBFlat16Support(indexType) + } else if dataType == schemapb.DataType_Float16Vector { + return mgr.IsFlat16Support(indexType) + } else if dataType == schemapb.DataType_SparseFloatVector { + return mgr.IsSparseFloat32Support(indexType) + } else { + return false + } +} + func (mgr *VecIndexMgrImpl) IsFlatVecIndex(indexType IndexType) bool { feature, ok := mgr.features[indexType] if !ok { diff --git a/pkg/util/indexparamcheck/vector_index_mgr_test.go b/internal/util/indexparamcheck/vector_index_mgr_test.go similarity index 100% rename from pkg/util/indexparamcheck/vector_index_mgr_test.go rename to internal/util/indexparamcheck/vector_index_mgr_test.go diff --git a/pkg/config/config.go b/pkg/config/config.go index fc93c086f74d4..1307a191874a5 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -20,7 +20,10 @@ import ( "strings" "github.com/cockroachdb/errors" + "github.com/spf13/cast" + "go.uber.org/zap" + "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -30,6 +33,10 @@ var ( ErrKeyNotFound = errors.New("key not found") ) +const ( + NotFormatPrefix = "knowhere." +) + func Init(opts ...Option) (*Manager, error) { o := &Options{} for _, opt := range opts { @@ -56,6 +63,9 @@ func Init(opts ...Option) (*Manager, error) { var formattedKeys = typeutil.NewConcurrentMap[string, string]() func formatKey(key string) string { + if strings.HasPrefix(key, NotFormatPrefix) { + return key + } cached, ok := formattedKeys.Get(key) if ok { return cached @@ -64,3 +74,41 @@ func formatKey(key string) string { formattedKeys.Insert(key, result) return result } + +func parseConfig(prefix string, m map[string]interface{}, result map[string]string) { + for k, v := range m { + fullKey := k + if prefix != "" { + fullKey = prefix + "." + k + } + + switch val := v.(type) { + case map[string]interface{}: + parseConfig(fullKey, val, result) + case []interface{}: + str := "" + for i, item := range val { + itemStr, err := cast.ToStringE(item) + if err != nil { + log.Warn("cast to string failed", zap.Any("item", item)) + continue + } + if i == 0 { + str = itemStr + } else { + str = str + "," + itemStr + } + } + result[fullKey] = str + result[formatKey(fullKey)] = str + default: + str, err := cast.ToStringE(val) + if err != nil { + log.Warn("cast to string failed", zap.Any("val", val)) + continue + } + result[strings.ToLower(fullKey)] = str + result[formatKey(fullKey)] = str + } + } +} diff --git a/pkg/config/file_source.go b/pkg/config/file_source.go index e8402efe6b6ad..386ee20a434a4 100644 --- a/pkg/config/file_source.go +++ b/pkg/config/file_source.go @@ -22,9 +22,8 @@ import ( "github.com/cockroachdb/errors" "github.com/samber/lo" - "github.com/spf13/cast" - "github.com/spf13/viper" "go.uber.org/zap" + "gopkg.in/yaml.v3" "github.com/milvus-io/milvus/pkg/log" ) @@ -115,7 +114,6 @@ func (fs *FileSource) UpdateOptions(opts Options) { } func (fs *FileSource) loadFromFile() error { - yamlReader := viper.New() newConfig := make(map[string]string) var configFiles []string @@ -128,38 +126,19 @@ func (fs *FileSource) loadFromFile() error { continue } - yamlReader.SetConfigFile(configFile) - if err := yamlReader.ReadInConfig(); err != nil { + data, err := os.ReadFile(configFile) + if err != nil { return errors.Wrap(err, "Read config failed: "+configFile) } - for _, key := range yamlReader.AllKeys() { - val := yamlReader.Get(key) - str, err := cast.ToStringE(val) - if err != nil { - switch val := val.(type) { - case []any: - str = str[:0] - for _, v := range val { - ss, err := cast.ToStringE(v) - if err != nil { - log.Warn("cast to string failed", zap.Any("value", v)) - } - if str == "" { - str = ss - } else { - str = str + "," + ss - } - } - - default: - log.Warn("val is not a slice", zap.Any("value", val)) - continue - } - } - newConfig[key] = str - newConfig[formatKey(key)] = str + var config map[string]interface{} + + err = yaml.Unmarshal(data, &config) + if err != nil { + return errors.Wrap(err, "Unmarshal config failed: "+configFile) } + + parseConfig("", config, newConfig) } return fs.update(newConfig) diff --git a/pkg/util/indexparamcheck/bin_flat_checker.go b/pkg/util/indexparamcheck/bin_flat_checker.go deleted file mode 100644 index 2e0b813c38402..0000000000000 --- a/pkg/util/indexparamcheck/bin_flat_checker.go +++ /dev/null @@ -1,17 +0,0 @@ -package indexparamcheck - -type binFlatChecker struct { - binaryVectorBaseChecker -} - -func (c binFlatChecker) CheckTrain(params map[string]string) error { - return c.binaryVectorBaseChecker.CheckTrain(params) -} - -func (c binFlatChecker) StaticCheck(params map[string]string) error { - return c.staticCheck(params) -} - -func newBinFlatChecker() IndexChecker { - return &binFlatChecker{} -} diff --git a/pkg/util/indexparamcheck/bitmap_checker_test.go b/pkg/util/indexparamcheck/bitmap_checker_test.go deleted file mode 100644 index 95d74f85bc2dd..0000000000000 --- a/pkg/util/indexparamcheck/bitmap_checker_test.go +++ /dev/null @@ -1,33 +0,0 @@ -package indexparamcheck - -import ( - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" -) - -func Test_BitmapIndexChecker(t *testing.T) { - c := newBITMAPChecker() - - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Bool})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int8})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int16})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int32})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int64})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_String})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Bool})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int8})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int16})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int32})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_String})) - - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_JSON})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Float})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Double})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Float})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Double})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Double, IsPrimaryKey: true})) -} diff --git a/pkg/util/indexparamcheck/hybrid_checker_test.go b/pkg/util/indexparamcheck/hybrid_checker_test.go deleted file mode 100644 index 733adc2922804..0000000000000 --- a/pkg/util/indexparamcheck/hybrid_checker_test.go +++ /dev/null @@ -1,37 +0,0 @@ -package indexparamcheck - -import ( - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" -) - -func Test_HybridIndexChecker(t *testing.T) { - c := newHYBRIDChecker() - - assert.NoError(t, c.CheckTrain(map[string]string{"bitmap_cardinality_limit": "100"})) - - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Bool})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int8})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int16})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int32})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int64})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_String})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Bool})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int8})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int16})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int32})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_String})) - - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_JSON})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Float})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Double})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Float})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Double})) - assert.Error(t, c.CheckTrain(map[string]string{})) - assert.Error(t, c.CheckTrain(map[string]string{"bitmap_cardinality_limit": "0"})) - assert.Error(t, c.CheckTrain(map[string]string{"bitmap_cardinality_limit": "2000"})) -} diff --git a/pkg/util/indexparamcheck/inverted_checker_test.go b/pkg/util/indexparamcheck/inverted_checker_test.go deleted file mode 100644 index baecd97dd1766..0000000000000 --- a/pkg/util/indexparamcheck/inverted_checker_test.go +++ /dev/null @@ -1,25 +0,0 @@ -package indexparamcheck - -import ( - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" -) - -func Test_INVERTEDIndexChecker(t *testing.T) { - c := newINVERTEDChecker() - - assert.NoError(t, c.CheckTrain(map[string]string{})) - - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_VarChar})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_String})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Bool})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int64})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Float})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array})) - - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_JSON})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_FloatVector})) -} diff --git a/pkg/util/indexparamcheck/ivf_base_checker.go b/pkg/util/indexparamcheck/ivf_base_checker.go deleted file mode 100644 index 9b8a3e2e045a0..0000000000000 --- a/pkg/util/indexparamcheck/ivf_base_checker.go +++ /dev/null @@ -1,26 +0,0 @@ -package indexparamcheck - -type ivfBaseChecker struct { - floatVectorBaseChecker -} - -func (c ivfBaseChecker) StaticCheck(params map[string]string) error { - if !CheckIntByRange(params, NLIST, MinNList, MaxNList) { - return errOutOfRange(NLIST, MinNList, MaxNList) - } - - // skip check number of rows - - return c.floatVectorBaseChecker.staticCheck(params) -} - -func (c ivfBaseChecker) CheckTrain(params map[string]string) error { - if err := c.StaticCheck(params); err != nil { - return err - } - return c.floatVectorBaseChecker.CheckTrain(params) -} - -func newIVFBaseChecker() IndexChecker { - return &ivfBaseChecker{} -} diff --git a/pkg/util/indexparamcheck/scalar_index_checker.go b/pkg/util/indexparamcheck/scalar_index_checker.go deleted file mode 100644 index 9c372f4034c10..0000000000000 --- a/pkg/util/indexparamcheck/scalar_index_checker.go +++ /dev/null @@ -1,9 +0,0 @@ -package indexparamcheck - -type scalarIndexChecker struct { - baseChecker -} - -func (c scalarIndexChecker) CheckTrain(params map[string]string) error { - return nil -} diff --git a/pkg/util/indexparamcheck/stl_sort_checker_test.go b/pkg/util/indexparamcheck/stl_sort_checker_test.go deleted file mode 100644 index 771a51cd32f68..0000000000000 --- a/pkg/util/indexparamcheck/stl_sort_checker_test.go +++ /dev/null @@ -1,22 +0,0 @@ -package indexparamcheck - -import ( - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" -) - -func Test_STLSORTIndexChecker(t *testing.T) { - c := newSTLSORTChecker() - - assert.NoError(t, c.CheckTrain(map[string]string{})) - - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int64})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Float})) - - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_VarChar})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Bool})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_JSON})) -} diff --git a/pkg/util/indexparamcheck/trie_checker_test.go b/pkg/util/indexparamcheck/trie_checker_test.go deleted file mode 100644 index 3e1eaea1c5890..0000000000000 --- a/pkg/util/indexparamcheck/trie_checker_test.go +++ /dev/null @@ -1,23 +0,0 @@ -package indexparamcheck - -import ( - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" -) - -func Test_TrieIndexChecker(t *testing.T) { - c := newTRIEChecker() - - assert.NoError(t, c.CheckTrain(map[string]string{})) - - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_VarChar})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_String})) - - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Bool})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int64})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Float})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_JSON})) -} diff --git a/pkg/util/paramtable/autoindex_param.go b/pkg/util/paramtable/autoindex_param.go index 31df71a4a358d..1e43f3243fce7 100644 --- a/pkg/util/paramtable/autoindex_param.go +++ b/pkg/util/paramtable/autoindex_param.go @@ -23,7 +23,6 @@ import ( "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/config" "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" ) // ///////////////////////////////////////////////////////////////////////////// @@ -231,23 +230,10 @@ func (p *autoIndexConfig) panicIfNotValidAndSetDefaultMetricTypeHelper(key strin panic(fmt.Sprintf("%s invalid, should be json format", key)) } - indexType, ok := m[common.IndexTypeKey] + _, ok := m[common.IndexTypeKey] if !ok { panic(fmt.Sprintf("%s invalid, index type not found", key)) } - - checker, err := indexparamcheck.GetIndexCheckerMgrInstance().GetChecker(indexType) - if err != nil { - panic(fmt.Sprintf("%s invalid, unsupported index type: %s", key, indexType)) - } - - checker.SetDefaultMetricTypeIfNotExist(m, dtype) - - if err := checker.StaticCheck(m); err != nil { - panic(fmt.Sprintf("%s invalid, parameters invalid, error: %s", key, err.Error())) - } - - p.reset(key, m, mgr) } func (p *autoIndexConfig) reset(key string, m map[string]string, mgr *config.Manager) { diff --git a/pkg/util/paramtable/autoindex_param_test.go b/pkg/util/paramtable/autoindex_param_test.go index 231c8377e7a9d..e81fdb64266dc 100644 --- a/pkg/util/paramtable/autoindex_param_test.go +++ b/pkg/util/paramtable/autoindex_param_test.go @@ -18,6 +18,7 @@ package paramtable import ( "encoding/json" + "github.com/milvus-io/milvus/pkg/util/metric" "strconv" "testing" @@ -26,7 +27,6 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/config" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" ) const ( @@ -187,7 +187,7 @@ func Test_autoIndexConfig_panicIfNotValid(t *testing.T) { }) metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey] assert.True(t, exist) - assert.Equal(t, indexparamcheck.FloatVectorDefaultMetricType, metricType) + assert.Equal(t, metric.COSINE, metricType) }) t.Run("normal case, binary vector", func(t *testing.T) { @@ -204,7 +204,7 @@ func Test_autoIndexConfig_panicIfNotValid(t *testing.T) { }) metricType, exist := p.BinaryIndexParams.GetAsJSONMap()[common.MetricTypeKey] assert.True(t, exist) - assert.Equal(t, indexparamcheck.BinaryVectorDefaultMetricType, metricType) + assert.Equal(t, metric.HAMMING, metricType) }) t.Run("normal case, sparse vector", func(t *testing.T) { @@ -221,7 +221,7 @@ func Test_autoIndexConfig_panicIfNotValid(t *testing.T) { }) metricType, exist := p.SparseIndexParams.GetAsJSONMap()[common.MetricTypeKey] assert.True(t, exist) - assert.Equal(t, indexparamcheck.SparseFloatVectorDefaultMetricType, metricType) + assert.Equal(t, metric.IP, metricType) }) t.Run("normal case, ivf flat", func(t *testing.T) { @@ -238,7 +238,7 @@ func Test_autoIndexConfig_panicIfNotValid(t *testing.T) { }) metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey] assert.True(t, exist) - assert.Equal(t, indexparamcheck.FloatVectorDefaultMetricType, metricType) + assert.Equal(t, metric.COSINE, metricType) }) t.Run("normal case, ivf flat", func(t *testing.T) { @@ -255,7 +255,7 @@ func Test_autoIndexConfig_panicIfNotValid(t *testing.T) { }) metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey] assert.True(t, exist) - assert.Equal(t, indexparamcheck.FloatVectorDefaultMetricType, metricType) + assert.Equal(t, metric.COSINE, metricType) }) t.Run("normal case, diskann", func(t *testing.T) { @@ -272,7 +272,7 @@ func Test_autoIndexConfig_panicIfNotValid(t *testing.T) { }) metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey] assert.True(t, exist) - assert.Equal(t, indexparamcheck.FloatVectorDefaultMetricType, metricType) + assert.Equal(t, metric.COSINE, metricType) }) t.Run("normal case, bin flat", func(t *testing.T) { @@ -289,7 +289,7 @@ func Test_autoIndexConfig_panicIfNotValid(t *testing.T) { }) metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey] assert.True(t, exist) - assert.Equal(t, indexparamcheck.BinaryVectorDefaultMetricType, metricType) + assert.Equal(t, metric.HAMMING, metricType) }) t.Run("normal case, bin ivf flat", func(t *testing.T) { @@ -306,7 +306,7 @@ func Test_autoIndexConfig_panicIfNotValid(t *testing.T) { }) metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey] assert.True(t, exist) - assert.Equal(t, indexparamcheck.BinaryVectorDefaultMetricType, metricType) + assert.Equal(t, metric.HAMMING, metricType) }) } diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index 7cc9162d74d25..47dfd9368fd19 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -77,6 +77,7 @@ type ComponentParam struct { DataCoordCfg dataCoordConfig DataNodeCfg dataNodeConfig IndexNodeCfg indexNodeConfig + IndexEngineConfig indexEngineConfig HTTPCfg httpConfig LogCfg logConfig RoleCfg roleConfig @@ -139,6 +140,7 @@ func (p *ComponentParam) init(bt *BaseTable) { p.GpuConfig.init(bt) p.StreamingCoordCfg.init(bt) p.StreamingNodeCfg.init(bt) + p.IndexEngineConfig.init(bt) p.RootCoordGrpcServerCfg.Init("rootCoord", bt) p.ProxyGrpcServerCfg.Init("proxy", bt) diff --git a/pkg/util/paramtable/indexengine_param.go b/pkg/util/paramtable/indexengine_param.go new file mode 100644 index 0000000000000..fdd28e26433cb --- /dev/null +++ b/pkg/util/paramtable/indexengine_param.go @@ -0,0 +1,118 @@ +package paramtable + +import ( + "fmt" + "strconv" + "strings" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus/pkg/util/hardware" +) + +type indexEngineConfig struct { + Enable ParamItem `refreshable:"true"` + IndexParam ParamGroup `refreshable:"true"` +} + +const ( + BuildStage = "build" + LoadStage = "load" + SearchStage = "search" +) + +const ( + BuildDramBudgetKey = "build_dram_budget_gb" + NumBuildThreadKey = "num_build_thread" + VecFieldSizeKey = "vec_field_size_gb" +) + +func (p *indexEngineConfig) init(base *BaseTable) { + p.IndexParam = ParamGroup{ + KeyPrefix: "knowhere.", + Version: "2.5.0", + } + p.IndexParam.Init(base.mgr) + + p.Enable = ParamItem{ + Key: "knowhere.enable", + Version: "2.5.0", + DefaultValue: "false", + } + p.Enable.Init(base.mgr) +} + +func (p *indexEngineConfig) getIndexParam(indexType string, stage string) map[string]string { + matchedParam := make(map[string]string) + + params := p.IndexParam.GetValue() + prefix := indexType + "." + stage + "." + + for k, v := range params { + if strings.HasPrefix(k, prefix) { + matchedParam[strings.TrimPrefix(k, prefix)] = v + } + } + + return matchedParam +} + +func GetKeyFromSlice(indexParams []*commonpb.KeyValuePair, key string) string { + for _, param := range indexParams { + if param.Key == key { + return param.Value + } + } + return "" +} + +func (p *indexEngineConfig) GetRuntimeParam(stage string) (map[string]string, error) { + params := make(map[string]string) + + if stage == BuildStage { + params[BuildDramBudgetKey] = fmt.Sprintf("%f", float32(hardware.GetFreeMemoryCount())/(1<<30)) + params[NumBuildThreadKey] = strconv.Itoa(int(float32(hardware.GetCPUNum()))) + } + + return params, nil +} + +func (p *indexEngineConfig) MergeRequestParam(indexType string, stage string, indexParams []*commonpb.KeyValuePair) ([]*commonpb.KeyValuePair, error) { + defaultParams := p.getIndexParam(indexType, stage) + + for key, val := range defaultParams { + if GetKeyFromSlice(indexParams, key) == "" { + indexParams = append(indexParams, + &commonpb.KeyValuePair{ + Key: key, + Value: val, + }) + } + } + + return indexParams, nil +} + +func (p *indexEngineConfig) MergeWithResource(vecFieldSize uint64, indexParam map[string]string) (map[string]string, error) { + param, _ := p.GetRuntimeParam(BuildStage) + + for key, val := range param { + indexParam[key] = val + } + + indexParam[VecFieldSizeKey] = fmt.Sprintf("%f", float32(vecFieldSize)/(1<<30)) + + return indexParam, nil +} + +func (p *indexEngineConfig) MergeRequestMapParam(indexType string, stage string, indexParam map[string]string) (map[string]string, error) { + defaultParams := p.getIndexParam(indexType, stage) + + for key, val := range defaultParams { + _, existed := indexParam[key] + if !existed { + indexParam[key] = val + } + } + + return indexParam, nil +} diff --git a/pkg/util/paramtable/indexengine_param_test.go b/pkg/util/paramtable/indexengine_param_test.go new file mode 100644 index 0000000000000..80ae21241125b --- /dev/null +++ b/pkg/util/paramtable/indexengine_param_test.go @@ -0,0 +1,20 @@ +package paramtable + +import "testing" + +func TestIndexEngineConfig_Init(t *testing.T) { + params := ComponentParam{} + params.Init(NewBaseTable(SkipRemote(true))) + + cfg := ¶ms.IndexEngineConfig + print(cfg) +} + +func TestIndexEngineConfig_Get(t *testing.T) { + params := ComponentParam{} + params.Init(NewBaseTable(SkipRemote(true))) + + cfg := ¶ms.IndexEngineConfig + diskANNbuild := cfg.getIndexParam("DISKANN", BuildStage) + print(diskANNbuild) +} diff --git a/tests/integration/import/import_test.go b/tests/integration/import/import_test.go index 0eb43b55f8e3f..594607c083e08 100644 --- a/tests/integration/import/import_test.go +++ b/tests/integration/import/import_test.go @@ -33,10 +33,10 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/util/importutilv2" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/metric" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/tests/integration" @@ -66,7 +66,7 @@ func (s *BulkInsertSuite) SetupTest() { s.autoID = false s.vecType = schemapb.DataType_FloatVector - s.indexType = indexparamcheck.IndexHNSW + s.indexType = "HNSW" s.metricType = metric.L2 } @@ -225,29 +225,29 @@ func (s *BulkInsertSuite) TestMultiFileTypes() { s.fileType = fileType s.vecType = schemapb.DataType_BinaryVector - s.indexType = indexparamcheck.IndexFaissBinIvfFlat + s.indexType = "BIN_IVF_FLAT" s.metricType = metric.HAMMING s.run() s.vecType = schemapb.DataType_FloatVector - s.indexType = indexparamcheck.IndexHNSW + s.indexType = "HNSW" s.metricType = metric.L2 s.run() s.vecType = schemapb.DataType_Float16Vector - s.indexType = indexparamcheck.IndexHNSW + s.indexType = "HNSW" s.metricType = metric.L2 s.run() s.vecType = schemapb.DataType_BFloat16Vector - s.indexType = indexparamcheck.IndexHNSW + s.indexType = "HNSW" s.metricType = metric.L2 s.run() // TODO: not support numpy for SparseFloatVector by now if fileType != importutilv2.Numpy { s.vecType = schemapb.DataType_SparseFloatVector - s.indexType = indexparamcheck.IndexSparseWand + s.indexType = "SPARSE_WAND" s.metricType = metric.IP s.run() } diff --git a/tests/integration/util_index.go b/tests/integration/util_index.go index 666cc2d15ac7b..15da7d7885aba 100644 --- a/tests/integration/util_index.go +++ b/tests/integration/util_index.go @@ -26,23 +26,22 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/pkg/common" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" ) const ( - IndexRaftIvfFlat = indexparamcheck.IndexRaftIvfFlat - IndexRaftIvfPQ = indexparamcheck.IndexRaftIvfPQ - IndexFaissIDMap = indexparamcheck.IndexFaissIDMap - IndexFaissIvfFlat = indexparamcheck.IndexFaissIvfFlat - IndexFaissIvfPQ = indexparamcheck.IndexFaissIvfPQ - IndexScaNN = indexparamcheck.IndexScaNN - IndexFaissIvfSQ8 = indexparamcheck.IndexFaissIvfSQ8 - IndexFaissBinIDMap = indexparamcheck.IndexFaissBinIDMap - IndexFaissBinIvfFlat = indexparamcheck.IndexFaissBinIvfFlat - IndexHNSW = indexparamcheck.IndexHNSW - IndexDISKANN = indexparamcheck.IndexDISKANN - IndexSparseInvertedIndex = indexparamcheck.IndexSparseInverted - IndexSparseWand = indexparamcheck.IndexSparseWand + IndexRaftIvfFlat = "GPU_IVF_FLAT" + IndexRaftIvfPQ = "GPU_IVF_PQ" + IndexFaissIDMap = "FLAT" + IndexFaissIvfFlat = "IVF_FLAT" + IndexFaissIvfPQ = "IVF_PQ" + IndexScaNN = "SCANN" + IndexFaissIvfSQ8 = "IVF_SQ8" + IndexFaissBinIDMap = "BIN_FLAT" + IndexFaissBinIvfFlat = "BIN_IVF_FLAT" + IndexHNSW = "HNSW" + IndexDISKANN = "DISKANN" + IndexSparseInvertedIndex = "SPARSE_INVERTED_INDEX" + IndexSparseWand = "SPARSE_WAND" ) func (s *MiniClusterSuite) WaitForIndexBuiltWithDB(ctx context.Context, dbName, collection, field string) {