Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make it work in new Forge (Gradio 4) #215

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 60 additions & 14 deletions scripts/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from scipy.ndimage import binary_dilation
from modules import scripts, shared, script_callbacks
from modules.ui import gr_show
from modules.ui_components import FormRow
from modules.ui_components import FormRow, ToolButton
from modules.safe import unsafe_torch_load, load
from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusionProcessing
from modules.devices import device, torch_gc, cpu
Expand All @@ -21,6 +21,18 @@
from scripts.auto import clear_sem_sam_cache, register_auto_sam, semantic_segmentation, sem_sam_garbage_collect, image_layer_internal, categorical_mask_image
from scripts.process_params import SAMProcessUnit, max_cn_num

import importlib.metadata

def get_gradio_version():
try:
version = importlib.metadata.version("gradio")
return version
except importlib.metadata.PackageNotFoundError:
return None

# Example usage
gradio_version = get_gradio_version()


refresh_symbol = '\U0001f504' # 🔄
sam_model_cache = OrderedDict()
Expand All @@ -35,16 +47,6 @@
txt2img_height: gr.Slider = None
img2img_width: gr.Slider = None
img2img_height: gr.Slider = None


class ToolButton(gr.Button, gr.components.FormComponent):
"""Small button with single emoji as text, fits inside gradio forms"""

def __init__(self, **kwargs):
super().__init__(variant="tool", **kwargs)

def get_block_name(self):
return "button"


def show_masks(image_np, masks: np.ndarray, alpha=0.5):
Expand All @@ -56,7 +58,38 @@ def show_masks(image_np, masks: np.ndarray, alpha=0.5):
return image.astype(np.uint8)


def update_mask(mask_gallery, chosen_mask, dilation_amt, input_image):
def update_mask_gr_four(mask_gallery, chosen_mask, dilation_amt, input_image):
print("Dilation Amount: ", dilation_amt)

# Check if mask_gallery is a list
if isinstance(mask_gallery, list):
# Extract the image from the (image, caption) tuple
image_tuple = mask_gallery[chosen_mask + 3]
if isinstance(image_tuple, tuple) and len(image_tuple) == 2:
mask_image = image_tuple[0] # Get the image part of the tuple
if isinstance(mask_image, str): # If it's a file path
mask_image = Image.open(mask_image)
elif isinstance(mask_image, np.ndarray): # If it's an ndarray
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you seen a case where while using this extension we're stumbling upon np.ndarray or PIL images ? I've only filename strings passed here.

mask_image = Image.fromarray(mask_image)
# Else, it is already a PIL.Image
else:
raise TypeError("Expected a tuple (image, caption) in mask_gallery")
else:
# If not a list, use the provided mask_gallery as-is
mask_image = mask_gallery

binary_img = np.array(mask_image.convert('1'))

if dilation_amt:
mask_image, binary_img = dilate_mask(binary_img, dilation_amt)

blended_image = Image.fromarray(show_masks(np.array(input_image), binary_img.astype(np.bool_)[None, ...]))
matted_image = np.array(input_image)
matted_image[~binary_img] = np.array([0, 0, 0, 0])

return [blended_image, mask_image, Image.fromarray(matted_image)]

def update_mask_gr_three(mask_gallery, chosen_mask, dilation_amt, input_image):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In theory we could just use update_mask without splitting the functions, or the additional get_gradio_version code. Something like this.

Suggested change
def update_mask_gr_three(mask_gallery, chosen_mask, dilation_amt, input_image):
def update_mask(mask_gallery, chosen_mask, dilation_amt, input_image):
print("Dilation Amount: ", dilation_amt)
if isinstance(mask_gallery, list):
image_data = mask_gallery[chosen_mask + 3]
# In Gradio 4 or above, this is an (image, caption) tuple
if isinstance(image_data, tuple) and len(image_data) == 2:
mask_image = Image.open(image_data[0])
# In Gradio 3 this is a dict with 'name', 'data', 'is_file' keys
elif isinstance(image_data, dict) and 'name' in image_data:
mask_image = Image.open(image_data['name'])
else:
raise TypeError("Cannot locate mask image while expanding mask")
else:
mask_image = mask_gallery

This implies that we would crash if np.ndarray or PIL images are passed from a Gradio 4 gallery, but I actually haven't seen this case while using the extension.

print("Dilation Amount: ", dilation_amt)
if isinstance(mask_gallery, list):
mask_image = Image.open(mask_gallery[chosen_mask + 3]['name'])
Expand All @@ -70,6 +103,12 @@ def update_mask(mask_gallery, chosen_mask, dilation_amt, input_image):
matted_image[~binary_img] = np.array([0, 0, 0, 0])
return [blended_image, mask_image, Image.fromarray(matted_image)]

def update_mask(mask_gallery, chosen_mask, dilation_amt, input_image):
if gradio_version and gradio_version < "4.0.0":
return update_mask_gr_three(mask_gallery, chosen_mask, dilation_amt, input_image)
else:
return update_mask_gr_four(mask_gallery, chosen_mask, dilation_amt, input_image)


def load_sam_model(sam_checkpoint):
model_type = sam_checkpoint.split('.')[0]
Expand Down Expand Up @@ -183,9 +222,16 @@ def create_mask_batch_output(
output_blend.save(os.path.join(dino_batch_dest_dir, f"{filename}_{idx}_blend{ext}"))


def parse_points(points):
if isinstance(points, str):
points = eval(points) # Convert string representation of list back to list
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any specific reason for the eval ? It's not great for security reasons. In theory this should be json compliant and json.loads should do the trick.

Suggested change
points = eval(points) # Convert string representation of list back to list
points = json.loads(points) # Convert string representation of list back to list

return points

def sam_predict(sam_model_name, input_image, positive_points, negative_points,
dino_checkbox, dino_model_name, text_prompt, box_threshold,
dino_preview_checkbox, dino_preview_boxes_selection):
positive_points = parse_points(positive_points)
negative_points = parse_points(negative_points)
print("Start SAM Processing")
if sam_model_name is None:
return [], "SAM model not found. Please download SAM model from extension README."
Expand Down Expand Up @@ -554,7 +600,7 @@ def ui(self, is_img2img):
with gr.Column(scale=10):
with gr.Row():
sam_model_name = gr.Dropdown(label="SAM Model", choices=sam_model_list, value=sam_model_list[0] if len(sam_model_list) > 0 else None)
sam_refresh_models = ToolButton(value=refresh_symbol)
sam_refresh_models = ToolButton(value=refresh_symbol, variant="tool")
sam_refresh_models.click(refresh_sam_models, sam_model_name, sam_model_name)
with gr.Column(scale=1):
sam_use_cpu = gr.Checkbox(value=False, label="Use CPU for SAM")
Expand All @@ -567,7 +613,7 @@ def change_sam_device(use_cpu=False):
gr.HTML(value="<p>Left click the image to add one positive point (black dot). Right click the image to add one negative point (red dot). Left click the point to remove it.</p>")
sam_input_image = gr.Image(label="Image for Segment Anything", elem_id=f"{tab_prefix}input_image", source="upload", type="pil", image_mode="RGBA")
sam_remove_dots = gr.Button(value="Remove all point prompts")
sam_dummy_component = gr.Label(visible=False)
sam_dummy_component = gr.Textbox(visible=False)
sam_remove_dots.click(
fn=lambda _: None,
_js="samRemoveDots",
Expand Down