-
Notifications
You must be signed in to change notification settings - Fork 208
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make it work in new Forge (Gradio 4) #215
base: master
Are you sure you want to change the base?
Changes from 4 commits
fe0c82a
433c6eb
61dba99
ac4db78
470b0f8
e25c240
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -10,7 +10,7 @@ | |||||||||||||||||||||||||||||||
from scipy.ndimage import binary_dilation | ||||||||||||||||||||||||||||||||
from modules import scripts, shared, script_callbacks | ||||||||||||||||||||||||||||||||
from modules.ui import gr_show | ||||||||||||||||||||||||||||||||
from modules.ui_components import FormRow | ||||||||||||||||||||||||||||||||
from modules.ui_components import FormRow, ToolButton | ||||||||||||||||||||||||||||||||
from modules.safe import unsafe_torch_load, load | ||||||||||||||||||||||||||||||||
from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusionProcessing | ||||||||||||||||||||||||||||||||
from modules.devices import device, torch_gc, cpu | ||||||||||||||||||||||||||||||||
|
@@ -21,6 +21,18 @@ | |||||||||||||||||||||||||||||||
from scripts.auto import clear_sem_sam_cache, register_auto_sam, semantic_segmentation, sem_sam_garbage_collect, image_layer_internal, categorical_mask_image | ||||||||||||||||||||||||||||||||
from scripts.process_params import SAMProcessUnit, max_cn_num | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
import importlib.metadata | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
def get_gradio_version(): | ||||||||||||||||||||||||||||||||
try: | ||||||||||||||||||||||||||||||||
version = importlib.metadata.version("gradio") | ||||||||||||||||||||||||||||||||
return version | ||||||||||||||||||||||||||||||||
except importlib.metadata.PackageNotFoundError: | ||||||||||||||||||||||||||||||||
return None | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
# Example usage | ||||||||||||||||||||||||||||||||
gradio_version = get_gradio_version() | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
refresh_symbol = '\U0001f504' # 🔄 | ||||||||||||||||||||||||||||||||
sam_model_cache = OrderedDict() | ||||||||||||||||||||||||||||||||
|
@@ -35,16 +47,6 @@ | |||||||||||||||||||||||||||||||
txt2img_height: gr.Slider = None | ||||||||||||||||||||||||||||||||
img2img_width: gr.Slider = None | ||||||||||||||||||||||||||||||||
img2img_height: gr.Slider = None | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
class ToolButton(gr.Button, gr.components.FormComponent): | ||||||||||||||||||||||||||||||||
"""Small button with single emoji as text, fits inside gradio forms""" | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
def __init__(self, **kwargs): | ||||||||||||||||||||||||||||||||
super().__init__(variant="tool", **kwargs) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
def get_block_name(self): | ||||||||||||||||||||||||||||||||
return "button" | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
def show_masks(image_np, masks: np.ndarray, alpha=0.5): | ||||||||||||||||||||||||||||||||
|
@@ -56,7 +58,38 @@ def show_masks(image_np, masks: np.ndarray, alpha=0.5): | |||||||||||||||||||||||||||||||
return image.astype(np.uint8) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
def update_mask(mask_gallery, chosen_mask, dilation_amt, input_image): | ||||||||||||||||||||||||||||||||
def update_mask_gr_four(mask_gallery, chosen_mask, dilation_amt, input_image): | ||||||||||||||||||||||||||||||||
print("Dilation Amount: ", dilation_amt) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
# Check if mask_gallery is a list | ||||||||||||||||||||||||||||||||
if isinstance(mask_gallery, list): | ||||||||||||||||||||||||||||||||
# Extract the image from the (image, caption) tuple | ||||||||||||||||||||||||||||||||
image_tuple = mask_gallery[chosen_mask + 3] | ||||||||||||||||||||||||||||||||
if isinstance(image_tuple, tuple) and len(image_tuple) == 2: | ||||||||||||||||||||||||||||||||
mask_image = image_tuple[0] # Get the image part of the tuple | ||||||||||||||||||||||||||||||||
if isinstance(mask_image, str): # If it's a file path | ||||||||||||||||||||||||||||||||
mask_image = Image.open(mask_image) | ||||||||||||||||||||||||||||||||
elif isinstance(mask_image, np.ndarray): # If it's an ndarray | ||||||||||||||||||||||||||||||||
mask_image = Image.fromarray(mask_image) | ||||||||||||||||||||||||||||||||
# Else, it is already a PIL.Image | ||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||
raise TypeError("Expected a tuple (image, caption) in mask_gallery") | ||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||
# If not a list, use the provided mask_gallery as-is | ||||||||||||||||||||||||||||||||
mask_image = mask_gallery | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
binary_img = np.array(mask_image.convert('1')) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
if dilation_amt: | ||||||||||||||||||||||||||||||||
mask_image, binary_img = dilate_mask(binary_img, dilation_amt) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
blended_image = Image.fromarray(show_masks(np.array(input_image), binary_img.astype(np.bool_)[None, ...])) | ||||||||||||||||||||||||||||||||
matted_image = np.array(input_image) | ||||||||||||||||||||||||||||||||
matted_image[~binary_img] = np.array([0, 0, 0, 0]) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
return [blended_image, mask_image, Image.fromarray(matted_image)] | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
def update_mask_gr_three(mask_gallery, chosen_mask, dilation_amt, input_image): | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In theory we could just use
Suggested change
This implies that we would crash if np.ndarray or PIL images are passed from a Gradio 4 gallery, but I actually haven't seen this case while using the extension. |
||||||||||||||||||||||||||||||||
print("Dilation Amount: ", dilation_amt) | ||||||||||||||||||||||||||||||||
if isinstance(mask_gallery, list): | ||||||||||||||||||||||||||||||||
mask_image = Image.open(mask_gallery[chosen_mask + 3]['name']) | ||||||||||||||||||||||||||||||||
|
@@ -70,6 +103,12 @@ def update_mask(mask_gallery, chosen_mask, dilation_amt, input_image): | |||||||||||||||||||||||||||||||
matted_image[~binary_img] = np.array([0, 0, 0, 0]) | ||||||||||||||||||||||||||||||||
return [blended_image, mask_image, Image.fromarray(matted_image)] | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
def update_mask(mask_gallery, chosen_mask, dilation_amt, input_image): | ||||||||||||||||||||||||||||||||
if gradio_version and gradio_version < "4.0.0": | ||||||||||||||||||||||||||||||||
return update_mask_gr_three(mask_gallery, chosen_mask, dilation_amt, input_image) | ||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||
return update_mask_gr_four(mask_gallery, chosen_mask, dilation_amt, input_image) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
def load_sam_model(sam_checkpoint): | ||||||||||||||||||||||||||||||||
model_type = sam_checkpoint.split('.')[0] | ||||||||||||||||||||||||||||||||
|
@@ -183,9 +222,16 @@ def create_mask_batch_output( | |||||||||||||||||||||||||||||||
output_blend.save(os.path.join(dino_batch_dest_dir, f"{filename}_{idx}_blend{ext}")) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
def parse_points(points): | ||||||||||||||||||||||||||||||||
if isinstance(points, str): | ||||||||||||||||||||||||||||||||
points = eval(points) # Convert string representation of list back to list | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there any specific reason for the
Suggested change
|
||||||||||||||||||||||||||||||||
return points | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
def sam_predict(sam_model_name, input_image, positive_points, negative_points, | ||||||||||||||||||||||||||||||||
dino_checkbox, dino_model_name, text_prompt, box_threshold, | ||||||||||||||||||||||||||||||||
dino_preview_checkbox, dino_preview_boxes_selection): | ||||||||||||||||||||||||||||||||
positive_points = parse_points(positive_points) | ||||||||||||||||||||||||||||||||
negative_points = parse_points(negative_points) | ||||||||||||||||||||||||||||||||
print("Start SAM Processing") | ||||||||||||||||||||||||||||||||
if sam_model_name is None: | ||||||||||||||||||||||||||||||||
return [], "SAM model not found. Please download SAM model from extension README." | ||||||||||||||||||||||||||||||||
|
@@ -554,7 +600,7 @@ def ui(self, is_img2img): | |||||||||||||||||||||||||||||||
with gr.Column(scale=10): | ||||||||||||||||||||||||||||||||
with gr.Row(): | ||||||||||||||||||||||||||||||||
sam_model_name = gr.Dropdown(label="SAM Model", choices=sam_model_list, value=sam_model_list[0] if len(sam_model_list) > 0 else None) | ||||||||||||||||||||||||||||||||
sam_refresh_models = ToolButton(value=refresh_symbol) | ||||||||||||||||||||||||||||||||
sam_refresh_models = ToolButton(value=refresh_symbol, variant="tool") | ||||||||||||||||||||||||||||||||
sam_refresh_models.click(refresh_sam_models, sam_model_name, sam_model_name) | ||||||||||||||||||||||||||||||||
with gr.Column(scale=1): | ||||||||||||||||||||||||||||||||
sam_use_cpu = gr.Checkbox(value=False, label="Use CPU for SAM") | ||||||||||||||||||||||||||||||||
|
@@ -567,7 +613,7 @@ def change_sam_device(use_cpu=False): | |||||||||||||||||||||||||||||||
gr.HTML(value="<p>Left click the image to add one positive point (black dot). Right click the image to add one negative point (red dot). Left click the point to remove it.</p>") | ||||||||||||||||||||||||||||||||
sam_input_image = gr.Image(label="Image for Segment Anything", elem_id=f"{tab_prefix}input_image", source="upload", type="pil", image_mode="RGBA") | ||||||||||||||||||||||||||||||||
sam_remove_dots = gr.Button(value="Remove all point prompts") | ||||||||||||||||||||||||||||||||
sam_dummy_component = gr.Label(visible=False) | ||||||||||||||||||||||||||||||||
sam_dummy_component = gr.Textbox(visible=False) | ||||||||||||||||||||||||||||||||
sam_remove_dots.click( | ||||||||||||||||||||||||||||||||
fn=lambda _: None, | ||||||||||||||||||||||||||||||||
_js="samRemoveDots", | ||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you seen a case where while using this extension we're stumbling upon
np.ndarray
or PIL images ? I've only filename strings passed here.