Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Use correct type for binary vector in ivf training #2088

Merged
merged 1 commit into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env,
void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x);

// Train a binary index with data provided
void InternalTrainBinaryIndex(faiss::IndexBinary * index, faiss::idx_t n, const float* x);
void InternalTrainBinaryIndex(faiss::IndexBinary * index, faiss::idx_t n, const uint8_t* x);

// Converts the int FilterIds to Faiss ids type array.
void convertFilterIdsToFaissIdType(const int* filterIds, int filterIdsLength, faiss::idx_t* convertedFilterIds);
Expand Down Expand Up @@ -286,7 +286,7 @@ void knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate(knn_jni::JNIUtilInter
auto *inputVectors = reinterpret_cast<std::vector<uint8_t>*>(vectorsAddressJ);
int dim = (int)dimJ;
if (dim % 8 != 0) {
throw std::runtime_error("Dimensions should be multiply of 8");
throw std::runtime_error("Dimensions should be multiple of 8");
}
int numVectors = (int) (inputVectors->size() / (uint64_t) (dim / 8));
int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ);
Expand Down Expand Up @@ -848,8 +848,12 @@ jbyteArray knn_jni::faiss_wrapper::TrainBinaryIndex(knn_jni::JNIUtilInterface *
}

// Train index if needed
auto *trainingVectorsPointerCpp = reinterpret_cast<std::vector<float>*>(trainVectorsPointerJ);
int numVectors = trainingVectorsPointerCpp->size()/(int) dimensionJ;
int dim = (int)dimensionJ;
if (dim % 8 != 0) {
throw std::runtime_error("Dimensions should be multiple of 8");
}
auto *trainingVectorsPointerCpp = reinterpret_cast<std::vector<uint8_t>*>(trainVectorsPointerJ);
int numVectors = (int) (trainingVectorsPointerCpp->size() / (dim / 8));
if(!indexWriter->is_trained) {
InternalTrainBinaryIndex(indexWriter.get(), numVectors, trainingVectorsPointerCpp->data());
}
Expand Down Expand Up @@ -997,12 +1001,12 @@ void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x) {
}
}

void InternalTrainBinaryIndex(faiss::IndexBinary * index, faiss::idx_t n, const float* x) {
void InternalTrainBinaryIndex(faiss::IndexBinary * index, faiss::idx_t n, const uint8_t* x) {
if (auto * indexIvf = dynamic_cast<faiss::IndexBinaryIVF*>(index)) {
indexIvf->make_direct_map();
}
if (!index->is_trained) {
index->train(n, reinterpret_cast<const uint8_t*>(x));
index->train(n, x);
}
}

Expand Down
1 change: 1 addition & 0 deletions release-notes/opensearch-knn.release-notes-2.17.0.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Compatible with OpenSearch 2.17.0
* Fix graph merge stats size calculation [#1844](https://github.com/opensearch-project/k-NN/pull/1844)
* Disallow a vector field to have an invalid character for a physical file name. [#1936](https://github.com/opensearch-project/k-NN/pull/1936)
* Fix memory overflow caused by cache behavior [#2015](https://github.com/opensearch-project/k-NN/pull/2015)
* Use correct type for binary vector in ivf training [#2086](https://github.com/opensearch-project/k-NN/pull/2086)
### Infrastructure
* Parallelize make to reduce build time [#2006] (https://github.com/opensearch-project/k-NN/pull/2006)
### Maintenance
Expand Down
Loading