Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make it work in new Forge (Gradio 4) #215

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 32 additions & 14 deletions scripts/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
import gradio as gr
from collections import OrderedDict
from scipy.ndimage import binary_dilation
import json
from modules import scripts, shared, script_callbacks
from modules.ui import gr_show
from modules.ui_components import FormRow
from modules.ui_components import FormRow, ToolButton
from modules.safe import unsafe_torch_load, load
from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusionProcessing
from modules.devices import device, torch_gc, cpu
Expand All @@ -21,6 +22,18 @@
from scripts.auto import clear_sem_sam_cache, register_auto_sam, semantic_segmentation, sem_sam_garbage_collect, image_layer_internal, categorical_mask_image
from scripts.process_params import SAMProcessUnit, max_cn_num

import importlib.metadata

def get_gradio_version():
try:
version = importlib.metadata.version("gradio")
return version
except importlib.metadata.PackageNotFoundError:
return None

# Example usage
gradio_version = get_gradio_version()


refresh_symbol = '\U0001f504' # 🔄
sam_model_cache = OrderedDict()
Expand All @@ -35,16 +48,6 @@
txt2img_height: gr.Slider = None
img2img_width: gr.Slider = None
img2img_height: gr.Slider = None


class ToolButton(gr.Button, gr.components.FormComponent):
"""Small button with single emoji as text, fits inside gradio forms"""

def __init__(self, **kwargs):
super().__init__(variant="tool", **kwargs)

def get_block_name(self):
return "button"


def show_masks(image_np, masks: np.ndarray, alpha=0.5):
Expand All @@ -59,7 +62,15 @@ def show_masks(image_np, masks: np.ndarray, alpha=0.5):
def update_mask(mask_gallery, chosen_mask, dilation_amt, input_image):
print("Dilation Amount: ", dilation_amt)
if isinstance(mask_gallery, list):
mask_image = Image.open(mask_gallery[chosen_mask + 3]['name'])
image_data = mask_gallery[chosen_mask + 3]
# In Gradio 4 or above, this is an (image, caption) tuple
if isinstance(image_data, tuple) and len(image_data) == 2:
mask_image = Image.open(image_data[0])
# In Gradio 3 this is a dict with 'name', 'data', 'is_file' keys
elif isinstance(image_data, dict) and 'name' in image_data:
mask_image = Image.open(image_data['name'])
else:
raise TypeError("Cannot locate mask image while expanding mask")
else:
mask_image = mask_gallery
binary_img = np.array(mask_image.convert('1'))
Expand Down Expand Up @@ -183,9 +194,16 @@ def create_mask_batch_output(
output_blend.save(os.path.join(dino_batch_dest_dir, f"{filename}_{idx}_blend{ext}"))


def parse_points(points):
if isinstance(points, str):
points = json.loads(points) # Convert string representation of list back to list
return points

def sam_predict(sam_model_name, input_image, positive_points, negative_points,
dino_checkbox, dino_model_name, text_prompt, box_threshold,
dino_preview_checkbox, dino_preview_boxes_selection):
positive_points = parse_points(positive_points)
negative_points = parse_points(negative_points)
print("Start SAM Processing")
if sam_model_name is None:
return [], "SAM model not found. Please download SAM model from extension README."
Expand Down Expand Up @@ -554,7 +572,7 @@ def ui(self, is_img2img):
with gr.Column(scale=10):
with gr.Row():
sam_model_name = gr.Dropdown(label="SAM Model", choices=sam_model_list, value=sam_model_list[0] if len(sam_model_list) > 0 else None)
sam_refresh_models = ToolButton(value=refresh_symbol)
sam_refresh_models = ToolButton(value=refresh_symbol, variant="tool")
sam_refresh_models.click(refresh_sam_models, sam_model_name, sam_model_name)
with gr.Column(scale=1):
sam_use_cpu = gr.Checkbox(value=False, label="Use CPU for SAM")
Expand All @@ -567,7 +585,7 @@ def change_sam_device(use_cpu=False):
gr.HTML(value="<p>Left click the image to add one positive point (black dot). Right click the image to add one negative point (red dot). Left click the point to remove it.</p>")
sam_input_image = gr.Image(label="Image for Segment Anything", elem_id=f"{tab_prefix}input_image", source="upload", type="pil", image_mode="RGBA")
sam_remove_dots = gr.Button(value="Remove all point prompts")
sam_dummy_component = gr.Label(visible=False)
sam_dummy_component = gr.Textbox(visible=False)
sam_remove_dots.click(
fn=lambda _: None,
_js="samRemoveDots",
Expand Down