diff --git a/docs/user_guide/hybrid_queries_02.ipynb b/docs/user_guide/hybrid_queries_02.ipynb index d09be7db..32ac1d30 100644 --- a/docs/user_guide/hybrid_queries_02.ipynb +++ b/docs/user_guide/hybrid_queries_02.ipynb @@ -76,7 +76,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -94,15 +94,15 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\u001b[32m14:06:19\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", - "\u001b[32m14:06:19\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_queries\n" + "\u001b[32m14:16:51\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", + "\u001b[32m14:16:51\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_queries\n" ] } ], @@ -113,7 +113,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -142,7 +142,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 26, "metadata": {}, "outputs": [ { @@ -164,11 +164,12 @@ "\n", "t = Tag(\"credit_score\") == \"high\"\n", "\n", - "v = VectorQuery([0.1, 0.1, 0.5],\n", - " \"user_embedding\",\n", - " return_fields=[\"user\", \"credit_score\", \"age\", \"job\", \"office_location\"],\n", - " filter_expression=t)\n", - "\n", + "v = VectorQuery(\n", + " vector=[0.1, 0.1, 0.5],\n", + " vector_field_name=\"user_embedding\",\n", + " return_fields=[\"user\", \"credit_score\", \"age\", \"job\", \"office_location\"],\n", + " filter_expression=t\n", + ")\n", "\n", "results = index.query(v)\n", "result_print(results)" @@ -176,7 +177,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -202,7 +203,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -228,7 +229,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -265,7 +266,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -300,7 +301,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -327,7 +328,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -353,7 +354,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -388,7 +389,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -416,7 +417,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -442,7 +443,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -468,7 +469,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -494,7 +495,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 18, "metadata": {}, "outputs": [ { @@ -520,7 +521,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -544,6 +545,91 @@ "result_print(index.query(v))" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Use raw query strings as input. Below we use the `~` flag to indicate that the full text query is optional. We also choose the BM25 scorer and return document scores along with the result." + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'id': 'user_queries_docs:409ff48274724984ba14865db0495fc5',\n", + " 'score': 0.9090908893868948,\n", + " 'vector_distance': '0',\n", + " 'user': 'john',\n", + " 'credit_score': 'high',\n", + " 'age': '18',\n", + " 'job': 'engineer',\n", + " 'office_location': '-122.4194,37.7749'},\n", + " {'id': 'user_queries_docs:69cb262c303a4147b213dfdec8bd4b01',\n", + " 'score': 0.0,\n", + " 'vector_distance': '0',\n", + " 'user': 'derrick',\n", + " 'credit_score': 'low',\n", + " 'age': '14',\n", + " 'job': 'doctor',\n", + " 'office_location': '-122.4194,37.7749'},\n", + " {'id': 'user_queries_docs:562263669ff74a0295c515018d151d7b',\n", + " 'score': 0.9090908893868948,\n", + " 'vector_distance': '0.109129190445',\n", + " 'user': 'tyler',\n", + " 'credit_score': 'high',\n", + " 'age': '100',\n", + " 'job': 'engineer',\n", + " 'office_location': '-122.0839,37.3861'},\n", + " {'id': 'user_queries_docs:94176145f9de4e288ca2460cd5d1188e',\n", + " 'score': 0.0,\n", + " 'vector_distance': '0.158808946609',\n", + " 'user': 'tim',\n", + " 'credit_score': 'high',\n", + " 'age': '12',\n", + " 'job': 'dermatologist',\n", + " 'office_location': '-122.0839,37.3861'},\n", + " {'id': 'user_queries_docs:d0bcf6842862410583901004b6b3aeba',\n", + " 'score': 0.0,\n", + " 'vector_distance': '0.217882037163',\n", + " 'user': 'taimur',\n", + " 'credit_score': 'low',\n", + " 'age': '15',\n", + " 'job': 'CEO',\n", + " 'office_location': '-122.0839,37.3861'},\n", + " {'id': 'user_queries_docs:3dec0e9f2db04e19bff224c5a2a0ba3c',\n", + " 'score': 0.0,\n", + " 'vector_distance': '0.266666650772',\n", + " 'user': 'nancy',\n", + " 'credit_score': 'high',\n", + " 'age': '94',\n", + " 'job': 'doctor',\n", + " 'office_location': '-122.4194,37.7749'},\n", + " {'id': 'user_queries_docs:93ee6c0e4ccb42f6b7af7858ea6a6408',\n", + " 'score': 0.0,\n", + " 'vector_distance': '0.653301358223',\n", + " 'user': 'joe',\n", + " 'credit_score': 'medium',\n", + " 'age': '35',\n", + " 'job': 'dentist',\n", + " 'office_location': '-122.0839,37.3861'}]" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "v.set_filter(\"(~(@job:engineer))\")\n", + "v.scorer(\"BM25\").with_scores()\n", + "\n", + "index.query(v)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -555,13 +641,13 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 33, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
" + "
scorevector_distanceusercredit_scoreagejoboffice_location
0.45454544469344740johnhigh18engineer-122.4194,37.7749
0.45454544469344740derricklow14doctor-122.4194,37.7749
0.45454544469344740.266666650772nancyhigh94doctor-122.4194,37.7749
" ], "text/plain": [ "" @@ -583,13 +669,13 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.653301358223joemedium35dentist-122.0839,37.3861
" + "
scorevector_distanceusercredit_scoreagejoboffice_location
0.45454544469344740johnhigh18engineer-122.4194,37.7749
0.45454544469344740derricklow14doctor-122.4194,37.7749
0.45454544469344740.109129190445tylerhigh100engineer-122.0839,37.3861
0.45454544469344740.158808946609timhigh12dermatologist-122.0839,37.3861
0.45454544469344740.217882037163taimurlow15CEO-122.0839,37.3861
0.45454544469344740.266666650772nancyhigh94doctor-122.4194,37.7749
0.45454544469344740.653301358223joemedium35dentist-122.0839,37.3861
" ], "text/plain": [ "" @@ -609,13 +695,13 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.653301358223joemedium35dentist-122.0839,37.3861
" + "
scorevector_distanceusercredit_scoreagejoboffice_location
0.00.109129190445tylerhigh100engineer-122.0839,37.3861
0.00.158808946609timhigh12dermatologist-122.0839,37.3861
0.00.217882037163taimurlow15CEO-122.0839,37.3861
0.00.653301358223joemedium35dentist-122.0839,37.3861
" ], "text/plain": [ "" @@ -646,7 +732,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 36, "metadata": {}, "outputs": [ { @@ -689,7 +775,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 37, "metadata": {}, "outputs": [ { @@ -733,7 +819,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 38, "metadata": {}, "outputs": [], "source": [ @@ -748,7 +834,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 39, "metadata": {}, "outputs": [ { @@ -773,7 +859,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 40, "metadata": {}, "outputs": [ { @@ -798,7 +884,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 41, "metadata": {}, "outputs": [ { @@ -823,7 +909,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 42, "metadata": {}, "outputs": [ { @@ -850,14 +936,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Filter Queries\n", + "## Non-vector Queries\n", "\n", "In some cases, you may not want to run a vector query, but just use a ``FilterExpression`` similar to a SQL query. The ``FilterQuery`` class enable this functionality. It is similar to the ``VectorQuery`` class but soley takes a ``FilterExpression``." ] }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 43, "metadata": {}, "outputs": [ { @@ -899,7 +985,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 44, "metadata": {}, "outputs": [ { @@ -933,7 +1019,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 45, "metadata": {}, "outputs": [ { @@ -974,7 +1060,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 46, "metadata": {}, "outputs": [ { @@ -1005,7 +1091,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 47, "metadata": {}, "outputs": [ { @@ -1045,7 +1131,7 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": 48, "metadata": {}, "outputs": [ { @@ -1086,7 +1172,7 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 49, "metadata": {}, "outputs": [ { @@ -1095,7 +1181,7 @@ "'@job:(\"engineer\")=>[KNN 5 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY age DESC DIALECT 3 LIMIT 0 5'" ] }, - "execution_count": 47, + "execution_count": 49, "metadata": {}, "output_type": "execute_result" } @@ -1107,7 +1193,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 50, "metadata": {}, "outputs": [ { @@ -1116,7 +1202,7 @@ "'@credit_score:{high}'" ] }, - "execution_count": 36, + "execution_count": 50, "metadata": {}, "output_type": "execute_result" } @@ -1129,7 +1215,7 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 51, "metadata": {}, "outputs": [ { @@ -1138,7 +1224,7 @@ "'((@credit_score:{high} @age:[18 +inf]) @age:[-inf 100])'" ] }, - "execution_count": 48, + "execution_count": 51, "metadata": {}, "output_type": "execute_result" } @@ -1163,17 +1249,17 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 52, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "{'id': 'user_queries_docs:43dc726b8a9541a6ab40ddedc8e48657', 'payload': None, 'user': 'john', 'age': '18', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '==\\x00\\x00\\x00?'}\n", - "{'id': 'user_queries_docs:93fdc65248a64fd390ed77aa3c248c23', 'payload': None, 'user': 'nancy', 'age': '94', 'job': 'doctor', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '333?=\\x00\\x00\\x00?'}\n", - "{'id': 'user_queries_docs:f1d1f69e5e6c41cb9b7ae70ed8f75da5', 'payload': None, 'user': 'tyler', 'age': '100', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '=>\\x00\\x00\\x00?'}\n", - "{'id': 'user_queries_docs:5dc68e47ef6d4a0f885c67368f0710b7', 'payload': None, 'user': 'tim', 'age': '12', 'job': 'dermatologist', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '>>\\x00\\x00\\x00?'}\n" + "{'id': 'user_queries_docs:409ff48274724984ba14865db0495fc5', 'payload': None, 'user': 'john', 'age': '18', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '==\\x00\\x00\\x00?'}\n", + "{'id': 'user_queries_docs:3dec0e9f2db04e19bff224c5a2a0ba3c', 'payload': None, 'user': 'nancy', 'age': '94', 'job': 'doctor', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '333?=\\x00\\x00\\x00?'}\n", + "{'id': 'user_queries_docs:562263669ff74a0295c515018d151d7b', 'payload': None, 'user': 'tyler', 'age': '100', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '=>\\x00\\x00\\x00?'}\n", + "{'id': 'user_queries_docs:94176145f9de4e288ca2460cd5d1188e', 'payload': None, 'user': 'tim', 'age': '12', 'job': 'dermatologist', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '>>\\x00\\x00\\x00?'}\n" ] } ], @@ -1185,7 +1271,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": 53, "metadata": {}, "outputs": [], "source": [ @@ -1210,7 +1296,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.9" + "version": "3.11.10" }, "orig_nbformat": 4, "vscode": { diff --git a/redisvl/query/query.py b/redisvl/query/query.py index 856a2572..d99a080e 100644 --- a/redisvl/query/query.py +++ b/redisvl/query/query.py @@ -10,7 +10,7 @@ class BaseQuery(RedisQuery): """Base query class used to subclass many query types.""" _params: Dict[str, Any] = {} - _filter_expression: FilterExpression = FilterExpression("*") + _filter_expression: Union[str, FilterExpression] = FilterExpression("*") def __init__(self, query_string: str = "*"): """ @@ -29,30 +29,33 @@ def _build_query_string(self) -> str: """Build the full Redis query string.""" raise NotImplementedError("Must be implemented by subclasses") - def set_filter(self, filter_expression: Optional[FilterExpression] = None): + def set_filter( + self, filter_expression: Optional[Union[str, FilterExpression]] = None + ): """Set the filter expression for the query. Args: - filter_expression (Optional[FilterExpression], optional): The filter to apply to the query. + filter_expression (Optional[Union[str, FilterExpression]], optional): The filter + expression or query string to use on the query. Raises: - TypeError: If filter_expression is not of type redisvl.query.FilterExpression + TypeError: If filter_expression is not a valid FilterExpression or string. """ if filter_expression is None: # Default filter to match everything self._filter_expression = FilterExpression("*") - elif isinstance(filter_expression, FilterExpression): + elif isinstance(filter_expression, (FilterExpression, str)): self._filter_expression = filter_expression else: raise TypeError( - "filter_expression must be of type FilterExpression or None" + "filter_expression must be of type FilterExpression or string or None" ) # Reset the query string self._query_string = self._build_query_string() @property - def filter(self) -> FilterExpression: + def filter(self) -> Union[str, FilterExpression]: """The filter expression for the query.""" return self._filter_expression @@ -70,7 +73,7 @@ def params(self) -> Dict[str, Any]: class FilterQuery(BaseQuery): def __init__( self, - filter_expression: Optional[FilterExpression] = None, + filter_expression: Optional[Union[str, FilterExpression]] = None, return_fields: Optional[List[str]] = None, num_results: int = 10, dialect: int = 2, @@ -81,7 +84,7 @@ def __init__( """A query for running a filtered search with a filter expression. Args: - filter_expression (Optional[FilterExpression]): The optional filter + filter_expression (Optional[Union[str, FilterExpression]]): The optional filter expression to query with. Defaults to '*'. return_fields (Optional[List[str]], optional): The fields to return. num_results (Optional[int], optional): The number of results to return. Defaults to 10. @@ -93,8 +96,8 @@ def __init__( Raises: TypeError: If filter_expression is not of type redisvl.query.FilterExpression """ - if filter_expression: - self._filter_expression = filter_expression + self.set_filter(filter_expression) + if params: self._params = params @@ -117,22 +120,22 @@ def __init__( def _build_query_string(self) -> str: """Build the full query string based on the filter and other components.""" - # Example logic to build the full query string from filter and other parts - # This can be customized in child classes for more complex queries - return str(self._filter_expression) + if isinstance(self._filter_expression, FilterExpression): + return str(self._filter_expression) + return self._filter_expression class CountQuery(BaseQuery): def __init__( self, - filter_expression: Optional[FilterExpression] = None, + filter_expression: Optional[Union[str, FilterExpression]] = None, dialect: int = 2, params: Optional[Dict[str, Any]] = None, ): """A query for a simple count operation provided some filter expression. Args: - filter_expression (Optional[FilterExpression]): The filter expression to query with. Defaults to None. + filter_expression (Optional[Union[str, FilterExpression]]): The filter expression to query with. Defaults to None. params (Optional[Dict[str, Any]], optional): The parameters for the query. Defaults to None. Raises: @@ -148,8 +151,8 @@ def __init__( count = index.query(query) """ - if filter_expression: - self._filter_expression = filter_expression + self.set_filter(filter_expression) + if params: self._params = params @@ -162,9 +165,9 @@ def __init__( def _build_query_string(self) -> str: """Build the full query string based on the filter and other components.""" - # Example logic to build the full query string from filter and other parts - # This can be customized in child classes for more complex queries - return str(self._filter_expression) + if isinstance(self._filter_expression, FilterExpression): + return str(self._filter_expression) + return self._filter_expression class BaseVectorQuery: @@ -178,7 +181,7 @@ def __init__( vector: Union[List[float], bytes], vector_field_name: str, return_fields: Optional[List[str]] = None, - filter_expression: Optional[FilterExpression] = None, + filter_expression: Optional[Union[str, FilterExpression]] = None, dtype: str = "float32", num_results: int = 10, return_score: bool = True, @@ -195,7 +198,7 @@ def __init__( against in the database. return_fields (List[str]): The declared fields to return with search results. - filter_expression (FilterExpression, optional): A filter to apply + filter_expression (Union[str, FilterExpression], optional): A filter to apply along with the vector search. Defaults to None. dtype (str, optional): The dtype of the vector. Defaults to "float32". @@ -217,15 +220,13 @@ def __init__( Note: Learn more about vector queries in Redis: https://redis.io/docs/interact/search-and-query/search/vectors/#knn-search """ - if filter_expression: - self._filter_expression = filter_expression - self._vector = vector self._vector_field_name = vector_field_name self._dtype = dtype self._num_results = num_results - + self.set_filter(filter_expression) query_string = self._build_query_string() + super().__init__(query_string) # Handle query modifiers @@ -247,7 +248,10 @@ def __init__( def _build_query_string(self) -> str: """Build the full query string for vector search with optional filtering.""" - return f"{str(self._filter_expression)}=>[KNN {self._num_results} @{self._vector_field_name} ${self.VECTOR_PARAM} AS {self.DISTANCE_ID}]" + filter_expression = self._filter_expression + if isinstance(filter_expression, FilterExpression): + filter_expression = str(filter_expression) + return f"{filter_expression}=>[KNN {self._num_results} @{self._vector_field_name} ${self.VECTOR_PARAM} AS {self.DISTANCE_ID}]" @property def params(self) -> Dict[str, Any]: @@ -272,7 +276,7 @@ def __init__( vector: Union[List[float], bytes], vector_field_name: str, return_fields: Optional[List[str]] = None, - filter_expression: Optional[FilterExpression] = None, + filter_expression: Optional[Union[str, FilterExpression]] = None, dtype: str = "float32", distance_threshold: float = 0.2, num_results: int = 10, @@ -290,7 +294,7 @@ def __init__( against in the database. return_fields (List[str]): The declared fields to return with search results. - filter_expression (FilterExpression, optional): A filter to apply + filter_expression (Union[str, FilterExpression], optional): A filter to apply along with the range query. Defaults to None. dtype (str, optional): The dtype of the vector. Defaults to "float32". @@ -316,16 +320,14 @@ def __init__( Learn more about vector range queries: https://redis.io/docs/interact/search-and-query/search/vectors/#range-query """ - if filter_expression: - self._filter_expression = filter_expression - self._vector = vector self._vector_field_name = vector_field_name self._dtype = dtype self._num_results = num_results self.set_distance_threshold(distance_threshold) - + self.set_filter(filter_expression) query_string = self._build_query_string() + super().__init__(query_string) # Handle query modifiers @@ -348,13 +350,14 @@ def __init__( def _build_query_string(self) -> str: """Build the full query string for vector range queries with optional filtering""" base_query = f"@{self._vector_field_name}:[VECTOR_RANGE ${self.DISTANCE_THRESHOLD_PARAM} ${self.VECTOR_PARAM}]" - _filter = str(self._filter_expression) - if _filter != "*": - return ( - f"({base_query}=>{{$yield_distance_as: {self.DISTANCE_ID}}} {_filter})" - ) - else: + + filter_expression = self._filter_expression + if isinstance(filter_expression, FilterExpression): + filter_expression = str(filter_expression) + + if filter_expression == "*": return f"{base_query}=>{{$yield_distance_as: {self.DISTANCE_ID}}}" + return f"({base_query}=>{{$yield_distance_as: {self.DISTANCE_ID}}} {filter_expression})" def set_distance_threshold(self, distance_threshold: float): """Set the distance threshold for the query. diff --git a/tests/integration/test_query.py b/tests/integration/test_query.py index 752d5fe7..2c6dc376 100644 --- a/tests/integration/test_query.py +++ b/tests/integration/test_query.py @@ -312,8 +312,52 @@ def test_filters(index, query): t = Text("job") % "" search(query, index, t, 7) - t = Text("job") % None - search(query, index, t, 7) + +def test_manual_string_filters(index, query): + # Simple Tag Filter + t = "@credit_score:{high}" + search(query, index, t, 4, credit_check="high") + + # Multiple Tags + t = "@credit_score:{high|low}" + search(query, index, t, 6) + + # Simple Numeric Filter + n1 = "@age:[18 +inf]" + search(query, index, n1, 4, age_range=(18, 100)) + + # intersection of rules + n2 = "@age:[18 (100]" + search(query, index, n2, 3, age_range=(18, 99)) + + n3 = "(@age:[-inf (18] | @age:[(94 +inf])" + search(query, index, n3, 4, age_range=(95, 17)) + + n4 = "(-@age:[18 18])" + search(query, index, n4, 6, age_range=(0, 0, 18)) + + # Geographic filters + g = "@location:[-122.4194 37.7749 1 m]" + search(query, index, g, 3, location="-122.4194,37.7749") + + g = "(-@location:[-122.4194 37.7749 1 m])" + search(query, index, g, 4, location="-110.0839,37.3861") + + # Text filters + t = "@job:engineer" + search(query, index, t, 2) + + t = "(-@job:engineer)" + search(query, index, t, 5) + + t = "@job:enginee*" + search(query, index, t, 2) + + t = "@job:(engine*|doctor)" + search(query, index, t, 4) + + t = "@job:*engine*" + search(query, index, t, 2) def test_filter_combinations(index, query): diff --git a/tests/unit/test_query_types.py b/tests/unit/test_query_types.py index 17426868..1e9fdb08 100644 --- a/tests/unit/test_query_types.py +++ b/tests/unit/test_query_types.py @@ -252,3 +252,28 @@ def test_query_modifiers(query): assert query._no_stopwords assert query._with_scores assert query._fields == ("test",) + + +@pytest.mark.parametrize( + "query", + [ + CountQuery(), + FilterQuery(), + VectorQuery(vector=[1, 2, 3], vector_field_name="vector"), + RangeQuery(vector=[1, 2, 3], vector_field_name="vector"), + ], +) +def test_string_filter_expressions(query): + # No filter + query.set_filter("*") + assert query._filter_expression == "*" + + # Simple full text search + query.set_filter("hello world") + assert query._filter_expression == "hello world" + assert query.query_string().__contains__("hello world") + + # Optional flag + query.set_filter("~(@desciption:(hello | world))") + assert query._filter_expression == "~(@desciption:(hello | world))" + assert query.query_string().__contains__("~(@desciption:(hello | world))")