diff --git a/docs/docs/predict-tutorial.md b/docs/docs/predict-tutorial.md index cc0d82e6..11bf3c80 100644 --- a/docs/docs/predict-tutorial.md +++ b/docs/docs/predict-tutorial.md @@ -158,3 +158,21 @@ And there's so much more! You can also do things like specify your region for fa ### 5. Test your configuration with a dry run Before kicking off a full run of inference, we recommend testing your code with a "dry run". This will run one batch of inference to quickly detect any bugs. See the [Debugging](debugging.md) page for details. + + +## Predicting species from images + +Zamba does not currently provide comprehensive support for images by default, only videos. We do, however, have experimental support for making predictions on images using our existing models. This may be useful if you have a few images that you would like to classify or you want to compare the performance on a small set of images. + +To do this, you will need to set the environment variable `PREDICT_ON_IMAGES=True` (for example by prefacing the `zamba` command with it: `PREDICT_ON_IMAGES=True zamba predict --data-dir example_images/`). + +By default, `zamba` will look for files with the following suffixes: `.jpg`, `.jpeg`, `.png`, and `.webp`. To use other image suffixes that are supported by OpenCV, set your `IMAGE_SUFFIXES` environment variable. + +The caveats are: + + - The models may be less accurate since there is less information in a single image than in a video. + - This approach will be computationally inefficient as compared to a model that works natively on images. + - Blank / non-blank detection may be less effective since only the classification portion is executed, not the detection portion. + - This is not recommended for training or finetuning scenarios given the computational inefficiency. + +More comprehensive image support is planned for a future release. \ No newline at end of file diff --git a/tests/assets/images/chimpanzee_bonobo.jpeg b/tests/assets/images/chimpanzee_bonobo.jpeg new file mode 100644 index 00000000..f65ede22 Binary files /dev/null and b/tests/assets/images/chimpanzee_bonobo.jpeg differ diff --git a/tests/assets/images/equid.webp b/tests/assets/images/equid.webp new file mode 100644 index 00000000..a5c9d94f Binary files /dev/null and b/tests/assets/images/equid.webp differ diff --git a/tests/assets/images/small_cat.jpg b/tests/assets/images/small_cat.jpg new file mode 100644 index 00000000..bd7de714 Binary files /dev/null and b/tests/assets/images/small_cat.jpg differ diff --git a/tests/assets/images/wild_dog_jackal.jpg b/tests/assets/images/wild_dog_jackal.jpg new file mode 100644 index 00000000..858cf5f7 Binary files /dev/null and b/tests/assets/images/wild_dog_jackal.jpg differ diff --git a/tests/test_cli.py b/tests/test_cli.py index 180a8202..ed901ada 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,4 +1,5 @@ import os +from pathlib import Path import shutil from typer.testing import CliRunner @@ -6,6 +7,7 @@ import pytest from pytest_mock import mocker # noqa: F401 +import zamba from zamba.cli import app from conftest import ASSETS_DIR, TEST_VIDEOS_DIR @@ -189,6 +191,44 @@ def test_actual_prediction_on_single_video(tmp_path, model): # noqa: F811 ) +@pytest.mark.parametrize("model", ["time_distributed", "blank_nonblank"]) +def test_actual_prediction_on_images(tmp_path, model, mocker): # noqa: F811 + """Tests experimental feature of predicting on images.""" + shutil.copytree(ASSETS_DIR / "images", tmp_path / "images") + data_dir = tmp_path / "images" + + save_dir = tmp_path / "zamba" + + mocker.patch.object(zamba.models.config, "PREDICT_ON_IMAGES", True) + + result = runner.invoke( + app, + [ + "predict", + "--data-dir", + str(data_dir), + "--yes", + "--save-dir", + str(save_dir), + "--model", + model, + ], + ) + assert result.exit_code == 0 + # check preds file got saved out + assert save_dir.exists() + # check config got saved out too + assert (save_dir / "predict_configuration.yaml").exists() + df = pd.read_csv(save_dir / "zamba_predictions.csv", index_col="filepath") + + if model == "time_distributed": + for img, label in df.idxmax(axis=1).items(): + assert Path(img).stem == label + + if model == "blank_nonblank": + assert (df.blank < 0.1).all() + + def test_depth_cli_options(mocker, tmp_path): # noqa: F811 mocker.patch("zamba.models.depth_estimation.config.DepthEstimationConfig.run_model", pred_mock) diff --git a/zamba/data/video.py b/zamba/data/video.py index 813e435c..eeb748c8 100644 --- a/zamba/data/video.py +++ b/zamba/data/video.py @@ -23,6 +23,7 @@ MegadetectorLiteYoloX, MegadetectorLiteYoloXConfig, ) +from zamba.settings import IMAGE_SUFFIXES def ffprobe(path: os.PathLike) -> pd.Series: @@ -414,6 +415,40 @@ def __del__(self): ) +def load_and_repeat_image(path, target_size=(224, 224), repeat_count=4): + """ + Loads an image, resizes it, and repeats it N times. + + Args: + path: Path to the image file. + target_size: A tuple (w, h) representing the desired width and height of the resized image. + repeat_count: Number of times to repeat the image. + + Returns: + A NumPy array of shape (N, h, w, 3) representing the repeated image. + """ + image = cv2.imread(str(path)) + + # Resize the image in same way as video frames are in `load_video_frames` + image = cv2.resize( + image, + target_size, + # https://stackoverflow.com/a/51042104/1692709 + interpolation=( + cv2.INTER_LINEAR + if image.shape[1] < target_size[0] # compare image width with target width + else cv2.INTER_AREA + ), + ) + + image_array = np.expand_dims(image, axis=0) + + # Repeat the image N times + repeated_image = np.repeat(image_array, repeat_count, axis=0) + + return repeated_image + + def load_video_frames( filepath: os.PathLike, config: Optional[VideoLoaderConfig] = None, @@ -421,6 +456,8 @@ def load_video_frames( ): """Loads frames from videos using fast ffmpeg commands. + Supports images as well, but it is inefficient since we just replicate the frames. + Args: filepath (os.PathLike): Path to the video. config (VideoLoaderConfig, optional): Configuration for video loading. @@ -435,6 +472,13 @@ def load_video_frames( if config is None: config = VideoLoaderConfig(**kwargs) + if Path(filepath).suffix.lower() in IMAGE_SUFFIXES: + return load_and_repeat_image( + filepath, + target_size=(config.model_input_width, config.model_input_height), + repeat_count=config.total_frames, + ) + video_stream = get_video_stream(filepath) w = int(video_stream["width"]) h = int(video_stream["height"]) diff --git a/zamba/models/config.py b/zamba/models/config.py index 0c3ba4e4..bd826381 100644 --- a/zamba/models/config.py +++ b/zamba/models/config.py @@ -29,7 +29,7 @@ RegionEnum, ) from zamba.pytorch.transforms import zamba_image_model_transforms, slowfast_transforms -from zamba.settings import SPLIT_SEED, VIDEO_SUFFIXES +from zamba.settings import IMAGE_SUFFIXES, PREDICT_ON_IMAGES, SPLIT_SEED, VIDEO_SUFFIXES GPUS_AVAILABLE = torch.cuda.device_count() @@ -224,11 +224,13 @@ def get_filepaths(cls, values): new_suffixes = [] # iterate over all files in data directory - for f in values["data_dir"].rglob("*"): + for f in Path(values["data_dir"]).rglob("*"): if f.is_file(): # keep just files with supported suffixes if f.suffix.lower() in VIDEO_SUFFIXES: files.append(f.resolve()) + elif PREDICT_ON_IMAGES and f.suffix.lower() in IMAGE_SUFFIXES: + files.append(f.resolve()) else: new_suffixes.append(f.suffix.lower()) diff --git a/zamba/settings.py b/zamba/settings.py index 7f59f7d7..8d3b9125 100644 --- a/zamba/settings.py +++ b/zamba/settings.py @@ -9,3 +9,10 @@ # random seed to use for splitting data without site info into train / val / holdout sets SPLIT_SEED = os.environ.get("SPLIT_SEED", 4007) + + +# experimental support for predicting on images +IMAGE_SUFFIXES = [ + ext.strip() for ext in os.environ.get("IMAGE_SUFFIXES", ".jpg,.jpeg,.png,.webp").split(",") +] +PREDICT_ON_IMAGES = os.environ.get("PREDICT_ON_IMAGES", "False").lower() == "true"