mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[pallas:triton] Simplify lowering code. BlockInfo
is now always present for memory refs.
PiperOrigin-RevId: 695414469
This commit is contained in:
parent
242ac2b92d
commit
0995bc231c
@ -87,7 +87,7 @@ class ModuleContext:
|
||||
class BlockInfo:
|
||||
full_shape_dtype: jax.ShapeDtypeStruct
|
||||
start_indices: Sequence[Any]
|
||||
block_shape: tuple[int, ...] # TODO(necula): can this contain "mapped"?
|
||||
block_shape: tuple[int | pallas_core.Mapped, ...]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -95,7 +95,7 @@ class LoweringRuleContext:
|
||||
context: ModuleContext
|
||||
avals_in: Sequence[jax_core.ShapedArray]
|
||||
avals_out: Sequence[jax_core.ShapedArray]
|
||||
block_infos: Sequence[BlockInfo | None] # TODO(necula): can this be None?
|
||||
block_infos: Sequence[BlockInfo | None]
|
||||
|
||||
replace = dataclasses.replace
|
||||
|
||||
@ -362,14 +362,15 @@ def lower_jaxpr_to_triton_ir(
|
||||
def read_block_info_env(atom: jax_core.Atom):
|
||||
if isinstance(atom, jax_core.Literal):
|
||||
return None
|
||||
return block_info_env.get(atom, None)
|
||||
return block_info_env.get(atom)
|
||||
|
||||
def write_env(var: jax_core.Var, val):
|
||||
env[var] = val
|
||||
|
||||
if block_infos is not None:
|
||||
for invar, block_info in zip(jaxpr.invars, block_infos):
|
||||
block_info_env[invar] = block_info
|
||||
if block_info is not None:
|
||||
block_info_env[invar] = block_info
|
||||
|
||||
map(write_env, jaxpr.invars, args)
|
||||
|
||||
@ -393,7 +394,7 @@ def lower_jaxpr_to_triton_ir(
|
||||
raise # We only add the extra info to the innermost exception.
|
||||
except Exception as e:
|
||||
if not pallas_call._verbose_errors_enabled():
|
||||
raise
|
||||
raise
|
||||
inval_types = map(lambda t: getattr(t, "type", None), invals)
|
||||
raise LoweringError(
|
||||
f"Exception while lowering eqn:\n {eqn}\nWith context:\n "
|
||||
@ -474,14 +475,14 @@ def _atomic_lowering_rule(
|
||||
args_tree,
|
||||
atomic_type: primitives.AtomicOpType,
|
||||
):
|
||||
block_info, *_ = ctx.block_infos
|
||||
assert block_info is not None
|
||||
ptr, indexers, val, mask = args_tree.unflatten(args_flat)
|
||||
*_, value_aval, mask_aval = args_tree.unflatten(ctx.avals_in)
|
||||
if len(indexers) != 1:
|
||||
raise NotImplementedError("Only single indexer is supported.")
|
||||
idx = indexers[0]
|
||||
ptr = _compute_pointers_from_indices(
|
||||
ptr, ctx.block_infos[0], idx, ctx.avals_in[0]
|
||||
)
|
||||
ptr = _compute_pointers_from_indices(ptr, block_info, idx)
|
||||
val = _ensure_ir_value(val, value_aval)
|
||||
if mask is not None:
|
||||
mask = _ensure_ir_value(mask, mask_aval)
|
||||
@ -1631,21 +1632,10 @@ def _reshape_lowering_rule(
|
||||
|
||||
|
||||
def _compute_pointers_from_indices(
|
||||
root_ptr: ir.Value,
|
||||
block_info: BlockInfo | None,
|
||||
nd_indexer: NDIndexer,
|
||||
array_shape_dtype: Any,
|
||||
root_ptr: ir.Value, block_info: BlockInfo, nd_indexer: NDIndexer
|
||||
) -> ir.Value:
|
||||
if block_info is None: # TODO(necula): is this branch dead?
|
||||
full_shape = array_shape_dtype.shape
|
||||
num_mapped_dims = 0
|
||||
block_shape = array_shape_dtype.shape
|
||||
else:
|
||||
full_shape = block_info.full_shape_dtype.shape
|
||||
num_mapped_dims = sum(
|
||||
b is pallas_core.mapped for b in block_info.block_shape
|
||||
)
|
||||
block_shape = block_info.block_shape
|
||||
full_shape = block_info.full_shape_dtype.shape
|
||||
num_mapped_dims = sum(b is pallas_core.mapped for b in block_info.block_shape)
|
||||
strides = pallas_utils.strides_from_shape(full_shape)
|
||||
indexer_shape = nd_indexer.get_indexer_shape()
|
||||
int_indexer_shape = nd_indexer.int_indexer_shape
|
||||
@ -1653,14 +1643,10 @@ def _compute_pointers_from_indices(
|
||||
indices = nd_indexer.indices
|
||||
other_shape = indexer_shape[len(int_indexer_shape) :]
|
||||
other_shape_idx = 0
|
||||
if block_info is None:
|
||||
start_index_offsets = [None] * len(indices)
|
||||
else:
|
||||
start_index_offsets = block_info.start_indices
|
||||
assert len(indices) + num_mapped_dims == len(full_shape)
|
||||
assert len(start_index_offsets) == len(full_shape)
|
||||
assert len(block_info.start_indices) == len(full_shape)
|
||||
|
||||
array_dtype = jnp.dtype(array_shape_dtype.dtype)
|
||||
array_dtype = jnp.dtype(block_info.full_shape_dtype.dtype)
|
||||
full_size = math.prod(full_shape) * array_dtype.itemsize
|
||||
# Use 64-bit indexing when offset might be >= 2**32 bytes.
|
||||
offset_eltype = ir.IntegerType.get_signless(64 if full_size > 2**32 else 32)
|
||||
@ -1671,7 +1657,7 @@ def _compute_pointers_from_indices(
|
||||
|
||||
indexer_iter = iter(indices)
|
||||
for dim_stride, dim_block_size, start_offset in zip(
|
||||
strides, block_shape, start_index_offsets
|
||||
strides, block_info.block_shape, block_info.start_indices
|
||||
):
|
||||
if dim_block_size is pallas_core.mapped:
|
||||
index = _ir_constant(0, offset_eltype)
|
||||
@ -1831,6 +1817,8 @@ def _masked_load_lowering_rule(
|
||||
cache_modifier,
|
||||
is_volatile,
|
||||
):
|
||||
block_info, *_ = ctx.block_infos
|
||||
assert block_info is not None
|
||||
ptr, indexers, mask, other = args_tree.unflatten(args_flat)
|
||||
*_, mask_aval, other_aval = args_tree.unflatten(ctx.avals_in)
|
||||
if len(indexers) > 1:
|
||||
@ -1839,9 +1827,7 @@ def _masked_load_lowering_rule(
|
||||
if not tt_dialect.PointerType.isinstance(ptr.type):
|
||||
assert len(ctx.avals_in) == 1
|
||||
return ptr
|
||||
ptr = _compute_pointers_from_indices(
|
||||
ptr, ctx.block_infos[0], idx, ctx.avals_in[0]
|
||||
)
|
||||
ptr = _compute_pointers_from_indices(ptr, block_info, idx)
|
||||
if mask is not None:
|
||||
mask = _bcast_to(_ensure_ir_value(mask, mask_aval), idx.get_indexer_shape())
|
||||
if other is not None:
|
||||
@ -1931,14 +1917,14 @@ def _store(
|
||||
def _masked_swap_lowering_rule(
|
||||
ctx: LoweringRuleContext, *args_flat, args_tree, eviction_policy
|
||||
):
|
||||
block_info, *_ = ctx.block_infos
|
||||
assert block_info is not None
|
||||
ptr, indexers, value, mask = args_tree.unflatten(args_flat)
|
||||
*_, value_aval, mask_aval = args_tree.unflatten(ctx.avals_in)
|
||||
if len(indexers) > 1:
|
||||
raise NotImplementedError("No support for multiple indexers yet.")
|
||||
idx = indexers[0]
|
||||
ptr = _compute_pointers_from_indices(
|
||||
ptr, ctx.block_infos[0], idx, ctx.avals_in[0]
|
||||
)
|
||||
ptr = _compute_pointers_from_indices(ptr, block_info, idx)
|
||||
other = None
|
||||
if value is not None:
|
||||
value = _ensure_ir_value(value, value_aval)
|
||||
@ -1954,6 +1940,8 @@ def _masked_swap_lowering_rule(
|
||||
|
||||
@register_lowering(sp.addupdate_p)
|
||||
def _addupdate_lowering_rule(ctx: LoweringRuleContext, ptr, value, *idx, tree):
|
||||
block_info, *_ = ctx.block_infos
|
||||
assert block_info is not None
|
||||
indexers = tree_util.tree_unflatten(tree, idx)
|
||||
if not tt_dialect.PointerType.isinstance(ptr.type):
|
||||
assert len(indexers) == 0
|
||||
@ -1961,9 +1949,7 @@ def _addupdate_lowering_rule(ctx: LoweringRuleContext, ptr, value, *idx, tree):
|
||||
if len(indexers) > 1:
|
||||
raise NotImplementedError("No support for multiple indexers yet.")
|
||||
indexer = indexers[0]
|
||||
ptr = _compute_pointers_from_indices(
|
||||
ptr, ctx.block_infos[0], indexer, ctx.avals_in[0]
|
||||
)
|
||||
ptr = _compute_pointers_from_indices(ptr, block_info, indexer)
|
||||
op = tt_dialect.RMWOp.FADD
|
||||
if isinstance(_element_type(value.type), ir.IntegerType):
|
||||
op = tt_dialect.RMWOp.ADD
|
||||
|
Loading…
x
Reference in New Issue
Block a user