diff --git a/ra2ce/analyses/analysis_config_data/analysis_config_data.py b/ra2ce/analyses/analysis_config_data/analysis_config_data.py index 26ad15bca..b662d296c 100644 --- a/ra2ce/analyses/analysis_config_data/analysis_config_data.py +++ b/ra2ce/analyses/analysis_config_data/analysis_config_data.py @@ -21,6 +21,7 @@ from __future__ import annotations + from ra2ce.common.configuration.config_data_protocol import ConfigDataProtocol diff --git a/ra2ce/analyses/analysis_config_data/analysis_config_data_validator_without_network.py b/ra2ce/analyses/analysis_config_data/analysis_config_data_validator_without_network.py index c3fd95cd7..f9c03d842 100644 --- a/ra2ce/analyses/analysis_config_data/analysis_config_data_validator_without_network.py +++ b/ra2ce/analyses/analysis_config_data/analysis_config_data_validator_without_network.py @@ -21,17 +21,16 @@ from pathlib import Path + from ra2ce.analyses.analysis_config_data.analysis_config_data import ( AnalysisConfigDataWithoutNetwork, ) - from ra2ce.common.validation.ra2ce_validator_protocol import Ra2ceIoValidator from ra2ce.common.validation.validation_report import ValidationReport from ra2ce.graph.network_config_data.network_config_data_validator import ( NetworkDictValues, ) - IndirectAnalysisNameList: list[str] = [ "single_link_redundancy", "multi_link_redundancy", diff --git a/ra2ce/analyses/analysis_config_data/readers/analysis_config_reader_factory.py b/ra2ce/analyses/analysis_config_data/readers/analysis_config_reader_factory.py index 0b997fda2..304104a06 100644 --- a/ra2ce/analyses/analysis_config_data/readers/analysis_config_reader_factory.py +++ b/ra2ce/analyses/analysis_config_data/readers/analysis_config_reader_factory.py @@ -23,9 +23,6 @@ from pathlib import Path from typing import Optional -from ra2ce.analyses.analysis_config_wrapper.analysis_config_wrapper_base import ( - AnalysisConfigWrapperBase, -) from ra2ce.analyses.analysis_config_data.readers.analysis_config_reader_base import ( AnalysisConfigReaderBase, ) @@ -35,6 +32,9 @@ from ra2ce.analyses.analysis_config_data.readers.analysis_config_reader_without_network import ( AnalysisConfigReaderWithoutNetwork, ) +from ra2ce.analyses.analysis_config_wrapper.analysis_config_wrapper_base import ( + AnalysisConfigWrapperBase, +) from ra2ce.graph.network_config_data.network_config_data import NetworkConfigData diff --git a/ra2ce/analyses/analysis_config_data/readers/analysis_config_reader_with_network.py b/ra2ce/analyses/analysis_config_data/readers/analysis_config_reader_with_network.py index 8603ecbce..48df087a7 100644 --- a/ra2ce/analyses/analysis_config_data/readers/analysis_config_reader_with_network.py +++ b/ra2ce/analyses/analysis_config_data/readers/analysis_config_reader_with_network.py @@ -22,15 +22,15 @@ from pathlib import Path -from ra2ce.analyses.analysis_config_wrapper.analysis_config_wrapper_base import ( - AnalysisConfigWrapperBase, -) from ra2ce.analyses.analysis_config_data.analysis_config_data import ( AnalysisConfigDataWithNetwork, ) from ra2ce.analyses.analysis_config_data.readers.analysis_config_reader_base import ( AnalysisConfigReaderBase, ) +from ra2ce.analyses.analysis_config_wrapper.analysis_config_wrapper_base import ( + AnalysisConfigWrapperBase, +) from ra2ce.graph.network_config_wrapper import NetworkConfigWrapper diff --git a/ra2ce/analyses/analysis_config_data/readers/analysis_config_reader_without_network.py b/ra2ce/analyses/analysis_config_data/readers/analysis_config_reader_without_network.py index b4114e322..aac67030e 100644 --- a/ra2ce/analyses/analysis_config_data/readers/analysis_config_reader_without_network.py +++ b/ra2ce/analyses/analysis_config_data/readers/analysis_config_reader_without_network.py @@ -23,15 +23,15 @@ import logging from pathlib import Path -from ra2ce.analyses.analysis_config_wrapper.analysis_config_wrapper_base import ( - AnalysisConfigWrapperBase, -) from ra2ce.analyses.analysis_config_data.analysis_config_data import ( AnalysisConfigDataWithoutNetwork, ) from ra2ce.analyses.analysis_config_data.readers.analysis_config_reader_base import ( AnalysisConfigReaderBase, ) +from ra2ce.analyses.analysis_config_wrapper.analysis_config_wrapper_base import ( + AnalysisConfigWrapperBase, +) from ra2ce.graph.network_config_data.network_config_data_reader import ( NetworkConfigDataReader, ) diff --git a/ra2ce/analyses/analysis_config_wrapper/analysis_config_wrapper_base.py b/ra2ce/analyses/analysis_config_wrapper/analysis_config_wrapper_base.py index fafd2f6f8..0c6e5838d 100644 --- a/ra2ce/analyses/analysis_config_wrapper/analysis_config_wrapper_base.py +++ b/ra2ce/analyses/analysis_config_wrapper/analysis_config_wrapper_base.py @@ -20,7 +20,7 @@ """ -from abc import abstractmethod, abstractclassmethod +from abc import abstractclassmethod, abstractmethod from pathlib import Path from typing import Optional diff --git a/ra2ce/analyses/analysis_config_wrapper/analysis_config_wrapper_factory.py b/ra2ce/analyses/analysis_config_wrapper/analysis_config_wrapper_factory.py index 6d72124f9..b73cf46a3 100644 --- a/ra2ce/analyses/analysis_config_wrapper/analysis_config_wrapper_factory.py +++ b/ra2ce/analyses/analysis_config_wrapper/analysis_config_wrapper_factory.py @@ -23,14 +23,14 @@ from pathlib import Path from typing import Optional -from ra2ce.analyses.analysis_config_wrapper.analysis_config_wrapper_base import ( - AnalysisConfigWrapperBase, -) from ra2ce.analyses.analysis_config_data.analysis_config_data import ( AnalysisConfigData, AnalysisConfigDataWithNetwork, AnalysisConfigDataWithoutNetwork, ) +from ra2ce.analyses.analysis_config_wrapper.analysis_config_wrapper_base import ( + AnalysisConfigWrapperBase, +) from ra2ce.analyses.analysis_config_wrapper.analysis_config_wrapper_with_network import ( AnalysisConfigWrapperWithNetwork, ) diff --git a/ra2ce/analyses/analysis_config_wrapper/analysis_config_wrapper_with_network.py b/ra2ce/analyses/analysis_config_wrapper/analysis_config_wrapper_with_network.py index 093cd0968..11de109af 100644 --- a/ra2ce/analyses/analysis_config_wrapper/analysis_config_wrapper_with_network.py +++ b/ra2ce/analyses/analysis_config_wrapper/analysis_config_wrapper_with_network.py @@ -23,12 +23,14 @@ from __future__ import annotations from pathlib import Path -from ra2ce.analyses.analysis_config_data.analysis_config_data_validator_with_network import AnalysisConfigDataValidatorWithNetwork +from ra2ce.analyses.analysis_config_data.analysis_config_data import AnalysisConfigData +from ra2ce.analyses.analysis_config_data.analysis_config_data_validator_with_network import ( + AnalysisConfigDataValidatorWithNetwork, +) from ra2ce.analyses.analysis_config_wrapper.analysis_config_wrapper_base import ( AnalysisConfigWrapperBase, ) -from ra2ce.analyses.analysis_config_data.analysis_config_data import AnalysisConfigData from ra2ce.graph.network_config_data.network_config_data import NetworkConfigData from ra2ce.graph.network_config_wrapper import NetworkConfigWrapper @@ -98,5 +100,7 @@ def configure(self) -> None: def is_valid(self) -> bool: _file_is_valid = self.ini_file.is_file() and self.ini_file.suffix == ".ini" - _validation_report = AnalysisConfigDataValidatorWithNetwork(self.config_data).validate() + _validation_report = AnalysisConfigDataValidatorWithNetwork( + self.config_data + ).validate() return _file_is_valid and _validation_report.is_valid() diff --git a/ra2ce/analyses/analysis_config_wrapper/analysis_config_wrapper_without_network.py b/ra2ce/analyses/analysis_config_wrapper/analysis_config_wrapper_without_network.py index f948d769e..b16c9d4fb 100644 --- a/ra2ce/analyses/analysis_config_wrapper/analysis_config_wrapper_without_network.py +++ b/ra2ce/analyses/analysis_config_wrapper/analysis_config_wrapper_without_network.py @@ -24,12 +24,14 @@ import logging from pathlib import Path -from ra2ce.analyses.analysis_config_data.analysis_config_data_validator_without_network import AnalysisConfigDataValidatorWithoutNetwork +from ra2ce.analyses.analysis_config_data.analysis_config_data import AnalysisConfigData +from ra2ce.analyses.analysis_config_data.analysis_config_data_validator_without_network import ( + AnalysisConfigDataValidatorWithoutNetwork, +) from ra2ce.analyses.analysis_config_wrapper.analysis_config_wrapper_base import ( AnalysisConfigWrapperBase, ) -from ra2ce.analyses.analysis_config_data.analysis_config_data import AnalysisConfigData from ra2ce.graph.network_config_wrapper import NetworkConfigWrapper @@ -77,5 +79,7 @@ def configure(self) -> None: def is_valid(self) -> bool: _file_is_valid = self.ini_file.is_file() and self.ini_file.suffix == ".ini" - _validation_report = AnalysisConfigDataValidatorWithoutNetwork(self.config_data).validate() + _validation_report = AnalysisConfigDataValidatorWithoutNetwork( + self.config_data + ).validate() return _file_is_valid and _validation_report.is_valid() diff --git a/ra2ce/common/configuration/config_data_protocol.py b/ra2ce/common/configuration/config_data_protocol.py index 5960901dc..4f505f126 100644 --- a/ra2ce/common/configuration/config_data_protocol.py +++ b/ra2ce/common/configuration/config_data_protocol.py @@ -33,4 +33,4 @@ def to_dict(self) -> dict: Returns: dict: Dictionary representing the `ConfigDataProtocol` instance. """ - pass \ No newline at end of file + pass diff --git a/ra2ce/common/configuration/ini_configuration_reader_protocol.py b/ra2ce/common/configuration/ini_configuration_reader_protocol.py index 1c081257a..edf92c54a 100644 --- a/ra2ce/common/configuration/ini_configuration_reader_protocol.py +++ b/ra2ce/common/configuration/ini_configuration_reader_protocol.py @@ -21,6 +21,7 @@ from pathlib import Path from typing import Protocol, runtime_checkable + from ra2ce.common.configuration.config_data_protocol import ConfigDataProtocol from ra2ce.common.io.readers.file_reader_protocol import FileReaderProtocol diff --git a/ra2ce/configuration/config_factory.py b/ra2ce/configuration/config_factory.py index 956fb3dd8..d3da83cb7 100644 --- a/ra2ce/configuration/config_factory.py +++ b/ra2ce/configuration/config_factory.py @@ -20,19 +20,19 @@ """ -from pathlib import Path import shutil +from pathlib import Path from typing import Optional +from ra2ce.analyses.analysis_config_data.readers.analysis_config_reader_factory import ( + AnalysisConfigReaderFactory, +) from ra2ce.analyses.analysis_config_wrapper.analysis_config_wrapper_base import ( AnalysisConfigWrapperBase, ) from ra2ce.analyses.analysis_config_wrapper.analysis_config_wrapper_factory import ( AnalysisConfigWrapperFactory, ) -from ra2ce.analyses.analysis_config_data.readers.analysis_config_reader_factory import ( - AnalysisConfigReaderFactory, -) from ra2ce.configuration.config_wrapper import ConfigWrapper from ra2ce.graph.network_config_data.network_config_data import NetworkConfigData from ra2ce.graph.network_config_data.network_config_data_reader import ( diff --git a/ra2ce/graph/network_config_data/network_config_data.py b/ra2ce/graph/network_config_data/network_config_data.py index e0f826753..74ab33976 100644 --- a/ra2ce/graph/network_config_data/network_config_data.py +++ b/ra2ce/graph/network_config_data/network_config_data.py @@ -25,6 +25,8 @@ from pathlib import Path from typing import Optional +from pyproj import CRS + from ra2ce.common.configuration.config_data_protocol import ConfigDataProtocol @@ -37,10 +39,10 @@ class ProjectSection: class NetworkSection: directed: bool = False source: str = "" # should be enum - primary_file: str = "" # TODO. Unclear whether this is `Path` or `list[Path]` - diversion_file: str = "" # TODO. Unclear whether this is `Path` or `list[Path]` + primary_file: list[Path] = field(default_factory=list) + diversion_file: list[Path] = field(default_factory=list) file_id: str = "" - polygon: str = "" # TODO. Unclear whether this is `str`` or `Path` + polygon: Optional[Path] = None network_type: str = "" # Should be enum road_types: list[str] = field(default_factory=list) save_shp: bool = False @@ -90,7 +92,8 @@ class NetworkConfigData(ConfigDataProtocol): input_path: Optional[Path] = None output_path: Optional[Path] = None static_path: Optional[Path] = None - + # CRS is not yet supported in the ini file, it might be relocated to a subsection. + crs: CRS = field(default_factory=lambda: CRS.from_user_input(4326)) project: ProjectSection = field(default_factory=lambda: ProjectSection()) network: NetworkSection = field(default_factory=lambda: NetworkSection()) origins_destinations: OriginsDestinationsSection = field( @@ -106,8 +109,15 @@ def output_graph_dir(self) -> Optional[Path]: return None return self.static_path.joinpath("output_graph") + @property + def network_dir(self) -> Optional[Path]: + if not self.static_path: + return None + return self.static_path.joinpath("network") + def to_dict(self) -> dict: _dict = self.__dict__ + _dict["crs"] = self.crs.to_epsg() _dict["project"] = self.project.__dict__ _dict["network"] = self.network.__dict__ _dict["origins_destinations"] = self.origins_destinations.__dict__ diff --git a/ra2ce/graph/network_config_data/network_config_data_reader.py b/ra2ce/graph/network_config_data/network_config_data_reader.py index 14ce57685..8cc0347e6 100644 --- a/ra2ce/graph/network_config_data/network_config_data_reader.py +++ b/ra2ce/graph/network_config_data/network_config_data_reader.py @@ -1,6 +1,6 @@ from configparser import ConfigParser from pathlib import Path -from typing import Union +from typing import Any, Union from ra2ce.common.configuration.ini_configuration_reader_protocol import ( ConfigDataReaderProtocol, @@ -58,6 +58,15 @@ def _select_to_correct(path_value: Union[list[Path], Path]) -> bool: return _select_to_correct(path_value[0]) return not path_value.exists() + def _correct_list(path_root: Path, path_value_list: list[Path]) -> list[Path]: + _corrected_list = [] + for _path_value in path_value_list: + if not _path_value.exists(): + _corrected_list.append(path_root.joinpath(_path_value)) + else: + _corrected_list.append(_path_value) + return _corrected_list + # Relative to network directory. _network_directory = config_data.static_path.joinpath("network") if _select_to_correct(config_data.origins_destinations.origins): @@ -70,15 +79,28 @@ def _select_to_correct(path_value: Union[list[Path], Path]) -> bool: config_data.origins_destinations.destinations ) + if _select_to_correct(config_data.origins_destinations.region): + config_data.origins_destinations.region = _network_directory.joinpath( + config_data.origins_destinations.region + ) + + if _select_to_correct(config_data.network.polygon): + config_data.network.polygon = _network_directory.joinpath( + config_data.network.polygon + ) + + config_data.network.primary_file = _correct_list( + _network_directory, config_data.network.primary_file + ) + config_data.network.diversion_file = _correct_list( + _network_directory, config_data.network.diversion_file + ) + # Relative to hazard directory. _hazard_directory = config_data.static_path.joinpath("hazard") - if _select_to_correct(config_data.hazard.hazard_map): - config_data.hazard.hazard_map = list( - map( - lambda x: _hazard_directory.joinpath(x), - config_data.hazard.hazard_map, - ) - ) + config_data.hazard.hazard_map = _correct_list( + _hazard_directory, config_data.hazard.hazard_map + ) def _get_str_as_path(self, str_value: Union[str, Path]) -> Path: if str_value and not isinstance(str_value, Path): @@ -107,9 +129,23 @@ def _get_sections(self) -> dict: def get_project_section(self) -> ProjectSection: return ProjectSection(**self._parser["project"]) + def _get_path_list( + self, section_name: str, property: str, fallback_opt: Any + ) -> list[Path]: + _value_list = self._parser.getlist( + section_name, property, fallback=fallback_opt + ) + return list(map(self._get_str_as_path, _value_list)) + def get_network_section(self) -> NetworkSection: _section = "network" _network_section = NetworkSection(**self._parser[_section]) + _network_section.primary_file = self._get_path_list( + _section, "primary_file", _network_section.primary_file + ) + _network_section.diversion_file = self._get_path_list( + _section, "diversion_file", _network_section.diversion_file + ) _network_section.directed = self._parser.getboolean( _section, "directed", fallback=_network_section.directed ) @@ -119,6 +155,7 @@ def get_network_section(self) -> NetworkSection: _network_section.road_types = self._parser.getlist( _section, "road_types", fallback=_network_section.road_types ) + _network_section.polygon = self._get_str_as_path(_network_section.polygon) return _network_section def get_origins_destinations_section(self) -> OriginsDestinationsSection: diff --git a/ra2ce/graph/network_config_data/network_config_data_validator.py b/ra2ce/graph/network_config_data/network_config_data_validator.py index f411c7ca9..1fe334168 100644 --- a/ra2ce/graph/network_config_data/network_config_data_validator.py +++ b/ra2ce/graph/network_config_data/network_config_data_validator.py @@ -20,6 +20,7 @@ """ from typing import Any + from ra2ce.common.validation.ra2ce_validator_protocol import Ra2ceIoValidator from ra2ce.common.validation.validation_report import ValidationReport from ra2ce.graph.network_config_data.network_config_data import ( diff --git a/ra2ce/graph/network_config_wrapper.py b/ra2ce/graph/network_config_wrapper.py index a25ef1d04..03e01168b 100644 --- a/ra2ce/graph/network_config_wrapper.py +++ b/ra2ce/graph/network_config_wrapper.py @@ -37,6 +37,7 @@ ) from ra2ce.graph.networks import Network + class NetworkConfigWrapper(ConfigWrapperProtocol): files: Dict[str, Path] = {} config_data: NetworkConfigData diff --git a/ra2ce/graph/network_wrappers/README.md b/ra2ce/graph/network_wrappers/README.md new file mode 100644 index 000000000..bc04349e7 --- /dev/null +++ b/ra2ce/graph/network_wrappers/README.md @@ -0,0 +1,8 @@ +# Network wrappers + +In this module we define wrappers for reading a ra2ce network (`tuple[MultiGraph, GeoDataFrame]`). Each different class defines how a network will be created and cleaned. +Network wrappers need to instantiate the `NetworkWrapperProtocol`. + +An instance of the `NetworkWrapperProtocol` defines the `get_network` method, which will return the aforementioend ra2ce network. + +Any network wrapper can be created with the use of a `NetworkConfigData` instance. At the same time, if the user does not know which wrapper to use, the `NetworkWrapperFactory` will resolve this for us. \ No newline at end of file diff --git a/ra2ce/graph/osm_network_wrapper/__init__.py b/ra2ce/graph/network_wrappers/__init__.py similarity index 100% rename from ra2ce/graph/osm_network_wrapper/__init__.py rename to ra2ce/graph/network_wrappers/__init__.py diff --git a/ra2ce/graph/network_wrappers/network_wrapper_factory.py b/ra2ce/graph/network_wrappers/network_wrapper_factory.py new file mode 100644 index 000000000..be558019a --- /dev/null +++ b/ra2ce/graph/network_wrappers/network_wrapper_factory.py @@ -0,0 +1,71 @@ +""" + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Risk Assessment and Adaptation for Critical Infrastructure (RA2CE). + Copyright (C) 2023 Stichting Deltares + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +import logging + +import geopandas as gpd +from geopandas import GeoDataFrame +from networkx import MultiGraph + +from ra2ce.common.io.readers import GraphPickleReader +from ra2ce.graph.network_config_data.network_config_data import NetworkConfigData +from ra2ce.graph.network_wrappers.network_wrapper_protocol import NetworkWrapperProtocol +from ra2ce.graph.network_wrappers.osm_network_wrapper.osm_network_wrapper import ( + OsmNetworkWrapper, +) +from ra2ce.graph.network_wrappers.shp_network_wrapper import ShpNetworkWrapper +from ra2ce.graph.network_wrappers.trails_network_wrapper import TrailsNetworkWrapper +from ra2ce.graph.network_wrappers.vector_network_wrapper import VectorNetworkWrapper + + +class NetworkWrapperFactory(NetworkWrapperProtocol): + def __init__(self, config_data: NetworkConfigData) -> None: + self._config_data = config_data + + def _any_cleanup_enabled(self) -> bool: + _cleanup = self._config_data.cleanup + return ( + _cleanup.snapping_threshold + or _cleanup.pruning_threshold + or _cleanup.merge_lines + or _cleanup.cut_at_intersections + ) + + def get_network(self) -> tuple[MultiGraph, GeoDataFrame]: + logging.info("Start creating a network from the submitted shapefile.") + source = self._config_data.network.source + if source == "shapefile": + if self._any_cleanup_enabled(): + return ShpNetworkWrapper(self._config_data).get_network() + return VectorNetworkWrapper(self._config_data).get_network() + elif source == "OSM PBF": + return TrailsNetworkWrapper(self._config_data).get_network() + elif source == "OSM download": + return OsmNetworkWrapper(self._config_data).get_network() + elif source == "pickle": + logging.info("Start importing a network from pickle") + base_graph = GraphPickleReader().read( + self.output_graph_dir.joinpath("base_graph.p") + ) + network_gdf = gpd.read_feather( + self.output_graph_dir.joinpath("base_network.feather") + ) + return base_graph, network_gdf diff --git a/ra2ce/graph/network_wrappers/network_wrapper_protocol.py b/ra2ce/graph/network_wrappers/network_wrapper_protocol.py new file mode 100644 index 000000000..b6159869d --- /dev/null +++ b/ra2ce/graph/network_wrappers/network_wrapper_protocol.py @@ -0,0 +1,36 @@ +""" + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Risk Assessment and Adaptation for Critical Infrastructure (RA2CE). + Copyright (C) 2023 Stichting Deltares + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" +from typing import Protocol, runtime_checkable + +from geopandas import GeoDataFrame +from networkx import MultiGraph + + +@runtime_checkable +class NetworkWrapperProtocol(Protocol): + def get_network(self) -> tuple[MultiGraph, GeoDataFrame]: + """ + Gets a network built within this wrapper instance. No arguments are accepted, the `__init__` method is meant to assign all required attributes for a wrapper. + + Returns: + tuple[MultiGraph, GeoDataFrame]: Tuple of MultiGraph representing the graph and GeoDataFrame representing the network. + """ + pass diff --git a/tests/graph/osm_network_wrapper/__init__.py b/ra2ce/graph/network_wrappers/osm_network_wrapper/__init__.py similarity index 100% rename from tests/graph/osm_network_wrapper/__init__.py rename to ra2ce/graph/network_wrappers/osm_network_wrapper/__init__.py diff --git a/ra2ce/graph/network_wrappers/osm_network_wrapper/extremities_data.py b/ra2ce/graph/network_wrappers/osm_network_wrapper/extremities_data.py new file mode 100644 index 000000000..e96c7697b --- /dev/null +++ b/ra2ce/graph/network_wrappers/osm_network_wrapper/extremities_data.py @@ -0,0 +1,125 @@ +""" + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Risk Assessment and Adaptation for Critical Infrastructure (RA2CE). + Copyright (C) 2023 Stichting Deltares + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +from dataclasses import dataclass + +from networkx import MultiDiGraph + + +@dataclass +class ExtremitiesData: + from_id: int = None + to_id: int = None + from_to_id: tuple = None + to_from_id: tuple = None + from_to_coor: tuple = None + to_from_coor: tuple = None + + @staticmethod + def get_extremities_data_for_sub_graph( + from_node_id: int, + to_node_id: int, + sub_graph: MultiDiGraph, + graph: MultiDiGraph, + shared_elements: set, + ): + """Both extremities should be in the unique_graph still makes an edge between similar node to u (the node + with u coordinates and different id, included in the unique_graph) and v Here, sub_graph is the unique_graph + and graph is complex_graph Shared elements are shared btw sub_graph and graph, which are elements to include + when dropping duplicates""" + if shared_elements is None or not isinstance(shared_elements, set): + raise ValueError("unique_elements should be a set") + if from_node_id in sub_graph.nodes() and to_node_id in sub_graph.nodes(): + return ExtremitiesData.arrange_extremities_data( + from_node_id=from_node_id, to_node_id=to_node_id, graph=sub_graph + ) + elif ( + graph.nodes[from_node_id]["x"], + graph.nodes[from_node_id]["y"], + ) in shared_elements and to_node_id in sub_graph.nodes(): + + from_node_id_prime = ExtremitiesData.find_node_id_by_coor( + sub_graph, + graph.nodes[from_node_id]["x"], + graph.nodes[from_node_id]["y"], + ) + if from_node_id_prime == to_node_id: + return ExtremitiesData() + else: + return ExtremitiesData.arrange_extremities_data( + from_node_id=from_node_id_prime, + to_node_id=to_node_id, + graph=sub_graph, + ) + + elif ( + from_node_id in sub_graph.nodes() + and (graph.nodes[to_node_id]["x"], graph.nodes[to_node_id]["y"]) + in shared_elements + ): + + to_node_id_prime = ExtremitiesData.find_node_id_by_coor( + sub_graph, graph.nodes[to_node_id]["x"], graph.nodes[to_node_id]["y"] + ) + if from_node_id == to_node_id_prime: + return ExtremitiesData() + else: + return ExtremitiesData.arrange_extremities_data( + from_node_id=from_node_id, + to_node_id=to_node_id_prime, + graph=sub_graph, + ) + else: + return ExtremitiesData() + + @staticmethod + def arrange_extremities_data( + from_node_id: int, to_node_id: int, graph: MultiDiGraph + ): + return ExtremitiesData( + from_id=from_node_id, + to_id=to_node_id, + from_to_id=(from_node_id, to_node_id), + to_from_id=(to_node_id, from_node_id), + from_to_coor=( + (graph.nodes[from_node_id]["x"], graph.nodes[to_node_id]["x"]), + (graph.nodes[from_node_id]["y"], graph.nodes[to_node_id]["y"]), + ), + to_from_coor=( + (graph.nodes[to_node_id]["x"], graph.nodes[from_node_id]["x"]), + (graph.nodes[to_node_id]["y"], graph.nodes[from_node_id]["y"]), + ), + ) + + @staticmethod + def find_node_id_by_coor(graph: MultiDiGraph, target_x: float, target_y: float): + """ + finds the node in unique graph with the same coor + """ + for node, data in graph.nodes(data=True): + if ( + "x" in data + and "y" in data + and data["x"] == target_x + and data["y"] == target_y + ): + return node + return None diff --git a/ra2ce/graph/osm_network_wrapper/osm_network_wrapper.py b/ra2ce/graph/network_wrappers/osm_network_wrapper/osm_network_wrapper.py similarity index 69% rename from ra2ce/graph/osm_network_wrapper/osm_network_wrapper.py rename to ra2ce/graph/network_wrappers/osm_network_wrapper/osm_network_wrapper.py index 45950730a..82d9c9b31 100644 --- a/ra2ce/graph/osm_network_wrapper/osm_network_wrapper.py +++ b/ra2ce/graph/network_wrappers/osm_network_wrapper/osm_network_wrapper.py @@ -16,38 +16,104 @@ """ import logging from pathlib import Path +from typing import Any import networkx as nx import osmnx -from networkx import MultiDiGraph -from osmnx import consolidate_intersections +import pandas as pd +from geopandas import GeoDataFrame +from networkx import MultiDiGraph, MultiGraph from shapely.geometry.base import BaseGeometry -from ra2ce.graph.network_config_data.network_config_data import NetworkSection import ra2ce.graph.networks_utils as nut -from ra2ce.graph.osm_network_wrapper.extremities_data import ExtremitiesData - - -class OsmNetworkWrapper: - network_type: str - road_types: list[str] - graph_crs: str - polygon_path: Path - - def __init__( - self, - network_type: str, - road_types: list[str], - graph_crs: str, - polygon_path: Path, - ) -> None: - self.network_type = network_type - self.road_types = road_types - self.polygon_path = polygon_path - self.graph_crs = graph_crs - if not graph_crs: - # Set default value - self.graph_crs = "epsg:4326" +from ra2ce.graph.exporters.json_exporter import JsonExporter +from ra2ce.graph.network_config_data.network_config_data import NetworkConfigData +from ra2ce.graph.network_wrappers.network_wrapper_protocol import NetworkWrapperProtocol +from ra2ce.graph.network_wrappers.osm_network_wrapper.extremities_data import ( + ExtremitiesData, +) + + +class OsmNetworkWrapper(NetworkWrapperProtocol): + def __init__(self, config_data: NetworkConfigData) -> None: + self.output_graph_dir = config_data.output_graph_dir + self.graph_crs = config_data.crs + + # Network options + self.network_type = config_data.network.network_type + self.road_types = config_data.network.road_types + self.polygon_path = config_data.network.polygon + self.is_directed = config_data.network.directed + + def get_network(self) -> tuple[MultiGraph, GeoDataFrame]: + """ + Gets an indirected graph + + Returns: + tuple[MultiGraph, GeoDataFrame]: _description_ + """ + logging.info("Start downloading a network from OSM.") + graph_complex = self.get_clean_graph_from_osm() + + # Create 'graph_simple' + graph_simple, graph_complex, link_tables = nut.create_simplified_graph( + graph_complex + ) + + # Create 'edges_complex', convert complex graph to geodataframe + logging.info("Start converting the graph to a geodataframe") + edges_complex, node_complex = nut.graph_to_gdf(graph_complex) + logging.info("Finished converting the graph to a geodataframe") + + # Save the link tables linking complex and simple IDs + self._export_linking_tables(link_tables) + + if not self.is_directed and isinstance(graph_simple, MultiDiGraph): + graph_simple = graph_simple.to_undirected() + + # Check if all geometries between nodes are there, if not, add them as a straight line. + graph_simple = nut.add_missing_geoms_graph(graph_simple, geom_name="geometry") + graph_simple = self._get_avg_speed(graph_simple) + return graph_simple, edges_complex + + def _get_avg_speed( + self, original_graph: nx.classes.graph.Graph + ) -> nx.classes.graph.Graph: + if all(["length" in e for u, v, e in original_graph.edges.data()]) and any( + ["maxspeed" in e for u, v, e in original_graph.edges.data()] + ): + # Add time weighing - Define and assign average speeds; or take the average speed from an existing CSV + path_avg_speed = self.output_graph_dir.joinpath("avg_speed.csv") + if path_avg_speed.is_file(): + avg_speeds = pd.read_csv(path_avg_speed) + else: + avg_speeds = nut.calc_avg_speed( + original_graph, + "highway", + save_csv=True, + save_path=path_avg_speed, + ) + original_graph = nut.assign_avg_speed(original_graph, avg_speeds, "highway") + + # make a time value of seconds, length of road streches is in meters + for u, v, k, edata in original_graph.edges.data(keys=True): + hours = (edata["length"] / 1000) / edata["avgspeed"] + original_graph[u][v][k]["time"] = round(hours * 3600, 0) + + return original_graph + logging.info( + "No attributes found in the graph to estimate average speed per network segment." + ) + return original_graph + + def _export_linking_tables(self, linking_tables: tuple[Any]) -> None: + _exporter = JsonExporter() + _exporter.export( + self.output_graph_dir.joinpath("simple_to_complex.json"), linking_tables[0] + ) + _exporter.export( + self.output_graph_dir.joinpath("complex_to_simple.json"), linking_tables[1] + ) def get_clean_graph_from_osm(self) -> MultiDiGraph: """ @@ -234,7 +300,7 @@ def valid_extremity_data(u, v, data) -> tuple[ExtremitiesData, dict]: @staticmethod def snap_nodes_to_nodes(graph: MultiDiGraph, threshold: float) -> MultiDiGraph: - return consolidate_intersections( + return osmnx.consolidate_intersections( G=graph, rebuild_graph=True, tolerance=threshold, dead_ends=False ) diff --git a/ra2ce/graph/osm_utils.py b/ra2ce/graph/network_wrappers/osm_network_wrapper/osm_utils.py similarity index 100% rename from ra2ce/graph/osm_utils.py rename to ra2ce/graph/network_wrappers/osm_network_wrapper/osm_utils.py diff --git a/ra2ce/graph/network_wrappers/shp_network_wrapper.py b/ra2ce/graph/network_wrappers/shp_network_wrapper.py new file mode 100644 index 000000000..a3d558174 --- /dev/null +++ b/ra2ce/graph/network_wrappers/shp_network_wrapper.py @@ -0,0 +1,196 @@ +""" + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Risk Assessment and Adaptation for Critical Infrastructure (RA2CE). + Copyright (C) 2023 Stichting Deltares + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +import logging +import math + +import geopandas as gpd +import networkx as nx +import pandas as pd +from shapely.geometry import MultiLineString + +import ra2ce.graph.networks_utils as nut +from ra2ce.graph.network_config_data.network_config_data import NetworkConfigData +from ra2ce.graph.network_wrappers.network_wrapper_protocol import NetworkWrapperProtocol +from ra2ce.graph.segmentation import Segmentation + + +class ShpNetworkWrapper(NetworkWrapperProtocol): + def __init__( + self, + config_data: NetworkConfigData, + ) -> None: + _network_options = config_data.network + _cleanup_options = config_data.cleanup + + self.project_name = config_data.project.name + self.output_graph_dir = config_data.output_graph_dir + self.crs = config_data.crs + + # Network options + self.primary_files = _network_options.primary_file + self.diversion_files = _network_options.diversion_file + self.directed = _network_options.directed + self.file_id = _network_options.file_id + + # Cleanup options + self.merge_lines = _cleanup_options.merge_lines + self.snapping_threshold = _cleanup_options.snapping_threshold + self.segmentation_length = _cleanup_options.segmentation_length + self.cut_at_intersections = _cleanup_options.cut_at_intersections + + # Origins Destinations + self.region_path = config_data.origins_destinations.region + + def _read_merge_shp(self) -> gpd.GeoDataFrame: + """Imports shapefile(s) and saves attributes in a pandas dataframe. + + Returns: + lines (list of shapely LineStrings): full list of linestrings + properties (pandas dataframe): attributes of shapefile(s), in order of the linestrings in lines + """ + # concatenate all shapefile into one geodataframe and set analysis to 1 or 0 for diversions + lines = [gpd.read_file(shp) for shp in self.primary_files] + + if any(self.diversion_files): + lines.extend( + [ + nut.check_crs_gdf(gpd.read_file(shp), self.crs) + for shp in self.diversion_files + ] + ) + lines = pd.concat(lines) + + lines.crs = self.crs + + # Check if there are any multilinestrings and convert them to linestrings. + if lines["geometry"].apply(lambda row: isinstance(row, MultiLineString)).any(): + mls_idx = lines.loc[ + lines["geometry"].apply(lambda row: isinstance(row, MultiLineString)) + ].index + for idx in mls_idx: + # Multilinestrings to linestrings + new_rows_geoms = list(lines.iloc[idx]["geometry"].geoms) + for nrg in new_rows_geoms: + dict_attributes = dict(lines.iloc[idx]) + dict_attributes["geometry"] = nrg + lines.loc[max(lines.index) + 1] = dict_attributes + + lines = lines.drop(labels=mls_idx, axis=0) + + # append the length of the road stretches + lines["length"] = lines["geometry"].apply( + lambda x: nut.line_length(x, self.crs) + ) + + logging.info( + "Shapefile(s) loaded with attributes: {}.".format( + list(lines.columns.values) + ) + ) # fill in parameter names + + return lines + + def _get_complex_graph_and_edges( + self, edges: gpd.GeoDataFrame, id_name: str + ) -> tuple[nx.MultiGraph, gpd.GeoDataFrame]: + # Get the unique points at the end of lines and at intersections to create nodes + nodes = nut.create_nodes(edges, self.crs, self.cut_at_intersections) + logging.info("Function [create_nodes]: executed") + + edges = nut.cut_lines( + edges, nodes, id_name, tolerance=0.00001, crs_=self.crs + ) ## PAY ATTENTION TO THE TOLERANCE, THE UNIT IS DEGREES + logging.info("Function [cut_lines]: executed") + + if not edges.crs: + edges.crs = self.crs + + # create tuples from the adjecent nodes and add as column in geodataframe + edges_complex = nut.join_nodes_edges(nodes, edges, id_name) + edges_complex.crs = self.crs # set the right CRS + edges_complex.dropna(subset=["node_A", "node_B"], inplace=True) + + assert ( + edges_complex["node_A"].isnull().sum() == 0 + ), "Some edges cannot be assigned nodes, please check your input shapefile." + assert ( + edges_complex["node_B"].isnull().sum() == 0 + ), "Some edges cannot be assigned nodes, please check your input shapefile." + + # Create networkx graph from geodataframe + graph_complex = nut.graph_from_gdf(edges_complex, nodes, node_id="node_fid") + logging.info("Function [graph_from_gdf]: executed") + return graph_complex, edges_complex + + def get_network( + self, + ) -> tuple[nx.MultiGraph, gpd.GeoDataFrame]: + edges = self._read_merge_shp() + lines_merged = gpd.GeoDataFrame() + # Check which of the lines are merged, also for the fid. The fid of the first line with a traffic count is taken. + # The list of fid's is reduced by the fid's that are not anymore in the merged lines + if self.merge_lines: + aadt_names = [] + edges, lines_merged = nut.merge_lines_automatic( + edges, self.file_id, aadt_names, self.crs + ) + logging.info( + "Function [merge_lines_automatic]: executed with properties {}".format( + list(edges.columns) + ) + ) + + edges, id_name = nut.gdf_check_create_unique_ids(edges, self.file_id) + + if self.snapping_threshold: + # TODO: snapping threshold it's a bool yet here we expect a float. + edges = nut.snap_endpoints_lines( + edges, self.snapping_threshold, id_name, self.crs + ) + logging.info( + "Function [snap_endpoints_lines]: executed with threshold = {}".format( + self.snapping_threshold + ) + ) + + # merge merged lines if there are any merged lines + if not lines_merged.empty: + # save the merged lines to a shapefile - CHECK if there are lines merged that should not be merged (e.g. main + secondary road) + lines_merged.set_geometry( + col="geometry", inplace=True + ) # To ensure the object is a GeoDataFrame and not a Series + _emerged_lines_file = self.output_graph_dir.joinpath( + f"{self.project_name}_lines_that_merged.shp" + ) + lines_merged.to_file(_emerged_lines_file) + logging.info( + "Function [edges_to_shp]: saved at {}".format(_emerged_lines_file) + ) + + graph_complex, edges_complex = self._get_complex_graph_and_edges(edges, id_name) + + if not math.isnan(self.segmentation_length): + edges_complex = Segmentation(edges_complex, self.segmentation_length) + edges_complex = edges_complex.apply_segmentation() + if edges_complex.crs is None: # The CRS might have dissapeared. + edges_complex.crs = self.crs # set the right CRS + return graph_complex, edges_complex diff --git a/ra2ce/graph/network_wrappers/trails_network_wrapper.py b/ra2ce/graph/network_wrappers/trails_network_wrapper.py new file mode 100644 index 000000000..5e95669b2 --- /dev/null +++ b/ra2ce/graph/network_wrappers/trails_network_wrapper.py @@ -0,0 +1,105 @@ +""" + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Risk Assessment and Adaptation for Critical Infrastructure (RA2CE). + Copyright (C) 2023 Stichting Deltares + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +import logging + +import geopandas as gpd +from geopandas import GeoDataFrame +from networkx import MultiGraph + +from ra2ce.graph.network_config_data.network_config_data import NetworkConfigData +from ra2ce.graph.network_wrappers.network_wrapper_protocol import NetworkWrapperProtocol +from ra2ce.graph.networks_utils import graph_from_gdf +from ra2ce.graph.segmentation import Segmentation + + +class TrailsNetworkWrapper(NetworkWrapperProtocol): + def __init__(self, config_data: NetworkConfigData) -> None: + logging.info( + """The original OSM PBF import is no longer supported. + Instead, the beta version of package TRAILS is used. + First stable release of TRAILS is expected in 2023.""" + ) + self.primary_files = config_data.network.primary_file + self.segmentation_length = config_data.cleanup.segmentation_length + self.crs = config_data.crs + + def get_network(self) -> tuple[MultiGraph, GeoDataFrame]: + """Creates a network which has been prepared in the TRAILS package + + #Todo: we might later simply import the whole trails code as a package, and directly use these functions + #Todo: because TRAILS is still in beta version we better wait with that untill the first stable version is + # released + + Returns: + graph_simple (NetworkX graph): Simplified graph (for use in the indirect analyses). + complex_edges (GeoDataFrame): Complex graph (for use in the direct analyses). + """ + logging.info( + "TRAILS importer: Reads the provided primary edge file: {}, assumes there also is a_nodes file".format( + self.primary_files + ) + ) + + logging.warning( + "Any coordinate projection information in the feather file will be overwritten (with default WGS84)" + ) + # Make a pyproj CRS from the EPSG code + + _edge_file = self.primary_files[0] + edges = gpd.read_feather(_edge_file) + edges = edges.set_crs(self.crs) + + corresponding_node_file = _edge_file.replace("edges", "nodes") + if not corresponding_node_file.exists(): + raise FileNotFoundError( + "The node file could not be found while importing from TRAILS" + ) + + nodes = gpd.read_feather(corresponding_node_file) + nodes = nodes.set_crs(self.crs) + + logging.info("TRAILS importer: start generating graph") + # tempfix to rename columns + edges = edges.rename({"from_id": "node_A", "to_id": "node_B"}, axis="columns") + node_id = "id" + graph_simple = graph_from_gdf(edges, nodes, name="network", node_id=node_id) + + logging.info("TRAILS importer: graph generating was succesfull.") + logging.warning( + "RA2CE will not clean-up your graph, assuming that it is already done in TRAILS" + ) + + edges_complex = edges + if self.segmentation_length: + logging.info("TRAILS importer: start segmentating graph") + to_segment = Segmentation(edges, self._cleanup.segmentation_length) + edges_simple_segmented = to_segment.apply_segmentation() + if edges_simple_segmented.crs is None: # The CRS might have dissapeared. + edges_simple_segmented.crs = edges.crs # set the right CRS + edges_complex = edges_simple_segmented + + graph_complex = graph_simple # NOTE THAT DIFFERENCE + # BETWEEN SIMPLE AND COMPLEX DOES NOT EXIST WHEN IMPORTING WITH TRAILS + + # Todo: better control over metadata in trails + # Todo: better control over where things are saved in the pipeline + return graph_complex, edges_complex diff --git a/ra2ce/graph/network_wrappers/vector_network_wrapper.py b/ra2ce/graph/network_wrappers/vector_network_wrapper.py new file mode 100644 index 000000000..3ba60c472 --- /dev/null +++ b/ra2ce/graph/network_wrappers/vector_network_wrapper.py @@ -0,0 +1,227 @@ +""" + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Risk Assessment and Adaptation for Critical Infrastructure (RA2CE). + Copyright (C) 2023 Stichting Deltares + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +import logging +from pathlib import Path + +import geopandas as gpd +import momepy +import networkx as nx +import pandas as pd +from shapely.geometry import Point + +import ra2ce.graph.networks_utils as nut +from ra2ce.graph.network_config_data.network_config_data import NetworkConfigData +from ra2ce.graph.network_wrappers.network_wrapper_protocol import NetworkWrapperProtocol + + +class VectorNetworkWrapper(NetworkWrapperProtocol): + """A class for handling and manipulating vector files. + + Provides methods for reading vector data, cleaning it, and setting up graph and + network. + """ + + def __init__( + self, + config_data: NetworkConfigData, + ) -> None: + self.crs = config_data.crs + + # Network options + self.primary_files = config_data.network.primary_file + self.directed = config_data.network.directed + + # Origins Destinations + self.region_path = config_data.origins_destinations.region + + def get_network( + self, + ) -> tuple[nx.MultiGraph, gpd.GeoDataFrame]: + """Gets a network built from vector files. + + Returns: + nx.MultiGraph: MultiGraph representing the graph. + gpd.GeoDataFrame: GeoDataFrame representing the network. + """ + gdf = self._read_vector_to_project_region_and_crs() + gdf = self.clean_vector(gdf) + if self.directed: + graph = self.get_direct_graph_from_vector(gdf) + else: + graph = self.get_indirect_graph_from_vector(gdf) + edges, nodes = self.get_network_edges_and_nodes_from_graph(graph) + graph_complex = nut.graph_from_gdf(edges, nodes, node_id="node_fid") + return graph_complex, edges + + def _read_vector_to_project_region_and_crs(self) -> gpd.GeoDataFrame: + gdf = self._read_files(self.primary_files) + if gdf is None: + logging.info("no file is read.") + return None + + # set crs and reproject if needed + if not gdf.crs and self.crs: + gdf = gdf.set_crs(self.crs) + logging.info("setting crs as default EPSG:4326. specify crs if incorrect") + + if self.crs: + gdf = gdf.to_crs(self.crs) + logging.info("reproject vector file to project crs") + + # clip for region + if self.region_path: + _region_gpd = self._read_files([self.region_path]) + gdf = gpd.overlay(gdf, _region_gpd, how="intersection", keep_geom_type=True) + logging.info("clip vector file to project region") + + # validate + if not any(gdf): + logging.warning("No vector features found within project region") + return None + + return gdf + + def _read_files(self, file_list: list[Path]) -> gpd.GeoDataFrame: + """Reads a list of files into a GeoDataFrame. + + Args: + file_list (list[Path]): List of file paths. + + Returns: + gpd.GeoDataFrame: GeoDataFrame representing the data. + """ + # read file + gdf = gpd.GeoDataFrame(pd.concat(list(map(gpd.read_file, file_list)))) + logging.info( + "Read files {} into a 'GeoDataFrame'.".format( + ", ".join(map(str, file_list)) + ) + ) + return gdf + + @staticmethod + def get_direct_graph_from_vector(gdf: gpd.GeoDataFrame) -> nx.DiGraph: + """Creates a simple directed graph with node and edge geometries based on a given GeoDataFrame. + + Args: + gdf (gpd.GeoDataFrame): Input GeoDataFrame containing line geometries. + Allow both LineString and MultiLineString. + + Returns: + nx.DiGraph: NetworkX graph object with "crs", "approach" as graph properties. + """ + + # simple geometry handeling + gdf = VectorNetworkWrapper.explode_and_deduplicate_geometries(gdf) + + # to graph + digraph = nx.DiGraph(crs=gdf.crs, approach="primal") + for _, row in gdf.iterrows(): + from_node = row.geometry.coords[0] + to_node = row.geometry.coords[-1] + digraph.add_node(from_node, geometry=Point(from_node)) + digraph.add_node(to_node, geometry=Point(to_node)) + digraph.add_edge( + from_node, + to_node, + geometry=row.pop( + "geometry" + ), # **row TODO: check if we do need all columns + ) + + return digraph + + @staticmethod + def get_indirect_graph_from_vector(gdf: gpd.GeoDataFrame) -> nx.Graph: + """Creates a simple undirected graph with node and edge geometries based on a given GeoDataFrame. + + Args: + gdf (gpd.GeoDataFrame): Input GeoDataFrame containing line geometries. + Allow both LineString and MultiLineString. + + Returns: + nx.Graph: NetworkX graph object with "crs", "approach" as graph properties. + """ + digraph = VectorNetworkWrapper.get_direct_graph_from_vector(gdf) + return digraph.to_undirected() + + @staticmethod + def get_network_edges_and_nodes_from_graph( + graph: nx.Graph, + ) -> tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]: + """Sets up network nodes and edges from a given graph. + + Args: + graph (nx.Graph): Input graph with geometry for nodes and edges. + Must contain "crs" as graph property. + + Returns: + gpd.GeoDataFrame: GeoDataFrame representing the network edges with "edge_fid", "node_A", and "node_B". + gpd.GeoDataFrame: GeoDataFrame representing the network nodes with "node_fid". + """ + + # TODO ths function use conventions. Good to make consistant convention with osm + nodes, edges = momepy.nx_to_gdf(graph, nodeID="node_fid") + edges["edge_fid"] = ( + edges["node_start"].astype(str) + "_" + edges["node_end"].astype(str) + ) + edges.rename( + {"node_start": "node_A", "node_end": "node_B"}, axis=1, inplace=True + ) + if not nodes.crs: + nodes.crs = graph.graph["crs"] + if not edges.crs: + edges.crs = graph.graph["crs"] + return edges, nodes + + @staticmethod + def clean_vector(gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame: + """Cleans a GeoDataFrame. + + Args: + gdf (gpd.GeoDataFrame): Input GeoDataFrame. + + Returns: + gpd.GeoDataFrame: Cleaned GeoDataFrame. + """ + + gdf = VectorNetworkWrapper.explode_and_deduplicate_geometries(gdf) + + return gdf + + @staticmethod + def explode_and_deduplicate_geometries(gpd: gpd.GeoDataFrame) -> gpd.GeoDataFrame: + """Explodes and deduplicates geometries a GeoDataFrame. + + Args: + gpd (gpd.GeoDataFrame): Input GeoDataFrame. + + Returns: + gpd.GeoDataFrame: GeoDataFrame with exploded and deduplicated geometries. + """ + gpd = gpd.explode() + gpd = gpd[ + gpd.index.isin( + gpd.geometry.apply(lambda geom: geom.wkb).drop_duplicates().index + ) + ] + return gpd diff --git a/ra2ce/graph/networks.py b/ra2ce/graph/networks.py index e97dfc5b7..ac05c21a2 100644 --- a/ra2ce/graph/networks.py +++ b/ra2ce/graph/networks.py @@ -20,26 +20,18 @@ """ import logging -import math -import os from pathlib import Path -from typing import Any, List, Tuple +from typing import Any import geopandas as gpd import networkx as nx -import pandas as pd import pyproj -import momepy -from shapely.geometry import MultiLineString from ra2ce.common.io.readers import GraphPickleReader from ra2ce.graph import networks_utils as nut -from ra2ce.graph.exporters.json_exporter import JsonExporter from ra2ce.graph.exporters.network_exporter_factory import NetworkExporterFactory from ra2ce.graph.network_config_data.network_config_data import NetworkConfigData -from ra2ce.graph.osm_network_wrapper.osm_network_wrapper import OsmNetworkWrapper -from ra2ce.graph.segmentation import Segmentation -from ra2ce.graph.vector_network_wrapper import VectorNetworkWrapper +from ra2ce.graph.network_wrappers.network_wrapper_factory import NetworkWrapperFactory class Network: @@ -52,20 +44,16 @@ class Network: config: A dictionary with the configuration details on how to create and adjust the network. """ - output_graph_dir: Path - files: dict - base_graph_crs: Any - base_network_crs: Any - def __init__(self, network_config: NetworkConfigData, files: dict): # General + self._config_data = network_config self.project_name = network_config.project.name - self.output_graph_dir = network_config.static_path.joinpath("output_graph") + self.output_graph_dir = network_config.output_graph_dir if not self.output_graph_dir.is_dir(): self.output_graph_dir.mkdir(parents=True) # Network - self._network_dir = network_config.static_path.joinpath("network") + self._network_dir = network_config.network_dir self.base_graph_crs = None # Initiate variable self.base_network_crs = None # Initiate variable @@ -82,290 +70,12 @@ def __init__(self, network_config: NetworkConfigData, files: dict): ) self.origin_count = _origins_destinations.origin_count self.od_category = _origins_destinations.category - self.region = ( - None - if not _origins_destinations.region - else network_config.static_path.joinpath( - "network", _origins_destinations.region - ) - ) + self.region = _origins_destinations.region self.region_var = _origins_destinations.region_var - # Cleanup - self._cleanup = network_config.cleanup - # files self.files = files - def network_shp( - self, crs: int = 4326 - ) -> Tuple[nx.classes.graph.Graph, gpd.GeoDataFrame]: - """Creates a (graph) network from a shapefile. - - Returns the same geometries for the network (GeoDataFrame) as for the graph (NetworkX graph), because - it is assumed that the user wants to keep the same geometries as their shapefile input. - - Args: - crs (int): the EPSG number of the coordinate reference system that is used - - Returns: - graph_complex (NetworkX graph): The resulting graph. - edges_complex (GeoDataFrame): The resulting network. - """ - # Make a pyproj CRS from the EPSG code - crs = pyproj.CRS.from_user_input(crs) - - lines = self.read_merge_shp(crs) - - logging.info( - "Function [read_merge_shp]: executed with {} {}".format( - self._network_config.primary_file, - self._network_config.diversion_file, - ) - ) - - # Check which of the lines are merged, also for the fid. The fid of the first line with a traffic count is taken. - # The list of fid's is reduced by the fid's that are not anymore in the merged lines - if self._cleanup.merge_lines: - aadt_names = [] - edges, lines_merged = nut.merge_lines_automatic( - lines, self._network_config.file_id, aadt_names, crs - ) - logging.info( - "Function [merge_lines_automatic]: executed with properties {}".format( - list(edges.columns) - ) - ) - else: - edges, lines_merged = lines, gpd.GeoDataFrame() - - edges, id_name = nut.gdf_check_create_unique_ids( - edges, self._network_config.file_id - ) - - if self._cleanup.snapping_threshold: - edges = nut.snap_endpoints_lines( - edges, self._cleanup.snapping_threshold, id_name, crs - ) - logging.info( - "Function [snap_endpoints_lines]: executed with threshold = {}".format( - self._cleanup.snapping_threshold - ) - ) - - # merge merged lines if there are any merged lines - if not lines_merged.empty: - # save the merged lines to a shapefile - CHECK if there are lines merged that should not be merged (e.g. main + secondary road) - lines_merged.set_geometry( - col="geometry", inplace=True - ) # To ensure the object is a GeoDataFrame and not a Series - lines_merged.to_file( - os.path.join( - self.output_graph_dir, - "{}_lines_that_merged.shp".format(self.project_name), - ) - ) - logging.info( - "Function [edges_to_shp]: saved at {}".format( - os.path.join( - self.output_graph_dir, - "{}_lines_that_merged".format(self.project_name), - ) - ) - ) - - # Get the unique points at the end of lines and at intersections to create nodes - nodes = nut.create_nodes(edges, crs, self._cleanup.cut_at_intersections) - logging.info("Function [create_nodes]: executed") - - edges = nut.cut_lines( - edges, nodes, id_name, tolerance=0.00001, crs_=crs - ) ## PAY ATTENTION TO THE TOLERANCE, THE UNIT IS DEGREES - logging.info("Function [cut_lines]: executed") - - if not edges.crs: - edges.crs = crs - - # create tuples from the adjecent nodes and add as column in geodataframe - edges_complex = nut.join_nodes_edges(nodes, edges, id_name) - edges_complex.crs = crs # set the right CRS - edges_complex.dropna(subset=["node_A", "node_B"], inplace=True) - - assert ( - edges_complex["node_A"].isnull().sum() == 0 - ), "Some edges cannot be assigned nodes, please check your input shapefile." - assert ( - edges_complex["node_B"].isnull().sum() == 0 - ), "Some edges cannot be assigned nodes, please check your input shapefile." - - # Create networkx graph from geodataframe - graph_complex = nut.graph_from_gdf(edges_complex, nodes, node_id="node_fid") - logging.info("Function [graph_from_gdf]: executed") - - if not math.isnan(self._cleanup.segmentation_length): - edges_complex = Segmentation( - edges_complex, self._cleanup.segmentation_length - ) - edges_complex = edges_complex.apply_segmentation() - if edges_complex.crs is None: # The CRS might have dissapeared. - edges_complex.crs = crs # set the right CRS - - self.base_graph_crs = pyproj.CRS.from_user_input(crs) - self.base_network_crs = pyproj.CRS.from_user_input(crs) - - # Exporting complex graph because the shapefile should be kept the same as much as possible. - return graph_complex, edges_complex - - def network_cleanshp(self) -> Tuple[nx.classes.graph.Graph, gpd.GeoDataFrame]: - """Creates a (graph) network from a clean shapefile (primary_file - no further advance cleanup is needed) - - Returns the same geometries for the network (GeoDataFrame) as for the graph (NetworkX graph), because - it is assumed that the user wants to keep the same geometries as their shapefile input. - - Returns: - graph_complex (NetworkX graph): The resulting graph. - edges_complex (GeoDataFrame): The resulting network. - """ - # initialise vector network wrapper - vector_network_wrapper = VectorNetworkWrapper(config=self.config) - - # setup network using the wrapper - ( - graph_complex, - edges_complex, - ) = vector_network_wrapper.setup_network_from_vector() - - # Set the CRS of the graph and network to wrapper crs - self.base_graph_crs = vector_network_wrapper.crs - self.base_network_crs = vector_network_wrapper.crs - - return graph_complex, edges_complex - - def _export_linking_tables(self, linking_tables: List[Any]) -> None: - _exporter = JsonExporter() - _exporter.export( - self.output_graph_dir.joinpath("simple_to_complex.json"), linking_tables[0] - ) - _exporter.export( - self.output_graph_dir.joinpath("complex_to_simple.json"), linking_tables[1] - ) - - def network_trails_import( - self, crs: int = 4326 - ) -> Tuple[nx.classes.graph.Graph, gpd.GeoDataFrame]: - """Creates a network which has been prepared in the TRAILS package - - #Todo: we might later simply import the whole trails code as a package, and directly use these functions - #Todo: because TRAILS is still in beta version we better wait with that untill the first stable version is - # released - - Returns: - graph_simple (NetworkX graph): Simplified graph (for use in the indirect analyses). - complex_edges (GeoDataFrame): Complex graph (for use in the direct analyses). - """ - - logging.info( - "TRAILS importer: Reads the provided primary edge file: {}, assumes there also is a_nodes file".format( - self._network_config.primary_file - ) - ) - - logging.warning( - "Any coordinate projection information in the feather file will be overwritten (with default WGS84)" - ) - # Make a pyproj CRS from the EPSG code - crs = pyproj.CRS.from_user_input(crs) - - edge_file = self._network_dir.joinpath(self._network_config.primary_file) - edges = gpd.read_feather(edge_file) - edges = edges.set_crs(crs) - - corresponding_node_file = self._network_dir.joinpath( - self._network_config.primary_file.replace("edges", "nodes") - ) - assert ( - corresponding_node_file.exists() - ), "The node file could not be found while importing from TRAILS" - nodes = gpd.read_feather(corresponding_node_file) - nodes = nodes.set_crs(crs) - # nodes = pd.read_pickle( - # corresponding_node_file - # ) # Todo: Throw exception if nodes file is not present - - logging.info("TRAILS importer: start generating graph") - # tempfix to rename columns - edges = edges.rename({"from_id": "node_A", "to_id": "node_B"}, axis="columns") - node_id = "id" - graph_simple = nut.graph_from_gdf(edges, nodes, name="network", node_id=node_id) - - logging.info("TRAILS importer: graph generating was succesfull.") - logging.warning( - "RA2CE will not clean-up your graph, assuming that it is already done in TRAILS" - ) - - if self._cleanup.segmentation_length: - logging.info("TRAILS importer: start segmentating graph") - to_segment = Segmentation(edges, self._cleanup.segmentation_length) - edges_simple_segmented = to_segment.apply_segmentation() - if edges_simple_segmented.crs is None: # The CRS might have dissapeared. - edges_simple_segmented.crs = edges.crs # set the right CRS - edges_complex = edges_simple_segmented - - else: - edges_complex = edges - - graph_complex = graph_simple # NOTE THAT DIFFERENCE - # BETWEEN SIMPLE AND COMPLEX DOES NOT EXIST WHEN IMPORTING WITH TRAILS - - # Todo: better control over metadata in trails - # Todo: better control over where things are saved in the pipeline - - return graph_complex, edges_complex - - def network_osm_download(self) -> tuple[nx.classes.graph.Graph, gpd.GeoDataFrame]: - """ - Creates a network from a polygon by downloading via the OSM API in the extent of the polygon. - - Returns: - tuple[nx.classes.graph.Graph, gpd.GeoDataFrame]: Tuple of Simplified graph (for use in the indirect analyses) and Complex graph (for use in the direct analyses). - """ - osm_network = OsmNetworkWrapper( - network_type=self._network_config.network_type, - road_types=self._network_config.road_types, - graph_crs="", - polygon_path=self._network_dir.joinpath(self._network_config.polygon), - ) - graph_complex = osm_network.get_clean_graph_from_osm() - - # Create 'graph_simple' - graph_simple, graph_complex, link_tables = nut.create_simplified_graph( - graph_complex - ) - - # Create 'edges_complex', convert complex graph to geodataframe - logging.info("Start converting the graph to a geodataframe") - edges_complex, node_complex = nut.graph_to_gdf(graph_complex) - logging.info("Finished converting the graph to a geodataframe") - - # Save the link tables linking complex and simple IDs - self._export_linking_tables(link_tables) - - # If the user wants to use undirected graphs, turn into an undirected graph (default). - if not self._network_config.directed: - if type(graph_simple) == nx.classes.multidigraph.MultiDiGraph: - graph_simple = graph_simple.to_undirected() - - # No segmentation required, the non-simplified road segments from OSM are already small enough - - self.base_graph_crs = pyproj.CRS.from_user_input( - "EPSG:4326" - ) # Graphs from OSM download are always in this CRS. - self.base_network_crs = pyproj.CRS.from_user_input( - "EPSG:4326" - ) # Graphs from OSM download are always in this CRS. - - return graph_simple, edges_complex - def add_od_nodes( self, graph: nx.classes.graph.Graph, crs: pyproj.CRS ) -> nx.classes.graph.Graph: @@ -426,100 +136,8 @@ def generate_origins_from_raster(self): return out_fn - def read_merge_shp(self, crs_: pyproj.CRS) -> gpd.GeoDataFrame: - """Imports shapefile(s) and saves attributes in a pandas dataframe. - - Args: - crs_ (int): the EPSG number of the coordinate reference system that is used - Returns: - lines (list of shapely LineStrings): full list of linestrings - properties (pandas dataframe): attributes of shapefile(s), in order of the linestrings in lines - """ - - # read shapefiles and add to list with path - def get_shp_paths(shp_str: str) -> Path: - return self._network_dir.joinpath(shp_str) - - shapefiles_analysis = list( - map(get_shp_paths, self._network_config.primary_file.split(",")) - ) - shapefiles_diversion = list( - map(get_shp_paths, self._network_config.diversion_file.split(",")) - ) - - # concatenate all shapefile into one geodataframe and set analysis to 1 or 0 for diversions - lines = [gpd.read_file(shp) for shp in shapefiles_analysis] - - if self._network_config.diversion_file: - lines.extend( - [ - nut.check_crs_gdf(gpd.read_file(shp), crs_) - for shp in shapefiles_diversion - ] - ) - lines = pd.concat(lines) - - lines.crs = crs_ - - # Check if there are any multilinestrings and convert them to linestrings. - if lines["geometry"].apply(lambda row: isinstance(row, MultiLineString)).any(): - mls_idx = lines.loc[ - lines["geometry"].apply(lambda row: isinstance(row, MultiLineString)) - ].index - for idx in mls_idx: - # Multilinestrings to linestrings - new_rows_geoms = list(lines.iloc[idx]["geometry"].geoms) - for nrg in new_rows_geoms: - dict_attributes = dict(lines.iloc[idx]) - dict_attributes["geometry"] = nrg - lines.loc[max(lines.index) + 1] = dict_attributes - - lines = lines.drop(labels=mls_idx, axis=0) - - # append the length of the road stretches - lines["length"] = lines["geometry"].apply(lambda x: nut.line_length(x, crs_)) - - logging.info( - "Shapefile(s) loaded with attributes: {}.".format( - list(lines.columns.values) - ) - ) # fill in parameter names - - return lines - - def get_avg_speed( - self, original_graph: nx.classes.graph.Graph - ) -> nx.classes.graph.Graph: - if all(["length" in e for u, v, e in original_graph.edges.data()]) and any( - ["maxspeed" in e for u, v, e in original_graph.edges.data()] - ): - # Add time weighing - Define and assign average speeds; or take the average speed from an existing CSV - path_avg_speed = self.output_graph_dir.joinpath("avg_speed.csv") - if path_avg_speed.is_file(): - avg_speeds = pd.read_csv(path_avg_speed) - else: - avg_speeds = nut.calc_avg_speed( - original_graph, - "highway", - save_csv=True, - save_path=path_avg_speed, - ) - original_graph = nut.assign_avg_speed(original_graph, avg_speeds, "highway") - - # make a time value of seconds, length of road streches is in meters - for u, v, k, edata in original_graph.edges.data(keys=True): - hours = (edata["length"] / 1000) / edata["avgspeed"] - original_graph[u][v][k]["time"] = round(hours * 3600, 0) - - return original_graph - else: - logging.info( - "No attributes found in the graph to estimate average speed per network segment." - ) - return original_graph - def _export_network_files( - self, network: Any, graph_name: str, types_to_export: List[str] + self, network: Any, graph_name: str, types_to_export: list[str] ): _exporter = NetworkExporterFactory() _exporter.export( @@ -530,14 +148,60 @@ def _export_network_files( ) self.files[graph_name] = _exporter.get_pickle_path() - def _any_cleanup_enabled(self) -> bool: - return ( - self._cleanup.snapping_threshold - or self._cleanup.pruning_threshold - or self._cleanup.merge_lines - or self._cleanup.cut_at_intersections + def _get_new_network_and_graph( + self, export_types: list[str] + ) -> tuple[nx.classes.graph.Graph, gpd.GeoDataFrame]: + _base_graph, _network_gdf = NetworkWrapperFactory( + self._config_data + ).get_network() + + self.base_graph_crs = _network_gdf.crs + self.base_network_crs = _network_gdf.crs + + # Set the road lengths to meters for both the base_graph and network_gdf + # TODO: rename "length" column to "length [m]" to be explicit + edges_lengths_meters = { + (e[0], e[1], e[2]): { + "length": nut.line_length(e[-1]["geometry"], _network_gdf.crs) + } + for e in _base_graph.edges.data(keys=True) + } + nx.set_edge_attributes(_base_graph, edges_lengths_meters) + + _network_gdf["length"] = _network_gdf["geometry"].apply( + lambda x: nut.line_length(x, _network_gdf.crs) + ) + + # Save the graph and geodataframe + self._export_network_files(_base_graph, "base_graph", export_types) + self._export_network_files(_network_gdf, "base_network", export_types) + return _base_graph, _network_gdf + + def _get_stored_network_and_graph( + self, base_graph_filepath: Path, base_network_filepath: Path + ): + logging.info( + "Apparently, you already did create a network with ra2ce earlier. " + + "Ra2ce will use this: {}".format(base_graph_filepath) ) + def check_base_file(file_type: str, file_path: Path): + if not isinstance(base_graph_filepath) or not base_graph_filepath.is_file(): + raise FileNotFoundError( + "No base {} file found at {}.".format(file_type, file_path) + ) + + check_base_file("graph", base_graph_filepath) + check_base_file("network", base_network_filepath) + + _base_graph = GraphPickleReader().read(base_graph_filepath) + _network_gdf = gpd.read_feather(base_network_filepath) + + # Assuming the same CRS for both the network and graph + self.base_graph_crs = _network_gdf.crs + self.base_network_crs = _network_gdf.crs + return _base_graph, _network_gdf + def create(self) -> dict: """Handler function with the logic to call the right functions to create a network. @@ -552,87 +216,12 @@ def create(self) -> dict: # For all graph and networks - check if it exists, otherwise, make the graph and/or network. if not (self.files["base_graph"] or self.files["base_network"]): - # Create the network from the network source - if self._network_config.source == "shapefile": - logging.info("Start creating a network from the submitted shapefile.") - if self._any_cleanup_enabled(): - base_graph, network_gdf = self.network_shp() - else: - base_graph, network_gdf = self.network_cleanshp() - - elif self._network_config.source == "OSM PBF": - logging.info( - """The original OSM PBF import is no longer supported. - Instead, the beta version of package TRAILS is used. - First stable release of TRAILS is expected in 2023.""" - ) - - # base_graph, network_gdf = self.network_osm_pbf() #The old approach is depreciated - base_graph, network_gdf = self.network_trails_import() - - self.base_network_crs = network_gdf.crs - - elif self._network_config.source == "OSM download": - logging.info("Start downloading a network from OSM.") - base_graph, network_gdf = self.network_osm_download() - # Graph & Network from OSM download - # Check if all geometries between nodes are there, if not, add them as a straight line. - base_graph = nut.add_missing_geoms_graph( - base_graph, geom_name="geometry" - ) - elif self._network_config.source == "pickle": - logging.info("Start importing a network from pickle") - base_graph = GraphPickleReader().read( - self.output_graph_dir.joinpath("base_graph.p") - ) - network_gdf = gpd.read_feather( - self.output_graph_dir.joinpath("base_network.feather") - ) - - # Assuming the same CRS for both the network and graph - self.base_graph_crs = pyproj.CRS.from_user_input(network_gdf.crs) - self.base_network_crs = pyproj.CRS.from_user_input(network_gdf.crs) - - # Set the road lengths to meters for both the base_graph and network_gdf - # TODO: rename "length" column to "length [m]" to be explicit - edges_lengths_meters = { - (e[0], e[1], e[2]): { - "length": nut.line_length(e[-1]["geometry"], self.base_graph_crs) - } - for e in base_graph.edges.data(keys=True) - } - nx.set_edge_attributes(base_graph, edges_lengths_meters) - - network_gdf["length"] = network_gdf["geometry"].apply( - lambda x: nut.line_length(x, self.base_network_crs) - ) - - if self._network_config.source == "OSM download": - base_graph = self.get_avg_speed(base_graph) - - # Save the graph and geodataframe - self._export_network_files(base_graph, "base_graph", to_save) - self._export_network_files(network_gdf, "base_network", to_save) + base_graph, network_gdf = self._get_new_network_and_graph(to_save) else: - logging.info( - "Apparently, you already did create a network with ra2ce earlier. " - + "Ra2ce will use this: {}".format(self.files["base_graph"]) + base_graph, network_gdf = self._get_stored_network_and_graph( + self.files["base_graph"], self.files["base_network"] ) - if self.files["base_graph"] is not None: - base_graph = GraphPickleReader().read(self.files["base_graph"]) - else: - base_graph = None - - if self.files["base_network"] is not None: - network_gdf = gpd.read_feather(self.files["base_network"]) - else: - network_gdf = None - - # Assuming the same CRS for both the network and graph - self.base_graph_crs = pyproj.CRS.from_user_input(network_gdf.crs) - self.base_network_crs = pyproj.CRS.from_user_input(network_gdf.crs) - # create origins destinations graph if ( (self.origins) diff --git a/ra2ce/graph/osm_network_wrapper/extremities_data.py b/ra2ce/graph/osm_network_wrapper/extremities_data.py deleted file mode 100644 index 6e4308de1..000000000 --- a/ra2ce/graph/osm_network_wrapper/extremities_data.py +++ /dev/null @@ -1,79 +0,0 @@ -from dataclasses import dataclass - -from networkx import MultiDiGraph - - -@dataclass -class ExtremitiesData: - from_id: int = None - to_id: int = None - from_to_id: tuple = None - to_from_id: tuple = None - from_to_coor: tuple = None - to_from_coor: tuple = None - - @staticmethod - def get_extremities_data_for_sub_graph(from_node_id: int, to_node_id: int, sub_graph: MultiDiGraph, - graph: MultiDiGraph, shared_elements: set): - """Both extremities should be in the unique_graph still makes an edge between similar node to u (the node - with u coordinates and different id, included in the unique_graph) and v Here, sub_graph is the unique_graph - and graph is complex_graph Shared elements are shared btw sub_graph and graph, which are elements to include - when dropping duplicates""" - if shared_elements is None or not isinstance(shared_elements, set): - raise ValueError("unique_elements should be a set") - if from_node_id in sub_graph.nodes() and to_node_id in sub_graph.nodes(): - return ExtremitiesData.arrange_extremities_data( - from_node_id=from_node_id, to_node_id=to_node_id, graph=sub_graph - ) - elif (graph.nodes[from_node_id]['x'], graph.nodes[from_node_id]['y']) in shared_elements and \ - to_node_id in sub_graph.nodes(): - - from_node_id_prime = ExtremitiesData.find_node_id_by_coor( - sub_graph, graph.nodes[from_node_id]['x'], graph.nodes[from_node_id]['y'] - ) - if from_node_id_prime == to_node_id: - return ExtremitiesData() - else: - return ExtremitiesData.arrange_extremities_data(from_node_id=from_node_id_prime, to_node_id=to_node_id, - graph=sub_graph) - - elif from_node_id in sub_graph.nodes() and \ - (graph.nodes[to_node_id]['x'], graph.nodes[to_node_id]['y']) in shared_elements: - - to_node_id_prime = ExtremitiesData.find_node_id_by_coor( - sub_graph, graph.nodes[to_node_id]['x'], graph.nodes[to_node_id]['y'] - ) - if from_node_id == to_node_id_prime: - return ExtremitiesData() - else: - return ExtremitiesData.arrange_extremities_data(from_node_id=from_node_id, to_node_id=to_node_id_prime, - graph=sub_graph) - else: - return ExtremitiesData() - - @staticmethod - def arrange_extremities_data(from_node_id: int, to_node_id: int, graph: MultiDiGraph): - return ExtremitiesData( - from_id=from_node_id, - to_id=to_node_id, - from_to_id=(from_node_id, to_node_id), - to_from_id=(to_node_id, from_node_id), - from_to_coor=( - (graph.nodes[from_node_id]['x'], graph.nodes[to_node_id]['x']), - (graph.nodes[from_node_id]['y'], graph.nodes[to_node_id]['y']) - ), - to_from_coor=( - (graph.nodes[to_node_id]['x'], graph.nodes[from_node_id]['x']), - (graph.nodes[to_node_id]['y'], graph.nodes[from_node_id]['y']) - ) - ) - - @staticmethod - def find_node_id_by_coor(graph: MultiDiGraph, target_x: float, target_y: float): - """ - finds the node in unique graph with the same coor - """ - for node, data in graph.nodes(data=True): - if 'x' in data and 'y' in data and data['x'] == target_x and data['y'] == target_y: - return node - return None diff --git a/ra2ce/graph/vector_network_wrapper.py b/ra2ce/graph/vector_network_wrapper.py deleted file mode 100644 index e6e3b2035..000000000 --- a/ra2ce/graph/vector_network_wrapper.py +++ /dev/null @@ -1,327 +0,0 @@ -import logging -from pathlib import Path -from typing import Union - -import networkx as nx -import pandas as pd -import geopandas as gpd -import momepy - -from shapely.geometry import Point -from pyproj import CRS -import ra2ce.graph.networks_utils as nut - -logger = logging.getLogger() - - -class VectorNetworkWrapper: - """A class for handling and manipulating vector files. - - Provides methods for reading vector data, cleaning it, and setting up graph and - network. - """ - - name: str - region: gpd.GeoDataFrame - crs: CRS - input_path: Path - output_path: Path - network_dict: dict - - def __init__(self, config: dict) -> None: - """Initializes the VectorNetworkWrapper object. - - Args: - config (dict): Configuration dictionary. - - Raises: - ValueError: If the config is None or doesn't contain a network dictionary, - or if config['network'] is not a dictionary. - """ - - if not config: - raise ValueError("Config cannot be None") - if not config.get("network", {}): - raise ValueError( - "A network dictionary is required for creating a " - + f"{self.__class__.__name__} object." - ) - if not isinstance(config.get("network"), dict): - raise ValueError('Config["network"] should be a dictionary') - - self._setup_global(config) - self.network_dict = self._get_network_opt(config["network"]) - - def _parse_ini_stringlist(self, value: str) -> Union[str, list, None]: - """Parses a string with "," into a list from an ini file. - - Args: - value (str): Value to parse. - - Returns: - list of str If the value contains a comma, it is split and returned - as a list, otherwise the original value is returned. - None if the value is an empty string. - """ - if not value: - return None - elif isinstance(value, str) and "," in value: - return value.split(",") - else: - return value - - def _setup_global(self, config: dict) -> None: - """Sets up project properties based on provided configuration. - - Args: - config (dict): Project configuration dictionary. - """ - self.input_path = config.get("static").joinpath("network") - self.output_path = config.get("output") - - project_config = config.get("project") - name = project_config.get("name", "project_name") - region = project_config.get("region", None) - crs = project_config.get("crs", 4326) - self.name = name - self.crs = CRS.from_user_input(crs) - self.region = self._read_files([Path(region)]) if region else region - - def _parse_ini_filenamelist(self, filename: str) -> list[Path]: - """Makes a list of file paths by joining with input path and checks validity of files. - - Args: - filename (str): String of file names separated by comma (","). - - Returns: - List[Path]: List of file paths. - """ - if not isinstance(filename, str): - logger.error("file names are not valid.") - - file_paths = [self.input_path.joinpath(f.strip()) for f in filename.split(",")] - - for f in file_paths: - if not f.resolve().is_file(): - logger.error(f"vector file {f} is not found.") - - return file_paths - - def _get_network_opt(self, network_config: dict) -> dict: - """Retrieves network options used in this wrapper from provided configuration. - - Args: - network_config (dict): Network configuration dictionary. - - Returns: - dict: Dictionary of network options. - """ - - files = self._parse_ini_filenamelist(network_config.get("primary_file", "")) - file_id = self._parse_ini_stringlist( - network_config.get("file_id", "") - ) # only needed when cleanup based on fid - file_filter = self._parse_ini_stringlist(network_config.get("filter", "")) - file_crs = CRS.from_user_input(network_config.get("crs", self.crs)) - is_directed = network_config.get("directed", False) - return dict( - files=files, - file_id=file_id, - file_filter=file_filter, - file_crs=file_crs, - is_directed=is_directed, - ) - - @staticmethod - def setup_digraph_from_vector(gdf: gpd.GeoDataFrame) -> nx.DiGraph: - """Creates a simple directed graph with node and edge geometries based on a given GeoDataFrame. - - Args: - gdf (gpd.GeoDataFrame): Input GeoDataFrame containing line geometries. - Allow both LineString and MultiLineString. - - Returns: - nx.DiGraph: NetworkX graph object with "crs", "approach" as graph properties. - """ - - # simple geometry handeling - gdf = VectorNetworkWrapper.explode_and_deduplicate_geometries(gdf) - - # to graph - digraph = nx.DiGraph(crs=gdf.crs, approach="primal") - for index, row in gdf.iterrows(): - from_node = row.geometry.coords[0] - to_node = row.geometry.coords[-1] - digraph.add_node(from_node, geometry=Point(from_node)) - digraph.add_node(to_node, geometry=Point(to_node)) - digraph.add_edge( - from_node, - to_node, - geometry=row.pop( - "geometry" - ), # **row TODO: check if we do need all columns - ) - - return digraph - - @staticmethod - def setup_graph_from_vector(gdf: gpd.GeoDataFrame) -> nx.Graph: - """Creates a simple undirected graph with node and edge geometries based on a given GeoDataFrame. - - Args: - gdf (gpd.GeoDataFrame): Input GeoDataFrame containing line geometries. - Allow both LineString and MultiLineString. - - Returns: - nx.Graph: NetworkX graph object with "crs", "approach" as graph properties. - """ - digraph = VectorNetworkWrapper.setup_digraph_from_vector(gdf) - return digraph.to_undirected() - - @staticmethod - def setup_network_edges_and_nodes_from_graph( - graph: nx.Graph, - ) -> tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]: - """Sets up network nodes and edges from a given graph. - - Args: - graph (nx.Graph): Input graph with geometry for nodes and edges. - Must contain "crs" as graph property. - - Returns: - gpd.GeoDataFrame: GeoDataFrame representing the network edges with "edge_fid", "node_A", and "node_B". - gpd.GeoDataFrame: GeoDataFrame representing the network nodes with "node_fid". - """ - - # TODO ths function use conventions. Good to make consistant convention with osm - nodes, edges = momepy.nx_to_gdf(graph, nodeID="node_fid") - edges["edge_fid"] = ( - edges["node_start"].astype(str) + "_" + edges["node_end"].astype(str) - ) - edges.rename( - {"node_start": "node_A", "node_end": "node_B"}, axis=1, inplace=True - ) - if not nodes.crs: - nodes.crs = graph.graph["crs"] - if not edges.crs: - edges.crs = graph.graph["crs"] - return edges, nodes - - def setup_network_from_vector( - self, - ) -> tuple[nx.MultiGraph, gpd.GeoDataFrame]: - """Sets up a network from vector files. - - Returns: - nx.MultiGraph: MultiGraph representing the graph. - gpd.GeoDataFrame: GeoDataFrame representing the network. - """ - files = self.network_dict["files"] - file_crs = self.network_dict["file_crs"] - is_directed = self.network_dict["is_directed"] - - gdf = self._read_vector_to_project_region_and_crs( - vector_filenames=files, crs=file_crs - ) - gdf = VectorNetworkWrapper.clean_vector(gdf) - if is_directed: - graph = VectorNetworkWrapper.setup_digraph_from_vector(gdf) - else: - graph = VectorNetworkWrapper.setup_graph_from_vector(gdf) - edges, nodes = VectorNetworkWrapper.setup_network_edges_and_nodes_from_graph( - graph - ) - graph_complex = nut.graph_from_gdf(edges, nodes, node_id="node_fid") - return graph_complex, edges - - def _read_vector_to_project_region_and_crs( - self, vector_filenames: list[Path], crs: CRS - ) -> gpd.GeoDataFrame: - """Reads a vector file or a list of vector files. - - Clips for project region and reproject to project crs if available. - Explodes multi geometry into single geometry. - - Args: - vector_filenames (list[Path]): List of Path to the vector files. - crs (CRS): Coordinate reference system for the files. Allow only one crs for all `vector_filenames`. - - Returns: - gpd.GeoDataFrame: GeoDataFrame representing the vector data. - """ - gdf = self._read_files(vector_filenames) - if gdf is None: - logger.info("no file is read.") - return None - - # set crs and reproject if needed - if not gdf.crs and crs: - gdf = gdf.set_crs(crs) - logger.info("setting crs as default EPSG:4326. specify crs if incorrect") - - if self.crs: - gdf = gdf.to_crs(self.crs) - logger.info("reproject vector file to project crs") - - # clip for region - if self.region is not None: - gdf = gpd.overlay(gdf, self.region, how="intersection", keep_geom_type=True) - logger.info("clip vector file to project region") - - # validate - if not any(gdf): - logger.warning("No vector features found within project region") - return None - - return gdf - - def _read_files(self, file_list: list[Path]) -> gpd.GeoDataFrame: - """Reads a list of files into a GeoDataFrame. - - Args: - file_list (list[Path]): List of file paths. - - Returns: - gpd.GeoDataFrame: GeoDataFrame representing the data. - """ - # read file - if isinstance(file_list, list): - gdf = gpd.GeoDataFrame(pd.concat([gpd.read_file(_fn) for _fn in file_list])) - logger.info("read vector files.") - else: - gdf = None - logger.info("no file is read.") - return gdf - - @staticmethod - def clean_vector(gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame: - """Cleans a GeoDataFrame. - - Args: - gdf (gpd.GeoDataFrame): Input GeoDataFrame. - - Returns: - gpd.GeoDataFrame: Cleaned GeoDataFrame. - """ - - gdf = VectorNetworkWrapper.explode_and_deduplicate_geometries(gdf) - - return gdf - - @staticmethod - def explode_and_deduplicate_geometries(gpd: gpd.GeoDataFrame) -> gpd.GeoDataFrame: - """Explodes and deduplicates geometries a GeoDataFrame. - - Args: - gpd (gpd.GeoDataFrame): Input GeoDataFrame. - - Returns: - gpd.GeoDataFrame: GeoDataFrame with exploded and deduplicated geometries. - """ - gpd = gpd.explode() - gpd = gpd[ - gpd.index.isin( - gpd.geometry.apply(lambda geom: geom.wkb).drop_duplicates().index - ) - ] - return gpd diff --git a/tests/analyses/analysis_config_data/readers/test_analysis_config_reader_with_network.py b/tests/analyses/analysis_config_data/readers/test_analysis_config_reader_with_network.py index 2c9aed388..8f332b53e 100644 --- a/tests/analyses/analysis_config_data/readers/test_analysis_config_reader_with_network.py +++ b/tests/analyses/analysis_config_data/readers/test_analysis_config_reader_with_network.py @@ -1,5 +1,6 @@ from pathlib import Path from typing import Any + import pytest from ra2ce.analyses.analysis_config_data.readers.analysis_config_reader_with_network import ( diff --git a/tests/analyses/analysis_config_data/readers/test_analysis_config_reader_without_network.py b/tests/analyses/analysis_config_data/readers/test_analysis_config_reader_without_network.py index 409895a89..24055f1af 100644 --- a/tests/analyses/analysis_config_data/readers/test_analysis_config_reader_without_network.py +++ b/tests/analyses/analysis_config_data/readers/test_analysis_config_reader_without_network.py @@ -1,5 +1,8 @@ from pathlib import Path from typing import Any + +import pytest + from ra2ce.analyses.analysis_config_data.readers.analysis_config_reader_base import ( AnalysisConfigReaderBase, ) @@ -7,8 +10,6 @@ AnalysisConfigReaderWithoutNetwork, ) -import pytest - class TestAnalysisWithoutNetworkConfigReader: def test_initialize(self): diff --git a/tests/analyses/analysis_config_data/test_analysis_config_data_validator_with_network.py b/tests/analyses/analysis_config_data/test_analysis_config_data_validator_with_network.py index d9ba07eeb..a9ab3302f 100644 --- a/tests/analyses/analysis_config_data/test_analysis_config_data_validator_with_network.py +++ b/tests/analyses/analysis_config_data/test_analysis_config_data_validator_with_network.py @@ -1,3 +1,8 @@ +from pathlib import Path +from typing import Optional + +import pytest + from ra2ce.analyses.analysis_config_data.analysis_config_data import ( AnalysisConfigDataWithNetwork, ) @@ -6,10 +11,7 @@ ) from ra2ce.common.validation.ra2ce_validator_protocol import Ra2ceIoValidator from ra2ce.common.validation.validation_report import ValidationReport -import pytest from tests import test_data -from typing import Optional -from pathlib import Path class TestAnalysisConfigDataValidatorWithNetwork: diff --git a/tests/analyses/analysis_config_data/test_analysis_config_data_validator_without_network.py b/tests/analyses/analysis_config_data/test_analysis_config_data_validator_without_network.py index 56c0fcb85..2bb0a0120 100644 --- a/tests/analyses/analysis_config_data/test_analysis_config_data_validator_without_network.py +++ b/tests/analyses/analysis_config_data/test_analysis_config_data_validator_without_network.py @@ -9,7 +9,6 @@ from ra2ce.analyses.analysis_config_data.analysis_config_data_validator_without_network import ( AnalysisConfigDataValidatorWithoutNetwork, ) - from ra2ce.common.validation.validation_report import ValidationReport from tests import test_data, test_results diff --git a/tests/analyses/analysis_config_wrapper/test_analysis_config_wrapper_factory.py b/tests/analyses/analysis_config_wrapper/test_analysis_config_wrapper_factory.py index d3ba7ac27..d5ad4b13a 100644 --- a/tests/analyses/analysis_config_wrapper/test_analysis_config_wrapper_factory.py +++ b/tests/analyses/analysis_config_wrapper/test_analysis_config_wrapper_factory.py @@ -1,13 +1,13 @@ import pytest -from ra2ce.analyses.analysis_config_wrapper.analysis_config_wrapper_factory import ( - AnalysisConfigWrapperFactory, -) from ra2ce.analyses.analysis_config_data.analysis_config_data import ( AnalysisConfigData, AnalysisConfigDataWithNetwork, AnalysisConfigDataWithoutNetwork, ) +from ra2ce.analyses.analysis_config_wrapper.analysis_config_wrapper_factory import ( + AnalysisConfigWrapperFactory, +) from ra2ce.analyses.analysis_config_wrapper.analysis_config_wrapper_with_network import ( AnalysisConfigWrapperWithNetwork, ) diff --git a/tests/analyses/analysis_config_wrapper/test_analysis_config_wrapper_with_network.py b/tests/analyses/analysis_config_wrapper/test_analysis_config_wrapper_with_network.py index 88e75a021..4be363067 100644 --- a/tests/analyses/analysis_config_wrapper/test_analysis_config_wrapper_with_network.py +++ b/tests/analyses/analysis_config_wrapper/test_analysis_config_wrapper_with_network.py @@ -1,3 +1,4 @@ +import shutil from pathlib import Path import pytest @@ -8,7 +9,6 @@ ) from ra2ce.graph.network_config_wrapper import NetworkConfigWrapper from tests import test_data, test_results -import shutil class TestAnalysisWithNetworkConfig: diff --git a/tests/analyses/analysis_config_wrapper/test_analysis_config_wrapper_without_network.py b/tests/analyses/analysis_config_wrapper/test_analysis_config_wrapper_without_network.py index 702273fed..5bcc22d41 100644 --- a/tests/analyses/analysis_config_wrapper/test_analysis_config_wrapper_without_network.py +++ b/tests/analyses/analysis_config_wrapper/test_analysis_config_wrapper_without_network.py @@ -1,13 +1,12 @@ +import shutil + import pytest -from ra2ce.analyses.analysis_config_data.analysis_config_data import ( - AnalysisConfigData, -) +from ra2ce.analyses.analysis_config_data.analysis_config_data import AnalysisConfigData from ra2ce.analyses.analysis_config_wrapper.analysis_config_wrapper_without_network import ( AnalysisConfigWrapperWithoutNetwork, ) from tests import acceptance_test_data, test_results -import shutil class TestAnalysisWithoutNetworkConfiguration: diff --git a/tests/graph/network_wrappers/__init__.py b/tests/graph/network_wrappers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/graph/osm_network_wrapper/test_osm_network_wrapper.py b/tests/graph/network_wrappers/test_osm_network_wrapper.py similarity index 89% rename from tests/graph/osm_network_wrapper/test_osm_network_wrapper.py rename to tests/graph/network_wrappers/test_osm_network_wrapper.py index 4661f933f..1dde51df2 100644 --- a/tests/graph/osm_network_wrapper/test_osm_network_wrapper.py +++ b/tests/graph/network_wrappers/test_osm_network_wrapper.py @@ -1,4 +1,5 @@ from pathlib import Path + import networkx as nx import pytest from networkx import Graph, MultiDiGraph @@ -6,26 +7,42 @@ from shapely.geometry import LineString, Polygon from shapely.geometry.base import BaseGeometry -from tests import test_data, slow_test import ra2ce.graph.networks_utils as nut -from ra2ce.graph.osm_network_wrapper.osm_network_wrapper import OsmNetworkWrapper +from ra2ce.graph.network_config_data.network_config_data import ( + NetworkConfigData, + NetworkSection, +) +from ra2ce.graph.network_wrappers.network_wrapper_protocol import NetworkWrapperProtocol +from ra2ce.graph.network_wrappers.osm_network_wrapper.osm_network_wrapper import ( + OsmNetworkWrapper, +) +from tests import slow_test, test_data, test_results class TestOsmNetworkWrapper: def test_initialize_without_graph_crs(self): - _wrapper = OsmNetworkWrapper("a_network", ["r"], "", Path()) + # 1. Define test data. + _network_section = NetworkSection(network_type="a_network", road_types=["r"]) + _network_config_data = NetworkConfigData(network=_network_section) + + # 2. Run test. + _wrapper = OsmNetworkWrapper(_network_config_data) + + # 3. Verify final expectations. assert isinstance(_wrapper, OsmNetworkWrapper) - assert _wrapper.graph_crs == "epsg:4326" + assert isinstance(_wrapper, NetworkWrapperProtocol) + assert _wrapper.graph_crs.to_epsg() == 4326 @pytest.fixture def _network_wrapper_without_polygon(self) -> OsmNetworkWrapper: - _network_type = "drive" - _road_types = ["road_link"] + _network_section = NetworkSection( + network_type="drive", road_types=["road_link"], directed=True + ) + _output_dir = test_results.joinpath("test_osm_network_wrapper") + if not _output_dir.exists(): + _output_dir.mkdir(parents=True) yield OsmNetworkWrapper( - network_type=_network_type, - road_types=_road_types, - graph_crs="", - polygon_path=None, + NetworkConfigData(network=_network_section, output_path=_output_dir) ) def test_download_clean_graph_from_osm_with_invalid_polygon_arg( diff --git a/tests/graph/test_osm_utils.py b/tests/graph/network_wrappers/test_osm_utils.py similarity index 95% rename from tests/graph/test_osm_utils.py rename to tests/graph/network_wrappers/test_osm_utils.py index 67303be27..dc2f31c4c 100644 --- a/tests/graph/test_osm_utils.py +++ b/tests/graph/network_wrappers/test_osm_utils.py @@ -2,7 +2,9 @@ import pytest -from ra2ce.graph.osm_utils import from_shapefile_to_poly +from ra2ce.graph.network_wrappers.osm_network_wrapper.osm_utils import ( + from_shapefile_to_poly, +) from tests import test_data, test_results diff --git a/tests/graph/network_wrappers/test_shp_network_wrapper.py b/tests/graph/network_wrappers/test_shp_network_wrapper.py new file mode 100644 index 000000000..56a7c79fc --- /dev/null +++ b/tests/graph/network_wrappers/test_shp_network_wrapper.py @@ -0,0 +1,19 @@ +from ra2ce.graph.network_config_data.network_config_data import ( + NetworkConfigData, +) +from ra2ce.graph.network_wrappers.network_wrapper_protocol import NetworkWrapperProtocol +from ra2ce.graph.network_wrappers.shp_network_wrapper import ShpNetworkWrapper + + +class TestShpNetworkWrapper: + def test_init(self): + # 1.Define test data. + _config_data = NetworkConfigData() + + # 2. Run test. + _wrapper = ShpNetworkWrapper(_config_data) + + # 3. Verify expectations. + assert isinstance(_wrapper, ShpNetworkWrapper) + assert isinstance(_wrapper, NetworkWrapperProtocol) + assert _wrapper.crs.to_epsg() == 4326 diff --git a/tests/graph/network_wrappers/test_trails_network_wrapper.py b/tests/graph/network_wrappers/test_trails_network_wrapper.py new file mode 100644 index 000000000..58c0fbbee --- /dev/null +++ b/tests/graph/network_wrappers/test_trails_network_wrapper.py @@ -0,0 +1,16 @@ +from ra2ce.graph.network_config_data.network_config_data import NetworkConfigData +from ra2ce.graph.network_wrappers.network_wrapper_protocol import NetworkWrapperProtocol +from ra2ce.graph.network_wrappers.trails_network_wrapper import TrailsNetworkWrapper + + +class TestTrailsNetworkWrapper: + def test_initialize(self): + # 1. Define test data. + _config_data = NetworkConfigData() + + # 2. Create wrapper. + _wrapper = TrailsNetworkWrapper(_config_data) + + # 3. Verify expectations. + assert isinstance(_wrapper, TrailsNetworkWrapper) + assert isinstance(_wrapper, NetworkWrapperProtocol) diff --git a/tests/graph/network_wrappers/test_vector_network_wrapper.py b/tests/graph/network_wrappers/test_vector_network_wrapper.py new file mode 100644 index 000000000..1bbfb9a83 --- /dev/null +++ b/tests/graph/network_wrappers/test_vector_network_wrapper.py @@ -0,0 +1,173 @@ +from pathlib import Path + +import geopandas as gpd +import networkx as nx +import pytest +from pyproj import CRS +from shapely.geometry import LineString, MultiLineString, Point + +from ra2ce.graph.network_config_data.network_config_data import NetworkConfigData +from ra2ce.graph.network_wrappers.network_wrapper_protocol import NetworkWrapperProtocol +from ra2ce.graph.network_wrappers.vector_network_wrapper import VectorNetworkWrapper +from tests import test_data + +_test_dir = test_data / "vector_network_wrapper" + + +class TestVectorNetworkWrapper: + @pytest.fixture + def points_gdf(self) -> gpd.GeoDataFrame: + points = [Point(-122.3, 47.6), Point(-122.2, 47.5), Point(-122.1, 47.6)] + return gpd.GeoDataFrame(geometry=points, crs=4326) + + @pytest.fixture + def lines_gdf(self) -> gpd.GeoDataFrame: + points = [Point(-122.3, 47.6), Point(-122.2, 47.5), Point(-122.1, 47.6)] + lines = [LineString([points[0], points[1]]), LineString([points[1], points[2]])] + return gpd.GeoDataFrame(geometry=lines, crs=4326) + + @pytest.fixture + def mock_graph(self): + points = [(-122.3, 47.6), (-122.2, 47.5), (-122.1, 47.6)] + lines = [(points[0], points[1]), (points[1], points[2])] + + graph = nx.Graph(crs=4326) + for point in points: + graph.add_node(point, geometry=Point(point)) + for line in lines: + graph.add_edge( + line[0], line[1], geometry=LineString([Point(line[0]), Point(line[1])]) + ) + + return graph + + def test_initialize(self): + # 1. Define test data. + _config_data = NetworkConfigData() + _config_data.network.primary_file = [Path("dummy_primary")] + _config_data.network.directed = False + _config_data.origins_destinations.region = Path("dummy_region") + + # 2. Run test. + _wrapper = VectorNetworkWrapper(_config_data) + + # 3. Verify expectations. + assert isinstance(_wrapper, VectorNetworkWrapper) + assert isinstance(_wrapper, NetworkWrapperProtocol) + assert _wrapper.primary_files == _config_data.network.primary_file + assert _wrapper.region_path == _config_data.origins_destinations.region + assert _wrapper.crs.to_epsg() == 4326 + + @pytest.fixture + def _valid_wrapper(self) -> VectorNetworkWrapper: + _network_dir = _test_dir.joinpath("static", "network") + _config_data = NetworkConfigData() + _config_data.network.primary_file = [ + _network_dir.joinpath("_test_lines.geojson") + ] + _config_data.network.directed = False + _config_data.origins_destinations.region = None + _config_data.crs = CRS.from_user_input(4326) + yield VectorNetworkWrapper( + config_data=_config_data, + ) + + def test_read_vector_to_project_region_and_crs( + self, _valid_wrapper: VectorNetworkWrapper + ): + # Given + assert not _valid_wrapper.directed + + # When + vector = _valid_wrapper._read_vector_to_project_region_and_crs() + + # Then + assert isinstance(vector, gpd.GeoDataFrame) + + def test_read_vector_to_project_region_and_crs_with_region( + self, _valid_wrapper: VectorNetworkWrapper + ): + # Given + _valid_wrapper.region_path = _test_dir / "_test_polygon.geojson" + _expected_region = gpd.read_file(_valid_wrapper.region_path) + assert isinstance(_expected_region, gpd.GeoDataFrame) + + # When + vector = _valid_wrapper._read_vector_to_project_region_and_crs() + + # Then + assert vector.crs == _expected_region.crs + assert _expected_region.covers(vector.unary_union).all() + + @pytest.mark.parametrize( + "region_path", + [ + pytest.param(None, id="No region"), + pytest.param(_test_dir / "_test_polygon.geojson", id="With region"), + ], + ) + def test_get_network_from_vector( + self, _valid_wrapper: VectorNetworkWrapper, region_path: Path + ): + # Given + _valid_wrapper.region_path = region_path + + # When + graph, edges = _valid_wrapper.get_network() + + # Then + assert isinstance(graph, nx.MultiGraph) + assert isinstance(edges, gpd.GeoDataFrame) + + def test_clean_vector(self, lines_gdf: gpd.GeoDataFrame): + # Given + gdf1 = VectorNetworkWrapper.explode_and_deduplicate_geometries(lines_gdf) + + # When + gdf2 = VectorNetworkWrapper.clean_vector( + lines_gdf + ) # for now cleanup only does the above + + # Then + assert gdf1.equals(gdf2) + + def test_get_indirect_graph_from_vector(self, lines_gdf: gpd.GeoDataFrame): + # When + graph = VectorNetworkWrapper.get_indirect_graph_from_vector(lines_gdf) + + # Then + assert graph.nodes(data="geometry") is not None + assert graph.edges(data="geometry") is not None + assert graph.graph["crs"] == lines_gdf.crs + assert isinstance(graph, nx.Graph) and not isinstance(graph, nx.DiGraph) + + def test_get_direct_graph_from_vector(self, lines_gdf: gpd.GeoDataFrame): + # When + graph = VectorNetworkWrapper.get_direct_graph_from_vector(lines_gdf) + + # Then + assert isinstance(graph, nx.DiGraph) + + def test_get_network_edges_and_nodes_from_graph( + self, mock_graph, points_gdf, lines_gdf + ): + # When + edges, nodes = VectorNetworkWrapper.get_network_edges_and_nodes_from_graph( + mock_graph + ) + + # Then + assert edges.geometry.equals(lines_gdf.geometry) + assert nodes.geometry.equals(points_gdf.geometry) + assert set(["node_A", "node_B", "edge_fid"]).issubset(edges.columns) + assert set(["node_fid"]).issubset(nodes.columns) + + def test_explode_and_deduplicate_geometries(self, lines_gdf): + # Given + multi_lines = lines_gdf.geometry.apply(lambda x: MultiLineString([x])) + + # When + gdf = VectorNetworkWrapper.explode_and_deduplicate_geometries(multi_lines) + + # Then + assert isinstance(gdf.geometry.iloc[0], LineString) diff --git a/tests/graph/test_vector_network_wrapper.py b/tests/graph/test_vector_network_wrapper.py deleted file mode 100644 index e34a1a879..000000000 --- a/tests/graph/test_vector_network_wrapper.py +++ /dev/null @@ -1,243 +0,0 @@ -import pytest - -from pathlib import Path - -import geopandas as gpd -import networkx as nx -from shapely.geometry import LineString, Point, MultiLineString - -from tests import test_data -from ra2ce.graph.vector_network_wrapper import VectorNetworkWrapper - -_test_dir = test_data / "vector_network_wrapper" - - -class TestVectorNetworkWrapper: - @pytest.fixture - def _config_fixture(self) -> dict: - yield { - "project": { - "name": "test", - "crs": 4326, - }, - "network": { - "directed": False, - "source": "shapefile", - "primary_file": "_test_lines.geojson", - "diversion_file": None, - "file_id": "fid", - "polygon": None, - "network_type": None, - "road_types": None, - "save_shp": False, - }, - "static": _test_dir / "static", - "output": _test_dir / "output", - } - - @pytest.fixture - def points_gdf(self): - points = [Point(-122.3, 47.6), Point(-122.2, 47.5), Point(-122.1, 47.6)] - return gpd.GeoDataFrame(geometry=points, crs=4326) - - @pytest.fixture - def lines_gdf(self): - points = [Point(-122.3, 47.6), Point(-122.2, 47.5), Point(-122.1, 47.6)] - lines = [LineString([points[0], points[1]]), LineString([points[1], points[2]])] - return gpd.GeoDataFrame(geometry=lines, crs=4326) - - @pytest.fixture - def mock_graph(self): - points = [(-122.3, 47.6), (-122.2, 47.5), (-122.1, 47.6)] - lines = [(points[0], points[1]), (points[1], points[2])] - - graph = nx.Graph(crs=4326) - for point in points: - graph.add_node(point, geometry=Point(point)) - for line in lines: - graph.add_edge( - line[0], line[1], geometry=LineString([Point(line[0]), Point(line[1])]) - ) - - return graph - - @pytest.mark.parametrize( - "config", - [ - pytest.param(None, id="NONE as dictionary"), - pytest.param({}, id="Empty dictionary"), - pytest.param({"network": {}}, id='Empty "network" in Config'), - pytest.param({"network": "string"}, id='Invalid "network" type in Config'), - ], - ) - def test_init(self, config: dict): - with pytest.raises(ValueError) as exc_err: - VectorNetworkWrapper(config=config) - assert str(exc_err.value) in [ - "Config cannot be None", - "A network dictionary is required for creating a VectorNetworkWrapper object.", - 'Config["network"] should be a dictionary', - ] - - def test_parse_ini_stringlist_with_comma_separated_string(self, _config_fixture): - # Given - test_wrapper = VectorNetworkWrapper(_config_fixture) - ini_value = "a,b,c" - - # When - result = test_wrapper._parse_ini_stringlist(ini_value) - - # Then - assert result == ["a", "b", "c"] - - def test_parse_ini_stringlist_with_single_string(self, _config_fixture): - # Given - test_wrapper = VectorNetworkWrapper(_config_fixture) - ini_value = "abc" - - # When - result = test_wrapper._parse_ini_stringlist(ini_value) - - # Then - assert result == "abc" - - def test_parse_ini_stringlist_with_empty_string(self, _config_fixture): - # Given - test_wrapper = VectorNetworkWrapper(_config_fixture) - ini_value = "" - - # When - result = test_wrapper._parse_ini_stringlist(ini_value) - - # Then - assert result is None - - def test_parse_ini_filenamelist(self, _config_fixture): - # Given - test_wrapper = VectorNetworkWrapper(_config_fixture) - ini_value = "_test_lines.geojson, dummy.geojson" - - # When - file_paths = test_wrapper._parse_ini_filenamelist(ini_value) - - # Then - assert file_paths[0].is_file() - - def test_setup_global(self, _config_fixture): - test_wrapper = VectorNetworkWrapper(_config_fixture) - test_wrapper._setup_global(_config_fixture) - assert test_wrapper.name == "test" - assert test_wrapper.region is None - assert test_wrapper.crs.to_epsg() == 4326 - assert test_wrapper.input_path == _test_dir / "static/network" - assert test_wrapper.output_path == _test_dir / "output" - - def test_get_network_opt(self, _config_fixture): - test_wrapper = VectorNetworkWrapper(_config_fixture) - network_dict = test_wrapper._get_network_opt(_config_fixture["network"]) - assert network_dict["files"][0].is_file() - assert network_dict["file_id"] == "fid" - assert network_dict["file_crs"].to_epsg() == 4326 - assert network_dict["is_directed"] is False - - def test_read_vector_to_project_region_and_crs(self, _config_fixture): - # Given - test_wrapper = VectorNetworkWrapper(_config_fixture) - files = test_wrapper.network_dict["files"] - file_crs = test_wrapper.network_dict["file_crs"] - - # When - vector = test_wrapper._read_vector_to_project_region_and_crs(files, file_crs) - - # Then - assert isinstance(vector, gpd.GeoDataFrame) - - def test_read_vector_to_project_region_and_crs_with_region(self, _config_fixture): - # Given - _config_fixture["project"]["region"] = _test_dir / "_test_polygon.geojson" - test_wrapper = VectorNetworkWrapper(_config_fixture) - files = test_wrapper.network_dict["files"] - file_crs = test_wrapper.network_dict["file_crs"] - - # When - vector = test_wrapper._read_vector_to_project_region_and_crs(files, file_crs) - - # Then - assert vector.crs == test_wrapper.region.crs - assert test_wrapper.region.covers(vector.unary_union).all() - - def test_setup_network_from_vector(self, _config_fixture): - # Given - test_wrapper = VectorNetworkWrapper(_config_fixture) - - # When - graph, edges = test_wrapper.setup_network_from_vector() - - # Then - assert isinstance(graph, nx.MultiGraph) - assert isinstance(edges, gpd.GeoDataFrame) - - def test_setup_network_from_vector_with_region(self, _config_fixture): - # Given - _config_fixture["project"]["region"] = _test_dir / "_test_polygon.geojson" - test_wrapper = VectorNetworkWrapper(_config_fixture) - - # When - graph, edges = test_wrapper.setup_network_from_vector() - - # Then - assert isinstance(graph, nx.MultiGraph) - assert isinstance(edges, gpd.GeoDataFrame) - - def test_clean_vector(self, lines_gdf): - # Given - gdf1 = VectorNetworkWrapper.explode_and_deduplicate_geometries(lines_gdf) - - # When - gdf2 = VectorNetworkWrapper.clean_vector( - lines_gdf - ) # for now cleanup only does the above - - # Then - assert gdf1.equals(gdf2) - - def test_setup_graph_from_vector(self, lines_gdf): - # When - graph = VectorNetworkWrapper.setup_graph_from_vector(lines_gdf) - - # Then - assert graph.nodes(data="geometry") is not None - assert graph.edges(data="geometry") is not None - assert graph.graph["crs"] == lines_gdf.crs - assert isinstance(graph, nx.Graph) and not isinstance(graph, nx.DiGraph) - - def test_setup_digraph_from_vector(self, lines_gdf): - # When - graph = VectorNetworkWrapper.setup_digraph_from_vector(lines_gdf) - - # Then - assert isinstance(graph, nx.DiGraph) - - def test_setup_network_edges_and_nodes_from_graph( - self, mock_graph, points_gdf, lines_gdf - ): - # When - edges, nodes = VectorNetworkWrapper.setup_network_edges_and_nodes_from_graph( - mock_graph - ) - - # Then - assert edges.geometry.equals(lines_gdf.geometry) - assert nodes.geometry.equals(points_gdf.geometry) - assert set(["node_A", "node_B", "edge_fid"]).issubset(edges.columns) - assert set(["node_fid"]).issubset(nodes.columns) - - def test_explode_and_deduplicate_geometries(self, lines_gdf): - # Given - multi_lines = lines_gdf.geometry.apply(lambda x: MultiLineString([x])) - - # When - gdf = VectorNetworkWrapper.explode_and_deduplicate_geometries(multi_lines) - - # Then - assert isinstance(gdf.geometry.iloc[0], LineString) diff --git a/tests/test_acceptance.py b/tests/test_acceptance.py index ea4116380..d81cd7d1b 100644 --- a/tests/test_acceptance.py +++ b/tests/test_acceptance.py @@ -14,14 +14,13 @@ _analysis_ini_name = "analyses.ini" _base_graph_p_filename = "base_graph.p" _base_network_feather_filename = "base_network.feather" +_skip_cases = [] def get_external_test_cases() -> list[pytest.param]: if not test_external_data.exists(): return [] - _skip_cases = ["bolivia"] - def get_pytest_param(test_dir: Path) -> pytest.param: _marks = [external_test] if test_dir.stem.lower() in _skip_cases: