-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #115 from aai-institute/feature/dvc
Feature: dvc
- Loading branch information
Showing
36 changed files
with
493 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
*.png filter=lfs diff=lfs merge=lfs -text |
File renamed without changes.
File renamed without changes
File renamed without changes
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import torch | ||
import matplotlib.pyplot as plt | ||
from continuity.benchmarks import NavierStokes | ||
from continuity.operators import FourierNeuralOperator | ||
|
||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
|
||
ns = NavierStokes() | ||
|
||
operator = FourierNeuralOperator( | ||
ns.train_dataset.shapes, | ||
grid_shape=(64, 64, 10), | ||
width=32, | ||
depth=4, | ||
device=device, | ||
) | ||
|
||
operator.load( | ||
"mlruns/271016623891034109/8755b17d3af9494db843e3a8d0c42ad6/artifacts/final.pt" | ||
) | ||
operator.eval() | ||
|
||
# Compute train loss | ||
loss_fn = ns.losses[0] | ||
|
||
|
||
def compute_loss(dataset): | ||
train_loader = torch.utils.data.DataLoader(dataset, batch_size=1) | ||
avg_loss = 0 | ||
max_loss, min_loss = 0, 1e10 | ||
max_i, min_i = 0, 0 | ||
for i, xuyv in enumerate(train_loader): | ||
x, u, y, v = [t.to(device) for t in xuyv] | ||
loss = loss_fn(operator, x, u, y, v) | ||
avg_loss += loss.detach() | ||
if loss > max_loss: | ||
max_loss = loss | ||
max_i = i | ||
if loss < min_loss: | ||
min_loss = loss | ||
min_i = i | ||
avg_loss = avg_loss / len(train_loader) | ||
return avg_loss, max_loss, max_i, min_loss, min_i | ||
|
||
|
||
loss_train, max_loss, max_i_train, min_loss, min_i_train = compute_loss( | ||
ns.train_dataset | ||
) | ||
print(f"rel. error train = {loss_train:.4e}") | ||
print(f"min loss = {min_loss:.4e} at index {min_i_train}") | ||
print(f"max loss = {max_loss:.4e} at index {max_i_train}") | ||
|
||
# Compute test loss | ||
loss_test, max_loss, max_i_test, min_loss, min_i_test = compute_loss(ns.test_dataset) | ||
print(f"rel. error test = {loss_test:.4e}") | ||
print(f"min loss = {min_loss:.4e} at index {min_i_test}") | ||
print(f"max loss = {max_loss:.4e} at index {max_i_test}") | ||
|
||
|
||
# Plot | ||
def plot_sample(split, sample): | ||
dataset = ns.train_dataset if split == "train" else ns.test_dataset | ||
x, u, y, v = [t.to(device) for t in dataset[sample : sample + 1]] | ||
v_pred = operator(x, u, y) | ||
v = v.reshape(1, 64, 64, 10, 1).cpu() | ||
v_pred = v_pred.reshape(1, 64, 64, 10, 1).detach().cpu() | ||
|
||
fig, axs = plt.subplots(10, 3, figsize=(4, 16)) | ||
|
||
axs[0][0].set_title("Truth") | ||
axs[0][1].set_title("Prediction") | ||
axs[0][2].set_title("Error") | ||
for t in range(10): | ||
axs[t][0].imshow(v[0, :, :, t, 0], cmap="jet") | ||
axs[t][1].imshow(v_pred[0, :, :, t, 0], cmap="jet") | ||
im = axs[t][2].imshow((v - v_pred)[0, :, :, t, 0], cmap="jet") | ||
fig.colorbar(im, ax=axs[t][2]) | ||
axs[t][0].axis("off") | ||
axs[t][1].axis("off") | ||
axs[t][2].axis("off") | ||
|
||
plt.tight_layout() | ||
plt.savefig(f"navierstokes/ns_{split}_{sample}.png", dpi=500) | ||
|
||
|
||
plot_sample("train", min_i_train) | ||
plot_sample("train", max_i_train) | ||
plot_sample("test", min_i_test) | ||
plot_sample("test", max_i_test) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
rel. error train = 1.8508e-02 | ||
min loss = 8.8748e-03 at index 237 | ||
max loss = 3.1433e-02 at index 420 | ||
rel. error test = 1.8408e-01 | ||
min loss = 1.0220e-01 at index 144 | ||
max loss = 4.4655e-01 at index 179 | ||
|
||
-- | ||
|
||
Reference: _Li, Zongyi, et al. "Fourier neural operator for parametric partial | ||
differential equations." arXiv preprint arXiv:2010.08895 (2020)_ | ||
|
||
Table 1: nu=1e−5 T=20 N=1000 | ||
FNO-3D: 0.1893 | ||
|
||
(Ours: 1.8408e-01) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from functools import partial | ||
from continuity.benchmarks.run import BenchmarkRunner, RunConfig | ||
from continuity.benchmarks import NavierStokes | ||
from continuity.operators import FourierNeuralOperator | ||
|
||
config = RunConfig( | ||
benchmark_factory=NavierStokes, | ||
operator_factory=partial( | ||
FourierNeuralOperator, | ||
grid_shape=(64, 64, 10), | ||
width=32, | ||
depth=4, | ||
), | ||
lr=1e-3, | ||
max_epochs=100, | ||
batch_size=10, | ||
) | ||
|
||
if __name__ == "__main__": | ||
BenchmarkRunner.run(config) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
/config.local | ||
/tmp | ||
/cache |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
[core] | ||
remote = gdrive | ||
['remote "gdrive"'] | ||
url = gdrive://1Mts9tmPjKzqw-as_j1XTRwa9ulNssy5a |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Add patterns of files dvc should ignore, which could improve | ||
# the performance. Learn more at | ||
# https://dvc.org/doc/user-guide/dvcignore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
/flame | ||
/navierstokes |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Data | ||
|
||
## Prerequisites | ||
|
||
We use `dvc` to manage the data. You can install the required packages by | ||
installing the benchmarks requirements. | ||
|
||
``` | ||
pip install -e .[benchmarks] | ||
``` | ||
|
||
## Downloading the data | ||
|
||
The data is stored in a remote storage on GDrive. | ||
To download the data, you can run: | ||
|
||
``` | ||
cd data | ||
dvc pull <NAME> | ||
``` | ||
|
||
where `<NAME>` is the name of the data set you want to download, | ||
e.g., `flame` or `navierstokes`, or empty. | ||
|
||
|
||
## Data sets | ||
|
||
### FLAME | ||
|
||
`data/flame` contains the dataset from [2023 FLAME AI | ||
Challenge](https://www.kaggle.com/competitions/2023-flame-ai-challenge/data). | ||
|
||
### Navier-Stokes | ||
|
||
`data/navierstokes` contains a part of the dataset linked in | ||
[neuraloperator/graph-pde](https://github.com/neuraloperator/graph-pde) | ||
(Zongyi Li et al. 2020). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
outs: | ||
- md5: 2e61c8311b09a4fdf29d3ec3527cf629.dir | ||
size: 415040265 | ||
nfiles: 13138 | ||
path: flame |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
outs: | ||
- md5: 0bb228674e976bab14e9493606e14a27.dir | ||
size: 412877192 | ||
nfiles: 1 | ||
hash: md5 | ||
path: navierstokes |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
table.html | ||
style.css | ||
img | ||
*.png | ||
*.svg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.