diff --git a/qupulse/program/linspace.py b/qupulse/program/linspace.py index 43d01113..600a6f82 100644 --- a/qupulse/program/linspace.py +++ b/qupulse/program/linspace.py @@ -341,7 +341,7 @@ def _add_iteration_node(self, node: LinSpaceIter): self.add_node(node.body) if node.length > 1: - self.iterations[-1] = node.length + self.iterations[-1] = node.length - 1 label, jmp = self.new_loop(node.length - 1) self.commands.append(label) self.add_node(node.body) @@ -412,3 +412,72 @@ def to_increment_commands(linspace_nodes: Sequence[LinSpaceNode]) -> List[Comman state.add_node(linspace_nodes) return state.commands + +class LinSpaceVM: + def __init__(self, channels: int): + self.current_values = [np.nan] * channels + self.time = TimeType(0) + self.registers = tuple({} for _ in range(channels)) + + self.history: List[Tuple[TimeType, Tuple[float, ...]]] = [] + + self.commands = None + self.label_targets = None + self.label_counts = None + self.current_command = None + + def change_state(self, cmd: Union[Set, Increment, Wait, Play]): + if isinstance(cmd, Play): + raise NotImplementedError("TODO: Implement arbitrary waveform simulation") + elif isinstance(cmd, Wait): + self.history.append( + (self.time, self.current_values.copy()) + ) + self.time += cmd.duration + elif isinstance(cmd, Set): + self.current_values[cmd.channel] = cmd.value + self.registers[cmd.channel][cmd.key] = cmd.value + elif isinstance(cmd, Increment): + value = self.registers[cmd.channel][cmd.dependency_key] + value += cmd.value + self.registers[cmd.channel][cmd.dependency_key] = value + self.current_values[cmd.channel] = value + else: + raise NotImplementedError(cmd) + + def set_commands(self, commands: Sequence[Command]): + self.commands = [] + self.label_targets = {} + self.label_counts = {} + self.current_command = None + + for cmd in commands: + self.commands.append(cmd) + if isinstance(cmd, LoopLabel): + # a loop label signifies a reset count followed by the actual label that targets the following command + assert cmd.idx not in self.label_targets + self.label_targets[cmd.idx] = len(self.commands) + + self.current_command = 0 + + def step(self): + cmd = self.commands[self.current_command] + if isinstance(cmd, LoopJmp): + if self.label_counts[cmd.idx] > 0: + self.label_counts[cmd.idx] -= 1 + self.current_command = self.label_targets[cmd.idx] + else: + # ignore jump + self.current_command += 1 + elif isinstance(cmd, LoopLabel): + self.label_counts[cmd.idx] = cmd.count - 1 + self.current_command += 1 + else: + self.change_state(cmd) + self.current_command += 1 + + def run(self): + while self.current_command < len(self.commands): + self.step() + + diff --git a/tests/program/linspace_tests.py b/tests/program/linspace_tests.py index 03a5b297..23d050ac 100644 --- a/tests/program/linspace_tests.py +++ b/tests/program/linspace_tests.py @@ -6,6 +6,17 @@ from qupulse.program.linspace import * from qupulse.program.transformation import * + +def assert_vm_output_almost_equal(test: TestCase, expected, actual): + """Compare two vm outputs with default TestCase.assertAlmostEqual accuracy""" + test.assertEqual(len(expected), len(actual)) + for idx, ((t_e, vals_e), (t_a, vals_a)) in enumerate(zip(expected, actual)): + test.assertEqual(t_e, t_a, f"Differing times in {idx} element") + test.assertEqual(len(vals_e), len(vals_a), f"Differing channel count in {idx} element") + for ch, (val_e, val_a) in enumerate(zip(vals_e, vals_a)): + test.assertAlmostEqual(val_e, val_a, msg=f"Differing values in {idx} element channel {ch}") + + class SingleRampTest(TestCase): def setUp(self): hold = ConstantPT(10 ** 6, {'a': '-1. + idx * 0.01'}) @@ -32,6 +43,10 @@ def setUp(self): LoopJmp(0) ] + self.output = [ + (TimeType(10**6 * idx), [sum([-1.0] + [0.01] * idx)]) for idx in range(200) + ] + def test_program(self): program_builder = LinSpaceBuilder(('a',)) program = self.pulse_template.create_program(program_builder=program_builder) @@ -41,6 +56,12 @@ def test_commands(self): commands = to_increment_commands([self.program]) self.assertEqual(self.commands, commands) + def test_output(self): + vm = LinSpaceVM(1) + vm.set_commands(commands=self.commands) + vm.run() + assert_vm_output_almost_equal(self, self.output, vm.history) + class PlainCSDTest(TestCase): def setUp(self): @@ -74,7 +95,7 @@ def setUp(self): LoopLabel(1, 99), - Increment(0, -2.0, key_0), + Increment(0, -1.99, key_0), Increment(1, 0.02, key_1), Wait(TimeType(10 ** 6)), @@ -86,6 +107,16 @@ def setUp(self): LoopJmp(1), ] + a_values = [sum([-1.] + [0.01] * i) for i in range(200)] + b_values = [sum([-.5] + [0.02] * j) for j in range(100)] + + self.output = [ + ( + TimeType(10 ** 6 * (i + 200 * j)), + [a_values[i], b_values[j]] + ) for j in range(100) for i in range(200) + ] + def test_program(self): program_builder = LinSpaceBuilder(('a', 'b')) program = self.pulse_template.create_program(program_builder=program_builder) @@ -95,13 +126,20 @@ def test_increment_commands(self): commands = to_increment_commands([self.program]) self.assertEqual(self.commands, commands) + def test_output(self): + vm = LinSpaceVM(2) + vm.set_commands(self.commands) + vm.run() + assert_vm_output_almost_equal(self, self.output, vm.history) + class TiltedCSDTest(TestCase): def setUp(self): + repetition_count = 3 hold = ConstantPT(10**6, {'a': '-1. + idx_a * 0.01 + idx_b * 1e-3', 'b': '-.5 + idx_b * 0.02 - 3e-3 * idx_a'}) scan_a = hold.with_iteration('idx_a', 200) self.pulse_template = scan_a.with_iteration('idx_b', 100) - self.repeated_pt = self.pulse_template.with_repetition(42) + self.repeated_pt = self.pulse_template.with_repetition(repetition_count) self.program = LinSpaceIter(length=100, body=(LinSpaceIter( length=200, @@ -113,7 +151,7 @@ def setUp(self): duration_factors=None ),) ),)) - self.repeated_program = LinSpaceRepeat(body=(self.program,), count=42) + self.repeated_program = LinSpaceRepeat(body=(self.program,), count=repetition_count) key_0 = DepKey.from_voltages((1e-3, 0.01,), DEFAULT_INCREMENT_RESOLUTION) key_1 = DepKey.from_voltages((0.02, -3e-3), DEFAULT_INCREMENT_RESOLUTION) @@ -131,8 +169,8 @@ def setUp(self): LoopLabel(1, 99), - Increment(0, 1e-3 + -200 * 1e-2, key_0), - Increment(1, 0.02 + -200 * -3e-3, key_1), + Increment(0, 1e-3 + -199 * 1e-2, key_0), + Increment(1, 0.02 + -199 * -3e-3, key_1), Wait(TimeType(10 ** 6)), LoopLabel(2, 199), @@ -147,7 +185,19 @@ def setUp(self): for cmd in inner_commands: if hasattr(cmd, 'idx'): cmd.idx += 1 - self.repeated_commands = [LoopLabel(0, 42)] + inner_commands + [LoopJmp(0)] + self.repeated_commands = [LoopLabel(0, repetition_count)] + inner_commands + [LoopJmp(0)] + + self.output = [ + ( + TimeType(10 ** 6 * (i + 200 * j)), + [-1. + i * 0.01 + j * 1e-3, -.5 + j * 0.02 - 3e-3 * i] + ) for j in range(100) for i in range(200) + ] + self.repeated_output = [ + (t + TimeType(10**6) * (n * 100 * 200), vals) + for n in range(repetition_count) + for t, vals in self.output + ] def test_program(self): program_builder = LinSpaceBuilder(('a', 'b')) @@ -167,6 +217,18 @@ def test_repeated_increment_commands(self): commands = to_increment_commands([self.repeated_program]) self.assertEqual(self.repeated_commands, commands) + def test_output(self): + vm = LinSpaceVM(2) + vm.set_commands(self.commands) + vm.run() + assert_vm_output_almost_equal(self, self.output, vm.history) + + def test_repeated_output(self): + vm = LinSpaceVM(2) + vm.set_commands(self.repeated_commands) + vm.run() + assert_vm_output_almost_equal(self, self.repeated_output, vm.history) + class SingletLoadProcessing(TestCase): def setUp(self): @@ -223,7 +285,7 @@ def setUp(self): Set(0, -0.4), Set(1, -0.3), Wait(TimeType(10 ** 5)), - Increment(0, -2.0, key_0), + Increment(0, -1.99, key_0), Increment(1, 0.02, key_1), Wait(TimeType(10 ** 6)), Set(0, 0.05), @@ -247,6 +309,23 @@ def setUp(self): LoopJmp(1), ] + self.output = [] + time = TimeType(0) + for idx_b in range(100): + for idx_a in range(200): + self.output.append( + (time, [-.4, -.3]) + ) + time += 10 ** 5 + self.output.append( + (time, [-1. + idx_a * 0.01, -.5 + idx_b * 0.02]) + ) + time += 10 ** 6 + self.output.append( + (time, [0.05, 0.06]) + ) + time += 10 ** 5 + def test_singlet_scan_program(self): program_builder = LinSpaceBuilder(('a', 'b')) program = self.pulse_template.create_program(program_builder=program_builder) @@ -256,6 +335,12 @@ def test_singlet_scan_commands(self): commands = to_increment_commands([self.program]) self.assertEqual(self.commands, commands) + def test_singlet_scan_output(self): + vm = LinSpaceVM(2) + vm.set_commands(self.commands) + vm.run() + assert_vm_output_almost_equal(self, self.output, vm.history) + class TransformedRampTest(TestCase): def setUp(self):