Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mappers Part 1 #167

Merged
merged 3 commits into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions marsha/.time.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import time

from marsha.llm import prettify_time_delta
from marsha.utils import prettify_time_delta

from mistletoe import Document, ast_renderer

Expand Down Expand Up @@ -42,7 +42,7 @@
run_stats_file = open('stats.md', 'r')
run_stats = run_stats_file.read()
run_stats_file.close()
except Exception as e:
except Exception:
raise Exception('Error reading stats file. Maybe something went run while running Marsha and the stats were not generated?')
try:
ast = ast_renderer.get_ast(Document(run_stats))
Expand Down
25 changes: 12 additions & 13 deletions marsha/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import traceback
import sys

from marsha.llm import gpt_can_func_python, gpt_improve_func, gpt_func_to_python, lint_and_fix_files, test_and_fix_files, prettify_time_delta
from marsha.llm import gpt_can_func_python, gpt_improve_func, gpt_func_to_python, lint_and_fix_files, test_and_fix_files
from marsha.parse import extract_functions_and_types, extract_type_name, write_files_from_markdown, is_defined_from_file, extract_type_filename
from marsha.stats import MarshaStats
from marsha.utils import read_file, autoformat_files, copy_file, get_filename_from_path, add_helper, copy_tree
from marsha.stats import stats
from marsha.utils import read_file, autoformat_files, copy_file, get_filename_from_path, add_helper, copy_tree, prettify_time_delta

# Set up OpenAI
openai.organization = os.getenv('OPENAI_ORG')
Expand Down Expand Up @@ -40,7 +40,6 @@

async def main():
t1 = time.time()
stats = MarshaStats()
input_file = args.source
# Name without extension
marsha_file_dirname = os.path.dirname(input_file)
Expand All @@ -66,7 +65,7 @@ async def main():
# First stage: generate code for functions and classes
try:
mds = await generate_python_code(
marsha_filename, functions, types_defined, void_funcs, n_results, debug, stats)
marsha_filename, functions, types_defined, void_funcs, n_results, debug)
except Exception:
continue
# Early exit if quick and dirty
Expand Down Expand Up @@ -94,7 +93,7 @@ async def main():
tasks = []
for file_group in file_groups:
tasks.append(asyncio.create_task(
review_and_fix(marsha_filename, file_group, functions, types_defined, void_funcs, stats, debug), name=file_group[0]))
review_and_fix(marsha_filename, file_group, functions, types_defined, void_funcs, debug), name=file_group[0]))
try:
done_task_name = await run_parallel_tasks(tasks)
print('Writing generated code to files...')
Expand Down Expand Up @@ -140,16 +139,16 @@ async def main():
f'{marsha_filename} done! Total time elapsed: {prettify_time_delta(t2 - t1)}. Total cost: {round(stats.total_cost, 2)}.')


async def generate_python_code(marsha_filename: str, functions: list[str], types_defined: list[str], void_funcs: list[str], n_results: int, debug: bool, stats: MarshaStats) -> list[str]:
async def generate_python_code(marsha_filename: str, functions: list[str], types_defined: list[str], void_funcs: list[str], n_results: int, debug: bool) -> list[str]:
t1 = time.time()
print('Generating Python code...')
mds = None
try:
if not args.exclude_sanity_check:
if not await gpt_can_func_python(marsha_filename, functions, types_defined, void_funcs, n_results, stats):
await gpt_improve_func(marsha_filename, functions, types_defined, void_funcs, stats)
if not await gpt_can_func_python(marsha_filename, functions, types_defined, void_funcs, n_results):
await gpt_improve_func(marsha_filename, functions, types_defined, void_funcs)
sys.exit(1)
mds = await gpt_func_to_python(marsha_filename, functions, types_defined, void_funcs, n_results, stats, debug=debug)
mds = await gpt_func_to_python(marsha_filename, functions, types_defined, void_funcs, n_results, debug=debug)
except Exception as e:
print('First stage failure')
print(e)
Expand Down Expand Up @@ -188,11 +187,11 @@ async def process_types(raw_types: list[str], dirname: str) -> list[str]:
return types_defined


async def review_and_fix(marsha_filename: str, files: list[str], functions: list[str], defined_types: list[str], void_functions: list[str], stats: MarshaStats, debug: bool = False):
async def review_and_fix(marsha_filename: str, files: list[str], functions: list[str], defined_types: list[str], void_functions: list[str], debug: bool = False):
t_ssi = time.time()
print('Parsing generated code...')
try:
await lint_and_fix_files(marsha_filename, files, stats, debug=debug)
await lint_and_fix_files(marsha_filename, files, debug=debug)
except Exception as e:
print('Second stage failure')
print(e)
Expand All @@ -207,7 +206,7 @@ async def review_and_fix(marsha_filename: str, files: list[str], functions: list
t_tsi = time.time()
print('Verifying and correcting generated code...')
try:
await test_and_fix_files(marsha_filename, functions, defined_types, void_functions, files, stats, debug=debug)
await test_and_fix_files(marsha_filename, functions, defined_types, void_functions, files, debug=debug)
except Exception as e:
print('Third stage failure')
print(e)
Expand Down
31 changes: 31 additions & 0 deletions marsha/basemapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
class BaseMapper():
"""Semi-abstract base for 'mappers' in Marsha"""

def __init__(self):
self.check_retries = 3
self.output = None

async def transform(self, i):
raise Exception('Not implemented')

async def check(self):
# Define a check if you want, but not necessary
return self.output

async def run(self, i):
try:
self.output = await self.transform(i)
except Exception as e:
# TODO: Log the error before re-raise?
raise e

iters = self.check_retries
while iters > 0:
try:
o = await self.check()
return o
except Exception:
# Using the exception here as flow control
iters = iters - 1

raise Exception('Transformer failed to converge')
69 changes: 69 additions & 0 deletions marsha/chatgptmapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import openai
import time

from marsha.basemapper import BaseMapper
from marsha.stats import stats
from marsha.utils import prettify_time_delta

# Get time at startup to make human legible "start times" in the logs
t0 = time.time()


async def retry_chat_completion(query, model='gpt-3.5-turbo', max_tries=3, n_results=1):
t1 = time.time()
query['model'] = model
query['n'] = n_results
while True:
try:
out = await openai.ChatCompletion.acreate(**query)
t2 = time.time()
print(
f'''Chat query took {prettify_time_delta(t2 - t1)}, started at {prettify_time_delta(t1 - t0)}, ms/chars = {(t2 - t1) * 1000 / out.get('usage', {}).get('total_tokens', 9001)}''')
return out
except openai.error.InvalidRequestError as e:
if e.code == 'context_length_exceeded':
# Try to cover up this error by choosing the bigger, more expensive model
query['model'] = 'gpt-4'
max_tries = max_tries - 1
if max_tries == 0:
raise e
time.sleep(3 / max_tries)
except Exception as e:
max_tries = max_tries - 1
if max_tries == 0:
raise e
time.sleep(3 / max_tries)
if max_tries == 0:
raise Exception('Could not execute chat completion')


class ChatGPTMapper(BaseMapper):
"""ChatGPT-based mapper class"""

def __init__(self, system, model='gpt-3.5-turbo', max_tokens=None, max_retries=3, n_results=1, stats_stage=None):
BaseMapper.__init__(self)
self.system = system
self.model = model
self.max_tokens = max_tokens
self.max_retries = max_retries
self.n_results = n_results
self.stats_stage = stats_stage

async def transform(self, user_request):
query_obj = {
'messages': [{
'role': 'system',
'content': self.system,
}, {
'role': 'user',
'content': user_request,
}],
}
if self.max_tokens is not None:
query_obj['max_tokens'] = self.max_tokens
res = await retry_chat_completion(query_obj, self.model, self.max_retries, self.n_results)

if self.stats_stage is not None:
stats.stage_update(self.stats_stage, [res])

return [choice.message.content for choice in res.choices] if self.n_results > 1 else res.choices[0].message.content
Loading
Loading