mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Restore support for nested xmaps when using the (standard) replica lowering
Fortunately this wasn't too difficult, as most of the code was already there. The biggest issues were a lack of axis name substitution and inability to handle multiple resource assignments for a single logical axis. I'm planning to add a more comprehensive test suite for this, but I'd like to wait for the pdot test PR (#5459) to land first. It contains a bunch of utilities that would come in handy.
This commit is contained in:
parent
802c773268
commit
137321f3f7
@ -457,7 +457,6 @@ def _xmap_translation_rule_replica(c, axis_env,
|
||||
local_mesh = resource_env.physical_mesh.local_mesh
|
||||
local_mesh_shape = local_mesh.shape
|
||||
mesh_in_axes, mesh_out_axes = plan.to_mesh_axes(in_axes, out_axes)
|
||||
raise NotImplementedError("TODO: Substitute axis names!")
|
||||
|
||||
assert type(call_jaxpr) is core.Jaxpr
|
||||
local_avals = [pxla.tile_aval_nd(
|
||||
@ -465,9 +464,15 @@ def _xmap_translation_rule_replica(c, axis_env,
|
||||
_insert_aval_axes(v.aval, aval_in_axes, axis_sizes))
|
||||
for v, aval_in_axes, aval_mesh_in_axes
|
||||
in zip(call_jaxpr.invars, in_axes, mesh_in_axes)]
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(core.ClosedJaxpr(call_jaxpr, ())))
|
||||
# We have to substitute before tracing, because we want the vectorized
|
||||
# axes to be used in the jaxpr.
|
||||
resource_call_jaxpr = subst_axis_names(call_jaxpr, plan.axis_subst)
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(core.ClosedJaxpr(resource_call_jaxpr, ())))
|
||||
f = hide_mapped_axes(f, tuple(in_axes), tuple(out_axes))
|
||||
f = plan.vectorize(f, in_axes, out_axes)
|
||||
# NOTE: We don't extend the resource env with the mesh shape, because those
|
||||
# resources are already in scope! It's the outermost xmap that introduces
|
||||
# them!
|
||||
vectorized_jaxpr, _, consts = pe.trace_to_jaxpr_final(f, local_avals)
|
||||
assert not consts
|
||||
|
||||
@ -476,12 +481,14 @@ def _xmap_translation_rule_replica(c, axis_env,
|
||||
if v.aval is not core.abstract_unit else in_node
|
||||
for v, in_node, arg_in_axes in zip(call_jaxpr.invars, in_nodes, mesh_in_axes))
|
||||
|
||||
# NOTE: We don't extend the resource env with the mesh shape, because those
|
||||
# resources are already in scope! It's the outermost xmap that introduces
|
||||
# them!
|
||||
# We in-line here rather than generating a Call HLO as in the xla_call
|
||||
# translation rule just because the extra tuple stuff is a pain.
|
||||
with core.extend_axis_env_nd(axis_sizes.items()):
|
||||
tiled_outs = xla.jaxpr_subcomp(
|
||||
c, vectorized_jaxpr, backend, axis_env, (),
|
||||
xla.extend_name_stack(name_stack, xla.wrap_name(name, 'xmap')), *tiled_ins)
|
||||
tiled_outs = xla.jaxpr_subcomp(
|
||||
c, vectorized_jaxpr, backend, axis_env, (),
|
||||
xla.extend_name_stack(name_stack, xla.wrap_name(name, 'xmap')), *tiled_ins)
|
||||
|
||||
outs = [_xla_untile(c, axis_env, tiled_out, ans_out_axes, local_mesh_shape, backend)
|
||||
if v.aval is not core.abstract_unit else tiled_out
|
||||
@ -490,25 +497,34 @@ def _xmap_translation_rule_replica(c, axis_env,
|
||||
|
||||
return xops.Tuple(c, outs)
|
||||
|
||||
def _xla_tile_base_indices(c, axis_env, tile_shape, axes, axis_sizes):
|
||||
zero = xb.constant(c, np.zeros((), dtype=np.int32))
|
||||
linear_idxs = [zero] * len(tile_shape)
|
||||
strides = [1] * len(tile_shape)
|
||||
for name, axis in reversed(axes.items()):
|
||||
axis_index = _axis_index_translation_rule(
|
||||
c, axis_name=name, axis_env=axis_env, platform=None)
|
||||
stride_c = xb.constant(c, np.array(strides[axis], np.int32))
|
||||
if linear_idxs[axis] is zero and strides[axis] == 1:
|
||||
linear_idxs[axis] = axis_index
|
||||
else:
|
||||
linear_idxs[axis] = xops.Add(linear_idxs[axis], xops.Mul(axis_index, stride_c))
|
||||
strides[axis] *= axis_sizes[name]
|
||||
return [zero if linear_idx is zero else
|
||||
xops.Mul(linear_idx, xb.constant(c, np.array(tile_dim_size, np.int32)))
|
||||
for linear_idx, tile_dim_size in zip(linear_idxs, tile_shape)]
|
||||
|
||||
def _xla_tile(c, axis_env, x, in_axes, axis_sizes):
|
||||
if not in_axes:
|
||||
return x
|
||||
shape = list(c.get_shape(x).dimensions())
|
||||
zero = xb.constant(c, np.zeros((), dtype=np.int32))
|
||||
start_idxs = [zero] * len(shape)
|
||||
tiled_shape = list(shape)
|
||||
tile_shape = list(shape)
|
||||
for name, axis in in_axes.items():
|
||||
axis_size = axis_sizes[name]
|
||||
|
||||
assert tiled_shape[axis] % axis_size == 0
|
||||
tiled_shape[axis] //= axis_size
|
||||
|
||||
axis_size_c = xb.constant(c, np.array(axis_size, np.int32))
|
||||
assert start_idxs[axis] is zero # TODO(apaszke): tiling over multiple mesh axes
|
||||
axis_index = _axis_index_translation_rule(
|
||||
c, axis_name=name, axis_env=axis_env, platform=None)
|
||||
start_idxs[axis] = xops.Mul(axis_index, axis_size_c)
|
||||
return xops.DynamicSlice(x, start_idxs, tiled_shape)
|
||||
assert tile_shape[axis] % axis_size == 0
|
||||
tile_shape[axis] //= axis_size
|
||||
base_idxs = _xla_tile_base_indices(c, axis_env, tile_shape, in_axes, axis_sizes)
|
||||
return xops.DynamicSlice(x, base_idxs, tile_shape)
|
||||
|
||||
# TODO(b/110096942): more efficient gather
|
||||
def _xla_untile(c, axis_env, x, out_axes, axis_sizes, backend):
|
||||
@ -520,23 +536,14 @@ def _xla_untile(c, axis_env, x, out_axes, axis_sizes, backend):
|
||||
if convert_bool:
|
||||
x = xops.ConvertElementType(x, xb.dtype_to_etype(np.float32))
|
||||
|
||||
untiled_shape = list(xla_shape.dimensions())
|
||||
zero_idx = xb.constant(c, np.zeros((), dtype=np.int32))
|
||||
start_idxs = [zero_idx] * len(untiled_shape)
|
||||
tile_shape = list(xla_shape.dimensions())
|
||||
shape = list(tile_shape)
|
||||
for name, axis in out_axes.items():
|
||||
axis_size = axis_sizes[name]
|
||||
shape[axis] *= axis_sizes[name]
|
||||
base_idxs = _xla_tile_base_indices(c, axis_env, tile_shape, out_axes, axis_sizes)
|
||||
|
||||
untiled_shape[axis] *= axis_size
|
||||
|
||||
axis_size_c = xb.constant(c, np.array(axis_size, np.int32))
|
||||
assert start_idxs[axis] is zero_idx # TODO(apaszke): tiling over multiple mesh axes
|
||||
axis_index = _axis_index_translation_rule(
|
||||
c, axis_name=name, axis_env=axis_env, platform=None)
|
||||
start_idxs[axis] = xops.Mul(axis_index, axis_size_c)
|
||||
|
||||
zero = xb.constant(c, np.array(0, x_dtype))
|
||||
padded = xops.Broadcast(zero, untiled_shape)
|
||||
padded = xops.DynamicUpdateSlice(padded, x, start_idxs)
|
||||
padded = xops.Broadcast(xb.constant(c, np.array(0, x_dtype)), shape)
|
||||
padded = xops.DynamicUpdateSlice(padded, x, base_idxs)
|
||||
replica_groups_protos = xc.make_replica_groups(
|
||||
xla.axis_groups(axis_env, tuple(out_axes.keys())))
|
||||
out = xops.CrossReplicaSum(padded, replica_groups_protos)
|
||||
@ -635,6 +642,16 @@ def subst_axis_names(jaxpr, axis_subst: Dict[AxisName, Tuple[AxisName]]):
|
||||
|
||||
def subst_eqn_axis_names(eqn, axis_subst: Dict[AxisName, Tuple[AxisName]]):
|
||||
# TODO: Support custom_vjp, custom_jvp
|
||||
if eqn.primitive is xmap_p:
|
||||
shadowed_axes = set(eqn.params['axis_sizes']) & set(axis_subst)
|
||||
if shadowed_axes:
|
||||
shadowed_subst = dict(axis_subst)
|
||||
for saxis in shadowed_axes:
|
||||
del shadowed_subst[saxis]
|
||||
else:
|
||||
shadowed_subst = axis_subst
|
||||
new_call_jaxpr = subst_axis_names(eqn.params['call_jaxpr'], shadowed_subst)
|
||||
return eqn._replace(params=dict(eqn.params, call_jaxpr=new_call_jaxpr))
|
||||
if isinstance(eqn.primitive, (core.CallPrimitive, core.MapPrimitive)):
|
||||
bound_name = eqn.params.get('axis_name', None)
|
||||
if bound_name in axis_subst: # Check for shadowing
|
||||
|
@ -182,42 +182,35 @@ class XMapTest(jtu.JaxTestCase):
|
||||
python_should_be_executing = False
|
||||
fm(x)
|
||||
|
||||
@skip("Need to implement vmap(xmap)")
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": name, "mesh": mesh, "axis_resources": axis_resources}
|
||||
for name, mesh, axis_resources in (
|
||||
('OneToOne', (('x', 2), ('y', 2)), (('a', 'y'), ('b', 'x'))),
|
||||
('Multiple', (('x', 2), ('y', 2), ('z', 2)), (('a', 'y'), ('b', ('x', 'z')))),
|
||||
))
|
||||
@ignore_xmap_warning()
|
||||
@with_mesh([('x', 2), ('y', 3)])
|
||||
def testNestedMesh(self):
|
||||
@partial(xmap, in_axes={1: 'a'}, out_axes={0: 'a'}, axis_resources={'a': 'y'})
|
||||
def f(x):
|
||||
y = x * 2
|
||||
@partial(xmap, in_axes={0: 'b'}, out_axes={1: 'b'}, axis_resources={'b': 'x'})
|
||||
def h(y):
|
||||
return jnp.sin(y)
|
||||
return h(y)
|
||||
xshape = (2, 3, 5)
|
||||
x = jnp.arange(np.prod(xshape)).reshape(xshape)
|
||||
y = f(x)
|
||||
self.assertAllClose(y, jnp.sin(x * 2).transpose((1, 2, 0)))
|
||||
# Make sure the op really ran accros a 2D mesh.
|
||||
self.assertEqual(y.sharding_spec.sharding,
|
||||
(pxla.Chunked(3), None, None))
|
||||
self.assertEqual(y.sharding_spec.mesh_mapping,
|
||||
(pxla.Replicated(2), pxla.ShardedAxis(0)))
|
||||
def testNestedMesh(self, mesh, axis_resources):
|
||||
@with_mesh(mesh)
|
||||
def run_test():
|
||||
@partial(xmap, in_axes={1: 'a'}, out_axes=({0: 'a'}, {}),
|
||||
axis_resources=dict([axis_resources[0]]))
|
||||
def f(x):
|
||||
y = x * 2
|
||||
@partial(xmap, in_axes={0: 'b'}, out_axes=({1: 'b'}, {}),
|
||||
axis_resources=dict([axis_resources[1]]))
|
||||
def h(y):
|
||||
return jnp.sin(y), lax.psum(y, ('a', 'b'))
|
||||
return h(y)
|
||||
|
||||
@ignore_xmap_warning()
|
||||
@with_mesh([('x', 2)])
|
||||
def testNestedDifferentResources(self):
|
||||
@partial(xmap, in_axes={0: 'a'}, out_axes={0: 'a'}, axis_resources={'a': 'x'})
|
||||
def f(x):
|
||||
with mesh(np.empty((), dtype=np.object), ()):
|
||||
@partial(xmap, in_axes={0: 'b'}, out_axes={0: 'b'})
|
||||
def h(x):
|
||||
return x
|
||||
return h(x)
|
||||
xshape = (2, 5, 6)
|
||||
x = jnp.arange(np.prod(xshape)).reshape(xshape)
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"Changing the resource environment.*"):
|
||||
f(x)
|
||||
xshape = (4, 2, 5)
|
||||
x = jnp.arange(np.prod(xshape)).reshape(xshape)
|
||||
y = f(x)
|
||||
self.assertAllClose(y, (jnp.sin(x * 2).transpose((1, 2, 0)), (x * 2).sum((0, 1))))
|
||||
self.assertEqual(y[0].sharding_spec.sharding,
|
||||
(pxla.Chunked(2), pxla.NoSharding(), pxla.NoSharding()))
|
||||
self.assertEqual(y[0].sharding_spec.mesh_mapping,
|
||||
(pxla.Replicated(2), pxla.ShardedAxis(0)) + (pxla.Replicated(2),) * (len(mesh) - 2))
|
||||
run_test()
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": name, "mesh": mesh, "axis_resources": axis_resources}
|
||||
@ -334,6 +327,9 @@ class XMapTestSPMD(XMapTest):
|
||||
super().setUp()
|
||||
if jtu.device_under_test() != "tpu":
|
||||
raise SkipTest
|
||||
# Nesting xmap calls is not supported in the SPMD lowering yet
|
||||
if "NestedMesh" in self._testMethodName:
|
||||
raise SkipTest
|
||||
jax.experimental.maps.make_xmap_callable.cache_clear()
|
||||
self.old_lowering_flag = jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING
|
||||
jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = True
|
||||
@ -667,6 +663,22 @@ class XMapErrorTest(jtu.JaxTestCase):
|
||||
fxy = xmap(f, in_axes=['a', ...], out_axes=['a', ...],
|
||||
axis_resources={'a': ('x', 'x')})
|
||||
|
||||
@ignore_xmap_warning()
|
||||
@with_mesh([('x', 2)])
|
||||
def testNestedDifferentResources(self):
|
||||
@partial(xmap, in_axes={0: 'a'}, out_axes={0: 'a'}, axis_resources={'a': 'x'})
|
||||
def f(x):
|
||||
with mesh(np.empty((), dtype=np.object), ()):
|
||||
@partial(xmap, in_axes={0: 'b'}, out_axes={0: 'b'})
|
||||
def h(x):
|
||||
return x
|
||||
return h(x)
|
||||
xshape = (2, 5, 6)
|
||||
x = jnp.arange(np.prod(xshape)).reshape(xshape)
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"Changing the resource environment.*"):
|
||||
f(x)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user