Skip to content

Commit

Permalink
Don't apply local_add_neg_to_sub rewrite if negative variabe is a c…
Browse files Browse the repository at this point in the history
…onstant
  • Loading branch information
ricardoV94 committed Jan 10, 2025
1 parent a7e08e5 commit 4660cd7
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 57 deletions.
80 changes: 50 additions & 30 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,30 +535,59 @@ def local_mul_pow_to_pow_add(fgraph, node):
@register_stabilize
@register_specialize
@register_canonicalize
@node_rewriter([sub])
@node_rewriter([add, sub])
def local_expm1(fgraph, node):
"""Detect ``exp(a) - 1`` and convert them to ``expm1(a)``."""
in1, in2 = node.inputs
out = node.outputs[0]
"""Detect ``exp(a) - 1`` or ``-1 + exp(a)`` and convert them to ``expm1(a)``."""
if len(node.inputs) != 2:
# TODO: handle more than two inputs in add
return None

if (
in1.owner
and isinstance(in1.owner.op, Elemwise)
and isinstance(in1.owner.op.scalar_op, ps.Exp)
and get_underlying_scalar_constant_value(in2, raise_not_constant=False) == 1
):
in11 = in1.owner.inputs[0]
new_out = expm1(in11)
if isinstance(node.op.scalar_op, ps.Sub):
exp_x, other_inp = node.inputs
if not (
exp_x.owner
and isinstance(exp_x.owner.op, Elemwise)
and isinstance(exp_x.owner.op.scalar_op, ps.Exp)
and get_underlying_scalar_constant_value(
other_inp, raise_not_constant=False
)
== 1
):
return None
else:
# Try both orders
other_inp, exp_x = node.inputs
for i in range(2):
if i == 1:
other_inp, exp_x = exp_x, other_inp
if (
exp_x.owner
and isinstance(exp_x.owner.op, Elemwise)
and isinstance(exp_x.owner.op.scalar_op, ps.Exp)
and get_underlying_scalar_constant_value(
other_inp, raise_not_constant=False
)
== -1
):
break
else: # no break
return None

if new_out.type.broadcastable != out.type.broadcastable:
new_out = broadcast_arrays(in11, in2)[0]
[old_out] = node.outputs

if new_out.dtype != out.dtype:
new_out = cast(new_out, dtype=out.dtype)
[x] = exp_x.owner.inputs
if x.type.broadcastable != old_out.type.broadcastable:
x = broadcast_arrays(x, other_inp)[0]

if not out.type.is_super(new_out.type):
return
return [new_out]
new_out = expm1(x)

if new_out.dtype != old_out.dtype:
new_out = cast(new_out, dtype=old_out.dtype)

if not old_out.type.is_super(new_out.type):
return None

return [new_out]


@register_specialize
Expand Down Expand Up @@ -1824,15 +1853,6 @@ def local_add_neg_to_sub(fgraph, node):
new_out = sub(first, pre_neg)
return [new_out]

# Check if it is a negative constant
if (
isinstance(second, TensorConstant)
and second.unique_value is not None
and second.unique_value < 0
):
new_out = sub(first, np.abs(second.data))
return [new_out]


@register_canonicalize
@node_rewriter([mul])
Expand Down Expand Up @@ -2606,9 +2626,9 @@ def local_greedy_distributor(fgraph, node):
register_stabilize(local_one_minus_erfc)
register_specialize(local_one_minus_erfc)

# erfc(-x)-1=>erf(x)
# -1 + erfc(-x)=>erf(x)
local_erf_neg_minus_one = PatternNodeRewriter(
(sub, (erfc, (neg, "x")), 1),
(add, -1, (erfc, (neg, "x"))),
(erf, "x"),
allow_multiple_clients=True,
name="local_erf_neg_minus_one",
Expand Down
30 changes: 3 additions & 27 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -3806,14 +3806,9 @@ def test_local_expm1():
for n in h.maker.fgraph.toposort()
)

# This rewrite works when `local_add_neg_to_sub` specialization rewrite is invoked
expect_rewrite = config.mode != "FAST_COMPILE"
assert (
any(
isinstance(n.op, Elemwise) and isinstance(n.op.scalar_op, ps.basic.Expm1)
for n in r.maker.fgraph.toposort()
)
== expect_rewrite
assert not any(
isinstance(n.op, Elemwise) and isinstance(n.op.scalar_op, ps.basic.Expm1)
for n in r.maker.fgraph.toposort()
)


Expand Down Expand Up @@ -4440,25 +4435,6 @@ def test_local_add_neg_to_sub(first_negative):
assert np.allclose(f(x_test, y_test), exp)


@pytest.mark.parametrize("const_left", (True, False))
def test_local_add_neg_to_sub_const(const_left):
x = vector("x")
const = np.full((3, 2), 5.0)
out = -const + x if const_left else x + (-const)

f = function([x], out, mode=Mode("py"))

nodes = [
node.op
for node in f.maker.fgraph.toposort()
if not isinstance(node.op, DimShuffle | Alloc)
]
assert nodes == [pt.sub]

x_test = np.array([3, 4], dtype=config.floatX)
assert np.allclose(f(x_test), x_test + (-const))


def test_log1mexp_stabilization():
mode = Mode("py").including("stabilize")

Expand Down

0 comments on commit 4660cd7

Please sign in to comment.