Move ragged tests under a new class.

PiperOrigin-RevId: 708811348
This commit is contained in:
Gunhyun Park 2024-12-22 07:49:43 -08:00 committed by jax authors
parent 1719986aaa
commit 38747a7a5d

View File

@ -1385,83 +1385,6 @@ class LaxTest(jtu.JaxTestCase):
numpy_op = lambda x, y: lax_reference.dot_general(x, y, dimension_numbers) numpy_op = lambda x, y: lax_reference.dot_general(x, y, dimension_numbers)
self._CheckAgainstNumpy(numpy_op, op, args_maker) self._CheckAgainstNumpy(numpy_op, op, args_maker)
def testRaggedAllToAllErrors(self):
operand = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=jnp.float32)
output = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=jnp.float32)
input_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
send_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
output_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
recv_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all input and output shapes must be equal, except for the outermost dimension."):
jax.jit(lax.ragged_all_to_all).lower(operand, jnp.array([[0.0], [0.0], [0.0], [0.0], [0.0]], dtype=jnp.float32), input_offsets, send_sizes, output_offsets, recv_sizes)
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all input_offsets must be integer type."):
jax.jit(lax.ragged_all_to_all).lower(operand, output, jnp.array([0.0, 1.0, 3.0], dtype=jnp.float32), send_sizes, output_offsets, recv_sizes)
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all send_sizes must be integer type."):
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32), output_offsets, recv_sizes)
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all output_offsets must be integer type."):
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, jnp.array([0.0, 1.0, 3.0], dtype=jnp.float32), recv_sizes)
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all recv_sizes must be integer type."):
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, output_offsets, jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32))
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all input_offsets must be rank 1 with positive dimension size, but got shape (1, 3)"):
jax.jit(lax.ragged_all_to_all).lower(operand, output, jnp.array([[0, 1, 3]], dtype=jnp.int32), send_sizes, output_offsets, recv_sizes)
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all input_offsets must be rank 1 with positive dimension size, but got shape (0,)"):
jax.jit(lax.ragged_all_to_all).lower(operand, output, jnp.array([], dtype=jnp.int32), send_sizes, output_offsets, recv_sizes)
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all send_sizes must be rank 1 with positive dimension size, but got shape (1, 3)"):
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, jnp.array([[1, 2, 3]], dtype=jnp.int32), output_offsets, recv_sizes)
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all send_sizes must be rank 1 with positive dimension size, but got shape (0,)"):
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, jnp.array([], dtype=jnp.int32), output_offsets, recv_sizes)
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all output_offsets must be rank 1 with positive dimension size, but got shape (1, 3)"):
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, jnp.array([[0, 1, 3]], dtype=jnp.int32), recv_sizes)
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all output_offsets must be rank 1 with positive dimension size, but got shape (0,)"):
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, jnp.array([], dtype=jnp.int32), recv_sizes)
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all recv_sizes must be rank 1 with positive dimension size, but got shape (1, 3)"):
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, output_offsets, jnp.array([[1, 2, 3]], dtype=jnp.int32))
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all recv_sizes must be rank 1 with positive dimension size, but got shape (0,)"):
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, output_offsets, jnp.array([], dtype=jnp.int32))
def testRaggedAllToAll(self):
operand = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=jnp.float32)
output = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=jnp.float32)
input_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
send_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
output_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
recv_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
mlir_module = jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, output_offsets, recv_sizes).as_text()
self.assertIn("stablehlo.custom_call @ragged_all_to_all", mlir_module)
self.assertIn(
"backend_config = {replica_groups = dense<[[0, 1, 2]]> :"
" tensor<1x3xi64>}}",
mlir_module,
)
@jtu.sample_product(
[
{'m': 5, 'k': 4, 'n': 3, 'num_groups': 1},
{'m': 10, 'k': 9, 'n': 8, 'num_groups': 2},
],
dtype=jtu.dtypes.numeric,
)
def testRaggedDot(self, m, k, n, num_groups, dtype):
"""Tests ragged_dot.
The ragged_dot is tested against numpy reference implementation, and by running JAX compilation.
Raises:
SkipTest: in the case dtype is not supported.
"""
lhs_shape = (m, k)
rhs_shape = (num_groups, k, n)
def group_sizes(m, num_groups):
ends_no_final = jnp.sort(self.rng().choice(m, size=num_groups - 1))
ends = jnp.concatenate([ends_no_final, jnp.array([m], dtype=ends_no_final.dtype)])
starts = jnp.concatenate([jnp.zeros(1, dtype=ends_no_final.dtype), ends_no_final])
return ends - starts
rng = jtu.rand_small(self.rng())
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype), group_sizes(m, num_groups)]
self._CompileAndCheck(lax.ragged_dot, args_maker)
self._CheckAgainstNumpy(lax_reference.ragged_dot, lax.ragged_dot, args_maker)
@jtu.sample_product( @jtu.sample_product(
shape=[(), (2, 3)], shape=[(), (2, 3)],
dtype=lax_test_util.default_dtypes, dtype=lax_test_util.default_dtypes,
@ -4757,5 +4680,181 @@ class CompositeTest(jtu.JaxTestCase):
): ):
grad(my_square)(1.0) grad(my_square)(1.0)
class RaggedTest(jtu.JaxTestCase):
def testRaggedAllToAll(self):
operand = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=jnp.float32)
output = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=jnp.float32)
input_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
send_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
output_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
recv_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
mlir_module = jax.jit(lax.ragged_all_to_all).lower(
operand, output, input_offsets, send_sizes, output_offsets,
recv_sizes).as_text()
self.assertIn("stablehlo.custom_call @ragged_all_to_all", mlir_module)
self.assertIn(
"backend_config = {replica_groups = dense<[[0, 1, 2]]> :"
" tensor<1x3xi64>}}",
mlir_module,
)
def testRaggedAllToAllErrors(self):
operand = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=jnp.float32)
output = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=jnp.float32)
input_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
send_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
output_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
recv_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
with self.assertRaisesWithLiteralMatch(
ValueError,
"ragged_all_to_all input and output shapes must be equal, except for"
" the outermost dimension.",
):
jax.jit(lax.ragged_all_to_all).lower(
operand,
jnp.array([[0.0], [0.0], [0.0], [0.0], [0.0]], dtype=jnp.float32),
input_offsets, send_sizes, output_offsets, recv_sizes)
with self.assertRaisesWithLiteralMatch(
ValueError, "ragged_all_to_all input_offsets must be integer type."
):
jax.jit(lax.ragged_all_to_all).lower(
operand, output, jnp.array([0.0, 1.0, 3.0], dtype=jnp.float32),
send_sizes, output_offsets, recv_sizes)
with self.assertRaisesWithLiteralMatch(
ValueError, "ragged_all_to_all send_sizes must be integer type."
):
jax.jit(lax.ragged_all_to_all).lower(
operand, output, input_offsets,
jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32), output_offsets,
recv_sizes)
with self.assertRaisesWithLiteralMatch(
ValueError, "ragged_all_to_all output_offsets must be integer type."
):
jax.jit(lax.ragged_all_to_all).lower(
operand, output, input_offsets, send_sizes,
jnp.array([0.0, 1.0, 3.0], dtype=jnp.float32), recv_sizes)
with self.assertRaisesWithLiteralMatch(
ValueError, "ragged_all_to_all recv_sizes must be integer type."
):
jax.jit(lax.ragged_all_to_all).lower(
operand, output, input_offsets, send_sizes, output_offsets,
jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32))
with self.assertRaisesWithLiteralMatch(
ValueError,
"ragged_all_to_all input_offsets must be rank 1 with positive dimension"
" size, but got shape (1, 3)",
):
jax.jit(lax.ragged_all_to_all).lower(
operand, output, jnp.array([[0, 1, 3]], dtype=jnp.int32), send_sizes,
output_offsets, recv_sizes)
with self.assertRaisesWithLiteralMatch(
ValueError,
"ragged_all_to_all input_offsets must be rank 1 with positive dimension"
" size, but got shape (0,)",
):
jax.jit(lax.ragged_all_to_all).lower(
operand, output, jnp.array([], dtype=jnp.int32), send_sizes,
output_offsets, recv_sizes)
with self.assertRaisesWithLiteralMatch(
ValueError,
"ragged_all_to_all send_sizes must be rank 1 with positive dimension"
" size, but got shape (1, 3)",
):
jax.jit(lax.ragged_all_to_all).lower(
operand, output, input_offsets,
jnp.array([[1, 2, 3]], dtype=jnp.int32), output_offsets, recv_sizes)
with self.assertRaisesWithLiteralMatch(
ValueError,
"ragged_all_to_all send_sizes must be rank 1 with positive dimension"
" size, but got shape (0,)",
):
jax.jit(lax.ragged_all_to_all).lower(
operand, output, input_offsets, jnp.array([], dtype=jnp.int32),
output_offsets, recv_sizes)
with self.assertRaisesWithLiteralMatch(
ValueError,
"ragged_all_to_all output_offsets must be rank 1 with positive"
" dimension size, but got shape (1, 3)",
):
jax.jit(lax.ragged_all_to_all).lower(
operand, output, input_offsets, send_sizes,
jnp.array([[0, 1, 3]], dtype=jnp.int32), recv_sizes)
with self.assertRaisesWithLiteralMatch(
ValueError,
"ragged_all_to_all output_offsets must be rank 1 with positive"
" dimension size, but got shape (0,)",
):
jax.jit(lax.ragged_all_to_all).lower(
operand, output, input_offsets, send_sizes,
jnp.array([], dtype=jnp.int32), recv_sizes)
with self.assertRaisesWithLiteralMatch(
ValueError,
"ragged_all_to_all recv_sizes must be rank 1 with positive dimension"
" size, but got shape (1, 3)",
):
jax.jit(lax.ragged_all_to_all).lower(
operand, output, input_offsets, send_sizes, output_offsets,
jnp.array([[1, 2, 3]], dtype=jnp.int32))
with self.assertRaisesWithLiteralMatch(
ValueError,
"ragged_all_to_all recv_sizes must be rank 1 with positive dimension"
" size, but got shape (0,)",
):
jax.jit(lax.ragged_all_to_all).lower(
operand, output, input_offsets, send_sizes, output_offsets,
jnp.array([], dtype=jnp.int32))
@jtu.sample_product(
[
{'m': 5, 'k': 4, 'n': 3, 'num_groups': 1},
{'m': 10, 'k': 9, 'n': 8, 'num_groups': 2},
],
dtype=jtu.dtypes.numeric,
)
def testRaggedDot(self, m, k, n, num_groups, dtype):
"""Tests ragged_dot.
The ragged_dot is tested against numpy reference implementation, and by
running JAX compilation.
Raises:
SkipTest: in the case dtype is not supported.
"""
lhs_shape = (m, k)
rhs_shape = (num_groups, k, n)
def group_sizes(m, num_groups):
ends_no_final = jnp.sort(self.rng().choice(m, size=num_groups - 1))
ends = jnp.concatenate(
[ends_no_final, jnp.array([m], dtype=ends_no_final.dtype)])
starts = jnp.concatenate(
[jnp.zeros(1, dtype=ends_no_final.dtype), ends_no_final])
return ends - starts
rng = jtu.rand_small(self.rng())
args_maker = lambda: [
rng(lhs_shape, dtype),
rng(rhs_shape, dtype),
group_sizes(m, num_groups),
]
self._CompileAndCheck(lax.ragged_dot, args_maker)
self._CheckAgainstNumpy(
lax_reference.ragged_dot, lax.ragged_dot, args_maker)
if __name__ == '__main__': if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader()) absltest.main(testLoader=jtu.JaxTestLoader())