Skip to content

Commit

Permalink
Merge pull request #64 from RWTH-EBC/63-add-option-for-different-kpis…
Browse files Browse the repository at this point in the history
…-for-different-variables

#63 added capability for different kpis for different goal variables
  • Loading branch information
jkriwet authored Nov 11, 2024
2 parents fb1756c + df47550 commit fbf9137
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 11 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,5 @@ venv
.pytest_cache
htmlcov
.coverage
.vscode/settings.json
tests/testzone
36 changes: 26 additions & 10 deletions aixcalibuha/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""
import warnings
import logging
from typing import Union, Callable
from typing import Union, Callable, List
from copy import deepcopy
import pandas as pd
import numpy as np
Expand Down Expand Up @@ -68,7 +68,7 @@ class Goals:
def __init__(self,
meas_target_data: Union[TimeSeriesData, pd.DataFrame],
variable_names: dict,
statistical_measure: str,
statistical_measure: Union[str, List[str]],
weightings: list = None):
"""Initialize class-objects and check correct input."""

Expand Down Expand Up @@ -170,18 +170,34 @@ def statistical_measure(self):
return self._stat_meas

@statistical_measure.setter
def statistical_measure(self, statistical_measure: Union[str, Callable]):
def statistical_measure(self,
statistical_measure: Union[str, Callable, List[Union[str, Callable]]]):
"""
Set the new statistical measure. The value must be
supported by the method argument in the
``StatisticsAnalyzer`` class of ``ebcpy``.
"""
self._stat_analyzer = StatisticsAnalyzer(method=statistical_measure)
if callable(statistical_measure):
self._stat_meas = statistical_measure.__name__
else:
self._stat_meas = statistical_measure

def _get_stat_meas(statistical_measure):
if callable(statistical_measure):
return statistical_measure.__name__
return statistical_measure

self._stat_meas = None
if not isinstance(statistical_measure, list):
statistical_measure = [statistical_measure] * len(self.variable_names)
self._stat_meas = _get_stat_meas(statistical_measure[0])

if len(statistical_measure) != len(self.variable_names):
raise ValueError("The number of statistical measures does not match the number of goals.")

if self._stat_meas is None:
self._stat_meas = '_'.join([_get_stat_meas(i) for i in statistical_measure])

self._stat_analyzer = {}
for n, goal_name in enumerate(self.variable_names.keys()):
self._stat_analyzer[goal_name] = StatisticsAnalyzer(method=statistical_measure[n])


def eval_difference(self, verbose=False, penaltyfactor=1):
"""
Evaluate the difference of the measurement and simulated data based on the
Expand All @@ -208,7 +224,7 @@ def eval_difference(self, verbose=False, penaltyfactor=1):
"interval of measured and simulated data "
"are not equal. \nPlease check the frequencies "
"in the toml file (output_interval & frequency).")
_diff = self._stat_analyzer.calc(
_diff = self._stat_analyzer[goal_name].calc(
meas=self._tsd[(goal_name, self.meas_tag_str)],
sim=self._tsd[(goal_name, self.sim_tag_str)]
)
Expand Down
27 changes: 26 additions & 1 deletion tests/test_data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def test_goals(self):
goals = Goals(meas_target_data=meas_target_data,
variable_names=var_names,
statistical_measure="RMSE")

# Check set_sim_target_data:
goals.set_sim_target_data(sim_target_data)

Expand All @@ -83,6 +82,32 @@ def test_goals(self):
statistical_measure="RMSE",
weightings=weightings)

# Check different KPIs for different goals:
goals = Goals(meas_target_data=meas_target_data,
variable_names=var_names,
statistical_measure=["RMSE", "MAE"])
# Check set_sim_target_data:
goals.set_sim_target_data(sim_target_data)

# Set relevant time interval test:
goals.set_relevant_time_intervals([(0, 100)])

# Check the eval_difference function:
self.assertIsInstance(goals.eval_difference(), float)

with self.assertRaises(ValueError):
# Test if wrong statistical_measure raises an error.
Goals(meas_target_data=meas_target_data,
variable_names=var_names,
statistical_measure="not a valid KPI")

with self.assertRaises(ValueError):
# Test that the length of the statistical_measure list is equal to the number of variables.
goals = Goals(meas_target_data=meas_target_data,
variable_names=var_names,
statistical_measure=["RMSE", "MAE", "MSE"])


def test_tuner_paras(self):
"""Test the class TunerParas"""
dim = np.random.randint(1, 100)
Expand Down

0 comments on commit fbf9137

Please sign in to comment.