mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add support for axis_name
and axis_index_groups
to lax.ragged_all_to_all
PiperOrigin-RevId: 720738861
This commit is contained in:
parent
9cbff64251
commit
809e1133c8
@ -457,7 +457,9 @@ def all_to_all(x, axis_name, split_axis, concat_axis, *, axis_index_groups=None,
|
||||
|
||||
return tree_util.tree_map(bind, x)
|
||||
|
||||
def ragged_all_to_all(operand, output, input_offsets, send_sizes, output_offsets, recv_sizes):
|
||||
def ragged_all_to_all(
|
||||
operand, output, input_offsets, send_sizes, output_offsets, recv_sizes, *,
|
||||
axis_name, axis_index_groups = None):
|
||||
"""Ragged version of :func:`all_to_all`.
|
||||
|
||||
For now, ``split_axis`` and ``concat_axis`` from `all_to_all` are equivalent
|
||||
@ -528,14 +530,26 @@ def ragged_all_to_all(operand, output, input_offsets, send_sizes, output_offsets
|
||||
send_sizes: array of ragged output data.
|
||||
output_offsets: array of ragged offsets in the target replica output.
|
||||
recv_sizes: array of ragged output receive sizes.
|
||||
axis_name: hashable Python object used to name a pmapped axis (see the
|
||||
:func:`jax.pmap` documentation for more details).
|
||||
axis_index_groups: optional list of lists containing axis indices (e.g. for
|
||||
an axis of size 4, [[0, 1], [2, 3]] would run ragged all to all over the
|
||||
first two and last two replicas). Groups must cover all axis indices
|
||||
exactly once, and all groups must be the same size. Otherwise, the
|
||||
behavior is undefined.
|
||||
|
||||
Returns:
|
||||
array with shape equal to ``output``.
|
||||
"""
|
||||
return ragged_all_to_all_p.bind(operand, output, input_offsets, send_sizes,
|
||||
output_offsets, recv_sizes)
|
||||
|
||||
ragged_all_to_all_p = core.Primitive('ragged_all_to_all')
|
||||
if not isinstance(axis_name, (tuple, list)):
|
||||
axis_name = (axis_name,)
|
||||
|
||||
axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups)
|
||||
return ragged_all_to_all_p.bind(operand, output, input_offsets, send_sizes,
|
||||
output_offsets, recv_sizes,
|
||||
axis_name=axis_name,
|
||||
axis_index_groups=axis_index_groups)
|
||||
|
||||
|
||||
def axis_index(axis_name):
|
||||
@ -1134,29 +1148,44 @@ batching.fancy_primitive_batchers[all_to_all_p] = _all_to_all_batched_collective
|
||||
batching.skippable_batchers[all_to_all_p] = partial(_names_in_param, 'axis_name')
|
||||
|
||||
|
||||
def _ragged_all_to_all_lowering(ctx, operand, output, input_offsets, send_sizes, output_offsets, recv_sizes):
|
||||
N = input_offsets.type.shape[0]
|
||||
backend_config = ir.DictAttr.get({
|
||||
'replica_groups': ir.DenseIntElementsAttr.get(
|
||||
np.arange(0, N, 1, dtype=np.int64), shape=[1, N]
|
||||
)
|
||||
})
|
||||
def _ragged_all_to_all_lowering(
|
||||
ctx, operand, output, input_offsets, send_sizes, output_offsets, recv_sizes,
|
||||
*, axis_name, axis_index_groups
|
||||
):
|
||||
replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name,
|
||||
axis_index_groups)
|
||||
|
||||
# Assumes all groups are the same size
|
||||
split_count = len(replica_groups[0])
|
||||
if not all(split_count == len(g) for g in replica_groups):
|
||||
raise ValueError('Replica groups must be equally sized')
|
||||
if len(replica_groups[0]) == 1:
|
||||
return [operand]
|
||||
|
||||
ragged_all_to_all_attrs = {
|
||||
"replica_groups": _replica_groups_hlo(replica_groups)
|
||||
}
|
||||
is_spmd = isinstance(
|
||||
ctx.module_context.axis_context, (SPMDAxisContext, ShardingContext))
|
||||
if is_spmd:
|
||||
ragged_all_to_all_attrs['channel_id'] = ir.IntegerAttr.get(
|
||||
ir.IntegerType.get_signless(64), ctx.module_context.new_channel()
|
||||
)
|
||||
|
||||
return hlo.CustomCallOp(
|
||||
result=[output.type],
|
||||
inputs=[operand, output, input_offsets, send_sizes, output_offsets,
|
||||
recv_sizes],
|
||||
call_target_name=ir.StringAttr.get('ragged_all_to_all'),
|
||||
backend_config=backend_config,
|
||||
backend_config=ir.DictAttr.get(ragged_all_to_all_attrs),
|
||||
api_version=ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 4),
|
||||
).results
|
||||
|
||||
@ragged_all_to_all_p.def_abstract_eval
|
||||
def _ragged_all_to_all_abstract_eval(operand, output, input_offsets, send_sizes, output_offsets, recv_sizes):
|
||||
if operand.shape[1:] != output.shape[1:]:
|
||||
raise ValueError(
|
||||
"ragged_all_to_all input and output shapes must be equal, except for"
|
||||
" the outermost dimension."
|
||||
)
|
||||
def _ragged_all_to_all_effectful_abstract_eval(
|
||||
operand, output, input_offsets, send_sizes, output_offsets, recv_sizes,
|
||||
axis_name, axis_index_groups
|
||||
):
|
||||
del operand, axis_index_groups
|
||||
if not dtypes.issubdtype(input_offsets.dtype, np.integer):
|
||||
raise ValueError("ragged_all_to_all input_offsets must be integer type.")
|
||||
if not dtypes.issubdtype(send_sizes.dtype, np.integer):
|
||||
@ -1185,15 +1214,16 @@ def _ragged_all_to_all_abstract_eval(operand, output, input_offsets, send_sizes,
|
||||
"ragged_all_to_all recv_sizes must be rank 1 with positive dimension"
|
||||
" size, but got shape {}".format(recv_sizes.shape)
|
||||
)
|
||||
return output.update(
|
||||
shape=list(output.shape),
|
||||
dtype=output.dtype,
|
||||
weak_type=output.weak_type,
|
||||
)
|
||||
|
||||
ragged_all_to_all_p.def_impl(partial(dispatch.apply_primitive, ragged_all_to_all_p))
|
||||
_check_axis_names(axis_name)
|
||||
out_aval = output.update(shape=output.shape, weak_type=False)
|
||||
effects = {*map(core.NamedAxisEffect, axis_name)}
|
||||
return out_aval, effects
|
||||
|
||||
ragged_all_to_all_p = core.Primitive('ragged_all_to_all')
|
||||
ragged_all_to_all_p.def_effectful_abstract_eval(_ragged_all_to_all_effectful_abstract_eval)
|
||||
mlir.register_lowering(ragged_all_to_all_p, _ragged_all_to_all_lowering)
|
||||
|
||||
batching.skippable_batchers[ragged_all_to_all_p] = partial(_names_in_param, 'axis_name')
|
||||
|
||||
def all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False):
|
||||
"""Gather values of x across all replicas.
|
||||
|
21
tests/BUILD
21
tests/BUILD
@ -1377,6 +1377,27 @@ jax_multiplatform_test(
|
||||
},
|
||||
)
|
||||
|
||||
jax_multiplatform_test(
|
||||
name = "ragged_collective_test",
|
||||
srcs = ["ragged_collective_test.py"],
|
||||
disable_configs = [
|
||||
"tpu_pjrt_c_api",
|
||||
],
|
||||
enable_backends = [
|
||||
"gpu",
|
||||
"tpu",
|
||||
],
|
||||
enable_configs = [
|
||||
"gpu_p100x2_shardy",
|
||||
],
|
||||
tags = [
|
||||
"multiaccelerator",
|
||||
],
|
||||
deps = [
|
||||
"//jax:experimental",
|
||||
],
|
||||
)
|
||||
|
||||
jax_multiplatform_test(
|
||||
name = "shard_map_test",
|
||||
srcs = ["shard_map_test.py"],
|
||||
|
@ -4720,23 +4720,6 @@ class CompositeTest(jtu.JaxTestCase):
|
||||
|
||||
class RaggedTest(jtu.JaxTestCase):
|
||||
|
||||
def testRaggedAllToAll(self):
|
||||
operand = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=jnp.float32)
|
||||
output = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=jnp.float32)
|
||||
input_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
|
||||
send_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
|
||||
output_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
|
||||
recv_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
|
||||
mlir_module = jax.jit(lax.ragged_all_to_all).lower(
|
||||
operand, output, input_offsets, send_sizes, output_offsets,
|
||||
recv_sizes).as_text()
|
||||
self.assertIn("stablehlo.custom_call @ragged_all_to_all", mlir_module)
|
||||
self.assertIn(
|
||||
"backend_config = {replica_groups = dense<[[0, 1, 2]]> :"
|
||||
" tensor<1x3xi64>}}",
|
||||
mlir_module,
|
||||
)
|
||||
|
||||
def testRaggedAllToAllErrors(self):
|
||||
operand = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=jnp.float32)
|
||||
output = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=jnp.float32)
|
||||
@ -4744,117 +4727,111 @@ class RaggedTest(jtu.JaxTestCase):
|
||||
send_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
|
||||
output_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
|
||||
recv_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
"ragged_all_to_all input and output shapes must be equal, except for"
|
||||
" the outermost dimension.",
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
operand,
|
||||
jnp.array([[0.0], [0.0], [0.0], [0.0], [0.0]], dtype=jnp.float32),
|
||||
input_offsets, send_sizes, output_offsets, recv_sizes)
|
||||
axis_name = "x"
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError, "ragged_all_to_all input_offsets must be integer type."
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
jax.jit(lax.ragged_all_to_all, static_argnames='axis_name').lower(
|
||||
operand, output, jnp.array([0.0, 1.0, 3.0], dtype=jnp.float32),
|
||||
send_sizes, output_offsets, recv_sizes)
|
||||
send_sizes, output_offsets, recv_sizes, axis_name=axis_name)
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError, "ragged_all_to_all send_sizes must be integer type."
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
jax.jit(lax.ragged_all_to_all, static_argnames='axis_name').lower(
|
||||
operand, output, input_offsets,
|
||||
jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32), output_offsets,
|
||||
recv_sizes)
|
||||
recv_sizes, axis_name=axis_name)
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError, "ragged_all_to_all output_offsets must be integer type."
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
jax.jit(lax.ragged_all_to_all, static_argnames='axis_name').lower(
|
||||
operand, output, input_offsets, send_sizes,
|
||||
jnp.array([0.0, 1.0, 3.0], dtype=jnp.float32), recv_sizes)
|
||||
jnp.array([0.0, 1.0, 3.0], dtype=jnp.float32), recv_sizes,
|
||||
axis_name=axis_name)
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError, "ragged_all_to_all recv_sizes must be integer type."
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
jax.jit(lax.ragged_all_to_all, static_argnames='axis_name').lower(
|
||||
operand, output, input_offsets, send_sizes, output_offsets,
|
||||
jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32))
|
||||
jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32), axis_name=axis_name)
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
"ragged_all_to_all input_offsets must be rank 1 with positive dimension"
|
||||
" size, but got shape (1, 3)",
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
jax.jit(lax.ragged_all_to_all, static_argnames='axis_name').lower(
|
||||
operand, output, jnp.array([[0, 1, 3]], dtype=jnp.int32), send_sizes,
|
||||
output_offsets, recv_sizes)
|
||||
output_offsets, recv_sizes, axis_name=axis_name)
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
"ragged_all_to_all input_offsets must be rank 1 with positive dimension"
|
||||
" size, but got shape (0,)",
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
jax.jit(lax.ragged_all_to_all, static_argnames='axis_name').lower(
|
||||
operand, output, jnp.array([], dtype=jnp.int32), send_sizes,
|
||||
output_offsets, recv_sizes)
|
||||
output_offsets, recv_sizes, axis_name=axis_name)
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
"ragged_all_to_all send_sizes must be rank 1 with positive dimension"
|
||||
" size, but got shape (1, 3)",
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
jax.jit(lax.ragged_all_to_all, static_argnames='axis_name').lower(
|
||||
operand, output, input_offsets,
|
||||
jnp.array([[1, 2, 3]], dtype=jnp.int32), output_offsets, recv_sizes)
|
||||
jnp.array([[1, 2, 3]], dtype=jnp.int32), output_offsets, recv_sizes,
|
||||
axis_name=axis_name)
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
"ragged_all_to_all send_sizes must be rank 1 with positive dimension"
|
||||
" size, but got shape (0,)",
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
jax.jit(lax.ragged_all_to_all, static_argnames='axis_name').lower(
|
||||
operand, output, input_offsets, jnp.array([], dtype=jnp.int32),
|
||||
output_offsets, recv_sizes)
|
||||
output_offsets, recv_sizes, axis_name=axis_name)
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
"ragged_all_to_all output_offsets must be rank 1 with positive"
|
||||
" dimension size, but got shape (1, 3)",
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
jax.jit(lax.ragged_all_to_all, static_argnames='axis_name').lower(
|
||||
operand, output, input_offsets, send_sizes,
|
||||
jnp.array([[0, 1, 3]], dtype=jnp.int32), recv_sizes)
|
||||
jnp.array([[0, 1, 3]], dtype=jnp.int32), recv_sizes,
|
||||
axis_name=axis_name)
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
"ragged_all_to_all output_offsets must be rank 1 with positive"
|
||||
" dimension size, but got shape (0,)",
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
jax.jit(lax.ragged_all_to_all, static_argnames='axis_name').lower(
|
||||
operand, output, input_offsets, send_sizes,
|
||||
jnp.array([], dtype=jnp.int32), recv_sizes)
|
||||
jnp.array([], dtype=jnp.int32), recv_sizes, axis_name=axis_name)
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
"ragged_all_to_all recv_sizes must be rank 1 with positive dimension"
|
||||
" size, but got shape (1, 3)",
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
jax.jit(lax.ragged_all_to_all, static_argnames='axis_name').lower(
|
||||
operand, output, input_offsets, send_sizes, output_offsets,
|
||||
jnp.array([[1, 2, 3]], dtype=jnp.int32))
|
||||
jnp.array([[1, 2, 3]], dtype=jnp.int32), axis_name=axis_name)
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
"ragged_all_to_all recv_sizes must be rank 1 with positive dimension"
|
||||
" size, but got shape (0,)",
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
jax.jit(lax.ragged_all_to_all, static_argnames='axis_name').lower(
|
||||
operand, output, input_offsets, send_sizes, output_offsets,
|
||||
jnp.array([], dtype=jnp.int32))
|
||||
jnp.array([], dtype=jnp.int32), axis_name=axis_name)
|
||||
|
||||
@jtu.sample_product(
|
||||
[
|
||||
|
219
tests/ragged_collective_test.py
Normal file
219
tests/ragged_collective_test.py
Normal file
@ -0,0 +1,219 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import jax
|
||||
import jax.ad_checkpoint
|
||||
from jax import lax
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax.experimental.shard_map import shard_map
|
||||
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
class RaggedCollectiveTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(
|
||||
testcase_name='_single_axis_name', axis_name='x', mesh_axes=dict(x=2)
|
||||
),
|
||||
)
|
||||
def test_ragged_all_to_all(self, axis_name, mesh_axes):
|
||||
device_type = jax.devices()[0].platform
|
||||
if device_type == 'tpu' and jtu.get_tpu_version() == 3:
|
||||
raise unittest.SkipTest(
|
||||
'UNSUPPORTED: HLO opcode `ragged-all-to-all` is not supported by'
|
||||
' TPU v3'
|
||||
)
|
||||
mesh = jtu.create_mesh(tuple(mesh_axes.values()), tuple(mesh_axes.keys()))
|
||||
operand = jax.device_put(
|
||||
jnp.array([[1, 2, 2], [3, 4, 0]], dtype=jnp.int32),
|
||||
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
|
||||
)
|
||||
output = jax.device_put(
|
||||
jnp.zeros((2, 4), dtype=jnp.int32),
|
||||
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
|
||||
)
|
||||
input_offsets = jax.device_put(
|
||||
jnp.array([[0, 1], [0, 1]], dtype=jnp.int32),
|
||||
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
|
||||
)
|
||||
send_sizes = jax.device_put(
|
||||
jnp.array([[1, 2], [1, 1]], dtype=jnp.int32),
|
||||
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
|
||||
)
|
||||
output_offsets = jax.device_put(
|
||||
jnp.array([[0, 0], [1, 2]], dtype=jnp.int32),
|
||||
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
|
||||
)
|
||||
recv_sizes = jax.device_put(
|
||||
jnp.array([[1, 1], [2, 1]], dtype=jnp.int32),
|
||||
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
|
||||
)
|
||||
|
||||
@jax.jit
|
||||
@partial(
|
||||
shard_map,
|
||||
mesh=mesh,
|
||||
in_specs=(
|
||||
P(axis_name, None),
|
||||
P(axis_name, None),
|
||||
P(axis_name, None),
|
||||
P(axis_name, None),
|
||||
P(axis_name, None),
|
||||
P(axis_name, None),
|
||||
),
|
||||
out_specs=P(axis_name),
|
||||
check_rep=False,
|
||||
)
|
||||
def fwd(
|
||||
operand, output, input_offsets, send_sizes, output_offsets, recv_sizes
|
||||
):
|
||||
operand = operand.reshape(operand.shape[1:])
|
||||
output = output.reshape(output.shape[1:])
|
||||
input_offsets = input_offsets.reshape(input_offsets.shape[1:])
|
||||
send_sizes = send_sizes.reshape(send_sizes.shape[1:])
|
||||
output_offsets = output_offsets.reshape(output_offsets.shape[1:])
|
||||
recv_sizes = recv_sizes.reshape(recv_sizes.shape[1:])
|
||||
return lax.ragged_all_to_all(
|
||||
operand,
|
||||
output,
|
||||
input_offsets,
|
||||
send_sizes,
|
||||
output_offsets,
|
||||
recv_sizes,
|
||||
axis_name=axis_name,
|
||||
)
|
||||
|
||||
mlir_module = fwd.lower(
|
||||
operand, output, input_offsets, send_sizes, output_offsets, recv_sizes
|
||||
).as_text()
|
||||
self.assertIn('stablehlo.custom_call @ragged_all_to_all', mlir_module)
|
||||
self.assertIn(
|
||||
'replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>', mlir_module
|
||||
)
|
||||
|
||||
c = fwd(
|
||||
operand, output, input_offsets, send_sizes, output_offsets, recv_sizes
|
||||
).reshape((2, 4))
|
||||
self.assertAllClose(
|
||||
c, jnp.array([[1, 3, 0, 0], [2, 2, 4, 0]], dtype=jnp.int32)
|
||||
)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(
|
||||
testcase_name='_single_axis_name', axis_name='x', mesh_axes=dict(x=4)
|
||||
),
|
||||
)
|
||||
def test_ragged_all_to_all_axis_index_groups(self, axis_name, mesh_axes):
|
||||
device_type = jax.devices()[0].platform
|
||||
if device_type == 'tpu' and jtu.get_tpu_version() == 3:
|
||||
raise unittest.SkipTest(
|
||||
'UNSUPPORTED: HLO opcode `ragged-all-to-all` is not supported by'
|
||||
' TPU v3'
|
||||
)
|
||||
mesh = jtu.create_mesh(tuple(mesh_axes.values()), tuple(mesh_axes.keys()))
|
||||
operand = jax.device_put(
|
||||
jnp.array([[1, 2, 2], [3, 4, 0],
|
||||
[10, 20, 20], [30, 40, 0]], dtype=jnp.int32),
|
||||
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
|
||||
)
|
||||
output = jax.device_put(
|
||||
jnp.zeros((4, 4), dtype=jnp.int32),
|
||||
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
|
||||
)
|
||||
input_offsets = jax.device_put(
|
||||
jnp.array([[0, 1], [0, 1],
|
||||
[0, 1], [0, 1]], dtype=jnp.int32),
|
||||
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
|
||||
)
|
||||
send_sizes = jax.device_put(
|
||||
jnp.array([[1, 2], [1, 1],
|
||||
[1, 2], [1, 1]], dtype=jnp.int32),
|
||||
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
|
||||
)
|
||||
output_offsets = jax.device_put(
|
||||
jnp.array([[0, 0], [1, 2],
|
||||
[0, 0], [1, 2]], dtype=jnp.int32),
|
||||
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
|
||||
)
|
||||
recv_sizes = jax.device_put(
|
||||
jnp.array([[1, 1], [2, 1],
|
||||
[1, 1], [2, 1]], dtype=jnp.int32),
|
||||
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
|
||||
)
|
||||
axis_index_groups = ((0, 1), (2, 3))
|
||||
|
||||
@jax.jit
|
||||
@partial(
|
||||
shard_map,
|
||||
mesh=mesh,
|
||||
in_specs=(
|
||||
P(axis_name, None),
|
||||
P(axis_name, None),
|
||||
P(axis_name, None),
|
||||
P(axis_name, None),
|
||||
P(axis_name, None),
|
||||
P(axis_name, None),
|
||||
),
|
||||
out_specs=P(axis_name),
|
||||
check_rep=False,
|
||||
)
|
||||
def fwd(
|
||||
operand, output, input_offsets, send_sizes, output_offsets, recv_sizes
|
||||
):
|
||||
operand = operand.reshape(operand.shape[1:])
|
||||
output = output.reshape(output.shape[1:])
|
||||
input_offsets = input_offsets.reshape(input_offsets.shape[1:])
|
||||
send_sizes = send_sizes.reshape(send_sizes.shape[1:])
|
||||
output_offsets = output_offsets.reshape(output_offsets.shape[1:])
|
||||
recv_sizes = recv_sizes.reshape(recv_sizes.shape[1:])
|
||||
return lax.ragged_all_to_all(
|
||||
operand,
|
||||
output,
|
||||
input_offsets,
|
||||
send_sizes,
|
||||
output_offsets,
|
||||
recv_sizes,
|
||||
axis_name=axis_name,
|
||||
axis_index_groups=axis_index_groups,
|
||||
)
|
||||
|
||||
mlir_module = fwd.lower(
|
||||
operand, output, input_offsets, send_sizes, output_offsets,
|
||||
recv_sizes).as_text()
|
||||
self.assertIn("stablehlo.custom_call @ragged_all_to_all", mlir_module)
|
||||
self.assertIn("replica_groups = dense<[[0, 1], [2, 3]]> :"
|
||||
" tensor<2x2xi64>", mlir_module)
|
||||
|
||||
c = fwd(
|
||||
operand, output, input_offsets, send_sizes, output_offsets, recv_sizes
|
||||
).reshape((4, 4))
|
||||
self.assertAllClose(
|
||||
c, jnp.array([[1, 3, 0, 0], [2, 2, 4, 0],
|
||||
[10, 30, 0, 0], [20, 20, 40, 0]], dtype=jnp.int32)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user