diff --git a/scripts/sam.py b/scripts/sam.py index 775bff2..620d174 100644 --- a/scripts/sam.py +++ b/scripts/sam.py @@ -8,9 +8,10 @@ import gradio as gr from collections import OrderedDict from scipy.ndimage import binary_dilation +import json from modules import scripts, shared, script_callbacks from modules.ui import gr_show -from modules.ui_components import FormRow +from modules.ui_components import FormRow, ToolButton from modules.safe import unsafe_torch_load, load from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusionProcessing from modules.devices import device, torch_gc, cpu @@ -21,6 +22,18 @@ from scripts.auto import clear_sem_sam_cache, register_auto_sam, semantic_segmentation, sem_sam_garbage_collect, image_layer_internal, categorical_mask_image from scripts.process_params import SAMProcessUnit, max_cn_num +import importlib.metadata + +def get_gradio_version(): + try: + version = importlib.metadata.version("gradio") + return version + except importlib.metadata.PackageNotFoundError: + return None + +# Example usage +gradio_version = get_gradio_version() + refresh_symbol = '\U0001f504' # 🔄 sam_model_cache = OrderedDict() @@ -35,16 +48,6 @@ txt2img_height: gr.Slider = None img2img_width: gr.Slider = None img2img_height: gr.Slider = None - - -class ToolButton(gr.Button, gr.components.FormComponent): - """Small button with single emoji as text, fits inside gradio forms""" - - def __init__(self, **kwargs): - super().__init__(variant="tool", **kwargs) - - def get_block_name(self): - return "button" def show_masks(image_np, masks: np.ndarray, alpha=0.5): @@ -59,7 +62,15 @@ def show_masks(image_np, masks: np.ndarray, alpha=0.5): def update_mask(mask_gallery, chosen_mask, dilation_amt, input_image): print("Dilation Amount: ", dilation_amt) if isinstance(mask_gallery, list): - mask_image = Image.open(mask_gallery[chosen_mask + 3]['name']) + image_data = mask_gallery[chosen_mask + 3] + # In Gradio 4 or above, this is an (image, caption) tuple + if isinstance(image_data, tuple) and len(image_data) == 2: + mask_image = Image.open(image_data[0]) + # In Gradio 3 this is a dict with 'name', 'data', 'is_file' keys + elif isinstance(image_data, dict) and 'name' in image_data: + mask_image = Image.open(image_data['name']) + else: + raise TypeError("Cannot locate mask image while expanding mask") else: mask_image = mask_gallery binary_img = np.array(mask_image.convert('1')) @@ -183,9 +194,16 @@ def create_mask_batch_output( output_blend.save(os.path.join(dino_batch_dest_dir, f"{filename}_{idx}_blend{ext}")) +def parse_points(points): + if isinstance(points, str): + points = json.loads(points) # Convert string representation of list back to list + return points + def sam_predict(sam_model_name, input_image, positive_points, negative_points, dino_checkbox, dino_model_name, text_prompt, box_threshold, dino_preview_checkbox, dino_preview_boxes_selection): + positive_points = parse_points(positive_points) + negative_points = parse_points(negative_points) print("Start SAM Processing") if sam_model_name is None: return [], "SAM model not found. Please download SAM model from extension README." @@ -554,7 +572,7 @@ def ui(self, is_img2img): with gr.Column(scale=10): with gr.Row(): sam_model_name = gr.Dropdown(label="SAM Model", choices=sam_model_list, value=sam_model_list[0] if len(sam_model_list) > 0 else None) - sam_refresh_models = ToolButton(value=refresh_symbol) + sam_refresh_models = ToolButton(value=refresh_symbol, variant="tool") sam_refresh_models.click(refresh_sam_models, sam_model_name, sam_model_name) with gr.Column(scale=1): sam_use_cpu = gr.Checkbox(value=False, label="Use CPU for SAM") @@ -567,7 +585,7 @@ def change_sam_device(use_cpu=False): gr.HTML(value="
Left click the image to add one positive point (black dot). Right click the image to add one negative point (red dot). Left click the point to remove it.
") sam_input_image = gr.Image(label="Image for Segment Anything", elem_id=f"{tab_prefix}input_image", source="upload", type="pil", image_mode="RGBA") sam_remove_dots = gr.Button(value="Remove all point prompts") - sam_dummy_component = gr.Label(visible=False) + sam_dummy_component = gr.Textbox(visible=False) sam_remove_dots.click( fn=lambda _: None, _js="samRemoveDots",