Skip to content

Commit

Permalink
Updates to tests to make sure no test pollution is occuring & add som…
Browse files Browse the repository at this point in the history
…e Payu related changes to them
  • Loading branch information
charles-turner-1 committed Jan 10, 2025
1 parent 378f0e1 commit 10b24e7
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 36 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ jobs:
shell: bash -l {0}
run: conda list

- name: Run tests
- name: Run tests with randomised order
shell: bash -l {0}
run: coverage run -m --source=access_py_telemetry pytest tests
run: coverage run -m --source=access_py_telemetry pytest --random-order tests

- name: Generate coverage report
shell: bash -l {0}
Expand Down
4 changes: 3 additions & 1 deletion src/access_py_telemetry/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
config = yaml.safe_load(f)

ENDPOINTS = {registry: content.get("endpoint") for registry, content in config.items()}
REGISTRIES = {registry for registry in config.keys()}
SERVER_URL = "https://tracking-services-d6c2fd311c12.herokuapp.com"


Expand Down Expand Up @@ -49,6 +50,7 @@ def __init__(
self._initialized = True
self._server_url = SERVER_URL
self.endpoints = ENDPOINTS
self.registries = REGISTRIES
self._extra_fields: dict[str, dict[str, Any]] = {
ep_name: {} for ep_name in self.endpoints.keys()
}
Expand Down Expand Up @@ -148,7 +150,7 @@ async def send_telemetry(endpoint: str, data: dict[str, Any]) -> None:
print(f"Posting telemetry to {endpoint}")
response = await client.post(endpoint, json=data, headers=headers)
response.raise_for_status()
except httpx.RequestError as e:
except (httpx.RequestError, httpx.HTTPStatusError) as e:
warnings.warn(
f"Request failed: {e}", category=RuntimeWarning, stacklevel=2
)
Expand Down
44 changes: 25 additions & 19 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python
# type: ignore

"""Tests for `access_py_telemetry` package."""

Expand All @@ -15,7 +16,7 @@ def local_host():

@pytest.fixture
def default_url():
return "https://tracking-services-d6c2fd311c12.herokuapp.com"
return access_py_telemetry.api.SERVER_URL


def test_session_id_properties():
Expand All @@ -41,6 +42,8 @@ def test_api_handler_server_url(local_host, default_url):
"""
Check that the APIHandler class is a singleton.
"""
ApiHandler._instance = None

session1 = ApiHandler()
session2 = ApiHandler()

Expand All @@ -53,12 +56,16 @@ def test_api_handler_server_url(local_host, default_url):
session1.server_url = local_host
assert session2.server_url == local_host

ApiHandler._instance = None


def test_api_handler_extra_fields(local_host):
"""
Check that adding extra fields to the APIHandler class works as expected.
"""

ApiHandler._instance = None

session1 = ApiHandler()
session2 = ApiHandler()

Expand All @@ -71,7 +78,9 @@ def test_api_handler_extra_fields(local_host):

session1.add_extra_fields("catalog", {"version": "1.0"})

assert session2.extra_fields == {"catalog": {"version": "1.0"}}
blank_registries = {key: {} for key in session1.registries if key != "catalog"}

assert session2.extra_fields == {"catalog": {"version": "1.0"}, **blank_registries}

with pytest.raises(KeyError) as excinfo:
session1.add_extra_fields("catalogue", {"version": "2.0"})
Expand All @@ -83,8 +92,7 @@ def test_api_handler_extra_fields(local_host):
assert session1.server_url == local_host
assert session3.server_url == local_host

# Reset the server URL to avoid breaking other tests
session1.server_url = default_url
ApiHandler._instance = None


def test_api_handler_extra_fields_validation():
Expand All @@ -93,6 +101,7 @@ def test_api_handler_extra_fields_validation():
to pass the correct types, and only let us update fields through the
add_extra_field method.
"""
ApiHandler._instance = None
api_handler = ApiHandler()

# Mock a couple of extra services
Expand Down Expand Up @@ -123,11 +132,14 @@ def test_api_handler_extra_fields_validation():
ep_name: {} for ep_name in api_handler.endpoints.keys()
}

ApiHandler._instance = None


def test_api_handler_remove_fields():
"""
Check that we can remove fields from the telemetry record.
"""
ApiHandler._instance = None
api_handler = ApiHandler()

# Pretend we only have catalog & payu services and then mock the initialisation
Expand Down Expand Up @@ -164,21 +176,17 @@ def test_api_handler_remove_fields():

assert api_handler._pop_fields == {"payu": ["session_id"]}

# Reset endpoints to avoid breaking other tests - we have to be careful here
# because we're using a singleton
api_handler.endpoints = access_py_telemetry.api.ENDPOINTS
api_handler._extra_fields = {
ep_name: {} for ep_name in api_handler.endpoints.keys()
}
api_handler._pop_fields = {}
ApiHandler._instance = None


def test_api_handler_send_api_request_no_loop():
def test_api_handler_send_api_request_no_loop(local_host):
"""
Create and send an API request with telemetry data.
"""

ApiHandler._instance = None
api_handler = ApiHandler()
api_handler.server_url = local_host

# Pretend we only have catalog & payu services and then mock the initialisation
# of the _extra_fields attribute
Expand Down Expand Up @@ -217,20 +225,15 @@ def test_api_handler_send_api_request_no_loop():
"random_number": 2,
}

# Reset endpoints to avoid breaking other tests - we have to be careful here
# because we're using a singleton
api_handler.endpoints = access_py_telemetry.api.ENDPOINTS
api_handler._extra_fields = {
ep_name: {} for ep_name in api_handler.endpoints.keys()
}
api_handler._pop_fields = {}
ApiHandler._instance = None


def test_api_handler_invalid_endpoint():
"""
Create and send an API request with telemetry data.
"""

ApiHandler._instance = None
api_handler = ApiHandler()

# Pretend we only have catalog & payu services and then mock the initialisation
Expand All @@ -253,3 +256,6 @@ def test_api_handler_invalid_endpoint():
)

assert "Endpoint for 'payu' not found " in str(excinfo.value)

ApiHandler._instance = None
api_handler._instance = None
1 change: 1 addition & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# type: ignore
import pytest
from unittest import mock

Expand Down
23 changes: 15 additions & 8 deletions tests/test_decorators.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# type: ignore

from access_py_telemetry.decorators import ipy_register_func, register_func
from access_py_telemetry.registry import TelemetryRegister
from access_py_telemetry.api import ApiHandler
Expand All @@ -22,9 +24,11 @@ def my_func():

register = TelemetryRegister("catalog")
api_handler = ApiHandler()
blank_registries = {key: {} for key in api_handler.registries if key != "catalog"}

assert api_handler.extra_fields == {
"catalog": {"model": "ACCESS-OM2", "random_number": 2}
"catalog": {"model": "ACCESS-OM2", "random_number": 2},
**blank_registries,
}

assert api_handler.pop_fields == {"catalog": ["session_id"]}
Expand All @@ -35,9 +39,9 @@ def my_func():

register.deregister(my_func.__name__)

# Reset the api_handler to avoid breaking other tests

api_handler._extra_fields = {}
# Reset the api_handler and register to avoid breaking other tests
api_handler._instance = None
register._instances = {}


@pytest.mark.asyncio
Expand All @@ -57,8 +61,11 @@ def my_func():
register = TelemetryRegister("catalog")
api_handler = ApiHandler()

blank_registries = {key: {} for key in api_handler.registries if key != "catalog"}

assert api_handler.extra_fields == {
"catalog": {"model": "ACCESS-OM2", "random_number": 2}
"catalog": {"model": "ACCESS-OM2", "random_number": 2},
**blank_registries,
}

assert api_handler.pop_fields == {"catalog": ["session_id"]}
Expand All @@ -81,6 +88,6 @@ def my_func():

register.deregister(my_func.__name__)

# Reset the api_handler to avoid breaking other tests

api_handler._extra_fields = {}
# Reset the api_handler and register to avoid breaking other tests
api_handler._instance = None
register._instances = {}
11 changes: 5 additions & 6 deletions tests/test_registry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python
# type: ignore

"""Tests for `access_py_telemetry` package."""

Expand All @@ -12,6 +13,7 @@ def test_telemetry_register_unique():
Check that the TelemetryRegister class is a singleton & that we can register
and deregister functions as we would expect.
"""
TelemetryRegister._instances = {}
session1 = TelemetryRegister("catalog")
session2 = TelemetryRegister("catalog")

Expand Down Expand Up @@ -46,12 +48,11 @@ def test_telemetry_register_unique():
"DfFileCatalog.search",
}

from access_py_telemetry.registry import REGISTRIES

session1.registry = REGISTRIES["catalog"]
TelemetryRegister._instances = {}


def test_telemetry_register_validation():
TelemetryRegister._instances = {}
session_register = TelemetryRegister("catalog")

with pytest.raises(ValidationError):
Expand Down Expand Up @@ -88,6 +89,4 @@ def test_function():

assert "test_function" not in session_register

from access_py_telemetry.registry import REGISTRIES

session_register.registry = REGISTRIES["catalog"]
TelemetryRegister._instances = {}

0 comments on commit 10b24e7

Please sign in to comment.