mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[TPU][Mosaic] Replace tpu lowering (at canonicalization) for repeat with concat (which handles far more cases)
PiperOrigin-RevId: 691192121
This commit is contained in:
parent
7c4cc9552c
commit
5ad066eeaa
@ -295,6 +295,8 @@ def TPU_IotaOp : TPU_Op<"iota", [Pure]> {
|
||||
let assemblyFormat = [{ attr-dict `:` type($output) }];
|
||||
}
|
||||
|
||||
// TODO(mvoz): deprecated - use concat. Canonicalization will do so automatically.
|
||||
// b/376295711
|
||||
def TPU_RepeatOp : TPU_Op<"repeat", [Pure]> {
|
||||
let arguments = (ins
|
||||
AnyVector:$source,
|
||||
|
@ -170,25 +170,6 @@ FailureOr<TypedValue<MemRefType>> getInternalScratch(
|
||||
.getResult();
|
||||
}
|
||||
|
||||
// Models Numpy's np.repeat, repeating each element `repeats` times along the
|
||||
// specified axis. For example, if `src` is [1, 2], `axis` is 0 and `repeats` is
|
||||
// 3, this will return [1, 1, 1, 2, 2, 2].
|
||||
xla::Array<Value> repeat(const xla::Array<Value> &src, const int repeats,
|
||||
const int64_t axis) {
|
||||
SmallVector<int64_t> dims(toArrayRef(src.dimensions()));
|
||||
dims[axis] *= repeats;
|
||||
xla::Array<Value> res(dims);
|
||||
src.Each([&](absl::Span<const int64_t> idx, const Value v) {
|
||||
SmallVector<int64_t> res_idx(toArrayRef(idx));
|
||||
res_idx[axis] *= repeats;
|
||||
for (int i = 0; i < repeats; ++i) {
|
||||
res(res_idx) = v;
|
||||
++res_idx[axis];
|
||||
}
|
||||
});
|
||||
return res;
|
||||
}
|
||||
|
||||
// Models Numpy's np.concatenate
|
||||
xla::Array<Value> concatenate(const ArrayRef<xla::Array<Value>> arrays,
|
||||
const int64_t axis) {
|
||||
@ -2949,48 +2930,6 @@ LogicalResult tpu_region_rule(RewriteContext &ctx, Operation &op,
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult tpu_repeat_rule(RewriteContext &ctx, Operation &op,
|
||||
const ArrayRef<Layout> layouts_in,
|
||||
const ArrayRef<Layout> layouts_out) {
|
||||
TPU_ASSERT_EQ_OP(layouts_in.size(), 1);
|
||||
TPU_ASSERT_EQ_OP(layouts_out.size(), 1);
|
||||
TPU_ASSERT_OP(layouts_in.front().has_value());
|
||||
TPU_ASSERT_OP(layouts_out.front().has_value());
|
||||
const VectorLayout &layout_in = *layouts_in.front();
|
||||
const VectorLayout &layout_out = *layouts_out.front();
|
||||
if (layout_in.implicit_dim() != VectorLayout::ImplicitDim::kNone) {
|
||||
return op.emitOpError("Not implemented: Only 2D layouts supported");
|
||||
}
|
||||
if (layout_in != layout_out) {
|
||||
return op.emitOpError("Not implemented: Changing layout mid-repeat");
|
||||
}
|
||||
if (!layout_in.hasNaturalTopology(ctx.target_shape) ||
|
||||
layout_in.offsets() != LayoutOffsets{0, 0}) {
|
||||
return op.emitOpError("Not implemented: Non-trivial layouts unsupported");
|
||||
}
|
||||
OpBuilder builder(&op);
|
||||
tpu::RepeatOp repeat_op = cast<tpu::RepeatOp>(op);
|
||||
VectorType src_ty = repeat_op.getSource().getType();
|
||||
const uint32_t dim = repeat_op.getDimension();
|
||||
if (dim != src_ty.getRank() - 1) {
|
||||
return op.emitOpError(
|
||||
"Not implemented: Only repeats along the last dim supported");
|
||||
}
|
||||
if (src_ty.getShape().back() % ctx.target_shape.back() != 0) {
|
||||
return op.emitOpError("Not implemented: Only free repeats are suppported");
|
||||
}
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
const xla::Array<Value> &in_vregs,
|
||||
disassemble(builder, layout_in, repeat_op.getSource(), ctx.target_shape));
|
||||
xla::Array<Value> out_vregs = repeat(in_vregs, repeat_op.getTimes(), dim);
|
||||
repeat_op->replaceAllUsesWith(
|
||||
assemble(builder, repeat_op.getResult().getType(), layout_out, out_vregs,
|
||||
ctx.target_shape)
|
||||
.getOperation());
|
||||
repeat_op->erase();
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op,
|
||||
const ArrayRef<Layout> layouts_in,
|
||||
const ArrayRef<Layout> layouts_out) {
|
||||
@ -4648,7 +4587,6 @@ const llvm::StringMap<rule_type> &rules() {
|
||||
{tpu::StridedStoreOp::getOperationName(), tpu_strided_store_rule},
|
||||
{tpu::MatmulOp::getOperationName(), tpu_matmul_rule},
|
||||
{tpu::RegionOp::getOperationName(), tpu_region_rule},
|
||||
{tpu::RepeatOp::getOperationName(), tpu_repeat_rule},
|
||||
{tpu::BitcastOp::getOperationName(), tpu_bitcast_rule},
|
||||
{tpu::TraceOp::getOperationName(), tpu_trace_rule},
|
||||
{tpu::AssumeLayoutOp::getOperationName(), tpu_assume_layout_rule},
|
||||
|
@ -1,3 +1,4 @@
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
@ -350,6 +351,29 @@ LogicalResult canonicalize_select(int hardware_generation, Operation &raw_op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult canonicalize_repeat(int hardware_generation, Operation &raw_op) {
|
||||
auto op = dyn_cast<tpu::RepeatOp>(raw_op);
|
||||
if (!isa<VectorType>(op.getType())) {
|
||||
return op.emitOpError("Only vector types supported");
|
||||
}
|
||||
auto operand = op.getSource();
|
||||
auto times = op.getTimes();
|
||||
if (times == 1) {
|
||||
// A true no op - kind of an odd edge case, but this does come up in
|
||||
// flash_attention_backward tests.
|
||||
op.replaceAllUsesWith(operand);
|
||||
op.erase();
|
||||
return success();
|
||||
}
|
||||
auto operands = std::vector<Value>(times, operand);
|
||||
ImplicitLocOpBuilder builder(op->getLoc(), op.getOperation());
|
||||
auto concat = builder.create<tpu::ConcatenateOp>(op.getLoc(), op.getType(),
|
||||
operands, op.getDimension());
|
||||
op.replaceAllUsesWith(concat.getResult());
|
||||
op.erase();
|
||||
return success();
|
||||
}
|
||||
|
||||
using canonicalize_rule_type =
|
||||
std::function<LogicalResult(int hardware_generation, Operation &op)>;
|
||||
|
||||
@ -360,7 +384,8 @@ const llvm::StringMap<canonicalize_rule_type> &rules() {
|
||||
{vector::ContractionOp::getOperationName(), canonicalize_extract},
|
||||
{vector::MultiDimReductionOp::getOperationName(),
|
||||
canonicalize_multi_dim_reduction},
|
||||
{arith::SelectOp::getOperationName(), canonicalize_select}};
|
||||
{arith::SelectOp::getOperationName(), canonicalize_select},
|
||||
{tpu::RepeatOp::getOperationName(), canonicalize_repeat}};
|
||||
return *rules;
|
||||
}
|
||||
|
||||
|
@ -288,10 +288,6 @@ class VectorLayoutInferer {
|
||||
if (infer(op).failed()) {
|
||||
return failure();
|
||||
}
|
||||
} else if (auto op = dyn_cast<tpu::RepeatOp>(any_op)) {
|
||||
if (infer(op).failed()) {
|
||||
return failure();
|
||||
}
|
||||
} else if (auto op = dyn_cast<tpu::TraceOp>(any_op)) {
|
||||
if (infer(op).failed()) {
|
||||
return failure();
|
||||
@ -1020,12 +1016,6 @@ class VectorLayoutInferer {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult infer(tpu::RepeatOp op) {
|
||||
auto src_layout = getLayout(op.getSource());
|
||||
setLayout(op, src_layout, src_layout);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult infer(tpu::TraceOp op) {
|
||||
static LogicalResult (*match_yield)(Operation *) = [](Operation *op) {
|
||||
TPU_CHECK_OP(isa<tpu::YieldOp>(op), "expected yield terminator");
|
||||
|
@ -1587,7 +1587,6 @@ class PallasCallTest(PallasBaseTest):
|
||||
self.assertEqual(analysis_result['transcendentals'], 21)
|
||||
self.assertEqual(analysis_result['bytes accessed'], 12345)
|
||||
|
||||
|
||||
def test_cost_analysis_vmap(self):
|
||||
def kernel(x, y):
|
||||
y[:] = x[:]
|
||||
@ -1606,7 +1605,6 @@ class PallasCallTest(PallasBaseTest):
|
||||
self.assertEqual(analysis_result['transcendentals'], batch_size * 21)
|
||||
self.assertEqual(analysis_result['bytes accessed'], batch_size * 12345)
|
||||
|
||||
|
||||
def test_vmem_limit(self):
|
||||
shape = (128, 128)
|
||||
|
||||
@ -1673,6 +1671,23 @@ class PallasCallTest(PallasBaseTest):
|
||||
),
|
||||
)(x)
|
||||
|
||||
@parameterized.product(dtype=[jnp.bfloat16, jnp.float32])
|
||||
def test_pltpu_repeat(self, dtype):
|
||||
def test_kernel(x_ref, o_ref):
|
||||
x = x_ref[...]
|
||||
o_ref[...] = pltpu.repeat(x, 2, axis=1)
|
||||
|
||||
@jax.jit
|
||||
def test(x: jax.Array) -> jax.Array:
|
||||
return pl.pallas_call(
|
||||
test_kernel,
|
||||
out_shape=jax.ShapeDtypeStruct([x.shape[0], x.shape[1] * 2], x.dtype),
|
||||
)(x)
|
||||
|
||||
x = jnp.arange(2048, dtype=dtype).reshape((8, 256))
|
||||
y = test(x)
|
||||
np.testing.assert_array_equal(y, jnp.concatenate([x, x], axis=1))
|
||||
|
||||
|
||||
class PallasUXTest(PallasBaseTest):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user