Skip to content

Commit

Permalink
Compare report output for "all::iamc"
Browse files Browse the repository at this point in the history
  • Loading branch information
khaeru committed Apr 19, 2024
1 parent a7ef3f5 commit dd518e9
Showing 1 changed file with 49 additions and 25 deletions.
74 changes: 49 additions & 25 deletions message_ix_models/tests/test_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import re
from importlib.metadata import version
from typing import List
from typing import List, Optional

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -315,10 +315,37 @@ def test_prepare_reporter(test_context):
assert 14299 <= len(rep.graph) - N


# Filters for comparison
PE0 = r"Primary Energy\|(Coal|Gas|Hydro|Nuclear|Solar|Wind)"
PE1 = r"Primary Energy\|(Coal|Gas|Solar|Wind)"
E = (
r"Emissions\|CO2\|Energy\|Demand\|Transportation\|Road Rail and Domestic "
"Shipping"
)

IGNORE = [
# Other 'variable' codes are missing from `obs`
re.compile(f"variable='(?!{PE0}).*': no right data"),
# 'variable' codes with further parts are missing from `obs`
re.compile(f"variable='{PE0}.*': no right data"),
# For `pe1` (NB: not Hydro or Solar) units and most values differ
re.compile(f"variable='{PE1}.*': units mismatch .*EJ/yr.*'', nan"),
re.compile(r"variable='Primary Energy|Coal': 220 of 240 values with \|diff"),
re.compile(r"variable='Primary Energy|Gas': 234 of 240 values with \|diff"),
re.compile(r"variable='Primary Energy|Solar': 191 of 240 values with \|diff"),
re.compile(r"variable='Primary Energy|Wind': 179 of 240 values with \|diff"),
# For `e` units and most values differ
re.compile(f"variable='{E}': units mismatch: .*Mt CO2/yr.*Mt / a"),
re.compile(rf"variable='{E}': 20 missing right entries"),
re.compile(rf"variable='{E}': 220 of 240 values with \|diff"),
]


@to_simulate.minimum_version
def test_compare(test_context):
"""Compare the output of genno-based and legacy reporting."""
key = "pe test"
key = "all::iamc"
# key = "pe test"

# Obtain the output from reporting `key` on `snapshot_id`
snapshot_id: int = 1
Expand All @@ -340,24 +367,8 @@ def test_compare(test_context):
engine="pyarrow",
)

# Filters for comparison
pe0 = r"Primary Energy\|(Coal|Gas|Hydro|Nuclear|Solar|Wind)"
pe1 = r"Primary Energy\|(Coal|Gas|Solar|Wind)"
ignore = [
# Other 'variable' codes are missing from `obs`
re.compile(f"variable='(?!{pe0}).*': no right data"),
# 'variable' codes with further parts are missing from `obs`
re.compile(f"variable='{pe0}.*': no right data"),
# For `pe1` (NB: not Hydro or Solar) units and most values differ
re.compile(f"variable='{pe1}.*': units mismatch .*EJ/yr.*'', nan"),
re.compile(r"variable='Primary Energy|Coal': 220 of 240 values with \|diff"),
re.compile(r"variable='Primary Energy|Gas': 234 of 240 values with \|diff"),
re.compile(r"variable='Primary Energy|Solar': 191 of 240 values with \|diff"),
re.compile(r"variable='Primary Energy|Wind': 179 of 240 values with \|diff"),
]

# Perform the comparison, ignoring some messages
if messages := compare_iamc(exp, obs, ignore=ignore):
if messages := compare_iamc(exp, obs, ignore=IGNORE):
# Other messages that were not explicitly ignored → some error
print("\n".join(messages))
assert False
Expand All @@ -369,8 +380,8 @@ def compare_iamc(
"""Compare IAMC-structured data in `left` and `right`; return a list of messages."""
result = []

def record(message: str) -> None:
if any(p.match(message) for p in ignore):
def record(message: str, condition: Optional[bool] = True) -> None:
if not condition or any(p.match(message) for p in ignore):
return
result.append(message)

Expand All @@ -388,16 +399,29 @@ def checks(df: pd.DataFrame):
"value_rel = value_diff / value_left"
)

na_left = tmp.isna()[["unit_left", "value_left"]]
if na_left.any(axis=None):
record(f"{prefix} {na_left.sum(axis=0).max()} missing left entries")
tmp = tmp[~na_left.any(axis=1)]
na_right = tmp.isna()[["unit_right", "value_right"]]
if na_right.any(axis=None):
record(f"{prefix} {na_right.sum(axis=0).max()} missing right entries")
tmp = tmp[~na_right.any(axis=1)]

units_left = set(tmp.unit_left.unique())
units_right = set(tmp.unit_right.unique())
if units_left != units_right:
record(f"{prefix} units mismatch: {units_left} != {units_right}")
record(
condition=units_left != units_right,
message=f"{prefix} units mismatch: {units_left} != {units_right}",
)

N0 = len(df)

mask1 = tmp.query("abs(value_diff) > @atol")
if len(mask1):
record(f"{prefix} {len(mask1)} of {N0} values with |diff| > {atol}")
record(
condition=len(mask1),
message=f"{prefix} {len(mask1)} of {N0} values with |diff| > {atol}",
)

for (model, scenario), group_0 in left.merge(
right,
Expand Down

0 comments on commit dd518e9

Please sign in to comment.