Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change simpleexpression for equality handling #828

Open
wants to merge 35 commits into
base: feat/linspace_timesweep
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
7cfcb74
equality handling SimpleExpression; file refactor
Nomos11 May 31, 2024
74eb4d7
supply __lt__ as well to be logically consistent
Nomos11 May 31, 2024
f595806
implement dummywaveform abstract compare subset
Nomos11 Jun 3, 2024
bc985cd
tests
Nomos11 Jun 3, 2024
627a27f
Update setup.cfg
Nomos11 Jun 3, 2024
06aa1b3
delete wrong decorator
Nomos11 Jun 3, 2024
f87b1ea
somewhat fix the tests
Nomos11 Jun 3, 2024
e6b1e9e
uncomment imports
Nomos11 Jun 4, 2024
3c65a9c
forward voltage increment resolution
Nomos11 Jun 5, 2024
634cd30
wrongly assumed hardware resolution
Nomos11 Jun 5, 2024
7d6b01a
draft dependent waits & dependency domains
Nomos11 Jun 6, 2024
1ea5176
first syntactic debug
Nomos11 Jun 6, 2024
37ee51d
fix definition of iterations
Nomos11 Jun 6, 2024
b37b510
see if one can replace int->ChannelID
Nomos11 Jun 7, 2024
c796b61
fix channel trafo call
Nomos11 Jun 7, 2024
6d4e835
fix transform commands
Nomos11 Jun 8, 2024
9fd050c
resolution dependent set/increment
Nomos11 Jun 8, 2024
051e8b9
remove outdates resolution handling attempt
Nomos11 Jun 8, 2024
e58277c
math methods for resolution class
Nomos11 Jun 8, 2024
cf74051
fix domain check
Nomos11 Jun 8, 2024
2447e23
fix __mul__
Nomos11 Jun 8, 2024
a9f1710
test wf amp sweep
Nomos11 Jun 11, 2024
d9dad64
bugfix
Nomos11 Jun 11, 2024
5f22a52
not sure if correct (depstate comparison)
Nomos11 Jun 11, 2024
b04cbcf
fix some of the depkey confusion
Nomos11 Jun 12, 2024
e4c2366
more flexible repetition in sequence structure
Nomos11 Jun 13, 2024
f250226
dependency_key -> key for consistency
Nomos11 Jun 13, 2024
e84ecf8
always emit incr/set before wait
Nomos11 Jun 13, 2024
73513eb
dirty stepped play node in LinSpaceBuilder
Nomos11 Jun 15, 2024
468a46c
further bugfixes
Nomos11 Jun 16, 2024
a7eacf4
further bug patching
Nomos11 Jun 16, 2024
550a31c
hash Commands
Nomos11 Jun 17, 2024
fc8aee4
only modify commands that affect the current awg
Nomos11 Jul 5, 2024
b3176f0
re-commit P.S.' initial changes
Nomos11 Jul 10, 2024
def9369
re-commit P.S.' bugfixes
Nomos11 Jul 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions qupulse/expressions/simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import numpy as np
from numbers import Real, Number
from typing import Optional, Union, Sequence, ContextManager, Mapping, Tuple, Generic, TypeVar, Iterable, Dict, List
from dataclasses import dataclass

from functools import total_ordering
from qupulse.utils.sympy import _lambdify_modules
from qupulse.expressions import sympy as sym_expr, Expression
from qupulse.utils.types import MeasurementWindow, TimeType, FrozenMapping


NumVal = TypeVar('NumVal', bound=Real)


@total_ordering
@dataclass
class SimpleExpression(Generic[NumVal]):
"""This is a potential hardware evaluable expression of the form

C + C1*R1 + C2*R2 + ...
where R1, R2, ... are potential runtime parameters.

The main use case is the expression of for loop dependent variables where the Rs are loop indices. There the
expressions can be calculated via simple increments.
"""

base: NumVal
offsets: Mapping[str, NumVal]

def __post_init__(self):
assert isinstance(self.offsets, Mapping)

def value(self, scope: Mapping[str, NumVal]) -> NumVal:
value = self.base
for name, factor in self.offsets.items():
value += scope[name] * factor
return value

def __abs__(self):
return abs(self.base)+sum([abs(o) for o in self.offsets.values()])

def __eq__(self, other):
#there is no good way to compare it without having a value,
#but cannot require more parameters in magic method?
#so have this weird full equality for now which doesn logically make sense
#in most cases to catch unintended consequences

if isinstance(other, (float, int, TimeType)):
return self.base==other and all([o==other for o in self.offsets])

if type(other) == type(self):
if len(self.offsets)!=len(other.offsets): return False
return self.base==other.base and all([o1==o2 for o1,o2 in zip(self.offsets,other.offsets)])

return NotImplemented

def __gt__(self, other):
return all([b for b in self._return_greater_comparison_bools(other)])

def __lt__(self, other):
return all([not b for b in self._return_greater_comparison_bools(other)])

def _return_greater_comparison_bools(self, other) -> List[bool]:
#there is no good way to compare it without having a value,
#but cannot require more parameters in magic method?
#so have this weird full equality for now which doesn logically make sense
#in most cases to catch unintended consequences
if isinstance(other, (float, int, TimeType)):
return [self.base>other] + [o>other for o in self.offsets.values()]

if type(other) == type(self):
if len(self.offsets)!=len(other.offsets): return [False]
return [self.base>other.base] + [o1>o2 for o1,o2 in zip(self.offsets.values(),other.offsets.values())]

return NotImplemented

def __add__(self, other):
if isinstance(other, (float, int, TimeType)):
return SimpleExpression(self.base + other, self.offsets)

if type(other) == type(self):
offsets = self.offsets.copy()
for name, value in other.offsets.items():
offsets[name] = value + offsets.get(name, 0)
return SimpleExpression(self.base + other.base, offsets)

return NotImplemented

def __radd__(self, other):
return self.__add__(other)

def __sub__(self, other):
return self.__add__(-other)

def __rsub__(self, other):
(-self).__add__(other)

def __neg__(self):
return SimpleExpression(-self.base, {name: -value for name, value in self.offsets.items()})

def __mul__(self, other: NumVal):
if isinstance(other, (float, int, TimeType)):
return SimpleExpression(self.base * other, {name: other * value for name, value in self.offsets.items()})

return NotImplemented

def __rmul__(self, other):
return self.__mul__(other)

def __truediv__(self, other):
inv = 1 / other
return self.__mul__(inv)

def __hash__(self):
return hash((self.base,frozenset(sorted(self.offsets.items()))))

@property
def free_symbols(self):
return ()

def _sympy_(self):
return self

def replace(self, r, s):
return self

def evaluate_in_scope_(self, *args, **kwargs):
# TODO: remove. It is currently required to avoid nesting this class in an expression for the MappedScope
# We can maybe replace is with a HardwareScope or something along those lines
return self


#alibi class to allow instance check?
@dataclass
class SimpleExpressionStepped(SimpleExpression):
step_nesting_level: int
rng: range


_lambdify_modules.append({'SimpleExpression': SimpleExpression, 'SimpleExpressionStepped': SimpleExpressionStepped})
44 changes: 35 additions & 9 deletions qupulse/hardware/awgs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
from qupulse.hardware.util import get_sample_times, not_none_indices
from qupulse.utils.types import ChannelID
from qupulse.program.linspace import LinSpaceNode, LinSpaceArbitraryWaveform, to_increment_commands, Command, \
Increment, Set as LSPSet, LoopLabel, LoopJmp, Wait, Play
Increment, Set as LSPSet, LoopLabel, LoopJmp, Wait, Play, DEFAULT_INCREMENT_RESOLUTION, DepDomain
from qupulse.program.loop import Loop
from qupulse.program.waveforms import Waveform
from qupulse.program.waveforms import Waveform, WaveformCollection
from qupulse.comparable import Comparable
from qupulse.utils.types import TimeType

Expand Down Expand Up @@ -191,6 +191,7 @@ def __init__(self, program: AllowedProgramTypes,
voltage_transformations: Tuple[Optional[Callable], ...],
sample_rate: TimeType,
waveforms: Sequence[Waveform] = None,
# voltage_resolution: Optional[float] = None,
program_type: _ProgramType = _ProgramType.Loop):
"""

Expand All @@ -204,6 +205,8 @@ def __init__(self, program: AllowedProgramTypes,
sample_rate:
waveforms: These waveforms are sampled and stored in _waveforms. If None the waveforms are extracted from
loop
# voltage_resolution: voltage resolution for LinSpaceProgram, i.e. 2**(-16) for 16 bit AWG
program_type: type of program from _ProgramType, determined by the ProgramBuilder used.
"""
assert len(channels) == len(amplitudes) == len(offsets) == len(voltage_transformations)

Expand All @@ -218,8 +221,11 @@ def __init__(self, program: AllowedProgramTypes,
self._program_type = program_type
self._program = program

# self._voltage_resolution = voltage_resolution

if program_type == _ProgramType.Linspace:
self._transformed_commands = self._transform_linspace_commands(to_increment_commands(program))
#!!! the voltage resolution may not be adequately represented if voltage transformations are not None?
self._transformed_commands = self._transform_linspace_commands(to_increment_commands(program,))

if waveforms is None:
if program_type is _ProgramType.Loop:
Expand All @@ -228,8 +234,18 @@ def __init__(self, program: AllowedProgramTypes,
elif program_type is _ProgramType.Linspace:
#not so clean
#TODO: also marker handling not optimal
waveforms = OrderedDict((command.waveform, None)
for command in self._transformed_commands if isinstance(command,Play)).keys()
waveforms_d = OrderedDict()
for command in self._transformed_commands:
if not isinstance(command,Play):
continue
if isinstance(command.waveform,Waveform):
waveforms_d[command.waveform] = None
elif isinstance(command.waveform,WaveformCollection):
for w in command.waveform.flatten():
waveforms_d[w] = None
else:
raise NotImplementedError()
waveforms = waveforms_d.keys()
else:
raise NotImplementedError()

Expand Down Expand Up @@ -267,20 +283,30 @@ def _channel_transformations(self) -> Mapping[ChannelID, ChannelTransformation]:

def _transform_linspace_commands(self, command_list: List[Command]) -> List[Command]:
# all commands = Union[Increment, Set, LoopLabel, LoopJmp, Wait, Play]
trafos_by_channel_idx = list(self._channel_transformations().values())

# TODO: voltage resolution

# trafos_by_channel_idx = list(self._channel_transformations().values())
# increment_domains_to_transform = {DepDomain.VOLTAGE, DepDomain.WF_SCALE, DepDomain.WF_OFFSET}

for command in command_list:
if isinstance(command, (LoopLabel, LoopJmp, Play, Wait)):
# play is handled by transforming the sampled waveform
continue
elif isinstance(command, Increment):
ch_trafo = trafos_by_channel_idx[command.channel]
if command.key.domain is not DepDomain.VOLTAGE:
#for sweeps of wf-scale and wf-offset, the channel amplitudes/offsets are already considered in the wf sampling.
continue
ch_trafo = self._channel_transformations()[command.channel]
if ch_trafo.voltage_transformation:
raise RuntimeError("Cannot apply a voltage transformation to a linspace increment command")
command.value /= ch_trafo.amplitude
elif isinstance(command, LSPSet):
ch_trafo = trafos_by_channel_idx[command.channel]
if command.key.domain is not DepDomain.VOLTAGE:
#for sweeps of wf-scale and wf-offset, the channel amplitudes/offsets are already considered in the wf sampling.
continue
ch_trafo = self._channel_transformations()[command.channel]
if ch_trafo.voltage_transformation:
# for the case of swept parameters, this is defaulted to identity
command.value = float(ch_trafo.voltage_transformation(command.value))
command.value -= ch_trafo.offset
command.value /= ch_trafo.amplitude
Expand Down
104 changes: 10 additions & 94 deletions qupulse/program/__init__.py
Original file line number Diff line number Diff line change
@@ -1,101 +1,12 @@
import contextlib
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, Union, Sequence, ContextManager, Mapping, Tuple, Generic, TypeVar, Iterable, Dict
from numbers import Real, Number

import numpy as np
from typing import Protocol, runtime_checkable

from qupulse._program.waveforms import Waveform
from qupulse.utils.types import MeasurementWindow, TimeType, FrozenMapping
from qupulse.utils.types import MeasurementWindow, TimeType
from qupulse._program.volatile import VolatileRepetitionCount
from qupulse.parameter_scope import Scope
from qupulse.expressions import sympy as sym_expr, Expression
from qupulse.utils.sympy import _lambdify_modules

from typing import Protocol, runtime_checkable


NumVal = TypeVar('NumVal', bound=Real)


@dataclass
class SimpleExpression(Generic[NumVal]):
"""This is a potential hardware evaluable expression of the form

C + C1*R1 + C2*R2 + ...
where R1, R2, ... are potential runtime parameters.

The main use case is the expression of for loop dependent variables where the Rs are loop indices. There the
expressions can be calculated via simple increments.
"""

base: NumVal
offsets: Mapping[str, NumVal]

def __post_init__(self):
assert isinstance(self.offsets, Mapping)

def value(self, scope: Mapping[str, NumVal]) -> NumVal:
value = self.base
for name, factor in self.offsets:
value += scope[name] * factor
return value

def __add__(self, other):
if isinstance(other, (float, int, TimeType)):
return SimpleExpression(self.base + other, self.offsets)

if type(other) == type(self):
offsets = self.offsets.copy()
for name, value in other.offsets.items():
offsets[name] = value + offsets.get(name, 0)
return SimpleExpression(self.base + other.base, offsets)

return NotImplemented

def __radd__(self, other):
return self.__add__(other)

def __sub__(self, other):
return self.__add__(-other)

def __rsub__(self, other):
(-self).__add__(other)

def __neg__(self):
return SimpleExpression(-self.base, {name: -value for name, value in self.offsets.items()})

def __mul__(self, other: NumVal):
if isinstance(other, (float, int, TimeType)):
return SimpleExpression(self.base * other, {name: other * value for name, value in self.offsets.items()})

return NotImplemented

def __rmul__(self, other):
return self.__mul__(other)

def __truediv__(self, other):
inv = 1 / other
return self.__mul__(inv)

@property
def free_symbols(self):
return ()

def _sympy_(self):
return self

def replace(self, r, s):
return self

def evaluate_in_scope_(self, *args, **kwargs):
# TODO: remove. It is currently required to avoid nesting this class in an expression for the MappedScope
# We can maybe replace is with a HardwareScope or something along those lines
return self


_lambdify_modules.append({'SimpleExpression': SimpleExpression})
from qupulse.expressions import sympy as sym_expr
from qupulse.expressions.simple import SimpleExpression


RepetitionCount = Union[int, VolatileRepetitionCount, SimpleExpression[int]]
Expand Down Expand Up @@ -155,9 +66,14 @@ def new_subprogram(self, global_transformation: 'Transformation' = None) -> Cont
it is not empty."""

def with_iteration(self, index_name: str, rng: range,
pt_obj: 'ForLoopPT', #hack this in for now.
# can be placed more suitably, like in pulsemetadata later on, but need some working thing now.
measurements: Optional[Sequence[MeasurementWindow]] = None) -> Iterable['ProgramBuilder']:
pass


def evaluate_nested_stepping(self, scope: Scope, parameter_names: set[str]) -> bool:
return False

def to_program(self) -> Optional[Program]:
"""Further addition of new elements might fail after finalizing the program."""

Expand Down
Loading
Loading