diff --git a/src/access_py_telemetry/api.py b/src/access_py_telemetry/api.py index 360cdcd..c222a66 100644 --- a/src/access_py_telemetry/api.py +++ b/src/access_py_telemetry/api.py @@ -11,6 +11,7 @@ import httpx import asyncio import pydantic +import re import yaml import multiprocessing from pathlib import Path, PurePosixPath @@ -171,9 +172,7 @@ def send_api_request( f"Endpoint for '{service_name}' not found in {self.endpoints}" ) from e - endpoint = str(PurePosixPath(self.server_url) / endpoint.lstrip("/")).replace( - "http:/", "http://" - ) + endpoint = _format_endpoint(self.server_url, endpoint) send_in_loop(endpoint, telemetry_data, self._request_timeout) return None @@ -377,3 +376,12 @@ def _run_in_proc(endpoint: str, telemetry_data: dict[str, Any], timeout: float) stacklevel=2, ) return None + + +def _format_endpoint(server_url: str, endpoint: str) -> str: + """ + Concatenates the server URL and endpoint, ensuring that there is only one + slash between them. + """ + endpoint = str(PurePosixPath(server_url) / endpoint.lstrip("/")) + return re.sub(r"^(https?:/)", r"\1/", endpoint) diff --git a/tests/test_api.py b/tests/test_api.py index 47bd0cb..7b8728f 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -4,7 +4,12 @@ """Tests for `access_py_telemetry` package.""" import access_py_telemetry.api -from access_py_telemetry.api import SessionID, ApiHandler, send_in_loop +from access_py_telemetry.api import ( + SessionID, + ApiHandler, + send_in_loop, + _format_endpoint, +) from pydantic import ValidationError import pytest @@ -281,3 +286,32 @@ def test_api_handler_set_timeout(api_handler): api_handler.request_timeout = None assert api_handler.request_timeout is None + + +@pytest.mark.parametrize( + "server_url, endpoint, expected", + [ + ( + "http://localhost:8000", + "/some/endpoint", + "http://localhost:8000/some/endpoint", + ), + ( + "http://localhost:8000/", + "some/endpoint/", + "http://localhost:8000/some/endpoint", + ), + ( + "https://localhost:8000", + "/some/endpoint", + "https://localhost:8000/some/endpoint", + ), + ( + "https://localhost:8000/", + "some/endpoint/", + "https://localhost:8000/some/endpoint", + ), + ], +) +def test_format_endpoint(server_url, endpoint, expected): + assert _format_endpoint(server_url, endpoint) == expected