mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[XLA:Mosaic][Pallas] Enable vector.ExtractOp for non-zero indices.
PiperOrigin-RevId: 679283281
This commit is contained in:
parent
46dbb6588a
commit
9f4e8d0039
@ -3555,24 +3555,53 @@ LogicalResult vector_extract_rule(RewriteContext &ctx, Operation &op,
|
||||
op.erase();
|
||||
return success();
|
||||
} else {
|
||||
for (int64_t i : extract_op.getStaticPosition()) {
|
||||
if (i != 0) {
|
||||
return op.emitOpError(
|
||||
"Not implemented: Only 0 indices supported for scalar results");
|
||||
}
|
||||
}
|
||||
// TODO(b/367459476): Support non-zero offsets.
|
||||
if (layout_in.offsets() != LayoutOffsets{0, 0}) {
|
||||
return op.emitOpError("Not implemented: Unsupported layout");
|
||||
}
|
||||
auto [sub_tile, lane_tile] = layout_in.tiling();
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
const xla::Array<Value> vregs,
|
||||
disassemble(builder, layout_in, extract_op.getVector(),
|
||||
ctx.target_shape));
|
||||
TPU_ASSERT_GT_OP(vregs.num_elements(), 0);
|
||||
|
||||
SmallVector<int64_t> indices(extract_op.getStaticPosition());
|
||||
auto vreg_slice = layout_in.vregSlice(ctx.target_shape);
|
||||
std::array<int64_t, 2> position = {0, 0};
|
||||
SmallVector<int64_t> vreg_index(indices);
|
||||
// TODO(b/367459476): Support non-VREG-aligned tiling.
|
||||
CHECK_EQ(lane_tile, ctx.target_shape[1]);
|
||||
layout_in.insertImplicit(indices, static_cast<int64_t>(0));
|
||||
layout_in.insertImplicit(vreg_index, static_cast<int64_t>(0));
|
||||
int i = *(indices.end()-2);
|
||||
int j = *(indices.end()-1);
|
||||
*(vreg_index.end() -2) = i / vreg_slice[0];
|
||||
*(vreg_index.end() -1) = j / vreg_slice[1];
|
||||
layout_in.eraseImplicit(vreg_index);
|
||||
position[0] = ((j % vreg_slice[1]) / lane_tile * sub_tile
|
||||
) + i % sub_tile;
|
||||
position[1] = j % lane_tile;
|
||||
|
||||
TPU_ASSERT_LT_OP(vreg_index, vregs.dimensions());
|
||||
Value extracted_vreg = vregs(vreg_index);
|
||||
|
||||
// Invert the offsets to get the rotation amount.
|
||||
position[0] = (ctx.target_shape[0] - position[0]) % ctx.target_shape[0];
|
||||
position[1] = (ctx.target_shape[1] - position[1]) % ctx.target_shape[1];
|
||||
auto res_vreg_ty = extracted_vreg.getType();
|
||||
Value shift = builder.create<arith::ConstantOp>(
|
||||
builder.getIntegerAttr(builder.getI32Type(), position[0]));
|
||||
Value rotated_vreg = builder.create<tpu::DynamicRotateOp>(
|
||||
res_vreg_ty, extracted_vreg, shift, 0, /*stride*/nullptr, nullptr);
|
||||
shift = builder.create<arith::ConstantOp>(
|
||||
builder.getIntegerAttr(builder.getI32Type(), position[1]));
|
||||
rotated_vreg = builder.create<tpu::DynamicRotateOp>(
|
||||
res_vreg_ty, rotated_vreg, shift, 1, /*stride*/nullptr, nullptr);
|
||||
extract_op.replaceAllUsesWith(
|
||||
builder
|
||||
.create<vector::ExtractOp>(op.getLoc(), *vregs.data(),
|
||||
ArrayRef<int64_t>{0, 0})
|
||||
builder.create<vector::ExtractOp>(
|
||||
op.getLoc(), rotated_vreg,
|
||||
ArrayRef<int64_t>{0, 0})
|
||||
.getResult());
|
||||
}
|
||||
extract_op.erase();
|
||||
|
@ -44,22 +44,22 @@ class PallasErrorHandlingTest(jtu.JaxTestCase):
|
||||
if not jtu.test_device_matches(["tpu"]):
|
||||
self.skipTest("Test only works on TPU.")
|
||||
|
||||
def test_vector_extract_nonzero(self):
|
||||
input_arr = jax.random.uniform(jax.random.key(0), (2, 2), dtype=jnp.float32)
|
||||
out_shape = jax.ShapeDtypeStruct((1, 1), jnp.float32)
|
||||
def test_non_singular_stride(self):
|
||||
input_arr = jax.random.uniform(
|
||||
jax.random.key(0), (8, 128), dtype=jnp.float32)
|
||||
out_shape = jax.ShapeDtypeStruct((8, 16), jnp.float32)
|
||||
grid_spec = pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=0,
|
||||
in_specs=[
|
||||
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
|
||||
],
|
||||
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
|
||||
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
|
||||
)
|
||||
|
||||
@functools.partial(pl.pallas_call, out_shape=out_shape, grid_spec=grid_spec)
|
||||
def test_kernel(input_ref, output_ref):
|
||||
val = input_ref[...]
|
||||
x = val[0, 0] + val[0, 1]
|
||||
output_ref[0, 0] = x
|
||||
x = input_ref[:, ::8]
|
||||
output_ref[...] = x
|
||||
|
||||
# Test that a Mosaic error is raised. This assert is a guard against
|
||||
# underlying changes in Mosaic.
|
||||
@ -67,7 +67,7 @@ class PallasErrorHandlingTest(jtu.JaxTestCase):
|
||||
# the test example to force a different error.
|
||||
with self.assertRaisesRegex(
|
||||
error_handling.MosaicError,
|
||||
"Not implemented: Only 0 indices supported for scalar results",
|
||||
"Not Implemented: Stride on last dim is not 1",
|
||||
):
|
||||
test_kernel(input_arr)
|
||||
|
||||
@ -78,7 +78,7 @@ class PallasErrorHandlingTest(jtu.JaxTestCase):
|
||||
except error_handling.MosaicError as e:
|
||||
tb_string = traceback.format_tb(e.__traceback__)
|
||||
tb_string = "".join(tb_string)
|
||||
self.assertEndsWith(tb_string, "x = val[0, 0] + val[0, 1]\n")
|
||||
self.assertEndsWith(tb_string, "x = input_ref[:, ::8]\n")
|
||||
|
||||
@jax.jit
|
||||
def kernel_in_jitted_fn(x):
|
||||
@ -91,7 +91,7 @@ class PallasErrorHandlingTest(jtu.JaxTestCase):
|
||||
except error_handling.MosaicError as e:
|
||||
tb_string = traceback.format_tb(e.__traceback__)
|
||||
tb_string = "".join(tb_string)
|
||||
self.assertEndsWith(tb_string, "x = val[0, 0] + val[0, 1]\n")
|
||||
self.assertEndsWith(tb_string, "x = input_ref[:, ::8]\n")
|
||||
|
||||
def test_invalid_smem_vmem_verification_error(self):
|
||||
input_arr = jax.random.uniform(jax.random.key(0), (2, 2), dtype=jnp.float32)
|
||||
|
Loading…
x
Reference in New Issue
Block a user