Skip to content

Commit

Permalink
[NOID] Fixes #4232: The apoc.vectordb.configure(WEAVIATE', ..) proced…
Browse files Browse the repository at this point in the history
…ure should append /v1 to url (#4248) (#4276)
  • Loading branch information
vga91 authored Dec 6, 2024
1 parent 2b74d67 commit dc72f76
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 19 deletions.
58 changes: 40 additions & 18 deletions full-it/src/test/java/apoc/full/it/vectordb/WeaviateTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,24 @@ public void queryVectorsWithCreateRelWithoutVectorResult() {
public void queryVectorsWithSystemDbStorage() {
String keyConfig = "weaviate-config-foo";
String baseUrl = "http://" + HOST + "/v1";
assertQueryVectorsWithSystemDbStorage(keyConfig, baseUrl, false);
}

@Test
public void queryVectorsWithSystemDbStorageWithUrlWithoutVersion() {
String keyConfig = "weaviate-config-foo";
String baseUrl = "http://" + HOST;
assertQueryVectorsWithSystemDbStorage(keyConfig, baseUrl, false);
}

@Test
public void queryVectorsWithSystemDbStorageWithUrlV3Version() {
String keyConfig = "weaviate-config-foo";
String baseUrl = "http://" + HOST + "/v3";
assertQueryVectorsWithSystemDbStorage(keyConfig, baseUrl, true);
}

private static void assertQueryVectorsWithSystemDbStorage(String keyConfig, String baseUrl, boolean fails) {
Map<String, String> mapping =
map(EMBEDDING_KEY, "vect", NODE_LABEL, "Test", ENTITY_KEY, "myId", METADATA_KEY, "foo");
sysDb.executeTransactionally(
Expand All @@ -550,25 +568,29 @@ public void queryVectorsWithSystemDbStorage() {
"host", baseUrl,
"credentials", ADMIN_KEY,
"mapping", mapping)));

db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})");

testResult(
db,
"CALL apoc.vectordb.weaviate.queryAndUpdate($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)",
map("host", keyConfig, "conf", map("fields", FIELDS, ALL_RESULTS_KEY, true)),
r -> {
Map<String, Object> row = r.next();
assertBerlinResult(row, ID_1, NODE);
assertNotNull(row.get("score"));
assertNotNull(row.get("vector"));

row = r.next();
assertLondonResult(row, ID_2, NODE);
assertNotNull(row.get("score"));
assertNotNull(row.get("vector"));
});

String query =
"CALL apoc.vectordb.weaviate.queryAndUpdate($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)";
Map<String, Object> params = map("host", keyConfig, "conf", map(FIELDS_KEY, FIELDS, ALL_RESULTS_KEY, true));
if (fails) {
try {
testCall(db, query, params, r -> fail());
} catch (Exception e) {
String message = e.getMessage();
assertTrue(message.contains("java.io.FileNotFoundException"));
}
return;
}
testResult(db, query, params, r -> {
Map<String, Object> row = r.next();
assertBerlinResult(row, ID_1, NODE);
assertNotNull(row.get("score"));
assertNotNull(row.get("vector"));
row = r.next();
assertLondonResult(row, ID_2, NODE);
assertNotNull(row.get("score"));
assertNotNull(row.get("vector"));
});
assertNodesCreated(db);
}
}
3 changes: 2 additions & 1 deletion full/src/main/java/apoc/vectordb/VectorDb.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import static apoc.util.Util.listOfNumbersToFloatArray;
import static apoc.util.Util.setProperties;
import static apoc.vectordb.VectorDbUtil.*;
import static apoc.vectordb.VectorDbUtil.appendVersionUrlIfNeeded;
import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY;
import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY;

Expand Down Expand Up @@ -267,7 +268,7 @@ public void vectordb(
Node node = Util.mergeNode(transaction, label, null, Pair.of(SystemPropertyKeys.name.name(), configKey));

Map mapping = (Map) config.get("mapping");
String host = (String) config.get("host");
String host = appendVersionUrlIfNeeded(type, (String) config.get("host"));
Object credentials = config.get("credentials");

if (host != null) {
Expand Down
14 changes: 14 additions & 0 deletions full/src/main/java/apoc/vectordb/VectorDbUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,18 @@ public static void checkMappingConf(Map<String, Object> configuration, String pr
ERROR_READONLY_MAPPING + "\n" + "Try the equivalent procedure, which is the " + procName);
}
}

/**
* If the vectorDb is WEAVIATE and endpoint doesn't end with `/vN`, where N is a number,
* then add `/v1` to the endpoint
*/
public static String appendVersionUrlIfNeeded(VectorDbHandler.Type type, String host) {
if (VectorDbHandler.Type.WEAVIATE == type) {
String regex = ".*(/v\\d+)$";
if (!host.matches(regex)) {
host = host + "/v1";
}
}
return host;
}
}

0 comments on commit dc72f76

Please sign in to comment.