1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-20 05:46:06 +00:00

[Mosaic TPU][NFC] Throw NYI error instead of crash when squeeze ref to 1d.

PiperOrigin-RevId: 736263705
This commit is contained in:
Jevin Jiang 2025-03-12 14:17:45 -07:00 committed by jax authors
parent 47480b4493
commit 12c0987e2f
2 changed files with 7 additions and 1 deletions
jaxlib/mosaic/dialect/tpu

@ -33,6 +33,7 @@ limitations under the License.
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "absl/hash/hash.h"
#include "absl/log/log.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.cc.inc"
@ -182,7 +183,8 @@ AffineMap TiledLayoutAttr::getAffineMap() const {
auto dimensions = tile.dimensions();
int64_t untiled_dims = map.getNumResults() - dimensions.size();
if (untiled_dims < 0) {
LOG(FATAL) << "Invalid TiledLayoutAttr!";
LOG(FATAL) << "Invalid TiledLayoutAttr: Number of dims must be larger "
"or equal to the rank of the tile";
}
for (int64_t i = 0; i < untiled_dims; ++i) {
exprs.push_back(getAffineDimExpr(i, getContext()));

@ -177,6 +177,10 @@ LogicalResult MemRefSqueezeOp::verify() {
this->emitOpError("Element types don't match.");
return failure();
}
if (!HasMemorySpace(source_type, tpu::MemorySpace::kSemaphoreMem) &&
source_type.getRank() > 1 && target_type.getRank() == 1) {
return emitError("Not implemented: squeeze memref to 1d.");
}
auto source_shape = source_type.getShape();
auto target_shape = target_type.getShape();
int source_index = source_shape.size() - 1;