diff --git a/config/roberta-tiny.yaml b/config/roberta-tiny.yaml
new file mode 100644
index 000000000..4b61ff7e4
--- /dev/null
+++ b/config/roberta-tiny.yaml
@@ -0,0 +1,39 @@
+data:
+  id: dlwh/wikitext_103_detokenized
+#  train_urls:
+#    - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz"
+#  validation_urls:
+#    - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz"
+  cache_dir: "cache/roberta-tiny"
+  tokenizer: "roberta-base"
+
+model:
+  type: roberta
+  vocab_size: 50265
+  hidden_size: 32
+  intermediate_size: 64
+  num_hidden_layers: 4
+  num_attention_heads: 2
+  max_position_embeddings: 512
+  hidden_act: "gelu"
+  hidden_dropout_prob: 0.1
+  attention_probs_dropout_prob: 0.1
+  gradient_checkpointing: true
+
+trainer:
+  tracker:
+    - type: wandb
+      project: "levanter"
+      tags: ["openwebtext", "roberta", "itest"]
+
+  mp: p=f32,c=bfloat16
+  model_axis_size: 1
+  per_device_parallelism: -1
+
+  train_batch_size: 32
+  num_train_steps: 20000
+
+optimizer:
+  learning_rate: 1E-3
+  weight_decay: 0.1
+  warmup: 0.01
\ No newline at end of file
diff --git a/config/roberta.yaml b/config/roberta.yaml
new file mode 100644
index 000000000..cea6bbb77
--- /dev/null
+++ b/config/roberta.yaml
@@ -0,0 +1,34 @@
+data:
+  id: dlwh/wikitext_103_detokenized
+  tokenizer: "roberta-base"
+
+model:
+  type: roberta
+  vocab_size: 50265
+  hidden_size: 768
+  intermediate_size: 3072
+  num_hidden_layers: 12
+  num_attention_heads: 12
+  max_position_embeddings: 512
+  hidden_act: "gelu"
+  hidden_dropout_prob: 0.1
+  attention_probs_dropout_prob: 0.1
+  gradient_checkpointing: true
+
+trainer:
+  tracker:
+    - type: wandb
+      project: "levanter"
+      tags: ["openwebtext", "roberta", "itest"]
+
+  mp: p=f32,c=bfloat16
+  model_axis_size: 1
+  per_device_parallelism: -1
+
+  train_batch_size: 32
+  num_train_steps: 20000
+
+optimizer:
+  learning_rate: 1E-3
+  weight_decay: 0.1
+  warmup: 0.01
\ No newline at end of file
diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py
index c29e55e83..e29cee205 100644
--- a/src/levanter/data/text.py
+++ b/src/levanter/data/text.py
@@ -14,6 +14,7 @@
 import equinox as eqx
 import fsspec
 import jax
+import jax.numpy as jnp
 import numpy as np
 import pyarrow as pa
 import regex
@@ -25,13 +26,11 @@
 
 from levanter.data.mixture import MixtureDataset, StopStrategy
 
-# intercept the logging nonsense here
 from levanter.logging import silence_transformer_nag  # noqa
 from levanter.models.attention import AttentionMask
-from levanter.models.lm_model import LmExample
+from levanter.models.lm_model import MaskedLmExample, LmExample
 from levanter.utils.hf_utils import num_cpus_used_by_tokenizer
 
-
 silence_transformer_nag()  # noqa
 from transformers import BatchEncoding, PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast  # noqa
 
@@ -47,7 +46,6 @@
 from levanter.shapes import NamedShapeSpec, ShapeSpec  # noqa
 from levanter.utils.jax_utils import use_cpu_device  # noqa
 
-
 logger = logging.getLogger("levanter.data.text")
 
 # TASKS:
@@ -58,6 +56,83 @@
 
 DEFAULT_IGNORE_INDEX = -100  # Mirrors pytorch's default ignore index
 
+class MaskedLmDataset(ShardableDataset[MaskedLmExample]):
+    def __init__(
+        self,
+        dataset: ShardableDataset[np.ndarray],
+        QPos: Axis,
+        KPos: Axis,
+        mask_token_id: int,
+        mask_prob: float = 0.15,
+        noise_prob: float = 0.1,
+        key: Optional[PRNGKeyArray] = None,
+        # ignore_index: Optional[int] = DEFAULT_IGNORE_INDEX,
+    ):
+        self.dataset = dataset
+        self.QPos = QPos
+        self.KPos = KPos
+        self.mask_prob = mask_prob
+        self.noise_prob = noise_prob
+        self.key = key
+        self.mask_token_id = mask_token_id
+
+        if self.mask_prob > 0.0 and self.key is None:
+            raise ValueError("must provide key if mask_prob > 0.0")
+
+    def shard(self, shard_id: int, num_shards: int) -> "MaskedLmDataset":
+        return MaskedLmDataset(
+            self.dataset.shard(shard_id, num_shards), self.QPos, self.KPos,
+            self.mask_token_id,
+            self.mask_prob, self.noise_prob, self.key
+        )
+
+    def __iter__(self) -> Iterator[MaskedLmExample]:
+        key = self.key
+        sharding = jax.sharding.SingleDeviceSharding(jax.local_devices(backend="cpu")[0])
+
+        with use_cpu_device():
+            @functools.partial(eqx.filter_jit, out_shardings=sharding)
+            def _create_mlm_example(tokens, key):
+                tokens_array = tokens.array
+                targets = tokens_array.copy()
+
+                if self.mask_prob > 0:
+                    this_key, key = jax.random.split(key)
+                    mask_shape = tokens_array.shape
+                    mask = jax.random.bernoulli(this_key, self.mask_prob, mask_shape)
+
+                    rand = jax.random.uniform(this_key, mask_shape)
+                    mask_token = jnp.where(rand < 0.8, self.mask_token_id, tokens_array)
+                    random_tokens = jax.random.randint(this_key, mask_shape, 0, tokens_array.max() + 1)
+                    mask_token = jnp.where((rand >= 0.8) & (rand < 0.8 + self.noise_prob), random_tokens, mask_token)
+                    masked_tokens = jnp.where(mask, mask_token, tokens_array)
+
+                    # Set targets to the original tokens where mask is True, otherwise set to mask_token_id
+                    targets = jnp.where(mask, tokens_array, self.mask_token_id)
+
+                    masked_tokens_named = hax.named(masked_tokens, self.QPos)
+                    targets_named = hax.named(targets, self.QPos)
+
+                    attn_mask_shape = (tokens_array.shape[0], tokens_array.shape[0])
+                    attn_mask = hax.named(jnp.ones(attn_mask_shape, dtype=jnp.bool_), (self.QPos, self.KPos))
+
+                    example = MaskedLmExample.masked_lm(tokens=masked_tokens_named, targets=targets_named, mask_token_id=self.mask_token_id, attn_mask=attn_mask)
+                else:
+                    targets_named = hax.named(targets, self.QPos)
+                    attn_mask_shape = (tokens_array.shape[0], tokens_array.shape[0])
+                    attn_mask = hax.named(jnp.ones(attn_mask_shape, dtype=jnp.bool_), (self.QPos, self.KPos))
+
+                    example = MaskedLmExample.masked_lm(tokens=tokens, targets=targets_named, mask_token_id=self.mask_token_id, attn_mask=attn_mask)
+
+                return example
+
+        for tokens in self.dataset:
+            tokens_array = jnp.array(tokens)
+            tokens_named = hax.named(tokens_array, self.QPos)
+            example = _create_mlm_example(tokens_named, key)
+            yield example
+
+
 
 class CausalLmDataset(ShardableDataset[LmExample]):
     def __init__(
@@ -89,7 +164,6 @@ def __iter__(self) -> Iterator[LmExample]:
         sharding = jax.sharding.SingleDeviceSharding(jax.local_devices(backend="cpu")[0])
 
         with use_cpu_device():
-
             @functools.partial(eqx.filter_jit, out_shardings=sharding)
             def _create_lm_example(tokens, key):
                 tokens = hax.named(tokens, self.QPos)
@@ -97,10 +171,6 @@ def _create_lm_example(tokens, key):
                 example = LmExample.causal(tokens=tokens, ignore_id=self.ignore_id)
 
                 if self.fcm_prob > 0:
-                    # masks for attention
-                    # We support forgetful causal masking (FCM) which is a technique that improves training speed by
-                    # randomly masking out some of the context. This is a bit like dropout, but it's applied to the attention
-                    # mask instead of the activations. It's described in https://arxiv.org/abs/2210.13432
                     assert self.key is not None
                     this_key, key = jax.random.split(key)
                     fcm_mask = hax.nn.attention.forgetful_causal_mask(self.KPos, self.fcm_prob, key=this_key)
@@ -114,6 +184,7 @@ def _create_lm_example(tokens, key):
                 yield example
 
 
+
 class TokenSeqDataset(ShardableDataset[np.ndarray]):
     """
     A dataset that yields sequences of tokens of fixed length from a TokenizedDocumentCache.
@@ -826,4 +897,4 @@ def build_caches(
 
     @property
     def sources(self) -> dict[str, LMDatasetSourceConfig]:
-        return self.configs
+        return self.configs
\ No newline at end of file
diff --git a/src/levanter/main/train_mlm.py b/src/levanter/main/train_mlm.py
new file mode 100644
index 000000000..a54baf13d
--- /dev/null
+++ b/src/levanter/main/train_mlm.py
@@ -0,0 +1,219 @@
+# train_mlm.py
+
+import dataclasses
+import functools
+import gc
+import logging
+import os
+from dataclasses import dataclass, field
+from typing import Optional, Union
+
+import jax
+import jax.random as jrandom
+
+import haliax as hax
+from haliax import Axis
+from haliax.partitioning import named_jit, round_axis_for_partitioning
+
+import levanter
+from levanter import callbacks
+from levanter.compat.hf_checkpoints import HFCompatConfig, save_hf_checkpoint_callback
+from levanter.data.text import MaskedLmDataset, LMDatasetConfig, LMMixtureDatasetConfig
+from levanter.models.gpt2 import Gpt2Config
+from levanter.models.llama import LlamaConfig
+from levanter.models.lm_model import LmConfig, MaskedLmExample, compute_next_token_loss
+from levanter.models.roberta import RobertaConfig
+from levanter.optim import AdamConfig, OptimizerConfig
+from levanter.trainer import Trainer, TrainerConfig
+from levanter.utils.jax_utils import parameter_count
+
+logger = logging.getLogger(__name__)
+
+@dataclass
+class TrainMlmConfig:
+    data: Union[LMDatasetConfig, LMMixtureDatasetConfig] = field(default_factory=LMDatasetConfig)
+    trainer: TrainerConfig = field(default_factory=TrainerConfig)
+    model: LmConfig = field(default_factory=RobertaConfig)
+    optimizer: OptimizerConfig = field(default_factory=AdamConfig)
+
+    # config related to continued pretraining
+    initialize_from_hf: Union[bool, str] = False
+    """if provided, this will override the model config in the config. if true, use the default hf checkpoint for this model class"""
+    use_hf_model_config: bool = False  # if true, replace the model config with the hf config from the checkpoint
+
+    # TODO: atm we don't support loading from a checkpoint that has a different tokenizer. this is a bit annoying
+    # TODO: atm you have to at least specify a levanter model config with the same type as the hf checkpoint
+
+    mlm_prob: float = 0.15  # masking probability for MLM
+    hf_save_path: Optional[str] = None
+    hf_upload: Optional[str] = None
+    hf_save_steps: int = 10000
+
+    update_hessian_steps: int = 10
+    data_seed: Optional[int] = None  # if provided, will override the data seed from the trainer
+
+def main(config: TrainMlmConfig):
+    tokenizer = config.data.the_tokenizer
+
+    # this is some unpleasant code to allow us to initialize from a hf checkpoint. If this is your first read through,
+    # I recommend skipping it for now
+    if config.initialize_from_hf:
+        if config.trainer.initialize_from is not None:
+            raise ValueError("Cannot specify both initialize_from_hf and initialize_from")
+
+        assert isinstance(config.model, HFCompatConfig)
+        converter = config.model.hf_checkpoint_converter()
+        if hasattr(tokenizer, "vocab") and tokenizer.vocab != converter.tokenizer.vocab:
+            logger.warning("The tokenizers appear to be different. You may want to check this.")
+
+        if isinstance(config.initialize_from_hf, str):
+            converter = converter.replaced(reference_checkpoint=config.initialize_from_hf, tokenizer=tokenizer)
+        else:
+            converter = converter.replaced(tokenizer=tokenizer)
+
+        if config.use_hf_model_config:
+            # TODO: log diff of old and new config
+            # NB: gross mutability
+            config.model = converter.config_from_hf_config(converter.default_hf_config)
+    elif isinstance(config.model, HFCompatConfig):
+        converter = config.model.hf_checkpoint_converter()
+        converter = converter.replaced(tokenizer=tokenizer)
+    else:
+        converter = None
+
+    levanter.initialize(config)
+    optimizer = config.optimizer.build(config.trainer.num_train_steps)
+    # loss_function = functools.partial(compute_next_token_loss)
+
+    def loss_function(
+            model,
+            example: MaskedLmExample,
+            *,
+            key=None,
+            reduction: Optional[hax.ReductionFunction] = hax.mean,
+            reduction_axis: Optional[hax.AxisSelection] = None,
+    ):
+        return model.compute_loss(example, key=key, reduction=reduction, reduction_axis=reduction_axis)
+
+
+    # Using the trainer as a context manager does 3 things:
+    # 1. Sets the device mesh
+    # 2. Sets the axis mapping (for fsdp)
+    # 3. Sets the global metrics tracker
+    # with Trainer(config.trainer, optimizer, loss_function) as trainer, jax.disable_jit(True):
+    with Trainer(config.trainer, optimizer, loss_function) as trainer:
+        # randomness in jax is tightly controlled by "keys" which are the states of the random number generators
+        # this makes deterministic training pretty easy
+        seed = config.trainer.seed
+        data_key, loader_key, model_key, training_key = jrandom.split(jrandom.PRNGKey(seed), 4)
+
+        if config.data_seed is not None:
+            logger.info(f"Overriding data seed with {config.data_seed}")
+            data_key = jrandom.PRNGKey(config.data_seed)
+
+        # We have two axis_mappings: one for storing the model and optimizer states, and one for compute
+        # This allows Zero-3-style parameter sharding, where we shard the parameters and optimizer state across the mesh
+        compute_axis_mapping = trainer.compute_axis_mapping
+        parameter_axis_mapping = trainer.parameter_axis_mapping
+
+        # some axes we need
+        Batch = config.trainer.TrainBatch
+        EvalBatch = config.trainer.EvalBatch
+        Pos = config.model.Pos
+        KeyPos = config.model.KeyPos
+
+        tagged_eval_datasets = config.data.tagged_eval_sets(Pos.size)
+        mask_id = tokenizer.mask_token_id
+        train_dataset = MaskedLmDataset(
+            config.data.train_set(Pos.size, key=data_key), Pos, KeyPos,
+            mask_token_id=mask_id,
+            mask_prob=config.mlm_prob, key=data_key, #ignore_index=config.data.ignore_token_id
+        )
+
+        # to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to
+        # For most things, we just insist you specify the config right, but tokenizers often have strange numbers of
+        # tokens: gpt-2 has 50257, for example. So we round up.
+        vocab_size = len(tokenizer)
+        Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), parameter_axis_mapping)
+        if vocab_size != Vocab.size:
+            logger.info(f"Rounding vocab size from {vocab_size} to {Vocab.size} for partitioning")
+
+        state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key))
+
+        if int(state.step) == 0:
+            # TODO: I don't love that we init the model twice, but it's not a big deal i think?
+            if config.initialize_from_hf:
+                # initialize from an hf pretrained model
+                logger.info(
+                    "No training checkpoint found. Initializing model from HF checkpoint"
+                    f" '{converter.reference_checkpoint}'"
+                )
+                # this is a bit gross, but we want to free up the memory from the model we just built
+                state = dataclasses.replace(state, model=None)
+                gc.collect()
+                model = converter.load_pretrained(
+                    config.model.model_type,
+                    config.model,
+                    axis_mapping=parameter_axis_mapping,
+                    dtype=trainer.mp.compute_dtype,
+                )
+                model = named_jit(trainer.mp.cast_to_param, parameter_axis_mapping)(model)
+                state = dataclasses.replace(state, model=model)
+            else:
+                logger.info("No checkpoint found. Starting from scratch.")
+
+        levanter.tracker.log_summary({"parameter_count": parameter_count(state.model)})
+
+        if len(tagged_eval_datasets) == 0:
+            logger.warning("No evaluation datasets provided.")
+        else:
+            masked_datasets = [
+                (MaskedLmDataset(ds, Pos, KeyPos, mask_prob=config.mlm_prob, key=data_key, mask_token_id=mask_id), tags)
+                for ds, tags in tagged_eval_datasets
+            ]
+            max_eval_examples_per_ds = config.trainer.max_eval_batches
+            if max_eval_examples_per_ds is not None:
+                max_eval_examples_per_ds *= config.trainer.eval_batch_size
+
+            cb = levanter.eval.cb_tagged_lm_evaluate(
+                EvalBatch, masked_datasets, trainer.device_mesh, compute_axis_mapping, max_eval_examples_per_ds
+            )
+            trainer.add_hook(cb, every=config.trainer.steps_per_eval)
+
+        flops_per_token = config.model.flops_per_token(vocab_size)
+        flops_per_example = 3 * flops_per_token * Pos.size if flops_per_token is not None else None
+        trainer.add_hook(
+            callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size, flops_per_example), every=1
+        )
+        if config.hf_save_path is not None:
+            full_save_path = os.path.join(config.hf_save_path, trainer.run_id)
+
+            trainer.add_hook(
+                save_hf_checkpoint_callback(full_save_path, converter, upload_to_hf=config.hf_upload or False),
+                every=config.hf_save_steps,
+            )
+
+        # visualize log probs
+        @named_jit(
+            in_axis_resources=parameter_axis_mapping,
+            axis_resources=compute_axis_mapping,
+            out_axis_resources=compute_axis_mapping,
+        )
+        def compute_log_probs(model, example):
+            model = trainer.mp.cast_to_compute(model)
+            logprobs = model.compute_loss(example, key=None, reduction=None)
+            # roll forward to get the loss for each predicted token
+            logprobs = hax.roll(logprobs, 1, Pos)
+            return logprobs.rearrange((EvalBatch, Pos)).array
+
+        train_loader = iter(trainer.sharded_loader(train_dataset, Batch))
+
+        if int(state.step) > 0:
+            import tqdm
+            for _ in tqdm.tqdm(range(state.step), desc="seeking data for resume"):
+                next(train_loader)
+
+        trainer.train(state, train_loader)
+
+if __name__ == "__main__":
+    levanter.config.main(main)()
\ No newline at end of file
diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py
index 468f6a4a4..926384cab 100644
--- a/src/levanter/models/lm_model.py
+++ b/src/levanter/models/lm_model.py
@@ -3,6 +3,7 @@
 
 import draccus
 import equinox as eqx
+import jax
 import jax.numpy as jnp
 from jax.random import PRNGKey
 
@@ -16,6 +17,36 @@
 LmConfigT = TypeVar("LmConfigT", bound="LmConfig")
 LmT = TypeVar("LmT", bound="LmHeadModel")
 
+class MaskedLmExample(eqx.Module):
+    tokens: hax.NamedArray
+    loss_mask: hax.NamedArray
+    attn_mask: hax.NamedArray
+    targets: Optional[hax.NamedArray] = None
+
+    @staticmethod
+    def masked_lm(
+        tokens: hax.NamedArray, targets: hax.NamedArray, attn_mask: hax.NamedArray, mask_token_id: Optional[int] = None
+    ) -> "MaskedLmExample":
+        if tokens.ndim != 1:
+            raise ValueError("tokens must be a 1D array")
+
+        if not jnp.issubdtype(tokens.dtype, jnp.integer):
+            raise ValueError("tokens must be an integer array")
+
+        if tokens.shape != targets.shape:
+            raise ValueError("tokens and targets must have the same shape")
+
+        Pos = tokens.axes[0]
+
+        mask = tokens.array != targets.array
+        loss_mask = hax.named(mask.astype(jnp.float32), Pos)
+
+        if mask_token_id is not None:
+            ignore_mask = targets.array != mask_token_id
+            loss_mask = loss_mask * hax.named(ignore_mask.astype(jnp.float32), Pos)
+
+        return MaskedLmExample(tokens=tokens, targets=targets, loss_mask=loss_mask, attn_mask=attn_mask)
+
 
 class LmExample(eqx.Module):
     tokens: hax.NamedArray
@@ -34,12 +65,10 @@ def causal(
 
         Pos = tokens.axes[0]
 
-        # don't predict the last token.
         if loss_mask is None:
             loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jnp.float32)
 
         if ignore_id is not None:
-            # we don't compute loss for any tokens matching the ignore index
             ignore_mask = hax.roll(tokens, -1, Pos) != ignore_id
             loss_mask = loss_mask * ignore_mask
 
@@ -47,7 +76,6 @@ def causal(
         return LmExample(tokens=tokens, loss_mask=loss_mask, attn_mask=attn_mask)
 
 
-# TODO: for some reason, mypy doesn't like the discover_packages_path argument?
 class LmConfig(draccus.PluginRegistry, abc.ABC, Generic[LmT], discover_packages_path="levanter.models"):  # type: ignore
     @property
     @abc.abstractmethod
@@ -70,12 +98,7 @@ def flops_per_token(self, vocab_size: int) -> Optional[float]:
     def build(self, Vocab: Axis, *, key: PRNGKey) -> "LmT":
         return self.model_type.init(Vocab, self, key=key)  # type: ignore
 
-
 class LmHeadModel(Generic[LmConfigT], abc.ABC):
-    """
-    Superclass for models with a language modeling head.
-    """
-
     @property
     @abc.abstractmethod
     def config(self) -> LmConfigT:
@@ -103,14 +126,11 @@ def init(cls, Vocab: Axis, config: LmConfigT, *, key: PRNGKey) -> "LmHeadModel[L
     def __call__(
         self, input_ids: NamedArray, attn_mask: Optional[AttentionMask | NamedArray] = None, *, key=None
     ) -> NamedArray:
-        pass
+        print(f"input_ids shape: {input_ids.shape}")
+        print(f"attn_mask shape: {attn_mask.shape}")
 
     @abc.abstractmethod
     def resize_vocab(self, new_size: int, key: Optional[PRNGKey] = None) -> "LmHeadModel[LmConfigT]":
-        """
-        Resizes the vocabulary of the model. Key may be provided to use random initialization, otherwise, there
-        should be some deterministic initialization of any new parameters.
-        """
         pass
 
     @property
@@ -149,3 +169,35 @@ def compute_next_token_loss(
     )
 
     return loss
+
+def compute_next_token_loss(
+    model: LmHeadModel,
+    example: LmExample,
+    *,
+    key=None,
+    reduction: Optional[hax.ReductionFunction] = hax.mean,
+    reduction_axis: Optional[hax.AxisSelection] = None,
+    logsumexp_weight: Optional[float] = None,
+    loss_dtype: Optional[Type[jnp.dtype]] = jnp.float32,
+) -> jnp.ndarray | NamedArray:
+    """
+    Computes the cross-entropy loss for a language modeling example. If reduction is not None, the loss is reduced
+    across the reduction axis (with reduction_axis=None meaning all axes). If reduction is None, the loss is not
+    reduced, and the result is a named array with axes (*batch axes, sequence_length).
+    """
+    logits = model(example.tokens, example.attn_mask, key=key)
+    if loss_dtype is not None:
+        logits = logits.astype(loss_dtype)
+
+    loss = next_token_loss(
+        model.Pos,
+        model.Vocab,
+        logits,
+        example.tokens,
+        loss_mask=example.loss_mask,
+        reduction=reduction,
+        reduction_axis=reduction_axis,
+        logsumexp_weight=logsumexp_weight,
+    )
+
+    return loss
\ No newline at end of file
diff --git a/src/levanter/models/roberta.py b/src/levanter/models/roberta.py
new file mode 100644
index 000000000..816f7f1ad
--- /dev/null
+++ b/src/levanter/models/roberta.py
@@ -0,0 +1,921 @@
+import dataclasses
+from dataclasses import dataclass
+from typing import Callable, Dict, List, Optional, Tuple, Type, Union
+
+import equinox as eqx
+import jax
+import jax.numpy as jnp
+import jax.random as jrandom
+from jaxtyping import PRNGKeyArray
+
+import haliax as hax
+import haliax.nn as hnn
+from haliax.nn import cross_entropy_loss
+from haliax import Axis, AxisSpec, NamedArray
+from haliax.jax_utils import maybe_rng_split, named_call, shaped_rng_split
+from haliax.nn.scan import BlockSeq, Stacked
+
+from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig
+from levanter.compat.torch_serialization import (
+    StateDict,
+    StateDictSerializationMixin,
+    apply_prefix,
+    flatten_linear_layers,
+    stack_state_dict,
+    unflatten_linear_layers,
+    unstack_state_dict,
+)
+from levanter.logging import silence_transformer_nag
+from levanter.models.attention import AttentionBackend, AttentionMask, simple_attention_with_dropout
+from levanter.models.gpt2 import ACT2FN
+from levanter.models.lm_model import LmConfig, LmHeadModel, MaskedLmExample
+from levanter.types import BlockFoldable
+from levanter.utils.flop_utils import lm_flops_per_token
+
+silence_transformer_nag()
+from transformers import PretrainedConfig as HfConfig
+from transformers import RobertaConfig as HfRobertaConfig
+
+
+
+@LmConfig.register_subclass("roberta")
+@dataclass(frozen=True)
+class RobertaConfig(HFCompatConfig):
+    r"""
+
+    Adapted from HuggingFace RobertaConfig, description below
+
+
+    This is the configuration class to store the configuration of a [`RobertaModel`] or a [`TFRobertaModel`]. It is
+    used to instantiate a RoBERTa model according to the specified arguments, defining the model architecture.
+    Instantiating a configuration with the defaults will yield a similar configuration to that of the RoBERTa
+    [FacebookAI/roberta-base](https://huggingface.co/FacebookAI/roberta-base) architecture.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+
+    Args:
+        vocab_size (`int`, *optional*, defaults to 50265):
+            Vocabulary size of the RoBERTa model. Defines the number of different tokens that can be represented by the
+            `inputs_ids` passed when calling [`RobertaModel`] or [`TFRobertaModel`].
+        hidden_size (`int`, *optional*, defaults to 768):
+            Dimensionality of the encoder layers and the pooler layer.
+        num_hidden_layers (`int`, *optional*, defaults to 12):
+            Number of hidden layers in the Transformer encoder.
+        num_attention_heads (`int`, *optional*, defaults to 12):
+            Number of attention heads for each attention layer in the Transformer encoder.
+        intermediate_size (`int`, *optional*, defaults to 3072):
+            Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
+        hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
+            The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+            `"relu"`, `"silu"` and `"gelu_new"` are supported.
+        hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+            The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+        attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
+            The dropout ratio for the attention probabilities.
+        max_position_embeddings (`int`, *optional*, defaults to 512):
+            The maximum sequence length that this model might ever be used with. Typically set this to something large
+            just in case (e.g., 512 or 1024 or 2048).
+        type_vocab_size (`int`, *optional*, defaults to 2):
+            The vocabulary size of the `token_type_ids` passed when calling [`RobertaModel`] or [`TFRobertaModel`].
+        initializer_range (`float`, *optional*, defaults to 0.02):
+            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+        layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+            The epsilon used by the layer normalization layers.
+        position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
+            Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
+            positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
+            [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
+            For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
+            with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
+        is_decoder (`bool`, *optional*, defaults to `False`):
+            Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
+        use_cache (`bool`, *optional*, defaults to `True`):
+            Whether or not the model should return the last key/values attentions (not used by all models). Only
+            relevant if `config.is_decoder=True`.
+        classifier_dropout (`float`, *optional*):
+            The dropout ratio for the classification head.
+
+    Examples:
+
+    ```python
+    >>> from transformers import RobertaConfig, RobertaModel
+
+    >>> # Initializing a RoBERTa configuration
+    >>> configuration = RobertaConfig()
+
+    >>> # Initializing a model (with random weights) from the configuration
+    >>> model = RobertaModel(configuration)
+
+    >>> # Accessing the model configuration
+    >>> configuration = model.config
+    ```"""
+
+    vocab_size: int = 50265
+    hidden_size: int = 768
+    num_hidden_layers: int = 12
+    num_attention_heads: int = 12
+    intermediate_size: int = 3072
+    hidden_act: str = "gelu"
+    hidden_dropout_prob: float = 0.1
+    attention_probs_dropout_prob: float = 0.1
+    max_position_embeddings: int = 512
+    type_vocab_size: int = 2
+    initializer_range: float = 0.02
+    layer_norm_eps: float = 1e-12
+    pad_token_id: int = 1
+    bos_token_id: int = 0
+    eos_token_id: int = 2
+    position_embedding_type: Optional[str] = "absolute"
+    use_cache: bool = False
+    classifier_dropout: Optional[float] = None
+
+    scan_layers: bool = True
+    gradient_checkpointing: bool = True
+
+    reference_checkpoint: str = "FacebookAI/roberta-base"
+    tokenizer: Optional[str] = None
+
+    # Axes
+    Pos = property(lambda self: Axis(name="position", size=self.max_position_embeddings))
+    KeyPos = property(lambda self: self.Pos.alias("key_position"))
+    Embed = property(lambda self: Axis(name="embed", size=self.hidden_size))
+    EmbedAtt = property(lambda self: self.Embed.alias("embed_att"))
+    FinalEmbed = property(lambda self: self.Embed.alias("final_embed"))
+    Heads = property(lambda self: Axis(name="heads", size=self.num_attention_heads))
+    Layers = property(lambda self: Axis(name="layers", size=self.num_hidden_layers))
+    Mlp = property(lambda self: Axis(name="mlp", size=self.intermediate_size))
+    HeadSize = property(lambda self: Axis(name="head_size", size=self.hidden_size // self.num_attention_heads))
+
+
+    @classmethod
+    def from_hf_config(cls, hf_config: HfConfig) -> "RobertaConfig":
+        return RobertaConfig(
+            vocab_size = hf_config.vocab_size,
+            hidden_size = hf_config.hidden_size,
+            num_hidden_layers = hf_config.num_hidden_layers,
+            num_attention_heads = hf_config.num_attention_heads,
+            intermediate_size = hf_config.intermediate_size,
+            hidden_act = hf_config.hidden_act,
+            hidden_dropout_prob= hf_config.hidden_dropout_prob,
+            attention_probs_dropout_prob = hf_config.attention_probs_dropout_prob,
+            max_position_embeddings = hf_config.max_position_embeddings,
+            type_vocab_size = hf_config.type_vocab_size,
+            initializer_range = hf_config.initializer_range,
+            layer_norm_eps = hf_config.layer_norm_eps,
+            pad_token_id = hf_config.pad_token_id,
+            bos_token_id = hf_config.bos_token_id,
+            eos_token_id = hf_config.eos_token_id,
+            position_embedding_type = hf_config.position_embedding_type,
+            use_cache = hf_config.use_cache,
+            classifier_dropout = hf_config.classifier_dropout,
+        )
+    
+    def hf_checkpoint_converter(self) -> HFCheckpointConverter["RobertaConfig"]:  # type: ignore
+        return HFCheckpointConverter(
+            self.__class__,
+            reference_checkpoint=self.reference_checkpoint,
+            trust_remote_code=True,
+            tokenizer=self.tokenizer if self.tokenizer else self.reference_checkpoint,
+            HfConfigClass=HfRobertaConfig,
+        )
+
+    def to_hf_config(self, vocab_size: int, config_overrides: Optional[Dict] = None) -> HfRobertaConfig:
+        """Convert to HuggingFace's LlamaConfig
+
+        Args:
+            vocab_size (int, optional): Vocabulary size of the tokenizer. Defaults to 32000.
+            config_overrides (dict, optional): Overrides for the config. Defaults to None.
+
+        Returns:
+            HfRobertaConfig: HuggingFace's RobertaConfig
+        """
+
+        if config_overrides is None:
+            config_overrides = {}
+
+        return HfRobertaConfig(
+            vocab_size = vocab_size,
+            hidden_size = self.hidden_size,
+            num_hidden_layers = self.num_hidden_layers,
+            num_attention_heads = self.num_attention_heads,
+            intermediate_size = self.intermediate_size,
+            hidden_act = self.hidden_act,
+            hidden_dropout_prob = self.hidden_dropout_prob,
+            attention_probs_dropout_prob = self.attention_probs_dropout_prob,
+            max_position_embeddings = self.max_position_embeddings,
+            type_vocab_size = self.type_vocab_size,
+            initializer_range = self.initializer_range,
+            layer_norm_eps = self.layer_norm_eps,
+            pad_token_id = self.pad_token_id,
+            bos_token_id = self.bos_token_id,
+            eos_token_id = self.eos_token_id,
+            position_embedding_type = self.position_embedding_type,
+            use_cache = self.use_cache,
+            classifier_dropout = self.classifier_dropout,
+        )
+
+    @property
+    def model_type(self) -> Type["RobertaForMaskedLM"]:
+        return RobertaForMaskedLM
+    
+    def flops_per_token(self, vocab_size: int):
+        return lm_flops_per_token(
+            hidden_dim=self.hidden_size,
+            intermediate_dim=self.intermediate_size,
+            num_layers=self.num_hidden_layers,
+            num_kv_heads=self.num_attention_heads,
+            num_heads=self.num_attention_heads,
+            seq_len=self.max_position_embeddings,
+            vocab_size=vocab_size,
+            glu=True,
+        )
+
+class RobertaSelfAttention(eqx.Module, StateDictSerializationMixin):
+
+    config: RobertaConfig
+    Heads: Axis
+    HeadSize: Axis
+    EmbedAtt: Axis
+
+    q_proj: hnn.Linear
+    k_proj: hnn.Linear
+    v_proj: hnn.Linear
+    
+    dropout: hnn.Dropout
+    position_embedding_type: Optional[str]
+
+    Pos: Axis
+    KeyPos: Axis
+    distance_embedding: Optional[hnn.Embedding]
+
+    @staticmethod
+    def init(config: RobertaConfig, *, key) -> "RobertaSelfAttention":        
+        Embed = config.Embed
+        EmbedAtt = config.EmbedAtt
+
+        k_q, k_k, k_v, k_e = jrandom.split(key, 4)
+        q_proj = hnn.Linear.init(In=Embed, Out=EmbedAtt, key=k_q, out_first=True)
+        k_proj = hnn.Linear.init(In=Embed, Out=EmbedAtt, key=k_k, out_first=True)
+        v_proj = hnn.Linear.init(In=Embed, Out=EmbedAtt, key=k_v, out_first=True)
+
+        dropout = hnn.Dropout(config.attention_probs_dropout_prob)
+
+        distance_embedding = None
+        position_embedding_type = config.position_embedding_type
+
+        if position_embedding_type == "relative_key" or position_embedding_type == "relative_key_query":
+            RelPos = Axis("rel_pos", 2 * config.max_position_embeddings - 1)
+            distance_embedding = hnn.Embedding.init(RelPos, config.HeadSize, k_e)
+
+        return RobertaSelfAttention(config, config.Heads, config.HeadSize, EmbedAtt,
+                                    q_proj, k_proj, v_proj,
+                                    dropout, position_embedding_type,
+                                    config.Pos, config.KeyPos, distance_embedding,
+                                    )
+
+    def _state_dict_key_map(self) -> Dict[str, Optional[str]]:
+        return {"q_proj": "query", "k_proj": "key", "v_proj": "value"}
+
+    def _rope_scale_factor(self) -> float:
+        # hasattr for gemma and I'm feeling lazy
+        if hasattr(self.config, "rope_scaling") and self.config.rope_scaling is not None:
+            assert self.config.rope_scaling["type"] == "linear"
+            return self.config.rope_scaling["factor"]
+        return 1.0
+
+    def transpose_for_scores(self, x: NamedArray) -> NamedArray:
+        # Makes sure to have the correct output order as well
+        y = hax.rearrange(x, "... position (embed_att: heads head_size) -> ... heads position head_size", heads=self.Heads, head_size=self.HeadSize)
+        return y
+
+    @named_call
+    def __call__(
+        self,
+        hidden_states: NamedArray,
+        attention_mask: Optional[NamedArray] = None,
+        *,
+        key = None
+    ) -> Tuple[NamedArray]:
+
+        query_layer = self.transpose_for_scores(self.q_proj(hidden_states))
+        key_layer = self.transpose_for_scores(self.k_proj(hidden_states))
+        value_layer = self.transpose_for_scores(self.v_proj(hidden_states))
+
+        if self.position_embedding_type == "rope":
+            cos, sin = llama_rotary_pos_emb(
+                self.config.HeadSize, hidden_states.resolve_axis("position"), scale=self._rope_scale_factor()
+            )
+            query_layer, key_layer = _apply_rotary_pos_emb(query_layer, key_layer, cos, sin)
+
+        key_layer = key_layer.rename({"position": "key_position"})
+        value_layer = value_layer.rename({"position": "key_position"})
+
+        attention_scores = hax.dot(query_layer, key_layer, axis=self.HeadSize) # aka hax.einsum("bhld, bhrd -> bhlr")
+
+        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+            Left = self.Pos # Queries
+            Right = self.KeyPos # Keys
+            
+            position_ids_l = hax.arange(Left).broadcast_to((Left,Right))
+            position_ids_r = hax.arange(Right).broadcast_to((Left,Right))
+
+            distance = position_ids_l - position_ids_r
+
+            positional_embedding = self.distance_embedding(distance + self.Pos.size)
+
+            if self.position_embedding_type == "relative_key":
+                relative_position_scores = hax.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+                attention_scores = attention_scores + relative_position_scores
+            elif self.position_embedding_type == "relative_key_query":
+                relative_position_scores_query = hax.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+                relative_position_scores_key = hax.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
+                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
+        
+        attention_scores /= jnp.sqrt(self.HeadSize.size)
+
+        if attention_mask is not None:
+            # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
+            # Attention_mask should have shape Batch Pos, so it should broadcast to shape Batch Heads Pos KeyPos for summation
+            attention_scores = attention_scores + attention_mask 
+        
+        attention_probs = hnn.softmax(attention_scores, axis=self.KeyPos)
+
+        # This is actually dropping out entire tokens to attend to, which might
+        # seem a bit unusual, but is taken from the original Transformer paper.
+        attention_probs = self.dropout(attention_probs, key=key)
+
+        hax.dot(query_layer, key_layer, axis=self.HeadSize)
+
+        context_layer = hax.dot(attention_probs, value_layer, axis=self.KeyPos)
+        
+        outputs = hax.rearrange(context_layer, ("... heads position head_size -> ... position (embed_att: heads head_size)"), heads=self.Heads, head_size=self.HeadSize)
+
+        # jax.debug.breakpoint()
+
+        return outputs
+
+class RobertaSelfOutput(eqx.Module, StateDictSerializationMixin):
+    dense: hnn.Linear
+    LayerNorm: hnn.LayerNorm
+    dropout: hnn.Dropout
+
+    @staticmethod
+    def init(config: RobertaConfig, *, key) -> "RobertaSelfOutput":
+        Embed = config.Embed
+        EmbedAtt = config.EmbedAtt
+        dense = hnn.Linear.init(In=EmbedAtt, Out=Embed, key=key, out_first=True)
+        LayerNorm = hnn.LayerNorm.init(axis=Embed, eps=config.layer_norm_eps)
+        dropout = hnn.Dropout(config.hidden_dropout_prob)
+        return RobertaSelfOutput(dense, LayerNorm, dropout)
+    
+    @named_call
+    def __call__(self, hidden_states: NamedArray, input: NamedArray,*, key) -> NamedArray:
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.dropout(hidden_states, key=key)
+        hidden_states = self.LayerNorm(hidden_states + input)
+        return hidden_states
+
+class RobertaAttention(eqx.Module, StateDictSerializationMixin):
+    self_attn: RobertaSelfAttention
+    output: RobertaSelfOutput
+
+    @staticmethod
+    def init(config: RobertaConfig, *, key) -> "RobertaAttention":
+        k_a, k_o = jrandom.split(key, 2)
+
+        self_attn = RobertaSelfAttention.init(config, key=k_a)
+        output = RobertaSelfOutput.init(config, key=k_o)
+
+        return RobertaAttention(self_attn, output)
+
+    def _state_dict_key_map(self) -> Dict[str, Optional[str]]:
+        return {"self_attn": "self"}
+
+    @named_call
+    def __call__(
+        self,
+        hidden_states: NamedArray,
+        attention_mask: Optional[NamedArray] = None,
+        *,
+        key
+    ) -> NamedArray:
+        k_a, k_o = maybe_rng_split(key, 2)
+        
+        self_outputs = self.self_attn(
+            hidden_states,
+            attention_mask,
+            key=k_a
+        )
+        attention_output = self.output(self_outputs, hidden_states, key=k_o)
+        return attention_output
+    
+class RobertaIntermediate(eqx.Module, StateDictSerializationMixin):
+    dense: hnn.Linear
+    intermediate_act_fn: Callable = eqx.static_field()
+
+    @staticmethod
+    def init(config, *, key) -> "RobertaIntermediate":
+        dense = hnn.Linear.init(config.Embed, config.Mlp, key=key, out_first=True)
+        if isinstance(config.hidden_act, str):
+            intermediate_act_fn = ACT2FN[config.hidden_act]
+        else:
+            intermediate_act_fn = config.hidden_act
+
+        return RobertaIntermediate(dense, intermediate_act_fn)
+
+    @named_call
+    def __call__(self, hidden_states: NamedArray, *, key = None) -> NamedArray:
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.intermediate_act_fn(hidden_states)
+        return hidden_states
+
+class RobertaOutput(eqx.Module, StateDictSerializationMixin):
+    dense: hnn.Linear
+    LayerNorm: hnn.LayerNorm
+    dropout: hnn.Dropout
+
+    @staticmethod
+    def init(config: RobertaConfig, *, key) -> "RobertaSelfOutput":
+        Embed = config.Embed
+        dense = hnn.Linear.init(In=config.Mlp, Out=Embed, key=key, out_first=True)
+        LayerNorm = hnn.LayerNorm.init(axis=Embed, eps=config.layer_norm_eps)
+        dropout = hnn.Dropout(config.hidden_dropout_prob)
+        return RobertaSelfOutput(dense, LayerNorm, dropout)
+
+    @named_call
+    def __call__(self, hidden_states: NamedArray, input: NamedArray, *, key) -> NamedArray:
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.dropout(hidden_states, key=key)
+        hidden_states = self.LayerNorm(hidden_states + input)
+        return hidden_states
+    
+class RobertaLayer(eqx.Module, StateDictSerializationMixin):
+    attention: RobertaAttention
+    intermediate: RobertaIntermediate
+    output: RobertaOutput
+    
+    @staticmethod
+    def init(config: RobertaConfig, *, key) -> "RobertaLayer":
+        k_a, k_i, k_o = jrandom.split(key, 3)
+
+        attention = RobertaAttention.init(config, key=k_a)
+        intermediate = RobertaIntermediate.init(config, key=k_i)
+        output = RobertaOutput.init(config, key=k_o)
+
+        return RobertaLayer(attention, intermediate, output)
+    
+    @named_call
+    def __call__(
+        self,
+        hidden_states: NamedArray,
+        attention_mask: Optional[NamedArray] = None,
+        *,
+        key
+    ) -> Tuple[NamedArray]:
+        k_a, k_o = maybe_rng_split(key, 2)
+
+        attention_output = self.attention(
+            hidden_states,
+            attention_mask,
+            key=k_a, 
+        )
+
+        intermediate_output = self.intermediate(attention_output)
+        layer_output = self.output(intermediate_output, attention_output, key=k_o)
+
+        # # jax.debug.print("{layer_output}", layer_output=layer_output)
+
+        # return (layer_output, layer_output)
+        return layer_output
+
+
+class RobertaEncoder(eqx.Module, StateDictSerializationMixin):
+    config: RobertaConfig
+    layer: BlockFoldable[RobertaLayer]
+    output_hidden_states: bool
+
+    @staticmethod
+    def init(config: RobertaConfig, output_hidden_states: bool = False, *, key) -> "RobertaEncoder":
+        S = BlockSeq
+
+        layer = S.init(config.Layers, RobertaLayer, gradient_checkpointing=config.gradient_checkpointing)(
+            config,
+            key=shaped_rng_split(key, config.num_hidden_layers), #TODO: config.gradient_checkpointing
+        )
+
+        return RobertaEncoder(config, layer, output_hidden_states)
+
+    @named_call
+    def __call__(
+        self,
+        hidden_states: NamedArray,
+        attention_mask: Optional[NamedArray] = None,
+        *,
+        key
+    ) -> Tuple[NamedArray]:
+        
+        keys = maybe_rng_split(key, self.config.num_hidden_layers) if key is not None else None
+
+        # x, intermediates = self.layer.scan(hidden_states, attention_mask, key=keys)
+        x = self.layer.fold(hidden_states, attention_mask, key=keys)
+
+        return x, None
+
+        # if not self.output_hidden_states:
+        #     return x, None
+        # else:
+        #      return x, intermediates
+    
+    def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None):
+        out = super().from_state_dict(state_dict, prefix=prefix)
+        return out
+
+    def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict:
+        my_state_dict: StateDict = {}
+        super().update_state_dict(my_state_dict, prefix=prefix)
+
+        state_dict.update(my_state_dict)
+
+        return state_dict
+
+class RobertaEmbedding(eqx.Module, StateDictSerializationMixin):
+    Vocab: Axis = eqx.static_field()
+    Pos: Axis = eqx.static_field()
+
+    word_embeddings: hnn.Embedding
+    position_embeddings: hnn.Embedding
+    token_type_embeddings: Optional[hnn.Embedding]
+    padding_idx: NamedArray
+
+    LayerNorm: hnn.LayerNorm
+    dropout: hnn.Dropout
+    position_embedding_type: Optional[str]
+
+    @staticmethod
+    def init(Vocab: Axis, config: RobertaConfig, *, key) -> "RobertaEmbedding":
+        key_w, key_p, key_t = jrandom.split(key, 3)
+
+        padding_idx = config.pad_token_id
+
+        word_embeddings = hnn.Embedding.init(Vocab, config.Embed, key=key_w) # padding_idx not specified
+        position_embeddings = hnn.Embedding.init(config.Pos, config.Embed, key=key_p)
+
+        Token = hax.Axis("token", config.type_vocab_size)
+
+        token_type_embeddings = hnn.Embedding.init(Token, config.Embed, key=key_t)
+        
+        LayerNorm = hnn.LayerNorm.init(config.Embed, config.layer_norm_eps)
+        dropout = hnn.Dropout(config.hidden_dropout_prob)
+
+        return RobertaEmbedding(Vocab, config.Pos, word_embeddings, position_embeddings, token_type_embeddings, padding_idx, LayerNorm, dropout, config.position_embedding_type)
+
+    def create_position_ids_from_input_ids(self, input_ids, past_key_values_length=0):
+        mask = hax.not_equal(input_ids, self.padding_idx) * 1
+        incremental_indices = (hax.cumsum(mask, axis=self.Pos).astype(mask) + past_key_values_length) * mask
+        incremental_indices -= mask.all(axis=self.Pos)
+        return incremental_indices
+
+    def create_position_ids_from_inputs_embeds(self, input_axes, PosInput):
+        position_ids = hax.arange(axis = PosInput, start = 0, dtype=jnp.int32)
+        # position_ids = hax.arange(axis = PosInput, start = self.padding_idx + 1, dtype=jnp.int32)
+
+        return hax.broadcast_to(position_ids, input_axes)
+
+    @named_call
+    def embed(self, input_ids=None, token_type_ids=None, position_ids=None, input_embeds=None, past_key_values_length=0, *, key = None):
+        # if input_ids is not None:
+        #     print(f"input_ids: {input_ids.dtype}")
+        # else:
+        #     print("input_ids: None")
+
+        # if token_type_ids is not None:
+        #     print(f"token_type_ids: {token_type_ids.dtype}")
+        # else:
+        #     print("token_type_ids: None")
+
+        # if position_ids is not None:
+        #     print(f"position_ids: {position_ids.dtype}")
+        # else:
+        #     print("position_ids: None")
+        
+        # if input_embeds is not None:
+        #     print(f"input_embeds: {input_embeds.dtype}")
+        # else:
+        #     print("input_embeds: None")
+
+        """
+        Note: When inputting your own embeds into input_embeds, make sure that the embeds axis has the name "embed"
+        for compatibility with the position_id creation function. Make sures its length is not equal to 
+        """
+        
+        # Get Axes
+        if input_ids is not None:
+            input_axes = input_ids.axes
+        else:
+            input_axes = hax.eliminate_axes(input_embeds.axes, "embed")
+
+        # Get position_ids
+        if position_ids is None:
+            if input_ids is not None:
+                # Create the position ids from the input token ids. Any padded tokens remain padded.
+                position_ids = self.create_position_ids_from_input_ids(input_ids, past_key_values_length)
+            else:
+                position_ids = self.create_position_ids_from_inputs_embeds(input_axes, input_embeds.resolve_axis("position"))
+        
+        # Get token_type_ids
+        if token_type_ids is None:
+            token_type_ids = hax.zeros(input_axes, dtype=jnp.int32)
+
+        if input_embeds is None:
+            input_embeds = self.word_embeddings(input_ids)
+
+        token_type_embeddings = self.token_type_embeddings(token_type_ids)
+        embeddings = input_embeds + token_type_embeddings
+
+        if self.position_embedding_type == "absolute":
+            position_embeddings = self.position_embeddings(position_ids)
+            embeddings += position_embeddings
+        
+        embeddings = self.LayerNorm(embeddings)
+        embeddings = self.dropout(embeddings, key=key)
+
+        # jax.debug.breakpoint()
+
+        return embeddings
+
+class RobertaPooler(eqx.Module, StateDictSerializationMixin):
+    dense: hnn.Linear
+    config: RobertaConfig
+
+    @staticmethod
+    def init(config: RobertaConfig, *, key):
+        dense = hnn.Linear.init(In=config.Embed, Out=config.FinalEmbed, key=key, out_first=True)
+
+        return RobertaPooler(dense, config)
+
+    @named_call
+    def __call__(self, hidden_states: NamedArray, *, key=None) -> NamedArray:
+        first_token = hidden_states[{"position" : 0}]
+        x = self.dense(first_token, key=key).rename({self.config.FinalEmbed: self.config.Embed})
+        x = hax.tanh(x)
+        return x
+
+
+class RobertaModel(eqx.Module, StateDictSerializationMixin):
+    encoder: RobertaEncoder
+    embeddings: RobertaEmbedding
+    pooler : Optional[RobertaPooler]
+    output_hidden_states: bool
+
+    @staticmethod
+    def init(Vocab: Axis, config: RobertaConfig, add_pooling_layer: bool = True, output_hidden_states: bool = False, *, key) -> "RobertaModel":
+        k_t, k_emb, k_p = jrandom.split(key, 3)
+        encoder = RobertaEncoder.init(config=config, output_hidden_states=output_hidden_states, key=k_t)
+        embeddings = RobertaEmbedding.init(Vocab, config, key=k_emb)
+
+        pooler = RobertaPooler.init(config, key=k_p) if add_pooling_layer else None
+        return RobertaModel(encoder, embeddings, pooler, output_hidden_states)
+
+    @property
+    def config(self):
+        return self.encoder.config
+
+    @property
+    def vocab_size(self) -> int:
+        return self.Vocab.size
+
+    @property
+    def Vocab(self) -> Axis:
+        return self.embeddings.Vocab
+
+    def get_input_embeddings(self):
+        return self.embeddings.word_embeddings
+
+    def set_input_embeddings(self, value):
+        self.embeddings.word_embeddings = value
+
+    @named_call
+    def __call__(
+        self,
+        input_ids: Optional[NamedArray] = None,
+        token_type_ids: Optional[NamedArray] = None, 
+        position_ids: Optional[NamedArray] = None, 
+        input_embeds: Optional[NamedArray] = None,
+        attention_mask: Optional[NamedArray] = None,
+        *,
+        key,
+    ) -> Tuple[NamedArray]:
+        """
+        Not Used: meant to be used to improve performance in decoder implementations
+
+        head_mask: Optional[NamedArray] = None,
+        encoder_hidden_states: Optional[NamedArray] = None,
+        encoder_attention_mask: Optional[NamedArray] = None,
+        past_key_values_length = 0,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        """
+        k_emb, k_e, k_p = maybe_rng_split(key, 3)
+
+        if input_ids is not None and input_embeds is not None:
+            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+        elif input_ids is not None:
+            input_axes = input_ids.axes
+        elif input_embeds is not None:
+            input_axes = hax.eliminate_axes(input_embeds.axes, "embed")
+        else:
+            raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+        if attention_mask is None:
+            attention_mask = hax.ones(input_axes)
+        
+        # print(f"attention_mask: {attention_mask}")
+        
+        # Attention mask from mask to actual numbers 0 -> -inf
+        attention_mask = (attention_mask == 0) * jnp.finfo(jnp.bfloat16).min
+
+        # print(f"attention_mask_real: {attention_mask}")
+        
+        embedding_output = self.embeddings.embed(input_ids, token_type_ids, position_ids, input_embeds, key=k_emb)
+
+        encoder_outputs = self.encoder(embedding_output, attention_mask=attention_mask, key=k_e)
+        
+        sequence_output = encoder_outputs[0]
+
+        pooled_output = self.pooler(sequence_output, key=k_p) if self.pooler is not None else None
+
+        # jax.debug.breakpoint()
+
+        return (sequence_output, pooled_output) + encoder_outputs[1:] if self.output_hidden_states else (sequence_output, pooled_output)
+
+class RobertaLMHead(eqx.Module, StateDictSerializationMixin):
+    """Roberta Head for masked language modeling."""
+
+    dense: hnn.Linear
+    layer_norm: hnn.LayerNorm
+    decoder: hnn.Linear
+    config: RobertaConfig
+
+    @staticmethod
+    def init(Vocab: Axis, config: RobertaConfig, *, key):
+        k_dense, k_decoder = jrandom.split(key, 2)
+        Embed = config.Embed
+
+        dense = hnn.Linear.init(In=Embed, Out=config.FinalEmbed, key=k_dense, out_first=True)
+        layer_norm = hnn.LayerNorm.init(axis=Embed, eps=config.layer_norm_eps)
+
+        decoder = hnn.Linear.init(Embed, Vocab, key=k_decoder, out_first=True)
+
+        # idk what this is: TODO
+        # self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+        # self.decoder.bias = self.bias
+
+        return RobertaLMHead(dense, layer_norm, decoder, config)
+
+    @named_call
+    def __call__(self, features: NamedArray, *, key=None) -> NamedArray:
+        x = self.dense(features).rename({self.config.FinalEmbed: self.config.Embed})
+        x = hnn.gelu(x, approximate=False)
+        x = self.layer_norm(x)
+
+        # project back to size of vocabulary with bias
+        x = self.decoder(x)
+
+        return x
+
+class RobertaForMaskedLM(eqx.Module, StateDictSerializationMixin):
+    roberta: RobertaModel
+    lm_head: RobertaLMHead
+    Vocab: Axis
+    Pos: Axis
+    output_hidden_states: bool
+
+    @classmethod
+    def init(self, Vocab: Axis, config: RobertaConfig, output_hidden_states: bool = False, *, key):
+
+        # if config.is_decoder:
+        #     raise AttributeError("Model is being run as a MaskedLM aka an encoder model, but is_decoder is true")
+
+        key_rob, key_head = jrandom.split(key, 2)
+        roberta = RobertaModel.init(Vocab, config, add_pooling_layer=False, output_hidden_states=output_hidden_states, key=key_rob)
+        lm_head = RobertaLMHead.init(Vocab, config, key=key_head)
+
+        return RobertaForMaskedLM(roberta, lm_head, Vocab, config.Pos, output_hidden_states)
+
+    def get_output_embeddings(self):
+        return self.lm_head.decoder
+
+    def set_output_embeddings(self, new_embeddings):
+        self.lm_head.decoder = new_embeddings
+
+    @named_call
+    def __call__(
+        self,
+        input_ids: Optional[NamedArray] = None,
+        attention_mask: Optional[NamedArray] = None,
+        token_type_ids: Optional[NamedArray] = None,
+        position_ids: Optional[NamedArray] = None,
+        input_embeds: Optional[NamedArray] = None,
+        *,
+        key=None
+    ) -> Tuple[NamedArray]:
+        r"""
+        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+        kwargs (`Dict[str, any]`, optional, defaults to *{}*):
+            Used to hide legacy arguments that have been deprecated.
+        """
+
+        k_rob, k_lm = maybe_rng_split(key, 2)
+
+        # print(f"input_ids: {input_ids}")
+        # print(f"attention_mask: {attention_mask}")        
+
+        outputs = self.roberta(
+            input_ids,
+            attention_mask=attention_mask,
+            token_type_ids=token_type_ids,
+            position_ids=position_ids,
+            input_embeds=input_embeds,
+            key=k_rob
+        )
+
+        # print(f"outputs: {outputs}")
+
+        prediction_scores = self.lm_head(outputs[0], key=k_lm)
+
+        # print(f"prediction_scores: {prediction_scores}")
+
+        # jax.debug.breakpoint()
+
+        # return (prediction_scores,) + outputs[2:]
+        return prediction_scores
+    
+    def compute_loss(
+            self,
+            example: MaskedLmExample,
+            *,
+            key=None,
+            reduction: Optional[hax.ReductionFunction] = hax.mean,
+            reduction_axis: Optional[hax.AxisSelection] = None,
+    ) -> jnp.ndarray | NamedArray:
+        # logits = self(example.tokens, example.attn_mask, key=key)[0]
+        logits = self(example.tokens, example.attn_mask, key=key)
+        # print(f"Logits: {logits}")
+        logits = logits.astype(jnp.float32)
+        targets = example.targets
+
+        target_y = hax.nn.one_hot(targets, self.Vocab, dtype=logits.dtype)
+        #target_y = jax.debug.breakpoint(token=target_y)
+        loss = cross_entropy_loss(
+            logits, self.Vocab, target_y, reduction, reduction_axis=reduction_axis, where=example.loss_mask
+        )
+
+        # print(f"loss: {loss}")
+
+        return loss
+    
+def _rotate_half(x: NamedArray) -> NamedArray:
+    """Rotates half of the hidden dims of the input and concatenates them."""
+    HeadSize = x.axes[-1]
+    x1 = x[HeadSize, : HeadSize.size // 2]
+    x2 = x[HeadSize, HeadSize.size // 2 :]
+    out = hax.concatenate(HeadSize, (-x2, x1))
+    return out
+
+
+def _apply_rotary_pos_emb(
+    q: NamedArray,  # [batch, position, kv_heads, q_heads_per_group, head_size]
+    k: NamedArray,  # [batch, position, kv_heads, head_size]
+    cos: NamedArray,  # [position, head_size]
+    sin: NamedArray,  # [position, head_size]
+) -> Tuple[NamedArray, NamedArray]:
+    """Applies rotary position embedding to q and k."""
+    q_embed = q * cos + _rotate_half(q) * sin
+    k_embed = k * cos + _rotate_half(k) * sin
+    return q_embed, k_embed
+
+
+def llama_rotary_pos_emb(
+    HeadSize: Axis, Pos: Axis, base: float = 10000, scale: float = 1.0
+) -> Tuple[NamedArray, NamedArray]:
+    with jax.ensure_compile_time_eval():
+        HeadHalfSize = HeadSize.resize(HeadSize.size // 2)
+        inv_freq: NamedArray = 1.0 / (base ** (hax.arange(HeadHalfSize, step=2) / HeadSize.size))
+
+        position_ids: NamedArray = hax.arange(Pos) / scale
+
+        freqs = position_ids * inv_freq.broadcast_axis(Pos)
+        # This is different from the paper but aligns with HF implementation:
+        # It uses a different permutation in order to obtain the same calculation
+        emb = hax.concatenate(HeadSize, (freqs, freqs))
+        cos = hax.cos(emb)
+        sin = hax.sin(emb)
+        # This is different from the paper but aligns with HF implementation:
+        return cos, sin
diff --git a/src/levanter/models/testing.ipynb b/src/levanter/models/testing.ipynb
new file mode 100644
index 000000000..27a24f7bf
--- /dev/null
+++ b/src/levanter/models/testing.ipynb
@@ -0,0 +1,1420 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 19,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import roberta as my_roberta\n",
+    "from transformers.models.roberta import modeling_roberta as hf_roberta\n",
+    "\n",
+    "import torch\n",
+    "import haliax as hax\n",
+    "import jax\n",
+    "import jax.random as jrandom\n",
+    "import jax.numpy as jnp\n",
+    "import numpy as np\n",
+    "\n",
+    "# hello"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 20,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from transformers import AutoConfig\n",
+    "from time import time\n",
+    "\n",
+    "hf_model_str = \"FacebookAI/roberta-base\"\n",
+    "\n",
+    "hf_config = AutoConfig.from_pretrained(hf_model_str)\n",
+    "hf_config.hidden_dropout_prob = 0\n",
+    "hf_config.attention_probs_dropout_prob = 0\n",
+    "# hf_config.pad_token_id = -1\n",
+    "my_config = my_roberta.RobertaConfig.from_hf_config(hf_config)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 21,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "seed: 1725922495\n"
+     ]
+    }
+   ],
+   "source": [
+    "seed = time()\n",
+    "print(f\"seed: {int(seed)}\")\n",
+    "key = jrandom.PRNGKey(int(seed))\n",
+    "\n",
+    "key_vars, key_funcs, key_run = jrandom.split(key, 3)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 22,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "EmbedAtt = my_config.EmbedAtt\n",
+    "Embed = my_config.Embed\n",
+    "Mlp = my_config.Mlp\n",
+    "Pos = my_config.Pos\n",
+    "KeyPos = my_config.KeyPos\n",
+    "Heads = my_config.Heads\n",
+    "\n",
+    "cut_end_for_bounds = True \n",
+    "\n",
+    "Batch = hax.Axis(\"batch\", 2)\n",
+    "Vocab = hax.Axis(\"vocab\", my_config.vocab_size)\n",
+    "\n",
+    "keys = jrandom.split(key_vars, 6)\n",
+    "\n",
+    "input_ids = hax.random.randint(keys[0], (Batch, Pos), minval = 3, maxval = my_config.vocab_size)\n",
+    "if cut_end_for_bounds:\n",
+    "    input_ids = input_ids[{\"position\": slice(0,-2)}]\n",
+    "input_ids_torch = torch.from_numpy(np.array(input_ids.array))\n",
+    "\n",
+    "input_embeds = hax.random.normal(keys[1], (Batch, Pos, Embed))\n",
+    "if cut_end_for_bounds:\n",
+    "    input_embeds = input_embeds[{\"position\": slice(0,-2)}]\n",
+    "input_embeds_torch = torch.from_numpy(np.array(input_embeds.array))\n",
+    "\n",
+    "# mask = hax.random.randint(keys[2], (Batch, Pos), minval = 0, maxval = 2)\n",
+    "# mask = hax.ones((Batch, Pos))\n",
+    "mask = hax.zeros((Batch, Pos))\n",
+    "\n",
+    "if cut_end_for_bounds:\n",
+    "    mask = mask[{\"position\": slice(0,-2)}]\n",
+    "mask_torch = torch.from_numpy(np.array(mask.array))\n",
+    "\n",
+    "mask_materialized = (mask == 0) * jnp.finfo(jnp.bfloat16).min\n",
+    "mask_torch_materialized = hf_roberta.RobertaModel.get_extended_attention_mask(self=hf_roberta.RobertaModel(hf_config), attention_mask=mask_torch, input_shape=input_embeds_torch.shape)\n",
+    "\n",
+    "features = input_embeds[{\"position\": 0}]\n",
+    "features_torch = torch.from_numpy(np.array(features.array))\n",
+    "\n",
+    "x_embed_att = input_embeds.rename({\"embed\": \"embed_att\"})\n",
+    "if cut_end_for_bounds:\n",
+    "    x_embed_att = x_embed_att[{\"position\": slice(0,-2)}]\n",
+    "x_embed_att_torch = torch.from_numpy(np.array(x_embed_att.array))\n",
+    "\n",
+    "x_mlp = hax.random.normal(keys[5], (Batch, Pos, Mlp))\n",
+    "if cut_end_for_bounds:\n",
+    "    x_mlp = x_mlp[{\"position\": slice(0,-2)}]    \n",
+    "x_mlp_torch = torch.from_numpy(np.array(x_mlp.array))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 23,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def check(my_output, hf_output, precision=1e-4):\n",
+    "    \n",
+    "    assert (np.array(my_output.shape) == np.array(hf_output.shape)).all()\n",
+    "    # print(my_output.shape)\n",
+    "    # print(hf_output.shape)\n",
+    "\n",
+    "    acc = np.isclose(hf_output, np.array(my_output), rtol=precision, atol=precision).mean()\n",
+    "\n",
+    "    # stats = (torch.tensor(np.array(my_output)).abs().mean(), torch.tensor(np.array(hf_output)).abs().mean()) \n",
+    "    stats = (torch.linalg.norm(torch.tensor(np.array(my_output))), torch.linalg.norm(torch.tensor(np.array(hf_output))))\n",
+    "    \n",
+    "    difference = torch.tensor(np.array(my_output)) - torch.tensor(np.array(hf_output))\n",
+    "\n",
+    "    diffs = difference.abs().mean()\n",
+    "\n",
+    "    to_print = f\"acc: {acc} \\t norms: {stats} \\t diffs: {diffs}\"\n",
+    "    \n",
+    "    return acc, stats, diffs, to_print"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 54,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def check_dicts(my_dict, hf_dict):\n",
+    "    print(my_dict.keys())\n",
+    "    print(hf_dict.keys())\n",
+    "\n",
+    "    hf_keys_save = list(hf_dict.keys())\n",
+    "\n",
+    "    flag = 0\n",
+    "    diff = 0\n",
+    "\n",
+    "    for k in my_dict.keys():\n",
+    "        i = my_dict[k]\n",
+    "        if k not in hf_dict:\n",
+    "            print(f\"ERROR \\t {k}: key in my_dict but not hf_dict\")\n",
+    "        j = hf_dict[k]\n",
+    "        diff += (np.array(i) - np.array(j)).sum()\n",
+    "        if check(i, j.detach())[0] < 1:\n",
+    "            print(f\"ERROR \\t {k}: {check(i, j)[0]}\")\n",
+    "            flag += 1\n",
+    "        hf_keys_save.remove(k)\n",
+    "\n",
+    "    if flag == 0:\n",
+    "        print(\"success1\") \n",
+    "    else:\n",
+    "        print(\"fail1\") \n",
+    "\n",
+    "    if len(hf_keys_save) == 0:\n",
+    "        print(\"success2\") \n",
+    "    else:\n",
+    "        print(\"fail2\")\n",
+    "        print(hf_keys_save)\n",
+    "\n",
+    "    return diff"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 25,
+   "metadata": {},
+   "outputs": [
+    {
+     "ename": "NameError",
+     "evalue": "name 'stop' is not defined",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[1;31mNameError\u001b[0m                                 Traceback (most recent call last)",
+      "Cell \u001b[1;32mIn[25], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m \u001b[43mstop\u001b[49m\n",
+      "\u001b[1;31mNameError\u001b[0m: name 'stop' is not defined"
+     ]
+    }
+   ],
+   "source": [
+    "stop"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Testing RobertaSelfOutput\n",
+    "\n",
+    "def test_RobertaSelfOutput(key):\n",
+    "    k_1, k_2 = jrandom.split(key, 2)\n",
+    "    my_func = my_roberta.RobertaSelfOutput.init(my_config, key=k_1)\n",
+    "    state = my_func.to_state_dict()\n",
+    "\n",
+    "    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n",
+    "\n",
+    "    # # print(state.keys())\n",
+    "\n",
+    "    hf_func = hf_roberta.RobertaSelfOutput(hf_config)\n",
+    "    hf_func.load_state_dict(state, strict=True, assign=True)\n",
+    "\n",
+    "    my_output = my_func(x_embed_att, input_embeds, key=k_2)\n",
+    "    hf_output = hf_func(x_embed_att_torch, input_embeds_torch)\n",
+    "\n",
+    "    return check(my_output.array, hf_output.detach())\n",
+    "\n",
+    "# Testing RobertaSelfAttention\n",
+    "\n",
+    "def test_RobertaSelfAttention(key):\n",
+    "    k_1, k_2 = jrandom.split(key, 2)\n",
+    "\n",
+    "    my_func = my_roberta.RobertaSelfAttention.init(my_config, key=k_1)\n",
+    "    state = my_func.to_state_dict()\n",
+    "\n",
+    "    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n",
+    "\n",
+    "    # print(state.keys())\n",
+    "\n",
+    "    hf_func  = hf_roberta.RobertaSelfAttention(hf_config)\n",
+    "    hf_func.load_state_dict(state, strict=True, assign=True)\n",
+    "\n",
+    "    my_output = my_func(input_embeds, mask_materialized, key=k_2)\n",
+    "    hf_output = hf_func(input_embeds_torch, mask_torch_materialized)\n",
+    "\n",
+    "    return check(my_output.array, hf_output[0].detach())\n",
+    "\n",
+    "# Testing RobertaAttention\n",
+    "\n",
+    "def test_RobertaAttention(key):\n",
+    "    k_1, k_2 = jrandom.split(key, 2)\n",
+    "    my_func = my_roberta.RobertaAttention.init(my_config, key=k_1)\n",
+    "    state = my_func.to_state_dict()\n",
+    "\n",
+    "    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n",
+    "\n",
+    "    # print(state.keys())\n",
+    "\n",
+    "    hf_func = hf_roberta.RobertaAttention(hf_config)\n",
+    "    hf_func.load_state_dict(state, strict=True, assign=True)\n",
+    "\n",
+    "    my_output = my_func(hidden_states=input_embeds, attention_mask=mask_materialized, key=k_2)\n",
+    "    hf_output = hf_func(hidden_states=input_embeds_torch, attention_mask=mask_torch_materialized)\n",
+    "\n",
+    "    return check(my_output.array, hf_output[0].detach())\n",
+    "\n",
+    "# Testing RobertaIntermediate\n",
+    "\n",
+    "def test_RobertaIntermediate(key):\n",
+    "    k_1, k_2 = jrandom.split(key, 2)\n",
+    "    my_func = my_roberta.RobertaIntermediate.init(my_config, key=k_1)\n",
+    "    state = my_func.to_state_dict()\n",
+    "\n",
+    "    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n",
+    "\n",
+    "    # print(state.keys())\n",
+    "\n",
+    "    hf_func = hf_roberta.RobertaIntermediate(hf_config)\n",
+    "    hf_func.load_state_dict(state, strict=True, assign=True)\n",
+    "\n",
+    "    my_output = my_func(input_embeds, key=k_2)\n",
+    "    hf_output = hf_func(input_embeds_torch)\n",
+    "\n",
+    "    return check(my_output.array, hf_output.detach())\n",
+    "\n",
+    "# Testing RobertaOutput\n",
+    "\n",
+    "def test_RobertaOutput(key):\n",
+    "    k_1, k_2 = jrandom.split(key, 2)\n",
+    "\n",
+    "    my_func = my_roberta.RobertaOutput.init(my_config, key=k_1)\n",
+    "    state = my_func.to_state_dict()\n",
+    "\n",
+    "    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n",
+    "\n",
+    "    # print(state.keys())\n",
+    "\n",
+    "    hf_func = hf_roberta.RobertaOutput(hf_config)\n",
+    "    hf_func.load_state_dict(state, strict=True, assign=True)\n",
+    "\n",
+    "    my_output = my_func(x_mlp, input_embeds, key=k_2)\n",
+    "    hf_output = hf_func(x_mlp_torch, input_embeds_torch)\n",
+    "\n",
+    "    return check(my_output.array, hf_output.detach())\n",
+    "\n",
+    "# Testing RobertaLayer\n",
+    "\n",
+    "def test_RobertaLayer(key):\n",
+    "    k_1, k_2 = jrandom.split(key, 2)\n",
+    "    my_func = my_roberta.RobertaLayer.init(my_config, key=k_1)\n",
+    "    state = my_func.to_state_dict()\n",
+    "\n",
+    "    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n",
+    "\n",
+    "    # print(state.keys())\n",
+    "\n",
+    "    hf_func = hf_roberta.RobertaLayer(hf_config)\n",
+    "    hf_func.load_state_dict(state, strict=True, assign=True)\n",
+    "\n",
+    "    my_output = my_func(hidden_states=input_embeds, attention_mask=mask_materialized, key=k_2)\n",
+    "    hf_output = hf_func(hidden_states=input_embeds_torch, attention_mask=mask_torch_materialized)\n",
+    "\n",
+    "    return check(my_output[0].array, hf_output[0].detach())\n",
+    "\n",
+    "# Testing RobertaEncoder\n",
+    "\n",
+    "def test_RobertaEncoder(key):\n",
+    "    k_1, k_2 = jrandom.split(key, 2)\n",
+    "    my_func = my_roberta.RobertaEncoder.init(my_config, key=k_1)\n",
+    "    state = my_func.to_state_dict()\n",
+    "\n",
+    "    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n",
+    "\n",
+    "    # print(state.keys())\n",
+    "\n",
+    "    hf_func = hf_roberta.RobertaEncoder(hf_config)\n",
+    "    hf_func.load_state_dict(state, strict=True, assign=True)\n",
+    "\n",
+    "    my_output = my_func(hidden_states=input_embeds, attention_mask=mask_materialized, key=k_2)\n",
+    "    hf_output = hf_func(hidden_states=input_embeds_torch, attention_mask=mask_torch_materialized)\n",
+    "\n",
+    "    return check(my_output[0].array, hf_output[0].detach())\n",
+    "\n",
+    "# Testing RobertaEmbedding\n",
+    "\n",
+    "def test_RobertaEmbedding(key, ids = True):\n",
+    "    k_1, k_2 = jrandom.split(key, 2)\n",
+    "    my_func = my_roberta.RobertaEmbedding.init(Vocab, my_config, key=k_1)\n",
+    "    state = my_func.to_state_dict()\n",
+    "\n",
+    "    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n",
+    "\n",
+    "    # print(state.keys())\n",
+    "\n",
+    "    hf_func = hf_roberta.RobertaEmbeddings(hf_config)\n",
+    "    hf_func.load_state_dict(state, strict=True, assign=True)\n",
+    "\n",
+    "    if ids:\n",
+    "        my_output = my_func.embed(input_ids=input_ids, key=k_2)\n",
+    "        hf_output = hf_func(input_ids=input_ids_torch)\n",
+    "    else:        \n",
+    "        my_output = my_func.embed(input_embeds=input_embeds, key=k_2)\n",
+    "        hf_output = hf_func(inputs_embeds=input_embeds_torch)\n",
+    "\n",
+    "    return check(my_output.array, hf_output.detach())\n",
+    "\n",
+    "# Testing RobertaPooler\n",
+    "\n",
+    "def test_RobertaPooler(key):\n",
+    "    k_1, k_2 = jrandom.split(key, 2)\n",
+    "    my_func = my_roberta.RobertaPooler.init(my_config, key=k_1)\n",
+    "    state = my_func.to_state_dict()\n",
+    "\n",
+    "    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n",
+    "\n",
+    "    # print(state.keys())\n",
+    "\n",
+    "    hf_func = hf_roberta.RobertaPooler(hf_config)\n",
+    "    hf_func.load_state_dict(state, strict=True, assign=True)\n",
+    "\n",
+    "    my_output = my_func(input_embeds, key=k_2)\n",
+    "    hf_output = hf_func(input_embeds_torch)\n",
+    "\n",
+    "    return check(my_output.array, hf_output.detach())\n",
+    "\n",
+    "# Testing RobertaModel\n",
+    "\n",
+    "def test_RobertaModel(key, ids = True, pool = True):\n",
+    "    k_1, k_2 = jrandom.split(key, 2)\n",
+    "    my_func = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=pool, key=k_1)\n",
+    "    state = my_func.to_state_dict()\n",
+    "\n",
+    "    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n",
+    "\n",
+    "    # print(state.keys())\n",
+    "\n",
+    "    hf_func = hf_roberta.RobertaModel(hf_config, add_pooling_layer=pool)\n",
+    "    hf_func.load_state_dict(state, strict=True, assign=True)\n",
+    "\n",
+    "    if ids:\n",
+    "        my_output = my_func(input_ids = input_ids, attention_mask=mask, key=k_2)\n",
+    "        hf_output = hf_func(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\n",
+    "    else:\n",
+    "        my_output = my_func(input_embeds = input_embeds, attention_mask=mask, key=k_2)\n",
+    "        hf_output = hf_func(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\n",
+    "\n",
+    "    if pool:\n",
+    "        return check(my_output[1].array, hf_output[1].detach())\n",
+    "    else:\n",
+    "        return check(my_output[0].array, hf_output[0].detach())\n",
+    "\n",
+    "# Testing RobertaLMHead\n",
+    "\n",
+    "def test_RobertaLMHead(key):\n",
+    "    k_1, k_2 = jrandom.split(key, 2)\n",
+    "    my_func = my_roberta.RobertaLMHead.init(Vocab, my_config, key=k_1)\n",
+    "    state = my_func.to_state_dict()\n",
+    "\n",
+    "    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n",
+    "\n",
+    "    state[\"bias\"] = torch.zeros(hf_config.vocab_size)\n",
+    "\n",
+    "    # print(state.keys())\n",
+    "\n",
+    "    hf_func = hf_roberta.RobertaLMHead(hf_config)\n",
+    "    hf_func.load_state_dict(state, strict=True, assign=True)\n",
+    "\n",
+    "    my_output = my_func(features, key=k_2)\n",
+    "    hf_output = hf_func(features_torch)\n",
+    "\n",
+    "    return check(my_output.array, hf_output.detach())\n",
+    "\n",
+    "# Testing RobertaForMaskedLM\n",
+    "\n",
+    "def test_RobertaForMaskedLM(key, ids = True):\n",
+    "    k_1, k_2 = jrandom.split(key, 2)\n",
+    "    my_pool = my_roberta.RobertaForMaskedLM.init(Vocab, my_config, key=k_1)\n",
+    "    state = my_pool.to_state_dict()\n",
+    "\n",
+    "    state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}\n",
+    "\n",
+    "    state[\"lm_head.bias\"] = torch.zeros(hf_config.vocab_size)\n",
+    "\n",
+    "    # print(state.keys())\n",
+    "\n",
+    "    hf_pool = hf_roberta.RobertaForMaskedLM(hf_config)\n",
+    "    hf_pool.load_state_dict(state, strict=True, assign=True)\n",
+    "\n",
+    "    if ids:\n",
+    "        my_output = my_pool(input_ids = input_ids, attention_mask=mask, key=k_2)\n",
+    "        hf_output = hf_pool(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\n",
+    "    else:\n",
+    "        my_output = my_pool(input_embeds = input_embeds, attention_mask=mask, key=k_2)\n",
+    "        hf_output = hf_pool(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\n",
+    "\n",
+    "    return check(my_output[0].array, hf_output[0].detach())"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "keys = jrandom.split(key_funcs, 15)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "outs = []\n",
+    "\n",
+    "outs.append(test_RobertaSelfOutput(keys[0]))\n",
+    "outs.append(test_RobertaSelfAttention(keys[1]))\n",
+    "outs.append(test_RobertaAttention(keys[2]))\n",
+    "outs.append(test_RobertaIntermediate(keys[3]))\n",
+    "outs.append(test_RobertaOutput(keys[4]))\n",
+    "outs.append(test_RobertaLayer(keys[4]))\n",
+    "outs.append(test_RobertaEncoder(keys[4]))\n",
+    "outs.append(test_RobertaEmbedding(keys[7], ids = True))\n",
+    "outs.append(test_RobertaEmbedding(keys[8], ids = False))\n",
+    "outs.append(test_RobertaModel(keys[9], ids = True, pool = True))\n",
+    "outs.append(test_RobertaModel(keys[10], ids = False, pool = False))\n",
+    "outs.append(test_RobertaModel(keys[9], ids = True, pool = True))\n",
+    "outs.append(test_RobertaModel(keys[10], ids = False, pool = False))\n",
+    "outs.append(test_RobertaPooler(keys[11]))\n",
+    "outs.append(test_RobertaLMHead(keys[12]))\n",
+    "outs.append(test_RobertaForMaskedLM(keys[13], ids = True))\n",
+    "outs.append(test_RobertaForMaskedLM(keys[14], ids = False))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "types = [\n",
+    "    \"test_RobertaSelfOutput\",\n",
+    "    \"test_RobertaSelfAttention\",\n",
+    "    \"test_RobertaAttention\",\n",
+    "    \"test_RobertaIntermediate\",\n",
+    "    \"test_RobertaOutput\",\n",
+    "    \"test_RobertaLayer\",\n",
+    "    \"test_RobertaEncoder\",\n",
+    "    \"test_RobertaEmbedding(ids = True)\",\n",
+    "    \"test_RobertaEmbedding(ids = False)\",\n",
+    "    \"test_RobertaModel(ids = True, pool = True)\",\n",
+    "    \"test_RobertaModel(ids = False, pool = False)\",\n",
+    "    \"test_RobertaModel(ids = True, pool = True)\",\n",
+    "    \"test_RobertaModel(ids = False, pool = False)\",\n",
+    "    \"test_RobertaPooler\",\n",
+    "    \"test_RobertaLMHead\",\n",
+    "    \"test_RobertaForMaskedLM(ids = True)\",\n",
+    "    \"test_RobertaForMaskedLM(ids = False)\"\n",
+    "]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "for i,o in enumerate(outs):\n",
+    "    if o[2] * 0 != 0:\n",
+    "        print(f\"nan alert\")\n",
+    "    if o[0] < 1:\n",
+    "        print(f\"{types[i]}: {o[3]}\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "for i,o in enumerate(outs):\n",
+    "    print(f\"{types[i]}: {o[3]}\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 13,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "key_model, key_lm = jrandom.split(key_funcs, 2)\n",
+    "key_model_run, key_lm_run = jrandom.split(key_run, 2)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 14,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "dict_keys(['encoder.layer.0.attention.self.query.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.layer.0.intermediate.dense.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.0.output.dense.weight', 'encoder.layer.0.output.dense.bias', 'encoder.layer.0.output.LayerNorm.weight', 'encoder.layer.0.output.LayerNorm.bias', 'encoder.layer.1.attention.self.query.weight', 'encoder.layer.1.attention.self.query.bias', 'encoder.layer.1.attention.self.key.weight', 'encoder.layer.1.attention.self.key.bias', 'encoder.layer.1.attention.self.value.weight', 'encoder.layer.1.attention.self.value.bias', 'encoder.layer.1.attention.output.dense.weight', 'encoder.layer.1.attention.output.dense.bias', 'encoder.layer.1.attention.output.LayerNorm.weight', 'encoder.layer.1.attention.output.LayerNorm.bias', 'encoder.layer.1.intermediate.dense.weight', 'encoder.layer.1.intermediate.dense.bias', 'encoder.layer.1.output.dense.weight', 'encoder.layer.1.output.dense.bias', 'encoder.layer.1.output.LayerNorm.weight', 'encoder.layer.1.output.LayerNorm.bias', 'encoder.layer.2.attention.self.query.weight', 'encoder.layer.2.attention.self.query.bias', 'encoder.layer.2.attention.self.key.weight', 'encoder.layer.2.attention.self.key.bias', 'encoder.layer.2.attention.self.value.weight', 'encoder.layer.2.attention.self.value.bias', 'encoder.layer.2.attention.output.dense.weight', 'encoder.layer.2.attention.output.dense.bias', 'encoder.layer.2.attention.output.LayerNorm.weight', 'encoder.layer.2.attention.output.LayerNorm.bias', 'encoder.layer.2.intermediate.dense.weight', 'encoder.layer.2.intermediate.dense.bias', 'encoder.layer.2.output.dense.weight', 'encoder.layer.2.output.dense.bias', 'encoder.layer.2.output.LayerNorm.weight', 'encoder.layer.2.output.LayerNorm.bias', 'encoder.layer.3.attention.self.query.weight', 'encoder.layer.3.attention.self.query.bias', 'encoder.layer.3.attention.self.key.weight', 'encoder.layer.3.attention.self.key.bias', 'encoder.layer.3.attention.self.value.weight', 'encoder.layer.3.attention.self.value.bias', 'encoder.layer.3.attention.output.dense.weight', 'encoder.layer.3.attention.output.dense.bias', 'encoder.layer.3.attention.output.LayerNorm.weight', 'encoder.layer.3.attention.output.LayerNorm.bias', 'encoder.layer.3.intermediate.dense.weight', 'encoder.layer.3.intermediate.dense.bias', 'encoder.layer.3.output.dense.weight', 'encoder.layer.3.output.dense.bias', 'encoder.layer.3.output.LayerNorm.weight', 'encoder.layer.3.output.LayerNorm.bias', 'encoder.layer.4.attention.self.query.weight', 'encoder.layer.4.attention.self.query.bias', 'encoder.layer.4.attention.self.key.weight', 'encoder.layer.4.attention.self.key.bias', 'encoder.layer.4.attention.self.value.weight', 'encoder.layer.4.attention.self.value.bias', 'encoder.layer.4.attention.output.dense.weight', 'encoder.layer.4.attention.output.dense.bias', 'encoder.layer.4.attention.output.LayerNorm.weight', 'encoder.layer.4.attention.output.LayerNorm.bias', 'encoder.layer.4.intermediate.dense.weight', 'encoder.layer.4.intermediate.dense.bias', 'encoder.layer.4.output.dense.weight', 'encoder.layer.4.output.dense.bias', 'encoder.layer.4.output.LayerNorm.weight', 'encoder.layer.4.output.LayerNorm.bias', 'encoder.layer.5.attention.self.query.weight', 'encoder.layer.5.attention.self.query.bias', 'encoder.layer.5.attention.self.key.weight', 'encoder.layer.5.attention.self.key.bias', 'encoder.layer.5.attention.self.value.weight', 'encoder.layer.5.attention.self.value.bias', 'encoder.layer.5.attention.output.dense.weight', 'encoder.layer.5.attention.output.dense.bias', 'encoder.layer.5.attention.output.LayerNorm.weight', 'encoder.layer.5.attention.output.LayerNorm.bias', 'encoder.layer.5.intermediate.dense.weight', 'encoder.layer.5.intermediate.dense.bias', 'encoder.layer.5.output.dense.weight', 'encoder.layer.5.output.dense.bias', 'encoder.layer.5.output.LayerNorm.weight', 'encoder.layer.5.output.LayerNorm.bias', 'encoder.layer.6.attention.self.query.weight', 'encoder.layer.6.attention.self.query.bias', 'encoder.layer.6.attention.self.key.weight', 'encoder.layer.6.attention.self.key.bias', 'encoder.layer.6.attention.self.value.weight', 'encoder.layer.6.attention.self.value.bias', 'encoder.layer.6.attention.output.dense.weight', 'encoder.layer.6.attention.output.dense.bias', 'encoder.layer.6.attention.output.LayerNorm.weight', 'encoder.layer.6.attention.output.LayerNorm.bias', 'encoder.layer.6.intermediate.dense.weight', 'encoder.layer.6.intermediate.dense.bias', 'encoder.layer.6.output.dense.weight', 'encoder.layer.6.output.dense.bias', 'encoder.layer.6.output.LayerNorm.weight', 'encoder.layer.6.output.LayerNorm.bias', 'encoder.layer.7.attention.self.query.weight', 'encoder.layer.7.attention.self.query.bias', 'encoder.layer.7.attention.self.key.weight', 'encoder.layer.7.attention.self.key.bias', 'encoder.layer.7.attention.self.value.weight', 'encoder.layer.7.attention.self.value.bias', 'encoder.layer.7.attention.output.dense.weight', 'encoder.layer.7.attention.output.dense.bias', 'encoder.layer.7.attention.output.LayerNorm.weight', 'encoder.layer.7.attention.output.LayerNorm.bias', 'encoder.layer.7.intermediate.dense.weight', 'encoder.layer.7.intermediate.dense.bias', 'encoder.layer.7.output.dense.weight', 'encoder.layer.7.output.dense.bias', 'encoder.layer.7.output.LayerNorm.weight', 'encoder.layer.7.output.LayerNorm.bias', 'encoder.layer.8.attention.self.query.weight', 'encoder.layer.8.attention.self.query.bias', 'encoder.layer.8.attention.self.key.weight', 'encoder.layer.8.attention.self.key.bias', 'encoder.layer.8.attention.self.value.weight', 'encoder.layer.8.attention.self.value.bias', 'encoder.layer.8.attention.output.dense.weight', 'encoder.layer.8.attention.output.dense.bias', 'encoder.layer.8.attention.output.LayerNorm.weight', 'encoder.layer.8.attention.output.LayerNorm.bias', 'encoder.layer.8.intermediate.dense.weight', 'encoder.layer.8.intermediate.dense.bias', 'encoder.layer.8.output.dense.weight', 'encoder.layer.8.output.dense.bias', 'encoder.layer.8.output.LayerNorm.weight', 'encoder.layer.8.output.LayerNorm.bias', 'encoder.layer.9.attention.self.query.weight', 'encoder.layer.9.attention.self.query.bias', 'encoder.layer.9.attention.self.key.weight', 'encoder.layer.9.attention.self.key.bias', 'encoder.layer.9.attention.self.value.weight', 'encoder.layer.9.attention.self.value.bias', 'encoder.layer.9.attention.output.dense.weight', 'encoder.layer.9.attention.output.dense.bias', 'encoder.layer.9.attention.output.LayerNorm.weight', 'encoder.layer.9.attention.output.LayerNorm.bias', 'encoder.layer.9.intermediate.dense.weight', 'encoder.layer.9.intermediate.dense.bias', 'encoder.layer.9.output.dense.weight', 'encoder.layer.9.output.dense.bias', 'encoder.layer.9.output.LayerNorm.weight', 'encoder.layer.9.output.LayerNorm.bias', 'encoder.layer.10.attention.self.query.weight', 'encoder.layer.10.attention.self.query.bias', 'encoder.layer.10.attention.self.key.weight', 'encoder.layer.10.attention.self.key.bias', 'encoder.layer.10.attention.self.value.weight', 'encoder.layer.10.attention.self.value.bias', 'encoder.layer.10.attention.output.dense.weight', 'encoder.layer.10.attention.output.dense.bias', 'encoder.layer.10.attention.output.LayerNorm.weight', 'encoder.layer.10.attention.output.LayerNorm.bias', 'encoder.layer.10.intermediate.dense.weight', 'encoder.layer.10.intermediate.dense.bias', 'encoder.layer.10.output.dense.weight', 'encoder.layer.10.output.dense.bias', 'encoder.layer.10.output.LayerNorm.weight', 'encoder.layer.10.output.LayerNorm.bias', 'encoder.layer.11.attention.self.query.weight', 'encoder.layer.11.attention.self.query.bias', 'encoder.layer.11.attention.self.key.weight', 'encoder.layer.11.attention.self.key.bias', 'encoder.layer.11.attention.self.value.weight', 'encoder.layer.11.attention.self.value.bias', 'encoder.layer.11.attention.output.dense.weight', 'encoder.layer.11.attention.output.dense.bias', 'encoder.layer.11.attention.output.LayerNorm.weight', 'encoder.layer.11.attention.output.LayerNorm.bias', 'encoder.layer.11.intermediate.dense.weight', 'encoder.layer.11.intermediate.dense.bias', 'encoder.layer.11.output.dense.weight', 'encoder.layer.11.output.dense.bias', 'encoder.layer.11.output.LayerNorm.weight', 'encoder.layer.11.output.LayerNorm.bias', 'embeddings.word_embeddings.weight', 'embeddings.position_embeddings.weight', 'embeddings.token_type_embeddings.weight', 'embeddings.LayerNorm.weight', 'embeddings.LayerNorm.bias', 'pooler.dense.weight', 'pooler.dense.bias'])\n",
+      "odict_keys(['embeddings.word_embeddings.weight', 'embeddings.position_embeddings.weight', 'embeddings.token_type_embeddings.weight', 'embeddings.LayerNorm.weight', 'embeddings.LayerNorm.bias', 'encoder.layer.0.attention.self.query.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.layer.0.intermediate.dense.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.0.output.dense.weight', 'encoder.layer.0.output.dense.bias', 'encoder.layer.0.output.LayerNorm.weight', 'encoder.layer.0.output.LayerNorm.bias', 'encoder.layer.1.attention.self.query.weight', 'encoder.layer.1.attention.self.query.bias', 'encoder.layer.1.attention.self.key.weight', 'encoder.layer.1.attention.self.key.bias', 'encoder.layer.1.attention.self.value.weight', 'encoder.layer.1.attention.self.value.bias', 'encoder.layer.1.attention.output.dense.weight', 'encoder.layer.1.attention.output.dense.bias', 'encoder.layer.1.attention.output.LayerNorm.weight', 'encoder.layer.1.attention.output.LayerNorm.bias', 'encoder.layer.1.intermediate.dense.weight', 'encoder.layer.1.intermediate.dense.bias', 'encoder.layer.1.output.dense.weight', 'encoder.layer.1.output.dense.bias', 'encoder.layer.1.output.LayerNorm.weight', 'encoder.layer.1.output.LayerNorm.bias', 'encoder.layer.2.attention.self.query.weight', 'encoder.layer.2.attention.self.query.bias', 'encoder.layer.2.attention.self.key.weight', 'encoder.layer.2.attention.self.key.bias', 'encoder.layer.2.attention.self.value.weight', 'encoder.layer.2.attention.self.value.bias', 'encoder.layer.2.attention.output.dense.weight', 'encoder.layer.2.attention.output.dense.bias', 'encoder.layer.2.attention.output.LayerNorm.weight', 'encoder.layer.2.attention.output.LayerNorm.bias', 'encoder.layer.2.intermediate.dense.weight', 'encoder.layer.2.intermediate.dense.bias', 'encoder.layer.2.output.dense.weight', 'encoder.layer.2.output.dense.bias', 'encoder.layer.2.output.LayerNorm.weight', 'encoder.layer.2.output.LayerNorm.bias', 'encoder.layer.3.attention.self.query.weight', 'encoder.layer.3.attention.self.query.bias', 'encoder.layer.3.attention.self.key.weight', 'encoder.layer.3.attention.self.key.bias', 'encoder.layer.3.attention.self.value.weight', 'encoder.layer.3.attention.self.value.bias', 'encoder.layer.3.attention.output.dense.weight', 'encoder.layer.3.attention.output.dense.bias', 'encoder.layer.3.attention.output.LayerNorm.weight', 'encoder.layer.3.attention.output.LayerNorm.bias', 'encoder.layer.3.intermediate.dense.weight', 'encoder.layer.3.intermediate.dense.bias', 'encoder.layer.3.output.dense.weight', 'encoder.layer.3.output.dense.bias', 'encoder.layer.3.output.LayerNorm.weight', 'encoder.layer.3.output.LayerNorm.bias', 'encoder.layer.4.attention.self.query.weight', 'encoder.layer.4.attention.self.query.bias', 'encoder.layer.4.attention.self.key.weight', 'encoder.layer.4.attention.self.key.bias', 'encoder.layer.4.attention.self.value.weight', 'encoder.layer.4.attention.self.value.bias', 'encoder.layer.4.attention.output.dense.weight', 'encoder.layer.4.attention.output.dense.bias', 'encoder.layer.4.attention.output.LayerNorm.weight', 'encoder.layer.4.attention.output.LayerNorm.bias', 'encoder.layer.4.intermediate.dense.weight', 'encoder.layer.4.intermediate.dense.bias', 'encoder.layer.4.output.dense.weight', 'encoder.layer.4.output.dense.bias', 'encoder.layer.4.output.LayerNorm.weight', 'encoder.layer.4.output.LayerNorm.bias', 'encoder.layer.5.attention.self.query.weight', 'encoder.layer.5.attention.self.query.bias', 'encoder.layer.5.attention.self.key.weight', 'encoder.layer.5.attention.self.key.bias', 'encoder.layer.5.attention.self.value.weight', 'encoder.layer.5.attention.self.value.bias', 'encoder.layer.5.attention.output.dense.weight', 'encoder.layer.5.attention.output.dense.bias', 'encoder.layer.5.attention.output.LayerNorm.weight', 'encoder.layer.5.attention.output.LayerNorm.bias', 'encoder.layer.5.intermediate.dense.weight', 'encoder.layer.5.intermediate.dense.bias', 'encoder.layer.5.output.dense.weight', 'encoder.layer.5.output.dense.bias', 'encoder.layer.5.output.LayerNorm.weight', 'encoder.layer.5.output.LayerNorm.bias', 'encoder.layer.6.attention.self.query.weight', 'encoder.layer.6.attention.self.query.bias', 'encoder.layer.6.attention.self.key.weight', 'encoder.layer.6.attention.self.key.bias', 'encoder.layer.6.attention.self.value.weight', 'encoder.layer.6.attention.self.value.bias', 'encoder.layer.6.attention.output.dense.weight', 'encoder.layer.6.attention.output.dense.bias', 'encoder.layer.6.attention.output.LayerNorm.weight', 'encoder.layer.6.attention.output.LayerNorm.bias', 'encoder.layer.6.intermediate.dense.weight', 'encoder.layer.6.intermediate.dense.bias', 'encoder.layer.6.output.dense.weight', 'encoder.layer.6.output.dense.bias', 'encoder.layer.6.output.LayerNorm.weight', 'encoder.layer.6.output.LayerNorm.bias', 'encoder.layer.7.attention.self.query.weight', 'encoder.layer.7.attention.self.query.bias', 'encoder.layer.7.attention.self.key.weight', 'encoder.layer.7.attention.self.key.bias', 'encoder.layer.7.attention.self.value.weight', 'encoder.layer.7.attention.self.value.bias', 'encoder.layer.7.attention.output.dense.weight', 'encoder.layer.7.attention.output.dense.bias', 'encoder.layer.7.attention.output.LayerNorm.weight', 'encoder.layer.7.attention.output.LayerNorm.bias', 'encoder.layer.7.intermediate.dense.weight', 'encoder.layer.7.intermediate.dense.bias', 'encoder.layer.7.output.dense.weight', 'encoder.layer.7.output.dense.bias', 'encoder.layer.7.output.LayerNorm.weight', 'encoder.layer.7.output.LayerNorm.bias', 'encoder.layer.8.attention.self.query.weight', 'encoder.layer.8.attention.self.query.bias', 'encoder.layer.8.attention.self.key.weight', 'encoder.layer.8.attention.self.key.bias', 'encoder.layer.8.attention.self.value.weight', 'encoder.layer.8.attention.self.value.bias', 'encoder.layer.8.attention.output.dense.weight', 'encoder.layer.8.attention.output.dense.bias', 'encoder.layer.8.attention.output.LayerNorm.weight', 'encoder.layer.8.attention.output.LayerNorm.bias', 'encoder.layer.8.intermediate.dense.weight', 'encoder.layer.8.intermediate.dense.bias', 'encoder.layer.8.output.dense.weight', 'encoder.layer.8.output.dense.bias', 'encoder.layer.8.output.LayerNorm.weight', 'encoder.layer.8.output.LayerNorm.bias', 'encoder.layer.9.attention.self.query.weight', 'encoder.layer.9.attention.self.query.bias', 'encoder.layer.9.attention.self.key.weight', 'encoder.layer.9.attention.self.key.bias', 'encoder.layer.9.attention.self.value.weight', 'encoder.layer.9.attention.self.value.bias', 'encoder.layer.9.attention.output.dense.weight', 'encoder.layer.9.attention.output.dense.bias', 'encoder.layer.9.attention.output.LayerNorm.weight', 'encoder.layer.9.attention.output.LayerNorm.bias', 'encoder.layer.9.intermediate.dense.weight', 'encoder.layer.9.intermediate.dense.bias', 'encoder.layer.9.output.dense.weight', 'encoder.layer.9.output.dense.bias', 'encoder.layer.9.output.LayerNorm.weight', 'encoder.layer.9.output.LayerNorm.bias', 'encoder.layer.10.attention.self.query.weight', 'encoder.layer.10.attention.self.query.bias', 'encoder.layer.10.attention.self.key.weight', 'encoder.layer.10.attention.self.key.bias', 'encoder.layer.10.attention.self.value.weight', 'encoder.layer.10.attention.self.value.bias', 'encoder.layer.10.attention.output.dense.weight', 'encoder.layer.10.attention.output.dense.bias', 'encoder.layer.10.attention.output.LayerNorm.weight', 'encoder.layer.10.attention.output.LayerNorm.bias', 'encoder.layer.10.intermediate.dense.weight', 'encoder.layer.10.intermediate.dense.bias', 'encoder.layer.10.output.dense.weight', 'encoder.layer.10.output.dense.bias', 'encoder.layer.10.output.LayerNorm.weight', 'encoder.layer.10.output.LayerNorm.bias', 'encoder.layer.11.attention.self.query.weight', 'encoder.layer.11.attention.self.query.bias', 'encoder.layer.11.attention.self.key.weight', 'encoder.layer.11.attention.self.key.bias', 'encoder.layer.11.attention.self.value.weight', 'encoder.layer.11.attention.self.value.bias', 'encoder.layer.11.attention.output.dense.weight', 'encoder.layer.11.attention.output.dense.bias', 'encoder.layer.11.attention.output.LayerNorm.weight', 'encoder.layer.11.attention.output.LayerNorm.bias', 'encoder.layer.11.intermediate.dense.weight', 'encoder.layer.11.intermediate.dense.bias', 'encoder.layer.11.output.dense.weight', 'encoder.layer.11.output.dense.bias', 'encoder.layer.11.output.LayerNorm.weight', 'encoder.layer.11.output.LayerNorm.bias', 'pooler.dense.weight', 'pooler.dense.bias'])\n",
+      "success1\n",
+      "success2\n",
+      "0.0\n"
+     ]
+    }
+   ],
+   "source": [
+    "# Initializing RobertaModel\n",
+    "my_model = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=True, output_hidden_states=True, key=key_model)\n",
+    "state_model = my_model.to_state_dict()\n",
+    "\n",
+    "state_model = {k: torch.from_numpy(np.array(v)) for k, v in state_model.items()}\n",
+    "\n",
+    "hf_model = hf_roberta.RobertaModel(hf_config, add_pooling_layer=True)\n",
+    "hf_model.load_state_dict(state_model, strict=True)\n",
+    "\n",
+    "print(check_dicts(my_model.to_state_dict(), hf_model.state_dict()))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 16,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "dict_keys(['roberta.encoder.layer.0.attention.self.query.weight', 'roberta.encoder.layer.0.attention.self.query.bias', 'roberta.encoder.layer.0.attention.self.key.weight', 'roberta.encoder.layer.0.attention.self.key.bias', 'roberta.encoder.layer.0.attention.self.value.weight', 'roberta.encoder.layer.0.attention.self.value.bias', 'roberta.encoder.layer.0.attention.output.dense.weight', 'roberta.encoder.layer.0.attention.output.dense.bias', 'roberta.encoder.layer.0.attention.output.LayerNorm.weight', 'roberta.encoder.layer.0.attention.output.LayerNorm.bias', 'roberta.encoder.layer.0.intermediate.dense.weight', 'roberta.encoder.layer.0.intermediate.dense.bias', 'roberta.encoder.layer.0.output.dense.weight', 'roberta.encoder.layer.0.output.dense.bias', 'roberta.encoder.layer.0.output.LayerNorm.weight', 'roberta.encoder.layer.0.output.LayerNorm.bias', 'roberta.encoder.layer.1.attention.self.query.weight', 'roberta.encoder.layer.1.attention.self.query.bias', 'roberta.encoder.layer.1.attention.self.key.weight', 'roberta.encoder.layer.1.attention.self.key.bias', 'roberta.encoder.layer.1.attention.self.value.weight', 'roberta.encoder.layer.1.attention.self.value.bias', 'roberta.encoder.layer.1.attention.output.dense.weight', 'roberta.encoder.layer.1.attention.output.dense.bias', 'roberta.encoder.layer.1.attention.output.LayerNorm.weight', 'roberta.encoder.layer.1.attention.output.LayerNorm.bias', 'roberta.encoder.layer.1.intermediate.dense.weight', 'roberta.encoder.layer.1.intermediate.dense.bias', 'roberta.encoder.layer.1.output.dense.weight', 'roberta.encoder.layer.1.output.dense.bias', 'roberta.encoder.layer.1.output.LayerNorm.weight', 'roberta.encoder.layer.1.output.LayerNorm.bias', 'roberta.encoder.layer.2.attention.self.query.weight', 'roberta.encoder.layer.2.attention.self.query.bias', 'roberta.encoder.layer.2.attention.self.key.weight', 'roberta.encoder.layer.2.attention.self.key.bias', 'roberta.encoder.layer.2.attention.self.value.weight', 'roberta.encoder.layer.2.attention.self.value.bias', 'roberta.encoder.layer.2.attention.output.dense.weight', 'roberta.encoder.layer.2.attention.output.dense.bias', 'roberta.encoder.layer.2.attention.output.LayerNorm.weight', 'roberta.encoder.layer.2.attention.output.LayerNorm.bias', 'roberta.encoder.layer.2.intermediate.dense.weight', 'roberta.encoder.layer.2.intermediate.dense.bias', 'roberta.encoder.layer.2.output.dense.weight', 'roberta.encoder.layer.2.output.dense.bias', 'roberta.encoder.layer.2.output.LayerNorm.weight', 'roberta.encoder.layer.2.output.LayerNorm.bias', 'roberta.encoder.layer.3.attention.self.query.weight', 'roberta.encoder.layer.3.attention.self.query.bias', 'roberta.encoder.layer.3.attention.self.key.weight', 'roberta.encoder.layer.3.attention.self.key.bias', 'roberta.encoder.layer.3.attention.self.value.weight', 'roberta.encoder.layer.3.attention.self.value.bias', 'roberta.encoder.layer.3.attention.output.dense.weight', 'roberta.encoder.layer.3.attention.output.dense.bias', 'roberta.encoder.layer.3.attention.output.LayerNorm.weight', 'roberta.encoder.layer.3.attention.output.LayerNorm.bias', 'roberta.encoder.layer.3.intermediate.dense.weight', 'roberta.encoder.layer.3.intermediate.dense.bias', 'roberta.encoder.layer.3.output.dense.weight', 'roberta.encoder.layer.3.output.dense.bias', 'roberta.encoder.layer.3.output.LayerNorm.weight', 'roberta.encoder.layer.3.output.LayerNorm.bias', 'roberta.encoder.layer.4.attention.self.query.weight', 'roberta.encoder.layer.4.attention.self.query.bias', 'roberta.encoder.layer.4.attention.self.key.weight', 'roberta.encoder.layer.4.attention.self.key.bias', 'roberta.encoder.layer.4.attention.self.value.weight', 'roberta.encoder.layer.4.attention.self.value.bias', 'roberta.encoder.layer.4.attention.output.dense.weight', 'roberta.encoder.layer.4.attention.output.dense.bias', 'roberta.encoder.layer.4.attention.output.LayerNorm.weight', 'roberta.encoder.layer.4.attention.output.LayerNorm.bias', 'roberta.encoder.layer.4.intermediate.dense.weight', 'roberta.encoder.layer.4.intermediate.dense.bias', 'roberta.encoder.layer.4.output.dense.weight', 'roberta.encoder.layer.4.output.dense.bias', 'roberta.encoder.layer.4.output.LayerNorm.weight', 'roberta.encoder.layer.4.output.LayerNorm.bias', 'roberta.encoder.layer.5.attention.self.query.weight', 'roberta.encoder.layer.5.attention.self.query.bias', 'roberta.encoder.layer.5.attention.self.key.weight', 'roberta.encoder.layer.5.attention.self.key.bias', 'roberta.encoder.layer.5.attention.self.value.weight', 'roberta.encoder.layer.5.attention.self.value.bias', 'roberta.encoder.layer.5.attention.output.dense.weight', 'roberta.encoder.layer.5.attention.output.dense.bias', 'roberta.encoder.layer.5.attention.output.LayerNorm.weight', 'roberta.encoder.layer.5.attention.output.LayerNorm.bias', 'roberta.encoder.layer.5.intermediate.dense.weight', 'roberta.encoder.layer.5.intermediate.dense.bias', 'roberta.encoder.layer.5.output.dense.weight', 'roberta.encoder.layer.5.output.dense.bias', 'roberta.encoder.layer.5.output.LayerNorm.weight', 'roberta.encoder.layer.5.output.LayerNorm.bias', 'roberta.encoder.layer.6.attention.self.query.weight', 'roberta.encoder.layer.6.attention.self.query.bias', 'roberta.encoder.layer.6.attention.self.key.weight', 'roberta.encoder.layer.6.attention.self.key.bias', 'roberta.encoder.layer.6.attention.self.value.weight', 'roberta.encoder.layer.6.attention.self.value.bias', 'roberta.encoder.layer.6.attention.output.dense.weight', 'roberta.encoder.layer.6.attention.output.dense.bias', 'roberta.encoder.layer.6.attention.output.LayerNorm.weight', 'roberta.encoder.layer.6.attention.output.LayerNorm.bias', 'roberta.encoder.layer.6.intermediate.dense.weight', 'roberta.encoder.layer.6.intermediate.dense.bias', 'roberta.encoder.layer.6.output.dense.weight', 'roberta.encoder.layer.6.output.dense.bias', 'roberta.encoder.layer.6.output.LayerNorm.weight', 'roberta.encoder.layer.6.output.LayerNorm.bias', 'roberta.encoder.layer.7.attention.self.query.weight', 'roberta.encoder.layer.7.attention.self.query.bias', 'roberta.encoder.layer.7.attention.self.key.weight', 'roberta.encoder.layer.7.attention.self.key.bias', 'roberta.encoder.layer.7.attention.self.value.weight', 'roberta.encoder.layer.7.attention.self.value.bias', 'roberta.encoder.layer.7.attention.output.dense.weight', 'roberta.encoder.layer.7.attention.output.dense.bias', 'roberta.encoder.layer.7.attention.output.LayerNorm.weight', 'roberta.encoder.layer.7.attention.output.LayerNorm.bias', 'roberta.encoder.layer.7.intermediate.dense.weight', 'roberta.encoder.layer.7.intermediate.dense.bias', 'roberta.encoder.layer.7.output.dense.weight', 'roberta.encoder.layer.7.output.dense.bias', 'roberta.encoder.layer.7.output.LayerNorm.weight', 'roberta.encoder.layer.7.output.LayerNorm.bias', 'roberta.encoder.layer.8.attention.self.query.weight', 'roberta.encoder.layer.8.attention.self.query.bias', 'roberta.encoder.layer.8.attention.self.key.weight', 'roberta.encoder.layer.8.attention.self.key.bias', 'roberta.encoder.layer.8.attention.self.value.weight', 'roberta.encoder.layer.8.attention.self.value.bias', 'roberta.encoder.layer.8.attention.output.dense.weight', 'roberta.encoder.layer.8.attention.output.dense.bias', 'roberta.encoder.layer.8.attention.output.LayerNorm.weight', 'roberta.encoder.layer.8.attention.output.LayerNorm.bias', 'roberta.encoder.layer.8.intermediate.dense.weight', 'roberta.encoder.layer.8.intermediate.dense.bias', 'roberta.encoder.layer.8.output.dense.weight', 'roberta.encoder.layer.8.output.dense.bias', 'roberta.encoder.layer.8.output.LayerNorm.weight', 'roberta.encoder.layer.8.output.LayerNorm.bias', 'roberta.encoder.layer.9.attention.self.query.weight', 'roberta.encoder.layer.9.attention.self.query.bias', 'roberta.encoder.layer.9.attention.self.key.weight', 'roberta.encoder.layer.9.attention.self.key.bias', 'roberta.encoder.layer.9.attention.self.value.weight', 'roberta.encoder.layer.9.attention.self.value.bias', 'roberta.encoder.layer.9.attention.output.dense.weight', 'roberta.encoder.layer.9.attention.output.dense.bias', 'roberta.encoder.layer.9.attention.output.LayerNorm.weight', 'roberta.encoder.layer.9.attention.output.LayerNorm.bias', 'roberta.encoder.layer.9.intermediate.dense.weight', 'roberta.encoder.layer.9.intermediate.dense.bias', 'roberta.encoder.layer.9.output.dense.weight', 'roberta.encoder.layer.9.output.dense.bias', 'roberta.encoder.layer.9.output.LayerNorm.weight', 'roberta.encoder.layer.9.output.LayerNorm.bias', 'roberta.encoder.layer.10.attention.self.query.weight', 'roberta.encoder.layer.10.attention.self.query.bias', 'roberta.encoder.layer.10.attention.self.key.weight', 'roberta.encoder.layer.10.attention.self.key.bias', 'roberta.encoder.layer.10.attention.self.value.weight', 'roberta.encoder.layer.10.attention.self.value.bias', 'roberta.encoder.layer.10.attention.output.dense.weight', 'roberta.encoder.layer.10.attention.output.dense.bias', 'roberta.encoder.layer.10.attention.output.LayerNorm.weight', 'roberta.encoder.layer.10.attention.output.LayerNorm.bias', 'roberta.encoder.layer.10.intermediate.dense.weight', 'roberta.encoder.layer.10.intermediate.dense.bias', 'roberta.encoder.layer.10.output.dense.weight', 'roberta.encoder.layer.10.output.dense.bias', 'roberta.encoder.layer.10.output.LayerNorm.weight', 'roberta.encoder.layer.10.output.LayerNorm.bias', 'roberta.encoder.layer.11.attention.self.query.weight', 'roberta.encoder.layer.11.attention.self.query.bias', 'roberta.encoder.layer.11.attention.self.key.weight', 'roberta.encoder.layer.11.attention.self.key.bias', 'roberta.encoder.layer.11.attention.self.value.weight', 'roberta.encoder.layer.11.attention.self.value.bias', 'roberta.encoder.layer.11.attention.output.dense.weight', 'roberta.encoder.layer.11.attention.output.dense.bias', 'roberta.encoder.layer.11.attention.output.LayerNorm.weight', 'roberta.encoder.layer.11.attention.output.LayerNorm.bias', 'roberta.encoder.layer.11.intermediate.dense.weight', 'roberta.encoder.layer.11.intermediate.dense.bias', 'roberta.encoder.layer.11.output.dense.weight', 'roberta.encoder.layer.11.output.dense.bias', 'roberta.encoder.layer.11.output.LayerNorm.weight', 'roberta.encoder.layer.11.output.LayerNorm.bias', 'roberta.embeddings.word_embeddings.weight', 'roberta.embeddings.position_embeddings.weight', 'roberta.embeddings.token_type_embeddings.weight', 'roberta.embeddings.LayerNorm.weight', 'roberta.embeddings.LayerNorm.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.decoder.bias', 'lm_head.bias'])\n",
+      "odict_keys(['roberta.embeddings.word_embeddings.weight', 'roberta.embeddings.position_embeddings.weight', 'roberta.embeddings.token_type_embeddings.weight', 'roberta.embeddings.LayerNorm.weight', 'roberta.embeddings.LayerNorm.bias', 'roberta.encoder.layer.0.attention.self.query.weight', 'roberta.encoder.layer.0.attention.self.query.bias', 'roberta.encoder.layer.0.attention.self.key.weight', 'roberta.encoder.layer.0.attention.self.key.bias', 'roberta.encoder.layer.0.attention.self.value.weight', 'roberta.encoder.layer.0.attention.self.value.bias', 'roberta.encoder.layer.0.attention.output.dense.weight', 'roberta.encoder.layer.0.attention.output.dense.bias', 'roberta.encoder.layer.0.attention.output.LayerNorm.weight', 'roberta.encoder.layer.0.attention.output.LayerNorm.bias', 'roberta.encoder.layer.0.intermediate.dense.weight', 'roberta.encoder.layer.0.intermediate.dense.bias', 'roberta.encoder.layer.0.output.dense.weight', 'roberta.encoder.layer.0.output.dense.bias', 'roberta.encoder.layer.0.output.LayerNorm.weight', 'roberta.encoder.layer.0.output.LayerNorm.bias', 'roberta.encoder.layer.1.attention.self.query.weight', 'roberta.encoder.layer.1.attention.self.query.bias', 'roberta.encoder.layer.1.attention.self.key.weight', 'roberta.encoder.layer.1.attention.self.key.bias', 'roberta.encoder.layer.1.attention.self.value.weight', 'roberta.encoder.layer.1.attention.self.value.bias', 'roberta.encoder.layer.1.attention.output.dense.weight', 'roberta.encoder.layer.1.attention.output.dense.bias', 'roberta.encoder.layer.1.attention.output.LayerNorm.weight', 'roberta.encoder.layer.1.attention.output.LayerNorm.bias', 'roberta.encoder.layer.1.intermediate.dense.weight', 'roberta.encoder.layer.1.intermediate.dense.bias', 'roberta.encoder.layer.1.output.dense.weight', 'roberta.encoder.layer.1.output.dense.bias', 'roberta.encoder.layer.1.output.LayerNorm.weight', 'roberta.encoder.layer.1.output.LayerNorm.bias', 'roberta.encoder.layer.2.attention.self.query.weight', 'roberta.encoder.layer.2.attention.self.query.bias', 'roberta.encoder.layer.2.attention.self.key.weight', 'roberta.encoder.layer.2.attention.self.key.bias', 'roberta.encoder.layer.2.attention.self.value.weight', 'roberta.encoder.layer.2.attention.self.value.bias', 'roberta.encoder.layer.2.attention.output.dense.weight', 'roberta.encoder.layer.2.attention.output.dense.bias', 'roberta.encoder.layer.2.attention.output.LayerNorm.weight', 'roberta.encoder.layer.2.attention.output.LayerNorm.bias', 'roberta.encoder.layer.2.intermediate.dense.weight', 'roberta.encoder.layer.2.intermediate.dense.bias', 'roberta.encoder.layer.2.output.dense.weight', 'roberta.encoder.layer.2.output.dense.bias', 'roberta.encoder.layer.2.output.LayerNorm.weight', 'roberta.encoder.layer.2.output.LayerNorm.bias', 'roberta.encoder.layer.3.attention.self.query.weight', 'roberta.encoder.layer.3.attention.self.query.bias', 'roberta.encoder.layer.3.attention.self.key.weight', 'roberta.encoder.layer.3.attention.self.key.bias', 'roberta.encoder.layer.3.attention.self.value.weight', 'roberta.encoder.layer.3.attention.self.value.bias', 'roberta.encoder.layer.3.attention.output.dense.weight', 'roberta.encoder.layer.3.attention.output.dense.bias', 'roberta.encoder.layer.3.attention.output.LayerNorm.weight', 'roberta.encoder.layer.3.attention.output.LayerNorm.bias', 'roberta.encoder.layer.3.intermediate.dense.weight', 'roberta.encoder.layer.3.intermediate.dense.bias', 'roberta.encoder.layer.3.output.dense.weight', 'roberta.encoder.layer.3.output.dense.bias', 'roberta.encoder.layer.3.output.LayerNorm.weight', 'roberta.encoder.layer.3.output.LayerNorm.bias', 'roberta.encoder.layer.4.attention.self.query.weight', 'roberta.encoder.layer.4.attention.self.query.bias', 'roberta.encoder.layer.4.attention.self.key.weight', 'roberta.encoder.layer.4.attention.self.key.bias', 'roberta.encoder.layer.4.attention.self.value.weight', 'roberta.encoder.layer.4.attention.self.value.bias', 'roberta.encoder.layer.4.attention.output.dense.weight', 'roberta.encoder.layer.4.attention.output.dense.bias', 'roberta.encoder.layer.4.attention.output.LayerNorm.weight', 'roberta.encoder.layer.4.attention.output.LayerNorm.bias', 'roberta.encoder.layer.4.intermediate.dense.weight', 'roberta.encoder.layer.4.intermediate.dense.bias', 'roberta.encoder.layer.4.output.dense.weight', 'roberta.encoder.layer.4.output.dense.bias', 'roberta.encoder.layer.4.output.LayerNorm.weight', 'roberta.encoder.layer.4.output.LayerNorm.bias', 'roberta.encoder.layer.5.attention.self.query.weight', 'roberta.encoder.layer.5.attention.self.query.bias', 'roberta.encoder.layer.5.attention.self.key.weight', 'roberta.encoder.layer.5.attention.self.key.bias', 'roberta.encoder.layer.5.attention.self.value.weight', 'roberta.encoder.layer.5.attention.self.value.bias', 'roberta.encoder.layer.5.attention.output.dense.weight', 'roberta.encoder.layer.5.attention.output.dense.bias', 'roberta.encoder.layer.5.attention.output.LayerNorm.weight', 'roberta.encoder.layer.5.attention.output.LayerNorm.bias', 'roberta.encoder.layer.5.intermediate.dense.weight', 'roberta.encoder.layer.5.intermediate.dense.bias', 'roberta.encoder.layer.5.output.dense.weight', 'roberta.encoder.layer.5.output.dense.bias', 'roberta.encoder.layer.5.output.LayerNorm.weight', 'roberta.encoder.layer.5.output.LayerNorm.bias', 'roberta.encoder.layer.6.attention.self.query.weight', 'roberta.encoder.layer.6.attention.self.query.bias', 'roberta.encoder.layer.6.attention.self.key.weight', 'roberta.encoder.layer.6.attention.self.key.bias', 'roberta.encoder.layer.6.attention.self.value.weight', 'roberta.encoder.layer.6.attention.self.value.bias', 'roberta.encoder.layer.6.attention.output.dense.weight', 'roberta.encoder.layer.6.attention.output.dense.bias', 'roberta.encoder.layer.6.attention.output.LayerNorm.weight', 'roberta.encoder.layer.6.attention.output.LayerNorm.bias', 'roberta.encoder.layer.6.intermediate.dense.weight', 'roberta.encoder.layer.6.intermediate.dense.bias', 'roberta.encoder.layer.6.output.dense.weight', 'roberta.encoder.layer.6.output.dense.bias', 'roberta.encoder.layer.6.output.LayerNorm.weight', 'roberta.encoder.layer.6.output.LayerNorm.bias', 'roberta.encoder.layer.7.attention.self.query.weight', 'roberta.encoder.layer.7.attention.self.query.bias', 'roberta.encoder.layer.7.attention.self.key.weight', 'roberta.encoder.layer.7.attention.self.key.bias', 'roberta.encoder.layer.7.attention.self.value.weight', 'roberta.encoder.layer.7.attention.self.value.bias', 'roberta.encoder.layer.7.attention.output.dense.weight', 'roberta.encoder.layer.7.attention.output.dense.bias', 'roberta.encoder.layer.7.attention.output.LayerNorm.weight', 'roberta.encoder.layer.7.attention.output.LayerNorm.bias', 'roberta.encoder.layer.7.intermediate.dense.weight', 'roberta.encoder.layer.7.intermediate.dense.bias', 'roberta.encoder.layer.7.output.dense.weight', 'roberta.encoder.layer.7.output.dense.bias', 'roberta.encoder.layer.7.output.LayerNorm.weight', 'roberta.encoder.layer.7.output.LayerNorm.bias', 'roberta.encoder.layer.8.attention.self.query.weight', 'roberta.encoder.layer.8.attention.self.query.bias', 'roberta.encoder.layer.8.attention.self.key.weight', 'roberta.encoder.layer.8.attention.self.key.bias', 'roberta.encoder.layer.8.attention.self.value.weight', 'roberta.encoder.layer.8.attention.self.value.bias', 'roberta.encoder.layer.8.attention.output.dense.weight', 'roberta.encoder.layer.8.attention.output.dense.bias', 'roberta.encoder.layer.8.attention.output.LayerNorm.weight', 'roberta.encoder.layer.8.attention.output.LayerNorm.bias', 'roberta.encoder.layer.8.intermediate.dense.weight', 'roberta.encoder.layer.8.intermediate.dense.bias', 'roberta.encoder.layer.8.output.dense.weight', 'roberta.encoder.layer.8.output.dense.bias', 'roberta.encoder.layer.8.output.LayerNorm.weight', 'roberta.encoder.layer.8.output.LayerNorm.bias', 'roberta.encoder.layer.9.attention.self.query.weight', 'roberta.encoder.layer.9.attention.self.query.bias', 'roberta.encoder.layer.9.attention.self.key.weight', 'roberta.encoder.layer.9.attention.self.key.bias', 'roberta.encoder.layer.9.attention.self.value.weight', 'roberta.encoder.layer.9.attention.self.value.bias', 'roberta.encoder.layer.9.attention.output.dense.weight', 'roberta.encoder.layer.9.attention.output.dense.bias', 'roberta.encoder.layer.9.attention.output.LayerNorm.weight', 'roberta.encoder.layer.9.attention.output.LayerNorm.bias', 'roberta.encoder.layer.9.intermediate.dense.weight', 'roberta.encoder.layer.9.intermediate.dense.bias', 'roberta.encoder.layer.9.output.dense.weight', 'roberta.encoder.layer.9.output.dense.bias', 'roberta.encoder.layer.9.output.LayerNorm.weight', 'roberta.encoder.layer.9.output.LayerNorm.bias', 'roberta.encoder.layer.10.attention.self.query.weight', 'roberta.encoder.layer.10.attention.self.query.bias', 'roberta.encoder.layer.10.attention.self.key.weight', 'roberta.encoder.layer.10.attention.self.key.bias', 'roberta.encoder.layer.10.attention.self.value.weight', 'roberta.encoder.layer.10.attention.self.value.bias', 'roberta.encoder.layer.10.attention.output.dense.weight', 'roberta.encoder.layer.10.attention.output.dense.bias', 'roberta.encoder.layer.10.attention.output.LayerNorm.weight', 'roberta.encoder.layer.10.attention.output.LayerNorm.bias', 'roberta.encoder.layer.10.intermediate.dense.weight', 'roberta.encoder.layer.10.intermediate.dense.bias', 'roberta.encoder.layer.10.output.dense.weight', 'roberta.encoder.layer.10.output.dense.bias', 'roberta.encoder.layer.10.output.LayerNorm.weight', 'roberta.encoder.layer.10.output.LayerNorm.bias', 'roberta.encoder.layer.11.attention.self.query.weight', 'roberta.encoder.layer.11.attention.self.query.bias', 'roberta.encoder.layer.11.attention.self.key.weight', 'roberta.encoder.layer.11.attention.self.key.bias', 'roberta.encoder.layer.11.attention.self.value.weight', 'roberta.encoder.layer.11.attention.self.value.bias', 'roberta.encoder.layer.11.attention.output.dense.weight', 'roberta.encoder.layer.11.attention.output.dense.bias', 'roberta.encoder.layer.11.attention.output.LayerNorm.weight', 'roberta.encoder.layer.11.attention.output.LayerNorm.bias', 'roberta.encoder.layer.11.intermediate.dense.weight', 'roberta.encoder.layer.11.intermediate.dense.bias', 'roberta.encoder.layer.11.output.dense.weight', 'roberta.encoder.layer.11.output.dense.bias', 'roberta.encoder.layer.11.output.LayerNorm.weight', 'roberta.encoder.layer.11.output.LayerNorm.bias', 'lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.decoder.bias'])\n",
+      "success1\n",
+      "success2\n",
+      "0.0\n"
+     ]
+    }
+   ],
+   "source": [
+    "# Initializing RobertaForMaskedLM\n",
+    "my_mlm = my_roberta.RobertaForMaskedLM.init(Vocab, my_config, output_hidden_states=True, key=key_funcs)\n",
+    "state_mlm = my_mlm.to_state_dict()\n",
+    "\n",
+    "state_mlm = {k: torch.from_numpy(np.array(v)) for k, v in state_mlm.items()}\n",
+    "state_mlm[\"lm_head.bias\"] = torch.zeros(hf_config.vocab_size)\n",
+    "# print(state_mlm[w_str])\n",
+    "\n",
+    "hf_mlm = hf_roberta.RobertaForMaskedLM(hf_config)\n",
+    "hf_mlm.load_state_dict(state_mlm, strict=True, assign=True)\n",
+    "# print(hf_mlm.state_dict()[w_str])\n",
+    "\n",
+    "print(check_dicts(state_mlm, hf_mlm.state_dict()))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 17,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def test_RobertaModel_Output(key_run, ids = False):\n",
+    "    if ids:\n",
+    "        my_output = my_model(input_ids = input_ids, attention_mask=mask, key=key_run)\n",
+    "        hf_output = hf_model(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False, output_hidden_states=True)\n",
+    "    else:\n",
+    "        my_output = my_model(input_embeds = input_embeds, attention_mask=mask, key=key_run)\n",
+    "        hf_output = hf_model(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False, output_hidden_states=True)\n",
+    "\n",
+    "    return my_output, hf_output"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 18,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "my_output_ids, hf_output_ids = test_RobertaModel_Output(key_model_run, ids=True)\n",
+    "my_output_embeds, hf_output_embeds = test_RobertaModel_Output(key_model_run, ids=False)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 19,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "model_out: acc: 1.0 \t norms: (tensor(888.5331), tensor(888.5331)) \t diffs: 6.946668236196274e-07\n",
+      "pool_out: acc: 1.0 \t norms: (tensor(24.6133), tensor(24.6133)) \t diffs: 4.2906631847472454e-07\n",
+      "intermediates:\n",
+      "acc: 1.0 \t norms: (tensor(888.5315), tensor(888.5314)) \t diffs: 1.6926360046909394e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5301), tensor(888.5301)) \t diffs: 2.57225963196106e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5310), tensor(888.5311)) \t diffs: 3.1373855335914413e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5312), tensor(888.5312)) \t diffs: 3.6955742643840495e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5304), tensor(888.5304)) \t diffs: 4.1312114262836985e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5308), tensor(888.5308)) \t diffs: 4.612424220340472e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5308), tensor(888.5308)) \t diffs: 5.164657750356128e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5295), tensor(888.5295)) \t diffs: 5.412561563389318e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5309), tensor(888.5310)) \t diffs: 5.724053266931151e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5297), tensor(888.5297)) \t diffs: 5.980200512567535e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5317), tensor(888.5317)) \t diffs: 6.434746637751232e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5331), tensor(888.5331)) \t diffs: 6.946668236196274e-07\n"
+     ]
+    }
+   ],
+   "source": [
+    "# RobertaModel ids\n",
+    "my_out, hf_out = my_output_ids[0], hf_output_ids[0]\n",
+    "\n",
+    "print(f\"model_out: {check(my_out.array, hf_out.detach())[3]}\")\n",
+    "\n",
+    "my_pool, hf_pool = my_output_ids[1], hf_output_ids[1]\n",
+    "\n",
+    "print(f\"pool_out: {check(my_pool.array, hf_pool.detach())[3]}\")\n",
+    "\n",
+    "print(\"intermediates:\")\n",
+    "my_ints, hf_ints = my_output_ids[2], hf_output_ids[2][1:]\n",
+    "\n",
+    "for i,j in zip(my_ints, hf_ints):\n",
+    "    print(check(i.array,j.detach())[3])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 20,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "model_out: acc: 1.0 \t norms: (tensor(888.5307), tensor(888.5306)) \t diffs: 6.107513286224275e-07\n",
+      "pool_out: acc: 1.0 \t norms: (tensor(24.6094), tensor(24.6094)) \t diffs: 3.8713346839358564e-07\n",
+      "intermediates:\n",
+      "acc: 1.0 \t norms: (tensor(888.5317), tensor(888.5317)) \t diffs: 1.3876427829018212e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5305), tensor(888.5305)) \t diffs: 2.239140144411067e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5290), tensor(888.5290)) \t diffs: 2.9642876597790746e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5301), tensor(888.5300)) \t diffs: 3.554245893155894e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5311), tensor(888.5311)) \t diffs: 4.070468264671945e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5311), tensor(888.5311)) \t diffs: 4.3890696588277933e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5305), tensor(888.5304)) \t diffs: 4.87373824853421e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5308), tensor(888.5308)) \t diffs: 5.209108735471091e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5298), tensor(888.5298)) \t diffs: 5.463080583467672e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5311), tensor(888.5311)) \t diffs: 5.576325179390551e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5295), tensor(888.5294)) \t diffs: 5.659875341734733e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5307), tensor(888.5306)) \t diffs: 6.107513286224275e-07\n"
+     ]
+    }
+   ],
+   "source": [
+    "# RobertaModel embeds\n",
+    "\n",
+    "my_out, hf_out = my_output_embeds[0], hf_output_embeds[0]\n",
+    "\n",
+    "print(f\"model_out: {check(my_out.array, hf_out.detach())[3]}\")\n",
+    "\n",
+    "my_pool, hf_pool = my_output_embeds[1], hf_output_embeds[1]\n",
+    "\n",
+    "print(f\"pool_out: {check(my_pool.array, hf_pool.detach())[3]}\")\n",
+    "\n",
+    "print(\"intermediates:\")\n",
+    "my_ints, hf_ints = my_output_embeds[2], hf_output_embeds[2][1:]\n",
+    "\n",
+    "for i,j in zip(my_ints, hf_ints):\n",
+    "    print(check(i.array,j.detach())[3])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 21,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def test_RobertaForMaskedLM_Output(key_run, ids = False):\n",
+    "    if ids:\n",
+    "        my_output = my_mlm(input_ids = input_ids, attention_mask=mask, key=key_run)\n",
+    "        hf_output = hf_mlm(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False, output_hidden_states=True)\n",
+    "    else:\n",
+    "        my_output = my_mlm(input_embeds = input_embeds, attention_mask=mask, key=key_run)\n",
+    "        hf_output = hf_mlm(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False, output_hidden_states=True)\n",
+    "\n",
+    "    return my_output, hf_output"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 22,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "my_mlm_output_ids, hf_mlm_output_ids = test_RobertaForMaskedLM_Output(key_run, ids=True)\n",
+    "my_mlm_output_embeds, hf_mlm_output_embeds = test_RobertaForMaskedLM_Output(key_run, ids=False)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 23,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "mlm_out: acc: 1.0 \t norms: (tensor(7054.6812), tensor(7054.6816)) \t diffs: 7.966510224832746e-07\n",
+      "intermediates:\n",
+      "acc: 1.0 \t norms: (tensor(888.5315), tensor(888.5314)) \t diffs: 1.6926360046909394e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5301), tensor(888.5301)) \t diffs: 2.57225963196106e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5310), tensor(888.5311)) \t diffs: 3.1373855335914413e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5312), tensor(888.5312)) \t diffs: 3.6955742643840495e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5304), tensor(888.5304)) \t diffs: 4.1312114262836985e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5308), tensor(888.5308)) \t diffs: 4.612424220340472e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5308), tensor(888.5308)) \t diffs: 5.164657750356128e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5295), tensor(888.5295)) \t diffs: 5.412561563389318e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5309), tensor(888.5310)) \t diffs: 5.724053266931151e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5297), tensor(888.5297)) \t diffs: 5.980200512567535e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5317), tensor(888.5317)) \t diffs: 6.434746637751232e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5331), tensor(888.5331)) \t diffs: 6.946668236196274e-07\n"
+     ]
+    }
+   ],
+   "source": [
+    "#Masked MLM ids\n",
+    "my_out, hf_out = my_mlm_output_ids[0], hf_mlm_output_ids[0]\n",
+    "\n",
+    "print(f\"mlm_out: {check(my_out.array, hf_out.detach())[3]}\")\n",
+    "\n",
+    "print(\"intermediates:\")\n",
+    "my_ints, hf_ints = my_mlm_output_ids[1], hf_mlm_output_ids[1][1:]\n",
+    "\n",
+    "for i,j in zip(my_ints, hf_ints):\n",
+    "    print(check(i.array,j.detach(), precision = 0.01)[3])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 24,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "mlm_out: acc: 1.0 \t norms: (tensor(7107.9902), tensor(7107.9902)) \t diffs: 7.997662692105223e-07\n",
+      "intermediates:\n",
+      "acc: 1.0 \t norms: (tensor(888.5317), tensor(888.5317)) \t diffs: 1.3876427829018212e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5305), tensor(888.5305)) \t diffs: 2.239140144411067e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5290), tensor(888.5290)) \t diffs: 2.9642876597790746e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5301), tensor(888.5300)) \t diffs: 3.554245893155894e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5311), tensor(888.5311)) \t diffs: 4.070468264671945e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5311), tensor(888.5311)) \t diffs: 4.3890696588277933e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5305), tensor(888.5304)) \t diffs: 4.87373824853421e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5308), tensor(888.5308)) \t diffs: 5.209108735471091e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5298), tensor(888.5298)) \t diffs: 5.463080583467672e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5311), tensor(888.5311)) \t diffs: 5.576325179390551e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5295), tensor(888.5294)) \t diffs: 5.659875341734733e-07\n",
+      "acc: 1.0 \t norms: (tensor(888.5307), tensor(888.5306)) \t diffs: 6.107513286224275e-07\n"
+     ]
+    }
+   ],
+   "source": [
+    "#Masked MLM embeds\n",
+    "my_out, hf_out = my_mlm_output_embeds[0], hf_mlm_output_embeds[0]\n",
+    "\n",
+    "print(f\"mlm_out: {check(my_out.array, hf_out.detach())[3]}\")\n",
+    "\n",
+    "print(\"intermediates:\")\n",
+    "my_ints, hf_ints = my_mlm_output_embeds[1], hf_mlm_output_embeds[1][1:]\n",
+    "\n",
+    "for i,j in zip(my_ints, hf_ints):\n",
+    "    print(check(i.array,j.detach())[3])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 25,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# # Testing RobertaModel\n",
+    "# my_model = my_roberta.RobertaModel.init(Vocab, my_config, add_pooling_layer=False, key=key_rob)\n",
+    "# state_model = my_model.to_state_dict()\n",
+    "\n",
+    "# state_model = {k: torch.from_numpy(np.array(v)) for k, v in state_model.items()}\n",
+    "\n",
+    "# hf_model = hf_roberta.RobertaModel(hf_config, add_pooling_layer=False)\n",
+    "# hf_model.load_state_dict(state_model, strict=True)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 26,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "dict_keys(['dense.weight', 'dense.bias', 'layer_norm.weight', 'layer_norm.bias', 'decoder.weight', 'decoder.bias', 'bias'])\n",
+      "odict_keys(['bias', 'dense.weight', 'dense.bias', 'layer_norm.weight', 'layer_norm.bias', 'decoder.weight', 'decoder.bias'])\n",
+      "success1\n",
+      "success2\n",
+      "0.0\n"
+     ]
+    }
+   ],
+   "source": [
+    "# Testing RobertaLMHead\n",
+    "my_head = my_roberta.RobertaLMHead.init(Vocab, my_config, key=key_lm)\n",
+    "state_head = my_head.to_state_dict()\n",
+    "\n",
+    "state_head = {k: torch.from_numpy(np.array(v)) for k, v in state_head.items()}\n",
+    "state_head[\"bias\"] = torch.zeros(hf_config.vocab_size)\n",
+    "\n",
+    "hf_head = hf_roberta.RobertaLMHead(hf_config)\n",
+    "hf_head.load_state_dict(state_head, strict=True)\n",
+    "\n",
+    "print(check_dicts(state_head, hf_head.state_dict()))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 27,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "my_output_mlm, hf_output_mlm = test_RobertaForMaskedLM_Output(key_run, ids = False)\n",
+    "\n",
+    "my_output_model, hf_output_model = test_RobertaModel_Output(key_model_run, ids = False)\n",
+    "my_output = my_head(my_output_model[0], key=key_lm_run)\n",
+    "hf_output = hf_head(hf_output_model[0])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 28,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# k_rob, k_lm = jrandom.split(key, 2)\n",
+    "\n",
+    "# # MLM\n",
+    "# my_output_mlm = my_mlm(input_embeds = input_embeds, attention_mask=mask, key=key)\n",
+    "# hf_output_mlm = hf_mlm(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\n",
+    "\n",
+    "# # Model + LM\n",
+    "\n",
+    "# my_output_model = my_model(input_embeds = input_embeds, attention_mask=mask, key=k_rob)\n",
+    "# my_output = my_head(my_output_model[0], key=k_lm)\n",
+    "\n",
+    "# hf_output_model = hf_model(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)\n",
+    "# hf_output = hf_head(hf_output_model[0])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 29,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "RobertaModel: acc: 1.0 \t norms: (tensor(888.5307), tensor(888.5306)) \t diffs: 6.107513286224275e-07\n",
+      "Roberta Model + LM head: acc: 1.0 \t norms: (tensor(7107.9902), tensor(7107.9902)) \t diffs: 7.997662692105223e-07\n",
+      "MLM: acc: 1.0 \t norms: (tensor(7107.9902), tensor(7107.9902)) \t diffs: 7.997662692105223e-07\n",
+      "my RobertaModel + LM head vs MLM: acc: 1.0 \t norms: (tensor(7107.9902), tensor(7107.9902)) \t diffs: 0.0\n",
+      "hf RobertaModel + LM head vs MLM: acc: 1.0 \t norms: (tensor(7107.9902), tensor(7107.9902)) \t diffs: 0.0\n"
+     ]
+    }
+   ],
+   "source": [
+    "# embeds\n",
+    "print(f\"RobertaModel: {check(my_output_model[0].array, hf_output_model[0].detach())[3]}\")\n",
+    "print(f\"Roberta Model + LM head: {check(my_output.array, hf_output.detach())[3]}\")\n",
+    "print(f\"MLM: {check(my_output_mlm[0].array, hf_output_mlm[0].detach())[3]}\")\n",
+    "\n",
+    "print(f\"my RobertaModel + LM head vs MLM: {check(my_output.array, my_output_mlm[0].array)[3]}\")\n",
+    "print(f\"hf RobertaModel + LM head vs MLM: {check(hf_output.detach(), hf_output_mlm[0].detach())[3]}\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 30,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "my_output_mlm, hf_output_mlm = test_RobertaForMaskedLM_Output(key_run, ids = True)\n",
+    "\n",
+    "my_output_model, hf_output_model = test_RobertaModel_Output(key_model_run, ids = True)\n",
+    "my_output = my_head(my_output_model[0], key=key_lm_run)\n",
+    "hf_output = hf_head(hf_output_model[0])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 31,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# # MLM\n",
+    "# my_output_mlm = my_mlm(input_ids = input_ids, attention_mask=mask, key=key)\n",
+    "# hf_output_mlm = hf_mlm(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\n",
+    "\n",
+    "# # Model + LM\n",
+    "# my_output_model = my_model(input_ids = input_ids, attention_mask=mask, key=k_rob)\n",
+    "# my_output = my_head(my_output_model[0], key=k_lm)\n",
+    "\n",
+    "# hf_output_model = hf_model(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)\n",
+    "# hf_output = hf_head(hf_output_model[0])\n",
+    "\n",
+    "# # Notes\n",
+    "# # embeds works between hf and my, and within model->head and mlm \n",
+    "# # ids does not work between hf and my for mlm or within hf for model->head and mlm - so hf mlm is doing something weird.'''"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 32,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "RobertaModel: acc: 1.0 \t norms: (tensor(888.5331), tensor(888.5331)) \t diffs: 6.946668236196274e-07\n",
+      "Roberta Model + LM head: acc: 1.0 \t norms: (tensor(7054.6812), tensor(7054.6816)) \t diffs: 7.966510224832746e-07\n",
+      "MLM: acc: 1.0 \t norms: (tensor(7054.6812), tensor(7054.6816)) \t diffs: 7.966510224832746e-07\n",
+      "my RobertaModel + LM head vs MLM: acc: 1.0 \t norms: (tensor(7054.6812), tensor(7054.6812)) \t diffs: 0.0\n",
+      "hf RobertaModel + LM head vs MLM: acc: 1.0 \t norms: (tensor(7054.6816), tensor(7054.6816)) \t diffs: 0.0\n"
+     ]
+    }
+   ],
+   "source": [
+    "# ids\n",
+    "print(f\"RobertaModel: {check(my_output_model[0].array, hf_output_model[0].detach())[3]}\")\n",
+    "print(f\"Roberta Model + LM head: {check(my_output.array, hf_output.detach())[3]}\")\n",
+    "print(f\"MLM: {check(my_output_mlm[0].array, hf_output_mlm[0].detach())[3]}\")\n",
+    "\n",
+    "print(f\"my RobertaModel + LM head vs MLM: {check(my_output.array, my_output_mlm[0].array)[3]}\")\n",
+    "print(f\"hf RobertaModel + LM head vs MLM: {check(hf_output.detach(), hf_output_mlm[0].detach())[3]}\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 33,
+   "metadata": {},
+   "outputs": [
+    {
+     "ename": "NameError",
+     "evalue": "name 'stop' is not defined",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[1;31mNameError\u001b[0m                                 Traceback (most recent call last)",
+      "Cell \u001b[1;32mIn[33], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m \u001b[43mstop\u001b[49m\n",
+      "\u001b[1;31mNameError\u001b[0m: name 'stop' is not defined"
+     ]
+    }
+   ],
+   "source": [
+    "stop"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 38,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Load pretrained weights from hf\n",
+    "hf_model = hf_roberta.RobertaModel.from_pretrained(\"roberta-base\")\n",
+    "state_model = hf_model.state_dict()\n",
+    "\n",
+    "state_model = {k: np.array(v) for k, v in state_model.items()}\n",
+    "\n",
+    "hf_config = hf_model.config\n",
+    "hf_config.hidden_dropout_prob = 0\n",
+    "hf_config.attention_probs_dropout_prob = 0\n",
+    "my_config = my_roberta.RobertaConfig.from_hf_config(hf_config)\n",
+    "\n",
+    "my_model = my_roberta.RobertaModel.init(Vocab, my_config, output_hidden_states=True, key=key)\n",
+    "my_model = my_model.from_state_dict(state_model)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 39,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "dict_keys(['encoder.layer.0.attention.self.query.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.layer.0.intermediate.dense.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.0.output.dense.weight', 'encoder.layer.0.output.dense.bias', 'encoder.layer.0.output.LayerNorm.weight', 'encoder.layer.0.output.LayerNorm.bias', 'encoder.layer.1.attention.self.query.weight', 'encoder.layer.1.attention.self.query.bias', 'encoder.layer.1.attention.self.key.weight', 'encoder.layer.1.attention.self.key.bias', 'encoder.layer.1.attention.self.value.weight', 'encoder.layer.1.attention.self.value.bias', 'encoder.layer.1.attention.output.dense.weight', 'encoder.layer.1.attention.output.dense.bias', 'encoder.layer.1.attention.output.LayerNorm.weight', 'encoder.layer.1.attention.output.LayerNorm.bias', 'encoder.layer.1.intermediate.dense.weight', 'encoder.layer.1.intermediate.dense.bias', 'encoder.layer.1.output.dense.weight', 'encoder.layer.1.output.dense.bias', 'encoder.layer.1.output.LayerNorm.weight', 'encoder.layer.1.output.LayerNorm.bias', 'encoder.layer.2.attention.self.query.weight', 'encoder.layer.2.attention.self.query.bias', 'encoder.layer.2.attention.self.key.weight', 'encoder.layer.2.attention.self.key.bias', 'encoder.layer.2.attention.self.value.weight', 'encoder.layer.2.attention.self.value.bias', 'encoder.layer.2.attention.output.dense.weight', 'encoder.layer.2.attention.output.dense.bias', 'encoder.layer.2.attention.output.LayerNorm.weight', 'encoder.layer.2.attention.output.LayerNorm.bias', 'encoder.layer.2.intermediate.dense.weight', 'encoder.layer.2.intermediate.dense.bias', 'encoder.layer.2.output.dense.weight', 'encoder.layer.2.output.dense.bias', 'encoder.layer.2.output.LayerNorm.weight', 'encoder.layer.2.output.LayerNorm.bias', 'encoder.layer.3.attention.self.query.weight', 'encoder.layer.3.attention.self.query.bias', 'encoder.layer.3.attention.self.key.weight', 'encoder.layer.3.attention.self.key.bias', 'encoder.layer.3.attention.self.value.weight', 'encoder.layer.3.attention.self.value.bias', 'encoder.layer.3.attention.output.dense.weight', 'encoder.layer.3.attention.output.dense.bias', 'encoder.layer.3.attention.output.LayerNorm.weight', 'encoder.layer.3.attention.output.LayerNorm.bias', 'encoder.layer.3.intermediate.dense.weight', 'encoder.layer.3.intermediate.dense.bias', 'encoder.layer.3.output.dense.weight', 'encoder.layer.3.output.dense.bias', 'encoder.layer.3.output.LayerNorm.weight', 'encoder.layer.3.output.LayerNorm.bias', 'encoder.layer.4.attention.self.query.weight', 'encoder.layer.4.attention.self.query.bias', 'encoder.layer.4.attention.self.key.weight', 'encoder.layer.4.attention.self.key.bias', 'encoder.layer.4.attention.self.value.weight', 'encoder.layer.4.attention.self.value.bias', 'encoder.layer.4.attention.output.dense.weight', 'encoder.layer.4.attention.output.dense.bias', 'encoder.layer.4.attention.output.LayerNorm.weight', 'encoder.layer.4.attention.output.LayerNorm.bias', 'encoder.layer.4.intermediate.dense.weight', 'encoder.layer.4.intermediate.dense.bias', 'encoder.layer.4.output.dense.weight', 'encoder.layer.4.output.dense.bias', 'encoder.layer.4.output.LayerNorm.weight', 'encoder.layer.4.output.LayerNorm.bias', 'encoder.layer.5.attention.self.query.weight', 'encoder.layer.5.attention.self.query.bias', 'encoder.layer.5.attention.self.key.weight', 'encoder.layer.5.attention.self.key.bias', 'encoder.layer.5.attention.self.value.weight', 'encoder.layer.5.attention.self.value.bias', 'encoder.layer.5.attention.output.dense.weight', 'encoder.layer.5.attention.output.dense.bias', 'encoder.layer.5.attention.output.LayerNorm.weight', 'encoder.layer.5.attention.output.LayerNorm.bias', 'encoder.layer.5.intermediate.dense.weight', 'encoder.layer.5.intermediate.dense.bias', 'encoder.layer.5.output.dense.weight', 'encoder.layer.5.output.dense.bias', 'encoder.layer.5.output.LayerNorm.weight', 'encoder.layer.5.output.LayerNorm.bias', 'encoder.layer.6.attention.self.query.weight', 'encoder.layer.6.attention.self.query.bias', 'encoder.layer.6.attention.self.key.weight', 'encoder.layer.6.attention.self.key.bias', 'encoder.layer.6.attention.self.value.weight', 'encoder.layer.6.attention.self.value.bias', 'encoder.layer.6.attention.output.dense.weight', 'encoder.layer.6.attention.output.dense.bias', 'encoder.layer.6.attention.output.LayerNorm.weight', 'encoder.layer.6.attention.output.LayerNorm.bias', 'encoder.layer.6.intermediate.dense.weight', 'encoder.layer.6.intermediate.dense.bias', 'encoder.layer.6.output.dense.weight', 'encoder.layer.6.output.dense.bias', 'encoder.layer.6.output.LayerNorm.weight', 'encoder.layer.6.output.LayerNorm.bias', 'encoder.layer.7.attention.self.query.weight', 'encoder.layer.7.attention.self.query.bias', 'encoder.layer.7.attention.self.key.weight', 'encoder.layer.7.attention.self.key.bias', 'encoder.layer.7.attention.self.value.weight', 'encoder.layer.7.attention.self.value.bias', 'encoder.layer.7.attention.output.dense.weight', 'encoder.layer.7.attention.output.dense.bias', 'encoder.layer.7.attention.output.LayerNorm.weight', 'encoder.layer.7.attention.output.LayerNorm.bias', 'encoder.layer.7.intermediate.dense.weight', 'encoder.layer.7.intermediate.dense.bias', 'encoder.layer.7.output.dense.weight', 'encoder.layer.7.output.dense.bias', 'encoder.layer.7.output.LayerNorm.weight', 'encoder.layer.7.output.LayerNorm.bias', 'encoder.layer.8.attention.self.query.weight', 'encoder.layer.8.attention.self.query.bias', 'encoder.layer.8.attention.self.key.weight', 'encoder.layer.8.attention.self.key.bias', 'encoder.layer.8.attention.self.value.weight', 'encoder.layer.8.attention.self.value.bias', 'encoder.layer.8.attention.output.dense.weight', 'encoder.layer.8.attention.output.dense.bias', 'encoder.layer.8.attention.output.LayerNorm.weight', 'encoder.layer.8.attention.output.LayerNorm.bias', 'encoder.layer.8.intermediate.dense.weight', 'encoder.layer.8.intermediate.dense.bias', 'encoder.layer.8.output.dense.weight', 'encoder.layer.8.output.dense.bias', 'encoder.layer.8.output.LayerNorm.weight', 'encoder.layer.8.output.LayerNorm.bias', 'encoder.layer.9.attention.self.query.weight', 'encoder.layer.9.attention.self.query.bias', 'encoder.layer.9.attention.self.key.weight', 'encoder.layer.9.attention.self.key.bias', 'encoder.layer.9.attention.self.value.weight', 'encoder.layer.9.attention.self.value.bias', 'encoder.layer.9.attention.output.dense.weight', 'encoder.layer.9.attention.output.dense.bias', 'encoder.layer.9.attention.output.LayerNorm.weight', 'encoder.layer.9.attention.output.LayerNorm.bias', 'encoder.layer.9.intermediate.dense.weight', 'encoder.layer.9.intermediate.dense.bias', 'encoder.layer.9.output.dense.weight', 'encoder.layer.9.output.dense.bias', 'encoder.layer.9.output.LayerNorm.weight', 'encoder.layer.9.output.LayerNorm.bias', 'encoder.layer.10.attention.self.query.weight', 'encoder.layer.10.attention.self.query.bias', 'encoder.layer.10.attention.self.key.weight', 'encoder.layer.10.attention.self.key.bias', 'encoder.layer.10.attention.self.value.weight', 'encoder.layer.10.attention.self.value.bias', 'encoder.layer.10.attention.output.dense.weight', 'encoder.layer.10.attention.output.dense.bias', 'encoder.layer.10.attention.output.LayerNorm.weight', 'encoder.layer.10.attention.output.LayerNorm.bias', 'encoder.layer.10.intermediate.dense.weight', 'encoder.layer.10.intermediate.dense.bias', 'encoder.layer.10.output.dense.weight', 'encoder.layer.10.output.dense.bias', 'encoder.layer.10.output.LayerNorm.weight', 'encoder.layer.10.output.LayerNorm.bias', 'encoder.layer.11.attention.self.query.weight', 'encoder.layer.11.attention.self.query.bias', 'encoder.layer.11.attention.self.key.weight', 'encoder.layer.11.attention.self.key.bias', 'encoder.layer.11.attention.self.value.weight', 'encoder.layer.11.attention.self.value.bias', 'encoder.layer.11.attention.output.dense.weight', 'encoder.layer.11.attention.output.dense.bias', 'encoder.layer.11.attention.output.LayerNorm.weight', 'encoder.layer.11.attention.output.LayerNorm.bias', 'encoder.layer.11.intermediate.dense.weight', 'encoder.layer.11.intermediate.dense.bias', 'encoder.layer.11.output.dense.weight', 'encoder.layer.11.output.dense.bias', 'encoder.layer.11.output.LayerNorm.weight', 'encoder.layer.11.output.LayerNorm.bias', 'embeddings.word_embeddings.weight', 'embeddings.position_embeddings.weight', 'embeddings.token_type_embeddings.weight', 'embeddings.LayerNorm.weight', 'embeddings.LayerNorm.bias', 'pooler.dense.weight', 'pooler.dense.bias'])\n",
+      "odict_keys(['embeddings.word_embeddings.weight', 'embeddings.position_embeddings.weight', 'embeddings.token_type_embeddings.weight', 'embeddings.LayerNorm.weight', 'embeddings.LayerNorm.bias', 'encoder.layer.0.attention.self.query.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.layer.0.intermediate.dense.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.0.output.dense.weight', 'encoder.layer.0.output.dense.bias', 'encoder.layer.0.output.LayerNorm.weight', 'encoder.layer.0.output.LayerNorm.bias', 'encoder.layer.1.attention.self.query.weight', 'encoder.layer.1.attention.self.query.bias', 'encoder.layer.1.attention.self.key.weight', 'encoder.layer.1.attention.self.key.bias', 'encoder.layer.1.attention.self.value.weight', 'encoder.layer.1.attention.self.value.bias', 'encoder.layer.1.attention.output.dense.weight', 'encoder.layer.1.attention.output.dense.bias', 'encoder.layer.1.attention.output.LayerNorm.weight', 'encoder.layer.1.attention.output.LayerNorm.bias', 'encoder.layer.1.intermediate.dense.weight', 'encoder.layer.1.intermediate.dense.bias', 'encoder.layer.1.output.dense.weight', 'encoder.layer.1.output.dense.bias', 'encoder.layer.1.output.LayerNorm.weight', 'encoder.layer.1.output.LayerNorm.bias', 'encoder.layer.2.attention.self.query.weight', 'encoder.layer.2.attention.self.query.bias', 'encoder.layer.2.attention.self.key.weight', 'encoder.layer.2.attention.self.key.bias', 'encoder.layer.2.attention.self.value.weight', 'encoder.layer.2.attention.self.value.bias', 'encoder.layer.2.attention.output.dense.weight', 'encoder.layer.2.attention.output.dense.bias', 'encoder.layer.2.attention.output.LayerNorm.weight', 'encoder.layer.2.attention.output.LayerNorm.bias', 'encoder.layer.2.intermediate.dense.weight', 'encoder.layer.2.intermediate.dense.bias', 'encoder.layer.2.output.dense.weight', 'encoder.layer.2.output.dense.bias', 'encoder.layer.2.output.LayerNorm.weight', 'encoder.layer.2.output.LayerNorm.bias', 'encoder.layer.3.attention.self.query.weight', 'encoder.layer.3.attention.self.query.bias', 'encoder.layer.3.attention.self.key.weight', 'encoder.layer.3.attention.self.key.bias', 'encoder.layer.3.attention.self.value.weight', 'encoder.layer.3.attention.self.value.bias', 'encoder.layer.3.attention.output.dense.weight', 'encoder.layer.3.attention.output.dense.bias', 'encoder.layer.3.attention.output.LayerNorm.weight', 'encoder.layer.3.attention.output.LayerNorm.bias', 'encoder.layer.3.intermediate.dense.weight', 'encoder.layer.3.intermediate.dense.bias', 'encoder.layer.3.output.dense.weight', 'encoder.layer.3.output.dense.bias', 'encoder.layer.3.output.LayerNorm.weight', 'encoder.layer.3.output.LayerNorm.bias', 'encoder.layer.4.attention.self.query.weight', 'encoder.layer.4.attention.self.query.bias', 'encoder.layer.4.attention.self.key.weight', 'encoder.layer.4.attention.self.key.bias', 'encoder.layer.4.attention.self.value.weight', 'encoder.layer.4.attention.self.value.bias', 'encoder.layer.4.attention.output.dense.weight', 'encoder.layer.4.attention.output.dense.bias', 'encoder.layer.4.attention.output.LayerNorm.weight', 'encoder.layer.4.attention.output.LayerNorm.bias', 'encoder.layer.4.intermediate.dense.weight', 'encoder.layer.4.intermediate.dense.bias', 'encoder.layer.4.output.dense.weight', 'encoder.layer.4.output.dense.bias', 'encoder.layer.4.output.LayerNorm.weight', 'encoder.layer.4.output.LayerNorm.bias', 'encoder.layer.5.attention.self.query.weight', 'encoder.layer.5.attention.self.query.bias', 'encoder.layer.5.attention.self.key.weight', 'encoder.layer.5.attention.self.key.bias', 'encoder.layer.5.attention.self.value.weight', 'encoder.layer.5.attention.self.value.bias', 'encoder.layer.5.attention.output.dense.weight', 'encoder.layer.5.attention.output.dense.bias', 'encoder.layer.5.attention.output.LayerNorm.weight', 'encoder.layer.5.attention.output.LayerNorm.bias', 'encoder.layer.5.intermediate.dense.weight', 'encoder.layer.5.intermediate.dense.bias', 'encoder.layer.5.output.dense.weight', 'encoder.layer.5.output.dense.bias', 'encoder.layer.5.output.LayerNorm.weight', 'encoder.layer.5.output.LayerNorm.bias', 'encoder.layer.6.attention.self.query.weight', 'encoder.layer.6.attention.self.query.bias', 'encoder.layer.6.attention.self.key.weight', 'encoder.layer.6.attention.self.key.bias', 'encoder.layer.6.attention.self.value.weight', 'encoder.layer.6.attention.self.value.bias', 'encoder.layer.6.attention.output.dense.weight', 'encoder.layer.6.attention.output.dense.bias', 'encoder.layer.6.attention.output.LayerNorm.weight', 'encoder.layer.6.attention.output.LayerNorm.bias', 'encoder.layer.6.intermediate.dense.weight', 'encoder.layer.6.intermediate.dense.bias', 'encoder.layer.6.output.dense.weight', 'encoder.layer.6.output.dense.bias', 'encoder.layer.6.output.LayerNorm.weight', 'encoder.layer.6.output.LayerNorm.bias', 'encoder.layer.7.attention.self.query.weight', 'encoder.layer.7.attention.self.query.bias', 'encoder.layer.7.attention.self.key.weight', 'encoder.layer.7.attention.self.key.bias', 'encoder.layer.7.attention.self.value.weight', 'encoder.layer.7.attention.self.value.bias', 'encoder.layer.7.attention.output.dense.weight', 'encoder.layer.7.attention.output.dense.bias', 'encoder.layer.7.attention.output.LayerNorm.weight', 'encoder.layer.7.attention.output.LayerNorm.bias', 'encoder.layer.7.intermediate.dense.weight', 'encoder.layer.7.intermediate.dense.bias', 'encoder.layer.7.output.dense.weight', 'encoder.layer.7.output.dense.bias', 'encoder.layer.7.output.LayerNorm.weight', 'encoder.layer.7.output.LayerNorm.bias', 'encoder.layer.8.attention.self.query.weight', 'encoder.layer.8.attention.self.query.bias', 'encoder.layer.8.attention.self.key.weight', 'encoder.layer.8.attention.self.key.bias', 'encoder.layer.8.attention.self.value.weight', 'encoder.layer.8.attention.self.value.bias', 'encoder.layer.8.attention.output.dense.weight', 'encoder.layer.8.attention.output.dense.bias', 'encoder.layer.8.attention.output.LayerNorm.weight', 'encoder.layer.8.attention.output.LayerNorm.bias', 'encoder.layer.8.intermediate.dense.weight', 'encoder.layer.8.intermediate.dense.bias', 'encoder.layer.8.output.dense.weight', 'encoder.layer.8.output.dense.bias', 'encoder.layer.8.output.LayerNorm.weight', 'encoder.layer.8.output.LayerNorm.bias', 'encoder.layer.9.attention.self.query.weight', 'encoder.layer.9.attention.self.query.bias', 'encoder.layer.9.attention.self.key.weight', 'encoder.layer.9.attention.self.key.bias', 'encoder.layer.9.attention.self.value.weight', 'encoder.layer.9.attention.self.value.bias', 'encoder.layer.9.attention.output.dense.weight', 'encoder.layer.9.attention.output.dense.bias', 'encoder.layer.9.attention.output.LayerNorm.weight', 'encoder.layer.9.attention.output.LayerNorm.bias', 'encoder.layer.9.intermediate.dense.weight', 'encoder.layer.9.intermediate.dense.bias', 'encoder.layer.9.output.dense.weight', 'encoder.layer.9.output.dense.bias', 'encoder.layer.9.output.LayerNorm.weight', 'encoder.layer.9.output.LayerNorm.bias', 'encoder.layer.10.attention.self.query.weight', 'encoder.layer.10.attention.self.query.bias', 'encoder.layer.10.attention.self.key.weight', 'encoder.layer.10.attention.self.key.bias', 'encoder.layer.10.attention.self.value.weight', 'encoder.layer.10.attention.self.value.bias', 'encoder.layer.10.attention.output.dense.weight', 'encoder.layer.10.attention.output.dense.bias', 'encoder.layer.10.attention.output.LayerNorm.weight', 'encoder.layer.10.attention.output.LayerNorm.bias', 'encoder.layer.10.intermediate.dense.weight', 'encoder.layer.10.intermediate.dense.bias', 'encoder.layer.10.output.dense.weight', 'encoder.layer.10.output.dense.bias', 'encoder.layer.10.output.LayerNorm.weight', 'encoder.layer.10.output.LayerNorm.bias', 'encoder.layer.11.attention.self.query.weight', 'encoder.layer.11.attention.self.query.bias', 'encoder.layer.11.attention.self.key.weight', 'encoder.layer.11.attention.self.key.bias', 'encoder.layer.11.attention.self.value.weight', 'encoder.layer.11.attention.self.value.bias', 'encoder.layer.11.attention.output.dense.weight', 'encoder.layer.11.attention.output.dense.bias', 'encoder.layer.11.attention.output.LayerNorm.weight', 'encoder.layer.11.attention.output.LayerNorm.bias', 'encoder.layer.11.intermediate.dense.weight', 'encoder.layer.11.intermediate.dense.bias', 'encoder.layer.11.output.dense.weight', 'encoder.layer.11.output.dense.bias', 'encoder.layer.11.output.LayerNorm.weight', 'encoder.layer.11.output.LayerNorm.bias', 'pooler.dense.weight', 'pooler.dense.bias'])\n",
+      "success1\n",
+      "success2\n",
+      "Total differences: 0.0\n"
+     ]
+    }
+   ],
+   "source": [
+    "# Check weights loaded correctly\n",
+    "my_dict = my_model.to_state_dict()\n",
+    "hf_dict = hf_model.state_dict()\n",
+    "\n",
+    "print(f\"Total differences: {check_dicts(my_dict, hf_dict)}\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 40,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n"
+     ]
+    }
+   ],
+   "source": [
+    "my_output_model = my_model(input_embeds = input_embeds, attention_mask=mask, key=key)\n",
+    "hf_output_model = hf_model(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 41,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Model: acc: 1.0 \t norms: (tensor(443.3569), tensor(443.3564)) \t diffs: 1.405485477334878e-06\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(f\"Model: {check(my_output_model[0].array, hf_output_model[0].detach())[3]}\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 42,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n"
+     ]
+    }
+   ],
+   "source": [
+    "my_output_model = my_model(input_ids = input_ids, attention_mask=mask, key=key)\n",
+    "hf_output_model = hf_model(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 43,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Model: acc: 1.0 \t norms: (tensor(431.9545), tensor(431.9545)) \t diffs: 1.3033360346526024e-06\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(f\"Model: {check(my_output_model[0].array, hf_output_model[0].detach())[3]}\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 57,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "hf_mlm = hf_roberta.RobertaForMaskedLM.from_pretrained(hf_model_str)\n",
+    "state_mlm = hf_mlm.state_dict()\n",
+    "\n",
+    "state_mlm = {k: np.array(v) for k, v in state_mlm.items()}\n",
+    "\n",
+    "hf_config = hf_mlm.config\n",
+    "hf_config.hidden_dropout_prob = 0\n",
+    "hf_config.attention_probs_dropout_prob = 0\n",
+    "my_config = my_roberta.RobertaConfig.from_hf_config(hf_config)\n",
+    "\n",
+    "my_mlm = my_roberta.RobertaForMaskedLM.init(Vocab, my_config, key=key)\n",
+    "my_mlm = my_mlm.from_state_dict(state_mlm)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 58,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "dict_keys(['roberta.encoder.layer.0.attention.self.query.weight', 'roberta.encoder.layer.0.attention.self.query.bias', 'roberta.encoder.layer.0.attention.self.key.weight', 'roberta.encoder.layer.0.attention.self.key.bias', 'roberta.encoder.layer.0.attention.self.value.weight', 'roberta.encoder.layer.0.attention.self.value.bias', 'roberta.encoder.layer.0.attention.output.dense.weight', 'roberta.encoder.layer.0.attention.output.dense.bias', 'roberta.encoder.layer.0.attention.output.LayerNorm.weight', 'roberta.encoder.layer.0.attention.output.LayerNorm.bias', 'roberta.encoder.layer.0.intermediate.dense.weight', 'roberta.encoder.layer.0.intermediate.dense.bias', 'roberta.encoder.layer.0.output.dense.weight', 'roberta.encoder.layer.0.output.dense.bias', 'roberta.encoder.layer.0.output.LayerNorm.weight', 'roberta.encoder.layer.0.output.LayerNorm.bias', 'roberta.encoder.layer.1.attention.self.query.weight', 'roberta.encoder.layer.1.attention.self.query.bias', 'roberta.encoder.layer.1.attention.self.key.weight', 'roberta.encoder.layer.1.attention.self.key.bias', 'roberta.encoder.layer.1.attention.self.value.weight', 'roberta.encoder.layer.1.attention.self.value.bias', 'roberta.encoder.layer.1.attention.output.dense.weight', 'roberta.encoder.layer.1.attention.output.dense.bias', 'roberta.encoder.layer.1.attention.output.LayerNorm.weight', 'roberta.encoder.layer.1.attention.output.LayerNorm.bias', 'roberta.encoder.layer.1.intermediate.dense.weight', 'roberta.encoder.layer.1.intermediate.dense.bias', 'roberta.encoder.layer.1.output.dense.weight', 'roberta.encoder.layer.1.output.dense.bias', 'roberta.encoder.layer.1.output.LayerNorm.weight', 'roberta.encoder.layer.1.output.LayerNorm.bias', 'roberta.encoder.layer.2.attention.self.query.weight', 'roberta.encoder.layer.2.attention.self.query.bias', 'roberta.encoder.layer.2.attention.self.key.weight', 'roberta.encoder.layer.2.attention.self.key.bias', 'roberta.encoder.layer.2.attention.self.value.weight', 'roberta.encoder.layer.2.attention.self.value.bias', 'roberta.encoder.layer.2.attention.output.dense.weight', 'roberta.encoder.layer.2.attention.output.dense.bias', 'roberta.encoder.layer.2.attention.output.LayerNorm.weight', 'roberta.encoder.layer.2.attention.output.LayerNorm.bias', 'roberta.encoder.layer.2.intermediate.dense.weight', 'roberta.encoder.layer.2.intermediate.dense.bias', 'roberta.encoder.layer.2.output.dense.weight', 'roberta.encoder.layer.2.output.dense.bias', 'roberta.encoder.layer.2.output.LayerNorm.weight', 'roberta.encoder.layer.2.output.LayerNorm.bias', 'roberta.encoder.layer.3.attention.self.query.weight', 'roberta.encoder.layer.3.attention.self.query.bias', 'roberta.encoder.layer.3.attention.self.key.weight', 'roberta.encoder.layer.3.attention.self.key.bias', 'roberta.encoder.layer.3.attention.self.value.weight', 'roberta.encoder.layer.3.attention.self.value.bias', 'roberta.encoder.layer.3.attention.output.dense.weight', 'roberta.encoder.layer.3.attention.output.dense.bias', 'roberta.encoder.layer.3.attention.output.LayerNorm.weight', 'roberta.encoder.layer.3.attention.output.LayerNorm.bias', 'roberta.encoder.layer.3.intermediate.dense.weight', 'roberta.encoder.layer.3.intermediate.dense.bias', 'roberta.encoder.layer.3.output.dense.weight', 'roberta.encoder.layer.3.output.dense.bias', 'roberta.encoder.layer.3.output.LayerNorm.weight', 'roberta.encoder.layer.3.output.LayerNorm.bias', 'roberta.encoder.layer.4.attention.self.query.weight', 'roberta.encoder.layer.4.attention.self.query.bias', 'roberta.encoder.layer.4.attention.self.key.weight', 'roberta.encoder.layer.4.attention.self.key.bias', 'roberta.encoder.layer.4.attention.self.value.weight', 'roberta.encoder.layer.4.attention.self.value.bias', 'roberta.encoder.layer.4.attention.output.dense.weight', 'roberta.encoder.layer.4.attention.output.dense.bias', 'roberta.encoder.layer.4.attention.output.LayerNorm.weight', 'roberta.encoder.layer.4.attention.output.LayerNorm.bias', 'roberta.encoder.layer.4.intermediate.dense.weight', 'roberta.encoder.layer.4.intermediate.dense.bias', 'roberta.encoder.layer.4.output.dense.weight', 'roberta.encoder.layer.4.output.dense.bias', 'roberta.encoder.layer.4.output.LayerNorm.weight', 'roberta.encoder.layer.4.output.LayerNorm.bias', 'roberta.encoder.layer.5.attention.self.query.weight', 'roberta.encoder.layer.5.attention.self.query.bias', 'roberta.encoder.layer.5.attention.self.key.weight', 'roberta.encoder.layer.5.attention.self.key.bias', 'roberta.encoder.layer.5.attention.self.value.weight', 'roberta.encoder.layer.5.attention.self.value.bias', 'roberta.encoder.layer.5.attention.output.dense.weight', 'roberta.encoder.layer.5.attention.output.dense.bias', 'roberta.encoder.layer.5.attention.output.LayerNorm.weight', 'roberta.encoder.layer.5.attention.output.LayerNorm.bias', 'roberta.encoder.layer.5.intermediate.dense.weight', 'roberta.encoder.layer.5.intermediate.dense.bias', 'roberta.encoder.layer.5.output.dense.weight', 'roberta.encoder.layer.5.output.dense.bias', 'roberta.encoder.layer.5.output.LayerNorm.weight', 'roberta.encoder.layer.5.output.LayerNorm.bias', 'roberta.encoder.layer.6.attention.self.query.weight', 'roberta.encoder.layer.6.attention.self.query.bias', 'roberta.encoder.layer.6.attention.self.key.weight', 'roberta.encoder.layer.6.attention.self.key.bias', 'roberta.encoder.layer.6.attention.self.value.weight', 'roberta.encoder.layer.6.attention.self.value.bias', 'roberta.encoder.layer.6.attention.output.dense.weight', 'roberta.encoder.layer.6.attention.output.dense.bias', 'roberta.encoder.layer.6.attention.output.LayerNorm.weight', 'roberta.encoder.layer.6.attention.output.LayerNorm.bias', 'roberta.encoder.layer.6.intermediate.dense.weight', 'roberta.encoder.layer.6.intermediate.dense.bias', 'roberta.encoder.layer.6.output.dense.weight', 'roberta.encoder.layer.6.output.dense.bias', 'roberta.encoder.layer.6.output.LayerNorm.weight', 'roberta.encoder.layer.6.output.LayerNorm.bias', 'roberta.encoder.layer.7.attention.self.query.weight', 'roberta.encoder.layer.7.attention.self.query.bias', 'roberta.encoder.layer.7.attention.self.key.weight', 'roberta.encoder.layer.7.attention.self.key.bias', 'roberta.encoder.layer.7.attention.self.value.weight', 'roberta.encoder.layer.7.attention.self.value.bias', 'roberta.encoder.layer.7.attention.output.dense.weight', 'roberta.encoder.layer.7.attention.output.dense.bias', 'roberta.encoder.layer.7.attention.output.LayerNorm.weight', 'roberta.encoder.layer.7.attention.output.LayerNorm.bias', 'roberta.encoder.layer.7.intermediate.dense.weight', 'roberta.encoder.layer.7.intermediate.dense.bias', 'roberta.encoder.layer.7.output.dense.weight', 'roberta.encoder.layer.7.output.dense.bias', 'roberta.encoder.layer.7.output.LayerNorm.weight', 'roberta.encoder.layer.7.output.LayerNorm.bias', 'roberta.encoder.layer.8.attention.self.query.weight', 'roberta.encoder.layer.8.attention.self.query.bias', 'roberta.encoder.layer.8.attention.self.key.weight', 'roberta.encoder.layer.8.attention.self.key.bias', 'roberta.encoder.layer.8.attention.self.value.weight', 'roberta.encoder.layer.8.attention.self.value.bias', 'roberta.encoder.layer.8.attention.output.dense.weight', 'roberta.encoder.layer.8.attention.output.dense.bias', 'roberta.encoder.layer.8.attention.output.LayerNorm.weight', 'roberta.encoder.layer.8.attention.output.LayerNorm.bias', 'roberta.encoder.layer.8.intermediate.dense.weight', 'roberta.encoder.layer.8.intermediate.dense.bias', 'roberta.encoder.layer.8.output.dense.weight', 'roberta.encoder.layer.8.output.dense.bias', 'roberta.encoder.layer.8.output.LayerNorm.weight', 'roberta.encoder.layer.8.output.LayerNorm.bias', 'roberta.encoder.layer.9.attention.self.query.weight', 'roberta.encoder.layer.9.attention.self.query.bias', 'roberta.encoder.layer.9.attention.self.key.weight', 'roberta.encoder.layer.9.attention.self.key.bias', 'roberta.encoder.layer.9.attention.self.value.weight', 'roberta.encoder.layer.9.attention.self.value.bias', 'roberta.encoder.layer.9.attention.output.dense.weight', 'roberta.encoder.layer.9.attention.output.dense.bias', 'roberta.encoder.layer.9.attention.output.LayerNorm.weight', 'roberta.encoder.layer.9.attention.output.LayerNorm.bias', 'roberta.encoder.layer.9.intermediate.dense.weight', 'roberta.encoder.layer.9.intermediate.dense.bias', 'roberta.encoder.layer.9.output.dense.weight', 'roberta.encoder.layer.9.output.dense.bias', 'roberta.encoder.layer.9.output.LayerNorm.weight', 'roberta.encoder.layer.9.output.LayerNorm.bias', 'roberta.encoder.layer.10.attention.self.query.weight', 'roberta.encoder.layer.10.attention.self.query.bias', 'roberta.encoder.layer.10.attention.self.key.weight', 'roberta.encoder.layer.10.attention.self.key.bias', 'roberta.encoder.layer.10.attention.self.value.weight', 'roberta.encoder.layer.10.attention.self.value.bias', 'roberta.encoder.layer.10.attention.output.dense.weight', 'roberta.encoder.layer.10.attention.output.dense.bias', 'roberta.encoder.layer.10.attention.output.LayerNorm.weight', 'roberta.encoder.layer.10.attention.output.LayerNorm.bias', 'roberta.encoder.layer.10.intermediate.dense.weight', 'roberta.encoder.layer.10.intermediate.dense.bias', 'roberta.encoder.layer.10.output.dense.weight', 'roberta.encoder.layer.10.output.dense.bias', 'roberta.encoder.layer.10.output.LayerNorm.weight', 'roberta.encoder.layer.10.output.LayerNorm.bias', 'roberta.encoder.layer.11.attention.self.query.weight', 'roberta.encoder.layer.11.attention.self.query.bias', 'roberta.encoder.layer.11.attention.self.key.weight', 'roberta.encoder.layer.11.attention.self.key.bias', 'roberta.encoder.layer.11.attention.self.value.weight', 'roberta.encoder.layer.11.attention.self.value.bias', 'roberta.encoder.layer.11.attention.output.dense.weight', 'roberta.encoder.layer.11.attention.output.dense.bias', 'roberta.encoder.layer.11.attention.output.LayerNorm.weight', 'roberta.encoder.layer.11.attention.output.LayerNorm.bias', 'roberta.encoder.layer.11.intermediate.dense.weight', 'roberta.encoder.layer.11.intermediate.dense.bias', 'roberta.encoder.layer.11.output.dense.weight', 'roberta.encoder.layer.11.output.dense.bias', 'roberta.encoder.layer.11.output.LayerNorm.weight', 'roberta.encoder.layer.11.output.LayerNorm.bias', 'roberta.embeddings.word_embeddings.weight', 'roberta.embeddings.position_embeddings.weight', 'roberta.embeddings.token_type_embeddings.weight', 'roberta.embeddings.LayerNorm.weight', 'roberta.embeddings.LayerNorm.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.decoder.bias'])\n",
+      "odict_keys(['roberta.embeddings.word_embeddings.weight', 'roberta.embeddings.position_embeddings.weight', 'roberta.embeddings.token_type_embeddings.weight', 'roberta.embeddings.LayerNorm.weight', 'roberta.embeddings.LayerNorm.bias', 'roberta.encoder.layer.0.attention.self.query.weight', 'roberta.encoder.layer.0.attention.self.query.bias', 'roberta.encoder.layer.0.attention.self.key.weight', 'roberta.encoder.layer.0.attention.self.key.bias', 'roberta.encoder.layer.0.attention.self.value.weight', 'roberta.encoder.layer.0.attention.self.value.bias', 'roberta.encoder.layer.0.attention.output.dense.weight', 'roberta.encoder.layer.0.attention.output.dense.bias', 'roberta.encoder.layer.0.attention.output.LayerNorm.weight', 'roberta.encoder.layer.0.attention.output.LayerNorm.bias', 'roberta.encoder.layer.0.intermediate.dense.weight', 'roberta.encoder.layer.0.intermediate.dense.bias', 'roberta.encoder.layer.0.output.dense.weight', 'roberta.encoder.layer.0.output.dense.bias', 'roberta.encoder.layer.0.output.LayerNorm.weight', 'roberta.encoder.layer.0.output.LayerNorm.bias', 'roberta.encoder.layer.1.attention.self.query.weight', 'roberta.encoder.layer.1.attention.self.query.bias', 'roberta.encoder.layer.1.attention.self.key.weight', 'roberta.encoder.layer.1.attention.self.key.bias', 'roberta.encoder.layer.1.attention.self.value.weight', 'roberta.encoder.layer.1.attention.self.value.bias', 'roberta.encoder.layer.1.attention.output.dense.weight', 'roberta.encoder.layer.1.attention.output.dense.bias', 'roberta.encoder.layer.1.attention.output.LayerNorm.weight', 'roberta.encoder.layer.1.attention.output.LayerNorm.bias', 'roberta.encoder.layer.1.intermediate.dense.weight', 'roberta.encoder.layer.1.intermediate.dense.bias', 'roberta.encoder.layer.1.output.dense.weight', 'roberta.encoder.layer.1.output.dense.bias', 'roberta.encoder.layer.1.output.LayerNorm.weight', 'roberta.encoder.layer.1.output.LayerNorm.bias', 'roberta.encoder.layer.2.attention.self.query.weight', 'roberta.encoder.layer.2.attention.self.query.bias', 'roberta.encoder.layer.2.attention.self.key.weight', 'roberta.encoder.layer.2.attention.self.key.bias', 'roberta.encoder.layer.2.attention.self.value.weight', 'roberta.encoder.layer.2.attention.self.value.bias', 'roberta.encoder.layer.2.attention.output.dense.weight', 'roberta.encoder.layer.2.attention.output.dense.bias', 'roberta.encoder.layer.2.attention.output.LayerNorm.weight', 'roberta.encoder.layer.2.attention.output.LayerNorm.bias', 'roberta.encoder.layer.2.intermediate.dense.weight', 'roberta.encoder.layer.2.intermediate.dense.bias', 'roberta.encoder.layer.2.output.dense.weight', 'roberta.encoder.layer.2.output.dense.bias', 'roberta.encoder.layer.2.output.LayerNorm.weight', 'roberta.encoder.layer.2.output.LayerNorm.bias', 'roberta.encoder.layer.3.attention.self.query.weight', 'roberta.encoder.layer.3.attention.self.query.bias', 'roberta.encoder.layer.3.attention.self.key.weight', 'roberta.encoder.layer.3.attention.self.key.bias', 'roberta.encoder.layer.3.attention.self.value.weight', 'roberta.encoder.layer.3.attention.self.value.bias', 'roberta.encoder.layer.3.attention.output.dense.weight', 'roberta.encoder.layer.3.attention.output.dense.bias', 'roberta.encoder.layer.3.attention.output.LayerNorm.weight', 'roberta.encoder.layer.3.attention.output.LayerNorm.bias', 'roberta.encoder.layer.3.intermediate.dense.weight', 'roberta.encoder.layer.3.intermediate.dense.bias', 'roberta.encoder.layer.3.output.dense.weight', 'roberta.encoder.layer.3.output.dense.bias', 'roberta.encoder.layer.3.output.LayerNorm.weight', 'roberta.encoder.layer.3.output.LayerNorm.bias', 'roberta.encoder.layer.4.attention.self.query.weight', 'roberta.encoder.layer.4.attention.self.query.bias', 'roberta.encoder.layer.4.attention.self.key.weight', 'roberta.encoder.layer.4.attention.self.key.bias', 'roberta.encoder.layer.4.attention.self.value.weight', 'roberta.encoder.layer.4.attention.self.value.bias', 'roberta.encoder.layer.4.attention.output.dense.weight', 'roberta.encoder.layer.4.attention.output.dense.bias', 'roberta.encoder.layer.4.attention.output.LayerNorm.weight', 'roberta.encoder.layer.4.attention.output.LayerNorm.bias', 'roberta.encoder.layer.4.intermediate.dense.weight', 'roberta.encoder.layer.4.intermediate.dense.bias', 'roberta.encoder.layer.4.output.dense.weight', 'roberta.encoder.layer.4.output.dense.bias', 'roberta.encoder.layer.4.output.LayerNorm.weight', 'roberta.encoder.layer.4.output.LayerNorm.bias', 'roberta.encoder.layer.5.attention.self.query.weight', 'roberta.encoder.layer.5.attention.self.query.bias', 'roberta.encoder.layer.5.attention.self.key.weight', 'roberta.encoder.layer.5.attention.self.key.bias', 'roberta.encoder.layer.5.attention.self.value.weight', 'roberta.encoder.layer.5.attention.self.value.bias', 'roberta.encoder.layer.5.attention.output.dense.weight', 'roberta.encoder.layer.5.attention.output.dense.bias', 'roberta.encoder.layer.5.attention.output.LayerNorm.weight', 'roberta.encoder.layer.5.attention.output.LayerNorm.bias', 'roberta.encoder.layer.5.intermediate.dense.weight', 'roberta.encoder.layer.5.intermediate.dense.bias', 'roberta.encoder.layer.5.output.dense.weight', 'roberta.encoder.layer.5.output.dense.bias', 'roberta.encoder.layer.5.output.LayerNorm.weight', 'roberta.encoder.layer.5.output.LayerNorm.bias', 'roberta.encoder.layer.6.attention.self.query.weight', 'roberta.encoder.layer.6.attention.self.query.bias', 'roberta.encoder.layer.6.attention.self.key.weight', 'roberta.encoder.layer.6.attention.self.key.bias', 'roberta.encoder.layer.6.attention.self.value.weight', 'roberta.encoder.layer.6.attention.self.value.bias', 'roberta.encoder.layer.6.attention.output.dense.weight', 'roberta.encoder.layer.6.attention.output.dense.bias', 'roberta.encoder.layer.6.attention.output.LayerNorm.weight', 'roberta.encoder.layer.6.attention.output.LayerNorm.bias', 'roberta.encoder.layer.6.intermediate.dense.weight', 'roberta.encoder.layer.6.intermediate.dense.bias', 'roberta.encoder.layer.6.output.dense.weight', 'roberta.encoder.layer.6.output.dense.bias', 'roberta.encoder.layer.6.output.LayerNorm.weight', 'roberta.encoder.layer.6.output.LayerNorm.bias', 'roberta.encoder.layer.7.attention.self.query.weight', 'roberta.encoder.layer.7.attention.self.query.bias', 'roberta.encoder.layer.7.attention.self.key.weight', 'roberta.encoder.layer.7.attention.self.key.bias', 'roberta.encoder.layer.7.attention.self.value.weight', 'roberta.encoder.layer.7.attention.self.value.bias', 'roberta.encoder.layer.7.attention.output.dense.weight', 'roberta.encoder.layer.7.attention.output.dense.bias', 'roberta.encoder.layer.7.attention.output.LayerNorm.weight', 'roberta.encoder.layer.7.attention.output.LayerNorm.bias', 'roberta.encoder.layer.7.intermediate.dense.weight', 'roberta.encoder.layer.7.intermediate.dense.bias', 'roberta.encoder.layer.7.output.dense.weight', 'roberta.encoder.layer.7.output.dense.bias', 'roberta.encoder.layer.7.output.LayerNorm.weight', 'roberta.encoder.layer.7.output.LayerNorm.bias', 'roberta.encoder.layer.8.attention.self.query.weight', 'roberta.encoder.layer.8.attention.self.query.bias', 'roberta.encoder.layer.8.attention.self.key.weight', 'roberta.encoder.layer.8.attention.self.key.bias', 'roberta.encoder.layer.8.attention.self.value.weight', 'roberta.encoder.layer.8.attention.self.value.bias', 'roberta.encoder.layer.8.attention.output.dense.weight', 'roberta.encoder.layer.8.attention.output.dense.bias', 'roberta.encoder.layer.8.attention.output.LayerNorm.weight', 'roberta.encoder.layer.8.attention.output.LayerNorm.bias', 'roberta.encoder.layer.8.intermediate.dense.weight', 'roberta.encoder.layer.8.intermediate.dense.bias', 'roberta.encoder.layer.8.output.dense.weight', 'roberta.encoder.layer.8.output.dense.bias', 'roberta.encoder.layer.8.output.LayerNorm.weight', 'roberta.encoder.layer.8.output.LayerNorm.bias', 'roberta.encoder.layer.9.attention.self.query.weight', 'roberta.encoder.layer.9.attention.self.query.bias', 'roberta.encoder.layer.9.attention.self.key.weight', 'roberta.encoder.layer.9.attention.self.key.bias', 'roberta.encoder.layer.9.attention.self.value.weight', 'roberta.encoder.layer.9.attention.self.value.bias', 'roberta.encoder.layer.9.attention.output.dense.weight', 'roberta.encoder.layer.9.attention.output.dense.bias', 'roberta.encoder.layer.9.attention.output.LayerNorm.weight', 'roberta.encoder.layer.9.attention.output.LayerNorm.bias', 'roberta.encoder.layer.9.intermediate.dense.weight', 'roberta.encoder.layer.9.intermediate.dense.bias', 'roberta.encoder.layer.9.output.dense.weight', 'roberta.encoder.layer.9.output.dense.bias', 'roberta.encoder.layer.9.output.LayerNorm.weight', 'roberta.encoder.layer.9.output.LayerNorm.bias', 'roberta.encoder.layer.10.attention.self.query.weight', 'roberta.encoder.layer.10.attention.self.query.bias', 'roberta.encoder.layer.10.attention.self.key.weight', 'roberta.encoder.layer.10.attention.self.key.bias', 'roberta.encoder.layer.10.attention.self.value.weight', 'roberta.encoder.layer.10.attention.self.value.bias', 'roberta.encoder.layer.10.attention.output.dense.weight', 'roberta.encoder.layer.10.attention.output.dense.bias', 'roberta.encoder.layer.10.attention.output.LayerNorm.weight', 'roberta.encoder.layer.10.attention.output.LayerNorm.bias', 'roberta.encoder.layer.10.intermediate.dense.weight', 'roberta.encoder.layer.10.intermediate.dense.bias', 'roberta.encoder.layer.10.output.dense.weight', 'roberta.encoder.layer.10.output.dense.bias', 'roberta.encoder.layer.10.output.LayerNorm.weight', 'roberta.encoder.layer.10.output.LayerNorm.bias', 'roberta.encoder.layer.11.attention.self.query.weight', 'roberta.encoder.layer.11.attention.self.query.bias', 'roberta.encoder.layer.11.attention.self.key.weight', 'roberta.encoder.layer.11.attention.self.key.bias', 'roberta.encoder.layer.11.attention.self.value.weight', 'roberta.encoder.layer.11.attention.self.value.bias', 'roberta.encoder.layer.11.attention.output.dense.weight', 'roberta.encoder.layer.11.attention.output.dense.bias', 'roberta.encoder.layer.11.attention.output.LayerNorm.weight', 'roberta.encoder.layer.11.attention.output.LayerNorm.bias', 'roberta.encoder.layer.11.intermediate.dense.weight', 'roberta.encoder.layer.11.intermediate.dense.bias', 'roberta.encoder.layer.11.output.dense.weight', 'roberta.encoder.layer.11.output.dense.bias', 'roberta.encoder.layer.11.output.LayerNorm.weight', 'roberta.encoder.layer.11.output.LayerNorm.bias', 'lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.decoder.bias'])\n",
+      "success1\n",
+      "fail2\n",
+      "['lm_head.bias']\n",
+      "Total differences: 0.0\n"
+     ]
+    }
+   ],
+   "source": [
+    "# Check weights loaded correctly\n",
+    "my_dict = my_mlm.to_state_dict()\n",
+    "hf_dict = hf_mlm.state_dict()\n",
+    "\n",
+    "print(f\"Total differences: {check_dicts(my_dict, hf_dict)}\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 59,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([-0.0972, -0.0294,  0.4988,  ..., -0.0312, -0.0312, -1.0000])"
+      ]
+     },
+     "execution_count": 59,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "hf_dict['lm_head.bias']"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 60,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([-0.0972, -0.0294,  0.4988,  ..., -0.0312, -0.0312, -1.0000])"
+      ]
+     },
+     "execution_count": 60,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "hf_dict['lm_head.decoder.bias']"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 61,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n"
+     ]
+    }
+   ],
+   "source": [
+    "my_output_mlm = my_mlm(input_embeds = input_embeds, attention_mask=mask, key=key)\n",
+    "hf_output_mlm = hf_mlm(inputs_embeds = input_embeds_torch, attention_mask=mask_torch, return_dict=False)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 62,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "MLM: acc: 1.0 \t norms: (tensor(33433.4062), tensor(33433.3945)) \t diffs: 1.7561978893354535e-05\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(f\"MLM: {check(my_output_mlm[0].array, hf_output_mlm[0].detach())[3]}\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 63,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n",
+      "c:\\Users\\julie\\anaconda3\\envs\\levanter2\\lib\\site-packages\\haliax\\core.py:249: UserWarning: Found axis with same name but different size.\n",
+      "  warnings.warn(\"Found axis with same name but different size.\", UserWarning)\n"
+     ]
+    }
+   ],
+   "source": [
+    "my_output_mlm = my_mlm(input_ids = input_ids, attention_mask=mask, key=key)\n",
+    "hf_output_mlm = hf_mlm(input_ids = input_ids_torch, attention_mask=mask_torch, return_dict=False)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 64,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "MLM: acc: 1.0 \t norms: (tensor(28814.2480), tensor(28814.2168)) \t diffs: 1.45375252031954e-05\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(f\"MLM: {check(my_output_mlm[0].array, hf_output_mlm[0].detach())[3]}\")"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.10.14"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}