diff --git a/torchtitan/datasets/hf_datasets.py b/torchtitan/datasets/hf_datasets.py index f6be8736..d8cd5d83 100644 --- a/torchtitan/datasets/hf_datasets.py +++ b/torchtitan/datasets/hf_datasets.py @@ -13,7 +13,7 @@ try: from torchdata.stateful_dataloader import StatefulDataLoader -except ModuleNotFoundError as e: +except ImportError as e: raise ImportError( "Please install the latest torchdata nightly to use StatefulDataloader via:" "pip3 install --pre torchdata --index-url https://download.pytorch.org/whl/nightly"