Skip to content

Commit

Permalink
Cylinder embedding and transformer training update
Browse files Browse the repository at this point in the history
- Updated viz auto loader
- Version 0.0.6 on pypi
  • Loading branch information
Nicholas Geneva committed Jul 25, 2021
1 parent 0b6150e commit 5de5340
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 12 deletions.
17 changes: 13 additions & 4 deletions examples/cylinder/train_cylinder_enn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
=====
Training embedding model for the Lorenz numerical example.
Training embedding model for the flow around cylinder numerical example.
This is a built-in model from the paper.
Distributed by: Notre Dame SCAI Lab (MIT Liscense)
Expand Down Expand Up @@ -46,11 +46,20 @@
# Load transformer config file
config = AutoPhysConfig.load_config(args.exp_name)
data_handler = AutoDataHandler.load_data_handler(args.exp_name)
viz = AutoViz.init_viz(args.exp_name)(args.plot_dir)
viz = AutoViz.load_viz(args.exp_name, plot_dir=args.plot_dir)

# Set up data-loaders
training_loader = data_handler.createTrainingLoader(args.training_h5_file, block_size=args.block_size, stride=args.stride, ndata=args.ntrain, batch_size=args.batch_size)
testing_loader = data_handler.createTestingLoader(args.eval_h5_file, block_size=32, ndata=args.ntest, batch_size=8)
training_loader = data_handler.createTrainingLoader(
args.training_h5_file,
block_size=args.block_size,
stride=args.stride,
ndata=args.ntrain,
batch_size=args.batch_size)
testing_loader = data_handler.createTestingLoader(
args.eval_h5_file,
block_size=32,
ndata=args.ntest,
batch_size=8)

# Set up model
model = AutoEmbeddingModel.init_trainer(args.exp_name, config).to(args.device)
Expand Down
104 changes: 104 additions & 0 deletions examples/cylinder/train_cylinder_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""
=====
Training transformer model for the flow around a cylinder numerical example.
This is a built-in model from the paper.
Distributed by: Notre Dame SCAI Lab (MIT Liscense)
- Associated publication:
url: https://arxiv.org/abs/2010.03957
doi:
github: https://github.com/zabaras/transformer-physx
=====
"""
import sys
import logging
import torch
from trphysx.config import HfArgumentParser
from trphysx.config.args import ModelArguments, TrainingArguments, DataArguments, ArgUtils
from trphysx.config import AutoPhysConfig
from trphysx.transformer import PhysformerTrain, PhysformerGPT2
from trphysx.embedding import AutoEmbeddingModel
from trphysx.viz import AutoViz
from trphysx.data_utils import AutoDataset
from trphysx.utils.trainer import Trainer

logger = logging.getLogger(__name__)

if __name__ == "__main__":

sys.argv = sys.argv + ["--init_name", "cylinder"]
sys.argv = sys.argv + ["--embedding_file_or_path", "./embedding_cylinder300.pth"]
sys.argv = sys.argv + ["--training_h5_file","./data/cylinder_training.hdf5"]
sys.argv = sys.argv + ["--eval_h5_file","./data/cylinder_valid.hdf5"]
sys.argv = sys.argv + ["--train_batch_size", "8"]
sys.argv = sys.argv + ["--n_train", "27"]
sys.argv = sys.argv + ["--n_eval", "3"]
sys.argv = sys.argv + ["--stride", "4"]
sys.argv = sys.argv + ["--max_grad_norm", "0.01"]
sys.argv = sys.argv + ["--save_steps", "25"]

# Parse arguments using the hugging face argument parser
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN)
# Configure arguments after intialization
model_args, data_args, training_args = ArgUtils.config(model_args, data_args, training_args)

# Load model configuration
config = AutoPhysConfig.load_config(model_args.config_name)

# Load embedding model
embedding_model = AutoEmbeddingModel.load_model(
model_args.embedding_name,
config,
model_args.embedding_file_or_path).to(training_args.src_device)

# Load visualization utility class
viz = AutoViz.load_viz(model_args.viz_name, plot_dir=training_args.plot_dir)

# Init transformer model
transformer = PhysformerGPT2(config, model_args.model_name)
model = PhysformerTrain(config, transformer)
if(training_args.epoch_start > 0):
model.load_model(training_args.ckpt_dir, epoch=training_args.epoch_start)
if(model_args.transformer_file_or_path):
model.load_model(model_args.transformer_file_or_path)

# Initialize training and validation datasets
training_data = AutoDataset.create_dataset(
model_args.model_name,
embedding_model,
data_args.training_h5_file,
block_size=config.n_ctx,
stride=data_args.stride,
ndata=data_args.n_train,
overwrite_cache=data_args.overwrite_cache)

eval_data = AutoDataset.create_dataset(
model_args.model_name,
embedding_model,
data_args.eval_h5_file,
block_size=256,
stride=1024,
ndata=data_args.n_eval,
eval = True,
overwrite_cache=data_args.overwrite_cache)

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=training_args.lr, weight_decay=1e-10)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 14, 2, eta_min=1e-9)
trainer = Trainer(
model,
training_args,
(optimizer, scheduler),
train_dataset = training_data,
eval_dataset = eval_data,
embedding_model = embedding_model,
viz=viz )

trainer.train()
2 changes: 1 addition & 1 deletion examples/lorenz/train_lorenz_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
model_args.embedding_file_or_path).to(training_args.src_device)

# Load visualization utility class
viz = AutoViz.init_viz(model_args.viz_name)(training_args.plot_dir)
viz = AutoViz.load_viz(model_args.viz_name, plot_dir=training_args.plot_dir)

# Init transformer model
transformer = PhysformerGPT2(config, model_args.model_name)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
# For a discussion on single-sourcing the version across setup.py and the
# project code, see
# https://packaging.python.org/en/latest/single_source_version.html
version='0.0.5', # Required
version='0.0.6', # Required

# This is a one-line description or tagline of what your project does. This
# corresponds to the "Summary" metadata field:
Expand Down
7 changes: 4 additions & 3 deletions trphysx/viz/viz_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def __init__(self):
)

@classmethod
def init_viz(cls, viz_name: str) -> Viz:
"""Initializes visualization class.
def load_viz(cls, viz_name: str, *args, **kwargs) -> Viz:
"""Loads built in visualization class.
Currently supports: "lorenz", "cylinder", "grayscott"
Args:
Expand All @@ -48,7 +48,8 @@ def init_viz(cls, viz_name: str) -> Viz:
"""
# First check if the model name is a pre-defined config
if(viz_name in VIZ_MAPPING.keys()):
return VIZ_MAPPING[viz_name]
viz_class = VIZ_MAPPING[viz_name]
return viz_class(*args, **kwargs)
else:
err_str = "Provided viz name, {:s}, not found in existing visualization classes.".format(viz_name)
raise KeyError(err_str)
9 changes: 6 additions & 3 deletions trphysx/viz/viz_cylinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,8 @@ def plotEmbeddingPrediction(self,
y_target: Tensor,
plot_dir: str = None,
epoch: int = None,
bidx: int = 0,
tidx: int = 0,
bidx: int = None,
tidx: int = None,
pid: int = 0
) -> None:
"""Plots the predicted x-velocity, y-velocity and pressure field contours
Expand All @@ -242,13 +242,16 @@ def plotEmbeddingPrediction(self,
y_target (Tensor): [B, T, 3, H, W] Target tensor.
plot_dir (str, optional): Directory to save figure, overrides plot_dir one if provided. Defaults to None.
epoch (int, optional): Current epoch, used for file name. Defaults to None.
bidx (int, optional): Batch index to plot. Defaults to 0.
bidx (int, optional): Batch index to plot. Defaults to None (plot random example in batch).
tidx (int, optional): Timestep index to plot. Defaults to None (plot random time-step).
pid (int, optional): Optional plotting id for indexing file name manually. Defaults to 0.
"""
print(y_pred.size())
if plot_dir is None:
plot_dir = self.plot_dir
# Convert to numpy array
if bidx is None:
bidx = np.random.randint(0, y_pred.size(0))
if tidx is None:
tidx = np.random.randint(0, y_pred.size(1))
y_pred = y_pred[bidx, tidx].detach().cpu().numpy()
Expand Down

0 comments on commit 5de5340

Please sign in to comment.