Skip to content

Commit

Permalink
Add olmoe.
Browse files Browse the repository at this point in the history
  • Loading branch information
joelburget committed Sep 9, 2024
1 parent db1a7f5 commit 74204c3
Show file tree
Hide file tree
Showing 8 changed files with 1,757 additions and 1,542 deletions.
3,204 changes: 1,665 additions & 1,539 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
{platform="linux", version=">=1.10"}, # We can use any torch version on Linux (e.g colab)
]
tqdm=">=4.64.1"
transformers=">=4.37.2"
transformers={ git = "https://github.com/huggingface/transformers.git" }
typing-extensions="*"
wandb=">=0.13.5"

Expand Down
2 changes: 1 addition & 1 deletion transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def __init__(
self.set_tokenizer(
AutoTokenizer.from_pretrained(
self.cfg.tokenizer_name,
add_bos_token=True,
# add_bos_token=True,
trust_remote_code=self.cfg.trust_remote_code,
use_fast=use_fast,
token=huggingface_token,
Expand Down
3 changes: 3 additions & 0 deletions transformer_lens/components/mlps/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):

self.num_experts: int = self.cfg.num_experts
self.experts_per_token: int = self.cfg.experts_per_token
# self.norm_topk_prob: bool = self.cfg.norm_topk_prob

assert (
self.cfg.experts_per_token <= self.cfg.num_experts
Expand All @@ -88,6 +89,8 @@ def forward(
# both are [batch, pos, experts_per_token]
weights = self.hook_expert_weights(F.softmax(gate_logits, dim=1, dtype=torch.float))
weights, expert_indices = torch.topk(weights, self.experts_per_token, dim=-1)
# if self.norm_topk_prob:
# weights /= weights.sum(dim=-1, keepdim=True)
weights /= weights.sum(dim=-1, keepdim=True)
expert_indices = self.hook_expert_indices(expert_indices)
weights = weights.to(x.dtype)
Expand Down
32 changes: 32 additions & 0 deletions transformer_lens/loading_from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
convert_neel_solu_old_weights,
convert_neo_weights,
convert_neox_weights,
convert_olmoe_weights,
convert_opt_weights,
convert_phi3_weights,
convert_phi_weights,
Expand Down Expand Up @@ -225,6 +226,7 @@
"google-t5/t5-base",
"google-t5/t5-large",
"ai-forever/mGPT",
"allenai/OLMoE-1B-7B-0924",
]
"""Official model names for models on HuggingFace."""

Expand Down Expand Up @@ -1329,6 +1331,34 @@ def convert_hf_model_config(model_name: str, **kwargs):
"use_attn_scale": False,
"tie_word_embeddings": hf_config.tie_word_embeddings,
}
elif architecture == "OlmoeForCausalLM":
cfg_dict = {
"d_model": hf_config.hidden_size,
"d_head": hf_config.hidden_size // hf_config.num_attention_heads,
"n_heads": hf_config.num_attention_heads,
"d_mlp": hf_config.intermediate_size,
"n_layers": hf_config.num_hidden_layers,
"n_ctx": hf_config.max_position_embeddings,
"eps": hf_config.rms_norm_eps,
"d_vocab": hf_config.vocab_size,
"act_fn": hf_config.hidden_act,
"num_experts": hf_config.num_experts,
"experts_per_token": hf_config.num_experts_per_tok,
# TODO: implement!
# "router_aux_loss_coef": hf_config.router_aux_loss_coef,
# "router_z_loss_coef": hf_config.router_z_loss_coef,
# "norm_topk_prob": hf_config.norm_topk_prob,
# end
"n_key_value_heads": hf_config.num_key_value_heads,
"rotary_base": hf_config.rope_theta,
"tie_word_embeddings": hf_config.tie_word_embeddings,
"initializer_range": hf_config.initializer_range,
"positional_embedding_type": "rotary",
"rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
"final_rms": True,
"gated_mlp": True,
"normalization_type": "RMS",
}
else:
raise NotImplementedError(f"{architecture} is not currently supported.")
# All of these models use LayerNorm
Expand Down Expand Up @@ -1714,6 +1744,8 @@ def get_pretrained_state_dict(
state_dict = convert_gemma_weights(hf_model, cfg)
elif cfg.original_architecture == "Gemma2ForCausalLM":
state_dict = convert_gemma_weights(hf_model, cfg)
elif cfg.original_architecture == "OlmoeForCausalLM":
state_dict = convert_olmoe_weights(hf_model, cfg)
else:
raise ValueError(
f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature."
Expand Down
1 change: 1 addition & 0 deletions transformer_lens/pretrained/weight_conversions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
from .nanogpt import convert_nanogpt_weights
from .t5 import convert_t5_weights
from .neel_solu_old import convert_neel_solu_old_weights
from .olmoe import convert_olmoe_weights
53 changes: 53 additions & 0 deletions transformer_lens/pretrained/weight_conversions/olmoe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import einops
import torch

from transformer_lens.HookedTransformerConfig import HookedTransformerConfig

def convert_olmoe_weights(olmoe, cfg: HookedTransformerConfig):
state_dict = {}

assert cfg.n_key_value_heads is not None
assert cfg.d_mlp is not None
assert cfg.num_experts is not None

state_dict["embed.W_E"] = olmoe.model.embed_tokens.weight

for l in range(cfg.n_layers):
olmoe_layer = olmoe.model.layers[l]
state_dict[f"blocks.{l}.ln1.w"] = olmoe_layer.input_layernorm.weight

W_Q = olmoe.model.layers[l].self_attn.q_proj.weight
W_K = olmoe.model.layers[l].self_attn.k_proj.weight
W_V = olmoe.model.layers[l].self_attn.v_proj.weight
W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads)
W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_key_value_heads)
W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_key_value_heads)
state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
state_dict[f"blocks.{l}.attn._W_K"] = W_K
state_dict[f"blocks.{l}.attn._W_V"] = W_V

state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype)
state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros(cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype)
state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros(cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype)

W_O = olmoe_layer.self_attn.o_proj.weight
W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads)
state_dict[f"blocks.{l}.attn.W_O"] = W_O

state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)

state_dict[f"blocks.{l}.ln2.w"] = olmoe_layer.post_attention_layernorm.weight

state_dict[f"blocks.{l}.mlp.W_gate.weight"] = olmoe_layer.mlp.gate.weight

for e in range(cfg.num_experts):
state_dict[f"blocks.{l}.mlp.experts.{e}.W_in.weight"] = olmoe_layer.mlp.experts[e].up_proj.weight
state_dict[f"blocks.{l}.mlp.experts.{e}.W_gate.weight"] = olmoe_layer.mlp.experts[e].gate_proj.weight
state_dict[f"blocks.{l}.mlp.experts.{e}.W_out.weight"] = olmoe_layer.mlp.experts[e].down_proj.weight

state_dict["ln_final.w"] = olmoe.model.norm.weight

state_dict["unembed.W_U"] = olmoe.lm_head.weight.T
state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype)

return state_dict
2 changes: 1 addition & 1 deletion transformer_lens/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,7 +1171,7 @@ def get_tokenizer_with_bos(tokenizer):
huggingface_token = os.environ.get("HF_TOKEN", None)
tokenizer_with_bos = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path,
add_bos_token=True,
# add_bos_token=True,
token=huggingface_token,
**init_kwargs,
)
Expand Down

0 comments on commit 74204c3

Please sign in to comment.