Skip to content

Commit

Permalink
Hotfix/clean graph (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
JPXKQX authored Aug 9, 2024
1 parent 863c489 commit d1a3298
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 3 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ Keep it human-readable, your future self will thank you!
- added downstream-ci pipeline

### Changed
- Fix `anemoi-graphs create`. Config argument is cast to a Path.
- Fix GraphCreator().clean() to not iterate over a dictionary that may change size during iterations.

### Removed

Expand Down
2 changes: 1 addition & 1 deletion src/anemoi/graphs/commands/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def add_arguments(self, command_parser):
help="Overwrite existing files. This will delete the target graph if it already exists.",
)
command_parser.add_argument(
"config", help="Configuration yaml file path defining the recipe to create the graph."
"config", type=Path, help="Configuration yaml file path defining the recipe to create the graph."
)
command_parser.add_argument("save_path", type=Path, help="Path to store the created graph.")

Expand Down
7 changes: 5 additions & 2 deletions src/anemoi/graphs/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ def __init__(
self,
config: Union[Path, DotDict],
):
self.config = DotDict.from_file(config) if isinstance(config, Path) else config
if isinstance(config, Path) or isinstance(config, str):
self.config = DotDict.from_file(config)
else:
self.config = config

def generate_graph(self) -> HeteroData:
"""Generate the graph.
Expand Down Expand Up @@ -59,7 +62,7 @@ def clean(self, graph: HeteroData) -> HeteroData:
cleaned graph
"""
for type_name in chain(graph.node_types, graph.edge_types):
for attr_name in graph[type_name].keys():
for attr_name in list(graph[type_name].keys()):
if attr_name.startswith("_"):
del graph[type_name][attr_name]

Expand Down

0 comments on commit d1a3298

Please sign in to comment.