Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

process_encoder_hidden_states function is applied on each denoising step #9625

Open
Benkovichnikita opened this issue Oct 9, 2024 · 3 comments

Comments

@Benkovichnikita
Copy link

Benkovichnikita commented Oct 9, 2024

def process_encoder_hidden_states(

Hey! I noticed that process_encoder_hidden_states is applied on each denoising step that influences on an inference performance. This function can be executed only once before running UNet2DConditionModel.forward. The output of process_encoder_hidden_states should be combined with encoder_hidden_states and be provided to UNet2DConditionModel.forward

@hlky
Copy link
Contributor

hlky commented Oct 9, 2024

Ditto for get_aug_embed and get_class_embed, these could both be called once from the pipeline rather than inside UNet2DConditionModel

def get_aug_embed(
self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
) -> Optional[torch.Tensor]:

def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:

@a-r-r-o-w
Copy link
Member

a-r-r-o-w commented Oct 10, 2024

These don't require recomputation, yes. We recently discussed about this internally too so I have some comments fresh in mind that I'd like to share.

This was tested in the past and it was found that the speedup gains were existent but quite insignificant (unless you're trying to generate in bulk, for example as an image generation service provider - in which case, it makes sense to save on as much extra overhead as you can). Since these layers exist in the UNet, performing intermediate computations in the pipeline using them to pass back into the UNet might be a little confusing for newcomers. Ideally, we'd like this be as simple as possible and behave like a blackbox function where you don't have to bother preparing anything but the latents and prompt embeddings, even if it has performance downsides to it.

I did have a design idea in mind to provide a modeling-level control where the user could pass in one or more layer identifiers/regex to be able to reuse computed values. Something that looked like:

pipe.unet.enable_caching(["encoder_hid_proj", "add_embedding", "class_embedding" ...])
pipe.transformer.enable_caching(["pos_embed", "patch_embed.text_proj"])

During the first inference step, the values you'd like to reuse/cache will be computed, and for the remaining inference steps, it would be reused. At the end of the denoising loop, the cache states would be cleared to be able use new prompts/images embeddings, etc.

Alternatively, we could add more parameters to the UNet forward such as class_embedding, which will be used if passed, and if not passed then class_labels would be used. I think this is a bit unclean though and not ideal to have two parameters per condition - which would become even more confusing for a newcomer, because of the extra bloat in the pipeline as well as modeling. If you have any design ideas with which the same effect could be achieved, without too many modifications in the pipelines themselves, please share! This could be very useful across models and help quite a bit for speeding up bulk generations

cc @yiyixuxu too because we were discussing this recently

@Benkovichnikita
Copy link
Author

Benkovichnikita commented Oct 13, 2024

Thank you @a-r-r-o-w! In my tests the overhead was around 2.5% for the IP adapter, which I don't think is a big deal for most users. But it depends on an encoder complexity. Personally I like your idea with caching especially if you don't want to do changes on a pipeline level.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants