Skip to content

Commit

Permalink
Update for command classes
Browse files Browse the repository at this point in the history
  • Loading branch information
thierry-martinez committed Jul 10, 2024
1 parent 7dfb7c0 commit 784dd34
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 79 deletions.
3 changes: 2 additions & 1 deletion graphix/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pydantic import BaseModel

from graphix.pauli import Plane
from graphix.parameter import ExpressionOrFloat

Node = int

Expand Down Expand Up @@ -48,7 +49,7 @@ class M(Command):
kind: CommandKind = CommandKind.M
node: Node
plane: Plane = Plane.XY
angle: float = 0.0
angle: ExpressionOrFloat = 0.0
s_domain: list[Node] = []
t_domain: list[Node] = []
vop: int = 0
Expand Down
9 changes: 5 additions & 4 deletions graphix/instruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pydantic import BaseModel

from graphix.pauli import Plane
from graphix.parameter import ExpressionOrFloat


class InstructionKind(enum.Enum):
Expand All @@ -18,7 +19,7 @@ class InstructionKind(enum.Enum):
X = "X"
Y = "Y"
Z = "Z"
I = "I"
I = "I" # noqa: E741
M = "M"
RX = "RX"
RY = "RY"
Expand Down Expand Up @@ -58,7 +59,7 @@ class RotationInstruction(OneQubitInstruction):
Rotation instruction base class model.
"""

angle: float
angle: ExpressionOrFloat


class OneControlInstruction(OneQubitInstruction):
Expand Down Expand Up @@ -166,7 +167,7 @@ class Z(OneQubitInstruction):
kind: InstructionKind = InstructionKind.Z


class I(OneQubitInstruction):
class I(OneQubitInstruction): # noqa: E742
"""
I circuit instruction.
"""
Expand All @@ -181,7 +182,7 @@ class M(OneQubitInstruction):

kind: InstructionKind = InstructionKind.M
plane: Plane
angle: float
angle: ExpressionOrFloat


class RX(RotationInstruction):
Expand Down
41 changes: 22 additions & 19 deletions graphix/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@

from __future__ import annotations

from abc import ABC, abstractmethod
import numbers
from abc import ABC, abstractmethod
import pydantic
import pydantic_core
import typing


class Expression(ABC):
"""Expression with parameters."""
Expand Down Expand Up @@ -89,22 +93,31 @@ def __str__(self) -> str: ...
@abstractmethod
def subs(self, variable: Parameter, value: ExpressionOrNumber) -> ExpressionOrNumber: ...

@abstractmethod
def flatten(self) -> ExpressionOrNumber: ...
@classmethod
def __get_pydantic_core_schema__(
cls, source_type: typing.Any, handler: pydantic.GetCoreSchemaHandler
) -> pydantic_core.CoreSchema:
def check_expression(obj) -> Expression:
if not isinstance(obj, Expression):
raise ValueError("Expression expected")
return obj

@abstractmethod
def conj(self) -> ExpressionOrNumber: ...
return pydantic_core.core_schema.no_info_plain_validator_function(function=check_expression)


class PlaceholderOperationError(ValueError):
def __init__(self):
super().__init__("Placeholder angles do not support any form of computation before substitution. Either use `subst` with an actual value before the computation, or use a symbolic parameter implementation, such that https://github.com/TeamGraphix/graphix-symbolic .")
super().__init__(
"Placeholder angles do not support any form of computation before substitution. Either use `subst` with an actual value before the computation, or use a symbolic parameter implementation, such that https://github.com/TeamGraphix/graphix-symbolic ."
)


class Parameter(Expression):
"""Abstract class for substituable parameter."""

...


class Placeholder(Parameter):
"""Placeholder for measurement angles, which allows the pattern optimizations
without specifying measurement angles for measurement commands.
Expand Down Expand Up @@ -133,12 +146,11 @@ def __str__(self) -> str:
return self.__name

def subs(self, variable: Parameter, value: ExpressionOrNumber) -> ExpressionOrNumber:
if self == variable:
if self is variable:
if isinstance(value, numbers.Number):
return complex(value)
return value
else:
return self
return self

def __mul__(self, other) -> ExpressionOrNumber:
return NotImplemented
Expand Down Expand Up @@ -206,17 +218,8 @@ def conjugate(self) -> ExpressionOrNumber:
def sqrt(self) -> ExpressionOrNumber:
raise PlaceholderOperationError()

def subs(self, variable: Parameter, value: ExpressionOrNumber) -> ExpressionOrNumber:
if variable is self:
return value
return self

def flatten(self) -> ExpressionOrNumber:
raise PlaceholderOperationError()

def conj(self) -> ExpressionOrNumber:
raise PlaceholderOperationError()

ExpressionOrFloat = Expression | float

ExpressionOrNumber = Expression | numbers.Number

Expand Down
16 changes: 7 additions & 9 deletions graphix/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@

from __future__ import annotations

from copy import deepcopy
import numbers
from copy import deepcopy

import networkx as nx
import numpy as np

import graphix.clifford
import graphix.parameter
import graphix.pauli
from graphix import command
import graphix.parameter
from graphix.clifford import CLIFFORD_CONJ, CLIFFORD_MEASURE, CLIFFORD_TO_QASM3
from graphix.device_interface import PatternRunner
from graphix.gflow import find_flow, find_gflow, get_layers
Expand Down Expand Up @@ -985,7 +985,6 @@ def get_meas_plane(self):
list of planes representing measurement plane for each node.
"""
meas_plane = dict()
order = [graphix.pauli.Axis.X, graphix.pauli.Axis.Y, graphix.pauli.Axis.Z]
for cmd in self.__seq:
if cmd.kind == command.CommandKind.M:
mplane = cmd.plane
Expand Down Expand Up @@ -1419,8 +1418,8 @@ def to_qasm3(self, filename):
file.write("// measurement result of qubit q" + str(id) + "\n")
file.write("bit c" + str(id) + " = " + str(res) + ";\n")
file.write("\n")
for command in self.__seq:
for line in cmd_to_qasm3(command):
for cmd in self.__seq:
for line in cmd_to_qasm3(cmd):
file.write(line)

def is_parameterized(self) -> bool:
Expand All @@ -1430,7 +1429,7 @@ def is_parameterized(self) -> bool:
expression that is not a number, typically an instance of `sympy.Expr`
(but we don't force to choose `sympy` here).
"""
return any(not isinstance(cmd[3], numbers.Number) for cmd in self if cmd[0] == "M")
return any(not isinstance(cmd.angle, numbers.Number) for cmd in self if cmd.kind == command.CommandKind.M)

def subs(self, variable, substitute) -> Pattern:
"""Return a copy of the pattern where all occurrences of the
Expand All @@ -1447,9 +1446,8 @@ def subs(self, variable, substitute) -> Pattern:
"""
result = Pattern(input_nodes=self.input_nodes)
for cmd in self:
if cmd[0] == "M":
new_cmd = cmd.copy()
new_cmd[3] = graphix.parameter.subs(new_cmd[3], variable, substitute)
if cmd.kind == command.CommandKind.M:
new_cmd = cmd.model_copy(update={"angle": graphix.parameter.subs(cmd.angle, variable, substitute)})
result.add(new_cmd)
else:
result.add(cmd)
Expand Down
9 changes: 8 additions & 1 deletion graphix/sim/density_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,14 @@ def subs(self, variable, substitute) -> DensityMatrix:
class DensityMatrixBackend(Backend):
"""MBQC simulator with density matrix method."""

def __init__(self, pattern, max_qubit_num=12, pr_calc=True, input_state: Data = graphix.states.BasicStates.PLUS, rng: np.random.Generator | None = None):
def __init__(
self,
pattern,
max_qubit_num=12,
pr_calc=True,
input_state: Data = graphix.states.BasicStates.PLUS,
rng: np.random.Generator | None = None,
):
"""
Parameters
----------
Expand Down
2 changes: 1 addition & 1 deletion graphix/sim/statevec.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import numpy as np
import numpy.typing as npt

import graphix.pauli
import graphix.parameter
import graphix.pauli
import graphix.sim.base_backend
import graphix.states
import graphix.types
Expand Down
Loading

0 comments on commit 784dd34

Please sign in to comment.