[pallas:triton] Fix atomic min/max lowering for unsigned integers and float types (#263)

This commit is contained in:
Dragan Mladjenovic 2025-03-10 16:39:22 +01:00 committed by GitHub
parent 01c8d7feb8
commit a8c11ba79e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 89 additions and 5 deletions

View File

@ -462,6 +462,74 @@ def _atomic_rmw(
result_type, op, ptr, val, mask=mask, sem=semantic, scope=sync_scope
)
def _fp_bits_type(t: ir.Type) -> ir.Type:
if ir.RankedTensorType.isinstance(t):
t_type = ir.RankedTensorType(t)
return ir.RankedTensorType.get(
t_type.shape, _fp_bits_type(t_type.element_type), t_type.encoding
)
elif tt_dialect.PointerType.isinstance(t):
ptr_type = tt_dialect.PointerType(t)
return tt_dialect.PointerType.get(
_fp_bits_type(ptr_type.pointee_type), ptr_type.address_space
)
else:
assert isinstance(t, ir.FloatType)
return ir.IntegerType.get_signless(t.width)
def _expand_atomic_fp_min_max(
atomic_type: primitives.AtomicOpType,
ptr: ir.Value,
val: ir.Value,
mask: ir.Value | None = None,
semantic: tt_dialect.MemSemantic = tt_dialect.MemSemantic.ACQUIRE_RELEASE,
sync_scope: tt_dialect.MemSyncScope = tt_dialect.MemSyncScope.GPU,
) -> ir.Value:
"""
Expands floating point min/max via sequence of integer min/max. Does not handle NaNs.
min:
return atomic_smin(i_ptr, i_val) if i_val >= 0 else atomic_umax(i_ptr, i_val)
max:
return atomic_smax(i_ptr, i_val) if i_val >= 0 else atomic_umin(i_ptr, i_val)
"""
if ir.RankedTensorType.isinstance(ptr.type):
ptr_type = ir.RankedTensorType(ptr.type)
element_type = tt_dialect.PointerType(ptr_type.element_type)
result_type = ir.RankedTensorType.get(
ptr_type.shape, element_type.pointee_type, ptr_type.encoding
)
else:
result_type = tt_dialect.PointerType(ptr.type).pointee_type
ptr_cast = tt_dialect.bitcast(_fp_bits_type(ptr.type), ptr)
val_cast = tt_dialect.bitcast(_fp_bits_type(val.type), val)
zero = _full(val_cast.type, 0)
pos_cmp = _greater_equal(val_cast, zero, signed=True)
neg_cmp = _less_than(val_cast, zero, signed=True)
pos_mask = pos_cmp if mask is None else arith_dialect.andi(mask, pos_cmp)
neg_mask = neg_cmp if mask is None else arith_dialect.andi(mask, neg_cmp)
pos_op, neg_op = (
(tt_dialect.RMWOp.MAX, tt_dialect.RMWOp.UMIN)
if atomic_type == primitives.AtomicOpType.MAX
else (tt_dialect.RMWOp.MIN, tt_dialect.RMWOp.UMAX)
)
pos_val = _atomic_rmw(
pos_op, ptr_cast, val_cast, mask=pos_mask, semantic=semantic, sync_scope=sync_scope
)
neg_val = _atomic_rmw(
neg_op, ptr_cast, val_cast, mask=neg_mask, semantic=semantic, sync_scope=sync_scope
)
result = arith_dialect.select(pos_cmp, pos_val, neg_val)
return tt_dialect.bitcast(result_type, result)
@register_lowering(primitives.atomic_rmw_p)
def _atomic_lowering_rule(
@ -489,9 +557,23 @@ def _atomic_lowering_rule(
else:
op = tt_dialect.RMWOp.FADD
elif atomic_type == primitives.AtomicOpType.MIN:
op = tt_dialect.RMWOp.MIN
if isinstance(val.type, ir.IntegerType):
op = (
tt_dialect.RMWOp.MIN
if jnp.issubdtype(value_aval.dtype, jnp.signedinteger)
else tt_dialect.RMWOp.UMIN
)
else:
return _expand_atomic_fp_min_max(atomic_type, ptr, val, mask=mask)
elif atomic_type == primitives.AtomicOpType.MAX:
op = tt_dialect.RMWOp.MAX
if isinstance(val.type, ir.IntegerType):
op = (
tt_dialect.RMWOp.MAX
if jnp.issubdtype(value_aval.dtype, jnp.signedinteger)
else tt_dialect.RMWOp.UMAX
)
else:
return _expand_atomic_fp_min_max(atomic_type, ptr, val, mask=mask)
elif atomic_type == primitives.AtomicOpType.AND:
op = tt_dialect.RMWOp.AND
elif atomic_type == primitives.AtomicOpType.OR:

View File

@ -1611,12 +1611,14 @@ class OpsTest(PallasBaseTest):
@parameterized.named_parameters(
("add_i32", pl.atomic_add, np.array([1, 2, 3, 4], np.int32), np.sum),
("max_i", pl.atomic_max, np.array([1, 2, 3, 4], np.int32), np.max),
("max_i32", pl.atomic_max, np.array([1, 2, 3, 4], np.int32), np.max),
("min_i32", pl.atomic_min, np.array([1, 2, 3, 4], np.int32), np.min),
("max_u32", pl.atomic_max, np.array([1, 2, 3, 4], np.uint32), np.max),
("min_u32", pl.atomic_min, np.array([1, 2, 3, 4], np.uint32), np.min),
("add_f16", pl.atomic_add, np.array([1, 2, 3, 4], np.float16), np.sum),
("add_f32", pl.atomic_add, np.array([1, 2, 3, 4], np.float32), np.sum),
("max_f32", pl.atomic_max, np.array([1, 2, 3, 4], np.float32), np.max),
("min_f32", pl.atomic_min, np.array([1, 2, 3, 4], np.float32), np.min),
("max_f32", pl.atomic_max, np.array([-2, -1, 0, 1], np.float32), np.max),
("min_f32", pl.atomic_min, np.array([-2, -1, 0, 1], np.float32), np.min),
)
def test_scalar_atomic(self, op, value, numpy_op):
if (numpy_op.__name__ == "max" or numpy_op.__name__ == "min"):