mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[pallas:triton] Fix atomic min/max lowering for unsigned integers and float types (#263)
This commit is contained in:
parent
01c8d7feb8
commit
a8c11ba79e
@ -462,6 +462,74 @@ def _atomic_rmw(
|
||||
result_type, op, ptr, val, mask=mask, sem=semantic, scope=sync_scope
|
||||
)
|
||||
|
||||
def _fp_bits_type(t: ir.Type) -> ir.Type:
|
||||
if ir.RankedTensorType.isinstance(t):
|
||||
t_type = ir.RankedTensorType(t)
|
||||
return ir.RankedTensorType.get(
|
||||
t_type.shape, _fp_bits_type(t_type.element_type), t_type.encoding
|
||||
)
|
||||
elif tt_dialect.PointerType.isinstance(t):
|
||||
ptr_type = tt_dialect.PointerType(t)
|
||||
return tt_dialect.PointerType.get(
|
||||
_fp_bits_type(ptr_type.pointee_type), ptr_type.address_space
|
||||
)
|
||||
else:
|
||||
assert isinstance(t, ir.FloatType)
|
||||
return ir.IntegerType.get_signless(t.width)
|
||||
|
||||
|
||||
def _expand_atomic_fp_min_max(
|
||||
atomic_type: primitives.AtomicOpType,
|
||||
ptr: ir.Value,
|
||||
val: ir.Value,
|
||||
mask: ir.Value | None = None,
|
||||
semantic: tt_dialect.MemSemantic = tt_dialect.MemSemantic.ACQUIRE_RELEASE,
|
||||
sync_scope: tt_dialect.MemSyncScope = tt_dialect.MemSyncScope.GPU,
|
||||
) -> ir.Value:
|
||||
"""
|
||||
Expands floating point min/max via sequence of integer min/max. Does not handle NaNs.
|
||||
|
||||
min:
|
||||
return atomic_smin(i_ptr, i_val) if i_val >= 0 else atomic_umax(i_ptr, i_val)
|
||||
|
||||
max:
|
||||
return atomic_smax(i_ptr, i_val) if i_val >= 0 else atomic_umin(i_ptr, i_val)
|
||||
|
||||
"""
|
||||
|
||||
if ir.RankedTensorType.isinstance(ptr.type):
|
||||
ptr_type = ir.RankedTensorType(ptr.type)
|
||||
element_type = tt_dialect.PointerType(ptr_type.element_type)
|
||||
result_type = ir.RankedTensorType.get(
|
||||
ptr_type.shape, element_type.pointee_type, ptr_type.encoding
|
||||
)
|
||||
else:
|
||||
result_type = tt_dialect.PointerType(ptr.type).pointee_type
|
||||
|
||||
ptr_cast = tt_dialect.bitcast(_fp_bits_type(ptr.type), ptr)
|
||||
val_cast = tt_dialect.bitcast(_fp_bits_type(val.type), val)
|
||||
|
||||
zero = _full(val_cast.type, 0)
|
||||
pos_cmp = _greater_equal(val_cast, zero, signed=True)
|
||||
neg_cmp = _less_than(val_cast, zero, signed=True)
|
||||
|
||||
pos_mask = pos_cmp if mask is None else arith_dialect.andi(mask, pos_cmp)
|
||||
neg_mask = neg_cmp if mask is None else arith_dialect.andi(mask, neg_cmp)
|
||||
|
||||
pos_op, neg_op = (
|
||||
(tt_dialect.RMWOp.MAX, tt_dialect.RMWOp.UMIN)
|
||||
if atomic_type == primitives.AtomicOpType.MAX
|
||||
else (tt_dialect.RMWOp.MIN, tt_dialect.RMWOp.UMAX)
|
||||
)
|
||||
|
||||
pos_val = _atomic_rmw(
|
||||
pos_op, ptr_cast, val_cast, mask=pos_mask, semantic=semantic, sync_scope=sync_scope
|
||||
)
|
||||
neg_val = _atomic_rmw(
|
||||
neg_op, ptr_cast, val_cast, mask=neg_mask, semantic=semantic, sync_scope=sync_scope
|
||||
)
|
||||
result = arith_dialect.select(pos_cmp, pos_val, neg_val)
|
||||
return tt_dialect.bitcast(result_type, result)
|
||||
|
||||
@register_lowering(primitives.atomic_rmw_p)
|
||||
def _atomic_lowering_rule(
|
||||
@ -489,9 +557,23 @@ def _atomic_lowering_rule(
|
||||
else:
|
||||
op = tt_dialect.RMWOp.FADD
|
||||
elif atomic_type == primitives.AtomicOpType.MIN:
|
||||
op = tt_dialect.RMWOp.MIN
|
||||
if isinstance(val.type, ir.IntegerType):
|
||||
op = (
|
||||
tt_dialect.RMWOp.MIN
|
||||
if jnp.issubdtype(value_aval.dtype, jnp.signedinteger)
|
||||
else tt_dialect.RMWOp.UMIN
|
||||
)
|
||||
else:
|
||||
return _expand_atomic_fp_min_max(atomic_type, ptr, val, mask=mask)
|
||||
elif atomic_type == primitives.AtomicOpType.MAX:
|
||||
op = tt_dialect.RMWOp.MAX
|
||||
if isinstance(val.type, ir.IntegerType):
|
||||
op = (
|
||||
tt_dialect.RMWOp.MAX
|
||||
if jnp.issubdtype(value_aval.dtype, jnp.signedinteger)
|
||||
else tt_dialect.RMWOp.UMAX
|
||||
)
|
||||
else:
|
||||
return _expand_atomic_fp_min_max(atomic_type, ptr, val, mask=mask)
|
||||
elif atomic_type == primitives.AtomicOpType.AND:
|
||||
op = tt_dialect.RMWOp.AND
|
||||
elif atomic_type == primitives.AtomicOpType.OR:
|
||||
|
@ -1611,12 +1611,14 @@ class OpsTest(PallasBaseTest):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("add_i32", pl.atomic_add, np.array([1, 2, 3, 4], np.int32), np.sum),
|
||||
("max_i", pl.atomic_max, np.array([1, 2, 3, 4], np.int32), np.max),
|
||||
("max_i32", pl.atomic_max, np.array([1, 2, 3, 4], np.int32), np.max),
|
||||
("min_i32", pl.atomic_min, np.array([1, 2, 3, 4], np.int32), np.min),
|
||||
("max_u32", pl.atomic_max, np.array([1, 2, 3, 4], np.uint32), np.max),
|
||||
("min_u32", pl.atomic_min, np.array([1, 2, 3, 4], np.uint32), np.min),
|
||||
("add_f16", pl.atomic_add, np.array([1, 2, 3, 4], np.float16), np.sum),
|
||||
("add_f32", pl.atomic_add, np.array([1, 2, 3, 4], np.float32), np.sum),
|
||||
("max_f32", pl.atomic_max, np.array([1, 2, 3, 4], np.float32), np.max),
|
||||
("min_f32", pl.atomic_min, np.array([1, 2, 3, 4], np.float32), np.min),
|
||||
("max_f32", pl.atomic_max, np.array([-2, -1, 0, 1], np.float32), np.max),
|
||||
("min_f32", pl.atomic_min, np.array([-2, -1, 0, 1], np.float32), np.min),
|
||||
)
|
||||
def test_scalar_atomic(self, op, value, numpy_op):
|
||||
if (numpy_op.__name__ == "max" or numpy_op.__name__ == "min"):
|
||||
|
Loading…
x
Reference in New Issue
Block a user