diff --git a/docs/zh/api/arch.md b/docs/zh/api/arch.md
index 4b493d143..be639dcf9 100644
--- a/docs/zh/api/arch.md
+++ b/docs/zh/api/arch.md
@@ -19,5 +19,6 @@
- AFNONet
- PrecipNet
- UNetEx
+ - NowcastNet
show_root_heading: false
heading_level: 3
diff --git a/docs/zh/api/data/dataset.md b/docs/zh/api/data/dataset.md
index cbaef5893..c5946a1cd 100644
--- a/docs/zh/api/data/dataset.md
+++ b/docs/zh/api/data/dataset.md
@@ -20,5 +20,6 @@
- VtuDataset
- MeshAirfoilDataset
- MeshCylinderDataset
+ - RadarDataset
- build_dataset
show_root_heading: false
diff --git a/docs/zh/api/visualize.md b/docs/zh/api/visualize.md
index 2c120b412..3b64d8797 100644
--- a/docs/zh/api/visualize.md
+++ b/docs/zh/api/visualize.md
@@ -12,6 +12,7 @@
- Visualizer2DPlot
- Visualizer3D
- VisualizerWeather
+ - VisualizerRadar
- save_vtu_from_dict
- save_vtu_to_mesh
- save_plot_from_1d_dict
diff --git a/docs/zh/examples/nowcastnet.md b/docs/zh/examples/nowcastnet.md
new file mode 100644
index 000000000..d665fb886
--- /dev/null
+++ b/docs/zh/examples/nowcastnet.md
@@ -0,0 +1,92 @@
+# NowcastNet
+
+=== "模型训练命令"
+
+ 暂无
+
+=== "模型评估命令"
+
+ ``` sh
+ # linux
+ wget https://paddle-org.bj.bcebos.com/paddlescience/datasets/nowcastnet/mrms.tar
+ # windows
+ # curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/nowcastnet/mrms.tar --output mrms.tar
+ tar -xvf mrms.tar -C datasets/
+ python nowcastnet.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/nowcastnet/nowcastnet_pretrained.pdparams
+ ```
+
+## 1. 背景简介
+
+近年来,深度学习方法已被应用于天气预报,尤其是雷达观测的降水预报。这些方法利用大量雷达复合观测数据来训练神经网络模型,以端到端的方式进行训练,无需明确参考降水过程的物理定律。
+这里复现了一个针对极端降水的非线性短临预报模型——NowcastNet,该模型将物理演变方案和条件学习法统一到一个神经网络框架中,实现了端到端的优化。
+
+## 2. 模型原理
+
+本章节仅对 NowcastNet 的模型原理进行简单地介绍,详细的理论推导请阅读 [Skilful nowcasting of extreme precipitation with NowcastNet](https://www.nature.com/articles/s41586-023-06184-4#Abs1)。
+
+模型的总体结构如图所示:
+
+
+
+模型使用预训练权重推理,接下来将介绍模型的推理过程。
+
+## 3. 模型构建
+
+在该案例中,用 PaddleScience 代码表示如下:
+
+``` py linenums="24" title="examples/nowcastnet/nowcastnet.py"
+--8<--
+examples/nowcastnet/nowcastnet.py:24:36
+--8<--
+```
+
+``` yaml linenums="35" title="examples/nowcastnet/conf/nowcastnet.yaml"
+--8<--
+examples/nowcastnet/conf/nowcastnet.yaml:35:53
+--8<--
+```
+
+其中,`input_keys` 和 `output_keys` 分别代表网络模型输入、输出变量的名称。
+
+## 4 模型评估可视化
+
+完成上述设置之后,将上述实例化的对象按顺序传递给 `ppsci.solver.Solver`:
+
+``` py linenums="57" title="examples/nowcastnet/nowcastnet.py"
+--8<--
+examples/nowcastnet/nowcastnet.py:57:61
+--8<--
+```
+
+然后构建 VisualizerRadar 生成图片结果:
+
+``` py linenums="69" title="examples/nowcastnet/nowcastnet.py"
+--8<--
+examples/nowcastnet/nowcastnet.py:69:82
+--8<--
+```
+
+## 5. 完整代码
+
+``` py linenums="1" title="examples/nowcastnet/nowcastnet.py"
+--8<--
+examples/nowcastnet/nowcastnet.py
+--8<--
+```
+
+## 6. 结果展示
+
+下图展示了模型的预测结果和真值结果。
+
+
+
+
diff --git a/examples/nowcastnet/conf/nowcastnet.yaml b/examples/nowcastnet/conf/nowcastnet.yaml
new file mode 100644
index 000000000..7a258cba8
--- /dev/null
+++ b/examples/nowcastnet/conf/nowcastnet.yaml
@@ -0,0 +1,57 @@
+hydra:
+ run:
+ # dynamic output directory according to running time and override name
+ dir: outputs_nowcastnet/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
+ job:
+ name: ${mode} # name of logfile
+ chdir: false # keep current working direcotry unchaned
+ config:
+ override_dirname:
+ exclude_keys:
+ - TRAIN.checkpoint_path
+ - TRAIN.pretrained_model_path
+ - EVAL.pretrained_model_path
+ - mode
+ - output_dir
+ - log_freq
+ sweep:
+ # output directory for multirun
+ dir: ${hydra.run.dir}
+ subdir: ./
+
+# general settings
+mode: eval # running mode: train/eval
+seed: 42
+output_dir: ${hydra:run.dir}
+NORMAL_DATASET_PATH: datasets/mrms/figure
+LARGE_DATASET_PATH: datasets/mrms/large_figure
+
+# set working condition
+CASE_TYPE: normal # normal/large
+NUM_SAVE_SAMPLES: 10
+CPU_WORKER: 0
+
+# model settings
+MODEL:
+ normal:
+ input_keys: ["input"]
+ output_keys: ["output"]
+ input_length: 9
+ total_length: 29
+ image_width: 512
+ image_height: 512
+ image_ch: 2
+ ngf: 32
+ large:
+ input_keys: ["input"]
+ output_keys: ["output"]
+ input_length: 9
+ total_length: 29
+ image_width: 1024
+ image_height: 1024
+ image_ch: 2
+ ngf: 32
+
+# evaluation settings
+EVAL:
+ pretrained_model_path: checkpoints/paddle_mrms_model
diff --git a/examples/nowcastnet/nowcastnet.py b/examples/nowcastnet/nowcastnet.py
new file mode 100644
index 000000000..6f3fbde79
--- /dev/null
+++ b/examples/nowcastnet/nowcastnet.py
@@ -0,0 +1,96 @@
+"""
+Reference: https://codeocean.com/capsule/3935105/tree/v1
+"""
+from os import path as osp
+
+import hydra
+import paddle
+from omegaconf import DictConfig
+
+import ppsci
+from ppsci.utils import logger
+
+
+def train(cfg: DictConfig):
+ print("Not supported.")
+
+
+def evaluate(cfg: DictConfig):
+ # set random seed for reproducibility
+ ppsci.utils.misc.set_random_seed(cfg.seed)
+ # initialize logger
+ logger.init_logger("ppsci", osp.join(cfg.output_dir, "train.log"), "info")
+
+ if cfg.CASE_TYPE == "large":
+ dataset_path = cfg.LARGE_DATASET_PATH
+ model_cfg = cfg.MODEL.large
+ output_dir = osp.join(cfg.output_dir, "large")
+ elif cfg.CASE_TYPE == "normal":
+ dataset_path = cfg.NORMAL_DATASET_PATH
+ model_cfg = cfg.MODEL.normal
+ output_dir = osp.join(cfg.output_dir, "normal")
+ else:
+ raise ValueError(
+ f"cfg.CASE_TYPE should in ['normal', 'large'], but got '{cfg.mode}'"
+ )
+ model = ppsci.arch.NowcastNet(**model_cfg)
+
+ input_keys = ("radar_frames",)
+ dataset_param = {
+ "input_keys": input_keys,
+ "label_keys": (),
+ "image_width": model_cfg.image_width,
+ "image_height": model_cfg.image_height,
+ "total_length": model_cfg.total_length,
+ "dataset_path": dataset_path,
+ "data_type": paddle.get_default_dtype(),
+ }
+ test_data_loader = paddle.io.DataLoader(
+ ppsci.data.dataset.RadarDataset(**dataset_param),
+ batch_size=1,
+ shuffle=False,
+ num_workers=cfg.CPU_WORKER,
+ drop_last=True,
+ )
+
+ # initialize solver
+ solver = ppsci.solver.Solver(
+ model,
+ output_dir=output_dir,
+ pretrained_model_path=cfg.EVAL.pretrained_model_path,
+ )
+
+ for batch_id, test_ims in enumerate(test_data_loader):
+ test_ims = test_ims[0][input_keys[0]].numpy()
+ frames_tensor = paddle.to_tensor(
+ data=test_ims, dtype=paddle.get_default_dtype()
+ )
+ if batch_id <= cfg.NUM_SAVE_SAMPLES:
+ visualizer = {
+ "v_nowcastnet": ppsci.visualize.VisualizerRadar(
+ {"input": frames_tensor},
+ {
+ "output": lambda out: out["output"],
+ },
+ prefix="v_nowcastnet",
+ case_type=cfg.CASE_TYPE,
+ total_length=model_cfg.total_length,
+ )
+ }
+ solver.visualizer = visualizer
+ # visualize prediction
+ solver.visualize(batch_id)
+
+
+@hydra.main(version_base=None, config_path="./conf", config_name="nowcastnet.yaml")
+def main(cfg: DictConfig):
+ if cfg.mode == "train":
+ train(cfg)
+ elif cfg.mode == "eval":
+ evaluate(cfg)
+ else:
+ raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/mkdocs.yml b/mkdocs.yml
index 5dafa40fe..54402cbc6 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -69,6 +69,7 @@ nav:
- EPNN: zh/examples/epnn.md
- 地球科学(AI for Earth Science):
- FourCastNet: zh/examples/fourcastnet.md
+ - NowcastNet: zh/examples/nowcastnet.md
- API文档:
- " ":
- ppsci.arch: zh/api/arch.md
diff --git a/ppsci/arch/__init__.py b/ppsci/arch/__init__.py
index baa15a892..7a8e6c263 100644
--- a/ppsci/arch/__init__.py
+++ b/ppsci/arch/__init__.py
@@ -32,6 +32,7 @@
from ppsci.arch.afno import PrecipNet # isort:skip
from ppsci.arch.unetex import UNetEx # isort:skip
from ppsci.arch.epnn import Epnn # isort:skip
+from ppsci.arch.nowcastnet import NowcastNet # isort:skip
from ppsci.utils import logger # isort:skip
@@ -52,6 +53,7 @@
"PrecipNet",
"UNetEx",
"Epnn",
+ "NowcastNet",
"build_model",
]
diff --git a/ppsci/arch/nowcastnet.py b/ppsci/arch/nowcastnet.py
new file mode 100644
index 000000000..4fc3d6a61
--- /dev/null
+++ b/ppsci/arch/nowcastnet.py
@@ -0,0 +1,623 @@
+import collections
+from typing import Tuple
+
+import paddle
+
+from ppsci.arch import base
+
+
+class NowcastNet(base.Arch):
+ """The NowcastNet model.
+
+ Args:
+ input_keys (Tuple[str, ...]): Name of input keys, such as ("input",).
+ output_keys (Tuple[str, ...]): Name of output keys, such as ("output",).
+ input_length (int, optional): Input length. Defaults to 9.
+ total_length (int, optional): Total length. Defaults to 29.
+ image_height (int, optional): Image height. Defaults to 512.
+ image_width (int, optional): Image width. Defaults to 512.
+ image_ch (int, optional): Image channel. Defaults to 2.
+ ngf (int, optional): Noise Projector input length. Defaults to 32.
+
+ Examples:
+ >>> import ppsci
+ >>> model = ppsci.arch.NowcastNet(("input", ), ("output", ))
+ """
+
+ def __init__(
+ self,
+ input_keys: Tuple[str, ...],
+ output_keys: Tuple[str, ...],
+ input_length: int = 9,
+ total_length: int = 29,
+ image_height: int = 512,
+ image_width: int = 512,
+ image_ch: int = 2,
+ ngf: int = 32,
+ ):
+ super().__init__()
+ self.input_keys = input_keys
+ self.output_keys = output_keys
+
+ self.input_length = input_length
+ self.total_length = total_length
+ self.image_height = image_height
+ self.image_width = image_width
+ self.image_ch = image_ch
+ self.ngf = ngf
+
+ configs = collections.namedtuple(
+ "Object", ["ngf", "evo_ic", "gen_oc", "ic_feature"]
+ )
+ configs.ngf = self.ngf
+ configs.evo_ic = self.total_length - self.input_length
+ configs.gen_oc = self.total_length - self.input_length
+ configs.ic_feature = self.ngf * 10
+
+ self.pred_length = self.total_length - self.input_length
+ self.evo_net = Evolution_Network(self.input_length, self.pred_length, base_c=32)
+ self.gen_enc = Generative_Encoder(self.total_length, base_c=self.ngf)
+ self.gen_dec = Generative_Decoder(configs)
+ self.proj = Noise_Projector(self.ngf)
+ sample_tensor = paddle.zeros(shape=[1, 1, self.image_height, self.image_width])
+ self.grid = make_grid(sample_tensor)
+
+ def split_to_dict(
+ self, data_tensors: Tuple[paddle.Tensor, ...], keys: Tuple[str, ...]
+ ):
+ return {key: data_tensors[i] for i, key in enumerate(keys)}
+
+ def forward(self, x):
+ if self._input_transform is not None:
+ x = self._input_transform(x)
+
+ x_tensor = self.concat_to_tensor(x, self.input_keys)
+
+ y = []
+ out = self.forward_tensor(x_tensor)
+ y.append(out)
+ y = self.split_to_dict(y, self.output_keys)
+
+ if self._output_transform is not None:
+ y = self._output_transform(x, y)
+ return y
+
+ def forward_tensor(self, x):
+ all_frames = x[:, :, :, :, :1]
+ frames = all_frames.transpose(perm=[0, 1, 4, 2, 3])
+ batch = frames.shape[0]
+ height = frames.shape[3]
+ width = frames.shape[4]
+ # Input Frames
+ input_frames = frames[:, : self.input_length]
+ input_frames = input_frames.reshape((batch, self.input_length, height, width))
+ # Evolution Network
+ intensity, motion = self.evo_net(input_frames)
+ motion_ = motion.reshape((batch, self.pred_length, 2, height, width))
+ intensity_ = intensity.reshape((batch, self.pred_length, 1, height, width))
+ series = []
+ last_frames = all_frames[:, self.input_length - 1 : self.input_length, :, :, 0]
+ grid = self.grid.tile((batch, 1, 1, 1))
+ for i in range(self.pred_length):
+ last_frames = warp(
+ last_frames, motion_[:, i], grid, mode="nearest", padding_mode="border"
+ )
+ last_frames = last_frames + intensity_[:, i]
+ series.append(last_frames)
+ evo_result = paddle.concat(x=series, axis=1)
+ evo_result = evo_result / 128
+ # Generative Network
+ evo_feature = self.gen_enc(paddle.concat(x=[input_frames, evo_result], axis=1))
+ noise = paddle.randn(shape=[batch, self.ngf, height // 32, width // 32])
+ noise_feature = (
+ self.proj(noise)
+ .reshape((batch, -1, 4, 4, 8, 8))
+ .transpose(perm=[0, 1, 4, 5, 2, 3])
+ .reshape((batch, -1, height // 8, width // 8))
+ )
+ feature = paddle.concat(x=[evo_feature, noise_feature], axis=1)
+ gen_result = self.gen_dec(feature, evo_result)
+ return gen_result.unsqueeze(axis=-1)
+
+
+class Evolution_Network(paddle.nn.Layer):
+ def __init__(self, n_channels, n_classes, base_c=64, bilinear=True):
+ super().__init__()
+ self.n_channels = n_channels
+ self.n_classes = n_classes
+ self.bilinear = bilinear
+ base_c = base_c
+ self.inc = DoubleConv(n_channels, base_c)
+ self.down1 = Down(base_c * 1, base_c * 2)
+ self.down2 = Down(base_c * 2, base_c * 4)
+ self.down3 = Down(base_c * 4, base_c * 8)
+ factor = 2 if bilinear else 1
+ self.down4 = Down(base_c * 8, base_c * 16 // factor)
+ self.up1 = Up(base_c * 16, base_c * 8 // factor, bilinear)
+ self.up2 = Up(base_c * 8, base_c * 4 // factor, bilinear)
+ self.up3 = Up(base_c * 4, base_c * 2 // factor, bilinear)
+ self.up4 = Up(base_c * 2, base_c * 1, bilinear)
+ self.outc = OutConv(base_c * 1, n_classes)
+ param1 = paddle.zeros(shape=[1, n_classes, 1, 1])
+ gamma = self.create_parameter(
+ shape=param1.shape,
+ dtype=param1.dtype,
+ default_initializer=paddle.nn.initializer.Assign(param1),
+ )
+ gamma.stop_gradient = False
+ self.gamma = gamma
+ self.up1_v = Up(base_c * 16, base_c * 8 // factor, bilinear)
+ self.up2_v = Up(base_c * 8, base_c * 4 // factor, bilinear)
+ self.up3_v = Up(base_c * 4, base_c * 2 // factor, bilinear)
+ self.up4_v = Up(base_c * 2, base_c * 1, bilinear)
+ self.outc_v = OutConv(base_c * 1, n_classes * 2)
+
+ def forward(self, x):
+ x1 = self.inc(x)
+ x2 = self.down1(x1)
+ x3 = self.down2(x2)
+ x4 = self.down3(x3)
+ x5 = self.down4(x4)
+ x = self.up1(x5, x4)
+ x = self.up2(x, x3)
+ x = self.up3(x, x2)
+ x = self.up4(x, x1)
+ x = self.outc(x) * self.gamma
+ v = self.up1_v(x5, x4)
+ v = self.up2_v(v, x3)
+ v = self.up3_v(v, x2)
+ v = self.up4_v(v, x1)
+ v = self.outc_v(v)
+ return x, v
+
+
+class DoubleConv(paddle.nn.Layer):
+ def __init__(self, in_channels, out_channels, kernel=3, mid_channels=None):
+ super().__init__()
+ if not mid_channels:
+ mid_channels = out_channels
+ self.double_conv = paddle.nn.Sequential(
+ paddle.nn.BatchNorm2D(num_features=in_channels),
+ paddle.nn.ReLU(),
+ paddle.nn.utils.spectral_norm(
+ layer=paddle.nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=mid_channels,
+ kernel_size=kernel,
+ padding=kernel // 2,
+ )
+ ),
+ paddle.nn.BatchNorm2D(num_features=mid_channels),
+ paddle.nn.ReLU(),
+ paddle.nn.utils.spectral_norm(
+ layer=paddle.nn.Conv2D(
+ in_channels=mid_channels,
+ out_channels=out_channels,
+ kernel_size=kernel,
+ padding=kernel // 2,
+ )
+ ),
+ )
+ self.single_conv = paddle.nn.Sequential(
+ paddle.nn.BatchNorm2D(num_features=in_channels),
+ paddle.nn.utils.spectral_norm(
+ layer=paddle.nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel,
+ padding=kernel // 2,
+ )
+ ),
+ )
+
+ def forward(self, x):
+ shortcut = self.single_conv(x)
+ x = self.double_conv(x)
+ x = x + shortcut
+ return x
+
+
+class Down(paddle.nn.Layer):
+ def __init__(self, in_channels, out_channels, kernel=3):
+ super().__init__()
+ self.maxpool_conv = paddle.nn.Sequential(
+ paddle.nn.MaxPool2D(kernel_size=2),
+ DoubleConv(in_channels, out_channels, kernel),
+ )
+
+ def forward(self, x):
+ x = self.maxpool_conv(x)
+ return x
+
+
+class Up(paddle.nn.Layer):
+ def __init__(self, in_channels, out_channels, bilinear=True, kernel=3):
+ super().__init__()
+ if bilinear:
+ self.up = paddle.nn.Upsample(
+ scale_factor=2, mode="bilinear", align_corners=True
+ )
+ self.conv = DoubleConv(
+ in_channels, out_channels, kernel=kernel, mid_channels=in_channels // 2
+ )
+ else:
+ self.up = paddle.nn.Conv2DTranspose(
+ in_channels=in_channels,
+ out_channels=in_channels // 2,
+ kernel_size=2,
+ stride=2,
+ )
+ self.conv = DoubleConv(in_channels, out_channels, kernel)
+
+ def forward(self, x1, x2):
+ x1 = self.up(x1)
+ # input is CHW
+ diffY = x2.shape[2] - x1.shape[2]
+ diffX = x2.shape[3] - x1.shape[3]
+ x1 = paddle.nn.functional.pad(
+ x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]
+ )
+ x = paddle.concat(x=[x2, x1], axis=1)
+ return self.conv(x)
+
+
+class Up_S(paddle.nn.Layer):
+ def __init__(self, in_channels, out_channels, bilinear=True, kernel=3):
+ super().__init__()
+ if bilinear:
+ self.up = paddle.nn.Upsample(
+ scale_factor=2, mode="bilinear", align_corners=True
+ )
+ self.conv = DoubleConv(
+ in_channels, out_channels, kernel=kernel, mid_channels=in_channels
+ )
+ else:
+ self.up = paddle.nn.Conv2DTranspose(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ kernel_size=2,
+ stride=2,
+ )
+ self.conv = DoubleConv(in_channels, out_channels, kernel)
+
+ def forward(self, x):
+ x = self.up(x)
+ return self.conv(x)
+
+
+class OutConv(paddle.nn.Layer):
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+ self.conv = paddle.nn.Conv2D(
+ in_channels=in_channels, out_channels=out_channels, kernel_size=1
+ )
+
+ def forward(self, x):
+ return self.conv(x)
+
+
+class Generative_Encoder(paddle.nn.Layer):
+ def __init__(self, n_channels, base_c=64):
+ super().__init__()
+ base_c = base_c
+ self.inc = DoubleConv(n_channels, base_c, kernel=3)
+ self.down1 = Down(base_c * 1, base_c * 2, 3)
+ self.down2 = Down(base_c * 2, base_c * 4, 3)
+ self.down3 = Down(base_c * 4, base_c * 8, 3)
+
+ def forward(self, x):
+ x = self.inc(x)
+ x = self.down1(x)
+ x = self.down2(x)
+ x = self.down3(x)
+ return x
+
+
+class Generative_Decoder(paddle.nn.Layer):
+ def __init__(self, opt):
+ super().__init__()
+ self.opt = opt
+ nf = opt.ngf
+ ic = opt.ic_feature
+ self.fc = paddle.nn.Conv2D(
+ in_channels=ic, out_channels=8 * nf, kernel_size=3, padding=1
+ )
+ self.head_0 = GenBlock(8 * nf, 8 * nf, opt)
+ self.G_middle_0 = GenBlock(8 * nf, 4 * nf, opt, double_conv=True)
+ self.G_middle_1 = GenBlock(4 * nf, 4 * nf, opt, double_conv=True)
+ self.up_0 = GenBlock(4 * nf, 2 * nf, opt)
+ self.up_1 = GenBlock(2 * nf, 1 * nf, opt, double_conv=True)
+ self.up_2 = GenBlock(1 * nf, 1 * nf, opt, double_conv=True)
+ final_nc = nf * 1
+ self.conv_img = paddle.nn.Conv2D(
+ in_channels=final_nc, out_channels=self.opt.gen_oc, kernel_size=3, padding=1
+ )
+ self.up = paddle.nn.Upsample(scale_factor=2)
+
+ def forward(self, x, evo):
+ x = self.fc(x)
+ x = self.head_0(x, evo)
+ x = self.up(x)
+ x = self.G_middle_0(x, evo)
+ x = self.G_middle_1(x, evo)
+ x = self.up(x)
+ x = self.up_0(x, evo)
+ x = self.up(x)
+ x = self.up_1(x, evo)
+ x = self.up_2(x, evo)
+ x = self.conv_img(paddle.nn.functional.leaky_relu(x=x, negative_slope=0.2))
+ return x
+
+
+class GenBlock(paddle.nn.Layer):
+ def __init__(self, fin, fout, opt, use_se=False, dilation=1, double_conv=False):
+ super().__init__()
+ self.learned_shortcut = fin != fout
+ fmiddle = min(fin, fout)
+ self.opt = opt
+ self.double_conv = double_conv
+ self.pad = paddle.nn.Pad2D(padding=dilation, mode="reflect")
+ self.conv_0 = paddle.nn.Conv2D(
+ in_channels=fin,
+ out_channels=fmiddle,
+ kernel_size=3,
+ padding=0,
+ dilation=dilation,
+ )
+ self.conv_1 = paddle.nn.Conv2D(
+ in_channels=fmiddle,
+ out_channels=fout,
+ kernel_size=3,
+ padding=0,
+ dilation=dilation,
+ )
+ if self.learned_shortcut:
+ self.conv_s = paddle.nn.Conv2D(
+ in_channels=fin, out_channels=fout, kernel_size=1, bias_attr=False
+ )
+ self.conv_0 = paddle.nn.utils.spectral_norm(layer=self.conv_0)
+ self.conv_1 = paddle.nn.utils.spectral_norm(layer=self.conv_1)
+ if self.learned_shortcut:
+ self.conv_s = paddle.nn.utils.spectral_norm(layer=self.conv_s)
+ ic = opt.evo_ic
+ self.norm_0 = SPADE(fin, ic)
+ self.norm_1 = SPADE(fmiddle, ic)
+ if self.learned_shortcut:
+ self.norm_s = SPADE(fin, ic)
+
+ def forward(self, x, evo):
+ x_s = self.shortcut(x, evo)
+ dx = self.conv_0(self.pad(self.actvn(self.norm_0(x, evo))))
+ if self.double_conv:
+ dx = self.conv_1(self.pad(self.actvn(self.norm_1(dx, evo))))
+ out = x_s + dx
+ return out
+
+ def shortcut(self, x, evo):
+ if self.learned_shortcut:
+ x_s = self.conv_s(self.norm_s(x, evo))
+ else:
+ x_s = x
+ return x_s
+
+ def actvn(self, x):
+ return paddle.nn.functional.leaky_relu(x=x, negative_slope=0.2)
+
+
+class SPADE(paddle.nn.Layer):
+ def __init__(self, norm_nc, label_nc):
+ super().__init__()
+ ks = 3
+ self.param_free_norm = paddle.nn.InstanceNorm2D(
+ num_features=norm_nc, weight_attr=False, bias_attr=False, momentum=1 - 0.1
+ )
+ nhidden = 64
+ ks = 3
+ pw = ks // 2
+ self.mlp_shared = paddle.nn.Sequential(
+ paddle.nn.Pad2D(padding=pw, mode="reflect"),
+ paddle.nn.Conv2D(
+ in_channels=label_nc, out_channels=nhidden, kernel_size=ks, padding=0
+ ),
+ paddle.nn.ReLU(),
+ )
+ self.pad = paddle.nn.Pad2D(padding=pw, mode="reflect")
+ self.mlp_gamma = paddle.nn.Conv2D(
+ in_channels=nhidden, out_channels=norm_nc, kernel_size=ks, padding=0
+ )
+ self.mlp_beta = paddle.nn.Conv2D(
+ in_channels=nhidden, out_channels=norm_nc, kernel_size=ks, padding=0
+ )
+
+ def forward(self, x, evo):
+ normalized = self.param_free_norm(x)
+ evo = paddle.nn.functional.adaptive_avg_pool2d(x=evo, output_size=x.shape[2:])
+ actv = self.mlp_shared(evo)
+ gamma = self.mlp_gamma(self.pad(actv))
+ beta = self.mlp_beta(self.pad(actv))
+ out = normalized * (1 + gamma) + beta
+ return out
+
+
+class Noise_Projector(paddle.nn.Layer):
+ def __init__(self, input_length):
+ super().__init__()
+ self.input_length = input_length
+ self.conv_first = spectral_norm(
+ paddle.nn.Conv2D(
+ in_channels=self.input_length,
+ out_channels=self.input_length * 2,
+ kernel_size=3,
+ padding=1,
+ )
+ )
+ self.L1 = ProjBlock(self.input_length * 2, self.input_length * 4)
+ self.L2 = ProjBlock(self.input_length * 4, self.input_length * 8)
+ self.L3 = ProjBlock(self.input_length * 8, self.input_length * 16)
+ self.L4 = ProjBlock(self.input_length * 16, self.input_length * 32)
+
+ def forward(self, x):
+ x = self.conv_first(x)
+ x = self.L1(x)
+ x = self.L2(x)
+ x = self.L3(x)
+ x = self.L4(x)
+ return x
+
+
+class ProjBlock(paddle.nn.Layer):
+ def __init__(self, in_channel, out_channel):
+ super().__init__()
+ self.one_conv = spectral_norm(
+ paddle.nn.Conv2D(
+ in_channels=in_channel,
+ out_channels=out_channel - in_channel,
+ kernel_size=1,
+ padding=0,
+ )
+ )
+ self.double_conv = paddle.nn.Sequential(
+ spectral_norm(
+ paddle.nn.Conv2D(
+ in_channels=in_channel,
+ out_channels=out_channel,
+ kernel_size=3,
+ padding=1,
+ )
+ ),
+ paddle.nn.ReLU(),
+ spectral_norm(
+ paddle.nn.Conv2D(
+ in_channels=out_channel,
+ out_channels=out_channel,
+ kernel_size=3,
+ padding=1,
+ )
+ ),
+ )
+
+ def forward(self, x):
+ x1 = paddle.concat(x=[x, self.one_conv(x)], axis=1)
+ x2 = self.double_conv(x)
+ output = x1 + x2
+ return output
+
+
+def make_grid(input):
+ B, C, H, W = input.shape
+ xx = paddle.arange(start=0, end=W).reshape((1, -1)).tile((H, 1))
+ yy = paddle.arange(start=0, end=H).reshape((-1, 1)).tile((1, W))
+ xx = xx.reshape((1, 1, H, W)).tile((B, 1, 1, 1))
+ yy = yy.reshape((1, 1, H, W)).tile((B, 1, 1, 1))
+ grid = paddle.concat(x=(xx, yy), axis=1).astype(dtype=paddle.get_default_dtype())
+ return grid
+
+
+def warp(input, flow, grid, mode="bilinear", padding_mode="zeros"):
+ B, C, H, W = input.shape
+ vgrid = grid + flow
+ vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0
+ vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0
+ vgrid = vgrid.transpose(perm=[0, 2, 3, 1])
+ output = paddle.nn.functional.grid_sample(
+ x=input.cpu(),
+ grid=vgrid.cpu(),
+ padding_mode=padding_mode,
+ mode=mode,
+ align_corners=True,
+ )
+ return output.cuda()
+
+
+def l2normalize(v, eps=1e-12):
+ return v / (v.norm() + eps)
+
+
+class spectral_norm(paddle.nn.Layer):
+ def __init__(self, module, name="weight", power_iterations=1):
+ super().__init__()
+ self.module = module
+ self.name = name
+ self.power_iterations = power_iterations
+ if not self._made_params():
+ self._make_params()
+
+ def _update_u_v(self):
+ u = getattr(self.module, self.name + "_u")
+ v = getattr(self.module, self.name + "_v")
+ w = getattr(self.module, self.name + "_bar")
+ height = w.detach().shape[0]
+ for _ in range(self.power_iterations):
+ v = l2normalize(
+ paddle.mv(
+ x=paddle.t(input=w.reshape((height, -1)).detach()), vec=u.detach()
+ )
+ )
+ u = l2normalize(
+ paddle.mv(x=w.reshape((height, -1)).detach(), vec=v.detach())
+ )
+ sigma = u.dot(y=w.reshape((height, -1)).mv(vec=v))
+ setattr(self.module, self.name, w / sigma.expand_as(y=w))
+
+ def _made_params(self):
+ try:
+ _ = getattr(self.module, self.name + "_u")
+ _ = getattr(self.module, self.name + "_v")
+ _ = getattr(self.module, self.name + "_bar")
+ return True
+ except AttributeError:
+ return False
+
+ def _make_params(self):
+ w = getattr(self.module, self.name)
+ height = w.detach().shape[0]
+ width = w.reshape((height, -1)).detach().shape[1]
+
+ tmp_w = paddle.normal(shape=[height])
+ out_0 = paddle.create_parameter(
+ shape=tmp_w.shape,
+ dtype=tmp_w.numpy().dtype,
+ default_initializer=paddle.nn.initializer.Assign(tmp_w),
+ )
+ out_0.stop_gradient = True
+ u = out_0
+
+ tmp_w = paddle.normal(shape=[width])
+ out_1 = paddle.create_parameter(
+ shape=tmp_w.shape,
+ dtype=tmp_w.numpy().dtype,
+ default_initializer=paddle.nn.initializer.Assign(tmp_w),
+ )
+ out_1.stop_gradient = True
+ v = out_1
+ u = l2normalize(u)
+ v = l2normalize(v)
+ tmp_w = w.detach()
+ out_2 = paddle.create_parameter(
+ shape=tmp_w.shape,
+ dtype=tmp_w.numpy().dtype,
+ default_initializer=paddle.nn.initializer.Assign(tmp_w),
+ )
+ out_2.stop_gradient = False
+ w_bar = out_2
+ del self.module._parameters[self.name]
+
+ u = create_param(u)
+ v = create_param(v)
+ self.module.add_parameter(name=self.name + "_u", parameter=u)
+ self.module.add_parameter(name=self.name + "_v", parameter=v)
+ self.module.add_parameter(name=self.name + "_bar", parameter=w_bar)
+
+ def forward(self, *args):
+ self._update_u_v()
+ return self.module.forward(*args)
+
+
+def create_param(x):
+ param = paddle.create_parameter(
+ shape=x.shape,
+ dtype=x.dtype,
+ default_initializer=paddle.nn.initializer.Assign(x),
+ )
+ param.stop_gradient = x.stop_gradient
+ return param
diff --git a/ppsci/data/dataset/__init__.py b/ppsci/data/dataset/__init__.py
index 997573502..bbddf17b2 100644
--- a/ppsci/data/dataset/__init__.py
+++ b/ppsci/data/dataset/__init__.py
@@ -27,6 +27,7 @@
from ppsci.data.dataset.mat_dataset import MatDataset
from ppsci.data.dataset.npz_dataset import IterableNPZDataset
from ppsci.data.dataset.npz_dataset import NPZDataset
+from ppsci.data.dataset.radar_dataset import RadarDataset
from ppsci.data.dataset.trphysx_dataset import CylinderDataset
from ppsci.data.dataset.trphysx_dataset import LorenzDataset
from ppsci.data.dataset.trphysx_dataset import RosslerDataset
@@ -50,6 +51,7 @@
"NPZDataset",
"CylinderDataset",
"LorenzDataset",
+ "RadarDataset",
"RosslerDataset",
"VtuDataset",
"MeshAirfoilDataset",
diff --git a/ppsci/data/dataset/radar_dataset.py b/ppsci/data/dataset/radar_dataset.py
new file mode 100644
index 000000000..e4913ea33
--- /dev/null
+++ b/ppsci/data/dataset/radar_dataset.py
@@ -0,0 +1,129 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import os
+from typing import Dict
+from typing import Optional
+from typing import Tuple
+
+import cv2
+import numpy as np
+import paddle
+from paddle import io
+
+
+class RadarDataset(io.Dataset):
+ """Class for Radar dataset.
+
+ Args:
+ input_keys (Tuple[str, ...]): Input keys, such as ("input",).
+ label_keys (Tuple[str, ...]): Output keys, such as ("output",).
+ image_width (int): Image width.
+ image_height (int): Image height.
+ total_length (int): Total length.
+ dataset_path (str): Dataset path.
+ data_type (str): Input and output data type. Defaults to paddle.get_default_dtype().
+ weight_dict (Optional[Dict[str, float]]): Weight dictionary. Defaults to None.
+
+ Examples:
+ >>> import ppsci
+ >>> dataset = ppsci.data.dataset.RadarDataset(
+ ... "input_keys": ("input",),
+ ... "label_keys": ("output",),
+ ... "image_width": 512,
+ ... "image_height": 512,
+ ... "total_length": 29,
+ ... "dataset_path": "datasets/mrms/figure",
+ ... "data_type": paddle.get_default_dtype(),
+ ... ) # doctest: +SKIP
+ """
+
+ def __init__(
+ self,
+ input_keys: Tuple[str, ...],
+ label_keys: Tuple[str, ...],
+ image_width: int,
+ image_height: int,
+ total_length: int,
+ dataset_path: str,
+ data_type: str = paddle.get_default_dtype(),
+ weight_dict: Optional[Dict[str, float]] = None,
+ ):
+ self.input_keys = input_keys
+ self.label_keys = label_keys
+ self.img_width = image_width
+ self.img_height = image_height
+ self.length = total_length
+ self.dataset_path = dataset_path
+ self.data_type = data_type
+
+ self.weight_dict = {} if weight_dict is None else weight_dict
+ if weight_dict is not None:
+ self.weight_dict = {key: 1.0 for key in self.label_keys}
+ self.weight_dict.update(weight_dict)
+
+ self.case_list = []
+ name_list = os.listdir(self.dataset_path)
+ name_list.sort()
+ for name in name_list:
+ case = []
+ for i in range(29):
+ case.append(
+ self.dataset_path
+ + "/"
+ + name
+ + "/"
+ + name
+ + "-"
+ + str(i).zfill(2)
+ + ".png"
+ )
+ self.case_list.append(case)
+
+ def load(self, index):
+ data = []
+ for img_path in self.case_list[index]:
+ img = cv2.imread(img_path, 2)
+ data.append(np.expand_dims(img, axis=0))
+ data = np.concatenate(data, axis=0).astype(self.data_type) / 10.0 - 3.0
+ assert data.shape[1] <= 1024 and data.shape[2] <= 1024
+ return data
+
+ def __getitem__(self, index):
+ data = self.load(index)[-self.length :].copy()
+ mask = np.ones_like(data)
+ mask[data < 0] = 0
+ data[data < 0] = 0
+ data = np.clip(data, 0, 128)
+ vid = np.zeros((self.length, self.img_height, self.img_width, 2))
+ vid[..., 0] = data
+ vid[..., 1] = mask
+
+ input_item = {self.input_keys[0]: vid}
+ label_item = {}
+ weight_item = {}
+ for key in self.label_keys:
+ label_item[key] = np.asarray([], paddle.get_default_dtype())
+ if len(label_item) > 0:
+ weight_shape = [1] * len(next(iter(label_item.values())).shape)
+ weight_item = {
+ key: np.full(weight_shape, value, paddle.get_default_dtype())
+ for key, value in self.weight_dict.items()
+ }
+ return input_item, label_item, weight_item
+
+ def __len__(self):
+ return len(self.case_list)
diff --git a/ppsci/visualize/__init__.py b/ppsci/visualize/__init__.py
index a0ea90e38..9082fbe95 100644
--- a/ppsci/visualize/__init__.py
+++ b/ppsci/visualize/__init__.py
@@ -26,6 +26,7 @@
from ppsci.visualize.visualizer import Visualizer2DPlot # isort:skip
from ppsci.visualize.visualizer import Visualizer3D # isort:skip
from ppsci.visualize.visualizer import VisualizerWeather # isort:skip
+from ppsci.visualize.radar import VisualizerRadar # isort:skip
from ppsci.visualize.vtu import save_vtu_from_dict # isort:skip
from ppsci.visualize.plot import save_plot_from_1d_dict # isort:skip
from ppsci.visualize.plot import save_plot_from_3d_dict # isort:skip
@@ -40,6 +41,7 @@
"Visualizer2DPlot",
"Visualizer3D",
"VisualizerWeather",
+ "VisualizerRadar",
"save_vtu_from_dict",
"save_vtu_to_mesh",
"save_plot_from_1d_dict",
diff --git a/ppsci/visualize/radar.py b/ppsci/visualize/radar.py
new file mode 100644
index 000000000..3d92ccfbe
--- /dev/null
+++ b/ppsci/visualize/radar.py
@@ -0,0 +1,123 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import os
+from typing import Callable
+from typing import Dict
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+from ppsci.visualize import base
+
+
+class VisualizerRadar(base.Visualizer):
+ """Visualizer for NowcastNet Radar Dataset.
+
+ Args:
+ input_dict (Dict[str, np.ndarray]): Input dict.
+ output_expr (Dict[str, Callable]): Output expression.
+ batch_size (int, optional): Batch size of data when computing result in visu.py. Defaults to 64.
+ num_timestamps (int, optional): Number of timestamps
+ prefix (str, optional): Prefix for output file.
+ case_type (str, optional): Case type.
+ total_length (str, optional): Total length.
+
+ Examples:
+ >>> import ppsci
+ >>> frames_tensor = paddle.randn([1, 29, 512, 512, 2])
+ >>> visualizer = ppsci.visualize.VisualizerRadar(
+ ... {"input": frames_tensor},
+ ... {"output": lambda out: out["output"]},
+ ... num_timestamps=1,
+ ... prefix="v_nowcastnet",
+ ... )
+ """
+
+ def __init__(
+ self,
+ input_dict: Dict[str, np.ndarray],
+ output_expr: Dict[str, Callable],
+ batch_size: int = 64,
+ num_timestamps: int = 1,
+ prefix: str = "vtu",
+ case_type: str = "normal",
+ total_length: int = 29,
+ ):
+ super().__init__(input_dict, output_expr, batch_size, num_timestamps, prefix)
+ self.case_type = case_type
+ self.total_length = total_length
+ self.input_dict = input_dict
+
+ def save(self, path, data_dict):
+ if not os.path.exists(path):
+ os.makedirs(path)
+ test_ims = self.input_dict[list(self.input_dict.keys())[0]]
+ # keys: {"input", "output"}
+ img_gen = data_dict[list(data_dict.keys())[1]]
+ vis_info = {"vmin": 1, "vmax": 40}
+ if self.case_type == "normal":
+ test_ims_plot = test_ims[0][
+ :-2, 256 - 192 : 256 + 192, 256 - 192 : 256 + 192
+ ]
+ img_gen_plot = img_gen[0][:-2, 256 - 192 : 256 + 192, 256 - 192 : 256 + 192]
+ else:
+ test_ims_plot = test_ims[0][:-2]
+ img_gen_plot = img_gen[0][:-2]
+ save_plots(
+ test_ims_plot,
+ labels=[f"gt{i + 1}" for i in range(self.total_length)],
+ res_path=path,
+ vmin=vis_info["vmin"],
+ vmax=vis_info["vmax"],
+ )
+ save_plots(
+ img_gen_plot,
+ labels=[f"pd{i + 1}" for i in range(9, self.total_length)],
+ res_path=path,
+ vmin=vis_info["vmin"],
+ vmax=vis_info["vmax"],
+ )
+
+
+def save_plots(
+ field,
+ labels,
+ res_path,
+ figsize=None,
+ vmin=0,
+ vmax=10,
+ cmap="viridis",
+ npy=False,
+ **imshow_args,
+):
+ for i, data in enumerate(field):
+ if i >= len(labels):
+ break
+ plt.figure(figsize=figsize)
+ ax = plt.axes()
+ ax.set_axis_off()
+ alpha = data[..., 0] / 1
+ alpha[alpha < 1] = 0
+ alpha[alpha > 1] = 1
+ ax.imshow(
+ data[..., 0], alpha=alpha, vmin=vmin, vmax=vmax, cmap=cmap, **imshow_args
+ )
+ plt.savefig(os.path.join(res_path, labels[i] + ".png"))
+ plt.close()
+ if npy:
+ with open(os.path.join(res_path, labels[i] + ".npy"), "wb") as f:
+ np.save(f, data[..., 0])
diff --git a/requirements.txt b/requirements.txt
index 0dbdb4dc0..de5751ef6 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -17,3 +17,4 @@ typing-extensions
seaborn
colorlog
hydra-core
+opencv-python