-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
use nn.Sequential to remove python control flow from autoencoder up/downsampling #33
base: main
Are you sure you want to change the base?
Conversation
31687a3
to
2a6dd2e
Compare
|
@yorickvP interesting that there's no graph breaks; wonder if the compiler is smart enough to realize that the if statements will always evaluate to true/false depending on loop iteration. agreed that we should get perf numbers & test that outputs are unchanged |
unfortunately this changes the state_dict keys: Got 180 missing keys:
encoder.down.0.norm2.weight
encoder.down.0.norm2.bias
encoder.down.0.conv2.weight
encoder.down.0.conv2.bias
encoder.down.1.norm1.weight
encoder.down.1.norm1.bias
encoder.down.1.conv1.weight
encoder.down.1.conv1.bias
encoder.down.1.norm2.weight
encoder.down.1.norm2.bias
encoder.down.1.conv2.weight
encoder.down.1.conv2.bias
encoder.down.2.conv.weight
encoder.down.2.conv.bias
encoder.down.3.norm1.weight
encoder.down.3.norm1.bias
encoder.down.3.conv1.weight
encoder.down.3.conv1.bias
encoder.down.3.norm2.weight
encoder.down.3.norm2.bias
encoder.down.3.conv2.weight
encoder.down.3.conv2.bias
encoder.down.3.nin_shortcut.weight
encoder.down.3.nin_shortcut.bias
encoder.down.4.norm1.weight
encoder.down.4.norm1.bias
encoder.down.4.conv1.weight
encoder.down.4.conv1.bias
encoder.down.4.norm2.weight
encoder.down.4.norm2.bias
encoder.down.4.conv2.weight
encoder.down.4.conv2.bias
encoder.down.5.conv.weight
encoder.down.5.conv.bias
encoder.down.6.norm1.weight
encoder.down.6.norm1.bias
encoder.down.6.conv1.weight
encoder.down.6.conv1.bias
encoder.down.6.norm2.weight
encoder.down.6.norm2.bias
encoder.down.6.conv2.weight
encoder.down.6.conv2.bias
encoder.down.6.nin_shortcut.weight
encoder.down.6.nin_shortcut.bias
encoder.down.7.norm1.weight
encoder.down.7.norm1.bias
encoder.down.7.conv1.weight
encoder.down.7.conv1.bias
encoder.down.7.norm2.weight
encoder.down.7.norm2.bias
encoder.down.7.conv2.weight
encoder.down.7.conv2.bias
encoder.down.8.conv.weight
encoder.down.8.conv.bias
encoder.down.9.norm1.weight
encoder.down.9.norm1.bias
encoder.down.9.conv1.weight
encoder.down.9.conv1.bias
encoder.down.9.norm2.weight
encoder.down.9.norm2.bias
encoder.down.9.conv2.weight
encoder.down.9.conv2.bias
encoder.down.10.norm1.weight
encoder.down.10.norm1.bias
encoder.down.10.conv1.weight
encoder.down.10.conv1.bias
encoder.down.10.norm2.weight
encoder.down.10.norm2.bias
encoder.down.10.conv2.weight
encoder.down.10.conv2.bias
decoder.up.0.norm1.weight
decoder.up.0.norm1.bias
decoder.up.0.conv1.weight
decoder.up.0.conv1.bias
decoder.up.0.norm2.weight
decoder.up.0.norm2.bias
decoder.up.0.conv2.weight
decoder.up.0.conv2.bias
decoder.up.0.nin_shortcut.weight
decoder.up.0.nin_shortcut.bias
decoder.up.1.norm1.weight
decoder.up.1.norm1.bias
decoder.up.1.conv1.weight
decoder.up.1.conv1.bias
decoder.up.1.norm2.weight
decoder.up.1.norm2.bias
decoder.up.1.conv2.weight
decoder.up.1.conv2.bias
decoder.up.2.norm1.weight
decoder.up.2.norm1.bias
decoder.up.2.conv1.weight
decoder.up.2.conv1.bias
decoder.up.2.norm2.weight
decoder.up.2.norm2.bias
decoder.up.2.conv2.weight
decoder.up.2.conv2.bias
decoder.up.3.norm1.weight
decoder.up.3.norm1.bias
decoder.up.3.conv1.weight
decoder.up.3.conv1.bias
decoder.up.3.norm2.weight
decoder.up.3.norm2.bias
decoder.up.3.conv2.weight
decoder.up.3.conv2.bias
decoder.up.3.nin_shortcut.weight
decoder.up.3.nin_shortcut.bias
decoder.up.4.norm1.weight
decoder.up.4.norm1.bias
decoder.up.4.conv1.weight
decoder.up.4.conv1.bias
decoder.up.4.norm2.weight
decoder.up.4.norm2.bias
decoder.up.4.conv2.weight
decoder.up.4.conv2.bias
decoder.up.5.norm1.weight
decoder.up.5.norm1.bias
decoder.up.5.conv1.weight
decoder.up.5.conv1.bias
decoder.up.5.norm2.weight
decoder.up.5.norm2.bias
decoder.up.5.conv2.weight
decoder.up.5.conv2.bias
decoder.up.6.conv.weight
decoder.up.6.conv.bias
decoder.up.7.norm1.weight
decoder.up.7.norm1.bias
decoder.up.7.conv1.weight
decoder.up.7.conv1.bias
decoder.up.7.norm2.weight
decoder.up.7.norm2.bias
decoder.up.7.conv2.weight
decoder.up.7.conv2.bias
decoder.up.8.norm1.weight
decoder.up.8.norm1.bias
decoder.up.8.conv1.weight
decoder.up.8.conv1.bias
decoder.up.8.norm2.weight
decoder.up.8.norm2.bias
decoder.up.8.conv2.weight
decoder.up.8.conv2.bias
decoder.up.9.norm1.weight
decoder.up.9.norm1.bias
decoder.up.9.conv1.weight
decoder.up.9.conv1.bias
decoder.up.9.norm2.weight
decoder.up.9.norm2.bias
decoder.up.9.conv2.weight
decoder.up.9.conv2.bias
decoder.up.10.conv.weight
decoder.up.10.conv.bias
decoder.up.11.norm1.weight
decoder.up.11.norm1.bias
decoder.up.11.conv1.weight
decoder.up.11.conv1.bias
decoder.up.11.norm2.weight
decoder.up.11.norm2.bias
decoder.up.11.conv2.weight
decoder.up.11.conv2.bias
decoder.up.12.norm1.weight
decoder.up.12.norm1.bias
decoder.up.12.conv1.weight
decoder.up.12.conv1.bias
decoder.up.12.norm2.weight
decoder.up.12.norm2.bias
decoder.up.12.conv2.weight
decoder.up.12.conv2.bias
decoder.up.13.norm1.weight
decoder.up.13.norm1.bias
decoder.up.13.conv1.weight
decoder.up.13.conv1.bias
decoder.up.13.norm2.weight
decoder.up.13.norm2.bias
decoder.up.13.conv2.weight
decoder.up.13.conv2.bias
decoder.up.14.conv.weight
decoder.up.14.conv.bias
Got 180 unexpected keys:
encoder.down.0.block.0.norm1.bias
encoder.down.0.block.0.norm1.weight
encoder.down.0.block.0.norm2.bias
encoder.down.0.block.0.norm2.weight
encoder.down.0.block.1.conv1.bias
encoder.down.0.block.1.conv1.weight
encoder.down.0.block.1.conv2.bias
encoder.down.0.block.1.conv2.weight
encoder.down.0.block.1.norm1.bias
encoder.down.0.block.1.norm1.weight
encoder.down.0.block.1.norm2.bias
encoder.down.0.block.1.norm2.weight
encoder.down.0.downsample.conv.bias
encoder.down.0.downsample.conv.weight
encoder.down.1.block.0.conv1.bias
encoder.down.1.block.0.conv1.weight
encoder.down.1.block.0.conv2.bias
encoder.down.1.block.0.conv2.weight
encoder.down.1.block.0.nin_shortcut.bias
encoder.down.1.block.0.nin_shortcut.weight
encoder.down.1.block.0.norm1.bias
encoder.down.1.block.0.norm1.weight
encoder.down.1.block.0.norm2.bias
encoder.down.1.block.0.norm2.weight
encoder.down.1.block.1.conv1.bias
encoder.down.1.block.1.conv1.weight
encoder.down.1.block.1.conv2.bias
encoder.down.1.block.1.conv2.weight
encoder.down.1.block.1.norm1.bias
encoder.down.1.block.1.norm1.weight
encoder.down.1.block.1.norm2.bias
encoder.down.1.block.1.norm2.weight
encoder.down.1.downsample.conv.bias
encoder.down.1.downsample.conv.weight
encoder.down.2.block.0.conv1.bias
encoder.down.2.block.0.conv1.weight
encoder.down.2.block.0.conv2.bias
encoder.down.2.block.0.conv2.weight
encoder.down.2.block.0.nin_shortcut.bias
encoder.down.2.block.0.nin_shortcut.weight
encoder.down.2.block.0.norm1.bias
encoder.down.2.block.0.norm1.weight
encoder.down.2.block.0.norm2.bias
encoder.down.2.block.0.norm2.weight
encoder.down.2.block.1.conv1.bias
encoder.down.2.block.1.conv1.weight
encoder.down.2.block.1.conv2.bias
encoder.down.2.block.1.conv2.weight
encoder.down.2.block.1.norm1.bias
encoder.down.2.block.1.norm1.weight
encoder.down.2.block.1.norm2.bias
encoder.down.2.block.1.norm2.weight
encoder.down.2.downsample.conv.bias
encoder.down.2.downsample.conv.weight
encoder.down.3.block.0.conv1.bias
encoder.down.3.block.0.conv1.weight
encoder.down.3.block.0.conv2.bias
encoder.down.3.block.0.conv2.weight
encoder.down.3.block.0.norm1.bias
encoder.down.3.block.0.norm1.weight
encoder.down.3.block.0.norm2.bias
encoder.down.3.block.0.norm2.weight
encoder.down.3.block.1.conv1.bias
encoder.down.3.block.1.conv1.weight
encoder.down.3.block.1.conv2.bias
encoder.down.3.block.1.conv2.weight
encoder.down.3.block.1.norm1.bias
encoder.down.3.block.1.norm1.weight
encoder.down.3.block.1.norm2.bias
encoder.down.3.block.1.norm2.weight
decoder.up.0.block.0.conv1.bias
decoder.up.0.block.0.conv1.weight
decoder.up.0.block.0.conv2.bias
decoder.up.0.block.0.conv2.weight
decoder.up.0.block.0.nin_shortcut.bias
decoder.up.0.block.0.nin_shortcut.weight
decoder.up.0.block.0.norm1.bias
decoder.up.0.block.0.norm1.weight
decoder.up.0.block.0.norm2.bias
decoder.up.0.block.0.norm2.weight
decoder.up.0.block.1.conv1.bias
decoder.up.0.block.1.conv1.weight
decoder.up.0.block.1.conv2.bias
decoder.up.0.block.1.conv2.weight
decoder.up.0.block.1.norm1.bias
decoder.up.0.block.1.norm1.weight
decoder.up.0.block.1.norm2.bias
decoder.up.0.block.1.norm2.weight
decoder.up.0.block.2.conv1.bias
decoder.up.0.block.2.conv1.weight
decoder.up.0.block.2.conv2.bias
decoder.up.0.block.2.conv2.weight
decoder.up.0.block.2.norm1.bias
decoder.up.0.block.2.norm1.weight
decoder.up.0.block.2.norm2.bias
decoder.up.0.block.2.norm2.weight
decoder.up.1.block.0.conv1.bias
decoder.up.1.block.0.conv1.weight
decoder.up.1.block.0.conv2.bias
decoder.up.1.block.0.conv2.weight
decoder.up.1.block.0.nin_shortcut.bias
decoder.up.1.block.0.nin_shortcut.weight
decoder.up.1.block.0.norm1.bias
decoder.up.1.block.0.norm1.weight
decoder.up.1.block.0.norm2.bias
decoder.up.1.block.0.norm2.weight
decoder.up.1.block.1.conv1.bias
decoder.up.1.block.1.conv1.weight
decoder.up.1.block.1.conv2.bias
decoder.up.1.block.1.conv2.weight
decoder.up.1.block.1.norm1.bias
decoder.up.1.block.1.norm1.weight
decoder.up.1.block.1.norm2.bias
decoder.up.1.block.1.norm2.weight
decoder.up.1.block.2.conv1.bias
decoder.up.1.block.2.conv1.weight
decoder.up.1.block.2.conv2.bias
decoder.up.1.block.2.conv2.weight
decoder.up.1.block.2.norm1.bias
decoder.up.1.block.2.norm1.weight
decoder.up.1.block.2.norm2.bias
decoder.up.1.block.2.norm2.weight
decoder.up.1.upsample.conv.bias
decoder.up.1.upsample.conv.weight
decoder.up.2.block.0.conv1.bias
decoder.up.2.block.0.conv1.weight
decoder.up.2.block.0.conv2.bias
decoder.up.2.block.0.conv2.weight
decoder.up.2.block.0.norm1.bias
decoder.up.2.block.0.norm1.weight
decoder.up.2.block.0.norm2.bias
decoder.up.2.block.0.norm2.weight
decoder.up.2.block.1.conv1.bias
decoder.up.2.block.1.conv1.weight
decoder.up.2.block.1.conv2.bias
decoder.up.2.block.1.conv2.weight
decoder.up.2.block.1.norm1.bias
decoder.up.2.block.1.norm1.weight
decoder.up.2.block.1.norm2.bias
decoder.up.2.block.1.norm2.weight
decoder.up.2.block.2.conv1.bias
decoder.up.2.block.2.conv1.weight
decoder.up.2.block.2.conv2.bias
decoder.up.2.block.2.conv2.weight
decoder.up.2.block.2.norm1.bias
decoder.up.2.block.2.norm1.weight
decoder.up.2.block.2.norm2.bias
decoder.up.2.block.2.norm2.weight
decoder.up.2.upsample.conv.bias
decoder.up.2.upsample.conv.weight
decoder.up.3.block.0.conv1.bias
decoder.up.3.block.0.conv1.weight
decoder.up.3.block.0.conv2.bias
decoder.up.3.block.0.conv2.weight
decoder.up.3.block.0.norm1.bias
decoder.up.3.block.0.norm1.weight
decoder.up.3.block.0.norm2.bias
decoder.up.3.block.0.norm2.weight
decoder.up.3.block.1.conv1.bias
decoder.up.3.block.1.conv1.weight
decoder.up.3.block.1.conv2.bias
decoder.up.3.block.1.conv2.weight
decoder.up.3.block.1.norm1.bias
decoder.up.3.block.1.norm1.weight
decoder.up.3.block.1.norm2.bias
decoder.up.3.block.1.norm2.weight
decoder.up.3.block.2.conv1.bias
decoder.up.3.block.2.conv1.weight
decoder.up.3.block.2.conv2.bias
decoder.up.3.block.2.conv2.weight
decoder.up.3.block.2.norm1.bias
decoder.up.3.block.2.norm1.weight
decoder.up.3.block.2.norm2.bias
decoder.up.3.block.2.norm2.weight
decoder.up.3.upsample.conv.bias
decoder.up.3.upsample.conv.weight
The unexpected keys are the one in the ae.sft file, the expected keys are the ones from using flat |
92c664c
to
35b343d
Compare
35b343d
to
7333f34
Compare
It might be nicer to instead override |
the current autoencoder implementation causes graph breaks, likely due to python control flow. the
hs
list constructed in the encoder is also egregious.performance improvement numbers TBDthis does not make an improvement with normal eager torch, if we decide to compile for the autoencoder it will fix the graph breaks with tensorrt, but torch.compile may see through this structure.