From 32dd6e1d88ff1d23af8e2fb2ee0da07903c0d5ea Mon Sep 17 00:00:00 2001 From: "Ware, Joseph (DLSLtd,RAL,LSCI)" Date: Thu, 24 Oct 2024 10:06:58 +0100 Subject: [PATCH] Move plans from deprecated module --- pyproject.toml | 2 + src/dodal/plan_stubs/wrapped.py | 151 +++++++++++++++++++++++++ src/dodal/plans/__init__.py | 4 + src/dodal/plans/scanspec.py | 65 +++++++++++ src/dodal/plans/wrapped.py | 42 +++++++ tests/plan_stubs/test_wrapped_stubs.py | 145 ++++++++++++++++++++++++ tests/plans/test_compliance.py | 69 +++++++++++ tests/plans/test_scanspec.py | 98 ++++++++++++++++ 8 files changed, 576 insertions(+) create mode 100644 src/dodal/plan_stubs/wrapped.py create mode 100644 src/dodal/plans/__init__.py create mode 100644 src/dodal/plans/scanspec.py create mode 100644 src/dodal/plans/wrapped.py create mode 100644 tests/plan_stubs/test_wrapped_stubs.py create mode 100644 tests/plans/test_compliance.py create mode 100644 tests/plans/test_scanspec.py diff --git a/pyproject.toml b/pyproject.toml index 9202389772..31d593261b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "aiohttp", "redis", "deepdiff", + "scanspec>=0.7.3", ] dynamic = ["version"] @@ -47,6 +48,7 @@ dev = [ # Commented out due to dependency version conflict with pydantic 1.x # "copier", "myst-parser", + "ophyd_async[sim]", "pipdeptree", "pre-commit", "psutil", diff --git a/src/dodal/plan_stubs/wrapped.py b/src/dodal/plan_stubs/wrapped.py new file mode 100644 index 0000000000..3f46f2d6b2 --- /dev/null +++ b/src/dodal/plan_stubs/wrapped.py @@ -0,0 +1,151 @@ +import itertools +from collections.abc import Mapping +from typing import Annotated, Any + +import bluesky.plan_stubs as bps +from bluesky.protocols import Movable +from bluesky.utils import MsgGenerator + +""" +Wrappers for Bluesky built-in plan stubs with type hinting +""" + +Group = Annotated[str, "String identifier used by 'wait' or stubs that await"] + + +# https://github.com/bluesky/bluesky/issues/1821 +def set_absolute( + movable: Movable, value: Any, group: Group | None = None, wait: bool = False +) -> MsgGenerator: + """ + Set a device, wrapper for `bp.abs_set`. + + Args: + movable (Movable): The device to set + value (T): The new value + group (Group | None, optional): The message group to associate with the + setting, for sequencing. Defaults to None. + wait (bool, optional): The group should wait until all setting is complete + (e.g. a motor has finished moving). Defaults to False. + + Returns: + MsgGenerator: Plan + + Yields: + Iterator[MsgGenerator]: Bluesky messages + """ + return (yield from bps.abs_set(movable, value, group=group, wait=wait)) + + +# https://github.com/bluesky/bluesky/issues/1821 +def set_relative( + movable: Movable, value: Any, group: Group | None = None, wait: bool = False +) -> MsgGenerator: + """ + Change a device, wrapper for `bp.rel_set`. + + Args: + movable (Movable): The device to set + value (T): The new value + group (Group | None, optional): The message group to associate with the + setting, for sequencing. Defaults to None. + wait (bool, optional): The group should wait until all setting is complete + (e.g. a motor has finished moving). Defaults to False. + + Returns: + MsgGenerator: Plan + + Yields: + Iterator[MsgGenerator]: Bluesky messages + """ + + return (yield from bps.rel_set(movable, value, group=group, wait=wait)) + + +# https://github.com/bluesky/bluesky/issues/1821 +def move(moves: Mapping[Movable, Any], group: Group | None = None) -> MsgGenerator: + """ + Move a device, wrapper for `bp.mv`. + + Args: + moves (Mapping[Movable, T]): Mapping of Movables to target positions + group (Group | None, optional): The message group to associate with the + setting, for sequencing. Defaults to None. + + Returns: + MsgGenerator: Plan + + Yields: + Iterator[MsgGenerator]: Bluesky messages + """ + + return ( + # https://github.com/bluesky/bluesky/issues/1809 + yield from bps.mv(*itertools.chain.from_iterable(moves.items()), group=group) # type: ignore + ) + + +# https://github.com/bluesky/bluesky/issues/1821 +def move_relative( + moves: Mapping[Movable, Any], group: Group | None = None +) -> MsgGenerator: + """ + Move a device relative to its current position, wrapper for `bp.mvr`. + + Args: + moves (Mapping[Movable, T]): Mapping of Movables to target deltas + group (Group | None, optional): The message group to associate with the + setting, for sequencing. Defaults to None. + + Returns: + MsgGenerator: Plan + + Yields: + Iterator[MsgGenerator]: Bluesky messages + """ + + return ( + # https://github.com/bluesky/bluesky/issues/1809 + yield from bps.mvr(*itertools.chain.from_iterable(moves.items()), group=group) # type: ignore + ) + + +def sleep(time: float) -> MsgGenerator: + """ + Suspend all action for a given time, wrapper for `bp.sleep` + + Args: + time (float): Time to wait in seconds + + Returns: + MsgGenerator: Plan + + Yields: + Iterator[MsgGenerator]: Bluesky messages + """ + + return (yield from bps.sleep(time)) + + +def wait( + group: Group | None = None, + timeout: float | None = None, +) -> MsgGenerator: + """ + Wait for a group status to complete, wrapper for `bp.wait` + + Args: + group (Group | None, optional): The name of the group to wait for, defaults + to None, in which case waits for all + groups that have not yet been awaited. + timeout (float | None, default=None): a timeout in seconds + + + Returns: + MsgGenerator: Plan + + Yields: + Iterator[MsgGenerator]: Bluesky messages + """ + + return (yield from bps.wait(group, timeout=timeout)) diff --git a/src/dodal/plans/__init__.py b/src/dodal/plans/__init__.py new file mode 100644 index 0000000000..fb40245969 --- /dev/null +++ b/src/dodal/plans/__init__.py @@ -0,0 +1,4 @@ +from .scanspec import spec_scan +from .wrapped import count + +__all__ = ["count", "spec_scan"] diff --git a/src/dodal/plans/scanspec.py b/src/dodal/plans/scanspec.py new file mode 100644 index 0000000000..c9af3b5df9 --- /dev/null +++ b/src/dodal/plans/scanspec.py @@ -0,0 +1,65 @@ +import operator +from functools import reduce +from typing import Annotated, Any + +import bluesky.plans as bp +from bluesky.protocols import Movable, Readable +from cycler import Cycler, cycler +from pydantic import Field, validate_call +from scanspec.specs import Spec + +from dodal.common import MsgGenerator +from dodal.plan_stubs.data_session import attach_data_session_metadata_decorator + + +@attach_data_session_metadata_decorator() +@validate_call(config={"arbitrary_types_allowed": True}) +def spec_scan( + detectors: Annotated[ + set[Readable], + Field( + description="Set of readable devices, will take a reading at each point, \ + in addition to any Movables in the Spec", + ), + ], + spec: Annotated[ + Spec[Movable], + Field(description="ScanSpec modelling the path of the scan"), + ], + metadata: dict[str, Any] | None = None, +) -> MsgGenerator: + """Generic plan for reading `detectors` at every point of a ScanSpec `Spec`. + A `Spec` is an N-dimensional path. + """ + _md = { + "plan_args": { + "detectors": {det.name for det in detectors}, + "spec": repr(spec), + }, + "plan_name": "spec_scan", + "shape": spec.shape(), + **(metadata or {}), + } + + yield from bp.scan_nd(tuple(detectors), _as_cycler(spec), md=_md) + + +def _as_cycler( + spec: Spec[Movable], # type: ignore +) -> Cycler: + """ + Convert a scanspec to a cycler for compatibility with legacy Bluesky plans such as + `bp.scan_nd`. Use the midpoints of the scanspec since cyclers are normally used + for software triggered scans. + + Args: + spec: A scanspec + + Returns: + Cycler: A new cycler + """ + + midpoints = spec.frames().midpoints + # Need to "add" the cyclers for all the axes together. The code below is + # effectively: cycler(motor1, [...]) + cycler(motor2, [...]) + ... + return reduce(operator.add, (cycler(*args) for args in midpoints.items())) diff --git a/src/dodal/plans/wrapped.py b/src/dodal/plans/wrapped.py new file mode 100644 index 0000000000..36b65b10b6 --- /dev/null +++ b/src/dodal/plans/wrapped.py @@ -0,0 +1,42 @@ +from typing import Annotated, Any + +import bluesky.plans as bp +from bluesky.protocols import Readable +from pydantic import Field, NonNegativeFloat, validate_call + +from dodal.common import MsgGenerator +from dodal.plan_stubs.data_session import attach_data_session_metadata_decorator + + +@attach_data_session_metadata_decorator() +@validate_call(config={"arbitrary_types_allowed": True}) +def count( + detectors: Annotated[ + set[Readable], + Field( + description="Set of readable devices, will take a reading at each point", + min_length=1, + ), + ], + num: Annotated[int, Field(description="Number of frames to collect", ge=1)] = 1, + delay: Annotated[ + NonNegativeFloat | list[NonNegativeFloat], + Field( + description="Delay between readings: if list, len(delay) == num - 1 and \ + the delays are between each point, if value or None is the delay for every \ + gap", + json_schema_extra={"units": "s"}, + ), + ] = 0.0, + metadata: dict[str, Any] | None = None, +) -> MsgGenerator: + """Reads from a number of devices. + Wraps bluesky.plans.count(det, num, delay, md=metadata) exposing only serializable + parameters and metadata.""" + if isinstance(delay, list): + assert ( + delays := len(delay) + ) == num - 1, f"Number of delays given must be {num - 1}: was given {delays}" + metadata = metadata or {} + metadata["shape"] = (num,) + yield from bp.count(tuple(detectors), num, delay=delay, md=metadata) diff --git a/tests/plan_stubs/test_wrapped_stubs.py b/tests/plan_stubs/test_wrapped_stubs.py new file mode 100644 index 0000000000..2755fc9fb0 --- /dev/null +++ b/tests/plan_stubs/test_wrapped_stubs.py @@ -0,0 +1,145 @@ +from unittest.mock import ANY + +import pytest +from bluesky.run_engine import RunEngine +from bluesky.utils import Msg +from ophyd_async.core import ( + DeviceCollector, +) +from ophyd_async.sim.demo import SimMotor + +from dodal.plan_stubs.wrapped import ( + move, + move_relative, + set_absolute, + set_relative, + sleep, + wait, +) + + +@pytest.fixture +def x_axis(RE: RunEngine) -> SimMotor: + with DeviceCollector(): + x_axis = SimMotor() + return x_axis + + +@pytest.fixture +def y_axis(RE: RunEngine) -> SimMotor: + with DeviceCollector(): + y_axis = SimMotor() + return y_axis + + +def test_set_absolute(x_axis: SimMotor): + assert list(set_absolute(x_axis, 0.5)) == [Msg("set", x_axis, 0.5, group=None)] + + +def test_set_absolute_with_group(x_axis: SimMotor): + assert list(set_absolute(x_axis, 0.5, group="foo")) == [ + Msg("set", x_axis, 0.5, group="foo") + ] + + +def test_set_absolute_with_wait(x_axis: SimMotor): + msgs = list(set_absolute(x_axis, 0.5, wait=True)) + assert len(msgs) == 2 + assert msgs[0] == Msg("set", x_axis, 0.5, group=ANY) + assert msgs[1] == Msg("wait", group=msgs[0].kwargs["group"]) + + +def test_set_absolute_with_group_and_wait(x_axis: SimMotor): + assert list(set_absolute(x_axis, 0.5, group="foo", wait=True)) == [ + Msg("set", x_axis, 0.5, group="foo"), + Msg("wait", group="foo"), + ] + + +def test_set_relative(x_axis: SimMotor): + assert list(set_relative(x_axis, 0.5)) == [ + Msg("read", x_axis), + Msg("set", x_axis, 0.5, group=None), + ] + + +def test_set_relative_with_group(x_axis: SimMotor): + assert list(set_relative(x_axis, 0.5, group="foo")) == [ + Msg("read", x_axis), + Msg("set", x_axis, 0.5, group="foo"), + ] + + +def test_set_relative_with_wait(x_axis: SimMotor): + msgs = list(set_relative(x_axis, 0.5, wait=True)) + assert len(msgs) == 3 + assert msgs[0] == Msg("read", x_axis) + assert msgs[1] == Msg("set", x_axis, 0.5, group=ANY) + assert msgs[2] == Msg("wait", group=msgs[1].kwargs["group"]) + + +def test_set_relative_with_group_and_wait(x_axis: SimMotor): + assert list(set_relative(x_axis, 0.5, group="foo", wait=True)) == [ + Msg("read", x_axis), + Msg("set", x_axis, 0.5, group="foo"), + Msg("wait", group="foo"), + ] + + +def test_move(x_axis: SimMotor, y_axis: SimMotor): + msgs = list(move({x_axis: 0.5, y_axis: 1.0})) + assert msgs[0] == Msg("set", x_axis, 0.5, group=ANY) + assert msgs[1] == Msg("set", y_axis, 1.0, group=msgs[0].kwargs["group"]) + assert msgs[2] == Msg("wait", group=msgs[0].kwargs["group"]) + + +def test_move_group(x_axis: SimMotor, y_axis: SimMotor): + msgs = list(move({x_axis: 0.5, y_axis: 1.0}, group="foo")) + assert msgs[0] == Msg("set", x_axis, 0.5, group="foo") + assert msgs[1] == Msg("set", y_axis, 1.0, group="foo") + assert msgs[2] == Msg("wait", group="foo") + + +def test_move_relative(x_axis: SimMotor, y_axis: SimMotor): + msgs = list(move_relative({x_axis: 0.5, y_axis: 1.0})) + assert msgs[0] == Msg("read", x_axis) + assert msgs[1] == Msg("set", x_axis, 0.5, group=ANY) + group = msgs[1].kwargs["group"] + assert msgs[2] == Msg("read", y_axis) + assert msgs[3] == Msg("set", y_axis, 1.0, group=group) + assert msgs[4] == Msg("wait", group=group) + + +def test_move_relative_group(x_axis: SimMotor, y_axis: SimMotor): + msgs = list(move_relative({x_axis: 0.5, y_axis: 1.0}, group="foo")) + assert msgs[0] == Msg("read", x_axis) + assert msgs[1] == Msg("set", x_axis, 0.5, group="foo") + assert msgs[2] == Msg("read", y_axis) + assert msgs[3] == Msg("set", y_axis, 1.0, group="foo") + assert msgs[4] == Msg("wait", group="foo") + + +def test_sleep(): + assert list(sleep(1.5)) == [Msg("sleep", None, 1.5)] + + +def test_wait(): + # Waits for all groups + # move_on and wait are antonyms: what does move_on do? + assert list(wait()) == [Msg("wait", group=None, timeout=None, move_on=False)] + + +def test_wait_group(): + assert list(wait("foo")) == [Msg("wait", group="foo", timeout=None, move_on=False)] + + +def test_wait_timeout(): + assert list(wait(timeout=5.0)) == [ + Msg("wait", group=None, timeout=5.0, move_on=False) + ] + + +def test_wait_group_and_timeout(): + assert list(wait("foo", 5.0)) == [ + Msg("wait", group="foo", timeout=5.0, move_on=False) + ] diff --git a/tests/plans/test_compliance.py b/tests/plans/test_compliance.py new file mode 100644 index 0000000000..a3b036c0a8 --- /dev/null +++ b/tests/plans/test_compliance.py @@ -0,0 +1,69 @@ +import inspect +from types import ModuleType +from typing import Any, get_type_hints + +from bluesky.utils import MsgGenerator + +from dodal import plan_stubs, plans +from dodal.common.types import PlanGenerator + + +def is_bluesky_plan_generator(func: Any) -> bool: + try: + return callable(func) and get_type_hints(func).get("return") == MsgGenerator + except TypeError: + # get_type_hints fails on some objects (such as Union or Optional) + return False + + +def get_all_available_generators(mod: ModuleType): + def get_named_subset(names: list[str]): + for name in names: + yield getattr(mod, name) + + if explicit_exports := mod.__dict__.get("__export__"): + yield from get_named_subset(explicit_exports) + elif implicit_exports := mod.__dict__.get("__all__"): + yield from get_named_subset(implicit_exports) + else: + for name, value in mod.__dict__.items(): + if not name.startswith("_"): + yield value + + +def assert_hard_requirements(plan: PlanGenerator, signature: inspect.Signature): + assert plan.__doc__ is not None, f"'{plan.__name__}' has no docstring" + for parameter in signature.parameters.values(): + assert ( + parameter.kind is not parameter.VAR_POSITIONAL + and parameter.kind is not parameter.VAR_KEYWORD + ), f"'{plan.__name__}' has variadic arguments" + + +def assert_metadata_requirements(plan: PlanGenerator, signature: inspect.Signature): + assert ( + "metadata" in signature.parameters + ), f"'{plan.__name__}' does not allow metadata" + metadata = signature.parameters["metadata"] + assert ( + metadata.annotation == dict[str, Any] | None + and metadata.default is not inspect.Parameter.empty + ), f"'{plan.__name__}' metadata is not optional" + assert metadata.default is None, f"'{plan.__name__}' metadata default is mutable" + + +def test_plans_comply(): + for plan in get_all_available_generators(plans): + if is_bluesky_plan_generator(plan): + signature = inspect.Signature.from_callable(plan) + assert_hard_requirements(plan, signature) + assert_metadata_requirements(plan, signature) + + +def test_stubs_comply(): + for stub in get_all_available_generators(plan_stubs): + if is_bluesky_plan_generator(stub): + signature = inspect.Signature.from_callable(stub) + assert_hard_requirements(stub, signature) + if "metadata" in signature.parameters: + assert_metadata_requirements(stub, signature) diff --git a/tests/plans/test_scanspec.py b/tests/plans/test_scanspec.py new file mode 100644 index 0000000000..abe615c103 --- /dev/null +++ b/tests/plans/test_scanspec.py @@ -0,0 +1,98 @@ +from pathlib import Path +from typing import cast +from unittest.mock import patch + +import pytest +from bluesky.run_engine import RunEngine +from event_model.documents import ( + DocumentType, + Event, + EventDescriptor, + RunStart, + RunStop, +) +from ophyd_async.core import ( + DeviceCollector, + PathProvider, + StandardDetector, +) +from ophyd_async.sim.demo import PatternDetector, SimMotor +from scanspec.specs import Line + +from dodal.plans import spec_scan + + +@pytest.fixture +def det(RE: RunEngine, tmp_path: Path) -> StandardDetector: + with DeviceCollector(mock=True): + det = PatternDetector(tmp_path / "foo.h5") + return det + + +@pytest.fixture +def x_axis(RE: RunEngine) -> SimMotor: + with DeviceCollector(mock=True): + x_axis = SimMotor() + return x_axis + + +@pytest.fixture +def y_axis(RE: RunEngine) -> SimMotor: + with DeviceCollector(mock=True): + y_axis = SimMotor() + return y_axis + + +@pytest.fixture +def path_provider(static_path_provider: PathProvider): + # Prevents issue with leftover state from beamline tests + with patch("dodal.plan_stubs.data_session.get_path_provider") as mock: + mock.return_value = static_path_provider + yield + + +def test_output_of_simple_spec( + RE: RunEngine, x_axis: SimMotor, det: StandardDetector, path_provider +): + docs: dict[str, list[DocumentType]] = {} + RE( + spec_scan( + {det}, + Line(axis=x_axis, start=1, stop=2, num=3), + ), + lambda name, doc: docs.setdefault(name, []).append(doc), + ) + for metadata_doc in ("start", "stop", "descriptor"): + assert metadata_doc in docs + assert len(docs[metadata_doc]) == 1 + + start = cast(RunStart, docs["start"][0]) + assert (hints := start.get("hints")) and ( + hints.get("dimensions") == [([x_axis.user_readback.name], "primary")] + ) + assert start.get("shape") == (3,) + + descriptor = cast(EventDescriptor, docs["descriptor"][0]) + assert x_axis.name in descriptor.get("object_keys", {}) + assert det.name in descriptor.get("object_keys", {}) + + stop = cast(RunStop, docs["stop"][0]) + assert stop.get("exit_status") == "success" + assert stop.get("num_events") == {"primary": 3} + assert stop.get("run_start") == start.get("uid") + + assert "event" in docs + + initial_position = 1.0 + step = 0.5 + for doc, index in zip(docs["event"], range(1, 4), strict=True): + event = cast(Event, doc) + location = initial_position + ((index - 1) * step) + assert event.get("data").get(x_axis.user_readback.name) == location + + # Output of detector not linked to Spec, just check that dets are all triggered + assert "stream_resource" in docs + assert len(docs["stream_resource"]) == 2 # det, det.sum + + assert "stream_datum" in docs + assert len(docs["stream_datum"]) == 3 * 2 # each point per resource