[Mosaic] Add extension files for infer/apply vector layout.

PiperOrigin-RevId: 691868278
This commit is contained in:
Praveen Batra 2024-10-31 11:07:52 -07:00 committed by jax authors
parent 7ff5a4eac2
commit 8296f6e0ba
8 changed files with 137 additions and 46 deletions

View File

@ -43,6 +43,7 @@ mosaic_gpu_internal_users = []
mosaic_internal_users = []
pallas_gpu_internal_users = []
pallas_tpu_internal_users = []
pallas_extension_deps = []
jax_internal_export_back_compat_test_util_visibility = []
jax_internal_test_harnesses_visibility = []

View File

@ -1,3 +1,6 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
load("@rules_python//python:defs.bzl", "py_library")
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -11,9 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
load("@rules_python//python:defs.bzl", "py_library")
load("//jaxlib:jax.bzl", "pallas_extension_deps")
licenses(["notice"])
@ -41,6 +42,7 @@ cc_library(
"dialect/tpu/tpu_dialect.cc",
"dialect/tpu/tpu_ops.cc",
"dialect/tpu/util.cc",
":extension_srcs",
] + glob([
"dialect/tpu/transforms/*.cc",
]),
@ -83,7 +85,7 @@ cc_library(
"@xla//xla:array",
"@xla//xla:shape_util",
"@xla//xla:util",
],
] + pallas_extension_deps,
)
gentbl_cc_library(
@ -226,3 +228,11 @@ cc_library(
],
alwayslink = True,
)
filegroup(
name = "extension_srcs",
srcs = [
"dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc",
"dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc",
],
)

View File

@ -13,16 +13,15 @@
#include <utility>
#include <vector>
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Compiler.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/MathExtras.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@ -52,7 +51,6 @@
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/types/span.h"
#include "llvm/ADT/ArrayRef.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
@ -61,6 +59,7 @@
#include "mlir/include/mlir/IR/OperationSupport.h"
#include "jaxlib/mosaic/dialect/tpu/layout.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
#include "jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h"
#include "jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h"
#include "jaxlib/mosaic/dialect/tpu/util.h"
#include "xla/array.h"
@ -4586,6 +4585,7 @@ LogicalResult prng_random_bits_rule(RewriteContext &ctx, Operation &op,
}
const llvm::StringMap<rule_type> &rules() {
static const llvm::StringMap<rule_type> *rules = [] {
static auto rules = new llvm::StringMap<rule_type>{
{arith::ConstantOp::getOperationName(), arith_constant_rule},
{arith::ExtFOp::getOperationName(), arith_extf_rule},
@ -4625,6 +4625,13 @@ const llvm::StringMap<rule_type> &rules() {
{vector::ShapeCastOp::getOperationName(), vector_shape_cast_rule},
{vector::StoreOp::getOperationName(), vector_store_rule},
{vector::TransposeOp::getOperationName(), vector_transpose_rule}};
llvm::StringMap<rule_type> extended_rules = mlir::tpu::extensions::rules();
for (auto &entry : extended_rules) {
rules->insert(&entry);
}
return rules;
}();
return *rules;
}

View File

@ -0,0 +1,21 @@
#ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANFORMS_APPLY_VECTOR_LAYOUT_EXTENSIONS_H_
#define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANFORMS_APPLY_VECTOR_LAYOUT_EXTENSIONS_H_
#include <functional>
#include "llvm/include/llvm/ADT/StringMap.h"
#include "mlir/include/mlir/IR/Operation.h"
#include "mlir/include/mlir/Support/LLVM.h"
#include "jaxlib/mosaic/dialect/tpu/layout.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
namespace mlir::tpu::extensions {
const llvm::StringMap<
std::function<LogicalResult(ApplyVectorLayoutContext &, Operation &,
ArrayRef<Layout>, ArrayRef<Layout>)>> &
rules();
} // namespace mlir::tpu::extensions
#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANFORMS_APPLY_VECTOR_LAYOUT_EXTENSIONS_H_

View File

@ -0,0 +1,19 @@
#include "jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h"
#include "llvm/include/llvm/ADT/StringMap.h"
#include "mlir/include/mlir/IR/Operation.h"
namespace mlir::tpu::extensions {
using RewriteContext = ApplyVectorLayoutContext;
using rule_type = std::function<LogicalResult(
RewriteContext &, Operation &, ArrayRef<Layout>, ArrayRef<Layout>)>;
const llvm::StringMap<rule_type> &rules() {
static const llvm::StringMap<rule_type> *rules =
new llvm::StringMap<rule_type>{};
return *rules;
}
} // namespace mlir::tpu::extensions

View File

@ -0,0 +1,13 @@
#include "jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h"
#include "mlir/include/mlir/IR/Operation.h"
#include "mlir/include/mlir/Support/LLVM.h"
#include "mlir/include/mlir/Support/LogicalResult.h"
namespace mlir::tpu::extensions {
bool canInferVectorLayout(const Operation &op) { return false; }
LogicalResult inferVectorLayout(const Operation &op) { return failure(); }
} // namespace mlir::tpu::extensions

View File

@ -50,6 +50,7 @@ limitations under the License.
#include "mlir/include/mlir/Pass/Pass.h"
#include "jaxlib/mosaic/dialect/tpu/layout.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
#include "jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h"
#include "jaxlib/mosaic/dialect/tpu/util.h"
#include "xla/layout.h"
@ -337,6 +338,10 @@ class VectorLayoutInferer {
if (inferElementwise(&any_op).failed()) {
return failure();
}
} else if (mlir::tpu::extensions::canInferVectorLayout(any_op)) {
if (mlir::tpu::extensions::inferVectorLayout(any_op).failed()) {
return failure();
}
} else {
any_op.emitOpError("unsupported in vector layout inference");
return failure();

View File

@ -0,0 +1,15 @@
#ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_VECTOR_LAYOUT_EXTENSIONS_H_
#define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_VECTOR_LAYOUT_EXTENSIONS_H_
#include "mlir/include/mlir/IR/Operation.h"
#include "mlir/include/mlir/Support/LLVM.h"
namespace mlir::tpu::extensions {
bool canInferVectorLayout(const Operation &op);
LogicalResult inferVectorLayout(const Operation &op);
} // namespace mlir::tpu::extensions
#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_VECTOR_LAYOUT_EXTENSIONS_H_