Skip to content

Commit

Permalink
Now, before paying, estimate the embedding cost.
Browse files Browse the repository at this point in the history
1. It is now possible to view the estimated cost of a document before Embedding it into the knowledge base.
2. Enhance stability
  • Loading branch information
Wannabeasmartguy committed Oct 21, 2023
1 parent a426764 commit 07350a9
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 15 deletions.
31 changes: 19 additions & 12 deletions GPT-Gradio-Agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,18 +127,15 @@ def upload_file(file_obj,
loader = UnstructuredFileLoader(file_obj.name)
document = loader.load()
progress(0.3, desc="Loading the file...")
except FileNotFoundError:
raise gr.Error("File upload failed. Please try again.")
except (FileNotFoundError,PDFInfoNotInstalledError):
raise gr.Error("File upload failed. This may be due to formatting issues (non-standard formats)")

# initialize splitter
try:
text_splitter = CharacterTextSplitter(chunk_size=150, chunk_overlap=10)
split_docs = text_splitter.split_documents(document)
split_tmp.append(split_docs)
progress(1, desc="Dealing...")
gr.Info("Processing completed.")
except (PDFInfoNotInstalledError,FileNotFoundError):
raise gr.Error("PDF dealing error.This may be due to formatting issues (non-standard formats)")
text_splitter = CharacterTextSplitter(chunk_size=150, chunk_overlap=10)
split_docs = text_splitter.split_documents(document)
split_tmp.append(split_docs)
progress(1, desc="Dealing...")
gr.Info("Processing completed.")

return split_tmp,gr.File(label="The file you want to chat with")

Expand Down Expand Up @@ -229,7 +226,15 @@ def rst_mem(chat_his:list):
file_answer = gr.State(['0'])

with gr.Column():
file = gr.File(label="The file you want to chat with")
with gr.Group():
file = gr.File(label="The file you want to chat with")
with gr.Row():
estimate_cost = gr.Text(label="Estimated cost:",
info="Estimated cost of embed file",
scale=2)
refresh_file_cost = gr.Button(value="Refresh file and estimate cost",
scale=1)

with gr.Group():
vector_path = gr.Text(label="Knowledge base save path",
info="Choose the folder you want to save, and PASTE THE ABSOLUTE PATH here")
Expand Down Expand Up @@ -272,7 +277,9 @@ def rst_mem(chat_his:list):
send.click(lambda: gr.update(value=''), [],[message])

# chat_file button event
file.upload(upload_file,inputs=[file,split_tmp],outputs=[split_tmp,file],show_progress="full")
file.upload(upload_file,inputs=[file,split_tmp],outputs=[split_tmp,file],show_progress="full").then(cal_token_cost,[split_tmp],[estimate_cost])
file.clear(lambda:gr.update(value=''),[],[estimate_cost])
refresh_file_cost.click(lambda:gr.Text("预计消耗费用:To be calculated"),[],[estimate_cost]).then(lambda:gr.File(),[],[file]).then(lambda:gr.Text(),[],[estimate_cost])
chat_with_file.click(ask_file,inputs=[chat_bot,message,file_answer,model_choice,sum_type,vector_path,file_list,filter_choice],outputs=[chat_bot,file_answer]).then(file_ask_stream,[chat_bot,file_answer],[chat_bot])
summarize.click(summarize_file,inputs=[split_tmp,chat_bot,model_choice,sum_type],outputs=[sum_result,chat_bot]).then(sum_stream,[sum_result,chat_bot],[chat_bot])

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ Then use `pip install -r requirements.txt` on the Command Prompt to install the

- [ ] List citation sources

- [ ] Estimated cost of embedding files
- [x] Estimated cost of embedding files

- [ ] Import and export chat history

Expand Down
2 changes: 1 addition & 1 deletion README_zh-cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@

- [ ] 显示引用来源

- [ ] 预估嵌入文件的费用
- [x] 预估嵌入文件的费用

- [ ] 聊天记录导入、导出

Expand Down
34 changes: 33 additions & 1 deletion vecstore/vecstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from langchain.chains import RetrievalQA
import gradio as gr
import pandas as pd
import tiktoken

def _init():
global vec_store
Expand Down Expand Up @@ -224,4 +225,35 @@ def find_source_paths(filename:str, data:dict):
source = metadata.get('source')
if source and filename in source and source not in paths:
paths.append(source)
return paths
return paths

def calculate_and_display_token_count(input_text:str,model_name:str):
'''
Calculate the token that embedding needs to be consumed, for being called
'''
# model name or encode type
encoder = tiktoken.encoding_for_model(model_name) # model name
# encoder = tiktoken.get_encoding("cl100k_base") # encode type

encoded_text = encoder.encode(input_text)
token_count = len(encoded_text)
pay_for_token = (token_count/1000) * 0.002

# print(f"输入的文本: '{input_text}'")
# print(f"对应的编码: {encoded_text}")
# print(f"Token数量: {token_count}")
# print("预计消耗费用: $ %0.5f\n"%pay_for_token)
return pay_for_token

def cal_token_cost(split_docs,model_name="text-embedding-ada-002"):
'''
Calculate the token that embedding needs to be consumed, for operation
'''
cost = 0
try:
for i in split_docs[-1]:
paid_per_doc = calculate_and_display_token_count(input_text=i.page_content,model_name=model_name)
cost += paid_per_doc
return gr.Text("预计消耗费用: $ %0.5f"%cost)
except AttributeError:
raise gr.Error("Cost calculating failed")

0 comments on commit 07350a9

Please sign in to comment.