Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[better_errors] Add debug info to more Jaxprs and Wrappedfun #26078

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

gnecula
Copy link
Collaborator

@gnecula gnecula commented Jan 24, 2025

Here we pass debug info in more places, so that it ends up in more Jaxprs and WrappedFun. As a result some of the tests are showing more complete debug info.

There are three kinds of information in the debug info:

  • the func_src_info: this is the easiest to keep track of because all we need is to pass it down. In this long chain of refactorings, I will prioritize having this everywhere.
  • the arg_names: this is collected from the function signature, and it is passed down, but it needs to be adjusted as we add and remove arguments. This is used when we generate location information in the lowering and when we explain some leaked tracers.
  • the result_paths: this is the hardest to keep track of, because you can only read it after tracing. This is also the least useful. It is used only for locations in the lowering.

To enable progress I will de-prioritize keeping accurate the arg names and result paths, for now. I relax a safety check in the Jaxpr constructor that was verifying that arg_names and result_paths have the proper length. Therefore, I needed to add some checks where the arg_names and result_paths are used (safe_arg_names, and safe_result_paths).

@gnecula gnecula self-assigned this Jan 24, 2025
@gnecula gnecula added the pull ready Ready for copybara import and testing label Jan 24, 2025
@gnecula gnecula marked this pull request as draft January 25, 2025 14:57
… (part 1)

We replace those uses with api_util.tracing_debug_info, which means we
have to move the call further upstream. But this is better because we
have the actual args and kwargs, and we can do a better job, especially
for `arg_names`.

This is part 1 of a series, for: cond, switch, while, scan, composite,
custom_dce, saved_residuals.
We replace uses of `pe.tracing_debug_info` with with `api_util.tracing_debug_info`,
which uses the actual args and kwargs, instead of `in_tree` to manufacture fake
args and kwargs. This ends up being more accurate, especially for `arg_names`;
see changes in debug_info_tests.py.
This means that we have to construct the debug info further upstream, before
flattening args. This will later help populate debug info in `WrappedFun` and
`Jaxpr`.

This is part 2 of a series (following jax-ml#26097) for Pallas.
We replace uses of `pe.tracing_debug_info` with with `api_util.tracing_debug_info`,
which uses the actual args and kwargs, instead of `in_tree` to manufacture fake
args and kwargs. This ends up being more accurate, especially for `arg_names`;
see changes in debug_info_tests.py.
This means that we have to construct the debug info further upstream, before
flattening args. This will later help populate debug info in `WrappedFun` and
`Jaxpr`.

This is part 3 of a series (following jax-ml#26097, jax-ml#26099) for jit, pmap, checkify, custom_dce.
Here we pass debug info in more places, so that it
ends up in more Jaxprs and Tracers. As a result some of the tests
are showing more complete debug info.

There are three kinds of information in the debug info:

  * the func_src_info: this is the easiest to keep track of
    because all we need is to pass it down. In this chain
    of refactorings, I will prioritize having this everywhere.
  * the arg_names: this is collected from the function signature,
    and it is passed down, but it needs to be adjusted as we add
    and remove arguments. This is used when we generate location
    information in the lowering and when we explain some leaked
    tracers.
   * the result_paths: this is the hardest to keep track of, because
    you can only read it after tracing. This is also the least useful.
    It is used only for locations in the lowering.

To enable progress I will de-prioritize keeping accurate the
arg names and result paths, for now. I relax a safety check in the Jaxpr
constructor that was verifying that arg_names and result_paths
have the proper length. Therefore, I needed to add some checks
where the arg_names and result_paths are used (`safe_arg_names`,
and `safe_result_paths`).
@gnecula gnecula changed the title [better_errors] Add debug info to more Jaxprs [better_errors] Add debug info to more Jaxprs and Wrappedfun Jan 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant