diff --git a/Carrot-Assistant/omop/OMOP_match.py b/Carrot-Assistant/omop/OMOP_match.py index 79cd5e3..ce5f847 100644 --- a/Carrot-Assistant/omop/OMOP_match.py +++ b/Carrot-Assistant/omop/OMOP_match.py @@ -19,6 +19,27 @@ class OMOPMatcher: """ This class retrieves matches from an OMOP database and returns the best """ + _instance = None + + @classmethod + def get_instance(cls, logger: Optional[Logger] = None): + """ + This method returns the singleton instance of the OMOPMatcher class + and creates it if it does not exist. + + Parameters + ---------- + logger: Logger + A logger for logging runs of the tool + + Returns + ------- + OMOPMatcher + The singleton instance of the OMOPMatcher class + """ + if cls._instance is None: + cls._instance = cls(logger) + return cls._instance def __init__(self, logger: Optional[Logger] = None): # Connect to database @@ -27,34 +48,37 @@ def __init__(self, logger: Optional[Logger] = None): self.logger = logger load_dotenv() - try: - self.logger.info( - "Initialize the PostgreSQL connection based on the environment variables" - ) - DB_HOST = environ["DB_HOST"] - DB_USER = environ["DB_USER"] - DB_PASSWORD = quote_plus(environ["DB_PASSWORD"]) - DB_NAME = environ["DB_NAME"] - DB_PORT = environ["DB_PORT"] - DB_SCHEMA = environ["DB_SCHEMA"] - - connection_string = ( - f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}" - ) - engine = create_engine(connection_string) - logger.info(f"Connected to PostgreSQL database {DB_NAME} on {DB_HOST}") + if not hasattr(self, 'engine'): + + try: + self.logger.info( + "Initialize the PostgreSQL connection based on the environment variables" + ) + DB_HOST = environ["DB_HOST"] + DB_USER = environ["DB_USER"] + DB_PASSWORD = quote_plus(environ["DB_PASSWORD"]) + DB_NAME = environ["DB_NAME"] + DB_PORT = environ["DB_PORT"] + DB_SCHEMA = environ["DB_SCHEMA"] + + connection_string = ( + f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}" + ) + engine = create_engine(connection_string) + logger.info(f"Connected to PostgreSQL database {DB_NAME} on {DB_HOST}") - except Exception as e: - logger.error(f"Failed to connect to PostgreSQL: {e}") - raise ValueError(f"Failed to connect to PostgreSQL: {e}") + except Exception as e: + logger.error(f"Failed to connect to PostgreSQL: {e}") + raise ValueError(f"Failed to connect to PostgreSQL: {e}") - self.engine = engine - self.schema = DB_SCHEMA + self.engine = engine + self.schema = DB_SCHEMA def close(self): """Close the engine connection.""" - self.engine.dispose() - self.logger.info("PostgreSQL connection closed.") + if hasattr(self, 'engine'): + self.engine.dispose() + self.logger.info("PostgreSQL connection closed.") def calculate_best_matches( self, @@ -193,7 +217,7 @@ def fetch_OMOP_concepts( session = Session() results = session.execute(query).fetchall() results = pd.DataFrame(results) - session.close() + if not results.empty: # Define a function to calculate similarity score using the provided logic def calculate_similarity(row): @@ -481,5 +505,5 @@ def run(opt: argparse.Namespace, search_term:str, logger: Logger): max_separation_descendant, max_separation_ancestor, ) - omop_matcher.close() + return res diff --git a/Carrot-Assistant/routers/pipeline_routes.py b/Carrot-Assistant/routers/pipeline_routes.py index 72ea062..683bc7f 100644 --- a/Carrot-Assistant/routers/pipeline_routes.py +++ b/Carrot-Assistant/routers/pipeline_routes.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter +from fastapi import APIRouter, Request import asyncio from collections.abc import AsyncGenerator import json @@ -9,6 +9,7 @@ import assistant from omop import OMOP_match +from omop.OMOP_match import OMOPMatcher from options.base_options import BaseOptions from components.embeddings import Embeddings from options.pipeline_options import PipelineOptions, parse_pipeline_args @@ -18,6 +19,7 @@ logger = Logger().make_logger() + class PipelineRequest(BaseModel): """ This class takes the format of a request to the API @@ -34,43 +36,29 @@ class PipelineRequest(BaseModel): pipeline_options: PipelineOptions = Field(default_factory=PipelineOptions) - - -async def generate_events(request: PipelineRequest) -> AsyncGenerator[str]: +async def generate_events( + request: PipelineRequest, use_llm: bool, end_session: bool +) -> AsyncGenerator[str]: """ - Generate LLM output and OMOP results for a list of informal names + Generate LLM output and OMOP results for a list of informal names. - Parameters + parameters ---------- request: PipelineRequest The request containing the list of informal names. - Workflow - -------- - For each informal name: - The first event is to Query the OMOP database for a match - The second event is to fetches relevant concepts from the OMOP database - Finally,The function yields results as they become available, - allowing for real-time streaming. - - Conditions - ---------- - If the OMOP database returns a match, the LLM is not queried - - If the OMOP database does not return a match, - the LLM is used to find the formal name and the OMOP database is - queried for the LLM output. - - Finally, the function yields the results for real-time streaming. + use_llm: bool + A flag to determine whether to use LLM to find the formal name. + end_session: bool + A flag to determine whether to end the session. Yields ------ str - JSON encoded strings of the event results. Two types are yielded: - 1. "llm_output": The result from the language model processing. - 2. "omop_output": The result from the OMOP database matching. + JSON encoded strings of the event results. """ + informal_names = request.names opt = BaseOptions() opt.initialize() @@ -78,67 +66,121 @@ async def generate_events(request: PipelineRequest) -> AsyncGenerator[str]: opt = opt.parse() print("Received informal names:", informal_names) + print(f"use_llm flag is set to: {use_llm}") + print(f"end_session flag is set to: {end_session}") - # Query OMOP for each informal name - - for informal_name in informal_names: - print(f"Querying OMOP for informal name: {informal_name}") - omop_output = OMOP_match.run(opt=opt, search_term=informal_name, logger=logger) - - if omop_output and any(concept["CONCEPT"] for concept in omop_output): - print(f"OMOP match found for {informal_name}: {omop_output}") - output = {"event": "omop_output", "data": omop_output} - yield json.dumps(output) - continue - - else: - print("No satisfactory OMOP results found for {informal_name}, using LLM...") - - # Use LLM to find the formal name and query OMOP for the LLM output - - llm_outputs = assistant.run(opt=opt, informal_names=informal_names, logger=logger) - for llm_output in llm_outputs: - - - print("LLM output for", llm_output["informal_name"], ":", llm_output["reply"]) - - print("Querying OMOP for LLM output:", llm_output["reply"]) - - output = {"event": "llm_output", "data": llm_output} - yield json.dumps(output) - - # Simulate some delay before sending the next part - await asyncio.sleep(2) - - omop_output = OMOP_match.run( - opt=opt, search_term=llm_output["reply"], logger=logger - ) - - print("OMOP output for", llm_output["reply"], ":", omop_output) - - output = {"event": "omop_output", "data": omop_output} + + # If the user chooses to end the session, close the database connection + if end_session: + print("Final API call. Closing the database connection....") + output = {"event": "session_ended", "message": "Session has ended."} yield json.dumps(output) + OMOPMatcher.get_instance().close() + return + + no_match_names = [] + + try: + if informal_names: + + # Query OMOP for the informal names + if not use_llm: + for informal_name in informal_names: + print(f"Querying OMOP for informal name: {informal_name}") + omop_output = OMOP_match.run( + opt=opt, search_term=informal_name, logger=logger + ) + + # If a match is found, yield the OMOP output + if omop_output and any( + concept["CONCEPT"] for concept in omop_output + ): + print(f"OMOP match found for {informal_name}: {omop_output}") + output = {"event": "omop_output", "data": omop_output} + yield json.dumps(output) + + # If no match is found, yield a message and add the name to the no_match_names list + else: + print(f"No satisfactory OMOP results found for {informal_name}") + output = { + "event": "omop_output", + "data": omop_output, + "message": f"No match found in OMOP database for {informal_name}.", + } + yield json.dumps(output) + no_match_names.append(informal_name) + print(f"\nno_match_names: {no_match_names}\n") + else: + no_match_names = informal_names + + # Use LLM to find the formal name and query OMOP for the LLM output + if no_match_names and use_llm: + llm_outputs = assistant.run( + opt=opt, informal_names=no_match_names, logger=logger + ) + + for llm_output in llm_outputs: + print( + "LLM output for", + llm_output["informal_name"], + ":", + llm_output["reply"], + ) + + output = {"event": "llm_output", "data": llm_output} + yield json.dumps(output) + + finally: + + # Ensure database connection is closed at the end of processing + if not no_match_names: + print( + "no matches found. Closing the database connection..." + ) + OMOPMatcher.get_instance().close() + + else: + print("\nDatabase connection remains open.") @router.post("/") -async def run_pipeline(request: PipelineRequest) -> EventSourceResponse: +async def run_pipeline(request: Request) -> EventSourceResponse: """ - Call generate_events to run the pipeline + This function runs the pipeline for a list of informal names. Parameters ---------- - request: PipelineRequest - The request containing a list of informal names + request: Request + The request containing the list of informal names. + + Workflow + -------- + The function generates events for each informal name in the list. + + use_llm: bool + A flag to determine whether to use LLM to find the formal name. + Returns ------- EventSourceResponse - The response containing the events + The response containing the results of the pipeline. """ - return EventSourceResponse(generate_events(request)) + body = await request.json() + pipeline_request = PipelineRequest(**body) + + use_llm = body.get("use_llm", False) + end_session = body.get("end_session", False) + + print( + f"Running pipeline with use_llm: {use_llm} and end_session: {end_session}" + ) + return EventSourceResponse( + generate_events(pipeline_request, use_llm, end_session) + ) @router.post("/db") -async def run_db(request: PipelineRequest) -> List[Dict[str,Any]]: +async def run_db(request: PipelineRequest) -> List[Dict[str, Any]]: """ Fetch OMOP concepts for a name @@ -166,7 +208,8 @@ async def run_db(request: PipelineRequest) -> List[Dict[str,Any]]: omop_outputs.append({"event": "omop_output", "content": omop_output}) return omop_outputs - + + @router.post("/vector_search") async def run_vector_search(request: PipelineRequest): """ @@ -187,12 +230,10 @@ async def run_vector_search(request: PipelineRequest): """ search_terms = request.names embeddings = Embeddings( - embeddings_path=request.pipeline_options.embeddings_path, - force_rebuild=request.pipeline_options.force_rebuild, - embed_vocab=request.pipeline_options.embed_vocab, - model_name=request.pipeline_options.embedding_model, - search_kwargs=request.pipeline_options.embedding_search_kwargs, - ) - return {'event': 'vector_search_output', 'content': embeddings.search(search_terms)} - - + embeddings_path=request.pipeline_options.embeddings_path, + force_rebuild=request.pipeline_options.force_rebuild, + embed_vocab=request.pipeline_options.embed_vocab, + model_name=request.pipeline_options.embedding_model, + search_kwargs=request.pipeline_options.embedding_search_kwargs, + ) + return {"event": "vector_search_output", "content": embeddings.search(search_terms)} diff --git a/Carrot-Assistant/text_input.py b/Carrot-Assistant/text_input.py index 31edec1..22a8570 100644 --- a/Carrot-Assistant/text_input.py +++ b/Carrot-Assistant/text_input.py @@ -1,9 +1,24 @@ import json -from ui_utilities import display_concept_info, stream_message, capitalize_words, make_api_call -import sseclient +from ui_utilities import ( + display_concept_info, + stream_message, + capitalize_words, + make_api_call, +) import streamlit as st + +# ---> Process + +# 1. User enters informal names of medications. +# 2. Send the informal names to the OMOP database. +# 3. If no matches are found, ask the user if they want to try the LLM. +# 4. If the user agrees, send the unmatched informal names to the LLM. +# 5. Use the LLM-predicted names to query the OMOP database. +# 6. Display the results to the user. + # Page configuration + st.set_page_config(page_title="Lettuce", page_icon="🥬", layout="wide") st.markdown( "

Lettuce 🥬

", @@ -20,44 +35,173 @@ ) with st.expander("Search options"): - skip_llm = st.checkbox("Ask the LLM first?", value=True) vocab_id = st.selectbox(label="Vocabulary ID", options=["RxNorm", "UK Biobank"]) +# Initialize session state + +# The session state is used to keep track of the state of the session. + +# If the results of the lettuce are displayed or the user doesn't want to use LLM, the session is ended. + +if "session_ended" not in st.session_state: + st.session_state["session_ended"] = False + +if st.session_state["session_ended"]: + st.write("Session has ended. Thank you for using Carrot!") + st.stop() + if st.button("Send"): if informal_names: names_list = [ capitalize_words(name.strip()) for name in informal_names.split(",") ] with st.spinner("Processing..."): - result_stream: sseclient.SSEClient = make_api_call(names_list, skip_llm, vocab_id) + # Step 1: Query OMOP database with the initial list of names + result_stream = make_api_call(names_list, use_llm=False, end_session=False, vocab_id=vocab_id) + no_match_names = [] - # Stream the results for event in result_stream.events(): response = json.loads(event.data) event_type = response["event"] + message = response.get("message", "") + + if message: + stream_message(f"

{message}

") + + # Display the results from the OMOP database + + if event_type == "omop_output": + for omop_data in response["data"]: + search_term = omop_data.get("search_term", "") + if not omop_data["CONCEPT"]: + stream_message( + f"

No concepts found for {search_term}.

" + ) + no_match_names.append(search_term) + else: + for i, concept in enumerate(omop_data["CONCEPT"], 1): + with st.expander( + f"Concept {i}: {concept['concept_name']}", + expanded=True, + ): + display_concept_info(concept) + + # Save unmatched names in session state if no match was found + if no_match_names: + st.session_state["no_match_names"] = no_match_names + st.session_state["vocab_id"] = vocab_id + st.session_state["llm_requested"] = False - # Stream the LLM output - if event_type == "llm_output": +# Ask the user if they want to try the LLM if no matches are found in the OMOP database + +if "no_match_names" in st.session_state and not st.session_state.get( + "llm_requested", False +): + user_choice = st.radio( + "No match found in OMOP database. Would you like to try the LLM?", + ["Select an option", "Yes", "No"], + key="llm_option", + ) + + # If the user chooses "Yes", set the session state to request LLM predictions. + # If the user chooses "No", end the session. + + if user_choice == "No": + with st.spinner("Ending session..."): + + # Make a final API call with to close the database connection. + result_stream = make_api_call( + st.session_state["no_match_names"], + use_llm=False, + end_session=True, # Set end_session flag to True + vocab_id=vocab_id + ) + + for event in result_stream.events(): + response = json.loads(event.data) + message = response.get("message", "") + if message: + stream_message(f"

{message}

") + + # Display final message before ending the session + stream_message("

Thank you for using Carrot! 🥕

") + + st.session_state["session_ended"] = True + st.stop() + + elif user_choice == "Yes": + st.session_state["llm_requested"] = True + st.rerun() + + if st.session_state["session_ended"]: + if st.button("End Session"): + st.write( + "Thank you for using Carrot. Feel free to ask me more about informal names!" + ) + st.stop() + +# Process LLM predictions + +if st.session_state.get("llm_requested", False): + no_match_names = st.session_state["no_match_names"] + vocab_id = st.session_state["vocab_id"] + + if "llm_processed_names" not in st.session_state: + st.session_state["llm_processed_names"] = [] + + llm_results = [] + + # Processing LLM predictions + with st.spinner("Processing with LLM..."): + + # Step 2: Query LLM for unmatched names + result_stream = make_api_call(no_match_names, use_llm=True, end_session=False, vocab_id=vocab_id) + + for event in result_stream.events(): + response = json.loads(event.data) + event_type = response["event"] + message = response.get("message", "") + + if message: + stream_message(f"

{message}

") + + if event_type == "llm_output": + llm_output = response["data"] + informal_name = llm_output["informal_name"] + formal_name = llm_output["reply"] + + if informal_name not in st.session_state["llm_processed_names"]: stream_message( - f'

I found {response["data"]["reply"]} as the formal name for {response["data"]["informal_name"]}

' + f'

I found {formal_name} as the formal name for {informal_name}

' ) + st.session_state["llm_processed_names"].append(informal_name) + llm_results.append(formal_name) + + # Step 3: Re-query OMOP database with the LLM-predicted names + if llm_results: + with st.spinner("Processing final OMOP query..."): + new_result_stream = make_api_call( + llm_results, use_llm=False, end_session=True, vocab_id=vocab_id + ) + + for new_event in new_result_stream.events(): + new_response = json.loads(new_event.data) + new_event_type = new_response["event"] - # Stream the OMOP output - elif event_type == "omop_output": - for j, omop_data in enumerate(response["data"]): - if ( - omop_data["CONCEPT"] is None - or len(omop_data["CONCEPT"]) == 0 - ): + if new_event_type == "omop_output": + for new_omop_data in new_response["data"]: + if not new_omop_data["CONCEPT"]: stream_message( - f"

No concepts found for {omop_data['search_term']}.

" + f"

No concepts found for {new_omop_data['search_term']}.

" ) else: - for i, concept in enumerate(omop_data["CONCEPT"], 1): + for i, concept in enumerate(new_omop_data["CONCEPT"], 1): with st.expander( f"Concept {i}: {concept['concept_name']}", expanded=True, ): display_concept_info(concept) - else: - st.warning("Please enter an informal name before sending.") + + # Stop any further API calls after processing + st.session_state["session_ended"] = True + st.stop() \ No newline at end of file diff --git a/Carrot-Assistant/ui_utilities.py b/Carrot-Assistant/ui_utilities.py index ea61f1e..0e445e3 100644 --- a/Carrot-Assistant/ui_utilities.py +++ b/Carrot-Assistant/ui_utilities.py @@ -3,6 +3,8 @@ import time import requests from options.pipeline_options import PipelineOptions +from typing import List, Union + def display_concept_info(concept: dict) -> None: """ @@ -45,6 +47,7 @@ def display_concept_info(concept: dict) -> None: for relationship in concept["CONCEPT_RELATIONSHIP"]: stream_message(f"

- {relationship}

") + def stream_message(message: str) -> None: """ Stream a message to the user, rendering HTML with a typewriter effect @@ -81,27 +84,43 @@ def capitalize_words(s: str) -> str: return " ".join(capitalized_words) -def make_api_call(names: list[str], skip_llm: bool, vocab_id: str | None) -> sseclient.SSEClient: +def make_api_call( + names: List[str], + use_llm: bool, + vocab_id: Union[str, None], + end_session: bool, +) -> sseclient.SSEClient: """ - Make a call to the Lettuce API to retrieve OMOP concepts. + This function makes an API call to the backend server to process the input names. Parameters ---------- - names: list[str] - The informal names to send to the API + names: List[str] + The list of names to process + + use_llm: bool + Whether to use the LLM model for processing + + vocab_id: Union[str, None] + The vocabulary ID to use for processing + + end_session: bool + Whether to end the session after processing Returns ------- sseclient.SSEClient - The stream of events from the API + The server-sent event client to stream the results """ url = "http://127.0.0.1:8000/pipeline/" - if not skip_llm: - url = url + "db" headers = {"Content-Type": "application/json"} pipe_opts = PipelineOptions(vocabulary_id=vocab_id) - data = {"names": names, "pipeline_options": pipe_opts.model_dump()} + data = { + "names": names, + "pipeline_options": pipe_opts.model_dump(), + "use_llm": use_llm, + "end_session": end_session, + } + print("Making API call with data:", data) response = requests.post(url, headers=headers, json=data, stream=True) return sseclient.SSEClient(response) - -