mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Added pl.program_id and pl.num_programs to Mosaic GPU lowering
PiperOrigin-RevId: 662836490
This commit is contained in:
parent
2ab7558425
commit
6290cd77fc
@ -19,7 +19,6 @@ from __future__ import annotations
|
||||
from collections.abc import Sequence
|
||||
import dataclasses
|
||||
import functools
|
||||
import itertools as it
|
||||
import math
|
||||
from typing import Any, cast
|
||||
|
||||
@ -161,22 +160,6 @@ def lower_jaxpr_to_module(
|
||||
) -> LoweringResult:
|
||||
del cost_estimate # Unused.
|
||||
|
||||
in_structs_gmem = [*grid_mapping.in_shapes]
|
||||
in_structs_smem = [
|
||||
jax.ShapeDtypeStruct(bm.block_shape, s.dtype)
|
||||
for bm, s in zip(
|
||||
grid_mapping.block_mappings[: grid_mapping.num_inputs],
|
||||
grid_mapping.in_shapes,
|
||||
)
|
||||
]
|
||||
out_structs_gmem = [*grid_mapping.out_shapes]
|
||||
out_structs_smem = [
|
||||
jax.ShapeDtypeStruct(bm.block_shape, s.dtype)
|
||||
for bm, s in zip(
|
||||
grid_mapping.block_mappings[grid_mapping.num_inputs :],
|
||||
grid_mapping.out_shapes,
|
||||
)
|
||||
]
|
||||
assert len(jaxpr.outvars) == 0
|
||||
assert not grid_mapping.vmapped_dims
|
||||
if len(grid_mapping.grid) > 3:
|
||||
@ -209,31 +192,46 @@ def lower_jaxpr_to_module(
|
||||
grid += (1,) * (3 - len(grid))
|
||||
block = (128,) + (1,) * (len(grid) - 1)
|
||||
|
||||
in_structs_gmem = [*grid_mapping.in_shapes]
|
||||
in_structs_smem = [
|
||||
jax.ShapeDtypeStruct(bm.block_shape, s.dtype)
|
||||
for bm, s in zip(
|
||||
grid_mapping.block_mappings[: grid_mapping.num_inputs],
|
||||
grid_mapping.in_shapes,
|
||||
)
|
||||
]
|
||||
out_structs_gmem = [*grid_mapping.out_shapes]
|
||||
out_structs_smem = [
|
||||
jax.ShapeDtypeStruct(bm.block_shape, s.dtype)
|
||||
for bm, s in zip(
|
||||
grid_mapping.block_mappings[grid_mapping.num_inputs :],
|
||||
grid_mapping.out_shapes,
|
||||
)
|
||||
]
|
||||
|
||||
def body(launch_ctx: mosaic_gpu.LaunchContext, *buffers):
|
||||
*buffers_gmem, (*buffers_smem, runtime_smem, barriers) = buffers
|
||||
assert len(buffers_gmem) == len(buffers_smem)
|
||||
in_buffers_gmem = buffers_gmem[: len(in_structs_gmem)]
|
||||
in_buffers_smem = buffers_smem[: len(in_structs_smem)]
|
||||
out_buffers_gmem = buffers_gmem[len(in_structs_gmem) :]
|
||||
out_buffers_smem = buffers_smem[len(in_structs_smem) :]
|
||||
in_buffers_gmem, out_buffers_gmem = util.split_list(
|
||||
buffers_gmem, [grid_mapping.num_inputs]
|
||||
)
|
||||
in_buffers_smem, out_buffers_smem = util.split_list(
|
||||
buffers_smem, [grid_mapping.num_inputs]
|
||||
)
|
||||
|
||||
[barrier] = cast(mgpu.BarrierRef, barriers)
|
||||
|
||||
module_ctx = ModuleContext(
|
||||
name_and_src_info.name, grid_mapping, runtime_smem, smem_used_bytes=0
|
||||
)
|
||||
program_ids = [
|
||||
arith_dialect.index_cast(
|
||||
ir.IntegerType.get_signless(32), gpu_dialect.block_id(dim)
|
||||
)
|
||||
for dim in it.islice(gpu_dialect.Dimension, len(grid_mapping.grid))
|
||||
]
|
||||
program_ids = map(_program_id, range(len(grid_mapping.grid)))
|
||||
start_indices = map(
|
||||
functools.partial(_eval_index_map, module_ctx, program_ids),
|
||||
grid_mapping.block_mappings,
|
||||
)
|
||||
in_start_indices = start_indices[: len(in_structs_gmem)]
|
||||
out_start_indices = start_indices[len(in_structs_gmem) :]
|
||||
in_start_indices, out_start_indices = util.split_list(
|
||||
start_indices, [grid_mapping.num_inputs]
|
||||
)
|
||||
|
||||
with mgpu.single_thread():
|
||||
for start_indices, b_gmem, b_smem in zip(
|
||||
@ -252,7 +250,9 @@ def lower_jaxpr_to_module(
|
||||
uniform=False,
|
||||
)
|
||||
|
||||
barrier.wait()
|
||||
if grid_mapping.num_inputs:
|
||||
# Only wait if async copies were issued.
|
||||
barrier.wait()
|
||||
|
||||
_ = lower_jaxpr_to_mosaic_gpu(module_ctx, jaxpr, buffers_smem)
|
||||
mgpu.commit_shared()
|
||||
@ -359,6 +359,28 @@ def lower_jaxpr_to_mosaic_gpu(
|
||||
return map(read_env, jaxpr.outvars)
|
||||
|
||||
|
||||
@register_lowering_rule(primitives.program_id_p)
|
||||
def _program_id_lowering_rule(ctx: LoweringRuleContext, axis):
|
||||
del ctx # Unused.
|
||||
return _program_id(axis)
|
||||
|
||||
|
||||
def _program_id(axis: int) -> ir.Value:
|
||||
return arith_dialect.index_cast(
|
||||
ir.IntegerType.get_signless(32),
|
||||
gpu_dialect.block_id(gpu_dialect.Dimension(axis)),
|
||||
)
|
||||
|
||||
|
||||
@register_lowering_rule(primitives.num_programs_p)
|
||||
def _num_programs_lowering_rule(ctx: LoweringRuleContext, axis):
|
||||
del ctx # Unused.
|
||||
return arith_dialect.index_cast(
|
||||
ir.IntegerType.get_signless(32),
|
||||
gpu_dialect.block_dim(gpu_dialect.Dimension(axis)),
|
||||
)
|
||||
|
||||
|
||||
@register_lowering_rule(sp.get_p)
|
||||
def _get_lowering_rule(ctx: LoweringRuleContext, x_smem, *indexers, tree):
|
||||
del ctx, tree # Unused.
|
||||
@ -510,6 +532,9 @@ def _ensure_fa(x: object, aval: jax_core.ShapedArray) -> mgpu.FragmentedArray:
|
||||
return mgpu.FragmentedArray.splat(
|
||||
_ir_constant(x, mlir.dtype_to_ir_type(aval.dtype)), ()
|
||||
)
|
||||
elif isinstance(x, ir.Value):
|
||||
if isinstance(x.type, (ir.IntegerType, ir.FloatType)):
|
||||
return mgpu.FragmentedArray.splat(x, ())
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
@ -162,6 +162,38 @@ class PallasCallTest(PallasTest):
|
||||
o = f(inp)
|
||||
np.testing.assert_array_equal(o, inp + 1.0)
|
||||
|
||||
def test_program_id(self):
|
||||
@functools.partial(
|
||||
pl.pallas_call,
|
||||
in_specs=(),
|
||||
out_specs=pl.BlockSpec((128,), lambda *i: i),
|
||||
out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.int32),
|
||||
grid=2,
|
||||
)
|
||||
def kernel(o_ref):
|
||||
o_ref[...] = jnp.full(o_ref.shape, pl.program_id(0))
|
||||
|
||||
np.testing.assert_array_equal(
|
||||
kernel(),
|
||||
jnp.array([0] * 128 + [1] * 128, dtype=jnp.int32),
|
||||
)
|
||||
|
||||
def test_num_programs(self):
|
||||
@functools.partial(
|
||||
pl.pallas_call,
|
||||
in_specs=(),
|
||||
out_specs=pl.BlockSpec((128,), lambda *i: i),
|
||||
out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.int32),
|
||||
grid=2,
|
||||
)
|
||||
def kernel(o_ref):
|
||||
o_ref[...] = jnp.full(o_ref.shape, pl.num_programs(0))
|
||||
|
||||
np.testing.assert_array_equal(
|
||||
kernel(),
|
||||
jnp.full([256], 2, dtype=jnp.int32),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user