-
Notifications
You must be signed in to change notification settings - Fork 250
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #399 from microsoft/PyTorchWildlife_prerelease
Pytorch-Wildlife version 1.0.2
- Loading branch information
Showing
34 changed files
with
1,703 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) [2023] [Microsoft] | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
## Welcome to the classification finetuning module for Pytorch-Wildlife! | ||
|
||
This repository focuses on training classification models for Pytorch-Wildlife. This module is designed to help both programmers and biologists train a classification model for animal identification. The output weights of this training codebase can be easily integrated in our [Pytorch-Wildlife](https://github.com/microsoft/CameraTraps/) framework. Our goal is to make this tool accessible and easy to use, regardless of your background in machine learning or programming. | ||
|
||
|
||
## Installation | ||
|
||
Before you start, ensure you have Python installed on your machine. This project is developed using Python 3.8. If you're not sure how to install Python, you can find detailed instructions on the [official Python website](https://www.python.org/). | ||
|
||
To install the required libraries and dependencies, follow these steps: | ||
|
||
1. Make sure you are in the PT_FT_classification directory. | ||
|
||
2. Install the required packages | ||
|
||
### Using pip and `requirements.txt` | ||
```bash | ||
pip install -r requirements.txt | ||
``` | ||
|
||
### Using conda and `environment.yaml` | ||
|
||
Create and activate a Conda environment with Python 3.8: | ||
|
||
```bash | ||
conda env create -f environment.yaml | ||
conda activate PT_Finetuning | ||
``` | ||
|
||
### Usage | ||
|
||
**Data Structure:** | ||
|
||
This codebase has been optimized to facilitate its use for non-technical users. To ensure the code works correctly, your images should be stored in a single directory with no nested directories. The `annotations.csv` file, containing image paths and their classification IDs and labels, should be placed outside of the images directory. Image paths in the CSV should be relative to the position of the `annotations.csv` file. | ||
|
||
Example directory structure: | ||
|
||
``` | ||
PT_FT_classification/ | ||
│ | ||
├── data/ | ||
│ ├── imgs/ # All images stored here | ||
│ └── annotation_example.csv # Annotations file | ||
│ | ||
└── configs/config.yaml # Configuration file | ||
``` | ||
|
||
**Annotation file structure** | ||
To ensure the code works correctly, your annotation file should contain the following columns: | ||
1. path: relative path to the image file | ||
2. classification: unique identifier for each class (e.g., 0, 1, 2, etc.) | ||
3. label: name of the class (e.g., "cat", "dog", "bird", etc.) | ||
|
||
**Data splitting** | ||
If you want to split your data into training, validation, and test sets, you can use the `split_path` and `split_data` parameters in the `config.py` file. This `split_path` should point to a CSV file containing the image paths and their corresponding classification IDs and labels, while the `split_data` parameter should be set to `True`. | ||
|
||
Currently, pytorch-wildlife classification supports three types of data splitting: `random`, `location`, and `sequence`. Random splitting uses the class ID to randomly split the data into training, validation, and test sets while keeping a balanced class distribution. **Due to the nature of camera trap images, it is common to capture a burst of pictures when movement is detected. For this reason, using random splitting is not recommended. This is because similar-looking images of the same animal could end up in both training and validation sets, leading to overfitting.** | ||
|
||
Location splitting requires an additional "Location" column in the data, and it splits the data based on the location of the images, making sure that all images from one location will be in a single split; this splitting method does not guarantee a balanced class distribution. Finally, sequence splitting requires a "Photo_time" column containing the shooting time of the picture, it should be in YYYY-MM-DD HH:MM:SS format. This method will group images within a 30 second period in a "sequence", and then split the data based on these sequences; this splitting method does not guarantee a balanced class distribution. | ||
|
||
|
||
The CSV file should have the previously mentioned structure. The code will then split the data into training, validation, and test sets based on the proportions specified in the `config.py` file and the splitting type. [The annotation example](data/imgs/annotation_example.csv) shows how files should be annotated for each type of splitting. | ||
|
||
If you don't require data splitting, you can set the `split_data` parameter to `False` in the `config.py` file. | ||
|
||
### Configuration | ||
|
||
Before training your model, you need to configure the training and data parameters in the `config.py` file. Here's a brief explanation of the parameters to help both technical and non-technical users understand their purposes: | ||
|
||
- **Training Parameters:** | ||
- `conf_id`: A unique identifier for your training configuration. | ||
- `algorithm`: The training algorithm to use. Default is "Plain". | ||
- `log_dir`: Directory where training logs are saved. | ||
- `num_epochs`: Total number of training epochs. | ||
- `log_interval`: How often to log training information. | ||
- `parallel`: Set to 1 to enable parallel computing (if supported by your hardware). | ||
|
||
- **Data Parameters:** | ||
- `dataset_root`: The root directory where your images are stored. | ||
- `dataset_name`: Name of the dataset. Custom_Crop is required for finetuning. | ||
- `annotation_dir`: Directory where annotation files are located. | ||
- `split_path`: Path to the single CSV file containing the annotations, it will be used for data splitting. | ||
- `test_size`: Proportion of data to use as test set. | ||
- `val_size`: Proportion of data to use as validation set. | ||
- `split_data`: Set to True if you want the code to split your data into training, validation, and test sets using the `split_path`. | ||
- `split_type`: Type of data splitting, it can be "random", "location" or "sequence". | ||
- `batch_size`: Number of images to process in a batch. | ||
- `num_workers`: Number of subprocesses to use for data loading. | ||
|
||
- **Model Parameters:** | ||
- `model_name`: The name of the model architecture to use. The current version only supports PlainResNetClassifier. | ||
- `num_layers`: Number of layers in the model. Currently only supports 18 and 50. | ||
- `weights_init`: Initial weights setting for the model. Currently only supports "ImageNet". | ||
|
||
- **Optimization Parameters:** | ||
- `lr_feature`, `momentum_feature`, `weight_decay_feature`: Learning rate, momentum, and weight decay for feature extractor. | ||
- `lr_classifier`, `momentum_classifier`, `weight_decay_classifier`: Learning rate, momentum, and weight decay for classifier. | ||
- `step_size`, `gamma`: Parameters for learning rate scheduler. | ||
|
||
|
||
## Usage | ||
|
||
### Currently, our support is limited to the ResNet architecture. You are encouraged to explore other architectures, but it's important to maintain consistency with our code structure (particularly, an independent feature extractor and a classifier) for compatibility with the PyTorch-Wildlife framework. | ||
|
||
After configuring your `config.yaml` file, you can start training your model by running: | ||
|
||
```bash | ||
python main.py | ||
``` | ||
|
||
This command will initiate the training process based on the parameters specified in `config.py`. Make sure to monitor the output for any errors or important information regarding the training progress. | ||
### We have provided 10 example images and an annotation file in the `data` directory for code testing without needing to provide your own data. | ||
|
||
## Output | ||
Once training is complete, the output weights will be saved in the `weights` directory. These weights can be used to classify new images using the [Pytorch-Wildlife](https://github.com/microsoft/CameraTraps/) | ||
|
||
### We are working on adding a feature in a future release to directly integrate the output weights with the Pytorch-Wildlife framework and the Gradio App. | ||
|
||
## License | ||
|
||
This project is licensed under the MIT License - see the LICENSE file for details. | ||
|
||
## Support | ||
|
||
If you encounter any issues or have questions, please feel free to open an issue on the GitHub repository page. We aim to make this tool as accessible as possible and will gladly provide assistance. | ||
|
||
Thank you! |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from batch_detection_cropping import * |
40 changes: 40 additions & 0 deletions
40
PT_FT_classification/configs/Raw/Crop_res50_plain_082723.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# training | ||
conf_id: Crop_Res50_plain_082723 | ||
algorithm: Plain | ||
log_dir: Crop | ||
num_epochs: 60 | ||
log_interval: 10 | ||
parallel: 0 | ||
|
||
# data | ||
dataset_root: ./data/imgs | ||
dataset_name: Custom_Crop | ||
# annotation directory (if you have train/val/test splits) | ||
annotation_dir: ./data/imgs | ||
# data splitting (if you don't have train/val/test splits) | ||
split_path: ./data/imgs/annotation_example.csv | ||
test_size: 0.2 | ||
val_size: 0.2 | ||
split_data: True | ||
split_type: location # options are: random, location, sequence | ||
# data loading | ||
batch_size: 256 | ||
num_workers: 0 #40 | ||
# model | ||
model_name: PlainResNetClassifier | ||
num_layers: 50 | ||
weights_init: ImageNet | ||
|
||
# optim | ||
## feature | ||
lr_feature: 0.01 | ||
momentum_feature: 0.9 | ||
weight_decay_feature: 0.0005 | ||
## classifier | ||
lr_classifier: 0.01 | ||
momentum_classifier: 0.9 | ||
weight_decay_classifier: 0.0005 | ||
## lr_scheduler | ||
step_size: 20 | ||
gamma: 0.1 | ||
|
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
path,classification,label,Photo_Time,Location | ||
10050028_0.JPG,0,Anteater,2024-01-27 21:30:11,Amazon_Guaviare | ||
10050028_1.JPG,0,Anteater,2024-01-27 21:30:13,Amazon_Guaviare | ||
10050028_2.JPG,0,Anteater,2024-01-27 21:30:15,Amazon_Guaviare | ||
10050028_3.JPG,1,Non-Anteater,2024-01-27 22:10:11,Amazon_Caqueta | ||
10050028_4.JPG,1,Non-Anteater,2024-01-27 22:10:11,Amazon_Caqueta | ||
10050028_5.JPG,1,Non-Anteater,2024-01-27 22:10:11,Amazon_Caqueta | ||
10050028_6.JPG,1,Non-Anteater,2024-01-28 01:54:43,Amazon_Guainia | ||
10050028_7.JPG,1,Non-Anteater,2024-01-28 01:55:13,Amazon_Guainia | ||
10050028_8.JPG,0,Anteater,2024-01-28 21:30:11,Amazon_Guainia | ||
10050028_9.JPG,0,Anteater,2024-01-28 21:30:11,Amazon_Guainia |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
name: PT_Finetuning | ||
channels: | ||
- conda-forge | ||
- defaults | ||
dependencies: | ||
- _libgcc_mutex=0.1=conda_forge | ||
- _openmp_mutex=4.5=2_gnu | ||
- bzip2=1.0.8=hd590300_5 | ||
- ca-certificates=2023.11.17=hbcca054_0 | ||
- ld_impl_linux-64=2.40=h41732ed_0 | ||
- libffi=3.4.2=h7f98852_5 | ||
- libgcc-ng=13.2.0=h807b86a_4 | ||
- libgomp=13.2.0=h807b86a_4 | ||
- libnsl=2.0.1=hd590300_0 | ||
- libsqlite=3.44.2=h2797004_0 | ||
- libuuid=2.38.1=h0b41bf4_0 | ||
- libxcrypt=4.4.36=hd590300_1 | ||
- libzlib=1.2.13=hd590300_5 | ||
- ncurses=6.4=h59595ed_2 | ||
- openssl=3.2.0=hd590300_1 | ||
- pip=23.3.2=pyhd8ed1ab_0 | ||
- python=3.8.18=hd12c33a_1_cpython | ||
- readline=8.2=h8228510_1 | ||
- setuptools=69.0.3=pyhd8ed1ab_0 | ||
- tk=8.6.13=noxft_h4845f30_101 | ||
- wheel=0.42.0=pyhd8ed1ab_0 | ||
- xz=5.2.6=h166bdaf_0 | ||
- pip: | ||
- absl-py==2.1.0 | ||
- aiofiles==23.2.1 | ||
- aiohttp==3.9.3 | ||
- aiosignal==1.3.1 | ||
- altair==5.2.0 | ||
- annotated-types==0.6.0 | ||
- anyio==4.2.0 | ||
- asttokens==2.4.1 | ||
- async-timeout==4.0.3 | ||
- attrs==23.2.0 | ||
- backcall==0.2.0 | ||
- cachetools==5.3.2 | ||
- certifi==2023.11.17 | ||
- charset-normalizer==3.3.2 | ||
- click==8.1.7 | ||
- colorama==0.4.6 | ||
- contourpy==1.1.1 | ||
- cycler==0.12.1 | ||
- decorator==5.1.1 | ||
- exceptiongroup==1.2.0 | ||
- executing==2.0.1 | ||
- fastapi==0.109.0 | ||
- ffmpy==0.3.1 | ||
- filelock==3.13.1 | ||
- fire==0.5.0 | ||
- fonttools==4.47.2 | ||
- frozenlist==1.4.1 | ||
- fsspec==2023.12.2 | ||
- google-auth==2.27.0 | ||
- google-auth-oauthlib==1.0.0 | ||
- gradio==4.8.0 | ||
- gradio-client==0.7.1 | ||
- grpcio==1.60.0 | ||
- h11==0.14.0 | ||
- httpcore==1.0.2 | ||
- httpx==0.26.0 | ||
- huggingface-hub==0.20.3 | ||
- idna==3.6 | ||
- importlib-metadata==7.0.1 | ||
- importlib-resources==6.1.1 | ||
- ipython==8.12.3 | ||
- jedi==0.19.1 | ||
- jinja2==3.1.3 | ||
- joblib==1.3.2 | ||
- jsonschema==4.21.1 | ||
- jsonschema-specifications==2023.12.1 | ||
- kiwisolver==1.4.5 | ||
- lightning-utilities==0.10.1 | ||
- markdown==3.5.2 | ||
- markdown-it-py==3.0.0 | ||
- markupsafe==2.1.4 | ||
- matplotlib==3.7.4 | ||
- matplotlib-inline==0.1.6 | ||
- mdurl==0.1.2 | ||
- multidict==6.0.4 | ||
- munch==2.5.0 | ||
- numpy==1.24.4 | ||
- oauthlib==3.2.2 | ||
- opencv-python==4.9.0.80 | ||
- opencv-python-headless==4.9.0.80 | ||
- orjson==3.9.12 | ||
- packaging==23.2 | ||
- pandas==2.0.3 | ||
- parso==0.8.3 | ||
- pexpect==4.9.0 | ||
- pickleshare==0.7.5 | ||
- pillow==10.1.0 | ||
- pkgutil-resolve-name==1.3.10 | ||
- prompt-toolkit==3.0.43 | ||
- protobuf==3.20.1 | ||
- psutil==5.9.8 | ||
- ptyprocess==0.7.0 | ||
- pure-eval==0.2.2 | ||
- pyasn1==0.5.1 | ||
- pyasn1-modules==0.3.0 | ||
- pydantic==2.6.0 | ||
- pydantic-core==2.16.1 | ||
- pydub==0.25.1 | ||
- pygments==2.17.2 | ||
- pyparsing==3.1.1 | ||
- python-dateutil==2.8.2 | ||
- python-multipart==0.0.6 | ||
- pytorch-lightning==1.9.0 | ||
- pytorchwildlife==1.0.1.1 | ||
- pytz==2023.4 | ||
- pyyaml==6.0.1 | ||
- referencing==0.33.0 | ||
- requests==2.31.0 | ||
- requests-oauthlib==1.3.1 | ||
- rich==13.7.0 | ||
- rpds-py==0.17.1 | ||
- rsa==4.9 | ||
- scikit-learn==1.2.0 | ||
- scipy==1.10.1 | ||
- seaborn==0.13.2 | ||
- semantic-version==2.10.0 | ||
- shellingham==1.5.4 | ||
- six==1.16.0 | ||
- sniffio==1.3.0 | ||
- stack-data==0.6.3 | ||
- starlette==0.35.1 | ||
- supervision==0.16.0 | ||
- tensorboard==2.14.0 | ||
- tensorboard-data-server==0.7.2 | ||
- termcolor==2.4.0 | ||
- thop==0.1.1-2209072238 | ||
- threadpoolctl==3.2.0 | ||
- tomlkit==0.12.0 | ||
- toolz==0.12.1 | ||
- torch==1.10.1 | ||
- torchaudio==0.10.1 | ||
- torchmetrics==1.3.0.post0 | ||
- torchvision==0.11.2 | ||
- tqdm==4.66.1 | ||
- traitlets==5.14.1 | ||
- typer==0.9.0 | ||
- typing-extensions==4.9.0 | ||
- tzdata==2023.4 | ||
- ultralytics-yolov5==0.1.1 | ||
- urllib3==2.2.0 | ||
- uvicorn==0.27.0.post1 | ||
- wcwidth==0.2.13 | ||
- websockets==11.0.3 | ||
- werkzeug==3.0.1 | ||
- yarl==1.9.4 | ||
- zipp==3.17.0 | ||
prefix: /home/andreshernandezcelisadeccoc/.conda/envs/PT_Finetuning |
Oops, something went wrong.