Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cuvs integration main #14111

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build-tools/build-infra/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ plugins {
}

repositories {
mavenLocal()
mavenCentral()
}

Expand Down
1 change: 1 addition & 0 deletions gradle/globals.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ allprojects {

// Repositories to fetch dependencies from.
repositories {
mavenLocal()
mavenCentral()
}

Expand Down
7 changes: 7 additions & 0 deletions lucene/sandbox/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,16 @@ apply plugin: 'java-library'

description = 'Various third party contributions and new ideas'

repositories {
mavenLocal()
}


dependencies {
moduleApi project(':lucene:core')
moduleApi project(':lucene:queries')
moduleApi project(':lucene:facet')
moduleTestImplementation project(':lucene:test-framework')
moduleImplementation deps.commons.lang3
moduleImplementation deps.cuvs
}
5 changes: 4 additions & 1 deletion lucene/sandbox/src/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
requires org.apache.lucene.core;
requires org.apache.lucene.queries;
requires org.apache.lucene.facet;

requires java.logging;
requires com.nvidia.cuvs;
requires org.apache.commons.lang3;

exports org.apache.lucene.payloads;
exports org.apache.lucene.sandbox.codecs.idversion;
exports org.apache.lucene.sandbox.codecs.quantization;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package org.apache.lucene.sandbox.vectorsearch;

import java.io.IOException;
import java.util.concurrent.ConcurrentHashMap;

import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.index.FieldInfo;

public class CagraFieldVectorsWriter extends KnnFieldVectorsWriter<float[]> {

public final String fieldName;
public final ConcurrentHashMap<Integer, float[]> vectors = new ConcurrentHashMap<Integer, float[]>();
public int fieldVectorDimension = -1;

public CagraFieldVectorsWriter(FieldInfo fieldInfo) {
this.fieldName = fieldInfo.getName();
this.fieldVectorDimension = fieldInfo.getVectorDimension();
}

@Override
public long ramBytesUsed() {
return fieldName.getBytes().length + Integer.BYTES + (vectors.size() * fieldVectorDimension * Float.BYTES);
}

@Override
public void addValue(int docID, float[] vectorValue) throws IOException {
vectors.put(docID, vectorValue);
}

@Override
public float[] copyValue(float[] vectorValue) {
throw new UnsupportedOperationException();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package org.apache.lucene.sandbox.vectorsearch;

import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.FilterCodec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene101.Lucene101Codec;
import org.apache.lucene.sandbox.vectorsearch.CuVSVectorsWriter.MergeStrategy;


public class CuVSCodec extends FilterCodec {

public CuVSCodec() {
this("CuVSCodec", new Lucene101Codec());
}

public CuVSCodec(String name, Codec delegate) {
super(name, delegate);
setKnnFormat(new CuVSVectorsFormat(1, 128, 64, MergeStrategy.NON_TRIVIAL_MERGE));
}

KnnVectorsFormat knnFormat = null;

@Override
public KnnVectorsFormat knnVectorsFormat() {
return knnFormat;
}

public void setKnnFormat(KnnVectorsFormat format) {
this.knnFormat = format;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package org.apache.lucene.sandbox.vectorsearch;

import java.util.List;
import java.util.Objects;

import com.nvidia.cuvs.BruteForceIndex;
import com.nvidia.cuvs.CagraIndex;

public class CuVSIndex {
private final CagraIndex cagraIndex;
private final BruteForceIndex bruteforceIndex;
private final List<Integer> mapping;
private final List<float[]> vectors;
private final int maxDocs;

private final String fieldName;
private final String segmentName;

public CuVSIndex(String segmentName, String fieldName, CagraIndex cagraIndex, List<Integer> mapping, List<float[]> vectors, int maxDocs, BruteForceIndex bruteforceIndex) {
this.cagraIndex = Objects.requireNonNull(cagraIndex);
this.bruteforceIndex = Objects.requireNonNull(bruteforceIndex);
this.mapping = Objects.requireNonNull(mapping);
this.vectors = Objects.requireNonNull(vectors);
this.fieldName = Objects.requireNonNull(fieldName);
this.segmentName = Objects.requireNonNull(segmentName);
this.maxDocs = Objects.requireNonNull(maxDocs);
}

public CagraIndex getCagraIndex() {
return cagraIndex;
}

public BruteForceIndex getBruteforceIndex() {
return bruteforceIndex;
}

public List<Integer> getMapping() {
return mapping;
}

public String getFieldName() {
return fieldName;
}

public List<float[]> getVectors() {
return vectors;
}

public String getSegmentName() {
return segmentName;
}

public int getMaxDocs() {
return maxDocs;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package org.apache.lucene.sandbox.vectorsearch;

import java.io.IOException;

import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.util.Bits;

public class CuVSKnnFloatVectorQuery extends KnnFloatVectorQuery {

final private int iTopK;
final private int searchWidth;

public CuVSKnnFloatVectorQuery(String field, float[] target, int k, int iTopK, int searchWidth) {
super(field, target, k);
this.iTopK = iTopK;
this.searchWidth = searchWidth;
}

@Override
protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitedLimit, KnnCollectorManager knnCollectorManager) throws IOException {

PerLeafCuVSKnnCollector results = new PerLeafCuVSKnnCollector(k, iTopK, searchWidth);

LeafReader reader = context.reader();
reader.searchNearestVectors(field, this.getTargetCopy(), results, null);
return results.topDocs();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package org.apache.lucene.sandbox.vectorsearch;

import java.io.File;
import java.io.IOException;
import java.io.OutputStream;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import java.util.logging.Logger;
import java.util.zip.Deflater;
import java.util.zip.ZipEntry;
import java.util.zip.ZipOutputStream;

public class CuVSSegmentFile implements AutoCloseable{
final private ZipOutputStream zos;

private Set<String> filesAdded = new HashSet<String>();

public CuVSSegmentFile(OutputStream out) {
zos = new ZipOutputStream(out);
zos.setLevel(Deflater.NO_COMPRESSION);
}

protected Logger log = Logger.getLogger(getClass().getName());

public void addFile(String name, byte[] bytes) throws IOException {
log.info("Writing the file: " + name + ", size="+bytes.length + ", space remaining: "+new File("/").getFreeSpace());
ZipEntry indexFileZipEntry = new ZipEntry(name);
zos.putNextEntry(indexFileZipEntry);
zos.write(bytes, 0, bytes.length);
zos.closeEntry();
filesAdded.add(name);
}

public Set<String> getFilesAdded() {
return Collections.unmodifiableSet(filesAdded);
}

@Override
public void close() throws IOException {
zos.close();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package org.apache.lucene.sandbox.vectorsearch;

import java.io.IOException;

import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.sandbox.vectorsearch.CuVSVectorsWriter.MergeStrategy;

import com.nvidia.cuvs.CuVSResources;

public class CuVSVectorsFormat extends KnnVectorsFormat {

public static final String VECTOR_DATA_CODEC_NAME = "Lucene99CagraVectorsFormatData";
public static final String VECTOR_DATA_EXTENSION = "cag";
public static final String META_EXTENSION = "cagmf";
public static final int VERSION_CURRENT = 0;
public final int maxDimensions = 4096;
public final int cuvsWriterThreads;
public final int intGraphDegree;
public final int graphDegree;
public MergeStrategy mergeStrategy;
public static CuVSResources resources;

public CuVSVectorsFormat() {
super("CuVSVectorsFormat");
this.cuvsWriterThreads = 1;
this.intGraphDegree = 128;
this.graphDegree = 64;
try {
resources = new CuVSResources();
} catch (Throwable e) {
e.printStackTrace();
}
}

public CuVSVectorsFormat(int cuvsWriterThreads, int intGraphDegree, int graphDegree, MergeStrategy mergeStrategy) {
super("CuVSVectorsFormat");
this.mergeStrategy = mergeStrategy;
this.cuvsWriterThreads = cuvsWriterThreads;
this.intGraphDegree = intGraphDegree;
this.graphDegree = graphDegree;
try {
resources = new CuVSResources();
} catch (Throwable e) {
e.printStackTrace();
}
}

@Override
public CuVSVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new CuVSVectorsWriter(state, cuvsWriterThreads, intGraphDegree, graphDegree, mergeStrategy, resources);
}

@Override
public CuVSVectorsReader fieldsReader(SegmentReadState state) throws IOException {
try {
return new CuVSVectorsReader(state, resources);
} catch (Throwable e) {
e.printStackTrace();
}
return null;
}

@Override
public int getMaxDimensions(String fieldName) {
return maxDimensions;
}

}
Loading
Loading