diff --git a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/chroma.adoc b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/chroma.adoc index 6b2d1b7ae3..f9f3b365ea 100644 --- a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/chroma.adoc +++ b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/chroma.adoc @@ -156,6 +156,22 @@ CALL apoc.vectordb.chroma.query($host, '', ---- +[NOTE] +==== +To optimize performances, we can choose what to `YIELD` with the apoc.vectordb.chroma.query and the `apoc.vectordb.chroma.get` procedures. +For example, by executing a `CALL apoc.vectordb.chroma.query(...) YIELD metadata, score, id`, the RestAPI request will have an {"include": ["metadatas", "documents", "distances"]}, +so that we do not return the other values that we do not need. +==== + +It is possible to execute vector db procedures together with the xref::ml/rag.adoc[apoc.ml.rag] as follow: + +[source,cypher] +---- +CALL apoc.vectordb.chroma.getAndUpdate($host, $collection, [, ], $conf) YIELD node, metadata, id, vector +WITH collect(node) as paths +CALL apoc.ml.rag(paths, $attributes, $question, $confPrompt) YIELD value +RETURN value +---- which returns a string that answers the `$question` by leveraging the embeddings of the db vector. diff --git a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/qdrant.adoc b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/qdrant.adoc index bd671cd3b5..df9d0fbc3a 100644 --- a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/qdrant.adoc +++ b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/qdrant.adoc @@ -1,4 +1,3 @@ - == Qdrant Here is a list of all available Qdrant procedures, @@ -218,7 +217,15 @@ For example, by executing a `CALL apoc.vectordb.qdrant.query(...) YIELD metadata so that we do not return the other values that we do not need. ==== +It is possible to execute vector db procedures together with the xref::ml/rag.adoc[apoc.ml.rag] as follow: +[source,cypher] +---- +CALL apoc.vectordb.qdrant.getAndUpdate($host, $collection, [, ], $conf) YIELD node, metadata, id, vector +WITH collect(node) as paths +CALL apoc.ml.rag(paths, $attributes, $question, $confPrompt) YIELD value +RETURN value +---- which returns a string that answers the `$question` by leveraging the embeddings of the db vector. diff --git a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/weaviate.adoc b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/weaviate.adoc index f88f839be2..66e1b83d73 100644 --- a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/weaviate.adoc +++ b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/weaviate.adoc @@ -1,4 +1,3 @@ - == Weaviate Here is a list of all available Weaviate procedures, @@ -235,7 +234,15 @@ For example, by executing a `CALL apoc.vectordb.weaviate.query(...) YIELD metada so that we do not return the other values that we do not need. ==== +It is possible to execute vector db procedures together with the xref::ml/rag.adoc[apoc.ml.rag] as follow: +[source,cypher] +---- +CALL apoc.vectordb.weaviate.getAndUpdate($host, $collection, [, ], $conf) YIELD score, node, metadata, id, vector +WITH collect(node) as paths +CALL apoc.ml.rag(paths, $attributes, $question, $confPrompt) YIELD value +RETURN value +---- which returns a string that answers the `$question` by leveraging the embeddings of the db vector. diff --git a/full-it/src/test/java/apoc/full/it/vectordb/ChromaDbTest.java b/full-it/src/test/java/apoc/full/it/vectordb/ChromaDbTest.java index 4c60de3ecd..9221b0112f 100644 --- a/full-it/src/test/java/apoc/full/it/vectordb/ChromaDbTest.java +++ b/full-it/src/test/java/apoc/full/it/vectordb/ChromaDbTest.java @@ -1,5 +1,6 @@ package apoc.full.it.vectordb; +import static apoc.ml.Prompt.API_KEY_CONF; import static apoc.util.MapUtil.map; import static apoc.util.TestUtil.testCall; import static apoc.util.TestUtil.testResult; @@ -10,6 +11,7 @@ import static apoc.vectordb.VectorDbTestUtil.assertNodesCreated; import static apoc.vectordb.VectorDbTestUtil.assertRelsCreated; import static apoc.vectordb.VectorDbTestUtil.dropAndDeleteAll; +import static apoc.vectordb.VectorDbTestUtil.ragSetup; import static apoc.vectordb.VectorDbUtil.ERROR_READONLY_MAPPING; import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY; import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY; @@ -21,9 +23,11 @@ import static org.neo4j.configuration.GraphDatabaseSettings.DEFAULT_DATABASE_NAME; import static org.neo4j.configuration.GraphDatabaseSettings.SYSTEM_DATABASE_NAME; +import apoc.ml.Prompt; import apoc.util.TestUtil; import apoc.vectordb.ChromaDb; import apoc.vectordb.VectorDb; +import apoc.vectordb.VectorDbTestUtil; import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicReference; @@ -60,9 +64,9 @@ public static void setUp() throws Exception { sysDb = databaseManagementService.database(SYSTEM_DATABASE_NAME); CHROMA_CONTAINER.start(); - HOST = CHROMA_CONTAINER.getEndpoint(); - TestUtil.registerProcedure(db, ChromaDb.class, VectorDb.class); + HOST = CHROMA_CONTAINER.getEndpoint(); + TestUtil.registerProcedure(db, ChromaDb.class, VectorDb.class, Prompt.class); testCall( db, @@ -452,4 +456,26 @@ public void queryVectorsWithSystemDbStorage() { assertNodesCreated(db); } + + @Test + public void queryVectorsWithRag() { + String openAIKey = ragSetup(db); + + Map conf = map( + ALL_RESULTS_KEY, true, MAPPING_KEY, map(NODE_LABEL, "Rag", ENTITY_KEY, "readID", METADATA_KEY, "foo")); + + testResult( + db, + "CALL apoc.vectordb.chroma.getAndUpdate($host, $collection, ['1', '2'], $conf) YIELD node, metadata, id, vector\n" + + "WITH collect(node) as paths\n" + + "CALL apoc.ml.rag(paths, $attributes, \"Which city has foo equals to one?\", $confPrompt) YIELD value\n" + + "RETURN value", + map( + "host", HOST, + "conf", conf, + "collection", COLL_ID.get(), + "confPrompt", map(API_KEY_CONF, openAIKey), + "attributes", List.of("city", "foo")), + VectorDbTestUtil::assertRagWithVectors); + } } diff --git a/full-it/src/test/java/apoc/full/it/vectordb/QdrantTest.java b/full-it/src/test/java/apoc/full/it/vectordb/QdrantTest.java index 9ec5890ad3..1c1370dd57 100644 --- a/full-it/src/test/java/apoc/full/it/vectordb/QdrantTest.java +++ b/full-it/src/test/java/apoc/full/it/vectordb/QdrantTest.java @@ -1,5 +1,6 @@ package apoc.full.it.vectordb; +import static apoc.ml.Prompt.API_KEY_CONF; import static apoc.ml.RestAPIConfig.HEADERS_KEY; import static apoc.util.MapUtil.map; import static apoc.util.TestUtil.testCall; @@ -22,17 +23,21 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.neo4j.configuration.GraphDatabaseSettings.*; +import apoc.ml.Prompt; import apoc.util.TestUtil; import apoc.util.Util; import apoc.vectordb.Qdrant; import apoc.vectordb.VectorDb; import apoc.vectordb.VectorDbTestUtil; +import java.util.List; import java.util.Map; import org.assertj.core.api.Assertions; import org.junit.AfterClass; +import org.junit.Assume; import org.junit.Before; import org.junit.BeforeClass; import org.junit.ClassRule; @@ -72,9 +77,9 @@ public static void setUp() throws Exception { sysDb = databaseManagementService.database(SYSTEM_DATABASE_NAME); QDRANT_CONTAINER.start(); - HOST = QDRANT_CONTAINER.getHost() + ":" + QDRANT_CONTAINER.getMappedPort(6333); - TestUtil.registerProcedure(db, Qdrant.class, VectorDb.class); + HOST = QDRANT_CONTAINER.getHost() + ":" + QDRANT_CONTAINER.getMappedPort(6333); + TestUtil.registerProcedure(db, Qdrant.class, VectorDb.class, Prompt.class); testCall( db, @@ -203,6 +208,44 @@ public void queryVectors() { }); } + @Test + public void queryVectorsWithRag() { + String openAIKey = System.getenv("OPENAI_KEY"); + ; + Assume.assumeNotNull("No OPENAI_KEY environment configured", openAIKey); + + db.executeTransactionally("CREATE (:Rag {readID: 'one'}), (:Rag {readID: 'two'})"); + + Map conf = map( + ALL_RESULTS_KEY, + true, + HEADERS_KEY, + READONLY_AUTHORIZATION, + MAPPING_KEY, + map(NODE_LABEL, "Rag", ENTITY_KEY, "readID", METADATA_KEY, "foo")); + + testResult( + db, + "CALL apoc.vectordb.qdrant.getAndUpdate($host, 'test_collection', [1, 2], $conf) YIELD node, metadata, id, vector\n" + + "WITH collect(node) as paths\n" + + "CALL apoc.ml.rag(paths, $attributes, \"Which city has foo equals to one?\", $confPrompt) YIELD value\n" + + "RETURN value", + map( + "host", + HOST, + "conf", + conf, + "confPrompt", + map(API_KEY_CONF, openAIKey), + "attributes", + List.of("city", "foo")), + r -> { + Map row = r.next(); + Object value = row.get("value"); + assertTrue("The actual value is: " + value, value.toString().contains("Berlin")); + }); + } + @Test public void queryVectorsWithoutVectorResult() { testResult( diff --git a/full-it/src/test/java/apoc/full/it/vectordb/WeaviateTest.java b/full-it/src/test/java/apoc/full/it/vectordb/WeaviateTest.java index d29902f585..77e8fe10c9 100644 --- a/full-it/src/test/java/apoc/full/it/vectordb/WeaviateTest.java +++ b/full-it/src/test/java/apoc/full/it/vectordb/WeaviateTest.java @@ -1,5 +1,6 @@ package apoc.full.it.vectordb; +import static apoc.ml.Prompt.API_KEY_CONF; import static apoc.ml.RestAPIConfig.HEADERS_KEY; import static apoc.util.TestUtil.testCall; import static apoc.util.TestUtil.testCallEmpty; @@ -8,6 +9,16 @@ import static apoc.vectordb.VectorDbHandler.Type.WEAVIATE; import static apoc.vectordb.VectorDbTestUtil.*; import static apoc.vectordb.VectorDbTestUtil.EntityType.*; +import static apoc.vectordb.VectorDbTestUtil.EntityType.FALSE; +import static apoc.vectordb.VectorDbTestUtil.EntityType.NODE; +import static apoc.vectordb.VectorDbTestUtil.EntityType.REL; +import static apoc.vectordb.VectorDbTestUtil.assertBerlinResult; +import static apoc.vectordb.VectorDbTestUtil.assertLondonResult; +import static apoc.vectordb.VectorDbTestUtil.assertNodesCreated; +import static apoc.vectordb.VectorDbTestUtil.assertRelsCreated; +import static apoc.vectordb.VectorDbTestUtil.dropAndDeleteAll; +import static apoc.vectordb.VectorDbTestUtil.getAuthHeader; +import static apoc.vectordb.VectorDbTestUtil.ragSetup; import static apoc.vectordb.VectorDbUtil.ERROR_READONLY_MAPPING; import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY; import static apoc.vectordb.VectorEmbeddingConfig.FIELDS_KEY; @@ -23,6 +34,7 @@ import static org.neo4j.configuration.GraphDatabaseSettings.DEFAULT_DATABASE_NAME; import static org.neo4j.configuration.GraphDatabaseSettings.SYSTEM_DATABASE_NAME; +import apoc.ml.Prompt; import apoc.util.MapUtil; import apoc.util.TestUtil; import apoc.vectordb.VectorDb; @@ -84,7 +96,7 @@ public static void setUp() throws Exception { WEAVIATE_CONTAINER.start(); HOST = WEAVIATE_CONTAINER.getHttpHostAddress(); - TestUtil.registerProcedure(db, Weaviate.class, VectorDb.class); + TestUtil.registerProcedure(db, Weaviate.class, VectorDb.class, Prompt.class); testCall( db, @@ -593,4 +605,33 @@ private static void assertQueryVectorsWithSystemDbStorage(String keyConfig, Stri }); assertNodesCreated(db); } + + @Test + public void queryVectorsWithRag() { + String openAIKey = ragSetup(db); + + Map conf = MapUtil.map( + FIELDS_KEY, + FIELDS, + ALL_RESULTS_KEY, + true, + HEADERS_KEY, + READONLY_AUTHORIZATION, + MAPPING_KEY, + MapUtil.map(EMBEDDING_KEY, "vect", NODE_LABEL, "Rag", ENTITY_KEY, "readID", METADATA_KEY, "foo")); + + testResult( + db, + "CALL apoc.vectordb.weaviate.getAndUpdate($host, 'TestCollection', [$id1], $conf) YIELD score, node, metadata, id, vector\n" + + "WITH collect(node) as paths\n" + + "CALL apoc.ml.rag(paths, $attributes, \"Which city has foo equals to one?\", $confPrompt) YIELD value\n" + + "RETURN value", + MapUtil.map( + "host", HOST, + "id1", ID_1, + "conf", conf, + "confPrompt", MapUtil.map(API_KEY_CONF, openAIKey), + "attributes", List.of("city", "foo")), + VectorDbTestUtil::assertRagWithVectors); + } } diff --git a/full/src/test/java/apoc/vectordb/VectorDbTestUtil.java b/full/src/test/java/apoc/vectordb/VectorDbTestUtil.java index d949adf64d..c430ecb19c 100644 --- a/full/src/test/java/apoc/vectordb/VectorDbTestUtil.java +++ b/full/src/test/java/apoc/vectordb/VectorDbTestUtil.java @@ -4,9 +4,12 @@ import static apoc.util.Util.map; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; +import apoc.util.MapUtil; import java.util.Map; +import org.junit.Assume; import org.neo4j.graphdb.Entity; import org.neo4j.graphdb.GraphDatabaseService; import org.neo4j.graphdb.ResourceIterator; @@ -89,4 +92,34 @@ private static void assertBerlinProperties(Map props) { public static Map getAuthHeader(String key) { return map("Authorization", "Bearer " + key); } + + public static void assertReadOnlyProcWithMappingResults(Result r, String node) { + Map row = r.next(); + Map props = ((Entity) row.get(node)).getAllProperties(); + assertEquals(MapUtil.map("readID", "one"), props); + assertNotNull(row.get("vector")); + assertNotNull(row.get("id")); + + row = r.next(); + props = ((Entity) row.get(node)).getAllProperties(); + assertEquals(MapUtil.map("readID", "two"), props); + assertNotNull(row.get("vector")); + assertNotNull(row.get("id")); + + assertFalse(r.hasNext()); + } + + public static void assertRagWithVectors(Result r) { + Map row = r.next(); + Object value = row.get("value"); + assertTrue("The actual value is: " + value, value.toString().contains("Berlin")); + } + + public static String ragSetup(GraphDatabaseService db) { + String openAIKey = System.getenv("OPENAI_KEY"); + ; + Assume.assumeNotNull("No OPENAI_KEY environment configured", openAIKey); + db.executeTransactionally("CREATE (:Rag {readID: 'one'}), (:Rag {readID: 'two'})"); + return openAIKey; + } }