From b8a1162738e87ea75cbc0fd186f29f4738d44a55 Mon Sep 17 00:00:00 2001 From: Ishan Chattopadhyaya Date: Tue, 7 Jan 2025 21:17:26 +0530 Subject: [PATCH 1/3] Initial cut of CuVS into Lucene as a Codec in sandbox --- build-tools/build-infra/build.gradle | 1 + gradle/globals.gradle | 1 + lucene/sandbox/build.gradle | 7 + lucene/sandbox/src/java/module-info.java | 5 +- .../vectorsearch/CagraFieldVectorsWriter.java | 35 ++ .../sandbox/vectorsearch/CuVSCodec.java | 31 ++ .../sandbox/vectorsearch/CuVSIndex.java | 56 +++ .../vectorsearch/CuVSKnnFloatVectorQuery.java | 33 ++ .../sandbox/vectorsearch/CuVSSegmentFile.java | 43 +++ .../vectorsearch/CuVSVectorsFormat.java | 70 ++++ .../vectorsearch/CuVSVectorsReader.java | 310 ++++++++++++++++ .../vectorsearch/CuVSVectorsWriter.java | 339 ++++++++++++++++++ .../vectorsearch/PerLeafCuVSKnnCollector.java | 74 ++++ .../vectorsearch/SegmentInputStream.java | 90 +++++ .../lucene/sandbox/vectorsearch/Util.java | 142 ++++++++ .../sandbox/vectorsearch/package-info.java | 1 + .../sandbox/vectorsearch/IntegrationTest.java | 201 +++++++++++ versions.toml | 6 + 18 files changed, 1444 insertions(+), 1 deletion(-) create mode 100644 lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CagraFieldVectorsWriter.java create mode 100644 lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSCodec.java create mode 100644 lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSIndex.java create mode 100644 lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSKnnFloatVectorQuery.java create mode 100644 lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSSegmentFile.java create mode 100644 lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSVectorsFormat.java create mode 100644 lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSVectorsReader.java create mode 100644 lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSVectorsWriter.java create mode 100644 lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/PerLeafCuVSKnnCollector.java create mode 100644 lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/SegmentInputStream.java create mode 100644 lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/Util.java create mode 100644 lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/package-info.java create mode 100644 lucene/sandbox/src/test/org/apache/lucene/sandbox/vectorsearch/IntegrationTest.java diff --git a/build-tools/build-infra/build.gradle b/build-tools/build-infra/build.gradle index 5cb1426cba97..34d71f7509d3 100644 --- a/build-tools/build-infra/build.gradle +++ b/build-tools/build-infra/build.gradle @@ -22,6 +22,7 @@ plugins { } repositories { + mavenLocal() mavenCentral() } diff --git a/gradle/globals.gradle b/gradle/globals.gradle index bcab6461ea91..25bfddc9bebf 100644 --- a/gradle/globals.gradle +++ b/gradle/globals.gradle @@ -22,6 +22,7 @@ allprojects { // Repositories to fetch dependencies from. repositories { + mavenLocal() mavenCentral() } diff --git a/lucene/sandbox/build.gradle b/lucene/sandbox/build.gradle index 72762fe1c3d2..6d225fd78ba4 100644 --- a/lucene/sandbox/build.gradle +++ b/lucene/sandbox/build.gradle @@ -19,9 +19,16 @@ apply plugin: 'java-library' description = 'Various third party contributions and new ideas' +repositories { + mavenLocal() +} + + dependencies { moduleApi project(':lucene:core') moduleApi project(':lucene:queries') moduleApi project(':lucene:facet') moduleTestImplementation project(':lucene:test-framework') + moduleImplementation deps.commons.lang3 + moduleImplementation deps.cuvs } diff --git a/lucene/sandbox/src/java/module-info.java b/lucene/sandbox/src/java/module-info.java index f40a05af433a..b2d45adf4d30 100644 --- a/lucene/sandbox/src/java/module-info.java +++ b/lucene/sandbox/src/java/module-info.java @@ -20,7 +20,10 @@ requires org.apache.lucene.core; requires org.apache.lucene.queries; requires org.apache.lucene.facet; - + requires java.logging; + requires com.nvidia.cuvs; + requires org.apache.commons.lang3; + exports org.apache.lucene.payloads; exports org.apache.lucene.sandbox.codecs.idversion; exports org.apache.lucene.sandbox.codecs.quantization; diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CagraFieldVectorsWriter.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CagraFieldVectorsWriter.java new file mode 100644 index 000000000000..21c088bd84f8 --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CagraFieldVectorsWriter.java @@ -0,0 +1,35 @@ +package org.apache.lucene.sandbox.vectorsearch; + +import java.io.IOException; +import java.util.concurrent.ConcurrentHashMap; + +import org.apache.lucene.codecs.KnnFieldVectorsWriter; +import org.apache.lucene.index.FieldInfo; + +public class CagraFieldVectorsWriter extends KnnFieldVectorsWriter { + + public final String fieldName; + public final ConcurrentHashMap vectors = new ConcurrentHashMap(); + public int fieldVectorDimension = -1; + + public CagraFieldVectorsWriter(FieldInfo fieldInfo) { + this.fieldName = fieldInfo.getName(); + this.fieldVectorDimension = fieldInfo.getVectorDimension(); + } + + @Override + public long ramBytesUsed() { + return fieldName.getBytes().length + Integer.BYTES + (vectors.size() * fieldVectorDimension * Float.BYTES); + } + + @Override + public void addValue(int docID, float[] vectorValue) throws IOException { + vectors.put(docID, vectorValue); + } + + @Override + public float[] copyValue(float[] vectorValue) { + throw new UnsupportedOperationException(); + } + +} diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSCodec.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSCodec.java new file mode 100644 index 000000000000..448803bb7fc4 --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSCodec.java @@ -0,0 +1,31 @@ +package org.apache.lucene.sandbox.vectorsearch; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.lucene101.Lucene101Codec; +import org.apache.lucene.sandbox.vectorsearch.CuVSVectorsWriter.MergeStrategy; + + +public class CuVSCodec extends FilterCodec { + + public CuVSCodec() { + this("CuVSCodec", new Lucene101Codec()); + } + + public CuVSCodec(String name, Codec delegate) { + super(name, delegate); + setKnnFormat(new CuVSVectorsFormat(1, 128, 64, MergeStrategy.NON_TRIVIAL_MERGE)); + } + + KnnVectorsFormat knnFormat = null; + + @Override + public KnnVectorsFormat knnVectorsFormat() { + return knnFormat; + } + + public void setKnnFormat(KnnVectorsFormat format) { + this.knnFormat = format; + } +} \ No newline at end of file diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSIndex.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSIndex.java new file mode 100644 index 000000000000..1878b6c236bc --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSIndex.java @@ -0,0 +1,56 @@ +package org.apache.lucene.sandbox.vectorsearch; + +import java.util.List; +import java.util.Objects; + +import com.nvidia.cuvs.BruteForceIndex; +import com.nvidia.cuvs.CagraIndex; + +public class CuVSIndex { + private final CagraIndex cagraIndex; + private final BruteForceIndex bruteforceIndex; + private final List mapping; + private final List vectors; + private final int maxDocs; + + private final String fieldName; + private final String segmentName; + + public CuVSIndex(String segmentName, String fieldName, CagraIndex cagraIndex, List mapping, List vectors, int maxDocs, BruteForceIndex bruteforceIndex) { + this.cagraIndex = Objects.requireNonNull(cagraIndex); + this.bruteforceIndex = Objects.requireNonNull(bruteforceIndex); + this.mapping = Objects.requireNonNull(mapping); + this.vectors = Objects.requireNonNull(vectors); + this.fieldName = Objects.requireNonNull(fieldName); + this.segmentName = Objects.requireNonNull(segmentName); + this.maxDocs = Objects.requireNonNull(maxDocs); + } + + public CagraIndex getCagraIndex() { + return cagraIndex; + } + + public BruteForceIndex getBruteforceIndex() { + return bruteforceIndex; + } + + public List getMapping() { + return mapping; + } + + public String getFieldName() { + return fieldName; + } + + public List getVectors() { + return vectors; + } + + public String getSegmentName() { + return segmentName; + } + + public int getMaxDocs() { + return maxDocs; + } +} \ No newline at end of file diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSKnnFloatVectorQuery.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSKnnFloatVectorQuery.java new file mode 100644 index 000000000000..1bbae88c5630 --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSKnnFloatVectorQuery.java @@ -0,0 +1,33 @@ +package org.apache.lucene.sandbox.vectorsearch; + +import java.io.IOException; + +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.knn.KnnCollectorManager; +import org.apache.lucene.util.Bits; + +public class CuVSKnnFloatVectorQuery extends KnnFloatVectorQuery { + + final private int iTopK; + final private int searchWidth; + + public CuVSKnnFloatVectorQuery(String field, float[] target, int k, int iTopK, int searchWidth) { + super(field, target, k); + this.iTopK = iTopK; + this.searchWidth = searchWidth; + } + + @Override + protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitedLimit, KnnCollectorManager knnCollectorManager) throws IOException { + + PerLeafCuVSKnnCollector results = new PerLeafCuVSKnnCollector(k, iTopK, searchWidth); + + LeafReader reader = context.reader(); + reader.searchNearestVectors(field, this.getTargetCopy(), results, null); + return results.topDocs(); + } + +} diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSSegmentFile.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSSegmentFile.java new file mode 100644 index 000000000000..9ca0d63ba087 --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSSegmentFile.java @@ -0,0 +1,43 @@ +package org.apache.lucene.sandbox.vectorsearch; + +import java.io.File; +import java.io.IOException; +import java.io.OutputStream; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; +import java.util.logging.Logger; +import java.util.zip.Deflater; +import java.util.zip.ZipEntry; +import java.util.zip.ZipOutputStream; + +public class CuVSSegmentFile implements AutoCloseable{ + final private ZipOutputStream zos; + + private Set filesAdded = new HashSet(); + + public CuVSSegmentFile(OutputStream out) { + zos = new ZipOutputStream(out); + zos.setLevel(Deflater.NO_COMPRESSION); + } + + protected Logger log = Logger.getLogger(getClass().getName()); + + public void addFile(String name, byte[] bytes) throws IOException { + log.info("Writing the file: " + name + ", size="+bytes.length + ", space remaining: "+new File("/").getFreeSpace()); + ZipEntry indexFileZipEntry = new ZipEntry(name); + zos.putNextEntry(indexFileZipEntry); + zos.write(bytes, 0, bytes.length); + zos.closeEntry(); + filesAdded.add(name); + } + + public Set getFilesAdded() { + return Collections.unmodifiableSet(filesAdded); + } + + @Override + public void close() throws IOException { + zos.close(); + } +} diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSVectorsFormat.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSVectorsFormat.java new file mode 100644 index 000000000000..c17b5258c9d5 --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSVectorsFormat.java @@ -0,0 +1,70 @@ +package org.apache.lucene.sandbox.vectorsearch; + +import java.io.IOException; + +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.sandbox.vectorsearch.CuVSVectorsWriter.MergeStrategy; + +import com.nvidia.cuvs.CuVSResources; + +public class CuVSVectorsFormat extends KnnVectorsFormat { + + public static final String VECTOR_DATA_CODEC_NAME = "Lucene99CagraVectorsFormatData"; + public static final String VECTOR_DATA_EXTENSION = "cag"; + public static final String META_EXTENSION = "cagmf"; + public static final int VERSION_CURRENT = 0; + public final int maxDimensions = 4096; + public final int cuvsWriterThreads; + public final int intGraphDegree; + public final int graphDegree; + public MergeStrategy mergeStrategy; + public static CuVSResources resources; + + public CuVSVectorsFormat() { + super("CuVSVectorsFormat"); + this.cuvsWriterThreads = 1; + this.intGraphDegree = 128; + this.graphDegree = 64; + try { + resources = new CuVSResources(); + } catch (Throwable e) { + e.printStackTrace(); + } + } + + public CuVSVectorsFormat(int cuvsWriterThreads, int intGraphDegree, int graphDegree, MergeStrategy mergeStrategy) { + super("CuVSVectorsFormat"); + this.mergeStrategy = mergeStrategy; + this.cuvsWriterThreads = cuvsWriterThreads; + this.intGraphDegree = intGraphDegree; + this.graphDegree = graphDegree; + try { + resources = new CuVSResources(); + } catch (Throwable e) { + e.printStackTrace(); + } + } + + @Override + public CuVSVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new CuVSVectorsWriter(state, cuvsWriterThreads, intGraphDegree, graphDegree, mergeStrategy, resources); + } + + @Override + public CuVSVectorsReader fieldsReader(SegmentReadState state) throws IOException { + try { + return new CuVSVectorsReader(state, resources); + } catch (Throwable e) { + e.printStackTrace(); + } + return null; + } + + @Override + public int getMaxDimensions(String fieldName) { + return maxDimensions; + } + +} diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSVectorsReader.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSVectorsReader.java new file mode 100644 index 000000000000..cac870afec6c --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSVectorsReader.java @@ -0,0 +1,310 @@ +package org.apache.lucene.sandbox.vectorsearch; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.lang.StackWalker.StackFrame; +import java.lang.invoke.MethodHandles; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.logging.Logger; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import java.util.zip.ZipEntry; +import java.util.zip.ZipInputStream; + +import org.apache.commons.lang3.SerializationUtils; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.TopKnnCollector; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.FixedBitSet; +import org.apache.lucene.util.IOUtils; + +import com.nvidia.cuvs.BruteForceIndex; +import com.nvidia.cuvs.BruteForceQuery; +import com.nvidia.cuvs.CagraIndex; +import com.nvidia.cuvs.CagraQuery; +import com.nvidia.cuvs.CagraSearchParams; +import com.nvidia.cuvs.CuVSResources; +import com.nvidia.cuvs.HnswIndex; +import com.nvidia.cuvs.HnswIndexParams; + +public class CuVSVectorsReader extends KnnVectorsReader { + + protected Logger log = Logger.getLogger(getClass().getName()); + + IndexInput vectorDataReader = null; + public String fileName = null; + public byte[] indexFileBytes; + public int[] docIds; + public float[] vectors; + public SegmentReadState segmentState = null; + public int indexFilePayloadSize = 0; + public long initialFilePointerLoc = 0; + public SegmentInputStream segmentInputStream; + + // Field to List of Indexes + public Map> cuvsIndexes; + + private CuVSResources resources; + + public CuVSVectorsReader(SegmentReadState state, CuVSResources resources) throws Throwable { + + segmentState = state; + this.resources = resources; + + fileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, + CuVSVectorsFormat.VECTOR_DATA_EXTENSION); + + vectorDataReader = segmentState.directory.openInput(fileName, segmentState.context); + CodecUtil.readIndexHeader(vectorDataReader); + + initialFilePointerLoc = vectorDataReader.getFilePointer(); + indexFilePayloadSize = (int)vectorDataReader.length() - (int)initialFilePointerLoc; //vectorMetaReader.readInt(); + segmentInputStream = new SegmentInputStream(vectorDataReader, indexFilePayloadSize, initialFilePointerLoc); + log.info("payloadSize: " + indexFilePayloadSize); + log.info("initialFilePointerLoc: " + initialFilePointerLoc); + + List stackTrace = StackWalker.getInstance().walk(this::getStackTrace); + + boolean isMergeCase = false; + for (StackFrame s : stackTrace) { + if (s.toString().startsWith("org.apache.lucene.index.IndexWriter.merge")) { + isMergeCase = true; + log.info("Reader opening on merge call"); + break; + } + } + + log.info("Source of this segment "+segmentState.segmentSuffix+" is " + segmentState.segmentInfo.getDiagnostics().get(IndexWriter.SOURCE)); + log.info("Loading for " + segmentState.segmentInfo.name + ", mergeCase? " + isMergeCase); + //if (!isMergeCase) { nocommit: TODO: don't load the cagra index for merge case. + log.info("Not the merge case, hence loading for " + segmentState.segmentInfo.name); + this.cuvsIndexes = loadCuVSIndex(getIndexInputStream(), isMergeCase); + //} + } + + @SuppressWarnings({"unchecked"}) + private Map> loadCuVSIndex(ZipInputStream zis, boolean isMergeCase) throws Throwable { + Map> ret = new HashMap>(); + Map cagraIndexes = new HashMap(); + Map bruteforceIndexes = new HashMap(); + Map hnswIndexes = new HashMap(); + Map> mappings = new HashMap>(); + Map> vectors = new HashMap>(); + + Map maxDocs = null; // map of segment, maxDocs + ZipEntry ze; + while ((ze = zis.getNextEntry()) != null) { + String entry = ze.getName(); + + String segmentField = entry.split("\\.")[0]; + String extension = entry.split("\\.")[1]; + + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + byte[] buffer = new byte[1024]; + int len = 0; + while ((len = zis.read(buffer)) != -1) { + baos.write(buffer, 0, len); + } + + switch (extension) { + case "meta": { + maxDocs = (Map) SerializationUtils.deserialize(baos.toByteArray()); // nocommit use IOUtils + break; + } + case "vec": { + vectors.put(segmentField, (List) SerializationUtils.deserialize(baos.toByteArray())); // nocommit use IOUtils + break; + } + case "map": { + List map = (List) SerializationUtils.deserialize(baos.toByteArray()); // nocommit use IOUtils + mappings.put(segmentField, map); + break; + } + case "cag": { + cagraIndexes.put(segmentField, new CagraIndex.Builder(resources) + .from(new ByteArrayInputStream(baos.toByteArray())) + .build()); + break; + } + case "bf": { + bruteforceIndexes.put(segmentField, new BruteForceIndex.Builder(resources) + .from(new ByteArrayInputStream(baos.toByteArray())) + .build()); + break; + } + case "hnsw": { + HnswIndexParams indexParams = new HnswIndexParams.Builder(resources) + .build(); + hnswIndexes.put(segmentField, new HnswIndex.Builder(resources) + .from(new ByteArrayInputStream(baos.toByteArray())) + .withIndexParams(indexParams) + .build()); + break; + } + } + } + + log.info("Loading cuvsIndexes from segment: " + segmentState.segmentInfo.name); + log.info("Diagnostics for this segment: " + segmentState.segmentInfo.getDiagnostics()); + log.info("Loading map of cagraIndexes: " + cagraIndexes); + log.info("Loading vectors: " + vectors); + log.info("Loading mapping: " + mappings); + + for (String segmentField: cagraIndexes.keySet()) { + log.info("Loading segmentField: " + segmentField); + String segment = segmentField.split("/")[0]; + String field = segmentField.split("/")[1]; + CuVSIndex cuvsIndex = new CuVSIndex(segment, field, cagraIndexes.get(segmentField), mappings.get(segmentField), vectors.get(segmentField), maxDocs.get(segment), bruteforceIndexes.get(segmentField)); + List listOfIndexes = ret.containsKey(field)? ret.get(field): new ArrayList(); + listOfIndexes.add(cuvsIndex); + ret.put(field, listOfIndexes); + } + return ret; + } + + public List getStackTrace(Stream stackFrameStream) { + return stackFrameStream.collect(Collectors.toList()); + } + + public ZipInputStream getIndexInputStream() throws IOException { + segmentInputStream.reset(); + return new ZipInputStream(segmentInputStream); + } + + @Override + public void close() throws IOException { + IOUtils.close(vectorDataReader); + } + + @Override + public void checkIntegrity() throws IOException { + // TODO: Pending implementation + } + + @Override + public FloatVectorValues getFloatVectorValues(String field) throws IOException { + throw new UnsupportedOperationException(); + /*return new FloatVectorValues() { + + int pos = -1; + + @Override + public int nextDoc() throws IOException { + pos++; + int size = cuvsIndexes.get(field).get(0).getMapping().size(); + if (pos >= size) return FloatVectorValues.NO_MORE_DOCS; + return cuvsIndexes.get(field).get(0).getMapping().get(pos); + } + + @Override + public int docID() { + return cuvsIndexes.get(field).get(0).getMapping().get(pos); + } + + @Override + public int advance(int target) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public float[] vectorValue() throws IOException { + return cuvsIndexes.get(field).get(0).getVectors().get(pos); + + } + + @Override + public int size() { + return cuvsIndexes.get(field).get(0).getVectors().size(); + } + + @Override + public VectorScorer scorer(float[] query) throws IOException { + // TODO Auto-generated method stub + return null; + } + + @Override + public int dimension() { + // TODO Auto-generated method stub + return cuvsIndexes.get(field).get(0).getVectors().get(0).length; + } + };*/ + } + + @Override + public ByteVectorValues getByteVectorValues(String field) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + PerLeafCuVSKnnCollector cuvsCollector = knnCollector instanceof PerLeafCuVSKnnCollector? ((PerLeafCuVSKnnCollector)knnCollector): new PerLeafCuVSKnnCollector(knnCollector.k(), knnCollector.k(), 1); + TopKnnCollector defaultCollector = knnCollector instanceof TopKnnCollector? ((TopKnnCollector)knnCollector): null; + + int prevDocCount = 0; + + // log.debug("Will try to search all the indexes for segment "+segmentState.segmentInfo.name+", field "+field+": "+cuvsIndexes); + for (CuVSIndex cuvsIndex: cuvsIndexes.get(field)) { + try { + Map result = new HashMap(); + if (cuvsCollector.k() <= 1024) { + CagraSearchParams searchParams = new CagraSearchParams.Builder(resources) + .withItopkSize(cuvsCollector.iTopK) + .withSearchWidth(cuvsCollector.searchWidth) + .build(); + + CagraQuery query = new CagraQuery.Builder() + .withTopK(cuvsCollector.k()) + .withSearchParams(searchParams) + .withMapping(cuvsIndex.getMapping()) + .withQueryVectors(new float[][] {target}) + .build(); + + CagraIndex cagraIndex = cuvsIndex.getCagraIndex(); + assert (cagraIndex != null); + log.info("k is " + cuvsCollector.k()); + result = cagraIndex.search(query).getResults().get(0); // List expected to have only one entry because of single query "target". + log.info("INTERMEDIATE RESULT FROM CUVS: " + result + ", prevDocCount=" + prevDocCount); + } else { + BruteForceQuery bruteforceQuery = new BruteForceQuery.Builder() + .withQueryVectors(new float[][] { target }) + .withPrefilter(((FixedBitSet)acceptDocs).getBits()) + .withTopK(cuvsCollector.k()) + .build(); + + BruteForceIndex bruteforceIndex = cuvsIndex.getBruteforceIndex(); + result = bruteforceIndex.search(bruteforceQuery).getResults().get(0); + } + + for(Entry kv : result.entrySet()) { + if (defaultCollector != null) { + defaultCollector.collect(prevDocCount + kv.getKey(), kv.getValue()); + } + cuvsCollector.collect(prevDocCount + kv.getKey(), kv.getValue()); + } + + } catch (Throwable e) { + e.printStackTrace(); + } + prevDocCount += cuvsIndex.getMaxDocs(); + } + } + + @Override + public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + throw new UnsupportedOperationException(); + } +} diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSVectorsWriter.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSVectorsWriter.java new file mode 100644 index 000000000000..1da7ca0f9e6c --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSVectorsWriter.java @@ -0,0 +1,339 @@ +package org.apache.lucene.sandbox.vectorsearch; + +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.IOException; +import java.io.OutputStream; +import java.lang.invoke.MethodHandles; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.logging.Logger; + +import org.apache.commons.lang3.SerializationUtils; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnFieldVectorsWriter; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.Sorter.DocMap; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.IOUtils; + +import com.nvidia.cuvs.BruteForceIndex; +import com.nvidia.cuvs.BruteForceIndexParams; +import com.nvidia.cuvs.CagraIndex; +import com.nvidia.cuvs.CagraIndexParams; +import com.nvidia.cuvs.CagraIndexParams.CagraGraphBuildAlgo; +import com.nvidia.cuvs.CuVSResources; + +public class CuVSVectorsWriter extends KnnVectorsWriter { + + protected Logger log = Logger.getLogger(getClass().getName()); + + private List fieldVectorWriters = new ArrayList<>(); + private IndexOutput cuVSIndex = null; + private SegmentWriteState segmentWriteState = null; + private String cuVSDataFilename = null; + + private CagraIndex cagraIndex; + private CagraIndex cagraIndexForHnsw; + + private int cuvsWriterThreads; + private int intGraphDegree; + private int graphDegree; + private MergeStrategy mergeStrategy; + private CuVSResources resources; + + public enum MergeStrategy { + TRIVIAL_MERGE, NON_TRIVIAL_MERGE + }; + + public CuVSVectorsWriter(SegmentWriteState state, int cuvsWriterThreads, int intGraphDegree, int graphDegree, MergeStrategy mergeStrategy, CuVSResources resources) + throws IOException { + super(); + this.segmentWriteState = state; + this.mergeStrategy = mergeStrategy; + this.cuvsWriterThreads = cuvsWriterThreads; + this.intGraphDegree = intGraphDegree; + this.graphDegree = graphDegree; + this.resources = resources; + + cuVSDataFilename = IndexFileNames.segmentFileName(this.segmentWriteState.segmentInfo.name, this.segmentWriteState.segmentSuffix, CuVSVectorsFormat.VECTOR_DATA_EXTENSION); + } + + @Override + public long ramBytesUsed() { + return 0; + } + + @Override + public void close() throws IOException { + IOUtils.close(cuVSIndex); + cuVSIndex = null; + fieldVectorWriters.clear(); + fieldVectorWriters = null; + } + + @Override + public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { + CagraFieldVectorsWriter cagraFieldVectorWriter = new CagraFieldVectorsWriter(fieldInfo); + fieldVectorWriters.add(cagraFieldVectorWriter); + return cagraFieldVectorWriter; + } + + private byte[] createCagraIndex(float[][] vectors, List mapping) throws Throwable { + CagraIndexParams indexParams = new CagraIndexParams.Builder(resources) + .withNumWriterThreads(cuvsWriterThreads) + .withIntermediateGraphDegree(intGraphDegree) + .withGraphDegree(graphDegree) + .withCagraGraphBuildAlgo(CagraGraphBuildAlgo.NN_DESCENT) + .build(); + + log.info("Indexing started: " + System.currentTimeMillis()); + cagraIndex = new CagraIndex.Builder(resources) + .withDataset(vectors) + .withIndexParams(indexParams) + .build(); + log.info("Indexing done: " + System.currentTimeMillis() + "ms, documents: " + vectors.length); + + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + File tmpFile = File.createTempFile("tmpindex", "cag"); // TODO: Should we make this a file with random names? + cagraIndex.serialize(baos, tmpFile); + return baos.toByteArray(); + } + + private byte[] createBruteForceIndex(float[][] vectors) throws Throwable { + BruteForceIndexParams indexParams = new BruteForceIndexParams.Builder() + .withNumWriterThreads(32) // TODO: Make this configurable later. + .build(); + + log.info("Indexing started: " + System.currentTimeMillis()); + BruteForceIndex index = new BruteForceIndex.Builder(resources) + .withIndexParams(indexParams) + .withDataset(vectors) + .build(); + + log.info("Indexing done: " + System.currentTimeMillis()); + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + index.serialize(baos); + return baos.toByteArray(); + } + + private byte[] createHnswIndex(float[][] vectors) throws Throwable { + CagraIndexParams indexParams = new CagraIndexParams.Builder(resources) + .withNumWriterThreads(cuvsWriterThreads) + .withIntermediateGraphDegree(intGraphDegree) + .withGraphDegree(graphDegree) + .withCagraGraphBuildAlgo(CagraGraphBuildAlgo.NN_DESCENT) + .build(); + + log.info("Indexing started: " + System.currentTimeMillis()); + cagraIndexForHnsw = new CagraIndex.Builder(resources) + .withDataset(vectors) + .withIndexParams(indexParams) + .build(); + log.info("Indexing done: " + System.currentTimeMillis() + "ms, documents: " + vectors.length); + + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + File tmpFile = File.createTempFile("tmpindex", "hnsw"); + cagraIndexForHnsw.serializeToHNSW(baos, tmpFile); + return baos.toByteArray(); + } + + @SuppressWarnings({"resource", "rawtypes", "unchecked"}) + @Override + public void flush(int maxDoc, DocMap sortMap) throws IOException { + cuVSIndex = this.segmentWriteState.directory.createOutput(cuVSDataFilename, this.segmentWriteState.context); + CodecUtil.writeIndexHeader(cuVSIndex, CuVSVectorsFormat.VECTOR_DATA_CODEC_NAME, CuVSVectorsFormat.VERSION_CURRENT, this.segmentWriteState.segmentInfo.getId(), this.segmentWriteState.segmentSuffix); + + + CuVSSegmentFile cuVSFile = new CuVSSegmentFile(new SegmentOutputStream(cuVSIndex, 100000)); + + LinkedHashMap metaMap = new LinkedHashMap(); + + for (CagraFieldVectorsWriter field : fieldVectorWriters) { + long start = System.currentTimeMillis(); + + byte[] cagraIndexBytes = null; + byte[] bruteForceIndexBytes = null; + byte[] hnswIndexBytes = null; + try { + log.info("Starting CAGRA indexing, space remaining: "+new File("/").getFreeSpace()); + log.info("Starting CAGRA indexing, docs: " + field.vectors.size()); + + float vectors[][] = new float[field.vectors.size()][field.vectors.get(0).length]; + for (int i = 0; i < vectors.length; i++) { + for (int j = 0; j < vectors[i].length; j++) { + vectors[i][j] = field.vectors.get(i)[j]; + } + } + + cagraIndexBytes = createCagraIndex(vectors, new ArrayList(field.vectors.keySet())); // nocommit + bruteForceIndexBytes = createBruteForceIndex(vectors); + hnswIndexBytes = createHnswIndex(vectors); + } catch (Throwable e) { + e.printStackTrace(); + } + + start = System.currentTimeMillis(); + cuVSFile.addFile(segmentWriteState.segmentInfo.name + "/" + field.fieldName + ".cag", cagraIndexBytes); + log.info("time for writing CAGRA index bytes to zip: " + (System.currentTimeMillis() - start)); + + start = System.currentTimeMillis(); + cuVSFile.addFile(segmentWriteState.segmentInfo.name + "/" + field.fieldName + ".bf", bruteForceIndexBytes); + log.info("time for writing BRUTEFORCE index bytes to zip: " + (System.currentTimeMillis() - start)); + + start = System.currentTimeMillis(); + cuVSFile.addFile(segmentWriteState.segmentInfo.name + "/" + field.fieldName + ".hnsw", hnswIndexBytes); + log.info("time for writing HNSW index bytes to zip: " + (System.currentTimeMillis() - start)); + + start = System.currentTimeMillis(); + cuVSFile.addFile(segmentWriteState.segmentInfo.name + "/" + field.fieldName + ".vec", SerializationUtils.serialize(new ArrayList(field.vectors.values()))); + cuVSFile.addFile(segmentWriteState.segmentInfo.name + "/" + field.fieldName + ".map", SerializationUtils.serialize(new ArrayList(field.vectors.keySet()))); + log.info("list serializing and writing: " + (System.currentTimeMillis() - start)); + field.vectors.clear(); + } + + metaMap.put(segmentWriteState.segmentInfo.name, maxDoc); + cuVSFile.addFile(segmentWriteState.segmentInfo.name + ".meta", SerializationUtils.serialize(metaMap)); + cuVSFile.close(); + + CodecUtil.writeFooter(cuVSIndex); + } + + SegmentOutputStream mergeOutputStream = null; + CuVSSegmentFile mergedIndexFile = null; + + @SuppressWarnings("resource") + @Override + public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + List segInputStreams = new ArrayList(); + List readers = new ArrayList(); + + for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { + CuVSVectorsReader reader = (CuVSVectorsReader) mergeState.knnVectorsReaders[i]; + segInputStreams.add(reader.segmentInputStream); + readers.add(reader); + } + + log.info("Merging one field for segment: " + segmentWriteState.segmentInfo.name); + log.info("Segment files? " + Arrays.toString(segmentWriteState.directory.listAll())); + + if (!List.of(segmentWriteState.directory.listAll()).contains(cuVSDataFilename)) { + IndexOutput mergedVectorIndex = segmentWriteState.directory.createOutput(cuVSDataFilename, segmentWriteState.context); + CodecUtil.writeIndexHeader(mergedVectorIndex, CuVSVectorsFormat.VECTOR_DATA_CODEC_NAME, + CuVSVectorsFormat.VERSION_CURRENT, segmentWriteState.segmentInfo.getId(), segmentWriteState.segmentSuffix); + this.mergeOutputStream = new SegmentOutputStream(mergedVectorIndex, 100000); + mergedIndexFile = new CuVSSegmentFile(this.mergeOutputStream); + } + + log.info("Segment files? " + Arrays.toString(segmentWriteState.directory.listAll())); + + if (mergeStrategy.equals(MergeStrategy.TRIVIAL_MERGE)) { + Util.getMergedArchiveCOS(segInputStreams, segmentWriteState.segmentInfo.name, this.mergeOutputStream + ); + } else if (mergeStrategy.equals(MergeStrategy.NON_TRIVIAL_MERGE)) { + // nocommit: this doesn't merge all the fields + log.info("Readers: "+segInputStreams.size()+", deocMaps: "+mergeState.docMaps.length); + ArrayList docMapList = new ArrayList(); + + for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { + CuVSVectorsReader reader = (CuVSVectorsReader) mergeState.knnVectorsReaders[i]; + for (CuVSIndex index: reader.cuvsIndexes.get(fieldInfo.name)) { + log.info("Mapping for segment ("+reader.fileName+"): " + index.getMapping()); + log.info("Mapping for segment ("+reader.fileName+"): " + index.getMapping().size()); + for (int id=0; id mergedVectors = Util.getMergedVectors(segInputStreams, fieldInfo.name, segmentWriteState.segmentInfo.name); + log.info("Final mapping: " + docMapList); + log.info("Final mapping: " + docMapList.size()); + log.info("Merged vectors: " + mergedVectors.size()); + LinkedHashMap metaMap = new LinkedHashMap(); + byte[] cagraIndexBytes = null; + byte[] bruteForceIndexBytes = null; + byte[] hnswIndexBytes = null; + try { + float vectors[][] = new float[mergedVectors.size()][mergedVectors.get(0).length]; + for (int i = 0; i < vectors.length; i++) { + for (int j = 0; j < vectors[i].length; j++) { + vectors[i][j] = mergedVectors.get(i)[j]; + } + } + cagraIndexBytes = createCagraIndex(vectors, new ArrayList()); + bruteForceIndexBytes = createBruteForceIndex(vectors); + hnswIndexBytes = createHnswIndex(vectors); + } catch (Throwable e) { + e.printStackTrace(); + } + mergedIndexFile.addFile(segmentWriteState.segmentInfo.name + "/" + fieldInfo.getName() + ".cag", cagraIndexBytes); + mergedIndexFile.addFile(segmentWriteState.segmentInfo.name + "/" + fieldInfo.getName() + ".bf", bruteForceIndexBytes); + mergedIndexFile.addFile(segmentWriteState.segmentInfo.name + "/" + fieldInfo.getName() + ".hnsw", hnswIndexBytes); + mergedIndexFile.addFile(segmentWriteState.segmentInfo.name + "/" + fieldInfo.getName() + ".vec", SerializationUtils.serialize(mergedVectors)); + mergedIndexFile.addFile(segmentWriteState.segmentInfo.name + "/" + fieldInfo.getName() + ".map", SerializationUtils.serialize(docMapList)); + metaMap.put(segmentWriteState.segmentInfo.name, mergedVectors.size()); + if (mergedIndexFile.getFilesAdded().contains(segmentWriteState.segmentInfo.name + ".meta") == false) { + mergedIndexFile.addFile(segmentWriteState.segmentInfo.name + ".meta", SerializationUtils.serialize(metaMap)); + } + log.info("DocMaps: "+Arrays.toString(mergeState.docMaps)); + + metaMap.clear(); + } + } + + + @Override + public void finish() throws IOException { + if (this.mergeOutputStream!=null) { + mergedIndexFile.close(); + CodecUtil.writeFooter(mergeOutputStream.out); + IOUtils.close(mergeOutputStream.out); + this.mergeOutputStream = null; + this.mergedIndexFile = null; + } + } + + public class SegmentOutputStream extends OutputStream { + + IndexOutput out; + int bufferSize; + byte[] buffer; + int p; + + public SegmentOutputStream(IndexOutput out, int bufferSize) throws IOException { + super(); + this.out = out; + this.bufferSize = bufferSize; + this.buffer = new byte[this.bufferSize]; + } + + @Override + public void write(int b) throws IOException { + buffer[p] = (byte) b; + p += 1; + if (p == bufferSize) { + flush(); + } + } + + @Override + public void flush() throws IOException { + out.writeBytes(buffer, p); + p = 0; + } + + @Override + public void close() throws IOException { + this.flush(); + } + + } +} diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/PerLeafCuVSKnnCollector.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/PerLeafCuVSKnnCollector.java new file mode 100644 index 000000000000..d4d19fad7041 --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/PerLeafCuVSKnnCollector.java @@ -0,0 +1,74 @@ +package org.apache.lucene.sandbox.vectorsearch; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; + +public class PerLeafCuVSKnnCollector implements KnnCollector { + + public List scoreDocs; + public int topK = 0; + public int iTopK = topK; // TODO getter, no setter + public int searchWidth = 1; // TODO getter, no setter + public int results = 0; + + public PerLeafCuVSKnnCollector(int topK, int iTopK, int searchWidth) { + super(); + this.topK = topK; + this.iTopK = iTopK; + this.searchWidth = searchWidth; + scoreDocs = new ArrayList(); + } + + @Override + public boolean earlyTerminated() { + // TODO: may need implementation + return false; + } + + @Override + public void incVisitedCount(int count) { + // TODO: may need implementation + } + + @Override + public long visitedCount() { + // TODO: may need implementation + return 0; + } + + @Override + public long visitLimit() { + // TODO: may need implementation + return 0; + } + + @Override + public int k() { + return topK; + } + + @Override + @SuppressWarnings("cast") + public boolean collect(int docId, float similarity) { + scoreDocs.add(new ScoreDoc(docId, 1f/(float)(similarity))); + return true; + } + + @Override + public float minCompetitiveSimilarity() { + // TODO: may need implementation + return 0; + } + + @Override + public TopDocs topDocs() { + return new TopDocs(new TotalHits(scoreDocs.size(), TotalHits.Relation.EQUAL_TO), + scoreDocs.toArray(new ScoreDoc[scoreDocs.size()])); + } + +} diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/SegmentInputStream.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/SegmentInputStream.java new file mode 100644 index 000000000000..a352269fbb1b --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/SegmentInputStream.java @@ -0,0 +1,90 @@ +package org.apache.lucene.sandbox.vectorsearch; + +import java.io.IOException; +import java.io.InputStream; + +import org.apache.lucene.store.IndexInput; + +public class SegmentInputStream extends InputStream { + + /** + * + */ + private final IndexInput indexInput; + public final long initialFilePointerPosition; + public final long limit; + public long pos = 0; + + // TODO: This input stream needs to be modified to enable buffering. + public SegmentInputStream(IndexInput indexInput, long limit, long initialFilePointerPosition) throws IOException { + super(); + this.indexInput = indexInput; + this.initialFilePointerPosition = initialFilePointerPosition; + this.limit = limit; + + this.indexInput.seek(initialFilePointerPosition); + } + + @Override + public int read() throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public int read(byte[] b, int off, int len) { + try { + long avail = limit - pos; + if (pos >= limit) { + return -1; + } + if (len > avail) { + len = (int) avail; + } + if (len <= 0) { + return 0; + } + indexInput.readBytes(b, off, len); + pos += len; + return len; + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public int read(byte[] b) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public void reset() throws IOException { + indexInput.seek(initialFilePointerPosition); + pos = 0; + } + + @Override + public long skip(long n) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean markSupported() { + return true; + } + + @Override + public void mark(int readlimit) { + throw new UnsupportedOperationException(); + } + + @Override + public void close() { + // Do nothing for now. + } + + @Override + public int available() { + throw new UnsupportedOperationException(); + } + +} \ No newline at end of file diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/Util.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/Util.java new file mode 100644 index 000000000000..a8200e7b897b --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/Util.java @@ -0,0 +1,142 @@ +package org.apache.lucene.sandbox.vectorsearch; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.OutputStream; +import java.lang.invoke.MethodHandles; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.logging.Logger; +import java.util.zip.Deflater; +import java.util.zip.ZipEntry; +import java.util.zip.ZipInputStream; +import java.util.zip.ZipOutputStream; + +public class Util { + + public static ByteArrayOutputStream getZipEntryBAOS(String fileName, SegmentInputStream segInputStream) + throws IOException { + segInputStream.reset(); + ZipInputStream zipInputStream = new ZipInputStream(segInputStream); + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + boolean fileFound = false; + ZipEntry zipEntry; + while (zipInputStream.available() == 1 && ((zipEntry = zipInputStream.getNextEntry()) != null)) { + if (zipEntry.getName().equals(fileName)) { + fileFound = true; + byte[] buffer = new byte[1024]; + int length; + while ((length = zipInputStream.read(buffer)) != -1) { + baos.write(buffer, 0, length); + } + } + } + if (!fileFound) throw new FileNotFoundException(); + return baos; + } + + private static final Logger log = Logger.getLogger(Util.class.getName()); + + public static ArrayList getMergedVectors(List segInputStreams, String fieldName, String mergedSegmentName) + throws IOException { + ZipEntry zs; + ArrayList mergedVectors = new ArrayList(); + log.info("Getting mergedVectors..."); + for (SegmentInputStream segInputStream : segInputStreams) { + segInputStream.reset(); + ZipInputStream zipStream = new ZipInputStream(segInputStream); + while ((zs = zipStream.getNextEntry()) != null) { + log.info("Getting mergedVectors... " + zs.getName()); + byte[] buffer = new byte[1024]; + int length; + if (zs.getName().endsWith(".vec")) { + String field = zs.getName().split("\\.")[0].split("/")[1]; + if (fieldName.equals(field)) { + ByteArrayOutputStream baosM = new ByteArrayOutputStream(); + while ((length = zipStream.read(buffer)) != -1) { + baosM.write(buffer, 0, length); + } + List m = deSerializeListInMemory(baosM.toByteArray()); + mergedVectors.addAll(m); + } + } + } + } + return mergedVectors; + } + + public static void getMergedArchiveCOS(List segInputStreams, String mergedSegmentName, + OutputStream os) throws IOException { + ZipOutputStream zos = new ZipOutputStream(os); + ZipEntry zs; + Map mergedMetaMap = new LinkedHashMap(); + for (SegmentInputStream segInputStream : segInputStreams) { + segInputStream.reset(); + ZipInputStream zipStream = new ZipInputStream(segInputStream); + while ((zs = zipStream.getNextEntry()) != null) { + byte[] buffer = new byte[1024]; + int length; + if (zs.getName().endsWith(".meta")) { + ByteArrayOutputStream baosM = new ByteArrayOutputStream(); + while ((length = zipStream.read(buffer)) != -1) { + baosM.write(buffer, 0, length); + } + Map m = deSerializeMapInMemory(baosM.toByteArray()); + mergedMetaMap.putAll(m); + } else { + ZipEntry zipEntry = new ZipEntry(zs.getName()); + zos.putNextEntry(zipEntry); + zos.setLevel(Deflater.NO_COMPRESSION); + while ((length = zipStream.read(buffer)) != -1) { + zos.write(buffer, 0, length); + } + zos.closeEntry(); + } + } + } + // Finally put the merged meta file + ZipEntry mergedMetaZipEntry = new ZipEntry(mergedSegmentName + ".meta"); + zos.putNextEntry(mergedMetaZipEntry); + zos.setLevel(Deflater.NO_COMPRESSION); + new ObjectOutputStream(zos).writeObject(mergedMetaMap); // Java serialization should be avoided + zos.closeEntry(); + zos.close(); + } + + @SuppressWarnings("unchecked") + public static Map deSerializeMapInMemory(byte[] bytes) { + Map map = null; + ObjectInputStream ois = null; + try { + ois = new ObjectInputStream(new ByteArrayInputStream(bytes)); + map = (Map) ois.readObject(); + ois.close(); + } catch (Exception e) { + e.printStackTrace(); + } + + return map; + } + + @SuppressWarnings("unchecked") + public static List deSerializeListInMemory(byte[] bytes) { + List map = null; + ObjectInputStream ois = null; + try { + ois = new ObjectInputStream(new ByteArrayInputStream(bytes)); + map = (List) ois.readObject(); + ois.close(); + } catch (Exception e) { + e.printStackTrace(); + } + + return map; + } + +} diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/package-info.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/package-info.java new file mode 100644 index 000000000000..67199edca2f6 --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/package-info.java @@ -0,0 +1 @@ +package org.apache.lucene.sandbox.vectorsearch; diff --git a/lucene/sandbox/src/test/org/apache/lucene/sandbox/vectorsearch/IntegrationTest.java b/lucene/sandbox/src/test/org/apache/lucene/sandbox/vectorsearch/IntegrationTest.java new file mode 100644 index 000000000000..89ee9a3879ba --- /dev/null +++ b/lucene/sandbox/src/test/org/apache/lucene/sandbox/vectorsearch/IntegrationTest.java @@ -0,0 +1,201 @@ +package org.apache.lucene.sandbox.vectorsearch; + +import java.io.IOException; +import java.lang.invoke.MethodHandles; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.TreeMap; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.document.StringField; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.analysis.MockAnalyzer; +import org.apache.lucene.tests.analysis.MockTokenizer; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.tests.util.English; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.tests.util.LuceneTestCase.SuppressSysoutChecks; +import org.apache.lucene.tests.util.TestUtil; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +@SuppressSysoutChecks(bugUrl = "prints info from within cuvs") +public class IntegrationTest extends LuceneTestCase { + + private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); + + private static IndexSearcher searcher; + private static IndexReader reader; + private static Directory directory; + + public static int DATASET_SIZE_LIMIT = 1000; + public static int DIMENSIONS_LIMIT = 2048; + public static int NUM_QUERIES_LIMIT = 10; + public static int TOP_K_LIMIT = 64; // nocommit This fails beyond 64 + + public static float[][] dataset = null; + + @BeforeClass + public static void beforeClass() throws Exception { + directory = newDirectory(); + + Codec codec = new CuVSCodec(); + + RandomIndexWriter writer = + new RandomIndexWriter( + random(), + directory, + newIndexWriterConfig(new MockAnalyzer(random(), MockTokenizer.SIMPLE, true)) + .setMaxBufferedDocs(TestUtil.nextInt(random(), 100, 1000)) + .setCodec(codec) + .setMergePolicy(newTieredMergePolicy())); + + log.info("Merge Policy: " + writer.w.getConfig().getMergePolicy()); + + Random random = random(); + int datasetSize = random.nextInt(DATASET_SIZE_LIMIT) + 1; + int dimensions = random.nextInt(DIMENSIONS_LIMIT) + 1; + dataset = generateDataset(random, datasetSize, dimensions); + for (int i = 0; i < datasetSize; i++) { + Document doc = new Document(); + doc.add(new StringField("id", String.valueOf(i), Field.Store.YES)); + doc.add(newTextField("field", English.intToEnglish(i), Field.Store.YES)); + boolean skipVector = random.nextInt(10) < 0; // nocommit disable testing with holes for now, there's some bug. + if (!skipVector || datasetSize<100) { // about 10th of the documents shouldn't have a single vector + doc.add(new KnnFloatVectorField("vector", dataset[i], VectorSimilarityFunction.EUCLIDEAN)); + doc.add(new KnnFloatVectorField("vector2", dataset[i], VectorSimilarityFunction.EUCLIDEAN)); + } + + writer.addDocument(doc); + } + + reader = writer.getReader(); + searcher = newSearcher(reader); + writer.close(); + } + + @AfterClass + public static void afterClass() throws Exception { + // nocommit This fails until flat vectors are implemented + reader.close(); + directory.close(); + searcher = null; + reader = null; + directory = null; + log.info("Test finished"); + } + + @Test + public void testVectorSearch() throws IOException { + Random random = random(); + int numQueries = random.nextInt(NUM_QUERIES_LIMIT) + 1; + int topK = Math.min(random.nextInt(TOP_K_LIMIT) + 1, dataset.length); + + if(dataset.length < topK) topK = dataset.length; + + float[][] queries = generateQueries(random, dataset[0].length, numQueries); + List> expected = generateExpectedResults(topK, dataset, queries); + + debugPrintDatasetAndQueries(dataset, queries); + + log.info("Dataset size: {}x{}", dataset.length, dataset[0].length); + log.info("Query size: {}x{}", numQueries, queries[0].length); + log.info("TopK: {}", topK); + + Query query = new CuVSKnnFloatVectorQuery("vector", queries[0], topK, topK, 1); + int correct[] = new int[topK]; + for (int i=0; i> generateExpectedResults(int topK, float[][] dataset, float[][] queries) { + List> neighborsResult = new ArrayList<>(); + int dimensions = dataset[0].length; + + for (float[] query : queries) { + Map distances = new TreeMap<>(); + for (int j = 0; j < dataset.length; j++) { + double distance = 0; + for (int k = 0; k < dimensions; k++) { + distance += (query[k] - dataset[j][k]) * (query[k] - dataset[j][k]); + } + distances.put(j, (distance)); + } + + Map sorted = new TreeMap(distances); + log.info("EXPECTED: " + sorted); + + // Sort by distance and select the topK nearest neighbors + List neighbors = distances.entrySet().stream() + .sorted(Map.Entry.comparingByValue()) + .map(Map.Entry::getKey) + .toList(); + neighborsResult.add(neighbors.subList(0, Math.min(topK * 3, dataset.length))); // generate double the topK results in the expected array + } + + log.info("Expected results generated successfully."); + return neighborsResult; + } +} diff --git a/versions.toml b/versions.toml index 80dc51f39bf2..327848fd10d4 100644 --- a/versions.toml +++ b/versions.toml @@ -4,6 +4,8 @@ asm = "9.6" assertj = "3.21.0" commons-codec = "1.13" commons-compress = "1.19" +commons-lang3 = "3.17.0" +cuvs = "25.02" ecj = "3.36.0" errorprone = "2.18.0" flexmark = "0.61.24" @@ -33,6 +35,7 @@ s2-geometry = "1.0.0" spatial4j = "0.8" xerces = "2.12.0" zstd = "1.5.5-11" +jackson-core = "2.18.2" [libraries] antlr-core = { module = "org.antlr:antlr4", version.ref = "antlr" } @@ -42,6 +45,8 @@ asm-core = { module = "org.ow2.asm:asm", version.ref = "asm" } assertj = { module = "org.assertj:assertj-core", version.ref = "assertj" } commons-codec = { module = "commons-codec:commons-codec", version.ref = "commons-codec" } commons-compress = { module = "org.apache.commons:commons-compress", version.ref = "commons-compress" } +commons-lang3 = { module = "org.apache.commons:commons-lang3", version.ref = "commons-lang3" } +cuvs = { module = "com.nvidia.cuvs:cuvs-java", version.ref = "cuvs" } ecj = { module = "org.eclipse.jdt:ecj", version.ref = "ecj" } errorprone = { module = "com.google.errorprone:error_prone_core", version.ref = "errorprone" } flexmark-core = { module = "com.vladsch.flexmark:flexmark", version.ref = "flexmark" } @@ -52,6 +57,7 @@ flexmark-ext-tables = { module = "com.vladsch.flexmark:flexmark-ext-tables", ver groovy = { module = "org.apache.groovy:groovy-all", version.ref = "groovy" } hamcrest = { module = "org.hamcrest:hamcrest", version.ref = "hamcrest" } icu4j = { module = "com.ibm.icu:icu4j", version.ref = "icu4j" } +jackson-core = { module = "com.fasterxml.jackson.core:jackson-core", version.ref = "jackson-core" } javacc = { module = "net.java.dev.javacc:javacc", version.ref = "javacc" } jflex = { module = "de.jflex:jflex", version.ref = "jflex" } jgit = { module = "org.eclipse.jgit:org.eclipse.jgit", version.ref = "jgit" } From 0e9f6d4bc9a98eb33d594409ce8e4b3a6b4b1a06 Mon Sep 17 00:00:00 2001 From: Ishan Chattopadhyaya Date: Tue, 7 Jan 2025 21:28:17 +0530 Subject: [PATCH 2/3] Test fixes --- .../services/org.apache.lucene.codecs.Codec | 1 + .../org.apache.lucene.codecs.KnnVectorsFormat | 16 ++++++++++++++++ .../{IntegrationTest.java => TestCuVS.java} | 2 +- 3 files changed, 18 insertions(+), 1 deletion(-) create mode 100644 lucene/sandbox/src/resources/META-INF/services/org.apache.lucene.codecs.Codec create mode 100644 lucene/sandbox/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat rename lucene/sandbox/src/test/org/apache/lucene/sandbox/vectorsearch/{IntegrationTest.java => TestCuVS.java} (99%) diff --git a/lucene/sandbox/src/resources/META-INF/services/org.apache.lucene.codecs.Codec b/lucene/sandbox/src/resources/META-INF/services/org.apache.lucene.codecs.Codec new file mode 100644 index 000000000000..38b31884377d --- /dev/null +++ b/lucene/sandbox/src/resources/META-INF/services/org.apache.lucene.codecs.Codec @@ -0,0 +1 @@ +org.apache.lucene.sandbox.vectorsearch.CuVSCodec \ No newline at end of file diff --git a/lucene/sandbox/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/lucene/sandbox/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat new file mode 100644 index 000000000000..666ee726f986 --- /dev/null +++ b/lucene/sandbox/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +org.apache.lucene.sandbox.vectorsearch.CuVSVectorsFormat diff --git a/lucene/sandbox/src/test/org/apache/lucene/sandbox/vectorsearch/IntegrationTest.java b/lucene/sandbox/src/test/org/apache/lucene/sandbox/vectorsearch/TestCuVS.java similarity index 99% rename from lucene/sandbox/src/test/org/apache/lucene/sandbox/vectorsearch/IntegrationTest.java rename to lucene/sandbox/src/test/org/apache/lucene/sandbox/vectorsearch/TestCuVS.java index 89ee9a3879ba..15a023d6fbd3 100644 --- a/lucene/sandbox/src/test/org/apache/lucene/sandbox/vectorsearch/IntegrationTest.java +++ b/lucene/sandbox/src/test/org/apache/lucene/sandbox/vectorsearch/TestCuVS.java @@ -34,7 +34,7 @@ import org.slf4j.LoggerFactory; @SuppressSysoutChecks(bugUrl = "prints info from within cuvs") -public class IntegrationTest extends LuceneTestCase { +public class TestCuVS extends LuceneTestCase { private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); From a95f084e1d5a9d16128bd133e0631b193eed8709 Mon Sep 17 00:00:00 2001 From: Vivek Narang Date: Tue, 7 Jan 2025 12:32:57 -0500 Subject: [PATCH 3/3] fix for getFloatVectorValues --- .../vectorsearch/CuVSVectorsReader.java | 40 ++++--------------- 1 file changed, 8 insertions(+), 32 deletions(-) diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSVectorsReader.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSVectorsReader.java index cac870afec6c..837a9229d061 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSVectorsReader.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSVectorsReader.java @@ -196,52 +196,28 @@ public void checkIntegrity() throws IOException { @Override public FloatVectorValues getFloatVectorValues(String field) throws IOException { - throw new UnsupportedOperationException(); - /*return new FloatVectorValues() { - - int pos = -1; - - @Override - public int nextDoc() throws IOException { - pos++; - int size = cuvsIndexes.get(field).get(0).getMapping().size(); - if (pos >= size) return FloatVectorValues.NO_MORE_DOCS; - return cuvsIndexes.get(field).get(0).getMapping().get(pos); - } + return new FloatVectorValues() { @Override - public int docID() { - return cuvsIndexes.get(field).get(0).getMapping().get(pos); + public int size() { + return cuvsIndexes.get(field).get(0).getVectors().size(); } @Override - public int advance(int target) throws IOException { - throw new UnsupportedOperationException(); + public int dimension() { + return cuvsIndexes.get(field).get(0).getVectors().get(0).length; } @Override - public float[] vectorValue() throws IOException { + public float[] vectorValue(int pos) throws IOException { return cuvsIndexes.get(field).get(0).getVectors().get(pos); - } @Override - public int size() { - return cuvsIndexes.get(field).get(0).getVectors().size(); - } - - @Override - public VectorScorer scorer(float[] query) throws IOException { - // TODO Auto-generated method stub + public FloatVectorValues copy() throws IOException { return null; } - - @Override - public int dimension() { - // TODO Auto-generated method stub - return cuvsIndexes.get(field).get(0).getVectors().get(0).length; - } - };*/ + }; } @Override