mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
[Pallas] Reductions with replicated axes.
PiperOrigin-RevId: 727292293
This commit is contained in:
parent
b6361b3e76
commit
eaceac3bf9
@ -94,6 +94,7 @@ cc_library(
|
||||
"@xla//xla:array",
|
||||
"@xla//xla:shape_util",
|
||||
"@xla//xla:util",
|
||||
"@xla//xla/tsl/platform:errors",
|
||||
] + pallas_extension_deps,
|
||||
)
|
||||
|
||||
|
@ -478,7 +478,8 @@ std::unique_ptr<VRegDataBounds> VectorLayout::tileDataBounds(
|
||||
|
||||
if (!hasNaturalTopology(target_shape)) {
|
||||
if (!offsets_[0].has_value() || !offsets_[1].has_value()) {
|
||||
emitError(UnknownLoc::get(mlir_ctx), "Not implemented");
|
||||
emitError(UnknownLoc::get(mlir_ctx),
|
||||
"Not implemented: non-natural topology with replication");
|
||||
return nullptr;
|
||||
}
|
||||
const int64_t so = *offsets_[0];
|
||||
|
@ -69,6 +69,7 @@
|
||||
#include "jaxlib/mosaic/dialect/tpu/vreg_util.h"
|
||||
#include "xla/array.h"
|
||||
#include "xla/layout.h"
|
||||
#include "xla/tsl/platform/errors.h"
|
||||
#include "xla/util.h"
|
||||
|
||||
// TODO(tlongeri): Prefer returning failure over CHECKs. In particular, be more
|
||||
@ -1997,7 +1998,7 @@ LogicalResult tpu_bitcast_rule(RewriteContext &ctx, Operation &op,
|
||||
if (in_tiling != out_tiling) {
|
||||
return op.emitOpError(
|
||||
"Expected tilings are the same after multiplying the "
|
||||
"second-minor dimension by the ratio of bitwidths.");
|
||||
"second-minor dimension by the ratio of bitwidths.");
|
||||
}
|
||||
auto in_offsets = in_layout.offsets();
|
||||
auto out_offsets = out_layout.offsets();
|
||||
@ -2012,7 +2013,7 @@ LogicalResult tpu_bitcast_rule(RewriteContext &ctx, Operation &op,
|
||||
in_offsets[1] != out_offsets[1]) {
|
||||
return op.emitOpError(
|
||||
"Expected offsets are the same after multiplying the "
|
||||
"second-minor dimension by the ratio of bitwidths.");
|
||||
"second-minor dimension by the ratio of bitwidths.");
|
||||
}
|
||||
if (in_layout.implicit_dim() != out_layout.implicit_dim()) {
|
||||
return op.emitOpError(
|
||||
@ -3805,7 +3806,7 @@ LogicalResult vector_extract_rule(RewriteContext &ctx, Operation &op,
|
||||
extract_op.replaceAllUsesWith(
|
||||
builder.create<vector::ExtractOp>(
|
||||
op.getLoc(), rotated_vreg,
|
||||
ArrayRef<int64_t>{0, 0})
|
||||
ArrayRef<int64_t>{0, 0})
|
||||
.getResult());
|
||||
}
|
||||
extract_op.erase();
|
||||
@ -3956,7 +3957,6 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
|
||||
false};
|
||||
break;
|
||||
}
|
||||
const std::array<bool, 2> allow_replicated = {!reduces[0], !reduces[1]};
|
||||
|
||||
if ((reduces[0] || reduces[1]) &&
|
||||
!src_layout.hasNativeTiling(ctx.target_shape)) {
|
||||
@ -3968,9 +3968,10 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
|
||||
return multi_reduction_op.emitOpError("Not implemented: Tiling change");
|
||||
}
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
if (reduces[i] && src_layout.offsets()[i] == std::nullopt) {
|
||||
if (reduces[i] && src_layout.offsets()[i] == std::nullopt &&
|
||||
element_type.getIntOrFloatBitWidth() != 32) {
|
||||
return multi_reduction_op.emitOpError(
|
||||
"Not implemented: Reductions over replicated axes");
|
||||
"Not implemented: Non-32-bit reductions over replicated axes");
|
||||
}
|
||||
// Offsets have to be equal, unless we're reducing over that dimension.
|
||||
if (src_layout.offsets()[i] != dst_layout.offsets()[i] && !reduces[i]) {
|
||||
@ -4034,130 +4035,202 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
|
||||
const ArrayRef<int64_t> src_shape = src_ty.getShape();
|
||||
auto all_results_ok = dst_vregs.EachStatus(
|
||||
[&](const absl::Span<const int64_t> idx, Value *const dst_vreg) {
|
||||
// Extract a subset of source vregs that reduce into this result vreg.
|
||||
SmallVector<int64_t> src_slice_start;
|
||||
src_slice_start.reserve(src_rank);
|
||||
SmallVector<int64_t> src_slice_end;
|
||||
src_slice_end.reserve(src_rank);
|
||||
for (int64_t i : idx) {
|
||||
src_slice_start.push_back(i);
|
||||
src_slice_end.push_back(i + 1);
|
||||
}
|
||||
for (int64_t d : dims) {
|
||||
src_slice_start.insert(src_slice_start.begin() + d, 0);
|
||||
src_slice_end.insert(src_slice_end.begin() + d, src_vregs.dim(d));
|
||||
}
|
||||
xla::Array<Value> reduced_vregs =
|
||||
src_vregs.Slice(src_slice_start, src_slice_end);
|
||||
std::optional<Value> acc_vreg;
|
||||
auto reduce_elementwise = [&](Value lhs, Value rhs) -> Value {
|
||||
Value result;
|
||||
switch (tpu_kind) {
|
||||
case tpu::ReductionKind::SUM:
|
||||
result =
|
||||
// Extract a subset of source vregs that reduce into this result vreg.
|
||||
SmallVector<int64_t> src_slice_start;
|
||||
src_slice_start.reserve(src_rank);
|
||||
SmallVector<int64_t> src_slice_end;
|
||||
src_slice_end.reserve(src_rank);
|
||||
for (int64_t i : idx) {
|
||||
src_slice_start.push_back(i);
|
||||
src_slice_end.push_back(i + 1);
|
||||
}
|
||||
for (int64_t d : dims) {
|
||||
int64_t d_size = src_vregs.dim(d);
|
||||
src_slice_start.insert(src_slice_start.begin() + d, 0);
|
||||
if (!src_layout.offsets()[0].has_value() && d == src_rank - 2) {
|
||||
d_size = 1;
|
||||
}
|
||||
if (!src_layout.offsets()[1].has_value() && d == src_rank - 1) {
|
||||
d_size = 1;
|
||||
}
|
||||
src_slice_end.insert(src_slice_end.begin() + d, d_size);
|
||||
}
|
||||
xla::Array<Value> reduced_vregs =
|
||||
src_vregs.Slice(src_slice_start, src_slice_end);
|
||||
std::optional<Value> acc_vreg;
|
||||
auto reduce_elementwise = [&](Value lhs, Value rhs) -> Value {
|
||||
Value result;
|
||||
switch (tpu_kind) {
|
||||
case tpu::ReductionKind::SUM:
|
||||
result =
|
||||
is_int
|
||||
? builder.create<arith::AddIOp>(loc, lhs, rhs).getResult()
|
||||
: builder.create<arith::AddFOp>(loc, lhs, rhs)
|
||||
.getResult();
|
||||
break;
|
||||
case tpu::ReductionKind::MAX:
|
||||
break;
|
||||
case tpu::ReductionKind::MAX:
|
||||
result = is_int ? builder.create<arith::MaxSIOp>(loc, lhs, rhs)
|
||||
.getResult()
|
||||
: builder.create<arith::MaximumFOp>(loc, lhs, rhs)
|
||||
.getResult();
|
||||
break;
|
||||
case tpu::ReductionKind::MIN:
|
||||
: builder.create<arith::MaximumFOp>(loc, lhs, rhs)
|
||||
.getResult();
|
||||
break;
|
||||
case tpu::ReductionKind::MIN:
|
||||
result = is_int ? builder.create<arith::MinSIOp>(loc, lhs, rhs)
|
||||
.getResult()
|
||||
: builder.create<arith::MinimumFOp>(loc, lhs, rhs)
|
||||
.getResult();
|
||||
break;
|
||||
: builder.create<arith::MinimumFOp>(loc, lhs, rhs)
|
||||
.getResult();
|
||||
break;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
auto reduction_status = reduced_vregs.EachStatus(
|
||||
[&](const absl::Span<const int64_t> red_idx, Value *const src_vreg) {
|
||||
SmallVector<int64_t> src_idx(red_idx.begin(), red_idx.end());
|
||||
for (int i = 0; i < src_idx.size(); ++i) {
|
||||
src_idx[i] += src_slice_start[i];
|
||||
}
|
||||
return result;
|
||||
};
|
||||
auto reduction_status = reduced_vregs.EachStatus(
|
||||
[&](const absl::Span<const int64_t> red_idx,
|
||||
Value *const src_vreg) {
|
||||
SmallVector<int64_t> src_idx(red_idx.begin(), red_idx.end());
|
||||
for (int i = 0; i < src_idx.size(); ++i) {
|
||||
src_idx[i] += src_slice_start[i];
|
||||
}
|
||||
const std::unique_ptr<VRegDataBounds> data_bounds =
|
||||
src_layout.tileDataBounds(builder.getContext(), src_shape,
|
||||
src_idx, ctx.target_shape,
|
||||
allow_replicated);
|
||||
if (data_bounds == nullptr) {
|
||||
// Op error has already been emitted inside tileDataBounds().
|
||||
return absl::UnknownError("Unable to obtain data bounds");
|
||||
}
|
||||
// TODO(tlongeri): Maybe assemble/disassemble should take
|
||||
// TypedValue<VectorType> and we could save casts here and
|
||||
// elsewhere
|
||||
FailureOr<Value> failure_or_vreg =
|
||||
maskOOB(ctx, builder, cast<TypedValue<VectorType>>(*src_vreg),
|
||||
*data_bounds, neutral);
|
||||
if (failed(failure_or_vreg)) {
|
||||
op.emitOpError("Failed to mask vreg");
|
||||
return absl::UnknownError("");
|
||||
}
|
||||
Value vreg = failure_or_vreg.value();
|
||||
if (!acc_vreg.has_value()) {
|
||||
acc_vreg = vreg;
|
||||
} else {
|
||||
acc_vreg = reduce_elementwise(*acc_vreg, vreg);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
});
|
||||
if (!reduction_status.ok()) {
|
||||
return reduction_status;
|
||||
}
|
||||
TPU_ASSERT_OP(acc_vreg.has_value());
|
||||
if (reduces[1]) {
|
||||
acc_vreg = builder.create<tpu::AllReduceOp>(
|
||||
multi_reduction_op->getLoc(), *acc_vreg, 1, tpu_kind);
|
||||
}
|
||||
if (reduces[0]) {
|
||||
// Packed types are compressed along rows, so we need to reduce them
|
||||
// within each 32-bit word. There's no performance penalty for doing
|
||||
// this in 32-bit precision, so we take advantage of it.
|
||||
Type acc_vreg_ty = acc_vreg->getType();
|
||||
if (acc_layout.packing() > 1) {
|
||||
Type vreg_ty_32 = nullptr;
|
||||
if (acc.getType().getElementType().isBF16()) {
|
||||
vreg_ty_32 =
|
||||
getNativeVregType(builder.getF32Type(), ctx.target_shape);
|
||||
} else {
|
||||
multi_reduction_op.emitOpError(
|
||||
"Not implemented: Unsupported reduction dtype");
|
||||
const std::unique_ptr<VRegDataBounds> data_bounds =
|
||||
src_layout.tileDataBounds(builder.getContext(), src_shape,
|
||||
src_idx, ctx.target_shape,
|
||||
{true, true});
|
||||
if (data_bounds == nullptr) {
|
||||
// Op error has already been emitted inside tileDataBounds().
|
||||
return absl::UnknownError("Unable to obtain data bounds");
|
||||
}
|
||||
Value vreg = *src_vreg;
|
||||
// If replicated, we don't need to mask.
|
||||
if (src_layout.offsets()[0].has_value() ||
|
||||
src_layout.offsets()[1].has_value()) {
|
||||
// TODO(tlongeri): Maybe assemble/disassemble should take
|
||||
// TypedValue<VectorType> and we could save casts here and
|
||||
// elsewhere
|
||||
FailureOr<Value> failure_or_vreg =
|
||||
maskOOB(ctx, builder, cast<TypedValue<VectorType>>(*src_vreg),
|
||||
*data_bounds, neutral);
|
||||
if (failed(failure_or_vreg)) {
|
||||
op.emitOpError("Failed to mask vreg");
|
||||
return absl::UnknownError("");
|
||||
}
|
||||
Value acc_vreg_32 = builder.create<tpu::UnpackSubelementsOp>(
|
||||
loc, vreg_ty_32, *acc_vreg, 0, tpu::PackFormat::kInterleaved);
|
||||
for (int i = 1; i < acc_layout.packing(); ++i) {
|
||||
Value acc_vreg_part_32 = builder.create<tpu::UnpackSubelementsOp>(
|
||||
loc, vreg_ty_32, *acc_vreg, i, tpu::PackFormat::kInterleaved);
|
||||
acc_vreg_32 = reduce_elementwise(acc_vreg_32, acc_vreg_part_32);
|
||||
}
|
||||
acc_vreg = acc_vreg_32;
|
||||
vreg = failure_or_vreg.value();
|
||||
}
|
||||
// At this point acc_vreg is always 32-bit.
|
||||
acc_vreg = builder.create<tpu::AllReduceOp>(
|
||||
multi_reduction_op->getLoc(), *acc_vreg, 0, tpu_kind);
|
||||
// We pack the final result back into the original type.
|
||||
if (acc_layout.packing() > 1) {
|
||||
SmallVector<int32_t> positions(acc_layout.packing());
|
||||
if (!acc_vreg.has_value()) {
|
||||
acc_vreg = vreg;
|
||||
} else {
|
||||
acc_vreg = reduce_elementwise(*acc_vreg, vreg);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
});
|
||||
TF_RETURN_IF_ERROR(reduction_status);
|
||||
TPU_ASSERT_OP(acc_vreg.has_value());
|
||||
const bool is_double_replicated_double_reduced =
|
||||
reduces[0] && reduces[1] && !src_layout.offsets()[0].has_value() &&
|
||||
!src_layout.offsets()[1].has_value();
|
||||
if (reduces[1]) {
|
||||
if (src_layout.offsets()[1].has_value()) {
|
||||
acc_vreg = builder.create<tpu::AllReduceOp>(
|
||||
multi_reduction_op->getLoc(), *acc_vreg, /* dim= */ 1, tpu_kind);
|
||||
} else {
|
||||
int64_t size_dim1 = src_layout.getImplicitTiledDims(src_shape, 1)[1];
|
||||
if (is_double_replicated_double_reduced) {
|
||||
size_dim1 *= src_layout.getImplicitTiledDims(src_shape, 1)[0];
|
||||
}
|
||||
switch (tpu_kind) {
|
||||
case tpu::ReductionKind::SUM:
|
||||
if (is_int) {
|
||||
IntegerAttr size_attr = builder.getI32IntegerAttr(size_dim1);
|
||||
TypedValue<VectorType> source_value = getFullVector(
|
||||
builder,
|
||||
getNativeVregType(builder.getI32Type(), ctx.target_shape),
|
||||
size_attr);
|
||||
acc_vreg =
|
||||
builder.create<arith::MulIOp>(loc, *acc_vreg, source_value);
|
||||
} else {
|
||||
FloatAttr size_attr = builder.getF32FloatAttr(size_dim1);
|
||||
TypedValue<VectorType> source_value = getFullVector(
|
||||
builder,
|
||||
getNativeVregType(builder.getF32Type(), ctx.target_shape),
|
||||
size_attr);
|
||||
acc_vreg =
|
||||
builder.create<arith::MulFOp>(loc, *acc_vreg, source_value);
|
||||
}
|
||||
break;
|
||||
// We don't need to do anything for other reduction kinds.
|
||||
case tpu::ReductionKind::MAX:
|
||||
case tpu::ReductionKind::MIN:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (reduces[0]) {
|
||||
// Packed types are compressed along rows, so we need to reduce them
|
||||
// within each 32-bit word. There's no performance penalty for doing
|
||||
// this in 32-bit precision, so we take advantage of it.
|
||||
Type acc_vreg_ty = acc_vreg->getType();
|
||||
if (acc_layout.packing() > 1) {
|
||||
Type vreg_ty_32 = nullptr;
|
||||
if (acc.getType().getElementType().isBF16()) {
|
||||
vreg_ty_32 =
|
||||
getNativeVregType(builder.getF32Type(), ctx.target_shape);
|
||||
} else {
|
||||
multi_reduction_op.emitOpError(
|
||||
"Not implemented: Unsupported reduction dtype");
|
||||
return absl::UnknownError("");
|
||||
}
|
||||
Value acc_vreg_32 = builder.create<tpu::UnpackSubelementsOp>(
|
||||
loc, vreg_ty_32, *acc_vreg, 0, tpu::PackFormat::kInterleaved);
|
||||
for (int i = 1; i < acc_layout.packing(); ++i) {
|
||||
Value acc_vreg_part_32 = builder.create<tpu::UnpackSubelementsOp>(
|
||||
loc, vreg_ty_32, *acc_vreg, i, tpu::PackFormat::kInterleaved);
|
||||
acc_vreg_32 = reduce_elementwise(acc_vreg_32, acc_vreg_part_32);
|
||||
}
|
||||
acc_vreg = acc_vreg_32;
|
||||
}
|
||||
// At this point acc_vreg is always 32-bit.
|
||||
if (src_layout.offsets()[0].has_value()) {
|
||||
acc_vreg = builder.create<tpu::AllReduceOp>(
|
||||
multi_reduction_op->getLoc(), *acc_vreg, 0, tpu_kind);
|
||||
} else if (!is_double_replicated_double_reduced) {
|
||||
int64_t size_dim0 = src_layout.getImplicitTiledDims(src_shape, 1)[0];
|
||||
switch (tpu_kind) {
|
||||
case tpu::ReductionKind::SUM:
|
||||
if (is_int) {
|
||||
IntegerAttr size_attr = builder.getI32IntegerAttr(size_dim0);
|
||||
TypedValue<VectorType> source_value = getFullVector(
|
||||
builder,
|
||||
getNativeVregType(builder.getI32Type(), ctx.target_shape),
|
||||
size_attr);
|
||||
acc_vreg =
|
||||
builder.create<arith::MulIOp>(loc, *acc_vreg, source_value);
|
||||
} else {
|
||||
FloatAttr size_attr = builder.getF32FloatAttr(size_dim0);
|
||||
TypedValue<VectorType> source_value = getFullVector(
|
||||
builder,
|
||||
getNativeVregType(builder.getF32Type(), ctx.target_shape),
|
||||
size_attr);
|
||||
acc_vreg =
|
||||
builder.create<arith::MulFOp>(loc, *acc_vreg, source_value);
|
||||
}
|
||||
break;
|
||||
case tpu::ReductionKind::MAX:
|
||||
case tpu::ReductionKind::MIN:
|
||||
break;
|
||||
}
|
||||
}
|
||||
// We pack the final result back into the original type.
|
||||
if (acc_layout.packing() > 1) {
|
||||
SmallVector<int32_t> positions(acc_layout.packing());
|
||||
std::iota(positions.begin(), positions.end(),
|
||||
static_cast<int32_t>(0));
|
||||
SmallVector<Value> parts(acc_layout.packing(), *acc_vreg);
|
||||
acc_vreg = builder.create<tpu::PackSubelementsOp>(
|
||||
SmallVector<Value> parts(acc_layout.packing(), *acc_vreg);
|
||||
acc_vreg = builder.create<tpu::PackSubelementsOp>(
|
||||
loc, acc_vreg_ty, parts,
|
||||
builder.getDenseI32ArrayAttr(positions),
|
||||
tpu::PackFormat::kInterleaved);
|
||||
}
|
||||
}
|
||||
*dst_vreg = *acc_vreg;
|
||||
return absl::OkStatus();
|
||||
});
|
||||
tpu::PackFormat::kInterleaved);
|
||||
}
|
||||
}
|
||||
*dst_vreg = *acc_vreg;
|
||||
return absl::OkStatus();
|
||||
});
|
||||
if (!all_results_ok.ok()) {
|
||||
return failure();
|
||||
}
|
||||
@ -4702,7 +4775,7 @@ LogicalResult tpu_prng_random_bits_rule(RewriteContext &ctx, Operation &op,
|
||||
const VectorLayout &layout_out = *layouts_out.front();
|
||||
tpu::PRNGRandomBitsOp rng_op = cast<tpu::PRNGRandomBitsOp>(op);
|
||||
if (layout_out != VectorLayout(32, {0, 0}, ctx.target_shape,
|
||||
VectorLayout::ImplicitDim::kNone)) {
|
||||
VectorLayout::ImplicitDim::kNone)) {
|
||||
return op.emitOpError(
|
||||
"Unsupported output layout for ") << rng_op->getName();
|
||||
}
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import itertools
|
||||
import gc
|
||||
import io
|
||||
import math
|
||||
@ -1708,6 +1709,98 @@ class PallasCallDMAInterpretTest(PallasCallDMATest):
|
||||
|
||||
class PallasCallTest(PallasBaseTest):
|
||||
|
||||
@parameterized.parameters([
|
||||
dict(shape=shape, dty=dty)
|
||||
for shape, dty in itertools.product(
|
||||
[(4, 2, 9), (1, 1025), (1024, 1024)], [jnp.float32, jnp.int32]
|
||||
)
|
||||
])
|
||||
def test_double_replicated_reduction(self, shape, dty):
|
||||
if not jtu.if_cloud_tpu_at_least(2025, 2, 19):
|
||||
self.skipTest("Needs a newer libTPU")
|
||||
def body(o_ref):
|
||||
x = jnp.full(shape, 2.0, dtype=dty)
|
||||
reduction = jnp.sum(x, axis=None)
|
||||
bcast = jnp.full((vregs_in_block * 1024,), reduction)
|
||||
o_ref[:] = bcast
|
||||
|
||||
vregs_in_block = 2
|
||||
total_vregs = 4
|
||||
|
||||
data_size = total_vregs * 1024
|
||||
block_size = vregs_in_block * 1024
|
||||
|
||||
@jax.jit
|
||||
def reduce():
|
||||
return self.pallas_call(
|
||||
body,
|
||||
out_shape=jax.ShapeDtypeStruct((data_size,), dty),
|
||||
in_specs=[],
|
||||
out_specs=pl.BlockSpec((block_size,), lambda i: i),
|
||||
grid= data_size // block_size,
|
||||
)()
|
||||
|
||||
x = jnp.full(shape, 2.0, dtype=dty)
|
||||
z = jax.block_until_ready(reduce())
|
||||
reduce_value = jnp.sum(jnp.full(shape, x), dtype=dty)
|
||||
np.testing.assert_allclose(z, reduce_value)
|
||||
|
||||
@parameterized.parameters([
|
||||
dict(
|
||||
m=m,
|
||||
replicated=replicated,
|
||||
reduced_dims=reduced_dims,
|
||||
dty=dty,
|
||||
reduce_func=reduce_func,
|
||||
)
|
||||
for m, replicated, reduced_dims, dty, reduce_func in itertools.product(
|
||||
[128, 256],
|
||||
[(True, True), (False, True), (True, False)],
|
||||
[(0, 1), (0,), (1,)],
|
||||
[jnp.float32, jnp.int32],
|
||||
[jnp.sum, jnp.max, jnp.min],
|
||||
)
|
||||
])
|
||||
def test_replicated_broadcast_reduction(
|
||||
self, m, replicated, reduced_dims, dty, reduce_func
|
||||
):
|
||||
if not jtu.if_cloud_tpu_at_least(2025, 2, 19):
|
||||
self.skipTest("Needs a newer libTPU")
|
||||
if dty == jnp.int32 and 1 in reduced_dims:
|
||||
# TODO(b/395579834): Remove this skip once we implement this.
|
||||
self.skipTest('int32 reduction on last dimension not supported')
|
||||
if not jtu.is_device_tpu_at_least(4) and len(replicated) == 2:
|
||||
self.skipTest(
|
||||
'Brodcast in both sublanes and lanes not supported on this hardware'
|
||||
)
|
||||
|
||||
in_shape = (1 if replicated[0] else m, 1 if replicated[1] else m)
|
||||
red_shape = [m, m]
|
||||
for d in reduced_dims:
|
||||
red_shape[d] = 1
|
||||
|
||||
def body(x_ref, o_ref):
|
||||
x = x_ref[:]
|
||||
dilated_x = jnp.broadcast_to(x, (m, m))
|
||||
reduced = reduce_func(dilated_x, axis=reduced_dims).reshape(red_shape)
|
||||
o_ref[:] = reduced
|
||||
|
||||
@jax.jit
|
||||
def reduce(x):
|
||||
return self.pallas_call(
|
||||
body,
|
||||
out_shape=jax.ShapeDtypeStruct(red_shape, dty),
|
||||
in_specs=[pl.BlockSpec(in_shape)],
|
||||
out_specs=pl.BlockSpec(red_shape),
|
||||
grid=1,
|
||||
)(x)
|
||||
|
||||
x = jnp.full(in_shape, 2.0, dtype=dty)
|
||||
y = jax.block_until_ready(reduce(x))
|
||||
dilated_x = jnp.broadcast_to(x, (m, m))
|
||||
expected = reduce_func(dilated_x, axis=reduced_dims).reshape(red_shape)
|
||||
np.testing.assert_allclose(y, expected)
|
||||
|
||||
def test_cost_analysis(self):
|
||||
def kernel(x, y):
|
||||
y[:] = x[:]
|
||||
|
Loading…
x
Reference in New Issue
Block a user