[Mosaic] Fix infer/apply extensions.

1. For apply, llvm::StringMap()::insert(MapEntryTy*) will cause dangling reference if not constructing mlir::tpu::extensions::rules() with const-reference. However, if we do construct it with const-reference, the signature is not const-qualified and fails to compile. Hence, change it to llvm::StringMap()::insert(std::pair<...>) and get extension rules by const-reference.
2. Pass default tiling to infer rule, we need it to infer single op. See infer of tpu::MatmulOp.

PiperOrigin-RevId: 716274818
This commit is contained in:
Tzu-Wei Sung 2025-01-16 09:56:35 -08:00 committed by jax authors
parent 0df4475aeb
commit 5c020ee317
4 changed files with 17 additions and 7 deletions

View File

@ -4683,9 +4683,8 @@ const llvm::StringMap<rule_type> &rules() {
{vector::StoreOp::getOperationName(), vector_store_rule},
{vector::TransposeOp::getOperationName(), vector_transpose_rule}};
llvm::StringMap<rule_type> extended_rules = mlir::tpu::extensions::rules();
for (auto &entry : extended_rules) {
rules->insert(&entry);
for (const auto &[name, rule] : mlir::tpu::extensions::rules()) {
rules->insert({name, rule});
}
return rules;
}();

View File

@ -1,5 +1,8 @@
#include "jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h"
#include <array>
#include <cstdint>
#include "mlir/include/mlir/IR/Operation.h"
#include "mlir/include/mlir/Support/LLVM.h"
#include "mlir/include/mlir/Support/LogicalResult.h"
@ -8,6 +11,9 @@ namespace mlir::tpu::extensions {
bool canInferVectorLayout(const Operation &op) { return false; }
LogicalResult inferVectorLayout(const Operation &op) { return failure(); }
LogicalResult inferVectorLayout(const Operation &op,
std::array<int64_t, 2> target_shape) {
return failure();
}
} // namespace mlir::tpu::extensions
} // namespace mlir::tpu::extensions

View File

@ -328,7 +328,8 @@ class VectorLayoutInferer {
return failure();
}
} else if (mlir::tpu::extensions::canInferVectorLayout(any_op)) {
if (mlir::tpu::extensions::inferVectorLayout(any_op).failed()) {
if (mlir::tpu::extensions::inferVectorLayout(any_op, target_shape_)
.failed()) {
return failure();
}
} else {

View File

@ -1,6 +1,9 @@
#ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_VECTOR_LAYOUT_EXTENSIONS_H_
#define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_VECTOR_LAYOUT_EXTENSIONS_H_
#include <array>
#include <cstdint>
#include "mlir/include/mlir/IR/Operation.h"
#include "mlir/include/mlir/Support/LLVM.h"
@ -8,7 +11,8 @@ namespace mlir::tpu::extensions {
bool canInferVectorLayout(const Operation &op);
LogicalResult inferVectorLayout(const Operation &op);
LogicalResult inferVectorLayout(const Operation &op,
std::array<int64_t, 2> target_shape);
} // namespace mlir::tpu::extensions