mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Migrated atomic operations to lower directly to Triton IR
PiperOrigin-RevId: 602384705
This commit is contained in:
parent
2518a6f6d2
commit
07f8f700ca
@ -598,45 +598,6 @@ class builder:
|
||||
result_type = ir.RankedTensorType.get(shape, arg.type)
|
||||
return tt_dialect.splat(result_type, arg)
|
||||
|
||||
def create_atomic_cas(
|
||||
self,
|
||||
ptr: ir.Value,
|
||||
cmp: ir.Value,
|
||||
val: ir.Value,
|
||||
sem: tt_dialect.MemSemantic,
|
||||
scope: tt_dialect.MemSyncScope,
|
||||
) -> ir.Value:
|
||||
if ir.RankedTensorType.isinstance(ptr.type):
|
||||
ptr_type = ir.RankedTensorType(ptr.type)
|
||||
element_type = tt_dialect.PointerType(ptr_type.element_type)
|
||||
result_type = ir.RankedTensorType.get(
|
||||
ptr_type.shape, element_type.pointee_type, ptr_type.encoding
|
||||
)
|
||||
else:
|
||||
result_type = tt_dialect.PointerType(ptr.type).pointee_type
|
||||
return tt_dialect.atomic_cas(result_type, ptr, cmp, val, sem, scope)
|
||||
|
||||
def create_atomic_rmw(
|
||||
self,
|
||||
rmw_op: tt_dialect.RMWOp,
|
||||
ptr: ir.Value,
|
||||
val: ir.Value,
|
||||
mask: ir.Value,
|
||||
sem: tt_dialect.MemSemantic,
|
||||
scope: tt_dialect.MemSyncScope,
|
||||
) -> ir.Value:
|
||||
if ir.RankedTensorType.isinstance(ptr.type):
|
||||
ptr_type = ir.RankedTensorType(ptr.type)
|
||||
element_type = tt_dialect.PointerType(ptr_type.element_type)
|
||||
result_type = ir.RankedTensorType.get(
|
||||
ptr_type.shape, element_type.pointee_type, ptr_type.encoding
|
||||
)
|
||||
else:
|
||||
result_type = tt_dialect.PointerType(ptr.type).pointee_type
|
||||
return tt_dialect.atomic_rmw(
|
||||
result_type, rmw_op, ptr, val, sem, scope, mask=mask
|
||||
)
|
||||
|
||||
def create_extern_elementwise(
|
||||
self,
|
||||
lib_name: str,
|
||||
@ -983,14 +944,81 @@ def reshape(x: tensor, dst_shape: Sequence[int]) -> tensor:
|
||||
|
||||
dot = wrap_with_builder(tl.core.dot)
|
||||
|
||||
atomic_xchg = wrap_with_builder(tl.core.atomic_xchg)
|
||||
atomic_add = wrap_with_builder(tl.core.atomic_add)
|
||||
atomic_max = wrap_with_builder(tl.core.atomic_max)
|
||||
atomic_min = wrap_with_builder(tl.core.atomic_min)
|
||||
atomic_and = wrap_with_builder(tl.core.atomic_and)
|
||||
atomic_or = wrap_with_builder(tl.core.atomic_or)
|
||||
atomic_xor = wrap_with_builder(tl.core.atomic_xor)
|
||||
atomic_cas = wrap_with_builder(tl.atomic_cas)
|
||||
|
||||
def atomic_cas(
|
||||
ptr: tensor,
|
||||
cmp: tensor,
|
||||
val: tensor,
|
||||
semantic: tt_dialect.MemSemantic = tt_dialect.MemSemantic.ACQUIRE_RELEASE,
|
||||
sync_scope: tt_dialect.MemSyncScope = tt_dialect.MemSyncScope.GPU,
|
||||
):
|
||||
if ir.RankedTensorType.isinstance(ptr.handle.type):
|
||||
ptr_type = ir.RankedTensorType(ptr.handle.type)
|
||||
element_type = tt_dialect.PointerType(ptr_type.element_type)
|
||||
result_type = ir.RankedTensorType.get(
|
||||
ptr_type.shape, element_type.pointee_type, ptr_type.encoding
|
||||
)
|
||||
else:
|
||||
result_type = tt_dialect.PointerType(ptr.handle.type).pointee_type
|
||||
result_handle = tt_dialect.atomic_cas(
|
||||
result_type,
|
||||
ptr.handle,
|
||||
cmp.handle,
|
||||
val.handle,
|
||||
sem=semantic,
|
||||
scope=sync_scope,
|
||||
)
|
||||
return tensor(result_handle, val.type)
|
||||
|
||||
|
||||
def _atomic_rmw(
|
||||
op: tt_dialect.RMWOp,
|
||||
ptr: tensor,
|
||||
val: tensor,
|
||||
mask: tensor | None = None,
|
||||
semantic: tt_dialect.MemSemantic = tt_dialect.MemSemantic.ACQUIRE_RELEASE,
|
||||
sync_scope: tt_dialect.MemSyncScope = tt_dialect.MemSyncScope.GPU,
|
||||
) -> tensor:
|
||||
if ir.RankedTensorType.isinstance(ptr.handle.type):
|
||||
ptr_type = ir.RankedTensorType(ptr.handle.type)
|
||||
element_type = tt_dialect.PointerType(ptr_type.element_type)
|
||||
result_type = ir.RankedTensorType.get(
|
||||
ptr_type.shape, element_type.pointee_type, ptr_type.encoding
|
||||
)
|
||||
else:
|
||||
result_type = tt_dialect.PointerType(ptr.handle.type).pointee_type
|
||||
result_handle = tt_dialect.atomic_rmw(
|
||||
result_type,
|
||||
op,
|
||||
ptr.handle,
|
||||
val.handle,
|
||||
mask=mask.handle if mask is not None else None,
|
||||
sem=semantic,
|
||||
scope=sync_scope,
|
||||
)
|
||||
return tensor(result_handle, val.type)
|
||||
|
||||
|
||||
atomic_xchg = partial(_atomic_rmw, tt_dialect.RMWOp.XCHG)
|
||||
atomic_max = partial(_atomic_rmw, tt_dialect.RMWOp.MAX)
|
||||
atomic_min = partial(_atomic_rmw, tt_dialect.RMWOp.MIN)
|
||||
atomic_and = partial(_atomic_rmw, tt_dialect.RMWOp.AND)
|
||||
atomic_or = partial(_atomic_rmw, tt_dialect.RMWOp.OR)
|
||||
atomic_xor = partial(_atomic_rmw, tt_dialect.RMWOp.XOR)
|
||||
|
||||
|
||||
def atomic_add(
|
||||
ptr: tensor,
|
||||
val: tensor,
|
||||
mask: tensor | None = None,
|
||||
semantic: tt_dialect.MemSemantic = tt_dialect.MemSemantic.ACQUIRE_RELEASE,
|
||||
sync_scope: tt_dialect.MemSyncScope = tt_dialect.MemSyncScope.GPU,
|
||||
):
|
||||
if val.dtype.is_floating():
|
||||
op = tt_dialect.RMWOp.FADD
|
||||
else:
|
||||
op = tt_dialect.RMWOp.ADD
|
||||
return _atomic_rmw(op, ptr, val, mask, semantic, sync_scope)
|
||||
|
||||
|
||||
def abs(x: object) -> tensor:
|
||||
@ -1041,7 +1069,7 @@ def cos(x: object) -> tensor:
|
||||
return tensor(math_dialect.cos(x.handle), x.type)
|
||||
|
||||
|
||||
def multiple_of(x: tensor, values: Sequence[int]) -> tl.tensor:
|
||||
def multiple_of(x: tensor, values: Sequence[int]) -> tensor:
|
||||
assert max(1, len(x.shape)) == len(values)
|
||||
set_attr(
|
||||
x.handle,
|
||||
@ -1053,7 +1081,7 @@ def multiple_of(x: tensor, values: Sequence[int]) -> tl.tensor:
|
||||
return x
|
||||
|
||||
|
||||
def max_contiguous(x: tensor, values: Sequence[int]) -> tl.tensor:
|
||||
def max_contiguous(x: tensor, values: Sequence[int]) -> tensor:
|
||||
assert len(x.shape) == len(values)
|
||||
set_attr(
|
||||
x.handle,
|
||||
|
Loading…
x
Reference in New Issue
Block a user