diff --git a/interlab/lang_models/__init__.py b/interlab/lang_models/__init__.py index 3efbb8e..d137165 100644 --- a/interlab/lang_models/__init__.py +++ b/interlab/lang_models/__init__.py @@ -2,4 +2,5 @@ from .base import LangModelBase # noqa: F401 from .openai import OpenAiChatModel # noqa: F401 from .query_model import query_model # noqa: F401 +from .replay import Replay # noqa: F401 from .web_console import WebConsoleModel # noqa: F401 diff --git a/interlab/lang_models/query_model.py b/interlab/lang_models/query_model.py index 943b4a5..0817aa0 100644 --- a/interlab/lang_models/query_model.py +++ b/interlab/lang_models/query_model.py @@ -2,6 +2,7 @@ from ..context import Context from .base import LangModelBase +from .replay import Replay def _prepare_model(model: any, model_kwargs: dict = None, call_async: bool = False): @@ -51,11 +52,21 @@ def _prepare_model(model: any, model_kwargs: dict = None, call_async: bool = Fal def query_model( - model: any, prompt: str | FormatStr, kwargs: dict = None, with_context=True + model: any, + prompt: str | FormatStr, + kwargs: dict = None, + with_context=True, + replay: Replay = None, ) -> str: if not isinstance(prompt, (str, FormatStr)): raise TypeError("query_model accepts only str and FormatStr as prompt") name, conf, call = _prepare_model(model, model_kwargs=kwargs, call_async=False) + + if replay is not None: + cached_result = replay.get_cached_response(conf, prompt) + if cached_result: + name = "(Cached) " + name + call = lambda _prompt: cached_result # noqa: E731 if with_context: with Context(name, kind="query", inputs=dict(prompt=prompt, conf=conf)) as c: if isinstance(prompt, FormatStr): diff --git a/interlab/lang_models/replay.py b/interlab/lang_models/replay.py new file mode 100644 index 0000000..a0ba609 --- /dev/null +++ b/interlab/lang_models/replay.py @@ -0,0 +1,28 @@ +from typing import Optional + +from interlab.context import Context +from interlab.context.context import ContextState + + +class Replay: + def __init__(self, context: Context): + self.replays = {} + for context in context.find_contexts(lambda ctx: ctx.kind == "query"): + if context.state != ContextState.FINISHED: + continue + conf = context.inputs.get("conf") + prompt = context.inputs.get("prompt") + if conf is None or prompt is None: + continue + key = (frozenset(conf.items()), prompt) + self.replays.setdefault(key, []) + self.replays[key].append(context.result) + for replay in self.replays.values(): + replay.reverse() + + def get_cached_response(self, conf: dict, prompt: str) -> Optional[str]: + key = (frozenset(conf.items()), prompt) + replay = self.replays.get(key) + if not replay: + return None + return replay.pop() diff --git a/notebooks/lang_models.ipynb b/notebooks/lang_models.ipynb new file mode 100644 index 0000000..27f1675 --- /dev/null +++ b/notebooks/lang_models.ipynb @@ -0,0 +1,107 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 8, + "id": "742b8b3a-92a1-442b-8dc6-9ebd69dfeb34", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + }, + { + "data": { + "text/plain": [ + "False" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "import dotenv\n", + "from interlab.context import Context\n", + "from interlab.lang_models import OpenAiChatModel, Replay, query_model\n", + "\n", + "dotenv.load_dotenv()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "46ec5d16-dcc6-4629-a582-5fd84d27387a", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "4fa0add9-52d5-4de4-81ac-7ca9a860a730", + "metadata": {}, + "outputs": [], + "source": [ + "# Without caching\n", + "\n", + "model = OpenAiChatModel()\n", + "\n", + "with Context(\"root\") as my_context:\n", + " query_model(model, \"How are you?\")\n", + " query_model(model, \"How are you?\")\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "6849b6a4-d0d7-4d2e-a916-41571896fd50", + "metadata": {}, + "outputs": [], + "source": [ + "# With caching\n", + "\n", + "replay = Replay(my_context)\n", + "\n", + "with Context(\"root\") as my_context:\n", + " # This goes from cache\n", + " query_model(model, \"How are you?\", replay=replay)\n", + " # This goes from cache\n", + " query_model(model, \"How are you?\", replay=replay)\n", + " # This is fresh call as it is not in cache\n", + " query_model(model, \"How are you?\", replay=replay)\n", + " " + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tests/lang_models/test_replay.py b/tests/lang_models/test_replay.py new file mode 100644 index 0000000..0bbbe66 --- /dev/null +++ b/tests/lang_models/test_replay.py @@ -0,0 +1,41 @@ +from unittest.mock import patch + +from interlab.context import Context +from interlab.lang_models import OpenAiChatModel, Replay, query_model + + +@patch("interlab.lang_models.openai._make_openai_chat_query") +def test_replay_model(_make_openai_chat_query): + model = OpenAiChatModel(api_key="xxx", api_org="xxx") + with Context("root") as root: + _make_openai_chat_query.return_value = "Answer 1" + query_model(model, "How are you?") + assert _make_openai_chat_query.call_count == 1 + _make_openai_chat_query.return_value = "Answer 2" + query_model(model, "How are you?") + assert _make_openai_chat_query.call_count == 2 + _make_openai_chat_query.return_value = "Answer 3" + query_model(model, "What is your name?") + assert _make_openai_chat_query.call_count == 3 + _make_openai_chat_query.return_value = "Answer 4" + query_model(model, "How are you?") + assert _make_openai_chat_query.call_count == 4 + + replay = Replay(root) + + _make_openai_chat_query.return_value = "Answer 5" + r = query_model(model, "How are you?", replay=replay) + assert r == "Answer 1" + + model2 = OpenAiChatModel(api_key="xxx", api_org="xxx", temperature=0.0123) + r = query_model(model2, "What is your name?", replay=replay) + assert r == "Answer 5" + + r = query_model(model, "How are you?", replay=replay) + assert r == "Answer 2" + r = query_model(model, "How are you?", replay=replay) + assert r == "Answer 4" + r = query_model(model, "How are you?", replay=replay) + assert r == "Answer 5" + r = query_model(model, "What is your name?", replay=replay) + assert r == "Answer 3"