Skip to content

Commit

Permalink
Improved Timm models (#109) (#118)
Browse files Browse the repository at this point in the history
* [minimal working example] HuggingFace hub models now supported (#109)
* Add gitignore
* Add support for TIMM image feature extractors
* docs
* renamed imagenet labels and added in21k labels
* extract_frames(hf): loading pretr weights now;
* extract_frames(hf): implemented show pred for timm models
* utils: test for model_name (should be specified)
* hf.yaml: rm model_name default; style fix
* extract_frames: a note with assumption
* renamed hf to timm
* timm.md: init
* conda_env, install_conda: upd for timm
* test_timm: test timm models
* extract_frames: not all hf models have 'tag'
* rename extract_frames.py to extract_timm.py
* README, index: added timm models

---------

Co-authored-by: Bruno Korbar <[email protected]>
  • Loading branch information
v-iashin and bjuncek authored Jan 25, 2024
1 parent c22f49b commit 896b852
Show file tree
Hide file tree
Showing 15 changed files with 22,118 additions and 13 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ Optical Flow

Frame-wise Features

- [All models from TIMM e.g. ViT, ConvNeXt, EVA, Swin, DINO (ImageNet, LAION, etc)](https://v-iashin.github.io/video_features/models/timm)
- [CLIP](https://v-iashin.github.io/video_features/models/clip)
- [ResNet-18,34,50,101,152 (ImageNet)](https://v-iashin.github.io/video_features/models/resnet)

Expand Down Expand Up @@ -85,12 +86,11 @@ On a rare occasion when the collision happens, the script will rewrite previousl

## Input
The inputs are paths to video files.
Paths can be passed as a list of paths or as a text file formatted with a single path per line.

Paths can be passed as a list of paths to videos or as a text file formatted with a single path per line.

## Output
Output is defined by the `on_extraction` argument; by default it prints the features to the command line.
Possible values of output are ['print', 'save_numpy', 'save_pickle']. `save_*` options save the features in
Possible values of output are `['print', 'save_numpy', 'save_pickle']`. `save_*` options save the features in
the `output_path` folder with the same name as the input video file but with the `.npy` or `.pkl` extension.

## Used in
Expand All @@ -103,7 +103,7 @@ Please, let me know if you found this repo useful for your projects or papers.

## Acknowledgements

- [@Kamino666](https://github.com/Kamino666): added CLIP model as well as Windows and CPU support (and
many other small things).
- [@borijang](https://github.com/borijang): for solving bugs with file names, I3D checkpoint loading enhancement and code style improvements.
- [@Kamino666](https://github.com/Kamino666): added CLIP model as well as Windows and CPU support (and many other useful things).
- [@ohjho](https://github.com/ohjho): added support of 37-layer R(2+1)d favors.
- [@borijang](https://github.com/borijang): for solving bugs with file names, I3D checkpoint loading enhancement and code style improvements.
- [@bjuncek](https://github.com/bjuncek): for helping with timm models and offline discussion.
5 changes: 5 additions & 0 deletions conda_env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -242,3 +242,8 @@ dependencies:
- yaml=0.2.5=h7b6447c_0
- zlib=1.2.13=hd590300_5
- zstd=1.5.5=hc292b87_0
- pip:
- fsspec==2023.12.2
- huggingface-hub==0.20.2
- safetensors==0.4.1
- timm==0.9.12
22 changes: 22 additions & 0 deletions configs/timm.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Model
feature_type: 'timm'
model_name: null # any timm model
batch_size: 1 # Batchsize (only frame-wise extractors are supported)
extraction_fps: null # For original video fps, leave unspecified 'null' (None)
extraction_total: null # extract a fix number of frames. It is mutually exclusive with 'fps'

# Extraction Parameters
device: 'cuda:0' # device as in `torch`, can be 'cpu'
on_extraction: 'print' # what to do once the features are extracted. Can be ['print', 'save_numpy', 'save_pickle']
output_path: './output' # where to store results if saved
tmp_path: './tmp' # folder to store the temporary files used for extraction (frames or aud files)
keep_tmp_files: false # to keep temp files after feature extraction.
show_pred: false # to show preds of a model, i.e. on a pre-train dataset for each feature (Kinetics 400)
pred_texts: null # provide a list of multiple sentences. if `null`, will perform zero-shot on Kinetics 400

# config
config: null

# Video paths
video_paths: null
file_with_video_paths: null # if the list of videos is large, you might put them in a txt file, use this argument to specify the path
7 changes: 4 additions & 3 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Optical Flow

Frame-wise Features

- [All models from TIMM e.g. ViT, ConvNeXt, EVA, Swin, DINO (ImageNet, LAION, etc)](https://v-iashin.github.io/video_features/models/timm)
- [CLIP](https://v-iashin.github.io/video_features/models/clip)
- [ResNet-18,34,50,101,152 (ImageNet)](https://v-iashin.github.io/video_features/models/resnet)

Expand Down Expand Up @@ -87,7 +88,7 @@ Please, let me know if you found this repo useful for your projects or papers.

## Acknowledgements

- [@Kamino666](https://github.com/Kamino666): added CLIP model as well as Windows and CPU support (and
many other small things).
- [@borijang](https://github.com/borijang): for solving bugs with file names, I3D checkpoint loading enhancement and code style improvements.
- [@Kamino666](https://github.com/Kamino666): added CLIP model as well as Windows and CPU support (and many other useful things).
- [@ohjho](https://github.com/ohjho): added support of 37-layer R(2+1)d favors.
- [@borijang](https://github.com/borijang): for solving bugs with file names, I3D checkpoint loading enhancement and code style improvements.
- [@bjuncek](https://github.com/bjuncek): for helping with timm models and offline discussion.
2 changes: 2 additions & 0 deletions docs/meta/install_conda.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,6 @@ conda install -c conda-forge omegaconf scipy tqdm pytest opencv
conda install -c conda-forge ftfy regex
# vggish
conda install -c conda-forge resampy pysoundfile
# timm models
pip install timm
```
91 changes: 91 additions & 0 deletions docs/models/timm.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# timm

`video_features` ❤️ [timm](https://huggingface.co/docs/timm/index).
We support all the models from the `timm` library (technically, for those where you can specify `pretrained=True`).

For details, see the [timm docs](https://huggingface.co/docs/timm/index) and,
specifically [model summaries](https://huggingface.co/docs/timm/models) and
[model benchmark results](https://huggingface.co/docs/timm/results).

## Supported Arguments
<!-- the <div> makes columns wider -->
| <div style="width: 12em">Argument</div> | <div style="width: 8em">Default</div> | Description |
| --------------------------------------- | ------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `model_name` | `null` | Any model from `timm.list_pretrained()`, e.g. `efficientnet_b0` or `efficientnet_b0.ra_in1k`. |
| `batch_size` | `1` | You may speed up extraction of features by increasing the batch size as much as your GPU permits. |
| `extraction_fps` | `null` | If specified (e.g. as `5`), the video will be re-encoded to the `extraction_fps` fps. Leave unspecified or `null` to skip re-encoding. |
| `device` | `"cuda:0"` | The device specification. It follows the PyTorch style. Use `"cuda:3"` for the 4th GPU on the machine or `"cpu"` for CPU-only. |
| `video_paths` | `null` | A list of videos for feature extraction. E.g. `"[./sample/v_ZNVhz7ctTq0.mp4, ./sample/v_GGSY1Qvo990.mp4]"` or just one path `"./sample/v_GGSY1Qvo990.mp4"`. |
| `file_with_video_paths` | `null` | A path to a text file with video paths (one path per line). Hint: given a folder `./dataset` with `.mp4` files one could use: `find ./dataset -name "*mp4" > ./video_paths.txt`. |
| `on_extraction` | `print` | If `print`, the features are printed to the terminal. If `save_numpy` or `save_pickle`, the features are saved to either `.npy` file or `.pkl`. |
| `output_path` | `"./output"` | A path to a folder for storing the extracted features (if `on_extraction` is either `save_numpy` or `save_pickle`). |
| `keep_tmp_files` | `false` | If `true`, the reencoded videos will be kept in `tmp_path`. |
| `tmp_path` | `"./tmp"` | A path to a folder for storing temporal files (e.g. reencoded videos). |
| `show_pred` | `false` | If `true`, the script will print the predictions of the model on a down-stream task. It is useful for debugging. This flag is only supported for the models that were trained on ImageNet 1K and 21K. |


## Examples

```bash
python main.py \
feature_type=timm \
model_name=efficientnet_b0 \
device="cuda:0" \
video_paths="[./sample/v_ZNVhz7ctTq0.mp4, ./sample/v_GGSY1Qvo990.mp4]"
```

If you want to specify particular weights, you can do it with `model_name` argument, as you'd do with `timm`,
e.g.
```bash
python main.py \
feature_type=timm \
model_name=efficientnet_b0.ra_in1k \
device="cuda:0" \
video_paths="[./sample/v_GGSY1Qvo990.mp4]"
```

If you'd like to check the model's outputs on a downstream task (ImageNet 1K or 21K), you can use `show_pred` argument.
```bash
python main.py \
feature_type=timm \
model_name=swin_small_patch4_window7_224.ms_in22k \
device="cuda:0" \
extraction_fps=1 \
video_paths="[./sample/v_GGSY1Qvo990.mp4]" \
show_pred=true
# Logits | Prob. | Label
# 12.029 | 0.456 | barbell
# 11.676 | 0.321 | weight, free_weight, exercising_weight
# 9.653 | 0.042 | pusher, thruster
# 9.499 | 0.036 | dumbbell
# 8.787 | 0.018 | bench_press

# Logits | Prob. | Label
# 11.742 | 0.467 | barbell
# 11.233 | 0.281 | weight, free_weight, exercising_weight
# 9.489 | 0.049 | dumbbell
# 8.923 | 0.028 | pusher, thruster
# 8.406 | 0.017 | bench_press

# Logits | Prob. | Label
# 12.257 | 0.571 | barbell
# 11.391 | 0.240 | weight, free_weight, exercising_weight
# 9.708 | 0.045 | dumbbell
# 9.031 | 0.023 | pusher, thruster
# 8.756 | 0.017 | bench_press

# Logits | Prob. | Label
# 12.469 | 0.571 | barbell
# 11.655 | 0.253 | weight, free_weight, exercising_weight
# 9.818 | 0.040 | dumbbell
# 9.648 | 0.034 | pusher, thruster
# 8.527 | 0.011 | bench_press

...
```

## Credits
* [timm](https://huggingface.co/docs/timm/index) library

## License
`video_features` is under MIT, the `timm` is under [Apache 2.0](https://github.com/huggingface/pytorch-image-models/blob/main/LICENSE).
2 changes: 2 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def main(args_cli):
from models.raft.extract_raft import ExtractRAFT as Extractor
elif args.feature_type == 'clip':
from models.clip.extract_clip import ExtractCLIP as Extractor
elif args.feature_type == 'timm':
from models.timm.extract_timm import ExtractTIMM as Extractor
else:
raise NotImplementedError(f'Extractor {args.feature_type} is not implemented.')

Expand Down
2 changes: 2 additions & 0 deletions models/_base/base_framewise_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ def extract(self, video_path: str) -> Dict[str, np.ndarray]:

def run_on_a_batch(self, batch: List[torch.Tensor]) -> torch.Tensor:
model = self.name2module['model']
# e.g for ResNet, batch is (B, C, H, W),
batch = torch.cat(batch).to(self.device)
# (B, D)
batch_feats = model(batch)
self.maybe_show_pred(batch_feats)
return batch_feats
2 changes: 1 addition & 1 deletion models/resnet/extract_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,4 @@ def load_model(self) -> Dict[str, torch.nn.Module]:
def maybe_show_pred(self, feats: torch.Tensor):
if self.show_pred:
logits = self.name2module['class_head'](feats)
show_predictions_on_dataset(logits, 'imagenet')
show_predictions_on_dataset(logits, 'imagenet1k')
1 change: 1 addition & 0 deletions models/timm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .extract_timm import *
91 changes: 91 additions & 0 deletions models/timm/extract_timm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@

import omegaconf
from typing import Dict, List
import torch
from PIL import Image
from torchvision.transforms import Compose
from models._base.base_framewise_extractor import BaseFrameWiseExtractor
from utils.utils import show_predictions_on_dataset

try:
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
except ImportError:
raise ImportError("This features require timm library to be installed.")

class ExtractTIMM(BaseFrameWiseExtractor):

def __init__(self, args: omegaconf.DictConfig) -> None:
super().__init__(
feature_type=args.feature_type,
on_extraction=args.on_extraction,
tmp_path=args.tmp_path,
output_path=args.output_path,
keep_tmp_files=args.keep_tmp_files,
device=args.device,
model_name=args.model_name,
batch_size=args.batch_size,
extraction_fps=args.extraction_fps,
extraction_total=args.extraction_total,
show_pred=args.show_pred,
)

# transform must be implemented in _create_model
self.transforms = None
self.name2module = self.load_model()

def load_model(self) -> Dict[str, torch.nn.Module]:
"""Defines the models, loads checkpoints and related transforms,
sends them to the device.
Raises:
NotImplementedError: if a model is not implemented.
Returns:
Dict[str, torch.nn.Module]: model-agnostic dict holding modules for extraction and show_pred
"""
model = timm.create_model(self.model_name, pretrained=True)

# transforms
self.transforms = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))
self.transforms = Compose([lambda np_array: Image.fromarray(np_array), self.transforms])
print(self.transforms)

model.to(self.device)
model.eval()

# remove the classifier after getting it
class_head = model.get_classifier()
model.reset_classifier(0)

# to be used in `run_on_a_batch` to determine the how to show predictions
self.hf_arch = model.default_cfg['architecture']
self.hf_tag = model.default_cfg.get('tag', '')

return {'model': model, 'class_head': class_head, }

def run_on_a_batch(self, batch: List) -> torch.Tensor:
"""This is a hack for timm models to output features.
Ideally, you want to use model_spec to define behaviour at forward pass in
the config file.
"""
model = self.name2module['model']
batch = torch.cat(batch).to(self.device)
batch_feats = model(batch)
self.maybe_show_pred(batch_feats)
return batch_feats

def maybe_show_pred(self, feats: torch.Tensor):
if self.show_pred:
logits = self.name2module['class_head'](feats)
# NOTE: these hardcoded ends assume that the end of the tag corresponds to the last training dset
if self.hf_tag.endswith(('in1k', 'in1k_288', 'in1k_320', 'in1k_384', 'in1k_475', 'in1k_512',)):
show_predictions_on_dataset(logits, 'imagenet1k')
elif self.hf_tag.endswith(('in21k', 'in21k_288', 'in21k_320', 'in21k_384', 'in21k_475',
'in21k_512',
'in22k', 'in22k_288', 'in22k_320', 'in22k_384', 'in22k_475',
'in22k_512',)):
show_predictions_on_dataset(logits, 'imagenet21k')
else:
print(f'No show_pred for {self.hf_arch} with tag {self.hf_tag}; use `show_pred=False`')
39 changes: 39 additions & 0 deletions tests/timm/test_timm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import sys
from pathlib import Path

import pytest

sys.path.insert(0, '.') # nopep8

from models.timm import ExtractTIMM as Extractor
from tests.utils import base_test_script

# a bit ugly but it assumes the features being tested has the same folder name,
# e.g. for r21d: ./tests/r21d/THIS_FILE
# it prevents doing the same tests for different features
THIS_FILE_PATH = __file__
FEATURE_TYPE = Path(THIS_FILE_PATH).parent.name

# True when run for the first time, then must be False
TO_MAKE_REF = False

signature = 'device, video_paths, model_name, batch_size, extraction_fps, to_make_ref'
test_params = [
('cuda:0', './sample/v_GGSY1Qvo990.mp4', 'vit_base_patch16_224.dino', 1, 1, TO_MAKE_REF),
('cuda:0', './sample/v_GGSY1Qvo990.mp4', 'coat_tiny.in1k', 1, 1, TO_MAKE_REF),
('cuda:0', './sample/v_GGSY1Qvo990.mp4', 'hf-hub:nateraw/resnet50-oxford-iiit-pet', 4, 2, TO_MAKE_REF),
('cuda:0', './sample/v_GGSY1Qvo990.mp4', 'mobilenetv3_small_050', 1, None, TO_MAKE_REF),
]


@pytest.mark.parametrize(signature, test_params)
def test(device, video_paths, model_name, batch_size, extraction_fps, to_make_ref):
# get config
patch_kwargs = dict(
device=device,
video_paths=video_paths,
model_name=model_name,
batch_size=batch_size,
extraction_fps=extraction_fps
)
base_test_script(FEATURE_TYPE, Extractor, to_make_ref, **patch_kwargs)
File renamed without changes.
Loading

0 comments on commit 896b852

Please sign in to comment.