diff --git a/mixtral-moe/README.md b/mixtral-moe/README.md index cf5e9d9..602a4bc 100644 --- a/mixtral-moe/README.md +++ b/mixtral-moe/README.md @@ -3,6 +3,12 @@ ## Downloading Weights +Models tested/supported +```text +Mixtral-8x7B-v0.1 +databricks/dbrx-base +``` + ```bash export MODEL_REPO=mistralai/Mixtral-8x7B-v0.1 python scripts/download.py --repo_id $MODEL_REPO @@ -12,11 +18,22 @@ python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$MODEL_REPO ## Benchmarks Benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh topology. Note that all benchmarks are run at *batch size=1*, making the reported tokens/s numbers equivalent to "tokens/s/user". In addition, they are run with a very small prompt length (just 5 tokens). +### Mixtral-8x7B +Mixtral has 46.7B total parameters but only uses 12.9B parameters per token, 8 experts and chooses 2. + | | 1 GPU | 2 GPU | 4 GPU | 8 GPU | |------------------|---------|-----------|--------|------------| |baseline(bfloat16)| OOM | 96.67 | 155.35 | 227.82 | | int8 | 97.92 | 155.03 | 216.87 | 279.35 | +### dbrx-base +DBRX has 132B total parameters of which 36B parameters are active on any input, 16 experts and chooses 4. + +| | 1 GPU | 2 GPU | 4 GPU | 8 GPU | +|------------------|---------|-----------|--------|------------| +|baseline(bfloat16)| OOM | OOM | 59.53 | 100.51 | +| int8 | OOM | 66.72 | 91.21 | 146.86 | + ## Generate Text diff --git a/mixtral-moe/generate.py b/mixtral-moe/generate.py index 9aa076b..fe5003d 100644 --- a/mixtral-moe/generate.py +++ b/mixtral-moe/generate.py @@ -31,9 +31,9 @@ def device_sync(device): wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) -from sentencepiece import SentencePieceProcessor from model import Transformer +from tokenizer import get_tokenizer from tp import maybe_init_dist @@ -175,7 +175,6 @@ def main( assert checkpoint_path.is_file(), checkpoint_path tokenizer_path = checkpoint_path.parent / "tokenizer.model" - assert tokenizer_path.is_file(), str(tokenizer_path) global print rank = maybe_init_dist() @@ -196,7 +195,7 @@ def main( device_sync(device=device) # MKG print(f"Time to load model: {time.time() - t0:.02f} seconds") - tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path)) + tokenizer = get_tokenizer(tokenizer_path, checkpoint_path) encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) prompt_length = encoded.size(0) diff --git a/mixtral-moe/model.py b/mixtral-moe/model.py index 9249ac9..aa5938e 100644 --- a/mixtral-moe/model.py +++ b/mixtral-moe/model.py @@ -31,6 +31,7 @@ class ModelArgs: norm_eps: float = 1e-5 num_experts: int = 8 num_activated_experts: int = 2 + clip_qkv: Optional[float] = None def __post_init__(self): if self.n_local_heads == -1: @@ -53,8 +54,16 @@ def from_name(cls, name: str): transformer_configs = { "Mixtral-8x7B-v0.1": dict(block_size=32768, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, rope_base=1000000.0, num_experts=8, num_activated_experts=2), + "dbrx-base": dict(block_size=32768, n_layer=40, n_head=48, n_local_heads=8, dim=6144, intermediate_size=10752, rope_base=500000.0, num_experts=16, num_activated_experts=4, vocab_size=100352, clip_qkv=8.0), + "dbrx-instruct": dict(block_size=32768, n_layer=40, n_head=48, n_local_heads=8, dim=6144, intermediate_size=10752, rope_base=500000.0, num_experts=16, num_activated_experts=4, vocab_size=100352, clip_qkv=8.0), } +def is_dbrx(config: ModelArgs): + if config.n_layer == 40 and config.rope_base == 500000.0: + return True + else: + return False + class KVCache(nn.Module): def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16): super().__init__() @@ -80,7 +89,10 @@ def __init__(self, config: ModelArgs) -> None: self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer)) - self.norm = RMSNorm(config.dim, eps=config.norm_eps) + if is_dbrx(config): + self.norm = nn.LayerNorm(config.dim, eps=config.norm_eps, bias=False) + else: + self.norm = RMSNorm(config.dim, eps=config.norm_eps) self.output = nn.Linear(config.dim, config.vocab_size, bias=False) self.freqs_cis: Optional[Tensor] = None @@ -123,8 +135,12 @@ def __init__(self, config: ModelArgs) -> None: super().__init__() self.attention = Attention(config) self.block_sparse_moe = MOEFeedForward(config) - self.ffn_norm = RMSNorm(config.dim, config.norm_eps) - self.attention_norm = RMSNorm(config.dim, config.norm_eps) + if is_dbrx(config): + self.ffn_norm = nn.LayerNorm(config.dim, config.norm_eps, bias=False) + self.attention_norm = nn.LayerNorm(config.dim, config.norm_eps, bias=False) + else: + self.ffn_norm = RMSNorm(config.dim, config.norm_eps) + self.attention_norm = RMSNorm(config.dim, config.norm_eps) def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor: h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) @@ -147,6 +163,7 @@ def __init__(self, config: ModelArgs): self.head_dim = config.head_dim self.n_local_heads = config.n_local_heads self.dim = config.dim + self.clip_qkv = config.clip_qkv self._register_load_state_dict_pre_hook(self.load_hook) def load_hook(self, state_dict, prefix, *args): @@ -160,7 +177,10 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona bsz, seqlen, _ = x.shape kv_size = self.n_local_heads * self.head_dim - q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) + qkv_states = self.wqkv(x) + if self.clip_qkv is not None: + qkv_states = qkv_states.clamp(min = -self.clip_qkv, max = self.clip_qkv) + q, k, v = qkv_states.split([self.dim, kv_size, kv_size], dim=-1) q = q.view(bsz, seqlen, self.n_head, self.head_dim) k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) @@ -215,7 +235,7 @@ def forward(self, x: Tensor) -> Tensor: scores = self.gate(x) # [T, E] expert_weights = F.softmax(scores, dim=-1) expert_weights, expert_indices = torch.topk(expert_weights, self.num_activated_experts, dim=-1) # [T, A], [T, A] - expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A] + expert_weights = expert_weights / torch.norm(expert_weights, p=1, dim=-1, keepdim=True) expert_outs = self.cond_ffn(x, expert_indices) return torch.einsum('tai,ta -> ti', expert_outs, expert_weights) @@ -245,16 +265,20 @@ def precompute_freqs_cis( return cache.to(dtype=torch.bfloat16) +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: - xshaped = x.float().reshape(*x.shape[:-1], -1, 2) - freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) - x_out2 = torch.stack( - [ - xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], - xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], - ], - -1, - ) - - x_out2 = x_out2.flatten(3) - return x_out2.type_as(x) + fc_shape = freqs_cis.shape + freqs_cis = freqs_cis.view(1, fc_shape[0], 1, fc_shape[1], fc_shape[2]) + cos, sin = freqs_cis.split([1, 1], dim=-1) + cos = cos.squeeze(-1) + sin = sin.squeeze(-1) + cos = torch.cat((cos, cos), dim=-1) + sin = torch.cat((sin, sin), dim=-1) + z = (x * cos) + (rotate_half(x)) * sin + return z diff --git a/mixtral-moe/scripts/convert_hf_checkpoint.py b/mixtral-moe/scripts/convert_hf_checkpoint.py index e659931..f7f4f9a 100644 --- a/mixtral-moe/scripts/convert_hf_checkpoint.py +++ b/mixtral-moe/scripts/convert_hf_checkpoint.py @@ -8,6 +8,7 @@ import re import sys from pathlib import Path +from safetensors.torch import load from typing import Optional import torch @@ -18,9 +19,8 @@ from model import ModelArgs - @torch.inference_mode() -def convert_hf_checkpoint( +def _convert_mixtral( *, checkpoint_dir: Path = Path("checkpoints/mistralai/Mixtral-8x7B-v0.1"), model_name: Optional[str] = None, @@ -87,6 +87,66 @@ def convert_hf_checkpoint( torch.save(final_result, checkpoint_dir / "model.pth") +@torch.inference_mode() +def _convert_dbrx( + *, + checkpoint_dir: Path = Path("checkpoints/databricks/dbrx-base"), + model_name: Optional[str] = None, +) -> None: + if model_name is None: + model_name = checkpoint_dir.name + + config = ModelArgs.from_name(model_name) + print(f"Model config {config.__dict__}") + + weight_map = { + "transformer.wte.weight": "tok_embeddings.weight", + "transformer.blocks.{}.norm_attn_norm.attn.Wqkv.weight": "layers.{}.attention.wqkv.weight", + "transformer.blocks.{}.norm_attn_norm.attn.out_proj.weight": "layers.{}.attention.wo.weight", + "transformer.blocks.{}.ffn.experts.mlp.w1": "layers.{}.block_sparse_moe.cond_ffn.w1", + "transformer.blocks.{}.ffn.experts.mlp.w2": "layers.{}.block_sparse_moe.cond_ffn.w2", + "transformer.blocks.{}.ffn.experts.mlp.v1": "layers.{}.block_sparse_moe.cond_ffn.w3", + "transformer.blocks.{}.ffn.router.layer.weight": "layers.{}.block_sparse_moe.gate.weight", + "transformer.blocks.{}.norm_attn_norm.norm_1.weight": "layers.{}.attention_norm.weight", + "transformer.blocks.{}.norm_attn_norm.norm_2.weight": "layers.{}.ffn_norm.weight", + "transformer.norm_f.weight": "norm.weight", + "lm_head.weight": "output.weight", + } + + st_files = glob.glob(str(checkpoint_dir / "*.safetensors")) + + merged_result = {} + for file in sorted(st_files): + with open(file, "rb") as f: + data = f.read() + state_dict = load(data) + merged_result.update(state_dict) + final_result = {} + for key, value in merged_result.items(): + if "blocks" in key: + abstract_key = re.sub(r'.(\d+).', '.{}.', key, count=1) + layer_num = re.search(r'\d+', key).group(0) + new_key = weight_map[abstract_key] + if new_key is None: + continue + new_key = new_key.format(layer_num) + else: + new_key = weight_map[key] + + final_result[new_key] = value + + for key in tuple(final_result.keys()): + if "w1" in key or "w3" in key: + final_result[key] = final_result[key].reshape(config.num_experts, config.intermediate_size, config.dim).contiguous() + elif "w2" in key: + final_result[key] = final_result[key].reshape(config.num_experts, config.intermediate_size, config.dim).permute(0, 2, 1).contiguous() + elif "gate" in key: + final_result[key] = final_result[key].contiguous() + + print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}") + torch.save(final_result, checkpoint_dir / "model.pth") + + if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint.') @@ -94,7 +154,13 @@ def convert_hf_checkpoint( parser.add_argument('--model_name', type=str, default=None) args = parser.parse_args() - convert_hf_checkpoint( - checkpoint_dir=args.checkpoint_dir, - model_name=args.model_name, - ) + checkpoint_dir=args.checkpoint_dir + model_name=args.model_name + if model_name is None: + model_name = checkpoint_dir.name + + if "Mixtral-8x7B" in model_name: + _convert_mixtral(checkpoint_dir=checkpoint_dir, model_name=model_name) + else: + assert "dbrx" in model_name, f"Unknown model name {model_name}" + _convert_dbrx(checkpoint_dir=checkpoint_dir, model_name=model_name) diff --git a/mixtral-moe/scripts/download.py b/mixtral-moe/scripts/download.py index d1505ef..a968cf3 100644 --- a/mixtral-moe/scripts/download.py +++ b/mixtral-moe/scripts/download.py @@ -13,7 +13,7 @@ def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) - from huggingface_hub import snapshot_download os.makedirs(f"checkpoints/{repo_id}", exist_ok=True) try: - snapshot_download(repo_id, local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token, ignore_patterns="*.safetensors") + snapshot_download(repo_id, local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token) except HTTPError as e: if e.response.status_code == 401: print("You need to pass a valid `--hf_token=...` to download private checkpoints.") diff --git a/mixtral-moe/tokenizer.py b/mixtral-moe/tokenizer.py new file mode 100644 index 0000000..ac87470 --- /dev/null +++ b/mixtral-moe/tokenizer.py @@ -0,0 +1,124 @@ +import os +import sentencepiece as spm +import tiktoken +from tiktoken.load import load_tiktoken_bpe +from pathlib import Path +from typing import Dict +from transformers import GPT2TokenizerFast + +class TokenizerInterface: + def __init__(self, model_path): + self.model_path = os.path.join(model_path, "tokenizer.model") + + def encode(self, text): + raise NotImplementedError("This method should be overridden by subclasses.") + + def decode(self, tokens): + raise NotImplementedError("This method should be overridden by subclasses.") + + def bos_id(self): + raise NotImplementedError("This method should be overridden by subclasses.") + + def eos_id(self): + raise NotImplementedError("This method should be overridden by subclasses.") + +class SentencePieceWrapper(TokenizerInterface): + def __init__(self, model_path): + super().__init__(model_path) + assert os.path.isfile(self.model_path), str(self.model_path) + self.processor = spm.SentencePieceProcessor(str(self.model_path)) + + def encode(self, text): + return self.processor.EncodeAsIds(text) + + def decode(self, tokens): + return self.processor.DecodeIds(tokens) + + def bos_id(self): + return self.processor.bos_id() + + def eos_id(self): + return self.processor.eos_id() + +class DBRXTokenizeWrapper(TokenizerInterface): + def __init__(self, model_path): + super().__init__(model_path) + vocab_file = os.path.join(model_path.parent, "vocab.json") + merges_file = os.path.join(model_path.parent, "merges.txt") + tokenizer_file = os.path.join(model_path.parent, "tokenizer.json") + self.processor = GPT2TokenizerFast(vocab_file, merges_file, tokenizer_file) + + def encode(self, text): + return self.processor.encode(text) + + def decode(self, tokens): + return self.processor.decode(tokens) + + def bos_id(self): + return self.processor.bos_token_id + + def eos_id(self): + return self.processor.eos_token_id + +class TiktokenWrapper(TokenizerInterface): + """ + Tokenizing and encoding/decoding text using the Tiktoken tokenizer. + """ + + special_tokens: Dict[str, int] + + num_reserved_special_tokens = 256 + + pat_str = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" # noqa: E501 + + def __init__(self, model_path): + super().__init__(model_path) + assert os.path.isfile(self.model_path), str(self.model_path) + mergeable_ranks = load_tiktoken_bpe(str(self.model_path)) + num_base_tokens = len(mergeable_ranks) + special_tokens = [ + "<|endoftext|>", + "<|pad|>", + ] + self.special_tokens = { + token: num_base_tokens + i for i, token in enumerate(special_tokens) + } + self.model = tiktoken.Encoding( + name=Path(self.model_path).name, + pat_str=self.pat_str, + mergeable_ranks=mergeable_ranks, + special_tokens=self.special_tokens, + ) + # BOS / EOS token IDs + self._bos_id: int = self.special_tokens["<|endoftext|>"] + self._eos_id: int = self.special_tokens["<|endoftext|>"] + + def encode(self, text): + return self.model.encode(text) + + def decode(self, tokens): + return self.model.decode(tokens) + + def bos_id(self): + return self._bos_id + + def eos_id(self): + return self._eos_id + +def get_tokenizer(tokenizer_model_path, model_name): + """ + Factory function to get the appropriate tokenizer based on the model name. + + Args: + - tokenizer_model_path (str): The file path to the tokenizer model. + - model_name (str): The name of the model, used to determine the tokenizer type. + + Returns: + - TokenizerInterface: An instance of a tokenizer. + """ + if "Llama-3" in str(model_name): + return TiktokenWrapper(tokenizer_model_path) + elif "dbrx" in str(model_name): + return DBRXTokenizeWrapper(tokenizer_model_path) + else: + return SentencePieceWrapper(tokenizer_model_path)