Migrated store/load to lower directly to Triton IR

PiperOrigin-RevId: 603764118
This commit is contained in:
Sergei Lebedev 2024-02-02 12:52:57 -08:00 committed by jax authors
parent 16636f9c97
commit 5867a05cdd
3 changed files with 163 additions and 122 deletions

View File

@ -417,8 +417,8 @@ def _swap_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_):
state_discharge.register_discharge_rule(swap_p)(_swap_discharge_rule)
def load(x_ref_or_view, idx, *, mask=None, other=None, cache_modifier="",
eviction_policy="", volatile=False) -> jax.Array:
def load(x_ref_or_view, idx, *, mask=None, other=None, cache_modifier=None,
eviction_policy=None, volatile=False) -> jax.Array:
x_ref, indexers = sp.get_ref_and_indexers(x_ref_or_view, idx, "load")
args_flat, args_tree = tree_util.tree_flatten((x_ref, indexers, mask, other))
return load_p.bind(
@ -429,7 +429,7 @@ def load(x_ref_or_view, idx, *, mask=None, other=None, cache_modifier="",
is_volatile=volatile,
)
def swap(x_ref_or_view, idx, val, *, mask=None, eviction_policy="",
def swap(x_ref_or_view, idx, val, *, mask=None, eviction_policy=None,
_function_name="swap") -> Any:
x_ref, indexers = sp.get_ref_and_indexers(x_ref_or_view, idx, _function_name)
args_flat, args_tree = tree_util.tree_flatten((x_ref, indexers, val, mask))
@ -437,7 +437,7 @@ def swap(x_ref_or_view, idx, val, *, mask=None, eviction_policy="",
*args_flat, args_tree=args_tree, eviction_policy=eviction_policy
)
def store(x_ref_or_view, idx, val, *, mask=None, eviction_policy="") -> None:
def store(x_ref_or_view, idx, val, *, mask=None, eviction_policy=None) -> None:
_ = swap(x_ref_or_view, idx, val, mask=mask, eviction_policy=eviction_policy,
_function_name="store")

View File

@ -880,12 +880,14 @@ def _masked_load_lowering_rule(
ptr = _compute_pointers_from_indices(
ptr, ctx.block_infos[0], idx, ctx.avals_in[0].shape
)
if other is not None and mask is not None:
other = tc.broadcast_to(other, mask.shape)
val = tc.load(
ptr,
mask=mask,
other=other,
cache_modifier=cache_modifier,
volatile=is_volatile,
is_volatile=is_volatile,
eviction_policy=eviction_policy,
)
# `tl.load` of a `*int1` returns a tensor with type `int8`, so fix the type.
@ -931,7 +933,9 @@ def _masked_swap_lowering_rule(
ptr = _compute_pointers_from_indices(
ptr, ctx.block_infos[0], idx, ctx.avals_in[0].shape
)
other = None if mask is None else value
other = None
if value is not None and mask is not None:
other = tc.broadcast_to(value, mask.shape)
old_value = tc.load(ptr, mask=mask, other=other)
tc.store(
ptr,

View File

@ -453,120 +453,6 @@ class builder:
def create_or(self, lhs: ir.Value, rhs: ir.Value) -> ir.Value:
return arith_dialect.ori(lhs, rhs)
def create_load(
self,
ptr: ir.Value,
cache_modifier: tt_dialect.CacheModifier,
eviction_policy: tt_dialect.EvictionPolicy,
is_volatile: bool,
) -> ir.Value:
if ir.RankedTensorType.isinstance(ptr.type):
ptr_type = ir.RankedTensorType(ptr.type)
element_type = tt_dialect.PointerType(ptr_type.element_type)
result_type = ir.RankedTensorType.get(
ptr_type.shape,
element_type.pointee_type,
ptr_type.encoding,
)
else:
ptr_type = tt_dialect.PointerType(ptr.type)
result_type = ptr_type.pointee_type
return tt_dialect.load(
result_type, ptr, cache_modifier, eviction_policy, is_volatile
)
def create_store(
self,
ptr: ir.Value,
value: ir.Value,
cache_modifier: tt_dialect.CacheModifier,
eviction_policy: tt_dialect.EvictionPolicy,
) -> ir.Value:
return tt_dialect.store(
ptr, value, cache=cache_modifier, evict=eviction_policy
)
def create_tensor_pointer_load(
self,
ptr: ir.Value,
boundary_check: Sequence[int],
padding_option: Sequence[tt_dialect.PaddingOption],
cache_modifier: tt_dialect.CacheModifier,
eviction_policy: tt_dialect.EvictionPolicy,
is_volatile: bool,
) -> ir.Value:
return tt_dialect.load(
ptr.type,
ptr,
cache_modifier,
eviction_policy,
is_volatile,
boundary_check=boundary_check,
padding=padding_option,
)
def create_tensor_pointer_store(
self,
ptr: ir.Value,
value: ir.Value,
boundary_check: Sequence[int],
cache_modifier: tt_dialect.CacheModifier,
eviction_policy: tt_dialect.EvictionPolicy,
) -> ir.Value:
return tt_dialect.store(
ptr,
value,
boundary_check=boundary_check,
cache=cache_modifier,
evict=eviction_policy,
)
def create_masked_load(
self,
ptr: ir.Value,
mask: ir.Value,
other: ir.Value | None,
cache_modifier: tt_dialect.CacheModifier,
eviction_policy: tt_dialect.EvictionPolicy,
is_volatile: bool,
) -> ir.Value:
if ir.RankedTensorType.isinstance(ptr.type):
ptr_type = ir.RankedTensorType(ptr.type)
element_type = tt_dialect.PointerType(ptr_type.element_type)
result_type = ir.RankedTensorType.get(
ptr_type.shape,
element_type.pointee_type,
ptr_type.encoding,
)
else:
ptr_type = tt_dialect.PointerType(ptr.type)
result_type = ptr_type.pointee_type
return tt_dialect.load(
result_type,
ptr,
cache_modifier,
eviction_policy,
is_volatile,
mask=mask,
other=other,
)
def create_masked_store(
self,
ptr: ir.Value,
value: ir.Value,
mask: ir.Value,
cache_modifier: tt_dialect.CacheModifier,
eviction_policy: tt_dialect.EvictionPolicy,
) -> ir.Value:
return tt_dialect.store(
ptr,
value,
mask=mask,
cache=cache_modifier,
evict=eviction_policy,
)
def create_cat(self, lhs: ir.Value, rhs: ir.Value) -> ir.Value:
assert ir.RankedTensorType.isinstance(lhs.type)
assert ir.RankedTensorType.isinstance(rhs.type)
@ -739,11 +625,13 @@ block_type = tl.core.block_type
function_type = tl.core.function_type
pointer_type = tl.core.pointer_type
void = tl.core.void
bfloat16 = tl.core.bfloat16
float16 = tl.core.float16
float32 = tl.core.float32
float64 = tl.core.float64
int1 = tl.core.int1
int8 = tl.core.int8
int32 = tl.core.int32
int64 = tl.core.int64
uint32 = tl.core.uint32
@ -858,8 +746,157 @@ def program_id(axis: int) -> tensor:
return tensor(tt_dialect.get_program_id(axis), int32)
load = wrap_with_builder(tl.core.load)
store = wrap_with_builder(tl.core.store)
_STR_TO_EVICTION_POLICY = {str(e): e for e in tt_dialect.EvictionPolicy}
_STR_TO_CACHE_MODIFIER = {str(c): c for c in tt_dialect.CacheModifier}
def _infer_load_return_type(ptr: ir.Value) -> ir.Type:
if ir.RankedTensorType.isinstance(ptr.type):
ptr_type = ir.RankedTensorType(ptr.type)
element_type = tt_dialect.PointerType(ptr_type.element_type)
return ir.RankedTensorType.get(
ptr_type.shape,
element_type.pointee_type,
ptr_type.encoding,
)
else:
ptr_type = tt_dialect.PointerType(ptr.type)
return ptr_type.pointee_type
def load(
ptr: tensor,
mask: tensor | None = None,
other: tensor | None = None,
*,
cache_modifier: str | None = None,
eviction_policy: str | None = None,
is_volatile: bool = False,
) -> tensor:
if cache_modifier is None:
cache_modifier = tt_dialect.CacheModifier.NONE
elif cache_modifier == ".ca" or cache_modifier == ".cg":
cache_modifier = _STR_TO_CACHE_MODIFIER[cache_modifier]
else:
raise ValueError(f"unsupported cache modifier: {cache_modifier}")
if eviction_policy is None:
eviction_policy = tt_dialect.EvictionPolicy.NORMAL
else:
try:
eviction_policy = _STR_TO_EVICTION_POLICY[eviction_policy]
except KeyError:
raise ValueError(
f"unsupported eviction policy: {eviction_policy}"
) from None
if ptr.type.is_ptr() and ptr.type.element_ty.is_block():
# TODO(slebedev): Support load from a block pointer.
raise NotImplementedError("loading from a block pointer is not supported")
if not ptr.dtype.is_ptr():
raise ValueError(f"unsupported pointer dtype: {ptr.dtype}")
if other is not None:
if mask is None:
raise ValueError("other requires mask to be provided")
assert mask.shape == other.shape == ptr.shape, (
mask.shape,
other.shape,
ptr.shape,
)
elif mask is not None:
assert mask.shape == ptr.shape
if not ptr.type.is_block():
if other is not None and other.type.is_block():
raise ValueError("other cannot be a block if pointer is not a block")
if mask is not None and mask.type.is_block():
raise ValueError("mask cannot be a block if pointer is not a block")
ptr_type = ptr.dtype
element_type = ptr_type.element_ty
if element_type == int1:
# TODO(slebedev): Cast the result back to int1 before returning.
element_type = int8
ptr_type = pointer_type(element_type, ptr_type.address_space)
ptr = semantic.cast(ptr, ptr_type)
if other is not None:
other = semantic.cast(other, element_type)
result_handle = tt_dialect.load(
_infer_load_return_type(ptr.handle),
ptr.handle,
mask=mask.handle if mask is not None else None,
other=other.handle if other is not None else None,
cache=cache_modifier,
evict=eviction_policy,
is_volatile=is_volatile,
)
if ptr.type.is_block():
return tensor(result_handle, block_type(element_type, ptr.type.shape))
else:
return tensor(result_handle, element_type)
def store(
ptr: tensor,
value: tensor,
mask: tensor | None = None,
*,
cache_modifier: str | None = None,
eviction_policy: str | None = None,
) -> tensor:
if cache_modifier is None:
cache_modifier = tt_dialect.CacheModifier.NONE
elif cache_modifier != ".ca":
cache_modifier = _STR_TO_CACHE_MODIFIER[cache_modifier]
else:
raise ValueError(f"unsupported cache modifier: {cache_modifier}")
if eviction_policy is None:
eviction_policy = tt_dialect.EvictionPolicy.NORMAL
else:
try:
eviction_policy = _STR_TO_EVICTION_POLICY[eviction_policy]
except KeyError:
raise ValueError(
f"unsupported eviction policy: {eviction_policy}"
) from None
if ptr.type.is_ptr() and ptr.type.element_ty.is_block():
# TODO(slebedev): Support load from a block pointer.
raise NotImplementedError("storing to a block pointer is not supported")
if not ptr.dtype.is_ptr():
raise ValueError(f"unsupported pointer dtype: {ptr.dtype}")
assert value.shape == ptr.shape
if mask is not None:
assert mask.shape == ptr.shape
if not ptr.type.is_block():
if value.type.is_block():
raise ValueError("other cannot be a block if pointer is not a block")
if mask is not None and mask.type.is_block():
raise ValueError("mask cannot be a block if pointer is not a block")
ptr_type = ptr.dtype
element_type = ptr_type.element_ty
if element_type == int1:
# TODO(slebedev): Cast the result back to int1 before returning.
element_type = int8
ptr_type = pointer_type(element_type, ptr_type.address_space)
ptr = semantic.cast(ptr, ptr_type)
value = semantic.cast(value, element_type)
return tensor(
tt_dialect.store(
ptr.handle,
value.handle,
mask=mask.handle if mask is not None else None,
cache=cache_modifier,
evict=eviction_policy,
),
void,
)
def arange(start: int, end: int) -> tensor: