Restore support for nested xmaps when using the (standard) replica lowering

Fortunately this wasn't too difficult, as most of the code was already
there. The biggest issues were a lack of axis name substitution and
inability to handle multiple resource assignments for a single logical
axis.

I'm planning to add a more comprehensive test suite for this, but I'd
like to wait for the pdot test PR (#5459) to land first. It contains a
bunch of utilities that would come in handy.
This commit is contained in:
Adam Paszke 2021-01-19 14:18:53 +00:00
parent 802c773268
commit 137321f3f7
2 changed files with 97 additions and 68 deletions

View File

@ -457,7 +457,6 @@ def _xmap_translation_rule_replica(c, axis_env,
local_mesh = resource_env.physical_mesh.local_mesh
local_mesh_shape = local_mesh.shape
mesh_in_axes, mesh_out_axes = plan.to_mesh_axes(in_axes, out_axes)
raise NotImplementedError("TODO: Substitute axis names!")
assert type(call_jaxpr) is core.Jaxpr
local_avals = [pxla.tile_aval_nd(
@ -465,9 +464,15 @@ def _xmap_translation_rule_replica(c, axis_env,
_insert_aval_axes(v.aval, aval_in_axes, axis_sizes))
for v, aval_in_axes, aval_mesh_in_axes
in zip(call_jaxpr.invars, in_axes, mesh_in_axes)]
f = lu.wrap_init(core.jaxpr_as_fun(core.ClosedJaxpr(call_jaxpr, ())))
# We have to substitute before tracing, because we want the vectorized
# axes to be used in the jaxpr.
resource_call_jaxpr = subst_axis_names(call_jaxpr, plan.axis_subst)
f = lu.wrap_init(core.jaxpr_as_fun(core.ClosedJaxpr(resource_call_jaxpr, ())))
f = hide_mapped_axes(f, tuple(in_axes), tuple(out_axes))
f = plan.vectorize(f, in_axes, out_axes)
# NOTE: We don't extend the resource env with the mesh shape, because those
# resources are already in scope! It's the outermost xmap that introduces
# them!
vectorized_jaxpr, _, consts = pe.trace_to_jaxpr_final(f, local_avals)
assert not consts
@ -476,12 +481,14 @@ def _xmap_translation_rule_replica(c, axis_env,
if v.aval is not core.abstract_unit else in_node
for v, in_node, arg_in_axes in zip(call_jaxpr.invars, in_nodes, mesh_in_axes))
# NOTE: We don't extend the resource env with the mesh shape, because those
# resources are already in scope! It's the outermost xmap that introduces
# them!
# We in-line here rather than generating a Call HLO as in the xla_call
# translation rule just because the extra tuple stuff is a pain.
with core.extend_axis_env_nd(axis_sizes.items()):
tiled_outs = xla.jaxpr_subcomp(
c, vectorized_jaxpr, backend, axis_env, (),
xla.extend_name_stack(name_stack, xla.wrap_name(name, 'xmap')), *tiled_ins)
tiled_outs = xla.jaxpr_subcomp(
c, vectorized_jaxpr, backend, axis_env, (),
xla.extend_name_stack(name_stack, xla.wrap_name(name, 'xmap')), *tiled_ins)
outs = [_xla_untile(c, axis_env, tiled_out, ans_out_axes, local_mesh_shape, backend)
if v.aval is not core.abstract_unit else tiled_out
@ -490,25 +497,34 @@ def _xmap_translation_rule_replica(c, axis_env,
return xops.Tuple(c, outs)
def _xla_tile_base_indices(c, axis_env, tile_shape, axes, axis_sizes):
zero = xb.constant(c, np.zeros((), dtype=np.int32))
linear_idxs = [zero] * len(tile_shape)
strides = [1] * len(tile_shape)
for name, axis in reversed(axes.items()):
axis_index = _axis_index_translation_rule(
c, axis_name=name, axis_env=axis_env, platform=None)
stride_c = xb.constant(c, np.array(strides[axis], np.int32))
if linear_idxs[axis] is zero and strides[axis] == 1:
linear_idxs[axis] = axis_index
else:
linear_idxs[axis] = xops.Add(linear_idxs[axis], xops.Mul(axis_index, stride_c))
strides[axis] *= axis_sizes[name]
return [zero if linear_idx is zero else
xops.Mul(linear_idx, xb.constant(c, np.array(tile_dim_size, np.int32)))
for linear_idx, tile_dim_size in zip(linear_idxs, tile_shape)]
def _xla_tile(c, axis_env, x, in_axes, axis_sizes):
if not in_axes:
return x
shape = list(c.get_shape(x).dimensions())
zero = xb.constant(c, np.zeros((), dtype=np.int32))
start_idxs = [zero] * len(shape)
tiled_shape = list(shape)
tile_shape = list(shape)
for name, axis in in_axes.items():
axis_size = axis_sizes[name]
assert tiled_shape[axis] % axis_size == 0
tiled_shape[axis] //= axis_size
axis_size_c = xb.constant(c, np.array(axis_size, np.int32))
assert start_idxs[axis] is zero # TODO(apaszke): tiling over multiple mesh axes
axis_index = _axis_index_translation_rule(
c, axis_name=name, axis_env=axis_env, platform=None)
start_idxs[axis] = xops.Mul(axis_index, axis_size_c)
return xops.DynamicSlice(x, start_idxs, tiled_shape)
assert tile_shape[axis] % axis_size == 0
tile_shape[axis] //= axis_size
base_idxs = _xla_tile_base_indices(c, axis_env, tile_shape, in_axes, axis_sizes)
return xops.DynamicSlice(x, base_idxs, tile_shape)
# TODO(b/110096942): more efficient gather
def _xla_untile(c, axis_env, x, out_axes, axis_sizes, backend):
@ -520,23 +536,14 @@ def _xla_untile(c, axis_env, x, out_axes, axis_sizes, backend):
if convert_bool:
x = xops.ConvertElementType(x, xb.dtype_to_etype(np.float32))
untiled_shape = list(xla_shape.dimensions())
zero_idx = xb.constant(c, np.zeros((), dtype=np.int32))
start_idxs = [zero_idx] * len(untiled_shape)
tile_shape = list(xla_shape.dimensions())
shape = list(tile_shape)
for name, axis in out_axes.items():
axis_size = axis_sizes[name]
shape[axis] *= axis_sizes[name]
base_idxs = _xla_tile_base_indices(c, axis_env, tile_shape, out_axes, axis_sizes)
untiled_shape[axis] *= axis_size
axis_size_c = xb.constant(c, np.array(axis_size, np.int32))
assert start_idxs[axis] is zero_idx # TODO(apaszke): tiling over multiple mesh axes
axis_index = _axis_index_translation_rule(
c, axis_name=name, axis_env=axis_env, platform=None)
start_idxs[axis] = xops.Mul(axis_index, axis_size_c)
zero = xb.constant(c, np.array(0, x_dtype))
padded = xops.Broadcast(zero, untiled_shape)
padded = xops.DynamicUpdateSlice(padded, x, start_idxs)
padded = xops.Broadcast(xb.constant(c, np.array(0, x_dtype)), shape)
padded = xops.DynamicUpdateSlice(padded, x, base_idxs)
replica_groups_protos = xc.make_replica_groups(
xla.axis_groups(axis_env, tuple(out_axes.keys())))
out = xops.CrossReplicaSum(padded, replica_groups_protos)
@ -635,6 +642,16 @@ def subst_axis_names(jaxpr, axis_subst: Dict[AxisName, Tuple[AxisName]]):
def subst_eqn_axis_names(eqn, axis_subst: Dict[AxisName, Tuple[AxisName]]):
# TODO: Support custom_vjp, custom_jvp
if eqn.primitive is xmap_p:
shadowed_axes = set(eqn.params['axis_sizes']) & set(axis_subst)
if shadowed_axes:
shadowed_subst = dict(axis_subst)
for saxis in shadowed_axes:
del shadowed_subst[saxis]
else:
shadowed_subst = axis_subst
new_call_jaxpr = subst_axis_names(eqn.params['call_jaxpr'], shadowed_subst)
return eqn._replace(params=dict(eqn.params, call_jaxpr=new_call_jaxpr))
if isinstance(eqn.primitive, (core.CallPrimitive, core.MapPrimitive)):
bound_name = eqn.params.get('axis_name', None)
if bound_name in axis_subst: # Check for shadowing

View File

@ -182,42 +182,35 @@ class XMapTest(jtu.JaxTestCase):
python_should_be_executing = False
fm(x)
@skip("Need to implement vmap(xmap)")
@parameterized.named_parameters(
{"testcase_name": name, "mesh": mesh, "axis_resources": axis_resources}
for name, mesh, axis_resources in (
('OneToOne', (('x', 2), ('y', 2)), (('a', 'y'), ('b', 'x'))),
('Multiple', (('x', 2), ('y', 2), ('z', 2)), (('a', 'y'), ('b', ('x', 'z')))),
))
@ignore_xmap_warning()
@with_mesh([('x', 2), ('y', 3)])
def testNestedMesh(self):
@partial(xmap, in_axes={1: 'a'}, out_axes={0: 'a'}, axis_resources={'a': 'y'})
def f(x):
y = x * 2
@partial(xmap, in_axes={0: 'b'}, out_axes={1: 'b'}, axis_resources={'b': 'x'})
def h(y):
return jnp.sin(y)
return h(y)
xshape = (2, 3, 5)
x = jnp.arange(np.prod(xshape)).reshape(xshape)
y = f(x)
self.assertAllClose(y, jnp.sin(x * 2).transpose((1, 2, 0)))
# Make sure the op really ran accros a 2D mesh.
self.assertEqual(y.sharding_spec.sharding,
(pxla.Chunked(3), None, None))
self.assertEqual(y.sharding_spec.mesh_mapping,
(pxla.Replicated(2), pxla.ShardedAxis(0)))
def testNestedMesh(self, mesh, axis_resources):
@with_mesh(mesh)
def run_test():
@partial(xmap, in_axes={1: 'a'}, out_axes=({0: 'a'}, {}),
axis_resources=dict([axis_resources[0]]))
def f(x):
y = x * 2
@partial(xmap, in_axes={0: 'b'}, out_axes=({1: 'b'}, {}),
axis_resources=dict([axis_resources[1]]))
def h(y):
return jnp.sin(y), lax.psum(y, ('a', 'b'))
return h(y)
@ignore_xmap_warning()
@with_mesh([('x', 2)])
def testNestedDifferentResources(self):
@partial(xmap, in_axes={0: 'a'}, out_axes={0: 'a'}, axis_resources={'a': 'x'})
def f(x):
with mesh(np.empty((), dtype=np.object), ()):
@partial(xmap, in_axes={0: 'b'}, out_axes={0: 'b'})
def h(x):
return x
return h(x)
xshape = (2, 5, 6)
x = jnp.arange(np.prod(xshape)).reshape(xshape)
with self.assertRaisesRegex(RuntimeError,
"Changing the resource environment.*"):
f(x)
xshape = (4, 2, 5)
x = jnp.arange(np.prod(xshape)).reshape(xshape)
y = f(x)
self.assertAllClose(y, (jnp.sin(x * 2).transpose((1, 2, 0)), (x * 2).sum((0, 1))))
self.assertEqual(y[0].sharding_spec.sharding,
(pxla.Chunked(2), pxla.NoSharding(), pxla.NoSharding()))
self.assertEqual(y[0].sharding_spec.mesh_mapping,
(pxla.Replicated(2), pxla.ShardedAxis(0)) + (pxla.Replicated(2),) * (len(mesh) - 2))
run_test()
@parameterized.named_parameters(
{"testcase_name": name, "mesh": mesh, "axis_resources": axis_resources}
@ -334,6 +327,9 @@ class XMapTestSPMD(XMapTest):
super().setUp()
if jtu.device_under_test() != "tpu":
raise SkipTest
# Nesting xmap calls is not supported in the SPMD lowering yet
if "NestedMesh" in self._testMethodName:
raise SkipTest
jax.experimental.maps.make_xmap_callable.cache_clear()
self.old_lowering_flag = jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING
jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = True
@ -667,6 +663,22 @@ class XMapErrorTest(jtu.JaxTestCase):
fxy = xmap(f, in_axes=['a', ...], out_axes=['a', ...],
axis_resources={'a': ('x', 'x')})
@ignore_xmap_warning()
@with_mesh([('x', 2)])
def testNestedDifferentResources(self):
@partial(xmap, in_axes={0: 'a'}, out_axes={0: 'a'}, axis_resources={'a': 'x'})
def f(x):
with mesh(np.empty((), dtype=np.object), ()):
@partial(xmap, in_axes={0: 'b'}, out_axes={0: 'b'})
def h(x):
return x
return h(x)
xshape = (2, 5, 6)
x = jnp.arange(np.prod(xshape)).reshape(xshape)
with self.assertRaisesRegex(RuntimeError,
"Changing the resource environment.*"):
f(x)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())