Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support for S^2 #1376

Merged
merged 2 commits into from
Apr 18, 2024
Merged

Add Support for S^2 #1376

merged 2 commits into from
Apr 18, 2024

Conversation

bfshi
Copy link
Contributor

@bfshi bfshi commented Apr 6, 2024

What does this PR do?

This PR integrates S2 into LLaVA-NeXT.

What is S2?

S2 is a method to extract multi-scale features from an image. For example, given an image of 336x336, S2 interpolates the image to multiple scales such as 336x336, 672x672, 1008x1008, extracts features at each scale and merge the features into a multi-scale feature map. The multi-scale features contain more detailed information about an image which is beneficial for Multimodal LLMs. Meanwhile, S2 ensures the number of visual token sent to LLM is the same as the regular single-scale features such that no computational overhead on LLM is incurred.

Please find more details in the S2 paper and GitHub repo.

What does this PR contain?

There are two changes in this PR.

  • A new class of CLIPVisionTowerS2 is defined in llava/model/multimodal_encoder/clip_encoder.py which augments a clip model with S2. This class is the same as the original CLIPVisionTower class except that it will return a multi-scale feature map instead of a single-scale feature map when forward is called. The multi-scale feature map has the same shape as the single-scale one, except on the channel dimension where multi-scale features have num_scales * original_dim dimensions (as defined in self.hidden_size).
  • llava/model/multimodal_encoder/builder.py is modified so that it will build CLIPVisionTowerS2 instead of CLIPVisionTower when S2 is used.

How to train LLaVA with S2?

First install s2wrapper through pip:

pip install git+https://github.com/bfshi/scaling_on_scales.git

This package only has one dependency of einops, so installing it shouldn't interfere with your environment.

Training configurations should be the same as training a regular LLaVA without anyres (i.e., image_aspect_ratio="pad" and mm_patch_merge_type="flat"), except for two new model configs:

  • s2=True. This turns on the usage of S2.
  • s2_scales="336,672,1008". This specifies the image scales S2 will extract features on.

Copy link
Owner

@haotian-liu haotian-liu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the contribution, and congrats on $S^2$ release! It's great to bring additional efficiency to high res perception. Minor thing in comments. Thanks!

llava/model/multimodal_encoder/clip_encoder.py Outdated Show resolved Hide resolved
self.s2_split_size = self.s2_scales[0]
self.s2_image_size = self.s2_scales[-1]

# change resize/crop size in preprocessing to the largest image size in s2_scale
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can move the import here
from s2wrapper import forward as multiscale_forward
and self.multiscale_forward = multiscale_forward

Copy link
Owner

@haotian-liu haotian-liu Apr 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and maybe you can add a exception handling on importerror to prompt the user to install s2wrapper

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Haotian, thanks for the suggestions! I've made the changes accordingly. Please take a look when you get a chance. Thanks!

llava/model/multimodal_encoder/clip_encoder.py Outdated Show resolved Hide resolved
@haotian-liu haotian-liu merged commit dcda07f into haotian-liu:main Apr 18, 2024
@haotian-liu
Copy link
Owner

Thanks!

@lightmatmul
Copy link

should not the model arguments be also changed to include s2 and s2_scales?

@bfshi
Copy link
Contributor Author

bfshi commented Apr 23, 2024

Hi @lightingvector,

Yes that's correct. Currently this PR only supports running and evaluating an already pre-trained LLaVA with S2. If you want to train LLaVA with S2 you need to add s2 and s2_scales in ModelArguments (here). You also need to store these two arguments into model.args here so that they will be saved to the checkpoint.

Thanks for pointing it out!

@diridiri
Copy link

Thanks for great works done, @bfshi

I've got few questions,
Since the output dimension (hidden_size) of CLIPVisionTowerS2 will differ in channels as we apply s2's multiple_forward (ex. 16 x 16 x 768 -> 16 x 16 x 1536 or more),
I guess we need to start from scratch to train mm_projector (stage 1) with new hidden_size, and then go for stage 2. Is that correct?

And I'm curious about effect of S2 on LLaVA regarding benchmarks other than V*, if you have any results or trained weights, can you share some?

@bfshi
Copy link
Contributor Author

bfshi commented Apr 25, 2024

Hi @diridiri,

Thanks for the interest! Yes, you need to train a mm_projector with the new hidden size. For results of S2 on other benchmarks, please refer to Table 11 in Appendix D in the paper. We are planning to release the checkpoint for LLaVA1.5 with S2 soon. For LLaVA1.6, since the training code is not released yet, we don't have the checkpoints currently.

@diridiri
Copy link

diridiri commented Apr 26, 2024

Appreciate your guidance,

I have one more question about your implementation :), since I encountered error in Training with S2.

In current implementation, the code below will call super class (CLIPVisionTower)'s constructor first rather than initializing its member attributes (self.s2_scales, self.s2_split_size, self.s2_image_size) first.

class CLIPVisionTowerS2(CLIPVisionTower):
    def __init__(self, vision_tower, args, delay_load=False):
        super().__init__(vision_tower, args, delay_load)

        self.s2_scales = getattr(args, 's2_scales', '336,672,1008')
        self.s2_scales = list(map(int, self.s2_scales.split(',')))
        self.s2_scales.sort()
        self.s2_split_size = self.s2_scales[0]
        self.s2_image_size = self.s2_scales[-1]

        try:
            from s2wrapper import forward as multiscale_forward
        except ImportError:
            raise ImportError('Package s2wrapper not found! Please install by running: \npip install git+https://github.com/bfshi/scaling_on_scales.git')
        self.multiscale_forward = multiscale_forward

        # change resize/crop size in preprocessing to the largest image size in s2_scale
        if not delay_load or getattr(args, 'unfreeze_mm_vision_tower', False):
            self.image_processor.size['shortest_edge'] = self.s2_image_size
            self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size

Then the constructor of CLIPVisionTower below will call self.load_model() method in case of delay_load == False, which will call self.load_model() method of CLIPVisionTowerS2, in accordance with Python MRO.

class CLIPVisionTower(nn.Module):
    def __init__(self, vision_tower, args, delay_load=False):
        super().__init__()

        self.is_loaded = False

        self.vision_tower_name = vision_tower
        self.select_layer = args.mm_vision_select_layer
        self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')

        if not delay_load:
            self.load_model()
        elif getattr(args, 'unfreeze_mm_vision_tower', False):
            self.load_model()
        else:
            self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)

And here I encounter error AttributeError: 'CLIPVisionTowerS2' object has no attribute 's2_image_size' because of accessing attribute not defined (which was scheduled to initialize after super class's constructor). In load_model method of CLIPVisionTowerS2, below.

    def load_model(self, device_map=None):
        if self.is_loaded:
            print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
            return

        self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
        self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
        self.vision_tower.requires_grad_(False)

        self.image_processor.size['shortest_edge'] = self.s2_image_size
        self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size

        self.is_loaded = True

Switching the order of initialization of attributes (self.s2_scales, self.s2_split_size, self.s2_image_size) and super().__init__(vision_tower, args, delay_load) will do, but I wonder if it might cause other errors on your current/later releases

@bfshi
Copy link
Contributor Author

bfshi commented Apr 26, 2024

Thanks for pointing this out! Yeah, probably the simplest fix is defining self.s2_scales, self.s2_split_size, self.s2_image_size before calling super().__init__(vision_tower, args, delay_load). In the meantime this range can be deleted. Will update this.

@baichuanzhou
Copy link

I also found this minor bug, and I submitted this PR for a quick fix.

@VldmrB
Copy link

VldmrB commented Apr 30, 2024

Hello,

Is this meant to support quantization? When I load a model with load_8bit=True (bitsandbytes), I get this error on inference:

File E:\llava1.6-ui\venv\Lib\site-packages\bitsandbytes\functional.py:2290, in igemmlt(A, B, SA, SB, out, Sout, dtype)
   2288 assert SB[1] in ["col_turing", "col_ampere"]
   2289 assert Sout[1] == "col32"
-> 2290 assert (
   2291     shapeA[-1] == shapeB[-1]
   2292 ), f"Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = {shapeA} @ {shapeB}"
   2293 formatB = SB[1]
   2294 prev_device = A.device

AssertionError: Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = torch.Size([576, 3072]) @ torch.Size([7168, 1024])

Whereas CLIPVisionTower works fine

Full stack trace
Cell In[3], line 163, in Context.generate(self, print_output, print_context)
    161 print('Inferring...')
    162 with torch.inference_mode():
--> 163     output_ids = model.generate(self.ready_tokens,
    164                                 images=self.image,
    165                                 do_sample=False,
    166                                 # temperature=0.0,
    167                                 max_new_tokens=16384,
    168                                 use_cache=True,
    169                                 # stopping_criteria=stopping_criteria
    170                                 )
    172 output_str = (tokenizer.decode(output_ids[0][:]).strip()
    173               .removeprefix('<|startoftext|>').lstrip())
    175 if self.context and self.context[-1].preset_reply:
    176     # Reset its unfinished state

File E:\llava1.6-ui\venv\Lib\site-packages\torch\utils\_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File E:\llava1.6-ui\llava\model\language_model\llava_llama.py:125, in LlavaLlamaForCausalLM.generate(self, inputs, images, image_sizes, **kwargs)
    115     raise NotImplementedError("`inputs_embeds` is not supported")
    117 if images is not None:
    118     (
    119         inputs,
    120         position_ids,
    121         attention_mask,
    122         _,
    123         inputs_embeds,
    124         _
--> 125     ) = self.prepare_inputs_labels_for_multimodal(
    126         inputs,
    127         position_ids,
    128         attention_mask,
    129         None,
    130         None,
    131         images,
    132         image_sizes=image_sizes
    133     )
    134 else:
    135     inputs_embeds = self.get_model().embed_tokens(inputs)

File E:\llava1.6-ui\llava\model\llava_arch.py:202, in LlavaMetaForCausalLM.prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes)
    200         raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
    201 else:
--> 202     image_features = self.encode_images(images)
    204 # TODO: image start / end is not implemented here to support pretraining.
    205 if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):

File E:\llava1.6-ui\llava\model\llava_arch.py:142, in LlavaMetaForCausalLM.encode_images(self, images)
    140 def encode_images(self, images):
    141     image_features = self.get_model().get_vision_tower()(images)
--> 142     image_features = self.get_model().mm_projector(image_features)
    143     return image_features

File E:\llava1.6-ui\venv\Lib\site-packages\torch\nn\modules\module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File E:\llava1.6-ui\venv\Lib\site-packages\torch\nn\modules\module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File E:\llava1.6-ui\venv\Lib\site-packages\accelerate\hooks.py:165, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    163         output = old_forward(*args, **kwargs)
    164 else:
--> 165     output = old_forward(*args, **kwargs)
    166 return module._hf_hook.post_forward(module, output)

File E:\llava1.6-ui\venv\Lib\site-packages\torch\nn\modules\container.py:217, in Sequential.forward(self, input)
    215 def forward(self, input):
    216     for module in self:
--> 217         input = module(input)
    218     return input

File E:\llava1.6-ui\venv\Lib\site-packages\torch\nn\modules\module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File E:\llava1.6-ui\venv\Lib\site-packages\torch\nn\modules\module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File E:\llava1.6-ui\venv\Lib\site-packages\accelerate\hooks.py:165, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    163         output = old_forward(*args, **kwargs)
    164 else:
--> 165     output = old_forward(*args, **kwargs)
    166 return module._hf_hook.post_forward(module, output)

File E:\llava1.6-ui\venv\Lib\site-packages\bitsandbytes\nn\modules.py:797, in Linear8bitLt.forward(self, x)
    794 if self.bias is not None and self.bias.dtype != x.dtype:
    795     self.bias.data = self.bias.data.to(x.dtype)
--> 797 out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
    799 if not self.state.has_fp16_weights:
    800     if self.state.CB is not None and self.state.CxB is not None:
    801         # we converted 8-bit row major to turing/ampere format in the first inference pass
    802         # we no longer need the row-major weight

File E:\llava1.6-ui\venv\Lib\site-packages\bitsandbytes\autograd\_functions.py:556, in matmul(A, B, out, state, threshold, bias)
    554 if threshold > 0.0:
    555     state.threshold = threshold
--> 556 return MatMul8bitLt.apply(A, B, out, bias, state)

File E:\llava1.6-ui\venv\Lib\site-packages\torch\autograd\function.py:553, in Function.apply(cls, *args, **kwargs)
    550 if not torch._C._are_functorch_transforms_active():
    551     # See NOTE: [functorch vjp and autograd interaction]
    552     args = _functorch.utils.unwrap_dead_wrappers(args)
--> 553     return super().apply(*args, **kwargs)  # type: ignore[misc]
    555 if not is_setup_ctx_defined:
    556     raise RuntimeError(
    557         "In order to use an autograd.Function with functorch transforms "
    558         "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
    559         "staticmethod. For more details, please see "
    560         "https://pytorch.org/docs/master/notes/extending.func.html"
    561     )

File E:\llava1.6-ui\venv\Lib\site-packages\bitsandbytes\autograd\_functions.py:395, in MatMul8bitLt.forward(ctx, A, B, out, bias, state)
    393 if using_igemmlt:
    394     C32A, SA = F.transform(CA, "col32")
--> 395     out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
    396     if bias is None or bias.dtype == torch.float16:
    397         # we apply the fused bias here
    398         output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)

File E:\llava1.6-ui\venv\Lib\site-packages\bitsandbytes\functional.py:2290, in igemmlt(A, B, SA, SB, out, Sout, dtype)
   2288 assert SB[1] in ["col_turing", "col_ampere"]
   2289 assert Sout[1] == "col32"
-> 2290 assert (
   2291     shapeA[-1] == shapeB[-1]
   2292 ), f"Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = {shapeA} @ {shapeB}"
   2293 formatB = SB[1]
   2294 prev_device = A.device

AssertionError: Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = torch.Size([576, 3072]) @ torch.Size([7168, 1024])

@bfshi
Copy link
Contributor Author

bfshi commented Apr 30, 2024

Hi @VldmrB, are you evaluating with CLIPVisionTowerS2? Did you train a model with S2 or did you use the LLaVA checkpoint trained without S2?

@VldmrB
Copy link

VldmrB commented May 1, 2024

Hi @VldmrB, are you evaluating with CLIPVisionTowerS2? Did you train a model with S2 or did you use the LLaVA checkpoint trained without S2?

Hello
I'm only evaluating/inferring with S2, I did not train any model with it. Oh...

Yes that's correct. Currently this PR only supports running and evaluating an already pre-trained LLaVA with S2

I just realized that you already mentioned that it will only work with a model trained with S2. My apologies, I misread it earlier. Thanks for doing this, in any case

@yanbai1993
Copy link

Hi @bfshi ,Thaks for your great work. Has the checkpoint for LLaVA1.5 with S2 been released yet?

@bfshi
Copy link
Contributor Author

bfshi commented May 15, 2024

Hi @yanbai1993,

Yes! Please see here in the S2 repo.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants