[pallas:mosaic_gpu] Added WG lowering rule for lax.bitcast_convert_type_p

PiperOrigin-RevId: 734081448
This commit is contained in:
Sergei Lebedev 2025-03-06 04:08:57 -08:00 committed by jax authors
parent d6b97c2026
commit 2a34019388
4 changed files with 60 additions and 22 deletions

View File

@ -1908,27 +1908,36 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches):
@register_lowering_rule(lax.bitcast_convert_type_p, mgpu.ThreadSemantics.Lane)
@register_lowering_rule(
lax.bitcast_convert_type_p, mgpu.ThreadSemantics.Warpgroup
)
def _bitcast_convert_type_lowering_rule(
ctx: LoweringRuleContext, operand, *, new_dtype
ctx: LoweringRuleContext, x, *, new_dtype
):
# TODO(petebu) Handle case where src and dst types have different bitwidths
[operand_aval] = ctx.avals_in
operand = _ensure_fa(operand, operand_aval.dtype)
src_elem_type = mgpu_utils.dtype_to_ir_type(operand_aval.dtype)
[x_aval] = ctx.avals_in
src_elem_type = mgpu_utils.dtype_to_ir_type(x_aval.dtype)
dst_elem_type = mgpu_utils.dtype_to_ir_type(new_dtype)
assert isinstance(src_elem_type, (ir.IntegerType, ir.FloatType))
assert isinstance(dst_elem_type, (ir.IntegerType, ir.FloatType))
if src_elem_type.width != dst_elem_type.width:
raise NotImplementedError(
f"Can't bitcast from {operand_aval.dtype} to {new_dtype} because they"
f"Cannot bitcast from {x_aval.dtype} to {new_dtype} because they"
" have different widths"
)
if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Warpgroup:
x = _ensure_ir_value(x, x_aval.dtype)
return arith_dialect.bitcast(
ir.VectorType.get(x_aval.shape, dst_elem_type), x
)
x = _ensure_fa(x, x_aval.dtype)
if ir.IntegerType.isinstance(dst_elem_type):
output_is_signed = mgpu_utils.is_signed(new_dtype)
else:
output_is_signed = None
return mgpu.FragmentedArray.bitcast(
operand, dst_elem_type, output_is_signed=output_is_signed
x, dst_elem_type, output_is_signed=output_is_signed
)

View File

@ -525,6 +525,25 @@ def _cmpf_op_lowering_rule(
return [_fragmented_array_to_ir(impl(lhs, rhs), op.result.type)]
@_register_lowering(arith.BitcastOp)
def _bitcast_op_lowering_rule(
_: LoweringContext, op: arith.BitcastOp
) -> Sequence[ir.Value]:
in_layouts = inference_utils.in_layouts(op)
[layout] = inference_utils.out_layouts(op)
if any(in_layout != layout for in_layout in in_layouts):
raise ValueError("Layout mismatch")
in_ = _fragmented_array_from_ir(op.in_, layout)
out_element_type = ir.VectorType(op.result.type).element_type
out = in_.bitcast(
out_element_type,
output_is_signed=False
if ir.IntegerType.isinstance(out_element_type)
else None,
)
return [_fragmented_array_to_ir(out, op.result.type)]
@_register_lowering(mgpu.WGMMAOp)
def _mgpu_wgmma_op_lowering_rule(
_: LoweringContext, wgmma_op: mgpu.WGMMAOp

View File

@ -194,6 +194,7 @@ def _infer_pointwise_op_layouts(op: ir.OpView) -> OptionalLayouts:
for op in [
arith.AddIOp, arith.AddFOp,
arith.AndIOp,
arith.BitcastOp,
arith.CmpFOp,
arith.CmpIOp,
arith.ExtFOp, arith.ExtSIOp, arith.ExtUIOp,

View File

@ -1161,29 +1161,38 @@ class PallasCallTest(PallasTest):
self.assertEqual(data.count('"name": "store"'), 2)
np.testing.assert_array_equal(y, x + x)
@parameterized.parameters(
(jnp.float16, jnp.float16), # Noop
(jnp.int16, jnp.bfloat16),
(jnp.int16, jnp.float16),
(jnp.uint16, jnp.float16),
(jnp.float32, jnp.int32),
(jnp.float32, jnp.uint32),
(jnp.uint32, jnp.int32),
(jnp.int32, jnp.uint32),
@parameterized.product(
dtypes=[
(jnp.float16, jnp.float16), # Noop
(jnp.int16, jnp.bfloat16),
(jnp.int16, jnp.float16),
(jnp.uint16, jnp.float16),
(jnp.float32, jnp.int32),
(jnp.float32, jnp.uint32),
(jnp.uint32, jnp.int32),
(jnp.int32, jnp.uint32),
],
thread_semantics=[*plgpu.ThreadSemantics],
)
def test_bitcast_convert_type(self, in_dtype, out_dtype):
def test_bitcast_convert_type(self, dtypes, thread_semantics):
in_dtype, out_dtype = dtypes
m, n = 16, 8
out_shape = jax.ShapeDtypeStruct((m, n), out_dtype)
grid = ()
@functools.partial(pl.pallas_call, out_shape=out_shape, grid=grid)
@functools.partial(
pl.pallas_call,
out_shape=out_shape,
compiler_params=plgpu.GPUCompilerParams(
thread_semantics=thread_semantics
),
)
def convert(x_ref, y_ref):
y_ref[...] = jax.lax.bitcast_convert_type(x_ref[...], out_shape)
x = jnp.arange(m * n, dtype=in_dtype).reshape((m, n))
y = convert(x)
y_ref = jax.lax.bitcast_convert_type(x, out_dtype)
np.testing.assert_array_equal(y, y_ref)
np.testing.assert_array_equal(
convert(x), jax.lax.bitcast_convert_type(x, out_dtype)
)
class PallasCallSm90ATest(PallasSm90ATest):