Skip to content

Commit

Permalink
Cylinder embedding model and training example script
Browse files Browse the repository at this point in the history
- Added training script for cylinder example
- Added viz method for embedding reconstruction
- Some embedding training updates
- Added MIT liscense to setup.py
  • Loading branch information
Nicholas Geneva committed Jul 23, 2021
1 parent 9f889e6 commit 0b6150e
Show file tree
Hide file tree
Showing 9 changed files with 234 additions and 31 deletions.
67 changes: 67 additions & 0 deletions examples/cylinder/train_cylinder_enn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""
=====
Training embedding model for the Lorenz numerical example.
This is a built-in model from the paper.
Distributed by: Notre Dame SCAI Lab (MIT Liscense)
- Associated publication:
url: https://arxiv.org/abs/2010.03957
doi:
github: https://github.com/zabaras/transformer-physx
=====
"""
import sys
import logging
import torch
from torch.optim.lr_scheduler import ExponentialLR

from trphysx.config.configuration_auto import AutoPhysConfig
from trphysx.embedding.embedding_auto import AutoEmbeddingModel
from trphysx.viz.viz_auto import AutoViz
from trphysx.embedding.training import *

logger = logging.getLogger(__name__)

if __name__ == '__main__':

sys.argv = sys.argv + ["--exp_name", "cylinder"]
sys.argv = sys.argv + ["--training_h5_file", "./data/cylinder_train.hdf5"]
sys.argv = sys.argv + ["--eval_h5_file", "./data/cylinder_valid.hdf5"]
sys.argv = sys.argv + ["--batch_size", "32"]
sys.argv = sys.argv + ["--block_size", "4"]
sys.argv = sys.argv + ["--ntrain", "27"]

# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO)

args = EmbeddingParser().parse()
if(torch.cuda.is_available()):
use_cuda = "cuda"
args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
logger.info("Torch device: {}".format(args.device))

# Load transformer config file
config = AutoPhysConfig.load_config(args.exp_name)
data_handler = AutoDataHandler.load_data_handler(args.exp_name)
viz = AutoViz.init_viz(args.exp_name)(args.plot_dir)

# Set up data-loaders
training_loader = data_handler.createTrainingLoader(args.training_h5_file, block_size=args.block_size, stride=args.stride, ndata=args.ntrain, batch_size=args.batch_size)
testing_loader = data_handler.createTestingLoader(args.eval_h5_file, block_size=32, ndata=args.ntest, batch_size=8)

# Set up model
model = AutoEmbeddingModel.init_trainer(args.exp_name, config).to(args.device)
mu, std = data_handler.norm_params
model.embedding_model.mu = mu.to(args.device)
model.embedding_model.std = std.to(args.device)
if args.epoch_start > 1:
model.load_model(args.ckpt_dir, args.epoch_start)

optimizer = torch.optim.Adam(model.parameters(), lr=args.lr*0.995**(args.epoch_start-1), weight_decay=1e-8)
scheduler = ExponentialLR(optimizer, gamma=0.995)

trainer = EmbeddingTrainer(model, args, (optimizer, scheduler), viz)
trainer.train(training_loader, testing_loader)
14 changes: 8 additions & 6 deletions examples/lorenz/train_lorenz_enn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from trphysx.embedding.embedding_auto import AutoEmbeddingModel
from trphysx.embedding.training import *

logger = logging.getLogger(__name__)

if __name__ == '__main__':

sys.argv = sys.argv + ["--exp_name", "lorenz"]
Expand All @@ -28,18 +30,18 @@
sys.argv = sys.argv + ["--block_size", "16"]
sys.argv = sys.argv + ["--ntrain", "2048"]

args = EmbeddingParser().parse()
if(torch.cuda.is_available()):
use_cuda = "cuda"
args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Torch device:{}".format(args.device))

# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO)

args = EmbeddingParser().parse()
if(torch.cuda.is_available()):
use_cuda = "cuda"
args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
logger.info("Torch device: {}".format(args.device))

# Load transformer config file
config = AutoPhysConfig.load_config(args.exp_name)
data_handler = AutoDataHandler.load_data_handler(args.exp_name)
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
# https://packaging.python.org/specifications/core-metadata/#home-page-optional
url='https://github.com/zabaras/transformer-physx', # Optional

license='MIT',

author='Nicholas Geneva', # Optional
author_email='[email protected]', # Optional
Expand Down
43 changes: 30 additions & 13 deletions trphysx/data_utils/dataset_phys.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,7 @@ def __init__(
with FileLock(lock_path):

if os.path.exists(cached_features_file) and not overwrite_cache:
start = time.time()
with open(cached_features_file, "rb") as handle:
self.examples, self.states = pickle.load(handle)
logger.info(
f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start)

self.read_cache(cached_features_file)
else:
logger.info(f"Creating features from dataset file at {directory}")

Expand All @@ -85,13 +80,35 @@ def __init__(
with h5py.File(file_path, "r") as f:
self.embed_data(f, embedder, **kwargs)

start = time.time()
os.makedirs(cache_path, exist_ok=True)
with open(cached_features_file, "wb") as handle:
pickle.dump((self.examples, self.states), handle, protocol=pickle.HIGHEST_PROTOCOL)
logger.info(
"Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start
)
self.write_cache(cached_features_file)

def read_cache(self, cached_features_file:str) -> None:
"""Default method to read cache file into object.
Args:
cached_features_file (str): Cache file path
"""
assert os.path.isfile(cached_features_file), 'Provided cache file path does not exist!'

start = time.time()
with open(cached_features_file, "rb") as handle:
self.examples, self.states = pickle.load(handle)
logger.info(
f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start)

def write_cache(self, cached_features_file:str) -> None:
"""Default method to write cache file .
Args:
cached_features_file (str): Cache file path
"""
start = time.time()
os.makedirs(os.path.dirname(cached_features_file), exist_ok=True)
with open(cached_features_file, "wb") as handle:
pickle.dump((self.examples, self.states), handle, protocol=pickle.HIGHEST_PROTOCOL)
logger.info(
"Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start
)

@abstractmethod
def embed_data(self, h5_file: h5py.File, embedder: EmbeddingModel):
Expand Down
35 changes: 34 additions & 1 deletion trphysx/embedding/embedding_cylinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,4 +277,37 @@ def forward(self, states: Tensor, viscosity: Tensor) -> FloatTuple:
loss_reconstruct = loss_reconstruct + mseLoss(xRec1, xin0).detach()
g1_old = g1Pred

return loss, loss_reconstruct
return loss, loss_reconstruct

def evaluate(self, states: Tensor, viscosity: Tensor) -> Tuple[float, Tensor, Tensor]:
"""Evaluates the embedding models reconstruction error and returns its
predictions.
Args:
states (Tensor): [B, T, 3, H, W] Time-series feature tensor
viscosity (Tensor): [B] Viscosities of the fluid in the mini-batch
Returns:
Tuple[Float, Tensor, Tensor]: Test error, Predicted states, Target states
"""
self.embedding_model.eval()
device = self.embedding_model.devices[0]

mseLoss = nn.MSELoss()

# Pull out targets from prediction dataset
yTarget = states[:,1:].to(device)
xInput = states[:,:-1].to(device)
yPred = torch.zeros(yTarget.size()).to(device)
viscosity = viscosity.to(device)

# Test accuracy of one time-step
for i in range(xInput.size(1)):
xInput0 = xInput[:,i].to(device)
g0 = self.embedding_model.embed(xInput0, viscosity)
yPred0 = self.embedding_model.recover(g0)
yPred[:,i] = yPred0.squeeze().detach()

test_loss = mseLoss(yTarget, yPred)

return test_loss, yPred, yTarget
12 changes: 8 additions & 4 deletions trphysx/embedding/training/enn_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,17 @@ def parse(self, args:List = None, dirs: bool = True) -> None:
else:
args = self.parse_args()

args.run_dir = os.path.join(HOME, args.exp_dir, "embedding_{}".format(args.exp_name),
"ntrain{}_epochs{:d}_batch{:d}".format(args.ntrain, args.epochs, args.batch_size))
if len(args.notes) > 0:
args.run_dir = os.path.join(HOME, args.exp_dir, "embedding_{}".format(args.exp_name),
"ntrain{}_epochs{:d}_batch{:d}_{:s}".format(args.ntrain, args.epochs, args.batch_size, args.notes))
else:
args.run_dir = os.path.join(HOME, args.exp_dir, "embedding_{}".format(args.exp_name),
"ntrain{}_epochs{:d}_batch{:d}".format(args.ntrain, args.epochs, args.batch_size))
args.ckpt_dir = os.path.join(args.run_dir,"checkpoints")
args.pred_dir = os.path.join(args.run_dir, "predictions")
args.plot_dir = os.path.join(args.run_dir, "predictions")

if(dirs):
self.mkdirs(args.run_dir, args.ckpt_dir, args.pred_dir)
self.mkdirs(args.run_dir, args.ckpt_dir, args.plot_dir)

# Set random seed
if args.seed is None:
Expand Down
10 changes: 5 additions & 5 deletions trphysx/embedding/training/enn_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ class EmbeddingTrainer:
viz (Viz, optional): Visualization class. Defaults to None.
"""
def __init__(self,
model: EmbeddingTrainingHead,
args: argparse.ArgumentParser,
optimizers: Tuple[Optimizer, Scheduler],
viz: Viz = None
) -> None:
model: EmbeddingTrainingHead,
args: argparse.ArgumentParser,
optimizers: Tuple[Optimizer, Scheduler],
viz: Viz = None
) -> None:
"""Constructor
"""
self.model = model.to(args.device)
Expand Down
3 changes: 2 additions & 1 deletion trphysx/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def evaluate(

for mbidx, inputs in enumerate(eval_dataloader):

states = inputs['states'].to(self.args.src_device)
states = inputs['states']
del inputs['states']

if mbidx == 0:
Expand Down Expand Up @@ -383,6 +383,7 @@ def eval_states(
tsize = pred_embeds.size(1)
device = self.embedding_model.devices[0]

states = states.to(device)
x_in = pred_embeds.contiguous().view(-1, pred_embeds.size(-1)).to(device)
out = self.embedding_model.recover(x_in)
out = out.view([bsize, tsize] + self.embedding_model.input_dims)
Expand Down
80 changes: 79 additions & 1 deletion trphysx/viz/viz_cylinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def yGrad(u, dy=1, padding=(1, 1, 1, 1)):
c_min = min([np.amin(vortTarget[:, :, :])+4])
c_max = 7
c_min = -7
print(vortPred.shape)

for t0 in range(nsteps):
# Plot target
ax[0, t0].imshow(vortTarget[t0 * stride, :, :], extent=[-2, 14, -4, 4], cmap=cmap0, origin='lower',
Expand Down Expand Up @@ -223,4 +223,82 @@ def yGrad(u, dy=1, padding=(1, 1, 1, 1)):
else:
file_name = 'cylinderVortPred{:d}'.format(pid)

self.saveFigure(plot_dir, file_name)


def plotEmbeddingPrediction(self,
y_pred: Tensor,
y_target: Tensor,
plot_dir: str = None,
epoch: int = None,
bidx: int = 0,
tidx: int = 0,
pid: int = 0
) -> None:
"""Plots the predicted x-velocity, y-velocity and pressure field contours
Args:
y_pred (Tensor): [B, T, 3, H, W] Prediction tensor.
y_target (Tensor): [B, T, 3, H, W] Target tensor.
plot_dir (str, optional): Directory to save figure, overrides plot_dir one if provided. Defaults to None.
epoch (int, optional): Current epoch, used for file name. Defaults to None.
bidx (int, optional): Batch index to plot. Defaults to 0.
tidx (int, optional): Timestep index to plot. Defaults to None (plot random time-step).
pid (int, optional): Optional plotting id for indexing file name manually. Defaults to 0.
"""
if plot_dir is None:
plot_dir = self.plot_dir
# Convert to numpy array
if tidx is None:
tidx = np.random.randint(0, y_pred.size(1))
y_pred = y_pred[bidx, tidx].detach().cpu().numpy()
y_target = y_target[bidx, tidx].detach().cpu().numpy()
y_error = np.power(y_pred - y_target, 2)

plt.close('all')
mpl.rcParams['font.family'] = ['serif'] # default is sans-serif
mpl.rcParams['figure.dpi'] = 300
mpl.rcParams['xtick.labelsize'] = 2
mpl.rcParams['ytick.labelsize'] = 2
# rc('text', usetex=True)

# Set up figure
cmap0 = 'viridis'
cmap1 = 'inferno'
# fig, ax = plt.subplots(2+yPred0.size(0), yPred0.size(2), figsize=(2*yPred0.size(1), 3+3*yPred0.size(0)))
fig, ax = plt.subplots(3, 3, figsize=(2.1*3, 2.25))
fig.subplots_adjust(wspace=0.1)

for i, field in enumerate(['ux', 'uy', 'p']):
c_max = max([np.amax(y_target[i, :, :])])
c_min = min([np.amin(y_target[i, :, :])])

ax[0,i].imshow(y_target[i, :, :], extent=[-2, 14, -4, 4], cmap=cmap0, origin='lower',
vmax=c_max, vmin=c_min)

ax[1,i].imshow(y_pred[i, :, :], extent=[-2, 14, -4, 4], cmap=cmap0, origin='lower',
vmax=c_max, vmin=c_min)

ax[2,i].imshow(y_error[i, :, :], extent=[-2, 14, -4, 4], cmap=cmap1, origin='lower')

for j in range(3):
ax[j, i].set_yticks(np.linspace(-4, 4, 5))

for tick in ax[j, i].yaxis.get_major_ticks():
tick.label.set_fontsize(5)

ax[2, i].set_xticks(np.linspace(-2, 14, 9))
for tick in ax[2, i].xaxis.get_major_ticks():
tick.label.set_fontsize(5)

ax[0, i].set_title(f'{field}', fontsize=8)

ax[0, 0].set_ylabel('Target', fontsize=8)
ax[1, 0].set_ylabel('Prediction', fontsize=8)
ax[2, 0].set_ylabel('Error', fontsize=8)

if (not epoch is None):
file_name = 'embeddingPred{:d}_{:d}'.format(pid, epoch)
else:
file_name = 'embeddingPred{:d}'.format(pid)
self.saveFigure(plot_dir, file_name)

0 comments on commit 0b6150e

Please sign in to comment.