[Mosaic TPU] Support non-32 bit mask relayout

PiperOrigin-RevId: 721552594
This commit is contained in:
Jevin Jiang 2025-01-30 16:12:28 -08:00 committed by jax authors
parent 9dfe03c5ea
commit 785a63ad0f

View File

@ -6585,13 +6585,6 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
}
VectorType vty = v.getType();
const bool is_mask = vty.getElementTypeBitWidth() == 1;
if (is_mask) {
if (src.bitwidth() != 32 || dst.bitwidth() != 32) {
return emitError(v.getLoc(),
"Not implemented: mask relayout with non-32 bitwidth in "
"vector layout");
}
}
{
// Replication imposes a replication constraint on the *logical* value of
// the vector: When moving along a replicated axis, all elements must be
@ -6626,21 +6619,22 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
xla::Array<Value> src_tiles,
disassemble(builder, src, v, target_shape, /*use_implicit_shape=*/true));
if (is_mask) {
auto new_tile_ty =
getNativeVregOrVmaskType(builder.getI32Type(), 32, target_shape);
auto new_tile_ty = getNativeVregOrVmaskType(
builder.getIntegerType(bitwidth), bitwidth, target_shape);
src_tiles.Each([&](const absl::Span<const int64_t> idx, Value *tile) {
*tile =
builder.create<arith::ExtUIOp>(tile->getLoc(), new_tile_ty, *tile);
});
vty = VectorType::get(vty.getShape(), builder.getI32Type());
vty = VectorType::get(vty.getShape(), builder.getIntegerType(bitwidth));
}
auto assemble_with_mask_check = [&](xla::Array<Value> &tiles,
bool use_implicit_shape = false) {
if (is_mask) {
auto zeros_tile = builder.create<arith::ConstantOp>(
tiles.begin()->getLoc(),
DenseElementsAttr::get(cast<VectorType>(tiles.begin()->getType()),
builder.getI32IntegerAttr(0)));
DenseElementsAttr::get(
cast<VectorType>(tiles.begin()->getType()),
builder.getIntegerAttr(builder.getIntegerType(bitwidth), 0)));
tiles.Each([&](const absl::Span<const int64_t> idx, Value *tile) {
*tile = builder.create<arith::CmpIOp>(
tile->getLoc(), arith::CmpIPredicate::ne, *tile, zeros_tile);
@ -6695,9 +6689,7 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
}
*vreg = src_vregs(local_idx);
});
return assemble(builder, vty, dst, std::move(dst_vregs), target_shape,
/*use_implicit_shape=*/true)
.getResult();
return assemble_with_mask_check(dst_vregs, /*use_implicit_shape=*/true);
}
src_tiles.Reshape(dst.tileArrayImplicitShape(vty.getShape(), target_shape));
return assemble_with_mask_check(src_tiles,