mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Migrated store/load to lower directly to Triton IR
PiperOrigin-RevId: 603764118
This commit is contained in:
parent
16636f9c97
commit
5867a05cdd
@ -417,8 +417,8 @@ def _swap_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_):
|
||||
state_discharge.register_discharge_rule(swap_p)(_swap_discharge_rule)
|
||||
|
||||
|
||||
def load(x_ref_or_view, idx, *, mask=None, other=None, cache_modifier="",
|
||||
eviction_policy="", volatile=False) -> jax.Array:
|
||||
def load(x_ref_or_view, idx, *, mask=None, other=None, cache_modifier=None,
|
||||
eviction_policy=None, volatile=False) -> jax.Array:
|
||||
x_ref, indexers = sp.get_ref_and_indexers(x_ref_or_view, idx, "load")
|
||||
args_flat, args_tree = tree_util.tree_flatten((x_ref, indexers, mask, other))
|
||||
return load_p.bind(
|
||||
@ -429,7 +429,7 @@ def load(x_ref_or_view, idx, *, mask=None, other=None, cache_modifier="",
|
||||
is_volatile=volatile,
|
||||
)
|
||||
|
||||
def swap(x_ref_or_view, idx, val, *, mask=None, eviction_policy="",
|
||||
def swap(x_ref_or_view, idx, val, *, mask=None, eviction_policy=None,
|
||||
_function_name="swap") -> Any:
|
||||
x_ref, indexers = sp.get_ref_and_indexers(x_ref_or_view, idx, _function_name)
|
||||
args_flat, args_tree = tree_util.tree_flatten((x_ref, indexers, val, mask))
|
||||
@ -437,7 +437,7 @@ def swap(x_ref_or_view, idx, val, *, mask=None, eviction_policy="",
|
||||
*args_flat, args_tree=args_tree, eviction_policy=eviction_policy
|
||||
)
|
||||
|
||||
def store(x_ref_or_view, idx, val, *, mask=None, eviction_policy="") -> None:
|
||||
def store(x_ref_or_view, idx, val, *, mask=None, eviction_policy=None) -> None:
|
||||
_ = swap(x_ref_or_view, idx, val, mask=mask, eviction_policy=eviction_policy,
|
||||
_function_name="store")
|
||||
|
||||
|
@ -880,12 +880,14 @@ def _masked_load_lowering_rule(
|
||||
ptr = _compute_pointers_from_indices(
|
||||
ptr, ctx.block_infos[0], idx, ctx.avals_in[0].shape
|
||||
)
|
||||
if other is not None and mask is not None:
|
||||
other = tc.broadcast_to(other, mask.shape)
|
||||
val = tc.load(
|
||||
ptr,
|
||||
mask=mask,
|
||||
other=other,
|
||||
cache_modifier=cache_modifier,
|
||||
volatile=is_volatile,
|
||||
is_volatile=is_volatile,
|
||||
eviction_policy=eviction_policy,
|
||||
)
|
||||
# `tl.load` of a `*int1` returns a tensor with type `int8`, so fix the type.
|
||||
@ -931,7 +933,9 @@ def _masked_swap_lowering_rule(
|
||||
ptr = _compute_pointers_from_indices(
|
||||
ptr, ctx.block_infos[0], idx, ctx.avals_in[0].shape
|
||||
)
|
||||
other = None if mask is None else value
|
||||
other = None
|
||||
if value is not None and mask is not None:
|
||||
other = tc.broadcast_to(value, mask.shape)
|
||||
old_value = tc.load(ptr, mask=mask, other=other)
|
||||
tc.store(
|
||||
ptr,
|
||||
|
@ -453,120 +453,6 @@ class builder:
|
||||
def create_or(self, lhs: ir.Value, rhs: ir.Value) -> ir.Value:
|
||||
return arith_dialect.ori(lhs, rhs)
|
||||
|
||||
def create_load(
|
||||
self,
|
||||
ptr: ir.Value,
|
||||
cache_modifier: tt_dialect.CacheModifier,
|
||||
eviction_policy: tt_dialect.EvictionPolicy,
|
||||
is_volatile: bool,
|
||||
) -> ir.Value:
|
||||
if ir.RankedTensorType.isinstance(ptr.type):
|
||||
ptr_type = ir.RankedTensorType(ptr.type)
|
||||
element_type = tt_dialect.PointerType(ptr_type.element_type)
|
||||
result_type = ir.RankedTensorType.get(
|
||||
ptr_type.shape,
|
||||
element_type.pointee_type,
|
||||
ptr_type.encoding,
|
||||
)
|
||||
else:
|
||||
ptr_type = tt_dialect.PointerType(ptr.type)
|
||||
result_type = ptr_type.pointee_type
|
||||
return tt_dialect.load(
|
||||
result_type, ptr, cache_modifier, eviction_policy, is_volatile
|
||||
)
|
||||
|
||||
def create_store(
|
||||
self,
|
||||
ptr: ir.Value,
|
||||
value: ir.Value,
|
||||
cache_modifier: tt_dialect.CacheModifier,
|
||||
eviction_policy: tt_dialect.EvictionPolicy,
|
||||
) -> ir.Value:
|
||||
return tt_dialect.store(
|
||||
ptr, value, cache=cache_modifier, evict=eviction_policy
|
||||
)
|
||||
|
||||
def create_tensor_pointer_load(
|
||||
self,
|
||||
ptr: ir.Value,
|
||||
boundary_check: Sequence[int],
|
||||
padding_option: Sequence[tt_dialect.PaddingOption],
|
||||
cache_modifier: tt_dialect.CacheModifier,
|
||||
eviction_policy: tt_dialect.EvictionPolicy,
|
||||
is_volatile: bool,
|
||||
) -> ir.Value:
|
||||
return tt_dialect.load(
|
||||
ptr.type,
|
||||
ptr,
|
||||
cache_modifier,
|
||||
eviction_policy,
|
||||
is_volatile,
|
||||
boundary_check=boundary_check,
|
||||
padding=padding_option,
|
||||
)
|
||||
|
||||
def create_tensor_pointer_store(
|
||||
self,
|
||||
ptr: ir.Value,
|
||||
value: ir.Value,
|
||||
boundary_check: Sequence[int],
|
||||
cache_modifier: tt_dialect.CacheModifier,
|
||||
eviction_policy: tt_dialect.EvictionPolicy,
|
||||
) -> ir.Value:
|
||||
return tt_dialect.store(
|
||||
ptr,
|
||||
value,
|
||||
boundary_check=boundary_check,
|
||||
cache=cache_modifier,
|
||||
evict=eviction_policy,
|
||||
)
|
||||
|
||||
def create_masked_load(
|
||||
self,
|
||||
ptr: ir.Value,
|
||||
mask: ir.Value,
|
||||
other: ir.Value | None,
|
||||
cache_modifier: tt_dialect.CacheModifier,
|
||||
eviction_policy: tt_dialect.EvictionPolicy,
|
||||
is_volatile: bool,
|
||||
) -> ir.Value:
|
||||
if ir.RankedTensorType.isinstance(ptr.type):
|
||||
ptr_type = ir.RankedTensorType(ptr.type)
|
||||
element_type = tt_dialect.PointerType(ptr_type.element_type)
|
||||
result_type = ir.RankedTensorType.get(
|
||||
ptr_type.shape,
|
||||
element_type.pointee_type,
|
||||
ptr_type.encoding,
|
||||
)
|
||||
else:
|
||||
ptr_type = tt_dialect.PointerType(ptr.type)
|
||||
result_type = ptr_type.pointee_type
|
||||
return tt_dialect.load(
|
||||
result_type,
|
||||
ptr,
|
||||
cache_modifier,
|
||||
eviction_policy,
|
||||
is_volatile,
|
||||
mask=mask,
|
||||
other=other,
|
||||
)
|
||||
|
||||
def create_masked_store(
|
||||
self,
|
||||
ptr: ir.Value,
|
||||
value: ir.Value,
|
||||
mask: ir.Value,
|
||||
cache_modifier: tt_dialect.CacheModifier,
|
||||
eviction_policy: tt_dialect.EvictionPolicy,
|
||||
) -> ir.Value:
|
||||
return tt_dialect.store(
|
||||
ptr,
|
||||
value,
|
||||
mask=mask,
|
||||
cache=cache_modifier,
|
||||
evict=eviction_policy,
|
||||
)
|
||||
|
||||
def create_cat(self, lhs: ir.Value, rhs: ir.Value) -> ir.Value:
|
||||
assert ir.RankedTensorType.isinstance(lhs.type)
|
||||
assert ir.RankedTensorType.isinstance(rhs.type)
|
||||
@ -739,11 +625,13 @@ block_type = tl.core.block_type
|
||||
function_type = tl.core.function_type
|
||||
pointer_type = tl.core.pointer_type
|
||||
|
||||
void = tl.core.void
|
||||
bfloat16 = tl.core.bfloat16
|
||||
float16 = tl.core.float16
|
||||
float32 = tl.core.float32
|
||||
float64 = tl.core.float64
|
||||
int1 = tl.core.int1
|
||||
int8 = tl.core.int8
|
||||
int32 = tl.core.int32
|
||||
int64 = tl.core.int64
|
||||
uint32 = tl.core.uint32
|
||||
@ -858,8 +746,157 @@ def program_id(axis: int) -> tensor:
|
||||
return tensor(tt_dialect.get_program_id(axis), int32)
|
||||
|
||||
|
||||
load = wrap_with_builder(tl.core.load)
|
||||
store = wrap_with_builder(tl.core.store)
|
||||
_STR_TO_EVICTION_POLICY = {str(e): e for e in tt_dialect.EvictionPolicy}
|
||||
_STR_TO_CACHE_MODIFIER = {str(c): c for c in tt_dialect.CacheModifier}
|
||||
|
||||
|
||||
def _infer_load_return_type(ptr: ir.Value) -> ir.Type:
|
||||
if ir.RankedTensorType.isinstance(ptr.type):
|
||||
ptr_type = ir.RankedTensorType(ptr.type)
|
||||
element_type = tt_dialect.PointerType(ptr_type.element_type)
|
||||
return ir.RankedTensorType.get(
|
||||
ptr_type.shape,
|
||||
element_type.pointee_type,
|
||||
ptr_type.encoding,
|
||||
)
|
||||
else:
|
||||
ptr_type = tt_dialect.PointerType(ptr.type)
|
||||
return ptr_type.pointee_type
|
||||
|
||||
|
||||
def load(
|
||||
ptr: tensor,
|
||||
mask: tensor | None = None,
|
||||
other: tensor | None = None,
|
||||
*,
|
||||
cache_modifier: str | None = None,
|
||||
eviction_policy: str | None = None,
|
||||
is_volatile: bool = False,
|
||||
) -> tensor:
|
||||
if cache_modifier is None:
|
||||
cache_modifier = tt_dialect.CacheModifier.NONE
|
||||
elif cache_modifier == ".ca" or cache_modifier == ".cg":
|
||||
cache_modifier = _STR_TO_CACHE_MODIFIER[cache_modifier]
|
||||
else:
|
||||
raise ValueError(f"unsupported cache modifier: {cache_modifier}")
|
||||
if eviction_policy is None:
|
||||
eviction_policy = tt_dialect.EvictionPolicy.NORMAL
|
||||
else:
|
||||
try:
|
||||
eviction_policy = _STR_TO_EVICTION_POLICY[eviction_policy]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"unsupported eviction policy: {eviction_policy}"
|
||||
) from None
|
||||
|
||||
if ptr.type.is_ptr() and ptr.type.element_ty.is_block():
|
||||
# TODO(slebedev): Support load from a block pointer.
|
||||
raise NotImplementedError("loading from a block pointer is not supported")
|
||||
if not ptr.dtype.is_ptr():
|
||||
raise ValueError(f"unsupported pointer dtype: {ptr.dtype}")
|
||||
if other is not None:
|
||||
if mask is None:
|
||||
raise ValueError("other requires mask to be provided")
|
||||
assert mask.shape == other.shape == ptr.shape, (
|
||||
mask.shape,
|
||||
other.shape,
|
||||
ptr.shape,
|
||||
)
|
||||
elif mask is not None:
|
||||
assert mask.shape == ptr.shape
|
||||
if not ptr.type.is_block():
|
||||
if other is not None and other.type.is_block():
|
||||
raise ValueError("other cannot be a block if pointer is not a block")
|
||||
if mask is not None and mask.type.is_block():
|
||||
raise ValueError("mask cannot be a block if pointer is not a block")
|
||||
|
||||
ptr_type = ptr.dtype
|
||||
element_type = ptr_type.element_ty
|
||||
|
||||
if element_type == int1:
|
||||
# TODO(slebedev): Cast the result back to int1 before returning.
|
||||
element_type = int8
|
||||
ptr_type = pointer_type(element_type, ptr_type.address_space)
|
||||
ptr = semantic.cast(ptr, ptr_type)
|
||||
|
||||
if other is not None:
|
||||
other = semantic.cast(other, element_type)
|
||||
|
||||
result_handle = tt_dialect.load(
|
||||
_infer_load_return_type(ptr.handle),
|
||||
ptr.handle,
|
||||
mask=mask.handle if mask is not None else None,
|
||||
other=other.handle if other is not None else None,
|
||||
cache=cache_modifier,
|
||||
evict=eviction_policy,
|
||||
is_volatile=is_volatile,
|
||||
)
|
||||
if ptr.type.is_block():
|
||||
return tensor(result_handle, block_type(element_type, ptr.type.shape))
|
||||
else:
|
||||
return tensor(result_handle, element_type)
|
||||
|
||||
|
||||
def store(
|
||||
ptr: tensor,
|
||||
value: tensor,
|
||||
mask: tensor | None = None,
|
||||
*,
|
||||
cache_modifier: str | None = None,
|
||||
eviction_policy: str | None = None,
|
||||
) -> tensor:
|
||||
if cache_modifier is None:
|
||||
cache_modifier = tt_dialect.CacheModifier.NONE
|
||||
elif cache_modifier != ".ca":
|
||||
cache_modifier = _STR_TO_CACHE_MODIFIER[cache_modifier]
|
||||
else:
|
||||
raise ValueError(f"unsupported cache modifier: {cache_modifier}")
|
||||
if eviction_policy is None:
|
||||
eviction_policy = tt_dialect.EvictionPolicy.NORMAL
|
||||
else:
|
||||
try:
|
||||
eviction_policy = _STR_TO_EVICTION_POLICY[eviction_policy]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"unsupported eviction policy: {eviction_policy}"
|
||||
) from None
|
||||
|
||||
if ptr.type.is_ptr() and ptr.type.element_ty.is_block():
|
||||
# TODO(slebedev): Support load from a block pointer.
|
||||
raise NotImplementedError("storing to a block pointer is not supported")
|
||||
|
||||
if not ptr.dtype.is_ptr():
|
||||
raise ValueError(f"unsupported pointer dtype: {ptr.dtype}")
|
||||
assert value.shape == ptr.shape
|
||||
if mask is not None:
|
||||
assert mask.shape == ptr.shape
|
||||
if not ptr.type.is_block():
|
||||
if value.type.is_block():
|
||||
raise ValueError("other cannot be a block if pointer is not a block")
|
||||
if mask is not None and mask.type.is_block():
|
||||
raise ValueError("mask cannot be a block if pointer is not a block")
|
||||
|
||||
ptr_type = ptr.dtype
|
||||
element_type = ptr_type.element_ty
|
||||
|
||||
if element_type == int1:
|
||||
# TODO(slebedev): Cast the result back to int1 before returning.
|
||||
element_type = int8
|
||||
ptr_type = pointer_type(element_type, ptr_type.address_space)
|
||||
ptr = semantic.cast(ptr, ptr_type)
|
||||
|
||||
value = semantic.cast(value, element_type)
|
||||
|
||||
return tensor(
|
||||
tt_dialect.store(
|
||||
ptr.handle,
|
||||
value.handle,
|
||||
mask=mask.handle if mask is not None else None,
|
||||
cache=cache_modifier,
|
||||
evict=eviction_policy,
|
||||
),
|
||||
void,
|
||||
)
|
||||
|
||||
|
||||
def arange(start: int, end: int) -> tensor:
|
||||
|
Loading…
x
Reference in New Issue
Block a user