From 431b950270f851fdccb2663961454fd4277251aa Mon Sep 17 00:00:00 2001 From: Robert Tuck Date: Tue, 22 Oct 2024 15:54:47 +0100 Subject: [PATCH] Fix various type linting issues --- src/dodal/common/signal_utils.py | 16 +++++++--------- src/dodal/devices/aperturescatterguard.py | 8 ++++---- src/dodal/devices/oav/utils.py | 2 +- .../unit_tests/test_aperture_scatterguard.py | 2 +- tests/devices/unit_tests/test_oav.py | 7 +++++-- 5 files changed, 18 insertions(+), 17 deletions(-) diff --git a/src/dodal/common/signal_utils.py b/src/dodal/common/signal_utils.py index 2c4c76aa2b..40cfad30cb 100644 --- a/src/dodal/common/signal_utils.py +++ b/src/dodal/common/signal_utils.py @@ -1,16 +1,14 @@ from collections.abc import Callable, Coroutine -from typing import Any, TypeVar +from typing import Any from bluesky.protocols import Reading -from ophyd_async.core import SignalR, SoftSignalBackend +from ophyd_async.core import SignalDatatypeT, SignalR, SoftSignalBackend -T = TypeVar("T") - -class HardwareBackedSoftSignalBackend(SoftSignalBackend[T]): +class HardwareBackedSoftSignalBackend(SoftSignalBackend[SignalDatatypeT]): def __init__( self, - get_from_hardware_func: Callable[[], Coroutine[Any, Any, T]], + get_from_hardware_func: Callable[[], Coroutine[Any, Any, SignalDatatypeT]], *args, **kwargs, ) -> None: @@ -25,14 +23,14 @@ async def get_reading(self) -> Reading: await self._update_value() return await super().get_reading() - async def get_value(self) -> T: + async def get_value(self) -> SignalDatatypeT: await self._update_value() return await super().get_value() def create_hardware_backed_soft_signal( - datatype: type[T], - get_from_hardware_func: Callable[[], Coroutine[Any, Any, T]], + datatype: type[SignalDatatypeT], + get_from_hardware_func: Callable[[], Coroutine[Any, Any, SignalDatatypeT]], units: str | None = None, precision: int | None = None, ): diff --git a/src/dodal/devices/aperturescatterguard.py b/src/dodal/devices/aperturescatterguard.py index c319e1dc78..83a6f6056d 100644 --- a/src/dodal/devices/aperturescatterguard.py +++ b/src/dodal/devices/aperturescatterguard.py @@ -27,7 +27,7 @@ class AperturePosition(BaseModel): aperture_z: float scatterguard_x: float scatterguard_y: float - radius: float | None = Field(json_schema_extra={"units": "µm"}, default=None) + radius: float = Field(json_schema_extra={"units": "µm"}, default=0.0) @property def values(self) -> tuple[float, float, float, float, float]: @@ -54,7 +54,7 @@ def tolerances_from_gda_params( @staticmethod def from_gda_params( name: ApertureValue, - radius: float | None, + radius: float, params: GDABeamlineParameters, ) -> AperturePosition: return AperturePosition( @@ -81,7 +81,7 @@ def load_positions_from_beamline_parameters( ) -> dict[ApertureValue, AperturePosition]: return { ApertureValue.ROBOT_LOAD: AperturePosition.from_gda_params( - ApertureValue.ROBOT_LOAD, None, params + ApertureValue.ROBOT_LOAD, 0, params ), ApertureValue.SMALL: AperturePosition.from_gda_params( ApertureValue.SMALL, 20, params @@ -172,7 +172,7 @@ async def _get_current_aperture_position(self) -> ApertureValue: raise InvalidApertureMove("Current aperture/scatterguard state unrecognised") - async def _get_current_radius(self) -> float | None: + async def _get_current_radius(self) -> float: current_value = await self._get_current_aperture_position() return self._loaded_positions[current_value].radius diff --git a/src/dodal/devices/oav/utils.py b/src/dodal/devices/oav/utils.py index f20e1b6e00..3a1c9d6767 100644 --- a/src/dodal/devices/oav/utils.py +++ b/src/dodal/devices/oav/utils.py @@ -106,4 +106,4 @@ def wait_for_tip_to_be_found( timeout = yield from bps.rd(ophyd_pin_tip_detection.validity_timeout) raise PinNotFoundException(f"No pin found after {timeout} seconds") - return found_tip # type: ignore + return Pixel((int(found_tip[0]), int(found_tip[1]))) # type: ignore diff --git a/tests/devices/unit_tests/test_aperture_scatterguard.py b/tests/devices/unit_tests/test_aperture_scatterguard.py index e4385a66c6..adb9da7bb4 100644 --- a/tests/devices/unit_tests/test_aperture_scatterguard.py +++ b/tests/devices/unit_tests/test_aperture_scatterguard.py @@ -174,7 +174,7 @@ async def test_aperture_unsafe_move( aperture_z=5.6, scatterguard_x=7.8, scatterguard_y=9.0, - radius=None, + radius=0, ) ap_sg = aperture_in_medium_pos await ap_sg._set_raw_unsafe(pos) diff --git a/tests/devices/unit_tests/test_oav.py b/tests/devices/unit_tests/test_oav.py index 614881e32c..c5a5e55d34 100644 --- a/tests/devices/unit_tests/test_oav.py +++ b/tests/devices/unit_tests/test_oav.py @@ -241,7 +241,7 @@ async def test_given_tip_found_when_wait_for_tip_to_be_found_called_then_tip_imm ) RE = RunEngine(call_returns_result=True) result = RE(wait_for_tip_to_be_found(mock_pin_tip_detect)) - assert all(result.plan_result == (100, 100)) + assert result.plan_result == (100, 100) # type: ignore mock_pin_tip_detect._get_tip_and_edge_data.assert_called_once() @@ -254,7 +254,10 @@ async def test_given_no_tip_when_wait_for_tip_to_be_found_called_then_exception_ await mock_pin_tip_detect.validity_timeout.set(0.2) mock_pin_tip_detect._get_tip_and_edge_data = AsyncMock( return_value=SampleLocation( - *PinTipDetection.INVALID_POSITION, np.array([]), np.array([]) + int(PinTipDetection.INVALID_POSITION[0]), + int(PinTipDetection.INVALID_POSITION[1]), + np.array([]), + np.array([]), ) ) RE = RunEngine(call_returns_result=True)