mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
[Mosaic] Fix infer/apply extensions.
1. For apply, llvm::StringMap()::insert(MapEntryTy*) will cause dangling reference if not constructing mlir::tpu::extensions::rules() with const-reference. However, if we do construct it with const-reference, the signature is not const-qualified and fails to compile. Hence, change it to llvm::StringMap()::insert(std::pair<...>) and get extension rules by const-reference. 2. Pass default tiling to infer rule, we need it to infer single op. See infer of tpu::MatmulOp. PiperOrigin-RevId: 716274818
This commit is contained in:
parent
0df4475aeb
commit
5c020ee317
@ -4683,9 +4683,8 @@ const llvm::StringMap<rule_type> &rules() {
|
||||
{vector::StoreOp::getOperationName(), vector_store_rule},
|
||||
{vector::TransposeOp::getOperationName(), vector_transpose_rule}};
|
||||
|
||||
llvm::StringMap<rule_type> extended_rules = mlir::tpu::extensions::rules();
|
||||
for (auto &entry : extended_rules) {
|
||||
rules->insert(&entry);
|
||||
for (const auto &[name, rule] : mlir::tpu::extensions::rules()) {
|
||||
rules->insert({name, rule});
|
||||
}
|
||||
return rules;
|
||||
}();
|
||||
|
@ -1,5 +1,8 @@
|
||||
#include "jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h"
|
||||
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
|
||||
#include "mlir/include/mlir/IR/Operation.h"
|
||||
#include "mlir/include/mlir/Support/LLVM.h"
|
||||
#include "mlir/include/mlir/Support/LogicalResult.h"
|
||||
@ -8,6 +11,9 @@ namespace mlir::tpu::extensions {
|
||||
|
||||
bool canInferVectorLayout(const Operation &op) { return false; }
|
||||
|
||||
LogicalResult inferVectorLayout(const Operation &op) { return failure(); }
|
||||
LogicalResult inferVectorLayout(const Operation &op,
|
||||
std::array<int64_t, 2> target_shape) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
} // namespace mlir::tpu::extensions
|
||||
} // namespace mlir::tpu::extensions
|
||||
|
@ -328,7 +328,8 @@ class VectorLayoutInferer {
|
||||
return failure();
|
||||
}
|
||||
} else if (mlir::tpu::extensions::canInferVectorLayout(any_op)) {
|
||||
if (mlir::tpu::extensions::inferVectorLayout(any_op).failed()) {
|
||||
if (mlir::tpu::extensions::inferVectorLayout(any_op, target_shape_)
|
||||
.failed()) {
|
||||
return failure();
|
||||
}
|
||||
} else {
|
||||
|
@ -1,6 +1,9 @@
|
||||
#ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_VECTOR_LAYOUT_EXTENSIONS_H_
|
||||
#define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_VECTOR_LAYOUT_EXTENSIONS_H_
|
||||
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
|
||||
#include "mlir/include/mlir/IR/Operation.h"
|
||||
#include "mlir/include/mlir/Support/LLVM.h"
|
||||
|
||||
@ -8,7 +11,8 @@ namespace mlir::tpu::extensions {
|
||||
|
||||
bool canInferVectorLayout(const Operation &op);
|
||||
|
||||
LogicalResult inferVectorLayout(const Operation &op);
|
||||
LogicalResult inferVectorLayout(const Operation &op,
|
||||
std::array<int64_t, 2> target_shape);
|
||||
|
||||
} // namespace mlir::tpu::extensions
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user