Skip to content

Commit

Permalink
Merge pull request #958 from azavea/lf/defaults
Browse files Browse the repository at this point in the history
Bug fixes
  • Loading branch information
lewfish authored Jul 1, 2020
2 parents 19d5d04 + 2be53b2 commit 006937b
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 49 deletions.
14 changes: 9 additions & 5 deletions rastervision_aws_s3/rastervision/aws_s3/s3_file_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from datetime import datetime
from urllib.parse import urlparse

from everett.manager import ConfigurationMissingError

from rastervision.pipeline.file_system import (FileSystem, NotReadableError,
NotWritableError)

Expand Down Expand Up @@ -83,11 +85,13 @@ class S3FileSystem(FileSystem):
def get_request_payer():
# Import here to avoid circular reference.
from rastervision.pipeline import rv_config
s3_config = rv_config.get_namespace_config(AWS_S3)

# 'None' needs the quotes because boto3 cannot handle None.
return ('requester' if s3_config(
'requester_pays', parser=bool, default='False') else 'None')
try:
s3_config = rv_config.get_namespace_config(AWS_S3)
# 'None' needs the quotes because boto3 cannot handle None.
return ('requester' if s3_config(
'requester_pays', parser=bool, default='False') else 'None')
except ConfigurationMissingError:
return 'None'

@staticmethod
def get_session():
Expand Down
29 changes: 10 additions & 19 deletions rastervision_pipeline/rastervision/pipeline/rv_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,14 @@
import json
from typing import Optional, List, Dict

from everett import NO_VALUE
from everett.manager import (ConfigManager, ConfigDictEnv, ConfigIniEnv,
ConfigOSEnv)
ConfigOSEnv, ConfigurationMissingError)

from rastervision.pipeline.verbosity import Verbosity

log = logging.getLogger(__name__)


class LocalEnv(object):
def get(self, key, namespace=None):
if isinstance(namespace, list):
if 'batch' in namespace or 'aws_s3' in namespace:
return ''
else:
return NO_VALUE


def load_conf_list(s):
"""Loads a list of items from the config.
Expand Down Expand Up @@ -187,12 +177,9 @@ def set_everett_config(self,
rv_home = os.path.join(home, '.rastervision')
self.rv_home = rv_home

if self.profile == 'local':
config_ini_env = LocalEnv()
else:
config_file_locations = self._discover_config_file_locations(
self.profile)
config_ini_env = ConfigIniEnv(config_file_locations)
config_file_locations = self._discover_config_file_locations(
self.profile)
config_ini_env = ConfigIniEnv(config_file_locations)

self.config = ConfigManager(
[
Expand Down Expand Up @@ -226,8 +213,12 @@ def get_config_dict(
config_dict = {}
for namespace, keys in rv_config_schema.items():
for key in keys:
config_dict[namespace + '_' + key] = \
self.get_namespace_config(namespace)(key)
try:
config_dict[namespace + '_' + key] = \
self.get_namespace_config(namespace)(key)
except ConfigurationMissingError:
pass

return config_dict

def _discover_config_file_locations(self, profile) -> List[str]:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from rastervision.pipeline.config import register_config
from rastervision.pipeline_example_plugin1.sample_pipeline2 import (MessageMakerConfig,
MessageMaker)
from rastervision.pipeline_example_plugin1.sample_pipeline2 import (
MessageMakerConfig, MessageMaker)


# You always need to use the register_config decorator.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from rastervision.pipeline.file_system import (
sync_to_dir, json_to_file, file_to_json, make_dir, zipdir,
download_if_needed, sync_from_dir, get_local_path, unzip, list_paths,
str_to_file)
str_to_file, FileSystem, LocalFileSystem)
from rastervision.pipeline.utils import terminate_at_exit
from rastervision.pipeline.config import (build_config, ConfigError,
upgrade_config, save_pipeline_config)
Expand Down Expand Up @@ -106,14 +106,14 @@ def __init__(self,
self.test_ds = None
self.test_dl = None

if cfg.output_uri.startswith('s3://'):
if FileSystem.get_file_system(cfg.output_uri) == LocalFileSystem:
self.output_dir = cfg.output_uri
make_dir(self.output_dir)
else:
self.output_dir = get_local_path(cfg.output_uri, tmp_dir)
make_dir(self.output_dir, force_empty=True)
if not cfg.overfit_mode:
self.sync_from_cloud()
else:
self.output_dir = cfg.output_uri
make_dir(self.output_dir)

self.last_model_path = join(self.output_dir, 'last-model.pth')
self.config_path = join(self.output_dir, 'learner-config.json')
Expand Down Expand Up @@ -162,13 +162,11 @@ def main(self):

def sync_to_cloud(self):
"""Sync any output to the cloud at output_uri."""
if self.cfg.output_uri.startswith('s3://'):
sync_to_dir(self.output_dir, self.cfg.output_uri)
sync_to_dir(self.output_dir, self.cfg.output_uri)

def sync_from_cloud(self):
"""Sync any previous output in the cloud to output_dir."""
if self.cfg.output_uri.startswith('s3://'):
sync_from_dir(self.cfg.output_uri, self.output_dir)
sync_from_dir(self.cfg.output_uri, self.output_dir)

def setup_tensorboard(self):
"""Setup for logging stats to TB."""
Expand Down Expand Up @@ -208,20 +206,13 @@ def unzip_data(self, uri: Union[str, List[str]]) -> List[str]:
Returns:
paths to directories that each contain contents of one zip file
"""
cfg = self.cfg
data_dirs = []

if isinstance(uri, list):
zip_uris = uri
else:
# TODO generalize this to work with any file system
if uri.startswith('s3://') or uri.startswith('/'):
data_uri = uri
else:
data_uri = join(cfg.base_uri, uri)
zip_uris = ([data_uri]
if data_uri.endswith('.zip') else list_paths(
data_uri, 'zip'))
zip_uris = ([uri]
if uri.endswith('.zip') else list_paths(uri, 'zip'))

for zip_ind, zip_uri in enumerate(zip_uris):
zip_path = get_local_path(zip_uri, self.data_cache_dir)
Expand Down
6 changes: 1 addition & 5 deletions scripts/integration_tests
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,5 @@ Run integration tests.
if [ "${1:-}" = "--help" ]; then
usage
else
# setting these env vars is a kludge to allow this to run in CI where there is no
# profile set. in the future we should have a mechanism for default values for
# everett fields.
BATCH_CPU_JOB_DEF="" BATCH_CPU_JOB_QUEUE="" BATCH_GPU_JOB_DEF="" BATCH_GPU_JOB_QUEUE="" BATCH_ATTEMPTS="" AWS_S3_REQUESTER_PAYS="False" \
python -m integration_tests ${@:1}
python -m integration_tests ${@:1}
fi

0 comments on commit 006937b

Please sign in to comment.