diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index b0a2b4dbc..61439cad9 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -462,6 +462,74 @@ def _atomic_rmw( result_type, op, ptr, val, mask=mask, sem=semantic, scope=sync_scope ) +def _fp_bits_type(t: ir.Type) -> ir.Type: + if ir.RankedTensorType.isinstance(t): + t_type = ir.RankedTensorType(t) + return ir.RankedTensorType.get( + t_type.shape, _fp_bits_type(t_type.element_type), t_type.encoding + ) + elif tt_dialect.PointerType.isinstance(t): + ptr_type = tt_dialect.PointerType(t) + return tt_dialect.PointerType.get( + _fp_bits_type(ptr_type.pointee_type), ptr_type.address_space + ) + else: + assert isinstance(t, ir.FloatType) + return ir.IntegerType.get_signless(t.width) + + +def _expand_atomic_fp_min_max( + atomic_type: primitives.AtomicOpType, + ptr: ir.Value, + val: ir.Value, + mask: ir.Value | None = None, + semantic: tt_dialect.MemSemantic = tt_dialect.MemSemantic.ACQUIRE_RELEASE, + sync_scope: tt_dialect.MemSyncScope = tt_dialect.MemSyncScope.GPU, +) -> ir.Value: + """ + Expands floating point min/max via sequence of integer min/max. Does not handle NaNs. + + min: + return atomic_smin(i_ptr, i_val) if i_val >= 0 else atomic_umax(i_ptr, i_val) + + max: + return atomic_smax(i_ptr, i_val) if i_val >= 0 else atomic_umin(i_ptr, i_val) + + """ + + if ir.RankedTensorType.isinstance(ptr.type): + ptr_type = ir.RankedTensorType(ptr.type) + element_type = tt_dialect.PointerType(ptr_type.element_type) + result_type = ir.RankedTensorType.get( + ptr_type.shape, element_type.pointee_type, ptr_type.encoding + ) + else: + result_type = tt_dialect.PointerType(ptr.type).pointee_type + + ptr_cast = tt_dialect.bitcast(_fp_bits_type(ptr.type), ptr) + val_cast = tt_dialect.bitcast(_fp_bits_type(val.type), val) + + zero = _full(val_cast.type, 0) + pos_cmp = _greater_equal(val_cast, zero, signed=True) + neg_cmp = _less_than(val_cast, zero, signed=True) + + pos_mask = pos_cmp if mask is None else arith_dialect.andi(mask, pos_cmp) + neg_mask = neg_cmp if mask is None else arith_dialect.andi(mask, neg_cmp) + + pos_op, neg_op = ( + (tt_dialect.RMWOp.MAX, tt_dialect.RMWOp.UMIN) + if atomic_type == primitives.AtomicOpType.MAX + else (tt_dialect.RMWOp.MIN, tt_dialect.RMWOp.UMAX) + ) + + pos_val = _atomic_rmw( + pos_op, ptr_cast, val_cast, mask=pos_mask, semantic=semantic, sync_scope=sync_scope + ) + neg_val = _atomic_rmw( + neg_op, ptr_cast, val_cast, mask=neg_mask, semantic=semantic, sync_scope=sync_scope + ) + result = arith_dialect.select(pos_cmp, pos_val, neg_val) + return tt_dialect.bitcast(result_type, result) @register_lowering(primitives.atomic_rmw_p) def _atomic_lowering_rule( @@ -489,9 +557,23 @@ def _atomic_lowering_rule( else: op = tt_dialect.RMWOp.FADD elif atomic_type == primitives.AtomicOpType.MIN: - op = tt_dialect.RMWOp.MIN + if isinstance(val.type, ir.IntegerType): + op = ( + tt_dialect.RMWOp.MIN + if jnp.issubdtype(value_aval.dtype, jnp.signedinteger) + else tt_dialect.RMWOp.UMIN + ) + else: + return _expand_atomic_fp_min_max(atomic_type, ptr, val, mask=mask) elif atomic_type == primitives.AtomicOpType.MAX: - op = tt_dialect.RMWOp.MAX + if isinstance(val.type, ir.IntegerType): + op = ( + tt_dialect.RMWOp.MAX + if jnp.issubdtype(value_aval.dtype, jnp.signedinteger) + else tt_dialect.RMWOp.UMAX + ) + else: + return _expand_atomic_fp_min_max(atomic_type, ptr, val, mask=mask) elif atomic_type == primitives.AtomicOpType.AND: op = tt_dialect.RMWOp.AND elif atomic_type == primitives.AtomicOpType.OR: diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 863c64892..087359ae9 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1611,12 +1611,14 @@ class OpsTest(PallasBaseTest): @parameterized.named_parameters( ("add_i32", pl.atomic_add, np.array([1, 2, 3, 4], np.int32), np.sum), - ("max_i", pl.atomic_max, np.array([1, 2, 3, 4], np.int32), np.max), + ("max_i32", pl.atomic_max, np.array([1, 2, 3, 4], np.int32), np.max), ("min_i32", pl.atomic_min, np.array([1, 2, 3, 4], np.int32), np.min), + ("max_u32", pl.atomic_max, np.array([1, 2, 3, 4], np.uint32), np.max), + ("min_u32", pl.atomic_min, np.array([1, 2, 3, 4], np.uint32), np.min), ("add_f16", pl.atomic_add, np.array([1, 2, 3, 4], np.float16), np.sum), ("add_f32", pl.atomic_add, np.array([1, 2, 3, 4], np.float32), np.sum), - ("max_f32", pl.atomic_max, np.array([1, 2, 3, 4], np.float32), np.max), - ("min_f32", pl.atomic_min, np.array([1, 2, 3, 4], np.float32), np.min), + ("max_f32", pl.atomic_max, np.array([-2, -1, 0, 1], np.float32), np.max), + ("min_f32", pl.atomic_min, np.array([-2, -1, 0, 1], np.float32), np.min), ) def test_scalar_atomic(self, op, value, numpy_op): if (numpy_op.__name__ == "max" or numpy_op.__name__ == "min"):