Skip to content

Commit

Permalink
Update to pydantic v2 (#297)
Browse files Browse the repository at this point in the history
* Add pyproject.toml

* Update to pydantic v2

* Make pydantic 2 changes so that tests can be run

* Fix named_attributes

* Use ValidationInfo for accessing other fields

* Use ValidationInfo

* Adjust error messages

* Update for changes in pydantic API

* Clean up NomenclatureConfig for pydantic v2

* Add coverage to dev dependencies

* Update type hints

* Update required_data and tests

* Update custom errors for pydantic 2

* Use new model_dump method

* Fix nightly workflow

* Use the new model_dump over dict

* Add ErrorCollector

* Switch to use PydanticCustomError

* Adjust to new each_item implementation

* Use ErrorCollector

* Use PydanticCustomErrors

* Adjust tests

* Simplify Code.__eq__

* Remove unused importa
  • Loading branch information
phackstock authored Dec 21, 2023
1 parent df34986 commit 72fec3f
Show file tree
Hide file tree
Showing 21 changed files with 813 additions and 597 deletions.
3 changes: 1 addition & 2 deletions nomenclature/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
from importlib.metadata import version
from pathlib import Path

import yaml

Expand All @@ -10,7 +9,7 @@
from nomenclature.countries import countries # noqa
from nomenclature.definition import SPECIAL_CODELIST, DataStructureDefinition # noqa
from nomenclature.processor import RegionAggregationMapping # noqa
from nomenclature.processor import RegionProcessor, RequiredDataValidator
from nomenclature.processor import RegionProcessor, RequiredDataValidator # noqa

# set up logging
logging.basicConfig(
Expand Down
75 changes: 32 additions & 43 deletions nomenclature/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import pycountry
from keyword import iskeyword
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Union
from pydantic import BaseModel, Field, validator
from typing import Any, Dict, List, Set, Union
from pydantic import field_validator, ConfigDict, BaseModel, Field, ValidationInfo

from pyam.utils import to_list

Expand All @@ -13,24 +13,25 @@ class Code(BaseModel):
"""A simple class for a mapping of a "code" to its attributes"""

name: str
description: Optional[str]
file: Optional[Union[str, Path]] = None
description: str | None = None
file: Union[str, Path] | None = None
extra_attributes: Dict[str, Any] = {}

def __eq__(self, other) -> bool:
return {key: value for key, value in self.dict().items() if key != "file"} == {
key: value for key, value in other.dict().items() if key != "file"
}
return self.model_dump(exclude="file") == other.model_dump(exclude="file")

@validator("extra_attributes")
def check_attribute_names(cls, v, values):
@field_validator("extra_attributes")
@classmethod
def check_attribute_names(
cls, v: Dict[str, Any], info: ValidationInfo
) -> Dict[str, Any]:
# Check that attributes only contains keys which are valid identifiers
if illegal_keys := [
key for key in v.keys() if not key.isidentifier() or iskeyword(key)
]:
raise ValueError(
"Only valid identifiers are allowed as attribute keys. Found "
f"'{illegal_keys}' in '{values['name']}' which are not allowed."
f"'{illegal_keys}' in '{info.data['name']}' which are not allowed."
)
return v

Expand Down Expand Up @@ -68,7 +69,7 @@ def from_dict(cls, mapping) -> "Code":

@classmethod
def named_attributes(cls) -> Set[str]:
return {a for a in cls.__dict__["__fields__"].keys() if a != "extra_attributes"}
return {a for a in cls.model_fields if a != "extra_attributes"}

@property
def contains_tags(self) -> bool:
Expand All @@ -80,15 +81,10 @@ def tags(self):

@property
def flattened_dict(self):
fields_set_alias = {
self.__fields__[field].alias for field in self.__fields_set__
}
return {
**{
k: v
for k, v in self.dict(by_alias=True).items()
if k != "extra_attributes" and k in fields_set_alias
},
**self.model_dump(
by_alias=True, exclude_unset=True, exclude="extra_attributes"
),
**self.extra_attributes,
}

Expand Down Expand Up @@ -148,25 +144,20 @@ def __setattr__(self, name, value):


class VariableCode(Code):
unit: Optional[Union[str, List[str]]] = Field(...)
weight: Optional[str] = None
region_aggregation: Optional[List[Dict[str, Dict]]] = Field(
unit: Union[str, List[str]] | None = Field(...)
weight: str | None = None
region_aggregation: List[Dict[str, Dict]] | None = Field(
None, alias="region-aggregation"
)
skip_region_aggregation: Optional[bool] = Field(
False, alias="skip-region-aggregation"
)
method: Optional[str] = None
check_aggregate: Optional[bool] = Field(False, alias="check-aggregate")
components: Optional[Union[List[str], List[Dict[str, List[str]]]]] = None
drop_negative_weights: Optional[bool] = None

class Config:
# this allows using both "check_aggregate" and "check-aggregate" for attribute
# setting
allow_population_by_field_name = True

@validator("region_aggregation", "components", "unit", pre=True)
skip_region_aggregation: bool | None = Field(False, alias="skip-region-aggregation")
method: str | None = None
check_aggregate: bool | None = Field(False, alias="check-aggregate")
components: Union[List[str], List[Dict[str, List[str]]]] | None = None
drop_negative_weights: bool | None = None
model_config = ConfigDict(populate_by_name=True)

@field_validator("region_aggregation", "components", "unit", mode="before")
@classmethod
def deserialize_json(cls, v):
try:
return json.loads(v) if isinstance(v, str) else v
Expand All @@ -180,9 +171,7 @@ def units(self) -> List[Union[str, None]]:
@classmethod
def named_attributes(cls) -> Set[str]:
return (
super()
.named_attributes()
.union(f.alias for f in cls.__dict__["__fields__"].values())
super().named_attributes().union(f.alias for f in cls.model_fields.values())
)

@property
Expand Down Expand Up @@ -225,16 +214,16 @@ class RegionCode(Code):
hierarchy: str = None
iso3_codes: Union[List[str], str] = None

@validator("iso3_codes")
def check_iso3_codes(cls, v, values) -> List[str]:
@field_validator("iso3_codes")
def check_iso3_codes(cls, v: List[str], info: ValidationInfo) -> List[str]:
"""Verifies that each ISO3 code is valid according to pycountry library."""
if invalid_iso3_codes := [
iso3_code
for iso3_code in to_list(v)
if pycountry.countries.get(alpha_3=iso3_code) is None
]:
raise ValueError(
f"Region '{values['name']}' has invalid ISO3 country code(s): "
f"Region '{info.data['name']}' has invalid ISO3 country code(s): "
+ ", ".join(invalid_iso3_codes)
)
return v
Expand All @@ -250,4 +239,4 @@ class MetaCode(Code):
"""

allowed_values: Optional[List[Any]]
allowed_values: List[Any] | None = None
63 changes: 34 additions & 29 deletions nomenclature/codelist.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,13 @@
import pandas as pd
import yaml
from pyam.utils import write_sheet
from pydantic import BaseModel, validator
from pydantic import field_validator, BaseModel, ValidationInfo
from pydantic_core import PydanticCustomError

import nomenclature
from nomenclature.code import Code, MetaCode, RegionCode, VariableCode
from nomenclature.config import NomenclatureConfig
from nomenclature.error.codelist import DuplicateCodeError
from nomenclature.error.variable import (
MissingWeightError,
VariableRenameArgError,
VariableRenameTargetError,
)
from nomenclature.error import custom_pydantic_errors
from pyam.utils import is_list_like

here = Path(__file__).parent.absolute()
Expand Down Expand Up @@ -45,8 +41,9 @@ class CodeList(BaseModel):
def __eq__(self, other):
return self.name == other.name and self.mapping == other.mapping

@validator("mapping")
def check_stray_tag(cls, v):
@field_validator("mapping")
@classmethod
def check_stray_tag(cls, v: Dict[str, Code]) -> Dict[str, Code]:
"""Check that no '{' are left in codes after tag replacement"""
for code in v:
if "{" in code:
Expand All @@ -56,20 +53,22 @@ def check_stray_tag(cls, v):
)
return v

@validator("mapping")
def check_end_whitespace(cls, v, values):
@field_validator("mapping")
def check_end_whitespace(
cls, v: Dict[str, Code], info: ValidationInfo
) -> Dict[str, Code]:
"""Check that no code ends with a whitespace"""
for code in v:
if code.endswith(" "):
raise ValueError(
f"Unexpected whitespace at the end of a {values['name']}"
f"Unexpected whitespace at the end of a {info.data['name']}"
f" code: '{code}'."
)
return v

def __setitem__(self, key, value):
if key in self.mapping:
raise DuplicateCodeError(name=self.name, code=key)
raise ValueError(f"Duplicate item in {self.name} codelist: {key}")
self.mapping[key] = value

def __getitem__(self, k):
Expand Down Expand Up @@ -154,7 +153,7 @@ def _parse_and_replace_tags(
for tag in _tag_list:
tag_name = next(iter(tag))
if tag_name in tag_dict:
raise DuplicateCodeError(name="tag", code=tag_name)
raise ValueError(f"Duplicate item in tag codelist: {tag_name}")
tag_dict[tag_name] = [Code.from_dict(t) for t in tag[tag_name]]

# start with all non tag codes
Expand Down Expand Up @@ -213,7 +212,7 @@ def from_directory(
mapping: Dict[str, Code] = {}
for code in code_list:
if code.name in mapping:
raise DuplicateCodeError(name=name, code=code.name)
raise ValueError(f"Duplicate item in {name} codelist: {code.name}")
mapping[code.name] = code
return cls(name=name, mapping=mapping)

Expand Down Expand Up @@ -469,7 +468,8 @@ def to_dimensionless(u):

return sorted(list(units))

@validator("mapping")
@field_validator("mapping")
@classmethod
def check_variable_region_aggregation_args(cls, v):
"""Check that any variable "region-aggregation" mappings are valid"""

Expand All @@ -478,39 +478,44 @@ def check_variable_region_aggregation_args(cls, v):
# pyam-aggregation-kwargs and a 'region-aggregation' attribute
if var.region_aggregation is not None:
if conflict_args := list(var.pyam_agg_kwargs.keys()):
raise VariableRenameArgError(
variable=var.name,
file=var.file,
args=conflict_args,
raise PydanticCustomError(
*custom_pydantic_errors.VariableRenameArgError,
{"variable": var.name, "file": var.file, "args": conflict_args},
)

# ensure that mapped variables are defined in the nomenclature
invalid = []
for inst in var.region_aggregation:
invalid.extend(var for var in inst if var not in v)
if invalid:
raise VariableRenameTargetError(
variable=var.name, file=var.file, target=invalid
raise PydanticCustomError(
*custom_pydantic_errors.VariableRenameTargetError,
{"variable": var.name, "file": var.file, "target": invalid},
)
return v

@validator("mapping")
@field_validator("mapping")
@classmethod
def check_weight_in_vars(cls, v):
"""Check that all variables specified in 'weight' are present in the codelist"""
if missing_weights := [
(var.name, var.weight, var.file)
for var in v.values()
if var.weight is not None and var.weight not in v
]:
raise MissingWeightError(
missing_weights="".join(
f"'{weight}' used for '{var}' in: {file}\n"
for var, weight, file in missing_weights
)
raise PydanticCustomError(
*custom_pydantic_errors.MissingWeightError,
{
"missing_weights": "".join(
f"'{weight}' used for '{var}' in: {file}\n"
for var, weight, file in missing_weights
)
},
)
return v

@validator("mapping")
@field_validator("mapping")
@classmethod
def cast_variable_components_args(cls, v):
"""Cast "components" list of dicts to a codelist"""

Expand Down
Loading

0 comments on commit 72fec3f

Please sign in to comment.