mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add support for reverse-mode AD of xmap
This commit is contained in:
parent
b9aaff4870
commit
478828a21a
@ -43,7 +43,7 @@ from ..lib import xla_bridge as xb
|
||||
from ..lib import xla_client as xc
|
||||
from .._src.util import (safe_map, safe_zip, HashableFunction,
|
||||
as_hashable_function, unzip2, distributed_debug_log,
|
||||
tuple_insert, moveaxis)
|
||||
tuple_insert, moveaxis, split_list, wrap_name)
|
||||
from .._src.lax.parallel import _axis_index_translation_rule
|
||||
from .. import lax
|
||||
|
||||
@ -846,6 +846,44 @@ core.axis_substitution_rules[xmap_p] = _xmap_axis_subst
|
||||
ad.JVPTrace.process_xmap = ad.JVPTrace.process_call # type: ignore
|
||||
ad.call_param_updaters[xmap_p] = ad.call_param_updaters[xla.xla_call_p]
|
||||
|
||||
def _xmap_transpose(params, call_jaxpr, args, cts_in, cts_in_avals, reduce_axes):
|
||||
all_args, in_tree_def = tree_flatten(((), args, cts_in)) # empty consts
|
||||
fun = lu.hashable_partial(
|
||||
lu.wrap_init(ad.backward_pass),
|
||||
call_jaxpr, reduce_axes + tuple(params['global_axis_sizes'].keys()))
|
||||
fun, nz_arg_cts = ad.nonzero_outputs(fun)
|
||||
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
||||
# Preserve axis for primal arguments, skip tangents (represented as undefined primals).
|
||||
in_axes, out_axes = params['in_axes'], params['out_axes']
|
||||
new_in_axes = (*(axis for axis, x in zip(in_axes, args) if not ad.is_undefined_primal(x)),
|
||||
*(axis for axis, x in zip(out_axes, cts_in) if type(x) is not ad.Zero))
|
||||
# NOTE: This assumes that the output cotangents being zero is a deterministic
|
||||
# function of which input cotangents were zero.
|
||||
@as_hashable_function(closure=(in_axes, tuple(type(c) is ad.Zero for c in cts_in)))
|
||||
def out_axes_thunk():
|
||||
return tuple(axis for axis, nz in zip(in_axes, nz_arg_cts()) if nz)
|
||||
new_params = dict(params,
|
||||
name=wrap_name(params['name'], 'transpose'),
|
||||
in_axes=new_in_axes,
|
||||
out_axes_thunk=out_axes_thunk,
|
||||
donated_invars=(False,) * len(new_in_axes),
|
||||
spmd_out_axes_thunk=None)
|
||||
del new_params['out_axes']
|
||||
del new_params['spmd_out_axes']
|
||||
out_flat = xmap_p.bind(fun, *all_args, **new_params)
|
||||
arg_cts = tree_unflatten(out_tree(), out_flat)
|
||||
|
||||
axis_resource_count = _get_axis_resource_count(params['axis_resources'],
|
||||
params['resource_env'])
|
||||
local_axis_sizes = {axis: axis_resource_count[axis].to_local(global_size)
|
||||
for axis, global_size in params['global_axis_sizes'].items()}
|
||||
def unmap_zero(zero, axes):
|
||||
return ad.Zero(_insert_aval_axes(zero.aval, axes, local_axis_sizes))
|
||||
return tuple(unmap_zero(arg_ct, in_axis) if type(arg_ct) is ad.Zero else arg_ct
|
||||
for arg_ct, in_axis in zip(arg_cts, in_axes))
|
||||
ad.primitive_transposes[xmap_p] = _xmap_transpose
|
||||
|
||||
|
||||
def _typecheck_xmap(
|
||||
*in_avals, call_jaxpr, name, in_axes, out_axes, donated_invars,
|
||||
global_axis_sizes, axis_resources, resource_env, backend,
|
||||
@ -917,7 +955,7 @@ pxla.custom_resource_typing_rules[xmap_p] = _resource_typing_xmap
|
||||
# This is DynamicJaxprTrace.process_map with some very minor modifications
|
||||
def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
|
||||
from jax.interpreters.partial_eval import (
|
||||
trace_to_subjaxpr_dynamic, DynamicJaxprTracer, source_info_util,
|
||||
trace_to_subjaxpr_dynamic, DynamicJaxprTracer,
|
||||
convert_constvars_jaxpr, new_jaxpr_eqn)
|
||||
assert primitive is xmap_p
|
||||
in_avals = [t.aval for t in tracers]
|
||||
@ -966,6 +1004,142 @@ def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
|
||||
pe.DynamicJaxprTrace.process_xmap = _dynamic_jaxpr_process_xmap # type: ignore
|
||||
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def out_local_named_shapes(local_axes, *args, **kwargs):
|
||||
ans = yield args, kwargs
|
||||
ans_axes = [frozenset(a.aval.named_shape) & local_axes for a in ans]
|
||||
yield ans, ans_axes
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def hide_units(unit_args, *args, **kwargs):
|
||||
ans = yield restore_units(unit_args, args), kwargs
|
||||
yield filter_units(ans)
|
||||
|
||||
def filter_units(vals):
|
||||
vals_no_units = [v for v in vals if v is not core.unit]
|
||||
vals_is_unit = [v is core.unit for v in vals]
|
||||
return vals_no_units, vals_is_unit
|
||||
|
||||
def restore_units(is_unit, vals):
|
||||
vals_it = iter(vals)
|
||||
vals_with_units = [core.unit if u else next(vals_it) for u in is_unit]
|
||||
try:
|
||||
next(vals_it)
|
||||
raise RuntimeError("Expected the iterator to be exhausted")
|
||||
except StopIteration:
|
||||
return vals_with_units
|
||||
|
||||
|
||||
def _jaxpr_trace_process_xmap(self, primitive, f: lu.WrappedFun, tracers, params):
|
||||
from jax.interpreters.partial_eval import (
|
||||
PartialVal, JaxprTracer, _drop_invars, _dce_open_jaxpr,
|
||||
convert_constvars_jaxpr, new_eqn_recipe)
|
||||
assert primitive is xmap_p
|
||||
in_axes = params['in_axes']
|
||||
donated_invars = params['donated_invars']
|
||||
global_axis_sizes = params['global_axis_sizes']
|
||||
|
||||
in_pvals = [t.pval for t in tracers]
|
||||
in_pvals = [pval if pval.is_known()
|
||||
else PartialVal.unknown(_delete_aval_axes(pval[0], axes))
|
||||
for pval, axes in zip(in_pvals, in_axes)]
|
||||
|
||||
const_axes_s = lu.Store()
|
||||
def app(f, *args):
|
||||
args_no_units, in_units = filter_units(args)
|
||||
f, out_units = hide_units(f, tuple(in_units))
|
||||
f, out_named_shapes = out_local_named_shapes(f, frozenset(global_axis_sizes))
|
||||
out_axes_thunk = params['out_axes_thunk']
|
||||
@as_hashable_function(closure=out_axes_thunk)
|
||||
def new_out_axes_thunk():
|
||||
out_axes = out_axes_thunk()
|
||||
axes_units, const_units = split_list(out_units(), [len(out_axes)])
|
||||
assert not any(const_units)
|
||||
num_consts = len(const_units)
|
||||
out_axes_no_units = [a for a, u in zip(out_axes, axes_units) if not u]
|
||||
const_axes = [
|
||||
AxisNamePos(zip(sort_named_shape, range(len(sort_named_shape))),
|
||||
user_repr=f'<internal: {sort_named_shape}>')
|
||||
for named_shape in out_named_shapes()[-num_consts:]
|
||||
# We sort here to make the iteration order deterministic
|
||||
for sort_named_shape in [sorted(named_shape, key=str)]
|
||||
]
|
||||
if not const_axes_s: # NOTE: This can be called multiple times
|
||||
const_axes_s.store(const_axes)
|
||||
assert const_axes_s.val == const_axes
|
||||
return (*out_axes_no_units, *const_axes)
|
||||
pe_params = dict(
|
||||
params,
|
||||
in_axes=tuple(a for a, u in zip(in_axes, in_units) if not u),
|
||||
donated_invars=tuple(a for a, u in zip(donated_invars, in_units) if not u),
|
||||
out_axes_thunk=new_out_axes_thunk)
|
||||
outs_no_units = primitive.bind(f, *args_no_units, **pe_params)
|
||||
new_out_axes_thunk() # Make sure it is called at least once to compute const_axes
|
||||
return restore_units(out_units(), outs_no_units)
|
||||
|
||||
jaxpr, out_pvals, consts, env_tracers = self.partial_eval(
|
||||
f, in_pvals, app, instantiate=False)
|
||||
|
||||
out_axes = params['out_axes_thunk']()
|
||||
const_axes = const_axes_s.val
|
||||
axis_resource_count = _get_axis_resource_count(params['axis_resources'],
|
||||
params['resource_env'])
|
||||
local_axis_sizes = {axis: axis_resource_count[axis].to_local(global_size)
|
||||
for axis, global_size in global_axis_sizes.items()}
|
||||
out_pvals = [pval if pval.is_known() else
|
||||
PartialVal.unknown(_insert_aval_axes(pval[0], axes, local_axis_sizes))
|
||||
for pval, axes in zip(out_pvals, out_axes)]
|
||||
|
||||
with core.extend_axis_env_nd(global_axis_sizes.items()):
|
||||
# Skip known invars and outvars, and lift constants as regular invars
|
||||
in_knowns = tuple(t.pval.is_known() for t in it.chain(env_tracers, tracers))
|
||||
out_unknowns = tuple(not pval.is_known() for pval in out_pvals)
|
||||
jaxpr = _drop_invars(jaxpr, in_knowns)
|
||||
jaxpr = _dce_open_jaxpr(jaxpr, out_unknowns, drop_outputs=True)
|
||||
jaxpr = convert_constvars_jaxpr(jaxpr)
|
||||
|
||||
# Known tracers get propagated as if they were constants
|
||||
known_tracers_out = [self.new_const(pval.get_known()) for pval in out_pvals
|
||||
if pval.is_known()]
|
||||
|
||||
# I'm not 100% if that's correct, but it is an assumption that
|
||||
# JaxprTrace.process_call already makes.
|
||||
if any(t.pval.is_known() for t in env_tracers):
|
||||
raise AssertionError("Please open a bug report!")
|
||||
# Unknown tracers need to have the jaxpr set up as their recipe
|
||||
unknown_tracers_in = (*env_tracers, *(t for t in tracers if not t.pval.is_known()))
|
||||
unknown_tracers_out = [JaxprTracer(self, pval, None) for pval in out_pvals
|
||||
if not pval.is_known()]
|
||||
const_tracers = map(self.new_instantiated_const, consts)
|
||||
|
||||
# Set up new params
|
||||
new_in_axes = (*const_axes,
|
||||
*(None for _ in env_tracers),
|
||||
*(axis for axis, t in zip(in_axes, tracers)
|
||||
if not t.pval.is_known()))
|
||||
new_out_axes = tuple(axis for axis, pval in zip(out_axes, out_pvals)
|
||||
if not pval.is_known())
|
||||
|
||||
assert params['spmd_in_axes'] is None and params['spmd_out_axes_thunk'] is None
|
||||
new_params = dict(
|
||||
params,
|
||||
call_jaxpr=jaxpr,
|
||||
donated_invars=(*(False for _ in const_tracers),
|
||||
*(d for d, t in zip(donated_invars, tracers) if not t.pval.is_known())),
|
||||
in_axes=new_in_axes,
|
||||
out_axes=new_out_axes,
|
||||
spmd_out_axes=None)
|
||||
del new_params['out_axes_thunk']
|
||||
del new_params['spmd_out_axes_thunk']
|
||||
|
||||
eqn = new_eqn_recipe((*const_tracers, *unknown_tracers_in),
|
||||
unknown_tracers_out,
|
||||
primitive, new_params, source_info_util.current())
|
||||
for t in unknown_tracers_out: t.recipe = eqn
|
||||
return pe._zip_knowns(known_tracers_out, unknown_tracers_out, out_unknowns)
|
||||
pe.JaxprTrace.process_xmap = _jaxpr_trace_process_xmap
|
||||
|
||||
|
||||
def _batch_trace_update_spmd_axes(
|
||||
spmd_in_axes, spmd_out_axes_thunk,
|
||||
axis_name, dims, dims_out_thunk):
|
||||
|
@ -515,13 +515,26 @@ class XMapTest(XMapTestCase):
|
||||
y = rng.randn(*yshape)
|
||||
self.assertAllClose(fm(x, y), fref(x, y))
|
||||
|
||||
def testJVP(self):
|
||||
def testAutodiffBroadcast(self):
|
||||
f = xmap(lambda x, y: jnp.cos(lax.dot(x, jnp.sin(y),
|
||||
precision=lax.Precision.HIGHEST)),
|
||||
in_axes=[['i', ...], {}], out_axes=['i', ...])
|
||||
x = jnp.arange(12, dtype=jnp.float32).reshape((3, 4)) / 100
|
||||
y = jnp.arange(20, dtype=jnp.float32).reshape((4, 5)) / 100
|
||||
jtu.check_grads(f, (x, y), order=2, modes=['fwd'])
|
||||
jtu.check_grads(f, (x, y), order=1, modes=['rev'])
|
||||
with self.assertRaises(AssertionError):
|
||||
# Second order reverse-mode differentiations seems to be broken,
|
||||
# likely due to the transpose of psum being defined incorrectly.
|
||||
jtu.check_grads(f, (x, y), order=2, modes=['rev'])
|
||||
|
||||
def testAutodiffNoBroadcast(self):
|
||||
f = xmap(lambda x, y: jnp.cos(lax.dot(x, jnp.sin(y),
|
||||
precision=lax.Precision.HIGHEST)),
|
||||
in_axes=[['i', ...], [None, 'i']], out_axes=['i'])
|
||||
x = jnp.arange(12, dtype=jnp.float32).reshape((3, 4)) / 100
|
||||
y = jnp.arange(12, dtype=jnp.float32).reshape((4, 3)) / 100
|
||||
jtu.check_grads(f, (x, y), order=2)
|
||||
|
||||
@jtu.with_and_without_mesh
|
||||
def testNamedShape(self, mesh, axis_resources):
|
||||
|
Loading…
x
Reference in New Issue
Block a user