diff --git a/config/roberta-tiny.yaml b/config/roberta-tiny.yaml new file mode 100644 index 000000000..4b61ff7e4 --- /dev/null +++ b/config/roberta-tiny.yaml @@ -0,0 +1,39 @@ +data: + id: dlwh/wikitext_103_detokenized +# train_urls: +# - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz" +# validation_urls: +# - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz" + cache_dir: "cache/roberta-tiny" + tokenizer: "roberta-base" + +model: + type: roberta + vocab_size: 50265 + hidden_size: 32 + intermediate_size: 64 + num_hidden_layers: 4 + num_attention_heads: 2 + max_position_embeddings: 512 + hidden_act: "gelu" + hidden_dropout_prob: 0.1 + attention_probs_dropout_prob: 0.1 + gradient_checkpointing: true + +trainer: + tracker: + - type: wandb + project: "levanter" + tags: ["openwebtext", "roberta", "itest"] + + mp: p=f32,c=bfloat16 + model_axis_size: 1 + per_device_parallelism: -1 + + train_batch_size: 32 + num_train_steps: 20000 + +optimizer: + learning_rate: 1E-3 + weight_decay: 0.1 + warmup: 0.01 \ No newline at end of file diff --git a/config/roberta.yaml b/config/roberta.yaml new file mode 100644 index 000000000..cea6bbb77 --- /dev/null +++ b/config/roberta.yaml @@ -0,0 +1,34 @@ +data: + id: dlwh/wikitext_103_detokenized + tokenizer: "roberta-base" + +model: + type: roberta + vocab_size: 50265 + hidden_size: 768 + intermediate_size: 3072 + num_hidden_layers: 12 + num_attention_heads: 12 + max_position_embeddings: 512 + hidden_act: "gelu" + hidden_dropout_prob: 0.1 + attention_probs_dropout_prob: 0.1 + gradient_checkpointing: true + +trainer: + tracker: + - type: wandb + project: "levanter" + tags: ["openwebtext", "roberta", "itest"] + + mp: p=f32,c=bfloat16 + model_axis_size: 1 + per_device_parallelism: -1 + + train_batch_size: 32 + num_train_steps: 20000 + +optimizer: + learning_rate: 1E-3 + weight_decay: 0.1 + warmup: 0.01 \ No newline at end of file diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index c29e55e83..e29cee205 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -14,6 +14,7 @@ import equinox as eqx import fsspec import jax +import jax.numpy as jnp import numpy as np import pyarrow as pa import regex @@ -25,13 +26,11 @@ from levanter.data.mixture import MixtureDataset, StopStrategy -# intercept the logging nonsense here from levanter.logging import silence_transformer_nag # noqa from levanter.models.attention import AttentionMask -from levanter.models.lm_model import LmExample +from levanter.models.lm_model import MaskedLmExample, LmExample from levanter.utils.hf_utils import num_cpus_used_by_tokenizer - silence_transformer_nag() # noqa from transformers import BatchEncoding, PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast # noqa @@ -47,7 +46,6 @@ from levanter.shapes import NamedShapeSpec, ShapeSpec # noqa from levanter.utils.jax_utils import use_cpu_device # noqa - logger = logging.getLogger("levanter.data.text") # TASKS: @@ -58,6 +56,83 @@ DEFAULT_IGNORE_INDEX = -100 # Mirrors pytorch's default ignore index +class MaskedLmDataset(ShardableDataset[MaskedLmExample]): + def __init__( + self, + dataset: ShardableDataset[np.ndarray], + QPos: Axis, + KPos: Axis, + mask_token_id: int, + mask_prob: float = 0.15, + noise_prob: float = 0.1, + key: Optional[PRNGKeyArray] = None, + # ignore_index: Optional[int] = DEFAULT_IGNORE_INDEX, + ): + self.dataset = dataset + self.QPos = QPos + self.KPos = KPos + self.mask_prob = mask_prob + self.noise_prob = noise_prob + self.key = key + self.mask_token_id = mask_token_id + + if self.mask_prob > 0.0 and self.key is None: + raise ValueError("must provide key if mask_prob > 0.0") + + def shard(self, shard_id: int, num_shards: int) -> "MaskedLmDataset": + return MaskedLmDataset( + self.dataset.shard(shard_id, num_shards), self.QPos, self.KPos, + self.mask_token_id, + self.mask_prob, self.noise_prob, self.key + ) + + def __iter__(self) -> Iterator[MaskedLmExample]: + key = self.key + sharding = jax.sharding.SingleDeviceSharding(jax.local_devices(backend="cpu")[0]) + + with use_cpu_device(): + @functools.partial(eqx.filter_jit, out_shardings=sharding) + def _create_mlm_example(tokens, key): + tokens_array = tokens.array + targets = tokens_array.copy() + + if self.mask_prob > 0: + this_key, key = jax.random.split(key) + mask_shape = tokens_array.shape + mask = jax.random.bernoulli(this_key, self.mask_prob, mask_shape) + + rand = jax.random.uniform(this_key, mask_shape) + mask_token = jnp.where(rand < 0.8, self.mask_token_id, tokens_array) + random_tokens = jax.random.randint(this_key, mask_shape, 0, tokens_array.max() + 1) + mask_token = jnp.where((rand >= 0.8) & (rand < 0.8 + self.noise_prob), random_tokens, mask_token) + masked_tokens = jnp.where(mask, mask_token, tokens_array) + + # Set targets to the original tokens where mask is True, otherwise set to mask_token_id + targets = jnp.where(mask, tokens_array, self.mask_token_id) + + masked_tokens_named = hax.named(masked_tokens, self.QPos) + targets_named = hax.named(targets, self.QPos) + + attn_mask_shape = (tokens_array.shape[0], tokens_array.shape[0]) + attn_mask = hax.named(jnp.ones(attn_mask_shape, dtype=jnp.bool_), (self.QPos, self.KPos)) + + example = MaskedLmExample.masked_lm(tokens=masked_tokens_named, targets=targets_named, mask_token_id=self.mask_token_id, attn_mask=attn_mask) + else: + targets_named = hax.named(targets, self.QPos) + attn_mask_shape = (tokens_array.shape[0], tokens_array.shape[0]) + attn_mask = hax.named(jnp.ones(attn_mask_shape, dtype=jnp.bool_), (self.QPos, self.KPos)) + + example = MaskedLmExample.masked_lm(tokens=tokens, targets=targets_named, mask_token_id=self.mask_token_id, attn_mask=attn_mask) + + return example + + for tokens in self.dataset: + tokens_array = jnp.array(tokens) + tokens_named = hax.named(tokens_array, self.QPos) + example = _create_mlm_example(tokens_named, key) + yield example + + class CausalLmDataset(ShardableDataset[LmExample]): def __init__( @@ -89,7 +164,6 @@ def __iter__(self) -> Iterator[LmExample]: sharding = jax.sharding.SingleDeviceSharding(jax.local_devices(backend="cpu")[0]) with use_cpu_device(): - @functools.partial(eqx.filter_jit, out_shardings=sharding) def _create_lm_example(tokens, key): tokens = hax.named(tokens, self.QPos) @@ -97,10 +171,6 @@ def _create_lm_example(tokens, key): example = LmExample.causal(tokens=tokens, ignore_id=self.ignore_id) if self.fcm_prob > 0: - # masks for attention - # We support forgetful causal masking (FCM) which is a technique that improves training speed by - # randomly masking out some of the context. This is a bit like dropout, but it's applied to the attention - # mask instead of the activations. It's described in https://arxiv.org/abs/2210.13432 assert self.key is not None this_key, key = jax.random.split(key) fcm_mask = hax.nn.attention.forgetful_causal_mask(self.KPos, self.fcm_prob, key=this_key) @@ -114,6 +184,7 @@ def _create_lm_example(tokens, key): yield example + class TokenSeqDataset(ShardableDataset[np.ndarray]): """ A dataset that yields sequences of tokens of fixed length from a TokenizedDocumentCache. @@ -826,4 +897,4 @@ def build_caches( @property def sources(self) -> dict[str, LMDatasetSourceConfig]: - return self.configs + return self.configs \ No newline at end of file diff --git a/src/levanter/main/train_mlm.py b/src/levanter/main/train_mlm.py new file mode 100644 index 000000000..a54baf13d --- /dev/null +++ b/src/levanter/main/train_mlm.py @@ -0,0 +1,219 @@ +# train_mlm.py + +import dataclasses +import functools +import gc +import logging +import os +from dataclasses import dataclass, field +from typing import Optional, Union + +import jax +import jax.random as jrandom + +import haliax as hax +from haliax import Axis +from haliax.partitioning import named_jit, round_axis_for_partitioning + +import levanter +from levanter import callbacks +from levanter.compat.hf_checkpoints import HFCompatConfig, save_hf_checkpoint_callback +from levanter.data.text import MaskedLmDataset, LMDatasetConfig, LMMixtureDatasetConfig +from levanter.models.gpt2 import Gpt2Config +from levanter.models.llama import LlamaConfig +from levanter.models.lm_model import LmConfig, MaskedLmExample, compute_next_token_loss +from levanter.models.roberta import RobertaConfig +from levanter.optim import AdamConfig, OptimizerConfig +from levanter.trainer import Trainer, TrainerConfig +from levanter.utils.jax_utils import parameter_count + +logger = logging.getLogger(__name__) + +@dataclass +class TrainMlmConfig: + data: Union[LMDatasetConfig, LMMixtureDatasetConfig] = field(default_factory=LMDatasetConfig) + trainer: TrainerConfig = field(default_factory=TrainerConfig) + model: LmConfig = field(default_factory=RobertaConfig) + optimizer: OptimizerConfig = field(default_factory=AdamConfig) + + # config related to continued pretraining + initialize_from_hf: Union[bool, str] = False + """if provided, this will override the model config in the config. if true, use the default hf checkpoint for this model class""" + use_hf_model_config: bool = False # if true, replace the model config with the hf config from the checkpoint + + # TODO: atm we don't support loading from a checkpoint that has a different tokenizer. this is a bit annoying + # TODO: atm you have to at least specify a levanter model config with the same type as the hf checkpoint + + mlm_prob: float = 0.15 # masking probability for MLM + hf_save_path: Optional[str] = None + hf_upload: Optional[str] = None + hf_save_steps: int = 10000 + + update_hessian_steps: int = 10 + data_seed: Optional[int] = None # if provided, will override the data seed from the trainer + +def main(config: TrainMlmConfig): + tokenizer = config.data.the_tokenizer + + # this is some unpleasant code to allow us to initialize from a hf checkpoint. If this is your first read through, + # I recommend skipping it for now + if config.initialize_from_hf: + if config.trainer.initialize_from is not None: + raise ValueError("Cannot specify both initialize_from_hf and initialize_from") + + assert isinstance(config.model, HFCompatConfig) + converter = config.model.hf_checkpoint_converter() + if hasattr(tokenizer, "vocab") and tokenizer.vocab != converter.tokenizer.vocab: + logger.warning("The tokenizers appear to be different. You may want to check this.") + + if isinstance(config.initialize_from_hf, str): + converter = converter.replaced(reference_checkpoint=config.initialize_from_hf, tokenizer=tokenizer) + else: + converter = converter.replaced(tokenizer=tokenizer) + + if config.use_hf_model_config: + # TODO: log diff of old and new config + # NB: gross mutability + config.model = converter.config_from_hf_config(converter.default_hf_config) + elif isinstance(config.model, HFCompatConfig): + converter = config.model.hf_checkpoint_converter() + converter = converter.replaced(tokenizer=tokenizer) + else: + converter = None + + levanter.initialize(config) + optimizer = config.optimizer.build(config.trainer.num_train_steps) + # loss_function = functools.partial(compute_next_token_loss) + + def loss_function( + model, + example: MaskedLmExample, + *, + key=None, + reduction: Optional[hax.ReductionFunction] = hax.mean, + reduction_axis: Optional[hax.AxisSelection] = None, + ): + return model.compute_loss(example, key=key, reduction=reduction, reduction_axis=reduction_axis) + + + # Using the trainer as a context manager does 3 things: + # 1. Sets the device mesh + # 2. Sets the axis mapping (for fsdp) + # 3. Sets the global metrics tracker + # with Trainer(config.trainer, optimizer, loss_function) as trainer, jax.disable_jit(True): + with Trainer(config.trainer, optimizer, loss_function) as trainer: + # randomness in jax is tightly controlled by "keys" which are the states of the random number generators + # this makes deterministic training pretty easy + seed = config.trainer.seed + data_key, loader_key, model_key, training_key = jrandom.split(jrandom.PRNGKey(seed), 4) + + if config.data_seed is not None: + logger.info(f"Overriding data seed with {config.data_seed}") + data_key = jrandom.PRNGKey(config.data_seed) + + # We have two axis_mappings: one for storing the model and optimizer states, and one for compute + # This allows Zero-3-style parameter sharding, where we shard the parameters and optimizer state across the mesh + compute_axis_mapping = trainer.compute_axis_mapping + parameter_axis_mapping = trainer.parameter_axis_mapping + + # some axes we need + Batch = config.trainer.TrainBatch + EvalBatch = config.trainer.EvalBatch + Pos = config.model.Pos + KeyPos = config.model.KeyPos + + tagged_eval_datasets = config.data.tagged_eval_sets(Pos.size) + mask_id = tokenizer.mask_token_id + train_dataset = MaskedLmDataset( + config.data.train_set(Pos.size, key=data_key), Pos, KeyPos, + mask_token_id=mask_id, + mask_prob=config.mlm_prob, key=data_key, #ignore_index=config.data.ignore_token_id + ) + + # to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to + # For most things, we just insist you specify the config right, but tokenizers often have strange numbers of + # tokens: gpt-2 has 50257, for example. So we round up. + vocab_size = len(tokenizer) + Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), parameter_axis_mapping) + if vocab_size != Vocab.size: + logger.info(f"Rounding vocab size from {vocab_size} to {Vocab.size} for partitioning") + + state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key)) + + if int(state.step) == 0: + # TODO: I don't love that we init the model twice, but it's not a big deal i think? + if config.initialize_from_hf: + # initialize from an hf pretrained model + logger.info( + "No training checkpoint found. Initializing model from HF checkpoint" + f" '{converter.reference_checkpoint}'" + ) + # this is a bit gross, but we want to free up the memory from the model we just built + state = dataclasses.replace(state, model=None) + gc.collect() + model = converter.load_pretrained( + config.model.model_type, + config.model, + axis_mapping=parameter_axis_mapping, + dtype=trainer.mp.compute_dtype, + ) + model = named_jit(trainer.mp.cast_to_param, parameter_axis_mapping)(model) + state = dataclasses.replace(state, model=model) + else: + logger.info("No checkpoint found. Starting from scratch.") + + levanter.tracker.log_summary({"parameter_count": parameter_count(state.model)}) + + if len(tagged_eval_datasets) == 0: + logger.warning("No evaluation datasets provided.") + else: + masked_datasets = [ + (MaskedLmDataset(ds, Pos, KeyPos, mask_prob=config.mlm_prob, key=data_key, mask_token_id=mask_id), tags) + for ds, tags in tagged_eval_datasets + ] + max_eval_examples_per_ds = config.trainer.max_eval_batches + if max_eval_examples_per_ds is not None: + max_eval_examples_per_ds *= config.trainer.eval_batch_size + + cb = levanter.eval.cb_tagged_lm_evaluate( + EvalBatch, masked_datasets, trainer.device_mesh, compute_axis_mapping, max_eval_examples_per_ds + ) + trainer.add_hook(cb, every=config.trainer.steps_per_eval) + + flops_per_token = config.model.flops_per_token(vocab_size) + flops_per_example = 3 * flops_per_token * Pos.size if flops_per_token is not None else None + trainer.add_hook( + callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size, flops_per_example), every=1 + ) + if config.hf_save_path is not None: + full_save_path = os.path.join(config.hf_save_path, trainer.run_id) + + trainer.add_hook( + save_hf_checkpoint_callback(full_save_path, converter, upload_to_hf=config.hf_upload or False), + every=config.hf_save_steps, + ) + + # visualize log probs + @named_jit( + in_axis_resources=parameter_axis_mapping, + axis_resources=compute_axis_mapping, + out_axis_resources=compute_axis_mapping, + ) + def compute_log_probs(model, example): + model = trainer.mp.cast_to_compute(model) + logprobs = model.compute_loss(example, key=None, reduction=None) + # roll forward to get the loss for each predicted token + logprobs = hax.roll(logprobs, 1, Pos) + return logprobs.rearrange((EvalBatch, Pos)).array + + train_loader = iter(trainer.sharded_loader(train_dataset, Batch)) + + if int(state.step) > 0: + import tqdm + for _ in tqdm.tqdm(range(state.step), desc="seeking data for resume"): + next(train_loader) + + trainer.train(state, train_loader) + +if __name__ == "__main__": + levanter.config.main(main)() \ No newline at end of file diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index 468f6a4a4..926384cab 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -3,6 +3,7 @@ import draccus import equinox as eqx +import jax import jax.numpy as jnp from jax.random import PRNGKey @@ -16,6 +17,36 @@ LmConfigT = TypeVar("LmConfigT", bound="LmConfig") LmT = TypeVar("LmT", bound="LmHeadModel") +class MaskedLmExample(eqx.Module): + tokens: hax.NamedArray + loss_mask: hax.NamedArray + attn_mask: hax.NamedArray + targets: Optional[hax.NamedArray] = None + + @staticmethod + def masked_lm( + tokens: hax.NamedArray, targets: hax.NamedArray, attn_mask: hax.NamedArray, mask_token_id: Optional[int] = None + ) -> "MaskedLmExample": + if tokens.ndim != 1: + raise ValueError("tokens must be a 1D array") + + if not jnp.issubdtype(tokens.dtype, jnp.integer): + raise ValueError("tokens must be an integer array") + + if tokens.shape != targets.shape: + raise ValueError("tokens and targets must have the same shape") + + Pos = tokens.axes[0] + + mask = tokens.array != targets.array + loss_mask = hax.named(mask.astype(jnp.float32), Pos) + + if mask_token_id is not None: + ignore_mask = targets.array != mask_token_id + loss_mask = loss_mask * hax.named(ignore_mask.astype(jnp.float32), Pos) + + return MaskedLmExample(tokens=tokens, targets=targets, loss_mask=loss_mask, attn_mask=attn_mask) + class LmExample(eqx.Module): tokens: hax.NamedArray @@ -34,12 +65,10 @@ def causal( Pos = tokens.axes[0] - # don't predict the last token. if loss_mask is None: loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jnp.float32) if ignore_id is not None: - # we don't compute loss for any tokens matching the ignore index ignore_mask = hax.roll(tokens, -1, Pos) != ignore_id loss_mask = loss_mask * ignore_mask @@ -47,7 +76,6 @@ def causal( return LmExample(tokens=tokens, loss_mask=loss_mask, attn_mask=attn_mask) -# TODO: for some reason, mypy doesn't like the discover_packages_path argument? class LmConfig(draccus.PluginRegistry, abc.ABC, Generic[LmT], discover_packages_path="levanter.models"): # type: ignore @property @abc.abstractmethod @@ -70,12 +98,7 @@ def flops_per_token(self, vocab_size: int) -> Optional[float]: def build(self, Vocab: Axis, *, key: PRNGKey) -> "LmT": return self.model_type.init(Vocab, self, key=key) # type: ignore - class LmHeadModel(Generic[LmConfigT], abc.ABC): - """ - Superclass for models with a language modeling head. - """ - @property @abc.abstractmethod def config(self) -> LmConfigT: @@ -103,14 +126,11 @@ def init(cls, Vocab: Axis, config: LmConfigT, *, key: PRNGKey) -> "LmHeadModel[L def __call__( self, input_ids: NamedArray, attn_mask: Optional[AttentionMask | NamedArray] = None, *, key=None ) -> NamedArray: - pass + print(f"input_ids shape: {input_ids.shape}") + print(f"attn_mask shape: {attn_mask.shape}") @abc.abstractmethod def resize_vocab(self, new_size: int, key: Optional[PRNGKey] = None) -> "LmHeadModel[LmConfigT]": - """ - Resizes the vocabulary of the model. Key may be provided to use random initialization, otherwise, there - should be some deterministic initialization of any new parameters. - """ pass @property @@ -149,3 +169,35 @@ def compute_next_token_loss( ) return loss + +def compute_next_token_loss( + model: LmHeadModel, + example: LmExample, + *, + key=None, + reduction: Optional[hax.ReductionFunction] = hax.mean, + reduction_axis: Optional[hax.AxisSelection] = None, + logsumexp_weight: Optional[float] = None, + loss_dtype: Optional[Type[jnp.dtype]] = jnp.float32, +) -> jnp.ndarray | NamedArray: + """ + Computes the cross-entropy loss for a language modeling example. If reduction is not None, the loss is reduced + across the reduction axis (with reduction_axis=None meaning all axes). If reduction is None, the loss is not + reduced, and the result is a named array with axes (*batch axes, sequence_length). + """ + logits = model(example.tokens, example.attn_mask, key=key) + if loss_dtype is not None: + logits = logits.astype(loss_dtype) + + loss = next_token_loss( + model.Pos, + model.Vocab, + logits, + example.tokens, + loss_mask=example.loss_mask, + reduction=reduction, + reduction_axis=reduction_axis, + logsumexp_weight=logsumexp_weight, + ) + + return loss \ No newline at end of file diff --git a/src/levanter/models/roberta.py b/src/levanter/models/roberta.py new file mode 100644 index 000000000..816f7f1ad --- /dev/null +++ b/src/levanter/models/roberta.py @@ -0,0 +1,921 @@ +import dataclasses +from dataclasses import dataclass +from typing import Callable, Dict, List, Optional, Tuple, Type, Union + +import equinox as eqx +import jax +import jax.numpy as jnp +import jax.random as jrandom +from jaxtyping import PRNGKeyArray + +import haliax as hax +import haliax.nn as hnn +from haliax.nn import cross_entropy_loss +from haliax import Axis, AxisSpec, NamedArray +from haliax.jax_utils import maybe_rng_split, named_call, shaped_rng_split +from haliax.nn.scan import BlockSeq, Stacked + +from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig +from levanter.compat.torch_serialization import ( + StateDict, + StateDictSerializationMixin, + apply_prefix, + flatten_linear_layers, + stack_state_dict, + unflatten_linear_layers, + unstack_state_dict, +) +from levanter.logging import silence_transformer_nag +from levanter.models.attention import AttentionBackend, AttentionMask, simple_attention_with_dropout +from levanter.models.gpt2 import ACT2FN +from levanter.models.lm_model import LmConfig, LmHeadModel, MaskedLmExample +from levanter.types import BlockFoldable +from levanter.utils.flop_utils import lm_flops_per_token + +silence_transformer_nag() +from transformers import PretrainedConfig as HfConfig +from transformers import RobertaConfig as HfRobertaConfig + + + +@LmConfig.register_subclass("roberta") +@dataclass(frozen=True) +class RobertaConfig(HFCompatConfig): + r""" + + Adapted from HuggingFace RobertaConfig, description below + + + This is the configuration class to store the configuration of a [`RobertaModel`] or a [`TFRobertaModel`]. It is + used to instantiate a RoBERTa model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the RoBERTa + [FacebookAI/roberta-base](https://huggingface.co/FacebookAI/roberta-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50265): + Vocabulary size of the RoBERTa model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`RobertaModel`] or [`TFRobertaModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`RobertaModel`] or [`TFRobertaModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + position_embedding_type (`str`, *optional*, defaults to `"absolute"`): + Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For + positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to + [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). + For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models + with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658). + is_decoder (`bool`, *optional*, defaults to `False`): + Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + classifier_dropout (`float`, *optional*): + The dropout ratio for the classification head. + + Examples: + + ```python + >>> from transformers import RobertaConfig, RobertaModel + + >>> # Initializing a RoBERTa configuration + >>> configuration = RobertaConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = RobertaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + vocab_size: int = 50265 + hidden_size: int = 768 + num_hidden_layers: int = 12 + num_attention_heads: int = 12 + intermediate_size: int = 3072 + hidden_act: str = "gelu" + hidden_dropout_prob: float = 0.1 + attention_probs_dropout_prob: float = 0.1 + max_position_embeddings: int = 512 + type_vocab_size: int = 2 + initializer_range: float = 0.02 + layer_norm_eps: float = 1e-12 + pad_token_id: int = 1 + bos_token_id: int = 0 + eos_token_id: int = 2 + position_embedding_type: Optional[str] = "absolute" + use_cache: bool = False + classifier_dropout: Optional[float] = None + + scan_layers: bool = True + gradient_checkpointing: bool = True + + reference_checkpoint: str = "FacebookAI/roberta-base" + tokenizer: Optional[str] = None + + # Axes + Pos = property(lambda self: Axis(name="position", size=self.max_position_embeddings)) + KeyPos = property(lambda self: self.Pos.alias("key_position")) + Embed = property(lambda self: Axis(name="embed", size=self.hidden_size)) + EmbedAtt = property(lambda self: self.Embed.alias("embed_att")) + FinalEmbed = property(lambda self: self.Embed.alias("final_embed")) + Heads = property(lambda self: Axis(name="heads", size=self.num_attention_heads)) + Layers = property(lambda self: Axis(name="layers", size=self.num_hidden_layers)) + Mlp = property(lambda self: Axis(name="mlp", size=self.intermediate_size)) + HeadSize = property(lambda self: Axis(name="head_size", size=self.hidden_size // self.num_attention_heads)) + + + @classmethod + def from_hf_config(cls, hf_config: HfConfig) -> "RobertaConfig": + return RobertaConfig( + vocab_size = hf_config.vocab_size, + hidden_size = hf_config.hidden_size, + num_hidden_layers = hf_config.num_hidden_layers, + num_attention_heads = hf_config.num_attention_heads, + intermediate_size = hf_config.intermediate_size, + hidden_act = hf_config.hidden_act, + hidden_dropout_prob= hf_config.hidden_dropout_prob, + attention_probs_dropout_prob = hf_config.attention_probs_dropout_prob, + max_position_embeddings = hf_config.max_position_embeddings, + type_vocab_size = hf_config.type_vocab_size, + initializer_range = hf_config.initializer_range, + layer_norm_eps = hf_config.layer_norm_eps, + pad_token_id = hf_config.pad_token_id, + bos_token_id = hf_config.bos_token_id, + eos_token_id = hf_config.eos_token_id, + position_embedding_type = hf_config.position_embedding_type, + use_cache = hf_config.use_cache, + classifier_dropout = hf_config.classifier_dropout, + ) + + def hf_checkpoint_converter(self) -> HFCheckpointConverter["RobertaConfig"]: # type: ignore + return HFCheckpointConverter( + self.__class__, + reference_checkpoint=self.reference_checkpoint, + trust_remote_code=True, + tokenizer=self.tokenizer if self.tokenizer else self.reference_checkpoint, + HfConfigClass=HfRobertaConfig, + ) + + def to_hf_config(self, vocab_size: int, config_overrides: Optional[Dict] = None) -> HfRobertaConfig: + """Convert to HuggingFace's LlamaConfig + + Args: + vocab_size (int, optional): Vocabulary size of the tokenizer. Defaults to 32000. + config_overrides (dict, optional): Overrides for the config. Defaults to None. + + Returns: + HfRobertaConfig: HuggingFace's RobertaConfig + """ + + if config_overrides is None: + config_overrides = {} + + return HfRobertaConfig( + vocab_size = vocab_size, + hidden_size = self.hidden_size, + num_hidden_layers = self.num_hidden_layers, + num_attention_heads = self.num_attention_heads, + intermediate_size = self.intermediate_size, + hidden_act = self.hidden_act, + hidden_dropout_prob = self.hidden_dropout_prob, + attention_probs_dropout_prob = self.attention_probs_dropout_prob, + max_position_embeddings = self.max_position_embeddings, + type_vocab_size = self.type_vocab_size, + initializer_range = self.initializer_range, + layer_norm_eps = self.layer_norm_eps, + pad_token_id = self.pad_token_id, + bos_token_id = self.bos_token_id, + eos_token_id = self.eos_token_id, + position_embedding_type = self.position_embedding_type, + use_cache = self.use_cache, + classifier_dropout = self.classifier_dropout, + ) + + @property + def model_type(self) -> Type["RobertaForMaskedLM"]: + return RobertaForMaskedLM + + def flops_per_token(self, vocab_size: int): + return lm_flops_per_token( + hidden_dim=self.hidden_size, + intermediate_dim=self.intermediate_size, + num_layers=self.num_hidden_layers, + num_kv_heads=self.num_attention_heads, + num_heads=self.num_attention_heads, + seq_len=self.max_position_embeddings, + vocab_size=vocab_size, + glu=True, + ) + +class RobertaSelfAttention(eqx.Module, StateDictSerializationMixin): + + config: RobertaConfig + Heads: Axis + HeadSize: Axis + EmbedAtt: Axis + + q_proj: hnn.Linear + k_proj: hnn.Linear + v_proj: hnn.Linear + + dropout: hnn.Dropout + position_embedding_type: Optional[str] + + Pos: Axis + KeyPos: Axis + distance_embedding: Optional[hnn.Embedding] + + @staticmethod + def init(config: RobertaConfig, *, key) -> "RobertaSelfAttention": + Embed = config.Embed + EmbedAtt = config.EmbedAtt + + k_q, k_k, k_v, k_e = jrandom.split(key, 4) + q_proj = hnn.Linear.init(In=Embed, Out=EmbedAtt, key=k_q, out_first=True) + k_proj = hnn.Linear.init(In=Embed, Out=EmbedAtt, key=k_k, out_first=True) + v_proj = hnn.Linear.init(In=Embed, Out=EmbedAtt, key=k_v, out_first=True) + + dropout = hnn.Dropout(config.attention_probs_dropout_prob) + + distance_embedding = None + position_embedding_type = config.position_embedding_type + + if position_embedding_type == "relative_key" or position_embedding_type == "relative_key_query": + RelPos = Axis("rel_pos", 2 * config.max_position_embeddings - 1) + distance_embedding = hnn.Embedding.init(RelPos, config.HeadSize, k_e) + + return RobertaSelfAttention(config, config.Heads, config.HeadSize, EmbedAtt, + q_proj, k_proj, v_proj, + dropout, position_embedding_type, + config.Pos, config.KeyPos, distance_embedding, + ) + + def _state_dict_key_map(self) -> Dict[str, Optional[str]]: + return {"q_proj": "query", "k_proj": "key", "v_proj": "value"} + + def _rope_scale_factor(self) -> float: + # hasattr for gemma and I'm feeling lazy + if hasattr(self.config, "rope_scaling") and self.config.rope_scaling is not None: + assert self.config.rope_scaling["type"] == "linear" + return self.config.rope_scaling["factor"] + return 1.0 + + def transpose_for_scores(self, x: NamedArray) -> NamedArray: + # Makes sure to have the correct output order as well + y = hax.rearrange(x, "... position (embed_att: heads head_size) -> ... heads position head_size", heads=self.Heads, head_size=self.HeadSize) + return y + + @named_call + def __call__( + self, + hidden_states: NamedArray, + attention_mask: Optional[NamedArray] = None, + *, + key = None + ) -> Tuple[NamedArray]: + + query_layer = self.transpose_for_scores(self.q_proj(hidden_states)) + key_layer = self.transpose_for_scores(self.k_proj(hidden_states)) + value_layer = self.transpose_for_scores(self.v_proj(hidden_states)) + + if self.position_embedding_type == "rope": + cos, sin = llama_rotary_pos_emb( + self.config.HeadSize, hidden_states.resolve_axis("position"), scale=self._rope_scale_factor() + ) + query_layer, key_layer = _apply_rotary_pos_emb(query_layer, key_layer, cos, sin) + + key_layer = key_layer.rename({"position": "key_position"}) + value_layer = value_layer.rename({"position": "key_position"}) + + attention_scores = hax.dot(query_layer, key_layer, axis=self.HeadSize) # aka hax.einsum("bhld, bhrd -> bhlr") + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + Left = self.Pos # Queries + Right = self.KeyPos # Keys + + position_ids_l = hax.arange(Left).broadcast_to((Left,Right)) + position_ids_r = hax.arange(Right).broadcast_to((Left,Right)) + + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding(distance + self.Pos.size) + + if self.position_embedding_type == "relative_key": + relative_position_scores = hax.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = hax.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = hax.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores /= jnp.sqrt(self.HeadSize.size) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function) + # Attention_mask should have shape Batch Pos, so it should broadcast to shape Batch Heads Pos KeyPos for summation + attention_scores = attention_scores + attention_mask + + attention_probs = hnn.softmax(attention_scores, axis=self.KeyPos) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs, key=key) + + hax.dot(query_layer, key_layer, axis=self.HeadSize) + + context_layer = hax.dot(attention_probs, value_layer, axis=self.KeyPos) + + outputs = hax.rearrange(context_layer, ("... heads position head_size -> ... position (embed_att: heads head_size)"), heads=self.Heads, head_size=self.HeadSize) + + # jax.debug.breakpoint() + + return outputs + +class RobertaSelfOutput(eqx.Module, StateDictSerializationMixin): + dense: hnn.Linear + LayerNorm: hnn.LayerNorm + dropout: hnn.Dropout + + @staticmethod + def init(config: RobertaConfig, *, key) -> "RobertaSelfOutput": + Embed = config.Embed + EmbedAtt = config.EmbedAtt + dense = hnn.Linear.init(In=EmbedAtt, Out=Embed, key=key, out_first=True) + LayerNorm = hnn.LayerNorm.init(axis=Embed, eps=config.layer_norm_eps) + dropout = hnn.Dropout(config.hidden_dropout_prob) + return RobertaSelfOutput(dense, LayerNorm, dropout) + + @named_call + def __call__(self, hidden_states: NamedArray, input: NamedArray,*, key) -> NamedArray: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, key=key) + hidden_states = self.LayerNorm(hidden_states + input) + return hidden_states + +class RobertaAttention(eqx.Module, StateDictSerializationMixin): + self_attn: RobertaSelfAttention + output: RobertaSelfOutput + + @staticmethod + def init(config: RobertaConfig, *, key) -> "RobertaAttention": + k_a, k_o = jrandom.split(key, 2) + + self_attn = RobertaSelfAttention.init(config, key=k_a) + output = RobertaSelfOutput.init(config, key=k_o) + + return RobertaAttention(self_attn, output) + + def _state_dict_key_map(self) -> Dict[str, Optional[str]]: + return {"self_attn": "self"} + + @named_call + def __call__( + self, + hidden_states: NamedArray, + attention_mask: Optional[NamedArray] = None, + *, + key + ) -> NamedArray: + k_a, k_o = maybe_rng_split(key, 2) + + self_outputs = self.self_attn( + hidden_states, + attention_mask, + key=k_a + ) + attention_output = self.output(self_outputs, hidden_states, key=k_o) + return attention_output + +class RobertaIntermediate(eqx.Module, StateDictSerializationMixin): + dense: hnn.Linear + intermediate_act_fn: Callable = eqx.static_field() + + @staticmethod + def init(config, *, key) -> "RobertaIntermediate": + dense = hnn.Linear.init(config.Embed, config.Mlp, key=key, out_first=True) + if isinstance(config.hidden_act, str): + intermediate_act_fn = ACT2FN[config.hidden_act] + else: + intermediate_act_fn = config.hidden_act + + return RobertaIntermediate(dense, intermediate_act_fn) + + @named_call + def __call__(self, hidden_states: NamedArray, *, key = None) -> NamedArray: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + +class RobertaOutput(eqx.Module, StateDictSerializationMixin): + dense: hnn.Linear + LayerNorm: hnn.LayerNorm + dropout: hnn.Dropout + + @staticmethod + def init(config: RobertaConfig, *, key) -> "RobertaSelfOutput": + Embed = config.Embed + dense = hnn.Linear.init(In=config.Mlp, Out=Embed, key=key, out_first=True) + LayerNorm = hnn.LayerNorm.init(axis=Embed, eps=config.layer_norm_eps) + dropout = hnn.Dropout(config.hidden_dropout_prob) + return RobertaSelfOutput(dense, LayerNorm, dropout) + + @named_call + def __call__(self, hidden_states: NamedArray, input: NamedArray, *, key) -> NamedArray: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, key=key) + hidden_states = self.LayerNorm(hidden_states + input) + return hidden_states + +class RobertaLayer(eqx.Module, StateDictSerializationMixin): + attention: RobertaAttention + intermediate: RobertaIntermediate + output: RobertaOutput + + @staticmethod + def init(config: RobertaConfig, *, key) -> "RobertaLayer": + k_a, k_i, k_o = jrandom.split(key, 3) + + attention = RobertaAttention.init(config, key=k_a) + intermediate = RobertaIntermediate.init(config, key=k_i) + output = RobertaOutput.init(config, key=k_o) + + return RobertaLayer(attention, intermediate, output) + + @named_call + def __call__( + self, + hidden_states: NamedArray, + attention_mask: Optional[NamedArray] = None, + *, + key + ) -> Tuple[NamedArray]: + k_a, k_o = maybe_rng_split(key, 2) + + attention_output = self.attention( + hidden_states, + attention_mask, + key=k_a, + ) + + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output, key=k_o) + + # # jax.debug.print("{layer_output}", layer_output=layer_output) + + # return (layer_output, layer_output) + return layer_output + + +class RobertaEncoder(eqx.Module, StateDictSerializationMixin): + config: RobertaConfig + layer: BlockFoldable[RobertaLayer] + output_hidden_states: bool + + @staticmethod + def init(config: RobertaConfig, output_hidden_states: bool = False, *, key) -> "RobertaEncoder": + S = BlockSeq + + layer = S.init(config.Layers, RobertaLayer, gradient_checkpointing=config.gradient_checkpointing)( + config, + key=shaped_rng_split(key, config.num_hidden_layers), #TODO: config.gradient_checkpointing + ) + + return RobertaEncoder(config, layer, output_hidden_states) + + @named_call + def __call__( + self, + hidden_states: NamedArray, + attention_mask: Optional[NamedArray] = None, + *, + key + ) -> Tuple[NamedArray]: + + keys = maybe_rng_split(key, self.config.num_hidden_layers) if key is not None else None + + # x, intermediates = self.layer.scan(hidden_states, attention_mask, key=keys) + x = self.layer.fold(hidden_states, attention_mask, key=keys) + + return x, None + + # if not self.output_hidden_states: + # return x, None + # else: + # return x, intermediates + + def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): + out = super().from_state_dict(state_dict, prefix=prefix) + return out + + def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: + my_state_dict: StateDict = {} + super().update_state_dict(my_state_dict, prefix=prefix) + + state_dict.update(my_state_dict) + + return state_dict + +class RobertaEmbedding(eqx.Module, StateDictSerializationMixin): + Vocab: Axis = eqx.static_field() + Pos: Axis = eqx.static_field() + + word_embeddings: hnn.Embedding + position_embeddings: hnn.Embedding + token_type_embeddings: Optional[hnn.Embedding] + padding_idx: NamedArray + + LayerNorm: hnn.LayerNorm + dropout: hnn.Dropout + position_embedding_type: Optional[str] + + @staticmethod + def init(Vocab: Axis, config: RobertaConfig, *, key) -> "RobertaEmbedding": + key_w, key_p, key_t = jrandom.split(key, 3) + + padding_idx = config.pad_token_id + + word_embeddings = hnn.Embedding.init(Vocab, config.Embed, key=key_w) # padding_idx not specified + position_embeddings = hnn.Embedding.init(config.Pos, config.Embed, key=key_p) + + Token = hax.Axis("token", config.type_vocab_size) + + token_type_embeddings = hnn.Embedding.init(Token, config.Embed, key=key_t) + + LayerNorm = hnn.LayerNorm.init(config.Embed, config.layer_norm_eps) + dropout = hnn.Dropout(config.hidden_dropout_prob) + + return RobertaEmbedding(Vocab, config.Pos, word_embeddings, position_embeddings, token_type_embeddings, padding_idx, LayerNorm, dropout, config.position_embedding_type) + + def create_position_ids_from_input_ids(self, input_ids, past_key_values_length=0): + mask = hax.not_equal(input_ids, self.padding_idx) * 1 + incremental_indices = (hax.cumsum(mask, axis=self.Pos).astype(mask) + past_key_values_length) * mask + incremental_indices -= mask.all(axis=self.Pos) + return incremental_indices + + def create_position_ids_from_inputs_embeds(self, input_axes, PosInput): + position_ids = hax.arange(axis = PosInput, start = 0, dtype=jnp.int32) + # position_ids = hax.arange(axis = PosInput, start = self.padding_idx + 1, dtype=jnp.int32) + + return hax.broadcast_to(position_ids, input_axes) + + @named_call + def embed(self, input_ids=None, token_type_ids=None, position_ids=None, input_embeds=None, past_key_values_length=0, *, key = None): + # if input_ids is not None: + # print(f"input_ids: {input_ids.dtype}") + # else: + # print("input_ids: None") + + # if token_type_ids is not None: + # print(f"token_type_ids: {token_type_ids.dtype}") + # else: + # print("token_type_ids: None") + + # if position_ids is not None: + # print(f"position_ids: {position_ids.dtype}") + # else: + # print("position_ids: None") + + # if input_embeds is not None: + # print(f"input_embeds: {input_embeds.dtype}") + # else: + # print("input_embeds: None") + + """ + Note: When inputting your own embeds into input_embeds, make sure that the embeds axis has the name "embed" + for compatibility with the position_id creation function. Make sures its length is not equal to + """ + + # Get Axes + if input_ids is not None: + input_axes = input_ids.axes + else: + input_axes = hax.eliminate_axes(input_embeds.axes, "embed") + + # Get position_ids + if position_ids is None: + if input_ids is not None: + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = self.create_position_ids_from_input_ids(input_ids, past_key_values_length) + else: + position_ids = self.create_position_ids_from_inputs_embeds(input_axes, input_embeds.resolve_axis("position")) + + # Get token_type_ids + if token_type_ids is None: + token_type_ids = hax.zeros(input_axes, dtype=jnp.int32) + + if input_embeds is None: + input_embeds = self.word_embeddings(input_ids) + + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings = input_embeds + token_type_embeddings + + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings, key=key) + + # jax.debug.breakpoint() + + return embeddings + +class RobertaPooler(eqx.Module, StateDictSerializationMixin): + dense: hnn.Linear + config: RobertaConfig + + @staticmethod + def init(config: RobertaConfig, *, key): + dense = hnn.Linear.init(In=config.Embed, Out=config.FinalEmbed, key=key, out_first=True) + + return RobertaPooler(dense, config) + + @named_call + def __call__(self, hidden_states: NamedArray, *, key=None) -> NamedArray: + first_token = hidden_states[{"position" : 0}] + x = self.dense(first_token, key=key).rename({self.config.FinalEmbed: self.config.Embed}) + x = hax.tanh(x) + return x + + +class RobertaModel(eqx.Module, StateDictSerializationMixin): + encoder: RobertaEncoder + embeddings: RobertaEmbedding + pooler : Optional[RobertaPooler] + output_hidden_states: bool + + @staticmethod + def init(Vocab: Axis, config: RobertaConfig, add_pooling_layer: bool = True, output_hidden_states: bool = False, *, key) -> "RobertaModel": + k_t, k_emb, k_p = jrandom.split(key, 3) + encoder = RobertaEncoder.init(config=config, output_hidden_states=output_hidden_states, key=k_t) + embeddings = RobertaEmbedding.init(Vocab, config, key=k_emb) + + pooler = RobertaPooler.init(config, key=k_p) if add_pooling_layer else None + return RobertaModel(encoder, embeddings, pooler, output_hidden_states) + + @property + def config(self): + return self.encoder.config + + @property + def vocab_size(self) -> int: + return self.Vocab.size + + @property + def Vocab(self) -> Axis: + return self.embeddings.Vocab + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + @named_call + def __call__( + self, + input_ids: Optional[NamedArray] = None, + token_type_ids: Optional[NamedArray] = None, + position_ids: Optional[NamedArray] = None, + input_embeds: Optional[NamedArray] = None, + attention_mask: Optional[NamedArray] = None, + *, + key, + ) -> Tuple[NamedArray]: + """ + Not Used: meant to be used to improve performance in decoder implementations + + head_mask: Optional[NamedArray] = None, + encoder_hidden_states: Optional[NamedArray] = None, + encoder_attention_mask: Optional[NamedArray] = None, + past_key_values_length = 0, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + """ + k_emb, k_e, k_p = maybe_rng_split(key, 3) + + if input_ids is not None and input_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_axes = input_ids.axes + elif input_embeds is not None: + input_axes = hax.eliminate_axes(input_embeds.axes, "embed") + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if attention_mask is None: + attention_mask = hax.ones(input_axes) + + # print(f"attention_mask: {attention_mask}") + + # Attention mask from mask to actual numbers 0 -> -inf + attention_mask = (attention_mask == 0) * jnp.finfo(jnp.bfloat16).min + + # print(f"attention_mask_real: {attention_mask}") + + embedding_output = self.embeddings.embed(input_ids, token_type_ids, position_ids, input_embeds, key=k_emb) + + encoder_outputs = self.encoder(embedding_output, attention_mask=attention_mask, key=k_e) + + sequence_output = encoder_outputs[0] + + pooled_output = self.pooler(sequence_output, key=k_p) if self.pooler is not None else None + + # jax.debug.breakpoint() + + return (sequence_output, pooled_output) + encoder_outputs[1:] if self.output_hidden_states else (sequence_output, pooled_output) + +class RobertaLMHead(eqx.Module, StateDictSerializationMixin): + """Roberta Head for masked language modeling.""" + + dense: hnn.Linear + layer_norm: hnn.LayerNorm + decoder: hnn.Linear + config: RobertaConfig + + @staticmethod + def init(Vocab: Axis, config: RobertaConfig, *, key): + k_dense, k_decoder = jrandom.split(key, 2) + Embed = config.Embed + + dense = hnn.Linear.init(In=Embed, Out=config.FinalEmbed, key=k_dense, out_first=True) + layer_norm = hnn.LayerNorm.init(axis=Embed, eps=config.layer_norm_eps) + + decoder = hnn.Linear.init(Embed, Vocab, key=k_decoder, out_first=True) + + # idk what this is: TODO + # self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + # self.decoder.bias = self.bias + + return RobertaLMHead(dense, layer_norm, decoder, config) + + @named_call + def __call__(self, features: NamedArray, *, key=None) -> NamedArray: + x = self.dense(features).rename({self.config.FinalEmbed: self.config.Embed}) + x = hnn.gelu(x, approximate=False) + x = self.layer_norm(x) + + # project back to size of vocabulary with bias + x = self.decoder(x) + + return x + +class RobertaForMaskedLM(eqx.Module, StateDictSerializationMixin): + roberta: RobertaModel + lm_head: RobertaLMHead + Vocab: Axis + Pos: Axis + output_hidden_states: bool + + @classmethod + def init(self, Vocab: Axis, config: RobertaConfig, output_hidden_states: bool = False, *, key): + + # if config.is_decoder: + # raise AttributeError("Model is being run as a MaskedLM aka an encoder model, but is_decoder is true") + + key_rob, key_head = jrandom.split(key, 2) + roberta = RobertaModel.init(Vocab, config, add_pooling_layer=False, output_hidden_states=output_hidden_states, key=key_rob) + lm_head = RobertaLMHead.init(Vocab, config, key=key_head) + + return RobertaForMaskedLM(roberta, lm_head, Vocab, config.Pos, output_hidden_states) + + def get_output_embeddings(self): + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + self.lm_head.decoder = new_embeddings + + @named_call + def __call__( + self, + input_ids: Optional[NamedArray] = None, + attention_mask: Optional[NamedArray] = None, + token_type_ids: Optional[NamedArray] = None, + position_ids: Optional[NamedArray] = None, + input_embeds: Optional[NamedArray] = None, + *, + key=None + ) -> Tuple[NamedArray]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + kwargs (`Dict[str, any]`, optional, defaults to *{}*): + Used to hide legacy arguments that have been deprecated. + """ + + k_rob, k_lm = maybe_rng_split(key, 2) + + # print(f"input_ids: {input_ids}") + # print(f"attention_mask: {attention_mask}") + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + input_embeds=input_embeds, + key=k_rob + ) + + # print(f"outputs: {outputs}") + + prediction_scores = self.lm_head(outputs[0], key=k_lm) + + # print(f"prediction_scores: {prediction_scores}") + + # jax.debug.breakpoint() + + # return (prediction_scores,) + outputs[2:] + return prediction_scores + + def compute_loss( + self, + example: MaskedLmExample, + *, + key=None, + reduction: Optional[hax.ReductionFunction] = hax.mean, + reduction_axis: Optional[hax.AxisSelection] = None, + ) -> jnp.ndarray | NamedArray: + # logits = self(example.tokens, example.attn_mask, key=key)[0] + logits = self(example.tokens, example.attn_mask, key=key) + # print(f"Logits: {logits}") + logits = logits.astype(jnp.float32) + targets = example.targets + + target_y = hax.nn.one_hot(targets, self.Vocab, dtype=logits.dtype) + #target_y = jax.debug.breakpoint(token=target_y) + loss = cross_entropy_loss( + logits, self.Vocab, target_y, reduction, reduction_axis=reduction_axis, where=example.loss_mask + ) + + # print(f"loss: {loss}") + + return loss + +def _rotate_half(x: NamedArray) -> NamedArray: + """Rotates half of the hidden dims of the input and concatenates them.""" + HeadSize = x.axes[-1] + x1 = x[HeadSize, : HeadSize.size // 2] + x2 = x[HeadSize, HeadSize.size // 2 :] + out = hax.concatenate(HeadSize, (-x2, x1)) + return out + + +def _apply_rotary_pos_emb( + q: NamedArray, # [batch, position, kv_heads, q_heads_per_group, head_size] + k: NamedArray, # [batch, position, kv_heads, head_size] + cos: NamedArray, # [position, head_size] + sin: NamedArray, # [position, head_size] +) -> Tuple[NamedArray, NamedArray]: + """Applies rotary position embedding to q and k.""" + q_embed = q * cos + _rotate_half(q) * sin + k_embed = k * cos + _rotate_half(k) * sin + return q_embed, k_embed + + +def llama_rotary_pos_emb( + HeadSize: Axis, Pos: Axis, base: float = 10000, scale: float = 1.0 +) -> Tuple[NamedArray, NamedArray]: + with jax.ensure_compile_time_eval(): + HeadHalfSize = HeadSize.resize(HeadSize.size // 2) + inv_freq: NamedArray = 1.0 / (base ** (hax.arange(HeadHalfSize, step=2) / HeadSize.size)) + + position_ids: NamedArray = hax.arange(Pos) / scale + + freqs = position_ids * inv_freq.broadcast_axis(Pos) + # This is different from the paper but aligns with HF implementation: + # It uses a different permutation in order to obtain the same calculation + emb = hax.concatenate(HeadSize, (freqs, freqs)) + cos = hax.cos(emb) + sin = hax.sin(emb) + # This is different from the paper but aligns with HF implementation: + return cos, sin diff --git a/src/levanter/models/testing.ipynb b/src/levanter/models/testing.ipynb new file mode 100644 index 000000000..27a24f7bf --- /dev/null +++ b/src/levanter/models/testing.ipynb @@ -0,0 +1,1420 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "import roberta as my_roberta\n", + "from transformers.models.roberta import modeling_roberta as hf_roberta\n", + "\n", + "import torch\n", + "import haliax as hax\n", + "import jax\n", + "import jax.random as jrandom\n", + "import jax.numpy as jnp\n", + "import numpy as np\n", + "\n", + "# hello" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoConfig\n", + "from time import time\n", + "\n", + "hf_model_str = \"FacebookAI/roberta-base\"\n", + "\n", + "hf_config = AutoConfig.from_pretrained(hf_model_str)\n", + "hf_config.hidden_dropout_prob = 0\n", + "hf_config.attention_probs_dropout_prob = 0\n", + "# hf_config.pad_token_id = -1\n", + "my_config = my_roberta.RobertaConfig.from_hf_config(hf_config)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "seed: 1725922495\n" + ] + } + ], + "source": [ + "seed = time()\n", + "print(f\"seed: {int(seed)}\")\n", + "key = jrandom.PRNGKey(int(seed))\n", + "\n", + "key_vars, key_funcs, key_run = jrandom.split(key, 3)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "EmbedAtt = my_config.EmbedAtt\n", + "Embed = my_config.Embed\n", + "Mlp = my_config.Mlp\n", + "Pos = my_config.Pos\n", + "KeyPos = my_config.KeyPos\n", + "Heads = my_config.Heads\n", + "\n", + "cut_end_for_bounds = True \n", + "\n", + "Batch = hax.Axis(\"batch\", 2)\n", + "Vocab = hax.Axis(\"vocab\", my_config.vocab_size)\n", + "\n", + "keys = jrandom.split(key_vars, 6)\n", + "\n", + "input_ids = hax.random.randint(keys[0], (Batch, Pos), minval = 3, maxval = my_config.vocab_size)\n", + "if cut_end_for_bounds:\n", + " input_ids = input_ids[{\"position\": slice(0,-2)}]\n", + "input_ids_torch = torch.from_numpy(np.array(input_ids.array))\n", + "\n", + "input_embeds = hax.random.normal(keys[1], (Batch, Pos, Embed))\n", + "if cut_end_for_bounds:\n", + " input_embeds = input_embeds[{\"position\": slice(0,-2)}]\n", + "input_embeds_torch = torch.from_numpy(np.array(input_embeds.array))\n", + "\n", + "# mask = hax.random.randint(keys[2], (Batch, Pos), minval = 0, maxval = 2)\n", + "# mask = hax.ones((Batch, Pos))\n", + "mask = hax.zeros((Batch, Pos))\n", + "\n", + "if cut_end_for_bounds:\n", + " mask = mask[{\"position\": slice(0,-2)}]\n", + "mask_torch = torch.from_numpy(np.array(mask.array))\n", + "\n", + "mask_materialized = (mask == 0) * jnp.finfo(jnp.bfloat16).min\n", + "mask_torch_materialized = hf_roberta.RobertaModel.get_extended_attention_mask(self=hf_roberta.RobertaModel(hf_config), attention_mask=mask_torch, input_shape=input_embeds_torch.shape)\n", + "\n", + "features = input_embeds[{\"position\": 0}]\n", + "features_torch = torch.from_numpy(np.array(features.array))\n", + "\n", + "x_embed_att = input_embeds.rename({\"embed\": \"embed_att\"})\n", + "if cut_end_for_bounds:\n", + " x_embed_att = x_embed_att[{\"position\": slice(0,-2)}]\n", + "x_embed_att_torch = torch.from_numpy(np.array(x_embed_att.array))\n", + "\n", + "x_mlp = hax.random.normal(keys[5], (Batch, Pos, Mlp))\n", + "if cut_end_for_bounds:\n", + " x_mlp = x_mlp[{\"position\": slice(0,-2)}] \n", + "x_mlp_torch = torch.from_numpy(np.array(x_mlp.array))" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "def check(my_output, hf_output, precision=1e-4):\n", + " \n", + " assert (np.array(my_output.shape) == np.array(hf_output.shape)).all()\n", + " # print(my_output.shape)\n", + " # print(hf_output.shape)\n", + "\n", + " acc = np.isclose(hf_output, np.array(my_output), rtol=precision, atol=precision).mean()\n", + "\n", + " # stats = (torch.tensor(np.array(my_output)).abs().mean(), torch.tensor(np.array(hf_output)).abs().mean()) \n", + " stats = (torch.linalg.norm(torch.tensor(np.array(my_output))), torch.linalg.norm(torch.tensor(np.array(hf_output))))\n", + " \n", + " difference = torch.tensor(np.array(my_output)) - torch.tensor(np.array(hf_output))\n", + "\n", + " diffs = difference.abs().mean()\n", + "\n", + " to_print = f\"acc: {acc} \\t norms: {stats} \\t diffs: {diffs}\"\n", + " \n", + " return acc, stats, diffs, to_print" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [], + "source": [ + "def check_dicts(my_dict, hf_dict):\n", + " print(my_dict.keys())\n", + " print(hf_dict.keys())\n", + "\n", + " hf_keys_save = list(hf_dict.keys())\n", + "\n", + " flag = 0\n", + " diff = 0\n", + "\n", + " for k in my_dict.keys():\n", + " i = my_dict[k]\n", + " if k not in hf_dict:\n", + " print(f\"ERROR \\t {k}: key in my_dict but not hf_dict\")\n", + " j = hf_dict[k]\n", + " diff += (np.array(i) - np.array(j)).sum()\n", + " if check(i, j.detach())[0] < 1:\n", + " print(f\"ERROR \\t {k}: {check(i, j)[0]}\")\n", + " flag += 1\n", + " hf_keys_save.remove(k)\n", + "\n", + " if flag == 0:\n", + " print(\"success1\") \n", + " else:\n", + " print(\"fail1\") \n", + "\n", + " if len(hf_keys_save) == 0:\n", + " print(\"success2\") \n", + " else:\n", + " print(\"fail2\")\n", + " print(hf_keys_save)\n", + "\n", + " return diff" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'stop' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[1;32mIn[25], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m \u001b[43mstop\u001b[49m\n", + "\u001b[1;31mNameError\u001b[0m: name 'stop' is not defined" + ] + } + ], + "source": [ + "stop" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "# Testing RobertaSelfOutput\n", + "\n", + "def test_RobertaSelfOutput(key):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", + " my_func = my_roberta.RobertaSelfOutput.init(my_config, key=k_1)\n", + " state = my_func.to_state_dict()\n", + "\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "\n", + " # # print(state.keys())\n", + "\n", + " hf_func = hf_roberta.RobertaSelfOutput(hf_config)\n", + " hf_func.load_state_dict(state, strict=True, assign=True)\n", + "\n", + " my_output = my_func(x_embed_att, input_embeds, key=k_2)\n", + " hf_output = hf_func(x_embed_att_torch, input_embeds_torch)\n", + "\n", + " return check(my_output.array, hf_output.detach())\n", + "\n", + "# Testing RobertaSelfAttention\n", + "\n", + "def test_RobertaSelfAttention(key):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", + "\n", + " my_func = my_roberta.RobertaSelfAttention.init(my_config, key=k_1)\n", + " state = my_func.to_state_dict()\n", + "\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "\n", + " # print(state.keys())\n", + "\n", + " hf_func = hf_roberta.RobertaSelfAttention(hf_config)\n", + " hf_func.load_state_dict(state, strict=True, assign=True)\n", + "\n", + " my_output = my_func(input_embeds, mask_materialized, key=k_2)\n", + " hf_output = hf_func(input_embeds_torch, mask_torch_materialized)\n", + "\n", + " return check(my_output.array, hf_output[0].detach())\n", + "\n", + "# Testing RobertaAttention\n", + "\n", + "def test_RobertaAttention(key):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", + " my_func = my_roberta.RobertaAttention.init(my_config, key=k_1)\n", + " state = my_func.to_state_dict()\n", + "\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "\n", + " # print(state.keys())\n", + "\n", + " hf_func = hf_roberta.RobertaAttention(hf_config)\n", + " hf_func.load_state_dict(state, strict=True, assign=True)\n", + "\n", + " my_output = my_func(hidden_states=input_embeds, attention_mask=mask_materialized, key=k_2)\n", + " hf_output = hf_func(hidden_states=input_embeds_torch, attention_mask=mask_torch_materialized)\n", + "\n", + " return check(my_output.array, hf_output[0].detach())\n", + "\n", + "# Testing RobertaIntermediate\n", + "\n", + "def test_RobertaIntermediate(key):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", + " my_func = my_roberta.RobertaIntermediate.init(my_config, key=k_1)\n", + " state = my_func.to_state_dict()\n", + "\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "\n", + " # print(state.keys())\n", + "\n", + " hf_func = hf_roberta.RobertaIntermediate(hf_config)\n", + " hf_func.load_state_dict(state, strict=True, assign=True)\n", + "\n", + " my_output = my_func(input_embeds, key=k_2)\n", + " hf_output = hf_func(input_embeds_torch)\n", + "\n", + " return check(my_output.array, hf_output.detach())\n", + "\n", + "# Testing RobertaOutput\n", + "\n", + "def test_RobertaOutput(key):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", + "\n", + " my_func = my_roberta.RobertaOutput.init(my_config, key=k_1)\n", + " state = my_func.to_state_dict()\n", + "\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "\n", + " # print(state.keys())\n", + "\n", + " hf_func = hf_roberta.RobertaOutput(hf_config)\n", + " hf_func.load_state_dict(state, strict=True, assign=True)\n", + "\n", + " my_output = my_func(x_mlp, input_embeds, key=k_2)\n", + " hf_output = hf_func(x_mlp_torch, input_embeds_torch)\n", + "\n", + " return check(my_output.array, hf_output.detach())\n", + "\n", + "# Testing RobertaLayer\n", + "\n", + "def test_RobertaLayer(key):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", + " my_func = my_roberta.RobertaLayer.init(my_config, key=k_1)\n", + " state = my_func.to_state_dict()\n", + "\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "\n", + " # print(state.keys())\n", + "\n", + " hf_func = hf_roberta.RobertaLayer(hf_config)\n", + " hf_func.load_state_dict(state, strict=True, assign=True)\n", + "\n", + " my_output = my_func(hidden_states=input_embeds, attention_mask=mask_materialized, key=k_2)\n", + " hf_output = hf_func(hidden_states=input_embeds_torch, attention_mask=mask_torch_materialized)\n", + "\n", + " return check(my_output[0].array, hf_output[0].detach())\n", + "\n", + "# Testing RobertaEncoder\n", + "\n", + "def test_RobertaEncoder(key):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", + " my_func = my_roberta.RobertaEncoder.init(my_config, key=k_1)\n", + " state = my_func.to_state_dict()\n", + "\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "\n", + " # print(state.keys())\n", + "\n", + " hf_func = hf_roberta.RobertaEncoder(hf_config)\n", + " hf_func.load_state_dict(state, strict=True, assign=True)\n", + "\n", + " my_output = my_func(hidden_states=input_embeds, attention_mask=mask_materialized, key=k_2)\n", + " hf_output = hf_func(hidden_states=input_embeds_torch, attention_mask=mask_torch_materialized)\n", + "\n", + " return check(my_output[0].array, hf_output[0].detach())\n", + "\n", + "# Testing RobertaEmbedding\n", + "\n", + "def test_RobertaEmbedding(key, ids = True):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", + " my_func = my_roberta.RobertaEmbedding.init(Vocab, my_config, key=k_1)\n", + " state = my_func.to_state_dict()\n", + "\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "\n", + " # print(state.keys())\n", + "\n", + " hf_func = hf_roberta.RobertaEmbeddings(hf_config)\n", + " hf_func.load_state_dict(state, strict=True, assign=True)\n", + "\n", + " if ids:\n", + " my_output = my_func.embed(input_ids=input_ids, key=k_2)\n", + " hf_output = hf_func(input_ids=input_ids_torch)\n", + " else: \n", + " my_output = my_func.embed(input_embeds=input_embeds, key=k_2)\n", + " hf_output = hf_func(inputs_embeds=input_embeds_torch)\n", + "\n", + " return check(my_output.array, hf_output.detach())\n", + "\n", + "# Testing RobertaPooler\n", + "\n", + "def test_RobertaPooler(key):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", + " my_func = my_roberta.RobertaPooler.init(my_config, key=k_1)\n", + " state = my_func.to_state_dict()\n", + "\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "\n", + " # print(state.keys())\n", + "\n", + " hf_func = hf_roberta.RobertaPooler(hf_config)\n", + " hf_func.load_state_dict(state, strict=True, assign=True)\n", + "\n", + " my_output = my_func(input_embeds, key=k_2)\n", + " hf_output = hf_func(input_embeds_torch)\n", + "\n", + " return check(my_output.array, hf_output.detach())\n", + "\n", + "# Testing RobertaModel\n", + "\n", + "def test_RobertaModel(key, ids = True, pool = True):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", + " my_func = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=pool, key=k_1)\n", + " state = my_func.to_state_dict()\n", + "\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "\n", + " # print(state.keys())\n", + "\n", + " hf_func = hf_roberta.RobertaModel(hf_config, add_pooling_layer=pool)\n", + " hf_func.load_state_dict(state, strict=True, assign=True)\n", + "\n", + " if ids:\n", + " my_output = my_func(input_ids = input_ids, attention_mask=mask, key=k_2)\n", + " hf_output = hf_func(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\n", + " else:\n", + " my_output = my_func(input_embeds = input_embeds, attention_mask=mask, key=k_2)\n", + " hf_output = hf_func(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\n", + "\n", + " if pool:\n", + " return check(my_output[1].array, hf_output[1].detach())\n", + " else:\n", + " return check(my_output[0].array, hf_output[0].detach())\n", + "\n", + "# Testing RobertaLMHead\n", + "\n", + "def test_RobertaLMHead(key):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", + " my_func = my_roberta.RobertaLMHead.init(Vocab, my_config, key=k_1)\n", + " state = my_func.to_state_dict()\n", + "\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "\n", + " state[\"bias\"] = torch.zeros(hf_config.vocab_size)\n", + "\n", + " # print(state.keys())\n", + "\n", + " hf_func = hf_roberta.RobertaLMHead(hf_config)\n", + " hf_func.load_state_dict(state, strict=True, assign=True)\n", + "\n", + " my_output = my_func(features, key=k_2)\n", + " hf_output = hf_func(features_torch)\n", + "\n", + " return check(my_output.array, hf_output.detach())\n", + "\n", + "# Testing RobertaForMaskedLM\n", + "\n", + "def test_RobertaForMaskedLM(key, ids = True):\n", + " k_1, k_2 = jrandom.split(key, 2)\n", + " my_pool = my_roberta.RobertaForMaskedLM.init(Vocab, my_config, key=k_1)\n", + " state = my_pool.to_state_dict()\n", + "\n", + " state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n", + "\n", + " state[\"lm_head.bias\"] = torch.zeros(hf_config.vocab_size)\n", + "\n", + " # print(state.keys())\n", + "\n", + " hf_pool = hf_roberta.RobertaForMaskedLM(hf_config)\n", + " hf_pool.load_state_dict(state, strict=True, assign=True)\n", + "\n", + " if ids:\n", + " my_output = my_pool(input_ids = input_ids, attention_mask=mask, key=k_2)\n", + " hf_output = hf_pool(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\n", + " else:\n", + " my_output = my_pool(input_embeds = input_embeds, attention_mask=mask, key=k_2)\n", + " hf_output = hf_pool(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\n", + "\n", + " return check(my_output[0].array, hf_output[0].detach())" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "keys = jrandom.split(key_funcs, 15)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "outs = []\n", + "\n", + "outs.append(test_RobertaSelfOutput(keys[0]))\n", + "outs.append(test_RobertaSelfAttention(keys[1]))\n", + "outs.append(test_RobertaAttention(keys[2]))\n", + "outs.append(test_RobertaIntermediate(keys[3]))\n", + "outs.append(test_RobertaOutput(keys[4]))\n", + "outs.append(test_RobertaLayer(keys[4]))\n", + "outs.append(test_RobertaEncoder(keys[4]))\n", + "outs.append(test_RobertaEmbedding(keys[7], ids = True))\n", + "outs.append(test_RobertaEmbedding(keys[8], ids = False))\n", + "outs.append(test_RobertaModel(keys[9], ids = True, pool = True))\n", + "outs.append(test_RobertaModel(keys[10], ids = False, pool = False))\n", + "outs.append(test_RobertaModel(keys[9], ids = True, pool = True))\n", + "outs.append(test_RobertaModel(keys[10], ids = False, pool = False))\n", + "outs.append(test_RobertaPooler(keys[11]))\n", + "outs.append(test_RobertaLMHead(keys[12]))\n", + "outs.append(test_RobertaForMaskedLM(keys[13], ids = True))\n", + "outs.append(test_RobertaForMaskedLM(keys[14], ids = False))" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "types = [\n", + " \"test_RobertaSelfOutput\",\n", + " \"test_RobertaSelfAttention\",\n", + " \"test_RobertaAttention\",\n", + " \"test_RobertaIntermediate\",\n", + " \"test_RobertaOutput\",\n", + " \"test_RobertaLayer\",\n", + " \"test_RobertaEncoder\",\n", + " \"test_RobertaEmbedding(ids = True)\",\n", + " \"test_RobertaEmbedding(ids = False)\",\n", + " \"test_RobertaModel(ids = True, pool = True)\",\n", + " \"test_RobertaModel(ids = False, pool = False)\",\n", + " \"test_RobertaModel(ids = True, pool = True)\",\n", + " \"test_RobertaModel(ids = False, pool = False)\",\n", + " \"test_RobertaPooler\",\n", + " \"test_RobertaLMHead\",\n", + " \"test_RobertaForMaskedLM(ids = True)\",\n", + " \"test_RobertaForMaskedLM(ids = False)\"\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "for i,o in enumerate(outs):\n", + " if o[2] * 0 != 0:\n", + " print(f\"nan alert\")\n", + " if o[0] < 1:\n", + " print(f\"{types[i]}: {o[3]}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "for i,o in enumerate(outs):\n", + " print(f\"{types[i]}: {o[3]}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "key_model, key_lm = jrandom.split(key_funcs, 2)\n", + "key_model_run, key_lm_run = jrandom.split(key_run, 2)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dict_keys(['encoder.layer.0.attention.self.query.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.layer.0.intermediate.dense.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.0.output.dense.weight', 'encoder.layer.0.output.dense.bias', 'encoder.layer.0.output.LayerNorm.weight', 'encoder.layer.0.output.LayerNorm.bias', 'encoder.layer.1.attention.self.query.weight', 'encoder.layer.1.attention.self.query.bias', 'encoder.layer.1.attention.self.key.weight', 'encoder.layer.1.attention.self.key.bias', 'encoder.layer.1.attention.self.value.weight', 'encoder.layer.1.attention.self.value.bias', 'encoder.layer.1.attention.output.dense.weight', 'encoder.layer.1.attention.output.dense.bias', 'encoder.layer.1.attention.output.LayerNorm.weight', 'encoder.layer.1.attention.output.LayerNorm.bias', 'encoder.layer.1.intermediate.dense.weight', 'encoder.layer.1.intermediate.dense.bias', 'encoder.layer.1.output.dense.weight', 'encoder.layer.1.output.dense.bias', 'encoder.layer.1.output.LayerNorm.weight', 'encoder.layer.1.output.LayerNorm.bias', 'encoder.layer.2.attention.self.query.weight', 'encoder.layer.2.attention.self.query.bias', 'encoder.layer.2.attention.self.key.weight', 'encoder.layer.2.attention.self.key.bias', 'encoder.layer.2.attention.self.value.weight', 'encoder.layer.2.attention.self.value.bias', 'encoder.layer.2.attention.output.dense.weight', 'encoder.layer.2.attention.output.dense.bias', 'encoder.layer.2.attention.output.LayerNorm.weight', 'encoder.layer.2.attention.output.LayerNorm.bias', 'encoder.layer.2.intermediate.dense.weight', 'encoder.layer.2.intermediate.dense.bias', 'encoder.layer.2.output.dense.weight', 'encoder.layer.2.output.dense.bias', 'encoder.layer.2.output.LayerNorm.weight', 'encoder.layer.2.output.LayerNorm.bias', 'encoder.layer.3.attention.self.query.weight', 'encoder.layer.3.attention.self.query.bias', 'encoder.layer.3.attention.self.key.weight', 'encoder.layer.3.attention.self.key.bias', 'encoder.layer.3.attention.self.value.weight', 'encoder.layer.3.attention.self.value.bias', 'encoder.layer.3.attention.output.dense.weight', 'encoder.layer.3.attention.output.dense.bias', 'encoder.layer.3.attention.output.LayerNorm.weight', 'encoder.layer.3.attention.output.LayerNorm.bias', 'encoder.layer.3.intermediate.dense.weight', 'encoder.layer.3.intermediate.dense.bias', 'encoder.layer.3.output.dense.weight', 'encoder.layer.3.output.dense.bias', 'encoder.layer.3.output.LayerNorm.weight', 'encoder.layer.3.output.LayerNorm.bias', 'encoder.layer.4.attention.self.query.weight', 'encoder.layer.4.attention.self.query.bias', 'encoder.layer.4.attention.self.key.weight', 'encoder.layer.4.attention.self.key.bias', 'encoder.layer.4.attention.self.value.weight', 'encoder.layer.4.attention.self.value.bias', 'encoder.layer.4.attention.output.dense.weight', 'encoder.layer.4.attention.output.dense.bias', 'encoder.layer.4.attention.output.LayerNorm.weight', 'encoder.layer.4.attention.output.LayerNorm.bias', 'encoder.layer.4.intermediate.dense.weight', 'encoder.layer.4.intermediate.dense.bias', 'encoder.layer.4.output.dense.weight', 'encoder.layer.4.output.dense.bias', 'encoder.layer.4.output.LayerNorm.weight', 'encoder.layer.4.output.LayerNorm.bias', 'encoder.layer.5.attention.self.query.weight', 'encoder.layer.5.attention.self.query.bias', 'encoder.layer.5.attention.self.key.weight', 'encoder.layer.5.attention.self.key.bias', 'encoder.layer.5.attention.self.value.weight', 'encoder.layer.5.attention.self.value.bias', 'encoder.layer.5.attention.output.dense.weight', 'encoder.layer.5.attention.output.dense.bias', 'encoder.layer.5.attention.output.LayerNorm.weight', 'encoder.layer.5.attention.output.LayerNorm.bias', 'encoder.layer.5.intermediate.dense.weight', 'encoder.layer.5.intermediate.dense.bias', 'encoder.layer.5.output.dense.weight', 'encoder.layer.5.output.dense.bias', 'encoder.layer.5.output.LayerNorm.weight', 'encoder.layer.5.output.LayerNorm.bias', 'encoder.layer.6.attention.self.query.weight', 'encoder.layer.6.attention.self.query.bias', 'encoder.layer.6.attention.self.key.weight', 'encoder.layer.6.attention.self.key.bias', 'encoder.layer.6.attention.self.value.weight', 'encoder.layer.6.attention.self.value.bias', 'encoder.layer.6.attention.output.dense.weight', 'encoder.layer.6.attention.output.dense.bias', 'encoder.layer.6.attention.output.LayerNorm.weight', 'encoder.layer.6.attention.output.LayerNorm.bias', 'encoder.layer.6.intermediate.dense.weight', 'encoder.layer.6.intermediate.dense.bias', 'encoder.layer.6.output.dense.weight', 'encoder.layer.6.output.dense.bias', 'encoder.layer.6.output.LayerNorm.weight', 'encoder.layer.6.output.LayerNorm.bias', 'encoder.layer.7.attention.self.query.weight', 'encoder.layer.7.attention.self.query.bias', 'encoder.layer.7.attention.self.key.weight', 'encoder.layer.7.attention.self.key.bias', 'encoder.layer.7.attention.self.value.weight', 'encoder.layer.7.attention.self.value.bias', 'encoder.layer.7.attention.output.dense.weight', 'encoder.layer.7.attention.output.dense.bias', 'encoder.layer.7.attention.output.LayerNorm.weight', 'encoder.layer.7.attention.output.LayerNorm.bias', 'encoder.layer.7.intermediate.dense.weight', 'encoder.layer.7.intermediate.dense.bias', 'encoder.layer.7.output.dense.weight', 'encoder.layer.7.output.dense.bias', 'encoder.layer.7.output.LayerNorm.weight', 'encoder.layer.7.output.LayerNorm.bias', 'encoder.layer.8.attention.self.query.weight', 'encoder.layer.8.attention.self.query.bias', 'encoder.layer.8.attention.self.key.weight', 'encoder.layer.8.attention.self.key.bias', 'encoder.layer.8.attention.self.value.weight', 'encoder.layer.8.attention.self.value.bias', 'encoder.layer.8.attention.output.dense.weight', 'encoder.layer.8.attention.output.dense.bias', 'encoder.layer.8.attention.output.LayerNorm.weight', 'encoder.layer.8.attention.output.LayerNorm.bias', 'encoder.layer.8.intermediate.dense.weight', 'encoder.layer.8.intermediate.dense.bias', 'encoder.layer.8.output.dense.weight', 'encoder.layer.8.output.dense.bias', 'encoder.layer.8.output.LayerNorm.weight', 'encoder.layer.8.output.LayerNorm.bias', 'encoder.layer.9.attention.self.query.weight', 'encoder.layer.9.attention.self.query.bias', 'encoder.layer.9.attention.self.key.weight', 'encoder.layer.9.attention.self.key.bias', 'encoder.layer.9.attention.self.value.weight', 'encoder.layer.9.attention.self.value.bias', 'encoder.layer.9.attention.output.dense.weight', 'encoder.layer.9.attention.output.dense.bias', 'encoder.layer.9.attention.output.LayerNorm.weight', 'encoder.layer.9.attention.output.LayerNorm.bias', 'encoder.layer.9.intermediate.dense.weight', 'encoder.layer.9.intermediate.dense.bias', 'encoder.layer.9.output.dense.weight', 'encoder.layer.9.output.dense.bias', 'encoder.layer.9.output.LayerNorm.weight', 'encoder.layer.9.output.LayerNorm.bias', 'encoder.layer.10.attention.self.query.weight', 'encoder.layer.10.attention.self.query.bias', 'encoder.layer.10.attention.self.key.weight', 'encoder.layer.10.attention.self.key.bias', 'encoder.layer.10.attention.self.value.weight', 'encoder.layer.10.attention.self.value.bias', 'encoder.layer.10.attention.output.dense.weight', 'encoder.layer.10.attention.output.dense.bias', 'encoder.layer.10.attention.output.LayerNorm.weight', 'encoder.layer.10.attention.output.LayerNorm.bias', 'encoder.layer.10.intermediate.dense.weight', 'encoder.layer.10.intermediate.dense.bias', 'encoder.layer.10.output.dense.weight', 'encoder.layer.10.output.dense.bias', 'encoder.layer.10.output.LayerNorm.weight', 'encoder.layer.10.output.LayerNorm.bias', 'encoder.layer.11.attention.self.query.weight', 'encoder.layer.11.attention.self.query.bias', 'encoder.layer.11.attention.self.key.weight', 'encoder.layer.11.attention.self.key.bias', 'encoder.layer.11.attention.self.value.weight', 'encoder.layer.11.attention.self.value.bias', 'encoder.layer.11.attention.output.dense.weight', 'encoder.layer.11.attention.output.dense.bias', 'encoder.layer.11.attention.output.LayerNorm.weight', 'encoder.layer.11.attention.output.LayerNorm.bias', 'encoder.layer.11.intermediate.dense.weight', 'encoder.layer.11.intermediate.dense.bias', 'encoder.layer.11.output.dense.weight', 'encoder.layer.11.output.dense.bias', 'encoder.layer.11.output.LayerNorm.weight', 'encoder.layer.11.output.LayerNorm.bias', 'embeddings.word_embeddings.weight', 'embeddings.position_embeddings.weight', 'embeddings.token_type_embeddings.weight', 'embeddings.LayerNorm.weight', 'embeddings.LayerNorm.bias', 'pooler.dense.weight', 'pooler.dense.bias'])\n", + "odict_keys(['embeddings.word_embeddings.weight', 'embeddings.position_embeddings.weight', 'embeddings.token_type_embeddings.weight', 'embeddings.LayerNorm.weight', 'embeddings.LayerNorm.bias', 'encoder.layer.0.attention.self.query.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.layer.0.intermediate.dense.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.0.output.dense.weight', 'encoder.layer.0.output.dense.bias', 'encoder.layer.0.output.LayerNorm.weight', 'encoder.layer.0.output.LayerNorm.bias', 'encoder.layer.1.attention.self.query.weight', 'encoder.layer.1.attention.self.query.bias', 'encoder.layer.1.attention.self.key.weight', 'encoder.layer.1.attention.self.key.bias', 'encoder.layer.1.attention.self.value.weight', 'encoder.layer.1.attention.self.value.bias', 'encoder.layer.1.attention.output.dense.weight', 'encoder.layer.1.attention.output.dense.bias', 'encoder.layer.1.attention.output.LayerNorm.weight', 'encoder.layer.1.attention.output.LayerNorm.bias', 'encoder.layer.1.intermediate.dense.weight', 'encoder.layer.1.intermediate.dense.bias', 'encoder.layer.1.output.dense.weight', 'encoder.layer.1.output.dense.bias', 'encoder.layer.1.output.LayerNorm.weight', 'encoder.layer.1.output.LayerNorm.bias', 'encoder.layer.2.attention.self.query.weight', 'encoder.layer.2.attention.self.query.bias', 'encoder.layer.2.attention.self.key.weight', 'encoder.layer.2.attention.self.key.bias', 'encoder.layer.2.attention.self.value.weight', 'encoder.layer.2.attention.self.value.bias', 'encoder.layer.2.attention.output.dense.weight', 'encoder.layer.2.attention.output.dense.bias', 'encoder.layer.2.attention.output.LayerNorm.weight', 'encoder.layer.2.attention.output.LayerNorm.bias', 'encoder.layer.2.intermediate.dense.weight', 'encoder.layer.2.intermediate.dense.bias', 'encoder.layer.2.output.dense.weight', 'encoder.layer.2.output.dense.bias', 'encoder.layer.2.output.LayerNorm.weight', 'encoder.layer.2.output.LayerNorm.bias', 'encoder.layer.3.attention.self.query.weight', 'encoder.layer.3.attention.self.query.bias', 'encoder.layer.3.attention.self.key.weight', 'encoder.layer.3.attention.self.key.bias', 'encoder.layer.3.attention.self.value.weight', 'encoder.layer.3.attention.self.value.bias', 'encoder.layer.3.attention.output.dense.weight', 'encoder.layer.3.attention.output.dense.bias', 'encoder.layer.3.attention.output.LayerNorm.weight', 'encoder.layer.3.attention.output.LayerNorm.bias', 'encoder.layer.3.intermediate.dense.weight', 'encoder.layer.3.intermediate.dense.bias', 'encoder.layer.3.output.dense.weight', 'encoder.layer.3.output.dense.bias', 'encoder.layer.3.output.LayerNorm.weight', 'encoder.layer.3.output.LayerNorm.bias', 'encoder.layer.4.attention.self.query.weight', 'encoder.layer.4.attention.self.query.bias', 'encoder.layer.4.attention.self.key.weight', 'encoder.layer.4.attention.self.key.bias', 'encoder.layer.4.attention.self.value.weight', 'encoder.layer.4.attention.self.value.bias', 'encoder.layer.4.attention.output.dense.weight', 'encoder.layer.4.attention.output.dense.bias', 'encoder.layer.4.attention.output.LayerNorm.weight', 'encoder.layer.4.attention.output.LayerNorm.bias', 'encoder.layer.4.intermediate.dense.weight', 'encoder.layer.4.intermediate.dense.bias', 'encoder.layer.4.output.dense.weight', 'encoder.layer.4.output.dense.bias', 'encoder.layer.4.output.LayerNorm.weight', 'encoder.layer.4.output.LayerNorm.bias', 'encoder.layer.5.attention.self.query.weight', 'encoder.layer.5.attention.self.query.bias', 'encoder.layer.5.attention.self.key.weight', 'encoder.layer.5.attention.self.key.bias', 'encoder.layer.5.attention.self.value.weight', 'encoder.layer.5.attention.self.value.bias', 'encoder.layer.5.attention.output.dense.weight', 'encoder.layer.5.attention.output.dense.bias', 'encoder.layer.5.attention.output.LayerNorm.weight', 'encoder.layer.5.attention.output.LayerNorm.bias', 'encoder.layer.5.intermediate.dense.weight', 'encoder.layer.5.intermediate.dense.bias', 'encoder.layer.5.output.dense.weight', 'encoder.layer.5.output.dense.bias', 'encoder.layer.5.output.LayerNorm.weight', 'encoder.layer.5.output.LayerNorm.bias', 'encoder.layer.6.attention.self.query.weight', 'encoder.layer.6.attention.self.query.bias', 'encoder.layer.6.attention.self.key.weight', 'encoder.layer.6.attention.self.key.bias', 'encoder.layer.6.attention.self.value.weight', 'encoder.layer.6.attention.self.value.bias', 'encoder.layer.6.attention.output.dense.weight', 'encoder.layer.6.attention.output.dense.bias', 'encoder.layer.6.attention.output.LayerNorm.weight', 'encoder.layer.6.attention.output.LayerNorm.bias', 'encoder.layer.6.intermediate.dense.weight', 'encoder.layer.6.intermediate.dense.bias', 'encoder.layer.6.output.dense.weight', 'encoder.layer.6.output.dense.bias', 'encoder.layer.6.output.LayerNorm.weight', 'encoder.layer.6.output.LayerNorm.bias', 'encoder.layer.7.attention.self.query.weight', 'encoder.layer.7.attention.self.query.bias', 'encoder.layer.7.attention.self.key.weight', 'encoder.layer.7.attention.self.key.bias', 'encoder.layer.7.attention.self.value.weight', 'encoder.layer.7.attention.self.value.bias', 'encoder.layer.7.attention.output.dense.weight', 'encoder.layer.7.attention.output.dense.bias', 'encoder.layer.7.attention.output.LayerNorm.weight', 'encoder.layer.7.attention.output.LayerNorm.bias', 'encoder.layer.7.intermediate.dense.weight', 'encoder.layer.7.intermediate.dense.bias', 'encoder.layer.7.output.dense.weight', 'encoder.layer.7.output.dense.bias', 'encoder.layer.7.output.LayerNorm.weight', 'encoder.layer.7.output.LayerNorm.bias', 'encoder.layer.8.attention.self.query.weight', 'encoder.layer.8.attention.self.query.bias', 'encoder.layer.8.attention.self.key.weight', 'encoder.layer.8.attention.self.key.bias', 'encoder.layer.8.attention.self.value.weight', 'encoder.layer.8.attention.self.value.bias', 'encoder.layer.8.attention.output.dense.weight', 'encoder.layer.8.attention.output.dense.bias', 'encoder.layer.8.attention.output.LayerNorm.weight', 'encoder.layer.8.attention.output.LayerNorm.bias', 'encoder.layer.8.intermediate.dense.weight', 'encoder.layer.8.intermediate.dense.bias', 'encoder.layer.8.output.dense.weight', 'encoder.layer.8.output.dense.bias', 'encoder.layer.8.output.LayerNorm.weight', 'encoder.layer.8.output.LayerNorm.bias', 'encoder.layer.9.attention.self.query.weight', 'encoder.layer.9.attention.self.query.bias', 'encoder.layer.9.attention.self.key.weight', 'encoder.layer.9.attention.self.key.bias', 'encoder.layer.9.attention.self.value.weight', 'encoder.layer.9.attention.self.value.bias', 'encoder.layer.9.attention.output.dense.weight', 'encoder.layer.9.attention.output.dense.bias', 'encoder.layer.9.attention.output.LayerNorm.weight', 'encoder.layer.9.attention.output.LayerNorm.bias', 'encoder.layer.9.intermediate.dense.weight', 'encoder.layer.9.intermediate.dense.bias', 'encoder.layer.9.output.dense.weight', 'encoder.layer.9.output.dense.bias', 'encoder.layer.9.output.LayerNorm.weight', 'encoder.layer.9.output.LayerNorm.bias', 'encoder.layer.10.attention.self.query.weight', 'encoder.layer.10.attention.self.query.bias', 'encoder.layer.10.attention.self.key.weight', 'encoder.layer.10.attention.self.key.bias', 'encoder.layer.10.attention.self.value.weight', 'encoder.layer.10.attention.self.value.bias', 'encoder.layer.10.attention.output.dense.weight', 'encoder.layer.10.attention.output.dense.bias', 'encoder.layer.10.attention.output.LayerNorm.weight', 'encoder.layer.10.attention.output.LayerNorm.bias', 'encoder.layer.10.intermediate.dense.weight', 'encoder.layer.10.intermediate.dense.bias', 'encoder.layer.10.output.dense.weight', 'encoder.layer.10.output.dense.bias', 'encoder.layer.10.output.LayerNorm.weight', 'encoder.layer.10.output.LayerNorm.bias', 'encoder.layer.11.attention.self.query.weight', 'encoder.layer.11.attention.self.query.bias', 'encoder.layer.11.attention.self.key.weight', 'encoder.layer.11.attention.self.key.bias', 'encoder.layer.11.attention.self.value.weight', 'encoder.layer.11.attention.self.value.bias', 'encoder.layer.11.attention.output.dense.weight', 'encoder.layer.11.attention.output.dense.bias', 'encoder.layer.11.attention.output.LayerNorm.weight', 'encoder.layer.11.attention.output.LayerNorm.bias', 'encoder.layer.11.intermediate.dense.weight', 'encoder.layer.11.intermediate.dense.bias', 'encoder.layer.11.output.dense.weight', 'encoder.layer.11.output.dense.bias', 'encoder.layer.11.output.LayerNorm.weight', 'encoder.layer.11.output.LayerNorm.bias', 'pooler.dense.weight', 'pooler.dense.bias'])\n", + "success1\n", + "success2\n", + "0.0\n" + ] + } + ], + "source": [ + "# Initializing RobertaModel\n", + "my_model = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=True, output_hidden_states=True, key=key_model)\n", + "state_model = my_model.to_state_dict()\n", + "\n", + "state_model = {k: torch.from_numpy(np.array(v)) for k, v in state_model.items()}\n", + "\n", + "hf_model = hf_roberta.RobertaModel(hf_config, add_pooling_layer=True)\n", + "hf_model.load_state_dict(state_model, strict=True)\n", + "\n", + "print(check_dicts(my_model.to_state_dict(), hf_model.state_dict()))" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dict_keys(['roberta.encoder.layer.0.attention.self.query.weight', 'roberta.encoder.layer.0.attention.self.query.bias', 'roberta.encoder.layer.0.attention.self.key.weight', 'roberta.encoder.layer.0.attention.self.key.bias', 'roberta.encoder.layer.0.attention.self.value.weight', 'roberta.encoder.layer.0.attention.self.value.bias', 'roberta.encoder.layer.0.attention.output.dense.weight', 'roberta.encoder.layer.0.attention.output.dense.bias', 'roberta.encoder.layer.0.attention.output.LayerNorm.weight', 'roberta.encoder.layer.0.attention.output.LayerNorm.bias', 'roberta.encoder.layer.0.intermediate.dense.weight', 'roberta.encoder.layer.0.intermediate.dense.bias', 'roberta.encoder.layer.0.output.dense.weight', 'roberta.encoder.layer.0.output.dense.bias', 'roberta.encoder.layer.0.output.LayerNorm.weight', 'roberta.encoder.layer.0.output.LayerNorm.bias', 'roberta.encoder.layer.1.attention.self.query.weight', 'roberta.encoder.layer.1.attention.self.query.bias', 'roberta.encoder.layer.1.attention.self.key.weight', 'roberta.encoder.layer.1.attention.self.key.bias', 'roberta.encoder.layer.1.attention.self.value.weight', 'roberta.encoder.layer.1.attention.self.value.bias', 'roberta.encoder.layer.1.attention.output.dense.weight', 'roberta.encoder.layer.1.attention.output.dense.bias', 'roberta.encoder.layer.1.attention.output.LayerNorm.weight', 'roberta.encoder.layer.1.attention.output.LayerNorm.bias', 'roberta.encoder.layer.1.intermediate.dense.weight', 'roberta.encoder.layer.1.intermediate.dense.bias', 'roberta.encoder.layer.1.output.dense.weight', 'roberta.encoder.layer.1.output.dense.bias', 'roberta.encoder.layer.1.output.LayerNorm.weight', 'roberta.encoder.layer.1.output.LayerNorm.bias', 'roberta.encoder.layer.2.attention.self.query.weight', 'roberta.encoder.layer.2.attention.self.query.bias', 'roberta.encoder.layer.2.attention.self.key.weight', 'roberta.encoder.layer.2.attention.self.key.bias', 'roberta.encoder.layer.2.attention.self.value.weight', 'roberta.encoder.layer.2.attention.self.value.bias', 'roberta.encoder.layer.2.attention.output.dense.weight', 'roberta.encoder.layer.2.attention.output.dense.bias', 'roberta.encoder.layer.2.attention.output.LayerNorm.weight', 'roberta.encoder.layer.2.attention.output.LayerNorm.bias', 'roberta.encoder.layer.2.intermediate.dense.weight', 'roberta.encoder.layer.2.intermediate.dense.bias', 'roberta.encoder.layer.2.output.dense.weight', 'roberta.encoder.layer.2.output.dense.bias', 'roberta.encoder.layer.2.output.LayerNorm.weight', 'roberta.encoder.layer.2.output.LayerNorm.bias', 'roberta.encoder.layer.3.attention.self.query.weight', 'roberta.encoder.layer.3.attention.self.query.bias', 'roberta.encoder.layer.3.attention.self.key.weight', 'roberta.encoder.layer.3.attention.self.key.bias', 'roberta.encoder.layer.3.attention.self.value.weight', 'roberta.encoder.layer.3.attention.self.value.bias', 'roberta.encoder.layer.3.attention.output.dense.weight', 'roberta.encoder.layer.3.attention.output.dense.bias', 'roberta.encoder.layer.3.attention.output.LayerNorm.weight', 'roberta.encoder.layer.3.attention.output.LayerNorm.bias', 'roberta.encoder.layer.3.intermediate.dense.weight', 'roberta.encoder.layer.3.intermediate.dense.bias', 'roberta.encoder.layer.3.output.dense.weight', 'roberta.encoder.layer.3.output.dense.bias', 'roberta.encoder.layer.3.output.LayerNorm.weight', 'roberta.encoder.layer.3.output.LayerNorm.bias', 'roberta.encoder.layer.4.attention.self.query.weight', 'roberta.encoder.layer.4.attention.self.query.bias', 'roberta.encoder.layer.4.attention.self.key.weight', 'roberta.encoder.layer.4.attention.self.key.bias', 'roberta.encoder.layer.4.attention.self.value.weight', 'roberta.encoder.layer.4.attention.self.value.bias', 'roberta.encoder.layer.4.attention.output.dense.weight', 'roberta.encoder.layer.4.attention.output.dense.bias', 'roberta.encoder.layer.4.attention.output.LayerNorm.weight', 'roberta.encoder.layer.4.attention.output.LayerNorm.bias', 'roberta.encoder.layer.4.intermediate.dense.weight', 'roberta.encoder.layer.4.intermediate.dense.bias', 'roberta.encoder.layer.4.output.dense.weight', 'roberta.encoder.layer.4.output.dense.bias', 'roberta.encoder.layer.4.output.LayerNorm.weight', 'roberta.encoder.layer.4.output.LayerNorm.bias', 'roberta.encoder.layer.5.attention.self.query.weight', 'roberta.encoder.layer.5.attention.self.query.bias', 'roberta.encoder.layer.5.attention.self.key.weight', 'roberta.encoder.layer.5.attention.self.key.bias', 'roberta.encoder.layer.5.attention.self.value.weight', 'roberta.encoder.layer.5.attention.self.value.bias', 'roberta.encoder.layer.5.attention.output.dense.weight', 'roberta.encoder.layer.5.attention.output.dense.bias', 'roberta.encoder.layer.5.attention.output.LayerNorm.weight', 'roberta.encoder.layer.5.attention.output.LayerNorm.bias', 'roberta.encoder.layer.5.intermediate.dense.weight', 'roberta.encoder.layer.5.intermediate.dense.bias', 'roberta.encoder.layer.5.output.dense.weight', 'roberta.encoder.layer.5.output.dense.bias', 'roberta.encoder.layer.5.output.LayerNorm.weight', 'roberta.encoder.layer.5.output.LayerNorm.bias', 'roberta.encoder.layer.6.attention.self.query.weight', 'roberta.encoder.layer.6.attention.self.query.bias', 'roberta.encoder.layer.6.attention.self.key.weight', 'roberta.encoder.layer.6.attention.self.key.bias', 'roberta.encoder.layer.6.attention.self.value.weight', 'roberta.encoder.layer.6.attention.self.value.bias', 'roberta.encoder.layer.6.attention.output.dense.weight', 'roberta.encoder.layer.6.attention.output.dense.bias', 'roberta.encoder.layer.6.attention.output.LayerNorm.weight', 'roberta.encoder.layer.6.attention.output.LayerNorm.bias', 'roberta.encoder.layer.6.intermediate.dense.weight', 'roberta.encoder.layer.6.intermediate.dense.bias', 'roberta.encoder.layer.6.output.dense.weight', 'roberta.encoder.layer.6.output.dense.bias', 'roberta.encoder.layer.6.output.LayerNorm.weight', 'roberta.encoder.layer.6.output.LayerNorm.bias', 'roberta.encoder.layer.7.attention.self.query.weight', 'roberta.encoder.layer.7.attention.self.query.bias', 'roberta.encoder.layer.7.attention.self.key.weight', 'roberta.encoder.layer.7.attention.self.key.bias', 'roberta.encoder.layer.7.attention.self.value.weight', 'roberta.encoder.layer.7.attention.self.value.bias', 'roberta.encoder.layer.7.attention.output.dense.weight', 'roberta.encoder.layer.7.attention.output.dense.bias', 'roberta.encoder.layer.7.attention.output.LayerNorm.weight', 'roberta.encoder.layer.7.attention.output.LayerNorm.bias', 'roberta.encoder.layer.7.intermediate.dense.weight', 'roberta.encoder.layer.7.intermediate.dense.bias', 'roberta.encoder.layer.7.output.dense.weight', 'roberta.encoder.layer.7.output.dense.bias', 'roberta.encoder.layer.7.output.LayerNorm.weight', 'roberta.encoder.layer.7.output.LayerNorm.bias', 'roberta.encoder.layer.8.attention.self.query.weight', 'roberta.encoder.layer.8.attention.self.query.bias', 'roberta.encoder.layer.8.attention.self.key.weight', 'roberta.encoder.layer.8.attention.self.key.bias', 'roberta.encoder.layer.8.attention.self.value.weight', 'roberta.encoder.layer.8.attention.self.value.bias', 'roberta.encoder.layer.8.attention.output.dense.weight', 'roberta.encoder.layer.8.attention.output.dense.bias', 'roberta.encoder.layer.8.attention.output.LayerNorm.weight', 'roberta.encoder.layer.8.attention.output.LayerNorm.bias', 'roberta.encoder.layer.8.intermediate.dense.weight', 'roberta.encoder.layer.8.intermediate.dense.bias', 'roberta.encoder.layer.8.output.dense.weight', 'roberta.encoder.layer.8.output.dense.bias', 'roberta.encoder.layer.8.output.LayerNorm.weight', 'roberta.encoder.layer.8.output.LayerNorm.bias', 'roberta.encoder.layer.9.attention.self.query.weight', 'roberta.encoder.layer.9.attention.self.query.bias', 'roberta.encoder.layer.9.attention.self.key.weight', 'roberta.encoder.layer.9.attention.self.key.bias', 'roberta.encoder.layer.9.attention.self.value.weight', 'roberta.encoder.layer.9.attention.self.value.bias', 'roberta.encoder.layer.9.attention.output.dense.weight', 'roberta.encoder.layer.9.attention.output.dense.bias', 'roberta.encoder.layer.9.attention.output.LayerNorm.weight', 'roberta.encoder.layer.9.attention.output.LayerNorm.bias', 'roberta.encoder.layer.9.intermediate.dense.weight', 'roberta.encoder.layer.9.intermediate.dense.bias', 'roberta.encoder.layer.9.output.dense.weight', 'roberta.encoder.layer.9.output.dense.bias', 'roberta.encoder.layer.9.output.LayerNorm.weight', 'roberta.encoder.layer.9.output.LayerNorm.bias', 'roberta.encoder.layer.10.attention.self.query.weight', 'roberta.encoder.layer.10.attention.self.query.bias', 'roberta.encoder.layer.10.attention.self.key.weight', 'roberta.encoder.layer.10.attention.self.key.bias', 'roberta.encoder.layer.10.attention.self.value.weight', 'roberta.encoder.layer.10.attention.self.value.bias', 'roberta.encoder.layer.10.attention.output.dense.weight', 'roberta.encoder.layer.10.attention.output.dense.bias', 'roberta.encoder.layer.10.attention.output.LayerNorm.weight', 'roberta.encoder.layer.10.attention.output.LayerNorm.bias', 'roberta.encoder.layer.10.intermediate.dense.weight', 'roberta.encoder.layer.10.intermediate.dense.bias', 'roberta.encoder.layer.10.output.dense.weight', 'roberta.encoder.layer.10.output.dense.bias', 'roberta.encoder.layer.10.output.LayerNorm.weight', 'roberta.encoder.layer.10.output.LayerNorm.bias', 'roberta.encoder.layer.11.attention.self.query.weight', 'roberta.encoder.layer.11.attention.self.query.bias', 'roberta.encoder.layer.11.attention.self.key.weight', 'roberta.encoder.layer.11.attention.self.key.bias', 'roberta.encoder.layer.11.attention.self.value.weight', 'roberta.encoder.layer.11.attention.self.value.bias', 'roberta.encoder.layer.11.attention.output.dense.weight', 'roberta.encoder.layer.11.attention.output.dense.bias', 'roberta.encoder.layer.11.attention.output.LayerNorm.weight', 'roberta.encoder.layer.11.attention.output.LayerNorm.bias', 'roberta.encoder.layer.11.intermediate.dense.weight', 'roberta.encoder.layer.11.intermediate.dense.bias', 'roberta.encoder.layer.11.output.dense.weight', 'roberta.encoder.layer.11.output.dense.bias', 'roberta.encoder.layer.11.output.LayerNorm.weight', 'roberta.encoder.layer.11.output.LayerNorm.bias', 'roberta.embeddings.word_embeddings.weight', 'roberta.embeddings.position_embeddings.weight', 'roberta.embeddings.token_type_embeddings.weight', 'roberta.embeddings.LayerNorm.weight', 'roberta.embeddings.LayerNorm.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.decoder.bias', 'lm_head.bias'])\n", + "odict_keys(['roberta.embeddings.word_embeddings.weight', 'roberta.embeddings.position_embeddings.weight', 'roberta.embeddings.token_type_embeddings.weight', 'roberta.embeddings.LayerNorm.weight', 'roberta.embeddings.LayerNorm.bias', 'roberta.encoder.layer.0.attention.self.query.weight', 'roberta.encoder.layer.0.attention.self.query.bias', 'roberta.encoder.layer.0.attention.self.key.weight', 'roberta.encoder.layer.0.attention.self.key.bias', 'roberta.encoder.layer.0.attention.self.value.weight', 'roberta.encoder.layer.0.attention.self.value.bias', 'roberta.encoder.layer.0.attention.output.dense.weight', 'roberta.encoder.layer.0.attention.output.dense.bias', 'roberta.encoder.layer.0.attention.output.LayerNorm.weight', 'roberta.encoder.layer.0.attention.output.LayerNorm.bias', 'roberta.encoder.layer.0.intermediate.dense.weight', 'roberta.encoder.layer.0.intermediate.dense.bias', 'roberta.encoder.layer.0.output.dense.weight', 'roberta.encoder.layer.0.output.dense.bias', 'roberta.encoder.layer.0.output.LayerNorm.weight', 'roberta.encoder.layer.0.output.LayerNorm.bias', 'roberta.encoder.layer.1.attention.self.query.weight', 'roberta.encoder.layer.1.attention.self.query.bias', 'roberta.encoder.layer.1.attention.self.key.weight', 'roberta.encoder.layer.1.attention.self.key.bias', 'roberta.encoder.layer.1.attention.self.value.weight', 'roberta.encoder.layer.1.attention.self.value.bias', 'roberta.encoder.layer.1.attention.output.dense.weight', 'roberta.encoder.layer.1.attention.output.dense.bias', 'roberta.encoder.layer.1.attention.output.LayerNorm.weight', 'roberta.encoder.layer.1.attention.output.LayerNorm.bias', 'roberta.encoder.layer.1.intermediate.dense.weight', 'roberta.encoder.layer.1.intermediate.dense.bias', 'roberta.encoder.layer.1.output.dense.weight', 'roberta.encoder.layer.1.output.dense.bias', 'roberta.encoder.layer.1.output.LayerNorm.weight', 'roberta.encoder.layer.1.output.LayerNorm.bias', 'roberta.encoder.layer.2.attention.self.query.weight', 'roberta.encoder.layer.2.attention.self.query.bias', 'roberta.encoder.layer.2.attention.self.key.weight', 'roberta.encoder.layer.2.attention.self.key.bias', 'roberta.encoder.layer.2.attention.self.value.weight', 'roberta.encoder.layer.2.attention.self.value.bias', 'roberta.encoder.layer.2.attention.output.dense.weight', 'roberta.encoder.layer.2.attention.output.dense.bias', 'roberta.encoder.layer.2.attention.output.LayerNorm.weight', 'roberta.encoder.layer.2.attention.output.LayerNorm.bias', 'roberta.encoder.layer.2.intermediate.dense.weight', 'roberta.encoder.layer.2.intermediate.dense.bias', 'roberta.encoder.layer.2.output.dense.weight', 'roberta.encoder.layer.2.output.dense.bias', 'roberta.encoder.layer.2.output.LayerNorm.weight', 'roberta.encoder.layer.2.output.LayerNorm.bias', 'roberta.encoder.layer.3.attention.self.query.weight', 'roberta.encoder.layer.3.attention.self.query.bias', 'roberta.encoder.layer.3.attention.self.key.weight', 'roberta.encoder.layer.3.attention.self.key.bias', 'roberta.encoder.layer.3.attention.self.value.weight', 'roberta.encoder.layer.3.attention.self.value.bias', 'roberta.encoder.layer.3.attention.output.dense.weight', 'roberta.encoder.layer.3.attention.output.dense.bias', 'roberta.encoder.layer.3.attention.output.LayerNorm.weight', 'roberta.encoder.layer.3.attention.output.LayerNorm.bias', 'roberta.encoder.layer.3.intermediate.dense.weight', 'roberta.encoder.layer.3.intermediate.dense.bias', 'roberta.encoder.layer.3.output.dense.weight', 'roberta.encoder.layer.3.output.dense.bias', 'roberta.encoder.layer.3.output.LayerNorm.weight', 'roberta.encoder.layer.3.output.LayerNorm.bias', 'roberta.encoder.layer.4.attention.self.query.weight', 'roberta.encoder.layer.4.attention.self.query.bias', 'roberta.encoder.layer.4.attention.self.key.weight', 'roberta.encoder.layer.4.attention.self.key.bias', 'roberta.encoder.layer.4.attention.self.value.weight', 'roberta.encoder.layer.4.attention.self.value.bias', 'roberta.encoder.layer.4.attention.output.dense.weight', 'roberta.encoder.layer.4.attention.output.dense.bias', 'roberta.encoder.layer.4.attention.output.LayerNorm.weight', 'roberta.encoder.layer.4.attention.output.LayerNorm.bias', 'roberta.encoder.layer.4.intermediate.dense.weight', 'roberta.encoder.layer.4.intermediate.dense.bias', 'roberta.encoder.layer.4.output.dense.weight', 'roberta.encoder.layer.4.output.dense.bias', 'roberta.encoder.layer.4.output.LayerNorm.weight', 'roberta.encoder.layer.4.output.LayerNorm.bias', 'roberta.encoder.layer.5.attention.self.query.weight', 'roberta.encoder.layer.5.attention.self.query.bias', 'roberta.encoder.layer.5.attention.self.key.weight', 'roberta.encoder.layer.5.attention.self.key.bias', 'roberta.encoder.layer.5.attention.self.value.weight', 'roberta.encoder.layer.5.attention.self.value.bias', 'roberta.encoder.layer.5.attention.output.dense.weight', 'roberta.encoder.layer.5.attention.output.dense.bias', 'roberta.encoder.layer.5.attention.output.LayerNorm.weight', 'roberta.encoder.layer.5.attention.output.LayerNorm.bias', 'roberta.encoder.layer.5.intermediate.dense.weight', 'roberta.encoder.layer.5.intermediate.dense.bias', 'roberta.encoder.layer.5.output.dense.weight', 'roberta.encoder.layer.5.output.dense.bias', 'roberta.encoder.layer.5.output.LayerNorm.weight', 'roberta.encoder.layer.5.output.LayerNorm.bias', 'roberta.encoder.layer.6.attention.self.query.weight', 'roberta.encoder.layer.6.attention.self.query.bias', 'roberta.encoder.layer.6.attention.self.key.weight', 'roberta.encoder.layer.6.attention.self.key.bias', 'roberta.encoder.layer.6.attention.self.value.weight', 'roberta.encoder.layer.6.attention.self.value.bias', 'roberta.encoder.layer.6.attention.output.dense.weight', 'roberta.encoder.layer.6.attention.output.dense.bias', 'roberta.encoder.layer.6.attention.output.LayerNorm.weight', 'roberta.encoder.layer.6.attention.output.LayerNorm.bias', 'roberta.encoder.layer.6.intermediate.dense.weight', 'roberta.encoder.layer.6.intermediate.dense.bias', 'roberta.encoder.layer.6.output.dense.weight', 'roberta.encoder.layer.6.output.dense.bias', 'roberta.encoder.layer.6.output.LayerNorm.weight', 'roberta.encoder.layer.6.output.LayerNorm.bias', 'roberta.encoder.layer.7.attention.self.query.weight', 'roberta.encoder.layer.7.attention.self.query.bias', 'roberta.encoder.layer.7.attention.self.key.weight', 'roberta.encoder.layer.7.attention.self.key.bias', 'roberta.encoder.layer.7.attention.self.value.weight', 'roberta.encoder.layer.7.attention.self.value.bias', 'roberta.encoder.layer.7.attention.output.dense.weight', 'roberta.encoder.layer.7.attention.output.dense.bias', 'roberta.encoder.layer.7.attention.output.LayerNorm.weight', 'roberta.encoder.layer.7.attention.output.LayerNorm.bias', 'roberta.encoder.layer.7.intermediate.dense.weight', 'roberta.encoder.layer.7.intermediate.dense.bias', 'roberta.encoder.layer.7.output.dense.weight', 'roberta.encoder.layer.7.output.dense.bias', 'roberta.encoder.layer.7.output.LayerNorm.weight', 'roberta.encoder.layer.7.output.LayerNorm.bias', 'roberta.encoder.layer.8.attention.self.query.weight', 'roberta.encoder.layer.8.attention.self.query.bias', 'roberta.encoder.layer.8.attention.self.key.weight', 'roberta.encoder.layer.8.attention.self.key.bias', 'roberta.encoder.layer.8.attention.self.value.weight', 'roberta.encoder.layer.8.attention.self.value.bias', 'roberta.encoder.layer.8.attention.output.dense.weight', 'roberta.encoder.layer.8.attention.output.dense.bias', 'roberta.encoder.layer.8.attention.output.LayerNorm.weight', 'roberta.encoder.layer.8.attention.output.LayerNorm.bias', 'roberta.encoder.layer.8.intermediate.dense.weight', 'roberta.encoder.layer.8.intermediate.dense.bias', 'roberta.encoder.layer.8.output.dense.weight', 'roberta.encoder.layer.8.output.dense.bias', 'roberta.encoder.layer.8.output.LayerNorm.weight', 'roberta.encoder.layer.8.output.LayerNorm.bias', 'roberta.encoder.layer.9.attention.self.query.weight', 'roberta.encoder.layer.9.attention.self.query.bias', 'roberta.encoder.layer.9.attention.self.key.weight', 'roberta.encoder.layer.9.attention.self.key.bias', 'roberta.encoder.layer.9.attention.self.value.weight', 'roberta.encoder.layer.9.attention.self.value.bias', 'roberta.encoder.layer.9.attention.output.dense.weight', 'roberta.encoder.layer.9.attention.output.dense.bias', 'roberta.encoder.layer.9.attention.output.LayerNorm.weight', 'roberta.encoder.layer.9.attention.output.LayerNorm.bias', 'roberta.encoder.layer.9.intermediate.dense.weight', 'roberta.encoder.layer.9.intermediate.dense.bias', 'roberta.encoder.layer.9.output.dense.weight', 'roberta.encoder.layer.9.output.dense.bias', 'roberta.encoder.layer.9.output.LayerNorm.weight', 'roberta.encoder.layer.9.output.LayerNorm.bias', 'roberta.encoder.layer.10.attention.self.query.weight', 'roberta.encoder.layer.10.attention.self.query.bias', 'roberta.encoder.layer.10.attention.self.key.weight', 'roberta.encoder.layer.10.attention.self.key.bias', 'roberta.encoder.layer.10.attention.self.value.weight', 'roberta.encoder.layer.10.attention.self.value.bias', 'roberta.encoder.layer.10.attention.output.dense.weight', 'roberta.encoder.layer.10.attention.output.dense.bias', 'roberta.encoder.layer.10.attention.output.LayerNorm.weight', 'roberta.encoder.layer.10.attention.output.LayerNorm.bias', 'roberta.encoder.layer.10.intermediate.dense.weight', 'roberta.encoder.layer.10.intermediate.dense.bias', 'roberta.encoder.layer.10.output.dense.weight', 'roberta.encoder.layer.10.output.dense.bias', 'roberta.encoder.layer.10.output.LayerNorm.weight', 'roberta.encoder.layer.10.output.LayerNorm.bias', 'roberta.encoder.layer.11.attention.self.query.weight', 'roberta.encoder.layer.11.attention.self.query.bias', 'roberta.encoder.layer.11.attention.self.key.weight', 'roberta.encoder.layer.11.attention.self.key.bias', 'roberta.encoder.layer.11.attention.self.value.weight', 'roberta.encoder.layer.11.attention.self.value.bias', 'roberta.encoder.layer.11.attention.output.dense.weight', 'roberta.encoder.layer.11.attention.output.dense.bias', 'roberta.encoder.layer.11.attention.output.LayerNorm.weight', 'roberta.encoder.layer.11.attention.output.LayerNorm.bias', 'roberta.encoder.layer.11.intermediate.dense.weight', 'roberta.encoder.layer.11.intermediate.dense.bias', 'roberta.encoder.layer.11.output.dense.weight', 'roberta.encoder.layer.11.output.dense.bias', 'roberta.encoder.layer.11.output.LayerNorm.weight', 'roberta.encoder.layer.11.output.LayerNorm.bias', 'lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.decoder.bias'])\n", + "success1\n", + "success2\n", + "0.0\n" + ] + } + ], + "source": [ + "# Initializing RobertaForMaskedLM\n", + "my_mlm = my_roberta.RobertaForMaskedLM.init(Vocab, my_config, output_hidden_states=True, key=key_funcs)\n", + "state_mlm = my_mlm.to_state_dict()\n", + "\n", + "state_mlm = {k: torch.from_numpy(np.array(v)) for k, v in state_mlm.items()}\n", + "state_mlm[\"lm_head.bias\"] = torch.zeros(hf_config.vocab_size)\n", + "# print(state_mlm[w_str])\n", + "\n", + "hf_mlm = hf_roberta.RobertaForMaskedLM(hf_config)\n", + "hf_mlm.load_state_dict(state_mlm, strict=True, assign=True)\n", + "# print(hf_mlm.state_dict()[w_str])\n", + "\n", + "print(check_dicts(state_mlm, hf_mlm.state_dict()))" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "def test_RobertaModel_Output(key_run, ids = False):\n", + " if ids:\n", + " my_output = my_model(input_ids = input_ids, attention_mask=mask, key=key_run)\n", + " hf_output = hf_model(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False, output_hidden_states=True)\n", + " else:\n", + " my_output = my_model(input_embeds = input_embeds, attention_mask=mask, key=key_run)\n", + " hf_output = hf_model(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False, output_hidden_states=True)\n", + "\n", + " return my_output, hf_output" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "my_output_ids, hf_output_ids = test_RobertaModel_Output(key_model_run, ids=True)\n", + "my_output_embeds, hf_output_embeds = test_RobertaModel_Output(key_model_run, ids=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "model_out: acc: 1.0 \t norms: (tensor(888.5331), tensor(888.5331)) \t diffs: 6.946668236196274e-07\n", + "pool_out: acc: 1.0 \t norms: (tensor(24.6133), tensor(24.6133)) \t diffs: 4.2906631847472454e-07\n", + "intermediates:\n", + "acc: 1.0 \t norms: (tensor(888.5315), tensor(888.5314)) \t diffs: 1.6926360046909394e-07\n", + "acc: 1.0 \t norms: (tensor(888.5301), tensor(888.5301)) \t diffs: 2.57225963196106e-07\n", + "acc: 1.0 \t norms: (tensor(888.5310), tensor(888.5311)) \t diffs: 3.1373855335914413e-07\n", + "acc: 1.0 \t norms: (tensor(888.5312), tensor(888.5312)) \t diffs: 3.6955742643840495e-07\n", + "acc: 1.0 \t norms: (tensor(888.5304), tensor(888.5304)) \t diffs: 4.1312114262836985e-07\n", + "acc: 1.0 \t norms: (tensor(888.5308), tensor(888.5308)) \t diffs: 4.612424220340472e-07\n", + "acc: 1.0 \t norms: (tensor(888.5308), tensor(888.5308)) \t diffs: 5.164657750356128e-07\n", + "acc: 1.0 \t norms: (tensor(888.5295), tensor(888.5295)) \t diffs: 5.412561563389318e-07\n", + "acc: 1.0 \t norms: (tensor(888.5309), tensor(888.5310)) \t diffs: 5.724053266931151e-07\n", + "acc: 1.0 \t norms: (tensor(888.5297), tensor(888.5297)) \t diffs: 5.980200512567535e-07\n", + "acc: 1.0 \t norms: (tensor(888.5317), tensor(888.5317)) \t diffs: 6.434746637751232e-07\n", + "acc: 1.0 \t norms: (tensor(888.5331), tensor(888.5331)) \t diffs: 6.946668236196274e-07\n" + ] + } + ], + "source": [ + "# RobertaModel ids\n", + "my_out, hf_out = my_output_ids[0], hf_output_ids[0]\n", + "\n", + "print(f\"model_out: {check(my_out.array, hf_out.detach())[3]}\")\n", + "\n", + "my_pool, hf_pool = my_output_ids[1], hf_output_ids[1]\n", + "\n", + "print(f\"pool_out: {check(my_pool.array, hf_pool.detach())[3]}\")\n", + "\n", + "print(\"intermediates:\")\n", + "my_ints, hf_ints = my_output_ids[2], hf_output_ids[2][1:]\n", + "\n", + "for i,j in zip(my_ints, hf_ints):\n", + " print(check(i.array,j.detach())[3])" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "model_out: acc: 1.0 \t norms: (tensor(888.5307), tensor(888.5306)) \t diffs: 6.107513286224275e-07\n", + "pool_out: acc: 1.0 \t norms: (tensor(24.6094), tensor(24.6094)) \t diffs: 3.8713346839358564e-07\n", + "intermediates:\n", + "acc: 1.0 \t norms: (tensor(888.5317), tensor(888.5317)) \t diffs: 1.3876427829018212e-07\n", + "acc: 1.0 \t norms: (tensor(888.5305), tensor(888.5305)) \t diffs: 2.239140144411067e-07\n", + "acc: 1.0 \t norms: (tensor(888.5290), tensor(888.5290)) \t diffs: 2.9642876597790746e-07\n", + "acc: 1.0 \t norms: (tensor(888.5301), tensor(888.5300)) \t diffs: 3.554245893155894e-07\n", + "acc: 1.0 \t norms: (tensor(888.5311), tensor(888.5311)) \t diffs: 4.070468264671945e-07\n", + "acc: 1.0 \t norms: (tensor(888.5311), tensor(888.5311)) \t diffs: 4.3890696588277933e-07\n", + "acc: 1.0 \t norms: (tensor(888.5305), tensor(888.5304)) \t diffs: 4.87373824853421e-07\n", + "acc: 1.0 \t norms: (tensor(888.5308), tensor(888.5308)) \t diffs: 5.209108735471091e-07\n", + "acc: 1.0 \t norms: (tensor(888.5298), tensor(888.5298)) \t diffs: 5.463080583467672e-07\n", + "acc: 1.0 \t norms: (tensor(888.5311), tensor(888.5311)) \t diffs: 5.576325179390551e-07\n", + "acc: 1.0 \t norms: (tensor(888.5295), tensor(888.5294)) \t diffs: 5.659875341734733e-07\n", + "acc: 1.0 \t norms: (tensor(888.5307), tensor(888.5306)) \t diffs: 6.107513286224275e-07\n" + ] + } + ], + "source": [ + "# RobertaModel embeds\n", + "\n", + "my_out, hf_out = my_output_embeds[0], hf_output_embeds[0]\n", + "\n", + "print(f\"model_out: {check(my_out.array, hf_out.detach())[3]}\")\n", + "\n", + "my_pool, hf_pool = my_output_embeds[1], hf_output_embeds[1]\n", + "\n", + "print(f\"pool_out: {check(my_pool.array, hf_pool.detach())[3]}\")\n", + "\n", + "print(\"intermediates:\")\n", + "my_ints, hf_ints = my_output_embeds[2], hf_output_embeds[2][1:]\n", + "\n", + "for i,j in zip(my_ints, hf_ints):\n", + " print(check(i.array,j.detach())[3])" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "def test_RobertaForMaskedLM_Output(key_run, ids = False):\n", + " if ids:\n", + " my_output = my_mlm(input_ids = input_ids, attention_mask=mask, key=key_run)\n", + " hf_output = hf_mlm(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False, output_hidden_states=True)\n", + " else:\n", + " my_output = my_mlm(input_embeds = input_embeds, attention_mask=mask, key=key_run)\n", + " hf_output = hf_mlm(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False, output_hidden_states=True)\n", + "\n", + " return my_output, hf_output" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "my_mlm_output_ids, hf_mlm_output_ids = test_RobertaForMaskedLM_Output(key_run, ids=True)\n", + "my_mlm_output_embeds, hf_mlm_output_embeds = test_RobertaForMaskedLM_Output(key_run, ids=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mlm_out: acc: 1.0 \t norms: (tensor(7054.6812), tensor(7054.6816)) \t diffs: 7.966510224832746e-07\n", + "intermediates:\n", + "acc: 1.0 \t norms: (tensor(888.5315), tensor(888.5314)) \t diffs: 1.6926360046909394e-07\n", + "acc: 1.0 \t norms: (tensor(888.5301), tensor(888.5301)) \t diffs: 2.57225963196106e-07\n", + "acc: 1.0 \t norms: (tensor(888.5310), tensor(888.5311)) \t diffs: 3.1373855335914413e-07\n", + "acc: 1.0 \t norms: (tensor(888.5312), tensor(888.5312)) \t diffs: 3.6955742643840495e-07\n", + "acc: 1.0 \t norms: (tensor(888.5304), tensor(888.5304)) \t diffs: 4.1312114262836985e-07\n", + "acc: 1.0 \t norms: (tensor(888.5308), tensor(888.5308)) \t diffs: 4.612424220340472e-07\n", + "acc: 1.0 \t norms: (tensor(888.5308), tensor(888.5308)) \t diffs: 5.164657750356128e-07\n", + "acc: 1.0 \t norms: (tensor(888.5295), tensor(888.5295)) \t diffs: 5.412561563389318e-07\n", + "acc: 1.0 \t norms: (tensor(888.5309), tensor(888.5310)) \t diffs: 5.724053266931151e-07\n", + "acc: 1.0 \t norms: (tensor(888.5297), tensor(888.5297)) \t diffs: 5.980200512567535e-07\n", + "acc: 1.0 \t norms: (tensor(888.5317), tensor(888.5317)) \t diffs: 6.434746637751232e-07\n", + "acc: 1.0 \t norms: (tensor(888.5331), tensor(888.5331)) \t diffs: 6.946668236196274e-07\n" + ] + } + ], + "source": [ + "#Masked MLM ids\n", + "my_out, hf_out = my_mlm_output_ids[0], hf_mlm_output_ids[0]\n", + "\n", + "print(f\"mlm_out: {check(my_out.array, hf_out.detach())[3]}\")\n", + "\n", + "print(\"intermediates:\")\n", + "my_ints, hf_ints = my_mlm_output_ids[1], hf_mlm_output_ids[1][1:]\n", + "\n", + "for i,j in zip(my_ints, hf_ints):\n", + " print(check(i.array,j.detach(), precision = 0.01)[3])" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mlm_out: acc: 1.0 \t norms: (tensor(7107.9902), tensor(7107.9902)) \t diffs: 7.997662692105223e-07\n", + "intermediates:\n", + "acc: 1.0 \t norms: (tensor(888.5317), tensor(888.5317)) \t diffs: 1.3876427829018212e-07\n", + "acc: 1.0 \t norms: (tensor(888.5305), tensor(888.5305)) \t diffs: 2.239140144411067e-07\n", + "acc: 1.0 \t norms: (tensor(888.5290), tensor(888.5290)) \t diffs: 2.9642876597790746e-07\n", + "acc: 1.0 \t norms: (tensor(888.5301), tensor(888.5300)) \t diffs: 3.554245893155894e-07\n", + "acc: 1.0 \t norms: (tensor(888.5311), tensor(888.5311)) \t diffs: 4.070468264671945e-07\n", + "acc: 1.0 \t norms: (tensor(888.5311), tensor(888.5311)) \t diffs: 4.3890696588277933e-07\n", + "acc: 1.0 \t norms: (tensor(888.5305), tensor(888.5304)) \t diffs: 4.87373824853421e-07\n", + "acc: 1.0 \t norms: (tensor(888.5308), tensor(888.5308)) \t diffs: 5.209108735471091e-07\n", + "acc: 1.0 \t norms: (tensor(888.5298), tensor(888.5298)) \t diffs: 5.463080583467672e-07\n", + "acc: 1.0 \t norms: (tensor(888.5311), tensor(888.5311)) \t diffs: 5.576325179390551e-07\n", + "acc: 1.0 \t norms: (tensor(888.5295), tensor(888.5294)) \t diffs: 5.659875341734733e-07\n", + "acc: 1.0 \t norms: (tensor(888.5307), tensor(888.5306)) \t diffs: 6.107513286224275e-07\n" + ] + } + ], + "source": [ + "#Masked MLM embeds\n", + "my_out, hf_out = my_mlm_output_embeds[0], hf_mlm_output_embeds[0]\n", + "\n", + "print(f\"mlm_out: {check(my_out.array, hf_out.detach())[3]}\")\n", + "\n", + "print(\"intermediates:\")\n", + "my_ints, hf_ints = my_mlm_output_embeds[1], hf_mlm_output_embeds[1][1:]\n", + "\n", + "for i,j in zip(my_ints, hf_ints):\n", + " print(check(i.array,j.detach())[3])" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "# # Testing RobertaModel\n", + "# my_model = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=False, key=key_rob)\n", + "# state_model = my_model.to_state_dict()\n", + "\n", + "# state_model = {k: torch.from_numpy(np.array(v)) for k, v in state_model.items()}\n", + "\n", + "# hf_model = hf_roberta.RobertaModel(hf_config, add_pooling_layer=False)\n", + "# hf_model.load_state_dict(state_model, strict=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dict_keys(['dense.weight', 'dense.bias', 'layer_norm.weight', 'layer_norm.bias', 'decoder.weight', 'decoder.bias', 'bias'])\n", + "odict_keys(['bias', 'dense.weight', 'dense.bias', 'layer_norm.weight', 'layer_norm.bias', 'decoder.weight', 'decoder.bias'])\n", + "success1\n", + "success2\n", + "0.0\n" + ] + } + ], + "source": [ + "# Testing RobertaLMHead\n", + "my_head = my_roberta.RobertaLMHead.init(Vocab, my_config, key=key_lm)\n", + "state_head = my_head.to_state_dict()\n", + "\n", + "state_head = {k: torch.from_numpy(np.array(v)) for k, v in state_head.items()}\n", + "state_head[\"bias\"] = torch.zeros(hf_config.vocab_size)\n", + "\n", + "hf_head = hf_roberta.RobertaLMHead(hf_config)\n", + "hf_head.load_state_dict(state_head, strict=True)\n", + "\n", + "print(check_dicts(state_head, hf_head.state_dict()))" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "my_output_mlm, hf_output_mlm = test_RobertaForMaskedLM_Output(key_run, ids = False)\n", + "\n", + "my_output_model, hf_output_model = test_RobertaModel_Output(key_model_run, ids = False)\n", + "my_output = my_head(my_output_model[0], key=key_lm_run)\n", + "hf_output = hf_head(hf_output_model[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "# k_rob, k_lm = jrandom.split(key, 2)\n", + "\n", + "# # MLM\n", + "# my_output_mlm = my_mlm(input_embeds = input_embeds, attention_mask=mask, key=key)\n", + "# hf_output_mlm = hf_mlm(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\n", + "\n", + "# # Model + LM\n", + "\n", + "# my_output_model = my_model(input_embeds = input_embeds, attention_mask=mask, key=k_rob)\n", + "# my_output = my_head(my_output_model[0], key=k_lm)\n", + "\n", + "# hf_output_model = hf_model(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\n", + "# hf_output = hf_head(hf_output_model[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "RobertaModel: acc: 1.0 \t norms: (tensor(888.5307), tensor(888.5306)) \t diffs: 6.107513286224275e-07\n", + "Roberta Model + LM head: acc: 1.0 \t norms: (tensor(7107.9902), tensor(7107.9902)) \t diffs: 7.997662692105223e-07\n", + "MLM: acc: 1.0 \t norms: (tensor(7107.9902), tensor(7107.9902)) \t diffs: 7.997662692105223e-07\n", + "my RobertaModel + LM head vs MLM: acc: 1.0 \t norms: (tensor(7107.9902), tensor(7107.9902)) \t diffs: 0.0\n", + "hf RobertaModel + LM head vs MLM: acc: 1.0 \t norms: (tensor(7107.9902), tensor(7107.9902)) \t diffs: 0.0\n" + ] + } + ], + "source": [ + "# embeds\n", + "print(f\"RobertaModel: {check(my_output_model[0].array, hf_output_model[0].detach())[3]}\")\n", + "print(f\"Roberta Model + LM head: {check(my_output.array, hf_output.detach())[3]}\")\n", + "print(f\"MLM: {check(my_output_mlm[0].array, hf_output_mlm[0].detach())[3]}\")\n", + "\n", + "print(f\"my RobertaModel + LM head vs MLM: {check(my_output.array, my_output_mlm[0].array)[3]}\")\n", + "print(f\"hf RobertaModel + LM head vs MLM: {check(hf_output.detach(), hf_output_mlm[0].detach())[3]}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [], + "source": [ + "my_output_mlm, hf_output_mlm = test_RobertaForMaskedLM_Output(key_run, ids = True)\n", + "\n", + "my_output_model, hf_output_model = test_RobertaModel_Output(key_model_run, ids = True)\n", + "my_output = my_head(my_output_model[0], key=key_lm_run)\n", + "hf_output = hf_head(hf_output_model[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "# # MLM\n", + "# my_output_mlm = my_mlm(input_ids = input_ids, attention_mask=mask, key=key)\n", + "# hf_output_mlm = hf_mlm(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\n", + "\n", + "# # Model + LM\n", + "# my_output_model = my_model(input_ids = input_ids, attention_mask=mask, key=k_rob)\n", + "# my_output = my_head(my_output_model[0], key=k_lm)\n", + "\n", + "# hf_output_model = hf_model(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\n", + "# hf_output = hf_head(hf_output_model[0])\n", + "\n", + "# # Notes\n", + "# # embeds works between hf and my, and within model->head and mlm \n", + "# # ids does not work between hf and my for mlm or within hf for model->head and mlm - so hf mlm is doing something weird.'''" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "RobertaModel: acc: 1.0 \t norms: (tensor(888.5331), tensor(888.5331)) \t diffs: 6.946668236196274e-07\n", + "Roberta Model + LM head: acc: 1.0 \t norms: (tensor(7054.6812), tensor(7054.6816)) \t diffs: 7.966510224832746e-07\n", + "MLM: acc: 1.0 \t norms: (tensor(7054.6812), tensor(7054.6816)) \t diffs: 7.966510224832746e-07\n", + "my RobertaModel + LM head vs MLM: acc: 1.0 \t norms: (tensor(7054.6812), tensor(7054.6812)) \t diffs: 0.0\n", + "hf RobertaModel + LM head vs MLM: acc: 1.0 \t norms: (tensor(7054.6816), tensor(7054.6816)) \t diffs: 0.0\n" + ] + } + ], + "source": [ + "# ids\n", + "print(f\"RobertaModel: {check(my_output_model[0].array, hf_output_model[0].detach())[3]}\")\n", + "print(f\"Roberta Model + LM head: {check(my_output.array, hf_output.detach())[3]}\")\n", + "print(f\"MLM: {check(my_output_mlm[0].array, hf_output_mlm[0].detach())[3]}\")\n", + "\n", + "print(f\"my RobertaModel + LM head vs MLM: {check(my_output.array, my_output_mlm[0].array)[3]}\")\n", + "print(f\"hf RobertaModel + LM head vs MLM: {check(hf_output.detach(), hf_output_mlm[0].detach())[3]}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'stop' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[1;32mIn[33], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m \u001b[43mstop\u001b[49m\n", + "\u001b[1;31mNameError\u001b[0m: name 'stop' is not defined" + ] + } + ], + "source": [ + "stop" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [], + "source": [ + "# Load pretrained weights from hf\n", + "hf_model = hf_roberta.RobertaModel.from_pretrained(\"roberta-base\")\n", + "state_model = hf_model.state_dict()\n", + "\n", + "state_model = {k: np.array(v) for k, v in state_model.items()}\n", + "\n", + "hf_config = hf_model.config\n", + "hf_config.hidden_dropout_prob = 0\n", + "hf_config.attention_probs_dropout_prob = 0\n", + "my_config = my_roberta.RobertaConfig.from_hf_config(hf_config)\n", + "\n", + "my_model = my_roberta.RobertaModel.init(Vocab, my_config, output_hidden_states=True, key=key)\n", + "my_model = my_model.from_state_dict(state_model)" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dict_keys(['encoder.layer.0.attention.self.query.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.layer.0.intermediate.dense.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.0.output.dense.weight', 'encoder.layer.0.output.dense.bias', 'encoder.layer.0.output.LayerNorm.weight', 'encoder.layer.0.output.LayerNorm.bias', 'encoder.layer.1.attention.self.query.weight', 'encoder.layer.1.attention.self.query.bias', 'encoder.layer.1.attention.self.key.weight', 'encoder.layer.1.attention.self.key.bias', 'encoder.layer.1.attention.self.value.weight', 'encoder.layer.1.attention.self.value.bias', 'encoder.layer.1.attention.output.dense.weight', 'encoder.layer.1.attention.output.dense.bias', 'encoder.layer.1.attention.output.LayerNorm.weight', 'encoder.layer.1.attention.output.LayerNorm.bias', 'encoder.layer.1.intermediate.dense.weight', 'encoder.layer.1.intermediate.dense.bias', 'encoder.layer.1.output.dense.weight', 'encoder.layer.1.output.dense.bias', 'encoder.layer.1.output.LayerNorm.weight', 'encoder.layer.1.output.LayerNorm.bias', 'encoder.layer.2.attention.self.query.weight', 'encoder.layer.2.attention.self.query.bias', 'encoder.layer.2.attention.self.key.weight', 'encoder.layer.2.attention.self.key.bias', 'encoder.layer.2.attention.self.value.weight', 'encoder.layer.2.attention.self.value.bias', 'encoder.layer.2.attention.output.dense.weight', 'encoder.layer.2.attention.output.dense.bias', 'encoder.layer.2.attention.output.LayerNorm.weight', 'encoder.layer.2.attention.output.LayerNorm.bias', 'encoder.layer.2.intermediate.dense.weight', 'encoder.layer.2.intermediate.dense.bias', 'encoder.layer.2.output.dense.weight', 'encoder.layer.2.output.dense.bias', 'encoder.layer.2.output.LayerNorm.weight', 'encoder.layer.2.output.LayerNorm.bias', 'encoder.layer.3.attention.self.query.weight', 'encoder.layer.3.attention.self.query.bias', 'encoder.layer.3.attention.self.key.weight', 'encoder.layer.3.attention.self.key.bias', 'encoder.layer.3.attention.self.value.weight', 'encoder.layer.3.attention.self.value.bias', 'encoder.layer.3.attention.output.dense.weight', 'encoder.layer.3.attention.output.dense.bias', 'encoder.layer.3.attention.output.LayerNorm.weight', 'encoder.layer.3.attention.output.LayerNorm.bias', 'encoder.layer.3.intermediate.dense.weight', 'encoder.layer.3.intermediate.dense.bias', 'encoder.layer.3.output.dense.weight', 'encoder.layer.3.output.dense.bias', 'encoder.layer.3.output.LayerNorm.weight', 'encoder.layer.3.output.LayerNorm.bias', 'encoder.layer.4.attention.self.query.weight', 'encoder.layer.4.attention.self.query.bias', 'encoder.layer.4.attention.self.key.weight', 'encoder.layer.4.attention.self.key.bias', 'encoder.layer.4.attention.self.value.weight', 'encoder.layer.4.attention.self.value.bias', 'encoder.layer.4.attention.output.dense.weight', 'encoder.layer.4.attention.output.dense.bias', 'encoder.layer.4.attention.output.LayerNorm.weight', 'encoder.layer.4.attention.output.LayerNorm.bias', 'encoder.layer.4.intermediate.dense.weight', 'encoder.layer.4.intermediate.dense.bias', 'encoder.layer.4.output.dense.weight', 'encoder.layer.4.output.dense.bias', 'encoder.layer.4.output.LayerNorm.weight', 'encoder.layer.4.output.LayerNorm.bias', 'encoder.layer.5.attention.self.query.weight', 'encoder.layer.5.attention.self.query.bias', 'encoder.layer.5.attention.self.key.weight', 'encoder.layer.5.attention.self.key.bias', 'encoder.layer.5.attention.self.value.weight', 'encoder.layer.5.attention.self.value.bias', 'encoder.layer.5.attention.output.dense.weight', 'encoder.layer.5.attention.output.dense.bias', 'encoder.layer.5.attention.output.LayerNorm.weight', 'encoder.layer.5.attention.output.LayerNorm.bias', 'encoder.layer.5.intermediate.dense.weight', 'encoder.layer.5.intermediate.dense.bias', 'encoder.layer.5.output.dense.weight', 'encoder.layer.5.output.dense.bias', 'encoder.layer.5.output.LayerNorm.weight', 'encoder.layer.5.output.LayerNorm.bias', 'encoder.layer.6.attention.self.query.weight', 'encoder.layer.6.attention.self.query.bias', 'encoder.layer.6.attention.self.key.weight', 'encoder.layer.6.attention.self.key.bias', 'encoder.layer.6.attention.self.value.weight', 'encoder.layer.6.attention.self.value.bias', 'encoder.layer.6.attention.output.dense.weight', 'encoder.layer.6.attention.output.dense.bias', 'encoder.layer.6.attention.output.LayerNorm.weight', 'encoder.layer.6.attention.output.LayerNorm.bias', 'encoder.layer.6.intermediate.dense.weight', 'encoder.layer.6.intermediate.dense.bias', 'encoder.layer.6.output.dense.weight', 'encoder.layer.6.output.dense.bias', 'encoder.layer.6.output.LayerNorm.weight', 'encoder.layer.6.output.LayerNorm.bias', 'encoder.layer.7.attention.self.query.weight', 'encoder.layer.7.attention.self.query.bias', 'encoder.layer.7.attention.self.key.weight', 'encoder.layer.7.attention.self.key.bias', 'encoder.layer.7.attention.self.value.weight', 'encoder.layer.7.attention.self.value.bias', 'encoder.layer.7.attention.output.dense.weight', 'encoder.layer.7.attention.output.dense.bias', 'encoder.layer.7.attention.output.LayerNorm.weight', 'encoder.layer.7.attention.output.LayerNorm.bias', 'encoder.layer.7.intermediate.dense.weight', 'encoder.layer.7.intermediate.dense.bias', 'encoder.layer.7.output.dense.weight', 'encoder.layer.7.output.dense.bias', 'encoder.layer.7.output.LayerNorm.weight', 'encoder.layer.7.output.LayerNorm.bias', 'encoder.layer.8.attention.self.query.weight', 'encoder.layer.8.attention.self.query.bias', 'encoder.layer.8.attention.self.key.weight', 'encoder.layer.8.attention.self.key.bias', 'encoder.layer.8.attention.self.value.weight', 'encoder.layer.8.attention.self.value.bias', 'encoder.layer.8.attention.output.dense.weight', 'encoder.layer.8.attention.output.dense.bias', 'encoder.layer.8.attention.output.LayerNorm.weight', 'encoder.layer.8.attention.output.LayerNorm.bias', 'encoder.layer.8.intermediate.dense.weight', 'encoder.layer.8.intermediate.dense.bias', 'encoder.layer.8.output.dense.weight', 'encoder.layer.8.output.dense.bias', 'encoder.layer.8.output.LayerNorm.weight', 'encoder.layer.8.output.LayerNorm.bias', 'encoder.layer.9.attention.self.query.weight', 'encoder.layer.9.attention.self.query.bias', 'encoder.layer.9.attention.self.key.weight', 'encoder.layer.9.attention.self.key.bias', 'encoder.layer.9.attention.self.value.weight', 'encoder.layer.9.attention.self.value.bias', 'encoder.layer.9.attention.output.dense.weight', 'encoder.layer.9.attention.output.dense.bias', 'encoder.layer.9.attention.output.LayerNorm.weight', 'encoder.layer.9.attention.output.LayerNorm.bias', 'encoder.layer.9.intermediate.dense.weight', 'encoder.layer.9.intermediate.dense.bias', 'encoder.layer.9.output.dense.weight', 'encoder.layer.9.output.dense.bias', 'encoder.layer.9.output.LayerNorm.weight', 'encoder.layer.9.output.LayerNorm.bias', 'encoder.layer.10.attention.self.query.weight', 'encoder.layer.10.attention.self.query.bias', 'encoder.layer.10.attention.self.key.weight', 'encoder.layer.10.attention.self.key.bias', 'encoder.layer.10.attention.self.value.weight', 'encoder.layer.10.attention.self.value.bias', 'encoder.layer.10.attention.output.dense.weight', 'encoder.layer.10.attention.output.dense.bias', 'encoder.layer.10.attention.output.LayerNorm.weight', 'encoder.layer.10.attention.output.LayerNorm.bias', 'encoder.layer.10.intermediate.dense.weight', 'encoder.layer.10.intermediate.dense.bias', 'encoder.layer.10.output.dense.weight', 'encoder.layer.10.output.dense.bias', 'encoder.layer.10.output.LayerNorm.weight', 'encoder.layer.10.output.LayerNorm.bias', 'encoder.layer.11.attention.self.query.weight', 'encoder.layer.11.attention.self.query.bias', 'encoder.layer.11.attention.self.key.weight', 'encoder.layer.11.attention.self.key.bias', 'encoder.layer.11.attention.self.value.weight', 'encoder.layer.11.attention.self.value.bias', 'encoder.layer.11.attention.output.dense.weight', 'encoder.layer.11.attention.output.dense.bias', 'encoder.layer.11.attention.output.LayerNorm.weight', 'encoder.layer.11.attention.output.LayerNorm.bias', 'encoder.layer.11.intermediate.dense.weight', 'encoder.layer.11.intermediate.dense.bias', 'encoder.layer.11.output.dense.weight', 'encoder.layer.11.output.dense.bias', 'encoder.layer.11.output.LayerNorm.weight', 'encoder.layer.11.output.LayerNorm.bias', 'embeddings.word_embeddings.weight', 'embeddings.position_embeddings.weight', 'embeddings.token_type_embeddings.weight', 'embeddings.LayerNorm.weight', 'embeddings.LayerNorm.bias', 'pooler.dense.weight', 'pooler.dense.bias'])\n", + "odict_keys(['embeddings.word_embeddings.weight', 'embeddings.position_embeddings.weight', 'embeddings.token_type_embeddings.weight', 'embeddings.LayerNorm.weight', 'embeddings.LayerNorm.bias', 'encoder.layer.0.attention.self.query.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.layer.0.intermediate.dense.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.0.output.dense.weight', 'encoder.layer.0.output.dense.bias', 'encoder.layer.0.output.LayerNorm.weight', 'encoder.layer.0.output.LayerNorm.bias', 'encoder.layer.1.attention.self.query.weight', 'encoder.layer.1.attention.self.query.bias', 'encoder.layer.1.attention.self.key.weight', 'encoder.layer.1.attention.self.key.bias', 'encoder.layer.1.attention.self.value.weight', 'encoder.layer.1.attention.self.value.bias', 'encoder.layer.1.attention.output.dense.weight', 'encoder.layer.1.attention.output.dense.bias', 'encoder.layer.1.attention.output.LayerNorm.weight', 'encoder.layer.1.attention.output.LayerNorm.bias', 'encoder.layer.1.intermediate.dense.weight', 'encoder.layer.1.intermediate.dense.bias', 'encoder.layer.1.output.dense.weight', 'encoder.layer.1.output.dense.bias', 'encoder.layer.1.output.LayerNorm.weight', 'encoder.layer.1.output.LayerNorm.bias', 'encoder.layer.2.attention.self.query.weight', 'encoder.layer.2.attention.self.query.bias', 'encoder.layer.2.attention.self.key.weight', 'encoder.layer.2.attention.self.key.bias', 'encoder.layer.2.attention.self.value.weight', 'encoder.layer.2.attention.self.value.bias', 'encoder.layer.2.attention.output.dense.weight', 'encoder.layer.2.attention.output.dense.bias', 'encoder.layer.2.attention.output.LayerNorm.weight', 'encoder.layer.2.attention.output.LayerNorm.bias', 'encoder.layer.2.intermediate.dense.weight', 'encoder.layer.2.intermediate.dense.bias', 'encoder.layer.2.output.dense.weight', 'encoder.layer.2.output.dense.bias', 'encoder.layer.2.output.LayerNorm.weight', 'encoder.layer.2.output.LayerNorm.bias', 'encoder.layer.3.attention.self.query.weight', 'encoder.layer.3.attention.self.query.bias', 'encoder.layer.3.attention.self.key.weight', 'encoder.layer.3.attention.self.key.bias', 'encoder.layer.3.attention.self.value.weight', 'encoder.layer.3.attention.self.value.bias', 'encoder.layer.3.attention.output.dense.weight', 'encoder.layer.3.attention.output.dense.bias', 'encoder.layer.3.attention.output.LayerNorm.weight', 'encoder.layer.3.attention.output.LayerNorm.bias', 'encoder.layer.3.intermediate.dense.weight', 'encoder.layer.3.intermediate.dense.bias', 'encoder.layer.3.output.dense.weight', 'encoder.layer.3.output.dense.bias', 'encoder.layer.3.output.LayerNorm.weight', 'encoder.layer.3.output.LayerNorm.bias', 'encoder.layer.4.attention.self.query.weight', 'encoder.layer.4.attention.self.query.bias', 'encoder.layer.4.attention.self.key.weight', 'encoder.layer.4.attention.self.key.bias', 'encoder.layer.4.attention.self.value.weight', 'encoder.layer.4.attention.self.value.bias', 'encoder.layer.4.attention.output.dense.weight', 'encoder.layer.4.attention.output.dense.bias', 'encoder.layer.4.attention.output.LayerNorm.weight', 'encoder.layer.4.attention.output.LayerNorm.bias', 'encoder.layer.4.intermediate.dense.weight', 'encoder.layer.4.intermediate.dense.bias', 'encoder.layer.4.output.dense.weight', 'encoder.layer.4.output.dense.bias', 'encoder.layer.4.output.LayerNorm.weight', 'encoder.layer.4.output.LayerNorm.bias', 'encoder.layer.5.attention.self.query.weight', 'encoder.layer.5.attention.self.query.bias', 'encoder.layer.5.attention.self.key.weight', 'encoder.layer.5.attention.self.key.bias', 'encoder.layer.5.attention.self.value.weight', 'encoder.layer.5.attention.self.value.bias', 'encoder.layer.5.attention.output.dense.weight', 'encoder.layer.5.attention.output.dense.bias', 'encoder.layer.5.attention.output.LayerNorm.weight', 'encoder.layer.5.attention.output.LayerNorm.bias', 'encoder.layer.5.intermediate.dense.weight', 'encoder.layer.5.intermediate.dense.bias', 'encoder.layer.5.output.dense.weight', 'encoder.layer.5.output.dense.bias', 'encoder.layer.5.output.LayerNorm.weight', 'encoder.layer.5.output.LayerNorm.bias', 'encoder.layer.6.attention.self.query.weight', 'encoder.layer.6.attention.self.query.bias', 'encoder.layer.6.attention.self.key.weight', 'encoder.layer.6.attention.self.key.bias', 'encoder.layer.6.attention.self.value.weight', 'encoder.layer.6.attention.self.value.bias', 'encoder.layer.6.attention.output.dense.weight', 'encoder.layer.6.attention.output.dense.bias', 'encoder.layer.6.attention.output.LayerNorm.weight', 'encoder.layer.6.attention.output.LayerNorm.bias', 'encoder.layer.6.intermediate.dense.weight', 'encoder.layer.6.intermediate.dense.bias', 'encoder.layer.6.output.dense.weight', 'encoder.layer.6.output.dense.bias', 'encoder.layer.6.output.LayerNorm.weight', 'encoder.layer.6.output.LayerNorm.bias', 'encoder.layer.7.attention.self.query.weight', 'encoder.layer.7.attention.self.query.bias', 'encoder.layer.7.attention.self.key.weight', 'encoder.layer.7.attention.self.key.bias', 'encoder.layer.7.attention.self.value.weight', 'encoder.layer.7.attention.self.value.bias', 'encoder.layer.7.attention.output.dense.weight', 'encoder.layer.7.attention.output.dense.bias', 'encoder.layer.7.attention.output.LayerNorm.weight', 'encoder.layer.7.attention.output.LayerNorm.bias', 'encoder.layer.7.intermediate.dense.weight', 'encoder.layer.7.intermediate.dense.bias', 'encoder.layer.7.output.dense.weight', 'encoder.layer.7.output.dense.bias', 'encoder.layer.7.output.LayerNorm.weight', 'encoder.layer.7.output.LayerNorm.bias', 'encoder.layer.8.attention.self.query.weight', 'encoder.layer.8.attention.self.query.bias', 'encoder.layer.8.attention.self.key.weight', 'encoder.layer.8.attention.self.key.bias', 'encoder.layer.8.attention.self.value.weight', 'encoder.layer.8.attention.self.value.bias', 'encoder.layer.8.attention.output.dense.weight', 'encoder.layer.8.attention.output.dense.bias', 'encoder.layer.8.attention.output.LayerNorm.weight', 'encoder.layer.8.attention.output.LayerNorm.bias', 'encoder.layer.8.intermediate.dense.weight', 'encoder.layer.8.intermediate.dense.bias', 'encoder.layer.8.output.dense.weight', 'encoder.layer.8.output.dense.bias', 'encoder.layer.8.output.LayerNorm.weight', 'encoder.layer.8.output.LayerNorm.bias', 'encoder.layer.9.attention.self.query.weight', 'encoder.layer.9.attention.self.query.bias', 'encoder.layer.9.attention.self.key.weight', 'encoder.layer.9.attention.self.key.bias', 'encoder.layer.9.attention.self.value.weight', 'encoder.layer.9.attention.self.value.bias', 'encoder.layer.9.attention.output.dense.weight', 'encoder.layer.9.attention.output.dense.bias', 'encoder.layer.9.attention.output.LayerNorm.weight', 'encoder.layer.9.attention.output.LayerNorm.bias', 'encoder.layer.9.intermediate.dense.weight', 'encoder.layer.9.intermediate.dense.bias', 'encoder.layer.9.output.dense.weight', 'encoder.layer.9.output.dense.bias', 'encoder.layer.9.output.LayerNorm.weight', 'encoder.layer.9.output.LayerNorm.bias', 'encoder.layer.10.attention.self.query.weight', 'encoder.layer.10.attention.self.query.bias', 'encoder.layer.10.attention.self.key.weight', 'encoder.layer.10.attention.self.key.bias', 'encoder.layer.10.attention.self.value.weight', 'encoder.layer.10.attention.self.value.bias', 'encoder.layer.10.attention.output.dense.weight', 'encoder.layer.10.attention.output.dense.bias', 'encoder.layer.10.attention.output.LayerNorm.weight', 'encoder.layer.10.attention.output.LayerNorm.bias', 'encoder.layer.10.intermediate.dense.weight', 'encoder.layer.10.intermediate.dense.bias', 'encoder.layer.10.output.dense.weight', 'encoder.layer.10.output.dense.bias', 'encoder.layer.10.output.LayerNorm.weight', 'encoder.layer.10.output.LayerNorm.bias', 'encoder.layer.11.attention.self.query.weight', 'encoder.layer.11.attention.self.query.bias', 'encoder.layer.11.attention.self.key.weight', 'encoder.layer.11.attention.self.key.bias', 'encoder.layer.11.attention.self.value.weight', 'encoder.layer.11.attention.self.value.bias', 'encoder.layer.11.attention.output.dense.weight', 'encoder.layer.11.attention.output.dense.bias', 'encoder.layer.11.attention.output.LayerNorm.weight', 'encoder.layer.11.attention.output.LayerNorm.bias', 'encoder.layer.11.intermediate.dense.weight', 'encoder.layer.11.intermediate.dense.bias', 'encoder.layer.11.output.dense.weight', 'encoder.layer.11.output.dense.bias', 'encoder.layer.11.output.LayerNorm.weight', 'encoder.layer.11.output.LayerNorm.bias', 'pooler.dense.weight', 'pooler.dense.bias'])\n", + "success1\n", + "success2\n", + "Total differences: 0.0\n" + ] + } + ], + "source": [ + "# Check weights loaded correctly\n", + "my_dict = my_model.to_state_dict()\n", + "hf_dict = hf_model.state_dict()\n", + "\n", + "print(f\"Total differences: {check_dicts(my_dict, hf_dict)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n" + ] + } + ], + "source": [ + "my_output_model = my_model(input_embeds = input_embeds, attention_mask=mask, key=key)\n", + "hf_output_model = hf_model(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: acc: 1.0 \t norms: (tensor(443.3569), tensor(443.3564)) \t diffs: 1.405485477334878e-06\n" + ] + } + ], + "source": [ + "print(f\"Model: {check(my_output_model[0].array, hf_output_model[0].detach())[3]}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n" + ] + } + ], + "source": [ + "my_output_model = my_model(input_ids = input_ids, attention_mask=mask, key=key)\n", + "hf_output_model = hf_model(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: acc: 1.0 \t norms: (tensor(431.9545), tensor(431.9545)) \t diffs: 1.3033360346526024e-06\n" + ] + } + ], + "source": [ + "print(f\"Model: {check(my_output_model[0].array, hf_output_model[0].detach())[3]}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": {}, + "outputs": [], + "source": [ + "hf_mlm = hf_roberta.RobertaForMaskedLM.from_pretrained(hf_model_str)\n", + "state_mlm = hf_mlm.state_dict()\n", + "\n", + "state_mlm = {k: np.array(v) for k, v in state_mlm.items()}\n", + "\n", + "hf_config = hf_mlm.config\n", + "hf_config.hidden_dropout_prob = 0\n", + "hf_config.attention_probs_dropout_prob = 0\n", + "my_config = my_roberta.RobertaConfig.from_hf_config(hf_config)\n", + "\n", + "my_mlm = my_roberta.RobertaForMaskedLM.init(Vocab, my_config, key=key)\n", + "my_mlm = my_mlm.from_state_dict(state_mlm)" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dict_keys(['roberta.encoder.layer.0.attention.self.query.weight', 'roberta.encoder.layer.0.attention.self.query.bias', 'roberta.encoder.layer.0.attention.self.key.weight', 'roberta.encoder.layer.0.attention.self.key.bias', 'roberta.encoder.layer.0.attention.self.value.weight', 'roberta.encoder.layer.0.attention.self.value.bias', 'roberta.encoder.layer.0.attention.output.dense.weight', 'roberta.encoder.layer.0.attention.output.dense.bias', 'roberta.encoder.layer.0.attention.output.LayerNorm.weight', 'roberta.encoder.layer.0.attention.output.LayerNorm.bias', 'roberta.encoder.layer.0.intermediate.dense.weight', 'roberta.encoder.layer.0.intermediate.dense.bias', 'roberta.encoder.layer.0.output.dense.weight', 'roberta.encoder.layer.0.output.dense.bias', 'roberta.encoder.layer.0.output.LayerNorm.weight', 'roberta.encoder.layer.0.output.LayerNorm.bias', 'roberta.encoder.layer.1.attention.self.query.weight', 'roberta.encoder.layer.1.attention.self.query.bias', 'roberta.encoder.layer.1.attention.self.key.weight', 'roberta.encoder.layer.1.attention.self.key.bias', 'roberta.encoder.layer.1.attention.self.value.weight', 'roberta.encoder.layer.1.attention.self.value.bias', 'roberta.encoder.layer.1.attention.output.dense.weight', 'roberta.encoder.layer.1.attention.output.dense.bias', 'roberta.encoder.layer.1.attention.output.LayerNorm.weight', 'roberta.encoder.layer.1.attention.output.LayerNorm.bias', 'roberta.encoder.layer.1.intermediate.dense.weight', 'roberta.encoder.layer.1.intermediate.dense.bias', 'roberta.encoder.layer.1.output.dense.weight', 'roberta.encoder.layer.1.output.dense.bias', 'roberta.encoder.layer.1.output.LayerNorm.weight', 'roberta.encoder.layer.1.output.LayerNorm.bias', 'roberta.encoder.layer.2.attention.self.query.weight', 'roberta.encoder.layer.2.attention.self.query.bias', 'roberta.encoder.layer.2.attention.self.key.weight', 'roberta.encoder.layer.2.attention.self.key.bias', 'roberta.encoder.layer.2.attention.self.value.weight', 'roberta.encoder.layer.2.attention.self.value.bias', 'roberta.encoder.layer.2.attention.output.dense.weight', 'roberta.encoder.layer.2.attention.output.dense.bias', 'roberta.encoder.layer.2.attention.output.LayerNorm.weight', 'roberta.encoder.layer.2.attention.output.LayerNorm.bias', 'roberta.encoder.layer.2.intermediate.dense.weight', 'roberta.encoder.layer.2.intermediate.dense.bias', 'roberta.encoder.layer.2.output.dense.weight', 'roberta.encoder.layer.2.output.dense.bias', 'roberta.encoder.layer.2.output.LayerNorm.weight', 'roberta.encoder.layer.2.output.LayerNorm.bias', 'roberta.encoder.layer.3.attention.self.query.weight', 'roberta.encoder.layer.3.attention.self.query.bias', 'roberta.encoder.layer.3.attention.self.key.weight', 'roberta.encoder.layer.3.attention.self.key.bias', 'roberta.encoder.layer.3.attention.self.value.weight', 'roberta.encoder.layer.3.attention.self.value.bias', 'roberta.encoder.layer.3.attention.output.dense.weight', 'roberta.encoder.layer.3.attention.output.dense.bias', 'roberta.encoder.layer.3.attention.output.LayerNorm.weight', 'roberta.encoder.layer.3.attention.output.LayerNorm.bias', 'roberta.encoder.layer.3.intermediate.dense.weight', 'roberta.encoder.layer.3.intermediate.dense.bias', 'roberta.encoder.layer.3.output.dense.weight', 'roberta.encoder.layer.3.output.dense.bias', 'roberta.encoder.layer.3.output.LayerNorm.weight', 'roberta.encoder.layer.3.output.LayerNorm.bias', 'roberta.encoder.layer.4.attention.self.query.weight', 'roberta.encoder.layer.4.attention.self.query.bias', 'roberta.encoder.layer.4.attention.self.key.weight', 'roberta.encoder.layer.4.attention.self.key.bias', 'roberta.encoder.layer.4.attention.self.value.weight', 'roberta.encoder.layer.4.attention.self.value.bias', 'roberta.encoder.layer.4.attention.output.dense.weight', 'roberta.encoder.layer.4.attention.output.dense.bias', 'roberta.encoder.layer.4.attention.output.LayerNorm.weight', 'roberta.encoder.layer.4.attention.output.LayerNorm.bias', 'roberta.encoder.layer.4.intermediate.dense.weight', 'roberta.encoder.layer.4.intermediate.dense.bias', 'roberta.encoder.layer.4.output.dense.weight', 'roberta.encoder.layer.4.output.dense.bias', 'roberta.encoder.layer.4.output.LayerNorm.weight', 'roberta.encoder.layer.4.output.LayerNorm.bias', 'roberta.encoder.layer.5.attention.self.query.weight', 'roberta.encoder.layer.5.attention.self.query.bias', 'roberta.encoder.layer.5.attention.self.key.weight', 'roberta.encoder.layer.5.attention.self.key.bias', 'roberta.encoder.layer.5.attention.self.value.weight', 'roberta.encoder.layer.5.attention.self.value.bias', 'roberta.encoder.layer.5.attention.output.dense.weight', 'roberta.encoder.layer.5.attention.output.dense.bias', 'roberta.encoder.layer.5.attention.output.LayerNorm.weight', 'roberta.encoder.layer.5.attention.output.LayerNorm.bias', 'roberta.encoder.layer.5.intermediate.dense.weight', 'roberta.encoder.layer.5.intermediate.dense.bias', 'roberta.encoder.layer.5.output.dense.weight', 'roberta.encoder.layer.5.output.dense.bias', 'roberta.encoder.layer.5.output.LayerNorm.weight', 'roberta.encoder.layer.5.output.LayerNorm.bias', 'roberta.encoder.layer.6.attention.self.query.weight', 'roberta.encoder.layer.6.attention.self.query.bias', 'roberta.encoder.layer.6.attention.self.key.weight', 'roberta.encoder.layer.6.attention.self.key.bias', 'roberta.encoder.layer.6.attention.self.value.weight', 'roberta.encoder.layer.6.attention.self.value.bias', 'roberta.encoder.layer.6.attention.output.dense.weight', 'roberta.encoder.layer.6.attention.output.dense.bias', 'roberta.encoder.layer.6.attention.output.LayerNorm.weight', 'roberta.encoder.layer.6.attention.output.LayerNorm.bias', 'roberta.encoder.layer.6.intermediate.dense.weight', 'roberta.encoder.layer.6.intermediate.dense.bias', 'roberta.encoder.layer.6.output.dense.weight', 'roberta.encoder.layer.6.output.dense.bias', 'roberta.encoder.layer.6.output.LayerNorm.weight', 'roberta.encoder.layer.6.output.LayerNorm.bias', 'roberta.encoder.layer.7.attention.self.query.weight', 'roberta.encoder.layer.7.attention.self.query.bias', 'roberta.encoder.layer.7.attention.self.key.weight', 'roberta.encoder.layer.7.attention.self.key.bias', 'roberta.encoder.layer.7.attention.self.value.weight', 'roberta.encoder.layer.7.attention.self.value.bias', 'roberta.encoder.layer.7.attention.output.dense.weight', 'roberta.encoder.layer.7.attention.output.dense.bias', 'roberta.encoder.layer.7.attention.output.LayerNorm.weight', 'roberta.encoder.layer.7.attention.output.LayerNorm.bias', 'roberta.encoder.layer.7.intermediate.dense.weight', 'roberta.encoder.layer.7.intermediate.dense.bias', 'roberta.encoder.layer.7.output.dense.weight', 'roberta.encoder.layer.7.output.dense.bias', 'roberta.encoder.layer.7.output.LayerNorm.weight', 'roberta.encoder.layer.7.output.LayerNorm.bias', 'roberta.encoder.layer.8.attention.self.query.weight', 'roberta.encoder.layer.8.attention.self.query.bias', 'roberta.encoder.layer.8.attention.self.key.weight', 'roberta.encoder.layer.8.attention.self.key.bias', 'roberta.encoder.layer.8.attention.self.value.weight', 'roberta.encoder.layer.8.attention.self.value.bias', 'roberta.encoder.layer.8.attention.output.dense.weight', 'roberta.encoder.layer.8.attention.output.dense.bias', 'roberta.encoder.layer.8.attention.output.LayerNorm.weight', 'roberta.encoder.layer.8.attention.output.LayerNorm.bias', 'roberta.encoder.layer.8.intermediate.dense.weight', 'roberta.encoder.layer.8.intermediate.dense.bias', 'roberta.encoder.layer.8.output.dense.weight', 'roberta.encoder.layer.8.output.dense.bias', 'roberta.encoder.layer.8.output.LayerNorm.weight', 'roberta.encoder.layer.8.output.LayerNorm.bias', 'roberta.encoder.layer.9.attention.self.query.weight', 'roberta.encoder.layer.9.attention.self.query.bias', 'roberta.encoder.layer.9.attention.self.key.weight', 'roberta.encoder.layer.9.attention.self.key.bias', 'roberta.encoder.layer.9.attention.self.value.weight', 'roberta.encoder.layer.9.attention.self.value.bias', 'roberta.encoder.layer.9.attention.output.dense.weight', 'roberta.encoder.layer.9.attention.output.dense.bias', 'roberta.encoder.layer.9.attention.output.LayerNorm.weight', 'roberta.encoder.layer.9.attention.output.LayerNorm.bias', 'roberta.encoder.layer.9.intermediate.dense.weight', 'roberta.encoder.layer.9.intermediate.dense.bias', 'roberta.encoder.layer.9.output.dense.weight', 'roberta.encoder.layer.9.output.dense.bias', 'roberta.encoder.layer.9.output.LayerNorm.weight', 'roberta.encoder.layer.9.output.LayerNorm.bias', 'roberta.encoder.layer.10.attention.self.query.weight', 'roberta.encoder.layer.10.attention.self.query.bias', 'roberta.encoder.layer.10.attention.self.key.weight', 'roberta.encoder.layer.10.attention.self.key.bias', 'roberta.encoder.layer.10.attention.self.value.weight', 'roberta.encoder.layer.10.attention.self.value.bias', 'roberta.encoder.layer.10.attention.output.dense.weight', 'roberta.encoder.layer.10.attention.output.dense.bias', 'roberta.encoder.layer.10.attention.output.LayerNorm.weight', 'roberta.encoder.layer.10.attention.output.LayerNorm.bias', 'roberta.encoder.layer.10.intermediate.dense.weight', 'roberta.encoder.layer.10.intermediate.dense.bias', 'roberta.encoder.layer.10.output.dense.weight', 'roberta.encoder.layer.10.output.dense.bias', 'roberta.encoder.layer.10.output.LayerNorm.weight', 'roberta.encoder.layer.10.output.LayerNorm.bias', 'roberta.encoder.layer.11.attention.self.query.weight', 'roberta.encoder.layer.11.attention.self.query.bias', 'roberta.encoder.layer.11.attention.self.key.weight', 'roberta.encoder.layer.11.attention.self.key.bias', 'roberta.encoder.layer.11.attention.self.value.weight', 'roberta.encoder.layer.11.attention.self.value.bias', 'roberta.encoder.layer.11.attention.output.dense.weight', 'roberta.encoder.layer.11.attention.output.dense.bias', 'roberta.encoder.layer.11.attention.output.LayerNorm.weight', 'roberta.encoder.layer.11.attention.output.LayerNorm.bias', 'roberta.encoder.layer.11.intermediate.dense.weight', 'roberta.encoder.layer.11.intermediate.dense.bias', 'roberta.encoder.layer.11.output.dense.weight', 'roberta.encoder.layer.11.output.dense.bias', 'roberta.encoder.layer.11.output.LayerNorm.weight', 'roberta.encoder.layer.11.output.LayerNorm.bias', 'roberta.embeddings.word_embeddings.weight', 'roberta.embeddings.position_embeddings.weight', 'roberta.embeddings.token_type_embeddings.weight', 'roberta.embeddings.LayerNorm.weight', 'roberta.embeddings.LayerNorm.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.decoder.bias'])\n", + "odict_keys(['roberta.embeddings.word_embeddings.weight', 'roberta.embeddings.position_embeddings.weight', 'roberta.embeddings.token_type_embeddings.weight', 'roberta.embeddings.LayerNorm.weight', 'roberta.embeddings.LayerNorm.bias', 'roberta.encoder.layer.0.attention.self.query.weight', 'roberta.encoder.layer.0.attention.self.query.bias', 'roberta.encoder.layer.0.attention.self.key.weight', 'roberta.encoder.layer.0.attention.self.key.bias', 'roberta.encoder.layer.0.attention.self.value.weight', 'roberta.encoder.layer.0.attention.self.value.bias', 'roberta.encoder.layer.0.attention.output.dense.weight', 'roberta.encoder.layer.0.attention.output.dense.bias', 'roberta.encoder.layer.0.attention.output.LayerNorm.weight', 'roberta.encoder.layer.0.attention.output.LayerNorm.bias', 'roberta.encoder.layer.0.intermediate.dense.weight', 'roberta.encoder.layer.0.intermediate.dense.bias', 'roberta.encoder.layer.0.output.dense.weight', 'roberta.encoder.layer.0.output.dense.bias', 'roberta.encoder.layer.0.output.LayerNorm.weight', 'roberta.encoder.layer.0.output.LayerNorm.bias', 'roberta.encoder.layer.1.attention.self.query.weight', 'roberta.encoder.layer.1.attention.self.query.bias', 'roberta.encoder.layer.1.attention.self.key.weight', 'roberta.encoder.layer.1.attention.self.key.bias', 'roberta.encoder.layer.1.attention.self.value.weight', 'roberta.encoder.layer.1.attention.self.value.bias', 'roberta.encoder.layer.1.attention.output.dense.weight', 'roberta.encoder.layer.1.attention.output.dense.bias', 'roberta.encoder.layer.1.attention.output.LayerNorm.weight', 'roberta.encoder.layer.1.attention.output.LayerNorm.bias', 'roberta.encoder.layer.1.intermediate.dense.weight', 'roberta.encoder.layer.1.intermediate.dense.bias', 'roberta.encoder.layer.1.output.dense.weight', 'roberta.encoder.layer.1.output.dense.bias', 'roberta.encoder.layer.1.output.LayerNorm.weight', 'roberta.encoder.layer.1.output.LayerNorm.bias', 'roberta.encoder.layer.2.attention.self.query.weight', 'roberta.encoder.layer.2.attention.self.query.bias', 'roberta.encoder.layer.2.attention.self.key.weight', 'roberta.encoder.layer.2.attention.self.key.bias', 'roberta.encoder.layer.2.attention.self.value.weight', 'roberta.encoder.layer.2.attention.self.value.bias', 'roberta.encoder.layer.2.attention.output.dense.weight', 'roberta.encoder.layer.2.attention.output.dense.bias', 'roberta.encoder.layer.2.attention.output.LayerNorm.weight', 'roberta.encoder.layer.2.attention.output.LayerNorm.bias', 'roberta.encoder.layer.2.intermediate.dense.weight', 'roberta.encoder.layer.2.intermediate.dense.bias', 'roberta.encoder.layer.2.output.dense.weight', 'roberta.encoder.layer.2.output.dense.bias', 'roberta.encoder.layer.2.output.LayerNorm.weight', 'roberta.encoder.layer.2.output.LayerNorm.bias', 'roberta.encoder.layer.3.attention.self.query.weight', 'roberta.encoder.layer.3.attention.self.query.bias', 'roberta.encoder.layer.3.attention.self.key.weight', 'roberta.encoder.layer.3.attention.self.key.bias', 'roberta.encoder.layer.3.attention.self.value.weight', 'roberta.encoder.layer.3.attention.self.value.bias', 'roberta.encoder.layer.3.attention.output.dense.weight', 'roberta.encoder.layer.3.attention.output.dense.bias', 'roberta.encoder.layer.3.attention.output.LayerNorm.weight', 'roberta.encoder.layer.3.attention.output.LayerNorm.bias', 'roberta.encoder.layer.3.intermediate.dense.weight', 'roberta.encoder.layer.3.intermediate.dense.bias', 'roberta.encoder.layer.3.output.dense.weight', 'roberta.encoder.layer.3.output.dense.bias', 'roberta.encoder.layer.3.output.LayerNorm.weight', 'roberta.encoder.layer.3.output.LayerNorm.bias', 'roberta.encoder.layer.4.attention.self.query.weight', 'roberta.encoder.layer.4.attention.self.query.bias', 'roberta.encoder.layer.4.attention.self.key.weight', 'roberta.encoder.layer.4.attention.self.key.bias', 'roberta.encoder.layer.4.attention.self.value.weight', 'roberta.encoder.layer.4.attention.self.value.bias', 'roberta.encoder.layer.4.attention.output.dense.weight', 'roberta.encoder.layer.4.attention.output.dense.bias', 'roberta.encoder.layer.4.attention.output.LayerNorm.weight', 'roberta.encoder.layer.4.attention.output.LayerNorm.bias', 'roberta.encoder.layer.4.intermediate.dense.weight', 'roberta.encoder.layer.4.intermediate.dense.bias', 'roberta.encoder.layer.4.output.dense.weight', 'roberta.encoder.layer.4.output.dense.bias', 'roberta.encoder.layer.4.output.LayerNorm.weight', 'roberta.encoder.layer.4.output.LayerNorm.bias', 'roberta.encoder.layer.5.attention.self.query.weight', 'roberta.encoder.layer.5.attention.self.query.bias', 'roberta.encoder.layer.5.attention.self.key.weight', 'roberta.encoder.layer.5.attention.self.key.bias', 'roberta.encoder.layer.5.attention.self.value.weight', 'roberta.encoder.layer.5.attention.self.value.bias', 'roberta.encoder.layer.5.attention.output.dense.weight', 'roberta.encoder.layer.5.attention.output.dense.bias', 'roberta.encoder.layer.5.attention.output.LayerNorm.weight', 'roberta.encoder.layer.5.attention.output.LayerNorm.bias', 'roberta.encoder.layer.5.intermediate.dense.weight', 'roberta.encoder.layer.5.intermediate.dense.bias', 'roberta.encoder.layer.5.output.dense.weight', 'roberta.encoder.layer.5.output.dense.bias', 'roberta.encoder.layer.5.output.LayerNorm.weight', 'roberta.encoder.layer.5.output.LayerNorm.bias', 'roberta.encoder.layer.6.attention.self.query.weight', 'roberta.encoder.layer.6.attention.self.query.bias', 'roberta.encoder.layer.6.attention.self.key.weight', 'roberta.encoder.layer.6.attention.self.key.bias', 'roberta.encoder.layer.6.attention.self.value.weight', 'roberta.encoder.layer.6.attention.self.value.bias', 'roberta.encoder.layer.6.attention.output.dense.weight', 'roberta.encoder.layer.6.attention.output.dense.bias', 'roberta.encoder.layer.6.attention.output.LayerNorm.weight', 'roberta.encoder.layer.6.attention.output.LayerNorm.bias', 'roberta.encoder.layer.6.intermediate.dense.weight', 'roberta.encoder.layer.6.intermediate.dense.bias', 'roberta.encoder.layer.6.output.dense.weight', 'roberta.encoder.layer.6.output.dense.bias', 'roberta.encoder.layer.6.output.LayerNorm.weight', 'roberta.encoder.layer.6.output.LayerNorm.bias', 'roberta.encoder.layer.7.attention.self.query.weight', 'roberta.encoder.layer.7.attention.self.query.bias', 'roberta.encoder.layer.7.attention.self.key.weight', 'roberta.encoder.layer.7.attention.self.key.bias', 'roberta.encoder.layer.7.attention.self.value.weight', 'roberta.encoder.layer.7.attention.self.value.bias', 'roberta.encoder.layer.7.attention.output.dense.weight', 'roberta.encoder.layer.7.attention.output.dense.bias', 'roberta.encoder.layer.7.attention.output.LayerNorm.weight', 'roberta.encoder.layer.7.attention.output.LayerNorm.bias', 'roberta.encoder.layer.7.intermediate.dense.weight', 'roberta.encoder.layer.7.intermediate.dense.bias', 'roberta.encoder.layer.7.output.dense.weight', 'roberta.encoder.layer.7.output.dense.bias', 'roberta.encoder.layer.7.output.LayerNorm.weight', 'roberta.encoder.layer.7.output.LayerNorm.bias', 'roberta.encoder.layer.8.attention.self.query.weight', 'roberta.encoder.layer.8.attention.self.query.bias', 'roberta.encoder.layer.8.attention.self.key.weight', 'roberta.encoder.layer.8.attention.self.key.bias', 'roberta.encoder.layer.8.attention.self.value.weight', 'roberta.encoder.layer.8.attention.self.value.bias', 'roberta.encoder.layer.8.attention.output.dense.weight', 'roberta.encoder.layer.8.attention.output.dense.bias', 'roberta.encoder.layer.8.attention.output.LayerNorm.weight', 'roberta.encoder.layer.8.attention.output.LayerNorm.bias', 'roberta.encoder.layer.8.intermediate.dense.weight', 'roberta.encoder.layer.8.intermediate.dense.bias', 'roberta.encoder.layer.8.output.dense.weight', 'roberta.encoder.layer.8.output.dense.bias', 'roberta.encoder.layer.8.output.LayerNorm.weight', 'roberta.encoder.layer.8.output.LayerNorm.bias', 'roberta.encoder.layer.9.attention.self.query.weight', 'roberta.encoder.layer.9.attention.self.query.bias', 'roberta.encoder.layer.9.attention.self.key.weight', 'roberta.encoder.layer.9.attention.self.key.bias', 'roberta.encoder.layer.9.attention.self.value.weight', 'roberta.encoder.layer.9.attention.self.value.bias', 'roberta.encoder.layer.9.attention.output.dense.weight', 'roberta.encoder.layer.9.attention.output.dense.bias', 'roberta.encoder.layer.9.attention.output.LayerNorm.weight', 'roberta.encoder.layer.9.attention.output.LayerNorm.bias', 'roberta.encoder.layer.9.intermediate.dense.weight', 'roberta.encoder.layer.9.intermediate.dense.bias', 'roberta.encoder.layer.9.output.dense.weight', 'roberta.encoder.layer.9.output.dense.bias', 'roberta.encoder.layer.9.output.LayerNorm.weight', 'roberta.encoder.layer.9.output.LayerNorm.bias', 'roberta.encoder.layer.10.attention.self.query.weight', 'roberta.encoder.layer.10.attention.self.query.bias', 'roberta.encoder.layer.10.attention.self.key.weight', 'roberta.encoder.layer.10.attention.self.key.bias', 'roberta.encoder.layer.10.attention.self.value.weight', 'roberta.encoder.layer.10.attention.self.value.bias', 'roberta.encoder.layer.10.attention.output.dense.weight', 'roberta.encoder.layer.10.attention.output.dense.bias', 'roberta.encoder.layer.10.attention.output.LayerNorm.weight', 'roberta.encoder.layer.10.attention.output.LayerNorm.bias', 'roberta.encoder.layer.10.intermediate.dense.weight', 'roberta.encoder.layer.10.intermediate.dense.bias', 'roberta.encoder.layer.10.output.dense.weight', 'roberta.encoder.layer.10.output.dense.bias', 'roberta.encoder.layer.10.output.LayerNorm.weight', 'roberta.encoder.layer.10.output.LayerNorm.bias', 'roberta.encoder.layer.11.attention.self.query.weight', 'roberta.encoder.layer.11.attention.self.query.bias', 'roberta.encoder.layer.11.attention.self.key.weight', 'roberta.encoder.layer.11.attention.self.key.bias', 'roberta.encoder.layer.11.attention.self.value.weight', 'roberta.encoder.layer.11.attention.self.value.bias', 'roberta.encoder.layer.11.attention.output.dense.weight', 'roberta.encoder.layer.11.attention.output.dense.bias', 'roberta.encoder.layer.11.attention.output.LayerNorm.weight', 'roberta.encoder.layer.11.attention.output.LayerNorm.bias', 'roberta.encoder.layer.11.intermediate.dense.weight', 'roberta.encoder.layer.11.intermediate.dense.bias', 'roberta.encoder.layer.11.output.dense.weight', 'roberta.encoder.layer.11.output.dense.bias', 'roberta.encoder.layer.11.output.LayerNorm.weight', 'roberta.encoder.layer.11.output.LayerNorm.bias', 'lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.decoder.bias'])\n", + "success1\n", + "fail2\n", + "['lm_head.bias']\n", + "Total differences: 0.0\n" + ] + } + ], + "source": [ + "# Check weights loaded correctly\n", + "my_dict = my_mlm.to_state_dict()\n", + "hf_dict = hf_mlm.state_dict()\n", + "\n", + "print(f\"Total differences: {check_dicts(my_dict, hf_dict)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([-0.0972, -0.0294, 0.4988, ..., -0.0312, -0.0312, -1.0000])" + ] + }, + "execution_count": 59, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hf_dict['lm_head.bias']" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([-0.0972, -0.0294, 0.4988, ..., -0.0312, -0.0312, -1.0000])" + ] + }, + "execution_count": 60, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hf_dict['lm_head.decoder.bias']" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n" + ] + } + ], + "source": [ + "my_output_mlm = my_mlm(input_embeds = input_embeds, attention_mask=mask, key=key)\n", + "hf_output_mlm = hf_mlm(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MLM: acc: 1.0 \t norms: (tensor(33433.4062), tensor(33433.3945)) \t diffs: 1.7561978893354535e-05\n" + ] + } + ], + "source": [ + "print(f\"MLM: {check(my_output_mlm[0].array, hf_output_mlm[0].detach())[3]}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n", + "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n", + " warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n" + ] + } + ], + "source": [ + "my_output_mlm = my_mlm(input_ids = input_ids, attention_mask=mask, key=key)\n", + "hf_output_mlm = hf_mlm(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MLM: acc: 1.0 \t norms: (tensor(28814.2480), tensor(28814.2168)) \t diffs: 1.45375252031954e-05\n" + ] + } + ], + "source": [ + "print(f\"MLM: {check(my_output_mlm[0].array, hf_output_mlm[0].detach())[3]}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}