From de340331ff34ad8b78cdb0089d2f54d7f16f752f Mon Sep 17 00:00:00 2001 From: Faisal Alquaddoomi Date: Wed, 25 Sep 2024 11:54:18 -0600 Subject: [PATCH 01/11] Replaced openai dependency with langchain_openai --- setup.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index b933cf9..3e89151 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setuptools.setup( name="manubot-ai-editor", - version="0.5.2", + version="0.5.3", author="Milton Pividori", author_email="miltondp@gmail.com", description="A Manubot plugin to revise a manuscript using GPT-3", @@ -25,7 +25,8 @@ ], python_requires=">=3.10", install_requires=[ - "openai==0.28", + # "openai==0.28", + "langchain-openai==0.2.0", "pyyaml", ], classifiers=[ From ae19171eee44e7a64e6f72954af62efa37bd93af Mon Sep 17 00:00:00 2001 From: Faisal Alquaddoomi Date: Wed, 25 Sep 2024 12:24:58 -0600 Subject: [PATCH 02/11] For posterity, updates to latest release openai, updates code accordingly --- libs/manubot_ai_editor/models.py | 22 +++++++++++++--------- setup.py | 3 ++- tests/test_model_basics.py | 16 ++++++---------- 3 files changed, 21 insertions(+), 20 deletions(-) diff --git a/libs/manubot_ai_editor/models.py b/libs/manubot_ai_editor/models.py index c25dd68..a608542 100644 --- a/libs/manubot_ai_editor/models.py +++ b/libs/manubot_ai_editor/models.py @@ -5,7 +5,7 @@ import time import json -import openai +from openai import OpenAI from manubot_ai_editor import env_vars @@ -141,18 +141,22 @@ def __init__( super().__init__() # make sure the OpenAI API key is set - openai.api_key = openai_api_key + if openai_api_key is None: + # attempt to get the OpenAI API key from the environment, since one + # wasn't specified as an argument + openai_api_key = os.environ.get(env_vars.OPENAI_API_KEY, None) - if openai.api_key is None: - openai.api_key = os.environ.get(env_vars.OPENAI_API_KEY, None) - - if openai.api_key is None or openai.api_key.strip() == "": + # if it's *still* not set, bail + if openai_api_key is None or openai_api_key.strip() == "": raise ValueError( f"OpenAI API key not found. Please provide it as parameter " f"or set it as an the environment variable " f"{env_vars.OPENAI_API_KEY}" ) + # construct the OpenAI client + self.client = OpenAI(api_key=openai_api_key) + if env_vars.LANGUAGE_MODEL in os.environ: val = os.environ[env_vars.LANGUAGE_MODEL] if val.strip() != "": @@ -527,11 +531,11 @@ def revise_paragraph(self, paragraph_text: str, section_name: str = None, resolv ) if self.endpoint == "edits": - completions = openai.Edit.create(**params) + completions = self.client.edits.create(**params) elif self.endpoint == "chat": - completions = openai.ChatCompletion.create(**params) + completions = self.client.chat.completions.create(**params) else: - completions = openai.Completion.create(**params) + completions = self.client.completions.create(**params) if self.endpoint == "chat": message = completions.choices[0].message.content.strip() diff --git a/setup.py b/setup.py index 3e89151..ef67c0b 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,8 @@ python_requires=">=3.10", install_requires=[ # "openai==0.28", - "langchain-openai==0.2.0", + "openai==1.48.0", + # "langchain-openai==0.2.0", "pyyaml", ], classifiers=[ diff --git a/tests/test_model_basics.py b/tests/test_model_basics.py index 8242ca9..02545b5 100644 --- a/tests/test_model_basics.py +++ b/tests/test_model_basics.py @@ -32,12 +32,12 @@ def test_model_object_init_without_openai_api_key(): @mock.patch.dict("os.environ", {env_vars.OPENAI_API_KEY: "env_var_test_value"}) def test_model_object_init_with_openai_api_key_as_environment_variable(): - GPT3CompletionModel( + model = GPT3CompletionModel( title="Test title", keywords=["test", "keywords"], ) - assert models.openai.api_key == "env_var_test_value" + assert model.client.api_key == "env_var_test_value" def test_model_object_init_with_openai_api_key_as_parameter(): @@ -46,30 +46,26 @@ def test_model_object_init_with_openai_api_key_as_parameter(): if env_vars.OPENAI_API_KEY in os.environ: os.environ.pop(env_vars.OPENAI_API_KEY) - GPT3CompletionModel( + model = GPT3CompletionModel( title="Test title", keywords=["test", "keywords"], openai_api_key="test_value", ) - from manubot_ai_editor import models - - assert models.openai.api_key == "test_value" + assert model.client.api_key == "test_value" finally: os.environ = _environ @mock.patch.dict("os.environ", {env_vars.OPENAI_API_KEY: "env_var_test_value"}) def test_model_object_init_with_openai_api_key_as_parameter_has_higher_priority(): - GPT3CompletionModel( + model = GPT3CompletionModel( title="Test title", keywords=["test", "keywords"], openai_api_key="test_value", ) - from manubot_ai_editor import models - - assert models.openai.api_key == "test_value" + assert model.client.api_key == "test_value" def test_model_object_init_default_language_model(): From 33d0adcddfb101d5bdb60351cf2ad9bfefc069cd Mon Sep 17 00:00:00 2001 From: Faisal Alquaddoomi Date: Wed, 25 Sep 2024 13:43:36 -0600 Subject: [PATCH 03/11] Switches from openai to langchain_openai, updates tests accordingly. --- libs/manubot_ai_editor/models.py | 71 ++++++++++++++++++++++++++------ setup.py | 4 +- tests/test_model_basics.py | 7 ++-- 3 files changed, 62 insertions(+), 20 deletions(-) diff --git a/libs/manubot_ai_editor/models.py b/libs/manubot_ai_editor/models.py index a608542..54714dd 100644 --- a/libs/manubot_ai_editor/models.py +++ b/libs/manubot_ai_editor/models.py @@ -5,7 +5,8 @@ import time import json -from openai import OpenAI +from langchain_openai import OpenAI, ChatOpenAI +from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage from manubot_ai_editor import env_vars @@ -154,9 +155,6 @@ def __init__( f"{env_vars.OPENAI_API_KEY}" ) - # construct the OpenAI client - self.client = OpenAI(api_key=openai_api_key) - if env_vars.LANGUAGE_MODEL in os.environ: val = os.environ[env_vars.LANGUAGE_MODEL] if val.strip() != "": @@ -257,6 +255,22 @@ def __init__( self.several_spaces_pattern = re.compile(r"\s+") + if self.endpoint == "edits": + # FIXME: what's the "edits" equivalent in langchain? + client_cls = OpenAI + elif self.endpoint == "chat": + client_cls = ChatOpenAI + else: + client_cls = OpenAI + + # construct the OpenAI client after all the rest of + # the settings above have been processed + self.client = client_cls( + api_key=openai_api_key, + **self.model_parameters, + ) + + def get_prompt( self, paragraph_text: str, section_name: str = None, resolved_prompt: str = None ) -> str | tuple[str, str]: @@ -530,17 +544,48 @@ def revise_paragraph(self, paragraph_text: str, section_name: str = None, resolv flush=True, ) - if self.endpoint == "edits": - completions = self.client.edits.create(**params) - elif self.endpoint == "chat": - completions = self.client.chat.completions.create(**params) + # FIXME: 'params' contains a lot of fields that we're not + # currently passing to the langchain client. i need to figure + # out where they're supposed to be given, e.g. in the client + # init or with each request. + + # map the prompt to langchain's prompt types, based on what + # kind of endpoint we're using + if "messages" in params: + # map the messages to langchain's message types + # based on the 'role' field + prompts = [ + HumanMessage(content=msg["content"]) + if msg["role"] == "user" else + SystemMessage(content=msg["content"]) + for msg in params["messages"] + ] + elif "instruction" in params: + # since we don't know how to use the edits endpoint, we'll just + # concatenate the instruction and input and use the regular + # completion endpoint + # FIXME: there's probably a langchain equivalent for + # "edits", so we should change this to use that + prompts = [ + HumanMessage(content=params["instruction"]), + HumanMessage(content=params["input"]), + ] + elif "prompt" in params: + prompts = [HumanMessage(content=params["prompt"])] + + response = self.client.invoke(prompts) + + if isinstance(response, BaseMessage): + message = response.content.strip() else: - completions = self.client.completions.create(**params) + message = response.strip() + + # FIXME: the prior code retrieved the first of the 'choices' + # response from the openai client. now, we only get one + # response from the langchain client, but i should check + # if that's really how langchain works or if there is a way + # to get multiple 'choices' back from the backend. - if self.endpoint == "chat": - message = completions.choices[0].message.content.strip() - else: - message = completions.choices[0].text.strip() except Exception as e: error_message = str(e) print(f"Error: {error_message}") diff --git a/setup.py b/setup.py index ef67c0b..b17d981 100644 --- a/setup.py +++ b/setup.py @@ -25,9 +25,7 @@ ], python_requires=">=3.10", install_requires=[ - # "openai==0.28", - "openai==1.48.0", - # "langchain-openai==0.2.0", + "langchain-openai==0.2.0", "pyyaml", ], classifiers=[ diff --git a/tests/test_model_basics.py b/tests/test_model_basics.py index 02545b5..54ec074 100644 --- a/tests/test_model_basics.py +++ b/tests/test_model_basics.py @@ -9,7 +9,6 @@ import pytest from manubot_ai_editor.editor import ManuscriptEditor, env_vars -from manubot_ai_editor import models from manubot_ai_editor.models import GPT3CompletionModel, RandomManuscriptRevisionModel MANUSCRIPTS_DIR = Path(__file__).parent / "manuscripts" @@ -37,7 +36,7 @@ def test_model_object_init_with_openai_api_key_as_environment_variable(): keywords=["test", "keywords"], ) - assert model.client.api_key == "env_var_test_value" + assert model.client.openai_api_key.get_secret_value() == "env_var_test_value" def test_model_object_init_with_openai_api_key_as_parameter(): @@ -52,7 +51,7 @@ def test_model_object_init_with_openai_api_key_as_parameter(): openai_api_key="test_value", ) - assert model.client.api_key == "test_value" + assert model.client.openai_api_key.get_secret_value() == "test_value" finally: os.environ = _environ @@ -65,7 +64,7 @@ def test_model_object_init_with_openai_api_key_as_parameter_has_higher_priority( openai_api_key="test_value", ) - assert model.client.api_key == "test_value" + assert model.client.openai_api_key.get_secret_value() == "test_value" def test_model_object_init_default_language_model(): From 8e75813b5a1f34db02e609665d9d3b66c9315fba Mon Sep 17 00:00:00 2001 From: Faisal Alquaddoomi Date: Thu, 26 Sep 2024 08:39:29 -0600 Subject: [PATCH 04/11] Passed the remaining important(?) params, max_tokens and stop, to invoke() --- libs/manubot_ai_editor/models.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/libs/manubot_ai_editor/models.py b/libs/manubot_ai_editor/models.py index 54714dd..7191e9f 100644 --- a/libs/manubot_ai_editor/models.py +++ b/libs/manubot_ai_editor/models.py @@ -544,17 +544,12 @@ def revise_paragraph(self, paragraph_text: str, section_name: str = None, resolv flush=True, ) - # FIXME: 'params' contains a lot of fields that we're not - # currently passing to the langchain client. i need to figure - # out where they're supposed to be given, e.g. in the client - # init or with each request. - # map the prompt to langchain's prompt types, based on what # kind of endpoint we're using if "messages" in params: # map the messages to langchain's message types # based on the 'role' field - prompts = [ + prompt = [ HumanMessage(content=msg["content"]) if msg["role"] == "user" else SystemMessage(content=msg["content"]) @@ -566,14 +561,18 @@ def revise_paragraph(self, paragraph_text: str, section_name: str = None, resolv # completion endpoint # FIXME: there's probably a langchain equivalent for # "edits", so we should change this to use that - prompts = [ + prompt = [ HumanMessage(content=params["instruction"]), HumanMessage(content=params["input"]), ] elif "prompt" in params: - prompts = [HumanMessage(content=params["prompt"])] + prompt = [HumanMessage(content=params["prompt"])] - response = self.client.invoke(prompts) + response = self.client.invoke( + input=prompt, + max_tokens=params.get("max_tokens"), + stop=params.get("stop"), + ) if isinstance(response, BaseMessage): message = response.content.strip() From e4da6f071d025e6d9397e455e826c9724c7c1f7a Mon Sep 17 00:00:00 2001 From: Faisal Alquaddoomi Date: Thu, 26 Sep 2024 08:59:43 -0600 Subject: [PATCH 05/11] Changed langchain-openai dep to 0.2.x, so we get patch releases --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index b17d981..7f7e140 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ ], python_requires=">=3.10", install_requires=[ - "langchain-openai==0.2.0", + "langchain-openai~=0.2.0", "pyyaml", ], classifiers=[ From cdca3b862255d4a072a3cc3d361318bef678a381 Mon Sep 17 00:00:00 2001 From: Faisal Alquaddoomi Date: Thu, 26 Sep 2024 09:04:09 -0600 Subject: [PATCH 06/11] Added direct dependency on langchain-core~=0.3.6, since we use it directly in the code --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 7f7e140..9b90a31 100644 --- a/setup.py +++ b/setup.py @@ -25,6 +25,7 @@ ], python_requires=">=3.10", install_requires=[ + "langchain-core~=0.3.6", "langchain-openai~=0.2.0", "pyyaml", ], From 9a2d11b7977b12ff50c6e51f0c32d0800229cd77 Mon Sep 17 00:00:00 2001 From: Faisal Alquaddoomi Date: Wed, 23 Oct 2024 13:19:47 -0600 Subject: [PATCH 07/11] Ran black on models.py as suggested --- libs/manubot_ai_editor/models.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/libs/manubot_ai_editor/models.py b/libs/manubot_ai_editor/models.py index 7191e9f..4096969 100644 --- a/libs/manubot_ai_editor/models.py +++ b/libs/manubot_ai_editor/models.py @@ -270,7 +270,6 @@ def __init__( **self.model_parameters, ) - def get_prompt( self, paragraph_text: str, section_name: str = None, resolved_prompt: str = None ) -> str | tuple[str, str]: @@ -520,7 +519,9 @@ def get_params(self, paragraph_text, section_name, resolved_prompt=None): return params - def revise_paragraph(self, paragraph_text: str, section_name: str = None, resolved_prompt=None): + def revise_paragraph( + self, paragraph_text: str, section_name: str = None, resolved_prompt=None + ): """ It revises a paragraph using GPT-3 completion model. @@ -550,9 +551,11 @@ def revise_paragraph(self, paragraph_text: str, section_name: str = None, resolv # map the messages to langchain's message types # based on the 'role' field prompt = [ - HumanMessage(content=msg["content"]) - if msg["role"] == "user" else - SystemMessage(content=msg["content"]) + ( + HumanMessage(content=msg["content"]) + if msg["role"] == "user" + else SystemMessage(content=msg["content"]) + ) for msg in params["messages"] ] elif "instruction" in params: @@ -631,10 +634,10 @@ class DebuggingManuscriptRevisionModel(GPT3CompletionModel): """ def __init__(self, *args, **kwargs): - if 'title' not in kwargs or kwargs['title'] is None: - kwargs['title'] = "Debugging Title" - if 'keywords' not in kwargs or kwargs['keywords'] is None: - kwargs['keywords'] = ["debugging", "keywords"] + if "title" not in kwargs or kwargs["title"] is None: + kwargs["title"] = "Debugging Title" + if "keywords" not in kwargs or kwargs["keywords"] is None: + kwargs["keywords"] = ["debugging", "keywords"] super().__init__(*args, **kwargs) From bfacb162da1969c0865e85f085541d646a7dcb38 Mon Sep 17 00:00:00 2001 From: Faisal Alquaddoomi Date: Wed, 20 Nov 2024 14:38:14 -0700 Subject: [PATCH 08/11] Removes defunct 'edits' endpoint code, tests, and mentions in docs. --- libs/manubot_ai_editor/env_vars.py | 2 +- libs/manubot_ai_editor/models.py | 47 +++++------------------------- tests/test_model_get_prompt.py | 30 ------------------- 3 files changed, 8 insertions(+), 71 deletions(-) diff --git a/libs/manubot_ai_editor/env_vars.py b/libs/manubot_ai_editor/env_vars.py index 0ed9a06..a888bc6 100644 --- a/libs/manubot_ai_editor/env_vars.py +++ b/libs/manubot_ai_editor/env_vars.py @@ -16,7 +16,7 @@ OPENAI_API_KEY = "OPENAI_API_KEY" # Language model to use. For example, "text-davinci-003", "gpt-3.5-turbo", "gpt-3.5-turbo-0301", etc -# The tool currently supports the "chat/completions", "completions", and "edits" endpoints, and you can check +# The tool currently supports the "chat/completions" and "completions" endpoints, and you can check # compatible models here: https://platform.openai.com/docs/models/model-endpoint-compatibility LANGUAGE_MODEL = "AI_EDITOR_LANGUAGE_MODEL" diff --git a/libs/manubot_ai_editor/models.py b/libs/manubot_ai_editor/models.py index 4096969..4a3c138 100644 --- a/libs/manubot_ai_editor/models.py +++ b/libs/manubot_ai_editor/models.py @@ -223,7 +223,7 @@ def __init__( self.title = title self.keywords = keywords if keywords is not None else [] - # adjust options if edits or chat endpoint was selected + # adjust options if chat endpoint was selected self.endpoint = "chat" if model_engine.startswith( @@ -231,9 +231,6 @@ def __init__( ): self.endpoint = "completions" - if "-edit-" in model_engine: - self.endpoint = "edits" - print(f"Language model: {model_engine}") print(f"Model endpoint used: {self.endpoint}") @@ -255,10 +252,7 @@ def __init__( self.several_spaces_pattern = re.compile(r"\s+") - if self.endpoint == "edits": - # FIXME: what's the "edits" equivalent in langchain? - client_cls = OpenAI - elif self.endpoint == "chat": + if self.endpoint == "chat": client_cls = ChatOpenAI else: client_cls = OpenAI @@ -285,13 +279,9 @@ def get_prompt( resolved_prompt: prompt resolved via ai-revision config, if available Returns: - If self.endpoint != "edits", then returns a string with the prompt to be used by the model for the revision of the paragraph. + A string with the prompt to be used by the model for the revision of the paragraph. It contains two paragraphs of text: the command for the model ("Revise...") and the paragraph to revise. - - If self.endpoint == "edits", then returns a tuple with two strings: - 1) the instructions to be used by the model for the revision of the paragraph, - 2) the paragraph to revise. """ # prompts are resolved in the following order, with the first satisfied @@ -327,8 +317,6 @@ def get_prompt( f"Using custom prompt from environment variable '{env_vars.CUSTOM_PROMPT}'" ) - # FIXME: if {paragraph_text} is in the prompt, this won't work for the edits endpoint - # a simple workaround is to remove {paragraph_text} from the prompt prompt = custom_prompt.format(**placeholders) elif resolved_prompt: # use the resolved prompt from the ai-revision config files, if available @@ -401,14 +389,10 @@ def get_prompt( if custom_prompt is None: prompt = self.several_spaces_pattern.sub(" ", prompt).strip() - if self.endpoint != "edits": - if custom_prompt is not None and "{paragraph_text}" in custom_prompt: - return prompt + if custom_prompt is not None and "{paragraph_text}" in custom_prompt: + return prompt - return f"{prompt}.\n\n{paragraph_text.strip()}" - else: - prompt = prompt.replace("the following paragraph", "this paragraph") - return f"{prompt}.", paragraph_text.strip() + return f"{prompt}.\n\n{paragraph_text.strip()}" def get_max_tokens(self, paragraph_text: str, fraction: float = 2.0) -> int: """ @@ -489,14 +473,7 @@ def get_params(self, paragraph_text, section_name, resolved_prompt=None): "n": 1, } - if self.endpoint == "edits": - params.update( - { - "instruction": prompt[0], - "input": prompt[1], - } - ) - elif self.endpoint == "chat": + if self.endpoint == "chat": params.update( { "messages": [ @@ -558,16 +535,6 @@ def revise_paragraph( ) for msg in params["messages"] ] - elif "instruction" in params: - # since we don't know how to use the edits endpoint, we'll just - # concatenate the instruction and input and use the regular - # completion endpoint - # FIXME: there's probably a langchain equivalent for - # "edits", so we should change this to use that - prompt = [ - HumanMessage(content=params["instruction"]), - HumanMessage(content=params["input"]), - ] elif "prompt" in params: prompt = [HumanMessage(content=params["prompt"])] diff --git a/tests/test_model_get_prompt.py b/tests/test_model_get_prompt.py index 538e654..438a16d 100644 --- a/tests/test_model_get_prompt.py +++ b/tests/test_model_get_prompt.py @@ -28,36 +28,6 @@ def test_get_prompt_for_abstract(): assert " " not in prompt -def test_get_prompt_for_abstract_edit_endpoint(): - manuscript_title = "Title of the manuscript to be revised" - manuscript_keywords = ["keyword0", "keyword1", "keyword2"] - - model = GPT3CompletionModel( - title=manuscript_title, - keywords=manuscript_keywords, - model_engine="text-davinci-edit-001", - ) - - paragraph_text = "Text of the abstract. " - - instruction, paragraph = model.get_prompt(paragraph_text, "abstract") - assert instruction is not None - assert isinstance(instruction, str) - assert paragraph is not None - assert isinstance(paragraph, str) - - assert "this paragraph" in instruction - assert "abstract" in instruction - assert f"'{manuscript_title}'" in instruction - assert f"{manuscript_keywords[0]}" in instruction - assert f"{manuscript_keywords[1]}" in instruction - assert f"{manuscript_keywords[2]}" in instruction - assert " " not in instruction - assert instruction.startswith("Revise") - - assert paragraph_text.strip() == paragraph - - def test_get_prompt_for_introduction(): manuscript_title = "Title of the manuscript to be revised" manuscript_keywords = ["keyword0", "keyword1", "keyword2"] From 3bafcf6bd0b79460a41b4758e63ebbc748b3b408 Mon Sep 17 00:00:00 2001 From: Faisal Alquaddoomi Date: Wed, 20 Nov 2024 14:38:57 -0700 Subject: [PATCH 09/11] Adds docstring to get_params(), other small comment and consistency tweaks --- libs/manubot_ai_editor/models.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/libs/manubot_ai_editor/models.py b/libs/manubot_ai_editor/models.py index 4a3c138..0944c7f 100644 --- a/libs/manubot_ai_editor/models.py +++ b/libs/manubot_ai_editor/models.py @@ -466,6 +466,22 @@ def get_max_tokens_from_error_message(error_message: str) -> dict[str, int] | No } def get_params(self, paragraph_text, section_name, resolved_prompt=None): + """ + Given the paragraph text and section name, produces parameters that are + used when invoking an LLM via an API. + + The specific parameters vary depending on the endpoint being used, which + is determined by the model that was chosen when GPT3CompletionModel was + instantiated. + + Args: + paragraph_text: The text of the paragraph to be revised. + section_name: The name of the section the paragraph belongs to. + resolved_prompt: The prompt resolved via ai-revision config files, if available. + + Returns: + A dictionary of parameters to be used when invoking an LLM API. + """ max_tokens = self.get_max_tokens(paragraph_text) prompt = self.get_prompt(paragraph_text, section_name, resolved_prompt) @@ -504,13 +520,15 @@ def revise_paragraph( Arguments: paragraph_text (str): Paragraph text to revise. - section_name (str): Section name of the paragraph. - throw_error (bool): If True, it throws an error if the API call fails. - If False, it returns the original paragraph text. + section_name (str): Section name of the paragrap + resolved_prompt (str): Prompt resolved via ai-revision config files, if available. Returns: Revised paragraph text. """ + + # based on the paragraph text to revise and the section to which it + # belongs, constructs parameters that we'll use to query the LLM's API params = self.get_params(paragraph_text, section_name, resolved_prompt) retry_count = 0 From a98f69ae9e6222dae2dce7b296cfd5bfe1aac931 Mon Sep 17 00:00:00 2001 From: Faisal Alquaddoomi Date: Wed, 20 Nov 2024 14:39:19 -0700 Subject: [PATCH 10/11] Formatted code w/black --- libs/manubot_ai_editor/prompt_config.py | 20 +++--- tests/test_editor.py | 20 +++--- tests/test_prompt_config.py | 94 +++++++++++++++++-------- 3 files changed, 85 insertions(+), 49 deletions(-) diff --git a/libs/manubot_ai_editor/prompt_config.py b/libs/manubot_ai_editor/prompt_config.py index d2d9f6e..695952e 100644 --- a/libs/manubot_ai_editor/prompt_config.py +++ b/libs/manubot_ai_editor/prompt_config.py @@ -47,9 +47,9 @@ def __init__(self, config_dir: str | Path, title: str, keywords: str) -> None: # specify filename-to-prompt mappings; if both are present, we use # self.config.files, but warn the user that they should only use one if ( - self.prompts_files is not None and - self.config is not None and - self.config.get('files', {}).get('matchings') is not None + self.prompts_files is not None + and self.config is not None + and self.config.get("files", {}).get("matchings") is not None ): print( "WARNING: Both 'ai-revision-config.yaml' and 'ai-revision-prompts.yaml' specify filename-to-prompt mappings. " @@ -93,7 +93,7 @@ def _load_custom_prompts(self) -> tuple[dict, dict]: # same as _load_config, if no config folder was specified, we just if self.config_dir is None: return (None, None) - + prompt_file_path = os.path.join(self.config_dir, "ai-revision-prompts.yaml") try: @@ -150,7 +150,7 @@ def get_prompt_for_filename( # ai-revision-prompts.yaml specifies prompts_files, then files.matchings # takes precedence. # (the user is notified of this in a validation warning in __init__) - + # then, consult ai-revision-config.yaml's 'matchings' collection if a # match is found, use the prompt ai-revision-prompts.yaml for entry in get_obj_path(self.config, ("files", "matchings"), missing=[]): @@ -169,7 +169,10 @@ def get_prompt_for_filename( if resolved_prompt is not None: resolved_prompt = resolved_prompt.strip() - return ( resolved_prompt, m, ) + return ( + resolved_prompt, + m, + ) # since we haven't found a match yet, consult ai-revision-prompts.yaml's # 'prompts_files' collection @@ -185,11 +188,10 @@ def get_prompt_for_filename( resolved_default_prompt = None if use_default and self.prompts is not None: resolved_default_prompt = self.prompts.get( - get_obj_path(self.config, ("files", "default_prompt")), - None + get_obj_path(self.config, ("files", "default_prompt")), None ) if resolved_default_prompt is not None: resolved_default_prompt = resolved_default_prompt.strip() - + return (resolved_default_prompt, None) diff --git a/tests/test_editor.py b/tests/test_editor.py index 60bb778..1068a6a 100644 --- a/tests/test_editor.py +++ b/tests/test_editor.py @@ -610,9 +610,7 @@ def test_revise_methods_with_equation_that_was_alrady_revised( # GPT3CompletionModel(None, None), ], ) -def test_revise_methods_mutator_epistasis_paper( - tmp_path, model, filename -): +def test_revise_methods_mutator_epistasis_paper(tmp_path, model, filename): """ This papers has several test cases: - it ends with multiple blank lines @@ -635,7 +633,7 @@ def test_revise_methods_mutator_epistasis_paper( ) assert ( - r""" + r""" %%% PARAGRAPH START %%% Briefly, we identified private single-nucleotide mutations in each BXD that were absent from all other BXDs, as well as from the C57BL/6J and DBA/2J parents. We required each private variant to be meet the following criteria: @@ -651,11 +649,11 @@ def test_revise_methods_mutator_epistasis_paper( * must occur on a parental haplotype that was inherited by at least one other BXD at the same locus; these other BXDs must be homozygous for the reference allele at the variant site %%% PARAGRAPH END %%% """.strip() - in open(tmp_path / filename).read() + in open(tmp_path / filename).read() ) - + assert ( - r""" + r""" ### Extracting mutation signatures We used SigProfilerExtractor (v.1.1.21) [@PMID:30371878] to extract mutation signatures from the BXD mutation data. @@ -678,11 +676,11 @@ def test_revise_methods_mutator_epistasis_paper( ### Comparing mutation spectra between Mouse Genomes Project strains """.strip() - in open(tmp_path / filename).read() + in open(tmp_path / filename).read() ) - + assert ( - r""" + r""" %%% PARAGRAPH START %%% We investigated the region implicated by our aggregate mutation spectrum distance approach on chromosome 6 by subsetting the joint-genotyped BXD VCF file (European Nucleotide Archive accession PRJEB45429 [@url:https://www.ebi.ac.uk/ena/browser/view/PRJEB45429]) using `bcftools` [@PMID:33590861]. We defined the candidate interval surrounding the cosine distance peak on chromosome 6 as the 90% bootstrap confidence interval (extending from approximately 95 Mbp to 114 Mbp). @@ -693,7 +691,7 @@ def test_revise_methods_mutator_epistasis_paper( java -Xmx16g -jar /path/to/snpeff/jarfile GRCm38.75 /path/to/bxd/vcf > /path/to/uncompressed/output/vcf ``` """.strip() - in open(tmp_path / filename).read() + in open(tmp_path / filename).read() ) diff --git a/tests/test_prompt_config.py b/tests/test_prompt_config.py index 7f68702..4d32b6a 100644 --- a/tests/test_prompt_config.py +++ b/tests/test_prompt_config.py @@ -5,7 +5,7 @@ from manubot_ai_editor.models import ( GPT3CompletionModel, RandomManuscriptRevisionModel, - DebuggingManuscriptRevisionModel + DebuggingManuscriptRevisionModel, ) from manubot_ai_editor.prompt_config import IGNORE_FILE import pytest @@ -13,7 +13,9 @@ from utils.dir_union import mock_unify_open MANUSCRIPTS_DIR = Path(__file__).parent / "manuscripts" / "phenoplier_full" / "content" -MANUSCRIPTS_CONFIG_DIR = Path(__file__).parent / "manuscripts" / "phenoplier_full" / "ci" +MANUSCRIPTS_CONFIG_DIR = ( + Path(__file__).parent / "manuscripts" / "phenoplier_full" / "ci" +) # check that this path exists and resolve it @@ -42,7 +44,9 @@ def test_create_manuscript_editor(): # check that we can resolve a file to a prompt, and that it's the correct prompt -@mock.patch("builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, PHENOPLIER_PROMPTS_DIR)) +@mock.patch( + "builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, PHENOPLIER_PROMPTS_DIR) +) def test_resolve_prompt(): content_dir = MANUSCRIPTS_DIR.resolve(strict=True) config_dir = MANUSCRIPTS_CONFIG_DIR.resolve(strict=True) @@ -100,7 +104,9 @@ def test_resolve_prompt(): # test that we get the default prompt with a None match object for a # file we don't recognize -@mock.patch("builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, PHENOPLIER_PROMPTS_DIR)) +@mock.patch( + "builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, PHENOPLIER_PROMPTS_DIR) +) def test_resolve_default_prompt_unknown_file(): content_dir = MANUSCRIPTS_DIR.resolve(strict=True) config_dir = MANUSCRIPTS_CONFIG_DIR.resolve(strict=True) @@ -114,7 +120,9 @@ def test_resolve_default_prompt_unknown_file(): # check that a file we don't recognize gets match==None and the 'default' prompt # from the ai-revision-config.yaml file -@mock.patch("builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, PHENOPLIER_PROMPTS_DIR)) +@mock.patch( + "builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, PHENOPLIER_PROMPTS_DIR) +) def test_unresolved_gets_default_prompt(): content_dir = MANUSCRIPTS_DIR.resolve(strict=True) config_dir = MANUSCRIPTS_CONFIG_DIR.resolve(strict=True) @@ -150,7 +158,9 @@ def test_unresolved_gets_default_prompt(): # - Both ai-revision-config.yaml and ai-revision-prompts.yaml specify filename matchings # (conflicting_promptsfiles_matchings) CONFLICTING_PROMPTSFILES_MATCHINGS_DIR = ( - Path(__file__).parent / "config_loader_fixtures" / "conflicting_promptsfiles_matchings" + Path(__file__).parent + / "config_loader_fixtures" + / "conflicting_promptsfiles_matchings" ) # --- # test ManuscriptEditor.prompt_config sub-attributes are set correctly @@ -178,7 +188,9 @@ def test_no_config_unloaded(): assert editor.prompt_config.config is None -@mock.patch("builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, ONLY_REV_PROMPTS_DIR)) +@mock.patch( + "builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, ONLY_REV_PROMPTS_DIR) +) def test_only_rev_prompts_loaded(): editor = get_editor() @@ -188,7 +200,9 @@ def test_only_rev_prompts_loaded(): assert editor.prompt_config.config is None -@mock.patch("builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, BOTH_PROMPTS_CONFIG_DIR)) +@mock.patch( + "builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, BOTH_PROMPTS_CONFIG_DIR) +) def test_both_prompts_loaded(): editor = get_editor() @@ -211,7 +225,8 @@ def test_single_generic_loaded(): @mock.patch( - "builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, CONFLICTING_PROMPTSFILES_MATCHINGS_DIR) + "builtins.open", + mock_unify_open(MANUSCRIPTS_CONFIG_DIR, CONFLICTING_PROMPTSFILES_MATCHINGS_DIR), ) def test_conflicting_sources_warning(capfd): """ @@ -234,7 +249,7 @@ def test_conflicting_sources_warning(capfd): # for this test, we define both prompts_files and files.matchings which # creates a conflict that produces the warning we're looking for assert editor.prompt_config.prompts_files is not None - assert editor.prompt_config.config['files']['matchings'] is not None + assert editor.prompt_config.config["files"]["matchings"] is not None expected_warning = ( "WARNING: Both 'ai-revision-config.yaml' and " @@ -262,11 +277,13 @@ def test_conflicting_sources_warning(capfd): RandomManuscriptRevisionModel(), DebuggingManuscriptRevisionModel( title="Test title", keywords=["test", "keywords"] - ) + ), # GPT3CompletionModel(None, None), ], ) -@mock.patch("builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, BOTH_PROMPTS_CONFIG_DIR)) +@mock.patch( + "builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, BOTH_PROMPTS_CONFIG_DIR) +) def test_revise_entire_manuscript(tmp_path, model): print(f"\n{str(tmp_path)}\n") me = get_editor() @@ -284,7 +301,9 @@ def test_revise_entire_manuscript(tmp_path, model): assert len(output_md_files) == 9 -@mock.patch("builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, BOTH_PROMPTS_CONFIG_DIR)) +@mock.patch( + "builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, BOTH_PROMPTS_CONFIG_DIR) +) def test_revise_entire_manuscript_includes_title_keywords(tmp_path): from os.path import basename @@ -317,8 +336,12 @@ def test_revise_entire_manuscript_includes_title_keywords(tmp_path): with open(output_md_file, "r") as f: content = f.read() - assert me.title in content, f"not found in filename: {basename(output_md_file)}" - assert ", ".join(me.keywords) in content, f"not found in filename: {basename(output_md_file)}" + assert ( + me.title in content + ), f"not found in filename: {basename(output_md_file)}" + assert ( + ", ".join(me.keywords) in content + ), f"not found in filename: {basename(output_md_file)}" # ============================================================================== @@ -329,7 +352,11 @@ def test_revise_entire_manuscript_includes_title_keywords(tmp_path): Path(__file__).parent / "config_loader_fixtures" / "prompt_propogation" ) -@mock.patch("builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, PROMPT_PROPOGATION_CONFIG_DIR)) + +@mock.patch( + "builtins.open", + mock_unify_open(MANUSCRIPTS_CONFIG_DIR, PROMPT_PROPOGATION_CONFIG_DIR), +) def test_prompts_in_final_result(tmp_path): """ Tests that the prompts are making it into the final resulting .md files. @@ -348,9 +375,7 @@ def test_prompts_in_final_result(tmp_path): """ me = get_editor() - model = DebuggingManuscriptRevisionModel( - title=me.title, keywords=me.keywords - ) + model = DebuggingManuscriptRevisionModel(title=me.title, keywords=me.keywords) output_folder = tmp_path assert output_folder.exists() @@ -361,7 +386,8 @@ def test_prompts_in_final_result(tmp_path): files_to_prompts = { "00.front-matter.md": "This is the front-matter prompt.", "01.abstract.md": "This is the abstract prompt", - "02.introduction.md": "This is the introduction prompt for the paper titled '%s'." % me.title, + "02.introduction.md": "This is the introduction prompt for the paper titled '%s'." + % me.title, # "04.00.results.md": "This is the results prompt", "04.05.00.results_framework.md": "This is the results_framework prompt", "04.05.01.crispr.md": "This is the crispr prompt", @@ -389,15 +415,26 @@ def test_prompts_in_final_result(tmp_path): # to save on time/cost, we use a version of the phenoplier manuscript that only # contains the first paragraph of each section -BRIEF_MANUSCRIPTS_DIR = Path(__file__).parent / "manuscripts" / "phenoplier_full_only_first_para" / "content" -BRIEF_MANUSCRIPTS_CONFIG_DIR = Path(__file__).parent / "manuscripts" / "phenoplier_full_only_first_para" / "ci" +BRIEF_MANUSCRIPTS_DIR = ( + Path(__file__).parent + / "manuscripts" + / "phenoplier_full_only_first_para" + / "content" +) +BRIEF_MANUSCRIPTS_CONFIG_DIR = ( + Path(__file__).parent / "manuscripts" / "phenoplier_full_only_first_para" / "ci" +) PROMPT_PROPOGATION_CONFIG_DIR = ( Path(__file__).parent / "config_loader_fixtures" / "prompt_gpt3_e2e" ) + @pytest.mark.cost -@mock.patch("builtins.open", mock_unify_open(BRIEF_MANUSCRIPTS_CONFIG_DIR, PROMPT_PROPOGATION_CONFIG_DIR)) +@mock.patch( + "builtins.open", + mock_unify_open(BRIEF_MANUSCRIPTS_CONFIG_DIR, PROMPT_PROPOGATION_CONFIG_DIR), +) def test_prompts_apply_gpt3(tmp_path): """ Tests that the custom prompts are applied when actually applying @@ -408,16 +445,15 @@ def test_prompts_apply_gpt3(tmp_path): this test is marked 'cost' and requires the --runcost argument to be run, e.g. to run just this test: `pytest --runcost -k test_prompts_apply_gpt3`. - As with test_prompts_in_final_result above, files that have no input and + As with test_prompts_in_final_result above, files that have no input and thus no applied prompt are ignored. """ - me = get_editor(content_dir=BRIEF_MANUSCRIPTS_DIR, config_dir=BRIEF_MANUSCRIPTS_CONFIG_DIR) - - model = GPT3CompletionModel( - title=me.title, - keywords=me.keywords + me = get_editor( + content_dir=BRIEF_MANUSCRIPTS_DIR, config_dir=BRIEF_MANUSCRIPTS_CONFIG_DIR ) + model = GPT3CompletionModel(title=me.title, keywords=me.keywords) + output_folder = tmp_path assert output_folder.exists() From 4c66a58de9e1bc1d5271cdc5e3697f62c33e9b10 Mon Sep 17 00:00:00 2001 From: Faisal Alquaddoomi Date: Tue, 3 Dec 2024 14:13:23 -0700 Subject: [PATCH 11/11] Removed FIXME; we'll always use a number of responses, aka n, of 1 so no need to use call() over invoke() --- libs/manubot_ai_editor/models.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/libs/manubot_ai_editor/models.py b/libs/manubot_ai_editor/models.py index 0944c7f..80a2dcc 100644 --- a/libs/manubot_ai_editor/models.py +++ b/libs/manubot_ai_editor/models.py @@ -567,12 +567,6 @@ def revise_paragraph( else: message = response.strip() - # FIXME: the prior code retrieved the first of the 'choices' - # response from the openai client. now, we only get one - # response from the langchain client, but i should check - # if that's really how langchain works or if there is a way - # to get multiple 'choices' back from the backend. - except Exception as e: error_message = str(e) print(f"Error: {error_message}")