Skip to content

Commit

Permalink
Merge pull request #111 from CosmoStat/bug_plotting
Browse files Browse the repository at this point in the history
Bug plotting missing Shape Metrics e2 and R2
  • Loading branch information
jeipollack authored Mar 6, 2024
2 parents 478aed1 + 59ff07b commit 8186939
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 86 deletions.
217 changes: 136 additions & 81 deletions src/wf_psf/plotting/plots_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def make_plot(
x_axis,
y_axis,
y_axis_err,
y2_axis,
label,
plot_title,
x_axis_label,
Expand All @@ -69,6 +70,8 @@ def make_plot(
y-axis values
y_axis_err: list
Error values for y-axis points
y2_axis: list
y2-axis values for right axis
label: str
Label for the points
plot_title: str
Expand Down Expand Up @@ -114,7 +117,7 @@ def make_plot(
kwargs = dict(
linewidth=2, linestyle="dashed", markersize=4, marker="^", alpha=0.5
)
ax2.plot(x_axis[it], y_axis[it][k], **kwargs)
ax2.plot(x_axis[it], y2_axis[it][k], **kwargs)

plt.savefig(filename)

Expand Down Expand Up @@ -158,6 +161,7 @@ def __init__(
metric_name,
rmse,
std_rmse,
rel_rmse,
plot_title,
plots_dir,
):
Expand All @@ -166,6 +170,7 @@ def __init__(
self.metric_name = metric_name
self.rmse = rmse
self.std_rmse = std_rmse
self.rel_rmse = rel_rmse
self.plot_title = plot_title
self.plots_dir = plots_dir
self.list_of_stars = list_of_stars
Expand All @@ -189,6 +194,7 @@ def get_metrics(self, dataset):
"""
rmse = []
std_rmse = []
rel_rmse = []
metrics_id = []
for k, v in self.metrics.items():
for metrics_data in v:
Expand All @@ -210,7 +216,15 @@ def get_metrics(self, dataset):
}
)

return metrics_id, rmse, std_rmse
rel_rmse.append(
{
(k + "-" + run_id): metrics_data[run_id][0][dataset][
self.metric_name
][self.rel_rmse]
}
)

return metrics_id, rmse, std_rmse, rel_rmse

def plot(self):
"""Plot.
Expand All @@ -220,11 +234,12 @@ def plot(self):
"""
for plot_dataset in ["test_metrics", "train_metrics"]:
metrics_id, rmse, std_rmse = self.get_metrics(plot_dataset)
metrics_id, rmse, std_rmse, rel_rmse = self.get_metrics(plot_dataset)
make_plot(
x_axis=self.list_of_stars,
y_axis=rmse,
y_axis_err=std_rmse,
y2_axis=rel_rmse,
label=metrics_id,
plot_title="Stars " + plot_dataset + self.plot_title,
x_axis_label="Number of stars",
Expand Down Expand Up @@ -295,6 +310,7 @@ def plot(self):
for plot_dataset in ["test_metrics", "train_metrics"]:
y_axis = []
y_axis_err = []
y2_axis = []
metrics_id = []

for k, v in self.metrics.items():
Expand All @@ -316,11 +332,19 @@ def plot(self):
]["mono_metric"]["std_rmse_lda"]
}
)
y2_axis.append(
{
(k + "-" + run_id): metrics_data[run_id][0][
plot_dataset
]["mono_metric"]["rel_rmse_lda"]
}
)

make_plot(
x_axis=[lambda_list for _ in range(len(y_axis))],
y_axis=y_axis,
y_axis_err=y_axis_err,
y2_axis=y2_axis,
label=metrics_id,
plot_title="Stars "
+ plot_dataset # type: ignore
Expand All @@ -343,10 +367,8 @@ def plot(self):

class ShapeMetricsPlotHandler:
"""ShapeMetricsPlotHandler class.
A class to handle plot parameters shape
metrics results.
Parameters
----------
id: str
Expand All @@ -359,7 +381,6 @@ class ShapeMetricsPlotHandler:
List containing the number of stars used for each training data set
plots_dir: str
Output directory for metrics plots
"""

id = "shape_metrics"
Expand All @@ -373,96 +394,126 @@ def __init__(self, plotting_params, metrics, list_of_stars, plots_dir):
def plot(self):
"""Plot.
A function to generate plots for the train and test
A generic function to generate plots for the train and test
metrics.
"""
# Define common data
# Common data
e1_req_euclid = 2e-04
e2_req_euclid = 2e-04
R2_req_euclid = 1e-03

for plot_dataset in ["test_metrics", "train_metrics"]:
e1_rmse = []
e1_std_rmse = []
e2_rmse = []
e2_std_rmse = []
rmse_R2_meanR2 = []
std_rmse_R2_meanR2 = []
metrics_id = []
metrics_data = self.prepare_metrics_data(
plot_dataset, e1_req_euclid, e2_req_euclid, R2_req_euclid
)

for k, v in self.metrics.items():
for metrics_data in v:
run_id = list(metrics_data.keys())[0]
metrics_id.append(run_id + "-" + k)

e1_rmse.append(
{
(k + "-" + run_id): metrics_data[run_id][0][plot_dataset][
"shape_results_dict"
]["rmse_e1"]
/ e1_req_euclid
}
)
e1_std_rmse.append(
{
(k + "-" + run_id): metrics_data[run_id][0][plot_dataset][
"shape_results_dict"
]["std_rmse_e1"]
}
# Plot for e1
for k, v in metrics_data.items():
self.make_shape_metrics_plot(
metrics_data[k]["rmse"],
metrics_data[k]["std_rmse"],
metrics_data[k]["rel_rmse"],
plot_dataset,
k,
)

def prepare_metrics_data(
self, plot_dataset, e1_req_euclid, e2_req_euclid, R2_req_euclid
):
"""Prepare Metrics Data.
A function to prepare the metrics data for plotting.
Parameters
----------
plot_dataset: str
A string representing the dataset, i.e. training or test metrics.
e1_req_euclid: float
A float denoting the Euclid requirement for the `e1` shape metric.
e2_req_euclid: float
A float denoting the Euclid requirement for the `e2` shape metric.
R2_req_euclid: float
A float denoting the Euclid requirement for the `R2` shape metric.
Returns
-------
shape_metrics_data: dict
A dictionary containing the shape metrics data from a set of runs.
"""
shape_metrics_data = {
"e1": {"rmse": [], "std_rmse": [], "rel_rmse": []},
"e2": {"rmse": [], "std_rmse": [], "rel_rmse": []},
"R2_meanR2": {"rmse": [], "std_rmse": [], "rel_rmse": []},
}

for k, v in self.metrics.items():
for metrics_data in v:
run_id = list(metrics_data.keys())[0]

for metric in ["e1", "e2", "R2_meanR2"]:
metric_rmse = metrics_data[run_id][0][plot_dataset][
"shape_results_dict"
][f"rmse_{metric}"]
metric_std_rmse = metrics_data[run_id][0][plot_dataset][
"shape_results_dict"
][f"std_rmse_{metric}"]

relative_metric_rmse = metric_rmse / (
e1_req_euclid
if metric == "e1"
else (e2_req_euclid if metric == "e2" else R2_req_euclid)
)

e2_rmse.append(
{
(k + "-" + run_id): metrics_data[run_id][0][plot_dataset][
"shape_results_dict"
]["rmse_e2"]
/ e2_req_euclid
}
shape_metrics_data[metric]["rmse"].append(
{f"{k}-{run_id}": metric_rmse}
)
e2_std_rmse.append(
{
(k + "-" + run_id): metrics_data[run_id][0][plot_dataset][
"shape_results_dict"
]["std_rmse_e2"]
}
shape_metrics_data[metric]["std_rmse"].append(
{f"{k}-{run_id}": metric_std_rmse}
)

rmse_R2_meanR2.append(
{
(k + "-" + run_id): metrics_data[run_id][0][plot_dataset][
"shape_results_dict"
]["rmse_R2_meanR2"]
/ R2_req_euclid
}
shape_metrics_data[metric]["rel_rmse"].append(
{f"{k}-{run_id}": relative_metric_rmse}
)

std_rmse_R2_meanR2.append(
{
(k + "-" + run_id): metrics_data[run_id][0][plot_dataset][
"shape_results_dict"
]["std_rmse_R2_meanR2"]
}
)
return shape_metrics_data

make_plot(
x_axis=self.list_of_stars,
y_axis=e1_rmse,
y_axis_err=e1_std_rmse,
label=metrics_id,
plot_title="Stars " + plot_dataset + ".\nShape RMSE",
x_axis_label="Number of stars",
y_left_axis_label="Absolute error",
y_right_axis_label="Relative error [%]",
filename=os.path.join(
self.plots_dir,
plot_dataset
+ "_nstars_"
+ "_".join(str(nstar) for nstar in self.list_of_stars)
+ "_Shape_RMSE.png",
),
plot_show=self.plotting_params.plot_show,
)
def make_shape_metrics_plot(
self, rmse_data, std_rmse_data, rel_rmse_data, plot_dataset, metric
):
"""Make Shape Metrics Plot.
A function to produce plots for the shape metrics.
Parameters
----------
rmse_data: list
A list of dictionaries where each dictionary stores run as the key and the Root Mean Square Error (rmse).
std_rmse_data: list
A list of dictionaries where each dictionary stores run as the key and the Standard Deviation of the Root Mean Square Error (rmse) as the value.
rel_rmse_data: list
A list of dictionaries where each dictionary stores run as the key and the Root Mean Square Error (rmse) relative to the Euclid requirements as the value.
plot_dataset: str
A string denoting whether metrics are for the train or test datasets.
metric: str
A string representing the type of shape metric, i.e., e1, e2, or R2.
"""
make_plot(
x_axis=self.list_of_stars,
y_axis=rmse_data,
y_axis_err=std_rmse_data,
y2_axis=rel_rmse_data,
label=[key for item in rmse_data for key in item],
plot_title=f"Stars {plot_dataset}. Shape {metric.upper()} RMSE",
x_axis_label="Number of stars",
y_left_axis_label="Absolute error",
y_right_axis_label="Relative error [%]",
filename=os.path.join(
self.plots_dir,
f"{plot_dataset}_nstars_{'_'.join(str(nstar) for nstar in self.list_of_stars)}_Shape_{metric.upper()}_RMSE.png",
),
plot_show=self.plotting_params.plot_show,
)


def get_number_of_stars(metrics):
Expand Down Expand Up @@ -509,16 +560,19 @@ def plot_metrics(plotting_params, list_of_metrics, metrics_confs, plot_saving_pa
"poly_metric": {
"rmse": "rmse",
"std_rmse": "std_rmse",
"rel_rmse": "rel_rmse",
"plot_title": ".\nPolychromatic pixel RMSE @ Euclid resolution",
},
"opd_metric": {
"rmse": "rmse_opd",
"std_rmse": "rmse_std_opd",
"rel_rmse": "rel_rmse_opd",
"plot_title": ".\nOPD RMSE",
},
"shape_results_dict": {
"rmse": "pix_rmse",
"std_rmse": "pix_rmse_std",
"rel_rmse": "rel_pix_rmse",
"plot_title": "\nPixel RMSE @ 3x Euclid resolution",
},
}
Expand All @@ -533,6 +587,7 @@ def plot_metrics(plotting_params, list_of_metrics, metrics_confs, plot_saving_pa
k,
v["rmse"],
v["std_rmse"],
v["rel_rmse"],
v["plot_title"],
plot_saving_path,
)
Expand Down
8 changes: 4 additions & 4 deletions src/wf_psf/tests/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def main_metrics(training_params):
return np.load(os.path.join(main_dir, metrics_filename), allow_pickle=True)[()]


@pytest.mark.skip(reason="Requires gpu")
@pytest.mark.skipif("GITHUB_ENV" in os.environ, reason="Skipping GPU tests in CI")
def test_eval_metrics_polychromatic_lowres(
training_params,
weights_path_basename,
Expand Down Expand Up @@ -156,7 +156,7 @@ def test_eval_metrics_polychromatic_lowres(
assert ratio_rel_std_rmse < tol


@pytest.mark.skip(reason="Requires gpu")
@pytest.mark.skipif("GITHUB_ENV" in os.environ, reason="Skipping GPU tests in CI")
def test_evaluate_metrics_opd(
training_params,
weights_path_basename,
Expand Down Expand Up @@ -206,7 +206,7 @@ def test_evaluate_metrics_opd(
assert ratio_rel_rmse_std_opd < tol


@pytest.mark.skip(reason="Requires gpu")
@pytest.mark.skipif("GITHUB_ENV" in os.environ, reason="Skipping GPU tests in CI")
def test_eval_metrics_mono_rmse(
training_params,
weights_path_basename,
Expand Down Expand Up @@ -271,7 +271,7 @@ def test_eval_metrics_mono_rmse(
assert ratio_rel_rmse_std_mono < tol


@pytest.mark.skip(reason="Requires gpu")
@pytest.mark.skipif("GITHUB_ENV" in os.environ, reason="Skipping GPU tests in CI")
def test_evaluate_metrics_shape(
training_params,
weights_path_basename,
Expand Down
Loading

0 comments on commit 8186939

Please sign in to comment.