From d6284ca11bf80a935d71d9e4441b9f84f2ffc300 Mon Sep 17 00:00:00 2001 From: JulienDarve Date: Tue, 30 Jul 2024 11:25:20 -0700 Subject: [PATCH 01/29] [WIP] Implements Roberta Model --- src/levanter/models/roberta.py | 852 ++++++++++++++++++++++++++++++ src/levanter/models/testing.ipynb | 563 ++++++++++++++++++++ 2 files changed, 1415 insertions(+) create mode 100644 src/levanter/models/roberta.py create mode 100644 src/levanter/models/testing.ipynb diff --git a/src/levanter/models/roberta.py b/src/levanter/models/roberta.py new file mode 100644 index 000000000..2db3c9bd4 --- /dev/null +++ b/src/levanter/models/roberta.py @@ -0,0 +1,852 @@ +import dataclasses +from dataclasses import dataclass +from typing import Callable, Dict, List, Optional, Tuple, Type, Union + +import equinox as eqx +import jax +import jax.numpy as jnp +import jax.random as jrandom +from jaxtyping import PRNGKeyArray + +import haliax as hax +import haliax.nn as hnn +from haliax import Axis, AxisSpec, NamedArray +from haliax.jax_utils import maybe_rng_split, named_call, shaped_rng_split +from haliax.nn.scan import Stacked + +from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig +from levanter.compat.torch_serialization import ( + StateDict, + StateDictSerializationMixin, + apply_prefix, + flatten_linear_layers, + stack_state_dict, + unflatten_linear_layers, + unstack_state_dict, +) +from levanter.logging import silence_transformer_nag +from levanter.models.attention import AttentionBackend, AttentionMask, simple_attention_with_dropout +from levanter.models.gpt2 import ACT2FN +from levanter.models.lm_model import LmConfig, LmHeadModel +from levanter.types import BlockFoldable +from levanter.utils.flop_utils import lm_flops_per_token + +silence_transformer_nag() +from transformers import PretrainedConfig as HfConfig +from transformers import RobertaConfig as HfRobertaConfig + + + +@LmConfig.register_subclass("roberta") +@dataclass(frozen=True) +class RobertaConfig(HFCompatConfig): + r""" + + Adapted from HuggingFace RobertaConfig, description below + + + This is the configuration class to store the configuration of a [`RobertaModel`] or a [`TFRobertaModel`]. It is + used to instantiate a RoBERTa model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the RoBERTa + [FacebookAI/roberta-base](https://huggingface.co/FacebookAI/roberta-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50265): + Vocabulary size of the RoBERTa model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`RobertaModel`] or [`TFRobertaModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`RobertaModel`] or [`TFRobertaModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + position_embedding_type (`str`, *optional*, defaults to `"absolute"`): + Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For + positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to + [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). + For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models + with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658). + is_decoder (`bool`, *optional*, defaults to `False`): + Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + classifier_dropout (`float`, *optional*): + The dropout ratio for the classification head. + + Examples: + + ```python + >>> from transformers import RobertaConfig, RobertaModel + + >>> # Initializing a RoBERTa configuration + >>> configuration = RobertaConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = RobertaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + vocab_size: int = 50265 + hidden_size: int = 768 + num_hidden_layers: int = 12 + num_attention_heads: int = 12 + intermediate_size: int = 3072 + hidden_act: str = "gelu" + hidden_dropout_prob: float = 0.1 + attention_probs_dropout_prob: float = 0.1 + max_position_embeddings: int = 512 + type_vocab_size: int = 2 + initializer_range: float = 0.02 + layer_norm_eps: float = 1e-12 + pad_token_id: int = 1 + bos_token_id: int = 0 + eos_token_id: int = 2 + position_embedding_type: Optional[str] = "absolute" + use_cache: bool = False + classifier_dropout: Optional[float] = None + + scan_layers: bool = True + gradient_checkpointing: bool = True + + reference_checkpoint: str = "FacebookAI/roberta-base" + tokenizer: Optional[str] = None + + # Axis + Pos = property(lambda self: Axis(name="position", size=self.max_position_embeddings)) + KeyPos = property(lambda self: self.Pos.alias("key_position")) + Embed = property(lambda self: Axis(name="embed", size=self.hidden_size)) + EmbedAtt = property(lambda self: self.Embed.alias("embed_att")) + FinalEmbed = property(lambda self: self.Embed.alias("final_embed")) + Heads = property(lambda self: Axis(name="heads", size=self.num_attention_heads)) + Layers = property(lambda self: Axis(name="layers", size=self.num_hidden_layers)) + Mlp = property(lambda self: Axis(name="mlp", size=self.intermediate_size)) + HeadSize = property(lambda self: Axis(name="head_size", size=self.hidden_size // self.num_attention_heads)) + + def __post_init__(self): + # TODO + pass + + @classmethod + def from_hf_config(cls, hf_config: HfConfig) -> "RobertaConfig": + return RobertaConfig( + vocab_size = hf_config.vocab_size, + hidden_size = hf_config.hidden_size, + num_hidden_layers = hf_config.num_hidden_layers, + num_attention_heads = hf_config.num_attention_heads, + intermediate_size = hf_config.intermediate_size, + hidden_act = hf_config.hidden_act, + hidden_dropout_prob= hf_config.hidden_dropout_prob, + attention_probs_dropout_prob = hf_config.attention_probs_dropout_prob, + max_position_embeddings = hf_config.max_position_embeddings, + type_vocab_size = hf_config.type_vocab_size, + initializer_range = hf_config.initializer_range, + layer_norm_eps = hf_config.layer_norm_eps, + pad_token_id = hf_config.pad_token_id, + bos_token_id = hf_config.bos_token_id, + eos_token_id = hf_config.eos_token_id, + position_embedding_type = hf_config.position_embedding_type, + use_cache = hf_config.use_cache, + classifier_dropout = hf_config.classifier_dropout, + ) + + def hf_checkpoint_converter(self) -> HFCheckpointConverter["RobertaConfig"]: # type: ignore + return HFCheckpointConverter( + self.__class__, + reference_checkpoint=self.reference_checkpoint, + trust_remote_code=True, + tokenizer=self.tokenizer if self.tokenizer else self.reference_checkpoint, + HfConfigClass=HfRobertaConfig, + ) + + def to_hf_config(self, vocab_size: int, config_overrides: Optional[Dict] = None) -> HfRobertaConfig: + """Convert to HuggingFace's LlamaConfig + + Args: + vocab_size (int, optional): Vocabulary size of the tokenizer. Defaults to 32000. + config_overrides (dict, optional): Overrides for the config. Defaults to None. + + Returns: + HfRobertaConfig: HuggingFace's RobertaConfig + """ + + if config_overrides is None: + config_overrides = {} + + return HfRobertaConfig( + vocab_size = vocab_size, + hidden_size = self.hidden_size, + num_hidden_layers = self.num_hidden_layers, + num_attention_heads = self.num_attention_heads, + intermediate_size = self.intermediate_size, + hidden_act = self.hidden_act, + hidden_dropout_prob = self.hidden_dropout_prob, + attention_probs_dropout_prob = self.attention_probs_dropout_prob, + max_position_embeddings = self.max_position_embeddings, + type_vocab_size = self.type_vocab_size, + initializer_range = self.initializer_range, + layer_norm_eps = self.layer_norm_eps, + pad_token_id = self.pad_token_id, + bos_token_id = self.bos_token_id, + eos_token_id = self.eos_token_id, + position_embedding_type = self.position_embedding_type, + use_cache = self.use_cache, + classifier_dropout = self.classifier_dropout, + ) + + @property + def model_type(self) -> Type["RobertaModel"]: + return RobertaModel + + def flops_per_token(self, vocab_size: int): + return lm_flops_per_token( + hidden_dim=self.hidden_size, + intermediate_dim=self.intermediate_size, + num_layers=self.num_hidden_layers, + num_kv_heads=self.num_attention_heads, + num_heads=self.num_attention_heads, + seq_len=self.max_position_embeddings, + vocab_size=vocab_size, + glu=True, + ) + +class RobertaSelfAttention(eqx.Module, StateDictSerializationMixin): + + config: RobertaConfig + Heads: Axis + HeadSize: Axis + EmbedAtt: Axis + + q_proj: hnn.Linear + k_proj: hnn.Linear + v_proj: hnn.Linear + + dropout: hnn.Dropout + position_embedding_type: Optional[str] + + Pos: Axis + KeyPos: Axis + distance_embedding: Optional[hnn.Embedding] + + @staticmethod + def init(config: RobertaConfig, *, key) -> "RobertaSelfAttention": + Embed = config.Embed + EmbedAtt = config.EmbedAtt + + k_q, k_k, k_v, k_e = jrandom.split(key, 4) + q_proj = hnn.Linear.init(In=Embed, Out=EmbedAtt, key=k_q, out_first=True) + k_proj = hnn.Linear.init(In=Embed, Out=EmbedAtt, key=k_k, out_first=True) + v_proj = hnn.Linear.init(In=Embed, Out=EmbedAtt, key=k_v, out_first=True) + + dropout = hnn.Dropout(config.attention_probs_dropout_prob) + + distance_embedding = None + position_embedding_type = config.position_embedding_type + + if position_embedding_type == "relative_key" or position_embedding_type == "relative_key_query": + RelPos = Axis("rel_pos", 2 * config.max_position_embeddings - 1) + distance_embedding = hnn.Embedding.init(RelPos, config.HeadSize, k_e) + + return RobertaSelfAttention(config, config.Heads, config.HeadSize, EmbedAtt, + q_proj, k_proj, v_proj, + dropout, position_embedding_type, + config.Pos, config.KeyPos, distance_embedding, + ) + + def _state_dict_key_map(self) -> Dict[str, Optional[str]]: + return {"q_proj": "query", "k_proj": "key", "v_proj": "value"} + + def _rope_scale_factor(self) -> float: + # hasattr for gemma and I'm feeling lazy + if hasattr(self.config, "rope_scaling") and self.config.rope_scaling is not None: + assert self.config.rope_scaling["type"] == "linear" + return self.config.rope_scaling["factor"] + return 1.0 + + def transpose_for_scores(self, x: NamedArray) -> NamedArray: + # Makes sure to have the correct output order as well + y = hax.rearrange(x, "... position (embed_att: heads head_size) -> ... heads position head_size", heads=self.Heads, head_size=self.HeadSize) + return y + + @named_call + def __call__( + self, + hidden_states: NamedArray, + attention_mask: Optional[NamedArray] = None, + *, + key = None + ) -> Tuple[NamedArray]: + + query_layer = self.transpose_for_scores(self.q_proj(hidden_states)) + key_layer = self.transpose_for_scores(self.k_proj(hidden_states)) + value_layer = self.transpose_for_scores(self.v_proj(hidden_states)) + + if self.position_embedding_type == "rope": + cos, sin = llama_rotary_pos_emb( + self.config.HeadSize, hidden_states.resolve_axis("position"), scale=self._rope_scale_factor() + ) + query_layer, key_layer = _apply_rotary_pos_emb(query_layer, key_layer, cos, sin) + + key_layer = key_layer.rename({"position": "key_position"}) + value_layer = value_layer.rename({"position": "key_position"}) + + attention_scores = hax.dot(query_layer, key_layer, axis=self.HeadSize) # aka hax.einsum("bhld, bhrd -> bhlr") + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + Left = self.Pos # Queries + Right = self.KeyPos # Keys + + position_ids_l = hax.arange(Left).broadcast_to((Left,Right)) + position_ids_r = hax.arange(Right).broadcast_to((Left,Right)) + + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding(distance + self.Pos.size) + + if self.position_embedding_type == "relative_key": + relative_position_scores = hax.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = hax.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = hax.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores /= jnp.sqrt(self.HeadSize.size) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function) + # Attention_mask should have shape Batch Pos, so it should broadcast to shape Batch Heads Pos KeyPos for summation + attention_scores = attention_scores + attention_mask + + attention_probs = hnn.softmax(attention_scores, axis=self.KeyPos) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs, key=key) + + hax.dot(query_layer, key_layer, axis=self.HeadSize) + + context_layer = hax.dot(attention_probs, value_layer, axis=self.KeyPos) + + outputs = hax.rearrange(context_layer, ("... heads position head_size -> ... position (embed_att: heads head_size)"), heads=self.Heads, head_size=self.HeadSize) + + return outputs + +class RobertaSelfOutput(eqx.Module, StateDictSerializationMixin): + dense: hnn.Linear + LayerNorm: hnn.LayerNorm + dropout: hnn.Dropout + + @staticmethod + def init(config: RobertaConfig, *, key) -> "RobertaSelfOutput": + Embed = config.Embed + EmbedAtt = config.EmbedAtt + dense = hnn.Linear.init(In=EmbedAtt, Out=Embed, key=key, out_first=True) + LayerNorm = hnn.LayerNorm.init(axis=Embed, eps=config.layer_norm_eps) + dropout = hnn.Dropout(config.hidden_dropout_prob) + return RobertaSelfOutput(dense, LayerNorm, dropout) + + @named_call + def __call__(self, hidden_states: NamedArray, input: NamedArray,*, key) -> NamedArray: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, key=key) + hidden_states = self.LayerNorm(hidden_states + input) + return hidden_states + +class RobertaAttention(eqx.Module, StateDictSerializationMixin): + self_attn: RobertaSelfAttention + output: RobertaSelfOutput + + @staticmethod + def init(config: RobertaConfig, *, key) -> "RobertaAttention": + k_a, k_o = jrandom.split(key, 2) + + self_attn = RobertaSelfAttention.init(config, key=k_a) + output = RobertaSelfOutput.init(config, key=k_o) + + return RobertaAttention(self_attn, output) + + def _state_dict_key_map(self) -> Dict[str, Optional[str]]: + return {"self_attn": "self"} + + @named_call + def __call__( + self, + hidden_states: NamedArray, + attention_mask: Optional[NamedArray] = None, + *, + key + ) -> Tuple[NamedArray]: + k_a, k_o = maybe_rng_split(key, 2) + + self_outputs = self.self_attn( + hidden_states, + attention_mask, + key=k_a + ) + attention_output = self.output(self_outputs, hidden_states, key=k_o) + return attention_output + +class RobertaIntermediate(eqx.Module, StateDictSerializationMixin): + dense: hnn.Linear + intermediate_act_fn: Callable = eqx.static_field() + + @staticmethod + def init(config, *, key) -> "RobertaIntermediate": + dense = hnn.Linear.init(config.Embed, config.Mlp, key=key, out_first=True) + if isinstance(config.hidden_act, str): + intermediate_act_fn = ACT2FN[config.hidden_act] + else: + intermediate_act_fn = config.hidden_act + + return RobertaIntermediate(dense, intermediate_act_fn) + + @named_call + def __call__(self, hidden_states: NamedArray, *, key = None) -> NamedArray: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + +class RobertaOutput(eqx.Module, StateDictSerializationMixin): + dense: hnn.Linear + LayerNorm: hnn.LayerNorm + dropout: hnn.Dropout + + @staticmethod + def init(config: RobertaConfig, *, key) -> "RobertaSelfOutput": + Embed = config.Embed + dense = hnn.Linear.init(In=config.Mlp, Out=Embed, key=key, out_first=True) + LayerNorm = hnn.LayerNorm.init(axis=Embed, eps=config.layer_norm_eps) + dropout = hnn.Dropout(config.hidden_dropout_prob) + return RobertaSelfOutput(dense, LayerNorm, dropout) + + @named_call + def __call__(self, hidden_states: NamedArray, input: NamedArray, *, key) -> NamedArray: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, key=key) + hidden_states = self.LayerNorm(hidden_states + input) + return hidden_states + +class RobertaLayer(eqx.Module, StateDictSerializationMixin): + attention: RobertaAttention + intermediate: RobertaIntermediate + output: RobertaOutput + + @staticmethod + def init(config: RobertaConfig, *, key) -> "RobertaLayer": + k_a, k_i, k_o = jrandom.split(key, 3) + + attention = RobertaAttention.init(config, key=k_a) + intermediate = RobertaIntermediate.init(config, key=k_i) + output = RobertaOutput.init(config, key=k_o) + + return RobertaLayer(attention, intermediate, output) + + @named_call + def __call__( + self, + hidden_states: NamedArray, + attention_mask: Optional[NamedArray] = None, + *, + key + ) -> Tuple[NamedArray]: + k_a, k_o = maybe_rng_split(key, 2) + + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + key=k_a, + ) + + attention_output = self_attention_outputs + + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output, key=k_o) + + return layer_output + + +class RobertaEncoder(eqx.Module, StateDictSerializationMixin): + config: RobertaConfig + layer: BlockFoldable[RobertaLayer] + + @staticmethod + def init(config: RobertaConfig, *, key) -> "RobertaEncoder": + S = Stacked + if not config.scan_layers: + from haliax.nn.scan import BlockSeq + + S = BlockSeq + + layer = S.init(config.Layers, RobertaLayer, gradient_checkpointing=config.gradient_checkpointing)( + config, + key=shaped_rng_split(key, config.num_hidden_layers), #TODO: config.gradient_checkpointing + ) + + return RobertaEncoder(config, layer) + + @named_call + def __call__( + self, + hidden_states: NamedArray, + attention_mask: Optional[NamedArray] = None, + *, + key + ) -> Tuple[NamedArray]: + + keys = maybe_rng_split(key, self.config.num_hidden_layers) if key is not None else None + x = self.layer.fold(hidden_states, attention_mask, key=keys) + + return x + + def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): + if isinstance(self.layer, Stacked): + state_dict = stack_state_dict(state_dict, prefix=apply_prefix(prefix, "layer")) + + out = super().from_state_dict(state_dict, prefix=prefix) + return out + + def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: + my_state_dict: StateDict = {} + super().update_state_dict(my_state_dict, prefix=prefix) + + if isinstance(self.layer, Stacked): + stacked_dict = unstack_state_dict(my_state_dict, prefix=apply_prefix(prefix, "layer")) + state_dict.update(stacked_dict) + else: + state_dict.update(my_state_dict) + + return state_dict + +class RobertaEmbedding(eqx.Module, StateDictSerializationMixin): + Vocab: Axis = eqx.static_field() + Pos: Axis = eqx.static_field() + + word_embeddings: hnn.Embedding + position_embeddings: hnn.Embedding + token_type_embeddings: Optional[hnn.Embedding] + padding_idx: NamedArray + + LayerNorm: hnn.LayerNorm + dropout: hnn.Dropout + position_embedding_type: Optional[str] + + @staticmethod + def init(Vocab: Axis, config: RobertaConfig, *, key) -> "RobertaEmbedding": + key_w, key_p, key_t = jrandom.split(key, 3) + + padding_idx = config.pad_token_id + + word_embeddings = hnn.Embedding.init(Vocab, config.Embed, key=key_w) # padding_idx not specified + position_embeddings = hnn.Embedding.init(config.Pos, config.Embed, key=key_p) + + Token = hax.Axis("token", config.type_vocab_size) + + token_type_embeddings = hnn.Embedding.init(Token, config.Embed, key=key_t) + + LayerNorm = hnn.LayerNorm.init(config.Embed, config.layer_norm_eps) + dropout = hnn.Dropout(config.hidden_dropout_prob) + + return RobertaEmbedding(Vocab, config.Pos, word_embeddings, position_embeddings, token_type_embeddings, padding_idx, LayerNorm, dropout, config.position_embedding_type) + + def create_position_ids_from_input_ids(self, input_ids, past_key_values_length=0): + mask = hax.not_equal(input_ids, self.padding_idx) * 1 + incremental_indices = (hax.cumsum(mask, axis=self.Pos).astype(mask) + past_key_values_length) * mask + return incremental_indices + self.padding_idx + + def create_position_ids_from_inputs_embeds(self, input_axes, PosInput): + position_ids = hax.arange(axis = PosInput, start = self.padding_idx + 1, dtype=jnp.int32) + return hax.broadcast_to(position_ids, input_axes) + + @named_call + def embed(self, input_ids=None, token_type_ids=None, position_ids=None, input_embeds=None, past_key_values_length=0, *, key = None): + """ + Note: When inputting your own embeds into input_embeds, make sure that the embeds axis has the name "embed" + for compatibility with the position_id creation function. Make sures its length is not equal to + """ + + # Get Axes + if input_ids is not None: + input_axes = input_ids.axes + else: + input_axes = hax.eliminate_axes(input_embeds.axes, "embed") + + # Get position_ids + if position_ids is None: + if input_ids is not None: + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = self.create_position_ids_from_input_ids(input_ids, past_key_values_length) + else: + position_ids = self.create_position_ids_from_inputs_embeds(input_axes, input_embeds.resolve_axis("position")) + + # Get token_type_ids + if token_type_ids is None: + token_type_ids = hax.zeros(input_axes, dtype=jnp.int32) + + if input_embeds is None: + input_embeds = self.word_embeddings(input_ids) + + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings = input_embeds + token_type_embeddings + + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings, key=key) + return embeddings + +class RobertaPooler(eqx.Module, StateDictSerializationMixin): + dense: hnn.Linear + config: RobertaConfig + + @staticmethod + def init(config: RobertaConfig, *, key): + dense = hnn.Linear.init(In=config.Embed, Out=config.FinalEmbed, key=key, out_first=True) + + return RobertaPooler(dense, config) + + @named_call + def __call__(self, hidden_states: NamedArray, *, key=None) -> NamedArray: + first_token = hidden_states[{"position" : 0}] + x = self.dense(first_token, key=key).rename({self.config.Embed: self.config.FinalEmbed}) + x = hax.tanh(x) + return x + + +class RobertaModel(eqx.Module, StateDictSerializationMixin): + encoder: RobertaEncoder + embeddings: RobertaEmbedding + pooler : Optional[RobertaPooler] + + @staticmethod + def init(Vocab: Axis, config: RobertaConfig, add_pooling_layer: bool = True, *, key) -> "RobertaModel": + k_t, k_emb, k_p = jrandom.split(key, 3) + encoder = RobertaEncoder.init(config=config, key=k_t) + embeddings = RobertaEmbedding.init(Vocab, config, key=k_emb) + + pooler = RobertaPooler.init(config, key=k_p) if add_pooling_layer else None + return RobertaModel(encoder, embeddings, pooler) + + @property + def config(self): + return self.encoder.config + + @property + def vocab_size(self) -> int: + return self.Vocab.size + + @property + def Vocab(self) -> Axis: + return self.embeddings.Vocab + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + @named_call + def __call__( + self, + input_ids: Optional[NamedArray] = None, + token_type_ids: Optional[NamedArray] = None, + position_ids: Optional[NamedArray] = None, + input_embeds: Optional[NamedArray] = None, + attention_mask: Optional[NamedArray] = None, + *, + key, + ) -> Tuple[NamedArray]: + """ + Not Used: meant to be used to improve performance in decoder implementations + + head_mask: Optional[NamedArray] = None, + encoder_hidden_states: Optional[NamedArray] = None, + encoder_attention_mask: Optional[NamedArray] = None, + past_key_values_length = 0, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + """ + k_emb, k_e, k_p = maybe_rng_split(key, 3) + + if input_ids is not None and input_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_axes = input_ids.axes + elif input_embeds is not None: + input_axes = hax.eliminate_axes(input_embeds.axes, "embed") + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if attention_mask is None: + attention_mask = hax.ones(input_axes) + + # Attention mask from mask to real numbers + attention_mask = (attention_mask == 0) * -1e9 + + embedding_output = self.embeddings.embed(input_ids, token_type_ids, position_ids, input_embeds, key=k_emb) + sequence_output = self.encoder(embedding_output, attention_mask=attention_mask, key=k_e) + + pooled_output = self.pooler(sequence_output, key=k_p) if self.pooler is not None else None + + return (sequence_output, pooled_output) + +class RobertaLMHead(eqx.Module, StateDictSerializationMixin): + """Roberta Head for masked language modeling.""" + + dense: hnn.Linear + layer_norm: hnn.LayerNorm + decoder: hnn.Linear + + @staticmethod + def init(Vocab: Axis, config: RobertaConfig, *, key): + k_dense, k_decoder = jrandom.split(key, 2) + Embed = config.Embed + + dense = hnn.Linear.init(Embed, config.FinalEmbed, key=k_dense, out_first=True) + layer_norm = hnn.LayerNorm.init(axis=Embed, eps=config.layer_norm_eps) + + decoder = hnn.Linear.init(Embed, Vocab, key=k_decoder, out_first=True) + + # idk what this is: TODO + # self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + # self.decoder.bias = self.bias + + return RobertaLMHead(dense, layer_norm, decoder) + + @named_call + def __call__(self, features: NamedArray, *, key=None) -> NamedArray: + x = self.dense(features).rename({self.config.Embed: self.config.FinalEmbed}) + x = hnn.gelu(x) + x = self.layer_norm(x) + + # project back to size of vocabulary with bias + x = self.decoder(x) + + return x + +class RobertaForMaskedLM(eqx.Module, StateDictSerializationMixin): + roberta: RobertaModel + lm_head: RobertaLMHead + Vocab: Axis + + @classmethod + def init(self, Vocab: Axis, config: RobertaConfig, *, key): + + # if config.is_decoder: + # raise AttributeError("Model is being run as a MaskedLM aka an encoder model, but is_decoder is true") + + k_rob, key_head = jrandom.split(key, 2) + roberta = RobertaModel.init(Vocab, config, add_pooling_layer=False, key=k_rob) + lm_head = RobertaLMHead.init(Vocab, config, key=key_head) + + return RobertaForMaskedLM(roberta, lm_head, Vocab) + + def get_output_embeddings(self): + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + self.lm_head.decoder = new_embeddings + + def forward( + self, + input_ids: Optional[NamedArray] = None, + attention_mask: Optional[NamedArray] = None, + token_type_ids: Optional[NamedArray] = None, + position_ids: Optional[NamedArray] = None, + inputs_embeds: Optional[NamedArray] = None, + labels: Optional[NamedArray] = None, + ) -> Tuple[NamedArray]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + kwargs (`Dict[str, any]`, optional, defaults to *{}*): + Used to hide legacy arguments that have been deprecated. + """ + + outputs = self.roberta( + input_ids, + attn_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds + ) + + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + masked_lm_loss = None + if labels is not None: + masked_lm_loss = hnn.cross_entropy_loss(logits=prediction_scores, Label=self.Vocab, targets=labels) + + output = (prediction_scores,) + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + +def _rotate_half(x: NamedArray) -> NamedArray: + """Rotates half of the hidden dims of the input and concatenates them.""" + HeadSize = x.axes[-1] + x1 = x[HeadSize, : HeadSize.size // 2] + x2 = x[HeadSize, HeadSize.size // 2 :] + out = hax.concatenate(HeadSize, (-x2, x1)) + return out + + +def _apply_rotary_pos_emb( + q: NamedArray, # [batch, position, kv_heads, q_heads_per_group, head_size] + k: NamedArray, # [batch, position, kv_heads, head_size] + cos: NamedArray, # [position, head_size] + sin: NamedArray, # [position, head_size] +) -> Tuple[NamedArray, NamedArray]: + """Applies rotary position embedding to q and k.""" + q_embed = q * cos + _rotate_half(q) * sin + k_embed = k * cos + _rotate_half(k) * sin + return q_embed, k_embed + + +def llama_rotary_pos_emb( + HeadSize: Axis, Pos: Axis, base: float = 10000, scale: float = 1.0 +) -> Tuple[NamedArray, NamedArray]: + with jax.ensure_compile_time_eval(): + HeadHalfSize = HeadSize.resize(HeadSize.size // 2) + inv_freq: NamedArray = 1.0 / (base ** (hax.arange(HeadHalfSize, step=2) / HeadSize.size)) + + position_ids: NamedArray = hax.arange(Pos) / scale + + freqs = position_ids * inv_freq.broadcast_axis(Pos) + # This is different from the paper but aligns with HF implementation: + # It uses a different permutation in order to obtain the same calculation + emb = hax.concatenate(HeadSize, (freqs, freqs)) + cos = hax.cos(emb) + sin = hax.sin(emb) + # This is different from the paper but aligns with HF implementation: + return cos, sin diff --git a/src/levanter/models/testing.ipynb b/src/levanter/models/testing.ipynb new file mode 100644 index 000000000..d8e0b4531 --- /dev/null +++ b/src/levanter/models/testing.ipynb @@ -0,0 +1,563 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "2024-07-30 10:42:02,733\tINFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n" + ] + } + ], + "source": [ + "import roberta as my_roberta\n", + "from transformers.models.roberta import modeling_roberta as hf_roberta\n", + "\n", + "import torch\n", + "import haliax as hax\n", + "import jax.random as jrandom\n", + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoConfig\n", + "\n", + "hf_model_str = \"FacebookAI/roberta-base\"\n", + "\n", + "hf_config = AutoConfig.from_pretrained(hf_model_str)\n", + "hf_config.hidden_dropout_prob = 0\n", + "hf_config.attention_probs_dropout_prob = 0\n", + "my_config = my_roberta.RobertaConfig.from_hf_config(hf_config)\n", + "\n", + "key = jrandom.PRNGKey(0)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "EmbedAtt = my_config.EmbedAtt\n", + "Embed = my_config.Embed\n", + "Mlp = my_config.Mlp\n", + "Pos = my_config.Pos\n", + "KeyPos = my_config.KeyPos\n", + "Batch = hax.Axis(\"batch\", 2)\n", + "\n", + "x_embed = hax.random.normal(key, (Batch, Pos, Embed))\n", + "x_embed_att = hax.random.normal(key, (Batch, Pos, EmbedAtt))\n", + "x_mlp = hax.random.normal(key, (Batch, Pos, Mlp))\n", + "# x = x[{\"position\": slice(Pos.size-2)}]\n", + "\n", + "x_embed_torch = torch.from_numpy(np.array(x_embed.array))\n", + "x_embed_att_torch = torch.from_numpy(np.array(x_embed_att.array))\n", + "x_mlp_torch = torch.from_numpy(np.array(x_mlp.array))\n", + "\n", + "mask = hax.ones((Batch, Pos))[{\"position\": slice(0,-2)}]\n", + "mask_torch = torch.from_numpy(np.array(mask.array))\n", + "# mask_torch = torch.ones((2, hf_config.num_attention_heads, hf_config.max_position_embeddings -2, hf_config.max_position_embeddings -2))\n", + "\n", + "Vocab = hax.Axis(\"vocab\", my_config.vocab_size)\n", + "\n", + "input_embeds = hax.random.uniform(key, (Batch, Pos, Embed))\n", + "input_embeds = input_embeds[{\"position\": slice(0,-2)}]\n", + "\n", + "input_ids = hax.random.randint(key, (Batch, Pos), minval = 0, maxval = my_config.vocab_size)\n", + "input_ids = input_ids[{\"position\": slice(0,-2)}]\n", + "\n", + "input_embeds_torch = torch.from_numpy(np.array(input_embeds.array))\n", + "input_ids_torch = torch.from_numpy(np.array(input_ids.array))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# from transformers import AutoTokenizer\n", + "\n", + "# tokenizer = AutoTokenizer.from_pretrained(hf_model_str)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "def check(my_output, hf_output, p=False, pp=False, ppp=True, precision=1e-4):\n", + " \n", + " print(my_output.shape)\n", + " print(hf_output.shape)\n", + "\n", + " success = np.isclose(hf_output, np.array(my_output), rtol=precision, atol=precision).all()\n", + "\n", + " if success:\n", + " print(\"Success!!!\")\n", + " else:\n", + " print(\"Fail :((((\")\n", + " \n", + " if ppp:\n", + " acc_prev = None\n", + " for i in range(15):\n", + " prec = 10 ** (-1*i)\n", + " acc = np.isclose(hf_output, np.array(my_output), rtol=precision, atol=prec).mean()\n", + " if acc != acc_prev:\n", + " print(f\"Iteration {i}, Precision {prec}:\\t{acc}\")\n", + " acc_prev = acc\n", + "\n", + " if p: \n", + " acc = np.isclose(hf_output, np.array(my_output), rtol=precision, atol=precision).mean()\n", + " print(f\"Accuracy: {acc}\")\n", + " print(f\"Jax:\\n{torch.tensor(np.array(my_output))}\\nTorch:\\n{hf_output}\")\n", + "\n", + " if pp:\n", + " diff = torch.tensor(np.array(my_output)) - hf_output\n", + " print(f\"Mean: {diff.abs().mean()}\")\n", + " print(f\"Stdev: {diff.std()}\")\n", + " print(f\"Difference:\\n{diff}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'# Testing RobertaSelfOutput\\n\\nmy_self_output = my_roberta.RobertaSelfOutput.init(my_config, key=key)\\nstate = my_self_output.to_state_dict()\\n\\nstate = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\\n\\nprint(state.keys())\\n\\nhf_self_output = hf_roberta.RobertaSelfOutput(hf_config)\\nhf_self_output.load_state_dict(state, strict=True)\\n\\nmy_output = my_self_output(x_embed_att, x_embed, key=key)\\nhf_output = hf_self_output(x_embed_att_torch, x_embed_torch)\\n\\ncheck(my_output.array, hf_output.detach(), ppp=True)'" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "'''# Testing RobertaSelfOutput\n", + "\n", + "my_self_output = my_roberta.RobertaSelfOutput.init(my_config, key=key)\n", + "state = my_self_output.to_state_dict()\n", + "\n", + "state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "\n", + "print(state.keys())\n", + "\n", + "hf_self_output = hf_roberta.RobertaSelfOutput(hf_config)\n", + "hf_self_output.load_state_dict(state, strict=True)\n", + "\n", + "my_output = my_self_output(x_embed_att, x_embed, key=key)\n", + "hf_output = hf_self_output(x_embed_att_torch, x_embed_torch)\n", + "\n", + "check(my_output.array, hf_output.detach(), ppp=True)'''" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'# Testing RobertaSelfAttention\\n\\nmy_attn_output = my_roberta.RobertaSelfAttention.init(my_config, key=key)\\nstate = my_attn_output.to_state_dict()\\n\\nstate = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\\n\\nprint(state.keys())\\n\\nhf_attn_output = hf_roberta.RobertaSelfAttention(hf_config)\\nhf_attn_output.load_state_dict(state, strict=True)\\n\\nmy_output = my_attn_output(x_embed, mask, key=key)\\nhf_output = hf_attn_output(x_embed_torch, mask_torch)\\n\\ncheck(my_output.array, hf_output[0].detach(), ppp=True)'" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "'''# Testing RobertaSelfAttention\n", + "\n", + "my_attn_output = my_roberta.RobertaSelfAttention.init(my_config, key=key)\n", + "state = my_attn_output.to_state_dict()\n", + "\n", + "state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "\n", + "print(state.keys())\n", + "\n", + "hf_attn_output = hf_roberta.RobertaSelfAttention(hf_config)\n", + "hf_attn_output.load_state_dict(state, strict=True)\n", + "\n", + "my_output = my_attn_output(x_embed, mask, key=key)\n", + "hf_output = hf_attn_output(x_embed_torch, mask_torch)\n", + "\n", + "check(my_output.array, hf_output[0].detach(), ppp=True)'''" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'# Testing RobertaAttention\\n\\nmy_func = my_roberta.RobertaAttention.init(my_config, key=key)\\nstate = my_func.to_state_dict()\\n\\nstate = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\\n\\nprint(state.keys())\\n\\nhf_func = hf_roberta.RobertaAttention(hf_config)\\nhf_func.load_state_dict(state, strict=True)\\n\\nmy_output = my_func(hidden_states=x_embed, attention_mask=mask, key=key)\\nhf_output = hf_func(hidden_states=x_embed_torch, attention_mask=mask_torch)\\n\\ncheck(my_output.array, hf_output[0].detach(), ppp=True)'" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "'''# Testing RobertaAttention\n", + "\n", + "my_func = my_roberta.RobertaAttention.init(my_config, key=key)\n", + "state = my_func.to_state_dict()\n", + "\n", + "state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "\n", + "print(state.keys())\n", + "\n", + "hf_func = hf_roberta.RobertaAttention(hf_config)\n", + "hf_func.load_state_dict(state, strict=True)\n", + "\n", + "my_output = my_func(hidden_states=x_embed, attention_mask=mask, key=key)\n", + "hf_output = hf_func(hidden_states=x_embed_torch, attention_mask=mask_torch)\n", + "\n", + "check(my_output.array, hf_output[0].detach(), ppp=True)'''" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'# Testing RobertaIntermediate\\n\\nmy_int_output = my_roberta.RobertaIntermediate.init(my_config, key=key)\\nstate = my_int_output.to_state_dict()\\n\\nstate = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\\n\\nprint(state.keys())\\n\\nhf_int_output = hf_roberta.RobertaIntermediate(hf_config)\\nhf_int_output.load_state_dict(state, strict=True)\\n\\nmy_output = my_int_output(x_embed, key=key)\\nhf_output = hf_int_output(x_embed_torch)\\n\\ncheck(my_output.array, hf_output.detach(), ppp=True)'" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "'''# Testing RobertaIntermediate\n", + "\n", + "my_int_output = my_roberta.RobertaIntermediate.init(my_config, key=key)\n", + "state = my_int_output.to_state_dict()\n", + "\n", + "state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "\n", + "print(state.keys())\n", + "\n", + "hf_int_output = hf_roberta.RobertaIntermediate(hf_config)\n", + "hf_int_output.load_state_dict(state, strict=True)\n", + "\n", + "my_output = my_int_output(x_embed, key=key)\n", + "hf_output = hf_int_output(x_embed_torch)\n", + "\n", + "check(my_output.array, hf_output.detach(), ppp=True)'''" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'# Testing RobertaOutput\\n\\nmy_int_output = my_roberta.RobertaOutput.init(my_config, key=key)\\nstate = my_int_output.to_state_dict()\\n\\nstate = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\\n\\nprint(state.keys())\\n\\nhf_int_output = hf_roberta.RobertaOutput(hf_config)\\nhf_int_output.load_state_dict(state, strict=True)\\n\\nmy_output = my_int_output(x_mlp, x_embed, key=key)\\nhf_output = hf_int_output(x_mlp_torch, x_embed_torch)\\n\\ncheck(my_output.array, hf_output.detach(), ppp=True)'" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "'''# Testing RobertaOutput\n", + "\n", + "my_int_output = my_roberta.RobertaOutput.init(my_config, key=key)\n", + "state = my_int_output.to_state_dict()\n", + "\n", + "state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "\n", + "print(state.keys())\n", + "\n", + "hf_int_output = hf_roberta.RobertaOutput(hf_config)\n", + "hf_int_output.load_state_dict(state, strict=True)\n", + "\n", + "my_output = my_int_output(x_mlp, x_embed, key=key)\n", + "hf_output = hf_int_output(x_mlp_torch, x_embed_torch)\n", + "\n", + "check(my_output.array, hf_output.detach(), ppp=True)'''" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'# Testing RobertaLayer\\n\\nmy_func = my_roberta.RobertaLayer.init(my_config, key=key)\\nstate = my_func.to_state_dict()\\n\\nstate = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\\n\\nprint(state.keys())\\n\\nhf_func = hf_roberta.RobertaLayer(hf_config)\\nhf_func.load_state_dict(state, strict=True)\\n\\nmy_output = my_func(hidden_states=x_embed, attention_mask=mask, key=key)\\nhf_output = hf_func(hidden_states=x_embed_torch, attention_mask=mask_torch)\\n\\ncheck(my_output.array, hf_output[0].detach(), False, False, True, 1e-4)'" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "'''# Testing RobertaLayer\n", + "\n", + "my_func = my_roberta.RobertaLayer.init(my_config, key=key)\n", + "state = my_func.to_state_dict()\n", + "\n", + "state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "\n", + "print(state.keys())\n", + "\n", + "hf_func = hf_roberta.RobertaLayer(hf_config)\n", + "hf_func.load_state_dict(state, strict=True)\n", + "\n", + "my_output = my_func(hidden_states=x_embed, attention_mask=mask, key=key)\n", + "hf_output = hf_func(hidden_states=x_embed_torch, attention_mask=mask_torch)\n", + "\n", + "check(my_output.array, hf_output[0].detach(), False, False, True, 1e-4)'''" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'# Testing RobertaEncoder\\n\\nmy_func = my_roberta.RobertaEncoder.init(my_config, key=key)\\nstate = my_func.to_state_dict()\\n\\nstate = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\\n\\nprint(state.keys())\\n\\nhf_func = hf_roberta.RobertaEncoder(hf_config)\\nhf_func.load_state_dict(state, strict=True)\\n\\nmy_output = my_func(hidden_states=x_embed, attention_mask=mask, key=key)\\nhf_output = hf_func(hidden_states=x_embed_torch, attention_mask=mask_torch)\\n\\ncheck(my_output.array, hf_output[0].detach(), False, False, True, 1e-4)'" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "'''# Testing RobertaEncoder\n", + "\n", + "my_func = my_roberta.RobertaEncoder.init(my_config, key=key)\n", + "state = my_func.to_state_dict()\n", + "\n", + "state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "\n", + "print(state.keys())\n", + "\n", + "hf_func = hf_roberta.RobertaEncoder(hf_config)\n", + "hf_func.load_state_dict(state, strict=True)\n", + "\n", + "my_output = my_func(hidden_states=x_embed, attention_mask=mask, key=key)\n", + "hf_output = hf_func(hidden_states=x_embed_torch, attention_mask=mask_torch)\n", + "\n", + "check(my_output.array, hf_output[0].detach(), False, False, True, 1e-4)'''" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'# Testing RobertaEmbedding\\n\\nVocab = hax.Axis(\"vocab\", my_config.vocab_size)\\n\\ninput_embeds = hax.random.uniform(key, (Batch, Pos, Embed))\\ninput_embeds = input_embeds[{\"position\": slice(0,-2)}]\\ninput_ids = hax.random.randint(key, (Batch, Pos), minval = 0, maxval = my_config.vocab_size)\\ninput_ids = input_ids[{\"position\": slice(0,-2)}]\\n\\ninput_embeds_torch = torch.from_numpy(np.array(input_embeds.array))\\ninput_ids_torch = torch.from_numpy(np.array(input_ids.array))\\n\\nmy_func = my_roberta.RobertaEmbedding.init(Vocab, my_config, key=key)\\nstate = my_func.to_state_dict()\\n\\nstate = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\\n\\nprint(state.keys())\\n\\nhf_func = hf_roberta.RobertaEmbeddings(hf_config)\\nhf_func.load_state_dict(state, strict=True)\\n\\n# my_output = my_func.embed(input_embeds=input_embeds, key=key)\\n# hf_output = hf_func(inputs_embeds=input_embeds_torch)\\n\\nmy_output = my_func.embed(input_ids=input_ids, key=key)\\nhf_output = hf_func(input_ids=input_ids_torch)\\n\\ncheck(my_output.array, hf_output.detach())'" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "'''# Testing RobertaEmbedding\n", + "\n", + "Vocab = hax.Axis(\"vocab\", my_config.vocab_size)\n", + "\n", + "input_embeds = hax.random.uniform(key, (Batch, Pos, Embed))\n", + "input_embeds = input_embeds[{\"position\": slice(0,-2)}]\n", + "input_ids = hax.random.randint(key, (Batch, Pos), minval = 0, maxval = my_config.vocab_size)\n", + "input_ids = input_ids[{\"position\": slice(0,-2)}]\n", + "\n", + "input_embeds_torch = torch.from_numpy(np.array(input_embeds.array))\n", + "input_ids_torch = torch.from_numpy(np.array(input_ids.array))\n", + "\n", + "my_func = my_roberta.RobertaEmbedding.init(Vocab, my_config, key=key)\n", + "state = my_func.to_state_dict()\n", + "\n", + "state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "\n", + "print(state.keys())\n", + "\n", + "hf_func = hf_roberta.RobertaEmbeddings(hf_config)\n", + "hf_func.load_state_dict(state, strict=True)\n", + "\n", + "# my_output = my_func.embed(input_embeds=input_embeds, key=key)\n", + "# hf_output = hf_func(inputs_embeds=input_embeds_torch)\n", + "\n", + "my_output = my_func.embed(input_ids=input_ids, key=key)\n", + "hf_output = hf_func(input_ids=input_ids_torch)\n", + "\n", + "check(my_output.array, hf_output.detach())'''" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'# Testing RobertaPooler\\n\\nmy_pool = my_roberta.RobertaPooler.init(my_config, key=key)\\nstate = my_pool.to_state_dict()\\n\\nstate = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\\n\\nprint(state.keys())\\n\\nhf_pool = hf_roberta.RobertaPooler(hf_config)\\nhf_pool.load_state_dict(state, strict=True)\\n\\nmy_output = my_pool(x_embed, key=key)\\nhf_output = hf_pool(x_embed_torch)\\n\\ncheck(my_output.array, hf_output.detach(), ppp=True)'" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "'''# Testing RobertaPooler\n", + "\n", + "my_pool = my_roberta.RobertaPooler.init(my_config, key=key)\n", + "state = my_pool.to_state_dict()\n", + "\n", + "state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "\n", + "print(state.keys())\n", + "\n", + "hf_pool = hf_roberta.RobertaPooler(hf_config)\n", + "hf_pool.load_state_dict(state, strict=True)\n", + "\n", + "my_output = my_pool(x_embed, key=key)\n", + "hf_output = hf_pool(x_embed_torch)\n", + "\n", + "check(my_output.array, hf_output.detach(), ppp=True)'''" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'# Testing RobertaModel\\n\\nmy_pool = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=True, key=key)\\nstate = my_pool.to_state_dict()\\n\\nstate = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\\n\\nprint(state.keys())\\n\\nhf_pool = hf_roberta.RobertaModel(hf_config, add_pooling_layer=True)\\nhf_pool.load_state_dict(state, strict=True)\\n\\nmy_output = my_pool(input_ids = input_ids, attention_mask=mask, key=key)\\nhf_output = hf_pool(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\\n\\ncheck(my_output[0].array, hf_output[0].detach(), ppp=True)'" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "'''# Testing RobertaModel\n", + "\n", + "my_pool = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=True, key=key)\n", + "state = my_pool.to_state_dict()\n", + "\n", + "state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "\n", + "print(state.keys())\n", + "\n", + "hf_pool = hf_roberta.RobertaModel(hf_config, add_pooling_layer=True)\n", + "hf_pool.load_state_dict(state, strict=True)\n", + "\n", + "my_output = my_pool(input_ids = input_ids, attention_mask=mask, key=key)\n", + "hf_output = hf_pool(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\n", + "\n", + "check(my_output[0].array, hf_output[0].detach(), ppp=True)'''" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dict_keys(['roberta.encoder.layer.0.attention.self.query.weight', 'roberta.encoder.layer.1.attention.self.query.weight', 'roberta.encoder.layer.2.attention.self.query.weight', 'roberta.encoder.layer.3.attention.self.query.weight', 'roberta.encoder.layer.4.attention.self.query.weight', 'roberta.encoder.layer.5.attention.self.query.weight', 'roberta.encoder.layer.6.attention.self.query.weight', 'roberta.encoder.layer.7.attention.self.query.weight', 'roberta.encoder.layer.8.attention.self.query.weight', 'roberta.encoder.layer.9.attention.self.query.weight', 'roberta.encoder.layer.10.attention.self.query.weight', 'roberta.encoder.layer.11.attention.self.query.weight', 'roberta.encoder.layer.0.attention.self.query.bias', 'roberta.encoder.layer.1.attention.self.query.bias', 'roberta.encoder.layer.2.attention.self.query.bias', 'roberta.encoder.layer.3.attention.self.query.bias', 'roberta.encoder.layer.4.attention.self.query.bias', 'roberta.encoder.layer.5.attention.self.query.bias', 'roberta.encoder.layer.6.attention.self.query.bias', 'roberta.encoder.layer.7.attention.self.query.bias', 'roberta.encoder.layer.8.attention.self.query.bias', 'roberta.encoder.layer.9.attention.self.query.bias', 'roberta.encoder.layer.10.attention.self.query.bias', 'roberta.encoder.layer.11.attention.self.query.bias', 'roberta.encoder.layer.0.attention.self.key.weight', 'roberta.encoder.layer.1.attention.self.key.weight', 'roberta.encoder.layer.2.attention.self.key.weight', 'roberta.encoder.layer.3.attention.self.key.weight', 'roberta.encoder.layer.4.attention.self.key.weight', 'roberta.encoder.layer.5.attention.self.key.weight', 'roberta.encoder.layer.6.attention.self.key.weight', 'roberta.encoder.layer.7.attention.self.key.weight', 'roberta.encoder.layer.8.attention.self.key.weight', 'roberta.encoder.layer.9.attention.self.key.weight', 'roberta.encoder.layer.10.attention.self.key.weight', 'roberta.encoder.layer.11.attention.self.key.weight', 'roberta.encoder.layer.0.attention.self.key.bias', 'roberta.encoder.layer.1.attention.self.key.bias', 'roberta.encoder.layer.2.attention.self.key.bias', 'roberta.encoder.layer.3.attention.self.key.bias', 'roberta.encoder.layer.4.attention.self.key.bias', 'roberta.encoder.layer.5.attention.self.key.bias', 'roberta.encoder.layer.6.attention.self.key.bias', 'roberta.encoder.layer.7.attention.self.key.bias', 'roberta.encoder.layer.8.attention.self.key.bias', 'roberta.encoder.layer.9.attention.self.key.bias', 'roberta.encoder.layer.10.attention.self.key.bias', 'roberta.encoder.layer.11.attention.self.key.bias', 'roberta.encoder.layer.0.attention.self.value.weight', 'roberta.encoder.layer.1.attention.self.value.weight', 'roberta.encoder.layer.2.attention.self.value.weight', 'roberta.encoder.layer.3.attention.self.value.weight', 'roberta.encoder.layer.4.attention.self.value.weight', 'roberta.encoder.layer.5.attention.self.value.weight', 'roberta.encoder.layer.6.attention.self.value.weight', 'roberta.encoder.layer.7.attention.self.value.weight', 'roberta.encoder.layer.8.attention.self.value.weight', 'roberta.encoder.layer.9.attention.self.value.weight', 'roberta.encoder.layer.10.attention.self.value.weight', 'roberta.encoder.layer.11.attention.self.value.weight', 'roberta.encoder.layer.0.attention.self.value.bias', 'roberta.encoder.layer.1.attention.self.value.bias', 'roberta.encoder.layer.2.attention.self.value.bias', 'roberta.encoder.layer.3.attention.self.value.bias', 'roberta.encoder.layer.4.attention.self.value.bias', 'roberta.encoder.layer.5.attention.self.value.bias', 'roberta.encoder.layer.6.attention.self.value.bias', 'roberta.encoder.layer.7.attention.self.value.bias', 'roberta.encoder.layer.8.attention.self.value.bias', 'roberta.encoder.layer.9.attention.self.value.bias', 'roberta.encoder.layer.10.attention.self.value.bias', 'roberta.encoder.layer.11.attention.self.value.bias', 'roberta.encoder.layer.0.attention.output.dense.weight', 'roberta.encoder.layer.1.attention.output.dense.weight', 'roberta.encoder.layer.2.attention.output.dense.weight', 'roberta.encoder.layer.3.attention.output.dense.weight', 'roberta.encoder.layer.4.attention.output.dense.weight', 'roberta.encoder.layer.5.attention.output.dense.weight', 'roberta.encoder.layer.6.attention.output.dense.weight', 'roberta.encoder.layer.7.attention.output.dense.weight', 'roberta.encoder.layer.8.attention.output.dense.weight', 'roberta.encoder.layer.9.attention.output.dense.weight', 'roberta.encoder.layer.10.attention.output.dense.weight', 'roberta.encoder.layer.11.attention.output.dense.weight', 'roberta.encoder.layer.0.attention.output.dense.bias', 'roberta.encoder.layer.1.attention.output.dense.bias', 'roberta.encoder.layer.2.attention.output.dense.bias', 'roberta.encoder.layer.3.attention.output.dense.bias', 'roberta.encoder.layer.4.attention.output.dense.bias', 'roberta.encoder.layer.5.attention.output.dense.bias', 'roberta.encoder.layer.6.attention.output.dense.bias', 'roberta.encoder.layer.7.attention.output.dense.bias', 'roberta.encoder.layer.8.attention.output.dense.bias', 'roberta.encoder.layer.9.attention.output.dense.bias', 'roberta.encoder.layer.10.attention.output.dense.bias', 'roberta.encoder.layer.11.attention.output.dense.bias', 'roberta.encoder.layer.0.attention.output.LayerNorm.weight', 'roberta.encoder.layer.1.attention.output.LayerNorm.weight', 'roberta.encoder.layer.2.attention.output.LayerNorm.weight', 'roberta.encoder.layer.3.attention.output.LayerNorm.weight', 'roberta.encoder.layer.4.attention.output.LayerNorm.weight', 'roberta.encoder.layer.5.attention.output.LayerNorm.weight', 'roberta.encoder.layer.6.attention.output.LayerNorm.weight', 'roberta.encoder.layer.7.attention.output.LayerNorm.weight', 'roberta.encoder.layer.8.attention.output.LayerNorm.weight', 'roberta.encoder.layer.9.attention.output.LayerNorm.weight', 'roberta.encoder.layer.10.attention.output.LayerNorm.weight', 'roberta.encoder.layer.11.attention.output.LayerNorm.weight', 'roberta.encoder.layer.0.attention.output.LayerNorm.bias', 'roberta.encoder.layer.1.attention.output.LayerNorm.bias', 'roberta.encoder.layer.2.attention.output.LayerNorm.bias', 'roberta.encoder.layer.3.attention.output.LayerNorm.bias', 'roberta.encoder.layer.4.attention.output.LayerNorm.bias', 'roberta.encoder.layer.5.attention.output.LayerNorm.bias', 'roberta.encoder.layer.6.attention.output.LayerNorm.bias', 'roberta.encoder.layer.7.attention.output.LayerNorm.bias', 'roberta.encoder.layer.8.attention.output.LayerNorm.bias', 'roberta.encoder.layer.9.attention.output.LayerNorm.bias', 'roberta.encoder.layer.10.attention.output.LayerNorm.bias', 'roberta.encoder.layer.11.attention.output.LayerNorm.bias', 'roberta.encoder.layer.0.intermediate.dense.weight', 'roberta.encoder.layer.1.intermediate.dense.weight', 'roberta.encoder.layer.2.intermediate.dense.weight', 'roberta.encoder.layer.3.intermediate.dense.weight', 'roberta.encoder.layer.4.intermediate.dense.weight', 'roberta.encoder.layer.5.intermediate.dense.weight', 'roberta.encoder.layer.6.intermediate.dense.weight', 'roberta.encoder.layer.7.intermediate.dense.weight', 'roberta.encoder.layer.8.intermediate.dense.weight', 'roberta.encoder.layer.9.intermediate.dense.weight', 'roberta.encoder.layer.10.intermediate.dense.weight', 'roberta.encoder.layer.11.intermediate.dense.weight', 'roberta.encoder.layer.0.intermediate.dense.bias', 'roberta.encoder.layer.1.intermediate.dense.bias', 'roberta.encoder.layer.2.intermediate.dense.bias', 'roberta.encoder.layer.3.intermediate.dense.bias', 'roberta.encoder.layer.4.intermediate.dense.bias', 'roberta.encoder.layer.5.intermediate.dense.bias', 'roberta.encoder.layer.6.intermediate.dense.bias', 'roberta.encoder.layer.7.intermediate.dense.bias', 'roberta.encoder.layer.8.intermediate.dense.bias', 'roberta.encoder.layer.9.intermediate.dense.bias', 'roberta.encoder.layer.10.intermediate.dense.bias', 'roberta.encoder.layer.11.intermediate.dense.bias', 'roberta.encoder.layer.0.output.dense.weight', 'roberta.encoder.layer.1.output.dense.weight', 'roberta.encoder.layer.2.output.dense.weight', 'roberta.encoder.layer.3.output.dense.weight', 'roberta.encoder.layer.4.output.dense.weight', 'roberta.encoder.layer.5.output.dense.weight', 'roberta.encoder.layer.6.output.dense.weight', 'roberta.encoder.layer.7.output.dense.weight', 'roberta.encoder.layer.8.output.dense.weight', 'roberta.encoder.layer.9.output.dense.weight', 'roberta.encoder.layer.10.output.dense.weight', 'roberta.encoder.layer.11.output.dense.weight', 'roberta.encoder.layer.0.output.dense.bias', 'roberta.encoder.layer.1.output.dense.bias', 'roberta.encoder.layer.2.output.dense.bias', 'roberta.encoder.layer.3.output.dense.bias', 'roberta.encoder.layer.4.output.dense.bias', 'roberta.encoder.layer.5.output.dense.bias', 'roberta.encoder.layer.6.output.dense.bias', 'roberta.encoder.layer.7.output.dense.bias', 'roberta.encoder.layer.8.output.dense.bias', 'roberta.encoder.layer.9.output.dense.bias', 'roberta.encoder.layer.10.output.dense.bias', 'roberta.encoder.layer.11.output.dense.bias', 'roberta.encoder.layer.0.output.LayerNorm.weight', 'roberta.encoder.layer.1.output.LayerNorm.weight', 'roberta.encoder.layer.2.output.LayerNorm.weight', 'roberta.encoder.layer.3.output.LayerNorm.weight', 'roberta.encoder.layer.4.output.LayerNorm.weight', 'roberta.encoder.layer.5.output.LayerNorm.weight', 'roberta.encoder.layer.6.output.LayerNorm.weight', 'roberta.encoder.layer.7.output.LayerNorm.weight', 'roberta.encoder.layer.8.output.LayerNorm.weight', 'roberta.encoder.layer.9.output.LayerNorm.weight', 'roberta.encoder.layer.10.output.LayerNorm.weight', 'roberta.encoder.layer.11.output.LayerNorm.weight', 'roberta.encoder.layer.0.output.LayerNorm.bias', 'roberta.encoder.layer.1.output.LayerNorm.bias', 'roberta.encoder.layer.2.output.LayerNorm.bias', 'roberta.encoder.layer.3.output.LayerNorm.bias', 'roberta.encoder.layer.4.output.LayerNorm.bias', 'roberta.encoder.layer.5.output.LayerNorm.bias', 'roberta.encoder.layer.6.output.LayerNorm.bias', 'roberta.encoder.layer.7.output.LayerNorm.bias', 'roberta.encoder.layer.8.output.LayerNorm.bias', 'roberta.encoder.layer.9.output.LayerNorm.bias', 'roberta.encoder.layer.10.output.LayerNorm.bias', 'roberta.encoder.layer.11.output.LayerNorm.bias', 'roberta.embeddings.word_embeddings.weight', 'roberta.embeddings.position_embeddings.weight', 'roberta.embeddings.token_type_embeddings.weight', 'roberta.embeddings.LayerNorm.weight', 'roberta.embeddings.LayerNorm.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.decoder.bias'])\n" + ] + }, + { + "ename": "RuntimeError", + "evalue": "Error(s) in loading state_dict for RobertaForMaskedLM:\n\tMissing key(s) in state_dict: \"lm_head.bias\". ", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[1;32mIn[16], line 11\u001b[0m\n\u001b[0;32m 8\u001b[0m \u001b[38;5;28mprint\u001b[39m(state\u001b[38;5;241m.\u001b[39mkeys())\n\u001b[0;32m 10\u001b[0m hf_pool \u001b[38;5;241m=\u001b[39m hf_roberta\u001b[38;5;241m.\u001b[39mRobertaForMaskedLM(hf_config)\n\u001b[1;32m---> 11\u001b[0m \u001b[43mhf_pool\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload_state_dict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstate\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstrict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[0;32m 13\u001b[0m my_output \u001b[38;5;241m=\u001b[39m my_pool(input_ids \u001b[38;5;241m=\u001b[39m input_ids, attention_mask\u001b[38;5;241m=\u001b[39mmask, key\u001b[38;5;241m=\u001b[39mkey)\n\u001b[0;32m 14\u001b[0m hf_output \u001b[38;5;241m=\u001b[39m hf_pool(input_ids \u001b[38;5;241m=\u001b[39m input_ids_torch, attention_mask\u001b[38;5;241m=\u001b[39mmask_torch, return_dict\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n", + "File \u001b[1;32mc:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\torch\\nn\\modules\\module.py:2189\u001b[0m, in \u001b[0;36mModule.load_state_dict\u001b[1;34m(self, state_dict, strict, assign)\u001b[0m\n\u001b[0;32m 2184\u001b[0m error_msgs\u001b[38;5;241m.\u001b[39minsert(\n\u001b[0;32m 2185\u001b[0m \u001b[38;5;241m0\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mMissing key(s) in state_dict: \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m. \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[0;32m 2186\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mk\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m k \u001b[38;5;129;01min\u001b[39;00m missing_keys)))\n\u001b[0;32m 2188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(error_msgs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m-> 2189\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mError(s) in loading state_dict for \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[0;32m 2190\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(error_msgs)))\n\u001b[0;32m 2191\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _IncompatibleKeys(missing_keys, unexpected_keys)\n", + "\u001b[1;31mRuntimeError\u001b[0m: Error(s) in loading state_dict for RobertaForMaskedLM:\n\tMissing key(s) in state_dict: \"lm_head.bias\". " + ] + } + ], + "source": [ + "# Testing RobertaForMaskedLM\n", + "\n", + "my_pool = my_roberta.RobertaForMaskedLM.init(Vocab, my_config, key=key)\n", + "state = my_pool.to_state_dict()\n", + "\n", + "state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "\n", + "print(state.keys())\n", + "\n", + "hf_pool = hf_roberta.RobertaForMaskedLM(hf_config)\n", + "hf_pool.load_state_dict(state, strict=True)\n", + "\n", + "my_output = my_pool(input_ids = input_ids, attention_mask=mask, key=key)\n", + "hf_output = hf_pool(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\n", + "\n", + "check(my_output[0].array, hf_output[0].detach(), ppp=True)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 8f7402ef2cc63848bd7a8d81ed419c2a91a5451e Mon Sep 17 00:00:00 2001 From: prady-saligram Date: Tue, 30 Jul 2024 11:33:46 -0700 Subject: [PATCH 02/29] Implements dynamic masking objective --- src/levanter/main/train_mlm.py | 200 +++++++++++++++++++++++++++++++++ 1 file changed, 200 insertions(+) create mode 100644 src/levanter/main/train_mlm.py diff --git a/src/levanter/main/train_mlm.py b/src/levanter/main/train_mlm.py new file mode 100644 index 000000000..208131310 --- /dev/null +++ b/src/levanter/main/train_mlm.py @@ -0,0 +1,200 @@ +# train_mlm.py + +import dataclasses +import gc +import logging +import os +from dataclasses import dataclass, field +from typing import Optional, Union + +import jax.random as jrandom + +import haliax as hax +from haliax import Axis +from haliax.partitioning import named_jit, round_axis_for_partitioning + +import levanter +from levanter import callbacks +from levanter.compat.hf_checkpoints import HFCompatConfig, save_hf_checkpoint_callback +from levanter.data.text import MaskedLmDataset, LMDatasetConfig, LMMixtureDatasetConfig +from levanter.models.gpt2 import Gpt2Config +from levanter.models.llama import LlamaConfig +from levanter.models.lm_model import LmConfig +from levanter.optim import AdamConfig, OptimizerConfig +from levanter.trainer import Trainer, TrainerConfig +from levanter.utils.jax_utils import parameter_count + +logger = logging.getLogger(__name__) + +@dataclass +class TrainMlmConfig: + data: Union[LMDatasetConfig, LMMixtureDatasetConfig] = field(default_factory=LMDatasetConfig) + trainer: TrainerConfig = field(default_factory=TrainerConfig) + model: LmConfig = field(default_factory=LlamaConfig) + optimizer: OptimizerConfig = field(default_factory=AdamConfig) + + # config related to continued pretraining + initialize_from_hf: Union[bool, str] = False + """if provided, this will override the model config in the config. if true, use the default hf checkpoint for this model class""" + use_hf_model_config: bool = False # if true, replace the model config with the hf config from the checkpoint + + # TODO: atm we don't support loading from a checkpoint that has a different tokenizer. this is a bit annoying + # TODO: atm you have to at least specify a levanter model config with the same type as the hf checkpoint + + mlm_prob: float = 0.15 # masking probability for MLM + hf_save_path: Optional[str] = None + hf_upload: Optional[str] = None + hf_save_steps: int = 10000 + + update_hessian_steps: int = 10 + data_seed: Optional[int] = None # if provided, will override the data seed from the trainer + +def main(config: TrainMlmConfig): + tokenizer = config.data.the_tokenizer + + # this is some unpleasant code to allow us to initialize from a hf checkpoint. If this is your first read through, + # I recommend skipping it for now + if config.initialize_from_hf: + if config.trainer.initialize_from is not None: + raise ValueError("Cannot specify both initialize_from_hf and initialize_from") + + assert isinstance(config.model, HFCompatConfig) + converter = config.model.hf_checkpoint_converter() + if hasattr(tokenizer, "vocab") and tokenizer.vocab != converter.tokenizer.vocab: + logger.warning("The tokenizers appear to be different. You may want to check this.") + + if isinstance(config.initialize_from_hf, str): + converter = converter.replaced(reference_checkpoint=config.initialize_from_hf, tokenizer=tokenizer) + else: + converter = converter.replaced(tokenizer=tokenizer) + + if config.use_hf_model_config: + # TODO: log diff of old and new config + # NB: gross mutability + config.model = converter.config_from_hf_config(converter.default_hf_config) + elif isinstance(config.model, HFCompatConfig): + converter = config.model.hf_checkpoint_converter() + converter = converter.replaced(tokenizer=tokenizer) + else: + converter = None + + levanter.initialize(config) + optimizer = config.optimizer.build(config.trainer.num_train_steps) + + # Using the trainer as a context manager does 3 things: + # 1. Sets the device mesh + # 2. Sets the axis mapping (for fsdp) + # 3. Sets the global metrics tracker + with Trainer(config.trainer, optimizer) as trainer: + # randomness in jax is tightly controlled by "keys" which are the states of the random number generators + # this makes deterministic training pretty easy + seed = config.trainer.seed + data_key, loader_key, model_key, training_key = jrandom.split(jrandom.PRNGKey(seed), 4) + + if config.data_seed is not None: + logger.info(f"Overriding data seed with {config.data_seed}") + data_key = jrandom.PRNGKey(config.data_seed) + + # We have two axis_mappings: one for storing the model and optimizer states, and one for compute + # This allows Zero-3-style parameter sharding, where we shard the parameters and optimizer state across the mesh + compute_axis_mapping = trainer.compute_axis_mapping + parameter_axis_mapping = trainer.parameter_axis_mapping + + # some axes we need + Batch = config.trainer.TrainBatch + EvalBatch = config.trainer.EvalBatch + Pos = config.model.Pos + KeyPos = config.model.KeyPos + + tagged_eval_datasets = config.data.tagged_eval_sets(Pos.size) + train_dataset = MaskedLmDataset( + config.data.train_set(Pos.size, key=data_key), Pos, KeyPos, mask_prob=config.mlm_prob, key=data_key, ignore_index=config.data.ignore_token_id + ) + + # to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to + # For most things, we just insist you specify the config right, but tokenizers often have strange numbers of + # tokens: gpt-2 has 50257, for example. So we round up. + vocab_size = len(tokenizer) + Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), parameter_axis_mapping) + if vocab_size != Vocab.size: + logger.info(f"Rounding vocab size from {vocab_size} to {Vocab.size} for partitioning") + + state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key)) + + if int(state.step) == 0: + # TODO: I don't love that we init the model twice, but it's not a big deal i think? + if config.initialize_from_hf: + # initialize from an hf pretrained model + logger.info( + "No training checkpoint found. Initializing model from HF checkpoint" + f" '{converter.reference_checkpoint}'" + ) + # this is a bit gross, but we want to free up the memory from the model we just built + state = dataclasses.replace(state, model=None) + gc.collect() + model = converter.load_pretrained( + config.model.model_type, + config.model, + axis_mapping=parameter_axis_mapping, + dtype=trainer.mp.compute_dtype, + ) + model = named_jit(trainer.mp.cast_to_param, parameter_axis_mapping)(model) + state = dataclasses.replace(state, model=model) + else: + logger.info("No checkpoint found. Starting from scratch.") + + levanter.tracker.log_summary({"parameter_count": parameter_count(state.model)}) + + if len(tagged_eval_datasets) == 0: + logger.warning("No evaluation datasets provided.") + else: + masked_datasets = [ + (MaskedLmDataset(ds, Pos, KeyPos, mask_prob=config.mlm_prob, key=data_key, ignore_index=config.data.ignore_token_id), tags) + for ds, tags in tagged_eval_datasets + ] + max_eval_examples_per_ds = config.trainer.max_eval_batches + if max_eval_examples_per_ds is not None: + max_eval_examples_per_ds *= config.trainer.eval_batch_size + + cb = levanter.eval.cb_tagged_lm_evaluate( + EvalBatch, masked_datasets, trainer.device_mesh, compute_axis_mapping, max_eval_examples_per_ds + ) + trainer.add_hook(cb, every=config.trainer.steps_per_eval) + + flops_per_token = config.model.flops_per_token(vocab_size) + flops_per_example = 3 * flops_per_token * Pos.size if flops_per_token is not None else None + trainer.add_hook( + callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size, flops_per_example), every=1 + ) + if config.hf_save_path is not None: + full_save_path = os.path.join(config.hf_save_path, trainer.run_id) + + trainer.add_hook( + save_hf_checkpoint_callback(full_save_path, converter, upload_to_hf=config.hf_upload or False), + every=config.hf_save_steps, + ) + + # visualize log probs + @named_jit( + in_axis_resources=parameter_axis_mapping, + axis_resources=compute_axis_mapping, + out_axis_resources=compute_axis_mapping, + ) + def compute_log_probs(model, example): + model = trainer.mp.cast_to_compute(model) + logprobs = model.compute_loss(example, key=None, reduction=None) + # roll forward to get the loss for each predicted token + logprobs = hax.roll(logprobs, 1, Pos) + return logprobs.rearrange((EvalBatch, Pos)).array + + train_loader = iter(trainer.sharded_loader(train_dataset, Batch)) + + if int(state.step) > 0: + import tqdm + for _ in tqdm.tqdm(range(state.step), desc="seeking data for resume"): + next(train_loader) + + trainer.train(state, train_loader) + +if __name__ == "__main__": + levanter.config.main(main)() From 670b053761c6806b39787bfd91626fea8be6876c Mon Sep 17 00:00:00 2001 From: prady-saligram Date: Tue, 30 Jul 2024 12:08:08 -0700 Subject: [PATCH 03/29] Implements dynamic masked dataset --- src/levanter/data/text.py | 58 +++++++++++++++++++++------------------ 1 file changed, 32 insertions(+), 26 deletions(-) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 484a98bf6..d0dcc95d7 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -14,6 +14,7 @@ import equinox as eqx import fsspec import jax +import jax.numpy as jnp import numpy as np import pyarrow as pa import regex @@ -64,30 +65,29 @@ DEFAULT_IGNORE_INDEX = -100 # Mirrors pytorch's default ignore index - -class CausalLmDataset(ShardableDataset[LmExample]): +class MaskedLmDataset(ShardableDataset[LmExample]): def __init__( self, dataset: ShardableDataset[np.ndarray], QPos: Axis, KPos: Axis, - fcm_prob: float = 0.0, + mask_prob: float = 0.15, key: Optional[PRNGKeyArray] = None, - ignore_index: Optional[int] = None, + ignore_index: Optional[int] = DEFAULT_IGNORE_INDEX, ): self.dataset = dataset self.QPos = QPos self.KPos = KPos - self.fcm_prob = fcm_prob + self.mask_prob = mask_prob self.key = key - self.ignore_id = ignore_index + self.ignore_id = ignore_index if ignore_index is not None else DEFAULT_IGNORE_INDEX - if self.fcm_prob > 0.0 and self.key is None: - raise ValueError("must provide key if fcm_prob > 0.0") + if self.mask_prob > 0.0 and self.key is None: + raise ValueError("must provide key if mask_prob > 0.0") - def shard(self, shard_id: int, num_shards: int) -> "CausalLmDataset": - return CausalLmDataset( - self.dataset.shard(shard_id, num_shards), self.QPos, self.KPos, self.fcm_prob, self.key, self.ignore_id + def shard(self, shard_id: int, num_shards: int) -> "MaskedLmDataset": + return MaskedLmDataset( + self.dataset.shard(shard_id, num_shards), self.QPos, self.KPos, self.mask_prob, self.key, self.ignore_id ) def __iter__(self) -> Iterator[LmExample]: @@ -95,31 +95,37 @@ def __iter__(self) -> Iterator[LmExample]: sharding = jax.sharding.SingleDeviceSharding(jax.local_devices(backend="cpu")[0]) with use_cpu_device(): - @functools.partial(eqx.filter_jit, out_shardings=sharding) - def _create_lm_example(tokens, key): - tokens = hax.named(tokens, self.QPos) - + def _create_mlm_example(tokens, key): + tokens_array = tokens.array + example = LmExample.causal(tokens=tokens, ignore_id=self.ignore_id) - - if self.fcm_prob > 0: - # masks for attention - # We support forgetful causal masking (FCM) which is a technique that improves training speed by - # randomly masking out some of the context. This is a bit like dropout, but it's applied to the attention - # mask instead of the activations. It's described in https://arxiv.org/abs/2210.13432 - assert self.key is not None + + if self.mask_prob > 0: this_key, key = jax.random.split(key) - fcm_mask = hax.nn.attention.forgetful_causal_mask(self.KPos, self.fcm_prob, key=this_key) - attn_mask = example.attn_mask & AttentionMask.explicit(fcm_mask) - example = dataclasses.replace(example, attn_mask=attn_mask) + mask_shape = tokens_array.shape + mask = jax.random.bernoulli(this_key, self.mask_prob, mask_shape) + + # Create a mask for 80% [MASK], 10% random, 10% original token + rand = jax.random.uniform(this_key, mask_shape) + mask_token = jnp.where(rand < 0.8, self.ignore_id, tokens_array) + mask_token = jnp.where((rand >= 0.8) & (rand < 0.9), tokens_array, mask_token) + random_tokens = jax.random.randint(this_key, mask_shape, 0, tokens_array.max() + 1) + masked_tokens = jnp.where(mask, mask_token, random_tokens) + + masked_tokens_named = hax.named(masked_tokens, self.QPos) + example = dataclasses.replace(example, tokens=masked_tokens_named) return example for tokens in self.dataset: - example = _create_lm_example(tokens, key) + tokens_array = jnp.array(tokens) + tokens_named = hax.named(tokens_array, self.QPos) + example = _create_mlm_example(tokens_named, key) yield example + class TokenSeqDataset(ShardableDataset[np.ndarray]): """ A dataset that yields sequences of tokens of fixed length from a TokenizedDocumentCache. From 42f54042fe9a2b089c79607ab273fd46fd393b18 Mon Sep 17 00:00:00 2001 From: prady-saligram Date: Tue, 30 Jul 2024 14:17:16 -0700 Subject: [PATCH 04/29] Reintroduced accidentally deleted CausalLMDataset class --- src/levanter/data/text.py | 56 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index d0dcc95d7..4e625ee8b 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -125,6 +125,62 @@ def _create_mlm_example(tokens, key): yield example +class CausalLmDataset(ShardableDataset[LmExample]): + def __init__( + self, + dataset: ShardableDataset[np.ndarray], + QPos: Axis, + KPos: Axis, + fcm_prob: float = 0.0, + key: Optional[PRNGKeyArray] = None, + ignore_index: Optional[int] = None, + ): + self.dataset = dataset + self.QPos = QPos + self.KPos = KPos + self.fcm_prob = fcm_prob + self.key = key + self.ignore_id = ignore_index + + if self.fcm_prob > 0.0 and self.key is None: + raise ValueError("must provide key if fcm_prob > 0.0") + + def shard(self, shard_id: int, num_shards: int) -> "CausalLmDataset": + return CausalLmDataset( + self.dataset.shard(shard_id, num_shards), self.QPos, self.KPos, self.fcm_prob, self.key, self.ignore_id + ) + + def __iter__(self) -> Iterator[LmExample]: + key = self.key + sharding = jax.sharding.SingleDeviceSharding(jax.local_devices(backend="cpu")[0]) + + with use_cpu_device(): + + @functools.partial(eqx.filter_jit, out_shardings=sharding) + def _create_lm_example(tokens, key): + tokens = hax.named(tokens, self.QPos) + + example = LmExample.causal(tokens=tokens, ignore_id=self.ignore_id) + + if self.fcm_prob > 0: + # masks for attention + # We support forgetful causal masking (FCM) which is a technique that improves training speed by + # randomly masking out some of the context. This is a bit like dropout, but it's applied to the attention + # mask instead of the activations. It's described in https://arxiv.org/abs/2210.13432 + assert self.key is not None + this_key, key = jax.random.split(key) + fcm_mask = hax.nn.attention.forgetful_causal_mask(self.KPos, self.fcm_prob, key=this_key) + attn_mask = example.attn_mask & AttentionMask.explicit(fcm_mask) + example = dataclasses.replace(example, attn_mask=attn_mask) + + return example + + for tokens in self.dataset: + example = _create_lm_example(tokens, key) + yield example + + + class TokenSeqDataset(ShardableDataset[np.ndarray]): """ From 9ad06af91481f869e53bf5720f4d54efbc9066ff Mon Sep 17 00:00:00 2001 From: JulienDarve Date: Thu, 1 Aug 2024 14:43:35 -0700 Subject: [PATCH 05/29] Everything works except stuck on the final method, RobertaForMaskedLM --- src/levanter/models/roberta.py | 49 +++-- src/levanter/models/testing.ipynb | 338 +++++++++++++++++++++++++++--- 2 files changed, 334 insertions(+), 53 deletions(-) diff --git a/src/levanter/models/roberta.py b/src/levanter/models/roberta.py index 2db3c9bd4..90f808fa0 100644 --- a/src/levanter/models/roberta.py +++ b/src/levanter/models/roberta.py @@ -400,7 +400,7 @@ def __call__( attention_mask: Optional[NamedArray] = None, *, key - ) -> Tuple[NamedArray]: + ) -> NamedArray: k_a, k_o = maybe_rng_split(key, 2) self_outputs = self.self_attn( @@ -476,14 +476,12 @@ def __call__( ) -> Tuple[NamedArray]: k_a, k_o = maybe_rng_split(key, 2) - self_attention_outputs = self.attention( + attention_output = self.attention( hidden_states, attention_mask, key=k_a, ) - attention_output = self_attention_outputs - intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output, key=k_o) @@ -634,7 +632,7 @@ def init(config: RobertaConfig, *, key): @named_call def __call__(self, hidden_states: NamedArray, *, key=None) -> NamedArray: first_token = hidden_states[{"position" : 0}] - x = self.dense(first_token, key=key).rename({self.config.Embed: self.config.FinalEmbed}) + x = self.dense(first_token, key=key).rename({self.config.FinalEmbed: self.config.Embed}) x = hax.tanh(x) return x @@ -708,7 +706,7 @@ def __call__( if attention_mask is None: attention_mask = hax.ones(input_axes) - # Attention mask from mask to real numbers + # Attention mask from mask to actual numbers attention_mask = (attention_mask == 0) * -1e9 embedding_output = self.embeddings.embed(input_ids, token_type_ids, position_ids, input_embeds, key=k_emb) @@ -724,13 +722,14 @@ class RobertaLMHead(eqx.Module, StateDictSerializationMixin): dense: hnn.Linear layer_norm: hnn.LayerNorm decoder: hnn.Linear + config: RobertaConfig @staticmethod def init(Vocab: Axis, config: RobertaConfig, *, key): k_dense, k_decoder = jrandom.split(key, 2) Embed = config.Embed - dense = hnn.Linear.init(Embed, config.FinalEmbed, key=k_dense, out_first=True) + dense = hnn.Linear.init(In=Embed, Out=config.FinalEmbed, key=k_dense, out_first=True) layer_norm = hnn.LayerNorm.init(axis=Embed, eps=config.layer_norm_eps) decoder = hnn.Linear.init(Embed, Vocab, key=k_decoder, out_first=True) @@ -739,12 +738,12 @@ def init(Vocab: Axis, config: RobertaConfig, *, key): # self.bias = nn.Parameter(torch.zeros(config.vocab_size)) # self.decoder.bias = self.bias - return RobertaLMHead(dense, layer_norm, decoder) + return RobertaLMHead(dense, layer_norm, decoder, config) @named_call def __call__(self, features: NamedArray, *, key=None) -> NamedArray: - x = self.dense(features).rename({self.config.Embed: self.config.FinalEmbed}) - x = hnn.gelu(x) + x = self.dense(features).rename({self.config.FinalEmbed: self.config.Embed}) + x = hnn.gelu(x, approximate=False) x = self.layer_norm(x) # project back to size of vocabulary with bias @@ -763,8 +762,8 @@ def init(self, Vocab: Axis, config: RobertaConfig, *, key): # if config.is_decoder: # raise AttributeError("Model is being run as a MaskedLM aka an encoder model, but is_decoder is true") - k_rob, key_head = jrandom.split(key, 2) - roberta = RobertaModel.init(Vocab, config, add_pooling_layer=False, key=k_rob) + key_rob, key_head = jrandom.split(key, 2) + roberta = RobertaModel.init(Vocab, config, add_pooling_layer=False, key=key_rob) lm_head = RobertaLMHead.init(Vocab, config, key=key_head) return RobertaForMaskedLM(roberta, lm_head, Vocab) @@ -775,14 +774,17 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.lm_head.decoder = new_embeddings - def forward( + @named_call + def __call__( self, input_ids: Optional[NamedArray] = None, attention_mask: Optional[NamedArray] = None, token_type_ids: Optional[NamedArray] = None, position_ids: Optional[NamedArray] = None, - inputs_embeds: Optional[NamedArray] = None, + input_embeds: Optional[NamedArray] = None, labels: Optional[NamedArray] = None, + *, + key=None ) -> Tuple[NamedArray]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -793,23 +795,20 @@ def forward( Used to hide legacy arguments that have been deprecated. """ + k_rob, k_lm = maybe_rng_split(key, 2) + outputs = self.roberta( - input_ids, - attn_mask=attention_mask, + input_ids=input_ids, + attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, - inputs_embeds=inputs_embeds + input_embeds=input_embeds, + key=k_rob ) - sequence_output = outputs[0] - prediction_scores = self.lm_head(sequence_output) - - masked_lm_loss = None - if labels is not None: - masked_lm_loss = hnn.cross_entropy_loss(logits=prediction_scores, Label=self.Vocab, targets=labels) + prediction_scores = self.lm_head(outputs[0], key=k_lm) - output = (prediction_scores,) - return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + return prediction_scores def _rotate_half(x: NamedArray) -> NamedArray: diff --git a/src/levanter/models/testing.ipynb b/src/levanter/models/testing.ipynb index d8e0b4531..f16e0fcdd 100644 --- a/src/levanter/models/testing.ipynb +++ b/src/levanter/models/testing.ipynb @@ -11,7 +11,7 @@ "text": [ "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", - "2024-07-30 10:42:02,733\tINFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n" + "2024-08-01 14:42:02,038\tINFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n" ] } ], @@ -78,7 +78,10 @@ "input_ids = input_ids[{\"position\": slice(0,-2)}]\n", "\n", "input_embeds_torch = torch.from_numpy(np.array(input_embeds.array))\n", - "input_ids_torch = torch.from_numpy(np.array(input_ids.array))" + "input_ids_torch = torch.from_numpy(np.array(input_ids.array))\n", + "\n", + "features = hax.random.normal(key, (Batch, Embed))\n", + "features_torch = torch.from_numpy(np.array(features.array))" ] }, { @@ -98,10 +101,11 @@ "metadata": {}, "outputs": [], "source": [ - "def check(my_output, hf_output, p=False, pp=False, ppp=True, precision=1e-4):\n", + "def check(my_output, hf_output, p=False, pp=False, ppp=False, pppp=True, precision=1e-4):\n", " \n", - " print(my_output.shape)\n", - " print(hf_output.shape)\n", + " assert (np.array(my_output.shape) == np.array(hf_output.shape)).all()\n", + " # print(my_output.shape)\n", + " # print(hf_output.shape)\n", "\n", " success = np.isclose(hf_output, np.array(my_output), rtol=precision, atol=precision).all()\n", "\n", @@ -115,10 +119,18 @@ " for i in range(15):\n", " prec = 10 ** (-1*i)\n", " acc = np.isclose(hf_output, np.array(my_output), rtol=precision, atol=prec).mean()\n", - " if acc != acc_prev:\n", + " if acc_prev is None:\n", " print(f\"Iteration {i}, Precision {prec}:\\t{acc}\")\n", + " else:\n", + " if np.abs(acc - acc_prev) > 1e-4:\n", + " print(f\"Iteration {i}, Precision {prec}:\\t{acc}\")\n", " acc_prev = acc\n", + " print(f\"Iteration {i}, Precision {prec}:\\t{acc}\")\n", "\n", + " if pppp:\n", + " acc = np.isclose(hf_output, np.array(my_output), rtol=precision, atol=precision).mean()\n", + " print(f\"Accuracy: {acc}\")\n", + " \n", " if p: \n", " acc = np.isclose(hf_output, np.array(my_output), rtol=precision, atol=precision).mean()\n", " print(f\"Accuracy: {acc}\")\n", @@ -467,7 +479,7 @@ { "data": { "text/plain": [ - "'# Testing RobertaModel\\n\\nmy_pool = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=True, key=key)\\nstate = my_pool.to_state_dict()\\n\\nstate = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\\n\\nprint(state.keys())\\n\\nhf_pool = hf_roberta.RobertaModel(hf_config, add_pooling_layer=True)\\nhf_pool.load_state_dict(state, strict=True)\\n\\nmy_output = my_pool(input_ids = input_ids, attention_mask=mask, key=key)\\nhf_output = hf_pool(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\\n\\ncheck(my_output[0].array, hf_output[0].detach(), ppp=True)'" + "'# Testing RobertaModel\\n\\nmy_pool = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=True, key=key)\\nstate = my_pool.to_state_dict()\\n\\nstate = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\\n\\nprint(state.keys())\\n\\nhf_pool = hf_roberta.RobertaModel(hf_config, add_pooling_layer=True)\\nhf_pool.load_state_dict(state, strict=True)\\n\\n# my_output = my_pool(input_ids = input_ids, attention_mask=mask, key=key)\\n# hf_output = hf_pool(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\\n\\nmy_output = my_pool(input_embeds = input_embeds, attention_mask=mask, key=key)\\nhf_output = hf_pool(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\\n\\ncheck(my_output[1].array, hf_output[1].detach(), ppp=True)'" ] }, "execution_count": 15, @@ -488,54 +500,324 @@ "hf_pool = hf_roberta.RobertaModel(hf_config, add_pooling_layer=True)\n", "hf_pool.load_state_dict(state, strict=True)\n", "\n", - "my_output = my_pool(input_ids = input_ids, attention_mask=mask, key=key)\n", - "hf_output = hf_pool(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\n", + "# my_output = my_pool(input_ids = input_ids, attention_mask=mask, key=key)\n", + "# hf_output = hf_pool(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\n", "\n", - "check(my_output[0].array, hf_output[0].detach(), ppp=True)'''" + "my_output = my_pool(input_embeds = input_embeds, attention_mask=mask, key=key)\n", + "hf_output = hf_pool(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\n", + "\n", + "check(my_output[1].array, hf_output[1].detach(), ppp=True)'''" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'# Testing RobertaLMHead\\n\\nmy_pool = my_roberta.RobertaLMHead.init(Vocab, my_config, key=key)\\nstate = my_pool.to_state_dict()\\n\\nstate = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\\n\\nstate[\"bias\"] = torch.zeros(hf_config.vocab_size)\\n\\nprint(state.keys())\\n\\nhf_pool = hf_roberta.RobertaLMHead(hf_config)\\nhf_pool.load_state_dict(state, strict=True)\\n\\nmy_output = my_pool(features, key=key)\\nhf_output = hf_pool(features_torch)\\n\\ncheck(my_output.array, hf_output.detach(), ppp=True)'" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "'''# Testing RobertaLMHead\n", + "\n", + "my_pool = my_roberta.RobertaLMHead.init(Vocab, my_config, key=key)\n", + "state = my_pool.to_state_dict()\n", + "\n", + "state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "\n", + "state[\"bias\"] = torch.zeros(hf_config.vocab_size)\n", + "\n", + "print(state.keys())\n", + "\n", + "hf_pool = hf_roberta.RobertaLMHead(hf_config)\n", + "hf_pool.load_state_dict(state, strict=True)\n", + "\n", + "my_output = my_pool(features, key=key)\n", + "hf_output = hf_pool(features_torch)\n", + "\n", + "check(my_output.array, hf_output.detach(), ppp=True)'''" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'# Testing RobertaForMaskedLM\\n\\nmy_pool = my_roberta.RobertaForMaskedLM.init(Vocab, my_config, key=key)\\nstate = my_pool.to_state_dict()\\n\\nstate = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\\n\\nstate[\"lm_head.bias\"] = torch.zeros(hf_config.vocab_size)\\n\\nprint(state.keys())\\n\\nhf_pool = hf_roberta.RobertaForMaskedLM(hf_config)\\nhf_pool.load_state_dict(state, strict=True)\\n\\nmy_output_MLM = my_pool(input_ids = input_ids, attention_mask=mask, key=key)\\nhf_output_MLM = hf_pool(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\\n\\n# my_output = my_pool(input_embeds = input_embeds, attention_mask=mask, key=key)\\n# hf_output = hf_pool(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\\n\\ncheck(my_output_MLM[0].array, hf_output_MLM[0].detach(), ppp=True)'" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "'''# Testing RobertaForMaskedLM\n", + "\n", + "my_pool = my_roberta.RobertaForMaskedLM.init(Vocab, my_config, key=key)\n", + "state = my_pool.to_state_dict()\n", + "\n", + "state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "\n", + "state[\"lm_head.bias\"] = torch.zeros(hf_config.vocab_size)\n", + "\n", + "print(state.keys())\n", + "\n", + "hf_pool = hf_roberta.RobertaForMaskedLM(hf_config)\n", + "hf_pool.load_state_dict(state, strict=True)\n", + "\n", + "my_output_MLM = my_pool(input_ids = input_ids, attention_mask=mask, key=key)\n", + "hf_output_MLM = hf_pool(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\n", + "\n", + "# my_output = my_pool(input_embeds = input_embeds, attention_mask=mask, key=key)\n", + "# hf_output = hf_pool(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\n", + "\n", + "check(my_output_MLM[0].array, hf_output_MLM[0].detach(), ppp=True)'''" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'# Testing RobertaModel\\n\\nmy_model = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=False, key=key)\\nstate = my_model.to_state_dict()\\n\\nstate = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\\n\\nprint(state.keys())\\n\\nhf_model = hf_roberta.RobertaModel(hf_config, add_pooling_layer=False)\\nhf_model.load_state_dict(state, strict=True)\\n\\nmy_output = my_model(input_ids = input_ids, attention_mask=mask, key=key)\\nhf_output = hf_model(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\\n\\ncheck(my_output[0].array, hf_output[0].detach(), ppp=True)\\n\\n# Testing RobertaLMHead\\n\\nmy_head = my_roberta.RobertaLMHead.init(Vocab, my_config, key=key)\\nstate = my_head.to_state_dict()\\n\\nstate = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\\n\\nstate[\"bias\"] = torch.zeros(hf_config.vocab_size)\\n\\nprint(state.keys())\\n\\nhf_head = hf_roberta.RobertaLMHead(hf_config)\\nhf_head.load_state_dict(state, strict=True)\\n\\nmy_output = my_head(my_output[0], key=key)\\nhf_output = hf_head(hf_output[0])\\n\\ncheck(my_output.array, hf_output.detach(), ppp=True)'" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "'''# Testing RobertaModel\n", + "\n", + "my_model = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=False, key=key)\n", + "state = my_model.to_state_dict()\n", + "\n", + "state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "\n", + "print(state.keys())\n", + "\n", + "hf_model = hf_roberta.RobertaModel(hf_config, add_pooling_layer=False)\n", + "hf_model.load_state_dict(state, strict=True)\n", + "\n", + "my_output = my_model(input_ids = input_ids, attention_mask=mask, key=key)\n", + "hf_output = hf_model(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\n", + "\n", + "check(my_output[0].array, hf_output[0].detach(), ppp=True)\n", + "\n", + "# Testing RobertaLMHead\n", + "\n", + "my_head = my_roberta.RobertaLMHead.init(Vocab, my_config, key=key)\n", + "state = my_head.to_state_dict()\n", + "\n", + "state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "\n", + "state[\"bias\"] = torch.zeros(hf_config.vocab_size)\n", + "\n", + "print(state.keys())\n", + "\n", + "hf_head = hf_roberta.RobertaLMHead(hf_config)\n", + "hf_head.load_state_dict(state, strict=True)\n", + "\n", + "my_output = my_head(my_output[0], key=key)\n", + "hf_output = hf_head(hf_output[0])\n", + "\n", + "check(my_output.array, hf_output.detach(), ppp=True)'''" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "dict_keys(['roberta.encoder.layer.0.attention.self.query.weight', 'roberta.encoder.layer.1.attention.self.query.weight', 'roberta.encoder.layer.2.attention.self.query.weight', 'roberta.encoder.layer.3.attention.self.query.weight', 'roberta.encoder.layer.4.attention.self.query.weight', 'roberta.encoder.layer.5.attention.self.query.weight', 'roberta.encoder.layer.6.attention.self.query.weight', 'roberta.encoder.layer.7.attention.self.query.weight', 'roberta.encoder.layer.8.attention.self.query.weight', 'roberta.encoder.layer.9.attention.self.query.weight', 'roberta.encoder.layer.10.attention.self.query.weight', 'roberta.encoder.layer.11.attention.self.query.weight', 'roberta.encoder.layer.0.attention.self.query.bias', 'roberta.encoder.layer.1.attention.self.query.bias', 'roberta.encoder.layer.2.attention.self.query.bias', 'roberta.encoder.layer.3.attention.self.query.bias', 'roberta.encoder.layer.4.attention.self.query.bias', 'roberta.encoder.layer.5.attention.self.query.bias', 'roberta.encoder.layer.6.attention.self.query.bias', 'roberta.encoder.layer.7.attention.self.query.bias', 'roberta.encoder.layer.8.attention.self.query.bias', 'roberta.encoder.layer.9.attention.self.query.bias', 'roberta.encoder.layer.10.attention.self.query.bias', 'roberta.encoder.layer.11.attention.self.query.bias', 'roberta.encoder.layer.0.attention.self.key.weight', 'roberta.encoder.layer.1.attention.self.key.weight', 'roberta.encoder.layer.2.attention.self.key.weight', 'roberta.encoder.layer.3.attention.self.key.weight', 'roberta.encoder.layer.4.attention.self.key.weight', 'roberta.encoder.layer.5.attention.self.key.weight', 'roberta.encoder.layer.6.attention.self.key.weight', 'roberta.encoder.layer.7.attention.self.key.weight', 'roberta.encoder.layer.8.attention.self.key.weight', 'roberta.encoder.layer.9.attention.self.key.weight', 'roberta.encoder.layer.10.attention.self.key.weight', 'roberta.encoder.layer.11.attention.self.key.weight', 'roberta.encoder.layer.0.attention.self.key.bias', 'roberta.encoder.layer.1.attention.self.key.bias', 'roberta.encoder.layer.2.attention.self.key.bias', 'roberta.encoder.layer.3.attention.self.key.bias', 'roberta.encoder.layer.4.attention.self.key.bias', 'roberta.encoder.layer.5.attention.self.key.bias', 'roberta.encoder.layer.6.attention.self.key.bias', 'roberta.encoder.layer.7.attention.self.key.bias', 'roberta.encoder.layer.8.attention.self.key.bias', 'roberta.encoder.layer.9.attention.self.key.bias', 'roberta.encoder.layer.10.attention.self.key.bias', 'roberta.encoder.layer.11.attention.self.key.bias', 'roberta.encoder.layer.0.attention.self.value.weight', 'roberta.encoder.layer.1.attention.self.value.weight', 'roberta.encoder.layer.2.attention.self.value.weight', 'roberta.encoder.layer.3.attention.self.value.weight', 'roberta.encoder.layer.4.attention.self.value.weight', 'roberta.encoder.layer.5.attention.self.value.weight', 'roberta.encoder.layer.6.attention.self.value.weight', 'roberta.encoder.layer.7.attention.self.value.weight', 'roberta.encoder.layer.8.attention.self.value.weight', 'roberta.encoder.layer.9.attention.self.value.weight', 'roberta.encoder.layer.10.attention.self.value.weight', 'roberta.encoder.layer.11.attention.self.value.weight', 'roberta.encoder.layer.0.attention.self.value.bias', 'roberta.encoder.layer.1.attention.self.value.bias', 'roberta.encoder.layer.2.attention.self.value.bias', 'roberta.encoder.layer.3.attention.self.value.bias', 'roberta.encoder.layer.4.attention.self.value.bias', 'roberta.encoder.layer.5.attention.self.value.bias', 'roberta.encoder.layer.6.attention.self.value.bias', 'roberta.encoder.layer.7.attention.self.value.bias', 'roberta.encoder.layer.8.attention.self.value.bias', 'roberta.encoder.layer.9.attention.self.value.bias', 'roberta.encoder.layer.10.attention.self.value.bias', 'roberta.encoder.layer.11.attention.self.value.bias', 'roberta.encoder.layer.0.attention.output.dense.weight', 'roberta.encoder.layer.1.attention.output.dense.weight', 'roberta.encoder.layer.2.attention.output.dense.weight', 'roberta.encoder.layer.3.attention.output.dense.weight', 'roberta.encoder.layer.4.attention.output.dense.weight', 'roberta.encoder.layer.5.attention.output.dense.weight', 'roberta.encoder.layer.6.attention.output.dense.weight', 'roberta.encoder.layer.7.attention.output.dense.weight', 'roberta.encoder.layer.8.attention.output.dense.weight', 'roberta.encoder.layer.9.attention.output.dense.weight', 'roberta.encoder.layer.10.attention.output.dense.weight', 'roberta.encoder.layer.11.attention.output.dense.weight', 'roberta.encoder.layer.0.attention.output.dense.bias', 'roberta.encoder.layer.1.attention.output.dense.bias', 'roberta.encoder.layer.2.attention.output.dense.bias', 'roberta.encoder.layer.3.attention.output.dense.bias', 'roberta.encoder.layer.4.attention.output.dense.bias', 'roberta.encoder.layer.5.attention.output.dense.bias', 'roberta.encoder.layer.6.attention.output.dense.bias', 'roberta.encoder.layer.7.attention.output.dense.bias', 'roberta.encoder.layer.8.attention.output.dense.bias', 'roberta.encoder.layer.9.attention.output.dense.bias', 'roberta.encoder.layer.10.attention.output.dense.bias', 'roberta.encoder.layer.11.attention.output.dense.bias', 'roberta.encoder.layer.0.attention.output.LayerNorm.weight', 'roberta.encoder.layer.1.attention.output.LayerNorm.weight', 'roberta.encoder.layer.2.attention.output.LayerNorm.weight', 'roberta.encoder.layer.3.attention.output.LayerNorm.weight', 'roberta.encoder.layer.4.attention.output.LayerNorm.weight', 'roberta.encoder.layer.5.attention.output.LayerNorm.weight', 'roberta.encoder.layer.6.attention.output.LayerNorm.weight', 'roberta.encoder.layer.7.attention.output.LayerNorm.weight', 'roberta.encoder.layer.8.attention.output.LayerNorm.weight', 'roberta.encoder.layer.9.attention.output.LayerNorm.weight', 'roberta.encoder.layer.10.attention.output.LayerNorm.weight', 'roberta.encoder.layer.11.attention.output.LayerNorm.weight', 'roberta.encoder.layer.0.attention.output.LayerNorm.bias', 'roberta.encoder.layer.1.attention.output.LayerNorm.bias', 'roberta.encoder.layer.2.attention.output.LayerNorm.bias', 'roberta.encoder.layer.3.attention.output.LayerNorm.bias', 'roberta.encoder.layer.4.attention.output.LayerNorm.bias', 'roberta.encoder.layer.5.attention.output.LayerNorm.bias', 'roberta.encoder.layer.6.attention.output.LayerNorm.bias', 'roberta.encoder.layer.7.attention.output.LayerNorm.bias', 'roberta.encoder.layer.8.attention.output.LayerNorm.bias', 'roberta.encoder.layer.9.attention.output.LayerNorm.bias', 'roberta.encoder.layer.10.attention.output.LayerNorm.bias', 'roberta.encoder.layer.11.attention.output.LayerNorm.bias', 'roberta.encoder.layer.0.intermediate.dense.weight', 'roberta.encoder.layer.1.intermediate.dense.weight', 'roberta.encoder.layer.2.intermediate.dense.weight', 'roberta.encoder.layer.3.intermediate.dense.weight', 'roberta.encoder.layer.4.intermediate.dense.weight', 'roberta.encoder.layer.5.intermediate.dense.weight', 'roberta.encoder.layer.6.intermediate.dense.weight', 'roberta.encoder.layer.7.intermediate.dense.weight', 'roberta.encoder.layer.8.intermediate.dense.weight', 'roberta.encoder.layer.9.intermediate.dense.weight', 'roberta.encoder.layer.10.intermediate.dense.weight', 'roberta.encoder.layer.11.intermediate.dense.weight', 'roberta.encoder.layer.0.intermediate.dense.bias', 'roberta.encoder.layer.1.intermediate.dense.bias', 'roberta.encoder.layer.2.intermediate.dense.bias', 'roberta.encoder.layer.3.intermediate.dense.bias', 'roberta.encoder.layer.4.intermediate.dense.bias', 'roberta.encoder.layer.5.intermediate.dense.bias', 'roberta.encoder.layer.6.intermediate.dense.bias', 'roberta.encoder.layer.7.intermediate.dense.bias', 'roberta.encoder.layer.8.intermediate.dense.bias', 'roberta.encoder.layer.9.intermediate.dense.bias', 'roberta.encoder.layer.10.intermediate.dense.bias', 'roberta.encoder.layer.11.intermediate.dense.bias', 'roberta.encoder.layer.0.output.dense.weight', 'roberta.encoder.layer.1.output.dense.weight', 'roberta.encoder.layer.2.output.dense.weight', 'roberta.encoder.layer.3.output.dense.weight', 'roberta.encoder.layer.4.output.dense.weight', 'roberta.encoder.layer.5.output.dense.weight', 'roberta.encoder.layer.6.output.dense.weight', 'roberta.encoder.layer.7.output.dense.weight', 'roberta.encoder.layer.8.output.dense.weight', 'roberta.encoder.layer.9.output.dense.weight', 'roberta.encoder.layer.10.output.dense.weight', 'roberta.encoder.layer.11.output.dense.weight', 'roberta.encoder.layer.0.output.dense.bias', 'roberta.encoder.layer.1.output.dense.bias', 'roberta.encoder.layer.2.output.dense.bias', 'roberta.encoder.layer.3.output.dense.bias', 'roberta.encoder.layer.4.output.dense.bias', 'roberta.encoder.layer.5.output.dense.bias', 'roberta.encoder.layer.6.output.dense.bias', 'roberta.encoder.layer.7.output.dense.bias', 'roberta.encoder.layer.8.output.dense.bias', 'roberta.encoder.layer.9.output.dense.bias', 'roberta.encoder.layer.10.output.dense.bias', 'roberta.encoder.layer.11.output.dense.bias', 'roberta.encoder.layer.0.output.LayerNorm.weight', 'roberta.encoder.layer.1.output.LayerNorm.weight', 'roberta.encoder.layer.2.output.LayerNorm.weight', 'roberta.encoder.layer.3.output.LayerNorm.weight', 'roberta.encoder.layer.4.output.LayerNorm.weight', 'roberta.encoder.layer.5.output.LayerNorm.weight', 'roberta.encoder.layer.6.output.LayerNorm.weight', 'roberta.encoder.layer.7.output.LayerNorm.weight', 'roberta.encoder.layer.8.output.LayerNorm.weight', 'roberta.encoder.layer.9.output.LayerNorm.weight', 'roberta.encoder.layer.10.output.LayerNorm.weight', 'roberta.encoder.layer.11.output.LayerNorm.weight', 'roberta.encoder.layer.0.output.LayerNorm.bias', 'roberta.encoder.layer.1.output.LayerNorm.bias', 'roberta.encoder.layer.2.output.LayerNorm.bias', 'roberta.encoder.layer.3.output.LayerNorm.bias', 'roberta.encoder.layer.4.output.LayerNorm.bias', 'roberta.encoder.layer.5.output.LayerNorm.bias', 'roberta.encoder.layer.6.output.LayerNorm.bias', 'roberta.encoder.layer.7.output.LayerNorm.bias', 'roberta.encoder.layer.8.output.LayerNorm.bias', 'roberta.encoder.layer.9.output.LayerNorm.bias', 'roberta.encoder.layer.10.output.LayerNorm.bias', 'roberta.encoder.layer.11.output.LayerNorm.bias', 'roberta.embeddings.word_embeddings.weight', 'roberta.embeddings.position_embeddings.weight', 'roberta.embeddings.token_type_embeddings.weight', 'roberta.embeddings.LayerNorm.weight', 'roberta.embeddings.LayerNorm.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.decoder.bias'])\n" + "dict_keys(['roberta.encoder.layer.0.attention.self.query.weight', 'roberta.encoder.layer.1.attention.self.query.weight', 'roberta.encoder.layer.2.attention.self.query.weight', 'roberta.encoder.layer.3.attention.self.query.weight', 'roberta.encoder.layer.4.attention.self.query.weight', 'roberta.encoder.layer.5.attention.self.query.weight', 'roberta.encoder.layer.6.attention.self.query.weight', 'roberta.encoder.layer.7.attention.self.query.weight', 'roberta.encoder.layer.8.attention.self.query.weight', 'roberta.encoder.layer.9.attention.self.query.weight', 'roberta.encoder.layer.10.attention.self.query.weight', 'roberta.encoder.layer.11.attention.self.query.weight', 'roberta.encoder.layer.0.attention.self.query.bias', 'roberta.encoder.layer.1.attention.self.query.bias', 'roberta.encoder.layer.2.attention.self.query.bias', 'roberta.encoder.layer.3.attention.self.query.bias', 'roberta.encoder.layer.4.attention.self.query.bias', 'roberta.encoder.layer.5.attention.self.query.bias', 'roberta.encoder.layer.6.attention.self.query.bias', 'roberta.encoder.layer.7.attention.self.query.bias', 'roberta.encoder.layer.8.attention.self.query.bias', 'roberta.encoder.layer.9.attention.self.query.bias', 'roberta.encoder.layer.10.attention.self.query.bias', 'roberta.encoder.layer.11.attention.self.query.bias', 'roberta.encoder.layer.0.attention.self.key.weight', 'roberta.encoder.layer.1.attention.self.key.weight', 'roberta.encoder.layer.2.attention.self.key.weight', 'roberta.encoder.layer.3.attention.self.key.weight', 'roberta.encoder.layer.4.attention.self.key.weight', 'roberta.encoder.layer.5.attention.self.key.weight', 'roberta.encoder.layer.6.attention.self.key.weight', 'roberta.encoder.layer.7.attention.self.key.weight', 'roberta.encoder.layer.8.attention.self.key.weight', 'roberta.encoder.layer.9.attention.self.key.weight', 'roberta.encoder.layer.10.attention.self.key.weight', 'roberta.encoder.layer.11.attention.self.key.weight', 'roberta.encoder.layer.0.attention.self.key.bias', 'roberta.encoder.layer.1.attention.self.key.bias', 'roberta.encoder.layer.2.attention.self.key.bias', 'roberta.encoder.layer.3.attention.self.key.bias', 'roberta.encoder.layer.4.attention.self.key.bias', 'roberta.encoder.layer.5.attention.self.key.bias', 'roberta.encoder.layer.6.attention.self.key.bias', 'roberta.encoder.layer.7.attention.self.key.bias', 'roberta.encoder.layer.8.attention.self.key.bias', 'roberta.encoder.layer.9.attention.self.key.bias', 'roberta.encoder.layer.10.attention.self.key.bias', 'roberta.encoder.layer.11.attention.self.key.bias', 'roberta.encoder.layer.0.attention.self.value.weight', 'roberta.encoder.layer.1.attention.self.value.weight', 'roberta.encoder.layer.2.attention.self.value.weight', 'roberta.encoder.layer.3.attention.self.value.weight', 'roberta.encoder.layer.4.attention.self.value.weight', 'roberta.encoder.layer.5.attention.self.value.weight', 'roberta.encoder.layer.6.attention.self.value.weight', 'roberta.encoder.layer.7.attention.self.value.weight', 'roberta.encoder.layer.8.attention.self.value.weight', 'roberta.encoder.layer.9.attention.self.value.weight', 'roberta.encoder.layer.10.attention.self.value.weight', 'roberta.encoder.layer.11.attention.self.value.weight', 'roberta.encoder.layer.0.attention.self.value.bias', 'roberta.encoder.layer.1.attention.self.value.bias', 'roberta.encoder.layer.2.attention.self.value.bias', 'roberta.encoder.layer.3.attention.self.value.bias', 'roberta.encoder.layer.4.attention.self.value.bias', 'roberta.encoder.layer.5.attention.self.value.bias', 'roberta.encoder.layer.6.attention.self.value.bias', 'roberta.encoder.layer.7.attention.self.value.bias', 'roberta.encoder.layer.8.attention.self.value.bias', 'roberta.encoder.layer.9.attention.self.value.bias', 'roberta.encoder.layer.10.attention.self.value.bias', 'roberta.encoder.layer.11.attention.self.value.bias', 'roberta.encoder.layer.0.attention.output.dense.weight', 'roberta.encoder.layer.1.attention.output.dense.weight', 'roberta.encoder.layer.2.attention.output.dense.weight', 'roberta.encoder.layer.3.attention.output.dense.weight', 'roberta.encoder.layer.4.attention.output.dense.weight', 'roberta.encoder.layer.5.attention.output.dense.weight', 'roberta.encoder.layer.6.attention.output.dense.weight', 'roberta.encoder.layer.7.attention.output.dense.weight', 'roberta.encoder.layer.8.attention.output.dense.weight', 'roberta.encoder.layer.9.attention.output.dense.weight', 'roberta.encoder.layer.10.attention.output.dense.weight', 'roberta.encoder.layer.11.attention.output.dense.weight', 'roberta.encoder.layer.0.attention.output.dense.bias', 'roberta.encoder.layer.1.attention.output.dense.bias', 'roberta.encoder.layer.2.attention.output.dense.bias', 'roberta.encoder.layer.3.attention.output.dense.bias', 'roberta.encoder.layer.4.attention.output.dense.bias', 'roberta.encoder.layer.5.attention.output.dense.bias', 'roberta.encoder.layer.6.attention.output.dense.bias', 'roberta.encoder.layer.7.attention.output.dense.bias', 'roberta.encoder.layer.8.attention.output.dense.bias', 'roberta.encoder.layer.9.attention.output.dense.bias', 'roberta.encoder.layer.10.attention.output.dense.bias', 'roberta.encoder.layer.11.attention.output.dense.bias', 'roberta.encoder.layer.0.attention.output.LayerNorm.weight', 'roberta.encoder.layer.1.attention.output.LayerNorm.weight', 'roberta.encoder.layer.2.attention.output.LayerNorm.weight', 'roberta.encoder.layer.3.attention.output.LayerNorm.weight', 'roberta.encoder.layer.4.attention.output.LayerNorm.weight', 'roberta.encoder.layer.5.attention.output.LayerNorm.weight', 'roberta.encoder.layer.6.attention.output.LayerNorm.weight', 'roberta.encoder.layer.7.attention.output.LayerNorm.weight', 'roberta.encoder.layer.8.attention.output.LayerNorm.weight', 'roberta.encoder.layer.9.attention.output.LayerNorm.weight', 'roberta.encoder.layer.10.attention.output.LayerNorm.weight', 'roberta.encoder.layer.11.attention.output.LayerNorm.weight', 'roberta.encoder.layer.0.attention.output.LayerNorm.bias', 'roberta.encoder.layer.1.attention.output.LayerNorm.bias', 'roberta.encoder.layer.2.attention.output.LayerNorm.bias', 'roberta.encoder.layer.3.attention.output.LayerNorm.bias', 'roberta.encoder.layer.4.attention.output.LayerNorm.bias', 'roberta.encoder.layer.5.attention.output.LayerNorm.bias', 'roberta.encoder.layer.6.attention.output.LayerNorm.bias', 'roberta.encoder.layer.7.attention.output.LayerNorm.bias', 'roberta.encoder.layer.8.attention.output.LayerNorm.bias', 'roberta.encoder.layer.9.attention.output.LayerNorm.bias', 'roberta.encoder.layer.10.attention.output.LayerNorm.bias', 'roberta.encoder.layer.11.attention.output.LayerNorm.bias', 'roberta.encoder.layer.0.intermediate.dense.weight', 'roberta.encoder.layer.1.intermediate.dense.weight', 'roberta.encoder.layer.2.intermediate.dense.weight', 'roberta.encoder.layer.3.intermediate.dense.weight', 'roberta.encoder.layer.4.intermediate.dense.weight', 'roberta.encoder.layer.5.intermediate.dense.weight', 'roberta.encoder.layer.6.intermediate.dense.weight', 'roberta.encoder.layer.7.intermediate.dense.weight', 'roberta.encoder.layer.8.intermediate.dense.weight', 'roberta.encoder.layer.9.intermediate.dense.weight', 'roberta.encoder.layer.10.intermediate.dense.weight', 'roberta.encoder.layer.11.intermediate.dense.weight', 'roberta.encoder.layer.0.intermediate.dense.bias', 'roberta.encoder.layer.1.intermediate.dense.bias', 'roberta.encoder.layer.2.intermediate.dense.bias', 'roberta.encoder.layer.3.intermediate.dense.bias', 'roberta.encoder.layer.4.intermediate.dense.bias', 'roberta.encoder.layer.5.intermediate.dense.bias', 'roberta.encoder.layer.6.intermediate.dense.bias', 'roberta.encoder.layer.7.intermediate.dense.bias', 'roberta.encoder.layer.8.intermediate.dense.bias', 'roberta.encoder.layer.9.intermediate.dense.bias', 'roberta.encoder.layer.10.intermediate.dense.bias', 'roberta.encoder.layer.11.intermediate.dense.bias', 'roberta.encoder.layer.0.output.dense.weight', 'roberta.encoder.layer.1.output.dense.weight', 'roberta.encoder.layer.2.output.dense.weight', 'roberta.encoder.layer.3.output.dense.weight', 'roberta.encoder.layer.4.output.dense.weight', 'roberta.encoder.layer.5.output.dense.weight', 'roberta.encoder.layer.6.output.dense.weight', 'roberta.encoder.layer.7.output.dense.weight', 'roberta.encoder.layer.8.output.dense.weight', 'roberta.encoder.layer.9.output.dense.weight', 'roberta.encoder.layer.10.output.dense.weight', 'roberta.encoder.layer.11.output.dense.weight', 'roberta.encoder.layer.0.output.dense.bias', 'roberta.encoder.layer.1.output.dense.bias', 'roberta.encoder.layer.2.output.dense.bias', 'roberta.encoder.layer.3.output.dense.bias', 'roberta.encoder.layer.4.output.dense.bias', 'roberta.encoder.layer.5.output.dense.bias', 'roberta.encoder.layer.6.output.dense.bias', 'roberta.encoder.layer.7.output.dense.bias', 'roberta.encoder.layer.8.output.dense.bias', 'roberta.encoder.layer.9.output.dense.bias', 'roberta.encoder.layer.10.output.dense.bias', 'roberta.encoder.layer.11.output.dense.bias', 'roberta.encoder.layer.0.output.LayerNorm.weight', 'roberta.encoder.layer.1.output.LayerNorm.weight', 'roberta.encoder.layer.2.output.LayerNorm.weight', 'roberta.encoder.layer.3.output.LayerNorm.weight', 'roberta.encoder.layer.4.output.LayerNorm.weight', 'roberta.encoder.layer.5.output.LayerNorm.weight', 'roberta.encoder.layer.6.output.LayerNorm.weight', 'roberta.encoder.layer.7.output.LayerNorm.weight', 'roberta.encoder.layer.8.output.LayerNorm.weight', 'roberta.encoder.layer.9.output.LayerNorm.weight', 'roberta.encoder.layer.10.output.LayerNorm.weight', 'roberta.encoder.layer.11.output.LayerNorm.weight', 'roberta.encoder.layer.0.output.LayerNorm.bias', 'roberta.encoder.layer.1.output.LayerNorm.bias', 'roberta.encoder.layer.2.output.LayerNorm.bias', 'roberta.encoder.layer.3.output.LayerNorm.bias', 'roberta.encoder.layer.4.output.LayerNorm.bias', 'roberta.encoder.layer.5.output.LayerNorm.bias', 'roberta.encoder.layer.6.output.LayerNorm.bias', 'roberta.encoder.layer.7.output.LayerNorm.bias', 'roberta.encoder.layer.8.output.LayerNorm.bias', 'roberta.encoder.layer.9.output.LayerNorm.bias', 'roberta.encoder.layer.10.output.LayerNorm.bias', 'roberta.encoder.layer.11.output.LayerNorm.bias', 'roberta.embeddings.word_embeddings.weight', 'roberta.embeddings.position_embeddings.weight', 'roberta.embeddings.token_type_embeddings.weight', 'roberta.embeddings.LayerNorm.weight', 'roberta.embeddings.LayerNorm.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.decoder.bias', 'lm_head.bias'])\n", + "dict_keys(['encoder.layer.0.attention.self.query.weight', 'encoder.layer.1.attention.self.query.weight', 'encoder.layer.2.attention.self.query.weight', 'encoder.layer.3.attention.self.query.weight', 'encoder.layer.4.attention.self.query.weight', 'encoder.layer.5.attention.self.query.weight', 'encoder.layer.6.attention.self.query.weight', 'encoder.layer.7.attention.self.query.weight', 'encoder.layer.8.attention.self.query.weight', 'encoder.layer.9.attention.self.query.weight', 'encoder.layer.10.attention.self.query.weight', 'encoder.layer.11.attention.self.query.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.1.attention.self.query.bias', 'encoder.layer.2.attention.self.query.bias', 'encoder.layer.3.attention.self.query.bias', 'encoder.layer.4.attention.self.query.bias', 'encoder.layer.5.attention.self.query.bias', 'encoder.layer.6.attention.self.query.bias', 'encoder.layer.7.attention.self.query.bias', 'encoder.layer.8.attention.self.query.bias', 'encoder.layer.9.attention.self.query.bias', 'encoder.layer.10.attention.self.query.bias', 'encoder.layer.11.attention.self.query.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.1.attention.self.key.weight', 'encoder.layer.2.attention.self.key.weight', 'encoder.layer.3.attention.self.key.weight', 'encoder.layer.4.attention.self.key.weight', 'encoder.layer.5.attention.self.key.weight', 'encoder.layer.6.attention.self.key.weight', 'encoder.layer.7.attention.self.key.weight', 'encoder.layer.8.attention.self.key.weight', 'encoder.layer.9.attention.self.key.weight', 'encoder.layer.10.attention.self.key.weight', 'encoder.layer.11.attention.self.key.weight', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.1.attention.self.key.bias', 'encoder.layer.2.attention.self.key.bias', 'encoder.layer.3.attention.self.key.bias', 'encoder.layer.4.attention.self.key.bias', 'encoder.layer.5.attention.self.key.bias', 'encoder.layer.6.attention.self.key.bias', 'encoder.layer.7.attention.self.key.bias', 'encoder.layer.8.attention.self.key.bias', 'encoder.layer.9.attention.self.key.bias', 'encoder.layer.10.attention.self.key.bias', 'encoder.layer.11.attention.self.key.bias', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.1.attention.self.value.weight', 'encoder.layer.2.attention.self.value.weight', 'encoder.layer.3.attention.self.value.weight', 'encoder.layer.4.attention.self.value.weight', 'encoder.layer.5.attention.self.value.weight', 'encoder.layer.6.attention.self.value.weight', 'encoder.layer.7.attention.self.value.weight', 'encoder.layer.8.attention.self.value.weight', 'encoder.layer.9.attention.self.value.weight', 'encoder.layer.10.attention.self.value.weight', 'encoder.layer.11.attention.self.value.weight', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.1.attention.self.value.bias', 'encoder.layer.2.attention.self.value.bias', 'encoder.layer.3.attention.self.value.bias', 'encoder.layer.4.attention.self.value.bias', 'encoder.layer.5.attention.self.value.bias', 'encoder.layer.6.attention.self.value.bias', 'encoder.layer.7.attention.self.value.bias', 'encoder.layer.8.attention.self.value.bias', 'encoder.layer.9.attention.self.value.bias', 'encoder.layer.10.attention.self.value.bias', 'encoder.layer.11.attention.self.value.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.1.attention.output.dense.weight', 'encoder.layer.2.attention.output.dense.weight', 'encoder.layer.3.attention.output.dense.weight', 'encoder.layer.4.attention.output.dense.weight', 'encoder.layer.5.attention.output.dense.weight', 'encoder.layer.6.attention.output.dense.weight', 'encoder.layer.7.attention.output.dense.weight', 'encoder.layer.8.attention.output.dense.weight', 'encoder.layer.9.attention.output.dense.weight', 'encoder.layer.10.attention.output.dense.weight', 'encoder.layer.11.attention.output.dense.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.1.attention.output.dense.bias', 'encoder.layer.2.attention.output.dense.bias', 'encoder.layer.3.attention.output.dense.bias', 'encoder.layer.4.attention.output.dense.bias', 'encoder.layer.5.attention.output.dense.bias', 'encoder.layer.6.attention.output.dense.bias', 'encoder.layer.7.attention.output.dense.bias', 'encoder.layer.8.attention.output.dense.bias', 'encoder.layer.9.attention.output.dense.bias', 'encoder.layer.10.attention.output.dense.bias', 'encoder.layer.11.attention.output.dense.bias', 'encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.layer.1.attention.output.LayerNorm.weight', 'encoder.layer.2.attention.output.LayerNorm.weight', 'encoder.layer.3.attention.output.LayerNorm.weight', 'encoder.layer.4.attention.output.LayerNorm.weight', 'encoder.layer.5.attention.output.LayerNorm.weight', 'encoder.layer.6.attention.output.LayerNorm.weight', 'encoder.layer.7.attention.output.LayerNorm.weight', 'encoder.layer.8.attention.output.LayerNorm.weight', 'encoder.layer.9.attention.output.LayerNorm.weight', 'encoder.layer.10.attention.output.LayerNorm.weight', 'encoder.layer.11.attention.output.LayerNorm.weight', 'encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.layer.1.attention.output.LayerNorm.bias', 'encoder.layer.2.attention.output.LayerNorm.bias', 'encoder.layer.3.attention.output.LayerNorm.bias', 'encoder.layer.4.attention.output.LayerNorm.bias', 'encoder.layer.5.attention.output.LayerNorm.bias', 'encoder.layer.6.attention.output.LayerNorm.bias', 'encoder.layer.7.attention.output.LayerNorm.bias', 'encoder.layer.8.attention.output.LayerNorm.bias', 'encoder.layer.9.attention.output.LayerNorm.bias', 'encoder.layer.10.attention.output.LayerNorm.bias', 'encoder.layer.11.attention.output.LayerNorm.bias', 'encoder.layer.0.intermediate.dense.weight', 'encoder.layer.1.intermediate.dense.weight', 'encoder.layer.2.intermediate.dense.weight', 'encoder.layer.3.intermediate.dense.weight', 'encoder.layer.4.intermediate.dense.weight', 'encoder.layer.5.intermediate.dense.weight', 'encoder.layer.6.intermediate.dense.weight', 'encoder.layer.7.intermediate.dense.weight', 'encoder.layer.8.intermediate.dense.weight', 'encoder.layer.9.intermediate.dense.weight', 'encoder.layer.10.intermediate.dense.weight', 'encoder.layer.11.intermediate.dense.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.1.intermediate.dense.bias', 'encoder.layer.2.intermediate.dense.bias', 'encoder.layer.3.intermediate.dense.bias', 'encoder.layer.4.intermediate.dense.bias', 'encoder.layer.5.intermediate.dense.bias', 'encoder.layer.6.intermediate.dense.bias', 'encoder.layer.7.intermediate.dense.bias', 'encoder.layer.8.intermediate.dense.bias', 'encoder.layer.9.intermediate.dense.bias', 'encoder.layer.10.intermediate.dense.bias', 'encoder.layer.11.intermediate.dense.bias', 'encoder.layer.0.output.dense.weight', 'encoder.layer.1.output.dense.weight', 'encoder.layer.2.output.dense.weight', 'encoder.layer.3.output.dense.weight', 'encoder.layer.4.output.dense.weight', 'encoder.layer.5.output.dense.weight', 'encoder.layer.6.output.dense.weight', 'encoder.layer.7.output.dense.weight', 'encoder.layer.8.output.dense.weight', 'encoder.layer.9.output.dense.weight', 'encoder.layer.10.output.dense.weight', 'encoder.layer.11.output.dense.weight', 'encoder.layer.0.output.dense.bias', 'encoder.layer.1.output.dense.bias', 'encoder.layer.2.output.dense.bias', 'encoder.layer.3.output.dense.bias', 'encoder.layer.4.output.dense.bias', 'encoder.layer.5.output.dense.bias', 'encoder.layer.6.output.dense.bias', 'encoder.layer.7.output.dense.bias', 'encoder.layer.8.output.dense.bias', 'encoder.layer.9.output.dense.bias', 'encoder.layer.10.output.dense.bias', 'encoder.layer.11.output.dense.bias', 'encoder.layer.0.output.LayerNorm.weight', 'encoder.layer.1.output.LayerNorm.weight', 'encoder.layer.2.output.LayerNorm.weight', 'encoder.layer.3.output.LayerNorm.weight', 'encoder.layer.4.output.LayerNorm.weight', 'encoder.layer.5.output.LayerNorm.weight', 'encoder.layer.6.output.LayerNorm.weight', 'encoder.layer.7.output.LayerNorm.weight', 'encoder.layer.8.output.LayerNorm.weight', 'encoder.layer.9.output.LayerNorm.weight', 'encoder.layer.10.output.LayerNorm.weight', 'encoder.layer.11.output.LayerNorm.weight', 'encoder.layer.0.output.LayerNorm.bias', 'encoder.layer.1.output.LayerNorm.bias', 'encoder.layer.2.output.LayerNorm.bias', 'encoder.layer.3.output.LayerNorm.bias', 'encoder.layer.4.output.LayerNorm.bias', 'encoder.layer.5.output.LayerNorm.bias', 'encoder.layer.6.output.LayerNorm.bias', 'encoder.layer.7.output.LayerNorm.bias', 'encoder.layer.8.output.LayerNorm.bias', 'encoder.layer.9.output.LayerNorm.bias', 'encoder.layer.10.output.LayerNorm.bias', 'encoder.layer.11.output.LayerNorm.bias', 'embeddings.word_embeddings.weight', 'embeddings.position_embeddings.weight', 'embeddings.token_type_embeddings.weight', 'embeddings.LayerNorm.weight', 'embeddings.LayerNorm.bias'])\n", + "dict_keys(['dense.weight', 'dense.bias', 'layer_norm.weight', 'layer_norm.bias', 'decoder.weight', 'decoder.bias', 'bias'])\n" ] }, { - "ename": "RuntimeError", - "evalue": "Error(s) in loading state_dict for RobertaForMaskedLM:\n\tMissing key(s) in state_dict: \"lm_head.bias\". ", - "output_type": "error", - "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[1;32mIn[16], line 11\u001b[0m\n\u001b[0;32m 8\u001b[0m \u001b[38;5;28mprint\u001b[39m(state\u001b[38;5;241m.\u001b[39mkeys())\n\u001b[0;32m 10\u001b[0m hf_pool \u001b[38;5;241m=\u001b[39m hf_roberta\u001b[38;5;241m.\u001b[39mRobertaForMaskedLM(hf_config)\n\u001b[1;32m---> 11\u001b[0m \u001b[43mhf_pool\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload_state_dict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstate\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstrict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[0;32m 13\u001b[0m my_output \u001b[38;5;241m=\u001b[39m my_pool(input_ids \u001b[38;5;241m=\u001b[39m input_ids, attention_mask\u001b[38;5;241m=\u001b[39mmask, key\u001b[38;5;241m=\u001b[39mkey)\n\u001b[0;32m 14\u001b[0m hf_output \u001b[38;5;241m=\u001b[39m hf_pool(input_ids \u001b[38;5;241m=\u001b[39m input_ids_torch, attention_mask\u001b[38;5;241m=\u001b[39mmask_torch, return_dict\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n", - "File \u001b[1;32mc:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\torch\\nn\\modules\\module.py:2189\u001b[0m, in \u001b[0;36mModule.load_state_dict\u001b[1;34m(self, state_dict, strict, assign)\u001b[0m\n\u001b[0;32m 2184\u001b[0m error_msgs\u001b[38;5;241m.\u001b[39minsert(\n\u001b[0;32m 2185\u001b[0m \u001b[38;5;241m0\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mMissing key(s) in state_dict: \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m. \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[0;32m 2186\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mk\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m k \u001b[38;5;129;01min\u001b[39;00m missing_keys)))\n\u001b[0;32m 2188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(error_msgs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m-> 2189\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mError(s) in loading state_dict for \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[0;32m 2190\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(error_msgs)))\n\u001b[0;32m 2191\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _IncompatibleKeys(missing_keys, unexpected_keys)\n", - "\u001b[1;31mRuntimeError\u001b[0m: Error(s) in loading state_dict for RobertaForMaskedLM:\n\tMissing key(s) in state_dict: \"lm_head.bias\". " - ] + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ "# Testing RobertaForMaskedLM\n", "\n", - "my_pool = my_roberta.RobertaForMaskedLM.init(Vocab, my_config, key=key)\n", - "state = my_pool.to_state_dict()\n", + "my_mlm = my_roberta.RobertaForMaskedLM.init(Vocab, my_config, key=key)\n", + "state = my_mlm.to_state_dict()\n", + "\n", + "state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "\n", + "state[\"lm_head.bias\"] = torch.zeros(hf_config.vocab_size)\n", + "\n", + "print(state.keys())\n", + "\n", + "hf_mlm = hf_roberta.RobertaForMaskedLM(hf_config)\n", + "hf_mlm.load_state_dict(state, strict=True)\n", + "\n", + "# Testing RobertaModel\n", + "\n", + "key_rob, key_head = jrandom.split(key, 2)\n", + "\n", + "my_model = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=False, key=key_rob)\n", + "state = my_model.to_state_dict()\n", "\n", "state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", "\n", "print(state.keys())\n", "\n", - "hf_pool = hf_roberta.RobertaForMaskedLM(hf_config)\n", - "hf_pool.load_state_dict(state, strict=True)\n", + "hf_model = hf_roberta.RobertaModel(hf_config, add_pooling_layer=False)\n", + "hf_model.load_state_dict(state, strict=True)\n", + "\n", + "# Testing RobertaLMHead\n", + "\n", + "my_head = my_roberta.RobertaLMHead.init(Vocab, my_config, key=key_head)\n", + "state = my_head.to_state_dict()\n", + "\n", + "state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "\n", + "state[\"bias\"] = torch.zeros(hf_config.vocab_size)\n", + "\n", + "print(state.keys())\n", + "\n", + "hf_head = hf_roberta.RobertaLMHead(hf_config)\n", + "hf_head.load_state_dict(state, strict=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Checking RobertaModel\n", + "Success!!!\n", + "Iteration 0, Precision 1:\t1.0\n", + "Iteration 7, Precision 1e-07:\t0.9986305236816406\n", + "Iteration 8, Precision 1e-08:\t0.9982020060221354\n", + "Iteration 14, Precision 1e-14:\t0.9981346130371094\n", + "Accuracy: 1.0\n", + "Checking Roberta Model + LM head\n", + "Success!!!\n", + "Iteration 0, Precision 1:\t1.0\n", + "Iteration 6, Precision 1e-06:\t0.9996785785337711\n", + "Iteration 7, Precision 1e-07:\t0.9964101978265194\n", + "Iteration 8, Precision 1e-08:\t0.995718628767532\n", + "Iteration 14, Precision 1e-14:\t0.9956369328496468\n", + "Accuracy: 1.0\n", + "Checking MLM\n", + "Fail :((((\n", + "Iteration 0, Precision 1:\t1.0\n", + "Iteration 1, Precision 0.1:\t0.6762516280898737\n", + "Iteration 2, Precision 0.01:\t0.0813977132137173\n", + "Iteration 3, Precision 0.001:\t0.00872603715930568\n", + "Iteration 4, Precision 0.0001:\t0.0014499908298517856\n", + "Iteration 5, Precision 1e-05:\t0.000723081729334527\n", + "Iteration 14, Precision 1e-14:\t0.0006434452091415498\n", + "Accuracy: 0.0014499908298517856\n", + "Checking my RobertaModel + LM head and MLM\n", + "Success!!!\n", + "Iteration 0, Precision 1:\t1.0\n", + "Iteration 14, Precision 1e-14:\t1.0\n", + "Accuracy: 1.0\n", + "Checking hf RobertaModel + LM head and MLM\n", + "Fail :((((\n", + "Iteration 0, Precision 1:\t1.0\n", + "Iteration 1, Precision 0.1:\t0.6762513949505122\n", + "Iteration 2, Precision 0.01:\t0.0813977132137173\n", + "Iteration 3, Precision 0.001:\t0.008726678292549488\n", + "Iteration 4, Precision 0.0001:\t0.0014501462560927087\n", + "Iteration 5, Precision 1e-05:\t0.0007229263030936039\n" + ] + } + ], + "source": [ + "k_rob, k_lm = jrandom.split(key, 2)\n", + "\n", + "# MLM\n", + "\n", + "my_output_mlm = my_mlm(input_ids = input_ids, attention_mask=mask, key=key)\n", + "hf_output_mlm = hf_mlm(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\n", + "\n", + "# Model + LM\n", + "\n", + "my_output_model = my_model(input_ids = input_ids, attention_mask=mask, key=k_rob)\n", + "my_output = my_head(my_output_model[0], key=k_lm)\n", + "\n", + "hf_output_model = hf_model(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\n", + "hf_output = hf_head(hf_output_model[0])\n", + "\n", + "# # MLM\n", + "# my_output_mlm = my_mlm(input_embeds = input_embeds, attention_mask=mask, key=key)\n", + "# hf_output_mlm = hf_mlm(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\n", + "\n", + "# # Model + LM\n", + "# my_output_model = my_model(input_embeds = input_embeds, attention_mask=mask, key=k_rob)\n", + "# my_output = my_head(my_output_model[0], key=k_lm)\n", + "\n", + "# hf_output_model = hf_model(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\n", + "# hf_output = hf_head(hf_output_model[0])\n", + "\n", + "# Checks\n", + "\n", + "print(\"Checking RobertaModel\")\n", + "check(my_output_model[0].array, hf_output_model[0].detach(), pppp=True)\n", + "print(\"Checking Roberta Model + LM head\")\n", + "check(my_output.array, hf_output.detach(), pppp=True)\n", + "print(\"Checking MLM\")\n", + "check(my_output_mlm.array, hf_output_mlm[0].detach(), pppp=True)\n", + "\n", + "print(\"Checking my RobertaModel + LM head and MLM\")\n", + "check(my_output.array, my_output_mlm.array, pppp=True)\n", + "print(\"Checking hf RobertaModel + LM head and MLM\")\n", + "check(hf_output.detach(), hf_output_mlm[0].detach(), pppp=True)\n", "\n", - "my_output = my_pool(input_ids = input_ids, attention_mask=mask, key=key)\n", - "hf_output = hf_pool(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\n", "\n", - "check(my_output[0].array, hf_output[0].detach(), ppp=True)" + "# Notes\n", + "# embeds works between hf and my, and within model->head and mlm \n", + "# ids does not work between hf and my for mlm or within hf for model->head and mlm - so hf mlm is doing something weird." ] } ], From 53fd8d23061acfd0e936c9b86006f386874e051f Mon Sep 17 00:00:00 2001 From: prady-saligram Date: Mon, 5 Aug 2024 14:51:47 -0700 Subject: [PATCH 06/29] [WIP] Re-implements MLM training objective --- src/levanter/data/text.py | 56 ++++++++++++++----------- src/levanter/models/lm_model.py | 74 +++++++++++++++++++++++---------- 2 files changed, 83 insertions(+), 47 deletions(-) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 4e625ee8b..dfc7df4ea 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -26,13 +26,11 @@ from levanter.data.mixture import MixtureDataset, StopStrategy -# intercept the logging nonsense here from levanter.logging import silence_transformer_nag # noqa from levanter.models.attention import AttentionMask -from levanter.models.lm_model import LmExample +from levanter.models.lm_model import MaskedLmExample, LmExample from levanter.utils.hf_utils import num_cpus_used_by_tokenizer - silence_transformer_nag() # noqa from transformers import BatchEncoding, PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast # noqa @@ -54,7 +52,6 @@ from levanter.shapes import NamedShapeSpec, ShapeSpec # noqa from levanter.utils.jax_utils import use_cpu_device # noqa - logger = logging.getLogger("levanter.data.text") # TASKS: @@ -65,13 +62,14 @@ DEFAULT_IGNORE_INDEX = -100 # Mirrors pytorch's default ignore index -class MaskedLmDataset(ShardableDataset[LmExample]): +class MaskedLmDataset(ShardableDataset[MaskedLmExample]): def __init__( self, dataset: ShardableDataset[np.ndarray], QPos: Axis, KPos: Axis, mask_prob: float = 0.15, + noise_prob: float = 0.1, key: Optional[PRNGKeyArray] = None, ignore_index: Optional[int] = DEFAULT_IGNORE_INDEX, ): @@ -79,6 +77,7 @@ def __init__( self.QPos = QPos self.KPos = KPos self.mask_prob = mask_prob + self.noise_prob = noise_prob self.key = key self.ignore_id = ignore_index if ignore_index is not None else DEFAULT_IGNORE_INDEX @@ -87,10 +86,10 @@ def __init__( def shard(self, shard_id: int, num_shards: int) -> "MaskedLmDataset": return MaskedLmDataset( - self.dataset.shard(shard_id, num_shards), self.QPos, self.KPos, self.mask_prob, self.key, self.ignore_id + self.dataset.shard(shard_id, num_shards), self.QPos, self.KPos, self.mask_prob, self.noise_prob, self.key, self.ignore_id ) - def __iter__(self) -> Iterator[LmExample]: + def __iter__(self) -> Iterator[MaskedLmExample]: key = self.key sharding = jax.sharding.SingleDeviceSharding(jax.local_devices(backend="cpu")[0]) @@ -98,31 +97,44 @@ def __iter__(self) -> Iterator[LmExample]: @functools.partial(eqx.filter_jit, out_shardings=sharding) def _create_mlm_example(tokens, key): tokens_array = tokens.array - - example = LmExample.causal(tokens=tokens, ignore_id=self.ignore_id) - + targets = tokens_array.copy() + if self.mask_prob > 0: this_key, key = jax.random.split(key) mask_shape = tokens_array.shape mask = jax.random.bernoulli(this_key, self.mask_prob, mask_shape) - # Create a mask for 80% [MASK], 10% random, 10% original token rand = jax.random.uniform(this_key, mask_shape) mask_token = jnp.where(rand < 0.8, self.ignore_id, tokens_array) - mask_token = jnp.where((rand >= 0.8) & (rand < 0.9), tokens_array, mask_token) random_tokens = jax.random.randint(this_key, mask_shape, 0, tokens_array.max() + 1) - masked_tokens = jnp.where(mask, mask_token, random_tokens) + mask_token = jnp.where((rand >= 0.8) & (rand < 0.8 + self.noise_prob), random_tokens, mask_token) + masked_tokens = jnp.where(mask, mask_token, tokens_array) + + # Set targets to the original tokens where mask is True, otherwise set to ignore_id + targets = jnp.where(mask, tokens_array, self.ignore_id) masked_tokens_named = hax.named(masked_tokens, self.QPos) - example = dataclasses.replace(example, tokens=masked_tokens_named) + targets_named = hax.named(targets, self.QPos) + + attn_mask_shape = (tokens_array.shape[0], tokens_array.shape[0]) + attn_mask = hax.named(jnp.ones(attn_mask_shape, dtype=jnp.bool_), (self.QPos, self.KPos)) + + example = MaskedLmExample.masked_lm(tokens=masked_tokens_named, targets=targets_named, ignore_id=self.ignore_id, attn_mask=attn_mask) + else: + targets_named = hax.named(targets, self.QPos) + attn_mask_shape = (tokens_array.shape[0], tokens_array.shape[0]) + attn_mask = hax.named(jnp.ones(attn_mask_shape, dtype=jnp.bool_), (self.QPos, self.KPos)) + + example = MaskedLmExample.masked_lm(tokens=tokens, targets=targets_named, ignore_id=self.ignore_id, attn_mask=attn_mask) return example - for tokens in self.dataset: - tokens_array = jnp.array(tokens) - tokens_named = hax.named(tokens_array, self.QPos) - example = _create_mlm_example(tokens_named, key) - yield example + for tokens in self.dataset: + tokens_array = jnp.array(tokens) + tokens_named = hax.named(tokens_array, self.QPos) + example = _create_mlm_example(tokens_named, key) + yield example + class CausalLmDataset(ShardableDataset[LmExample]): @@ -155,7 +167,6 @@ def __iter__(self) -> Iterator[LmExample]: sharding = jax.sharding.SingleDeviceSharding(jax.local_devices(backend="cpu")[0]) with use_cpu_device(): - @functools.partial(eqx.filter_jit, out_shardings=sharding) def _create_lm_example(tokens, key): tokens = hax.named(tokens, self.QPos) @@ -163,10 +174,6 @@ def _create_lm_example(tokens, key): example = LmExample.causal(tokens=tokens, ignore_id=self.ignore_id) if self.fcm_prob > 0: - # masks for attention - # We support forgetful causal masking (FCM) which is a technique that improves training speed by - # randomly masking out some of the context. This is a bit like dropout, but it's applied to the attention - # mask instead of the activations. It's described in https://arxiv.org/abs/2210.13432 assert self.key is not None this_key, key = jax.random.split(key) fcm_mask = hax.nn.attention.forgetful_causal_mask(self.KPos, self.fcm_prob, key=this_key) @@ -181,7 +188,6 @@ def _create_lm_example(tokens, key): - class TokenSeqDataset(ShardableDataset[np.ndarray]): """ A dataset that yields sequences of tokens of fixed length from a TokenizedDocumentCache. diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index 543c6a5ca..edcbb59f9 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -3,6 +3,7 @@ import draccus import equinox as eqx +import jax import jax.numpy as jnp from jax.random import PRNGKey @@ -12,15 +13,36 @@ from levanter.models.attention import AttentionMask - LmConfigT = TypeVar("LmConfigT", bound="LmConfig") LmT = TypeVar("LmT", bound="LmHeadModel") +class MaskedLmExample(eqx.Module): + tokens: hax.NamedArray + loss_mask: hax.NamedArray + attn_mask: hax.NamedArray + targets: Optional[hax.NamedArray] = None + + @staticmethod + def masked_lm( + tokens: hax.NamedArray, targets: hax.NamedArray, attn_mask: hax.NamedArray, ignore_id: Optional[int] = None + ) -> "MaskedLmExample": + Pos = tokens.axes[0] + + mask = tokens.array != targets.array + loss_mask = hax.named(mask.astype(jnp.float32), Pos) + + if ignore_id is not None: + ignore_mask = targets.array != ignore_id + loss_mask = loss_mask * hax.named(ignore_mask.astype(jnp.float32), Pos) + + return MaskedLmExample(tokens=tokens, targets=targets, loss_mask=loss_mask, attn_mask=attn_mask) + class LmExample(eqx.Module): tokens: hax.NamedArray loss_mask: hax.NamedArray - attn_mask: AttentionMask | NamedArray = AttentionMask.causal() + attn_mask: hax.NamedArray + targets: Optional[hax.NamedArray] = None @staticmethod def causal( @@ -34,20 +56,38 @@ def causal( Pos = tokens.axes[0] - # don't predict the last token. if loss_mask is None: loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jnp.float32) if ignore_id is not None: - # we don't compute loss for any tokens matching the ignore index ignore_mask = hax.roll(tokens, -1, Pos) != ignore_id loss_mask = loss_mask * ignore_mask attn_mask = AttentionMask.causal() return LmExample(tokens=tokens, loss_mask=loss_mask, attn_mask=attn_mask) + @staticmethod + def masked_lm( + tokens: hax.NamedArray, targets: hax.NamedArray, attn_mask: hax.NamedArray, ignore_id: Optional[int] = None + ) -> "LmExample": + Pos = tokens.axes[0] + + mask = tokens.array != targets.array + loss_mask = mask.astype(jnp.float32) + + if ignore_id is not None: + ignore_mask = targets.array != ignore_id + loss_mask = loss_mask * ignore_mask.astype(jnp.float32) + + print(f"tokens shape: {tokens.shape}") + print(f"targets shape: {targets.shape}") + print(f"loss_mask shape: {loss_mask.shape}") + print(f"attn_mask shape: {attn_mask.shape}") + + return LmExample(tokens=tokens, targets=targets, loss_mask=loss_mask, attn_mask=attn_mask) + + -# TODO: for some reason, mypy doesn't like the discover_packages_path argument? class LmConfig(draccus.PluginRegistry, abc.ABC, Generic[LmT], discover_packages_path="levanter.models"): # type: ignore @property @abc.abstractmethod @@ -70,12 +110,7 @@ def flops_per_token(self, vocab_size: int) -> Optional[float]: def build(self, Vocab: Axis, *, key: PRNGKey) -> "LmT": return self.model_type.init(Vocab, self, key=key) # type: ignore - class LmHeadModel(Generic[LmConfigT], abc.ABC): - """ - Superclass for models with a language modeling head. - """ - @property @abc.abstractmethod def config(self) -> LmConfigT: @@ -103,14 +138,11 @@ def init(cls, Vocab: Axis, config: LmConfigT, *, key: PRNGKey) -> "LmHeadModel[L def __call__( self, input_ids: NamedArray, attn_mask: Optional[AttentionMask | NamedArray] = None, *, key=None ) -> NamedArray: - pass + print(f"input_ids shape: {input_ids.shape}") + print(f"attn_mask shape: {attn_mask.shape}") @abc.abstractmethod def resize_vocab(self, new_size: int, key: Optional[PRNGKey] = None) -> "LmHeadModel[LmConfigT]": - """ - Resizes the vocabulary of the model. Key may be provided to use random initialization, otherwise, there - should be some deterministic initialization of any new parameters. - """ pass def compute_loss( @@ -121,15 +153,13 @@ def compute_loss( reduction: Optional[hax.ReductionFunction] = hax.mean, reduction_axis: Optional[hax.AxisSelection] = None, ) -> jnp.ndarray | NamedArray: - """ - Computes the cross-entropy loss for a language modeling example. If reduction is not None, the loss is reduced - across the reduction axis (with reduction_axis=None meaning all axes). If reduction is None, the loss is not - reduced, and the result is a named array with axes (*batch axes, sequence_length). - """ logits = self(example.tokens, example.attn_mask, key=key) - # TODO: would be nice if we made the dtype configurable logits = logits.astype(jnp.float32) - targets = hax.roll(example.tokens, -1, axis=self.Pos.name) + if example.targets is not None: + targets = example.targets + else: + targets = hax.roll(example.tokens, -1, axis=self.Pos.name) + target_y = hax.nn.one_hot(targets, self.Vocab, dtype=logits.dtype) loss = cross_entropy_loss( logits, self.Vocab, target_y, reduction, reduction_axis=reduction_axis, where=example.loss_mask From dcd45b209b81946efa6a57253545c03a618db28d Mon Sep 17 00:00:00 2001 From: prady-saligram Date: Tue, 6 Aug 2024 11:38:09 -0700 Subject: [PATCH 07/29] Adds error handling and reverts LmExample class to original --- src/levanter/models/lm_model.py | 33 ++++++++++----------------------- 1 file changed, 10 insertions(+), 23 deletions(-) diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index edcbb59f9..c36e0e622 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -26,6 +26,15 @@ class MaskedLmExample(eqx.Module): def masked_lm( tokens: hax.NamedArray, targets: hax.NamedArray, attn_mask: hax.NamedArray, ignore_id: Optional[int] = None ) -> "MaskedLmExample": + if tokens.ndim != 1: + raise ValueError("tokens must be a 1D array") + + if not jnp.issubdtype(tokens.dtype, jnp.integer): + raise ValueError("tokens must be an integer array") + + if tokens.shape != targets.shape: + raise ValueError("tokens and targets must have the same shape") + Pos = tokens.axes[0] mask = tokens.array != targets.array @@ -41,8 +50,7 @@ def masked_lm( class LmExample(eqx.Module): tokens: hax.NamedArray loss_mask: hax.NamedArray - attn_mask: hax.NamedArray - targets: Optional[hax.NamedArray] = None + attn_mask: AttentionMask | NamedArray = AttentionMask.causal() @staticmethod def causal( @@ -66,27 +74,6 @@ def causal( attn_mask = AttentionMask.causal() return LmExample(tokens=tokens, loss_mask=loss_mask, attn_mask=attn_mask) - @staticmethod - def masked_lm( - tokens: hax.NamedArray, targets: hax.NamedArray, attn_mask: hax.NamedArray, ignore_id: Optional[int] = None - ) -> "LmExample": - Pos = tokens.axes[0] - - mask = tokens.array != targets.array - loss_mask = mask.astype(jnp.float32) - - if ignore_id is not None: - ignore_mask = targets.array != ignore_id - loss_mask = loss_mask * ignore_mask.astype(jnp.float32) - - print(f"tokens shape: {tokens.shape}") - print(f"targets shape: {targets.shape}") - print(f"loss_mask shape: {loss_mask.shape}") - print(f"attn_mask shape: {attn_mask.shape}") - - return LmExample(tokens=tokens, targets=targets, loss_mask=loss_mask, attn_mask=attn_mask) - - class LmConfig(draccus.PluginRegistry, abc.ABC, Generic[LmT], discover_packages_path="levanter.models"): # type: ignore @property From 6f21e0db1e99d2b4febea4cb81bee1b2d3fba38b Mon Sep 17 00:00:00 2001 From: JulienDarve Date: Mon, 12 Aug 2024 17:22:03 -0700 Subject: [PATCH 08/29] Testing Modifications --- src/levanter/models/roberta.py | 12 +- src/levanter/models/testing.ipynb | 1041 +++++++++++++++-------------- 2 files changed, 552 insertions(+), 501 deletions(-) diff --git a/src/levanter/models/roberta.py b/src/levanter/models/roberta.py index 90f808fa0..3db771350 100644 --- a/src/levanter/models/roberta.py +++ b/src/levanter/models/roberta.py @@ -136,7 +136,7 @@ class RobertaConfig(HFCompatConfig): reference_checkpoint: str = "FacebookAI/roberta-base" tokenizer: Optional[str] = None - # Axis + # Axes Pos = property(lambda self: Axis(name="position", size=self.max_position_embeddings)) KeyPos = property(lambda self: self.Pos.alias("key_position")) Embed = property(lambda self: Axis(name="embed", size=self.hidden_size)) @@ -147,10 +147,7 @@ class RobertaConfig(HFCompatConfig): Mlp = property(lambda self: Axis(name="mlp", size=self.intermediate_size)) HeadSize = property(lambda self: Axis(name="head_size", size=self.hidden_size // self.num_attention_heads)) - def __post_init__(self): - # TODO - pass - + @classmethod def from_hf_config(cls, hf_config: HfConfig) -> "RobertaConfig": return RobertaConfig( @@ -707,7 +704,7 @@ def __call__( attention_mask = hax.ones(input_axes) # Attention mask from mask to actual numbers - attention_mask = (attention_mask == 0) * -1e9 + attention_mask = (attention_mask == 0) * -jnp.inf embedding_output = self.embeddings.embed(input_ids, token_type_ids, position_ids, input_embeds, key=k_emb) sequence_output = self.encoder(embedding_output, attention_mask=attention_mask, key=k_e) @@ -782,7 +779,6 @@ def __call__( token_type_ids: Optional[NamedArray] = None, position_ids: Optional[NamedArray] = None, input_embeds: Optional[NamedArray] = None, - labels: Optional[NamedArray] = None, *, key=None ) -> Tuple[NamedArray]: @@ -798,7 +794,7 @@ def __call__( k_rob, k_lm = maybe_rng_split(key, 2) outputs = self.roberta( - input_ids=input_ids, + input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, diff --git a/src/levanter/models/testing.ipynb b/src/levanter/models/testing.ipynb index f16e0fcdd..1b9e231af 100644 --- a/src/levanter/models/testing.ipynb +++ b/src/levanter/models/testing.ipynb @@ -11,7 +11,7 @@ "text": [ "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", - "2024-08-01 14:42:02,038\tINFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n" + "2024-08-12 14:00:57,074\tINFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n" ] } ], @@ -22,6 +22,7 @@ "import torch\n", "import haliax as hax\n", "import jax.random as jrandom\n", + "import jax.numpy as jnp\n", "import numpy as np" ] }, @@ -32,72 +33,77 @@ "outputs": [], "source": [ "from transformers import AutoConfig\n", + "from time import time\n", "\n", "hf_model_str = \"FacebookAI/roberta-base\"\n", "\n", "hf_config = AutoConfig.from_pretrained(hf_model_str)\n", "hf_config.hidden_dropout_prob = 0\n", "hf_config.attention_probs_dropout_prob = 0\n", - "my_config = my_roberta.RobertaConfig.from_hf_config(hf_config)\n", - "\n", - "key = jrandom.PRNGKey(0)" + "hf_config.pad_token_id = -1\n", + "my_config = my_roberta.RobertaConfig.from_hf_config(hf_config)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "seed: 1723496463\n" + ] + } + ], "source": [ + "seed = time()\n", + "print(f\"seed: {int(seed)}\")\n", + "key_vars = jrandom.PRNGKey(int(seed))\n", + "\n", "EmbedAtt = my_config.EmbedAtt\n", "Embed = my_config.Embed\n", "Mlp = my_config.Mlp\n", "Pos = my_config.Pos\n", "KeyPos = my_config.KeyPos\n", + "Heads = my_config.Heads\n", + "\n", "Batch = hax.Axis(\"batch\", 2)\n", + "Vocab = hax.Axis(\"vocab\", my_config.vocab_size)\n", "\n", - "x_embed = hax.random.normal(key, (Batch, Pos, Embed))\n", - "x_embed_att = hax.random.normal(key, (Batch, Pos, EmbedAtt))\n", - "x_mlp = hax.random.normal(key, (Batch, Pos, Mlp))\n", - "# x = x[{\"position\": slice(Pos.size-2)}]\n", + "keys = jrandom.split(key_vars, 6)\n", "\n", - "x_embed_torch = torch.from_numpy(np.array(x_embed.array))\n", - "x_embed_att_torch = torch.from_numpy(np.array(x_embed_att.array))\n", - "x_mlp_torch = torch.from_numpy(np.array(x_mlp.array))\n", + "input_ids = hax.random.randint(keys[0], (Batch, Pos), minval = 3, maxval = my_config.vocab_size)\n", + "input_ids_torch = torch.from_numpy(np.array(input_ids.array))\n", + "input_embeds = hax.random.normal(keys[1], (Batch, Pos, Embed))\n", + "input_embeds_torch = torch.from_numpy(np.array(input_embeds.array))\n", "\n", - "mask = hax.ones((Batch, Pos))[{\"position\": slice(0,-2)}]\n", + "# mask = hax.random.randint(keys[2], (Batch, Pos), minval = 0, maxval = 2)\n", + "mask = hax.ones((Batch, Pos))\n", "mask_torch = torch.from_numpy(np.array(mask.array))\n", - "# mask_torch = torch.ones((2, hf_config.num_attention_heads, hf_config.max_position_embeddings -2, hf_config.max_position_embeddings -2))\n", - "\n", - "Vocab = hax.Axis(\"vocab\", my_config.vocab_size)\n", + "mask_torch_materialized = torch.ones((2, hf_config.num_attention_heads, hf_config.max_position_embeddings, hf_config.max_position_embeddings))\n", "\n", - "input_embeds = hax.random.uniform(key, (Batch, Pos, Embed))\n", - "input_embeds = input_embeds[{\"position\": slice(0,-2)}]\n", + "features = input_embeds[{\"position\": 0}]\n", + "features_torch = torch.from_numpy(np.array(features.array))\n", "\n", - "input_ids = hax.random.randint(key, (Batch, Pos), minval = 0, maxval = my_config.vocab_size)\n", - "input_ids = input_ids[{\"position\": slice(0,-2)}]\n", - "\n", - "input_embeds_torch = torch.from_numpy(np.array(input_embeds.array))\n", - "input_ids_torch = torch.from_numpy(np.array(input_ids.array))\n", - "\n", - "features = hax.random.normal(key, (Batch, Embed))\n", - "features_torch = torch.from_numpy(np.array(features.array))" + "x_embed_att = input_embeds.rename({\"embed\": \"embed_att\"})\n", + "x_embed_att_torch = torch.from_numpy(np.array(x_embed_att.array))\n", + "x_mlp = hax.random.normal(keys[5], (Batch, Pos, Mlp))\n", + "x_mlp_torch = torch.from_numpy(np.array(x_mlp.array))" ] }, { - "cell_type": "code", - "execution_count": 4, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "# from transformers import AutoTokenizer\n", - "\n", - "# tokenizer = AutoTokenizer.from_pretrained(hf_model_str)" + "Notes:\n", + "- Random Mask causes RobertaModel to have different output" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 30, "metadata": {}, "outputs": [], "source": [ @@ -107,13 +113,21 @@ " # print(my_output.shape)\n", " # print(hf_output.shape)\n", "\n", - " success = np.isclose(hf_output, np.array(my_output), rtol=precision, atol=precision).all()\n", + " acc = np.isclose(hf_output, np.array(my_output), rtol=precision, atol=precision).mean()\n", "\n", - " if success:\n", - " print(\"Success!!!\")\n", - " else:\n", - " print(\"Fail :((((\")\n", + " stats = (torch.tensor(np.array(my_output)).abs().mean(), torch.tensor(np.array(hf_output)).abs().mean())\n", " \n", + " if p: \n", + " acc = np.isclose(hf_output, np.array(my_output), rtol=precision, atol=precision).mean()\n", + " print(f\"Accuracy: {acc}\")\n", + " print(f\"Jax:\\n{torch.tensor(np.array(my_output))}\\nTorch:\\n{hf_output}\")\n", + "\n", + " if pp:\n", + " diff = torch.tensor(np.array(my_output)) - hf_output\n", + " print(f\"Mean: {diff.abs().mean()}\")\n", + " print(f\"Stdev: {diff.std()}\")\n", + " print(f\"Difference:\\n{diff}\")\n", + "\n", " if ppp:\n", " acc_prev = None\n", " for i in range(15):\n", @@ -126,424 +140,461 @@ " print(f\"Iteration {i}, Precision {prec}:\\t{acc}\")\n", " acc_prev = acc\n", " print(f\"Iteration {i}, Precision {prec}:\\t{acc}\")\n", - "\n", - " if pppp:\n", - " acc = np.isclose(hf_output, np.array(my_output), rtol=precision, atol=precision).mean()\n", - " print(f\"Accuracy: {acc}\")\n", " \n", - " if p: \n", - " acc = np.isclose(hf_output, np.array(my_output), rtol=precision, atol=precision).mean()\n", - " print(f\"Accuracy: {acc}\")\n", - " print(f\"Jax:\\n{torch.tensor(np.array(my_output))}\\nTorch:\\n{hf_output}\")\n", - "\n", - " if pp:\n", - " diff = torch.tensor(np.array(my_output)) - hf_output\n", - " print(f\"Mean: {diff.abs().mean()}\")\n", - " print(f\"Stdev: {diff.std()}\")\n", - " print(f\"Difference:\\n{diff}\")" + " return acc, stats" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'# Testing RobertaSelfOutput\\n\\nmy_self_output = my_roberta.RobertaSelfOutput.init(my_config, key=key)\\nstate = my_self_output.to_state_dict()\\n\\nstate = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\\n\\nprint(state.keys())\\n\\nhf_self_output = hf_roberta.RobertaSelfOutput(hf_config)\\nhf_self_output.load_state_dict(state, strict=True)\\n\\nmy_output = my_self_output(x_embed_att, x_embed, key=key)\\nhf_output = hf_self_output(x_embed_att_torch, x_embed_torch)\\n\\ncheck(my_output.array, hf_output.detach(), ppp=True)'" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "'''# Testing RobertaSelfOutput\n", + "# Testing RobertaSelfOutput\n", "\n", - "my_self_output = my_roberta.RobertaSelfOutput.init(my_config, key=key)\n", - "state = my_self_output.to_state_dict()\n", + "def test_RobertaSelfOutput(key):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", + " my_func = my_roberta.RobertaSelfOutput.init(my_config, key=k_1)\n", + " state = my_func.to_state_dict()\n", "\n", - "state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", "\n", - "print(state.keys())\n", + " # # print(state.keys())\n", "\n", - "hf_self_output = hf_roberta.RobertaSelfOutput(hf_config)\n", - "hf_self_output.load_state_dict(state, strict=True)\n", + " hf_func = hf_roberta.RobertaSelfOutput(hf_config)\n", + " hf_func.load_state_dict(state, strict=True)\n", "\n", - "my_output = my_self_output(x_embed_att, x_embed, key=key)\n", - "hf_output = hf_self_output(x_embed_att_torch, x_embed_torch)\n", + " my_output = my_func(x_embed_att, input_embeds, key=k_2)\n", + " hf_output = hf_func(x_embed_att_torch, input_embeds_torch)\n", "\n", - "check(my_output.array, hf_output.detach(), ppp=True)'''" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'# Testing RobertaSelfAttention\\n\\nmy_attn_output = my_roberta.RobertaSelfAttention.init(my_config, key=key)\\nstate = my_attn_output.to_state_dict()\\n\\nstate = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\\n\\nprint(state.keys())\\n\\nhf_attn_output = hf_roberta.RobertaSelfAttention(hf_config)\\nhf_attn_output.load_state_dict(state, strict=True)\\n\\nmy_output = my_attn_output(x_embed, mask, key=key)\\nhf_output = hf_attn_output(x_embed_torch, mask_torch)\\n\\ncheck(my_output.array, hf_output[0].detach(), ppp=True)'" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "'''# Testing RobertaSelfAttention\n", + " return check(my_output.array, hf_output.detach())\n", "\n", - "my_attn_output = my_roberta.RobertaSelfAttention.init(my_config, key=key)\n", - "state = my_attn_output.to_state_dict()\n", + "# Testing RobertaSelfAttention\n", "\n", - "state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "def test_RobertaSelfAttention(key):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", "\n", - "print(state.keys())\n", + " my_func = my_roberta.RobertaSelfAttention.init(my_config, key=k_1)\n", + " state = my_func.to_state_dict()\n", "\n", - "hf_attn_output = hf_roberta.RobertaSelfAttention(hf_config)\n", - "hf_attn_output.load_state_dict(state, strict=True)\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", "\n", - "my_output = my_attn_output(x_embed, mask, key=key)\n", - "hf_output = hf_attn_output(x_embed_torch, mask_torch)\n", + " # print(state.keys())\n", "\n", - "check(my_output.array, hf_output[0].detach(), ppp=True)'''" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'# Testing RobertaAttention\\n\\nmy_func = my_roberta.RobertaAttention.init(my_config, key=key)\\nstate = my_func.to_state_dict()\\n\\nstate = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\\n\\nprint(state.keys())\\n\\nhf_func = hf_roberta.RobertaAttention(hf_config)\\nhf_func.load_state_dict(state, strict=True)\\n\\nmy_output = my_func(hidden_states=x_embed, attention_mask=mask, key=key)\\nhf_output = hf_func(hidden_states=x_embed_torch, attention_mask=mask_torch)\\n\\ncheck(my_output.array, hf_output[0].detach(), ppp=True)'" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "'''# Testing RobertaAttention\n", + " hf_func = hf_roberta.RobertaSelfAttention(hf_config)\n", + " hf_func.load_state_dict(state, strict=True)\n", "\n", - "my_func = my_roberta.RobertaAttention.init(my_config, key=key)\n", - "state = my_func.to_state_dict()\n", + " my_output = my_func(input_embeds, mask, key=k_2)\n", + " hf_output = hf_func(input_embeds_torch, mask_torch_materialized)\n", "\n", - "state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + " return check(my_output.array, hf_output[0].detach())\n", "\n", - "print(state.keys())\n", + "# Testing RobertaAttention\n", "\n", - "hf_func = hf_roberta.RobertaAttention(hf_config)\n", - "hf_func.load_state_dict(state, strict=True)\n", + "def test_RobertaAttention(key):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", + " my_func = my_roberta.RobertaAttention.init(my_config, key=k_1)\n", + " state = my_func.to_state_dict()\n", "\n", - "my_output = my_func(hidden_states=x_embed, attention_mask=mask, key=key)\n", - "hf_output = hf_func(hidden_states=x_embed_torch, attention_mask=mask_torch)\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", "\n", - "check(my_output.array, hf_output[0].detach(), ppp=True)'''" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'# Testing RobertaIntermediate\\n\\nmy_int_output = my_roberta.RobertaIntermediate.init(my_config, key=key)\\nstate = my_int_output.to_state_dict()\\n\\nstate = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\\n\\nprint(state.keys())\\n\\nhf_int_output = hf_roberta.RobertaIntermediate(hf_config)\\nhf_int_output.load_state_dict(state, strict=True)\\n\\nmy_output = my_int_output(x_embed, key=key)\\nhf_output = hf_int_output(x_embed_torch)\\n\\ncheck(my_output.array, hf_output.detach(), ppp=True)'" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "'''# Testing RobertaIntermediate\n", + " # print(state.keys())\n", "\n", - "my_int_output = my_roberta.RobertaIntermediate.init(my_config, key=key)\n", - "state = my_int_output.to_state_dict()\n", + " hf_func = hf_roberta.RobertaAttention(hf_config)\n", + " hf_func.load_state_dict(state, strict=True)\n", "\n", - "state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + " my_output = my_func(hidden_states=input_embeds, attention_mask=mask, key=k_2)\n", + " hf_output = hf_func(hidden_states=input_embeds_torch, attention_mask=mask_torch_materialized)\n", "\n", - "print(state.keys())\n", + " return check(my_output.array, hf_output[0].detach())\n", "\n", - "hf_int_output = hf_roberta.RobertaIntermediate(hf_config)\n", - "hf_int_output.load_state_dict(state, strict=True)\n", + "# Testing RobertaIntermediate\n", "\n", - "my_output = my_int_output(x_embed, key=key)\n", - "hf_output = hf_int_output(x_embed_torch)\n", + "def test_RobertaIntermediate(key):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", + " my_func = my_roberta.RobertaIntermediate.init(my_config, key=k_1)\n", + " state = my_func.to_state_dict()\n", "\n", - "check(my_output.array, hf_output.detach(), ppp=True)'''" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'# Testing RobertaOutput\\n\\nmy_int_output = my_roberta.RobertaOutput.init(my_config, key=key)\\nstate = my_int_output.to_state_dict()\\n\\nstate = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\\n\\nprint(state.keys())\\n\\nhf_int_output = hf_roberta.RobertaOutput(hf_config)\\nhf_int_output.load_state_dict(state, strict=True)\\n\\nmy_output = my_int_output(x_mlp, x_embed, key=key)\\nhf_output = hf_int_output(x_mlp_torch, x_embed_torch)\\n\\ncheck(my_output.array, hf_output.detach(), ppp=True)'" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "'''# Testing RobertaOutput\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", "\n", - "my_int_output = my_roberta.RobertaOutput.init(my_config, key=key)\n", - "state = my_int_output.to_state_dict()\n", + " # print(state.keys())\n", "\n", - "state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + " hf_func = hf_roberta.RobertaIntermediate(hf_config)\n", + " hf_func.load_state_dict(state, strict=True)\n", "\n", - "print(state.keys())\n", + " my_output = my_func(input_embeds, key=k_2)\n", + " hf_output = hf_func(input_embeds_torch)\n", "\n", - "hf_int_output = hf_roberta.RobertaOutput(hf_config)\n", - "hf_int_output.load_state_dict(state, strict=True)\n", + " return check(my_output.array, hf_output.detach())\n", "\n", - "my_output = my_int_output(x_mlp, x_embed, key=key)\n", - "hf_output = hf_int_output(x_mlp_torch, x_embed_torch)\n", + "# Testing RobertaOutput\n", "\n", - "check(my_output.array, hf_output.detach(), ppp=True)'''" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'# Testing RobertaLayer\\n\\nmy_func = my_roberta.RobertaLayer.init(my_config, key=key)\\nstate = my_func.to_state_dict()\\n\\nstate = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\\n\\nprint(state.keys())\\n\\nhf_func = hf_roberta.RobertaLayer(hf_config)\\nhf_func.load_state_dict(state, strict=True)\\n\\nmy_output = my_func(hidden_states=x_embed, attention_mask=mask, key=key)\\nhf_output = hf_func(hidden_states=x_embed_torch, attention_mask=mask_torch)\\n\\ncheck(my_output.array, hf_output[0].detach(), False, False, True, 1e-4)'" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "'''# Testing RobertaLayer\n", + "def test_RobertaOutput(key):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", "\n", - "my_func = my_roberta.RobertaLayer.init(my_config, key=key)\n", - "state = my_func.to_state_dict()\n", + " my_func = my_roberta.RobertaOutput.init(my_config, key=k_1)\n", + " state = my_func.to_state_dict()\n", "\n", - "state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", "\n", - "print(state.keys())\n", + " # print(state.keys())\n", "\n", - "hf_func = hf_roberta.RobertaLayer(hf_config)\n", - "hf_func.load_state_dict(state, strict=True)\n", + " hf_func = hf_roberta.RobertaOutput(hf_config)\n", + " hf_func.load_state_dict(state, strict=True)\n", "\n", - "my_output = my_func(hidden_states=x_embed, attention_mask=mask, key=key)\n", - "hf_output = hf_func(hidden_states=x_embed_torch, attention_mask=mask_torch)\n", + " my_output = my_func(x_mlp, input_embeds, key=k_2)\n", + " hf_output = hf_func(x_mlp_torch, input_embeds_torch)\n", "\n", - "check(my_output.array, hf_output[0].detach(), False, False, True, 1e-4)'''" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'# Testing RobertaEncoder\\n\\nmy_func = my_roberta.RobertaEncoder.init(my_config, key=key)\\nstate = my_func.to_state_dict()\\n\\nstate = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\\n\\nprint(state.keys())\\n\\nhf_func = hf_roberta.RobertaEncoder(hf_config)\\nhf_func.load_state_dict(state, strict=True)\\n\\nmy_output = my_func(hidden_states=x_embed, attention_mask=mask, key=key)\\nhf_output = hf_func(hidden_states=x_embed_torch, attention_mask=mask_torch)\\n\\ncheck(my_output.array, hf_output[0].detach(), False, False, True, 1e-4)'" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "'''# Testing RobertaEncoder\n", + " return check(my_output.array, hf_output.detach())\n", "\n", - "my_func = my_roberta.RobertaEncoder.init(my_config, key=key)\n", - "state = my_func.to_state_dict()\n", + "# Testing RobertaLayer\n", "\n", - "state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "def test_RobertaLayer(key):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", + " my_func = my_roberta.RobertaLayer.init(my_config, key=k_1)\n", + " state = my_func.to_state_dict()\n", "\n", - "print(state.keys())\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", "\n", - "hf_func = hf_roberta.RobertaEncoder(hf_config)\n", - "hf_func.load_state_dict(state, strict=True)\n", + " # print(state.keys())\n", "\n", - "my_output = my_func(hidden_states=x_embed, attention_mask=mask, key=key)\n", - "hf_output = hf_func(hidden_states=x_embed_torch, attention_mask=mask_torch)\n", + " hf_func = hf_roberta.RobertaLayer(hf_config)\n", + " hf_func.load_state_dict(state, strict=True)\n", "\n", - "check(my_output.array, hf_output[0].detach(), False, False, True, 1e-4)'''" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'# Testing RobertaEmbedding\\n\\nVocab = hax.Axis(\"vocab\", my_config.vocab_size)\\n\\ninput_embeds = hax.random.uniform(key, (Batch, Pos, Embed))\\ninput_embeds = input_embeds[{\"position\": slice(0,-2)}]\\ninput_ids = hax.random.randint(key, (Batch, Pos), minval = 0, maxval = my_config.vocab_size)\\ninput_ids = input_ids[{\"position\": slice(0,-2)}]\\n\\ninput_embeds_torch = torch.from_numpy(np.array(input_embeds.array))\\ninput_ids_torch = torch.from_numpy(np.array(input_ids.array))\\n\\nmy_func = my_roberta.RobertaEmbedding.init(Vocab, my_config, key=key)\\nstate = my_func.to_state_dict()\\n\\nstate = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\\n\\nprint(state.keys())\\n\\nhf_func = hf_roberta.RobertaEmbeddings(hf_config)\\nhf_func.load_state_dict(state, strict=True)\\n\\n# my_output = my_func.embed(input_embeds=input_embeds, key=key)\\n# hf_output = hf_func(inputs_embeds=input_embeds_torch)\\n\\nmy_output = my_func.embed(input_ids=input_ids, key=key)\\nhf_output = hf_func(input_ids=input_ids_torch)\\n\\ncheck(my_output.array, hf_output.detach())'" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "'''# Testing RobertaEmbedding\n", + " my_output = my_func(hidden_states=input_embeds, attention_mask=mask, key=k_2)\n", + " hf_output = hf_func(hidden_states=input_embeds_torch, attention_mask=mask_torch_materialized)\n", "\n", - "Vocab = hax.Axis(\"vocab\", my_config.vocab_size)\n", + " return check(my_output.array, hf_output[0].detach())\n", "\n", - "input_embeds = hax.random.uniform(key, (Batch, Pos, Embed))\n", - "input_embeds = input_embeds[{\"position\": slice(0,-2)}]\n", - "input_ids = hax.random.randint(key, (Batch, Pos), minval = 0, maxval = my_config.vocab_size)\n", - "input_ids = input_ids[{\"position\": slice(0,-2)}]\n", + "# Testing RobertaEncoder\n", "\n", - "input_embeds_torch = torch.from_numpy(np.array(input_embeds.array))\n", - "input_ids_torch = torch.from_numpy(np.array(input_ids.array))\n", + "def test_RobertaEncoder(key):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", + " my_func = my_roberta.RobertaEncoder.init(my_config, key=k_1)\n", + " state = my_func.to_state_dict()\n", "\n", - "my_func = my_roberta.RobertaEmbedding.init(Vocab, my_config, key=key)\n", - "state = my_func.to_state_dict()\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", "\n", - "state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + " # print(state.keys())\n", "\n", - "print(state.keys())\n", + " hf_func = hf_roberta.RobertaEncoder(hf_config)\n", + " hf_func.load_state_dict(state, strict=True)\n", "\n", - "hf_func = hf_roberta.RobertaEmbeddings(hf_config)\n", - "hf_func.load_state_dict(state, strict=True)\n", + " my_output = my_func(hidden_states=input_embeds, attention_mask=mask, key=k_2)\n", + " hf_output = hf_func(hidden_states=input_embeds_torch, attention_mask=mask_torch_materialized)\n", "\n", - "# my_output = my_func.embed(input_embeds=input_embeds, key=key)\n", - "# hf_output = hf_func(inputs_embeds=input_embeds_torch)\n", + " return check(my_output.array, hf_output[0].detach())\n", "\n", - "my_output = my_func.embed(input_ids=input_ids, key=key)\n", - "hf_output = hf_func(input_ids=input_ids_torch)\n", + "# Testing RobertaEmbedding\n", "\n", - "check(my_output.array, hf_output.detach())'''" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'# Testing RobertaPooler\\n\\nmy_pool = my_roberta.RobertaPooler.init(my_config, key=key)\\nstate = my_pool.to_state_dict()\\n\\nstate = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\\n\\nprint(state.keys())\\n\\nhf_pool = hf_roberta.RobertaPooler(hf_config)\\nhf_pool.load_state_dict(state, strict=True)\\n\\nmy_output = my_pool(x_embed, key=key)\\nhf_output = hf_pool(x_embed_torch)\\n\\ncheck(my_output.array, hf_output.detach(), ppp=True)'" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "'''# Testing RobertaPooler\n", + "def test_RobertaEmbedding(key, ids = True):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", + " my_func = my_roberta.RobertaEmbedding.init(Vocab, my_config, key=k_1)\n", + " state = my_func.to_state_dict()\n", "\n", - "my_pool = my_roberta.RobertaPooler.init(my_config, key=key)\n", - "state = my_pool.to_state_dict()\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", "\n", - "state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + " # print(state.keys())\n", "\n", - "print(state.keys())\n", + " hf_func = hf_roberta.RobertaEmbeddings(hf_config)\n", + " hf_func.load_state_dict(state, strict=True)\n", "\n", - "hf_pool = hf_roberta.RobertaPooler(hf_config)\n", - "hf_pool.load_state_dict(state, strict=True)\n", + " if ids:\n", + " my_output = my_func.embed(input_ids=input_ids, key=k_2)\n", + " hf_output = hf_func(input_ids=input_ids_torch)\n", + " else: \n", + " my_output = my_func.embed(input_embeds=input_embeds, key=k_2)\n", + " hf_output = hf_func(inputs_embeds=input_embeds_torch)\n", "\n", - "my_output = my_pool(x_embed, key=key)\n", - "hf_output = hf_pool(x_embed_torch)\n", + " return check(my_output.array, hf_output.detach())\n", "\n", - "check(my_output.array, hf_output.detach(), ppp=True)'''" + "# Testing RobertaPooler\n", + "\n", + "def test_RobertaPooler(key):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", + " my_func = my_roberta.RobertaPooler.init(my_config, key=k_1)\n", + " state = my_func.to_state_dict()\n", + "\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "\n", + " # print(state.keys())\n", + "\n", + " hf_func = hf_roberta.RobertaPooler(hf_config)\n", + " hf_func.load_state_dict(state, strict=True)\n", + "\n", + " my_output = my_func(input_embeds, key=k_2)\n", + " hf_output = hf_func(input_embeds_torch)\n", + "\n", + " return check(my_output.array, hf_output.detach())\n", + "\n", + "# Testing RobertaModel\n", + "\n", + "def test_RobertaModel(key, ids = True, pool = True):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", + " my_func = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=pool, key=k_1)\n", + " state = my_func.to_state_dict()\n", + "\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "\n", + " # print(state.keys())\n", + "\n", + " hf_func = hf_roberta.RobertaModel(hf_config, add_pooling_layer=pool)\n", + " hf_func.load_state_dict(state, strict=True)\n", + "\n", + " if ids:\n", + " my_output = my_func(input_ids = input_ids, attention_mask=mask, key=k_2)\n", + " hf_output = hf_func(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\n", + " else:\n", + " my_output = my_func(input_embeds = input_embeds, attention_mask=mask, key=k_2)\n", + " hf_output = hf_func(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\n", + "\n", + " if pool:\n", + " return check(my_output[1].array, hf_output[1].detach())\n", + " else:\n", + " return check(my_output[0].array, hf_output[0].detach())\n", + "\n", + "# Testing RobertaLMHead\n", + "\n", + "def test_RobertaLMHead(key):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", + " my_func = my_roberta.RobertaLMHead.init(Vocab, my_config, key=k_1)\n", + " state = my_func.to_state_dict()\n", + "\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "\n", + " state[\"bias\"] = torch.zeros(hf_config.vocab_size)\n", + "\n", + " # print(state.keys())\n", + "\n", + " hf_func = hf_roberta.RobertaLMHead(hf_config)\n", + " hf_func.load_state_dict(state, strict=True)\n", + "\n", + " my_output = my_func(features, key=k_2)\n", + " hf_output = hf_func(features_torch)\n", + "\n", + " return check(my_output.array, hf_output.detach())\n", + "\n", + "# Testing RobertaForMaskedLM\n", + "\n", + "def test_RobertaForMaskedLM(key, ids = True):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", + " my_pool = my_roberta.RobertaForMaskedLM.init(Vocab, my_config, key=k_1)\n", + " state = my_pool.to_state_dict()\n", + "\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "\n", + " state[\"lm_head.bias\"] = torch.zeros(hf_config.vocab_size)\n", + "\n", + " # print(state.keys())\n", + "\n", + " hf_pool = hf_roberta.RobertaForMaskedLM(hf_config)\n", + " hf_pool.load_state_dict(state, strict=True)\n", + "\n", + " if ids:\n", + " my_output = my_pool(input_ids = input_ids, attention_mask=mask, key=k_2)\n", + " hf_output = hf_pool(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\n", + " else:\n", + " my_output = my_pool(input_embeds = input_embeds, attention_mask=mask, key=k_2)\n", + " hf_output = hf_pool(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\n", + "\n", + " return check(my_output.array, hf_output[0].detach())" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 6, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'# Testing RobertaModel\\n\\nmy_pool = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=True, key=key)\\nstate = my_pool.to_state_dict()\\n\\nstate = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\\n\\nprint(state.keys())\\n\\nhf_pool = hf_roberta.RobertaModel(hf_config, add_pooling_layer=True)\\nhf_pool.load_state_dict(state, strict=True)\\n\\n# my_output = my_pool(input_ids = input_ids, attention_mask=mask, key=key)\\n# hf_output = hf_pool(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\\n\\nmy_output = my_pool(input_embeds = input_embeds, attention_mask=mask, key=key)\\nhf_output = hf_pool(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\\n\\ncheck(my_output[1].array, hf_output[1].detach(), ppp=True)'" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "'''# Testing RobertaModel\n", + "def out_func(input):\n", + " acc, stats = input\n", + " if acc < 1:\n", + " return str(acc) + \"\\t<---- here\"\n", + " else:\n", + " return str(acc)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "# seed = time() + 20\n", + "# print(f\"seed: {int(seed)}\")\n", + "# key_vars = jrandom.PRNGKey(int(seed))\n", + "# keys = jrandom.split(key_vars, 15)\n", + "\n", + "# print(f\"test_RobertaSelfOutput: {out_func(test_RobertaSelfOutput(keys[0]))}\")\n", + "# print(f\"test_RobertaSelfAttention: {out_func(test_RobertaSelfAttention(keys[1]))}\")\n", + "# print(f\"test_RobertaAttention: {out_func(test_RobertaAttention(keys[2]))}\")\n", + "# print(f\"test_RobertaIntermediate: {out_func(test_RobertaIntermediate(keys[3]))}\")\n", + "# print(f\"test_RobertaOutput: {out_func(test_RobertaOutput(keys[4]))}\")\n", + "# print(f\"test_RobertaEmbedding(ids = True): {out_func(test_RobertaEmbedding(keys[7], ids = True))}\")\n", + "# print(f\"test_RobertaEmbedding(ids = False): {out_func(test_RobertaEmbedding(keys[8], ids = False))}\")\n", + "# print(f\"test_RobertaModel(ids = True, pool = True): {out_func(test_RobertaModel(keys[9], ids = True, pool = True))}\")\n", + "# print(f\"test_RobertaModel(ids = False, pool = False): {out_func(test_RobertaModel(keys[10], ids = False, pool = False))}\")\n", + "# print(f\"test_RobertaModel(ids = True, pool = True): {out_func(test_RobertaModel(keys[9], ids = True, pool = True))}\")\n", + "# print(f\"test_RobertaModel(ids = False, pool = False): {out_func(test_RobertaModel(keys[10], ids = False, pool = False))}\")\n", + "# print(f\"test_RobertaPooler: {out_func(test_RobertaPooler(keys[11]))}\")\n", + "# print(f\"test_RobertaLMHead: {out_func(test_RobertaLMHead(keys[12]))}\")\n", + "# print(f\"test_RobertaForMaskedLM(ids = True): {out_func(test_RobertaForMaskedLM(keys[13], ids = True))}\")\n", + "# print(f\"test_RobertaForMaskedLM(ids = False): {out_func(test_RobertaForMaskedLM(keys[14], ids = False))}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "def get_output_RobertaEmbedding(input, key, ids = True):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", + " my_func = my_roberta.RobertaEmbedding.init(Vocab, my_config, key=k_1)\n", + " state = my_func.to_state_dict()\n", "\n", - "my_pool = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=True, key=key)\n", - "state = my_pool.to_state_dict()\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", "\n", - "state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + " # print(state.keys())\n", + "\n", + " hf_func = hf_roberta.RobertaEmbeddings(hf_config)\n", + " hf_func.load_state_dict(state, strict=True)\n", + "\n", + " input_torch = torch.from_numpy(np.array(input.array))\n", + "\n", + " if ids:\n", + " my_output = my_func.embed(input_ids=input, key=k_2)\n", + " hf_output = hf_func(input_ids=input_torch)\n", + " else: \n", + " my_output = my_func.embed(input_embeds=input, key=k_2)\n", + " hf_output = hf_func(inputs_embeds=input_torch)\n", + "\n", + " return check(my_output.array, hf_output.detach()), (my_output, hf_output)\n", + "\n", + "def get_output_RobertaEncoder(input, key):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", + " my_func = my_roberta.RobertaEncoder.init(my_config, key=k_1)\n", + " state = my_func.to_state_dict()\n", + "\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "\n", + " # print(state.keys())\n", "\n", - "print(state.keys())\n", + " hf_func = hf_roberta.RobertaEncoder(hf_config)\n", + " hf_func.load_state_dict(state, strict=True)\n", "\n", - "hf_pool = hf_roberta.RobertaModel(hf_config, add_pooling_layer=True)\n", - "hf_pool.load_state_dict(state, strict=True)\n", + " input_torch = torch.from_numpy(np.array(input.array))\n", "\n", - "# my_output = my_pool(input_ids = input_ids, attention_mask=mask, key=key)\n", - "# hf_output = hf_pool(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\n", + " attention_mask = hax.ones((Batch, Heads, Pos, KeyPos)) * -jnp.inf\n", + " attention_mask_torch = torch.from_numpy(np.array(attention_mask.array))\n", "\n", - "my_output = my_pool(input_embeds = input_embeds, attention_mask=mask, key=key)\n", - "hf_output = hf_pool(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\n", + " my_output = my_func(hidden_states=input, attention_mask=attention_mask, key=k_2)\n", + " hf_output = hf_func(hidden_states=input_torch, attention_mask=attention_mask_torch)\n", "\n", - "check(my_output[1].array, hf_output[1].detach(), ppp=True)'''" + " return check(my_output.array, hf_output[0].detach()), (my_output, hf_output)\n", + "\n", + "def get_output_RobertaPooler(input, key):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", + " my_func = my_roberta.RobertaPooler.init(my_config, key=k_1)\n", + " state = my_func.to_state_dict()\n", + "\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "\n", + " # print(state.keys())\n", + "\n", + " hf_func = hf_roberta.RobertaPooler(hf_config)\n", + " hf_func.load_state_dict(state, strict=True)\n", + "\n", + " input_torch = torch.from_numpy(np.array(input.array))\n", + "\n", + " my_output = my_func(input, key=k_2)\n", + " hf_output = hf_func(input_torch)\n", + "\n", + " return check(my_output.array, hf_output.detach()), (my_output, hf_output)\n", + "\n", + "# Testing RobertaModel\n", + "\n", + "def get_output_RobertaModel(input, key, ids = True, pool = True):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", + " my_func = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=pool, key=k_1)\n", + " state = my_func.to_state_dict()\n", + "\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "\n", + " # print(state.keys())\n", + "\n", + " hf_func = hf_roberta.RobertaModel(hf_config, add_pooling_layer=pool)\n", + " hf_func.load_state_dict(state, strict=True)\n", + " \n", + " input_torch = torch.from_numpy(np.array(input.array))\n", + " \n", + " # attention_mask = hax.ones((Batch, Heads, Pos, KeyPos)) * -jnp.inf\n", + " # attention_mask_torch = torch.from_numpy(np.array(attention_mask.array))\n", + "\n", + " if ids:\n", + " my_output = my_func(input_ids = input, attention_mask=mask, key=k_2)\n", + " hf_output = hf_func(input_ids = input_torch, attention_mask=mask_torch, return_dict=False)\n", + " else:\n", + " my_output = my_func(input_embeds = input, attention_mask=mask, key=k_2)\n", + " hf_output = hf_func(inputs_embeds = input_torch, attention_mask=mask_torch, return_dict=False)\n", + "\n", + " if pool:\n", + " return check(my_output[1].array, hf_output[1].detach()), (my_output, hf_output)\n", + " else:\n", + " return check(my_output[0].array, hf_output[0].detach()), (my_output, hf_output)\n" ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 33, "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "'# Testing RobertaLMHead\\n\\nmy_pool = my_roberta.RobertaLMHead.init(Vocab, my_config, key=key)\\nstate = my_pool.to_state_dict()\\n\\nstate = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\\n\\nstate[\"bias\"] = torch.zeros(hf_config.vocab_size)\\n\\nprint(state.keys())\\n\\nhf_pool = hf_roberta.RobertaLMHead(hf_config)\\nhf_pool.load_state_dict(state, strict=True)\\n\\nmy_output = my_pool(features, key=key)\\nhf_output = hf_pool(features_torch)\\n\\ncheck(my_output.array, hf_output.detach(), ppp=True)'" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "seed: 1723505034\n", + "{'batch': 2, 'position': 514, 'embed': 768}\n", + "(tensor(0.7984), tensor(0.7984))\n", + "(tensor(nan), tensor(nan))\n", + "(tensor(nan), tensor(nan))\n", + "(tensor(0.5408), tensor(0.5408))\n", + "acc_embeds: 1.0\n", + "acc_enc: 0.0\n", + "acc_pool: 0.0\n", + "acc_model: 1.0\n", + "my comparison pool: (0.0, (tensor(nan), tensor(0.5408)))\n", + "my comparison no pool: (0.0, (tensor(nan), tensor(0.7873)))\n", + "hf comparison: (0.0, (tensor(nan), tensor(0.5408)))\n" + ] } ], "source": [ - "'''# Testing RobertaLMHead\n", - "\n", - "my_pool = my_roberta.RobertaLMHead.init(Vocab, my_config, key=key)\n", - "state = my_pool.to_state_dict()\n", - "\n", - "state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", - "\n", - "state[\"bias\"] = torch.zeros(hf_config.vocab_size)\n", - "\n", - "print(state.keys())\n", - "\n", - "hf_pool = hf_roberta.RobertaLMHead(hf_config)\n", - "hf_pool.load_state_dict(state, strict=True)\n", - "\n", - "my_output = my_pool(features, key=key)\n", - "hf_output = hf_pool(features_torch)\n", - "\n", - "check(my_output.array, hf_output.detach(), ppp=True)'''" + "seed = time() + 30\n", + "print(f\"seed: {int(seed)}\")\n", + "key = jrandom.PRNGKey(int(seed))\n", + "\n", + "k_t, k_emb, k_p = jrandom.split(key, 3)\n", + "\n", + "input = input_embeds\n", + "\n", + "(acc_embeds, stats_embed), (my_out_embeds, hf_out_embeds) = get_output_RobertaEmbedding(input, k_t, ids = False)\n", + "print(stats_embed)\n", + "(acc_enc, stats_enc), (my_out_enc, hf_out_enc) = get_output_RobertaEncoder(my_out_embeds, k_emb)\n", + "print(stats_enc)\n", + "(acc_pool, stats_pool), (my_out_pool, hf_out_pool) = get_output_RobertaPooler(my_out_enc, k_p)\n", + "print(stats_pool)\n", + "\n", + "(acc_model, stats_model), (my_out_model, hf_out_model) = get_output_RobertaModel(input, key, ids = False, pool = True)\n", + "print(stats_model)\n", + "\n", + "print(f\"acc_embeds: {acc_embeds}\")\n", + "print(f\"acc_enc: {acc_enc}\")\n", + "print(f\"acc_pool: {acc_pool}\")\n", + "print(f\"acc_model: {acc_model}\")\n", + "print(f\"my comparison pool: {check(my_out_pool.array, my_out_model[1].array)}\")\n", + "print(f\"my comparison no pool: {check(my_out_enc.array, my_out_model[0].array)}\")\n", + "print(f\"hf comparison: {check(hf_out_pool.detach(), hf_out_model[1].detach())}\")" ] }, { @@ -552,52 +603,108 @@ "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "'# Testing RobertaForMaskedLM\\n\\nmy_pool = my_roberta.RobertaForMaskedLM.init(Vocab, my_config, key=key)\\nstate = my_pool.to_state_dict()\\n\\nstate = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\\n\\nstate[\"lm_head.bias\"] = torch.zeros(hf_config.vocab_size)\\n\\nprint(state.keys())\\n\\nhf_pool = hf_roberta.RobertaForMaskedLM(hf_config)\\nhf_pool.load_state_dict(state, strict=True)\\n\\nmy_output_MLM = my_pool(input_ids = input_ids, attention_mask=mask, key=key)\\nhf_output_MLM = hf_pool(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\\n\\n# my_output = my_pool(input_embeds = input_embeds, attention_mask=mask, key=key)\\n# hf_output = hf_pool(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\\n\\ncheck(my_output_MLM[0].array, hf_output_MLM[0].detach(), ppp=True)'" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "NamedArray(float32{'batch': 2, 'embed': 768},\n", + "[[nan nan nan ... nan nan nan]\n", + " [nan nan nan ... nan nan nan]])\n", + "NamedArray(float32{'batch': 2, 'position': 514, 'embed': 768},\n", + "[[[nan nan nan ... nan nan nan]\n", + " [nan nan nan ... nan nan nan]\n", + " [nan nan nan ... nan nan nan]\n", + " ...\n", + " [nan nan nan ... nan nan nan]\n", + " [nan nan nan ... nan nan nan]\n", + " [nan nan nan ... nan nan nan]]\n", + "\n", + " [[nan nan nan ... nan nan nan]\n", + " [nan nan nan ... nan nan nan]\n", + " [nan nan nan ... nan nan nan]\n", + " ...\n", + " [nan nan nan ... nan nan nan]\n", + " [nan nan nan ... nan nan nan]\n", + " [nan nan nan ... nan nan nan]]])\n", + "NamedArray(float32{'batch': 2, 'position': 514, 'embed': 768},\n", + "[[[ 0.47864354 1.2938721 1.0534003 ... -1.4044254 0.829634\n", + " 0.00428176]\n", + " [-0.98862374 -0.943986 -1.0448135 ... -0.64666593 0.12967904\n", + " -1.0975188 ]\n", + " [ 0.43071395 -0.60738516 -1.7641208 ... -1.1334671 0.9041689\n", + " 0.9875958 ]\n", + " ...\n", + " [ 1.287216 0.507795 0.23451686 ... -0.9582702 -0.3576718\n", + " 0.6565546 ]\n", + " [ 0.33264828 -0.68922603 0.41440547 ... 0.4528543 -0.6819962\n", + " 0.4289952 ]\n", + " [ 0.725326 1.9756228 1.1881577 ... 0.5643402 0.5135605\n", + " 0.92514485]]\n", + "\n", + " [[ 1.0026733 1.1235753 -1.017235 ... -1.8810284 -0.29097554\n", + " -0.63098675]\n", + " [-0.47498116 1.9341669 0.23969549 ... -0.45160082 -0.955768\n", + " -1.4716814 ]\n", + " [-0.2948639 0.25138515 -1.3983693 ... -0.96624637 0.44848248\n", + " -0.71705264]\n", + " ...\n", + " [-0.4495727 0.07491604 0.919175 ... 0.565745 -0.34500855\n", + " -1.9166113 ]\n", + " [-0.3112308 -0.21019831 0.2379393 ... 1.3521733 0.1243041\n", + " -1.3730545 ]\n", + " [-0.9592797 -1.1558015 -1.3304269 ... 1.4129258 0.69931823\n", + " 0.24171986]]])\n", + "(NamedArray(array=Array([[[-0.18572666, -0.80562836, -0.98205453, ..., 0.9597818 ,\n", + " 1.5924176 , 0.3848163 ],\n", + " [-0.19698945, -0.7534842 , -0.81553775, ..., 0.8907304 ,\n", + " 1.6596173 , 0.37368602],\n", + " [-0.14524579, -0.6914802 , -0.91773754, ..., 1.024012 ,\n", + " 1.6283392 , 0.34325117],\n", + " ...,\n", + " [-0.12597106, -0.76580876, -0.8392121 , ..., 0.9352241 ,\n", + " 1.5550641 , 0.46660298],\n", + " [-0.17409518, -0.749031 , -1.056306 , ..., 0.9757236 ,\n", + " 1.633118 , 0.5897971 ],\n", + " [-0.30729672, -0.69016093, -0.87607175, ..., 0.874229 ,\n", + " 1.6674999 , 0.38814685]],\n", + "\n", + " [[-0.2995884 , -0.92027843, -0.78937566, ..., 0.43273145,\n", + " 1.0177305 , 0.4611196 ],\n", + " [-0.3179962 , -0.8442052 , -0.75374943, ..., 0.7022148 ,\n", + " 1.0696493 , 0.3404984 ],\n", + " [-0.33889046, -0.90215874, -0.65796405, ..., 0.5069377 ,\n", + " 1.0210142 , 0.30466717],\n", + " ...,\n", + " [-0.47294238, -0.8416684 , -0.7532904 , ..., 0.46125498,\n", + " 1.1491499 , 0.41495347],\n", + " [-0.28819865, -0.8842407 , -0.69517034, ..., 0.49842533,\n", + " 1.0367949 , 0.58008623],\n", + " [-0.31019112, -0.90532184, -0.7528029 , ..., 0.5191115 ,\n", + " 1.299418 , 0.43962964]]], dtype=float32), axes=(Axis(name='batch', size=2), Axis(name='position', size=514), Axis(name='embed', size=768))), NamedArray(array=Array([[ 0.60376894, 0.34222102, -0.01021165, ..., -0.90009135,\n", + " -0.6602305 , 0.14456 ],\n", + " [ 0.8103463 , 0.4916969 , -0.01769697, ..., -0.94779646,\n", + " -0.8301279 , 0.29279563]], dtype=float32), axes=(Axis(name='batch', size=2), Axis(name='embed', size=768))))\n" + ] } ], "source": [ - "'''# Testing RobertaForMaskedLM\n", - "\n", - "my_pool = my_roberta.RobertaForMaskedLM.init(Vocab, my_config, key=key)\n", - "state = my_pool.to_state_dict()\n", - "\n", - "state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", - "\n", - "state[\"lm_head.bias\"] = torch.zeros(hf_config.vocab_size)\n", - "\n", - "print(state.keys())\n", - "\n", - "hf_pool = hf_roberta.RobertaForMaskedLM(hf_config)\n", - "hf_pool.load_state_dict(state, strict=True)\n", - "\n", - "my_output_MLM = my_pool(input_ids = input_ids, attention_mask=mask, key=key)\n", - "hf_output_MLM = hf_pool(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\n", - "\n", - "# my_output = my_pool(input_embeds = input_embeds, attention_mask=mask, key=key)\n", - "# hf_output = hf_pool(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\n", - "\n", - "check(my_output_MLM[0].array, hf_output_MLM[0].detach(), ppp=True)'''" + "print(my_out_pool)\n", + "print(my_out_enc)\n", + "print(my_out_embeds)\n", + "print(my_out_model)" ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "'# Testing RobertaModel\\n\\nmy_model = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=False, key=key)\\nstate = my_model.to_state_dict()\\n\\nstate = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\\n\\nprint(state.keys())\\n\\nhf_model = hf_roberta.RobertaModel(hf_config, add_pooling_layer=False)\\nhf_model.load_state_dict(state, strict=True)\\n\\nmy_output = my_model(input_ids = input_ids, attention_mask=mask, key=key)\\nhf_output = hf_model(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\\n\\ncheck(my_output[0].array, hf_output[0].detach(), ppp=True)\\n\\n# Testing RobertaLMHead\\n\\nmy_head = my_roberta.RobertaLMHead.init(Vocab, my_config, key=key)\\nstate = my_head.to_state_dict()\\n\\nstate = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\\n\\nstate[\"bias\"] = torch.zeros(hf_config.vocab_size)\\n\\nprint(state.keys())\\n\\nhf_head = hf_roberta.RobertaLMHead(hf_config)\\nhf_head.load_state_dict(state, strict=True)\\n\\nmy_output = my_head(my_output[0], key=key)\\nhf_output = hf_head(hf_output[0])\\n\\ncheck(my_output.array, hf_output.detach(), ppp=True)'" + "'# Testing RobertaModel\\n\\nmy_model = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=False, key=key)\\nstate = my_model.to_state_dict()\\n\\nstate = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\\n\\n# print(state.keys())\\n\\nhf_model = hf_roberta.RobertaModel(hf_config, add_pooling_layer=False)\\nhf_model.load_state_dict(state, strict=True)\\n\\nmy_output = my_model(input_ids = input_ids, attention_mask=mask, key=key)\\nhf_output = hf_model(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\\n\\ncheck(my_output[0].array, hf_output[0].detach(), ppp=True)\\n\\n# Testing RobertaLMHead\\n\\nmy_head = my_roberta.RobertaLMHead.init(Vocab, my_config, key=key)\\nstate = my_head.to_state_dict()\\n\\nstate = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\\n\\nstate[\"bias\"] = torch.zeros(hf_config.vocab_size)\\n\\n# print(state.keys())\\n\\nhf_head = hf_roberta.RobertaLMHead(hf_config)\\nhf_head.load_state_dict(state, strict=True)\\n\\nmy_output = my_head(my_output[0], key=key)\\nhf_output = hf_head(hf_output[0])\\n\\ncheck(my_output.array, hf_output.detach(), ppp=True)'" ] }, - "execution_count": 18, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -610,7 +717,7 @@ "\n", "state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", "\n", - "print(state.keys())\n", + "# print(state.keys())\n", "\n", "hf_model = hf_roberta.RobertaModel(hf_config, add_pooling_layer=False)\n", "hf_model.load_state_dict(state, strict=True)\n", @@ -629,7 +736,7 @@ "\n", "state[\"bias\"] = torch.zeros(hf_config.vocab_size)\n", "\n", - "print(state.keys())\n", + "# print(state.keys())\n", "\n", "hf_head = hf_roberta.RobertaLMHead(hf_config)\n", "hf_head.load_state_dict(state, strict=True)\n", @@ -642,139 +749,87 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 11, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "dict_keys(['roberta.encoder.layer.0.attention.self.query.weight', 'roberta.encoder.layer.1.attention.self.query.weight', 'roberta.encoder.layer.2.attention.self.query.weight', 'roberta.encoder.layer.3.attention.self.query.weight', 'roberta.encoder.layer.4.attention.self.query.weight', 'roberta.encoder.layer.5.attention.self.query.weight', 'roberta.encoder.layer.6.attention.self.query.weight', 'roberta.encoder.layer.7.attention.self.query.weight', 'roberta.encoder.layer.8.attention.self.query.weight', 'roberta.encoder.layer.9.attention.self.query.weight', 'roberta.encoder.layer.10.attention.self.query.weight', 'roberta.encoder.layer.11.attention.self.query.weight', 'roberta.encoder.layer.0.attention.self.query.bias', 'roberta.encoder.layer.1.attention.self.query.bias', 'roberta.encoder.layer.2.attention.self.query.bias', 'roberta.encoder.layer.3.attention.self.query.bias', 'roberta.encoder.layer.4.attention.self.query.bias', 'roberta.encoder.layer.5.attention.self.query.bias', 'roberta.encoder.layer.6.attention.self.query.bias', 'roberta.encoder.layer.7.attention.self.query.bias', 'roberta.encoder.layer.8.attention.self.query.bias', 'roberta.encoder.layer.9.attention.self.query.bias', 'roberta.encoder.layer.10.attention.self.query.bias', 'roberta.encoder.layer.11.attention.self.query.bias', 'roberta.encoder.layer.0.attention.self.key.weight', 'roberta.encoder.layer.1.attention.self.key.weight', 'roberta.encoder.layer.2.attention.self.key.weight', 'roberta.encoder.layer.3.attention.self.key.weight', 'roberta.encoder.layer.4.attention.self.key.weight', 'roberta.encoder.layer.5.attention.self.key.weight', 'roberta.encoder.layer.6.attention.self.key.weight', 'roberta.encoder.layer.7.attention.self.key.weight', 'roberta.encoder.layer.8.attention.self.key.weight', 'roberta.encoder.layer.9.attention.self.key.weight', 'roberta.encoder.layer.10.attention.self.key.weight', 'roberta.encoder.layer.11.attention.self.key.weight', 'roberta.encoder.layer.0.attention.self.key.bias', 'roberta.encoder.layer.1.attention.self.key.bias', 'roberta.encoder.layer.2.attention.self.key.bias', 'roberta.encoder.layer.3.attention.self.key.bias', 'roberta.encoder.layer.4.attention.self.key.bias', 'roberta.encoder.layer.5.attention.self.key.bias', 'roberta.encoder.layer.6.attention.self.key.bias', 'roberta.encoder.layer.7.attention.self.key.bias', 'roberta.encoder.layer.8.attention.self.key.bias', 'roberta.encoder.layer.9.attention.self.key.bias', 'roberta.encoder.layer.10.attention.self.key.bias', 'roberta.encoder.layer.11.attention.self.key.bias', 'roberta.encoder.layer.0.attention.self.value.weight', 'roberta.encoder.layer.1.attention.self.value.weight', 'roberta.encoder.layer.2.attention.self.value.weight', 'roberta.encoder.layer.3.attention.self.value.weight', 'roberta.encoder.layer.4.attention.self.value.weight', 'roberta.encoder.layer.5.attention.self.value.weight', 'roberta.encoder.layer.6.attention.self.value.weight', 'roberta.encoder.layer.7.attention.self.value.weight', 'roberta.encoder.layer.8.attention.self.value.weight', 'roberta.encoder.layer.9.attention.self.value.weight', 'roberta.encoder.layer.10.attention.self.value.weight', 'roberta.encoder.layer.11.attention.self.value.weight', 'roberta.encoder.layer.0.attention.self.value.bias', 'roberta.encoder.layer.1.attention.self.value.bias', 'roberta.encoder.layer.2.attention.self.value.bias', 'roberta.encoder.layer.3.attention.self.value.bias', 'roberta.encoder.layer.4.attention.self.value.bias', 'roberta.encoder.layer.5.attention.self.value.bias', 'roberta.encoder.layer.6.attention.self.value.bias', 'roberta.encoder.layer.7.attention.self.value.bias', 'roberta.encoder.layer.8.attention.self.value.bias', 'roberta.encoder.layer.9.attention.self.value.bias', 'roberta.encoder.layer.10.attention.self.value.bias', 'roberta.encoder.layer.11.attention.self.value.bias', 'roberta.encoder.layer.0.attention.output.dense.weight', 'roberta.encoder.layer.1.attention.output.dense.weight', 'roberta.encoder.layer.2.attention.output.dense.weight', 'roberta.encoder.layer.3.attention.output.dense.weight', 'roberta.encoder.layer.4.attention.output.dense.weight', 'roberta.encoder.layer.5.attention.output.dense.weight', 'roberta.encoder.layer.6.attention.output.dense.weight', 'roberta.encoder.layer.7.attention.output.dense.weight', 'roberta.encoder.layer.8.attention.output.dense.weight', 'roberta.encoder.layer.9.attention.output.dense.weight', 'roberta.encoder.layer.10.attention.output.dense.weight', 'roberta.encoder.layer.11.attention.output.dense.weight', 'roberta.encoder.layer.0.attention.output.dense.bias', 'roberta.encoder.layer.1.attention.output.dense.bias', 'roberta.encoder.layer.2.attention.output.dense.bias', 'roberta.encoder.layer.3.attention.output.dense.bias', 'roberta.encoder.layer.4.attention.output.dense.bias', 'roberta.encoder.layer.5.attention.output.dense.bias', 'roberta.encoder.layer.6.attention.output.dense.bias', 'roberta.encoder.layer.7.attention.output.dense.bias', 'roberta.encoder.layer.8.attention.output.dense.bias', 'roberta.encoder.layer.9.attention.output.dense.bias', 'roberta.encoder.layer.10.attention.output.dense.bias', 'roberta.encoder.layer.11.attention.output.dense.bias', 'roberta.encoder.layer.0.attention.output.LayerNorm.weight', 'roberta.encoder.layer.1.attention.output.LayerNorm.weight', 'roberta.encoder.layer.2.attention.output.LayerNorm.weight', 'roberta.encoder.layer.3.attention.output.LayerNorm.weight', 'roberta.encoder.layer.4.attention.output.LayerNorm.weight', 'roberta.encoder.layer.5.attention.output.LayerNorm.weight', 'roberta.encoder.layer.6.attention.output.LayerNorm.weight', 'roberta.encoder.layer.7.attention.output.LayerNorm.weight', 'roberta.encoder.layer.8.attention.output.LayerNorm.weight', 'roberta.encoder.layer.9.attention.output.LayerNorm.weight', 'roberta.encoder.layer.10.attention.output.LayerNorm.weight', 'roberta.encoder.layer.11.attention.output.LayerNorm.weight', 'roberta.encoder.layer.0.attention.output.LayerNorm.bias', 'roberta.encoder.layer.1.attention.output.LayerNorm.bias', 'roberta.encoder.layer.2.attention.output.LayerNorm.bias', 'roberta.encoder.layer.3.attention.output.LayerNorm.bias', 'roberta.encoder.layer.4.attention.output.LayerNorm.bias', 'roberta.encoder.layer.5.attention.output.LayerNorm.bias', 'roberta.encoder.layer.6.attention.output.LayerNorm.bias', 'roberta.encoder.layer.7.attention.output.LayerNorm.bias', 'roberta.encoder.layer.8.attention.output.LayerNorm.bias', 'roberta.encoder.layer.9.attention.output.LayerNorm.bias', 'roberta.encoder.layer.10.attention.output.LayerNorm.bias', 'roberta.encoder.layer.11.attention.output.LayerNorm.bias', 'roberta.encoder.layer.0.intermediate.dense.weight', 'roberta.encoder.layer.1.intermediate.dense.weight', 'roberta.encoder.layer.2.intermediate.dense.weight', 'roberta.encoder.layer.3.intermediate.dense.weight', 'roberta.encoder.layer.4.intermediate.dense.weight', 'roberta.encoder.layer.5.intermediate.dense.weight', 'roberta.encoder.layer.6.intermediate.dense.weight', 'roberta.encoder.layer.7.intermediate.dense.weight', 'roberta.encoder.layer.8.intermediate.dense.weight', 'roberta.encoder.layer.9.intermediate.dense.weight', 'roberta.encoder.layer.10.intermediate.dense.weight', 'roberta.encoder.layer.11.intermediate.dense.weight', 'roberta.encoder.layer.0.intermediate.dense.bias', 'roberta.encoder.layer.1.intermediate.dense.bias', 'roberta.encoder.layer.2.intermediate.dense.bias', 'roberta.encoder.layer.3.intermediate.dense.bias', 'roberta.encoder.layer.4.intermediate.dense.bias', 'roberta.encoder.layer.5.intermediate.dense.bias', 'roberta.encoder.layer.6.intermediate.dense.bias', 'roberta.encoder.layer.7.intermediate.dense.bias', 'roberta.encoder.layer.8.intermediate.dense.bias', 'roberta.encoder.layer.9.intermediate.dense.bias', 'roberta.encoder.layer.10.intermediate.dense.bias', 'roberta.encoder.layer.11.intermediate.dense.bias', 'roberta.encoder.layer.0.output.dense.weight', 'roberta.encoder.layer.1.output.dense.weight', 'roberta.encoder.layer.2.output.dense.weight', 'roberta.encoder.layer.3.output.dense.weight', 'roberta.encoder.layer.4.output.dense.weight', 'roberta.encoder.layer.5.output.dense.weight', 'roberta.encoder.layer.6.output.dense.weight', 'roberta.encoder.layer.7.output.dense.weight', 'roberta.encoder.layer.8.output.dense.weight', 'roberta.encoder.layer.9.output.dense.weight', 'roberta.encoder.layer.10.output.dense.weight', 'roberta.encoder.layer.11.output.dense.weight', 'roberta.encoder.layer.0.output.dense.bias', 'roberta.encoder.layer.1.output.dense.bias', 'roberta.encoder.layer.2.output.dense.bias', 'roberta.encoder.layer.3.output.dense.bias', 'roberta.encoder.layer.4.output.dense.bias', 'roberta.encoder.layer.5.output.dense.bias', 'roberta.encoder.layer.6.output.dense.bias', 'roberta.encoder.layer.7.output.dense.bias', 'roberta.encoder.layer.8.output.dense.bias', 'roberta.encoder.layer.9.output.dense.bias', 'roberta.encoder.layer.10.output.dense.bias', 'roberta.encoder.layer.11.output.dense.bias', 'roberta.encoder.layer.0.output.LayerNorm.weight', 'roberta.encoder.layer.1.output.LayerNorm.weight', 'roberta.encoder.layer.2.output.LayerNorm.weight', 'roberta.encoder.layer.3.output.LayerNorm.weight', 'roberta.encoder.layer.4.output.LayerNorm.weight', 'roberta.encoder.layer.5.output.LayerNorm.weight', 'roberta.encoder.layer.6.output.LayerNorm.weight', 'roberta.encoder.layer.7.output.LayerNorm.weight', 'roberta.encoder.layer.8.output.LayerNorm.weight', 'roberta.encoder.layer.9.output.LayerNorm.weight', 'roberta.encoder.layer.10.output.LayerNorm.weight', 'roberta.encoder.layer.11.output.LayerNorm.weight', 'roberta.encoder.layer.0.output.LayerNorm.bias', 'roberta.encoder.layer.1.output.LayerNorm.bias', 'roberta.encoder.layer.2.output.LayerNorm.bias', 'roberta.encoder.layer.3.output.LayerNorm.bias', 'roberta.encoder.layer.4.output.LayerNorm.bias', 'roberta.encoder.layer.5.output.LayerNorm.bias', 'roberta.encoder.layer.6.output.LayerNorm.bias', 'roberta.encoder.layer.7.output.LayerNorm.bias', 'roberta.encoder.layer.8.output.LayerNorm.bias', 'roberta.encoder.layer.9.output.LayerNorm.bias', 'roberta.encoder.layer.10.output.LayerNorm.bias', 'roberta.encoder.layer.11.output.LayerNorm.bias', 'roberta.embeddings.word_embeddings.weight', 'roberta.embeddings.position_embeddings.weight', 'roberta.embeddings.token_type_embeddings.weight', 'roberta.embeddings.LayerNorm.weight', 'roberta.embeddings.LayerNorm.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.decoder.bias', 'lm_head.bias'])\n", - "dict_keys(['encoder.layer.0.attention.self.query.weight', 'encoder.layer.1.attention.self.query.weight', 'encoder.layer.2.attention.self.query.weight', 'encoder.layer.3.attention.self.query.weight', 'encoder.layer.4.attention.self.query.weight', 'encoder.layer.5.attention.self.query.weight', 'encoder.layer.6.attention.self.query.weight', 'encoder.layer.7.attention.self.query.weight', 'encoder.layer.8.attention.self.query.weight', 'encoder.layer.9.attention.self.query.weight', 'encoder.layer.10.attention.self.query.weight', 'encoder.layer.11.attention.self.query.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.1.attention.self.query.bias', 'encoder.layer.2.attention.self.query.bias', 'encoder.layer.3.attention.self.query.bias', 'encoder.layer.4.attention.self.query.bias', 'encoder.layer.5.attention.self.query.bias', 'encoder.layer.6.attention.self.query.bias', 'encoder.layer.7.attention.self.query.bias', 'encoder.layer.8.attention.self.query.bias', 'encoder.layer.9.attention.self.query.bias', 'encoder.layer.10.attention.self.query.bias', 'encoder.layer.11.attention.self.query.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.1.attention.self.key.weight', 'encoder.layer.2.attention.self.key.weight', 'encoder.layer.3.attention.self.key.weight', 'encoder.layer.4.attention.self.key.weight', 'encoder.layer.5.attention.self.key.weight', 'encoder.layer.6.attention.self.key.weight', 'encoder.layer.7.attention.self.key.weight', 'encoder.layer.8.attention.self.key.weight', 'encoder.layer.9.attention.self.key.weight', 'encoder.layer.10.attention.self.key.weight', 'encoder.layer.11.attention.self.key.weight', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.1.attention.self.key.bias', 'encoder.layer.2.attention.self.key.bias', 'encoder.layer.3.attention.self.key.bias', 'encoder.layer.4.attention.self.key.bias', 'encoder.layer.5.attention.self.key.bias', 'encoder.layer.6.attention.self.key.bias', 'encoder.layer.7.attention.self.key.bias', 'encoder.layer.8.attention.self.key.bias', 'encoder.layer.9.attention.self.key.bias', 'encoder.layer.10.attention.self.key.bias', 'encoder.layer.11.attention.self.key.bias', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.1.attention.self.value.weight', 'encoder.layer.2.attention.self.value.weight', 'encoder.layer.3.attention.self.value.weight', 'encoder.layer.4.attention.self.value.weight', 'encoder.layer.5.attention.self.value.weight', 'encoder.layer.6.attention.self.value.weight', 'encoder.layer.7.attention.self.value.weight', 'encoder.layer.8.attention.self.value.weight', 'encoder.layer.9.attention.self.value.weight', 'encoder.layer.10.attention.self.value.weight', 'encoder.layer.11.attention.self.value.weight', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.1.attention.self.value.bias', 'encoder.layer.2.attention.self.value.bias', 'encoder.layer.3.attention.self.value.bias', 'encoder.layer.4.attention.self.value.bias', 'encoder.layer.5.attention.self.value.bias', 'encoder.layer.6.attention.self.value.bias', 'encoder.layer.7.attention.self.value.bias', 'encoder.layer.8.attention.self.value.bias', 'encoder.layer.9.attention.self.value.bias', 'encoder.layer.10.attention.self.value.bias', 'encoder.layer.11.attention.self.value.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.1.attention.output.dense.weight', 'encoder.layer.2.attention.output.dense.weight', 'encoder.layer.3.attention.output.dense.weight', 'encoder.layer.4.attention.output.dense.weight', 'encoder.layer.5.attention.output.dense.weight', 'encoder.layer.6.attention.output.dense.weight', 'encoder.layer.7.attention.output.dense.weight', 'encoder.layer.8.attention.output.dense.weight', 'encoder.layer.9.attention.output.dense.weight', 'encoder.layer.10.attention.output.dense.weight', 'encoder.layer.11.attention.output.dense.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.1.attention.output.dense.bias', 'encoder.layer.2.attention.output.dense.bias', 'encoder.layer.3.attention.output.dense.bias', 'encoder.layer.4.attention.output.dense.bias', 'encoder.layer.5.attention.output.dense.bias', 'encoder.layer.6.attention.output.dense.bias', 'encoder.layer.7.attention.output.dense.bias', 'encoder.layer.8.attention.output.dense.bias', 'encoder.layer.9.attention.output.dense.bias', 'encoder.layer.10.attention.output.dense.bias', 'encoder.layer.11.attention.output.dense.bias', 'encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.layer.1.attention.output.LayerNorm.weight', 'encoder.layer.2.attention.output.LayerNorm.weight', 'encoder.layer.3.attention.output.LayerNorm.weight', 'encoder.layer.4.attention.output.LayerNorm.weight', 'encoder.layer.5.attention.output.LayerNorm.weight', 'encoder.layer.6.attention.output.LayerNorm.weight', 'encoder.layer.7.attention.output.LayerNorm.weight', 'encoder.layer.8.attention.output.LayerNorm.weight', 'encoder.layer.9.attention.output.LayerNorm.weight', 'encoder.layer.10.attention.output.LayerNorm.weight', 'encoder.layer.11.attention.output.LayerNorm.weight', 'encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.layer.1.attention.output.LayerNorm.bias', 'encoder.layer.2.attention.output.LayerNorm.bias', 'encoder.layer.3.attention.output.LayerNorm.bias', 'encoder.layer.4.attention.output.LayerNorm.bias', 'encoder.layer.5.attention.output.LayerNorm.bias', 'encoder.layer.6.attention.output.LayerNorm.bias', 'encoder.layer.7.attention.output.LayerNorm.bias', 'encoder.layer.8.attention.output.LayerNorm.bias', 'encoder.layer.9.attention.output.LayerNorm.bias', 'encoder.layer.10.attention.output.LayerNorm.bias', 'encoder.layer.11.attention.output.LayerNorm.bias', 'encoder.layer.0.intermediate.dense.weight', 'encoder.layer.1.intermediate.dense.weight', 'encoder.layer.2.intermediate.dense.weight', 'encoder.layer.3.intermediate.dense.weight', 'encoder.layer.4.intermediate.dense.weight', 'encoder.layer.5.intermediate.dense.weight', 'encoder.layer.6.intermediate.dense.weight', 'encoder.layer.7.intermediate.dense.weight', 'encoder.layer.8.intermediate.dense.weight', 'encoder.layer.9.intermediate.dense.weight', 'encoder.layer.10.intermediate.dense.weight', 'encoder.layer.11.intermediate.dense.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.1.intermediate.dense.bias', 'encoder.layer.2.intermediate.dense.bias', 'encoder.layer.3.intermediate.dense.bias', 'encoder.layer.4.intermediate.dense.bias', 'encoder.layer.5.intermediate.dense.bias', 'encoder.layer.6.intermediate.dense.bias', 'encoder.layer.7.intermediate.dense.bias', 'encoder.layer.8.intermediate.dense.bias', 'encoder.layer.9.intermediate.dense.bias', 'encoder.layer.10.intermediate.dense.bias', 'encoder.layer.11.intermediate.dense.bias', 'encoder.layer.0.output.dense.weight', 'encoder.layer.1.output.dense.weight', 'encoder.layer.2.output.dense.weight', 'encoder.layer.3.output.dense.weight', 'encoder.layer.4.output.dense.weight', 'encoder.layer.5.output.dense.weight', 'encoder.layer.6.output.dense.weight', 'encoder.layer.7.output.dense.weight', 'encoder.layer.8.output.dense.weight', 'encoder.layer.9.output.dense.weight', 'encoder.layer.10.output.dense.weight', 'encoder.layer.11.output.dense.weight', 'encoder.layer.0.output.dense.bias', 'encoder.layer.1.output.dense.bias', 'encoder.layer.2.output.dense.bias', 'encoder.layer.3.output.dense.bias', 'encoder.layer.4.output.dense.bias', 'encoder.layer.5.output.dense.bias', 'encoder.layer.6.output.dense.bias', 'encoder.layer.7.output.dense.bias', 'encoder.layer.8.output.dense.bias', 'encoder.layer.9.output.dense.bias', 'encoder.layer.10.output.dense.bias', 'encoder.layer.11.output.dense.bias', 'encoder.layer.0.output.LayerNorm.weight', 'encoder.layer.1.output.LayerNorm.weight', 'encoder.layer.2.output.LayerNorm.weight', 'encoder.layer.3.output.LayerNorm.weight', 'encoder.layer.4.output.LayerNorm.weight', 'encoder.layer.5.output.LayerNorm.weight', 'encoder.layer.6.output.LayerNorm.weight', 'encoder.layer.7.output.LayerNorm.weight', 'encoder.layer.8.output.LayerNorm.weight', 'encoder.layer.9.output.LayerNorm.weight', 'encoder.layer.10.output.LayerNorm.weight', 'encoder.layer.11.output.LayerNorm.weight', 'encoder.layer.0.output.LayerNorm.bias', 'encoder.layer.1.output.LayerNorm.bias', 'encoder.layer.2.output.LayerNorm.bias', 'encoder.layer.3.output.LayerNorm.bias', 'encoder.layer.4.output.LayerNorm.bias', 'encoder.layer.5.output.LayerNorm.bias', 'encoder.layer.6.output.LayerNorm.bias', 'encoder.layer.7.output.LayerNorm.bias', 'encoder.layer.8.output.LayerNorm.bias', 'encoder.layer.9.output.LayerNorm.bias', 'encoder.layer.10.output.LayerNorm.bias', 'encoder.layer.11.output.LayerNorm.bias', 'embeddings.word_embeddings.weight', 'embeddings.position_embeddings.weight', 'embeddings.token_type_embeddings.weight', 'embeddings.LayerNorm.weight', 'embeddings.LayerNorm.bias'])\n", - "dict_keys(['dense.weight', 'dense.bias', 'layer_norm.weight', 'layer_norm.bias', 'decoder.weight', 'decoder.bias', 'bias'])\n" - ] - }, { "data": { "text/plain": [ - "" + "'# Testing RobertaForMaskedLM\\n\\nmy_mlm = my_roberta.RobertaForMaskedLM.init(Vocab, my_config, key=key)\\nstate_mlm = my_mlm.to_state_dict()\\n\\nstate_mlm = {k: torch.from_numpy(np.array(v)) for k, v in state_mlm.items()}\\n\\n# if \"lm_head.decoder.bias\" in state:\\n# print(state[\"lm_head.decoder.bias\"])\\n# else:\\n# print(f\"RobertaForMaskedLM, {state.keys()}\")\\n\\nstate_mlm[\"lm_head.bias\"] = torch.zeros(hf_config.vocab_size)\\n\\nprint(state_mlm.keys())\\n\\nhf_mlm = hf_roberta.RobertaForMaskedLM(hf_config)\\nhf_mlm.load_state_dict(state_mlm, strict=True)\\n\\n# Testing RobertaModel\\n\\nkey_rob, key_head = jrandom.split(key, 2)\\n\\nmy_model = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=False, key=key_rob)\\nstate_model = my_model.to_state_dict()\\n\\nstate_model = {k: torch.from_numpy(np.array(v)) for k, v in state_model.items()}\\n\\nprint(state_model.keys())\\n\\nhf_model = hf_roberta.RobertaModel(hf_config, add_pooling_layer=False)\\nhf_model.load_state_dict(state_model, strict=True)\\n\\n# Testing RobertaLMHead\\n\\nmy_head = my_roberta.RobertaLMHead.init(Vocab, my_config, key=key_head)\\nstate_head = my_head.to_state_dict()\\n\\nstate_head = {k: torch.from_numpy(np.array(v)) for k, v in state_head.items()}\\n\\nstate_head[\"bias\"] = torch.zeros(hf_config.vocab_size)\\n\\nprint(state_head.keys())\\n\\nhf_head = hf_roberta.RobertaLMHead(hf_config)\\nhf_head.load_state_dict(state_head, strict=True)'" ] }, - "execution_count": 19, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "# Testing RobertaForMaskedLM\n", + "'''# Testing RobertaForMaskedLM\n", "\n", "my_mlm = my_roberta.RobertaForMaskedLM.init(Vocab, my_config, key=key)\n", - "state = my_mlm.to_state_dict()\n", + "state_mlm = my_mlm.to_state_dict()\n", "\n", - "state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "state_mlm = {k: torch.from_numpy(np.array(v)) for k, v in state_mlm.items()}\n", + "\n", + "# if \"lm_head.decoder.bias\" in state:\n", + "# print(state[\"lm_head.decoder.bias\"])\n", + "# else:\n", + "# print(f\"RobertaForMaskedLM, {state.keys()}\")\n", "\n", - "state[\"lm_head.bias\"] = torch.zeros(hf_config.vocab_size)\n", + "state_mlm[\"lm_head.bias\"] = torch.zeros(hf_config.vocab_size)\n", "\n", - "print(state.keys())\n", + "print(state_mlm.keys())\n", "\n", "hf_mlm = hf_roberta.RobertaForMaskedLM(hf_config)\n", - "hf_mlm.load_state_dict(state, strict=True)\n", + "hf_mlm.load_state_dict(state_mlm, strict=True)\n", "\n", "# Testing RobertaModel\n", "\n", "key_rob, key_head = jrandom.split(key, 2)\n", "\n", "my_model = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=False, key=key_rob)\n", - "state = my_model.to_state_dict()\n", + "state_model = my_model.to_state_dict()\n", "\n", - "state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "state_model = {k: torch.from_numpy(np.array(v)) for k, v in state_model.items()}\n", "\n", - "print(state.keys())\n", + "print(state_model.keys())\n", "\n", "hf_model = hf_roberta.RobertaModel(hf_config, add_pooling_layer=False)\n", - "hf_model.load_state_dict(state, strict=True)\n", + "hf_model.load_state_dict(state_model, strict=True)\n", "\n", "# Testing RobertaLMHead\n", "\n", "my_head = my_roberta.RobertaLMHead.init(Vocab, my_config, key=key_head)\n", - "state = my_head.to_state_dict()\n", + "state_head = my_head.to_state_dict()\n", "\n", - "state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "state_head = {k: torch.from_numpy(np.array(v)) for k, v in state_head.items()}\n", "\n", - "state[\"bias\"] = torch.zeros(hf_config.vocab_size)\n", + "state_head[\"bias\"] = torch.zeros(hf_config.vocab_size)\n", "\n", - "print(state.keys())\n", + "print(state_head.keys())\n", "\n", "hf_head = hf_roberta.RobertaLMHead(hf_config)\n", - "hf_head.load_state_dict(state, strict=True)" + "hf_head.load_state_dict(state_head, strict=True)'''" ] }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 12, "metadata": {}, "outputs": [ { - "name": "stderr", - "output_type": "stream", - "text": [ - "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", - " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", - "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", - " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", - "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", - " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", - "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", - " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Checking RobertaModel\n", - "Success!!!\n", - "Iteration 0, Precision 1:\t1.0\n", - "Iteration 7, Precision 1e-07:\t0.9986305236816406\n", - "Iteration 8, Precision 1e-08:\t0.9982020060221354\n", - "Iteration 14, Precision 1e-14:\t0.9981346130371094\n", - "Accuracy: 1.0\n", - "Checking Roberta Model + LM head\n", - "Success!!!\n", - "Iteration 0, Precision 1:\t1.0\n", - "Iteration 6, Precision 1e-06:\t0.9996785785337711\n", - "Iteration 7, Precision 1e-07:\t0.9964101978265194\n", - "Iteration 8, Precision 1e-08:\t0.995718628767532\n", - "Iteration 14, Precision 1e-14:\t0.9956369328496468\n", - "Accuracy: 1.0\n", - "Checking MLM\n", - "Fail :((((\n", - "Iteration 0, Precision 1:\t1.0\n", - "Iteration 1, Precision 0.1:\t0.6762516280898737\n", - "Iteration 2, Precision 0.01:\t0.0813977132137173\n", - "Iteration 3, Precision 0.001:\t0.00872603715930568\n", - "Iteration 4, Precision 0.0001:\t0.0014499908298517856\n", - "Iteration 5, Precision 1e-05:\t0.000723081729334527\n", - "Iteration 14, Precision 1e-14:\t0.0006434452091415498\n", - "Accuracy: 0.0014499908298517856\n", - "Checking my RobertaModel + LM head and MLM\n", - "Success!!!\n", - "Iteration 0, Precision 1:\t1.0\n", - "Iteration 14, Precision 1e-14:\t1.0\n", - "Accuracy: 1.0\n", - "Checking hf RobertaModel + LM head and MLM\n", - "Fail :((((\n", - "Iteration 0, Precision 1:\t1.0\n", - "Iteration 1, Precision 0.1:\t0.6762513949505122\n", - "Iteration 2, Precision 0.01:\t0.0813977132137173\n", - "Iteration 3, Precision 0.001:\t0.008726678292549488\n", - "Iteration 4, Precision 0.0001:\t0.0014501462560927087\n", - "Iteration 5, Precision 1e-05:\t0.0007229263030936039\n" - ] + "data": { + "text/plain": [ + "'k_rob, k_lm = jrandom.split(key, 2)\\n\\n# MLM\\n\\nmy_output_mlm = my_mlm(input_ids = input_ids, attention_mask=mask, key=key)\\nhf_output_mlm = hf_mlm(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\\n\\n# Model + LM\\n\\nmy_output_model = my_model(input_ids = input_ids, attention_mask=mask, key=k_rob)\\nmy_output = my_head(my_output_model[0], key=k_lm)\\n\\nhf_output_model = hf_model(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\\nhf_output = hf_head(hf_output_model[0])\\n\\n# # MLM\\n# my_output_mlm = my_mlm(input_embeds = input_embeds, attention_mask=mask, key=key)\\n# hf_output_mlm = hf_mlm(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\\n\\n# # Model + LM\\n# my_output_model = my_model(input_embeds = input_embeds, attention_mask=mask, key=k_rob)\\n# my_output = my_head(my_output_model[0], key=k_lm)\\n\\n# hf_output_model = hf_model(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\\n# hf_output = hf_head(hf_output_model[0])\\n\\n# Checks\\n\\nprint(\"\\nChecking RobertaModel\")\\ncheck(my_output_model[0].array, hf_output_model[0].detach(), pppp=True)\\nprint(\"\\nChecking Roberta Model + LM head\")\\ncheck(my_output.array, hf_output.detach(), pppp=True)\\nprint(\"\\nChecking MLM\")\\ncheck(my_output_mlm.array, hf_output_mlm[0].detach(), pppp=True)\\n\\nprint(\"\\nChecking my RobertaModel + LM head and MLM\")\\ncheck(my_output.array, my_output_mlm.array, pppp=True)\\nprint(\"\\nChecking hf RobertaModel + LM head and MLM\")\\ncheck(hf_output.detach(), hf_output_mlm[0].detach(), pppp=True)\\n\\n\\n# Notes\\n# embeds works between hf and my, and within model->head and mlm \\n# ids does not work between hf and my for mlm or within hf for model->head and mlm - so hf mlm is doing something weird.'" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "k_rob, k_lm = jrandom.split(key, 2)\n", + "'''k_rob, k_lm = jrandom.split(key, 2)\n", "\n", "# MLM\n", "\n", @@ -802,22 +857,22 @@ "\n", "# Checks\n", "\n", - "print(\"Checking RobertaModel\")\n", + "print(\"\\nChecking RobertaModel\")\n", "check(my_output_model[0].array, hf_output_model[0].detach(), pppp=True)\n", - "print(\"Checking Roberta Model + LM head\")\n", + "print(\"\\nChecking Roberta Model + LM head\")\n", "check(my_output.array, hf_output.detach(), pppp=True)\n", - "print(\"Checking MLM\")\n", + "print(\"\\nChecking MLM\")\n", "check(my_output_mlm.array, hf_output_mlm[0].detach(), pppp=True)\n", "\n", - "print(\"Checking my RobertaModel + LM head and MLM\")\n", + "print(\"\\nChecking my RobertaModel + LM head and MLM\")\n", "check(my_output.array, my_output_mlm.array, pppp=True)\n", - "print(\"Checking hf RobertaModel + LM head and MLM\")\n", + "print(\"\\nChecking hf RobertaModel + LM head and MLM\")\n", "check(hf_output.detach(), hf_output_mlm[0].detach(), pppp=True)\n", "\n", "\n", "# Notes\n", "# embeds works between hf and my, and within model->head and mlm \n", - "# ids does not work between hf and my for mlm or within hf for model->head and mlm - so hf mlm is doing something weird." + "# ids does not work between hf and my for mlm or within hf for model->head and mlm - so hf mlm is doing something weird.'''" ] } ], From 027b17623913017408345bd5817bc443d8295961 Mon Sep 17 00:00:00 2001 From: Prady Saligram Date: Mon, 26 Aug 2024 15:03:35 -0700 Subject: [PATCH 09/29] Sets RobertaConfig as model architecture and creates default config file --- config/roberta.yaml | 38 ++ src/levanter/main/train_mlm.py | 3 +- src/levanter/models/roberta.py | 847 +++++++++++++++++++++++++++++++++ 3 files changed, 887 insertions(+), 1 deletion(-) create mode 100644 config/roberta.yaml create mode 100644 src/levanter/models/roberta.py diff --git a/config/roberta.yaml b/config/roberta.yaml new file mode 100644 index 000000000..81f5d4d35 --- /dev/null +++ b/config/roberta.yaml @@ -0,0 +1,38 @@ +data: + train_urls: + - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz" + validation_urls: + - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz" + cache_dir: "gs://levanter-data/tokenized/openwebtext_roberta/" + tokenizer: "roberta-base" + +model: + type: roberta + vocab_size: 50265 + hidden_size: 768 + intermediate_size: 3072 + num_hidden_layers: 12 + num_attention_heads: 12 + max_position_embeddings: 512 + hidden_act: "gelu" + hidden_dropout_prob: 0.1 + attention_probs_dropout_prob: 0.1 + gradient_checkpointing: true + +trainer: + tracker: + - type: wandb + project: "levanter" + tags: ["openwebtext", "roberta", "itest"] + + mp: p=f32,c=bfloat16 + model_axis_size: 1 + per_device_parallelism: -1 + + train_batch_size: 32 + num_train_steps: 20000 + +optimizer: + learning_rate: 1E-3 + weight_decay: 0.1 + warmup: 0.01 diff --git a/src/levanter/main/train_mlm.py b/src/levanter/main/train_mlm.py index 208131310..80e941d5a 100644 --- a/src/levanter/main/train_mlm.py +++ b/src/levanter/main/train_mlm.py @@ -20,6 +20,7 @@ from levanter.models.gpt2 import Gpt2Config from levanter.models.llama import LlamaConfig from levanter.models.lm_model import LmConfig +from levanter.models.roberta import RobertaConfig from levanter.optim import AdamConfig, OptimizerConfig from levanter.trainer import Trainer, TrainerConfig from levanter.utils.jax_utils import parameter_count @@ -30,7 +31,7 @@ class TrainMlmConfig: data: Union[LMDatasetConfig, LMMixtureDatasetConfig] = field(default_factory=LMDatasetConfig) trainer: TrainerConfig = field(default_factory=TrainerConfig) - model: LmConfig = field(default_factory=LlamaConfig) + model: LmConfig = field(default_factory=RobertaConfig) optimizer: OptimizerConfig = field(default_factory=AdamConfig) # config related to continued pretraining diff --git a/src/levanter/models/roberta.py b/src/levanter/models/roberta.py new file mode 100644 index 000000000..f51771eff --- /dev/null +++ b/src/levanter/models/roberta.py @@ -0,0 +1,847 @@ +import dataclasses +from dataclasses import dataclass +from typing import Callable, Dict, List, Optional, Tuple, Type, Union + +import equinox as eqx +import jax +import jax.numpy as jnp +import jax.random as jrandom +from jaxtyping import PRNGKeyArray + +import haliax as hax +import haliax.nn as hnn +from haliax import Axis, AxisSpec, NamedArray +from haliax.jax_utils import maybe_rng_split, named_call, shaped_rng_split +from haliax.nn.scan import Stacked + +from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig +from levanter.compat.torch_serialization import ( + StateDict, + StateDictSerializationMixin, + apply_prefix, + flatten_linear_layers, + stack_state_dict, + unflatten_linear_layers, + unstack_state_dict, +) +from levanter.logging import silence_transformer_nag +from levanter.models.attention import AttentionBackend, AttentionMask, simple_attention_with_dropout +from levanter.models.gpt2 import ACT2FN +from levanter.models.lm_model import LmConfig, LmHeadModel +from levanter.types import BlockFoldable +from levanter.utils.flop_utils import lm_flops_per_token + +silence_transformer_nag() +from transformers import PretrainedConfig as HfConfig +from transformers import RobertaConfig as HfRobertaConfig + + + +@LmConfig.register_subclass("roberta") +@dataclass(frozen=True) +class RobertaConfig(HFCompatConfig): + r""" + + Adapted from HuggingFace RobertaConfig, description below + + + This is the configuration class to store the configuration of a [`RobertaModel`] or a [`TFRobertaModel`]. It is + used to instantiate a RoBERTa model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the RoBERTa + [FacebookAI/roberta-base](https://huggingface.co/FacebookAI/roberta-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50265): + Vocabulary size of the RoBERTa model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`RobertaModel`] or [`TFRobertaModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`RobertaModel`] or [`TFRobertaModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + position_embedding_type (`str`, *optional*, defaults to `"absolute"`): + Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For + positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to + [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). + For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models + with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658). + is_decoder (`bool`, *optional*, defaults to `False`): + Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + classifier_dropout (`float`, *optional*): + The dropout ratio for the classification head. + + Examples: + + ```python + >>> from transformers import RobertaConfig, RobertaModel + + >>> # Initializing a RoBERTa configuration + >>> configuration = RobertaConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = RobertaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + vocab_size: int = 50265 + hidden_size: int = 768 + num_hidden_layers: int = 12 + num_attention_heads: int = 12 + intermediate_size: int = 3072 + hidden_act: str = "gelu" + hidden_dropout_prob: float = 0.1 + attention_probs_dropout_prob: float = 0.1 + max_position_embeddings: int = 512 + type_vocab_size: int = 2 + initializer_range: float = 0.02 + layer_norm_eps: float = 1e-12 + pad_token_id: int = 1 + bos_token_id: int = 0 + eos_token_id: int = 2 + position_embedding_type: Optional[str] = "absolute" + use_cache: bool = False + classifier_dropout: Optional[float] = None + + scan_layers: bool = True + gradient_checkpointing: bool = True + + reference_checkpoint: str = "FacebookAI/roberta-base" + tokenizer: Optional[str] = None + + # Axes + Pos = property(lambda self: Axis(name="position", size=self.max_position_embeddings)) + KeyPos = property(lambda self: self.Pos.alias("key_position")) + Embed = property(lambda self: Axis(name="embed", size=self.hidden_size)) + EmbedAtt = property(lambda self: self.Embed.alias("embed_att")) + FinalEmbed = property(lambda self: self.Embed.alias("final_embed")) + Heads = property(lambda self: Axis(name="heads", size=self.num_attention_heads)) + Layers = property(lambda self: Axis(name="layers", size=self.num_hidden_layers)) + Mlp = property(lambda self: Axis(name="mlp", size=self.intermediate_size)) + HeadSize = property(lambda self: Axis(name="head_size", size=self.hidden_size // self.num_attention_heads)) + + + @classmethod + def from_hf_config(cls, hf_config: HfConfig) -> "RobertaConfig": + return RobertaConfig( + vocab_size = hf_config.vocab_size, + hidden_size = hf_config.hidden_size, + num_hidden_layers = hf_config.num_hidden_layers, + num_attention_heads = hf_config.num_attention_heads, + intermediate_size = hf_config.intermediate_size, + hidden_act = hf_config.hidden_act, + hidden_dropout_prob= hf_config.hidden_dropout_prob, + attention_probs_dropout_prob = hf_config.attention_probs_dropout_prob, + max_position_embeddings = hf_config.max_position_embeddings, + type_vocab_size = hf_config.type_vocab_size, + initializer_range = hf_config.initializer_range, + layer_norm_eps = hf_config.layer_norm_eps, + pad_token_id = hf_config.pad_token_id, + bos_token_id = hf_config.bos_token_id, + eos_token_id = hf_config.eos_token_id, + position_embedding_type = hf_config.position_embedding_type, + use_cache = hf_config.use_cache, + classifier_dropout = hf_config.classifier_dropout, + ) + + def hf_checkpoint_converter(self) -> HFCheckpointConverter["RobertaConfig"]: # type: ignore + return HFCheckpointConverter( + self.__class__, + reference_checkpoint=self.reference_checkpoint, + trust_remote_code=True, + tokenizer=self.tokenizer if self.tokenizer else self.reference_checkpoint, + HfConfigClass=HfRobertaConfig, + ) + + def to_hf_config(self, vocab_size: int, config_overrides: Optional[Dict] = None) -> HfRobertaConfig: + """Convert to HuggingFace's LlamaConfig + + Args: + vocab_size (int, optional): Vocabulary size of the tokenizer. Defaults to 32000. + config_overrides (dict, optional): Overrides for the config. Defaults to None. + + Returns: + HfRobertaConfig: HuggingFace's RobertaConfig + """ + + if config_overrides is None: + config_overrides = {} + + return HfRobertaConfig( + vocab_size = vocab_size, + hidden_size = self.hidden_size, + num_hidden_layers = self.num_hidden_layers, + num_attention_heads = self.num_attention_heads, + intermediate_size = self.intermediate_size, + hidden_act = self.hidden_act, + hidden_dropout_prob = self.hidden_dropout_prob, + attention_probs_dropout_prob = self.attention_probs_dropout_prob, + max_position_embeddings = self.max_position_embeddings, + type_vocab_size = self.type_vocab_size, + initializer_range = self.initializer_range, + layer_norm_eps = self.layer_norm_eps, + pad_token_id = self.pad_token_id, + bos_token_id = self.bos_token_id, + eos_token_id = self.eos_token_id, + position_embedding_type = self.position_embedding_type, + use_cache = self.use_cache, + classifier_dropout = self.classifier_dropout, + ) + + @property + def model_type(self) -> Type["RobertaModel"]: + return RobertaModel + + def flops_per_token(self, vocab_size: int): + return lm_flops_per_token( + hidden_dim=self.hidden_size, + intermediate_dim=self.intermediate_size, + num_layers=self.num_hidden_layers, + num_kv_heads=self.num_attention_heads, + num_heads=self.num_attention_heads, + seq_len=self.max_position_embeddings, + vocab_size=vocab_size, + glu=True, + ) + +class RobertaSelfAttention(eqx.Module, StateDictSerializationMixin): + + config: RobertaConfig + Heads: Axis + HeadSize: Axis + EmbedAtt: Axis + + q_proj: hnn.Linear + k_proj: hnn.Linear + v_proj: hnn.Linear + + dropout: hnn.Dropout + position_embedding_type: Optional[str] + + Pos: Axis + KeyPos: Axis + distance_embedding: Optional[hnn.Embedding] + + @staticmethod + def init(config: RobertaConfig, *, key) -> "RobertaSelfAttention": + Embed = config.Embed + EmbedAtt = config.EmbedAtt + + k_q, k_k, k_v, k_e = jrandom.split(key, 4) + q_proj = hnn.Linear.init(In=Embed, Out=EmbedAtt, key=k_q, out_first=True) + k_proj = hnn.Linear.init(In=Embed, Out=EmbedAtt, key=k_k, out_first=True) + v_proj = hnn.Linear.init(In=Embed, Out=EmbedAtt, key=k_v, out_first=True) + + dropout = hnn.Dropout(config.attention_probs_dropout_prob) + + distance_embedding = None + position_embedding_type = config.position_embedding_type + + if position_embedding_type == "relative_key" or position_embedding_type == "relative_key_query": + RelPos = Axis("rel_pos", 2 * config.max_position_embeddings - 1) + distance_embedding = hnn.Embedding.init(RelPos, config.HeadSize, k_e) + + return RobertaSelfAttention(config, config.Heads, config.HeadSize, EmbedAtt, + q_proj, k_proj, v_proj, + dropout, position_embedding_type, + config.Pos, config.KeyPos, distance_embedding, + ) + + def _state_dict_key_map(self) -> Dict[str, Optional[str]]: + return {"q_proj": "query", "k_proj": "key", "v_proj": "value"} + + def _rope_scale_factor(self) -> float: + # hasattr for gemma and I'm feeling lazy + if hasattr(self.config, "rope_scaling") and self.config.rope_scaling is not None: + assert self.config.rope_scaling["type"] == "linear" + return self.config.rope_scaling["factor"] + return 1.0 + + def transpose_for_scores(self, x: NamedArray) -> NamedArray: + # Makes sure to have the correct output order as well + y = hax.rearrange(x, "... position (embed_att: heads head_size) -> ... heads position head_size", heads=self.Heads, head_size=self.HeadSize) + return y + + @named_call + def __call__( + self, + hidden_states: NamedArray, + attention_mask: Optional[NamedArray] = None, + *, + key = None + ) -> Tuple[NamedArray]: + + query_layer = self.transpose_for_scores(self.q_proj(hidden_states)) + key_layer = self.transpose_for_scores(self.k_proj(hidden_states)) + value_layer = self.transpose_for_scores(self.v_proj(hidden_states)) + + if self.position_embedding_type == "rope": + cos, sin = llama_rotary_pos_emb( + self.config.HeadSize, hidden_states.resolve_axis("position"), scale=self._rope_scale_factor() + ) + query_layer, key_layer = _apply_rotary_pos_emb(query_layer, key_layer, cos, sin) + + key_layer = key_layer.rename({"position": "key_position"}) + value_layer = value_layer.rename({"position": "key_position"}) + + attention_scores = hax.dot(query_layer, key_layer, axis=self.HeadSize) # aka hax.einsum("bhld, bhrd -> bhlr") + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + Left = self.Pos # Queries + Right = self.KeyPos # Keys + + position_ids_l = hax.arange(Left).broadcast_to((Left,Right)) + position_ids_r = hax.arange(Right).broadcast_to((Left,Right)) + + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding(distance + self.Pos.size) + + if self.position_embedding_type == "relative_key": + relative_position_scores = hax.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = hax.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = hax.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores /= jnp.sqrt(self.HeadSize.size) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function) + # Attention_mask should have shape Batch Pos, so it should broadcast to shape Batch Heads Pos KeyPos for summation + attention_scores = attention_scores + attention_mask + + attention_probs = hnn.softmax(attention_scores, axis=self.KeyPos) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs, key=key) + + hax.dot(query_layer, key_layer, axis=self.HeadSize) + + context_layer = hax.dot(attention_probs, value_layer, axis=self.KeyPos) + + outputs = hax.rearrange(context_layer, ("... heads position head_size -> ... position (embed_att: heads head_size)"), heads=self.Heads, head_size=self.HeadSize) + + return outputs + +class RobertaSelfOutput(eqx.Module, StateDictSerializationMixin): + dense: hnn.Linear + LayerNorm: hnn.LayerNorm + dropout: hnn.Dropout + + @staticmethod + def init(config: RobertaConfig, *, key) -> "RobertaSelfOutput": + Embed = config.Embed + EmbedAtt = config.EmbedAtt + dense = hnn.Linear.init(In=EmbedAtt, Out=Embed, key=key, out_first=True) + LayerNorm = hnn.LayerNorm.init(axis=Embed, eps=config.layer_norm_eps) + dropout = hnn.Dropout(config.hidden_dropout_prob) + return RobertaSelfOutput(dense, LayerNorm, dropout) + + @named_call + def __call__(self, hidden_states: NamedArray, input: NamedArray,*, key) -> NamedArray: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, key=key) + hidden_states = self.LayerNorm(hidden_states + input) + return hidden_states + +class RobertaAttention(eqx.Module, StateDictSerializationMixin): + self_attn: RobertaSelfAttention + output: RobertaSelfOutput + + @staticmethod + def init(config: RobertaConfig, *, key) -> "RobertaAttention": + k_a, k_o = jrandom.split(key, 2) + + self_attn = RobertaSelfAttention.init(config, key=k_a) + output = RobertaSelfOutput.init(config, key=k_o) + + return RobertaAttention(self_attn, output) + + def _state_dict_key_map(self) -> Dict[str, Optional[str]]: + return {"self_attn": "self"} + + @named_call + def __call__( + self, + hidden_states: NamedArray, + attention_mask: Optional[NamedArray] = None, + *, + key + ) -> NamedArray: + k_a, k_o = maybe_rng_split(key, 2) + + self_outputs = self.self_attn( + hidden_states, + attention_mask, + key=k_a + ) + attention_output = self.output(self_outputs, hidden_states, key=k_o) + return attention_output + +class RobertaIntermediate(eqx.Module, StateDictSerializationMixin): + dense: hnn.Linear + intermediate_act_fn: Callable = eqx.static_field() + + @staticmethod + def init(config, *, key) -> "RobertaIntermediate": + dense = hnn.Linear.init(config.Embed, config.Mlp, key=key, out_first=True) + if isinstance(config.hidden_act, str): + intermediate_act_fn = ACT2FN[config.hidden_act] + else: + intermediate_act_fn = config.hidden_act + + return RobertaIntermediate(dense, intermediate_act_fn) + + @named_call + def __call__(self, hidden_states: NamedArray, *, key = None) -> NamedArray: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + +class RobertaOutput(eqx.Module, StateDictSerializationMixin): + dense: hnn.Linear + LayerNorm: hnn.LayerNorm + dropout: hnn.Dropout + + @staticmethod + def init(config: RobertaConfig, *, key) -> "RobertaSelfOutput": + Embed = config.Embed + dense = hnn.Linear.init(In=config.Mlp, Out=Embed, key=key, out_first=True) + LayerNorm = hnn.LayerNorm.init(axis=Embed, eps=config.layer_norm_eps) + dropout = hnn.Dropout(config.hidden_dropout_prob) + return RobertaSelfOutput(dense, LayerNorm, dropout) + + @named_call + def __call__(self, hidden_states: NamedArray, input: NamedArray, *, key) -> NamedArray: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, key=key) + hidden_states = self.LayerNorm(hidden_states + input) + return hidden_states + +class RobertaLayer(eqx.Module, StateDictSerializationMixin): + attention: RobertaAttention + intermediate: RobertaIntermediate + output: RobertaOutput + + @staticmethod + def init(config: RobertaConfig, *, key) -> "RobertaLayer": + k_a, k_i, k_o = jrandom.split(key, 3) + + attention = RobertaAttention.init(config, key=k_a) + intermediate = RobertaIntermediate.init(config, key=k_i) + output = RobertaOutput.init(config, key=k_o) + + return RobertaLayer(attention, intermediate, output) + + @named_call + def __call__( + self, + hidden_states: NamedArray, + attention_mask: Optional[NamedArray] = None, + *, + key + ) -> Tuple[NamedArray]: + k_a, k_o = maybe_rng_split(key, 2) + + attention_output = self.attention( + hidden_states, + attention_mask, + key=k_a, + ) + + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output, key=k_o) + + return layer_output + + +class RobertaEncoder(eqx.Module, StateDictSerializationMixin): + config: RobertaConfig + layer: BlockFoldable[RobertaLayer] + + @staticmethod + def init(config: RobertaConfig, *, key) -> "RobertaEncoder": + S = Stacked + if not config.scan_layers: + from haliax.nn.scan import BlockSeq + + S = BlockSeq + + layer = S.init(config.Layers, RobertaLayer, gradient_checkpointing=config.gradient_checkpointing)( + config, + key=shaped_rng_split(key, config.num_hidden_layers), #TODO: config.gradient_checkpointing + ) + + return RobertaEncoder(config, layer) + + @named_call + def __call__( + self, + hidden_states: NamedArray, + attention_mask: Optional[NamedArray] = None, + *, + key + ) -> Tuple[NamedArray]: + + keys = maybe_rng_split(key, self.config.num_hidden_layers) if key is not None else None + x = self.layer.fold(hidden_states, attention_mask, key=keys) + + return x + + def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): + if isinstance(self.layer, Stacked): + state_dict = stack_state_dict(state_dict, prefix=apply_prefix(prefix, "layer")) + + out = super().from_state_dict(state_dict, prefix=prefix) + return out + + def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: + my_state_dict: StateDict = {} + super().update_state_dict(my_state_dict, prefix=prefix) + + if isinstance(self.layer, Stacked): + stacked_dict = unstack_state_dict(my_state_dict, prefix=apply_prefix(prefix, "layer")) + state_dict.update(stacked_dict) + else: + state_dict.update(my_state_dict) + + return state_dict + +class RobertaEmbedding(eqx.Module, StateDictSerializationMixin): + Vocab: Axis = eqx.static_field() + Pos: Axis = eqx.static_field() + + word_embeddings: hnn.Embedding + position_embeddings: hnn.Embedding + token_type_embeddings: Optional[hnn.Embedding] + padding_idx: NamedArray + + LayerNorm: hnn.LayerNorm + dropout: hnn.Dropout + position_embedding_type: Optional[str] + + @staticmethod + def init(Vocab: Axis, config: RobertaConfig, *, key) -> "RobertaEmbedding": + key_w, key_p, key_t = jrandom.split(key, 3) + + padding_idx = config.pad_token_id + + word_embeddings = hnn.Embedding.init(Vocab, config.Embed, key=key_w) # padding_idx not specified + position_embeddings = hnn.Embedding.init(config.Pos, config.Embed, key=key_p) + + Token = hax.Axis("token", config.type_vocab_size) + + token_type_embeddings = hnn.Embedding.init(Token, config.Embed, key=key_t) + + LayerNorm = hnn.LayerNorm.init(config.Embed, config.layer_norm_eps) + dropout = hnn.Dropout(config.hidden_dropout_prob) + + return RobertaEmbedding(Vocab, config.Pos, word_embeddings, position_embeddings, token_type_embeddings, padding_idx, LayerNorm, dropout, config.position_embedding_type) + + def create_position_ids_from_input_ids(self, input_ids, past_key_values_length=0): + mask = hax.not_equal(input_ids, self.padding_idx) * 1 + incremental_indices = (hax.cumsum(mask, axis=self.Pos).astype(mask) + past_key_values_length) * mask + return incremental_indices + self.padding_idx + + def create_position_ids_from_inputs_embeds(self, input_axes, PosInput): + position_ids = hax.arange(axis = PosInput, start = self.padding_idx + 1, dtype=jnp.int32) + return hax.broadcast_to(position_ids, input_axes) + + @named_call + def embed(self, input_ids=None, token_type_ids=None, position_ids=None, input_embeds=None, past_key_values_length=0, *, key = None): + """ + Note: When inputting your own embeds into input_embeds, make sure that the embeds axis has the name "embed" + for compatibility with the position_id creation function. Make sures its length is not equal to + """ + + # Get Axes + if input_ids is not None: + input_axes = input_ids.axes + else: + input_axes = hax.eliminate_axes(input_embeds.axes, "embed") + + # Get position_ids + if position_ids is None: + if input_ids is not None: + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = self.create_position_ids_from_input_ids(input_ids, past_key_values_length) + else: + position_ids = self.create_position_ids_from_inputs_embeds(input_axes, input_embeds.resolve_axis("position")) + + # Get token_type_ids + if token_type_ids is None: + token_type_ids = hax.zeros(input_axes, dtype=jnp.int32) + + if input_embeds is None: + input_embeds = self.word_embeddings(input_ids) + + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings = input_embeds + token_type_embeddings + + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings, key=key) + return embeddings + +class RobertaPooler(eqx.Module, StateDictSerializationMixin): + dense: hnn.Linear + config: RobertaConfig + + @staticmethod + def init(config: RobertaConfig, *, key): + dense = hnn.Linear.init(In=config.Embed, Out=config.FinalEmbed, key=key, out_first=True) + + return RobertaPooler(dense, config) + + @named_call + def __call__(self, hidden_states: NamedArray, *, key=None) -> NamedArray: + first_token = hidden_states[{"position" : 0}] + x = self.dense(first_token, key=key).rename({self.config.FinalEmbed: self.config.Embed}) + x = hax.tanh(x) + return x + + +class RobertaModel(eqx.Module, StateDictSerializationMixin): + encoder: RobertaEncoder + embeddings: RobertaEmbedding + pooler : Optional[RobertaPooler] + + @staticmethod + def init(Vocab: Axis, config: RobertaConfig, add_pooling_layer: bool = True, *, key) -> "RobertaModel": + k_t, k_emb, k_p = jrandom.split(key, 3) + encoder = RobertaEncoder.init(config=config, key=k_t) + embeddings = RobertaEmbedding.init(Vocab, config, key=k_emb) + + pooler = RobertaPooler.init(config, key=k_p) if add_pooling_layer else None + return RobertaModel(encoder, embeddings, pooler) + + @property + def config(self): + return self.encoder.config + + @property + def vocab_size(self) -> int: + return self.Vocab.size + + @property + def Vocab(self) -> Axis: + return self.embeddings.Vocab + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + @named_call + def __call__( + self, + input_ids: Optional[NamedArray] = None, + token_type_ids: Optional[NamedArray] = None, + position_ids: Optional[NamedArray] = None, + input_embeds: Optional[NamedArray] = None, + attention_mask: Optional[NamedArray] = None, + *, + key, + ) -> Tuple[NamedArray]: + """ + Not Used: meant to be used to improve performance in decoder implementations + + head_mask: Optional[NamedArray] = None, + encoder_hidden_states: Optional[NamedArray] = None, + encoder_attention_mask: Optional[NamedArray] = None, + past_key_values_length = 0, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + """ + k_emb, k_e, k_p = maybe_rng_split(key, 3) + + if input_ids is not None and input_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_axes = input_ids.axes + elif input_embeds is not None: + input_axes = hax.eliminate_axes(input_embeds.axes, "embed") + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if attention_mask is None: + attention_mask = hax.ones(input_axes) + + # Attention mask from mask to actual numbers + attention_mask = (attention_mask == 0) * -jnp.inf + + embedding_output = self.embeddings.embed(input_ids, token_type_ids, position_ids, input_embeds, key=k_emb) + sequence_output = self.encoder(embedding_output, attention_mask=attention_mask, key=k_e) + + pooled_output = self.pooler(sequence_output, key=k_p) if self.pooler is not None else None + + return (sequence_output, pooled_output) + +class RobertaLMHead(eqx.Module, StateDictSerializationMixin): + """Roberta Head for masked language modeling.""" + + dense: hnn.Linear + layer_norm: hnn.LayerNorm + decoder: hnn.Linear + config: RobertaConfig + + @staticmethod + def init(Vocab: Axis, config: RobertaConfig, *, key): + k_dense, k_decoder = jrandom.split(key, 2) + Embed = config.Embed + + dense = hnn.Linear.init(In=Embed, Out=config.FinalEmbed, key=k_dense, out_first=True) + layer_norm = hnn.LayerNorm.init(axis=Embed, eps=config.layer_norm_eps) + + decoder = hnn.Linear.init(Embed, Vocab, key=k_decoder, out_first=True) + + # idk what this is: TODO + # self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + # self.decoder.bias = self.bias + + return RobertaLMHead(dense, layer_norm, decoder, config) + + @named_call + def __call__(self, features: NamedArray, *, key=None) -> NamedArray: + x = self.dense(features).rename({self.config.FinalEmbed: self.config.Embed}) + x = hnn.gelu(x, approximate=False) + x = self.layer_norm(x) + + # project back to size of vocabulary with bias + x = self.decoder(x) + + return x + +class RobertaForMaskedLM(eqx.Module, StateDictSerializationMixin): + roberta: RobertaModel + lm_head: RobertaLMHead + Vocab: Axis + + @classmethod + def init(self, Vocab: Axis, config: RobertaConfig, *, key): + + # if config.is_decoder: + # raise AttributeError("Model is being run as a MaskedLM aka an encoder model, but is_decoder is true") + + key_rob, key_head = jrandom.split(key, 2) + roberta = RobertaModel.init(Vocab, config, add_pooling_layer=False, key=key_rob) + lm_head = RobertaLMHead.init(Vocab, config, key=key_head) + + return RobertaForMaskedLM(roberta, lm_head, Vocab) + + def get_output_embeddings(self): + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + self.lm_head.decoder = new_embeddings + + @named_call + def __call__( + self, + input_ids: Optional[NamedArray] = None, + attention_mask: Optional[NamedArray] = None, + token_type_ids: Optional[NamedArray] = None, + position_ids: Optional[NamedArray] = None, + input_embeds: Optional[NamedArray] = None, + *, + key=None + ) -> Tuple[NamedArray]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + kwargs (`Dict[str, any]`, optional, defaults to *{}*): + Used to hide legacy arguments that have been deprecated. + """ + + k_rob, k_lm = maybe_rng_split(key, 2) + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + input_embeds=input_embeds, + key=k_rob + ) + + prediction_scores = self.lm_head(outputs[0], key=k_lm) + + return prediction_scores + + +def _rotate_half(x: NamedArray) -> NamedArray: + """Rotates half of the hidden dims of the input and concatenates them.""" + HeadSize = x.axes[-1] + x1 = x[HeadSize, : HeadSize.size // 2] + x2 = x[HeadSize, HeadSize.size // 2 :] + out = hax.concatenate(HeadSize, (-x2, x1)) + return out + + +def _apply_rotary_pos_emb( + q: NamedArray, # [batch, position, kv_heads, q_heads_per_group, head_size] + k: NamedArray, # [batch, position, kv_heads, head_size] + cos: NamedArray, # [position, head_size] + sin: NamedArray, # [position, head_size] +) -> Tuple[NamedArray, NamedArray]: + """Applies rotary position embedding to q and k.""" + q_embed = q * cos + _rotate_half(q) * sin + k_embed = k * cos + _rotate_half(k) * sin + return q_embed, k_embed + + +def llama_rotary_pos_emb( + HeadSize: Axis, Pos: Axis, base: float = 10000, scale: float = 1.0 +) -> Tuple[NamedArray, NamedArray]: + with jax.ensure_compile_time_eval(): + HeadHalfSize = HeadSize.resize(HeadSize.size // 2) + inv_freq: NamedArray = 1.0 / (base ** (hax.arange(HeadHalfSize, step=2) / HeadSize.size)) + + position_ids: NamedArray = hax.arange(Pos) / scale + + freqs = position_ids * inv_freq.broadcast_axis(Pos) + # This is different from the paper but aligns with HF implementation: + # It uses a different permutation in order to obtain the same calculation + emb = hax.concatenate(HeadSize, (freqs, freqs)) + cos = hax.cos(emb) + sin = hax.sin(emb) + # This is different from the paper but aligns with HF implementation: + return cos, sin \ No newline at end of file From 399e08c5024a7d4f80e8d90e4375b19c7ad13366 Mon Sep 17 00:00:00 2001 From: Prady Saligram Date: Sat, 31 Aug 2024 17:23:09 -0700 Subject: [PATCH 10/29] Adds compute_loss to roberta and changes positional ids to begin from 0 --- config/roberta-tiny.yaml | 39 +++++++++++++++++++++++++++++++++ config/roberta.yaml | 2 +- src/levanter/data/text.py | 21 ++++++++++-------- src/levanter/main/train_mlm.py | 12 ++++++---- src/levanter/models/lm_model.py | 8 +++---- src/levanter/models/roberta.py | 2 +- 6 files changed, 65 insertions(+), 19 deletions(-) create mode 100644 config/roberta-tiny.yaml diff --git a/config/roberta-tiny.yaml b/config/roberta-tiny.yaml new file mode 100644 index 000000000..4b61ff7e4 --- /dev/null +++ b/config/roberta-tiny.yaml @@ -0,0 +1,39 @@ +data: + id: dlwh/wikitext_103_detokenized +# train_urls: +# - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz" +# validation_urls: +# - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz" + cache_dir: "cache/roberta-tiny" + tokenizer: "roberta-base" + +model: + type: roberta + vocab_size: 50265 + hidden_size: 32 + intermediate_size: 64 + num_hidden_layers: 4 + num_attention_heads: 2 + max_position_embeddings: 512 + hidden_act: "gelu" + hidden_dropout_prob: 0.1 + attention_probs_dropout_prob: 0.1 + gradient_checkpointing: true + +trainer: + tracker: + - type: wandb + project: "levanter" + tags: ["openwebtext", "roberta", "itest"] + + mp: p=f32,c=bfloat16 + model_axis_size: 1 + per_device_parallelism: -1 + + train_batch_size: 32 + num_train_steps: 20000 + +optimizer: + learning_rate: 1E-3 + weight_decay: 0.1 + warmup: 0.01 \ No newline at end of file diff --git a/config/roberta.yaml b/config/roberta.yaml index 81f5d4d35..c854f8109 100644 --- a/config/roberta.yaml +++ b/config/roberta.yaml @@ -35,4 +35,4 @@ trainer: optimizer: learning_rate: 1E-3 weight_decay: 0.1 - warmup: 0.01 + warmup: 0.01 \ No newline at end of file diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index dfc7df4ea..d0898f2f0 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -68,10 +68,11 @@ def __init__( dataset: ShardableDataset[np.ndarray], QPos: Axis, KPos: Axis, + mask_token_id: int, mask_prob: float = 0.15, noise_prob: float = 0.1, key: Optional[PRNGKeyArray] = None, - ignore_index: Optional[int] = DEFAULT_IGNORE_INDEX, + # ignore_index: Optional[int] = DEFAULT_IGNORE_INDEX, ): self.dataset = dataset self.QPos = QPos @@ -79,14 +80,16 @@ def __init__( self.mask_prob = mask_prob self.noise_prob = noise_prob self.key = key - self.ignore_id = ignore_index if ignore_index is not None else DEFAULT_IGNORE_INDEX + self.mask_token_id = mask_token_id if self.mask_prob > 0.0 and self.key is None: raise ValueError("must provide key if mask_prob > 0.0") def shard(self, shard_id: int, num_shards: int) -> "MaskedLmDataset": return MaskedLmDataset( - self.dataset.shard(shard_id, num_shards), self.QPos, self.KPos, self.mask_prob, self.noise_prob, self.key, self.ignore_id + self.dataset.shard(shard_id, num_shards), self.QPos, self.KPos, + self.mask_token_id, + self.mask_prob, self.noise_prob, self.key ) def __iter__(self) -> Iterator[MaskedLmExample]: @@ -105,13 +108,13 @@ def _create_mlm_example(tokens, key): mask = jax.random.bernoulli(this_key, self.mask_prob, mask_shape) rand = jax.random.uniform(this_key, mask_shape) - mask_token = jnp.where(rand < 0.8, self.ignore_id, tokens_array) + mask_token = jnp.where(rand < 0.8, self.mask_token_id, tokens_array) random_tokens = jax.random.randint(this_key, mask_shape, 0, tokens_array.max() + 1) mask_token = jnp.where((rand >= 0.8) & (rand < 0.8 + self.noise_prob), random_tokens, mask_token) masked_tokens = jnp.where(mask, mask_token, tokens_array) - # Set targets to the original tokens where mask is True, otherwise set to ignore_id - targets = jnp.where(mask, tokens_array, self.ignore_id) + # Set targets to the original tokens where mask is True, otherwise set to mask_token_id + targets = jnp.where(mask, tokens_array, self.mask_token_id) masked_tokens_named = hax.named(masked_tokens, self.QPos) targets_named = hax.named(targets, self.QPos) @@ -119,13 +122,13 @@ def _create_mlm_example(tokens, key): attn_mask_shape = (tokens_array.shape[0], tokens_array.shape[0]) attn_mask = hax.named(jnp.ones(attn_mask_shape, dtype=jnp.bool_), (self.QPos, self.KPos)) - example = MaskedLmExample.masked_lm(tokens=masked_tokens_named, targets=targets_named, ignore_id=self.ignore_id, attn_mask=attn_mask) + example = MaskedLmExample.masked_lm(tokens=masked_tokens_named, targets=targets_named, mask_token_id=self.mask_token_id, attn_mask=attn_mask) else: targets_named = hax.named(targets, self.QPos) attn_mask_shape = (tokens_array.shape[0], tokens_array.shape[0]) attn_mask = hax.named(jnp.ones(attn_mask_shape, dtype=jnp.bool_), (self.QPos, self.KPos)) - example = MaskedLmExample.masked_lm(tokens=tokens, targets=targets_named, ignore_id=self.ignore_id, attn_mask=attn_mask) + example = MaskedLmExample.masked_lm(tokens=tokens, targets=targets_named, mask_token_id=self.mask_token_id, attn_mask=attn_mask) return example @@ -900,4 +903,4 @@ def build_caches( @property def sources(self) -> dict[str, LMDatasetSourceConfig]: - return self.configs + return self.configs \ No newline at end of file diff --git a/src/levanter/main/train_mlm.py b/src/levanter/main/train_mlm.py index 80e941d5a..435abe5bf 100644 --- a/src/levanter/main/train_mlm.py +++ b/src/levanter/main/train_mlm.py @@ -7,6 +7,7 @@ from dataclasses import dataclass, field from typing import Optional, Union +import jax import jax.random as jrandom import haliax as hax @@ -86,7 +87,7 @@ def main(config: TrainMlmConfig): # 1. Sets the device mesh # 2. Sets the axis mapping (for fsdp) # 3. Sets the global metrics tracker - with Trainer(config.trainer, optimizer) as trainer: + with Trainer(config.trainer, optimizer) as trainer, jax.disable_jit(True): # randomness in jax is tightly controlled by "keys" which are the states of the random number generators # this makes deterministic training pretty easy seed = config.trainer.seed @@ -108,8 +109,11 @@ def main(config: TrainMlmConfig): KeyPos = config.model.KeyPos tagged_eval_datasets = config.data.tagged_eval_sets(Pos.size) + mask_id = tokenizer.mask_token_id train_dataset = MaskedLmDataset( - config.data.train_set(Pos.size, key=data_key), Pos, KeyPos, mask_prob=config.mlm_prob, key=data_key, ignore_index=config.data.ignore_token_id + config.data.train_set(Pos.size, key=data_key), Pos, KeyPos, + mask_token_id=mask_id, + mask_prob=config.mlm_prob, key=data_key, #ignore_index=config.data.ignore_token_id ) # to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to @@ -150,7 +154,7 @@ def main(config: TrainMlmConfig): logger.warning("No evaluation datasets provided.") else: masked_datasets = [ - (MaskedLmDataset(ds, Pos, KeyPos, mask_prob=config.mlm_prob, key=data_key, ignore_index=config.data.ignore_token_id), tags) + (MaskedLmDataset(ds, Pos, KeyPos, mask_prob=config.mlm_prob, key=data_key, mask_token_id=mask_id), tags) for ds, tags in tagged_eval_datasets ] max_eval_examples_per_ds = config.trainer.max_eval_batches @@ -198,4 +202,4 @@ def compute_log_probs(model, example): trainer.train(state, train_loader) if __name__ == "__main__": - levanter.config.main(main)() + levanter.config.main(main)() \ No newline at end of file diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index c36e0e622..01e1252a8 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -24,7 +24,7 @@ class MaskedLmExample(eqx.Module): @staticmethod def masked_lm( - tokens: hax.NamedArray, targets: hax.NamedArray, attn_mask: hax.NamedArray, ignore_id: Optional[int] = None + tokens: hax.NamedArray, targets: hax.NamedArray, attn_mask: hax.NamedArray, mask_token_id: Optional[int] = None ) -> "MaskedLmExample": if tokens.ndim != 1: raise ValueError("tokens must be a 1D array") @@ -40,8 +40,8 @@ def masked_lm( mask = tokens.array != targets.array loss_mask = hax.named(mask.astype(jnp.float32), Pos) - if ignore_id is not None: - ignore_mask = targets.array != ignore_id + if mask_token_id is not None: + ignore_mask = targets.array != mask_token_id loss_mask = loss_mask * hax.named(ignore_mask.astype(jnp.float32), Pos) return MaskedLmExample(tokens=tokens, targets=targets, loss_mask=loss_mask, attn_mask=attn_mask) @@ -156,4 +156,4 @@ def compute_loss( @property def vocab_size(self) -> int: - return self.Vocab.size + return self.Vocab.size \ No newline at end of file diff --git a/src/levanter/models/roberta.py b/src/levanter/models/roberta.py index f51771eff..ed23708f7 100644 --- a/src/levanter/models/roberta.py +++ b/src/levanter/models/roberta.py @@ -574,7 +574,7 @@ def create_position_ids_from_input_ids(self, input_ids, past_key_values_length=0 return incremental_indices + self.padding_idx def create_position_ids_from_inputs_embeds(self, input_axes, PosInput): - position_ids = hax.arange(axis = PosInput, start = self.padding_idx + 1, dtype=jnp.int32) + position_ids = hax.arange(axis = PosInput, start = 0, dtype=jnp.int32) return hax.broadcast_to(position_ids, input_axes) @named_call From cd4118c77e0a16e0a5a69713dc9fb9d8efb1f751 Mon Sep 17 00:00:00 2001 From: JulienDarve Date: Wed, 4 Sep 2024 10:42:34 -0700 Subject: [PATCH 11/29] Investingating precision loss over time within the model using output_hidden_states implementation in jax model --- src/levanter/models/roberta.py | 57 +- src/levanter/models/testing.ipynb | 939 +++++++++++++++--------------- 2 files changed, 487 insertions(+), 509 deletions(-) diff --git a/src/levanter/models/roberta.py b/src/levanter/models/roberta.py index 3db771350..cbaaf6754 100644 --- a/src/levanter/models/roberta.py +++ b/src/levanter/models/roberta.py @@ -12,7 +12,7 @@ import haliax.nn as hnn from haliax import Axis, AxisSpec, NamedArray from haliax.jax_utils import maybe_rng_split, named_call, shaped_rng_split -from haliax.nn.scan import Stacked +from haliax.nn.scan import BlockSeq from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig from levanter.compat.torch_serialization import ( @@ -482,27 +482,26 @@ def __call__( intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output, key=k_o) - return layer_output + # jax.debug.print("{layer_output}", layer_output=layer_output) + + return (layer_output, layer_output) class RobertaEncoder(eqx.Module, StateDictSerializationMixin): config: RobertaConfig layer: BlockFoldable[RobertaLayer] + output_hidden_states: bool @staticmethod - def init(config: RobertaConfig, *, key) -> "RobertaEncoder": - S = Stacked - if not config.scan_layers: - from haliax.nn.scan import BlockSeq - - S = BlockSeq + def init(config: RobertaConfig, output_hidden_states: bool, *, key) -> "RobertaEncoder": + S = BlockSeq layer = S.init(config.Layers, RobertaLayer, gradient_checkpointing=config.gradient_checkpointing)( config, key=shaped_rng_split(key, config.num_hidden_layers), #TODO: config.gradient_checkpointing ) - return RobertaEncoder(config, layer) + return RobertaEncoder(config, layer, output_hidden_states) @named_call def __call__( @@ -514,14 +513,15 @@ def __call__( ) -> Tuple[NamedArray]: keys = maybe_rng_split(key, self.config.num_hidden_layers) if key is not None else None - x = self.layer.fold(hidden_states, attention_mask, key=keys) - return x + x, intermediates = self.layer.scan(hidden_states, attention_mask, key=keys) + + if not self.output_hidden_states: + return x, None + else: + return x, intermediates def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): - if isinstance(self.layer, Stacked): - state_dict = stack_state_dict(state_dict, prefix=apply_prefix(prefix, "layer")) - out = super().from_state_dict(state_dict, prefix=prefix) return out @@ -529,11 +529,7 @@ def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) my_state_dict: StateDict = {} super().update_state_dict(my_state_dict, prefix=prefix) - if isinstance(self.layer, Stacked): - stacked_dict = unstack_state_dict(my_state_dict, prefix=apply_prefix(prefix, "layer")) - state_dict.update(stacked_dict) - else: - state_dict.update(my_state_dict) + state_dict.update(my_state_dict) return state_dict @@ -638,15 +634,16 @@ class RobertaModel(eqx.Module, StateDictSerializationMixin): encoder: RobertaEncoder embeddings: RobertaEmbedding pooler : Optional[RobertaPooler] + output_hidden_states: bool @staticmethod - def init(Vocab: Axis, config: RobertaConfig, add_pooling_layer: bool = True, *, key) -> "RobertaModel": + def init(Vocab: Axis, config: RobertaConfig, add_pooling_layer: bool = True, output_hidden_states: bool = False, *, key) -> "RobertaModel": k_t, k_emb, k_p = jrandom.split(key, 3) - encoder = RobertaEncoder.init(config=config, key=k_t) + encoder = RobertaEncoder.init(config=config, output_hidden_states=output_hidden_states, key=k_t) embeddings = RobertaEmbedding.init(Vocab, config, key=k_emb) pooler = RobertaPooler.init(config, key=k_p) if add_pooling_layer else None - return RobertaModel(encoder, embeddings, pooler) + return RobertaModel(encoder, embeddings, pooler, output_hidden_states) @property def config(self): @@ -707,11 +704,14 @@ def __call__( attention_mask = (attention_mask == 0) * -jnp.inf embedding_output = self.embeddings.embed(input_ids, token_type_ids, position_ids, input_embeds, key=k_emb) - sequence_output = self.encoder(embedding_output, attention_mask=attention_mask, key=k_e) + + encoder_outputs = self.encoder(embedding_output, attention_mask=attention_mask, key=k_e) + + sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output, key=k_p) if self.pooler is not None else None - return (sequence_output, pooled_output) + return (sequence_output, pooled_output) + encoder_outputs[1:] if self.output_hidden_states else (sequence_output, pooled_output) class RobertaLMHead(eqx.Module, StateDictSerializationMixin): """Roberta Head for masked language modeling.""" @@ -752,18 +752,19 @@ class RobertaForMaskedLM(eqx.Module, StateDictSerializationMixin): roberta: RobertaModel lm_head: RobertaLMHead Vocab: Axis + output_hidden_states: bool @classmethod - def init(self, Vocab: Axis, config: RobertaConfig, *, key): + def init(self, Vocab: Axis, config: RobertaConfig, output_hidden_states: bool = False, *, key): # if config.is_decoder: # raise AttributeError("Model is being run as a MaskedLM aka an encoder model, but is_decoder is true") key_rob, key_head = jrandom.split(key, 2) - roberta = RobertaModel.init(Vocab, config, add_pooling_layer=False, key=key_rob) + roberta = RobertaModel.init(Vocab, config, add_pooling_layer=False, output_hidden_states=output_hidden_states, key=key_rob) lm_head = RobertaLMHead.init(Vocab, config, key=key_head) - return RobertaForMaskedLM(roberta, lm_head, Vocab) + return RobertaForMaskedLM(roberta, lm_head, Vocab, output_hidden_states) def get_output_embeddings(self): return self.lm_head.decoder @@ -804,7 +805,7 @@ def __call__( prediction_scores = self.lm_head(outputs[0], key=k_lm) - return prediction_scores + return (prediction_scores,) + outputs[2:] def _rotate_half(x: NamedArray) -> NamedArray: diff --git a/src/levanter/models/testing.ipynb b/src/levanter/models/testing.ipynb index 1b9e231af..bc340fc1f 100644 --- a/src/levanter/models/testing.ipynb +++ b/src/levanter/models/testing.ipynb @@ -11,7 +11,7 @@ "text": [ "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", - "2024-08-12 14:00:57,074\tINFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n" + "2024-09-04 10:40:36,597\tINFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n" ] } ], @@ -53,7 +53,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "seed: 1723496463\n" + "seed: 1725471642\n" ] } ], @@ -69,45 +69,65 @@ "KeyPos = my_config.KeyPos\n", "Heads = my_config.Heads\n", "\n", + "cut_end_for_bounds = False \n", + "\n", "Batch = hax.Axis(\"batch\", 2)\n", "Vocab = hax.Axis(\"vocab\", my_config.vocab_size)\n", "\n", "keys = jrandom.split(key_vars, 6)\n", "\n", "input_ids = hax.random.randint(keys[0], (Batch, Pos), minval = 3, maxval = my_config.vocab_size)\n", + "\n", + "if cut_end_for_bounds:\n", + " input_ids = input_ids[{\"position\": slice(0,-2)}]\n", + "\n", "input_ids_torch = torch.from_numpy(np.array(input_ids.array))\n", "input_embeds = hax.random.normal(keys[1], (Batch, Pos, Embed))\n", + "\n", + "if cut_end_for_bounds:\n", + " input_embeds = input_embeds[{\"position\": slice(0,-2)}]\n", + "\n", "input_embeds_torch = torch.from_numpy(np.array(input_embeds.array))\n", "\n", "# mask = hax.random.randint(keys[2], (Batch, Pos), minval = 0, maxval = 2)\n", "mask = hax.ones((Batch, Pos))\n", + "\n", + "if cut_end_for_bounds:\n", + " mask = mask[{\"position\": slice(0,-2)}]\n", + "\n", "mask_torch = torch.from_numpy(np.array(mask.array))\n", - "mask_torch_materialized = torch.ones((2, hf_config.num_attention_heads, hf_config.max_position_embeddings, hf_config.max_position_embeddings))\n", + "\n", + "mask_materialized = hax.ones((Batch, Heads, Pos, KeyPos)) * -jnp.inf\n", + "\n", + "if cut_end_for_bounds:\n", + " mask_materialized = mask_materialized[{\"position\": slice(0,-2), \"key_position\": slice(0,-2)}]\n", + "\n", + "mask_materialized_torch = torch.from_numpy(np.array(mask_materialized.array))\n", "\n", "features = input_embeds[{\"position\": 0}]\n", "features_torch = torch.from_numpy(np.array(features.array))\n", "\n", "x_embed_att = input_embeds.rename({\"embed\": \"embed_att\"})\n", + "\n", + "if cut_end_for_bounds:\n", + " x_embed_att = x_embed_att[{\"position\": slice(0,-2)}]\n", + "\n", "x_embed_att_torch = torch.from_numpy(np.array(x_embed_att.array))\n", "x_mlp = hax.random.normal(keys[5], (Batch, Pos, Mlp))\n", + "\n", + "if cut_end_for_bounds:\n", + " x_mlp = x_mlp[{\"position\": slice(0,-2)}]\n", + " \n", "x_mlp_torch = torch.from_numpy(np.array(x_mlp.array))" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Notes:\n", - "- Random Mask causes RobertaModel to have different output" - ] - }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ - "def check(my_output, hf_output, p=False, pp=False, ppp=False, pppp=True, precision=1e-4):\n", + "def check(my_output, hf_output, precision=1e-4):\n", " \n", " assert (np.array(my_output.shape) == np.array(hf_output.shape)).all()\n", " # print(my_output.shape)\n", @@ -115,33 +135,16 @@ "\n", " acc = np.isclose(hf_output, np.array(my_output), rtol=precision, atol=precision).mean()\n", "\n", - " stats = (torch.tensor(np.array(my_output)).abs().mean(), torch.tensor(np.array(hf_output)).abs().mean())\n", + " # stats = (torch.tensor(np.array(my_output)).abs().mean(), torch.tensor(np.array(hf_output)).abs().mean()) \n", + " stats = (torch.linalg.norm(torch.tensor(np.array(my_output))), torch.linalg.norm(torch.tensor(np.array(hf_output))))\n", " \n", - " if p: \n", - " acc = np.isclose(hf_output, np.array(my_output), rtol=precision, atol=precision).mean()\n", - " print(f\"Accuracy: {acc}\")\n", - " print(f\"Jax:\\n{torch.tensor(np.array(my_output))}\\nTorch:\\n{hf_output}\")\n", - "\n", - " if pp:\n", - " diff = torch.tensor(np.array(my_output)) - hf_output\n", - " print(f\"Mean: {diff.abs().mean()}\")\n", - " print(f\"Stdev: {diff.std()}\")\n", - " print(f\"Difference:\\n{diff}\")\n", - "\n", - " if ppp:\n", - " acc_prev = None\n", - " for i in range(15):\n", - " prec = 10 ** (-1*i)\n", - " acc = np.isclose(hf_output, np.array(my_output), rtol=precision, atol=prec).mean()\n", - " if acc_prev is None:\n", - " print(f\"Iteration {i}, Precision {prec}:\\t{acc}\")\n", - " else:\n", - " if np.abs(acc - acc_prev) > 1e-4:\n", - " print(f\"Iteration {i}, Precision {prec}:\\t{acc}\")\n", - " acc_prev = acc\n", - " print(f\"Iteration {i}, Precision {prec}:\\t{acc}\")\n", + " difference = torch.tensor(np.array(my_output)) - torch.tensor(np.array(hf_output))\n", + "\n", + " diffs = difference.abs().mean()\n", + "\n", + " to_print = f\"acc: {acc} \\t norms: {stats} \\t diffs: {diffs}\"\n", " \n", - " return acc, stats" + " return acc, stats, diffs, to_print" ] }, { @@ -150,254 +153,254 @@ "metadata": {}, "outputs": [], "source": [ - "# Testing RobertaSelfOutput\n", + "# # Testing RobertaSelfOutput\n", "\n", - "def test_RobertaSelfOutput(key):\n", - " k_1, k_2 = jrandom.split(key, 2)\n", - " my_func = my_roberta.RobertaSelfOutput.init(my_config, key=k_1)\n", - " state = my_func.to_state_dict()\n", + "# def test_RobertaSelfOutput(key):\n", + "# k_1, k_2 = jrandom.split(key, 2)\n", + "# my_func = my_roberta.RobertaSelfOutput.init(my_config, key=k_1)\n", + "# state = my_func.to_state_dict()\n", "\n", - " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "# state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", "\n", - " # # print(state.keys())\n", + "# # # print(state.keys())\n", "\n", - " hf_func = hf_roberta.RobertaSelfOutput(hf_config)\n", - " hf_func.load_state_dict(state, strict=True)\n", + "# hf_func = hf_roberta.RobertaSelfOutput(hf_config)\n", + "# hf_func.load_state_dict(state, strict=True)\n", "\n", - " my_output = my_func(x_embed_att, input_embeds, key=k_2)\n", - " hf_output = hf_func(x_embed_att_torch, input_embeds_torch)\n", + "# my_output = my_func(x_embed_att, input_embeds, key=k_2)\n", + "# hf_output = hf_func(x_embed_att_torch, input_embeds_torch)\n", "\n", - " return check(my_output.array, hf_output.detach())\n", + "# return check(my_output.array, hf_output.detach())\n", "\n", - "# Testing RobertaSelfAttention\n", + "# # Testing RobertaSelfAttention\n", "\n", - "def test_RobertaSelfAttention(key):\n", - " k_1, k_2 = jrandom.split(key, 2)\n", + "# def test_RobertaSelfAttention(key):\n", + "# k_1, k_2 = jrandom.split(key, 2)\n", "\n", - " my_func = my_roberta.RobertaSelfAttention.init(my_config, key=k_1)\n", - " state = my_func.to_state_dict()\n", + "# my_func = my_roberta.RobertaSelfAttention.init(my_config, key=k_1)\n", + "# state = my_func.to_state_dict()\n", "\n", - " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "# state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", "\n", - " # print(state.keys())\n", + "# # print(state.keys())\n", "\n", - " hf_func = hf_roberta.RobertaSelfAttention(hf_config)\n", - " hf_func.load_state_dict(state, strict=True)\n", + "# hf_func = hf_roberta.RobertaSelfAttention(hf_config)\n", + "# hf_func.load_state_dict(state, strict=True)\n", "\n", - " my_output = my_func(input_embeds, mask, key=k_2)\n", - " hf_output = hf_func(input_embeds_torch, mask_torch_materialized)\n", + "# my_output = my_func(input_embeds, mask, key=k_2)\n", + "# hf_output = hf_func(input_embeds_torch, mask_torch_materialized)\n", "\n", - " return check(my_output.array, hf_output[0].detach())\n", + "# return check(my_output.array, hf_output[0].detach())\n", "\n", - "# Testing RobertaAttention\n", + "# # Testing RobertaAttention\n", "\n", - "def test_RobertaAttention(key):\n", - " k_1, k_2 = jrandom.split(key, 2)\n", - " my_func = my_roberta.RobertaAttention.init(my_config, key=k_1)\n", - " state = my_func.to_state_dict()\n", + "# def test_RobertaAttention(key):\n", + "# k_1, k_2 = jrandom.split(key, 2)\n", + "# my_func = my_roberta.RobertaAttention.init(my_config, key=k_1)\n", + "# state = my_func.to_state_dict()\n", "\n", - " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "# state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", "\n", - " # print(state.keys())\n", + "# # print(state.keys())\n", "\n", - " hf_func = hf_roberta.RobertaAttention(hf_config)\n", - " hf_func.load_state_dict(state, strict=True)\n", + "# hf_func = hf_roberta.RobertaAttention(hf_config)\n", + "# hf_func.load_state_dict(state, strict=True)\n", "\n", - " my_output = my_func(hidden_states=input_embeds, attention_mask=mask, key=k_2)\n", - " hf_output = hf_func(hidden_states=input_embeds_torch, attention_mask=mask_torch_materialized)\n", + "# my_output = my_func(hidden_states=input_embeds, attention_mask=mask, key=k_2)\n", + "# hf_output = hf_func(hidden_states=input_embeds_torch, attention_mask=mask_torch_materialized)\n", "\n", - " return check(my_output.array, hf_output[0].detach())\n", + "# return check(my_output.array, hf_output[0].detach())\n", "\n", - "# Testing RobertaIntermediate\n", + "# # Testing RobertaIntermediate\n", "\n", - "def test_RobertaIntermediate(key):\n", - " k_1, k_2 = jrandom.split(key, 2)\n", - " my_func = my_roberta.RobertaIntermediate.init(my_config, key=k_1)\n", - " state = my_func.to_state_dict()\n", + "# def test_RobertaIntermediate(key):\n", + "# k_1, k_2 = jrandom.split(key, 2)\n", + "# my_func = my_roberta.RobertaIntermediate.init(my_config, key=k_1)\n", + "# state = my_func.to_state_dict()\n", "\n", - " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "# state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", "\n", - " # print(state.keys())\n", + "# # print(state.keys())\n", "\n", - " hf_func = hf_roberta.RobertaIntermediate(hf_config)\n", - " hf_func.load_state_dict(state, strict=True)\n", + "# hf_func = hf_roberta.RobertaIntermediate(hf_config)\n", + "# hf_func.load_state_dict(state, strict=True)\n", "\n", - " my_output = my_func(input_embeds, key=k_2)\n", - " hf_output = hf_func(input_embeds_torch)\n", + "# my_output = my_func(input_embeds, key=k_2)\n", + "# hf_output = hf_func(input_embeds_torch)\n", "\n", - " return check(my_output.array, hf_output.detach())\n", + "# return check(my_output.array, hf_output.detach())\n", "\n", - "# Testing RobertaOutput\n", + "# # Testing RobertaOutput\n", "\n", - "def test_RobertaOutput(key):\n", - " k_1, k_2 = jrandom.split(key, 2)\n", + "# def test_RobertaOutput(key):\n", + "# k_1, k_2 = jrandom.split(key, 2)\n", "\n", - " my_func = my_roberta.RobertaOutput.init(my_config, key=k_1)\n", - " state = my_func.to_state_dict()\n", + "# my_func = my_roberta.RobertaOutput.init(my_config, key=k_1)\n", + "# state = my_func.to_state_dict()\n", "\n", - " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "# state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", "\n", - " # print(state.keys())\n", + "# # print(state.keys())\n", "\n", - " hf_func = hf_roberta.RobertaOutput(hf_config)\n", - " hf_func.load_state_dict(state, strict=True)\n", + "# hf_func = hf_roberta.RobertaOutput(hf_config)\n", + "# hf_func.load_state_dict(state, strict=True)\n", "\n", - " my_output = my_func(x_mlp, input_embeds, key=k_2)\n", - " hf_output = hf_func(x_mlp_torch, input_embeds_torch)\n", + "# my_output = my_func(x_mlp, input_embeds, key=k_2)\n", + "# hf_output = hf_func(x_mlp_torch, input_embeds_torch)\n", "\n", - " return check(my_output.array, hf_output.detach())\n", + "# return check(my_output.array, hf_output.detach())\n", "\n", - "# Testing RobertaLayer\n", + "# # Testing RobertaLayer\n", "\n", - "def test_RobertaLayer(key):\n", - " k_1, k_2 = jrandom.split(key, 2)\n", - " my_func = my_roberta.RobertaLayer.init(my_config, key=k_1)\n", - " state = my_func.to_state_dict()\n", + "# def test_RobertaLayer(key):\n", + "# k_1, k_2 = jrandom.split(key, 2)\n", + "# my_func = my_roberta.RobertaLayer.init(my_config, key=k_1)\n", + "# state = my_func.to_state_dict()\n", "\n", - " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "# state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", "\n", - " # print(state.keys())\n", + "# # print(state.keys())\n", "\n", - " hf_func = hf_roberta.RobertaLayer(hf_config)\n", - " hf_func.load_state_dict(state, strict=True)\n", + "# hf_func = hf_roberta.RobertaLayer(hf_config)\n", + "# hf_func.load_state_dict(state, strict=True)\n", "\n", - " my_output = my_func(hidden_states=input_embeds, attention_mask=mask, key=k_2)\n", - " hf_output = hf_func(hidden_states=input_embeds_torch, attention_mask=mask_torch_materialized)\n", + "# my_output = my_func(hidden_states=input_embeds, attention_mask=mask, key=k_2)\n", + "# hf_output = hf_func(hidden_states=input_embeds_torch, attention_mask=mask_torch_materialized)\n", "\n", - " return check(my_output.array, hf_output[0].detach())\n", + "# return check(my_output.array, hf_output[0].detach())\n", "\n", - "# Testing RobertaEncoder\n", + "# # Testing RobertaEncoder\n", "\n", - "def test_RobertaEncoder(key):\n", - " k_1, k_2 = jrandom.split(key, 2)\n", - " my_func = my_roberta.RobertaEncoder.init(my_config, key=k_1)\n", - " state = my_func.to_state_dict()\n", + "# def test_RobertaEncoder(key):\n", + "# k_1, k_2 = jrandom.split(key, 2)\n", + "# my_func = my_roberta.RobertaEncoder.init(my_config, key=k_1)\n", + "# state = my_func.to_state_dict()\n", "\n", - " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "# state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", "\n", - " # print(state.keys())\n", + "# # print(state.keys())\n", "\n", - " hf_func = hf_roberta.RobertaEncoder(hf_config)\n", - " hf_func.load_state_dict(state, strict=True)\n", + "# hf_func = hf_roberta.RobertaEncoder(hf_config)\n", + "# hf_func.load_state_dict(state, strict=True)\n", "\n", - " my_output = my_func(hidden_states=input_embeds, attention_mask=mask, key=k_2)\n", - " hf_output = hf_func(hidden_states=input_embeds_torch, attention_mask=mask_torch_materialized)\n", + "# my_output = my_func(hidden_states=input_embeds, attention_mask=mask, key=k_2)\n", + "# hf_output = hf_func(hidden_states=input_embeds_torch, attention_mask=mask_torch_materialized)\n", "\n", - " return check(my_output.array, hf_output[0].detach())\n", + "# return check(my_output.array, hf_output[0].detach())\n", "\n", - "# Testing RobertaEmbedding\n", + "# # Testing RobertaEmbedding\n", "\n", - "def test_RobertaEmbedding(key, ids = True):\n", - " k_1, k_2 = jrandom.split(key, 2)\n", - " my_func = my_roberta.RobertaEmbedding.init(Vocab, my_config, key=k_1)\n", - " state = my_func.to_state_dict()\n", + "# def test_RobertaEmbedding(key, ids = True):\n", + "# k_1, k_2 = jrandom.split(key, 2)\n", + "# my_func = my_roberta.RobertaEmbedding.init(Vocab, my_config, key=k_1)\n", + "# state = my_func.to_state_dict()\n", "\n", - " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "# state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", "\n", - " # print(state.keys())\n", + "# # print(state.keys())\n", "\n", - " hf_func = hf_roberta.RobertaEmbeddings(hf_config)\n", - " hf_func.load_state_dict(state, strict=True)\n", + "# hf_func = hf_roberta.RobertaEmbeddings(hf_config)\n", + "# hf_func.load_state_dict(state, strict=True)\n", "\n", - " if ids:\n", - " my_output = my_func.embed(input_ids=input_ids, key=k_2)\n", - " hf_output = hf_func(input_ids=input_ids_torch)\n", - " else: \n", - " my_output = my_func.embed(input_embeds=input_embeds, key=k_2)\n", - " hf_output = hf_func(inputs_embeds=input_embeds_torch)\n", + "# if ids:\n", + "# my_output = my_func.embed(input_ids=input_ids, key=k_2)\n", + "# hf_output = hf_func(input_ids=input_ids_torch)\n", + "# else: \n", + "# my_output = my_func.embed(input_embeds=input_embeds, key=k_2)\n", + "# hf_output = hf_func(inputs_embeds=input_embeds_torch)\n", "\n", - " return check(my_output.array, hf_output.detach())\n", + "# return check(my_output.array, hf_output.detach())\n", "\n", - "# Testing RobertaPooler\n", + "# # Testing RobertaPooler\n", "\n", - "def test_RobertaPooler(key):\n", - " k_1, k_2 = jrandom.split(key, 2)\n", - " my_func = my_roberta.RobertaPooler.init(my_config, key=k_1)\n", - " state = my_func.to_state_dict()\n", + "# def test_RobertaPooler(key):\n", + "# k_1, k_2 = jrandom.split(key, 2)\n", + "# my_func = my_roberta.RobertaPooler.init(my_config, key=k_1)\n", + "# state = my_func.to_state_dict()\n", "\n", - " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "# state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", "\n", - " # print(state.keys())\n", + "# # print(state.keys())\n", "\n", - " hf_func = hf_roberta.RobertaPooler(hf_config)\n", - " hf_func.load_state_dict(state, strict=True)\n", + "# hf_func = hf_roberta.RobertaPooler(hf_config)\n", + "# hf_func.load_state_dict(state, strict=True)\n", "\n", - " my_output = my_func(input_embeds, key=k_2)\n", - " hf_output = hf_func(input_embeds_torch)\n", + "# my_output = my_func(input_embeds, key=k_2)\n", + "# hf_output = hf_func(input_embeds_torch)\n", "\n", - " return check(my_output.array, hf_output.detach())\n", + "# return check(my_output.array, hf_output.detach())\n", "\n", - "# Testing RobertaModel\n", + "# # Testing RobertaModel\n", "\n", - "def test_RobertaModel(key, ids = True, pool = True):\n", - " k_1, k_2 = jrandom.split(key, 2)\n", - " my_func = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=pool, key=k_1)\n", - " state = my_func.to_state_dict()\n", + "# def test_RobertaModel(key, ids = True, pool = True):\n", + "# k_1, k_2 = jrandom.split(key, 2)\n", + "# my_func = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=pool, key=k_1)\n", + "# state = my_func.to_state_dict()\n", "\n", - " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "# state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", "\n", - " # print(state.keys())\n", + "# # print(state.keys())\n", "\n", - " hf_func = hf_roberta.RobertaModel(hf_config, add_pooling_layer=pool)\n", - " hf_func.load_state_dict(state, strict=True)\n", + "# hf_func = hf_roberta.RobertaModel(hf_config, add_pooling_layer=pool)\n", + "# hf_func.load_state_dict(state, strict=True)\n", "\n", - " if ids:\n", - " my_output = my_func(input_ids = input_ids, attention_mask=mask, key=k_2)\n", - " hf_output = hf_func(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\n", - " else:\n", - " my_output = my_func(input_embeds = input_embeds, attention_mask=mask, key=k_2)\n", - " hf_output = hf_func(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\n", + "# if ids:\n", + "# my_output = my_func(input_ids = input_ids, attention_mask=mask, key=k_2)\n", + "# hf_output = hf_func(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\n", + "# else:\n", + "# my_output = my_func(input_embeds = input_embeds, attention_mask=mask, key=k_2)\n", + "# hf_output = hf_func(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\n", "\n", - " if pool:\n", - " return check(my_output[1].array, hf_output[1].detach())\n", - " else:\n", - " return check(my_output[0].array, hf_output[0].detach())\n", + "# if pool:\n", + "# return check(my_output[1].array, hf_output[1].detach())\n", + "# else:\n", + "# return check(my_output[0].array, hf_output[0].detach())\n", "\n", - "# Testing RobertaLMHead\n", + "# # Testing RobertaLMHead\n", "\n", - "def test_RobertaLMHead(key):\n", - " k_1, k_2 = jrandom.split(key, 2)\n", - " my_func = my_roberta.RobertaLMHead.init(Vocab, my_config, key=k_1)\n", - " state = my_func.to_state_dict()\n", + "# def test_RobertaLMHead(key):\n", + "# k_1, k_2 = jrandom.split(key, 2)\n", + "# my_func = my_roberta.RobertaLMHead.init(Vocab, my_config, key=k_1)\n", + "# state = my_func.to_state_dict()\n", "\n", - " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "# state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", "\n", - " state[\"bias\"] = torch.zeros(hf_config.vocab_size)\n", + "# state[\"bias\"] = torch.zeros(hf_config.vocab_size)\n", "\n", - " # print(state.keys())\n", + "# # print(state.keys())\n", "\n", - " hf_func = hf_roberta.RobertaLMHead(hf_config)\n", - " hf_func.load_state_dict(state, strict=True)\n", + "# hf_func = hf_roberta.RobertaLMHead(hf_config)\n", + "# hf_func.load_state_dict(state, strict=True)\n", "\n", - " my_output = my_func(features, key=k_2)\n", - " hf_output = hf_func(features_torch)\n", + "# my_output = my_func(features, key=k_2)\n", + "# hf_output = hf_func(features_torch)\n", "\n", - " return check(my_output.array, hf_output.detach())\n", + "# return check(my_output.array, hf_output.detach())\n", "\n", - "# Testing RobertaForMaskedLM\n", + "# # Testing RobertaForMaskedLM\n", "\n", - "def test_RobertaForMaskedLM(key, ids = True):\n", - " k_1, k_2 = jrandom.split(key, 2)\n", - " my_pool = my_roberta.RobertaForMaskedLM.init(Vocab, my_config, key=k_1)\n", - " state = my_pool.to_state_dict()\n", + "# def test_RobertaForMaskedLM(key, ids = True):\n", + "# k_1, k_2 = jrandom.split(key, 2)\n", + "# my_pool = my_roberta.RobertaForMaskedLM.init(Vocab, my_config, key=k_1)\n", + "# state = my_pool.to_state_dict()\n", "\n", - " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "# state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", "\n", - " state[\"lm_head.bias\"] = torch.zeros(hf_config.vocab_size)\n", + "# state[\"lm_head.bias\"] = torch.zeros(hf_config.vocab_size)\n", "\n", - " # print(state.keys())\n", + "# # print(state.keys())\n", "\n", - " hf_pool = hf_roberta.RobertaForMaskedLM(hf_config)\n", - " hf_pool.load_state_dict(state, strict=True)\n", + "# hf_pool = hf_roberta.RobertaForMaskedLM(hf_config)\n", + "# hf_pool.load_state_dict(state, strict=True)\n", "\n", - " if ids:\n", - " my_output = my_pool(input_ids = input_ids, attention_mask=mask, key=k_2)\n", - " hf_output = hf_pool(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\n", - " else:\n", - " my_output = my_pool(input_embeds = input_embeds, attention_mask=mask, key=k_2)\n", - " hf_output = hf_pool(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\n", + "# if ids:\n", + "# my_output = my_pool(input_ids = input_ids, attention_mask=mask, key=k_2)\n", + "# hf_output = hf_pool(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\n", + "# else:\n", + "# my_output = my_pool(input_embeds = input_embeds, attention_mask=mask, key=k_2)\n", + "# hf_output = hf_pool(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\n", "\n", - " return check(my_output.array, hf_output[0].detach())" + "# return check(my_output.array, hf_output[0].detach())" ] }, { @@ -406,21 +409,7 @@ "metadata": {}, "outputs": [], "source": [ - "def out_func(input):\n", - " acc, stats = input\n", - " if acc < 1:\n", - " return str(acc) + \"\\t<---- here\"\n", - " else:\n", - " return str(acc)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "# seed = time() + 20\n", + "# seed = time()\n", "# print(f\"seed: {int(seed)}\")\n", "# key_vars = jrandom.PRNGKey(int(seed))\n", "# keys = jrandom.split(key_vars, 15)\n", @@ -444,79 +433,25 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 7, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "seed: 1725471643\n" + ] + } + ], "source": [ - "def get_output_RobertaEmbedding(input, key, ids = True):\n", - " k_1, k_2 = jrandom.split(key, 2)\n", - " my_func = my_roberta.RobertaEmbedding.init(Vocab, my_config, key=k_1)\n", - " state = my_func.to_state_dict()\n", - "\n", - " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", - "\n", - " # print(state.keys())\n", - "\n", - " hf_func = hf_roberta.RobertaEmbeddings(hf_config)\n", - " hf_func.load_state_dict(state, strict=True)\n", - "\n", - " input_torch = torch.from_numpy(np.array(input.array))\n", - "\n", - " if ids:\n", - " my_output = my_func.embed(input_ids=input, key=k_2)\n", - " hf_output = hf_func(input_ids=input_torch)\n", - " else: \n", - " my_output = my_func.embed(input_embeds=input, key=k_2)\n", - " hf_output = hf_func(inputs_embeds=input_torch)\n", - "\n", - " return check(my_output.array, hf_output.detach()), (my_output, hf_output)\n", - "\n", - "def get_output_RobertaEncoder(input, key):\n", - " k_1, k_2 = jrandom.split(key, 2)\n", - " my_func = my_roberta.RobertaEncoder.init(my_config, key=k_1)\n", - " state = my_func.to_state_dict()\n", - "\n", - " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", - "\n", - " # print(state.keys())\n", - "\n", - " hf_func = hf_roberta.RobertaEncoder(hf_config)\n", - " hf_func.load_state_dict(state, strict=True)\n", - "\n", - " input_torch = torch.from_numpy(np.array(input.array))\n", - "\n", - " attention_mask = hax.ones((Batch, Heads, Pos, KeyPos)) * -jnp.inf\n", - " attention_mask_torch = torch.from_numpy(np.array(attention_mask.array))\n", - "\n", - " my_output = my_func(hidden_states=input, attention_mask=attention_mask, key=k_2)\n", - " hf_output = hf_func(hidden_states=input_torch, attention_mask=attention_mask_torch)\n", - "\n", - " return check(my_output.array, hf_output[0].detach()), (my_output, hf_output)\n", - "\n", - "def get_output_RobertaPooler(input, key):\n", - " k_1, k_2 = jrandom.split(key, 2)\n", - " my_func = my_roberta.RobertaPooler.init(my_config, key=k_1)\n", - " state = my_func.to_state_dict()\n", - "\n", - " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", - "\n", - " # print(state.keys())\n", - "\n", - " hf_func = hf_roberta.RobertaPooler(hf_config)\n", - " hf_func.load_state_dict(state, strict=True)\n", - "\n", - " input_torch = torch.from_numpy(np.array(input.array))\n", - "\n", - " my_output = my_func(input, key=k_2)\n", - " hf_output = hf_func(input_torch)\n", - "\n", - " return check(my_output.array, hf_output.detach()), (my_output, hf_output)\n", - "\n", - "# Testing RobertaModel\n", + "seed = time()\n", + "print(f\"seed: {int(seed)}\")\n", + "key = jrandom.PRNGKey(int(seed))\n", "\n", - "def get_output_RobertaModel(input, key, ids = True, pool = True):\n", + "def test_RobertaModel_Output(key, ids = False, pool = True):\n", " k_1, k_2 = jrandom.split(key, 2)\n", - " my_func = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=pool, key=k_1)\n", + " my_func = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=pool, output_hidden_states=True, key=k_1)\n", " state = my_func.to_state_dict()\n", "\n", " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", @@ -525,172 +460,104 @@ "\n", " hf_func = hf_roberta.RobertaModel(hf_config, add_pooling_layer=pool)\n", " hf_func.load_state_dict(state, strict=True)\n", - " \n", - " input_torch = torch.from_numpy(np.array(input.array))\n", - " \n", - " # attention_mask = hax.ones((Batch, Heads, Pos, KeyPos)) * -jnp.inf\n", - " # attention_mask_torch = torch.from_numpy(np.array(attention_mask.array))\n", "\n", " if ids:\n", - " my_output = my_func(input_ids = input, attention_mask=mask, key=k_2)\n", - " hf_output = hf_func(input_ids = input_torch, attention_mask=mask_torch, return_dict=False)\n", + " my_output = my_func(input_ids = input_ids, attention_mask=mask, key=k_2)\n", + " hf_output = hf_func(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False, output_hidden_states=True)\n", " else:\n", - " my_output = my_func(input_embeds = input, attention_mask=mask, key=k_2)\n", - " hf_output = hf_func(inputs_embeds = input_torch, attention_mask=mask_torch, return_dict=False)\n", + " my_output = my_func(input_embeds = input_embeds, attention_mask=mask, key=k_2)\n", + " hf_output = hf_func(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False, output_hidden_states=True)\n", "\n", - " if pool:\n", - " return check(my_output[1].array, hf_output[1].detach()), (my_output, hf_output)\n", - " else:\n", - " return check(my_output[0].array, hf_output[0].detach()), (my_output, hf_output)\n" + " return my_output, hf_output\n", + "\n", + "my_output_ids, hf_output_ids = test_RobertaModel_Output(key, ids=True)\n", + "my_output_embeds, hf_output_embeds = test_RobertaModel_Output(key, ids=False)" ] }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "seed: 1723505034\n", - "{'batch': 2, 'position': 514, 'embed': 768}\n", - "(tensor(0.7984), tensor(0.7984))\n", - "(tensor(nan), tensor(nan))\n", - "(tensor(nan), tensor(nan))\n", - "(tensor(0.5408), tensor(0.5408))\n", - "acc_embeds: 1.0\n", - "acc_enc: 0.0\n", - "acc_pool: 0.0\n", - "acc_model: 1.0\n", - "my comparison pool: (0.0, (tensor(nan), tensor(0.5408)))\n", - "my comparison no pool: (0.0, (tensor(nan), tensor(0.7873)))\n", - "hf comparison: (0.0, (tensor(nan), tensor(0.5408)))\n" + "model_out: acc: 1.0 \t norms: (tensor(888.5305), tensor(888.5305)) \t diffs: 3.735790983228071e-07\n", + "pool_out: acc: 1.0 \t norms: (tensor(24.1699), tensor(24.1699)) \t diffs: 2.6738598535303026e-07\n", + "intermediates:\n", + "acc: 1.0 \t norms: (tensor(888.5306), tensor(888.5306)) \t diffs: 1.842970362986307e-07\n", + "acc: 1.0 \t norms: (tensor(888.5308), tensor(888.5308)) \t diffs: 2.625348543006112e-07\n", + "acc: 1.0 \t norms: (tensor(888.5301), tensor(888.5301)) \t diffs: 3.2068956556940975e-07\n", + "acc: 1.0 \t norms: (tensor(888.5295), tensor(888.5295)) \t diffs: 3.396997101390298e-07\n", + "acc: 1.0 \t norms: (tensor(888.5296), tensor(888.5297)) \t diffs: 3.419580139052414e-07\n", + "acc: 1.0 \t norms: (tensor(888.5310), tensor(888.5310)) \t diffs: 3.721844734627666e-07\n", + "acc: 1.0 \t norms: (tensor(888.5298), tensor(888.5299)) \t diffs: 3.591211736875266e-07\n", + "acc: 1.0 \t norms: (tensor(888.5299), tensor(888.5299)) \t diffs: 3.513960677992145e-07\n", + "acc: 1.0 \t norms: (tensor(888.5300), tensor(888.5300)) \t diffs: 3.739319538453856e-07\n", + "acc: 1.0 \t norms: (tensor(888.5305), tensor(888.5305)) \t diffs: 3.6201470265950775e-07\n", + "acc: 1.0 \t norms: (tensor(888.5309), tensor(888.5309)) \t diffs: 3.658180958154844e-07\n", + "acc: 1.0 \t norms: (tensor(888.5305), tensor(888.5305)) \t diffs: 3.735790983228071e-07\n" ] } ], "source": [ - "seed = time() + 30\n", - "print(f\"seed: {int(seed)}\")\n", - "key = jrandom.PRNGKey(int(seed))\n", + "my_out, hf_out = my_output_ids[0], hf_output_ids[0]\n", "\n", - "k_t, k_emb, k_p = jrandom.split(key, 3)\n", + "print(f\"model_out: {check(my_out.array, hf_out.detach())[3]}\")\n", "\n", - "input = input_embeds\n", + "my_pool, hf_pool = my_output_ids[1], hf_output_ids[1]\n", "\n", - "(acc_embeds, stats_embed), (my_out_embeds, hf_out_embeds) = get_output_RobertaEmbedding(input, k_t, ids = False)\n", - "print(stats_embed)\n", - "(acc_enc, stats_enc), (my_out_enc, hf_out_enc) = get_output_RobertaEncoder(my_out_embeds, k_emb)\n", - "print(stats_enc)\n", - "(acc_pool, stats_pool), (my_out_pool, hf_out_pool) = get_output_RobertaPooler(my_out_enc, k_p)\n", - "print(stats_pool)\n", + "print(f\"pool_out: {check(my_pool.array, hf_pool.detach())[3]}\")\n", "\n", - "(acc_model, stats_model), (my_out_model, hf_out_model) = get_output_RobertaModel(input, key, ids = False, pool = True)\n", - "print(stats_model)\n", + "print(\"intermediates:\")\n", + "my_ints, hf_ints = my_output_ids[2], hf_output_ids[2][1:]\n", "\n", - "print(f\"acc_embeds: {acc_embeds}\")\n", - "print(f\"acc_enc: {acc_enc}\")\n", - "print(f\"acc_pool: {acc_pool}\")\n", - "print(f\"acc_model: {acc_model}\")\n", - "print(f\"my comparison pool: {check(my_out_pool.array, my_out_model[1].array)}\")\n", - "print(f\"my comparison no pool: {check(my_out_enc.array, my_out_model[0].array)}\")\n", - "print(f\"hf comparison: {check(hf_out_pool.detach(), hf_out_model[1].detach())}\")" + "for i,j in zip(my_ints, hf_ints):\n", + " print(check(i.array,j.detach())[3])" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "NamedArray(float32{'batch': 2, 'embed': 768},\n", - "[[nan nan nan ... nan nan nan]\n", - " [nan nan nan ... nan nan nan]])\n", - "NamedArray(float32{'batch': 2, 'position': 514, 'embed': 768},\n", - "[[[nan nan nan ... nan nan nan]\n", - " [nan nan nan ... nan nan nan]\n", - " [nan nan nan ... nan nan nan]\n", - " ...\n", - " [nan nan nan ... nan nan nan]\n", - " [nan nan nan ... nan nan nan]\n", - " [nan nan nan ... nan nan nan]]\n", - "\n", - " [[nan nan nan ... nan nan nan]\n", - " [nan nan nan ... nan nan nan]\n", - " [nan nan nan ... nan nan nan]\n", - " ...\n", - " [nan nan nan ... nan nan nan]\n", - " [nan nan nan ... nan nan nan]\n", - " [nan nan nan ... nan nan nan]]])\n", - "NamedArray(float32{'batch': 2, 'position': 514, 'embed': 768},\n", - "[[[ 0.47864354 1.2938721 1.0534003 ... -1.4044254 0.829634\n", - " 0.00428176]\n", - " [-0.98862374 -0.943986 -1.0448135 ... -0.64666593 0.12967904\n", - " -1.0975188 ]\n", - " [ 0.43071395 -0.60738516 -1.7641208 ... -1.1334671 0.9041689\n", - " 0.9875958 ]\n", - " ...\n", - " [ 1.287216 0.507795 0.23451686 ... -0.9582702 -0.3576718\n", - " 0.6565546 ]\n", - " [ 0.33264828 -0.68922603 0.41440547 ... 0.4528543 -0.6819962\n", - " 0.4289952 ]\n", - " [ 0.725326 1.9756228 1.1881577 ... 0.5643402 0.5135605\n", - " 0.92514485]]\n", - "\n", - " [[ 1.0026733 1.1235753 -1.017235 ... -1.8810284 -0.29097554\n", - " -0.63098675]\n", - " [-0.47498116 1.9341669 0.23969549 ... -0.45160082 -0.955768\n", - " -1.4716814 ]\n", - " [-0.2948639 0.25138515 -1.3983693 ... -0.96624637 0.44848248\n", - " -0.71705264]\n", - " ...\n", - " [-0.4495727 0.07491604 0.919175 ... 0.565745 -0.34500855\n", - " -1.9166113 ]\n", - " [-0.3112308 -0.21019831 0.2379393 ... 1.3521733 0.1243041\n", - " -1.3730545 ]\n", - " [-0.9592797 -1.1558015 -1.3304269 ... 1.4129258 0.69931823\n", - " 0.24171986]]])\n", - "(NamedArray(array=Array([[[-0.18572666, -0.80562836, -0.98205453, ..., 0.9597818 ,\n", - " 1.5924176 , 0.3848163 ],\n", - " [-0.19698945, -0.7534842 , -0.81553775, ..., 0.8907304 ,\n", - " 1.6596173 , 0.37368602],\n", - " [-0.14524579, -0.6914802 , -0.91773754, ..., 1.024012 ,\n", - " 1.6283392 , 0.34325117],\n", - " ...,\n", - " [-0.12597106, -0.76580876, -0.8392121 , ..., 0.9352241 ,\n", - " 1.5550641 , 0.46660298],\n", - " [-0.17409518, -0.749031 , -1.056306 , ..., 0.9757236 ,\n", - " 1.633118 , 0.5897971 ],\n", - " [-0.30729672, -0.69016093, -0.87607175, ..., 0.874229 ,\n", - " 1.6674999 , 0.38814685]],\n", - "\n", - " [[-0.2995884 , -0.92027843, -0.78937566, ..., 0.43273145,\n", - " 1.0177305 , 0.4611196 ],\n", - " [-0.3179962 , -0.8442052 , -0.75374943, ..., 0.7022148 ,\n", - " 1.0696493 , 0.3404984 ],\n", - " [-0.33889046, -0.90215874, -0.65796405, ..., 0.5069377 ,\n", - " 1.0210142 , 0.30466717],\n", - " ...,\n", - " [-0.47294238, -0.8416684 , -0.7532904 , ..., 0.46125498,\n", - " 1.1491499 , 0.41495347],\n", - " [-0.28819865, -0.8842407 , -0.69517034, ..., 0.49842533,\n", - " 1.0367949 , 0.58008623],\n", - " [-0.31019112, -0.90532184, -0.7528029 , ..., 0.5191115 ,\n", - " 1.299418 , 0.43962964]]], dtype=float32), axes=(Axis(name='batch', size=2), Axis(name='position', size=514), Axis(name='embed', size=768))), NamedArray(array=Array([[ 0.60376894, 0.34222102, -0.01021165, ..., -0.90009135,\n", - " -0.6602305 , 0.14456 ],\n", - " [ 0.8103463 , 0.4916969 , -0.01769697, ..., -0.94779646,\n", - " -0.8301279 , 0.29279563]], dtype=float32), axes=(Axis(name='batch', size=2), Axis(name='embed', size=768))))\n" + "model_out: acc: 1.0 \t norms: (tensor(888.5308), tensor(888.5308)) \t diffs: 3.864420250465628e-07\n", + "pool_out: acc: 1.0 \t norms: (tensor(24.0400), tensor(24.0400)) \t diffs: 2.5442224682592496e-07\n", + "intermediates:\n", + "acc: 1.0 \t norms: (tensor(888.5295), tensor(888.5295)) \t diffs: 1.46142554058315e-07\n", + "acc: 1.0 \t norms: (tensor(888.5310), tensor(888.5311)) \t diffs: 2.35209114407553e-07\n", + "acc: 1.0 \t norms: (tensor(888.5301), tensor(888.5302)) \t diffs: 3.0417106700042496e-07\n", + "acc: 1.0 \t norms: (tensor(888.5303), tensor(888.5303)) \t diffs: 3.522779365994211e-07\n", + "acc: 1.0 \t norms: (tensor(888.5314), tensor(888.5314)) \t diffs: 3.7978762179591286e-07\n", + "acc: 1.0 \t norms: (tensor(888.5302), tensor(888.5302)) \t diffs: 3.9330373624579806e-07\n", + "acc: 1.0 \t norms: (tensor(888.5305), tensor(888.5304)) \t diffs: 3.8590263784499257e-07\n", + "acc: 1.0 \t norms: (tensor(888.5306), tensor(888.5306)) \t diffs: 3.735180200692412e-07\n", + "acc: 1.0 \t norms: (tensor(888.5292), tensor(888.5291)) \t diffs: 3.75227983795412e-07\n", + "acc: 1.0 \t norms: (tensor(888.5312), tensor(888.5312)) \t diffs: 3.7935546970402356e-07\n", + "acc: 1.0 \t norms: (tensor(888.5292), tensor(888.5292)) \t diffs: 3.81289993356404e-07\n", + "acc: 1.0 \t norms: (tensor(888.5308), tensor(888.5308)) \t diffs: 3.864420250465628e-07\n" ] } ], "source": [ - "print(my_out_pool)\n", - "print(my_out_enc)\n", - "print(my_out_embeds)\n", - "print(my_out_model)" + "my_out, hf_out = my_output_embeds[0], hf_output_embeds[0]\n", + "\n", + "print(f\"model_out: {check(my_out.array, hf_out.detach())[3]}\")\n", + "\n", + "my_pool, hf_pool = my_output_embeds[1], hf_output_embeds[1]\n", + "\n", + "print(f\"pool_out: {check(my_pool.array, hf_pool.detach())[3]}\")\n", + "\n", + "print(\"intermediates:\")\n", + "my_ints, hf_ints = my_output_embeds[2], hf_output_embeds[2][1:]\n", + "\n", + "for i,j in zip(my_ints, hf_ints):\n", + " print(check(i.array,j.detach())[3])" ] }, { @@ -699,72 +566,146 @@ "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "'# Testing RobertaModel\\n\\nmy_model = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=False, key=key)\\nstate = my_model.to_state_dict()\\n\\nstate = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\\n\\n# print(state.keys())\\n\\nhf_model = hf_roberta.RobertaModel(hf_config, add_pooling_layer=False)\\nhf_model.load_state_dict(state, strict=True)\\n\\nmy_output = my_model(input_ids = input_ids, attention_mask=mask, key=key)\\nhf_output = hf_model(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\\n\\ncheck(my_output[0].array, hf_output[0].detach(), ppp=True)\\n\\n# Testing RobertaLMHead\\n\\nmy_head = my_roberta.RobertaLMHead.init(Vocab, my_config, key=key)\\nstate = my_head.to_state_dict()\\n\\nstate = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\\n\\nstate[\"bias\"] = torch.zeros(hf_config.vocab_size)\\n\\n# print(state.keys())\\n\\nhf_head = hf_roberta.RobertaLMHead(hf_config)\\nhf_head.load_state_dict(state, strict=True)\\n\\nmy_output = my_head(my_output[0], key=key)\\nhf_output = hf_head(hf_output[0])\\n\\ncheck(my_output.array, hf_output.detach(), ppp=True)'" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "seed: 1725471659\n" + ] } ], "source": [ - "'''# Testing RobertaModel\n", - "\n", - "my_model = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=False, key=key)\n", - "state = my_model.to_state_dict()\n", - "\n", - "state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "seed = time()\n", + "print(f\"seed: {int(seed)}\")\n", + "key = jrandom.PRNGKey(int(seed))\n", "\n", - "# print(state.keys())\n", + "def test_RobertaForMaskedLM_Output(key, ids = True):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", + " my_func = my_roberta.RobertaForMaskedLM.init(Vocab, my_config, output_hidden_states=True, key=k_1)\n", + " state = my_func.to_state_dict()\n", "\n", - "hf_model = hf_roberta.RobertaModel(hf_config, add_pooling_layer=False)\n", - "hf_model.load_state_dict(state, strict=True)\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + " \n", + " state[\"lm_head.bias\"] = torch.zeros(hf_config.vocab_size)\n", "\n", - "my_output = my_model(input_ids = input_ids, attention_mask=mask, key=key)\n", - "hf_output = hf_model(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\n", + " hf_func = hf_roberta.RobertaForMaskedLM(hf_config)\n", + " hf_func.load_state_dict(state, strict=True)\n", "\n", - "check(my_output[0].array, hf_output[0].detach(), ppp=True)\n", + " if ids:\n", + " my_output = my_func(input_ids = input_ids, attention_mask=mask, key=k_2)\n", + " hf_output = hf_func(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False, output_hidden_states=True)\n", + " else:\n", + " my_output = my_func(input_embeds = input_embeds, attention_mask=mask, key=k_2)\n", + " hf_output = hf_func(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False, output_hidden_states=True)\n", "\n", - "# Testing RobertaLMHead\n", + " return my_output, hf_output\n", "\n", - "my_head = my_roberta.RobertaLMHead.init(Vocab, my_config, key=key)\n", - "state = my_head.to_state_dict()\n", + "my_mlm_output_ids, hf_mlm_output_ids = test_RobertaForMaskedLM_Output(key, ids=True)\n", + "my_mlm_output_embeds, hf_mlm_output_embeds = test_RobertaForMaskedLM_Output(key, ids=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mlm_out: acc: 0.001719795589213743 \t norms: (tensor(7055.4185), tensor(7059.8276)) \t diffs: 0.06642061471939087\n", + "intermediates:\n", + "acc: 0.019604713845654993 \t norms: (tensor(888.5299), tensor(888.5306)) \t diffs: 0.5695138573646545\n", + "acc: 0.024172138456549936 \t norms: (tensor(888.5310), tensor(888.5300)) \t diffs: 0.46615326404571533\n", + "acc: 0.030564759646562904 \t norms: (tensor(888.5312), tensor(888.5323)) \t diffs: 0.37349948287010193\n", + "acc: 0.04000106395914397 \t norms: (tensor(888.5294), tensor(888.5300)) \t diffs: 0.28456440567970276\n", + "acc: 0.05416818660830091 \t norms: (tensor(888.5305), tensor(888.5297)) \t diffs: 0.21144814789295197\n", + "acc: 0.07160065053501946 \t norms: (tensor(888.5291), tensor(888.5300)) \t diffs: 0.16202779114246368\n", + "acc: 0.08982095087548637 \t norms: (tensor(888.5302), tensor(888.5295)) \t diffs: 0.12657980620861053\n", + "acc: 0.11758268482490272 \t norms: (tensor(888.5289), tensor(888.5308)) \t diffs: 0.09682736545801163\n", + "acc: 0.1463805123216602 \t norms: (tensor(888.5303), tensor(888.5312)) \t diffs: 0.07892335206270218\n", + "acc: 0.16748110205901426 \t norms: (tensor(888.5292), tensor(888.5310)) \t diffs: 0.06763624399900436\n", + "acc: 0.18012701645590143 \t norms: (tensor(888.5292), tensor(888.5284)) \t diffs: 0.062167659401893616\n", + "acc: 0.17417264510376135 \t norms: (tensor(888.5310), tensor(888.5300)) \t diffs: 0.06321203708648682\n" + ] + } + ], + "source": [ + "my_out, hf_out = my_mlm_output_ids[0], hf_mlm_output_ids[0]\n", "\n", - "state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "print(f\"mlm_out: {check(my_out.array, hf_out.detach())[3]}\")\n", "\n", - "state[\"bias\"] = torch.zeros(hf_config.vocab_size)\n", + "print(\"intermediates:\")\n", + "my_ints, hf_ints = my_mlm_output_ids[1], hf_mlm_output_ids[1][1:]\n", "\n", - "# print(state.keys())\n", + "for i,j in zip(my_ints, hf_ints):\n", + " print(check(i.array,j.detach(), precision = 0.01)[3])" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mlm_out: acc: 1.0 \t norms: (tensor(7033.6245), tensor(7033.6240)) \t diffs: 5.174669013285893e-07\n", + "intermediates:\n", + "acc: 1.0 \t norms: (tensor(888.5309), tensor(888.5310)) \t diffs: 1.4689783256471856e-07\n", + "acc: 1.0 \t norms: (tensor(888.5306), tensor(888.5306)) \t diffs: 2.3626945733212779e-07\n", + "acc: 1.0 \t norms: (tensor(888.5312), tensor(888.5312)) \t diffs: 3.0318486210489937e-07\n", + "acc: 1.0 \t norms: (tensor(888.5305), tensor(888.5305)) \t diffs: 3.531020809077745e-07\n", + "acc: 1.0 \t norms: (tensor(888.5297), tensor(888.5297)) \t diffs: 3.7493518334486e-07\n", + "acc: 1.0 \t norms: (tensor(888.5297), tensor(888.5297)) \t diffs: 3.8230905374803115e-07\n", + "acc: 1.0 \t norms: (tensor(888.5295), tensor(888.5296)) \t diffs: 3.8595226214965805e-07\n", + "acc: 1.0 \t norms: (tensor(888.5308), tensor(888.5308)) \t diffs: 3.713914793479489e-07\n", + "acc: 1.0 \t norms: (tensor(888.5317), tensor(888.5318)) \t diffs: 3.5173252399545163e-07\n", + "acc: 1.0 \t norms: (tensor(888.5316), tensor(888.5316)) \t diffs: 3.4720503094831656e-07\n", + "acc: 1.0 \t norms: (tensor(888.5290), tensor(888.5290)) \t diffs: 3.541817932273261e-07\n", + "acc: 1.0 \t norms: (tensor(888.5295), tensor(888.5295)) \t diffs: 3.6307170603322447e-07\n" + ] + } + ], + "source": [ + "my_out, hf_out = my_mlm_output_embeds[0], hf_mlm_output_embeds[0]\n", "\n", - "hf_head = hf_roberta.RobertaLMHead(hf_config)\n", - "hf_head.load_state_dict(state, strict=True)\n", + "print(f\"mlm_out: {check(my_out.array, hf_out.detach())[3]}\")\n", "\n", - "my_output = my_head(my_output[0], key=key)\n", - "hf_output = hf_head(hf_output[0])\n", + "print(\"intermediates:\")\n", + "my_ints, hf_ints = my_mlm_output_embeds[1], hf_mlm_output_embeds[1][1:]\n", "\n", - "check(my_output.array, hf_output.detach(), ppp=True)'''" + "for i,j in zip(my_ints, hf_ints):\n", + " print(check(i.array,j.detach())[3])" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 13, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dict_keys(['roberta.encoder.layer.0.attention.self.query.weight', 'roberta.encoder.layer.0.attention.self.query.bias', 'roberta.encoder.layer.0.attention.self.key.weight', 'roberta.encoder.layer.0.attention.self.key.bias', 'roberta.encoder.layer.0.attention.self.value.weight', 'roberta.encoder.layer.0.attention.self.value.bias', 'roberta.encoder.layer.0.attention.output.dense.weight', 'roberta.encoder.layer.0.attention.output.dense.bias', 'roberta.encoder.layer.0.attention.output.LayerNorm.weight', 'roberta.encoder.layer.0.attention.output.LayerNorm.bias', 'roberta.encoder.layer.0.intermediate.dense.weight', 'roberta.encoder.layer.0.intermediate.dense.bias', 'roberta.encoder.layer.0.output.dense.weight', 'roberta.encoder.layer.0.output.dense.bias', 'roberta.encoder.layer.0.output.LayerNorm.weight', 'roberta.encoder.layer.0.output.LayerNorm.bias', 'roberta.encoder.layer.1.attention.self.query.weight', 'roberta.encoder.layer.1.attention.self.query.bias', 'roberta.encoder.layer.1.attention.self.key.weight', 'roberta.encoder.layer.1.attention.self.key.bias', 'roberta.encoder.layer.1.attention.self.value.weight', 'roberta.encoder.layer.1.attention.self.value.bias', 'roberta.encoder.layer.1.attention.output.dense.weight', 'roberta.encoder.layer.1.attention.output.dense.bias', 'roberta.encoder.layer.1.attention.output.LayerNorm.weight', 'roberta.encoder.layer.1.attention.output.LayerNorm.bias', 'roberta.encoder.layer.1.intermediate.dense.weight', 'roberta.encoder.layer.1.intermediate.dense.bias', 'roberta.encoder.layer.1.output.dense.weight', 'roberta.encoder.layer.1.output.dense.bias', 'roberta.encoder.layer.1.output.LayerNorm.weight', 'roberta.encoder.layer.1.output.LayerNorm.bias', 'roberta.encoder.layer.2.attention.self.query.weight', 'roberta.encoder.layer.2.attention.self.query.bias', 'roberta.encoder.layer.2.attention.self.key.weight', 'roberta.encoder.layer.2.attention.self.key.bias', 'roberta.encoder.layer.2.attention.self.value.weight', 'roberta.encoder.layer.2.attention.self.value.bias', 'roberta.encoder.layer.2.attention.output.dense.weight', 'roberta.encoder.layer.2.attention.output.dense.bias', 'roberta.encoder.layer.2.attention.output.LayerNorm.weight', 'roberta.encoder.layer.2.attention.output.LayerNorm.bias', 'roberta.encoder.layer.2.intermediate.dense.weight', 'roberta.encoder.layer.2.intermediate.dense.bias', 'roberta.encoder.layer.2.output.dense.weight', 'roberta.encoder.layer.2.output.dense.bias', 'roberta.encoder.layer.2.output.LayerNorm.weight', 'roberta.encoder.layer.2.output.LayerNorm.bias', 'roberta.encoder.layer.3.attention.self.query.weight', 'roberta.encoder.layer.3.attention.self.query.bias', 'roberta.encoder.layer.3.attention.self.key.weight', 'roberta.encoder.layer.3.attention.self.key.bias', 'roberta.encoder.layer.3.attention.self.value.weight', 'roberta.encoder.layer.3.attention.self.value.bias', 'roberta.encoder.layer.3.attention.output.dense.weight', 'roberta.encoder.layer.3.attention.output.dense.bias', 'roberta.encoder.layer.3.attention.output.LayerNorm.weight', 'roberta.encoder.layer.3.attention.output.LayerNorm.bias', 'roberta.encoder.layer.3.intermediate.dense.weight', 'roberta.encoder.layer.3.intermediate.dense.bias', 'roberta.encoder.layer.3.output.dense.weight', 'roberta.encoder.layer.3.output.dense.bias', 'roberta.encoder.layer.3.output.LayerNorm.weight', 'roberta.encoder.layer.3.output.LayerNorm.bias', 'roberta.encoder.layer.4.attention.self.query.weight', 'roberta.encoder.layer.4.attention.self.query.bias', 'roberta.encoder.layer.4.attention.self.key.weight', 'roberta.encoder.layer.4.attention.self.key.bias', 'roberta.encoder.layer.4.attention.self.value.weight', 'roberta.encoder.layer.4.attention.self.value.bias', 'roberta.encoder.layer.4.attention.output.dense.weight', 'roberta.encoder.layer.4.attention.output.dense.bias', 'roberta.encoder.layer.4.attention.output.LayerNorm.weight', 'roberta.encoder.layer.4.attention.output.LayerNorm.bias', 'roberta.encoder.layer.4.intermediate.dense.weight', 'roberta.encoder.layer.4.intermediate.dense.bias', 'roberta.encoder.layer.4.output.dense.weight', 'roberta.encoder.layer.4.output.dense.bias', 'roberta.encoder.layer.4.output.LayerNorm.weight', 'roberta.encoder.layer.4.output.LayerNorm.bias', 'roberta.encoder.layer.5.attention.self.query.weight', 'roberta.encoder.layer.5.attention.self.query.bias', 'roberta.encoder.layer.5.attention.self.key.weight', 'roberta.encoder.layer.5.attention.self.key.bias', 'roberta.encoder.layer.5.attention.self.value.weight', 'roberta.encoder.layer.5.attention.self.value.bias', 'roberta.encoder.layer.5.attention.output.dense.weight', 'roberta.encoder.layer.5.attention.output.dense.bias', 'roberta.encoder.layer.5.attention.output.LayerNorm.weight', 'roberta.encoder.layer.5.attention.output.LayerNorm.bias', 'roberta.encoder.layer.5.intermediate.dense.weight', 'roberta.encoder.layer.5.intermediate.dense.bias', 'roberta.encoder.layer.5.output.dense.weight', 'roberta.encoder.layer.5.output.dense.bias', 'roberta.encoder.layer.5.output.LayerNorm.weight', 'roberta.encoder.layer.5.output.LayerNorm.bias', 'roberta.encoder.layer.6.attention.self.query.weight', 'roberta.encoder.layer.6.attention.self.query.bias', 'roberta.encoder.layer.6.attention.self.key.weight', 'roberta.encoder.layer.6.attention.self.key.bias', 'roberta.encoder.layer.6.attention.self.value.weight', 'roberta.encoder.layer.6.attention.self.value.bias', 'roberta.encoder.layer.6.attention.output.dense.weight', 'roberta.encoder.layer.6.attention.output.dense.bias', 'roberta.encoder.layer.6.attention.output.LayerNorm.weight', 'roberta.encoder.layer.6.attention.output.LayerNorm.bias', 'roberta.encoder.layer.6.intermediate.dense.weight', 'roberta.encoder.layer.6.intermediate.dense.bias', 'roberta.encoder.layer.6.output.dense.weight', 'roberta.encoder.layer.6.output.dense.bias', 'roberta.encoder.layer.6.output.LayerNorm.weight', 'roberta.encoder.layer.6.output.LayerNorm.bias', 'roberta.encoder.layer.7.attention.self.query.weight', 'roberta.encoder.layer.7.attention.self.query.bias', 'roberta.encoder.layer.7.attention.self.key.weight', 'roberta.encoder.layer.7.attention.self.key.bias', 'roberta.encoder.layer.7.attention.self.value.weight', 'roberta.encoder.layer.7.attention.self.value.bias', 'roberta.encoder.layer.7.attention.output.dense.weight', 'roberta.encoder.layer.7.attention.output.dense.bias', 'roberta.encoder.layer.7.attention.output.LayerNorm.weight', 'roberta.encoder.layer.7.attention.output.LayerNorm.bias', 'roberta.encoder.layer.7.intermediate.dense.weight', 'roberta.encoder.layer.7.intermediate.dense.bias', 'roberta.encoder.layer.7.output.dense.weight', 'roberta.encoder.layer.7.output.dense.bias', 'roberta.encoder.layer.7.output.LayerNorm.weight', 'roberta.encoder.layer.7.output.LayerNorm.bias', 'roberta.encoder.layer.8.attention.self.query.weight', 'roberta.encoder.layer.8.attention.self.query.bias', 'roberta.encoder.layer.8.attention.self.key.weight', 'roberta.encoder.layer.8.attention.self.key.bias', 'roberta.encoder.layer.8.attention.self.value.weight', 'roberta.encoder.layer.8.attention.self.value.bias', 'roberta.encoder.layer.8.attention.output.dense.weight', 'roberta.encoder.layer.8.attention.output.dense.bias', 'roberta.encoder.layer.8.attention.output.LayerNorm.weight', 'roberta.encoder.layer.8.attention.output.LayerNorm.bias', 'roberta.encoder.layer.8.intermediate.dense.weight', 'roberta.encoder.layer.8.intermediate.dense.bias', 'roberta.encoder.layer.8.output.dense.weight', 'roberta.encoder.layer.8.output.dense.bias', 'roberta.encoder.layer.8.output.LayerNorm.weight', 'roberta.encoder.layer.8.output.LayerNorm.bias', 'roberta.encoder.layer.9.attention.self.query.weight', 'roberta.encoder.layer.9.attention.self.query.bias', 'roberta.encoder.layer.9.attention.self.key.weight', 'roberta.encoder.layer.9.attention.self.key.bias', 'roberta.encoder.layer.9.attention.self.value.weight', 'roberta.encoder.layer.9.attention.self.value.bias', 'roberta.encoder.layer.9.attention.output.dense.weight', 'roberta.encoder.layer.9.attention.output.dense.bias', 'roberta.encoder.layer.9.attention.output.LayerNorm.weight', 'roberta.encoder.layer.9.attention.output.LayerNorm.bias', 'roberta.encoder.layer.9.intermediate.dense.weight', 'roberta.encoder.layer.9.intermediate.dense.bias', 'roberta.encoder.layer.9.output.dense.weight', 'roberta.encoder.layer.9.output.dense.bias', 'roberta.encoder.layer.9.output.LayerNorm.weight', 'roberta.encoder.layer.9.output.LayerNorm.bias', 'roberta.encoder.layer.10.attention.self.query.weight', 'roberta.encoder.layer.10.attention.self.query.bias', 'roberta.encoder.layer.10.attention.self.key.weight', 'roberta.encoder.layer.10.attention.self.key.bias', 'roberta.encoder.layer.10.attention.self.value.weight', 'roberta.encoder.layer.10.attention.self.value.bias', 'roberta.encoder.layer.10.attention.output.dense.weight', 'roberta.encoder.layer.10.attention.output.dense.bias', 'roberta.encoder.layer.10.attention.output.LayerNorm.weight', 'roberta.encoder.layer.10.attention.output.LayerNorm.bias', 'roberta.encoder.layer.10.intermediate.dense.weight', 'roberta.encoder.layer.10.intermediate.dense.bias', 'roberta.encoder.layer.10.output.dense.weight', 'roberta.encoder.layer.10.output.dense.bias', 'roberta.encoder.layer.10.output.LayerNorm.weight', 'roberta.encoder.layer.10.output.LayerNorm.bias', 'roberta.encoder.layer.11.attention.self.query.weight', 'roberta.encoder.layer.11.attention.self.query.bias', 'roberta.encoder.layer.11.attention.self.key.weight', 'roberta.encoder.layer.11.attention.self.key.bias', 'roberta.encoder.layer.11.attention.self.value.weight', 'roberta.encoder.layer.11.attention.self.value.bias', 'roberta.encoder.layer.11.attention.output.dense.weight', 'roberta.encoder.layer.11.attention.output.dense.bias', 'roberta.encoder.layer.11.attention.output.LayerNorm.weight', 'roberta.encoder.layer.11.attention.output.LayerNorm.bias', 'roberta.encoder.layer.11.intermediate.dense.weight', 'roberta.encoder.layer.11.intermediate.dense.bias', 'roberta.encoder.layer.11.output.dense.weight', 'roberta.encoder.layer.11.output.dense.bias', 'roberta.encoder.layer.11.output.LayerNorm.weight', 'roberta.encoder.layer.11.output.LayerNorm.bias', 'roberta.embeddings.word_embeddings.weight', 'roberta.embeddings.position_embeddings.weight', 'roberta.embeddings.token_type_embeddings.weight', 'roberta.embeddings.LayerNorm.weight', 'roberta.embeddings.LayerNorm.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.decoder.bias', 'lm_head.bias'])\n", + "dict_keys(['encoder.layer.0.attention.self.query.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.layer.0.intermediate.dense.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.0.output.dense.weight', 'encoder.layer.0.output.dense.bias', 'encoder.layer.0.output.LayerNorm.weight', 'encoder.layer.0.output.LayerNorm.bias', 'encoder.layer.1.attention.self.query.weight', 'encoder.layer.1.attention.self.query.bias', 'encoder.layer.1.attention.self.key.weight', 'encoder.layer.1.attention.self.key.bias', 'encoder.layer.1.attention.self.value.weight', 'encoder.layer.1.attention.self.value.bias', 'encoder.layer.1.attention.output.dense.weight', 'encoder.layer.1.attention.output.dense.bias', 'encoder.layer.1.attention.output.LayerNorm.weight', 'encoder.layer.1.attention.output.LayerNorm.bias', 'encoder.layer.1.intermediate.dense.weight', 'encoder.layer.1.intermediate.dense.bias', 'encoder.layer.1.output.dense.weight', 'encoder.layer.1.output.dense.bias', 'encoder.layer.1.output.LayerNorm.weight', 'encoder.layer.1.output.LayerNorm.bias', 'encoder.layer.2.attention.self.query.weight', 'encoder.layer.2.attention.self.query.bias', 'encoder.layer.2.attention.self.key.weight', 'encoder.layer.2.attention.self.key.bias', 'encoder.layer.2.attention.self.value.weight', 'encoder.layer.2.attention.self.value.bias', 'encoder.layer.2.attention.output.dense.weight', 'encoder.layer.2.attention.output.dense.bias', 'encoder.layer.2.attention.output.LayerNorm.weight', 'encoder.layer.2.attention.output.LayerNorm.bias', 'encoder.layer.2.intermediate.dense.weight', 'encoder.layer.2.intermediate.dense.bias', 'encoder.layer.2.output.dense.weight', 'encoder.layer.2.output.dense.bias', 'encoder.layer.2.output.LayerNorm.weight', 'encoder.layer.2.output.LayerNorm.bias', 'encoder.layer.3.attention.self.query.weight', 'encoder.layer.3.attention.self.query.bias', 'encoder.layer.3.attention.self.key.weight', 'encoder.layer.3.attention.self.key.bias', 'encoder.layer.3.attention.self.value.weight', 'encoder.layer.3.attention.self.value.bias', 'encoder.layer.3.attention.output.dense.weight', 'encoder.layer.3.attention.output.dense.bias', 'encoder.layer.3.attention.output.LayerNorm.weight', 'encoder.layer.3.attention.output.LayerNorm.bias', 'encoder.layer.3.intermediate.dense.weight', 'encoder.layer.3.intermediate.dense.bias', 'encoder.layer.3.output.dense.weight', 'encoder.layer.3.output.dense.bias', 'encoder.layer.3.output.LayerNorm.weight', 'encoder.layer.3.output.LayerNorm.bias', 'encoder.layer.4.attention.self.query.weight', 'encoder.layer.4.attention.self.query.bias', 'encoder.layer.4.attention.self.key.weight', 'encoder.layer.4.attention.self.key.bias', 'encoder.layer.4.attention.self.value.weight', 'encoder.layer.4.attention.self.value.bias', 'encoder.layer.4.attention.output.dense.weight', 'encoder.layer.4.attention.output.dense.bias', 'encoder.layer.4.attention.output.LayerNorm.weight', 'encoder.layer.4.attention.output.LayerNorm.bias', 'encoder.layer.4.intermediate.dense.weight', 'encoder.layer.4.intermediate.dense.bias', 'encoder.layer.4.output.dense.weight', 'encoder.layer.4.output.dense.bias', 'encoder.layer.4.output.LayerNorm.weight', 'encoder.layer.4.output.LayerNorm.bias', 'encoder.layer.5.attention.self.query.weight', 'encoder.layer.5.attention.self.query.bias', 'encoder.layer.5.attention.self.key.weight', 'encoder.layer.5.attention.self.key.bias', 'encoder.layer.5.attention.self.value.weight', 'encoder.layer.5.attention.self.value.bias', 'encoder.layer.5.attention.output.dense.weight', 'encoder.layer.5.attention.output.dense.bias', 'encoder.layer.5.attention.output.LayerNorm.weight', 'encoder.layer.5.attention.output.LayerNorm.bias', 'encoder.layer.5.intermediate.dense.weight', 'encoder.layer.5.intermediate.dense.bias', 'encoder.layer.5.output.dense.weight', 'encoder.layer.5.output.dense.bias', 'encoder.layer.5.output.LayerNorm.weight', 'encoder.layer.5.output.LayerNorm.bias', 'encoder.layer.6.attention.self.query.weight', 'encoder.layer.6.attention.self.query.bias', 'encoder.layer.6.attention.self.key.weight', 'encoder.layer.6.attention.self.key.bias', 'encoder.layer.6.attention.self.value.weight', 'encoder.layer.6.attention.self.value.bias', 'encoder.layer.6.attention.output.dense.weight', 'encoder.layer.6.attention.output.dense.bias', 'encoder.layer.6.attention.output.LayerNorm.weight', 'encoder.layer.6.attention.output.LayerNorm.bias', 'encoder.layer.6.intermediate.dense.weight', 'encoder.layer.6.intermediate.dense.bias', 'encoder.layer.6.output.dense.weight', 'encoder.layer.6.output.dense.bias', 'encoder.layer.6.output.LayerNorm.weight', 'encoder.layer.6.output.LayerNorm.bias', 'encoder.layer.7.attention.self.query.weight', 'encoder.layer.7.attention.self.query.bias', 'encoder.layer.7.attention.self.key.weight', 'encoder.layer.7.attention.self.key.bias', 'encoder.layer.7.attention.self.value.weight', 'encoder.layer.7.attention.self.value.bias', 'encoder.layer.7.attention.output.dense.weight', 'encoder.layer.7.attention.output.dense.bias', 'encoder.layer.7.attention.output.LayerNorm.weight', 'encoder.layer.7.attention.output.LayerNorm.bias', 'encoder.layer.7.intermediate.dense.weight', 'encoder.layer.7.intermediate.dense.bias', 'encoder.layer.7.output.dense.weight', 'encoder.layer.7.output.dense.bias', 'encoder.layer.7.output.LayerNorm.weight', 'encoder.layer.7.output.LayerNorm.bias', 'encoder.layer.8.attention.self.query.weight', 'encoder.layer.8.attention.self.query.bias', 'encoder.layer.8.attention.self.key.weight', 'encoder.layer.8.attention.self.key.bias', 'encoder.layer.8.attention.self.value.weight', 'encoder.layer.8.attention.self.value.bias', 'encoder.layer.8.attention.output.dense.weight', 'encoder.layer.8.attention.output.dense.bias', 'encoder.layer.8.attention.output.LayerNorm.weight', 'encoder.layer.8.attention.output.LayerNorm.bias', 'encoder.layer.8.intermediate.dense.weight', 'encoder.layer.8.intermediate.dense.bias', 'encoder.layer.8.output.dense.weight', 'encoder.layer.8.output.dense.bias', 'encoder.layer.8.output.LayerNorm.weight', 'encoder.layer.8.output.LayerNorm.bias', 'encoder.layer.9.attention.self.query.weight', 'encoder.layer.9.attention.self.query.bias', 'encoder.layer.9.attention.self.key.weight', 'encoder.layer.9.attention.self.key.bias', 'encoder.layer.9.attention.self.value.weight', 'encoder.layer.9.attention.self.value.bias', 'encoder.layer.9.attention.output.dense.weight', 'encoder.layer.9.attention.output.dense.bias', 'encoder.layer.9.attention.output.LayerNorm.weight', 'encoder.layer.9.attention.output.LayerNorm.bias', 'encoder.layer.9.intermediate.dense.weight', 'encoder.layer.9.intermediate.dense.bias', 'encoder.layer.9.output.dense.weight', 'encoder.layer.9.output.dense.bias', 'encoder.layer.9.output.LayerNorm.weight', 'encoder.layer.9.output.LayerNorm.bias', 'encoder.layer.10.attention.self.query.weight', 'encoder.layer.10.attention.self.query.bias', 'encoder.layer.10.attention.self.key.weight', 'encoder.layer.10.attention.self.key.bias', 'encoder.layer.10.attention.self.value.weight', 'encoder.layer.10.attention.self.value.bias', 'encoder.layer.10.attention.output.dense.weight', 'encoder.layer.10.attention.output.dense.bias', 'encoder.layer.10.attention.output.LayerNorm.weight', 'encoder.layer.10.attention.output.LayerNorm.bias', 'encoder.layer.10.intermediate.dense.weight', 'encoder.layer.10.intermediate.dense.bias', 'encoder.layer.10.output.dense.weight', 'encoder.layer.10.output.dense.bias', 'encoder.layer.10.output.LayerNorm.weight', 'encoder.layer.10.output.LayerNorm.bias', 'encoder.layer.11.attention.self.query.weight', 'encoder.layer.11.attention.self.query.bias', 'encoder.layer.11.attention.self.key.weight', 'encoder.layer.11.attention.self.key.bias', 'encoder.layer.11.attention.self.value.weight', 'encoder.layer.11.attention.self.value.bias', 'encoder.layer.11.attention.output.dense.weight', 'encoder.layer.11.attention.output.dense.bias', 'encoder.layer.11.attention.output.LayerNorm.weight', 'encoder.layer.11.attention.output.LayerNorm.bias', 'encoder.layer.11.intermediate.dense.weight', 'encoder.layer.11.intermediate.dense.bias', 'encoder.layer.11.output.dense.weight', 'encoder.layer.11.output.dense.bias', 'encoder.layer.11.output.LayerNorm.weight', 'encoder.layer.11.output.LayerNorm.bias', 'embeddings.word_embeddings.weight', 'embeddings.position_embeddings.weight', 'embeddings.token_type_embeddings.weight', 'embeddings.LayerNorm.weight', 'embeddings.LayerNorm.bias'])\n", + "dict_keys(['dense.weight', 'dense.bias', 'layer_norm.weight', 'layer_norm.bias', 'decoder.weight', 'decoder.bias', 'bias'])\n" + ] + }, { "data": { "text/plain": [ - "'# Testing RobertaForMaskedLM\\n\\nmy_mlm = my_roberta.RobertaForMaskedLM.init(Vocab, my_config, key=key)\\nstate_mlm = my_mlm.to_state_dict()\\n\\nstate_mlm = {k: torch.from_numpy(np.array(v)) for k, v in state_mlm.items()}\\n\\n# if \"lm_head.decoder.bias\" in state:\\n# print(state[\"lm_head.decoder.bias\"])\\n# else:\\n# print(f\"RobertaForMaskedLM, {state.keys()}\")\\n\\nstate_mlm[\"lm_head.bias\"] = torch.zeros(hf_config.vocab_size)\\n\\nprint(state_mlm.keys())\\n\\nhf_mlm = hf_roberta.RobertaForMaskedLM(hf_config)\\nhf_mlm.load_state_dict(state_mlm, strict=True)\\n\\n# Testing RobertaModel\\n\\nkey_rob, key_head = jrandom.split(key, 2)\\n\\nmy_model = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=False, key=key_rob)\\nstate_model = my_model.to_state_dict()\\n\\nstate_model = {k: torch.from_numpy(np.array(v)) for k, v in state_model.items()}\\n\\nprint(state_model.keys())\\n\\nhf_model = hf_roberta.RobertaModel(hf_config, add_pooling_layer=False)\\nhf_model.load_state_dict(state_model, strict=True)\\n\\n# Testing RobertaLMHead\\n\\nmy_head = my_roberta.RobertaLMHead.init(Vocab, my_config, key=key_head)\\nstate_head = my_head.to_state_dict()\\n\\nstate_head = {k: torch.from_numpy(np.array(v)) for k, v in state_head.items()}\\n\\nstate_head[\"bias\"] = torch.zeros(hf_config.vocab_size)\\n\\nprint(state_head.keys())\\n\\nhf_head = hf_roberta.RobertaLMHead(hf_config)\\nhf_head.load_state_dict(state_head, strict=True)'" + "" ] }, - "execution_count": 11, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "'''# Testing RobertaForMaskedLM\n", + "# Testing RobertaForMaskedLM\n", "\n", "my_mlm = my_roberta.RobertaForMaskedLM.init(Vocab, my_config, key=key)\n", "state_mlm = my_mlm.to_state_dict()\n", @@ -809,28 +750,61 @@ "print(state_head.keys())\n", "\n", "hf_head = hf_roberta.RobertaLMHead(hf_config)\n", - "hf_head.load_state_dict(state_head, strict=True)'''" + "hf_head.load_state_dict(state_head, strict=True)" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "k_rob, k_lm = jrandom.split(key, 2)\n", + "\n", + "# MLM\n", + "my_output_mlm = my_mlm(input_embeds = input_embeds, attention_mask=mask, key=key)\n", + "hf_output_mlm = hf_mlm(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\n", + "\n", + "# Model + LM\n", + "my_output_model = my_model(input_embeds = input_embeds, attention_mask=mask, key=k_rob)\n", + "my_output = my_head(my_output_model[0], key=k_lm)\n", + "\n", + "hf_output_model = hf_model(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\n", + "hf_output = hf_head(hf_output_model[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 15, "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "'k_rob, k_lm = jrandom.split(key, 2)\\n\\n# MLM\\n\\nmy_output_mlm = my_mlm(input_ids = input_ids, attention_mask=mask, key=key)\\nhf_output_mlm = hf_mlm(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\\n\\n# Model + LM\\n\\nmy_output_model = my_model(input_ids = input_ids, attention_mask=mask, key=k_rob)\\nmy_output = my_head(my_output_model[0], key=k_lm)\\n\\nhf_output_model = hf_model(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\\nhf_output = hf_head(hf_output_model[0])\\n\\n# # MLM\\n# my_output_mlm = my_mlm(input_embeds = input_embeds, attention_mask=mask, key=key)\\n# hf_output_mlm = hf_mlm(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\\n\\n# # Model + LM\\n# my_output_model = my_model(input_embeds = input_embeds, attention_mask=mask, key=k_rob)\\n# my_output = my_head(my_output_model[0], key=k_lm)\\n\\n# hf_output_model = hf_model(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\\n# hf_output = hf_head(hf_output_model[0])\\n\\n# Checks\\n\\nprint(\"\\nChecking RobertaModel\")\\ncheck(my_output_model[0].array, hf_output_model[0].detach(), pppp=True)\\nprint(\"\\nChecking Roberta Model + LM head\")\\ncheck(my_output.array, hf_output.detach(), pppp=True)\\nprint(\"\\nChecking MLM\")\\ncheck(my_output_mlm.array, hf_output_mlm[0].detach(), pppp=True)\\n\\nprint(\"\\nChecking my RobertaModel + LM head and MLM\")\\ncheck(my_output.array, my_output_mlm.array, pppp=True)\\nprint(\"\\nChecking hf RobertaModel + LM head and MLM\")\\ncheck(hf_output.detach(), hf_output_mlm[0].detach(), pppp=True)\\n\\n\\n# Notes\\n# embeds works between hf and my, and within model->head and mlm \\n# ids does not work between hf and my for mlm or within hf for model->head and mlm - so hf mlm is doing something weird.'" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "RobertaModel: acc: 1.0 \t norms: (tensor(888.5305), tensor(888.5305)) \t diffs: 3.802756225468329e-07\n", + "Roberta Model + LM head: acc: 1.0 \t norms: (tensor(7087.6094), tensor(7087.6094)) \t diffs: 5.642933729177457e-07\n", + "MLM: acc: 1.0 \t norms: (tensor(7087.6094), tensor(7087.6094)) \t diffs: 5.642933729177457e-07\n", + "my RobertaModel + LM head vs MLM: acc: 1.0 \t norms: (tensor(7087.6094), tensor(7087.6094)) \t diffs: 0.0\n", + "hf RobertaModel + LM head vs MLM: acc: 1.0 \t norms: (tensor(7087.6094), tensor(7087.6094)) \t diffs: 0.0\n" + ] } ], "source": [ - "'''k_rob, k_lm = jrandom.split(key, 2)\n", + "print(f\"RobertaModel: {check(my_output_model[0].array, hf_output_model[0].detach())[3]}\")\n", + "print(f\"Roberta Model + LM head: {check(my_output.array, hf_output.detach())[3]}\")\n", + "print(f\"MLM: {check(my_output_mlm[0].array, hf_output_mlm[0].detach())[3]}\")\n", "\n", + "print(f\"my RobertaModel + LM head vs MLM: {check(my_output.array, my_output_mlm[0].array)[3]}\")\n", + "print(f\"hf RobertaModel + LM head vs MLM: {check(hf_output.detach(), hf_output_mlm[0].detach())[3]}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ "# MLM\n", "\n", "my_output_mlm = my_mlm(input_ids = input_ids, attention_mask=mask, key=key)\n", @@ -844,36 +818,39 @@ "hf_output_model = hf_model(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\n", "hf_output = hf_head(hf_output_model[0])\n", "\n", - "# # MLM\n", - "# my_output_mlm = my_mlm(input_embeds = input_embeds, attention_mask=mask, key=key)\n", - "# hf_output_mlm = hf_mlm(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\n", - "\n", - "# # Model + LM\n", - "# my_output_model = my_model(input_embeds = input_embeds, attention_mask=mask, key=k_rob)\n", - "# my_output = my_head(my_output_model[0], key=k_lm)\n", - "\n", - "# hf_output_model = hf_model(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\n", - "# hf_output = hf_head(hf_output_model[0])\n", "\n", "# Checks\n", "\n", - "print(\"\\nChecking RobertaModel\")\n", - "check(my_output_model[0].array, hf_output_model[0].detach(), pppp=True)\n", - "print(\"\\nChecking Roberta Model + LM head\")\n", - "check(my_output.array, hf_output.detach(), pppp=True)\n", - "print(\"\\nChecking MLM\")\n", - "check(my_output_mlm.array, hf_output_mlm[0].detach(), pppp=True)\n", - "\n", - "print(\"\\nChecking my RobertaModel + LM head and MLM\")\n", - "check(my_output.array, my_output_mlm.array, pppp=True)\n", - "print(\"\\nChecking hf RobertaModel + LM head and MLM\")\n", - "check(hf_output.detach(), hf_output_mlm[0].detach(), pppp=True)\n", - "\n", - "\n", "# Notes\n", "# embeds works between hf and my, and within model->head and mlm \n", "# ids does not work between hf and my for mlm or within hf for model->head and mlm - so hf mlm is doing something weird.'''" ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "RobertaModel: acc: 1.0 \t norms: (tensor(888.5287), tensor(888.5287)) \t diffs: 3.736330143055966e-07\n", + "Roberta Model + LM head: acc: 1.0 \t norms: (tensor(7065.4971), tensor(7065.4971)) \t diffs: 5.507896503331722e-07\n", + "MLM: acc: 0.0014420458728273227 \t norms: (tensor(7065.4971), tensor(7062.7324)) \t diffs: 0.07865540683269501\n", + "my RobertaModel + LM head vs MLM: acc: 1.0 \t norms: (tensor(7065.4971), tensor(7065.4971)) \t diffs: 0.0\n", + "hf RobertaModel + LM head vs MLM: acc: 0.0014420071674599332 \t norms: (tensor(7065.4971), tensor(7062.7324)) \t diffs: 0.07865539938211441\n" + ] + } + ], + "source": [ + "print(f\"RobertaModel: {check(my_output_model[0].array, hf_output_model[0].detach())[3]}\")\n", + "print(f\"Roberta Model + LM head: {check(my_output.array, hf_output.detach())[3]}\")\n", + "print(f\"MLM: {check(my_output_mlm[0].array, hf_output_mlm[0].detach())[3]}\")\n", + "\n", + "print(f\"my RobertaModel + LM head vs MLM: {check(my_output.array, my_output_mlm[0].array)[3]}\")\n", + "print(f\"hf RobertaModel + LM head vs MLM: {check(hf_output.detach(), hf_output_mlm[0].detach())[3]}\")" + ] } ], "metadata": { From 8a732e5fb7e0c9545b660a4206ce711afe4f0cf4 Mon Sep 17 00:00:00 2001 From: JulienDarve Date: Tue, 10 Sep 2024 13:53:27 -0700 Subject: [PATCH 12/29] Model can now successfully import weights from huggingface + made attention mask more robust --- src/levanter/models/roberta.py | 8 +- src/levanter/models/testing.ipynb | 1385 ++++++++++++++++++++--------- 2 files changed, 968 insertions(+), 425 deletions(-) diff --git a/src/levanter/models/roberta.py b/src/levanter/models/roberta.py index cbaaf6754..84c85527e 100644 --- a/src/levanter/models/roberta.py +++ b/src/levanter/models/roberta.py @@ -493,7 +493,7 @@ class RobertaEncoder(eqx.Module, StateDictSerializationMixin): output_hidden_states: bool @staticmethod - def init(config: RobertaConfig, output_hidden_states: bool, *, key) -> "RobertaEncoder": + def init(config: RobertaConfig, output_hidden_states: bool = False, *, key) -> "RobertaEncoder": S = BlockSeq layer = S.init(config.Layers, RobertaLayer, gradient_checkpointing=config.gradient_checkpointing)( @@ -570,7 +570,9 @@ def create_position_ids_from_input_ids(self, input_ids, past_key_values_length=0 return incremental_indices + self.padding_idx def create_position_ids_from_inputs_embeds(self, input_axes, PosInput): + # position_ids = hax.arange(axis = PosInput, start = 0, dtype=jnp.int32) position_ids = hax.arange(axis = PosInput, start = self.padding_idx + 1, dtype=jnp.int32) + return hax.broadcast_to(position_ids, input_axes) @named_call @@ -700,8 +702,8 @@ def __call__( if attention_mask is None: attention_mask = hax.ones(input_axes) - # Attention mask from mask to actual numbers - attention_mask = (attention_mask == 0) * -jnp.inf + # Attention mask from mask to actual numbers 0 -> -inf + attention_mask = (attention_mask == 0) * jnp.finfo(jnp.bfloat16).min embedding_output = self.embeddings.embed(input_ids, token_type_ids, position_ids, input_embeds, key=k_emb) diff --git a/src/levanter/models/testing.ipynb b/src/levanter/models/testing.ipynb index bc340fc1f..fc7dbaef1 100644 --- a/src/levanter/models/testing.ipynb +++ b/src/levanter/models/testing.ipynb @@ -2,25 +2,16 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 19, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n", - "2024-09-04 10:40:36,597\tINFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n" - ] - } - ], + "outputs": [], "source": [ "import roberta as my_roberta\n", "from transformers.models.roberta import modeling_roberta as hf_roberta\n", "\n", "import torch\n", "import haliax as hax\n", + "import jax\n", "import jax.random as jrandom\n", "import jax.numpy as jnp\n", "import numpy as np" @@ -28,7 +19,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ @@ -40,28 +31,37 @@ "hf_config = AutoConfig.from_pretrained(hf_model_str)\n", "hf_config.hidden_dropout_prob = 0\n", "hf_config.attention_probs_dropout_prob = 0\n", - "hf_config.pad_token_id = -1\n", + "# hf_config.pad_token_id = -1\n", "my_config = my_roberta.RobertaConfig.from_hf_config(hf_config)" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "seed: 1725471642\n" + "seed: 1725922495\n" ] } ], "source": [ "seed = time()\n", "print(f\"seed: {int(seed)}\")\n", - "key_vars = jrandom.PRNGKey(int(seed))\n", + "key = jrandom.PRNGKey(int(seed))\n", "\n", + "key_vars, key_funcs, key_run = jrandom.split(key, 3)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ "EmbedAtt = my_config.EmbedAtt\n", "Embed = my_config.Embed\n", "Mlp = my_config.Mlp\n", @@ -69,7 +69,7 @@ "KeyPos = my_config.KeyPos\n", "Heads = my_config.Heads\n", "\n", - "cut_end_for_bounds = False \n", + "cut_end_for_bounds = True \n", "\n", "Batch = hax.Axis(\"batch\", 2)\n", "Vocab = hax.Axis(\"vocab\", my_config.vocab_size)\n", @@ -77,53 +77,43 @@ "keys = jrandom.split(key_vars, 6)\n", "\n", "input_ids = hax.random.randint(keys[0], (Batch, Pos), minval = 3, maxval = my_config.vocab_size)\n", - "\n", "if cut_end_for_bounds:\n", " input_ids = input_ids[{\"position\": slice(0,-2)}]\n", - "\n", "input_ids_torch = torch.from_numpy(np.array(input_ids.array))\n", - "input_embeds = hax.random.normal(keys[1], (Batch, Pos, Embed))\n", "\n", + "input_embeds = hax.random.normal(keys[1], (Batch, Pos, Embed))\n", "if cut_end_for_bounds:\n", " input_embeds = input_embeds[{\"position\": slice(0,-2)}]\n", - "\n", "input_embeds_torch = torch.from_numpy(np.array(input_embeds.array))\n", "\n", "# mask = hax.random.randint(keys[2], (Batch, Pos), minval = 0, maxval = 2)\n", - "mask = hax.ones((Batch, Pos))\n", + "# mask = hax.ones((Batch, Pos))\n", + "mask = hax.zeros((Batch, Pos))\n", "\n", "if cut_end_for_bounds:\n", " mask = mask[{\"position\": slice(0,-2)}]\n", - "\n", "mask_torch = torch.from_numpy(np.array(mask.array))\n", "\n", - "mask_materialized = hax.ones((Batch, Heads, Pos, KeyPos)) * -jnp.inf\n", - "\n", - "if cut_end_for_bounds:\n", - " mask_materialized = mask_materialized[{\"position\": slice(0,-2), \"key_position\": slice(0,-2)}]\n", - "\n", - "mask_materialized_torch = torch.from_numpy(np.array(mask_materialized.array))\n", + "mask_materialized = (mask == 0) * jnp.finfo(jnp.bfloat16).min\n", + "mask_torch_materialized = hf_roberta.RobertaModel.get_extended_attention_mask(self=hf_roberta.RobertaModel(hf_config), attention_mask=mask_torch, input_shape=input_embeds_torch.shape)\n", "\n", "features = input_embeds[{\"position\": 0}]\n", "features_torch = torch.from_numpy(np.array(features.array))\n", "\n", "x_embed_att = input_embeds.rename({\"embed\": \"embed_att\"})\n", - "\n", "if cut_end_for_bounds:\n", " x_embed_att = x_embed_att[{\"position\": slice(0,-2)}]\n", - "\n", "x_embed_att_torch = torch.from_numpy(np.array(x_embed_att.array))\n", - "x_mlp = hax.random.normal(keys[5], (Batch, Pos, Mlp))\n", "\n", + "x_mlp = hax.random.normal(keys[5], (Batch, Pos, Mlp))\n", "if cut_end_for_bounds:\n", - " x_mlp = x_mlp[{\"position\": slice(0,-2)}]\n", - " \n", + " x_mlp = x_mlp[{\"position\": slice(0,-2)}] \n", "x_mlp_torch = torch.from_numpy(np.array(x_mlp.array))" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 23, "metadata": {}, "outputs": [], "source": [ @@ -149,359 +139,536 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 54, + "metadata": {}, + "outputs": [], + "source": [ + "def check_dicts(my_dict, hf_dict):\n", + " print(my_dict.keys())\n", + " print(hf_dict.keys())\n", + "\n", + " hf_keys_save = list(hf_dict.keys())\n", + "\n", + " flag = 0\n", + " diff = 0\n", + "\n", + " for k in my_dict.keys():\n", + " i = my_dict[k]\n", + " if k not in hf_dict:\n", + " print(f\"ERROR \\t {k}: key in my_dict but not hf_dict\")\n", + " j = hf_dict[k]\n", + " diff += (np.array(i) - np.array(j)).sum()\n", + " if check(i, j.detach())[0] < 1:\n", + " print(f\"ERROR \\t {k}: {check(i, j)[0]}\")\n", + " flag += 1\n", + " hf_keys_save.remove(k)\n", + "\n", + " if flag == 0:\n", + " print(\"success1\") \n", + " else:\n", + " print(\"fail1\") \n", + "\n", + " if len(hf_keys_save) == 0:\n", + " print(\"success2\") \n", + " else:\n", + " print(\"fail2\")\n", + " print(hf_keys_save)\n", + "\n", + " return diff" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'stop' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[1;32mIn[25], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m \u001b[43mstop\u001b[49m\n", + "\u001b[1;31mNameError\u001b[0m: name 'stop' is not defined" + ] + } + ], + "source": [ + "stop" + ] + }, + { + "cell_type": "code", + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ - "# # Testing RobertaSelfOutput\n", + "# Testing RobertaSelfOutput\n", "\n", - "# def test_RobertaSelfOutput(key):\n", - "# k_1, k_2 = jrandom.split(key, 2)\n", - "# my_func = my_roberta.RobertaSelfOutput.init(my_config, key=k_1)\n", - "# state = my_func.to_state_dict()\n", + "def test_RobertaSelfOutput(key):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", + " my_func = my_roberta.RobertaSelfOutput.init(my_config, key=k_1)\n", + " state = my_func.to_state_dict()\n", "\n", - "# state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", "\n", - "# # # print(state.keys())\n", + " # # print(state.keys())\n", "\n", - "# hf_func = hf_roberta.RobertaSelfOutput(hf_config)\n", - "# hf_func.load_state_dict(state, strict=True)\n", + " hf_func = hf_roberta.RobertaSelfOutput(hf_config)\n", + " hf_func.load_state_dict(state, strict=True, assign=True)\n", "\n", - "# my_output = my_func(x_embed_att, input_embeds, key=k_2)\n", - "# hf_output = hf_func(x_embed_att_torch, input_embeds_torch)\n", + " my_output = my_func(x_embed_att, input_embeds, key=k_2)\n", + " hf_output = hf_func(x_embed_att_torch, input_embeds_torch)\n", "\n", - "# return check(my_output.array, hf_output.detach())\n", + " return check(my_output.array, hf_output.detach())\n", "\n", - "# # Testing RobertaSelfAttention\n", + "# Testing RobertaSelfAttention\n", "\n", - "# def test_RobertaSelfAttention(key):\n", - "# k_1, k_2 = jrandom.split(key, 2)\n", + "def test_RobertaSelfAttention(key):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", "\n", - "# my_func = my_roberta.RobertaSelfAttention.init(my_config, key=k_1)\n", - "# state = my_func.to_state_dict()\n", + " my_func = my_roberta.RobertaSelfAttention.init(my_config, key=k_1)\n", + " state = my_func.to_state_dict()\n", "\n", - "# state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", "\n", - "# # print(state.keys())\n", + " # print(state.keys())\n", "\n", - "# hf_func = hf_roberta.RobertaSelfAttention(hf_config)\n", - "# hf_func.load_state_dict(state, strict=True)\n", + " hf_func = hf_roberta.RobertaSelfAttention(hf_config)\n", + " hf_func.load_state_dict(state, strict=True, assign=True)\n", "\n", - "# my_output = my_func(input_embeds, mask, key=k_2)\n", - "# hf_output = hf_func(input_embeds_torch, mask_torch_materialized)\n", + " my_output = my_func(input_embeds, mask_materialized, key=k_2)\n", + " hf_output = hf_func(input_embeds_torch, mask_torch_materialized)\n", "\n", - "# return check(my_output.array, hf_output[0].detach())\n", + " return check(my_output.array, hf_output[0].detach())\n", "\n", - "# # Testing RobertaAttention\n", + "# Testing RobertaAttention\n", "\n", - "# def test_RobertaAttention(key):\n", - "# k_1, k_2 = jrandom.split(key, 2)\n", - "# my_func = my_roberta.RobertaAttention.init(my_config, key=k_1)\n", - "# state = my_func.to_state_dict()\n", + "def test_RobertaAttention(key):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", + " my_func = my_roberta.RobertaAttention.init(my_config, key=k_1)\n", + " state = my_func.to_state_dict()\n", "\n", - "# state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", "\n", - "# # print(state.keys())\n", + " # print(state.keys())\n", "\n", - "# hf_func = hf_roberta.RobertaAttention(hf_config)\n", - "# hf_func.load_state_dict(state, strict=True)\n", + " hf_func = hf_roberta.RobertaAttention(hf_config)\n", + " hf_func.load_state_dict(state, strict=True, assign=True)\n", "\n", - "# my_output = my_func(hidden_states=input_embeds, attention_mask=mask, key=k_2)\n", - "# hf_output = hf_func(hidden_states=input_embeds_torch, attention_mask=mask_torch_materialized)\n", + " my_output = my_func(hidden_states=input_embeds, attention_mask=mask_materialized, key=k_2)\n", + " hf_output = hf_func(hidden_states=input_embeds_torch, attention_mask=mask_torch_materialized)\n", "\n", - "# return check(my_output.array, hf_output[0].detach())\n", + " return check(my_output.array, hf_output[0].detach())\n", "\n", - "# # Testing RobertaIntermediate\n", + "# Testing RobertaIntermediate\n", "\n", - "# def test_RobertaIntermediate(key):\n", - "# k_1, k_2 = jrandom.split(key, 2)\n", - "# my_func = my_roberta.RobertaIntermediate.init(my_config, key=k_1)\n", - "# state = my_func.to_state_dict()\n", + "def test_RobertaIntermediate(key):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", + " my_func = my_roberta.RobertaIntermediate.init(my_config, key=k_1)\n", + " state = my_func.to_state_dict()\n", "\n", - "# state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", "\n", - "# # print(state.keys())\n", + " # print(state.keys())\n", "\n", - "# hf_func = hf_roberta.RobertaIntermediate(hf_config)\n", - "# hf_func.load_state_dict(state, strict=True)\n", + " hf_func = hf_roberta.RobertaIntermediate(hf_config)\n", + " hf_func.load_state_dict(state, strict=True, assign=True)\n", "\n", - "# my_output = my_func(input_embeds, key=k_2)\n", - "# hf_output = hf_func(input_embeds_torch)\n", + " my_output = my_func(input_embeds, key=k_2)\n", + " hf_output = hf_func(input_embeds_torch)\n", "\n", - "# return check(my_output.array, hf_output.detach())\n", + " return check(my_output.array, hf_output.detach())\n", "\n", - "# # Testing RobertaOutput\n", + "# Testing RobertaOutput\n", "\n", - "# def test_RobertaOutput(key):\n", - "# k_1, k_2 = jrandom.split(key, 2)\n", + "def test_RobertaOutput(key):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", "\n", - "# my_func = my_roberta.RobertaOutput.init(my_config, key=k_1)\n", - "# state = my_func.to_state_dict()\n", + " my_func = my_roberta.RobertaOutput.init(my_config, key=k_1)\n", + " state = my_func.to_state_dict()\n", "\n", - "# state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", "\n", - "# # print(state.keys())\n", + " # print(state.keys())\n", "\n", - "# hf_func = hf_roberta.RobertaOutput(hf_config)\n", - "# hf_func.load_state_dict(state, strict=True)\n", + " hf_func = hf_roberta.RobertaOutput(hf_config)\n", + " hf_func.load_state_dict(state, strict=True, assign=True)\n", "\n", - "# my_output = my_func(x_mlp, input_embeds, key=k_2)\n", - "# hf_output = hf_func(x_mlp_torch, input_embeds_torch)\n", + " my_output = my_func(x_mlp, input_embeds, key=k_2)\n", + " hf_output = hf_func(x_mlp_torch, input_embeds_torch)\n", "\n", - "# return check(my_output.array, hf_output.detach())\n", + " return check(my_output.array, hf_output.detach())\n", "\n", - "# # Testing RobertaLayer\n", + "# Testing RobertaLayer\n", "\n", - "# def test_RobertaLayer(key):\n", - "# k_1, k_2 = jrandom.split(key, 2)\n", - "# my_func = my_roberta.RobertaLayer.init(my_config, key=k_1)\n", - "# state = my_func.to_state_dict()\n", + "def test_RobertaLayer(key):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", + " my_func = my_roberta.RobertaLayer.init(my_config, key=k_1)\n", + " state = my_func.to_state_dict()\n", "\n", - "# state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", "\n", - "# # print(state.keys())\n", + " # print(state.keys())\n", "\n", - "# hf_func = hf_roberta.RobertaLayer(hf_config)\n", - "# hf_func.load_state_dict(state, strict=True)\n", + " hf_func = hf_roberta.RobertaLayer(hf_config)\n", + " hf_func.load_state_dict(state, strict=True, assign=True)\n", "\n", - "# my_output = my_func(hidden_states=input_embeds, attention_mask=mask, key=k_2)\n", - "# hf_output = hf_func(hidden_states=input_embeds_torch, attention_mask=mask_torch_materialized)\n", + " my_output = my_func(hidden_states=input_embeds, attention_mask=mask_materialized, key=k_2)\n", + " hf_output = hf_func(hidden_states=input_embeds_torch, attention_mask=mask_torch_materialized)\n", "\n", - "# return check(my_output.array, hf_output[0].detach())\n", + " return check(my_output[0].array, hf_output[0].detach())\n", "\n", - "# # Testing RobertaEncoder\n", + "# Testing RobertaEncoder\n", "\n", - "# def test_RobertaEncoder(key):\n", - "# k_1, k_2 = jrandom.split(key, 2)\n", - "# my_func = my_roberta.RobertaEncoder.init(my_config, key=k_1)\n", - "# state = my_func.to_state_dict()\n", + "def test_RobertaEncoder(key):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", + " my_func = my_roberta.RobertaEncoder.init(my_config, key=k_1)\n", + " state = my_func.to_state_dict()\n", "\n", - "# state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", "\n", - "# # print(state.keys())\n", + " # print(state.keys())\n", "\n", - "# hf_func = hf_roberta.RobertaEncoder(hf_config)\n", - "# hf_func.load_state_dict(state, strict=True)\n", + " hf_func = hf_roberta.RobertaEncoder(hf_config)\n", + " hf_func.load_state_dict(state, strict=True, assign=True)\n", "\n", - "# my_output = my_func(hidden_states=input_embeds, attention_mask=mask, key=k_2)\n", - "# hf_output = hf_func(hidden_states=input_embeds_torch, attention_mask=mask_torch_materialized)\n", + " my_output = my_func(hidden_states=input_embeds, attention_mask=mask_materialized, key=k_2)\n", + " hf_output = hf_func(hidden_states=input_embeds_torch, attention_mask=mask_torch_materialized)\n", "\n", - "# return check(my_output.array, hf_output[0].detach())\n", + " return check(my_output[0].array, hf_output[0].detach())\n", "\n", - "# # Testing RobertaEmbedding\n", + "# Testing RobertaEmbedding\n", "\n", - "# def test_RobertaEmbedding(key, ids = True):\n", - "# k_1, k_2 = jrandom.split(key, 2)\n", - "# my_func = my_roberta.RobertaEmbedding.init(Vocab, my_config, key=k_1)\n", - "# state = my_func.to_state_dict()\n", + "def test_RobertaEmbedding(key, ids = True):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", + " my_func = my_roberta.RobertaEmbedding.init(Vocab, my_config, key=k_1)\n", + " state = my_func.to_state_dict()\n", "\n", - "# state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", "\n", - "# # print(state.keys())\n", + " # print(state.keys())\n", "\n", - "# hf_func = hf_roberta.RobertaEmbeddings(hf_config)\n", - "# hf_func.load_state_dict(state, strict=True)\n", + " hf_func = hf_roberta.RobertaEmbeddings(hf_config)\n", + " hf_func.load_state_dict(state, strict=True, assign=True)\n", "\n", - "# if ids:\n", - "# my_output = my_func.embed(input_ids=input_ids, key=k_2)\n", - "# hf_output = hf_func(input_ids=input_ids_torch)\n", - "# else: \n", - "# my_output = my_func.embed(input_embeds=input_embeds, key=k_2)\n", - "# hf_output = hf_func(inputs_embeds=input_embeds_torch)\n", + " if ids:\n", + " my_output = my_func.embed(input_ids=input_ids, key=k_2)\n", + " hf_output = hf_func(input_ids=input_ids_torch)\n", + " else: \n", + " my_output = my_func.embed(input_embeds=input_embeds, key=k_2)\n", + " hf_output = hf_func(inputs_embeds=input_embeds_torch)\n", "\n", - "# return check(my_output.array, hf_output.detach())\n", + " return check(my_output.array, hf_output.detach())\n", "\n", - "# # Testing RobertaPooler\n", + "# Testing RobertaPooler\n", "\n", - "# def test_RobertaPooler(key):\n", - "# k_1, k_2 = jrandom.split(key, 2)\n", - "# my_func = my_roberta.RobertaPooler.init(my_config, key=k_1)\n", - "# state = my_func.to_state_dict()\n", + "def test_RobertaPooler(key):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", + " my_func = my_roberta.RobertaPooler.init(my_config, key=k_1)\n", + " state = my_func.to_state_dict()\n", "\n", - "# state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", "\n", - "# # print(state.keys())\n", + " # print(state.keys())\n", "\n", - "# hf_func = hf_roberta.RobertaPooler(hf_config)\n", - "# hf_func.load_state_dict(state, strict=True)\n", + " hf_func = hf_roberta.RobertaPooler(hf_config)\n", + " hf_func.load_state_dict(state, strict=True, assign=True)\n", "\n", - "# my_output = my_func(input_embeds, key=k_2)\n", - "# hf_output = hf_func(input_embeds_torch)\n", + " my_output = my_func(input_embeds, key=k_2)\n", + " hf_output = hf_func(input_embeds_torch)\n", "\n", - "# return check(my_output.array, hf_output.detach())\n", + " return check(my_output.array, hf_output.detach())\n", "\n", - "# # Testing RobertaModel\n", + "# Testing RobertaModel\n", "\n", - "# def test_RobertaModel(key, ids = True, pool = True):\n", - "# k_1, k_2 = jrandom.split(key, 2)\n", - "# my_func = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=pool, key=k_1)\n", - "# state = my_func.to_state_dict()\n", + "def test_RobertaModel(key, ids = True, pool = True):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", + " my_func = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=pool, key=k_1)\n", + " state = my_func.to_state_dict()\n", "\n", - "# state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", "\n", - "# # print(state.keys())\n", + " # print(state.keys())\n", "\n", - "# hf_func = hf_roberta.RobertaModel(hf_config, add_pooling_layer=pool)\n", - "# hf_func.load_state_dict(state, strict=True)\n", + " hf_func = hf_roberta.RobertaModel(hf_config, add_pooling_layer=pool)\n", + " hf_func.load_state_dict(state, strict=True, assign=True)\n", "\n", - "# if ids:\n", - "# my_output = my_func(input_ids = input_ids, attention_mask=mask, key=k_2)\n", - "# hf_output = hf_func(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\n", - "# else:\n", - "# my_output = my_func(input_embeds = input_embeds, attention_mask=mask, key=k_2)\n", - "# hf_output = hf_func(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\n", + " if ids:\n", + " my_output = my_func(input_ids = input_ids, attention_mask=mask, key=k_2)\n", + " hf_output = hf_func(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\n", + " else:\n", + " my_output = my_func(input_embeds = input_embeds, attention_mask=mask, key=k_2)\n", + " hf_output = hf_func(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\n", "\n", - "# if pool:\n", - "# return check(my_output[1].array, hf_output[1].detach())\n", - "# else:\n", - "# return check(my_output[0].array, hf_output[0].detach())\n", + " if pool:\n", + " return check(my_output[1].array, hf_output[1].detach())\n", + " else:\n", + " return check(my_output[0].array, hf_output[0].detach())\n", "\n", - "# # Testing RobertaLMHead\n", + "# Testing RobertaLMHead\n", "\n", - "# def test_RobertaLMHead(key):\n", - "# k_1, k_2 = jrandom.split(key, 2)\n", - "# my_func = my_roberta.RobertaLMHead.init(Vocab, my_config, key=k_1)\n", - "# state = my_func.to_state_dict()\n", + "def test_RobertaLMHead(key):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", + " my_func = my_roberta.RobertaLMHead.init(Vocab, my_config, key=k_1)\n", + " state = my_func.to_state_dict()\n", "\n", - "# state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", "\n", - "# state[\"bias\"] = torch.zeros(hf_config.vocab_size)\n", + " state[\"bias\"] = torch.zeros(hf_config.vocab_size)\n", "\n", - "# # print(state.keys())\n", + " # print(state.keys())\n", "\n", - "# hf_func = hf_roberta.RobertaLMHead(hf_config)\n", - "# hf_func.load_state_dict(state, strict=True)\n", + " hf_func = hf_roberta.RobertaLMHead(hf_config)\n", + " hf_func.load_state_dict(state, strict=True, assign=True)\n", "\n", - "# my_output = my_func(features, key=k_2)\n", - "# hf_output = hf_func(features_torch)\n", + " my_output = my_func(features, key=k_2)\n", + " hf_output = hf_func(features_torch)\n", "\n", - "# return check(my_output.array, hf_output.detach())\n", + " return check(my_output.array, hf_output.detach())\n", "\n", - "# # Testing RobertaForMaskedLM\n", + "# Testing RobertaForMaskedLM\n", "\n", - "# def test_RobertaForMaskedLM(key, ids = True):\n", - "# k_1, k_2 = jrandom.split(key, 2)\n", - "# my_pool = my_roberta.RobertaForMaskedLM.init(Vocab, my_config, key=k_1)\n", - "# state = my_pool.to_state_dict()\n", + "def test_RobertaForMaskedLM(key, ids = True):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", + " my_pool = my_roberta.RobertaForMaskedLM.init(Vocab, my_config, key=k_1)\n", + " state = my_pool.to_state_dict()\n", "\n", - "# state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", "\n", - "# state[\"lm_head.bias\"] = torch.zeros(hf_config.vocab_size)\n", + " state[\"lm_head.bias\"] = torch.zeros(hf_config.vocab_size)\n", "\n", - "# # print(state.keys())\n", + " # print(state.keys())\n", "\n", - "# hf_pool = hf_roberta.RobertaForMaskedLM(hf_config)\n", - "# hf_pool.load_state_dict(state, strict=True)\n", + " hf_pool = hf_roberta.RobertaForMaskedLM(hf_config)\n", + " hf_pool.load_state_dict(state, strict=True, assign=True)\n", "\n", - "# if ids:\n", - "# my_output = my_pool(input_ids = input_ids, attention_mask=mask, key=k_2)\n", - "# hf_output = hf_pool(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\n", - "# else:\n", - "# my_output = my_pool(input_embeds = input_embeds, attention_mask=mask, key=k_2)\n", - "# hf_output = hf_pool(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\n", + " if ids:\n", + " my_output = my_pool(input_ids = input_ids, attention_mask=mask, key=k_2)\n", + " hf_output = hf_pool(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\n", + " else:\n", + " my_output = my_pool(input_embeds = input_embeds, attention_mask=mask, key=k_2)\n", + " hf_output = hf_pool(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\n", "\n", - "# return check(my_output.array, hf_output[0].detach())" + " return check(my_output[0].array, hf_output[0].detach())" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ - "# seed = time()\n", - "# print(f\"seed: {int(seed)}\")\n", - "# key_vars = jrandom.PRNGKey(int(seed))\n", - "# keys = jrandom.split(key_vars, 15)\n", - "\n", - "# print(f\"test_RobertaSelfOutput: {out_func(test_RobertaSelfOutput(keys[0]))}\")\n", - "# print(f\"test_RobertaSelfAttention: {out_func(test_RobertaSelfAttention(keys[1]))}\")\n", - "# print(f\"test_RobertaAttention: {out_func(test_RobertaAttention(keys[2]))}\")\n", - "# print(f\"test_RobertaIntermediate: {out_func(test_RobertaIntermediate(keys[3]))}\")\n", - "# print(f\"test_RobertaOutput: {out_func(test_RobertaOutput(keys[4]))}\")\n", - "# print(f\"test_RobertaEmbedding(ids = True): {out_func(test_RobertaEmbedding(keys[7], ids = True))}\")\n", - "# print(f\"test_RobertaEmbedding(ids = False): {out_func(test_RobertaEmbedding(keys[8], ids = False))}\")\n", - "# print(f\"test_RobertaModel(ids = True, pool = True): {out_func(test_RobertaModel(keys[9], ids = True, pool = True))}\")\n", - "# print(f\"test_RobertaModel(ids = False, pool = False): {out_func(test_RobertaModel(keys[10], ids = False, pool = False))}\")\n", - "# print(f\"test_RobertaModel(ids = True, pool = True): {out_func(test_RobertaModel(keys[9], ids = True, pool = True))}\")\n", - "# print(f\"test_RobertaModel(ids = False, pool = False): {out_func(test_RobertaModel(keys[10], ids = False, pool = False))}\")\n", - "# print(f\"test_RobertaPooler: {out_func(test_RobertaPooler(keys[11]))}\")\n", - "# print(f\"test_RobertaLMHead: {out_func(test_RobertaLMHead(keys[12]))}\")\n", - "# print(f\"test_RobertaForMaskedLM(ids = True): {out_func(test_RobertaForMaskedLM(keys[13], ids = True))}\")\n", - "# print(f\"test_RobertaForMaskedLM(ids = False): {out_func(test_RobertaForMaskedLM(keys[14], ids = False))}\")" + "keys = jrandom.split(key_funcs, 15)" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "outs = []\n", + "\n", + "outs.append(test_RobertaSelfOutput(keys[0]))\n", + "outs.append(test_RobertaSelfAttention(keys[1]))\n", + "outs.append(test_RobertaAttention(keys[2]))\n", + "outs.append(test_RobertaIntermediate(keys[3]))\n", + "outs.append(test_RobertaOutput(keys[4]))\n", + "outs.append(test_RobertaLayer(keys[4]))\n", + "outs.append(test_RobertaEncoder(keys[4]))\n", + "outs.append(test_RobertaEmbedding(keys[7], ids = True))\n", + "outs.append(test_RobertaEmbedding(keys[8], ids = False))\n", + "outs.append(test_RobertaModel(keys[9], ids = True, pool = True))\n", + "outs.append(test_RobertaModel(keys[10], ids = False, pool = False))\n", + "outs.append(test_RobertaModel(keys[9], ids = True, pool = True))\n", + "outs.append(test_RobertaModel(keys[10], ids = False, pool = False))\n", + "outs.append(test_RobertaPooler(keys[11]))\n", + "outs.append(test_RobertaLMHead(keys[12]))\n", + "outs.append(test_RobertaForMaskedLM(keys[13], ids = True))\n", + "outs.append(test_RobertaForMaskedLM(keys[14], ids = False))" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "types = [\n", + " \"test_RobertaSelfOutput\",\n", + " \"test_RobertaSelfAttention\",\n", + " \"test_RobertaAttention\",\n", + " \"test_RobertaIntermediate\",\n", + " \"test_RobertaOutput\",\n", + " \"test_RobertaLayer\",\n", + " \"test_RobertaEncoder\",\n", + " \"test_RobertaEmbedding(ids = True)\",\n", + " \"test_RobertaEmbedding(ids = False)\",\n", + " \"test_RobertaModel(ids = True, pool = True)\",\n", + " \"test_RobertaModel(ids = False, pool = False)\",\n", + " \"test_RobertaModel(ids = True, pool = True)\",\n", + " \"test_RobertaModel(ids = False, pool = False)\",\n", + " \"test_RobertaPooler\",\n", + " \"test_RobertaLMHead\",\n", + " \"test_RobertaForMaskedLM(ids = True)\",\n", + " \"test_RobertaForMaskedLM(ids = False)\"\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "for i,o in enumerate(outs):\n", + " if o[2] * 0 != 0:\n", + " print(f\"nan alert\")\n", + " if o[0] < 1:\n", + " print(f\"{types[i]}: {o[3]}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "for i,o in enumerate(outs):\n", + " print(f\"{types[i]}: {o[3]}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "key_model, key_lm = jrandom.split(key_funcs, 2)\n", + "key_model_run, key_lm_run = jrandom.split(key_run, 2)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "seed: 1725471643\n" + "dict_keys(['encoder.layer.0.attention.self.query.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.layer.0.intermediate.dense.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.0.output.dense.weight', 'encoder.layer.0.output.dense.bias', 'encoder.layer.0.output.LayerNorm.weight', 'encoder.layer.0.output.LayerNorm.bias', 'encoder.layer.1.attention.self.query.weight', 'encoder.layer.1.attention.self.query.bias', 'encoder.layer.1.attention.self.key.weight', 'encoder.layer.1.attention.self.key.bias', 'encoder.layer.1.attention.self.value.weight', 'encoder.layer.1.attention.self.value.bias', 'encoder.layer.1.attention.output.dense.weight', 'encoder.layer.1.attention.output.dense.bias', 'encoder.layer.1.attention.output.LayerNorm.weight', 'encoder.layer.1.attention.output.LayerNorm.bias', 'encoder.layer.1.intermediate.dense.weight', 'encoder.layer.1.intermediate.dense.bias', 'encoder.layer.1.output.dense.weight', 'encoder.layer.1.output.dense.bias', 'encoder.layer.1.output.LayerNorm.weight', 'encoder.layer.1.output.LayerNorm.bias', 'encoder.layer.2.attention.self.query.weight', 'encoder.layer.2.attention.self.query.bias', 'encoder.layer.2.attention.self.key.weight', 'encoder.layer.2.attention.self.key.bias', 'encoder.layer.2.attention.self.value.weight', 'encoder.layer.2.attention.self.value.bias', 'encoder.layer.2.attention.output.dense.weight', 'encoder.layer.2.attention.output.dense.bias', 'encoder.layer.2.attention.output.LayerNorm.weight', 'encoder.layer.2.attention.output.LayerNorm.bias', 'encoder.layer.2.intermediate.dense.weight', 'encoder.layer.2.intermediate.dense.bias', 'encoder.layer.2.output.dense.weight', 'encoder.layer.2.output.dense.bias', 'encoder.layer.2.output.LayerNorm.weight', 'encoder.layer.2.output.LayerNorm.bias', 'encoder.layer.3.attention.self.query.weight', 'encoder.layer.3.attention.self.query.bias', 'encoder.layer.3.attention.self.key.weight', 'encoder.layer.3.attention.self.key.bias', 'encoder.layer.3.attention.self.value.weight', 'encoder.layer.3.attention.self.value.bias', 'encoder.layer.3.attention.output.dense.weight', 'encoder.layer.3.attention.output.dense.bias', 'encoder.layer.3.attention.output.LayerNorm.weight', 'encoder.layer.3.attention.output.LayerNorm.bias', 'encoder.layer.3.intermediate.dense.weight', 'encoder.layer.3.intermediate.dense.bias', 'encoder.layer.3.output.dense.weight', 'encoder.layer.3.output.dense.bias', 'encoder.layer.3.output.LayerNorm.weight', 'encoder.layer.3.output.LayerNorm.bias', 'encoder.layer.4.attention.self.query.weight', 'encoder.layer.4.attention.self.query.bias', 'encoder.layer.4.attention.self.key.weight', 'encoder.layer.4.attention.self.key.bias', 'encoder.layer.4.attention.self.value.weight', 'encoder.layer.4.attention.self.value.bias', 'encoder.layer.4.attention.output.dense.weight', 'encoder.layer.4.attention.output.dense.bias', 'encoder.layer.4.attention.output.LayerNorm.weight', 'encoder.layer.4.attention.output.LayerNorm.bias', 'encoder.layer.4.intermediate.dense.weight', 'encoder.layer.4.intermediate.dense.bias', 'encoder.layer.4.output.dense.weight', 'encoder.layer.4.output.dense.bias', 'encoder.layer.4.output.LayerNorm.weight', 'encoder.layer.4.output.LayerNorm.bias', 'encoder.layer.5.attention.self.query.weight', 'encoder.layer.5.attention.self.query.bias', 'encoder.layer.5.attention.self.key.weight', 'encoder.layer.5.attention.self.key.bias', 'encoder.layer.5.attention.self.value.weight', 'encoder.layer.5.attention.self.value.bias', 'encoder.layer.5.attention.output.dense.weight', 'encoder.layer.5.attention.output.dense.bias', 'encoder.layer.5.attention.output.LayerNorm.weight', 'encoder.layer.5.attention.output.LayerNorm.bias', 'encoder.layer.5.intermediate.dense.weight', 'encoder.layer.5.intermediate.dense.bias', 'encoder.layer.5.output.dense.weight', 'encoder.layer.5.output.dense.bias', 'encoder.layer.5.output.LayerNorm.weight', 'encoder.layer.5.output.LayerNorm.bias', 'encoder.layer.6.attention.self.query.weight', 'encoder.layer.6.attention.self.query.bias', 'encoder.layer.6.attention.self.key.weight', 'encoder.layer.6.attention.self.key.bias', 'encoder.layer.6.attention.self.value.weight', 'encoder.layer.6.attention.self.value.bias', 'encoder.layer.6.attention.output.dense.weight', 'encoder.layer.6.attention.output.dense.bias', 'encoder.layer.6.attention.output.LayerNorm.weight', 'encoder.layer.6.attention.output.LayerNorm.bias', 'encoder.layer.6.intermediate.dense.weight', 'encoder.layer.6.intermediate.dense.bias', 'encoder.layer.6.output.dense.weight', 'encoder.layer.6.output.dense.bias', 'encoder.layer.6.output.LayerNorm.weight', 'encoder.layer.6.output.LayerNorm.bias', 'encoder.layer.7.attention.self.query.weight', 'encoder.layer.7.attention.self.query.bias', 'encoder.layer.7.attention.self.key.weight', 'encoder.layer.7.attention.self.key.bias', 'encoder.layer.7.attention.self.value.weight', 'encoder.layer.7.attention.self.value.bias', 'encoder.layer.7.attention.output.dense.weight', 'encoder.layer.7.attention.output.dense.bias', 'encoder.layer.7.attention.output.LayerNorm.weight', 'encoder.layer.7.attention.output.LayerNorm.bias', 'encoder.layer.7.intermediate.dense.weight', 'encoder.layer.7.intermediate.dense.bias', 'encoder.layer.7.output.dense.weight', 'encoder.layer.7.output.dense.bias', 'encoder.layer.7.output.LayerNorm.weight', 'encoder.layer.7.output.LayerNorm.bias', 'encoder.layer.8.attention.self.query.weight', 'encoder.layer.8.attention.self.query.bias', 'encoder.layer.8.attention.self.key.weight', 'encoder.layer.8.attention.self.key.bias', 'encoder.layer.8.attention.self.value.weight', 'encoder.layer.8.attention.self.value.bias', 'encoder.layer.8.attention.output.dense.weight', 'encoder.layer.8.attention.output.dense.bias', 'encoder.layer.8.attention.output.LayerNorm.weight', 'encoder.layer.8.attention.output.LayerNorm.bias', 'encoder.layer.8.intermediate.dense.weight', 'encoder.layer.8.intermediate.dense.bias', 'encoder.layer.8.output.dense.weight', 'encoder.layer.8.output.dense.bias', 'encoder.layer.8.output.LayerNorm.weight', 'encoder.layer.8.output.LayerNorm.bias', 'encoder.layer.9.attention.self.query.weight', 'encoder.layer.9.attention.self.query.bias', 'encoder.layer.9.attention.self.key.weight', 'encoder.layer.9.attention.self.key.bias', 'encoder.layer.9.attention.self.value.weight', 'encoder.layer.9.attention.self.value.bias', 'encoder.layer.9.attention.output.dense.weight', 'encoder.layer.9.attention.output.dense.bias', 'encoder.layer.9.attention.output.LayerNorm.weight', 'encoder.layer.9.attention.output.LayerNorm.bias', 'encoder.layer.9.intermediate.dense.weight', 'encoder.layer.9.intermediate.dense.bias', 'encoder.layer.9.output.dense.weight', 'encoder.layer.9.output.dense.bias', 'encoder.layer.9.output.LayerNorm.weight', 'encoder.layer.9.output.LayerNorm.bias', 'encoder.layer.10.attention.self.query.weight', 'encoder.layer.10.attention.self.query.bias', 'encoder.layer.10.attention.self.key.weight', 'encoder.layer.10.attention.self.key.bias', 'encoder.layer.10.attention.self.value.weight', 'encoder.layer.10.attention.self.value.bias', 'encoder.layer.10.attention.output.dense.weight', 'encoder.layer.10.attention.output.dense.bias', 'encoder.layer.10.attention.output.LayerNorm.weight', 'encoder.layer.10.attention.output.LayerNorm.bias', 'encoder.layer.10.intermediate.dense.weight', 'encoder.layer.10.intermediate.dense.bias', 'encoder.layer.10.output.dense.weight', 'encoder.layer.10.output.dense.bias', 'encoder.layer.10.output.LayerNorm.weight', 'encoder.layer.10.output.LayerNorm.bias', 'encoder.layer.11.attention.self.query.weight', 'encoder.layer.11.attention.self.query.bias', 'encoder.layer.11.attention.self.key.weight', 'encoder.layer.11.attention.self.key.bias', 'encoder.layer.11.attention.self.value.weight', 'encoder.layer.11.attention.self.value.bias', 'encoder.layer.11.attention.output.dense.weight', 'encoder.layer.11.attention.output.dense.bias', 'encoder.layer.11.attention.output.LayerNorm.weight', 'encoder.layer.11.attention.output.LayerNorm.bias', 'encoder.layer.11.intermediate.dense.weight', 'encoder.layer.11.intermediate.dense.bias', 'encoder.layer.11.output.dense.weight', 'encoder.layer.11.output.dense.bias', 'encoder.layer.11.output.LayerNorm.weight', 'encoder.layer.11.output.LayerNorm.bias', 'embeddings.word_embeddings.weight', 'embeddings.position_embeddings.weight', 'embeddings.token_type_embeddings.weight', 'embeddings.LayerNorm.weight', 'embeddings.LayerNorm.bias', 'pooler.dense.weight', 'pooler.dense.bias'])\n", + "odict_keys(['embeddings.word_embeddings.weight', 'embeddings.position_embeddings.weight', 'embeddings.token_type_embeddings.weight', 'embeddings.LayerNorm.weight', 'embeddings.LayerNorm.bias', 'encoder.layer.0.attention.self.query.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.layer.0.intermediate.dense.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.0.output.dense.weight', 'encoder.layer.0.output.dense.bias', 'encoder.layer.0.output.LayerNorm.weight', 'encoder.layer.0.output.LayerNorm.bias', 'encoder.layer.1.attention.self.query.weight', 'encoder.layer.1.attention.self.query.bias', 'encoder.layer.1.attention.self.key.weight', 'encoder.layer.1.attention.self.key.bias', 'encoder.layer.1.attention.self.value.weight', 'encoder.layer.1.attention.self.value.bias', 'encoder.layer.1.attention.output.dense.weight', 'encoder.layer.1.attention.output.dense.bias', 'encoder.layer.1.attention.output.LayerNorm.weight', 'encoder.layer.1.attention.output.LayerNorm.bias', 'encoder.layer.1.intermediate.dense.weight', 'encoder.layer.1.intermediate.dense.bias', 'encoder.layer.1.output.dense.weight', 'encoder.layer.1.output.dense.bias', 'encoder.layer.1.output.LayerNorm.weight', 'encoder.layer.1.output.LayerNorm.bias', 'encoder.layer.2.attention.self.query.weight', 'encoder.layer.2.attention.self.query.bias', 'encoder.layer.2.attention.self.key.weight', 'encoder.layer.2.attention.self.key.bias', 'encoder.layer.2.attention.self.value.weight', 'encoder.layer.2.attention.self.value.bias', 'encoder.layer.2.attention.output.dense.weight', 'encoder.layer.2.attention.output.dense.bias', 'encoder.layer.2.attention.output.LayerNorm.weight', 'encoder.layer.2.attention.output.LayerNorm.bias', 'encoder.layer.2.intermediate.dense.weight', 'encoder.layer.2.intermediate.dense.bias', 'encoder.layer.2.output.dense.weight', 'encoder.layer.2.output.dense.bias', 'encoder.layer.2.output.LayerNorm.weight', 'encoder.layer.2.output.LayerNorm.bias', 'encoder.layer.3.attention.self.query.weight', 'encoder.layer.3.attention.self.query.bias', 'encoder.layer.3.attention.self.key.weight', 'encoder.layer.3.attention.self.key.bias', 'encoder.layer.3.attention.self.value.weight', 'encoder.layer.3.attention.self.value.bias', 'encoder.layer.3.attention.output.dense.weight', 'encoder.layer.3.attention.output.dense.bias', 'encoder.layer.3.attention.output.LayerNorm.weight', 'encoder.layer.3.attention.output.LayerNorm.bias', 'encoder.layer.3.intermediate.dense.weight', 'encoder.layer.3.intermediate.dense.bias', 'encoder.layer.3.output.dense.weight', 'encoder.layer.3.output.dense.bias', 'encoder.layer.3.output.LayerNorm.weight', 'encoder.layer.3.output.LayerNorm.bias', 'encoder.layer.4.attention.self.query.weight', 'encoder.layer.4.attention.self.query.bias', 'encoder.layer.4.attention.self.key.weight', 'encoder.layer.4.attention.self.key.bias', 'encoder.layer.4.attention.self.value.weight', 'encoder.layer.4.attention.self.value.bias', 'encoder.layer.4.attention.output.dense.weight', 'encoder.layer.4.attention.output.dense.bias', 'encoder.layer.4.attention.output.LayerNorm.weight', 'encoder.layer.4.attention.output.LayerNorm.bias', 'encoder.layer.4.intermediate.dense.weight', 'encoder.layer.4.intermediate.dense.bias', 'encoder.layer.4.output.dense.weight', 'encoder.layer.4.output.dense.bias', 'encoder.layer.4.output.LayerNorm.weight', 'encoder.layer.4.output.LayerNorm.bias', 'encoder.layer.5.attention.self.query.weight', 'encoder.layer.5.attention.self.query.bias', 'encoder.layer.5.attention.self.key.weight', 'encoder.layer.5.attention.self.key.bias', 'encoder.layer.5.attention.self.value.weight', 'encoder.layer.5.attention.self.value.bias', 'encoder.layer.5.attention.output.dense.weight', 'encoder.layer.5.attention.output.dense.bias', 'encoder.layer.5.attention.output.LayerNorm.weight', 'encoder.layer.5.attention.output.LayerNorm.bias', 'encoder.layer.5.intermediate.dense.weight', 'encoder.layer.5.intermediate.dense.bias', 'encoder.layer.5.output.dense.weight', 'encoder.layer.5.output.dense.bias', 'encoder.layer.5.output.LayerNorm.weight', 'encoder.layer.5.output.LayerNorm.bias', 'encoder.layer.6.attention.self.query.weight', 'encoder.layer.6.attention.self.query.bias', 'encoder.layer.6.attention.self.key.weight', 'encoder.layer.6.attention.self.key.bias', 'encoder.layer.6.attention.self.value.weight', 'encoder.layer.6.attention.self.value.bias', 'encoder.layer.6.attention.output.dense.weight', 'encoder.layer.6.attention.output.dense.bias', 'encoder.layer.6.attention.output.LayerNorm.weight', 'encoder.layer.6.attention.output.LayerNorm.bias', 'encoder.layer.6.intermediate.dense.weight', 'encoder.layer.6.intermediate.dense.bias', 'encoder.layer.6.output.dense.weight', 'encoder.layer.6.output.dense.bias', 'encoder.layer.6.output.LayerNorm.weight', 'encoder.layer.6.output.LayerNorm.bias', 'encoder.layer.7.attention.self.query.weight', 'encoder.layer.7.attention.self.query.bias', 'encoder.layer.7.attention.self.key.weight', 'encoder.layer.7.attention.self.key.bias', 'encoder.layer.7.attention.self.value.weight', 'encoder.layer.7.attention.self.value.bias', 'encoder.layer.7.attention.output.dense.weight', 'encoder.layer.7.attention.output.dense.bias', 'encoder.layer.7.attention.output.LayerNorm.weight', 'encoder.layer.7.attention.output.LayerNorm.bias', 'encoder.layer.7.intermediate.dense.weight', 'encoder.layer.7.intermediate.dense.bias', 'encoder.layer.7.output.dense.weight', 'encoder.layer.7.output.dense.bias', 'encoder.layer.7.output.LayerNorm.weight', 'encoder.layer.7.output.LayerNorm.bias', 'encoder.layer.8.attention.self.query.weight', 'encoder.layer.8.attention.self.query.bias', 'encoder.layer.8.attention.self.key.weight', 'encoder.layer.8.attention.self.key.bias', 'encoder.layer.8.attention.self.value.weight', 'encoder.layer.8.attention.self.value.bias', 'encoder.layer.8.attention.output.dense.weight', 'encoder.layer.8.attention.output.dense.bias', 'encoder.layer.8.attention.output.LayerNorm.weight', 'encoder.layer.8.attention.output.LayerNorm.bias', 'encoder.layer.8.intermediate.dense.weight', 'encoder.layer.8.intermediate.dense.bias', 'encoder.layer.8.output.dense.weight', 'encoder.layer.8.output.dense.bias', 'encoder.layer.8.output.LayerNorm.weight', 'encoder.layer.8.output.LayerNorm.bias', 'encoder.layer.9.attention.self.query.weight', 'encoder.layer.9.attention.self.query.bias', 'encoder.layer.9.attention.self.key.weight', 'encoder.layer.9.attention.self.key.bias', 'encoder.layer.9.attention.self.value.weight', 'encoder.layer.9.attention.self.value.bias', 'encoder.layer.9.attention.output.dense.weight', 'encoder.layer.9.attention.output.dense.bias', 'encoder.layer.9.attention.output.LayerNorm.weight', 'encoder.layer.9.attention.output.LayerNorm.bias', 'encoder.layer.9.intermediate.dense.weight', 'encoder.layer.9.intermediate.dense.bias', 'encoder.layer.9.output.dense.weight', 'encoder.layer.9.output.dense.bias', 'encoder.layer.9.output.LayerNorm.weight', 'encoder.layer.9.output.LayerNorm.bias', 'encoder.layer.10.attention.self.query.weight', 'encoder.layer.10.attention.self.query.bias', 'encoder.layer.10.attention.self.key.weight', 'encoder.layer.10.attention.self.key.bias', 'encoder.layer.10.attention.self.value.weight', 'encoder.layer.10.attention.self.value.bias', 'encoder.layer.10.attention.output.dense.weight', 'encoder.layer.10.attention.output.dense.bias', 'encoder.layer.10.attention.output.LayerNorm.weight', 'encoder.layer.10.attention.output.LayerNorm.bias', 'encoder.layer.10.intermediate.dense.weight', 'encoder.layer.10.intermediate.dense.bias', 'encoder.layer.10.output.dense.weight', 'encoder.layer.10.output.dense.bias', 'encoder.layer.10.output.LayerNorm.weight', 'encoder.layer.10.output.LayerNorm.bias', 'encoder.layer.11.attention.self.query.weight', 'encoder.layer.11.attention.self.query.bias', 'encoder.layer.11.attention.self.key.weight', 'encoder.layer.11.attention.self.key.bias', 'encoder.layer.11.attention.self.value.weight', 'encoder.layer.11.attention.self.value.bias', 'encoder.layer.11.attention.output.dense.weight', 'encoder.layer.11.attention.output.dense.bias', 'encoder.layer.11.attention.output.LayerNorm.weight', 'encoder.layer.11.attention.output.LayerNorm.bias', 'encoder.layer.11.intermediate.dense.weight', 'encoder.layer.11.intermediate.dense.bias', 'encoder.layer.11.output.dense.weight', 'encoder.layer.11.output.dense.bias', 'encoder.layer.11.output.LayerNorm.weight', 'encoder.layer.11.output.LayerNorm.bias', 'pooler.dense.weight', 'pooler.dense.bias'])\n", + "success1\n", + "success2\n", + "0.0\n" ] } ], "source": [ - "seed = time()\n", - "print(f\"seed: {int(seed)}\")\n", - "key = jrandom.PRNGKey(int(seed))\n", + "# Initializing RobertaModel\n", + "my_model = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=True, output_hidden_states=True, key=key_model)\n", + "state_model = my_model.to_state_dict()\n", "\n", - "def test_RobertaModel_Output(key, ids = False, pool = True):\n", - " k_1, k_2 = jrandom.split(key, 2)\n", - " my_func = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=pool, output_hidden_states=True, key=k_1)\n", - " state = my_func.to_state_dict()\n", + "state_model = {k: torch.from_numpy(np.array(v)) for k, v in state_model.items()}\n", "\n", - " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "hf_model = hf_roberta.RobertaModel(hf_config, add_pooling_layer=True)\n", + "hf_model.load_state_dict(state_model, strict=True)\n", "\n", - " # print(state.keys())\n", + "print(check_dicts(my_model.to_state_dict(), hf_model.state_dict()))" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dict_keys(['roberta.encoder.layer.0.attention.self.query.weight', 'roberta.encoder.layer.0.attention.self.query.bias', 'roberta.encoder.layer.0.attention.self.key.weight', 'roberta.encoder.layer.0.attention.self.key.bias', 'roberta.encoder.layer.0.attention.self.value.weight', 'roberta.encoder.layer.0.attention.self.value.bias', 'roberta.encoder.layer.0.attention.output.dense.weight', 'roberta.encoder.layer.0.attention.output.dense.bias', 'roberta.encoder.layer.0.attention.output.LayerNorm.weight', 'roberta.encoder.layer.0.attention.output.LayerNorm.bias', 'roberta.encoder.layer.0.intermediate.dense.weight', 'roberta.encoder.layer.0.intermediate.dense.bias', 'roberta.encoder.layer.0.output.dense.weight', 'roberta.encoder.layer.0.output.dense.bias', 'roberta.encoder.layer.0.output.LayerNorm.weight', 'roberta.encoder.layer.0.output.LayerNorm.bias', 'roberta.encoder.layer.1.attention.self.query.weight', 'roberta.encoder.layer.1.attention.self.query.bias', 'roberta.encoder.layer.1.attention.self.key.weight', 'roberta.encoder.layer.1.attention.self.key.bias', 'roberta.encoder.layer.1.attention.self.value.weight', 'roberta.encoder.layer.1.attention.self.value.bias', 'roberta.encoder.layer.1.attention.output.dense.weight', 'roberta.encoder.layer.1.attention.output.dense.bias', 'roberta.encoder.layer.1.attention.output.LayerNorm.weight', 'roberta.encoder.layer.1.attention.output.LayerNorm.bias', 'roberta.encoder.layer.1.intermediate.dense.weight', 'roberta.encoder.layer.1.intermediate.dense.bias', 'roberta.encoder.layer.1.output.dense.weight', 'roberta.encoder.layer.1.output.dense.bias', 'roberta.encoder.layer.1.output.LayerNorm.weight', 'roberta.encoder.layer.1.output.LayerNorm.bias', 'roberta.encoder.layer.2.attention.self.query.weight', 'roberta.encoder.layer.2.attention.self.query.bias', 'roberta.encoder.layer.2.attention.self.key.weight', 'roberta.encoder.layer.2.attention.self.key.bias', 'roberta.encoder.layer.2.attention.self.value.weight', 'roberta.encoder.layer.2.attention.self.value.bias', 'roberta.encoder.layer.2.attention.output.dense.weight', 'roberta.encoder.layer.2.attention.output.dense.bias', 'roberta.encoder.layer.2.attention.output.LayerNorm.weight', 'roberta.encoder.layer.2.attention.output.LayerNorm.bias', 'roberta.encoder.layer.2.intermediate.dense.weight', 'roberta.encoder.layer.2.intermediate.dense.bias', 'roberta.encoder.layer.2.output.dense.weight', 'roberta.encoder.layer.2.output.dense.bias', 'roberta.encoder.layer.2.output.LayerNorm.weight', 'roberta.encoder.layer.2.output.LayerNorm.bias', 'roberta.encoder.layer.3.attention.self.query.weight', 'roberta.encoder.layer.3.attention.self.query.bias', 'roberta.encoder.layer.3.attention.self.key.weight', 'roberta.encoder.layer.3.attention.self.key.bias', 'roberta.encoder.layer.3.attention.self.value.weight', 'roberta.encoder.layer.3.attention.self.value.bias', 'roberta.encoder.layer.3.attention.output.dense.weight', 'roberta.encoder.layer.3.attention.output.dense.bias', 'roberta.encoder.layer.3.attention.output.LayerNorm.weight', 'roberta.encoder.layer.3.attention.output.LayerNorm.bias', 'roberta.encoder.layer.3.intermediate.dense.weight', 'roberta.encoder.layer.3.intermediate.dense.bias', 'roberta.encoder.layer.3.output.dense.weight', 'roberta.encoder.layer.3.output.dense.bias', 'roberta.encoder.layer.3.output.LayerNorm.weight', 'roberta.encoder.layer.3.output.LayerNorm.bias', 'roberta.encoder.layer.4.attention.self.query.weight', 'roberta.encoder.layer.4.attention.self.query.bias', 'roberta.encoder.layer.4.attention.self.key.weight', 'roberta.encoder.layer.4.attention.self.key.bias', 'roberta.encoder.layer.4.attention.self.value.weight', 'roberta.encoder.layer.4.attention.self.value.bias', 'roberta.encoder.layer.4.attention.output.dense.weight', 'roberta.encoder.layer.4.attention.output.dense.bias', 'roberta.encoder.layer.4.attention.output.LayerNorm.weight', 'roberta.encoder.layer.4.attention.output.LayerNorm.bias', 'roberta.encoder.layer.4.intermediate.dense.weight', 'roberta.encoder.layer.4.intermediate.dense.bias', 'roberta.encoder.layer.4.output.dense.weight', 'roberta.encoder.layer.4.output.dense.bias', 'roberta.encoder.layer.4.output.LayerNorm.weight', 'roberta.encoder.layer.4.output.LayerNorm.bias', 'roberta.encoder.layer.5.attention.self.query.weight', 'roberta.encoder.layer.5.attention.self.query.bias', 'roberta.encoder.layer.5.attention.self.key.weight', 'roberta.encoder.layer.5.attention.self.key.bias', 'roberta.encoder.layer.5.attention.self.value.weight', 'roberta.encoder.layer.5.attention.self.value.bias', 'roberta.encoder.layer.5.attention.output.dense.weight', 'roberta.encoder.layer.5.attention.output.dense.bias', 'roberta.encoder.layer.5.attention.output.LayerNorm.weight', 'roberta.encoder.layer.5.attention.output.LayerNorm.bias', 'roberta.encoder.layer.5.intermediate.dense.weight', 'roberta.encoder.layer.5.intermediate.dense.bias', 'roberta.encoder.layer.5.output.dense.weight', 'roberta.encoder.layer.5.output.dense.bias', 'roberta.encoder.layer.5.output.LayerNorm.weight', 'roberta.encoder.layer.5.output.LayerNorm.bias', 'roberta.encoder.layer.6.attention.self.query.weight', 'roberta.encoder.layer.6.attention.self.query.bias', 'roberta.encoder.layer.6.attention.self.key.weight', 'roberta.encoder.layer.6.attention.self.key.bias', 'roberta.encoder.layer.6.attention.self.value.weight', 'roberta.encoder.layer.6.attention.self.value.bias', 'roberta.encoder.layer.6.attention.output.dense.weight', 'roberta.encoder.layer.6.attention.output.dense.bias', 'roberta.encoder.layer.6.attention.output.LayerNorm.weight', 'roberta.encoder.layer.6.attention.output.LayerNorm.bias', 'roberta.encoder.layer.6.intermediate.dense.weight', 'roberta.encoder.layer.6.intermediate.dense.bias', 'roberta.encoder.layer.6.output.dense.weight', 'roberta.encoder.layer.6.output.dense.bias', 'roberta.encoder.layer.6.output.LayerNorm.weight', 'roberta.encoder.layer.6.output.LayerNorm.bias', 'roberta.encoder.layer.7.attention.self.query.weight', 'roberta.encoder.layer.7.attention.self.query.bias', 'roberta.encoder.layer.7.attention.self.key.weight', 'roberta.encoder.layer.7.attention.self.key.bias', 'roberta.encoder.layer.7.attention.self.value.weight', 'roberta.encoder.layer.7.attention.self.value.bias', 'roberta.encoder.layer.7.attention.output.dense.weight', 'roberta.encoder.layer.7.attention.output.dense.bias', 'roberta.encoder.layer.7.attention.output.LayerNorm.weight', 'roberta.encoder.layer.7.attention.output.LayerNorm.bias', 'roberta.encoder.layer.7.intermediate.dense.weight', 'roberta.encoder.layer.7.intermediate.dense.bias', 'roberta.encoder.layer.7.output.dense.weight', 'roberta.encoder.layer.7.output.dense.bias', 'roberta.encoder.layer.7.output.LayerNorm.weight', 'roberta.encoder.layer.7.output.LayerNorm.bias', 'roberta.encoder.layer.8.attention.self.query.weight', 'roberta.encoder.layer.8.attention.self.query.bias', 'roberta.encoder.layer.8.attention.self.key.weight', 'roberta.encoder.layer.8.attention.self.key.bias', 'roberta.encoder.layer.8.attention.self.value.weight', 'roberta.encoder.layer.8.attention.self.value.bias', 'roberta.encoder.layer.8.attention.output.dense.weight', 'roberta.encoder.layer.8.attention.output.dense.bias', 'roberta.encoder.layer.8.attention.output.LayerNorm.weight', 'roberta.encoder.layer.8.attention.output.LayerNorm.bias', 'roberta.encoder.layer.8.intermediate.dense.weight', 'roberta.encoder.layer.8.intermediate.dense.bias', 'roberta.encoder.layer.8.output.dense.weight', 'roberta.encoder.layer.8.output.dense.bias', 'roberta.encoder.layer.8.output.LayerNorm.weight', 'roberta.encoder.layer.8.output.LayerNorm.bias', 'roberta.encoder.layer.9.attention.self.query.weight', 'roberta.encoder.layer.9.attention.self.query.bias', 'roberta.encoder.layer.9.attention.self.key.weight', 'roberta.encoder.layer.9.attention.self.key.bias', 'roberta.encoder.layer.9.attention.self.value.weight', 'roberta.encoder.layer.9.attention.self.value.bias', 'roberta.encoder.layer.9.attention.output.dense.weight', 'roberta.encoder.layer.9.attention.output.dense.bias', 'roberta.encoder.layer.9.attention.output.LayerNorm.weight', 'roberta.encoder.layer.9.attention.output.LayerNorm.bias', 'roberta.encoder.layer.9.intermediate.dense.weight', 'roberta.encoder.layer.9.intermediate.dense.bias', 'roberta.encoder.layer.9.output.dense.weight', 'roberta.encoder.layer.9.output.dense.bias', 'roberta.encoder.layer.9.output.LayerNorm.weight', 'roberta.encoder.layer.9.output.LayerNorm.bias', 'roberta.encoder.layer.10.attention.self.query.weight', 'roberta.encoder.layer.10.attention.self.query.bias', 'roberta.encoder.layer.10.attention.self.key.weight', 'roberta.encoder.layer.10.attention.self.key.bias', 'roberta.encoder.layer.10.attention.self.value.weight', 'roberta.encoder.layer.10.attention.self.value.bias', 'roberta.encoder.layer.10.attention.output.dense.weight', 'roberta.encoder.layer.10.attention.output.dense.bias', 'roberta.encoder.layer.10.attention.output.LayerNorm.weight', 'roberta.encoder.layer.10.attention.output.LayerNorm.bias', 'roberta.encoder.layer.10.intermediate.dense.weight', 'roberta.encoder.layer.10.intermediate.dense.bias', 'roberta.encoder.layer.10.output.dense.weight', 'roberta.encoder.layer.10.output.dense.bias', 'roberta.encoder.layer.10.output.LayerNorm.weight', 'roberta.encoder.layer.10.output.LayerNorm.bias', 'roberta.encoder.layer.11.attention.self.query.weight', 'roberta.encoder.layer.11.attention.self.query.bias', 'roberta.encoder.layer.11.attention.self.key.weight', 'roberta.encoder.layer.11.attention.self.key.bias', 'roberta.encoder.layer.11.attention.self.value.weight', 'roberta.encoder.layer.11.attention.self.value.bias', 'roberta.encoder.layer.11.attention.output.dense.weight', 'roberta.encoder.layer.11.attention.output.dense.bias', 'roberta.encoder.layer.11.attention.output.LayerNorm.weight', 'roberta.encoder.layer.11.attention.output.LayerNorm.bias', 'roberta.encoder.layer.11.intermediate.dense.weight', 'roberta.encoder.layer.11.intermediate.dense.bias', 'roberta.encoder.layer.11.output.dense.weight', 'roberta.encoder.layer.11.output.dense.bias', 'roberta.encoder.layer.11.output.LayerNorm.weight', 'roberta.encoder.layer.11.output.LayerNorm.bias', 'roberta.embeddings.word_embeddings.weight', 'roberta.embeddings.position_embeddings.weight', 'roberta.embeddings.token_type_embeddings.weight', 'roberta.embeddings.LayerNorm.weight', 'roberta.embeddings.LayerNorm.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.decoder.bias', 'lm_head.bias'])\n", + "odict_keys(['roberta.embeddings.word_embeddings.weight', 'roberta.embeddings.position_embeddings.weight', 'roberta.embeddings.token_type_embeddings.weight', 'roberta.embeddings.LayerNorm.weight', 'roberta.embeddings.LayerNorm.bias', 'roberta.encoder.layer.0.attention.self.query.weight', 'roberta.encoder.layer.0.attention.self.query.bias', 'roberta.encoder.layer.0.attention.self.key.weight', 'roberta.encoder.layer.0.attention.self.key.bias', 'roberta.encoder.layer.0.attention.self.value.weight', 'roberta.encoder.layer.0.attention.self.value.bias', 'roberta.encoder.layer.0.attention.output.dense.weight', 'roberta.encoder.layer.0.attention.output.dense.bias', 'roberta.encoder.layer.0.attention.output.LayerNorm.weight', 'roberta.encoder.layer.0.attention.output.LayerNorm.bias', 'roberta.encoder.layer.0.intermediate.dense.weight', 'roberta.encoder.layer.0.intermediate.dense.bias', 'roberta.encoder.layer.0.output.dense.weight', 'roberta.encoder.layer.0.output.dense.bias', 'roberta.encoder.layer.0.output.LayerNorm.weight', 'roberta.encoder.layer.0.output.LayerNorm.bias', 'roberta.encoder.layer.1.attention.self.query.weight', 'roberta.encoder.layer.1.attention.self.query.bias', 'roberta.encoder.layer.1.attention.self.key.weight', 'roberta.encoder.layer.1.attention.self.key.bias', 'roberta.encoder.layer.1.attention.self.value.weight', 'roberta.encoder.layer.1.attention.self.value.bias', 'roberta.encoder.layer.1.attention.output.dense.weight', 'roberta.encoder.layer.1.attention.output.dense.bias', 'roberta.encoder.layer.1.attention.output.LayerNorm.weight', 'roberta.encoder.layer.1.attention.output.LayerNorm.bias', 'roberta.encoder.layer.1.intermediate.dense.weight', 'roberta.encoder.layer.1.intermediate.dense.bias', 'roberta.encoder.layer.1.output.dense.weight', 'roberta.encoder.layer.1.output.dense.bias', 'roberta.encoder.layer.1.output.LayerNorm.weight', 'roberta.encoder.layer.1.output.LayerNorm.bias', 'roberta.encoder.layer.2.attention.self.query.weight', 'roberta.encoder.layer.2.attention.self.query.bias', 'roberta.encoder.layer.2.attention.self.key.weight', 'roberta.encoder.layer.2.attention.self.key.bias', 'roberta.encoder.layer.2.attention.self.value.weight', 'roberta.encoder.layer.2.attention.self.value.bias', 'roberta.encoder.layer.2.attention.output.dense.weight', 'roberta.encoder.layer.2.attention.output.dense.bias', 'roberta.encoder.layer.2.attention.output.LayerNorm.weight', 'roberta.encoder.layer.2.attention.output.LayerNorm.bias', 'roberta.encoder.layer.2.intermediate.dense.weight', 'roberta.encoder.layer.2.intermediate.dense.bias', 'roberta.encoder.layer.2.output.dense.weight', 'roberta.encoder.layer.2.output.dense.bias', 'roberta.encoder.layer.2.output.LayerNorm.weight', 'roberta.encoder.layer.2.output.LayerNorm.bias', 'roberta.encoder.layer.3.attention.self.query.weight', 'roberta.encoder.layer.3.attention.self.query.bias', 'roberta.encoder.layer.3.attention.self.key.weight', 'roberta.encoder.layer.3.attention.self.key.bias', 'roberta.encoder.layer.3.attention.self.value.weight', 'roberta.encoder.layer.3.attention.self.value.bias', 'roberta.encoder.layer.3.attention.output.dense.weight', 'roberta.encoder.layer.3.attention.output.dense.bias', 'roberta.encoder.layer.3.attention.output.LayerNorm.weight', 'roberta.encoder.layer.3.attention.output.LayerNorm.bias', 'roberta.encoder.layer.3.intermediate.dense.weight', 'roberta.encoder.layer.3.intermediate.dense.bias', 'roberta.encoder.layer.3.output.dense.weight', 'roberta.encoder.layer.3.output.dense.bias', 'roberta.encoder.layer.3.output.LayerNorm.weight', 'roberta.encoder.layer.3.output.LayerNorm.bias', 'roberta.encoder.layer.4.attention.self.query.weight', 'roberta.encoder.layer.4.attention.self.query.bias', 'roberta.encoder.layer.4.attention.self.key.weight', 'roberta.encoder.layer.4.attention.self.key.bias', 'roberta.encoder.layer.4.attention.self.value.weight', 'roberta.encoder.layer.4.attention.self.value.bias', 'roberta.encoder.layer.4.attention.output.dense.weight', 'roberta.encoder.layer.4.attention.output.dense.bias', 'roberta.encoder.layer.4.attention.output.LayerNorm.weight', 'roberta.encoder.layer.4.attention.output.LayerNorm.bias', 'roberta.encoder.layer.4.intermediate.dense.weight', 'roberta.encoder.layer.4.intermediate.dense.bias', 'roberta.encoder.layer.4.output.dense.weight', 'roberta.encoder.layer.4.output.dense.bias', 'roberta.encoder.layer.4.output.LayerNorm.weight', 'roberta.encoder.layer.4.output.LayerNorm.bias', 'roberta.encoder.layer.5.attention.self.query.weight', 'roberta.encoder.layer.5.attention.self.query.bias', 'roberta.encoder.layer.5.attention.self.key.weight', 'roberta.encoder.layer.5.attention.self.key.bias', 'roberta.encoder.layer.5.attention.self.value.weight', 'roberta.encoder.layer.5.attention.self.value.bias', 'roberta.encoder.layer.5.attention.output.dense.weight', 'roberta.encoder.layer.5.attention.output.dense.bias', 'roberta.encoder.layer.5.attention.output.LayerNorm.weight', 'roberta.encoder.layer.5.attention.output.LayerNorm.bias', 'roberta.encoder.layer.5.intermediate.dense.weight', 'roberta.encoder.layer.5.intermediate.dense.bias', 'roberta.encoder.layer.5.output.dense.weight', 'roberta.encoder.layer.5.output.dense.bias', 'roberta.encoder.layer.5.output.LayerNorm.weight', 'roberta.encoder.layer.5.output.LayerNorm.bias', 'roberta.encoder.layer.6.attention.self.query.weight', 'roberta.encoder.layer.6.attention.self.query.bias', 'roberta.encoder.layer.6.attention.self.key.weight', 'roberta.encoder.layer.6.attention.self.key.bias', 'roberta.encoder.layer.6.attention.self.value.weight', 'roberta.encoder.layer.6.attention.self.value.bias', 'roberta.encoder.layer.6.attention.output.dense.weight', 'roberta.encoder.layer.6.attention.output.dense.bias', 'roberta.encoder.layer.6.attention.output.LayerNorm.weight', 'roberta.encoder.layer.6.attention.output.LayerNorm.bias', 'roberta.encoder.layer.6.intermediate.dense.weight', 'roberta.encoder.layer.6.intermediate.dense.bias', 'roberta.encoder.layer.6.output.dense.weight', 'roberta.encoder.layer.6.output.dense.bias', 'roberta.encoder.layer.6.output.LayerNorm.weight', 'roberta.encoder.layer.6.output.LayerNorm.bias', 'roberta.encoder.layer.7.attention.self.query.weight', 'roberta.encoder.layer.7.attention.self.query.bias', 'roberta.encoder.layer.7.attention.self.key.weight', 'roberta.encoder.layer.7.attention.self.key.bias', 'roberta.encoder.layer.7.attention.self.value.weight', 'roberta.encoder.layer.7.attention.self.value.bias', 'roberta.encoder.layer.7.attention.output.dense.weight', 'roberta.encoder.layer.7.attention.output.dense.bias', 'roberta.encoder.layer.7.attention.output.LayerNorm.weight', 'roberta.encoder.layer.7.attention.output.LayerNorm.bias', 'roberta.encoder.layer.7.intermediate.dense.weight', 'roberta.encoder.layer.7.intermediate.dense.bias', 'roberta.encoder.layer.7.output.dense.weight', 'roberta.encoder.layer.7.output.dense.bias', 'roberta.encoder.layer.7.output.LayerNorm.weight', 'roberta.encoder.layer.7.output.LayerNorm.bias', 'roberta.encoder.layer.8.attention.self.query.weight', 'roberta.encoder.layer.8.attention.self.query.bias', 'roberta.encoder.layer.8.attention.self.key.weight', 'roberta.encoder.layer.8.attention.self.key.bias', 'roberta.encoder.layer.8.attention.self.value.weight', 'roberta.encoder.layer.8.attention.self.value.bias', 'roberta.encoder.layer.8.attention.output.dense.weight', 'roberta.encoder.layer.8.attention.output.dense.bias', 'roberta.encoder.layer.8.attention.output.LayerNorm.weight', 'roberta.encoder.layer.8.attention.output.LayerNorm.bias', 'roberta.encoder.layer.8.intermediate.dense.weight', 'roberta.encoder.layer.8.intermediate.dense.bias', 'roberta.encoder.layer.8.output.dense.weight', 'roberta.encoder.layer.8.output.dense.bias', 'roberta.encoder.layer.8.output.LayerNorm.weight', 'roberta.encoder.layer.8.output.LayerNorm.bias', 'roberta.encoder.layer.9.attention.self.query.weight', 'roberta.encoder.layer.9.attention.self.query.bias', 'roberta.encoder.layer.9.attention.self.key.weight', 'roberta.encoder.layer.9.attention.self.key.bias', 'roberta.encoder.layer.9.attention.self.value.weight', 'roberta.encoder.layer.9.attention.self.value.bias', 'roberta.encoder.layer.9.attention.output.dense.weight', 'roberta.encoder.layer.9.attention.output.dense.bias', 'roberta.encoder.layer.9.attention.output.LayerNorm.weight', 'roberta.encoder.layer.9.attention.output.LayerNorm.bias', 'roberta.encoder.layer.9.intermediate.dense.weight', 'roberta.encoder.layer.9.intermediate.dense.bias', 'roberta.encoder.layer.9.output.dense.weight', 'roberta.encoder.layer.9.output.dense.bias', 'roberta.encoder.layer.9.output.LayerNorm.weight', 'roberta.encoder.layer.9.output.LayerNorm.bias', 'roberta.encoder.layer.10.attention.self.query.weight', 'roberta.encoder.layer.10.attention.self.query.bias', 'roberta.encoder.layer.10.attention.self.key.weight', 'roberta.encoder.layer.10.attention.self.key.bias', 'roberta.encoder.layer.10.attention.self.value.weight', 'roberta.encoder.layer.10.attention.self.value.bias', 'roberta.encoder.layer.10.attention.output.dense.weight', 'roberta.encoder.layer.10.attention.output.dense.bias', 'roberta.encoder.layer.10.attention.output.LayerNorm.weight', 'roberta.encoder.layer.10.attention.output.LayerNorm.bias', 'roberta.encoder.layer.10.intermediate.dense.weight', 'roberta.encoder.layer.10.intermediate.dense.bias', 'roberta.encoder.layer.10.output.dense.weight', 'roberta.encoder.layer.10.output.dense.bias', 'roberta.encoder.layer.10.output.LayerNorm.weight', 'roberta.encoder.layer.10.output.LayerNorm.bias', 'roberta.encoder.layer.11.attention.self.query.weight', 'roberta.encoder.layer.11.attention.self.query.bias', 'roberta.encoder.layer.11.attention.self.key.weight', 'roberta.encoder.layer.11.attention.self.key.bias', 'roberta.encoder.layer.11.attention.self.value.weight', 'roberta.encoder.layer.11.attention.self.value.bias', 'roberta.encoder.layer.11.attention.output.dense.weight', 'roberta.encoder.layer.11.attention.output.dense.bias', 'roberta.encoder.layer.11.attention.output.LayerNorm.weight', 'roberta.encoder.layer.11.attention.output.LayerNorm.bias', 'roberta.encoder.layer.11.intermediate.dense.weight', 'roberta.encoder.layer.11.intermediate.dense.bias', 'roberta.encoder.layer.11.output.dense.weight', 'roberta.encoder.layer.11.output.dense.bias', 'roberta.encoder.layer.11.output.LayerNorm.weight', 'roberta.encoder.layer.11.output.LayerNorm.bias', 'lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.decoder.bias'])\n", + "success1\n", + "success2\n", + "0.0\n" + ] + } + ], + "source": [ + "# Initializing RobertaForMaskedLM\n", + "my_mlm = my_roberta.RobertaForMaskedLM.init(Vocab, my_config, output_hidden_states=True, key=key_funcs)\n", + "state_mlm = my_mlm.to_state_dict()\n", "\n", - " hf_func = hf_roberta.RobertaModel(hf_config, add_pooling_layer=pool)\n", - " hf_func.load_state_dict(state, strict=True)\n", + "state_mlm = {k: torch.from_numpy(np.array(v)) for k, v in state_mlm.items()}\n", + "state_mlm[\"lm_head.bias\"] = torch.zeros(hf_config.vocab_size)\n", + "# print(state_mlm[w_str])\n", "\n", + "hf_mlm = hf_roberta.RobertaForMaskedLM(hf_config)\n", + "hf_mlm.load_state_dict(state_mlm, strict=True, assign=True)\n", + "# print(hf_mlm.state_dict()[w_str])\n", + "\n", + "print(check_dicts(state_mlm, hf_mlm.state_dict()))" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "def test_RobertaModel_Output(key_run, ids = False):\n", " if ids:\n", - " my_output = my_func(input_ids = input_ids, attention_mask=mask, key=k_2)\n", - " hf_output = hf_func(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False, output_hidden_states=True)\n", + " my_output = my_model(input_ids = input_ids, attention_mask=mask, key=key_run)\n", + " hf_output = hf_model(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False, output_hidden_states=True)\n", " else:\n", - " my_output = my_func(input_embeds = input_embeds, attention_mask=mask, key=k_2)\n", - " hf_output = hf_func(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False, output_hidden_states=True)\n", + " my_output = my_model(input_embeds = input_embeds, attention_mask=mask, key=key_run)\n", + " hf_output = hf_model(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False, output_hidden_states=True)\n", "\n", - " return my_output, hf_output\n", - "\n", - "my_output_ids, hf_output_ids = test_RobertaModel_Output(key, ids=True)\n", - "my_output_embeds, hf_output_embeds = test_RobertaModel_Output(key, ids=False)" + " return my_output, hf_output" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "my_output_ids, hf_output_ids = test_RobertaModel_Output(key_model_run, ids=True)\n", + "my_output_embeds, hf_output_embeds = test_RobertaModel_Output(key_model_run, ids=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "model_out: acc: 1.0 \t norms: (tensor(888.5305), tensor(888.5305)) \t diffs: 3.735790983228071e-07\n", - "pool_out: acc: 1.0 \t norms: (tensor(24.1699), tensor(24.1699)) \t diffs: 2.6738598535303026e-07\n", + "model_out: acc: 1.0 \t norms: (tensor(888.5331), tensor(888.5331)) \t diffs: 6.946668236196274e-07\n", + "pool_out: acc: 1.0 \t norms: (tensor(24.6133), tensor(24.6133)) \t diffs: 4.2906631847472454e-07\n", "intermediates:\n", - "acc: 1.0 \t norms: (tensor(888.5306), tensor(888.5306)) \t diffs: 1.842970362986307e-07\n", - "acc: 1.0 \t norms: (tensor(888.5308), tensor(888.5308)) \t diffs: 2.625348543006112e-07\n", - "acc: 1.0 \t norms: (tensor(888.5301), tensor(888.5301)) \t diffs: 3.2068956556940975e-07\n", - "acc: 1.0 \t norms: (tensor(888.5295), tensor(888.5295)) \t diffs: 3.396997101390298e-07\n", - "acc: 1.0 \t norms: (tensor(888.5296), tensor(888.5297)) \t diffs: 3.419580139052414e-07\n", - "acc: 1.0 \t norms: (tensor(888.5310), tensor(888.5310)) \t diffs: 3.721844734627666e-07\n", - "acc: 1.0 \t norms: (tensor(888.5298), tensor(888.5299)) \t diffs: 3.591211736875266e-07\n", - "acc: 1.0 \t norms: (tensor(888.5299), tensor(888.5299)) \t diffs: 3.513960677992145e-07\n", - "acc: 1.0 \t norms: (tensor(888.5300), tensor(888.5300)) \t diffs: 3.739319538453856e-07\n", - "acc: 1.0 \t norms: (tensor(888.5305), tensor(888.5305)) \t diffs: 3.6201470265950775e-07\n", - "acc: 1.0 \t norms: (tensor(888.5309), tensor(888.5309)) \t diffs: 3.658180958154844e-07\n", - "acc: 1.0 \t norms: (tensor(888.5305), tensor(888.5305)) \t diffs: 3.735790983228071e-07\n" + "acc: 1.0 \t norms: (tensor(888.5315), tensor(888.5314)) \t diffs: 1.6926360046909394e-07\n", + "acc: 1.0 \t norms: (tensor(888.5301), tensor(888.5301)) \t diffs: 2.57225963196106e-07\n", + "acc: 1.0 \t norms: (tensor(888.5310), tensor(888.5311)) \t diffs: 3.1373855335914413e-07\n", + "acc: 1.0 \t norms: (tensor(888.5312), tensor(888.5312)) \t diffs: 3.6955742643840495e-07\n", + "acc: 1.0 \t norms: (tensor(888.5304), tensor(888.5304)) \t diffs: 4.1312114262836985e-07\n", + "acc: 1.0 \t norms: (tensor(888.5308), tensor(888.5308)) \t diffs: 4.612424220340472e-07\n", + "acc: 1.0 \t norms: (tensor(888.5308), tensor(888.5308)) \t diffs: 5.164657750356128e-07\n", + "acc: 1.0 \t norms: (tensor(888.5295), tensor(888.5295)) \t diffs: 5.412561563389318e-07\n", + "acc: 1.0 \t norms: (tensor(888.5309), tensor(888.5310)) \t diffs: 5.724053266931151e-07\n", + "acc: 1.0 \t norms: (tensor(888.5297), tensor(888.5297)) \t diffs: 5.980200512567535e-07\n", + "acc: 1.0 \t norms: (tensor(888.5317), tensor(888.5317)) \t diffs: 6.434746637751232e-07\n", + "acc: 1.0 \t norms: (tensor(888.5331), tensor(888.5331)) \t diffs: 6.946668236196274e-07\n" ] } ], "source": [ + "# RobertaModel ids\n", "my_out, hf_out = my_output_ids[0], hf_output_ids[0]\n", "\n", "print(f\"model_out: {check(my_out.array, hf_out.detach())[3]}\")\n", @@ -519,32 +686,34 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "model_out: acc: 1.0 \t norms: (tensor(888.5308), tensor(888.5308)) \t diffs: 3.864420250465628e-07\n", - "pool_out: acc: 1.0 \t norms: (tensor(24.0400), tensor(24.0400)) \t diffs: 2.5442224682592496e-07\n", + "model_out: acc: 1.0 \t norms: (tensor(888.5307), tensor(888.5306)) \t diffs: 6.107513286224275e-07\n", + "pool_out: acc: 1.0 \t norms: (tensor(24.6094), tensor(24.6094)) \t diffs: 3.8713346839358564e-07\n", "intermediates:\n", - "acc: 1.0 \t norms: (tensor(888.5295), tensor(888.5295)) \t diffs: 1.46142554058315e-07\n", - "acc: 1.0 \t norms: (tensor(888.5310), tensor(888.5311)) \t diffs: 2.35209114407553e-07\n", - "acc: 1.0 \t norms: (tensor(888.5301), tensor(888.5302)) \t diffs: 3.0417106700042496e-07\n", - "acc: 1.0 \t norms: (tensor(888.5303), tensor(888.5303)) \t diffs: 3.522779365994211e-07\n", - "acc: 1.0 \t norms: (tensor(888.5314), tensor(888.5314)) \t diffs: 3.7978762179591286e-07\n", - "acc: 1.0 \t norms: (tensor(888.5302), tensor(888.5302)) \t diffs: 3.9330373624579806e-07\n", - "acc: 1.0 \t norms: (tensor(888.5305), tensor(888.5304)) \t diffs: 3.8590263784499257e-07\n", - "acc: 1.0 \t norms: (tensor(888.5306), tensor(888.5306)) \t diffs: 3.735180200692412e-07\n", - "acc: 1.0 \t norms: (tensor(888.5292), tensor(888.5291)) \t diffs: 3.75227983795412e-07\n", - "acc: 1.0 \t norms: (tensor(888.5312), tensor(888.5312)) \t diffs: 3.7935546970402356e-07\n", - "acc: 1.0 \t norms: (tensor(888.5292), tensor(888.5292)) \t diffs: 3.81289993356404e-07\n", - "acc: 1.0 \t norms: (tensor(888.5308), tensor(888.5308)) \t diffs: 3.864420250465628e-07\n" + "acc: 1.0 \t norms: (tensor(888.5317), tensor(888.5317)) \t diffs: 1.3876427829018212e-07\n", + "acc: 1.0 \t norms: (tensor(888.5305), tensor(888.5305)) \t diffs: 2.239140144411067e-07\n", + "acc: 1.0 \t norms: (tensor(888.5290), tensor(888.5290)) \t diffs: 2.9642876597790746e-07\n", + "acc: 1.0 \t norms: (tensor(888.5301), tensor(888.5300)) \t diffs: 3.554245893155894e-07\n", + "acc: 1.0 \t norms: (tensor(888.5311), tensor(888.5311)) \t diffs: 4.070468264671945e-07\n", + "acc: 1.0 \t norms: (tensor(888.5311), tensor(888.5311)) \t diffs: 4.3890696588277933e-07\n", + "acc: 1.0 \t norms: (tensor(888.5305), tensor(888.5304)) \t diffs: 4.87373824853421e-07\n", + "acc: 1.0 \t norms: (tensor(888.5308), tensor(888.5308)) \t diffs: 5.209108735471091e-07\n", + "acc: 1.0 \t norms: (tensor(888.5298), tensor(888.5298)) \t diffs: 5.463080583467672e-07\n", + "acc: 1.0 \t norms: (tensor(888.5311), tensor(888.5311)) \t diffs: 5.576325179390551e-07\n", + "acc: 1.0 \t norms: (tensor(888.5295), tensor(888.5294)) \t diffs: 5.659875341734733e-07\n", + "acc: 1.0 \t norms: (tensor(888.5307), tensor(888.5306)) \t diffs: 6.107513286224275e-07\n" ] } ], "source": [ + "# RobertaModel embeds\n", + "\n", "my_out, hf_out = my_output_embeds[0], hf_output_embeds[0]\n", "\n", "print(f\"model_out: {check(my_out.array, hf_out.detach())[3]}\")\n", @@ -562,74 +731,59 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 21, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "seed: 1725471659\n" - ] - } - ], + "outputs": [], "source": [ - "seed = time()\n", - "print(f\"seed: {int(seed)}\")\n", - "key = jrandom.PRNGKey(int(seed))\n", - "\n", - "def test_RobertaForMaskedLM_Output(key, ids = True):\n", - " k_1, k_2 = jrandom.split(key, 2)\n", - " my_func = my_roberta.RobertaForMaskedLM.init(Vocab, my_config, output_hidden_states=True, key=k_1)\n", - " state = my_func.to_state_dict()\n", - "\n", - " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", - " \n", - " state[\"lm_head.bias\"] = torch.zeros(hf_config.vocab_size)\n", - "\n", - " hf_func = hf_roberta.RobertaForMaskedLM(hf_config)\n", - " hf_func.load_state_dict(state, strict=True)\n", - "\n", + "def test_RobertaForMaskedLM_Output(key_run, ids = False):\n", " if ids:\n", - " my_output = my_func(input_ids = input_ids, attention_mask=mask, key=k_2)\n", - " hf_output = hf_func(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False, output_hidden_states=True)\n", + " my_output = my_mlm(input_ids = input_ids, attention_mask=mask, key=key_run)\n", + " hf_output = hf_mlm(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False, output_hidden_states=True)\n", " else:\n", - " my_output = my_func(input_embeds = input_embeds, attention_mask=mask, key=k_2)\n", - " hf_output = hf_func(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False, output_hidden_states=True)\n", - "\n", - " return my_output, hf_output\n", + " my_output = my_mlm(input_embeds = input_embeds, attention_mask=mask, key=key_run)\n", + " hf_output = hf_mlm(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False, output_hidden_states=True)\n", "\n", - "my_mlm_output_ids, hf_mlm_output_ids = test_RobertaForMaskedLM_Output(key, ids=True)\n", - "my_mlm_output_embeds, hf_mlm_output_embeds = test_RobertaForMaskedLM_Output(key, ids=False)" + " return my_output, hf_output" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "my_mlm_output_ids, hf_mlm_output_ids = test_RobertaForMaskedLM_Output(key_run, ids=True)\n", + "my_mlm_output_embeds, hf_mlm_output_embeds = test_RobertaForMaskedLM_Output(key_run, ids=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "mlm_out: acc: 0.001719795589213743 \t norms: (tensor(7055.4185), tensor(7059.8276)) \t diffs: 0.06642061471939087\n", + "mlm_out: acc: 1.0 \t norms: (tensor(7054.6812), tensor(7054.6816)) \t diffs: 7.966510224832746e-07\n", "intermediates:\n", - "acc: 0.019604713845654993 \t norms: (tensor(888.5299), tensor(888.5306)) \t diffs: 0.5695138573646545\n", - "acc: 0.024172138456549936 \t norms: (tensor(888.5310), tensor(888.5300)) \t diffs: 0.46615326404571533\n", - "acc: 0.030564759646562904 \t norms: (tensor(888.5312), tensor(888.5323)) \t diffs: 0.37349948287010193\n", - "acc: 0.04000106395914397 \t norms: (tensor(888.5294), tensor(888.5300)) \t diffs: 0.28456440567970276\n", - "acc: 0.05416818660830091 \t norms: (tensor(888.5305), tensor(888.5297)) \t diffs: 0.21144814789295197\n", - "acc: 0.07160065053501946 \t norms: (tensor(888.5291), tensor(888.5300)) \t diffs: 0.16202779114246368\n", - "acc: 0.08982095087548637 \t norms: (tensor(888.5302), tensor(888.5295)) \t diffs: 0.12657980620861053\n", - "acc: 0.11758268482490272 \t norms: (tensor(888.5289), tensor(888.5308)) \t diffs: 0.09682736545801163\n", - "acc: 0.1463805123216602 \t norms: (tensor(888.5303), tensor(888.5312)) \t diffs: 0.07892335206270218\n", - "acc: 0.16748110205901426 \t norms: (tensor(888.5292), tensor(888.5310)) \t diffs: 0.06763624399900436\n", - "acc: 0.18012701645590143 \t norms: (tensor(888.5292), tensor(888.5284)) \t diffs: 0.062167659401893616\n", - "acc: 0.17417264510376135 \t norms: (tensor(888.5310), tensor(888.5300)) \t diffs: 0.06321203708648682\n" + "acc: 1.0 \t norms: (tensor(888.5315), tensor(888.5314)) \t diffs: 1.6926360046909394e-07\n", + "acc: 1.0 \t norms: (tensor(888.5301), tensor(888.5301)) \t diffs: 2.57225963196106e-07\n", + "acc: 1.0 \t norms: (tensor(888.5310), tensor(888.5311)) \t diffs: 3.1373855335914413e-07\n", + "acc: 1.0 \t norms: (tensor(888.5312), tensor(888.5312)) \t diffs: 3.6955742643840495e-07\n", + "acc: 1.0 \t norms: (tensor(888.5304), tensor(888.5304)) \t diffs: 4.1312114262836985e-07\n", + "acc: 1.0 \t norms: (tensor(888.5308), tensor(888.5308)) \t diffs: 4.612424220340472e-07\n", + "acc: 1.0 \t norms: (tensor(888.5308), tensor(888.5308)) \t diffs: 5.164657750356128e-07\n", + "acc: 1.0 \t norms: (tensor(888.5295), tensor(888.5295)) \t diffs: 5.412561563389318e-07\n", + "acc: 1.0 \t norms: (tensor(888.5309), tensor(888.5310)) \t diffs: 5.724053266931151e-07\n", + "acc: 1.0 \t norms: (tensor(888.5297), tensor(888.5297)) \t diffs: 5.980200512567535e-07\n", + "acc: 1.0 \t norms: (tensor(888.5317), tensor(888.5317)) \t diffs: 6.434746637751232e-07\n", + "acc: 1.0 \t norms: (tensor(888.5331), tensor(888.5331)) \t diffs: 6.946668236196274e-07\n" ] } ], "source": [ + "#Masked MLM ids\n", "my_out, hf_out = my_mlm_output_ids[0], hf_mlm_output_ids[0]\n", "\n", "print(f\"mlm_out: {check(my_out.array, hf_out.detach())[3]}\")\n", @@ -643,31 +797,32 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "mlm_out: acc: 1.0 \t norms: (tensor(7033.6245), tensor(7033.6240)) \t diffs: 5.174669013285893e-07\n", + "mlm_out: acc: 1.0 \t norms: (tensor(7107.9902), tensor(7107.9902)) \t diffs: 7.997662692105223e-07\n", "intermediates:\n", - "acc: 1.0 \t norms: (tensor(888.5309), tensor(888.5310)) \t diffs: 1.4689783256471856e-07\n", - "acc: 1.0 \t norms: (tensor(888.5306), tensor(888.5306)) \t diffs: 2.3626945733212779e-07\n", - "acc: 1.0 \t norms: (tensor(888.5312), tensor(888.5312)) \t diffs: 3.0318486210489937e-07\n", - "acc: 1.0 \t norms: (tensor(888.5305), tensor(888.5305)) \t diffs: 3.531020809077745e-07\n", - "acc: 1.0 \t norms: (tensor(888.5297), tensor(888.5297)) \t diffs: 3.7493518334486e-07\n", - "acc: 1.0 \t norms: (tensor(888.5297), tensor(888.5297)) \t diffs: 3.8230905374803115e-07\n", - "acc: 1.0 \t norms: (tensor(888.5295), tensor(888.5296)) \t diffs: 3.8595226214965805e-07\n", - "acc: 1.0 \t norms: (tensor(888.5308), tensor(888.5308)) \t diffs: 3.713914793479489e-07\n", - "acc: 1.0 \t norms: (tensor(888.5317), tensor(888.5318)) \t diffs: 3.5173252399545163e-07\n", - "acc: 1.0 \t norms: (tensor(888.5316), tensor(888.5316)) \t diffs: 3.4720503094831656e-07\n", - "acc: 1.0 \t norms: (tensor(888.5290), tensor(888.5290)) \t diffs: 3.541817932273261e-07\n", - "acc: 1.0 \t norms: (tensor(888.5295), tensor(888.5295)) \t diffs: 3.6307170603322447e-07\n" + "acc: 1.0 \t norms: (tensor(888.5317), tensor(888.5317)) \t diffs: 1.3876427829018212e-07\n", + "acc: 1.0 \t norms: (tensor(888.5305), tensor(888.5305)) \t diffs: 2.239140144411067e-07\n", + "acc: 1.0 \t norms: (tensor(888.5290), tensor(888.5290)) \t diffs: 2.9642876597790746e-07\n", + "acc: 1.0 \t norms: (tensor(888.5301), tensor(888.5300)) \t diffs: 3.554245893155894e-07\n", + "acc: 1.0 \t norms: (tensor(888.5311), tensor(888.5311)) \t diffs: 4.070468264671945e-07\n", + "acc: 1.0 \t norms: (tensor(888.5311), tensor(888.5311)) \t diffs: 4.3890696588277933e-07\n", + "acc: 1.0 \t norms: (tensor(888.5305), tensor(888.5304)) \t diffs: 4.87373824853421e-07\n", + "acc: 1.0 \t norms: (tensor(888.5308), tensor(888.5308)) \t diffs: 5.209108735471091e-07\n", + "acc: 1.0 \t norms: (tensor(888.5298), tensor(888.5298)) \t diffs: 5.463080583467672e-07\n", + "acc: 1.0 \t norms: (tensor(888.5311), tensor(888.5311)) \t diffs: 5.576325179390551e-07\n", + "acc: 1.0 \t norms: (tensor(888.5295), tensor(888.5294)) \t diffs: 5.659875341734733e-07\n", + "acc: 1.0 \t norms: (tensor(888.5307), tensor(888.5306)) \t diffs: 6.107513286224275e-07\n" ] } ], "source": [ + "#Masked MLM embeds\n", "my_out, hf_out = my_mlm_output_embeds[0], hf_mlm_output_embeds[0]\n", "\n", "print(f\"mlm_out: {check(my_out.array, hf_out.detach())[3]}\")\n", @@ -681,116 +836,166 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "# # Testing RobertaModel\n", + "# my_model = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=False, key=key_rob)\n", + "# state_model = my_model.to_state_dict()\n", + "\n", + "# state_model = {k: torch.from_numpy(np.array(v)) for k, v in state_model.items()}\n", + "\n", + "# hf_model = hf_roberta.RobertaModel(hf_config, add_pooling_layer=False)\n", + "# hf_model.load_state_dict(state_model, strict=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "dict_keys(['roberta.encoder.layer.0.attention.self.query.weight', 'roberta.encoder.layer.0.attention.self.query.bias', 'roberta.encoder.layer.0.attention.self.key.weight', 'roberta.encoder.layer.0.attention.self.key.bias', 'roberta.encoder.layer.0.attention.self.value.weight', 'roberta.encoder.layer.0.attention.self.value.bias', 'roberta.encoder.layer.0.attention.output.dense.weight', 'roberta.encoder.layer.0.attention.output.dense.bias', 'roberta.encoder.layer.0.attention.output.LayerNorm.weight', 'roberta.encoder.layer.0.attention.output.LayerNorm.bias', 'roberta.encoder.layer.0.intermediate.dense.weight', 'roberta.encoder.layer.0.intermediate.dense.bias', 'roberta.encoder.layer.0.output.dense.weight', 'roberta.encoder.layer.0.output.dense.bias', 'roberta.encoder.layer.0.output.LayerNorm.weight', 'roberta.encoder.layer.0.output.LayerNorm.bias', 'roberta.encoder.layer.1.attention.self.query.weight', 'roberta.encoder.layer.1.attention.self.query.bias', 'roberta.encoder.layer.1.attention.self.key.weight', 'roberta.encoder.layer.1.attention.self.key.bias', 'roberta.encoder.layer.1.attention.self.value.weight', 'roberta.encoder.layer.1.attention.self.value.bias', 'roberta.encoder.layer.1.attention.output.dense.weight', 'roberta.encoder.layer.1.attention.output.dense.bias', 'roberta.encoder.layer.1.attention.output.LayerNorm.weight', 'roberta.encoder.layer.1.attention.output.LayerNorm.bias', 'roberta.encoder.layer.1.intermediate.dense.weight', 'roberta.encoder.layer.1.intermediate.dense.bias', 'roberta.encoder.layer.1.output.dense.weight', 'roberta.encoder.layer.1.output.dense.bias', 'roberta.encoder.layer.1.output.LayerNorm.weight', 'roberta.encoder.layer.1.output.LayerNorm.bias', 'roberta.encoder.layer.2.attention.self.query.weight', 'roberta.encoder.layer.2.attention.self.query.bias', 'roberta.encoder.layer.2.attention.self.key.weight', 'roberta.encoder.layer.2.attention.self.key.bias', 'roberta.encoder.layer.2.attention.self.value.weight', 'roberta.encoder.layer.2.attention.self.value.bias', 'roberta.encoder.layer.2.attention.output.dense.weight', 'roberta.encoder.layer.2.attention.output.dense.bias', 'roberta.encoder.layer.2.attention.output.LayerNorm.weight', 'roberta.encoder.layer.2.attention.output.LayerNorm.bias', 'roberta.encoder.layer.2.intermediate.dense.weight', 'roberta.encoder.layer.2.intermediate.dense.bias', 'roberta.encoder.layer.2.output.dense.weight', 'roberta.encoder.layer.2.output.dense.bias', 'roberta.encoder.layer.2.output.LayerNorm.weight', 'roberta.encoder.layer.2.output.LayerNorm.bias', 'roberta.encoder.layer.3.attention.self.query.weight', 'roberta.encoder.layer.3.attention.self.query.bias', 'roberta.encoder.layer.3.attention.self.key.weight', 'roberta.encoder.layer.3.attention.self.key.bias', 'roberta.encoder.layer.3.attention.self.value.weight', 'roberta.encoder.layer.3.attention.self.value.bias', 'roberta.encoder.layer.3.attention.output.dense.weight', 'roberta.encoder.layer.3.attention.output.dense.bias', 'roberta.encoder.layer.3.attention.output.LayerNorm.weight', 'roberta.encoder.layer.3.attention.output.LayerNorm.bias', 'roberta.encoder.layer.3.intermediate.dense.weight', 'roberta.encoder.layer.3.intermediate.dense.bias', 'roberta.encoder.layer.3.output.dense.weight', 'roberta.encoder.layer.3.output.dense.bias', 'roberta.encoder.layer.3.output.LayerNorm.weight', 'roberta.encoder.layer.3.output.LayerNorm.bias', 'roberta.encoder.layer.4.attention.self.query.weight', 'roberta.encoder.layer.4.attention.self.query.bias', 'roberta.encoder.layer.4.attention.self.key.weight', 'roberta.encoder.layer.4.attention.self.key.bias', 'roberta.encoder.layer.4.attention.self.value.weight', 'roberta.encoder.layer.4.attention.self.value.bias', 'roberta.encoder.layer.4.attention.output.dense.weight', 'roberta.encoder.layer.4.attention.output.dense.bias', 'roberta.encoder.layer.4.attention.output.LayerNorm.weight', 'roberta.encoder.layer.4.attention.output.LayerNorm.bias', 'roberta.encoder.layer.4.intermediate.dense.weight', 'roberta.encoder.layer.4.intermediate.dense.bias', 'roberta.encoder.layer.4.output.dense.weight', 'roberta.encoder.layer.4.output.dense.bias', 'roberta.encoder.layer.4.output.LayerNorm.weight', 'roberta.encoder.layer.4.output.LayerNorm.bias', 'roberta.encoder.layer.5.attention.self.query.weight', 'roberta.encoder.layer.5.attention.self.query.bias', 'roberta.encoder.layer.5.attention.self.key.weight', 'roberta.encoder.layer.5.attention.self.key.bias', 'roberta.encoder.layer.5.attention.self.value.weight', 'roberta.encoder.layer.5.attention.self.value.bias', 'roberta.encoder.layer.5.attention.output.dense.weight', 'roberta.encoder.layer.5.attention.output.dense.bias', 'roberta.encoder.layer.5.attention.output.LayerNorm.weight', 'roberta.encoder.layer.5.attention.output.LayerNorm.bias', 'roberta.encoder.layer.5.intermediate.dense.weight', 'roberta.encoder.layer.5.intermediate.dense.bias', 'roberta.encoder.layer.5.output.dense.weight', 'roberta.encoder.layer.5.output.dense.bias', 'roberta.encoder.layer.5.output.LayerNorm.weight', 'roberta.encoder.layer.5.output.LayerNorm.bias', 'roberta.encoder.layer.6.attention.self.query.weight', 'roberta.encoder.layer.6.attention.self.query.bias', 'roberta.encoder.layer.6.attention.self.key.weight', 'roberta.encoder.layer.6.attention.self.key.bias', 'roberta.encoder.layer.6.attention.self.value.weight', 'roberta.encoder.layer.6.attention.self.value.bias', 'roberta.encoder.layer.6.attention.output.dense.weight', 'roberta.encoder.layer.6.attention.output.dense.bias', 'roberta.encoder.layer.6.attention.output.LayerNorm.weight', 'roberta.encoder.layer.6.attention.output.LayerNorm.bias', 'roberta.encoder.layer.6.intermediate.dense.weight', 'roberta.encoder.layer.6.intermediate.dense.bias', 'roberta.encoder.layer.6.output.dense.weight', 'roberta.encoder.layer.6.output.dense.bias', 'roberta.encoder.layer.6.output.LayerNorm.weight', 'roberta.encoder.layer.6.output.LayerNorm.bias', 'roberta.encoder.layer.7.attention.self.query.weight', 'roberta.encoder.layer.7.attention.self.query.bias', 'roberta.encoder.layer.7.attention.self.key.weight', 'roberta.encoder.layer.7.attention.self.key.bias', 'roberta.encoder.layer.7.attention.self.value.weight', 'roberta.encoder.layer.7.attention.self.value.bias', 'roberta.encoder.layer.7.attention.output.dense.weight', 'roberta.encoder.layer.7.attention.output.dense.bias', 'roberta.encoder.layer.7.attention.output.LayerNorm.weight', 'roberta.encoder.layer.7.attention.output.LayerNorm.bias', 'roberta.encoder.layer.7.intermediate.dense.weight', 'roberta.encoder.layer.7.intermediate.dense.bias', 'roberta.encoder.layer.7.output.dense.weight', 'roberta.encoder.layer.7.output.dense.bias', 'roberta.encoder.layer.7.output.LayerNorm.weight', 'roberta.encoder.layer.7.output.LayerNorm.bias', 'roberta.encoder.layer.8.attention.self.query.weight', 'roberta.encoder.layer.8.attention.self.query.bias', 'roberta.encoder.layer.8.attention.self.key.weight', 'roberta.encoder.layer.8.attention.self.key.bias', 'roberta.encoder.layer.8.attention.self.value.weight', 'roberta.encoder.layer.8.attention.self.value.bias', 'roberta.encoder.layer.8.attention.output.dense.weight', 'roberta.encoder.layer.8.attention.output.dense.bias', 'roberta.encoder.layer.8.attention.output.LayerNorm.weight', 'roberta.encoder.layer.8.attention.output.LayerNorm.bias', 'roberta.encoder.layer.8.intermediate.dense.weight', 'roberta.encoder.layer.8.intermediate.dense.bias', 'roberta.encoder.layer.8.output.dense.weight', 'roberta.encoder.layer.8.output.dense.bias', 'roberta.encoder.layer.8.output.LayerNorm.weight', 'roberta.encoder.layer.8.output.LayerNorm.bias', 'roberta.encoder.layer.9.attention.self.query.weight', 'roberta.encoder.layer.9.attention.self.query.bias', 'roberta.encoder.layer.9.attention.self.key.weight', 'roberta.encoder.layer.9.attention.self.key.bias', 'roberta.encoder.layer.9.attention.self.value.weight', 'roberta.encoder.layer.9.attention.self.value.bias', 'roberta.encoder.layer.9.attention.output.dense.weight', 'roberta.encoder.layer.9.attention.output.dense.bias', 'roberta.encoder.layer.9.attention.output.LayerNorm.weight', 'roberta.encoder.layer.9.attention.output.LayerNorm.bias', 'roberta.encoder.layer.9.intermediate.dense.weight', 'roberta.encoder.layer.9.intermediate.dense.bias', 'roberta.encoder.layer.9.output.dense.weight', 'roberta.encoder.layer.9.output.dense.bias', 'roberta.encoder.layer.9.output.LayerNorm.weight', 'roberta.encoder.layer.9.output.LayerNorm.bias', 'roberta.encoder.layer.10.attention.self.query.weight', 'roberta.encoder.layer.10.attention.self.query.bias', 'roberta.encoder.layer.10.attention.self.key.weight', 'roberta.encoder.layer.10.attention.self.key.bias', 'roberta.encoder.layer.10.attention.self.value.weight', 'roberta.encoder.layer.10.attention.self.value.bias', 'roberta.encoder.layer.10.attention.output.dense.weight', 'roberta.encoder.layer.10.attention.output.dense.bias', 'roberta.encoder.layer.10.attention.output.LayerNorm.weight', 'roberta.encoder.layer.10.attention.output.LayerNorm.bias', 'roberta.encoder.layer.10.intermediate.dense.weight', 'roberta.encoder.layer.10.intermediate.dense.bias', 'roberta.encoder.layer.10.output.dense.weight', 'roberta.encoder.layer.10.output.dense.bias', 'roberta.encoder.layer.10.output.LayerNorm.weight', 'roberta.encoder.layer.10.output.LayerNorm.bias', 'roberta.encoder.layer.11.attention.self.query.weight', 'roberta.encoder.layer.11.attention.self.query.bias', 'roberta.encoder.layer.11.attention.self.key.weight', 'roberta.encoder.layer.11.attention.self.key.bias', 'roberta.encoder.layer.11.attention.self.value.weight', 'roberta.encoder.layer.11.attention.self.value.bias', 'roberta.encoder.layer.11.attention.output.dense.weight', 'roberta.encoder.layer.11.attention.output.dense.bias', 'roberta.encoder.layer.11.attention.output.LayerNorm.weight', 'roberta.encoder.layer.11.attention.output.LayerNorm.bias', 'roberta.encoder.layer.11.intermediate.dense.weight', 'roberta.encoder.layer.11.intermediate.dense.bias', 'roberta.encoder.layer.11.output.dense.weight', 'roberta.encoder.layer.11.output.dense.bias', 'roberta.encoder.layer.11.output.LayerNorm.weight', 'roberta.encoder.layer.11.output.LayerNorm.bias', 'roberta.embeddings.word_embeddings.weight', 'roberta.embeddings.position_embeddings.weight', 'roberta.embeddings.token_type_embeddings.weight', 'roberta.embeddings.LayerNorm.weight', 'roberta.embeddings.LayerNorm.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.decoder.bias', 'lm_head.bias'])\n", - "dict_keys(['encoder.layer.0.attention.self.query.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.layer.0.intermediate.dense.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.0.output.dense.weight', 'encoder.layer.0.output.dense.bias', 'encoder.layer.0.output.LayerNorm.weight', 'encoder.layer.0.output.LayerNorm.bias', 'encoder.layer.1.attention.self.query.weight', 'encoder.layer.1.attention.self.query.bias', 'encoder.layer.1.attention.self.key.weight', 'encoder.layer.1.attention.self.key.bias', 'encoder.layer.1.attention.self.value.weight', 'encoder.layer.1.attention.self.value.bias', 'encoder.layer.1.attention.output.dense.weight', 'encoder.layer.1.attention.output.dense.bias', 'encoder.layer.1.attention.output.LayerNorm.weight', 'encoder.layer.1.attention.output.LayerNorm.bias', 'encoder.layer.1.intermediate.dense.weight', 'encoder.layer.1.intermediate.dense.bias', 'encoder.layer.1.output.dense.weight', 'encoder.layer.1.output.dense.bias', 'encoder.layer.1.output.LayerNorm.weight', 'encoder.layer.1.output.LayerNorm.bias', 'encoder.layer.2.attention.self.query.weight', 'encoder.layer.2.attention.self.query.bias', 'encoder.layer.2.attention.self.key.weight', 'encoder.layer.2.attention.self.key.bias', 'encoder.layer.2.attention.self.value.weight', 'encoder.layer.2.attention.self.value.bias', 'encoder.layer.2.attention.output.dense.weight', 'encoder.layer.2.attention.output.dense.bias', 'encoder.layer.2.attention.output.LayerNorm.weight', 'encoder.layer.2.attention.output.LayerNorm.bias', 'encoder.layer.2.intermediate.dense.weight', 'encoder.layer.2.intermediate.dense.bias', 'encoder.layer.2.output.dense.weight', 'encoder.layer.2.output.dense.bias', 'encoder.layer.2.output.LayerNorm.weight', 'encoder.layer.2.output.LayerNorm.bias', 'encoder.layer.3.attention.self.query.weight', 'encoder.layer.3.attention.self.query.bias', 'encoder.layer.3.attention.self.key.weight', 'encoder.layer.3.attention.self.key.bias', 'encoder.layer.3.attention.self.value.weight', 'encoder.layer.3.attention.self.value.bias', 'encoder.layer.3.attention.output.dense.weight', 'encoder.layer.3.attention.output.dense.bias', 'encoder.layer.3.attention.output.LayerNorm.weight', 'encoder.layer.3.attention.output.LayerNorm.bias', 'encoder.layer.3.intermediate.dense.weight', 'encoder.layer.3.intermediate.dense.bias', 'encoder.layer.3.output.dense.weight', 'encoder.layer.3.output.dense.bias', 'encoder.layer.3.output.LayerNorm.weight', 'encoder.layer.3.output.LayerNorm.bias', 'encoder.layer.4.attention.self.query.weight', 'encoder.layer.4.attention.self.query.bias', 'encoder.layer.4.attention.self.key.weight', 'encoder.layer.4.attention.self.key.bias', 'encoder.layer.4.attention.self.value.weight', 'encoder.layer.4.attention.self.value.bias', 'encoder.layer.4.attention.output.dense.weight', 'encoder.layer.4.attention.output.dense.bias', 'encoder.layer.4.attention.output.LayerNorm.weight', 'encoder.layer.4.attention.output.LayerNorm.bias', 'encoder.layer.4.intermediate.dense.weight', 'encoder.layer.4.intermediate.dense.bias', 'encoder.layer.4.output.dense.weight', 'encoder.layer.4.output.dense.bias', 'encoder.layer.4.output.LayerNorm.weight', 'encoder.layer.4.output.LayerNorm.bias', 'encoder.layer.5.attention.self.query.weight', 'encoder.layer.5.attention.self.query.bias', 'encoder.layer.5.attention.self.key.weight', 'encoder.layer.5.attention.self.key.bias', 'encoder.layer.5.attention.self.value.weight', 'encoder.layer.5.attention.self.value.bias', 'encoder.layer.5.attention.output.dense.weight', 'encoder.layer.5.attention.output.dense.bias', 'encoder.layer.5.attention.output.LayerNorm.weight', 'encoder.layer.5.attention.output.LayerNorm.bias', 'encoder.layer.5.intermediate.dense.weight', 'encoder.layer.5.intermediate.dense.bias', 'encoder.layer.5.output.dense.weight', 'encoder.layer.5.output.dense.bias', 'encoder.layer.5.output.LayerNorm.weight', 'encoder.layer.5.output.LayerNorm.bias', 'encoder.layer.6.attention.self.query.weight', 'encoder.layer.6.attention.self.query.bias', 'encoder.layer.6.attention.self.key.weight', 'encoder.layer.6.attention.self.key.bias', 'encoder.layer.6.attention.self.value.weight', 'encoder.layer.6.attention.self.value.bias', 'encoder.layer.6.attention.output.dense.weight', 'encoder.layer.6.attention.output.dense.bias', 'encoder.layer.6.attention.output.LayerNorm.weight', 'encoder.layer.6.attention.output.LayerNorm.bias', 'encoder.layer.6.intermediate.dense.weight', 'encoder.layer.6.intermediate.dense.bias', 'encoder.layer.6.output.dense.weight', 'encoder.layer.6.output.dense.bias', 'encoder.layer.6.output.LayerNorm.weight', 'encoder.layer.6.output.LayerNorm.bias', 'encoder.layer.7.attention.self.query.weight', 'encoder.layer.7.attention.self.query.bias', 'encoder.layer.7.attention.self.key.weight', 'encoder.layer.7.attention.self.key.bias', 'encoder.layer.7.attention.self.value.weight', 'encoder.layer.7.attention.self.value.bias', 'encoder.layer.7.attention.output.dense.weight', 'encoder.layer.7.attention.output.dense.bias', 'encoder.layer.7.attention.output.LayerNorm.weight', 'encoder.layer.7.attention.output.LayerNorm.bias', 'encoder.layer.7.intermediate.dense.weight', 'encoder.layer.7.intermediate.dense.bias', 'encoder.layer.7.output.dense.weight', 'encoder.layer.7.output.dense.bias', 'encoder.layer.7.output.LayerNorm.weight', 'encoder.layer.7.output.LayerNorm.bias', 'encoder.layer.8.attention.self.query.weight', 'encoder.layer.8.attention.self.query.bias', 'encoder.layer.8.attention.self.key.weight', 'encoder.layer.8.attention.self.key.bias', 'encoder.layer.8.attention.self.value.weight', 'encoder.layer.8.attention.self.value.bias', 'encoder.layer.8.attention.output.dense.weight', 'encoder.layer.8.attention.output.dense.bias', 'encoder.layer.8.attention.output.LayerNorm.weight', 'encoder.layer.8.attention.output.LayerNorm.bias', 'encoder.layer.8.intermediate.dense.weight', 'encoder.layer.8.intermediate.dense.bias', 'encoder.layer.8.output.dense.weight', 'encoder.layer.8.output.dense.bias', 'encoder.layer.8.output.LayerNorm.weight', 'encoder.layer.8.output.LayerNorm.bias', 'encoder.layer.9.attention.self.query.weight', 'encoder.layer.9.attention.self.query.bias', 'encoder.layer.9.attention.self.key.weight', 'encoder.layer.9.attention.self.key.bias', 'encoder.layer.9.attention.self.value.weight', 'encoder.layer.9.attention.self.value.bias', 'encoder.layer.9.attention.output.dense.weight', 'encoder.layer.9.attention.output.dense.bias', 'encoder.layer.9.attention.output.LayerNorm.weight', 'encoder.layer.9.attention.output.LayerNorm.bias', 'encoder.layer.9.intermediate.dense.weight', 'encoder.layer.9.intermediate.dense.bias', 'encoder.layer.9.output.dense.weight', 'encoder.layer.9.output.dense.bias', 'encoder.layer.9.output.LayerNorm.weight', 'encoder.layer.9.output.LayerNorm.bias', 'encoder.layer.10.attention.self.query.weight', 'encoder.layer.10.attention.self.query.bias', 'encoder.layer.10.attention.self.key.weight', 'encoder.layer.10.attention.self.key.bias', 'encoder.layer.10.attention.self.value.weight', 'encoder.layer.10.attention.self.value.bias', 'encoder.layer.10.attention.output.dense.weight', 'encoder.layer.10.attention.output.dense.bias', 'encoder.layer.10.attention.output.LayerNorm.weight', 'encoder.layer.10.attention.output.LayerNorm.bias', 'encoder.layer.10.intermediate.dense.weight', 'encoder.layer.10.intermediate.dense.bias', 'encoder.layer.10.output.dense.weight', 'encoder.layer.10.output.dense.bias', 'encoder.layer.10.output.LayerNorm.weight', 'encoder.layer.10.output.LayerNorm.bias', 'encoder.layer.11.attention.self.query.weight', 'encoder.layer.11.attention.self.query.bias', 'encoder.layer.11.attention.self.key.weight', 'encoder.layer.11.attention.self.key.bias', 'encoder.layer.11.attention.self.value.weight', 'encoder.layer.11.attention.self.value.bias', 'encoder.layer.11.attention.output.dense.weight', 'encoder.layer.11.attention.output.dense.bias', 'encoder.layer.11.attention.output.LayerNorm.weight', 'encoder.layer.11.attention.output.LayerNorm.bias', 'encoder.layer.11.intermediate.dense.weight', 'encoder.layer.11.intermediate.dense.bias', 'encoder.layer.11.output.dense.weight', 'encoder.layer.11.output.dense.bias', 'encoder.layer.11.output.LayerNorm.weight', 'encoder.layer.11.output.LayerNorm.bias', 'embeddings.word_embeddings.weight', 'embeddings.position_embeddings.weight', 'embeddings.token_type_embeddings.weight', 'embeddings.LayerNorm.weight', 'embeddings.LayerNorm.bias'])\n", - "dict_keys(['dense.weight', 'dense.bias', 'layer_norm.weight', 'layer_norm.bias', 'decoder.weight', 'decoder.bias', 'bias'])\n" + "dict_keys(['dense.weight', 'dense.bias', 'layer_norm.weight', 'layer_norm.bias', 'decoder.weight', 'decoder.bias', 'bias'])\n", + "odict_keys(['bias', 'dense.weight', 'dense.bias', 'layer_norm.weight', 'layer_norm.bias', 'decoder.weight', 'decoder.bias'])\n", + "success1\n", + "success2\n", + "0.0\n" ] - }, - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" } ], "source": [ - "# Testing RobertaForMaskedLM\n", - "\n", - "my_mlm = my_roberta.RobertaForMaskedLM.init(Vocab, my_config, key=key)\n", - "state_mlm = my_mlm.to_state_dict()\n", - "\n", - "state_mlm = {k: torch.from_numpy(np.array(v)) for k, v in state_mlm.items()}\n", - "\n", - "# if \"lm_head.decoder.bias\" in state:\n", - "# print(state[\"lm_head.decoder.bias\"])\n", - "# else:\n", - "# print(f\"RobertaForMaskedLM, {state.keys()}\")\n", - "\n", - "state_mlm[\"lm_head.bias\"] = torch.zeros(hf_config.vocab_size)\n", - "\n", - "print(state_mlm.keys())\n", - "\n", - "hf_mlm = hf_roberta.RobertaForMaskedLM(hf_config)\n", - "hf_mlm.load_state_dict(state_mlm, strict=True)\n", - "\n", - "# Testing RobertaModel\n", - "\n", - "key_rob, key_head = jrandom.split(key, 2)\n", + "# Testing RobertaLMHead\n", + "my_head = my_roberta.RobertaLMHead.init(Vocab, my_config, key=key_lm)\n", + "state_head = my_head.to_state_dict()\n", "\n", - "my_model = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=False, key=key_rob)\n", - "state_model = my_model.to_state_dict()\n", + "state_head = {k: torch.from_numpy(np.array(v)) for k, v in state_head.items()}\n", + "state_head[\"bias\"] = torch.zeros(hf_config.vocab_size)\n", "\n", - "state_model = {k: torch.from_numpy(np.array(v)) for k, v in state_model.items()}\n", + "hf_head = hf_roberta.RobertaLMHead(hf_config)\n", + "hf_head.load_state_dict(state_head, strict=True)\n", "\n", - "print(state_model.keys())\n", + "print(check_dicts(state_head, hf_head.state_dict()))" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "my_output_mlm, hf_output_mlm = test_RobertaForMaskedLM_Output(key_run, ids = False)\n", "\n", - "hf_model = hf_roberta.RobertaModel(hf_config, add_pooling_layer=False)\n", - "hf_model.load_state_dict(state_model, strict=True)\n", + "my_output_model, hf_output_model = test_RobertaModel_Output(key_model_run, ids = False)\n", + "my_output = my_head(my_output_model[0], key=key_lm_run)\n", + "hf_output = hf_head(hf_output_model[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "# k_rob, k_lm = jrandom.split(key, 2)\n", "\n", - "# Testing RobertaLMHead\n", + "# # MLM\n", + "# my_output_mlm = my_mlm(input_embeds = input_embeds, attention_mask=mask, key=key)\n", + "# hf_output_mlm = hf_mlm(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\n", "\n", - "my_head = my_roberta.RobertaLMHead.init(Vocab, my_config, key=key_head)\n", - "state_head = my_head.to_state_dict()\n", + "# # Model + LM\n", "\n", - "state_head = {k: torch.from_numpy(np.array(v)) for k, v in state_head.items()}\n", + "# my_output_model = my_model(input_embeds = input_embeds, attention_mask=mask, key=k_rob)\n", + "# my_output = my_head(my_output_model[0], key=k_lm)\n", "\n", - "state_head[\"bias\"] = torch.zeros(hf_config.vocab_size)\n", + "# hf_output_model = hf_model(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\n", + "# hf_output = hf_head(hf_output_model[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "RobertaModel: acc: 1.0 \t norms: (tensor(888.5307), tensor(888.5306)) \t diffs: 6.107513286224275e-07\n", + "Roberta Model + LM head: acc: 1.0 \t norms: (tensor(7107.9902), tensor(7107.9902)) \t diffs: 7.997662692105223e-07\n", + "MLM: acc: 1.0 \t norms: (tensor(7107.9902), tensor(7107.9902)) \t diffs: 7.997662692105223e-07\n", + "my RobertaModel + LM head vs MLM: acc: 1.0 \t norms: (tensor(7107.9902), tensor(7107.9902)) \t diffs: 0.0\n", + "hf RobertaModel + LM head vs MLM: acc: 1.0 \t norms: (tensor(7107.9902), tensor(7107.9902)) \t diffs: 0.0\n" + ] + } + ], + "source": [ + "# embeds\n", + "print(f\"RobertaModel: {check(my_output_model[0].array, hf_output_model[0].detach())[3]}\")\n", + "print(f\"Roberta Model + LM head: {check(my_output.array, hf_output.detach())[3]}\")\n", + "print(f\"MLM: {check(my_output_mlm[0].array, hf_output_mlm[0].detach())[3]}\")\n", "\n", - "print(state_head.keys())\n", + "print(f\"my RobertaModel + LM head vs MLM: {check(my_output.array, my_output_mlm[0].array)[3]}\")\n", + "print(f\"hf RobertaModel + LM head vs MLM: {check(hf_output.detach(), hf_output_mlm[0].detach())[3]}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [], + "source": [ + "my_output_mlm, hf_output_mlm = test_RobertaForMaskedLM_Output(key_run, ids = True)\n", "\n", - "hf_head = hf_roberta.RobertaLMHead(hf_config)\n", - "hf_head.load_state_dict(state_head, strict=True)" + "my_output_model, hf_output_model = test_RobertaModel_Output(key_model_run, ids = True)\n", + "my_output = my_head(my_output_model[0], key=key_lm_run)\n", + "hf_output = hf_head(hf_output_model[0])" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 31, "metadata": {}, "outputs": [], "source": [ - "k_rob, k_lm = jrandom.split(key, 2)\n", + "# # MLM\n", + "# my_output_mlm = my_mlm(input_ids = input_ids, attention_mask=mask, key=key)\n", + "# hf_output_mlm = hf_mlm(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\n", "\n", - "# MLM\n", - "my_output_mlm = my_mlm(input_embeds = input_embeds, attention_mask=mask, key=key)\n", - "hf_output_mlm = hf_mlm(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\n", + "# # Model + LM\n", + "# my_output_model = my_model(input_ids = input_ids, attention_mask=mask, key=k_rob)\n", + "# my_output = my_head(my_output_model[0], key=k_lm)\n", "\n", - "# Model + LM\n", - "my_output_model = my_model(input_embeds = input_embeds, attention_mask=mask, key=k_rob)\n", - "my_output = my_head(my_output_model[0], key=k_lm)\n", + "# hf_output_model = hf_model(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\n", + "# hf_output = hf_head(hf_output_model[0])\n", "\n", - "hf_output_model = hf_model(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\n", - "hf_output = hf_head(hf_output_model[0])" + "# # Notes\n", + "# # embeds works between hf and my, and within model->head and mlm \n", + "# # ids does not work between hf and my for mlm or within hf for model->head and mlm - so hf mlm is doing something weird.'''" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "RobertaModel: acc: 1.0 \t norms: (tensor(888.5305), tensor(888.5305)) \t diffs: 3.802756225468329e-07\n", - "Roberta Model + LM head: acc: 1.0 \t norms: (tensor(7087.6094), tensor(7087.6094)) \t diffs: 5.642933729177457e-07\n", - "MLM: acc: 1.0 \t norms: (tensor(7087.6094), tensor(7087.6094)) \t diffs: 5.642933729177457e-07\n", - "my RobertaModel + LM head vs MLM: acc: 1.0 \t norms: (tensor(7087.6094), tensor(7087.6094)) \t diffs: 0.0\n", - "hf RobertaModel + LM head vs MLM: acc: 1.0 \t norms: (tensor(7087.6094), tensor(7087.6094)) \t diffs: 0.0\n" + "RobertaModel: acc: 1.0 \t norms: (tensor(888.5331), tensor(888.5331)) \t diffs: 6.946668236196274e-07\n", + "Roberta Model + LM head: acc: 1.0 \t norms: (tensor(7054.6812), tensor(7054.6816)) \t diffs: 7.966510224832746e-07\n", + "MLM: acc: 1.0 \t norms: (tensor(7054.6812), tensor(7054.6816)) \t diffs: 7.966510224832746e-07\n", + "my RobertaModel + LM head vs MLM: acc: 1.0 \t norms: (tensor(7054.6812), tensor(7054.6812)) \t diffs: 0.0\n", + "hf RobertaModel + LM head vs MLM: acc: 1.0 \t norms: (tensor(7054.6816), tensor(7054.6816)) \t diffs: 0.0\n" ] } ], "source": [ + "# ids\n", "print(f\"RobertaModel: {check(my_output_model[0].array, hf_output_model[0].detach())[3]}\")\n", "print(f\"Roberta Model + LM head: {check(my_output.array, hf_output.detach())[3]}\")\n", "print(f\"MLM: {check(my_output_mlm[0].array, hf_output_mlm[0].detach())[3]}\")\n", @@ -801,55 +1006,391 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'stop' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[1;32mIn[33], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m \u001b[43mstop\u001b[49m\n", + "\u001b[1;31mNameError\u001b[0m: name 'stop' is not defined" + ] + } + ], + "source": [ + "stop" + ] + }, + { + "cell_type": "code", + "execution_count": 38, "metadata": {}, "outputs": [], "source": [ - "# MLM\n", + "# Load pretrained weights from hf\n", + "hf_model = hf_roberta.RobertaModel.from_pretrained(\"roberta-base\")\n", + "state_model = hf_model.state_dict()\n", "\n", - "my_output_mlm = my_mlm(input_ids = input_ids, attention_mask=mask, key=key)\n", - "hf_output_mlm = hf_mlm(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\n", + "state_model = {k: np.array(v) for k, v in state_model.items()}\n", "\n", - "# Model + LM\n", + "hf_config = hf_model.config\n", + "hf_config.hidden_dropout_prob = 0\n", + "hf_config.attention_probs_dropout_prob = 0\n", + "my_config = my_roberta.RobertaConfig.from_hf_config(hf_config)\n", "\n", - "my_output_model = my_model(input_ids = input_ids, attention_mask=mask, key=k_rob)\n", - "my_output = my_head(my_output_model[0], key=k_lm)\n", + "my_model = my_roberta.RobertaModel.init(Vocab, my_config, output_hidden_states=True, key=key)\n", + "my_model = my_model.from_state_dict(state_model)" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dict_keys(['encoder.layer.0.attention.self.query.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.layer.0.intermediate.dense.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.0.output.dense.weight', 'encoder.layer.0.output.dense.bias', 'encoder.layer.0.output.LayerNorm.weight', 'encoder.layer.0.output.LayerNorm.bias', 'encoder.layer.1.attention.self.query.weight', 'encoder.layer.1.attention.self.query.bias', 'encoder.layer.1.attention.self.key.weight', 'encoder.layer.1.attention.self.key.bias', 'encoder.layer.1.attention.self.value.weight', 'encoder.layer.1.attention.self.value.bias', 'encoder.layer.1.attention.output.dense.weight', 'encoder.layer.1.attention.output.dense.bias', 'encoder.layer.1.attention.output.LayerNorm.weight', 'encoder.layer.1.attention.output.LayerNorm.bias', 'encoder.layer.1.intermediate.dense.weight', 'encoder.layer.1.intermediate.dense.bias', 'encoder.layer.1.output.dense.weight', 'encoder.layer.1.output.dense.bias', 'encoder.layer.1.output.LayerNorm.weight', 'encoder.layer.1.output.LayerNorm.bias', 'encoder.layer.2.attention.self.query.weight', 'encoder.layer.2.attention.self.query.bias', 'encoder.layer.2.attention.self.key.weight', 'encoder.layer.2.attention.self.key.bias', 'encoder.layer.2.attention.self.value.weight', 'encoder.layer.2.attention.self.value.bias', 'encoder.layer.2.attention.output.dense.weight', 'encoder.layer.2.attention.output.dense.bias', 'encoder.layer.2.attention.output.LayerNorm.weight', 'encoder.layer.2.attention.output.LayerNorm.bias', 'encoder.layer.2.intermediate.dense.weight', 'encoder.layer.2.intermediate.dense.bias', 'encoder.layer.2.output.dense.weight', 'encoder.layer.2.output.dense.bias', 'encoder.layer.2.output.LayerNorm.weight', 'encoder.layer.2.output.LayerNorm.bias', 'encoder.layer.3.attention.self.query.weight', 'encoder.layer.3.attention.self.query.bias', 'encoder.layer.3.attention.self.key.weight', 'encoder.layer.3.attention.self.key.bias', 'encoder.layer.3.attention.self.value.weight', 'encoder.layer.3.attention.self.value.bias', 'encoder.layer.3.attention.output.dense.weight', 'encoder.layer.3.attention.output.dense.bias', 'encoder.layer.3.attention.output.LayerNorm.weight', 'encoder.layer.3.attention.output.LayerNorm.bias', 'encoder.layer.3.intermediate.dense.weight', 'encoder.layer.3.intermediate.dense.bias', 'encoder.layer.3.output.dense.weight', 'encoder.layer.3.output.dense.bias', 'encoder.layer.3.output.LayerNorm.weight', 'encoder.layer.3.output.LayerNorm.bias', 'encoder.layer.4.attention.self.query.weight', 'encoder.layer.4.attention.self.query.bias', 'encoder.layer.4.attention.self.key.weight', 'encoder.layer.4.attention.self.key.bias', 'encoder.layer.4.attention.self.value.weight', 'encoder.layer.4.attention.self.value.bias', 'encoder.layer.4.attention.output.dense.weight', 'encoder.layer.4.attention.output.dense.bias', 'encoder.layer.4.attention.output.LayerNorm.weight', 'encoder.layer.4.attention.output.LayerNorm.bias', 'encoder.layer.4.intermediate.dense.weight', 'encoder.layer.4.intermediate.dense.bias', 'encoder.layer.4.output.dense.weight', 'encoder.layer.4.output.dense.bias', 'encoder.layer.4.output.LayerNorm.weight', 'encoder.layer.4.output.LayerNorm.bias', 'encoder.layer.5.attention.self.query.weight', 'encoder.layer.5.attention.self.query.bias', 'encoder.layer.5.attention.self.key.weight', 'encoder.layer.5.attention.self.key.bias', 'encoder.layer.5.attention.self.value.weight', 'encoder.layer.5.attention.self.value.bias', 'encoder.layer.5.attention.output.dense.weight', 'encoder.layer.5.attention.output.dense.bias', 'encoder.layer.5.attention.output.LayerNorm.weight', 'encoder.layer.5.attention.output.LayerNorm.bias', 'encoder.layer.5.intermediate.dense.weight', 'encoder.layer.5.intermediate.dense.bias', 'encoder.layer.5.output.dense.weight', 'encoder.layer.5.output.dense.bias', 'encoder.layer.5.output.LayerNorm.weight', 'encoder.layer.5.output.LayerNorm.bias', 'encoder.layer.6.attention.self.query.weight', 'encoder.layer.6.attention.self.query.bias', 'encoder.layer.6.attention.self.key.weight', 'encoder.layer.6.attention.self.key.bias', 'encoder.layer.6.attention.self.value.weight', 'encoder.layer.6.attention.self.value.bias', 'encoder.layer.6.attention.output.dense.weight', 'encoder.layer.6.attention.output.dense.bias', 'encoder.layer.6.attention.output.LayerNorm.weight', 'encoder.layer.6.attention.output.LayerNorm.bias', 'encoder.layer.6.intermediate.dense.weight', 'encoder.layer.6.intermediate.dense.bias', 'encoder.layer.6.output.dense.weight', 'encoder.layer.6.output.dense.bias', 'encoder.layer.6.output.LayerNorm.weight', 'encoder.layer.6.output.LayerNorm.bias', 'encoder.layer.7.attention.self.query.weight', 'encoder.layer.7.attention.self.query.bias', 'encoder.layer.7.attention.self.key.weight', 'encoder.layer.7.attention.self.key.bias', 'encoder.layer.7.attention.self.value.weight', 'encoder.layer.7.attention.self.value.bias', 'encoder.layer.7.attention.output.dense.weight', 'encoder.layer.7.attention.output.dense.bias', 'encoder.layer.7.attention.output.LayerNorm.weight', 'encoder.layer.7.attention.output.LayerNorm.bias', 'encoder.layer.7.intermediate.dense.weight', 'encoder.layer.7.intermediate.dense.bias', 'encoder.layer.7.output.dense.weight', 'encoder.layer.7.output.dense.bias', 'encoder.layer.7.output.LayerNorm.weight', 'encoder.layer.7.output.LayerNorm.bias', 'encoder.layer.8.attention.self.query.weight', 'encoder.layer.8.attention.self.query.bias', 'encoder.layer.8.attention.self.key.weight', 'encoder.layer.8.attention.self.key.bias', 'encoder.layer.8.attention.self.value.weight', 'encoder.layer.8.attention.self.value.bias', 'encoder.layer.8.attention.output.dense.weight', 'encoder.layer.8.attention.output.dense.bias', 'encoder.layer.8.attention.output.LayerNorm.weight', 'encoder.layer.8.attention.output.LayerNorm.bias', 'encoder.layer.8.intermediate.dense.weight', 'encoder.layer.8.intermediate.dense.bias', 'encoder.layer.8.output.dense.weight', 'encoder.layer.8.output.dense.bias', 'encoder.layer.8.output.LayerNorm.weight', 'encoder.layer.8.output.LayerNorm.bias', 'encoder.layer.9.attention.self.query.weight', 'encoder.layer.9.attention.self.query.bias', 'encoder.layer.9.attention.self.key.weight', 'encoder.layer.9.attention.self.key.bias', 'encoder.layer.9.attention.self.value.weight', 'encoder.layer.9.attention.self.value.bias', 'encoder.layer.9.attention.output.dense.weight', 'encoder.layer.9.attention.output.dense.bias', 'encoder.layer.9.attention.output.LayerNorm.weight', 'encoder.layer.9.attention.output.LayerNorm.bias', 'encoder.layer.9.intermediate.dense.weight', 'encoder.layer.9.intermediate.dense.bias', 'encoder.layer.9.output.dense.weight', 'encoder.layer.9.output.dense.bias', 'encoder.layer.9.output.LayerNorm.weight', 'encoder.layer.9.output.LayerNorm.bias', 'encoder.layer.10.attention.self.query.weight', 'encoder.layer.10.attention.self.query.bias', 'encoder.layer.10.attention.self.key.weight', 'encoder.layer.10.attention.self.key.bias', 'encoder.layer.10.attention.self.value.weight', 'encoder.layer.10.attention.self.value.bias', 'encoder.layer.10.attention.output.dense.weight', 'encoder.layer.10.attention.output.dense.bias', 'encoder.layer.10.attention.output.LayerNorm.weight', 'encoder.layer.10.attention.output.LayerNorm.bias', 'encoder.layer.10.intermediate.dense.weight', 'encoder.layer.10.intermediate.dense.bias', 'encoder.layer.10.output.dense.weight', 'encoder.layer.10.output.dense.bias', 'encoder.layer.10.output.LayerNorm.weight', 'encoder.layer.10.output.LayerNorm.bias', 'encoder.layer.11.attention.self.query.weight', 'encoder.layer.11.attention.self.query.bias', 'encoder.layer.11.attention.self.key.weight', 'encoder.layer.11.attention.self.key.bias', 'encoder.layer.11.attention.self.value.weight', 'encoder.layer.11.attention.self.value.bias', 'encoder.layer.11.attention.output.dense.weight', 'encoder.layer.11.attention.output.dense.bias', 'encoder.layer.11.attention.output.LayerNorm.weight', 'encoder.layer.11.attention.output.LayerNorm.bias', 'encoder.layer.11.intermediate.dense.weight', 'encoder.layer.11.intermediate.dense.bias', 'encoder.layer.11.output.dense.weight', 'encoder.layer.11.output.dense.bias', 'encoder.layer.11.output.LayerNorm.weight', 'encoder.layer.11.output.LayerNorm.bias', 'embeddings.word_embeddings.weight', 'embeddings.position_embeddings.weight', 'embeddings.token_type_embeddings.weight', 'embeddings.LayerNorm.weight', 'embeddings.LayerNorm.bias', 'pooler.dense.weight', 'pooler.dense.bias'])\n", + "odict_keys(['embeddings.word_embeddings.weight', 'embeddings.position_embeddings.weight', 'embeddings.token_type_embeddings.weight', 'embeddings.LayerNorm.weight', 'embeddings.LayerNorm.bias', 'encoder.layer.0.attention.self.query.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.layer.0.intermediate.dense.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.0.output.dense.weight', 'encoder.layer.0.output.dense.bias', 'encoder.layer.0.output.LayerNorm.weight', 'encoder.layer.0.output.LayerNorm.bias', 'encoder.layer.1.attention.self.query.weight', 'encoder.layer.1.attention.self.query.bias', 'encoder.layer.1.attention.self.key.weight', 'encoder.layer.1.attention.self.key.bias', 'encoder.layer.1.attention.self.value.weight', 'encoder.layer.1.attention.self.value.bias', 'encoder.layer.1.attention.output.dense.weight', 'encoder.layer.1.attention.output.dense.bias', 'encoder.layer.1.attention.output.LayerNorm.weight', 'encoder.layer.1.attention.output.LayerNorm.bias', 'encoder.layer.1.intermediate.dense.weight', 'encoder.layer.1.intermediate.dense.bias', 'encoder.layer.1.output.dense.weight', 'encoder.layer.1.output.dense.bias', 'encoder.layer.1.output.LayerNorm.weight', 'encoder.layer.1.output.LayerNorm.bias', 'encoder.layer.2.attention.self.query.weight', 'encoder.layer.2.attention.self.query.bias', 'encoder.layer.2.attention.self.key.weight', 'encoder.layer.2.attention.self.key.bias', 'encoder.layer.2.attention.self.value.weight', 'encoder.layer.2.attention.self.value.bias', 'encoder.layer.2.attention.output.dense.weight', 'encoder.layer.2.attention.output.dense.bias', 'encoder.layer.2.attention.output.LayerNorm.weight', 'encoder.layer.2.attention.output.LayerNorm.bias', 'encoder.layer.2.intermediate.dense.weight', 'encoder.layer.2.intermediate.dense.bias', 'encoder.layer.2.output.dense.weight', 'encoder.layer.2.output.dense.bias', 'encoder.layer.2.output.LayerNorm.weight', 'encoder.layer.2.output.LayerNorm.bias', 'encoder.layer.3.attention.self.query.weight', 'encoder.layer.3.attention.self.query.bias', 'encoder.layer.3.attention.self.key.weight', 'encoder.layer.3.attention.self.key.bias', 'encoder.layer.3.attention.self.value.weight', 'encoder.layer.3.attention.self.value.bias', 'encoder.layer.3.attention.output.dense.weight', 'encoder.layer.3.attention.output.dense.bias', 'encoder.layer.3.attention.output.LayerNorm.weight', 'encoder.layer.3.attention.output.LayerNorm.bias', 'encoder.layer.3.intermediate.dense.weight', 'encoder.layer.3.intermediate.dense.bias', 'encoder.layer.3.output.dense.weight', 'encoder.layer.3.output.dense.bias', 'encoder.layer.3.output.LayerNorm.weight', 'encoder.layer.3.output.LayerNorm.bias', 'encoder.layer.4.attention.self.query.weight', 'encoder.layer.4.attention.self.query.bias', 'encoder.layer.4.attention.self.key.weight', 'encoder.layer.4.attention.self.key.bias', 'encoder.layer.4.attention.self.value.weight', 'encoder.layer.4.attention.self.value.bias', 'encoder.layer.4.attention.output.dense.weight', 'encoder.layer.4.attention.output.dense.bias', 'encoder.layer.4.attention.output.LayerNorm.weight', 'encoder.layer.4.attention.output.LayerNorm.bias', 'encoder.layer.4.intermediate.dense.weight', 'encoder.layer.4.intermediate.dense.bias', 'encoder.layer.4.output.dense.weight', 'encoder.layer.4.output.dense.bias', 'encoder.layer.4.output.LayerNorm.weight', 'encoder.layer.4.output.LayerNorm.bias', 'encoder.layer.5.attention.self.query.weight', 'encoder.layer.5.attention.self.query.bias', 'encoder.layer.5.attention.self.key.weight', 'encoder.layer.5.attention.self.key.bias', 'encoder.layer.5.attention.self.value.weight', 'encoder.layer.5.attention.self.value.bias', 'encoder.layer.5.attention.output.dense.weight', 'encoder.layer.5.attention.output.dense.bias', 'encoder.layer.5.attention.output.LayerNorm.weight', 'encoder.layer.5.attention.output.LayerNorm.bias', 'encoder.layer.5.intermediate.dense.weight', 'encoder.layer.5.intermediate.dense.bias', 'encoder.layer.5.output.dense.weight', 'encoder.layer.5.output.dense.bias', 'encoder.layer.5.output.LayerNorm.weight', 'encoder.layer.5.output.LayerNorm.bias', 'encoder.layer.6.attention.self.query.weight', 'encoder.layer.6.attention.self.query.bias', 'encoder.layer.6.attention.self.key.weight', 'encoder.layer.6.attention.self.key.bias', 'encoder.layer.6.attention.self.value.weight', 'encoder.layer.6.attention.self.value.bias', 'encoder.layer.6.attention.output.dense.weight', 'encoder.layer.6.attention.output.dense.bias', 'encoder.layer.6.attention.output.LayerNorm.weight', 'encoder.layer.6.attention.output.LayerNorm.bias', 'encoder.layer.6.intermediate.dense.weight', 'encoder.layer.6.intermediate.dense.bias', 'encoder.layer.6.output.dense.weight', 'encoder.layer.6.output.dense.bias', 'encoder.layer.6.output.LayerNorm.weight', 'encoder.layer.6.output.LayerNorm.bias', 'encoder.layer.7.attention.self.query.weight', 'encoder.layer.7.attention.self.query.bias', 'encoder.layer.7.attention.self.key.weight', 'encoder.layer.7.attention.self.key.bias', 'encoder.layer.7.attention.self.value.weight', 'encoder.layer.7.attention.self.value.bias', 'encoder.layer.7.attention.output.dense.weight', 'encoder.layer.7.attention.output.dense.bias', 'encoder.layer.7.attention.output.LayerNorm.weight', 'encoder.layer.7.attention.output.LayerNorm.bias', 'encoder.layer.7.intermediate.dense.weight', 'encoder.layer.7.intermediate.dense.bias', 'encoder.layer.7.output.dense.weight', 'encoder.layer.7.output.dense.bias', 'encoder.layer.7.output.LayerNorm.weight', 'encoder.layer.7.output.LayerNorm.bias', 'encoder.layer.8.attention.self.query.weight', 'encoder.layer.8.attention.self.query.bias', 'encoder.layer.8.attention.self.key.weight', 'encoder.layer.8.attention.self.key.bias', 'encoder.layer.8.attention.self.value.weight', 'encoder.layer.8.attention.self.value.bias', 'encoder.layer.8.attention.output.dense.weight', 'encoder.layer.8.attention.output.dense.bias', 'encoder.layer.8.attention.output.LayerNorm.weight', 'encoder.layer.8.attention.output.LayerNorm.bias', 'encoder.layer.8.intermediate.dense.weight', 'encoder.layer.8.intermediate.dense.bias', 'encoder.layer.8.output.dense.weight', 'encoder.layer.8.output.dense.bias', 'encoder.layer.8.output.LayerNorm.weight', 'encoder.layer.8.output.LayerNorm.bias', 'encoder.layer.9.attention.self.query.weight', 'encoder.layer.9.attention.self.query.bias', 'encoder.layer.9.attention.self.key.weight', 'encoder.layer.9.attention.self.key.bias', 'encoder.layer.9.attention.self.value.weight', 'encoder.layer.9.attention.self.value.bias', 'encoder.layer.9.attention.output.dense.weight', 'encoder.layer.9.attention.output.dense.bias', 'encoder.layer.9.attention.output.LayerNorm.weight', 'encoder.layer.9.attention.output.LayerNorm.bias', 'encoder.layer.9.intermediate.dense.weight', 'encoder.layer.9.intermediate.dense.bias', 'encoder.layer.9.output.dense.weight', 'encoder.layer.9.output.dense.bias', 'encoder.layer.9.output.LayerNorm.weight', 'encoder.layer.9.output.LayerNorm.bias', 'encoder.layer.10.attention.self.query.weight', 'encoder.layer.10.attention.self.query.bias', 'encoder.layer.10.attention.self.key.weight', 'encoder.layer.10.attention.self.key.bias', 'encoder.layer.10.attention.self.value.weight', 'encoder.layer.10.attention.self.value.bias', 'encoder.layer.10.attention.output.dense.weight', 'encoder.layer.10.attention.output.dense.bias', 'encoder.layer.10.attention.output.LayerNorm.weight', 'encoder.layer.10.attention.output.LayerNorm.bias', 'encoder.layer.10.intermediate.dense.weight', 'encoder.layer.10.intermediate.dense.bias', 'encoder.layer.10.output.dense.weight', 'encoder.layer.10.output.dense.bias', 'encoder.layer.10.output.LayerNorm.weight', 'encoder.layer.10.output.LayerNorm.bias', 'encoder.layer.11.attention.self.query.weight', 'encoder.layer.11.attention.self.query.bias', 'encoder.layer.11.attention.self.key.weight', 'encoder.layer.11.attention.self.key.bias', 'encoder.layer.11.attention.self.value.weight', 'encoder.layer.11.attention.self.value.bias', 'encoder.layer.11.attention.output.dense.weight', 'encoder.layer.11.attention.output.dense.bias', 'encoder.layer.11.attention.output.LayerNorm.weight', 'encoder.layer.11.attention.output.LayerNorm.bias', 'encoder.layer.11.intermediate.dense.weight', 'encoder.layer.11.intermediate.dense.bias', 'encoder.layer.11.output.dense.weight', 'encoder.layer.11.output.dense.bias', 'encoder.layer.11.output.LayerNorm.weight', 'encoder.layer.11.output.LayerNorm.bias', 'pooler.dense.weight', 'pooler.dense.bias'])\n", + "success1\n", + "success2\n", + "Total differences: 0.0\n" + ] + } + ], + "source": [ + "# Check weights loaded correctly\n", + "my_dict = my_model.to_state_dict()\n", + "hf_dict = hf_model.state_dict()\n", "\n", - "hf_output_model = hf_model(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\n", - "hf_output = hf_head(hf_output_model[0])\n", + "print(f\"Total differences: {check_dicts(my_dict, hf_dict)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n" + ] + } + ], + "source": [ + "my_output_model = my_model(input_embeds = input_embeds, attention_mask=mask, key=key)\n", + "hf_output_model = hf_model(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: acc: 1.0 \t norms: (tensor(443.3569), tensor(443.3564)) \t diffs: 1.405485477334878e-06\n" + ] + } + ], + "source": [ + "print(f\"Model: {check(my_output_model[0].array, hf_output_model[0].detach())[3]}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n" + ] + } + ], + "source": [ + "my_output_model = my_model(input_ids = input_ids, attention_mask=mask, key=key)\n", + "hf_output_model = hf_model(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: acc: 1.0 \t norms: (tensor(431.9545), tensor(431.9545)) \t diffs: 1.3033360346526024e-06\n" + ] + } + ], + "source": [ + "print(f\"Model: {check(my_output_model[0].array, hf_output_model[0].detach())[3]}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": {}, + "outputs": [], + "source": [ + "hf_mlm = hf_roberta.RobertaForMaskedLM.from_pretrained(hf_model_str)\n", + "state_mlm = hf_mlm.state_dict()\n", "\n", + "state_mlm = {k: np.array(v) for k, v in state_mlm.items()}\n", "\n", - "# Checks\n", + "hf_config = hf_mlm.config\n", + "hf_config.hidden_dropout_prob = 0\n", + "hf_config.attention_probs_dropout_prob = 0\n", + "my_config = my_roberta.RobertaConfig.from_hf_config(hf_config)\n", "\n", - "# Notes\n", - "# embeds works between hf and my, and within model->head and mlm \n", - "# ids does not work between hf and my for mlm or within hf for model->head and mlm - so hf mlm is doing something weird.'''" + "my_mlm = my_roberta.RobertaForMaskedLM.init(Vocab, my_config, key=key)\n", + "my_mlm = my_mlm.from_state_dict(state_mlm)" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 58, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "RobertaModel: acc: 1.0 \t norms: (tensor(888.5287), tensor(888.5287)) \t diffs: 3.736330143055966e-07\n", - "Roberta Model + LM head: acc: 1.0 \t norms: (tensor(7065.4971), tensor(7065.4971)) \t diffs: 5.507896503331722e-07\n", - "MLM: acc: 0.0014420458728273227 \t norms: (tensor(7065.4971), tensor(7062.7324)) \t diffs: 0.07865540683269501\n", - "my RobertaModel + LM head vs MLM: acc: 1.0 \t norms: (tensor(7065.4971), tensor(7065.4971)) \t diffs: 0.0\n", - "hf RobertaModel + LM head vs MLM: acc: 0.0014420071674599332 \t norms: (tensor(7065.4971), tensor(7062.7324)) \t diffs: 0.07865539938211441\n" + "dict_keys(['roberta.encoder.layer.0.attention.self.query.weight', 'roberta.encoder.layer.0.attention.self.query.bias', 'roberta.encoder.layer.0.attention.self.key.weight', 'roberta.encoder.layer.0.attention.self.key.bias', 'roberta.encoder.layer.0.attention.self.value.weight', 'roberta.encoder.layer.0.attention.self.value.bias', 'roberta.encoder.layer.0.attention.output.dense.weight', 'roberta.encoder.layer.0.attention.output.dense.bias', 'roberta.encoder.layer.0.attention.output.LayerNorm.weight', 'roberta.encoder.layer.0.attention.output.LayerNorm.bias', 'roberta.encoder.layer.0.intermediate.dense.weight', 'roberta.encoder.layer.0.intermediate.dense.bias', 'roberta.encoder.layer.0.output.dense.weight', 'roberta.encoder.layer.0.output.dense.bias', 'roberta.encoder.layer.0.output.LayerNorm.weight', 'roberta.encoder.layer.0.output.LayerNorm.bias', 'roberta.encoder.layer.1.attention.self.query.weight', 'roberta.encoder.layer.1.attention.self.query.bias', 'roberta.encoder.layer.1.attention.self.key.weight', 'roberta.encoder.layer.1.attention.self.key.bias', 'roberta.encoder.layer.1.attention.self.value.weight', 'roberta.encoder.layer.1.attention.self.value.bias', 'roberta.encoder.layer.1.attention.output.dense.weight', 'roberta.encoder.layer.1.attention.output.dense.bias', 'roberta.encoder.layer.1.attention.output.LayerNorm.weight', 'roberta.encoder.layer.1.attention.output.LayerNorm.bias', 'roberta.encoder.layer.1.intermediate.dense.weight', 'roberta.encoder.layer.1.intermediate.dense.bias', 'roberta.encoder.layer.1.output.dense.weight', 'roberta.encoder.layer.1.output.dense.bias', 'roberta.encoder.layer.1.output.LayerNorm.weight', 'roberta.encoder.layer.1.output.LayerNorm.bias', 'roberta.encoder.layer.2.attention.self.query.weight', 'roberta.encoder.layer.2.attention.self.query.bias', 'roberta.encoder.layer.2.attention.self.key.weight', 'roberta.encoder.layer.2.attention.self.key.bias', 'roberta.encoder.layer.2.attention.self.value.weight', 'roberta.encoder.layer.2.attention.self.value.bias', 'roberta.encoder.layer.2.attention.output.dense.weight', 'roberta.encoder.layer.2.attention.output.dense.bias', 'roberta.encoder.layer.2.attention.output.LayerNorm.weight', 'roberta.encoder.layer.2.attention.output.LayerNorm.bias', 'roberta.encoder.layer.2.intermediate.dense.weight', 'roberta.encoder.layer.2.intermediate.dense.bias', 'roberta.encoder.layer.2.output.dense.weight', 'roberta.encoder.layer.2.output.dense.bias', 'roberta.encoder.layer.2.output.LayerNorm.weight', 'roberta.encoder.layer.2.output.LayerNorm.bias', 'roberta.encoder.layer.3.attention.self.query.weight', 'roberta.encoder.layer.3.attention.self.query.bias', 'roberta.encoder.layer.3.attention.self.key.weight', 'roberta.encoder.layer.3.attention.self.key.bias', 'roberta.encoder.layer.3.attention.self.value.weight', 'roberta.encoder.layer.3.attention.self.value.bias', 'roberta.encoder.layer.3.attention.output.dense.weight', 'roberta.encoder.layer.3.attention.output.dense.bias', 'roberta.encoder.layer.3.attention.output.LayerNorm.weight', 'roberta.encoder.layer.3.attention.output.LayerNorm.bias', 'roberta.encoder.layer.3.intermediate.dense.weight', 'roberta.encoder.layer.3.intermediate.dense.bias', 'roberta.encoder.layer.3.output.dense.weight', 'roberta.encoder.layer.3.output.dense.bias', 'roberta.encoder.layer.3.output.LayerNorm.weight', 'roberta.encoder.layer.3.output.LayerNorm.bias', 'roberta.encoder.layer.4.attention.self.query.weight', 'roberta.encoder.layer.4.attention.self.query.bias', 'roberta.encoder.layer.4.attention.self.key.weight', 'roberta.encoder.layer.4.attention.self.key.bias', 'roberta.encoder.layer.4.attention.self.value.weight', 'roberta.encoder.layer.4.attention.self.value.bias', 'roberta.encoder.layer.4.attention.output.dense.weight', 'roberta.encoder.layer.4.attention.output.dense.bias', 'roberta.encoder.layer.4.attention.output.LayerNorm.weight', 'roberta.encoder.layer.4.attention.output.LayerNorm.bias', 'roberta.encoder.layer.4.intermediate.dense.weight', 'roberta.encoder.layer.4.intermediate.dense.bias', 'roberta.encoder.layer.4.output.dense.weight', 'roberta.encoder.layer.4.output.dense.bias', 'roberta.encoder.layer.4.output.LayerNorm.weight', 'roberta.encoder.layer.4.output.LayerNorm.bias', 'roberta.encoder.layer.5.attention.self.query.weight', 'roberta.encoder.layer.5.attention.self.query.bias', 'roberta.encoder.layer.5.attention.self.key.weight', 'roberta.encoder.layer.5.attention.self.key.bias', 'roberta.encoder.layer.5.attention.self.value.weight', 'roberta.encoder.layer.5.attention.self.value.bias', 'roberta.encoder.layer.5.attention.output.dense.weight', 'roberta.encoder.layer.5.attention.output.dense.bias', 'roberta.encoder.layer.5.attention.output.LayerNorm.weight', 'roberta.encoder.layer.5.attention.output.LayerNorm.bias', 'roberta.encoder.layer.5.intermediate.dense.weight', 'roberta.encoder.layer.5.intermediate.dense.bias', 'roberta.encoder.layer.5.output.dense.weight', 'roberta.encoder.layer.5.output.dense.bias', 'roberta.encoder.layer.5.output.LayerNorm.weight', 'roberta.encoder.layer.5.output.LayerNorm.bias', 'roberta.encoder.layer.6.attention.self.query.weight', 'roberta.encoder.layer.6.attention.self.query.bias', 'roberta.encoder.layer.6.attention.self.key.weight', 'roberta.encoder.layer.6.attention.self.key.bias', 'roberta.encoder.layer.6.attention.self.value.weight', 'roberta.encoder.layer.6.attention.self.value.bias', 'roberta.encoder.layer.6.attention.output.dense.weight', 'roberta.encoder.layer.6.attention.output.dense.bias', 'roberta.encoder.layer.6.attention.output.LayerNorm.weight', 'roberta.encoder.layer.6.attention.output.LayerNorm.bias', 'roberta.encoder.layer.6.intermediate.dense.weight', 'roberta.encoder.layer.6.intermediate.dense.bias', 'roberta.encoder.layer.6.output.dense.weight', 'roberta.encoder.layer.6.output.dense.bias', 'roberta.encoder.layer.6.output.LayerNorm.weight', 'roberta.encoder.layer.6.output.LayerNorm.bias', 'roberta.encoder.layer.7.attention.self.query.weight', 'roberta.encoder.layer.7.attention.self.query.bias', 'roberta.encoder.layer.7.attention.self.key.weight', 'roberta.encoder.layer.7.attention.self.key.bias', 'roberta.encoder.layer.7.attention.self.value.weight', 'roberta.encoder.layer.7.attention.self.value.bias', 'roberta.encoder.layer.7.attention.output.dense.weight', 'roberta.encoder.layer.7.attention.output.dense.bias', 'roberta.encoder.layer.7.attention.output.LayerNorm.weight', 'roberta.encoder.layer.7.attention.output.LayerNorm.bias', 'roberta.encoder.layer.7.intermediate.dense.weight', 'roberta.encoder.layer.7.intermediate.dense.bias', 'roberta.encoder.layer.7.output.dense.weight', 'roberta.encoder.layer.7.output.dense.bias', 'roberta.encoder.layer.7.output.LayerNorm.weight', 'roberta.encoder.layer.7.output.LayerNorm.bias', 'roberta.encoder.layer.8.attention.self.query.weight', 'roberta.encoder.layer.8.attention.self.query.bias', 'roberta.encoder.layer.8.attention.self.key.weight', 'roberta.encoder.layer.8.attention.self.key.bias', 'roberta.encoder.layer.8.attention.self.value.weight', 'roberta.encoder.layer.8.attention.self.value.bias', 'roberta.encoder.layer.8.attention.output.dense.weight', 'roberta.encoder.layer.8.attention.output.dense.bias', 'roberta.encoder.layer.8.attention.output.LayerNorm.weight', 'roberta.encoder.layer.8.attention.output.LayerNorm.bias', 'roberta.encoder.layer.8.intermediate.dense.weight', 'roberta.encoder.layer.8.intermediate.dense.bias', 'roberta.encoder.layer.8.output.dense.weight', 'roberta.encoder.layer.8.output.dense.bias', 'roberta.encoder.layer.8.output.LayerNorm.weight', 'roberta.encoder.layer.8.output.LayerNorm.bias', 'roberta.encoder.layer.9.attention.self.query.weight', 'roberta.encoder.layer.9.attention.self.query.bias', 'roberta.encoder.layer.9.attention.self.key.weight', 'roberta.encoder.layer.9.attention.self.key.bias', 'roberta.encoder.layer.9.attention.self.value.weight', 'roberta.encoder.layer.9.attention.self.value.bias', 'roberta.encoder.layer.9.attention.output.dense.weight', 'roberta.encoder.layer.9.attention.output.dense.bias', 'roberta.encoder.layer.9.attention.output.LayerNorm.weight', 'roberta.encoder.layer.9.attention.output.LayerNorm.bias', 'roberta.encoder.layer.9.intermediate.dense.weight', 'roberta.encoder.layer.9.intermediate.dense.bias', 'roberta.encoder.layer.9.output.dense.weight', 'roberta.encoder.layer.9.output.dense.bias', 'roberta.encoder.layer.9.output.LayerNorm.weight', 'roberta.encoder.layer.9.output.LayerNorm.bias', 'roberta.encoder.layer.10.attention.self.query.weight', 'roberta.encoder.layer.10.attention.self.query.bias', 'roberta.encoder.layer.10.attention.self.key.weight', 'roberta.encoder.layer.10.attention.self.key.bias', 'roberta.encoder.layer.10.attention.self.value.weight', 'roberta.encoder.layer.10.attention.self.value.bias', 'roberta.encoder.layer.10.attention.output.dense.weight', 'roberta.encoder.layer.10.attention.output.dense.bias', 'roberta.encoder.layer.10.attention.output.LayerNorm.weight', 'roberta.encoder.layer.10.attention.output.LayerNorm.bias', 'roberta.encoder.layer.10.intermediate.dense.weight', 'roberta.encoder.layer.10.intermediate.dense.bias', 'roberta.encoder.layer.10.output.dense.weight', 'roberta.encoder.layer.10.output.dense.bias', 'roberta.encoder.layer.10.output.LayerNorm.weight', 'roberta.encoder.layer.10.output.LayerNorm.bias', 'roberta.encoder.layer.11.attention.self.query.weight', 'roberta.encoder.layer.11.attention.self.query.bias', 'roberta.encoder.layer.11.attention.self.key.weight', 'roberta.encoder.layer.11.attention.self.key.bias', 'roberta.encoder.layer.11.attention.self.value.weight', 'roberta.encoder.layer.11.attention.self.value.bias', 'roberta.encoder.layer.11.attention.output.dense.weight', 'roberta.encoder.layer.11.attention.output.dense.bias', 'roberta.encoder.layer.11.attention.output.LayerNorm.weight', 'roberta.encoder.layer.11.attention.output.LayerNorm.bias', 'roberta.encoder.layer.11.intermediate.dense.weight', 'roberta.encoder.layer.11.intermediate.dense.bias', 'roberta.encoder.layer.11.output.dense.weight', 'roberta.encoder.layer.11.output.dense.bias', 'roberta.encoder.layer.11.output.LayerNorm.weight', 'roberta.encoder.layer.11.output.LayerNorm.bias', 'roberta.embeddings.word_embeddings.weight', 'roberta.embeddings.position_embeddings.weight', 'roberta.embeddings.token_type_embeddings.weight', 'roberta.embeddings.LayerNorm.weight', 'roberta.embeddings.LayerNorm.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.decoder.bias'])\n", + "odict_keys(['roberta.embeddings.word_embeddings.weight', 'roberta.embeddings.position_embeddings.weight', 'roberta.embeddings.token_type_embeddings.weight', 'roberta.embeddings.LayerNorm.weight', 'roberta.embeddings.LayerNorm.bias', 'roberta.encoder.layer.0.attention.self.query.weight', 'roberta.encoder.layer.0.attention.self.query.bias', 'roberta.encoder.layer.0.attention.self.key.weight', 'roberta.encoder.layer.0.attention.self.key.bias', 'roberta.encoder.layer.0.attention.self.value.weight', 'roberta.encoder.layer.0.attention.self.value.bias', 'roberta.encoder.layer.0.attention.output.dense.weight', 'roberta.encoder.layer.0.attention.output.dense.bias', 'roberta.encoder.layer.0.attention.output.LayerNorm.weight', 'roberta.encoder.layer.0.attention.output.LayerNorm.bias', 'roberta.encoder.layer.0.intermediate.dense.weight', 'roberta.encoder.layer.0.intermediate.dense.bias', 'roberta.encoder.layer.0.output.dense.weight', 'roberta.encoder.layer.0.output.dense.bias', 'roberta.encoder.layer.0.output.LayerNorm.weight', 'roberta.encoder.layer.0.output.LayerNorm.bias', 'roberta.encoder.layer.1.attention.self.query.weight', 'roberta.encoder.layer.1.attention.self.query.bias', 'roberta.encoder.layer.1.attention.self.key.weight', 'roberta.encoder.layer.1.attention.self.key.bias', 'roberta.encoder.layer.1.attention.self.value.weight', 'roberta.encoder.layer.1.attention.self.value.bias', 'roberta.encoder.layer.1.attention.output.dense.weight', 'roberta.encoder.layer.1.attention.output.dense.bias', 'roberta.encoder.layer.1.attention.output.LayerNorm.weight', 'roberta.encoder.layer.1.attention.output.LayerNorm.bias', 'roberta.encoder.layer.1.intermediate.dense.weight', 'roberta.encoder.layer.1.intermediate.dense.bias', 'roberta.encoder.layer.1.output.dense.weight', 'roberta.encoder.layer.1.output.dense.bias', 'roberta.encoder.layer.1.output.LayerNorm.weight', 'roberta.encoder.layer.1.output.LayerNorm.bias', 'roberta.encoder.layer.2.attention.self.query.weight', 'roberta.encoder.layer.2.attention.self.query.bias', 'roberta.encoder.layer.2.attention.self.key.weight', 'roberta.encoder.layer.2.attention.self.key.bias', 'roberta.encoder.layer.2.attention.self.value.weight', 'roberta.encoder.layer.2.attention.self.value.bias', 'roberta.encoder.layer.2.attention.output.dense.weight', 'roberta.encoder.layer.2.attention.output.dense.bias', 'roberta.encoder.layer.2.attention.output.LayerNorm.weight', 'roberta.encoder.layer.2.attention.output.LayerNorm.bias', 'roberta.encoder.layer.2.intermediate.dense.weight', 'roberta.encoder.layer.2.intermediate.dense.bias', 'roberta.encoder.layer.2.output.dense.weight', 'roberta.encoder.layer.2.output.dense.bias', 'roberta.encoder.layer.2.output.LayerNorm.weight', 'roberta.encoder.layer.2.output.LayerNorm.bias', 'roberta.encoder.layer.3.attention.self.query.weight', 'roberta.encoder.layer.3.attention.self.query.bias', 'roberta.encoder.layer.3.attention.self.key.weight', 'roberta.encoder.layer.3.attention.self.key.bias', 'roberta.encoder.layer.3.attention.self.value.weight', 'roberta.encoder.layer.3.attention.self.value.bias', 'roberta.encoder.layer.3.attention.output.dense.weight', 'roberta.encoder.layer.3.attention.output.dense.bias', 'roberta.encoder.layer.3.attention.output.LayerNorm.weight', 'roberta.encoder.layer.3.attention.output.LayerNorm.bias', 'roberta.encoder.layer.3.intermediate.dense.weight', 'roberta.encoder.layer.3.intermediate.dense.bias', 'roberta.encoder.layer.3.output.dense.weight', 'roberta.encoder.layer.3.output.dense.bias', 'roberta.encoder.layer.3.output.LayerNorm.weight', 'roberta.encoder.layer.3.output.LayerNorm.bias', 'roberta.encoder.layer.4.attention.self.query.weight', 'roberta.encoder.layer.4.attention.self.query.bias', 'roberta.encoder.layer.4.attention.self.key.weight', 'roberta.encoder.layer.4.attention.self.key.bias', 'roberta.encoder.layer.4.attention.self.value.weight', 'roberta.encoder.layer.4.attention.self.value.bias', 'roberta.encoder.layer.4.attention.output.dense.weight', 'roberta.encoder.layer.4.attention.output.dense.bias', 'roberta.encoder.layer.4.attention.output.LayerNorm.weight', 'roberta.encoder.layer.4.attention.output.LayerNorm.bias', 'roberta.encoder.layer.4.intermediate.dense.weight', 'roberta.encoder.layer.4.intermediate.dense.bias', 'roberta.encoder.layer.4.output.dense.weight', 'roberta.encoder.layer.4.output.dense.bias', 'roberta.encoder.layer.4.output.LayerNorm.weight', 'roberta.encoder.layer.4.output.LayerNorm.bias', 'roberta.encoder.layer.5.attention.self.query.weight', 'roberta.encoder.layer.5.attention.self.query.bias', 'roberta.encoder.layer.5.attention.self.key.weight', 'roberta.encoder.layer.5.attention.self.key.bias', 'roberta.encoder.layer.5.attention.self.value.weight', 'roberta.encoder.layer.5.attention.self.value.bias', 'roberta.encoder.layer.5.attention.output.dense.weight', 'roberta.encoder.layer.5.attention.output.dense.bias', 'roberta.encoder.layer.5.attention.output.LayerNorm.weight', 'roberta.encoder.layer.5.attention.output.LayerNorm.bias', 'roberta.encoder.layer.5.intermediate.dense.weight', 'roberta.encoder.layer.5.intermediate.dense.bias', 'roberta.encoder.layer.5.output.dense.weight', 'roberta.encoder.layer.5.output.dense.bias', 'roberta.encoder.layer.5.output.LayerNorm.weight', 'roberta.encoder.layer.5.output.LayerNorm.bias', 'roberta.encoder.layer.6.attention.self.query.weight', 'roberta.encoder.layer.6.attention.self.query.bias', 'roberta.encoder.layer.6.attention.self.key.weight', 'roberta.encoder.layer.6.attention.self.key.bias', 'roberta.encoder.layer.6.attention.self.value.weight', 'roberta.encoder.layer.6.attention.self.value.bias', 'roberta.encoder.layer.6.attention.output.dense.weight', 'roberta.encoder.layer.6.attention.output.dense.bias', 'roberta.encoder.layer.6.attention.output.LayerNorm.weight', 'roberta.encoder.layer.6.attention.output.LayerNorm.bias', 'roberta.encoder.layer.6.intermediate.dense.weight', 'roberta.encoder.layer.6.intermediate.dense.bias', 'roberta.encoder.layer.6.output.dense.weight', 'roberta.encoder.layer.6.output.dense.bias', 'roberta.encoder.layer.6.output.LayerNorm.weight', 'roberta.encoder.layer.6.output.LayerNorm.bias', 'roberta.encoder.layer.7.attention.self.query.weight', 'roberta.encoder.layer.7.attention.self.query.bias', 'roberta.encoder.layer.7.attention.self.key.weight', 'roberta.encoder.layer.7.attention.self.key.bias', 'roberta.encoder.layer.7.attention.self.value.weight', 'roberta.encoder.layer.7.attention.self.value.bias', 'roberta.encoder.layer.7.attention.output.dense.weight', 'roberta.encoder.layer.7.attention.output.dense.bias', 'roberta.encoder.layer.7.attention.output.LayerNorm.weight', 'roberta.encoder.layer.7.attention.output.LayerNorm.bias', 'roberta.encoder.layer.7.intermediate.dense.weight', 'roberta.encoder.layer.7.intermediate.dense.bias', 'roberta.encoder.layer.7.output.dense.weight', 'roberta.encoder.layer.7.output.dense.bias', 'roberta.encoder.layer.7.output.LayerNorm.weight', 'roberta.encoder.layer.7.output.LayerNorm.bias', 'roberta.encoder.layer.8.attention.self.query.weight', 'roberta.encoder.layer.8.attention.self.query.bias', 'roberta.encoder.layer.8.attention.self.key.weight', 'roberta.encoder.layer.8.attention.self.key.bias', 'roberta.encoder.layer.8.attention.self.value.weight', 'roberta.encoder.layer.8.attention.self.value.bias', 'roberta.encoder.layer.8.attention.output.dense.weight', 'roberta.encoder.layer.8.attention.output.dense.bias', 'roberta.encoder.layer.8.attention.output.LayerNorm.weight', 'roberta.encoder.layer.8.attention.output.LayerNorm.bias', 'roberta.encoder.layer.8.intermediate.dense.weight', 'roberta.encoder.layer.8.intermediate.dense.bias', 'roberta.encoder.layer.8.output.dense.weight', 'roberta.encoder.layer.8.output.dense.bias', 'roberta.encoder.layer.8.output.LayerNorm.weight', 'roberta.encoder.layer.8.output.LayerNorm.bias', 'roberta.encoder.layer.9.attention.self.query.weight', 'roberta.encoder.layer.9.attention.self.query.bias', 'roberta.encoder.layer.9.attention.self.key.weight', 'roberta.encoder.layer.9.attention.self.key.bias', 'roberta.encoder.layer.9.attention.self.value.weight', 'roberta.encoder.layer.9.attention.self.value.bias', 'roberta.encoder.layer.9.attention.output.dense.weight', 'roberta.encoder.layer.9.attention.output.dense.bias', 'roberta.encoder.layer.9.attention.output.LayerNorm.weight', 'roberta.encoder.layer.9.attention.output.LayerNorm.bias', 'roberta.encoder.layer.9.intermediate.dense.weight', 'roberta.encoder.layer.9.intermediate.dense.bias', 'roberta.encoder.layer.9.output.dense.weight', 'roberta.encoder.layer.9.output.dense.bias', 'roberta.encoder.layer.9.output.LayerNorm.weight', 'roberta.encoder.layer.9.output.LayerNorm.bias', 'roberta.encoder.layer.10.attention.self.query.weight', 'roberta.encoder.layer.10.attention.self.query.bias', 'roberta.encoder.layer.10.attention.self.key.weight', 'roberta.encoder.layer.10.attention.self.key.bias', 'roberta.encoder.layer.10.attention.self.value.weight', 'roberta.encoder.layer.10.attention.self.value.bias', 'roberta.encoder.layer.10.attention.output.dense.weight', 'roberta.encoder.layer.10.attention.output.dense.bias', 'roberta.encoder.layer.10.attention.output.LayerNorm.weight', 'roberta.encoder.layer.10.attention.output.LayerNorm.bias', 'roberta.encoder.layer.10.intermediate.dense.weight', 'roberta.encoder.layer.10.intermediate.dense.bias', 'roberta.encoder.layer.10.output.dense.weight', 'roberta.encoder.layer.10.output.dense.bias', 'roberta.encoder.layer.10.output.LayerNorm.weight', 'roberta.encoder.layer.10.output.LayerNorm.bias', 'roberta.encoder.layer.11.attention.self.query.weight', 'roberta.encoder.layer.11.attention.self.query.bias', 'roberta.encoder.layer.11.attention.self.key.weight', 'roberta.encoder.layer.11.attention.self.key.bias', 'roberta.encoder.layer.11.attention.self.value.weight', 'roberta.encoder.layer.11.attention.self.value.bias', 'roberta.encoder.layer.11.attention.output.dense.weight', 'roberta.encoder.layer.11.attention.output.dense.bias', 'roberta.encoder.layer.11.attention.output.LayerNorm.weight', 'roberta.encoder.layer.11.attention.output.LayerNorm.bias', 'roberta.encoder.layer.11.intermediate.dense.weight', 'roberta.encoder.layer.11.intermediate.dense.bias', 'roberta.encoder.layer.11.output.dense.weight', 'roberta.encoder.layer.11.output.dense.bias', 'roberta.encoder.layer.11.output.LayerNorm.weight', 'roberta.encoder.layer.11.output.LayerNorm.bias', 'lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.decoder.bias'])\n", + "success1\n", + "fail2\n", + "['lm_head.bias']\n", + "Total differences: 0.0\n" ] } ], "source": [ - "print(f\"RobertaModel: {check(my_output_model[0].array, hf_output_model[0].detach())[3]}\")\n", - "print(f\"Roberta Model + LM head: {check(my_output.array, hf_output.detach())[3]}\")\n", - "print(f\"MLM: {check(my_output_mlm[0].array, hf_output_mlm[0].detach())[3]}\")\n", + "# Check weights loaded correctly\n", + "my_dict = my_mlm.to_state_dict()\n", + "hf_dict = hf_mlm.state_dict()\n", "\n", - "print(f\"my RobertaModel + LM head vs MLM: {check(my_output.array, my_output_mlm[0].array)[3]}\")\n", - "print(f\"hf RobertaModel + LM head vs MLM: {check(hf_output.detach(), hf_output_mlm[0].detach())[3]}\")" + "print(f\"Total differences: {check_dicts(my_dict, hf_dict)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([-0.0972, -0.0294, 0.4988, ..., -0.0312, -0.0312, -1.0000])" + ] + }, + "execution_count": 59, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hf_dict['lm_head.bias']" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([-0.0972, -0.0294, 0.4988, ..., -0.0312, -0.0312, -1.0000])" + ] + }, + "execution_count": 60, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hf_dict['lm_head.decoder.bias']" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n" + ] + } + ], + "source": [ + "my_output_mlm = my_mlm(input_embeds = input_embeds, attention_mask=mask, key=key)\n", + "hf_output_mlm = hf_mlm(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MLM: acc: 1.0 \t norms: (tensor(33433.4062), tensor(33433.3945)) \t diffs: 1.7561978893354535e-05\n" + ] + } + ], + "source": [ + "print(f\"MLM: {check(my_output_mlm[0].array, hf_output_mlm[0].detach())[3]}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n" + ] + } + ], + "source": [ + "my_output_mlm = my_mlm(input_ids = input_ids, attention_mask=mask, key=key)\n", + "hf_output_mlm = hf_mlm(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MLM: acc: 1.0 \t norms: (tensor(28814.2480), tensor(28814.2168)) \t diffs: 1.45375252031954e-05\n" + ] + } + ], + "source": [ + "print(f\"MLM: {check(my_output_mlm[0].array, hf_output_mlm[0].detach())[3]}\")" ] } ], From 6c105f5ff45b83454cbb8376e7f8a3401be46f78 Mon Sep 17 00:00:00 2001 From: JulienDarve Date: Wed, 11 Sep 2024 19:51:41 -0700 Subject: [PATCH 13/29] trial --- src/levanter/models/testing.ipynb | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/levanter/models/testing.ipynb b/src/levanter/models/testing.ipynb index fc7dbaef1..27a24f7bf 100644 --- a/src/levanter/models/testing.ipynb +++ b/src/levanter/models/testing.ipynb @@ -14,7 +14,9 @@ "import jax\n", "import jax.random as jrandom\n", "import jax.numpy as jnp\n", - "import numpy as np" + "import numpy as np\n", + "\n", + "# hello" ] }, { From ab8507929514a2264d916dc405113f5f29af99bd Mon Sep 17 00:00:00 2001 From: JulienDarve Date: Wed, 11 Sep 2024 20:17:48 -0700 Subject: [PATCH 14/29] update 1 --- src/levanter/main/train_mlm.py | 6 ++++-- src/levanter/models/roberta.py | 22 +++++++++++++++++++++- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/src/levanter/main/train_mlm.py b/src/levanter/main/train_mlm.py index 435abe5bf..1f1d96140 100644 --- a/src/levanter/main/train_mlm.py +++ b/src/levanter/main/train_mlm.py @@ -1,6 +1,7 @@ # train_mlm.py import dataclasses +import functools import gc import logging import os @@ -20,7 +21,7 @@ from levanter.data.text import MaskedLmDataset, LMDatasetConfig, LMMixtureDatasetConfig from levanter.models.gpt2 import Gpt2Config from levanter.models.llama import LlamaConfig -from levanter.models.lm_model import LmConfig +from levanter.models.lm_model import LmConfig, compute_next_token_loss from levanter.models.roberta import RobertaConfig from levanter.optim import AdamConfig, OptimizerConfig from levanter.trainer import Trainer, TrainerConfig @@ -82,12 +83,13 @@ def main(config: TrainMlmConfig): levanter.initialize(config) optimizer = config.optimizer.build(config.trainer.num_train_steps) + loss_function = functools.partial(compute_next_token_loss, logsumexp_weight=config.z_loss_weight) # Using the trainer as a context manager does 3 things: # 1. Sets the device mesh # 2. Sets the axis mapping (for fsdp) # 3. Sets the global metrics tracker - with Trainer(config.trainer, optimizer) as trainer, jax.disable_jit(True): + with Trainer(config.trainer, optimizer, loss_function) as trainer: # randomness in jax is tightly controlled by "keys" which are the states of the random number generators # this makes deterministic training pretty easy seed = config.trainer.seed diff --git a/src/levanter/models/roberta.py b/src/levanter/models/roberta.py index 84c85527e..386b8511f 100644 --- a/src/levanter/models/roberta.py +++ b/src/levanter/models/roberta.py @@ -10,6 +10,7 @@ import haliax as hax import haliax.nn as hnn +from haliax.nn import cross_entropy_loss from haliax import Axis, AxisSpec, NamedArray from haliax.jax_utils import maybe_rng_split, named_call, shaped_rng_split from haliax.nn.scan import BlockSeq @@ -27,7 +28,7 @@ from levanter.logging import silence_transformer_nag from levanter.models.attention import AttentionBackend, AttentionMask, simple_attention_with_dropout from levanter.models.gpt2 import ACT2FN -from levanter.models.lm_model import LmConfig, LmHeadModel +from levanter.models.lm_model import LmConfig, LmHeadModel, MaskedLmExample from levanter.types import BlockFoldable from levanter.utils.flop_utils import lm_flops_per_token @@ -809,6 +810,25 @@ def __call__( return (prediction_scores,) + outputs[2:] + def compute_loss( + self, + example: MaskedLmExample, + *, + key=None, + reduction: Optional[hax.ReductionFunction] = hax.mean, + reduction_axis: Optional[hax.AxisSelection] = None, + ) -> jnp.ndarray | NamedArray: + logits = self(example.tokens, example.attn_mask, key=key) + logits = logits.astype(jnp.float32) + targets = example.targets + + target_y = hax.nn.one_hot(targets, self.Vocab, dtype=logits.dtype) + target_y = jax.debug.breakpoint(token=target_y) + loss = cross_entropy_loss( + logits, self.Vocab, target_y, reduction, reduction_axis=reduction_axis, where=example.loss_mask + ) + + return loss def _rotate_half(x: NamedArray) -> NamedArray: """Rotates half of the hidden dims of the input and concatenates them.""" From 5b97400def34e0e4631322d33998167f50261bbc Mon Sep 17 00:00:00 2001 From: JulienDarve Date: Wed, 11 Sep 2024 20:43:55 -0700 Subject: [PATCH 15/29] update 2 --- src/levanter/main/train_mlm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/levanter/main/train_mlm.py b/src/levanter/main/train_mlm.py index 1f1d96140..1821b1b90 100644 --- a/src/levanter/main/train_mlm.py +++ b/src/levanter/main/train_mlm.py @@ -83,7 +83,7 @@ def main(config: TrainMlmConfig): levanter.initialize(config) optimizer = config.optimizer.build(config.trainer.num_train_steps) - loss_function = functools.partial(compute_next_token_loss, logsumexp_weight=config.z_loss_weight) + loss_function = functools.partial(compute_next_token_loss) # Using the trainer as a context manager does 3 things: # 1. Sets the device mesh From bd7d411ce98adf8dca38aa1d837945c86390a20c Mon Sep 17 00:00:00 2001 From: JulienDarve Date: Wed, 11 Sep 2024 21:03:58 -0700 Subject: [PATCH 16/29] update 3 --- config/roberta.yaml | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/config/roberta.yaml b/config/roberta.yaml index c854f8109..cea6bbb77 100644 --- a/config/roberta.yaml +++ b/config/roberta.yaml @@ -1,9 +1,5 @@ data: - train_urls: - - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz" - validation_urls: - - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz" - cache_dir: "gs://levanter-data/tokenized/openwebtext_roberta/" + id: dlwh/wikitext_103_detokenized tokenizer: "roberta-base" model: From b5d8e143c6bffbcfd009071acfe497e36fc147ed Mon Sep 17 00:00:00 2001 From: JulienDarve Date: Thu, 12 Sep 2024 10:24:07 -0700 Subject: [PATCH 17/29] update --- src/levanter/models/roberta.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/levanter/models/roberta.py b/src/levanter/models/roberta.py index 386b8511f..389bcc946 100644 --- a/src/levanter/models/roberta.py +++ b/src/levanter/models/roberta.py @@ -578,6 +578,11 @@ def create_position_ids_from_inputs_embeds(self, input_axes, PosInput): @named_call def embed(self, input_ids=None, token_type_ids=None, position_ids=None, input_embeds=None, past_key_values_length=0, *, key = None): + print(input_ids.dtype) + print(token_type_ids.dtype) + print(position_ids.dtype) + print(input_embeds.dtype) + """ Note: When inputting your own embeds into input_embeds, make sure that the embeds axis has the name "embed" for compatibility with the position_id creation function. Make sures its length is not equal to From 8717c3f5c2bd969f6bcb1c728954499627a48ac4 Mon Sep 17 00:00:00 2001 From: JulienDarve Date: Thu, 12 Sep 2024 10:27:17 -0700 Subject: [PATCH 18/29] update --- src/levanter/models/roberta.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/src/levanter/models/roberta.py b/src/levanter/models/roberta.py index 389bcc946..66f7db3af 100644 --- a/src/levanter/models/roberta.py +++ b/src/levanter/models/roberta.py @@ -578,11 +578,25 @@ def create_position_ids_from_inputs_embeds(self, input_axes, PosInput): @named_call def embed(self, input_ids=None, token_type_ids=None, position_ids=None, input_embeds=None, past_key_values_length=0, *, key = None): - print(input_ids.dtype) - print(token_type_ids.dtype) - print(position_ids.dtype) - print(input_embeds.dtype) + if input_ids is not None: + jax.debug.print(f"input_ids: {input_ids.dtype}") + else: + jax.debug.print(f"input_ids: None") + + if token_type_ids is not None: + jax.debug.print(f"token_type_ids: {token_type_ids.dtype}") + else: + jax.debug.print(f"token_type_ids: None") + + if position_ids is not None: + jax.debug.print(f"position_ids: {position_ids.dtype}") + else: + jax.debug.print(f"position_ids: None") + if input_embeds is not None: + jax.debug.print(f"input_embeds: {input_embeds.dtype}") + else: + jax.debug.print(f"input_embeds: None") """ Note: When inputting your own embeds into input_embeds, make sure that the embeds axis has the name "embed" for compatibility with the position_id creation function. Make sures its length is not equal to From 10c130c0ca7eecff98fc7655c10c5989a7ab8730 Mon Sep 17 00:00:00 2001 From: JulienDarve Date: Thu, 12 Sep 2024 10:33:09 -0700 Subject: [PATCH 19/29] update --- src/levanter/models/roberta.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/levanter/models/roberta.py b/src/levanter/models/roberta.py index 66f7db3af..12031a48c 100644 --- a/src/levanter/models/roberta.py +++ b/src/levanter/models/roberta.py @@ -579,24 +579,25 @@ def create_position_ids_from_inputs_embeds(self, input_axes, PosInput): @named_call def embed(self, input_ids=None, token_type_ids=None, position_ids=None, input_embeds=None, past_key_values_length=0, *, key = None): if input_ids is not None: - jax.debug.print(f"input_ids: {input_ids.dtype}") + jax.debug.print(f"input_ids: {d}", d=input_ids.dtype) else: jax.debug.print(f"input_ids: None") if token_type_ids is not None: - jax.debug.print(f"token_type_ids: {token_type_ids.dtype}") + jax.debug.print(f"token_type_ids: {d}", d=token_type_ids.dtype) else: jax.debug.print(f"token_type_ids: None") if position_ids is not None: - jax.debug.print(f"position_ids: {position_ids.dtype}") + jax.debug.print(f"position_ids: {d}", d=position_ids.dtype) else: jax.debug.print(f"position_ids: None") if input_embeds is not None: - jax.debug.print(f"input_embeds: {input_embeds.dtype}") + jax.debug.print(f"input_embeds: {d}", d=input_embeds.dtype) else: jax.debug.print(f"input_embeds: None") + """ Note: When inputting your own embeds into input_embeds, make sure that the embeds axis has the name "embed" for compatibility with the position_id creation function. Make sures its length is not equal to From 834d88dd71fc8a813eb34a128c9c9f9225abfc1e Mon Sep 17 00:00:00 2001 From: JulienDarve Date: Thu, 12 Sep 2024 10:33:44 -0700 Subject: [PATCH 20/29] update --- src/levanter/models/roberta.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/levanter/models/roberta.py b/src/levanter/models/roberta.py index 12031a48c..d9c00283b 100644 --- a/src/levanter/models/roberta.py +++ b/src/levanter/models/roberta.py @@ -579,24 +579,24 @@ def create_position_ids_from_inputs_embeds(self, input_axes, PosInput): @named_call def embed(self, input_ids=None, token_type_ids=None, position_ids=None, input_embeds=None, past_key_values_length=0, *, key = None): if input_ids is not None: - jax.debug.print(f"input_ids: {d}", d=input_ids.dtype) + jax.debug.print("input_ids: {d}", d=input_ids.dtype) else: - jax.debug.print(f"input_ids: None") + jax.debug.print("input_ids: None") if token_type_ids is not None: - jax.debug.print(f"token_type_ids: {d}", d=token_type_ids.dtype) + jax.debug.print("token_type_ids: {d}", d=token_type_ids.dtype) else: - jax.debug.print(f"token_type_ids: None") + jax.debug.print("token_type_ids: None") if position_ids is not None: - jax.debug.print(f"position_ids: {d}", d=position_ids.dtype) + jax.debug.print("position_ids: {d}", d=position_ids.dtype) else: - jax.debug.print(f"position_ids: None") + jax.debug.print("position_ids: None") if input_embeds is not None: - jax.debug.print(f"input_embeds: {d}", d=input_embeds.dtype) + jax.debug.print("input_embeds: {d}", d=input_embeds.dtype) else: - jax.debug.print(f"input_embeds: None") + jax.debug.print("input_embeds: None") """ Note: When inputting your own embeds into input_embeds, make sure that the embeds axis has the name "embed" From 47fe23bcfe732cc97ef52d2cd46be53cbc3fdb78 Mon Sep 17 00:00:00 2001 From: JulienDarve Date: Thu, 12 Sep 2024 11:38:12 -0700 Subject: [PATCH 21/29] update --- src/levanter/models/roberta.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/levanter/models/roberta.py b/src/levanter/models/roberta.py index d9c00283b..9ed1328a4 100644 --- a/src/levanter/models/roberta.py +++ b/src/levanter/models/roberta.py @@ -579,24 +579,24 @@ def create_position_ids_from_inputs_embeds(self, input_axes, PosInput): @named_call def embed(self, input_ids=None, token_type_ids=None, position_ids=None, input_embeds=None, past_key_values_length=0, *, key = None): if input_ids is not None: - jax.debug.print("input_ids: {d}", d=input_ids.dtype) + print(f"input_ids: {input_ids.dtype}") else: - jax.debug.print("input_ids: None") + print("input_ids: None") if token_type_ids is not None: - jax.debug.print("token_type_ids: {d}", d=token_type_ids.dtype) + print(f"input_ids: {token_type_ids.dtype}") else: - jax.debug.print("token_type_ids: None") + print("token_type_ids: None") if position_ids is not None: - jax.debug.print("position_ids: {d}", d=position_ids.dtype) + print(f"input_ids: {position_ids.dtype}") else: - jax.debug.print("position_ids: None") + print("position_ids: None") if input_embeds is not None: - jax.debug.print("input_embeds: {d}", d=input_embeds.dtype) + print(f"input_ids: {input_embeds.dtype}") else: - jax.debug.print("input_embeds: None") + print("input_embeds: None") """ Note: When inputting your own embeds into input_embeds, make sure that the embeds axis has the name "embed" From fb5c55c5739c9fb5b14664aadc2ffa07834fd173 Mon Sep 17 00:00:00 2001 From: JulienDarve Date: Thu, 12 Sep 2024 11:50:15 -0700 Subject: [PATCH 22/29] update --- src/levanter/main/train_mlm.py | 15 +++++++++++++-- src/levanter/models/lm_model.py | 32 ++++++++++++++++++++++++++++++++ src/levanter/models/roberta.py | 6 +++--- 3 files changed, 48 insertions(+), 5 deletions(-) diff --git a/src/levanter/main/train_mlm.py b/src/levanter/main/train_mlm.py index 1821b1b90..f0a7a1692 100644 --- a/src/levanter/main/train_mlm.py +++ b/src/levanter/main/train_mlm.py @@ -21,7 +21,7 @@ from levanter.data.text import MaskedLmDataset, LMDatasetConfig, LMMixtureDatasetConfig from levanter.models.gpt2 import Gpt2Config from levanter.models.llama import LlamaConfig -from levanter.models.lm_model import LmConfig, compute_next_token_loss +from levanter.models.lm_model import LmConfig, MaskedLmExample, compute_next_token_loss from levanter.models.roberta import RobertaConfig from levanter.optim import AdamConfig, OptimizerConfig from levanter.trainer import Trainer, TrainerConfig @@ -83,7 +83,18 @@ def main(config: TrainMlmConfig): levanter.initialize(config) optimizer = config.optimizer.build(config.trainer.num_train_steps) - loss_function = functools.partial(compute_next_token_loss) + # loss_function = functools.partial(compute_next_token_loss) + + def loss_function( + model, + example: MaskedLmExample, + *, + key=None, + reduction: Optional[hax.ReductionFunction] = hax.mean, + reduction_axis: Optional[hax.AxisSelection] = None, + ): + return model.compute_loss(example, key=key, reduction=reduction, reduction_axis=reduction_axis) + # Using the trainer as a context manager does 3 things: # 1. Sets the device mesh diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index f8d987e1a..926384cab 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -169,3 +169,35 @@ def compute_next_token_loss( ) return loss + +def compute_next_token_loss( + model: LmHeadModel, + example: LmExample, + *, + key=None, + reduction: Optional[hax.ReductionFunction] = hax.mean, + reduction_axis: Optional[hax.AxisSelection] = None, + logsumexp_weight: Optional[float] = None, + loss_dtype: Optional[Type[jnp.dtype]] = jnp.float32, +) -> jnp.ndarray | NamedArray: + """ + Computes the cross-entropy loss for a language modeling example. If reduction is not None, the loss is reduced + across the reduction axis (with reduction_axis=None meaning all axes). If reduction is None, the loss is not + reduced, and the result is a named array with axes (*batch axes, sequence_length). + """ + logits = model(example.tokens, example.attn_mask, key=key) + if loss_dtype is not None: + logits = logits.astype(loss_dtype) + + loss = next_token_loss( + model.Pos, + model.Vocab, + logits, + example.tokens, + loss_mask=example.loss_mask, + reduction=reduction, + reduction_axis=reduction_axis, + logsumexp_weight=logsumexp_weight, + ) + + return loss \ No newline at end of file diff --git a/src/levanter/models/roberta.py b/src/levanter/models/roberta.py index 9ed1328a4..27355b5ad 100644 --- a/src/levanter/models/roberta.py +++ b/src/levanter/models/roberta.py @@ -584,17 +584,17 @@ def embed(self, input_ids=None, token_type_ids=None, position_ids=None, input_em print("input_ids: None") if token_type_ids is not None: - print(f"input_ids: {token_type_ids.dtype}") + print(f"token_type_ids: {token_type_ids.dtype}") else: print("token_type_ids: None") if position_ids is not None: - print(f"input_ids: {position_ids.dtype}") + print(f"position_ids: {position_ids.dtype}") else: print("position_ids: None") if input_embeds is not None: - print(f"input_ids: {input_embeds.dtype}") + print(f"input_embeds: {input_embeds.dtype}") else: print("input_embeds: None") From 8594e796985c471a4e9696de89c0ca053196fa3a Mon Sep 17 00:00:00 2001 From: JulienDarve Date: Thu, 12 Sep 2024 11:55:03 -0700 Subject: [PATCH 23/29] update --- src/levanter/models/roberta.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/levanter/models/roberta.py b/src/levanter/models/roberta.py index 27355b5ad..958729c1a 100644 --- a/src/levanter/models/roberta.py +++ b/src/levanter/models/roberta.py @@ -217,8 +217,8 @@ def to_hf_config(self, vocab_size: int, config_overrides: Optional[Dict] = None) ) @property - def model_type(self) -> Type["RobertaModel"]: - return RobertaModel + def model_type(self) -> Type["RobertaForMaskedLM"]: + return RobertaForMaskedLM def flops_per_token(self, vocab_size: int): return lm_flops_per_token( From 3ae80d79df3b0f90b61a2b6bb9566ec40866011f Mon Sep 17 00:00:00 2001 From: JulienDarve Date: Thu, 12 Sep 2024 11:57:53 -0700 Subject: [PATCH 24/29] update --- src/levanter/models/roberta.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/levanter/models/roberta.py b/src/levanter/models/roberta.py index 958729c1a..c20b081ed 100644 --- a/src/levanter/models/roberta.py +++ b/src/levanter/models/roberta.py @@ -838,7 +838,7 @@ def compute_loss( reduction: Optional[hax.ReductionFunction] = hax.mean, reduction_axis: Optional[hax.AxisSelection] = None, ) -> jnp.ndarray | NamedArray: - logits = self(example.tokens, example.attn_mask, key=key) + logits = self(example.tokens, example.attn_mask, key=key)[0] logits = logits.astype(jnp.float32) targets = example.targets From de93fc921e071bfb56fece61a439557a6a9c3bce Mon Sep 17 00:00:00 2001 From: JulienDarve Date: Thu, 12 Sep 2024 12:06:47 -0700 Subject: [PATCH 25/29] update --- src/levanter/models/roberta.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/levanter/models/roberta.py b/src/levanter/models/roberta.py index c20b081ed..309cb1a70 100644 --- a/src/levanter/models/roberta.py +++ b/src/levanter/models/roberta.py @@ -828,7 +828,8 @@ def __call__( prediction_scores = self.lm_head(outputs[0], key=k_lm) - return (prediction_scores,) + outputs[2:] + # return (prediction_scores,) + outputs[2:] + return prediction_scores def compute_loss( self, @@ -838,7 +839,8 @@ def compute_loss( reduction: Optional[hax.ReductionFunction] = hax.mean, reduction_axis: Optional[hax.AxisSelection] = None, ) -> jnp.ndarray | NamedArray: - logits = self(example.tokens, example.attn_mask, key=key)[0] + # logits = self(example.tokens, example.attn_mask, key=key)[0] + logits = self(example.tokens, example.attn_mask, key=key) logits = logits.astype(jnp.float32) targets = example.targets From 896af7d8e367f8e0bbeba659cd705e5812285cdb Mon Sep 17 00:00:00 2001 From: JulienDarve Date: Thu, 12 Sep 2024 12:10:38 -0700 Subject: [PATCH 26/29] update --- src/levanter/models/roberta.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/levanter/models/roberta.py b/src/levanter/models/roberta.py index 309cb1a70..014f583ea 100644 --- a/src/levanter/models/roberta.py +++ b/src/levanter/models/roberta.py @@ -13,7 +13,7 @@ from haliax.nn import cross_entropy_loss from haliax import Axis, AxisSpec, NamedArray from haliax.jax_utils import maybe_rng_split, named_call, shaped_rng_split -from haliax.nn.scan import BlockSeq +from haliax.nn.scan import BlockSeq, Stacked from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig from levanter.compat.torch_serialization import ( @@ -495,7 +495,7 @@ class RobertaEncoder(eqx.Module, StateDictSerializationMixin): @staticmethod def init(config: RobertaConfig, output_hidden_states: bool = False, *, key) -> "RobertaEncoder": - S = BlockSeq + S = Stacked layer = S.init(config.Layers, RobertaLayer, gradient_checkpointing=config.gradient_checkpointing)( config, From 0be9a83b9b2c6bef6983234adbe65f8a5c291e1d Mon Sep 17 00:00:00 2001 From: JulienDarve Date: Thu, 12 Sep 2024 12:13:58 -0700 Subject: [PATCH 27/29] update --- src/levanter/models/roberta.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/levanter/models/roberta.py b/src/levanter/models/roberta.py index 014f583ea..03804934e 100644 --- a/src/levanter/models/roberta.py +++ b/src/levanter/models/roberta.py @@ -495,7 +495,7 @@ class RobertaEncoder(eqx.Module, StateDictSerializationMixin): @staticmethod def init(config: RobertaConfig, output_hidden_states: bool = False, *, key) -> "RobertaEncoder": - S = Stacked + S = BlockFoldable layer = S.init(config.Layers, RobertaLayer, gradient_checkpointing=config.gradient_checkpointing)( config, @@ -515,12 +515,15 @@ def __call__( keys = maybe_rng_split(key, self.config.num_hidden_layers) if key is not None else None - x, intermediates = self.layer.scan(hidden_states, attention_mask, key=keys) + # x, intermediates = self.layer.scan(hidden_states, attention_mask, key=keys) + x = self.layer.fold(hidden_states, attention_mask, key=keys) - if not self.output_hidden_states: - return x, None - else: - return x, intermediates + return x, None + + # if not self.output_hidden_states: + # return x, None + # else: + # return x, intermediates def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): out = super().from_state_dict(state_dict, prefix=prefix) From 0c94a472d117cd09696d6b7613f389c5046334ca Mon Sep 17 00:00:00 2001 From: JulienDarve Date: Thu, 12 Sep 2024 12:14:51 -0700 Subject: [PATCH 28/29] update --- src/levanter/models/roberta.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/levanter/models/roberta.py b/src/levanter/models/roberta.py index 03804934e..85b4a593d 100644 --- a/src/levanter/models/roberta.py +++ b/src/levanter/models/roberta.py @@ -485,7 +485,8 @@ def __call__( # jax.debug.print("{layer_output}", layer_output=layer_output) - return (layer_output, layer_output) + # return (layer_output, layer_output) + return layer_output class RobertaEncoder(eqx.Module, StateDictSerializationMixin): @@ -495,7 +496,7 @@ class RobertaEncoder(eqx.Module, StateDictSerializationMixin): @staticmethod def init(config: RobertaConfig, output_hidden_states: bool = False, *, key) -> "RobertaEncoder": - S = BlockFoldable + S = BlockSeq layer = S.init(config.Layers, RobertaLayer, gradient_checkpointing=config.gradient_checkpointing)( config, From 7ae681d48f8696da11b870aac30e9e70bb38c90c Mon Sep 17 00:00:00 2001 From: Julien Darve Date: Thu, 12 Sep 2024 18:49:24 -0700 Subject: [PATCH 29/29] Training works! --- src/levanter/main/train_mlm.py | 1 + src/levanter/models/roberta.py | 69 +++++++++++++++++++++++----------- 2 files changed, 48 insertions(+), 22 deletions(-) diff --git a/src/levanter/main/train_mlm.py b/src/levanter/main/train_mlm.py index f0a7a1692..a54baf13d 100644 --- a/src/levanter/main/train_mlm.py +++ b/src/levanter/main/train_mlm.py @@ -100,6 +100,7 @@ def loss_function( # 1. Sets the device mesh # 2. Sets the axis mapping (for fsdp) # 3. Sets the global metrics tracker + # with Trainer(config.trainer, optimizer, loss_function) as trainer, jax.disable_jit(True): with Trainer(config.trainer, optimizer, loss_function) as trainer: # randomness in jax is tightly controlled by "keys" which are the states of the random number generators # this makes deterministic training pretty easy diff --git a/src/levanter/models/roberta.py b/src/levanter/models/roberta.py index 85b4a593d..816f7f1ad 100644 --- a/src/levanter/models/roberta.py +++ b/src/levanter/models/roberta.py @@ -352,6 +352,8 @@ def __call__( outputs = hax.rearrange(context_layer, ("... heads position head_size -> ... position (embed_att: heads head_size)"), heads=self.Heads, head_size=self.HeadSize) + # jax.debug.breakpoint() + return outputs class RobertaSelfOutput(eqx.Module, StateDictSerializationMixin): @@ -483,7 +485,7 @@ def __call__( intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output, key=k_o) - # jax.debug.print("{layer_output}", layer_output=layer_output) + # # jax.debug.print("{layer_output}", layer_output=layer_output) # return (layer_output, layer_output) return layer_output @@ -572,35 +574,36 @@ def init(Vocab: Axis, config: RobertaConfig, *, key) -> "RobertaEmbedding": def create_position_ids_from_input_ids(self, input_ids, past_key_values_length=0): mask = hax.not_equal(input_ids, self.padding_idx) * 1 incremental_indices = (hax.cumsum(mask, axis=self.Pos).astype(mask) + past_key_values_length) * mask - return incremental_indices + self.padding_idx + incremental_indices -= mask.all(axis=self.Pos) + return incremental_indices def create_position_ids_from_inputs_embeds(self, input_axes, PosInput): - # position_ids = hax.arange(axis = PosInput, start = 0, dtype=jnp.int32) - position_ids = hax.arange(axis = PosInput, start = self.padding_idx + 1, dtype=jnp.int32) + position_ids = hax.arange(axis = PosInput, start = 0, dtype=jnp.int32) + # position_ids = hax.arange(axis = PosInput, start = self.padding_idx + 1, dtype=jnp.int32) return hax.broadcast_to(position_ids, input_axes) @named_call def embed(self, input_ids=None, token_type_ids=None, position_ids=None, input_embeds=None, past_key_values_length=0, *, key = None): - if input_ids is not None: - print(f"input_ids: {input_ids.dtype}") - else: - print("input_ids: None") + # if input_ids is not None: + # print(f"input_ids: {input_ids.dtype}") + # else: + # print("input_ids: None") - if token_type_ids is not None: - print(f"token_type_ids: {token_type_ids.dtype}") - else: - print("token_type_ids: None") + # if token_type_ids is not None: + # print(f"token_type_ids: {token_type_ids.dtype}") + # else: + # print("token_type_ids: None") - if position_ids is not None: - print(f"position_ids: {position_ids.dtype}") - else: - print("position_ids: None") + # if position_ids is not None: + # print(f"position_ids: {position_ids.dtype}") + # else: + # print("position_ids: None") - if input_embeds is not None: - print(f"input_embeds: {input_embeds.dtype}") - else: - print("input_embeds: None") + # if input_embeds is not None: + # print(f"input_embeds: {input_embeds.dtype}") + # else: + # print("input_embeds: None") """ Note: When inputting your own embeds into input_embeds, make sure that the embeds axis has the name "embed" @@ -637,6 +640,9 @@ def embed(self, input_ids=None, token_type_ids=None, position_ids=None, input_em embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings, key=key) + + # jax.debug.breakpoint() + return embeddings class RobertaPooler(eqx.Module, StateDictSerializationMixin): @@ -727,8 +733,12 @@ def __call__( if attention_mask is None: attention_mask = hax.ones(input_axes) + # print(f"attention_mask: {attention_mask}") + # Attention mask from mask to actual numbers 0 -> -inf attention_mask = (attention_mask == 0) * jnp.finfo(jnp.bfloat16).min + + # print(f"attention_mask_real: {attention_mask}") embedding_output = self.embeddings.embed(input_ids, token_type_ids, position_ids, input_embeds, key=k_emb) @@ -738,6 +748,8 @@ def __call__( pooled_output = self.pooler(sequence_output, key=k_p) if self.pooler is not None else None + # jax.debug.breakpoint() + return (sequence_output, pooled_output) + encoder_outputs[1:] if self.output_hidden_states else (sequence_output, pooled_output) class RobertaLMHead(eqx.Module, StateDictSerializationMixin): @@ -779,6 +791,7 @@ class RobertaForMaskedLM(eqx.Module, StateDictSerializationMixin): roberta: RobertaModel lm_head: RobertaLMHead Vocab: Axis + Pos: Axis output_hidden_states: bool @classmethod @@ -791,7 +804,7 @@ def init(self, Vocab: Axis, config: RobertaConfig, output_hidden_states: bool = roberta = RobertaModel.init(Vocab, config, add_pooling_layer=False, output_hidden_states=output_hidden_states, key=key_rob) lm_head = RobertaLMHead.init(Vocab, config, key=key_head) - return RobertaForMaskedLM(roberta, lm_head, Vocab, output_hidden_states) + return RobertaForMaskedLM(roberta, lm_head, Vocab, config.Pos, output_hidden_states) def get_output_embeddings(self): return self.lm_head.decoder @@ -821,6 +834,9 @@ def __call__( k_rob, k_lm = maybe_rng_split(key, 2) + # print(f"input_ids: {input_ids}") + # print(f"attention_mask: {attention_mask}") + outputs = self.roberta( input_ids, attention_mask=attention_mask, @@ -830,8 +846,14 @@ def __call__( key=k_rob ) + # print(f"outputs: {outputs}") + prediction_scores = self.lm_head(outputs[0], key=k_lm) + # print(f"prediction_scores: {prediction_scores}") + + # jax.debug.breakpoint() + # return (prediction_scores,) + outputs[2:] return prediction_scores @@ -845,15 +867,18 @@ def compute_loss( ) -> jnp.ndarray | NamedArray: # logits = self(example.tokens, example.attn_mask, key=key)[0] logits = self(example.tokens, example.attn_mask, key=key) + # print(f"Logits: {logits}") logits = logits.astype(jnp.float32) targets = example.targets target_y = hax.nn.one_hot(targets, self.Vocab, dtype=logits.dtype) - target_y = jax.debug.breakpoint(token=target_y) + #target_y = jax.debug.breakpoint(token=target_y) loss = cross_entropy_loss( logits, self.Vocab, target_y, reduction, reduction_axis=reduction_axis, where=example.loss_mask ) + # print(f"loss: {loss}") + return loss def _rotate_half(x: NamedArray) -> NamedArray: