Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CPU profiling (not tracing) #24349

Open
joaospinto opened this issue Oct 16, 2024 · 4 comments
Open

CPU profiling (not tracing) #24349

joaospinto opened this issue Oct 16, 2024 · 4 comments
Labels
enhancement New feature or request

Comments

@joaospinto
Copy link
Contributor

I want to get some flamegraphs from some JIT'd JAX CPU code to understand where time is being spent (in terms of my user-defined functions).

My understanding (based on the docs) is that currently the recommended approach is to add some custom tracing events and run JAX's tracing feature.

This seems rather suboptimal. Is there a better way?

Related discussion: #19888

@joaospinto joaospinto added the enhancement New feature or request label Oct 16, 2024
@justinjfu
Copy link
Collaborator

One difficulty with this feature request is that after a function is compiled into HLO, information about the original python function boundaries is lost. So we would not be able to automatically generate a profile that contains information about user-defined functions. You can look at the compiled code yourself by running jax.jit(f).lower(*args).compiler_ir('hlo').

One workaround for this could be to decorate all of your user functions using jax.named_scope. After this, they should be visible in the trace viewer (https://jax.readthedocs.io/en/latest/profiling.html#tensorboard-profiling). It's not automatic, but it shouldn't be too much of an overhead.

@joaospinto
Copy link
Contributor Author

joaospinto commented Oct 17, 2024

One difficulty with this feature request is that after a function is compiled into HLO, information about the original python function boundaries is lost.

There are several ways of exporting HLO/StableHLO from JAX, and many (certainly the StableHLO MLIR bytecode portable artifacts) do export location information (which maps HLO/StableHLO ops to the Python code that created them).

@joaospinto
Copy link
Contributor Author

For example, this can be used (although it might be not the most compact representation):

with open("output.hlo", "w") as f:
  ir.operation.print(
    enable_debug_info=True,
    pretty_debug_info=True,
    use_local_scope=True,
    file=f,
  )

@joaospinto
Copy link
Contributor Author

Related discussion: #23251

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants