From 07350a90d65bb296c6abdef4b4a8144e38cef7ab Mon Sep 17 00:00:00 2001 From: Wannabeasmartguy <997139385@qq.com> Date: Sat, 21 Oct 2023 15:59:51 +0800 Subject: [PATCH] Now, before paying, estimate the embedding cost. 1. It is now possible to view the estimated cost of a document before Embedding it into the knowledge base. 2. Enhance stability --- GPT-Gradio-Agent.py | 31 +++++++++++++++++++------------ README.md | 2 +- README_zh-cn.md | 2 +- vecstore/vecstore.py | 34 +++++++++++++++++++++++++++++++++- 4 files changed, 54 insertions(+), 15 deletions(-) diff --git a/GPT-Gradio-Agent.py b/GPT-Gradio-Agent.py index 7d25970..762053d 100644 --- a/GPT-Gradio-Agent.py +++ b/GPT-Gradio-Agent.py @@ -127,18 +127,15 @@ def upload_file(file_obj, loader = UnstructuredFileLoader(file_obj.name) document = loader.load() progress(0.3, desc="Loading the file...") - except FileNotFoundError: - raise gr.Error("File upload failed. Please try again.") + except (FileNotFoundError,PDFInfoNotInstalledError): + raise gr.Error("File upload failed. This may be due to formatting issues (non-standard formats)") # initialize splitter - try: - text_splitter = CharacterTextSplitter(chunk_size=150, chunk_overlap=10) - split_docs = text_splitter.split_documents(document) - split_tmp.append(split_docs) - progress(1, desc="Dealing...") - gr.Info("Processing completed.") - except (PDFInfoNotInstalledError,FileNotFoundError): - raise gr.Error("PDF dealing error.This may be due to formatting issues (non-standard formats)") + text_splitter = CharacterTextSplitter(chunk_size=150, chunk_overlap=10) + split_docs = text_splitter.split_documents(document) + split_tmp.append(split_docs) + progress(1, desc="Dealing...") + gr.Info("Processing completed.") return split_tmp,gr.File(label="The file you want to chat with") @@ -229,7 +226,15 @@ def rst_mem(chat_his:list): file_answer = gr.State(['0']) with gr.Column(): - file = gr.File(label="The file you want to chat with") + with gr.Group(): + file = gr.File(label="The file you want to chat with") + with gr.Row(): + estimate_cost = gr.Text(label="Estimated cost:", + info="Estimated cost of embed file", + scale=2) + refresh_file_cost = gr.Button(value="Refresh file and estimate cost", + scale=1) + with gr.Group(): vector_path = gr.Text(label="Knowledge base save path", info="Choose the folder you want to save, and PASTE THE ABSOLUTE PATH here") @@ -272,7 +277,9 @@ def rst_mem(chat_his:list): send.click(lambda: gr.update(value=''), [],[message]) # chat_file button event - file.upload(upload_file,inputs=[file,split_tmp],outputs=[split_tmp,file],show_progress="full") + file.upload(upload_file,inputs=[file,split_tmp],outputs=[split_tmp,file],show_progress="full").then(cal_token_cost,[split_tmp],[estimate_cost]) + file.clear(lambda:gr.update(value=''),[],[estimate_cost]) + refresh_file_cost.click(lambda:gr.Text("预计消耗费用:To be calculated"),[],[estimate_cost]).then(lambda:gr.File(),[],[file]).then(lambda:gr.Text(),[],[estimate_cost]) chat_with_file.click(ask_file,inputs=[chat_bot,message,file_answer,model_choice,sum_type,vector_path,file_list,filter_choice],outputs=[chat_bot,file_answer]).then(file_ask_stream,[chat_bot,file_answer],[chat_bot]) summarize.click(summarize_file,inputs=[split_tmp,chat_bot,model_choice,sum_type],outputs=[sum_result,chat_bot]).then(sum_stream,[sum_result,chat_bot],[chat_bot]) diff --git a/README.md b/README.md index 303f990..589c452 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ Then use `pip install -r requirements.txt` on the Command Prompt to install the - [ ] List citation sources - - [ ] Estimated cost of embedding files + - [x] Estimated cost of embedding files - [ ] Import and export chat history diff --git a/README_zh-cn.md b/README_zh-cn.md index 5218304..27dc44f 100644 --- a/README_zh-cn.md +++ b/README_zh-cn.md @@ -49,7 +49,7 @@ - [ ] 显示引用来源 - - [ ] 预估嵌入文件的费用 + - [x] 预估嵌入文件的费用 - [ ] 聊天记录导入、导出 diff --git a/vecstore/vecstore.py b/vecstore/vecstore.py index 78b33c9..24c3af8 100644 --- a/vecstore/vecstore.py +++ b/vecstore/vecstore.py @@ -5,6 +5,7 @@ from langchain.chains import RetrievalQA import gradio as gr import pandas as pd +import tiktoken def _init(): global vec_store @@ -224,4 +225,35 @@ def find_source_paths(filename:str, data:dict): source = metadata.get('source') if source and filename in source and source not in paths: paths.append(source) - return paths \ No newline at end of file + return paths + +def calculate_and_display_token_count(input_text:str,model_name:str): + ''' + Calculate the token that embedding needs to be consumed, for being called + ''' + # model name or encode type + encoder = tiktoken.encoding_for_model(model_name) # model name + # encoder = tiktoken.get_encoding("cl100k_base") # encode type + + encoded_text = encoder.encode(input_text) + token_count = len(encoded_text) + pay_for_token = (token_count/1000) * 0.002 + + # print(f"输入的文本: '{input_text}'") + # print(f"对应的编码: {encoded_text}") + # print(f"Token数量: {token_count}") + # print("预计消耗费用: $ %0.5f\n"%pay_for_token) + return pay_for_token + +def cal_token_cost(split_docs,model_name="text-embedding-ada-002"): + ''' + Calculate the token that embedding needs to be consumed, for operation + ''' + cost = 0 + try: + for i in split_docs[-1]: + paid_per_doc = calculate_and_display_token_count(input_text=i.page_content,model_name=model_name) + cost += paid_per_doc + return gr.Text("预计消耗费用: $ %0.5f"%cost) + except AttributeError: + raise gr.Error("Cost calculating failed") \ No newline at end of file