mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[XLA:Mosaic] Support trunc/ext op for 1D vector with any implicit dim.
PiperOrigin-RevId: 626466602
This commit is contained in:
parent
6e23c14f85
commit
167161706c
@ -270,7 +270,6 @@ class VectorLayout {
|
||||
|
||||
SmallVector<int64_t> implicitShape(ArrayRef<int64_t> shape) const;
|
||||
|
||||
private:
|
||||
SmallVector<int64_t> tileArrayImplicitShape(
|
||||
ArrayRef<int64_t> shape, std::array<int64_t, 2> target_shape) const;
|
||||
|
||||
|
@ -658,15 +658,23 @@ LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op,
|
||||
const auto result_ty = cast<VectorType>(op.getResult().getType());
|
||||
auto source = cast<TypedValue<VectorType>>(op.getIn());
|
||||
const auto source_ty = source.getType();
|
||||
auto output_vregs_shape =
|
||||
layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape);
|
||||
if (layout_out.bitwidth() != 32) {
|
||||
return op.emitOpError(
|
||||
"Not implemented: Only extensions to 32-bit supported");
|
||||
}
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
const xla::Array<Value> input_vregs,
|
||||
xla::Array<Value> input_vregs,
|
||||
disassemble(builder, layout_in, source, ctx.target_shape));
|
||||
xla::Array<Value> output_vregs(
|
||||
layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape));
|
||||
xla::Array<Value> output_vregs(output_vregs_shape);
|
||||
// TODO(jevinjiang): maybe just use tileArrayImplicitShape in disassemble?
|
||||
if (layout_in.implicit_dim() != VectorLayout::ImplicitDim::kNone) {
|
||||
input_vregs.Reshape(layout_in.tileArrayImplicitShape(source_ty.getShape(),
|
||||
ctx.target_shape));
|
||||
output_vregs.Reshape(layout_out.tileArrayImplicitShape(result_ty.getShape(),
|
||||
ctx.target_shape));
|
||||
}
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
const VectorType res_vreg_ty,
|
||||
getNativeVregType(result_ty.getElementType(), ctx.target_shape));
|
||||
@ -676,51 +684,24 @@ LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op,
|
||||
if (layout_in.offsets() != layout_out.offsets()) {
|
||||
return op.emitOpError("Not implemented: Change of offsets during the cast");
|
||||
}
|
||||
switch (layout_in.implicit_dim()) {
|
||||
case VectorLayout::ImplicitDim::kNone: {
|
||||
if (layout_in.tiling() != layout_out.tiling()) {
|
||||
return op.emitOpError(
|
||||
"Not implemented: Changing tiling during the cast");
|
||||
}
|
||||
auto tiling = layout_in.tiling();
|
||||
if (ctx.target_shape[0] % tiling[0] != 0 ||
|
||||
ctx.target_shape[1] != tiling[1]) {
|
||||
return op.emitOpError("Not implemented: tiling not supported");
|
||||
}
|
||||
const int packing = layout_in.packing();
|
||||
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
|
||||
SmallVector<int64_t> input_vreg_idxs(toArrayRef(idxs));
|
||||
input_vreg_idxs.back() /= packing;
|
||||
const int64_t vreg_part = idxs.back() % packing;
|
||||
*v = builder.create<UnpackSubelementsOp>(
|
||||
res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part);
|
||||
});
|
||||
} break;
|
||||
case VectorLayout::ImplicitDim::kMinor:
|
||||
return op.emitOpError(
|
||||
"Not implemented: Only casts of lane-oriented values supported");
|
||||
case VectorLayout::ImplicitDim::kSecondMinor: {
|
||||
auto is_one_tile = [](VectorType vty, VectorLayout layout) {
|
||||
auto implicit_shape = layout.implicitShape(vty.getShape());
|
||||
auto tiled_shape = ArrayRef<int64_t>(implicit_shape).take_back(2);
|
||||
return (layout.offsets()[0].value_or(0) + tiled_shape[0] <=
|
||||
layout.tiling()[0]) &&
|
||||
(layout.offsets()[1].value_or(0) + tiled_shape[1] <=
|
||||
layout.tiling()[1]);
|
||||
};
|
||||
if (input_vregs.dimensions() != absl::Span<const int64_t>{1} ||
|
||||
output_vregs.dimensions() != absl::Span<const int64_t>{1} ||
|
||||
!is_one_tile(source_ty, layout_in) ||
|
||||
!is_one_tile(result_ty, layout_out)) {
|
||||
return op.emitOpError("Not implemented");
|
||||
}
|
||||
if (layout_in.offsets()[0] >= ctx.target_shape[0]) {
|
||||
return op.emitOpError("Not implemented");
|
||||
}
|
||||
auto unpack_subelements_op = builder.create<UnpackSubelementsOp>(
|
||||
res_vreg_ty, *input_vregs.begin(), 0);
|
||||
output_vregs.Fill(unpack_subelements_op.getResult());
|
||||
}
|
||||
if (layout_in.tiling() != layout_out.tiling()) {
|
||||
return op.emitOpError("Not implemented: Changing tiling during the cast");
|
||||
}
|
||||
auto tiling = layout_in.tiling();
|
||||
if (ctx.target_shape[0] % tiling[0] != 0 ||
|
||||
ctx.target_shape[1] != tiling[1]) {
|
||||
return op.emitOpError("Not implemented: tiling not supported");
|
||||
}
|
||||
const int packing = layout_in.packing();
|
||||
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
|
||||
SmallVector<int64_t> input_vreg_idxs(toArrayRef(idxs));
|
||||
input_vreg_idxs.back() /= packing;
|
||||
const int64_t vreg_part = idxs.back() % packing;
|
||||
*v = builder.create<UnpackSubelementsOp>(
|
||||
res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part);
|
||||
});
|
||||
if (layout_out.implicit_dim() != VectorLayout::ImplicitDim::kNone) {
|
||||
output_vregs.Reshape(output_vregs_shape);
|
||||
}
|
||||
op.replaceAllUsesWith(assemble(builder, result_ty, layout_out,
|
||||
std::move(output_vregs), ctx.target_shape)
|
||||
@ -762,73 +743,85 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op,
|
||||
const VectorLayout &layout_in,
|
||||
const VectorLayout &layout_out) {
|
||||
ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation());
|
||||
auto source = cast<TypedValue<VectorType>>(op.getIn());
|
||||
const auto source_ty = source.getType();
|
||||
auto result_ty = cast<VectorType>(op.getResult().getType());
|
||||
auto output_vregs_shape =
|
||||
layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape);
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
const xla::Array<Value> input_vregs,
|
||||
disassemble(builder, layout_in, cast<TypedValue<VectorType>>(op.getIn()),
|
||||
ctx.target_shape));
|
||||
xla::Array<Value> output_vregs(
|
||||
layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape));
|
||||
xla::Array<Value> input_vregs,
|
||||
disassemble(builder, layout_in, source, ctx.target_shape));
|
||||
xla::Array<Value> output_vregs(output_vregs_shape);
|
||||
if (layout_in.bitwidth() != 32) {
|
||||
return op.emitOpError("Not implemented: Only 32-bit truncation supported");
|
||||
}
|
||||
if (layout_in.offsets() != layout_out.offsets()) {
|
||||
return op.emitOpError(
|
||||
"Not implemented: Change of offsets during the truncation");
|
||||
}
|
||||
if (layout_in.implicit_dim() != layout_out.implicit_dim()) {
|
||||
return op.emitOpError("Not implemented: Change of layout during the cast");
|
||||
}
|
||||
if (layout_in.tiling() != ctx.target_shape) {
|
||||
return op.emitOpError("Not implemented: Only (8,128) tiling supported");
|
||||
}
|
||||
if (layout_in.implicit_dim() != VectorLayout::ImplicitDim::kNone) {
|
||||
input_vregs.Reshape(layout_in.tileArrayImplicitShape(source_ty.getShape(),
|
||||
ctx.target_shape));
|
||||
output_vregs.Reshape(layout_out.tileArrayImplicitShape(result_ty.getShape(),
|
||||
ctx.target_shape));
|
||||
}
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
VectorType res_vreg_ty,
|
||||
getNativeVregType(result_ty.getElementType(), ctx.target_shape));
|
||||
if (layout_in.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
|
||||
layout_out.implicit_dim() == VectorLayout::ImplicitDim::kNone) {
|
||||
if (layout_in.tiling() != ctx.target_shape) {
|
||||
return op.emitOpError("Not implemented: Only (8,128) tiling supported");
|
||||
}
|
||||
if (layout_out.tiling() == ctx.target_shape) {
|
||||
const int packing = layout_out.packing();
|
||||
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
|
||||
SmallVector<Value> parts;
|
||||
SmallVector<int64_t> idxs_local(toArrayRef(idxs));
|
||||
idxs_local.back() *= packing;
|
||||
for (int64_t i = 0; i < packing; ++i) {
|
||||
parts.push_back(input_vregs(idxs_local));
|
||||
// Pack any data lying around if OOB
|
||||
if (idxs_local.back() < input_vregs.dimensions().back() - 1) {
|
||||
++idxs_local.back();
|
||||
}
|
||||
}
|
||||
*v = builder.create<PackSubelementsOp>(res_vreg_ty, parts);
|
||||
});
|
||||
|
||||
} else if (layout_out.hasNativeTiling(ctx.target_shape)) {
|
||||
int packing = layout_out.packing();
|
||||
if (layout_out.tiling() == ctx.target_shape) {
|
||||
const int packing = layout_out.packing();
|
||||
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
|
||||
SmallVector<Value> parts;
|
||||
parts.reserve(packing);
|
||||
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
|
||||
CHECK_GE(idxs.size(), 2);
|
||||
SmallVector<int64_t> idxs_local(toArrayRef(idxs));
|
||||
idxs_local[idxs.size() - 2] *= packing;
|
||||
SmallVector<int64_t> idxs_local(toArrayRef(idxs));
|
||||
idxs_local.back() *= packing;
|
||||
for (int64_t i = 0; i < packing; ++i) {
|
||||
parts.push_back(input_vregs(idxs_local));
|
||||
idxs_local[idxs.size() - 2]++;
|
||||
while (parts.size() < packing) {
|
||||
if (*(idxs_local.end() - 2) < *(input_vregs.dimensions().end() - 2)) {
|
||||
parts.push_back(input_vregs(idxs_local));
|
||||
idxs_local[idxs.size() - 2]++;
|
||||
} else {
|
||||
// Once we run out of tiles, we can pick any one we like.
|
||||
parts.push_back(parts.back());
|
||||
}
|
||||
// Pack any data lying around if OOB
|
||||
if (idxs_local.back() < input_vregs.dimensions().back() - 1) {
|
||||
++idxs_local.back();
|
||||
}
|
||||
*v = builder.create<PackSubelementsOp>(res_vreg_ty, parts);
|
||||
parts.clear();
|
||||
});
|
||||
} else {
|
||||
return op.emitOpError("Not implemented: unsupported output tiling");
|
||||
}
|
||||
op.replaceAllUsesWith(assemble(builder, result_ty, layout_out,
|
||||
std::move(output_vregs), ctx.target_shape)
|
||||
.getResult());
|
||||
op.erase();
|
||||
return success();
|
||||
}
|
||||
*v = builder.create<PackSubelementsOp>(res_vreg_ty, parts);
|
||||
});
|
||||
} else if (layout_out.hasNativeTiling(ctx.target_shape)) {
|
||||
int packing = layout_out.packing();
|
||||
SmallVector<Value> parts;
|
||||
parts.reserve(packing);
|
||||
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
|
||||
CHECK_GE(idxs.size(), 2);
|
||||
SmallVector<int64_t> idxs_local(toArrayRef(idxs));
|
||||
idxs_local[idxs.size() - 2] *= packing;
|
||||
parts.push_back(input_vregs(idxs_local));
|
||||
idxs_local[idxs.size() - 2]++;
|
||||
while (parts.size() < packing) {
|
||||
if (*(idxs_local.end() - 2) < *(input_vregs.dimensions().end() - 2)) {
|
||||
parts.push_back(input_vregs(idxs_local));
|
||||
idxs_local[idxs.size() - 2]++;
|
||||
} else {
|
||||
// Once we run out of tiles, we can pick any one we like.
|
||||
parts.push_back(parts.back());
|
||||
}
|
||||
}
|
||||
*v = builder.create<PackSubelementsOp>(res_vreg_ty, parts);
|
||||
parts.clear();
|
||||
});
|
||||
} else {
|
||||
return op.emitOpError("Not implemented: unsupported output tiling");
|
||||
}
|
||||
// TODO(tlongeri): why wasn't this part of the original code?
|
||||
return op.emitOpError("Not implemented");
|
||||
if (layout_out.implicit_dim() != VectorLayout::ImplicitDim::kNone) {
|
||||
output_vregs.Reshape(output_vregs_shape);
|
||||
}
|
||||
op.replaceAllUsesWith(assemble(builder, result_ty, layout_out,
|
||||
std::move(output_vregs), ctx.target_shape)
|
||||
.getResult());
|
||||
op.erase();
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult arith_truncf_rule(RewriteContext &ctx, Operation &op,
|
||||
|
@ -1473,34 +1473,24 @@ class VectorLayoutInferer {
|
||||
"Only extensions to 32-bit supported");
|
||||
}
|
||||
auto &layout = *some_layout;
|
||||
if (layout.implicit_dim() == ImplicitDim::kNone) {
|
||||
// TODO(apaszke): Support native packed layouts here.
|
||||
Layout src_layout;
|
||||
Layout dst_layout;
|
||||
// All layouts that subdivide the rows of the default tiling evenly
|
||||
// can be handled uniformly with the default case, by preserving the
|
||||
// tiling through the op.
|
||||
if (default_tiling_[0] % layout.tiling()[0] == 0 &&
|
||||
default_tiling_[1] == layout.tiling()[1]) {
|
||||
src_layout = layout;
|
||||
} else {
|
||||
src_layout = VectorLayout(layout.bitwidth(), layout.offsets(),
|
||||
default_tiling_, ImplicitDim::kNone);
|
||||
}
|
||||
dst_layout = VectorLayout(32, layout.offsets(), src_layout->tiling(),
|
||||
ImplicitDim::kNone);
|
||||
setLayout(op, src_layout, dst_layout);
|
||||
return success();
|
||||
// TODO(apaszke): Support native packed layouts here.
|
||||
Layout src_layout;
|
||||
Layout dst_layout;
|
||||
// All layouts that subdivide the rows of the default tiling evenly
|
||||
// can be handled uniformly with the default case, by preserving the
|
||||
// tiling through the op.
|
||||
if (default_tiling_[0] % layout.tiling()[0] == 0 &&
|
||||
default_tiling_[1] == layout.tiling()[1]) {
|
||||
src_layout = layout;
|
||||
} else {
|
||||
// TODO(b/335863273): we should also reduce offsets.
|
||||
src_layout = VectorLayout(layout.bitwidth(), layout.offsets(),
|
||||
default_tiling_, layout.implicit_dim());
|
||||
}
|
||||
if (layout.implicit_dim() == ImplicitDim::kSecondMinor) {
|
||||
TPU_CHECK_OP(layout.tiling() == nativeTiling(16), "unsupported tiling");
|
||||
auto dst_layout = VectorLayout(32, layout.offsets(), default_tiling_,
|
||||
layout.implicit_dim());
|
||||
setLayout(op, some_layout, dst_layout);
|
||||
return success();
|
||||
}
|
||||
op->emitOpError("unsupported extension layout");
|
||||
return failure();
|
||||
dst_layout = VectorLayout(32, layout.offsets(), src_layout->tiling(),
|
||||
layout.implicit_dim());
|
||||
setLayout(op, src_layout, dst_layout);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult inferTrunc(Operation *op) {
|
||||
@ -1523,20 +1513,16 @@ class VectorLayoutInferer {
|
||||
"Only 32-bit truncation supported");
|
||||
}
|
||||
auto &layout = *some_layout;
|
||||
if (layout.implicit_dim() == ImplicitDim::kNone) {
|
||||
bool select_native = allUsersRequireNativeTiling(op->getResult(0));
|
||||
auto src_layout = VectorLayout(32, layout.offsets(), default_tiling_,
|
||||
ImplicitDim::kNone);
|
||||
auto dst_layout = VectorLayout(
|
||||
dst_ty.getElementTypeBitWidth(), layout.offsets(),
|
||||
select_native ? nativeTiling(dst_ty.getElementTypeBitWidth())
|
||||
: default_tiling_,
|
||||
ImplicitDim::kNone);
|
||||
setLayout(op, src_layout, dst_layout);
|
||||
return success();
|
||||
}
|
||||
op->emitOpError("unsupported truncation layout");
|
||||
return failure();
|
||||
bool select_native = allUsersRequireNativeTiling(op->getResult(0));
|
||||
auto src_layout = VectorLayout(32, layout.offsets(), default_tiling_,
|
||||
layout.implicit_dim());
|
||||
auto dst_layout = VectorLayout(
|
||||
dst_ty.getElementTypeBitWidth(), layout.offsets(),
|
||||
select_native ? nativeTiling(dst_ty.getElementTypeBitWidth())
|
||||
: default_tiling_,
|
||||
layout.implicit_dim());
|
||||
setLayout(op, src_layout, dst_layout);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult inferElementwise(Operation *op, bool check_bitwidth = true) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user