Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise an error or warning in qml.compile if basis_set contains operations instead of strings #6137

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
23 changes: 21 additions & 2 deletions pennylane/transforms/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,11 @@ def compile(
expansion will continue until gates in the specific set are
reached. If no basis set is specified, a default of
``pennylane.ops.__all__`` will be used. This decomposes templates and
operator arithmetic.
operator arithmetic. If the basis set contains
any items that are not strings or subclasses of ``qml.operation.Operator``,
a ValueError will be raised. Additionally, if the basis set contains
operator types instead of names and results in an empty set, an error
will be raised.
num_passes (int): The number of times to apply the set of transforms in
``pipeline``. The default is to perform each transform once;
however, doing so may produce a new circuit where applying the set
Expand Down Expand Up @@ -182,12 +186,27 @@ def qfunc(x, y, z):
with QueuingManager.stop_recording():
basis_set = basis_set or all_ops

if basis_set:
# Handle the case where basis_set becomes equivalent to an empty list due to improper types
class_types = tuple(
o
for o in basis_set
if isinstance(o, type) and issubclass(o, qml.operation.Operator)
)
class_names = set(o for o in basis_set if isinstance(o, str))

# Convert operator types to their names and merge with string-based names
basis_set = class_names.union({op.name for op in class_types})

if not basis_set:
raise ValueError("basis_set contains no valid operation names or types.")

def stop_at(obj):
if not isinstance(obj, qml.operation.Operator):
return True
if not obj.has_decomposition:
return True
return obj.name in basis_set and (not getattr(obj, "only_visual", False))
return obj.name in class_names or isinstance(obj, class_types)

[expanded_tape], _ = qml.devices.preprocess.decompose(
tape,
Expand Down
39 changes: 39 additions & 0 deletions tests/transforms/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,45 @@ def test_compile_invalid_num_passes(self):
with pytest.raises(ValueError, match="Number of passes must be an integer"):
transformed_qnode(0.1, 0.2, 0.3)

def test_stop_at_with_simple_basis(self):
"""
Test the stop_at function with a valid basis set and a simple quantum function.
"""

# Define a simple quantum function
def qfunc(x):
qml.RX(x, wires=0)
return qml.expval(qml.PauliZ(0)) # Return a measurement

# Define a valid basis set
basis_set = ["RX"]

# Apply the compile function with stop_at and a basis set
transformed_qfunc = compile(qfunc, basis_set=basis_set)

# Ensure the transformation occurred
assert transformed_qfunc is not None

# Create a QNode using the transformed function
dev = qml.device("default.qubit", wires=1)
qnode = qml.QNode(transformed_qfunc, dev)

# Evaluate the QNode with an input
result = qnode(0.5)

# Merged isinstance checks to resolve pylint error
if isinstance(result, (float, qml.numpy.ndarray)):
if isinstance(result, qml.numpy.ndarray):
# If result is a tensor (qml.numpy.ndarray), convert it to a float
result = result.item()
else:
raise ValueError("Result is not a valid numerical value")

# Verify the QNode's operations are within the allowed basis set
allowed_ops = ["RX"]
for op in qnode.qtape.operations:
assert op.name in allowed_ops, f"Operation {op.name} is not in the allowed basis set"

def test_compile_mixed_tape_qfunc_transform(self):
"""Test that we can interchange tape and qfunc transforms."""

Expand Down
Loading