Skip to content

Commit

Permalink
[better_errors] Add debug info to more Jaxprs and WrappedFun
Browse files Browse the repository at this point in the history
Here we pass debug info in more places, so that it
ends up in more Jaxprs and Tracers. As a result some of the tests
are showing more complete debug info.

There are three kinds of information in the debug info:

  * the func_src_info: this is the easiest to keep track of
    because all we need is to pass it down. In this chain
    of refactorings, I will prioritize having this everywhere.
  * the arg_names: this is collected from the function signature,
    and it is passed down, but it needs to be adjusted as we add
    and remove arguments. This is used when we generate location
    information in the lowering and when we explain some leaked
    tracers.
   * the result_paths: this is the hardest to keep track of, because
    you can only read it after tracing. This is also the least useful.
    It is used only for locations in the lowering.

To enable progress I will de-prioritize keeping accurate the
arg names and result paths, for now. I relax a safety check in the Jaxpr
constructor that was verifying that arg_names and result_paths
have the proper length. Therefore, I needed to add some checks
where the arg_names and result_paths are used (`safe_arg_names`,
and `safe_result_paths`).
  • Loading branch information
gnecula committed Jan 26, 2025
1 parent 75584c3 commit cfd34b2
Show file tree
Hide file tree
Showing 11 changed files with 165 additions and 86 deletions.
2 changes: 1 addition & 1 deletion jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def _trace_to_jaxpr(fun: Callable,
in_avals: Sequence[core.AbstractValue],
debug: lu.TracingDebugInfo
) -> tuple[core.Jaxpr, Sequence[Any], PyTreeDef]:
flat_fun, out_tree = api_util.flatten_fun(lu.wrap_init(fun), in_tree)
flat_fun, out_tree = api_util.flatten_fun(lu.wrap_init(fun, debug_info=debug), in_tree)
try:
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
except core.ConcretizationTypeError as e:
Expand Down
46 changes: 34 additions & 12 deletions jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,18 @@ def tracing_debug_info(
return TracingDebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk)


def tracing_debug_info_from_jaxpr(maybe_closed_jaxpr: core.ClosedJaxpr | core.Jaxpr) -> TracingDebugInfo | None:
if isinstance(maybe_closed_jaxpr, core.ClosedJaxpr):
jaxpr = maybe_closed_jaxpr.jaxpr
else:
jaxpr = maybe_closed_jaxpr
jaxpr_dbg = jaxpr._debug_info
if jaxpr_dbg is None: return None
return TracingDebugInfo(jaxpr_dbg.traced_for,
jaxpr_dbg.func_src_info,
jaxpr_dbg.arg_names,
lambda: jaxpr_dbg.result_paths)

def fun_signature(fun: Callable) -> inspect.Signature | None:
try:
return inspect.signature(fun)
Expand Down Expand Up @@ -705,22 +717,32 @@ def result_paths(_fun, _store, *args, **kwargs):
_store.store([keystr(path) for path, _ in generate_key_paths(ans)])
return ans

# TODO(necula): simplify this function, all it needs is to add the trace_debug to the Jaxpr
def jaxpr_debug_info(trace_debug: TracingDebugInfo | None,
result_paths: tuple[str, ...] | None = None) -> core.JaxprDebugInfo | None:
# TODO(necula): re-enable this check
# assert (result_paths is not None) ^ (trace_debug.result_paths_thunk is not None)
if trace_debug is None:
return None
if result_paths is None:
if trace_debug.result_paths_thunk is not None:
result_paths = tuple(trace_debug.result_paths_thunk()) # type: ignore
else:
# TODO(necula): fix result paths
result_paths = ()
else:
result_paths = tuple(result_paths)
return core.JaxprDebugInfo(
trace_debug.traced_for, trace_debug.func_src_info,
trace_debug.arg_names, result_paths)

def add_jaxpr_debug_info(jaxpr: core.Jaxpr,
trace_debug: TracingDebugInfo | None,
result_paths: tuple[str, ...] | None = None,
) -> core.Jaxpr:
"""Add debug info to jaxpr, given trace-time debug info and result paths."""
if trace_debug is None:
return jaxpr

# assert (result_paths is not None) ^ (trace_debug.result_paths_thunk is not None)
if result_paths is None:
result_paths = trace_debug.result_paths_thunk() # type: ignore
debug_info = core.JaxprDebugInfo(
trace_debug.traced_for, trace_debug.func_src_info,
trace_debug.arg_names, tuple(result_paths)) # type: ignore
return jaxpr.replace(debug_info=debug_info)
return jaxpr.replace(debug_info=jaxpr_debug_info(trace_debug, result_paths))

def debug_info_final(f: lu.WrappedFun, dbg: TracingDebugInfo | None,
res_paths_thunk: Callable[[], tuple[str, ...]]
Expand Down Expand Up @@ -756,7 +778,7 @@ def register_class_with_attrs(t: type) -> None:
_class_with_attrs: set[type] = set()

# TODO(mattjj): make this function faster
def _check_no_aliased_ref_args(dbg, avals, args):
def _check_no_aliased_ref_args(dbg: TracingDebugInfo | None, avals, args):
assert config.mutable_array_checks.value
refs: dict[int, int] = {}
for i, (a, x) in enumerate(zip(avals, args)):
Expand All @@ -770,7 +792,7 @@ def _check_no_aliased_ref_args(dbg, avals, args):
if dbg else
f"at both flat index {dup_idx} and flat index {i}") from None

def _check_no_aliased_closed_over_refs(dbg, consts, args) -> None:
def _check_no_aliased_closed_over_refs(dbg: TracingDebugInfo | None, consts, args) -> None:
assert config.mutable_array_checks.value
refs: set[int] = {id(core.get_referent(c)) for c in consts
if isinstance(core.get_aval(c), AbstractRef)}
Expand All @@ -781,4 +803,4 @@ def _check_no_aliased_closed_over_refs(dbg, consts, args) -> None:
f"when tracing {dbg.func_src_info} for {dbg.traced_for}, a mutable "
f"array reference of type {a.str_short()} was both closed over and "
f"passed as the argument "
f"{dbg.arg_names[i]}" if dbg else "at flat index {i}")
f"{dbg.safe_arg_names(len(args))[i]}" if dbg else "at flat index {i}")
31 changes: 28 additions & 3 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,30 @@ class JaxprDebugInfo(NamedTuple):
# This is formed after tracing, when we have concrete `result_paths`
result_paths: tuple[str, ...] # e.g. ('[0]', '[1]', ...)

def safe_arg_names(self, expected: int) -> tuple[str | None, ...]:
"""Get the arg_names with a safety check."""
if len(self.arg_names) == expected:
return self.arg_names
else:
# TODO(necula): this should not happen
return (None,) * expected

def filter_arg_names(self, keep: Sequence[bool]) -> tuple[str | None, ...]:
"""Keep only the arg_names for which `keep` is True."""
return tuple(v for v, b in zip(self.safe_arg_names(len(keep)), keep) if b)

def safe_result_paths(self, expected: int) -> tuple[str, ...]:
"""Get the result_paths with a safety check."""
if len(self.result_paths) == expected:
return self.result_paths
else:
# TODO(necula): this should not happen
return ("",) * expected

def filter_result_paths(self, keep: Sequence[bool]) -> tuple[str, ...]:
"""Keep only the result_paths for which `keep` is True."""
return tuple(v for v, b in zip(self.safe_result_paths(len(keep)), keep) if b)


class Jaxpr:
__slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns',
Expand Down Expand Up @@ -146,8 +170,9 @@ def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
self._eqns = list(eqns)
self._effects = effects
self._debug_info = debug_info
assert (not debug_info or len(debug_info.arg_names) == len(invars)), (debug_info, invars)
assert (not debug_info or len(debug_info.result_paths) == len(outvars)), (debug_info, outvars)
# TODO(necula): re-enable these checks
# assert (not debug_info or len(debug_info.arg_names) == len(invars)), (debug_info, invars)
# assert (not debug_info or len(debug_info.result_paths) == len(outvars)), (debug_info, outvars)

def __str__(self):
return str(self.pretty_print())
Expand Down Expand Up @@ -2327,7 +2352,7 @@ class MapPrimitive(Primitive):
map_primitive = True

def bind_with_trace(self, trace, fun_and_args, params):
fun = fun_and_args[0]
fun: lu.WrappedFun = fun_and_args[0]
args = fun_and_args[1:]
assert len(params['in_axes']) == len(args)
return trace.process_map(self, fun, args, params)
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/custom_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def __call__(self, *args, **kwargs):
flat_fun, out_tree = api_util.flatten_fun_nokwargs(f_, in_tree)
in_avals = [core.get_aval(x) for x in args_flat]
debug = pe.tracing_debug_info(self.fun, in_tree, out_tree, False,
"custom_partitioning")
"custom_partitioning")
mesh = mesh_lib.thread_resources.env.physical_mesh
with core.extend_axis_env_nd(mesh.shape.items()):
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
Expand Down
18 changes: 11 additions & 7 deletions jax/_src/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
add_jaxvals, replace_internal_symbolic_zeros,
replace_rule_output_symbolic_zeros, Zero, zeros_like_aval, SymbolicZero)
from jax._src.ad_util import zeros_like_p, add_jaxvals_p # noqa: F401
from jax._src import api_util
from jax._src.api_util import flatten_fun, flatten_fun_nokwargs
from jax._src.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal)
from jax._src.dtypes import dtype, float0
Expand Down Expand Up @@ -98,7 +99,7 @@ def linearize_subtrace(_f: Callable, _store, _tag, nzs_in, *primals, **params):
nzs_out = tuple(type(t) is not Zero for t in out_tangents)
out_tangents = tuple(t for t, nz in zip(out_tangents, nzs_out) if nz)
out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents) # type: ignore[assignment]
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents)
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, None)
residual_avals = map(get_aval, consts)
if attrs_tracked:
raise NotImplementedError("TODO: attrs")
Expand Down Expand Up @@ -147,7 +148,7 @@ def _linearize_jaxpr(
jaxpr: core.ClosedJaxpr,
nonzeros: tuple[bool, ...]
) -> tuple[core.ClosedJaxpr, int, Sequence[bool], core.ClosedJaxpr]:
dbg = lu.TracingDebugInfo.from_jaxpr(jaxpr)
dbg = api_util.tracing_debug_info_from_jaxpr(jaxpr)
primal_trace = pe.DynamicJaxprTrace(dbg)
tangent_trace = pe.DynamicJaxprTrace(dbg)
lin_trace = LinearizeTrace(primal_trace, tangent_trace)
Expand All @@ -166,16 +167,17 @@ def new_arg(trace, primal_aval, nz):
out_primals, out_tangents = unzip2(map(lin_trace.to_primal_tangent_pair, ans))
del lin_trace, ans, tracers, new_arg

trace_debug_info = api_util.tracing_debug_info_from_jaxpr(jaxpr)
nzs_out = [type(t) is not Zero for t in out_tangents]
out_tangents = tuple(tangent_trace.to_jaxpr_tracer(t)
for (nz, t) in zip(nzs_out, out_tangents) if nz)
tangent_jaxpr, tangent_consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents)
tangent_jaxpr, tangent_consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, trace_debug_info)
tangent_trace.invalidate()
if attrs_tracked:
raise NotImplementedError("TODO: attrs")
residuals_and_primals = (*tangent_consts, *out_primals)
residuals_and_primals = map(primal_trace.to_jaxpr_tracer, residuals_and_primals) # type: ignore[assignment]
primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals)
primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals, trace_debug_info)
primal_trace.invalidate()
num_residuals = len(tangent_consts)
tangent_jaxpr = pe.close_jaxpr(convert_constvars_jaxpr_constvars_at_end(tangent_jaxpr))
Expand Down Expand Up @@ -207,7 +209,7 @@ def direct_linearize(traceable: lu.WrappedFun,
out_nzs = [type(t) is not Zero for t in out_tangents]
out_nz_tangents = [t for t, nz in zip(out_tangents, out_nzs) if nz]
out_nz_tangents = map(tangent_trace.to_jaxpr_tracer, out_nz_tangents) # type: ignore
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_nz_tangents)
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_nz_tangents, traceable.debug_info)
tangent_trace.invalidate()
out_tangents_pvals = [pe.PartialVal.unknown(core.get_aval(t)) if nz else
pe.PartialVal.known(zeros_like_aval(t.aval))
Expand Down Expand Up @@ -1019,12 +1021,14 @@ def jvp_jaxpr(jaxpr: core.ClosedJaxpr, nonzeros: Sequence[bool],
def _jvp_jaxpr(jaxpr: core.ClosedJaxpr,
nonzeros: Sequence[bool], instantiate: Sequence[bool]):
assert len(jaxpr.in_avals) == len(nonzeros)
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
debug_info = api_util.tracing_debug_info_from_jaxpr(jaxpr)
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr), debug_info=debug_info)
f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate, transform_stack=False),
nonzeros)
tangent_avals = [aval.to_tangent_aval() for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]
avals_in = list(it.chain(jaxpr.in_avals, tangent_avals))
jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in)
jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic(
f_jvp, avals_in, debug_info)
return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros()

@lu.transformation_with_aux2
Expand Down
41 changes: 26 additions & 15 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ def trace_to_subjaxpr_nounits(

@lu.transformation2
def trace_to_subjaxpr_nounits2(
f,
f: Callable,
tag: TraceTag,
instantiate: bool | Sequence[bool],
in_pvals: Sequence[PartialVal]):
Expand Down Expand Up @@ -933,7 +933,7 @@ def _partial_eval_jaxpr_nounits(jaxpr: ClosedJaxpr,
in_unknowns: Sequence[bool],
instantiate: bool | Sequence[bool]):
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr),
debug_info=lu.TracingDebugInfo.from_jaxpr(jaxpr))
debug_info=api_util.tracing_debug_info_from_jaxpr(jaxpr))

cell = []
def fun(*known_vals_in):
Expand Down Expand Up @@ -1335,10 +1335,13 @@ def prune_jaxpr_outputs(jaxpr: Jaxpr, used_outputs: Sequence[bool]) -> Jaxpr:

def _prune_jaxpr_outputs(jaxpr: Jaxpr, used_outputs: tuple[bool, ...]) -> Jaxpr:
outvars = [v for v, b in zip(jaxpr.outvars, used_outputs) if b]
dbg = jaxpr.debug_info and core.JaxprDebugInfo(
jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info,
jaxpr.debug_info.arg_names,
tuple(v for v, b in zip(jaxpr.debug_info.result_paths, used_outputs) if b))
if jaxpr.debug_info:
dbg = jaxpr.debug_info and core.JaxprDebugInfo(
jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info,
jaxpr.debug_info.arg_names,
jaxpr.debug_info.filter_result_paths(used_outputs))
else:
dbg = None
new_jaxpr = jaxpr.replace(outvars=outvars, debug_info=dbg)
config.enable_checks.value and core.check_jaxpr(new_jaxpr)
return new_jaxpr
Expand Down Expand Up @@ -1425,8 +1428,8 @@ def write(x: Atom, b: bool) -> None:

dbg = jaxpr.debug_info and core.JaxprDebugInfo(
jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info,
tuple(v for v, b in zip(jaxpr.debug_info.arg_names, used_inputs) if b),
tuple(v for v, b in zip(jaxpr.debug_info.result_paths, used_outputs) if b))
jaxpr.debug_info.filter_arg_names(used_inputs),
jaxpr.debug_info.filter_result_paths(used_outputs))
new_jaxpr = Jaxpr(jaxpr.constvars, invars, outvars, eqns, jaxpr_effects, dbg)
config.enable_checks.value and core.check_jaxpr(new_jaxpr)

Expand Down Expand Up @@ -1643,8 +1646,11 @@ def __init__(self, debug_info: lu.TracingDebugInfo | None):
def add_eqn(self, eqn: core.JaxprEqn):
self.eqns.append(eqn)

def to_jaxpr(self, trace: DynamicJaxprTrace, out_tracers: Sequence[Tracer]
def to_jaxpr(self, trace: DynamicJaxprTrace,
out_tracers: Sequence[Tracer],
debug_info: api_util.TracingDebugInfo | None,
) -> tuple[Jaxpr, list[Any], list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]:
jaxpr_debug_info = api_util.jaxpr_debug_info(debug_info)
# It's not necessary, but we keep the tracer-to-var mapping injective:
assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values()))
invars = self.attrs_vars + self.invars
Expand All @@ -1656,7 +1662,7 @@ def to_jaxpr(self, trace: DynamicJaxprTrace, out_tracers: Sequence[Tracer]
outvars = state_outvars + explicit_outvars
constvars, constvals = unzip2(self.constvar_to_val.items())
jaxpr_effects = make_jaxpr_effects(constvars, self.invars, explicit_outvars, self.eqns)
jaxpr = Jaxpr(constvars, invars, outvars, self.eqns, jaxpr_effects)
jaxpr = Jaxpr(constvars, invars, outvars, self.eqns, jaxpr_effects, jaxpr_debug_info)
jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals)
jaxpr, constvals = _inline_literals(jaxpr, constvals) # type: ignore
init_trees = [tree_structure(init_val) for init_val in self.attrs_inits]
Expand Down Expand Up @@ -2074,8 +2080,9 @@ def transpose_jaxpr_thunk():
self.frame.add_eqn(eqn)
return out_tracers

def to_jaxpr(self, out_tracers: Sequence[Tracer]):
return self.frame.to_jaxpr(self, out_tracers)
def to_jaxpr(self, out_tracers: Sequence[Tracer],
debug_info: api_util.TracingDebugInfo | None):
return self.frame.to_jaxpr(self, out_tracers, debug_info)


custom_staging_rules: dict[Primitive, Callable] = {}
Expand Down Expand Up @@ -2133,7 +2140,11 @@ def tracing_debug_info(
args, kwargs = dummy_args if has_kwargs else (dummy_args, {}) # type: ignore
def res_paths_thunk() -> tuple[str, ...]:
out_tree = out_tree_thunk()
dummy_result = tree_unflatten(out_tree, [False] * out_tree.num_leaves)
try:
dummy_result = tree_unflatten(out_tree, [False] * out_tree.num_leaves)
except:
# TODO(necula): remove this catch-all. Repro in batching_test:test_basic_jit
dummy_result = 0
return tuple(tree_util.keystr(path)
for path, _ in tree_util.generate_key_paths(dummy_result))
return api_util.tracing_debug_info(traced_for, fn, args, kwargs,
Expand Down Expand Up @@ -2168,7 +2179,7 @@ def trace_to_jaxpr_dynamic(

out_tracers = map(trace.to_jaxpr_tracer, ans)
_check_no_returned_refs(debug_info, out_tracers)
jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers)
jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers, debug_info)
del trace, fun, in_tracers, out_tracers, ans

config.enable_checks.value and core.check_jaxpr(jaxpr)
Expand Down Expand Up @@ -2198,7 +2209,7 @@ def _check_no_returned_refs(
origin_info = ('\n\nThe returned mutable array was created on line '
f'{source_info_util.summarize(eqn.source_info)}.')
elif v in frame.invars:
arg_name = dbg.arg_names[frame.invars.index(v)] # type: ignore
arg_name = dbg.safe_arg_names(len(frame.invars))[frame.invars.index(v)] # type: ignore
origin_info = ('\n\nThe returned mutable array was passed in as the '
f'argument {arg_name}.')
else:
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,8 +881,8 @@ def lower_parallel_callable(
replicated_args=replicated_args,
arg_shardings=None,
result_shardings=None,
arg_names=jaxpr._debug_info and jaxpr._debug_info.arg_names,
result_names=jaxpr._debug_info and jaxpr._debug_info.result_paths,
arg_names=jaxpr._debug_info and jaxpr._debug_info.safe_arg_names(len(jaxpr.invars)),
result_names=jaxpr._debug_info and jaxpr._debug_info.safe_result_paths(len(jaxpr.outvars)),
num_replicas=replicas.num_global_replicas,
lowering_parameters=lowering_parameters)
return PmapComputation(lowering_result.module,
Expand Down Expand Up @@ -3161,7 +3161,7 @@ def check_arg_avals_for_call(ref_avals, arg_avals,
f"but called with {len(arg_avals)}")

if jaxpr_debug_info is not None:
arg_names = [f"'{name}'" for name in jaxpr_debug_info.arg_names]
arg_names = [f"'{name}'" for name in jaxpr_debug_info.safe_arg_names(len(ref_avals))]
else:
num_args = len(ref_avals)
arg_names = [f"{i + 1}/{num_args}" for i in range(num_args)]
Expand Down
Loading

0 comments on commit cfd34b2

Please sign in to comment.