mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic TPU] Set in_bounds for transfer_read used in replicated loads
This is in preparation for integrating changes from MLIR:
2ee5586ac7 (diff-3cbcc8f6c740f2d6e16f5a0c19daf4bb8224ad92d9e430fc10c935587a67dcce)
Also don't pass in `padding` since there is a builder that uses `padding` of zero as default.
PiperOrigin-RevId: 654370142
This commit is contained in:
parent
ff36ea5de3
commit
f4b09234a0
@ -2898,7 +2898,6 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op,
|
||||
auto batch_base_idxs = ArrayRef<Value>(base_indices).drop_back(tiled_dims);
|
||||
const LayoutOffsets offsets = layout_out.offsets();
|
||||
AffineMap load_map;
|
||||
arith::ConstantOp padding;
|
||||
if (offsets[1] == std::nullopt) {
|
||||
return op.emitOpError(
|
||||
"Not implemented: Load replicated along lanes is unsupported");
|
||||
@ -2918,10 +2917,6 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op,
|
||||
{getAffineConstantExpr(0, mlir_ctx),
|
||||
getAffineDimExpr(memref_ty.getRank() - 1, mlir_ctx)},
|
||||
mlir_ctx);
|
||||
FAILUREOR_ASSIGN_OR_RETURN(const TypedAttr zero_attr,
|
||||
getZeroIntOrFloatAttr(vty.getElementType()));
|
||||
padding =
|
||||
builder.create<arith::ConstantOp>(vty.getElementType(), zero_attr);
|
||||
}
|
||||
|
||||
xla::Array<Value> tiles(
|
||||
@ -2962,14 +2957,16 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op,
|
||||
builder.getI32IntegerAttr(sublane_stride));
|
||||
} else {
|
||||
if (load_map) {
|
||||
CHECK(padding);
|
||||
if (layout_out.bitwidth() != 32) {
|
||||
load_op.emitOpError("Not implemented");
|
||||
return absl::UnimplementedError("");
|
||||
}
|
||||
tile = builder.create<vector::TransferReadOp>(
|
||||
target_ty, base_addr, idxs, load_map, padding, nullptr,
|
||||
nullptr);
|
||||
target_ty, base_addr, idxs, load_map,
|
||||
// TODO(tlongeri): Not sure whether we are obeying the semantics
|
||||
// of in_bounds, but our lowering ignores it and this path will
|
||||
// removed soon anyway.
|
||||
SmallVector<bool>(2, true));
|
||||
} else {
|
||||
const SmallVector<bool> sublane_mask(ctx.target_shape[0], true);
|
||||
const auto sublane_mask_attr =
|
||||
|
Loading…
x
Reference in New Issue
Block a user