diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index f36a58fcc3..aa2d279f43 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -535,30 +535,59 @@ def local_mul_pow_to_pow_add(fgraph, node): @register_stabilize @register_specialize @register_canonicalize -@node_rewriter([sub]) +@node_rewriter([add, sub]) def local_expm1(fgraph, node): - """Detect ``exp(a) - 1`` and convert them to ``expm1(a)``.""" - in1, in2 = node.inputs - out = node.outputs[0] + """Detect ``exp(a) - 1`` or ``-1 + exp(a)`` and convert them to ``expm1(a)``.""" + if len(node.inputs) != 2: + # TODO: handle more than two inputs in add + return None - if ( - in1.owner - and isinstance(in1.owner.op, Elemwise) - and isinstance(in1.owner.op.scalar_op, ps.Exp) - and get_underlying_scalar_constant_value(in2, raise_not_constant=False) == 1 - ): - in11 = in1.owner.inputs[0] - new_out = expm1(in11) + if isinstance(node.op.scalar_op, ps.Sub): + exp_x, other_inp = node.inputs + if not ( + exp_x.owner + and isinstance(exp_x.owner.op, Elemwise) + and isinstance(exp_x.owner.op.scalar_op, ps.Exp) + and get_underlying_scalar_constant_value( + other_inp, raise_not_constant=False + ) + == 1 + ): + return None + else: + # Try both orders + other_inp, exp_x = node.inputs + for i in range(2): + if i == 1: + other_inp, exp_x = exp_x, other_inp + if ( + exp_x.owner + and isinstance(exp_x.owner.op, Elemwise) + and isinstance(exp_x.owner.op.scalar_op, ps.Exp) + and get_underlying_scalar_constant_value( + other_inp, raise_not_constant=False + ) + == -1 + ): + break + else: # no break + return None - if new_out.type.broadcastable != out.type.broadcastable: - new_out = broadcast_arrays(in11, in2)[0] + [old_out] = node.outputs - if new_out.dtype != out.dtype: - new_out = cast(new_out, dtype=out.dtype) + [x] = exp_x.owner.inputs + if x.type.broadcastable != old_out.type.broadcastable: + x = broadcast_arrays(x, other_inp)[0] - if not out.type.is_super(new_out.type): - return - return [new_out] + new_out = expm1(x) + + if new_out.dtype != old_out.dtype: + new_out = cast(new_out, dtype=old_out.dtype) + + if not old_out.type.is_super(new_out.type): + return None + + return [new_out] @register_specialize @@ -1824,15 +1853,6 @@ def local_add_neg_to_sub(fgraph, node): new_out = sub(first, pre_neg) return [new_out] - # Check if it is a negative constant - if ( - isinstance(second, TensorConstant) - and second.unique_value is not None - and second.unique_value < 0 - ): - new_out = sub(first, np.abs(second.data)) - return [new_out] - @register_canonicalize @node_rewriter([mul]) @@ -2606,9 +2626,9 @@ def local_greedy_distributor(fgraph, node): register_stabilize(local_one_minus_erfc) register_specialize(local_one_minus_erfc) -# erfc(-x)-1=>erf(x) +# -1 + erfc(-x)=>erf(x) local_erf_neg_minus_one = PatternNodeRewriter( - (sub, (erfc, (neg, "x")), 1), + (add, -1, (erfc, (neg, "x"))), (erf, "x"), allow_multiple_clients=True, name="local_erf_neg_minus_one", diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index f2f421c6a5..7156b8fcbf 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -3806,14 +3806,9 @@ def test_local_expm1(): for n in h.maker.fgraph.toposort() ) - # This rewrite works when `local_add_neg_to_sub` specialization rewrite is invoked - expect_rewrite = config.mode != "FAST_COMPILE" - assert ( - any( - isinstance(n.op, Elemwise) and isinstance(n.op.scalar_op, ps.basic.Expm1) - for n in r.maker.fgraph.toposort() - ) - == expect_rewrite + assert not any( + isinstance(n.op, Elemwise) and isinstance(n.op.scalar_op, ps.basic.Expm1) + for n in r.maker.fgraph.toposort() ) @@ -4440,25 +4435,6 @@ def test_local_add_neg_to_sub(first_negative): assert np.allclose(f(x_test, y_test), exp) -@pytest.mark.parametrize("const_left", (True, False)) -def test_local_add_neg_to_sub_const(const_left): - x = vector("x") - const = np.full((3, 2), 5.0) - out = -const + x if const_left else x + (-const) - - f = function([x], out, mode=Mode("py")) - - nodes = [ - node.op - for node in f.maker.fgraph.toposort() - if not isinstance(node.op, DimShuffle | Alloc) - ] - assert nodes == [pt.sub] - - x_test = np.array([3, 4], dtype=config.floatX) - assert np.allclose(f(x_test), x_test + (-const)) - - def test_log1mexp_stabilization(): mode = Mode("py").including("stabilize")