diff --git a/docs/_extension/gallery_directive.py b/docs/_extension/gallery_directive.py
index 1e058970..54692158 100644
--- a/docs/_extension/gallery_directive.py
+++ b/docs/_extension/gallery_directive.py
@@ -1,12 +1,12 @@
"""A directive to generate a gallery of images from structured data.
-Generating a gallery of images that are all the same size is a common
-pattern in documentation, and this can be cumbersome if the gallery is
-generated programmatically. This directive wraps this particular use-case
-in a helper-directive to generate it with a single YAML configuration file.
+Generating a gallery of images that are all the same size is a common pattern in
+documentation, and this can be cumbersome if the gallery is generated
+programmatically. This directive wraps this particular use-case in a helper-
+directive to generate it with a single YAML configuration file.
-It currently exists for maintainers of the pydata-sphinx-theme,
-but might be abstracted into a standalone package if it proves useful.
+It currently exists for maintainers of the pydata-sphinx-theme, but might be
+abstracted into a standalone package if it proves useful.
"""
from pathlib import Path
from typing import Any, Dict, List
diff --git a/docs/api/searchindex.rst b/docs/api/searchindex.rst
index f889489d..972dd251 100644
--- a/docs/api/searchindex.rst
+++ b/docs/api/searchindex.rst
@@ -13,18 +13,22 @@ SearchIndex
.. autosummary::
SearchIndex.__init__
+ SearchIndex.client
+ SearchIndex.name
+ SearchIndex.prefix
+ SearchIndex.key_separator
+ SearchIndex.storage_type
SearchIndex.from_yaml
SearchIndex.from_dict
SearchIndex.from_existing
+ SearchIndex.connect
+ SearchIndex.create
+ SearchIndex.load
SearchIndex.search
SearchIndex.query
- SearchIndex.create
SearchIndex.delete
- SearchIndex.load
- SearchIndex.client
- SearchIndex.connect
- SearchIndex.disconnect
SearchIndex.info
+ SearchIndex.disconnect
@@ -44,17 +48,22 @@ AsyncSearchIndex
.. autosummary::
AsyncSearchIndex.__init__
+ AsyncSearchIndex.client
+ AsyncSearchIndex.name
+ AsyncSearchIndex.prefix
+ AsyncSearchIndex.key_separator
+ AsyncSearchIndex.storage_type
AsyncSearchIndex.from_yaml
AsyncSearchIndex.from_dict
AsyncSearchIndex.from_existing
+ AsyncSearchIndex.connect
+ AsyncSearchIndex.create
+ AsyncSearchIndex.load
AsyncSearchIndex.search
AsyncSearchIndex.query
- AsyncSearchIndex.create
AsyncSearchIndex.delete
- AsyncSearchIndex.load
- AsyncSearchIndex.connect
- AsyncSearchIndex.disconnect
AsyncSearchIndex.info
+ AsyncSearchIndex.disconnect
diff --git a/docs/examples/openai_qna.ipynb b/docs/examples/openai_qna.ipynb
index 51e86c21..10141ade 100644
--- a/docs/examples/openai_qna.ipynb
+++ b/docs/examples/openai_qna.ipynb
@@ -46,7 +46,7 @@
"source": [
"# first we need to install a few things\n",
"\n",
- "!pip install pandas wget tenacity tiktoken openai"
+ "!pip install pandas wget tenacity tiktoken openai==0.28.1"
]
},
{
diff --git a/docs/user_guide/getting_started_01.ipynb b/docs/user_guide/getting_started_01.ipynb
index de670bfb..6cc0a47c 100644
--- a/docs/user_guide/getting_started_01.ipynb
+++ b/docs/user_guide/getting_started_01.ipynb
@@ -127,6 +127,8 @@
"index:\n",
" name: user_index\n",
" prefix: user\n",
+ " storage_type: hash\n",
+ " key_separator: ':'\n",
"\n",
"fields:\n",
" # define tag fields\n",
@@ -162,6 +164,8 @@
" \"index\": {\n",
" \"name\": \"user_index\",\n",
" \"prefix\": \"user\",\n",
+ " \"storage_type\": \"hash\",\n",
+ " \"key_separator\": \":\"\n",
" },\n",
" \"fields\": {\n",
" \"tag\": [{\"name\": \"credit_score\"}],\n",
@@ -217,8 +221,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "\u001b[32m16:03:01\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n",
- "\u001b[32m16:03:01\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_index\n"
+ "\u001b[32m22:49:46\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n",
+ "\u001b[32m22:49:46\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_index\n"
]
}
],
@@ -465,7 +469,7 @@
],
"source": [
"# create a new SearchIndex instance from an existing index\n",
- "existing_index = SearchIndex.from_existing(\"user_index\", \"redis://localhost:6379\")\n",
+ "existing_index = SearchIndex.from_existing(name=\"user_index\", redis_url=\"redis://localhost:6379\")\n",
"\n",
"# run the same query\n",
"results = existing_index.query(query)\n",
@@ -583,7 +587,10 @@
{
"data": {
"text/plain": [
- "{'index': {'name': 'user_index', 'prefix': 'user'},\n",
+ "{'index': {'name': 'user_index',\n",
+ " 'prefix': 'user',\n",
+ " 'storage_type': 'hash',\n",
+ " 'key_separator': ':'},\n",
" 'fields': {'tag': [{'name': 'credit_score'}],\n",
" 'text': [{'name': 'job'}],\n",
" 'numeric': [{'name': 'age'}],\n",
@@ -612,7 +619,10 @@
{
"data": {
"text/plain": [
- "{'index': {'name': 'user_index', 'prefix': 'user'},\n",
+ "{'index': {'name': 'user_index',\n",
+ " 'prefix': 'user',\n",
+ " 'storage_type': 'hash',\n",
+ " 'key_separator': ':'},\n",
" 'fields': {'tag': [{'name': 'credit_score'}, {'name': 'job'}],\n",
" 'text': [],\n",
" 'numeric': [{'name': 'age'}],\n",
@@ -725,7 +735,7 @@
"│ offsets_per_term_avg │ 0 │\n",
"│ records_per_doc_avg │ 4 │\n",
"│ sortable_values_size_mb │ 0 │\n",
- "│ total_indexing_time │ 0.59 │\n",
+ "│ total_indexing_time │ 1.738 │\n",
"│ total_inverted_index_blocks │ 7 │\n",
"│ vector_index_sz_mb │ 0.235603 │\n",
"╰─────────────────────────────┴─────────────╯\n"
diff --git a/docs/user_guide/hash_vs_json_05.ipynb b/docs/user_guide/hash_vs_json_05.ipynb
new file mode 100644
index 00000000..3e2daac3
--- /dev/null
+++ b/docs/user_guide/hash_vs_json_05.ipynb
@@ -0,0 +1,519 @@
+{
+ "cells": [
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Hash vs JSON Storage\n",
+ "\n",
+ "\n",
+ "Out of the box, Redis provides a [variety of data structures](https://redis.com/redis-enterprise/data-structures/) that can adapt to your domain specific applications and use cases.\n",
+ "In this notebook, we will demonstrate how to use RedisVL with both [Hash](https://redis.io/docs/data-types/hashes/) and [JSON](https://redis.io/docs/data-types/json/) data.\n",
+ "\n",
+ "\n",
+ "Before running this notebook, be sure to\n",
+ "1. Have installed ``redisvl`` and have that environment active for this notebook.\n",
+ "2. Have a running Redis Stack or Redis Enterprise instance with RediSearch > 2.4 activated.\n",
+ "\n",
+ "For example, you can run Redis Stack locally with Docker:\n",
+ "\n",
+ "```bash\n",
+ "docker run -d -p 6379:6379 -p 8001:8001 redis/redis-stack:latest\n",
+ "```\n",
+ "\n",
+ "Or create a [FREE Redis Enterprise instance.](https://redis.com/try-free).\n",
+ "\n",
+ "This example will assume a local Redis is running on port 6379 and RedisInsight at 8001."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# import necessary modules\n",
+ "import pickle\n",
+ "from jupyterutils import table_print, result_print\n",
+ "from redisvl.index import SearchIndex\n",
+ "\n",
+ "\n",
+ "# load in the example data and printing utils\n",
+ "data = pickle.load(open(\"hybrid_example_data.pkl\", \"rb\"))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
user | age | job | credit_score | office_location | user_embedding |
---|
john | 18 | engineer | high | -122.4194,37.7749 | b'\\xcd\\xcc\\xcc=\\xcd\\xcc\\xcc=\\x00\\x00\\x00?' |
derrick | 14 | doctor | low | -122.4194,37.7749 | b'\\xcd\\xcc\\xcc=\\xcd\\xcc\\xcc=\\x00\\x00\\x00?' |
nancy | 94 | doctor | high | -122.4194,37.7749 | b'333?\\xcd\\xcc\\xcc=\\x00\\x00\\x00?' |
tyler | 100 | engineer | high | -122.0839,37.3861 | b'\\xcd\\xcc\\xcc=\\xcd\\xcc\\xcc>\\x00\\x00\\x00?' |
tim | 12 | dermatologist | high | -122.0839,37.3861 | b'\\xcd\\xcc\\xcc>\\xcd\\xcc\\xcc>\\x00\\x00\\x00?' |
taimur | 15 | CEO | low | -122.0839,37.3861 | b'\\x9a\\x99\\x19?\\xcd\\xcc\\xcc=\\x00\\x00\\x00?' |
joe | 35 | dentist | medium | -122.0839,37.3861 | b'fff?fff?\\xcd\\xcc\\xcc=' |
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "table_print(data)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Hash or JSON -- how to choose?\n",
+ "Both storage options offer a variety of features and tradeoffs. Below we will work through a dummy dataset to learn when and how to use both."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Working with Hashes\n",
+ "Hashes in Redis are simple collections of field-value pairs. Think of it like a mutable single-level dictionary contains multiple \"rows\":\n",
+ "\n",
+ "\n",
+ "```python\n",
+ "{\n",
+ " \"model\": \"Deimos\",\n",
+ " \"brand\": \"Ergonom\",\n",
+ " \"type\": \"Enduro bikes\",\n",
+ " \"price\": 4972,\n",
+ "}\n",
+ "```\n",
+ "\n",
+ "Hashes are best suited for use cases with the following characteristics:\n",
+ "- Performance (speed) and storage space (memory consumption) are top concerns\n",
+ "- Data can be easily normalized and modeled as a single-level dict\n",
+ "\n",
+ "> Hashes are typically the default recommendation."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# define the hash index schema\n",
+ "hash_schema = {\n",
+ " \"index\": {\n",
+ " \"name\": \"user-hashes\",\n",
+ " \"storage_type\": \"hash\", # default setting\n",
+ " \"prefix\": \"hash\",\n",
+ " \"key_separator\": \":\",\n",
+ " },\n",
+ " \"fields\": {\n",
+ " \"tag\": [{\"name\": \"credit_score\"}, {\"name\": \"user\"}],\n",
+ " \"text\": [{\"name\": \"job\"}],\n",
+ " \"numeric\": [{\"name\": \"age\"}],\n",
+ " \"geo\": [{\"name\": \"office_location\"}],\n",
+ " \"vector\": [{\n",
+ " \"name\": \"user_embedding\",\n",
+ " \"dims\": 3,\n",
+ " \"distance_metric\": \"cosine\",\n",
+ " \"algorithm\": \"flat\",\n",
+ " \"datatype\": \"float32\"}\n",
+ " ]\n",
+ " },\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# construct a search index from the hash schema\n",
+ "hindex = SearchIndex.from_dict(hash_schema)\n",
+ "\n",
+ "# connect to local redis instance\n",
+ "hindex.connect(\"redis://localhost:6379\")\n",
+ "\n",
+ "# create the index (no data yet)\n",
+ "hindex.create(overwrite=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# show the underlying storage type\n",
+ "hindex.storage_type"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Vectors as byte strings\n",
+ "One nuance when working with Hashes in Redis, is that all vectorized data must be passed as a byte string (for efficient storage, indexing, and processing). An example of that can be seen below:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'user': 'john',\n",
+ " 'age': 18,\n",
+ " 'job': 'engineer',\n",
+ " 'credit_score': 'high',\n",
+ " 'office_location': '-122.4194,37.7749',\n",
+ " 'user_embedding': b'\\xcd\\xcc\\xcc=\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'}"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# show a single entry from the data that will be loaded\n",
+ "data[0]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# load hash data\n",
+ "hindex.load(data)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Statistics:\n",
+ "╭─────────────────────────────┬─────────────╮\n",
+ "│ Stat Key │ Value │\n",
+ "├─────────────────────────────┼─────────────┤\n",
+ "│ num_docs │ 7 │\n",
+ "│ num_terms │ 6 │\n",
+ "│ max_doc_id │ 7 │\n",
+ "│ num_records │ 44 │\n",
+ "│ percent_indexed │ 1 │\n",
+ "│ hash_indexing_failures │ 0 │\n",
+ "│ number_of_uses │ 2 │\n",
+ "│ bytes_per_record_avg │ 3.40909 │\n",
+ "│ doc_table_size_mb │ 0.000700951 │\n",
+ "│ inverted_sz_mb │ 0.000143051 │\n",
+ "│ key_table_size_mb │ 0.000221252 │\n",
+ "│ offset_bits_per_record_avg │ 8 │\n",
+ "│ offset_vectors_sz_mb │ 8.58307e-06 │\n",
+ "│ offsets_per_term_avg │ 0.204545 │\n",
+ "│ records_per_doc_avg │ 6.28571 │\n",
+ "│ sortable_values_size_mb │ 0 │\n",
+ "│ total_indexing_time │ 0.335 │\n",
+ "│ total_inverted_index_blocks │ 18 │\n",
+ "│ vector_index_sz_mb │ 0.0202332 │\n",
+ "╰─────────────────────────────┴─────────────╯\n"
+ ]
+ }
+ ],
+ "source": [
+ "!rvl stats -i user-hashes"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Performing Queries\n",
+ "Once our index is created and data is loaded into the right format, we can run queries against the index with RedisVL:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "vector_distance | user | credit_score | age | job | office_location |
---|
0 | john | high | 18 | engineer | -122.4194,37.7749 |
0.109129190445 | tyler | high | 100 | engineer | -122.0839,37.3861 |
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "from redisvl.query import VectorQuery\n",
+ "from redisvl.query.filter import Tag, Text, Num\n",
+ "\n",
+ "t = (Tag(\"credit_score\") == \"high\") & (Text(\"job\") % \"enginee*\") & (Num(\"age\") > 17)\n",
+ "\n",
+ "v = VectorQuery([0.1, 0.1, 0.5],\n",
+ " \"user_embedding\",\n",
+ " return_fields=[\"user\", \"credit_score\", \"age\", \"job\", \"office_location\"],\n",
+ " filter_expression=t)\n",
+ "\n",
+ "\n",
+ "results = hindex.query(v)\n",
+ "result_print(results)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Working with JSON\n",
+ "Redis also supports native **JSON** objects. These can be multi-level (nested) objects, with full JSONPath support for updating/retrieving sub elements:\n",
+ "\n",
+ "```python\n",
+ "{\n",
+ " \"name\": \"bike\",\n",
+ " \"metadata\": {\n",
+ " \"model\": \"Deimos\",\n",
+ " \"brand\": \"Ergonom\",\n",
+ " \"type\": \"Enduro bikes\",\n",
+ " \"price\": 4972,\n",
+ " }\n",
+ "}\n",
+ "```\n",
+ "\n",
+ "JSON is best suited for use cases with the following characteristics:\n",
+ "- Ease of use and data model flexibility are top concerns\n",
+ "- Application data is already native JSON\n",
+ "- Replacing another document storage/db solution"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Full JSON Path support\n",
+ "Because RedisJSON enables full path support, when creating an index schema, elements need to be indexed and selected by their path with the `name` param and aliased using the `as_name` param as shown below."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# define the json index schema\n",
+ "json_schema = {\n",
+ " \"index\": {\n",
+ " \"name\": \"user-json\",\n",
+ " \"storage_type\": \"json\", # updated storage_type option\n",
+ " \"prefix\": \"json\",\n",
+ " \"key_separator\": \":\",\n",
+ " },\n",
+ " \"fields\": {\n",
+ " \"tag\": [{\"name\": \"$.credit_score\", \"as_name\": \"credit_score\"}, {\"name\": \"$.user\", \"as_name\": \"user\"}],\n",
+ " \"text\": [{\"name\": \"$.job\", \"as_name\": \"job\"}],\n",
+ " \"numeric\": [{\"name\": \"$.age\", \"as_name\": \"age\"}],\n",
+ " \"geo\": [{\"name\": \"$.office_location\", \"as_name\": \"office_location\"}],\n",
+ " \"vector\": [{\n",
+ " \"name\": \"$.user_embedding\",\n",
+ " \"as_name\": \"user_embedding\",\n",
+ " \"dims\": 3,\n",
+ " \"distance_metric\": \"cosine\",\n",
+ " \"algorithm\": \"flat\",\n",
+ " \"datatype\": \"float32\"}\n",
+ " ]\n",
+ " },\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# construct a search index from the json schema\n",
+ "jindex = SearchIndex.from_dict(json_schema)\n",
+ "\n",
+ "# connect to local redis instance\n",
+ "jindex.connect(\"redis://localhost:6379\")\n",
+ "\n",
+ "# create the index (no data yet)\n",
+ "jindex.create(overwrite=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\u001b[32m22:50:47\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n",
+ "\u001b[32m22:50:47\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user-hashes\n",
+ "\u001b[32m22:50:47\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 2. user-json\n"
+ ]
+ }
+ ],
+ "source": [
+ "# note the multiple indices in the same database\n",
+ "!rvl index listall"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Vectors as float arrays\n",
+ "Vectorized data stored in JSON must be stored as a pure array (python list) of floats. We will modify our sample data to account for this below:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "\n",
+ "json_data = data.copy()\n",
+ "\n",
+ "for d in json_data:\n",
+ " d['user_embedding'] = np.frombuffer(d['user_embedding'], dtype=np.float32).tolist()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'user': 'john',\n",
+ " 'age': 18,\n",
+ " 'job': 'engineer',\n",
+ " 'credit_score': 'high',\n",
+ " 'office_location': '-122.4194,37.7749',\n",
+ " 'user_embedding': [0.10000000149011612, 0.10000000149011612, 0.5]}"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# inspect a single JSON record\n",
+ "json_data[0]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "jindex.load(json_data)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "vector_distance | user | credit_score | age | job | office_location |
---|
0 | john | high | 18 | engineer | -122.4194,37.7749 |
0.109129190445 | tyler | high | 100 | engineer | -122.0839,37.3861 |
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# we can now run the exact same query as above\n",
+ "result_print(jindex.query(v))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Cleanup"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "hindex.delete()\n",
+ "jindex.delete()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3.8.13 ('redisvl2')",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.12"
+ },
+ "orig_nbformat": 4,
+ "vscode": {
+ "interpreter": {
+ "hash": "9b1e6e9c2967143209c2f955cb869d1d3234f92dc4787f49f155f3abbdfb1316"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/docs/user_guide/hybrid_queries_02.ipynb b/docs/user_guide/hybrid_queries_02.ipynb
index 9443cd53..3daa48ef 100644
--- a/docs/user_guide/hybrid_queries_02.ipynb
+++ b/docs/user_guide/hybrid_queries_02.ipynb
@@ -51,6 +51,8 @@
" \"index\": {\n",
" \"name\": \"user_index\",\n",
" \"prefix\": \"v1\",\n",
+ " \"storage_type\": \"hash\",\n",
+ " \"key_separator\": \":\"\n",
" },\n",
" \"fields\": {\n",
" \"tag\": [{\"name\": \"credit_score\"}],\n",
@@ -95,8 +97,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "\u001b[32m16:03:26\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n",
- "\u001b[32m16:03:26\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_index\n"
+ "\u001b[32m22:51:08\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n",
+ "\u001b[32m22:51:08\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_index\n"
]
}
],
@@ -731,7 +733,6 @@
"metadata": {},
"outputs": [],
"source": [
- "#\n",
"def make_filter(age=None, credit=None, job=None):\n",
" flexible_filter = (\n",
" (Num(\"age\") > age) &\n",
@@ -1107,10 +1108,10 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "{'id': 'v1:54b273392e4d4fa2af424caca095d2d4', 'payload': None, 'user': 'john', 'age': '18', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '==\\x00\\x00\\x00?'}\n",
- "{'id': 'v1:abdab0c48bed49bea9a79d9eb3f247fa', 'payload': None, 'user': 'nancy', 'age': '94', 'job': 'doctor', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '333?=\\x00\\x00\\x00?'}\n",
- "{'id': 'v1:81ea678467be4ca1bd8efaec27766d10', 'payload': None, 'user': 'tyler', 'age': '100', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '=>\\x00\\x00\\x00?'}\n",
- "{'id': 'v1:44741013d4d5469dad4b95f70cedc0bb', 'payload': None, 'user': 'tim', 'age': '12', 'job': 'dermatologist', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '>>\\x00\\x00\\x00?'}\n"
+ "{'id': 'v1:13dbcb6b63e6416187a8c9ee1ab6eae7', 'payload': None, 'user': 'john', 'age': '18', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '==\\x00\\x00\\x00?'}\n",
+ "{'id': 'v1:02d544f7543a40c780dee81116dd5610', 'payload': None, 'user': 'nancy', 'age': '94', 'job': 'doctor', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '333?=\\x00\\x00\\x00?'}\n",
+ "{'id': 'v1:d521d5c1778842e98d8ad50d837a60a4', 'payload': None, 'user': 'tyler', 'age': '100', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '=>\\x00\\x00\\x00?'}\n",
+ "{'id': 'v1:2efe1220f62a4f8fb94055de526ff8f6', 'payload': None, 'user': 'tim', 'age': '12', 'job': 'dermatologist', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '>>\\x00\\x00\\x00?'}\n"
]
}
],
diff --git a/docs/user_guide/index.md b/docs/user_guide/index.md
index d5d88ce7..beb46718 100644
--- a/docs/user_guide/index.md
+++ b/docs/user_guide/index.md
@@ -13,7 +13,8 @@ myst:
getting_started_01
hybrid_queries_02
-vectorizers_03
llmcache_03
+vectorizers_04
+hash_vs_json_05
```
diff --git a/docs/user_guide/llmcache_03.ipynb b/docs/user_guide/llmcache_03.ipynb
index 557b8d10..5d4d236b 100644
--- a/docs/user_guide/llmcache_03.ipynb
+++ b/docs/user_guide/llmcache_03.ipynb
@@ -20,15 +20,15 @@
},
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import openai\n",
"import getpass\n",
- "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"False\"\n",
"\n",
+ "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"False\"\n",
"\n",
"api_key = os.getenv(\"OPENAI_API_KEY\") or getpass.getpass(\"Enter your OpenAI API key: \")\n",
"\n",
@@ -45,14 +45,14 @@
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Paris\n"
+ "Paris.\n"
]
}
],
@@ -72,7 +72,7 @@
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
@@ -80,12 +80,12 @@
"cache = SemanticCache(\n",
" redis_url=\"redis://localhost:6379\",\n",
" threshold=0.9, # semantic similarity threshold\n",
- " )"
+ ")"
]
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": 4,
"metadata": {},
"outputs": [
{
@@ -116,7 +116,7 @@
},
{
"cell_type": "code",
- "execution_count": 20,
+ "execution_count": 5,
"metadata": {},
"outputs": [
{
@@ -125,7 +125,7 @@
"[]"
]
},
- "execution_count": 20,
+ "execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
@@ -137,7 +137,7 @@
},
{
"cell_type": "code",
- "execution_count": 21,
+ "execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@@ -147,7 +147,7 @@
},
{
"cell_type": "code",
- "execution_count": 22,
+ "execution_count": 7,
"metadata": {},
"outputs": [
{
@@ -156,7 +156,7 @@
"['Paris']"
]
},
- "execution_count": 22,
+ "execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
@@ -168,7 +168,7 @@
},
{
"cell_type": "code",
- "execution_count": 23,
+ "execution_count": 8,
"metadata": {},
"outputs": [
{
@@ -177,7 +177,7 @@
"[]"
]
},
- "execution_count": 23,
+ "execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
@@ -189,7 +189,7 @@
},
{
"cell_type": "code",
- "execution_count": 24,
+ "execution_count": 9,
"metadata": {},
"outputs": [
{
@@ -198,7 +198,7 @@
"['Paris']"
]
},
- "execution_count": 24,
+ "execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
@@ -211,7 +211,7 @@
},
{
"cell_type": "code",
- "execution_count": 25,
+ "execution_count": 10,
"metadata": {},
"outputs": [
{
@@ -220,7 +220,7 @@
"[]"
]
},
- "execution_count": 25,
+ "execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
@@ -232,7 +232,7 @@
},
{
"cell_type": "code",
- "execution_count": 26,
+ "execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
@@ -250,7 +250,7 @@
},
{
"cell_type": "code",
- "execution_count": 27,
+ "execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
@@ -266,14 +266,14 @@
},
{
"cell_type": "code",
- "execution_count": 28,
+ "execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Time taken without cache 0.8732700347900391\n"
+ "Time taken without cache 0.574105978012085\n"
]
}
],
@@ -287,15 +287,15 @@
},
{
"cell_type": "code",
- "execution_count": 29,
+ "execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Time Taken with cache: 0.04746699333190918\n",
- "Percentage of time saved: 94.56%\n"
+ "Time Taken with cache: 0.09868717193603516\n",
+ "Percentage of time saved: 82.81%\n"
]
}
],
@@ -309,7 +309,7 @@
},
{
"cell_type": "code",
- "execution_count": 30,
+ "execution_count": 15,
"metadata": {},
"outputs": [
{
@@ -337,9 +337,9 @@
"│ offsets_per_term_avg │ 0 │\n",
"│ records_per_doc_avg │ 2 │\n",
"│ sortable_values_size_mb │ 0 │\n",
- "│ total_indexing_time │ 0.211 │\n",
+ "│ total_indexing_time │ 0.087 │\n",
"│ total_inverted_index_blocks │ 11 │\n",
- "│ vector_index_sz_mb │ 3.00814 │\n",
+ "│ vector_index_sz_mb │ 3.0161 │\n",
"╰─────────────────────────────┴─────────────╯\n"
]
}
@@ -351,7 +351,7 @@
},
{
"cell_type": "code",
- "execution_count": 31,
+ "execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
@@ -376,7 +376,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.10.10"
+ "version": "3.9.12"
},
"orig_nbformat": 4
},
diff --git a/docs/user_guide/schema.yaml b/docs/user_guide/schema.yaml
index c96ac23f..ed6f9f9b 100644
--- a/docs/user_guide/schema.yaml
+++ b/docs/user_guide/schema.yaml
@@ -1,8 +1,8 @@
-
index:
name: providers
prefix: rvl
storage_type: hash
+ key_separator: ':'
fields:
text:
diff --git a/docs/user_guide/vectorizers_03.ipynb b/docs/user_guide/vectorizers_04.ipynb
similarity index 99%
rename from docs/user_guide/vectorizers_03.ipynb
rename to docs/user_guide/vectorizers_04.ipynb
index 183cbc58..4caf1964 100644
--- a/docs/user_guide/vectorizers_03.ipynb
+++ b/docs/user_guide/vectorizers_04.ipynb
@@ -271,6 +271,8 @@
"index:\n",
" name: providers\n",
" prefix: rvl\n",
+ " storage_type: hash\n",
+ " key_separator: ':'\n",
"\n",
"fields:\n",
" text:\n",
diff --git a/redisvl/cli/index.py b/redisvl/cli/index.py
index 3e795224..f1f5c16d 100644
--- a/redisvl/cli/index.py
+++ b/redisvl/cli/index.py
@@ -51,7 +51,7 @@ def __init__(self):
exit(0)
def create(self, args: Namespace):
- """Create an index
+ """Create an index.
Usage:
rvl index create -i | -s
@@ -59,13 +59,13 @@ def create(self, args: Namespace):
if not args.schema:
logger.error("Schema must be provided to create an index")
index = SearchIndex.from_yaml(args.schema)
- url = create_redis_url(args)
- index.connect(url)
+ redis_url = create_redis_url(args)
+ index.connect(redis_url)
index.create()
logger.info("Index created successfully")
def info(self, args: Namespace):
- """Obtain information about an index
+ """Obtain information about an index.
Usage:
rvl index info -i | -s
@@ -74,20 +74,20 @@ def info(self, args: Namespace):
_display_in_table(index.info(), output_format=args.format)
def listall(self, args: Namespace):
- """List all indices
+ """List all indices.
Usage:
rvl index listall
"""
- url = create_redis_url(args)
- conn = get_redis_connection(url)
+ redis_url = create_redis_url(args)
+ conn = get_redis_connection(redis_url)
indices = convert_bytes(conn.execute_command("FT._LIST"))
logger.info("Indices:")
for i, index in enumerate(indices):
logger.info(str(i + 1) + ". " + index)
def delete(self, args: Namespace, drop=False):
- """Delete an index
+ """Delete an index.
Usage:
rvl index delete -i | -s
@@ -97,7 +97,7 @@ def delete(self, args: Namespace, drop=False):
logger.info("Index deleted successfully")
def destroy(self, args: Namespace):
- """Delete an index and the documents within it
+ """Delete an index and the documents within it.
Usage:
rvl index destroy -i | -s
@@ -107,8 +107,8 @@ def destroy(self, args: Namespace):
def _connect_to_index(self, args: Namespace) -> SearchIndex:
# connect to redis
try:
- url = create_redis_url(args)
- conn = get_redis_connection(url=url)
+ redis_url = create_redis_url(args)
+ conn = get_redis_connection(url=redis_url)
except ValueError:
logger.error(
"Must set REDIS_URL environment variable or provide host and port"
@@ -116,7 +116,7 @@ def _connect_to_index(self, args: Namespace) -> SearchIndex:
exit(0)
if args.index:
- index = SearchIndex.from_existing(name=args.index, url=url)
+ index = SearchIndex.from_existing(name=args.index, redis_url=redis_url)
elif args.schema:
index = SearchIndex.from_yaml(args.schema)
index.set_client(conn)
diff --git a/redisvl/cli/log.py b/redisvl/cli/log.py
index 145a76c5..41d5fcb1 100644
--- a/redisvl/cli/log.py
+++ b/redisvl/cli/log.py
@@ -9,7 +9,7 @@
def get_logger(name, log_level="info", fmt=None):
- """Return a logger instance"""
+ """Return a logger instance."""
# Use file name if logger is in debug mode
name = "RedisVL" if log_level == "debug" else name
diff --git a/redisvl/cli/stats.py b/redisvl/cli/stats.py
index 9861f9e3..27a2b9de 100644
--- a/redisvl/cli/stats.py
+++ b/redisvl/cli/stats.py
@@ -56,7 +56,7 @@ def __init__(self):
exit(0)
def stats(self, args: Namespace):
- """Obtain stats about an index
+ """Obtain stats about an index.
Usage:
rvl stats -i | -s
@@ -67,8 +67,8 @@ def stats(self, args: Namespace):
def _connect_to_index(self, args: Namespace) -> SearchIndex:
# connect to redis
try:
- url = create_redis_url(args)
- conn = get_redis_connection(url=url)
+ redis_url = create_redis_url(args)
+ conn = get_redis_connection(url=redis_url)
except ValueError:
logger.error(
"Must set REDIS_ADDRESS environment variable or provide host and port"
@@ -76,7 +76,7 @@ def _connect_to_index(self, args: Namespace) -> SearchIndex:
exit(0)
if args.index:
- index = SearchIndex.from_existing(name=args.index, url=url)
+ index = SearchIndex.from_existing(name=args.index, redis_url=redis_url)
elif args.schema:
index = SearchIndex.from_yaml(args.schema)
index.set_client(conn)
diff --git a/redisvl/index.py b/redisvl/index.py
index 61eef279..03109d01 100644
--- a/redisvl/index.py
+++ b/redisvl/index.py
@@ -1,91 +1,243 @@
-import asyncio
+import json
+from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Union
-from uuid import uuid4
if TYPE_CHECKING:
from redis.commands.search.field import Field
+ from redis.commands.search.document import Document
from redis.commands.search.result import Result
from redisvl.query.query import BaseQuery
import redis
-from redis.commands.search.indexDefinition import IndexDefinition, IndexType
-
-from redisvl.query.query import CountQuery
-from redisvl.schema import SchemaModel, read_schema
-from redisvl.utils.connection import (
- check_connected,
- get_async_redis_connection,
- get_redis_connection,
-)
+from redis.commands.search.indexDefinition import IndexDefinition
+
+from redisvl.query.query import BaseQuery, CountQuery, FilterQuery
+from redisvl.schema import SchemaModel, StorageType, read_schema
+from redisvl.storage import HashStorage, JsonStorage
+from redisvl.utils.connection import get_async_redis_connection, get_redis_connection
from redisvl.utils.utils import (
+ check_async_redis_modules_exist,
check_redis_modules_exist,
convert_bytes,
make_dict,
- process_results,
)
+def process_results(
+ results: "Result", query: BaseQuery, storage_type: StorageType
+) -> List[Dict[str, Any]]:
+ """Convert a list of search Result objects into a list of document
+ dictionaries.
+
+ This function processes results from Redis, handling different storage
+ types and query types. For JSON storage with empty return fields, it
+ unpacks the JSON object while retaining the document ID. The 'payload'
+ field is also removed from all resulting documents for consistency.
+
+ Args:
+ results (Result): The search results from Redis.
+ query (BaseQuery): The query object used for the search.
+ storage_type (StorageType): The storage type of the search
+ index (json or hash).
+
+ Returns:
+ List[Dict[str, Any]]: A list of processed document dictionaries.
+ """
+ # Handle count queries
+ if isinstance(query, CountQuery):
+ return results.total
+
+ # Determine if unpacking JSON is needed
+ unpack_json = (
+ (storage_type == StorageType.JSON)
+ and isinstance(query, FilterQuery)
+ and not query._return_fields
+ )
+
+ # Process records
+ def _process(doc: "Document") -> Dict[str, Any]:
+ doc_dict = doc.__dict__
+
+ # Unpack and Project JSON fields properly
+ if unpack_json and "json" in doc_dict:
+ json_data = doc_dict.get("json", {})
+ if isinstance(json_data, str):
+ json_data = json.loads(json_data)
+ if isinstance(json_data, dict):
+ return {"id": doc_dict.get("id"), **json_data}
+ raise ValueError(f"Unable to parse json data from Redis {json_data}")
+
+ # Remove 'payload' if present
+ doc_dict.pop("payload", None)
+
+ return doc_dict
+
+ return [_process(doc) for doc in results.docs]
+
+
+def check_modules_present(client_variable_name: str):
+ def decorator(func):
+ @wraps(func)
+ def wrapper(self, *args, **kwargs):
+ client = getattr(self, client_variable_name)
+ check_redis_modules_exist(client)
+ return func(self, *args, **kwargs)
+
+ return wrapper
+
+ return decorator
+
+
+def check_async_modules_present(client_variable_name: str):
+ def decorator(func):
+ @wraps(func)
+ async def wrapper(self, *args, **kwargs):
+ client = getattr(self, client_variable_name)
+ await check_async_redis_modules_exist(client)
+ return await func(self, *args, **kwargs)
+
+ return wrapper
+
+ return decorator
+
+
+def check_index_exists():
+ def decorator(func):
+ @wraps(func)
+ def wrapper(self, *args, **kwargs):
+ if not self.exists():
+ raise ValueError(
+ f"Index has not been created. Must be created before calling {func.__name__}"
+ )
+ return func(self, *args, **kwargs)
+
+ return wrapper
+
+ return decorator
+
+
+def check_async_index_exists():
+ def decorator(func):
+ @wraps(func)
+ async def wrapper(self, *args, **kwargs):
+ if not await self.exists():
+ raise ValueError(
+ f"Index has not been created. Must be created before calling {func.__name__}"
+ )
+ return await func(self, *args, **kwargs)
+
+ return wrapper
+
+ return decorator
+
+
+def check_connected(client_variable_name: str):
+ def decorator(func):
+ @wraps(func)
+ def wrapper(self, *args, **kwargs):
+ if getattr(self, client_variable_name) is None:
+ raise ValueError(
+ f"SearchIndex.connect() must be called before calling {func.__name__}"
+ )
+ return func(self, *args, **kwargs)
+
+ return wrapper
+
+ return decorator
+
+
+def check_async_connected(client_variable_name: str):
+ def decorator(func):
+ @wraps(func)
+ async def wrapper(self, *args, **kwargs):
+ if getattr(self, client_variable_name) is None:
+ raise ValueError(
+ f"SearchIndex.connect() must be called before calling {func.__name__}"
+ )
+ return await func(self, *args, **kwargs)
+
+ return wrapper
+
+ return decorator
+
+
class SearchIndexBase:
+ STORAGE_MAP = {
+ StorageType.HASH: HashStorage,
+ StorageType.JSON: JsonStorage,
+ }
+
def __init__(
self,
name: str,
prefix: str = "rvl",
storage_type: str = "hash",
+ key_separator: str = ":",
fields: Optional[List["Field"]] = None,
+ **kwargs,
):
+ """Initialize the RedisVL search index class.
+
+ Args:
+ name (str): Index name.
+ prefix (str, optional): Key prefix associated with the index.
+ Defaults to "rvl".
+ storage_type (str, optional): Underlying Redis storage type (hash
+ or json). Defaults to "hash".
+ key_separator (str, optional): : Separator character to combine
+ prefix and key value for constructing redis keys.
+ Defaults to ":".
+ fields (Optional[List[Field]], optional): List of Redis fields to
+ index. Defaults to None.
+ """
self._name = name
self._prefix = prefix
- self._storage = storage_type
+ self._key_separator = key_separator
+ self._storage_type = StorageType(storage_type)
self._fields = fields
+
+ # configure storage layer
+ self._storage = self.STORAGE_MAP[self._storage_type]( # type: ignore
+ self._prefix, self._key_separator
+ )
+
+ # init empty redis conn
self._redis_conn: Optional[redis.Redis] = None
+ if "redis_url" in kwargs:
+ redis_url = kwargs.pop("redis_url")
+ self.connect(redis_url, **kwargs)
- def set_client(self, client: redis.Redis):
+ def set_client(self, client: redis.Redis) -> None:
+ """Set the Redis client object for the search index."""
self._redis_conn = client
@property
- @check_connected("_redis_conn")
- def client(self) -> redis.Redis:
- """The redis-py client object.
-
- Returns:
- redis.Redis: The redis-py client object
- """
- return self._redis_conn # type: ignore
+ def name(self) -> str:
+ """The name of the Redis search index."""
+ return self._name
- @check_connected("_redis_conn")
- def search(self, *args, **kwargs) -> Union["Result", Any]:
- """Perform a search on this index.
+ @property
+ def prefix(self) -> str:
+ """The optional key prefix that comes before a unique key value in
+ forming a Redis key."""
+ return self._prefix
- Wrapper around redis.search.Search that adds the index name
- to the search query and passes along the rest of the arguments
- to the redis-py ft.search() method.
+ @property
+ def key_separator(self) -> str:
+ """The optional separator between a defined prefix and key value in
+ forming a Redis key."""
+ return self._key_separator
- Returns:
- Union["Result", Any]: Search results.
- """
- results = self._redis_conn.ft(self._name).search( # type: ignore
- *args, **kwargs
- )
- return results
+ @property
+ def storage_type(self) -> StorageType:
+ """The underlying storage type for the search index: hash or json."""
+ return self._storage_type
+ @property
@check_connected("_redis_conn")
- def query(self, query: "BaseQuery") -> List[Dict[str, Any]]:
- """Run a query on this index.
-
- This is similar to the search method, but takes a BaseQuery
- object directly (does not allow for the usage of a raw
- redis query string) and post-processes results of the search.
-
- Args:
- query (BaseQuery): The query to run.
-
- Returns:
- List[Result]: A list of search results.
- """
- results = self.search(query.query, query_params=query.params)
- if isinstance(query, CountQuery):
- return results.total
- return process_results(results)
+ def client(self) -> redis.Redis:
+ """The underlying redis-py client object."""
+ return self._redis_conn # type: ignore
@classmethod
def from_yaml(cls, schema_path: str):
@@ -94,6 +246,12 @@ def from_yaml(cls, schema_path: str):
Args:
schema_path (str): Path to the YAML schema file.
+ Example:
+ >>> from redisvl.index import SearchIndex
+ >>> index = SearchIndex.from_yaml("schema.yaml")
+ >>> index.connect("redis://localhost:6379")
+ >>> index.create(overwrite=True)
+
Returns:
SearchIndex: A SearchIndex object.
"""
@@ -107,6 +265,22 @@ def from_dict(cls, schema_dict: Dict[str, Any]):
Args:
schema_dict (Dict[str, Any]): A dictionary containing the schema.
+ Example:
+ >>> from redisvl.index import SearchIndex
+ >>> index = SearchIndex.from_dict({
+ >>> "index": {
+ >>> "name": "my-index",
+ >>> "prefix": "rvl",
+ >>> "storage_type": "hash",
+ >>> "key_separator": ":"
+ >>> },
+ >>> "fields": {
+ >>> "tag": [{"name": "doc-id"}]
+ >>> }
+ >>> })
+ >>> index.connect("redis://localhost:6379")
+ >>> index.create(overwrite=True)
+
Returns:
SearchIndex: A SearchIndex object.
"""
@@ -117,132 +291,77 @@ def from_dict(cls, schema_dict: Dict[str, Any]):
def from_existing(
cls,
name: str,
- url: Optional[str] = None,
+ redis_url: Optional[str] = None,
+ key_separator: str = ":",
fields: Optional[List["Field"]] = None,
**kwargs,
):
- """Create a SearchIndex from an existing index in Redis.
-
- Args:
- name (str): Index name.
- url (Optional[str], optional): Redis URL. REDIS_URL env var
- is used if not provided. Defaults to None.
- fields (Optional[List[Field]], optional): List of Redis search
- fields to include in the schema. Defaults to None.
-
- Returns:
- SearchIndex: A SearchIndex object.
+ raise NotImplementedError
- Raises:
- redis.exceptions.ResponseError: If the index does not exist.
- ValueError: If the REDIS_URL env var is not set and url is not provided.
- """
+ @check_connected("_redis_conn")
+ @check_modules_present("_redis_conn")
+ @check_index_exists()
+ def search(self, *args, **kwargs) -> Union["Result", Any]:
raise NotImplementedError
- def connect(self, url: str, **kwargs):
- """Connect to a Redis instance.
+ @check_connected("_redis_conn")
+ @check_modules_present("_redis_conn")
+ @check_index_exists()
+ def query(self, query: "BaseQuery") -> List[Dict[str, Any]]:
+ raise NotImplementedError
- Args:
- url (str): Redis URL. REDIS_URL env var is used if not provided.
- """
+ def connect(self, redis_url: Optional[str] = None, **kwargs):
+ """Connect to a Redis instance."""
raise NotImplementedError
def disconnect(self):
- """Disconnect from the Redis instance"""
+ """Disconnect from the Redis instance."""
self._redis_conn = None
return self
def key(self, key_value: str) -> str:
- """
- Create a redis key as a combination of an index key prefix (optional) and specified key value.
- The key value is typically a unique identifier, created at random, or derived from
- some specified metadata.
+ """Create a redis key as a combination of an index key prefix (optional)
+ and specified key value. The key value is typically a unique identifier,
+ created at random, or derived from some specified metadata.
Args:
- key_value (str): The specified unique identifier for a particular document
- indexed in Redis.
+ key_value (str): The specified unique identifier for a particular
+ document indexed in Redis.
Returns:
str: The full Redis key including key prefix and value as a string.
"""
- return f"{self._prefix}:{key_value}" if self._prefix else key_value
-
- def _create_key(
- self, record: Dict[str, Any], key_field: Optional[str] = None
- ) -> str:
- """Construct the Redis HASH top level key.
-
- Args:
- record (Dict[str, Any]): A dictionary containing the record to be indexed.
- key_field (Optional[str], optional): A field within the record
- to use in the Redis hash key.
-
- Returns:
- str: The key to be used for a given record in Redis.
-
- Raises:
- ValueError: If the key field is not found in the record.
- """
- if key_field is None:
- key_value = uuid4().hex
- else:
- try:
- key_value = record[key_field] # type: ignore
- except KeyError:
- raise ValueError(f"Key field {key_field} not found in record {record}")
- return self.key(key_value)
+ return self._storage._key(key_value, self._prefix, self._key_separator)
@check_connected("_redis_conn")
+ @check_modules_present("_redis_conn")
+ @check_index_exists()
def info(self) -> Dict[str, Any]:
- """Get information about the index.
-
- Returns:
- dict: A dictionary containing the information about the index.
- """
- return convert_bytes(self._redis_conn.ft(self._name).info()) # type: ignore
-
- def create(self, overwrite: Optional[bool] = False):
- """Create an index in Redis from this SearchIndex object.
-
- Args:
- overwrite (bool, optional): Overwrite the index if it already exists. Defaults to False.
+ raise NotImplementedError
- Raises:
- redis.exceptions.ResponseError: If the index already exists.
- """
+ @check_connected("_redis_conn")
+ @check_modules_present("_redis_conn")
+ def create(self, overwrite: bool = False):
raise NotImplementedError
+ @check_connected("_redis_conn")
+ @check_modules_present("_redis_conn")
+ @check_index_exists()
def delete(self, drop: bool = True):
- """Delete the search index.
-
- Args:
- drop (bool, optional): Delete the documents in the index. Defaults to True.
-
- Raises:
- redis.exceptions.ResponseError: If the index does not exist.
- """
raise NotImplementedError
+ @check_connected("_redis_conn")
+ @check_modules_present("_redis_conn")
def load(
self,
- data: Iterable[Dict[str, Any]],
+ data: Iterable[Any],
key_field: Optional[str] = None,
+ keys: Optional[Iterable[str]] = None,
+ ttl: Optional[int] = None,
preprocess: Optional[Callable] = None,
+ concurrency: Optional[int] = None,
**kwargs,
):
- """Load data into Redis and index using this SearchIndex object.
-
- Args:
- data (Iterable[Dict[str, Any]]): An iterable of dictionaries
- containing the data to be indexed.
- key_field (Optional[str], optional): A field within the record
- to use in the Redis hash key.
- preprocess (Optional[Callabl], optional): An optional preprocessor function
- that mutates the individual record before writing to redis.
-
- Raises:
- redis.exceptions.ResponseError: If the index does not exist.
- """
raise NotImplementedError
@@ -255,24 +374,17 @@ class SearchIndex(SearchIndexBase):
Example:
>>> from redisvl.index import SearchIndex
>>> index = SearchIndex.from_yaml("schema.yaml")
+ >>> index.connect("redis://localhost:6379")
>>> index.create(overwrite=True)
>>> index.load(data) # data is an iterable of dictionaries
"""
- def __init__(
- self,
- name: str,
- prefix: str = "rvl",
- storage_type: str = "hash",
- fields: Optional[List["Field"]] = None,
- ):
- super().__init__(name, prefix, storage_type, fields)
-
@classmethod
def from_existing(
cls,
name: str,
- url: Optional[str] = None,
+ redis_url: Optional[str] = None,
+ key_separator: str = ":",
fields: Optional[List["Field"]] = None,
**kwargs,
):
@@ -280,8 +392,10 @@ def from_existing(
Args:
name (str): Index name.
- url (Optional[str], optional): Redis URL. REDIS_URL env var
+ redis_url (Optional[str], optional): Redis URL. REDIS_URL env var
is used if not provided. Defaults to None.
+ key_separator (str, optional): Separator char to combine prefix and
+ key value for constructing redis keys. Defaults to ":".
fields (Optional[List[Field]], optional): List of Redis search
fields to include in the schema. Defaults to None.
@@ -290,10 +404,9 @@ def from_existing(
Raises:
redis.exceptions.ResponseError: If the index does not exist.
- ValueError: If the REDIS_URL env var is not set and url is not provided.
-
+ ValueError: If the redis url is not accessible.
"""
- client = get_redis_connection(url, **kwargs)
+ client = get_redis_connection(redis_url, **kwargs)
info = convert_bytes(client.ft(name).info())
index_definition = make_dict(info["index_definition"])
storage_type = index_definition["key_type"].lower()
@@ -302,36 +415,40 @@ def from_existing(
name=name,
storage_type=storage_type,
prefix=prefix,
+ key_separator=key_separator,
fields=fields,
)
instance.set_client(client)
return instance
- def connect(self, url: Optional[str] = None, **kwargs):
+ def connect(self, redis_url: Optional[str] = None, **kwargs):
"""Connect to a Redis instance.
Args:
- url (str): Redis URL. REDIS_URL env var is used if not provided.
+ redis_url (Optional[str], optional): Redis URL. REDIS_URL env var is
+ used if not provided.
Raises:
redis.exceptions.ConnectionError: If the connection to Redis fails.
- ValueError: If the REDIS_URL env var is not set and url is not provided.
+ ValueError: If the redis url is not accessible.
"""
- self._redis_conn = get_redis_connection(url, **kwargs)
+ self._redis_conn = get_redis_connection(redis_url, **kwargs)
return self
@check_connected("_redis_conn")
- def create(self, overwrite: Optional[bool] = False):
+ @check_modules_present("_redis_conn")
+ def create(self, overwrite: bool = False) -> None:
"""Create an index in Redis from this SearchIndex object.
Args:
- overwrite (bool, optional): Overwrite the index if it already exists. Defaults to False.
+ overwrite (bool, optional): Whether to overwrite the index if it
+ already exists. Defaults to False.
Raises:
- redis.exceptions.ResponseError: If the index already exists.
+ RuntimeError: If the index already exists and 'overwrite' is False.
+ ValueError: If no fields are defined for the index.
"""
- check_redis_modules_exist(self._redis_conn)
-
+ # Check that fields are defined.
if not self._fields:
raise ValueError("No fields defined for index")
if not isinstance(overwrite, bool):
@@ -344,25 +461,23 @@ def create(self, overwrite: Optional[bool] = False):
print("Index already exists, overwriting.")
self.delete()
- # set storage_type, default to hash
- storage_type = IndexType.HASH
- # TODO - enable JSON support
- # if self._storage.lower() == "json":
- # storage_type = IndexType.JSON
-
- # Create Index
- # will raise correct response error if index already exists
+ # Create the index with the specified fields and settings.
self._redis_conn.ft(self._name).create_index( # type: ignore
fields=self._fields,
- definition=IndexDefinition(prefix=[self._prefix], index_type=storage_type),
+ definition=IndexDefinition(
+ prefix=[self._prefix], index_type=self._storage.type
+ ),
)
@check_connected("_redis_conn")
+ @check_modules_present("_redis_conn")
+ @check_index_exists()
def delete(self, drop: bool = True):
"""Delete the search index.
Args:
- drop (bool, optional): Delete the documents in the index. Defaults to True.
+ drop (bool, optional): Delete the documents in the index.
+ Defaults to True.
raises:
redis.exceptions.ResponseError: If the index does not exist.
@@ -371,61 +486,94 @@ def delete(self, drop: bool = True):
self._redis_conn.ft(self._name).dropindex(delete_documents=drop) # type: ignore
@check_connected("_redis_conn")
+ @check_modules_present("_redis_conn")
def load(
self,
- data: Iterable[Dict[str, Any]],
+ data: Iterable[Any],
key_field: Optional[str] = None,
+ keys: Optional[Iterable[str]] = None,
+ ttl: Optional[int] = None,
preprocess: Optional[Callable] = None,
+ batch_size: Optional[int] = None,
**kwargs,
):
- """Load data into Redis and index using this SearchIndex object.
+ """Load a batch of objects to Redis.
Args:
- data (Iterable[Dict[str, Any]]): An iterable of dictionaries
- containing the data to be indexed.
- key_field (Optional[str], optional): A field within the record to
- use in the Redis hash key.
- preprocess (Optional[Callable], optional): An optional preprocessor function
- that mutates the individual record before writing to redis.
+ data (Iterable[Any]): An iterable of objects to store.
+ key_field (Optional[str], optional): Field used as the key for each
+ object. Defaults to None.
+ keys (Optional[Iterable[str]], optional): Optional iterable of keys.
+ Must match the length of objects if provided. Defaults to None.
+ ttl (Optional[int], optional): Time-to-live in seconds for each key.
+ Defaults to None.
+ preprocess (Optional[Callable], optional): A function to preprocess
+ objects before storage. Defaults to None.
+ batch_size (Optional[int], optional): Number of objects to write in
+ a single Redis pipeline execution. Defaults to class's
+ default batch size.
- raises:
- redis.exceptions.ResponseError: If the index does not exist.
+ Raises:
+ ValueError: If the length of provided keys does not match the length
+ of objects.
Example:
>>> data = [{"foo": "bar"}, {"test": "values"}]
- >>> def func(record: dict): record["new"]="value";return record
+ >>> def func(record: dict):
+ >>> record["new"] = "value"
+ >>> return record
>>> index.load(data, preprocess=func)
"""
- # TODO -- should we return a count of the upserts? or some kind of metadata?
- if data:
- if not isinstance(data, Iterable):
- if not isinstance(data[0], dict):
- raise TypeError("data must be an iterable of dictionaries")
-
- # Check if outer interface passes in TTL on load
- ttl = kwargs.get("ttl")
- with self._redis_conn.pipeline(transaction=False) as pipe: # type: ignore
- for record in data:
- key = self._create_key(record, key_field)
- # Optionally preprocess the record and validate type
- if preprocess:
- try:
- record = preprocess(record)
- except Exception as e:
- raise RuntimeError(
- "Error while preprocessing records on load"
- ) from e
- if not isinstance(record, dict):
- raise TypeError(
- f"Individual records must be of type dict, got type {type(record)}"
- )
- # Write the record to Redis
- pipe.hset(key, mapping=record) # type: ignore
- if ttl:
- pipe.expire(key, ttl)
- pipe.execute()
+ self._storage.write(
+ self.client,
+ objects=data,
+ key_field=key_field,
+ keys=keys,
+ ttl=ttl,
+ preprocess=preprocess,
+ batch_size=batch_size,
+ )
@check_connected("_redis_conn")
+ @check_modules_present("_redis_conn")
+ @check_index_exists()
+ def search(self, *args, **kwargs) -> Union["Result", Any]:
+ """Perform a search on this index.
+
+ Wrapper around redis.search.Search that adds the index name
+ to the search query and passes along the rest of the arguments
+ to the redis-py ft.search() method.
+
+ Returns:
+ Union["Result", Any]: Search results.
+ """
+ results = self._redis_conn.ft(self._name).search( # type: ignore
+ *args, **kwargs
+ )
+ return results
+
+ @check_connected("_redis_conn")
+ @check_modules_present("_redis_conn")
+ @check_index_exists()
+ def query(self, query: "BaseQuery") -> List[Dict[str, Any]]:
+ """Run a query on this index.
+
+ This is similar to the search method, but takes a BaseQuery
+ object directly (does not allow for the usage of a raw
+ redis query string) and post-processes results of the search.
+
+ Args:
+ query (BaseQuery): The query to run.
+
+ Returns:
+ List[Result]: A list of search results.
+ """
+ results = self.search(query.query, query_params=query.params)
+ # post process the results
+ return process_results(results, query=query, storage_type=self._storage_type)
+
+ @check_connected("_redis_conn")
+ @check_modules_present("_redis_conn")
def exists(self) -> bool:
"""Check if the index exists in Redis.
@@ -435,6 +583,17 @@ def exists(self) -> bool:
indices = convert_bytes(self._redis_conn.execute_command("FT._LIST")) # type: ignore
return self._name in indices
+ @check_connected("_redis_conn")
+ @check_modules_present("_redis_conn")
+ @check_index_exists()
+ def info(self) -> Dict[str, Any]:
+ """Get information about the index.
+
+ Returns:
+ dict: A dictionary containing the information about the index.
+ """
+ return convert_bytes(self._redis_conn.ft(self._name).info()) # type: ignore
+
class AsyncSearchIndex(SearchIndexBase):
"""A class for interacting with Redis as a vector database asynchronously.
@@ -445,24 +604,17 @@ class AsyncSearchIndex(SearchIndexBase):
Example:
>>> from redisvl.index import AsyncSearchIndex
>>> index = AsyncSearchIndex.from_yaml("schema.yaml")
+ >>> index.connect("redis://localhost:6379")
>>> await index.create(overwrite=True)
>>> await index.load(data) # data is an iterable of dictionaries
"""
- def __init__(
- self,
- name: str,
- prefix: str = "rvl",
- storage_type: str = "hash",
- fields: Optional[List["Field"]] = None,
- ):
- super().__init__(name, prefix, storage_type, fields)
-
@classmethod
async def from_existing(
cls,
name: str,
- url: Optional[str] = None,
+ redis_url: Optional[str] = None,
+ key_separator: str = ":",
fields: Optional[List["Field"]] = None,
**kwargs,
):
@@ -470,20 +622,21 @@ async def from_existing(
Args:
name (str): Index name.
- url (Optional[str], optional): Redis URL. REDIS_URL env var
+ redis_url (Optional[str], optional): Redis URL. REDIS_URL env var
is used if not provided. Defaults to None.
+ key_separator (str, optional): Separator char to combine prefix and
+ key value for constructing redis keys. Defaults to ":".
fields (Optional[List[Field]], optional): List of Redis search
fields to include in the schema. Defaults to None.
Returns:
- SearchIndex: A SearchIndex object.
+ AsyncSearchIndex: An AsyncSearchIndex object.
Raises:
redis.exceptions.ResponseError: If the index does not exist.
- ValueError: If the REDIS_URL env var is not set and url is not provided.
-
+ ValueError: If the Redis URL is not accessible.
"""
- client = get_async_redis_connection(url, **kwargs)
+ client = get_async_redis_connection(redis_url, **kwargs)
info = convert_bytes(await client.ft(name).info())
index_definition = make_dict(info["index_definition"])
storage_type = index_definition["key_type"].lower()
@@ -492,37 +645,38 @@ async def from_existing(
name=name,
storage_type=storage_type,
prefix=prefix,
+ key_separator=key_separator,
fields=fields,
)
instance.set_client(client)
return instance
- def connect(self, url: Optional[str] = None, **kwargs):
+ def connect(self, redis_url: Optional[str] = None, **kwargs):
"""Connect to a Redis instance.
Args:
- url (str): Redis URL. REDIS_URL env var is used if not provided.
+ redis_url (Optional[str], optional): Redis URL. REDIS_URL env var is
+ used if not provided.
Raises:
redis.exceptions.ConnectionError: If the connection to Redis fails.
- ValueError: If no Redis URL is provided and REDIS_URL env var is not set.
+ ValueError: If the Redis URL is not accessible.
"""
- self._redis_conn = get_async_redis_connection(url, **kwargs)
+ self._redis_conn = get_async_redis_connection(redis_url, **kwargs)
return self
- @check_connected("_redis_conn")
- async def create(self, overwrite: Optional[bool] = False):
- """Create an index in Redis from this SearchIndex object.
+ @check_async_connected("_redis_conn")
+ @check_async_modules_present("_redis_conn")
+ async def create(self, overwrite: bool = False) -> None:
+ """Asynchronously create an index in Redis from this SearchIndex object.
Args:
- overwrite (bool, optional): Overwrite the index if it already exists. Defaults to False.
+ overwrite (bool, optional): Whether to overwrite the index if it
+ already exists. Defaults to False.
Raises:
- redis.exceptions.ResponseError: If the index already exists.
+ RuntimeError: If the index already exists and 'overwrite' is False.
"""
- # TODO - enable async version of this
- # check_redis_modules_exist(self._redis_conn)
-
if not self._fields:
raise ValueError("No fields defined for index")
if not isinstance(overwrite, bool):
@@ -535,24 +689,23 @@ async def create(self, overwrite: Optional[bool] = False):
print("Index already exists, overwriting.")
await self.delete()
- # set storage_type, default to hash
- storage_type = IndexType.HASH
- # TODO - enable JSON support
- # if self._storage.lower() == "json":
- # storage_type = IndexType.JSON
-
- # Create Index
+ # Create Index with proper IndexType
await self._redis_conn.ft(self._name).create_index( # type: ignore
fields=self._fields,
- definition=IndexDefinition(prefix=[self._prefix], index_type=storage_type),
+ definition=IndexDefinition(
+ prefix=[self._prefix], index_type=self._storage.type
+ ),
)
- @check_connected("_redis_conn")
+ @check_async_connected("_redis_conn")
+ @check_async_modules_present("_redis_conn")
+ @check_async_index_exists()
async def delete(self, drop: bool = True):
"""Delete the search index.
Args:
- drop (bool, optional): Delete the documents in the index. Defaults to True.
+ drop (bool, optional): Delete the documents in the index.
+ Defaults to True.
Raises:
redis.exceptions.ResponseError: If the index does not exist.
@@ -560,61 +713,58 @@ async def delete(self, drop: bool = True):
# Delete the search index
await self._redis_conn.ft(self._name).dropindex(delete_documents=drop) # type: ignore
- @check_connected("_redis_conn")
+ @check_async_connected("_redis_conn")
+ @check_async_modules_present("_redis_conn")
async def load(
self,
- data: Iterable[Dict[str, Any]],
- concurrency: int = 10,
+ data: Iterable[Any],
key_field: Optional[str] = None,
+ keys: Optional[Iterable[str]] = None,
+ ttl: Optional[int] = None,
preprocess: Optional[Callable] = None,
+ concurrency: Optional[int] = None,
**kwargs,
):
- """Load data into Redis and index using this SearchIndex object.
+ """Asynchronously load objects to Redis with concurrency control.
Args:
- data (Iterable[Dict[str, Any]]): An iterable of dictionaries
- containing the data to be indexed.
- concurrency (int, optional): Number of concurrent tasks to run. Defaults to 10.
- key_field (Optional[str], optional): A field within the record to
- use in the Redis hash key.
- preprocess (Optional[Callable], optional): An optional preprocessor function
- that mutates the individual record before writing to redis.
+ data (Iterable[Any]): An iterable of objects to store.
+ key_field (Optional[str], optional): Field used as the key for each
+ object. Defaults to None.
+ keys (Optional[Iterable[str]], optional): Optional iterable of keys.
+ Must match the length of objects if provided. Defaults to None.
+ ttl (Optional[int], optional): Time-to-live in seconds for each key.
+ Defaults to None.
+ preprocess (Optional[Callable], optional): An async function to
+ preprocess objects before storage. Defaults to None.
+ concurrency (Optional[int], optional): The maximum number of
+ concurrent write operations. Defaults to class's default
+ concurrency level.
Raises:
- redis.exceptions.ResponseError: If the index does not exist.
+ ValueError: If the length of provided keys does not match the
+ length of objects.
Example:
>>> data = [{"foo": "bar"}, {"test": "values"}]
- >>> def func(record: dict): record["new"]="value";return record
+ >>> async def func(record: dict):
+ >>> record["new"] = "value"
+ >>> return record
>>> await index.load(data, preprocess=func)
"""
- ttl = kwargs.get("ttl")
- semaphore = asyncio.Semaphore(concurrency)
-
- async def _load(record: dict):
- async with semaphore:
- key = self._create_key(record, key_field)
- # Optionally preprocess the record and validate type
- if preprocess:
- try:
- record = preprocess(record)
- except Exception as e:
- raise RuntimeError(
- "Error while preprocessing records on load"
- ) from e
- if not isinstance(record, dict):
- raise TypeError(
- f"Individual records must be of type dict, got type {type(record)}"
- )
- # Write the record to Redis
- await self._redis_conn.hset(key, mapping=record) # type: ignore
- if ttl:
- await self._redis_conn.expire(key, ttl) # type: ignore
-
- # Gather with concurrency
- await asyncio.gather(*[_load(record) for record in data])
+ await self._storage.awrite(
+ self.client,
+ objects=data,
+ key_field=key_field,
+ keys=keys,
+ ttl=ttl,
+ preprocess=preprocess,
+ concurrency=concurrency,
+ )
- @check_connected("_redis_conn")
+ @check_async_connected("_redis_conn")
+ @check_async_modules_present("_redis_conn")
+ @check_async_index_exists()
async def search(self, *args, **kwargs) -> Union["Result", Any]:
"""Perform a search on this index.
@@ -625,11 +775,12 @@ async def search(self, *args, **kwargs) -> Union["Result", Any]:
Returns:
Union["Result", Any]: Search results.
"""
- results = await self._redis_conn.ft(self._name).search( # type: ignore
- *args, **kwargs
- )
+ results = await self._redis_conn.ft(self._name).search(*args, **kwargs) # type: ignore
return results
+ @check_async_connected("_redis_conn")
+ @check_async_modules_present("_redis_conn")
+ @check_async_index_exists()
async def query(self, query: "BaseQuery") -> List[Dict[str, Any]]:
"""Run a query on this index.
@@ -644,11 +795,11 @@ async def query(self, query: "BaseQuery") -> List[Dict[str, Any]]:
List[Result]: A list of search results.
"""
results = await self.search(query.query, query_params=query.params)
- if isinstance(query, CountQuery):
- return results.total
- return process_results(results)
+ # post process the results
+ return process_results(results, query=query, storage_type=self._storage_type)
- @check_connected("_redis_conn")
+ @check_async_connected("_redis_conn")
+ @check_async_modules_present("_redis_conn")
async def exists(self) -> bool:
"""Check if the index exists in Redis.
@@ -657,3 +808,16 @@ async def exists(self) -> bool:
"""
indices = await self._redis_conn.execute_command("FT._LIST") # type: ignore
return self._name in convert_bytes(indices)
+
+ @check_async_connected("_redis_conn")
+ @check_async_modules_present("_redis_conn")
+ @check_async_index_exists()
+ async def info(self) -> Dict[str, Any]:
+ """Get information about the index.
+
+ Returns:
+ dict: A dictionary containing the information about the index.
+ """
+ return convert_bytes(
+ await self._redis_conn.ft(self._name).info() # type: ignore
+ )
diff --git a/redisvl/llmcache/base.py b/redisvl/llmcache/base.py
index fb6248c7..574c114f 100644
--- a/redisvl/llmcache/base.py
+++ b/redisvl/llmcache/base.py
@@ -26,7 +26,8 @@ def store(
vector: Optional[List[float]] = None,
metadata: Optional[dict] = {},
) -> None:
- """Stores the specified key-value pair in the cache along with metadata."""
+ """Stores the specified key-value pair in the cache along with
+ metadata."""
raise NotImplementedError
def _refresh_ttl(self, key: str):
diff --git a/redisvl/llmcache/semantic.py b/redisvl/llmcache/semantic.py
index 60f2452b..84b0e38e 100644
--- a/redisvl/llmcache/semantic.py
+++ b/redisvl/llmcache/semantic.py
@@ -63,7 +63,7 @@ def __init__(
index = SearchIndex(
name=index_name, prefix=prefix, fields=self._default_fields
)
- index.connect(url=redis_url, **connection_args)
+ index.connect(redis_url=redis_url, **connection_args)
else:
raise ValueError(
"Index name and prefix must be provided if not constructing from an existing index."
@@ -135,11 +135,12 @@ def set_threshold(self, threshold: float):
self._threshold = float(threshold)
def clear(self):
- """Clear the LLMCache of all keys in the index"""
+ """Clear the LLMCache of all keys in the index."""
client = self._index.client
+ prefix = self._index.prefix
if client:
with client.pipeline(transaction=False) as pipe:
- for key in client.scan_iter(match=f"{self._index._prefix}:*"):
+ for key in client.scan_iter(match=f"{prefix}:*"):
pipe.delete(key)
pipe.execute()
else:
diff --git a/redisvl/query/filter.py b/redisvl/query/filter.py
index e2d46834..aec14454 100644
--- a/redisvl/query/filter.py
+++ b/redisvl/query/filter.py
@@ -95,7 +95,7 @@ class Tag(FilterField):
SUPPORTED_VAL_TYPES = (list, set, tuple, str, type(None))
def __init__(self, field: str):
- """Create a Tag FilterField
+ """Create a Tag FilterField.
Args:
field (str): The name of the tag field in the index to be queried against
@@ -121,7 +121,7 @@ def _set_tag_value(
@check_operator_misuse
def __eq__(self, other: Union[List[str], str]) -> "FilterExpression":
- """Create a Tag equality filter expression
+ """Create a Tag equality filter expression.
Args:
other (Union[List[str], str]): The tag(s) to filter on.
@@ -135,7 +135,7 @@ def __eq__(self, other: Union[List[str], str]) -> "FilterExpression":
@check_operator_misuse
def __ne__(self, other) -> "FilterExpression":
- """Create a Tag inequality filter expression
+ """Create a Tag inequality filter expression.
Args:
other (Union[List[str], str]): The tag(s) to filter on.
@@ -152,7 +152,7 @@ def _formatted_tag_value(self) -> str:
return "|".join([self.escaper.escape(tag) for tag in self._value])
def __str__(self) -> str:
- """Return the Redis Query syntax for a Tag filter expression"""
+ """Return the Redis Query syntax for a Tag filter expression."""
if not self._value:
return "*"
@@ -175,7 +175,7 @@ def __init__(self, longitude: float, latitude: float, unit: str = "km"):
class GeoRadius(GeoSpec):
- """A GeoRadius is a GeoSpec representing a geographic radius"""
+ """A GeoRadius is a GeoSpec representing a geographic radius."""
def __init__(
self,
@@ -194,7 +194,6 @@ def __init__(
Raises:
ValueError: If the unit is not one of "m", "km", "mi", or "ft".
-
"""
super().__init__(longitude, latitude, unit)
self._radius = radius
@@ -204,24 +203,22 @@ def get_args(self) -> List[Union[float, int, str]]:
class Geo(FilterField):
- """A Geo is a FilterField representing a geographic (lat/lon)
- field in a Redis index.
-
- """
+ """A Geo is a FilterField representing a geographic (lat/lon) field in a
+ Redis index."""
OPERATORS: Dict[FilterOperator, str] = {
FilterOperator.EQ: "==",
FilterOperator.NE: "!=",
}
OPERATOR_MAP: Dict[FilterOperator, str] = {
- FilterOperator.EQ: "@%s:[%f %f %i %s]",
- FilterOperator.NE: "(-@%s:[%f %f %i %s])",
+ FilterOperator.EQ: "@%s:[%s %s %i %s]",
+ FilterOperator.NE: "(-@%s:[%s %s %i %s])",
}
SUPPORTED_VAL_TYPES = (GeoSpec, type(None))
@check_operator_misuse
def __eq__(self, other) -> "FilterExpression":
- """Create a Geographic equality filter expression
+ """Create a Geographic equality filter expression.
Args:
other (GeoSpec): The geographic spec to filter on.
@@ -235,7 +232,7 @@ def __eq__(self, other) -> "FilterExpression":
@check_operator_misuse
def __ne__(self, other) -> "FilterExpression":
- """Create a Geographic inequality filter expression
+ """Create a Geographic inequality filter expression.
Args:
other (GeoSpec): The geographic spec to filter on.
@@ -248,7 +245,7 @@ def __ne__(self, other) -> "FilterExpression":
return FilterExpression(str(self))
def __str__(self) -> str:
- """Return the Redis Query syntax for a Geographic filter expression"""
+ """Return the Redis Query syntax for a Geographic filter expression."""
if not self._value:
return "*"
@@ -270,17 +267,17 @@ class Num(FilterField):
FilterOperator.GE: ">=",
}
OPERATOR_MAP: Dict[FilterOperator, str] = {
- FilterOperator.EQ: "@%s:[%i %i]",
- FilterOperator.NE: "(-@%s:[%i %i])",
- FilterOperator.GT: "@%s:[(%i +inf]",
- FilterOperator.LT: "@%s:[-inf (%i]",
- FilterOperator.GE: "@%s:[%i +inf]",
- FilterOperator.LE: "@%s:[-inf %i]",
+ FilterOperator.EQ: "@%s:[%s %s]",
+ FilterOperator.NE: "(-@%s:[%s %s])",
+ FilterOperator.GT: "@%s:[(%s +inf]",
+ FilterOperator.LT: "@%s:[-inf (%s]",
+ FilterOperator.GE: "@%s:[%s +inf]",
+ FilterOperator.LE: "@%s:[-inf %s]",
}
SUPPORTED_VAL_TYPES = (int, float, type(None))
def __eq__(self, other: int) -> "FilterExpression":
- """Create a Numeric equality filter expression
+ """Create a Numeric equality filter expression.
Args:
other (int): The value to filter on.
@@ -293,7 +290,7 @@ def __eq__(self, other: int) -> "FilterExpression":
return FilterExpression(str(self))
def __ne__(self, other: int) -> "FilterExpression":
- """Create a Numeric inequality filter expression
+ """Create a Numeric inequality filter expression.
Args:
other (int): The value to filter on.
@@ -306,7 +303,7 @@ def __ne__(self, other: int) -> "FilterExpression":
return FilterExpression(str(self))
def __gt__(self, other: int) -> "FilterExpression":
- """Create a Numeric greater than filter expression
+ """Create a Numeric greater than filter expression.
Args:
other (int): The value to filter on.
@@ -319,7 +316,7 @@ def __gt__(self, other: int) -> "FilterExpression":
return FilterExpression(str(self))
def __lt__(self, other: int) -> "FilterExpression":
- """Create a Numeric less than filter expression
+ """Create a Numeric less than filter expression.
Args:
other (int): The value to filter on.
@@ -332,7 +329,7 @@ def __lt__(self, other: int) -> "FilterExpression":
return FilterExpression(str(self))
def __ge__(self, other: int) -> "FilterExpression":
- """Create a Numeric greater than or equal to filter expression
+ """Create a Numeric greater than or equal to filter expression.
Args:
other (int): The value to filter on.
@@ -345,7 +342,7 @@ def __ge__(self, other: int) -> "FilterExpression":
return FilterExpression(str(self))
def __le__(self, other: int) -> "FilterExpression":
- """Create a Numeric less than or equal to filter expression
+ """Create a Numeric less than or equal to filter expression.
Args:
other (int): The value to filter on.
@@ -358,7 +355,7 @@ def __le__(self, other: int) -> "FilterExpression":
return FilterExpression(str(self))
def __str__(self) -> str:
- """Return the Redis Query syntax for a Numeric filter expression"""
+ """Return the Redis Query syntax for a Numeric filter expression."""
if not self._value:
return "*"
@@ -389,8 +386,8 @@ class Text(FilterField):
@check_operator_misuse
def __eq__(self, other: str) -> "FilterExpression":
- """Create a Text equality filter expression. These expressions
- yield filters that enforce an exact match on the supplied term(s).
+ """Create a Text equality filter expression. These expressions yield
+ filters that enforce an exact match on the supplied term(s).
Args:
other (str): The text value to filter on.
@@ -405,8 +402,8 @@ def __eq__(self, other: str) -> "FilterExpression":
@check_operator_misuse
def __ne__(self, other: str) -> "FilterExpression":
"""Create a Text inequality filter expression. These expressions yield
- negated filters on exact matches on the supplied term(s). Opposite of
- an equality filter expression.
+ negated filters on exact matches on the supplied term(s). Opposite of an
+ equality filter expression.
Args:
other (str): The text value to filter on.
@@ -420,8 +417,9 @@ def __ne__(self, other: str) -> "FilterExpression":
def __mod__(self, other: str) -> "FilterExpression":
"""Create a Text "LIKE" filter expression. A flexible expression that
- yields filters that can use a variety of additional operators like
- wildcards (*), fuzzy matches (%%), or combinatorics (|) of the supplied term(s).
+ yields filters that can use a variety of additional operators like
+ wildcards (*), fuzzy matches (%%), or combinatorics (|) of the supplied
+ term(s).
Args:
other (str): The text value to filter on.
diff --git a/redisvl/query/query.py b/redisvl/query/query.py
index 36efc014..c24ad385 100644
--- a/redisvl/query/query.py
+++ b/redisvl/query/query.py
@@ -210,7 +210,6 @@ def __init__(
Raises:
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
-
"""
super().__init__(
vector,
@@ -290,7 +289,6 @@ def __init__(
Raises:
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
-
"""
super().__init__(
vector,
diff --git a/redisvl/schema.py b/redisvl/schema.py
index 55dbad7f..eebf7318 100644
--- a/redisvl/schema.py
+++ b/redisvl/schema.py
@@ -1,6 +1,6 @@
+from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
-from uuid import uuid4
import yaml
from pydantic import BaseModel, Field, validator
@@ -17,6 +17,7 @@
class BaseField(BaseModel):
name: str = Field(...)
sortable: Optional[bool] = False
+ as_name: Optional[str] = None
class TextFieldSchema(BaseField):
@@ -32,6 +33,7 @@ def as_field(self):
no_stem=self.no_stem,
phonetic_matcher=self.phonetic_matcher,
sortable=self.sortable,
+ as_name=self.as_name,
)
@@ -45,17 +47,18 @@ def as_field(self):
separator=self.separator,
case_sensitive=self.case_sensitive,
sortable=self.sortable,
+ as_name=self.as_name,
)
class NumericFieldSchema(BaseField):
def as_field(self):
- return NumericField(self.name, sortable=self.sortable)
+ return NumericField(self.name, sortable=self.sortable, as_name=self.as_name)
class GeoFieldSchema(BaseField):
def as_field(self):
- return GeoField(self.name, sortable=self.sortable)
+ return GeoField(self.name, sortable=self.sortable, as_name=self.as_name)
class BaseVectorField(BaseModel):
@@ -65,6 +68,7 @@ class BaseVectorField(BaseModel):
datatype: str = Field(default="FLOAT32")
distance_metric: str = Field(default="COSINE")
initial_cap: Optional[int] = None
+ as_name: Optional[str] = None
@validator("algorithm", "datatype", "distance_metric", pre=True)
def uppercase_strings(cls, v):
@@ -90,7 +94,7 @@ def as_field(self):
field_data = super().as_field()
if self.block_size is not None:
field_data["BLOCK_SIZE"] = self.block_size
- return VectorField(self.name, self.algorithm, field_data)
+ return VectorField(self.name, self.algorithm, field_data, as_name=self.as_name)
class HNSWVectorField(BaseVectorField):
@@ -111,13 +115,22 @@ def as_field(self):
"EPSILON": self.epsilon,
}
)
- return VectorField(self.name, self.algorithm, field_data)
+ return VectorField(self.name, self.algorithm, field_data, as_name=self.as_name)
+
+
+class StorageType(Enum):
+ HASH = "hash"
+ JSON = "json"
class IndexModel(BaseModel):
- name: str = Field(...)
- prefix: Optional[str] = Field(default="")
- storage_type: Optional[str] = Field(default="hash")
+ """Represents the schema for an index, including its name, optional prefix,
+ and the storage type used."""
+
+ name: str
+ prefix: str = "rvl"
+ key_separator: str = ":"
+ storage_type: StorageType = StorageType.HASH
class FieldsModel(BaseModel):
@@ -132,12 +145,6 @@ class SchemaModel(BaseModel):
index: IndexModel = Field(...)
fields: FieldsModel = Field(...)
- @validator("index")
- def validate_index(cls, v):
- if v.storage_type not in ["hash", "json"]:
- raise ValueError(f"Storage type {v.storage_type} not supported")
- return v
-
@property
def index_fields(self):
redis_fields = []
@@ -160,21 +167,12 @@ def read_schema(file_path: str):
return SchemaModel(**schema)
-class MetadataSchemaGenerator:
- """
- A class to generate a schema for metadata, categorizing fields into text, numeric, and tag types.
- """
+class SchemaGenerator:
+ """A class to generate a schema for metadata, categorizing fields into text,
+ numeric, and tag types."""
def _test_numeric(self, value) -> bool:
- """
- Test if the given value can be represented as a numeric value.
-
- Args:
- value: The value to test.
-
- Returns:
- bool: True if the value can be converted to float, False otherwise.
- """
+ """Test if a value is numeric."""
try:
float(value)
return True
@@ -182,72 +180,62 @@ def _test_numeric(self, value) -> bool:
return False
def _infer_type(self, value) -> Optional[str]:
- """
- Infer the type of the given value.
-
- Args:
- value: The value to infer the type of.
-
- Returns:
- Optional[str]: The inferred type of the value, or None if the type is unrecognized or the value is empty.
- """
- if value is None or value == "":
+ """Infer the type of a value."""
+ if value in [None, ""]:
return None
- elif self._test_numeric(value):
+ if self._test_numeric(value):
return "numeric"
- elif isinstance(value, (list, set, tuple)) and all(
+ if isinstance(value, (list, set, tuple)) and all(
isinstance(v, str) for v in value
):
return "tag"
- elif isinstance(value, str):
- return "text"
- else:
- return "unknown"
+ return "text" if isinstance(value, str) else "unknown"
def generate(
- self, metadata: Dict[str, Any], strict: Optional[bool] = False
+ self, metadata: Dict[str, Any], strict: bool = False
) -> Dict[str, List[Dict[str, Any]]]:
- """
- Generate a schema from the provided metadata.
-
- This method categorizes each metadata field into text, numeric, or tag types based on the field values.
- It also allows forcing strict type determination by raising an exception if a type cannot be inferred.
+ """Generate a schema from metadata.
Args:
- metadata: The metadata dictionary to generate the schema from.
- strict: If True, the method will raise an exception for fields where the type cannot be determined.
-
- Returns:
- Dict[str, List[Dict[str, Any]]]: A dictionary with keys 'text', 'numeric', and 'tag', each mapping to a list of field schemas.
+ metadata (Dict[str, Any]): Metadata object to validate and
+ generate schema.
+ strict (bool, optional): Whether to generate schema in strict
+ mode. Defaults to False.
Raises:
- ValueError: If the force parameter is True and a field's type cannot be determined.
+ ValueError: Unable to determine schema field type for a
+ key-value pair.
+
+ Returns:
+ Dict[str, List[Dict[str, Any]]]: Output metadata schema.
"""
result: Dict[str, List[Dict[str, Any]]] = {"text": [], "numeric": [], "tag": []}
+ field_classes = {
+ "text": TextFieldSchema,
+ "tag": TagFieldSchema,
+ "numeric": NumericFieldSchema,
+ }
for key, value in metadata.items():
field_type = self._infer_type(value)
- if field_type in ["unknown", None]:
+ if field_type is None or field_type == "unknown":
if strict:
raise ValueError(
- f"Unable to determine field type for key '{key}' with value '{value}'"
+ f"Unable to determine field type for key '{key}' with"
+ f" value '{value}'"
)
print(
- f"Warning: Unable to determine field type for key '{key}' with value '{value}'"
+ f"Warning: Unable to determine field type for key '{key}'"
+ f" with value '{value}'"
)
continue
- # Extract the field class with defaults
- field_class = {
- "text": TextFieldSchema,
- "tag": TagFieldSchema,
- "numeric": NumericFieldSchema,
- }.get(
- field_type # type: ignore
- )
-
- if field_class:
- result[field_type].append(field_class(name=key).dict(exclude_none=True)) # type: ignore
+ if isinstance(field_type, str):
+ field_class = field_classes.get(field_type)
+ if field_class:
+ result[field_type].append(
+ field_class(name=key).dict(exclude_none=True)
+ )
return result
diff --git a/redisvl/storage.py b/redisvl/storage.py
new file mode 100644
index 00000000..a1b9daba
--- /dev/null
+++ b/redisvl/storage.py
@@ -0,0 +1,496 @@
+import asyncio
+import uuid
+from typing import Any, Callable, Dict, Iterable, List, Optional
+
+from redis import Redis
+from redis.asyncio import Redis as AsyncRedis
+from redis.commands.search.indexDefinition import IndexType
+
+from redisvl.utils.utils import convert_bytes
+
+
+class BaseStorage:
+ type: IndexType
+ DEFAULT_BATCH_SIZE: int = 200
+ DEFAULT_WRITE_CONCURRENCY: int = 20
+
+ def __init__(self, prefix: str, key_separator: str):
+ """Initialize the BaseStorage with a specific prefix and key separator
+ for Redis keys.
+
+ Args:
+ prefix (str): The prefix to prepend to each Redis key.
+ key_separator (str): The separator to use between the prefix and
+ the key value.
+ """
+ self._prefix = prefix
+ self._key_separator = key_separator
+
+ @staticmethod
+ def _key(key_value: str, prefix: str, key_separator: str) -> str:
+ """Create a Redis key using a combination of a prefix, separator, and
+ the key value.
+
+ Args:
+ key_value (str): The unique identifier for the Redis entry.
+ prefix (str): A prefix to append before the key value.
+ key_separator (str): A separator to insert between prefix
+ and key value.
+
+ Returns:
+ str: The fully formed Redis key.
+ """
+ if not prefix:
+ return key_value
+ else:
+ return f"{prefix}{key_separator}{key_value}"
+
+ def _create_key(self, obj: Dict[str, Any], key_field: Optional[str] = None) -> str:
+ """Construct a Redis key for a given object, optionally using a
+ specified field from the object as the key.
+
+ Args:
+ obj (Dict[str, Any]): The object from which to construct the key.
+ key_field (Optional[str], optional): The field to use as the
+ key, if provided.
+
+ Returns:
+ str: The constructed Redis key for the object.
+
+ Raises:
+ ValueError: If the key_field is not found in the object.
+ """
+ if key_field is None:
+ key_value = uuid.uuid4().hex
+ else:
+ try:
+ key_value = obj[key_field] # type: ignore
+ except KeyError:
+ raise ValueError(f"Key field {key_field} not found in record {obj}")
+
+ return self._key(
+ key_value, prefix=self._prefix, key_separator=self._key_separator
+ )
+
+ @staticmethod
+ def _preprocess(obj: Any, preprocess: Optional[Callable] = None) -> Dict[str, Any]:
+ """Apply a preprocessing function to the object if provided.
+
+ Args:
+ preprocess (Optional[Callable], optional): Function to
+ process the object.
+ obj (Any): Object to preprocess.
+
+ Returns:
+ Dict[str, Any]: Processed object as a dictionary.
+ """
+ # optionally preprocess object
+ if preprocess:
+ obj = preprocess(obj)
+ return obj
+
+ @staticmethod
+ async def _apreprocess(
+ obj: Any, preprocess: Optional[Callable] = None
+ ) -> Dict[str, Any]:
+ """Asynchronously apply a preprocessing function to the object if
+ provided.
+
+ Args:
+ preprocess (Optional[Callable], optional): Async function to
+ process the object.
+ obj (Any): Object to preprocess.
+
+ Returns:
+ Dict[str, Any]: Processed object as a dictionary.
+ """
+ # optionally async preprocess object
+ if preprocess:
+ obj = await preprocess(obj)
+ return obj
+
+ def _validate(self, obj: Dict[str, Any]):
+ """Validate the object before writing to Redis. This method should be
+ implemented by subclasses.
+
+ Args:
+ obj (Dict[str, Any]): The object to validate.
+ """
+ raise NotImplementedError
+
+ @staticmethod
+ def _set(client: Redis, key: str, obj: Dict[str, Any]):
+ """Synchronously set the value in Redis for the given key.
+
+ Args:
+ client (Redis): The Redis client instance.
+ key (str): The key under which to store the object.
+ obj (Dict[str, Any]): The object to store in Redis.
+ """
+ raise NotImplementedError
+
+ @staticmethod
+ async def _aset(client: AsyncRedis, key: str, obj: Dict[str, Any]):
+ """Asynchronously set the value in Redis for the given key.
+
+ Args:
+ client (AsyncRedis): The Redis client instance.
+ key (str): The key under which to store the object.
+ obj (Dict[str, Any]): The object to store in Redis.
+ """
+ raise NotImplementedError
+
+ @staticmethod
+ def _get(client: Redis, key: str) -> Dict[str, Any]:
+ """Synchronously get the value from Redis for the given key.
+
+ Args:
+ client (Redis): The Redis client instance.
+ key (str): The key for which to retrieve the object.
+
+ Returns:
+ Dict[str, Any]: The retrieved object from Redis.
+ """
+ raise NotImplementedError
+
+ @staticmethod
+ async def _aget(client: AsyncRedis, key: str) -> Dict[str, Any]:
+ """Asynchronously get the value from Redis for the given key.
+
+ Args:
+ client (AsyncRedis): The Redis client instance.
+ key (str): The key for which to retrieve the object.
+
+ Returns:
+ Dict[str, Any]: The retrieved object from Redis.
+ """
+ raise NotImplementedError
+
+ def write(
+ self,
+ redis_client: Redis,
+ objects: Iterable[Any],
+ key_field: Optional[str] = None,
+ keys: Optional[Iterable[str]] = None,
+ ttl: Optional[int] = None,
+ preprocess: Optional[Callable] = None,
+ batch_size: Optional[int] = None,
+ ):
+ """Write a batch of objects to Redis as hash entries.
+
+ Args:
+ redis_client (Redis): A Redis client used for writing data.
+ objects (Iterable[Any]): An iterable of objects to store.
+ key_field (Optional[str], optional): Field used as the key for
+ each object. Defaults to None.
+ keys (Optional[Iterable[str]], optional): Optional iterable of
+ keys, must match the length of objects if provided.
+ ttl (Optional[int], optional): Time-to-live in seconds for each
+ key. Defaults to None.
+ preprocess (Optional[Callable], optional): A function to preprocess
+ objects before storage. Defaults to None.
+ batch_size (Optional[int], optional): Number of objects to write
+ in a single Redis pipeline execution.
+
+ Raises:
+ ValueError: If the length of provided keys does not match the
+ length of objects.
+ """
+ if keys and len(keys) != len(objects): # type: ignore
+ raise ValueError("Length of keys does not match the length of objects")
+
+ if batch_size is None:
+ batch_size = (
+ self.DEFAULT_BATCH_SIZE
+ ) # Use default or calculate based on the input data
+
+ keys_iterator = iter(keys) if keys else None
+
+ with redis_client.pipeline(transaction=False) as pipe:
+ for i, obj in enumerate(objects, start=1):
+ key = (
+ next(keys_iterator)
+ if keys_iterator
+ else self._create_key(obj, key_field)
+ )
+ obj = self._preprocess(obj, preprocess)
+ self._validate(obj)
+ self._set(pipe, key, obj)
+ if ttl:
+ pipe.expire(key, ttl) # Set TTL if provided
+ # execute mini batch
+ if i % batch_size == 0:
+ pipe.execute()
+ # clean up batches if needed
+ if i % batch_size != 0:
+ pipe.execute()
+
+ async def awrite(
+ self,
+ redis_client: AsyncRedis,
+ objects: Iterable[Any],
+ key_field: Optional[str] = None,
+ keys: Optional[Iterable[str]] = None,
+ ttl: Optional[int] = None,
+ preprocess: Optional[Callable] = None,
+ concurrency: Optional[int] = None,
+ ):
+ """Asynchronously write objects to Redis as hash entries with
+ concurrency control.
+
+ Args:
+ redis_client (AsyncRedis): An asynchronous Redis client used
+ for writing data.
+ objects (Iterable[Any]): An iterable of objects to store.
+ key_field (Optional[str], optional): Field used as the key for each
+ object. Defaults to None.
+ keys (Optional[Iterable[str]], optional): Optional iterable of keys.
+ Must match the length of objects if provided.
+ ttl (Optional[int], optional): Time-to-live in seconds for each key.
+ Defaults to None.
+ preprocess (Optional[Callable], optional): An async function to
+ preprocess objects before storage. Defaults to None.
+ concurrency (Optional[int], optional): The maximum number of
+ concurrent write operations. Defaults to class's default
+ concurrency level.
+
+ Raises:
+ ValueError: If the length of provided keys does not match the
+ length of objects.
+ """
+ if keys and len(keys) != len(objects): # type: ignore
+ raise ValueError("Length of keys does not match the length of objects")
+
+ if not concurrency:
+ concurrency = self.DEFAULT_WRITE_CONCURRENCY
+
+ semaphore = asyncio.Semaphore(concurrency)
+ keys_iterator = iter(keys) if keys else None
+
+ async def _load(obj: Dict[str, Any], key: Optional[str] = None) -> None:
+ async with semaphore:
+ if key is None:
+ key = self._create_key(obj, key_field)
+ obj = await self._apreprocess(obj, preprocess)
+ self._validate(obj)
+ await self._aset(redis_client, key, obj)
+ if ttl:
+ await redis_client.expire(key, ttl)
+
+ if keys_iterator:
+ tasks = [
+ asyncio.create_task(_load(obj, next(keys_iterator))) for obj in objects
+ ]
+ else:
+ tasks = [asyncio.create_task(_load(obj)) for obj in objects]
+
+ await asyncio.gather(*tasks)
+
+ def get(
+ self, redis_client: Redis, keys: Iterable[str], batch_size: Optional[int] = None
+ ) -> List[Dict[str, Any]]:
+ """Retrieve objects from Redis by keys.
+
+ Args:
+ redis_client (Redis): Synchronous Redis client.
+ keys (Iterable[str]): Keys to retrieve from Redis.
+ batch_size (Optional[int], optional): Number of objects to write
+ in a single Redis pipeline execution. Defaults to class's
+ default batch size.
+
+ Returns:
+ List[Dict[str, Any]]: List of objects pulled from redis.
+ """
+ results: List = []
+
+ if not isinstance(keys, Iterable): # type: ignore
+ raise TypeError("Keys must be an iterable of strings")
+
+ if len(keys) == 0: # type: ignore
+ return []
+
+ if batch_size is None:
+ batch_size = (
+ self.DEFAULT_BATCH_SIZE
+ ) # Use default or calculate based on the input data
+
+ # Use a pipeline to batch the retrieval
+ with redis_client.pipeline(transaction=False) as pipe:
+ for i, key in enumerate(keys, start=1):
+ self._get(pipe, key)
+ if i % batch_size == 0:
+ results.extend(pipe.execute())
+ if i % batch_size != 0:
+ results.extend(pipe.execute())
+
+ # Process results
+ return convert_bytes(results)
+
+ async def aget(
+ self,
+ redis_client: AsyncRedis,
+ keys: Iterable[str],
+ concurrency: Optional[int] = None,
+ ) -> List[Dict[str, Any]]:
+ """Asynchronously retrieve objects from Redis by keys, with concurrency
+ control.
+
+ Args:
+ redis_client (AsyncRedis): Asynchronous Redis client.
+ keys (Iterable[str]): Keys to retrieve from Redis.
+ concurrency (Optional[int], optional): The number of concurrent
+ requests to make.
+
+ Returns:
+ Dict[str, Any]: Dictionary with keys and their corresponding
+ objects.
+ """
+ if not isinstance(keys, Iterable): # type: ignore
+ raise TypeError("Keys must be an iterable of strings")
+
+ if len(keys) == 0: # type: ignore
+ return []
+
+ if not concurrency:
+ concurrency = self.DEFAULT_WRITE_CONCURRENCY
+
+ semaphore = asyncio.Semaphore(concurrency)
+
+ async def _get(key: str) -> Dict[str, Any]:
+ async with semaphore:
+ result = await self._aget(redis_client, key)
+ return result
+
+ tasks = [asyncio.create_task(_get(key)) for key in keys]
+ results = await asyncio.gather(*tasks)
+ return convert_bytes(results)
+
+
+class HashStorage(BaseStorage):
+ type: IndexType = IndexType.HASH
+
+ def _validate(self, obj: Dict[str, Any]):
+ """Validate that the given object is a dictionary, suitable for storage
+ as a Redis hash.
+
+ Args:
+ obj (Dict[str, Any]): The object to validate.
+
+ Raises:
+ TypeError: If the object is not a dictionary.
+ """
+ if not isinstance(obj, dict):
+ raise TypeError("Object must be a dictionary.")
+
+ @staticmethod
+ def _set(client: Redis, key: str, obj: Dict[str, Any]):
+ """Synchronously set a hash value in Redis for the given key.
+
+ Args:
+ client (Redis): The Redis client instance.
+ key (str): The key under which to store the hash.
+ obj (Dict[str, Any]): The hash to store in Redis.
+ """
+ client.hset(name=key, mapping=obj) # type: ignore
+
+ @staticmethod
+ async def _aset(client: AsyncRedis, key: str, obj: Dict[str, Any]):
+ """Asynchronously set a hash value in Redis for the given key.
+
+ Args:
+ client (AsyncRedis): The Redis client instance.
+ key (str): The key under which to store the hash.
+ obj (Dict[str, Any]): The hash to store in Redis.
+ """
+ await client.hset(name=key, mapping=obj) # type: ignore
+
+ @staticmethod
+ def _get(client: Redis, key: str) -> Dict[str, Any]:
+ """Synchronously retrieve a hash value from Redis for the given key.
+
+ Args:
+ client (Redis): The Redis client instance.
+ key (str): The key for which to retrieve the hash.
+
+ Returns:
+ Dict[str, Any]: The retrieved hash from Redis.
+ """
+ return client.hgetall(key)
+
+ @staticmethod
+ async def _aget(client: AsyncRedis, key: str) -> Dict[str, Any]:
+ """Asynchronously retrieve a hash value from Redis for the given key.
+
+ Args:
+ client (AsyncRedis): The Redis client instance.
+ key (str): The key for which to retrieve the hash.
+
+ Returns:
+ Dict[str, Any]: The retrieved hash from Redis.
+ """
+ return await client.hgetall(key)
+
+
+class JsonStorage(BaseStorage):
+ type: IndexType = IndexType.JSON
+
+ def _validate(self, obj: Dict[str, Any]):
+ """Validate that the given object is a dictionary, suitable for JSON
+ serialization.
+
+ Args:
+ obj (Dict[str, Any]): The object to validate.
+
+ Raises:
+ TypeError: If the object is not a dictionary.
+ """
+ if not isinstance(obj, dict):
+ raise TypeError("Object must be a dictionary.")
+
+ @staticmethod
+ def _set(client: Redis, key: str, obj: Dict[str, Any]):
+ """Synchronously set a JSON obj in Redis for the given key.
+
+ Args:
+ client (AsyncRedis): The Redis client instance.
+ key (str): The key under which to store the JSON obj.
+ obj (Dict[str, Any]): The JSON obj to store in Redis.
+ """
+ client.json().set(key, "$", obj)
+
+ @staticmethod
+ async def _aset(client: AsyncRedis, key: str, obj: Dict[str, Any]):
+ """Asynchronously set a JSON obj in Redis for the given key.
+
+ Args:
+ client (AsyncRedis): The Redis client instance.
+ key (str): The key under which to store the JSON obj.
+ obj (Dict[str, Any]): The JSON obj to store in Redis.
+ """
+ await client.json().set(key, "$", obj)
+
+ @staticmethod
+ def _get(client: Redis, key: str) -> Dict[str, Any]:
+ """Synchronously retrieve a JSON obj from Redis for the given key.
+
+ Args:
+ client (AsyncRedis): The Redis client instance.
+ key (str): The key for which to retrieve the JSON obj.
+
+ Returns:
+ Dict[str, Any]: The retrieved JSON obj from Redis.
+ """
+ return client.json().get(key)
+
+ @staticmethod
+ async def _aget(client: AsyncRedis, key: str) -> Dict[str, Any]:
+ """Asynchronously retrieve a JSON obj from Redis for the given key.
+
+ Args:
+ client (AsyncRedis): The Redis client instance.
+ key (str): The key for which to retrieve the JSON obj.
+
+ Returns:
+ Dict[str, Any]: The retrieved JSON obj from Redis.
+ """
+ return await client.json().get(key)
diff --git a/redisvl/utils/connection.py b/redisvl/utils/connection.py
index e9a04647..a0f9da4a 100644
--- a/redisvl/utils/connection.py
+++ b/redisvl/utils/connection.py
@@ -1,5 +1,4 @@
import os
-from functools import wraps
from typing import Optional
# TODO: handle connection errors.
@@ -32,7 +31,7 @@ def get_async_redis_connection(url: Optional[str] = None, **kwargs):
def get_address_from_env():
- """Get a redis connection from environment variables
+ """Get a redis connection from environment variables.
Returns:
str: Redis URL
@@ -41,18 +40,3 @@ def get_address_from_env():
if not addr:
raise ValueError("REDIS_URL env var not set")
return addr
-
-
-def check_connected(client_variable_name: str):
- def decorator(func):
- @wraps(func)
- def wrapper(self, *args, **kwargs):
- if getattr(self, client_variable_name) is None:
- raise ValueError(
- f"SearchIndex.connect() must be called before calling {func.__name__}"
- )
- return func(self, *args, **kwargs)
-
- return wrapper
-
- return decorator
diff --git a/redisvl/utils/token_escaper.py b/redisvl/utils/token_escaper.py
index 10260866..53e47a73 100644
--- a/redisvl/utils/token_escaper.py
+++ b/redisvl/utils/token_escaper.py
@@ -3,8 +3,9 @@
class TokenEscaper:
- """
- Escape punctuation within an input string. Adapted from RedisOM Python.
+ """Escape punctuation within an input string.
+
+ Adapted from RedisOM Python.
"""
# Characters that RediSearch requires us to escape during queries.
diff --git a/redisvl/utils/utils.py b/redisvl/utils/utils.py
index f9757f73..267359ed 100644
--- a/redisvl/utils/utils.py
+++ b/redisvl/utils/utils.py
@@ -1,8 +1,4 @@
-from typing import TYPE_CHECKING, Any, Dict, List
-
-if TYPE_CHECKING:
- from redis.commands.search.result import Result
- from redis.commands.search.document import Document
+from typing import Any, List
import numpy as np
@@ -57,18 +53,27 @@ def check_redis_modules_exist(client) -> None:
raise ValueError(error_message)
+async def check_async_redis_modules_exist(client) -> None:
+ """Check if the correct Redis modules are installed."""
+ installed_modules = await client.module_list()
+ installed_modules = {
+ module[b"name"].decode("utf-8"): module for module in installed_modules
+ }
+ for module in REDIS_REQUIRED_MODULES:
+ if module["name"] in installed_modules and int(
+ installed_modules[module["name"]][b"ver"]
+ ) >= int(
+ module["ver"]
+ ): # type: ignore[call-overload]
+ return
+ # otherwise raise error
+ error_message = (
+ "You must add the RediSearch (>= 2.4) module from Redis Stack. "
+ "Please refer to Redis Stack docs: https://redis.io/docs/stack/"
+ )
+ raise ValueError(error_message)
+
+
def array_to_buffer(array: List[float], dtype: Any = np.float32) -> bytes:
"""Convert a list of floats into a numpy byte string."""
return np.array(array).astype(dtype).tobytes()
-
-
-def process_results(results: "Result") -> List[Dict[str, Any]]:
- """Convert a list of search Result objects into a list of document dicts"""
-
- def _process(doc: "Document") -> Dict[str, Any]:
- d = doc.__dict__
- if "payload" in d:
- del d["payload"]
- return d
-
- return [_process(doc) for doc in results.docs]
diff --git a/redisvl/vectorize/text/huggingface.py b/redisvl/vectorize/text/huggingface.py
index 53db328f..50755716 100644
--- a/redisvl/vectorize/text/huggingface.py
+++ b/redisvl/vectorize/text/huggingface.py
@@ -61,8 +61,8 @@ def embed_many(
batch_size: int = 1000,
as_buffer: bool = False,
) -> List[List[float]]:
- """Asynchronously embed many chunks of texts using the Hugging Face sentence
- transformer.
+ """Asynchronously embed many chunks of texts using the Hugging Face
+ sentence transformer.
Args:
texts (List[str]): List of text chunks to embed.
diff --git a/redisvl/vectorize/text/openai.py b/redisvl/vectorize/text/openai.py
index b9a83162..368e760a 100644
--- a/redisvl/vectorize/text/openai.py
+++ b/redisvl/vectorize/text/openai.py
@@ -10,10 +10,11 @@
class OpenAITextVectorizer(BaseVectorizer):
- """OpenAI text vectorizer
+ """OpenAI text vectorizer.
- This vectorizer uses the OpenAI API to create embeddings for text. It requires an
- API key to be passed in the api_config dictionary. The API key can be obtained from
+ This vectorizer uses the OpenAI API to create embeddings for text. It
+ requires an API key to be passed in the api_config dictionary. The API key
+ can be obtained from
https://api.openai.com/.
"""
diff --git a/redisvl/vectorize/text/vertexai.py b/redisvl/vectorize/text/vertexai.py
index a96d3aa9..7dbc2785 100644
--- a/redisvl/vectorize/text/vertexai.py
+++ b/redisvl/vectorize/text/vertexai.py
@@ -7,10 +7,11 @@
class VertexAITextVectorizer(BaseVectorizer):
- """VertexAI text vectorizer
+ """VertexAI text vectorizer.
- This vectorizer uses the VertexAI Palm 2 embedding model API to create embeddings for text. It requires an
- active GCP project, location, and application credentials.
+ This vectorizer uses the VertexAI Palm 2 embedding model API to create
+ embeddings for text. It requires an active GCP project, location, and
+ application credentials.
"""
def __init__(
diff --git a/tests/integration/test_query.py b/tests/integration/test_query.py
index 8436a5a5..49cad43e 100644
--- a/tests/integration/test_query.py
+++ b/tests/integration/test_query.py
@@ -211,7 +211,7 @@ def filter_test(
location=None,
distance_threshold=0.2,
):
- """Utility function to test filters"""
+ """Utility function to test filters."""
# set the new filter
query.set_filter(_filter)
diff --git a/tests/integration/test_simple.py b/tests/integration/test_simple.py
index 95aaf629..437519e4 100644
--- a/tests/integration/test_simple.py
+++ b/tests/integration/test_simple.py
@@ -1,9 +1,12 @@
from pprint import pprint
import numpy as np
+import pytest
from redisvl.index import SearchIndex
from redisvl.query import VectorQuery
+from redisvl.schema import StorageType
+from redisvl.utils.utils import array_to_buffer
data = [
{
@@ -12,7 +15,7 @@
"age": 1,
"job": "engineer",
"credit_score": "high",
- "user_embedding": np.array([0.1, 0.1, 0.5], dtype=np.float32).tobytes(),
+ "user_embedding": [0.1, 0.1, 0.5],
},
{
"id": 2,
@@ -20,7 +23,7 @@
"age": 2,
"job": "doctor",
"credit_score": "low",
- "user_embedding": np.array([0.1, 0.1, 0.5], dtype=np.float32).tobytes(),
+ "user_embedding": [0.1, 0.1, 0.5],
},
{
"id": 3,
@@ -28,14 +31,14 @@
"age": 3,
"job": "dentist",
"credit_score": "medium",
- "user_embedding": np.array([0.9, 0.9, 0.1], dtype=np.float32).tobytes(),
+ "user_embedding": [0.9, 0.9, 0.1],
},
]
-schema = {
+hash_schema = {
"index": {
- "name": "user_index",
- "prefix": "users",
+ "name": "user_index_hash",
+ "prefix": "users_hash",
"storage_type": "hash",
},
"fields": {
@@ -54,16 +57,51 @@
},
}
+json_schema = {
+ "index": {
+ "name": "user_index_json",
+ "prefix": "users_json",
+ "storage_type": "json",
+ },
+ "fields": {
+ "tag": [
+ {"name": "$.credit_score", "as_name": "credit_score"},
+ {"name": "$.user", "as_name": "user"},
+ ],
+ "text": [{"name": "$.job", "as_name": "job"}],
+ "numeric": [{"name": "$.age", "as_name": "age"}],
+ "vector": [
+ {
+ "name": "$.user_embedding",
+ "as_name": "user_embedding",
+ "dims": 3,
+ "distance_metric": "cosine",
+ "algorithm": "flat",
+ "datatype": "float32",
+ }
+ ],
+ },
+}
-def test_simple(client):
+
+@pytest.mark.parametrize("schema", [hash_schema, json_schema])
+def test_simple(client, schema):
index = SearchIndex.from_dict(schema)
# assign client (only for testing)
index.set_client(client)
# create the index
index.create(overwrite=True)
- # load data into the index in Redis
- index.load(data)
+ # Prepare and load the data based on storage type
+ def hash_preprocess(item: dict) -> dict:
+ return {**item, "user_embedding": array_to_buffer(item["user_embedding"])}
+
+ if index.storage_type == StorageType.HASH:
+ index.load(data, preprocess=hash_preprocess)
+ else:
+ # Load the prepared data into the index
+ print("DATA", data, flush=True)
+ index.load(data)
query = VectorQuery(
vector=[0.1, 0.1, 0.5],
@@ -80,6 +118,7 @@ def test_simple(client):
# users = list(results.docs)
# print(len(users))
users = [doc for doc in results.docs]
+ pprint(users)
assert users[0].user in ["john", "mary"]
assert users[1].user in ["john", "mary"]
diff --git a/tests/sample_hash_schema.yaml b/tests/sample_hash_schema.yaml
new file mode 100644
index 00000000..c4a603c3
--- /dev/null
+++ b/tests/sample_hash_schema.yaml
@@ -0,0 +1,14 @@
+index:
+ name: hash-test
+ prefix: hash
+ key_separator: ':'
+ storage_type: hash
+
+fields:
+ text:
+ - name: sentence
+ vector:
+ - name: embedding
+ dims: 768
+ algorithm: flat
+ distance_metric: cosine
\ No newline at end of file
diff --git a/tests/sample_json_schema.yaml b/tests/sample_json_schema.yaml
new file mode 100644
index 00000000..8f9fd564
--- /dev/null
+++ b/tests/sample_json_schema.yaml
@@ -0,0 +1,16 @@
+index:
+ name: json-test
+ prefix: json
+ key_separator: ':'
+ storage_type: json
+
+fields:
+ text:
+ - name: '$.sentence'
+ as_name: sentence
+ vector:
+ - name: '$.embedding'
+ as_name: embedding
+ dims: 768
+ algorithm: flat
+ distance_metric: cosine
\ No newline at end of file
diff --git a/tests/unit/test_filter.py b/tests/unit/test_filter.py
index 5fef85a5..42088bde 100644
--- a/tests/unit/test_filter.py
+++ b/tests/unit/test_filter.py
@@ -98,6 +98,9 @@ def test_numeric_filter():
nf = Num("numeric_field") <= 5
assert str(nf) == "@numeric_field:[-inf 5]"
+ nf = Num("numeric_field") > 5.5
+ assert str(nf) == "@numeric_field:[-inf 5.5]"
+
nf = Num("numeric_field") <= None
assert str(nf) == "*"
@@ -130,10 +133,10 @@ def test_text_filter():
def test_geo_filter():
geo_f = Geo("geo_field") == GeoRadius(1.0, 2.0, 3, "km")
- assert str(geo_f) == "@geo_field:[1.000000 2.000000 3 km]"
+ assert str(geo_f) == "@geo_field:[1.0 2.0 3 km]"
geo_f = Geo("geo_field") != GeoRadius(1.0, 2.0, 3, "km")
- assert str(geo_f) != "(-@geo_field:[1.000000 2.000000 3 m])"
+ assert str(geo_f) != "(-@geo_field:[1.0 2.0 3 m])"
@pytest.mark.parametrize(
@@ -215,8 +218,8 @@ def test_text_filter(operation, value, expected):
@pytest.mark.parametrize(
"operation, expected",
[
- ("__eq__", "@geo_field:[1.000000 2.000000 3 km]"),
- ("__ne__", "(-@geo_field:[1.000000 2.000000 3 km])"),
+ ("__eq__", "@geo_field:[1.0 2.0 3 km]"),
+ ("__ne__", "(-@geo_field:[1.0 2.0 3 km])"),
],
ids=["eq", "ne"],
)
diff --git a/tests/unit/test_index.py b/tests/unit/test_index.py
index fb6a474f..40505a23 100644
--- a/tests/unit/test_index.py
+++ b/tests/unit/test_index.py
@@ -11,18 +11,18 @@
def test_search_index_get_key():
si = SearchIndex("my_index", fields=fields)
key = si.key("foo")
- assert key.startswith(si._prefix)
+ assert key.startswith(si.prefix)
assert "foo" in key
- key = si._create_key({"id": "foo"})
- assert key.startswith(si._prefix)
+ key = si._storage._create_key({"id": "foo"})
+ assert key.startswith(si.prefix)
assert "foo" not in key
def test_search_index_no_prefix():
# specify None as the prefix...
- si = SearchIndex("my_index", prefix=None, fields=fields)
+ si = SearchIndex("my_index", prefix="", fields=fields)
key = si.key("foo")
- assert not si._prefix
+ assert not si.prefix
assert key == "foo"
@@ -40,7 +40,7 @@ def test_search_index_create(client, redis_url):
assert si.exists()
assert "my_index" in convert_bytes(si.client.execute_command("FT._LIST"))
- s1_2 = SearchIndex.from_existing("my_index", url=redis_url)
+ s1_2 = SearchIndex.from_existing("my_index", redis_url=redis_url)
assert s1_2.info()["index_name"] == si.info()["index_name"]
si.create(overwrite=False)
@@ -133,7 +133,7 @@ async def test_async_search_index_load(async_client):
def test_search_index_delete_nonexistent(client):
si = SearchIndex("my_index", fields=fields)
si.set_client(client)
- with pytest.raises(redis.exceptions.ResponseError):
+ with pytest.raises(ValueError):
si.delete()
@@ -141,7 +141,7 @@ def test_search_index_delete_nonexistent(client):
async def test_async_search_index_delete_nonexistent(async_client):
asi = AsyncSearchIndex("my_index", fields=fields)
asi.set_client(async_client)
- with pytest.raises(redis.exceptions.ResponseError):
+ with pytest.raises(ValueError):
await asi.delete()
diff --git a/tests/unit/test_query_types.py b/tests/unit/test_query_types.py
new file mode 100644
index 00000000..8fc380fe
--- /dev/null
+++ b/tests/unit/test_query_types.py
@@ -0,0 +1,56 @@
+import pytest
+from redis.commands.search.document import Document
+from redis.commands.search.query import Query
+from redis.commands.search.result import Result
+
+from redisvl.index import process_results
+from redisvl.query import CountQuery, FilterQuery, VectorQuery
+from redisvl.query.filter import FilterExpression, Tag
+
+# Sample data for testing
+sample_vector = [0.1, 0.2, 0.3, 0.4]
+
+
+# Test Cases
+
+
+def test_count_query():
+ # Create a filter expression
+ filter_expression = Tag("brand") == "Nike"
+ count_query = CountQuery(filter_expression)
+
+ # Check properties
+ assert isinstance(count_query.query, Query)
+ assert isinstance(count_query.params, dict)
+ assert count_query.params == {}
+
+ fake_result = Result([2], "")
+ assert process_results(fake_result, count_query, "json") == 2
+
+
+def test_filter_query():
+ # Create a filter expression
+ filter_expression = Tag("brand") == "Nike"
+ return_fields = ["brand", "price"]
+ filter_query = FilterQuery(return_fields, filter_expression, 10)
+
+ # Check properties
+ assert filter_query._return_fields == return_fields
+ assert filter_query._num_results == 10
+ assert filter_query.get_filter() == filter_expression
+ assert isinstance(filter_query.query, Query)
+ assert isinstance(filter_query.params, dict)
+ assert filter_query.params == {}
+
+
+def test_vector_query():
+ # Create a vector query
+ vector_query = VectorQuery(sample_vector, "vector_field", ["field1", "field2"])
+
+ # Check properties
+ assert vector_query._vector == sample_vector
+ assert vector_query._field == "vector_field"
+ assert "field1" in vector_query._return_fields
+ assert isinstance(vector_query.query, Query)
+ assert isinstance(vector_query.params, dict)
+ assert vector_query.params != {}
diff --git a/tests/unit/test_schema.py b/tests/unit/test_schema.py
index 560fc08a..2e71a769 100644
--- a/tests/unit/test_schema.py
+++ b/tests/unit/test_schema.py
@@ -1,3 +1,5 @@
+import pathlib
+
import pytest
from pydantic import ValidationError
from redis.commands.search.field import (
@@ -12,15 +14,21 @@
FlatVectorField,
GeoFieldSchema,
HNSWVectorField,
- MetadataSchemaGenerator,
+ IndexModel,
NumericFieldSchema,
+ SchemaGenerator,
SchemaModel,
+ StorageType,
TagFieldSchema,
TextFieldSchema,
read_schema,
)
+def get_base_path():
+ return pathlib.Path(__file__).parent.resolve()
+
+
# Utility functions to create schema instances with default values
def create_text_field_schema(**kwargs):
defaults = {"name": "example_textfield", "sortable": False, "weight": 1.0}
@@ -143,15 +151,43 @@ def test_flat_vector_field_block_size_not_set():
assert "INITIAL_CAP" not in field_exported.args
-# Test for schema model validation
-def test_schema_model_validation_success():
- valid_index = {"name": "test_index", "storage_type": "hash"}
- valid_fields = {"text": [create_text_field_schema()]}
- schema_model = SchemaModel(index=valid_index, fields=valid_fields)
+# Tests for IndexModel
+
+
+def test_index_model_defaults():
+ index = IndexModel(name="test_index")
+ assert index.name == "test_index"
+ assert index.prefix == "rvl"
+ assert index.key_separator == ":"
+ assert index.storage_type == StorageType.HASH
+
+
+def test_index_model_custom_settings():
+ index = IndexModel(
+ name="test_index", prefix="custom", key_separator="_", storage_type="json"
+ )
+ assert index.name == "test_index"
+ assert index.prefix == "custom"
+ assert index.key_separator == "_"
+ assert index.storage_type == StorageType.JSON
+
+
+def test_index_model_validation_errors():
+ # Missing required field
+ with pytest.raises(ValueError):
+ IndexModel()
+
+ # Invalid type
+ with pytest.raises(ValidationError):
+ IndexModel(name="test_index", prefix=None)
+
+ # Invalid type
+ with pytest.raises(ValidationError):
+ IndexModel(name="test_index", key_separator=None)
- assert schema_model.index.name == "test_index"
- assert schema_model.index.storage_type == "hash"
- assert len(schema_model.fields.text) == 1
+ # Invalid type
+ with pytest.raises(ValidationError):
+ IndexModel(name="test_index", storage_type=None)
def test_schema_model_validation_failures():
@@ -165,6 +201,20 @@ def test_schema_model_validation_failures():
SchemaModel(index={}, fields={})
+def test_read_hash_schema():
+ hash_schema = read_schema(
+ str(get_base_path().joinpath("../sample_hash_schema.yaml"))
+ )
+ assert hash_schema.index.name == "hash-test"
+
+
+def test_read_json_schema():
+ json_schema = read_schema(
+ str(get_base_path().joinpath("../sample_json_schema.yaml"))
+ )
+ assert json_schema.index.name == "json-test"
+
+
def test_read_schema_file_not_found():
with pytest.raises(FileNotFoundError):
read_schema("non_existent_file.yaml")
@@ -173,7 +223,7 @@ def test_read_schema_file_not_found():
# Fixture for the generator instance
@pytest.fixture
def schema_generator():
- return MetadataSchemaGenerator()
+ return SchemaGenerator()
# Test cases for _test_numeric
diff --git a/tests/unit/test_storage.py b/tests/unit/test_storage.py
new file mode 100644
index 00000000..d0da9dbc
--- /dev/null
+++ b/tests/unit/test_storage.py
@@ -0,0 +1,80 @@
+import pytest
+
+from redisvl.storage import BaseStorage, HashStorage, JsonStorage
+
+
+@pytest.fixture(params=[JsonStorage, HashStorage])
+def storage_instance(request):
+ StorageClass = request.param
+ instance = StorageClass(prefix="test", key_separator=":")
+ return instance
+
+
+def test_key_formatting(storage_instance):
+ key = "1234"
+ generated_key = storage_instance._key(key, "", "")
+ assert generated_key == key, "The generated key does not match the expected format."
+ generated_key = storage_instance._key(key, "", ":")
+ assert generated_key == key, "The generated key does not match the expected format."
+ generated_key = storage_instance._key(key, "test", ":")
+ assert (
+ generated_key == f"test:{key}"
+ ), "The generated key does not match the expected format."
+
+
+def test_create_key(storage_instance):
+ key_field = "id"
+ obj = {key_field: "1234"}
+ expected_key = (
+ f"{storage_instance._prefix}{storage_instance._key_separator}{obj[key_field]}"
+ )
+ generated_key = storage_instance._create_key(obj, key_field)
+ assert (
+ generated_key == expected_key
+ ), "The generated key does not match the expected format."
+
+
+def test_validate_success(storage_instance):
+ data = {"foo": "bar"}
+ try:
+ storage_instance._validate(data)
+ except Exception as e:
+ pytest.fail(f"_validate should not raise an exception here, but raised {e}")
+
+
+def test_validate_failure(storage_instance):
+ data = "Some invalid data type"
+ with pytest.raises(TypeError):
+ storage_instance._validate(data)
+ data = 12345
+ with pytest.raises(TypeError):
+ storage_instance._validate(data)
+
+
+def test_preprocess(storage_instance):
+ data = {"key": "value"}
+ preprocessed_data = storage_instance._preprocess(preprocess=None, obj=data)
+ assert preprocessed_data == data
+
+ def fn(d):
+ d["foo"] = "bar"
+ return d
+
+ preprocessed_data = storage_instance._preprocess(fn, data)
+ assert "foo" in preprocessed_data
+ assert preprocessed_data["foo"] == "bar"
+
+
+@pytest.mark.asyncio
+async def test_preprocess(storage_instance):
+ data = {"key": "value"}
+ preprocessed_data = await storage_instance._apreprocess(preprocess=None, obj=data)
+ assert preprocessed_data == data
+
+ async def fn(d):
+ d["foo"] = "bar"
+ return d
+
+ preprocessed_data = await storage_instance._apreprocess(data, fn)
+ assert "foo" in preprocessed_data
+ assert preprocessed_data["foo"] == "bar"