Migrated atomic operations to lower directly to Triton IR

PiperOrigin-RevId: 602384705
This commit is contained in:
Sergei Lebedev 2024-01-29 07:44:50 -08:00 committed by jax authors
parent 2518a6f6d2
commit 07f8f700ca

View File

@ -598,45 +598,6 @@ class builder:
result_type = ir.RankedTensorType.get(shape, arg.type)
return tt_dialect.splat(result_type, arg)
def create_atomic_cas(
self,
ptr: ir.Value,
cmp: ir.Value,
val: ir.Value,
sem: tt_dialect.MemSemantic,
scope: tt_dialect.MemSyncScope,
) -> ir.Value:
if ir.RankedTensorType.isinstance(ptr.type):
ptr_type = ir.RankedTensorType(ptr.type)
element_type = tt_dialect.PointerType(ptr_type.element_type)
result_type = ir.RankedTensorType.get(
ptr_type.shape, element_type.pointee_type, ptr_type.encoding
)
else:
result_type = tt_dialect.PointerType(ptr.type).pointee_type
return tt_dialect.atomic_cas(result_type, ptr, cmp, val, sem, scope)
def create_atomic_rmw(
self,
rmw_op: tt_dialect.RMWOp,
ptr: ir.Value,
val: ir.Value,
mask: ir.Value,
sem: tt_dialect.MemSemantic,
scope: tt_dialect.MemSyncScope,
) -> ir.Value:
if ir.RankedTensorType.isinstance(ptr.type):
ptr_type = ir.RankedTensorType(ptr.type)
element_type = tt_dialect.PointerType(ptr_type.element_type)
result_type = ir.RankedTensorType.get(
ptr_type.shape, element_type.pointee_type, ptr_type.encoding
)
else:
result_type = tt_dialect.PointerType(ptr.type).pointee_type
return tt_dialect.atomic_rmw(
result_type, rmw_op, ptr, val, sem, scope, mask=mask
)
def create_extern_elementwise(
self,
lib_name: str,
@ -983,14 +944,81 @@ def reshape(x: tensor, dst_shape: Sequence[int]) -> tensor:
dot = wrap_with_builder(tl.core.dot)
atomic_xchg = wrap_with_builder(tl.core.atomic_xchg)
atomic_add = wrap_with_builder(tl.core.atomic_add)
atomic_max = wrap_with_builder(tl.core.atomic_max)
atomic_min = wrap_with_builder(tl.core.atomic_min)
atomic_and = wrap_with_builder(tl.core.atomic_and)
atomic_or = wrap_with_builder(tl.core.atomic_or)
atomic_xor = wrap_with_builder(tl.core.atomic_xor)
atomic_cas = wrap_with_builder(tl.atomic_cas)
def atomic_cas(
ptr: tensor,
cmp: tensor,
val: tensor,
semantic: tt_dialect.MemSemantic = tt_dialect.MemSemantic.ACQUIRE_RELEASE,
sync_scope: tt_dialect.MemSyncScope = tt_dialect.MemSyncScope.GPU,
):
if ir.RankedTensorType.isinstance(ptr.handle.type):
ptr_type = ir.RankedTensorType(ptr.handle.type)
element_type = tt_dialect.PointerType(ptr_type.element_type)
result_type = ir.RankedTensorType.get(
ptr_type.shape, element_type.pointee_type, ptr_type.encoding
)
else:
result_type = tt_dialect.PointerType(ptr.handle.type).pointee_type
result_handle = tt_dialect.atomic_cas(
result_type,
ptr.handle,
cmp.handle,
val.handle,
sem=semantic,
scope=sync_scope,
)
return tensor(result_handle, val.type)
def _atomic_rmw(
op: tt_dialect.RMWOp,
ptr: tensor,
val: tensor,
mask: tensor | None = None,
semantic: tt_dialect.MemSemantic = tt_dialect.MemSemantic.ACQUIRE_RELEASE,
sync_scope: tt_dialect.MemSyncScope = tt_dialect.MemSyncScope.GPU,
) -> tensor:
if ir.RankedTensorType.isinstance(ptr.handle.type):
ptr_type = ir.RankedTensorType(ptr.handle.type)
element_type = tt_dialect.PointerType(ptr_type.element_type)
result_type = ir.RankedTensorType.get(
ptr_type.shape, element_type.pointee_type, ptr_type.encoding
)
else:
result_type = tt_dialect.PointerType(ptr.handle.type).pointee_type
result_handle = tt_dialect.atomic_rmw(
result_type,
op,
ptr.handle,
val.handle,
mask=mask.handle if mask is not None else None,
sem=semantic,
scope=sync_scope,
)
return tensor(result_handle, val.type)
atomic_xchg = partial(_atomic_rmw, tt_dialect.RMWOp.XCHG)
atomic_max = partial(_atomic_rmw, tt_dialect.RMWOp.MAX)
atomic_min = partial(_atomic_rmw, tt_dialect.RMWOp.MIN)
atomic_and = partial(_atomic_rmw, tt_dialect.RMWOp.AND)
atomic_or = partial(_atomic_rmw, tt_dialect.RMWOp.OR)
atomic_xor = partial(_atomic_rmw, tt_dialect.RMWOp.XOR)
def atomic_add(
ptr: tensor,
val: tensor,
mask: tensor | None = None,
semantic: tt_dialect.MemSemantic = tt_dialect.MemSemantic.ACQUIRE_RELEASE,
sync_scope: tt_dialect.MemSyncScope = tt_dialect.MemSyncScope.GPU,
):
if val.dtype.is_floating():
op = tt_dialect.RMWOp.FADD
else:
op = tt_dialect.RMWOp.ADD
return _atomic_rmw(op, ptr, val, mask, semantic, sync_scope)
def abs(x: object) -> tensor:
@ -1041,7 +1069,7 @@ def cos(x: object) -> tensor:
return tensor(math_dialect.cos(x.handle), x.type)
def multiple_of(x: tensor, values: Sequence[int]) -> tl.tensor:
def multiple_of(x: tensor, values: Sequence[int]) -> tensor:
assert max(1, len(x.shape)) == len(values)
set_attr(
x.handle,
@ -1053,7 +1081,7 @@ def multiple_of(x: tensor, values: Sequence[int]) -> tl.tensor:
return x
def max_contiguous(x: tensor, values: Sequence[int]) -> tl.tensor:
def max_contiguous(x: tensor, values: Sequence[int]) -> tensor:
assert len(x.shape) == len(values)
set_attr(
x.handle,