Added pl.program_id and pl.num_programs to Mosaic GPU lowering

PiperOrigin-RevId: 662836490
This commit is contained in:
Sergei Lebedev 2024-08-14 02:22:50 -07:00 committed by jax authors
parent 2ab7558425
commit 6290cd77fc
2 changed files with 87 additions and 30 deletions

View File

@ -19,7 +19,6 @@ from __future__ import annotations
from collections.abc import Sequence
import dataclasses
import functools
import itertools as it
import math
from typing import Any, cast
@ -161,22 +160,6 @@ def lower_jaxpr_to_module(
) -> LoweringResult:
del cost_estimate # Unused.
in_structs_gmem = [*grid_mapping.in_shapes]
in_structs_smem = [
jax.ShapeDtypeStruct(bm.block_shape, s.dtype)
for bm, s in zip(
grid_mapping.block_mappings[: grid_mapping.num_inputs],
grid_mapping.in_shapes,
)
]
out_structs_gmem = [*grid_mapping.out_shapes]
out_structs_smem = [
jax.ShapeDtypeStruct(bm.block_shape, s.dtype)
for bm, s in zip(
grid_mapping.block_mappings[grid_mapping.num_inputs :],
grid_mapping.out_shapes,
)
]
assert len(jaxpr.outvars) == 0
assert not grid_mapping.vmapped_dims
if len(grid_mapping.grid) > 3:
@ -209,31 +192,46 @@ def lower_jaxpr_to_module(
grid += (1,) * (3 - len(grid))
block = (128,) + (1,) * (len(grid) - 1)
in_structs_gmem = [*grid_mapping.in_shapes]
in_structs_smem = [
jax.ShapeDtypeStruct(bm.block_shape, s.dtype)
for bm, s in zip(
grid_mapping.block_mappings[: grid_mapping.num_inputs],
grid_mapping.in_shapes,
)
]
out_structs_gmem = [*grid_mapping.out_shapes]
out_structs_smem = [
jax.ShapeDtypeStruct(bm.block_shape, s.dtype)
for bm, s in zip(
grid_mapping.block_mappings[grid_mapping.num_inputs :],
grid_mapping.out_shapes,
)
]
def body(launch_ctx: mosaic_gpu.LaunchContext, *buffers):
*buffers_gmem, (*buffers_smem, runtime_smem, barriers) = buffers
assert len(buffers_gmem) == len(buffers_smem)
in_buffers_gmem = buffers_gmem[: len(in_structs_gmem)]
in_buffers_smem = buffers_smem[: len(in_structs_smem)]
out_buffers_gmem = buffers_gmem[len(in_structs_gmem) :]
out_buffers_smem = buffers_smem[len(in_structs_smem) :]
in_buffers_gmem, out_buffers_gmem = util.split_list(
buffers_gmem, [grid_mapping.num_inputs]
)
in_buffers_smem, out_buffers_smem = util.split_list(
buffers_smem, [grid_mapping.num_inputs]
)
[barrier] = cast(mgpu.BarrierRef, barriers)
module_ctx = ModuleContext(
name_and_src_info.name, grid_mapping, runtime_smem, smem_used_bytes=0
)
program_ids = [
arith_dialect.index_cast(
ir.IntegerType.get_signless(32), gpu_dialect.block_id(dim)
)
for dim in it.islice(gpu_dialect.Dimension, len(grid_mapping.grid))
]
program_ids = map(_program_id, range(len(grid_mapping.grid)))
start_indices = map(
functools.partial(_eval_index_map, module_ctx, program_ids),
grid_mapping.block_mappings,
)
in_start_indices = start_indices[: len(in_structs_gmem)]
out_start_indices = start_indices[len(in_structs_gmem) :]
in_start_indices, out_start_indices = util.split_list(
start_indices, [grid_mapping.num_inputs]
)
with mgpu.single_thread():
for start_indices, b_gmem, b_smem in zip(
@ -252,7 +250,9 @@ def lower_jaxpr_to_module(
uniform=False,
)
barrier.wait()
if grid_mapping.num_inputs:
# Only wait if async copies were issued.
barrier.wait()
_ = lower_jaxpr_to_mosaic_gpu(module_ctx, jaxpr, buffers_smem)
mgpu.commit_shared()
@ -359,6 +359,28 @@ def lower_jaxpr_to_mosaic_gpu(
return map(read_env, jaxpr.outvars)
@register_lowering_rule(primitives.program_id_p)
def _program_id_lowering_rule(ctx: LoweringRuleContext, axis):
del ctx # Unused.
return _program_id(axis)
def _program_id(axis: int) -> ir.Value:
return arith_dialect.index_cast(
ir.IntegerType.get_signless(32),
gpu_dialect.block_id(gpu_dialect.Dimension(axis)),
)
@register_lowering_rule(primitives.num_programs_p)
def _num_programs_lowering_rule(ctx: LoweringRuleContext, axis):
del ctx # Unused.
return arith_dialect.index_cast(
ir.IntegerType.get_signless(32),
gpu_dialect.block_dim(gpu_dialect.Dimension(axis)),
)
@register_lowering_rule(sp.get_p)
def _get_lowering_rule(ctx: LoweringRuleContext, x_smem, *indexers, tree):
del ctx, tree # Unused.
@ -510,6 +532,9 @@ def _ensure_fa(x: object, aval: jax_core.ShapedArray) -> mgpu.FragmentedArray:
return mgpu.FragmentedArray.splat(
_ir_constant(x, mlir.dtype_to_ir_type(aval.dtype)), ()
)
elif isinstance(x, ir.Value):
if isinstance(x.type, (ir.IntegerType, ir.FloatType)):
return mgpu.FragmentedArray.splat(x, ())
raise NotImplementedError

View File

@ -162,6 +162,38 @@ class PallasCallTest(PallasTest):
o = f(inp)
np.testing.assert_array_equal(o, inp + 1.0)
def test_program_id(self):
@functools.partial(
pl.pallas_call,
in_specs=(),
out_specs=pl.BlockSpec((128,), lambda *i: i),
out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.int32),
grid=2,
)
def kernel(o_ref):
o_ref[...] = jnp.full(o_ref.shape, pl.program_id(0))
np.testing.assert_array_equal(
kernel(),
jnp.array([0] * 128 + [1] * 128, dtype=jnp.int32),
)
def test_num_programs(self):
@functools.partial(
pl.pallas_call,
in_specs=(),
out_specs=pl.BlockSpec((128,), lambda *i: i),
out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.int32),
grid=2,
)
def kernel(o_ref):
o_ref[...] = jnp.full(o_ref.shape, pl.num_programs(0))
np.testing.assert_array_equal(
kernel(),
jnp.full([256], 2, dtype=jnp.int32),
)
if __name__ == "__main__":
absltest.main()