mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
shard_map and other fixes to direct-linearize
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
This commit is contained in:
parent
178278863d
commit
7c2f842353
@ -2005,8 +2005,8 @@ def vjp(
|
||||
raise NotImplementedError("reduce_axes argument to vjp is deprecated")
|
||||
del reduce_axes
|
||||
check_callable(fun)
|
||||
wrapped_fun = lu.wrap_init(fun,
|
||||
debug_info=debug_info("vjp", fun, primals, {}))
|
||||
wrapped_fun = lu.wrap_init(
|
||||
fun, debug_info=debug_info("vjp", fun, primals, {}))
|
||||
return _vjp(wrapped_fun, *primals, has_aux=has_aux)
|
||||
|
||||
def _vjp(fun: lu.WrappedFun, *primals, has_aux=False):
|
||||
|
@ -2507,8 +2507,8 @@ class MapPrimitive(Primitive):
|
||||
def get_bind_params(self, params):
|
||||
new_params = dict(params)
|
||||
jaxpr: Jaxpr = new_params.pop('call_jaxpr')
|
||||
subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr,
|
||||
debug_info=jaxpr.debug_info), jaxpr, ())
|
||||
subfun = lu.hashable_partial(
|
||||
lu.wrap_init(eval_jaxpr, debug_info=jaxpr.debug_info), jaxpr, ())
|
||||
axes = new_params.pop('out_axes')
|
||||
new_params['out_axes_thunk'] = HashableFunction(lambda: axes, closure=axes)
|
||||
return [subfun], new_params
|
||||
|
@ -39,7 +39,7 @@ from jax._src.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal)
|
||||
from jax._src.dtypes import dtype, float0
|
||||
from jax._src.util import (unzip2, safe_map, safe_zip, split_list, wrap_name,
|
||||
as_hashable_function, weakref_lru_cache,
|
||||
partition_list)
|
||||
partition_list, subs_list2)
|
||||
|
||||
zip = safe_zip
|
||||
map = safe_map
|
||||
@ -91,6 +91,7 @@ def linearize_subtrace(_f: Callable, _store: lu.Store, _tag: core.TraceTag,
|
||||
*primals, **params):
|
||||
with core.take_current_trace() as parent_trace:
|
||||
tangent_trace = pe.DynamicJaxprTrace(debug_info)
|
||||
tangent_trace.tag = _tag
|
||||
linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=_tag)
|
||||
tracers = [LinearizeTracer(linearize_trace, p,
|
||||
tangent_trace.new_arg(get_aval(p).to_tangent_aval()))
|
||||
@ -104,11 +105,23 @@ def linearize_subtrace(_f: Callable, _store: lu.Store, _tag: core.TraceTag,
|
||||
out_tangents = tuple(t for t, nz in zip(out_tangents, nzs_out) if nz)
|
||||
out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents) # type: ignore[assignment]
|
||||
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, debug_info)
|
||||
residual_avals = map(get_aval, consts)
|
||||
if attrs_tracked:
|
||||
raise NotImplementedError("TODO: attrs")
|
||||
_store.store((residual_avals, nzs_out, jaxpr))
|
||||
return tuple(consts) + tuple(out_primals)
|
||||
which_env = [(isinstance(c, pe.DynamicJaxprTracer) and
|
||||
getattr(c._trace, 'tag', None) is _tag) for c in consts]
|
||||
jaxpr = pe.move_envvars(jaxpr, tuple(which_env))
|
||||
res, env = partition_list(which_env, consts)
|
||||
residual_avals = map(get_aval, res)
|
||||
# Which residuals are just forwarded inputs? Check object id.
|
||||
id_map = {id(p): i for i, p in enumerate(primals)}
|
||||
in_fwd: list[int | None] = [id_map.get(id(r)) for r in res]
|
||||
# Which residuals are already primal outputs? Check object id.
|
||||
id_map = {id(p): i for i, p in enumerate(out_primals)}
|
||||
out_fwd: list[int | None] = [id_map.get(id(r)) for r in res]
|
||||
# Prune residuals not to include forwarded primal inputs or outputs.
|
||||
res = [p for p, f1, f2 in zip(res, in_fwd, out_fwd) if f1 is None and f2 is None]
|
||||
_store.store((residual_avals, nzs_out, jaxpr, env, in_fwd, out_fwd))
|
||||
return *res, *out_primals
|
||||
|
||||
@lu.transformation2
|
||||
def jvp_subtrace(f: Callable, tag: core.TraceTag, primals, tangents):
|
||||
@ -157,6 +170,7 @@ def _linearize_jaxpr(
|
||||
primal_trace = pe.DynamicJaxprTrace(dbg)
|
||||
tangent_trace = pe.DynamicJaxprTrace(dbg)
|
||||
lin_trace = LinearizeTrace(primal_trace, tangent_trace)
|
||||
tangent_trace.tag = lin_trace.tag
|
||||
|
||||
def new_arg(trace, primal_aval, nz):
|
||||
primal = primal_trace.new_arg(primal_aval)
|
||||
@ -197,6 +211,7 @@ def direct_linearize(traceable: lu.WrappedFun,
|
||||
tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval()) for p in primals]
|
||||
tangents = [Zero.from_primal_value(t) if dtype(t) == float0 else t for t in tangents]
|
||||
linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=tag)
|
||||
tangent_trace.tag = linearize_trace.tag
|
||||
tracers = [LinearizeTracer(linearize_trace, p, t) for p, t in zip(primals, tangents)]
|
||||
tracers = [t.full_lower() for t in tracers]
|
||||
with (core.set_current_trace(linearize_trace, check_leaks=True),
|
||||
@ -217,6 +232,10 @@ def direct_linearize(traceable: lu.WrappedFun,
|
||||
out_nz_tangents = map(tangent_trace.to_jaxpr_tracer, out_nz_tangents)
|
||||
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_nz_tangents, traceable.debug_info)
|
||||
tangent_trace.invalidate()
|
||||
jaxpr, used_consts, _ = pe.dce_jaxpr_consts(
|
||||
jaxpr, [True] * len(jaxpr.outvars),
|
||||
[False] * len(jaxpr.constvars) + [True] * len(jaxpr.invars))
|
||||
consts = [c for c, used in zip(consts, used_consts) if used]
|
||||
out_tangents_pvals = [pe.PartialVal.unknown(core.get_aval(t)) if nz else
|
||||
pe.PartialVal.known(zeros_like_aval(t.aval))
|
||||
for t, nz in zip(out_tangents, out_nzs)]
|
||||
@ -586,7 +605,7 @@ def _primal_tangent_shapes_match(primal, tangent):
|
||||
if type(tangent) is not Zero:
|
||||
primal_aval = get_aval(primal).strip_weak_type()
|
||||
tangent_aval = get_aval(tangent).strip_weak_type()
|
||||
assert core.definitely_equal_shape(primal_aval.shape, tangent_aval.shape)
|
||||
assert core.definitely_equal_shape(primal_aval.shape, tangent_aval.shape), (primal_aval.shape, tangent_aval.shape)
|
||||
expected_tangent_dtype = core.primal_dtype_to_tangent_dtype(primal_aval.dtype)
|
||||
assert expected_tangent_dtype == tangent_aval.dtype, (expected_tangent_dtype, tangent_aval.dtype)
|
||||
|
||||
@ -641,6 +660,7 @@ class LinearizeTrace(Trace):
|
||||
return prim.bind_with_trace(self.parent_trace, (fun, f_jvp, *primals_in),
|
||||
dict(symbolic_zeros=symbolic_zeros))
|
||||
|
||||
@partial(lu.wrap_init, debug_info=f_jvp.debug_info)
|
||||
def _f_jvp(primals, tangents):
|
||||
outs = f_jvp.call_wrapped(*primals, *tangents)
|
||||
primals_out, tangents_out = split_list(outs, [len(outs) // 2])
|
||||
@ -651,7 +671,7 @@ class LinearizeTrace(Trace):
|
||||
nonzeros_in = [type(t) is not Zero for t in tangents_in]
|
||||
primals_out, tangent_nzs_out, residuals, linearized = linearize_from_jvp(
|
||||
_f_jvp, True, nonzeros_in, symbolic_zeros, instantiate_zeros,
|
||||
f_jvp.debug_info, primals_in, {})
|
||||
primals_in, {})
|
||||
|
||||
with core.set_current_trace(self.tangent_trace):
|
||||
tangents_out = linearized(residuals, *tangents_in)
|
||||
@ -690,53 +710,65 @@ class LinearizeTrace(Trace):
|
||||
assert call_primitive.multiple_results
|
||||
primals, tangents = unzip2(map(self.to_primal_tangent_pair, tracers))
|
||||
nzs_in = tuple(type(t) is not Zero for t in tangents)
|
||||
f_primal, linearize_outs_thunk = linearize_subtrace(f, self.tag, nzs_in,
|
||||
f.debug_info)
|
||||
f_primal, linearize_outs_thunk = linearize_subtrace(
|
||||
f, self.tag, nzs_in, f.debug_info)
|
||||
if isinstance(call_primitive, core.MapPrimitive):
|
||||
@as_hashable_function(closure=(linearize_outs_thunk))
|
||||
out_axes_thunk = params['out_axes_thunk']
|
||||
@as_hashable_function(closure=out_axes_thunk)
|
||||
def new_out_axes_thunk():
|
||||
residual_avals, _, _ = linearize_outs_thunk()
|
||||
out_axes = params['out_axes_thunk']()
|
||||
return (*(0 for _ in residual_avals), *out_axes)
|
||||
_, _, _, _, in_fwd, out_fwd = linearize_outs_thunk()
|
||||
num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
|
||||
out_axes = out_axes_thunk()
|
||||
return (*(0 for _ in range(num_res_out)), *out_axes)
|
||||
primal_params = dict(params, out_axes_thunk=new_out_axes_thunk)
|
||||
else:
|
||||
primal_params = params
|
||||
|
||||
all_primal_results = call_primitive.bind_with_trace(self.parent_trace, (f_primal, *primals), primal_params)
|
||||
residual_avals, nzs_out, lin_jaxpr = linearize_outs_thunk()
|
||||
num_residuals = len(residual_avals)
|
||||
residuals = all_primal_results[:num_residuals]
|
||||
primals_out = all_primal_results[num_residuals:]
|
||||
residual_avals, nzs_out, lin_jaxpr, env, in_fwd, out_fwd = linearize_outs_thunk()
|
||||
num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
|
||||
non_fwd_res = all_primal_results[:num_res_out]
|
||||
primals_out = all_primal_results[num_res_out:]
|
||||
residuals = subs_list2(in_fwd, out_fwd, primals, primals_out, non_fwd_res)
|
||||
|
||||
if isinstance(call_primitive, core.MapPrimitive):
|
||||
in_axes = params['in_axes']
|
||||
out_axes = params['out_axes_thunk']()
|
||||
residual_avals = map(get_aval, residuals)
|
||||
new_in_axes = (*(0 for _ in residual_avals),
|
||||
residual_axes = [in_axes[f1] if f1 is not None else
|
||||
out_axes[f2] if f2 is not None else
|
||||
0 for f1, f2 in zip(in_fwd, out_fwd)]
|
||||
new_in_axes = (*residual_axes, *(None for _ in range(len(env))),
|
||||
*(ax for ax, nz in zip(in_axes, nzs_in) if nz))
|
||||
new_out_axes = (*(ax for ax, nz in zip(out_axes, nzs_out) if nz),)
|
||||
# NOTE: This assumes that the output tangents being zero is a
|
||||
# deterministic function of which input tangents were zero.
|
||||
@as_hashable_function(closure=(new_out_axes))
|
||||
@as_hashable_function(closure=new_out_axes)
|
||||
def new_out_axes_thunk():
|
||||
return new_out_axes
|
||||
params = dict(params,
|
||||
in_axes=new_in_axes,
|
||||
out_axes_thunk=new_out_axes_thunk)
|
||||
params = dict(params, in_axes=new_in_axes, out_axes_thunk=new_out_axes_thunk)
|
||||
|
||||
update_params = call_linearize_param_updaters.get(call_primitive)
|
||||
new_params = update_params(params, residual_avals, nzs_in) if update_params else params
|
||||
num_new_args = len(residuals) + len(env)
|
||||
new_params = update_params(params, num_new_args, nzs_in) if update_params else params
|
||||
num_residuals = len(residual_avals)
|
||||
|
||||
@as_hashable_function(closure=(num_residuals, lin_jaxpr))
|
||||
def f_tangent(*args):
|
||||
residuals = args[:num_residuals]
|
||||
consts = args[:num_residuals]
|
||||
nz_tangents = args[num_residuals:]
|
||||
return core.eval_jaxpr(lin_jaxpr, residuals, *nz_tangents)
|
||||
return core.eval_jaxpr(lin_jaxpr, consts, *nz_tangents)
|
||||
# TODO(mattjj,dougalm): this tag is read by DynamicJaxprTrace.process_map to
|
||||
# avoid round-tripping the jaxpr and thus getting grad-of-pmap cache misses.
|
||||
# Remove when we replace the pmap implementation.
|
||||
f_tangent._pmap_tag = isinstance(call_primitive, core.MapPrimitive)
|
||||
|
||||
thing = lu.wrap_init(f_tangent, debug_info=lin_jaxpr.debug_info)
|
||||
nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz]
|
||||
nz_tangents_out = call_primitive.bind_with_trace(
|
||||
self.tangent_trace, (lu.wrap_init(f_tangent,
|
||||
debug_info=lin_jaxpr.debug_info),
|
||||
*residuals, *nz_tangents_in), new_params)
|
||||
self.tangent_trace,
|
||||
(thing,
|
||||
*residuals, *env, *nz_tangents_in), new_params)
|
||||
nz_tangents_out_iter = iter(nz_tangents_out)
|
||||
tangents_out = [next(nz_tangents_out_iter) if nz else Zero.from_primal_value(primal)
|
||||
for nz, primal in zip(nzs_out, primals_out)]
|
||||
@ -762,14 +794,14 @@ def fallback_linearize_rule(_prim: core.Primitive,
|
||||
msg = f"Differentiation rule for '{_prim}' not implemented"
|
||||
raise NotImplementedError(msg)
|
||||
debug_jvp = debug_info("linearize_prim_jvp", jvp, primals, params)
|
||||
return linearize_from_jvp(jvp, _prim.multiple_results, _nonzeros, False, False,
|
||||
debug_jvp, primals, params)
|
||||
return linearize_from_jvp(lu.wrap_init(jvp, debug_info=debug_jvp),
|
||||
_prim.multiple_results, _nonzeros, False, False,
|
||||
primals, params)
|
||||
|
||||
def linearize_from_jvp(jvp: Callable,
|
||||
def linearize_from_jvp(jvp: lu.WrappedFun,
|
||||
multiple_results: bool,
|
||||
nonzeros: Sequence[bool],
|
||||
user_facing_symbolic_zeros: bool, instantiate_input_zeros: bool,
|
||||
debug_info: core.DebugInfo,
|
||||
primals, params):
|
||||
current_name_stack = source_info_util.current_name_stack()
|
||||
with core.take_current_trace() as parent_trace:
|
||||
@ -792,13 +824,18 @@ def linearize_from_jvp(jvp: Callable,
|
||||
tangent_args = tuple(trace.new_arg(pe.PartialVal.unknown(aval)) if nz else make_zero(aval)
|
||||
for aval, nz in zip(tangent_avals, nonzeros))
|
||||
with core.set_current_trace(trace):
|
||||
out_primals, out_tangents = jvp(primals, tangent_args, **params)
|
||||
out_primals, out_tangents = jvp.call_wrapped(primals, tangent_args, **params)
|
||||
|
||||
if not multiple_results:
|
||||
out_primals = [out_primals]
|
||||
out_tangents = [out_tangents]
|
||||
|
||||
out_primals = [trace.to_jaxpr_tracer(p).pval.get_known() for p in out_primals]
|
||||
if any(p is None for p in out_primals):
|
||||
raise ValueError(
|
||||
"Linearization failed to produce known values for all output primals. "
|
||||
"This is typically caused by attempting to differentiate a function "
|
||||
"uses an operation that does not support reverse-mode autodiff.")
|
||||
|
||||
out_nzs = [type(t) is not zero_type and not trace.to_jaxpr_tracer(t).is_known()
|
||||
for t in out_tangents]
|
||||
@ -806,7 +843,7 @@ def linearize_from_jvp(jvp: Callable,
|
||||
out_nz_tracers = [trace.to_jaxpr_tracer(r)
|
||||
for (r, nz) in zip(out_tangents, out_nzs) if nz]
|
||||
in_tracers = [t for t, nz in zip(tangent_args, nonzeros) if nz]
|
||||
jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers, debug_info)
|
||||
jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers, jvp.debug_info)
|
||||
|
||||
def linearized(residuals, *tangents):
|
||||
nz_tangents_in = [t for (t, nz) in zip(tangents, nonzeros) if nz]
|
||||
@ -973,9 +1010,8 @@ def call_transpose(primitive, params, call_jaxpr: core.Jaxpr, args, ct, _):
|
||||
else:
|
||||
consts = ()
|
||||
all_args, in_tree_def = tree_flatten((consts, args, ct))
|
||||
fun = lu.hashable_partial(lu.wrap_init(backward_pass,
|
||||
debug_info=call_jaxpr.debug_info),
|
||||
call_jaxpr, False)
|
||||
fun = lu.hashable_partial(lu.wrap_init(
|
||||
backward_pass, debug_info=call_jaxpr.debug_info), call_jaxpr, False)
|
||||
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
||||
update_params = call_transpose_param_updaters.get(primitive)
|
||||
if update_params:
|
||||
@ -1013,9 +1049,8 @@ def map_transpose(primitive: core.Primitive, params,
|
||||
call_jaxpr: core.Jaxpr, args, ct, _):
|
||||
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
|
||||
# TODO(necula): use the right debug_info for the backwards pass
|
||||
fun = lu.hashable_partial(lu.wrap_init(backward_pass,
|
||||
debug_info=call_jaxpr.debug_info),
|
||||
call_jaxpr, False)
|
||||
fun = lu.hashable_partial(lu.wrap_init(
|
||||
backward_pass, debug_info=call_jaxpr.debug_info), call_jaxpr, False)
|
||||
fun, nz_arg_cts = nonzero_outputs(fun)
|
||||
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
||||
# Preserve axis for primal arguments, skip tangents (represented as undefined primals).
|
||||
|
@ -46,7 +46,8 @@ from jax._src.tree_util import (PyTreeDef, treedef_tuple,
|
||||
tree_flatten, tree_structure)
|
||||
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
|
||||
merge_lists, partition_list, OrderedSet,
|
||||
as_hashable_function, weakref_lru_cache, subs_list)
|
||||
as_hashable_function, weakref_lru_cache, subs_list,
|
||||
HashableFunction)
|
||||
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
@ -837,6 +838,11 @@ def tracers_to_jaxpr(
|
||||
# del getvar # needed to avoid cyclic-reference closure, apparently!
|
||||
return jaxpr, const_vals, env_vals
|
||||
|
||||
@weakref_lru_cache
|
||||
def move_envvars(jaxpr: Jaxpr, which: tuple[bool, ...]) -> Jaxpr:
|
||||
constvars, envvars = partition_list(which, jaxpr.constvars)
|
||||
return jaxpr.replace(constvars=constvars, invars=[*envvars, *jaxpr.invars])
|
||||
|
||||
@weakref_lru_cache
|
||||
def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr:
|
||||
"""Moves the constvars to the start of invars."""
|
||||
@ -1840,7 +1846,7 @@ def _inline_literals(
|
||||
|
||||
|
||||
class DynamicJaxprTrace(core.Trace):
|
||||
__slots__ = ("frame",)
|
||||
__slots__ = ("frame", "tag")
|
||||
|
||||
def __init__(self, debug_info: core.DebugInfo):
|
||||
self.frame = JaxprStackFrame(debug_info)
|
||||
@ -1972,17 +1978,18 @@ class DynamicJaxprTrace(core.Trace):
|
||||
self.frame.add_eqn(eqn)
|
||||
return [t for t, (_, keep) in zip(out_tracers, out_type) if keep]
|
||||
|
||||
def process_map(self, map_primitive, f: lu.WrappedFun,
|
||||
tracers: Sequence[core.Tracer], params):
|
||||
def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
|
||||
tracers = map(self.to_jaxpr_tracer, tracers)
|
||||
in_avals = [t.aval for t in tracers]
|
||||
axis_name, axis_size = params['axis_name'], params['axis_size']
|
||||
reduced_in_avals = [core.mapped_aval(axis_size, in_axis, a)
|
||||
if in_axis is not None else a
|
||||
for a, in_axis in zip(in_avals, params['in_axes'])]
|
||||
|
||||
with core.extend_axis_env_nd([(axis_name, params["global_axis_size"])]):
|
||||
jaxpr, reduced_out_avals, consts, () = trace_to_jaxpr_dynamic(
|
||||
f, reduced_in_avals)
|
||||
jaxpr, consts = _linearize_of_pmap_hack(f, jaxpr, consts)
|
||||
ordered_effects = effects.ordered_effects.filter_in(jaxpr.effects)
|
||||
if ordered_effects:
|
||||
raise ValueError("Ordered effects not supported for "
|
||||
@ -2582,3 +2589,13 @@ def inline_jaxpr_into_trace(
|
||||
return tracer
|
||||
return [x.val if isinstance(x, Literal) else tracer_env[x] if x in tracer_env
|
||||
else new_tracer(x) for x in jaxpr.outvars]
|
||||
|
||||
# TODO(mattjj,dougalm): this special handling is to avoid round-tripping the
|
||||
# jaxpr when we do grad-of-pmap. The tag is set by LinearizeTrace.process_call's
|
||||
# handling of pmap. Remove when we replace the pmap implementation.
|
||||
def _linearize_of_pmap_hack(f: lu.WrappedFun, jaxpr, consts) -> tuple[Jaxpr, list]:
|
||||
if (not f.transforms and type(f.f) is HashableFunction and
|
||||
getattr(f.f, '_pmap_tag', None)):
|
||||
_, jaxpr = f.f.closure
|
||||
return convert_constvars_jaxpr(jaxpr), []
|
||||
return jaxpr, consts
|
||||
|
@ -1394,9 +1394,9 @@ def xla_call_jvp_update_params(params, nz_tangents):
|
||||
new_donated_invars = (*donated_invars, *donated_tangents)
|
||||
return dict(params, donated_invars=new_donated_invars)
|
||||
|
||||
def _xla_call_linearize_update_params(params, residual_avals, nz_tangents):
|
||||
def _xla_call_linearize_update_params(params, num_new_inputs, nz_tangents):
|
||||
donated_invars_prev = params['donated_invars']
|
||||
donated_invars = (*(False for _ in residual_avals),
|
||||
donated_invars = (*(False for _ in range(num_new_inputs)),
|
||||
*(d for d, nz in zip(donated_invars_prev, nz_tangents) if nz))
|
||||
return dict(params, donated_invars=donated_invars)
|
||||
|
||||
|
@ -3736,14 +3736,14 @@ def _sin_lowering(ctx, x):
|
||||
return sine(ctx, x)
|
||||
return _nary_lower_hlo(hlo.sine, ctx, x)
|
||||
|
||||
def _sin_p_lin(nzs, x):
|
||||
def _sin_lin(nzs, x):
|
||||
nz, = nzs
|
||||
cos_x = cos(x) # TODO: allow this to happen in the linearized computation (need to fix backward_pass)
|
||||
return (sin_p.bind(x), nz, cos_x, lambda cos_x_, t: mul(t, cos_x_))
|
||||
|
||||
sin_p = standard_unop(_float | _complex, 'sin')
|
||||
ad.defjvp(sin_p, lambda g, x: mul(g, cos(x)))
|
||||
ad.primitive_linearizations[sin_p] = _sin_p_lin
|
||||
ad.primitive_linearizations[sin_p] = _sin_lin
|
||||
mlir.register_lowering(sin_p, _sin_lowering)
|
||||
batching.ragged_prop_rules[sin_p] = batching.ragged_mask_elementwise_rule
|
||||
|
||||
|
@ -200,7 +200,7 @@ class _BaseMesh:
|
||||
|
||||
_mesh_object_dict = {} # type: ignore
|
||||
|
||||
MeshAxisType = dict[AxisTypes, str | tuple[str, ...]]
|
||||
MeshAxisType = dict[AxisTypes, MeshAxisName | tuple[MeshAxisName, ...]]
|
||||
|
||||
class Mesh(_BaseMesh, contextlib.ContextDecorator):
|
||||
"""Declare the hardware resources available in the scope of this manager.
|
||||
|
@ -2043,18 +2043,6 @@ def _pjit_jvp(primals_in, tangents_in,
|
||||
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, keep_unused, inline,
|
||||
compiler_options_kvs):
|
||||
if any(isinstance(c, core.MutableArray) for c in jaxpr.consts):
|
||||
jaxpr, mut_primals = pxla._move_mutable_consts(jaxpr)
|
||||
mut_tangents = map(ad_util.zeros_like_jaxval, mut_primals)
|
||||
primals_in = [*primals_in, *mut_primals]
|
||||
tangents_in = [*tangents_in, *mut_tangents]
|
||||
in_shardings = (*in_shardings,) + (UNSPECIFIED,) * len(mut_primals)
|
||||
in_layouts = (*in_layouts,) + (None,) * len(mut_primals)
|
||||
donated_invars = (*donated_invars,) + (False,) * len(mut_primals)
|
||||
|
||||
tangents_in = [ad_util.zeros_like_aval(a) if isinstance(a, AbstractRef) else x
|
||||
for x, a in zip(tangents_in, jaxpr.in_avals)]
|
||||
|
||||
is_nz_tangents_in = [type(t) is not ad.Zero for t in tangents_in]
|
||||
jaxpr_jvp, is_nz_tangents_out = ad.jvp_jaxpr(
|
||||
jaxpr, is_nz_tangents_in, instantiate=False)
|
||||
|
@ -437,6 +437,8 @@ def _swap_jvp(primals: list[Any], tangents: list[Any], **params: Any):
|
||||
ref_primal, x_primal, *idx = primals
|
||||
assert isinstance(ref_primal.aval, AbstractRef)
|
||||
ref_tangent, x_tangent, *_ = tangents
|
||||
if type(ref_tangent) is ad_util.Zero:
|
||||
raise Exception("you're an idiot")
|
||||
assert isinstance(ref_tangent.aval, AbstractRef)
|
||||
x_tangent = ad_util.instantiate(x_tangent)
|
||||
return (swap_p.bind(ref_primal, x_primal, *idx, **params),
|
||||
@ -657,5 +659,14 @@ mlir.register_lowering(
|
||||
|
||||
# === AD rules for mutable arrays ===
|
||||
|
||||
ad.defjvp(core.mutable_array_p, lambda g, _: core.mutable_array(g))
|
||||
def _mut_jvp(primals, tangents):
|
||||
(init_val,), (init_val_dot,) = primals, tangents
|
||||
primal_out = core.mutable_array_p.bind(init_val)
|
||||
if type(init_val_dot) is ad_util.Zero:
|
||||
tangent_out = core.mutable_array_p.bind(ad_util.zeros_like_aval(init_val_dot.aval))
|
||||
else:
|
||||
tangent_out = core.mutable_array_p.bind(init_val_dot)
|
||||
return primal_out, tangent_out
|
||||
|
||||
ad.primitive_jvps[core.mutable_array_p] = _mut_jvp
|
||||
ad.defjvp(core.freeze_p, lambda g, _: core.freeze(g))
|
||||
|
@ -544,6 +544,8 @@ def _shard_map_staging(
|
||||
return out_tracers
|
||||
pe.DynamicJaxprTrace.process_shard_map = _shard_map_staging
|
||||
|
||||
# TODO add underscore version, for direct-linearize to consume
|
||||
|
||||
def _check_shapedarray(aval: core.AbstractValue) -> core.ShapedArray:
|
||||
assert isinstance(aval, core.ShapedArray)
|
||||
return aval
|
||||
@ -742,9 +744,8 @@ def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names,
|
||||
out_avals_ = [x.aval for x in jaxpr.outvars]
|
||||
in_nodes_ = map(partial(_xla_shard, ctx, mesh, auto), in_names, ctx.avals_in,
|
||||
in_avals_, in_nodes)
|
||||
new_axis_context = sharding_impls.SPMDAxisContext(
|
||||
mesh, frozenset(mesh.axis_names) - auto
|
||||
)
|
||||
manual_axes = frozenset(mesh.axis_names) - auto
|
||||
new_axis_context = sharding_impls.SPMDAxisContext(mesh, manual_axes)
|
||||
sub_ctx = ctx.module_context.replace(axis_context=new_axis_context)
|
||||
with _extend_axis_env(mesh, auto):
|
||||
out_nodes_, tokens_out = mlir.call_lowering(
|
||||
@ -895,7 +896,6 @@ def _match_spec(mesh: Mesh, check_rep: bool,
|
||||
|
||||
def _match(mesh, check_rep, pspec, x):
|
||||
src = P(mesh.axis_names)
|
||||
# TODO put back (?) needed for rep checking in eager? for now test rewrite
|
||||
return shard_map(_rem_singleton, mesh, (src,), pspec, check_rep=False)(x)
|
||||
|
||||
def _rem_singleton(x): return jnp.squeeze(x, axis=0)
|
||||
@ -914,6 +914,7 @@ class ShardMapTrace(core.Trace):
|
||||
__slots__ = ("mesh", "auto", "check", "context_mesh")
|
||||
|
||||
mesh: Mesh
|
||||
auto: frozenset[AxisName]
|
||||
check: bool
|
||||
context_mesh: AbstractMesh
|
||||
|
||||
@ -927,7 +928,7 @@ class ShardMapTrace(core.Trace):
|
||||
if isinstance(val, ShardMapTracer):
|
||||
return val.val, val.rep
|
||||
elif isinstance(val, Tracer):
|
||||
raise Exception("Shouldn't have any non-shard_map tracers")
|
||||
raise Exception(f"Shouldn't have any non-shard_map tracers: {val}")
|
||||
else:
|
||||
val_ = _unmatch_spec(self.mesh, {}, val, self.context_mesh)
|
||||
return val_, None
|
||||
@ -1609,34 +1610,40 @@ def _shard_map_linearize(trace, shard_map_p, f: lu.WrappedFun,
|
||||
out_names_thunk, check_rep, rewrite, auto):
|
||||
primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers))
|
||||
nzs_in = tuple(type(t) is not ad.Zero for t in tangents)
|
||||
f_primal, linearize_outs_thunk = ad.linearize_subtrace(f, trace.tag, nzs_in,
|
||||
f.debug_info)
|
||||
f_primal, linearize_outs_thunk = ad.linearize_subtrace(f, trace.tag, nzs_in, f.debug_info)
|
||||
f_primal = _promote_scalar_residuals_lin(f_primal, linearize_outs_thunk)
|
||||
tangent_in_names = [ax for ax, nz in zip(in_names, nzs_in) if nz]
|
||||
all_names = _all_newly_manual_mesh_names(mesh, auto, trace)
|
||||
res_names = _all_newly_manual_mesh_names(mesh, auto, trace)
|
||||
|
||||
@as_hashable_function(closure=(linearize_outs_thunk))
|
||||
@as_hashable_function(closure=linearize_outs_thunk)
|
||||
def primal_out_names_thunk():
|
||||
residual_avals, _, _ = linearize_outs_thunk()
|
||||
_, _, _, _, in_fwd, out_fwd = linearize_outs_thunk()
|
||||
out_names = out_names_thunk()
|
||||
# This is incorrect so we set `check_rep=False` as we do in the JVP rule.
|
||||
return (*({0: all_names} for _ in residual_avals), *out_names)
|
||||
num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
|
||||
# This is incorrect so we set `check_rep=False` in the tangent (as in JVP).
|
||||
return (*({0: res_names} for _ in range(num_res_out)), *out_names)
|
||||
primal_params = dict(
|
||||
mesh=mesh, in_names=in_names,
|
||||
out_names_thunk=primal_out_names_thunk, check_rep=check_rep,
|
||||
rewrite=rewrite, auto=auto)
|
||||
all_primal_results = shard_map_p.bind_with_trace(
|
||||
trace.parent_trace, (f_primal,) + tuple(primals), primal_params)
|
||||
residual_avals, nzs_out, lin_jaxpr = linearize_outs_thunk()
|
||||
num_residuals = len(residual_avals)
|
||||
residuals = all_primal_results[:num_residuals]
|
||||
primals_out = all_primal_results[num_residuals:]
|
||||
args_to_promote = [getattr(aval, 'shape', ()) == () for aval in residual_avals]
|
||||
lin_jaxpr = _promote_scalar_residuals_jaxpr(lin_jaxpr, args_to_promote)
|
||||
trace.parent_trace, (f_primal, *primals), primal_params)
|
||||
residual_avals, nzs_out, lin_jaxpr, env, in_fwd, out_fwd = linearize_outs_thunk()
|
||||
num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
|
||||
non_fwd_res = all_primal_results[:num_res_out]
|
||||
primals_out = all_primal_results[num_res_out:]
|
||||
residuals = subs_list2(in_fwd, out_fwd, primals, primals_out, non_fwd_res)
|
||||
args_to_promote = [getattr(aval, 'shape', ()) == () and f1 is None and f2 is None
|
||||
for aval, f1, f2 in zip(residual_avals, in_fwd, out_fwd)]
|
||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
||||
lin_jaxpr = _promote_scalar_residuals_jaxpr(lin_jaxpr, args_to_promote)
|
||||
out_names = out_names_thunk()
|
||||
new_in_names = (*({0: all_names} for _ in residual_avals),
|
||||
residual_names = [in_names[f1] if f1 is not None else
|
||||
out_names[f2] if f2 is not None else
|
||||
{0: res_names} for f1, f2 in zip(in_fwd, out_fwd)]
|
||||
new_in_names = (*residual_names, *({} for _ in range(len(env))),
|
||||
*(ax for ax, nz in zip(in_names, nzs_in) if nz))
|
||||
new_out_names = (*(ax for ax, nz in zip(out_names, nzs_out) if nz),)
|
||||
new_out_names = tuple(ax for ax, nz in zip(out_names, nzs_out) if nz)
|
||||
@as_hashable_function(closure=(new_out_names))
|
||||
def tangent_out_names_thunk():
|
||||
return new_out_names
|
||||
@ -1645,15 +1652,14 @@ def _shard_map_linearize(trace, shard_map_p, f: lu.WrappedFun,
|
||||
out_names_thunk=tangent_out_names_thunk, check_rep=False,
|
||||
rewrite=rewrite, auto=auto)
|
||||
|
||||
# TODO TODO don't round-trip
|
||||
def f_tangent(*args):
|
||||
residuals = args[:num_residuals]
|
||||
nz_tangents = args[num_residuals:]
|
||||
return core.eval_jaxpr(lin_jaxpr, (), *residuals, *nz_tangents)
|
||||
return core.eval_jaxpr(lin_jaxpr, (), *args)
|
||||
|
||||
nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz]
|
||||
nz_tangents_out = shard_map_p.bind_with_trace(trace.tangent_trace,
|
||||
(lu.wrap_init(f_tangent, debug_info=lin_jaxpr.debug_info),
|
||||
*residuals, *nz_tangents_in), tangent_params)
|
||||
*residuals, *env, *nz_tangents_in), tangent_params)
|
||||
nz_tangents_out_iter = iter(nz_tangents_out)
|
||||
tangents_out = [next(nz_tangents_out_iter) if nz else ad.Zero.from_primal_value(primal)
|
||||
for nz, primal in zip(nzs_out, primals_out)]
|
||||
@ -1663,13 +1669,13 @@ ad.LinearizeTrace.process_shard_map = _shard_map_linearize
|
||||
@lu.transformation2
|
||||
def _promote_scalar_residuals_lin(f, linearize_outs_thunk, *args, **kwargs):
|
||||
ans = f(*args, **kwargs)
|
||||
residual_avals, _, _ = linearize_outs_thunk()
|
||||
num_residuals = len(residual_avals)
|
||||
residuals = ans[:num_residuals]
|
||||
primals = ans[num_residuals:]
|
||||
residuals = tuple(jax.lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x
|
||||
for x in residuals)
|
||||
return residuals + primals
|
||||
_, _, _, _, in_fwd, out_fwd = linearize_outs_thunk()
|
||||
num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
|
||||
residuals = ans[:num_res_out]
|
||||
primals = ans[num_res_out:]
|
||||
residuals = [jax.lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x
|
||||
for x in residuals]
|
||||
return *residuals, *primals
|
||||
|
||||
@lu.transformation2
|
||||
def _promote_scalar_residuals(f: Callable, *args, **kwargs):
|
||||
@ -1798,10 +1804,10 @@ def _partial_eval_jaxpr_custom_rule(
|
||||
_, ins_staged = partition_list(inst_in, eqn.invars)
|
||||
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
|
||||
newvar = core.gensym()
|
||||
params_known, params_staged, all_names = _pe_custom_params(
|
||||
params_known, params_staged, res_names = _pe_custom_params(
|
||||
unks_in, inst_in, map(op.not_, unks_out), inst_out, in_fwd, out_fwd, which,
|
||||
dict(eqn.params, jaxpr=jaxpr_known), dict(eqn.params, jaxpr=jaxpr_staged))
|
||||
residuals = [newvar(_unshard_aval(mesh, {0: all_names}, var.aval))
|
||||
residuals = [newvar(_unshard_aval(mesh, {0: res_names}, var.aval))
|
||||
for var, w in zip(jaxpr_staged.invars[:num_res], which) if w]
|
||||
eqn_known = pe.new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals],
|
||||
eqn.primitive, params_known, jaxpr_known.effects,
|
||||
@ -1853,10 +1859,10 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged,
|
||||
# prune inputs to jaxpr_known according to unks_in
|
||||
mesh = params_known['mesh']
|
||||
auto = params_known['auto']
|
||||
all_names = _all_newly_manual_mesh_names(mesh, auto)
|
||||
res_names_ = _all_newly_manual_mesh_names(mesh, auto)
|
||||
in_names_known, _ = partition_list(unks_in, params_known['in_names'])
|
||||
_, out_names_known = partition_list(kept_outs_known, params_known['out_names'])
|
||||
out_names_known = out_names_known + [{0: all_names}] * sum(which)
|
||||
out_names_known = out_names_known + [{0: res_names_}] * sum(which)
|
||||
new_params_known = dict(params_known, in_names=tuple(in_names_known),
|
||||
out_names=tuple(out_names_known))
|
||||
|
||||
@ -1864,12 +1870,12 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged,
|
||||
_, in_names_staged = partition_list(inst_in, params_staged['in_names'])
|
||||
res_names = [in_names_known[f1] if f1 is not None else
|
||||
out_names_known[f2] if f2 is not None else
|
||||
{0: all_names} for f1, f2 in zip(in_fwd, out_fwd)]
|
||||
{0: res_names_} for f1, f2 in zip(in_fwd, out_fwd)]
|
||||
in_names_staged = res_names + in_names_staged
|
||||
_, out_names_staged = partition_list(kept_outs_staged, params_staged['out_names'])
|
||||
new_params_staged = dict(params_staged, in_names=tuple(in_names_staged),
|
||||
out_names=tuple(out_names_staged), check_rep=False)
|
||||
return new_params_known, new_params_staged, all_names
|
||||
return new_params_known, new_params_staged, res_names_
|
||||
|
||||
# TODO(mattjj): remove this mechanism when we revise mesh scopes
|
||||
def _all_mesh_names_except_spmd(
|
||||
@ -1880,15 +1886,21 @@ def _all_mesh_names_except_spmd(
|
||||
return tuple(name for name in mesh.axis_names if name not in spmd_names and
|
||||
name not in auto)
|
||||
|
||||
# TODO(mattjj): remove this mechanism when we revise mesh scopes
|
||||
def _all_newly_manual_mesh_names(
|
||||
mesh: Mesh, auto: frozenset[AxisName], trace=None
|
||||
) -> tuple[AxisName, ...]:
|
||||
axis_env = core.get_axis_env()
|
||||
spmd_names = axis_env.spmd_axis_names
|
||||
axis_sizes = axis_env.axis_sizes
|
||||
return tuple(name for name in mesh.axis_names if name not in spmd_names and
|
||||
name not in auto and name not in axis_sizes)
|
||||
if not (ctx_mesh := get_abstract_mesh()).empty:
|
||||
del mesh
|
||||
already_manual_names = set(ctx_mesh.axis_types.get(AxisTypes.Manual, ()))
|
||||
return tuple(name for name in ctx_mesh.axis_names
|
||||
if name not in auto | already_manual_names)
|
||||
else:
|
||||
# TODO(mattjj): remove this mechanism when we revise mesh scopes
|
||||
axis_env = core.get_axis_env()
|
||||
vmap_spmd_names = set(axis_env.spmd_axis_names)
|
||||
already_manual_names = set(axis_env.axis_sizes) # may include vmap axis_names
|
||||
return tuple(name for name in mesh.axis_names
|
||||
if name not in auto | vmap_spmd_names | already_manual_names)
|
||||
|
||||
# DCE
|
||||
|
||||
|
@ -43,13 +43,14 @@ __ = pe.PartialVal.unknown(ShapedArray((), np.float32))
|
||||
def call(f, *args):
|
||||
return jit(f)(*args)
|
||||
|
||||
@util.curry
|
||||
def core_call(f, *args):
|
||||
args, in_tree = jax.tree.flatten(args)
|
||||
dbg = debug_info("core_call_test", f, args, {})
|
||||
f, out_tree = flatten_fun_nokwargs(lu.wrap_init(f, debug_info=dbg), in_tree)
|
||||
out = core.call_p.bind(f, *args)
|
||||
return jax.tree.unflatten(out_tree(), out)
|
||||
# call = core_call
|
||||
core_call = util.curry(core_call)
|
||||
|
||||
@util.curry
|
||||
def core_closed_call(f, *args):
|
||||
|
@ -131,51 +131,6 @@ class MutableArrayTest(jtu.JaxTestCase):
|
||||
out = f()
|
||||
self.assertAllClose(out, jnp.array([2., 0., 1.]), check_dtypes=False)
|
||||
|
||||
@parameterized.parameters([True, False])
|
||||
def test_refs_in_vjps(self, jit):
|
||||
def gradient_history_calculator_fwd(x, ref):
|
||||
return x, ref
|
||||
|
||||
def gradient_history_calculator_bwd(amax_history, grad_output):
|
||||
amax_update = jnp.max(jnp.abs(grad_output))
|
||||
shifted = jnp.roll(amax_history[:], 1)
|
||||
shifted = shifted.at[0].set(amax_update)
|
||||
amax_history[:] = shifted
|
||||
amax_from_history = jnp.max(amax_history[:])
|
||||
grad_output = grad_output / amax_from_history
|
||||
return grad_output, None
|
||||
|
||||
@jax.custom_vjp
|
||||
def gradient_history_calculator(x, ref):
|
||||
return x
|
||||
|
||||
gradient_history_calculator.defvjp(
|
||||
gradient_history_calculator_fwd,
|
||||
gradient_history_calculator_bwd)
|
||||
|
||||
class DotOp:
|
||||
def __init__(self):
|
||||
self.amax_history = core.mutable_array(jnp.zeros(5,))
|
||||
|
||||
def forward(self, x, y):
|
||||
out = jnp.dot(x, y)
|
||||
out = gradient_history_calculator(out, self.amax_history)
|
||||
return out
|
||||
|
||||
dot_op = DotOp()
|
||||
x_top = jnp.ones((5,))
|
||||
y_top = jnp.ones((5,))
|
||||
|
||||
def loss(x, y):
|
||||
return dot_op.forward(x, y).sum()
|
||||
|
||||
if jit:
|
||||
loss = jax.jit(loss)
|
||||
|
||||
for i in range(3):
|
||||
jax.grad(loss, (0,1))(x_top, y_top)
|
||||
self.assertAllClose(dot_op.amax_history[:], jnp.zeros((5,)).at[:i+1].set(1.0), check_dtypes=False)
|
||||
|
||||
@parameterized.parameters([True, False])
|
||||
def test_scan_internal_mut_array(self, jit):
|
||||
def body_fun(_, x):
|
||||
@ -371,17 +326,18 @@ class MutableArrayErrorsTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(ValueError, "x_ref and y_ref"):
|
||||
f(x_ref, x_ref)
|
||||
|
||||
@parameterized.parameters([False, True])
|
||||
def test_argument_aliases_custom_vjp_fwd(self, jit):
|
||||
@jax.custom_vjp
|
||||
def f(x_ref, y_ref):
|
||||
...
|
||||
f.defvjp(lambda x_ref, y_ref: (None, None), lambda _, g: (None, None))
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
x_ref = core.mutable_array(0.)
|
||||
with self.assertRaisesRegex(ValueError, "x_ref and y_ref"):
|
||||
jax.vjp(f, x_ref, x_ref)
|
||||
# TODO(mattjj): re-enable test after direct-linearize
|
||||
# @parameterized.parameters([False, True])
|
||||
# def test_argument_aliases_custom_vjp_fwd(self, jit):
|
||||
# @jax.custom_vjp
|
||||
# def f(x_ref, y_ref):
|
||||
# ...
|
||||
# f.defvjp(lambda x_ref, y_ref: (None, None), lambda _, g: (None, None))
|
||||
# if jit:
|
||||
# f = jax.jit(f)
|
||||
# x_ref = core.mutable_array(0.)
|
||||
# with self.assertRaisesRegex(ValueError, "x_ref and y_ref"):
|
||||
# jax.vjp(f, x_ref, x_ref)
|
||||
|
||||
# TODO(mattjj): add test test_closure_and_argument_aliases_custom_vjp
|
||||
|
||||
|
@ -51,6 +51,9 @@ from jax._src.interpreters import pxla
|
||||
from jax._src.lax import parallel
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.util import safe_map, safe_zip
|
||||
from jax._src import util
|
||||
from jax.api_util import flatten_fun_nokwargs, debug_info
|
||||
from jax._src import linear_util as lu
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jtu.request_cpu_devices(8)
|
||||
|
@ -2205,20 +2205,23 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
mesh = jtu.create_mesh((2, 2), ('i', 'j'))
|
||||
|
||||
def g(x):
|
||||
# manual: 'i', 'j'
|
||||
return x * x
|
||||
|
||||
def h(x):
|
||||
# auto: 'j', manual: 'i'
|
||||
return shard_map(g, mesh,
|
||||
in_specs=P(None, 'j'),
|
||||
out_specs=P(None, 'j'))(x)
|
||||
in_specs=P(None, 'j'),
|
||||
out_specs=P(None, 'j'))(x)
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
# auto: 'i', 'j'
|
||||
return shard_map(h, mesh,
|
||||
in_specs=P('i', None),
|
||||
out_specs=P('i', None),
|
||||
check_rep=False,
|
||||
auto=frozenset({'j'}))(x).sum()
|
||||
in_specs=P('i', None),
|
||||
out_specs=P('i', None),
|
||||
check_rep=False,
|
||||
auto=frozenset({'j'}))(x).sum()
|
||||
|
||||
v = jnp.arange(32.).reshape(4, 8)
|
||||
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
|
||||
@ -2814,7 +2817,7 @@ def sample(num: int, make_gen: Callable[[], Chooser]) -> Iterator[CaseSpec]:
|
||||
name, *case = sample_one(rng, make_gen())
|
||||
if name not in seen:
|
||||
seen.add(name)
|
||||
yield name, *case
|
||||
yield case
|
||||
|
||||
# To sample one test spec, we run the generator, getting back sequences of
|
||||
# options from it and sending in our choices from those options until finally a
|
||||
@ -2929,7 +2932,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
|
||||
def make_mesh(mesh_shape):
|
||||
return jtu.create_mesh(tuple(mesh_shape.values()), tuple(mesh_shape))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
@parameterized.parameters(
|
||||
sample(jtu.NUM_GENERATED_CASES.value, sample_shmap))
|
||||
def test_eager_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref):
|
||||
mesh = self.make_mesh(mesh)
|
||||
@ -2938,7 +2941,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
|
||||
expected = ref(fun, mesh, in_specs, out_specs)(*args)
|
||||
self.assertAllClose(expected, out, check_dtypes=False)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
@parameterized.parameters(
|
||||
sample(jtu.NUM_GENERATED_CASES.value, sample_shmap))
|
||||
def test_jit_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref):
|
||||
mesh = self.make_mesh(mesh)
|
||||
@ -2947,9 +2950,9 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
|
||||
expected = ref(fun, mesh, in_specs, out_specs)(*args)
|
||||
self.assertAllClose(expected, out, check_dtypes=False)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
(name + f'_check_rep={check_rep}', *params, check_rep)
|
||||
for (name, *params) in sample(jtu.NUM_GENERATED_CASES.value, sample_shmap)
|
||||
@parameterized.parameters(
|
||||
(*params, check_rep)
|
||||
for params in sample(jtu.NUM_GENERATED_CASES.value, sample_shmap)
|
||||
for check_rep in [True, False]
|
||||
)
|
||||
@jax.default_matmul_precision("float32")
|
||||
@ -2961,7 +2964,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
|
||||
f = jax.jit(f)
|
||||
jtu.check_grads(f, args, order=2, atol=1e-2, rtol=1e-2)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
@parameterized.parameters(
|
||||
sample(jtu.NUM_GENERATED_CASES.value, sample_shmap))
|
||||
@jax.default_matmul_precision("float32")
|
||||
def test_grads_closure(self, fun, mesh, jit, in_specs, out_specs, args, _):
|
||||
@ -2980,7 +2983,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
|
||||
return g(*args)
|
||||
jtu.check_grads(f, (0.2, *closed_over_args), order=2, atol=1e-2, rtol=1e-2)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
@parameterized.parameters(
|
||||
sample(jtu.NUM_GENERATED_CASES.value,
|
||||
partial(sample_shmap_batched, 5)))
|
||||
def test_vmap(self, bdims, fun, mesh, jit, in_specs, out_specs, args, ref):
|
||||
@ -3003,7 +3006,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
|
||||
tol = 1e-2 if jtu.test_device_matches(['tpu']) else None
|
||||
self.assertAllClose(ans, expected, check_dtypes=False, atol=tol, rtol=tol)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
@parameterized.parameters(
|
||||
sample(jtu.NUM_GENERATED_CASES.value,
|
||||
partial(sample_shmap_batched, 5)))
|
||||
def test_vmap_closure(self, bdims, fun, mesh, jit, in_specs, out_specs, args, _):
|
||||
|
Loading…
x
Reference in New Issue
Block a user