[Mosaic] Extend tpu matmulop to have dimension dims. Add support for batching and simple transposition.

PiperOrigin-RevId: 691706218
This commit is contained in:
jax authors 2024-10-31 00:58:28 -07:00
parent f355dcf34b
commit 5aeffde707
9 changed files with 542 additions and 12 deletions

View File

@ -1615,6 +1615,7 @@ def _dot_general_lowering_rule(
)
return vector.shape_cast(out_type, red)
# TODO(mvoz): Plumb these into dot dimension numbers on the matmul op!
if lhs_dims == (1,):
transpose_lhs = False
elif lhs_dims == (0,):

View File

@ -384,22 +384,47 @@ def TPU_DynamicGatherOp : TPU_Op<"dynamic_gather", [Pure]> {
}];
}
// TODO(apaszke): Add a verifier for this op
def TPU_DotDimensionNumbersAttr : TPU_Attr<"DotDimensionNumbers", "dot_dimension_numbers"> {
let parameters = (ins
ArrayRefParameter<"int64_t", "">:$lhs_contracting_dims,
ArrayRefParameter<"int64_t", "">:$rhs_contracting_dims,
ArrayRefParameter<"int64_t", "">:$lhs_non_contracting_dims,
ArrayRefParameter<"int64_t", "">:$rhs_non_contracting_dims,
// The contract is a flattened structure, wherein, each element is half of a
// pair of indices. The first element is always 0 (lhs) or 1 (rhs) and the
// second index is the index from the lhs or rhs.
ArrayRefParameter<"int64_t", "">:$output_dim_order,
OptionalArrayRefParameter<"int64_t", "">:$lhs_batch_dims,
OptionalArrayRefParameter<"int64_t", "">:$rhs_batch_dims
);
let assemblyFormat = "`<` `[` $lhs_contracting_dims `]` `,` `[` $rhs_contracting_dims `]` `,` "
"`[` $lhs_non_contracting_dims `]` `,` `[` $rhs_non_contracting_dims `]` `,` "
"`[` $output_dim_order `]` `,` "
"`[` (`]`):($lhs_batch_dims^ `]`)? `,` "
"`[` (`]`):($rhs_batch_dims^ `]`)? `>`";
}
// TODO(apaszke): Think hard about precision
def TPU_MatmulOp : TPU_Op<"matmul", [Pure]> {
let arguments = (ins
AnyVector:$lhs,
AnyVector:$rhs,
AnyVector:$acc,
// These flags are deprecated - if dimension_numbers are defined,
// these flags are ignored. They will always be false after canonicalize.
DefaultValuedAttr<BoolAttr, "false">:$transpose_lhs,
DefaultValuedAttr<BoolAttr, "false">:$transpose_rhs,
OptionalAttr<TPU_ContractPrecisionEnum>:$precision
OptionalAttr<TPU_ContractPrecisionEnum>:$precision,
// NOTE: User-level optional, once canonicalized, always present.
OptionalAttr<TPU_DotDimensionNumbersAttr>:$dimension_numbers
);
let results = (outs AnyVector:$result);
let assemblyFormat = [{
$lhs `,` $rhs `,` $acc attr-dict `:` type($lhs) `,` type($rhs) `,` type($acc) `->` type($result)
}];
let hasCanonicalizer = 1;
let hasVerifier = 1;
}
def TPU_ConcatenateOp : TPU_Op<"concatenate", [Pure]> {

View File

@ -223,4 +223,18 @@ bool isGuaranteedDivisible(Value value, int64_t divisor, int64_t fuel) {
return false;
}
DotDimensionNumbersAttr defaultDimensionNumbers(Builder &builder,
bool transpose_lhs,
bool transpose_rhs) {
return tpu::DotDimensionNumbersAttr::get(
builder.getContext(),
/*lhs_contracting_dims=*/{transpose_lhs ? 0 : 1},
/*rhs_contracting_dims=*/{transpose_rhs ? 1 : 0},
/*lhs_non_contracting_dims=*/{transpose_lhs ? 1 : 0},
/*rhs_non_contracting_dims=*/{transpose_rhs ? 0 : 1},
/*output_dim_order=*/{0, transpose_lhs ? 1 : 0, 1, transpose_rhs ? 0 : 1},
/*lhs_batch_dims=*/{},
/*rhs_batch_dims=*/{});
}
} // namespace mlir::tpu

View File

@ -104,6 +104,10 @@ MemRefType getMemRefType(Value value);
bool isGuaranteedDivisible(Value value, int64_t divisor, int64_t fuel = 8);
DotDimensionNumbersAttr defaultDimensionNumbers(Builder &builder,
bool transpose_lhs,
bool transpose_rhs);
#define GEN_PASS_REGISTRATION
#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc"

View File

@ -14,6 +14,9 @@ limitations under the License.
==============================================================================*/
#include <cstdint>
#include <optional>
#include <string_view>
#include <vector>
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/FormatVariadic.h"
@ -507,6 +510,280 @@ class CanonicalizeAddOfMatmul : public OpRewritePattern<AddOp> {
}
};
LogicalResult MatmulOp::verify() {
// Note - this is not yet an exhaustive verification of matmul. Many of the
// invariants are spread across infer, apply, llo and below. This is,
// however, a good start and the recommended place to add more invariants.
const VectorType lhs_ty = getLhs().getType();
const VectorType rhs_ty = getRhs().getType();
if (getTransposeLhs()) {
emitOpError(
"Lhs transpose not supported via this API - please use the "
"dimension numbers API.");
return failure();
}
if (getDimensionNumbers().has_value()) {
auto dimension_numbers = getDimensionNumbers().value();
auto lhs_contracting_dims = dimension_numbers.getLhsContractingDims();
auto rhs_contracting_dims = dimension_numbers.getRhsContractingDims();
if (lhs_contracting_dims.size() != 1) {
emitOpError("Not implemented: lhs contracting dims must be of size 1");
return failure();
}
if (rhs_contracting_dims.size() != 1) {
emitOpError("Not implemented: rhs contracting dims must be of size 1");
return failure();
}
auto lhs_contracting_dim = lhs_contracting_dims[0];
auto rhs_contracting_dim = rhs_contracting_dims[0];
auto lhs_batch_dims = dimension_numbers.getLhsBatchDims();
auto rhs_batch_dims = dimension_numbers.getRhsBatchDims();
auto lhs_non_contracting_dims =
dimension_numbers.getLhsNonContractingDims();
auto rhs_non_contracting_dims =
dimension_numbers.getRhsNonContractingDims();
if (lhs_contracting_dims.size() + lhs_non_contracting_dims.size() +
lhs_batch_dims.size() !=
lhs_ty.getShape().size()) {
emitOpError(
"Not implemented: lhs contracting + non contracting + batch dims "
"must be of the same size as the lhs shape");
return failure();
}
if (rhs_contracting_dims.size() + rhs_non_contracting_dims.size() +
rhs_batch_dims.size() !=
rhs_ty.getShape().size()) {
emitOpError(
"Not implemented: rhs contracting + non contracting + batch dims "
"must be of the same size as the rhs shape");
return failure();
}
if (lhs_ty.getShape()[lhs_contracting_dim] !=
rhs_ty.getShape()[rhs_contracting_dim]) {
emitOpError(
"Not implemented: lhs and rhs contracting dims must be of the same "
"size");
return failure();
}
if (lhs_batch_dims.size() != rhs_batch_dims.size()) {
emitOpError(
"Not implemented: lhs and rhs should have the same number of batch "
"dims");
return failure();
}
if (lhs_batch_dims.size() > 1) {
emitOpError("Not implemented: Up to 1 batch dim supported");
return failure();
}
int64_t lhs_rank = lhs_ty.getShape().size();
int64_t rhs_rank = rhs_ty.getShape().size();
std::vector<bool> seen_dims_lhs(lhs_rank, false);
std::vector<bool> seen_dims_rhs(rhs_rank, false);
auto check_and_mark_dims = [&](const std::vector<int64_t> &dims,
std::vector<bool> &seen_dims,
const std::string_view operand) {
for (int64_t dim : dims) {
if (seen_dims[dim]) {
emitOpError("Illegal: Dim ")
<< dim << " repeats in dimension numbers of " << operand;
return failure();
}
seen_dims[dim] = true;
}
return success();
};
if (failed(
check_and_mark_dims(lhs_contracting_dims, seen_dims_lhs, "lhs")) ||
failed(check_and_mark_dims(lhs_non_contracting_dims, seen_dims_lhs,
"lhs")) ||
failed(check_and_mark_dims(lhs_batch_dims, seen_dims_lhs, "lhs"))) {
return failure();
}
if (failed(
check_and_mark_dims(rhs_contracting_dims, seen_dims_rhs, "rhs")) ||
failed(check_and_mark_dims(rhs_non_contracting_dims, seen_dims_rhs,
"rhs")) ||
failed(check_and_mark_dims(rhs_batch_dims, seen_dims_rhs, "rhs"))) {
return failure();
}
for (int64_t dim = 0; dim < lhs_rank; ++dim) {
if (!seen_dims_lhs[dim]) {
emitOpError("Illegal: Dim ")
<< dim << " is not seen in lhs dimension numbers";
return failure();
}
}
for (int64_t dim = 0; dim < rhs_rank; ++dim) {
if (!seen_dims_rhs[dim]) {
emitOpError("Illegal: Dim ")
<< dim << " is not seen in rhs dimension numbers";
}
}
const std::optional<int64_t> batch_dim_lhs =
lhs_batch_dims.empty() ? std::nullopt
: std::optional<int64_t>(lhs_batch_dims[0]);
const std::optional<int64_t> batch_dim_rhs =
rhs_batch_dims.empty() ? std::nullopt
: std::optional<int64_t>(rhs_batch_dims[0]);
if (batch_dim_lhs != batch_dim_rhs) {
emitOpError("Not Implemented: batch dims must be equal");
return failure();
}
if (batch_dim_lhs.has_value() && (batch_dim_lhs.value() != 0)) {
emitOpError("Not Implemented: batch dims pos must be 0");
return failure();
}
// Invariant above enforces only 1 batch dim atm, and that both are eq
std::optional<int64_t> batch_size = std::nullopt;
if (batch_dim_lhs.has_value()) {
batch_size = lhs_ty.getShape()[batch_dim_lhs.value()];
auto rhs_batch_size = rhs_ty.getShape()[batch_dim_rhs.value()];
if (batch_size != rhs_batch_size) {
emitOpError("Not Implemented: batch dims must be equal");
return failure();
}
if (batch_size == 0) {
emitOpError("Illegal: batch size must be > 0");
return failure();
}
}
auto output_dim_order = dimension_numbers.getOutputDimOrder();
if (output_dim_order.size() % 2 != 0) {
emitOpError(
"Illegal: output dim order must have an even number of elements.");
return failure();
}
if (batch_size.has_value()) {
if (output_dim_order[0] != 0 || output_dim_order[1] != 0) {
emitOpError(
"Not implemented: Output with batch size must be the lhs 0 idx for "
"now.");
return failure();
}
}
// Invariants above enforce a single batch idx for now, and that it is in
// position 0. Future extensions to this will be to:
// 1. Support multiple batch dims
// 2. Support batch dims in any position in the output dim order
if (lhs_non_contracting_dims.size() != 1) {
emitOpError(
"Not implemented: lhs non contracting dims must be of size 1");
return failure();
}
if (rhs_non_contracting_dims.size() != 1) {
emitOpError(
"Not implemented: rhs non contracting dims must be of size 1");
return failure();
}
// A bit long winded, but the invariants we enforce below are:
// 1. The output order idx is 0 (lhs) or 1 (rhs)
// 2. The output dim order is in valid bounds
// 3. We saw the rhs and lhs non contracting dims in the output dim order
// 4. We never see the contracting dims in the output dim order
// 5. We only see each of the non contracting dim once
std::vector<bool> lhs_dims_seen_in_output(lhs_rank, false);
std::vector<bool> rhs_dims_seen_in_output(rhs_rank, false);
// Iterate over the output dimension order
for (int dim_pos = 0; dim_pos < output_dim_order.size(); dim_pos += 2) {
auto idx = output_dim_order[dim_pos];
auto dim = output_dim_order[dim_pos + 1];
if (idx != 0 && idx != 1) {
emitOpError("Illegal: output dim order index must be 0 or 1");
return failure();
}
auto is_lhs = (idx == 0);
if (is_lhs) {
if (dim < 0 || dim >= lhs_rank) {
emitOpError("Illegal: lhs dimension index out of bounds");
return failure();
}
if (lhs_dims_seen_in_output[dim]) {
emitOpError("Illegal: lhs dimension ")
<< dim << " appears more than once in output dim order";
return failure();
}
if (dim == lhs_contracting_dim) {
emitOpError("Illegal: contracting dimension ")
<< dim << " appears in lhs output dim order";
return failure();
}
// batch_dim_lhs is either 0 or nullopt
if (dim == batch_dim_lhs) {
// Upstream invariants enforce that batch dim is in position 0
// of the output dim order.
rhs_dims_seen_in_output[dim] = true;
}
lhs_dims_seen_in_output[dim] = true;
} else {
if (dim < 0 || dim >= rhs_rank) {
emitOpError("Illegal: rhs dimension index out of bounds");
return failure();
}
if (rhs_dims_seen_in_output[dim]) {
emitOpError("Illegal: rhs dimension ")
<< dim << " appears more than once in output dim order";
return failure();
}
if (dim == rhs_contracting_dim) {
emitOpError("Illegal: contracting dimension ")
<< dim << " appears in rhs output dim order";
return failure();
}
if (dim == batch_dim_rhs) {
// Upstream invariants enforce that batch dim is in position 0
// of the output dim order.
lhs_dims_seen_in_output[dim] = true;
}
rhs_dims_seen_in_output[dim] = true;
}
}
// Check that all dims have been seen (except contracting dims)
for (int i = 0; i < lhs_rank; ++i) {
if (i == lhs_contracting_dim) {
continue;
}
if (!lhs_dims_seen_in_output[i]) {
emitOpError("Illegal: lhs non-contracting dimension ")
<< i << " is not seen in output dim order";
return failure();
}
}
for (int i = 0; i < rhs_rank; ++i) {
if (i == rhs_contracting_dim) {
continue;
}
if (!rhs_dims_seen_in_output[i]) {
emitOpError("Illegal: rhs non-contracting dimension ")
<< i << " is not seen in output dim order";
return failure();
}
}
}
return success();
}
void MatmulOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<CanonicalizeAddOfMatmul<arith::AddFOp>,

View File

@ -1696,15 +1696,36 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
llvm::all_of(layouts_in, [&](const Layout &l) { return l.has_value(); }));
TPU_ASSERT_OP(layouts_out.front().has_value());
auto matmul_op = cast<tpu::MatmulOp>(op);
const auto transpose_lhs = matmul_op.getTransposeLhs();
const auto transpose_rhs = matmul_op.getTransposeRhs();
const auto &layout_lhs = *layouts_in[0];
const auto &layout_rhs = *layouts_in[1];
const auto &layout_acc = *layouts_in[2];
const auto &layout_out = *layouts_out[0];
if (transpose_lhs) {
return op.emitOpError("Not implemented: Transposed LHS");
if (matmul_op.getTransposeRhs()) {
return op.emitOpError(
"Transposition must have been erased into dimension numbers during "
"canonicalization");
}
auto dimension_numbers = matmul_op.getDimensionNumbers();
if (!dimension_numbers.has_value()) {
return op.emitOpError(
"Dimension numbers must be provided, ensure canonicalization has been "
"run.");
}
auto transposed_mkn = isTransposedMatmul(dimension_numbers.value());
if (!transposed_mkn.has_value()) {
return op.emitOpError(
"Dimension numbers must be MKN, ensure canonicalization has been "
"run.");
}
auto [transpose_lhs, transpose_rhs] = transposed_mkn.value();
if (transpose_lhs) {
return op.emitOpError(
"Transposition of LHS is not supported in apply_vector_layout, ensure "
"canonicalization has been run.");
}
auto &layout_lhs = *layouts_in[0];
auto &layout_rhs = *layouts_in[1];
auto &layout_acc = *layouts_in[2];
auto &layout_out = *layouts_out[0];
const std::array<std::reference_wrapper<const VectorLayout>, 4> all_layouts =
{layout_lhs, layout_rhs, layout_acc, layout_out};
for (const VectorLayout &layout : all_layouts) {
@ -1965,6 +1986,8 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
const tpu::ContractPrecisionAttr precision_attr = // May be null
op.getAttrOfType<tpu::ContractPrecisionAttr>("precision");
const tpu::DotDimensionNumbersAttr dot_dimension_numbers_attr =
defaultDimensionNumbers(builder, false, transpose_rhs);
for (int64_t j = 0; j < nj; ++j) {
for (int64_t k = 0; k < nk; ++k) {
// TODO(tlongeri): there should be a way to slice without copying
@ -1981,7 +2004,8 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
acc_col->setAttr("out_layout", acc_layout_attr);
auto new_acc_col = builder.create<tpu::MatmulOp>(
op.getLoc(), acc_col_ty, lhs_cols[k], rhs_rolled_group, acc_col,
transpose_lhs, transpose_rhs, precision_attr);
/*transpose_lhs=*/false, /*transpose_rhs=*/false, precision_attr,
dot_dimension_numbers_attr);
auto new_acc_vregs = builder.create<tpu::UnrollVectorsOp>(
op.getLoc(),
TypeRange(ValueRange(XlaArrayToFlatArrayRef(acc_col_vregs))),

View File

@ -1,6 +1,10 @@
#include <algorithm>
#include <cstdint>
#include <functional>
#include <memory>
#include <numeric>
#include <optional>
#include <utility>
#include <vector>
#include "llvm/ADT/STLExtras.h"
@ -16,6 +20,8 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "absl/log/check.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h"
@ -23,6 +29,7 @@
#include "mlir/include/mlir/IR/Attributes.h"
#include "mlir/include/mlir/IR/Block.h"
#include "mlir/include/mlir/IR/Builders.h"
#include "mlir/include/mlir/IR/BuiltinAttributes.h"
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/include/mlir/IR/OpDefinition.h"
#include "mlir/include/mlir/IR/Operation.h"
@ -40,6 +47,9 @@ namespace mlir::tpu {
LogicalResult tpu_matmul_rule(tpu::MatmulOp op) {
ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation());
auto transpose_lhs = op.getTransposeLhs();
auto transpose_rhs = op.getTransposeRhs();
auto lhs = op.getLhs();
auto rhs = op.getRhs();
auto acc = op.getAcc();
@ -52,6 +62,51 @@ LogicalResult tpu_matmul_rule(tpu::MatmulOp op) {
auto rhs_element_type = rhs_ty.getElementType();
auto acc_element_type = acc_ty.getElementType();
// there are a few primary paths for dimension_numbers in matmul
// 1) No dimension numbers provided -> set to default
// 2) defined and not default -> verify and apply
// 3) defined and matching defaultDimensionNumbers -> no-op for
// canonicalization of dims
std::optional<int64_t> batch_size = std::nullopt;
// MKN matmul - no dims or transpositions set
if (!op.getDimensionNumbers().has_value()) {
// Legacy API - convert it to dimension numbers
op.setDimensionNumbersAttr(
defaultDimensionNumbers(builder, transpose_lhs, transpose_rhs));
} else if (
// Dot dim API - dimensions are provided and are not default
(op.getDimensionNumbers().value() !=
defaultDimensionNumbers(builder, false, false))) {
auto dimension_numbers = op.getDimensionNumbers();
auto lhs_contracting_dims = dimension_numbers->getLhsContractingDims();
auto rhs_contracting_dims = dimension_numbers->getRhsContractingDims();
auto lhs_batch_dims = dimension_numbers->getLhsBatchDims();
auto rhs_batch_dims = dimension_numbers->getRhsBatchDims();
// Invariant in matmul verifier: <= 1 batch dim atm, and that lhs and rhs
// are the same
// Invariant in matmul verifier: Exactly one contracting and non contracting
// dim in each of lhs and rhs for now.
batch_size =
lhs_batch_dims.empty()
? std::nullopt
: std::optional<int64_t>(lhs_ty.getShape()[lhs_batch_dims[0]]);
// Lower each dim in contracting dims by size(batch_dims)
auto batch_adjusted_lhs_contracting_dim =
lhs_contracting_dims[0] - lhs_batch_dims.size();
auto batch_adjusted_rhs_contracting_dim =
rhs_contracting_dims[0] - rhs_batch_dims.size();
if (batch_adjusted_lhs_contracting_dim != 1) {
transpose_lhs = true;
}
if (batch_adjusted_rhs_contracting_dim != 0) {
transpose_rhs = true;
}
}
auto extsi_sitofp = [&builder, &op](TypedValue<VectorType> element) {
const VectorType ty = element.getType();
auto shape = ty.getShape();
@ -88,10 +143,12 @@ LogicalResult tpu_matmul_rule(tpu::MatmulOp op) {
if (lhs_element_type.isInteger()) {
auto float_lhs = extsi_sitofp(lhs);
op->setOperand(0, float_lhs);
lhs = cast<TypedValue<VectorType>>(float_lhs.getResult());
}
if (rhs_element_type.isInteger()) {
auto float_rhs = extsi_sitofp(rhs);
op->setOperand(1, float_rhs);
rhs = cast<TypedValue<VectorType>>(float_rhs.getResult());
}
}
// TODO(mvoz): Add more invariants.
@ -114,6 +171,91 @@ LogicalResult tpu_matmul_rule(tpu::MatmulOp op) {
return failure();
}
}
auto dot_dim_matmul = [&](auto lhs, auto rhs, auto acc) {
auto precision_attr = op.getPrecisionAttr();
// If we are transposing the lhs, we need to transpose the lhs before
// matmul here, as we don't have lhs fusion implemented in apply.
if (transpose_lhs) {
auto lhs_ty = cast<VectorType>(lhs.getType());
auto rank = lhs_ty.getShape().size();
// This transposition must run on vectors with rank >= 2
CHECK_GE(rank, 2);
std::vector<int64_t> perm(rank);
std::iota(perm.begin(), perm.end(), 0);
std::swap(perm[rank - 2], perm[rank - 1]);
std::vector<int64_t> shape(lhs_ty.getShape());
std::swap(shape[rank - 2], shape[rank - 1]);
auto lhs_ty_transposed = VectorType::get(shape, lhs_ty.getElementType());
const SmallVector<int64_t> perm_vec =
SmallVector<int64_t>(perm.begin(), perm.end());
lhs = builder.create<vector::TransposeOp>(
lhs_ty_transposed, lhs,
DenseI64ArrayAttr::get(builder.getContext(), perm_vec));
}
auto ddn = defaultDimensionNumbers(builder, /*transpose_lhs=*/false,
transpose_rhs);
// transpose flags are always false here, because ddn takes precedence
// after this pass.
auto matmul_res = builder.create<tpu::MatmulOp>(
op.getLoc(), acc.getType(), lhs, rhs, acc,
/*transpose_lhs=*/false,
/*transpose_rhs=*/false, precision_attr, ddn);
return matmul_res;
};
// If we have a batch_size, we want to slice rhs and lhs [:batch_size],
// and then do O[i] = A[i] @ B[i]
// Produce an output shape of [batch_size, m, n]
if (batch_size.has_value()) {
std::vector<Value> outputs;
for (int64_t i = 0; i < batch_size; ++i) {
auto sliced_lhs = builder.create<vector::ExtractOp>(op.getLoc(), lhs,
ArrayRef<int64_t>{i});
auto sliced_rhs = builder.create<vector::ExtractOp>(op.getLoc(), rhs,
ArrayRef<int64_t>{i});
auto sliced_acc = builder.create<vector::ExtractOp>(op.getLoc(), acc,
ArrayRef<int64_t>{i});
auto matmul_res =
dot_dim_matmul(sliced_lhs.getResult(), sliced_rhs.getResult(),
sliced_acc.getResult());
auto res_ty = matmul_res.getType().cast<VectorType>();
auto res_shape = res_ty.getShape();
// reshape to 1x[prior_shape]
auto reshape_shape = llvm::to_vector(res_shape);
reshape_shape.insert(reshape_shape.begin(), 1);
auto shape_cast = builder.create<vector::ShapeCastOp>(
op.getLoc(), VectorType::get(reshape_shape, res_ty.getElementType()),
matmul_res);
outputs.push_back(shape_cast);
}
// Technically almost identical to the case where batch_size is 1, but
// we want to avoid the spurious concat here.
if (batch_size == 1) {
op.replaceAllUsesWith(outputs[0]);
op.erase();
return success();
}
auto output = builder
.create<tpu::ConcatenateOp>(op.getLoc(), acc_ty, outputs,
/*dimension=*/0)
.getResult();
op.replaceAllUsesWith(output);
op.erase();
} else {
auto matmul_res = dot_dim_matmul(lhs, rhs, acc).getResult();
op.replaceAllUsesWith(matmul_res);
op.erase();
}
return success();
};
@ -309,9 +451,14 @@ LogicalResult canonicalize_contraction(int hardware_generation, Operation &op) {
}
const tpu::ContractPrecisionAttr precision_attr = // May be null
contraction_op->getAttrOfType<tpu::ContractPrecisionAttr>("precision");
const auto dot_dimension_numbers_attr =
defaultDimensionNumbers(builder, false, transpose_rhs);
auto matmul_op = builder.create<tpu::MatmulOp>(
contraction_op->getLoc(), acc_ty, lhs, rhs, acc,
/*transpose_lhs=*/false, transpose_rhs, precision_attr);
/*transpose_lhs=*/false,
/*transpose_rhs=*/false, precision_attr, dot_dimension_numbers_attr);
contraction_op.replaceAllUsesWith(matmul_op.getResult());
contraction_op.erase();
auto result = tpu_matmul_rule(matmul_op);

View File

@ -17,6 +17,8 @@ limitations under the License.
#include <array>
#include <cstdint>
#include <optional>
#include <utility>
#include "llvm/Support/MathExtras.h"
#include "absl/types/span.h"
@ -42,6 +44,31 @@ SmallVector<int64_t> ComputeTileStrides(MemRefType memref_ty,
return tile_strides;
}
std::optional<std::pair<bool, bool>> isTransposedMatmul(
DotDimensionNumbersAttr dim_numbers) {
auto lhs_contracting_dims = dim_numbers.getLhsContractingDims();
auto rhs_contracting_dims = dim_numbers.getRhsContractingDims();
auto lhs_non_contracting_dims = dim_numbers.getLhsNonContractingDims();
auto rhs_non_contracting_dims = dim_numbers.getRhsNonContractingDims();
if (lhs_contracting_dims.size() != 1 || rhs_contracting_dims.size() != 1 ||
lhs_non_contracting_dims.size() != 1 ||
rhs_non_contracting_dims.size() != 1) {
return std::nullopt;
}
int64_t lhs_non_contracting_dim = lhs_non_contracting_dims[0];
int64_t lhs_contracting_dim = lhs_contracting_dims[0];
int64_t rhs_non_contracting_dim = rhs_non_contracting_dims[0];
int64_t rhs_contracting_dim = rhs_contracting_dims[0];
bool lhs_transposed = lhs_non_contracting_dim > lhs_contracting_dim;
bool rhs_transposed = rhs_contracting_dim > rhs_non_contracting_dim;
return std::pair<bool, bool>{lhs_transposed, rhs_transposed};
}
bool canReinterpretToUntiledMemref(MemRefType tiled_memref_ty,
const std::array<int64_t, 2>& target_shape,
bool allow_minormost_padding) {
@ -68,4 +95,5 @@ bool canReinterpretToUntiledMemref(MemRefType tiled_memref_ty,
*(tiled_layout.getTileStrides().end() - 1) == 1 &&
*(tiled_layout.getTileStrides().end() - 2) == 1);
}
} // namespace mlir::tpu

View File

@ -16,6 +16,7 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "absl/types/span.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
#include "tsl/platform/statusor.h"
// TODO: Instead of CHECK_EQs, can we do something like TF_RET_CHECK but with
@ -98,6 +99,14 @@ std::string shapeToString(const T &shape) {
SmallVector<int64_t> ComputeTileStrides(MemRefType memref_ty,
absl::Span<const int64_t> tiling);
// Assuming MKN matmul - This function must only be called after
// canonicalization passes.
//
// Given a set of dimension numbers, Returns a pair of booleans, where the
// first is true if the lhs is transposed
// and the second is true if the rhs is transposed.
std::optional<std::pair<bool, bool>> isTransposedMatmul(
DotDimensionNumbersAttr dim_numbers);
// Returns true if a >=2D memref has a tiled layout and can be equivalently
// considered as an untiled memref, except for potential padding in the
@ -106,6 +115,7 @@ SmallVector<int64_t> ComputeTileStrides(MemRefType memref_ty,
bool canReinterpretToUntiledMemref(MemRefType tiled_memref_ty,
const std::array<int64_t, 2> &target_shape,
bool allow_minormost_padding = false);
} // namespace mlir::tpu
#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_UTIL_H_