diff --git a/ra2ce/graph/networks.py b/ra2ce/graph/networks.py index 417be3ba8..04d4f2dd8 100644 --- a/ra2ce/graph/networks.py +++ b/ra2ce/graph/networks.py @@ -19,19 +19,18 @@ along with this program. If not, see . """ - import logging import os from typing import Any, List, Tuple import geopandas as gpd import networkx as nx -import osmnx import pandas as pd import pyproj from shapely.geometry import MultiLineString import ra2ce.graph.networks_utils as nut +from ra2ce.graph.osm_network_wrapper.osm_network_wrapper import OsmNetworkWrapper from ra2ce.graph.segmentation import Segmentation from ra2ce.io.readers import GraphPickleReader from ra2ce.io.writers import JsonExporter @@ -202,11 +201,15 @@ def network_shp( # Exporting complex graph because the shapefile should be kept the same as much as possible. return graph_complex, edges_complex - def _export_linking_tables(self, linking_tables: List[Any]) -> None: + def _export_linking_tables(self, linking_tables: list[Any]) -> None: _exporter = JsonExporter() _output_dir = self.config["static"] / "output_graph" - _exporter.export(_output_dir / "simple_to_complex.json", linking_tables[0]) - _exporter.export(_output_dir / "complex_to_simple.json", linking_tables[1]) + _exporter.export( + _output_dir.joinpath("simple_to_complex.json"), linking_tables[0] + ) + _exporter.export( + _output_dir.joinpath("complex_to_simple.json"), linking_tables[1] + ) def network_trails_import( self, crs: int = 4326 @@ -288,52 +291,15 @@ def network_trails_import( return graph_complex, edges_complex - def network_osm_download(self) -> Tuple[nx.classes.graph.Graph, gpd.GeoDataFrame]: - """Creates a network from a polygon by downloading via the OSM API in the extent of the polygon. + def network_osm_download(self) -> tuple[nx.classes.graph.Graph, gpd.GeoDataFrame]: + """ + Creates a network from a polygon by downloading via the OSM API in the extent of the polygon. Returns: - graph_simple (NetworkX graph): Simplified graph (for use in the indirect analyses). - complex_edges (GeoDataFrame): Complex graph (for use in the direct analyses). + tuple[nx.classes.graph.Graph, gpd.GeoDataFrame]: Tuple of Simplified graph (for use in the indirect analyses) and Complex graph (for use in the direct analyses). """ - poly_dict = nut.read_geojson( - self.config["network"]["polygon"][0] - ) # It can only read in one geojson - poly = nut.geojson_to_shp(poly_dict) - - if not self.config["network"]["road_types"]: - # The user specified only the network type. - graph_complex = osmnx.graph_from_polygon( - polygon=poly, - network_type=self.config["network"]["network_type"], - simplify=False, - retain_all=True, - ) - elif not self.config["network"]["network_type"]: - # The user specified only the road types. - cf = '["highway"~"{}"]'.format( - self.config["network"]["road_types"].replace(",", "|") - ) - graph_complex = osmnx.graph_from_polygon( - polygon=poly, custom_filter=cf, simplify=False, retain_all=True - ) - else: - # The user specified the network type and road types. - cf = '["highway"~"{}"]'.format( - self.config["network"]["road_types"].replace(",", "|") - ) - graph_complex = osmnx.graph_from_polygon( - polygon=poly, - network_type=self.config["network"]["network_type"], - custom_filter=cf, - simplify=False, - retain_all=True, - ) - - logging.info( - "graph downloaded from OSM with {:,} nodes and {:,} edges".format( - len(list(graph_complex.nodes())), len(list(graph_complex.edges())) - ) - ) + osm_network = OsmNetworkWrapper(self.config, "") + graph_complex = osm_network.get_clean_graph_from_osm() # Create 'graph_simple' graph_simple, graph_complex, link_tables = nut.create_simplified_graph( diff --git a/ra2ce/graph/networks_utils.py b/ra2ce/graph/networks_utils.py index 8dd671356..2cddcf0b3 100644 --- a/ra2ce/graph/networks_utils.py +++ b/ra2ce/graph/networks_utils.py @@ -19,7 +19,6 @@ along with this program. If not, see . """ - import itertools import logging import os @@ -1045,7 +1044,7 @@ def graph_check_create_unique_ids( ): i = 0 - for u, v, k in graph.edges(keys=True): + for u, v, k in graph.edges(data=True): graph[u][v][k][new_id_name] = i i += 1 logging.info( @@ -1080,10 +1079,10 @@ def add_missing_geoms_graph(graph: nx.Graph, geom_name: str = "geometry") -> nx. graph.nodes[nd][geom_name] = Point(graph.nodes[nd]["x"], graph.nodes[nd]["y"]) edges_without_geom = [ - e for e in graph.edges.data(keys=True) if geom_name not in e[-1] + e for e in graph.edges.data(data=True) if geom_name not in e[-1] ] for ed in edges_without_geom: - graph[ed[0]][ed[1]][ed[2]][geom_name] = LineString( + graph[ed[0]][ed[1]][0][geom_name] = LineString( [graph.nodes[ed[0]][geom_name], graph.nodes[ed[1]][geom_name]] ) diff --git a/ra2ce/graph/osm_network_wrapper/__init__.py b/ra2ce/graph/osm_network_wrapper/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ra2ce/graph/osm_network_wrapper/extremities_data.py b/ra2ce/graph/osm_network_wrapper/extremities_data.py new file mode 100644 index 000000000..6e4308de1 --- /dev/null +++ b/ra2ce/graph/osm_network_wrapper/extremities_data.py @@ -0,0 +1,79 @@ +from dataclasses import dataclass + +from networkx import MultiDiGraph + + +@dataclass +class ExtremitiesData: + from_id: int = None + to_id: int = None + from_to_id: tuple = None + to_from_id: tuple = None + from_to_coor: tuple = None + to_from_coor: tuple = None + + @staticmethod + def get_extremities_data_for_sub_graph(from_node_id: int, to_node_id: int, sub_graph: MultiDiGraph, + graph: MultiDiGraph, shared_elements: set): + """Both extremities should be in the unique_graph still makes an edge between similar node to u (the node + with u coordinates and different id, included in the unique_graph) and v Here, sub_graph is the unique_graph + and graph is complex_graph Shared elements are shared btw sub_graph and graph, which are elements to include + when dropping duplicates""" + if shared_elements is None or not isinstance(shared_elements, set): + raise ValueError("unique_elements should be a set") + if from_node_id in sub_graph.nodes() and to_node_id in sub_graph.nodes(): + return ExtremitiesData.arrange_extremities_data( + from_node_id=from_node_id, to_node_id=to_node_id, graph=sub_graph + ) + elif (graph.nodes[from_node_id]['x'], graph.nodes[from_node_id]['y']) in shared_elements and \ + to_node_id in sub_graph.nodes(): + + from_node_id_prime = ExtremitiesData.find_node_id_by_coor( + sub_graph, graph.nodes[from_node_id]['x'], graph.nodes[from_node_id]['y'] + ) + if from_node_id_prime == to_node_id: + return ExtremitiesData() + else: + return ExtremitiesData.arrange_extremities_data(from_node_id=from_node_id_prime, to_node_id=to_node_id, + graph=sub_graph) + + elif from_node_id in sub_graph.nodes() and \ + (graph.nodes[to_node_id]['x'], graph.nodes[to_node_id]['y']) in shared_elements: + + to_node_id_prime = ExtremitiesData.find_node_id_by_coor( + sub_graph, graph.nodes[to_node_id]['x'], graph.nodes[to_node_id]['y'] + ) + if from_node_id == to_node_id_prime: + return ExtremitiesData() + else: + return ExtremitiesData.arrange_extremities_data(from_node_id=from_node_id, to_node_id=to_node_id_prime, + graph=sub_graph) + else: + return ExtremitiesData() + + @staticmethod + def arrange_extremities_data(from_node_id: int, to_node_id: int, graph: MultiDiGraph): + return ExtremitiesData( + from_id=from_node_id, + to_id=to_node_id, + from_to_id=(from_node_id, to_node_id), + to_from_id=(to_node_id, from_node_id), + from_to_coor=( + (graph.nodes[from_node_id]['x'], graph.nodes[to_node_id]['x']), + (graph.nodes[from_node_id]['y'], graph.nodes[to_node_id]['y']) + ), + to_from_coor=( + (graph.nodes[to_node_id]['x'], graph.nodes[from_node_id]['x']), + (graph.nodes[to_node_id]['y'], graph.nodes[from_node_id]['y']) + ) + ) + + @staticmethod + def find_node_id_by_coor(graph: MultiDiGraph, target_x: float, target_y: float): + """ + finds the node in unique graph with the same coor + """ + for node, data in graph.nodes(data=True): + if 'x' in data and 'y' in data and data['x'] == target_x and data['y'] == target_y: + return node + return None diff --git a/ra2ce/graph/osm_network_wrapper/osm_network_wrapper.py b/ra2ce/graph/osm_network_wrapper/osm_network_wrapper.py new file mode 100644 index 000000000..e4b3cf6d4 --- /dev/null +++ b/ra2ce/graph/osm_network_wrapper/osm_network_wrapper.py @@ -0,0 +1,242 @@ +""" + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + Risk Assessment and Adaptation for Critical Infrastructure (RA2CE). + Copyright (C) 2023 Stichting Deltares + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" +import logging +from pathlib import Path + +import networkx as nx +import osmnx +from networkx import MultiDiGraph +from osmnx import consolidate_intersections +from shapely.geometry.base import BaseGeometry + +import ra2ce.graph.networks_utils as nut +from ra2ce.graph.osm_network_wrapper.extremities_data import ExtremitiesData + + +class OsmNetworkWrapper: + network_dict: dict + output_path: Path + graph_crs: str + + def __init__(self, config: dict, graph_crs: str) -> None: + if not config: + raise ValueError("Config cannot be None") + if not config.get("network", {}): + raise ValueError( + "A network dictionary is required for creating a OsmNetworkWrapper object." + ) + if not isinstance(config.get("network"), dict): + raise ValueError('Config["network"] should be a dictionary') + + self.network_dict = config["network"] + self.output_path = config["static"] / "output_graph" + self.graph_crs = graph_crs + if not self.graph_crs: + self.graph_crs = "epsg:4326" + + def get_clean_graph_from_osm(self) -> MultiDiGraph: + """ + Creates a network from a polygon by by downloading via the OSM API in its extent. + + Raises: + FileNotFoundError: When no valid polygon file is provided. + + Returns: + MultiDiGraph: Complex (clean) graph after download from OSM, for use in the direct analyses and input to derive simplified network. + """ + # It can only read in one geojson + if not self.network_dict.get("polygon", []): + raise ValueError("No valid value provided for polygon file.") + + polygon_file = self.output_path.parent.joinpath( + "network", self.network_dict.get("polygon", [])[0] + ) + if not polygon_file.is_file(): + raise FileNotFoundError("No polygon_file file found.") + + poly_dict = nut.read_geojson(geojson_file=polygon_file) + _complex_graph = self._download_clean_graph_from_osm( + polygon=nut.geojson_to_shp(poly_dict), + network_type=self.network_dict.get("network_type", ""), + road_types=self.network_dict.get("road_types", ""), + ) + return _complex_graph + + def _download_clean_graph_from_osm( + self, polygon: BaseGeometry, road_types: str, network_type: str + ) -> MultiDiGraph: + if not road_types and not network_type: + raise ValueError("Either of the link_type or network_type should be known") + elif not road_types: + # The user specified only the network type. + _complex_graph = osmnx.graph_from_polygon( + polygon=polygon, + network_type=network_type, + simplify=False, + retain_all=True, + ) + elif not network_type: + # The user specified only the road types. + cf = f'["highway"~"{road_types.replace(",", "|")}"]' + _complex_graph = osmnx.graph_from_polygon( + polygon=polygon, custom_filter=cf, simplify=False, retain_all=True + ) + else: + cf = f'["highway"~"{road_types.replace(",", "|")}"]' + _complex_graph = osmnx.graph_from_polygon( + polygon=polygon, + network_type=network_type, + custom_filter=cf, + simplify=False, + retain_all=True, + ) + + logging.info( + "graph downloaded from OSM with {:,} nodes and {:,} edges".format( + len(list(_complex_graph.nodes())), len(list(_complex_graph.edges())) + ) + ) + if "crs" not in _complex_graph.graph.keys(): + _complex_graph.graph["crs"] = self.graph_crs + self.get_clean_graph(_complex_graph) + return _complex_graph + + @staticmethod + def get_clean_graph(complex_graph: MultiDiGraph): + complex_graph = OsmNetworkWrapper.drop_duplicates(complex_graph) + complex_graph = nut.add_missing_geoms_graph( + graph=complex_graph, geom_name="geometry" + ).to_directed() + complex_graph = OsmNetworkWrapper.snap_nodes_to_nodes( + graph=complex_graph, threshold=0.000025 + ) + return complex_graph + + @staticmethod + def drop_duplicates(complex_graph: MultiDiGraph) -> MultiDiGraph: + unique_elements = ( + set() + ) # This gets updated during the drop_duplicates_in_nodes and drop_duplicates_in_edges + + unique_graph = OsmNetworkWrapper.drop_duplicates_in_nodes( + unique_elements=unique_elements, graph=complex_graph + ) + unique_graph = OsmNetworkWrapper.drop_duplicates_in_edges( + unique_elements=unique_elements, + unique_graph=unique_graph, + graph=complex_graph, + ) + return unique_graph + + @staticmethod + def drop_duplicates_in_nodes( + unique_elements: set, graph: MultiDiGraph + ) -> MultiDiGraph: + if unique_elements is None or not isinstance(unique_elements, set): + raise ValueError("unique_elements should be a set") + + unique_graph = nx.MultiDiGraph() + for node, data in graph.nodes(data=True): + if data["x"] is None or data["y"] is None: + raise ValueError( + "Incompatible coordinate keys. Check the keys which define the x and y coordinates" + ) + + x, y = data["x"], data["y"] + coord = (x, y) + if coord not in unique_elements: + node_attributes = {key: value for key, value in data.items()} + unique_graph.add_node(node, **node_attributes) + unique_elements.add(coord) + # Copy the graph dictionary from the source one. + unique_graph.graph = graph.graph + return unique_graph + + @staticmethod + def drop_duplicates_in_edges( + unique_elements: set, unique_graph: MultiDiGraph, graph: MultiDiGraph + ): + """ + Checks if both extremities are in the unique_graph (u has not the same coor of v, no line from u to itself is + allowed). Checks if an edge is already made between such extremities with the given id and coordinates before + considering it in the unique graph + """ + if ( + not unique_elements + or not any(unique_elements) + or not all(isinstance(item, tuple) for item in unique_elements) + ): + raise ValueError( + """unique_elements cannot be None, empty, or have non-tuple elements. + Provide a set with all unique node coordinates as tuples of (x, y)""" + ) + if unique_graph is None: + raise ValueError( + """unique_graph cannot be None. Provide a graph with unique nodes or perform the + drop_duplicates_in_nodes on the graph to generate a unique_graph""" + ) + + def validity_check(extremities_tuple) -> bool: + extremities = extremities_tuple[0] + return not ( + extremities.from_to_id is None + or extremities.from_to_coor is None + or extremities.from_id == extremities.to_id + ) + + def valid_extremity_data(u, v, data) -> tuple[ExtremitiesData, dict]: + _extremities_data = ExtremitiesData.get_extremities_data_for_sub_graph( + from_node_id=u, + to_node_id=v, + sub_graph=unique_graph, + graph=graph, + shared_elements=unique_elements, + ) + + return _extremities_data, data + + for _extremity_data, _edge_data in filter( + validity_check, + map(lambda edge: valid_extremity_data(*edge), graph.edges(data=True)), + ): + _id_combination = (_extremity_data.from_to_id, _extremity_data.to_from_id) + _coor_combination = ( + _extremity_data.from_to_coor, + _extremity_data.to_from_coor, + ) + if all( + _combination not in unique_elements + for _combination in [_id_combination, _coor_combination] + ): + edge_attributes = {key: value for key, value in _edge_data.items()} + unique_graph.add_edge( + _extremity_data.from_id, _extremity_data.to_id, **edge_attributes + ) + unique_elements.add(_id_combination) + unique_elements.add(_coor_combination) + + return unique_graph + + @staticmethod + def snap_nodes_to_nodes(graph: MultiDiGraph, threshold: float) -> MultiDiGraph: + return consolidate_intersections( + G=graph, rebuild_graph=True, tolerance=threshold, dead_ends=False + ) + + @staticmethod + def snap_nodes_to_edges(graph: MultiDiGraph, threshold: float): + raise NotImplementedError("Next thing to do!") diff --git a/tests/graph/osm_network_wrapper/__init__.py b/tests/graph/osm_network_wrapper/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/graph/osm_network_wrapper/test_osm_network_wrapper.py b/tests/graph/osm_network_wrapper/test_osm_network_wrapper.py new file mode 100644 index 000000000..638e202b1 --- /dev/null +++ b/tests/graph/osm_network_wrapper/test_osm_network_wrapper.py @@ -0,0 +1,298 @@ +import networkx as nx +import pytest +from networkx import Graph, MultiDiGraph +from networkx.utils import graphs_equal +from shapely.geometry import LineString, Polygon +from shapely.geometry.base import BaseGeometry + +from tests import test_data, slow_test +import ra2ce.graph.networks_utils as nut +from ra2ce.graph.osm_network_wrapper.osm_network_wrapper import OsmNetworkWrapper + + +class TestOsmNetworkWrapper: + @pytest.fixture + def _config_fixture(self) -> dict: + _test_dir = test_data / "graph" / "test_osm_network_wrapper" + + yield { + "static": _test_dir / "static", + "network": {"polygon": "_test_polygon.geojson"}, + "origins_destinations": { + "origins": None, + "destinations": None, + "origins_names": None, + "destinations_names": None, + "id_name_origin_destination": None, + "category": "dummy_category", + "region": "", + }, + "cleanup": {"snapping_threshold": None, "segmentation_length": None}, + } + + @pytest.mark.parametrize( + "config", + [ + pytest.param(None, id="None"), + pytest.param({}, id="{}"), + pytest.param({"network": {}}, id='"Config["network"]" = None'), + pytest.param({"network": "string"}, id='"invalid type(Config["network"])"'), + ], + ) + def test_osm_network_wrapper_initialisation_with_invalid_config(self, config: dict): + _files = [] + with pytest.raises(ValueError) as exc_err: + OsmNetworkWrapper(config=config, graph_crs="") + assert ( + str(exc_err.value) == "Config cannot be None" + or "A network dictionary is required for creating a OsmNetworkWrapper object" + or 'Config["network"] should be a dictionary' + ) + + def test__download_clean_graph_from_osm_with_invalid_polygon_arg( + self, _config_fixture: dict + ): + _osm_network = OsmNetworkWrapper(config=_config_fixture, graph_crs="") + _polygon = None + _link_type = "road_link" + _network_type = "drive" + with pytest.raises(AttributeError) as exc_err: + _osm_network._download_clean_graph_from_osm( + _polygon, _link_type, _network_type + ) + + assert str(exc_err.value) == "'NoneType' object has no attribute 'is_valid'" + + def test__download_clean_graph_from_osm_with_invalid_polygon_arg_geometry( + self, _config_fixture: dict + ): + _osm_network = OsmNetworkWrapper(config=_config_fixture, graph_crs="") + _polygon = LineString([[0, 0], [1, 0], [1, 1]]) + _link_type = "road_link" + _network_type = "drive" + with pytest.raises(TypeError) as exc_err: + _osm_network._download_clean_graph_from_osm( + _polygon, _link_type, _network_type + ) + + assert ( + str(exc_err.value) + == "Geometry must be a shapely Polygon or MultiPolygon. If you requested graph from place name, make sure your query resolves to a Polygon or MultiPolygon, and not some other geometry, like a Point. See OSMnx documentation for details." + ) + + def test__download_clean_graph_from_osm_with_invalid_network_type_arg( + self, _config_fixture: dict + ): + _osm_network = OsmNetworkWrapper(config=_config_fixture, graph_crs="") + _polygon = Polygon([(0.0, 0.0), (0.0, 1.0), (1.0, 1.0), (1.0, 0.0), (0.0, 0.0)]) + _link_type = "" + _network_type = "drv" + with pytest.raises(ValueError) as exc_err: + _osm_network._download_clean_graph_from_osm( + _polygon, _link_type, _network_type + ) + + assert str(exc_err.value) == f'Unrecognized network_type "{_network_type}"' + + @pytest.fixture + def _valid_network_polygon_fixture(self) -> BaseGeometry: + _test_input_directory = test_data.joinpath("graph", "test_osm_network_wrapper") + _polygon_file = _test_input_directory.joinpath("_test_polygon.geojson") + assert _polygon_file.exists() + _polygon_dict = nut.read_geojson(_polygon_file) + yield nut.geojson_to_shp(_polygon_dict) + + @slow_test + def test__download_clean_graph_from_osm_output( + self, _config_fixture: dict, _valid_network_polygon_fixture: BaseGeometry + ): + # 1. Define test data. + _osm_network = OsmNetworkWrapper(config=_config_fixture, graph_crs="") + _link_type = "" + _network_type = "drive" + + # 2. Run test. + graph_complex = _osm_network._download_clean_graph_from_osm( + polygon=_valid_network_polygon_fixture, + network_type=_network_type, + road_types=_link_type, + ) + + # 3. Verify expectations + # reference: https://www.openstreetmap.org/node/1402598729#map=17/51.98816/4.39126&layers=T + _osm_id_node_to_validate = 4987298323 + # reference: https://www.openstreetmap.org/way/334316041#map=19/51.98945/4.39166&layers=T + _osm_id_edge_to_validate = 334316041 + + assert isinstance(graph_complex, Graph) + assert _osm_id_node_to_validate in list(graph_complex.nodes.keys()) + assert _osm_id_edge_to_validate in list( + map(lambda x: x["osmid"], graph_complex.edges.values()) + ) + + @pytest.mark.parametrize( + "polygon_values", + [ + pytest.param([""], id="Empty polygon file value"), + pytest.param(["Not a valid name"], id="Invalid polygon file name"), + ], + ) + def test_get_clean_graph_from_osm_with_invalid_polygon_parameter_filename( + self, polygon_values + ): + _config_dict = { + "static": test_data.joinpath("graph", "test_osm_network_wrapper", "static"), + "network": {"polygon": polygon_values}, + } + _osm_network = OsmNetworkWrapper(config=_config_dict, graph_crs="") + with pytest.raises(FileNotFoundError) as exc_err: + _osm_network.get_clean_graph_from_osm() + + assert str(exc_err.value) == "No polygon_file file found." + + @pytest.mark.parametrize( + "polygon_value", + [ + pytest.param(None, id="None polygon file"), + pytest.param([], id="Invalid polygon file name"), + ], + ) + def test_get_clean_graph_from_osm_with_invalid_polygon_parameter( + self, polygon_value + ): + _config_dict = { + "static": test_data.joinpath("graph", "test_osm_network_wrapper", "static"), + "network": {"polygon": polygon_value}, + } + _osm_network = OsmNetworkWrapper(config=_config_dict, graph_crs="") + with pytest.raises(ValueError) as exc_err: + _osm_network.get_clean_graph_from_osm() + + assert str(exc_err.value) == "No valid value provided for polygon file." + + @pytest.fixture + def _valid_graph_fixture(self) -> MultiDiGraph: + _valid_graph = nx.MultiDiGraph() + _valid_graph.add_node(1, x=1, y=10) + _valid_graph.add_node(2, x=2, y=20) + _valid_graph.add_node(3, x=1, y=10) + _valid_graph.add_node(4, x=2, y=40) + _valid_graph.add_node(5, x=3, y=50) + + _valid_graph.add_edge(1, 2, x=[1, 2], y=[10, 20]) + _valid_graph.add_edge(1, 3, x=[1, 1], y=[10, 10]) + _valid_graph.add_edge(2, 4, x=[2, 2], y=[20, 40]) + _valid_graph.add_edge(3, 4, x=[1, 2], y=[10, 40]) + _valid_graph.add_edge(1, 4, x=[1, 2], y=[10, 40]) + _valid_graph.add_edge(5, 3, x=[3, 1], y=[50, 10]) + _valid_graph.add_edge(5, 5, x=[3, 3], y=[50, 50]) + + # Add a valid CRS value. + _valid_graph.graph["crs"] = "EPSG:4326" + + return _valid_graph + + @pytest.fixture + def _expected_unique_graph_fixture(self) -> MultiDiGraph: + _valid_unique_graph = nx.MultiDiGraph() + _valid_unique_graph.add_node(1, x=1, y=10) + _valid_unique_graph.add_node(2, x=2, y=20) + _valid_unique_graph.add_node(4, x=2, y=40) + _valid_unique_graph.add_node(5, x=3, y=50) + + _valid_unique_graph.add_edge(1, 2, x=[1, 2], y=[10, 20]) + _valid_unique_graph.add_edge(2, 4, x=[2, 2], y=[20, 40]) + _valid_unique_graph.add_edge(1, 4, x=[1, 2], y=[10, 40]) + _valid_unique_graph.add_edge(5, 1, x=[3, 1], y=[50, 10]) + + return _valid_unique_graph + + def test_drop_duplicates_in_nodes( + self, + _valid_graph_fixture: MultiDiGraph, + _expected_unique_graph_fixture: MultiDiGraph, + ): + unique_graph = OsmNetworkWrapper.drop_duplicates_in_nodes( + graph=_valid_graph_fixture, unique_elements=set() + ) + + assert unique_graph.nodes() == _expected_unique_graph_fixture.nodes() + + @pytest.mark.parametrize( + "unique_elements", + [ + pytest.param(None, id="None unique_elements"), + pytest.param(set(), id="Empty unique_elements"), + pytest.param({1, 2}, id="Non-tuple elements"), + ], + ) + def test_drop_duplicates_in_edges_invalid_unique_elements_input( + self, + unique_elements: set, + _valid_graph_fixture: MultiDiGraph, + _expected_unique_graph_fixture: MultiDiGraph, + ): + with pytest.raises(ValueError) as exc_err: + OsmNetworkWrapper.drop_duplicates_in_edges( + graph=_valid_graph_fixture, + unique_elements=unique_elements, + unique_graph=None, + ) + + assert ( + str(exc_err.value) + == """unique_elements cannot be None, empty, or have non-tuple elements. + Provide a set with all unique node coordinates as tuples of (x, y)""" + ) + + def test_drop_duplicates_in_edges_invalid_unique_graph_input( + self, + _valid_graph_fixture: MultiDiGraph, + _expected_unique_graph_fixture: MultiDiGraph, + ): + with pytest.raises(ValueError) as exc_err: + OsmNetworkWrapper.drop_duplicates_in_edges( + graph=_valid_graph_fixture, unique_elements={(1, 2)}, unique_graph=None + ) + + assert ( + str(exc_err.value) + == """unique_graph cannot be None. Provide a graph with unique nodes or perform the + drop_duplicates_in_nodes on the graph to generate a unique_graph""" + ) + + def test_drop_duplicates_in_edges( + self, + _valid_graph_fixture: MultiDiGraph, + _expected_unique_graph_fixture: MultiDiGraph, + ): + # 1. Define test data. + unique_elements = {(1, 10), (2, 20), (4, 40), (5, 50)} + + unique_graph = nx.MultiDiGraph() + unique_graph.add_node(1, x=1, y=10) + unique_graph.add_node(2, x=2, y=20) + unique_graph.add_node(4, x=2, y=40) + unique_graph.add_node(5, x=3, y=50) + + # 2. Run test. + unique_graph = OsmNetworkWrapper.drop_duplicates_in_edges( + graph=_valid_graph_fixture, + unique_elements=unique_elements, + unique_graph=unique_graph, + ) + + # 3. Verify results + assert graphs_equal(unique_graph, _expected_unique_graph_fixture) + + def test_snap_nodes_to_nodes(self, _valid_graph_fixture: MultiDiGraph): + # 1. Define test data. + _threshold = 0.00002 + + # 2. Run test. + _result_graph = OsmNetworkWrapper.snap_nodes_to_nodes( + _valid_graph_fixture, _threshold + ) + + # 3. Verify expectations. + assert isinstance(_result_graph, MultiDiGraph) diff --git a/tests/graph/test_networks.py b/tests/graph/test_networks.py index 9197a0259..8dd7481c0 100644 --- a/tests/graph/test_networks.py +++ b/tests/graph/test_networks.py @@ -1,6 +1,5 @@ import shutil from typing import Iterator - import pytest from ra2ce.graph.networks import Network diff --git a/tests/graph/test_networks_utils.py b/tests/graph/test_networks_utils.py index e2801170f..8f9146219 100644 --- a/tests/graph/test_networks_utils.py +++ b/tests/graph/test_networks_utils.py @@ -394,7 +394,7 @@ def test_with_valid_data(self): # 3. Verify final expectations assert _return_graph == _graph - _items = list(_return_graph.edges.data(keys=True)) + _items = list(_return_graph.edges.data(data=True)) assert len(_items) == 1 _data = _items[0][-1] assert isinstance(_data, dict) diff --git a/tests/test_data/graph/test_osm_network_wrapper/_test_polygon.PNG b/tests/test_data/graph/test_osm_network_wrapper/_test_polygon.PNG new file mode 100644 index 000000000..1b09b25cf Binary files /dev/null and b/tests/test_data/graph/test_osm_network_wrapper/_test_polygon.PNG differ diff --git a/tests/test_data/graph/test_osm_network_wrapper/_test_polygon.geojson b/tests/test_data/graph/test_osm_network_wrapper/_test_polygon.geojson new file mode 100644 index 000000000..6d9bf2645 --- /dev/null +++ b/tests/test_data/graph/test_osm_network_wrapper/_test_polygon.geojson @@ -0,0 +1,36 @@ +{ + "type": "FeatureCollection", + "features": [ + { + "type": "Feature", + "properties": {}, + "geometry": { + "coordinates": [ + [ + [ + 4.391052553972401, + 51.99092158517831 + ], + [ + 4.392773550123422, + 51.98700503937195 + ], + [ + 4.393053863048749, + 51.9870598530963 + ], + [ + 4.391344199966596, + 51.99101022946144 + ], + [ + 4.391052553972401, + 51.99092158517831 + ] + ] + ], + "type": "Polygon" + } + } + ] +} \ No newline at end of file