-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
148 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
[general] | ||
use_gpu = true | ||
|
||
# Destination of log messages | ||
# 'stderr' and 'stdout' specify the console. | ||
log_destination = "stderr" | ||
# A path name specifies a file e.g. | ||
# log = "fibad_log.txt" | ||
|
||
# Lowest log level to emit. | ||
# As you go down the list, fibad will become more verbose in the log. | ||
# | ||
# log_level = "critical" # Only emit the most severe of errors | ||
# log_level = "error" # Emit all errors | ||
# log_level = "warning" # Emit warnings and all errors | ||
log_level = "info" # Emit informational messages, warnings and all errors | ||
# log_level = "debug" # Very verbose, emit all log messages. | ||
|
||
[download] | ||
sw = "22asec" | ||
sh = "22asec" | ||
filter = ["HSC-G", "HSC-R", "HSC-I", "HSC-Z", "HSC-Y"] | ||
type = "coadd" | ||
rerun = "pdr3_wide" | ||
username = "mtauraso@local" | ||
password = "cCw+nX53lmNLHMy+JbizpH/dl4t7sxljiNm6a7k1" | ||
max_connections = 2 | ||
fits_file = "../hscplay/temp.fits" | ||
cutout_dir = "../hscplay/cutouts/" | ||
offset = 0 | ||
num_sources = 500 | ||
|
||
# These control the downloader's HTTP requests and retries | ||
# `retry_wait` How long to wait before retrying a failed HTTP request in seconds. Default 30s | ||
retry_wait = 30 | ||
# `retries` How many times to retry a failed HTTP request before moving on to the next one. Default 3 times | ||
retries = 3 | ||
# `timepout` How long should we wait to get a full HTTP response from the server. Default 3600s (1hr) | ||
timeout = 3600 | ||
# `chunksize` How many sky location rectangles should we request in a single request. Default is 990 | ||
chunksize = 990 | ||
|
||
[model] | ||
# name = "ExampleCNN" | ||
# name = "ExampleAutoencoder" | ||
|
||
# An example of requesting an external model class | ||
# external_class = "user_package.submodule.ExternalModel" | ||
external_cls = "kbmod_ml.models.cnn.CNN" | ||
|
||
weights_filepath = "example_model.pth" | ||
epochs = 10 | ||
|
||
[data_loader] | ||
# Name of data loader to use | ||
name = "CifarDataLoader" | ||
# name = "HSCDataLoader" | ||
|
||
# An example of requesting an external data loader class | ||
# external_class = "user_package.submodule.ExternalDataLoader" | ||
|
||
# Directory path where the data is stored | ||
path = "/Users/drew/code/fibad/data/cifar-10-batches-py" | ||
# path = "/Users/drew/code/fibad/data/hsc-samples" | ||
|
||
# Default PyTorch DataLoader parameters | ||
batch_size = 10 | ||
shuffle = true | ||
num_workers = 10 | ||
|
||
[predict] | ||
batch_size = 32 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# ruff: noqa: D101, D102 | ||
|
||
# This example model is taken from the PyTorch CIFAR10 tutorial: | ||
# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#define-a-convolutional-neural-network | ||
import logging | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F # noqa N812 | ||
import torch.optim as optim | ||
from fibad.models.model_registry import fibad_model | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@fibad_model | ||
class CNN(nn.Module): | ||
def __init__(self, model_config, shape): | ||
logger.info("This is an external model, not in FIBAD!!!") | ||
super().__init__() | ||
self.conv1 = nn.Conv2d(3, 6, 5) | ||
self.pool = nn.MaxPool2d(2, 2) | ||
self.conv2 = nn.Conv2d(6, 16, 5) | ||
self.fc1 = nn.Linear(16 * 5 * 5, 120) | ||
self.fc2 = nn.Linear(120, 84) | ||
self.fc3 = nn.Linear(84, 10) | ||
|
||
self.config = model_config | ||
|
||
# Optimizer and criterion could be set directly, i.e. `self.optimizer = optim.SGD(...)` | ||
# but we define them as methods as a way to allow for more flexibility in the future. | ||
self.optimizer = self._optimizer() | ||
self.criterion = self._criterion() | ||
|
||
def forward(self, x): | ||
x = self.pool(F.relu(self.conv1(x))) | ||
x = self.pool(F.relu(self.conv2(x))) | ||
x = torch.flatten(x, 1) | ||
x = F.relu(self.fc1(x)) | ||
x = F.relu(self.fc2(x)) | ||
x = self.fc3(x) | ||
return x | ||
|
||
def train_step(self, batch): | ||
"""This function contains the logic for a single training step. i.e. the | ||
contents of the inner loop of a ML training process. | ||
Parameters | ||
---------- | ||
batch : tuple | ||
A tuple containing the inputs and labels for the current batch. | ||
Returns | ||
------- | ||
Current loss value | ||
The loss value for the current batch. | ||
""" | ||
inputs, labels = batch | ||
|
||
self.optimizer.zero_grad() | ||
outputs = self(inputs) | ||
loss = self.criterion(outputs, labels) | ||
loss.backward() | ||
self.optimizer.step() | ||
return {"loss": loss.item()} | ||
|
||
def _criterion(self): | ||
return nn.CrossEntropyLoss() | ||
|
||
def _optimizer(self): | ||
return optim.SGD(self.parameters(), lr=0.001, momentum=0.9) | ||
|
||
def save(self): | ||
torch.save(self.state_dict(), self.config.get("weights_filepath", "example_cnn.pth")) |