[XLA:Mosaic] Support trunc/ext op for 1D vector with any implicit dim.

PiperOrigin-RevId: 626466602
This commit is contained in:
Jevin Jiang 2024-04-19 14:13:41 -07:00 committed by jax authors
parent 6e23c14f85
commit 167161706c
3 changed files with 123 additions and 145 deletions

View File

@ -270,7 +270,6 @@ class VectorLayout {
SmallVector<int64_t> implicitShape(ArrayRef<int64_t> shape) const;
private:
SmallVector<int64_t> tileArrayImplicitShape(
ArrayRef<int64_t> shape, std::array<int64_t, 2> target_shape) const;

View File

@ -658,15 +658,23 @@ LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op,
const auto result_ty = cast<VectorType>(op.getResult().getType());
auto source = cast<TypedValue<VectorType>>(op.getIn());
const auto source_ty = source.getType();
auto output_vregs_shape =
layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape);
if (layout_out.bitwidth() != 32) {
return op.emitOpError(
"Not implemented: Only extensions to 32-bit supported");
}
FAILUREOR_ASSIGN_OR_RETURN(
const xla::Array<Value> input_vregs,
xla::Array<Value> input_vregs,
disassemble(builder, layout_in, source, ctx.target_shape));
xla::Array<Value> output_vregs(
layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape));
xla::Array<Value> output_vregs(output_vregs_shape);
// TODO(jevinjiang): maybe just use tileArrayImplicitShape in disassemble?
if (layout_in.implicit_dim() != VectorLayout::ImplicitDim::kNone) {
input_vregs.Reshape(layout_in.tileArrayImplicitShape(source_ty.getShape(),
ctx.target_shape));
output_vregs.Reshape(layout_out.tileArrayImplicitShape(result_ty.getShape(),
ctx.target_shape));
}
FAILUREOR_ASSIGN_OR_RETURN(
const VectorType res_vreg_ty,
getNativeVregType(result_ty.getElementType(), ctx.target_shape));
@ -676,51 +684,24 @@ LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op,
if (layout_in.offsets() != layout_out.offsets()) {
return op.emitOpError("Not implemented: Change of offsets during the cast");
}
switch (layout_in.implicit_dim()) {
case VectorLayout::ImplicitDim::kNone: {
if (layout_in.tiling() != layout_out.tiling()) {
return op.emitOpError(
"Not implemented: Changing tiling during the cast");
}
auto tiling = layout_in.tiling();
if (ctx.target_shape[0] % tiling[0] != 0 ||
ctx.target_shape[1] != tiling[1]) {
return op.emitOpError("Not implemented: tiling not supported");
}
const int packing = layout_in.packing();
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
SmallVector<int64_t> input_vreg_idxs(toArrayRef(idxs));
input_vreg_idxs.back() /= packing;
const int64_t vreg_part = idxs.back() % packing;
*v = builder.create<UnpackSubelementsOp>(
res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part);
});
} break;
case VectorLayout::ImplicitDim::kMinor:
return op.emitOpError(
"Not implemented: Only casts of lane-oriented values supported");
case VectorLayout::ImplicitDim::kSecondMinor: {
auto is_one_tile = [](VectorType vty, VectorLayout layout) {
auto implicit_shape = layout.implicitShape(vty.getShape());
auto tiled_shape = ArrayRef<int64_t>(implicit_shape).take_back(2);
return (layout.offsets()[0].value_or(0) + tiled_shape[0] <=
layout.tiling()[0]) &&
(layout.offsets()[1].value_or(0) + tiled_shape[1] <=
layout.tiling()[1]);
};
if (input_vregs.dimensions() != absl::Span<const int64_t>{1} ||
output_vregs.dimensions() != absl::Span<const int64_t>{1} ||
!is_one_tile(source_ty, layout_in) ||
!is_one_tile(result_ty, layout_out)) {
return op.emitOpError("Not implemented");
}
if (layout_in.offsets()[0] >= ctx.target_shape[0]) {
return op.emitOpError("Not implemented");
}
auto unpack_subelements_op = builder.create<UnpackSubelementsOp>(
res_vreg_ty, *input_vregs.begin(), 0);
output_vregs.Fill(unpack_subelements_op.getResult());
}
if (layout_in.tiling() != layout_out.tiling()) {
return op.emitOpError("Not implemented: Changing tiling during the cast");
}
auto tiling = layout_in.tiling();
if (ctx.target_shape[0] % tiling[0] != 0 ||
ctx.target_shape[1] != tiling[1]) {
return op.emitOpError("Not implemented: tiling not supported");
}
const int packing = layout_in.packing();
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
SmallVector<int64_t> input_vreg_idxs(toArrayRef(idxs));
input_vreg_idxs.back() /= packing;
const int64_t vreg_part = idxs.back() % packing;
*v = builder.create<UnpackSubelementsOp>(
res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part);
});
if (layout_out.implicit_dim() != VectorLayout::ImplicitDim::kNone) {
output_vregs.Reshape(output_vregs_shape);
}
op.replaceAllUsesWith(assemble(builder, result_ty, layout_out,
std::move(output_vregs), ctx.target_shape)
@ -762,73 +743,85 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op,
const VectorLayout &layout_in,
const VectorLayout &layout_out) {
ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation());
auto source = cast<TypedValue<VectorType>>(op.getIn());
const auto source_ty = source.getType();
auto result_ty = cast<VectorType>(op.getResult().getType());
auto output_vregs_shape =
layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape);
FAILUREOR_ASSIGN_OR_RETURN(
const xla::Array<Value> input_vregs,
disassemble(builder, layout_in, cast<TypedValue<VectorType>>(op.getIn()),
ctx.target_shape));
xla::Array<Value> output_vregs(
layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape));
xla::Array<Value> input_vregs,
disassemble(builder, layout_in, source, ctx.target_shape));
xla::Array<Value> output_vregs(output_vregs_shape);
if (layout_in.bitwidth() != 32) {
return op.emitOpError("Not implemented: Only 32-bit truncation supported");
}
if (layout_in.offsets() != layout_out.offsets()) {
return op.emitOpError(
"Not implemented: Change of offsets during the truncation");
}
if (layout_in.implicit_dim() != layout_out.implicit_dim()) {
return op.emitOpError("Not implemented: Change of layout during the cast");
}
if (layout_in.tiling() != ctx.target_shape) {
return op.emitOpError("Not implemented: Only (8,128) tiling supported");
}
if (layout_in.implicit_dim() != VectorLayout::ImplicitDim::kNone) {
input_vregs.Reshape(layout_in.tileArrayImplicitShape(source_ty.getShape(),
ctx.target_shape));
output_vregs.Reshape(layout_out.tileArrayImplicitShape(result_ty.getShape(),
ctx.target_shape));
}
FAILUREOR_ASSIGN_OR_RETURN(
VectorType res_vreg_ty,
getNativeVregType(result_ty.getElementType(), ctx.target_shape));
if (layout_in.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
layout_out.implicit_dim() == VectorLayout::ImplicitDim::kNone) {
if (layout_in.tiling() != ctx.target_shape) {
return op.emitOpError("Not implemented: Only (8,128) tiling supported");
}
if (layout_out.tiling() == ctx.target_shape) {
const int packing = layout_out.packing();
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
SmallVector<Value> parts;
SmallVector<int64_t> idxs_local(toArrayRef(idxs));
idxs_local.back() *= packing;
for (int64_t i = 0; i < packing; ++i) {
parts.push_back(input_vregs(idxs_local));
// Pack any data lying around if OOB
if (idxs_local.back() < input_vregs.dimensions().back() - 1) {
++idxs_local.back();
}
}
*v = builder.create<PackSubelementsOp>(res_vreg_ty, parts);
});
} else if (layout_out.hasNativeTiling(ctx.target_shape)) {
int packing = layout_out.packing();
if (layout_out.tiling() == ctx.target_shape) {
const int packing = layout_out.packing();
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
SmallVector<Value> parts;
parts.reserve(packing);
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
CHECK_GE(idxs.size(), 2);
SmallVector<int64_t> idxs_local(toArrayRef(idxs));
idxs_local[idxs.size() - 2] *= packing;
SmallVector<int64_t> idxs_local(toArrayRef(idxs));
idxs_local.back() *= packing;
for (int64_t i = 0; i < packing; ++i) {
parts.push_back(input_vregs(idxs_local));
idxs_local[idxs.size() - 2]++;
while (parts.size() < packing) {
if (*(idxs_local.end() - 2) < *(input_vregs.dimensions().end() - 2)) {
parts.push_back(input_vregs(idxs_local));
idxs_local[idxs.size() - 2]++;
} else {
// Once we run out of tiles, we can pick any one we like.
parts.push_back(parts.back());
}
// Pack any data lying around if OOB
if (idxs_local.back() < input_vregs.dimensions().back() - 1) {
++idxs_local.back();
}
*v = builder.create<PackSubelementsOp>(res_vreg_ty, parts);
parts.clear();
});
} else {
return op.emitOpError("Not implemented: unsupported output tiling");
}
op.replaceAllUsesWith(assemble(builder, result_ty, layout_out,
std::move(output_vregs), ctx.target_shape)
.getResult());
op.erase();
return success();
}
*v = builder.create<PackSubelementsOp>(res_vreg_ty, parts);
});
} else if (layout_out.hasNativeTiling(ctx.target_shape)) {
int packing = layout_out.packing();
SmallVector<Value> parts;
parts.reserve(packing);
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
CHECK_GE(idxs.size(), 2);
SmallVector<int64_t> idxs_local(toArrayRef(idxs));
idxs_local[idxs.size() - 2] *= packing;
parts.push_back(input_vregs(idxs_local));
idxs_local[idxs.size() - 2]++;
while (parts.size() < packing) {
if (*(idxs_local.end() - 2) < *(input_vregs.dimensions().end() - 2)) {
parts.push_back(input_vregs(idxs_local));
idxs_local[idxs.size() - 2]++;
} else {
// Once we run out of tiles, we can pick any one we like.
parts.push_back(parts.back());
}
}
*v = builder.create<PackSubelementsOp>(res_vreg_ty, parts);
parts.clear();
});
} else {
return op.emitOpError("Not implemented: unsupported output tiling");
}
// TODO(tlongeri): why wasn't this part of the original code?
return op.emitOpError("Not implemented");
if (layout_out.implicit_dim() != VectorLayout::ImplicitDim::kNone) {
output_vregs.Reshape(output_vregs_shape);
}
op.replaceAllUsesWith(assemble(builder, result_ty, layout_out,
std::move(output_vregs), ctx.target_shape)
.getResult());
op.erase();
return success();
}
LogicalResult arith_truncf_rule(RewriteContext &ctx, Operation &op,

View File

@ -1473,34 +1473,24 @@ class VectorLayoutInferer {
"Only extensions to 32-bit supported");
}
auto &layout = *some_layout;
if (layout.implicit_dim() == ImplicitDim::kNone) {
// TODO(apaszke): Support native packed layouts here.
Layout src_layout;
Layout dst_layout;
// All layouts that subdivide the rows of the default tiling evenly
// can be handled uniformly with the default case, by preserving the
// tiling through the op.
if (default_tiling_[0] % layout.tiling()[0] == 0 &&
default_tiling_[1] == layout.tiling()[1]) {
src_layout = layout;
} else {
src_layout = VectorLayout(layout.bitwidth(), layout.offsets(),
default_tiling_, ImplicitDim::kNone);
}
dst_layout = VectorLayout(32, layout.offsets(), src_layout->tiling(),
ImplicitDim::kNone);
setLayout(op, src_layout, dst_layout);
return success();
// TODO(apaszke): Support native packed layouts here.
Layout src_layout;
Layout dst_layout;
// All layouts that subdivide the rows of the default tiling evenly
// can be handled uniformly with the default case, by preserving the
// tiling through the op.
if (default_tiling_[0] % layout.tiling()[0] == 0 &&
default_tiling_[1] == layout.tiling()[1]) {
src_layout = layout;
} else {
// TODO(b/335863273): we should also reduce offsets.
src_layout = VectorLayout(layout.bitwidth(), layout.offsets(),
default_tiling_, layout.implicit_dim());
}
if (layout.implicit_dim() == ImplicitDim::kSecondMinor) {
TPU_CHECK_OP(layout.tiling() == nativeTiling(16), "unsupported tiling");
auto dst_layout = VectorLayout(32, layout.offsets(), default_tiling_,
layout.implicit_dim());
setLayout(op, some_layout, dst_layout);
return success();
}
op->emitOpError("unsupported extension layout");
return failure();
dst_layout = VectorLayout(32, layout.offsets(), src_layout->tiling(),
layout.implicit_dim());
setLayout(op, src_layout, dst_layout);
return success();
}
LogicalResult inferTrunc(Operation *op) {
@ -1523,20 +1513,16 @@ class VectorLayoutInferer {
"Only 32-bit truncation supported");
}
auto &layout = *some_layout;
if (layout.implicit_dim() == ImplicitDim::kNone) {
bool select_native = allUsersRequireNativeTiling(op->getResult(0));
auto src_layout = VectorLayout(32, layout.offsets(), default_tiling_,
ImplicitDim::kNone);
auto dst_layout = VectorLayout(
dst_ty.getElementTypeBitWidth(), layout.offsets(),
select_native ? nativeTiling(dst_ty.getElementTypeBitWidth())
: default_tiling_,
ImplicitDim::kNone);
setLayout(op, src_layout, dst_layout);
return success();
}
op->emitOpError("unsupported truncation layout");
return failure();
bool select_native = allUsersRequireNativeTiling(op->getResult(0));
auto src_layout = VectorLayout(32, layout.offsets(), default_tiling_,
layout.implicit_dim());
auto dst_layout = VectorLayout(
dst_ty.getElementTypeBitWidth(), layout.offsets(),
select_native ? nativeTiling(dst_ty.getElementTypeBitWidth())
: default_tiling_,
layout.implicit_dim());
setLayout(op, src_layout, dst_layout);
return success();
}
LogicalResult inferElementwise(Operation *op, bool check_bitwidth = true) {