Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam Partee committed Nov 22, 2023
1 parent f0bab76 commit fbe1891
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 24 deletions.
26 changes: 6 additions & 20 deletions redisvl/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,7 @@ def _key(key_value: str, prefix: str, key_separator: str) -> str:
else:
return f"{prefix}{key_separator}{key_value}"

def _create_key(
self,
obj: Dict[str, Any],
key_field: Optional[str] = None
) -> str:
def _create_key(self, obj: Dict[str, Any], key_field: Optional[str] = None) -> str:
"""Construct a Redis key for a given object, optionally using a
specified field from the object as the key.
Expand Down Expand Up @@ -77,9 +73,7 @@ def _create_key(
)

@staticmethod
def _preprocess(
obj: Any, preprocess: Optional[Callable] = None
) -> Dict[str, Any]:
def _preprocess(obj: Any, preprocess: Optional[Callable] = None) -> Dict[str, Any]:
"""Apply a preprocessing function to the object if provided.
Args:
Expand Down Expand Up @@ -203,9 +197,7 @@ def write(
length of objects.
"""
if keys and len(keys) != len(objects): # type: ignore
raise ValueError(
"Length of keys does not match the length of objects"
)
raise ValueError("Length of keys does not match the length of objects")

if batch_size is None:
batch_size = (
Expand Down Expand Up @@ -267,9 +259,7 @@ async def awrite(
length of objects.
"""
if keys and len(keys) != len(objects): # type: ignore
raise ValueError(
"Length of keys does not match the length of objects"
)
raise ValueError("Length of keys does not match the length of objects")

if not concurrency:
concurrency = self.DEFAULT_WRITE_CONCURRENCY
Expand All @@ -289,19 +279,15 @@ async def _load(obj: Dict[str, Any], key: Optional[str] = None) -> None:

if keys_iterator:
tasks = [
asyncio.create_task(
_load(obj, next(keys_iterator))) for obj in objects
asyncio.create_task(_load(obj, next(keys_iterator))) for obj in objects
]
else:
tasks = [asyncio.create_task(_load(obj)) for obj in objects]

await asyncio.gather(*tasks)

def get(
self,
redis_client: Redis,
keys: Iterable[str],
batch_size: Optional[int] = None
self, redis_client: Redis, keys: Iterable[str], batch_size: Optional[int] = None
) -> List[Dict[str, Any]]:
"""Retrieve objects from Redis by keys.
Expand Down
7 changes: 3 additions & 4 deletions tests/unit/test_query.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
import pytest

from redis.commands.search.document import Document
from redis.commands.search.result import Result
from redis.commands.search.query import Query
from redis.commands.search.result import Result

from redisvl.query import CountQuery, FilterQuery, VectorQuery
from redisvl.index import process_results
from redisvl.query import CountQuery, FilterQuery, VectorQuery
from redisvl.query.filter import FilterExpression, Tag


# Sample data for testing
sample_vector = [0.1, 0.2, 0.3, 0.4]


# Test Cases


def test_count_query():
# Create a filter expression
filter_expression = Tag("brand") == "Nike"
Expand Down

0 comments on commit fbe1891

Please sign in to comment.