[Mosaic] Allow part of x2 int casts.

This should at least allow int2 -> int4 for native tiling vregs. Skip many tests due to XLA compatibility.

PiperOrigin-RevId: 736710186
This commit is contained in:
Tzu-Wei Sung 2025-03-13 18:56:56 -07:00 committed by jax authors
parent 34d6bb2e16
commit e235fb9760
3 changed files with 118 additions and 17 deletions

View File

@ -39,21 +39,26 @@ namespace mlir::tpu {
// Arguments:
// src_sublane: A number of lanes in the full operand.
// hardware_generation: An integer indicating the target TPU generation.
// tiling_sublane: The number of sublane in the target shape.
// target_sublane_count: The number of sublane in the target shape.
// tpu_tiling_flags: A struct of flags indicating which large tiling modes are
// enabled by XLA for memrefs.
// bitwidth: The bitwidth of the element type of the operand.
// is_kernel_argument: Whether the operand is a kernel argument.
int getTilingFactor(const int src_sublane,
const int hardware_generation,
const int64_t tiling_sublane,
int getTilingFactor(const int src_sublane, const int hardware_generation,
const int64_t target_sublane_count,
const TpuTilingFlags &tpu_tiling_flags,
const int8_t bitwidth, const bool is_kernel_argument) {
CHECK(llvm::isPowerOf2_32(bitwidth));
CHECK_LE(4, bitwidth);
CHECK_LE(2, bitwidth);
CHECK_LE(bitwidth, 32);
const int packing = 32 / bitwidth;
const int min_tiling = (1 + (hardware_generation < 4)) * packing;
// When packing is larger than the sublane count, we want its tiling to be at
// least as large as the packing to make sure we can fully pack values. For
// example, for int2 on the target with 8 sublanes, we want the tiling to be
// at least 16.
const int64_t tiling_sublane =
std::max(target_sublane_count, static_cast<int64_t>(packing));
const int max_normal_tiling = tiling_sublane;
int large_tiling = [&] {

View File

@ -1702,6 +1702,33 @@ class VectorLayoutInferer {
src_layout = layout;
dst_layout = VectorLayout(dst_bitwidth, layout.offsets(),
src_layout->tiling(), layout.implicit_dim());
} else if (layout.packing() > target_shape_[0]) {
// When the input dtype has packing greater than the sublane count, we
// can't preserve its native tiling in the output (the tile would be too
// big to fit in a vreg). At the same time, we can't use the default
// tiling either, because the tile size in the input dtype is smaller than
// a sublane.
// For example, for int2 on the target with 8 sublanes, subelements are
// unpacked into 16 consecutive sublanes.
// TODO(b/401624977): Perhaps there is a better layout for this case, or
// if it's impossible, such layout should be used everywhere for int2, not
// just ExtOp.
std::array<int64_t, 2> src_native_tiling = nativeTiling(src_bitwidth);
std::array<int64_t, 2> dst_native_tiling = nativeTiling(dst_bitwidth);
LayoutOffsets src_offsets = {
layout.offsets()[0] ? *layout.offsets()[0] % src_native_tiling[0]
: LayoutOffset(),
layout.offsets()[1] ? *layout.offsets()[1] % src_native_tiling[1]
: LayoutOffset()};
LayoutOffsets dst_offsets = {
layout.offsets()[0] ? *layout.offsets()[0] % dst_native_tiling[0]
: LayoutOffset(),
layout.offsets()[1] ? *layout.offsets()[1] % dst_native_tiling[1]
: LayoutOffset()};
src_layout = VectorLayout(src_bitwidth, src_offsets, src_native_tiling,
layout.implicit_dim());
dst_layout = VectorLayout(dst_bitwidth, dst_offsets, dst_native_tiling,
layout.implicit_dim());
} else {
LayoutOffsets offsets = {
layout.offsets()[0] ? *layout.offsets()[0] % default_tiling_[0]
@ -1742,6 +1769,9 @@ class VectorLayoutInferer {
select_native |= tpu_tiling_flags_.use_x8_large_second_minor;
} else if (dst_ty.getElementTypeBitWidth() == 4) {
select_native |= tpu_tiling_flags_.use_x4_large_second_minor;
} else if (dst_ty.getElementTypeBitWidth() == 2) {
// Force it to native tiling. See comments in `inferExt`.
select_native = true;
} else {
return op->emitOpError("Unsupported target bitwidth for truncation");
}

View File

@ -108,9 +108,11 @@ _DTYPES_SUB_32BIT = (
"int16",
"int8",
"int4",
"int2",
"uint16",
"uint8",
"uint4",
"uint2",
"bool",
"float8_e4m3b11fnuz",
"float8_e5m2",
@ -591,10 +593,15 @@ class OpsTest(PallasBaseTest):
self.skipTest("Not supported on this hardware")
if not jtu.if_cloud_tpu_at_least(2025, 3, 8):
self.skipTest("Test requires libtpu from 2025/3/8 or later")
if from_dtype in {"int2", "uint2"} or to_dtype in {"int2", "uint2"}:
if jtu.test_device_matches(["tpu"]) and not jtu.if_cloud_tpu_at_least(
2025, 4, 1
):
self.skipTest("Test requires libtpu from 2025/4/1 or later")
if from_dtype == to_dtype:
self.skipTest("Unnecessary test")
if jtu.is_device_tpu(version=4):
if to_dtype in {"int8", "uint8", "int4", "uint4"}:
if to_dtype in {"int8", "uint8", "int4", "uint4", "int2", "uint2"}:
self.skipTest("Not supported on this TPU generation")
if to_dtype in {"int16", "uint16"} and not jtu.if_cloud_tpu_at_least(2025, 1, 18):
self.skipTest("Test requires libtpu from 2025/1/18 or later")
@ -602,8 +609,13 @@ class OpsTest(PallasBaseTest):
# Currently only casts between 32-bit types and to bf16 are supported.
if to_dtype not in {"int32", "uint32", "float32", "bfloat16"}:
self.skipTest("Not supported on this TPU generation")
if jtu.test_device_matches(["gpu"]) and to_dtype in {"int4", "uint4"}:
self.skipTest("int4/uint4 casts are buggy on GPU") # b/391292861
if jtu.test_device_matches(["gpu"]) and to_dtype in {
"int4",
"uint4",
"int2",
"uint2",
}:
self.skipTest("sub-byte casts are buggy on GPU") # b/391292861
if to_dtype == "float16" and not sut_is_mosaic_gpu:
self.skipTest("float16 is only supported with Mosaic GPU")
if sut_is_mosaic_gpu and to_dtype == "bool":
@ -611,7 +623,11 @@ class OpsTest(PallasBaseTest):
# XLA does not specify the float->int conversion result for NaNs.
elements = dict(allow_nan=not jnp.issubdtype(to_dtype, jnp.integer))
x = data.draw(hnp.arrays(from_dtype, (8, 128), elements=elements))
shape = (8, 128)
if to_dtype in {"int2", "uint2"}:
# Make sure #rows is a least the packing factor of int2.
shape = (16, 128)
x = data.draw(hnp.arrays(from_dtype, shape, elements=elements))
x = jnp.asarray(x)
def kernel(x_ref, y_ref):
x = x_ref[...]
@ -643,11 +659,26 @@ class OpsTest(PallasBaseTest):
if from_dtype == to_dtype:
self.skipTest("Unnecessary test")
if from_dtype in {"int2", "uint2"} or to_dtype in {"int2", "uint2"}:
if jtu.test_device_matches(["tpu"]) and not jtu.if_cloud_tpu_at_least(
2025, 4, 1
):
self.skipTest("Test requires libtpu from 2025/4/1 or later")
if jtu.is_device_tpu(version=4):
allowed_v4_cats = {("int16", "int32"): (2025, 1, 18)}
if (
from_dtype in {"int16", "int8", "uint16", "uint8", "int4", "uint4"}
or to_dtype in {"int8", "uint8", "int4", "uint4"}
from_dtype
in {
"int16",
"int8",
"uint16",
"uint8",
"int4",
"uint4",
"int2",
"uint2",
}
or to_dtype in {"int8", "uint8", "int4", "uint4", "int2", "uint2"}
) and (from_dtype, to_dtype) not in allowed_v4_cats:
self.skipTest("Not supported on this TPU generation")
if minimum_libtpu_date := allowed_v4_cats.get((from_dtype, to_dtype), None):
@ -657,12 +688,21 @@ class OpsTest(PallasBaseTest):
self.skipTest("Test requires libtpu from 2025/1/18 or later")
if jtu.test_device_matches(["tpu"]) and jtu.get_tpu_version() < 4:
self.skipTest("Not supported on this TPU generation")
if jtu.test_device_matches(["gpu"]) and to_dtype in {"int4", "uint4"}:
self.skipTest("int4/uint4 casts are buggy on GPU") # b/391292861
if jtu.test_device_matches(["gpu"]) and (
to_dtype
in {
"int4",
"uint4",
"int2",
"uint2",
}
or from_dtype in {"int2", "uint2"}
):
self.skipTest("sub-byte casts are buggy on GPU") # b/391292861
if from_dtype == "float16" or to_dtype == "float16" and not sut_is_mosaic_gpu:
self.skipTest("float16 is only supported with Mosaic GPU")
if sut_is_mosaic_gpu:
unsupported_types = {"bool", "int4", "uint4"}
unsupported_types = {"bool", "int4", "uint4", "int2", "uint2"}
if to_dtype in unsupported_types or from_dtype in unsupported_types:
self.skipTest("Sub-byte types are not yet supported with Mosaic GPU")
if not randomize:
@ -677,6 +717,21 @@ class OpsTest(PallasBaseTest):
self.skipTest("Not supported on this hardware")
if not jtu.if_cloud_tpu_at_least(2025, 3, 9):
self.skipTest("Test requires libtpu from 2025/3/9 or later")
if from_dtype == "int2" and to_dtype == "bool":
self.skipTest(
"TODO(b/343490729): XLA compare(s2, s2) yields wrong results"
)
if not randomize:
if from_dtype in {"int2", "uint2"}:
# TODO(b/343490729): XLA doesn't work well with int2.
# iota 1D is unsupported, and XLA tends to select an unsupported
# layout too when passing constants created in numpy. Thankfully,
# there is randomize=True for the test coverage.
self.skipTest("XLA tends to select an unsupported layout for int2")
if to_dtype in {"int2", "uint2"}:
self.skipTest(
"TODO(b/401624977): Mask on int2 is not yet supported in Mosaic"
)
from_int = np.issubdtype(np.dtype(from_dtype), np.integer)
to_int = np.issubdtype(np.dtype(to_dtype), np.integer)
@ -687,8 +742,16 @@ class OpsTest(PallasBaseTest):
self.skipTest("trunc from non-32 bit only implemented recently")
# TODO(sharadmv,apaszke): add support for the following casts
if (from_dtype == "bool" and
to_dtype in {"int16", "int8", "int4", "uint16", "uint8", "uint4"}):
if from_dtype == "bool" and to_dtype in {
"int16",
"int8",
"int4",
"uint16",
"uint8",
"uint4",
"int2",
"uint2",
}:
self.skipTest("Not supported: cannot extend to sub-32 bit types")
def bitwidth(dtype):
@ -753,7 +816,10 @@ class OpsTest(PallasBaseTest):
if to_dtype == jnp.bool:
y = y.astype(jnp.bool)
y_ref = x.astype(to_dtype)
if jnp.dtype(to_dtype) in map(jnp.dtype, (jnp.bfloat16, jnp.int4, jnp.uint4)):
if jnp.dtype(to_dtype) in map(
jnp.dtype,
(jnp.bfloat16, jnp.int4, jnp.uint4, dtypes.int2, dtypes.uint2),
):
y, y_ref = y.astype(np.float32), y_ref.astype(np.float32)
np.testing.assert_allclose(y, y_ref, atol=0., rtol=0.)