diff --git a/tests/integration/test_async_search_index.py b/tests/integration/test_async_search_index.py index e7ba1c3b..a5c937b6 100644 --- a/tests/integration/test_async_search_index.py +++ b/tests/integration/test_async_search_index.py @@ -152,9 +152,7 @@ async def test_search_index_client(async_client, index_schema): async def test_search_index_set_client(async_client, client, async_index): await async_index.set_client(async_client) assert async_index.client == async_client - # should not be able to set the sync client here - with pytest.raises(TypeError): - await async_index.set_client(client) + await async_index.set_client(client) async_index.disconnect() assert async_index.client == None diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index 2263b745..cbfa3e9c 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -6,7 +6,7 @@ from redis.exceptions import ConnectionError from redisvl.extensions.llmcache import SemanticCache -from redisvl.index.index import SearchIndex +from redisvl.index.index import AsyncSearchIndex, SearchIndex from redisvl.query.filter import Num, Tag, Text from redisvl.utils.vectorize import HFTextVectorizer @@ -89,6 +89,33 @@ def test_reset_ttl(cache): assert cache.ttl is None +def test_get_index(cache): + assert isinstance(cache.index, SearchIndex) + + +@pytest.mark.asyncio +async def test_get_async_index(cache): + aindex = await cache._get_async_index() + assert isinstance(aindex, AsyncSearchIndex) + + +@pytest.mark.asyncio +async def test_get_async_index_from_provided_client(cache_with_redis_client): + aindex = await cache_with_redis_client._get_async_index() + assert isinstance(aindex, AsyncSearchIndex) + + +def test_delete(cache_no_cleanup): + cache_no_cleanup.delete() + assert not cache_no_cleanup.index.exists() + + +@pytest.mark.asyncio +async def test_async_delete(cache_no_cleanup): + await cache_no_cleanup.adelete() + assert not cache_no_cleanup.index.exists() + + def test_store_and_check(cache, vectorizer): prompt = "This is a test prompt." response = "This is a test response." @@ -103,6 +130,21 @@ def test_store_and_check(cache, vectorizer): assert "metadata" not in check_result[0] +@pytest.mark.asyncio +async def test_async_store_and_check(cache, vectorizer): + prompt = "This is a test prompt." + response = "This is a test response." + vector = vectorizer.embed(prompt) + + await cache.astore(prompt, response, vector=vector) + check_result = await cache.acheck(vector=vector, distance_threshold=0.4) + + assert len(check_result) == 1 + print(check_result, flush=True) + assert response == check_result[0]["response"] + assert "metadata" not in check_result[0] + + def test_return_fields(cache, vectorizer): prompt = "This is a test prompt." response = "This is a test response." @@ -140,6 +182,44 @@ def test_return_fields(cache, vectorizer): assert set(check_result[0].keys()) == set(fields) +@pytest.mark.asyncio +async def test_async_return_fields(cache, vectorizer): + prompt = "This is a test prompt." + response = "This is a test response." + vector = vectorizer.embed(prompt) + + await cache.astore(prompt, response, vector=vector) + + # check default return fields + check_result = await cache.acheck(vector=vector) + assert set(check_result[0].keys()) == { + "key", + "entry_id", + "prompt", + "response", + "vector_distance", + "inserted_at", + "updated_at", + } + + # check specific return fields + fields = [ + "key", + "entry_id", + "prompt", + "response", + "vector_distance", + ] + check_result = await cache.acheck(vector=vector, return_fields=fields) + assert set(check_result[0].keys()) == set(fields) + + # check only some return fields + fields = ["inserted_at", "updated_at"] + check_result = await cache.acheck(vector=vector, return_fields=fields) + fields.append("key") + assert set(check_result[0].keys()) == set(fields) + + # Test clearing the cache def test_clear(cache, vectorizer): prompt = "This is a test prompt." @@ -153,6 +233,19 @@ def test_clear(cache, vectorizer): assert len(check_result) == 0 +@pytest.mark.asyncio +async def test_async_clear(cache, vectorizer): + prompt = "This is a test prompt." + response = "This is a test response." + vector = vectorizer.embed(prompt) + + await cache.astore(prompt, response, vector=vector) + await cache.aclear() + check_result = await cache.acheck(vector=vector) + + assert len(check_result) == 0 + + # Test TTL functionality def test_ttl_expiration(cache_with_ttl, vectorizer): prompt = "This is a test prompt." @@ -166,6 +259,19 @@ def test_ttl_expiration(cache_with_ttl, vectorizer): assert len(check_result) == 0 +@pytest.mark.asyncio +async def test_async_ttl_expiration(cache_with_ttl, vectorizer): + prompt = "This is a test prompt." + response = "This is a test response." + vector = vectorizer.embed(prompt) + + await cache_with_ttl.astore(prompt, response, vector=vector) + sleep(3) + + check_result = await cache_with_ttl.acheck(vector=vector) + assert len(check_result) == 0 + + def test_ttl_refresh(cache_with_ttl, vectorizer): prompt = "This is a test prompt." response = "This is a test response." @@ -180,6 +286,21 @@ def test_ttl_refresh(cache_with_ttl, vectorizer): assert len(check_result) == 1 +@pytest.mark.asyncio +async def test_async_ttl_refresh(cache_with_ttl, vectorizer): + prompt = "This is a test prompt." + response = "This is a test response." + vector = vectorizer.embed(prompt) + + await cache_with_ttl.astore(prompt, response, vector=vector) + + for _ in range(3): + sleep(1) + check_result = await cache_with_ttl.acheck(vector=vector) + + assert len(check_result) == 1 + + # Test manual expiration of single document def test_drop_document(cache, vectorizer): prompt = "This is a test prompt." @@ -194,6 +315,20 @@ def test_drop_document(cache, vectorizer): assert len(recheck_result) == 0 +@pytest.mark.asyncio +async def test_async_drop_document(cache, vectorizer): + prompt = "This is a test prompt." + response = "This is a test response." + vector = vectorizer.embed(prompt) + + await cache.astore(prompt, response, vector=vector) + check_result = await cache.acheck(vector=vector) + + await cache.adrop(ids=[check_result[0]["entry_id"]]) + recheck_result = await cache.acheck(vector=vector) + assert len(recheck_result) == 0 + + # Test manual expiration of multiple documents def test_drop_documents(cache, vectorizer): prompts = [ @@ -219,6 +354,31 @@ def test_drop_documents(cache, vectorizer): assert len(recheck_result) == 1 +@pytest.mark.asyncio +async def test_async_drop_documents(cache, vectorizer): + prompts = [ + "This is a test prompt.", + "This is also test prompt.", + "This is another test prompt.", + ] + responses = [ + "This is a test response.", + "This is also test response.", + "This is a another test response.", + ] + for prompt, response in zip(prompts, responses): + vector = vectorizer.embed(prompt) + await cache.astore(prompt, response, vector=vector) + + check_result = await cache.acheck(vector=vector, num_results=3) + print(check_result, flush=True) + ids = [r["entry_id"] for r in check_result[0:2]] # drop first 2 entries + await cache.adrop(ids=ids) + + recheck_result = await cache.acheck(vector=vector, num_results=3) + assert len(recheck_result) == 1 + + # Test updating document fields def test_updating_document(cache): prompt = "This is a test prompt." @@ -240,6 +400,27 @@ def test_updating_document(cache): assert updated_result[0]["updated_at"] > check_result[0]["updated_at"] +@pytest.mark.asyncio +async def test_async_updating_document(cache): + prompt = "This is a test prompt." + response = "This is a test response." + await cache.astore(prompt=prompt, response=response) + + check_result = await cache.acheck(prompt=prompt, return_fields=["updated_at"]) + key = check_result[0]["key"] + + sleep(1) + + metadata = {"foo": "bar"} + await cache.aupdate(key=key, metadata=metadata) + + updated_result = await cache.acheck( + prompt=prompt, return_fields=["updated_at", "metadata"] + ) + assert updated_result[0]["metadata"] == metadata + assert updated_result[0]["updated_at"] > check_result[0]["updated_at"] + + def test_ttl_expiration_after_update(cache_with_ttl, vectorizer): prompt = "This is a test prompt." response = "This is a test response." @@ -255,6 +436,22 @@ def test_ttl_expiration_after_update(cache_with_ttl, vectorizer): assert len(check_result) == 0 +@pytest.mark.asyncio +async def test_async_ttl_expiration_after_update(cache_with_ttl, vectorizer): + prompt = "This is a test prompt." + response = "This is a test response." + vector = vectorizer.embed(prompt) + cache_with_ttl.set_ttl(4) + + assert cache_with_ttl.ttl == 4 + + await cache_with_ttl.astore(prompt, response, vector=vector) + sleep(5) + + check_result = await cache_with_ttl.acheck(vector=vector) + assert len(check_result) == 0 + + # Test check behavior with no match def test_check_no_match(cache, vectorizer): vector = vectorizer.embed("Some random sentence.") @@ -270,6 +467,15 @@ def test_check_invalid_input(cache): cache.check(prompt="test", return_fields="bad value") +@pytest.mark.asyncio +async def test_async_check_invalid_input(cache): + with pytest.raises(ValueError): + await cache.acheck() + + with pytest.raises(TypeError): + await cache.acheck(prompt="test", return_fields="bad value") + + def test_bad_connection_info(vectorizer): with pytest.raises(ConnectionError): SemanticCache( @@ -357,10 +563,6 @@ def test_multiple_items(cache, vectorizer): assert "metadata" not in check_result[0] -def test_get_index(cache): - assert isinstance(cache.index, SearchIndex) - - def test_store_and_check_with_provided_client(cache_with_redis_client, vectorizer): prompt = "This is a test prompt." response = "This is a test response." @@ -375,9 +577,21 @@ def test_store_and_check_with_provided_client(cache_with_redis_client, vectorize assert "metadata" not in check_result[0] -def test_delete(cache_no_cleanup): - cache_no_cleanup.delete() - assert not cache_no_cleanup.index.exists() +@pytest.mark.asyncio +async def test_async_store_and_check_with_provided_client( + cache_with_redis_client, vectorizer +): + prompt = "This is a test prompt." + response = "This is a test response." + vector = vectorizer.embed(prompt) + + await cache_with_redis_client.astore(prompt, response, vector=vector) + check_result = await cache_with_redis_client.acheck(vector=vector) + + assert len(check_result) == 1 + print(check_result, flush=True) + assert response == check_result[0]["response"] + assert "metadata" not in check_result[0] def test_vector_size(cache, vectorizer): diff --git a/tests/integration/test_session_manager.py b/tests/integration/test_session_manager.py index 56943447..20c2955d 100644 --- a/tests/integration/test_session_manager.py +++ b/tests/integration/test_session_manager.py @@ -464,8 +464,7 @@ def test_semantic_add_and_get_relevant(semantic_session): default_context = semantic_session.get_relevant("list of fruits and vegetables") assert len(default_context) == 5 # 2 pairs of prompt:response, and system assert default_context == semantic_session.get_relevant( - "list of fruits and vegetables", - distance_threshold=0.5 + "list of fruits and vegetables", distance_threshold=0.5 ) # test tool calls can also be returned