Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate redundant utilities for extracting constants #1046

Merged
merged 7 commits into from
Jan 13, 2025

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Oct 21, 2024

Description

In PyTensor we have tho following similar utilities

  1. pytensor.get_underlying_scalar_constant
  2. pytensor.tensor.basic.get_underlying_scalar_constant_value
  3. pytensor.tensor.basic.get_scalar_constant_value
  4. pytensor.tensor.basic.extract_constant
  5. pytensor.tensor.rewriting.math.get_constant

This PR removes and deprecates all except:

  1. pytensor.tensor.basic.get_underlying_scalar_constant_value
  2. pytensor.tensor.basic.get_scalar_constant_value

The reason for this distinction, is that the core utility, get_underlying_scalar_constant_value actually works for non-scalar inputs, if it can find a single scalar value underlies a potential n-dimensional tensor (say as in pt.zeros(5, 3, 2)). This is powerful, but can lead to subtle bugs when the caller forgets about it. This was the source of the bug behind #584 and was also likely present in other graphs such as gt(x, [0, 0, 0, 0]) and alike where the repeated condition broadcasts the output of the operation.

The utility get_scalar_constant_value raises if the input is not a scalar (ndim=0) type.

I don't love the underlying distinguishing the two. Perhaps unique would be better.

Both utilities now accept a raise_not_constant which when False (not-default) return the variable as is. I think I would prefer for it to return None but this requires less code changes.

Related Issue

@ricardoV94 ricardoV94 changed the title Remove repeated utilities for extracting constants Deprecate redundant utilities for extracting constants Oct 21, 2024
@ricardoV94 ricardoV94 force-pushed the fix_rewrite_bug branch 7 times, most recently from 24716ac to 8dcbc2b Compare November 29, 2024 11:04
Copy link

codecov bot commented Nov 29, 2024

Codecov Report

Attention: Patch coverage is 86.66667% with 36 lines in your changes missing coverage. Please review.

Project coverage is 82.11%. Comparing base (8cc489b) to head (628c321).
Report is 10 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/tensor/rewriting/math.py 75.25% 11 Missing and 13 partials ⚠️
pytensor/tensor/basic.py 88.00% 6 Missing ⚠️
pytensor/tensor/variable.py 85.71% 3 Missing ⚠️
pytensor/scan/rewriting.py 90.90% 0 Missing and 1 partial ⚠️
pytensor/tensor/rewriting/subtensor.py 80.00% 0 Missing and 1 partial ⚠️
pytensor/tensor/shape.py 83.33% 1 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1046      +/-   ##
==========================================
- Coverage   82.14%   82.11%   -0.03%     
==========================================
  Files         186      186              
  Lines       48207    48201       -6     
  Branches     8677     8679       +2     
==========================================
- Hits        39598    39579      -19     
- Misses       6441     6447       +6     
- Partials     2168     2175       +7     
Files with missing lines Coverage Δ
pytensor/gradient.py 77.62% <100.00%> (+0.05%) ⬆️
pytensor/link/jax/dispatch/tensor_basic.py 98.21% <100.00%> (ø)
pytensor/scan/basic.py 84.56% <100.00%> (ø)
pytensor/sparse/basic.py 82.57% <100.00%> (+0.02%) ⬆️
pytensor/tensor/conv/abstract_conv.py 76.05% <100.00%> (ø)
pytensor/tensor/extra_ops.py 88.01% <100.00%> (ø)
pytensor/tensor/rewriting/basic.py 94.24% <100.00%> (+0.09%) ⬆️
pytensor/tensor/rewriting/elemwise.py 91.13% <100.00%> (-0.02%) ⬇️
pytensor/tensor/rewriting/shape.py 82.02% <100.00%> (ø)
pytensor/tensor/slinalg.py 93.52% <100.00%> (+0.04%) ⬆️
... and 7 more

... and 1 file with indirect coverage changes

@ricardoV94 ricardoV94 marked this pull request as ready for review November 29, 2024 22:05
@ricardoV94 ricardoV94 requested a review from Armavica November 29, 2024 22:05
@ricardoV94
Copy link
Member Author

@Armavica I removed the negative constant rewrite in the last commit

@ricardoV94 ricardoV94 force-pushed the fix_rewrite_bug branch 2 times, most recently from 7961d57 to 8134f62 Compare December 9, 2024 11:45
@ricardoV94 ricardoV94 force-pushed the fix_rewrite_bug branch 3 times, most recently from 7c38bed to 15c06ff Compare January 10, 2025 18:37
Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some ticky-tacky comments, feel free to ignore.

One thing I noticed was that get_underlying_scalar_constant_value is available in pt.? We should maybe curate that namespace a bit more, because this isn't a function an average user is going to need, and use it from ptb directly.

if v.owner is not None and isinstance(v.owner.op, sparse.CSM):
data = v.owner.inputs[0]
return tensor.get_underlying_scalar_constant_value(data)
warnings.warn(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the logger instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wdym. Never saw deprecation warnings in logging.

@@ -2157,6 +2157,9 @@ def _is_zero(x):
'maybe' means that x is an expression that is complicated enough
that we can't tell that it simplifies to 0.
"""
from pytensor.tensor import get_underlying_scalar_constant_value
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this need a local import here but not in the above function?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the other is using it through tensor. This was not done on purpose but I prefer explicit imports and there is in fact a circular dependency here.

Parameters
----------
v: Variable
elemwise : bool
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check_through_elemwise_ops or something? Just elemwise makes it sound like its expecting an elemwise input

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These were all stuff that existed. If I change the kwarg names it will be a breaking change which I think this PR is not

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I could keep the wrapper with the old kwargs and the inner with better but not sure that helps or makes things more confusing

If False, we won't try to go into elemwise. So this call is faster.
But we still investigate in Second Elemwise (as this is a substitute
for Alloc)
only_process_constants : bool
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shallow_search? Why is "constants" plural?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

underlying constant scalar value. If False, return `v` as is.


Raises
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numpy docs suggest to only have a Raises section if its not obvious; I don't think its needed here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed. It's actually how you use the function most of the time. You need to know what to try except

return x
warnings.warn(
"extract_constant is deprecated. Use `get_underlying_scalar_constant_value(..., raise_not_constant=False)`",
FutureWarning,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DeprecationWarning ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DeprecationWarning is more invisible to users, so I always use FutureWarning

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see I used a DeprecationWarning above, I'll make it Future

if (flat_data == flat_data[0]).all():
return flat_data[0]

warnings.warn("get_unique_constant_value is deprecated.", FutureWarning)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DeprecationWarning ?

1. Use actual Solve Op to infer output dtype as CholSolve outputs a different dtype than basic Solve in Scipy==1.15

2. Tweaked test related to pymc-devs#1152

3. Tweak tolerage
@ricardoV94 ricardoV94 merged commit 581f65a into pymc-devs:main Jan 13, 2025
60 of 62 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

local_add_neg_to_sub rewrite gives wrong results with negative constants
2 participants