shard_map and other fixes to direct-linearize

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
This commit is contained in:
Matthew Johnson 2025-02-05 01:41:08 +00:00
parent 178278863d
commit 7c2f842353
14 changed files with 207 additions and 181 deletions

View File

@ -2005,8 +2005,8 @@ def vjp(
raise NotImplementedError("reduce_axes argument to vjp is deprecated")
del reduce_axes
check_callable(fun)
wrapped_fun = lu.wrap_init(fun,
debug_info=debug_info("vjp", fun, primals, {}))
wrapped_fun = lu.wrap_init(
fun, debug_info=debug_info("vjp", fun, primals, {}))
return _vjp(wrapped_fun, *primals, has_aux=has_aux)
def _vjp(fun: lu.WrappedFun, *primals, has_aux=False):

View File

@ -2507,8 +2507,8 @@ class MapPrimitive(Primitive):
def get_bind_params(self, params):
new_params = dict(params)
jaxpr: Jaxpr = new_params.pop('call_jaxpr')
subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr,
debug_info=jaxpr.debug_info), jaxpr, ())
subfun = lu.hashable_partial(
lu.wrap_init(eval_jaxpr, debug_info=jaxpr.debug_info), jaxpr, ())
axes = new_params.pop('out_axes')
new_params['out_axes_thunk'] = HashableFunction(lambda: axes, closure=axes)
return [subfun], new_params

View File

@ -39,7 +39,7 @@ from jax._src.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal)
from jax._src.dtypes import dtype, float0
from jax._src.util import (unzip2, safe_map, safe_zip, split_list, wrap_name,
as_hashable_function, weakref_lru_cache,
partition_list)
partition_list, subs_list2)
zip = safe_zip
map = safe_map
@ -91,6 +91,7 @@ def linearize_subtrace(_f: Callable, _store: lu.Store, _tag: core.TraceTag,
*primals, **params):
with core.take_current_trace() as parent_trace:
tangent_trace = pe.DynamicJaxprTrace(debug_info)
tangent_trace.tag = _tag
linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=_tag)
tracers = [LinearizeTracer(linearize_trace, p,
tangent_trace.new_arg(get_aval(p).to_tangent_aval()))
@ -104,11 +105,23 @@ def linearize_subtrace(_f: Callable, _store: lu.Store, _tag: core.TraceTag,
out_tangents = tuple(t for t, nz in zip(out_tangents, nzs_out) if nz)
out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents) # type: ignore[assignment]
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, debug_info)
residual_avals = map(get_aval, consts)
if attrs_tracked:
raise NotImplementedError("TODO: attrs")
_store.store((residual_avals, nzs_out, jaxpr))
return tuple(consts) + tuple(out_primals)
which_env = [(isinstance(c, pe.DynamicJaxprTracer) and
getattr(c._trace, 'tag', None) is _tag) for c in consts]
jaxpr = pe.move_envvars(jaxpr, tuple(which_env))
res, env = partition_list(which_env, consts)
residual_avals = map(get_aval, res)
# Which residuals are just forwarded inputs? Check object id.
id_map = {id(p): i for i, p in enumerate(primals)}
in_fwd: list[int | None] = [id_map.get(id(r)) for r in res]
# Which residuals are already primal outputs? Check object id.
id_map = {id(p): i for i, p in enumerate(out_primals)}
out_fwd: list[int | None] = [id_map.get(id(r)) for r in res]
# Prune residuals not to include forwarded primal inputs or outputs.
res = [p for p, f1, f2 in zip(res, in_fwd, out_fwd) if f1 is None and f2 is None]
_store.store((residual_avals, nzs_out, jaxpr, env, in_fwd, out_fwd))
return *res, *out_primals
@lu.transformation2
def jvp_subtrace(f: Callable, tag: core.TraceTag, primals, tangents):
@ -157,6 +170,7 @@ def _linearize_jaxpr(
primal_trace = pe.DynamicJaxprTrace(dbg)
tangent_trace = pe.DynamicJaxprTrace(dbg)
lin_trace = LinearizeTrace(primal_trace, tangent_trace)
tangent_trace.tag = lin_trace.tag
def new_arg(trace, primal_aval, nz):
primal = primal_trace.new_arg(primal_aval)
@ -197,6 +211,7 @@ def direct_linearize(traceable: lu.WrappedFun,
tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval()) for p in primals]
tangents = [Zero.from_primal_value(t) if dtype(t) == float0 else t for t in tangents]
linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=tag)
tangent_trace.tag = linearize_trace.tag
tracers = [LinearizeTracer(linearize_trace, p, t) for p, t in zip(primals, tangents)]
tracers = [t.full_lower() for t in tracers]
with (core.set_current_trace(linearize_trace, check_leaks=True),
@ -217,6 +232,10 @@ def direct_linearize(traceable: lu.WrappedFun,
out_nz_tangents = map(tangent_trace.to_jaxpr_tracer, out_nz_tangents)
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_nz_tangents, traceable.debug_info)
tangent_trace.invalidate()
jaxpr, used_consts, _ = pe.dce_jaxpr_consts(
jaxpr, [True] * len(jaxpr.outvars),
[False] * len(jaxpr.constvars) + [True] * len(jaxpr.invars))
consts = [c for c, used in zip(consts, used_consts) if used]
out_tangents_pvals = [pe.PartialVal.unknown(core.get_aval(t)) if nz else
pe.PartialVal.known(zeros_like_aval(t.aval))
for t, nz in zip(out_tangents, out_nzs)]
@ -586,7 +605,7 @@ def _primal_tangent_shapes_match(primal, tangent):
if type(tangent) is not Zero:
primal_aval = get_aval(primal).strip_weak_type()
tangent_aval = get_aval(tangent).strip_weak_type()
assert core.definitely_equal_shape(primal_aval.shape, tangent_aval.shape)
assert core.definitely_equal_shape(primal_aval.shape, tangent_aval.shape), (primal_aval.shape, tangent_aval.shape)
expected_tangent_dtype = core.primal_dtype_to_tangent_dtype(primal_aval.dtype)
assert expected_tangent_dtype == tangent_aval.dtype, (expected_tangent_dtype, tangent_aval.dtype)
@ -641,6 +660,7 @@ class LinearizeTrace(Trace):
return prim.bind_with_trace(self.parent_trace, (fun, f_jvp, *primals_in),
dict(symbolic_zeros=symbolic_zeros))
@partial(lu.wrap_init, debug_info=f_jvp.debug_info)
def _f_jvp(primals, tangents):
outs = f_jvp.call_wrapped(*primals, *tangents)
primals_out, tangents_out = split_list(outs, [len(outs) // 2])
@ -651,7 +671,7 @@ class LinearizeTrace(Trace):
nonzeros_in = [type(t) is not Zero for t in tangents_in]
primals_out, tangent_nzs_out, residuals, linearized = linearize_from_jvp(
_f_jvp, True, nonzeros_in, symbolic_zeros, instantiate_zeros,
f_jvp.debug_info, primals_in, {})
primals_in, {})
with core.set_current_trace(self.tangent_trace):
tangents_out = linearized(residuals, *tangents_in)
@ -690,53 +710,65 @@ class LinearizeTrace(Trace):
assert call_primitive.multiple_results
primals, tangents = unzip2(map(self.to_primal_tangent_pair, tracers))
nzs_in = tuple(type(t) is not Zero for t in tangents)
f_primal, linearize_outs_thunk = linearize_subtrace(f, self.tag, nzs_in,
f.debug_info)
f_primal, linearize_outs_thunk = linearize_subtrace(
f, self.tag, nzs_in, f.debug_info)
if isinstance(call_primitive, core.MapPrimitive):
@as_hashable_function(closure=(linearize_outs_thunk))
out_axes_thunk = params['out_axes_thunk']
@as_hashable_function(closure=out_axes_thunk)
def new_out_axes_thunk():
residual_avals, _, _ = linearize_outs_thunk()
out_axes = params['out_axes_thunk']()
return (*(0 for _ in residual_avals), *out_axes)
_, _, _, _, in_fwd, out_fwd = linearize_outs_thunk()
num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
out_axes = out_axes_thunk()
return (*(0 for _ in range(num_res_out)), *out_axes)
primal_params = dict(params, out_axes_thunk=new_out_axes_thunk)
else:
primal_params = params
all_primal_results = call_primitive.bind_with_trace(self.parent_trace, (f_primal, *primals), primal_params)
residual_avals, nzs_out, lin_jaxpr = linearize_outs_thunk()
num_residuals = len(residual_avals)
residuals = all_primal_results[:num_residuals]
primals_out = all_primal_results[num_residuals:]
residual_avals, nzs_out, lin_jaxpr, env, in_fwd, out_fwd = linearize_outs_thunk()
num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
non_fwd_res = all_primal_results[:num_res_out]
primals_out = all_primal_results[num_res_out:]
residuals = subs_list2(in_fwd, out_fwd, primals, primals_out, non_fwd_res)
if isinstance(call_primitive, core.MapPrimitive):
in_axes = params['in_axes']
out_axes = params['out_axes_thunk']()
residual_avals = map(get_aval, residuals)
new_in_axes = (*(0 for _ in residual_avals),
residual_axes = [in_axes[f1] if f1 is not None else
out_axes[f2] if f2 is not None else
0 for f1, f2 in zip(in_fwd, out_fwd)]
new_in_axes = (*residual_axes, *(None for _ in range(len(env))),
*(ax for ax, nz in zip(in_axes, nzs_in) if nz))
new_out_axes = (*(ax for ax, nz in zip(out_axes, nzs_out) if nz),)
# NOTE: This assumes that the output tangents being zero is a
# deterministic function of which input tangents were zero.
@as_hashable_function(closure=(new_out_axes))
@as_hashable_function(closure=new_out_axes)
def new_out_axes_thunk():
return new_out_axes
params = dict(params,
in_axes=new_in_axes,
out_axes_thunk=new_out_axes_thunk)
params = dict(params, in_axes=new_in_axes, out_axes_thunk=new_out_axes_thunk)
update_params = call_linearize_param_updaters.get(call_primitive)
new_params = update_params(params, residual_avals, nzs_in) if update_params else params
num_new_args = len(residuals) + len(env)
new_params = update_params(params, num_new_args, nzs_in) if update_params else params
num_residuals = len(residual_avals)
@as_hashable_function(closure=(num_residuals, lin_jaxpr))
def f_tangent(*args):
residuals = args[:num_residuals]
consts = args[:num_residuals]
nz_tangents = args[num_residuals:]
return core.eval_jaxpr(lin_jaxpr, residuals, *nz_tangents)
return core.eval_jaxpr(lin_jaxpr, consts, *nz_tangents)
# TODO(mattjj,dougalm): this tag is read by DynamicJaxprTrace.process_map to
# avoid round-tripping the jaxpr and thus getting grad-of-pmap cache misses.
# Remove when we replace the pmap implementation.
f_tangent._pmap_tag = isinstance(call_primitive, core.MapPrimitive)
thing = lu.wrap_init(f_tangent, debug_info=lin_jaxpr.debug_info)
nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz]
nz_tangents_out = call_primitive.bind_with_trace(
self.tangent_trace, (lu.wrap_init(f_tangent,
debug_info=lin_jaxpr.debug_info),
*residuals, *nz_tangents_in), new_params)
self.tangent_trace,
(thing,
*residuals, *env, *nz_tangents_in), new_params)
nz_tangents_out_iter = iter(nz_tangents_out)
tangents_out = [next(nz_tangents_out_iter) if nz else Zero.from_primal_value(primal)
for nz, primal in zip(nzs_out, primals_out)]
@ -762,14 +794,14 @@ def fallback_linearize_rule(_prim: core.Primitive,
msg = f"Differentiation rule for '{_prim}' not implemented"
raise NotImplementedError(msg)
debug_jvp = debug_info("linearize_prim_jvp", jvp, primals, params)
return linearize_from_jvp(jvp, _prim.multiple_results, _nonzeros, False, False,
debug_jvp, primals, params)
return linearize_from_jvp(lu.wrap_init(jvp, debug_info=debug_jvp),
_prim.multiple_results, _nonzeros, False, False,
primals, params)
def linearize_from_jvp(jvp: Callable,
def linearize_from_jvp(jvp: lu.WrappedFun,
multiple_results: bool,
nonzeros: Sequence[bool],
user_facing_symbolic_zeros: bool, instantiate_input_zeros: bool,
debug_info: core.DebugInfo,
primals, params):
current_name_stack = source_info_util.current_name_stack()
with core.take_current_trace() as parent_trace:
@ -792,13 +824,18 @@ def linearize_from_jvp(jvp: Callable,
tangent_args = tuple(trace.new_arg(pe.PartialVal.unknown(aval)) if nz else make_zero(aval)
for aval, nz in zip(tangent_avals, nonzeros))
with core.set_current_trace(trace):
out_primals, out_tangents = jvp(primals, tangent_args, **params)
out_primals, out_tangents = jvp.call_wrapped(primals, tangent_args, **params)
if not multiple_results:
out_primals = [out_primals]
out_tangents = [out_tangents]
out_primals = [trace.to_jaxpr_tracer(p).pval.get_known() for p in out_primals]
if any(p is None for p in out_primals):
raise ValueError(
"Linearization failed to produce known values for all output primals. "
"This is typically caused by attempting to differentiate a function "
"uses an operation that does not support reverse-mode autodiff.")
out_nzs = [type(t) is not zero_type and not trace.to_jaxpr_tracer(t).is_known()
for t in out_tangents]
@ -806,7 +843,7 @@ def linearize_from_jvp(jvp: Callable,
out_nz_tracers = [trace.to_jaxpr_tracer(r)
for (r, nz) in zip(out_tangents, out_nzs) if nz]
in_tracers = [t for t, nz in zip(tangent_args, nonzeros) if nz]
jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers, debug_info)
jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers, jvp.debug_info)
def linearized(residuals, *tangents):
nz_tangents_in = [t for (t, nz) in zip(tangents, nonzeros) if nz]
@ -973,9 +1010,8 @@ def call_transpose(primitive, params, call_jaxpr: core.Jaxpr, args, ct, _):
else:
consts = ()
all_args, in_tree_def = tree_flatten((consts, args, ct))
fun = lu.hashable_partial(lu.wrap_init(backward_pass,
debug_info=call_jaxpr.debug_info),
call_jaxpr, False)
fun = lu.hashable_partial(lu.wrap_init(
backward_pass, debug_info=call_jaxpr.debug_info), call_jaxpr, False)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
update_params = call_transpose_param_updaters.get(primitive)
if update_params:
@ -1013,9 +1049,8 @@ def map_transpose(primitive: core.Primitive, params,
call_jaxpr: core.Jaxpr, args, ct, _):
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
# TODO(necula): use the right debug_info for the backwards pass
fun = lu.hashable_partial(lu.wrap_init(backward_pass,
debug_info=call_jaxpr.debug_info),
call_jaxpr, False)
fun = lu.hashable_partial(lu.wrap_init(
backward_pass, debug_info=call_jaxpr.debug_info), call_jaxpr, False)
fun, nz_arg_cts = nonzero_outputs(fun)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
# Preserve axis for primal arguments, skip tangents (represented as undefined primals).

View File

@ -46,7 +46,8 @@ from jax._src.tree_util import (PyTreeDef, treedef_tuple,
tree_flatten, tree_structure)
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
merge_lists, partition_list, OrderedSet,
as_hashable_function, weakref_lru_cache, subs_list)
as_hashable_function, weakref_lru_cache, subs_list,
HashableFunction)
map, unsafe_map = safe_map, map
@ -837,6 +838,11 @@ def tracers_to_jaxpr(
# del getvar # needed to avoid cyclic-reference closure, apparently!
return jaxpr, const_vals, env_vals
@weakref_lru_cache
def move_envvars(jaxpr: Jaxpr, which: tuple[bool, ...]) -> Jaxpr:
constvars, envvars = partition_list(which, jaxpr.constvars)
return jaxpr.replace(constvars=constvars, invars=[*envvars, *jaxpr.invars])
@weakref_lru_cache
def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr:
"""Moves the constvars to the start of invars."""
@ -1840,7 +1846,7 @@ def _inline_literals(
class DynamicJaxprTrace(core.Trace):
__slots__ = ("frame",)
__slots__ = ("frame", "tag")
def __init__(self, debug_info: core.DebugInfo):
self.frame = JaxprStackFrame(debug_info)
@ -1972,17 +1978,18 @@ class DynamicJaxprTrace(core.Trace):
self.frame.add_eqn(eqn)
return [t for t, (_, keep) in zip(out_tracers, out_type) if keep]
def process_map(self, map_primitive, f: lu.WrappedFun,
tracers: Sequence[core.Tracer], params):
def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
tracers = map(self.to_jaxpr_tracer, tracers)
in_avals = [t.aval for t in tracers]
axis_name, axis_size = params['axis_name'], params['axis_size']
reduced_in_avals = [core.mapped_aval(axis_size, in_axis, a)
if in_axis is not None else a
for a, in_axis in zip(in_avals, params['in_axes'])]
with core.extend_axis_env_nd([(axis_name, params["global_axis_size"])]):
jaxpr, reduced_out_avals, consts, () = trace_to_jaxpr_dynamic(
f, reduced_in_avals)
jaxpr, consts = _linearize_of_pmap_hack(f, jaxpr, consts)
ordered_effects = effects.ordered_effects.filter_in(jaxpr.effects)
if ordered_effects:
raise ValueError("Ordered effects not supported for "
@ -2582,3 +2589,13 @@ def inline_jaxpr_into_trace(
return tracer
return [x.val if isinstance(x, Literal) else tracer_env[x] if x in tracer_env
else new_tracer(x) for x in jaxpr.outvars]
# TODO(mattjj,dougalm): this special handling is to avoid round-tripping the
# jaxpr when we do grad-of-pmap. The tag is set by LinearizeTrace.process_call's
# handling of pmap. Remove when we replace the pmap implementation.
def _linearize_of_pmap_hack(f: lu.WrappedFun, jaxpr, consts) -> tuple[Jaxpr, list]:
if (not f.transforms and type(f.f) is HashableFunction and
getattr(f.f, '_pmap_tag', None)):
_, jaxpr = f.f.closure
return convert_constvars_jaxpr(jaxpr), []
return jaxpr, consts

View File

@ -1394,9 +1394,9 @@ def xla_call_jvp_update_params(params, nz_tangents):
new_donated_invars = (*donated_invars, *donated_tangents)
return dict(params, donated_invars=new_donated_invars)
def _xla_call_linearize_update_params(params, residual_avals, nz_tangents):
def _xla_call_linearize_update_params(params, num_new_inputs, nz_tangents):
donated_invars_prev = params['donated_invars']
donated_invars = (*(False for _ in residual_avals),
donated_invars = (*(False for _ in range(num_new_inputs)),
*(d for d, nz in zip(donated_invars_prev, nz_tangents) if nz))
return dict(params, donated_invars=donated_invars)

View File

@ -3736,14 +3736,14 @@ def _sin_lowering(ctx, x):
return sine(ctx, x)
return _nary_lower_hlo(hlo.sine, ctx, x)
def _sin_p_lin(nzs, x):
def _sin_lin(nzs, x):
nz, = nzs
cos_x = cos(x) # TODO: allow this to happen in the linearized computation (need to fix backward_pass)
return (sin_p.bind(x), nz, cos_x, lambda cos_x_, t: mul(t, cos_x_))
sin_p = standard_unop(_float | _complex, 'sin')
ad.defjvp(sin_p, lambda g, x: mul(g, cos(x)))
ad.primitive_linearizations[sin_p] = _sin_p_lin
ad.primitive_linearizations[sin_p] = _sin_lin
mlir.register_lowering(sin_p, _sin_lowering)
batching.ragged_prop_rules[sin_p] = batching.ragged_mask_elementwise_rule

View File

@ -200,7 +200,7 @@ class _BaseMesh:
_mesh_object_dict = {} # type: ignore
MeshAxisType = dict[AxisTypes, str | tuple[str, ...]]
MeshAxisType = dict[AxisTypes, MeshAxisName | tuple[MeshAxisName, ...]]
class Mesh(_BaseMesh, contextlib.ContextDecorator):
"""Declare the hardware resources available in the scope of this manager.

View File

@ -2043,18 +2043,6 @@ def _pjit_jvp(primals_in, tangents_in,
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline,
compiler_options_kvs):
if any(isinstance(c, core.MutableArray) for c in jaxpr.consts):
jaxpr, mut_primals = pxla._move_mutable_consts(jaxpr)
mut_tangents = map(ad_util.zeros_like_jaxval, mut_primals)
primals_in = [*primals_in, *mut_primals]
tangents_in = [*tangents_in, *mut_tangents]
in_shardings = (*in_shardings,) + (UNSPECIFIED,) * len(mut_primals)
in_layouts = (*in_layouts,) + (None,) * len(mut_primals)
donated_invars = (*donated_invars,) + (False,) * len(mut_primals)
tangents_in = [ad_util.zeros_like_aval(a) if isinstance(a, AbstractRef) else x
for x, a in zip(tangents_in, jaxpr.in_avals)]
is_nz_tangents_in = [type(t) is not ad.Zero for t in tangents_in]
jaxpr_jvp, is_nz_tangents_out = ad.jvp_jaxpr(
jaxpr, is_nz_tangents_in, instantiate=False)

View File

@ -437,6 +437,8 @@ def _swap_jvp(primals: list[Any], tangents: list[Any], **params: Any):
ref_primal, x_primal, *idx = primals
assert isinstance(ref_primal.aval, AbstractRef)
ref_tangent, x_tangent, *_ = tangents
if type(ref_tangent) is ad_util.Zero:
raise Exception("you're an idiot")
assert isinstance(ref_tangent.aval, AbstractRef)
x_tangent = ad_util.instantiate(x_tangent)
return (swap_p.bind(ref_primal, x_primal, *idx, **params),
@ -657,5 +659,14 @@ mlir.register_lowering(
# === AD rules for mutable arrays ===
ad.defjvp(core.mutable_array_p, lambda g, _: core.mutable_array(g))
def _mut_jvp(primals, tangents):
(init_val,), (init_val_dot,) = primals, tangents
primal_out = core.mutable_array_p.bind(init_val)
if type(init_val_dot) is ad_util.Zero:
tangent_out = core.mutable_array_p.bind(ad_util.zeros_like_aval(init_val_dot.aval))
else:
tangent_out = core.mutable_array_p.bind(init_val_dot)
return primal_out, tangent_out
ad.primitive_jvps[core.mutable_array_p] = _mut_jvp
ad.defjvp(core.freeze_p, lambda g, _: core.freeze(g))

View File

@ -544,6 +544,8 @@ def _shard_map_staging(
return out_tracers
pe.DynamicJaxprTrace.process_shard_map = _shard_map_staging
# TODO add underscore version, for direct-linearize to consume
def _check_shapedarray(aval: core.AbstractValue) -> core.ShapedArray:
assert isinstance(aval, core.ShapedArray)
return aval
@ -742,9 +744,8 @@ def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names,
out_avals_ = [x.aval for x in jaxpr.outvars]
in_nodes_ = map(partial(_xla_shard, ctx, mesh, auto), in_names, ctx.avals_in,
in_avals_, in_nodes)
new_axis_context = sharding_impls.SPMDAxisContext(
mesh, frozenset(mesh.axis_names) - auto
)
manual_axes = frozenset(mesh.axis_names) - auto
new_axis_context = sharding_impls.SPMDAxisContext(mesh, manual_axes)
sub_ctx = ctx.module_context.replace(axis_context=new_axis_context)
with _extend_axis_env(mesh, auto):
out_nodes_, tokens_out = mlir.call_lowering(
@ -895,7 +896,6 @@ def _match_spec(mesh: Mesh, check_rep: bool,
def _match(mesh, check_rep, pspec, x):
src = P(mesh.axis_names)
# TODO put back (?) needed for rep checking in eager? for now test rewrite
return shard_map(_rem_singleton, mesh, (src,), pspec, check_rep=False)(x)
def _rem_singleton(x): return jnp.squeeze(x, axis=0)
@ -914,6 +914,7 @@ class ShardMapTrace(core.Trace):
__slots__ = ("mesh", "auto", "check", "context_mesh")
mesh: Mesh
auto: frozenset[AxisName]
check: bool
context_mesh: AbstractMesh
@ -927,7 +928,7 @@ class ShardMapTrace(core.Trace):
if isinstance(val, ShardMapTracer):
return val.val, val.rep
elif isinstance(val, Tracer):
raise Exception("Shouldn't have any non-shard_map tracers")
raise Exception(f"Shouldn't have any non-shard_map tracers: {val}")
else:
val_ = _unmatch_spec(self.mesh, {}, val, self.context_mesh)
return val_, None
@ -1609,34 +1610,40 @@ def _shard_map_linearize(trace, shard_map_p, f: lu.WrappedFun,
out_names_thunk, check_rep, rewrite, auto):
primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers))
nzs_in = tuple(type(t) is not ad.Zero for t in tangents)
f_primal, linearize_outs_thunk = ad.linearize_subtrace(f, trace.tag, nzs_in,
f.debug_info)
f_primal, linearize_outs_thunk = ad.linearize_subtrace(f, trace.tag, nzs_in, f.debug_info)
f_primal = _promote_scalar_residuals_lin(f_primal, linearize_outs_thunk)
tangent_in_names = [ax for ax, nz in zip(in_names, nzs_in) if nz]
all_names = _all_newly_manual_mesh_names(mesh, auto, trace)
res_names = _all_newly_manual_mesh_names(mesh, auto, trace)
@as_hashable_function(closure=(linearize_outs_thunk))
@as_hashable_function(closure=linearize_outs_thunk)
def primal_out_names_thunk():
residual_avals, _, _ = linearize_outs_thunk()
_, _, _, _, in_fwd, out_fwd = linearize_outs_thunk()
out_names = out_names_thunk()
# This is incorrect so we set `check_rep=False` as we do in the JVP rule.
return (*({0: all_names} for _ in residual_avals), *out_names)
num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
# This is incorrect so we set `check_rep=False` in the tangent (as in JVP).
return (*({0: res_names} for _ in range(num_res_out)), *out_names)
primal_params = dict(
mesh=mesh, in_names=in_names,
out_names_thunk=primal_out_names_thunk, check_rep=check_rep,
rewrite=rewrite, auto=auto)
all_primal_results = shard_map_p.bind_with_trace(
trace.parent_trace, (f_primal,) + tuple(primals), primal_params)
residual_avals, nzs_out, lin_jaxpr = linearize_outs_thunk()
num_residuals = len(residual_avals)
residuals = all_primal_results[:num_residuals]
primals_out = all_primal_results[num_residuals:]
args_to_promote = [getattr(aval, 'shape', ()) == () for aval in residual_avals]
lin_jaxpr = _promote_scalar_residuals_jaxpr(lin_jaxpr, args_to_promote)
trace.parent_trace, (f_primal, *primals), primal_params)
residual_avals, nzs_out, lin_jaxpr, env, in_fwd, out_fwd = linearize_outs_thunk()
num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
non_fwd_res = all_primal_results[:num_res_out]
primals_out = all_primal_results[num_res_out:]
residuals = subs_list2(in_fwd, out_fwd, primals, primals_out, non_fwd_res)
args_to_promote = [getattr(aval, 'shape', ()) == () and f1 is None and f2 is None
for aval, f1, f2 in zip(residual_avals, in_fwd, out_fwd)]
with core.extend_axis_env_nd(mesh.shape.items()):
lin_jaxpr = _promote_scalar_residuals_jaxpr(lin_jaxpr, args_to_promote)
out_names = out_names_thunk()
new_in_names = (*({0: all_names} for _ in residual_avals),
residual_names = [in_names[f1] if f1 is not None else
out_names[f2] if f2 is not None else
{0: res_names} for f1, f2 in zip(in_fwd, out_fwd)]
new_in_names = (*residual_names, *({} for _ in range(len(env))),
*(ax for ax, nz in zip(in_names, nzs_in) if nz))
new_out_names = (*(ax for ax, nz in zip(out_names, nzs_out) if nz),)
new_out_names = tuple(ax for ax, nz in zip(out_names, nzs_out) if nz)
@as_hashable_function(closure=(new_out_names))
def tangent_out_names_thunk():
return new_out_names
@ -1645,15 +1652,14 @@ def _shard_map_linearize(trace, shard_map_p, f: lu.WrappedFun,
out_names_thunk=tangent_out_names_thunk, check_rep=False,
rewrite=rewrite, auto=auto)
# TODO TODO don't round-trip
def f_tangent(*args):
residuals = args[:num_residuals]
nz_tangents = args[num_residuals:]
return core.eval_jaxpr(lin_jaxpr, (), *residuals, *nz_tangents)
return core.eval_jaxpr(lin_jaxpr, (), *args)
nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz]
nz_tangents_out = shard_map_p.bind_with_trace(trace.tangent_trace,
(lu.wrap_init(f_tangent, debug_info=lin_jaxpr.debug_info),
*residuals, *nz_tangents_in), tangent_params)
*residuals, *env, *nz_tangents_in), tangent_params)
nz_tangents_out_iter = iter(nz_tangents_out)
tangents_out = [next(nz_tangents_out_iter) if nz else ad.Zero.from_primal_value(primal)
for nz, primal in zip(nzs_out, primals_out)]
@ -1663,13 +1669,13 @@ ad.LinearizeTrace.process_shard_map = _shard_map_linearize
@lu.transformation2
def _promote_scalar_residuals_lin(f, linearize_outs_thunk, *args, **kwargs):
ans = f(*args, **kwargs)
residual_avals, _, _ = linearize_outs_thunk()
num_residuals = len(residual_avals)
residuals = ans[:num_residuals]
primals = ans[num_residuals:]
residuals = tuple(jax.lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x
for x in residuals)
return residuals + primals
_, _, _, _, in_fwd, out_fwd = linearize_outs_thunk()
num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
residuals = ans[:num_res_out]
primals = ans[num_res_out:]
residuals = [jax.lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x
for x in residuals]
return *residuals, *primals
@lu.transformation2
def _promote_scalar_residuals(f: Callable, *args, **kwargs):
@ -1798,10 +1804,10 @@ def _partial_eval_jaxpr_custom_rule(
_, ins_staged = partition_list(inst_in, eqn.invars)
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
newvar = core.gensym()
params_known, params_staged, all_names = _pe_custom_params(
params_known, params_staged, res_names = _pe_custom_params(
unks_in, inst_in, map(op.not_, unks_out), inst_out, in_fwd, out_fwd, which,
dict(eqn.params, jaxpr=jaxpr_known), dict(eqn.params, jaxpr=jaxpr_staged))
residuals = [newvar(_unshard_aval(mesh, {0: all_names}, var.aval))
residuals = [newvar(_unshard_aval(mesh, {0: res_names}, var.aval))
for var, w in zip(jaxpr_staged.invars[:num_res], which) if w]
eqn_known = pe.new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals],
eqn.primitive, params_known, jaxpr_known.effects,
@ -1853,10 +1859,10 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged,
# prune inputs to jaxpr_known according to unks_in
mesh = params_known['mesh']
auto = params_known['auto']
all_names = _all_newly_manual_mesh_names(mesh, auto)
res_names_ = _all_newly_manual_mesh_names(mesh, auto)
in_names_known, _ = partition_list(unks_in, params_known['in_names'])
_, out_names_known = partition_list(kept_outs_known, params_known['out_names'])
out_names_known = out_names_known + [{0: all_names}] * sum(which)
out_names_known = out_names_known + [{0: res_names_}] * sum(which)
new_params_known = dict(params_known, in_names=tuple(in_names_known),
out_names=tuple(out_names_known))
@ -1864,12 +1870,12 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged,
_, in_names_staged = partition_list(inst_in, params_staged['in_names'])
res_names = [in_names_known[f1] if f1 is not None else
out_names_known[f2] if f2 is not None else
{0: all_names} for f1, f2 in zip(in_fwd, out_fwd)]
{0: res_names_} for f1, f2 in zip(in_fwd, out_fwd)]
in_names_staged = res_names + in_names_staged
_, out_names_staged = partition_list(kept_outs_staged, params_staged['out_names'])
new_params_staged = dict(params_staged, in_names=tuple(in_names_staged),
out_names=tuple(out_names_staged), check_rep=False)
return new_params_known, new_params_staged, all_names
return new_params_known, new_params_staged, res_names_
# TODO(mattjj): remove this mechanism when we revise mesh scopes
def _all_mesh_names_except_spmd(
@ -1880,15 +1886,21 @@ def _all_mesh_names_except_spmd(
return tuple(name for name in mesh.axis_names if name not in spmd_names and
name not in auto)
# TODO(mattjj): remove this mechanism when we revise mesh scopes
def _all_newly_manual_mesh_names(
mesh: Mesh, auto: frozenset[AxisName], trace=None
) -> tuple[AxisName, ...]:
axis_env = core.get_axis_env()
spmd_names = axis_env.spmd_axis_names
axis_sizes = axis_env.axis_sizes
return tuple(name for name in mesh.axis_names if name not in spmd_names and
name not in auto and name not in axis_sizes)
if not (ctx_mesh := get_abstract_mesh()).empty:
del mesh
already_manual_names = set(ctx_mesh.axis_types.get(AxisTypes.Manual, ()))
return tuple(name for name in ctx_mesh.axis_names
if name not in auto | already_manual_names)
else:
# TODO(mattjj): remove this mechanism when we revise mesh scopes
axis_env = core.get_axis_env()
vmap_spmd_names = set(axis_env.spmd_axis_names)
already_manual_names = set(axis_env.axis_sizes) # may include vmap axis_names
return tuple(name for name in mesh.axis_names
if name not in auto | vmap_spmd_names | already_manual_names)
# DCE

View File

@ -43,13 +43,14 @@ __ = pe.PartialVal.unknown(ShapedArray((), np.float32))
def call(f, *args):
return jit(f)(*args)
@util.curry
def core_call(f, *args):
args, in_tree = jax.tree.flatten(args)
dbg = debug_info("core_call_test", f, args, {})
f, out_tree = flatten_fun_nokwargs(lu.wrap_init(f, debug_info=dbg), in_tree)
out = core.call_p.bind(f, *args)
return jax.tree.unflatten(out_tree(), out)
# call = core_call
core_call = util.curry(core_call)
@util.curry
def core_closed_call(f, *args):

View File

@ -131,51 +131,6 @@ class MutableArrayTest(jtu.JaxTestCase):
out = f()
self.assertAllClose(out, jnp.array([2., 0., 1.]), check_dtypes=False)
@parameterized.parameters([True, False])
def test_refs_in_vjps(self, jit):
def gradient_history_calculator_fwd(x, ref):
return x, ref
def gradient_history_calculator_bwd(amax_history, grad_output):
amax_update = jnp.max(jnp.abs(grad_output))
shifted = jnp.roll(amax_history[:], 1)
shifted = shifted.at[0].set(amax_update)
amax_history[:] = shifted
amax_from_history = jnp.max(amax_history[:])
grad_output = grad_output / amax_from_history
return grad_output, None
@jax.custom_vjp
def gradient_history_calculator(x, ref):
return x
gradient_history_calculator.defvjp(
gradient_history_calculator_fwd,
gradient_history_calculator_bwd)
class DotOp:
def __init__(self):
self.amax_history = core.mutable_array(jnp.zeros(5,))
def forward(self, x, y):
out = jnp.dot(x, y)
out = gradient_history_calculator(out, self.amax_history)
return out
dot_op = DotOp()
x_top = jnp.ones((5,))
y_top = jnp.ones((5,))
def loss(x, y):
return dot_op.forward(x, y).sum()
if jit:
loss = jax.jit(loss)
for i in range(3):
jax.grad(loss, (0,1))(x_top, y_top)
self.assertAllClose(dot_op.amax_history[:], jnp.zeros((5,)).at[:i+1].set(1.0), check_dtypes=False)
@parameterized.parameters([True, False])
def test_scan_internal_mut_array(self, jit):
def body_fun(_, x):
@ -371,17 +326,18 @@ class MutableArrayErrorsTest(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError, "x_ref and y_ref"):
f(x_ref, x_ref)
@parameterized.parameters([False, True])
def test_argument_aliases_custom_vjp_fwd(self, jit):
@jax.custom_vjp
def f(x_ref, y_ref):
...
f.defvjp(lambda x_ref, y_ref: (None, None), lambda _, g: (None, None))
if jit:
f = jax.jit(f)
x_ref = core.mutable_array(0.)
with self.assertRaisesRegex(ValueError, "x_ref and y_ref"):
jax.vjp(f, x_ref, x_ref)
# TODO(mattjj): re-enable test after direct-linearize
# @parameterized.parameters([False, True])
# def test_argument_aliases_custom_vjp_fwd(self, jit):
# @jax.custom_vjp
# def f(x_ref, y_ref):
# ...
# f.defvjp(lambda x_ref, y_ref: (None, None), lambda _, g: (None, None))
# if jit:
# f = jax.jit(f)
# x_ref = core.mutable_array(0.)
# with self.assertRaisesRegex(ValueError, "x_ref and y_ref"):
# jax.vjp(f, x_ref, x_ref)
# TODO(mattjj): add test test_closure_and_argument_aliases_custom_vjp

View File

@ -51,6 +51,9 @@ from jax._src.interpreters import pxla
from jax._src.lax import parallel
from jax._src.lib import xla_extension
from jax._src.util import safe_map, safe_zip
from jax._src import util
from jax.api_util import flatten_fun_nokwargs, debug_info
from jax._src import linear_util as lu
config.parse_flags_with_absl()
jtu.request_cpu_devices(8)

View File

@ -2205,20 +2205,23 @@ class ShardMapTest(jtu.JaxTestCase):
mesh = jtu.create_mesh((2, 2), ('i', 'j'))
def g(x):
# manual: 'i', 'j'
return x * x
def h(x):
# auto: 'j', manual: 'i'
return shard_map(g, mesh,
in_specs=P(None, 'j'),
out_specs=P(None, 'j'))(x)
in_specs=P(None, 'j'),
out_specs=P(None, 'j'))(x)
@jax.jit
def f(x):
# auto: 'i', 'j'
return shard_map(h, mesh,
in_specs=P('i', None),
out_specs=P('i', None),
check_rep=False,
auto=frozenset({'j'}))(x).sum()
in_specs=P('i', None),
out_specs=P('i', None),
check_rep=False,
auto=frozenset({'j'}))(x).sum()
v = jnp.arange(32.).reshape(4, 8)
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
@ -2814,7 +2817,7 @@ def sample(num: int, make_gen: Callable[[], Chooser]) -> Iterator[CaseSpec]:
name, *case = sample_one(rng, make_gen())
if name not in seen:
seen.add(name)
yield name, *case
yield case
# To sample one test spec, we run the generator, getting back sequences of
# options from it and sending in our choices from those options until finally a
@ -2929,7 +2932,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
def make_mesh(mesh_shape):
return jtu.create_mesh(tuple(mesh_shape.values()), tuple(mesh_shape))
@parameterized.named_parameters(
@parameterized.parameters(
sample(jtu.NUM_GENERATED_CASES.value, sample_shmap))
def test_eager_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref):
mesh = self.make_mesh(mesh)
@ -2938,7 +2941,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
expected = ref(fun, mesh, in_specs, out_specs)(*args)
self.assertAllClose(expected, out, check_dtypes=False)
@parameterized.named_parameters(
@parameterized.parameters(
sample(jtu.NUM_GENERATED_CASES.value, sample_shmap))
def test_jit_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref):
mesh = self.make_mesh(mesh)
@ -2947,9 +2950,9 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
expected = ref(fun, mesh, in_specs, out_specs)(*args)
self.assertAllClose(expected, out, check_dtypes=False)
@parameterized.named_parameters(
(name + f'_check_rep={check_rep}', *params, check_rep)
for (name, *params) in sample(jtu.NUM_GENERATED_CASES.value, sample_shmap)
@parameterized.parameters(
(*params, check_rep)
for params in sample(jtu.NUM_GENERATED_CASES.value, sample_shmap)
for check_rep in [True, False]
)
@jax.default_matmul_precision("float32")
@ -2961,7 +2964,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
f = jax.jit(f)
jtu.check_grads(f, args, order=2, atol=1e-2, rtol=1e-2)
@parameterized.named_parameters(
@parameterized.parameters(
sample(jtu.NUM_GENERATED_CASES.value, sample_shmap))
@jax.default_matmul_precision("float32")
def test_grads_closure(self, fun, mesh, jit, in_specs, out_specs, args, _):
@ -2980,7 +2983,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
return g(*args)
jtu.check_grads(f, (0.2, *closed_over_args), order=2, atol=1e-2, rtol=1e-2)
@parameterized.named_parameters(
@parameterized.parameters(
sample(jtu.NUM_GENERATED_CASES.value,
partial(sample_shmap_batched, 5)))
def test_vmap(self, bdims, fun, mesh, jit, in_specs, out_specs, args, ref):
@ -3003,7 +3006,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
tol = 1e-2 if jtu.test_device_matches(['tpu']) else None
self.assertAllClose(ans, expected, check_dtypes=False, atol=tol, rtol=tol)
@parameterized.named_parameters(
@parameterized.parameters(
sample(jtu.NUM_GENERATED_CASES.value,
partial(sample_shmap_batched, 5)))
def test_vmap_closure(self, bdims, fun, mesh, jit, in_specs, out_specs, args, _):