From 19678ab8026895402f40e616484d8071d8a154cd Mon Sep 17 00:00:00 2001 From: wangshuai09 <391746016@qq.com> Date: Sat, 30 Mar 2024 14:56:04 +0800 Subject: [PATCH] fix import error --- torchtext/datasets/cnndm.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/torchtext/datasets/cnndm.py b/torchtext/datasets/cnndm.py index 92b2da8ce1..a9022a926b 100644 --- a/torchtext/datasets/cnndm.py +++ b/torchtext/datasets/cnndm.py @@ -77,12 +77,14 @@ def _hash_urls(s: tuple): def _get_split_list(source: str, split: str): + from torchdata.datapipes.iter import IterableWrapper, OnlineReader # noqa url_dp = IterableWrapper([SPLIT_LIST[source + "_" + split]]) online_dp = OnlineReader(url_dp) return online_dp.readlines().map(fn=_hash_urls) def _load_stories(root: str, source: str, split: str): + from torchdata.datapipes.iter import FileOpener, IterableWrapper, GDriveReader # noqa split_list = set(_get_split_list(source, split)) story_dp = IterableWrapper([URL[source]]) cache_compressed_dp = story_dp.on_disk_cache( @@ -135,12 +137,6 @@ def CNNDM(root: str, split: Union[Tuple[str], str]): raise ModuleNotFoundError( "Package `torchdata` not found. Please install following instructions at https://github.com/pytorch/data" ) - from torchdata.datapipes.iter import ( # noqa - FileOpener, - IterableWrapper, - OnlineReader, - GDriveReader, - ) cnn_dp = _load_stories(root, "cnn", split) dailymail_dp = _load_stories(root, "dailymail", split)