Skip to content

Commit

Permalink
Merge pull request #70 from CosmoStat/refactor_plotting
Browse files Browse the repository at this point in the history
Refactor plotting
  • Loading branch information
sfarrens authored Sep 29, 2023
2 parents 07367b1 + 77011c6 commit 2458a7f
Show file tree
Hide file tree
Showing 61 changed files with 2,087 additions and 1,020 deletions.
2 changes: 0 additions & 2 deletions config/configs.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
---
data_conf: data_config.yaml
training_conf: training_config.yaml
metrics_conf: metrics_config.yaml
15 changes: 0 additions & 15 deletions config/features_config.yaml

This file was deleted.

24 changes: 11 additions & 13 deletions config/metrics_config.yaml
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
metrics:
# Set flag to evaluate model weights; if False, then final updated model will be evaluated
use_callback: False
# Choose the training cycle for the evaluation. Can be: 1, 2, ...
saved_training_cycle: 1
# Provide path to saved model checkpoints.
chkp_save_path: checkpoint
# Fill model_params if computing metrics_only on a pre-trained model
# ID name
id_name: -coherent_euclid_200stars
# Path to Trained Model
trained_model_path: /Users/jenniferpollack/Projects/wf-outputs/wf-outputs-202305262344/
# Name of Trained Model Config file
trained_model_config: config/training_config.yaml
# Specify the type of model weights to load by entering "psf_model" to load weights of final psf model or "checkpoint" to load weights from a checkpoint callback.
model_save_path: <enter psf_model or checkpoint>
# Choose the training cycle for which to evaluate the psf_model. Can be: 1, 2, ...
saved_training_cycle: 2
# Metrics-only run: Specify model_params for a pre-trained model else leave blank if running training + metrics
# Specify path to Trained Model
trained_model_path: </path/to/trained/model>
# Path to Trained Model Config file inside /trained_model_path/ parent directory
trained_model_config: </path/to/trained/model>
# Name of Plotting Config file - Enter name of yaml file to run plot metrics else if empty run metrics evaluation only
plotting_config: <enter name of plotting_config .yaml file or leave empty>
#Evaluate the monchromatic RMSE metric.
eval_mono_metric_rmse: True
#Evaluate the OPD RMSE metric.
Expand Down
12 changes: 12 additions & 0 deletions config/plotting_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
plotting_params:
# Specify path to parent folder containing wf-psf metrics outputs for all runs, ex: $WORK/wf-outputs/
metrics_output_path: <PATH>
# List directory(s) for metrics of trained PSF models to include in plot,
metrics_dir:
# - wf-outputs-xxxxxxxxxxx1
# - wf-outputs-xxxxxxxxxxx2
# List of name of metric config file to add to plot (would like to change such that code goes and finds them in the metrics_dir)
metrics_config:
# - metrics_config_1.yaml
# - metrics_config_2.yaml
plot_show: False
7 changes: 7 additions & 0 deletions config/training_config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
training:
# ID name
id_name: -coherent_euclid_200stars
# Name of Data Config file
data_config: data_config.yaml
# Metrics Config file - Enter file to run metrics evaluation else if empty run train only
metrics_config: metrics_config.yaml
model_params:
# Model type. Options are: 'mccd', 'graph', 'poly, 'param', 'poly_physical'."
model_name: poly
Expand Down Expand Up @@ -46,6 +50,9 @@ training:

# Hyperparameters for Parametric model
param_hparams:
# Random seed for Tensor Flow Initialization
random_seed: 3877572

# Parameter for the l2 loss function for the Optical path differences (OPD)/WFE
l2_param: 0.

Expand Down
83 changes: 83 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
[project]
name = "wf_psf"
requires-python = ">=3.9"
authors = [
{ "name" = "Tobias Liaudat", "email" = "[email protected]"},
{ "name" = "Jennifer Pollack", "email" = "[email protected]"},
]
maintainers = [
{ "name" = "Jennifer Pollack", "email" = "[email protected]" },
]

description = 'A software framework to perform Differentiable wavefront-based PSF modelling.'
dependencies = [
"numpy>=1.19.2",
"scipy>=1.11.2",
"tensorflow>=2.9.1",
"tensorflow-addons>=0.12.1",
"tensorflow-estimator>=2.9.0",
"zernike==0.0.31",
"opencv-python>=4.5.1.48",
"pillow>=9.5.0",
"galsim>=2.4.11",
"astropy>=5.3.3",
"matplotlib>=3.3.2",
"seaborn>=0.12.2",
]
version = "1.0.0"

[project.optional-dependencies]
docs = [
"importlib_metadata",
"myst-parser",
"numpydoc",
"sphinx",
"sphinxcontrib-bibtex",
"sphinxawesome-theme",
"sphinx-gallery",
]

lint = [
"black",
]

release = [
"build",
"twine",
]

test = [
"pytest",
"pytest-black",
"pytest-cases",
"pytest-cov",
"pytest-emoji",
"pytest-pydocstyle",
"pytest-raises",
"pytest-xdist",
]

# Install for development
dev = ["wf_psf[docs,lint,release,test]"]

[project.scripts]
wavediff = "wf_psf.run:mainMethod"

[tool.black]
line-length = 88

[tool.pydocstyle]
convention = "numpy"

[tool.pytest.ini_options]
addopts = [
"--verbose",
"--black",
"--emoji",
"--pydocstyle",
"--cov=wf_psf",
"--cov-report=term-missing",
"--cov-report=xml",
"--junitxml=pytest.xml",
]
testpaths = ["src/wf_psf"]
25 changes: 0 additions & 25 deletions setup.py

This file was deleted.

File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import wf_psf.utils.utils as utils
import tensorflow as tf
import tensorflow_addons as tfa
import wf_psf.SimPSFToolkit as SimPSFToolkit
import wf_psf.sims.SimPSFToolkit as SimPSFToolkit
import os


Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import wf_psf.utils.utils as utils
from wf_psf.psf_models.tf_psf_field import build_PSF_model
from wf_psf.psf_models import tf_psf_field as psf_field
from wf_psf import SimPSFToolkit as SimPSFToolkit
from wf_psf.sims import SimPSFToolkit as SimPSFToolkit
import logging

logger = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,20 @@
logger = logging.getLogger(__name__)


def ground_truth_psf_model(metrics_params, coeff_matrix):
psf_model = psf_models.get_psf_model(
metrics_params.ground_truth_model.model_params,
metrics_params.metrics_hparams.batch_size,
)
psf_model.tf_poly_Z_field.assign_coeff_matrix(coeff_matrix)

psf_model.tf_np_poly_opd.alpha_mat.assign(
np.zeros_like(psf_model.tf_np_poly_opd.alpha_mat) # type: ignore
)

return psf_model


class MetricsParamsHandler:
"""Metrics Parameters Handler.
Expand All @@ -43,15 +57,6 @@ def __init__(self, metrics_params, trained_model):
self.metrics_params = metrics_params
self.trained_model = trained_model

@property
def ground_truth_psf_model(self):
psf_model = psf_models.get_psf_model(
self.metrics_params.ground_truth_model.model_params,
self.metrics_params.metrics_hparams.batch_size,
)
psf_model.set_zero_nonparam()
return psf_model

def evaluate_metrics_polychromatic_lowres(self, psf_model, simPSF, dataset):
"""Evaluate Polychromatic PSF Low-Res Metrics.
Expand All @@ -74,9 +79,12 @@ def evaluate_metrics_polychromatic_lowres(self, psf_model, simPSF, dataset):
"""
logger.info("Computing polychromatic metrics at low resolution.")

rmse, rel_rmse, std_rmse, std_rel_rmse = wf_metrics.compute_poly_metric(
tf_semiparam_field=psf_model,
GT_tf_semiparam_field=self.ground_truth_psf_model,
GT_tf_semiparam_field=ground_truth_psf_model(
self.metrics_params, dataset["C_poly"]
),
simPSF_np=simPSF,
tf_pos=dataset["positions"],
tf_SEDs=dataset["SEDs"],
Expand Down Expand Up @@ -124,7 +132,9 @@ def evaluate_metrics_mono_rmse(self, psf_model, simPSF, dataset):
std_rel_rmse_lda,
) = wf_metrics.compute_mono_metric(
tf_semiparam_field=psf_model,
GT_tf_semiparam_field=self.ground_truth_psf_model,
GT_tf_semiparam_field=ground_truth_psf_model(
self.metrics_params, dataset["C_poly"]
),
simPSF_np=simPSF,
tf_pos=dataset["positions"],
lambda_list=lambda_list,
Expand Down Expand Up @@ -167,7 +177,9 @@ def evaluate_metrics_opd(self, psf_model, simPSF, dataset):
rel_rmse_std_opd,
) = wf_metrics.compute_opd_metrics(
tf_semiparam_field=psf_model,
GT_tf_semiparam_field=self.ground_truth_psf_model,
GT_tf_semiparam_field=ground_truth_psf_model(
self.metrics_params, dataset["C_poly"]
),
pos=dataset["positions"],
batch_size=self.metrics_params.metrics_hparams.batch_size,
)
Expand Down Expand Up @@ -204,7 +216,9 @@ def evaluate_metrics_shape(self, psf_model, simPSF, dataset):

shape_results = wf_metrics.compute_shape_metrics(
tf_semiparam_field=psf_model,
GT_tf_semiparam_field=self.ground_truth_psf_model,
GT_tf_semiparam_field=ground_truth_psf_model(
self.metrics_params, dataset["C_poly"]
),
simPSF_np=simPSF,
SEDs=dataset["SEDs"],
tf_pos=dataset["positions"],
Expand Down Expand Up @@ -259,6 +273,7 @@ def evaluate_model(
# Get training data
logger.info(f"Fetching and preprocessing training and test data...")

# Initialize metrics_handler
metrics_handler = MetricsParamsHandler(metrics_params, trained_model_params)

## Prepare models
Expand All @@ -267,6 +282,7 @@ def evaluate_model(

## Load the model's weights
try:
logger.info("Loading PSF model weights from {}".format(weights_path))
psf_model.load_weights(weights_path)
except:
logger.exception("An error occurred with the weights_path file.")
Expand Down Expand Up @@ -325,15 +341,15 @@ def evaluate_model(
psf_model, simPSF_np, training_data.train_dataset
)
else:
trained_monoe_metric = None
train_mono_metric = None

# OPD metrics turn into a class
if metrics_params.eval_opd_metric_rmse:
train_opd_metric = metrics_handler.evaluate_metrics_opd(
psf_model, simPSF_np, training_data.train_dataset
)
else:
trained_opd_metric = None
train_opd_metric = None

# Shape metrics turn into a class
if metrics_params.eval_train_shape_sr_metric_rmse:
Expand All @@ -354,7 +370,7 @@ def evaluate_model(
## Save results
metrics = {"test_metrics": test_metrics, "train_metrics": train_metrics}
run_id_name = (
trained_model_params.model_params.model_name + metrics_params.id_name
trained_model_params.model_params.model_name + trained_model_params.id_name
)
output_path = metrics_output + "/" + "metrics-" + run_id_name
np.save(output_path, metrics, allow_pickle=True)
Expand All @@ -366,6 +382,7 @@ def evaluate_model(
## Close log file
print("\n Good bye..")

return metrics
except Exception as e:
print("Error: %s" % e)
raise
File renamed without changes.
Loading

0 comments on commit 2458a7f

Please sign in to comment.