diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index b868549d..6b563135 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -73,6 +73,21 @@ jobs: - name: Checkout source uses: actions/checkout@v4 + with: + # Getting the correct commit for a pull_request_target event appears to be + # a known, problematic issue: https://github.com/actions/checkout/issues/518 + # It seems that ideally, we want github.event.pull_request.merge_commit_sha, + # but that it is not reliable, and can sometimes be a null values. It + # appears that this is the most reasonable way to ensure that we are pulling + # the same code that triggered things, based upon this particular comment: + # https://github.com/actions/checkout/issues/518#issuecomment-1661941548 + ref: "refs/pull/${{ github.event.number }}/merge" + fetch-depth: 2 + + - name: Sanity check + # Continuing from previous comment in checkout step above. + run: | + [[ "$(git rev-parse 'HEAD~2')" == "${{ github.event.pull_request.head.sha }}" ]] - name: Install package with dependencies uses: ./.github/actions/install-pkg diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1d6250d4..cd7b6cd3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -27,7 +27,7 @@ jobs: run: mypy - name: Test - run: pytest tests/unit --cov=earthaccess --cov=tests --cov-report=term-missing --capture=no --tb=native --log-cli-level=INFO + run: pytest tests/unit --verbose --cov=earthaccess --cov-report=term-missing --capture=no --tb=native --log-cli-level=INFO - name: Upload coverage # Don't upload coverage when using the `act` tool to run the workflow locally diff --git a/earthaccess/api.py b/earthaccess/api.py index 5ad75ed5..6b758aa2 100644 --- a/earthaccess/api.py +++ b/earthaccess/api.py @@ -1,4 +1,5 @@ import logging +from pathlib import Path import requests import s3fs @@ -202,9 +203,10 @@ def login(strategy: str = "all", persist: bool = False, system: System = PROD) - def download( granules: Union[DataGranule, List[DataGranule], str, List[str]], - local_path: Optional[str], + local_path: Optional[Union[Path, str]] = None, provider: Optional[str] = None, threads: int = 8, + *, pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[str]: """Retrieves data granules from a remote storage system. @@ -215,7 +217,11 @@ def download( Parameters: granules: a granule, list of granules, a granule link (HTTP), or a list of granule links (HTTP) - local_path: local directory to store the remote data granules + local_path: Local directory to store the remote data granules. If not + supplied, defaults to a subdirectory of the current working directory + of the form `data/YYYY-MM-DD-UUID`, where `YYYY-MM-DD` is the year, + month, and day of the current date, and `UUID` is the last 6 digits + of a UUID4 value. provider: if we download a list of URLs, we need to specify the provider. threads: parallel number of threads to use to download the files, adjust as necessary, default = 8 pqdm_kwargs: Additional keyword arguments to pass to pqdm, a parallel processing library. @@ -228,31 +234,29 @@ def download( Raises: Exception: A file download failed. """ - provider = _normalize_location(provider) - pqdm_kwargs = { - "exception_behavior": "immediate", - "n_jobs": threads, - **(pqdm_kwargs or {}), - } + provider = _normalize_location(str(provider)) + if isinstance(granules, DataGranule): granules = [granules] elif isinstance(granules, str): granules = [granules] + try: - results = earthaccess.__store__.get( - granules, local_path, provider, threads, pqdm_kwargs + return earthaccess.__store__.get( + granules, local_path, provider, threads, pqdm_kwargs=pqdm_kwargs ) except AttributeError as err: logger.error( f"{err}: You must call earthaccess.login() before you can download data" ) - return [] - return results + + return [] def open( granules: Union[List[str], List[DataGranule]], provider: Optional[str] = None, + *, pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[AbstractFileSystem]: """Returns a list of file-like objects that can be used to access files @@ -269,15 +273,11 @@ def open( Returns: A list of "file pointers" to remote (i.e. s3 or https) files. """ - provider = _normalize_location(provider) - pqdm_kwargs = { - "exception_behavior": "immediate", - **(pqdm_kwargs or {}), - } - results = earthaccess.__store__.open( - granules=granules, provider=provider, pqdm_kwargs=pqdm_kwargs + return earthaccess.__store__.open( + granules=granules, + provider=_normalize_location(provider), + pqdm_kwargs=pqdm_kwargs, ) - return results def get_s3_credentials( diff --git a/earthaccess/store.py b/earthaccess/store.py index f7b5c85e..58ac9f59 100644 --- a/earthaccess/store.py +++ b/earthaccess/store.py @@ -63,7 +63,7 @@ def __repr__(self) -> str: def _open_files( url_mapping: Mapping[str, Union[DataGranule, None]], fs: fsspec.AbstractFileSystem, - threads: int = 8, + *, pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[fsspec.spec.AbstractBufferedFile]: def multi_thread_open(data: tuple[str, Optional[DataGranule]]) -> EarthAccessFile: @@ -71,14 +71,12 @@ def multi_thread_open(data: tuple[str, Optional[DataGranule]]) -> EarthAccessFil return EarthAccessFile(fs.open(url), granule) # type: ignore pqdm_kwargs = { - "exception_behavior": "immediate", + "exception_behaviour": "immediate", + "n_jobs": 8, **(pqdm_kwargs or {}), } - fileset = pqdm( - url_mapping.items(), multi_thread_open, n_jobs=threads, **pqdm_kwargs - ) - return fileset + return pqdm(url_mapping.items(), multi_thread_open, **pqdm_kwargs) def make_instance( @@ -344,6 +342,7 @@ def open( self, granules: Union[List[str], List[DataGranule]], provider: Optional[str] = None, + *, pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[fsspec.spec.AbstractBufferedFile]: """Returns a list of file-like objects that can be used to access files @@ -361,7 +360,7 @@ def open( A list of "file pointers" to remote (i.e. s3 or https) files. """ if len(granules): - return self._open(granules, provider, pqdm_kwargs) + return self._open(granules, provider, pqdm_kwargs=pqdm_kwargs) return [] @singledispatchmethod @@ -369,6 +368,7 @@ def _open( self, granules: Union[List[str], List[DataGranule]], provider: Optional[str] = None, + *, pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[Any]: raise NotImplementedError("granules should be a list of DataGranule or URLs") @@ -378,7 +378,8 @@ def _open_granules( self, granules: List[DataGranule], provider: Optional[str] = None, - threads: int = 8, + *, + pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[Any]: fileset: List = [] total_size = round(sum([granule.size() for granule in granules]) / 1024, 2) @@ -411,7 +412,7 @@ def _open_granules( fileset = _open_files( url_mapping, fs=s3_fs, - threads=threads, + pqdm_kwargs=pqdm_kwargs, ) except Exception as e: raise RuntimeError( @@ -420,19 +421,19 @@ def _open_granules( f"Exception: {traceback.format_exc()}" ) from e else: - fileset = self._open_urls_https(url_mapping, threads=threads) - return fileset + fileset = self._open_urls_https(url_mapping, pqdm_kwargs=pqdm_kwargs) else: url_mapping = _get_url_granule_mapping(granules, access="on_prem") - fileset = self._open_urls_https(url_mapping, threads=threads) - return fileset + fileset = self._open_urls_https(url_mapping, pqdm_kwargs=pqdm_kwargs) + + return fileset @_open.register def _open_urls( self, granules: List[str], provider: Optional[str] = None, - threads: int = 8, + *, pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[Any]: fileset: List = [] @@ -460,7 +461,6 @@ def _open_urls( fileset = _open_files( url_mapping, fs=s3_fs, - threads=threads, pqdm_kwargs=pqdm_kwargs, ) except Exception as e: @@ -481,15 +481,16 @@ def _open_urls( raise ValueError( "We cannot open S3 links when we are not in-region, try using HTTPS links" ) - fileset = self._open_urls_https(url_mapping, threads, pqdm_kwargs) + fileset = self._open_urls_https(url_mapping, pqdm_kwargs=pqdm_kwargs) return fileset def get( self, granules: Union[List[DataGranule], List[str]], - local_path: Union[Path, str, None] = None, + local_path: Optional[Union[Path, str]] = None, provider: Optional[str] = None, threads: int = 8, + *, pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[str]: """Retrieves data granules from a remote storage system. @@ -503,7 +504,11 @@ def get( Parameters: granules: A list of granules(DataGranule) instances or a list of granule links (HTTP). - local_path: Local directory to store the remote data granules. + local_path: Local directory to store the remote data granules. If not + supplied, defaults to a subdirectory of the current working directory + of the form `data/YYYY-MM-DD-UUID`, where `YYYY-MM-DD` is the year, + month, and day of the current date, and `UUID` is the last 6 digits + of a UUID4 value. provider: a valid cloud provider, each DAAC has a provider code for their cloud distributions threads: Parallel number of threads to use to download the files; adjust as necessary, default = 8. @@ -514,18 +519,20 @@ def get( Returns: List of downloaded files """ + if not granules: + raise ValueError("List of URLs or DataGranule instances expected") + if local_path is None: - today = datetime.datetime.today().strftime("%Y-%m-%d") + today = datetime.datetime.now().strftime("%Y-%m-%d") uuid = uuid4().hex[:6] local_path = Path.cwd() / "data" / f"{today}-{uuid}" - elif isinstance(local_path, str): - local_path = Path(local_path) - if len(granules): - files = self._get(granules, local_path, provider, threads, pqdm_kwargs) - return files - else: - raise ValueError("List of URLs or DataGranule instances expected") + pqdm_kwargs = { + "n_jobs": threads, + **(pqdm_kwargs or {}), + } + + return self._get(granules, Path(local_path), provider, pqdm_kwargs=pqdm_kwargs) @singledispatchmethod def _get( @@ -533,7 +540,7 @@ def _get( granules: Union[List[DataGranule], List[str]], local_path: Path, provider: Optional[str] = None, - threads: int = 8, + *, pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[str]: """Retrieves data granules from a remote storage system. @@ -566,7 +573,7 @@ def _get_urls( granules: List[str], local_path: Path, provider: Optional[str] = None, - threads: int = 8, + *, pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[str]: data_links = granules @@ -590,7 +597,7 @@ def _get_urls( else: # if we are not in AWS return self._download_onprem_granules( - data_links, local_path, threads, pqdm_kwargs + data_links, local_path, pqdm_kwargs=pqdm_kwargs ) @_get.register @@ -599,7 +606,7 @@ def _get_granules( granules: List[DataGranule], local_path: Path, provider: Optional[str] = None, - threads: int = 8, + *, pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[str]: data_links: List = [] @@ -615,7 +622,7 @@ def _get_granules( for granule in granules ) ) - total_size = round(sum([granule.size() for granule in granules]) / 1024, 2) + total_size = round(sum(granule.size() for granule in granules) / 1024, 2) logger.info( f" Getting {len(granules)} granules, approx download size: {total_size} GB" ) @@ -642,7 +649,7 @@ def _get_granules( # if the data are cloud-based, but we are not in AWS, # it will be downloaded as if it was on prem return self._download_onprem_granules( - data_links, local_path, threads, pqdm_kwargs + data_links, local_path, pqdm_kwargs=pqdm_kwargs ) def _download_file(self, url: str, directory: Path) -> str: @@ -684,7 +691,7 @@ def _download_onprem_granules( self, urls: List[str], directory: Path, - threads: int = 8, + *, pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[Any]: """Downloads a list of URLS into the data directory. @@ -711,25 +718,26 @@ def _download_onprem_granules( arguments = [(url, directory) for url in urls] - results = pqdm( - arguments, - self._download_file, - n_jobs=threads, - argument_type="args", - **pqdm_kwargs, - ) - return results + pqdm_kwargs = { + "exception_behaviour": "immediate", + **(pqdm_kwargs or {}), + # We don't want a user to be able to override the following kwargs, + # which is why they appear *after* spreading pqdm_kwargs above. + "argument_type": "args", + } + + return pqdm(arguments, self._download_file, **pqdm_kwargs) def _open_urls_https( self, url_mapping: Mapping[str, Union[DataGranule, None]], - threads: int = 8, + *, pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[fsspec.AbstractFileSystem]: https_fs = self.get_fsspec_session() try: - return _open_files(url_mapping, https_fs, threads, pqdm_kwargs) + return _open_files(url_mapping, https_fs, pqdm_kwargs=pqdm_kwargs) except Exception: logger.exception( "An exception occurred while trying to access remote files via HTTPS" diff --git a/scripts/integration-test.sh b/scripts/integration-test.sh index 506976ad..1a8ad4f6 100755 --- a/scripts/integration-test.sh +++ b/scripts/integration-test.sh @@ -1,7 +1,7 @@ #!/usr/bin/env bash set -x -pytest tests/integration --cov=earthaccess --cov=tests/integration --cov-report=term-missing "${@}" --capture=no --tb=native --log-cli-level=INFO +pytest tests/integration --cov=earthaccess --cov-report=term-missing "${@}" --capture=no --tb=native --log-cli-level=INFO RET=$? set +x diff --git a/tests/integration/test_api.py b/tests/integration/test_api.py index f0fdd219..f6ec1fcd 100644 --- a/tests/integration/test_api.py +++ b/tests/integration/test_api.py @@ -1,6 +1,7 @@ import logging import os from pathlib import Path +from unittest.mock import patch import earthaccess import pytest @@ -77,6 +78,51 @@ def test_download(tmp_path, selection, use_url): assert all(Path(f).exists() for f in files) +def fail_to_download_file(*args, **kwargs): + raise IOError("Download failed") + + +def test_download_immediate_failure(tmp_path: Path): + results = earthaccess.search_data( + short_name="ATL06", + bounding_box=(-10, 20, 10, 50), + temporal=("1999-02", "2019-03"), + count=3, + ) + + with patch.object(earthaccess.__store__, "_download_file", fail_to_download_file): + with pytest.raises(IOError, match="Download failed"): + # By default, we set pqdm exception_behavior to "immediate" so that + # it simply propagates the first download error it encounters, halting + # any further downloads. + earthaccess.download(results, tmp_path, pqdm_kwargs=dict(disable=True)) + + +def test_download_deferred_failure(tmp_path: Path): + count = 3 + results = earthaccess.search_data( + short_name="ATL06", + bounding_box=(-10, 20, 10, 50), + temporal=("1999-02", "2019-03"), + count=count, + ) + + with patch.object(earthaccess.__store__, "_download_file", fail_to_download_file): + # With "deferred" exceptions, pqdm catches all exceptions, then at the end + # raises a single generic Exception, passing the sequence of caught exceptions + # as arguments to the Exception constructor. + with pytest.raises(Exception) as exc_info: + earthaccess.download( + results, + tmp_path, + pqdm_kwargs=dict(exception_behaviour="deferred", disable=True), + ) + + errors = exc_info.value.args + assert len(errors) == count + assert all(isinstance(e, IOError) and str(e) == "Download failed" for e in errors) + + def test_auth_environ(): earthaccess.login(strategy="environment") environ = earthaccess.auth_environ() diff --git a/tests/unit/test_api.py b/tests/unit/test_api.py deleted file mode 100644 index 20980e35..00000000 --- a/tests/unit/test_api.py +++ /dev/null @@ -1,50 +0,0 @@ -from unittest.mock import Mock - -import earthaccess -import pytest - - -def test_download_immediate_failure(monkeypatch): - earthaccess.login() - - results = earthaccess.search_data( - short_name="ATL06", - bounding_box=(-10, 20, 10, 50), - temporal=("1999-02", "2019-03"), - count=10, - ) - - def mock_get(*args, **kwargs): - raise Exception("Download failed") - - mock_store = Mock() - monkeypatch.setattr(earthaccess, "__store__", mock_store) - monkeypatch.setattr(mock_store, "get", mock_get) - - with pytest.raises(Exception, match="Download failed"): - earthaccess.download(results, "/home/download-folder") - - -def test_download_deferred_failure(monkeypatch): - earthaccess.login() - - results = earthaccess.search_data( - short_name="ATL06", - bounding_box=(-10, 20, 10, 50), - temporal=("1999-02", "2019-03"), - count=10, - ) - - def mock_get(*args, **kwargs): - return [Exception("Download failed")] * len(results) - - mock_store = Mock() - monkeypatch.setattr(earthaccess, "__store__", mock_store) - monkeypatch.setattr(mock_store, "get", mock_get) - - results = earthaccess.download( - results, "/home/download-folder", None, 8, {"exception_behavior": "deferred"} - ) - - assert all(isinstance(e, Exception) for e in results) - assert len(results) == 10