Skip to content

Commit

Permalink
Merge pull request #861 from qutech/issues/860_linspace_vm
Browse files Browse the repository at this point in the history
Add simulator for linspace program
  • Loading branch information
terrorfisch authored Jun 21, 2024
2 parents 83a6421 + e2801ff commit c6b0ff4
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 8 deletions.
71 changes: 70 additions & 1 deletion qupulse/program/linspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def _add_iteration_node(self, node: LinSpaceIter):
self.add_node(node.body)

if node.length > 1:
self.iterations[-1] = node.length
self.iterations[-1] = node.length - 1
label, jmp = self.new_loop(node.length - 1)
self.commands.append(label)
self.add_node(node.body)
Expand Down Expand Up @@ -412,3 +412,72 @@ def to_increment_commands(linspace_nodes: Sequence[LinSpaceNode]) -> List[Comman
state.add_node(linspace_nodes)
return state.commands


class LinSpaceVM:
def __init__(self, channels: int):
self.current_values = [np.nan] * channels
self.time = TimeType(0)
self.registers = tuple({} for _ in range(channels))

self.history: List[Tuple[TimeType, Tuple[float, ...]]] = []

self.commands = None
self.label_targets = None
self.label_counts = None
self.current_command = None

def change_state(self, cmd: Union[Set, Increment, Wait, Play]):
if isinstance(cmd, Play):
raise NotImplementedError("TODO: Implement arbitrary waveform simulation")
elif isinstance(cmd, Wait):
self.history.append(
(self.time, self.current_values.copy())
)
self.time += cmd.duration
elif isinstance(cmd, Set):
self.current_values[cmd.channel] = cmd.value
self.registers[cmd.channel][cmd.key] = cmd.value
elif isinstance(cmd, Increment):
value = self.registers[cmd.channel][cmd.dependency_key]
value += cmd.value
self.registers[cmd.channel][cmd.dependency_key] = value
self.current_values[cmd.channel] = value
else:
raise NotImplementedError(cmd)

def set_commands(self, commands: Sequence[Command]):
self.commands = []
self.label_targets = {}
self.label_counts = {}
self.current_command = None

for cmd in commands:
self.commands.append(cmd)
if isinstance(cmd, LoopLabel):
# a loop label signifies a reset count followed by the actual label that targets the following command
assert cmd.idx not in self.label_targets
self.label_targets[cmd.idx] = len(self.commands)

self.current_command = 0

def step(self):
cmd = self.commands[self.current_command]
if isinstance(cmd, LoopJmp):
if self.label_counts[cmd.idx] > 0:
self.label_counts[cmd.idx] -= 1
self.current_command = self.label_targets[cmd.idx]
else:
# ignore jump
self.current_command += 1
elif isinstance(cmd, LoopLabel):
self.label_counts[cmd.idx] = cmd.count - 1
self.current_command += 1
else:
self.change_state(cmd)
self.current_command += 1

def run(self):
while self.current_command < len(self.commands):
self.step()


99 changes: 92 additions & 7 deletions tests/program/linspace_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,17 @@
from qupulse.program.linspace import *
from qupulse.program.transformation import *


def assert_vm_output_almost_equal(test: TestCase, expected, actual):
"""Compare two vm outputs with default TestCase.assertAlmostEqual accuracy"""
test.assertEqual(len(expected), len(actual))
for idx, ((t_e, vals_e), (t_a, vals_a)) in enumerate(zip(expected, actual)):
test.assertEqual(t_e, t_a, f"Differing times in {idx} element")
test.assertEqual(len(vals_e), len(vals_a), f"Differing channel count in {idx} element")
for ch, (val_e, val_a) in enumerate(zip(vals_e, vals_a)):
test.assertAlmostEqual(val_e, val_a, msg=f"Differing values in {idx} element channel {ch}")


class SingleRampTest(TestCase):
def setUp(self):
hold = ConstantPT(10 ** 6, {'a': '-1. + idx * 0.01'})
Expand All @@ -32,6 +43,10 @@ def setUp(self):
LoopJmp(0)
]

self.output = [
(TimeType(10**6 * idx), [sum([-1.0] + [0.01] * idx)]) for idx in range(200)
]

def test_program(self):
program_builder = LinSpaceBuilder(('a',))
program = self.pulse_template.create_program(program_builder=program_builder)
Expand All @@ -41,6 +56,12 @@ def test_commands(self):
commands = to_increment_commands([self.program])
self.assertEqual(self.commands, commands)

def test_output(self):
vm = LinSpaceVM(1)
vm.set_commands(commands=self.commands)
vm.run()
assert_vm_output_almost_equal(self, self.output, vm.history)


class PlainCSDTest(TestCase):
def setUp(self):
Expand Down Expand Up @@ -74,7 +95,7 @@ def setUp(self):

LoopLabel(1, 99),

Increment(0, -2.0, key_0),
Increment(0, -1.99, key_0),
Increment(1, 0.02, key_1),
Wait(TimeType(10 ** 6)),

Expand All @@ -86,6 +107,16 @@ def setUp(self):
LoopJmp(1),
]

a_values = [sum([-1.] + [0.01] * i) for i in range(200)]
b_values = [sum([-.5] + [0.02] * j) for j in range(100)]

self.output = [
(
TimeType(10 ** 6 * (i + 200 * j)),
[a_values[i], b_values[j]]
) for j in range(100) for i in range(200)
]

def test_program(self):
program_builder = LinSpaceBuilder(('a', 'b'))
program = self.pulse_template.create_program(program_builder=program_builder)
Expand All @@ -95,13 +126,20 @@ def test_increment_commands(self):
commands = to_increment_commands([self.program])
self.assertEqual(self.commands, commands)

def test_output(self):
vm = LinSpaceVM(2)
vm.set_commands(self.commands)
vm.run()
assert_vm_output_almost_equal(self, self.output, vm.history)


class TiltedCSDTest(TestCase):
def setUp(self):
repetition_count = 3
hold = ConstantPT(10**6, {'a': '-1. + idx_a * 0.01 + idx_b * 1e-3', 'b': '-.5 + idx_b * 0.02 - 3e-3 * idx_a'})
scan_a = hold.with_iteration('idx_a', 200)
self.pulse_template = scan_a.with_iteration('idx_b', 100)
self.repeated_pt = self.pulse_template.with_repetition(42)
self.repeated_pt = self.pulse_template.with_repetition(repetition_count)

self.program = LinSpaceIter(length=100, body=(LinSpaceIter(
length=200,
Expand All @@ -113,7 +151,7 @@ def setUp(self):
duration_factors=None
),)
),))
self.repeated_program = LinSpaceRepeat(body=(self.program,), count=42)
self.repeated_program = LinSpaceRepeat(body=(self.program,), count=repetition_count)

key_0 = DepKey.from_voltages((1e-3, 0.01,), DEFAULT_INCREMENT_RESOLUTION)
key_1 = DepKey.from_voltages((0.02, -3e-3), DEFAULT_INCREMENT_RESOLUTION)
Expand All @@ -131,8 +169,8 @@ def setUp(self):

LoopLabel(1, 99),

Increment(0, 1e-3 + -200 * 1e-2, key_0),
Increment(1, 0.02 + -200 * -3e-3, key_1),
Increment(0, 1e-3 + -199 * 1e-2, key_0),
Increment(1, 0.02 + -199 * -3e-3, key_1),
Wait(TimeType(10 ** 6)),

LoopLabel(2, 199),
Expand All @@ -147,7 +185,19 @@ def setUp(self):
for cmd in inner_commands:
if hasattr(cmd, 'idx'):
cmd.idx += 1
self.repeated_commands = [LoopLabel(0, 42)] + inner_commands + [LoopJmp(0)]
self.repeated_commands = [LoopLabel(0, repetition_count)] + inner_commands + [LoopJmp(0)]

self.output = [
(
TimeType(10 ** 6 * (i + 200 * j)),
[-1. + i * 0.01 + j * 1e-3, -.5 + j * 0.02 - 3e-3 * i]
) for j in range(100) for i in range(200)
]
self.repeated_output = [
(t + TimeType(10**6) * (n * 100 * 200), vals)
for n in range(repetition_count)
for t, vals in self.output
]

def test_program(self):
program_builder = LinSpaceBuilder(('a', 'b'))
Expand All @@ -167,6 +217,18 @@ def test_repeated_increment_commands(self):
commands = to_increment_commands([self.repeated_program])
self.assertEqual(self.repeated_commands, commands)

def test_output(self):
vm = LinSpaceVM(2)
vm.set_commands(self.commands)
vm.run()
assert_vm_output_almost_equal(self, self.output, vm.history)

def test_repeated_output(self):
vm = LinSpaceVM(2)
vm.set_commands(self.repeated_commands)
vm.run()
assert_vm_output_almost_equal(self, self.repeated_output, vm.history)


class SingletLoadProcessing(TestCase):
def setUp(self):
Expand Down Expand Up @@ -223,7 +285,7 @@ def setUp(self):
Set(0, -0.4),
Set(1, -0.3),
Wait(TimeType(10 ** 5)),
Increment(0, -2.0, key_0),
Increment(0, -1.99, key_0),
Increment(1, 0.02, key_1),
Wait(TimeType(10 ** 6)),
Set(0, 0.05),
Expand All @@ -247,6 +309,23 @@ def setUp(self):
LoopJmp(1),
]

self.output = []
time = TimeType(0)
for idx_b in range(100):
for idx_a in range(200):
self.output.append(
(time, [-.4, -.3])
)
time += 10 ** 5
self.output.append(
(time, [-1. + idx_a * 0.01, -.5 + idx_b * 0.02])
)
time += 10 ** 6
self.output.append(
(time, [0.05, 0.06])
)
time += 10 ** 5

def test_singlet_scan_program(self):
program_builder = LinSpaceBuilder(('a', 'b'))
program = self.pulse_template.create_program(program_builder=program_builder)
Expand All @@ -256,6 +335,12 @@ def test_singlet_scan_commands(self):
commands = to_increment_commands([self.program])
self.assertEqual(self.commands, commands)

def test_singlet_scan_output(self):
vm = LinSpaceVM(2)
vm.set_commands(self.commands)
vm.run()
assert_vm_output_almost_equal(self, self.output, vm.history)


class TransformedRampTest(TestCase):
def setUp(self):
Expand Down

0 comments on commit c6b0ff4

Please sign in to comment.