mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[Mosaic TPU] Support mask concat
PiperOrigin-RevId: 728349788
This commit is contained in:
parent
1dc58b79bf
commit
bb68124c33
@ -777,7 +777,6 @@ class VectorLayoutInferer {
|
||||
TPU_CHECK_OP(0 <= dimension && dimension < res_rank,
|
||||
"Expect a valid concatenate dimension");
|
||||
VectorType res_ty = op.getResult().getType();
|
||||
int8_t bitwidth = res_ty.getElementTypeBitWidth();
|
||||
|
||||
std::optional<int64_t> tiling_dim;
|
||||
if (dimension == res_ty.getRank() - 1) {
|
||||
@ -793,6 +792,7 @@ class VectorLayoutInferer {
|
||||
SmallVector<Layout, 4> op_layouts = getLayoutFromOperands(op);
|
||||
SmallVector<Layout> in_layouts;
|
||||
in_layouts.reserve(op.getSources().size());
|
||||
int8_t bitwidth = first_layout->bitwidth();
|
||||
|
||||
// Set implicit dim to treat 1D as (1, N) and tile it as (1, 128)
|
||||
std::array<int64_t, 2> tiling =
|
||||
|
@ -438,6 +438,38 @@ class OpsTest(PallasBaseTest):
|
||||
y = jax.random.normal(k2, (8, 128), dtype=dtype)
|
||||
np.testing.assert_allclose(run(x, y), jax.lax.div(x, y), **kwargs)
|
||||
|
||||
@parameterized.product(
|
||||
dtype=[jnp.float32, jnp.bfloat16, jnp.int8],
|
||||
)
|
||||
def test_concat_mask(self, dtype):
|
||||
if not jtu.if_cloud_tpu_at_least(2025, 2, 19):
|
||||
self.skipTest("Requires libtpu built after 2025-02-19")
|
||||
bitwidth = pallas_utils.dtype_bitwidth(dtype)
|
||||
if jtu.get_tpu_version() < 5 and bitwidth < 32:
|
||||
self.skipTest(
|
||||
f"Not implemented: cast vector to mask with bitwidth == {bitwidth}"
|
||||
)
|
||||
shape = (128, 128)
|
||||
|
||||
def kernel(x, out):
|
||||
mask = x[...] != 0
|
||||
concated_mask = jnp.concatenate([mask, mask], axis=0)
|
||||
concated_x = jnp.concatenate([x[:], x[:]], axis=0)
|
||||
out[:] = lax.select(concated_mask, concated_x, jnp.zeros_like(concated_x))
|
||||
|
||||
x = jax.random.normal(jax.random.key(1234), shape, dtype=jnp.float32)
|
||||
if dtype == jnp.int8:
|
||||
x = (x * 100).astype(jnp.int8)
|
||||
else:
|
||||
x = x.astype(dtype)
|
||||
out = self.pallas_call(
|
||||
kernel, out_shape=jax.ShapeDtypeStruct((shape[0] * 2, shape[1]), dtype)
|
||||
)(x)
|
||||
concated_mask = jnp.concatenate([x != 0, x != 0], axis=0)
|
||||
concated_x = jnp.concatenate([x, x], axis=0)
|
||||
expected = lax.select(concated_mask, concated_x, jnp.zeros_like(concated_x))
|
||||
np.testing.assert_array_equal(out, expected)
|
||||
|
||||
|
||||
class OpsInterpretTest(OpsTest):
|
||||
INTERPRET = True
|
||||
|
Loading…
x
Reference in New Issue
Block a user