-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Query replay #45
base: main
Are you sure you want to change the base?
Query replay #45
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from typing import Optional | ||
|
||
from interlab.context import Context | ||
from interlab.context.context import ContextState | ||
|
||
|
||
class Replay: | ||
def __init__(self, context: Context): | ||
self.replays = {} | ||
for context in context.find_contexts(lambda ctx: ctx.kind == "query"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. might be good to extract string constants like |
||
if context.state != ContextState.FINISHED: | ||
continue | ||
conf = context.inputs.get("conf") | ||
prompt = context.inputs.get("prompt") | ||
if conf is None or prompt is None: | ||
continue | ||
key = (frozenset(conf.items()), prompt) | ||
self.replays.setdefault(key, []) | ||
self.replays[key].append(context.result) | ||
for replay in self.replays.values(): | ||
replay.reverse() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the logic here assumes that future implementations of |
||
|
||
def get_cached_response(self, conf: dict, prompt: str) -> Optional[str]: | ||
key = (frozenset(conf.items()), prompt) | ||
replay = self.replays.get(key) | ||
if not replay: | ||
return None | ||
return replay.pop() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"id": "742b8b3a-92a1-442b-8dc6-9ebd69dfeb34", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"The autoreload extension is already loaded. To reload it, use:\n", | ||
" %reload_ext autoreload\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"False" | ||
] | ||
}, | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"%load_ext autoreload\n", | ||
"%autoreload 2\n", | ||
"\n", | ||
"import dotenv\n", | ||
"from interlab.context import Context\n", | ||
"from interlab.lang_models import OpenAiChatModel, Replay, query_model\n", | ||
"\n", | ||
"dotenv.load_dotenv()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "46ec5d16-dcc6-4629-a582-5fd84d27387a", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 10, | ||
"id": "4fa0add9-52d5-4de4-81ac-7ca9a860a730", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Without caching\n", | ||
"\n", | ||
"model = OpenAiChatModel()\n", | ||
"\n", | ||
"with Context(\"root\") as my_context:\n", | ||
" query_model(model, \"How are you?\")\n", | ||
" query_model(model, \"How are you?\")\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 12, | ||
"id": "6849b6a4-d0d7-4d2e-a916-41571896fd50", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# With caching\n", | ||
"\n", | ||
"replay = Replay(my_context)\n", | ||
"\n", | ||
"with Context(\"root\") as my_context:\n", | ||
" # This goes from cache\n", | ||
" query_model(model, \"How are you?\", replay=replay)\n", | ||
" # This goes from cache\n", | ||
" query_model(model, \"How are you?\", replay=replay)\n", | ||
" # This is fresh call as it is not in cache\n", | ||
" query_model(model, \"How are you?\", replay=replay)\n", | ||
" " | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from unittest.mock import patch | ||
|
||
from interlab.context import Context | ||
from interlab.lang_models import OpenAiChatModel, Replay, query_model | ||
|
||
|
||
@patch("interlab.lang_models.openai._make_openai_chat_query") | ||
def test_replay_model(_make_openai_chat_query): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice test! |
||
model = OpenAiChatModel(api_key="xxx", api_org="xxx") | ||
with Context("root") as root: | ||
_make_openai_chat_query.return_value = "Answer 1" | ||
query_model(model, "How are you?") | ||
assert _make_openai_chat_query.call_count == 1 | ||
_make_openai_chat_query.return_value = "Answer 2" | ||
query_model(model, "How are you?") | ||
assert _make_openai_chat_query.call_count == 2 | ||
_make_openai_chat_query.return_value = "Answer 3" | ||
query_model(model, "What is your name?") | ||
assert _make_openai_chat_query.call_count == 3 | ||
_make_openai_chat_query.return_value = "Answer 4" | ||
query_model(model, "How are you?") | ||
assert _make_openai_chat_query.call_count == 4 | ||
|
||
replay = Replay(root) | ||
|
||
_make_openai_chat_query.return_value = "Answer 5" | ||
r = query_model(model, "How are you?", replay=replay) | ||
assert r == "Answer 1" | ||
|
||
model2 = OpenAiChatModel(api_key="xxx", api_org="xxx", temperature=0.0123) | ||
r = query_model(model2, "What is your name?", replay=replay) | ||
assert r == "Answer 5" | ||
|
||
r = query_model(model, "How are you?", replay=replay) | ||
assert r == "Answer 2" | ||
r = query_model(model, "How are you?", replay=replay) | ||
assert r == "Answer 4" | ||
r = query_model(model, "How are you?", replay=replay) | ||
assert r == "Answer 5" | ||
r = query_model(model, "What is your name?", replay=replay) | ||
assert r == "Answer 3" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typing should be
Optional[Replay]
orReplay | None