diff --git a/internal/core/src/common/ChunkWriter.cpp b/internal/core/src/common/ChunkWriter.cpp index 634cf4c99db31..1d05b8ddcffa5 100644 --- a/internal/core/src/common/ChunkWriter.cpp +++ b/internal/core/src/common/ChunkWriter.cpp @@ -356,13 +356,14 @@ create_chunk(const FieldMeta& field_meta, } case milvus::DataType::VECTOR_FLOAT: { w = std::make_shared< - ChunkWriter>(dim, nullable); + ChunkWriter>( + dim, nullable); break; } case milvus::DataType::VECTOR_BINARY: { w = std::make_shared< - ChunkWriter>(dim / 8, - nullable); + ChunkWriter>( + dim / 8, nullable); break; } case milvus::DataType::VECTOR_FLOAT16: { @@ -377,6 +378,12 @@ create_chunk(const FieldMeta& field_meta, dim, nullable); break; } + case milvus::DataType::VECTOR_INT8: { + w = std::make_shared< + ChunkWriter>( + dim, nullable); + break; + } case milvus::DataType::VARCHAR: case milvus::DataType::STRING: { w = std::make_shared(nullable); @@ -450,13 +457,13 @@ create_chunk(const FieldMeta& field_meta, } case milvus::DataType::VECTOR_FLOAT: { w = std::make_shared< - ChunkWriter>( + ChunkWriter>( dim, file, file_offset, nullable); break; } case milvus::DataType::VECTOR_BINARY: { w = std::make_shared< - ChunkWriter>( + ChunkWriter>( dim / 8, file, file_offset, nullable); break; } @@ -472,6 +479,12 @@ create_chunk(const FieldMeta& field_meta, dim, file, file_offset, nullable); break; } + case milvus::DataType::VECTOR_INT8: { + w = std::make_shared< + ChunkWriter>( + dim, file, file_offset, nullable); + break; + } case milvus::DataType::VARCHAR: case milvus::DataType::STRING: { w = std::make_shared( diff --git a/internal/core/src/common/FieldData.cpp b/internal/core/src/common/FieldData.cpp index 69015dd2743c0..470f9e9420322 100644 --- a/internal/core/src/common/FieldData.cpp +++ b/internal/core/src/common/FieldData.cpp @@ -238,6 +238,7 @@ FieldDataImpl::FillFieldData( case DataType::VECTOR_FLOAT: case DataType::VECTOR_FLOAT16: case DataType::VECTOR_BFLOAT16: + case DataType::VECTOR_INT8: case DataType::VECTOR_BINARY: { auto array_info = GetDataInfoFromArray : public FieldDataSparseVectorImpl { } }; +template <> +class FieldData : public FieldDataImpl { + public: + explicit FieldData(int64_t dim, + DataType data_type, + int64_t buffered_num_rows = 0) + : FieldDataImpl::FieldDataImpl( + dim, data_type, false, buffered_num_rows) { + } +}; + using FieldDataPtr = std::shared_ptr; using FieldDataChannel = Channel; using FieldDataChannelPtr = std::shared_ptr; diff --git a/internal/core/src/common/Types.h b/internal/core/src/common/Types.h index d97399a31d870..52b7a3325849b 100644 --- a/internal/core/src/common/Types.h +++ b/internal/core/src/common/Types.h @@ -125,12 +125,12 @@ GetDataTypeSize(DataType data_type, int dim = 1) { AssertInfo(dim % 8 == 0, "dim={}", dim); return dim / 8; } - case DataType::VECTOR_FLOAT16: { + case DataType::VECTOR_FLOAT16: return sizeof(float16) * dim; - } - case DataType::VECTOR_BFLOAT16: { + case DataType::VECTOR_BFLOAT16: return sizeof(bfloat16) * dim; - } + case DataType::VECTOR_INT8: + return sizeof(int8) * dim; // Not supporting variable length types(such as VECTOR_SPARSE_FLOAT and // VARCHAR) here intentionally. We can't easily estimate the size of // them. Caller of this method must handle this case themselves and must @@ -192,6 +192,8 @@ GetDataTypeName(DataType data_type) { return "vector_bfloat16"; case DataType::VECTOR_SPARSE_FLOAT: return "vector_sparse_float"; + case DataType::VECTOR_INT8: + return "vector_int8"; default: PanicInfo(DataTypeInvalid, "Unsupported DataType({})", data_type); } @@ -325,7 +327,7 @@ IsSparseFloatVectorDataType(DataType data_type) { } inline bool -IsInt8VectorDataType(DataType data_type) { +IsIntVectorDataType(DataType data_type) { return data_type == DataType::VECTOR_INT8; } @@ -338,7 +340,7 @@ IsFloatVectorDataType(DataType data_type) { inline bool IsVectorDataType(DataType data_type) { return IsBinaryVectorDataType(data_type) || - IsFloatVectorDataType(data_type) || IsInt8VectorDataType(data_type); + IsFloatVectorDataType(data_type) || IsIntVectorDataType(data_type); } inline bool @@ -642,6 +644,9 @@ struct fmt::formatter : formatter { case milvus::DataType::VECTOR_SPARSE_FLOAT: name = "VECTOR_SPARSE_FLOAT"; break; + case milvus::DataType::VECTOR_INT8: + name = "VECTOR_INT8"; + break; } return formatter::format(name, ctx); } diff --git a/internal/core/src/common/Utils.h b/internal/core/src/common/Utils.h index 0e52db367ad52..0bdade059d10d 100644 --- a/internal/core/src/common/Utils.h +++ b/internal/core/src/common/Utils.h @@ -156,14 +156,6 @@ IsMetricType(const std::string_view str, return !strcasecmp(str.data(), metric_type.c_str()); } -inline bool -IsFloatMetricType(const knowhere::MetricType& metric_type) { - return IsMetricType(metric_type, knowhere::metric::L2) || - IsMetricType(metric_type, knowhere::metric::IP) || - IsMetricType(metric_type, knowhere::metric::COSINE) || - IsMetricType(metric_type, knowhere::metric::BM25); -} - inline bool PositivelyRelated(const knowhere::MetricType& metric_type) { return IsMetricType(metric_type, knowhere::metric::IP) || diff --git a/internal/core/src/common/VectorTrait.h b/internal/core/src/common/VectorTrait.h index 19c112737a670..4e9734177a6cb 100644 --- a/internal/core/src/common/VectorTrait.h +++ b/internal/core/src/common/VectorTrait.h @@ -30,25 +30,30 @@ namespace milvus { #define GET_ELEM_TYPE_FOR_VECTOR_TRAIT \ using elem_type = std::conditional_t< \ - std::is_same_v, \ - BinaryVector::embedded_type, \ + std::is_same_v, \ + milvus::FloatVector::embedded_type, \ std::conditional_t< \ std::is_same_v, \ - Float16Vector::embedded_type, \ + milvus::Float16Vector::embedded_type, \ std::conditional_t< \ std::is_same_v, \ - BFloat16Vector::embedded_type, \ - FloatVector::embedded_type>>>; + milvus::BFloat16Vector::embedded_type, \ + std::conditional_t< \ + std::is_same_v, \ + milvus::Int8Vector::embedded_type, \ + milvus::BinaryVector::embedded_type>>>>; #define GET_SCHEMA_DATA_TYPE_FOR_VECTOR_TRAIT \ auto schema_data_type = \ std::is_same_v \ - ? FloatVector::schema_data_type \ + ? milvus::FloatVector::schema_data_type \ : std::is_same_v \ - ? Float16Vector::schema_data_type \ + ? milvus::Float16Vector::schema_data_type \ : std::is_same_v \ - ? BFloat16Vector::schema_data_type \ - : BinaryVector::schema_data_type; + ? milvus::BFloat16Vector::schema_data_type \ + : std::is_same_v \ + ? milvus::Int8Vector::schema_data_type \ + : milvus::BinaryVector::schema_data_type; class VectorTrait {}; @@ -118,6 +123,19 @@ class SparseFloatVector : public VectorTrait { proto::common::PlaceholderType::SparseFloatVector; }; +class Int8Vector : public VectorTrait { + public: + using embedded_type = int8; + static constexpr int32_t dim_factor = 1; + static constexpr auto data_type = DataType::VECTOR_INT8; + static constexpr auto c_data_type = CDataType::Int8Vector; + static constexpr auto schema_data_type = + proto::schema::DataType::Int8Vector; + static constexpr auto vector_type = proto::plan::VectorType::Int8Vector; + static constexpr auto placeholder_type = + proto::common::PlaceholderType::Int8Vector; +}; + template constexpr bool IsVector = std::is_base_of_v; diff --git a/internal/core/src/common/type_c.h b/internal/core/src/common/type_c.h index 77bc563698933..e6c6c8e8f6811 100644 --- a/internal/core/src/common/type_c.h +++ b/internal/core/src/common/type_c.h @@ -55,6 +55,7 @@ enum CDataType { Float16Vector = 102, BFloat16Vector = 103, SparseFloatVector = 104, + Int8Vector = 105, }; typedef enum CDataType CDataType; diff --git a/internal/core/src/index/IndexFactory.cpp b/internal/core/src/index/IndexFactory.cpp index 6d88c596a76bb..e7bbfcfc0041e 100644 --- a/internal/core/src/index/IndexFactory.cpp +++ b/internal/core/src/index/IndexFactory.cpp @@ -179,6 +179,16 @@ IndexFactory::VecIndexLoadResource( knowhere::IndexStaticFaced::HasRawData( index_type, index_version, config); break; + case milvus::DataType::VECTOR_INT8: + resource = knowhere::IndexStaticFaced< + knowhere::int8>::EstimateLoadResource(index_type, + index_version, + index_size_gb, + config); + has_raw_data = + knowhere::IndexStaticFaced::HasRawData( + index_type, index_version, config); + break; default: LOG_ERROR("invalid data type to estimate index load resource: {}", field_type); @@ -426,6 +436,9 @@ IndexFactory::CreateVectorIndex( return std::make_unique>( index_type, metric_type, version, file_manager_context); } + case DataType::VECTOR_INT8: { + // TODO caiyd, not support yet + } default: PanicInfo( DataTypeInvalid, diff --git a/internal/core/src/query/ExecPlanNodeVisitor.cpp b/internal/core/src/query/ExecPlanNodeVisitor.cpp index d00488bf94b4e..a28e37aa70770 100644 --- a/internal/core/src/query/ExecPlanNodeVisitor.cpp +++ b/internal/core/src/query/ExecPlanNodeVisitor.cpp @@ -222,4 +222,9 @@ ExecPlanNodeVisitor::visit(SparseFloatVectorANNS& node) { VectorVisitorImpl(node); } +void +ExecPlanNodeVisitor::visit(Int8VectorANNS& node) { + VectorVisitorImpl(node); +} + } // namespace milvus::query diff --git a/internal/core/src/query/ExecPlanNodeVisitor.h b/internal/core/src/query/ExecPlanNodeVisitor.h index 1d9f160da94df..bfaebef15179e 100644 --- a/internal/core/src/query/ExecPlanNodeVisitor.h +++ b/internal/core/src/query/ExecPlanNodeVisitor.h @@ -37,6 +37,9 @@ class ExecPlanNodeVisitor : public PlanNodeVisitor { void visit(SparseFloatVectorANNS& node) override; + void + visit(Int8VectorANNS& node) override; + void visit(RetrievePlanNode& node) override; diff --git a/internal/core/src/query/PlanNode.cpp b/internal/core/src/query/PlanNode.cpp index 540ad68aa925f..65214706bde5c 100644 --- a/internal/core/src/query/PlanNode.cpp +++ b/internal/core/src/query/PlanNode.cpp @@ -40,6 +40,11 @@ SparseFloatVectorANNS::accept(PlanNodeVisitor& visitor) { visitor.visit(*this); } +void +Int8VectorANNS::accept(PlanNodeVisitor& visitor) { + visitor.visit(*this); +} + void RetrievePlanNode::accept(PlanNodeVisitor& visitor) { visitor.visit(*this); diff --git a/internal/core/src/query/PlanNode.h b/internal/core/src/query/PlanNode.h index 5f771f40aa88d..af5b11323173e 100644 --- a/internal/core/src/query/PlanNode.h +++ b/internal/core/src/query/PlanNode.h @@ -71,6 +71,12 @@ struct SparseFloatVectorANNS : VectorPlanNode { accept(PlanNodeVisitor&) override; }; +struct Int8VectorANNS : VectorPlanNode { + public: + void + accept(PlanNodeVisitor&) override; +}; + struct RetrievePlanNode : PlanNode { public: void diff --git a/internal/core/src/query/PlanNodeVisitor.h b/internal/core/src/query/PlanNodeVisitor.h index 60dda9c3eb7fe..9f4620ef47f74 100644 --- a/internal/core/src/query/PlanNodeVisitor.h +++ b/internal/core/src/query/PlanNodeVisitor.h @@ -34,6 +34,9 @@ class PlanNodeVisitor { virtual void visit(SparseFloatVectorANNS&) = 0; + virtual void + visit(Int8VectorANNS&) = 0; + virtual void visit(RetrievePlanNode&) = 0; }; diff --git a/internal/core/src/query/PlanProto.cpp b/internal/core/src/query/PlanProto.cpp index 0790a841b2ccb..8c13f17272a44 100644 --- a/internal/core/src/query/PlanProto.cpp +++ b/internal/core/src/query/PlanProto.cpp @@ -123,6 +123,9 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) { } else if (anns_proto.vector_type() == milvus::proto::plan::VectorType::SparseFloatVector) { return std::make_unique(); + } else if (anns_proto.vector_type() == + milvus::proto::plan::VectorType::Int8Vector) { + return std::make_unique(); } else { return std::make_unique(); } diff --git a/internal/core/src/query/SearchBruteForce.cpp b/internal/core/src/query/SearchBruteForce.cpp index 38887a150909e..02efb604c8bf2 100644 --- a/internal/core/src/query/SearchBruteForce.cpp +++ b/internal/core/src/query/SearchBruteForce.cpp @@ -35,10 +35,19 @@ CheckBruteForceSearchParam(const FieldMeta& field, AssertInfo(IsVectorDataType(data_type), "[BruteForceSearch] Data type isn't vector type"); - bool is_float_vec_data_type = IsFloatVectorDataType(data_type); - bool is_float_metric_type = IsFloatMetricType(metric_type); - AssertInfo(is_float_vec_data_type == is_float_metric_type, - "[BruteForceSearch] Data type and metric type miss-match"); + if (IsBinaryVectorDataType(data_type)) { + AssertInfo(IsBinaryVectorMetricType(metric_type), + "[BruteForceSearch] Binary vector, data type and metric type miss-match"); + } else if (IsFloatVectorDataType(data_type)) { + AssertInfo(IsFloatVectorMetricType(metric_type), + "[BruteForceSearch] Float vector, data type and metric type miss-match"); + } else if (IsIntVectorDataType(data_type)) { + AssertInfo(IsIntVectorMetricType(metric_type), + "[BruteForceSearch] Int vector, data type and metric type miss-match"); + } else { + AssertInfo(IsVectorDataType(data_type), + "[BruteForceSearch] Unsupported vector data type"); + } } knowhere::Json @@ -94,6 +103,12 @@ PrepareBFDataSet(const dataset::SearchDataset& query_ds, knowhere::ConvertFromDataTypeIfNeeded(base_dataset); query_dataset = knowhere::ConvertFromDataTypeIfNeeded(query_dataset); + } else if (data_type == DataType::VECTOR_INT8) { + // TODO caiyd: if knowhere support real int8 bf, remove this + base_dataset = + knowhere::ConvertFromDataTypeIfNeeded(base_dataset); + query_dataset = + knowhere::ConvertFromDataTypeIfNeeded(query_dataset); } base_dataset->SetTensorBeginId(raw_ds.begin_id); return std::make_pair(query_dataset, base_dataset); @@ -147,6 +162,10 @@ BruteForceSearch(const dataset::SearchDataset& query_ds, res = knowhere::BruteForce::RangeSearch< knowhere::sparse::SparseRow>( base_dataset, query_dataset, search_cfg, bitset); + } else if (data_type == DataType::VECTOR_INT8) { + // TODO caiyd: if knowhere support real int8 bf, change it + res = knowhere::BruteForce::RangeSearch( + base_dataset, query_dataset, search_cfg, bitset); } else { PanicInfo( ErrorCode::Unsupported, @@ -211,6 +230,15 @@ BruteForceSearch(const dataset::SearchDataset& query_ds, sub_result.mutable_distances().data(), search_cfg, bitset); + } else if (data_type == DataType::VECTOR_INT8) { + // TODO caiyd: if knowhere support real int8 bf, change it + stat = knowhere::BruteForce::SearchWithBuf( + base_dataset, + query_dataset, + sub_result.mutable_seg_offsets().data(), + sub_result.mutable_distances().data(), + search_cfg, + bitset); } else { PanicInfo(ErrorCode::Unsupported, "Unsupported dataType for chunk brute force search:{}", @@ -236,22 +264,22 @@ DispatchBruteForceIteratorByDataType(const knowhere::DataSetPtr& base_dataset, case DataType::VECTOR_FLOAT: return knowhere::BruteForce::AnnIterator( base_dataset, query_dataset, config, bitset); - break; case DataType::VECTOR_FLOAT16: //todo: if knowhere support real fp16/bf16 bf, change it return knowhere::BruteForce::AnnIterator( base_dataset, query_dataset, config, bitset); - break; case DataType::VECTOR_BFLOAT16: //todo: if knowhere support real fp16/bf16 bf, change it return knowhere::BruteForce::AnnIterator( base_dataset, query_dataset, config, bitset); - break; case DataType::VECTOR_SPARSE_FLOAT: return knowhere::BruteForce::AnnIterator< knowhere::sparse::SparseRow>( base_dataset, query_dataset, config, bitset); - break; + case DataType::VECTOR_INT8: + // TODO caiyd: if knowhere support real int8 bf, change it + return knowhere::BruteForce::AnnIterator( + base_dataset, query_dataset, config, bitset); default: PanicInfo(ErrorCode::Unsupported, "Unsupported dataType for chunk brute force iterator:{}", diff --git a/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp b/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp index 20f8f8860a9b6..976d9566c3c37 100644 --- a/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp +++ b/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp @@ -1722,7 +1722,15 @@ ChunkedSegmentSealedImpl::get_raw_data(FieldId field_id, ret->mutable_vectors()->set_dim(dst->dim()); break; } - + case DataType::VECTOR_INT8: { + bulk_subscript_impl( + field_meta.get_sizeof(), + column.get(), + seg_offsets, + count, + ret->mutable_vectors()->mutable_int8_vector()->data()); + break; + } default: { PanicInfo(DataTypeInvalid, fmt::format("unsupported data type {}", diff --git a/internal/core/src/segcore/ConcurrentVector.cpp b/internal/core/src/segcore/ConcurrentVector.cpp index 0fc665d303ab6..ba6154f065457 100644 --- a/internal/core/src/segcore/ConcurrentVector.cpp +++ b/internal/core/src/segcore/ConcurrentVector.cpp @@ -44,6 +44,9 @@ VectorBase::set_data_raw(ssize_t element_offset, data->vectors().sparse_float_vector().contents()) .get(), element_count); + } else if (field_meta.get_data_type() == DataType::VECTOR_INT8) { + return set_data_raw( + element_offset, VEC_FIELD_DATA(data, int8), element_count); } else { PanicInfo(DataTypeInvalid, "unsupported vector type"); } diff --git a/internal/core/src/segcore/ConcurrentVector.h b/internal/core/src/segcore/ConcurrentVector.h index 3ddd3c27484bc..14108a2f615c2 100644 --- a/internal/core/src/segcore/ConcurrentVector.h +++ b/internal/core/src/segcore/ConcurrentVector.h @@ -167,9 +167,13 @@ class ConcurrentVectorImpl : public VectorBase { std::conditional_t< std::is_same_v, Float16Vector, - std::conditional_t, - BFloat16Vector, - BinaryVector>>>>; + std::conditional_t< + std::is_same_v, + BFloat16Vector, + std::conditional_t< + std::is_same_v, + Int8Vector, + BinaryVector>>>>>; public: explicit ConcurrentVectorImpl( @@ -541,4 +545,16 @@ class ConcurrentVector } }; +template <> +class ConcurrentVector + : public ConcurrentVectorImpl { + public: + ConcurrentVector(int64_t dim, + int64_t size_per_chunk, + storage::MmapChunkDescriptorPtr mmap_descriptor = nullptr) + : ConcurrentVectorImpl::ConcurrentVectorImpl( + dim, size_per_chunk, std::move(mmap_descriptor)) { + } +}; + } // namespace milvus::segcore diff --git a/internal/core/src/segcore/FieldIndexing.cpp b/internal/core/src/segcore/FieldIndexing.cpp index 8c924e24ba01e..89eb12a33562b 100644 --- a/internal/core/src/segcore/FieldIndexing.cpp +++ b/internal/core/src/segcore/FieldIndexing.cpp @@ -324,6 +324,7 @@ CreateIndex(const FieldMeta& field_meta, if (field_meta.get_data_type() == DataType::VECTOR_FLOAT || field_meta.get_data_type() == DataType::VECTOR_FLOAT16 || field_meta.get_data_type() == DataType::VECTOR_BFLOAT16 || + field_meta.get_data_type() == DataType::VECTOR_INT8 || field_meta.get_data_type() == DataType::VECTOR_SPARSE_FLOAT) { return std::make_unique(field_meta, field_index_meta, diff --git a/internal/core/src/segcore/InsertRecord.h b/internal/core/src/segcore/InsertRecord.h index 3f8bb5a4d3738..cdea9a60e0e10 100644 --- a/internal/core/src/segcore/InsertRecord.h +++ b/internal/core/src/segcore/InsertRecord.h @@ -412,6 +412,11 @@ struct InsertRecord { this->append_data(field_id, size_per_chunk); continue; + } else if (field_meta.get_data_type() == + DataType::VECTOR_INT8) { + this->append_data( + field_id, field_meta.get_dim(), size_per_chunk); + continue; } else { PanicInfo(DataTypeInvalid, fmt::format("unsupported vector type", diff --git a/internal/core/src/segcore/SegmentGrowingImpl.cpp b/internal/core/src/segcore/SegmentGrowingImpl.cpp index 4817df64eaa69..c4715f931f5c2 100644 --- a/internal/core/src/segcore/SegmentGrowingImpl.cpp +++ b/internal/core/src/segcore/SegmentGrowingImpl.cpp @@ -486,6 +486,14 @@ SegmentGrowingImpl::bulk_subscript(FieldId field_id, result->mutable_vectors()->mutable_sparse_float_vector()); result->mutable_vectors()->set_dim( result->vectors().sparse_float_vector().dim()); + } else if (field_meta.get_data_type() == DataType::VECTOR_INT8) { + bulk_subscript_impl( + field_id, + field_meta.get_sizeof(), + vec_ptr, + seg_offsets, + count, + result->mutable_vectors()->mutable_int8_vector()->data()); } else { PanicInfo(DataTypeInvalid, "logical error"); } diff --git a/internal/core/src/segcore/SegmentSealedImpl.cpp b/internal/core/src/segcore/SegmentSealedImpl.cpp index d5edeb23c676d..c901a3a47f906 100644 --- a/internal/core/src/segcore/SegmentSealedImpl.cpp +++ b/internal/core/src/segcore/SegmentSealedImpl.cpp @@ -1542,7 +1542,15 @@ SegmentSealedImpl::get_raw_data(FieldId field_id, ret->mutable_vectors()->set_dim(dst->dim()); break; } - + case DataType::VECTOR_INT8: { + bulk_subscript_impl( + field_meta.get_sizeof(), + column->Data(0), + seg_offsets, + count, + ret->mutable_vectors()->mutable_int8_vector()->data()); + break; + } default: { PanicInfo(DataTypeInvalid, fmt::format("unsupported data type {}", diff --git a/internal/core/src/segcore/Utils.cpp b/internal/core/src/segcore/Utils.cpp index 30b01caa86a4d..4a8853f0a5050 100644 --- a/internal/core/src/segcore/Utils.cpp +++ b/internal/core/src/segcore/Utils.cpp @@ -356,6 +356,12 @@ CreateVectorDataArray(int64_t count, const FieldMeta& field_meta) { // does nothing here break; } + case DataType::VECTOR_INT8: { + auto length = count * dim; + auto obj = vector_array->mutable_int8_vector(); + obj->resize(length * sizeof(int8)); + break; + } default: { PanicInfo(DataTypeInvalid, fmt::format("unsupported datatype {}", data_type)); @@ -519,6 +525,13 @@ CreateVectorDataArrayFrom(const void* data_raw, vector_array->set_dim(vector_array->sparse_float_vector().dim()); break; } + case DataType::VECTOR_INT8: { + auto length = count * dim; + auto data = reinterpret_cast(data_raw); + auto obj = vector_array->mutable_int8_vector(); + obj->assign(data, length * sizeof(int8)); + break; + } default: { PanicInfo(DataTypeInvalid, fmt::format("unsupported datatype {}", data_type)); @@ -596,6 +609,11 @@ MergeDataArray(std::vector& merge_bases, } vector_array->set_dim(dst->dim()); *dst->mutable_contents() = src.contents(); + } else if (field_meta.get_data_type() == + DataType::VECTOR_INT8) { + auto data = VEC_FIELD_DATA(src_field_data, int8); + auto obj = vector_array->mutable_int8_vector(); + obj->assign(data, dim * sizeof(int8)); } else { PanicInfo(DataTypeInvalid, fmt::format("unsupported datatype {}", data_type)); diff --git a/internal/core/src/storage/Util.cpp b/internal/core/src/storage/Util.cpp index 8ac8430e3b299..2c812cf42934f 100644 --- a/internal/core/src/storage/Util.cpp +++ b/internal/core/src/storage/Util.cpp @@ -196,6 +196,7 @@ AddPayloadToArrowBuilder(std::shared_ptr builder, case DataType::VECTOR_FLOAT16: case DataType::VECTOR_BFLOAT16: case DataType::VECTOR_BINARY: + case DataType::VECTOR_INT8: case DataType::VECTOR_FLOAT: { add_vector_payload(builder, const_cast(raw_data), length); break; @@ -312,6 +313,11 @@ CreateArrowBuilder(DataType data_type, int dim) { return std::make_shared( arrow::fixed_size_binary(dim * sizeof(bfloat16))); } + case DataType::VECTOR_INT8: { + AssertInfo(dim > 0, "invalid dim value"); + return std::make_shared( + arrow::fixed_size_binary(dim * sizeof(int8))); + } default: { PanicInfo( DataTypeInvalid, "unsupported vector data type {}", data_type); @@ -405,6 +411,13 @@ CreateArrowSchema(DataType data_type, int dim, bool nullable) { return arrow::schema( {arrow::field("val", arrow::binary(), nullable)}); } + case DataType::VECTOR_INT8: { + AssertInfo(dim > 0, "invalid dim value"); + return arrow::schema( + {arrow::field("val", + arrow::fixed_size_binary(dim * sizeof(int8)), + nullable)}); + } default: { PanicInfo( DataTypeInvalid, "unsupported vector data type {}", data_type); @@ -433,6 +446,9 @@ GetDimensionFromFileMetaData(const parquet::ColumnDescriptor* schema, fmt::format("GetDimensionFromFileMetaData should not be " "called for sparse vector")); } + case DataType::VECTOR_INT8: { + return schema->type_length() / sizeof(int8); + } default: PanicInfo(DataTypeInvalid, "unsupported data type {}", data_type); } @@ -478,6 +494,15 @@ GetDimensionFromArrowArray(std::shared_ptr data, std::dynamic_pointer_cast(data); return array->byte_width() / sizeof(bfloat16); } + case DataType::VECTOR_INT8: { + AssertInfo( + data->type()->id() == arrow::Type::type::FIXED_SIZE_BINARY, + "inconsistent data type: {}", + data->type_id()); + auto array = + std::dynamic_pointer_cast(data); + return array->byte_width() / sizeof(int8); + } default: PanicInfo(DataTypeInvalid, "unsupported data type {}", data_type); } @@ -810,6 +835,9 @@ CreateFieldData(const DataType& type, case DataType::VECTOR_SPARSE_FLOAT: return std::make_shared>( type, total_num_rows); + case DataType::VECTOR_INT8: + return std::make_shared>( + dim, type, total_num_rows); default: PanicInfo(DataTypeInvalid, "CreateFieldData not support data type " + diff --git a/internal/core/thirdparty/knowhere/CMakeLists.txt b/internal/core/thirdparty/knowhere/CMakeLists.txt index 942aacc56c25d..89d30e2a36fe8 100644 --- a/internal/core/thirdparty/knowhere/CMakeLists.txt +++ b/internal/core/thirdparty/knowhere/CMakeLists.txt @@ -14,7 +14,7 @@ # Update KNOWHERE_VERSION for the first occurrence milvus_add_pkg_config("knowhere") set_property(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES "") -set( KNOWHERE_VERSION 45d757c9 ) +set( KNOWHERE_VERSION 7dc867d3 ) set( GIT_REPOSITORY "https://github.com/zilliztech/knowhere.git") message(STATUS "Knowhere repo: ${GIT_REPOSITORY}") message(STATUS "Knowhere version: ${KNOWHERE_VERSION}") diff --git a/internal/core/unittest/test_c_api.cpp b/internal/core/unittest/test_c_api.cpp index 6f7e4e7f2c75f..8dedd82f0be14 100644 --- a/internal/core/unittest/test_c_api.cpp +++ b/internal/core/unittest/test_c_api.cpp @@ -43,7 +43,6 @@ #include "test_utils/GenExprProto.h" #include "expr/ITypeExpr.h" #include "plan/PlanNode.h" -#include "exec/expression/Expr.h" #include "segcore/load_index_c.h" #include "test_utils/c_api_test_utils.h" #include "segcore/vector_index_c.h" @@ -174,7 +173,7 @@ template std::string generate_collection_schema(std::string metric_type, int dim) { namespace schema = milvus::proto::schema; - GET_SCHEMA_DATA_TYPE_FOR_VECTOR_TRAIT; + GET_SCHEMA_DATA_TYPE_FOR_VECTOR_TRAIT schema::CollectionSchema collection_schema; collection_schema.set_name("collection_test"); @@ -326,6 +325,7 @@ TEST(CApiTest, CPlan) { Test_CPlan(knowhere::metric::L2); Test_CPlan(knowhere::metric::L2); Test_CPlan(knowhere::metric::L2); + Test_CPlan(knowhere::metric::L2); } TEST(CApiTest, InsertTest) { @@ -1783,11 +1783,14 @@ TEST(CApiTest, ReduceSearchWithExpr) { testReduceSearchWithExpr(10000, 1, 1); testReduceSearchWithExpr(10000, 10, 10); // float16 - testReduceSearchWithExpr(2, 10, 10, false); - testReduceSearchWithExpr(100, 10, 10, false); + testReduceSearchWithExpr(2, 10, 10); + testReduceSearchWithExpr(100, 10, 10); // bfloat16 - testReduceSearchWithExpr(2, 10, 10, false); - testReduceSearchWithExpr(100, 10, 10, false); + testReduceSearchWithExpr(2, 10, 10); + testReduceSearchWithExpr(100, 10, 10); + // int8 + testReduceSearchWithExpr(2, 10, 10); + testReduceSearchWithExpr(100, 10, 10); } TEST(CApiTest, ReduceSearchWithExprFilterAll) { @@ -1796,8 +1799,13 @@ TEST(CApiTest, ReduceSearchWithExprFilterAll) { testReduceSearchWithExpr(2, 10, 10, true); // float16 testReduceSearchWithExpr(2, 1, 1, true); + testReduceSearchWithExpr(2, 10, 10, true); // bfloat16 testReduceSearchWithExpr(2, 1, 1, true); + testReduceSearchWithExpr(2, 10, 10, true); + // int8 + testReduceSearchWithExpr(2, 1, 1, true); + testReduceSearchWithExpr(2, 10, 10, true); } TEST(CApiTest, LoadIndexInfo) { @@ -2053,6 +2061,7 @@ TEST(CApiTest, Indexing_Without_Predicate) { Test_Indexing_Without_Predicate(); Test_Indexing_Without_Predicate(); Test_Indexing_Without_Predicate(); + Test_Indexing_Without_Predicate(); } TEST(CApiTest, Indexing_Expr_Without_Predicate) { @@ -4373,6 +4382,7 @@ TEST(CApiTest, Range_Search_With_Radius_And_Range_Filter) { Test_Range_Search_With_Radius_And_Range_Filter(); Test_Range_Search_With_Radius_And_Range_Filter(); Test_Range_Search_With_Radius_And_Range_Filter(); + Test_Range_Search_With_Radius_And_Range_Filter(); } std::vector diff --git a/internal/core/unittest/test_index_c_api.cpp b/internal/core/unittest/test_index_c_api.cpp index ed969fee19efe..edc95c89ead2e 100644 --- a/internal/core/unittest/test_index_c_api.cpp +++ b/internal/core/unittest/test_index_c_api.cpp @@ -83,6 +83,9 @@ TestVecIndex() { } else if (std::is_same_v) { auto xb_data = dataset.template get_col(milvus::FieldId(100)); status = BuildBFloat16VecIndex(index, NB * DIM, xb_data.data()); + } else if (std::is_same_v) { + auto xb_data = dataset.template get_col(milvus::FieldId(100)); + status = BuildInt8VecIndex(index, NB * DIM, xb_data.data()); } ASSERT_EQ(milvus::Success, status.error_code); @@ -111,6 +114,7 @@ TEST(VecIndex, All) { TestVecIndex(); TestVecIndex(); TestVecIndex(); + TestVecIndex(); } TEST(CBoolIndexTest, All) { diff --git a/internal/core/unittest/test_utils/DataGen.h b/internal/core/unittest/test_utils/DataGen.h index 6f943a315eb08..24a82ffe944f2 100644 --- a/internal/core/unittest/test_utils/DataGen.h +++ b/internal/core/unittest/test_utils/DataGen.h @@ -126,6 +126,13 @@ struct GeneratedData { auto src_data = reinterpret_cast( target_field_data.vectors().bfloat16_vector().data()); std::copy_n(src_data, len, ret.data()); + } else if (field_meta.get_data_type() == + DataType::VECTOR_INT8) { + int len = raw_->num_rows() * field_meta.get_dim(); + ret.resize(len); + auto src_data = reinterpret_cast( + target_field_data.vectors().int8_vector().data()); + std::copy_n(src_data, len, ret.data()); } else { PanicInfo(Unsupported, "unsupported"); } @@ -410,7 +417,6 @@ inline GeneratedData DataGen(SchemaPtr schema, array.release()); break; } - case DataType::VECTOR_BFLOAT16: { auto dim = field_meta.get_dim(); vector final(dim * N); @@ -420,6 +426,15 @@ inline GeneratedData DataGen(SchemaPtr schema, insert_cols(final, N, field_meta, random_valid); break; } + case DataType::VECTOR_INT8: { + auto dim = field_meta.get_dim(); + vector final(dim * N); + for (auto& x : final) { + x = int8_t(rand() % 256 - 128); + } + insert_cols(final, N, field_meta, random_valid); + break; + } case DataType::BOOL: { FixedVector data(N); for (int i = 0; i < N; ++i) { @@ -834,6 +849,46 @@ CreateSparseFloatPlaceholderGroup(int64_t num_queries, int64_t seed = 42) { return raw_group; } +inline auto +CreateInt8PlaceholderGroup(int64_t num_queries, + int64_t dim, + int64_t seed = 42) { + namespace ser = milvus::proto::common; + ser::PlaceholderGroup raw_group; + auto value = raw_group.add_placeholders(); + value->set_tag("$0"); + value->set_type(ser::PlaceholderType::Int8Vector); + std::default_random_engine e(seed); + for (int i = 0; i < num_queries; ++i) { + std::vector vec; + for (int d = 0; d < dim; ++d) { + vec.push_back(e()); + } + value->add_values(vec.data(), vec.size() * sizeof(int8)); + } + return raw_group; +} + +inline auto +CreateInt8PlaceholderGroupFromBlob(int64_t num_queries, + int64_t dim, + const int8* ptr) { + namespace ser = milvus::proto::common; + ser::PlaceholderGroup raw_group; + auto value = raw_group.add_placeholders(); + value->set_tag("$0"); + value->set_type(ser::PlaceholderType::Int8Vector); + for (int i = 0; i < num_queries; ++i) { + std::vector vec; + for (int d = 0; d < dim; ++d) { + vec.push_back(*ptr); + ++ptr; + } + value->add_values(vec.data(), vec.size() * sizeof(int8)); + } + return raw_group; +} + inline auto SearchResultToVector(const SearchResult& sr) { int64_t num_queries = sr.total_nq_; @@ -934,6 +989,12 @@ CreateFieldDataFromDataArray(ssize_t raw_count, createFieldData(rows.get(), DataType::VECTOR_SPARSE_FLOAT, 0); break; } + case DataType::VECTOR_INT8: { + auto raw_data = data->vectors().int8_vector().data(); + dim = field_meta.get_dim(); + createFieldData(raw_data, DataType::VECTOR_INT8, dim); + break; + } default: { PanicInfo(Unsupported, "unsupported"); } diff --git a/pkg/proto/plan.proto b/pkg/proto/plan.proto index e8abed2846d99..289d9244cc05f 100644 --- a/pkg/proto/plan.proto +++ b/pkg/proto/plan.proto @@ -38,6 +38,7 @@ enum VectorType { Float16Vector = 2; BFloat16Vector = 3; SparseFloatVector = 4; + Int8Vector = 5; }; message GenericValue { diff --git a/pkg/proto/planpb/plan.pb.go b/pkg/proto/planpb/plan.pb.go index 655c3f07621d6..0adbaa4e5db63 100644 --- a/pkg/proto/planpb/plan.pb.go +++ b/pkg/proto/planpb/plan.pb.go @@ -175,6 +175,7 @@ const ( VectorType_Float16Vector VectorType = 2 VectorType_BFloat16Vector VectorType = 3 VectorType_SparseFloatVector VectorType = 4 + VectorType_Int8Vector VectorType = 5 ) // Enum value maps for VectorType. @@ -185,6 +186,7 @@ var ( 2: "Float16Vector", 3: "BFloat16Vector", 4: "SparseFloatVector", + 5: "Int8Vector", } VectorType_value = map[string]int32{ "BinaryVector": 0, @@ -192,6 +194,7 @@ var ( "Float16Vector": 2, "BFloat16Vector": 3, "SparseFloatVector": 4, + "Int8Vector": 5, } ) @@ -2897,14 +2900,15 @@ var file_plan_proto_rawDesc = []byte{ 0x12, 0x07, 0x0a, 0x03, 0x41, 0x64, 0x64, 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x53, 0x75, 0x62, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x4d, 0x75, 0x6c, 0x10, 0x03, 0x12, 0x07, 0x0a, 0x03, 0x44, 0x69, 0x76, 0x10, 0x04, 0x12, 0x07, 0x0a, 0x03, 0x4d, 0x6f, 0x64, 0x10, 0x05, 0x12, 0x0f, 0x0a, - 0x0b, 0x41, 0x72, 0x72, 0x61, 0x79, 0x4c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x10, 0x06, 0x2a, 0x6d, + 0x0b, 0x41, 0x72, 0x72, 0x61, 0x79, 0x4c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x10, 0x06, 0x2a, 0x7d, 0x0a, 0x0a, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x54, 0x79, 0x70, 0x65, 0x12, 0x10, 0x0a, 0x0c, 0x42, 0x69, 0x6e, 0x61, 0x72, 0x79, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x00, 0x12, 0x0f, 0x0a, 0x0b, 0x46, 0x6c, 0x6f, 0x61, 0x74, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x01, 0x12, 0x11, 0x0a, 0x0d, 0x46, 0x6c, 0x6f, 0x61, 0x74, 0x31, 0x36, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x02, 0x12, 0x12, 0x0a, 0x0e, 0x42, 0x46, 0x6c, 0x6f, 0x61, 0x74, 0x31, 0x36, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x03, 0x12, 0x15, 0x0a, 0x11, 0x53, 0x70, 0x61, 0x72, 0x73, 0x65, - 0x46, 0x6c, 0x6f, 0x61, 0x74, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x04, 0x42, 0x2e, 0x5a, + 0x46, 0x6c, 0x6f, 0x61, 0x74, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x04, 0x12, 0x0e, 0x0a, + 0x0a, 0x49, 0x6e, 0x74, 0x38, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x05, 0x42, 0x2e, 0x5a, 0x2c, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2d, 0x69, 0x6f, 0x2f, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x70, 0x6c, 0x61, 0x6e, 0x70, 0x62, 0x62, 0x06, 0x70,