[TPU][Mosaic] Replace tpu lowering (at canonicalization) for repeat with concat (which handles far more cases)

PiperOrigin-RevId: 691192121
This commit is contained in:
jax authors 2024-10-29 15:56:44 -07:00
parent 7c4cc9552c
commit 5ad066eeaa
5 changed files with 45 additions and 75 deletions

View File

@ -295,6 +295,8 @@ def TPU_IotaOp : TPU_Op<"iota", [Pure]> {
let assemblyFormat = [{ attr-dict `:` type($output) }];
}
// TODO(mvoz): deprecated - use concat. Canonicalization will do so automatically.
// b/376295711
def TPU_RepeatOp : TPU_Op<"repeat", [Pure]> {
let arguments = (ins
AnyVector:$source,

View File

@ -170,25 +170,6 @@ FailureOr<TypedValue<MemRefType>> getInternalScratch(
.getResult();
}
// Models Numpy's np.repeat, repeating each element `repeats` times along the
// specified axis. For example, if `src` is [1, 2], `axis` is 0 and `repeats` is
// 3, this will return [1, 1, 1, 2, 2, 2].
xla::Array<Value> repeat(const xla::Array<Value> &src, const int repeats,
const int64_t axis) {
SmallVector<int64_t> dims(toArrayRef(src.dimensions()));
dims[axis] *= repeats;
xla::Array<Value> res(dims);
src.Each([&](absl::Span<const int64_t> idx, const Value v) {
SmallVector<int64_t> res_idx(toArrayRef(idx));
res_idx[axis] *= repeats;
for (int i = 0; i < repeats; ++i) {
res(res_idx) = v;
++res_idx[axis];
}
});
return res;
}
// Models Numpy's np.concatenate
xla::Array<Value> concatenate(const ArrayRef<xla::Array<Value>> arrays,
const int64_t axis) {
@ -2949,48 +2930,6 @@ LogicalResult tpu_region_rule(RewriteContext &ctx, Operation &op,
return success();
}
LogicalResult tpu_repeat_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
TPU_ASSERT_EQ_OP(layouts_in.size(), 1);
TPU_ASSERT_EQ_OP(layouts_out.size(), 1);
TPU_ASSERT_OP(layouts_in.front().has_value());
TPU_ASSERT_OP(layouts_out.front().has_value());
const VectorLayout &layout_in = *layouts_in.front();
const VectorLayout &layout_out = *layouts_out.front();
if (layout_in.implicit_dim() != VectorLayout::ImplicitDim::kNone) {
return op.emitOpError("Not implemented: Only 2D layouts supported");
}
if (layout_in != layout_out) {
return op.emitOpError("Not implemented: Changing layout mid-repeat");
}
if (!layout_in.hasNaturalTopology(ctx.target_shape) ||
layout_in.offsets() != LayoutOffsets{0, 0}) {
return op.emitOpError("Not implemented: Non-trivial layouts unsupported");
}
OpBuilder builder(&op);
tpu::RepeatOp repeat_op = cast<tpu::RepeatOp>(op);
VectorType src_ty = repeat_op.getSource().getType();
const uint32_t dim = repeat_op.getDimension();
if (dim != src_ty.getRank() - 1) {
return op.emitOpError(
"Not implemented: Only repeats along the last dim supported");
}
if (src_ty.getShape().back() % ctx.target_shape.back() != 0) {
return op.emitOpError("Not implemented: Only free repeats are suppported");
}
FAILUREOR_ASSIGN_OR_RETURN(
const xla::Array<Value> &in_vregs,
disassemble(builder, layout_in, repeat_op.getSource(), ctx.target_shape));
xla::Array<Value> out_vregs = repeat(in_vregs, repeat_op.getTimes(), dim);
repeat_op->replaceAllUsesWith(
assemble(builder, repeat_op.getResult().getType(), layout_out, out_vregs,
ctx.target_shape)
.getOperation());
repeat_op->erase();
return success();
}
LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
@ -4648,7 +4587,6 @@ const llvm::StringMap<rule_type> &rules() {
{tpu::StridedStoreOp::getOperationName(), tpu_strided_store_rule},
{tpu::MatmulOp::getOperationName(), tpu_matmul_rule},
{tpu::RegionOp::getOperationName(), tpu_region_rule},
{tpu::RepeatOp::getOperationName(), tpu_repeat_rule},
{tpu::BitcastOp::getOperationName(), tpu_bitcast_rule},
{tpu::TraceOp::getOperationName(), tpu_trace_rule},
{tpu::AssumeLayoutOp::getOperationName(), tpu_assume_layout_rule},

View File

@ -1,3 +1,4 @@
#include <cstdint>
#include <functional>
#include <memory>
#include <vector>
@ -350,6 +351,29 @@ LogicalResult canonicalize_select(int hardware_generation, Operation &raw_op) {
return success();
}
LogicalResult canonicalize_repeat(int hardware_generation, Operation &raw_op) {
auto op = dyn_cast<tpu::RepeatOp>(raw_op);
if (!isa<VectorType>(op.getType())) {
return op.emitOpError("Only vector types supported");
}
auto operand = op.getSource();
auto times = op.getTimes();
if (times == 1) {
// A true no op - kind of an odd edge case, but this does come up in
// flash_attention_backward tests.
op.replaceAllUsesWith(operand);
op.erase();
return success();
}
auto operands = std::vector<Value>(times, operand);
ImplicitLocOpBuilder builder(op->getLoc(), op.getOperation());
auto concat = builder.create<tpu::ConcatenateOp>(op.getLoc(), op.getType(),
operands, op.getDimension());
op.replaceAllUsesWith(concat.getResult());
op.erase();
return success();
}
using canonicalize_rule_type =
std::function<LogicalResult(int hardware_generation, Operation &op)>;
@ -360,7 +384,8 @@ const llvm::StringMap<canonicalize_rule_type> &rules() {
{vector::ContractionOp::getOperationName(), canonicalize_extract},
{vector::MultiDimReductionOp::getOperationName(),
canonicalize_multi_dim_reduction},
{arith::SelectOp::getOperationName(), canonicalize_select}};
{arith::SelectOp::getOperationName(), canonicalize_select},
{tpu::RepeatOp::getOperationName(), canonicalize_repeat}};
return *rules;
}

View File

@ -288,10 +288,6 @@ class VectorLayoutInferer {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<tpu::RepeatOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<tpu::TraceOp>(any_op)) {
if (infer(op).failed()) {
return failure();
@ -1020,12 +1016,6 @@ class VectorLayoutInferer {
return success();
}
LogicalResult infer(tpu::RepeatOp op) {
auto src_layout = getLayout(op.getSource());
setLayout(op, src_layout, src_layout);
return success();
}
LogicalResult infer(tpu::TraceOp op) {
static LogicalResult (*match_yield)(Operation *) = [](Operation *op) {
TPU_CHECK_OP(isa<tpu::YieldOp>(op), "expected yield terminator");

View File

@ -1587,7 +1587,6 @@ class PallasCallTest(PallasBaseTest):
self.assertEqual(analysis_result['transcendentals'], 21)
self.assertEqual(analysis_result['bytes accessed'], 12345)
def test_cost_analysis_vmap(self):
def kernel(x, y):
y[:] = x[:]
@ -1606,7 +1605,6 @@ class PallasCallTest(PallasBaseTest):
self.assertEqual(analysis_result['transcendentals'], batch_size * 21)
self.assertEqual(analysis_result['bytes accessed'], batch_size * 12345)
def test_vmem_limit(self):
shape = (128, 128)
@ -1673,6 +1671,23 @@ class PallasCallTest(PallasBaseTest):
),
)(x)
@parameterized.product(dtype=[jnp.bfloat16, jnp.float32])
def test_pltpu_repeat(self, dtype):
def test_kernel(x_ref, o_ref):
x = x_ref[...]
o_ref[...] = pltpu.repeat(x, 2, axis=1)
@jax.jit
def test(x: jax.Array) -> jax.Array:
return pl.pallas_call(
test_kernel,
out_shape=jax.ShapeDtypeStruct([x.shape[0], x.shape[1] * 2], x.dtype),
)(x)
x = jnp.arange(2048, dtype=dtype).reshape((8, 256))
y = test(x)
np.testing.assert_array_equal(y, jnp.concatenate([x, x], axis=1))
class PallasUXTest(PallasBaseTest):