[Pallas] Reductions with replicated axes.

PiperOrigin-RevId: 727292293
This commit is contained in:
jax authors 2025-02-15 07:40:44 -08:00
parent b6361b3e76
commit eaceac3bf9
4 changed files with 284 additions and 116 deletions

View File

@ -94,6 +94,7 @@ cc_library(
"@xla//xla:array",
"@xla//xla:shape_util",
"@xla//xla:util",
"@xla//xla/tsl/platform:errors",
] + pallas_extension_deps,
)

View File

@ -478,7 +478,8 @@ std::unique_ptr<VRegDataBounds> VectorLayout::tileDataBounds(
if (!hasNaturalTopology(target_shape)) {
if (!offsets_[0].has_value() || !offsets_[1].has_value()) {
emitError(UnknownLoc::get(mlir_ctx), "Not implemented");
emitError(UnknownLoc::get(mlir_ctx),
"Not implemented: non-natural topology with replication");
return nullptr;
}
const int64_t so = *offsets_[0];

View File

@ -69,6 +69,7 @@
#include "jaxlib/mosaic/dialect/tpu/vreg_util.h"
#include "xla/array.h"
#include "xla/layout.h"
#include "xla/tsl/platform/errors.h"
#include "xla/util.h"
// TODO(tlongeri): Prefer returning failure over CHECKs. In particular, be more
@ -1997,7 +1998,7 @@ LogicalResult tpu_bitcast_rule(RewriteContext &ctx, Operation &op,
if (in_tiling != out_tiling) {
return op.emitOpError(
"Expected tilings are the same after multiplying the "
"second-minor dimension by the ratio of bitwidths.");
"second-minor dimension by the ratio of bitwidths.");
}
auto in_offsets = in_layout.offsets();
auto out_offsets = out_layout.offsets();
@ -2012,7 +2013,7 @@ LogicalResult tpu_bitcast_rule(RewriteContext &ctx, Operation &op,
in_offsets[1] != out_offsets[1]) {
return op.emitOpError(
"Expected offsets are the same after multiplying the "
"second-minor dimension by the ratio of bitwidths.");
"second-minor dimension by the ratio of bitwidths.");
}
if (in_layout.implicit_dim() != out_layout.implicit_dim()) {
return op.emitOpError(
@ -3805,7 +3806,7 @@ LogicalResult vector_extract_rule(RewriteContext &ctx, Operation &op,
extract_op.replaceAllUsesWith(
builder.create<vector::ExtractOp>(
op.getLoc(), rotated_vreg,
ArrayRef<int64_t>{0, 0})
ArrayRef<int64_t>{0, 0})
.getResult());
}
extract_op.erase();
@ -3956,7 +3957,6 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
false};
break;
}
const std::array<bool, 2> allow_replicated = {!reduces[0], !reduces[1]};
if ((reduces[0] || reduces[1]) &&
!src_layout.hasNativeTiling(ctx.target_shape)) {
@ -3968,9 +3968,10 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
return multi_reduction_op.emitOpError("Not implemented: Tiling change");
}
for (int i = 0; i < 2; ++i) {
if (reduces[i] && src_layout.offsets()[i] == std::nullopt) {
if (reduces[i] && src_layout.offsets()[i] == std::nullopt &&
element_type.getIntOrFloatBitWidth() != 32) {
return multi_reduction_op.emitOpError(
"Not implemented: Reductions over replicated axes");
"Not implemented: Non-32-bit reductions over replicated axes");
}
// Offsets have to be equal, unless we're reducing over that dimension.
if (src_layout.offsets()[i] != dst_layout.offsets()[i] && !reduces[i]) {
@ -4034,130 +4035,202 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<int64_t> src_shape = src_ty.getShape();
auto all_results_ok = dst_vregs.EachStatus(
[&](const absl::Span<const int64_t> idx, Value *const dst_vreg) {
// Extract a subset of source vregs that reduce into this result vreg.
SmallVector<int64_t> src_slice_start;
src_slice_start.reserve(src_rank);
SmallVector<int64_t> src_slice_end;
src_slice_end.reserve(src_rank);
for (int64_t i : idx) {
src_slice_start.push_back(i);
src_slice_end.push_back(i + 1);
}
for (int64_t d : dims) {
src_slice_start.insert(src_slice_start.begin() + d, 0);
src_slice_end.insert(src_slice_end.begin() + d, src_vregs.dim(d));
}
xla::Array<Value> reduced_vregs =
src_vregs.Slice(src_slice_start, src_slice_end);
std::optional<Value> acc_vreg;
auto reduce_elementwise = [&](Value lhs, Value rhs) -> Value {
Value result;
switch (tpu_kind) {
case tpu::ReductionKind::SUM:
result =
// Extract a subset of source vregs that reduce into this result vreg.
SmallVector<int64_t> src_slice_start;
src_slice_start.reserve(src_rank);
SmallVector<int64_t> src_slice_end;
src_slice_end.reserve(src_rank);
for (int64_t i : idx) {
src_slice_start.push_back(i);
src_slice_end.push_back(i + 1);
}
for (int64_t d : dims) {
int64_t d_size = src_vregs.dim(d);
src_slice_start.insert(src_slice_start.begin() + d, 0);
if (!src_layout.offsets()[0].has_value() && d == src_rank - 2) {
d_size = 1;
}
if (!src_layout.offsets()[1].has_value() && d == src_rank - 1) {
d_size = 1;
}
src_slice_end.insert(src_slice_end.begin() + d, d_size);
}
xla::Array<Value> reduced_vregs =
src_vregs.Slice(src_slice_start, src_slice_end);
std::optional<Value> acc_vreg;
auto reduce_elementwise = [&](Value lhs, Value rhs) -> Value {
Value result;
switch (tpu_kind) {
case tpu::ReductionKind::SUM:
result =
is_int
? builder.create<arith::AddIOp>(loc, lhs, rhs).getResult()
: builder.create<arith::AddFOp>(loc, lhs, rhs)
.getResult();
break;
case tpu::ReductionKind::MAX:
break;
case tpu::ReductionKind::MAX:
result = is_int ? builder.create<arith::MaxSIOp>(loc, lhs, rhs)
.getResult()
: builder.create<arith::MaximumFOp>(loc, lhs, rhs)
.getResult();
break;
case tpu::ReductionKind::MIN:
: builder.create<arith::MaximumFOp>(loc, lhs, rhs)
.getResult();
break;
case tpu::ReductionKind::MIN:
result = is_int ? builder.create<arith::MinSIOp>(loc, lhs, rhs)
.getResult()
: builder.create<arith::MinimumFOp>(loc, lhs, rhs)
.getResult();
break;
: builder.create<arith::MinimumFOp>(loc, lhs, rhs)
.getResult();
break;
}
return result;
};
auto reduction_status = reduced_vregs.EachStatus(
[&](const absl::Span<const int64_t> red_idx, Value *const src_vreg) {
SmallVector<int64_t> src_idx(red_idx.begin(), red_idx.end());
for (int i = 0; i < src_idx.size(); ++i) {
src_idx[i] += src_slice_start[i];
}
return result;
};
auto reduction_status = reduced_vregs.EachStatus(
[&](const absl::Span<const int64_t> red_idx,
Value *const src_vreg) {
SmallVector<int64_t> src_idx(red_idx.begin(), red_idx.end());
for (int i = 0; i < src_idx.size(); ++i) {
src_idx[i] += src_slice_start[i];
}
const std::unique_ptr<VRegDataBounds> data_bounds =
src_layout.tileDataBounds(builder.getContext(), src_shape,
src_idx, ctx.target_shape,
allow_replicated);
if (data_bounds == nullptr) {
// Op error has already been emitted inside tileDataBounds().
return absl::UnknownError("Unable to obtain data bounds");
}
// TODO(tlongeri): Maybe assemble/disassemble should take
// TypedValue<VectorType> and we could save casts here and
// elsewhere
FailureOr<Value> failure_or_vreg =
maskOOB(ctx, builder, cast<TypedValue<VectorType>>(*src_vreg),
*data_bounds, neutral);
if (failed(failure_or_vreg)) {
op.emitOpError("Failed to mask vreg");
return absl::UnknownError("");
}
Value vreg = failure_or_vreg.value();
if (!acc_vreg.has_value()) {
acc_vreg = vreg;
} else {
acc_vreg = reduce_elementwise(*acc_vreg, vreg);
}
return absl::OkStatus();
});
if (!reduction_status.ok()) {
return reduction_status;
}
TPU_ASSERT_OP(acc_vreg.has_value());
if (reduces[1]) {
acc_vreg = builder.create<tpu::AllReduceOp>(
multi_reduction_op->getLoc(), *acc_vreg, 1, tpu_kind);
}
if (reduces[0]) {
// Packed types are compressed along rows, so we need to reduce them
// within each 32-bit word. There's no performance penalty for doing
// this in 32-bit precision, so we take advantage of it.
Type acc_vreg_ty = acc_vreg->getType();
if (acc_layout.packing() > 1) {
Type vreg_ty_32 = nullptr;
if (acc.getType().getElementType().isBF16()) {
vreg_ty_32 =
getNativeVregType(builder.getF32Type(), ctx.target_shape);
} else {
multi_reduction_op.emitOpError(
"Not implemented: Unsupported reduction dtype");
const std::unique_ptr<VRegDataBounds> data_bounds =
src_layout.tileDataBounds(builder.getContext(), src_shape,
src_idx, ctx.target_shape,
{true, true});
if (data_bounds == nullptr) {
// Op error has already been emitted inside tileDataBounds().
return absl::UnknownError("Unable to obtain data bounds");
}
Value vreg = *src_vreg;
// If replicated, we don't need to mask.
if (src_layout.offsets()[0].has_value() ||
src_layout.offsets()[1].has_value()) {
// TODO(tlongeri): Maybe assemble/disassemble should take
// TypedValue<VectorType> and we could save casts here and
// elsewhere
FailureOr<Value> failure_or_vreg =
maskOOB(ctx, builder, cast<TypedValue<VectorType>>(*src_vreg),
*data_bounds, neutral);
if (failed(failure_or_vreg)) {
op.emitOpError("Failed to mask vreg");
return absl::UnknownError("");
}
Value acc_vreg_32 = builder.create<tpu::UnpackSubelementsOp>(
loc, vreg_ty_32, *acc_vreg, 0, tpu::PackFormat::kInterleaved);
for (int i = 1; i < acc_layout.packing(); ++i) {
Value acc_vreg_part_32 = builder.create<tpu::UnpackSubelementsOp>(
loc, vreg_ty_32, *acc_vreg, i, tpu::PackFormat::kInterleaved);
acc_vreg_32 = reduce_elementwise(acc_vreg_32, acc_vreg_part_32);
}
acc_vreg = acc_vreg_32;
vreg = failure_or_vreg.value();
}
// At this point acc_vreg is always 32-bit.
acc_vreg = builder.create<tpu::AllReduceOp>(
multi_reduction_op->getLoc(), *acc_vreg, 0, tpu_kind);
// We pack the final result back into the original type.
if (acc_layout.packing() > 1) {
SmallVector<int32_t> positions(acc_layout.packing());
if (!acc_vreg.has_value()) {
acc_vreg = vreg;
} else {
acc_vreg = reduce_elementwise(*acc_vreg, vreg);
}
return absl::OkStatus();
});
TF_RETURN_IF_ERROR(reduction_status);
TPU_ASSERT_OP(acc_vreg.has_value());
const bool is_double_replicated_double_reduced =
reduces[0] && reduces[1] && !src_layout.offsets()[0].has_value() &&
!src_layout.offsets()[1].has_value();
if (reduces[1]) {
if (src_layout.offsets()[1].has_value()) {
acc_vreg = builder.create<tpu::AllReduceOp>(
multi_reduction_op->getLoc(), *acc_vreg, /* dim= */ 1, tpu_kind);
} else {
int64_t size_dim1 = src_layout.getImplicitTiledDims(src_shape, 1)[1];
if (is_double_replicated_double_reduced) {
size_dim1 *= src_layout.getImplicitTiledDims(src_shape, 1)[0];
}
switch (tpu_kind) {
case tpu::ReductionKind::SUM:
if (is_int) {
IntegerAttr size_attr = builder.getI32IntegerAttr(size_dim1);
TypedValue<VectorType> source_value = getFullVector(
builder,
getNativeVregType(builder.getI32Type(), ctx.target_shape),
size_attr);
acc_vreg =
builder.create<arith::MulIOp>(loc, *acc_vreg, source_value);
} else {
FloatAttr size_attr = builder.getF32FloatAttr(size_dim1);
TypedValue<VectorType> source_value = getFullVector(
builder,
getNativeVregType(builder.getF32Type(), ctx.target_shape),
size_attr);
acc_vreg =
builder.create<arith::MulFOp>(loc, *acc_vreg, source_value);
}
break;
// We don't need to do anything for other reduction kinds.
case tpu::ReductionKind::MAX:
case tpu::ReductionKind::MIN:
break;
}
}
}
if (reduces[0]) {
// Packed types are compressed along rows, so we need to reduce them
// within each 32-bit word. There's no performance penalty for doing
// this in 32-bit precision, so we take advantage of it.
Type acc_vreg_ty = acc_vreg->getType();
if (acc_layout.packing() > 1) {
Type vreg_ty_32 = nullptr;
if (acc.getType().getElementType().isBF16()) {
vreg_ty_32 =
getNativeVregType(builder.getF32Type(), ctx.target_shape);
} else {
multi_reduction_op.emitOpError(
"Not implemented: Unsupported reduction dtype");
return absl::UnknownError("");
}
Value acc_vreg_32 = builder.create<tpu::UnpackSubelementsOp>(
loc, vreg_ty_32, *acc_vreg, 0, tpu::PackFormat::kInterleaved);
for (int i = 1; i < acc_layout.packing(); ++i) {
Value acc_vreg_part_32 = builder.create<tpu::UnpackSubelementsOp>(
loc, vreg_ty_32, *acc_vreg, i, tpu::PackFormat::kInterleaved);
acc_vreg_32 = reduce_elementwise(acc_vreg_32, acc_vreg_part_32);
}
acc_vreg = acc_vreg_32;
}
// At this point acc_vreg is always 32-bit.
if (src_layout.offsets()[0].has_value()) {
acc_vreg = builder.create<tpu::AllReduceOp>(
multi_reduction_op->getLoc(), *acc_vreg, 0, tpu_kind);
} else if (!is_double_replicated_double_reduced) {
int64_t size_dim0 = src_layout.getImplicitTiledDims(src_shape, 1)[0];
switch (tpu_kind) {
case tpu::ReductionKind::SUM:
if (is_int) {
IntegerAttr size_attr = builder.getI32IntegerAttr(size_dim0);
TypedValue<VectorType> source_value = getFullVector(
builder,
getNativeVregType(builder.getI32Type(), ctx.target_shape),
size_attr);
acc_vreg =
builder.create<arith::MulIOp>(loc, *acc_vreg, source_value);
} else {
FloatAttr size_attr = builder.getF32FloatAttr(size_dim0);
TypedValue<VectorType> source_value = getFullVector(
builder,
getNativeVregType(builder.getF32Type(), ctx.target_shape),
size_attr);
acc_vreg =
builder.create<arith::MulFOp>(loc, *acc_vreg, source_value);
}
break;
case tpu::ReductionKind::MAX:
case tpu::ReductionKind::MIN:
break;
}
}
// We pack the final result back into the original type.
if (acc_layout.packing() > 1) {
SmallVector<int32_t> positions(acc_layout.packing());
std::iota(positions.begin(), positions.end(),
static_cast<int32_t>(0));
SmallVector<Value> parts(acc_layout.packing(), *acc_vreg);
acc_vreg = builder.create<tpu::PackSubelementsOp>(
SmallVector<Value> parts(acc_layout.packing(), *acc_vreg);
acc_vreg = builder.create<tpu::PackSubelementsOp>(
loc, acc_vreg_ty, parts,
builder.getDenseI32ArrayAttr(positions),
tpu::PackFormat::kInterleaved);
}
}
*dst_vreg = *acc_vreg;
return absl::OkStatus();
});
tpu::PackFormat::kInterleaved);
}
}
*dst_vreg = *acc_vreg;
return absl::OkStatus();
});
if (!all_results_ok.ok()) {
return failure();
}
@ -4702,7 +4775,7 @@ LogicalResult tpu_prng_random_bits_rule(RewriteContext &ctx, Operation &op,
const VectorLayout &layout_out = *layouts_out.front();
tpu::PRNGRandomBitsOp rng_op = cast<tpu::PRNGRandomBitsOp>(op);
if (layout_out != VectorLayout(32, {0, 0}, ctx.target_shape,
VectorLayout::ImplicitDim::kNone)) {
VectorLayout::ImplicitDim::kNone)) {
return op.emitOpError(
"Unsupported output layout for ") << rng_op->getName();
}

View File

@ -16,6 +16,7 @@
import contextlib
import functools
import itertools
import gc
import io
import math
@ -1708,6 +1709,98 @@ class PallasCallDMAInterpretTest(PallasCallDMATest):
class PallasCallTest(PallasBaseTest):
@parameterized.parameters([
dict(shape=shape, dty=dty)
for shape, dty in itertools.product(
[(4, 2, 9), (1, 1025), (1024, 1024)], [jnp.float32, jnp.int32]
)
])
def test_double_replicated_reduction(self, shape, dty):
if not jtu.if_cloud_tpu_at_least(2025, 2, 19):
self.skipTest("Needs a newer libTPU")
def body(o_ref):
x = jnp.full(shape, 2.0, dtype=dty)
reduction = jnp.sum(x, axis=None)
bcast = jnp.full((vregs_in_block * 1024,), reduction)
o_ref[:] = bcast
vregs_in_block = 2
total_vregs = 4
data_size = total_vregs * 1024
block_size = vregs_in_block * 1024
@jax.jit
def reduce():
return self.pallas_call(
body,
out_shape=jax.ShapeDtypeStruct((data_size,), dty),
in_specs=[],
out_specs=pl.BlockSpec((block_size,), lambda i: i),
grid= data_size // block_size,
)()
x = jnp.full(shape, 2.0, dtype=dty)
z = jax.block_until_ready(reduce())
reduce_value = jnp.sum(jnp.full(shape, x), dtype=dty)
np.testing.assert_allclose(z, reduce_value)
@parameterized.parameters([
dict(
m=m,
replicated=replicated,
reduced_dims=reduced_dims,
dty=dty,
reduce_func=reduce_func,
)
for m, replicated, reduced_dims, dty, reduce_func in itertools.product(
[128, 256],
[(True, True), (False, True), (True, False)],
[(0, 1), (0,), (1,)],
[jnp.float32, jnp.int32],
[jnp.sum, jnp.max, jnp.min],
)
])
def test_replicated_broadcast_reduction(
self, m, replicated, reduced_dims, dty, reduce_func
):
if not jtu.if_cloud_tpu_at_least(2025, 2, 19):
self.skipTest("Needs a newer libTPU")
if dty == jnp.int32 and 1 in reduced_dims:
# TODO(b/395579834): Remove this skip once we implement this.
self.skipTest('int32 reduction on last dimension not supported')
if not jtu.is_device_tpu_at_least(4) and len(replicated) == 2:
self.skipTest(
'Brodcast in both sublanes and lanes not supported on this hardware'
)
in_shape = (1 if replicated[0] else m, 1 if replicated[1] else m)
red_shape = [m, m]
for d in reduced_dims:
red_shape[d] = 1
def body(x_ref, o_ref):
x = x_ref[:]
dilated_x = jnp.broadcast_to(x, (m, m))
reduced = reduce_func(dilated_x, axis=reduced_dims).reshape(red_shape)
o_ref[:] = reduced
@jax.jit
def reduce(x):
return self.pallas_call(
body,
out_shape=jax.ShapeDtypeStruct(red_shape, dty),
in_specs=[pl.BlockSpec(in_shape)],
out_specs=pl.BlockSpec(red_shape),
grid=1,
)(x)
x = jnp.full(in_shape, 2.0, dtype=dty)
y = jax.block_until_ready(reduce(x))
dilated_x = jnp.broadcast_to(x, (m, m))
expected = reduce_func(dilated_x, axis=reduced_dims).reshape(red_shape)
np.testing.assert_allclose(y, expected)
def test_cost_analysis(self):
def kernel(x, y):
y[:] = x[:]