Skip to content

Commit

Permalink
Merge pull request #114 from aai-institute/feature/trainer-with-torch…
Browse files Browse the repository at this point in the history
…-model

Make Trainer usable with standard torch model.
  • Loading branch information
Samuel Burbulla authored Apr 8, 2024
2 parents d67fa7a + df821d0 commit 544631d
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 15 deletions.
16 changes: 12 additions & 4 deletions src/continuity/trainer/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,13 @@ def __init__(self, epochs: Optional[int] = None, steps: Optional[int] = None):
self.start_time = time()
super().__init__()

def step(self, logs: Logs):
"""Called after every gradient step.
def __call__(self, logs: Logs):
"""Callback function.
Called at the end of each epoch.
Args:
logs: Training logs.
"""
self.steps_performed += 1

elapsed = time() - self.start_time
sec_per_step = elapsed / self.steps_performed

Expand Down Expand Up @@ -113,6 +112,15 @@ def to_min(t):

print("\r" + s, end="")

def step(self, logs: Logs):
"""Called after every gradient step.
Args:
logs: Training logs.
"""
self.steps_performed += 1
self.__call__(logs)

def on_train_begin(self):
"""Called at the beginning of training."""
self.start_time = time()
Expand Down
21 changes: 11 additions & 10 deletions src/continuity/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,10 @@ def fit(

# Print number of model parameters
if self.verbose:
num_params = self.operator.num_params()
if hasattr(self.operator, "num_params"):
num_params = self.operator.num_params()
else:
num_params = sum(p.numel() for p in self.operator.parameters())
print(f"Parameters: {num_params}", end=" ")

# Move operator to device
Expand Down Expand Up @@ -194,13 +197,12 @@ def fit(
loss_test=loss_test,
)

for x, u, y, v in data_loader:
x, u = x.to(self.device), u.to(self.device)
y, v = y.to(self.device), v.to(self.device)
for xuyv in data_loader:
xuyv = [t.to(self.device) for t in xuyv]

def closure(x=x, u=u, y=y, v=v):
def closure(xuyv=xuyv):
self.optimizer.zero_grad()
loss = self.loss_fn(operator, x, u, y, v)
loss = self.loss_fn(operator, *xuyv)
loss.backward(retain_graph=True)
return loss

Expand All @@ -221,10 +223,9 @@ def closure(x=x, u=u, y=y, v=v):
# Compute test loss
if test_dataset is not None:
loss_test = 0
for x, u, y, v in test_data_loader:
x, u = x.to(self.device), u.to(self.device)
y, v = y.to(self.device), v.to(self.device)
loss = self.loss_fn(operator, x, u, y, v)
for xuyv in test_data_loader:
xuyv = [t.to(self.device) for t in xuyv]
loss = self.loss_fn(operator, *xuyv)
if is_distributed:
dist.all_reduce(loss)
loss /= dist.get_world_size()
Expand Down
45 changes: 44 additions & 1 deletion tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest
import torch
from continuity.operators import DeepONet
from continuity.operators.common import DeepResidualNetwork
from continuity.benchmarks.sine import SineBenchmark
from continuity.trainer import Trainer

Expand All @@ -17,10 +19,51 @@ def train():


@pytest.mark.slow
def test_trainer():
def test_trainer_with_operator():
train()


@pytest.mark.slow
def test_trainer_with_torch_model():
def f(x):
return torch.sin(2 * torch.pi * x)

x_train = torch.rand(128, 1)
x_test = torch.rand(32, 1)

y_train = f(x_train)
y_test = f(x_test)

train_dataset = torch.utils.data.TensorDataset(x_train, y_train)
test_dataset = torch.utils.data.TensorDataset(x_test, y_test)

# Create a model
model = DeepResidualNetwork(
input_size=1,
output_size=1,
width=32,
depth=3,
)

# Define loss function (in Continuity style)
mse = torch.nn.MSELoss()

def loss_fn(op, x, y):
y_pred = op(x)
return mse(y_pred, y)

# Train the model
trainer = Trainer(model, loss_fn=loss_fn)
logs = trainer.fit(
train_dataset,
tol=1e-3,
test_dataset=test_dataset,
)

# Test the model
assert logs.loss_test < 1e-3


# Use ./run_parallel.sh to run test with CUDA
if __name__ == "__main__":
train()

0 comments on commit 544631d

Please sign in to comment.