Skip to content

Commit

Permalink
Allow for https
Browse files Browse the repository at this point in the history
  • Loading branch information
charles-turner-1 committed Jan 24, 2025
1 parent e0241c6 commit 049fde4
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 4 deletions.
14 changes: 11 additions & 3 deletions src/access_py_telemetry/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import httpx
import asyncio
import pydantic
import re
import yaml
import multiprocessing
from pathlib import Path, PurePosixPath
Expand Down Expand Up @@ -171,9 +172,7 @@ def send_api_request(
f"Endpoint for '{service_name}' not found in {self.endpoints}"
) from e

endpoint = str(PurePosixPath(self.server_url) / endpoint.lstrip("/")).replace(
"http:/", "http://"
)
endpoint = _format_endpoint(self.server_url, endpoint)

send_in_loop(endpoint, telemetry_data, self._request_timeout)
return None
Expand Down Expand Up @@ -377,3 +376,12 @@ def _run_in_proc(endpoint: str, telemetry_data: dict[str, Any], timeout: float)
stacklevel=2,
)
return None


def _format_endpoint(server_url: str, endpoint: str) -> str:
"""
Concatenates the server URL and endpoint, ensuring that there is only one
slash between them.
"""
endpoint = str(PurePosixPath(server_url) / endpoint.lstrip("/"))
return re.sub(r"^(https?:/)", r"\1/", endpoint)
36 changes: 35 additions & 1 deletion tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
"""Tests for `access_py_telemetry` package."""

import access_py_telemetry.api
from access_py_telemetry.api import SessionID, ApiHandler, send_in_loop
from access_py_telemetry.api import (
SessionID,
ApiHandler,
send_in_loop,
_format_endpoint,
)
from pydantic import ValidationError
import pytest

Expand Down Expand Up @@ -281,3 +286,32 @@ def test_api_handler_set_timeout(api_handler):
api_handler.request_timeout = None

assert api_handler.request_timeout is None


@pytest.mark.parametrize(
"server_url, endpoint, expected",
[
(
"http://localhost:8000",
"/some/endpoint",
"http://localhost:8000/some/endpoint",
),
(
"http://localhost:8000/",
"some/endpoint/",
"http://localhost:8000/some/endpoint",
),
(
"https://localhost:8000",
"/some/endpoint",
"https://localhost:8000/some/endpoint",
),
(
"https://localhost:8000/",
"some/endpoint/",
"https://localhost:8000/some/endpoint",
),
],
)
def test_format_endpoint(server_url, endpoint, expected):
assert _format_endpoint(server_url, endpoint) == expected

0 comments on commit 049fde4

Please sign in to comment.