[pallas:triton] Simplify lowering code. BlockInfo is now always present for memory refs.

PiperOrigin-RevId: 695414469
This commit is contained in:
Chris Jones 2024-11-11 11:10:06 -08:00 committed by jax authors
parent 242ac2b92d
commit 0995bc231c

View File

@ -87,7 +87,7 @@ class ModuleContext:
class BlockInfo:
full_shape_dtype: jax.ShapeDtypeStruct
start_indices: Sequence[Any]
block_shape: tuple[int, ...] # TODO(necula): can this contain "mapped"?
block_shape: tuple[int | pallas_core.Mapped, ...]
@dataclasses.dataclass
@ -95,7 +95,7 @@ class LoweringRuleContext:
context: ModuleContext
avals_in: Sequence[jax_core.ShapedArray]
avals_out: Sequence[jax_core.ShapedArray]
block_infos: Sequence[BlockInfo | None] # TODO(necula): can this be None?
block_infos: Sequence[BlockInfo | None]
replace = dataclasses.replace
@ -362,14 +362,15 @@ def lower_jaxpr_to_triton_ir(
def read_block_info_env(atom: jax_core.Atom):
if isinstance(atom, jax_core.Literal):
return None
return block_info_env.get(atom, None)
return block_info_env.get(atom)
def write_env(var: jax_core.Var, val):
env[var] = val
if block_infos is not None:
for invar, block_info in zip(jaxpr.invars, block_infos):
block_info_env[invar] = block_info
if block_info is not None:
block_info_env[invar] = block_info
map(write_env, jaxpr.invars, args)
@ -393,7 +394,7 @@ def lower_jaxpr_to_triton_ir(
raise # We only add the extra info to the innermost exception.
except Exception as e:
if not pallas_call._verbose_errors_enabled():
raise
raise
inval_types = map(lambda t: getattr(t, "type", None), invals)
raise LoweringError(
f"Exception while lowering eqn:\n {eqn}\nWith context:\n "
@ -474,14 +475,14 @@ def _atomic_lowering_rule(
args_tree,
atomic_type: primitives.AtomicOpType,
):
block_info, *_ = ctx.block_infos
assert block_info is not None
ptr, indexers, val, mask = args_tree.unflatten(args_flat)
*_, value_aval, mask_aval = args_tree.unflatten(ctx.avals_in)
if len(indexers) != 1:
raise NotImplementedError("Only single indexer is supported.")
idx = indexers[0]
ptr = _compute_pointers_from_indices(
ptr, ctx.block_infos[0], idx, ctx.avals_in[0]
)
ptr = _compute_pointers_from_indices(ptr, block_info, idx)
val = _ensure_ir_value(val, value_aval)
if mask is not None:
mask = _ensure_ir_value(mask, mask_aval)
@ -1631,21 +1632,10 @@ def _reshape_lowering_rule(
def _compute_pointers_from_indices(
root_ptr: ir.Value,
block_info: BlockInfo | None,
nd_indexer: NDIndexer,
array_shape_dtype: Any,
root_ptr: ir.Value, block_info: BlockInfo, nd_indexer: NDIndexer
) -> ir.Value:
if block_info is None: # TODO(necula): is this branch dead?
full_shape = array_shape_dtype.shape
num_mapped_dims = 0
block_shape = array_shape_dtype.shape
else:
full_shape = block_info.full_shape_dtype.shape
num_mapped_dims = sum(
b is pallas_core.mapped for b in block_info.block_shape
)
block_shape = block_info.block_shape
full_shape = block_info.full_shape_dtype.shape
num_mapped_dims = sum(b is pallas_core.mapped for b in block_info.block_shape)
strides = pallas_utils.strides_from_shape(full_shape)
indexer_shape = nd_indexer.get_indexer_shape()
int_indexer_shape = nd_indexer.int_indexer_shape
@ -1653,14 +1643,10 @@ def _compute_pointers_from_indices(
indices = nd_indexer.indices
other_shape = indexer_shape[len(int_indexer_shape) :]
other_shape_idx = 0
if block_info is None:
start_index_offsets = [None] * len(indices)
else:
start_index_offsets = block_info.start_indices
assert len(indices) + num_mapped_dims == len(full_shape)
assert len(start_index_offsets) == len(full_shape)
assert len(block_info.start_indices) == len(full_shape)
array_dtype = jnp.dtype(array_shape_dtype.dtype)
array_dtype = jnp.dtype(block_info.full_shape_dtype.dtype)
full_size = math.prod(full_shape) * array_dtype.itemsize
# Use 64-bit indexing when offset might be >= 2**32 bytes.
offset_eltype = ir.IntegerType.get_signless(64 if full_size > 2**32 else 32)
@ -1671,7 +1657,7 @@ def _compute_pointers_from_indices(
indexer_iter = iter(indices)
for dim_stride, dim_block_size, start_offset in zip(
strides, block_shape, start_index_offsets
strides, block_info.block_shape, block_info.start_indices
):
if dim_block_size is pallas_core.mapped:
index = _ir_constant(0, offset_eltype)
@ -1831,6 +1817,8 @@ def _masked_load_lowering_rule(
cache_modifier,
is_volatile,
):
block_info, *_ = ctx.block_infos
assert block_info is not None
ptr, indexers, mask, other = args_tree.unflatten(args_flat)
*_, mask_aval, other_aval = args_tree.unflatten(ctx.avals_in)
if len(indexers) > 1:
@ -1839,9 +1827,7 @@ def _masked_load_lowering_rule(
if not tt_dialect.PointerType.isinstance(ptr.type):
assert len(ctx.avals_in) == 1
return ptr
ptr = _compute_pointers_from_indices(
ptr, ctx.block_infos[0], idx, ctx.avals_in[0]
)
ptr = _compute_pointers_from_indices(ptr, block_info, idx)
if mask is not None:
mask = _bcast_to(_ensure_ir_value(mask, mask_aval), idx.get_indexer_shape())
if other is not None:
@ -1931,14 +1917,14 @@ def _store(
def _masked_swap_lowering_rule(
ctx: LoweringRuleContext, *args_flat, args_tree, eviction_policy
):
block_info, *_ = ctx.block_infos
assert block_info is not None
ptr, indexers, value, mask = args_tree.unflatten(args_flat)
*_, value_aval, mask_aval = args_tree.unflatten(ctx.avals_in)
if len(indexers) > 1:
raise NotImplementedError("No support for multiple indexers yet.")
idx = indexers[0]
ptr = _compute_pointers_from_indices(
ptr, ctx.block_infos[0], idx, ctx.avals_in[0]
)
ptr = _compute_pointers_from_indices(ptr, block_info, idx)
other = None
if value is not None:
value = _ensure_ir_value(value, value_aval)
@ -1954,6 +1940,8 @@ def _masked_swap_lowering_rule(
@register_lowering(sp.addupdate_p)
def _addupdate_lowering_rule(ctx: LoweringRuleContext, ptr, value, *idx, tree):
block_info, *_ = ctx.block_infos
assert block_info is not None
indexers = tree_util.tree_unflatten(tree, idx)
if not tt_dialect.PointerType.isinstance(ptr.type):
assert len(indexers) == 0
@ -1961,9 +1949,7 @@ def _addupdate_lowering_rule(ctx: LoweringRuleContext, ptr, value, *idx, tree):
if len(indexers) > 1:
raise NotImplementedError("No support for multiple indexers yet.")
indexer = indexers[0]
ptr = _compute_pointers_from_indices(
ptr, ctx.block_infos[0], indexer, ctx.avals_in[0]
)
ptr = _compute_pointers_from_indices(ptr, block_info, indexer)
op = tt_dialect.RMWOp.FADD
if isinstance(_element_type(value.type), ir.IntegerType):
op = tt_dialect.RMWOp.ADD