mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #26399 from gnecula:debug_info_jaxpr_6
PiperOrigin-RevId: 725563435
This commit is contained in:
commit
49eccd6c60
@ -27,7 +27,7 @@ from jax._src import dtypes
|
||||
from jax._src.state.types import AbstractRef
|
||||
from jax._src.tree_util import (
|
||||
PyTreeDef, tree_flatten, tree_unflatten, tree_map,
|
||||
treedef_children, generate_key_paths, keystr, broadcast_prefix,
|
||||
treedef_children, generate_key_paths, broadcast_prefix,
|
||||
prefix_errors)
|
||||
from jax._src.tree_util import _replace_nones
|
||||
from jax._src import linear_util as lu
|
||||
@ -595,6 +595,15 @@ def debug_info(
|
||||
sourceinfo: str | None = None,
|
||||
signature: inspect.Signature | None = None,
|
||||
) -> core.DebugInfo:
|
||||
"""Constructd core.DebugInfo for a function given example args and kwargs.
|
||||
|
||||
`args` and `kwargs` are example positional and keyword arguments, users with
|
||||
`inspect.Signature` to get the names of argments. The arguments that are
|
||||
considered static for tracing purposes should be included, and designated
|
||||
using `static_argnums` and `static_argnames`.
|
||||
|
||||
See docstring for linear_util.DebugInfo.
|
||||
"""
|
||||
if sourceinfo is None:
|
||||
sourceinfo = fun_sourceinfo(fun)
|
||||
if signature is None:
|
||||
@ -671,12 +680,13 @@ def _non_static_arg_names(fn_signature: inspect.Signature | None,
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
else:
|
||||
return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items()
|
||||
return tuple(f'{name}{lu._clean_keystr_arg_names(path)}'
|
||||
for name, x in ba.arguments.items()
|
||||
for path, l in generate_key_paths(x) if l is not static)
|
||||
args_arg_names = tuple(f'args{keystr(path)}'
|
||||
args_arg_names = tuple(f'args{lu._clean_keystr_arg_names(path)}'
|
||||
for path, l in generate_key_paths(args_)
|
||||
if l is not static)
|
||||
kwargs_arg_names = tuple(f'kwargs{keystr(path)}'
|
||||
kwargs_arg_names = tuple(f'kwargs{lu._clean_keystr_arg_names(path)}'
|
||||
for path, l in generate_key_paths(kwargs_)
|
||||
if l is not static)
|
||||
arg_names = args_arg_names + kwargs_arg_names
|
||||
|
@ -833,7 +833,7 @@ def checkify_while_body_jaxpr(
|
||||
# This checks if the next cond application will error
|
||||
_ = cond_f(*c_consts, *out)
|
||||
return out
|
||||
new_body_f_ = lu.wrap_init(new_body_f)
|
||||
new_body_f_ = lu.wrap_init(new_body_f, debug_info=body_jaxpr.jaxpr.debug_info)
|
||||
c_consts_avals = cond_jaxpr.in_avals[:c_consts_num]
|
||||
jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(new_body_f_, [*c_consts_avals,
|
||||
*body_jaxpr.in_avals])
|
||||
@ -952,7 +952,8 @@ error_checks[ad_checkpoint.remat_p] = remat_error_check
|
||||
|
||||
|
||||
def shard_map_error_check(
|
||||
error, enabled_errors, *vals_in, jaxpr, in_names, out_names, **kwargs
|
||||
error: Error, enabled_errors, *vals_in,
|
||||
jaxpr: core.Jaxpr, in_names, out_names, **kwargs
|
||||
):
|
||||
if (mesh := kwargs.get('mesh')) is None:
|
||||
raise ValueError('Mesh must be provided for shard_map with checkify.')
|
||||
@ -976,7 +977,6 @@ def shard_map_error_check(
|
||||
)
|
||||
num_out_error_vals = out_tree.num_leaves - len(out_names)
|
||||
|
||||
@lu.wrap_init
|
||||
def expand_errors_leading_dim(*xs):
|
||||
outs = core.eval_jaxpr(checked_jaxpr.jaxpr, checked_jaxpr.consts, *xs)
|
||||
errs, outs = split_list(outs, [num_out_error_vals])
|
||||
@ -985,7 +985,9 @@ def shard_map_error_check(
|
||||
|
||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
expand_errors_leading_dim, checked_jaxpr.in_avals
|
||||
lu.wrap_init(expand_errors_leading_dim,
|
||||
debug_info=checked_jaxpr.jaxpr.debug_info),
|
||||
checked_jaxpr.in_avals
|
||||
)
|
||||
checked_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||||
|
||||
@ -993,7 +995,8 @@ def shard_map_error_check(
|
||||
# Use fully sharded partitioning for out errors.
|
||||
new_out_names = (*([{0: mesh.axis_names}] * num_out_error_vals), *out_names)
|
||||
subfun = lu.hashable_partial(
|
||||
lu.wrap_init(core.eval_jaxpr), checked_jaxpr.jaxpr, checked_jaxpr.consts
|
||||
lu.wrap_init(core.eval_jaxpr, debug_info=checked_jaxpr.jaxpr.debug_info),
|
||||
checked_jaxpr.jaxpr, checked_jaxpr.consts
|
||||
)
|
||||
new_params = dict(
|
||||
jaxpr=checked_jaxpr.jaxpr,
|
||||
@ -1007,8 +1010,10 @@ def shard_map_error_check(
|
||||
return tree_unflatten(out_tree, err_and_out)
|
||||
error_checks[shard_map.shard_map_p] = shard_map_error_check
|
||||
|
||||
def custom_jvp_call_rule(in_err, enabled_errors, *in_vals, num_consts,
|
||||
jvp_jaxpr_thunk, call_jaxpr, **params):
|
||||
def custom_jvp_call_rule(in_err: Error,
|
||||
enabled_errors: set, *in_vals, num_consts,
|
||||
jvp_jaxpr_fun: lu.WrappedFun,
|
||||
call_jaxpr: core.ClosedJaxpr, **params):
|
||||
# The types to have in mind are:
|
||||
# jvp : (a -> b) -> (a, T a) -> (b, T b)
|
||||
# checkify : (a -> b) -> a -> Err b
|
||||
@ -1021,10 +1026,11 @@ def custom_jvp_call_rule(in_err, enabled_errors, *in_vals, num_consts,
|
||||
err_vals, err_tree = jtu.tree_flatten(in_err)
|
||||
partial_checkify = lu.wrap_init(
|
||||
functools.partial(checkify_jaxpr_flat, call_jaxpr.jaxpr,
|
||||
call_jaxpr.consts, enabled_errors, err_tree))
|
||||
call_jaxpr.consts, enabled_errors, err_tree),
|
||||
debug_info=call_jaxpr.jaxpr.debug_info)
|
||||
partial_checkify, f_metadata = _flatten_and_get_error_metadata_thunk(
|
||||
partial_checkify)
|
||||
jvp = lift_jvp(err_tree.num_leaves, num_consts, jvp_jaxpr_thunk)
|
||||
jvp = lift_jvp(err_tree.num_leaves, num_consts, jvp_jaxpr_fun)
|
||||
jvp, jvp_out_tree = flatten_fun_output(jvp)
|
||||
all_outs = custom_derivatives.custom_jvp_call_p.bind(
|
||||
partial_checkify, jvp, *err_vals, *in_vals, **params)
|
||||
@ -1041,17 +1047,17 @@ error_checks[custom_derivatives.custom_jvp_call_p] = custom_jvp_call_rule
|
||||
|
||||
# Compared to custom_derivatives.lift_jvp, we're handling the extra inputs and
|
||||
# outputs that checkify adds (just forwarding the error data's primal and
|
||||
# tangent components). The jaxpr in jvp_jaxpr_thunk doesn't expect those.
|
||||
# tangent components). The jaxpr in jvp_jaxpr_fun doesn't expect those.
|
||||
# TODO(mattjj): can we simplify this, or dedup with custom_derivatives.lift_jvp?
|
||||
# Adding another layer of lu.transformation was tricky, though maybe doable.
|
||||
def lift_jvp(num_errs, num_consts, jvp_jaxpr_thunk):
|
||||
@lu.wrap_init
|
||||
def lift_jvp(num_errs: int, num_consts: int,
|
||||
jvp_jaxpr_fun: lu.WrappedFun) -> lu.WrappedFun:
|
||||
def jvp(*xs):
|
||||
n, ragged = divmod(len(xs), 2)
|
||||
assert not ragged
|
||||
primals, tangents = xs[num_consts+num_errs:n], xs[n+num_consts+num_errs:]
|
||||
zeros = [type(t) is SymbolicZero for t in tangents]
|
||||
jvp_jaxpr, jvp_consts, out_zeros = jvp_jaxpr_thunk(*zeros)
|
||||
jvp_jaxpr, jvp_consts, out_zeros = jvp_jaxpr_fun.call_wrapped(*zeros)
|
||||
nonzero_tangents = [t for t in tangents if type(t) is not SymbolicZero]
|
||||
out = core.eval_jaxpr(jvp_jaxpr, jvp_consts, *primals, *nonzero_tangents)
|
||||
out_primals, nz_out_tangents = split_list(out, [len(out_zeros)])
|
||||
@ -1063,7 +1069,7 @@ def lift_jvp(num_errs, num_consts, jvp_jaxpr_thunk):
|
||||
primal_errs = xs[num_consts:num_consts+num_errs]
|
||||
tangent_errs = xs[n+num_consts:n+num_consts+num_errs]
|
||||
return [*primal_errs, *out_primals, *tangent_errs, *out_tangents]
|
||||
return jvp
|
||||
return lu.wrap_init(jvp, debug_info=jvp_jaxpr_fun.debug_info)
|
||||
|
||||
def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals,
|
||||
fun_jaxpr: core.ClosedJaxpr,
|
||||
|
@ -378,20 +378,19 @@ class CustomJVPCallPrimitive(core.Primitive):
|
||||
new_params = dict(params)
|
||||
call_jaxpr: core.ClosedJaxpr = new_params.pop('call_jaxpr')
|
||||
num_consts: int = new_params.pop('num_consts')
|
||||
jvp_jaxpr_thunk = new_params.pop('jvp_jaxpr_thunk')
|
||||
jvp_jaxpr_fun = new_params.pop('jvp_jaxpr_fun')
|
||||
fun = lu.wrap_init(core.jaxpr_as_fun(call_jaxpr),
|
||||
debug_info=call_jaxpr.jaxpr.debug_info)
|
||||
jvp = lift_jvp(num_consts, jvp_jaxpr_thunk, call_jaxpr.jaxpr.debug_info)
|
||||
jvp = lift_jvp(num_consts, jvp_jaxpr_fun)
|
||||
return [fun, jvp], new_params
|
||||
|
||||
def lift_jvp(num_consts: int, jvp_jaxpr_thunk: Callable,
|
||||
debug_info: core.DebugInfo | None) -> lu.WrappedFun:
|
||||
def lift_jvp(num_consts: int, jvp_jaxpr_fun: lu.WrappedFun) -> lu.WrappedFun:
|
||||
def jvp(*xs):
|
||||
n, ragged = divmod(len(xs), 2)
|
||||
assert not ragged
|
||||
primals, tangents = xs[num_consts:n], xs[n+num_consts:]
|
||||
zeros = [type(t) is SymbolicZero for t in tangents]
|
||||
jvp_jaxpr, jvp_consts, out_zeros = jvp_jaxpr_thunk(*zeros)
|
||||
jvp_jaxpr, jvp_consts, out_zeros = jvp_jaxpr_fun.call_wrapped(*zeros)
|
||||
nonzero_tangents = [t for t in tangents if type(t) is not SymbolicZero]
|
||||
out = core.eval_jaxpr(jvp_jaxpr, jvp_consts, *primals, *nonzero_tangents)
|
||||
out_primals, nz_out_tangents = split_list(out, [len(out_zeros)])
|
||||
@ -401,16 +400,16 @@ def lift_jvp(num_consts: int, jvp_jaxpr_thunk: Callable,
|
||||
for p, z in zip(out_primals, out_zeros)]
|
||||
assert next(nz_out_tangents_, None) is None
|
||||
return [*out_primals, *out_tangents]
|
||||
return lu.wrap_init(jvp, debug_info=debug_info)
|
||||
return lu.wrap_init(jvp, debug_info=jvp_jaxpr_fun.debug_info)
|
||||
|
||||
effects.custom_derivatives_allowed_effects.add_type(lax.InOutFeedEffect)
|
||||
|
||||
custom_jvp_call_p = CustomJVPCallPrimitive('custom_jvp_call')
|
||||
|
||||
def _custom_jvp_call_typecheck(_, *in_avals, call_jaxpr, jvp_jaxpr_thunk,
|
||||
def _custom_jvp_call_typecheck(_, *in_avals, call_jaxpr, jvp_jaxpr_fun,
|
||||
num_consts, symbolic_zeros):
|
||||
# TODO(mattjj): could do more checking here...
|
||||
del in_avals, jvp_jaxpr_thunk, num_consts
|
||||
del in_avals, jvp_jaxpr_fun, num_consts
|
||||
disallowed_effects = effects.custom_derivatives_allowed_effects.filter_not_in(call_jaxpr.effects)
|
||||
if disallowed_effects:
|
||||
raise NotImplementedError(
|
||||
@ -418,9 +417,9 @@ def _custom_jvp_call_typecheck(_, *in_avals, call_jaxpr, jvp_jaxpr_thunk,
|
||||
return call_jaxpr.out_avals, call_jaxpr.effects
|
||||
core.custom_typechecks[custom_jvp_call_p] = _custom_jvp_call_typecheck
|
||||
|
||||
def _custom_jvp_call_mlir_translation(ctx, *args, call_jaxpr, jvp_jaxpr_thunk,
|
||||
def _custom_jvp_call_mlir_translation(ctx, *args, call_jaxpr, jvp_jaxpr_fun,
|
||||
num_consts, symbolic_zeros):
|
||||
del jvp_jaxpr_thunk, num_consts, symbolic_zeros
|
||||
del jvp_jaxpr_fun, num_consts, symbolic_zeros
|
||||
consts = mlir._ir_consts(call_jaxpr.consts)
|
||||
out, tokens = mlir.jaxpr_subcomp(ctx.module_context, call_jaxpr.jaxpr,
|
||||
ctx.name_stack, ctx.tokens_in, consts,
|
||||
@ -452,7 +451,7 @@ def _custom_jvp_call_dce(
|
||||
return [False] * len(eqn.invars), None
|
||||
|
||||
call_jaxpr = eqn.params["call_jaxpr"]
|
||||
jvp_jaxpr_thunk = eqn.params["jvp_jaxpr_thunk"]
|
||||
jvp_jaxpr_fun = eqn.params["jvp_jaxpr_fun"]
|
||||
# We must set instantiate=True because some inputs that are unused by the
|
||||
# DCE'ed primal might be used in the JVP rule.
|
||||
dce_call_jaxpr, used_ins = _cached_closed_call_dce_instantiate(
|
||||
@ -461,7 +460,7 @@ def _custom_jvp_call_dce(
|
||||
|
||||
@pe._memoize
|
||||
def dce_jvp_jaxpr_thunk(*in_zeros):
|
||||
jvp_jaxpr, consts, out_zeros = jvp_jaxpr_thunk(*in_zeros)
|
||||
jvp_jaxpr, consts, out_zeros = jvp_jaxpr_fun.call_wrapped(*in_zeros)
|
||||
dce_jvp_jaxpr, _ = pe.dce_jaxpr(jvp_jaxpr, [*used_outs, *used_outs], True)
|
||||
dce_out_zeros = [v for used, v in zip(used_outs, out_zeros) if used]
|
||||
return dce_jvp_jaxpr, consts, dce_out_zeros
|
||||
@ -470,7 +469,8 @@ def _custom_jvp_call_dce(
|
||||
new_params = dict(
|
||||
eqn.params,
|
||||
call_jaxpr=dce_call_jaxpr,
|
||||
jvp_jaxpr_thunk=dce_jvp_jaxpr_thunk,
|
||||
jvp_jaxpr_fun=lu.wrap_init(dce_jvp_jaxpr_thunk,
|
||||
debug_info=jvp_jaxpr_fun.debug_info)
|
||||
)
|
||||
new_eqn = pe.new_jaxpr_eqn(
|
||||
eqn.invars, outvars, eqn.primitive, new_params, dce_call_jaxpr.effects,
|
||||
|
@ -84,9 +84,12 @@ def jvpfun(f: Callable, instantiate, transform_stack, primals, tangents):
|
||||
return out_primals, out_tangents
|
||||
|
||||
@lu.transformation_with_aux2
|
||||
def linearize_subtrace(_f: Callable, _store, _tag, nzs_in, *primals, **params):
|
||||
def linearize_subtrace(_f: Callable, _store: lu.Store, _tag: core.TraceTag,
|
||||
nzs_in: Sequence[bool],
|
||||
debug_info: core.DebugInfo | None,
|
||||
*primals, **params):
|
||||
with core.take_current_trace() as parent_trace:
|
||||
tangent_trace = pe.DynamicJaxprTrace(None) # TODO(necula): fill-in the debug info (use JAX_USE_DIRECT_LINEARIZATION=1)
|
||||
tangent_trace = pe.DynamicJaxprTrace(debug_info)
|
||||
linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=_tag)
|
||||
tracers = [LinearizeTracer(linearize_trace, p,
|
||||
tangent_trace.new_arg(get_aval(p).to_tangent_aval()))
|
||||
@ -99,7 +102,7 @@ def linearize_subtrace(_f: Callable, _store, _tag, nzs_in, *primals, **params):
|
||||
nzs_out = tuple(type(t) is not Zero for t in out_tangents)
|
||||
out_tangents = tuple(t for t, nz in zip(out_tangents, nzs_out) if nz)
|
||||
out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents) # type: ignore[assignment]
|
||||
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, None)
|
||||
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, debug_info)
|
||||
residual_avals = map(get_aval, consts)
|
||||
if attrs_tracked:
|
||||
raise NotImplementedError("TODO: attrs")
|
||||
@ -678,11 +681,13 @@ class LinearizeTrace(Trace):
|
||||
tangent_nzs_out = [type(t) is not Zero for t in tangents_out]
|
||||
return map(partial(maybe_linearize_tracer, self), primals_out, tangent_nzs_out, tangents_out)
|
||||
|
||||
def process_call(self, call_primitive, f, tracers, params):
|
||||
def process_call(self, call_primitive, f: lu.WrappedFun,
|
||||
tracers, params):
|
||||
assert call_primitive.multiple_results
|
||||
primals, tangents = unzip2(map(self.to_primal_tangent_pair, tracers))
|
||||
nzs_in = tuple(type(t) is not Zero for t in tangents)
|
||||
f_primal, linearize_outs_thunk = linearize_subtrace(f, self.tag, nzs_in)
|
||||
f_primal, linearize_outs_thunk = linearize_subtrace(f, self.tag, nzs_in,
|
||||
f.debug_info)
|
||||
if isinstance(call_primitive, core.MapPrimitive):
|
||||
@as_hashable_function(closure=(linearize_outs_thunk))
|
||||
def new_out_axes_thunk():
|
||||
@ -725,7 +730,9 @@ class LinearizeTrace(Trace):
|
||||
|
||||
nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz]
|
||||
nz_tangents_out = call_primitive.bind_with_trace(
|
||||
self.tangent_trace, (lu.wrap_init(f_tangent), *residuals, *nz_tangents_in), new_params)
|
||||
self.tangent_trace, (lu.wrap_init(f_tangent,
|
||||
debug_info=lin_jaxpr.debug_info),
|
||||
*residuals, *nz_tangents_in), new_params)
|
||||
nz_tangents_out_iter = iter(nz_tangents_out)
|
||||
tangents_out = [next(nz_tangents_out_iter) if nz else Zero.from_primal_value(primal)
|
||||
for nz, primal in zip(nzs_out, primals_out)]
|
||||
|
@ -502,7 +502,7 @@ def _closed_call_param_updater(params, _, __):
|
||||
return dict(params, call_jaxpr=core.ClosedJaxpr(jaxpr, ()))
|
||||
call_param_updaters[core.closed_call_p] = _closed_call_param_updater
|
||||
|
||||
def abstract_eval_fun(fun, *avals, debug_info=None, **params):
|
||||
def abstract_eval_fun(fun: Callable, *avals, debug_info=None, **params):
|
||||
_, avals_out, _, () = trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(fun, params, debug_info=debug_info), avals)
|
||||
assert all(isinstance(aval, AbstractValue) for aval in avals_out)
|
||||
@ -1992,7 +1992,9 @@ class DynamicJaxprTrace(core.Trace):
|
||||
self.frame.add_eqn(eqn)
|
||||
return out_tracers
|
||||
|
||||
def process_custom_jvp_call(self, prim, fun, jvp, tracers, symbolic_zeros):
|
||||
def process_custom_jvp_call(self, prim, fun: lu.WrappedFun,
|
||||
jvp: lu.WrappedFun, tracers,
|
||||
symbolic_zeros: bool):
|
||||
tracers = map(self.to_jaxpr_tracer, tracers)
|
||||
in_avals = [t.aval for t in tracers]
|
||||
in_tangent_avals = [t.to_tangent_aval() for t in in_avals]
|
||||
@ -2014,7 +2016,8 @@ class DynamicJaxprTrace(core.Trace):
|
||||
outvars = map(self.makevar, out_tracers)
|
||||
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim,
|
||||
dict(call_jaxpr=closed_fun_jaxpr,
|
||||
jvp_jaxpr_thunk=jvp_jaxpr_thunk,
|
||||
jvp_jaxpr_fun=lu.wrap_init(jvp_jaxpr_thunk,
|
||||
debug_info=jvp.debug_info),
|
||||
num_consts=len(consts),
|
||||
symbolic_zeros=symbolic_zeros),
|
||||
fun_jaxpr.effects,
|
||||
|
@ -1255,7 +1255,9 @@ def _scan_state_partial_discharge_rule(should_discharge, in_avals, out_avals, *a
|
||||
)
|
||||
# TODO(cperivol): avoid tracing the jaxpr twice. When doing so don't
|
||||
# forget to manage the effects.
|
||||
new_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(wrapped), avals_for_wrapped_no_refs)
|
||||
new_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(wrapped, debug_info=discharged_jaxpr.debug_info),
|
||||
avals_for_wrapped_no_refs)
|
||||
all_out = scan_p.bind(*args_for_wrapped,
|
||||
jaxpr=core.ClosedJaxpr(new_jaxpr, ()),
|
||||
length=length,
|
||||
@ -1922,9 +1924,9 @@ def _while_partial_discharge_rule(should_discharge, in_avals, out_avals, *args,
|
||||
carry, refs_out = split_list(carry_refs, [num_carry])
|
||||
return [*refs_out, *carry]
|
||||
new_body_jaxpr, _, new_body_consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(new_body), [*remaining_body_const_avals, *[a.inner_aval for a
|
||||
in ref_avals],
|
||||
*carry_avals])
|
||||
lu.wrap_init(new_body, debug_info=discharged_body_jaxpr.debug_info),
|
||||
[*remaining_body_const_avals, *[a.inner_aval for a in ref_avals],
|
||||
*carry_avals])
|
||||
if new_body_consts: raise NotImplementedError
|
||||
|
||||
# Since some `Ref`s that were previously consts are now carries, we need to
|
||||
@ -1936,9 +1938,8 @@ def _while_partial_discharge_rule(should_discharge, in_avals, out_avals, *args,
|
||||
del refs # We don't use them here!
|
||||
return core.eval_jaxpr(cond_jaxpr, cond_jaxpr_consts, *consts, *carry)
|
||||
new_cond_jaxpr, _, new_cond_consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(new_cond), [*cond_consts_avals,
|
||||
*[a.inner_aval for a in ref_avals],
|
||||
*carry_avals])
|
||||
lu.wrap_init(new_cond, debug_info=cond_jaxpr.debug_info),
|
||||
[*cond_consts_avals, *[a.inner_aval for a in ref_avals], *carry_avals])
|
||||
if new_cond_consts: raise NotImplementedError
|
||||
|
||||
out = while_p.bind(*cond_consts, *remaining_body_consts, *refs, *carry,
|
||||
|
@ -159,7 +159,8 @@ def _root_jvp(const_lengths, jaxprs, primals, tangents):
|
||||
linearize_and_solve = partial(
|
||||
core.jaxpr_as_fun(jaxprs.l_and_s), *params.l_and_s)
|
||||
f_at_solution = lambda *params: f(*params, *solution)
|
||||
_, rhs = ad.jvp(lu.wrap_init(f_at_solution)).call_wrapped(
|
||||
_, rhs = ad.jvp(lu.wrap_init(f_at_solution,
|
||||
debug_info=jaxprs.f.jaxpr.debug_info)).call_wrapped(
|
||||
params.f, params_dot.f)
|
||||
solution_dot = _map(
|
||||
operator.neg, linearize_and_solve(*solution, *rhs))
|
||||
|
@ -65,13 +65,14 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
import re
|
||||
from typing import Any, NamedTuple
|
||||
import weakref
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import traceback_util
|
||||
from jax._src.tree_util import keystr, generate_key_paths
|
||||
from jax._src.tree_util import keystr, KeyPath, generate_key_paths
|
||||
from jax._src.util import curry, cache_clearing_funs, HashableFunction
|
||||
|
||||
|
||||
@ -266,13 +267,17 @@ class DebugInfo(NamedTuple):
|
||||
# The paths of the flattened non-static argnames,
|
||||
# e.g. ('x', 'dict_arg["a"]', ... ).
|
||||
# Uses `None` for the args that do not correspond to user-named arguments,
|
||||
# e.g., tangent args in jax.jvp.
|
||||
# e.g., tangent args in jax.jvp. At the moment, `arg_names` accuracy is
|
||||
# best-effort. Use `safe_arg_names` to detect and handle an unexpected
|
||||
# number of elements in `arg_names`.
|
||||
arg_names: tuple[str | None, ...]
|
||||
|
||||
# The result paths are not available while we are tracing the function,
|
||||
# instead we keep a thunk. Once we are done tracing, we use
|
||||
# `self.resolve_result_paths()` to execute the thunk and replace the
|
||||
# actual result paths.
|
||||
# actual result paths. At the moment, `result_paths` accuracy is
|
||||
# best-effort. Use `safe_result_paths` to detect and handle an unexpected
|
||||
# number of elements in `result_paths`.
|
||||
# e.g. ('[0]', '[1]', ...)
|
||||
result_paths: tuple[str, ...] | Callable[[], tuple[str, ...]] | None
|
||||
|
||||
@ -329,10 +334,16 @@ def wrap_init(f: Callable, params=None, *,
|
||||
return fun
|
||||
|
||||
|
||||
# We replace <flat index 0> with 0
|
||||
_re_clean_keystr_arg_names = re.compile(r"<flat index ([^>]+)>")
|
||||
def _clean_keystr_arg_names(k: KeyPath) -> str:
|
||||
res = keystr(k)
|
||||
return _re_clean_keystr_arg_names.sub(r"\1", res)
|
||||
|
||||
@transformation_with_aux2
|
||||
def _get_result_paths_thunk(_fun: Callable, _store: Store, *args, **kwargs):
|
||||
ans = _fun(*args, **kwargs)
|
||||
result_paths = [keystr(path) for path, _ in generate_key_paths(ans)]
|
||||
result_paths = [_clean_keystr_arg_names(path) for path, _ in generate_key_paths(ans)]
|
||||
if _store:
|
||||
# In some instances a lu.WrappedFun is called multiple times, e.g.,
|
||||
# the bwd function in a custom_vjp
|
||||
|
@ -3044,11 +3044,11 @@ def _custom_jvp_call_lowering_rule(
|
||||
ctx: LoweringRuleContext,
|
||||
*args,
|
||||
call_jaxpr: jax_core.Jaxpr,
|
||||
jvp_jaxpr_thunk: Callable,
|
||||
jvp_jaxpr_fun: lu.WrappedFun,
|
||||
num_consts: int,
|
||||
symbolic_zeros: bool,
|
||||
):
|
||||
del jvp_jaxpr_thunk
|
||||
del jvp_jaxpr_fun
|
||||
if symbolic_zeros: raise NotImplementedError
|
||||
if num_consts: raise NotImplementedError
|
||||
if call_jaxpr.consts: raise NotImplementedError
|
||||
|
@ -242,7 +242,9 @@ def _batch_block_mapping(
|
||||
|
||||
with grid_mapping.trace_env():
|
||||
block_mapping_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(_block_map_function), idx_avals)
|
||||
lu.wrap_init(_block_map_function,
|
||||
debug_info=block_mapping.index_map_jaxpr.jaxpr.debug_info),
|
||||
idx_avals)
|
||||
shape = block_mapping.block_shape
|
||||
if dim is batching.not_mapped:
|
||||
new_block_shape = shape
|
||||
|
@ -73,7 +73,7 @@ def hoist_consts_to_refs(
|
||||
return core.eval_jaxpr(jaxpr, all_consts, *args0, *args1)
|
||||
|
||||
hoisted_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(_hoist), in_avals)
|
||||
lu.wrap_init(_hoist, debug_info=jaxpr.debug_info), in_avals)
|
||||
assert not consts, "All consts should have been converted to refs"
|
||||
return hoisted_jaxpr
|
||||
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Callable
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import source_info_util
|
||||
@ -90,7 +90,9 @@ def jvp(f, primals, tangents, attr_tangents):
|
||||
primals_flat, in_tree = tree_flatten((attr_primals, *primals))
|
||||
tangents_flat, in_tree_ = tree_flatten((attr_tangents, *tangents))
|
||||
if in_tree != in_tree_: raise Exception
|
||||
f_, out_tree = flatten_fun_nokwargs(_set_attrs(lu.wrap_init(f), attrs), in_tree)
|
||||
dbg = api_util.debug_info("attrs_jvp", f, primals, {})
|
||||
f_, out_tree = flatten_fun_nokwargs(
|
||||
_set_attrs(lu.wrap_init(f, debug_info=dbg), attrs), in_tree)
|
||||
out_primals_flat, out_tangents_flat, tangent_attrs_out = _jvp(f_).call_wrapped(
|
||||
primals_flat, tangents_flat)
|
||||
out_primals = tree_unflatten(out_tree(), out_primals_flat)
|
||||
@ -151,12 +153,14 @@ ad.JVPTrace.process_getattr = _getattr_jvp
|
||||
ad.LinearizeTrace.process_setattr = _setattr_jvp
|
||||
ad.LinearizeTrace.process_getattr = _getattr_jvp
|
||||
|
||||
def linearize(f, *primals, attrs: list[tuple[Any, str]] = []):
|
||||
def linearize(f: Callable, *primals, attrs: list[tuple[Any, str]] = []):
|
||||
attr_primals = [jax_getattr(o, a) for o, a in attrs]
|
||||
attr_avals = [core.get_aval(p) for p in attr_primals]
|
||||
primals_flat, in_tree = tree_flatten(primals)
|
||||
tree = treedef_tuple((tree_structure(attr_primals), *in_tree.children()))
|
||||
f_, out_tree = flatten_fun_nokwargs(_set_attrs(lu.wrap_init(f), attrs), tree)
|
||||
dbg = api_util.debug_info("attrs linearize", f, primals, {})
|
||||
f_, out_tree = flatten_fun_nokwargs(
|
||||
_set_attrs(lu.wrap_init(f, debug_info=dbg), attrs), tree)
|
||||
primal_out, out_pvals, jaxpr, consts, attrs_out = _linearize(
|
||||
f_, *attr_primals, *primals_flat)
|
||||
f_lin = _lin_wrap(jaxpr, consts, out_pvals, attr_avals, (in_tree, out_tree()),
|
||||
@ -206,7 +210,9 @@ def vjp(f, *primals, attrs: list[tuple[Any, str]] = []):
|
||||
attr_primals = [jax_getattr(o, a) for o, a in attrs]
|
||||
primals_flat, in_tree = tree_flatten(primals)
|
||||
tree = treedef_tuple((tree_structure(attr_primals), *in_tree.children()))
|
||||
f_, out_tree = flatten_fun_nokwargs(_set_attrs(lu.wrap_init(f), attrs), tree)
|
||||
dbg = api_util.debug_info("attrs vjp", f, primals, {})
|
||||
f_, out_tree = flatten_fun_nokwargs(
|
||||
_set_attrs(lu.wrap_init(f, debug_info=dbg), attrs), tree)
|
||||
primal_out, out_pvals, jaxpr, consts, attrs_out = _linearize(
|
||||
f_, *attr_primals, *primals_flat)
|
||||
attr_avals = [core.get_aval(jax_getattr(o, a)).to_tangent_aval()
|
||||
|
@ -599,6 +599,8 @@ class GraphSerializationImpl(SerializationImpl):
|
||||
name_stack = util.wrap_name(fun_name, "jax2tf")
|
||||
self.name_stack = name_stack
|
||||
self.args_flat_tf = args_flat_tf
|
||||
self.debug = api_util.debug_info("jax2tf", fun_jax,
|
||||
args_specs, kwargs_specs)
|
||||
|
||||
def before_conversion(self):
|
||||
prev_enable_xla = _thread_local_state.enable_xla
|
||||
@ -623,7 +625,10 @@ class GraphSerializationImpl(SerializationImpl):
|
||||
dim_values, _ = _interpret_fun_jax(
|
||||
partial(shape_poly.compute_dim_vars_from_arg_shapes,
|
||||
self.args_avals_flat, args_kwargs_tree=self.in_tree),
|
||||
self.args_flat_tf, self.args_avals_flat, self.name_stack)
|
||||
self.args_flat_tf, self.args_avals_flat, self.name_stack,
|
||||
debug_info=api_util.debug_info("jax2tf dim_vars",
|
||||
shape_poly.compute_dim_vars_from_arg_shapes,
|
||||
self.args_specs, self.kwargs_specs))
|
||||
|
||||
_thread_local_state.shape_env = zip(dim_vars, dim_values)
|
||||
|
||||
@ -639,7 +644,8 @@ class GraphSerializationImpl(SerializationImpl):
|
||||
fun_flat_jax,
|
||||
args_flat_tf, self.args_avals_flat,
|
||||
self.name_stack,
|
||||
fresh_constant_cache=True)
|
||||
fresh_constant_cache=True,
|
||||
debug_info=self.debug)
|
||||
return outs_tf, self.outs_avals, out_tree_thunk()
|
||||
|
||||
def get_vjp_fun(self) -> tuple[Callable,
|
||||
@ -849,10 +855,12 @@ def _interpret_fun_jax(
|
||||
fun_jax: Callable,
|
||||
args_tf: Sequence[TfVal],
|
||||
args_avals: Sequence[core.ShapedArray],
|
||||
extra_name_stack: str | None,
|
||||
extra_name_stack: str | None, *,
|
||||
fresh_constant_cache: bool = False,
|
||||
debug_info: core.DebugInfo,
|
||||
) -> tuple[tuple[TfVal, ...], tuple[core.ShapedArray, ...]]:
|
||||
subtrace_fun = _interpret_subtrace(lu.wrap_init(fun_jax), args_avals)
|
||||
subtrace_fun = _interpret_subtrace(
|
||||
lu.wrap_init(fun_jax, debug_info=debug_info), args_avals)
|
||||
with _extended_name_stack(extra_name_stack):
|
||||
out_vals: Sequence[tuple[TfVal, core.ShapedArray]] = \
|
||||
_call_wrapped_with_new_constant_cache(subtrace_fun, args_tf,
|
||||
@ -1033,7 +1041,9 @@ def _convert_jax_impl(impl_jax: Callable, *,
|
||||
|
||||
results_tf, _ = _interpret_fun_jax(
|
||||
impl_multiple_results_jax, args_tf, _in_avals,
|
||||
extra_name_stack)
|
||||
extra_name_stack,
|
||||
debug_info=api_util.debug_info("jax2tf", impl_jax,
|
||||
args_tf, kwargs))
|
||||
return results_tf if multiple_results else results_tf[0]
|
||||
|
||||
return wrapped_tf
|
||||
@ -1066,7 +1076,8 @@ def _interpret_jaxpr(jaxpr: core.ClosedJaxpr, *args_tf: TfVal,
|
||||
"""
|
||||
outs_tf, _ = _interpret_fun_jax(core.jaxpr_as_fun(jaxpr),
|
||||
args_tf, jaxpr.in_avals, extra_name_stack,
|
||||
fresh_constant_cache=fresh_constant_cache)
|
||||
fresh_constant_cache=fresh_constant_cache,
|
||||
debug_info=jaxpr.jaxpr.debug_info)
|
||||
return outs_tf
|
||||
|
||||
|
||||
@ -1197,7 +1208,9 @@ def _eval_shape(shape: Sequence[shape_poly.DimSize], dtype=None) -> Sequence[TfV
|
||||
dim_vars, dim_values = util.unzip2(_thread_local_state.shape_env)
|
||||
shape_values_tf, _ = _interpret_fun_jax(
|
||||
partial(core.evaluate_shape, shape, dim_vars),
|
||||
dim_values, [core.dim_value_aval()] * len(dim_values), "") # type: ignore
|
||||
dim_values, [core.dim_value_aval()] * len(dim_values), "", # type: ignore
|
||||
debug_info=api_util.debug_info("jax2tf evaluate_shape", core.evaluate_shape,
|
||||
(0, 0, *dim_values), {}))
|
||||
# Keep only the non-constant dimensions
|
||||
return tuple(operator.index(d) if core.is_constant_dim(d) else d_tf # type: ignore
|
||||
for d, d_tf in zip(shape, shape_values_tf))
|
||||
@ -3431,10 +3444,10 @@ def _tridiagonal_solve(*args: TfVal, _in_avals, _out_aval, **params):
|
||||
tf_impl_with_avals[lax.linalg.tridiagonal_solve_p] = _tridiagonal_solve
|
||||
|
||||
def _custom_jvp_call(*args: TfVal, call_jaxpr: core.ClosedJaxpr,
|
||||
jvp_jaxpr_thunk: Callable,
|
||||
jvp_jaxpr_fun: Callable,
|
||||
num_consts: int) -> Sequence[TfVal]:
|
||||
# TODO(necula): ensure that there is no AD transformation in scope
|
||||
del jvp_jaxpr_thunk, num_consts
|
||||
del jvp_jaxpr_fun, num_consts
|
||||
return _interpret_jaxpr(call_jaxpr, *args, extra_name_stack="custom_jvp",
|
||||
fresh_constant_cache=False)
|
||||
|
||||
|
@ -38,6 +38,7 @@ from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax import random
|
||||
from jax import tree_util
|
||||
from jax._src import api_util
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import test_util as jtu
|
||||
@ -442,7 +443,10 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
partial(shape_poly.compute_dim_vars_from_arg_shapes,
|
||||
avals,
|
||||
args_kwargs_tree=tree_util.tree_flatten((avals, {}))[1]),
|
||||
args_tf, avals, "")
|
||||
args_tf, avals, "",
|
||||
debug_info=api_util.debug_info("jax2tf dim_vars",
|
||||
shape_poly.compute_dim_vars_from_arg_shapes,
|
||||
avals, {}))
|
||||
if expected_shapes is not None:
|
||||
expected_avals = tree_util.tree_map(
|
||||
lambda shape_str: core.ShapedArray(
|
||||
|
@ -1387,7 +1387,7 @@ def _closed_call_check(mesh, *in_rep, call_jaxpr, **kwargs):
|
||||
|
||||
|
||||
@register_check(custom_derivatives.custom_jvp_call_p)
|
||||
def _custom_jvp_call_check(mesh, *in_rep, call_jaxpr, jvp_jaxpr_thunk,
|
||||
def _custom_jvp_call_check(mesh, *in_rep, call_jaxpr, jvp_jaxpr_fun,
|
||||
num_consts, symbolic_zeros):
|
||||
return _check_rep(mesh, call_jaxpr.jaxpr, in_rep)
|
||||
|
||||
@ -1578,11 +1578,13 @@ def _shard_map_partial_eval(trace: pe.JaxprTrace, shard_map_p,
|
||||
return pe.merge_lists(out_knowns, out_tracers, out_consts)
|
||||
pe.JaxprTrace.process_shard_map = _shard_map_partial_eval
|
||||
|
||||
def _shard_map_linearize(trace, shard_map_p, f, tracers, mesh, in_names,
|
||||
def _shard_map_linearize(trace, shard_map_p, f: lu.WrappedFun,
|
||||
tracers, mesh, in_names,
|
||||
out_names_thunk, check_rep, rewrite, auto):
|
||||
primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers))
|
||||
nzs_in = tuple(type(t) is not ad.Zero for t in tangents)
|
||||
f_primal, linearize_outs_thunk = ad.linearize_subtrace(f, trace.tag, nzs_in)
|
||||
f_primal, linearize_outs_thunk = ad.linearize_subtrace(f, trace.tag, nzs_in,
|
||||
f.debug_info)
|
||||
f_primal = _promote_scalar_residuals_lin(f_primal, linearize_outs_thunk)
|
||||
tangent_in_names = [ax for ax, nz in zip(in_names, nzs_in) if nz]
|
||||
all_names = _all_mesh_names_except_spmd(mesh, trace)
|
||||
@ -1780,23 +1782,24 @@ def _partial_eval_jaxpr_custom_rule(
|
||||
pe.partial_eval_jaxpr_custom_rules[shard_map_p] = \
|
||||
_partial_eval_jaxpr_custom_rule
|
||||
|
||||
def _add_reshapes(which, jaxpr_known, jaxpr_staged):
|
||||
def _add_reshapes(which: Sequence[bool],
|
||||
jaxpr_known: core.Jaxpr,
|
||||
jaxpr_staged: core.Jaxpr) -> tuple[core.Jaxpr, core.Jaxpr]:
|
||||
# add singleton axes to residuals which are from jaxpr_known and are scalars
|
||||
which_ = [w and not v.aval.shape # pytype: disable=attribute-error
|
||||
for w, v in zip(which, jaxpr_staged.invars[:len(which)])]
|
||||
if not any(which_): return jaxpr_known, jaxpr_staged
|
||||
assert not jaxpr_known.constvars and not jaxpr_staged.constvars
|
||||
|
||||
@lu.wrap_init
|
||||
def known(*args):
|
||||
out = core.eval_jaxpr(jaxpr_known, (), *args)
|
||||
out_known, res = split_list(out, [len(out) - sum(which)])
|
||||
res = [_add_singleton(x) if not x.shape else x for x in res]
|
||||
return [*out_known, *res]
|
||||
avals_in = [v.aval for v in jaxpr_known.invars]
|
||||
jaxpr_known, _, (), () = pe.trace_to_jaxpr_dynamic(known, avals_in)
|
||||
jaxpr_known, _, (), () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(known, debug_info=jaxpr_known.debug_info), avals_in)
|
||||
|
||||
@lu.wrap_init
|
||||
def staged(*args):
|
||||
res_, ins = split_list(args, [len(which)])
|
||||
res = [_rem_singleton(x) if w else x for x, w in zip(res_, which_)]
|
||||
@ -1804,7 +1807,8 @@ def _add_reshapes(which, jaxpr_known, jaxpr_staged):
|
||||
res_avals = [core.unmapped_aval(1, 0, v.aval) if w else v.aval
|
||||
for w, v in zip(which_, jaxpr_staged.invars[:len(which)])]
|
||||
avals_in = [*res_avals, *[v.aval for v in jaxpr_staged.invars[len(which):]]]
|
||||
jaxpr_staged, _, (), () = pe.trace_to_jaxpr_dynamic(staged, avals_in)
|
||||
jaxpr_staged, _, (), () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(staged, debug_info=jaxpr_staged.debug_info), avals_in)
|
||||
|
||||
return jaxpr_known, jaxpr_staged
|
||||
|
||||
@ -2070,7 +2074,8 @@ def _replication_rewrite_match(
|
||||
in_rep: Sequence[set[AxisName]],
|
||||
out_rep_dst: Sequence[set[AxisName]],
|
||||
) -> core.ClosedJaxpr:
|
||||
f = lu.wrap_init(partial(core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts))
|
||||
f = lu.wrap_init(partial(core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts),
|
||||
debug_info=jaxpr.jaxpr.debug_info)
|
||||
f, out_rep = _efficient_transpose_rewrite_nomatch(f, mesh, in_rep)
|
||||
f = _match_rep(f, mesh, out_rep, out_rep_dst)
|
||||
jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals)
|
||||
|
@ -1042,8 +1042,14 @@ def _bcoo_dot_general_sampled_impl(A, B, indices, *, dimension_numbers):
|
||||
|
||||
@bcoo_dot_general_sampled_p.def_abstract_eval
|
||||
def _bcoo_dot_general_sampled_abstract_eval(A, B, indices, *, dimension_numbers):
|
||||
dense_result, = pe.abstract_eval_fun(lambda *args: [lax.dot_general(*args, dimension_numbers=dimension_numbers)], A, B)
|
||||
sparse_result, = pe.abstract_eval_fun(lambda *args: [_bcoo_extract(*args)], indices, dense_result)
|
||||
dbg = api_util.debug_info("bcoo_dot_general_sampled_abstract_eval",
|
||||
lax.dot_general, (A, B), dict(dimension_numbers=dimension_numbers))
|
||||
dense_result, = pe.abstract_eval_fun(lambda *args: [lax.dot_general(*args, dimension_numbers=dimension_numbers)], A, B,
|
||||
debug_info=dbg)
|
||||
dbg = api_util.debug_info("bcoo_dot_general_sampled_abstract_eval",
|
||||
_bcoo_extract, (indices, dense_result), {})
|
||||
sparse_result, = pe.abstract_eval_fun(lambda *args: [_bcoo_extract(*args)], indices, dense_result,
|
||||
debug_info=dbg)
|
||||
return sparse_result
|
||||
|
||||
def _bcoo_dot_general_sampled_transpose(ct, A, B, indices, *, dimension_numbers):
|
||||
|
@ -55,6 +55,7 @@ import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax._src import api_util
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src.custom_derivatives import lift_jvp
|
||||
@ -365,12 +366,15 @@ def sparsify_fun(wrapped_fun, args: list[ArrayOrSparse]):
|
||||
spenv = SparsifyEnv(out_bufs)
|
||||
return spvalues_to_arrays(spenv, out_spvalues())
|
||||
|
||||
def _sparsify_with_tracer(fun):
|
||||
def _sparsify_with_tracer(fun: Callable):
|
||||
"""Implementation of sparsify() using tracers."""
|
||||
@functools.wraps(fun)
|
||||
def _wrapped(*args):
|
||||
args_flat, in_tree = tree_flatten(args, is_leaf=_is_sparse_obj)
|
||||
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
||||
wrapped_fun, out_tree = flatten_fun_nokwargs(
|
||||
lu.wrap_init(fun,
|
||||
debug_info=api_util.debug_info("sparsify", fun, args, {})),
|
||||
in_tree)
|
||||
out = sparsify_fun(wrapped_fun, args_flat)
|
||||
return tree_unflatten(out_tree(), out)
|
||||
return _wrapped
|
||||
@ -439,7 +443,12 @@ def sparsify_raw(f):
|
||||
) -> tuple[Sequence[SparsifyValue], pytree.PyTreeDef]:
|
||||
spvalues_flat, in_tree = tree_flatten(spvalues, is_leaf=_is_spvalue)
|
||||
in_avals_flat = spvalues_to_avals(spenv, spvalues_flat)
|
||||
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(f, params), in_tree)
|
||||
wrapped_fun, out_tree = flatten_fun_nokwargs(
|
||||
lu.wrap_init(
|
||||
f, params,
|
||||
debug_info=api_util.debug_info("sparsify", f,
|
||||
spvalues_to_arrays(spenv, spvalues), {})),
|
||||
in_tree)
|
||||
jaxpr, out_avals_flat, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat)
|
||||
result = eval_sparse(jaxpr, consts, spvalues_flat, spenv)
|
||||
if len(out_avals_flat) != len(result):
|
||||
@ -716,14 +725,14 @@ def _gather_sparse_rule(spenv, *args, dimension_numbers, slice_sizes, unique_ind
|
||||
|
||||
sparse_rules_bcoo[lax.gather_p] = _gather_sparse_rule
|
||||
|
||||
def _sparsify_jaxpr(spenv, jaxpr, *spvalues):
|
||||
def _sparsify_jaxpr(spenv: SparsifyEnv,
|
||||
jaxpr: core.ClosedJaxpr, *spvalues):
|
||||
# TODO(jakevdp): currently this approach discards all information about
|
||||
# shared data & indices when generating the sparsified jaxpr. The
|
||||
# current approach produces valid sparsified while loops, but they
|
||||
# don't work in corner cases (see associated TODO in sparsify_test.py)
|
||||
out_tree: pytree.PyTreeDef | None = None
|
||||
|
||||
@lu.wrap_init
|
||||
def wrapped(*args_flat):
|
||||
# TODO(frostig,jakevdp): This closes over `spenv`, which can bring
|
||||
# in buffers from the "outer scope" as constants. Is this a
|
||||
@ -740,7 +749,8 @@ def _sparsify_jaxpr(spenv, jaxpr, *spvalues):
|
||||
args = spvalues_to_arrays(spenv, spvalues)
|
||||
args_flat, in_tree = tree_flatten(args)
|
||||
avals_flat = [core.get_aval(arg) for arg in args_flat]
|
||||
sp_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped, avals_flat)
|
||||
sp_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(wrapped, debug_info=jaxpr.jaxpr.debug_info), avals_flat)
|
||||
sp_jaxpr = pe.ClosedJaxpr(sp_jaxpr, consts)
|
||||
assert out_tree is not None
|
||||
return sp_jaxpr, out_tree
|
||||
@ -866,17 +876,18 @@ sparse_rules_bcsr[sparse.todense_p] = _todense_sparse_rule
|
||||
|
||||
def _custom_jvp_sparse_rule(spenv, *spvalues, **params):
|
||||
call_jaxpr: core.ClosedJaxpr = params.pop('call_jaxpr')
|
||||
jvp_jaxpr_thunk = params.pop('jvp_jaxpr_thunk')
|
||||
num_consts = params.pop('num_consts')
|
||||
jvp_jaxpr_fun: lu.WrappedFun = params.pop('jvp_jaxpr_fun')
|
||||
num_consts: int = params.pop('num_consts')
|
||||
sp_call_jaxpr, out_tree = _sparsify_jaxpr(spenv, call_jaxpr, *spvalues)
|
||||
@lu.wrap_init
|
||||
def fun(*arrs):
|
||||
sparrs = arrays_to_spvalues(spenv, arrs)
|
||||
out = eval_sparse(call_jaxpr.jaxpr, call_jaxpr.consts, sparrs, spenv)
|
||||
return spvalues_to_arrays(spenv, out)
|
||||
jvp = lift_jvp(num_consts, jvp_jaxpr_thunk, call_jaxpr.jaxpr.debug_info)
|
||||
jvp = lift_jvp(num_consts, jvp_jaxpr_fun)
|
||||
invals = spvalues_to_arrays(spenv, spvalues)
|
||||
outvals = jax.custom_derivatives.custom_jvp_call_p.bind(fun, jvp, *invals, **params)
|
||||
outvals = jax.custom_derivatives.custom_jvp_call_p.bind(
|
||||
lu.wrap_init(fun, debug_info=call_jaxpr.jaxpr.debug_info),
|
||||
jvp, *invals, **params)
|
||||
return arrays_to_spvalues(spenv, outvals)
|
||||
|
||||
sparse_rules_bcoo[jax.custom_derivatives.custom_jvp_call_p] = _custom_jvp_sparse_rule
|
||||
|
@ -26,9 +26,8 @@ import jax
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import jvp, linearize, vjp, jit, make_jaxpr
|
||||
from jax.api_util import flatten_fun_nokwargs
|
||||
from jax.api_util import flatten_fun_nokwargs, debug_info
|
||||
from jax._src import config
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import util
|
||||
@ -48,14 +47,16 @@ def call(f, *args):
|
||||
@util.curry
|
||||
def core_call(f, *args):
|
||||
args, in_tree = jax.tree.flatten(args)
|
||||
f, out_tree = flatten_fun_nokwargs(lu.wrap_init(f), in_tree)
|
||||
dbg = debug_info("core_call_test", f, args, {})
|
||||
f, out_tree = flatten_fun_nokwargs(lu.wrap_init(f, debug_info=dbg), in_tree)
|
||||
out = core.call_p.bind(f, *args)
|
||||
return jax.tree.unflatten(out_tree(), out)
|
||||
|
||||
@util.curry
|
||||
def core_closed_call(f, *args):
|
||||
args, in_tree = jax.tree.flatten(args)
|
||||
f, out_tree = flatten_fun_nokwargs(lu.wrap_init(f), in_tree)
|
||||
dbg = debug_info("core_closed_call_test", f, args, {})
|
||||
f, out_tree = flatten_fun_nokwargs(lu.wrap_init(f, debug_info=dbg), in_tree)
|
||||
out = core.closed_call_p.bind(f, *args)
|
||||
return jax.tree.unflatten(out_tree(), out)
|
||||
|
||||
@ -362,7 +363,10 @@ class CoreTest(jtu.JaxTestCase):
|
||||
|
||||
aval = core.ShapedArray((), jnp.dtype('int32'))
|
||||
pval = pe.PartialVal.unknown(aval)
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_nounits(lu.wrap_init(f), [pval], False)
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_nounits(
|
||||
lu.wrap_init(f,
|
||||
debug_info=debug_info("test", f, (0,), {})),
|
||||
[pval], False)
|
||||
dropvar, b = jaxpr.eqns[0].outvars
|
||||
self.assertEqual(dropvar.aval, aval)
|
||||
|
||||
@ -548,12 +552,13 @@ class DynamicShapesTest(jtu.JaxTestCase):
|
||||
a = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
|
||||
b = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
|
||||
|
||||
@lu.wrap_init
|
||||
def f(x, y):
|
||||
return x, y
|
||||
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
f, [n, a, b], keep_inputs=[False, True, True])
|
||||
lu.wrap_init(f,
|
||||
debug_info=debug_info("test", f, (1, 2), {})),
|
||||
[n, a, b], keep_inputs=[False, True, True])
|
||||
|
||||
self.assertLen(jaxpr.invars, 3)
|
||||
self.assertEqual((jaxpr.invars[0],), jaxpr.invars[1].aval.shape)
|
||||
@ -569,7 +574,6 @@ class DynamicShapesTest(jtu.JaxTestCase):
|
||||
a = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
|
||||
b = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
|
||||
|
||||
@lu.wrap_init
|
||||
def f(x, y):
|
||||
@jax.jit
|
||||
def g(x, y, z, w):
|
||||
@ -577,7 +581,9 @@ class DynamicShapesTest(jtu.JaxTestCase):
|
||||
return g(x, y, x, y)
|
||||
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
f, [n, a, b], keep_inputs=[False, True, True])
|
||||
lu.wrap_init(f,
|
||||
debug_info=debug_info("test", f, (0, 1), {})),
|
||||
[n, a, b], keep_inputs=[False, True, True])
|
||||
|
||||
self.assertLen(jaxpr.invars, 1 + 2) # one axis size var, two other inputs
|
||||
self.assertEqual((jaxpr.invars[0],), jaxpr.invars[1].aval.shape)
|
||||
@ -605,7 +611,6 @@ class DynamicShapesTest(jtu.JaxTestCase):
|
||||
a = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
|
||||
b = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
|
||||
|
||||
@lu.wrap_init
|
||||
def f(x, y):
|
||||
@jax.jit
|
||||
def g(_, x, y, z, w):
|
||||
@ -613,7 +618,9 @@ class DynamicShapesTest(jtu.JaxTestCase):
|
||||
return g(x.shape[0], x, y, x, y)
|
||||
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
f, [n, a, b], keep_inputs=[False, True, True])
|
||||
lu.wrap_init(f,
|
||||
debug_info=debug_info("test", f, (1, 2), {})),
|
||||
[n, a, b], keep_inputs=[False, True, True])
|
||||
|
||||
# { lambda ; a:i32[] b:f32[a] c:f32[a]. let
|
||||
# d:f32[a] e:f32[a] = xla_call[
|
||||
@ -641,7 +648,6 @@ class DynamicShapesTest(jtu.JaxTestCase):
|
||||
a = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
|
||||
b = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
|
||||
|
||||
@lu.wrap_init
|
||||
def f(x, y):
|
||||
z = lax.mul(x, y)
|
||||
w = lax.sin(z)
|
||||
@ -649,7 +655,9 @@ class DynamicShapesTest(jtu.JaxTestCase):
|
||||
return (u,)
|
||||
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
f, [n, a, b], keep_inputs=[False, True, True])
|
||||
lu.wrap_init(f,
|
||||
debug_info=debug_info("test", f, (1, 2), {})),
|
||||
[n, a, b], keep_inputs=[False, True, True])
|
||||
|
||||
self.assertLen(jaxpr.invars, 1 + 2) # one axis size var, two other inputs
|
||||
self.assertLen(jaxpr.eqns, 3)
|
||||
@ -667,14 +675,15 @@ class DynamicShapesTest(jtu.JaxTestCase):
|
||||
a = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
|
||||
b = core.DShapedArray((DBIdx(1),), jnp.dtype('float32'), weak_type=False)
|
||||
|
||||
@lu.wrap_init
|
||||
def f(a, b):
|
||||
@jax.jit
|
||||
def g(x): return x
|
||||
return g(a),
|
||||
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
f, [n, m, a, b], keep_inputs=[False, False, True, True])
|
||||
lu.wrap_init(f,
|
||||
debug_info=debug_info("test", f, (1, 2), {})),
|
||||
[n, m, a, b], keep_inputs=[False, False, True, True])
|
||||
# { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let
|
||||
# e:f32[a] = xla_call[
|
||||
# call_jaxpr={ lambda ; f:i32[] g:f32[f]. let in (g,) }
|
||||
|
@ -901,7 +901,7 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
tracer_spy=tracer_spy,
|
||||
expected_jaxpr_debug_infos=[
|
||||
# TODO(necula): what are these flat_index components?
|
||||
"traced_for=jit, fun=apply_fn, arg_names=inp, result_paths=[0],[1][<flat index 0>][0][<flat index 0>][0][0]",
|
||||
"traced_for=jit, fun=apply_fn, arg_names=inp, result_paths=[0],[1][0][0][0][0][0]",
|
||||
re.compile(r"traced_for=custom_jvp fun, fun=relu at .*nn.functions.py:.*, arg_names=x, result_paths="),
|
||||
re.compile(r"traced_for=jit, fun=relu at .*nn.functions.py:.*, arg_names=x, result_paths="),
|
||||
],
|
||||
@ -1091,8 +1091,7 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=cond, fun=my_f, arg_names=x['c'], result_paths=",
|
||||
"traced_for=cond, fun=<lambda>, arg_names=x['c'], result_paths=",
|
||||
# TODO(necula): flat_index?
|
||||
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=[<flat index 0>][0][0],[<flat index 0>][0][1]",
|
||||
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=[0][0][0],[0][0][1]",
|
||||
],
|
||||
check_tracer_arg_name=True,
|
||||
expected_tracer_debug_infos=[
|
||||
|
@ -29,6 +29,7 @@ import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.ad_checkpoint
|
||||
from jax import api_util
|
||||
from jax import lax
|
||||
from jax.sharding import Mesh, NamedSharding
|
||||
from jax.sharding import PartitionSpec as P
|
||||
@ -1329,7 +1330,11 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
|
||||
def test_rewrite_process_call(self):
|
||||
def f(x):
|
||||
return core.call_p.bind(lu.wrap_init(lambda x: [2. * x]), x)[0] * x
|
||||
return core.call_p.bind(
|
||||
lu.wrap_init(lambda x: [2. * x],
|
||||
debug_info=api_util.debug_info("test", lambda x: [2. * x],
|
||||
(x,), {})),
|
||||
x)[0] * x
|
||||
|
||||
mesh = jtu.create_mesh((4,), ('x',))
|
||||
g = shard_map(f, mesh, in_specs=(P('x'),), out_specs=P('x'))
|
||||
@ -1345,7 +1350,10 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
@partial(shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'))
|
||||
def f(x):
|
||||
return core.call_p.bind(lu.wrap_init(lambda: [2. * x]))[0] * x
|
||||
return core.call_p.bind(
|
||||
lu.wrap_init(lambda: [2. * x],
|
||||
debug_info=api_util.debug_info("test", lambda: [2. * x],
|
||||
(), {})))[0] * x
|
||||
|
||||
x = jnp.arange(4.)
|
||||
y = f(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user