mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic] Extend tpu matmulop to have dimension dims. Add support for batching and simple transposition.
PiperOrigin-RevId: 691706218
This commit is contained in:
parent
f355dcf34b
commit
5aeffde707
@ -1615,6 +1615,7 @@ def _dot_general_lowering_rule(
|
||||
)
|
||||
return vector.shape_cast(out_type, red)
|
||||
|
||||
# TODO(mvoz): Plumb these into dot dimension numbers on the matmul op!
|
||||
if lhs_dims == (1,):
|
||||
transpose_lhs = False
|
||||
elif lhs_dims == (0,):
|
||||
|
@ -384,22 +384,47 @@ def TPU_DynamicGatherOp : TPU_Op<"dynamic_gather", [Pure]> {
|
||||
}];
|
||||
}
|
||||
|
||||
// TODO(apaszke): Add a verifier for this op
|
||||
|
||||
def TPU_DotDimensionNumbersAttr : TPU_Attr<"DotDimensionNumbers", "dot_dimension_numbers"> {
|
||||
let parameters = (ins
|
||||
ArrayRefParameter<"int64_t", "">:$lhs_contracting_dims,
|
||||
ArrayRefParameter<"int64_t", "">:$rhs_contracting_dims,
|
||||
ArrayRefParameter<"int64_t", "">:$lhs_non_contracting_dims,
|
||||
ArrayRefParameter<"int64_t", "">:$rhs_non_contracting_dims,
|
||||
// The contract is a flattened structure, wherein, each element is half of a
|
||||
// pair of indices. The first element is always 0 (lhs) or 1 (rhs) and the
|
||||
// second index is the index from the lhs or rhs.
|
||||
ArrayRefParameter<"int64_t", "">:$output_dim_order,
|
||||
OptionalArrayRefParameter<"int64_t", "">:$lhs_batch_dims,
|
||||
OptionalArrayRefParameter<"int64_t", "">:$rhs_batch_dims
|
||||
);
|
||||
let assemblyFormat = "`<` `[` $lhs_contracting_dims `]` `,` `[` $rhs_contracting_dims `]` `,` "
|
||||
"`[` $lhs_non_contracting_dims `]` `,` `[` $rhs_non_contracting_dims `]` `,` "
|
||||
"`[` $output_dim_order `]` `,` "
|
||||
"`[` (`]`):($lhs_batch_dims^ `]`)? `,` "
|
||||
"`[` (`]`):($rhs_batch_dims^ `]`)? `>`";
|
||||
}
|
||||
|
||||
// TODO(apaszke): Think hard about precision
|
||||
def TPU_MatmulOp : TPU_Op<"matmul", [Pure]> {
|
||||
let arguments = (ins
|
||||
AnyVector:$lhs,
|
||||
AnyVector:$rhs,
|
||||
AnyVector:$acc,
|
||||
// These flags are deprecated - if dimension_numbers are defined,
|
||||
// these flags are ignored. They will always be false after canonicalize.
|
||||
DefaultValuedAttr<BoolAttr, "false">:$transpose_lhs,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$transpose_rhs,
|
||||
OptionalAttr<TPU_ContractPrecisionEnum>:$precision
|
||||
OptionalAttr<TPU_ContractPrecisionEnum>:$precision,
|
||||
// NOTE: User-level optional, once canonicalized, always present.
|
||||
OptionalAttr<TPU_DotDimensionNumbersAttr>:$dimension_numbers
|
||||
);
|
||||
let results = (outs AnyVector:$result);
|
||||
let assemblyFormat = [{
|
||||
$lhs `,` $rhs `,` $acc attr-dict `:` type($lhs) `,` type($rhs) `,` type($acc) `->` type($result)
|
||||
}];
|
||||
let hasCanonicalizer = 1;
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def TPU_ConcatenateOp : TPU_Op<"concatenate", [Pure]> {
|
||||
|
@ -223,4 +223,18 @@ bool isGuaranteedDivisible(Value value, int64_t divisor, int64_t fuel) {
|
||||
return false;
|
||||
}
|
||||
|
||||
DotDimensionNumbersAttr defaultDimensionNumbers(Builder &builder,
|
||||
bool transpose_lhs,
|
||||
bool transpose_rhs) {
|
||||
return tpu::DotDimensionNumbersAttr::get(
|
||||
builder.getContext(),
|
||||
/*lhs_contracting_dims=*/{transpose_lhs ? 0 : 1},
|
||||
/*rhs_contracting_dims=*/{transpose_rhs ? 1 : 0},
|
||||
/*lhs_non_contracting_dims=*/{transpose_lhs ? 1 : 0},
|
||||
/*rhs_non_contracting_dims=*/{transpose_rhs ? 0 : 1},
|
||||
/*output_dim_order=*/{0, transpose_lhs ? 1 : 0, 1, transpose_rhs ? 0 : 1},
|
||||
/*lhs_batch_dims=*/{},
|
||||
/*rhs_batch_dims=*/{});
|
||||
}
|
||||
|
||||
} // namespace mlir::tpu
|
||||
|
@ -104,6 +104,10 @@ MemRefType getMemRefType(Value value);
|
||||
|
||||
bool isGuaranteedDivisible(Value value, int64_t divisor, int64_t fuel = 8);
|
||||
|
||||
DotDimensionNumbersAttr defaultDimensionNumbers(Builder &builder,
|
||||
bool transpose_lhs,
|
||||
bool transpose_rhs);
|
||||
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc"
|
||||
|
||||
|
@ -14,6 +14,9 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
@ -507,6 +510,280 @@ class CanonicalizeAddOfMatmul : public OpRewritePattern<AddOp> {
|
||||
}
|
||||
};
|
||||
|
||||
LogicalResult MatmulOp::verify() {
|
||||
// Note - this is not yet an exhaustive verification of matmul. Many of the
|
||||
// invariants are spread across infer, apply, llo and below. This is,
|
||||
// however, a good start and the recommended place to add more invariants.
|
||||
const VectorType lhs_ty = getLhs().getType();
|
||||
const VectorType rhs_ty = getRhs().getType();
|
||||
|
||||
if (getTransposeLhs()) {
|
||||
emitOpError(
|
||||
"Lhs transpose not supported via this API - please use the "
|
||||
"dimension numbers API.");
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (getDimensionNumbers().has_value()) {
|
||||
auto dimension_numbers = getDimensionNumbers().value();
|
||||
auto lhs_contracting_dims = dimension_numbers.getLhsContractingDims();
|
||||
auto rhs_contracting_dims = dimension_numbers.getRhsContractingDims();
|
||||
if (lhs_contracting_dims.size() != 1) {
|
||||
emitOpError("Not implemented: lhs contracting dims must be of size 1");
|
||||
return failure();
|
||||
}
|
||||
if (rhs_contracting_dims.size() != 1) {
|
||||
emitOpError("Not implemented: rhs contracting dims must be of size 1");
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto lhs_contracting_dim = lhs_contracting_dims[0];
|
||||
auto rhs_contracting_dim = rhs_contracting_dims[0];
|
||||
|
||||
auto lhs_batch_dims = dimension_numbers.getLhsBatchDims();
|
||||
auto rhs_batch_dims = dimension_numbers.getRhsBatchDims();
|
||||
|
||||
auto lhs_non_contracting_dims =
|
||||
dimension_numbers.getLhsNonContractingDims();
|
||||
auto rhs_non_contracting_dims =
|
||||
dimension_numbers.getRhsNonContractingDims();
|
||||
|
||||
if (lhs_contracting_dims.size() + lhs_non_contracting_dims.size() +
|
||||
lhs_batch_dims.size() !=
|
||||
lhs_ty.getShape().size()) {
|
||||
emitOpError(
|
||||
"Not implemented: lhs contracting + non contracting + batch dims "
|
||||
"must be of the same size as the lhs shape");
|
||||
return failure();
|
||||
}
|
||||
if (rhs_contracting_dims.size() + rhs_non_contracting_dims.size() +
|
||||
rhs_batch_dims.size() !=
|
||||
rhs_ty.getShape().size()) {
|
||||
emitOpError(
|
||||
"Not implemented: rhs contracting + non contracting + batch dims "
|
||||
"must be of the same size as the rhs shape");
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (lhs_ty.getShape()[lhs_contracting_dim] !=
|
||||
rhs_ty.getShape()[rhs_contracting_dim]) {
|
||||
emitOpError(
|
||||
"Not implemented: lhs and rhs contracting dims must be of the same "
|
||||
"size");
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (lhs_batch_dims.size() != rhs_batch_dims.size()) {
|
||||
emitOpError(
|
||||
"Not implemented: lhs and rhs should have the same number of batch "
|
||||
"dims");
|
||||
return failure();
|
||||
}
|
||||
if (lhs_batch_dims.size() > 1) {
|
||||
emitOpError("Not implemented: Up to 1 batch dim supported");
|
||||
return failure();
|
||||
}
|
||||
|
||||
int64_t lhs_rank = lhs_ty.getShape().size();
|
||||
int64_t rhs_rank = rhs_ty.getShape().size();
|
||||
|
||||
std::vector<bool> seen_dims_lhs(lhs_rank, false);
|
||||
std::vector<bool> seen_dims_rhs(rhs_rank, false);
|
||||
|
||||
auto check_and_mark_dims = [&](const std::vector<int64_t> &dims,
|
||||
std::vector<bool> &seen_dims,
|
||||
const std::string_view operand) {
|
||||
for (int64_t dim : dims) {
|
||||
if (seen_dims[dim]) {
|
||||
emitOpError("Illegal: Dim ")
|
||||
<< dim << " repeats in dimension numbers of " << operand;
|
||||
return failure();
|
||||
}
|
||||
seen_dims[dim] = true;
|
||||
}
|
||||
return success();
|
||||
};
|
||||
|
||||
if (failed(
|
||||
check_and_mark_dims(lhs_contracting_dims, seen_dims_lhs, "lhs")) ||
|
||||
failed(check_and_mark_dims(lhs_non_contracting_dims, seen_dims_lhs,
|
||||
"lhs")) ||
|
||||
failed(check_and_mark_dims(lhs_batch_dims, seen_dims_lhs, "lhs"))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (failed(
|
||||
check_and_mark_dims(rhs_contracting_dims, seen_dims_rhs, "rhs")) ||
|
||||
failed(check_and_mark_dims(rhs_non_contracting_dims, seen_dims_rhs,
|
||||
"rhs")) ||
|
||||
failed(check_and_mark_dims(rhs_batch_dims, seen_dims_rhs, "rhs"))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
for (int64_t dim = 0; dim < lhs_rank; ++dim) {
|
||||
if (!seen_dims_lhs[dim]) {
|
||||
emitOpError("Illegal: Dim ")
|
||||
<< dim << " is not seen in lhs dimension numbers";
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
for (int64_t dim = 0; dim < rhs_rank; ++dim) {
|
||||
if (!seen_dims_rhs[dim]) {
|
||||
emitOpError("Illegal: Dim ")
|
||||
<< dim << " is not seen in rhs dimension numbers";
|
||||
}
|
||||
}
|
||||
|
||||
const std::optional<int64_t> batch_dim_lhs =
|
||||
lhs_batch_dims.empty() ? std::nullopt
|
||||
: std::optional<int64_t>(lhs_batch_dims[0]);
|
||||
const std::optional<int64_t> batch_dim_rhs =
|
||||
rhs_batch_dims.empty() ? std::nullopt
|
||||
: std::optional<int64_t>(rhs_batch_dims[0]);
|
||||
if (batch_dim_lhs != batch_dim_rhs) {
|
||||
emitOpError("Not Implemented: batch dims must be equal");
|
||||
return failure();
|
||||
}
|
||||
if (batch_dim_lhs.has_value() && (batch_dim_lhs.value() != 0)) {
|
||||
emitOpError("Not Implemented: batch dims pos must be 0");
|
||||
return failure();
|
||||
}
|
||||
// Invariant above enforces only 1 batch dim atm, and that both are eq
|
||||
std::optional<int64_t> batch_size = std::nullopt;
|
||||
if (batch_dim_lhs.has_value()) {
|
||||
batch_size = lhs_ty.getShape()[batch_dim_lhs.value()];
|
||||
auto rhs_batch_size = rhs_ty.getShape()[batch_dim_rhs.value()];
|
||||
if (batch_size != rhs_batch_size) {
|
||||
emitOpError("Not Implemented: batch dims must be equal");
|
||||
return failure();
|
||||
}
|
||||
if (batch_size == 0) {
|
||||
emitOpError("Illegal: batch size must be > 0");
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
auto output_dim_order = dimension_numbers.getOutputDimOrder();
|
||||
if (output_dim_order.size() % 2 != 0) {
|
||||
emitOpError(
|
||||
"Illegal: output dim order must have an even number of elements.");
|
||||
return failure();
|
||||
}
|
||||
if (batch_size.has_value()) {
|
||||
if (output_dim_order[0] != 0 || output_dim_order[1] != 0) {
|
||||
emitOpError(
|
||||
"Not implemented: Output with batch size must be the lhs 0 idx for "
|
||||
"now.");
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
// Invariants above enforce a single batch idx for now, and that it is in
|
||||
// position 0. Future extensions to this will be to:
|
||||
// 1. Support multiple batch dims
|
||||
// 2. Support batch dims in any position in the output dim order
|
||||
if (lhs_non_contracting_dims.size() != 1) {
|
||||
emitOpError(
|
||||
"Not implemented: lhs non contracting dims must be of size 1");
|
||||
return failure();
|
||||
}
|
||||
if (rhs_non_contracting_dims.size() != 1) {
|
||||
emitOpError(
|
||||
"Not implemented: rhs non contracting dims must be of size 1");
|
||||
return failure();
|
||||
}
|
||||
|
||||
// A bit long winded, but the invariants we enforce below are:
|
||||
// 1. The output order idx is 0 (lhs) or 1 (rhs)
|
||||
// 2. The output dim order is in valid bounds
|
||||
// 3. We saw the rhs and lhs non contracting dims in the output dim order
|
||||
// 4. We never see the contracting dims in the output dim order
|
||||
// 5. We only see each of the non contracting dim once
|
||||
std::vector<bool> lhs_dims_seen_in_output(lhs_rank, false);
|
||||
std::vector<bool> rhs_dims_seen_in_output(rhs_rank, false);
|
||||
|
||||
// Iterate over the output dimension order
|
||||
for (int dim_pos = 0; dim_pos < output_dim_order.size(); dim_pos += 2) {
|
||||
auto idx = output_dim_order[dim_pos];
|
||||
auto dim = output_dim_order[dim_pos + 1];
|
||||
|
||||
if (idx != 0 && idx != 1) {
|
||||
emitOpError("Illegal: output dim order index must be 0 or 1");
|
||||
return failure();
|
||||
}
|
||||
auto is_lhs = (idx == 0);
|
||||
|
||||
if (is_lhs) {
|
||||
if (dim < 0 || dim >= lhs_rank) {
|
||||
emitOpError("Illegal: lhs dimension index out of bounds");
|
||||
return failure();
|
||||
}
|
||||
if (lhs_dims_seen_in_output[dim]) {
|
||||
emitOpError("Illegal: lhs dimension ")
|
||||
<< dim << " appears more than once in output dim order";
|
||||
return failure();
|
||||
}
|
||||
if (dim == lhs_contracting_dim) {
|
||||
emitOpError("Illegal: contracting dimension ")
|
||||
<< dim << " appears in lhs output dim order";
|
||||
return failure();
|
||||
}
|
||||
// batch_dim_lhs is either 0 or nullopt
|
||||
if (dim == batch_dim_lhs) {
|
||||
// Upstream invariants enforce that batch dim is in position 0
|
||||
// of the output dim order.
|
||||
rhs_dims_seen_in_output[dim] = true;
|
||||
}
|
||||
lhs_dims_seen_in_output[dim] = true;
|
||||
} else {
|
||||
if (dim < 0 || dim >= rhs_rank) {
|
||||
emitOpError("Illegal: rhs dimension index out of bounds");
|
||||
return failure();
|
||||
}
|
||||
if (rhs_dims_seen_in_output[dim]) {
|
||||
emitOpError("Illegal: rhs dimension ")
|
||||
<< dim << " appears more than once in output dim order";
|
||||
return failure();
|
||||
}
|
||||
if (dim == rhs_contracting_dim) {
|
||||
emitOpError("Illegal: contracting dimension ")
|
||||
<< dim << " appears in rhs output dim order";
|
||||
return failure();
|
||||
}
|
||||
if (dim == batch_dim_rhs) {
|
||||
// Upstream invariants enforce that batch dim is in position 0
|
||||
// of the output dim order.
|
||||
lhs_dims_seen_in_output[dim] = true;
|
||||
}
|
||||
rhs_dims_seen_in_output[dim] = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Check that all dims have been seen (except contracting dims)
|
||||
for (int i = 0; i < lhs_rank; ++i) {
|
||||
if (i == lhs_contracting_dim) {
|
||||
continue;
|
||||
}
|
||||
if (!lhs_dims_seen_in_output[i]) {
|
||||
emitOpError("Illegal: lhs non-contracting dimension ")
|
||||
<< i << " is not seen in output dim order";
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < rhs_rank; ++i) {
|
||||
if (i == rhs_contracting_dim) {
|
||||
continue;
|
||||
}
|
||||
if (!rhs_dims_seen_in_output[i]) {
|
||||
emitOpError("Illegal: rhs non-contracting dimension ")
|
||||
<< i << " is not seen in output dim order";
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
void MatmulOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add<CanonicalizeAddOfMatmul<arith::AddFOp>,
|
||||
|
@ -1696,15 +1696,36 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
|
||||
llvm::all_of(layouts_in, [&](const Layout &l) { return l.has_value(); }));
|
||||
TPU_ASSERT_OP(layouts_out.front().has_value());
|
||||
auto matmul_op = cast<tpu::MatmulOp>(op);
|
||||
const auto transpose_lhs = matmul_op.getTransposeLhs();
|
||||
const auto transpose_rhs = matmul_op.getTransposeRhs();
|
||||
const auto &layout_lhs = *layouts_in[0];
|
||||
const auto &layout_rhs = *layouts_in[1];
|
||||
const auto &layout_acc = *layouts_in[2];
|
||||
const auto &layout_out = *layouts_out[0];
|
||||
if (transpose_lhs) {
|
||||
return op.emitOpError("Not implemented: Transposed LHS");
|
||||
if (matmul_op.getTransposeRhs()) {
|
||||
return op.emitOpError(
|
||||
"Transposition must have been erased into dimension numbers during "
|
||||
"canonicalization");
|
||||
}
|
||||
|
||||
auto dimension_numbers = matmul_op.getDimensionNumbers();
|
||||
if (!dimension_numbers.has_value()) {
|
||||
return op.emitOpError(
|
||||
"Dimension numbers must be provided, ensure canonicalization has been "
|
||||
"run.");
|
||||
}
|
||||
auto transposed_mkn = isTransposedMatmul(dimension_numbers.value());
|
||||
if (!transposed_mkn.has_value()) {
|
||||
return op.emitOpError(
|
||||
"Dimension numbers must be MKN, ensure canonicalization has been "
|
||||
"run.");
|
||||
}
|
||||
auto [transpose_lhs, transpose_rhs] = transposed_mkn.value();
|
||||
if (transpose_lhs) {
|
||||
return op.emitOpError(
|
||||
"Transposition of LHS is not supported in apply_vector_layout, ensure "
|
||||
"canonicalization has been run.");
|
||||
}
|
||||
|
||||
auto &layout_lhs = *layouts_in[0];
|
||||
auto &layout_rhs = *layouts_in[1];
|
||||
auto &layout_acc = *layouts_in[2];
|
||||
auto &layout_out = *layouts_out[0];
|
||||
|
||||
const std::array<std::reference_wrapper<const VectorLayout>, 4> all_layouts =
|
||||
{layout_lhs, layout_rhs, layout_acc, layout_out};
|
||||
for (const VectorLayout &layout : all_layouts) {
|
||||
@ -1965,6 +1986,8 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
|
||||
|
||||
const tpu::ContractPrecisionAttr precision_attr = // May be null
|
||||
op.getAttrOfType<tpu::ContractPrecisionAttr>("precision");
|
||||
const tpu::DotDimensionNumbersAttr dot_dimension_numbers_attr =
|
||||
defaultDimensionNumbers(builder, false, transpose_rhs);
|
||||
for (int64_t j = 0; j < nj; ++j) {
|
||||
for (int64_t k = 0; k < nk; ++k) {
|
||||
// TODO(tlongeri): there should be a way to slice without copying
|
||||
@ -1981,7 +2004,8 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
|
||||
acc_col->setAttr("out_layout", acc_layout_attr);
|
||||
auto new_acc_col = builder.create<tpu::MatmulOp>(
|
||||
op.getLoc(), acc_col_ty, lhs_cols[k], rhs_rolled_group, acc_col,
|
||||
transpose_lhs, transpose_rhs, precision_attr);
|
||||
/*transpose_lhs=*/false, /*transpose_rhs=*/false, precision_attr,
|
||||
dot_dimension_numbers_attr);
|
||||
auto new_acc_vregs = builder.create<tpu::UnrollVectorsOp>(
|
||||
op.getLoc(),
|
||||
TypeRange(ValueRange(XlaArrayToFlatArrayRef(acc_col_vregs))),
|
||||
|
@ -1,6 +1,10 @@
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
@ -16,6 +20,8 @@
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h"
|
||||
@ -23,6 +29,7 @@
|
||||
#include "mlir/include/mlir/IR/Attributes.h"
|
||||
#include "mlir/include/mlir/IR/Block.h"
|
||||
#include "mlir/include/mlir/IR/Builders.h"
|
||||
#include "mlir/include/mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h"
|
||||
#include "mlir/include/mlir/IR/OpDefinition.h"
|
||||
#include "mlir/include/mlir/IR/Operation.h"
|
||||
@ -40,6 +47,9 @@ namespace mlir::tpu {
|
||||
LogicalResult tpu_matmul_rule(tpu::MatmulOp op) {
|
||||
ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation());
|
||||
|
||||
auto transpose_lhs = op.getTransposeLhs();
|
||||
auto transpose_rhs = op.getTransposeRhs();
|
||||
|
||||
auto lhs = op.getLhs();
|
||||
auto rhs = op.getRhs();
|
||||
auto acc = op.getAcc();
|
||||
@ -52,6 +62,51 @@ LogicalResult tpu_matmul_rule(tpu::MatmulOp op) {
|
||||
auto rhs_element_type = rhs_ty.getElementType();
|
||||
auto acc_element_type = acc_ty.getElementType();
|
||||
|
||||
// there are a few primary paths for dimension_numbers in matmul
|
||||
// 1) No dimension numbers provided -> set to default
|
||||
// 2) defined and not default -> verify and apply
|
||||
// 3) defined and matching defaultDimensionNumbers -> no-op for
|
||||
// canonicalization of dims
|
||||
std::optional<int64_t> batch_size = std::nullopt;
|
||||
|
||||
// MKN matmul - no dims or transpositions set
|
||||
if (!op.getDimensionNumbers().has_value()) {
|
||||
// Legacy API - convert it to dimension numbers
|
||||
op.setDimensionNumbersAttr(
|
||||
defaultDimensionNumbers(builder, transpose_lhs, transpose_rhs));
|
||||
} else if (
|
||||
// Dot dim API - dimensions are provided and are not default
|
||||
(op.getDimensionNumbers().value() !=
|
||||
defaultDimensionNumbers(builder, false, false))) {
|
||||
auto dimension_numbers = op.getDimensionNumbers();
|
||||
auto lhs_contracting_dims = dimension_numbers->getLhsContractingDims();
|
||||
auto rhs_contracting_dims = dimension_numbers->getRhsContractingDims();
|
||||
|
||||
auto lhs_batch_dims = dimension_numbers->getLhsBatchDims();
|
||||
auto rhs_batch_dims = dimension_numbers->getRhsBatchDims();
|
||||
|
||||
// Invariant in matmul verifier: <= 1 batch dim atm, and that lhs and rhs
|
||||
// are the same
|
||||
// Invariant in matmul verifier: Exactly one contracting and non contracting
|
||||
// dim in each of lhs and rhs for now.
|
||||
batch_size =
|
||||
lhs_batch_dims.empty()
|
||||
? std::nullopt
|
||||
: std::optional<int64_t>(lhs_ty.getShape()[lhs_batch_dims[0]]);
|
||||
// Lower each dim in contracting dims by size(batch_dims)
|
||||
auto batch_adjusted_lhs_contracting_dim =
|
||||
lhs_contracting_dims[0] - lhs_batch_dims.size();
|
||||
auto batch_adjusted_rhs_contracting_dim =
|
||||
rhs_contracting_dims[0] - rhs_batch_dims.size();
|
||||
|
||||
if (batch_adjusted_lhs_contracting_dim != 1) {
|
||||
transpose_lhs = true;
|
||||
}
|
||||
if (batch_adjusted_rhs_contracting_dim != 0) {
|
||||
transpose_rhs = true;
|
||||
}
|
||||
}
|
||||
|
||||
auto extsi_sitofp = [&builder, &op](TypedValue<VectorType> element) {
|
||||
const VectorType ty = element.getType();
|
||||
auto shape = ty.getShape();
|
||||
@ -88,10 +143,12 @@ LogicalResult tpu_matmul_rule(tpu::MatmulOp op) {
|
||||
if (lhs_element_type.isInteger()) {
|
||||
auto float_lhs = extsi_sitofp(lhs);
|
||||
op->setOperand(0, float_lhs);
|
||||
lhs = cast<TypedValue<VectorType>>(float_lhs.getResult());
|
||||
}
|
||||
if (rhs_element_type.isInteger()) {
|
||||
auto float_rhs = extsi_sitofp(rhs);
|
||||
op->setOperand(1, float_rhs);
|
||||
rhs = cast<TypedValue<VectorType>>(float_rhs.getResult());
|
||||
}
|
||||
}
|
||||
// TODO(mvoz): Add more invariants.
|
||||
@ -114,6 +171,91 @@ LogicalResult tpu_matmul_rule(tpu::MatmulOp op) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
auto dot_dim_matmul = [&](auto lhs, auto rhs, auto acc) {
|
||||
auto precision_attr = op.getPrecisionAttr();
|
||||
|
||||
// If we are transposing the lhs, we need to transpose the lhs before
|
||||
// matmul here, as we don't have lhs fusion implemented in apply.
|
||||
if (transpose_lhs) {
|
||||
auto lhs_ty = cast<VectorType>(lhs.getType());
|
||||
auto rank = lhs_ty.getShape().size();
|
||||
|
||||
// This transposition must run on vectors with rank >= 2
|
||||
CHECK_GE(rank, 2);
|
||||
|
||||
std::vector<int64_t> perm(rank);
|
||||
std::iota(perm.begin(), perm.end(), 0);
|
||||
std::swap(perm[rank - 2], perm[rank - 1]);
|
||||
|
||||
std::vector<int64_t> shape(lhs_ty.getShape());
|
||||
std::swap(shape[rank - 2], shape[rank - 1]);
|
||||
|
||||
auto lhs_ty_transposed = VectorType::get(shape, lhs_ty.getElementType());
|
||||
|
||||
const SmallVector<int64_t> perm_vec =
|
||||
SmallVector<int64_t>(perm.begin(), perm.end());
|
||||
lhs = builder.create<vector::TransposeOp>(
|
||||
lhs_ty_transposed, lhs,
|
||||
DenseI64ArrayAttr::get(builder.getContext(), perm_vec));
|
||||
}
|
||||
auto ddn = defaultDimensionNumbers(builder, /*transpose_lhs=*/false,
|
||||
transpose_rhs);
|
||||
// transpose flags are always false here, because ddn takes precedence
|
||||
// after this pass.
|
||||
auto matmul_res = builder.create<tpu::MatmulOp>(
|
||||
op.getLoc(), acc.getType(), lhs, rhs, acc,
|
||||
/*transpose_lhs=*/false,
|
||||
/*transpose_rhs=*/false, precision_attr, ddn);
|
||||
return matmul_res;
|
||||
};
|
||||
|
||||
// If we have a batch_size, we want to slice rhs and lhs [:batch_size],
|
||||
// and then do O[i] = A[i] @ B[i]
|
||||
// Produce an output shape of [batch_size, m, n]
|
||||
if (batch_size.has_value()) {
|
||||
std::vector<Value> outputs;
|
||||
|
||||
for (int64_t i = 0; i < batch_size; ++i) {
|
||||
auto sliced_lhs = builder.create<vector::ExtractOp>(op.getLoc(), lhs,
|
||||
ArrayRef<int64_t>{i});
|
||||
auto sliced_rhs = builder.create<vector::ExtractOp>(op.getLoc(), rhs,
|
||||
ArrayRef<int64_t>{i});
|
||||
|
||||
auto sliced_acc = builder.create<vector::ExtractOp>(op.getLoc(), acc,
|
||||
ArrayRef<int64_t>{i});
|
||||
|
||||
auto matmul_res =
|
||||
dot_dim_matmul(sliced_lhs.getResult(), sliced_rhs.getResult(),
|
||||
sliced_acc.getResult());
|
||||
auto res_ty = matmul_res.getType().cast<VectorType>();
|
||||
auto res_shape = res_ty.getShape();
|
||||
// reshape to 1x[prior_shape]
|
||||
auto reshape_shape = llvm::to_vector(res_shape);
|
||||
reshape_shape.insert(reshape_shape.begin(), 1);
|
||||
auto shape_cast = builder.create<vector::ShapeCastOp>(
|
||||
op.getLoc(), VectorType::get(reshape_shape, res_ty.getElementType()),
|
||||
matmul_res);
|
||||
outputs.push_back(shape_cast);
|
||||
}
|
||||
// Technically almost identical to the case where batch_size is 1, but
|
||||
// we want to avoid the spurious concat here.
|
||||
if (batch_size == 1) {
|
||||
op.replaceAllUsesWith(outputs[0]);
|
||||
op.erase();
|
||||
return success();
|
||||
}
|
||||
auto output = builder
|
||||
.create<tpu::ConcatenateOp>(op.getLoc(), acc_ty, outputs,
|
||||
/*dimension=*/0)
|
||||
.getResult();
|
||||
op.replaceAllUsesWith(output);
|
||||
op.erase();
|
||||
} else {
|
||||
auto matmul_res = dot_dim_matmul(lhs, rhs, acc).getResult();
|
||||
op.replaceAllUsesWith(matmul_res);
|
||||
op.erase();
|
||||
}
|
||||
return success();
|
||||
};
|
||||
|
||||
@ -309,9 +451,14 @@ LogicalResult canonicalize_contraction(int hardware_generation, Operation &op) {
|
||||
}
|
||||
const tpu::ContractPrecisionAttr precision_attr = // May be null
|
||||
contraction_op->getAttrOfType<tpu::ContractPrecisionAttr>("precision");
|
||||
|
||||
const auto dot_dimension_numbers_attr =
|
||||
defaultDimensionNumbers(builder, false, transpose_rhs);
|
||||
|
||||
auto matmul_op = builder.create<tpu::MatmulOp>(
|
||||
contraction_op->getLoc(), acc_ty, lhs, rhs, acc,
|
||||
/*transpose_lhs=*/false, transpose_rhs, precision_attr);
|
||||
/*transpose_lhs=*/false,
|
||||
/*transpose_rhs=*/false, precision_attr, dot_dimension_numbers_attr);
|
||||
contraction_op.replaceAllUsesWith(matmul_op.getResult());
|
||||
contraction_op.erase();
|
||||
auto result = tpu_matmul_rule(matmul_op);
|
||||
|
@ -17,6 +17,8 @@ limitations under the License.
|
||||
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
|
||||
#include "llvm/Support/MathExtras.h"
|
||||
#include "absl/types/span.h"
|
||||
@ -42,6 +44,31 @@ SmallVector<int64_t> ComputeTileStrides(MemRefType memref_ty,
|
||||
return tile_strides;
|
||||
}
|
||||
|
||||
std::optional<std::pair<bool, bool>> isTransposedMatmul(
|
||||
DotDimensionNumbersAttr dim_numbers) {
|
||||
auto lhs_contracting_dims = dim_numbers.getLhsContractingDims();
|
||||
auto rhs_contracting_dims = dim_numbers.getRhsContractingDims();
|
||||
auto lhs_non_contracting_dims = dim_numbers.getLhsNonContractingDims();
|
||||
auto rhs_non_contracting_dims = dim_numbers.getRhsNonContractingDims();
|
||||
|
||||
if (lhs_contracting_dims.size() != 1 || rhs_contracting_dims.size() != 1 ||
|
||||
lhs_non_contracting_dims.size() != 1 ||
|
||||
rhs_non_contracting_dims.size() != 1) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
int64_t lhs_non_contracting_dim = lhs_non_contracting_dims[0];
|
||||
int64_t lhs_contracting_dim = lhs_contracting_dims[0];
|
||||
int64_t rhs_non_contracting_dim = rhs_non_contracting_dims[0];
|
||||
int64_t rhs_contracting_dim = rhs_contracting_dims[0];
|
||||
|
||||
bool lhs_transposed = lhs_non_contracting_dim > lhs_contracting_dim;
|
||||
|
||||
bool rhs_transposed = rhs_contracting_dim > rhs_non_contracting_dim;
|
||||
|
||||
return std::pair<bool, bool>{lhs_transposed, rhs_transposed};
|
||||
}
|
||||
|
||||
bool canReinterpretToUntiledMemref(MemRefType tiled_memref_ty,
|
||||
const std::array<int64_t, 2>& target_shape,
|
||||
bool allow_minormost_padding) {
|
||||
@ -68,4 +95,5 @@ bool canReinterpretToUntiledMemref(MemRefType tiled_memref_ty,
|
||||
*(tiled_layout.getTileStrides().end() - 1) == 1 &&
|
||||
*(tiled_layout.getTileStrides().end() - 2) == 1);
|
||||
}
|
||||
|
||||
} // namespace mlir::tpu
|
||||
|
@ -16,6 +16,7 @@
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
|
||||
#include "tsl/platform/statusor.h"
|
||||
|
||||
// TODO: Instead of CHECK_EQs, can we do something like TF_RET_CHECK but with
|
||||
@ -98,6 +99,14 @@ std::string shapeToString(const T &shape) {
|
||||
|
||||
SmallVector<int64_t> ComputeTileStrides(MemRefType memref_ty,
|
||||
absl::Span<const int64_t> tiling);
|
||||
// Assuming MKN matmul - This function must only be called after
|
||||
// canonicalization passes.
|
||||
//
|
||||
// Given a set of dimension numbers, Returns a pair of booleans, where the
|
||||
// first is true if the lhs is transposed
|
||||
// and the second is true if the rhs is transposed.
|
||||
std::optional<std::pair<bool, bool>> isTransposedMatmul(
|
||||
DotDimensionNumbersAttr dim_numbers);
|
||||
|
||||
// Returns true if a >=2D memref has a tiled layout and can be equivalently
|
||||
// considered as an untiled memref, except for potential padding in the
|
||||
@ -106,6 +115,7 @@ SmallVector<int64_t> ComputeTileStrides(MemRefType memref_ty,
|
||||
bool canReinterpretToUntiledMemref(MemRefType tiled_memref_ty,
|
||||
const std::array<int64_t, 2> &target_shape,
|
||||
bool allow_minormost_padding = false);
|
||||
|
||||
} // namespace mlir::tpu
|
||||
|
||||
#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_UTIL_H_
|
||||
|
Loading…
x
Reference in New Issue
Block a user