Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to upload files to GradioUI #138

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ wandb
# Data
data
outputs
data/

# Apple
.DS_Store
Expand Down
11 changes: 11 additions & 0 deletions examples/gradio_upload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from smolagents import (
CodeAgent,
HfApiModel,
GradioUI
)

agent = CodeAgent(
tools=[], model=HfApiModel(), max_steps=4, verbose=True
)

GradioUI(agent, UPLOAD_FOLDER='./data').launch()
58 changes: 56 additions & 2 deletions src/smolagents/gradio_ui.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#!/usr/bin/env python
# coding=utf-8

# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -15,6 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gradio as gr
import shutil
import os
import mimetypes
import re

from .agents import ActionStep, AgentStep, MultiStepAgent
from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
Expand Down Expand Up @@ -82,8 +85,12 @@ def stream_to_gradio(
class GradioUI:
"""A one-line interface to launch your agent in Gradio"""

def __init__(self, agent: MultiStepAgent):
def __init__(self, agent: MultiStepAgent, UPLOAD_FOLDER: str | None=None):
self.agent = agent
self.UPLOAD_FOLDER = UPLOAD_FOLDER
if self.UPLOAD_FOLDER is not None:
if not os.path.exists(UPLOAD_FOLDER):
os.mkdir(UPLOAD_FOLDER)

def interact_with_agent(self, prompt, messages):
messages.append(gr.ChatMessage(role="user", content=prompt))
Expand All @@ -93,6 +100,45 @@ def interact_with_agent(self, prompt, messages):
yield messages
yield messages

def upload_file(self, file, allowed_file_types=["application/pdf", "application/vnd.openxmlformats-officedocument.wordprocessingml.document", "text/plain"]):
"""
Handle file uploads, default allowed types are pdf, docx, and .txt
"""

# Check if file is uploaded
if file is None:
return "No file uploaded"

# Check if file is in allowed filetypes
name = os.path.basename(file.name)
try:
mime_type, _ = mimetypes.guess_type(file.name)
except Exception as e:
return f"Error: {e}"

if mime_type not in allowed_file_types:
return "File type disallowed"

# Sanitize file name
original_name = os.path.basename(file.name)
sanitized_name = re.sub(r'[^\w\-.]', '_', original_name) # Replace any non-alphanumeric, non-dash, or non-dot characters with underscores

type_to_ext = {}
for ext, t in mimetypes.types_map.items():
if t not in type_to_ext:
type_to_ext[t] = ext

# Ensure the extension correlates to the mime type
sanitized_name = sanitized_name.split(".")[:-1]
sanitized_name.append("" + type_to_ext[mime_type])
sanitized_name = "".join(sanitized_name)

# Save the uploaded file to the specified folder
file_path = os.path.join(self.UPLOAD_FOLDER, os.path.basename(sanitized_name))
shutil.copy(file.name, file_path)

return f"File uploaded successfully to {self.UPLOAD_FOLDER}"

def launch(self):
with gr.Blocks() as demo:
stored_message = gr.State([])
Expand All @@ -104,6 +150,14 @@ def launch(self):
"https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png",
),
)
# If an upload folder is provided, enable the upload feature
if self.UPLOAD_FOLDER is not None:
upload_file = gr.File(label="Upload a file")
upload_status = gr.Textbox(label="Upload Status", interactive=False)

upload_file.change(
self.upload_file, [upload_file], [upload_status]
)
text_input = gr.Textbox(lines=1, label="Chat Message")
text_input.submit(
lambda s: (s, ""), [text_input], [stored_message, text_input]
Expand Down