Merge pull request #21071 from mattjj:vmap-spmd-axis-name-errors

PiperOrigin-RevId: 638783978
This commit is contained in:
jax authors 2024-05-30 14:30:21 -07:00
commit b1d37d1d20
4 changed files with 55 additions and 2 deletions

View File

@ -2458,6 +2458,13 @@ mlir.register_lowering(sharding_constraint_p,
def _sharding_constraint_batcher(insert_axis, spmd_axis_name, axis_size,
axis_name, main_type, vals_in, dims_in,
sharding, resource_env, unconstrained_dims):
if spmd_axis_name is not None and isinstance(sharding, NamedSharding):
used = {n for ns in sharding.spec
for n in (ns if isinstance(ns, tuple) else (ns,))}
if set(spmd_axis_name) & used:
raise ValueError("vmap spmd_axis_name cannot appear in "
"with_sharding_constraint spec, but got spec"
f"{sharding}")
x, = vals_in
d, = dims_in
# None means unconstrained in ParsedPartitionSpec

View File

@ -1274,6 +1274,9 @@ def _shard_map_batch(
for ax in names} for names, d in zip(in_names, in_dims)]
spmd_axis_name = trace.spmd_axis_name
if spmd_axis_name is not None:
used = {n for names in in_names for ns in names.values() for n in ns}
if set(spmd_axis_name) & used:
raise ValueError("vmap spmd_axis_name cannot appear in shard_map in_specs")
new_in_names = [{**ns, d:spmd_axis_name} if d is not batching.not_mapped # type: ignore
else ns for ns, d in zip(new_in_names, in_dims)]
@as_hashable_function(closure=out_names_thunk)
@ -1306,6 +1309,9 @@ def _batch_out_names(spmd_axis_name, dims, out_names):
out_names_ = [{ax + (d is not batching.not_mapped and d <= ax): names[ax]
for ax in names} for names, d in zip(out_names, dims)]
if spmd_axis_name is not None:
used = {n for names in out_names for ns in names.values() for n in ns}
if set(spmd_axis_name) & used:
raise ValueError("vmap spmd_axis_name cannot appear in shard_map out_specs")
out_names_ = [{**ns, d:spmd_axis_name} if d is not batching.not_mapped
else ns for ns, d in zip(out_names_, dims)]
return out_names_

View File

@ -1303,6 +1303,16 @@ class PJitTest(jtu.BufferDonationTestCase):
""").strip(),
)
def test_with_sharding_constraint_vmap_spmd_axis_name_error(self):
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
def f(x):
return jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P('x')))
xs = jnp.arange(4 * 16.).reshape(4, 16)
with self.assertRaisesRegex(ValueError, "spmd_axis_name"):
jax.vmap(f, spmd_axis_name='x')(xs)
@jtu.pytest_mark_if_available('multiaccelerator')
class CustomPartitionerTest(jtu.JaxTestCase):
@ -4270,8 +4280,7 @@ class PJitErrorTest(jtu.JaxTestCase):
r".*rank at least 2, but was applied to a value of rank 1", re.M | re.S)
with self.assertRaisesRegex(ValueError, error):
pjit(
lambda x: with_sharding_constraint(x, spec),
in_shardings=None,
lambda x: with_sharding_constraint(x, spec), in_shardings=None,
out_shardings=None,
)(x)

View File

@ -1798,6 +1798,37 @@ class ShardMapTest(jtu.JaxTestCase):
ir.as_text()
)
def test_vmap_spmd_axis_name_error(self):
mesh = jtu.create_global_mesh((4, 2), ('i', 'j'))
@partial(
shard_map,
mesh=mesh,
in_specs=P('i'),
out_specs=P('i'),
)
def f(x):
return jnp.sin(x)
xs = jnp.arange(4 * 16.).reshape(4, 16)
with self.assertRaisesRegex(ValueError, "spmd_axis_name cannot appear"):
jax.vmap(f, spmd_axis_name='i')(xs)
@partial(
shard_map,
mesh=mesh,
in_specs=P('j'),
out_specs=P(('i', 'j')),
check_rep=False,
)
def g(x):
return jnp.sin(x)
xs = jnp.arange(4 * 16.).reshape(4, 16)
with self.assertRaisesRegex(ValueError, "spmd_axis_name cannot appear"):
jax.vmap(g, spmd_axis_name='i')(xs)
class FunSpec(NamedTuple):
name: str
num_inputs: int