mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[Mosaic TPU] Support mask concat
PiperOrigin-RevId: 728349788
This commit is contained in:
parent
1dc58b79bf
commit
bb68124c33
@ -777,7 +777,6 @@ class VectorLayoutInferer {
|
|||||||
TPU_CHECK_OP(0 <= dimension && dimension < res_rank,
|
TPU_CHECK_OP(0 <= dimension && dimension < res_rank,
|
||||||
"Expect a valid concatenate dimension");
|
"Expect a valid concatenate dimension");
|
||||||
VectorType res_ty = op.getResult().getType();
|
VectorType res_ty = op.getResult().getType();
|
||||||
int8_t bitwidth = res_ty.getElementTypeBitWidth();
|
|
||||||
|
|
||||||
std::optional<int64_t> tiling_dim;
|
std::optional<int64_t> tiling_dim;
|
||||||
if (dimension == res_ty.getRank() - 1) {
|
if (dimension == res_ty.getRank() - 1) {
|
||||||
@ -793,6 +792,7 @@ class VectorLayoutInferer {
|
|||||||
SmallVector<Layout, 4> op_layouts = getLayoutFromOperands(op);
|
SmallVector<Layout, 4> op_layouts = getLayoutFromOperands(op);
|
||||||
SmallVector<Layout> in_layouts;
|
SmallVector<Layout> in_layouts;
|
||||||
in_layouts.reserve(op.getSources().size());
|
in_layouts.reserve(op.getSources().size());
|
||||||
|
int8_t bitwidth = first_layout->bitwidth();
|
||||||
|
|
||||||
// Set implicit dim to treat 1D as (1, N) and tile it as (1, 128)
|
// Set implicit dim to treat 1D as (1, N) and tile it as (1, 128)
|
||||||
std::array<int64_t, 2> tiling =
|
std::array<int64_t, 2> tiling =
|
||||||
|
@ -438,6 +438,38 @@ class OpsTest(PallasBaseTest):
|
|||||||
y = jax.random.normal(k2, (8, 128), dtype=dtype)
|
y = jax.random.normal(k2, (8, 128), dtype=dtype)
|
||||||
np.testing.assert_allclose(run(x, y), jax.lax.div(x, y), **kwargs)
|
np.testing.assert_allclose(run(x, y), jax.lax.div(x, y), **kwargs)
|
||||||
|
|
||||||
|
@parameterized.product(
|
||||||
|
dtype=[jnp.float32, jnp.bfloat16, jnp.int8],
|
||||||
|
)
|
||||||
|
def test_concat_mask(self, dtype):
|
||||||
|
if not jtu.if_cloud_tpu_at_least(2025, 2, 19):
|
||||||
|
self.skipTest("Requires libtpu built after 2025-02-19")
|
||||||
|
bitwidth = pallas_utils.dtype_bitwidth(dtype)
|
||||||
|
if jtu.get_tpu_version() < 5 and bitwidth < 32:
|
||||||
|
self.skipTest(
|
||||||
|
f"Not implemented: cast vector to mask with bitwidth == {bitwidth}"
|
||||||
|
)
|
||||||
|
shape = (128, 128)
|
||||||
|
|
||||||
|
def kernel(x, out):
|
||||||
|
mask = x[...] != 0
|
||||||
|
concated_mask = jnp.concatenate([mask, mask], axis=0)
|
||||||
|
concated_x = jnp.concatenate([x[:], x[:]], axis=0)
|
||||||
|
out[:] = lax.select(concated_mask, concated_x, jnp.zeros_like(concated_x))
|
||||||
|
|
||||||
|
x = jax.random.normal(jax.random.key(1234), shape, dtype=jnp.float32)
|
||||||
|
if dtype == jnp.int8:
|
||||||
|
x = (x * 100).astype(jnp.int8)
|
||||||
|
else:
|
||||||
|
x = x.astype(dtype)
|
||||||
|
out = self.pallas_call(
|
||||||
|
kernel, out_shape=jax.ShapeDtypeStruct((shape[0] * 2, shape[1]), dtype)
|
||||||
|
)(x)
|
||||||
|
concated_mask = jnp.concatenate([x != 0, x != 0], axis=0)
|
||||||
|
concated_x = jnp.concatenate([x, x], axis=0)
|
||||||
|
expected = lax.select(concated_mask, concated_x, jnp.zeros_like(concated_x))
|
||||||
|
np.testing.assert_array_equal(out, expected)
|
||||||
|
|
||||||
|
|
||||||
class OpsInterpretTest(OpsTest):
|
class OpsInterpretTest(OpsTest):
|
||||||
INTERPRET = True
|
INTERPRET = True
|
||||||
|
Loading…
x
Reference in New Issue
Block a user