Add support for reverse-mode AD of xmap

This commit is contained in:
Adam Paszke 2021-07-15 13:10:05 +00:00
parent b9aaff4870
commit 478828a21a
2 changed files with 190 additions and 3 deletions

View File

@ -43,7 +43,7 @@ from ..lib import xla_bridge as xb
from ..lib import xla_client as xc
from .._src.util import (safe_map, safe_zip, HashableFunction,
as_hashable_function, unzip2, distributed_debug_log,
tuple_insert, moveaxis)
tuple_insert, moveaxis, split_list, wrap_name)
from .._src.lax.parallel import _axis_index_translation_rule
from .. import lax
@ -846,6 +846,44 @@ core.axis_substitution_rules[xmap_p] = _xmap_axis_subst
ad.JVPTrace.process_xmap = ad.JVPTrace.process_call # type: ignore
ad.call_param_updaters[xmap_p] = ad.call_param_updaters[xla.xla_call_p]
def _xmap_transpose(params, call_jaxpr, args, cts_in, cts_in_avals, reduce_axes):
all_args, in_tree_def = tree_flatten(((), args, cts_in)) # empty consts
fun = lu.hashable_partial(
lu.wrap_init(ad.backward_pass),
call_jaxpr, reduce_axes + tuple(params['global_axis_sizes'].keys()))
fun, nz_arg_cts = ad.nonzero_outputs(fun)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
# Preserve axis for primal arguments, skip tangents (represented as undefined primals).
in_axes, out_axes = params['in_axes'], params['out_axes']
new_in_axes = (*(axis for axis, x in zip(in_axes, args) if not ad.is_undefined_primal(x)),
*(axis for axis, x in zip(out_axes, cts_in) if type(x) is not ad.Zero))
# NOTE: This assumes that the output cotangents being zero is a deterministic
# function of which input cotangents were zero.
@as_hashable_function(closure=(in_axes, tuple(type(c) is ad.Zero for c in cts_in)))
def out_axes_thunk():
return tuple(axis for axis, nz in zip(in_axes, nz_arg_cts()) if nz)
new_params = dict(params,
name=wrap_name(params['name'], 'transpose'),
in_axes=new_in_axes,
out_axes_thunk=out_axes_thunk,
donated_invars=(False,) * len(new_in_axes),
spmd_out_axes_thunk=None)
del new_params['out_axes']
del new_params['spmd_out_axes']
out_flat = xmap_p.bind(fun, *all_args, **new_params)
arg_cts = tree_unflatten(out_tree(), out_flat)
axis_resource_count = _get_axis_resource_count(params['axis_resources'],
params['resource_env'])
local_axis_sizes = {axis: axis_resource_count[axis].to_local(global_size)
for axis, global_size in params['global_axis_sizes'].items()}
def unmap_zero(zero, axes):
return ad.Zero(_insert_aval_axes(zero.aval, axes, local_axis_sizes))
return tuple(unmap_zero(arg_ct, in_axis) if type(arg_ct) is ad.Zero else arg_ct
for arg_ct, in_axis in zip(arg_cts, in_axes))
ad.primitive_transposes[xmap_p] = _xmap_transpose
def _typecheck_xmap(
*in_avals, call_jaxpr, name, in_axes, out_axes, donated_invars,
global_axis_sizes, axis_resources, resource_env, backend,
@ -917,7 +955,7 @@ pxla.custom_resource_typing_rules[xmap_p] = _resource_typing_xmap
# This is DynamicJaxprTrace.process_map with some very minor modifications
def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
from jax.interpreters.partial_eval import (
trace_to_subjaxpr_dynamic, DynamicJaxprTracer, source_info_util,
trace_to_subjaxpr_dynamic, DynamicJaxprTracer,
convert_constvars_jaxpr, new_jaxpr_eqn)
assert primitive is xmap_p
in_avals = [t.aval for t in tracers]
@ -966,6 +1004,142 @@ def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
pe.DynamicJaxprTrace.process_xmap = _dynamic_jaxpr_process_xmap # type: ignore
@lu.transformation_with_aux
def out_local_named_shapes(local_axes, *args, **kwargs):
ans = yield args, kwargs
ans_axes = [frozenset(a.aval.named_shape) & local_axes for a in ans]
yield ans, ans_axes
@lu.transformation_with_aux
def hide_units(unit_args, *args, **kwargs):
ans = yield restore_units(unit_args, args), kwargs
yield filter_units(ans)
def filter_units(vals):
vals_no_units = [v for v in vals if v is not core.unit]
vals_is_unit = [v is core.unit for v in vals]
return vals_no_units, vals_is_unit
def restore_units(is_unit, vals):
vals_it = iter(vals)
vals_with_units = [core.unit if u else next(vals_it) for u in is_unit]
try:
next(vals_it)
raise RuntimeError("Expected the iterator to be exhausted")
except StopIteration:
return vals_with_units
def _jaxpr_trace_process_xmap(self, primitive, f: lu.WrappedFun, tracers, params):
from jax.interpreters.partial_eval import (
PartialVal, JaxprTracer, _drop_invars, _dce_open_jaxpr,
convert_constvars_jaxpr, new_eqn_recipe)
assert primitive is xmap_p
in_axes = params['in_axes']
donated_invars = params['donated_invars']
global_axis_sizes = params['global_axis_sizes']
in_pvals = [t.pval for t in tracers]
in_pvals = [pval if pval.is_known()
else PartialVal.unknown(_delete_aval_axes(pval[0], axes))
for pval, axes in zip(in_pvals, in_axes)]
const_axes_s = lu.Store()
def app(f, *args):
args_no_units, in_units = filter_units(args)
f, out_units = hide_units(f, tuple(in_units))
f, out_named_shapes = out_local_named_shapes(f, frozenset(global_axis_sizes))
out_axes_thunk = params['out_axes_thunk']
@as_hashable_function(closure=out_axes_thunk)
def new_out_axes_thunk():
out_axes = out_axes_thunk()
axes_units, const_units = split_list(out_units(), [len(out_axes)])
assert not any(const_units)
num_consts = len(const_units)
out_axes_no_units = [a for a, u in zip(out_axes, axes_units) if not u]
const_axes = [
AxisNamePos(zip(sort_named_shape, range(len(sort_named_shape))),
user_repr=f'<internal: {sort_named_shape}>')
for named_shape in out_named_shapes()[-num_consts:]
# We sort here to make the iteration order deterministic
for sort_named_shape in [sorted(named_shape, key=str)]
]
if not const_axes_s: # NOTE: This can be called multiple times
const_axes_s.store(const_axes)
assert const_axes_s.val == const_axes
return (*out_axes_no_units, *const_axes)
pe_params = dict(
params,
in_axes=tuple(a for a, u in zip(in_axes, in_units) if not u),
donated_invars=tuple(a for a, u in zip(donated_invars, in_units) if not u),
out_axes_thunk=new_out_axes_thunk)
outs_no_units = primitive.bind(f, *args_no_units, **pe_params)
new_out_axes_thunk() # Make sure it is called at least once to compute const_axes
return restore_units(out_units(), outs_no_units)
jaxpr, out_pvals, consts, env_tracers = self.partial_eval(
f, in_pvals, app, instantiate=False)
out_axes = params['out_axes_thunk']()
const_axes = const_axes_s.val
axis_resource_count = _get_axis_resource_count(params['axis_resources'],
params['resource_env'])
local_axis_sizes = {axis: axis_resource_count[axis].to_local(global_size)
for axis, global_size in global_axis_sizes.items()}
out_pvals = [pval if pval.is_known() else
PartialVal.unknown(_insert_aval_axes(pval[0], axes, local_axis_sizes))
for pval, axes in zip(out_pvals, out_axes)]
with core.extend_axis_env_nd(global_axis_sizes.items()):
# Skip known invars and outvars, and lift constants as regular invars
in_knowns = tuple(t.pval.is_known() for t in it.chain(env_tracers, tracers))
out_unknowns = tuple(not pval.is_known() for pval in out_pvals)
jaxpr = _drop_invars(jaxpr, in_knowns)
jaxpr = _dce_open_jaxpr(jaxpr, out_unknowns, drop_outputs=True)
jaxpr = convert_constvars_jaxpr(jaxpr)
# Known tracers get propagated as if they were constants
known_tracers_out = [self.new_const(pval.get_known()) for pval in out_pvals
if pval.is_known()]
# I'm not 100% if that's correct, but it is an assumption that
# JaxprTrace.process_call already makes.
if any(t.pval.is_known() for t in env_tracers):
raise AssertionError("Please open a bug report!")
# Unknown tracers need to have the jaxpr set up as their recipe
unknown_tracers_in = (*env_tracers, *(t for t in tracers if not t.pval.is_known()))
unknown_tracers_out = [JaxprTracer(self, pval, None) for pval in out_pvals
if not pval.is_known()]
const_tracers = map(self.new_instantiated_const, consts)
# Set up new params
new_in_axes = (*const_axes,
*(None for _ in env_tracers),
*(axis for axis, t in zip(in_axes, tracers)
if not t.pval.is_known()))
new_out_axes = tuple(axis for axis, pval in zip(out_axes, out_pvals)
if not pval.is_known())
assert params['spmd_in_axes'] is None and params['spmd_out_axes_thunk'] is None
new_params = dict(
params,
call_jaxpr=jaxpr,
donated_invars=(*(False for _ in const_tracers),
*(d for d, t in zip(donated_invars, tracers) if not t.pval.is_known())),
in_axes=new_in_axes,
out_axes=new_out_axes,
spmd_out_axes=None)
del new_params['out_axes_thunk']
del new_params['spmd_out_axes_thunk']
eqn = new_eqn_recipe((*const_tracers, *unknown_tracers_in),
unknown_tracers_out,
primitive, new_params, source_info_util.current())
for t in unknown_tracers_out: t.recipe = eqn
return pe._zip_knowns(known_tracers_out, unknown_tracers_out, out_unknowns)
pe.JaxprTrace.process_xmap = _jaxpr_trace_process_xmap
def _batch_trace_update_spmd_axes(
spmd_in_axes, spmd_out_axes_thunk,
axis_name, dims, dims_out_thunk):

View File

@ -515,13 +515,26 @@ class XMapTest(XMapTestCase):
y = rng.randn(*yshape)
self.assertAllClose(fm(x, y), fref(x, y))
def testJVP(self):
def testAutodiffBroadcast(self):
f = xmap(lambda x, y: jnp.cos(lax.dot(x, jnp.sin(y),
precision=lax.Precision.HIGHEST)),
in_axes=[['i', ...], {}], out_axes=['i', ...])
x = jnp.arange(12, dtype=jnp.float32).reshape((3, 4)) / 100
y = jnp.arange(20, dtype=jnp.float32).reshape((4, 5)) / 100
jtu.check_grads(f, (x, y), order=2, modes=['fwd'])
jtu.check_grads(f, (x, y), order=1, modes=['rev'])
with self.assertRaises(AssertionError):
# Second order reverse-mode differentiations seems to be broken,
# likely due to the transpose of psum being defined incorrectly.
jtu.check_grads(f, (x, y), order=2, modes=['rev'])
def testAutodiffNoBroadcast(self):
f = xmap(lambda x, y: jnp.cos(lax.dot(x, jnp.sin(y),
precision=lax.Precision.HIGHEST)),
in_axes=[['i', ...], [None, 'i']], out_axes=['i'])
x = jnp.arange(12, dtype=jnp.float32).reshape((3, 4)) / 100
y = jnp.arange(12, dtype=jnp.float32).reshape((4, 3)) / 100
jtu.check_grads(f, (x, y), order=2)
@jtu.with_and_without_mesh
def testNamedShape(self, mesh, axis_resources):