mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[Pallas/MGPU] Skip output transfers when they don't depend on sequenital dims
Note that thanks to the previous revisiting-related checks we weren't doing the transfers anyway, but this way we can also avoid having to pay for the checks. PiperOrigin-RevId: 679516275
This commit is contained in:
parent
afaf8b823d
commit
5740ab3b02
@ -201,6 +201,11 @@ def _eval_index_map(
|
||||
return tuple(result)
|
||||
|
||||
|
||||
def _uses_arguments(cjaxpr: jax_core.ClosedJaxpr) -> list[bool]:
|
||||
jaxpr = cjaxpr.jaxpr
|
||||
return pe.dce_jaxpr(jaxpr, used_outputs=[True] * len(jaxpr.outvars))[1]
|
||||
|
||||
|
||||
def lower_jaxpr_to_module(
|
||||
grid_mapping: pallas_core.GridMapping,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
@ -270,8 +275,13 @@ def lower_jaxpr_to_module(
|
||||
)
|
||||
[sequential_axis] = sequential_axes
|
||||
num_steps = grid_mapping.grid[sequential_axis]
|
||||
out_sequential_invariant = [
|
||||
not _uses_arguments(bm.index_map_jaxpr)[sequential_axis]
|
||||
for bm in grid_mapping.block_mappings_output
|
||||
]
|
||||
else:
|
||||
num_steps = 1
|
||||
out_sequential_invariant = [True] * len(grid_mapping.out_shapes)
|
||||
|
||||
in_in_smem, out_in_smem = util.split_list(
|
||||
[
|
||||
@ -429,36 +439,42 @@ def lower_jaxpr_to_module(
|
||||
)
|
||||
|
||||
def store(
|
||||
idx: int, step: ir.Value, slot: ir.Value, prev_base_offset: ir.Value
|
||||
) -> ir.Value:
|
||||
idx: int, step: ir.Value, slot: ir.Value, prev_base_offset: ir.Value | None
|
||||
) -> ir.Value | None:
|
||||
if not out_in_smem[idx]:
|
||||
return _as_index(-1)
|
||||
|
||||
# We have to do some work to make sure that consecutive stores are not
|
||||
# going to be writing to the same location, or else we'll end up with
|
||||
# multiple concurrent writes and a racy program.
|
||||
# TODO(apaszke,slebedev): In most cases output index maps depend only on
|
||||
# parallel grid axes and in that case we can simply move the store to
|
||||
# happen after the loop.
|
||||
# TODO(apaszke,slebedev): This still diverges significantly from the TPU
|
||||
# semantics in that it will move on to the next SMEM output slice even if
|
||||
# it's not storing the previous one.
|
||||
store_slice = gmem_slice(step, out_block_mappings[idx])
|
||||
strides, _ = ir.MemRefType(out_buffers_gmem[idx].type).get_strides_and_offset()
|
||||
base_offset = _as_index(0)
|
||||
for stride, slc in zip(strides, store_slice):
|
||||
base_offset = arith_dialect.addi(
|
||||
base_offset, arith_dialect.muli(slc.base, _as_index(stride))
|
||||
if out_sequential_invariant[idx]:
|
||||
assert prev_base_offset is None
|
||||
do_store = None # Lack of predicate defaults to True.
|
||||
base_offset = None
|
||||
else:
|
||||
assert prev_base_offset is not None
|
||||
# We have to do some work to make sure that consecutive stores are not
|
||||
# going to be writing to the same location, or else we'll end up with
|
||||
# multiple concurrent writes and a racy program.
|
||||
# TODO(apaszke,slebedev): In most cases output index maps depend only on
|
||||
# parallel grid axes and in that case we can simply move the store to
|
||||
# happen after the loop.
|
||||
# TODO(apaszke,slebedev): This still diverges significantly from the TPU
|
||||
# semantics in that it will move on to the next SMEM output slice even if
|
||||
# it's not storing the previous one.
|
||||
strides, _ = ir.MemRefType(out_buffers_gmem[idx].type).get_strides_and_offset()
|
||||
base_offset = _as_index(0)
|
||||
for stride, slc in zip(strides, store_slice):
|
||||
base_offset = arith_dialect.addi(
|
||||
base_offset, arith_dialect.muli(slc.base, _as_index(stride))
|
||||
)
|
||||
base_offset_changed = arith_dialect.cmpi(
|
||||
arith_dialect.CmpIPredicate.ne, base_offset, prev_base_offset
|
||||
)
|
||||
is_last_step = arith_dialect.cmpi(
|
||||
arith_dialect.CmpIPredicate.eq, step, _as_index(num_steps - 1)
|
||||
)
|
||||
do_store = arith_dialect.andi(
|
||||
is_memory_thread, arith_dialect.ori(base_offset_changed, is_last_step)
|
||||
)
|
||||
base_offset_changed = arith_dialect.cmpi(
|
||||
arith_dialect.CmpIPredicate.ne, base_offset, prev_base_offset
|
||||
)
|
||||
is_last_step = arith_dialect.cmpi(
|
||||
arith_dialect.CmpIPredicate.eq, step, _as_index(num_steps - 1)
|
||||
)
|
||||
do_store = arith_dialect.andi(
|
||||
is_memory_thread, arith_dialect.ori(base_offset_changed, is_last_step)
|
||||
)
|
||||
# TODO(slebedev): Support 128-byte swizzling, once we can lower matmuls.
|
||||
launch_ctx.async_copy(
|
||||
src_ref=mgpu.memref_slice(out_buffers_smem[idx], slot),
|
||||
@ -475,7 +491,7 @@ def lower_jaxpr_to_module(
|
||||
for idx in range(grid_mapping.num_inputs):
|
||||
fetch(idx, _as_index(slot), _as_index(slot))
|
||||
|
||||
last_store_offsets = [_as_index(-1)] * grid_mapping.num_outputs
|
||||
last_store_offsets = [None if inv else _as_index(-1) for inv in out_sequential_invariant]
|
||||
@mgpu.fori(_as_index(num_steps), (accs, last_store_offsets))
|
||||
def _(step, carry):
|
||||
accs, last_store_offsets = carry
|
||||
@ -510,8 +526,11 @@ def lower_jaxpr_to_module(
|
||||
mgpu.commit_shared()
|
||||
new_store_offsets = []
|
||||
for idx in range(grid_mapping.num_outputs):
|
||||
last_offset = last_store_offsets[idx]
|
||||
new_store_offsets.append(
|
||||
store(idx, step, slot, last_store_offsets[idx])
|
||||
store(idx, step, slot, last_offset)
|
||||
if not out_sequential_invariant[idx]
|
||||
else last_offset # Only store if the output can depend on the step.
|
||||
)
|
||||
|
||||
next_step = arith_dialect.addi(step, _as_index(max_concurrent_steps))
|
||||
@ -526,6 +545,13 @@ def lower_jaxpr_to_module(
|
||||
|
||||
return list(new_accs), new_store_offsets
|
||||
|
||||
# Outputs invariant to the sequential axis are never written from inside the
|
||||
# loop. This is the only place where we store them.
|
||||
last_slot = _as_index((num_steps - 1) % max_concurrent_steps)
|
||||
for idx in range(grid_mapping.num_outputs):
|
||||
if out_sequential_invariant[idx]:
|
||||
store(idx, _as_index(0), last_slot, None)
|
||||
|
||||
launch_ctx.await_async_copy(0)
|
||||
|
||||
scratch_avals = [
|
||||
|
Loading…
x
Reference in New Issue
Block a user