Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix failing integration tests #869

Merged
merged 6 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions .github/workflows/integration-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,21 @@ jobs:

- name: Checkout source
uses: actions/checkout@v4
with:
# Getting the correct commit for a pull_request_target event appears to be
# a known, problematic issue: https://github.com/actions/checkout/issues/518
# It seems that ideally, we want github.event.pull_request.merge_commit_sha,
# but that it is not reliable, and can sometimes be a null values. It
# appears that this is the most reasonable way to ensure that we are pulling
# the same code that triggered things, based upon this particular comment:
# https://github.com/actions/checkout/issues/518#issuecomment-1661941548
ref: "refs/pull/${{ github.event.number }}/merge"
fetch-depth: 2

- name: Sanity check
# Continuing from previous comment in checkout step above.
run: |
[[ "$(git rev-parse 'HEAD~2')" == "${{ github.event.pull_request.head.sha }}" ]]

- name: Install package with dependencies
uses: ./.github/actions/install-pkg
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
run: mypy

- name: Test
run: pytest tests/unit --cov=earthaccess --cov=tests --cov-report=term-missing --capture=no --tb=native --log-cli-level=INFO
run: pytest tests/unit --verbose --cov=earthaccess --cov-report=term-missing --capture=no --tb=native --log-cli-level=INFO

- name: Upload coverage
# Don't upload coverage when using the `act` tool to run the workflow locally
Expand Down
40 changes: 20 additions & 20 deletions earthaccess/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from pathlib import Path

import requests
import s3fs
Expand Down Expand Up @@ -202,9 +203,10 @@ def login(strategy: str = "all", persist: bool = False, system: System = PROD) -

def download(
granules: Union[DataGranule, List[DataGranule], str, List[str]],
local_path: Optional[str],
local_path: Optional[Union[Path, str]] = None,
provider: Optional[str] = None,
threads: int = 8,
*,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[str]:
"""Retrieves data granules from a remote storage system.
Expand All @@ -215,7 +217,11 @@ def download(

Parameters:
granules: a granule, list of granules, a granule link (HTTP), or a list of granule links (HTTP)
local_path: local directory to store the remote data granules
local_path: Local directory to store the remote data granules. If not
supplied, defaults to a subdirectory of the current working directory
of the form `data/YYYY-MM-DD-UUID`, where `YYYY-MM-DD` is the year,
month, and day of the current date, and `UUID` is the last 6 digits
of a UUID4 value.
provider: if we download a list of URLs, we need to specify the provider.
threads: parallel number of threads to use to download the files, adjust as necessary, default = 8
pqdm_kwargs: Additional keyword arguments to pass to pqdm, a parallel processing library.
Expand All @@ -228,31 +234,29 @@ def download(
Raises:
Exception: A file download failed.
"""
provider = _normalize_location(provider)
pqdm_kwargs = {
"exception_behavior": "immediate",
"n_jobs": threads,
**(pqdm_kwargs or {}),
}
provider = _normalize_location(str(provider))

if isinstance(granules, DataGranule):
granules = [granules]
elif isinstance(granules, str):
granules = [granules]

try:
results = earthaccess.__store__.get(
granules, local_path, provider, threads, pqdm_kwargs
return earthaccess.__store__.get(
granules, local_path, provider, threads, pqdm_kwargs=pqdm_kwargs
)
except AttributeError as err:
logger.error(
f"{err}: You must call earthaccess.login() before you can download data"
)
return []
return results

return []


def open(
granules: Union[List[str], List[DataGranule]],
provider: Optional[str] = None,
*,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[AbstractFileSystem]:
"""Returns a list of file-like objects that can be used to access files
Expand All @@ -269,15 +273,11 @@ def open(
Returns:
A list of "file pointers" to remote (i.e. s3 or https) files.
"""
provider = _normalize_location(provider)
pqdm_kwargs = {
"exception_behavior": "immediate",
**(pqdm_kwargs or {}),
}
results = earthaccess.__store__.open(
granules=granules, provider=provider, pqdm_kwargs=pqdm_kwargs
return earthaccess.__store__.open(
granules=granules,
provider=_normalize_location(provider),
pqdm_kwargs=pqdm_kwargs,
)
return results


def get_s3_credentials(
Expand Down
94 changes: 51 additions & 43 deletions earthaccess/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,22 +63,20 @@ def __repr__(self) -> str:
def _open_files(
url_mapping: Mapping[str, Union[DataGranule, None]],
fs: fsspec.AbstractFileSystem,
threads: int = 8,
*,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[fsspec.spec.AbstractBufferedFile]:
def multi_thread_open(data: tuple[str, Optional[DataGranule]]) -> EarthAccessFile:
url, granule = data
return EarthAccessFile(fs.open(url), granule) # type: ignore

pqdm_kwargs = {
"exception_behavior": "immediate",
"exception_behaviour": "immediate",
"n_jobs": 8,
**(pqdm_kwargs or {}),
}

fileset = pqdm(
url_mapping.items(), multi_thread_open, n_jobs=threads, **pqdm_kwargs
)
return fileset
return pqdm(url_mapping.items(), multi_thread_open, **pqdm_kwargs)


def make_instance(
Expand Down Expand Up @@ -344,6 +342,7 @@ def open(
self,
granules: Union[List[str], List[DataGranule]],
provider: Optional[str] = None,
*,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[fsspec.spec.AbstractBufferedFile]:
"""Returns a list of file-like objects that can be used to access files
Expand All @@ -361,14 +360,15 @@ def open(
A list of "file pointers" to remote (i.e. s3 or https) files.
"""
if len(granules):
return self._open(granules, provider, pqdm_kwargs)
return self._open(granules, provider, pqdm_kwargs=pqdm_kwargs)
return []

@singledispatchmethod
def _open(
self,
granules: Union[List[str], List[DataGranule]],
provider: Optional[str] = None,
*,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[Any]:
raise NotImplementedError("granules should be a list of DataGranule or URLs")
Expand All @@ -378,7 +378,8 @@ def _open_granules(
self,
granules: List[DataGranule],
provider: Optional[str] = None,
threads: int = 8,
*,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[Any]:
fileset: List = []
total_size = round(sum([granule.size() for granule in granules]) / 1024, 2)
Expand Down Expand Up @@ -411,7 +412,7 @@ def _open_granules(
fileset = _open_files(
url_mapping,
fs=s3_fs,
threads=threads,
pqdm_kwargs=pqdm_kwargs,
)
except Exception as e:
raise RuntimeError(
Expand All @@ -420,19 +421,19 @@ def _open_granules(
f"Exception: {traceback.format_exc()}"
) from e
else:
fileset = self._open_urls_https(url_mapping, threads=threads)
return fileset
fileset = self._open_urls_https(url_mapping, pqdm_kwargs=pqdm_kwargs)
else:
url_mapping = _get_url_granule_mapping(granules, access="on_prem")
fileset = self._open_urls_https(url_mapping, threads=threads)
return fileset
fileset = self._open_urls_https(url_mapping, pqdm_kwargs=pqdm_kwargs)

return fileset

@_open.register
def _open_urls(
self,
granules: List[str],
provider: Optional[str] = None,
threads: int = 8,
*,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[Any]:
fileset: List = []
Expand Down Expand Up @@ -460,7 +461,6 @@ def _open_urls(
fileset = _open_files(
url_mapping,
fs=s3_fs,
threads=threads,
pqdm_kwargs=pqdm_kwargs,
)
except Exception as e:
Expand All @@ -481,15 +481,16 @@ def _open_urls(
raise ValueError(
"We cannot open S3 links when we are not in-region, try using HTTPS links"
)
fileset = self._open_urls_https(url_mapping, threads, pqdm_kwargs)
fileset = self._open_urls_https(url_mapping, pqdm_kwargs=pqdm_kwargs)
return fileset

def get(
self,
granules: Union[List[DataGranule], List[str]],
local_path: Union[Path, str, None] = None,
local_path: Optional[Union[Path, str]] = None,
provider: Optional[str] = None,
threads: int = 8,
*,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[str]:
"""Retrieves data granules from a remote storage system.
Expand All @@ -503,7 +504,11 @@ def get(

Parameters:
granules: A list of granules(DataGranule) instances or a list of granule links (HTTP).
local_path: Local directory to store the remote data granules.
local_path: Local directory to store the remote data granules. If not
supplied, defaults to a subdirectory of the current working directory
of the form `data/YYYY-MM-DD-UUID`, where `YYYY-MM-DD` is the year,
month, and day of the current date, and `UUID` is the last 6 digits
of a UUID4 value.
provider: a valid cloud provider, each DAAC has a provider code for their cloud distributions
threads: Parallel number of threads to use to download the files;
adjust as necessary, default = 8.
Expand All @@ -514,26 +519,28 @@ def get(
Returns:
List of downloaded files
"""
if not granules:
raise ValueError("List of URLs or DataGranule instances expected")

if local_path is None:
today = datetime.datetime.today().strftime("%Y-%m-%d")
today = datetime.datetime.now().strftime("%Y-%m-%d")
uuid = uuid4().hex[:6]
local_path = Path.cwd() / "data" / f"{today}-{uuid}"
elif isinstance(local_path, str):
local_path = Path(local_path)

if len(granules):
files = self._get(granules, local_path, provider, threads, pqdm_kwargs)
return files
else:
raise ValueError("List of URLs or DataGranule instances expected")
pqdm_kwargs = {
"n_jobs": threads,
**(pqdm_kwargs or {}),
}

return self._get(granules, Path(local_path), provider, pqdm_kwargs=pqdm_kwargs)

@singledispatchmethod
def _get(
self,
granules: Union[List[DataGranule], List[str]],
local_path: Path,
provider: Optional[str] = None,
threads: int = 8,
*,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[str]:
"""Retrieves data granules from a remote storage system.
Expand Down Expand Up @@ -566,7 +573,7 @@ def _get_urls(
granules: List[str],
local_path: Path,
provider: Optional[str] = None,
threads: int = 8,
*,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[str]:
data_links = granules
Expand All @@ -590,7 +597,7 @@ def _get_urls(
else:
# if we are not in AWS
return self._download_onprem_granules(
data_links, local_path, threads, pqdm_kwargs
data_links, local_path, pqdm_kwargs=pqdm_kwargs
)

@_get.register
Expand All @@ -599,7 +606,7 @@ def _get_granules(
granules: List[DataGranule],
local_path: Path,
provider: Optional[str] = None,
threads: int = 8,
*,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[str]:
data_links: List = []
Expand All @@ -615,7 +622,7 @@ def _get_granules(
for granule in granules
)
)
total_size = round(sum([granule.size() for granule in granules]) / 1024, 2)
total_size = round(sum(granule.size() for granule in granules) / 1024, 2)
logger.info(
f" Getting {len(granules)} granules, approx download size: {total_size} GB"
)
Expand All @@ -642,7 +649,7 @@ def _get_granules(
# if the data are cloud-based, but we are not in AWS,
# it will be downloaded as if it was on prem
return self._download_onprem_granules(
data_links, local_path, threads, pqdm_kwargs
data_links, local_path, pqdm_kwargs=pqdm_kwargs
)

def _download_file(self, url: str, directory: Path) -> str:
Expand Down Expand Up @@ -684,7 +691,7 @@ def _download_onprem_granules(
self,
urls: List[str],
directory: Path,
threads: int = 8,
*,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[Any]:
"""Downloads a list of URLS into the data directory.
Expand All @@ -711,25 +718,26 @@ def _download_onprem_granules(

arguments = [(url, directory) for url in urls]

results = pqdm(
arguments,
self._download_file,
n_jobs=threads,
argument_type="args",
**pqdm_kwargs,
)
return results
pqdm_kwargs = {
"exception_behaviour": "immediate",
**(pqdm_kwargs or {}),
# We don't want a user to be able to override the following kwargs,
# which is why they appear *after* spreading pqdm_kwargs above.
"argument_type": "args",
}

return pqdm(arguments, self._download_file, **pqdm_kwargs)

def _open_urls_https(
self,
url_mapping: Mapping[str, Union[DataGranule, None]],
threads: int = 8,
*,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[fsspec.AbstractFileSystem]:
https_fs = self.get_fsspec_session()

try:
return _open_files(url_mapping, https_fs, threads, pqdm_kwargs)
return _open_files(url_mapping, https_fs, pqdm_kwargs=pqdm_kwargs)
except Exception:
logger.exception(
"An exception occurred while trying to access remote files via HTTPS"
Expand Down
2 changes: 1 addition & 1 deletion scripts/integration-test.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env bash

set -x
pytest tests/integration --cov=earthaccess --cov=tests/integration --cov-report=term-missing "${@}" --capture=no --tb=native --log-cli-level=INFO
pytest tests/integration --cov=earthaccess --cov-report=term-missing "${@}" --capture=no --tb=native --log-cli-level=INFO
RET=$?
set +x

Expand Down
Loading
Loading