diff --git a/llava/model/multimodal_encoder/clip_encoder.py b/llava/model/multimodal_encoder/clip_encoder.py index 2c81415cd..6bfbc3a9e 100644 --- a/llava/model/multimodal_encoder/clip_encoder.py +++ b/llava/model/multimodal_encoder/clip_encoder.py @@ -91,14 +91,13 @@ def num_patches(self): class CLIPVisionTowerS2(CLIPVisionTower): def __init__(self, vision_tower, args, delay_load=False): - super().__init__(vision_tower, args, delay_load) - self.s2_scales = getattr(args, 's2_scales', '336,672,1008') self.s2_scales = list(map(int, self.s2_scales.split(','))) self.s2_scales.sort() self.s2_split_size = self.s2_scales[0] self.s2_image_size = self.s2_scales[-1] - + super().__init__(vision_tower, args, delay_load) + try: from s2wrapper import forward as multiscale_forward except ImportError: