Skip to content

Commit

Permalink
formatting and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerhutcherson committed Sep 6, 2024
1 parent a3ff11d commit 2e2004a
Show file tree
Hide file tree
Showing 3 changed files with 224 additions and 13 deletions.
4 changes: 1 addition & 3 deletions tests/integration/test_async_search_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,7 @@ async def test_search_index_client(async_client, index_schema):
async def test_search_index_set_client(async_client, client, async_index):
await async_index.set_client(async_client)
assert async_index.client == async_client
# should not be able to set the sync client here
with pytest.raises(TypeError):
await async_index.set_client(client)
await async_index.set_client(client)

async_index.disconnect()
assert async_index.client == None
Expand Down
230 changes: 222 additions & 8 deletions tests/integration/test_llmcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from redis.exceptions import ConnectionError

from redisvl.extensions.llmcache import SemanticCache
from redisvl.index.index import SearchIndex
from redisvl.index.index import AsyncSearchIndex, SearchIndex
from redisvl.query.filter import Num, Tag, Text
from redisvl.utils.vectorize import HFTextVectorizer

Expand Down Expand Up @@ -89,6 +89,33 @@ def test_reset_ttl(cache):
assert cache.ttl is None


def test_get_index(cache):
assert isinstance(cache.index, SearchIndex)


@pytest.mark.asyncio
async def test_get_async_index(cache):
aindex = await cache._get_async_index()
assert isinstance(aindex, AsyncSearchIndex)


@pytest.mark.asyncio
async def test_get_async_index_from_provided_client(cache_with_redis_client):
aindex = await cache_with_redis_client._get_async_index()
assert isinstance(aindex, AsyncSearchIndex)


def test_delete(cache_no_cleanup):
cache_no_cleanup.delete()
assert not cache_no_cleanup.index.exists()


@pytest.mark.asyncio
async def test_async_delete(cache_no_cleanup):
await cache_no_cleanup.adelete()
assert not cache_no_cleanup.index.exists()


def test_store_and_check(cache, vectorizer):
prompt = "This is a test prompt."
response = "This is a test response."
Expand All @@ -103,6 +130,21 @@ def test_store_and_check(cache, vectorizer):
assert "metadata" not in check_result[0]


@pytest.mark.asyncio
async def test_async_store_and_check(cache, vectorizer):
prompt = "This is a test prompt."
response = "This is a test response."
vector = vectorizer.embed(prompt)

await cache.astore(prompt, response, vector=vector)
check_result = await cache.acheck(vector=vector, distance_threshold=0.4)

assert len(check_result) == 1
print(check_result, flush=True)
assert response == check_result[0]["response"]
assert "metadata" not in check_result[0]


def test_return_fields(cache, vectorizer):
prompt = "This is a test prompt."
response = "This is a test response."
Expand Down Expand Up @@ -140,6 +182,44 @@ def test_return_fields(cache, vectorizer):
assert set(check_result[0].keys()) == set(fields)


@pytest.mark.asyncio
async def test_async_return_fields(cache, vectorizer):
prompt = "This is a test prompt."
response = "This is a test response."
vector = vectorizer.embed(prompt)

await cache.astore(prompt, response, vector=vector)

# check default return fields
check_result = await cache.acheck(vector=vector)
assert set(check_result[0].keys()) == {
"key",
"entry_id",
"prompt",
"response",
"vector_distance",
"inserted_at",
"updated_at",
}

# check specific return fields
fields = [
"key",
"entry_id",
"prompt",
"response",
"vector_distance",
]
check_result = await cache.acheck(vector=vector, return_fields=fields)
assert set(check_result[0].keys()) == set(fields)

# check only some return fields
fields = ["inserted_at", "updated_at"]
check_result = await cache.acheck(vector=vector, return_fields=fields)
fields.append("key")
assert set(check_result[0].keys()) == set(fields)


# Test clearing the cache
def test_clear(cache, vectorizer):
prompt = "This is a test prompt."
Expand All @@ -153,6 +233,19 @@ def test_clear(cache, vectorizer):
assert len(check_result) == 0


@pytest.mark.asyncio
async def test_async_clear(cache, vectorizer):
prompt = "This is a test prompt."
response = "This is a test response."
vector = vectorizer.embed(prompt)

await cache.astore(prompt, response, vector=vector)
await cache.aclear()
check_result = await cache.acheck(vector=vector)

assert len(check_result) == 0


# Test TTL functionality
def test_ttl_expiration(cache_with_ttl, vectorizer):
prompt = "This is a test prompt."
Expand All @@ -166,6 +259,19 @@ def test_ttl_expiration(cache_with_ttl, vectorizer):
assert len(check_result) == 0


@pytest.mark.asyncio
async def test_async_ttl_expiration(cache_with_ttl, vectorizer):
prompt = "This is a test prompt."
response = "This is a test response."
vector = vectorizer.embed(prompt)

await cache_with_ttl.astore(prompt, response, vector=vector)
sleep(3)

check_result = await cache_with_ttl.acheck(vector=vector)
assert len(check_result) == 0


def test_ttl_refresh(cache_with_ttl, vectorizer):
prompt = "This is a test prompt."
response = "This is a test response."
Expand All @@ -180,6 +286,21 @@ def test_ttl_refresh(cache_with_ttl, vectorizer):
assert len(check_result) == 1


@pytest.mark.asyncio
async def test_async_ttl_refresh(cache_with_ttl, vectorizer):
prompt = "This is a test prompt."
response = "This is a test response."
vector = vectorizer.embed(prompt)

await cache_with_ttl.astore(prompt, response, vector=vector)

for _ in range(3):
sleep(1)
check_result = await cache_with_ttl.acheck(vector=vector)

assert len(check_result) == 1


# Test manual expiration of single document
def test_drop_document(cache, vectorizer):
prompt = "This is a test prompt."
Expand All @@ -194,6 +315,20 @@ def test_drop_document(cache, vectorizer):
assert len(recheck_result) == 0


@pytest.mark.asyncio
async def test_async_drop_document(cache, vectorizer):
prompt = "This is a test prompt."
response = "This is a test response."
vector = vectorizer.embed(prompt)

await cache.astore(prompt, response, vector=vector)
check_result = await cache.acheck(vector=vector)

await cache.adrop(ids=[check_result[0]["entry_id"]])
recheck_result = await cache.acheck(vector=vector)
assert len(recheck_result) == 0


# Test manual expiration of multiple documents
def test_drop_documents(cache, vectorizer):
prompts = [
Expand All @@ -219,6 +354,31 @@ def test_drop_documents(cache, vectorizer):
assert len(recheck_result) == 1


@pytest.mark.asyncio
async def test_async_drop_documents(cache, vectorizer):
prompts = [
"This is a test prompt.",
"This is also test prompt.",
"This is another test prompt.",
]
responses = [
"This is a test response.",
"This is also test response.",
"This is a another test response.",
]
for prompt, response in zip(prompts, responses):
vector = vectorizer.embed(prompt)
await cache.astore(prompt, response, vector=vector)

check_result = await cache.acheck(vector=vector, num_results=3)
print(check_result, flush=True)
ids = [r["entry_id"] for r in check_result[0:2]] # drop first 2 entries
await cache.adrop(ids=ids)

recheck_result = await cache.acheck(vector=vector, num_results=3)
assert len(recheck_result) == 1


# Test updating document fields
def test_updating_document(cache):
prompt = "This is a test prompt."
Expand All @@ -240,6 +400,27 @@ def test_updating_document(cache):
assert updated_result[0]["updated_at"] > check_result[0]["updated_at"]


@pytest.mark.asyncio
async def test_async_updating_document(cache):
prompt = "This is a test prompt."
response = "This is a test response."
await cache.astore(prompt=prompt, response=response)

check_result = await cache.acheck(prompt=prompt, return_fields=["updated_at"])
key = check_result[0]["key"]

sleep(1)

metadata = {"foo": "bar"}
await cache.aupdate(key=key, metadata=metadata)

updated_result = await cache.acheck(
prompt=prompt, return_fields=["updated_at", "metadata"]
)
assert updated_result[0]["metadata"] == metadata
assert updated_result[0]["updated_at"] > check_result[0]["updated_at"]


def test_ttl_expiration_after_update(cache_with_ttl, vectorizer):
prompt = "This is a test prompt."
response = "This is a test response."
Expand All @@ -255,6 +436,22 @@ def test_ttl_expiration_after_update(cache_with_ttl, vectorizer):
assert len(check_result) == 0


@pytest.mark.asyncio
async def test_async_ttl_expiration_after_update(cache_with_ttl, vectorizer):
prompt = "This is a test prompt."
response = "This is a test response."
vector = vectorizer.embed(prompt)
cache_with_ttl.set_ttl(4)

assert cache_with_ttl.ttl == 4

await cache_with_ttl.astore(prompt, response, vector=vector)
sleep(5)

check_result = await cache_with_ttl.acheck(vector=vector)
assert len(check_result) == 0


# Test check behavior with no match
def test_check_no_match(cache, vectorizer):
vector = vectorizer.embed("Some random sentence.")
Expand All @@ -270,6 +467,15 @@ def test_check_invalid_input(cache):
cache.check(prompt="test", return_fields="bad value")


@pytest.mark.asyncio
async def test_async_check_invalid_input(cache):
with pytest.raises(ValueError):
await cache.acheck()

with pytest.raises(TypeError):
await cache.acheck(prompt="test", return_fields="bad value")


def test_bad_connection_info(vectorizer):
with pytest.raises(ConnectionError):
SemanticCache(
Expand Down Expand Up @@ -357,10 +563,6 @@ def test_multiple_items(cache, vectorizer):
assert "metadata" not in check_result[0]


def test_get_index(cache):
assert isinstance(cache.index, SearchIndex)


def test_store_and_check_with_provided_client(cache_with_redis_client, vectorizer):
prompt = "This is a test prompt."
response = "This is a test response."
Expand All @@ -375,9 +577,21 @@ def test_store_and_check_with_provided_client(cache_with_redis_client, vectorize
assert "metadata" not in check_result[0]


def test_delete(cache_no_cleanup):
cache_no_cleanup.delete()
assert not cache_no_cleanup.index.exists()
@pytest.mark.asyncio
async def test_async_store_and_check_with_provided_client(
cache_with_redis_client, vectorizer
):
prompt = "This is a test prompt."
response = "This is a test response."
vector = vectorizer.embed(prompt)

await cache_with_redis_client.astore(prompt, response, vector=vector)
check_result = await cache_with_redis_client.acheck(vector=vector)

assert len(check_result) == 1
print(check_result, flush=True)
assert response == check_result[0]["response"]
assert "metadata" not in check_result[0]


def test_vector_size(cache, vectorizer):
Expand Down
3 changes: 1 addition & 2 deletions tests/integration/test_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,8 +464,7 @@ def test_semantic_add_and_get_relevant(semantic_session):
default_context = semantic_session.get_relevant("list of fruits and vegetables")
assert len(default_context) == 5 # 2 pairs of prompt:response, and system
assert default_context == semantic_session.get_relevant(
"list of fruits and vegetables",
distance_threshold=0.5
"list of fruits and vegetables", distance_threshold=0.5
)

# test tool calls can also be returned
Expand Down

0 comments on commit 2e2004a

Please sign in to comment.