From d08f90c08abfbdeeb68ad571b9402e74c257ba23 Mon Sep 17 00:00:00 2001 From: Quinten Steenhuis Date: Thu, 30 Nov 2023 10:48:33 -0500 Subject: [PATCH] Should be good, unable to test yet --- formfyxer/keys/openai_key.txt | 1 - formfyxer/keys/openai_org.txt | 1 - formfyxer/lit_explorer.py | 64 +++++++++++++++++++++++++---------- 3 files changed, 46 insertions(+), 20 deletions(-) diff --git a/formfyxer/keys/openai_key.txt b/formfyxer/keys/openai_key.txt index fd07277..e69de29 100644 --- a/formfyxer/keys/openai_key.txt +++ b/formfyxer/keys/openai_key.txt @@ -1 +0,0 @@ -your_OPENAI_API_key goes here diff --git a/formfyxer/keys/openai_org.txt b/formfyxer/keys/openai_org.txt index e0c3e24..e69de29 100644 --- a/formfyxer/keys/openai_org.txt +++ b/formfyxer/keys/openai_org.txt @@ -1 +0,0 @@ -your_OPENAI_API_org goes here diff --git a/formfyxer/lit_explorer.py b/formfyxer/lit_explorer.py index f5eea77..851ecec 100644 --- a/formfyxer/lit_explorer.py +++ b/formfyxer/lit_explorer.py @@ -65,6 +65,8 @@ ) import openai +from openai import OpenAI + from transformers import GPT2TokenizerFast tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") @@ -125,14 +127,27 @@ os.path.join(os.path.dirname(__file__), "keys", "spot_token.txt"), "r" ) as in_file: default_spot_token = in_file.read().rstrip() -with open( - os.path.join(os.path.dirname(__file__), "keys", "openai_org.txt"), "r" -) as in_file: - openai.organization = in_file.read().rstrip() -with open( - os.path.join(os.path.dirname(__file__), "keys", "openai_key.txt"), "r" -) as in_file: - openai.api_key = in_file.read().rstrip() +try: + with open( + os.path.join(os.path.dirname(__file__), "keys", "openai_key.txt"), "r" + ) as in_file: + default_key = OpenAI(api_key=in_file.read().rstrip()) +except: + default_key = None +try: + with open( + os.path.join(os.path.dirname(__file__), "keys", "openai_org.txt"), "r" + ) as in_file: + default_org = in_file.read().rstrip() +except: + default_org = None +if default_key: + client = OpenAI(api_key=default_key, organization=default_org or None) +elif os.getenv("OPENAI_API_KEY"): + client = OpenAI() +else: + client = None + # TODO(brycew): remove by retraining the model to work with random_state=4. NEEDS_STABILITY = True if os.getenv("ISUNITTEST") else False @@ -801,22 +816,35 @@ class OpenAiCreds(TypedDict): key: str -def text_complete(prompt, max_tokens=500, creds: Optional[OpenAiCreds] = None) -> str: - if creds: - openai.organization = creds["org"].strip() or "" - openai.api_key = creds["key"].strip() or "" +def text_complete(prompt:str, max_tokens:int=500, creds: Optional[OpenAiCreds] = None, temperature:float=0) -> str: + """Run a prompt via openAI's API and return the result. + Args: + prompt (str): The prompt to send to the API. + max_tokens (int, optional): The number of tokens to generate. Defaults to 500. + creds (Optional[OpenAiCreds], optional): The credentials to use. Defaults to None. + temperature (float, optional): The temperature to use. Defaults to 0. + """ + if creds: + openai_client = OpenAI(api_key=creds["key"], organization=creds["org"]) + else: + openai_client = client try: - response = openai.Completion.create( - model="text-davinci-003", - prompt=prompt, - temperature=0, + response = openai_client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[ + { + "role": "system", + "content": prompt + }, + ], + temperature=temperature, max_tokens=max_tokens, top_p=1.0, frequency_penalty=0.0, - presence_penalty=0.0, + presence_penalty=0.0 ) - return str(response["choices"][0]["text"].strip()) + return str(response.choices[0].message.content.strip()) except Exception as ex: print(f"{ex}") return "ApiError"