From fbe1891385e192d9d1c61d7f8eb0f465cf028e32 Mon Sep 17 00:00:00 2001 From: Sam Partee Date: Tue, 21 Nov 2023 18:41:07 -0800 Subject: [PATCH] formatting --- redisvl/storage.py | 26 ++++++-------------------- tests/unit/test_query.py | 7 +++---- 2 files changed, 9 insertions(+), 24 deletions(-) diff --git a/redisvl/storage.py b/redisvl/storage.py index 2e92ca40..a1b9daba 100644 --- a/redisvl/storage.py +++ b/redisvl/storage.py @@ -45,11 +45,7 @@ def _key(key_value: str, prefix: str, key_separator: str) -> str: else: return f"{prefix}{key_separator}{key_value}" - def _create_key( - self, - obj: Dict[str, Any], - key_field: Optional[str] = None - ) -> str: + def _create_key(self, obj: Dict[str, Any], key_field: Optional[str] = None) -> str: """Construct a Redis key for a given object, optionally using a specified field from the object as the key. @@ -77,9 +73,7 @@ def _create_key( ) @staticmethod - def _preprocess( - obj: Any, preprocess: Optional[Callable] = None - ) -> Dict[str, Any]: + def _preprocess(obj: Any, preprocess: Optional[Callable] = None) -> Dict[str, Any]: """Apply a preprocessing function to the object if provided. Args: @@ -203,9 +197,7 @@ def write( length of objects. """ if keys and len(keys) != len(objects): # type: ignore - raise ValueError( - "Length of keys does not match the length of objects" - ) + raise ValueError("Length of keys does not match the length of objects") if batch_size is None: batch_size = ( @@ -267,9 +259,7 @@ async def awrite( length of objects. """ if keys and len(keys) != len(objects): # type: ignore - raise ValueError( - "Length of keys does not match the length of objects" - ) + raise ValueError("Length of keys does not match the length of objects") if not concurrency: concurrency = self.DEFAULT_WRITE_CONCURRENCY @@ -289,8 +279,7 @@ async def _load(obj: Dict[str, Any], key: Optional[str] = None) -> None: if keys_iterator: tasks = [ - asyncio.create_task( - _load(obj, next(keys_iterator))) for obj in objects + asyncio.create_task(_load(obj, next(keys_iterator))) for obj in objects ] else: tasks = [asyncio.create_task(_load(obj)) for obj in objects] @@ -298,10 +287,7 @@ async def _load(obj: Dict[str, Any], key: Optional[str] = None) -> None: await asyncio.gather(*tasks) def get( - self, - redis_client: Redis, - keys: Iterable[str], - batch_size: Optional[int] = None + self, redis_client: Redis, keys: Iterable[str], batch_size: Optional[int] = None ) -> List[Dict[str, Any]]: """Retrieve objects from Redis by keys. diff --git a/tests/unit/test_query.py b/tests/unit/test_query.py index 9d865e0a..8fc380fe 100644 --- a/tests/unit/test_query.py +++ b/tests/unit/test_query.py @@ -1,20 +1,19 @@ import pytest - from redis.commands.search.document import Document -from redis.commands.search.result import Result from redis.commands.search.query import Query +from redis.commands.search.result import Result -from redisvl.query import CountQuery, FilterQuery, VectorQuery from redisvl.index import process_results +from redisvl.query import CountQuery, FilterQuery, VectorQuery from redisvl.query.filter import FilterExpression, Tag - # Sample data for testing sample_vector = [0.1, 0.2, 0.3, 0.4] # Test Cases + def test_count_query(): # Create a filter expression filter_expression = Tag("brand") == "Nike"