diff --git a/earthaccess/api.py b/earthaccess/api.py index b82c550a..2df3b6c7 100644 --- a/earthaccess/api.py +++ b/earthaccess/api.py @@ -1,11 +1,10 @@ from typing import Any, Dict, List, Optional, Type, Union +import earthaccess import requests import s3fs from fsspec import AbstractFileSystem -import earthaccess - from .auth import Auth from .search import CollectionQuery, DataCollections, DataGranules, GranuleQuery from .store import Store diff --git a/earthaccess/store.py b/earthaccess/store.py index ff0c6170..3ee879f9 100644 --- a/earthaccess/store.py +++ b/earthaccess/store.py @@ -7,7 +7,9 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Union from uuid import uuid4 +from functools import lru_cache +import earthaccess import fsspec import requests import s3fs @@ -19,6 +21,43 @@ from .search import DataCollections +def _open_files(files, granules, fs): + def multi_thread_open(data) -> Any: + url, granule = data + return EarthAccessFile(fs.open(url), granule) + + fileset = pqdm(zip(files, granules), multi_thread_open, n_jobs=8) + return fileset + + +def make_instance(cls, granule, _reduce): + if earthaccess.__store__.running_in_aws and cls is not s3fs.S3File: + # On AWS but not using a S3File. Reopen the file in this case for direct S3 access. + # NOTE: This uses the first data_link listed in the granule. That's not + # guaranteed to be the right one. + return EarthAccessFile(earthaccess.open([granule])[0], granule) + else: + func = _reduce[0] + args = _reduce[1] + return func(*args) + + +class EarthAccessFile(fsspec.spec.AbstractBufferedFile): + def __init__(self, f, granule): + self.f = f + self.granule = granule + + def __getattr__(self, method): + return getattr(self.f, method) + + def __reduce__(self): + return make_instance, ( + type(self.f), + self.granule, + self.f.__reduce__(), + ) + + class Store(object): """ Store class to access granules on-prem or in the cloud. @@ -111,6 +150,7 @@ def set_requests_session( elif resp.status_code >= 500: resp.raise_for_status() + @lru_cache def get_s3fs_session( self, daac: Optional[str] = None, @@ -158,6 +198,7 @@ def get_s3fs_session( ) return None + @lru_cache def get_fsspec_session(self) -> fsspec.AbstractFileSystem: """Returns a fsspec HTTPS session with bearer tokens that are used by CMR. This HTTPS session can be used to download granules if we want to use a direct, lower level API @@ -253,12 +294,7 @@ def _open_granules( if s3_fs is not None: try: - - def multi_thread_open(url: str) -> Any: - return s3_fs.open(url) - - fileset = pqdm(data_links, multi_thread_open, n_jobs=8) - + fileset = _open_files(data_links, granules, s3_fs) except Exception: print( "An exception occurred while trying to access remote files on S3: " @@ -267,7 +303,7 @@ def multi_thread_open(url: str) -> Any: ) return None else: - fileset = self._open_urls_https(data_links, n_jobs=8) + fileset = self._open_urls_https(data_links, granules, n_jobs=8) return fileset else: access_method = "on_prem" @@ -276,7 +312,7 @@ def multi_thread_open(url: str) -> Any: granule.data_links(access=access_method) for granule in granules ) ) - fileset = self._open_urls_https(data_links, n_jobs=8) + fileset = self._open_urls_https(data_links, granules, n_jobs=8) return fileset @_open.register @@ -310,11 +346,7 @@ def _open_urls( s3_fs = self.get_s3fs_session(provider=provider) if s3_fs is not None: try: - - def multi_thread_open(url: str) -> Any: - return s3_fs.open(url) - - fileset = pqdm(data_links, multi_thread_open, n_jobs=8) + fileset = _open_files(data_links, granules, s3_fs) except Exception: print( "An exception occurred while trying to access remote files on S3: " @@ -336,7 +368,7 @@ def multi_thread_open(url: str) -> Any: "We cannot open S3 links when we are not in-region, try using HTTPS links" ) return None - fileset = self._open_urls_https(data_links, 8) + fileset = self._open_urls_https(data_links, granules, 8) return fileset def get( @@ -538,17 +570,12 @@ def _download_onprem_granules( return results def _open_urls_https( - self, urls: List[str] = [], n_jobs: int = 8 + self, urls: List[str] = [], granules=[], n_jobs: int = 8 ) -> List[fsspec.AbstractFileSystem]: https_fs = self.get_fsspec_session() if https_fs is not None: try: - - def multi_thread_open(url: str) -> Any: - return https_fs.open(url) - - fileset = pqdm(urls, multi_thread_open, n_jobs=8) - + fileset = _open_files(urls, granules, https_fs) except Exception: print( "An exception occurred while trying to access remote files via HTTPS: "