Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test and implementation for TimeReversalPulseTemplate ProgramBuilder support #866

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions qupulse/program/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@ def with_iteration(self, index_name: str, rng: range,
measurements: Optional[Sequence[MeasurementWindow]] = None) -> Iterable['ProgramBuilder']:
pass

def time_reversed(self) -> ContextManager['ProgramBuilder']:
pass

def to_program(self) -> Optional[Program]:
"""Further addition of new elements might fail after finalizing the program."""

Expand Down
63 changes: 62 additions & 1 deletion qupulse/program/linspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ class LinSpaceNode:
def dependencies(self) -> Mapping[int, set]:
raise NotImplementedError

def reversed(self, offset: int, lengths: list):
raise NotImplementedError


@dataclass
class LinSpaceHold(LinSpaceNode):
Expand All @@ -60,13 +63,42 @@ def dependencies(self) -> Mapping[int, set]:
for idx, factors in enumerate(self.factors)
if factors}

def reversed(self, offset: int, lengths: list):
if not lengths:
return self
bases = []
factors = []
for ch_base, ch_factors in zip(self.bases, self.factors):
if ch_factors is None or len(ch_factors) <= offset:
bases.append(ch_base)
factors.append(ch_factors)
else:
ch_reverse_base = ch_base + sum(length*factor for factor, length in zip(ch_factors[offset:], lengths))
reversed_factors = ch_factors[:offset] + tuple(-f for f in ch_factors[offset:])
bases.append(ch_reverse_base)
factors.append(reversed_factors)

if self.duration_factors is None or len(self.duration_factors) <= offset:
duration_factors = self.duration_factors
duration_base = self.duration_base
else:
duration_base = self.duration_base + sum((length*factor for factor, length in zip(self.duration_factors[offset:], lengths)), TimeType(0))
duration_factors = self.duration_factors[:offset] + tuple(-f for f in self.duration_factors[offset:])
return LinSpaceHold(tuple(bases), tuple(factors), duration_base=duration_base, duration_factors=duration_factors)


@dataclass
class LinSpaceArbitraryWaveform(LinSpaceNode):
"""This is just a wrapper to pipe arbitrary waveforms through the system."""
waveform: Waveform
channels: Tuple[ChannelID, ...]

def reversed(self, offset: int, lengths: list):
return LinSpaceArbitraryWaveform(
waveform=self.waveform.reversed(),
channels=self.channels,
)


@dataclass
class LinSpaceRepeat(LinSpaceNode):
Expand All @@ -81,6 +113,9 @@ def dependencies(self):
dependencies.setdefault(idx, set()).update(deps)
return dependencies

def reversed(self, offset: int, counts: list):
return LinSpaceRepeat(tuple(node.reversed(offset, counts) for node in reversed(self.body)), self.count)


@dataclass
class LinSpaceIter(LinSpaceNode):
Expand All @@ -100,6 +135,12 @@ def dependencies(self):
dependencies.setdefault(idx, set()).update(shortened)
return dependencies

def reversed(self, offset: int, lengths: list):
lengths.append(self.length)
reversed_iter = LinSpaceIter(tuple(node.reversed(offset, lengths) for node in reversed(self.body)), self.length)
lengths.pop()
return reversed_iter


class LinSpaceBuilder(ProgramBuilder):
"""This program builder supports efficient translation of pulse templates that use symbolic linearly
Expand Down Expand Up @@ -214,6 +255,14 @@ def with_iteration(self, index_name: str, rng: range,
if cmds:
self._stack[-1].append(LinSpaceIter(body=tuple(cmds), length=len(rng)))

@contextlib.contextmanager
def time_reversed(self) -> ContextManager['LinSpaceBuilder']:
self._stack.append([])
yield self
inner = self._stack.pop()
offset = len(self._ranges)
self._stack[-1].extend(node.reversed(offset, []) for node in reversed(inner))

def to_program(self) -> Optional[Sequence[LinSpaceNode]]:
if self._root():
return self._root()
Expand Down Expand Up @@ -428,7 +477,19 @@ def __init__(self, channels: int):

def change_state(self, cmd: Union[Set, Increment, Wait, Play]):
if isinstance(cmd, Play):
raise NotImplementedError("TODO: Implement arbitrary waveform simulation")
num = 17
dt = cmd.waveform.duration / num
t = TimeType(0)
for _ in range(num):
sample_time = np.array([float(t)])
values = []
for (idx, ch) in enumerate(cmd.channels):
self.current_values[idx] = values.append(cmd.waveform.get_sampled(channel=ch, sample_times=sample_time)[0])
self.history.append(
(self.time, self.current_values.copy())
)
self.time += dt
t += dt
elif isinstance(cmd, Wait):
self.history.append(
(self.time, self.current_values.copy())
Expand Down
10 changes: 10 additions & 0 deletions qupulse/program/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,16 @@ def with_iteration(self, index_name: str, rng: range,
top_frame.iterating = (index_name, value)
yield self

@contextmanager
def time_reversed(self) -> ContextManager['LoopBuilder']:
inner_builder = LoopBuilder()
yield inner_builder
inner_program = inner_builder.to_program()

if inner_program:
inner_program.reverse_inplace()
self._try_append(inner_program, None)

@contextmanager
def with_sequence(self, measurements: Optional[Sequence[MeasurementWindow]] = None) -> ContextManager['ProgramBuilder']:
top_frame = StackFrame(LoopGuard(self._top, measurements), None)
Expand Down
2 changes: 1 addition & 1 deletion qupulse/program/waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1257,7 +1257,7 @@ def unsafe_sample(self, channel: ChannelID, sample_times: np.ndarray,
else:
inner_output_array = output_array[::-1]
inner_output_array = self._inner.unsafe_sample(channel, inner_sample_times, output_array=inner_output_array)
if inner_output_array.base not in (output_array, output_array.base):
if id(inner_output_array.base) not in (id(output_array), id(output_array.base)):
# TODO: is there a guarantee by numpy we never end up here?
output_array[:] = inner_output_array[::-1]
return output_array
Expand Down
11 changes: 4 additions & 7 deletions qupulse/pulses/time_reversal_pulse_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Optional, Set, Dict, Union

from qupulse import ChannelID
from qupulse.program.loop import Loop
from qupulse.program import ProgramBuilder
from qupulse.program.waveforms import Waveform
from qupulse.serialization import PulseRegistryType
from qupulse.expressions import ExpressionScalar
Expand Down Expand Up @@ -50,12 +50,9 @@ def defined_channels(self) -> Set['ChannelID']:
def integral(self) -> Dict[ChannelID, ExpressionScalar]:
return self._inner.integral

def _internal_create_program(self, *, parent_loop: Loop, **kwargs) -> None:
inner_loop = Loop()
self._inner._internal_create_program(parent_loop=inner_loop, **kwargs)
inner_loop.reverse_inplace()

parent_loop.append_child(inner_loop)
def _internal_create_program(self, *, program_builder: ProgramBuilder, **kwargs) -> None:
with program_builder.time_reversed() as reversed_builder:
self._inner._internal_create_program(program_builder=reversed_builder, **kwargs)

def build_waveform(self,
*args, **kwargs) -> Optional[Waveform]:
Expand Down
2 changes: 1 addition & 1 deletion tests/program/linspace_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def assert_vm_output_almost_equal(test: TestCase, expected, actual):
test.assertEqual(t_e, t_a, f"Differing times in {idx} element")
test.assertEqual(len(vals_e), len(vals_a), f"Differing channel count in {idx} element")
for ch, (val_e, val_a) in enumerate(zip(vals_e, vals_a)):
test.assertAlmostEqual(val_e, val_a, msg=f"Differing values in {idx} element channel {ch}")
test.assertAlmostEqual(val_e, val_a, msg=f"Differing values in {idx} of {len(expected)} element channel {ch}")


class SingleRampTest(TestCase):
Expand Down
51 changes: 49 additions & 2 deletions tests/pulses/time_reversal_pulse_template_tests.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import unittest

import numpy as np

from qupulse.pulses import ConstantPT, FunctionPT
from qupulse.plotting import render
from qupulse.pulses.time_reversal_pulse_template import TimeReversalPulseTemplate
from qupulse.utils.types import TimeType
from qupulse.expressions import ExpressionScalar

from qupulse.program.loop import LoopBuilder
from qupulse.program.linspace import LinSpaceBuilder, LinSpaceVM, to_increment_commands
from tests.pulses.sequencing_dummies import DummyPulseTemplate
from tests.serialization_tests import SerializableTests

from tests.program.linspace_tests import assert_vm_output_almost_equal

class TimeReversalPulseTemplateTests(unittest.TestCase):
def test_simple_properties(self):
Expand All @@ -25,6 +30,48 @@ def test_simple_properties(self):

self.assertEqual(reversed_pt.identifier, 'reverse')

def test_time_reversal_loop(self):
inner = ConstantPT(4, {'a': 3}) @ FunctionPT('sin(t)', 5, channel='a')
manual_reverse = FunctionPT('sin(5 - t)', 5, channel='a') @ ConstantPT(4, {'a': 3})
time_reversed = TimeReversalPulseTemplate(inner)

program = time_reversed.create_program(program_builder=LoopBuilder())
manual_program = manual_reverse.create_program(program_builder=LoopBuilder())

t, data, _ = render(program, 9 / 10)
_, manual_data, _ = render(manual_program, 9 / 10)

np.testing.assert_allclose(data['a'], manual_data['a'])

def test_time_reversal_linspace(self):
constant_pt = ConstantPT(4, {'a': '3.0 + x * 1.0 + y * -0.3'})
function_pt = FunctionPT('sin(t)', 5, channel='a')
reversed_function_pt = FunctionPT('sin(5 - t)', 5, channel='a')

inner = (constant_pt @ function_pt).with_iteration('x', 6)
inner_manual = (reversed_function_pt @ constant_pt).with_iteration('x', (5, -1, -1))

outer = inner.with_time_reversal().with_iteration('y', 8)
outer_man = inner_manual.with_iteration('y', 8)

self.assertEqual(outer.duration, outer_man.duration)

program = outer.create_program(program_builder=LinSpaceBuilder(channels=('a',)))
manual_program = outer_man.create_program(program_builder=LinSpaceBuilder(channels=('a',)))

commands = to_increment_commands(program)
manual_commands = to_increment_commands(manual_program)

manual_vm = LinSpaceVM(1)
manual_vm.set_commands(manual_commands)
manual_vm.run()

vm = LinSpaceVM(1)
vm.set_commands(commands)
vm.run()

assert_vm_output_almost_equal(self, manual_vm.history, vm.history)


class TimeReversalPulseTemplateSerializationTests(unittest.TestCase, SerializableTests):
@property
Expand Down
Loading