[Pallas/MGPU] Implement block spec evaluation correctly

The preivous implementation made some surprising assumptions about the contents
of the block specs and wasn't correct in general. The new implementation handles
all the cases and seems to be sufficient to finally run the matmul example with
multiple k steps while producing correct results (it's also shorter!).

PiperOrigin-RevId: 679175212
This commit is contained in:
Adam Paszke 2024-09-26 09:14:38 -07:00 committed by jax authors
parent a3284bd8a3
commit 076287fb5c
2 changed files with 40 additions and 61 deletions

View File

@ -185,7 +185,7 @@ class LoweringError(Exception): # pylint: disable=g-bad-exception-name
def _eval_index_map(
module_ctx: ModuleContext,
launch_ctx: mgpu.LaunchContext,
idx: ir.Value,
idx: Sequence[ir.Value],
block_mapping: pallas_core.BlockMapping,
) -> Sequence[ir.Value]:
block_indices = lower_jaxpr_to_mosaic_gpu(
@ -238,10 +238,7 @@ def lower_jaxpr_to_module(
jaxpr, [True] * len(jaxpr.outvars), instantiate=True
)
grid = grid_mapping.grid
if len(grid) < 3:
grid += (1,) * (3 - len(grid))
block = (128,) + (1,) * (len(grid) - 1)
block = (128, 1, 1)
params = compiler_params.get("mosaic_gpu", {})
approx_math = params.get("approx_math", False)
max_concurrent_steps = params.get("max_concurrent_steps", 1)
@ -256,8 +253,25 @@ def lower_jaxpr_to_module(
sequential_axes = tuple(
i for i, s in enumerate(dimension_semantics) if s == "sequential"
)
assert all(grid[axis] for axis in sequential_axes)
assert all(block[axis] == 1 for axis in sequential_axes)
grid = [d for i, d in enumerate(grid_mapping.grid) if i not in sequential_axes]
if len(grid) < 3:
grid += (1,) * (3 - len(grid))
else:
raise NotImplementedError(
"Only <=3D grids are supported in Mosaic GPU lowering."
)
# Compute the number of steps along each sequential axis.
if sequential_axes:
# TODO(slebedev): Support multiple sequential axes.
if len(sequential_axes) > 1:
raise NotImplementedError(
"Multiple sequential axes are not supported in Mosaic GPU lowering."
)
[sequential_axis] = sequential_axes
num_steps = grid_mapping.grid[sequential_axis]
else:
num_steps = 1
in_in_smem, out_in_smem = util.split_list(
[
@ -268,10 +282,9 @@ def lower_jaxpr_to_module(
)
in_structs_gmem = [*grid_mapping.in_shapes]
in_block_shapes = [
bm.block_shape
for bm in grid_mapping.block_mappings[: grid_mapping.num_inputs]
]
in_block_mappings, out_block_mappings = util.split_list(
block_mappings, [grid_mapping.num_inputs]
)
in_structs_smem = [
jax.ShapeDtypeStruct(
[max_concurrent_steps, *bm.ref_aval.shape], bm.ref_aval.dtype
@ -283,8 +296,7 @@ def lower_jaxpr_to_module(
)
]
in_gmem_transforms = [
cast(gpu_core.MemoryRefTransform, bm.transforms)
cast(gpu_core.MemoryRefTransform, bm.transforms)
for bm in grid_mapping.block_mappings[: grid_mapping.num_inputs]
]
in_swizzles = map(
@ -322,17 +334,14 @@ def lower_jaxpr_to_module(
)
barriers, *extra_barriers = barriers
parallel_count = it.count()
program_ids_template = [
_program_id(next(parallel_count)) if i not in sequential_axes else None
for i in range(len(grid_mapping.grid))
]
module_ctx = ModuleContext(
name_and_src_info.name, grid_mapping, approx_math, runtime_smem
)
program_ids = map(_program_id, range(len(grid_mapping.grid)))
start_indices = map(
partial(_eval_index_map, module_ctx, launch_ctx, program_ids),
block_mappings,
)
in_start_indices, out_start_indices = util.split_list(
start_indices, [grid_mapping.num_inputs]
)
smem_scratch_it = iter(scratch_buffers_smem)
scratch_buffers_template = []
@ -385,20 +394,14 @@ def lower_jaxpr_to_module(
)
def gmem_slice(
start_indices: Sequence[ir.Value],
step: ir.Value,
shape: Sequence[int],
block_mapping: pallas_core.BlockMapping,
) -> Sequence[mgpu.DynamicSlice]:
assert len(sequential_axes) == 1
program_ids = [step if i is None else i for i in program_ids_template]
idxs = _eval_index_map(module_ctx, launch_ctx, program_ids, block_mapping)
return tuple(
mgpu.ds(
arith_dialect.addi(
start_index, arith_dialect.muli(step, _as_index(dim))
)
if axis in sequential_axes
else start_index,
dim,
)
for axis, (start_index, dim) in enumerate(zip(start_indices, shape))
mgpu.ds(idx, dim) for idx, dim in zip(idxs, block_mapping.block_shape)
)
def fetch(idx: int, step: ir.Value, slot: ir.Value) -> None:
@ -410,11 +413,7 @@ def lower_jaxpr_to_module(
launch_ctx.async_copy(
src_ref=in_buffers_gmem[idx],
dst_ref=mgpu.memref_slice(in_buffers_smem[idx], slot),
gmem_slice=gmem_slice(
in_start_indices[idx],
step,
in_block_shapes[idx],
),
gmem_slice=gmem_slice(step, in_block_mappings[idx]),
barrier=barriers[slot],
gmem_transform=tuple(gmem_transforms),
swizzle=in_swizzles[idx],
@ -430,27 +429,11 @@ def lower_jaxpr_to_module(
launch_ctx.async_copy(
src_ref=mgpu.memref_slice(out_buffers_smem[idx], slot),
dst_ref=out_buffers_gmem[idx],
gmem_slice=gmem_slice(
out_start_indices[idx],
step,
ir.MemRefType(out_buffers_smem[idx].type).shape[1:],
),
gmem_slice=gmem_slice(step, out_block_mappings[idx]),
swizzle=None,
uniform=False,
)
# Compute the number of steps along each sequential axis.
if sequential_axes:
# TODO(slebedev): Support multiple sequential axes.
if len(sequential_axes) > 1:
raise NotImplementedError(
"Multiple sequential axes are not supported in Mosaic GPU lowering."
)
[sequential_axis] = sequential_axes
num_steps = grid_mapping.grid[sequential_axis]
else:
num_steps = 1
with mgpu.single_thread():
for slot in range(min(max_concurrent_steps, num_steps)):
barriers[slot].arrive_expect_tx(in_transfer_bytes)
@ -619,6 +602,7 @@ def lower_jaxpr_to_mosaic_gpu(
@register_lowering_rule(primitives.program_id_p)
def _program_id_lowering_rule(ctx: LoweringRuleContext, axis):
# TODO(apaszke): Sequential axis should be handled specially!!
del ctx # Unused.
return _program_id(axis)

View File

@ -462,13 +462,8 @@ class PallasCallTest(PallasTest):
dtype = jnp.float16
swizzle = 128
elems_128b = swizzle // jnp.dtype(dtype).itemsize
# TODO(apaszke): Make the grid and tile sizes larger
# grid_m, grid_k, grid_n = 132, 10, 4
# TODO(apaszke): Increasing grid_k causes th test to fail.
# It seems like our pipelining implementation has a number of races.
grid_m, grid_k, grid_n = 2, 1, 2
# tile_m = tile_n = 128
tile_m = tile_n = 64
grid_m, grid_k, grid_n = 132, 10, 4
tile_m = tile_n = 128
tile_k = elems_128b
m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n
def kernel(a_ref, b_ref, o_ref, acc_ref):