From bb68124c336e6e0ad69a78f03929a4be8e4bf9e8 Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Tue, 18 Feb 2025 14:03:14 -0800 Subject: [PATCH] [Mosaic TPU] Support mask concat PiperOrigin-RevId: 728349788 --- .../tpu/transforms/infer_vector_layout.cc | 2 +- tests/pallas/tpu_ops_test.py | 32 +++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index d22edf08e..33912fddf 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -777,7 +777,6 @@ class VectorLayoutInferer { TPU_CHECK_OP(0 <= dimension && dimension < res_rank, "Expect a valid concatenate dimension"); VectorType res_ty = op.getResult().getType(); - int8_t bitwidth = res_ty.getElementTypeBitWidth(); std::optional tiling_dim; if (dimension == res_ty.getRank() - 1) { @@ -793,6 +792,7 @@ class VectorLayoutInferer { SmallVector op_layouts = getLayoutFromOperands(op); SmallVector in_layouts; in_layouts.reserve(op.getSources().size()); + int8_t bitwidth = first_layout->bitwidth(); // Set implicit dim to treat 1D as (1, N) and tile it as (1, 128) std::array tiling = diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index 5c4e60726..2011602d8 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -438,6 +438,38 @@ class OpsTest(PallasBaseTest): y = jax.random.normal(k2, (8, 128), dtype=dtype) np.testing.assert_allclose(run(x, y), jax.lax.div(x, y), **kwargs) + @parameterized.product( + dtype=[jnp.float32, jnp.bfloat16, jnp.int8], + ) + def test_concat_mask(self, dtype): + if not jtu.if_cloud_tpu_at_least(2025, 2, 19): + self.skipTest("Requires libtpu built after 2025-02-19") + bitwidth = pallas_utils.dtype_bitwidth(dtype) + if jtu.get_tpu_version() < 5 and bitwidth < 32: + self.skipTest( + f"Not implemented: cast vector to mask with bitwidth == {bitwidth}" + ) + shape = (128, 128) + + def kernel(x, out): + mask = x[...] != 0 + concated_mask = jnp.concatenate([mask, mask], axis=0) + concated_x = jnp.concatenate([x[:], x[:]], axis=0) + out[:] = lax.select(concated_mask, concated_x, jnp.zeros_like(concated_x)) + + x = jax.random.normal(jax.random.key(1234), shape, dtype=jnp.float32) + if dtype == jnp.int8: + x = (x * 100).astype(jnp.int8) + else: + x = x.astype(dtype) + out = self.pallas_call( + kernel, out_shape=jax.ShapeDtypeStruct((shape[0] * 2, shape[1]), dtype) + )(x) + concated_mask = jnp.concatenate([x != 0, x != 0], axis=0) + concated_x = jnp.concatenate([x, x], axis=0) + expected = lax.select(concated_mask, concated_x, jnp.zeros_like(concated_x)) + np.testing.assert_array_equal(out, expected) + class OpsInterpretTest(OpsTest): INTERPRET = True