From 476b6dc79720e5d9ddfb3cd589680b2308871926 Mon Sep 17 00:00:00 2001 From: ZXMMD <46393405+ZXMMD@users.noreply.github.com> Date: Fri, 12 Jul 2024 14:52:21 +0800 Subject: [PATCH] fix ckpt_utils.py (#580) --- opensora/utils/ckpt_utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/opensora/utils/ckpt_utils.py b/opensora/utils/ckpt_utils.py index b2ac5e24..d730981c 100644 --- a/opensora/utils/ckpt_utils.py +++ b/opensora/utils/ckpt_utils.py @@ -193,6 +193,12 @@ def load_checkpoint(model, ckpt_path, save_as_pt=False, model_name="model", stri missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=strict) get_logger().info("Missing keys: %s", missing_keys) get_logger().info("Unexpected keys: %s", unexpected_keys) + elif ckpt_path.endswith(".safetensors"): + from safetensors.torch import load_file + state_dict = load_file(ckpt_path) + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + print(f"Missing keys: {missing_keys}") + print(f"Unexpected keys: {unexpected_keys}") elif os.path.isdir(ckpt_path): load_from_sharded_state_dict(model, ckpt_path, model_name, strict=strict) get_logger().info("Model checkpoint loaded from %s", ckpt_path)