From 049fde473e7bef56cc3e9c3a67a1e2e8d7f313b8 Mon Sep 17 00:00:00 2001 From: Charles Turner Date: Fri, 24 Jan 2025 14:29:33 +0800 Subject: [PATCH] Allow for https --- src/access_py_telemetry/api.py | 14 ++++++++++--- tests/test_api.py | 36 +++++++++++++++++++++++++++++++++- 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/src/access_py_telemetry/api.py b/src/access_py_telemetry/api.py index 360cdcd..c222a66 100644 --- a/src/access_py_telemetry/api.py +++ b/src/access_py_telemetry/api.py @@ -11,6 +11,7 @@ import httpx import asyncio import pydantic +import re import yaml import multiprocessing from pathlib import Path, PurePosixPath @@ -171,9 +172,7 @@ def send_api_request( f"Endpoint for '{service_name}' not found in {self.endpoints}" ) from e - endpoint = str(PurePosixPath(self.server_url) / endpoint.lstrip("/")).replace( - "http:/", "http://" - ) + endpoint = _format_endpoint(self.server_url, endpoint) send_in_loop(endpoint, telemetry_data, self._request_timeout) return None @@ -377,3 +376,12 @@ def _run_in_proc(endpoint: str, telemetry_data: dict[str, Any], timeout: float) stacklevel=2, ) return None + + +def _format_endpoint(server_url: str, endpoint: str) -> str: + """ + Concatenates the server URL and endpoint, ensuring that there is only one + slash between them. + """ + endpoint = str(PurePosixPath(server_url) / endpoint.lstrip("/")) + return re.sub(r"^(https?:/)", r"\1/", endpoint) diff --git a/tests/test_api.py b/tests/test_api.py index 47bd0cb..7b8728f 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -4,7 +4,12 @@ """Tests for `access_py_telemetry` package.""" import access_py_telemetry.api -from access_py_telemetry.api import SessionID, ApiHandler, send_in_loop +from access_py_telemetry.api import ( + SessionID, + ApiHandler, + send_in_loop, + _format_endpoint, +) from pydantic import ValidationError import pytest @@ -281,3 +286,32 @@ def test_api_handler_set_timeout(api_handler): api_handler.request_timeout = None assert api_handler.request_timeout is None + + +@pytest.mark.parametrize( + "server_url, endpoint, expected", + [ + ( + "http://localhost:8000", + "/some/endpoint", + "http://localhost:8000/some/endpoint", + ), + ( + "http://localhost:8000/", + "some/endpoint/", + "http://localhost:8000/some/endpoint", + ), + ( + "https://localhost:8000", + "/some/endpoint", + "https://localhost:8000/some/endpoint", + ), + ( + "https://localhost:8000/", + "some/endpoint/", + "https://localhost:8000/some/endpoint", + ), + ], +) +def test_format_endpoint(server_url, endpoint, expected): + assert _format_endpoint(server_url, endpoint) == expected