Skip to content

Commit

Permalink
Update stable_diffusion_pipeline_compiler.py
Browse files Browse the repository at this point in the history
Do not trace vae for pytorch 2
  • Loading branch information
chengzeyi authored Nov 6, 2023
1 parent 3058d09 commit f255ed8
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions sfast/compilers/stable_diffusion_pipeline_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,14 @@ def unet_forward_wrapper(sample, t, *args, **kwargs):

m.unet.forward = unet_forward_wrapper

# if packaging.version.parse(
# torch.__version__) < packaging.version.parse('2.0.0'):
m.vae.decode = lazy_trace_(to_module(m.vae.decode))
# For img2img
m.vae.encoder.forward = lazy_trace_(
to_module(m.vae.encoder.forward))
m.vae.quant_conv.forward = lazy_trace_(
to_module(m.vae.quant_conv.forward))
if packaging.version.parse(
torch.__version__) < packaging.version.parse('2.0.0'):
m.vae.decode = lazy_trace_(to_module(m.vae.decode))
# For img2img
m.vae.encoder.forward = lazy_trace_(
to_module(m.vae.encoder.forward))
m.vae.quant_conv.forward = lazy_trace_(
to_module(m.vae.quant_conv.forward))

if config.trace_scheduler:
m.scheduler.scale_model_input = lazy_trace_(
Expand Down

0 comments on commit f255ed8

Please sign in to comment.