Skip to content

Commit

Permalink
【PPSCI Export&Infer No.9】Bubble (PaddlePaddle#887)
Browse files Browse the repository at this point in the history
* 【PPSCI Export&Infer No.9】

* update examples/bubble/conf/bubble.yaml

* fix codestyle bugs

* Update examples/bubble/bubble.py

* update examples/bubble/bubble.py

---------

Co-authored-by: HydrogenSulfate <[email protected]>
  • Loading branch information
wufei2 and HydrogenSulfate authored May 12, 2024
1 parent 60a6369 commit 7c04bf8
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 1 deletion.
16 changes: 16 additions & 0 deletions docs/zh/examples/bubble.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,22 @@
python bubble.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/bubble/bubble_pretrained.pdparams
```

=== "模型导出命令"

``` sh
python bubble.py mode=export
```

=== "模型推理命令"

``` sh
# linux
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/BubbleNet/bubble.mat
# windows
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/BubbleNet/bubble.mat --output bubble.mat
python bubble.py mode=infer
```

| 预训练模型 | 指标 |
|:--| :--|
| [bubble_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/bubble/bubble_pretrained.pdparams) | loss(bubble_mse): 0.00558<br>MSE.u(bubble_mse): 0.00090<br>MSE.v(bubble_mse): 0.00322<br>MSE.p(bubble_mse): 0.00066<br>MSE.phil(bubble_mse): 0.00079 |
Expand Down
106 changes: 105 additions & 1 deletion examples/bubble/bubble.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,14 +403,118 @@ def transform_out(in_, out):
)


def export(cfg: DictConfig):
# set model
model_psi = ppsci.arch.MLP(**cfg.MODEL.psi_net)
model_p = ppsci.arch.MLP(**cfg.MODEL.p_net)
model_phil = ppsci.arch.MLP(**cfg.MODEL.phil_net)

# transform
def transform_out(in_, out):
psi_y = out["psi"]
y = in_["y"]
x = in_["x"]
u = jacobian(psi_y, y, create_graph=False)
v = -jacobian(psi_y, x, create_graph=False)
return {"u": u, "v": v}

# register transform
model_psi.register_output_transform(transform_out)
model_list = ppsci.arch.ModelList((model_psi, model_p, model_phil))

# initialize solver
solver = ppsci.solver.Solver(
model_list,
pretrained_model_path=cfg.INFER.pretrained_model_path,
)
# export model
from paddle.static import InputSpec

input_spec = [
{
key: InputSpec([None, 1], "float32", name=key)
for key in model_list.input_keys
},
]
solver.export(input_spec, cfg.INFER.export_path)


def inference(cfg: DictConfig):
# load Data
data = scipy.io.loadmat(cfg.DATA_PATH)
# normalize data
p_max = data["p"].max(axis=0)
p_min = data["p"].min(axis=0)
u_max = data["u"].max(axis=0)
u_min = data["u"].min(axis=0)
v_max = data["v"].max(axis=0)
v_min = data["v"].min(axis=0)

from deploy.python_infer import pinn_predictor

predictor = pinn_predictor.PINNPredictor(cfg)
# set time-geometry
timestamps = np.linspace(0, 126, 127, endpoint=True)
geom = {
"time_rect_visu": ppsci.geometry.TimeXGeometry(
ppsci.geometry.TimeDomain(1, 126, timestamps=timestamps),
ppsci.geometry.Rectangle((0, 0), (15, 5)),
),
}
NTIME_ALL = len(timestamps)
NPOINT_PDE, NTIME_PDE = 300 * 100, NTIME_ALL - 1
input_dict = geom["time_rect_visu"].sample_interior(
NPOINT_PDE * NTIME_PDE, evenly=True
)
output_dict = predictor.predict(input_dict, cfg.INFER.batch_size)

# mapping data to cfg.INFER.output_keys
output_dict = {
store_key: output_dict[infer_key]
for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys())
}

# inverse normalization
p_pred = output_dict["p"].reshape([NTIME_PDE, NPOINT_PDE]).T
u_pred = output_dict["u"].reshape([NTIME_PDE, NPOINT_PDE]).T
v_pred = output_dict["v"].reshape([NTIME_PDE, NPOINT_PDE]).T
pred = {
"p": (p_pred * (p_max - p_min) + p_min).T.reshape([-1, 1]),
"u": (u_pred * (u_max - u_min) + u_min).T.reshape([-1, 1]),
"v": (v_pred * (v_max - v_min) + v_min).T.reshape([-1, 1]),
"phil": output_dict["phil"],
}
ppsci.visualize.save_vtu_from_dict(
"./visual/bubble_pred.vtu",
{
"t": input_dict["t"],
"x": input_dict["x"],
"y": input_dict["y"],
"u": pred["u"],
"v": pred["v"],
"p": pred["p"],
"phil": pred["phil"],
},
("t", "x", "y"),
("u", "v", "p", "phil"),
NTIME_PDE,
)


@hydra.main(version_base=None, config_path="./conf", config_name="bubble.yaml")
def main(cfg: DictConfig):
if cfg.mode == "train":
train(cfg)
elif cfg.mode == "eval":
evaluate(cfg)
elif cfg.mode == "export":
export(cfg)
elif cfg.mode == "infer":
inference(cfg)
else:
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
raise ValueError(
f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
)


if __name__ == "__main__":
Expand Down
19 changes: 19 additions & 0 deletions examples/bubble/conf/bubble.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ MODEL:
num_layers: 9
hidden_size: 30
activation: "tanh"
output_keys: ["u", "v", "p", "phil"]

# training settings
TRAIN:
Expand All @@ -65,3 +66,21 @@ TRAIN:
EVAL:
pretrained_model_path: null
eval_with_no_grad: true

# inference settings
INFER:
pretrained_model_path: https://paddle-org.bj.bcebos.com/paddlescience/models/bubble/bubble_pretrained.pdparams
export_path: ./inference/bubble
pdmodel_path: ${INFER.export_path}.pdmodel
pdiparams_path: ${INFER.export_path}.pdiparams
onnx_path: ${INFER.export_path}.onnx
device: gpu
engine: native
precision: fp32
ir_optim: true
min_subgraph_size: 5
gpu_mem: 2000
gpu_id: 0
max_batch_size: 8192
num_cpu_threads: 10
batch_size: 8192

0 comments on commit 7c04bf8

Please sign in to comment.