Skip to content

Commit

Permalink
Merge pull request #399 from microsoft/PyTorchWildlife_prerelease
Browse files Browse the repository at this point in the history
Pytorch-Wildlife version 1.0.2
  • Loading branch information
zhmiao authored Feb 15, 2024
2 parents 14529bc + 05d9835 commit 832f123
Show file tree
Hide file tree
Showing 34 changed files with 1,703 additions and 4 deletions.
21 changes: 21 additions & 0 deletions PT_FT_classification/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) [2023] [Microsoft]

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
127 changes: 127 additions & 0 deletions PT_FT_classification/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
## Welcome to the classification finetuning module for Pytorch-Wildlife!

This repository focuses on training classification models for Pytorch-Wildlife. This module is designed to help both programmers and biologists train a classification model for animal identification. The output weights of this training codebase can be easily integrated in our [Pytorch-Wildlife](https://github.com/microsoft/CameraTraps/) framework. Our goal is to make this tool accessible and easy to use, regardless of your background in machine learning or programming.


## Installation

Before you start, ensure you have Python installed on your machine. This project is developed using Python 3.8. If you're not sure how to install Python, you can find detailed instructions on the [official Python website](https://www.python.org/).

To install the required libraries and dependencies, follow these steps:

1. Make sure you are in the PT_FT_classification directory.

2. Install the required packages

### Using pip and `requirements.txt`
```bash
pip install -r requirements.txt
```

### Using conda and `environment.yaml`

Create and activate a Conda environment with Python 3.8:

```bash
conda env create -f environment.yaml
conda activate PT_Finetuning
```

### Usage

**Data Structure:**

This codebase has been optimized to facilitate its use for non-technical users. To ensure the code works correctly, your images should be stored in a single directory with no nested directories. The `annotations.csv` file, containing image paths and their classification IDs and labels, should be placed outside of the images directory. Image paths in the CSV should be relative to the position of the `annotations.csv` file.

Example directory structure:

```
PT_FT_classification/
├── data/
│ ├── imgs/ # All images stored here
│ └── annotation_example.csv # Annotations file
└── configs/config.yaml # Configuration file
```

**Annotation file structure**
To ensure the code works correctly, your annotation file should contain the following columns:
1. path: relative path to the image file
2. classification: unique identifier for each class (e.g., 0, 1, 2, etc.)
3. label: name of the class (e.g., "cat", "dog", "bird", etc.)

**Data splitting**
If you want to split your data into training, validation, and test sets, you can use the `split_path` and `split_data` parameters in the `config.py` file. This `split_path` should point to a CSV file containing the image paths and their corresponding classification IDs and labels, while the `split_data` parameter should be set to `True`.

Currently, pytorch-wildlife classification supports three types of data splitting: `random`, `location`, and `sequence`. Random splitting uses the class ID to randomly split the data into training, validation, and test sets while keeping a balanced class distribution. **Due to the nature of camera trap images, it is common to capture a burst of pictures when movement is detected. For this reason, using random splitting is not recommended. This is because similar-looking images of the same animal could end up in both training and validation sets, leading to overfitting.**

Location splitting requires an additional "Location" column in the data, and it splits the data based on the location of the images, making sure that all images from one location will be in a single split; this splitting method does not guarantee a balanced class distribution. Finally, sequence splitting requires a "Photo_time" column containing the shooting time of the picture, it should be in YYYY-MM-DD HH:MM:SS format. This method will group images within a 30 second period in a "sequence", and then split the data based on these sequences; this splitting method does not guarantee a balanced class distribution.


The CSV file should have the previously mentioned structure. The code will then split the data into training, validation, and test sets based on the proportions specified in the `config.py` file and the splitting type. [The annotation example](data/imgs/annotation_example.csv) shows how files should be annotated for each type of splitting.

If you don't require data splitting, you can set the `split_data` parameter to `False` in the `config.py` file.

### Configuration

Before training your model, you need to configure the training and data parameters in the `config.py` file. Here's a brief explanation of the parameters to help both technical and non-technical users understand their purposes:

- **Training Parameters:**
- `conf_id`: A unique identifier for your training configuration.
- `algorithm`: The training algorithm to use. Default is "Plain".
- `log_dir`: Directory where training logs are saved.
- `num_epochs`: Total number of training epochs.
- `log_interval`: How often to log training information.
- `parallel`: Set to 1 to enable parallel computing (if supported by your hardware).

- **Data Parameters:**
- `dataset_root`: The root directory where your images are stored.
- `dataset_name`: Name of the dataset. Custom_Crop is required for finetuning.
- `annotation_dir`: Directory where annotation files are located.
- `split_path`: Path to the single CSV file containing the annotations, it will be used for data splitting.
- `test_size`: Proportion of data to use as test set.
- `val_size`: Proportion of data to use as validation set.
- `split_data`: Set to True if you want the code to split your data into training, validation, and test sets using the `split_path`.
- `split_type`: Type of data splitting, it can be "random", "location" or "sequence".
- `batch_size`: Number of images to process in a batch.
- `num_workers`: Number of subprocesses to use for data loading.

- **Model Parameters:**
- `model_name`: The name of the model architecture to use. The current version only supports PlainResNetClassifier.
- `num_layers`: Number of layers in the model. Currently only supports 18 and 50.
- `weights_init`: Initial weights setting for the model. Currently only supports "ImageNet".

- **Optimization Parameters:**
- `lr_feature`, `momentum_feature`, `weight_decay_feature`: Learning rate, momentum, and weight decay for feature extractor.
- `lr_classifier`, `momentum_classifier`, `weight_decay_classifier`: Learning rate, momentum, and weight decay for classifier.
- `step_size`, `gamma`: Parameters for learning rate scheduler.


## Usage

### Currently, our support is limited to the ResNet architecture. You are encouraged to explore other architectures, but it's important to maintain consistency with our code structure (particularly, an independent feature extractor and a classifier) for compatibility with the PyTorch-Wildlife framework.

After configuring your `config.yaml` file, you can start training your model by running:

```bash
python main.py
```

This command will initiate the training process based on the parameters specified in `config.py`. Make sure to monitor the output for any errors or important information regarding the training progress.
### We have provided 10 example images and an annotation file in the `data` directory for code testing without needing to provide your own data.

## Output
Once training is complete, the output weights will be saved in the `weights` directory. These weights can be used to classify new images using the [Pytorch-Wildlife](https://github.com/microsoft/CameraTraps/)

### We are working on adding a feature in a future release to directly integrate the output weights with the Pytorch-Wildlife framework and the Gradio App.

## License

This project is licensed under the MIT License - see the LICENSE file for details.

## Support

If you encounter any issues or have questions, please feel free to open an issue on the GitHub repository page. We aim to make this tool as accessible as possible and will gladly provide assistance.

Thank you!
1 change: 1 addition & 0 deletions PT_FT_classification/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from batch_detection_cropping import *
40 changes: 40 additions & 0 deletions PT_FT_classification/configs/Raw/Crop_res50_plain_082723.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# training
conf_id: Crop_Res50_plain_082723
algorithm: Plain
log_dir: Crop
num_epochs: 60
log_interval: 10
parallel: 0

# data
dataset_root: ./data/imgs
dataset_name: Custom_Crop
# annotation directory (if you have train/val/test splits)
annotation_dir: ./data/imgs
# data splitting (if you don't have train/val/test splits)
split_path: ./data/imgs/annotation_example.csv
test_size: 0.2
val_size: 0.2
split_data: True
split_type: location # options are: random, location, sequence
# data loading
batch_size: 256
num_workers: 0 #40
# model
model_name: PlainResNetClassifier
num_layers: 50
weights_init: ImageNet

# optim
## feature
lr_feature: 0.01
momentum_feature: 0.9
weight_decay_feature: 0.0005
## classifier
lr_classifier: 0.01
momentum_classifier: 0.9
weight_decay_classifier: 0.0005
## lr_scheduler
step_size: 20
gamma: 0.1

Binary file added PT_FT_classification/data/imgs/10050028_0.JPG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added PT_FT_classification/data/imgs/10050028_1.JPG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added PT_FT_classification/data/imgs/10050028_2.JPG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added PT_FT_classification/data/imgs/10050028_3.JPG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added PT_FT_classification/data/imgs/10050028_4.JPG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added PT_FT_classification/data/imgs/10050028_5.JPG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added PT_FT_classification/data/imgs/10050028_6.JPG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added PT_FT_classification/data/imgs/10050028_7.JPG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added PT_FT_classification/data/imgs/10050028_8.JPG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added PT_FT_classification/data/imgs/10050028_9.JPG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
11 changes: 11 additions & 0 deletions PT_FT_classification/data/imgs/annotation_example.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
path,classification,label,Photo_Time,Location
10050028_0.JPG,0,Anteater,2024-01-27 21:30:11,Amazon_Guaviare
10050028_1.JPG,0,Anteater,2024-01-27 21:30:13,Amazon_Guaviare
10050028_2.JPG,0,Anteater,2024-01-27 21:30:15,Amazon_Guaviare
10050028_3.JPG,1,Non-Anteater,2024-01-27 22:10:11,Amazon_Caqueta
10050028_4.JPG,1,Non-Anteater,2024-01-27 22:10:11,Amazon_Caqueta
10050028_5.JPG,1,Non-Anteater,2024-01-27 22:10:11,Amazon_Caqueta
10050028_6.JPG,1,Non-Anteater,2024-01-28 01:54:43,Amazon_Guainia
10050028_7.JPG,1,Non-Anteater,2024-01-28 01:55:13,Amazon_Guainia
10050028_8.JPG,0,Anteater,2024-01-28 21:30:11,Amazon_Guainia
10050028_9.JPG,0,Anteater,2024-01-28 21:30:11,Amazon_Guainia
155 changes: 155 additions & 0 deletions PT_FT_classification/environment.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
name: PT_Finetuning
channels:
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=conda_forge
- _openmp_mutex=4.5=2_gnu
- bzip2=1.0.8=hd590300_5
- ca-certificates=2023.11.17=hbcca054_0
- ld_impl_linux-64=2.40=h41732ed_0
- libffi=3.4.2=h7f98852_5
- libgcc-ng=13.2.0=h807b86a_4
- libgomp=13.2.0=h807b86a_4
- libnsl=2.0.1=hd590300_0
- libsqlite=3.44.2=h2797004_0
- libuuid=2.38.1=h0b41bf4_0
- libxcrypt=4.4.36=hd590300_1
- libzlib=1.2.13=hd590300_5
- ncurses=6.4=h59595ed_2
- openssl=3.2.0=hd590300_1
- pip=23.3.2=pyhd8ed1ab_0
- python=3.8.18=hd12c33a_1_cpython
- readline=8.2=h8228510_1
- setuptools=69.0.3=pyhd8ed1ab_0
- tk=8.6.13=noxft_h4845f30_101
- wheel=0.42.0=pyhd8ed1ab_0
- xz=5.2.6=h166bdaf_0
- pip:
- absl-py==2.1.0
- aiofiles==23.2.1
- aiohttp==3.9.3
- aiosignal==1.3.1
- altair==5.2.0
- annotated-types==0.6.0
- anyio==4.2.0
- asttokens==2.4.1
- async-timeout==4.0.3
- attrs==23.2.0
- backcall==0.2.0
- cachetools==5.3.2
- certifi==2023.11.17
- charset-normalizer==3.3.2
- click==8.1.7
- colorama==0.4.6
- contourpy==1.1.1
- cycler==0.12.1
- decorator==5.1.1
- exceptiongroup==1.2.0
- executing==2.0.1
- fastapi==0.109.0
- ffmpy==0.3.1
- filelock==3.13.1
- fire==0.5.0
- fonttools==4.47.2
- frozenlist==1.4.1
- fsspec==2023.12.2
- google-auth==2.27.0
- google-auth-oauthlib==1.0.0
- gradio==4.8.0
- gradio-client==0.7.1
- grpcio==1.60.0
- h11==0.14.0
- httpcore==1.0.2
- httpx==0.26.0
- huggingface-hub==0.20.3
- idna==3.6
- importlib-metadata==7.0.1
- importlib-resources==6.1.1
- ipython==8.12.3
- jedi==0.19.1
- jinja2==3.1.3
- joblib==1.3.2
- jsonschema==4.21.1
- jsonschema-specifications==2023.12.1
- kiwisolver==1.4.5
- lightning-utilities==0.10.1
- markdown==3.5.2
- markdown-it-py==3.0.0
- markupsafe==2.1.4
- matplotlib==3.7.4
- matplotlib-inline==0.1.6
- mdurl==0.1.2
- multidict==6.0.4
- munch==2.5.0
- numpy==1.24.4
- oauthlib==3.2.2
- opencv-python==4.9.0.80
- opencv-python-headless==4.9.0.80
- orjson==3.9.12
- packaging==23.2
- pandas==2.0.3
- parso==0.8.3
- pexpect==4.9.0
- pickleshare==0.7.5
- pillow==10.1.0
- pkgutil-resolve-name==1.3.10
- prompt-toolkit==3.0.43
- protobuf==3.20.1
- psutil==5.9.8
- ptyprocess==0.7.0
- pure-eval==0.2.2
- pyasn1==0.5.1
- pyasn1-modules==0.3.0
- pydantic==2.6.0
- pydantic-core==2.16.1
- pydub==0.25.1
- pygments==2.17.2
- pyparsing==3.1.1
- python-dateutil==2.8.2
- python-multipart==0.0.6
- pytorch-lightning==1.9.0
- pytorchwildlife==1.0.1.1
- pytz==2023.4
- pyyaml==6.0.1
- referencing==0.33.0
- requests==2.31.0
- requests-oauthlib==1.3.1
- rich==13.7.0
- rpds-py==0.17.1
- rsa==4.9
- scikit-learn==1.2.0
- scipy==1.10.1
- seaborn==0.13.2
- semantic-version==2.10.0
- shellingham==1.5.4
- six==1.16.0
- sniffio==1.3.0
- stack-data==0.6.3
- starlette==0.35.1
- supervision==0.16.0
- tensorboard==2.14.0
- tensorboard-data-server==0.7.2
- termcolor==2.4.0
- thop==0.1.1-2209072238
- threadpoolctl==3.2.0
- tomlkit==0.12.0
- toolz==0.12.1
- torch==1.10.1
- torchaudio==0.10.1
- torchmetrics==1.3.0.post0
- torchvision==0.11.2
- tqdm==4.66.1
- traitlets==5.14.1
- typer==0.9.0
- typing-extensions==4.9.0
- tzdata==2023.4
- ultralytics-yolov5==0.1.1
- urllib3==2.2.0
- uvicorn==0.27.0.post1
- wcwidth==0.2.13
- websockets==11.0.3
- werkzeug==3.0.1
- yarl==1.9.4
- zipp==3.17.0
prefix: /home/andreshernandezcelisadeccoc/.conda/envs/PT_Finetuning
Loading

0 comments on commit 832f123

Please sign in to comment.