Skip to content

Commit

Permalink
Add more detailed exception message for bad topology graph. (#521)
Browse files Browse the repository at this point in the history
Related Issues: #463 #462 
Reviewed By: @lukasturcani

This  commit starts the process of adding more clear errors
for issues with construction with stk, to help the user as they
build new topology graphs.
  • Loading branch information
andrewtarzia authored Feb 13, 2024
1 parent 9a818c8 commit 1627838
Show file tree
Hide file tree
Showing 55 changed files with 326 additions and 150 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
cache: "pip"
- run: "pip install -e '.[dev]'"
- run: mypy src
black:
ruff-format:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v3
Expand All @@ -35,7 +35,7 @@ jobs:
python-version: "3.11"
cache: "pip"
- run: "pip install -e '.[dev]'"
- run: black --check .
- run: ruff format --check .
pytest:
runs-on: ubuntu-22.04
services:
Expand Down
4 changes: 2 additions & 2 deletions justfile
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ check:
(set -x; ruff . )

echo
( set -x; black --check . )
( set -x; ruff format --check . )

echo
( set -x; mypy src )
Expand All @@ -37,7 +37,7 @@ check:

# Auto-fix code issues.
fix:
black .
ruff format .
ruff --fix .

# Start a MongoDB instance in docker.
Expand Down
9 changes: 5 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@ readme = "README.rst"

[project.optional-dependencies]
dev = [
"black",
"ruff",
"moldoc",
"mypy",
"pip-tools",
"pytest",
# TODO: Remove pin when https://github.com/TvoroG/pytest-lazy-fixture/issues/65 is resolved.
# pytest-lazy-fixture 0.6.0 is incompatible with pytest 8.0.0
"pytest<8",
"pytest-benchmark",
"pytest-datadir",
"pytest-lazy-fixture",
Expand All @@ -51,11 +52,11 @@ documentation = "https://stk.readthedocs.io"
[tool.setuptools_scm]
write_to = "src/stk/_version.py"

[tool.black]
line-length = 79

[tool.ruff]
line-length = 79

[too.ruff.lint]
extend-select = ["I"]

[tool.pytest.ini_options]
Expand Down
70 changes: 44 additions & 26 deletions src/stk/_internal/building_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,11 @@ class BuildingBlock(Molecule):
def __init__(
self,
smiles: str,
functional_groups: FunctionalGroup
| FunctionalGroupFactory
| Iterable[FunctionalGroup | FunctionalGroupFactory] = (),
functional_groups: (
FunctionalGroup
| FunctionalGroupFactory
| Iterable[FunctionalGroup | FunctionalGroupFactory]
) = (),
placer_ids: Iterable[int] | None = None,
position_matrix: np.ndarray | None = None,
) -> None:
Expand Down Expand Up @@ -137,9 +139,11 @@ def __init__(
def init_from_molecule(
cls,
molecule: Molecule,
functional_groups: FunctionalGroup
| FunctionalGroupFactory
| Iterable[FunctionalGroup | FunctionalGroupFactory] = (),
functional_groups: (
FunctionalGroup
| FunctionalGroupFactory
| Iterable[FunctionalGroup | FunctionalGroupFactory]
) = (),
placer_ids: Iterable[int] | None = None,
) -> typing.Self:
"""
Expand Down Expand Up @@ -200,9 +204,11 @@ def init_from_molecule(
def init_from_vabene_molecule(
cls,
molecule: vabene.Molecule,
functional_groups: FunctionalGroup
| FunctionalGroupFactory
| Iterable[FunctionalGroup | FunctionalGroupFactory] = (),
functional_groups: (
FunctionalGroup
| FunctionalGroupFactory
| Iterable[FunctionalGroup | FunctionalGroupFactory]
) = (),
placer_ids: Iterable[int] | None = None,
position_matrix: np.ndarray | None = None,
) -> typing.Self:
Expand Down Expand Up @@ -316,9 +322,11 @@ def init(
atoms: Iterable[Atom],
bonds: Iterable[Bond],
position_matrix: np.ndarray,
functional_groups: FunctionalGroup
| FunctionalGroupFactory
| Iterable[FunctionalGroup | FunctionalGroupFactory] = (),
functional_groups: (
FunctionalGroup
| FunctionalGroupFactory
| Iterable[FunctionalGroup | FunctionalGroupFactory]
) = (),
placer_ids: Iterable[int] | None = None,
) -> typing.Self:
"""
Expand Down Expand Up @@ -380,11 +388,13 @@ def init(
bonds=bonds,
position_matrix=position_matrix,
)
building_block._fg_repr = repr(functional_groups)
functional_groups = building_block._extract_functional_groups(
functional_groups=functional_groups,
)
building_block._with_functional_groups(functional_groups)
building_block._fg_repr = repr( # type: ignore[has-type]
tuple(building_block.get_functional_groups())
)
building_block._placer_ids = building_block._normalize_placer_ids(
placer_ids=placer_ids,
functional_groups=building_block._functional_groups,
Expand All @@ -400,9 +410,11 @@ def init(
def init_from_file(
cls,
path: pathlib.Path | str,
functional_groups: FunctionalGroup
| FunctionalGroupFactory
| Iterable[FunctionalGroup | FunctionalGroupFactory] = (),
functional_groups: (
FunctionalGroup
| FunctionalGroupFactory
| Iterable[FunctionalGroup | FunctionalGroupFactory]
) = (),
placer_ids: Iterable[int] | None = None,
) -> typing.Self:
"""
Expand Down Expand Up @@ -479,9 +491,11 @@ def init_from_file(
def init_from_rdkit_mol(
cls,
molecule: rdkit.Mol,
functional_groups: FunctionalGroup
| FunctionalGroupFactory
| Iterable[FunctionalGroup | FunctionalGroupFactory] = (),
functional_groups: (
FunctionalGroup
| FunctionalGroupFactory
| Iterable[FunctionalGroup | FunctionalGroupFactory]
) = (),
placer_ids: Iterable[int] | None = None,
) -> typing.Self:
"""
Expand Down Expand Up @@ -548,9 +562,11 @@ def init_from_rdkit_mol(
def _init_from_rdkit_mol(
self,
molecule: rdkit.Mol,
functional_groups: FunctionalGroup
| FunctionalGroupFactory
| Iterable[FunctionalGroup | FunctionalGroupFactory],
functional_groups: (
FunctionalGroup
| FunctionalGroupFactory
| Iterable[FunctionalGroup | FunctionalGroupFactory]
),
placer_ids: Iterable[int] | None,
) -> None:
"""
Expand Down Expand Up @@ -595,7 +611,6 @@ def _init_from_rdkit_mol(
"""

self._fg_repr = repr(functional_groups)
atoms = tuple(
Atom(a.GetIdx(), a.GetAtomicNum(), a.GetFormalCharge())
for a in molecule.GetAtoms()
Expand All @@ -620,6 +635,7 @@ def _init_from_rdkit_mol(
functional_groups=functional_groups,
)
)
self._fg_repr = repr(self._functional_groups)
self._placer_ids = self._normalize_placer_ids(
placer_ids=placer_ids,
functional_groups=self._functional_groups,
Expand Down Expand Up @@ -714,9 +730,11 @@ def _get_core_ids(

def _extract_functional_groups(
self,
functional_groups: FunctionalGroup
| FunctionalGroupFactory
| Iterable[FunctionalGroup | FunctionalGroupFactory],
functional_groups: (
FunctionalGroup
| FunctionalGroupFactory
| Iterable[FunctionalGroup | FunctionalGroupFactory]
),
) -> Iterator[FunctionalGroup]:
"""
Yield functional groups.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def get_key(record: T) -> str:
population, keys = dedupe(
items=(
molecule_record
for molecule_record, in self._generation_selector.select(
for (molecule_record,) in self._generation_selector.select(
population=normalized_fitness_values
)
),
Expand Down
13 changes: 7 additions & 6 deletions src/stk/_internal/ea/fitness_normalizers/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,8 @@ class Add(FitnessNormalizer[T]):
def __init__(
self,
number: float | Iterable[float],
filter: Callable[
[dict[T, Any], T], bool
] = lambda fitness_values, record: True,
filter: Callable[[dict[T, Any], T], bool] = lambda fitness_values,
record: True,
) -> None:
"""
Parameters:
Expand All @@ -134,8 +133,10 @@ def __init__(

def normalize(self, fitness_values: dict[T, Any]) -> dict[T, Any]:
return {
record: np.add(self._number, fitness_value)
if self._filter(fitness_values, record)
else fitness_value
record: (
np.add(self._number, fitness_value)
if self._filter(fitness_values, record)
else fitness_value
)
for record, fitness_value in fitness_values.items()
}
13 changes: 7 additions & 6 deletions src/stk/_internal/ea/fitness_normalizers/divide_by_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,8 @@ class DivideByMean(FitnessNormalizer[T]):

def __init__(
self,
filter: Callable[
[dict[T, Any], T], bool
] = lambda fitness_values, record: True,
filter: Callable[[dict[T, Any], T], bool] = lambda fitness_values,
record: True,
) -> None:
"""
Parameters:
Expand Down Expand Up @@ -116,8 +115,10 @@ def normalize(self, fitness_values: dict[T, Any]) -> dict[T, Any]:
logger.debug(f"Means used: {mean}")

return {
record: np.divide(fitness_value, mean)
if self._filter(fitness_values, record)
else fitness_value
record: (
np.divide(fitness_value, mean)
if self._filter(fitness_values, record)
else fitness_value
)
for record, fitness_value in fitness_values.items()
}
13 changes: 7 additions & 6 deletions src/stk/_internal/ea/fitness_normalizers/multiply.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,8 @@ class Multiply(FitnessNormalizer[T]):
def __init__(
self,
coefficient: float | Iterable[float],
filter: Callable[
[dict[T, Any], T], bool
] = lambda fitness_values, record: True,
filter: Callable[[dict[T, Any], T], bool] = lambda fitness_values,
record: True,
) -> None:
"""
Parameters:
Expand All @@ -135,8 +134,10 @@ def __init__(

def normalize(self, fitness_values: dict[T, Any]) -> dict[T, Any]:
return {
record: np.multiply(self._coefficient, fitness_value)
if self._filter(fitness_values, record)
else fitness_value
record: (
np.multiply(self._coefficient, fitness_value)
if self._filter(fitness_values, record)
else fitness_value
)
for record, fitness_value in fitness_values.items()
}
13 changes: 7 additions & 6 deletions src/stk/_internal/ea/fitness_normalizers/power.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,8 @@ class Power(FitnessNormalizer[T]):
def __init__(
self,
power: float | Iterable[float],
filter: Callable[
[dict[T, Any], T], bool
] = lambda fitness_values, record: True,
filter: Callable[[dict[T, Any], T], bool] = lambda fitness_values,
record: True,
) -> None:
"""
Parameters:
Expand All @@ -169,8 +168,10 @@ def __init__(

def normalize(self, fitness_values: dict[T, Any]) -> dict[T, Any]:
return {
record: np.float_power(fitness_value, self._power)
if self._filter(fitness_values, record)
else fitness_value
record: (
np.float_power(fitness_value, self._power)
if self._filter(fitness_values, record)
else fitness_value
)
for record, fitness_value in fitness_values.items()
}
13 changes: 7 additions & 6 deletions src/stk/_internal/ea/fitness_normalizers/replace_fitness.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,8 @@ def get_minimum_fitness_value(fitness_values):
def __init__(
self,
get_replacement: Callable[[dict[T, Any]], Any],
filter: Callable[
[dict[T, Any], T], bool
] = lambda fitness_values, record: True,
filter: Callable[[dict[T, Any], T], bool] = lambda fitness_values,
record: True,
) -> None:
"""
Parameters:
Expand All @@ -97,8 +96,10 @@ def __init__(
def normalize(self, fitness_values: dict[T, Any]) -> dict[T, Any]:
replacement = self._get_replacement(fitness_values)
return {
record: replacement
if self._filter(fitness_values, record)
else fitness_value
record: (
replacement
if self._filter(fitness_values, record)
else fitness_value
)
for record, fitness_value in fitness_values.items()
}
13 changes: 7 additions & 6 deletions src/stk/_internal/ea/fitness_normalizers/shift_up.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,8 @@ class ShiftUp(FitnessNormalizer[T]):

def __init__(
self,
filter: Callable[
[dict[T, Any], T], bool
] = lambda fitness_values, record: True,
filter: Callable[[dict[T, Any], T], bool] = lambda fitness_values,
record: True,
) -> None:
"""
Parameters:
Expand Down Expand Up @@ -184,8 +183,10 @@ def normalize(self, fitness_values: dict[T, Any]) -> dict[T, Any]:
shift[i] = 1 - min_

return {
record: fitness_value + shift
if self._filter(fitness_values, record)
else fitness_value
record: (
fitness_value + shift
if self._filter(fitness_values, record)
else fitness_value
)
for record, fitness_value in fitness_values.items()
}
Loading

0 comments on commit 1627838

Please sign in to comment.