Skip to content

Commit

Permalink
add mamba
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvonplaten committed Jul 16, 2024
1 parent e3a64e4 commit 26a52a1
Show file tree
Hide file tree
Showing 14 changed files with 720 additions and 614 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ You can continue chatting afterwards, *e.g.* with *"Translate it to Python"*.
- *Instruction Following*:

```py
from mistral_inference.model import Transformer
from mistral_inference.transformer import Transformer
from mistral_inference.generate import generate

from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
Expand Down Expand Up @@ -228,7 +228,7 @@ pip install --upgrade mistral-common
You can simulate a code completion in-filling as follows.

```py
from mistral_inference.model import Transformer
from mistral_inference.transformer import Transformer
from mistral_inference.generate import generate
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.instruct.request import FIMRequest
Expand Down
64 changes: 19 additions & 45 deletions moe_one_file_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class MoeArgs(Serializable):


@dataclass
class ModelArgs(Serializable):
class TransformerArgs(Serializable):
dim: int
n_layers: int
head_dim: int
Expand Down Expand Up @@ -80,7 +80,7 @@ def apply_rotary_emb(


class Attention(nn.Module):
def __init__(self, args: ModelArgs):
def __init__(self, args: TransformerArgs):
super().__init__()
self.args = args

Expand Down Expand Up @@ -144,9 +144,7 @@ def forward(
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

# Update cache
scatter_pos = positions[None, :, None, None].repeat(
bsz, 1, self.n_kv_heads, self.args.head_dim
)
scatter_pos = positions[None, :, None, None].repeat(bsz, 1, self.n_kv_heads, self.args.head_dim)
cache_k[:bsz].scatter_(dim=1, index=scatter_pos, src=xk)
cache_v[:bsz].scatter_(dim=1, index=scatter_pos, src=xv)

Expand Down Expand Up @@ -179,7 +177,7 @@ def forward(


class FeedForward(nn.Module):
def __init__(self, args: ModelArgs):
def __init__(self, args: TransformerArgs):
super().__init__()
self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False)
Expand Down Expand Up @@ -214,9 +212,7 @@ def __init__(self, experts: List[nn.Module], gate: nn.Module, moe_args: MoeArgs)
def forward(self, inputs: torch.Tensor):
inputs_squashed = inputs.view(-1, inputs.shape[-1])
gate_logits = self.gate(inputs_squashed)
weights, selected_experts = torch.topk(
gate_logits, self.args.num_experts_per_tok
)
weights, selected_experts = torch.topk(gate_logits, self.args.num_experts_per_tok)
weights = nn.functional.softmax(
weights,
dim=1,
Expand All @@ -225,14 +221,12 @@ def forward(self, inputs: torch.Tensor):
results = torch.zeros_like(inputs_squashed)
for i, expert in enumerate(self.experts):
batch_idx, nth_expert = torch.where(selected_experts == i)
results[batch_idx] += weights[batch_idx, nth_expert, None] * expert(
inputs_squashed[batch_idx]
)
results[batch_idx] += weights[batch_idx, nth_expert, None] * expert(inputs_squashed[batch_idx])
return results.view_as(inputs)


class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
def __init__(self, args: TransformerArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
Expand Down Expand Up @@ -270,7 +264,7 @@ def precompute_freqs_cis(dim: int, end: int, theta: float) -> torch.Tensor:
class Transformer(nn.Module):
def __init__(
self,
args: ModelArgs,
args: TransformerArgs,
pipeline_rank: int = 0,
num_pipeline_ranks: int = 1,
):
Expand Down Expand Up @@ -316,13 +310,9 @@ def freqs_cis(self) -> torch.Tensor:
# from the module's dtype means we cannot register it as a buffer
if self._precomputed_freqs_cis is None:
theta = self.args.rope_theta or 1000000.0
self._precomputed_freqs_cis = precompute_freqs_cis(
self.args.head_dim, 128_000, theta
)
self._precomputed_freqs_cis = precompute_freqs_cis(self.args.head_dim, 128_000, theta)
if self._precomputed_freqs_cis.device != self.device:
self._precomputed_freqs_cis = self._precomputed_freqs_cis.to(
device=self.device
)
self._precomputed_freqs_cis = self._precomputed_freqs_cis.to(device=self.device)
return self._precomputed_freqs_cis

def forward(
Expand All @@ -341,9 +331,7 @@ def forward(
assert h.shape == (bsz, seqlen, self.args.dim)
assert h.dtype == self.dtype
else:
h = torch.empty(
bsz, seqlen, self.args.dim, device=self.device, dtype=self.dtype
)
h = torch.empty(bsz, seqlen, self.args.dim, device=self.device, dtype=self.dtype)
torch.distributed.recv(h, src=self.pipeline_rank - 1)

mask: Optional[torch.Tensor] = None
Expand All @@ -361,9 +349,7 @@ def forward(

if self.pipeline_rank < self.num_pipeline_ranks - 1:
torch.distributed.send(h, dst=self.pipeline_rank + 1)
outs = torch.empty(
*h.shape[:-1], self.vocab_size, device=h.device, dtype=h.dtype
)
outs = torch.empty(*h.shape[:-1], self.vocab_size, device=h.device, dtype=h.dtype)
else:
assert self.output is not None
assert self.norm is not None
Expand Down Expand Up @@ -422,7 +408,7 @@ def from_folder(
dtype=torch.float16,
) -> "Transformer":
with open(folder / "params.json", "r") as f:
model_args = ModelArgs.from_dict(json.load(f))
model_args = TransformerArgs.from_dict(json.load(f))
model_args.max_batch_size = max_batch_size
model_args.max_seq_len = max_seq_len
if num_pipeline_ranks > 1:
Expand Down Expand Up @@ -457,9 +443,7 @@ def from_folder(


def load_tokenizer(model_path: Path) -> MistralTokenizer:
tokenizer = [
f for f in os.listdir(Path(model_path)) if f.startswith("tokenizer.model")
]
tokenizer = [f for f in os.listdir(Path(model_path)) if f.startswith("tokenizer.model")]
assert (
len(tokenizer) == 1
), f"Multiple tokenizers {', '.join(tokenizer)} found in `model_path`, make sure to only have one tokenizer"
Expand All @@ -470,12 +454,8 @@ def load_tokenizer(model_path: Path) -> MistralTokenizer:


@torch.no_grad()
def generate(
prompts: List[str], model: Transformer, tokenizer: Tokenizer, max_tokens: int
):
encoded_prompts = [
tokenizer.encode(prompt, bos=True, eos=False) for prompt in prompts
]
def generate(prompts: List[str], model: Transformer, tokenizer: Tokenizer, max_tokens: int):
encoded_prompts = [tokenizer.encode(prompt, bos=True, eos=False) for prompt in prompts]
prompt_lens = [len(x) for x in encoded_prompts]
min_prompt_len = min(prompt_lens)
max_prompt_len = max(prompt_lens)
Expand All @@ -498,23 +478,17 @@ def generate(
# decode
generated = []
all_logprobs = [
logprobs[:, :-1, :]
.gather(2, input_tokens[:, 1:min_prompt_len, None])
.squeeze(-1),
logprobs[:, :-1, :].gather(2, input_tokens[:, 1:min_prompt_len, None]).squeeze(-1),
]
for cur_pos in range(min_prompt_len, max_tokens):
next_token = torch.argmax(logprobs[:, -1, :], dim=-1)
if cur_pos < input_mask.shape[1]:
next_token = torch.where(
input_mask[:, cur_pos], input_tokens[:, cur_pos], next_token
)
next_token = torch.where(input_mask[:, cur_pos], input_tokens[:, cur_pos], next_token)
all_logprobs.append(
logprobs[:, -1, :].gather(1, next_token[:, None]),
)
generated.append(next_token[:, None])
logits = model.forward(
next_token[:, None], torch.LongTensor([cur_pos]).to(next_token)
)
logits = model.forward(next_token[:, None], torch.LongTensor([cur_pos]).to(next_token))
logprobs = nn.functional.log_softmax(logits, dim=-1)

all_logprobs_merged = torch.cat(all_logprobs, 1)
Expand Down
56 changes: 17 additions & 39 deletions one_file_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


@dataclass
class ModelArgs(Serializable):
class TransformerArgs(Serializable):
dim: int
n_layers: int
head_dim: int
Expand All @@ -31,9 +31,7 @@ class ModelArgs(Serializable):
max_batch_size: int = 0


def repeat_kv(
keys: torch.Tensor, values: torch.Tensor, repeats: int
) -> Tuple[torch.Tensor]:
def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int) -> Tuple[torch.Tensor]:
keys = torch.repeat_interleave(keys, repeats=repeats, dim=2)
values = torch.repeat_interleave(values, repeats=repeats, dim=2)
return keys, values
Expand Down Expand Up @@ -68,7 +66,7 @@ def apply_rotary_emb(


class Attention(nn.Module):
def __init__(self, args: ModelArgs):
def __init__(self, args: TransformerArgs):
super().__init__()
self.args = args

Expand Down Expand Up @@ -118,9 +116,7 @@ def forward(
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

# cache
scatter_pos = positions[None, :, None, None].repeat(
bsz, 1, self.n_kv_heads, self.args.head_dim
)
scatter_pos = positions[None, :, None, None].repeat(bsz, 1, self.n_kv_heads, self.args.head_dim)
self.cache_k[:bsz].scatter_(dim=1, index=scatter_pos, src=xk)
self.cache_v[:bsz].scatter_(dim=1, index=scatter_pos, src=xv)

Expand Down Expand Up @@ -152,7 +148,7 @@ def forward(


class FeedForward(nn.Module):
def __init__(self, args: ModelArgs):
def __init__(self, args: TransformerArgs):
super().__init__()

self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
Expand All @@ -178,7 +174,7 @@ def forward(self, x):


class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
def __init__(self, args: TransformerArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
Expand Down Expand Up @@ -210,7 +206,7 @@ def precompute_freqs_cis(dim: int, end: int, theta: float) -> torch.Tensor:


class Transformer(nn.Module):
def __init__(self, args: ModelArgs):
def __init__(self, args: TransformerArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
Expand All @@ -219,18 +215,14 @@ def __init__(self, args: ModelArgs):

self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)

self.layers = torch.nn.ModuleList(
[TransformerBlock(args=args) for _ in range(args.n_layers)]
)
self.layers = torch.nn.ModuleList([TransformerBlock(args=args) for _ in range(args.n_layers)])

self.norm = RMSNorm(args.dim, eps=args.norm_eps)

self.output = nn.Linear(args.dim, args.vocab_size, bias=False)

theta = self.args.rope_theta or 1000000.0
self.freqs_cis = precompute_freqs_cis(self.args.head_dim, 128_000, theta).to(
"cuda"
)
self.freqs_cis = precompute_freqs_cis(self.args.head_dim, 128_000, theta).to("cuda")

def forward(
self,
Expand Down Expand Up @@ -259,11 +251,9 @@ def forward(
return self.output(self.norm(h)).float()

@staticmethod
def from_folder(
folder: Path, max_batch_size: int = 1, device="cuda", dtype=torch.float16
):
def from_folder(folder: Path, max_batch_size: int = 1, device="cuda", dtype=torch.float16):
with open(Path(folder) / "params.json", "r") as f:
model_args = ModelArgs.from_dict(json.load(f))
model_args = TransformerArgs.from_dict(json.load(f))
model_args.max_batch_size = max_batch_size

model = Transformer(model_args)
Expand All @@ -288,9 +278,7 @@ def from_folder(


def load_tokenizer(model_path: Path) -> MistralTokenizer:
tokenizer = [
f for f in os.listdir(Path(model_path)) if f.startswith("tokenizer.model")
]
tokenizer = [f for f in os.listdir(Path(model_path)) if f.startswith("tokenizer.model")]
assert (
len(tokenizer) > 0
), f"No tokenizer found in {model_path}, make sure to place a `tokenizer.model.[v1,v2,v3]` file in {model_path}."
Expand All @@ -304,12 +292,8 @@ def load_tokenizer(model_path: Path) -> MistralTokenizer:


@torch.no_grad()
def generate(
prompts: List[str], model: Transformer, tokenizer: Tokenizer, max_tokens: int
):
encoded_prompts = [
tokenizer.encode(prompt, bos=True, eos=False) for prompt in prompts
]
def generate(prompts: List[str], model: Transformer, tokenizer: Tokenizer, max_tokens: int):
encoded_prompts = [tokenizer.encode(prompt, bos=True, eos=False) for prompt in prompts]
prompt_lens = [len(x) for x in encoded_prompts]
min_prompt_len = min(prompt_lens)
max_prompt_len = max(prompt_lens)
Expand All @@ -333,24 +317,18 @@ def generate(
# decode
generated = []
all_logprobs = [
logprobs[:, :-1, :]
.gather(2, input_tokens[:, 1:min_prompt_len, None])
.squeeze(-1),
logprobs[:, :-1, :].gather(2, input_tokens[:, 1:min_prompt_len, None]).squeeze(-1),
]
cur_pos = min_prompt_len
for _ in range(max_tokens):
next_token = torch.argmax(logprobs[:, -1, :], dim=-1)
if cur_pos < input_mask.shape[1]:
next_token = torch.where(
input_mask[:, cur_pos], input_tokens[:, cur_pos], next_token
)
next_token = torch.where(input_mask[:, cur_pos], input_tokens[:, cur_pos], next_token)
all_logprobs.append(
logprobs[:, -1, :].gather(1, next_token[:, None]),
)
generated.append(next_token[:, None])
logits = model.forward(
next_token[:, None], torch.LongTensor([cur_pos]).to(next_token)
)
logits = model.forward(next_token[:, None], torch.LongTensor([cur_pos]).to(next_token))
logprobs = nn.functional.log_softmax(logits, dim=-1)
cur_pos += 1

Expand Down
49 changes: 49 additions & 0 deletions src/mistral_inference/args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from dataclasses import dataclass
from typing import Optional

from simple_parsing.helpers import Serializable

from mistral_inference.lora import LoraArgs
from mistral_inference.moe import MoeArgs


@dataclass
class TransformerArgs(Serializable):
dim: int
n_layers: int
head_dim: int
hidden_dim: int
n_heads: int
n_kv_heads: int
norm_eps: float
vocab_size: int

max_batch_size: int = 0

# For rotary embeddings. If not set, will be inferred
rope_theta: Optional[float] = None
# If this is set, we will use MoE layers instead of dense layers.
moe: Optional[MoeArgs] = None
# If this is set, we will load LoRA linear layers instead of linear layers.
lora: Optional[LoraArgs] = None
model_type: str = "transformer"

def __post_init__(self):
assert self.model_type == "transformer", self.model_type


@dataclass
class MambaArgs(Serializable):
dim: int
n_layers: int
vocab_size: int
n_groups: int
rms_norm: bool
residual_in_fp32: bool
fused_add_norm: bool
pad_vocab_size_multiple: int
tie_embeddings: bool
model_type: str = "mamba"

def __post_init__(self):
assert self.model_type == "mamba", self.model_type
Loading

0 comments on commit 26a52a1

Please sign in to comment.