[pallas:mosaic_gpu] Ported two pipelining optimizations to emit_pipeline

* Skip SMEM->GMEM copy if the destination buffer is being revisited
* Skip SMEM->GMEM copy if the corresponding index map does not use grid indices

PiperOrigin-RevId: 696448043
This commit is contained in:
Sergei Lebedev 2024-11-14 02:37:08 -08:00 committed by jax authors
parent 83700828c5
commit aefe6215ca
5 changed files with 195 additions and 28 deletions

View File

@ -107,7 +107,10 @@ pytype_strict_library(
":core",
":primitives",
"//jax",
"//jax:core",
"//jax:mosaic_gpu",
"//jax:pallas",
"//jax:partial_eval",
"//jax:util",
"//jax/_src/pallas",
],

View File

@ -264,6 +264,7 @@ class ModuleContext:
class LoweringRuleContext:
module_ctx: ModuleContext
launch_ctx: mgpu.LaunchContext
predicate: ir.Value
avals_in: Sequence[jax_core.ShapedArray]
avals_out: Sequence[jax_core.ShapedArray]
@ -878,6 +879,7 @@ def lower_jaxpr_to_mosaic_gpu(
rule_ctx = LoweringRuleContext(
module_ctx,
launch_ctx,
predicate=mgpu.single_thread_predicate(per_block=False),
avals_in=[cast(jax_core.ShapedArray, v.aval) for v in eqn.invars],
avals_out=[cast(jax_core.ShapedArray, v.aval) for v in eqn.outvars],
)
@ -1120,6 +1122,12 @@ def _convert_element_type_lowering_rule(
)
mosaic_lowering_rules.update({
lax.neg_p: lambda ctx, x: -x,
lax.not_p: lambda ctx, x: ~x,
})
def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl):
x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out)
return impl(x, y)
@ -1576,4 +1584,4 @@ def _as_index(v: object) -> ir.Value:
case mgpu.FragmentedArray(layout=mgpu.WGSplatFragLayout()):
return _as_index(v.registers.item())
case _:
raise ValueError(f"Unsupported index: {v}")
raise ValueError(f"Unsupported index: {v} of type {type(v)}")

View File

@ -16,7 +16,7 @@
from __future__ import annotations
from collections.abc import Sequence
from collections.abc import Callable, Sequence
import dataclasses
import functools
import itertools as it
@ -25,7 +25,10 @@ from typing import Any
import jax
from jax import lax
from jax._src import core
from jax._src import linear_util as lu
from jax._src import util
from jax._src.interpreters import partial_eval as pe
from jax._src.pallas import core as pallas_core
from jax._src.pallas.mosaic_gpu import core as gpu_core
from jax._src.pallas.mosaic_gpu import primitives as gpu_primitives
@ -37,17 +40,19 @@ map = util.safe_map
zip = util.safe_zip
@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class BufferedRef:
spec: pallas_core.BlockSpec
spec: pallas_core.BlockSpec = dataclasses.field(metadata={"static": True})
is_index_invariant: bool = dataclasses.field(metadata={"static": True})
gmem_ref: pallas_core.AbstractMemoryRef
smem_ref: pallas_core.AbstractMemoryRef # [num_slots, *spec.block_shape]
def compute_gmem_slice(self, grid_indices) -> tuple[Any, ...]:
def compute_gmem_slice(self, grid_indices) -> tuple[pl.Slice, ...]:
index_map = self.spec.index_map
assert index_map is not None
return tuple(
pl.ds(idx * size, size)
pl.Slice(idx * size, size) # type: ignore[arg-type]
for idx, size in zip(
index_map(*grid_indices), self.spec.block_shape # type: ignore[arg-type]
)
@ -61,16 +66,31 @@ class BufferedRef:
barrier=barrier_ref.at[slot],
)
def copy_out(self, slot, grid_indices):
def copy_out(self, slot, grid_indices, predicate=None):
gmem_slices = self.compute_gmem_slice(grid_indices)
gpu_primitives.copy_smem_to_gmem(
self.smem_ref.at[slot], self.gmem_ref.at[gmem_slices] # pytype: disable=unsupported-operands
self.smem_ref.at[slot],
self.gmem_ref.at[gmem_slices], # pytype: disable=unsupported-operands
predicate=predicate,
)
jax.tree_util.register_dataclass(
BufferedRef, data_fields=["gmem_ref", "smem_ref"], meta_fields=["spec"]
)
def _uses_arguments(
index_map: Callable[..., Any], num_args: int
) -> Sequence[bool]:
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(index_map), (core.ShapedArray((), jnp.int32),) * num_args
)
_, used_inputs = pe.dce_jaxpr(jaxpr, used_outputs=[True] * len(jaxpr.outvars))
return used_inputs
def _is_index_invariant(
spec: pallas_core.BlockSpec, grid: pallas_core.StaticGrid
) -> bool:
index_map = spec.index_map
assert index_map is not None
return not any(_uses_arguments(index_map, len(grid)))
def _inc_grid_by_1(
@ -85,6 +105,25 @@ def _inc_grid_by_1(
return tuple(reversed(next_indices))
# ``pl.Slice`` uses a different pytree encoding, depending on whether the
# start/size are static or dynamic. This leads to pytree structure mismatch
# in the pipeline body. So, we define a different ``Slice`` class below.
@dataclasses.dataclass(frozen=True)
class _Slice:
start: int | jax.Array
size: int | jax.Array
def __eq__(self, other: _Slice) -> jax.Array: # type: ignore
return lax.bitwise_and(self.start == other.start, self.size == other.size)
jax.tree_util.register_dataclass(
_Slice, data_fields=["start", "size"], meta_fields=[]
)
def emit_pipeline(
body,
*,
@ -102,6 +141,16 @@ def emit_pipeline(
max_concurrent_steps = num_steps
def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef):
for gmem_ref, spec in zip(gmem_refs, it.chain(in_specs, out_specs)):
if any(
spec.block_shape[-idx] * grid[-idx] != gmem_ref.shape[-idx] # type: ignore
for idx in range(1, len(grid) + 1)
):
raise NotImplementedError(
f"Cannot emit a pipeline over the {grid=} for {gmem_ref} with block"
f" shape {spec.block_shape}."
)
in_gmem_refs, out_gmem_refs = util.split_list(gmem_refs, [len(in_specs)])
in_smem_refs, out_smem_refs = util.split_list(
map(
@ -132,13 +181,18 @@ def emit_pipeline(
def scoped_pipeline(
*, in_gmem_refs, out_gmem_refs, in_smem_refs, out_smem_refs, barrier_ref
):
in_brefs: Sequence[BufferedRef] = map(
BufferedRef, in_specs, in_gmem_refs, in_smem_refs
)
out_brefs: Sequence[BufferedRef] = map(
BufferedRef, out_specs, out_gmem_refs, out_smem_refs
)
in_brefs: Sequence[BufferedRef] = [
BufferedRef(spec, _is_index_invariant(spec, grid), gmem_ref, smem_ref)
for spec, gmem_ref, smem_ref in zip(
in_specs, in_gmem_refs, in_smem_refs
)
]
out_brefs: Sequence[BufferedRef] = [
BufferedRef(spec, _is_index_invariant(spec, grid), gmem_ref, smem_ref)
for spec, gmem_ref, smem_ref in zip(
out_specs, out_gmem_refs, out_smem_refs
)
]
for step, indices in enumerate(
it.islice(it.product(*map(range, grid)), max_concurrent_steps)
@ -147,10 +201,11 @@ def emit_pipeline(
def loop_body(step, carry):
slot = step % max_concurrent_steps
indices, fetch_indices = carry
indices, fetch_indices, last_store_slices = carry
# Wait for the current GMEM->SMEM copy to complete.
gpu_primitives.barrier_wait(barrier_ref.at[slot])
if in_specs:
# Wait for the current GMEM->SMEM copy to complete.
gpu_primitives.barrier_wait(barrier_ref.at[slot])
# Wait for the previous output SMEM->GMEM copy to complete.
gpu_primitives.wait_smem_to_gmem(max_concurrent_steps - 1)
@ -159,9 +214,34 @@ def emit_pipeline(
*(bref.smem_ref.at[slot] for bref in it.chain(in_brefs, out_brefs))
)
if not all(bref.is_index_invariant for bref in out_brefs):
gpu_primitives.commit_smem()
# Copy the output from SMEM to GMEM.
gpu_primitives.commit_smem()
map(lambda bref: bref.copy_out(slot, indices), out_brefs)
new_store_slices = last_store_slices[:]
for idx, bref in enumerate(out_brefs):
if bref.is_index_invariant:
assert last_store_slices[idx] is None
continue
assert last_store_slices[idx] is not None
new_store_slices[idx] = tuple(
_Slice(s.start, s.size) for s in bref.compute_gmem_slice(indices)
)
are_same_slices = map(
lambda old, new: old == new,
last_store_slices[idx],
new_store_slices[idx],
)
slices_changed = ~functools.reduce(lax.bitwise_and, are_same_slices)
is_last_step = step == num_steps - 1
# TODO(apaszke,slebedev): This still diverges significantly from the
# TPU semantics in that it will move on to the next SMEM output slice
# even if it's not storing the previous one.
bref.copy_out(
slot,
indices,
predicate=lax.bitwise_or(slices_changed, is_last_step),
)
fetch_step = step + max_concurrent_steps
fetch_slot = slot # (x + y) % y == x % y
@ -174,13 +254,34 @@ def emit_pipeline(
lambda: [None] * len(in_brefs),
)
return _inc_grid_by_1(indices, grid), _inc_grid_by_1(fetch_indices, grid)
return (
_inc_grid_by_1(indices, grid),
_inc_grid_by_1(fetch_indices, grid),
new_store_slices,
)
indices = (jnp.asarray(0, dtype=lax.dtype(0)),) * len(grid)
fetch_indices = indices
for _ in range(max_concurrent_steps):
fetch_indices = _inc_grid_by_1(fetch_indices, grid)
lax.fori_loop(0, num_steps, loop_body, (indices, fetch_indices))
last_store_slices = [
None
if bref.is_index_invariant
else (_Slice(-1, -1),) * len(bref.spec.block_shape)
for bref in out_brefs
]
last_indices, _, _ = lax.fori_loop(
0, num_steps, loop_body, (indices, fetch_indices, last_store_slices)
)
# Outputs invariant to the sequential axis are never written from inside the
# loop. This is the only place where we store them.
if all(bref.is_index_invariant for bref in out_brefs):
gpu_primitives.commit_smem()
last_slot = (num_steps - 1) % max_concurrent_steps
for bref in out_brefs:
if bref.is_index_invariant:
bref.copy_out(last_slot, last_indices, predicate=None)
# Finalize the pipeline.
gpu_primitives.wait_smem_to_gmem(0)

View File

@ -26,6 +26,7 @@ from jax._src import state
from jax._src import tree_util
from jax._src import util
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith as arith_dialect
from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect
from jax._src.pallas import core as pallas_core
from jax._src.pallas.mosaic_gpu import core as gpu_core
@ -34,6 +35,7 @@ from jax._src.state import discharge
from jax._src.state import indexing
from jax._src.state import primitives as state_primitives
import jax.experimental.mosaic.gpu as mgpu
import jax.numpy as jnp
WARPGROUP_SIZE = 128
@ -54,19 +56,31 @@ def _copy_smem_to_gmem_lowering(
ctx: lowering.LoweringRuleContext,
src,
dst,
*flat_transforms,
*flat_args,
src_transforms_treedef,
dst_transforms_treedef,
has_user_predicate,
):
predicate = ctx.predicate
if has_user_predicate:
flat_args, user_predicate = flat_args[:-1], flat_args[-1]
predicate = arith_dialect.andi(
predicate, lowering._ensure_ir_value(user_predicate, jnp.bool)
)
flat_src_transforms, flat_dst_transforms = util.split_list(
flat_transforms,
flat_args,
[src_transforms_treedef.num_leaves],
)
src_transforms = src_transforms_treedef.unflatten(flat_src_transforms)
dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms)
src, src_transforms = lowering._handle_indexing(src, src_transforms)
copy_params = _extract_gmem_copy_params(dst_transforms) | _extract_smem_copy_params(src_transforms)
ctx.launch_ctx.async_copy(src_ref=src, dst_ref=dst, **copy_params)
ctx.launch_ctx.async_copy(
src_ref=src,
dst_ref=dst,
predicate=predicate,
**copy_params,
)
return ()
@ -98,10 +112,18 @@ def _extract_smem_copy_params(transforms):
def copy_smem_to_gmem(
src: pallas_core.AbstractMemoryRef, dst: pallas_core.AbstractMemoryRef
src: pallas_core.AbstractMemoryRef,
dst: pallas_core.AbstractMemoryRef,
predicate: jax.Array | None = None,
) -> None:
"""Asynchronously copies a SMEM reference to a GMEM reference.
Args:
src: The SMEM reference to copy from.
dst: The GMEM reference to copy to.
predicate: A boolean indicating whether the copy should be performed. If
``None``, the copy is always performed.
See also:
:func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem`
:func:`jax.experimental.mosaic.gpu.commit_smem`
@ -127,8 +149,10 @@ def copy_smem_to_gmem(
dst,
*flat_src_transforms,
*flat_dst_transforms,
*[] if predicate is None else [predicate],
src_transforms_treedef=src_transforms_treedef,
dst_transforms_treedef=dst_transforms_treedef,
has_user_predicate=predicate is not None,
)
return None

View File

@ -1146,6 +1146,37 @@ class PipelineTest(PallasTest):
)
np.testing.assert_array_equal(kernel_fn(x), x + 1.0)
def test_emit_with_grid_invariant_output(self):
num_steps = 4
def kernel(x_gmem, o_gmem):
plgpu.emit_pipeline(
kernel_body,
in_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))],
out_specs=[pl.BlockSpec((32, 16), lambda i: (0, 0))],
grid=(num_steps,),
max_concurrent_steps=2,
)(x_gmem, o_gmem)
def kernel_body(x_smem, o_smem):
o_smem[...] = x_smem[...] + 1.0
x = jnp.arange(32 * num_steps * 16)
x = x.reshape(-1, num_steps * 16).astype(jnp.float32)
kernel_fn = pl.pallas_call(
kernel,
in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
)
y = jnp.empty_like(x)
for i in range(num_steps):
i_slice = slice(16 * i, 16 * (i + 1))
y = y.at[:, :16].set(x[:, i_slice] + 1)
# We only compare the elements in the first 16 columns, because the rest
# are never written to.
np.testing.assert_array_equal(kernel_fn(x)[:, :16], y[:, :16])
def test_emit_with_parallel_grid(self):
self.skipTest("Enable once we support multiple levels of indexing")