From 3e4cee6e367c5e71681430e76dcbe0d8823da7e5 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Mon, 1 Jul 2024 18:13:20 +0200 Subject: [PATCH 1/4] Add failing test --- .../time_reversal_pulse_template_tests.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/pulses/time_reversal_pulse_template_tests.py b/tests/pulses/time_reversal_pulse_template_tests.py index 0ded8423..ad012e5f 100644 --- a/tests/pulses/time_reversal_pulse_template_tests.py +++ b/tests/pulses/time_reversal_pulse_template_tests.py @@ -1,5 +1,9 @@ import unittest +import numpy as np + +from qupulse.pulses import ConstantPT, FunctionPT +from qupulse.plotting import render from qupulse.pulses.time_reversal_pulse_template import TimeReversalPulseTemplate from qupulse.utils.types import TimeType from qupulse.expressions import ExpressionScalar @@ -25,6 +29,19 @@ def test_simple_properties(self): self.assertEqual(reversed_pt.identifier, 'reverse') + def test_time_reversal_program(self): + inner = ConstantPT(4, {'a': 3}) @ FunctionPT('sin(t)', 5, channel='a') + manual_reverse = FunctionPT('sin(5 - t)', 5, channel='a') @ ConstantPT(4, {'a': 3}) + time_reversed = TimeReversalPulseTemplate(inner) + + program = time_reversed.create_program() + manual_program = manual_reverse.create_program() + + t, data, _ = render(program, 9 / 10) + _, manual_data, _ = render(manual_program, 9 / 10) + + np.testing.assert_allclose(data['a'], manual_data['a']) + class TimeReversalPulseTemplateSerializationTests(unittest.TestCase, SerializableTests): @property From 2fd223e8c12fc90b9b3f60a9039dc42ac5192019 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Mon, 1 Jul 2024 18:24:56 +0200 Subject: [PATCH 2/4] Add time_reversed to LoopBuilder --- qupulse/program/__init__.py | 3 +++ qupulse/program/loop.py | 10 ++++++++++ qupulse/program/waveforms.py | 2 +- qupulse/pulses/time_reversal_pulse_template.py | 11 ++++------- 4 files changed, 18 insertions(+), 8 deletions(-) diff --git a/qupulse/program/__init__.py b/qupulse/program/__init__.py index cd578dd5..63dc4f60 100644 --- a/qupulse/program/__init__.py +++ b/qupulse/program/__init__.py @@ -164,6 +164,9 @@ def with_iteration(self, index_name: str, rng: range, measurements: Optional[Sequence[MeasurementWindow]] = None) -> Iterable['ProgramBuilder']: pass + def time_reversed(self) -> ContextManager['ProgramBuilder']: + pass + def to_program(self) -> Optional[Program]: """Further addition of new elements might fail after finalizing the program.""" diff --git a/qupulse/program/loop.py b/qupulse/program/loop.py index 0e5dccc3..9e59f9d3 100644 --- a/qupulse/program/loop.py +++ b/qupulse/program/loop.py @@ -817,6 +817,16 @@ def with_iteration(self, index_name: str, rng: range, top_frame.iterating = (index_name, value) yield self + @contextmanager + def time_reversed(self) -> ContextManager['LoopBuilder']: + inner_builder = LoopBuilder() + yield inner_builder + inner_program = inner_builder.to_program() + + if inner_program: + inner_program.reverse_inplace() + self._try_append(inner_program, None) + @contextmanager def with_sequence(self, measurements: Optional[Sequence[MeasurementWindow]] = None) -> ContextManager['ProgramBuilder']: top_frame = StackFrame(LoopGuard(self._top, measurements), None) diff --git a/qupulse/program/waveforms.py b/qupulse/program/waveforms.py index 08b54411..1e35bc1b 100644 --- a/qupulse/program/waveforms.py +++ b/qupulse/program/waveforms.py @@ -1257,7 +1257,7 @@ def unsafe_sample(self, channel: ChannelID, sample_times: np.ndarray, else: inner_output_array = output_array[::-1] inner_output_array = self._inner.unsafe_sample(channel, inner_sample_times, output_array=inner_output_array) - if inner_output_array.base not in (output_array, output_array.base): + if id(inner_output_array.base) not in (id(output_array), id(output_array.base)): # TODO: is there a guarantee by numpy we never end up here? output_array[:] = inner_output_array[::-1] return output_array diff --git a/qupulse/pulses/time_reversal_pulse_template.py b/qupulse/pulses/time_reversal_pulse_template.py index 5dc9fcab..47bec232 100644 --- a/qupulse/pulses/time_reversal_pulse_template.py +++ b/qupulse/pulses/time_reversal_pulse_template.py @@ -5,7 +5,7 @@ from typing import Optional, Set, Dict, Union from qupulse import ChannelID -from qupulse.program.loop import Loop +from qupulse.program import ProgramBuilder from qupulse.program.waveforms import Waveform from qupulse.serialization import PulseRegistryType from qupulse.expressions import ExpressionScalar @@ -50,12 +50,9 @@ def defined_channels(self) -> Set['ChannelID']: def integral(self) -> Dict[ChannelID, ExpressionScalar]: return self._inner.integral - def _internal_create_program(self, *, parent_loop: Loop, **kwargs) -> None: - inner_loop = Loop() - self._inner._internal_create_program(parent_loop=inner_loop, **kwargs) - inner_loop.reverse_inplace() - - parent_loop.append_child(inner_loop) + def _internal_create_program(self, *, program_builder: ProgramBuilder, **kwargs) -> None: + with program_builder.time_reversed() as reversed_builder: + self._inner._internal_create_program(program_builder=reversed_builder, **kwargs) def build_waveform(self, *args, **kwargs) -> Optional[Waveform]: From cd67b0da7c42e1f27e11c44c6a4de6ecf250bde3 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Mon, 1 Jul 2024 18:46:05 +0200 Subject: [PATCH 3/4] Add untested linspace implementation --- qupulse/program/linspace.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/qupulse/program/linspace.py b/qupulse/program/linspace.py index 600a6f82..cb6d26d4 100644 --- a/qupulse/program/linspace.py +++ b/qupulse/program/linspace.py @@ -44,6 +44,9 @@ class LinSpaceNode: def dependencies(self) -> Mapping[int, set]: raise NotImplementedError + def reversed(self, level: int): + raise NotImplementedError + @dataclass class LinSpaceHold(LinSpaceNode): @@ -60,6 +63,21 @@ def dependencies(self) -> Mapping[int, set]: for idx, factors in enumerate(self.factors) if factors} + def reversed(self, level: int): + factors = [] + for ch_factors in self.factors: + if ch_factors is None or len(ch_factors) <= level: + factors.append(ch_factors) + else: + reversed_factors = ch_factors[:level] + tuple(-f for f in ch_factors[level:]) + factors.append(reversed_factors) + + if self.duration_factors is not None and len(self.duration_factors) <= level: + duration_factors = self.duration_factors + else: + duration_factors = self.duration_factors[:level] + tuple(-f for f in self.duration_factors[level:]) + return LinSpaceHold(self.bases, factors, duration_base=self.duration_base, duration_factors=duration_factors) + @dataclass class LinSpaceArbitraryWaveform(LinSpaceNode): @@ -67,6 +85,12 @@ class LinSpaceArbitraryWaveform(LinSpaceNode): waveform: Waveform channels: Tuple[ChannelID, ...] + def reversed(self, level: int): + return LinSpaceArbitraryWaveform( + waveform=self.waveform.reversed(), + channels=self.channels, + ) + @dataclass class LinSpaceRepeat(LinSpaceNode): @@ -81,6 +105,9 @@ def dependencies(self): dependencies.setdefault(idx, set()).update(deps) return dependencies + def reversed(self, level: int): + return LinSpaceRepeat(tuple(node.reversed(level) for node in reversed(self.body)), self.count) + @dataclass class LinSpaceIter(LinSpaceNode): @@ -100,6 +127,9 @@ def dependencies(self): dependencies.setdefault(idx, set()).update(shortened) return dependencies + def reversed(self, level: int): + return LinSpaceIter(tuple(node.reversed() for node in reversed(self.body)), self.length) + class LinSpaceBuilder(ProgramBuilder): """This program builder supports efficient translation of pulse templates that use symbolic linearly @@ -214,6 +244,13 @@ def with_iteration(self, index_name: str, rng: range, if cmds: self._stack[-1].append(LinSpaceIter(body=tuple(cmds), length=len(rng))) + def time_reversed(self) -> ContextManager['LinSpaceBuilder']: + self._stack.append([]) + yield self + inner = self._stack.pop() + level = len(self._ranges) + self._stack[-1].extend(node.reversed(level) for node in reversed(inner)) + def to_program(self) -> Optional[Sequence[LinSpaceNode]]: if self._root(): return self._root() From 3faa5345d7a23e51125042547275650d6247066f Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Mon, 1 Jul 2024 19:28:56 +0200 Subject: [PATCH 4/4] Failing test for linspace implementation --- qupulse/program/linspace.py | 56 +++++++++++++------ tests/program/linspace_tests.py | 2 +- .../time_reversal_pulse_template_tests.py | 40 +++++++++++-- 3 files changed, 76 insertions(+), 22 deletions(-) diff --git a/qupulse/program/linspace.py b/qupulse/program/linspace.py index cb6d26d4..10d26da6 100644 --- a/qupulse/program/linspace.py +++ b/qupulse/program/linspace.py @@ -44,7 +44,7 @@ class LinSpaceNode: def dependencies(self) -> Mapping[int, set]: raise NotImplementedError - def reversed(self, level: int): + def reversed(self, offset: int, lengths: list): raise NotImplementedError @@ -63,20 +63,28 @@ def dependencies(self) -> Mapping[int, set]: for idx, factors in enumerate(self.factors) if factors} - def reversed(self, level: int): + def reversed(self, offset: int, lengths: list): + if not lengths: + return self + bases = [] factors = [] - for ch_factors in self.factors: - if ch_factors is None or len(ch_factors) <= level: + for ch_base, ch_factors in zip(self.bases, self.factors): + if ch_factors is None or len(ch_factors) <= offset: + bases.append(ch_base) factors.append(ch_factors) else: - reversed_factors = ch_factors[:level] + tuple(-f for f in ch_factors[level:]) + ch_reverse_base = ch_base + sum(length*factor for factor, length in zip(ch_factors[offset:], lengths)) + reversed_factors = ch_factors[:offset] + tuple(-f for f in ch_factors[offset:]) + bases.append(ch_reverse_base) factors.append(reversed_factors) - if self.duration_factors is not None and len(self.duration_factors) <= level: + if self.duration_factors is None or len(self.duration_factors) <= offset: duration_factors = self.duration_factors + duration_base = self.duration_base else: - duration_factors = self.duration_factors[:level] + tuple(-f for f in self.duration_factors[level:]) - return LinSpaceHold(self.bases, factors, duration_base=self.duration_base, duration_factors=duration_factors) + duration_base = self.duration_base + sum((length*factor for factor, length in zip(self.duration_factors[offset:], lengths)), TimeType(0)) + duration_factors = self.duration_factors[:offset] + tuple(-f for f in self.duration_factors[offset:]) + return LinSpaceHold(tuple(bases), tuple(factors), duration_base=duration_base, duration_factors=duration_factors) @dataclass @@ -85,7 +93,7 @@ class LinSpaceArbitraryWaveform(LinSpaceNode): waveform: Waveform channels: Tuple[ChannelID, ...] - def reversed(self, level: int): + def reversed(self, offset: int, lengths: list): return LinSpaceArbitraryWaveform( waveform=self.waveform.reversed(), channels=self.channels, @@ -105,8 +113,8 @@ def dependencies(self): dependencies.setdefault(idx, set()).update(deps) return dependencies - def reversed(self, level: int): - return LinSpaceRepeat(tuple(node.reversed(level) for node in reversed(self.body)), self.count) + def reversed(self, offset: int, counts: list): + return LinSpaceRepeat(tuple(node.reversed(offset, counts) for node in reversed(self.body)), self.count) @dataclass @@ -127,8 +135,11 @@ def dependencies(self): dependencies.setdefault(idx, set()).update(shortened) return dependencies - def reversed(self, level: int): - return LinSpaceIter(tuple(node.reversed() for node in reversed(self.body)), self.length) + def reversed(self, offset: int, lengths: list): + lengths.append(self.length) + reversed_iter = LinSpaceIter(tuple(node.reversed(offset, lengths) for node in reversed(self.body)), self.length) + lengths.pop() + return reversed_iter class LinSpaceBuilder(ProgramBuilder): @@ -244,12 +255,13 @@ def with_iteration(self, index_name: str, rng: range, if cmds: self._stack[-1].append(LinSpaceIter(body=tuple(cmds), length=len(rng))) + @contextlib.contextmanager def time_reversed(self) -> ContextManager['LinSpaceBuilder']: self._stack.append([]) yield self inner = self._stack.pop() - level = len(self._ranges) - self._stack[-1].extend(node.reversed(level) for node in reversed(inner)) + offset = len(self._ranges) + self._stack[-1].extend(node.reversed(offset, []) for node in reversed(inner)) def to_program(self) -> Optional[Sequence[LinSpaceNode]]: if self._root(): @@ -465,7 +477,19 @@ def __init__(self, channels: int): def change_state(self, cmd: Union[Set, Increment, Wait, Play]): if isinstance(cmd, Play): - raise NotImplementedError("TODO: Implement arbitrary waveform simulation") + num = 17 + dt = cmd.waveform.duration / num + t = TimeType(0) + for _ in range(num): + sample_time = np.array([float(t)]) + values = [] + for (idx, ch) in enumerate(cmd.channels): + self.current_values[idx] = values.append(cmd.waveform.get_sampled(channel=ch, sample_times=sample_time)[0]) + self.history.append( + (self.time, self.current_values.copy()) + ) + self.time += dt + t += dt elif isinstance(cmd, Wait): self.history.append( (self.time, self.current_values.copy()) diff --git a/tests/program/linspace_tests.py b/tests/program/linspace_tests.py index 23d050ac..cd717048 100644 --- a/tests/program/linspace_tests.py +++ b/tests/program/linspace_tests.py @@ -14,7 +14,7 @@ def assert_vm_output_almost_equal(test: TestCase, expected, actual): test.assertEqual(t_e, t_a, f"Differing times in {idx} element") test.assertEqual(len(vals_e), len(vals_a), f"Differing channel count in {idx} element") for ch, (val_e, val_a) in enumerate(zip(vals_e, vals_a)): - test.assertAlmostEqual(val_e, val_a, msg=f"Differing values in {idx} element channel {ch}") + test.assertAlmostEqual(val_e, val_a, msg=f"Differing values in {idx} of {len(expected)} element channel {ch}") class SingleRampTest(TestCase): diff --git a/tests/pulses/time_reversal_pulse_template_tests.py b/tests/pulses/time_reversal_pulse_template_tests.py index ad012e5f..ca158504 100644 --- a/tests/pulses/time_reversal_pulse_template_tests.py +++ b/tests/pulses/time_reversal_pulse_template_tests.py @@ -7,10 +7,11 @@ from qupulse.pulses.time_reversal_pulse_template import TimeReversalPulseTemplate from qupulse.utils.types import TimeType from qupulse.expressions import ExpressionScalar - +from qupulse.program.loop import LoopBuilder +from qupulse.program.linspace import LinSpaceBuilder, LinSpaceVM, to_increment_commands from tests.pulses.sequencing_dummies import DummyPulseTemplate from tests.serialization_tests import SerializableTests - +from tests.program.linspace_tests import assert_vm_output_almost_equal class TimeReversalPulseTemplateTests(unittest.TestCase): def test_simple_properties(self): @@ -29,19 +30,48 @@ def test_simple_properties(self): self.assertEqual(reversed_pt.identifier, 'reverse') - def test_time_reversal_program(self): + def test_time_reversal_loop(self): inner = ConstantPT(4, {'a': 3}) @ FunctionPT('sin(t)', 5, channel='a') manual_reverse = FunctionPT('sin(5 - t)', 5, channel='a') @ ConstantPT(4, {'a': 3}) time_reversed = TimeReversalPulseTemplate(inner) - program = time_reversed.create_program() - manual_program = manual_reverse.create_program() + program = time_reversed.create_program(program_builder=LoopBuilder()) + manual_program = manual_reverse.create_program(program_builder=LoopBuilder()) t, data, _ = render(program, 9 / 10) _, manual_data, _ = render(manual_program, 9 / 10) np.testing.assert_allclose(data['a'], manual_data['a']) + def test_time_reversal_linspace(self): + constant_pt = ConstantPT(4, {'a': '3.0 + x * 1.0 + y * -0.3'}) + function_pt = FunctionPT('sin(t)', 5, channel='a') + reversed_function_pt = FunctionPT('sin(5 - t)', 5, channel='a') + + inner = (constant_pt @ function_pt).with_iteration('x', 6) + inner_manual = (reversed_function_pt @ constant_pt).with_iteration('x', (5, -1, -1)) + + outer = inner.with_time_reversal().with_iteration('y', 8) + outer_man = inner_manual.with_iteration('y', 8) + + self.assertEqual(outer.duration, outer_man.duration) + + program = outer.create_program(program_builder=LinSpaceBuilder(channels=('a',))) + manual_program = outer_man.create_program(program_builder=LinSpaceBuilder(channels=('a',))) + + commands = to_increment_commands(program) + manual_commands = to_increment_commands(manual_program) + + manual_vm = LinSpaceVM(1) + manual_vm.set_commands(manual_commands) + manual_vm.run() + + vm = LinSpaceVM(1) + vm.set_commands(commands) + vm.run() + + assert_vm_output_almost_equal(self, manual_vm.history, vm.history) + class TimeReversalPulseTemplateSerializationTests(unittest.TestCase, SerializableTests): @property