diff --git a/qiskit_ibm_runtime/execution_span/__init__.py b/qiskit_ibm_runtime/execution_span/__init__.py index 7637ae82d..49d891b25 100644 --- a/qiskit_ibm_runtime/execution_span/__init__.py +++ b/qiskit_ibm_runtime/execution_span/__init__.py @@ -35,9 +35,11 @@ ExecutionSpans ShapeType SliceSpan + TwirledSliceSpan """ from .double_slice_span import DoubleSliceSpan from .execution_span import ExecutionSpan, ShapeType from .execution_spans import ExecutionSpans from .slice_span import SliceSpan +from .twirled_slice_span import TwirledSliceSpan diff --git a/qiskit_ibm_runtime/execution_span/double_slice_span.py b/qiskit_ibm_runtime/execution_span/double_slice_span.py index 2e9bc0b0c..70238f505 100644 --- a/qiskit_ibm_runtime/execution_span/double_slice_span.py +++ b/qiskit_ibm_runtime/execution_span/double_slice_span.py @@ -28,16 +28,16 @@ class DoubleSliceSpan(ExecutionSpan): """An :class:`~.ExecutionSpan` for data stored in a sliceable format. This type of execution span references pub result data by assuming that it is a sliceable - portion of the data where the shots are the outermost slice and the rest of the data is flattened. - Therefore, for each pub dependent on this span, the constructor accepts two :class:`slice` objects, - along with the corresponding shape of the data to be sliced; in contrast to - :class:`~.SliceSpan`, this class does not assume that *all* shots for a particular set of parameter - values are contiguous in the array of data. + portion of the data where the shots are the outermost slice and the rest of the data is + flattened. Therefore, for each pub dependent on this span, the constructor accepts two + :class:`slice` objects, along with the corresponding shape of the data to be sliced; in contrast + to :class:`~.SliceSpan`, this class does not assume that *all* shots for a particular set of + parameter values are contiguous in the array of data. Args: start: The start time of the span, in UTC. stop: The stop time of the span, in UTC. - data_slices: A map from pub indices to ``(shape_tuple, slice, slice)``. + data_slices: A map from pub indices to ``(shape_tuple, flat_shape_slice, shots_slice)``. """ def __init__( diff --git a/qiskit_ibm_runtime/execution_span/twirled_slice_span.py b/qiskit_ibm_runtime/execution_span/twirled_slice_span.py new file mode 100644 index 000000000..06be56d70 --- /dev/null +++ b/qiskit_ibm_runtime/execution_span/twirled_slice_span.py @@ -0,0 +1,92 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2024. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +"""TwirledSliceSpan""" + +from __future__ import annotations + +from datetime import datetime +from typing import Iterable + +import math +import numpy as np +import numpy.typing as npt + +from .execution_span import ExecutionSpan, ShapeType + + +class TwirledSliceSpan(ExecutionSpan): + """An :class:`~.ExecutionSpan` for data stored in a sliceable format when twirling. + + This type of execution span references pub result data that came from a twirled sampler + experiment which was executed by either prepending or appending an axis to paramater values + to account for twirling. Concretely, ``data_slices`` is a map from pub slices to tuples + ``(twirled_shape, at_front, shape_slice, shots_slice)`` where + + * ``twirled_shape`` is the shape tuple including a twirling axis, and where the last + axis is shots per randomization, + * ``at_front`` is whether ``num_randomizations`` is at the front of the tuple, as + opposed to right before the ``shots`` axis at the end, + * ``shape_slice`` is a slice of an array of shape ``twirled_shape[:-1]``, flattened, + * and ``shots_slice`` is a slice of ``twirled_shape[-1]``. + + Args: + start: The start time of the span, in UTC. + stop: The stop time of the span, in UTC. + data_slices: A map from pub indices to length-4 tuples described above. + """ + + def __init__( + self, + start: datetime, + stop: datetime, + data_slices: dict[int, tuple[ShapeType, bool, slice, slice]], + ): + super().__init__(start, stop) + self._data_slices = data_slices + + def __eq__(self, other: object) -> bool: + return isinstance(other, TwirledSliceSpan) and ( + self.start == other.start + and self.stop == other.stop + and self._data_slices == other._data_slices + ) + + @property + def pub_idxs(self) -> list[int]: + return sorted(self._data_slices) + + @property + def size(self) -> int: + size = 0 + for shape, _, shape_sl, shots_sl in self._data_slices.values(): + size += len(range(math.prod(shape[:-1]))[shape_sl]) * len(range(shape[-1])[shots_sl]) + return size + + def mask(self, pub_idx: int) -> npt.NDArray[np.bool_]: + twirled_shape, at_front, shape_sl, shots_sl = self._data_slices[pub_idx] + mask = np.zeros(twirled_shape, dtype=np.bool_) + mask.reshape((np.prod(twirled_shape[:-1]), twirled_shape[-1]))[(shape_sl, shots_sl)] = True + + if at_front: + # if the first axis is over twirling samples, push them right before shots + ndim = len(twirled_shape) + mask = mask.transpose((*range(1, ndim - 1), 0, ndim - 1)) + twirled_shape = twirled_shape[1:-1] + twirled_shape[:1] + twirled_shape[-1:] + + # merge twirling axis and shots axis before returning + return mask.reshape((*twirled_shape[:-2], math.prod(twirled_shape[-2:]))) + + def filter_by_pub(self, pub_idx: int | Iterable[int]) -> "TwirledSliceSpan": + pub_idx = {pub_idx} if isinstance(pub_idx, int) else set(pub_idx) + slices = {idx: val for idx, val in self._data_slices.items() if idx in pub_idx} + return TwirledSliceSpan(self.start, self.stop, slices) diff --git a/qiskit_ibm_runtime/utils/json.py b/qiskit_ibm_runtime/utils/json.py index 1ceb745cd..f0b85483c 100644 --- a/qiskit_ibm_runtime/utils/json.py +++ b/qiskit_ibm_runtime/utils/json.py @@ -79,6 +79,7 @@ DoubleSliceSpan, SliceSpan, ExecutionSpans, + TwirledSliceSpan, ) from .noise_learner_result import NoiseLearnerResult @@ -341,6 +342,16 @@ def default(self, obj: Any) -> Any: # pylint: disable=arguments-differ }, } return {"__type__": "DoubleSliceSpan", "__value__": out_val} + if isinstance(obj, TwirledSliceSpan): + out_val = { + "start": obj.start, + "stop": obj.stop, + "data_slices": { + idx: (shape, at_front, arg_sl.start, arg_sl.stop, shot_sl.start, shot_sl.stop) + for idx, (shape, at_front, arg_sl, shot_sl) in obj._data_slices.items() + }, + } + return {"__type__": "TwirledSliceSpan", "__value__": out_val} if isinstance(obj, SliceSpan): out_val = { "start": obj.start, @@ -470,6 +481,13 @@ def object_hook(self, obj: Any) -> Any: for idx, (shape, arg0, arg1, shot0, shot1) in obj_val["data_slices"].items() } return DoubleSliceSpan(**obj_val) + if obj_type == "TwirledSliceSpan": + data_slices = obj_val["data_slices"] + obj_val["data_slices"] = { + int(idx): (tuple(shape), at_start, slice(arg0, arg1), slice(shot0, shot1)) + for idx, (shape, at_start, arg0, arg1, shot0, shot1) in data_slices.items() + } + return TwirledSliceSpan(**obj_val) if obj_type == "ExecutionSpan": new_slices = { int(idx): (tuple(shape), slice(*sl_args)) diff --git a/release-notes/unreleased/2011.feat.rst b/release-notes/unreleased/2011.feat.rst new file mode 100644 index 000000000..b1a6f83dc --- /dev/null +++ b/release-notes/unreleased/2011.feat.rst @@ -0,0 +1,4 @@ +Added :class:`.TwirledSliceSpan`, an :class:`ExecutionSpan` to be used when +twirling is enabled in the sampler. In particular, it keeps track of an extra shape +axis corresponding to twirling randomizations, and also whether this axis exists at +the front of the shape tuple, or right before the shots axis. \ No newline at end of file diff --git a/test/unit/test_data_serialization.py b/test/unit/test_data_serialization.py index 1b12e54b9..8343c54af 100644 --- a/test/unit/test_data_serialization.py +++ b/test/unit/test_data_serialization.py @@ -51,6 +51,7 @@ DoubleSliceSpan, SliceSpan, ExecutionSpans, + TwirledSliceSpan, ) from .mock.fake_runtime_client import CustomResultRuntimeJob @@ -468,6 +469,14 @@ def make_test_primitive_results(self): datetime(2024, 8, 21), {0: ((14,), slice(2, 3), slice(1, 9))}, ), + TwirledSliceSpan( + datetime(2024, 9, 20), + datetime(2024, 3, 21), + { + 0: ((14, 18, 21), True, slice(2, 3), slice(1, 9)), + 2: ((18, 14, 19), False, slice(2, 3), slice(1, 9)), + }, + ), ] ) } diff --git a/test/unit/test_execution_span.py b/test/unit/test_execution_span.py index 2f55ddc5a..26911f5f3 100644 --- a/test/unit/test_execution_span.py +++ b/test/unit/test_execution_span.py @@ -18,7 +18,12 @@ import numpy as np import numpy.testing as npt -from qiskit_ibm_runtime.execution_span import SliceSpan, DoubleSliceSpan, ExecutionSpans +from qiskit_ibm_runtime.execution_span import ( + SliceSpan, + DoubleSliceSpan, + ExecutionSpans, + TwirledSliceSpan, +) from ..ibm_test_case import IBMTestCase @@ -222,6 +227,118 @@ def test_filter_by_pub(self): ) +@ddt.ddt +class TestTwirledSliceSpan(IBMTestCase): + """Class for testing TwirledSliceSpan.""" + + def setUp(self) -> None: + super().setUp() + self.start1 = datetime(2024, 10, 11, 4, 31, 30) + self.stop1 = datetime(2024, 10, 11, 4, 31, 34) + self.slices1 = { + 2: ((3, 1, 5), True, slice(1), slice(2, 4)), + 0: ((3, 5, 18, 10), False, slice(10, 13), slice(2, 5)), + } + self.span1 = TwirledSliceSpan(self.start1, self.stop1, self.slices1) + + self.start2 = datetime(2024, 10, 16, 11, 9, 20) + self.stop2 = datetime(2024, 10, 16, 11, 9, 30) + self.slices2 = { + 0: ((7, 5, 100), True, slice(3, 5), slice(20, 40)), + 1: ((1, 5, 2, 3), False, slice(3, 9), slice(1, 3)), + } + self.span2 = TwirledSliceSpan(self.start2, self.stop2, self.slices2) + + def test_limits(self): + """Test the start and stop properties""" + self.assertEqual(self.span1.start, self.start1) + self.assertEqual(self.span1.stop, self.stop1) + self.assertEqual(self.span2.start, self.start2) + self.assertEqual(self.span2.stop, self.stop2) + + def test_equality(self): + """Test the equality method.""" + self.assertEqual(self.span1, self.span1) + self.assertEqual(self.span1, TwirledSliceSpan(self.start1, self.stop1, self.slices1)) + self.assertNotEqual(self.span1, "aoeu") + self.assertNotEqual(self.span1, self.span2) + + def test_duration(self): + """Test the duration property""" + self.assertEqual(self.span1.duration, 4) + self.assertEqual(self.span2.duration, 10) + + def test_repr(self): + """Test the repr method""" + expect = "start='2024-10-11 04:31:30', stop='2024-10-11 04:31:34', size=11" + self.assertEqual(repr(self.span1), f"TwirledSliceSpan(<{expect}>)") + + def test_size(self): + """Test the size property""" + self.assertEqual(self.span1.size, 1 * 2 + 3 * 3) + self.assertEqual(self.span2.size, 2 * 20 + 6 * 2) + + def test_pub_idxs(self): + """Test the pub_idxs property""" + self.assertEqual(self.span1.pub_idxs, [0, 2]) + self.assertEqual(self.span2.pub_idxs, [0, 1]) + + def test_mask(self): + """Test the mask() method""" + # reminder: ((3, 1, 5), True, slice(1), slice(2, 4)) + mask1 = np.zeros((3, 1, 5), dtype=bool) + mask1.reshape((3, 5))[:1, 2:4] = True + mask1 = mask1.transpose((1, 0, 2)).reshape((1, 15)) + npt.assert_array_equal(self.span1.mask(2), mask1) + + # reminder: ((1, 5, 2, 3), False, slice(3,9), slice(1, 3)), + mask2 = [ + [ + [[[0, 0, 0], [0, 0, 0]]], + [[[0, 0, 0], [0, 1, 1]]], + [[[0, 1, 1], [0, 1, 1]]], + [[[0, 1, 1], [0, 1, 1]]], + [[[0, 1, 1], [0, 0, 0]]], + ] + ] + mask2 = np.array(mask2, dtype=bool).reshape((1, 5, 6)) + npt.assert_array_equal(self.span2.mask(1), mask2) + + @ddt.data( + (0, True, True), + ([0, 1], True, True), + ([0, 1, 2], True, True), + ([1, 2], True, True), + ([1], False, True), + (2, True, False), + ([0, 2], True, True), + ) + @ddt.unpack + def test_contains_pub(self, idx, span1_expected_res, span2_expected_res): + """The the contains_pub method""" + self.assertEqual(self.span1.contains_pub(idx), span1_expected_res) + self.assertEqual(self.span2.contains_pub(idx), span2_expected_res) + + def test_filter_by_pub(self): + """The the filter_by_pub method""" + self.assertEqual( + self.span1.filter_by_pub([]), TwirledSliceSpan(self.start1, self.stop1, {}) + ) + self.assertEqual( + self.span2.filter_by_pub([]), TwirledSliceSpan(self.start2, self.stop2, {}) + ) + + self.assertEqual( + self.span1.filter_by_pub([1, 0]), + TwirledSliceSpan(self.start1, self.stop1, {0: self.slices1[0]}), + ) + + self.assertEqual( + self.span1.filter_by_pub(2), + TwirledSliceSpan(self.start1, self.stop1, {2: self.slices1[2]}), + ) + + @ddt.ddt class TestExecutionSpans(IBMTestCase): """Class for testing ExecutionSpans."""