diff --git a/ai_ta_backend/beam/nomic_logging.py b/ai_ta_backend/beam/nomic_logging.py index 92db8a62..d15c616e 100644 --- a/ai_ta_backend/beam/nomic_logging.py +++ b/ai_ta_backend/beam/nomic_logging.py @@ -1,438 +1,438 @@ -import datetime -import os - -import nomic -import numpy as np -import pandas as pd -import sentry_sdk -import supabase -from langchain.embeddings import OpenAIEmbeddings -from nomic import AtlasProject, atlas - -OPENAI_API_TYPE = "azure" - -SUPABASE_CLIENT = supabase.create_client( # type: ignore - supabase_url=os.getenv('SUPABASE_URL'), # type: ignore - supabase_key=os.getenv('SUPABASE_API_KEY')) # type: ignore - -NOMIC_MAP_NAME_PREFIX = 'Document Map for ' - -## -------------------------------- DOCUMENT MAP FUNCTIONS --------------------------------- ## - -def create_document_map(course_name: str): - """ - This is a function which creates a document map for a given course from scratch - 1. Gets count of documents for the course - 2. If less than 20, returns a message that a map cannot be created - 3. If greater than 20, iteratively fetches documents in batches of 25 - 4. Prepares metadata and embeddings for nomic upload - 5. Creates a new map and uploads the data - - Args: - course_name: str - Returns: - str: success or failed - """ - print("in create_document_map()") - nomic.login(os.getenv('NOMIC_API_KEY')) +# import datetime +# import os + +# import nomic +# import numpy as np +# import pandas as pd +# import sentry_sdk +# import supabase +# from langchain.embeddings import OpenAIEmbeddings +# from nomic import AtlasProject, atlas + +# OPENAI_API_TYPE = "azure" + +# SUPABASE_CLIENT = supabase.create_client( # type: ignore +# supabase_url=os.getenv('SUPABASE_URL'), # type: ignore +# supabase_key=os.getenv('SUPABASE_API_KEY')) # type: ignore + +# NOMIC_MAP_NAME_PREFIX = 'Document Map for ' + +# ## -------------------------------- DOCUMENT MAP FUNCTIONS --------------------------------- ## + +# def create_document_map(course_name: str): +# """ +# This is a function which creates a document map for a given course from scratch +# 1. Gets count of documents for the course +# 2. If less than 20, returns a message that a map cannot be created +# 3. If greater than 20, iteratively fetches documents in batches of 25 +# 4. Prepares metadata and embeddings for nomic upload +# 5. Creates a new map and uploads the data + +# Args: +# course_name: str +# Returns: +# str: success or failed +# """ +# print("in create_document_map()") +# nomic.login(os.getenv('NOMIC_API_KEY')) - try: - # check if map exists - response = SUPABASE_CLIENT.table("projects").select("doc_map_id").eq("course_name", course_name).execute() - if response.data: - if response.data[0]['doc_map_id']: - return "Map already exists for this course." - - # fetch relevant document data from Supabase - response = SUPABASE_CLIENT.table("documents").select("id", - count="exact").eq("course_name", - course_name).order('id', - desc=False).execute() - if not response.count: - return "No documents found for this course." +# try: +# # check if map exists +# response = SUPABASE_CLIENT.table("projects").select("doc_map_id").eq("course_name", course_name).execute() +# if response.data: +# if response.data[0]['doc_map_id']: +# return "Map already exists for this course." + +# # fetch relevant document data from Supabase +# response = SUPABASE_CLIENT.table("documents").select("id", +# count="exact").eq("course_name", +# course_name).order('id', +# desc=False).execute() +# if not response.count: +# return "No documents found for this course." - total_doc_count = response.count - print("Total number of documents in Supabase: ", total_doc_count) +# total_doc_count = response.count +# print("Total number of documents in Supabase: ", total_doc_count) - # minimum 20 docs needed to create map - if total_doc_count < 20: - return "Cannot create a map because there are less than 20 documents in the course." +# # minimum 20 docs needed to create map +# if total_doc_count < 20: +# return "Cannot create a map because there are less than 20 documents in the course." - first_id = response.data[0]['id'] +# first_id = response.data[0]['id'] - combined_dfs = [] - curr_total_doc_count = 0 - doc_count = 0 - first_batch = True +# combined_dfs = [] +# curr_total_doc_count = 0 +# doc_count = 0 +# first_batch = True - # iteratively query in batches of 25 - while curr_total_doc_count < total_doc_count: +# # iteratively query in batches of 25 +# while curr_total_doc_count < total_doc_count: - response = SUPABASE_CLIENT.table("documents").select( - "id, created_at, s3_path, url, base_url, readable_filename, contexts").eq("course_name", course_name).gte( - 'id', first_id).order('id', desc=False).limit(25).execute() - df = pd.DataFrame(response.data) - combined_dfs.append(df) # list of dfs - - curr_total_doc_count += len(response.data) - doc_count += len(response.data) - - if doc_count >= 1000: # upload to Nomic in batches of 1000 - - # concat all dfs from the combined_dfs list - final_df = pd.concat(combined_dfs, ignore_index=True) - - # prep data for nomic upload - embeddings, metadata = data_prep_for_doc_map(final_df) - - if first_batch: - # create a new map - print("Creating new map...") - project_name = NOMIC_MAP_NAME_PREFIX + course_name - index_name = course_name + "_doc_index" - topic_label_field = "text" - colorable_fields = ["readable_filename", "text", "base_url", "created_at"] - result = create_map(embeddings, metadata, project_name, index_name, topic_label_field, colorable_fields) +# response = SUPABASE_CLIENT.table("documents").select( +# "id, created_at, s3_path, url, base_url, readable_filename, contexts").eq("course_name", course_name).gte( +# 'id', first_id).order('id', desc=False).limit(25).execute() +# df = pd.DataFrame(response.data) +# combined_dfs.append(df) # list of dfs + +# curr_total_doc_count += len(response.data) +# doc_count += len(response.data) + +# if doc_count >= 1000: # upload to Nomic in batches of 1000 + +# # concat all dfs from the combined_dfs list +# final_df = pd.concat(combined_dfs, ignore_index=True) + +# # prep data for nomic upload +# embeddings, metadata = data_prep_for_doc_map(final_df) + +# if first_batch: +# # create a new map +# print("Creating new map...") +# project_name = NOMIC_MAP_NAME_PREFIX + course_name +# index_name = course_name + "_doc_index" +# topic_label_field = "text" +# colorable_fields = ["readable_filename", "text", "base_url", "created_at"] +# result = create_map(embeddings, metadata, project_name, index_name, topic_label_field, colorable_fields) - if result == "success": - # update flag - first_batch = False - # log project info to supabase - project = AtlasProject(name=project_name, add_datums_if_exists=True) - project_id = project.id - last_id = int(final_df['id'].iloc[-1]) - project_info = {'course_name': course_name, 'doc_map_id': project_id, 'last_uploaded_doc_id': last_id} - project_response = SUPABASE_CLIENT.table("projects").select("*").eq("course_name", course_name).execute() - if project_response.data: - update_response = SUPABASE_CLIENT.table("projects").update(project_info).eq("course_name", course_name).execute() - print("Response from supabase: ", update_response) - else: - insert_response = SUPABASE_CLIENT.table("projects").insert(project_info).execute() - print("Insert Response from supabase: ", insert_response) +# if result == "success": +# # update flag +# first_batch = False +# # log project info to supabase +# project = AtlasProject(name=project_name, add_datums_if_exists=True) +# project_id = project.id +# last_id = int(final_df['id'].iloc[-1]) +# project_info = {'course_name': course_name, 'doc_map_id': project_id, 'last_uploaded_doc_id': last_id} +# project_response = SUPABASE_CLIENT.table("projects").select("*").eq("course_name", course_name).execute() +# if project_response.data: +# update_response = SUPABASE_CLIENT.table("projects").update(project_info).eq("course_name", course_name).execute() +# print("Response from supabase: ", update_response) +# else: +# insert_response = SUPABASE_CLIENT.table("projects").insert(project_info).execute() +# print("Insert Response from supabase: ", insert_response) - else: - # append to existing map - print("Appending data to existing map...") - project_name = NOMIC_MAP_NAME_PREFIX + course_name - # add project lock logic here - result = append_to_map(embeddings, metadata, project_name) - if result == "success": - # update the last uploaded id in supabase - last_id = int(final_df['id'].iloc[-1]) - info = {'last_uploaded_doc_id': last_id} - update_response = SUPABASE_CLIENT.table("projects").update(info).eq("course_name", course_name).execute() - print("Response from supabase: ", update_response) +# else: +# # append to existing map +# print("Appending data to existing map...") +# project_name = NOMIC_MAP_NAME_PREFIX + course_name +# # add project lock logic here +# result = append_to_map(embeddings, metadata, project_name) +# if result == "success": +# # update the last uploaded id in supabase +# last_id = int(final_df['id'].iloc[-1]) +# info = {'last_uploaded_doc_id': last_id} +# update_response = SUPABASE_CLIENT.table("projects").update(info).eq("course_name", course_name).execute() +# print("Response from supabase: ", update_response) - # reset variables - combined_dfs = [] - doc_count = 0 - print("Records uploaded: ", curr_total_doc_count) - - # set first_id for next iteration - first_id = response.data[-1]['id'] + 1 - - # upload last set of docs - if doc_count > 0: - final_df = pd.concat(combined_dfs, ignore_index=True) - embeddings, metadata = data_prep_for_doc_map(final_df) - project_name = NOMIC_MAP_NAME_PREFIX + course_name - if first_batch: - index_name = course_name + "_doc_index" - topic_label_field = "text" - colorable_fields = ["readable_filename", "text", "base_url", "created_at"] - result = create_map(embeddings, metadata, project_name, index_name, topic_label_field, colorable_fields) - else: - result = append_to_map(embeddings, metadata, project_name) - - # update the last uploaded id in supabase - if result == "success": - # update the last uploaded id in supabase - last_id = int(final_df['id'].iloc[-1]) - project = AtlasProject(name=project_name, add_datums_if_exists=True) - project_id = project.id - project_info = {'course_name': course_name, 'doc_map_id': project_id, 'last_uploaded_doc_id': last_id} - print("project_info: ", project_info) - project_response = SUPABASE_CLIENT.table("projects").select("*").eq("course_name", course_name).execute() - if project_response.data: - update_response = SUPABASE_CLIENT.table("projects").update(project_info).eq("course_name", course_name).execute() - print("Response from supabase: ", update_response) - else: - insert_response = SUPABASE_CLIENT.table("projects").insert(project_info).execute() - print("Insert Response from supabase: ", insert_response) +# # reset variables +# combined_dfs = [] +# doc_count = 0 +# print("Records uploaded: ", curr_total_doc_count) + +# # set first_id for next iteration +# first_id = response.data[-1]['id'] + 1 + +# # upload last set of docs +# if doc_count > 0: +# final_df = pd.concat(combined_dfs, ignore_index=True) +# embeddings, metadata = data_prep_for_doc_map(final_df) +# project_name = NOMIC_MAP_NAME_PREFIX + course_name +# if first_batch: +# index_name = course_name + "_doc_index" +# topic_label_field = "text" +# colorable_fields = ["readable_filename", "text", "base_url", "created_at"] +# result = create_map(embeddings, metadata, project_name, index_name, topic_label_field, colorable_fields) +# else: +# result = append_to_map(embeddings, metadata, project_name) + +# # update the last uploaded id in supabase +# if result == "success": +# # update the last uploaded id in supabase +# last_id = int(final_df['id'].iloc[-1]) +# project = AtlasProject(name=project_name, add_datums_if_exists=True) +# project_id = project.id +# project_info = {'course_name': course_name, 'doc_map_id': project_id, 'last_uploaded_doc_id': last_id} +# print("project_info: ", project_info) +# project_response = SUPABASE_CLIENT.table("projects").select("*").eq("course_name", course_name).execute() +# if project_response.data: +# update_response = SUPABASE_CLIENT.table("projects").update(project_info).eq("course_name", course_name).execute() +# print("Response from supabase: ", update_response) +# else: +# insert_response = SUPABASE_CLIENT.table("projects").insert(project_info).execute() +# print("Insert Response from supabase: ", insert_response) - # rebuild the map - rebuild_map(course_name, "document") +# # rebuild the map +# rebuild_map(course_name, "document") - except Exception as e: - print(e) - sentry_sdk.capture_exception(e) - return "failed" - -def delete_from_document_map(course_name: str, ids: list): - """ - This function is used to delete datapoints from a document map. - Currently used within the delete_data() function in vector_database.py - Args: - course_name: str - ids: list of str - """ - print("in delete_from_document_map()") - - try: - # check if project exists - response = SUPABASE_CLIENT.table("projects").select("doc_map_id").eq("course_name", course_name).execute() - if response.data: - project_id = response.data[0]['doc_map_id'] - else: - return "No document map found for this course" - - # fetch project from Nomic - project = AtlasProject(project_id=project_id, add_datums_if_exists=True) - - # delete the ids from Nomic - print("Deleting point from document map:", project.delete_data(ids)) - with project.wait_for_project_lock(): - project.rebuild_maps() - return "success" - except Exception as e: - print(e) - sentry_sdk.capture_exception(e) - return "Error in deleting from document map: {e}" - - -def log_to_document_map(course_name: str): - """ - This is a function which appends new documents to an existing document map. It's called - at the end of split_and_upload() after inserting data to Supabase. - Args: - data: dict - the response data from Supabase insertion - """ - print("in add_to_document_map()") - - try: - # check if map exists - response = SUPABASE_CLIENT.table("projects").select("doc_map_id, last_uploaded_doc_id").eq("course_name", course_name).execute() - if response.data: - if response.data[0]['doc_map_id']: - project_id = response.data[0]['doc_map_id'] - last_uploaded_doc_id = response.data[0]['last_uploaded_doc_id'] - else: - # entry present in supabase, but doc map not present - create_document_map(course_name) - return "Document map not present, triggering map creation." - - else: - # create a map - create_document_map(course_name) - return "Document map not present, triggering map creation." +# except Exception as e: +# print(e) +# sentry_sdk.capture_exception(e) +# return "failed" + +# def delete_from_document_map(course_name: str, ids: list): +# """ +# This function is used to delete datapoints from a document map. +# Currently used within the delete_data() function in vector_database.py +# Args: +# course_name: str +# ids: list of str +# """ +# print("in delete_from_document_map()") + +# try: +# # check if project exists +# response = SUPABASE_CLIENT.table("projects").select("doc_map_id").eq("course_name", course_name).execute() +# if response.data: +# project_id = response.data[0]['doc_map_id'] +# else: +# return "No document map found for this course" + +# # fetch project from Nomic +# project = AtlasProject(project_id=project_id, add_datums_if_exists=True) + +# # delete the ids from Nomic +# print("Deleting point from document map:", project.delete_data(ids)) +# with project.wait_for_project_lock(): +# project.rebuild_maps() +# return "success" +# except Exception as e: +# print(e) +# sentry_sdk.capture_exception(e) +# return "Error in deleting from document map: {e}" + + +# def log_to_document_map(course_name: str): +# """ +# This is a function which appends new documents to an existing document map. It's called +# at the end of split_and_upload() after inserting data to Supabase. +# Args: +# data: dict - the response data from Supabase insertion +# """ +# print("in add_to_document_map()") + +# try: +# # check if map exists +# response = SUPABASE_CLIENT.table("projects").select("doc_map_id, last_uploaded_doc_id").eq("course_name", course_name).execute() +# if response.data: +# if response.data[0]['doc_map_id']: +# project_id = response.data[0]['doc_map_id'] +# last_uploaded_doc_id = response.data[0]['last_uploaded_doc_id'] +# else: +# # entry present in supabase, but doc map not present +# create_document_map(course_name) +# return "Document map not present, triggering map creation." + +# else: +# # create a map +# create_document_map(course_name) +# return "Document map not present, triggering map creation." - project = AtlasProject(project_id=project_id, add_datums_if_exists=True) - project_name = "Document Map for " + course_name +# project = AtlasProject(project_id=project_id, add_datums_if_exists=True) +# project_name = "Document Map for " + course_name - # check if project is LOCKED, if yes -> skip logging - if not project.is_accepting_data: - return "Skipping Nomic logging because project is locked." +# # check if project is LOCKED, if yes -> skip logging +# if not project.is_accepting_data: +# return "Skipping Nomic logging because project is locked." - # fetch count of records greater than last_uploaded_doc_id - print("last uploaded doc id: ", last_uploaded_doc_id) - response = SUPABASE_CLIENT.table("documents").select("id", count="exact").eq("course_name", course_name).gt("id", last_uploaded_doc_id).execute() - print("Number of new documents: ", response.count) - - total_doc_count = response.count - current_doc_count = 0 - combined_dfs = [] - doc_count = 0 - first_id = last_uploaded_doc_id - while current_doc_count < total_doc_count: - # fetch all records from supabase greater than last_uploaded_doc_id - response = SUPABASE_CLIENT.table("documents").select("id, created_at, s3_path, url, base_url, readable_filename, contexts").eq("course_name", course_name).gt("id", first_id).limit(25).execute() - df = pd.DataFrame(response.data) - combined_dfs.append(df) # list of dfs - - current_doc_count += len(response.data) - doc_count += len(response.data) - - if doc_count >= 1000: # upload to Nomic in batches of 1000 - # concat all dfs from the combined_dfs list - final_df = pd.concat(combined_dfs, ignore_index=True) - # prep data for nomic upload - embeddings, metadata = data_prep_for_doc_map(final_df) - - # append to existing map - print("Appending data to existing map...") +# # fetch count of records greater than last_uploaded_doc_id +# print("last uploaded doc id: ", last_uploaded_doc_id) +# response = SUPABASE_CLIENT.table("documents").select("id", count="exact").eq("course_name", course_name).gt("id", last_uploaded_doc_id).execute() +# print("Number of new documents: ", response.count) + +# total_doc_count = response.count +# current_doc_count = 0 +# combined_dfs = [] +# doc_count = 0 +# first_id = last_uploaded_doc_id +# while current_doc_count < total_doc_count: +# # fetch all records from supabase greater than last_uploaded_doc_id +# response = SUPABASE_CLIENT.table("documents").select("id, created_at, s3_path, url, base_url, readable_filename, contexts").eq("course_name", course_name).gt("id", first_id).limit(25).execute() +# df = pd.DataFrame(response.data) +# combined_dfs.append(df) # list of dfs + +# current_doc_count += len(response.data) +# doc_count += len(response.data) + +# if doc_count >= 1000: # upload to Nomic in batches of 1000 +# # concat all dfs from the combined_dfs list +# final_df = pd.concat(combined_dfs, ignore_index=True) +# # prep data for nomic upload +# embeddings, metadata = data_prep_for_doc_map(final_df) + +# # append to existing map +# print("Appending data to existing map...") - result = append_to_map(embeddings, metadata, project_name) - if result == "success": - # update the last uploaded id in supabase - last_id = int(final_df['id'].iloc[-1]) - info = {'last_uploaded_doc_id': last_id} - update_response = SUPABASE_CLIENT.table("projects").update(info).eq("course_name", course_name).execute() - print("Response from supabase: ", update_response) +# result = append_to_map(embeddings, metadata, project_name) +# if result == "success": +# # update the last uploaded id in supabase +# last_id = int(final_df['id'].iloc[-1]) +# info = {'last_uploaded_doc_id': last_id} +# update_response = SUPABASE_CLIENT.table("projects").update(info).eq("course_name", course_name).execute() +# print("Response from supabase: ", update_response) - # reset variables - combined_dfs = [] - doc_count = 0 - print("Records uploaded: ", current_doc_count) +# # reset variables +# combined_dfs = [] +# doc_count = 0 +# print("Records uploaded: ", current_doc_count) - # set first_id for next iteration - first_id = response.data[-1]['id'] + 1 +# # set first_id for next iteration +# first_id = response.data[-1]['id'] + 1 - # upload last set of docs - if doc_count > 0: - final_df = pd.concat(combined_dfs, ignore_index=True) - embeddings, metadata = data_prep_for_doc_map(final_df) - result = append_to_map(embeddings, metadata, project_name) - - # update the last uploaded id in supabase - if result == "success": - # update the last uploaded id in supabase - last_id = int(final_df['id'].iloc[-1]) - project_info = {'last_uploaded_doc_id': last_id} - update_response = SUPABASE_CLIENT.table("projects").update(project_info).eq("course_name", course_name).execute() - print("Response from supabase: ", update_response) +# # upload last set of docs +# if doc_count > 0: +# final_df = pd.concat(combined_dfs, ignore_index=True) +# embeddings, metadata = data_prep_for_doc_map(final_df) +# result = append_to_map(embeddings, metadata, project_name) + +# # update the last uploaded id in supabase +# if result == "success": +# # update the last uploaded id in supabase +# last_id = int(final_df['id'].iloc[-1]) +# project_info = {'last_uploaded_doc_id': last_id} +# update_response = SUPABASE_CLIENT.table("projects").update(project_info).eq("course_name", course_name).execute() +# print("Response from supabase: ", update_response) - return "success" - except Exception as e: - print(e) - return "failed" +# return "success" +# except Exception as e: +# print(e) +# return "failed" -def create_map(embeddings, metadata, map_name, index_name, topic_label_field, colorable_fields): - """ - Generic function to create a Nomic map from given parameters. - Args: - embeddings: np.array of embeddings - metadata: pd.DataFrame of metadata - map_name: str - index_name: str - topic_label_field: str - colorable_fields: list of str - """ - nomic.login(os.getenv('NOMIC_API_KEY')) - try: - project = atlas.map_embeddings(embeddings=embeddings, - data=metadata, - id_field="id", - build_topic_model=True, - topic_label_field=topic_label_field, - name=map_name, - colorable_fields=colorable_fields, - add_datums_if_exists=True) - project.create_index(name=index_name, build_topic_model=True) - return "success" - except Exception as e: - print(e) - return "Error in creating map: {e}" - -def append_to_map(embeddings, metadata, map_name): - """ - Generic function to append new data to an existing Nomic map. - Args: - embeddings: np.array of embeddings - metadata: pd.DataFrame of Nomic upload metadata - map_name: str - """ - - nomic.login(os.getenv('NOMIC_API_KEY')) - try: - project = atlas.AtlasProject(name=map_name, add_datums_if_exists=True) - with project.wait_for_project_lock(): - project.add_embeddings(embeddings=embeddings, data=metadata) - return "success" - except Exception as e: - print(e) - return "Error in appending to map: {e}" - -def data_prep_for_doc_map(df: pd.DataFrame): - """ - This function prepares embeddings and metadata for nomic upload in document map creation. - Args: - df: pd.DataFrame - the dataframe of documents from Supabase - Returns: - embeddings: np.array of embeddings - metadata: pd.DataFrame of metadata - """ - print("in data_prep_for_doc_map()") - - metadata = [] - embeddings = [] - - texts = [] - - for index, row in df.iterrows(): - current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") - created_at = datetime.datetime.strptime(row['created_at'], "%Y-%m-%dT%H:%M:%S.%f%z").strftime("%Y-%m-%d %H:%M:%S") - if row['url'] == None: - row['url'] = "" - if row['base_url'] == None: - row['base_url'] = "" - # iterate through all contexts and create separate entries for each - context_count = 0 - for context in row['contexts']: - context_count += 1 - text_row = context['text'] - embeddings_row = context['embedding'] - - meta_row = { - "id": str(row['id']) + "_" + str(context_count), - "created_at": created_at, - "s3_path": row['s3_path'], - "url": row['url'], - "base_url": row['base_url'], - "readable_filename": row['readable_filename'], - "modified_at": current_time, - "text": text_row - } - - embeddings.append(embeddings_row) - metadata.append(meta_row) - texts.append(text_row) - - embeddings_np = np.array(embeddings, dtype=object) - print("Shape of embeddings: ", embeddings_np.shape) - - # check dimension if embeddings_np is (n, 1536) - if len(embeddings_np.shape) < 2: - print("Creating new embeddings...") - - embeddings_model = OpenAIEmbeddings(openai_api_type="openai", - openai_api_base="https://api.openai.com/v1/", - openai_api_key=os.getenv('VLADS_OPENAI_KEY')) # type: ignore - embeddings = embeddings_model.embed_documents(texts) - - metadata = pd.DataFrame(metadata) - embeddings = np.array(embeddings) - - return embeddings, metadata - -def rebuild_map(course_name:str, map_type:str): - """ - This function rebuilds a given map in Nomic. - """ - print("in rebuild_map()") - nomic.login(os.getenv('NOMIC_API_KEY')) - if map_type.lower() == 'document': - NOMIC_MAP_NAME_PREFIX = 'Document Map for ' - else: - NOMIC_MAP_NAME_PREFIX = 'Conversation Map for ' - - try: - # fetch project from Nomic - project_name = NOMIC_MAP_NAME_PREFIX + course_name - project = AtlasProject(name=project_name, add_datums_if_exists=True) - - if project.is_accepting_data: # temporary fix - will skip rebuilding if project is locked - project.rebuild_maps() - return "success" - except Exception as e: - print(e) - sentry_sdk.capture_exception(e) - return "Error in rebuilding map: {e}" - - - -if __name__ == '__main__': - pass +# def create_map(embeddings, metadata, map_name, index_name, topic_label_field, colorable_fields): +# """ +# Generic function to create a Nomic map from given parameters. +# Args: +# embeddings: np.array of embeddings +# metadata: pd.DataFrame of metadata +# map_name: str +# index_name: str +# topic_label_field: str +# colorable_fields: list of str +# """ +# nomic.login(os.getenv('NOMIC_API_KEY')) +# try: +# project = atlas.map_embeddings(embeddings=embeddings, +# data=metadata, +# id_field="id", +# build_topic_model=True, +# topic_label_field=topic_label_field, +# name=map_name, +# colorable_fields=colorable_fields, +# add_datums_if_exists=True) +# project.create_index(name=index_name, build_topic_model=True) +# return "success" +# except Exception as e: +# print(e) +# return "Error in creating map: {e}" + +# def append_to_map(embeddings, metadata, map_name): +# """ +# Generic function to append new data to an existing Nomic map. +# Args: +# embeddings: np.array of embeddings +# metadata: pd.DataFrame of Nomic upload metadata +# map_name: str +# """ + +# nomic.login(os.getenv('NOMIC_API_KEY')) +# try: +# project = atlas.AtlasProject(name=map_name, add_datums_if_exists=True) +# with project.wait_for_project_lock(): +# project.add_embeddings(embeddings=embeddings, data=metadata) +# return "success" +# except Exception as e: +# print(e) +# return "Error in appending to map: {e}" + +# def data_prep_for_doc_map(df: pd.DataFrame): +# """ +# This function prepares embeddings and metadata for nomic upload in document map creation. +# Args: +# df: pd.DataFrame - the dataframe of documents from Supabase +# Returns: +# embeddings: np.array of embeddings +# metadata: pd.DataFrame of metadata +# """ +# print("in data_prep_for_doc_map()") + +# metadata = [] +# embeddings = [] + +# texts = [] + +# for index, row in df.iterrows(): +# current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") +# created_at = datetime.datetime.strptime(row['created_at'], "%Y-%m-%dT%H:%M:%S.%f%z").strftime("%Y-%m-%d %H:%M:%S") +# if row['url'] == None: +# row['url'] = "" +# if row['base_url'] == None: +# row['base_url'] = "" +# # iterate through all contexts and create separate entries for each +# context_count = 0 +# for context in row['contexts']: +# context_count += 1 +# text_row = context['text'] +# embeddings_row = context['embedding'] + +# meta_row = { +# "id": str(row['id']) + "_" + str(context_count), +# "created_at": created_at, +# "s3_path": row['s3_path'], +# "url": row['url'], +# "base_url": row['base_url'], +# "readable_filename": row['readable_filename'], +# "modified_at": current_time, +# "text": text_row +# } + +# embeddings.append(embeddings_row) +# metadata.append(meta_row) +# texts.append(text_row) + +# embeddings_np = np.array(embeddings, dtype=object) +# print("Shape of embeddings: ", embeddings_np.shape) + +# # check dimension if embeddings_np is (n, 1536) +# if len(embeddings_np.shape) < 2: +# print("Creating new embeddings...") + +# embeddings_model = OpenAIEmbeddings(openai_api_type="openai", +# openai_api_base="https://api.openai.com/v1/", +# openai_api_key=os.getenv('VLADS_OPENAI_KEY')) # type: ignore +# embeddings = embeddings_model.embed_documents(texts) + +# metadata = pd.DataFrame(metadata) +# embeddings = np.array(embeddings) + +# return embeddings, metadata + +# def rebuild_map(course_name:str, map_type:str): +# """ +# This function rebuilds a given map in Nomic. +# """ +# print("in rebuild_map()") +# nomic.login(os.getenv('NOMIC_API_KEY')) +# if map_type.lower() == 'document': +# NOMIC_MAP_NAME_PREFIX = 'Document Map for ' +# else: +# NOMIC_MAP_NAME_PREFIX = 'Conversation Map for ' + +# try: +# # fetch project from Nomic +# project_name = NOMIC_MAP_NAME_PREFIX + course_name +# project = AtlasProject(name=project_name, add_datums_if_exists=True) + +# if project.is_accepting_data: # temporary fix - will skip rebuilding if project is locked +# project.rebuild_maps() +# return "success" +# except Exception as e: +# print(e) +# sentry_sdk.capture_exception(e) +# return "Error in rebuilding map: {e}" + + + +# if __name__ == '__main__': +# pass diff --git a/ai_ta_backend/main.py b/ai_ta_backend/main.py index a11e9c91..79ced2b7 100644 --- a/ai_ta_backend/main.py +++ b/ai_ta_backend/main.py @@ -19,7 +19,7 @@ from flask_injector import FlaskInjector, RequestScope from injector import Binder, SingletonScope -from ai_ta_backend.beam.nomic_logging import create_document_map +#from ai_ta_backend.beam.nomic_logging import create_document_map from ai_ta_backend.database.aws import AWSStorage from ai_ta_backend.database.sql import SQLDatabase from ai_ta_backend.database.vector import VectorDatabase @@ -36,13 +36,16 @@ ThreadPoolExecutorInterface, ) from ai_ta_backend.service.export_service import ExportService -from ai_ta_backend.service.nomic_service import NomicService +#from ai_ta_backend.service.nomic_service import NomicService from ai_ta_backend.service.posthog_service import PosthogService from ai_ta_backend.service.project_service import ProjectService from ai_ta_backend.service.retrieval_service import RetrievalService from ai_ta_backend.service.sentry_service import SentryService from ai_ta_backend.service.workflow_service import WorkflowService +from ai_ta_backend.utils.pubmed_extraction import extractPubmedData + + app = Flask(__name__) CORS(app) executor = Executor(app) @@ -176,21 +179,21 @@ def delete(service: RetrievalService, flaskExecutor: ExecutorInterface): return response -@app.route('/getNomicMap', methods=['GET']) -def nomic_map(service: NomicService): - course_name: str = request.args.get('course_name', default='', type=str) - map_type: str = request.args.get('map_type', default='conversation', type=str) +# @app.route('/getNomicMap', methods=['GET']) +# def nomic_map(service: NomicService): +# course_name: str = request.args.get('course_name', default='', type=str) +# map_type: str = request.args.get('map_type', default='conversation', type=str) - if course_name == '': - # proper web error "400 Bad request" - abort(400, description=f"Missing required parameter: 'course_name' must be provided. Course name: `{course_name}`") +# if course_name == '': +# # proper web error "400 Bad request" +# abort(400, description=f"Missing required parameter: 'course_name' must be provided. Course name: `{course_name}`") - map_id = service.get_nomic_map(course_name, map_type) - print("nomic map\n", map_id) +# map_id = service.get_nomic_map(course_name, map_type) +# print("nomic map\n", map_id) - response = jsonify(map_id) - response.headers.add('Access-Control-Allow-Origin', '*') - return response +# response = jsonify(map_id) +# response.headers.add('Access-Control-Allow-Origin', '*') +# return response # @app.route('/createDocumentMap', methods=['GET']) @@ -237,27 +240,27 @@ def nomic_map(service: NomicService): # return response -@app.route('/onResponseCompletion', methods=['POST']) -def logToNomic(service: NomicService, flaskExecutor: ExecutorInterface): - data = request.get_json() - course_name = data['course_name'] - conversation = data['conversation'] +# @app.route('/onResponseCompletion', methods=['POST']) +# def logToNomic(service: NomicService, flaskExecutor: ExecutorInterface): +# data = request.get_json() +# course_name = data['course_name'] +# conversation = data['conversation'] - if course_name == '' or conversation == '': - # proper web error "400 Bad request" - abort( - 400, - description= - f"Missing one or more required parameters: 'course_name' and 'conversation' must be provided. Course name: `{course_name}`, Conversation: `{conversation}`" - ) - print(f"In /onResponseCompletion for course: {course_name}") - - # background execution of tasks!! - #response = flaskExecutor.submit(service.log_convo_to_nomic, course_name, data) - #result = flaskExecutor.submit(service.log_to_conversation_map, course_name, conversation).result() - response = jsonify({'outcome': 'success'}) - response.headers.add('Access-Control-Allow-Origin', '*') - return response +# if course_name == '' or conversation == '': +# # proper web error "400 Bad request" +# abort( +# 400, +# description= +# f"Missing one or more required parameters: 'course_name' and 'conversation' must be provided. Course name: `{course_name}`, Conversation: `{conversation}`" +# ) +# print(f"In /onResponseCompletion for course: {course_name}") + +# # background execution of tasks!! +# #response = flaskExecutor.submit(service.log_convo_to_nomic, course_name, data) +# #result = flaskExecutor.submit(service.log_to_conversation_map, course_name, conversation).result() +# response = jsonify({'outcome': 'success'}) +# response.headers.add('Access-Control-Allow-Origin', '*') +# return response @app.route('/export-convo-history-csv', methods=['GET']) @@ -602,6 +605,17 @@ def createProject(service: ProjectService, flaskExecutor: ExecutorInterface) -> response.headers.add('Access-Control-Allow-Origin', '*') return response +@app.route('/pubmedExtraction', methods=['GET']) +def pubmedExtraction(): + """ + Extracts metadata and download papers from PubMed. + """ + result = extractPubmedData() + + response = jsonify(result) + response.headers.add('Access-Control-Allow-Origin', '*') + return response + def configure(binder: Binder) -> None: binder.bind(ThreadPoolExecutorInterface, to=ThreadPoolExecutorAdapter(max_workers=10), scope=SingletonScope) @@ -609,7 +623,7 @@ def configure(binder: Binder) -> None: binder.bind(RetrievalService, to=RetrievalService, scope=RequestScope) binder.bind(PosthogService, to=PosthogService, scope=SingletonScope) binder.bind(SentryService, to=SentryService, scope=SingletonScope) - binder.bind(NomicService, to=NomicService, scope=SingletonScope) + #binder.bind(NomicService, to=NomicService, scope=SingletonScope) binder.bind(ExportService, to=ExportService, scope=SingletonScope) binder.bind(WorkflowService, to=WorkflowService, scope=SingletonScope) binder.bind(VectorDatabase, to=VectorDatabase, scope=SingletonScope) diff --git a/ai_ta_backend/service/nomic_service.py b/ai_ta_backend/service/nomic_service.py index 80ca86ca..57e29d4d 100644 --- a/ai_ta_backend/service/nomic_service.py +++ b/ai_ta_backend/service/nomic_service.py @@ -1,562 +1,562 @@ -import datetime -import os -import time -from typing import Union - -import backoff -import nomic -import numpy as np -import pandas as pd -from injector import inject -from langchain.embeddings.openai import OpenAIEmbeddings -from nomic import AtlasProject, atlas - -from ai_ta_backend.database.sql import SQLDatabase -from ai_ta_backend.service.sentry_service import SentryService - -LOCK_EXCEPTIONS = [ - 'Project is locked for state access! Please wait until the project is unlocked to access embeddings.', - 'Project is locked for state access! Please wait until the project is unlocked to access data.', - 'Project is currently indexing and cannot ingest new datums. Try again later.' -] - - -class NomicService(): - - @inject - def __init__(self, sentry: SentryService, sql: SQLDatabase): - nomic.login(os.environ['NOMIC_API_KEY']) - self.sentry = sentry - self.sql = sql - - def get_nomic_map(self, course_name: str, type: str): - """ - Returns the variables necessary to construct an iframe of the Nomic map given a course name. - We just need the ID and URL. - Example values: - map link: https://atlas.nomic.ai/map/ed222613-97d9-46a9-8755-12bbc8a06e3a/f4967ad7-ff37-4098-ad06-7e1e1a93dd93 - map id: f4967ad7-ff37-4098-ad06-7e1e1a93dd93 - """ - # nomic.login(os.getenv('NOMIC_API_KEY')) # login during start of flask app - if type.lower() == 'document': - NOMIC_MAP_NAME_PREFIX = 'Document Map for ' - else: - NOMIC_MAP_NAME_PREFIX = 'Conversation Map for ' - - project_name = NOMIC_MAP_NAME_PREFIX + course_name - start_time = time.monotonic() - - try: - project = atlas.AtlasProject(name=project_name, add_datums_if_exists=True) - map = project.get_map(project_name) - - print(f"⏰ Nomic Full Map Retrieval: {(time.monotonic() - start_time):.2f} seconds") - return {"map_id": f"iframe{map.id}", "map_link": map.map_link} - except Exception as e: - # Error: ValueError: You must specify a unique_id_field when creating a new project. - if str(e) == 'You must specify a unique_id_field when creating a new project.': # type: ignore - print( - "Nomic map does not exist yet, probably because you have less than 20 queries/documents on your project: ", - e) - else: - print("ERROR in get_nomic_map():", e) - self.sentry.capture_exception(e) - return {"map_id": None, "map_link": None} - - - def log_to_conversation_map(self, course_name: str, conversation): - """ - This function logs new conversations to existing nomic maps. - 1. Check if nomic map exists - 2. If no, create it - 3. If yes, fetch all conversations since last upload and log it - """ - nomic.login(os.getenv('NOMIC_API_KEY')) - NOMIC_MAP_NAME_PREFIX = 'Conversation Map for ' - try: - # check if map exists - response = self.sql.getConvoMapFromProjects(course_name) - print("Response from supabase: ", response.data) - - # entry not present in projects table - if not response.data: - print("Map does not exist for this course. Redirecting to map creation...") - return self.create_conversation_map(course_name) +# import datetime +# import os +# import time +# from typing import Union + +# import backoff +# import nomic +# import numpy as np +# import pandas as pd +# from injector import inject +# from langchain.embeddings.openai import OpenAIEmbeddings +# from nomic import AtlasProject, atlas + +# from ai_ta_backend.database.sql import SQLDatabase +# from ai_ta_backend.service.sentry_service import SentryService + +# LOCK_EXCEPTIONS = [ +# 'Project is locked for state access! Please wait until the project is unlocked to access embeddings.', +# 'Project is locked for state access! Please wait until the project is unlocked to access data.', +# 'Project is currently indexing and cannot ingest new datums. Try again later.' +# ] + + +# class NomicService(): + +# @inject +# def __init__(self, sentry: SentryService, sql: SQLDatabase): +# nomic.login(os.environ['NOMIC_API_KEY']) +# self.sentry = sentry +# self.sql = sql + +# def get_nomic_map(self, course_name: str, type: str): +# """ +# Returns the variables necessary to construct an iframe of the Nomic map given a course name. +# We just need the ID and URL. +# Example values: +# map link: https://atlas.nomic.ai/map/ed222613-97d9-46a9-8755-12bbc8a06e3a/f4967ad7-ff37-4098-ad06-7e1e1a93dd93 +# map id: f4967ad7-ff37-4098-ad06-7e1e1a93dd93 +# """ +# # nomic.login(os.getenv('NOMIC_API_KEY')) # login during start of flask app +# if type.lower() == 'document': +# NOMIC_MAP_NAME_PREFIX = 'Document Map for ' +# else: +# NOMIC_MAP_NAME_PREFIX = 'Conversation Map for ' + +# project_name = NOMIC_MAP_NAME_PREFIX + course_name +# start_time = time.monotonic() + +# try: +# project = atlas.AtlasProject(name=project_name, add_datums_if_exists=True) +# map = project.get_map(project_name) + +# print(f"⏰ Nomic Full Map Retrieval: {(time.monotonic() - start_time):.2f} seconds") +# return {"map_id": f"iframe{map.id}", "map_link": map.map_link} +# except Exception as e: +# # Error: ValueError: You must specify a unique_id_field when creating a new project. +# if str(e) == 'You must specify a unique_id_field when creating a new project.': # type: ignore +# print( +# "Nomic map does not exist yet, probably because you have less than 20 queries/documents on your project: ", +# e) +# else: +# print("ERROR in get_nomic_map():", e) +# self.sentry.capture_exception(e) +# return {"map_id": None, "map_link": None} + + +# def log_to_conversation_map(self, course_name: str, conversation): +# """ +# This function logs new conversations to existing nomic maps. +# 1. Check if nomic map exists +# 2. If no, create it +# 3. If yes, fetch all conversations since last upload and log it +# """ +# nomic.login(os.getenv('NOMIC_API_KEY')) +# NOMIC_MAP_NAME_PREFIX = 'Conversation Map for ' +# try: +# # check if map exists +# response = self.sql.getConvoMapFromProjects(course_name) +# print("Response from supabase: ", response.data) + +# # entry not present in projects table +# if not response.data: +# print("Map does not exist for this course. Redirecting to map creation...") +# return self.create_conversation_map(course_name) - # entry present for doc map, but not convo map - elif not response.data[0]['convo_map_id']: - print("Map does not exist for this course. Redirecting to map creation...") - return self.create_conversation_map(course_name) +# # entry present for doc map, but not convo map +# elif not response.data[0]['convo_map_id']: +# print("Map does not exist for this course. Redirecting to map creation...") +# return self.create_conversation_map(course_name) - project_id = response.data[0]['convo_map_id'] - last_uploaded_convo_id = response.data[0]['last_uploaded_convo_id'] - - # check if project is accepting data - project = AtlasProject(project_id=project_id, add_datums_if_exists=True) - if not project.is_accepting_data: - return "Project is currently indexing and cannot ingest new datums. Try again later." - - # fetch count of conversations since last upload - response = self.sql.getCountFromLLMConvoMonitor(course_name, last_id=last_uploaded_convo_id) - total_convo_count = response.count - print("Total number of unlogged conversations in Supabase: ", total_convo_count) - - if total_convo_count == 0: - # log to an existing conversation - existing_convo = self.log_to_existing_conversation(course_name, conversation) - return existing_convo - - first_id = last_uploaded_convo_id - combined_dfs = [] - current_convo_count = 0 - convo_count = 0 - - while current_convo_count < total_convo_count: - response = self.sql.getAllConversationsBetweenIds(course_name, first_id, 0, 100) - print("Response count: ", len(response.data)) - if len(response.data) == 0: - break - df = pd.DataFrame(response.data) - combined_dfs.append(df) - current_convo_count += len(response.data) - convo_count += len(response.data) - print(current_convo_count) - - if convo_count >= 500: - # concat all dfs from the combined_dfs list - final_df = pd.concat(combined_dfs, ignore_index=True) - # prep data for nomic upload - embeddings, metadata = self.data_prep_for_convo_map(final_df) - # append to existing map - print("Appending data to existing map...") - result = self.append_to_map(embeddings, metadata, NOMIC_MAP_NAME_PREFIX + course_name) - if result == "success": - last_id = int(final_df['id'].iloc[-1]) - project_info = {'course_name': course_name, 'convo_map_id': project_id, 'last_uploaded_convo_id': last_id} - project_response = self.sql.updateProjects(course_name, project_info) - print("Update response from supabase: ", project_response) - # reset variables - combined_dfs = [] - convo_count = 0 - print("Records uploaded: ", current_convo_count) - - # set first_id for next iteration - first_id = response.data[-1]['id'] + 1 - - # upload last set of convos - if convo_count > 0: - print("Uploading last set of conversations...") - final_df = pd.concat(combined_dfs, ignore_index=True) - embeddings, metadata = self.data_prep_for_convo_map(final_df) - result = self.append_to_map(embeddings, metadata, NOMIC_MAP_NAME_PREFIX + course_name) - if result == "success": - last_id = int(final_df['id'].iloc[-1]) - project_info = {'course_name': course_name, 'convo_map_id': project_id, 'last_uploaded_convo_id': last_id} - project_response = self.sql.updateProjects(course_name, project_info) - print("Update response from supabase: ", project_response) +# project_id = response.data[0]['convo_map_id'] +# last_uploaded_convo_id = response.data[0]['last_uploaded_convo_id'] + +# # check if project is accepting data +# project = AtlasProject(project_id=project_id, add_datums_if_exists=True) +# if not project.is_accepting_data: +# return "Project is currently indexing and cannot ingest new datums. Try again later." + +# # fetch count of conversations since last upload +# response = self.sql.getCountFromLLMConvoMonitor(course_name, last_id=last_uploaded_convo_id) +# total_convo_count = response.count +# print("Total number of unlogged conversations in Supabase: ", total_convo_count) + +# if total_convo_count == 0: +# # log to an existing conversation +# existing_convo = self.log_to_existing_conversation(course_name, conversation) +# return existing_convo + +# first_id = last_uploaded_convo_id +# combined_dfs = [] +# current_convo_count = 0 +# convo_count = 0 + +# while current_convo_count < total_convo_count: +# response = self.sql.getAllConversationsBetweenIds(course_name, first_id, 0, 100) +# print("Response count: ", len(response.data)) +# if len(response.data) == 0: +# break +# df = pd.DataFrame(response.data) +# combined_dfs.append(df) +# current_convo_count += len(response.data) +# convo_count += len(response.data) +# print(current_convo_count) + +# if convo_count >= 500: +# # concat all dfs from the combined_dfs list +# final_df = pd.concat(combined_dfs, ignore_index=True) +# # prep data for nomic upload +# embeddings, metadata = self.data_prep_for_convo_map(final_df) +# # append to existing map +# print("Appending data to existing map...") +# result = self.append_to_map(embeddings, metadata, NOMIC_MAP_NAME_PREFIX + course_name) +# if result == "success": +# last_id = int(final_df['id'].iloc[-1]) +# project_info = {'course_name': course_name, 'convo_map_id': project_id, 'last_uploaded_convo_id': last_id} +# project_response = self.sql.updateProjects(course_name, project_info) +# print("Update response from supabase: ", project_response) +# # reset variables +# combined_dfs = [] +# convo_count = 0 +# print("Records uploaded: ", current_convo_count) + +# # set first_id for next iteration +# first_id = response.data[-1]['id'] + 1 + +# # upload last set of convos +# if convo_count > 0: +# print("Uploading last set of conversations...") +# final_df = pd.concat(combined_dfs, ignore_index=True) +# embeddings, metadata = self.data_prep_for_convo_map(final_df) +# result = self.append_to_map(embeddings, metadata, NOMIC_MAP_NAME_PREFIX + course_name) +# if result == "success": +# last_id = int(final_df['id'].iloc[-1]) +# project_info = {'course_name': course_name, 'convo_map_id': project_id, 'last_uploaded_convo_id': last_id} +# project_response = self.sql.updateProjects(course_name, project_info) +# print("Update response from supabase: ", project_response) - # rebuild the map - self.rebuild_map(course_name, "conversation") - return "success" +# # rebuild the map +# self.rebuild_map(course_name, "conversation") +# return "success" - except Exception as e: - print(e) - self.sentry.capture_exception(e) - return "Error in logging to conversation map: {e}" +# except Exception as e: +# print(e) +# self.sentry.capture_exception(e) +# return "Error in logging to conversation map: {e}" - def log_to_existing_conversation(self, course_name: str, conversation): - """ - This function logs follow-up questions to existing conversations in the map. - """ - print(f"in log_to_existing_conversation() for course: {course_name}") +# def log_to_existing_conversation(self, course_name: str, conversation): +# """ +# This function logs follow-up questions to existing conversations in the map. +# """ +# print(f"in log_to_existing_conversation() for course: {course_name}") - try: - conversation_id = conversation['id'] +# try: +# conversation_id = conversation['id'] - # fetch id from supabase - incoming_id_response = self.sql.getConversation(course_name, key="convo_id", value=conversation_id) +# # fetch id from supabase +# incoming_id_response = self.sql.getConversation(course_name, key="convo_id", value=conversation_id) - project_name = 'Conversation Map for ' + course_name - project = AtlasProject(name=project_name, add_datums_if_exists=True) +# project_name = 'Conversation Map for ' + course_name +# project = AtlasProject(name=project_name, add_datums_if_exists=True) - prev_id = incoming_id_response.data[0]['id'] - uploaded_data = project.get_data(ids=[prev_id]) # fetch data point from nomic - prev_convo = uploaded_data[0]['conversation'] +# prev_id = incoming_id_response.data[0]['id'] +# uploaded_data = project.get_data(ids=[prev_id]) # fetch data point from nomic +# prev_convo = uploaded_data[0]['conversation'] - # update conversation - messages = conversation['messages'] - messages_to_be_logged = messages[-2:] +# # update conversation +# messages = conversation['messages'] +# messages_to_be_logged = messages[-2:] - for message in messages_to_be_logged: - if message['role'] == 'user': - emoji = "🙋 " - else: - emoji = "🤖 " - - if isinstance(message['content'], list): - text = message['content'][0]['text'] - else: - text = message['content'] - - prev_convo += "\n>>> " + emoji + message['role'] + ": " + text + "\n" +# for message in messages_to_be_logged: +# if message['role'] == 'user': +# emoji = "🙋 " +# else: +# emoji = "🤖 " + +# if isinstance(message['content'], list): +# text = message['content'][0]['text'] +# else: +# text = message['content'] + +# prev_convo += "\n>>> " + emoji + message['role'] + ": " + text + "\n" - # create embeddings of first query - embeddings_model = OpenAIEmbeddings(openai_api_type="openai", - openai_api_base="https://api.openai.com/v1/", - openai_api_key=os.environ['VLADS_OPENAI_KEY'], - openai_api_version="2020-11-07") - embeddings = embeddings_model.embed_documents([uploaded_data[0]['first_query']]) +# # create embeddings of first query +# embeddings_model = OpenAIEmbeddings(openai_api_type="openai", +# openai_api_base="https://api.openai.com/v1/", +# openai_api_key=os.environ['VLADS_OPENAI_KEY'], +# openai_api_version="2020-11-07") +# embeddings = embeddings_model.embed_documents([uploaded_data[0]['first_query']]) - # modified timestamp - current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") - uploaded_data[0]['conversation'] = prev_convo - uploaded_data[0]['modified_at'] = current_time +# # modified timestamp +# current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") +# uploaded_data[0]['conversation'] = prev_convo +# uploaded_data[0]['modified_at'] = current_time - metadata = pd.DataFrame(uploaded_data) - embeddings = np.array(embeddings) +# metadata = pd.DataFrame(uploaded_data) +# embeddings = np.array(embeddings) - print("Metadata shape:", metadata.shape) - print("Embeddings shape:", embeddings.shape) +# print("Metadata shape:", metadata.shape) +# print("Embeddings shape:", embeddings.shape) - # deleting existing map - print("Deleting point from nomic:", project.delete_data([prev_id])) +# # deleting existing map +# print("Deleting point from nomic:", project.delete_data([prev_id])) - # re-build map to reflect deletion - project.rebuild_maps() +# # re-build map to reflect deletion +# project.rebuild_maps() - # re-insert updated conversation - result = self.append_to_map(embeddings, metadata, project_name) - print("Result of appending to existing map:", result) +# # re-insert updated conversation +# result = self.append_to_map(embeddings, metadata, project_name) +# print("Result of appending to existing map:", result) - return "success" - - except Exception as e: - print("Error in log_to_existing_conversation():", e) - self.sentry.capture_exception(e) - return "Error in logging to existing conversation: {e}" - - - def create_conversation_map(self, course_name: str): - """ - This function creates a conversation map for a given course from scratch. - """ - nomic.login(os.getenv('NOMIC_API_KEY')) - NOMIC_MAP_NAME_PREFIX = 'Conversation Map for ' - try: - # check if map exists - response = self.sql.getConvoMapFromProjects(course_name) - print("Response from supabase: ", response.data) - if response.data: - if response.data[0]['convo_map_id']: - return "Map already exists for this course." - - # if no, fetch total count of records - response = self.sql.getCountFromLLMConvoMonitor(course_name, last_id=0) - - # if <20, return message that map cannot be created - if not response.count: - return "No conversations found for this course." - elif response.count < 20: - return "Cannot create a map because there are less than 20 conversations in the course." - - # if >20, iteratively fetch records in batches of 100 - total_convo_count = response.count - print("Total number of conversations in Supabase: ", total_convo_count) - - first_id = response.data[0]['id'] - 1 - combined_dfs = [] - current_convo_count = 0 - convo_count = 0 - first_batch = True - project_name = NOMIC_MAP_NAME_PREFIX + course_name - - # iteratively query in batches of 50 - while current_convo_count < total_convo_count: - response = self.sql.getAllConversationsBetweenIds(course_name, first_id, 0, 100) - print("Response count: ", len(response.data)) - if len(response.data) == 0: - break - df = pd.DataFrame(response.data) - combined_dfs.append(df) - current_convo_count += len(response.data) - convo_count += len(response.data) - print(current_convo_count) - - if convo_count >= 500: - # concat all dfs from the combined_dfs list - final_df = pd.concat(combined_dfs, ignore_index=True) - # prep data for nomic upload - embeddings, metadata = self.data_prep_for_convo_map(final_df) - - if first_batch: - # create a new map - print("Creating new map...") - index_name = course_name + "_convo_index" - topic_label_field = "first_query" - colorable_fields = ["user_email", "first_query", "conversation_id", "created_at"] - result = self.create_map(embeddings, metadata, project_name, index_name, topic_label_field, - colorable_fields) - - if result == "success": - # update flag - first_batch = False - # log project info to supabase - project = AtlasProject(name=project_name, add_datums_if_exists=True) - project_id = project.id - last_id = int(final_df['id'].iloc[-1]) - project_info = {'course_name': course_name, 'convo_map_id': project_id, 'last_uploaded_convo_id': last_id} - # if entry already exists, update it - projects_record = self.sql.getConvoMapFromProjects(course_name) - if projects_record.data: - project_response = self.sql.updateProjects(course_name, project_info) - else: - project_response = self.sql.insertProjectInfo(project_info) - print("Update response from supabase: ", project_response) - else: - # append to existing map - print("Appending data to existing map...") - project = AtlasProject(name=project_name, add_datums_if_exists=True) - result = self.append_to_map(embeddings, metadata, project_name) - if result == "success": - print("map append successful") - last_id = int(final_df['id'].iloc[-1]) - project_info = {'last_uploaded_convo_id': last_id} - project_response = self.sql.updateProjects(course_name, project_info) - print("Update response from supabase: ", project_response) - - # reset variables - combined_dfs = [] - convo_count = 0 - print("Records uploaded: ", current_convo_count) - - # set first_id for next iteration - try: - print("response: ", response.data[-1]['id']) - except: - print("response: ", response.data) - first_id = response.data[-1]['id'] + 1 - - print("Convo count: ", convo_count) - # upload last set of convos - if convo_count > 0: - print("Uploading last set of conversations...") - final_df = pd.concat(combined_dfs, ignore_index=True) - embeddings, metadata = self.data_prep_for_convo_map(final_df) - if first_batch: - # create map - index_name = course_name + "_convo_index" - topic_label_field = "first_query" - colorable_fields = ["user_email", "first_query", "conversation_id", "created_at"] - result = self.create_map(embeddings, metadata, project_name, index_name, topic_label_field, colorable_fields) - - else: - # append to map - print("in map append") - result = self.append_to_map(embeddings, metadata, project_name) - - if result == "success": - print("last map append successful") - last_id = int(final_df['id'].iloc[-1]) - project = AtlasProject(name=project_name, add_datums_if_exists=True) - project_id = project.id - project_info = {'course_name': course_name, 'convo_map_id': project_id, 'last_uploaded_convo_id': last_id} - print("Project info: ", project_info) - # if entry already exists, update it - projects_record = self.sql.getConvoMapFromProjects(course_name) - if projects_record.data: - project_response = self.sql.updateProjects(course_name, project_info) - else: - project_response = self.sql.insertProjectInfo(project_info) - print("Response from supabase: ", project_response) - - - # rebuild the map - self.rebuild_map(course_name, "conversation") - return "success" - except Exception as e: - print(e) - self.sentry.capture_exception(e) - return "Error in creating conversation map:" + str(e) - - ## -------------------------------- SUPPLEMENTARY MAP FUNCTIONS --------------------------------- ## - - def rebuild_map(self, course_name: str, map_type: str): - """ - This function rebuilds a given map in Nomic. - """ - print("in rebuild_map()") - nomic.login(os.getenv('NOMIC_API_KEY')) - - if map_type.lower() == 'document': - NOMIC_MAP_NAME_PREFIX = 'Document Map for ' - else: - NOMIC_MAP_NAME_PREFIX = 'Conversation Map for ' - - try: - # fetch project from Nomic - project_name = NOMIC_MAP_NAME_PREFIX + course_name - project = AtlasProject(name=project_name, add_datums_if_exists=True) - - if project.is_accepting_data: - project.rebuild_maps() - return "success" - except Exception as e: - print(e) - self.sentry.capture_exception(e) - return "Error in rebuilding map: {e}" - - def create_map(self, embeddings, metadata, map_name, index_name, topic_label_field, colorable_fields): - """ - Generic function to create a Nomic map from given parameters. - Args: - embeddings: np.array of embeddings - metadata: pd.DataFrame of metadata - map_name: str - index_name: str - topic_label_field: str - colorable_fields: list of str - """ - nomic.login(os.environ['NOMIC_API_KEY']) - print("in create_map()") - try: - project = atlas.map_embeddings(embeddings=embeddings, - data=metadata, - id_field="id", - build_topic_model=True, - name=map_name, - topic_label_field=topic_label_field, - colorable_fields=colorable_fields, - add_datums_if_exists=True) - project.create_index(index_name, build_topic_model=True) - return "success" - except Exception as e: - print(e) - return "Error in creating map: {e}" - - def append_to_map(self, embeddings, metadata, map_name): - """ - Generic function to append new data to an existing Nomic map. - Args: - embeddings: np.array of embeddings - metadata: pd.DataFrame of Nomic upload metadata - map_name: str - """ - nomic.login(os.environ['NOMIC_API_KEY']) - try: - project = atlas.AtlasProject(name=map_name, add_datums_if_exists=True) - with project.wait_for_project_lock(): - project.add_embeddings(embeddings=embeddings, data=metadata) - return "success" - except Exception as e: - print(e) - return "Error in appending to map: {e}" - - def data_prep_for_convo_map(self, df: pd.DataFrame): - """ - This function prepares embeddings and metadata for nomic upload in conversation map creation. - Args: - df: pd.DataFrame - the dataframe of documents from Supabase - Returns: - embeddings: np.array of embeddings - metadata: pd.DataFrame of metadata - """ - print("in data_prep_for_convo_map()") - - try: - metadata = [] - embeddings = [] - user_queries = [] - - for _index, row in df.iterrows(): - current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") - created_at = datetime.datetime.strptime(row['created_at'], "%Y-%m-%dT%H:%M:%S.%f%z").strftime("%Y-%m-%d %H:%M:%S") - conversation_exists = False - conversation = "" - emoji = "" - - if row['user_email'] is None: - user_email = "" - else: - user_email = row['user_email'] - - messages = row['convo']['messages'] - - # some conversations include images, so the data structure is different - if isinstance(messages[0]['content'], list): - if 'text' in messages[0]['content'][0]: - first_message = messages[0]['content'][0]['text'] - #print("First message:", first_message) - else: - first_message = messages[0]['content'] - user_queries.append(first_message) - - # construct metadata for multi-turn conversation - for message in messages: - if message['role'] == 'user': - emoji = "🙋 " - else: - emoji = "🤖 " - - if isinstance(message['content'], list): - - if 'text' in message['content'][0]: - text = message['content'][0]['text'] - else: - text = message['content'] - - conversation += "\n>>> " + emoji + message['role'] + ": " + text + "\n" - - meta_row = { - "course": row['course_name'], - "conversation": conversation, - "conversation_id": row['convo']['id'], - "id": row['id'], - "user_email": user_email, - "first_query": first_message, - "created_at": created_at, - "modified_at": current_time - } - #print("Metadata row:", meta_row) - metadata.append(meta_row) - - embeddings_model = OpenAIEmbeddings(openai_api_type="openai", - openai_api_base="https://api.openai.com/v1/", - openai_api_key=os.environ['VLADS_OPENAI_KEY'], - openai_api_version="2020-11-07") - embeddings = embeddings_model.embed_documents(user_queries) - - metadata = pd.DataFrame(metadata) - embeddings = np.array(embeddings) - print("Metadata shape:", metadata.shape) - print("Embeddings shape:", embeddings.shape) - return embeddings, metadata - - except Exception as e: - print("Error in data_prep_for_convo_map():", e) - self.sentry.capture_exception(e) - return None, None - - def delete_from_document_map(self, project_id: str, ids: list): - """ - This function is used to delete datapoints from a document map. - Currently used within the delete_data() function in vector_database.py - Args: - course_name: str - ids: list of str - """ - print("in delete_from_document_map()") - - try: - # fetch project from Nomic - project = AtlasProject(project_id=project_id, add_datums_if_exists=True) - - # delete the ids from Nomic - print("Deleting point from document map:", project.delete_data(ids)) - with project.wait_for_project_lock(): - project.rebuild_maps() - return "Successfully deleted from Nomic map" - except Exception as e: - print(e) - self.sentry.capture_exception(e) - return "Error in deleting from document map: {e}" +# return "success" + +# except Exception as e: +# print("Error in log_to_existing_conversation():", e) +# self.sentry.capture_exception(e) +# return "Error in logging to existing conversation: {e}" + + +# def create_conversation_map(self, course_name: str): +# """ +# This function creates a conversation map for a given course from scratch. +# """ +# nomic.login(os.getenv('NOMIC_API_KEY')) +# NOMIC_MAP_NAME_PREFIX = 'Conversation Map for ' +# try: +# # check if map exists +# response = self.sql.getConvoMapFromProjects(course_name) +# print("Response from supabase: ", response.data) +# if response.data: +# if response.data[0]['convo_map_id']: +# return "Map already exists for this course." + +# # if no, fetch total count of records +# response = self.sql.getCountFromLLMConvoMonitor(course_name, last_id=0) + +# # if <20, return message that map cannot be created +# if not response.count: +# return "No conversations found for this course." +# elif response.count < 20: +# return "Cannot create a map because there are less than 20 conversations in the course." + +# # if >20, iteratively fetch records in batches of 100 +# total_convo_count = response.count +# print("Total number of conversations in Supabase: ", total_convo_count) + +# first_id = response.data[0]['id'] - 1 +# combined_dfs = [] +# current_convo_count = 0 +# convo_count = 0 +# first_batch = True +# project_name = NOMIC_MAP_NAME_PREFIX + course_name + +# # iteratively query in batches of 50 +# while current_convo_count < total_convo_count: +# response = self.sql.getAllConversationsBetweenIds(course_name, first_id, 0, 100) +# print("Response count: ", len(response.data)) +# if len(response.data) == 0: +# break +# df = pd.DataFrame(response.data) +# combined_dfs.append(df) +# current_convo_count += len(response.data) +# convo_count += len(response.data) +# print(current_convo_count) + +# if convo_count >= 500: +# # concat all dfs from the combined_dfs list +# final_df = pd.concat(combined_dfs, ignore_index=True) +# # prep data for nomic upload +# embeddings, metadata = self.data_prep_for_convo_map(final_df) + +# if first_batch: +# # create a new map +# print("Creating new map...") +# index_name = course_name + "_convo_index" +# topic_label_field = "first_query" +# colorable_fields = ["user_email", "first_query", "conversation_id", "created_at"] +# result = self.create_map(embeddings, metadata, project_name, index_name, topic_label_field, +# colorable_fields) + +# if result == "success": +# # update flag +# first_batch = False +# # log project info to supabase +# project = AtlasProject(name=project_name, add_datums_if_exists=True) +# project_id = project.id +# last_id = int(final_df['id'].iloc[-1]) +# project_info = {'course_name': course_name, 'convo_map_id': project_id, 'last_uploaded_convo_id': last_id} +# # if entry already exists, update it +# projects_record = self.sql.getConvoMapFromProjects(course_name) +# if projects_record.data: +# project_response = self.sql.updateProjects(course_name, project_info) +# else: +# project_response = self.sql.insertProjectInfo(project_info) +# print("Update response from supabase: ", project_response) +# else: +# # append to existing map +# print("Appending data to existing map...") +# project = AtlasProject(name=project_name, add_datums_if_exists=True) +# result = self.append_to_map(embeddings, metadata, project_name) +# if result == "success": +# print("map append successful") +# last_id = int(final_df['id'].iloc[-1]) +# project_info = {'last_uploaded_convo_id': last_id} +# project_response = self.sql.updateProjects(course_name, project_info) +# print("Update response from supabase: ", project_response) + +# # reset variables +# combined_dfs = [] +# convo_count = 0 +# print("Records uploaded: ", current_convo_count) + +# # set first_id for next iteration +# try: +# print("response: ", response.data[-1]['id']) +# except: +# print("response: ", response.data) +# first_id = response.data[-1]['id'] + 1 + +# print("Convo count: ", convo_count) +# # upload last set of convos +# if convo_count > 0: +# print("Uploading last set of conversations...") +# final_df = pd.concat(combined_dfs, ignore_index=True) +# embeddings, metadata = self.data_prep_for_convo_map(final_df) +# if first_batch: +# # create map +# index_name = course_name + "_convo_index" +# topic_label_field = "first_query" +# colorable_fields = ["user_email", "first_query", "conversation_id", "created_at"] +# result = self.create_map(embeddings, metadata, project_name, index_name, topic_label_field, colorable_fields) + +# else: +# # append to map +# print("in map append") +# result = self.append_to_map(embeddings, metadata, project_name) + +# if result == "success": +# print("last map append successful") +# last_id = int(final_df['id'].iloc[-1]) +# project = AtlasProject(name=project_name, add_datums_if_exists=True) +# project_id = project.id +# project_info = {'course_name': course_name, 'convo_map_id': project_id, 'last_uploaded_convo_id': last_id} +# print("Project info: ", project_info) +# # if entry already exists, update it +# projects_record = self.sql.getConvoMapFromProjects(course_name) +# if projects_record.data: +# project_response = self.sql.updateProjects(course_name, project_info) +# else: +# project_response = self.sql.insertProjectInfo(project_info) +# print("Response from supabase: ", project_response) + + +# # rebuild the map +# self.rebuild_map(course_name, "conversation") +# return "success" +# except Exception as e: +# print(e) +# self.sentry.capture_exception(e) +# return "Error in creating conversation map:" + str(e) + +# ## -------------------------------- SUPPLEMENTARY MAP FUNCTIONS --------------------------------- ## + +# def rebuild_map(self, course_name: str, map_type: str): +# """ +# This function rebuilds a given map in Nomic. +# """ +# print("in rebuild_map()") +# nomic.login(os.getenv('NOMIC_API_KEY')) + +# if map_type.lower() == 'document': +# NOMIC_MAP_NAME_PREFIX = 'Document Map for ' +# else: +# NOMIC_MAP_NAME_PREFIX = 'Conversation Map for ' + +# try: +# # fetch project from Nomic +# project_name = NOMIC_MAP_NAME_PREFIX + course_name +# project = AtlasProject(name=project_name, add_datums_if_exists=True) + +# if project.is_accepting_data: +# project.rebuild_maps() +# return "success" +# except Exception as e: +# print(e) +# self.sentry.capture_exception(e) +# return "Error in rebuilding map: {e}" + +# def create_map(self, embeddings, metadata, map_name, index_name, topic_label_field, colorable_fields): +# """ +# Generic function to create a Nomic map from given parameters. +# Args: +# embeddings: np.array of embeddings +# metadata: pd.DataFrame of metadata +# map_name: str +# index_name: str +# topic_label_field: str +# colorable_fields: list of str +# """ +# nomic.login(os.environ['NOMIC_API_KEY']) +# print("in create_map()") +# try: +# project = atlas.map_embeddings(embeddings=embeddings, +# data=metadata, +# id_field="id", +# build_topic_model=True, +# name=map_name, +# topic_label_field=topic_label_field, +# colorable_fields=colorable_fields, +# add_datums_if_exists=True) +# project.create_index(index_name, build_topic_model=True) +# return "success" +# except Exception as e: +# print(e) +# return "Error in creating map: {e}" + +# def append_to_map(self, embeddings, metadata, map_name): +# """ +# Generic function to append new data to an existing Nomic map. +# Args: +# embeddings: np.array of embeddings +# metadata: pd.DataFrame of Nomic upload metadata +# map_name: str +# """ +# nomic.login(os.environ['NOMIC_API_KEY']) +# try: +# project = atlas.AtlasProject(name=map_name, add_datums_if_exists=True) +# with project.wait_for_project_lock(): +# project.add_embeddings(embeddings=embeddings, data=metadata) +# return "success" +# except Exception as e: +# print(e) +# return "Error in appending to map: {e}" + +# def data_prep_for_convo_map(self, df: pd.DataFrame): +# """ +# This function prepares embeddings and metadata for nomic upload in conversation map creation. +# Args: +# df: pd.DataFrame - the dataframe of documents from Supabase +# Returns: +# embeddings: np.array of embeddings +# metadata: pd.DataFrame of metadata +# """ +# print("in data_prep_for_convo_map()") + +# try: +# metadata = [] +# embeddings = [] +# user_queries = [] + +# for _index, row in df.iterrows(): +# current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") +# created_at = datetime.datetime.strptime(row['created_at'], "%Y-%m-%dT%H:%M:%S.%f%z").strftime("%Y-%m-%d %H:%M:%S") +# conversation_exists = False +# conversation = "" +# emoji = "" + +# if row['user_email'] is None: +# user_email = "" +# else: +# user_email = row['user_email'] + +# messages = row['convo']['messages'] + +# # some conversations include images, so the data structure is different +# if isinstance(messages[0]['content'], list): +# if 'text' in messages[0]['content'][0]: +# first_message = messages[0]['content'][0]['text'] +# #print("First message:", first_message) +# else: +# first_message = messages[0]['content'] +# user_queries.append(first_message) + +# # construct metadata for multi-turn conversation +# for message in messages: +# if message['role'] == 'user': +# emoji = "🙋 " +# else: +# emoji = "🤖 " + +# if isinstance(message['content'], list): + +# if 'text' in message['content'][0]: +# text = message['content'][0]['text'] +# else: +# text = message['content'] + +# conversation += "\n>>> " + emoji + message['role'] + ": " + text + "\n" + +# meta_row = { +# "course": row['course_name'], +# "conversation": conversation, +# "conversation_id": row['convo']['id'], +# "id": row['id'], +# "user_email": user_email, +# "first_query": first_message, +# "created_at": created_at, +# "modified_at": current_time +# } +# #print("Metadata row:", meta_row) +# metadata.append(meta_row) + +# embeddings_model = OpenAIEmbeddings(openai_api_type="openai", +# openai_api_base="https://api.openai.com/v1/", +# openai_api_key=os.environ['VLADS_OPENAI_KEY'], +# openai_api_version="2020-11-07") +# embeddings = embeddings_model.embed_documents(user_queries) + +# metadata = pd.DataFrame(metadata) +# embeddings = np.array(embeddings) +# print("Metadata shape:", metadata.shape) +# print("Embeddings shape:", embeddings.shape) +# return embeddings, metadata + +# except Exception as e: +# print("Error in data_prep_for_convo_map():", e) +# self.sentry.capture_exception(e) +# return None, None + +# def delete_from_document_map(self, project_id: str, ids: list): +# """ +# This function is used to delete datapoints from a document map. +# Currently used within the delete_data() function in vector_database.py +# Args: +# course_name: str +# ids: list of str +# """ +# print("in delete_from_document_map()") + +# try: +# # fetch project from Nomic +# project = AtlasProject(project_id=project_id, add_datums_if_exists=True) + +# # delete the ids from Nomic +# print("Deleting point from document map:", project.delete_data(ids)) +# with project.wait_for_project_lock(): +# project.rebuild_maps() +# return "Successfully deleted from Nomic map" +# except Exception as e: +# print(e) +# self.sentry.capture_exception(e) +# return "Error in deleting from document map: {e}" diff --git a/ai_ta_backend/service/retrieval_service.py b/ai_ta_backend/service/retrieval_service.py index 506d4af1..8fc4a260 100644 --- a/ai_ta_backend/service/retrieval_service.py +++ b/ai_ta_backend/service/retrieval_service.py @@ -18,7 +18,7 @@ from ai_ta_backend.database.sql import SQLDatabase from ai_ta_backend.database.vector import VectorDatabase from ai_ta_backend.executors.thread_pool_executor import ThreadPoolExecutorAdapter -from ai_ta_backend.service.nomic_service import NomicService +# from ai_ta_backend.service.nomic_service import NomicService from ai_ta_backend.service.posthog_service import PosthogService from ai_ta_backend.service.sentry_service import SentryService from ai_ta_backend.utils.utils_tokenization import count_tokens_and_cost @@ -31,13 +31,13 @@ class RetrievalService: @inject def __init__(self, vdb: VectorDatabase, sqlDb: SQLDatabase, aws: AWSStorage, posthog: PosthogService, - sentry: SentryService, nomicService: NomicService, thread_pool_executor: ThreadPoolExecutorAdapter): + sentry: SentryService, thread_pool_executor: ThreadPoolExecutorAdapter): self.vdb = vdb self.sqlDb = sqlDb self.aws = aws self.sentry = sentry self.posthog = posthog - self.nomicService = nomicService + #self.nomicService = nomicService self.thread_pool_executor = thread_pool_executor openai.api_key = os.environ["VLADS_OPENAI_KEY"] diff --git a/ai_ta_backend/utils/pubmed_extraction.py b/ai_ta_backend/utils/pubmed_extraction.py new file mode 100644 index 00000000..2cfc2ea0 --- /dev/null +++ b/ai_ta_backend/utils/pubmed_extraction.py @@ -0,0 +1,832 @@ +import os +import requests +import shutil +import xml.etree.ElementTree as ET +import ftplib +import supabase +import gzip +import concurrent.futures +from urllib.parse import urlparse +import tarfile +import os +import shutil +from minio import Minio +import time +from multiprocessing import Manager +import pandas as pd +import threading +import json +from functools import partial +from posthog import Posthog +import asyncio + +POSTHOG = Posthog(sync_mode=False, project_api_key=os.environ['POSTHOG_API_KEY'], host="https://app.posthog.com") + +SUPBASE_CLIENT = supabase.create_client( # type: ignore + supabase_url=os.getenv('SUPABASE_URL'), # type: ignore + supabase_key=os.getenv('SUPABASE_API_KEY') # type: ignore +) + +MINIO_CLIENT = Minio(os.environ['MINIO_URL'], + access_key=os.environ['MINIO_ACCESS_KEY'], + secret_key=os.environ['MINIO_SECRET_KEY'], + secure=True +) + +def extractPubmedData(): + """ + Main function to extract metadata and articles from the PubMed baseline folder. + """ + start_time = time.monotonic() + + ftp_address = "ftp.ncbi.nlm.nih.gov" + #ftp_path = "pubmed/baseline" + ftp_path = "pubmed/updatefiles" + file_list = getFileList(ftp_address, ftp_path, ".gz") + print("Total files: ", len(file_list)) + + # with concurrent.futures.ProcessPoolExecutor() as executor: + # futures = [executor.submit(processPubmedXML, file, ftp_address, ftp_path) for file in file_list[131:133]] + # for future in concurrent.futures.as_completed(futures): + # try: + # future.result() + # except Exception as e: + # print("Error processing file: ", e) + + files_to_process = getFilesToProcess(file_list) + + + for file in files_to_process: + status = processPubmedXML(file, ftp_address, ftp_path) + print("Status: ", status) + + end_time = time.monotonic() + + POSTHOG.capture(distinct_id = "pubmed_extraction", + event = "total_pubmed_extraction_runtime", + properties = { + "total_runtime": end_time - start_time, + } + ) + + return "success" + +def getFilesToProcess(file_list: list): + last_processed_response = SUPBASE_CLIENT.table("pubmed_daily_update").select("*").order("created_at", desc=True).limit(1).execute() # type: ignore + last_processed_file = last_processed_response.data[0]['last_xml_file'] + print("Last processed file: ", last_processed_file) + files_to_process = [] + + for file in file_list: + if file == last_processed_file: + break + files_to_process.append(file) + + return files_to_process + +def processPubmedXML(file:str, ftp_address:str, ftp_path:str): + """ + Main function to extract metadata and articles from the PubMed baseline folder. + """ + start_time = time.monotonic() + try: + print("Processing file: ", file) + gz_filepath = downloadXML(ftp_address, ftp_path, file, "pubmed") + gz_file_download_time = time.time() + + # extract the XML file + if not gz_filepath: + return "failure" + xml_filepath = extractXMLFile(gz_filepath) + + xml_id = xml_filepath[7:-4].replace(".", "_") + destination_dir = xml_id + "_papers" + csv_filepath = xml_id + "_metadata.csv" + error_log = xml_id + "_errors.txt" + + for i, metadata in enumerate(extractMetadataFromXML(xml_filepath, destination_dir, error_log)): + metadata_extract_start_time = time.time() + + batch_dir = os.path.join(destination_dir, f"batch_{i+1}") + os.makedirs(batch_dir, exist_ok=True) + + # find PMC ID and DOI for all articles + metadata_with_ids = getArticleIDs(metadata, error_log) + metadata_update_time = time.time() + print("Time taken to get PMC ID and DOI for 100 articles: ", round(metadata_update_time - metadata_extract_start_time, 2), "seconds") + + # download the articles + complete_metadata = downloadArticles(metadata_with_ids, batch_dir, error_log) + print("Time taken to download 100 articles: ", round(time.time() - metadata_update_time, 2), "seconds") + + # store metadata in csv file + df = pd.DataFrame(complete_metadata) + + # add a column for the XML file path + df['xml_filename'] = os.path.basename(xml_filepath) + + if os.path.isfile(csv_filepath): + df.to_csv(csv_filepath, mode='a', header=False, index=False) + else: + df.to_csv(csv_filepath, index=False) + + before_upload = time.time() + # upload current batch to minio + print(f"Starting async upload for batch {i+1}...") + #asyncio.run(uploadToStorage(batch_dir, error_log)) + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(uploadToStorage, batch_dir, error_log) + after_upload = time.time() + + print("Time elapsed between upload call: ", after_upload - before_upload) + POSTHOG.capture(distinct_id = "pubmed_extraction", + event = "uploadToMinio", + properties = { + "description": "upload files to minio", + "total_runtime": after_upload - before_upload, + } + ) + + print("Time taken to download 100 articles: ", round(time.time() - metadata_extract_start_time, 2), "seconds") + + post_download_time_1 = time.monotonic() + + # upload metadata to SQL DB + df = pd.read_csv(csv_filepath) + complete_metadata = df.to_dict('records') + final_metadata = [] + unique_pmids = [] + for item in complete_metadata: + for key, value in item.items(): + if pd.isna(value): # Or: math.isnan(value) + item[key] = None + + # check for duplicates + if item['pmid'] not in unique_pmids: + final_metadata.append(item) + unique_pmids.append(item['pmid']) + print("Final metadata: ", len(final_metadata)) + + try: + response = SUPBASE_CLIENT.table("publications").upsert(final_metadata).execute() # type: ignore + print("Uploaded metadata to SQL DB.") + except Exception as e: + print("Error in uploading to Supabase: ", e) + # log the supabase error + with open(error_log, 'a') as f: + f.write("Error in Supabase upsert: " + str(e) + "\n") + + post_download_time_2 = time.monotonic() + + POSTHOG.capture(distinct_id = "pubmed_extraction", + event = "uploadToSupabase", + properties = { + "description": "Process and upload metadata file to Supabase", + "total_runtime": post_download_time_2 - post_download_time_1, + } + ) + + # upload txt articles to bucket + # print("Uploading articles to storage...") + # #destination_dir = "/home/avd6/chatbotai/asmita/ai-ta-backend/papers" + # #error_log = "all_errors.txt" + # #asyncio.run(uploadToStorage(destination_dir, error_log)) + # with concurrent.futures.ThreadPoolExecutor() as executor: + # future = executor.submit(uploadToStorage, destination_dir, error_log) + # print("Upload started asynchronously.") + + post_download_time_3 = time.monotonic() + + # delete files + # os.remove(csv_filepath) + # os.remove(xml_filepath) + # os.remove(gz_filepath) + print("Finished file: ", file) + + except Exception as e: + print("Error processing XML file: ", e) + with open("errors.txt", 'a') as f: + f.write("Error in proccesing XML file: " + file + str(e) + "\n") + + end_time = time.monotonic() + + POSTHOG.capture(distinct_id = "pubmed_extraction", + event = "processPubmedXML", + properties = { + "total_runtime": end_time - start_time, + } + ) + + +def downloadXML(ftp_address: str, ftp_path: str, file: str, local_dir: str): + """ + Downloads a .gz XML file from the FTP baseline folder and stores it in the local directory. + Args: + ftp_address: FTP server address. + ftp_path: Path to the FTP folder. + file: File to download. + local_dir: Local directory to store the downloaded file. + Returns: + local_filepath: Path to the downloaded file. + """ + try: + # create local directory if it doesn't exist + os.makedirs(local_dir, exist_ok=True) + + # connect to the FTP server + ftp = ftplib.FTP(ftp_address) + ftp.login() + ftp.cwd(ftp_path) + + local_filepath = os.path.join(local_dir, file) + with open(local_filepath, 'wb') as f: + ftp.retrbinary('RETR ' + file, f.write) + + # print(f"Downloaded {file} to {local_filepath}") + + ftp.quit() + return local_filepath + except Exception as e: + print("Error downloading file: ", e) + return None + +def getFileList(ftp_address: str, ftp_path: str, extension: str = ".gz"): + """ + Returns a list of .gz files in the FTP baseline folder. + Args: + ftp_address: FTP server address. + ftp_path: Path to the FTP folder. + extension: File extension to filter for. + Returns: + gz_files: List of .gz files in the FTP folder. + """ + try: + # connect to the FTP server + ftp = ftplib.FTP(ftp_address) + ftp.login() + + # Change directory to the specified path + ftp.cwd(ftp_path) + + # Get list of file entries + file_listing = ftp.nlst() + + ftp.quit() + + # Filter for files with the specified extension + gz_files = [entry for entry in file_listing if entry.endswith(extension)] + gz_files.sort(reverse=True) + print(f"Found {len(gz_files)} files on {ftp_address}/{ftp_path}") + + return gz_files + except Exception as e: + print("Error getting file list: ", e) + return [] + +def extractXMLFile(gz_filepath: str): + """ + Extracts the XML file from the .gz file. + Args: + gz_filepath: Path to the .gz file. + Returns: + xml_filepath: Path to the extracted XML file. + """ + start_time = time.monotonic() + try: + # print("Downloaded .gz file path: ", gz_filepath) + xml_filepath = gz_filepath.replace(".gz", "") + with gzip.open(gz_filepath, 'rb') as f_in: + with open(xml_filepath, 'wb') as f_out: + shutil.copyfileobj(f_in, f_out) + + POSTHOG.capture(distinct_id = "pubmed_extraction", + event = "extractXMLFile", + properties = { + "total_runtime": time.monotonic() - start_time, + } + ) + + return xml_filepath + except Exception as e: + print("Error extracting XML file: ", e) + return None + +def extractMetadataFromXML(xml_filepath: str, dir: str, error_file: str): + """ + Extracts article details from the XML file and stores it in a dictionary. + Details extracted: PMID, PMCID, DOI, ISSN, journal title, article title, + last revised date, published date, abstract. + Args: + xml_filepath: Path to the XML file. + Returns: + metadata: List of dictionaries containing metadata for each article. + """ + print("inside extractMetadataFromXML()") + start_time = time.monotonic() + try: + # create a directory to store papers + os.makedirs(dir, exist_ok=True) + + tree = ET.parse(xml_filepath) + root = tree.getroot() + metadata = [] + + + with concurrent.futures.ProcessPoolExecutor() as executor: + futures = [] + article_items = list(item for item in root.iter('PubmedArticle')) # Convert generator to list + + for item in article_items: + future = executor.submit(processArticleItem, item, dir, error_file) + article_data = future.result() + metadata.append(article_data) + + if len(metadata) == 100: + # print("collected 100 articles") + yield metadata + metadata = [] # reset metadata for next batch + + if metadata: + yield metadata + + # print("Metadata extraction complete.") + POSTHOG.capture(distinct_id = "pubmed_extraction", + event = "extractMetadataFromXML", + properties = { + "no_of_articles": 100, + "total_runtime": time.monotonic() - start_time, + } + ) + except Exception as e: + #print("Error extracting metadata: ", e) + with open(error_file, 'a') as f: + f.write("Error in main metadata extraction function: " + str(e) + "\n") + return [] + + +def processArticleItem(item: ET.Element, directory: str, error_file: str): + """ + Extracts article details from a single PubmedArticle XML element. This is used in the process pool executor. + Args: + item: PubmedArticle XML element. + Returns: + article_data: Dictionary containing metadata for the article. + """ + try: + article_data = {} + + medline_citation = item.find('MedlineCitation') + article = medline_citation.find('Article') + journal = article.find('Journal') + issue = journal.find('JournalIssue') + + if medline_citation.find('PMID') is not None: + article_data['pmid'] = medline_citation.find('PMID').text + article_data['pmcid'] = None + article_data['doi'] = None + else: + return article_data + + if journal.find('ISSN') is not None: + article_data['issn'] = journal.find('ISSN').text + else: + article_data['issn'] = None + + if journal.find('Title') is not None: + article_data['journal_title'] = journal.find('Title').text + else: + article_data['journal_title'] = None + + # some articles don't have an article title + article_title = article.find('ArticleTitle') + if article_title is not None and article_title.text is not None: + article_data['article_title'] = article_title.text.replace('[', '').replace(']', '') + else: + article_data['article_title'] = None + + article_data['last_revised'] = f"{medline_citation.find('DateRevised/Year').text}-{medline_citation.find('DateRevised/Month').text}-{medline_citation.find('DateRevised/Day').text}" + + # some articles don't have all fields present for publication date + if issue.find('PubDate/Year') is not None and issue.find('PubDate/Month') is not None and issue.find('PubDate/Day') is not None: + article_data['published'] = f"{issue.find('PubDate/Year').text}-{issue.find('PubDate/Month').text}-{issue.find('PubDate/Day').text}" + elif issue.find('PubDate/Year') is not None and issue.find('PubDate/Month') is not None: + article_data['published'] = f"{issue.find('PubDate/Year').text}-{issue.find('PubDate/Month').text}-01" + elif issue.find('PubDate/Year') is not None: + article_data['published'] = f"{issue.find('PubDate/Year').text}-01-01" + else: + article_data['published'] = None + + # extract and store abstract in a text file + abstract = article.find('Abstract') + abstract_filename = None + if abstract is not None: + abstract_text = "" + for abstract_text_element in abstract.iter('AbstractText'): + # if labels (objective, methods, etc.) are present, add them to the text (e.g. "OBJECTIVE: ") + if abstract_text_element.attrib.get('Label') is not None: + abstract_text += abstract_text_element.attrib.get('Label') + ": " + if abstract_text_element.text is not None: + abstract_text += abstract_text_element.text + "\n" + + # save abstract to a text file + abstract_filename = directory + "/" + article_data['pmid'] + ".txt" + with open(abstract_filename, 'w') as f: + if article_data['journal_title']: + f.write("Journal title: " + article_data['journal_title'] + "\n\n") + if article_data['article_title']: + f.write("Article title: " + article_data['article_title'] + "\n\n") + f.write("Abstract: " + abstract_text) + + # some articles are listed, but not released yet. Adding fields for such articles to maintain uniformity. + article_data['live'] = True + article_data['release_date'] = None + article_data['license'] = None + article_data['pubmed_ftp_link'] = None + article_data['filepath'] = abstract_filename + + return article_data + except Exception as e: + with open(error_file, 'a') as f: + f.write("Error in metadata extraction subprocess for PMID " + article_data['pmid'] + ": " + str(e) + "\n") + return {'error': str(e)} + + +def getArticleIDs(metadata: list, error_file: str): + """ + Uses the PubMed ID converter API to get PMCID and DOI for each article. + Queries the API in batches of 200 articles at a time. + Also updates the metadata with the release date and live status - some articles are yet to be released. + Args: + metadata: List of dictionaries containing metadata for each article. + Returns: + metadata: Updated metadata with PMCID, DOI, release date, and live status information. + """ + # print("In getArticleIDs()") + + start_time = time.monotonic() + + base_url = "https://www.ncbi.nlm.nih.gov/pmc/utils/idconv/v1.0/" + app_details = "?tool=ncsa_uiuc&email=caiincsa@gmail.com&format=json" + + batch_size = 200 # maximum number of articles API can process in one request + + for i in range(0, len(metadata), batch_size): + batch = metadata[i:i + batch_size] + ids = ",".join([article['pmid'] for article in batch]) + try: + response = requests.get(base_url + app_details + "&ids=" + ids) + data = response.json() + records = data['records'] + + # PARALLELIZE THIS FOR LOOP - UPDATES ADDITIONAL FIELDS FOR ALL ARTICLES AT ONCE + with Manager() as manager: + shared_metadata = manager.dict() # Use a shared dictionary + with concurrent.futures.ProcessPoolExecutor() as executor: + futures = { + executor.submit(updateArticleMetadata, shared_metadata, record): record + for record in records + } + concurrent.futures.wait(futures) + for future in concurrent.futures.as_completed(futures): + record = futures[future] + try: + future.result() + except Exception as exc: + print('%r generated an exception: %s' % (record, exc)) + with open(error_file, 'a') as f: + f.write(f"Record: {record}\t") + f.write(f"Exception: {type(exc).__name__} - {exc}\n") + + # Update original metadata after loop + for article in metadata: + if article['pmid'] in shared_metadata: + # print("Shared metadata: ", shared_metadata[article['pmid']]) + if 'errmsg' in shared_metadata[article['pmid']]: + article['live'] = False + else: + article['pmcid'] = shared_metadata[article['pmid']]['pmcid'] + article['doi'] = shared_metadata[article['pmid']]['doi'] + article['live'] = shared_metadata[article['pmid']]['live'] + article['release_date'] = shared_metadata[article['pmid']]['release_date'] + #print("Updated metadata: ", article) + except Exception as e: + #print("Error: ", e) + with open(error_file, 'a') as f: + f.write("Error in getArticleIds(): " + str(e) + "\n") + #print("Length of metadata after ID conversion: ", len(metadata)) + + POSTHOG.capture(distinct_id = "pubmed_extraction", + event = "getArticleIDs", + properties = { + "description": "Converting PMIDs to PMCIDs and update the ID in main metadata", + "no_of_articles": 100, + "total_runtime": time.monotonic() - start_time, + } + ) + return metadata + + +def updateArticleMetadata(shared_metadata, record): + """ + Updates metadata with PMCID, DOI, release date, and live status information for given article. + Used within getArticleIDs() function. + """ + if 'errmsg' in record: + #print("Error: ", record['errmsg']) + shared_metadata[record['pmid']] = { + **record, # Create a copy with record data + 'live': False + } + else: + # Update shared dictionary with pmid as key and updated article data as value + shared_metadata[record['pmid']] = { + **record, # Create a copy with record data + 'pmcid': record['pmcid'], + 'doi': record.get('doi', ''), + 'live': False if 'live' in record and record['live'] == "false" else True, + 'release_date': record['release-date'] if 'release-date' in record else None, + } + + # POSTHOG.capture(distinct_id = "pubmed_extraction", + # event_name = "updateArticleMetadata", + # properties = { + # "description": "Updating PMCID and DOI in main metadata" + # "no_of_articles": 1, + # "total_runtime": round(time.time() - start_time, 2), + # } + # ) + + + + +def downloadArticles(metadata: list, dir: str, error_file: str): + """ + Downloads articles from PMC and stores them in local directory. + Args: + metadata: List of dictionaries containing metadata for each article. + Returns: + metadata: Updated metadata with license, FTP link, and downloaded filepath information. + """ + # print("In downloadArticles()") + start_time = time.monotonic() + try: + base_url = "https://www.ncbi.nlm.nih.gov/pmc/utils/oa/oa.fcgi?" + + updated_articles = {} + + # Use ThreadPoolExecutor to run download_article for each article in parallel + download_article_partial = partial(download_article, api_url=base_url, dir=dir, error_file=error_file) + with concurrent.futures.ProcessPoolExecutor() as executor: + futures = [executor.submit(download_article_partial, article) for article in metadata] + for future in concurrent.futures.as_completed(futures): + try: + # print("Starting new download...") + updated_article = future.result(timeout=15*60) # Check result without blocking + if updated_article: + updated_articles[updated_article['pmid']] = updated_article + # print("Updated article: ", updated_article) + except Exception as e: + print("Error downloading article:", e) + with open(error_file, 'a') as f: + f.write("Error in downloadArticles(): " + str(e) + "\n") + + # Update original metadata with updated articles + for article in metadata: + if article['pmid'] in updated_articles: + article.update(updated_articles[article['pmid']]) + + # print("Updated metadata after download: ", metadata) + + POSTHOG.capture(distinct_id = "pubmed_extraction", + event = "downloadArticles", + properties = { + "description": "Download articles and update metadata", + "no_of_articles": 100, + "total_runtime": time.monotonic() - start_time, + } + ) + + return metadata + + except Exception as e: + #print("Error downloading articles: ", e) + with open(error_file, 'a') as f: + f.write("Error in downloadArticles(): " + str(e) + "\n") + return metadata + +def download_article(article, api_url, dir, error_file): + """ + Downloads the article from given FTP link and updates metadata with license, FTP link, and downloaded filepath information. + This function is used within downloadArticles() function. + Args: + article: Dictionary containing metadata for the article. + api_url: URL for the article download API. + ftp: FTP connection object. + Returns: + article: Updated metadata for the article. + """ + + # print("Downloading articles...") + try: + if not article['live'] or article['pmcid'] is None: + return + + # Proceed with download + # Connect to FTP server anonymously + ftp = ftplib.FTP("ftp.ncbi.nlm.nih.gov", timeout=15*60) + ftp.login() + + if article['pmcid']: + final_url = api_url + "id=" + article['pmcid'] + # print("\nDownload URL: ", final_url) + + xml_response = requests.get(final_url) + extracted_data = extractArticleData(xml_response.text, error_file) + # print("Extracted license and link data: ", extracted_data) + + if not extracted_data: + article['live'] = False + return + + article['license'] = extracted_data[0]['license'] + article['pubmed_ftp_link'] = extracted_data[0]['href'] if 'href' in extracted_data[0] else None + + ftp_url = urlparse(extracted_data[0]['href']) + ftp_path = ftp_url.path[1:] + # print("FTP path: ", ftp_path) + + filename = article['pmcid'] + "_" + ftp_path.split("/")[-1] + local_file = os.path.join(dir, filename) + + try: + with open(local_file, 'wb') as f: + ftp.retrbinary('RETR ' + ftp_path, f.write) # Download directly to file + + # print("Downloaded FTP file: ", local_file) + article['filepath'] = local_file + + if filename.endswith(".tar.gz"): + extracted_pdf_paths = extractPDF(local_file, dir, error_file, article['pmcid']) + #print("Extracted PDFs from .tar.gz file: ", extracted_pdf_paths) + article['filepath'] = ",".join(extracted_pdf_paths) + os.remove(local_file) + + except concurrent.futures.TimeoutError: + print("Download timeout reached.") + + ftp.quit() + + # print("\nUpdated metadata after download: ", article) + return article + except Exception as e: + #print("Error in article download subprocess: ", e) + with open(error_file, 'a') as f: + f.write("Error in download_article() PMID " + article['pmid'] + ": " + str(e) + "\n") + return None + +def extractPDF(tar_gz_filepath: str, dest_directory: str, error_file: str, pmcid: str): + """ + Extracts PDF files from the downloaded .tar.gz file. The zipped folder contains other supplementary + materials like images, etc. which are not extracted. + Args: + tar_gz_filepath: Path to the .tar.gz file. + Returns: + extracted_paths: List of paths to the extracted PDF files. + """ + try: + # print("Extracting PDF from: ", tar_gz_filepath) + extracted_paths = [] + with tarfile.open(tar_gz_filepath, "r:gz") as tar: + for member in tar: + if member.isreg() and member.name.endswith(".pdf"): + tar.extract(member, path=dest_directory) + #print("Extracted: ", member.name) + original_path = os.path.join(dest_directory, member.name) + new_filename = pmcid + "_" + os.path.basename(member.name) + new_path = os.path.join(dest_directory, new_filename) + #print("New path: ", new_path) + os.rename(original_path, new_path) + extracted_paths.append(new_path) + + return extracted_paths + except Exception as e: + #print("Error extracting PDF: ", e) + with open(error_file, 'a') as f: + f.write("Error in extractPDF() PMCID - " + pmcid + ": " + str(e) + "\n") + return [] + +def extractArticleData(xml_string: str, error_file: str): + """ + Extracts license information and article download link from the XML response. + This function process XML response for single article. + Args: + xml_string: XML response from PMC download API. + Returns: + extracted_data: List of dictionaries containing license and download link for the article. + """ + # print("In extractArticleData") + + try: + root = ET.fromstring(xml_string) + # if there is an errors (article not open-access), return empty list (skip article) + if root.find(".//error") is not None: + return [] + + records = root.findall(".//record") + extracted_data = [] + href = None + + for record in records: + record_id = record.get("id") # pmcid + license = record.get("license") + links = record.findall(".//link") + + for link in links: + if link.get("format") == "pdf": + href = link.get("href") + break + # if PDF link not found, use the available tgz link + if not href: + href = links[0].get("href") + + extracted_data.append({ + "record_id": record_id, + "license": license, + "href": href + }) + + return extracted_data + except Exception as e: + #print("Error extracting article data: ", e) + with open(error_file, 'a') as f: + f.write("Error in extractArticleData(): " + str(e) + "\n") + f.write("XML String: " + xml_string + "\n") + return [] + +def upload_file(client, bucket_name, file_path, object_name, error_file, upload_log): + """ + Uploads a single file to the Minio bucket. + """ + try: + client.fput_object(bucket_name, object_name, file_path) + print(f"Uploaded: {object_name}") + with open(upload_log, 'a') as f: + f.write("uploaded: " + file_path + "\n") + os.remove(file_path) + except Exception as e: + #print(f"Error uploading {object_name}: {e}") + with open(error_file, 'a') as f: + f.write("Error in upload_file(): " + str(e) + "\n") + +def uploadToStorage(filepath: str, error_file: str): + """ + Uploads all files present under given filepath to Minio bucket in parallel. + """ + # print("in uploadToStorage()") + try: + bucket_name = "pubmed" + + found = MINIO_CLIENT.bucket_exists(bucket_name) + if not found: + MINIO_CLIENT.make_bucket(bucket_name) + print("Created bucket", bucket_name) + # else: + # print("Bucket", bucket_name, "already exists") + #upload_log = error_file.split("_")[0] + ".txt" + upload_log = "all_papers.txt" + # Get all files to upload + files = [] + for root, _, files_ in os.walk(filepath): + for file in files_: + file_path = os.path.join(root, file) + object_name = file_path.split("/")[-1] + files.append((MINIO_CLIENT, bucket_name, file_path, object_name, error_file, upload_log)) + + # Use concurrent.futures ThreadPoolExecutor with limited pool size + with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: + # Submit files in batches of 10 + for i in range(0, len(files), 10): + batch_files = files[i:i+10] + futures = [executor.submit(upload_file, *args) for args in batch_files] + done, not_done = concurrent.futures.wait(futures, timeout=180) + + for future in not_done: + future.cancel() # Cancel the future if it is not done within the timeout + + for future in done: + try: + future.result() # This will raise any exceptions from upload_file + except Exception as e: + with open(error_file, 'a') as f: + f.write("Error in upload_file(): " + str(e) + "\n") + + # for future in futures: + # future.result() # This will raise any exceptions from upload_file + + return "success" + except Exception as e: + #print("Error uploading to storage: ", e) + with open(error_file, 'a') as f: + f.write("Error in uploadToStorage(): " + str(e) + "\n") + return "failure" +