mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Move ragged tests under a new class.
PiperOrigin-RevId: 708811348
This commit is contained in:
parent
1719986aaa
commit
38747a7a5d
@ -1385,83 +1385,6 @@ class LaxTest(jtu.JaxTestCase):
|
||||
numpy_op = lambda x, y: lax_reference.dot_general(x, y, dimension_numbers)
|
||||
self._CheckAgainstNumpy(numpy_op, op, args_maker)
|
||||
|
||||
def testRaggedAllToAllErrors(self):
|
||||
operand = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=jnp.float32)
|
||||
output = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=jnp.float32)
|
||||
input_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
|
||||
send_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
|
||||
output_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
|
||||
recv_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all input and output shapes must be equal, except for the outermost dimension."):
|
||||
jax.jit(lax.ragged_all_to_all).lower(operand, jnp.array([[0.0], [0.0], [0.0], [0.0], [0.0]], dtype=jnp.float32), input_offsets, send_sizes, output_offsets, recv_sizes)
|
||||
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all input_offsets must be integer type."):
|
||||
jax.jit(lax.ragged_all_to_all).lower(operand, output, jnp.array([0.0, 1.0, 3.0], dtype=jnp.float32), send_sizes, output_offsets, recv_sizes)
|
||||
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all send_sizes must be integer type."):
|
||||
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32), output_offsets, recv_sizes)
|
||||
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all output_offsets must be integer type."):
|
||||
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, jnp.array([0.0, 1.0, 3.0], dtype=jnp.float32), recv_sizes)
|
||||
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all recv_sizes must be integer type."):
|
||||
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, output_offsets, jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32))
|
||||
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all input_offsets must be rank 1 with positive dimension size, but got shape (1, 3)"):
|
||||
jax.jit(lax.ragged_all_to_all).lower(operand, output, jnp.array([[0, 1, 3]], dtype=jnp.int32), send_sizes, output_offsets, recv_sizes)
|
||||
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all input_offsets must be rank 1 with positive dimension size, but got shape (0,)"):
|
||||
jax.jit(lax.ragged_all_to_all).lower(operand, output, jnp.array([], dtype=jnp.int32), send_sizes, output_offsets, recv_sizes)
|
||||
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all send_sizes must be rank 1 with positive dimension size, but got shape (1, 3)"):
|
||||
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, jnp.array([[1, 2, 3]], dtype=jnp.int32), output_offsets, recv_sizes)
|
||||
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all send_sizes must be rank 1 with positive dimension size, but got shape (0,)"):
|
||||
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, jnp.array([], dtype=jnp.int32), output_offsets, recv_sizes)
|
||||
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all output_offsets must be rank 1 with positive dimension size, but got shape (1, 3)"):
|
||||
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, jnp.array([[0, 1, 3]], dtype=jnp.int32), recv_sizes)
|
||||
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all output_offsets must be rank 1 with positive dimension size, but got shape (0,)"):
|
||||
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, jnp.array([], dtype=jnp.int32), recv_sizes)
|
||||
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all recv_sizes must be rank 1 with positive dimension size, but got shape (1, 3)"):
|
||||
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, output_offsets, jnp.array([[1, 2, 3]], dtype=jnp.int32))
|
||||
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all recv_sizes must be rank 1 with positive dimension size, but got shape (0,)"):
|
||||
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, output_offsets, jnp.array([], dtype=jnp.int32))
|
||||
|
||||
def testRaggedAllToAll(self):
|
||||
operand = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=jnp.float32)
|
||||
output = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=jnp.float32)
|
||||
input_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
|
||||
send_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
|
||||
output_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
|
||||
recv_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
|
||||
mlir_module = jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, output_offsets, recv_sizes).as_text()
|
||||
self.assertIn("stablehlo.custom_call @ragged_all_to_all", mlir_module)
|
||||
self.assertIn(
|
||||
"backend_config = {replica_groups = dense<[[0, 1, 2]]> :"
|
||||
" tensor<1x3xi64>}}",
|
||||
mlir_module,
|
||||
)
|
||||
|
||||
@jtu.sample_product(
|
||||
[
|
||||
{'m': 5, 'k': 4, 'n': 3, 'num_groups': 1},
|
||||
{'m': 10, 'k': 9, 'n': 8, 'num_groups': 2},
|
||||
],
|
||||
dtype=jtu.dtypes.numeric,
|
||||
)
|
||||
def testRaggedDot(self, m, k, n, num_groups, dtype):
|
||||
"""Tests ragged_dot.
|
||||
|
||||
The ragged_dot is tested against numpy reference implementation, and by running JAX compilation.
|
||||
|
||||
Raises:
|
||||
SkipTest: in the case dtype is not supported.
|
||||
"""
|
||||
lhs_shape = (m, k)
|
||||
rhs_shape = (num_groups, k, n)
|
||||
def group_sizes(m, num_groups):
|
||||
ends_no_final = jnp.sort(self.rng().choice(m, size=num_groups - 1))
|
||||
ends = jnp.concatenate([ends_no_final, jnp.array([m], dtype=ends_no_final.dtype)])
|
||||
starts = jnp.concatenate([jnp.zeros(1, dtype=ends_no_final.dtype), ends_no_final])
|
||||
return ends - starts
|
||||
rng = jtu.rand_small(self.rng())
|
||||
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype), group_sizes(m, num_groups)]
|
||||
self._CompileAndCheck(lax.ragged_dot, args_maker)
|
||||
self._CheckAgainstNumpy(lax_reference.ragged_dot, lax.ragged_dot, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=[(), (2, 3)],
|
||||
dtype=lax_test_util.default_dtypes,
|
||||
@ -4757,5 +4680,181 @@ class CompositeTest(jtu.JaxTestCase):
|
||||
):
|
||||
grad(my_square)(1.0)
|
||||
|
||||
|
||||
class RaggedTest(jtu.JaxTestCase):
|
||||
|
||||
def testRaggedAllToAll(self):
|
||||
operand = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=jnp.float32)
|
||||
output = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=jnp.float32)
|
||||
input_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
|
||||
send_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
|
||||
output_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
|
||||
recv_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
|
||||
mlir_module = jax.jit(lax.ragged_all_to_all).lower(
|
||||
operand, output, input_offsets, send_sizes, output_offsets,
|
||||
recv_sizes).as_text()
|
||||
self.assertIn("stablehlo.custom_call @ragged_all_to_all", mlir_module)
|
||||
self.assertIn(
|
||||
"backend_config = {replica_groups = dense<[[0, 1, 2]]> :"
|
||||
" tensor<1x3xi64>}}",
|
||||
mlir_module,
|
||||
)
|
||||
|
||||
def testRaggedAllToAllErrors(self):
|
||||
operand = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=jnp.float32)
|
||||
output = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=jnp.float32)
|
||||
input_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
|
||||
send_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
|
||||
output_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
|
||||
recv_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
"ragged_all_to_all input and output shapes must be equal, except for"
|
||||
" the outermost dimension.",
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
operand,
|
||||
jnp.array([[0.0], [0.0], [0.0], [0.0], [0.0]], dtype=jnp.float32),
|
||||
input_offsets, send_sizes, output_offsets, recv_sizes)
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError, "ragged_all_to_all input_offsets must be integer type."
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
operand, output, jnp.array([0.0, 1.0, 3.0], dtype=jnp.float32),
|
||||
send_sizes, output_offsets, recv_sizes)
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError, "ragged_all_to_all send_sizes must be integer type."
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
operand, output, input_offsets,
|
||||
jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32), output_offsets,
|
||||
recv_sizes)
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError, "ragged_all_to_all output_offsets must be integer type."
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
operand, output, input_offsets, send_sizes,
|
||||
jnp.array([0.0, 1.0, 3.0], dtype=jnp.float32), recv_sizes)
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError, "ragged_all_to_all recv_sizes must be integer type."
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
operand, output, input_offsets, send_sizes, output_offsets,
|
||||
jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32))
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
"ragged_all_to_all input_offsets must be rank 1 with positive dimension"
|
||||
" size, but got shape (1, 3)",
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
operand, output, jnp.array([[0, 1, 3]], dtype=jnp.int32), send_sizes,
|
||||
output_offsets, recv_sizes)
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
"ragged_all_to_all input_offsets must be rank 1 with positive dimension"
|
||||
" size, but got shape (0,)",
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
operand, output, jnp.array([], dtype=jnp.int32), send_sizes,
|
||||
output_offsets, recv_sizes)
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
"ragged_all_to_all send_sizes must be rank 1 with positive dimension"
|
||||
" size, but got shape (1, 3)",
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
operand, output, input_offsets,
|
||||
jnp.array([[1, 2, 3]], dtype=jnp.int32), output_offsets, recv_sizes)
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
"ragged_all_to_all send_sizes must be rank 1 with positive dimension"
|
||||
" size, but got shape (0,)",
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
operand, output, input_offsets, jnp.array([], dtype=jnp.int32),
|
||||
output_offsets, recv_sizes)
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
"ragged_all_to_all output_offsets must be rank 1 with positive"
|
||||
" dimension size, but got shape (1, 3)",
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
operand, output, input_offsets, send_sizes,
|
||||
jnp.array([[0, 1, 3]], dtype=jnp.int32), recv_sizes)
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
"ragged_all_to_all output_offsets must be rank 1 with positive"
|
||||
" dimension size, but got shape (0,)",
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
operand, output, input_offsets, send_sizes,
|
||||
jnp.array([], dtype=jnp.int32), recv_sizes)
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
"ragged_all_to_all recv_sizes must be rank 1 with positive dimension"
|
||||
" size, but got shape (1, 3)",
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
operand, output, input_offsets, send_sizes, output_offsets,
|
||||
jnp.array([[1, 2, 3]], dtype=jnp.int32))
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
"ragged_all_to_all recv_sizes must be rank 1 with positive dimension"
|
||||
" size, but got shape (0,)",
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
operand, output, input_offsets, send_sizes, output_offsets,
|
||||
jnp.array([], dtype=jnp.int32))
|
||||
|
||||
@jtu.sample_product(
|
||||
[
|
||||
{'m': 5, 'k': 4, 'n': 3, 'num_groups': 1},
|
||||
{'m': 10, 'k': 9, 'n': 8, 'num_groups': 2},
|
||||
],
|
||||
dtype=jtu.dtypes.numeric,
|
||||
)
|
||||
def testRaggedDot(self, m, k, n, num_groups, dtype):
|
||||
"""Tests ragged_dot.
|
||||
|
||||
The ragged_dot is tested against numpy reference implementation, and by
|
||||
running JAX compilation.
|
||||
|
||||
Raises:
|
||||
SkipTest: in the case dtype is not supported.
|
||||
"""
|
||||
lhs_shape = (m, k)
|
||||
rhs_shape = (num_groups, k, n)
|
||||
|
||||
def group_sizes(m, num_groups):
|
||||
ends_no_final = jnp.sort(self.rng().choice(m, size=num_groups - 1))
|
||||
ends = jnp.concatenate(
|
||||
[ends_no_final, jnp.array([m], dtype=ends_no_final.dtype)])
|
||||
starts = jnp.concatenate(
|
||||
[jnp.zeros(1, dtype=ends_no_final.dtype), ends_no_final])
|
||||
return ends - starts
|
||||
|
||||
rng = jtu.rand_small(self.rng())
|
||||
args_maker = lambda: [
|
||||
rng(lhs_shape, dtype),
|
||||
rng(rhs_shape, dtype),
|
||||
group_sizes(m, num_groups),
|
||||
]
|
||||
self._CompileAndCheck(lax.ragged_dot, args_maker)
|
||||
self._CheckAgainstNumpy(
|
||||
lax_reference.ragged_dot, lax.ragged_dot, args_maker)
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user