Skip to content

Commit

Permalink
Merge pull request #4 from dirac-institute/awo/initial-model
Browse files Browse the repository at this point in the history
Adding example CNN to the set of models
  • Loading branch information
drewoldag authored Sep 17, 2024
2 parents 807a47f + c2ce047 commit 9920d1e
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 0 deletions.
72 changes: 72 additions & 0 deletions example_config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
[general]
use_gpu = true

# Destination of log messages
# 'stderr' and 'stdout' specify the console.
log_destination = "stderr"
# A path name specifies a file e.g.
# log = "fibad_log.txt"

# Lowest log level to emit.
# As you go down the list, fibad will become more verbose in the log.
#
# log_level = "critical" # Only emit the most severe of errors
# log_level = "error" # Emit all errors
# log_level = "warning" # Emit warnings and all errors
log_level = "info" # Emit informational messages, warnings and all errors
# log_level = "debug" # Very verbose, emit all log messages.

[download]
sw = "22asec"
sh = "22asec"
filter = ["HSC-G", "HSC-R", "HSC-I", "HSC-Z", "HSC-Y"]
type = "coadd"
rerun = "pdr3_wide"
username = "mtauraso@local"
password = "cCw+nX53lmNLHMy+JbizpH/dl4t7sxljiNm6a7k1"
max_connections = 2
fits_file = "../hscplay/temp.fits"
cutout_dir = "../hscplay/cutouts/"
offset = 0
num_sources = 500

# These control the downloader's HTTP requests and retries
# `retry_wait` How long to wait before retrying a failed HTTP request in seconds. Default 30s
retry_wait = 30
# `retries` How many times to retry a failed HTTP request before moving on to the next one. Default 3 times
retries = 3
# `timepout` How long should we wait to get a full HTTP response from the server. Default 3600s (1hr)
timeout = 3600
# `chunksize` How many sky location rectangles should we request in a single request. Default is 990
chunksize = 990

[model]
# name = "ExampleCNN"
# name = "ExampleAutoencoder"

# An example of requesting an external model class
# external_class = "user_package.submodule.ExternalModel"
external_cls = "kbmod_ml.models.cnn.CNN"

weights_filepath = "example_model.pth"
epochs = 10

[data_loader]
# Name of data loader to use
name = "CifarDataLoader"
# name = "HSCDataLoader"

# An example of requesting an external data loader class
# external_class = "user_package.submodule.ExternalDataLoader"

# Directory path where the data is stored
path = "/Users/drew/code/fibad/data/cifar-10-batches-py"
# path = "/Users/drew/code/fibad/data/hsc-samples"

# Default PyTorch DataLoader parameters
batch_size = 10
shuffle = true
num_workers = 10

[predict]
batch_size = 32
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ classifiers = [
dynamic = ["version"]
requires-python = ">=3.9"
dependencies = [
"torch", # PyTorch
# "fibad", when it is available on PyPI
]

[project.urls]
Expand Down
74 changes: 74 additions & 0 deletions src/kbmod_ml/models/cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# ruff: noqa: D101, D102

# This example model is taken from the PyTorch CIFAR10 tutorial:
# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#define-a-convolutional-neural-network
import logging

import torch
import torch.nn as nn
import torch.nn.functional as F # noqa N812
import torch.optim as optim
from fibad.models.model_registry import fibad_model

logger = logging.getLogger(__name__)


@fibad_model
class CNN(nn.Module):
def __init__(self, model_config, shape):
logger.info("This is an external model, not in FIBAD!!!")
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

self.config = model_config

# Optimizer and criterion could be set directly, i.e. `self.optimizer = optim.SGD(...)`
# but we define them as methods as a way to allow for more flexibility in the future.
self.optimizer = self._optimizer()
self.criterion = self._criterion()

def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x

def train_step(self, batch):
"""This function contains the logic for a single training step. i.e. the
contents of the inner loop of a ML training process.
Parameters
----------
batch : tuple
A tuple containing the inputs and labels for the current batch.
Returns
-------
Current loss value
The loss value for the current batch.
"""
inputs, labels = batch

self.optimizer.zero_grad()
outputs = self(inputs)
loss = self.criterion(outputs, labels)
loss.backward()
self.optimizer.step()
return {"loss": loss.item()}

def _criterion(self):
return nn.CrossEntropyLoss()

def _optimizer(self):
return optim.SGD(self.parameters(), lr=0.001, momentum=0.9)

def save(self):
torch.save(self.state_dict(), self.config.get("weights_filepath", "example_cnn.pth"))

0 comments on commit 9920d1e

Please sign in to comment.