From 19dedcbe94eeab6f6f4aa0c1dbf3083642fb48f9 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Tue, 8 Oct 2024 16:52:39 -0400 Subject: [PATCH] Expose aggregation API from SearchIndex (#230) In order to support more advanced queries, we expose the `aggregate` method to pass through to the core Redis FT.AGGREGATE API. This PR also simplifies and standardizes error handling for Redis searches/aggregations on the index. --- redisvl/exceptions.py | 4 + redisvl/extensions/router/semantic.py | 12 +- redisvl/index/index.py | 112 +++++++++---------- tests/integration/test_async_search_index.py | 7 +- tests/integration/test_search_index.py | 7 +- 5 files changed, 73 insertions(+), 69 deletions(-) diff --git a/redisvl/exceptions.py b/redisvl/exceptions.py index 79b165e3..e645e3e2 100644 --- a/redisvl/exceptions.py +++ b/redisvl/exceptions.py @@ -4,3 +4,7 @@ class RedisVLException(Exception): class RedisModuleVersionError(RedisVLException): """Invalid module versions installed""" + + +class RedisSearchError(RedisVLException): + """Error while performing a search or aggregate request""" diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py index 52ea321f..f17e5ab3 100644 --- a/redisvl/extensions/router/semantic.py +++ b/redisvl/extensions/router/semantic.py @@ -256,9 +256,9 @@ def _classify_route( ) try: - aggregation_result: AggregateResult = self._index.client.ft( # type: ignore - self._index.name - ).aggregate(aggregate_request, vector_range_query.params) + aggregation_result: AggregateResult = self._index.aggregate( + aggregate_request, vector_range_query.params + ) except ResponseError as e: if "VSS is not yet supported on FT.AGGREGATE" in str(e): raise RuntimeError( @@ -308,9 +308,9 @@ def _classify_multi_route( ) try: - aggregation_result: AggregateResult = self._index.client.ft( # type: ignore - self._index.name - ).aggregate(aggregate_request, vector_range_query.params) + aggregation_result: AggregateResult = self._index.aggregate( + aggregate_request, vector_range_query.params + ) except ResponseError as e: if "VSS is not yet supported on FT.AGGREGATE" in str(e): raise RuntimeError( diff --git a/redisvl/index/index.py b/redisvl/index/index.py index 0c4d3e2a..025d6ce9 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -17,6 +17,7 @@ ) if TYPE_CHECKING: + from redis.commands.search.aggregation import AggregateResult from redis.commands.search.document import Document from redis.commands.search.result import Result from redisvl.query.query import BaseQuery @@ -25,7 +26,7 @@ import redis.asyncio as aredis from redis.commands.search.indexDefinition import IndexDefinition -from redisvl.exceptions import RedisModuleVersionError +from redisvl.exceptions import RedisModuleVersionError, RedisSearchError from redisvl.index.storage import BaseStorage, HashStorage, JsonStorage from redisvl.query import BaseQuery, CountQuery, FilterQuery from redisvl.query.filter import FilterExpression @@ -123,36 +124,6 @@ async def wrapper(self, *args, **kwargs): return decorator -def check_index_exists(): - def decorator(func): - @wraps(func) - def wrapper(self, *args, **kwargs): - if not self.exists(): - raise RuntimeError( - f"Index has not been created. Must be created before calling {func.__name__}" - ) - return func(self, *args, **kwargs) - - return wrapper - - return decorator - - -def check_async_index_exists(): - def decorator(func): - @wraps(func) - async def wrapper(self, *args, **kwargs): - if not await self.exists(): - raise ValueError( - f"Index has not been created. Must be created before calling {func.__name__}" - ) - return await func(self, *args, **kwargs) - - return wrapper - - return decorator - - class BaseSearchIndex: """Base search engine class""" @@ -486,7 +457,6 @@ def create(self, overwrite: bool = False, drop: bool = False) -> None: logger.exception("Error while trying to create the index") raise - @check_index_exists() def delete(self, drop: bool = True): """Delete the search index while optionally dropping all keys associated with the index. @@ -502,8 +472,8 @@ def delete(self, drop: bool = True): self._redis_client.ft(self.schema.index.name).dropindex( # type: ignore delete_documents=drop ) - except: - logger.exception("Error while deleting index") + except Exception as e: + raise RedisSearchError(f"Error while deleting index: {str(e)}") from e def clear(self) -> int: """Clear all keys in Redis associated with the index, leaving the index @@ -629,13 +599,29 @@ def fetch(self, id: str) -> Optional[Dict[str, Any]]: return convert_bytes(obj[0]) return None - @check_index_exists() + def aggregate(self, *args, **kwargs) -> "AggregateResult": + """Perform an aggregation operation against the index. + + Wrapper around the aggregation API that adds the index name + to the query and passes along the rest of the arguments + to the redis-py ft().aggregate() method. + + Returns: + Result: Raw Redis aggregation results. + """ + try: + return self._redis_client.ft(self.schema.index.name).aggregate( # type: ignore + *args, **kwargs + ) + except Exception as e: + raise RedisSearchError(f"Error while aggregating: {str(e)}") from e + def search(self, *args, **kwargs) -> "Result": """Perform a search against the index. - Wrapper around redis.search.Search that adds the index name - to the search query and passes along the rest of the arguments - to the redis-py ft.search() method. + Wrapper around the search API that adds the index name + to the query and passes along the rest of the arguments + to the redis-py ft().search() method. Returns: Result: Raw Redis search results. @@ -644,9 +630,8 @@ def search(self, *args, **kwargs) -> "Result": return self._redis_client.ft(self.schema.index.name).search( # type: ignore *args, **kwargs ) - except: - logger.exception("Error while searching") - raise + except Exception as e: + raise RedisSearchError(f"Error while searching: {str(e)}") from e def _query(self, query: BaseQuery) -> List[Dict[str, Any]]: """Execute a query and process results.""" @@ -752,11 +737,11 @@ def _info(name: str, redis_client: redis.Redis) -> Dict[str, Any]: """Run FT.INFO to fetch information about the index.""" try: return convert_bytes(redis_client.ft(name).info()) # type: ignore - except: - logger.exception(f"Error while fetching {name} index info") - raise + except Exception as e: + raise RedisSearchError( + f"Error while fetching {name} index info: {str(e)}" + ) from e - @check_index_exists() def info(self, name: Optional[str] = None) -> Dict[str, Any]: """Get information about the index. @@ -1010,7 +995,6 @@ async def create(self, overwrite: bool = False, drop: bool = False) -> None: logger.exception("Error while trying to create the index") raise - @check_async_index_exists() async def delete(self, drop: bool = True): """Delete the search index. @@ -1025,9 +1009,8 @@ async def delete(self, drop: bool = True): await self._redis_client.ft(self.schema.index.name).dropindex( # type: ignore delete_documents=drop ) - except: - logger.exception("Error while deleting index") - raise + except Exception as e: + raise RedisSearchError(f"Error while deleting index: {str(e)}") from e async def clear(self) -> int: """Clear all keys in Redis associated with the index, leaving the index @@ -1152,7 +1135,23 @@ async def fetch(self, id: str) -> Optional[Dict[str, Any]]: return convert_bytes(obj[0]) return None - @check_async_index_exists() + async def aggregate(self, *args, **kwargs) -> "AggregateResult": + """Perform an aggregation operation against the index. + + Wrapper around the aggregation API that adds the index name + to the query and passes along the rest of the arguments + to the redis-py ft().aggregate() method. + + Returns: + Result: Raw Redis aggregation results. + """ + try: + return await self._redis_client.ft(self.schema.index.name).aggregate( # type: ignore + *args, **kwargs + ) + except Exception as e: + raise RedisSearchError(f"Error while aggregating: {str(e)}") from e + async def search(self, *args, **kwargs) -> "Result": """Perform a search on this index. @@ -1167,9 +1166,8 @@ async def search(self, *args, **kwargs) -> "Result": return await self._redis_client.ft(self.schema.index.name).search( # type: ignore *args, **kwargs ) - except: - logger.exception("Error while searching") - raise + except Exception as e: + raise RedisSearchError(f"Error while searching: {str(e)}") from e async def _query(self, query: BaseQuery) -> List[Dict[str, Any]]: """Asynchronously execute a query and process results.""" @@ -1275,11 +1273,11 @@ async def exists(self) -> bool: async def _info(name: str, redis_client: aredis.Redis) -> Dict[str, Any]: try: return convert_bytes(await redis_client.ft(name).info()) # type: ignore - except: - logger.exception(f"Error while fetching {name} index info") - raise + except Exception as e: + raise RedisSearchError( + f"Error while fetching {name} index info: {str(e)}" + ) from e - @check_async_index_exists() async def info(self, name: Optional[str] = None) -> Dict[str, Any]: """Get information about the index. diff --git a/tests/integration/test_async_search_index.py b/tests/integration/test_async_search_index.py index a5c937b6..9dc8460d 100644 --- a/tests/integration/test_async_search_index.py +++ b/tests/integration/test_async_search_index.py @@ -1,5 +1,6 @@ import pytest +from redisvl.exceptions import RedisSearchError from redisvl.index import AsyncSearchIndex from redisvl.query import VectorQuery from redisvl.redis.utils import convert_bytes @@ -291,7 +292,7 @@ async def test_check_index_exists_before_delete(async_client, async_index): await async_index.set_client(async_client) await async_index.create(overwrite=True, drop=True) await async_index.delete(drop=True) - with pytest.raises(ValueError): + with pytest.raises(RedisSearchError): await async_index.delete() @@ -307,7 +308,7 @@ async def test_check_index_exists_before_search(async_client, async_index): return_fields=["user", "credit_score", "age", "job", "location"], num_results=7, ) - with pytest.raises(ValueError): + with pytest.raises(RedisSearchError): await async_index.search(query.query, query_params=query.params) @@ -317,5 +318,5 @@ async def test_check_index_exists_before_info(async_client, async_index): await async_index.create(overwrite=True, drop=True) await async_index.delete(drop=True) - with pytest.raises(ValueError): + with pytest.raises(RedisSearchError): await async_index.info() diff --git a/tests/integration/test_search_index.py b/tests/integration/test_search_index.py index 574cec91..36781a28 100644 --- a/tests/integration/test_search_index.py +++ b/tests/integration/test_search_index.py @@ -1,5 +1,6 @@ import pytest +from redisvl.exceptions import RedisSearchError from redisvl.index import SearchIndex from redisvl.query import VectorQuery from redisvl.redis.connection import RedisConnectionFactory, validate_modules @@ -251,7 +252,7 @@ def test_check_index_exists_before_delete(client, index): index.set_client(client) index.create(overwrite=True, drop=True) index.delete(drop=True) - with pytest.raises(RuntimeError): + with pytest.raises(RedisSearchError): index.delete() @@ -266,7 +267,7 @@ def test_check_index_exists_before_search(client, index): return_fields=["user", "credit_score", "age", "job", "location"], num_results=7, ) - with pytest.raises(RuntimeError): + with pytest.raises(RedisSearchError): index.search(query.query, query_params=query.params) @@ -275,7 +276,7 @@ def test_check_index_exists_before_info(client, index): index.create(overwrite=True, drop=True) index.delete(drop=True) - with pytest.raises(RuntimeError): + with pytest.raises(RedisSearchError): index.info()