diff --git a/ra2ce/analyses/analysis_config_data/analysis_config_data.py b/ra2ce/analyses/analysis_config_data/analysis_config_data.py
index 26ad15bca..b662d296c 100644
--- a/ra2ce/analyses/analysis_config_data/analysis_config_data.py
+++ b/ra2ce/analyses/analysis_config_data/analysis_config_data.py
@@ -21,6 +21,7 @@
from __future__ import annotations
+
from ra2ce.common.configuration.config_data_protocol import ConfigDataProtocol
diff --git a/ra2ce/analyses/analysis_config_data/analysis_config_data_validator_without_network.py b/ra2ce/analyses/analysis_config_data/analysis_config_data_validator_without_network.py
index c3fd95cd7..f9c03d842 100644
--- a/ra2ce/analyses/analysis_config_data/analysis_config_data_validator_without_network.py
+++ b/ra2ce/analyses/analysis_config_data/analysis_config_data_validator_without_network.py
@@ -21,17 +21,16 @@
from pathlib import Path
+
from ra2ce.analyses.analysis_config_data.analysis_config_data import (
AnalysisConfigDataWithoutNetwork,
)
-
from ra2ce.common.validation.ra2ce_validator_protocol import Ra2ceIoValidator
from ra2ce.common.validation.validation_report import ValidationReport
from ra2ce.graph.network_config_data.network_config_data_validator import (
NetworkDictValues,
)
-
IndirectAnalysisNameList: list[str] = [
"single_link_redundancy",
"multi_link_redundancy",
diff --git a/ra2ce/analyses/analysis_config_data/readers/analysis_config_reader_factory.py b/ra2ce/analyses/analysis_config_data/readers/analysis_config_reader_factory.py
index 0b997fda2..304104a06 100644
--- a/ra2ce/analyses/analysis_config_data/readers/analysis_config_reader_factory.py
+++ b/ra2ce/analyses/analysis_config_data/readers/analysis_config_reader_factory.py
@@ -23,9 +23,6 @@
from pathlib import Path
from typing import Optional
-from ra2ce.analyses.analysis_config_wrapper.analysis_config_wrapper_base import (
- AnalysisConfigWrapperBase,
-)
from ra2ce.analyses.analysis_config_data.readers.analysis_config_reader_base import (
AnalysisConfigReaderBase,
)
@@ -35,6 +32,9 @@
from ra2ce.analyses.analysis_config_data.readers.analysis_config_reader_without_network import (
AnalysisConfigReaderWithoutNetwork,
)
+from ra2ce.analyses.analysis_config_wrapper.analysis_config_wrapper_base import (
+ AnalysisConfigWrapperBase,
+)
from ra2ce.graph.network_config_data.network_config_data import NetworkConfigData
diff --git a/ra2ce/analyses/analysis_config_data/readers/analysis_config_reader_with_network.py b/ra2ce/analyses/analysis_config_data/readers/analysis_config_reader_with_network.py
index 8603ecbce..48df087a7 100644
--- a/ra2ce/analyses/analysis_config_data/readers/analysis_config_reader_with_network.py
+++ b/ra2ce/analyses/analysis_config_data/readers/analysis_config_reader_with_network.py
@@ -22,15 +22,15 @@
from pathlib import Path
-from ra2ce.analyses.analysis_config_wrapper.analysis_config_wrapper_base import (
- AnalysisConfigWrapperBase,
-)
from ra2ce.analyses.analysis_config_data.analysis_config_data import (
AnalysisConfigDataWithNetwork,
)
from ra2ce.analyses.analysis_config_data.readers.analysis_config_reader_base import (
AnalysisConfigReaderBase,
)
+from ra2ce.analyses.analysis_config_wrapper.analysis_config_wrapper_base import (
+ AnalysisConfigWrapperBase,
+)
from ra2ce.graph.network_config_wrapper import NetworkConfigWrapper
diff --git a/ra2ce/analyses/analysis_config_data/readers/analysis_config_reader_without_network.py b/ra2ce/analyses/analysis_config_data/readers/analysis_config_reader_without_network.py
index b4114e322..aac67030e 100644
--- a/ra2ce/analyses/analysis_config_data/readers/analysis_config_reader_without_network.py
+++ b/ra2ce/analyses/analysis_config_data/readers/analysis_config_reader_without_network.py
@@ -23,15 +23,15 @@
import logging
from pathlib import Path
-from ra2ce.analyses.analysis_config_wrapper.analysis_config_wrapper_base import (
- AnalysisConfigWrapperBase,
-)
from ra2ce.analyses.analysis_config_data.analysis_config_data import (
AnalysisConfigDataWithoutNetwork,
)
from ra2ce.analyses.analysis_config_data.readers.analysis_config_reader_base import (
AnalysisConfigReaderBase,
)
+from ra2ce.analyses.analysis_config_wrapper.analysis_config_wrapper_base import (
+ AnalysisConfigWrapperBase,
+)
from ra2ce.graph.network_config_data.network_config_data_reader import (
NetworkConfigDataReader,
)
diff --git a/ra2ce/analyses/analysis_config_wrapper/analysis_config_wrapper_base.py b/ra2ce/analyses/analysis_config_wrapper/analysis_config_wrapper_base.py
index fafd2f6f8..0c6e5838d 100644
--- a/ra2ce/analyses/analysis_config_wrapper/analysis_config_wrapper_base.py
+++ b/ra2ce/analyses/analysis_config_wrapper/analysis_config_wrapper_base.py
@@ -20,7 +20,7 @@
"""
-from abc import abstractmethod, abstractclassmethod
+from abc import abstractclassmethod, abstractmethod
from pathlib import Path
from typing import Optional
diff --git a/ra2ce/analyses/analysis_config_wrapper/analysis_config_wrapper_factory.py b/ra2ce/analyses/analysis_config_wrapper/analysis_config_wrapper_factory.py
index 6d72124f9..b73cf46a3 100644
--- a/ra2ce/analyses/analysis_config_wrapper/analysis_config_wrapper_factory.py
+++ b/ra2ce/analyses/analysis_config_wrapper/analysis_config_wrapper_factory.py
@@ -23,14 +23,14 @@
from pathlib import Path
from typing import Optional
-from ra2ce.analyses.analysis_config_wrapper.analysis_config_wrapper_base import (
- AnalysisConfigWrapperBase,
-)
from ra2ce.analyses.analysis_config_data.analysis_config_data import (
AnalysisConfigData,
AnalysisConfigDataWithNetwork,
AnalysisConfigDataWithoutNetwork,
)
+from ra2ce.analyses.analysis_config_wrapper.analysis_config_wrapper_base import (
+ AnalysisConfigWrapperBase,
+)
from ra2ce.analyses.analysis_config_wrapper.analysis_config_wrapper_with_network import (
AnalysisConfigWrapperWithNetwork,
)
diff --git a/ra2ce/analyses/analysis_config_wrapper/analysis_config_wrapper_with_network.py b/ra2ce/analyses/analysis_config_wrapper/analysis_config_wrapper_with_network.py
index 093cd0968..11de109af 100644
--- a/ra2ce/analyses/analysis_config_wrapper/analysis_config_wrapper_with_network.py
+++ b/ra2ce/analyses/analysis_config_wrapper/analysis_config_wrapper_with_network.py
@@ -23,12 +23,14 @@
from __future__ import annotations
from pathlib import Path
-from ra2ce.analyses.analysis_config_data.analysis_config_data_validator_with_network import AnalysisConfigDataValidatorWithNetwork
+from ra2ce.analyses.analysis_config_data.analysis_config_data import AnalysisConfigData
+from ra2ce.analyses.analysis_config_data.analysis_config_data_validator_with_network import (
+ AnalysisConfigDataValidatorWithNetwork,
+)
from ra2ce.analyses.analysis_config_wrapper.analysis_config_wrapper_base import (
AnalysisConfigWrapperBase,
)
-from ra2ce.analyses.analysis_config_data.analysis_config_data import AnalysisConfigData
from ra2ce.graph.network_config_data.network_config_data import NetworkConfigData
from ra2ce.graph.network_config_wrapper import NetworkConfigWrapper
@@ -98,5 +100,7 @@ def configure(self) -> None:
def is_valid(self) -> bool:
_file_is_valid = self.ini_file.is_file() and self.ini_file.suffix == ".ini"
- _validation_report = AnalysisConfigDataValidatorWithNetwork(self.config_data).validate()
+ _validation_report = AnalysisConfigDataValidatorWithNetwork(
+ self.config_data
+ ).validate()
return _file_is_valid and _validation_report.is_valid()
diff --git a/ra2ce/analyses/analysis_config_wrapper/analysis_config_wrapper_without_network.py b/ra2ce/analyses/analysis_config_wrapper/analysis_config_wrapper_without_network.py
index f948d769e..b16c9d4fb 100644
--- a/ra2ce/analyses/analysis_config_wrapper/analysis_config_wrapper_without_network.py
+++ b/ra2ce/analyses/analysis_config_wrapper/analysis_config_wrapper_without_network.py
@@ -24,12 +24,14 @@
import logging
from pathlib import Path
-from ra2ce.analyses.analysis_config_data.analysis_config_data_validator_without_network import AnalysisConfigDataValidatorWithoutNetwork
+from ra2ce.analyses.analysis_config_data.analysis_config_data import AnalysisConfigData
+from ra2ce.analyses.analysis_config_data.analysis_config_data_validator_without_network import (
+ AnalysisConfigDataValidatorWithoutNetwork,
+)
from ra2ce.analyses.analysis_config_wrapper.analysis_config_wrapper_base import (
AnalysisConfigWrapperBase,
)
-from ra2ce.analyses.analysis_config_data.analysis_config_data import AnalysisConfigData
from ra2ce.graph.network_config_wrapper import NetworkConfigWrapper
@@ -77,5 +79,7 @@ def configure(self) -> None:
def is_valid(self) -> bool:
_file_is_valid = self.ini_file.is_file() and self.ini_file.suffix == ".ini"
- _validation_report = AnalysisConfigDataValidatorWithoutNetwork(self.config_data).validate()
+ _validation_report = AnalysisConfigDataValidatorWithoutNetwork(
+ self.config_data
+ ).validate()
return _file_is_valid and _validation_report.is_valid()
diff --git a/ra2ce/common/configuration/config_data_protocol.py b/ra2ce/common/configuration/config_data_protocol.py
index 5960901dc..4f505f126 100644
--- a/ra2ce/common/configuration/config_data_protocol.py
+++ b/ra2ce/common/configuration/config_data_protocol.py
@@ -33,4 +33,4 @@ def to_dict(self) -> dict:
Returns:
dict: Dictionary representing the `ConfigDataProtocol` instance.
"""
- pass
\ No newline at end of file
+ pass
diff --git a/ra2ce/common/configuration/ini_configuration_reader_protocol.py b/ra2ce/common/configuration/ini_configuration_reader_protocol.py
index 1c081257a..edf92c54a 100644
--- a/ra2ce/common/configuration/ini_configuration_reader_protocol.py
+++ b/ra2ce/common/configuration/ini_configuration_reader_protocol.py
@@ -21,6 +21,7 @@
from pathlib import Path
from typing import Protocol, runtime_checkable
+
from ra2ce.common.configuration.config_data_protocol import ConfigDataProtocol
from ra2ce.common.io.readers.file_reader_protocol import FileReaderProtocol
diff --git a/ra2ce/configuration/config_factory.py b/ra2ce/configuration/config_factory.py
index 956fb3dd8..d3da83cb7 100644
--- a/ra2ce/configuration/config_factory.py
+++ b/ra2ce/configuration/config_factory.py
@@ -20,19 +20,19 @@
"""
-from pathlib import Path
import shutil
+from pathlib import Path
from typing import Optional
+from ra2ce.analyses.analysis_config_data.readers.analysis_config_reader_factory import (
+ AnalysisConfigReaderFactory,
+)
from ra2ce.analyses.analysis_config_wrapper.analysis_config_wrapper_base import (
AnalysisConfigWrapperBase,
)
from ra2ce.analyses.analysis_config_wrapper.analysis_config_wrapper_factory import (
AnalysisConfigWrapperFactory,
)
-from ra2ce.analyses.analysis_config_data.readers.analysis_config_reader_factory import (
- AnalysisConfigReaderFactory,
-)
from ra2ce.configuration.config_wrapper import ConfigWrapper
from ra2ce.graph.network_config_data.network_config_data import NetworkConfigData
from ra2ce.graph.network_config_data.network_config_data_reader import (
diff --git a/ra2ce/graph/network_config_data/network_config_data.py b/ra2ce/graph/network_config_data/network_config_data.py
index e0f826753..74ab33976 100644
--- a/ra2ce/graph/network_config_data/network_config_data.py
+++ b/ra2ce/graph/network_config_data/network_config_data.py
@@ -25,6 +25,8 @@
from pathlib import Path
from typing import Optional
+from pyproj import CRS
+
from ra2ce.common.configuration.config_data_protocol import ConfigDataProtocol
@@ -37,10 +39,10 @@ class ProjectSection:
class NetworkSection:
directed: bool = False
source: str = "" # should be enum
- primary_file: str = "" # TODO. Unclear whether this is `Path` or `list[Path]`
- diversion_file: str = "" # TODO. Unclear whether this is `Path` or `list[Path]`
+ primary_file: list[Path] = field(default_factory=list)
+ diversion_file: list[Path] = field(default_factory=list)
file_id: str = ""
- polygon: str = "" # TODO. Unclear whether this is `str`` or `Path`
+ polygon: Optional[Path] = None
network_type: str = "" # Should be enum
road_types: list[str] = field(default_factory=list)
save_shp: bool = False
@@ -90,7 +92,8 @@ class NetworkConfigData(ConfigDataProtocol):
input_path: Optional[Path] = None
output_path: Optional[Path] = None
static_path: Optional[Path] = None
-
+ # CRS is not yet supported in the ini file, it might be relocated to a subsection.
+ crs: CRS = field(default_factory=lambda: CRS.from_user_input(4326))
project: ProjectSection = field(default_factory=lambda: ProjectSection())
network: NetworkSection = field(default_factory=lambda: NetworkSection())
origins_destinations: OriginsDestinationsSection = field(
@@ -106,8 +109,15 @@ def output_graph_dir(self) -> Optional[Path]:
return None
return self.static_path.joinpath("output_graph")
+ @property
+ def network_dir(self) -> Optional[Path]:
+ if not self.static_path:
+ return None
+ return self.static_path.joinpath("network")
+
def to_dict(self) -> dict:
_dict = self.__dict__
+ _dict["crs"] = self.crs.to_epsg()
_dict["project"] = self.project.__dict__
_dict["network"] = self.network.__dict__
_dict["origins_destinations"] = self.origins_destinations.__dict__
diff --git a/ra2ce/graph/network_config_data/network_config_data_reader.py b/ra2ce/graph/network_config_data/network_config_data_reader.py
index 14ce57685..8cc0347e6 100644
--- a/ra2ce/graph/network_config_data/network_config_data_reader.py
+++ b/ra2ce/graph/network_config_data/network_config_data_reader.py
@@ -1,6 +1,6 @@
from configparser import ConfigParser
from pathlib import Path
-from typing import Union
+from typing import Any, Union
from ra2ce.common.configuration.ini_configuration_reader_protocol import (
ConfigDataReaderProtocol,
@@ -58,6 +58,15 @@ def _select_to_correct(path_value: Union[list[Path], Path]) -> bool:
return _select_to_correct(path_value[0])
return not path_value.exists()
+ def _correct_list(path_root: Path, path_value_list: list[Path]) -> list[Path]:
+ _corrected_list = []
+ for _path_value in path_value_list:
+ if not _path_value.exists():
+ _corrected_list.append(path_root.joinpath(_path_value))
+ else:
+ _corrected_list.append(_path_value)
+ return _corrected_list
+
# Relative to network directory.
_network_directory = config_data.static_path.joinpath("network")
if _select_to_correct(config_data.origins_destinations.origins):
@@ -70,15 +79,28 @@ def _select_to_correct(path_value: Union[list[Path], Path]) -> bool:
config_data.origins_destinations.destinations
)
+ if _select_to_correct(config_data.origins_destinations.region):
+ config_data.origins_destinations.region = _network_directory.joinpath(
+ config_data.origins_destinations.region
+ )
+
+ if _select_to_correct(config_data.network.polygon):
+ config_data.network.polygon = _network_directory.joinpath(
+ config_data.network.polygon
+ )
+
+ config_data.network.primary_file = _correct_list(
+ _network_directory, config_data.network.primary_file
+ )
+ config_data.network.diversion_file = _correct_list(
+ _network_directory, config_data.network.diversion_file
+ )
+
# Relative to hazard directory.
_hazard_directory = config_data.static_path.joinpath("hazard")
- if _select_to_correct(config_data.hazard.hazard_map):
- config_data.hazard.hazard_map = list(
- map(
- lambda x: _hazard_directory.joinpath(x),
- config_data.hazard.hazard_map,
- )
- )
+ config_data.hazard.hazard_map = _correct_list(
+ _hazard_directory, config_data.hazard.hazard_map
+ )
def _get_str_as_path(self, str_value: Union[str, Path]) -> Path:
if str_value and not isinstance(str_value, Path):
@@ -107,9 +129,23 @@ def _get_sections(self) -> dict:
def get_project_section(self) -> ProjectSection:
return ProjectSection(**self._parser["project"])
+ def _get_path_list(
+ self, section_name: str, property: str, fallback_opt: Any
+ ) -> list[Path]:
+ _value_list = self._parser.getlist(
+ section_name, property, fallback=fallback_opt
+ )
+ return list(map(self._get_str_as_path, _value_list))
+
def get_network_section(self) -> NetworkSection:
_section = "network"
_network_section = NetworkSection(**self._parser[_section])
+ _network_section.primary_file = self._get_path_list(
+ _section, "primary_file", _network_section.primary_file
+ )
+ _network_section.diversion_file = self._get_path_list(
+ _section, "diversion_file", _network_section.diversion_file
+ )
_network_section.directed = self._parser.getboolean(
_section, "directed", fallback=_network_section.directed
)
@@ -119,6 +155,7 @@ def get_network_section(self) -> NetworkSection:
_network_section.road_types = self._parser.getlist(
_section, "road_types", fallback=_network_section.road_types
)
+ _network_section.polygon = self._get_str_as_path(_network_section.polygon)
return _network_section
def get_origins_destinations_section(self) -> OriginsDestinationsSection:
diff --git a/ra2ce/graph/network_config_data/network_config_data_validator.py b/ra2ce/graph/network_config_data/network_config_data_validator.py
index f411c7ca9..1fe334168 100644
--- a/ra2ce/graph/network_config_data/network_config_data_validator.py
+++ b/ra2ce/graph/network_config_data/network_config_data_validator.py
@@ -20,6 +20,7 @@
"""
from typing import Any
+
from ra2ce.common.validation.ra2ce_validator_protocol import Ra2ceIoValidator
from ra2ce.common.validation.validation_report import ValidationReport
from ra2ce.graph.network_config_data.network_config_data import (
diff --git a/ra2ce/graph/network_config_wrapper.py b/ra2ce/graph/network_config_wrapper.py
index a25ef1d04..03e01168b 100644
--- a/ra2ce/graph/network_config_wrapper.py
+++ b/ra2ce/graph/network_config_wrapper.py
@@ -37,6 +37,7 @@
)
from ra2ce.graph.networks import Network
+
class NetworkConfigWrapper(ConfigWrapperProtocol):
files: Dict[str, Path] = {}
config_data: NetworkConfigData
diff --git a/ra2ce/graph/network_wrappers/README.md b/ra2ce/graph/network_wrappers/README.md
new file mode 100644
index 000000000..bc04349e7
--- /dev/null
+++ b/ra2ce/graph/network_wrappers/README.md
@@ -0,0 +1,8 @@
+# Network wrappers
+
+In this module we define wrappers for reading a ra2ce network (`tuple[MultiGraph, GeoDataFrame]`). Each different class defines how a network will be created and cleaned.
+Network wrappers need to instantiate the `NetworkWrapperProtocol`.
+
+An instance of the `NetworkWrapperProtocol` defines the `get_network` method, which will return the aforementioend ra2ce network.
+
+Any network wrapper can be created with the use of a `NetworkConfigData` instance. At the same time, if the user does not know which wrapper to use, the `NetworkWrapperFactory` will resolve this for us.
\ No newline at end of file
diff --git a/ra2ce/graph/osm_network_wrapper/__init__.py b/ra2ce/graph/network_wrappers/__init__.py
similarity index 100%
rename from ra2ce/graph/osm_network_wrapper/__init__.py
rename to ra2ce/graph/network_wrappers/__init__.py
diff --git a/ra2ce/graph/network_wrappers/network_wrapper_factory.py b/ra2ce/graph/network_wrappers/network_wrapper_factory.py
new file mode 100644
index 000000000..be558019a
--- /dev/null
+++ b/ra2ce/graph/network_wrappers/network_wrapper_factory.py
@@ -0,0 +1,71 @@
+"""
+ GNU GENERAL PUBLIC LICENSE
+ Version 3, 29 June 2007
+
+ Risk Assessment and Adaptation for Critical Infrastructure (RA2CE).
+ Copyright (C) 2023 Stichting Deltares
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+"""
+
+import logging
+
+import geopandas as gpd
+from geopandas import GeoDataFrame
+from networkx import MultiGraph
+
+from ra2ce.common.io.readers import GraphPickleReader
+from ra2ce.graph.network_config_data.network_config_data import NetworkConfigData
+from ra2ce.graph.network_wrappers.network_wrapper_protocol import NetworkWrapperProtocol
+from ra2ce.graph.network_wrappers.osm_network_wrapper.osm_network_wrapper import (
+ OsmNetworkWrapper,
+)
+from ra2ce.graph.network_wrappers.shp_network_wrapper import ShpNetworkWrapper
+from ra2ce.graph.network_wrappers.trails_network_wrapper import TrailsNetworkWrapper
+from ra2ce.graph.network_wrappers.vector_network_wrapper import VectorNetworkWrapper
+
+
+class NetworkWrapperFactory(NetworkWrapperProtocol):
+ def __init__(self, config_data: NetworkConfigData) -> None:
+ self._config_data = config_data
+
+ def _any_cleanup_enabled(self) -> bool:
+ _cleanup = self._config_data.cleanup
+ return (
+ _cleanup.snapping_threshold
+ or _cleanup.pruning_threshold
+ or _cleanup.merge_lines
+ or _cleanup.cut_at_intersections
+ )
+
+ def get_network(self) -> tuple[MultiGraph, GeoDataFrame]:
+ logging.info("Start creating a network from the submitted shapefile.")
+ source = self._config_data.network.source
+ if source == "shapefile":
+ if self._any_cleanup_enabled():
+ return ShpNetworkWrapper(self._config_data).get_network()
+ return VectorNetworkWrapper(self._config_data).get_network()
+ elif source == "OSM PBF":
+ return TrailsNetworkWrapper(self._config_data).get_network()
+ elif source == "OSM download":
+ return OsmNetworkWrapper(self._config_data).get_network()
+ elif source == "pickle":
+ logging.info("Start importing a network from pickle")
+ base_graph = GraphPickleReader().read(
+ self.output_graph_dir.joinpath("base_graph.p")
+ )
+ network_gdf = gpd.read_feather(
+ self.output_graph_dir.joinpath("base_network.feather")
+ )
+ return base_graph, network_gdf
diff --git a/ra2ce/graph/network_wrappers/network_wrapper_protocol.py b/ra2ce/graph/network_wrappers/network_wrapper_protocol.py
new file mode 100644
index 000000000..b6159869d
--- /dev/null
+++ b/ra2ce/graph/network_wrappers/network_wrapper_protocol.py
@@ -0,0 +1,36 @@
+"""
+ GNU GENERAL PUBLIC LICENSE
+ Version 3, 29 June 2007
+
+ Risk Assessment and Adaptation for Critical Infrastructure (RA2CE).
+ Copyright (C) 2023 Stichting Deltares
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+"""
+from typing import Protocol, runtime_checkable
+
+from geopandas import GeoDataFrame
+from networkx import MultiGraph
+
+
+@runtime_checkable
+class NetworkWrapperProtocol(Protocol):
+ def get_network(self) -> tuple[MultiGraph, GeoDataFrame]:
+ """
+ Gets a network built within this wrapper instance. No arguments are accepted, the `__init__` method is meant to assign all required attributes for a wrapper.
+
+ Returns:
+ tuple[MultiGraph, GeoDataFrame]: Tuple of MultiGraph representing the graph and GeoDataFrame representing the network.
+ """
+ pass
diff --git a/tests/graph/osm_network_wrapper/__init__.py b/ra2ce/graph/network_wrappers/osm_network_wrapper/__init__.py
similarity index 100%
rename from tests/graph/osm_network_wrapper/__init__.py
rename to ra2ce/graph/network_wrappers/osm_network_wrapper/__init__.py
diff --git a/ra2ce/graph/network_wrappers/osm_network_wrapper/extremities_data.py b/ra2ce/graph/network_wrappers/osm_network_wrapper/extremities_data.py
new file mode 100644
index 000000000..e96c7697b
--- /dev/null
+++ b/ra2ce/graph/network_wrappers/osm_network_wrapper/extremities_data.py
@@ -0,0 +1,125 @@
+"""
+ GNU GENERAL PUBLIC LICENSE
+ Version 3, 29 June 2007
+
+ Risk Assessment and Adaptation for Critical Infrastructure (RA2CE).
+ Copyright (C) 2023 Stichting Deltares
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+"""
+
+from dataclasses import dataclass
+
+from networkx import MultiDiGraph
+
+
+@dataclass
+class ExtremitiesData:
+ from_id: int = None
+ to_id: int = None
+ from_to_id: tuple = None
+ to_from_id: tuple = None
+ from_to_coor: tuple = None
+ to_from_coor: tuple = None
+
+ @staticmethod
+ def get_extremities_data_for_sub_graph(
+ from_node_id: int,
+ to_node_id: int,
+ sub_graph: MultiDiGraph,
+ graph: MultiDiGraph,
+ shared_elements: set,
+ ):
+ """Both extremities should be in the unique_graph still makes an edge between similar node to u (the node
+ with u coordinates and different id, included in the unique_graph) and v Here, sub_graph is the unique_graph
+ and graph is complex_graph Shared elements are shared btw sub_graph and graph, which are elements to include
+ when dropping duplicates"""
+ if shared_elements is None or not isinstance(shared_elements, set):
+ raise ValueError("unique_elements should be a set")
+ if from_node_id in sub_graph.nodes() and to_node_id in sub_graph.nodes():
+ return ExtremitiesData.arrange_extremities_data(
+ from_node_id=from_node_id, to_node_id=to_node_id, graph=sub_graph
+ )
+ elif (
+ graph.nodes[from_node_id]["x"],
+ graph.nodes[from_node_id]["y"],
+ ) in shared_elements and to_node_id in sub_graph.nodes():
+
+ from_node_id_prime = ExtremitiesData.find_node_id_by_coor(
+ sub_graph,
+ graph.nodes[from_node_id]["x"],
+ graph.nodes[from_node_id]["y"],
+ )
+ if from_node_id_prime == to_node_id:
+ return ExtremitiesData()
+ else:
+ return ExtremitiesData.arrange_extremities_data(
+ from_node_id=from_node_id_prime,
+ to_node_id=to_node_id,
+ graph=sub_graph,
+ )
+
+ elif (
+ from_node_id in sub_graph.nodes()
+ and (graph.nodes[to_node_id]["x"], graph.nodes[to_node_id]["y"])
+ in shared_elements
+ ):
+
+ to_node_id_prime = ExtremitiesData.find_node_id_by_coor(
+ sub_graph, graph.nodes[to_node_id]["x"], graph.nodes[to_node_id]["y"]
+ )
+ if from_node_id == to_node_id_prime:
+ return ExtremitiesData()
+ else:
+ return ExtremitiesData.arrange_extremities_data(
+ from_node_id=from_node_id,
+ to_node_id=to_node_id_prime,
+ graph=sub_graph,
+ )
+ else:
+ return ExtremitiesData()
+
+ @staticmethod
+ def arrange_extremities_data(
+ from_node_id: int, to_node_id: int, graph: MultiDiGraph
+ ):
+ return ExtremitiesData(
+ from_id=from_node_id,
+ to_id=to_node_id,
+ from_to_id=(from_node_id, to_node_id),
+ to_from_id=(to_node_id, from_node_id),
+ from_to_coor=(
+ (graph.nodes[from_node_id]["x"], graph.nodes[to_node_id]["x"]),
+ (graph.nodes[from_node_id]["y"], graph.nodes[to_node_id]["y"]),
+ ),
+ to_from_coor=(
+ (graph.nodes[to_node_id]["x"], graph.nodes[from_node_id]["x"]),
+ (graph.nodes[to_node_id]["y"], graph.nodes[from_node_id]["y"]),
+ ),
+ )
+
+ @staticmethod
+ def find_node_id_by_coor(graph: MultiDiGraph, target_x: float, target_y: float):
+ """
+ finds the node in unique graph with the same coor
+ """
+ for node, data in graph.nodes(data=True):
+ if (
+ "x" in data
+ and "y" in data
+ and data["x"] == target_x
+ and data["y"] == target_y
+ ):
+ return node
+ return None
diff --git a/ra2ce/graph/osm_network_wrapper/osm_network_wrapper.py b/ra2ce/graph/network_wrappers/osm_network_wrapper/osm_network_wrapper.py
similarity index 69%
rename from ra2ce/graph/osm_network_wrapper/osm_network_wrapper.py
rename to ra2ce/graph/network_wrappers/osm_network_wrapper/osm_network_wrapper.py
index 45950730a..82d9c9b31 100644
--- a/ra2ce/graph/osm_network_wrapper/osm_network_wrapper.py
+++ b/ra2ce/graph/network_wrappers/osm_network_wrapper/osm_network_wrapper.py
@@ -16,38 +16,104 @@
"""
import logging
from pathlib import Path
+from typing import Any
import networkx as nx
import osmnx
-from networkx import MultiDiGraph
-from osmnx import consolidate_intersections
+import pandas as pd
+from geopandas import GeoDataFrame
+from networkx import MultiDiGraph, MultiGraph
from shapely.geometry.base import BaseGeometry
-from ra2ce.graph.network_config_data.network_config_data import NetworkSection
import ra2ce.graph.networks_utils as nut
-from ra2ce.graph.osm_network_wrapper.extremities_data import ExtremitiesData
-
-
-class OsmNetworkWrapper:
- network_type: str
- road_types: list[str]
- graph_crs: str
- polygon_path: Path
-
- def __init__(
- self,
- network_type: str,
- road_types: list[str],
- graph_crs: str,
- polygon_path: Path,
- ) -> None:
- self.network_type = network_type
- self.road_types = road_types
- self.polygon_path = polygon_path
- self.graph_crs = graph_crs
- if not graph_crs:
- # Set default value
- self.graph_crs = "epsg:4326"
+from ra2ce.graph.exporters.json_exporter import JsonExporter
+from ra2ce.graph.network_config_data.network_config_data import NetworkConfigData
+from ra2ce.graph.network_wrappers.network_wrapper_protocol import NetworkWrapperProtocol
+from ra2ce.graph.network_wrappers.osm_network_wrapper.extremities_data import (
+ ExtremitiesData,
+)
+
+
+class OsmNetworkWrapper(NetworkWrapperProtocol):
+ def __init__(self, config_data: NetworkConfigData) -> None:
+ self.output_graph_dir = config_data.output_graph_dir
+ self.graph_crs = config_data.crs
+
+ # Network options
+ self.network_type = config_data.network.network_type
+ self.road_types = config_data.network.road_types
+ self.polygon_path = config_data.network.polygon
+ self.is_directed = config_data.network.directed
+
+ def get_network(self) -> tuple[MultiGraph, GeoDataFrame]:
+ """
+ Gets an indirected graph
+
+ Returns:
+ tuple[MultiGraph, GeoDataFrame]: _description_
+ """
+ logging.info("Start downloading a network from OSM.")
+ graph_complex = self.get_clean_graph_from_osm()
+
+ # Create 'graph_simple'
+ graph_simple, graph_complex, link_tables = nut.create_simplified_graph(
+ graph_complex
+ )
+
+ # Create 'edges_complex', convert complex graph to geodataframe
+ logging.info("Start converting the graph to a geodataframe")
+ edges_complex, node_complex = nut.graph_to_gdf(graph_complex)
+ logging.info("Finished converting the graph to a geodataframe")
+
+ # Save the link tables linking complex and simple IDs
+ self._export_linking_tables(link_tables)
+
+ if not self.is_directed and isinstance(graph_simple, MultiDiGraph):
+ graph_simple = graph_simple.to_undirected()
+
+ # Check if all geometries between nodes are there, if not, add them as a straight line.
+ graph_simple = nut.add_missing_geoms_graph(graph_simple, geom_name="geometry")
+ graph_simple = self._get_avg_speed(graph_simple)
+ return graph_simple, edges_complex
+
+ def _get_avg_speed(
+ self, original_graph: nx.classes.graph.Graph
+ ) -> nx.classes.graph.Graph:
+ if all(["length" in e for u, v, e in original_graph.edges.data()]) and any(
+ ["maxspeed" in e for u, v, e in original_graph.edges.data()]
+ ):
+ # Add time weighing - Define and assign average speeds; or take the average speed from an existing CSV
+ path_avg_speed = self.output_graph_dir.joinpath("avg_speed.csv")
+ if path_avg_speed.is_file():
+ avg_speeds = pd.read_csv(path_avg_speed)
+ else:
+ avg_speeds = nut.calc_avg_speed(
+ original_graph,
+ "highway",
+ save_csv=True,
+ save_path=path_avg_speed,
+ )
+ original_graph = nut.assign_avg_speed(original_graph, avg_speeds, "highway")
+
+ # make a time value of seconds, length of road streches is in meters
+ for u, v, k, edata in original_graph.edges.data(keys=True):
+ hours = (edata["length"] / 1000) / edata["avgspeed"]
+ original_graph[u][v][k]["time"] = round(hours * 3600, 0)
+
+ return original_graph
+ logging.info(
+ "No attributes found in the graph to estimate average speed per network segment."
+ )
+ return original_graph
+
+ def _export_linking_tables(self, linking_tables: tuple[Any]) -> None:
+ _exporter = JsonExporter()
+ _exporter.export(
+ self.output_graph_dir.joinpath("simple_to_complex.json"), linking_tables[0]
+ )
+ _exporter.export(
+ self.output_graph_dir.joinpath("complex_to_simple.json"), linking_tables[1]
+ )
def get_clean_graph_from_osm(self) -> MultiDiGraph:
"""
@@ -234,7 +300,7 @@ def valid_extremity_data(u, v, data) -> tuple[ExtremitiesData, dict]:
@staticmethod
def snap_nodes_to_nodes(graph: MultiDiGraph, threshold: float) -> MultiDiGraph:
- return consolidate_intersections(
+ return osmnx.consolidate_intersections(
G=graph, rebuild_graph=True, tolerance=threshold, dead_ends=False
)
diff --git a/ra2ce/graph/osm_utils.py b/ra2ce/graph/network_wrappers/osm_network_wrapper/osm_utils.py
similarity index 100%
rename from ra2ce/graph/osm_utils.py
rename to ra2ce/graph/network_wrappers/osm_network_wrapper/osm_utils.py
diff --git a/ra2ce/graph/network_wrappers/shp_network_wrapper.py b/ra2ce/graph/network_wrappers/shp_network_wrapper.py
new file mode 100644
index 000000000..a3d558174
--- /dev/null
+++ b/ra2ce/graph/network_wrappers/shp_network_wrapper.py
@@ -0,0 +1,196 @@
+"""
+ GNU GENERAL PUBLIC LICENSE
+ Version 3, 29 June 2007
+
+ Risk Assessment and Adaptation for Critical Infrastructure (RA2CE).
+ Copyright (C) 2023 Stichting Deltares
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+"""
+
+import logging
+import math
+
+import geopandas as gpd
+import networkx as nx
+import pandas as pd
+from shapely.geometry import MultiLineString
+
+import ra2ce.graph.networks_utils as nut
+from ra2ce.graph.network_config_data.network_config_data import NetworkConfigData
+from ra2ce.graph.network_wrappers.network_wrapper_protocol import NetworkWrapperProtocol
+from ra2ce.graph.segmentation import Segmentation
+
+
+class ShpNetworkWrapper(NetworkWrapperProtocol):
+ def __init__(
+ self,
+ config_data: NetworkConfigData,
+ ) -> None:
+ _network_options = config_data.network
+ _cleanup_options = config_data.cleanup
+
+ self.project_name = config_data.project.name
+ self.output_graph_dir = config_data.output_graph_dir
+ self.crs = config_data.crs
+
+ # Network options
+ self.primary_files = _network_options.primary_file
+ self.diversion_files = _network_options.diversion_file
+ self.directed = _network_options.directed
+ self.file_id = _network_options.file_id
+
+ # Cleanup options
+ self.merge_lines = _cleanup_options.merge_lines
+ self.snapping_threshold = _cleanup_options.snapping_threshold
+ self.segmentation_length = _cleanup_options.segmentation_length
+ self.cut_at_intersections = _cleanup_options.cut_at_intersections
+
+ # Origins Destinations
+ self.region_path = config_data.origins_destinations.region
+
+ def _read_merge_shp(self) -> gpd.GeoDataFrame:
+ """Imports shapefile(s) and saves attributes in a pandas dataframe.
+
+ Returns:
+ lines (list of shapely LineStrings): full list of linestrings
+ properties (pandas dataframe): attributes of shapefile(s), in order of the linestrings in lines
+ """
+ # concatenate all shapefile into one geodataframe and set analysis to 1 or 0 for diversions
+ lines = [gpd.read_file(shp) for shp in self.primary_files]
+
+ if any(self.diversion_files):
+ lines.extend(
+ [
+ nut.check_crs_gdf(gpd.read_file(shp), self.crs)
+ for shp in self.diversion_files
+ ]
+ )
+ lines = pd.concat(lines)
+
+ lines.crs = self.crs
+
+ # Check if there are any multilinestrings and convert them to linestrings.
+ if lines["geometry"].apply(lambda row: isinstance(row, MultiLineString)).any():
+ mls_idx = lines.loc[
+ lines["geometry"].apply(lambda row: isinstance(row, MultiLineString))
+ ].index
+ for idx in mls_idx:
+ # Multilinestrings to linestrings
+ new_rows_geoms = list(lines.iloc[idx]["geometry"].geoms)
+ for nrg in new_rows_geoms:
+ dict_attributes = dict(lines.iloc[idx])
+ dict_attributes["geometry"] = nrg
+ lines.loc[max(lines.index) + 1] = dict_attributes
+
+ lines = lines.drop(labels=mls_idx, axis=0)
+
+ # append the length of the road stretches
+ lines["length"] = lines["geometry"].apply(
+ lambda x: nut.line_length(x, self.crs)
+ )
+
+ logging.info(
+ "Shapefile(s) loaded with attributes: {}.".format(
+ list(lines.columns.values)
+ )
+ ) # fill in parameter names
+
+ return lines
+
+ def _get_complex_graph_and_edges(
+ self, edges: gpd.GeoDataFrame, id_name: str
+ ) -> tuple[nx.MultiGraph, gpd.GeoDataFrame]:
+ # Get the unique points at the end of lines and at intersections to create nodes
+ nodes = nut.create_nodes(edges, self.crs, self.cut_at_intersections)
+ logging.info("Function [create_nodes]: executed")
+
+ edges = nut.cut_lines(
+ edges, nodes, id_name, tolerance=0.00001, crs_=self.crs
+ ) ## PAY ATTENTION TO THE TOLERANCE, THE UNIT IS DEGREES
+ logging.info("Function [cut_lines]: executed")
+
+ if not edges.crs:
+ edges.crs = self.crs
+
+ # create tuples from the adjecent nodes and add as column in geodataframe
+ edges_complex = nut.join_nodes_edges(nodes, edges, id_name)
+ edges_complex.crs = self.crs # set the right CRS
+ edges_complex.dropna(subset=["node_A", "node_B"], inplace=True)
+
+ assert (
+ edges_complex["node_A"].isnull().sum() == 0
+ ), "Some edges cannot be assigned nodes, please check your input shapefile."
+ assert (
+ edges_complex["node_B"].isnull().sum() == 0
+ ), "Some edges cannot be assigned nodes, please check your input shapefile."
+
+ # Create networkx graph from geodataframe
+ graph_complex = nut.graph_from_gdf(edges_complex, nodes, node_id="node_fid")
+ logging.info("Function [graph_from_gdf]: executed")
+ return graph_complex, edges_complex
+
+ def get_network(
+ self,
+ ) -> tuple[nx.MultiGraph, gpd.GeoDataFrame]:
+ edges = self._read_merge_shp()
+ lines_merged = gpd.GeoDataFrame()
+ # Check which of the lines are merged, also for the fid. The fid of the first line with a traffic count is taken.
+ # The list of fid's is reduced by the fid's that are not anymore in the merged lines
+ if self.merge_lines:
+ aadt_names = []
+ edges, lines_merged = nut.merge_lines_automatic(
+ edges, self.file_id, aadt_names, self.crs
+ )
+ logging.info(
+ "Function [merge_lines_automatic]: executed with properties {}".format(
+ list(edges.columns)
+ )
+ )
+
+ edges, id_name = nut.gdf_check_create_unique_ids(edges, self.file_id)
+
+ if self.snapping_threshold:
+ # TODO: snapping threshold it's a bool yet here we expect a float.
+ edges = nut.snap_endpoints_lines(
+ edges, self.snapping_threshold, id_name, self.crs
+ )
+ logging.info(
+ "Function [snap_endpoints_lines]: executed with threshold = {}".format(
+ self.snapping_threshold
+ )
+ )
+
+ # merge merged lines if there are any merged lines
+ if not lines_merged.empty:
+ # save the merged lines to a shapefile - CHECK if there are lines merged that should not be merged (e.g. main + secondary road)
+ lines_merged.set_geometry(
+ col="geometry", inplace=True
+ ) # To ensure the object is a GeoDataFrame and not a Series
+ _emerged_lines_file = self.output_graph_dir.joinpath(
+ f"{self.project_name}_lines_that_merged.shp"
+ )
+ lines_merged.to_file(_emerged_lines_file)
+ logging.info(
+ "Function [edges_to_shp]: saved at {}".format(_emerged_lines_file)
+ )
+
+ graph_complex, edges_complex = self._get_complex_graph_and_edges(edges, id_name)
+
+ if not math.isnan(self.segmentation_length):
+ edges_complex = Segmentation(edges_complex, self.segmentation_length)
+ edges_complex = edges_complex.apply_segmentation()
+ if edges_complex.crs is None: # The CRS might have dissapeared.
+ edges_complex.crs = self.crs # set the right CRS
+ return graph_complex, edges_complex
diff --git a/ra2ce/graph/network_wrappers/trails_network_wrapper.py b/ra2ce/graph/network_wrappers/trails_network_wrapper.py
new file mode 100644
index 000000000..5e95669b2
--- /dev/null
+++ b/ra2ce/graph/network_wrappers/trails_network_wrapper.py
@@ -0,0 +1,105 @@
+"""
+ GNU GENERAL PUBLIC LICENSE
+ Version 3, 29 June 2007
+
+ Risk Assessment and Adaptation for Critical Infrastructure (RA2CE).
+ Copyright (C) 2023 Stichting Deltares
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+"""
+
+import logging
+
+import geopandas as gpd
+from geopandas import GeoDataFrame
+from networkx import MultiGraph
+
+from ra2ce.graph.network_config_data.network_config_data import NetworkConfigData
+from ra2ce.graph.network_wrappers.network_wrapper_protocol import NetworkWrapperProtocol
+from ra2ce.graph.networks_utils import graph_from_gdf
+from ra2ce.graph.segmentation import Segmentation
+
+
+class TrailsNetworkWrapper(NetworkWrapperProtocol):
+ def __init__(self, config_data: NetworkConfigData) -> None:
+ logging.info(
+ """The original OSM PBF import is no longer supported.
+ Instead, the beta version of package TRAILS is used.
+ First stable release of TRAILS is expected in 2023."""
+ )
+ self.primary_files = config_data.network.primary_file
+ self.segmentation_length = config_data.cleanup.segmentation_length
+ self.crs = config_data.crs
+
+ def get_network(self) -> tuple[MultiGraph, GeoDataFrame]:
+ """Creates a network which has been prepared in the TRAILS package
+
+ #Todo: we might later simply import the whole trails code as a package, and directly use these functions
+ #Todo: because TRAILS is still in beta version we better wait with that untill the first stable version is
+ # released
+
+ Returns:
+ graph_simple (NetworkX graph): Simplified graph (for use in the indirect analyses).
+ complex_edges (GeoDataFrame): Complex graph (for use in the direct analyses).
+ """
+ logging.info(
+ "TRAILS importer: Reads the provided primary edge file: {}, assumes there also is a_nodes file".format(
+ self.primary_files
+ )
+ )
+
+ logging.warning(
+ "Any coordinate projection information in the feather file will be overwritten (with default WGS84)"
+ )
+ # Make a pyproj CRS from the EPSG code
+
+ _edge_file = self.primary_files[0]
+ edges = gpd.read_feather(_edge_file)
+ edges = edges.set_crs(self.crs)
+
+ corresponding_node_file = _edge_file.replace("edges", "nodes")
+ if not corresponding_node_file.exists():
+ raise FileNotFoundError(
+ "The node file could not be found while importing from TRAILS"
+ )
+
+ nodes = gpd.read_feather(corresponding_node_file)
+ nodes = nodes.set_crs(self.crs)
+
+ logging.info("TRAILS importer: start generating graph")
+ # tempfix to rename columns
+ edges = edges.rename({"from_id": "node_A", "to_id": "node_B"}, axis="columns")
+ node_id = "id"
+ graph_simple = graph_from_gdf(edges, nodes, name="network", node_id=node_id)
+
+ logging.info("TRAILS importer: graph generating was succesfull.")
+ logging.warning(
+ "RA2CE will not clean-up your graph, assuming that it is already done in TRAILS"
+ )
+
+ edges_complex = edges
+ if self.segmentation_length:
+ logging.info("TRAILS importer: start segmentating graph")
+ to_segment = Segmentation(edges, self._cleanup.segmentation_length)
+ edges_simple_segmented = to_segment.apply_segmentation()
+ if edges_simple_segmented.crs is None: # The CRS might have dissapeared.
+ edges_simple_segmented.crs = edges.crs # set the right CRS
+ edges_complex = edges_simple_segmented
+
+ graph_complex = graph_simple # NOTE THAT DIFFERENCE
+ # BETWEEN SIMPLE AND COMPLEX DOES NOT EXIST WHEN IMPORTING WITH TRAILS
+
+ # Todo: better control over metadata in trails
+ # Todo: better control over where things are saved in the pipeline
+ return graph_complex, edges_complex
diff --git a/ra2ce/graph/network_wrappers/vector_network_wrapper.py b/ra2ce/graph/network_wrappers/vector_network_wrapper.py
new file mode 100644
index 000000000..3ba60c472
--- /dev/null
+++ b/ra2ce/graph/network_wrappers/vector_network_wrapper.py
@@ -0,0 +1,227 @@
+"""
+ GNU GENERAL PUBLIC LICENSE
+ Version 3, 29 June 2007
+
+ Risk Assessment and Adaptation for Critical Infrastructure (RA2CE).
+ Copyright (C) 2023 Stichting Deltares
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+"""
+
+import logging
+from pathlib import Path
+
+import geopandas as gpd
+import momepy
+import networkx as nx
+import pandas as pd
+from shapely.geometry import Point
+
+import ra2ce.graph.networks_utils as nut
+from ra2ce.graph.network_config_data.network_config_data import NetworkConfigData
+from ra2ce.graph.network_wrappers.network_wrapper_protocol import NetworkWrapperProtocol
+
+
+class VectorNetworkWrapper(NetworkWrapperProtocol):
+ """A class for handling and manipulating vector files.
+
+ Provides methods for reading vector data, cleaning it, and setting up graph and
+ network.
+ """
+
+ def __init__(
+ self,
+ config_data: NetworkConfigData,
+ ) -> None:
+ self.crs = config_data.crs
+
+ # Network options
+ self.primary_files = config_data.network.primary_file
+ self.directed = config_data.network.directed
+
+ # Origins Destinations
+ self.region_path = config_data.origins_destinations.region
+
+ def get_network(
+ self,
+ ) -> tuple[nx.MultiGraph, gpd.GeoDataFrame]:
+ """Gets a network built from vector files.
+
+ Returns:
+ nx.MultiGraph: MultiGraph representing the graph.
+ gpd.GeoDataFrame: GeoDataFrame representing the network.
+ """
+ gdf = self._read_vector_to_project_region_and_crs()
+ gdf = self.clean_vector(gdf)
+ if self.directed:
+ graph = self.get_direct_graph_from_vector(gdf)
+ else:
+ graph = self.get_indirect_graph_from_vector(gdf)
+ edges, nodes = self.get_network_edges_and_nodes_from_graph(graph)
+ graph_complex = nut.graph_from_gdf(edges, nodes, node_id="node_fid")
+ return graph_complex, edges
+
+ def _read_vector_to_project_region_and_crs(self) -> gpd.GeoDataFrame:
+ gdf = self._read_files(self.primary_files)
+ if gdf is None:
+ logging.info("no file is read.")
+ return None
+
+ # set crs and reproject if needed
+ if not gdf.crs and self.crs:
+ gdf = gdf.set_crs(self.crs)
+ logging.info("setting crs as default EPSG:4326. specify crs if incorrect")
+
+ if self.crs:
+ gdf = gdf.to_crs(self.crs)
+ logging.info("reproject vector file to project crs")
+
+ # clip for region
+ if self.region_path:
+ _region_gpd = self._read_files([self.region_path])
+ gdf = gpd.overlay(gdf, _region_gpd, how="intersection", keep_geom_type=True)
+ logging.info("clip vector file to project region")
+
+ # validate
+ if not any(gdf):
+ logging.warning("No vector features found within project region")
+ return None
+
+ return gdf
+
+ def _read_files(self, file_list: list[Path]) -> gpd.GeoDataFrame:
+ """Reads a list of files into a GeoDataFrame.
+
+ Args:
+ file_list (list[Path]): List of file paths.
+
+ Returns:
+ gpd.GeoDataFrame: GeoDataFrame representing the data.
+ """
+ # read file
+ gdf = gpd.GeoDataFrame(pd.concat(list(map(gpd.read_file, file_list))))
+ logging.info(
+ "Read files {} into a 'GeoDataFrame'.".format(
+ ", ".join(map(str, file_list))
+ )
+ )
+ return gdf
+
+ @staticmethod
+ def get_direct_graph_from_vector(gdf: gpd.GeoDataFrame) -> nx.DiGraph:
+ """Creates a simple directed graph with node and edge geometries based on a given GeoDataFrame.
+
+ Args:
+ gdf (gpd.GeoDataFrame): Input GeoDataFrame containing line geometries.
+ Allow both LineString and MultiLineString.
+
+ Returns:
+ nx.DiGraph: NetworkX graph object with "crs", "approach" as graph properties.
+ """
+
+ # simple geometry handeling
+ gdf = VectorNetworkWrapper.explode_and_deduplicate_geometries(gdf)
+
+ # to graph
+ digraph = nx.DiGraph(crs=gdf.crs, approach="primal")
+ for _, row in gdf.iterrows():
+ from_node = row.geometry.coords[0]
+ to_node = row.geometry.coords[-1]
+ digraph.add_node(from_node, geometry=Point(from_node))
+ digraph.add_node(to_node, geometry=Point(to_node))
+ digraph.add_edge(
+ from_node,
+ to_node,
+ geometry=row.pop(
+ "geometry"
+ ), # **row TODO: check if we do need all columns
+ )
+
+ return digraph
+
+ @staticmethod
+ def get_indirect_graph_from_vector(gdf: gpd.GeoDataFrame) -> nx.Graph:
+ """Creates a simple undirected graph with node and edge geometries based on a given GeoDataFrame.
+
+ Args:
+ gdf (gpd.GeoDataFrame): Input GeoDataFrame containing line geometries.
+ Allow both LineString and MultiLineString.
+
+ Returns:
+ nx.Graph: NetworkX graph object with "crs", "approach" as graph properties.
+ """
+ digraph = VectorNetworkWrapper.get_direct_graph_from_vector(gdf)
+ return digraph.to_undirected()
+
+ @staticmethod
+ def get_network_edges_and_nodes_from_graph(
+ graph: nx.Graph,
+ ) -> tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]:
+ """Sets up network nodes and edges from a given graph.
+
+ Args:
+ graph (nx.Graph): Input graph with geometry for nodes and edges.
+ Must contain "crs" as graph property.
+
+ Returns:
+ gpd.GeoDataFrame: GeoDataFrame representing the network edges with "edge_fid", "node_A", and "node_B".
+ gpd.GeoDataFrame: GeoDataFrame representing the network nodes with "node_fid".
+ """
+
+ # TODO ths function use conventions. Good to make consistant convention with osm
+ nodes, edges = momepy.nx_to_gdf(graph, nodeID="node_fid")
+ edges["edge_fid"] = (
+ edges["node_start"].astype(str) + "_" + edges["node_end"].astype(str)
+ )
+ edges.rename(
+ {"node_start": "node_A", "node_end": "node_B"}, axis=1, inplace=True
+ )
+ if not nodes.crs:
+ nodes.crs = graph.graph["crs"]
+ if not edges.crs:
+ edges.crs = graph.graph["crs"]
+ return edges, nodes
+
+ @staticmethod
+ def clean_vector(gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
+ """Cleans a GeoDataFrame.
+
+ Args:
+ gdf (gpd.GeoDataFrame): Input GeoDataFrame.
+
+ Returns:
+ gpd.GeoDataFrame: Cleaned GeoDataFrame.
+ """
+
+ gdf = VectorNetworkWrapper.explode_and_deduplicate_geometries(gdf)
+
+ return gdf
+
+ @staticmethod
+ def explode_and_deduplicate_geometries(gpd: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
+ """Explodes and deduplicates geometries a GeoDataFrame.
+
+ Args:
+ gpd (gpd.GeoDataFrame): Input GeoDataFrame.
+
+ Returns:
+ gpd.GeoDataFrame: GeoDataFrame with exploded and deduplicated geometries.
+ """
+ gpd = gpd.explode()
+ gpd = gpd[
+ gpd.index.isin(
+ gpd.geometry.apply(lambda geom: geom.wkb).drop_duplicates().index
+ )
+ ]
+ return gpd
diff --git a/ra2ce/graph/networks.py b/ra2ce/graph/networks.py
index e97dfc5b7..ac05c21a2 100644
--- a/ra2ce/graph/networks.py
+++ b/ra2ce/graph/networks.py
@@ -20,26 +20,18 @@
"""
import logging
-import math
-import os
from pathlib import Path
-from typing import Any, List, Tuple
+from typing import Any
import geopandas as gpd
import networkx as nx
-import pandas as pd
import pyproj
-import momepy
-from shapely.geometry import MultiLineString
from ra2ce.common.io.readers import GraphPickleReader
from ra2ce.graph import networks_utils as nut
-from ra2ce.graph.exporters.json_exporter import JsonExporter
from ra2ce.graph.exporters.network_exporter_factory import NetworkExporterFactory
from ra2ce.graph.network_config_data.network_config_data import NetworkConfigData
-from ra2ce.graph.osm_network_wrapper.osm_network_wrapper import OsmNetworkWrapper
-from ra2ce.graph.segmentation import Segmentation
-from ra2ce.graph.vector_network_wrapper import VectorNetworkWrapper
+from ra2ce.graph.network_wrappers.network_wrapper_factory import NetworkWrapperFactory
class Network:
@@ -52,20 +44,16 @@ class Network:
config: A dictionary with the configuration details on how to create and adjust the network.
"""
- output_graph_dir: Path
- files: dict
- base_graph_crs: Any
- base_network_crs: Any
-
def __init__(self, network_config: NetworkConfigData, files: dict):
# General
+ self._config_data = network_config
self.project_name = network_config.project.name
- self.output_graph_dir = network_config.static_path.joinpath("output_graph")
+ self.output_graph_dir = network_config.output_graph_dir
if not self.output_graph_dir.is_dir():
self.output_graph_dir.mkdir(parents=True)
# Network
- self._network_dir = network_config.static_path.joinpath("network")
+ self._network_dir = network_config.network_dir
self.base_graph_crs = None # Initiate variable
self.base_network_crs = None # Initiate variable
@@ -82,290 +70,12 @@ def __init__(self, network_config: NetworkConfigData, files: dict):
)
self.origin_count = _origins_destinations.origin_count
self.od_category = _origins_destinations.category
- self.region = (
- None
- if not _origins_destinations.region
- else network_config.static_path.joinpath(
- "network", _origins_destinations.region
- )
- )
+ self.region = _origins_destinations.region
self.region_var = _origins_destinations.region_var
- # Cleanup
- self._cleanup = network_config.cleanup
-
# files
self.files = files
- def network_shp(
- self, crs: int = 4326
- ) -> Tuple[nx.classes.graph.Graph, gpd.GeoDataFrame]:
- """Creates a (graph) network from a shapefile.
-
- Returns the same geometries for the network (GeoDataFrame) as for the graph (NetworkX graph), because
- it is assumed that the user wants to keep the same geometries as their shapefile input.
-
- Args:
- crs (int): the EPSG number of the coordinate reference system that is used
-
- Returns:
- graph_complex (NetworkX graph): The resulting graph.
- edges_complex (GeoDataFrame): The resulting network.
- """
- # Make a pyproj CRS from the EPSG code
- crs = pyproj.CRS.from_user_input(crs)
-
- lines = self.read_merge_shp(crs)
-
- logging.info(
- "Function [read_merge_shp]: executed with {} {}".format(
- self._network_config.primary_file,
- self._network_config.diversion_file,
- )
- )
-
- # Check which of the lines are merged, also for the fid. The fid of the first line with a traffic count is taken.
- # The list of fid's is reduced by the fid's that are not anymore in the merged lines
- if self._cleanup.merge_lines:
- aadt_names = []
- edges, lines_merged = nut.merge_lines_automatic(
- lines, self._network_config.file_id, aadt_names, crs
- )
- logging.info(
- "Function [merge_lines_automatic]: executed with properties {}".format(
- list(edges.columns)
- )
- )
- else:
- edges, lines_merged = lines, gpd.GeoDataFrame()
-
- edges, id_name = nut.gdf_check_create_unique_ids(
- edges, self._network_config.file_id
- )
-
- if self._cleanup.snapping_threshold:
- edges = nut.snap_endpoints_lines(
- edges, self._cleanup.snapping_threshold, id_name, crs
- )
- logging.info(
- "Function [snap_endpoints_lines]: executed with threshold = {}".format(
- self._cleanup.snapping_threshold
- )
- )
-
- # merge merged lines if there are any merged lines
- if not lines_merged.empty:
- # save the merged lines to a shapefile - CHECK if there are lines merged that should not be merged (e.g. main + secondary road)
- lines_merged.set_geometry(
- col="geometry", inplace=True
- ) # To ensure the object is a GeoDataFrame and not a Series
- lines_merged.to_file(
- os.path.join(
- self.output_graph_dir,
- "{}_lines_that_merged.shp".format(self.project_name),
- )
- )
- logging.info(
- "Function [edges_to_shp]: saved at {}".format(
- os.path.join(
- self.output_graph_dir,
- "{}_lines_that_merged".format(self.project_name),
- )
- )
- )
-
- # Get the unique points at the end of lines and at intersections to create nodes
- nodes = nut.create_nodes(edges, crs, self._cleanup.cut_at_intersections)
- logging.info("Function [create_nodes]: executed")
-
- edges = nut.cut_lines(
- edges, nodes, id_name, tolerance=0.00001, crs_=crs
- ) ## PAY ATTENTION TO THE TOLERANCE, THE UNIT IS DEGREES
- logging.info("Function [cut_lines]: executed")
-
- if not edges.crs:
- edges.crs = crs
-
- # create tuples from the adjecent nodes and add as column in geodataframe
- edges_complex = nut.join_nodes_edges(nodes, edges, id_name)
- edges_complex.crs = crs # set the right CRS
- edges_complex.dropna(subset=["node_A", "node_B"], inplace=True)
-
- assert (
- edges_complex["node_A"].isnull().sum() == 0
- ), "Some edges cannot be assigned nodes, please check your input shapefile."
- assert (
- edges_complex["node_B"].isnull().sum() == 0
- ), "Some edges cannot be assigned nodes, please check your input shapefile."
-
- # Create networkx graph from geodataframe
- graph_complex = nut.graph_from_gdf(edges_complex, nodes, node_id="node_fid")
- logging.info("Function [graph_from_gdf]: executed")
-
- if not math.isnan(self._cleanup.segmentation_length):
- edges_complex = Segmentation(
- edges_complex, self._cleanup.segmentation_length
- )
- edges_complex = edges_complex.apply_segmentation()
- if edges_complex.crs is None: # The CRS might have dissapeared.
- edges_complex.crs = crs # set the right CRS
-
- self.base_graph_crs = pyproj.CRS.from_user_input(crs)
- self.base_network_crs = pyproj.CRS.from_user_input(crs)
-
- # Exporting complex graph because the shapefile should be kept the same as much as possible.
- return graph_complex, edges_complex
-
- def network_cleanshp(self) -> Tuple[nx.classes.graph.Graph, gpd.GeoDataFrame]:
- """Creates a (graph) network from a clean shapefile (primary_file - no further advance cleanup is needed)
-
- Returns the same geometries for the network (GeoDataFrame) as for the graph (NetworkX graph), because
- it is assumed that the user wants to keep the same geometries as their shapefile input.
-
- Returns:
- graph_complex (NetworkX graph): The resulting graph.
- edges_complex (GeoDataFrame): The resulting network.
- """
- # initialise vector network wrapper
- vector_network_wrapper = VectorNetworkWrapper(config=self.config)
-
- # setup network using the wrapper
- (
- graph_complex,
- edges_complex,
- ) = vector_network_wrapper.setup_network_from_vector()
-
- # Set the CRS of the graph and network to wrapper crs
- self.base_graph_crs = vector_network_wrapper.crs
- self.base_network_crs = vector_network_wrapper.crs
-
- return graph_complex, edges_complex
-
- def _export_linking_tables(self, linking_tables: List[Any]) -> None:
- _exporter = JsonExporter()
- _exporter.export(
- self.output_graph_dir.joinpath("simple_to_complex.json"), linking_tables[0]
- )
- _exporter.export(
- self.output_graph_dir.joinpath("complex_to_simple.json"), linking_tables[1]
- )
-
- def network_trails_import(
- self, crs: int = 4326
- ) -> Tuple[nx.classes.graph.Graph, gpd.GeoDataFrame]:
- """Creates a network which has been prepared in the TRAILS package
-
- #Todo: we might later simply import the whole trails code as a package, and directly use these functions
- #Todo: because TRAILS is still in beta version we better wait with that untill the first stable version is
- # released
-
- Returns:
- graph_simple (NetworkX graph): Simplified graph (for use in the indirect analyses).
- complex_edges (GeoDataFrame): Complex graph (for use in the direct analyses).
- """
-
- logging.info(
- "TRAILS importer: Reads the provided primary edge file: {}, assumes there also is a_nodes file".format(
- self._network_config.primary_file
- )
- )
-
- logging.warning(
- "Any coordinate projection information in the feather file will be overwritten (with default WGS84)"
- )
- # Make a pyproj CRS from the EPSG code
- crs = pyproj.CRS.from_user_input(crs)
-
- edge_file = self._network_dir.joinpath(self._network_config.primary_file)
- edges = gpd.read_feather(edge_file)
- edges = edges.set_crs(crs)
-
- corresponding_node_file = self._network_dir.joinpath(
- self._network_config.primary_file.replace("edges", "nodes")
- )
- assert (
- corresponding_node_file.exists()
- ), "The node file could not be found while importing from TRAILS"
- nodes = gpd.read_feather(corresponding_node_file)
- nodes = nodes.set_crs(crs)
- # nodes = pd.read_pickle(
- # corresponding_node_file
- # ) # Todo: Throw exception if nodes file is not present
-
- logging.info("TRAILS importer: start generating graph")
- # tempfix to rename columns
- edges = edges.rename({"from_id": "node_A", "to_id": "node_B"}, axis="columns")
- node_id = "id"
- graph_simple = nut.graph_from_gdf(edges, nodes, name="network", node_id=node_id)
-
- logging.info("TRAILS importer: graph generating was succesfull.")
- logging.warning(
- "RA2CE will not clean-up your graph, assuming that it is already done in TRAILS"
- )
-
- if self._cleanup.segmentation_length:
- logging.info("TRAILS importer: start segmentating graph")
- to_segment = Segmentation(edges, self._cleanup.segmentation_length)
- edges_simple_segmented = to_segment.apply_segmentation()
- if edges_simple_segmented.crs is None: # The CRS might have dissapeared.
- edges_simple_segmented.crs = edges.crs # set the right CRS
- edges_complex = edges_simple_segmented
-
- else:
- edges_complex = edges
-
- graph_complex = graph_simple # NOTE THAT DIFFERENCE
- # BETWEEN SIMPLE AND COMPLEX DOES NOT EXIST WHEN IMPORTING WITH TRAILS
-
- # Todo: better control over metadata in trails
- # Todo: better control over where things are saved in the pipeline
-
- return graph_complex, edges_complex
-
- def network_osm_download(self) -> tuple[nx.classes.graph.Graph, gpd.GeoDataFrame]:
- """
- Creates a network from a polygon by downloading via the OSM API in the extent of the polygon.
-
- Returns:
- tuple[nx.classes.graph.Graph, gpd.GeoDataFrame]: Tuple of Simplified graph (for use in the indirect analyses) and Complex graph (for use in the direct analyses).
- """
- osm_network = OsmNetworkWrapper(
- network_type=self._network_config.network_type,
- road_types=self._network_config.road_types,
- graph_crs="",
- polygon_path=self._network_dir.joinpath(self._network_config.polygon),
- )
- graph_complex = osm_network.get_clean_graph_from_osm()
-
- # Create 'graph_simple'
- graph_simple, graph_complex, link_tables = nut.create_simplified_graph(
- graph_complex
- )
-
- # Create 'edges_complex', convert complex graph to geodataframe
- logging.info("Start converting the graph to a geodataframe")
- edges_complex, node_complex = nut.graph_to_gdf(graph_complex)
- logging.info("Finished converting the graph to a geodataframe")
-
- # Save the link tables linking complex and simple IDs
- self._export_linking_tables(link_tables)
-
- # If the user wants to use undirected graphs, turn into an undirected graph (default).
- if not self._network_config.directed:
- if type(graph_simple) == nx.classes.multidigraph.MultiDiGraph:
- graph_simple = graph_simple.to_undirected()
-
- # No segmentation required, the non-simplified road segments from OSM are already small enough
-
- self.base_graph_crs = pyproj.CRS.from_user_input(
- "EPSG:4326"
- ) # Graphs from OSM download are always in this CRS.
- self.base_network_crs = pyproj.CRS.from_user_input(
- "EPSG:4326"
- ) # Graphs from OSM download are always in this CRS.
-
- return graph_simple, edges_complex
-
def add_od_nodes(
self, graph: nx.classes.graph.Graph, crs: pyproj.CRS
) -> nx.classes.graph.Graph:
@@ -426,100 +136,8 @@ def generate_origins_from_raster(self):
return out_fn
- def read_merge_shp(self, crs_: pyproj.CRS) -> gpd.GeoDataFrame:
- """Imports shapefile(s) and saves attributes in a pandas dataframe.
-
- Args:
- crs_ (int): the EPSG number of the coordinate reference system that is used
- Returns:
- lines (list of shapely LineStrings): full list of linestrings
- properties (pandas dataframe): attributes of shapefile(s), in order of the linestrings in lines
- """
-
- # read shapefiles and add to list with path
- def get_shp_paths(shp_str: str) -> Path:
- return self._network_dir.joinpath(shp_str)
-
- shapefiles_analysis = list(
- map(get_shp_paths, self._network_config.primary_file.split(","))
- )
- shapefiles_diversion = list(
- map(get_shp_paths, self._network_config.diversion_file.split(","))
- )
-
- # concatenate all shapefile into one geodataframe and set analysis to 1 or 0 for diversions
- lines = [gpd.read_file(shp) for shp in shapefiles_analysis]
-
- if self._network_config.diversion_file:
- lines.extend(
- [
- nut.check_crs_gdf(gpd.read_file(shp), crs_)
- for shp in shapefiles_diversion
- ]
- )
- lines = pd.concat(lines)
-
- lines.crs = crs_
-
- # Check if there are any multilinestrings and convert them to linestrings.
- if lines["geometry"].apply(lambda row: isinstance(row, MultiLineString)).any():
- mls_idx = lines.loc[
- lines["geometry"].apply(lambda row: isinstance(row, MultiLineString))
- ].index
- for idx in mls_idx:
- # Multilinestrings to linestrings
- new_rows_geoms = list(lines.iloc[idx]["geometry"].geoms)
- for nrg in new_rows_geoms:
- dict_attributes = dict(lines.iloc[idx])
- dict_attributes["geometry"] = nrg
- lines.loc[max(lines.index) + 1] = dict_attributes
-
- lines = lines.drop(labels=mls_idx, axis=0)
-
- # append the length of the road stretches
- lines["length"] = lines["geometry"].apply(lambda x: nut.line_length(x, crs_))
-
- logging.info(
- "Shapefile(s) loaded with attributes: {}.".format(
- list(lines.columns.values)
- )
- ) # fill in parameter names
-
- return lines
-
- def get_avg_speed(
- self, original_graph: nx.classes.graph.Graph
- ) -> nx.classes.graph.Graph:
- if all(["length" in e for u, v, e in original_graph.edges.data()]) and any(
- ["maxspeed" in e for u, v, e in original_graph.edges.data()]
- ):
- # Add time weighing - Define and assign average speeds; or take the average speed from an existing CSV
- path_avg_speed = self.output_graph_dir.joinpath("avg_speed.csv")
- if path_avg_speed.is_file():
- avg_speeds = pd.read_csv(path_avg_speed)
- else:
- avg_speeds = nut.calc_avg_speed(
- original_graph,
- "highway",
- save_csv=True,
- save_path=path_avg_speed,
- )
- original_graph = nut.assign_avg_speed(original_graph, avg_speeds, "highway")
-
- # make a time value of seconds, length of road streches is in meters
- for u, v, k, edata in original_graph.edges.data(keys=True):
- hours = (edata["length"] / 1000) / edata["avgspeed"]
- original_graph[u][v][k]["time"] = round(hours * 3600, 0)
-
- return original_graph
- else:
- logging.info(
- "No attributes found in the graph to estimate average speed per network segment."
- )
- return original_graph
-
def _export_network_files(
- self, network: Any, graph_name: str, types_to_export: List[str]
+ self, network: Any, graph_name: str, types_to_export: list[str]
):
_exporter = NetworkExporterFactory()
_exporter.export(
@@ -530,14 +148,60 @@ def _export_network_files(
)
self.files[graph_name] = _exporter.get_pickle_path()
- def _any_cleanup_enabled(self) -> bool:
- return (
- self._cleanup.snapping_threshold
- or self._cleanup.pruning_threshold
- or self._cleanup.merge_lines
- or self._cleanup.cut_at_intersections
+ def _get_new_network_and_graph(
+ self, export_types: list[str]
+ ) -> tuple[nx.classes.graph.Graph, gpd.GeoDataFrame]:
+ _base_graph, _network_gdf = NetworkWrapperFactory(
+ self._config_data
+ ).get_network()
+
+ self.base_graph_crs = _network_gdf.crs
+ self.base_network_crs = _network_gdf.crs
+
+ # Set the road lengths to meters for both the base_graph and network_gdf
+ # TODO: rename "length" column to "length [m]" to be explicit
+ edges_lengths_meters = {
+ (e[0], e[1], e[2]): {
+ "length": nut.line_length(e[-1]["geometry"], _network_gdf.crs)
+ }
+ for e in _base_graph.edges.data(keys=True)
+ }
+ nx.set_edge_attributes(_base_graph, edges_lengths_meters)
+
+ _network_gdf["length"] = _network_gdf["geometry"].apply(
+ lambda x: nut.line_length(x, _network_gdf.crs)
+ )
+
+ # Save the graph and geodataframe
+ self._export_network_files(_base_graph, "base_graph", export_types)
+ self._export_network_files(_network_gdf, "base_network", export_types)
+ return _base_graph, _network_gdf
+
+ def _get_stored_network_and_graph(
+ self, base_graph_filepath: Path, base_network_filepath: Path
+ ):
+ logging.info(
+ "Apparently, you already did create a network with ra2ce earlier. "
+ + "Ra2ce will use this: {}".format(base_graph_filepath)
)
+ def check_base_file(file_type: str, file_path: Path):
+ if not isinstance(base_graph_filepath) or not base_graph_filepath.is_file():
+ raise FileNotFoundError(
+ "No base {} file found at {}.".format(file_type, file_path)
+ )
+
+ check_base_file("graph", base_graph_filepath)
+ check_base_file("network", base_network_filepath)
+
+ _base_graph = GraphPickleReader().read(base_graph_filepath)
+ _network_gdf = gpd.read_feather(base_network_filepath)
+
+ # Assuming the same CRS for both the network and graph
+ self.base_graph_crs = _network_gdf.crs
+ self.base_network_crs = _network_gdf.crs
+ return _base_graph, _network_gdf
+
def create(self) -> dict:
"""Handler function with the logic to call the right functions to create a network.
@@ -552,87 +216,12 @@ def create(self) -> dict:
# For all graph and networks - check if it exists, otherwise, make the graph and/or network.
if not (self.files["base_graph"] or self.files["base_network"]):
- # Create the network from the network source
- if self._network_config.source == "shapefile":
- logging.info("Start creating a network from the submitted shapefile.")
- if self._any_cleanup_enabled():
- base_graph, network_gdf = self.network_shp()
- else:
- base_graph, network_gdf = self.network_cleanshp()
-
- elif self._network_config.source == "OSM PBF":
- logging.info(
- """The original OSM PBF import is no longer supported.
- Instead, the beta version of package TRAILS is used.
- First stable release of TRAILS is expected in 2023."""
- )
-
- # base_graph, network_gdf = self.network_osm_pbf() #The old approach is depreciated
- base_graph, network_gdf = self.network_trails_import()
-
- self.base_network_crs = network_gdf.crs
-
- elif self._network_config.source == "OSM download":
- logging.info("Start downloading a network from OSM.")
- base_graph, network_gdf = self.network_osm_download()
- # Graph & Network from OSM download
- # Check if all geometries between nodes are there, if not, add them as a straight line.
- base_graph = nut.add_missing_geoms_graph(
- base_graph, geom_name="geometry"
- )
- elif self._network_config.source == "pickle":
- logging.info("Start importing a network from pickle")
- base_graph = GraphPickleReader().read(
- self.output_graph_dir.joinpath("base_graph.p")
- )
- network_gdf = gpd.read_feather(
- self.output_graph_dir.joinpath("base_network.feather")
- )
-
- # Assuming the same CRS for both the network and graph
- self.base_graph_crs = pyproj.CRS.from_user_input(network_gdf.crs)
- self.base_network_crs = pyproj.CRS.from_user_input(network_gdf.crs)
-
- # Set the road lengths to meters for both the base_graph and network_gdf
- # TODO: rename "length" column to "length [m]" to be explicit
- edges_lengths_meters = {
- (e[0], e[1], e[2]): {
- "length": nut.line_length(e[-1]["geometry"], self.base_graph_crs)
- }
- for e in base_graph.edges.data(keys=True)
- }
- nx.set_edge_attributes(base_graph, edges_lengths_meters)
-
- network_gdf["length"] = network_gdf["geometry"].apply(
- lambda x: nut.line_length(x, self.base_network_crs)
- )
-
- if self._network_config.source == "OSM download":
- base_graph = self.get_avg_speed(base_graph)
-
- # Save the graph and geodataframe
- self._export_network_files(base_graph, "base_graph", to_save)
- self._export_network_files(network_gdf, "base_network", to_save)
+ base_graph, network_gdf = self._get_new_network_and_graph(to_save)
else:
- logging.info(
- "Apparently, you already did create a network with ra2ce earlier. "
- + "Ra2ce will use this: {}".format(self.files["base_graph"])
+ base_graph, network_gdf = self._get_stored_network_and_graph(
+ self.files["base_graph"], self.files["base_network"]
)
- if self.files["base_graph"] is not None:
- base_graph = GraphPickleReader().read(self.files["base_graph"])
- else:
- base_graph = None
-
- if self.files["base_network"] is not None:
- network_gdf = gpd.read_feather(self.files["base_network"])
- else:
- network_gdf = None
-
- # Assuming the same CRS for both the network and graph
- self.base_graph_crs = pyproj.CRS.from_user_input(network_gdf.crs)
- self.base_network_crs = pyproj.CRS.from_user_input(network_gdf.crs)
-
# create origins destinations graph
if (
(self.origins)
diff --git a/ra2ce/graph/osm_network_wrapper/extremities_data.py b/ra2ce/graph/osm_network_wrapper/extremities_data.py
deleted file mode 100644
index 6e4308de1..000000000
--- a/ra2ce/graph/osm_network_wrapper/extremities_data.py
+++ /dev/null
@@ -1,79 +0,0 @@
-from dataclasses import dataclass
-
-from networkx import MultiDiGraph
-
-
-@dataclass
-class ExtremitiesData:
- from_id: int = None
- to_id: int = None
- from_to_id: tuple = None
- to_from_id: tuple = None
- from_to_coor: tuple = None
- to_from_coor: tuple = None
-
- @staticmethod
- def get_extremities_data_for_sub_graph(from_node_id: int, to_node_id: int, sub_graph: MultiDiGraph,
- graph: MultiDiGraph, shared_elements: set):
- """Both extremities should be in the unique_graph still makes an edge between similar node to u (the node
- with u coordinates and different id, included in the unique_graph) and v Here, sub_graph is the unique_graph
- and graph is complex_graph Shared elements are shared btw sub_graph and graph, which are elements to include
- when dropping duplicates"""
- if shared_elements is None or not isinstance(shared_elements, set):
- raise ValueError("unique_elements should be a set")
- if from_node_id in sub_graph.nodes() and to_node_id in sub_graph.nodes():
- return ExtremitiesData.arrange_extremities_data(
- from_node_id=from_node_id, to_node_id=to_node_id, graph=sub_graph
- )
- elif (graph.nodes[from_node_id]['x'], graph.nodes[from_node_id]['y']) in shared_elements and \
- to_node_id in sub_graph.nodes():
-
- from_node_id_prime = ExtremitiesData.find_node_id_by_coor(
- sub_graph, graph.nodes[from_node_id]['x'], graph.nodes[from_node_id]['y']
- )
- if from_node_id_prime == to_node_id:
- return ExtremitiesData()
- else:
- return ExtremitiesData.arrange_extremities_data(from_node_id=from_node_id_prime, to_node_id=to_node_id,
- graph=sub_graph)
-
- elif from_node_id in sub_graph.nodes() and \
- (graph.nodes[to_node_id]['x'], graph.nodes[to_node_id]['y']) in shared_elements:
-
- to_node_id_prime = ExtremitiesData.find_node_id_by_coor(
- sub_graph, graph.nodes[to_node_id]['x'], graph.nodes[to_node_id]['y']
- )
- if from_node_id == to_node_id_prime:
- return ExtremitiesData()
- else:
- return ExtremitiesData.arrange_extremities_data(from_node_id=from_node_id, to_node_id=to_node_id_prime,
- graph=sub_graph)
- else:
- return ExtremitiesData()
-
- @staticmethod
- def arrange_extremities_data(from_node_id: int, to_node_id: int, graph: MultiDiGraph):
- return ExtremitiesData(
- from_id=from_node_id,
- to_id=to_node_id,
- from_to_id=(from_node_id, to_node_id),
- to_from_id=(to_node_id, from_node_id),
- from_to_coor=(
- (graph.nodes[from_node_id]['x'], graph.nodes[to_node_id]['x']),
- (graph.nodes[from_node_id]['y'], graph.nodes[to_node_id]['y'])
- ),
- to_from_coor=(
- (graph.nodes[to_node_id]['x'], graph.nodes[from_node_id]['x']),
- (graph.nodes[to_node_id]['y'], graph.nodes[from_node_id]['y'])
- )
- )
-
- @staticmethod
- def find_node_id_by_coor(graph: MultiDiGraph, target_x: float, target_y: float):
- """
- finds the node in unique graph with the same coor
- """
- for node, data in graph.nodes(data=True):
- if 'x' in data and 'y' in data and data['x'] == target_x and data['y'] == target_y:
- return node
- return None
diff --git a/ra2ce/graph/vector_network_wrapper.py b/ra2ce/graph/vector_network_wrapper.py
deleted file mode 100644
index e6e3b2035..000000000
--- a/ra2ce/graph/vector_network_wrapper.py
+++ /dev/null
@@ -1,327 +0,0 @@
-import logging
-from pathlib import Path
-from typing import Union
-
-import networkx as nx
-import pandas as pd
-import geopandas as gpd
-import momepy
-
-from shapely.geometry import Point
-from pyproj import CRS
-import ra2ce.graph.networks_utils as nut
-
-logger = logging.getLogger()
-
-
-class VectorNetworkWrapper:
- """A class for handling and manipulating vector files.
-
- Provides methods for reading vector data, cleaning it, and setting up graph and
- network.
- """
-
- name: str
- region: gpd.GeoDataFrame
- crs: CRS
- input_path: Path
- output_path: Path
- network_dict: dict
-
- def __init__(self, config: dict) -> None:
- """Initializes the VectorNetworkWrapper object.
-
- Args:
- config (dict): Configuration dictionary.
-
- Raises:
- ValueError: If the config is None or doesn't contain a network dictionary,
- or if config['network'] is not a dictionary.
- """
-
- if not config:
- raise ValueError("Config cannot be None")
- if not config.get("network", {}):
- raise ValueError(
- "A network dictionary is required for creating a "
- + f"{self.__class__.__name__} object."
- )
- if not isinstance(config.get("network"), dict):
- raise ValueError('Config["network"] should be a dictionary')
-
- self._setup_global(config)
- self.network_dict = self._get_network_opt(config["network"])
-
- def _parse_ini_stringlist(self, value: str) -> Union[str, list, None]:
- """Parses a string with "," into a list from an ini file.
-
- Args:
- value (str): Value to parse.
-
- Returns:
- list of str If the value contains a comma, it is split and returned
- as a list, otherwise the original value is returned.
- None if the value is an empty string.
- """
- if not value:
- return None
- elif isinstance(value, str) and "," in value:
- return value.split(",")
- else:
- return value
-
- def _setup_global(self, config: dict) -> None:
- """Sets up project properties based on provided configuration.
-
- Args:
- config (dict): Project configuration dictionary.
- """
- self.input_path = config.get("static").joinpath("network")
- self.output_path = config.get("output")
-
- project_config = config.get("project")
- name = project_config.get("name", "project_name")
- region = project_config.get("region", None)
- crs = project_config.get("crs", 4326)
- self.name = name
- self.crs = CRS.from_user_input(crs)
- self.region = self._read_files([Path(region)]) if region else region
-
- def _parse_ini_filenamelist(self, filename: str) -> list[Path]:
- """Makes a list of file paths by joining with input path and checks validity of files.
-
- Args:
- filename (str): String of file names separated by comma (",").
-
- Returns:
- List[Path]: List of file paths.
- """
- if not isinstance(filename, str):
- logger.error("file names are not valid.")
-
- file_paths = [self.input_path.joinpath(f.strip()) for f in filename.split(",")]
-
- for f in file_paths:
- if not f.resolve().is_file():
- logger.error(f"vector file {f} is not found.")
-
- return file_paths
-
- def _get_network_opt(self, network_config: dict) -> dict:
- """Retrieves network options used in this wrapper from provided configuration.
-
- Args:
- network_config (dict): Network configuration dictionary.
-
- Returns:
- dict: Dictionary of network options.
- """
-
- files = self._parse_ini_filenamelist(network_config.get("primary_file", ""))
- file_id = self._parse_ini_stringlist(
- network_config.get("file_id", "")
- ) # only needed when cleanup based on fid
- file_filter = self._parse_ini_stringlist(network_config.get("filter", ""))
- file_crs = CRS.from_user_input(network_config.get("crs", self.crs))
- is_directed = network_config.get("directed", False)
- return dict(
- files=files,
- file_id=file_id,
- file_filter=file_filter,
- file_crs=file_crs,
- is_directed=is_directed,
- )
-
- @staticmethod
- def setup_digraph_from_vector(gdf: gpd.GeoDataFrame) -> nx.DiGraph:
- """Creates a simple directed graph with node and edge geometries based on a given GeoDataFrame.
-
- Args:
- gdf (gpd.GeoDataFrame): Input GeoDataFrame containing line geometries.
- Allow both LineString and MultiLineString.
-
- Returns:
- nx.DiGraph: NetworkX graph object with "crs", "approach" as graph properties.
- """
-
- # simple geometry handeling
- gdf = VectorNetworkWrapper.explode_and_deduplicate_geometries(gdf)
-
- # to graph
- digraph = nx.DiGraph(crs=gdf.crs, approach="primal")
- for index, row in gdf.iterrows():
- from_node = row.geometry.coords[0]
- to_node = row.geometry.coords[-1]
- digraph.add_node(from_node, geometry=Point(from_node))
- digraph.add_node(to_node, geometry=Point(to_node))
- digraph.add_edge(
- from_node,
- to_node,
- geometry=row.pop(
- "geometry"
- ), # **row TODO: check if we do need all columns
- )
-
- return digraph
-
- @staticmethod
- def setup_graph_from_vector(gdf: gpd.GeoDataFrame) -> nx.Graph:
- """Creates a simple undirected graph with node and edge geometries based on a given GeoDataFrame.
-
- Args:
- gdf (gpd.GeoDataFrame): Input GeoDataFrame containing line geometries.
- Allow both LineString and MultiLineString.
-
- Returns:
- nx.Graph: NetworkX graph object with "crs", "approach" as graph properties.
- """
- digraph = VectorNetworkWrapper.setup_digraph_from_vector(gdf)
- return digraph.to_undirected()
-
- @staticmethod
- def setup_network_edges_and_nodes_from_graph(
- graph: nx.Graph,
- ) -> tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]:
- """Sets up network nodes and edges from a given graph.
-
- Args:
- graph (nx.Graph): Input graph with geometry for nodes and edges.
- Must contain "crs" as graph property.
-
- Returns:
- gpd.GeoDataFrame: GeoDataFrame representing the network edges with "edge_fid", "node_A", and "node_B".
- gpd.GeoDataFrame: GeoDataFrame representing the network nodes with "node_fid".
- """
-
- # TODO ths function use conventions. Good to make consistant convention with osm
- nodes, edges = momepy.nx_to_gdf(graph, nodeID="node_fid")
- edges["edge_fid"] = (
- edges["node_start"].astype(str) + "_" + edges["node_end"].astype(str)
- )
- edges.rename(
- {"node_start": "node_A", "node_end": "node_B"}, axis=1, inplace=True
- )
- if not nodes.crs:
- nodes.crs = graph.graph["crs"]
- if not edges.crs:
- edges.crs = graph.graph["crs"]
- return edges, nodes
-
- def setup_network_from_vector(
- self,
- ) -> tuple[nx.MultiGraph, gpd.GeoDataFrame]:
- """Sets up a network from vector files.
-
- Returns:
- nx.MultiGraph: MultiGraph representing the graph.
- gpd.GeoDataFrame: GeoDataFrame representing the network.
- """
- files = self.network_dict["files"]
- file_crs = self.network_dict["file_crs"]
- is_directed = self.network_dict["is_directed"]
-
- gdf = self._read_vector_to_project_region_and_crs(
- vector_filenames=files, crs=file_crs
- )
- gdf = VectorNetworkWrapper.clean_vector(gdf)
- if is_directed:
- graph = VectorNetworkWrapper.setup_digraph_from_vector(gdf)
- else:
- graph = VectorNetworkWrapper.setup_graph_from_vector(gdf)
- edges, nodes = VectorNetworkWrapper.setup_network_edges_and_nodes_from_graph(
- graph
- )
- graph_complex = nut.graph_from_gdf(edges, nodes, node_id="node_fid")
- return graph_complex, edges
-
- def _read_vector_to_project_region_and_crs(
- self, vector_filenames: list[Path], crs: CRS
- ) -> gpd.GeoDataFrame:
- """Reads a vector file or a list of vector files.
-
- Clips for project region and reproject to project crs if available.
- Explodes multi geometry into single geometry.
-
- Args:
- vector_filenames (list[Path]): List of Path to the vector files.
- crs (CRS): Coordinate reference system for the files. Allow only one crs for all `vector_filenames`.
-
- Returns:
- gpd.GeoDataFrame: GeoDataFrame representing the vector data.
- """
- gdf = self._read_files(vector_filenames)
- if gdf is None:
- logger.info("no file is read.")
- return None
-
- # set crs and reproject if needed
- if not gdf.crs and crs:
- gdf = gdf.set_crs(crs)
- logger.info("setting crs as default EPSG:4326. specify crs if incorrect")
-
- if self.crs:
- gdf = gdf.to_crs(self.crs)
- logger.info("reproject vector file to project crs")
-
- # clip for region
- if self.region is not None:
- gdf = gpd.overlay(gdf, self.region, how="intersection", keep_geom_type=True)
- logger.info("clip vector file to project region")
-
- # validate
- if not any(gdf):
- logger.warning("No vector features found within project region")
- return None
-
- return gdf
-
- def _read_files(self, file_list: list[Path]) -> gpd.GeoDataFrame:
- """Reads a list of files into a GeoDataFrame.
-
- Args:
- file_list (list[Path]): List of file paths.
-
- Returns:
- gpd.GeoDataFrame: GeoDataFrame representing the data.
- """
- # read file
- if isinstance(file_list, list):
- gdf = gpd.GeoDataFrame(pd.concat([gpd.read_file(_fn) for _fn in file_list]))
- logger.info("read vector files.")
- else:
- gdf = None
- logger.info("no file is read.")
- return gdf
-
- @staticmethod
- def clean_vector(gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
- """Cleans a GeoDataFrame.
-
- Args:
- gdf (gpd.GeoDataFrame): Input GeoDataFrame.
-
- Returns:
- gpd.GeoDataFrame: Cleaned GeoDataFrame.
- """
-
- gdf = VectorNetworkWrapper.explode_and_deduplicate_geometries(gdf)
-
- return gdf
-
- @staticmethod
- def explode_and_deduplicate_geometries(gpd: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
- """Explodes and deduplicates geometries a GeoDataFrame.
-
- Args:
- gpd (gpd.GeoDataFrame): Input GeoDataFrame.
-
- Returns:
- gpd.GeoDataFrame: GeoDataFrame with exploded and deduplicated geometries.
- """
- gpd = gpd.explode()
- gpd = gpd[
- gpd.index.isin(
- gpd.geometry.apply(lambda geom: geom.wkb).drop_duplicates().index
- )
- ]
- return gpd
diff --git a/tests/analyses/analysis_config_data/readers/test_analysis_config_reader_with_network.py b/tests/analyses/analysis_config_data/readers/test_analysis_config_reader_with_network.py
index 2c9aed388..8f332b53e 100644
--- a/tests/analyses/analysis_config_data/readers/test_analysis_config_reader_with_network.py
+++ b/tests/analyses/analysis_config_data/readers/test_analysis_config_reader_with_network.py
@@ -1,5 +1,6 @@
from pathlib import Path
from typing import Any
+
import pytest
from ra2ce.analyses.analysis_config_data.readers.analysis_config_reader_with_network import (
diff --git a/tests/analyses/analysis_config_data/readers/test_analysis_config_reader_without_network.py b/tests/analyses/analysis_config_data/readers/test_analysis_config_reader_without_network.py
index 409895a89..24055f1af 100644
--- a/tests/analyses/analysis_config_data/readers/test_analysis_config_reader_without_network.py
+++ b/tests/analyses/analysis_config_data/readers/test_analysis_config_reader_without_network.py
@@ -1,5 +1,8 @@
from pathlib import Path
from typing import Any
+
+import pytest
+
from ra2ce.analyses.analysis_config_data.readers.analysis_config_reader_base import (
AnalysisConfigReaderBase,
)
@@ -7,8 +10,6 @@
AnalysisConfigReaderWithoutNetwork,
)
-import pytest
-
class TestAnalysisWithoutNetworkConfigReader:
def test_initialize(self):
diff --git a/tests/analyses/analysis_config_data/test_analysis_config_data_validator_with_network.py b/tests/analyses/analysis_config_data/test_analysis_config_data_validator_with_network.py
index d9ba07eeb..a9ab3302f 100644
--- a/tests/analyses/analysis_config_data/test_analysis_config_data_validator_with_network.py
+++ b/tests/analyses/analysis_config_data/test_analysis_config_data_validator_with_network.py
@@ -1,3 +1,8 @@
+from pathlib import Path
+from typing import Optional
+
+import pytest
+
from ra2ce.analyses.analysis_config_data.analysis_config_data import (
AnalysisConfigDataWithNetwork,
)
@@ -6,10 +11,7 @@
)
from ra2ce.common.validation.ra2ce_validator_protocol import Ra2ceIoValidator
from ra2ce.common.validation.validation_report import ValidationReport
-import pytest
from tests import test_data
-from typing import Optional
-from pathlib import Path
class TestAnalysisConfigDataValidatorWithNetwork:
diff --git a/tests/analyses/analysis_config_data/test_analysis_config_data_validator_without_network.py b/tests/analyses/analysis_config_data/test_analysis_config_data_validator_without_network.py
index 56c0fcb85..2bb0a0120 100644
--- a/tests/analyses/analysis_config_data/test_analysis_config_data_validator_without_network.py
+++ b/tests/analyses/analysis_config_data/test_analysis_config_data_validator_without_network.py
@@ -9,7 +9,6 @@
from ra2ce.analyses.analysis_config_data.analysis_config_data_validator_without_network import (
AnalysisConfigDataValidatorWithoutNetwork,
)
-
from ra2ce.common.validation.validation_report import ValidationReport
from tests import test_data, test_results
diff --git a/tests/analyses/analysis_config_wrapper/test_analysis_config_wrapper_factory.py b/tests/analyses/analysis_config_wrapper/test_analysis_config_wrapper_factory.py
index d3ba7ac27..d5ad4b13a 100644
--- a/tests/analyses/analysis_config_wrapper/test_analysis_config_wrapper_factory.py
+++ b/tests/analyses/analysis_config_wrapper/test_analysis_config_wrapper_factory.py
@@ -1,13 +1,13 @@
import pytest
-from ra2ce.analyses.analysis_config_wrapper.analysis_config_wrapper_factory import (
- AnalysisConfigWrapperFactory,
-)
from ra2ce.analyses.analysis_config_data.analysis_config_data import (
AnalysisConfigData,
AnalysisConfigDataWithNetwork,
AnalysisConfigDataWithoutNetwork,
)
+from ra2ce.analyses.analysis_config_wrapper.analysis_config_wrapper_factory import (
+ AnalysisConfigWrapperFactory,
+)
from ra2ce.analyses.analysis_config_wrapper.analysis_config_wrapper_with_network import (
AnalysisConfigWrapperWithNetwork,
)
diff --git a/tests/analyses/analysis_config_wrapper/test_analysis_config_wrapper_with_network.py b/tests/analyses/analysis_config_wrapper/test_analysis_config_wrapper_with_network.py
index 88e75a021..4be363067 100644
--- a/tests/analyses/analysis_config_wrapper/test_analysis_config_wrapper_with_network.py
+++ b/tests/analyses/analysis_config_wrapper/test_analysis_config_wrapper_with_network.py
@@ -1,3 +1,4 @@
+import shutil
from pathlib import Path
import pytest
@@ -8,7 +9,6 @@
)
from ra2ce.graph.network_config_wrapper import NetworkConfigWrapper
from tests import test_data, test_results
-import shutil
class TestAnalysisWithNetworkConfig:
diff --git a/tests/analyses/analysis_config_wrapper/test_analysis_config_wrapper_without_network.py b/tests/analyses/analysis_config_wrapper/test_analysis_config_wrapper_without_network.py
index 702273fed..5bcc22d41 100644
--- a/tests/analyses/analysis_config_wrapper/test_analysis_config_wrapper_without_network.py
+++ b/tests/analyses/analysis_config_wrapper/test_analysis_config_wrapper_without_network.py
@@ -1,13 +1,12 @@
+import shutil
+
import pytest
-from ra2ce.analyses.analysis_config_data.analysis_config_data import (
- AnalysisConfigData,
-)
+from ra2ce.analyses.analysis_config_data.analysis_config_data import AnalysisConfigData
from ra2ce.analyses.analysis_config_wrapper.analysis_config_wrapper_without_network import (
AnalysisConfigWrapperWithoutNetwork,
)
from tests import acceptance_test_data, test_results
-import shutil
class TestAnalysisWithoutNetworkConfiguration:
diff --git a/tests/graph/network_wrappers/__init__.py b/tests/graph/network_wrappers/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/graph/osm_network_wrapper/test_osm_network_wrapper.py b/tests/graph/network_wrappers/test_osm_network_wrapper.py
similarity index 89%
rename from tests/graph/osm_network_wrapper/test_osm_network_wrapper.py
rename to tests/graph/network_wrappers/test_osm_network_wrapper.py
index 4661f933f..1dde51df2 100644
--- a/tests/graph/osm_network_wrapper/test_osm_network_wrapper.py
+++ b/tests/graph/network_wrappers/test_osm_network_wrapper.py
@@ -1,4 +1,5 @@
from pathlib import Path
+
import networkx as nx
import pytest
from networkx import Graph, MultiDiGraph
@@ -6,26 +7,42 @@
from shapely.geometry import LineString, Polygon
from shapely.geometry.base import BaseGeometry
-from tests import test_data, slow_test
import ra2ce.graph.networks_utils as nut
-from ra2ce.graph.osm_network_wrapper.osm_network_wrapper import OsmNetworkWrapper
+from ra2ce.graph.network_config_data.network_config_data import (
+ NetworkConfigData,
+ NetworkSection,
+)
+from ra2ce.graph.network_wrappers.network_wrapper_protocol import NetworkWrapperProtocol
+from ra2ce.graph.network_wrappers.osm_network_wrapper.osm_network_wrapper import (
+ OsmNetworkWrapper,
+)
+from tests import slow_test, test_data, test_results
class TestOsmNetworkWrapper:
def test_initialize_without_graph_crs(self):
- _wrapper = OsmNetworkWrapper("a_network", ["r"], "", Path())
+ # 1. Define test data.
+ _network_section = NetworkSection(network_type="a_network", road_types=["r"])
+ _network_config_data = NetworkConfigData(network=_network_section)
+
+ # 2. Run test.
+ _wrapper = OsmNetworkWrapper(_network_config_data)
+
+ # 3. Verify final expectations.
assert isinstance(_wrapper, OsmNetworkWrapper)
- assert _wrapper.graph_crs == "epsg:4326"
+ assert isinstance(_wrapper, NetworkWrapperProtocol)
+ assert _wrapper.graph_crs.to_epsg() == 4326
@pytest.fixture
def _network_wrapper_without_polygon(self) -> OsmNetworkWrapper:
- _network_type = "drive"
- _road_types = ["road_link"]
+ _network_section = NetworkSection(
+ network_type="drive", road_types=["road_link"], directed=True
+ )
+ _output_dir = test_results.joinpath("test_osm_network_wrapper")
+ if not _output_dir.exists():
+ _output_dir.mkdir(parents=True)
yield OsmNetworkWrapper(
- network_type=_network_type,
- road_types=_road_types,
- graph_crs="",
- polygon_path=None,
+ NetworkConfigData(network=_network_section, output_path=_output_dir)
)
def test_download_clean_graph_from_osm_with_invalid_polygon_arg(
diff --git a/tests/graph/test_osm_utils.py b/tests/graph/network_wrappers/test_osm_utils.py
similarity index 95%
rename from tests/graph/test_osm_utils.py
rename to tests/graph/network_wrappers/test_osm_utils.py
index 67303be27..dc2f31c4c 100644
--- a/tests/graph/test_osm_utils.py
+++ b/tests/graph/network_wrappers/test_osm_utils.py
@@ -2,7 +2,9 @@
import pytest
-from ra2ce.graph.osm_utils import from_shapefile_to_poly
+from ra2ce.graph.network_wrappers.osm_network_wrapper.osm_utils import (
+ from_shapefile_to_poly,
+)
from tests import test_data, test_results
diff --git a/tests/graph/network_wrappers/test_shp_network_wrapper.py b/tests/graph/network_wrappers/test_shp_network_wrapper.py
new file mode 100644
index 000000000..56a7c79fc
--- /dev/null
+++ b/tests/graph/network_wrappers/test_shp_network_wrapper.py
@@ -0,0 +1,19 @@
+from ra2ce.graph.network_config_data.network_config_data import (
+ NetworkConfigData,
+)
+from ra2ce.graph.network_wrappers.network_wrapper_protocol import NetworkWrapperProtocol
+from ra2ce.graph.network_wrappers.shp_network_wrapper import ShpNetworkWrapper
+
+
+class TestShpNetworkWrapper:
+ def test_init(self):
+ # 1.Define test data.
+ _config_data = NetworkConfigData()
+
+ # 2. Run test.
+ _wrapper = ShpNetworkWrapper(_config_data)
+
+ # 3. Verify expectations.
+ assert isinstance(_wrapper, ShpNetworkWrapper)
+ assert isinstance(_wrapper, NetworkWrapperProtocol)
+ assert _wrapper.crs.to_epsg() == 4326
diff --git a/tests/graph/network_wrappers/test_trails_network_wrapper.py b/tests/graph/network_wrappers/test_trails_network_wrapper.py
new file mode 100644
index 000000000..58c0fbbee
--- /dev/null
+++ b/tests/graph/network_wrappers/test_trails_network_wrapper.py
@@ -0,0 +1,16 @@
+from ra2ce.graph.network_config_data.network_config_data import NetworkConfigData
+from ra2ce.graph.network_wrappers.network_wrapper_protocol import NetworkWrapperProtocol
+from ra2ce.graph.network_wrappers.trails_network_wrapper import TrailsNetworkWrapper
+
+
+class TestTrailsNetworkWrapper:
+ def test_initialize(self):
+ # 1. Define test data.
+ _config_data = NetworkConfigData()
+
+ # 2. Create wrapper.
+ _wrapper = TrailsNetworkWrapper(_config_data)
+
+ # 3. Verify expectations.
+ assert isinstance(_wrapper, TrailsNetworkWrapper)
+ assert isinstance(_wrapper, NetworkWrapperProtocol)
diff --git a/tests/graph/network_wrappers/test_vector_network_wrapper.py b/tests/graph/network_wrappers/test_vector_network_wrapper.py
new file mode 100644
index 000000000..1bbfb9a83
--- /dev/null
+++ b/tests/graph/network_wrappers/test_vector_network_wrapper.py
@@ -0,0 +1,173 @@
+from pathlib import Path
+
+import geopandas as gpd
+import networkx as nx
+import pytest
+from pyproj import CRS
+from shapely.geometry import LineString, MultiLineString, Point
+
+from ra2ce.graph.network_config_data.network_config_data import NetworkConfigData
+from ra2ce.graph.network_wrappers.network_wrapper_protocol import NetworkWrapperProtocol
+from ra2ce.graph.network_wrappers.vector_network_wrapper import VectorNetworkWrapper
+from tests import test_data
+
+_test_dir = test_data / "vector_network_wrapper"
+
+
+class TestVectorNetworkWrapper:
+ @pytest.fixture
+ def points_gdf(self) -> gpd.GeoDataFrame:
+ points = [Point(-122.3, 47.6), Point(-122.2, 47.5), Point(-122.1, 47.6)]
+ return gpd.GeoDataFrame(geometry=points, crs=4326)
+
+ @pytest.fixture
+ def lines_gdf(self) -> gpd.GeoDataFrame:
+ points = [Point(-122.3, 47.6), Point(-122.2, 47.5), Point(-122.1, 47.6)]
+ lines = [LineString([points[0], points[1]]), LineString([points[1], points[2]])]
+ return gpd.GeoDataFrame(geometry=lines, crs=4326)
+
+ @pytest.fixture
+ def mock_graph(self):
+ points = [(-122.3, 47.6), (-122.2, 47.5), (-122.1, 47.6)]
+ lines = [(points[0], points[1]), (points[1], points[2])]
+
+ graph = nx.Graph(crs=4326)
+ for point in points:
+ graph.add_node(point, geometry=Point(point))
+ for line in lines:
+ graph.add_edge(
+ line[0], line[1], geometry=LineString([Point(line[0]), Point(line[1])])
+ )
+
+ return graph
+
+ def test_initialize(self):
+ # 1. Define test data.
+ _config_data = NetworkConfigData()
+ _config_data.network.primary_file = [Path("dummy_primary")]
+ _config_data.network.directed = False
+ _config_data.origins_destinations.region = Path("dummy_region")
+
+ # 2. Run test.
+ _wrapper = VectorNetworkWrapper(_config_data)
+
+ # 3. Verify expectations.
+ assert isinstance(_wrapper, VectorNetworkWrapper)
+ assert isinstance(_wrapper, NetworkWrapperProtocol)
+ assert _wrapper.primary_files == _config_data.network.primary_file
+ assert _wrapper.region_path == _config_data.origins_destinations.region
+ assert _wrapper.crs.to_epsg() == 4326
+
+ @pytest.fixture
+ def _valid_wrapper(self) -> VectorNetworkWrapper:
+ _network_dir = _test_dir.joinpath("static", "network")
+ _config_data = NetworkConfigData()
+ _config_data.network.primary_file = [
+ _network_dir.joinpath("_test_lines.geojson")
+ ]
+ _config_data.network.directed = False
+ _config_data.origins_destinations.region = None
+ _config_data.crs = CRS.from_user_input(4326)
+ yield VectorNetworkWrapper(
+ config_data=_config_data,
+ )
+
+ def test_read_vector_to_project_region_and_crs(
+ self, _valid_wrapper: VectorNetworkWrapper
+ ):
+ # Given
+ assert not _valid_wrapper.directed
+
+ # When
+ vector = _valid_wrapper._read_vector_to_project_region_and_crs()
+
+ # Then
+ assert isinstance(vector, gpd.GeoDataFrame)
+
+ def test_read_vector_to_project_region_and_crs_with_region(
+ self, _valid_wrapper: VectorNetworkWrapper
+ ):
+ # Given
+ _valid_wrapper.region_path = _test_dir / "_test_polygon.geojson"
+ _expected_region = gpd.read_file(_valid_wrapper.region_path)
+ assert isinstance(_expected_region, gpd.GeoDataFrame)
+
+ # When
+ vector = _valid_wrapper._read_vector_to_project_region_and_crs()
+
+ # Then
+ assert vector.crs == _expected_region.crs
+ assert _expected_region.covers(vector.unary_union).all()
+
+ @pytest.mark.parametrize(
+ "region_path",
+ [
+ pytest.param(None, id="No region"),
+ pytest.param(_test_dir / "_test_polygon.geojson", id="With region"),
+ ],
+ )
+ def test_get_network_from_vector(
+ self, _valid_wrapper: VectorNetworkWrapper, region_path: Path
+ ):
+ # Given
+ _valid_wrapper.region_path = region_path
+
+ # When
+ graph, edges = _valid_wrapper.get_network()
+
+ # Then
+ assert isinstance(graph, nx.MultiGraph)
+ assert isinstance(edges, gpd.GeoDataFrame)
+
+ def test_clean_vector(self, lines_gdf: gpd.GeoDataFrame):
+ # Given
+ gdf1 = VectorNetworkWrapper.explode_and_deduplicate_geometries(lines_gdf)
+
+ # When
+ gdf2 = VectorNetworkWrapper.clean_vector(
+ lines_gdf
+ ) # for now cleanup only does the above
+
+ # Then
+ assert gdf1.equals(gdf2)
+
+ def test_get_indirect_graph_from_vector(self, lines_gdf: gpd.GeoDataFrame):
+ # When
+ graph = VectorNetworkWrapper.get_indirect_graph_from_vector(lines_gdf)
+
+ # Then
+ assert graph.nodes(data="geometry") is not None
+ assert graph.edges(data="geometry") is not None
+ assert graph.graph["crs"] == lines_gdf.crs
+ assert isinstance(graph, nx.Graph) and not isinstance(graph, nx.DiGraph)
+
+ def test_get_direct_graph_from_vector(self, lines_gdf: gpd.GeoDataFrame):
+ # When
+ graph = VectorNetworkWrapper.get_direct_graph_from_vector(lines_gdf)
+
+ # Then
+ assert isinstance(graph, nx.DiGraph)
+
+ def test_get_network_edges_and_nodes_from_graph(
+ self, mock_graph, points_gdf, lines_gdf
+ ):
+ # When
+ edges, nodes = VectorNetworkWrapper.get_network_edges_and_nodes_from_graph(
+ mock_graph
+ )
+
+ # Then
+ assert edges.geometry.equals(lines_gdf.geometry)
+ assert nodes.geometry.equals(points_gdf.geometry)
+ assert set(["node_A", "node_B", "edge_fid"]).issubset(edges.columns)
+ assert set(["node_fid"]).issubset(nodes.columns)
+
+ def test_explode_and_deduplicate_geometries(self, lines_gdf):
+ # Given
+ multi_lines = lines_gdf.geometry.apply(lambda x: MultiLineString([x]))
+
+ # When
+ gdf = VectorNetworkWrapper.explode_and_deduplicate_geometries(multi_lines)
+
+ # Then
+ assert isinstance(gdf.geometry.iloc[0], LineString)
diff --git a/tests/graph/test_vector_network_wrapper.py b/tests/graph/test_vector_network_wrapper.py
deleted file mode 100644
index e34a1a879..000000000
--- a/tests/graph/test_vector_network_wrapper.py
+++ /dev/null
@@ -1,243 +0,0 @@
-import pytest
-
-from pathlib import Path
-
-import geopandas as gpd
-import networkx as nx
-from shapely.geometry import LineString, Point, MultiLineString
-
-from tests import test_data
-from ra2ce.graph.vector_network_wrapper import VectorNetworkWrapper
-
-_test_dir = test_data / "vector_network_wrapper"
-
-
-class TestVectorNetworkWrapper:
- @pytest.fixture
- def _config_fixture(self) -> dict:
- yield {
- "project": {
- "name": "test",
- "crs": 4326,
- },
- "network": {
- "directed": False,
- "source": "shapefile",
- "primary_file": "_test_lines.geojson",
- "diversion_file": None,
- "file_id": "fid",
- "polygon": None,
- "network_type": None,
- "road_types": None,
- "save_shp": False,
- },
- "static": _test_dir / "static",
- "output": _test_dir / "output",
- }
-
- @pytest.fixture
- def points_gdf(self):
- points = [Point(-122.3, 47.6), Point(-122.2, 47.5), Point(-122.1, 47.6)]
- return gpd.GeoDataFrame(geometry=points, crs=4326)
-
- @pytest.fixture
- def lines_gdf(self):
- points = [Point(-122.3, 47.6), Point(-122.2, 47.5), Point(-122.1, 47.6)]
- lines = [LineString([points[0], points[1]]), LineString([points[1], points[2]])]
- return gpd.GeoDataFrame(geometry=lines, crs=4326)
-
- @pytest.fixture
- def mock_graph(self):
- points = [(-122.3, 47.6), (-122.2, 47.5), (-122.1, 47.6)]
- lines = [(points[0], points[1]), (points[1], points[2])]
-
- graph = nx.Graph(crs=4326)
- for point in points:
- graph.add_node(point, geometry=Point(point))
- for line in lines:
- graph.add_edge(
- line[0], line[1], geometry=LineString([Point(line[0]), Point(line[1])])
- )
-
- return graph
-
- @pytest.mark.parametrize(
- "config",
- [
- pytest.param(None, id="NONE as dictionary"),
- pytest.param({}, id="Empty dictionary"),
- pytest.param({"network": {}}, id='Empty "network" in Config'),
- pytest.param({"network": "string"}, id='Invalid "network" type in Config'),
- ],
- )
- def test_init(self, config: dict):
- with pytest.raises(ValueError) as exc_err:
- VectorNetworkWrapper(config=config)
- assert str(exc_err.value) in [
- "Config cannot be None",
- "A network dictionary is required for creating a VectorNetworkWrapper object.",
- 'Config["network"] should be a dictionary',
- ]
-
- def test_parse_ini_stringlist_with_comma_separated_string(self, _config_fixture):
- # Given
- test_wrapper = VectorNetworkWrapper(_config_fixture)
- ini_value = "a,b,c"
-
- # When
- result = test_wrapper._parse_ini_stringlist(ini_value)
-
- # Then
- assert result == ["a", "b", "c"]
-
- def test_parse_ini_stringlist_with_single_string(self, _config_fixture):
- # Given
- test_wrapper = VectorNetworkWrapper(_config_fixture)
- ini_value = "abc"
-
- # When
- result = test_wrapper._parse_ini_stringlist(ini_value)
-
- # Then
- assert result == "abc"
-
- def test_parse_ini_stringlist_with_empty_string(self, _config_fixture):
- # Given
- test_wrapper = VectorNetworkWrapper(_config_fixture)
- ini_value = ""
-
- # When
- result = test_wrapper._parse_ini_stringlist(ini_value)
-
- # Then
- assert result is None
-
- def test_parse_ini_filenamelist(self, _config_fixture):
- # Given
- test_wrapper = VectorNetworkWrapper(_config_fixture)
- ini_value = "_test_lines.geojson, dummy.geojson"
-
- # When
- file_paths = test_wrapper._parse_ini_filenamelist(ini_value)
-
- # Then
- assert file_paths[0].is_file()
-
- def test_setup_global(self, _config_fixture):
- test_wrapper = VectorNetworkWrapper(_config_fixture)
- test_wrapper._setup_global(_config_fixture)
- assert test_wrapper.name == "test"
- assert test_wrapper.region is None
- assert test_wrapper.crs.to_epsg() == 4326
- assert test_wrapper.input_path == _test_dir / "static/network"
- assert test_wrapper.output_path == _test_dir / "output"
-
- def test_get_network_opt(self, _config_fixture):
- test_wrapper = VectorNetworkWrapper(_config_fixture)
- network_dict = test_wrapper._get_network_opt(_config_fixture["network"])
- assert network_dict["files"][0].is_file()
- assert network_dict["file_id"] == "fid"
- assert network_dict["file_crs"].to_epsg() == 4326
- assert network_dict["is_directed"] is False
-
- def test_read_vector_to_project_region_and_crs(self, _config_fixture):
- # Given
- test_wrapper = VectorNetworkWrapper(_config_fixture)
- files = test_wrapper.network_dict["files"]
- file_crs = test_wrapper.network_dict["file_crs"]
-
- # When
- vector = test_wrapper._read_vector_to_project_region_and_crs(files, file_crs)
-
- # Then
- assert isinstance(vector, gpd.GeoDataFrame)
-
- def test_read_vector_to_project_region_and_crs_with_region(self, _config_fixture):
- # Given
- _config_fixture["project"]["region"] = _test_dir / "_test_polygon.geojson"
- test_wrapper = VectorNetworkWrapper(_config_fixture)
- files = test_wrapper.network_dict["files"]
- file_crs = test_wrapper.network_dict["file_crs"]
-
- # When
- vector = test_wrapper._read_vector_to_project_region_and_crs(files, file_crs)
-
- # Then
- assert vector.crs == test_wrapper.region.crs
- assert test_wrapper.region.covers(vector.unary_union).all()
-
- def test_setup_network_from_vector(self, _config_fixture):
- # Given
- test_wrapper = VectorNetworkWrapper(_config_fixture)
-
- # When
- graph, edges = test_wrapper.setup_network_from_vector()
-
- # Then
- assert isinstance(graph, nx.MultiGraph)
- assert isinstance(edges, gpd.GeoDataFrame)
-
- def test_setup_network_from_vector_with_region(self, _config_fixture):
- # Given
- _config_fixture["project"]["region"] = _test_dir / "_test_polygon.geojson"
- test_wrapper = VectorNetworkWrapper(_config_fixture)
-
- # When
- graph, edges = test_wrapper.setup_network_from_vector()
-
- # Then
- assert isinstance(graph, nx.MultiGraph)
- assert isinstance(edges, gpd.GeoDataFrame)
-
- def test_clean_vector(self, lines_gdf):
- # Given
- gdf1 = VectorNetworkWrapper.explode_and_deduplicate_geometries(lines_gdf)
-
- # When
- gdf2 = VectorNetworkWrapper.clean_vector(
- lines_gdf
- ) # for now cleanup only does the above
-
- # Then
- assert gdf1.equals(gdf2)
-
- def test_setup_graph_from_vector(self, lines_gdf):
- # When
- graph = VectorNetworkWrapper.setup_graph_from_vector(lines_gdf)
-
- # Then
- assert graph.nodes(data="geometry") is not None
- assert graph.edges(data="geometry") is not None
- assert graph.graph["crs"] == lines_gdf.crs
- assert isinstance(graph, nx.Graph) and not isinstance(graph, nx.DiGraph)
-
- def test_setup_digraph_from_vector(self, lines_gdf):
- # When
- graph = VectorNetworkWrapper.setup_digraph_from_vector(lines_gdf)
-
- # Then
- assert isinstance(graph, nx.DiGraph)
-
- def test_setup_network_edges_and_nodes_from_graph(
- self, mock_graph, points_gdf, lines_gdf
- ):
- # When
- edges, nodes = VectorNetworkWrapper.setup_network_edges_and_nodes_from_graph(
- mock_graph
- )
-
- # Then
- assert edges.geometry.equals(lines_gdf.geometry)
- assert nodes.geometry.equals(points_gdf.geometry)
- assert set(["node_A", "node_B", "edge_fid"]).issubset(edges.columns)
- assert set(["node_fid"]).issubset(nodes.columns)
-
- def test_explode_and_deduplicate_geometries(self, lines_gdf):
- # Given
- multi_lines = lines_gdf.geometry.apply(lambda x: MultiLineString([x]))
-
- # When
- gdf = VectorNetworkWrapper.explode_and_deduplicate_geometries(multi_lines)
-
- # Then
- assert isinstance(gdf.geometry.iloc[0], LineString)
diff --git a/tests/test_acceptance.py b/tests/test_acceptance.py
index ea4116380..d81cd7d1b 100644
--- a/tests/test_acceptance.py
+++ b/tests/test_acceptance.py
@@ -14,14 +14,13 @@
_analysis_ini_name = "analyses.ini"
_base_graph_p_filename = "base_graph.p"
_base_network_feather_filename = "base_network.feather"
+_skip_cases = []
def get_external_test_cases() -> list[pytest.param]:
if not test_external_data.exists():
return []
- _skip_cases = ["bolivia"]
-
def get_pytest_param(test_dir: Path) -> pytest.param:
_marks = [external_test]
if test_dir.stem.lower() in _skip_cases: