mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[Mosaic] Allow part of x2 int casts.
This should at least allow int2 -> int4 for native tiling vregs. Skip many tests due to XLA compatibility. PiperOrigin-RevId: 736710186
This commit is contained in:
parent
34d6bb2e16
commit
e235fb9760
@ -39,21 +39,26 @@ namespace mlir::tpu {
|
||||
// Arguments:
|
||||
// src_sublane: A number of lanes in the full operand.
|
||||
// hardware_generation: An integer indicating the target TPU generation.
|
||||
// tiling_sublane: The number of sublane in the target shape.
|
||||
// target_sublane_count: The number of sublane in the target shape.
|
||||
// tpu_tiling_flags: A struct of flags indicating which large tiling modes are
|
||||
// enabled by XLA for memrefs.
|
||||
// bitwidth: The bitwidth of the element type of the operand.
|
||||
// is_kernel_argument: Whether the operand is a kernel argument.
|
||||
int getTilingFactor(const int src_sublane,
|
||||
const int hardware_generation,
|
||||
const int64_t tiling_sublane,
|
||||
int getTilingFactor(const int src_sublane, const int hardware_generation,
|
||||
const int64_t target_sublane_count,
|
||||
const TpuTilingFlags &tpu_tiling_flags,
|
||||
const int8_t bitwidth, const bool is_kernel_argument) {
|
||||
CHECK(llvm::isPowerOf2_32(bitwidth));
|
||||
CHECK_LE(4, bitwidth);
|
||||
CHECK_LE(2, bitwidth);
|
||||
CHECK_LE(bitwidth, 32);
|
||||
const int packing = 32 / bitwidth;
|
||||
const int min_tiling = (1 + (hardware_generation < 4)) * packing;
|
||||
// When packing is larger than the sublane count, we want its tiling to be at
|
||||
// least as large as the packing to make sure we can fully pack values. For
|
||||
// example, for int2 on the target with 8 sublanes, we want the tiling to be
|
||||
// at least 16.
|
||||
const int64_t tiling_sublane =
|
||||
std::max(target_sublane_count, static_cast<int64_t>(packing));
|
||||
const int max_normal_tiling = tiling_sublane;
|
||||
|
||||
int large_tiling = [&] {
|
||||
|
@ -1702,6 +1702,33 @@ class VectorLayoutInferer {
|
||||
src_layout = layout;
|
||||
dst_layout = VectorLayout(dst_bitwidth, layout.offsets(),
|
||||
src_layout->tiling(), layout.implicit_dim());
|
||||
} else if (layout.packing() > target_shape_[0]) {
|
||||
// When the input dtype has packing greater than the sublane count, we
|
||||
// can't preserve its native tiling in the output (the tile would be too
|
||||
// big to fit in a vreg). At the same time, we can't use the default
|
||||
// tiling either, because the tile size in the input dtype is smaller than
|
||||
// a sublane.
|
||||
// For example, for int2 on the target with 8 sublanes, subelements are
|
||||
// unpacked into 16 consecutive sublanes.
|
||||
// TODO(b/401624977): Perhaps there is a better layout for this case, or
|
||||
// if it's impossible, such layout should be used everywhere for int2, not
|
||||
// just ExtOp.
|
||||
std::array<int64_t, 2> src_native_tiling = nativeTiling(src_bitwidth);
|
||||
std::array<int64_t, 2> dst_native_tiling = nativeTiling(dst_bitwidth);
|
||||
LayoutOffsets src_offsets = {
|
||||
layout.offsets()[0] ? *layout.offsets()[0] % src_native_tiling[0]
|
||||
: LayoutOffset(),
|
||||
layout.offsets()[1] ? *layout.offsets()[1] % src_native_tiling[1]
|
||||
: LayoutOffset()};
|
||||
LayoutOffsets dst_offsets = {
|
||||
layout.offsets()[0] ? *layout.offsets()[0] % dst_native_tiling[0]
|
||||
: LayoutOffset(),
|
||||
layout.offsets()[1] ? *layout.offsets()[1] % dst_native_tiling[1]
|
||||
: LayoutOffset()};
|
||||
src_layout = VectorLayout(src_bitwidth, src_offsets, src_native_tiling,
|
||||
layout.implicit_dim());
|
||||
dst_layout = VectorLayout(dst_bitwidth, dst_offsets, dst_native_tiling,
|
||||
layout.implicit_dim());
|
||||
} else {
|
||||
LayoutOffsets offsets = {
|
||||
layout.offsets()[0] ? *layout.offsets()[0] % default_tiling_[0]
|
||||
@ -1742,6 +1769,9 @@ class VectorLayoutInferer {
|
||||
select_native |= tpu_tiling_flags_.use_x8_large_second_minor;
|
||||
} else if (dst_ty.getElementTypeBitWidth() == 4) {
|
||||
select_native |= tpu_tiling_flags_.use_x4_large_second_minor;
|
||||
} else if (dst_ty.getElementTypeBitWidth() == 2) {
|
||||
// Force it to native tiling. See comments in `inferExt`.
|
||||
select_native = true;
|
||||
} else {
|
||||
return op->emitOpError("Unsupported target bitwidth for truncation");
|
||||
}
|
||||
|
@ -108,9 +108,11 @@ _DTYPES_SUB_32BIT = (
|
||||
"int16",
|
||||
"int8",
|
||||
"int4",
|
||||
"int2",
|
||||
"uint16",
|
||||
"uint8",
|
||||
"uint4",
|
||||
"uint2",
|
||||
"bool",
|
||||
"float8_e4m3b11fnuz",
|
||||
"float8_e5m2",
|
||||
@ -591,10 +593,15 @@ class OpsTest(PallasBaseTest):
|
||||
self.skipTest("Not supported on this hardware")
|
||||
if not jtu.if_cloud_tpu_at_least(2025, 3, 8):
|
||||
self.skipTest("Test requires libtpu from 2025/3/8 or later")
|
||||
if from_dtype in {"int2", "uint2"} or to_dtype in {"int2", "uint2"}:
|
||||
if jtu.test_device_matches(["tpu"]) and not jtu.if_cloud_tpu_at_least(
|
||||
2025, 4, 1
|
||||
):
|
||||
self.skipTest("Test requires libtpu from 2025/4/1 or later")
|
||||
if from_dtype == to_dtype:
|
||||
self.skipTest("Unnecessary test")
|
||||
if jtu.is_device_tpu(version=4):
|
||||
if to_dtype in {"int8", "uint8", "int4", "uint4"}:
|
||||
if to_dtype in {"int8", "uint8", "int4", "uint4", "int2", "uint2"}:
|
||||
self.skipTest("Not supported on this TPU generation")
|
||||
if to_dtype in {"int16", "uint16"} and not jtu.if_cloud_tpu_at_least(2025, 1, 18):
|
||||
self.skipTest("Test requires libtpu from 2025/1/18 or later")
|
||||
@ -602,8 +609,13 @@ class OpsTest(PallasBaseTest):
|
||||
# Currently only casts between 32-bit types and to bf16 are supported.
|
||||
if to_dtype not in {"int32", "uint32", "float32", "bfloat16"}:
|
||||
self.skipTest("Not supported on this TPU generation")
|
||||
if jtu.test_device_matches(["gpu"]) and to_dtype in {"int4", "uint4"}:
|
||||
self.skipTest("int4/uint4 casts are buggy on GPU") # b/391292861
|
||||
if jtu.test_device_matches(["gpu"]) and to_dtype in {
|
||||
"int4",
|
||||
"uint4",
|
||||
"int2",
|
||||
"uint2",
|
||||
}:
|
||||
self.skipTest("sub-byte casts are buggy on GPU") # b/391292861
|
||||
if to_dtype == "float16" and not sut_is_mosaic_gpu:
|
||||
self.skipTest("float16 is only supported with Mosaic GPU")
|
||||
if sut_is_mosaic_gpu and to_dtype == "bool":
|
||||
@ -611,7 +623,11 @@ class OpsTest(PallasBaseTest):
|
||||
|
||||
# XLA does not specify the float->int conversion result for NaNs.
|
||||
elements = dict(allow_nan=not jnp.issubdtype(to_dtype, jnp.integer))
|
||||
x = data.draw(hnp.arrays(from_dtype, (8, 128), elements=elements))
|
||||
shape = (8, 128)
|
||||
if to_dtype in {"int2", "uint2"}:
|
||||
# Make sure #rows is a least the packing factor of int2.
|
||||
shape = (16, 128)
|
||||
x = data.draw(hnp.arrays(from_dtype, shape, elements=elements))
|
||||
x = jnp.asarray(x)
|
||||
def kernel(x_ref, y_ref):
|
||||
x = x_ref[...]
|
||||
@ -643,11 +659,26 @@ class OpsTest(PallasBaseTest):
|
||||
|
||||
if from_dtype == to_dtype:
|
||||
self.skipTest("Unnecessary test")
|
||||
if from_dtype in {"int2", "uint2"} or to_dtype in {"int2", "uint2"}:
|
||||
if jtu.test_device_matches(["tpu"]) and not jtu.if_cloud_tpu_at_least(
|
||||
2025, 4, 1
|
||||
):
|
||||
self.skipTest("Test requires libtpu from 2025/4/1 or later")
|
||||
if jtu.is_device_tpu(version=4):
|
||||
allowed_v4_cats = {("int16", "int32"): (2025, 1, 18)}
|
||||
if (
|
||||
from_dtype in {"int16", "int8", "uint16", "uint8", "int4", "uint4"}
|
||||
or to_dtype in {"int8", "uint8", "int4", "uint4"}
|
||||
from_dtype
|
||||
in {
|
||||
"int16",
|
||||
"int8",
|
||||
"uint16",
|
||||
"uint8",
|
||||
"int4",
|
||||
"uint4",
|
||||
"int2",
|
||||
"uint2",
|
||||
}
|
||||
or to_dtype in {"int8", "uint8", "int4", "uint4", "int2", "uint2"}
|
||||
) and (from_dtype, to_dtype) not in allowed_v4_cats:
|
||||
self.skipTest("Not supported on this TPU generation")
|
||||
if minimum_libtpu_date := allowed_v4_cats.get((from_dtype, to_dtype), None):
|
||||
@ -657,12 +688,21 @@ class OpsTest(PallasBaseTest):
|
||||
self.skipTest("Test requires libtpu from 2025/1/18 or later")
|
||||
if jtu.test_device_matches(["tpu"]) and jtu.get_tpu_version() < 4:
|
||||
self.skipTest("Not supported on this TPU generation")
|
||||
if jtu.test_device_matches(["gpu"]) and to_dtype in {"int4", "uint4"}:
|
||||
self.skipTest("int4/uint4 casts are buggy on GPU") # b/391292861
|
||||
if jtu.test_device_matches(["gpu"]) and (
|
||||
to_dtype
|
||||
in {
|
||||
"int4",
|
||||
"uint4",
|
||||
"int2",
|
||||
"uint2",
|
||||
}
|
||||
or from_dtype in {"int2", "uint2"}
|
||||
):
|
||||
self.skipTest("sub-byte casts are buggy on GPU") # b/391292861
|
||||
if from_dtype == "float16" or to_dtype == "float16" and not sut_is_mosaic_gpu:
|
||||
self.skipTest("float16 is only supported with Mosaic GPU")
|
||||
if sut_is_mosaic_gpu:
|
||||
unsupported_types = {"bool", "int4", "uint4"}
|
||||
unsupported_types = {"bool", "int4", "uint4", "int2", "uint2"}
|
||||
if to_dtype in unsupported_types or from_dtype in unsupported_types:
|
||||
self.skipTest("Sub-byte types are not yet supported with Mosaic GPU")
|
||||
if not randomize:
|
||||
@ -677,6 +717,21 @@ class OpsTest(PallasBaseTest):
|
||||
self.skipTest("Not supported on this hardware")
|
||||
if not jtu.if_cloud_tpu_at_least(2025, 3, 9):
|
||||
self.skipTest("Test requires libtpu from 2025/3/9 or later")
|
||||
if from_dtype == "int2" and to_dtype == "bool":
|
||||
self.skipTest(
|
||||
"TODO(b/343490729): XLA compare(s2, s2) yields wrong results"
|
||||
)
|
||||
if not randomize:
|
||||
if from_dtype in {"int2", "uint2"}:
|
||||
# TODO(b/343490729): XLA doesn't work well with int2.
|
||||
# iota 1D is unsupported, and XLA tends to select an unsupported
|
||||
# layout too when passing constants created in numpy. Thankfully,
|
||||
# there is randomize=True for the test coverage.
|
||||
self.skipTest("XLA tends to select an unsupported layout for int2")
|
||||
if to_dtype in {"int2", "uint2"}:
|
||||
self.skipTest(
|
||||
"TODO(b/401624977): Mask on int2 is not yet supported in Mosaic"
|
||||
)
|
||||
|
||||
from_int = np.issubdtype(np.dtype(from_dtype), np.integer)
|
||||
to_int = np.issubdtype(np.dtype(to_dtype), np.integer)
|
||||
@ -687,8 +742,16 @@ class OpsTest(PallasBaseTest):
|
||||
self.skipTest("trunc from non-32 bit only implemented recently")
|
||||
|
||||
# TODO(sharadmv,apaszke): add support for the following casts
|
||||
if (from_dtype == "bool" and
|
||||
to_dtype in {"int16", "int8", "int4", "uint16", "uint8", "uint4"}):
|
||||
if from_dtype == "bool" and to_dtype in {
|
||||
"int16",
|
||||
"int8",
|
||||
"int4",
|
||||
"uint16",
|
||||
"uint8",
|
||||
"uint4",
|
||||
"int2",
|
||||
"uint2",
|
||||
}:
|
||||
self.skipTest("Not supported: cannot extend to sub-32 bit types")
|
||||
|
||||
def bitwidth(dtype):
|
||||
@ -753,7 +816,10 @@ class OpsTest(PallasBaseTest):
|
||||
if to_dtype == jnp.bool:
|
||||
y = y.astype(jnp.bool)
|
||||
y_ref = x.astype(to_dtype)
|
||||
if jnp.dtype(to_dtype) in map(jnp.dtype, (jnp.bfloat16, jnp.int4, jnp.uint4)):
|
||||
if jnp.dtype(to_dtype) in map(
|
||||
jnp.dtype,
|
||||
(jnp.bfloat16, jnp.int4, jnp.uint4, dtypes.int2, dtypes.uint2),
|
||||
):
|
||||
y, y_ref = y.astype(np.float32), y_ref.astype(np.float32)
|
||||
np.testing.assert_allclose(y, y_ref, atol=0., rtol=0.)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user