Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify save_dir and some directory -> dir renames #151

Merged
merged 36 commits into from
Oct 25, 2021
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
73ed0cb
wip renames
ejm714 Oct 23, 2021
5802102
renames in docs
ejm714 Oct 23, 2021
be8478d
readme
ejm714 Oct 23, 2021
1e986d3
data dir renamme in docs
ejm714 Oct 23, 2021
5630332
rename in code from data_directory to data_dir
ejm714 Oct 23, 2021
7f28c59
maintaining update
ejm714 Oct 23, 2021
3e51468
fix capitalization
ejm714 Oct 23, 2021
2b8b282
further updates
ejm714 Oct 23, 2021
3618be5
tweak
ejm714 Oct 23, 2021
49a3029
do not overwrite
ejm714 Oct 23, 2021
bd45a60
add overwrite save dir
ejm714 Oct 23, 2021
bacc14d
add overwrite save dir to config
ejm714 Oct 23, 2021
efb5ec2
update configs with all info
ejm714 Oct 25, 2021
c7578b5
use full train configuration
ejm714 Oct 25, 2021
70d9b1c
only upload if does not exist
ejm714 Oct 25, 2021
72adb39
tests for save
ejm714 Oct 25, 2021
0578a8a
overwrite param
ejm714 Oct 25, 2021
ddd4c78
better set up and test for overwrite
ejm714 Oct 25, 2021
bfbb24f
docs
ejm714 Oct 25, 2021
dc28179
update docs with overwrite
ejm714 Oct 25, 2021
0b4a098
from overwrite_save_dir to overwrite
ejm714 Oct 25, 2021
2515b3f
missed rename
ejm714 Oct 25, 2021
247fc5c
remove machine specific from vlc
ejm714 Oct 25, 2021
2b8f546
unindent so test actually runs
ejm714 Oct 25, 2021
2bbe343
check for local and cached checkpoints
ejm714 Oct 25, 2021
e7fee37
should be and
ejm714 Oct 25, 2021
e6eb0db
write out predict config before preds start like we do for train config
ejm714 Oct 25, 2021
a5680a2
update all configs and use only first 10 digits of hash
ejm714 Oct 25, 2021
1d1bfbb
dry run check after save is configured; more robust test
ejm714 Oct 25, 2021
488c47a
reorder
ejm714 Oct 25, 2021
61d3878
show save directory
ejm714 Oct 25, 2021
e78cea0
copy edits
ejm714 Oct 25, 2021
78270d5
update template
ejm714 Oct 25, 2021
4dd9aae
fix test
ejm714 Oct 25, 2021
2a0a315
lower case for consistency
ejm714 Oct 25, 2021
59322a7
fix test
ejm714 Oct 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/docs/configurations.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ class PredictConfig(ZambaBaseModel)
batch_size: int = 2,
save: bool = True,
save_dir: Optional[Path] = None,
overwrite: bool = False,
dry_run: bool = False,
proba_threshold: float = None,
output_class_names: bool = False,
Expand Down Expand Up @@ -204,6 +205,10 @@ Whether to save out predictions. If `False`, predictions are not saved. Defaults
An optional directory in which to save the model predictions and configuration yaml. If
no `save_dir` is specified and `save` is True, outputs will be written to the current working directory. Defaults to `None`

#### `overwrite (bool)`

If True, will overwrite `zamba_predictions.csv` and `predict_configuration.yaml` in `save_dir` if they exist. Defaults to False.

#### `dry_run (bool, optional)`

Specifying `True` is useful for trying out model implementations more quickly by running only a single batch of inference. Defaults to `False`
Expand Down
2 changes: 2 additions & 0 deletions docs/docs/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ Options:
loaded prior to inference. Only use if
you're very confident all your videos can be
loaded.
-o, --overwrite Overwrite outputs in the save directory if
they exist.
-y, --yes Skip confirmation of configuration and
proceed right to prediction.
--help Show this message and exit.
Expand Down
8 changes: 8 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,14 @@ def test_predict_specific_options(mocker, minimum_valid_predict, tmp_path): # n
)
assert result.exit_code == 0

# test save overwrite
(tmp_path / "zamba_predictions.csv").touch()
result = runner.invoke(
app,
minimum_valid_predict + ["--output-class-names", "--save-dir", str(tmp_path), "-o"],
)
assert result.exit_code == 0


def test_actual_prediction_on_single_video(tmp_path): # noqa: F811
data_dir = tmp_path / "videos"
Expand Down
37 changes: 37 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,45 @@ def test_predict_save(labels_absolute_path, tmp_path, dummy_trained_model_checkp
skip_load_validation=True,
)
assert config.save is True
# save dir gets created
assert (tmp_path / "my_dir").exists()

# empty save dir does not error
save_dir = tmp_path / "save_dir"
save_dir.mkdir()

config = PredictConfig(
filepaths=labels_absolute_path, save_dir=save_dir, skip_load_validation=True
)
assert config.save_dir == save_dir

# save dir with prediction csv or yaml will error
for pred_file in [
(save_dir / "zamba_predictions.csv"),
(save_dir / "predict_configuration.yaml"),
]:
with pytest.raises(ValueError) as error:
# just takes one of the two files to raise error
pred_file.touch()
_ = PredictConfig(
filepaths=labels_absolute_path, save_dir=save_dir, skip_load_validation=True
)
assert (
"zamba_predictions.csv and/or predict_configuration.yaml already exist in /home/emily/zamba. If you would like to overwrite, set overwrite=True"
ejm714 marked this conversation as resolved.
Show resolved Hide resolved
== error.value.errors()[0]["msg"]
)
pred_file.unlink()

# can overwrite
pred_file.touch()
config = PredictConfig(
filepaths=labels_absolute_path,
save_dir=save_dir,
skip_load_validation=True,
overwrite=True,
)
assert config.save_dir == save_dir


def test_validate_scheduler(labels_absolute_path):
# None gets transformed into SchedulerConfig
Expand Down
6 changes: 6 additions & 0 deletions zamba/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,9 @@ def predict(
None,
help="Skip check that verifies all videos can be loaded prior to inference. Only use if you're very confident all your videos can be loaded.",
),
overwrite: bool = typer.Option(
None, "--overwrite", "-o", help="Overwrite outputs in the save directory if they exist."
),
yes: bool = typer.Option(
False,
"--yes",
Expand Down Expand Up @@ -314,6 +317,9 @@ def predict(
if skip_load_validation is not None:
predict_dict["skip_load_validation"] = skip_load_validation

if overwrite is not None:
predict_dict["overwrite"] = overwrite

try:
manager = ModelManager(
ModelConfig(
Expand Down
51 changes: 16 additions & 35 deletions zamba/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,35 +522,6 @@ def preprocess_labels(cls, values):
values["labels"] = labels.reset_index()
return values

def get_model_only_params(self):
"""Return only params that are not data or machine specific.
Used for generating official configs.
"""
train_config = self.dict()

# remove data and machine specific params
for key in [
"labels",
"data_dir",
"dry_run",
"batch_size",
"auto_lr_find",
"gpus",
"num_workers",
"max_epochs",
"weight_download_region",
"split_proportions",
"save_dir",
"overwrite_save_dir",
"skip_load_validation",
"from_scratch",
"model_cache_dir",
"predict_all_zamba_species",
]:
train_config.pop(key)

return train_config


class PredictConfig(ZambaBaseModel):
"""
Expand Down Expand Up @@ -580,6 +551,8 @@ class PredictConfig(ZambaBaseModel):
predictions and configuration yaml. ath to a CSV to save predictions.
If no save_dir is specified and save=True, outputs will be written to
the current working directory. Defaults to None.
overwrite (bool): If True, overwrite outputs in save_dir if they exist.
Defaults to False.
dry_run (bool): Perform inference on a single batch for testing. Predictions
will not be saved. Defaults to False.
proba_threshold (float, optional): Probability threshold for classification.
Expand Down Expand Up @@ -611,6 +584,7 @@ class PredictConfig(ZambaBaseModel):
batch_size: int = 2
save: bool = True
save_dir: Optional[Path] = None
overwrite: bool = False
dry_run: bool = False
proba_threshold: Optional[float] = None
output_class_names: bool = False
Expand Down Expand Up @@ -638,16 +612,23 @@ def validate_save_dir(cls, values):
save = values["save"]

# if no save_dir but save is True, use current working directory
if save_dir is None:
if save:
save_dir = Path.cwd()
if save_dir is None and save:
save_dir = Path.cwd()

if save_dir is not None:
# check if files exist
if (
(save_dir / "zamba_predictions.csv").exists()
or (save_dir / "predict_configuration.yaml").exists()
) and not values["overwrite"]:
raise ValueError(
f"zamba_predictions.csv and/or predict_configuration.yaml already exist in {save_dir}. If you would like to overwrite, set overwrite=True"
)

# if save dir is not None
else:
# make a directory if needed
save_dir.mkdir(parents=True, exist_ok=True)

# set set to True is save_dir is set (save dir takes precedence)
# set save to True if save_dir is set (save dir takes precedence)
if not save:
save = True

Expand Down
30 changes: 22 additions & 8 deletions zamba/models/official_models/european/config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
train_config:
scheduler_config: default
model_name: european
backbone_finetune_config:
backbone_initial_ratio_lr: 0.01
multiplier: 1
Expand All @@ -9,21 +7,37 @@ train_config:
unfreeze_backbone_at_epoch: 15
verbose: true
early_stopping_config:
mode: max
monitor: val_macro_f1
patience: 3
verbose: true
mode: max
model_name: european
scheduler_config: default
video_loader_config:
model_input_height: 240
model_input_width: 426
crop_bottom_pixels: 50
fps: 4
total_frames: 16
early_bias: false
ensure_total_frames: true
evenly_sample_total_frames: false
fps: 4.0
frame_indices: null
frame_selection_height: null
frame_selection_width: null
i_frames: false
megadetector_lite_config:
confidence: 0.25
device: cuda
ejm714 marked this conversation as resolved.
Show resolved Hide resolved
fill_mode: score_sorted
image_height: 416
image_width: 416
n_frames: 16
nms_threshold: 0.45
seed: 55
sort_by_time: true
model_input_height: 240
model_input_width: 426
pix_fmt: rgb24
scene_threshold: null
total_frames: 16
predict_config:
model_name: european
public_checkpoint: european_0c69da8a888c499411deaa040a91d76546ddf78a.ckpt
public_checkpoint: european_01cba5835bdfd9948ae392b16d0b38f12578a54c.ckpt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ train_config:
unfreeze_backbone_at_epoch: 15
verbose: true
batch_size: 2
cache_dir: /home/ubuntu/.cache/zamba
checkpoint: data/results/experiments/td_dev_set_full_size_mdlite/results/version_1/time_distributed_zamba.ckpt
data_dir: /home/ubuntu/zamba-algorithms
dry_run: false
Expand All @@ -35,6 +34,7 @@ train_config:
from_scratch: false
gpus: 1
max_epochs: null
model_cache_dir: /home/ubuntu/.cache/zamba
model_name: null
num_workers: 3
overwrite_save_dir: false
Expand Down
45 changes: 32 additions & 13 deletions zamba/models/official_models/slowfast/config.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,4 @@
train_config:
scheduler_config:
scheduler: MultiStepLR
scheduler_params:
gamma: 0.5
milestones:
- 1
verbose: true
model_name: slowfast
backbone_finetune_config:
backbone_initial_ratio_lr: 0.01
multiplier: 10
Expand All @@ -15,18 +7,45 @@ train_config:
unfreeze_backbone_at_epoch: 3
verbose: true
early_stopping_config:
mode: max
monitor: val_macro_f1
patience: 5
verbose: true
model_name: slowfast
scheduler_config:
scheduler: MultiStepLR
scheduler_params:
gamma: 0.5
milestones:
- 1
verbose: true
video_loader_config:
model_input_height: 240
model_input_width: 426
cache_dir: /tmp/zamba_cache
ejm714 marked this conversation as resolved.
Show resolved Hide resolved
cleanup_cache: false
crop_bottom_pixels: 50
fps: 8
total_frames: 32
early_bias: false
ensure_total_frames: true
evenly_sample_total_frames: false
fps: 8.0
frame_indices: null
frame_selection_height: null
frame_selection_width: null
i_frames: false
megadetector_lite_config:
confidence: 0.25
device: cuda
ejm714 marked this conversation as resolved.
Show resolved Hide resolved
fill_mode: score_sorted
image_height: 416
image_width: 416
n_frames: 32
nms_threshold: 0.45
seed: 55
sort_by_time: true
model_input_height: 240
model_input_width: 426
pix_fmt: rgb24
scene_threshold: null
total_frames: 32
predict_config:
model_name: slowfast
public_checkpoint: slowfast_501182d969aabf49805829a2b09ed8078b4255a3.ckpt
public_checkpoint: slowfast_23ad8c41beb9c86eef8fd0f0dc02a35a7807e4c6.ckpt
43 changes: 30 additions & 13 deletions zamba/models/official_models/time_distributed/config.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,4 @@
train_config:
scheduler_config:
scheduler: MultiStepLR
scheduler_params:
gamma: 0.5
milestones:
- 3
verbose: true
model_name: time_distributed
backbone_finetune_config:
backbone_initial_ratio_lr: 0.01
multiplier: 1
Expand All @@ -15,18 +7,43 @@ train_config:
unfreeze_backbone_at_epoch: 3
verbose: true
early_stopping_config:
mode: max
monitor: val_macro_f1
patience: 5
verbose: true
model_name: time_distributed
scheduler_config:
scheduler: MultiStepLR
scheduler_params:
gamma: 0.5
milestones:
- 3
verbose: true
video_loader_config:
model_input_height: 240
model_input_width: 426
crop_bottom_pixels: 50
fps: 4
total_frames: 16
early_bias: false
ensure_total_frames: true
evenly_sample_total_frames: false
fps: 4.0
frame_indices: null
frame_selection_height: null
frame_selection_width: null
i_frames: false
megadetector_lite_config:
confidence: 0.25
device: cuda
ejm714 marked this conversation as resolved.
Show resolved Hide resolved
fill_mode: score_sorted
image_height: 416
image_width: 416
n_frames: 16
nms_threshold: 0.45
seed: 55
sort_by_time: true
model_input_height: 240
model_input_width: 426
pix_fmt: rgb24
scene_threshold: null
total_frames: 16
predict_config:
model_name: time_distributed
public_checkpoint: time_distributed_9e710aa8c92d25190a64b3b04b9122bdcb456982.ckpt
public_checkpoint: time_distributed_c68adb0e9eb4a6cf49b83e86f03dea55b0edc77f.ckpt
Loading