-
Notifications
You must be signed in to change notification settings - Fork 170
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add comet_ml callback * fix typo * add comet to readme * add link to notebook * train for more epochs * add TrainingCallbacks in doc * apply isort and black * prepare release * typing * Update README.md * Update README.md * update gitignore * add viz in notebooks
- Loading branch information
1 parent
b045748
commit a04ea06
Showing
10 changed files
with
547 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
********************************** | ||
TrainingCallbacks | ||
********************************** | ||
|
||
.. automodule:: | ||
pythae.trainers.training_callbacks | ||
|
||
.. autoclass:: pythae.trainers.training_callbacks.TrainingCallback | ||
:members: | ||
|
||
.. autoclass:: pythae.trainers.training_callbacks.CallbackHandler | ||
:members: | ||
|
||
.. autoclass:: pythae.trainers.training_callbacks.MetricConsolePrinterCallback | ||
:members: | ||
|
||
.. autoclass:: pythae.trainers.training_callbacks.ProgressBarCallback | ||
:members: | ||
|
||
.. autoclass:: pythae.trainers.training_callbacks.WandbCallback | ||
:members: | ||
|
||
.. autoclass:: pythae.trainers.training_callbacks.MLFlowCallback | ||
:members: | ||
|
||
.. autoclass:: pythae.trainers.training_callbacks.CometCallback | ||
:members: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Tutorial - Comet ml experiments monitoring\n", | ||
"\n", | ||
"In this notebook, we will see how to monitor your experiments using the integrated **comet_ml** callbacks." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Install the library\n", | ||
"%pip install pythae" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Train your Pythae model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torchvision.datasets as datasets\n", | ||
"\n", | ||
"%load_ext autoreload\n", | ||
"%autoreload 2" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"mnist_trainset = datasets.MNIST(root='../data', train=True, download=True, transform=None)\n", | ||
"\n", | ||
"train_dataset = mnist_trainset.data[:-10000].reshape(-1, 1, 28, 28) / 255.\n", | ||
"eval_dataset = mnist_trainset.data[-10000:].reshape(-1, 1, 28, 28) / 255." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from pythae.models import BetaVAE, BetaVAEConfig\n", | ||
"from pythae.trainers import BaseTrainerConfig\n", | ||
"from pythae.pipelines.training import TrainingPipeline\n", | ||
"from pythae.models.nn.benchmarks.mnist import Encoder_ResNet_VAE_MNIST, Decoder_ResNet_AE_MNIST" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"training_config = BaseTrainerConfig(\n", | ||
" output_dir='my_model',\n", | ||
" learning_rate=1e-4,\n", | ||
" batch_size=100,\n", | ||
" num_epochs=10, # Change this to train the model a bit more,\n", | ||
" steps_predict=3\n", | ||
")\n", | ||
"\n", | ||
"\n", | ||
"model_config = BetaVAEConfig(\n", | ||
" input_dim=(1, 28, 28),\n", | ||
" latent_dim=16,\n", | ||
" beta=2.\n", | ||
"\n", | ||
")\n", | ||
"\n", | ||
"model = BetaVAE(\n", | ||
" model_config=model_config,\n", | ||
" encoder=Encoder_ResNet_VAE_MNIST(model_config), \n", | ||
" decoder=Decoder_ResNet_AE_MNIST(model_config) \n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Before lauching the pipeline, you will need to build your `CometCallback`\n", | ||
"\n", | ||
"To be able to access this feature you will need:\n", | ||
"- a valid comet_ml acccount\n", | ||
"- the `comet_ml` package installed in your virtual env. You can install it by running (`pip install comet_ml`)\n", | ||
"- Your `api_key` when setting up the `CometCallback`. Note that you may need to run `comet init --api-key` to set up your api-key locally and be able to synchronize your offline runs." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Before being allowed to monitor your experiments you may need to run the following\n", | ||
"# !pip install comet_ml" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Create you callback\n", | ||
"from pythae.trainers.training_callbacks import CometCallback\n", | ||
"\n", | ||
"callbacks = [] # the TrainingPipeline expects a list of callbacks\n", | ||
"\n", | ||
"comet_cb = CometCallback() # Build the callback \n", | ||
"\n", | ||
"# SetUp the callback \n", | ||
"comet_cb.setup(\n", | ||
" training_config=training_config, # training config\n", | ||
" model_config=model_config, # model config\n", | ||
" api_key=\"your_comet_api_key\", # specify your comet api-key\n", | ||
" project_name=\"your_comet_project\", # specify your wandb project\n", | ||
" #offline_run=True, # run in offline mode\n", | ||
" #offline_directory='my_offline_runs' # set the directory to store the offline runs\n", | ||
")\n", | ||
"\n", | ||
"callbacks.append(comet_cb) # Add it to the callbacks list" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"pipeline = TrainingPipeline(\n", | ||
" training_config=training_config,\n", | ||
" model=model\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"pipeline(\n", | ||
" train_data=train_dataset,\n", | ||
" eval_data=eval_dataset,\n", | ||
" callbacks=callbacks # pass the callbacks to the TrainingPipeline and you are done!\n", | ||
")\n", | ||
"# You can log to https://comet.com/your_comet_username/your_comet_project to monitor your training" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Or you can alternatively ability to view the Comet UI in the jupyter notebook\n", | ||
"import comet_ml\n", | ||
"\n", | ||
"experiment = comet_ml.get_global_experiment()\n", | ||
"experiment.display()" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"interpreter": { | ||
"hash": "3efa06c4da850a09a4898b773c7e91b0da3286dbbffa369a8099a14a8fa43098" | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3.8.11 64-bit ('pythae_dev': conda)", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.8.13" | ||
}, | ||
"orig_nbformat": 4 | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.