generated from amosproj/amos202Xss0Y-projname
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
History aware Retriever and multiple LLM answers and basic in/out
Signed-off-by: potthoffjan <[email protected]>
- Loading branch information
1 parent
a7253d4
commit 2add1c1
Showing
1 changed file
with
209 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,209 @@ | ||
import json | ||
import os | ||
import sys | ||
|
||
|
||
from dotenv import load_dotenv | ||
|
||
from astrapy import DataAPIClient | ||
from astrapy.db import AstraDB | ||
from langchain_astradb import AstraDBVectorStore | ||
|
||
#from langchain.embeddings import OpenAIEmbeddings | ||
#from langchain_community.embeddings import OpenAIEmbeddings | ||
from langchain_openai import OpenAIEmbeddings | ||
from langchain.schema import Document | ||
from langchain.text_splitter import RecursiveCharacterTextSplitter | ||
#from langchain_community.embeddings import HuggingFaceEmbeddings | ||
#from langchain_huggingface import HuggingFaceEmbeddings | ||
#from langchain_community.llms.openai import OpenAI | ||
from langchain_openai import OpenAI | ||
from langchain_google_genai import ChatGoogleGenerativeAI | ||
from langchain_community.vectorstores.chroma import Chroma | ||
|
||
from langchain_core.prompts import ChatPromptTemplate | ||
from langchain.chains import create_history_aware_retriever | ||
from langchain_core.prompts import MessagesPlaceholder | ||
from langchain.chains import create_retrieval_chain | ||
from langchain.chains.combine_documents import create_stuff_documents_chain | ||
|
||
from langchain_community.chat_message_histories import ChatMessageHistory | ||
from langchain_core.chat_history import BaseChatMessageHistory | ||
from langchain_core.runnables.history import RunnableWithMessageHistory | ||
from langchain_core.messages import HumanMessage, AIMessage | ||
from langchain_anthropic import AnthropicLLM | ||
from langchain_anthropic import ChatAnthropic | ||
|
||
""" | ||
Script takes in args: | ||
0) list of LLMs with which to retrieve e.g. ['gpt-4', 'gemini', 'mistral'] | ||
1) input string | ||
2) chat history in following shape [ | ||
{'gpt-4': "Hello, how can I help you?"}, | ||
{'user': "What do prisons and plants have in common?"} | ||
] | ||
""" | ||
|
||
def custom_history(entire_history:list, llm_name:str): | ||
chat_history = [] | ||
for msg in entire_history: | ||
if 'user' in msg: | ||
chat_history.extend([HumanMessage(content=msg['user'])]) | ||
if llm_name in msg: | ||
chat_history.extend([AIMessage(content=msg[llm_name])]) | ||
return chat_history | ||
|
||
|
||
def main(): | ||
|
||
if len(sys.argv) < 3: | ||
print("""Error: Please provide: | ||
1) [list of LLM models to use] | ||
(['gpt-4', 'gemini', 'claude']) | ||
2) 'input string' | ||
3) [{chat history}] in the following shape: | ||
[{'gpt-4': "Hello, how can I help you?"}, | ||
{'user': "What do prisons and plants have in common?"} | ||
etc.]""") | ||
|
||
# Arguments | ||
llm_list = sys.argv[1] | ||
llm_list = list(llm_list.replace('[', '').replace(']', '').replace("'", '').split(',')) | ||
if not llm_list: | ||
llm_list = ['gpt-4'] | ||
#print(llm_list) | ||
input_string = sys.argv[2] | ||
#print(input_string) | ||
message_history = sys.argv[3] | ||
#print(message_history) | ||
message_history = message_history.split(';;') | ||
#print(message_history) | ||
message_history = [json.loads(substring.replace("'", '"')) for substring in message_history] | ||
#print(message_history) | ||
|
||
load_dotenv() | ||
|
||
# to be put into seperate function in order to invoke LLMs seperately | ||
openai_api_key = os.environ.get('OPENAI_API_KEY') | ||
GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY') | ||
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY") | ||
|
||
test_llm_list = ['gpt-4'] | ||
#llm_list = test_llm_list | ||
test_history = [ | ||
{'gpt-4': "Hello, how can I help you?", | ||
'gemini': "Hello, how can I help you?"}, | ||
{'user': "What do prisons and plants have in common?"}, | ||
{'gpt-4': "They both have cell walls.", | ||
'gemini': "They have cell walls."}, | ||
] | ||
# message_history = test_history | ||
|
||
test_query = "Ah, true. Thanks. What else do they have in common?" | ||
# test_query = "How many corners does a heptagon have?" | ||
# input_string = test_query | ||
# test_follow_up = "How does one call a polygon with two more corners?" | ||
|
||
# AstraDB Section | ||
ASTRA_DB_API_ENDPOINT = os.environ.get('ASTRA_DB_API_ENDPOINT') | ||
ASTRA_DB_APPLICATION_TOKEN = os.environ.get('ASTRA_DB_APPLICATION_TOKEN') | ||
ASTRA_DB_NAMESPACE = 'test' | ||
ASTRA_DB_COLLECTION = 'test_collection_2' | ||
|
||
# LangChain Docs: ------------------------- | ||
vstore = AstraDBVectorStore( | ||
embedding=OpenAIEmbeddings(openai_api_key=openai_api_key), | ||
collection_name=ASTRA_DB_COLLECTION, | ||
api_endpoint=ASTRA_DB_API_ENDPOINT, | ||
token=ASTRA_DB_APPLICATION_TOKEN, | ||
namespace=ASTRA_DB_NAMESPACE, | ||
) | ||
# ------------------------------------------ | ||
|
||
# For test purposes: ----------------------- | ||
# import bs4 | ||
# from langchain_chroma import Chroma | ||
# from langchain_community.document_loaders import WebBaseLoader | ||
|
||
# loader = WebBaseLoader( | ||
# web_paths=("https://lilianweng.github.io/posts/2023-06-23-agent/",), | ||
# bs_kwargs=dict( | ||
# parse_only=bs4.SoupStrainer( | ||
# class_=("post-content", "post-title", "post-header") | ||
# ) | ||
# ), | ||
# ) | ||
# docs = loader.load() | ||
|
||
# text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | ||
# splits = text_splitter.split_documents(docs) | ||
# vectorstore = Chroma.from_documents(documents=splits, embedding=OpenAIEmbeddings()) | ||
# retriever = vectorstore.as_retriever() | ||
# # test end ---------------------------------- | ||
|
||
retriever = vstore.as_retriever(search_kwargs={"k": 3}) | ||
|
||
contextualize_q_system_prompt = """Given a chat history and the latest user question \ | ||
which might reference context in the chat history, formulate a standalone question \ | ||
which can be understood without the chat history. Do NOT answer the question, \ | ||
just reformulate it if needed and otherwise return it as is.""" | ||
|
||
qa_system_prompt = """You are an assistant for question-answering tasks. \ | ||
Use the following pieces of retrieved context to answer the question. \ | ||
If you don't know the answer, just say that you don't know. \ | ||
Use three sentences maximum and keep the answer concise.\ | ||
{context}""" | ||
|
||
answers = {} | ||
for _llm in llm_list: | ||
#print(_llm) | ||
chat_history = custom_history(message_history, _llm) | ||
if _llm == 'gpt-4': | ||
llm = OpenAI(temperature=0.2) | ||
elif _llm == 'gemini': | ||
llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro-latest") | ||
elif _llm == 'claude': | ||
llm = ChatAnthropic(model_name="claude-3-opus-20240229") | ||
|
||
print(chat_history) | ||
contextualize_q_prompt = ChatPromptTemplate.from_messages( | ||
[ | ||
("system", contextualize_q_system_prompt), | ||
MessagesPlaceholder("chat_history"), | ||
("human", "{input}"), | ||
] | ||
) | ||
history_aware_retriever = create_history_aware_retriever( | ||
llm, retriever, contextualize_q_prompt | ||
) | ||
|
||
qa_prompt = ChatPromptTemplate.from_messages( | ||
[ | ||
("system", qa_system_prompt), | ||
MessagesPlaceholder("chat_history"), | ||
("human", "{input}"), | ||
] | ||
) | ||
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) | ||
### Answer question ### | ||
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) | ||
msg = rag_chain.invoke({"input": input_string, "chat_history": chat_history}) | ||
answers[_llm] = msg['answer'] | ||
print(msg['answer']) | ||
#print(answers) | ||
|
||
# chat_history.extend([HumanMessage(content=input_string), AIMessage(content=msg_1["answer"])]) | ||
# print(msg_1['input']) | ||
# print(msg_1['answer']) | ||
# print(chat_history) | ||
# msg_2 = rag_chain.invoke({"input": test_follow_up, "chat_history": chat_history}) | ||
# chat_history.extend([HumanMessage(content=test_follow_up), AIMessage(content=msg_2["answer"])]) | ||
# print(msg_2['input']) | ||
# print(msg_2['answer']) | ||
# print(chat_history) | ||
return answers | ||
|
||
if __name__ == "__main__": | ||
main() | ||
|