Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

segfault when latent height or width is not divisible by 4 #159

Open
Birch-san opened this issue Aug 21, 2024 · 0 comments
Open

segfault when latent height or width is not divisible by 4 #159

Birch-san opened this issue Aug 21, 2024 · 0 comments

Comments

@Birch-san
Copy link

Birch-san commented Aug 21, 2024

Problem

If I trace my model like so (weird shape then nice shape), everything is fine:

model(randn(1, 4, 150, 157))
model(randn(1, 4, 64, 64))

whereas if I trace my model with the nice shape first, it segfaults (backtrace here):

model(randn(1, 4, 64, 64))
model(randn(1, 4, 150, 157))

Note: this problem only reproduces for me on inpainting models (i.e. where conv_in has 8 channels). I don't know why.

Possibly related:
#153 (comment)

Proximal cause

I can get a better error message if I set enable_jit_freeze = False in my CompilationConfig:

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 76 but got size 75 for tensor number 1 in the list.
sfast/jit/overrides.py(21): __torch_function__
diffusers/models/unet_2d_blocks.py(2521): forward
diffusers/models/unet_2d_condition.py(1281): forward
sfast/jit/trace_helper.py(89): forward

the problem led to this line of code:
https://github.com/huggingface/diffusers/blob/c2916175186e2b6d9c2d09b13a753cc47f5d9e19/src/diffusers/models/unets/unet_2d_blocks.py#L2521

‎CrossAttnUpBlock2D#forward

hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)

The decoder's upsampled hidden_states are not the same size as the encoder's residual hidden states.

Ordinarily (stable-fast disabled), we expect both hidden states to have a height of 75:

    hidden_states = torch.Size([2, 1280, 75, 79])
res_hidden_states = torch.Size([2, 640, 75, 79])

    hidden_states = torch.Size([2, 640, 75, 79])
res_hidden_states = torch.Size([2, 640, 75, 79])

    hidden_states = torch.Size([2, 640, 75, 79])
res_hidden_states = torch.Size([2, 320, 75, 79])

But the JIT has caused the decoder's upsampled hidden states to have a height of 76 instead of 75.

Root cause

It turns out stable-diffusion has two upsample algorithms.
Flow control determines the upsample algorithm.

https://github.com/huggingface/diffusers/blob/c2916175186e2b6d9c2d09b13a753cc47f5d9e19/src/diffusers/models/upsampling.py#L167-L171

Upsample2D#forward

if self.interpolate:
    if output_size is None:
        hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
    else:
        hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")

Sometimes you do not upsample by 2x.
Sometimes you upsample to a target resolution instead!

"Upsample to a target resolution" happens when the latent height or width is indivisible by 4:

https://github.com/huggingface/diffusers/blob/c2916175186e2b6d9c2d09b13a753cc47f5d9e19/src/diffusers/models/unets/unet_2d_condition.py#L1106-L1113

UNet2DConditionModel#forward

# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None

for dim in sample.shape[-2:]:
    if dim % default_overall_up_factor != 0:
        # Forward upsample size to force interpolation output size.
        forward_upsample_size = True
        break

That's why 64x64 behaves differently to 150x157.
It's also why the 150x157 solution generalizes, but the 64x64 solution does not.

Evidence

We can print ScriptModule#code to see the difference in the the traced model:

attach_script_module_clear_hook(script_module._c)

s = ''
for name, mod in script_module.named_modules():
    if hasattr(mod, 'code'):
        s += f'\n{name} {mod.__class__.__name__} code:\n{mod.code}\n'
with open('code.py', 'w') as f:
    f.write(s)

Or if you prefer the IR representation:

s = ''
for name, mod in script_module.named_modules():
    try:
        if hasattr(mod, 'graph'):
            s += f'\n{name} {mod.__class__.__name__} graph:\n{mod.graph}\n'
        else:
            s += f'\n{name} {mod.__class__.__name__} graph:\n(no graph)\n'
    except RuntimeError:
        s += f'\n{name} {mod.__class__.__name__} graph:\nRuntimeError\n'
with open('graph.txt', 'w') as f:
    f.write(s)

There are only a couple of differences in the codegen!

UNet2DCondition#forward passes more arguments when the tensor shape is weird (it's passing upsample_size!)
Left = weird tensor shape
Right = simple tensor shape
image

The up-blocks only pass the upsample_size argument through if the tensor shape is weird.
Left = weird tensor shape
Right = simple tensor shape
image
image

Upsample2D#forward only invokes upsample_nearest2d with a target size if the tensor shape is weird.
Left = weird tensor shape
Right = simple tensor shape
image

Proposed solution

Maybe it's fine to just require users to trace the model first with a weird tensor shape? So that you always go down the "give nearest_upsample2d a target shape" codepath.

Or modify UNet2DConditionModel#forward to always set forward_upsample_size = True. That would achieve the same outcome more cheaply.

I don't know whether torch.upsample_nearest2d() returns the same when you use scale_factor vs when you use target_size. I'm optimistic that it probably would though, at least for scale_factor=2.

Or maybe this section of UNet2DConditionModel#forward needs to be JITed in script-mode, to enable control flow:

https://github.com/huggingface/diffusers/blob/c2916175186e2b6d9c2d09b13a753cc47f5d9e19/src/diffusers/models/unets/unet_2d_condition.py#L1106-L1113

# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None

for dim in sample.shape[-2:]:
    if dim % default_overall_up_factor != 0:
        # Forward upsample size to force interpolation output size.
        forward_upsample_size = True
        break

Maybe something like this? I haven't tried running this code.

from torch.jit import script_if_tracing
from torch import Tensor

# …

@script_if_tracing
def should_fwd_upsample_size(
    sample: Tensor,
    default_overall_up_factor: int
) -> bool:
    # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
    for dim in sample.shape[-2:]:
        if dim % default_overall_up_factor != 0:
            # Forward upsample size to force interpolation output size.
            return True
    return False

forward_upsample_size: bool = should_fwd_upsample_size(sample, default_overall_up_factor)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant