Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[example] Add support for DBRX #174

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,12 @@ python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model_int
To generate int4 version of model
```bash
# Spits out model at checkpoints/$MODEL_REPO/model_int4.g32.$DEVICE.pth
python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4 --groupsize 32 --device $DEVICE
python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4 --groupsize 32
```

To run with int4, just pass the int4 checkpoint to generate.py.
```bash
python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.$DEVICE.pth --compile --device $DEVICE
python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.pth --compile
```

## Speculative Sampling
Expand Down
8 changes: 3 additions & 5 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def generate(
seq = empty
input_pos = torch.arange(0, T, device=device)

next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs)
next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs).clone()
if is_speculative:
prefill(draft_model, prompt.view(1, -1), input_pos, **sampling_kwargs)
seq[T] = next_token
Expand Down Expand Up @@ -225,12 +225,10 @@ def _load_model(checkpoint_path, device, precision, use_tp):
if "int4" in str(checkpoint_path):
print("Using int4 weight-only quantization!")
path_comps = checkpoint_path.name.split(".")
assert path_comps[-3].startswith("g")
assert path_comps[-2] in device, "weight packed format mismatch, please rerun quantize.py!"
groupsize = int(path_comps[-3][1:])
groupsize = int(path_comps[-2][1:])
from quantize import WeightOnlyInt4QuantHandler
simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
model = simple_quantizer.convert_for_runtime(use_cuda)
model = simple_quantizer.convert_for_runtime()

checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
if "model" in checkpoint and "stories" in str(checkpoint_path):
Expand Down
17 changes: 17 additions & 0 deletions mixtral-moe/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@

## Downloading Weights

Models tested/supported
```text
Mixtral-8x7B-v0.1
databricks/dbrx-base
```

```bash
export MODEL_REPO=mistralai/Mixtral-8x7B-v0.1
python scripts/download.py --repo_id $MODEL_REPO
Expand All @@ -12,11 +18,22 @@ python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$MODEL_REPO
## Benchmarks
Benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh topology. Note that all benchmarks are run at *batch size=1*, making the reported tokens/s numbers equivalent to "tokens/s/user". In addition, they are run with a very small prompt length (just 5 tokens).

### Mixtral-8x7B
Mixtral has 46.7B total parameters but only uses 12.9B parameters per token, 8 experts and chooses 2.

| | 1 GPU | 2 GPU | 4 GPU | 8 GPU |
|------------------|---------|-----------|--------|------------|
|baseline(bfloat16)| OOM | 96.67 | 155.35 | 227.82 |
| int8 | 97.92 | 155.03 | 216.87 | 279.35 |

### dbrx-base
DBRX has 132B total parameters of which 36B parameters are active on any input, 16 experts and chooses 4.

| | 1 GPU | 2 GPU | 4 GPU | 8 GPU |
|------------------|---------|-----------|--------|------------|
|baseline(bfloat16)| OOM | OOM | 59.53 | 100.51 |
| int8 | OOM | 66.72 | 91.21 | 146.86 |


## Generate Text

Expand Down
5 changes: 2 additions & 3 deletions mixtral-moe/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ def device_sync(device):
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from sentencepiece import SentencePieceProcessor

from model import Transformer
from tokenizer import get_tokenizer
from tp import maybe_init_dist


Expand Down Expand Up @@ -175,7 +175,6 @@ def main(
assert checkpoint_path.is_file(), checkpoint_path

tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), str(tokenizer_path)

global print
rank = maybe_init_dist()
Expand All @@ -196,7 +195,7 @@ def main(
device_sync(device=device) # MKG
print(f"Time to load model: {time.time() - t0:.02f} seconds")

tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)
encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
prompt_length = encoded.size(0)

Expand Down
58 changes: 41 additions & 17 deletions mixtral-moe/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class ModelArgs:
norm_eps: float = 1e-5
num_experts: int = 8
num_activated_experts: int = 2
clip_qkv: Optional[float] = None

def __post_init__(self):
if self.n_local_heads == -1:
Expand All @@ -53,8 +54,16 @@ def from_name(cls, name: str):

transformer_configs = {
"Mixtral-8x7B-v0.1": dict(block_size=32768, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, rope_base=1000000.0, num_experts=8, num_activated_experts=2),
"dbrx-base": dict(block_size=32768, n_layer=40, n_head=48, n_local_heads=8, dim=6144, intermediate_size=10752, rope_base=500000.0, num_experts=16, num_activated_experts=4, vocab_size=100352, clip_qkv=8.0),
"dbrx-instruct": dict(block_size=32768, n_layer=40, n_head=48, n_local_heads=8, dim=6144, intermediate_size=10752, rope_base=500000.0, num_experts=16, num_activated_experts=4, vocab_size=100352, clip_qkv=8.0),
}

def is_dbrx(config: ModelArgs):
if config.n_layer == 40 and config.rope_base == 500000.0:
return True
else:
return False

class KVCache(nn.Module):
def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16):
super().__init__()
Expand All @@ -80,7 +89,10 @@ def __init__(self, config: ModelArgs) -> None:

self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
if is_dbrx(config):
self.norm = nn.LayerNorm(config.dim, eps=config.norm_eps, bias=False)
else:
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)

self.freqs_cis: Optional[Tensor] = None
Expand Down Expand Up @@ -123,8 +135,12 @@ def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.attention = Attention(config)
self.block_sparse_moe = MOEFeedForward(config)
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
self.attention_norm = RMSNorm(config.dim, config.norm_eps)
if is_dbrx(config):
self.ffn_norm = nn.LayerNorm(config.dim, config.norm_eps, bias=False)
self.attention_norm = nn.LayerNorm(config.dim, config.norm_eps, bias=False)
else:
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
self.attention_norm = RMSNorm(config.dim, config.norm_eps)

def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
Expand All @@ -147,6 +163,7 @@ def __init__(self, config: ModelArgs):
self.head_dim = config.head_dim
self.n_local_heads = config.n_local_heads
self.dim = config.dim
self.clip_qkv = config.clip_qkv
self._register_load_state_dict_pre_hook(self.load_hook)

def load_hook(self, state_dict, prefix, *args):
Expand All @@ -160,7 +177,10 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona
bsz, seqlen, _ = x.shape

kv_size = self.n_local_heads * self.head_dim
q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
qkv_states = self.wqkv(x)
if self.clip_qkv is not None:
qkv_states = qkv_states.clamp(min = -self.clip_qkv, max = self.clip_qkv)
q, k, v = qkv_states.split([self.dim, kv_size, kv_size], dim=-1)

q = q.view(bsz, seqlen, self.n_head, self.head_dim)
k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
Expand Down Expand Up @@ -215,7 +235,7 @@ def forward(self, x: Tensor) -> Tensor:
scores = self.gate(x) # [T, E]
expert_weights = F.softmax(scores, dim=-1)
expert_weights, expert_indices = torch.topk(expert_weights, self.num_activated_experts, dim=-1) # [T, A], [T, A]
expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A]
expert_weights = expert_weights / torch.norm(expert_weights, p=1, dim=-1, keepdim=True)
expert_outs = self.cond_ffn(x, expert_indices)
return torch.einsum('tai,ta -> ti', expert_outs, expert_weights)

Expand Down Expand Up @@ -245,16 +265,20 @@ def precompute_freqs_cis(
return cache.to(dtype=torch.bfloat16)


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
x_out2 = torch.stack(
[
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
],
-1,
)

x_out2 = x_out2.flatten(3)
return x_out2.type_as(x)
fc_shape = freqs_cis.shape
freqs_cis = freqs_cis.view(1, fc_shape[0], 1, fc_shape[1], fc_shape[2])
cos, sin = freqs_cis.split([1, 1], dim=-1)
cos = cos.squeeze(-1)
sin = sin.squeeze(-1)
cos = torch.cat((cos, cos), dim=-1)
sin = torch.cat((sin, sin), dim=-1)
z = (x * cos) + (rotate_half(x)) * sin
return z
78 changes: 72 additions & 6 deletions mixtral-moe/scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import re
import sys
from pathlib import Path
from safetensors.torch import load
from typing import Optional

import torch
Expand All @@ -18,9 +19,8 @@

from model import ModelArgs


@torch.inference_mode()
def convert_hf_checkpoint(
def _convert_mixtral(
*,
checkpoint_dir: Path = Path("checkpoints/mistralai/Mixtral-8x7B-v0.1"),
model_name: Optional[str] = None,
Expand Down Expand Up @@ -87,14 +87,80 @@ def convert_hf_checkpoint(
torch.save(final_result, checkpoint_dir / "model.pth")


@torch.inference_mode()
def _convert_dbrx(
*,
checkpoint_dir: Path = Path("checkpoints/databricks/dbrx-base"),
model_name: Optional[str] = None,
) -> None:
if model_name is None:
model_name = checkpoint_dir.name

config = ModelArgs.from_name(model_name)
print(f"Model config {config.__dict__}")

weight_map = {
"transformer.wte.weight": "tok_embeddings.weight",
"transformer.blocks.{}.norm_attn_norm.attn.Wqkv.weight": "layers.{}.attention.wqkv.weight",
"transformer.blocks.{}.norm_attn_norm.attn.out_proj.weight": "layers.{}.attention.wo.weight",
"transformer.blocks.{}.ffn.experts.mlp.w1": "layers.{}.block_sparse_moe.cond_ffn.w1",
"transformer.blocks.{}.ffn.experts.mlp.w2": "layers.{}.block_sparse_moe.cond_ffn.w2",
"transformer.blocks.{}.ffn.experts.mlp.v1": "layers.{}.block_sparse_moe.cond_ffn.w3",
"transformer.blocks.{}.ffn.router.layer.weight": "layers.{}.block_sparse_moe.gate.weight",
"transformer.blocks.{}.norm_attn_norm.norm_1.weight": "layers.{}.attention_norm.weight",
"transformer.blocks.{}.norm_attn_norm.norm_2.weight": "layers.{}.ffn_norm.weight",
"transformer.norm_f.weight": "norm.weight",
"lm_head.weight": "output.weight",
}

st_files = glob.glob(str(checkpoint_dir / "*.safetensors"))

merged_result = {}
for file in sorted(st_files):
with open(file, "rb") as f:
data = f.read()
state_dict = load(data)
merged_result.update(state_dict)
final_result = {}
for key, value in merged_result.items():
if "blocks" in key:
abstract_key = re.sub(r'.(\d+).', '.{}.', key, count=1)
layer_num = re.search(r'\d+', key).group(0)
new_key = weight_map[abstract_key]
if new_key is None:
continue
new_key = new_key.format(layer_num)
else:
new_key = weight_map[key]

final_result[new_key] = value

for key in tuple(final_result.keys()):
if "w1" in key or "w3" in key:
final_result[key] = final_result[key].reshape(config.num_experts, config.intermediate_size, config.dim).contiguous()
elif "w2" in key:
final_result[key] = final_result[key].reshape(config.num_experts, config.intermediate_size, config.dim).permute(0, 2, 1).contiguous()
elif "gate" in key:
final_result[key] = final_result[key].contiguous()

print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}")
torch.save(final_result, checkpoint_dir / "model.pth")


if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint.')
parser.add_argument('--checkpoint_dir', type=Path, default=Path("checkpoints/meta-llama/llama-2-7b-chat-hf"))
parser.add_argument('--model_name', type=str, default=None)

args = parser.parse_args()
convert_hf_checkpoint(
checkpoint_dir=args.checkpoint_dir,
model_name=args.model_name,
)
checkpoint_dir=args.checkpoint_dir
model_name=args.model_name
if model_name is None:
model_name = checkpoint_dir.name

if "Mixtral-8x7B" in model_name:
_convert_mixtral(checkpoint_dir=checkpoint_dir, model_name=model_name)
else:
assert "dbrx" in model_name, f"Unknown model name {model_name}"
_convert_dbrx(checkpoint_dir=checkpoint_dir, model_name=model_name)
2 changes: 1 addition & 1 deletion mixtral-moe/scripts/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -
from huggingface_hub import snapshot_download
os.makedirs(f"checkpoints/{repo_id}", exist_ok=True)
try:
snapshot_download(repo_id, local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token, ignore_patterns="*.safetensors")
snapshot_download(repo_id, local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token)
except HTTPError as e:
if e.response.status_code == 401:
print("You need to pass a valid `--hf_token=...` to download private checkpoints.")
Expand Down
Loading