Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[better_errors] Refactor more uses of pe.tracing_debug_info (part 3) #26100

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

gnecula
Copy link
Collaborator

@gnecula gnecula commented Jan 25, 2025

We replace uses of pe.tracing_debug_info with with api_util.tracing_debug_info,
which uses the actual args and kwargs, instead of in_tree to manufacture fake
args and kwargs. This ends up being more accurate, especially for arg_names;
see changes in debug_info_tests.py.
This means that we have to construct the debug info further upstream, before
flattening args. This will later help populate debug info in WrappedFun and
Jaxpr.

This is part 3 of a series (following #26097, #26099) for jit, checkify, custom_dce.

@gnecula gnecula self-assigned this Jan 25, 2025
… (part 1)

We replace those uses with api_util.tracing_debug_info, which means we
have to move the call further upstream. But this is better because we
have the actual args and kwargs, and we can do a better job, especially
for `arg_names`.

This is part 1 of a series, for: cond, switch, while, scan, composite,
custom_dce, saved_residuals.
We replace uses of `pe.tracing_debug_info` with with `api_util.tracing_debug_info`,
which uses the actual args and kwargs, instead of `in_tree` to manufacture fake
args and kwargs. This ends up being more accurate, especially for `arg_names`;
see changes in debug_info_tests.py.
This means that we have to construct the debug info further upstream, before
flattening args. This will later help populate debug info in `WrappedFun` and
`Jaxpr`.

This is part 2 of a series (following jax-ml#26097) for Pallas.
@gnecula gnecula force-pushed the debug_info_no_pe_debug_info_3 branch from 33ff715 to 4c3bbcf Compare January 26, 2025 08:26
We replace uses of `pe.tracing_debug_info` with with `api_util.tracing_debug_info`,
which uses the actual args and kwargs, instead of `in_tree` to manufacture fake
args and kwargs. This ends up being more accurate, especially for `arg_names`;
see changes in debug_info_tests.py.
This means that we have to construct the debug info further upstream, before
flattening args. This will later help populate debug info in `WrappedFun` and
`Jaxpr`.

This is part 3 of a series (following jax-ml#26097, jax-ml#26099) for jit, pmap, checkify, custom_dce.
@gnecula gnecula force-pushed the debug_info_no_pe_debug_info_3 branch from 4c3bbcf to 75584c3 Compare January 26, 2025 09:07
@gnecula gnecula changed the title [better_errors] Refactor more uses of partial_eval.tracing_debug_info (part 3) [better_errors] Refactor more uses of pe.tracing_debug_info (part 3) Jan 26, 2025
@gnecula gnecula added the pull ready Ready for copybara import and testing label Jan 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant