Skip to content

Commit

Permalink
use index method wrappers and add async support
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerhutcherson committed Nov 20, 2023
1 parent 74eed62 commit 77777ad
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 11 deletions.
57 changes: 46 additions & 11 deletions redisvl/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,18 @@
)
from redisvl.storage import BaseStorage, HashStorage, JsonStorage
from redisvl.utils.connection import (
check_async_index_exists,
check_connected,
check_index_exists,
get_async_redis_connection,
get_redis_connection,
)
from redisvl.utils.utils import check_redis_modules_exist, convert_bytes, make_dict
from redisvl.utils.utils import (
check_async_modules_present,
check_modules_present,
convert_bytes,
make_dict,
)


def process_results(
Expand Down Expand Up @@ -211,14 +218,18 @@ def from_existing(
raise NotImplementedError

@check_connected("_redis_conn")
@check_modules_present("_redis_conn")
@check_index_exists()
def search(self, *args, **kwargs) -> Union["Result", Any]:
raise NotImplementedError

@check_connected("_redis_conn")
@check_modules_present("_redis_conn")
@check_index_exists()
def query(self, query: "BaseQuery") -> List[Dict[str, Any]]:
raise NotImplementedError

def connect(self, redis_url: str, **kwargs):
def connect(self, url: str, **kwargs):
"""Connect to a Redis instance."""
raise NotImplementedError

Expand All @@ -244,15 +255,24 @@ def key(self, key_value: str) -> str:
)

@check_connected("_redis_conn")
@check_modules_present("_redis_conn")
@check_index_exists()
def info(self) -> Dict[str, Any]:
raise NotImplementedError

@check_connected("_redis_conn")
@check_modules_present("_redis_conn")
def create(self, overwrite: Optional[bool] = False):
raise NotImplementedError

@check_connected("_redis_conn")
@check_modules_present("_redis_conn")
@check_index_exists()
def delete(self, drop: bool = True):
raise NotImplementedError

@check_connected("_redis_conn")
@check_modules_present("_redis_conn")
def load(
self,
data: Iterable[Dict[str, Any]],
Expand Down Expand Up @@ -332,6 +352,7 @@ def connect(self, redis_url: Optional[str] = None, **kwargs):
return self

@check_connected("_redis_conn")
@check_modules_present("_redis_conn")
def create(self, overwrite: Optional[bool] = False) -> None:
"""Create an index in Redis from this SearchIndex object.
Expand All @@ -342,9 +363,6 @@ def create(self, overwrite: Optional[bool] = False) -> None:
RuntimeError: If the index already exists and 'overwrite' is False.
ValueError: If no fields are defined for the index.
"""
# Ensure that the Redis connection has the necessary modules.
check_redis_modules_exist(self._redis_conn)

# Check that fields are defined.
fields = self._schema.index_fields
if not fields:
Expand All @@ -368,6 +386,8 @@ def create(self, overwrite: Optional[bool] = False) -> None:
)

@check_connected("_redis_conn")
@check_modules_present("_redis_conn")
@check_index_exists()
def delete(self, drop: bool = True):
"""Delete the search index.
Expand All @@ -378,11 +398,10 @@ def delete(self, drop: bool = True):
redis.exceptions.ResponseError: If the index does not exist.
"""
# Delete the search index
self._redis_conn.ft(self._schema.index.name).dropindex(
delete_documents=drop
) # type: ignore
self._redis_conn.ft(self._schema.index.name).dropindex(delete_documents=drop)

@check_connected("_redis_conn")
@check_modules_present("_redis_conn")
def load(
self,
data: Iterable[Any],
Expand Down Expand Up @@ -423,6 +442,8 @@ def load(
)

@check_connected("_redis_conn")
@check_modules_present("_redis_conn")
@check_index_exists()
def search(self, *args, **kwargs) -> Union["Result", Any]:
"""Perform a search on this index.
Expand All @@ -439,6 +460,8 @@ def search(self, *args, **kwargs) -> Union["Result", Any]:
return results

@check_connected("_redis_conn")
@check_modules_present("_redis_conn")
@check_index_exists()
def query(self, query: "BaseQuery") -> List[Dict[str, Any]]:
"""Run a query on this index.
Expand All @@ -459,6 +482,7 @@ def query(self, query: "BaseQuery") -> List[Dict[str, Any]]:
)

@check_connected("_redis_conn")
@check_modules_present("_redis_conn")
def exists(self) -> bool:
"""Check if the index exists in Redis.
Expand All @@ -469,6 +493,8 @@ def exists(self) -> bool:
return self._schema.index.name in indices

@check_connected("_redis_conn")
@check_modules_present("_redis_conn")
@check_index_exists()
def info(self) -> Dict[str, Any]:
"""Get information about the index.
Expand Down Expand Up @@ -549,6 +575,7 @@ def connect(self, redis_url: Optional[str] = None, **kwargs):
return self

@check_connected("_redis_conn")
@check_async_modules_present("_redis_conn")
async def create(self, overwrite: Optional[bool] = False) -> None:
"""Asynchronously create an index in Redis from this SearchIndex object.
Expand All @@ -558,9 +585,6 @@ async def create(self, overwrite: Optional[bool] = False) -> None:
Raises:
RuntimeError: If the index already exists and 'overwrite' is False.
"""
# TODO - enable async version of this
# check_redis_modules_exist(self._redis_conn)

fields = self._schema.index_fields
if not fields:
raise ValueError("No fields defined for index")
Expand All @@ -583,6 +607,8 @@ async def create(self, overwrite: Optional[bool] = False) -> None:
)

@check_connected("_redis_conn")
@check_async_modules_present("_redis_conn")
@check_async_index_exists()
async def delete(self, drop: bool = True):
"""Delete the search index.
Expand All @@ -598,6 +624,7 @@ async def delete(self, drop: bool = True):
) # type: ignore

@check_connected("_redis_conn")
@check_async_modules_present("_redis_conn")
async def load(
self,
data: Iterable[Any],
Expand Down Expand Up @@ -638,6 +665,8 @@ async def load(
)

@check_connected("_redis_conn")
@check_async_modules_present("_redis_conn")
@check_async_index_exists()
async def search(self, *args, **kwargs) -> Union["Result", Any]:
"""Perform a search on this index.
Expand All @@ -653,6 +682,9 @@ async def search(self, *args, **kwargs) -> Union["Result", Any]:
) # type: ignore
return results

@check_connected("_redis_conn")
@check_async_modules_present("_redis_conn")
@check_async_index_exists()
async def query(self, query: "BaseQuery") -> List[Dict[str, Any]]:
"""Run a query on this index.
Expand All @@ -673,6 +705,7 @@ async def query(self, query: "BaseQuery") -> List[Dict[str, Any]]:
)

@check_connected("_redis_conn")
@check_async_modules_present("_redis_conn")
async def exists(self) -> bool:
"""Check if the index exists in Redis.
Expand All @@ -683,6 +716,8 @@ async def exists(self) -> bool:
return self._schema.index.name in convert_bytes(indices)

@check_connected("_redis_conn")
@check_async_modules_present("_redis_conn")
@check_async_index_exists()
async def info(self) -> Dict[str, Any]:
"""Get information about the index.
Expand Down
30 changes: 30 additions & 0 deletions redisvl/utils/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,36 @@ def get_address_from_env():
return addr


def check_index_exists():
def decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if not self.exists():
raise ValueError(
f"Index has not been created. Must be created before calling {func.__name__}"
)
return func(self, *args, **kwargs)

return wrapper

return decorator


async def check_async_index_exists():
async def decorator(func):
@wraps(func)
async def wrapper(self, *args, **kwargs):
if not await self.exists():
raise ValueError(
f"Index has not been created. Must be created before calling {func.__name__}"
)
return func(self, *args, **kwargs)

return wrapper

return decorator


def check_connected(client_variable_name: str):
def decorator(func):
@wraps(func)
Expand Down
48 changes: 48 additions & 0 deletions redisvl/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import wraps
from typing import Any, List

import numpy as np
Expand Down Expand Up @@ -53,6 +54,53 @@ def check_redis_modules_exist(client) -> None:
raise ValueError(error_message)


async def check_async_redis_modules_exist(client) -> None:
"""Check if the correct Redis modules are installed."""
installed_modules = await client.module_list()
installed_modules = {
module[b"name"].decode("utf-8"): module for module in installed_modules
}
for module in REDIS_REQUIRED_MODULES:
if module["name"] in installed_modules and int(
installed_modules[module["name"]][b"ver"]
) >= int(
module["ver"]
): # type: ignore[call-overload]
return
# otherwise raise error
error_message = (
"You must add the RediSearch (>= 2.4) module from Redis Stack. "
"Please refer to Redis Stack docs: https://redis.io/docs/stack/"
)
raise ValueError(error_message)


def check_modules_present(client_variable_name: str):
def decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
client = getattr(self, client_variable_name)
check_redis_modules_exist(client)
return func(self, *args, **kwargs)

return wrapper

return decorator


async def check_async_modules_present(client_variable_name: str):
async def decorator(func):
@wraps(func)
async def wrapper(self, *args, **kwargs):
client = getattr(self, client_variable_name)
await check_redis_modules_exist(client)
return func(self, *args, **kwargs)

return wrapper

return decorator


def array_to_buffer(array: List[float], dtype: Any = np.float32) -> bytes:
"""Convert a list of floats into a numpy byte string."""
return np.array(array).astype(dtype).tobytes()

0 comments on commit 77777ad

Please sign in to comment.