From 26a52a14eb5019ef0d40a822c1fe46456d6c378b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 16 Jul 2024 12:16:39 +0200 Subject: [PATCH] add mamba --- README.md | 4 +- moe_one_file_ref.py | 64 ++--- one_file_ref.py | 56 ++-- src/mistral_inference/args.py | 49 ++++ src/mistral_inference/cache.py | 53 +--- src/mistral_inference/generate.py | 53 +++- src/mistral_inference/lora.py | 57 ++-- src/mistral_inference/main.py | 80 ++++-- src/mistral_inference/mamba.py | 83 ++++++ src/mistral_inference/model.py | 400 ++------------------------- src/mistral_inference/moe.py | 8 +- src/mistral_inference/transformer.py | 367 ++++++++++++++++++++++++ tests/test_generate.py | 58 ++-- tutorials/getting_started.ipynb | 2 +- 14 files changed, 720 insertions(+), 614 deletions(-) create mode 100644 src/mistral_inference/args.py create mode 100644 src/mistral_inference/mamba.py create mode 100644 src/mistral_inference/transformer.py diff --git a/README.md b/README.md index 79ec2f4..090dc45 100644 --- a/README.md +++ b/README.md @@ -155,7 +155,7 @@ You can continue chatting afterwards, *e.g.* with *"Translate it to Python"*. - *Instruction Following*: ```py -from mistral_inference.model import Transformer +from mistral_inference.transformer import Transformer from mistral_inference.generate import generate from mistral_common.tokens.tokenizers.mistral import MistralTokenizer @@ -228,7 +228,7 @@ pip install --upgrade mistral-common You can simulate a code completion in-filling as follows. ```py -from mistral_inference.model import Transformer +from mistral_inference.transformer import Transformer from mistral_inference.generate import generate from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from mistral_common.tokens.instruct.request import FIMRequest diff --git a/moe_one_file_ref.py b/moe_one_file_ref.py index aa9c4bb..542388e 100644 --- a/moe_one_file_ref.py +++ b/moe_one_file_ref.py @@ -22,7 +22,7 @@ class MoeArgs(Serializable): @dataclass -class ModelArgs(Serializable): +class TransformerArgs(Serializable): dim: int n_layers: int head_dim: int @@ -80,7 +80,7 @@ def apply_rotary_emb( class Attention(nn.Module): - def __init__(self, args: ModelArgs): + def __init__(self, args: TransformerArgs): super().__init__() self.args = args @@ -144,9 +144,7 @@ def forward( xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) # Update cache - scatter_pos = positions[None, :, None, None].repeat( - bsz, 1, self.n_kv_heads, self.args.head_dim - ) + scatter_pos = positions[None, :, None, None].repeat(bsz, 1, self.n_kv_heads, self.args.head_dim) cache_k[:bsz].scatter_(dim=1, index=scatter_pos, src=xk) cache_v[:bsz].scatter_(dim=1, index=scatter_pos, src=xv) @@ -179,7 +177,7 @@ def forward( class FeedForward(nn.Module): - def __init__(self, args: ModelArgs): + def __init__(self, args: TransformerArgs): super().__init__() self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False) self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False) @@ -214,9 +212,7 @@ def __init__(self, experts: List[nn.Module], gate: nn.Module, moe_args: MoeArgs) def forward(self, inputs: torch.Tensor): inputs_squashed = inputs.view(-1, inputs.shape[-1]) gate_logits = self.gate(inputs_squashed) - weights, selected_experts = torch.topk( - gate_logits, self.args.num_experts_per_tok - ) + weights, selected_experts = torch.topk(gate_logits, self.args.num_experts_per_tok) weights = nn.functional.softmax( weights, dim=1, @@ -225,14 +221,12 @@ def forward(self, inputs: torch.Tensor): results = torch.zeros_like(inputs_squashed) for i, expert in enumerate(self.experts): batch_idx, nth_expert = torch.where(selected_experts == i) - results[batch_idx] += weights[batch_idx, nth_expert, None] * expert( - inputs_squashed[batch_idx] - ) + results[batch_idx] += weights[batch_idx, nth_expert, None] * expert(inputs_squashed[batch_idx]) return results.view_as(inputs) class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs): + def __init__(self, args: TransformerArgs): super().__init__() self.n_heads = args.n_heads self.dim = args.dim @@ -270,7 +264,7 @@ def precompute_freqs_cis(dim: int, end: int, theta: float) -> torch.Tensor: class Transformer(nn.Module): def __init__( self, - args: ModelArgs, + args: TransformerArgs, pipeline_rank: int = 0, num_pipeline_ranks: int = 1, ): @@ -316,13 +310,9 @@ def freqs_cis(self) -> torch.Tensor: # from the module's dtype means we cannot register it as a buffer if self._precomputed_freqs_cis is None: theta = self.args.rope_theta or 1000000.0 - self._precomputed_freqs_cis = precompute_freqs_cis( - self.args.head_dim, 128_000, theta - ) + self._precomputed_freqs_cis = precompute_freqs_cis(self.args.head_dim, 128_000, theta) if self._precomputed_freqs_cis.device != self.device: - self._precomputed_freqs_cis = self._precomputed_freqs_cis.to( - device=self.device - ) + self._precomputed_freqs_cis = self._precomputed_freqs_cis.to(device=self.device) return self._precomputed_freqs_cis def forward( @@ -341,9 +331,7 @@ def forward( assert h.shape == (bsz, seqlen, self.args.dim) assert h.dtype == self.dtype else: - h = torch.empty( - bsz, seqlen, self.args.dim, device=self.device, dtype=self.dtype - ) + h = torch.empty(bsz, seqlen, self.args.dim, device=self.device, dtype=self.dtype) torch.distributed.recv(h, src=self.pipeline_rank - 1) mask: Optional[torch.Tensor] = None @@ -361,9 +349,7 @@ def forward( if self.pipeline_rank < self.num_pipeline_ranks - 1: torch.distributed.send(h, dst=self.pipeline_rank + 1) - outs = torch.empty( - *h.shape[:-1], self.vocab_size, device=h.device, dtype=h.dtype - ) + outs = torch.empty(*h.shape[:-1], self.vocab_size, device=h.device, dtype=h.dtype) else: assert self.output is not None assert self.norm is not None @@ -422,7 +408,7 @@ def from_folder( dtype=torch.float16, ) -> "Transformer": with open(folder / "params.json", "r") as f: - model_args = ModelArgs.from_dict(json.load(f)) + model_args = TransformerArgs.from_dict(json.load(f)) model_args.max_batch_size = max_batch_size model_args.max_seq_len = max_seq_len if num_pipeline_ranks > 1: @@ -457,9 +443,7 @@ def from_folder( def load_tokenizer(model_path: Path) -> MistralTokenizer: - tokenizer = [ - f for f in os.listdir(Path(model_path)) if f.startswith("tokenizer.model") - ] + tokenizer = [f for f in os.listdir(Path(model_path)) if f.startswith("tokenizer.model")] assert ( len(tokenizer) == 1 ), f"Multiple tokenizers {', '.join(tokenizer)} found in `model_path`, make sure to only have one tokenizer" @@ -470,12 +454,8 @@ def load_tokenizer(model_path: Path) -> MistralTokenizer: @torch.no_grad() -def generate( - prompts: List[str], model: Transformer, tokenizer: Tokenizer, max_tokens: int -): - encoded_prompts = [ - tokenizer.encode(prompt, bos=True, eos=False) for prompt in prompts - ] +def generate(prompts: List[str], model: Transformer, tokenizer: Tokenizer, max_tokens: int): + encoded_prompts = [tokenizer.encode(prompt, bos=True, eos=False) for prompt in prompts] prompt_lens = [len(x) for x in encoded_prompts] min_prompt_len = min(prompt_lens) max_prompt_len = max(prompt_lens) @@ -498,23 +478,17 @@ def generate( # decode generated = [] all_logprobs = [ - logprobs[:, :-1, :] - .gather(2, input_tokens[:, 1:min_prompt_len, None]) - .squeeze(-1), + logprobs[:, :-1, :].gather(2, input_tokens[:, 1:min_prompt_len, None]).squeeze(-1), ] for cur_pos in range(min_prompt_len, max_tokens): next_token = torch.argmax(logprobs[:, -1, :], dim=-1) if cur_pos < input_mask.shape[1]: - next_token = torch.where( - input_mask[:, cur_pos], input_tokens[:, cur_pos], next_token - ) + next_token = torch.where(input_mask[:, cur_pos], input_tokens[:, cur_pos], next_token) all_logprobs.append( logprobs[:, -1, :].gather(1, next_token[:, None]), ) generated.append(next_token[:, None]) - logits = model.forward( - next_token[:, None], torch.LongTensor([cur_pos]).to(next_token) - ) + logits = model.forward(next_token[:, None], torch.LongTensor([cur_pos]).to(next_token)) logprobs = nn.functional.log_softmax(logits, dim=-1) all_logprobs_merged = torch.cat(all_logprobs, 1) diff --git a/one_file_ref.py b/one_file_ref.py index b654e78..a848d73 100644 --- a/one_file_ref.py +++ b/one_file_ref.py @@ -14,7 +14,7 @@ @dataclass -class ModelArgs(Serializable): +class TransformerArgs(Serializable): dim: int n_layers: int head_dim: int @@ -31,9 +31,7 @@ class ModelArgs(Serializable): max_batch_size: int = 0 -def repeat_kv( - keys: torch.Tensor, values: torch.Tensor, repeats: int -) -> Tuple[torch.Tensor]: +def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int) -> Tuple[torch.Tensor]: keys = torch.repeat_interleave(keys, repeats=repeats, dim=2) values = torch.repeat_interleave(values, repeats=repeats, dim=2) return keys, values @@ -68,7 +66,7 @@ def apply_rotary_emb( class Attention(nn.Module): - def __init__(self, args: ModelArgs): + def __init__(self, args: TransformerArgs): super().__init__() self.args = args @@ -118,9 +116,7 @@ def forward( xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) # cache - scatter_pos = positions[None, :, None, None].repeat( - bsz, 1, self.n_kv_heads, self.args.head_dim - ) + scatter_pos = positions[None, :, None, None].repeat(bsz, 1, self.n_kv_heads, self.args.head_dim) self.cache_k[:bsz].scatter_(dim=1, index=scatter_pos, src=xk) self.cache_v[:bsz].scatter_(dim=1, index=scatter_pos, src=xv) @@ -152,7 +148,7 @@ def forward( class FeedForward(nn.Module): - def __init__(self, args: ModelArgs): + def __init__(self, args: TransformerArgs): super().__init__() self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False) @@ -178,7 +174,7 @@ def forward(self, x): class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs): + def __init__(self, args: TransformerArgs): super().__init__() self.n_heads = args.n_heads self.dim = args.dim @@ -210,7 +206,7 @@ def precompute_freqs_cis(dim: int, end: int, theta: float) -> torch.Tensor: class Transformer(nn.Module): - def __init__(self, args: ModelArgs): + def __init__(self, args: TransformerArgs): super().__init__() self.args = args self.vocab_size = args.vocab_size @@ -219,18 +215,14 @@ def __init__(self, args: ModelArgs): self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) - self.layers = torch.nn.ModuleList( - [TransformerBlock(args=args) for _ in range(args.n_layers)] - ) + self.layers = torch.nn.ModuleList([TransformerBlock(args=args) for _ in range(args.n_layers)]) self.norm = RMSNorm(args.dim, eps=args.norm_eps) self.output = nn.Linear(args.dim, args.vocab_size, bias=False) theta = self.args.rope_theta or 1000000.0 - self.freqs_cis = precompute_freqs_cis(self.args.head_dim, 128_000, theta).to( - "cuda" - ) + self.freqs_cis = precompute_freqs_cis(self.args.head_dim, 128_000, theta).to("cuda") def forward( self, @@ -259,11 +251,9 @@ def forward( return self.output(self.norm(h)).float() @staticmethod - def from_folder( - folder: Path, max_batch_size: int = 1, device="cuda", dtype=torch.float16 - ): + def from_folder(folder: Path, max_batch_size: int = 1, device="cuda", dtype=torch.float16): with open(Path(folder) / "params.json", "r") as f: - model_args = ModelArgs.from_dict(json.load(f)) + model_args = TransformerArgs.from_dict(json.load(f)) model_args.max_batch_size = max_batch_size model = Transformer(model_args) @@ -288,9 +278,7 @@ def from_folder( def load_tokenizer(model_path: Path) -> MistralTokenizer: - tokenizer = [ - f for f in os.listdir(Path(model_path)) if f.startswith("tokenizer.model") - ] + tokenizer = [f for f in os.listdir(Path(model_path)) if f.startswith("tokenizer.model")] assert ( len(tokenizer) > 0 ), f"No tokenizer found in {model_path}, make sure to place a `tokenizer.model.[v1,v2,v3]` file in {model_path}." @@ -304,12 +292,8 @@ def load_tokenizer(model_path: Path) -> MistralTokenizer: @torch.no_grad() -def generate( - prompts: List[str], model: Transformer, tokenizer: Tokenizer, max_tokens: int -): - encoded_prompts = [ - tokenizer.encode(prompt, bos=True, eos=False) for prompt in prompts - ] +def generate(prompts: List[str], model: Transformer, tokenizer: Tokenizer, max_tokens: int): + encoded_prompts = [tokenizer.encode(prompt, bos=True, eos=False) for prompt in prompts] prompt_lens = [len(x) for x in encoded_prompts] min_prompt_len = min(prompt_lens) max_prompt_len = max(prompt_lens) @@ -333,24 +317,18 @@ def generate( # decode generated = [] all_logprobs = [ - logprobs[:, :-1, :] - .gather(2, input_tokens[:, 1:min_prompt_len, None]) - .squeeze(-1), + logprobs[:, :-1, :].gather(2, input_tokens[:, 1:min_prompt_len, None]).squeeze(-1), ] cur_pos = min_prompt_len for _ in range(max_tokens): next_token = torch.argmax(logprobs[:, -1, :], dim=-1) if cur_pos < input_mask.shape[1]: - next_token = torch.where( - input_mask[:, cur_pos], input_tokens[:, cur_pos], next_token - ) + next_token = torch.where(input_mask[:, cur_pos], input_tokens[:, cur_pos], next_token) all_logprobs.append( logprobs[:, -1, :].gather(1, next_token[:, None]), ) generated.append(next_token[:, None]) - logits = model.forward( - next_token[:, None], torch.LongTensor([cur_pos]).to(next_token) - ) + logits = model.forward(next_token[:, None], torch.LongTensor([cur_pos]).to(next_token)) logprobs = nn.functional.log_softmax(logits, dim=-1) cur_pos += 1 diff --git a/src/mistral_inference/args.py b/src/mistral_inference/args.py new file mode 100644 index 0000000..dbf9f81 --- /dev/null +++ b/src/mistral_inference/args.py @@ -0,0 +1,49 @@ +from dataclasses import dataclass +from typing import Optional + +from simple_parsing.helpers import Serializable + +from mistral_inference.lora import LoraArgs +from mistral_inference.moe import MoeArgs + + +@dataclass +class TransformerArgs(Serializable): + dim: int + n_layers: int + head_dim: int + hidden_dim: int + n_heads: int + n_kv_heads: int + norm_eps: float + vocab_size: int + + max_batch_size: int = 0 + + # For rotary embeddings. If not set, will be inferred + rope_theta: Optional[float] = None + # If this is set, we will use MoE layers instead of dense layers. + moe: Optional[MoeArgs] = None + # If this is set, we will load LoRA linear layers instead of linear layers. + lora: Optional[LoraArgs] = None + model_type: str = "transformer" + + def __post_init__(self): + assert self.model_type == "transformer", self.model_type + + +@dataclass +class MambaArgs(Serializable): + dim: int + n_layers: int + vocab_size: int + n_groups: int + rms_norm: bool + residual_in_fp32: bool + fused_add_norm: bool + pad_vocab_size_multiple: int + tie_embeddings: bool + model_type: str = "mamba" + + def __post_init__(self): + assert self.model_type == "mamba", self.model_type diff --git a/src/mistral_inference/cache.py b/src/mistral_inference/cache.py index a3b4725..93cfb1c 100644 --- a/src/mistral_inference/cache.py +++ b/src/mistral_inference/cache.py @@ -24,9 +24,7 @@ class CacheInputMetadata: seqlens: List[int] -def interleave_list( - l1: List[torch.Tensor], l2: List[torch.Tensor] -) -> List[torch.Tensor]: +def interleave_list(l1: List[torch.Tensor], l2: List[torch.Tensor]) -> List[torch.Tensor]: assert len(l1) == len(l2) return [v for pair in zip(l1, l2) for v in pair] @@ -55,9 +53,7 @@ def update(self, xk: torch.Tensor, xv: torch.Tensor) -> None: flat_cache_k.index_copy_(0, self.metadata.cache_positions, xk) flat_cache_v.index_copy_(0, self.metadata.cache_positions, xv) - def interleave_kv( - self, xk: torch.Tensor, xv: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: + def interleave_kv(self, xk: torch.Tensor, xv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ This is a naive implementation and not optimized for speed. """ @@ -71,17 +67,11 @@ def interleave_kv( # Make it a list of [(T, H, D)] xk: Tuple[torch.Tensor] = torch.split(xk, self.metadata.seqlens) # type: ignore xv: Tuple[torch.Tensor] = torch.split(xv, self.metadata.seqlens) # type: ignore - assert len(xk) == len( - self.kv_seqlens - ), f"Batch size is {len(self.kv_seqlens)}, got {len(xk)}" + assert len(xk) == len(self.kv_seqlens), f"Batch size is {len(self.kv_seqlens)}, got {len(xk)}" # Retrieve cache - cache_k = [ - cache_k[:seq_len] for cache_k, seq_len in zip(self.cache_k, self.kv_seqlens) - ] - cache_v = [ - cache_v[:seq_len] for cache_v, seq_len in zip(self.cache_v, self.kv_seqlens) - ] + cache_k = [cache_k[:seq_len] for cache_k, seq_len in zip(self.cache_k, self.kv_seqlens)] + cache_v = [cache_v[:seq_len] for cache_v, seq_len in zip(self.cache_v, self.kv_seqlens)] interleaved_k = interleave_list(cache_k, list(xk)) interleaved_v = interleave_list(cache_v, list(xv)) @@ -127,28 +117,20 @@ def __init__( self.n_kv_heads = n_kv_heads self.head_dim = head_dim - self.cache_k = torch.empty( - (n_layers, max_batch_size, max_seq_len, n_kv_heads, head_dim) - ) - self.cache_v = torch.empty( - (n_layers, max_batch_size, max_seq_len, n_kv_heads, head_dim) - ) + self.cache_k = torch.empty((n_layers, max_batch_size, max_seq_len, n_kv_heads, head_dim)) + self.cache_v = torch.empty((n_layers, max_batch_size, max_seq_len, n_kv_heads, head_dim)) # holds the valid length for each batch element in the cache self.kv_seqlens: Optional[torch.Tensor] = None def get_view(self, layer_id: int, metadata: CacheInputMetadata) -> CacheView: assert self.kv_seqlens is not None - return CacheView( - self.cache_k[layer_id], self.cache_v[layer_id], metadata, self.kv_seqlens - ) + return CacheView(self.cache_k[layer_id], self.cache_v[layer_id], metadata, self.kv_seqlens) def reset(self) -> None: self.kv_seqlens = None def init_kvseqlens(self, batch_size: int) -> None: - self.kv_seqlens = torch.zeros( - (batch_size,), device=self.device, dtype=torch.long - ) + self.kv_seqlens = torch.zeros((batch_size,), device=self.device, dtype=torch.long) @property def device(self) -> torch.device: @@ -180,9 +162,9 @@ def get_input_metadata(self, seqlens: List[int]) -> CacheInputMetadata: assert len(seqlens) > 0, seqlens cached_elements = torch.tensor(seqlens, device=self.device, dtype=torch.long) - positions = torch.cat( - [torch.arange(pos, pos + seqlen) for pos, seqlen in zip(seqpos, seqlens)] - ).to(device=self.device, dtype=torch.long) + positions = torch.cat([torch.arange(pos, pos + seqlen) for pos, seqlen in zip(seqpos, seqlens)]).to( + device=self.device, dtype=torch.long + ) batch_idx = torch.tensor( sum([[i] * seqlen for i, seqlen in enumerate(seqlens)], []), @@ -195,24 +177,19 @@ def get_input_metadata(self, seqlens: List[int]) -> CacheInputMetadata: subsequent_prefill = any(seqlen > 1 for seqlen in seqlens) if first_prefill: assert all([pos == 0 for pos in seqpos]), seqpos - mask = BlockDiagonalCausalMask.from_seqlens(seqlens).make_local_attention( - self.max_seq_len - ) + mask = BlockDiagonalCausalMask.from_seqlens(seqlens).make_local_attention(self.max_seq_len) elif subsequent_prefill: mask = BlockDiagonalMask.from_seqlens( q_seqlen=seqlens, kv_seqlen=[ - s + cached_s.clamp(max=self.max_seq_len).item() - for (s, cached_s) in zip(seqlens, self.kv_seqlens) + s + cached_s.clamp(max=self.max_seq_len).item() for (s, cached_s) in zip(seqlens, self.kv_seqlens) ], ).make_local_attention_from_bottomright(self.max_seq_len) else: mask = BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( q_seqlen=seqlens, kv_padding=self.max_seq_len, - kv_seqlen=(self.kv_seqlens + cached_elements) - .clamp(max=self.max_seq_len) - .tolist(), + kv_seqlen=(self.kv_seqlens + cached_elements).clamp(max=self.max_seq_len).tolist(), ) return CacheInputMetadata( diff --git a/src/mistral_inference/generate.py b/src/mistral_inference/generate.py index c7f0bca..c9e35c5 100644 --- a/src/mistral_inference/generate.py +++ b/src/mistral_inference/generate.py @@ -3,7 +3,40 @@ import torch from mistral_inference.cache import BufferCache -from mistral_inference.model import Transformer +from mistral_inference.mamba import Mamba +from mistral_inference.transformer import Transformer + + +@torch.inference_mode() +def generate_mamba( + encoded_prompts: List[List[int]], + model: Mamba, + *, + max_tokens: int, + temperature: float, + chunk_size: Optional[int] = None, + eos_id: Optional[int] = None, +) -> Tuple[List[List[int]], List[List[float]]]: + input_ids = torch.tensor(encoded_prompts, device=model.device) + output = model.model.generate( + input_ids=input_ids, + max_length=input_ids.shape[-1] + max_tokens, + cg=True, + return_dict_in_generate=True, + output_scores=True, + enable_timing=False, + eos_token_id=eos_id, + temperature=temperature, + top_p=0.8, + ) + generated_tokens = output.sequences[:, input_ids.shape[-1] :].tolist() + + _logprobs: List[List[float]] = [[] for _ in range(len(generated_tokens))] + for seq_idx, batch_score in enumerate(output.scores): + for batch_idx, score in enumerate(batch_score.tolist()): + _logprobs[batch_idx].append(score[generated_tokens[batch_idx][seq_idx]]) + + return generated_tokens, _logprobs @torch.inference_mode() @@ -14,7 +47,7 @@ def generate( max_tokens: int, temperature: float, chunk_size: Optional[int] = None, - eos_id: Optional[int] = None + eos_id: Optional[int] = None, ) -> Tuple[List[List[int]], List[List[float]]]: model = model.eval() B, V = len(encoded_prompts), model.args.vocab_size @@ -57,26 +90,16 @@ def generate( # Pass > 1 last_token_logits = torch.log_softmax(last_token_prelogits, dim=-1) for i_seq in range(B): - logprobs[i_seq].append( - last_token_logits[i_seq, prompt_chunks[i_seq][0]].item() - ) + logprobs[i_seq].append(last_token_logits[i_seq, prompt_chunks[i_seq][0]].item()) offset = 0 for i_seq, sequence in enumerate(prompt_chunks): - logprobs[i_seq].extend( - [ - logits[offset + i, sequence[i + 1]].item() - for i in range(len(sequence) - 1) - ] - ) + logprobs[i_seq].extend([logits[offset + i, sequence[i + 1]].item() for i in range(len(sequence) - 1)]) offset += len(sequence) last_token_prelogits = prelogits.index_select( 0, - torch.tensor( - [len(p) for p in prompt_chunks], device=prelogits.device - ).cumsum(dim=0) - - 1, + torch.tensor([len(p) for p in prompt_chunks], device=prelogits.device).cumsum(dim=0) - 1, ) assert last_token_prelogits.shape == (B, V) diff --git a/src/mistral_inference/lora.py b/src/mistral_inference/lora.py index 2ab978c..3092429 100644 --- a/src/mistral_inference/lora.py +++ b/src/mistral_inference/lora.py @@ -1,7 +1,7 @@ import logging from dataclasses import dataclass from pathlib import Path -from typing import Dict, NamedTuple, Union +from typing import Any, Dict, NamedTuple, Union import safetensors.torch import torch @@ -14,7 +14,7 @@ class LoraArgs(Serializable): rank: int scaling: float - def __post_init__(self): + def __post_init__(self) -> None: assert self.rank > 0 assert self.scaling > 0.0 @@ -63,16 +63,17 @@ def __init__( self.linear = nn.Linear(self.in_features, self.out_features, bias=self.bias) # make sure no LoRA weights are marked as "missing" in load_state_dict - def ignore_missing_keys(m: nn.Module, incompatible_keys: NamedTuple): + def ignore_missing_keys(m: nn.Module, incompatible_keys: NamedTuple) -> None: incompatible_keys.missing_keys[:] = [] # type: ignore self.register_load_state_dict_post_hook(ignore_missing_keys) - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: lora = self.lora_B(self.lora_A(x)) - return self.linear(x) + lora * self.scaling + result: torch.Tensor = self.linear(x) + lora * self.scaling + return result - def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + def _load_from_state_dict(self, state_dict: Dict[str, Any], prefix: str, *args, **kwargs) -> None: # type: ignore[no-untyped-def] key_name = prefix + "weight" # full checkpoint @@ -82,18 +83,14 @@ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): # load frozen weights state_dict = { "linear.weight": w_ref, - "lora_A.weight": torch.zeros_like( - self.lora_A.weight, device=w_ref.device, dtype=w_ref.dtype - ), - "lora_B.weight": torch.zeros_like( - self.lora_B.weight, device=w_ref.device, dtype=w_ref.dtype - ), + "lora_A.weight": torch.zeros_like(self.lora_A.weight, device=w_ref.device, dtype=w_ref.dtype), + "lora_B.weight": torch.zeros_like(self.lora_B.weight, device=w_ref.device, dtype=w_ref.dtype), } self.load_state_dict(state_dict, assign=True, strict=True) class LoRALoaderMixin: - def load_lora(self, lora_path: Union[Path, str], scaling: float = 2.0): + def load_lora(self, lora_path: Union[Path, str], scaling: float = 2.0) -> None: """Loads LoRA checkpoint""" lora_path = Path(lora_path) @@ -103,47 +100,39 @@ def load_lora(self, lora_path: Union[Path, str], scaling: float = 2.0): self._load_lora_state_dict(state_dict, scaling=scaling) - def _load_lora_state_dict( - self, lora_state_dict: Dict[str, torch.Tensor], scaling: float = 2.0 - ): + def _load_lora_state_dict(self, lora_state_dict: Dict[str, torch.Tensor], scaling: float = 2.0) -> None: """Loads LoRA state_dict""" - lora_dtypes = set([p.dtype for p in lora_state_dict.values()]) assert ( len(lora_dtypes) == 1 ), f"LoRA weights have multiple different dtypes {lora_dtypes}. All weights need to have the same dtype" lora_dtype = lora_dtypes.pop() - assert ( - lora_dtype == self.dtype - ), f"LoRA weights dtype differs from model's dtype {lora_dtype} != {self.dtype}" + assert lora_dtype == self.dtype, f"LoRA weights dtype differs from model's dtype {lora_dtype} != {self.dtype}" # type: ignore[attr-defined] assert all("lora" in key for key in lora_state_dict.keys()) # move tensors to device - lora_state_dict = {k: v.to(self.device) for k, v in lora_state_dict.items()} + lora_state_dict = {k: v.to(self.device) for k, v in lora_state_dict.items()} # type: ignore[attr-defined] - state_dict = self.state_dict() + state_dict = self.state_dict() # type: ignore[attr-defined] - if self.args.lora is None: + if self.args.lora is None: # type: ignore[attr-defined] logging.info("Loading and merging LoRA weights...") # replace every nn.Linear with a LoRALinear with 'meta' device except the output layer - named_modules = dict(self.named_modules()) + named_modules = dict(self.named_modules()) # type: ignore[attr-defined] for name, module in named_modules.items(): if isinstance(module, nn.Linear) and name != "output": layer_id = name.split(".")[1] - if layer_id not in self.layers: + if layer_id not in self.layers: # type: ignore[attr-defined] logging.debug( "Skipping parameter %s at pipeline rank %d", name, - self.pipeline_rank, + self.pipeline_rank, # type: ignore[attr-defined] ) - else: + elif (name + ".lora_B.weight") in lora_state_dict: weight = ( module.weight - + ( - lora_state_dict[name + ".lora_B.weight"] - @ lora_state_dict[name + ".lora_A.weight"] - ) + + (lora_state_dict[name + ".lora_B.weight"] @ lora_state_dict[name + ".lora_A.weight"]) * scaling ) @@ -154,13 +143,13 @@ def _load_lora_state_dict( state_dict.update(lora_state_dict) layer_id = k.split(".")[1] - if layer_id in self.layers: + if layer_id in self.layers: # type: ignore[attr-defined] state_dict[k] = v else: logging.debug( "Skipping parameter %s at pipeline rank %d", k, - self.pipeline_rank, + self.pipeline_rank, # type: ignore[attr-defined] ) - self.load_state_dict(state_dict, strict=True) + self.load_state_dict(state_dict, strict=True) # type: ignore[attr-defined] diff --git a/src/mistral_inference/main.py b/src/mistral_inference/main.py index a5ef3a0..9cafec8 100644 --- a/src/mistral_inference/main.py +++ b/src/mistral_inference/main.py @@ -1,7 +1,9 @@ +import json import logging import os +import warnings from pathlib import Path -from typing import List, Optional +from typing import List, Optional, Type, Union import fire # type: ignore import torch @@ -11,8 +13,9 @@ from mistral_common.tokens.tokenizers.base import Tokenizer from mistral_common.tokens.tokenizers.mistral import MistralTokenizer -from mistral_inference.generate import generate -from mistral_inference.model import Transformer +from mistral_inference.generate import generate, generate_mamba +from mistral_inference.mamba import Mamba +from mistral_inference.transformer import Transformer def is_torchrun() -> bool: @@ -21,9 +24,7 @@ def is_torchrun() -> bool: def load_tokenizer(model_path: Path) -> MistralTokenizer: - tokenizer = [ - f for f in os.listdir(Path(model_path)) if f.startswith("tokenizer.model") - ] + tokenizer = [f for f in os.listdir(Path(model_path)) if f.startswith("tokenizer.model")] assert ( len(tokenizer) > 0 ), f"No tokenizer found in {model_path}, make sure to place a `tokenizer.model.[v1,v2,v3]` file in {model_path}." @@ -33,13 +34,28 @@ def load_tokenizer(model_path: Path) -> MistralTokenizer: mistral_tokenizer = MistralTokenizer.from_file(str(model_path / tokenizer[0])) - logging.info( - f"Loaded tokenizer of type {mistral_tokenizer.instruct_tokenizer.__class__}" - ) + logging.info(f"Loaded tokenizer of type {mistral_tokenizer.instruct_tokenizer.__class__}") return mistral_tokenizer +def get_model_cls(model_path: str) -> Union[Type[Mamba], Type[Transformer]]: + with open(Path(model_path) / "params.json", "r") as f: + args_dict = json.load(f) + + return {"mamba": Mamba, "transformer": Transformer}[args_dict.get("model_type", "transformer")] # type: ignore[return-value] + + +def pad_and_convert_to_tensor(list_of_lists: List[List[int]], pad_id: int) -> List[List[int]]: + # Determine the length of the longest list + max_len = max(len(lst) for lst in list_of_lists) + + # Left pad each list to the maximum length + padded_lists = [[pad_id] * (max_len - len(lst)) + lst for lst in list_of_lists] + + return padded_lists + + def interactive( model_path: str, max_tokens: int = 35, @@ -61,13 +77,12 @@ def interactive( mistral_tokenizer: MistralTokenizer = load_tokenizer(Path(model_path)) tokenizer: Tokenizer = mistral_tokenizer.instruct_tokenizer.tokenizer - transformer = Transformer.from_folder( - Path(model_path), max_batch_size=3, num_pipeline_ranks=num_pipeline_ranks - ) + model_cls = get_model_cls(model_path) + model = model_cls.from_folder(Path(model_path), max_batch_size=3, num_pipeline_ranks=num_pipeline_ranks) # load LoRA if lora_path is not None: - transformer.load_lora(Path(lora_path)) + model.load_lora(Path(lora_path)) prompt: str = "" messages: List[UserMessage | AssistantMessage] = [] @@ -80,9 +95,7 @@ def interactive( messages += [UserMessage(content=user_input)] chat_completion_request = ChatCompletionRequest(messages=messages) - tokens = mistral_tokenizer.encode_chat_completion( - chat_completion_request - ).tokens + tokens = mistral_tokenizer.encode_chat_completion(chat_completion_request).tokens else: prompt += user_input @@ -98,9 +111,10 @@ def interactive( if not should_print: tokens = int(length_tensor.item()) * [0] - generated_tokens, _ = generate( + generate_fn = generate if isinstance(model, Transformer) else generate_mamba + generated_tokens, _ = generate_fn( # type: ignore[operator] [tokens], - transformer, + model, max_tokens=max_tokens, temperature=temperature, eos_id=tokenizer.eos_id, @@ -134,12 +148,11 @@ def demo( should_print = True num_pipeline_ranks = 1 - transformer = Transformer.from_folder( - Path(model_path), max_batch_size=3, num_pipeline_ranks=num_pipeline_ranks - ) + model_cls = get_model_cls(model_path) + model = model_cls.from_folder(Path(model_path), max_batch_size=3, num_pipeline_ranks=num_pipeline_ranks) # load LoRA if lora_path is not None: - transformer.load_lora(Path(lora_path)) + model.load_lora(Path(lora_path)) mistral_tokenizer: MistralTokenizer = load_tokenizer(Path(model_path)) tokenizer: Tokenizer = mistral_tokenizer.instruct_tokenizer.tokenizer @@ -150,21 +163,30 @@ def demo( "This is a third test, mistral AI is very good at testing. ", ] - encoded_prompts = [ - tokenizer.encode(prompt, bos=True, eos=False) for prompt in prompts - ] + encoded_prompts = [tokenizer.encode(prompt, bos=True, eos=False) for prompt in prompts] + + if isinstance(model, Transformer): + generate_fn = generate + else: + generate_fn = generate_mamba # type: ignore[assignment] + warnings.warn( + "Batched generation is not correctly supported at the moment and therefore might lead to worse results " + "as compared to non-batched generation. " + "See https://github.com/state-spaces/mamba/issues/66#issuecomment-1862349718 for more information." + ) + encoded_prompts = pad_and_convert_to_tensor(encoded_prompts, mistral_tokenizer.instruct_tokenizer.BOS) # type: ignore[attr-defined] - generated_tokens, _logprobs = generate( + generated_tokens, _logprobs = generate_fn( encoded_prompts, - transformer, + model, # type: ignore[arg-type] max_tokens=max_tokens, temperature=temperature, eos_id=tokenizer.eos_id, ) generated_words = [] - for i, x in enumerate(encoded_prompts): - generated_words.append(tokenizer.decode(x + generated_tokens[i])) + for i, x in enumerate(generated_tokens): + generated_words.append(tokenizer.decode(encoded_prompts[i] + x)) res = generated_words diff --git a/src/mistral_inference/mamba.py b/src/mistral_inference/mamba.py new file mode 100644 index 0000000..02745e3 --- /dev/null +++ b/src/mistral_inference/mamba.py @@ -0,0 +1,83 @@ +import json +from pathlib import Path +from typing import List, Optional, Union + +import safetensors +import torch +import torch.nn as nn + +from mistral_inference.args import MambaArgs +from mistral_inference.cache import BufferCache +from mistral_inference.model import ModelBase + +_is_mamba_installed = False +try: + from mamba_ssm.models.config_mamba import MambaConfig + from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel + + _is_mamba_installed = True +except ImportError: + _is_mamba_installed = False + + +class Mamba(ModelBase, nn.Module): + def __init__(self, args: MambaArgs): + super().__init__() + self.args = args + assert _is_mamba_installed, "Mamba is not installed. Please install it using `pip install mamba-ssm`." + + # make sure naming is consistent with `mamba_ssm` + config = MambaConfig( + d_model=args.dim, + n_layer=args.n_layers, + vocab_size=args.vocab_size, + ssm_cfg={"ngroups": args.n_groups, "layer": "Mamba2"}, + attn_layer_idx=[], + attn_cfg={}, + rms_norm=args.rms_norm, + residual_in_fp32=args.residual_in_fp32, + fused_add_norm=args.fused_add_norm, + pad_vocab_size_multiple=args.pad_vocab_size_multiple, + tie_embeddings=args.tie_embeddings, + ) + self.model = MambaLMHeadModel(config) + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + def forward( + self, + input_ids: torch.Tensor, + seqlens: List[int], # not supported for now + cache: Optional[BufferCache] = None, # not supported for now + ) -> torch.Tensor: + lm_output = self.model(input_ids) + result: torch.Tensor = lm_output.logits + return result + + @staticmethod + def from_folder( + folder: Union[Path, str], + max_batch_size: int = 1, + num_pipeline_ranks: int = 1, + device: Union[torch.device, str] = "cuda", + dtype: Optional[torch.dtype] = None, + ) -> "Mamba": + with open(Path(folder) / "params.json", "r") as f: + model_args = MambaArgs.from_dict(json.load(f)) + + with torch.device("meta"): + model = Mamba(model_args) + + model_file = Path(folder) / "consolidated.safetensors" + + assert model_file.exists(), f"Make sure {model_file} exists." + loaded = safetensors.torch.load_file(str(model_file)) + + model.load_state_dict(loaded, assign=True, strict=True) + return model.to(device=device, dtype=dtype) diff --git a/src/mistral_inference/model.py b/src/mistral_inference/model.py index f5ebd7a..be41fcd 100644 --- a/src/mistral_inference/model.py +++ b/src/mistral_inference/model.py @@ -1,409 +1,43 @@ -import json -import logging -import math -from dataclasses import dataclass -from functools import partial +from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, List, Mapping, Optional, Tuple, Union +from typing import List, Optional, Union -import safetensors.torch import torch -from simple_parsing.helpers import Serializable -from torch import nn -from xformers.ops.fmha import memory_efficient_attention # type: ignore +import torch.nn as nn -from mistral_inference.cache import ( - BufferCache, - CacheInputMetadata, - CacheView, -) -from mistral_inference.lora import LoraArgs, LoRALinear, LoRALoaderMixin -from mistral_inference.moe import MoeArgs, MoeLayer -from mistral_inference.rope import apply_rotary_emb, precompute_freqs_cis +from mistral_inference.cache import BufferCache -@dataclass -class ModelArgs(Serializable): - dim: int - n_layers: int - head_dim: int - hidden_dim: int - n_heads: int - n_kv_heads: int - norm_eps: float - vocab_size: int - - max_batch_size: int = 0 - - # For rotary embeddings. If not set, will be inferred - rope_theta: Optional[float] = None - # If this is set, we will use MoE layers instead of dense layers. - moe: Optional[MoeArgs] = None - # If this is set, we will load LoRA linear layers instead of linear layers. - lora: Optional[LoraArgs] = None - - -@dataclass -class SimpleInputMetadata: - # rope absolute positions - positions: torch.Tensor - - @staticmethod - def from_seqlens(seqlens: List[int], device: torch.device) -> "SimpleInputMetadata": - return SimpleInputMetadata( - positions=torch.cat([torch.arange(0, seqlen) for seqlen in seqlens]).to( - device=device, dtype=torch.long - ) - ) - - -def repeat_kv( - keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int -) -> Tuple[torch.Tensor, torch.Tensor]: - keys = torch.repeat_interleave(keys, repeats=repeats, dim=dim) - values = torch.repeat_interleave(values, repeats=repeats, dim=dim) - return keys, values - - -def maybe_lora(args: ModelArgs) -> Union[nn.Linear, LoRALinear]: - if args.lora is None: - return nn.Linear - else: - return partial(LoRALinear, rank=args.lora.rank, scaling=args.lora.scaling) - - -class Attention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - - self.n_heads: int = args.n_heads - self.head_dim: int = args.head_dim - self.n_kv_heads: int = args.n_kv_heads - - self.repeats = self.n_heads // self.n_kv_heads - - self.scale = self.args.head_dim**-0.5 - - MaybeLora = maybe_lora(args) - self.wq = MaybeLora(args.dim, args.n_heads * args.head_dim, bias=False) - self.wk = MaybeLora(args.dim, args.n_kv_heads * args.head_dim, bias=False) - self.wv = MaybeLora(args.dim, args.n_kv_heads * args.head_dim, bias=False) - self.wo = MaybeLora(args.n_heads * args.head_dim, args.dim, bias=False) - - def forward( - self, - x: torch.Tensor, - freqs_cis: torch.Tensor, - cache: Optional[CacheView], - ) -> torch.Tensor: - seqlen_sum, _ = x.shape - - xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) - xq = xq.view(seqlen_sum, self.n_heads, self.head_dim) - xk = xk.view(seqlen_sum, self.n_kv_heads, self.head_dim) - xv = xv.view(seqlen_sum, self.n_kv_heads, self.head_dim) - xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) - - if cache is None: - key, val = xk, xv - elif cache.prefill: - key, val = cache.interleave_kv(xk, xv) - cache.update(xk, xv) - else: - cache.update(xk, xv) - key, val = cache.key, cache.value - key = key.view( - seqlen_sum * cache.max_seq_len, self.n_kv_heads, self.head_dim - ) - val = val.view( - seqlen_sum * cache.max_seq_len, self.n_kv_heads, self.head_dim - ) - - # Repeat keys and values to match number of query heads - key, val = repeat_kv(key, val, self.repeats, dim=1) - - # xformers requires (B=1, S, H, D) - xq, key, val = xq[None, ...], key[None, ...], val[None, ...] - output = memory_efficient_attention( - xq, key, val, None if cache is None else cache.mask - ) - output = output.view(seqlen_sum, self.n_heads * self.head_dim) - - assert isinstance(output, torch.Tensor) - - return self.wo(output) # type: ignore - - -class FeedForward(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - MaybeLora = maybe_lora(args) - self.w1 = MaybeLora(args.dim, args.hidden_dim, bias=False) - self.w2 = MaybeLora(args.hidden_dim, args.dim, bias=False) - self.w3 = MaybeLora(args.dim, args.hidden_dim, bias=False) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) # type: ignore - - -class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): +class ModelBase(nn.Module, ABC): + def __init__(self) -> None: super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x: torch.Tensor) -> torch.Tensor: - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - output = self._norm(x.float()).type_as(x) - return output * self.weight - - -class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.n_heads = args.n_heads - self.dim = args.dim - self.attention = Attention(args) - self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) - self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) - self.args = args - - self.feed_forward: nn.Module - if args.moe is not None: - self.feed_forward = MoeLayer( - experts=[FeedForward(args=args) for _ in range(args.moe.num_experts)], - gate=nn.Linear(args.dim, args.moe.num_experts, bias=False), - moe_args=args.moe, - ) - else: - self.feed_forward = FeedForward(args=args) - - def forward( - self, x: torch.Tensor, freqs_cis: torch.Tensor, cache: Optional[CacheView] - ) -> torch.Tensor: - r = self.attention.forward(self.attention_norm(x), freqs_cis, cache) - h = x + r - r = self.feed_forward.forward(self.ffn_norm(h)) - out = h + r - return out - - -class Transformer(nn.Module, LoRALoaderMixin): - def __init__( - self, - args: ModelArgs, - pipeline_rank: int = 0, - num_pipeline_ranks: int = 1, - ): - super().__init__() - self.args = args - self.vocab_size = args.vocab_size - self.n_layers = args.n_layers - self._precomputed_freqs_cis: Optional[torch.Tensor] = None - assert self.vocab_size > 0 - assert pipeline_rank < num_pipeline_ranks, (pipeline_rank, num_pipeline_ranks) - self.pipeline_rank = pipeline_rank - self.num_pipeline_ranks = num_pipeline_ranks - # Modules specific to some ranks: - self.tok_embeddings: Optional[nn.Embedding] = None - self.norm: Optional[RMSNorm] = None - self.output: Optional[nn.Linear] = None - if pipeline_rank == 0: - self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) - if pipeline_rank == num_pipeline_ranks - 1: - self.norm = RMSNorm(args.dim, eps=args.norm_eps) - self.output = nn.Linear(args.dim, args.vocab_size, bias=False) - # Initialize all layers but slice off those not of this rank. - layers = [TransformerBlock(args=args) for _ in range(args.n_layers)] - num_layers_per_rank = math.ceil(self.n_layers / self.num_pipeline_ranks) - offset = self.pipeline_rank * num_layers_per_rank - end = min(self.n_layers, offset + num_layers_per_rank) - self.layers = nn.ModuleDict({str(i): layers[i] for i in range(offset, end)}) - self.n_local_layers = len(self.layers) @property + @abstractmethod def dtype(self) -> torch.dtype: - return next(self.parameters()).dtype + pass @property + @abstractmethod def device(self) -> torch.device: - return next(self.parameters()).device - - @property - def freqs_cis(self) -> torch.Tensor: - # We cache freqs_cis but need to take care that it is on the right device - # and has the right dtype (complex64). The fact that the dtype is different - # from the module's dtype means we cannot register it as a buffer - if self._precomputed_freqs_cis is None: - # default to 10**6 - theta = self.args.rope_theta or 1000000.0 - self._precomputed_freqs_cis = precompute_freqs_cis( - self.args.head_dim, 128_000, theta - ) - - if self._precomputed_freqs_cis.device != self.device: - self._precomputed_freqs_cis = self._precomputed_freqs_cis.to( - device=self.device - ) - return self._precomputed_freqs_cis - - def forward_partial( - self, - input_ids: torch.Tensor, - seqlens: List[int], - cache: Optional[BufferCache] = None, - ) -> torch.Tensor: - """Local forward pass. - - If doing pipeline parallelism, this will return the activations of the last layer of this stage. - For the last stage, this will return the normalized final embeddings. - """ - assert ( - len(seqlens) <= self.args.max_batch_size - ), f"Max batch size is {self.args.max_batch_size}, got batch size of {len(seqlens)}" - (num_toks,) = input_ids.shape - assert sum(seqlens) == num_toks, (sum(seqlens), num_toks) - - input_metadata: Union[CacheInputMetadata, SimpleInputMetadata] - - if cache is not None: - input_metadata = cache.get_input_metadata(seqlens) - else: - input_metadata = SimpleInputMetadata.from_seqlens(seqlens, self.device) - - if self.pipeline_rank == 0: - assert self.tok_embeddings is not None - h = self.tok_embeddings(input_ids) - else: - h = torch.empty( - num_toks, self.args.dim, device=self.device, dtype=self.dtype - ) - torch.distributed.recv(h, src=self.pipeline_rank - 1) - - freqs_cis = self.freqs_cis[input_metadata.positions] - - for local_layer_id, layer in enumerate(self.layers.values()): - if cache is not None: - assert input_metadata is not None - assert isinstance(input_metadata, CacheInputMetadata) - cache_view = cache.get_view(local_layer_id, input_metadata) - else: - cache_view = None - h = layer(h, freqs_cis, cache_view) - - if cache is not None: - cache.update_seqlens(seqlens) - if self.pipeline_rank < self.num_pipeline_ranks - 1: - torch.distributed.send(h, dst=self.pipeline_rank + 1) - return h # type: ignore - else: - # Last rank has a final normalization step. - assert self.norm is not None - return self.norm(h) # type: ignore + pass + @abstractmethod def forward( self, input_ids: torch.Tensor, - seqlens: List[int], - cache: Optional[BufferCache] = None, + seqlens: List[int], # not supported for now + cache: Optional[BufferCache] = None, # not supported for now ) -> torch.Tensor: - h = self.forward_partial(input_ids, seqlens, cache=cache) - if self.pipeline_rank < self.num_pipeline_ranks - 1: - # ignore the intermediate activations as we'll get the final output from - # the last stage - outs = torch.empty( - h.shape[0], self.vocab_size, device=h.device, dtype=h.dtype - ) - else: - assert self.output is not None - outs = self.output(h) - if self.num_pipeline_ranks > 1: - torch.distributed.broadcast(outs, src=self.num_pipeline_ranks - 1) - return outs.float() - - def load_state_dict( - self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False - ) -> None: - state_to_load = {} - skipped = set([]) - for k, v in state_dict.items(): - if k.startswith("tok_embeddings"): - if self.pipeline_rank == 0: - state_to_load[k] = v - else: - logging.debug( - "Skipping parameter %s at pipeline rank %d", - k, - self.pipeline_rank, - ) - skipped.add(k) - elif k.startswith("norm") or k.startswith("output"): - if self.pipeline_rank == self.num_pipeline_ranks - 1: - state_to_load[k] = v - else: - logging.debug( - "Skipping parameter %s at pipeline rank %d", - k, - self.pipeline_rank, - ) - skipped.add(k) - elif k.startswith("layers"): - layer_id = k.split(".")[1] - if layer_id in self.layers: - state_to_load[k] = v - else: - logging.debug( - "Skipping parameter %s at pipeline rank %d", - k, - self.pipeline_rank, - ) - skipped.add(k) - else: - raise ValueError(f"Unexpected key {k}") - assert set(state_dict.keys()) == skipped.union(set(state_to_load.keys())) - super().load_state_dict(state_to_load, strict=strict, assign=assign) + pass @staticmethod + @abstractmethod def from_folder( folder: Union[Path, str], max_batch_size: int = 1, num_pipeline_ranks: int = 1, device: Union[torch.device, str] = "cuda", dtype: Optional[torch.dtype] = None, - ) -> "Transformer": - with open(Path(folder) / "params.json", "r") as f: - model_args = ModelArgs.from_dict(json.load(f)) - model_args.max_batch_size = max_batch_size - if num_pipeline_ranks > 1: - pipeline_rank = torch.distributed.get_rank() - else: - pipeline_rank = 0 - with torch.device("meta"): - model = Transformer( - model_args, - pipeline_rank=pipeline_rank, - num_pipeline_ranks=num_pipeline_ranks, - ) - - pt_model_file = Path(folder) / "consolidated.00.pth" - safetensors_model_file = Path(folder) / "consolidated.safetensors" - - assert ( - pt_model_file.exists() or safetensors_model_file.exists() - ), f"Make sure either {pt_model_file} or {safetensors_model_file} exists" - assert not ( - pt_model_file.exists() and safetensors_model_file.exists() - ), f"Both {pt_model_file} and {safetensors_model_file} cannot exist" - - if pt_model_file.exists(): - loaded = torch.load(str(pt_model_file), mmap=True) - else: - loaded = safetensors.torch.load_file(str(safetensors_model_file)) - - model.load_state_dict(loaded, assign=True, strict=True) - - return model.to(device=device, dtype=dtype) + ) -> "ModelBase": + pass diff --git a/src/mistral_inference/moe.py b/src/mistral_inference/moe.py index 043776b..9ce8a8a 100644 --- a/src/mistral_inference/moe.py +++ b/src/mistral_inference/moe.py @@ -23,14 +23,10 @@ def __init__(self, experts: List[nn.Module], gate: nn.Module, moe_args: MoeArgs) def forward(self, inputs: torch.Tensor) -> torch.Tensor: gate_logits = self.gate(inputs) - weights, selected_experts = torch.topk( - gate_logits, self.args.num_experts_per_tok - ) + weights, selected_experts = torch.topk(gate_logits, self.args.num_experts_per_tok) weights = F.softmax(weights, dim=1, dtype=torch.float).to(inputs.dtype) results = torch.zeros_like(inputs) for i, expert in enumerate(self.experts): batch_idx, nth_expert = torch.where(selected_experts == i) - results[batch_idx] += weights[batch_idx, nth_expert, None] * expert( - inputs[batch_idx] - ) + results[batch_idx] += weights[batch_idx, nth_expert, None] * expert(inputs[batch_idx]) return results diff --git a/src/mistral_inference/transformer.py b/src/mistral_inference/transformer.py new file mode 100644 index 0000000..5754684 --- /dev/null +++ b/src/mistral_inference/transformer.py @@ -0,0 +1,367 @@ +import json +import logging +import math +from dataclasses import dataclass +from functools import partial +from pathlib import Path +from typing import Any, List, Mapping, Optional, Tuple, Type, Union + +import safetensors.torch +import torch +from torch import nn +from xformers.ops.fmha import memory_efficient_attention # type: ignore + +from mistral_inference.args import TransformerArgs +from mistral_inference.cache import ( + BufferCache, + CacheInputMetadata, + CacheView, +) +from mistral_inference.lora import LoRALinear, LoRALoaderMixin +from mistral_inference.model import ModelBase +from mistral_inference.moe import MoeLayer +from mistral_inference.rope import apply_rotary_emb, precompute_freqs_cis + + +@dataclass +class SimpleInputMetadata: + # rope absolute positions + positions: torch.Tensor + + @staticmethod + def from_seqlens(seqlens: List[int], device: torch.device) -> "SimpleInputMetadata": + return SimpleInputMetadata( + positions=torch.cat([torch.arange(0, seqlen) for seqlen in seqlens]).to(device=device, dtype=torch.long) + ) + + +def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int) -> Tuple[torch.Tensor, torch.Tensor]: + keys = torch.repeat_interleave(keys, repeats=repeats, dim=dim) + values = torch.repeat_interleave(values, repeats=repeats, dim=dim) + return keys, values + + +def maybe_lora(args: TransformerArgs) -> Union[Type[nn.Linear], partial[LoRALinear]]: + if args.lora is None: + return nn.Linear + else: + return partial(LoRALinear, rank=args.lora.rank, scaling=args.lora.scaling) + + +class Attention(nn.Module): + def __init__(self, args: TransformerArgs): + super().__init__() + self.args = args + + self.n_heads: int = args.n_heads + self.head_dim: int = args.head_dim + self.n_kv_heads: int = args.n_kv_heads + + self.repeats = self.n_heads // self.n_kv_heads + + self.scale = self.args.head_dim**-0.5 + + MaybeLora = maybe_lora(args) + self.wq = MaybeLora(args.dim, args.n_heads * args.head_dim, bias=False) + self.wk = MaybeLora(args.dim, args.n_kv_heads * args.head_dim, bias=False) + self.wv = MaybeLora(args.dim, args.n_kv_heads * args.head_dim, bias=False) + self.wo = MaybeLora(args.n_heads * args.head_dim, args.dim, bias=False) + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + cache: Optional[CacheView], + ) -> torch.Tensor: + seqlen_sum, _ = x.shape + + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + xq = xq.view(seqlen_sum, self.n_heads, self.head_dim) + xk = xk.view(seqlen_sum, self.n_kv_heads, self.head_dim) + xv = xv.view(seqlen_sum, self.n_kv_heads, self.head_dim) + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + if cache is None: + key, val = xk, xv + elif cache.prefill: + key, val = cache.interleave_kv(xk, xv) + cache.update(xk, xv) + else: + cache.update(xk, xv) + key, val = cache.key, cache.value + key = key.view(seqlen_sum * cache.max_seq_len, self.n_kv_heads, self.head_dim) + val = val.view(seqlen_sum * cache.max_seq_len, self.n_kv_heads, self.head_dim) + + # Repeat keys and values to match number of query heads + key, val = repeat_kv(key, val, self.repeats, dim=1) + + # xformers requires (B=1, S, H, D) + xq, key, val = xq[None, ...], key[None, ...], val[None, ...] + output = memory_efficient_attention(xq, key, val, None if cache is None else cache.mask) + output = output.view(seqlen_sum, self.n_heads * self.head_dim) + + assert isinstance(output, torch.Tensor) + + return self.wo(output) # type: ignore + + +class FeedForward(nn.Module): + def __init__(self, args: TransformerArgs): + super().__init__() + + MaybeLora = maybe_lora(args) + self.w1 = MaybeLora(args.dim, args.hidden_dim, bias=False) + self.w2 = MaybeLora(args.hidden_dim, args.dim, bias=False) + self.w3 = MaybeLora(args.dim, args.hidden_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) # type: ignore + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x: torch.Tensor) -> torch.Tensor: + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +class TransformerBlock(nn.Module): + def __init__(self, args: TransformerArgs): + super().__init__() + self.n_heads = args.n_heads + self.dim = args.dim + self.attention = Attention(args) + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.args = args + + self.feed_forward: nn.Module + if args.moe is not None: + self.feed_forward = MoeLayer( + experts=[FeedForward(args=args) for _ in range(args.moe.num_experts)], + gate=nn.Linear(args.dim, args.moe.num_experts, bias=False), + moe_args=args.moe, + ) + else: + self.feed_forward = FeedForward(args=args) + + def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor, cache: Optional[CacheView]) -> torch.Tensor: + r = self.attention.forward(self.attention_norm(x), freqs_cis, cache) + h = x + r + r = self.feed_forward.forward(self.ffn_norm(h)) + out = h + r + return out + + +class Transformer(ModelBase, LoRALoaderMixin): + def __init__( + self, + args: TransformerArgs, + pipeline_rank: int = 0, + num_pipeline_ranks: int = 1, + ): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.n_layers = args.n_layers + self._precomputed_freqs_cis: Optional[torch.Tensor] = None + assert self.vocab_size > 0 + assert pipeline_rank < num_pipeline_ranks, (pipeline_rank, num_pipeline_ranks) + self.pipeline_rank = pipeline_rank + self.num_pipeline_ranks = num_pipeline_ranks + # Modules specific to some ranks: + self.tok_embeddings: Optional[nn.Embedding] = None + self.norm: Optional[RMSNorm] = None + self.output: Optional[nn.Linear] = None + if pipeline_rank == 0: + self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) + if pipeline_rank == num_pipeline_ranks - 1: + self.norm = RMSNorm(args.dim, eps=args.norm_eps) + self.output = nn.Linear(args.dim, args.vocab_size, bias=False) + # Initialize all layers but slice off those not of this rank. + layers = [TransformerBlock(args=args) for _ in range(args.n_layers)] + num_layers_per_rank = math.ceil(self.n_layers / self.num_pipeline_ranks) + offset = self.pipeline_rank * num_layers_per_rank + end = min(self.n_layers, offset + num_layers_per_rank) + self.layers = nn.ModuleDict({str(i): layers[i] for i in range(offset, end)}) + self.n_local_layers = len(self.layers) + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @property + def freqs_cis(self) -> torch.Tensor: + # We cache freqs_cis but need to take care that it is on the right device + # and has the right dtype (complex64). The fact that the dtype is different + # from the module's dtype means we cannot register it as a buffer + if self._precomputed_freqs_cis is None: + # default to 10**6 + theta = self.args.rope_theta or 1000000.0 + self._precomputed_freqs_cis = precompute_freqs_cis(self.args.head_dim, 128_000, theta) + + if self._precomputed_freqs_cis.device != self.device: + self._precomputed_freqs_cis = self._precomputed_freqs_cis.to(device=self.device) + return self._precomputed_freqs_cis + + def forward_partial( + self, + input_ids: torch.Tensor, + seqlens: List[int], + cache: Optional[BufferCache] = None, + ) -> torch.Tensor: + """Local forward pass. + + If doing pipeline parallelism, this will return the activations of the last layer of this stage. + For the last stage, this will return the normalized final embeddings. + """ + assert ( + len(seqlens) <= self.args.max_batch_size + ), f"Max batch size is {self.args.max_batch_size}, got batch size of {len(seqlens)}" + (num_toks,) = input_ids.shape + assert sum(seqlens) == num_toks, (sum(seqlens), num_toks) + + input_metadata: Union[CacheInputMetadata, SimpleInputMetadata] + + if cache is not None: + input_metadata = cache.get_input_metadata(seqlens) + else: + input_metadata = SimpleInputMetadata.from_seqlens(seqlens, self.device) + + if self.pipeline_rank == 0: + assert self.tok_embeddings is not None + h = self.tok_embeddings(input_ids) + else: + h = torch.empty(num_toks, self.args.dim, device=self.device, dtype=self.dtype) + torch.distributed.recv(h, src=self.pipeline_rank - 1) + + freqs_cis = self.freqs_cis[input_metadata.positions] + + for local_layer_id, layer in enumerate(self.layers.values()): + if cache is not None: + assert input_metadata is not None + assert isinstance(input_metadata, CacheInputMetadata) + cache_view = cache.get_view(local_layer_id, input_metadata) + else: + cache_view = None + h = layer(h, freqs_cis, cache_view) + + if cache is not None: + cache.update_seqlens(seqlens) + if self.pipeline_rank < self.num_pipeline_ranks - 1: + torch.distributed.send(h, dst=self.pipeline_rank + 1) + return h # type: ignore + else: + # Last rank has a final normalization step. + assert self.norm is not None + return self.norm(h) # type: ignore + + def forward( + self, + input_ids: torch.Tensor, + seqlens: List[int], + cache: Optional[BufferCache] = None, + ) -> torch.Tensor: + h = self.forward_partial(input_ids, seqlens, cache=cache) + if self.pipeline_rank < self.num_pipeline_ranks - 1: + # ignore the intermediate activations as we'll get the final output from + # the last stage + outs = torch.empty(h.shape[0], self.vocab_size, device=h.device, dtype=h.dtype) + else: + assert self.output is not None + outs = self.output(h) + if self.num_pipeline_ranks > 1: + torch.distributed.broadcast(outs, src=self.num_pipeline_ranks - 1) + return outs.float() + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False) -> None: + state_to_load = {} + skipped = set([]) + for k, v in state_dict.items(): + if k.startswith("tok_embeddings"): + if self.pipeline_rank == 0: + state_to_load[k] = v + else: + logging.debug( + "Skipping parameter %s at pipeline rank %d", + k, + self.pipeline_rank, + ) + skipped.add(k) + elif k.startswith("norm") or k.startswith("output"): + if self.pipeline_rank == self.num_pipeline_ranks - 1: + state_to_load[k] = v + else: + logging.debug( + "Skipping parameter %s at pipeline rank %d", + k, + self.pipeline_rank, + ) + skipped.add(k) + elif k.startswith("layers"): + layer_id = k.split(".")[1] + if layer_id in self.layers: + state_to_load[k] = v + else: + logging.debug( + "Skipping parameter %s at pipeline rank %d", + k, + self.pipeline_rank, + ) + skipped.add(k) + else: + raise ValueError(f"Unexpected key {k}") + assert set(state_dict.keys()) == skipped.union(set(state_to_load.keys())) + super().load_state_dict(state_to_load, strict=strict, assign=assign) + + @staticmethod + def from_folder( + folder: Union[Path, str], + max_batch_size: int = 1, + num_pipeline_ranks: int = 1, + device: Union[torch.device, str] = "cuda", + dtype: Optional[torch.dtype] = None, + ) -> "Transformer": + with open(Path(folder) / "params.json", "r") as f: + model_args = TransformerArgs.from_dict(json.load(f)) + model_args.max_batch_size = max_batch_size + if num_pipeline_ranks > 1: + pipeline_rank = torch.distributed.get_rank() + else: + pipeline_rank = 0 + with torch.device("meta"): + model = Transformer( + model_args, + pipeline_rank=pipeline_rank, + num_pipeline_ranks=num_pipeline_ranks, + ) + + pt_model_file = Path(folder) / "consolidated.00.pth" + safetensors_model_file = Path(folder) / "consolidated.safetensors" + + assert ( + pt_model_file.exists() or safetensors_model_file.exists() + ), f"Make sure either {pt_model_file} or {safetensors_model_file} exists" + assert not ( + pt_model_file.exists() and safetensors_model_file.exists() + ), f"Both {pt_model_file} and {safetensors_model_file} cannot exist" + + if pt_model_file.exists(): + loaded = torch.load(str(pt_model_file), mmap=True) + else: + loaded = safetensors.torch.load_file(str(safetensors_model_file)) + + model.load_state_dict(loaded, assign=True, strict=True) + + return model.to(device=device, dtype=dtype) diff --git a/tests/test_generate.py b/tests/test_generate.py index bf9f117..00bfad4 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -1,8 +1,10 @@ from typing import List import torch +from mistral_inference.generate import generate_mamba from mistral_inference.main import generate -from mistral_inference.model import ModelArgs, Transformer +from mistral_inference.mamba import Mamba, MambaArgs +from mistral_inference.transformer import Transformer, TransformerArgs class DebugTokenizer: @@ -29,11 +31,11 @@ def decode(self, t: List[int]) -> str: return " ".join([str(x) for x in t]) -def test_generation(): +def test_generation_transformer(): torch.manual_seed(42) sequences = ["1 2 3 4 5 6 7", "0 1 2", "12 13 14", "2 4 34"] - args = ModelArgs( + args = TransformerArgs( dim=512, n_layers=1, head_dim=128, @@ -53,30 +55,51 @@ def test_generation(): # concat generated and prompt encoded = [e + t for e, t in zip(encoded, toks)] - generated, all_logprobs_new = generate( - encoded, model, temperature=0.0, max_tokens=0 - ) + generated, all_logprobs_new = generate(encoded, model, temperature=0.0, max_tokens=0) assert generated == [] # Verify that logprobs are the same assert len(sequences) == len(all_logprobs_old) == len(all_logprobs_new) for lp_old, lp_new in zip(all_logprobs_old, all_logprobs_new): - assert all( - [abs(x - y) < 1e-5 for x, y in zip(lp_old, lp_new)] - ), f"\n{lp_old}\n{lp_new}" + assert all([abs(x - y) < 5e-4 for x, y in zip(lp_old, lp_new)]), f"\n{lp_old}\n{lp_new}" print("All tests passed.") -def test_chunks(): +def test_generation_mamba(): + torch.manual_seed(42) + + sequences = ["1 2 3 4 5 6 7"] + args = MambaArgs( + dim=512, + n_layers=1, + n_groups=1, + rms_norm=True, + residual_in_fp32=True, + fused_add_norm=True, + pad_vocab_size_multiple=1, + tie_embeddings=False, + vocab_size=32768, + ) + model = Mamba(args).to("cuda", dtype=torch.float32) + tokenizer = DebugTokenizer() + + encoded = [tokenizer.encode(s, bos=True) for s in sequences] + toks, all_logprobs_old = generate_mamba(encoded, model, temperature=0.0, max_tokens=7) + + assert len(toks[0]) == 7 + assert toks == [[25574, 14821, 11843, 23698, 12735, 23522, 27542]] + + +def test_chunks_transformer(): torch.manual_seed(42) sequences = [ " ".join([str(i) for i in range(7)]), " ".join([str(i) for i in range(9, 0, -1)]), ] - args = ModelArgs( + args = TransformerArgs( dim=512, n_layers=1, head_dim=128, @@ -96,17 +119,8 @@ def test_chunks(): # concat generated and prompt encoded = [e + t for e, t in zip(encoded, toks)] - generated, all_logprobs_new = generate( - encoded, model, temperature=0.0, max_tokens=0, chunk_size=5 - ) + generated, all_logprobs_new = generate(encoded, model, temperature=0.0, max_tokens=0, chunk_size=5) assert len(generated) == 0 for lp_old, lp_new in zip(all_logprobs_old, all_logprobs_new): - assert all( - [abs(x - y) < 1e-5 for x, y in zip(lp_old, lp_new)] - ), f"\n{lp_old}\n{lp_new}" - - -if __name__ == "__main__": - test_generation() - test_chunks() + assert all([abs(x - y) < 5e-4 for x, y in zip(lp_old, lp_new)]), f"\n{lp_old}\n{lp_new}" diff --git a/tutorials/getting_started.ipynb b/tutorials/getting_started.ipynb index a3e70b1..747b807 100644 --- a/tutorials/getting_started.ipynb +++ b/tutorials/getting_started.ipynb @@ -82,7 +82,7 @@ "source": [ "import os \n", "\n", - "from mistral_inference.model import Transformer\n", + "from mistral_inference.transformer import Transformer\n", "from mistral_inference.generate import generate\n", "\n", "from mistral_common.tokens.tokenizers.mistral import MistralTokenizer\n",