mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Pallas/MGPU] Implement block spec evaluation correctly
The preivous implementation made some surprising assumptions about the contents of the block specs and wasn't correct in general. The new implementation handles all the cases and seems to be sufficient to finally run the matmul example with multiple k steps while producing correct results (it's also shorter!). PiperOrigin-RevId: 679175212
This commit is contained in:
parent
a3284bd8a3
commit
076287fb5c
@ -185,7 +185,7 @@ class LoweringError(Exception): # pylint: disable=g-bad-exception-name
|
||||
def _eval_index_map(
|
||||
module_ctx: ModuleContext,
|
||||
launch_ctx: mgpu.LaunchContext,
|
||||
idx: ir.Value,
|
||||
idx: Sequence[ir.Value],
|
||||
block_mapping: pallas_core.BlockMapping,
|
||||
) -> Sequence[ir.Value]:
|
||||
block_indices = lower_jaxpr_to_mosaic_gpu(
|
||||
@ -238,10 +238,7 @@ def lower_jaxpr_to_module(
|
||||
jaxpr, [True] * len(jaxpr.outvars), instantiate=True
|
||||
)
|
||||
|
||||
grid = grid_mapping.grid
|
||||
if len(grid) < 3:
|
||||
grid += (1,) * (3 - len(grid))
|
||||
block = (128,) + (1,) * (len(grid) - 1)
|
||||
block = (128, 1, 1)
|
||||
params = compiler_params.get("mosaic_gpu", {})
|
||||
approx_math = params.get("approx_math", False)
|
||||
max_concurrent_steps = params.get("max_concurrent_steps", 1)
|
||||
@ -256,8 +253,25 @@ def lower_jaxpr_to_module(
|
||||
sequential_axes = tuple(
|
||||
i for i, s in enumerate(dimension_semantics) if s == "sequential"
|
||||
)
|
||||
assert all(grid[axis] for axis in sequential_axes)
|
||||
assert all(block[axis] == 1 for axis in sequential_axes)
|
||||
|
||||
grid = [d for i, d in enumerate(grid_mapping.grid) if i not in sequential_axes]
|
||||
if len(grid) < 3:
|
||||
grid += (1,) * (3 - len(grid))
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Only <=3D grids are supported in Mosaic GPU lowering."
|
||||
)
|
||||
# Compute the number of steps along each sequential axis.
|
||||
if sequential_axes:
|
||||
# TODO(slebedev): Support multiple sequential axes.
|
||||
if len(sequential_axes) > 1:
|
||||
raise NotImplementedError(
|
||||
"Multiple sequential axes are not supported in Mosaic GPU lowering."
|
||||
)
|
||||
[sequential_axis] = sequential_axes
|
||||
num_steps = grid_mapping.grid[sequential_axis]
|
||||
else:
|
||||
num_steps = 1
|
||||
|
||||
in_in_smem, out_in_smem = util.split_list(
|
||||
[
|
||||
@ -268,10 +282,9 @@ def lower_jaxpr_to_module(
|
||||
)
|
||||
|
||||
in_structs_gmem = [*grid_mapping.in_shapes]
|
||||
in_block_shapes = [
|
||||
bm.block_shape
|
||||
for bm in grid_mapping.block_mappings[: grid_mapping.num_inputs]
|
||||
]
|
||||
in_block_mappings, out_block_mappings = util.split_list(
|
||||
block_mappings, [grid_mapping.num_inputs]
|
||||
)
|
||||
in_structs_smem = [
|
||||
jax.ShapeDtypeStruct(
|
||||
[max_concurrent_steps, *bm.ref_aval.shape], bm.ref_aval.dtype
|
||||
@ -283,8 +296,7 @@ def lower_jaxpr_to_module(
|
||||
)
|
||||
]
|
||||
in_gmem_transforms = [
|
||||
cast(gpu_core.MemoryRefTransform, bm.transforms)
|
||||
|
||||
cast(gpu_core.MemoryRefTransform, bm.transforms)
|
||||
for bm in grid_mapping.block_mappings[: grid_mapping.num_inputs]
|
||||
]
|
||||
in_swizzles = map(
|
||||
@ -322,17 +334,14 @@ def lower_jaxpr_to_module(
|
||||
)
|
||||
barriers, *extra_barriers = barriers
|
||||
|
||||
parallel_count = it.count()
|
||||
program_ids_template = [
|
||||
_program_id(next(parallel_count)) if i not in sequential_axes else None
|
||||
for i in range(len(grid_mapping.grid))
|
||||
]
|
||||
module_ctx = ModuleContext(
|
||||
name_and_src_info.name, grid_mapping, approx_math, runtime_smem
|
||||
)
|
||||
program_ids = map(_program_id, range(len(grid_mapping.grid)))
|
||||
start_indices = map(
|
||||
partial(_eval_index_map, module_ctx, launch_ctx, program_ids),
|
||||
block_mappings,
|
||||
)
|
||||
in_start_indices, out_start_indices = util.split_list(
|
||||
start_indices, [grid_mapping.num_inputs]
|
||||
)
|
||||
|
||||
smem_scratch_it = iter(scratch_buffers_smem)
|
||||
scratch_buffers_template = []
|
||||
@ -385,20 +394,14 @@ def lower_jaxpr_to_module(
|
||||
)
|
||||
|
||||
def gmem_slice(
|
||||
start_indices: Sequence[ir.Value],
|
||||
step: ir.Value,
|
||||
shape: Sequence[int],
|
||||
block_mapping: pallas_core.BlockMapping,
|
||||
) -> Sequence[mgpu.DynamicSlice]:
|
||||
assert len(sequential_axes) == 1
|
||||
program_ids = [step if i is None else i for i in program_ids_template]
|
||||
idxs = _eval_index_map(module_ctx, launch_ctx, program_ids, block_mapping)
|
||||
return tuple(
|
||||
mgpu.ds(
|
||||
arith_dialect.addi(
|
||||
start_index, arith_dialect.muli(step, _as_index(dim))
|
||||
)
|
||||
if axis in sequential_axes
|
||||
else start_index,
|
||||
dim,
|
||||
)
|
||||
for axis, (start_index, dim) in enumerate(zip(start_indices, shape))
|
||||
mgpu.ds(idx, dim) for idx, dim in zip(idxs, block_mapping.block_shape)
|
||||
)
|
||||
|
||||
def fetch(idx: int, step: ir.Value, slot: ir.Value) -> None:
|
||||
@ -410,11 +413,7 @@ def lower_jaxpr_to_module(
|
||||
launch_ctx.async_copy(
|
||||
src_ref=in_buffers_gmem[idx],
|
||||
dst_ref=mgpu.memref_slice(in_buffers_smem[idx], slot),
|
||||
gmem_slice=gmem_slice(
|
||||
in_start_indices[idx],
|
||||
step,
|
||||
in_block_shapes[idx],
|
||||
),
|
||||
gmem_slice=gmem_slice(step, in_block_mappings[idx]),
|
||||
barrier=barriers[slot],
|
||||
gmem_transform=tuple(gmem_transforms),
|
||||
swizzle=in_swizzles[idx],
|
||||
@ -430,27 +429,11 @@ def lower_jaxpr_to_module(
|
||||
launch_ctx.async_copy(
|
||||
src_ref=mgpu.memref_slice(out_buffers_smem[idx], slot),
|
||||
dst_ref=out_buffers_gmem[idx],
|
||||
gmem_slice=gmem_slice(
|
||||
out_start_indices[idx],
|
||||
step,
|
||||
ir.MemRefType(out_buffers_smem[idx].type).shape[1:],
|
||||
),
|
||||
gmem_slice=gmem_slice(step, out_block_mappings[idx]),
|
||||
swizzle=None,
|
||||
uniform=False,
|
||||
)
|
||||
|
||||
# Compute the number of steps along each sequential axis.
|
||||
if sequential_axes:
|
||||
# TODO(slebedev): Support multiple sequential axes.
|
||||
if len(sequential_axes) > 1:
|
||||
raise NotImplementedError(
|
||||
"Multiple sequential axes are not supported in Mosaic GPU lowering."
|
||||
)
|
||||
[sequential_axis] = sequential_axes
|
||||
num_steps = grid_mapping.grid[sequential_axis]
|
||||
else:
|
||||
num_steps = 1
|
||||
|
||||
with mgpu.single_thread():
|
||||
for slot in range(min(max_concurrent_steps, num_steps)):
|
||||
barriers[slot].arrive_expect_tx(in_transfer_bytes)
|
||||
@ -619,6 +602,7 @@ def lower_jaxpr_to_mosaic_gpu(
|
||||
|
||||
@register_lowering_rule(primitives.program_id_p)
|
||||
def _program_id_lowering_rule(ctx: LoweringRuleContext, axis):
|
||||
# TODO(apaszke): Sequential axis should be handled specially!!
|
||||
del ctx # Unused.
|
||||
return _program_id(axis)
|
||||
|
||||
|
@ -462,13 +462,8 @@ class PallasCallTest(PallasTest):
|
||||
dtype = jnp.float16
|
||||
swizzle = 128
|
||||
elems_128b = swizzle // jnp.dtype(dtype).itemsize
|
||||
# TODO(apaszke): Make the grid and tile sizes larger
|
||||
# grid_m, grid_k, grid_n = 132, 10, 4
|
||||
# TODO(apaszke): Increasing grid_k causes th test to fail.
|
||||
# It seems like our pipelining implementation has a number of races.
|
||||
grid_m, grid_k, grid_n = 2, 1, 2
|
||||
# tile_m = tile_n = 128
|
||||
tile_m = tile_n = 64
|
||||
grid_m, grid_k, grid_n = 132, 10, 4
|
||||
tile_m = tile_n = 128
|
||||
tile_k = elems_128b
|
||||
m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n
|
||||
def kernel(a_ref, b_ref, o_ref, acc_ref):
|
||||
|
Loading…
x
Reference in New Issue
Block a user