mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
[Mosaic TPU] Support non-32 bit mask relayout
PiperOrigin-RevId: 721552594
This commit is contained in:
parent
9dfe03c5ea
commit
785a63ad0f
@ -6585,13 +6585,6 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
|
||||
}
|
||||
VectorType vty = v.getType();
|
||||
const bool is_mask = vty.getElementTypeBitWidth() == 1;
|
||||
if (is_mask) {
|
||||
if (src.bitwidth() != 32 || dst.bitwidth() != 32) {
|
||||
return emitError(v.getLoc(),
|
||||
"Not implemented: mask relayout with non-32 bitwidth in "
|
||||
"vector layout");
|
||||
}
|
||||
}
|
||||
{
|
||||
// Replication imposes a replication constraint on the *logical* value of
|
||||
// the vector: When moving along a replicated axis, all elements must be
|
||||
@ -6626,21 +6619,22 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
|
||||
xla::Array<Value> src_tiles,
|
||||
disassemble(builder, src, v, target_shape, /*use_implicit_shape=*/true));
|
||||
if (is_mask) {
|
||||
auto new_tile_ty =
|
||||
getNativeVregOrVmaskType(builder.getI32Type(), 32, target_shape);
|
||||
auto new_tile_ty = getNativeVregOrVmaskType(
|
||||
builder.getIntegerType(bitwidth), bitwidth, target_shape);
|
||||
src_tiles.Each([&](const absl::Span<const int64_t> idx, Value *tile) {
|
||||
*tile =
|
||||
builder.create<arith::ExtUIOp>(tile->getLoc(), new_tile_ty, *tile);
|
||||
});
|
||||
vty = VectorType::get(vty.getShape(), builder.getI32Type());
|
||||
vty = VectorType::get(vty.getShape(), builder.getIntegerType(bitwidth));
|
||||
}
|
||||
auto assemble_with_mask_check = [&](xla::Array<Value> &tiles,
|
||||
bool use_implicit_shape = false) {
|
||||
if (is_mask) {
|
||||
auto zeros_tile = builder.create<arith::ConstantOp>(
|
||||
tiles.begin()->getLoc(),
|
||||
DenseElementsAttr::get(cast<VectorType>(tiles.begin()->getType()),
|
||||
builder.getI32IntegerAttr(0)));
|
||||
DenseElementsAttr::get(
|
||||
cast<VectorType>(tiles.begin()->getType()),
|
||||
builder.getIntegerAttr(builder.getIntegerType(bitwidth), 0)));
|
||||
tiles.Each([&](const absl::Span<const int64_t> idx, Value *tile) {
|
||||
*tile = builder.create<arith::CmpIOp>(
|
||||
tile->getLoc(), arith::CmpIPredicate::ne, *tile, zeros_tile);
|
||||
@ -6695,9 +6689,7 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
|
||||
}
|
||||
*vreg = src_vregs(local_idx);
|
||||
});
|
||||
return assemble(builder, vty, dst, std::move(dst_vregs), target_shape,
|
||||
/*use_implicit_shape=*/true)
|
||||
.getResult();
|
||||
return assemble_with_mask_check(dst_vregs, /*use_implicit_shape=*/true);
|
||||
}
|
||||
src_tiles.Reshape(dst.tileArrayImplicitShape(vty.getShape(), target_shape));
|
||||
return assemble_with_mask_check(src_tiles,
|
||||
|
Loading…
x
Reference in New Issue
Block a user