Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query replay #45

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions interlab/lang_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
from .base import LangModelBase # noqa: F401
from .openai import OpenAiChatModel # noqa: F401
from .query_model import query_model # noqa: F401
from .replay import Replay # noqa: F401
from .web_console import WebConsoleModel # noqa: F401
13 changes: 12 additions & 1 deletion interlab/lang_models/query_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from ..context import Context
from .base import LangModelBase
from .replay import Replay


def _prepare_model(model: any, model_kwargs: dict = None, call_async: bool = False):
Expand Down Expand Up @@ -51,11 +52,21 @@ def _prepare_model(model: any, model_kwargs: dict = None, call_async: bool = Fal


def query_model(
model: any, prompt: str | FormatStr, kwargs: dict = None, with_context=True
model: any,
prompt: str | FormatStr,
kwargs: dict = None,
with_context=True,
replay: Replay = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typing should be Optional[Replay] or Replay | None

) -> str:
if not isinstance(prompt, (str, FormatStr)):
raise TypeError("query_model accepts only str and FormatStr as prompt")
name, conf, call = _prepare_model(model, model_kwargs=kwargs, call_async=False)

if replay is not None:
cached_result = replay.get_cached_response(conf, prompt)
if cached_result:
name = "(Cached) " + name
call = lambda _prompt: cached_result # noqa: E731
if with_context:
with Context(name, kind="query", inputs=dict(prompt=prompt, conf=conf)) as c:
if isinstance(prompt, FormatStr):
Expand Down
28 changes: 28 additions & 0 deletions interlab/lang_models/replay.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Optional

from interlab.context import Context
from interlab.context.context import ContextState


class Replay:
def __init__(self, context: Context):
self.replays = {}
for context in context.find_contexts(lambda ctx: ctx.kind == "query"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be good to extract string constants like "query", "conf", "prompt" etc into a constants module. it can reduce the risk of bugs from typos. it can also signal to future devs that these are keys which more than one module relies on

if context.state != ContextState.FINISHED:
continue
conf = context.inputs.get("conf")
prompt = context.inputs.get("prompt")
if conf is None or prompt is None:
continue
key = (frozenset(conf.items()), prompt)
self.replays.setdefault(key, [])
self.replays[key].append(context.result)
for replay in self.replays.values():
replay.reverse()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the logic here assumes that future implementations of StorageBase.list adhere to a certain order of listing contexts. This should be specified in the doc string of StorageBase.list and tested in a unit test (if not there already)


def get_cached_response(self, conf: dict, prompt: str) -> Optional[str]:
key = (frozenset(conf.items()), prompt)
replay = self.replays.get(key)
if not replay:
return None
return replay.pop()
107 changes: 107 additions & 0 deletions notebooks/lang_models.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 8,
"id": "742b8b3a-92a1-442b-8dc6-9ebd69dfeb34",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
},
{
"data": {
"text/plain": [
"False"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"import dotenv\n",
"from interlab.context import Context\n",
"from interlab.lang_models import OpenAiChatModel, Replay, query_model\n",
"\n",
"dotenv.load_dotenv()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "46ec5d16-dcc6-4629-a582-5fd84d27387a",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 10,
"id": "4fa0add9-52d5-4de4-81ac-7ca9a860a730",
"metadata": {},
"outputs": [],
"source": [
"# Without caching\n",
"\n",
"model = OpenAiChatModel()\n",
"\n",
"with Context(\"root\") as my_context:\n",
" query_model(model, \"How are you?\")\n",
" query_model(model, \"How are you?\")\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "6849b6a4-d0d7-4d2e-a916-41571896fd50",
"metadata": {},
"outputs": [],
"source": [
"# With caching\n",
"\n",
"replay = Replay(my_context)\n",
"\n",
"with Context(\"root\") as my_context:\n",
" # This goes from cache\n",
" query_model(model, \"How are you?\", replay=replay)\n",
" # This goes from cache\n",
" query_model(model, \"How are you?\", replay=replay)\n",
" # This is fresh call as it is not in cache\n",
" query_model(model, \"How are you?\", replay=replay)\n",
" "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
41 changes: 41 additions & 0 deletions tests/lang_models/test_replay.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from unittest.mock import patch

from interlab.context import Context
from interlab.lang_models import OpenAiChatModel, Replay, query_model


@patch("interlab.lang_models.openai._make_openai_chat_query")
def test_replay_model(_make_openai_chat_query):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice test!

model = OpenAiChatModel(api_key="xxx", api_org="xxx")
with Context("root") as root:
_make_openai_chat_query.return_value = "Answer 1"
query_model(model, "How are you?")
assert _make_openai_chat_query.call_count == 1
_make_openai_chat_query.return_value = "Answer 2"
query_model(model, "How are you?")
assert _make_openai_chat_query.call_count == 2
_make_openai_chat_query.return_value = "Answer 3"
query_model(model, "What is your name?")
assert _make_openai_chat_query.call_count == 3
_make_openai_chat_query.return_value = "Answer 4"
query_model(model, "How are you?")
assert _make_openai_chat_query.call_count == 4

replay = Replay(root)

_make_openai_chat_query.return_value = "Answer 5"
r = query_model(model, "How are you?", replay=replay)
assert r == "Answer 1"

model2 = OpenAiChatModel(api_key="xxx", api_org="xxx", temperature=0.0123)
r = query_model(model2, "What is your name?", replay=replay)
assert r == "Answer 5"

r = query_model(model, "How are you?", replay=replay)
assert r == "Answer 2"
r = query_model(model, "How are you?", replay=replay)
assert r == "Answer 4"
r = query_model(model, "How are you?", replay=replay)
assert r == "Answer 5"
r = query_model(model, "What is your name?", replay=replay)
assert r == "Answer 3"
Loading