mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[pallas:mosaic_gpu] Ported two pipelining optimizations to emit_pipeline
* Skip SMEM->GMEM copy if the destination buffer is being revisited * Skip SMEM->GMEM copy if the corresponding index map does not use grid indices PiperOrigin-RevId: 696448043
This commit is contained in:
parent
83700828c5
commit
aefe6215ca
@ -107,7 +107,10 @@ pytype_strict_library(
|
||||
":core",
|
||||
":primitives",
|
||||
"//jax",
|
||||
"//jax:core",
|
||||
"//jax:mosaic_gpu",
|
||||
"//jax:pallas",
|
||||
"//jax:partial_eval",
|
||||
"//jax:util",
|
||||
"//jax/_src/pallas",
|
||||
],
|
||||
|
@ -264,6 +264,7 @@ class ModuleContext:
|
||||
class LoweringRuleContext:
|
||||
module_ctx: ModuleContext
|
||||
launch_ctx: mgpu.LaunchContext
|
||||
predicate: ir.Value
|
||||
avals_in: Sequence[jax_core.ShapedArray]
|
||||
avals_out: Sequence[jax_core.ShapedArray]
|
||||
|
||||
@ -878,6 +879,7 @@ def lower_jaxpr_to_mosaic_gpu(
|
||||
rule_ctx = LoweringRuleContext(
|
||||
module_ctx,
|
||||
launch_ctx,
|
||||
predicate=mgpu.single_thread_predicate(per_block=False),
|
||||
avals_in=[cast(jax_core.ShapedArray, v.aval) for v in eqn.invars],
|
||||
avals_out=[cast(jax_core.ShapedArray, v.aval) for v in eqn.outvars],
|
||||
)
|
||||
@ -1120,6 +1122,12 @@ def _convert_element_type_lowering_rule(
|
||||
)
|
||||
|
||||
|
||||
mosaic_lowering_rules.update({
|
||||
lax.neg_p: lambda ctx, x: -x,
|
||||
lax.not_p: lambda ctx, x: ~x,
|
||||
})
|
||||
|
||||
|
||||
def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl):
|
||||
x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out)
|
||||
return impl(x, y)
|
||||
@ -1576,4 +1584,4 @@ def _as_index(v: object) -> ir.Value:
|
||||
case mgpu.FragmentedArray(layout=mgpu.WGSplatFragLayout()):
|
||||
return _as_index(v.registers.item())
|
||||
case _:
|
||||
raise ValueError(f"Unsupported index: {v}")
|
||||
raise ValueError(f"Unsupported index: {v} of type {type(v)}")
|
||||
|
@ -16,7 +16,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
import dataclasses
|
||||
import functools
|
||||
import itertools as it
|
||||
@ -25,7 +25,10 @@ from typing import Any
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import util
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas.mosaic_gpu import core as gpu_core
|
||||
from jax._src.pallas.mosaic_gpu import primitives as gpu_primitives
|
||||
@ -37,17 +40,19 @@ map = util.safe_map
|
||||
zip = util.safe_zip
|
||||
|
||||
|
||||
@jax.tree_util.register_dataclass
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class BufferedRef:
|
||||
spec: pallas_core.BlockSpec
|
||||
spec: pallas_core.BlockSpec = dataclasses.field(metadata={"static": True})
|
||||
is_index_invariant: bool = dataclasses.field(metadata={"static": True})
|
||||
gmem_ref: pallas_core.AbstractMemoryRef
|
||||
smem_ref: pallas_core.AbstractMemoryRef # [num_slots, *spec.block_shape]
|
||||
|
||||
def compute_gmem_slice(self, grid_indices) -> tuple[Any, ...]:
|
||||
def compute_gmem_slice(self, grid_indices) -> tuple[pl.Slice, ...]:
|
||||
index_map = self.spec.index_map
|
||||
assert index_map is not None
|
||||
return tuple(
|
||||
pl.ds(idx * size, size)
|
||||
pl.Slice(idx * size, size) # type: ignore[arg-type]
|
||||
for idx, size in zip(
|
||||
index_map(*grid_indices), self.spec.block_shape # type: ignore[arg-type]
|
||||
)
|
||||
@ -61,16 +66,31 @@ class BufferedRef:
|
||||
barrier=barrier_ref.at[slot],
|
||||
)
|
||||
|
||||
def copy_out(self, slot, grid_indices):
|
||||
def copy_out(self, slot, grid_indices, predicate=None):
|
||||
gmem_slices = self.compute_gmem_slice(grid_indices)
|
||||
gpu_primitives.copy_smem_to_gmem(
|
||||
self.smem_ref.at[slot], self.gmem_ref.at[gmem_slices] # pytype: disable=unsupported-operands
|
||||
self.smem_ref.at[slot],
|
||||
self.gmem_ref.at[gmem_slices], # pytype: disable=unsupported-operands
|
||||
predicate=predicate,
|
||||
)
|
||||
|
||||
|
||||
jax.tree_util.register_dataclass(
|
||||
BufferedRef, data_fields=["gmem_ref", "smem_ref"], meta_fields=["spec"]
|
||||
)
|
||||
def _uses_arguments(
|
||||
index_map: Callable[..., Any], num_args: int
|
||||
) -> Sequence[bool]:
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(index_map), (core.ShapedArray((), jnp.int32),) * num_args
|
||||
)
|
||||
_, used_inputs = pe.dce_jaxpr(jaxpr, used_outputs=[True] * len(jaxpr.outvars))
|
||||
return used_inputs
|
||||
|
||||
|
||||
def _is_index_invariant(
|
||||
spec: pallas_core.BlockSpec, grid: pallas_core.StaticGrid
|
||||
) -> bool:
|
||||
index_map = spec.index_map
|
||||
assert index_map is not None
|
||||
return not any(_uses_arguments(index_map, len(grid)))
|
||||
|
||||
|
||||
def _inc_grid_by_1(
|
||||
@ -85,6 +105,25 @@ def _inc_grid_by_1(
|
||||
return tuple(reversed(next_indices))
|
||||
|
||||
|
||||
# ``pl.Slice`` uses a different pytree encoding, depending on whether the
|
||||
# start/size are static or dynamic. This leads to pytree structure mismatch
|
||||
# in the pipeline body. So, we define a different ``Slice`` class below.
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class _Slice:
|
||||
start: int | jax.Array
|
||||
size: int | jax.Array
|
||||
|
||||
def __eq__(self, other: _Slice) -> jax.Array: # type: ignore
|
||||
return lax.bitwise_and(self.start == other.start, self.size == other.size)
|
||||
|
||||
|
||||
jax.tree_util.register_dataclass(
|
||||
_Slice, data_fields=["start", "size"], meta_fields=[]
|
||||
)
|
||||
|
||||
|
||||
def emit_pipeline(
|
||||
body,
|
||||
*,
|
||||
@ -102,6 +141,16 @@ def emit_pipeline(
|
||||
max_concurrent_steps = num_steps
|
||||
|
||||
def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef):
|
||||
for gmem_ref, spec in zip(gmem_refs, it.chain(in_specs, out_specs)):
|
||||
if any(
|
||||
spec.block_shape[-idx] * grid[-idx] != gmem_ref.shape[-idx] # type: ignore
|
||||
for idx in range(1, len(grid) + 1)
|
||||
):
|
||||
raise NotImplementedError(
|
||||
f"Cannot emit a pipeline over the {grid=} for {gmem_ref} with block"
|
||||
f" shape {spec.block_shape}."
|
||||
)
|
||||
|
||||
in_gmem_refs, out_gmem_refs = util.split_list(gmem_refs, [len(in_specs)])
|
||||
in_smem_refs, out_smem_refs = util.split_list(
|
||||
map(
|
||||
@ -132,13 +181,18 @@ def emit_pipeline(
|
||||
def scoped_pipeline(
|
||||
*, in_gmem_refs, out_gmem_refs, in_smem_refs, out_smem_refs, barrier_ref
|
||||
):
|
||||
|
||||
in_brefs: Sequence[BufferedRef] = map(
|
||||
BufferedRef, in_specs, in_gmem_refs, in_smem_refs
|
||||
)
|
||||
out_brefs: Sequence[BufferedRef] = map(
|
||||
BufferedRef, out_specs, out_gmem_refs, out_smem_refs
|
||||
)
|
||||
in_brefs: Sequence[BufferedRef] = [
|
||||
BufferedRef(spec, _is_index_invariant(spec, grid), gmem_ref, smem_ref)
|
||||
for spec, gmem_ref, smem_ref in zip(
|
||||
in_specs, in_gmem_refs, in_smem_refs
|
||||
)
|
||||
]
|
||||
out_brefs: Sequence[BufferedRef] = [
|
||||
BufferedRef(spec, _is_index_invariant(spec, grid), gmem_ref, smem_ref)
|
||||
for spec, gmem_ref, smem_ref in zip(
|
||||
out_specs, out_gmem_refs, out_smem_refs
|
||||
)
|
||||
]
|
||||
|
||||
for step, indices in enumerate(
|
||||
it.islice(it.product(*map(range, grid)), max_concurrent_steps)
|
||||
@ -147,10 +201,11 @@ def emit_pipeline(
|
||||
|
||||
def loop_body(step, carry):
|
||||
slot = step % max_concurrent_steps
|
||||
indices, fetch_indices = carry
|
||||
indices, fetch_indices, last_store_slices = carry
|
||||
|
||||
# Wait for the current GMEM->SMEM copy to complete.
|
||||
gpu_primitives.barrier_wait(barrier_ref.at[slot])
|
||||
if in_specs:
|
||||
# Wait for the current GMEM->SMEM copy to complete.
|
||||
gpu_primitives.barrier_wait(barrier_ref.at[slot])
|
||||
# Wait for the previous output SMEM->GMEM copy to complete.
|
||||
gpu_primitives.wait_smem_to_gmem(max_concurrent_steps - 1)
|
||||
|
||||
@ -159,9 +214,34 @@ def emit_pipeline(
|
||||
*(bref.smem_ref.at[slot] for bref in it.chain(in_brefs, out_brefs))
|
||||
)
|
||||
|
||||
if not all(bref.is_index_invariant for bref in out_brefs):
|
||||
gpu_primitives.commit_smem()
|
||||
|
||||
# Copy the output from SMEM to GMEM.
|
||||
gpu_primitives.commit_smem()
|
||||
map(lambda bref: bref.copy_out(slot, indices), out_brefs)
|
||||
new_store_slices = last_store_slices[:]
|
||||
for idx, bref in enumerate(out_brefs):
|
||||
if bref.is_index_invariant:
|
||||
assert last_store_slices[idx] is None
|
||||
continue
|
||||
assert last_store_slices[idx] is not None
|
||||
new_store_slices[idx] = tuple(
|
||||
_Slice(s.start, s.size) for s in bref.compute_gmem_slice(indices)
|
||||
)
|
||||
are_same_slices = map(
|
||||
lambda old, new: old == new,
|
||||
last_store_slices[idx],
|
||||
new_store_slices[idx],
|
||||
)
|
||||
slices_changed = ~functools.reduce(lax.bitwise_and, are_same_slices)
|
||||
is_last_step = step == num_steps - 1
|
||||
# TODO(apaszke,slebedev): This still diverges significantly from the
|
||||
# TPU semantics in that it will move on to the next SMEM output slice
|
||||
# even if it's not storing the previous one.
|
||||
bref.copy_out(
|
||||
slot,
|
||||
indices,
|
||||
predicate=lax.bitwise_or(slices_changed, is_last_step),
|
||||
)
|
||||
|
||||
fetch_step = step + max_concurrent_steps
|
||||
fetch_slot = slot # (x + y) % y == x % y
|
||||
@ -174,13 +254,34 @@ def emit_pipeline(
|
||||
lambda: [None] * len(in_brefs),
|
||||
)
|
||||
|
||||
return _inc_grid_by_1(indices, grid), _inc_grid_by_1(fetch_indices, grid)
|
||||
return (
|
||||
_inc_grid_by_1(indices, grid),
|
||||
_inc_grid_by_1(fetch_indices, grid),
|
||||
new_store_slices,
|
||||
)
|
||||
|
||||
indices = (jnp.asarray(0, dtype=lax.dtype(0)),) * len(grid)
|
||||
fetch_indices = indices
|
||||
for _ in range(max_concurrent_steps):
|
||||
fetch_indices = _inc_grid_by_1(fetch_indices, grid)
|
||||
lax.fori_loop(0, num_steps, loop_body, (indices, fetch_indices))
|
||||
last_store_slices = [
|
||||
None
|
||||
if bref.is_index_invariant
|
||||
else (_Slice(-1, -1),) * len(bref.spec.block_shape)
|
||||
for bref in out_brefs
|
||||
]
|
||||
last_indices, _, _ = lax.fori_loop(
|
||||
0, num_steps, loop_body, (indices, fetch_indices, last_store_slices)
|
||||
)
|
||||
|
||||
# Outputs invariant to the sequential axis are never written from inside the
|
||||
# loop. This is the only place where we store them.
|
||||
if all(bref.is_index_invariant for bref in out_brefs):
|
||||
gpu_primitives.commit_smem()
|
||||
last_slot = (num_steps - 1) % max_concurrent_steps
|
||||
for bref in out_brefs:
|
||||
if bref.is_index_invariant:
|
||||
bref.copy_out(last_slot, last_indices, predicate=None)
|
||||
|
||||
# Finalize the pipeline.
|
||||
gpu_primitives.wait_smem_to_gmem(0)
|
||||
|
@ -26,6 +26,7 @@ from jax._src import state
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import arith as arith_dialect
|
||||
from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas.mosaic_gpu import core as gpu_core
|
||||
@ -34,6 +35,7 @@ from jax._src.state import discharge
|
||||
from jax._src.state import indexing
|
||||
from jax._src.state import primitives as state_primitives
|
||||
import jax.experimental.mosaic.gpu as mgpu
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
WARPGROUP_SIZE = 128
|
||||
@ -54,19 +56,31 @@ def _copy_smem_to_gmem_lowering(
|
||||
ctx: lowering.LoweringRuleContext,
|
||||
src,
|
||||
dst,
|
||||
*flat_transforms,
|
||||
*flat_args,
|
||||
src_transforms_treedef,
|
||||
dst_transforms_treedef,
|
||||
has_user_predicate,
|
||||
):
|
||||
predicate = ctx.predicate
|
||||
if has_user_predicate:
|
||||
flat_args, user_predicate = flat_args[:-1], flat_args[-1]
|
||||
predicate = arith_dialect.andi(
|
||||
predicate, lowering._ensure_ir_value(user_predicate, jnp.bool)
|
||||
)
|
||||
flat_src_transforms, flat_dst_transforms = util.split_list(
|
||||
flat_transforms,
|
||||
flat_args,
|
||||
[src_transforms_treedef.num_leaves],
|
||||
)
|
||||
src_transforms = src_transforms_treedef.unflatten(flat_src_transforms)
|
||||
dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms)
|
||||
src, src_transforms = lowering._handle_indexing(src, src_transforms)
|
||||
copy_params = _extract_gmem_copy_params(dst_transforms) | _extract_smem_copy_params(src_transforms)
|
||||
ctx.launch_ctx.async_copy(src_ref=src, dst_ref=dst, **copy_params)
|
||||
ctx.launch_ctx.async_copy(
|
||||
src_ref=src,
|
||||
dst_ref=dst,
|
||||
predicate=predicate,
|
||||
**copy_params,
|
||||
)
|
||||
return ()
|
||||
|
||||
|
||||
@ -98,10 +112,18 @@ def _extract_smem_copy_params(transforms):
|
||||
|
||||
|
||||
def copy_smem_to_gmem(
|
||||
src: pallas_core.AbstractMemoryRef, dst: pallas_core.AbstractMemoryRef
|
||||
src: pallas_core.AbstractMemoryRef,
|
||||
dst: pallas_core.AbstractMemoryRef,
|
||||
predicate: jax.Array | None = None,
|
||||
) -> None:
|
||||
"""Asynchronously copies a SMEM reference to a GMEM reference.
|
||||
|
||||
Args:
|
||||
src: The SMEM reference to copy from.
|
||||
dst: The GMEM reference to copy to.
|
||||
predicate: A boolean indicating whether the copy should be performed. If
|
||||
``None``, the copy is always performed.
|
||||
|
||||
See also:
|
||||
:func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem`
|
||||
:func:`jax.experimental.mosaic.gpu.commit_smem`
|
||||
@ -127,8 +149,10 @@ def copy_smem_to_gmem(
|
||||
dst,
|
||||
*flat_src_transforms,
|
||||
*flat_dst_transforms,
|
||||
*[] if predicate is None else [predicate],
|
||||
src_transforms_treedef=src_transforms_treedef,
|
||||
dst_transforms_treedef=dst_transforms_treedef,
|
||||
has_user_predicate=predicate is not None,
|
||||
)
|
||||
return None
|
||||
|
||||
|
@ -1146,6 +1146,37 @@ class PipelineTest(PallasTest):
|
||||
)
|
||||
np.testing.assert_array_equal(kernel_fn(x), x + 1.0)
|
||||
|
||||
def test_emit_with_grid_invariant_output(self):
|
||||
num_steps = 4
|
||||
|
||||
def kernel(x_gmem, o_gmem):
|
||||
plgpu.emit_pipeline(
|
||||
kernel_body,
|
||||
in_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))],
|
||||
out_specs=[pl.BlockSpec((32, 16), lambda i: (0, 0))],
|
||||
grid=(num_steps,),
|
||||
max_concurrent_steps=2,
|
||||
)(x_gmem, o_gmem)
|
||||
|
||||
def kernel_body(x_smem, o_smem):
|
||||
o_smem[...] = x_smem[...] + 1.0
|
||||
|
||||
x = jnp.arange(32 * num_steps * 16)
|
||||
x = x.reshape(-1, num_steps * 16).astype(jnp.float32)
|
||||
kernel_fn = pl.pallas_call(
|
||||
kernel,
|
||||
in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],
|
||||
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
|
||||
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
|
||||
)
|
||||
y = jnp.empty_like(x)
|
||||
for i in range(num_steps):
|
||||
i_slice = slice(16 * i, 16 * (i + 1))
|
||||
y = y.at[:, :16].set(x[:, i_slice] + 1)
|
||||
# We only compare the elements in the first 16 columns, because the rest
|
||||
# are never written to.
|
||||
np.testing.assert_array_equal(kernel_fn(x)[:, :16], y[:, :16])
|
||||
|
||||
def test_emit_with_parallel_grid(self):
|
||||
self.skipTest("Enable once we support multiple levels of indexing")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user