diff --git a/ra2ce/graph/networks.py b/ra2ce/graph/networks.py
index 417be3ba8..04d4f2dd8 100644
--- a/ra2ce/graph/networks.py
+++ b/ra2ce/graph/networks.py
@@ -19,19 +19,18 @@
along with this program. If not, see .
"""
-
import logging
import os
from typing import Any, List, Tuple
import geopandas as gpd
import networkx as nx
-import osmnx
import pandas as pd
import pyproj
from shapely.geometry import MultiLineString
import ra2ce.graph.networks_utils as nut
+from ra2ce.graph.osm_network_wrapper.osm_network_wrapper import OsmNetworkWrapper
from ra2ce.graph.segmentation import Segmentation
from ra2ce.io.readers import GraphPickleReader
from ra2ce.io.writers import JsonExporter
@@ -202,11 +201,15 @@ def network_shp(
# Exporting complex graph because the shapefile should be kept the same as much as possible.
return graph_complex, edges_complex
- def _export_linking_tables(self, linking_tables: List[Any]) -> None:
+ def _export_linking_tables(self, linking_tables: list[Any]) -> None:
_exporter = JsonExporter()
_output_dir = self.config["static"] / "output_graph"
- _exporter.export(_output_dir / "simple_to_complex.json", linking_tables[0])
- _exporter.export(_output_dir / "complex_to_simple.json", linking_tables[1])
+ _exporter.export(
+ _output_dir.joinpath("simple_to_complex.json"), linking_tables[0]
+ )
+ _exporter.export(
+ _output_dir.joinpath("complex_to_simple.json"), linking_tables[1]
+ )
def network_trails_import(
self, crs: int = 4326
@@ -288,52 +291,15 @@ def network_trails_import(
return graph_complex, edges_complex
- def network_osm_download(self) -> Tuple[nx.classes.graph.Graph, gpd.GeoDataFrame]:
- """Creates a network from a polygon by downloading via the OSM API in the extent of the polygon.
+ def network_osm_download(self) -> tuple[nx.classes.graph.Graph, gpd.GeoDataFrame]:
+ """
+ Creates a network from a polygon by downloading via the OSM API in the extent of the polygon.
Returns:
- graph_simple (NetworkX graph): Simplified graph (for use in the indirect analyses).
- complex_edges (GeoDataFrame): Complex graph (for use in the direct analyses).
+ tuple[nx.classes.graph.Graph, gpd.GeoDataFrame]: Tuple of Simplified graph (for use in the indirect analyses) and Complex graph (for use in the direct analyses).
"""
- poly_dict = nut.read_geojson(
- self.config["network"]["polygon"][0]
- ) # It can only read in one geojson
- poly = nut.geojson_to_shp(poly_dict)
-
- if not self.config["network"]["road_types"]:
- # The user specified only the network type.
- graph_complex = osmnx.graph_from_polygon(
- polygon=poly,
- network_type=self.config["network"]["network_type"],
- simplify=False,
- retain_all=True,
- )
- elif not self.config["network"]["network_type"]:
- # The user specified only the road types.
- cf = '["highway"~"{}"]'.format(
- self.config["network"]["road_types"].replace(",", "|")
- )
- graph_complex = osmnx.graph_from_polygon(
- polygon=poly, custom_filter=cf, simplify=False, retain_all=True
- )
- else:
- # The user specified the network type and road types.
- cf = '["highway"~"{}"]'.format(
- self.config["network"]["road_types"].replace(",", "|")
- )
- graph_complex = osmnx.graph_from_polygon(
- polygon=poly,
- network_type=self.config["network"]["network_type"],
- custom_filter=cf,
- simplify=False,
- retain_all=True,
- )
-
- logging.info(
- "graph downloaded from OSM with {:,} nodes and {:,} edges".format(
- len(list(graph_complex.nodes())), len(list(graph_complex.edges()))
- )
- )
+ osm_network = OsmNetworkWrapper(self.config, "")
+ graph_complex = osm_network.get_clean_graph_from_osm()
# Create 'graph_simple'
graph_simple, graph_complex, link_tables = nut.create_simplified_graph(
diff --git a/ra2ce/graph/networks_utils.py b/ra2ce/graph/networks_utils.py
index 8dd671356..2cddcf0b3 100644
--- a/ra2ce/graph/networks_utils.py
+++ b/ra2ce/graph/networks_utils.py
@@ -19,7 +19,6 @@
along with this program. If not, see .
"""
-
import itertools
import logging
import os
@@ -1045,7 +1044,7 @@ def graph_check_create_unique_ids(
):
i = 0
- for u, v, k in graph.edges(keys=True):
+ for u, v, k in graph.edges(data=True):
graph[u][v][k][new_id_name] = i
i += 1
logging.info(
@@ -1080,10 +1079,10 @@ def add_missing_geoms_graph(graph: nx.Graph, geom_name: str = "geometry") -> nx.
graph.nodes[nd][geom_name] = Point(graph.nodes[nd]["x"], graph.nodes[nd]["y"])
edges_without_geom = [
- e for e in graph.edges.data(keys=True) if geom_name not in e[-1]
+ e for e in graph.edges.data(data=True) if geom_name not in e[-1]
]
for ed in edges_without_geom:
- graph[ed[0]][ed[1]][ed[2]][geom_name] = LineString(
+ graph[ed[0]][ed[1]][0][geom_name] = LineString(
[graph.nodes[ed[0]][geom_name], graph.nodes[ed[1]][geom_name]]
)
diff --git a/ra2ce/graph/osm_network_wrapper/__init__.py b/ra2ce/graph/osm_network_wrapper/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/ra2ce/graph/osm_network_wrapper/extremities_data.py b/ra2ce/graph/osm_network_wrapper/extremities_data.py
new file mode 100644
index 000000000..6e4308de1
--- /dev/null
+++ b/ra2ce/graph/osm_network_wrapper/extremities_data.py
@@ -0,0 +1,79 @@
+from dataclasses import dataclass
+
+from networkx import MultiDiGraph
+
+
+@dataclass
+class ExtremitiesData:
+ from_id: int = None
+ to_id: int = None
+ from_to_id: tuple = None
+ to_from_id: tuple = None
+ from_to_coor: tuple = None
+ to_from_coor: tuple = None
+
+ @staticmethod
+ def get_extremities_data_for_sub_graph(from_node_id: int, to_node_id: int, sub_graph: MultiDiGraph,
+ graph: MultiDiGraph, shared_elements: set):
+ """Both extremities should be in the unique_graph still makes an edge between similar node to u (the node
+ with u coordinates and different id, included in the unique_graph) and v Here, sub_graph is the unique_graph
+ and graph is complex_graph Shared elements are shared btw sub_graph and graph, which are elements to include
+ when dropping duplicates"""
+ if shared_elements is None or not isinstance(shared_elements, set):
+ raise ValueError("unique_elements should be a set")
+ if from_node_id in sub_graph.nodes() and to_node_id in sub_graph.nodes():
+ return ExtremitiesData.arrange_extremities_data(
+ from_node_id=from_node_id, to_node_id=to_node_id, graph=sub_graph
+ )
+ elif (graph.nodes[from_node_id]['x'], graph.nodes[from_node_id]['y']) in shared_elements and \
+ to_node_id in sub_graph.nodes():
+
+ from_node_id_prime = ExtremitiesData.find_node_id_by_coor(
+ sub_graph, graph.nodes[from_node_id]['x'], graph.nodes[from_node_id]['y']
+ )
+ if from_node_id_prime == to_node_id:
+ return ExtremitiesData()
+ else:
+ return ExtremitiesData.arrange_extremities_data(from_node_id=from_node_id_prime, to_node_id=to_node_id,
+ graph=sub_graph)
+
+ elif from_node_id in sub_graph.nodes() and \
+ (graph.nodes[to_node_id]['x'], graph.nodes[to_node_id]['y']) in shared_elements:
+
+ to_node_id_prime = ExtremitiesData.find_node_id_by_coor(
+ sub_graph, graph.nodes[to_node_id]['x'], graph.nodes[to_node_id]['y']
+ )
+ if from_node_id == to_node_id_prime:
+ return ExtremitiesData()
+ else:
+ return ExtremitiesData.arrange_extremities_data(from_node_id=from_node_id, to_node_id=to_node_id_prime,
+ graph=sub_graph)
+ else:
+ return ExtremitiesData()
+
+ @staticmethod
+ def arrange_extremities_data(from_node_id: int, to_node_id: int, graph: MultiDiGraph):
+ return ExtremitiesData(
+ from_id=from_node_id,
+ to_id=to_node_id,
+ from_to_id=(from_node_id, to_node_id),
+ to_from_id=(to_node_id, from_node_id),
+ from_to_coor=(
+ (graph.nodes[from_node_id]['x'], graph.nodes[to_node_id]['x']),
+ (graph.nodes[from_node_id]['y'], graph.nodes[to_node_id]['y'])
+ ),
+ to_from_coor=(
+ (graph.nodes[to_node_id]['x'], graph.nodes[from_node_id]['x']),
+ (graph.nodes[to_node_id]['y'], graph.nodes[from_node_id]['y'])
+ )
+ )
+
+ @staticmethod
+ def find_node_id_by_coor(graph: MultiDiGraph, target_x: float, target_y: float):
+ """
+ finds the node in unique graph with the same coor
+ """
+ for node, data in graph.nodes(data=True):
+ if 'x' in data and 'y' in data and data['x'] == target_x and data['y'] == target_y:
+ return node
+ return None
diff --git a/ra2ce/graph/osm_network_wrapper/osm_network_wrapper.py b/ra2ce/graph/osm_network_wrapper/osm_network_wrapper.py
new file mode 100644
index 000000000..e4b3cf6d4
--- /dev/null
+++ b/ra2ce/graph/osm_network_wrapper/osm_network_wrapper.py
@@ -0,0 +1,242 @@
+"""
+ GNU GENERAL PUBLIC LICENSE
+ Version 3, 29 June 2007
+ Risk Assessment and Adaptation for Critical Infrastructure (RA2CE).
+ Copyright (C) 2023 Stichting Deltares
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+"""
+import logging
+from pathlib import Path
+
+import networkx as nx
+import osmnx
+from networkx import MultiDiGraph
+from osmnx import consolidate_intersections
+from shapely.geometry.base import BaseGeometry
+
+import ra2ce.graph.networks_utils as nut
+from ra2ce.graph.osm_network_wrapper.extremities_data import ExtremitiesData
+
+
+class OsmNetworkWrapper:
+ network_dict: dict
+ output_path: Path
+ graph_crs: str
+
+ def __init__(self, config: dict, graph_crs: str) -> None:
+ if not config:
+ raise ValueError("Config cannot be None")
+ if not config.get("network", {}):
+ raise ValueError(
+ "A network dictionary is required for creating a OsmNetworkWrapper object."
+ )
+ if not isinstance(config.get("network"), dict):
+ raise ValueError('Config["network"] should be a dictionary')
+
+ self.network_dict = config["network"]
+ self.output_path = config["static"] / "output_graph"
+ self.graph_crs = graph_crs
+ if not self.graph_crs:
+ self.graph_crs = "epsg:4326"
+
+ def get_clean_graph_from_osm(self) -> MultiDiGraph:
+ """
+ Creates a network from a polygon by by downloading via the OSM API in its extent.
+
+ Raises:
+ FileNotFoundError: When no valid polygon file is provided.
+
+ Returns:
+ MultiDiGraph: Complex (clean) graph after download from OSM, for use in the direct analyses and input to derive simplified network.
+ """
+ # It can only read in one geojson
+ if not self.network_dict.get("polygon", []):
+ raise ValueError("No valid value provided for polygon file.")
+
+ polygon_file = self.output_path.parent.joinpath(
+ "network", self.network_dict.get("polygon", [])[0]
+ )
+ if not polygon_file.is_file():
+ raise FileNotFoundError("No polygon_file file found.")
+
+ poly_dict = nut.read_geojson(geojson_file=polygon_file)
+ _complex_graph = self._download_clean_graph_from_osm(
+ polygon=nut.geojson_to_shp(poly_dict),
+ network_type=self.network_dict.get("network_type", ""),
+ road_types=self.network_dict.get("road_types", ""),
+ )
+ return _complex_graph
+
+ def _download_clean_graph_from_osm(
+ self, polygon: BaseGeometry, road_types: str, network_type: str
+ ) -> MultiDiGraph:
+ if not road_types and not network_type:
+ raise ValueError("Either of the link_type or network_type should be known")
+ elif not road_types:
+ # The user specified only the network type.
+ _complex_graph = osmnx.graph_from_polygon(
+ polygon=polygon,
+ network_type=network_type,
+ simplify=False,
+ retain_all=True,
+ )
+ elif not network_type:
+ # The user specified only the road types.
+ cf = f'["highway"~"{road_types.replace(",", "|")}"]'
+ _complex_graph = osmnx.graph_from_polygon(
+ polygon=polygon, custom_filter=cf, simplify=False, retain_all=True
+ )
+ else:
+ cf = f'["highway"~"{road_types.replace(",", "|")}"]'
+ _complex_graph = osmnx.graph_from_polygon(
+ polygon=polygon,
+ network_type=network_type,
+ custom_filter=cf,
+ simplify=False,
+ retain_all=True,
+ )
+
+ logging.info(
+ "graph downloaded from OSM with {:,} nodes and {:,} edges".format(
+ len(list(_complex_graph.nodes())), len(list(_complex_graph.edges()))
+ )
+ )
+ if "crs" not in _complex_graph.graph.keys():
+ _complex_graph.graph["crs"] = self.graph_crs
+ self.get_clean_graph(_complex_graph)
+ return _complex_graph
+
+ @staticmethod
+ def get_clean_graph(complex_graph: MultiDiGraph):
+ complex_graph = OsmNetworkWrapper.drop_duplicates(complex_graph)
+ complex_graph = nut.add_missing_geoms_graph(
+ graph=complex_graph, geom_name="geometry"
+ ).to_directed()
+ complex_graph = OsmNetworkWrapper.snap_nodes_to_nodes(
+ graph=complex_graph, threshold=0.000025
+ )
+ return complex_graph
+
+ @staticmethod
+ def drop_duplicates(complex_graph: MultiDiGraph) -> MultiDiGraph:
+ unique_elements = (
+ set()
+ ) # This gets updated during the drop_duplicates_in_nodes and drop_duplicates_in_edges
+
+ unique_graph = OsmNetworkWrapper.drop_duplicates_in_nodes(
+ unique_elements=unique_elements, graph=complex_graph
+ )
+ unique_graph = OsmNetworkWrapper.drop_duplicates_in_edges(
+ unique_elements=unique_elements,
+ unique_graph=unique_graph,
+ graph=complex_graph,
+ )
+ return unique_graph
+
+ @staticmethod
+ def drop_duplicates_in_nodes(
+ unique_elements: set, graph: MultiDiGraph
+ ) -> MultiDiGraph:
+ if unique_elements is None or not isinstance(unique_elements, set):
+ raise ValueError("unique_elements should be a set")
+
+ unique_graph = nx.MultiDiGraph()
+ for node, data in graph.nodes(data=True):
+ if data["x"] is None or data["y"] is None:
+ raise ValueError(
+ "Incompatible coordinate keys. Check the keys which define the x and y coordinates"
+ )
+
+ x, y = data["x"], data["y"]
+ coord = (x, y)
+ if coord not in unique_elements:
+ node_attributes = {key: value for key, value in data.items()}
+ unique_graph.add_node(node, **node_attributes)
+ unique_elements.add(coord)
+ # Copy the graph dictionary from the source one.
+ unique_graph.graph = graph.graph
+ return unique_graph
+
+ @staticmethod
+ def drop_duplicates_in_edges(
+ unique_elements: set, unique_graph: MultiDiGraph, graph: MultiDiGraph
+ ):
+ """
+ Checks if both extremities are in the unique_graph (u has not the same coor of v, no line from u to itself is
+ allowed). Checks if an edge is already made between such extremities with the given id and coordinates before
+ considering it in the unique graph
+ """
+ if (
+ not unique_elements
+ or not any(unique_elements)
+ or not all(isinstance(item, tuple) for item in unique_elements)
+ ):
+ raise ValueError(
+ """unique_elements cannot be None, empty, or have non-tuple elements.
+ Provide a set with all unique node coordinates as tuples of (x, y)"""
+ )
+ if unique_graph is None:
+ raise ValueError(
+ """unique_graph cannot be None. Provide a graph with unique nodes or perform the
+ drop_duplicates_in_nodes on the graph to generate a unique_graph"""
+ )
+
+ def validity_check(extremities_tuple) -> bool:
+ extremities = extremities_tuple[0]
+ return not (
+ extremities.from_to_id is None
+ or extremities.from_to_coor is None
+ or extremities.from_id == extremities.to_id
+ )
+
+ def valid_extremity_data(u, v, data) -> tuple[ExtremitiesData, dict]:
+ _extremities_data = ExtremitiesData.get_extremities_data_for_sub_graph(
+ from_node_id=u,
+ to_node_id=v,
+ sub_graph=unique_graph,
+ graph=graph,
+ shared_elements=unique_elements,
+ )
+
+ return _extremities_data, data
+
+ for _extremity_data, _edge_data in filter(
+ validity_check,
+ map(lambda edge: valid_extremity_data(*edge), graph.edges(data=True)),
+ ):
+ _id_combination = (_extremity_data.from_to_id, _extremity_data.to_from_id)
+ _coor_combination = (
+ _extremity_data.from_to_coor,
+ _extremity_data.to_from_coor,
+ )
+ if all(
+ _combination not in unique_elements
+ for _combination in [_id_combination, _coor_combination]
+ ):
+ edge_attributes = {key: value for key, value in _edge_data.items()}
+ unique_graph.add_edge(
+ _extremity_data.from_id, _extremity_data.to_id, **edge_attributes
+ )
+ unique_elements.add(_id_combination)
+ unique_elements.add(_coor_combination)
+
+ return unique_graph
+
+ @staticmethod
+ def snap_nodes_to_nodes(graph: MultiDiGraph, threshold: float) -> MultiDiGraph:
+ return consolidate_intersections(
+ G=graph, rebuild_graph=True, tolerance=threshold, dead_ends=False
+ )
+
+ @staticmethod
+ def snap_nodes_to_edges(graph: MultiDiGraph, threshold: float):
+ raise NotImplementedError("Next thing to do!")
diff --git a/tests/graph/osm_network_wrapper/__init__.py b/tests/graph/osm_network_wrapper/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/graph/osm_network_wrapper/test_osm_network_wrapper.py b/tests/graph/osm_network_wrapper/test_osm_network_wrapper.py
new file mode 100644
index 000000000..638e202b1
--- /dev/null
+++ b/tests/graph/osm_network_wrapper/test_osm_network_wrapper.py
@@ -0,0 +1,298 @@
+import networkx as nx
+import pytest
+from networkx import Graph, MultiDiGraph
+from networkx.utils import graphs_equal
+from shapely.geometry import LineString, Polygon
+from shapely.geometry.base import BaseGeometry
+
+from tests import test_data, slow_test
+import ra2ce.graph.networks_utils as nut
+from ra2ce.graph.osm_network_wrapper.osm_network_wrapper import OsmNetworkWrapper
+
+
+class TestOsmNetworkWrapper:
+ @pytest.fixture
+ def _config_fixture(self) -> dict:
+ _test_dir = test_data / "graph" / "test_osm_network_wrapper"
+
+ yield {
+ "static": _test_dir / "static",
+ "network": {"polygon": "_test_polygon.geojson"},
+ "origins_destinations": {
+ "origins": None,
+ "destinations": None,
+ "origins_names": None,
+ "destinations_names": None,
+ "id_name_origin_destination": None,
+ "category": "dummy_category",
+ "region": "",
+ },
+ "cleanup": {"snapping_threshold": None, "segmentation_length": None},
+ }
+
+ @pytest.mark.parametrize(
+ "config",
+ [
+ pytest.param(None, id="None"),
+ pytest.param({}, id="{}"),
+ pytest.param({"network": {}}, id='"Config["network"]" = None'),
+ pytest.param({"network": "string"}, id='"invalid type(Config["network"])"'),
+ ],
+ )
+ def test_osm_network_wrapper_initialisation_with_invalid_config(self, config: dict):
+ _files = []
+ with pytest.raises(ValueError) as exc_err:
+ OsmNetworkWrapper(config=config, graph_crs="")
+ assert (
+ str(exc_err.value) == "Config cannot be None"
+ or "A network dictionary is required for creating a OsmNetworkWrapper object"
+ or 'Config["network"] should be a dictionary'
+ )
+
+ def test__download_clean_graph_from_osm_with_invalid_polygon_arg(
+ self, _config_fixture: dict
+ ):
+ _osm_network = OsmNetworkWrapper(config=_config_fixture, graph_crs="")
+ _polygon = None
+ _link_type = "road_link"
+ _network_type = "drive"
+ with pytest.raises(AttributeError) as exc_err:
+ _osm_network._download_clean_graph_from_osm(
+ _polygon, _link_type, _network_type
+ )
+
+ assert str(exc_err.value) == "'NoneType' object has no attribute 'is_valid'"
+
+ def test__download_clean_graph_from_osm_with_invalid_polygon_arg_geometry(
+ self, _config_fixture: dict
+ ):
+ _osm_network = OsmNetworkWrapper(config=_config_fixture, graph_crs="")
+ _polygon = LineString([[0, 0], [1, 0], [1, 1]])
+ _link_type = "road_link"
+ _network_type = "drive"
+ with pytest.raises(TypeError) as exc_err:
+ _osm_network._download_clean_graph_from_osm(
+ _polygon, _link_type, _network_type
+ )
+
+ assert (
+ str(exc_err.value)
+ == "Geometry must be a shapely Polygon or MultiPolygon. If you requested graph from place name, make sure your query resolves to a Polygon or MultiPolygon, and not some other geometry, like a Point. See OSMnx documentation for details."
+ )
+
+ def test__download_clean_graph_from_osm_with_invalid_network_type_arg(
+ self, _config_fixture: dict
+ ):
+ _osm_network = OsmNetworkWrapper(config=_config_fixture, graph_crs="")
+ _polygon = Polygon([(0.0, 0.0), (0.0, 1.0), (1.0, 1.0), (1.0, 0.0), (0.0, 0.0)])
+ _link_type = ""
+ _network_type = "drv"
+ with pytest.raises(ValueError) as exc_err:
+ _osm_network._download_clean_graph_from_osm(
+ _polygon, _link_type, _network_type
+ )
+
+ assert str(exc_err.value) == f'Unrecognized network_type "{_network_type}"'
+
+ @pytest.fixture
+ def _valid_network_polygon_fixture(self) -> BaseGeometry:
+ _test_input_directory = test_data.joinpath("graph", "test_osm_network_wrapper")
+ _polygon_file = _test_input_directory.joinpath("_test_polygon.geojson")
+ assert _polygon_file.exists()
+ _polygon_dict = nut.read_geojson(_polygon_file)
+ yield nut.geojson_to_shp(_polygon_dict)
+
+ @slow_test
+ def test__download_clean_graph_from_osm_output(
+ self, _config_fixture: dict, _valid_network_polygon_fixture: BaseGeometry
+ ):
+ # 1. Define test data.
+ _osm_network = OsmNetworkWrapper(config=_config_fixture, graph_crs="")
+ _link_type = ""
+ _network_type = "drive"
+
+ # 2. Run test.
+ graph_complex = _osm_network._download_clean_graph_from_osm(
+ polygon=_valid_network_polygon_fixture,
+ network_type=_network_type,
+ road_types=_link_type,
+ )
+
+ # 3. Verify expectations
+ # reference: https://www.openstreetmap.org/node/1402598729#map=17/51.98816/4.39126&layers=T
+ _osm_id_node_to_validate = 4987298323
+ # reference: https://www.openstreetmap.org/way/334316041#map=19/51.98945/4.39166&layers=T
+ _osm_id_edge_to_validate = 334316041
+
+ assert isinstance(graph_complex, Graph)
+ assert _osm_id_node_to_validate in list(graph_complex.nodes.keys())
+ assert _osm_id_edge_to_validate in list(
+ map(lambda x: x["osmid"], graph_complex.edges.values())
+ )
+
+ @pytest.mark.parametrize(
+ "polygon_values",
+ [
+ pytest.param([""], id="Empty polygon file value"),
+ pytest.param(["Not a valid name"], id="Invalid polygon file name"),
+ ],
+ )
+ def test_get_clean_graph_from_osm_with_invalid_polygon_parameter_filename(
+ self, polygon_values
+ ):
+ _config_dict = {
+ "static": test_data.joinpath("graph", "test_osm_network_wrapper", "static"),
+ "network": {"polygon": polygon_values},
+ }
+ _osm_network = OsmNetworkWrapper(config=_config_dict, graph_crs="")
+ with pytest.raises(FileNotFoundError) as exc_err:
+ _osm_network.get_clean_graph_from_osm()
+
+ assert str(exc_err.value) == "No polygon_file file found."
+
+ @pytest.mark.parametrize(
+ "polygon_value",
+ [
+ pytest.param(None, id="None polygon file"),
+ pytest.param([], id="Invalid polygon file name"),
+ ],
+ )
+ def test_get_clean_graph_from_osm_with_invalid_polygon_parameter(
+ self, polygon_value
+ ):
+ _config_dict = {
+ "static": test_data.joinpath("graph", "test_osm_network_wrapper", "static"),
+ "network": {"polygon": polygon_value},
+ }
+ _osm_network = OsmNetworkWrapper(config=_config_dict, graph_crs="")
+ with pytest.raises(ValueError) as exc_err:
+ _osm_network.get_clean_graph_from_osm()
+
+ assert str(exc_err.value) == "No valid value provided for polygon file."
+
+ @pytest.fixture
+ def _valid_graph_fixture(self) -> MultiDiGraph:
+ _valid_graph = nx.MultiDiGraph()
+ _valid_graph.add_node(1, x=1, y=10)
+ _valid_graph.add_node(2, x=2, y=20)
+ _valid_graph.add_node(3, x=1, y=10)
+ _valid_graph.add_node(4, x=2, y=40)
+ _valid_graph.add_node(5, x=3, y=50)
+
+ _valid_graph.add_edge(1, 2, x=[1, 2], y=[10, 20])
+ _valid_graph.add_edge(1, 3, x=[1, 1], y=[10, 10])
+ _valid_graph.add_edge(2, 4, x=[2, 2], y=[20, 40])
+ _valid_graph.add_edge(3, 4, x=[1, 2], y=[10, 40])
+ _valid_graph.add_edge(1, 4, x=[1, 2], y=[10, 40])
+ _valid_graph.add_edge(5, 3, x=[3, 1], y=[50, 10])
+ _valid_graph.add_edge(5, 5, x=[3, 3], y=[50, 50])
+
+ # Add a valid CRS value.
+ _valid_graph.graph["crs"] = "EPSG:4326"
+
+ return _valid_graph
+
+ @pytest.fixture
+ def _expected_unique_graph_fixture(self) -> MultiDiGraph:
+ _valid_unique_graph = nx.MultiDiGraph()
+ _valid_unique_graph.add_node(1, x=1, y=10)
+ _valid_unique_graph.add_node(2, x=2, y=20)
+ _valid_unique_graph.add_node(4, x=2, y=40)
+ _valid_unique_graph.add_node(5, x=3, y=50)
+
+ _valid_unique_graph.add_edge(1, 2, x=[1, 2], y=[10, 20])
+ _valid_unique_graph.add_edge(2, 4, x=[2, 2], y=[20, 40])
+ _valid_unique_graph.add_edge(1, 4, x=[1, 2], y=[10, 40])
+ _valid_unique_graph.add_edge(5, 1, x=[3, 1], y=[50, 10])
+
+ return _valid_unique_graph
+
+ def test_drop_duplicates_in_nodes(
+ self,
+ _valid_graph_fixture: MultiDiGraph,
+ _expected_unique_graph_fixture: MultiDiGraph,
+ ):
+ unique_graph = OsmNetworkWrapper.drop_duplicates_in_nodes(
+ graph=_valid_graph_fixture, unique_elements=set()
+ )
+
+ assert unique_graph.nodes() == _expected_unique_graph_fixture.nodes()
+
+ @pytest.mark.parametrize(
+ "unique_elements",
+ [
+ pytest.param(None, id="None unique_elements"),
+ pytest.param(set(), id="Empty unique_elements"),
+ pytest.param({1, 2}, id="Non-tuple elements"),
+ ],
+ )
+ def test_drop_duplicates_in_edges_invalid_unique_elements_input(
+ self,
+ unique_elements: set,
+ _valid_graph_fixture: MultiDiGraph,
+ _expected_unique_graph_fixture: MultiDiGraph,
+ ):
+ with pytest.raises(ValueError) as exc_err:
+ OsmNetworkWrapper.drop_duplicates_in_edges(
+ graph=_valid_graph_fixture,
+ unique_elements=unique_elements,
+ unique_graph=None,
+ )
+
+ assert (
+ str(exc_err.value)
+ == """unique_elements cannot be None, empty, or have non-tuple elements.
+ Provide a set with all unique node coordinates as tuples of (x, y)"""
+ )
+
+ def test_drop_duplicates_in_edges_invalid_unique_graph_input(
+ self,
+ _valid_graph_fixture: MultiDiGraph,
+ _expected_unique_graph_fixture: MultiDiGraph,
+ ):
+ with pytest.raises(ValueError) as exc_err:
+ OsmNetworkWrapper.drop_duplicates_in_edges(
+ graph=_valid_graph_fixture, unique_elements={(1, 2)}, unique_graph=None
+ )
+
+ assert (
+ str(exc_err.value)
+ == """unique_graph cannot be None. Provide a graph with unique nodes or perform the
+ drop_duplicates_in_nodes on the graph to generate a unique_graph"""
+ )
+
+ def test_drop_duplicates_in_edges(
+ self,
+ _valid_graph_fixture: MultiDiGraph,
+ _expected_unique_graph_fixture: MultiDiGraph,
+ ):
+ # 1. Define test data.
+ unique_elements = {(1, 10), (2, 20), (4, 40), (5, 50)}
+
+ unique_graph = nx.MultiDiGraph()
+ unique_graph.add_node(1, x=1, y=10)
+ unique_graph.add_node(2, x=2, y=20)
+ unique_graph.add_node(4, x=2, y=40)
+ unique_graph.add_node(5, x=3, y=50)
+
+ # 2. Run test.
+ unique_graph = OsmNetworkWrapper.drop_duplicates_in_edges(
+ graph=_valid_graph_fixture,
+ unique_elements=unique_elements,
+ unique_graph=unique_graph,
+ )
+
+ # 3. Verify results
+ assert graphs_equal(unique_graph, _expected_unique_graph_fixture)
+
+ def test_snap_nodes_to_nodes(self, _valid_graph_fixture: MultiDiGraph):
+ # 1. Define test data.
+ _threshold = 0.00002
+
+ # 2. Run test.
+ _result_graph = OsmNetworkWrapper.snap_nodes_to_nodes(
+ _valid_graph_fixture, _threshold
+ )
+
+ # 3. Verify expectations.
+ assert isinstance(_result_graph, MultiDiGraph)
diff --git a/tests/graph/test_networks.py b/tests/graph/test_networks.py
index 9197a0259..8dd7481c0 100644
--- a/tests/graph/test_networks.py
+++ b/tests/graph/test_networks.py
@@ -1,6 +1,5 @@
import shutil
from typing import Iterator
-
import pytest
from ra2ce.graph.networks import Network
diff --git a/tests/graph/test_networks_utils.py b/tests/graph/test_networks_utils.py
index e2801170f..8f9146219 100644
--- a/tests/graph/test_networks_utils.py
+++ b/tests/graph/test_networks_utils.py
@@ -394,7 +394,7 @@ def test_with_valid_data(self):
# 3. Verify final expectations
assert _return_graph == _graph
- _items = list(_return_graph.edges.data(keys=True))
+ _items = list(_return_graph.edges.data(data=True))
assert len(_items) == 1
_data = _items[0][-1]
assert isinstance(_data, dict)
diff --git a/tests/test_data/graph/test_osm_network_wrapper/_test_polygon.PNG b/tests/test_data/graph/test_osm_network_wrapper/_test_polygon.PNG
new file mode 100644
index 000000000..1b09b25cf
Binary files /dev/null and b/tests/test_data/graph/test_osm_network_wrapper/_test_polygon.PNG differ
diff --git a/tests/test_data/graph/test_osm_network_wrapper/_test_polygon.geojson b/tests/test_data/graph/test_osm_network_wrapper/_test_polygon.geojson
new file mode 100644
index 000000000..6d9bf2645
--- /dev/null
+++ b/tests/test_data/graph/test_osm_network_wrapper/_test_polygon.geojson
@@ -0,0 +1,36 @@
+{
+ "type": "FeatureCollection",
+ "features": [
+ {
+ "type": "Feature",
+ "properties": {},
+ "geometry": {
+ "coordinates": [
+ [
+ [
+ 4.391052553972401,
+ 51.99092158517831
+ ],
+ [
+ 4.392773550123422,
+ 51.98700503937195
+ ],
+ [
+ 4.393053863048749,
+ 51.9870598530963
+ ],
+ [
+ 4.391344199966596,
+ 51.99101022946144
+ ],
+ [
+ 4.391052553972401,
+ 51.99092158517831
+ ]
+ ]
+ ],
+ "type": "Polygon"
+ }
+ }
+ ]
+}
\ No newline at end of file