[Pallas/MGPU] Undo transforms before giving refs back to users

This is a second attempt at this change. The first one was rolled back because of reported failures.

Reverts 411928b9668570bbc3795522aba94cece6894881

PiperOrigin-RevId: 680943744
This commit is contained in:
Adam Paszke 2024-10-01 03:30:15 -07:00 committed by jax authors
parent 14ef2b6a21
commit cac2b8d5fc
8 changed files with 249 additions and 90 deletions

View File

@ -40,6 +40,7 @@ from jax._src import util
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.state import discharge as state_discharge
from jax._src.state.types import TransformedRef
import jax.numpy as jnp
@ -496,7 +497,7 @@ class BlockSpec:
mapping = BlockMapping(
block_shape=mapped_block_shape,
block_aval=block_aval,
transformed_block_aval=block_aval, # There are no transforms by default
index_map_jaxpr=jax_core.ClosedJaxpr(jaxpr, consts),
index_map_src_info=index_map_src_info,
indexing_mode=self.indexing_mode,
@ -523,7 +524,7 @@ BlockSpecTree = Any
class MemoryRefTransform(Protocol):
"""Transforms a memory reference on load or store."""
def __call__(self, block_aval: AbstractMemoryRef) -> AbstractMemoryRef:
def undo(self, ref: TransformedRef) -> TransformedRef:
raise NotImplementedError("Abstract evaluation not implemented.")
@ -533,8 +534,10 @@ class BlockMapping:
See the `check_invariants` method for precise specification.
"""
# TODO(apaszke,sharadmv): Replace mapped dims in block_shape with a transform.
# After all, it's just indexing out singleton dimensions.
block_shape: tuple[Mapped | int, ...]
block_aval: AbstractMemoryRef # The block ref aval
transformed_block_aval: AbstractMemoryRef
index_map_jaxpr: jax_core.ClosedJaxpr
index_map_src_info: NameAndSrcInfo
indexing_mode: IndexingMode
@ -546,8 +549,8 @@ class BlockMapping:
if not config.enable_checks.value: return
unmapped_block_shape = tuple(s for s in self.block_shape if s is not mapped)
assert unmapped_block_shape == self.block_aval.shape, (
self.block_shape, self.block_aval)
assert unmapped_block_shape == self.ref_aval.shape, (
self.block_shape, self.ref_aval.shape)
assert len(self.block_shape) == len(self.array_shape_dtype.shape), (
self.block_shape, self.array_shape_dtype
)
@ -568,12 +571,21 @@ class BlockMapping:
return new_self
@property
def ref_aval(self) -> AbstractMemoryRef:
def block_aval(self) -> AbstractMemoryRef:
# If you hit this, make sure you take transforms into account and use either
# ref_aval or transformed_block_aval.
assert not self.transforms, "Lowering failed to handle transforms"
return self.transformed_block_aval
@property
def ref_aval(self) -> AbstractMemoryRef | TransformedRef:
"""Returns the abstract value of the Ref after transformations."""
block_aval = self.block_aval
for transform in self.transforms:
block_aval = transform(block_aval)
return block_aval
if not self.transforms:
return self.transformed_block_aval
ref = TransformedRef(self.transformed_block_aval, ())
for transform in reversed(self.transforms):
ref = transform.undo(ref)
return ref
def compute_start_indices_interpret(self, loop_idx, *args):
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(

View File

@ -14,14 +14,16 @@
"""Contains GPU-specific Pallas abstractions."""
import abc
from collections.abc import Sequence
import dataclasses
import enum
from typing import Any, ClassVar, Literal, Protocol
from typing import Any, ClassVar, Literal
from jax._src import core as jax_core
from jax._src import dtypes
from jax._src import tree_util
from jax._src.state.types import Transform
from jax._src.pallas import core as pallas_core
import jax.experimental.mosaic.gpu as mgpu
import jax.numpy as jnp
@ -63,9 +65,15 @@ class GPUMemorySpace(enum.Enum):
return pallas_core.MemoryRef(shape, dtype, memory_space=self)
class MemoryRefTransform(pallas_core.MemoryRefTransform, Protocol):
class MemoryRefTransform(pallas_core.MemoryRefTransform, abc.ABC):
@abc.abstractmethod
def to_gpu_transform(self) -> mgpu.MemRefTransform:
...
pass
def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray:
return aval.update(
shape=self.to_gpu_transform().transform_shape(aval.shape)
)
@dataclasses.dataclass(frozen=True)
@ -79,52 +87,67 @@ class TilingTransform(MemoryRefTransform):
tiling: tuple[int, ...]
def __call__(
self, block_aval: pallas_core.AbstractMemoryRef
) -> pallas_core.AbstractMemoryRef:
block_shape = block_aval.shape
old_tiled_dims = block_shape[-len(self.tiling) :]
num_tiles = tuple(
block_dim // tiling_dim
for block_dim, tiling_dim in zip(old_tiled_dims, self.tiling)
)
rem = (
block_dim % tiling_dim
for block_dim, tiling_dim in zip(old_tiled_dims, self.tiling)
)
if any(rem):
raise ValueError(
f"Block shape {block_shape} is not divisible by tiling {self.tiling}"
)
new_block_shape = block_shape[: -len(self.tiling)] + num_tiles + self.tiling
return block_aval.update(
inner_aval=block_aval.inner_aval.update(shape=new_block_shape)
def undo(self, ref: pallas_core.TransformedRef) -> pallas_core.TransformedRef:
return dataclasses.replace(
ref, transforms=(*ref.transforms, UntileRef(self.tiling))
)
def to_gpu_transform(self) -> mgpu.MemRefTransform:
return mgpu.TileTransform(self.tiling)
@dataclasses.dataclass(frozen=True)
class UntileRef(Transform):
tiling: tuple[int, ...]
def transform_shape(self, shape):
if shape is None:
return None
assert shape[-len(self.tiling) :] == self.tiling
shape = shape[: -len(self.tiling)] # Drop tiling
return shape[: -len(self.tiling)] + tuple(
block_dim * tiling_dim
for block_dim, tiling_dim in zip(shape[-len(self.tiling) :], self.tiling)
)
def transform_dtype(self, dtype):
return dtype
@dataclasses.dataclass(frozen=True)
class TransposeTransform(MemoryRefTransform):
"""Transpose a tiled memref."""
permutation: tuple[int, ...]
def __call__(
self, block_aval: pallas_core.AbstractMemoryRef
) -> pallas_core.AbstractMemoryRef:
shape = block_aval.shape # pytype: disable=attribute-error
return block_aval.update(
inner_aval=block_aval.inner_aval.update(
shape=self.to_gpu_transform().transform_shape(shape)
)
def __post_init__(self):
if set(self.permutation) != set(range(len(self.permutation))):
raise ValueError(f"Permutation {self.permutation} is not a permutation.")
def undo(self, ref: pallas_core.TransformedRef) -> pallas_core.TransformedRef:
inverse = [-1] * len(self.permutation)
for i, p in enumerate(self.permutation):
inverse[p] = i
return dataclasses.replace(
ref, transforms=(*ref.transforms, TransposeRef(tuple(inverse)))
)
def to_gpu_transform(self) -> mgpu.MemRefTransform:
return mgpu.TransposeTransform(self.permutation)
@dataclasses.dataclass(frozen=True)
class TransposeRef(Transform):
permutation: tuple[int, ...]
def transform_shape(self, shape):
if shape is None:
return None
return tuple(shape[i] for i in self.permutation)
def transform_dtype(self, dtype):
return dtype
@dataclasses.dataclass(frozen=True)
class GPUBlockMapping(pallas_core.BlockMapping):
swizzle: int | None = None
@ -156,9 +179,14 @@ class GPUBlockSpec(pallas_core.BlockSpec):
transforms = self.transforms
if not isinstance(transforms, tuple):
transforms = (transforms,)
block_inner_aval = bm.block_aval.inner_aval
for t in transforms:
block_inner_aval = t(block_inner_aval)
return GPUBlockMapping(
block_shape=bm.block_shape,
block_aval=bm.block_aval,
transformed_block_aval=bm.block_aval.update(
inner_aval=block_inner_aval
),
origin=bm.origin,
index_map_jaxpr=bm.index_map_jaxpr,
index_map_src_info=bm.index_map_src_info,

View File

@ -280,7 +280,7 @@ def lower_jaxpr_to_module(
in_in_smem, out_in_smem = util.split_list(
[
bm.block_aval.memory_space in (None, gpu_core.SMEM)
bm.transformed_block_aval.memory_space in (None, gpu_core.SMEM)
for bm in block_mappings
],
[grid_mapping.num_inputs],
@ -290,9 +290,13 @@ def lower_jaxpr_to_module(
in_block_mappings, out_block_mappings = util.split_list(
block_mappings, [grid_mapping.num_inputs]
)
# TODO(apaszke): We can shrink allocation if max_concurrent_steps is more than the actual number of steps.
# We allocate the fully transformed shapes here. All primitives have seen the
# inverse transformation stack and will understand how to handle it.
in_structs_smem = [
jax.ShapeDtypeStruct(
[max_concurrent_steps, *bm.ref_aval.shape], bm.ref_aval.dtype
[max_concurrent_steps, *bm.transformed_block_aval.shape],
bm.transformed_block_aval.dtype,
)
if in_smem
else None
@ -312,6 +316,9 @@ def lower_jaxpr_to_module(
)
out_structs_gmem = [*grid_mapping.out_shapes]
# TODO(justinfu): Implement output Memref transforms
for bm in block_mappings[grid_mapping.num_inputs :]:
if bm.transforms:
raise NotImplementedError("Output transforms are not supported")
out_structs_smem = [
jax.ShapeDtypeStruct([max_concurrent_steps, *bm.block_shape], s.dtype)
if in_smem

View File

@ -119,7 +119,7 @@ effects.control_flow_allowed_effects.add_type(_WGMMAPipelineEffect)
wgmma_ref_p = jax_core.Primitive("wgmma_ref")
wgmma_ref_p.multiple_results = True
def wgmma(acc, a, b, *, rhs_transpose: bool = False, swizzle: int = 128):
def wgmma(acc, a, b, *, swizzle: int = 128):
"""Asynchronous warp group matmul.
The sm90 wgmma instruction, essentially acc[...] += a @ b. Requires
@ -129,24 +129,49 @@ def wgmma(acc, a, b, *, rhs_transpose: bool = False, swizzle: int = 128):
acc: The accumulator register.
a: The left hand side operand.
b: The right hand side operand.
transpose: Whether to transpose b.
n_tile: The number of tiles to use.
swizzle: The swizzle pattern.
"""
if not isinstance(acc.aval, gpu_core.WGMMAAbstractAccumulatorRef):
raise TypeError(f"Expected WGMMAAbstractAccumulatorRef got {acc}")
ma, ka, tma, tka = a.shape
kb, nb, tkb, tnb = b.shape
mc, nc = acc.shape
# TODO(apaszke): Make swizzling another transform and read it from the refs.
if not isinstance(a, pallas_core.TransformedRef):
raise ValueError("WGMMA inputs must be tiled references.")
if rhs_transpose:
kb, nb, tkb, tnb = nb, kb, tnb, tkb
m, n = acc.shape
m2, k = a.shape
k2, n2 = b.shape
if tma * ma != mc or nb * tnb != nc or ka != kb or tka != tkb:
raise ValueError(f"Incompatible shapes: {a.shape=}, {b.shape=}, {acc.shape=}, {rhs_transpose=}")
if m != m2 or n != n2 or k != k2:
raise ValueError(
f"Incompatible shapes for matrix multiplication: lhs={a.shape},"
f" rhs={b.shape=}, acc={acc.shape}"
)
return wgmma_ref_p.bind(acc, a, b, swizzle=swizzle, rhs_transpose=rhs_transpose)
if (dtype := a.dtype) != b.dtype:
raise ValueError(f"Mixed input dtypes for matrix multiplication unsupported: lhs={a.dtype}, rhs={b.dtype}")
if not isinstance(a, pallas_core.TransformedRef):
raise ValueError("WGMMA lhs must be a tiled reference.")
if not isinstance(b, pallas_core.TransformedRef):
raise ValueError("WGMMA rhs must be a tiled reference.")
elems_128b = swizzle // dtype.itemsize
if a.transforms != (gpu_core.UntileRef((64, elems_128b)),):
raise ValueError(
f"WGMMA lhs must be tiled with 64x{elems_128b} tiles for element type"
f" {dtype}."
)
rhs_transpose_transform = gpu_core.TransposeRef((1, 0, 2, 3))
rhs_tiling = gpu_core.UntileRef((elems_128b, elems_128b))
if not (
rhs_transpose := (b.transforms == (rhs_transpose_transform, rhs_tiling))
) and not (b.transforms == (rhs_tiling,)):
raise ValueError(
f"WGMMA rhs must be tiled with {elems_128b}x{elems_128b} tiles for"
f" element type {dtype} (and optionally transposed)."
)
return wgmma_ref_p.bind(acc, a.ref, b.ref, swizzle=swizzle, rhs_transpose=rhs_transpose)
@wgmma_ref_p.def_effectful_abstract_eval

View File

@ -39,6 +39,7 @@ from jax._src.pallas import core as pallas_core
from jax._src.pallas import primitives
from jax._src.pallas import utils as pallas_utils
from jax._src.state import discharge as state_discharge
from jax._src.state import types as state_types
from jax._src.util import (
safe_map,
safe_zip,
@ -941,6 +942,7 @@ def _pallas_call_batching_rule(
batched_grid_mapping,
tuple(flat_kernel_avals),
kernel_in_tree,
tuple(() for _ in flat_kernel_avals),
interpret=interpret,
)
@ -1136,12 +1138,27 @@ def pallas_call_checkify_rule(error: checkify.Error,
return new_error, results
checkify.error_checks[pallas_call_p] = pallas_call_checkify_rule
# All of those shenanigans are because we can't make TransformedRef a PyTree,
# because they should appear as atomic JAX values to the users.
@lu.transformation
def wrap_with_transforms(transforms, *args):
new_args = tuple(
state_types.TransformedRef(a, t) if t else a
for a, t in zip(args, transforms)
)
res = yield new_args, {}
yield res
@weakref_lru_cache
def _trace_kernel_to_jaxpr(fun: Callable,
def _trace_kernel_to_jaxpr(
fun: Callable,
name_and_src_info: pallas_core.NameAndSrcInfo,
grid_mapping: GridMapping,
kernel_avals: tuple[pallas_core.AbstractMemRef, ...],
kernel_in_tree: tree_util.PyTreeDef,
kernel_in_transforms: tuple[tuple[pallas_core.Transform, ...], ...],
interpret: bool,
) -> jax_core.ClosedJaxpr:
if interpret:
@ -1149,6 +1166,7 @@ def _trace_kernel_to_jaxpr(fun: Callable,
kernel_avals))
wrapped_kernel_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(fun), kernel_in_tree)
wrapped_kernel_fun = wrap_with_transforms(wrapped_kernel_fun, kernel_in_transforms)
debug = pe.debug_info(fun, kernel_in_tree, out_tree_thunk, False, "pallas_call")
with grid_mapping.trace_env():
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,
@ -1406,16 +1424,25 @@ def pallas_call(
for p in in_paths)
out_origins = tuple(f"outputs{tree_util.keystr(p)}" for p in out_paths)
# TODO(necula): check that input_output_aliases is well-formed: no duplicates, etc.
kernel_avals, grid_mapping = pallas_core.get_grid_mapping(
kernel_args, grid_mapping = pallas_core.get_grid_mapping(
grid_spec,
flat_in_avals, in_tree, in_origins,
flat_out_avals, out_tree, out_origins)
flat_kernel_avals, kernel_in_tree = tree_util.tree_flatten(kernel_avals)
flat_kernel_args, kernel_in_tree = tree_util.tree_flatten(kernel_args)
flat_kernel_avals = tuple(
x.ref if isinstance(x, state_types.TransformedRef) else x
for x in flat_kernel_args
)
# Note that only a subset of all transforms can be found here, and they are
# never expected to contains any arrays.
kernel_arg_transforms = tuple(
x.transforms if isinstance(x, state_types.TransformedRef) else ()
for x in flat_kernel_args
)
with pallas_core.interpret_mode_env(interpret):
jaxpr = _trace_kernel_to_jaxpr(
kernel, kernel_src_info,
grid_mapping, tuple(flat_kernel_avals), kernel_in_tree,
interpret=interpret)
kernel, kernel_src_info, grid_mapping, tuple(flat_kernel_avals),
kernel_in_tree, kernel_arg_transforms, interpret=interpret)
for i_idx, o_idx in input_output_aliases.items():
if i_idx not in range(len(flat_in_avals)):
raise ValueError(

View File

@ -256,3 +256,10 @@ class NDIndexer:
# In NDIndexers, the int_indexer_shape is *always* at the front of the
# result.
return (*self.int_indexer_shape, *slice_shape)
def transform_shape(self, shape: None | tuple[int | Array, ...]) -> None | tuple[int | Array, ...]:
del shape # Unused
return self.get_indexer_shape()
def transform_dtype(self, dtype):
return dtype

View File

@ -18,8 +18,9 @@ from __future__ import annotations
from collections.abc import Sequence
import dataclasses
import math
from typing import Any, Union
from typing import Any, Union, Protocol
from jax._src.typing import DTypeLike
from jax._src import core
from jax._src import dtypes
from jax._src import effects
@ -105,8 +106,39 @@ class RefBitcaster:
assert not arrays
return cls(*metadata)
def transform_shape(
self, shape: tuple[int | Array, ...] | None
) -> tuple[int | Array, ...] | None:
del shape # Unused
return self.shape
def transform_dtype(self, dtype):
del dtype # Unused
return self.dtype
class Transform(Protocol):
def transform_shape(
self, shape: tuple[int | Array, ...] | None
) -> tuple[int | Array, ...] | None:
"""Transform the shape.
Can return None if the input shape is not known, but must return a concrete
result when the input shape is known.
"""
return shape
def transform_dtype(
self, dtype: DTypeLike | None
) -> DTypeLike | None:
"""Transform the dtype.
Can return None if the input dtype is not known, but must return a concrete
result when the input dtype is known.
"""
return dtype
Transform = indexing.NDIndexer | RefBitcaster
@dataclasses.dataclass
class RefIndexer:
@ -122,30 +154,51 @@ class RefIndexer:
return TransformedRef(self.ref_or_view, (indexer,))
@dataclasses.dataclass
@dataclasses.dataclass(frozen=True)
class TransformedRef:
ref: Any
transforms: tuple[Transform, ...]
@property
def is_dynamic_size(self):
return self.transforms[-1].is_dynamic_size
return any(not isinstance(i, int) for i in self.shape)
@property
def shape(self) -> tuple[int | Array, ...]:
assert (
len(self.transforms) > 0
), "Should not be able to create a trivial TransformedRef"
if isinstance(self.transforms[-1], indexing.NDIndexer):
return self.transforms[-1].get_indexer_shape()
return self.transforms[-1].shape
unprocessed, shape = 0, None
# We first go backwards to find the first transform that knows its output
# shape. It's possible none of them do!
for unprocessed, t in enumerate(reversed(self.transforms), 1):
if (shape := t.transform_shape(None)) is not None:
unprocessed -= 1
break
if shape is None:
shape = self.ref.shape
if not unprocessed:
return shape
# If there are any unprocessed transforms left, we apply them to the shape
# we've found previuously.
for t in self.transforms[-unprocessed:]:
shape = t.transform_shape(shape)
assert shape is not None
return shape
@property
def dtype(self):
for transform in reversed(self.transforms):
if isinstance(transform, RefBitcaster):
return transform.dtype
return self.ref.dtype
# The structure of this method is analogous to `shape`. See comments there.
unprocessed, dtype = 0, None
for unprocessed, t in enumerate(reversed(self.transforms), 1):
if (dtype := t.transform_dtype(None)) is not None:
unprocessed -= 1
break
if dtype is None:
dtype = self.ref.dtype
if not unprocessed:
return dtype
for t in self.transforms[-unprocessed:]:
dtype = t.transform_dtype(dtype)
assert dtype is not None
return dtype
@property
def at(self) -> RefIndexer:

View File

@ -404,17 +404,13 @@ class PallasCallTest(PallasTest):
swizzle=128,
),
],
out_specs=pl.BlockSpec((2, 1, 64, 64), lambda i, j: (i, j, 64, 64)),
out_shape=jax.ShapeDtypeStruct((4, 2, 64, 64), jnp.float16),
out_shape=jax.ShapeDtypeStruct((128, 64), jnp.float16),
grid=(2, 2),
)
def kernel(x_ref, o_ref):
assert x_ref.shape == (2, 1, 64, 64), x_ref.shape
o_ref[...] = x_ref[...]
assert x_ref.shape == (128, 64), x_ref.shape
x = jnp.zeros((256, 128), dtype=jnp.float16)
result = kernel(x)
self.assertEqual(result.shape, (4, 2, 64, 64))
kernel.lower(jax.ShapeDtypeStruct((256, 128), jnp.float16))
def test_fori_loop_array(self):
@functools.partial(
@ -473,7 +469,7 @@ class PallasCallTest(PallasTest):
elems_128b = swizzle // jnp.dtype(dtype).itemsize
def kernel(a_ref, b_ref, o_ref):
def scope(acc_ref):
plgpu.wgmma(acc_ref, a_ref, b_ref, rhs_transpose=rhs_transpose)
plgpu.wgmma(acc_ref, a_ref, b_ref)
return acc_ref[...]
o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 128), jnp.float32))
@ -534,6 +530,10 @@ class PallasCallTest(PallasTest):
tile_k = elems_128b
m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n
def kernel(a_ref, b_ref, o_ref, acc_ref):
# Make sure tiling does not alter the shape of references
assert a_ref.shape == (tile_m, tile_k)
assert b_ref.shape == (tile_k, tile_n)
assert o_ref.shape == acc_ref.shape == (tile_m, tile_n)
plgpu.wgmma(acc_ref, a_ref, b_ref)
plgpu.wgmma_wait(0) # TODO(apaszke): Delay the pipeline to avoid memory races
# TODO(apaszke): Only store in the last step. It doesn't work because we