From 0b6150ec719733fb55c6866216e3779be4cdb61c Mon Sep 17 00:00:00 2001 From: Nicholas Geneva Date: Fri, 23 Jul 2021 13:48:42 -0400 Subject: [PATCH] Cylinder embedding model and training example script - Added training script for cylinder example - Added viz method for embedding reconstruction - Some embedding training updates - Added MIT liscense to setup.py --- examples/cylinder/train_cylinder_enn.py | 67 +++++++++++++++++++ examples/lorenz/train_lorenz_enn.py | 14 ++-- setup.py | 1 + trphysx/data_utils/dataset_phys.py | 43 ++++++++---- trphysx/embedding/embedding_cylinder.py | 35 +++++++++- trphysx/embedding/training/enn_args.py | 12 ++-- trphysx/embedding/training/enn_trainer.py | 10 +-- trphysx/utils/trainer.py | 3 +- trphysx/viz/viz_cylinder.py | 80 ++++++++++++++++++++++- 9 files changed, 234 insertions(+), 31 deletions(-) create mode 100644 examples/cylinder/train_cylinder_enn.py diff --git a/examples/cylinder/train_cylinder_enn.py b/examples/cylinder/train_cylinder_enn.py new file mode 100644 index 0000000..c0ec164 --- /dev/null +++ b/examples/cylinder/train_cylinder_enn.py @@ -0,0 +1,67 @@ +""" +===== +Training embedding model for the Lorenz numerical example. +This is a built-in model from the paper. + +Distributed by: Notre Dame SCAI Lab (MIT Liscense) +- Associated publication: +url: https://arxiv.org/abs/2010.03957 +doi: +github: https://github.com/zabaras/transformer-physx +===== +""" +import sys +import logging +import torch +from torch.optim.lr_scheduler import ExponentialLR + +from trphysx.config.configuration_auto import AutoPhysConfig +from trphysx.embedding.embedding_auto import AutoEmbeddingModel +from trphysx.viz.viz_auto import AutoViz +from trphysx.embedding.training import * + +logger = logging.getLogger(__name__) + +if __name__ == '__main__': + + sys.argv = sys.argv + ["--exp_name", "cylinder"] + sys.argv = sys.argv + ["--training_h5_file", "./data/cylinder_train.hdf5"] + sys.argv = sys.argv + ["--eval_h5_file", "./data/cylinder_valid.hdf5"] + sys.argv = sys.argv + ["--batch_size", "32"] + sys.argv = sys.argv + ["--block_size", "4"] + sys.argv = sys.argv + ["--ntrain", "27"] + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO) + + args = EmbeddingParser().parse() + if(torch.cuda.is_available()): + use_cuda = "cuda" + args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + logger.info("Torch device: {}".format(args.device)) + + # Load transformer config file + config = AutoPhysConfig.load_config(args.exp_name) + data_handler = AutoDataHandler.load_data_handler(args.exp_name) + viz = AutoViz.init_viz(args.exp_name)(args.plot_dir) + + # Set up data-loaders + training_loader = data_handler.createTrainingLoader(args.training_h5_file, block_size=args.block_size, stride=args.stride, ndata=args.ntrain, batch_size=args.batch_size) + testing_loader = data_handler.createTestingLoader(args.eval_h5_file, block_size=32, ndata=args.ntest, batch_size=8) + + # Set up model + model = AutoEmbeddingModel.init_trainer(args.exp_name, config).to(args.device) + mu, std = data_handler.norm_params + model.embedding_model.mu = mu.to(args.device) + model.embedding_model.std = std.to(args.device) + if args.epoch_start > 1: + model.load_model(args.ckpt_dir, args.epoch_start) + + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr*0.995**(args.epoch_start-1), weight_decay=1e-8) + scheduler = ExponentialLR(optimizer, gamma=0.995) + + trainer = EmbeddingTrainer(model, args, (optimizer, scheduler), viz) + trainer.train(training_loader, testing_loader) \ No newline at end of file diff --git a/examples/lorenz/train_lorenz_enn.py b/examples/lorenz/train_lorenz_enn.py index 5a24850..5c5ddee 100644 --- a/examples/lorenz/train_lorenz_enn.py +++ b/examples/lorenz/train_lorenz_enn.py @@ -19,6 +19,8 @@ from trphysx.embedding.embedding_auto import AutoEmbeddingModel from trphysx.embedding.training import * +logger = logging.getLogger(__name__) + if __name__ == '__main__': sys.argv = sys.argv + ["--exp_name", "lorenz"] @@ -28,18 +30,18 @@ sys.argv = sys.argv + ["--block_size", "16"] sys.argv = sys.argv + ["--ntrain", "2048"] - args = EmbeddingParser().parse() - if(torch.cuda.is_available()): - use_cuda = "cuda" - args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - print("Torch device:{}".format(args.device)) - # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO) + args = EmbeddingParser().parse() + if(torch.cuda.is_available()): + use_cuda = "cuda" + args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + logger.info("Torch device: {}".format(args.device)) + # Load transformer config file config = AutoPhysConfig.load_config(args.exp_name) data_handler = AutoDataHandler.load_data_handler(args.exp_name) diff --git a/setup.py b/setup.py index 4534639..0a44208 100644 --- a/setup.py +++ b/setup.py @@ -56,6 +56,7 @@ # https://packaging.python.org/specifications/core-metadata/#home-page-optional url='https://github.com/zabaras/transformer-physx', # Optional + license='MIT', author='Nicholas Geneva', # Optional author_email='ngeneva@nd.edu', # Optional diff --git a/trphysx/data_utils/dataset_phys.py b/trphysx/data_utils/dataset_phys.py index cac703f..e85dfb6 100644 --- a/trphysx/data_utils/dataset_phys.py +++ b/trphysx/data_utils/dataset_phys.py @@ -70,12 +70,7 @@ def __init__( with FileLock(lock_path): if os.path.exists(cached_features_file) and not overwrite_cache: - start = time.time() - with open(cached_features_file, "rb") as handle: - self.examples, self.states = pickle.load(handle) - logger.info( - f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start) - + self.read_cache(cached_features_file) else: logger.info(f"Creating features from dataset file at {directory}") @@ -85,13 +80,35 @@ def __init__( with h5py.File(file_path, "r") as f: self.embed_data(f, embedder, **kwargs) - start = time.time() - os.makedirs(cache_path, exist_ok=True) - with open(cached_features_file, "wb") as handle: - pickle.dump((self.examples, self.states), handle, protocol=pickle.HIGHEST_PROTOCOL) - logger.info( - "Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start - ) + self.write_cache(cached_features_file) + + def read_cache(self, cached_features_file:str) -> None: + """Default method to read cache file into object. + + Args: + cached_features_file (str): Cache file path + """ + assert os.path.isfile(cached_features_file), 'Provided cache file path does not exist!' + + start = time.time() + with open(cached_features_file, "rb") as handle: + self.examples, self.states = pickle.load(handle) + logger.info( + f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start) + + def write_cache(self, cached_features_file:str) -> None: + """Default method to write cache file . + + Args: + cached_features_file (str): Cache file path + """ + start = time.time() + os.makedirs(os.path.dirname(cached_features_file), exist_ok=True) + with open(cached_features_file, "wb") as handle: + pickle.dump((self.examples, self.states), handle, protocol=pickle.HIGHEST_PROTOCOL) + logger.info( + "Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start + ) @abstractmethod def embed_data(self, h5_file: h5py.File, embedder: EmbeddingModel): diff --git a/trphysx/embedding/embedding_cylinder.py b/trphysx/embedding/embedding_cylinder.py index 88f2570..fc7aaa9 100644 --- a/trphysx/embedding/embedding_cylinder.py +++ b/trphysx/embedding/embedding_cylinder.py @@ -277,4 +277,37 @@ def forward(self, states: Tensor, viscosity: Tensor) -> FloatTuple: loss_reconstruct = loss_reconstruct + mseLoss(xRec1, xin0).detach() g1_old = g1Pred - return loss, loss_reconstruct \ No newline at end of file + return loss, loss_reconstruct + + def evaluate(self, states: Tensor, viscosity: Tensor) -> Tuple[float, Tensor, Tensor]: + """Evaluates the embedding models reconstruction error and returns its + predictions. + + Args: + states (Tensor): [B, T, 3, H, W] Time-series feature tensor + viscosity (Tensor): [B] Viscosities of the fluid in the mini-batch + + Returns: + Tuple[Float, Tensor, Tensor]: Test error, Predicted states, Target states + """ + self.embedding_model.eval() + device = self.embedding_model.devices[0] + + mseLoss = nn.MSELoss() + + # Pull out targets from prediction dataset + yTarget = states[:,1:].to(device) + xInput = states[:,:-1].to(device) + yPred = torch.zeros(yTarget.size()).to(device) + viscosity = viscosity.to(device) + + # Test accuracy of one time-step + for i in range(xInput.size(1)): + xInput0 = xInput[:,i].to(device) + g0 = self.embedding_model.embed(xInput0, viscosity) + yPred0 = self.embedding_model.recover(g0) + yPred[:,i] = yPred0.squeeze().detach() + + test_loss = mseLoss(yTarget, yPred) + + return test_loss, yPred, yTarget \ No newline at end of file diff --git a/trphysx/embedding/training/enn_args.py b/trphysx/embedding/training/enn_args.py index 7c39488..e0eaf93 100644 --- a/trphysx/embedding/training/enn_args.py +++ b/trphysx/embedding/training/enn_args.py @@ -74,13 +74,17 @@ def parse(self, args:List = None, dirs: bool = True) -> None: else: args = self.parse_args() - args.run_dir = os.path.join(HOME, args.exp_dir, "embedding_{}".format(args.exp_name), - "ntrain{}_epochs{:d}_batch{:d}".format(args.ntrain, args.epochs, args.batch_size)) + if len(args.notes) > 0: + args.run_dir = os.path.join(HOME, args.exp_dir, "embedding_{}".format(args.exp_name), + "ntrain{}_epochs{:d}_batch{:d}_{:s}".format(args.ntrain, args.epochs, args.batch_size, args.notes)) + else: + args.run_dir = os.path.join(HOME, args.exp_dir, "embedding_{}".format(args.exp_name), + "ntrain{}_epochs{:d}_batch{:d}".format(args.ntrain, args.epochs, args.batch_size)) args.ckpt_dir = os.path.join(args.run_dir,"checkpoints") - args.pred_dir = os.path.join(args.run_dir, "predictions") + args.plot_dir = os.path.join(args.run_dir, "predictions") if(dirs): - self.mkdirs(args.run_dir, args.ckpt_dir, args.pred_dir) + self.mkdirs(args.run_dir, args.ckpt_dir, args.plot_dir) # Set random seed if args.seed is None: diff --git a/trphysx/embedding/training/enn_trainer.py b/trphysx/embedding/training/enn_trainer.py index a563b1e..85e3c05 100644 --- a/trphysx/embedding/training/enn_trainer.py +++ b/trphysx/embedding/training/enn_trainer.py @@ -43,11 +43,11 @@ class EmbeddingTrainer: viz (Viz, optional): Visualization class. Defaults to None. """ def __init__(self, - model: EmbeddingTrainingHead, - args: argparse.ArgumentParser, - optimizers: Tuple[Optimizer, Scheduler], - viz: Viz = None - ) -> None: + model: EmbeddingTrainingHead, + args: argparse.ArgumentParser, + optimizers: Tuple[Optimizer, Scheduler], + viz: Viz = None + ) -> None: """Constructor """ self.model = model.to(args.device) diff --git a/trphysx/utils/trainer.py b/trphysx/utils/trainer.py index 5ed131b..d51c2fa 100644 --- a/trphysx/utils/trainer.py +++ b/trphysx/utils/trainer.py @@ -300,7 +300,7 @@ def evaluate( for mbidx, inputs in enumerate(eval_dataloader): - states = inputs['states'].to(self.args.src_device) + states = inputs['states'] del inputs['states'] if mbidx == 0: @@ -383,6 +383,7 @@ def eval_states( tsize = pred_embeds.size(1) device = self.embedding_model.devices[0] + states = states.to(device) x_in = pred_embeds.contiguous().view(-1, pred_embeds.size(-1)).to(device) out = self.embedding_model.recover(x_in) out = out.view([bsize, tsize] + self.embedding_model.input_dims) diff --git a/trphysx/viz/viz_cylinder.py b/trphysx/viz/viz_cylinder.py index 234b8cb..b854640 100644 --- a/trphysx/viz/viz_cylinder.py +++ b/trphysx/viz/viz_cylinder.py @@ -190,7 +190,7 @@ def yGrad(u, dy=1, padding=(1, 1, 1, 1)): c_min = min([np.amin(vortTarget[:, :, :])+4]) c_max = 7 c_min = -7 - print(vortPred.shape) + for t0 in range(nsteps): # Plot target ax[0, t0].imshow(vortTarget[t0 * stride, :, :], extent=[-2, 14, -4, 4], cmap=cmap0, origin='lower', @@ -223,4 +223,82 @@ def yGrad(u, dy=1, padding=(1, 1, 1, 1)): else: file_name = 'cylinderVortPred{:d}'.format(pid) + self.saveFigure(plot_dir, file_name) + + + def plotEmbeddingPrediction(self, + y_pred: Tensor, + y_target: Tensor, + plot_dir: str = None, + epoch: int = None, + bidx: int = 0, + tidx: int = 0, + pid: int = 0 + ) -> None: + """Plots the predicted x-velocity, y-velocity and pressure field contours + + Args: + y_pred (Tensor): [B, T, 3, H, W] Prediction tensor. + y_target (Tensor): [B, T, 3, H, W] Target tensor. + plot_dir (str, optional): Directory to save figure, overrides plot_dir one if provided. Defaults to None. + epoch (int, optional): Current epoch, used for file name. Defaults to None. + bidx (int, optional): Batch index to plot. Defaults to 0. + tidx (int, optional): Timestep index to plot. Defaults to None (plot random time-step). + pid (int, optional): Optional plotting id for indexing file name manually. Defaults to 0. + """ + if plot_dir is None: + plot_dir = self.plot_dir + # Convert to numpy array + if tidx is None: + tidx = np.random.randint(0, y_pred.size(1)) + y_pred = y_pred[bidx, tidx].detach().cpu().numpy() + y_target = y_target[bidx, tidx].detach().cpu().numpy() + y_error = np.power(y_pred - y_target, 2) + + plt.close('all') + mpl.rcParams['font.family'] = ['serif'] # default is sans-serif + mpl.rcParams['figure.dpi'] = 300 + mpl.rcParams['xtick.labelsize'] = 2 + mpl.rcParams['ytick.labelsize'] = 2 + # rc('text', usetex=True) + + # Set up figure + cmap0 = 'viridis' + cmap1 = 'inferno' + # fig, ax = plt.subplots(2+yPred0.size(0), yPred0.size(2), figsize=(2*yPred0.size(1), 3+3*yPred0.size(0))) + fig, ax = plt.subplots(3, 3, figsize=(2.1*3, 2.25)) + fig.subplots_adjust(wspace=0.1) + + for i, field in enumerate(['ux', 'uy', 'p']): + c_max = max([np.amax(y_target[i, :, :])]) + c_min = min([np.amin(y_target[i, :, :])]) + + ax[0,i].imshow(y_target[i, :, :], extent=[-2, 14, -4, 4], cmap=cmap0, origin='lower', + vmax=c_max, vmin=c_min) + + ax[1,i].imshow(y_pred[i, :, :], extent=[-2, 14, -4, 4], cmap=cmap0, origin='lower', + vmax=c_max, vmin=c_min) + + ax[2,i].imshow(y_error[i, :, :], extent=[-2, 14, -4, 4], cmap=cmap1, origin='lower') + + for j in range(3): + ax[j, i].set_yticks(np.linspace(-4, 4, 5)) + + for tick in ax[j, i].yaxis.get_major_ticks(): + tick.label.set_fontsize(5) + + ax[2, i].set_xticks(np.linspace(-2, 14, 9)) + for tick in ax[2, i].xaxis.get_major_ticks(): + tick.label.set_fontsize(5) + + ax[0, i].set_title(f'{field}', fontsize=8) + + ax[0, 0].set_ylabel('Target', fontsize=8) + ax[1, 0].set_ylabel('Prediction', fontsize=8) + ax[2, 0].set_ylabel('Error', fontsize=8) + + if (not epoch is None): + file_name = 'embeddingPred{:d}_{:d}'.format(pid, epoch) + else: + file_name = 'embeddingPred{:d}'.format(pid) self.saveFigure(plot_dir, file_name) \ No newline at end of file