From 47dbf9fd4ed123a92b3847c716f0f9f818d5f305 Mon Sep 17 00:00:00 2001 From: Evan Racah Date: Fri, 24 May 2024 09:34:14 -0700 Subject: [PATCH] Update docstring for get_model_state_dict (#3318) Turns out it's empty dict for nonzero ranks for unsharded state dicts because for torch 2.1.2 we set the `FullStateDictConfig` `rank0_only` flag to `True` and for torch >2.1.2, the `dcp.get_model_state_dict` function always has empty dict for nonzero ranks for unsharded state dicts --- composer/checkpoint/state_dict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer/checkpoint/state_dict.py b/composer/checkpoint/state_dict.py index 5417188466..e3b091c071 100644 --- a/composer/checkpoint/state_dict.py +++ b/composer/checkpoint/state_dict.py @@ -32,7 +32,7 @@ def get_model_state_dict( Args: model: The model to get the state dict from. sharded_state_dict: Whether the model state dict should be sharded or not. If True, every rank returns the state dict of its shards. - If False, then rank 0 returns the state dict of the entire model and the other ranks return a dict of their shards. Default is False. + If False, then rank 0 returns the state dict of the entire model and the other ranks return an empty dict. Default is False. precision: The precision of the model. Can be specified as a string ('fp32', 'fp16', 'bf16') or a torch.dtype. include_keys: The list of keys to exclusively include in the state dict. If None, all keys are included. Both include_keys and ignore_keys cannot be non-None. ignore_keys: The list of keys to ignore in the state dict. If None, no keys are ignored. Both include_keys and ignore_keys cannot be non-None.