-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathRetriever_Agent_youtube.py
86 lines (72 loc) · 2.47 KB
/
Retriever_Agent_youtube.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from keys import openkey
from langchain.prompts import PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.chains import LLMChain, RetrievalQA
from langchain.chains.conversation.memory import ConversationBufferMemory
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores.neo4j_vector import Neo4jVector
from langchain.agents import AgentType, initialize_agent
from langchain.tools import Tool, YouTubeSearchTool
OPENAI_API_KEY = openkey
llm = ChatOpenAI(
openai_api_key=OPENAI_API_KEY
)
youtube = YouTubeSearchTool()
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
prompt = PromptTemplate(
template="""
You are a movie expert. You find movies from a genre or plot.
ChatHistory:{chat_history}
Question:{input}
""",
input_variables=["chat_history", "input"]
)
chat_chain = LLMChain(llm=llm, prompt=prompt, memory=memory, verbose=True)
embedding_provider = OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY)
movie_plot_vector = Neo4jVector.from_existing_index(
embedding_provider,
url="bolt://54.174.85.163:7687",
username="neo4j",
password="rejections-gyroscopes-street",
index_name="moviePlots",
embedding_node_property="embedding",
text_node_property="plot",
)
retrievalQA = RetrievalQA.from_llm(
llm=llm,
retriever=movie_plot_vector.as_retriever(),
verbose=True,
return_source_documents=True
)
def run_retriever(query):
results = retrievalQA({"query":query})
return str(results)
tools = [
Tool.from_function(
name="ChatOpenAI",
description="For when you need to chat about movies, genres or plots. The question will be a string. Return a string.",
func=chat_chain.run,
return_direct=True
),
Tool.from_function(
name="YouTubeSearchTool",
description="For when you need a link to a movie trailer. The question will be a string. Return a link to a YouTube video.",
func=youtube.run,
return_direct=True
),
Tool.from_function(
name="PlotRetrieval",
description="For when you need to compare a plot to a movie. The question will be a string. Return a string.",
func=run_retriever,
return_direct=True
)
]
agent = initialize_agent(
tools, llm, memory=memory,
agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION,
verbose=True,
handle_parsing_errors=True,
)
while True:
q = input(">")
print(agent.run(q))