[Mosaic TPU] Support mask concat

PiperOrigin-RevId: 728349788
This commit is contained in:
Jevin Jiang 2025-02-18 14:03:14 -08:00 committed by jax authors
parent 1dc58b79bf
commit bb68124c33
2 changed files with 33 additions and 1 deletions

View File

@ -777,7 +777,6 @@ class VectorLayoutInferer {
TPU_CHECK_OP(0 <= dimension && dimension < res_rank,
"Expect a valid concatenate dimension");
VectorType res_ty = op.getResult().getType();
int8_t bitwidth = res_ty.getElementTypeBitWidth();
std::optional<int64_t> tiling_dim;
if (dimension == res_ty.getRank() - 1) {
@ -793,6 +792,7 @@ class VectorLayoutInferer {
SmallVector<Layout, 4> op_layouts = getLayoutFromOperands(op);
SmallVector<Layout> in_layouts;
in_layouts.reserve(op.getSources().size());
int8_t bitwidth = first_layout->bitwidth();
// Set implicit dim to treat 1D as (1, N) and tile it as (1, 128)
std::array<int64_t, 2> tiling =

View File

@ -438,6 +438,38 @@ class OpsTest(PallasBaseTest):
y = jax.random.normal(k2, (8, 128), dtype=dtype)
np.testing.assert_allclose(run(x, y), jax.lax.div(x, y), **kwargs)
@parameterized.product(
dtype=[jnp.float32, jnp.bfloat16, jnp.int8],
)
def test_concat_mask(self, dtype):
if not jtu.if_cloud_tpu_at_least(2025, 2, 19):
self.skipTest("Requires libtpu built after 2025-02-19")
bitwidth = pallas_utils.dtype_bitwidth(dtype)
if jtu.get_tpu_version() < 5 and bitwidth < 32:
self.skipTest(
f"Not implemented: cast vector to mask with bitwidth == {bitwidth}"
)
shape = (128, 128)
def kernel(x, out):
mask = x[...] != 0
concated_mask = jnp.concatenate([mask, mask], axis=0)
concated_x = jnp.concatenate([x[:], x[:]], axis=0)
out[:] = lax.select(concated_mask, concated_x, jnp.zeros_like(concated_x))
x = jax.random.normal(jax.random.key(1234), shape, dtype=jnp.float32)
if dtype == jnp.int8:
x = (x * 100).astype(jnp.int8)
else:
x = x.astype(dtype)
out = self.pallas_call(
kernel, out_shape=jax.ShapeDtypeStruct((shape[0] * 2, shape[1]), dtype)
)(x)
concated_mask = jnp.concatenate([x != 0, x != 0], axis=0)
concated_x = jnp.concatenate([x, x], axis=0)
expected = lax.select(concated_mask, concated_x, jnp.zeros_like(concated_x))
np.testing.assert_array_equal(out, expected)
class OpsInterpretTest(OpsTest):
INTERPRET = True