mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Fix lax.ragged_all_to_all
degenerate case
In a singleton group case, unlike regular all_to_all, the ragged op becomes a generic equivalent of DynamicUpdateSlice, except update size is not statically known. This operation can't be expressed with standard HLO instructions -- the backend will handle this case separately. Added small improvement to error messages. PiperOrigin-RevId: 721473063
This commit is contained in:
parent
f4e2c6c34c
commit
a8df383ccf
@ -1159,8 +1159,6 @@ def _ragged_all_to_all_lowering(
|
||||
split_count = len(replica_groups[0])
|
||||
if not all(split_count == len(g) for g in replica_groups):
|
||||
raise ValueError('Replica groups must be equally sized')
|
||||
if len(replica_groups[0]) == 1:
|
||||
return [operand]
|
||||
|
||||
ragged_all_to_all_attrs = {
|
||||
"replica_groups": _replica_groups_hlo(replica_groups)
|
||||
|
@ -1390,6 +1390,10 @@ jax_multiplatform_test(
|
||||
enable_configs = [
|
||||
"gpu_p100x2_shardy",
|
||||
],
|
||||
shard_count = {
|
||||
"gpu": 10,
|
||||
"tpu": 10,
|
||||
},
|
||||
tags = [
|
||||
"multiaccelerator",
|
||||
],
|
||||
|
@ -45,10 +45,10 @@ class RaggedCollectiveTest(jtu.JaxTestCase):
|
||||
)
|
||||
def test_ragged_all_to_all(self, axis_name, mesh_axes):
|
||||
device_type = jax.devices()[0].platform
|
||||
if device_type == 'tpu' and jtu.get_tpu_version() == 3:
|
||||
if device_type == 'tpu' and jtu.get_tpu_version() < 4:
|
||||
raise unittest.SkipTest(
|
||||
'UNSUPPORTED: HLO opcode `ragged-all-to-all` is not supported by'
|
||||
' TPU v3'
|
||||
'UNSUPPORTED: HLO opcode `ragged-all-to-all` is not supported by TPU'
|
||||
f' v{jtu.get_tpu_version()}'
|
||||
)
|
||||
mesh = jtu.create_mesh(tuple(mesh_axes.values()), tuple(mesh_axes.keys()))
|
||||
operand = jax.device_put(
|
||||
@ -132,10 +132,10 @@ class RaggedCollectiveTest(jtu.JaxTestCase):
|
||||
)
|
||||
def test_ragged_all_to_all_axis_index_groups(self, axis_name, mesh_axes):
|
||||
device_type = jax.devices()[0].platform
|
||||
if device_type == 'tpu' and jtu.get_tpu_version() == 3:
|
||||
if device_type == 'tpu' and jtu.get_tpu_version() < 4:
|
||||
raise unittest.SkipTest(
|
||||
'UNSUPPORTED: HLO opcode `ragged-all-to-all` is not supported by'
|
||||
' TPU v3'
|
||||
'UNSUPPORTED: HLO opcode `ragged-all-to-all` is not supported by TPU'
|
||||
f' v{jtu.get_tpu_version()}'
|
||||
)
|
||||
mesh = jtu.create_mesh(tuple(mesh_axes.values()), tuple(mesh_axes.keys()))
|
||||
operand = jax.device_put(
|
||||
@ -219,6 +219,94 @@ class RaggedCollectiveTest(jtu.JaxTestCase):
|
||||
[10, 30, 0, 0], [20, 20, 40, 0]], dtype=jnp.int32)
|
||||
)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(
|
||||
testcase_name='_single_axis_name', axis_name='x', mesh_axes=dict(x=2)
|
||||
),
|
||||
)
|
||||
def test_ragged_all_to_all_degenerate_groups(self, axis_name, mesh_axes):
|
||||
device_type = jax.devices()[0].platform
|
||||
if device_type == 'tpu':
|
||||
raise unittest.SkipTest(
|
||||
'UNSUPPORTED: HLO opcode `ragged-all-to-all` with singleton group is'
|
||||
' not supported by TPU'
|
||||
)
|
||||
mesh = jtu.create_mesh(tuple(mesh_axes.values()), tuple(mesh_axes.keys()))
|
||||
operand = jax.device_put(
|
||||
jnp.array([[1, 0, 0, 0], [2, 3, 4, 0]], dtype=jnp.int32),
|
||||
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
|
||||
)
|
||||
output = jax.device_put(
|
||||
jnp.zeros((2, 4), dtype=jnp.int32),
|
||||
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
|
||||
)
|
||||
input_offsets = jax.device_put(
|
||||
jnp.array([[0], [0]], dtype=jnp.int32),
|
||||
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
|
||||
)
|
||||
send_sizes = jax.device_put(
|
||||
jnp.array([[1], [3]], dtype=jnp.int32),
|
||||
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
|
||||
)
|
||||
output_offsets = jax.device_put(
|
||||
jnp.array([[2], [1]], dtype=jnp.int32),
|
||||
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
|
||||
)
|
||||
recv_sizes = jax.device_put(
|
||||
jnp.array([[1], [3]], dtype=jnp.int32),
|
||||
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
|
||||
)
|
||||
axis_index_groups = ((0,), (1,))
|
||||
|
||||
@jax.jit
|
||||
@partial(
|
||||
shard_map,
|
||||
mesh=mesh,
|
||||
in_specs=(
|
||||
P(axis_name, None),
|
||||
P(axis_name, None),
|
||||
P(axis_name, None),
|
||||
P(axis_name, None),
|
||||
P(axis_name, None),
|
||||
P(axis_name, None),
|
||||
),
|
||||
out_specs=P(axis_name),
|
||||
check_rep=False,
|
||||
)
|
||||
def fwd(
|
||||
operand, output, input_offsets, send_sizes, output_offsets, recv_sizes
|
||||
):
|
||||
operand = operand.reshape(operand.shape[1:])
|
||||
output = output.reshape(output.shape[1:])
|
||||
input_offsets = input_offsets.reshape(input_offsets.shape[1:])
|
||||
send_sizes = send_sizes.reshape(send_sizes.shape[1:])
|
||||
output_offsets = output_offsets.reshape(output_offsets.shape[1:])
|
||||
recv_sizes = recv_sizes.reshape(recv_sizes.shape[1:])
|
||||
return lax.ragged_all_to_all(
|
||||
operand,
|
||||
output,
|
||||
input_offsets,
|
||||
send_sizes,
|
||||
output_offsets,
|
||||
recv_sizes,
|
||||
axis_name=axis_name,
|
||||
axis_index_groups=axis_index_groups,
|
||||
)
|
||||
|
||||
mlir_module = fwd.lower(
|
||||
operand, output, input_offsets, send_sizes, output_offsets,
|
||||
recv_sizes).as_text()
|
||||
self.assertIn("stablehlo.custom_call @ragged_all_to_all", mlir_module)
|
||||
self.assertIn("replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>",
|
||||
mlir_module)
|
||||
|
||||
c = fwd(
|
||||
operand, output, input_offsets, send_sizes, output_offsets, recv_sizes
|
||||
).reshape((2, 4))
|
||||
self.assertAllClose(
|
||||
c, jnp.array([[0, 0, 1, 0], [0, 2, 3, 4]], dtype=jnp.int32)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user