From 510f7ae8e81c164a02cf2d18513588be67138b31 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 13 Sep 2024 12:28:56 +0000 Subject: [PATCH] pixtral --- src/mistral_inference/transformer.py | 2 +- ...sformer_utils.py => transformer_layers.py} | 0 src/mistral_inference/vision_encoder.py | 60 ++++++++++--------- 3 files changed, 32 insertions(+), 30 deletions(-) rename src/mistral_inference/{transformer_utils.py => transformer_layers.py} (100%) diff --git a/src/mistral_inference/transformer.py b/src/mistral_inference/transformer.py index 2894c07..9c9aebe 100644 --- a/src/mistral_inference/transformer.py +++ b/src/mistral_inference/transformer.py @@ -14,7 +14,7 @@ from mistral_inference.lora import LoRALoaderMixin from mistral_inference.model import ModelBase from mistral_inference.rope import precompute_freqs_cis -from mistral_inference.transformer_utils import RMSNorm, TransformerBlock +from mistral_inference.transformer_layers import RMSNorm, TransformerBlock from mistral_inference.vision_encoder import VisionLanguageAdapter, VisionTransformer diff --git a/src/mistral_inference/transformer_utils.py b/src/mistral_inference/transformer_layers.py similarity index 100% rename from src/mistral_inference/transformer_utils.py rename to src/mistral_inference/transformer_layers.py diff --git a/src/mistral_inference/vision_encoder.py b/src/mistral_inference/vision_encoder.py index fcca530..833cbb6 100644 --- a/src/mistral_inference/vision_encoder.py +++ b/src/mistral_inference/vision_encoder.py @@ -6,34 +6,7 @@ from mistral_inference.args import VisionEncoderArgs from mistral_inference.rope import precompute_freqs_cis_2d -from mistral_inference.transformer_utils import RMSNorm, TransformerBlock - - -class Transformer(nn.Module): - def __init__(self, args: VisionEncoderArgs): - super().__init__() - self.layers = torch.nn.ModuleList() - for _ in range(args.num_hidden_layers): - self.layers.append( - TransformerBlock( - dim=args.hidden_size, - hidden_dim=args.intermediate_size, - n_heads=args.num_attention_heads, - n_kv_heads=args.num_attention_heads, - head_dim=args.hidden_size // args.num_attention_heads, - norm_eps=1e-5, - ) - ) - - def forward( - self, - x: torch.Tensor, - mask: BlockDiagonalMask, - freqs_cis: Optional[torch.Tensor], - ) -> torch.Tensor: - for layer in self.layers: - x = layer(x, mask=mask, freqs_cis=freqs_cis) - return x +from mistral_inference.transformer_layers import RMSNorm, TransformerBlock def position_meshgrid( @@ -67,7 +40,7 @@ def __init__(self, args: VisionEncoderArgs): bias=False, ) self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5) - self.transformer = Transformer(args) + self.transformer = VisionTransformerBlocks(args) head_dim = self.args.hidden_size // self.args.num_attention_heads assert head_dim % 2 == 0, "ROPE requires even head_dim" @@ -142,3 +115,32 @@ def __init__(self, in_dim: int, out_dim: int): def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w_out(self.gelu(self.w_in(x))) # type: ignore[no-any-return] + + +class VisionTransformerBlocks(nn.Module): + def __init__(self, args: VisionEncoderArgs): + super().__init__() + self.layers = torch.nn.ModuleList() + for _ in range(args.num_hidden_layers): + self.layers.append( + TransformerBlock( + dim=args.hidden_size, + hidden_dim=args.intermediate_size, + n_heads=args.num_attention_heads, + n_kv_heads=args.num_attention_heads, + head_dim=args.hidden_size // args.num_attention_heads, + norm_eps=1e-5, + ) + ) + + def forward( + self, + x: torch.Tensor, + mask: BlockDiagonalMask, + freqs_cis: Optional[torch.Tensor], + ) -> torch.Tensor: + for layer in self.layers: + x = layer(x, mask=mask, freqs_cis=freqs_cis) + return x + +