diff --git a/docs/user_guide/hybrid_queries_02.ipynb b/docs/user_guide/hybrid_queries_02.ipynb
index d09be7db..32ac1d30 100644
--- a/docs/user_guide/hybrid_queries_02.ipynb
+++ b/docs/user_guide/hybrid_queries_02.ipynb
@@ -76,7 +76,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
@@ -94,15 +94,15 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "\u001b[32m14:06:19\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n",
- "\u001b[32m14:06:19\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_queries\n"
+ "\u001b[32m14:16:51\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n",
+ "\u001b[32m14:16:51\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_queries\n"
]
}
],
@@ -113,7 +113,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
@@ -142,7 +142,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 26,
"metadata": {},
"outputs": [
{
@@ -164,11 +164,12 @@
"\n",
"t = Tag(\"credit_score\") == \"high\"\n",
"\n",
- "v = VectorQuery([0.1, 0.1, 0.5],\n",
- " \"user_embedding\",\n",
- " return_fields=[\"user\", \"credit_score\", \"age\", \"job\", \"office_location\"],\n",
- " filter_expression=t)\n",
- "\n",
+ "v = VectorQuery(\n",
+ " vector=[0.1, 0.1, 0.5],\n",
+ " vector_field_name=\"user_embedding\",\n",
+ " return_fields=[\"user\", \"credit_score\", \"age\", \"job\", \"office_location\"],\n",
+ " filter_expression=t\n",
+ ")\n",
"\n",
"results = index.query(v)\n",
"result_print(results)"
@@ -176,7 +177,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 7,
"metadata": {},
"outputs": [
{
@@ -202,7 +203,7 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 8,
"metadata": {},
"outputs": [
{
@@ -228,7 +229,7 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 9,
"metadata": {},
"outputs": [
{
@@ -265,7 +266,7 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 10,
"metadata": {},
"outputs": [
{
@@ -300,7 +301,7 @@
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 11,
"metadata": {},
"outputs": [
{
@@ -327,7 +328,7 @@
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": 12,
"metadata": {},
"outputs": [
{
@@ -353,7 +354,7 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": 13,
"metadata": {},
"outputs": [
{
@@ -388,7 +389,7 @@
},
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": 14,
"metadata": {},
"outputs": [
{
@@ -416,7 +417,7 @@
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": 15,
"metadata": {},
"outputs": [
{
@@ -442,7 +443,7 @@
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": 16,
"metadata": {},
"outputs": [
{
@@ -468,7 +469,7 @@
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": 17,
"metadata": {},
"outputs": [
{
@@ -494,7 +495,7 @@
},
{
"cell_type": "code",
- "execution_count": 20,
+ "execution_count": 18,
"metadata": {},
"outputs": [
{
@@ -520,7 +521,7 @@
},
{
"cell_type": "code",
- "execution_count": 21,
+ "execution_count": 19,
"metadata": {},
"outputs": [
{
@@ -544,6 +545,91 @@
"result_print(index.query(v))"
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Use raw query strings as input. Below we use the `~` flag to indicate that the full text query is optional. We also choose the BM25 scorer and return document scores along with the result."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[{'id': 'user_queries_docs:409ff48274724984ba14865db0495fc5',\n",
+ " 'score': 0.9090908893868948,\n",
+ " 'vector_distance': '0',\n",
+ " 'user': 'john',\n",
+ " 'credit_score': 'high',\n",
+ " 'age': '18',\n",
+ " 'job': 'engineer',\n",
+ " 'office_location': '-122.4194,37.7749'},\n",
+ " {'id': 'user_queries_docs:69cb262c303a4147b213dfdec8bd4b01',\n",
+ " 'score': 0.0,\n",
+ " 'vector_distance': '0',\n",
+ " 'user': 'derrick',\n",
+ " 'credit_score': 'low',\n",
+ " 'age': '14',\n",
+ " 'job': 'doctor',\n",
+ " 'office_location': '-122.4194,37.7749'},\n",
+ " {'id': 'user_queries_docs:562263669ff74a0295c515018d151d7b',\n",
+ " 'score': 0.9090908893868948,\n",
+ " 'vector_distance': '0.109129190445',\n",
+ " 'user': 'tyler',\n",
+ " 'credit_score': 'high',\n",
+ " 'age': '100',\n",
+ " 'job': 'engineer',\n",
+ " 'office_location': '-122.0839,37.3861'},\n",
+ " {'id': 'user_queries_docs:94176145f9de4e288ca2460cd5d1188e',\n",
+ " 'score': 0.0,\n",
+ " 'vector_distance': '0.158808946609',\n",
+ " 'user': 'tim',\n",
+ " 'credit_score': 'high',\n",
+ " 'age': '12',\n",
+ " 'job': 'dermatologist',\n",
+ " 'office_location': '-122.0839,37.3861'},\n",
+ " {'id': 'user_queries_docs:d0bcf6842862410583901004b6b3aeba',\n",
+ " 'score': 0.0,\n",
+ " 'vector_distance': '0.217882037163',\n",
+ " 'user': 'taimur',\n",
+ " 'credit_score': 'low',\n",
+ " 'age': '15',\n",
+ " 'job': 'CEO',\n",
+ " 'office_location': '-122.0839,37.3861'},\n",
+ " {'id': 'user_queries_docs:3dec0e9f2db04e19bff224c5a2a0ba3c',\n",
+ " 'score': 0.0,\n",
+ " 'vector_distance': '0.266666650772',\n",
+ " 'user': 'nancy',\n",
+ " 'credit_score': 'high',\n",
+ " 'age': '94',\n",
+ " 'job': 'doctor',\n",
+ " 'office_location': '-122.4194,37.7749'},\n",
+ " {'id': 'user_queries_docs:93ee6c0e4ccb42f6b7af7858ea6a6408',\n",
+ " 'score': 0.0,\n",
+ " 'vector_distance': '0.653301358223',\n",
+ " 'user': 'joe',\n",
+ " 'credit_score': 'medium',\n",
+ " 'age': '35',\n",
+ " 'job': 'dentist',\n",
+ " 'office_location': '-122.0839,37.3861'}]"
+ ]
+ },
+ "execution_count": 32,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "v.set_filter(\"(~(@job:engineer))\")\n",
+ "v.scorer(\"BM25\").with_scores()\n",
+ "\n",
+ "index.query(v)"
+ ]
+ },
{
"cell_type": "markdown",
"metadata": {},
@@ -555,13 +641,13 @@
},
{
"cell_type": "code",
- "execution_count": 22,
+ "execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "
vector_distance | user | credit_score | age | job | office_location |
---|
0 | john | high | 18 | engineer | -122.4194,37.7749 |
0 | derrick | low | 14 | doctor | -122.4194,37.7749 |
0.266666650772 | nancy | high | 94 | doctor | -122.4194,37.7749 |
"
+ "score | vector_distance | user | credit_score | age | job | office_location |
---|
0.4545454446934474 | 0 | john | high | 18 | engineer | -122.4194,37.7749 |
0.4545454446934474 | 0 | derrick | low | 14 | doctor | -122.4194,37.7749 |
0.4545454446934474 | 0.266666650772 | nancy | high | 94 | doctor | -122.4194,37.7749 |
"
],
"text/plain": [
""
@@ -583,13 +669,13 @@
},
{
"cell_type": "code",
- "execution_count": 23,
+ "execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "vector_distance | user | credit_score | age | job | office_location |
---|
0 | john | high | 18 | engineer | -122.4194,37.7749 |
0 | derrick | low | 14 | doctor | -122.4194,37.7749 |
0.109129190445 | tyler | high | 100 | engineer | -122.0839,37.3861 |
0.158808946609 | tim | high | 12 | dermatologist | -122.0839,37.3861 |
0.217882037163 | taimur | low | 15 | CEO | -122.0839,37.3861 |
0.266666650772 | nancy | high | 94 | doctor | -122.4194,37.7749 |
0.653301358223 | joe | medium | 35 | dentist | -122.0839,37.3861 |
"
+ "score | vector_distance | user | credit_score | age | job | office_location |
---|
0.4545454446934474 | 0 | john | high | 18 | engineer | -122.4194,37.7749 |
0.4545454446934474 | 0 | derrick | low | 14 | doctor | -122.4194,37.7749 |
0.4545454446934474 | 0.109129190445 | tyler | high | 100 | engineer | -122.0839,37.3861 |
0.4545454446934474 | 0.158808946609 | tim | high | 12 | dermatologist | -122.0839,37.3861 |
0.4545454446934474 | 0.217882037163 | taimur | low | 15 | CEO | -122.0839,37.3861 |
0.4545454446934474 | 0.266666650772 | nancy | high | 94 | doctor | -122.4194,37.7749 |
0.4545454446934474 | 0.653301358223 | joe | medium | 35 | dentist | -122.0839,37.3861 |
"
],
"text/plain": [
""
@@ -609,13 +695,13 @@
},
{
"cell_type": "code",
- "execution_count": 24,
+ "execution_count": 35,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "vector_distance | user | credit_score | age | job | office_location |
---|
0.109129190445 | tyler | high | 100 | engineer | -122.0839,37.3861 |
0.158808946609 | tim | high | 12 | dermatologist | -122.0839,37.3861 |
0.217882037163 | taimur | low | 15 | CEO | -122.0839,37.3861 |
0.653301358223 | joe | medium | 35 | dentist | -122.0839,37.3861 |
"
+ "score | vector_distance | user | credit_score | age | job | office_location |
---|
0.0 | 0.109129190445 | tyler | high | 100 | engineer | -122.0839,37.3861 |
0.0 | 0.158808946609 | tim | high | 12 | dermatologist | -122.0839,37.3861 |
0.0 | 0.217882037163 | taimur | low | 15 | CEO | -122.0839,37.3861 |
0.0 | 0.653301358223 | joe | medium | 35 | dentist | -122.0839,37.3861 |
"
],
"text/plain": [
""
@@ -646,7 +732,7 @@
},
{
"cell_type": "code",
- "execution_count": 25,
+ "execution_count": 36,
"metadata": {},
"outputs": [
{
@@ -689,7 +775,7 @@
},
{
"cell_type": "code",
- "execution_count": 26,
+ "execution_count": 37,
"metadata": {},
"outputs": [
{
@@ -733,7 +819,7 @@
},
{
"cell_type": "code",
- "execution_count": 27,
+ "execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
@@ -748,7 +834,7 @@
},
{
"cell_type": "code",
- "execution_count": 28,
+ "execution_count": 39,
"metadata": {},
"outputs": [
{
@@ -773,7 +859,7 @@
},
{
"cell_type": "code",
- "execution_count": 29,
+ "execution_count": 40,
"metadata": {},
"outputs": [
{
@@ -798,7 +884,7 @@
},
{
"cell_type": "code",
- "execution_count": 30,
+ "execution_count": 41,
"metadata": {},
"outputs": [
{
@@ -823,7 +909,7 @@
},
{
"cell_type": "code",
- "execution_count": 31,
+ "execution_count": 42,
"metadata": {},
"outputs": [
{
@@ -850,14 +936,14 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "## Filter Queries\n",
+ "## Non-vector Queries\n",
"\n",
"In some cases, you may not want to run a vector query, but just use a ``FilterExpression`` similar to a SQL query. The ``FilterQuery`` class enable this functionality. It is similar to the ``VectorQuery`` class but soley takes a ``FilterExpression``."
]
},
{
"cell_type": "code",
- "execution_count": 32,
+ "execution_count": 43,
"metadata": {},
"outputs": [
{
@@ -899,7 +985,7 @@
},
{
"cell_type": "code",
- "execution_count": 33,
+ "execution_count": 44,
"metadata": {},
"outputs": [
{
@@ -933,7 +1019,7 @@
},
{
"cell_type": "code",
- "execution_count": 34,
+ "execution_count": 45,
"metadata": {},
"outputs": [
{
@@ -974,7 +1060,7 @@
},
{
"cell_type": "code",
- "execution_count": 35,
+ "execution_count": 46,
"metadata": {},
"outputs": [
{
@@ -1005,7 +1091,7 @@
},
{
"cell_type": "code",
- "execution_count": 36,
+ "execution_count": 47,
"metadata": {},
"outputs": [
{
@@ -1045,7 +1131,7 @@
},
{
"cell_type": "code",
- "execution_count": 45,
+ "execution_count": 48,
"metadata": {},
"outputs": [
{
@@ -1086,7 +1172,7 @@
},
{
"cell_type": "code",
- "execution_count": 47,
+ "execution_count": 49,
"metadata": {},
"outputs": [
{
@@ -1095,7 +1181,7 @@
"'@job:(\"engineer\")=>[KNN 5 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY age DESC DIALECT 3 LIMIT 0 5'"
]
},
- "execution_count": 47,
+ "execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
@@ -1107,7 +1193,7 @@
},
{
"cell_type": "code",
- "execution_count": 36,
+ "execution_count": 50,
"metadata": {},
"outputs": [
{
@@ -1116,7 +1202,7 @@
"'@credit_score:{high}'"
]
},
- "execution_count": 36,
+ "execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
@@ -1129,7 +1215,7 @@
},
{
"cell_type": "code",
- "execution_count": 48,
+ "execution_count": 51,
"metadata": {},
"outputs": [
{
@@ -1138,7 +1224,7 @@
"'((@credit_score:{high} @age:[18 +inf]) @age:[-inf 100])'"
]
},
- "execution_count": 48,
+ "execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
@@ -1163,17 +1249,17 @@
},
{
"cell_type": "code",
- "execution_count": 49,
+ "execution_count": 52,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "{'id': 'user_queries_docs:43dc726b8a9541a6ab40ddedc8e48657', 'payload': None, 'user': 'john', 'age': '18', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '==\\x00\\x00\\x00?'}\n",
- "{'id': 'user_queries_docs:93fdc65248a64fd390ed77aa3c248c23', 'payload': None, 'user': 'nancy', 'age': '94', 'job': 'doctor', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '333?=\\x00\\x00\\x00?'}\n",
- "{'id': 'user_queries_docs:f1d1f69e5e6c41cb9b7ae70ed8f75da5', 'payload': None, 'user': 'tyler', 'age': '100', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '=>\\x00\\x00\\x00?'}\n",
- "{'id': 'user_queries_docs:5dc68e47ef6d4a0f885c67368f0710b7', 'payload': None, 'user': 'tim', 'age': '12', 'job': 'dermatologist', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '>>\\x00\\x00\\x00?'}\n"
+ "{'id': 'user_queries_docs:409ff48274724984ba14865db0495fc5', 'payload': None, 'user': 'john', 'age': '18', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '==\\x00\\x00\\x00?'}\n",
+ "{'id': 'user_queries_docs:3dec0e9f2db04e19bff224c5a2a0ba3c', 'payload': None, 'user': 'nancy', 'age': '94', 'job': 'doctor', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '333?=\\x00\\x00\\x00?'}\n",
+ "{'id': 'user_queries_docs:562263669ff74a0295c515018d151d7b', 'payload': None, 'user': 'tyler', 'age': '100', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '=>\\x00\\x00\\x00?'}\n",
+ "{'id': 'user_queries_docs:94176145f9de4e288ca2460cd5d1188e', 'payload': None, 'user': 'tim', 'age': '12', 'job': 'dermatologist', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '>>\\x00\\x00\\x00?'}\n"
]
}
],
@@ -1185,7 +1271,7 @@
},
{
"cell_type": "code",
- "execution_count": 50,
+ "execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
@@ -1210,7 +1296,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.11.9"
+ "version": "3.11.10"
},
"orig_nbformat": 4,
"vscode": {
diff --git a/redisvl/query/query.py b/redisvl/query/query.py
index 856a2572..d99a080e 100644
--- a/redisvl/query/query.py
+++ b/redisvl/query/query.py
@@ -10,7 +10,7 @@ class BaseQuery(RedisQuery):
"""Base query class used to subclass many query types."""
_params: Dict[str, Any] = {}
- _filter_expression: FilterExpression = FilterExpression("*")
+ _filter_expression: Union[str, FilterExpression] = FilterExpression("*")
def __init__(self, query_string: str = "*"):
"""
@@ -29,30 +29,33 @@ def _build_query_string(self) -> str:
"""Build the full Redis query string."""
raise NotImplementedError("Must be implemented by subclasses")
- def set_filter(self, filter_expression: Optional[FilterExpression] = None):
+ def set_filter(
+ self, filter_expression: Optional[Union[str, FilterExpression]] = None
+ ):
"""Set the filter expression for the query.
Args:
- filter_expression (Optional[FilterExpression], optional): The filter to apply to the query.
+ filter_expression (Optional[Union[str, FilterExpression]], optional): The filter
+ expression or query string to use on the query.
Raises:
- TypeError: If filter_expression is not of type redisvl.query.FilterExpression
+ TypeError: If filter_expression is not a valid FilterExpression or string.
"""
if filter_expression is None:
# Default filter to match everything
self._filter_expression = FilterExpression("*")
- elif isinstance(filter_expression, FilterExpression):
+ elif isinstance(filter_expression, (FilterExpression, str)):
self._filter_expression = filter_expression
else:
raise TypeError(
- "filter_expression must be of type FilterExpression or None"
+ "filter_expression must be of type FilterExpression or string or None"
)
# Reset the query string
self._query_string = self._build_query_string()
@property
- def filter(self) -> FilterExpression:
+ def filter(self) -> Union[str, FilterExpression]:
"""The filter expression for the query."""
return self._filter_expression
@@ -70,7 +73,7 @@ def params(self) -> Dict[str, Any]:
class FilterQuery(BaseQuery):
def __init__(
self,
- filter_expression: Optional[FilterExpression] = None,
+ filter_expression: Optional[Union[str, FilterExpression]] = None,
return_fields: Optional[List[str]] = None,
num_results: int = 10,
dialect: int = 2,
@@ -81,7 +84,7 @@ def __init__(
"""A query for running a filtered search with a filter expression.
Args:
- filter_expression (Optional[FilterExpression]): The optional filter
+ filter_expression (Optional[Union[str, FilterExpression]]): The optional filter
expression to query with. Defaults to '*'.
return_fields (Optional[List[str]], optional): The fields to return.
num_results (Optional[int], optional): The number of results to return. Defaults to 10.
@@ -93,8 +96,8 @@ def __init__(
Raises:
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
"""
- if filter_expression:
- self._filter_expression = filter_expression
+ self.set_filter(filter_expression)
+
if params:
self._params = params
@@ -117,22 +120,22 @@ def __init__(
def _build_query_string(self) -> str:
"""Build the full query string based on the filter and other components."""
- # Example logic to build the full query string from filter and other parts
- # This can be customized in child classes for more complex queries
- return str(self._filter_expression)
+ if isinstance(self._filter_expression, FilterExpression):
+ return str(self._filter_expression)
+ return self._filter_expression
class CountQuery(BaseQuery):
def __init__(
self,
- filter_expression: Optional[FilterExpression] = None,
+ filter_expression: Optional[Union[str, FilterExpression]] = None,
dialect: int = 2,
params: Optional[Dict[str, Any]] = None,
):
"""A query for a simple count operation provided some filter expression.
Args:
- filter_expression (Optional[FilterExpression]): The filter expression to query with. Defaults to None.
+ filter_expression (Optional[Union[str, FilterExpression]]): The filter expression to query with. Defaults to None.
params (Optional[Dict[str, Any]], optional): The parameters for the query. Defaults to None.
Raises:
@@ -148,8 +151,8 @@ def __init__(
count = index.query(query)
"""
- if filter_expression:
- self._filter_expression = filter_expression
+ self.set_filter(filter_expression)
+
if params:
self._params = params
@@ -162,9 +165,9 @@ def __init__(
def _build_query_string(self) -> str:
"""Build the full query string based on the filter and other components."""
- # Example logic to build the full query string from filter and other parts
- # This can be customized in child classes for more complex queries
- return str(self._filter_expression)
+ if isinstance(self._filter_expression, FilterExpression):
+ return str(self._filter_expression)
+ return self._filter_expression
class BaseVectorQuery:
@@ -178,7 +181,7 @@ def __init__(
vector: Union[List[float], bytes],
vector_field_name: str,
return_fields: Optional[List[str]] = None,
- filter_expression: Optional[FilterExpression] = None,
+ filter_expression: Optional[Union[str, FilterExpression]] = None,
dtype: str = "float32",
num_results: int = 10,
return_score: bool = True,
@@ -195,7 +198,7 @@ def __init__(
against in the database.
return_fields (List[str]): The declared fields to return with search
results.
- filter_expression (FilterExpression, optional): A filter to apply
+ filter_expression (Union[str, FilterExpression], optional): A filter to apply
along with the vector search. Defaults to None.
dtype (str, optional): The dtype of the vector. Defaults to
"float32".
@@ -217,15 +220,13 @@ def __init__(
Note:
Learn more about vector queries in Redis: https://redis.io/docs/interact/search-and-query/search/vectors/#knn-search
"""
- if filter_expression:
- self._filter_expression = filter_expression
-
self._vector = vector
self._vector_field_name = vector_field_name
self._dtype = dtype
self._num_results = num_results
-
+ self.set_filter(filter_expression)
query_string = self._build_query_string()
+
super().__init__(query_string)
# Handle query modifiers
@@ -247,7 +248,10 @@ def __init__(
def _build_query_string(self) -> str:
"""Build the full query string for vector search with optional filtering."""
- return f"{str(self._filter_expression)}=>[KNN {self._num_results} @{self._vector_field_name} ${self.VECTOR_PARAM} AS {self.DISTANCE_ID}]"
+ filter_expression = self._filter_expression
+ if isinstance(filter_expression, FilterExpression):
+ filter_expression = str(filter_expression)
+ return f"{filter_expression}=>[KNN {self._num_results} @{self._vector_field_name} ${self.VECTOR_PARAM} AS {self.DISTANCE_ID}]"
@property
def params(self) -> Dict[str, Any]:
@@ -272,7 +276,7 @@ def __init__(
vector: Union[List[float], bytes],
vector_field_name: str,
return_fields: Optional[List[str]] = None,
- filter_expression: Optional[FilterExpression] = None,
+ filter_expression: Optional[Union[str, FilterExpression]] = None,
dtype: str = "float32",
distance_threshold: float = 0.2,
num_results: int = 10,
@@ -290,7 +294,7 @@ def __init__(
against in the database.
return_fields (List[str]): The declared fields to return with search
results.
- filter_expression (FilterExpression, optional): A filter to apply
+ filter_expression (Union[str, FilterExpression], optional): A filter to apply
along with the range query. Defaults to None.
dtype (str, optional): The dtype of the vector. Defaults to
"float32".
@@ -316,16 +320,14 @@ def __init__(
Learn more about vector range queries: https://redis.io/docs/interact/search-and-query/search/vectors/#range-query
"""
- if filter_expression:
- self._filter_expression = filter_expression
-
self._vector = vector
self._vector_field_name = vector_field_name
self._dtype = dtype
self._num_results = num_results
self.set_distance_threshold(distance_threshold)
-
+ self.set_filter(filter_expression)
query_string = self._build_query_string()
+
super().__init__(query_string)
# Handle query modifiers
@@ -348,13 +350,14 @@ def __init__(
def _build_query_string(self) -> str:
"""Build the full query string for vector range queries with optional filtering"""
base_query = f"@{self._vector_field_name}:[VECTOR_RANGE ${self.DISTANCE_THRESHOLD_PARAM} ${self.VECTOR_PARAM}]"
- _filter = str(self._filter_expression)
- if _filter != "*":
- return (
- f"({base_query}=>{{$yield_distance_as: {self.DISTANCE_ID}}} {_filter})"
- )
- else:
+
+ filter_expression = self._filter_expression
+ if isinstance(filter_expression, FilterExpression):
+ filter_expression = str(filter_expression)
+
+ if filter_expression == "*":
return f"{base_query}=>{{$yield_distance_as: {self.DISTANCE_ID}}}"
+ return f"({base_query}=>{{$yield_distance_as: {self.DISTANCE_ID}}} {filter_expression})"
def set_distance_threshold(self, distance_threshold: float):
"""Set the distance threshold for the query.
diff --git a/tests/integration/test_query.py b/tests/integration/test_query.py
index 752d5fe7..2c6dc376 100644
--- a/tests/integration/test_query.py
+++ b/tests/integration/test_query.py
@@ -312,8 +312,52 @@ def test_filters(index, query):
t = Text("job") % ""
search(query, index, t, 7)
- t = Text("job") % None
- search(query, index, t, 7)
+
+def test_manual_string_filters(index, query):
+ # Simple Tag Filter
+ t = "@credit_score:{high}"
+ search(query, index, t, 4, credit_check="high")
+
+ # Multiple Tags
+ t = "@credit_score:{high|low}"
+ search(query, index, t, 6)
+
+ # Simple Numeric Filter
+ n1 = "@age:[18 +inf]"
+ search(query, index, n1, 4, age_range=(18, 100))
+
+ # intersection of rules
+ n2 = "@age:[18 (100]"
+ search(query, index, n2, 3, age_range=(18, 99))
+
+ n3 = "(@age:[-inf (18] | @age:[(94 +inf])"
+ search(query, index, n3, 4, age_range=(95, 17))
+
+ n4 = "(-@age:[18 18])"
+ search(query, index, n4, 6, age_range=(0, 0, 18))
+
+ # Geographic filters
+ g = "@location:[-122.4194 37.7749 1 m]"
+ search(query, index, g, 3, location="-122.4194,37.7749")
+
+ g = "(-@location:[-122.4194 37.7749 1 m])"
+ search(query, index, g, 4, location="-110.0839,37.3861")
+
+ # Text filters
+ t = "@job:engineer"
+ search(query, index, t, 2)
+
+ t = "(-@job:engineer)"
+ search(query, index, t, 5)
+
+ t = "@job:enginee*"
+ search(query, index, t, 2)
+
+ t = "@job:(engine*|doctor)"
+ search(query, index, t, 4)
+
+ t = "@job:*engine*"
+ search(query, index, t, 2)
def test_filter_combinations(index, query):
diff --git a/tests/unit/test_query_types.py b/tests/unit/test_query_types.py
index 17426868..1e9fdb08 100644
--- a/tests/unit/test_query_types.py
+++ b/tests/unit/test_query_types.py
@@ -252,3 +252,28 @@ def test_query_modifiers(query):
assert query._no_stopwords
assert query._with_scores
assert query._fields == ("test",)
+
+
+@pytest.mark.parametrize(
+ "query",
+ [
+ CountQuery(),
+ FilterQuery(),
+ VectorQuery(vector=[1, 2, 3], vector_field_name="vector"),
+ RangeQuery(vector=[1, 2, 3], vector_field_name="vector"),
+ ],
+)
+def test_string_filter_expressions(query):
+ # No filter
+ query.set_filter("*")
+ assert query._filter_expression == "*"
+
+ # Simple full text search
+ query.set_filter("hello world")
+ assert query._filter_expression == "hello world"
+ assert query.query_string().__contains__("hello world")
+
+ # Optional flag
+ query.set_filter("~(@desciption:(hello | world))")
+ assert query._filter_expression == "~(@desciption:(hello | world))"
+ assert query.query_string().__contains__("~(@desciption:(hello | world))")