From 2add1c1732e852a025da9ba87e64db509b5e9c32 Mon Sep 17 00:00:00 2001 From: potthoffjan Date: Wed, 10 Jul 2024 09:36:26 +0200 Subject: [PATCH] History aware Retriever and multiple LLM answers and basic in/out Signed-off-by: potthoffjan --- .../RAG/LangChain_Implementation/chain.py | 209 ++++++++++++++++++ 1 file changed, 209 insertions(+) create mode 100644 src/backend/RAG/LangChain_Implementation/chain.py diff --git a/src/backend/RAG/LangChain_Implementation/chain.py b/src/backend/RAG/LangChain_Implementation/chain.py new file mode 100644 index 0000000..b84c75a --- /dev/null +++ b/src/backend/RAG/LangChain_Implementation/chain.py @@ -0,0 +1,209 @@ +import json +import os +import sys + + +from dotenv import load_dotenv + +from astrapy import DataAPIClient +from astrapy.db import AstraDB +from langchain_astradb import AstraDBVectorStore + +#from langchain.embeddings import OpenAIEmbeddings +#from langchain_community.embeddings import OpenAIEmbeddings +from langchain_openai import OpenAIEmbeddings +from langchain.schema import Document +from langchain.text_splitter import RecursiveCharacterTextSplitter +#from langchain_community.embeddings import HuggingFaceEmbeddings +#from langchain_huggingface import HuggingFaceEmbeddings +#from langchain_community.llms.openai import OpenAI +from langchain_openai import OpenAI +from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_community.vectorstores.chroma import Chroma + +from langchain_core.prompts import ChatPromptTemplate +from langchain.chains import create_history_aware_retriever +from langchain_core.prompts import MessagesPlaceholder +from langchain.chains import create_retrieval_chain +from langchain.chains.combine_documents import create_stuff_documents_chain + +from langchain_community.chat_message_histories import ChatMessageHistory +from langchain_core.chat_history import BaseChatMessageHistory +from langchain_core.runnables.history import RunnableWithMessageHistory +from langchain_core.messages import HumanMessage, AIMessage +from langchain_anthropic import AnthropicLLM +from langchain_anthropic import ChatAnthropic + +""" +Script takes in args: +0) list of LLMs with which to retrieve e.g. ['gpt-4', 'gemini', 'mistral'] +1) input string +2) chat history in following shape [ + {'gpt-4': "Hello, how can I help you?"}, + {'user': "What do prisons and plants have in common?"} +] +""" + +def custom_history(entire_history:list, llm_name:str): + chat_history = [] + for msg in entire_history: + if 'user' in msg: + chat_history.extend([HumanMessage(content=msg['user'])]) + if llm_name in msg: + chat_history.extend([AIMessage(content=msg[llm_name])]) + return chat_history + + +def main(): + + if len(sys.argv) < 3: + print("""Error: Please provide: + 1) [list of LLM models to use] + (['gpt-4', 'gemini', 'claude']) + 2) 'input string' + 3) [{chat history}] in the following shape: + [{'gpt-4': "Hello, how can I help you?"}, + {'user': "What do prisons and plants have in common?"} + etc.]""") + + # Arguments + llm_list = sys.argv[1] + llm_list = list(llm_list.replace('[', '').replace(']', '').replace("'", '').split(',')) + if not llm_list: + llm_list = ['gpt-4'] + #print(llm_list) + input_string = sys.argv[2] + #print(input_string) + message_history = sys.argv[3] + #print(message_history) + message_history = message_history.split(';;') + #print(message_history) + message_history = [json.loads(substring.replace("'", '"')) for substring in message_history] + #print(message_history) + + load_dotenv() + + # to be put into seperate function in order to invoke LLMs seperately + openai_api_key = os.environ.get('OPENAI_API_KEY') + GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY') + ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY") + + test_llm_list = ['gpt-4'] + #llm_list = test_llm_list + test_history = [ + {'gpt-4': "Hello, how can I help you?", + 'gemini': "Hello, how can I help you?"}, + {'user': "What do prisons and plants have in common?"}, + {'gpt-4': "They both have cell walls.", + 'gemini': "They have cell walls."}, + ] + # message_history = test_history + + test_query = "Ah, true. Thanks. What else do they have in common?" + # test_query = "How many corners does a heptagon have?" + # input_string = test_query + # test_follow_up = "How does one call a polygon with two more corners?" + + # AstraDB Section + ASTRA_DB_API_ENDPOINT = os.environ.get('ASTRA_DB_API_ENDPOINT') + ASTRA_DB_APPLICATION_TOKEN = os.environ.get('ASTRA_DB_APPLICATION_TOKEN') + ASTRA_DB_NAMESPACE = 'test' + ASTRA_DB_COLLECTION = 'test_collection_2' + + # LangChain Docs: ------------------------- + vstore = AstraDBVectorStore( + embedding=OpenAIEmbeddings(openai_api_key=openai_api_key), + collection_name=ASTRA_DB_COLLECTION, + api_endpoint=ASTRA_DB_API_ENDPOINT, + token=ASTRA_DB_APPLICATION_TOKEN, + namespace=ASTRA_DB_NAMESPACE, + ) + # ------------------------------------------ + + # For test purposes: ----------------------- + # import bs4 + # from langchain_chroma import Chroma + # from langchain_community.document_loaders import WebBaseLoader + + # loader = WebBaseLoader( + # web_paths=("https://lilianweng.github.io/posts/2023-06-23-agent/",), + # bs_kwargs=dict( + # parse_only=bs4.SoupStrainer( + # class_=("post-content", "post-title", "post-header") + # ) + # ), + # ) + # docs = loader.load() + + # text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) + # splits = text_splitter.split_documents(docs) + # vectorstore = Chroma.from_documents(documents=splits, embedding=OpenAIEmbeddings()) + # retriever = vectorstore.as_retriever() + # # test end ---------------------------------- + + retriever = vstore.as_retriever(search_kwargs={"k": 3}) + + contextualize_q_system_prompt = """Given a chat history and the latest user question \ + which might reference context in the chat history, formulate a standalone question \ + which can be understood without the chat history. Do NOT answer the question, \ + just reformulate it if needed and otherwise return it as is.""" + + qa_system_prompt = """You are an assistant for question-answering tasks. \ + Use the following pieces of retrieved context to answer the question. \ + If you don't know the answer, just say that you don't know. \ + Use three sentences maximum and keep the answer concise.\ + + {context}""" + + answers = {} + for _llm in llm_list: + #print(_llm) + chat_history = custom_history(message_history, _llm) + if _llm == 'gpt-4': + llm = OpenAI(temperature=0.2) + elif _llm == 'gemini': + llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro-latest") + elif _llm == 'claude': + llm = ChatAnthropic(model_name="claude-3-opus-20240229") + + print(chat_history) + contextualize_q_prompt = ChatPromptTemplate.from_messages( + [ + ("system", contextualize_q_system_prompt), + MessagesPlaceholder("chat_history"), + ("human", "{input}"), + ] + ) + history_aware_retriever = create_history_aware_retriever( + llm, retriever, contextualize_q_prompt + ) + + qa_prompt = ChatPromptTemplate.from_messages( + [ + ("system", qa_system_prompt), + MessagesPlaceholder("chat_history"), + ("human", "{input}"), + ] + ) + question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) + ### Answer question ### + rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) + msg = rag_chain.invoke({"input": input_string, "chat_history": chat_history}) + answers[_llm] = msg['answer'] + print(msg['answer']) + #print(answers) + + # chat_history.extend([HumanMessage(content=input_string), AIMessage(content=msg_1["answer"])]) + # print(msg_1['input']) + # print(msg_1['answer']) + # print(chat_history) + # msg_2 = rag_chain.invoke({"input": test_follow_up, "chat_history": chat_history}) + # chat_history.extend([HumanMessage(content=test_follow_up), AIMessage(content=msg_2["answer"])]) + # print(msg_2['input']) + # print(msg_2['answer']) + # print(chat_history) + return answers + +if __name__ == "__main__": + main() + \ No newline at end of file