diff --git a/GPT-Gradio-Agent.py b/GPT-Gradio-Agent.py index 4ac5ca6..7d25970 100644 --- a/GPT-Gradio-Agent.py +++ b/GPT-Gradio-Agent.py @@ -13,7 +13,6 @@ from langchain.chat_models import AzureChatOpenAI from langchain.document_loaders import DirectoryLoader,PyPDFLoader,UnstructuredFileLoader from langchain.chains import RetrievalQA -from langchain.chains.summarize import load_summarize_chain load_dotenv() @@ -158,21 +157,6 @@ def file_ask_stream(file_ask_history_list:list[list],file_answer:list): time.sleep(0.02) yield file_ask_history_list -def summarize_file(split_docs,chatbot,model_choice,sum_type): - llm = AzureChatOpenAI(model=model_choice, - openai_api_type="azure", - deployment_name=model_choice, # <----------设置选择模型的时候修改这里 - temperature=0.7) - # 创建总结链 - chain = load_summarize_chain(llm, chain_type=sum_type, verbose=True) - - # 执行总结链 - summarize_result = chain.run(split_docs[-1]) - - # 构造 chatbox 格式 - chatbot.append(["Please summarize the file for me.",None]) - return summarize_result,chatbot - def sum_stream(summarize_result,chatbot): ''' Used to make summarized result be outputed as stream. @@ -201,7 +185,7 @@ def rst_mem(chat_his:list): usr_msg = gr.State() chat_his = gr.State([]) with gr.Row(): - with gr.Column(scale=1.8): + with gr.Column(scale=2): model_choice = gr.Radio(choices=["gpt-35-turbo","gpt-35-turbo-16k","gpt-4"], value="gpt-35-turbo", label="Model",info="支持模型选择,立即生效") @@ -210,9 +194,12 @@ def rst_mem(chat_his:list): bubble_full_width=False) message = gr.Textbox(label="Input your prompt", info="'Shift + Enter' to begin an new line. Press 'Enter' can also send your Prompt to the LLM.") - with gr.Row(scale=0.1): + with gr.Row(): clear = gr.ClearButton([message, chat_bot,chat_his],scale=1,size="sm") send = gr.Button("Send",scale=2) + with gr.Row(): + chat_with_file = gr.Button(value="Chat with file (Valid for knowledge base)") + summarize = gr.Button(value="Summarize (Valid only for uploaded file)") with gr.Column(): with gr.Tab("Chat"): @@ -247,9 +234,10 @@ def rst_mem(chat_his:list): vector_path = gr.Text(label="Knowledge base save path", info="Choose the folder you want to save, and PASTE THE ABSOLUTE PATH here") with gr.Row(): - vector_content = gr.DataFrame(label="Knowledge Base Document Catalog", - interactive=False, - ) + vector_content = gr.DataFrame(#label="Knowledge Base Document Catalog", + value = pd.DataFrame(columns=['文件名称']), + interactive=False, + ) file_list = gr.Dropdown(interactive=True, # allow_custom_value=True, label="File list") @@ -257,15 +245,17 @@ def rst_mem(chat_his:list): create_vec_but = gr.Button(value="Create a new knowledge base") load_vec = gr.Button(value="Load your knowledge base") with gr.Row(): - add_file = gr.Button(value="Add it(The file uploaded) to knowledge base") - delete_file = gr.Button(value="Delete it(selected in dropdown) from knowledge base") - sum_type = gr.Radio(choices=[("小文件(file with few words)","stuff"),("大文件(file with a large word count)","refine")], - value="stuff", - label="Choose the type of file to be summarized", - info="如果待总结字数较多,请选择“大文件”(选小文件可能导致超出 GPT 的最大 Token )") - with gr.Row(): - chat_with_file = gr.Button(value="Chat with file") - summarize = gr.Button(value="Summarize") + add_file = gr.Button(value="Add it (The file uploaded) to knowledge base") + delete_file = gr.Button(value="Delete it (Selected in dropdown) from knowledge base") + with gr.Accordion("File chat setting"): + filter_choice = gr.Radio(choices=["All", "Selected file"], + value="All", + label="Search scope", + info="“All” means whole knowledge base;“Selected file” means the file selected in dropdown") + sum_type = gr.Radio(choices=[("small file","stuff"),("large file","refine")], + value="stuff", + label="File size type", + info="也作用于“Summarize”。如果待总结字数较多,请选择“lagre size”(选“large size”可能导致超出 GPT 的最大 Token )") # Merge all handles that require input and output. input_param = [message, model_choice, chat_his, chat_bot, System_Prompt, @@ -283,7 +273,7 @@ def rst_mem(chat_his:list): # chat_file button event file.upload(upload_file,inputs=[file,split_tmp],outputs=[split_tmp,file],show_progress="full") - chat_with_file.click(ask_file,inputs=[chat_bot,message,file_answer,model_choice,sum_type,vector_path,file_list],outputs=[chat_bot,file_answer]).then(file_ask_stream,[chat_bot,file_answer],[chat_bot]) + chat_with_file.click(ask_file,inputs=[chat_bot,message,file_answer,model_choice,sum_type,vector_path,file_list,filter_choice],outputs=[chat_bot,file_answer]).then(file_ask_stream,[chat_bot,file_answer],[chat_bot]) summarize.click(summarize_file,inputs=[split_tmp,chat_bot,model_choice,sum_type],outputs=[sum_result,chat_bot]).then(sum_stream,[sum_result,chat_bot],[chat_bot]) chat_with_file.click(lambda: gr.update(value=''), [],[message]) diff --git a/README.md b/README.md index e87d903..303f990 100644 --- a/README.md +++ b/README.md @@ -50,9 +50,11 @@ Then use `pip install -r requirements.txt` on the Command Prompt to install the - [x] Local knowledge base management - - [ ] Chat with whole knowledge base - -- [ ] Local storage of data + - [x] Chat with whole knowledge base + + - [ ] List citation sources + + - [ ] Estimated cost of embedding files - [ ] Import and export chat history diff --git a/README_zh-cn.md b/README_zh-cn.md index f6214a4..5218304 100644 --- a/README_zh-cn.md +++ b/README_zh-cn.md @@ -43,13 +43,13 @@ - [x] 文件全文总结 - - [x] 知识库本地存储 - - [x] 知识库本地管理 - - [ ] 多文件对话 - -- [ ] 数据本地存储 + - [x] 知识库全局检索与对话 + + - [ ] 显示引用来源 + + - [ ] 预估嵌入文件的费用 - [ ] 聊天记录导入、导出 diff --git a/vecstore/vecstore.py b/vecstore/vecstore.py index 456dfc6..78b33c9 100644 --- a/vecstore/vecstore.py +++ b/vecstore/vecstore.py @@ -1,4 +1,4 @@ - +from langchain.chains.summarize import load_summarize_chain from langchain.vectorstores import Chroma from langchain.embeddings.openai import OpenAIEmbeddings from langchain.chat_models import AzureChatOpenAI @@ -148,26 +148,14 @@ def refresh_file_list(df): gr.Info('Successfully update kowledge base.') return gr.Dropdown.update(choices=file_list) -def find_source_paths(filenames, data): - ''' - Retrieve file paths in a vector database based on file name and remove duplicate paths - ''' - paths = [] - for metadata in data['metadatas']: - source = metadata.get('source') - if source: - for filename in filenames: - if filename in source and source not in paths: - paths.append(source) - return paths - def ask_file(file_ask_history_list:list, question_prompt: str, file_answer:list, model_choice:str, sum_type:str, persist_vec_path, - file_list + file_list, + filter_type:str, ): ''' send splitted file to LLM @@ -178,11 +166,11 @@ def ask_file(file_ask_history_list:list, temperature=0.7) source_data = vectorstore.get() - filter_goal = find_source_paths(filenames=file_list,data=source_data) + filter_goal = find_source_paths(file_list,source_data) if persist_vec_path != None: - # docsearch = Chroma.from_documents(split_docs[-1], embeddings) - if file_list == "Unselect file(s)" or file_list != None: + # Codes here in "if" may be deleted or modified later + if filter_type == "All": # unselect file: retrieve whole knowledge base try: qa = RetrievalQA.from_chain_type(llm=llm, chain_type=sum_type, @@ -191,7 +179,7 @@ def ask_file(file_ask_history_list:list, result = qa({"query": question_prompt}) except (NameError): raise gr.Error("You have not load kownledge base yet.") - else: + elif filter_type == "Selected file": # only selected one file # Retrieve the specified knowledge base with filter qa = RetrievalQA.from_chain_type(llm=llm, chain_type=sum_type, @@ -211,15 +199,29 @@ def ask_file(file_ask_history_list:list, file_ask_history_list.append([usr_prob,None]) return file_ask_history_list,file_answer -def find_source_paths(filenames:list, data): +def summarize_file(split_docs,chatbot,model_choice,sum_type): + llm = AzureChatOpenAI(model=model_choice, + openai_api_type="azure", + deployment_name=model_choice, # <----------设置选择模型的时候修改这里 + temperature=0.7) + # 创建总结链 + chain = load_summarize_chain(llm, chain_type=sum_type, verbose=True) + + # 执行总结链 + summarize_result = chain.run(split_docs[-1]) + + # 构造 chatbox 格式 + chatbot.append(["Please summarize the file for me.",None]) + return summarize_result,chatbot + +def find_source_paths(filename:str, data:dict): """ Find the source paths of the files in the knowledge base. + return --> list """ paths = [] for metadata in data['metadatas']: source = metadata.get('source') - if source: - for filename in filenames: - if filename in source and source not in paths: - paths.append(source) + if source and filename in source and source not in paths: + paths.append(source) return paths \ No newline at end of file