From 594c0c2a7a85062d4461f75c34327844e35f8e2a Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Wed, 2 Oct 2024 11:49:17 -0700 Subject: [PATCH] hf config --- src/levanter/compat/torch_serialization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/levanter/compat/torch_serialization.py b/src/levanter/compat/torch_serialization.py index ea110d981..ddfb8dc11 100644 --- a/src/levanter/compat/torch_serialization.py +++ b/src/levanter/compat/torch_serialization.py @@ -168,7 +168,7 @@ def default_eqx_module_from_state_dict(mod: Mod, state_dict: StateDict, prefix: # Hack to get around the fact we're using llama code for # olmo model and something weird w layernorm if prefix is not None: - if 'layernorm' in prefix.lower(): + if 'layernorm' in prefix.lower() or 'model.norm.weight' in prefix.lower(): continue new = jax_tree_from_state_dict(value, state_dict, apply_prefix(prefix, key)) # Do not try to update parameters that are never defined