Skip to content

Commit

Permalink
History aware Retriever and multiple LLM answers and basic in/out
Browse files Browse the repository at this point in the history
Signed-off-by: potthoffjan <[email protected]>
  • Loading branch information
potthoffjan committed Jul 10, 2024
1 parent a7253d4 commit 2add1c1
Showing 1 changed file with 209 additions and 0 deletions.
209 changes: 209 additions & 0 deletions src/backend/RAG/LangChain_Implementation/chain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
import json
import os
import sys


from dotenv import load_dotenv

from astrapy import DataAPIClient
from astrapy.db import AstraDB
from langchain_astradb import AstraDBVectorStore

#from langchain.embeddings import OpenAIEmbeddings
#from langchain_community.embeddings import OpenAIEmbeddings
from langchain_openai import OpenAIEmbeddings
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
#from langchain_community.embeddings import HuggingFaceEmbeddings
#from langchain_huggingface import HuggingFaceEmbeddings
#from langchain_community.llms.openai import OpenAI
from langchain_openai import OpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_community.vectorstores.chroma import Chroma

from langchain_core.prompts import ChatPromptTemplate
from langchain.chains import create_history_aware_retriever
from langchain_core.prompts import MessagesPlaceholder
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain

from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.messages import HumanMessage, AIMessage
from langchain_anthropic import AnthropicLLM
from langchain_anthropic import ChatAnthropic

"""
Script takes in args:
0) list of LLMs with which to retrieve e.g. ['gpt-4', 'gemini', 'mistral']
1) input string
2) chat history in following shape [
{'gpt-4': "Hello, how can I help you?"},
{'user': "What do prisons and plants have in common?"}
]
"""

def custom_history(entire_history:list, llm_name:str):
chat_history = []
for msg in entire_history:
if 'user' in msg:
chat_history.extend([HumanMessage(content=msg['user'])])
if llm_name in msg:
chat_history.extend([AIMessage(content=msg[llm_name])])
return chat_history


def main():

if len(sys.argv) < 3:
print("""Error: Please provide:
1) [list of LLM models to use]
(['gpt-4', 'gemini', 'claude'])
2) 'input string'
3) [{chat history}] in the following shape:
[{'gpt-4': "Hello, how can I help you?"},
{'user': "What do prisons and plants have in common?"}
etc.]""")

# Arguments
llm_list = sys.argv[1]
llm_list = list(llm_list.replace('[', '').replace(']', '').replace("'", '').split(','))
if not llm_list:
llm_list = ['gpt-4']
#print(llm_list)
input_string = sys.argv[2]
#print(input_string)
message_history = sys.argv[3]
#print(message_history)
message_history = message_history.split(';;')
#print(message_history)
message_history = [json.loads(substring.replace("'", '"')) for substring in message_history]
#print(message_history)

load_dotenv()

# to be put into seperate function in order to invoke LLMs seperately
openai_api_key = os.environ.get('OPENAI_API_KEY')
GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY')
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY")

test_llm_list = ['gpt-4']
#llm_list = test_llm_list
test_history = [
{'gpt-4': "Hello, how can I help you?",
'gemini': "Hello, how can I help you?"},
{'user': "What do prisons and plants have in common?"},
{'gpt-4': "They both have cell walls.",
'gemini': "They have cell walls."},
]
# message_history = test_history

test_query = "Ah, true. Thanks. What else do they have in common?"
# test_query = "How many corners does a heptagon have?"
# input_string = test_query
# test_follow_up = "How does one call a polygon with two more corners?"

# AstraDB Section
ASTRA_DB_API_ENDPOINT = os.environ.get('ASTRA_DB_API_ENDPOINT')
ASTRA_DB_APPLICATION_TOKEN = os.environ.get('ASTRA_DB_APPLICATION_TOKEN')
ASTRA_DB_NAMESPACE = 'test'
ASTRA_DB_COLLECTION = 'test_collection_2'

# LangChain Docs: -------------------------
vstore = AstraDBVectorStore(
embedding=OpenAIEmbeddings(openai_api_key=openai_api_key),
collection_name=ASTRA_DB_COLLECTION,
api_endpoint=ASTRA_DB_API_ENDPOINT,
token=ASTRA_DB_APPLICATION_TOKEN,
namespace=ASTRA_DB_NAMESPACE,
)
# ------------------------------------------

# For test purposes: -----------------------
# import bs4
# from langchain_chroma import Chroma
# from langchain_community.document_loaders import WebBaseLoader

# loader = WebBaseLoader(
# web_paths=("https://lilianweng.github.io/posts/2023-06-23-agent/",),
# bs_kwargs=dict(
# parse_only=bs4.SoupStrainer(
# class_=("post-content", "post-title", "post-header")
# )
# ),
# )
# docs = loader.load()

# text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
# splits = text_splitter.split_documents(docs)
# vectorstore = Chroma.from_documents(documents=splits, embedding=OpenAIEmbeddings())
# retriever = vectorstore.as_retriever()
# # test end ----------------------------------

retriever = vstore.as_retriever(search_kwargs={"k": 3})

contextualize_q_system_prompt = """Given a chat history and the latest user question \
which might reference context in the chat history, formulate a standalone question \
which can be understood without the chat history. Do NOT answer the question, \
just reformulate it if needed and otherwise return it as is."""

qa_system_prompt = """You are an assistant for question-answering tasks. \
Use the following pieces of retrieved context to answer the question. \
If you don't know the answer, just say that you don't know. \
Use three sentences maximum and keep the answer concise.\
{context}"""

answers = {}
for _llm in llm_list:
#print(_llm)
chat_history = custom_history(message_history, _llm)
if _llm == 'gpt-4':
llm = OpenAI(temperature=0.2)
elif _llm == 'gemini':
llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro-latest")
elif _llm == 'claude':
llm = ChatAnthropic(model_name="claude-3-opus-20240229")

print(chat_history)
contextualize_q_prompt = ChatPromptTemplate.from_messages(
[
("system", contextualize_q_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
history_aware_retriever = create_history_aware_retriever(
llm, retriever, contextualize_q_prompt
)

qa_prompt = ChatPromptTemplate.from_messages(
[
("system", qa_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
### Answer question ###
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
msg = rag_chain.invoke({"input": input_string, "chat_history": chat_history})
answers[_llm] = msg['answer']
print(msg['answer'])
#print(answers)

# chat_history.extend([HumanMessage(content=input_string), AIMessage(content=msg_1["answer"])])
# print(msg_1['input'])
# print(msg_1['answer'])
# print(chat_history)
# msg_2 = rag_chain.invoke({"input": test_follow_up, "chat_history": chat_history})
# chat_history.extend([HumanMessage(content=test_follow_up), AIMessage(content=msg_2["answer"])])
# print(msg_2['input'])
# print(msg_2['answer'])
# print(chat_history)
return answers

if __name__ == "__main__":
main()

0 comments on commit 2add1c1

Please sign in to comment.