mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[Mosaic] Add extension files for infer/apply vector layout.
PiperOrigin-RevId: 691868278
This commit is contained in:
parent
7ff5a4eac2
commit
8296f6e0ba
@ -43,6 +43,7 @@ mosaic_gpu_internal_users = []
|
||||
mosaic_internal_users = []
|
||||
pallas_gpu_internal_users = []
|
||||
pallas_tpu_internal_users = []
|
||||
pallas_extension_deps = []
|
||||
|
||||
jax_internal_export_back_compat_test_util_visibility = []
|
||||
jax_internal_test_harnesses_visibility = []
|
||||
|
@ -1,3 +1,6 @@
|
||||
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
|
||||
load("@rules_python//python:defs.bzl", "py_library")
|
||||
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -11,9 +14,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
|
||||
load("@rules_python//python:defs.bzl", "py_library")
|
||||
load("//jaxlib:jax.bzl", "pallas_extension_deps")
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
@ -41,6 +42,7 @@ cc_library(
|
||||
"dialect/tpu/tpu_dialect.cc",
|
||||
"dialect/tpu/tpu_ops.cc",
|
||||
"dialect/tpu/util.cc",
|
||||
":extension_srcs",
|
||||
] + glob([
|
||||
"dialect/tpu/transforms/*.cc",
|
||||
]),
|
||||
@ -83,7 +85,7 @@ cc_library(
|
||||
"@xla//xla:array",
|
||||
"@xla//xla:shape_util",
|
||||
"@xla//xla:util",
|
||||
],
|
||||
] + pallas_extension_deps,
|
||||
)
|
||||
|
||||
gentbl_cc_library(
|
||||
@ -226,3 +228,11 @@ cc_library(
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "extension_srcs",
|
||||
srcs = [
|
||||
"dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc",
|
||||
"dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc",
|
||||
],
|
||||
)
|
||||
|
@ -13,16 +13,15 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVectorExtras.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
#include "llvm/ADT/iterator_range.h"
|
||||
#include "llvm/Support/Compiler.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/MathExtras.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Traits.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
@ -52,7 +51,6 @@
|
||||
#include "absl/log/log.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
@ -61,6 +59,7 @@
|
||||
#include "mlir/include/mlir/IR/OperationSupport.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/layout.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/util.h"
|
||||
#include "xla/array.h"
|
||||
@ -4586,6 +4585,7 @@ LogicalResult prng_random_bits_rule(RewriteContext &ctx, Operation &op,
|
||||
}
|
||||
|
||||
const llvm::StringMap<rule_type> &rules() {
|
||||
static const llvm::StringMap<rule_type> *rules = [] {
|
||||
static auto rules = new llvm::StringMap<rule_type>{
|
||||
{arith::ConstantOp::getOperationName(), arith_constant_rule},
|
||||
{arith::ExtFOp::getOperationName(), arith_extf_rule},
|
||||
@ -4625,6 +4625,13 @@ const llvm::StringMap<rule_type> &rules() {
|
||||
{vector::ShapeCastOp::getOperationName(), vector_shape_cast_rule},
|
||||
{vector::StoreOp::getOperationName(), vector_store_rule},
|
||||
{vector::TransposeOp::getOperationName(), vector_transpose_rule}};
|
||||
|
||||
llvm::StringMap<rule_type> extended_rules = mlir::tpu::extensions::rules();
|
||||
for (auto &entry : extended_rules) {
|
||||
rules->insert(&entry);
|
||||
}
|
||||
return rules;
|
||||
}();
|
||||
return *rules;
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,21 @@
|
||||
#ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANFORMS_APPLY_VECTOR_LAYOUT_EXTENSIONS_H_
|
||||
#define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANFORMS_APPLY_VECTOR_LAYOUT_EXTENSIONS_H_
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "llvm/include/llvm/ADT/StringMap.h"
|
||||
#include "mlir/include/mlir/IR/Operation.h"
|
||||
#include "mlir/include/mlir/Support/LLVM.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/layout.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
|
||||
|
||||
namespace mlir::tpu::extensions {
|
||||
|
||||
const llvm::StringMap<
|
||||
std::function<LogicalResult(ApplyVectorLayoutContext &, Operation &,
|
||||
ArrayRef<Layout>, ArrayRef<Layout>)>> &
|
||||
rules();
|
||||
|
||||
} // namespace mlir::tpu::extensions
|
||||
|
||||
#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANFORMS_APPLY_VECTOR_LAYOUT_EXTENSIONS_H_
|
@ -0,0 +1,19 @@
|
||||
#include "jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h"
|
||||
|
||||
#include "llvm/include/llvm/ADT/StringMap.h"
|
||||
#include "mlir/include/mlir/IR/Operation.h"
|
||||
|
||||
namespace mlir::tpu::extensions {
|
||||
|
||||
using RewriteContext = ApplyVectorLayoutContext;
|
||||
|
||||
using rule_type = std::function<LogicalResult(
|
||||
RewriteContext &, Operation &, ArrayRef<Layout>, ArrayRef<Layout>)>;
|
||||
|
||||
const llvm::StringMap<rule_type> &rules() {
|
||||
static const llvm::StringMap<rule_type> *rules =
|
||||
new llvm::StringMap<rule_type>{};
|
||||
return *rules;
|
||||
}
|
||||
|
||||
} // namespace mlir::tpu::extensions
|
@ -0,0 +1,13 @@
|
||||
#include "jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h"
|
||||
|
||||
#include "mlir/include/mlir/IR/Operation.h"
|
||||
#include "mlir/include/mlir/Support/LLVM.h"
|
||||
#include "mlir/include/mlir/Support/LogicalResult.h"
|
||||
|
||||
namespace mlir::tpu::extensions {
|
||||
|
||||
bool canInferVectorLayout(const Operation &op) { return false; }
|
||||
|
||||
LogicalResult inferVectorLayout(const Operation &op) { return failure(); }
|
||||
|
||||
} // namespace mlir::tpu::extensions
|
@ -50,6 +50,7 @@ limitations under the License.
|
||||
#include "mlir/include/mlir/Pass/Pass.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/layout.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/util.h"
|
||||
#include "xla/layout.h"
|
||||
|
||||
@ -337,6 +338,10 @@ class VectorLayoutInferer {
|
||||
if (inferElementwise(&any_op).failed()) {
|
||||
return failure();
|
||||
}
|
||||
} else if (mlir::tpu::extensions::canInferVectorLayout(any_op)) {
|
||||
if (mlir::tpu::extensions::inferVectorLayout(any_op).failed()) {
|
||||
return failure();
|
||||
}
|
||||
} else {
|
||||
any_op.emitOpError("unsupported in vector layout inference");
|
||||
return failure();
|
||||
|
@ -0,0 +1,15 @@
|
||||
#ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_VECTOR_LAYOUT_EXTENSIONS_H_
|
||||
#define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_VECTOR_LAYOUT_EXTENSIONS_H_
|
||||
|
||||
#include "mlir/include/mlir/IR/Operation.h"
|
||||
#include "mlir/include/mlir/Support/LLVM.h"
|
||||
|
||||
namespace mlir::tpu::extensions {
|
||||
|
||||
bool canInferVectorLayout(const Operation &op);
|
||||
|
||||
LogicalResult inferVectorLayout(const Operation &op);
|
||||
|
||||
} // namespace mlir::tpu::extensions
|
||||
|
||||
#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_VECTOR_LAYOUT_EXTENSIONS_H_
|
Loading…
x
Reference in New Issue
Block a user