mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[pallas:mosaic_gpu] Added WG lowering rule for lax.bitcast_convert_type_p
PiperOrigin-RevId: 734081448
This commit is contained in:
parent
d6b97c2026
commit
2a34019388
@ -1908,27 +1908,36 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches):
|
||||
|
||||
|
||||
@register_lowering_rule(lax.bitcast_convert_type_p, mgpu.ThreadSemantics.Lane)
|
||||
@register_lowering_rule(
|
||||
lax.bitcast_convert_type_p, mgpu.ThreadSemantics.Warpgroup
|
||||
)
|
||||
def _bitcast_convert_type_lowering_rule(
|
||||
ctx: LoweringRuleContext, operand, *, new_dtype
|
||||
ctx: LoweringRuleContext, x, *, new_dtype
|
||||
):
|
||||
# TODO(petebu) Handle case where src and dst types have different bitwidths
|
||||
[operand_aval] = ctx.avals_in
|
||||
operand = _ensure_fa(operand, operand_aval.dtype)
|
||||
src_elem_type = mgpu_utils.dtype_to_ir_type(operand_aval.dtype)
|
||||
[x_aval] = ctx.avals_in
|
||||
src_elem_type = mgpu_utils.dtype_to_ir_type(x_aval.dtype)
|
||||
dst_elem_type = mgpu_utils.dtype_to_ir_type(new_dtype)
|
||||
assert isinstance(src_elem_type, (ir.IntegerType, ir.FloatType))
|
||||
assert isinstance(dst_elem_type, (ir.IntegerType, ir.FloatType))
|
||||
if src_elem_type.width != dst_elem_type.width:
|
||||
raise NotImplementedError(
|
||||
f"Can't bitcast from {operand_aval.dtype} to {new_dtype} because they"
|
||||
f"Cannot bitcast from {x_aval.dtype} to {new_dtype} because they"
|
||||
" have different widths"
|
||||
)
|
||||
|
||||
if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Warpgroup:
|
||||
x = _ensure_ir_value(x, x_aval.dtype)
|
||||
return arith_dialect.bitcast(
|
||||
ir.VectorType.get(x_aval.shape, dst_elem_type), x
|
||||
)
|
||||
|
||||
x = _ensure_fa(x, x_aval.dtype)
|
||||
if ir.IntegerType.isinstance(dst_elem_type):
|
||||
output_is_signed = mgpu_utils.is_signed(new_dtype)
|
||||
else:
|
||||
output_is_signed = None
|
||||
return mgpu.FragmentedArray.bitcast(
|
||||
operand, dst_elem_type, output_is_signed=output_is_signed
|
||||
x, dst_elem_type, output_is_signed=output_is_signed
|
||||
)
|
||||
|
||||
|
||||
|
@ -525,6 +525,25 @@ def _cmpf_op_lowering_rule(
|
||||
return [_fragmented_array_to_ir(impl(lhs, rhs), op.result.type)]
|
||||
|
||||
|
||||
@_register_lowering(arith.BitcastOp)
|
||||
def _bitcast_op_lowering_rule(
|
||||
_: LoweringContext, op: arith.BitcastOp
|
||||
) -> Sequence[ir.Value]:
|
||||
in_layouts = inference_utils.in_layouts(op)
|
||||
[layout] = inference_utils.out_layouts(op)
|
||||
if any(in_layout != layout for in_layout in in_layouts):
|
||||
raise ValueError("Layout mismatch")
|
||||
in_ = _fragmented_array_from_ir(op.in_, layout)
|
||||
out_element_type = ir.VectorType(op.result.type).element_type
|
||||
out = in_.bitcast(
|
||||
out_element_type,
|
||||
output_is_signed=False
|
||||
if ir.IntegerType.isinstance(out_element_type)
|
||||
else None,
|
||||
)
|
||||
return [_fragmented_array_to_ir(out, op.result.type)]
|
||||
|
||||
|
||||
@_register_lowering(mgpu.WGMMAOp)
|
||||
def _mgpu_wgmma_op_lowering_rule(
|
||||
_: LoweringContext, wgmma_op: mgpu.WGMMAOp
|
||||
|
@ -194,6 +194,7 @@ def _infer_pointwise_op_layouts(op: ir.OpView) -> OptionalLayouts:
|
||||
for op in [
|
||||
arith.AddIOp, arith.AddFOp,
|
||||
arith.AndIOp,
|
||||
arith.BitcastOp,
|
||||
arith.CmpFOp,
|
||||
arith.CmpIOp,
|
||||
arith.ExtFOp, arith.ExtSIOp, arith.ExtUIOp,
|
||||
|
@ -1161,29 +1161,38 @@ class PallasCallTest(PallasTest):
|
||||
self.assertEqual(data.count('"name": "store"'), 2)
|
||||
np.testing.assert_array_equal(y, x + x)
|
||||
|
||||
@parameterized.parameters(
|
||||
(jnp.float16, jnp.float16), # Noop
|
||||
(jnp.int16, jnp.bfloat16),
|
||||
(jnp.int16, jnp.float16),
|
||||
(jnp.uint16, jnp.float16),
|
||||
(jnp.float32, jnp.int32),
|
||||
(jnp.float32, jnp.uint32),
|
||||
(jnp.uint32, jnp.int32),
|
||||
(jnp.int32, jnp.uint32),
|
||||
@parameterized.product(
|
||||
dtypes=[
|
||||
(jnp.float16, jnp.float16), # Noop
|
||||
(jnp.int16, jnp.bfloat16),
|
||||
(jnp.int16, jnp.float16),
|
||||
(jnp.uint16, jnp.float16),
|
||||
(jnp.float32, jnp.int32),
|
||||
(jnp.float32, jnp.uint32),
|
||||
(jnp.uint32, jnp.int32),
|
||||
(jnp.int32, jnp.uint32),
|
||||
],
|
||||
thread_semantics=[*plgpu.ThreadSemantics],
|
||||
)
|
||||
def test_bitcast_convert_type(self, in_dtype, out_dtype):
|
||||
def test_bitcast_convert_type(self, dtypes, thread_semantics):
|
||||
in_dtype, out_dtype = dtypes
|
||||
m, n = 16, 8
|
||||
out_shape = jax.ShapeDtypeStruct((m, n), out_dtype)
|
||||
grid = ()
|
||||
|
||||
@functools.partial(pl.pallas_call, out_shape=out_shape, grid=grid)
|
||||
@functools.partial(
|
||||
pl.pallas_call,
|
||||
out_shape=out_shape,
|
||||
compiler_params=plgpu.GPUCompilerParams(
|
||||
thread_semantics=thread_semantics
|
||||
),
|
||||
)
|
||||
def convert(x_ref, y_ref):
|
||||
y_ref[...] = jax.lax.bitcast_convert_type(x_ref[...], out_shape)
|
||||
|
||||
x = jnp.arange(m * n, dtype=in_dtype).reshape((m, n))
|
||||
y = convert(x)
|
||||
y_ref = jax.lax.bitcast_convert_type(x, out_dtype)
|
||||
np.testing.assert_array_equal(y, y_ref)
|
||||
np.testing.assert_array_equal(
|
||||
convert(x), jax.lax.bitcast_convert_type(x, out_dtype)
|
||||
)
|
||||
|
||||
|
||||
class PallasCallSm90ATest(PallasSm90ATest):
|
||||
|
Loading…
x
Reference in New Issue
Block a user