mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #21071 from mattjj:vmap-spmd-axis-name-errors
PiperOrigin-RevId: 638783978
This commit is contained in:
commit
b1d37d1d20
@ -2458,6 +2458,13 @@ mlir.register_lowering(sharding_constraint_p,
|
||||
def _sharding_constraint_batcher(insert_axis, spmd_axis_name, axis_size,
|
||||
axis_name, main_type, vals_in, dims_in,
|
||||
sharding, resource_env, unconstrained_dims):
|
||||
if spmd_axis_name is not None and isinstance(sharding, NamedSharding):
|
||||
used = {n for ns in sharding.spec
|
||||
for n in (ns if isinstance(ns, tuple) else (ns,))}
|
||||
if set(spmd_axis_name) & used:
|
||||
raise ValueError("vmap spmd_axis_name cannot appear in "
|
||||
"with_sharding_constraint spec, but got spec"
|
||||
f"{sharding}")
|
||||
x, = vals_in
|
||||
d, = dims_in
|
||||
# None means unconstrained in ParsedPartitionSpec
|
||||
|
@ -1274,6 +1274,9 @@ def _shard_map_batch(
|
||||
for ax in names} for names, d in zip(in_names, in_dims)]
|
||||
spmd_axis_name = trace.spmd_axis_name
|
||||
if spmd_axis_name is not None:
|
||||
used = {n for names in in_names for ns in names.values() for n in ns}
|
||||
if set(spmd_axis_name) & used:
|
||||
raise ValueError("vmap spmd_axis_name cannot appear in shard_map in_specs")
|
||||
new_in_names = [{**ns, d:spmd_axis_name} if d is not batching.not_mapped # type: ignore
|
||||
else ns for ns, d in zip(new_in_names, in_dims)]
|
||||
@as_hashable_function(closure=out_names_thunk)
|
||||
@ -1306,6 +1309,9 @@ def _batch_out_names(spmd_axis_name, dims, out_names):
|
||||
out_names_ = [{ax + (d is not batching.not_mapped and d <= ax): names[ax]
|
||||
for ax in names} for names, d in zip(out_names, dims)]
|
||||
if spmd_axis_name is not None:
|
||||
used = {n for names in out_names for ns in names.values() for n in ns}
|
||||
if set(spmd_axis_name) & used:
|
||||
raise ValueError("vmap spmd_axis_name cannot appear in shard_map out_specs")
|
||||
out_names_ = [{**ns, d:spmd_axis_name} if d is not batching.not_mapped
|
||||
else ns for ns, d in zip(out_names_, dims)]
|
||||
return out_names_
|
||||
|
@ -1303,6 +1303,16 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
""").strip(),
|
||||
)
|
||||
|
||||
def test_with_sharding_constraint_vmap_spmd_axis_name_error(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
|
||||
def f(x):
|
||||
return jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P('x')))
|
||||
|
||||
xs = jnp.arange(4 * 16.).reshape(4, 16)
|
||||
with self.assertRaisesRegex(ValueError, "spmd_axis_name"):
|
||||
jax.vmap(f, spmd_axis_name='x')(xs)
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class CustomPartitionerTest(jtu.JaxTestCase):
|
||||
@ -4270,8 +4280,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
r".*rank at least 2, but was applied to a value of rank 1", re.M | re.S)
|
||||
with self.assertRaisesRegex(ValueError, error):
|
||||
pjit(
|
||||
lambda x: with_sharding_constraint(x, spec),
|
||||
in_shardings=None,
|
||||
lambda x: with_sharding_constraint(x, spec), in_shardings=None,
|
||||
out_shardings=None,
|
||||
)(x)
|
||||
|
||||
|
@ -1798,6 +1798,37 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
ir.as_text()
|
||||
)
|
||||
|
||||
def test_vmap_spmd_axis_name_error(self):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('i', 'j'))
|
||||
|
||||
@partial(
|
||||
shard_map,
|
||||
mesh=mesh,
|
||||
in_specs=P('i'),
|
||||
out_specs=P('i'),
|
||||
)
|
||||
def f(x):
|
||||
return jnp.sin(x)
|
||||
|
||||
xs = jnp.arange(4 * 16.).reshape(4, 16)
|
||||
with self.assertRaisesRegex(ValueError, "spmd_axis_name cannot appear"):
|
||||
jax.vmap(f, spmd_axis_name='i')(xs)
|
||||
|
||||
@partial(
|
||||
shard_map,
|
||||
mesh=mesh,
|
||||
in_specs=P('j'),
|
||||
out_specs=P(('i', 'j')),
|
||||
check_rep=False,
|
||||
)
|
||||
def g(x):
|
||||
return jnp.sin(x)
|
||||
|
||||
xs = jnp.arange(4 * 16.).reshape(4, 16)
|
||||
with self.assertRaisesRegex(ValueError, "spmd_axis_name cannot appear"):
|
||||
jax.vmap(g, spmd_axis_name='i')(xs)
|
||||
|
||||
|
||||
class FunSpec(NamedTuple):
|
||||
name: str
|
||||
num_inputs: int
|
||||
|
Loading…
x
Reference in New Issue
Block a user