[Pallas/Mosaic GPU] Add initial support for warpgroup semantics in lowering.

This will allow us to lower Pallas kernels using the Mosaic GPU dialect, and
in turn to perform layout inference and optimization automatically.

The change contains lowering rules for `get` and `swap` (which are necessary
to get a basic example to run), as well as for `add`.

The new lowering path can be used by specifying the `Warpgroup` thread
semantics as part of `pallas_call`'s compiler params.

PiperOrigin-RevId: 725958027
This commit is contained in:
Benjamin Chetioui 2025-02-12 01:47:12 -08:00 committed by jax authors
parent 72e7b93b4d
commit 5ad89006c3
6 changed files with 211 additions and 52 deletions

View File

@ -86,6 +86,7 @@ class GPUCompilerParams(pallas_core.CompilerParams):
delay_release: int = 0
profile_space: int = 0
profile_dir: str = ""
thread_semantics: mgpu.core.ThreadSemantics = mgpu.core.ThreadSemantics.Lane
def __post_init__(self):
if bool(self.profile_space) ^ bool(self.profile_dir):

View File

@ -39,6 +39,7 @@ from jax._src.lib.mlir.dialects import gpu as gpu_dialect
from jax._src.lib.mlir.dialects import memref as memref_dialect
from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect
from jax._src.lib.mlir.dialects import scf as scf_dialect
from jax._src.lib.mlir.dialects import vector as vector_dialect
from jax._src.pallas import core as pallas_core
from jax._src.pallas import pallas_call
from jax._src.pallas import primitives
@ -272,6 +273,7 @@ class ModuleContext:
class LoweringRuleContext:
module_ctx: ModuleContext
launch_ctx: mgpu.LaunchContext
thread_semantics: mgpu_core.ThreadSemantics
predicate: ir.Value
avals_in: Sequence[jax_core.ShapedArray]
avals_out: Sequence[jax_core.ShapedArray]
@ -368,7 +370,7 @@ def lower_jaxpr_to_module(
jaxpr: jax_core.Jaxpr,
name_and_src_info: pallas_core.NameAndSrcInfo,
compiler_params: dict[str, Any],
cost_estimate: pallas_core.CostEstimate | None,
cost_estimate: pallas_core.CostEstimate | None
) -> LoweringResult:
del cost_estimate # Unused.
@ -390,6 +392,9 @@ def lower_jaxpr_to_module(
approx_math = params.get("approx_math", False)
max_concurrent_steps = params.get("max_concurrent_steps", 1)
delay_release = params.get("delay_release", 0)
thread_semantics = params.get(
"thread_semantics", mgpu_core.ThreadSemantics.Lane
)
dimension_semantics = params.get("dimension_semantics")
if dimension_semantics is None:
dimension_semantics = ["parallel"] * len(grid_mapping.grid)
@ -724,6 +729,7 @@ def lower_jaxpr_to_module(
launch_ctx,
lowered_jaxpr,
args,
thread_semantics=thread_semantics
)
if not all(out_sequential_invariant):
@ -826,6 +832,13 @@ def lower_jaxpr_to_module(
prof_spec=prof_spec,
)
)
if thread_semantics == mgpu.ThreadSemantics.Warpgroup:
# Run Python lowering passes. The remaining passes will be run in C++ in
# jax/jaxlib/mosaic/gpu/custom_call.cc
mgpu.infer_layout(module) # pytype: disable=attribute-error
mgpu.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error
mgpu_core._initialize_scratch(launch_ctx, scratch_arr)
return LoweringResult(
@ -833,12 +846,19 @@ def lower_jaxpr_to_module(
)
mosaic_lowering_rules = {}
mosaic_lowering_rules = {
# Lowering rules when using Mosaic GPU lane semantics.
mgpu.ThreadSemantics.Lane: {} ,
# Lowering rules when using Mosaic GPU warpgroup semantics.
mgpu.ThreadSemantics.Warpgroup: {},
}
def register_lowering_rule(primitive: jax_core.Primitive):
def register_lowering_rule(
primitive: jax_core.Primitive, thread_semantics: mgpu.ThreadSemantics
):
def deco(fn):
mosaic_lowering_rules[primitive] = fn
mosaic_lowering_rules[thread_semantics][primitive] = fn
return fn
return deco
@ -863,6 +883,7 @@ def lower_jaxpr_to_mosaic_gpu(
jaxpr: jax_core.Jaxpr,
args: Sequence[ir.Value],
consts=(),
thread_semantics: mgpu.ThreadSemantics = mgpu.ThreadSemantics.Lane,
) -> Sequence[ir.Value]:
env = {}
@ -884,7 +905,7 @@ def lower_jaxpr_to_mosaic_gpu(
)
loc = mlir._source_info_to_location(module_ctx, eqn.primitive, source_info)
with source_info_util.user_context(eqn.source_info.traceback), loc:
if eqn.primitive not in mosaic_lowering_rules:
if eqn.primitive not in mosaic_lowering_rules[thread_semantics]:
raise NotImplementedError(
"Unimplemented primitive in Pallas Mosaic GPU lowering: "
f"{eqn.primitive.name}. "
@ -899,10 +920,11 @@ def lower_jaxpr_to_mosaic_gpu(
wrapper_stack = contextlib.ExitStack()
wrapper_stack.enter_context(launch_ctx.named_region(name))
named_regions.append(wrapper_stack)
rule = mosaic_lowering_rules[eqn.primitive]
rule = mosaic_lowering_rules[thread_semantics][eqn.primitive]
rule_ctx = LoweringRuleContext(
module_ctx,
launch_ctx,
thread_semantics,
predicate=mgpu.single_thread_predicate(per_block=False),
avals_in=[cast(jax_core.ShapedArray, v.aval) for v in eqn.invars],
avals_out=[cast(jax_core.ShapedArray, v.aval) for v in eqn.outvars],
@ -928,7 +950,8 @@ def lower_jaxpr_to_mosaic_gpu(
return map(read_env, jaxpr.outvars)
@register_lowering_rule(primitives.program_id_p)
@register_lowering_rule(primitives.program_id_p, mgpu.ThreadSemantics.Lane)
@register_lowering_rule(primitives.program_id_p, mgpu.ThreadSemantics.Warpgroup)
def _program_id_lowering_rule(ctx: LoweringRuleContext, axis):
if ctx.module_ctx.program_ids is None:
raise NotImplementedError("pl.program_id() is not supported in this context")
@ -972,7 +995,8 @@ def _program_id(parallel_axis: int, squashed_dims: tuple[int, ...]) -> ir.Value:
)
@register_lowering_rule(primitives.num_programs_p)
@register_lowering_rule(primitives.num_programs_p, mgpu.ThreadSemantics.Lane)
@register_lowering_rule(primitives.num_programs_p, mgpu.ThreadSemantics.Warpgroup)
def _num_programs_lowering_rule(ctx: LoweringRuleContext, axis):
del ctx # Unused.
return arith_dialect.index_cast(
@ -1056,7 +1080,7 @@ def _ndindexer_indices(indexer: indexing.NDIndexer) -> tuple[gpu_core.Index, ...
return tuple(indices)
@register_lowering_rule(sp.get_p)
@register_lowering_rule(sp.get_p, mgpu.ThreadSemantics.Lane)
def _get_lowering_rule(ctx: LoweringRuleContext, x_smem, *leaves, tree):
if not isinstance(x_smem, ir.Value) and ir.MemRefType.isinstance(x_smem):
raise TypeError(f"Can only load from references (got {x_smem}).")
@ -1088,7 +1112,33 @@ def _get_lowering_rule(ctx: LoweringRuleContext, x_smem, *leaves, tree):
raise NotImplementedError(f"Unsupported transforms: {transforms}")
@register_lowering_rule(sp.swap_p)
@register_lowering_rule(sp.get_p, mgpu.ThreadSemantics.Warpgroup)
def _get_lowering_rule_wg(ctx: LoweringRuleContext, x_smem, *leaves, tree):
if not isinstance(x_smem, ir.Value) and ir.MemRefType.isinstance(x_smem):
raise TypeError(f"Can only load from references (got {x_smem}).")
x_aval = ctx.avals_in[0]
transforms = jax.tree.unflatten(tree, leaves)
x_smem, transforms = _handle_reshaping(x_smem, transforms)
x_smem, transforms = _handle_indexing(x_smem, transforms)
if transforms:
raise NotImplementedError(
"Transforms are not yet implemented for warpgroup semantics"
)
shape = ctx.avals_out[0].shape
ty = ir.VectorType.get(shape, mgpu_utils.dtype_to_ir_type(x_aval.dtype))
if shape:
zero_index = arith_dialect.constant(ir.IndexType.get(), 0)
indices = [zero_index for _ in range(len(shape))]
else:
indices = []
return vector_dialect.load(ty, x_smem, indices)
@register_lowering_rule(sp.swap_p, mgpu.ThreadSemantics.Lane)
def _swap_lowering_rule(
ctx: LoweringRuleContext, x_smem, value, *leaves, tree
):
@ -1119,19 +1169,53 @@ def _swap_lowering_rule(
raise NotImplementedError(f"Unsupported transforms: {transforms}")
@register_lowering_rule(pjit.pjit_p)
def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_):
@register_lowering_rule(sp.swap_p, mgpu.ThreadSemantics.Warpgroup)
def _swap_lowering_rule_wg(
ctx: LoweringRuleContext, x_smem, value, *leaves, tree
):
if not ir.VectorType.isinstance(value.type):
raise TypeError(f"Can only store vectors (got {value}).")
if not ir.MemRefType.isinstance(x_smem.type):
raise TypeError(f"Can only store to references (got {x_smem}).")
x_aval = ctx.avals_in[0]
transforms = jax.tree.unflatten(tree, leaves)
x_smem, transforms = _handle_reshaping(x_smem, transforms)
x_smem, transforms = _handle_indexing(x_smem, transforms)
if transforms:
raise NotImplementedError(
"Transforms are not yet implemented for warpgroup semantics"
)
shape = ctx.avals_out[0].shape
ty = ir.VectorType.get(shape, mgpu_utils.dtype_to_ir_type(x_aval.dtype))
if shape:
zero_index = arith_dialect.constant(ir.IndexType.get(), 0)
indices = [zero_index for _ in range(len(shape))]
else:
indices = []
old_value = vector_dialect.load(ty, x_smem, indices)
vector_dialect.store(value, x_smem, indices)
return old_value
@register_lowering_rule(pjit.pjit_p, mgpu.ThreadSemantics.Lane)
@register_lowering_rule(pjit.pjit_p, mgpu.ThreadSemantics.Warpgroup)
def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **kwargs):
if jaxpr.consts:
raise NotImplementedError
return lower_jaxpr_to_mosaic_gpu(
ctx.module_ctx, ctx.launch_ctx, jaxpr.jaxpr, args
ctx.module_ctx, ctx.launch_ctx, jaxpr.jaxpr, args,
thread_semantics=ctx.thread_semantics
)
@register_lowering_rule(pjit.mesh_cast_p)
@register_lowering_rule(pjit.mesh_cast_p, mgpu.ThreadSemantics.Lane)
def _mesh_cast_lowering_rule(ctx, x, dst_sharding):
return x
@register_lowering_rule(lax.slice_p)
@register_lowering_rule(lax.slice_p, mgpu.ThreadSemantics.Lane)
def _slice_lowering_rule(
ctx: LoweringRuleContext, x, limit_indices, start_indices, strides
):
@ -1141,7 +1225,7 @@ def _slice_lowering_rule(
return x[tuple(slice(b, e) for b, e in zip(start_indices, limit_indices))]
@register_lowering_rule(lax.select_n_p)
@register_lowering_rule(lax.select_n_p, mgpu.ThreadSemantics.Lane)
def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases):
if len(cases) != 2:
raise NotImplementedError(
@ -1157,7 +1241,7 @@ def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases):
return pred.select(*reversed(cases))
@register_lowering_rule(lax.broadcast_in_dim_p)
@register_lowering_rule(lax.broadcast_in_dim_p, mgpu.ThreadSemantics.Lane)
def _broadcast_in_dim_lowering_rule(
ctx: LoweringRuleContext,
x: mgpu.FragmentedArray,
@ -1181,7 +1265,7 @@ def _broadcast_in_dim_lowering_rule(
return x.broadcast(shape)
@register_lowering_rule(lax.convert_element_type_p)
@register_lowering_rule(lax.convert_element_type_p, mgpu.ThreadSemantics.Lane)
def _convert_element_type_lowering_rule(
ctx: LoweringRuleContext, x, *, new_dtype, weak_type, sharding
):
@ -1192,7 +1276,7 @@ def _convert_element_type_lowering_rule(
)
mosaic_lowering_rules.update({
mosaic_lowering_rules[mgpu.ThreadSemantics.Lane].update({
lax.neg_p: lambda ctx, x: -x,
lax.not_p: lambda ctx, x: ~x,
})
@ -1203,7 +1287,7 @@ def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl):
return impl(x, y)
mosaic_lowering_rules.update({
mosaic_lowering_rules[mgpu.ThreadSemantics.Lane].update({
lax.add_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x + y),
lax.sub_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x - y),
lax.mul_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x * y),
@ -1222,7 +1306,35 @@ mosaic_lowering_rules.update({
})
@register_lowering_rule(lax.div_p)
# TODO(bchetioui): explore how this can be generalized for more binary ops.
def _add_lowering_rule_wg(ctx: LoweringRuleContext, x, y):
x_aval, y_aval = ctx.avals_in
[out_aval] = ctx.avals_out
# TODO(bchetioui): support implicit broadcast.
if x_aval.shape != out_aval.shape or y_aval.shape != out_aval.shape:
raise NotImplementedError(
"Implicit broadcast not implemented with warpgroup semantics")
if np.issubdtype(ctx.avals_in[0].dtype, np.floating):
add_op = arith_dialect.addf
else:
raise NotImplementedError(
"Lowering of non-float addition is not implemented"
)
x = _ensure_vector(x, x_aval.dtype)
y = _ensure_vector(y, y_aval.dtype)
return add_op(x, y)
mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup].update({
lax.add_p: _add_lowering_rule_wg,
# TODO(bchetioui): add support for the remaining binary ops.
})
@register_lowering_rule(lax.div_p, mgpu.ThreadSemantics.Lane)
def _div_lowering_rule(ctx: LoweringRuleContext, x, y):
x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out)
if ir.FloatType.isinstance(x.mlir_dtype):
@ -1230,7 +1342,7 @@ def _div_lowering_rule(ctx: LoweringRuleContext, x, y):
return x // y
@register_lowering_rule(lax.integer_pow_p)
@register_lowering_rule(lax.integer_pow_p, mgpu.ThreadSemantics.Lane)
def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y):
[x_aval] = ctx.avals_in
x = _ensure_fa(x, x_aval.dtype)
@ -1238,44 +1350,44 @@ def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y):
return x * x
return NotImplementedError
@register_lowering_rule(lax.square_p)
@register_lowering_rule(lax.square_p, mgpu.ThreadSemantics.Lane)
def _square_lowering_rule(ctx: LoweringRuleContext, x):
[x_aval] = ctx.avals_in
x = _ensure_fa(x, x_aval.dtype)
return x * x
@register_lowering_rule(lax.rsqrt_p)
@register_lowering_rule(lax.rsqrt_p, mgpu.ThreadSemantics.Lane)
def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x):
[x_aval] = ctx.avals_in
return _ensure_fa(x, x_aval.dtype).rsqrt(approx=ctx.module_ctx.approx_math)
@register_lowering_rule(lax.tanh_p)
@register_lowering_rule(lax.tanh_p, mgpu.ThreadSemantics.Lane)
def _tanh_lowering_rule(ctx: LoweringRuleContext, x):
[x_aval] = ctx.avals_in
return _ensure_fa(x, x_aval.dtype).tanh(approx=ctx.module_ctx.approx_math)
@register_lowering_rule(lax.logistic_p)
@register_lowering_rule(lax.logistic_p, mgpu.ThreadSemantics.Lane)
def _logistic_lowering_rule(ctx: LoweringRuleContext, x):
[x_aval] = ctx.avals_in
a = _ensure_fa(x, x_aval.dtype)
return 1. / (1. + (-a).exp(approx=ctx.module_ctx.approx_math))
@register_lowering_rule(lax.exp_p)
@register_lowering_rule(lax.exp_p, mgpu.ThreadSemantics.Lane)
def _exp_lowering_rule(ctx: LoweringRuleContext, x):
[x_aval] = ctx.avals_in
a = _ensure_fa(x, x_aval.dtype)
return a.exp(approx=ctx.module_ctx.approx_math)
@register_lowering_rule(lax.exp2_p)
@register_lowering_rule(lax.exp2_p, mgpu.ThreadSemantics.Lane)
def _exp2_lowering_rule(ctx: LoweringRuleContext, x):
[x_aval] = ctx.avals_in
a = _ensure_fa(x, x_aval.dtype)
return a.exp2(approx=ctx.module_ctx.approx_math)
@register_lowering_rule(lax.reduce_sum_p)
@register_lowering_rule(lax.reduce_sum_p, mgpu.ThreadSemantics.Lane)
def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
[x_aval] = ctx.avals_in
match x.layout:
@ -1295,7 +1407,7 @@ def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
raise NotImplementedError(f"Unsupported layout {x.layout}")
@register_lowering_rule(lax.reduce_max_p)
@register_lowering_rule(lax.reduce_max_p, mgpu.ThreadSemantics.Lane)
def _reduce_max_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
[x_aval] = ctx.avals_in
match x.layout:
@ -1309,7 +1421,7 @@ def _reduce_max_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
raise NotImplementedError(f"Unsupported layout {x.layout}")
@register_lowering_rule(lax.axis_index_p)
@register_lowering_rule(lax.axis_index_p, mgpu.ThreadSemantics.Lane)
def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable):
i32 = ir.IntegerType.get_signless(32)
grid_names = ctx.module_ctx.grid_names
@ -1354,7 +1466,7 @@ def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable):
)
@register_lowering_rule(primitives.debug_print_p)
@register_lowering_rule(primitives.debug_print_p, mgpu.ThreadSemantics.Lane)
def _debug_print_lowering_rule(
ctx: LoweringRuleContext,
*args,
@ -1387,7 +1499,7 @@ def _debug_print_lowering_rule(
return ()
@register_lowering_rule(primitives.run_scoped_p)
@register_lowering_rule(primitives.run_scoped_p, mgpu.ThreadSemantics.Lane)
def _run_scoped_lowering_rule(
ctx: LoweringRuleContext, *consts, jaxpr: jax_core.Jaxpr
):
@ -1443,7 +1555,7 @@ def _run_scoped_lowering_rule(
return outs
@register_lowering_rule(discharge.run_state_p)
@register_lowering_rule(discharge.run_state_p, mgpu.ThreadSemantics.Lane)
def _run_state_lowering_rule(
ctx: LoweringRuleContext,
*args,
@ -1530,7 +1642,7 @@ def _lower_jaxpr_to_for_loop(
return loop.results
@register_lowering_rule(lax.scan_p)
@register_lowering_rule(lax.scan_p, mgpu.ThreadSemantics.Lane)
def _scan_lowering_rule(
ctx: LoweringRuleContext,
*args,
@ -1618,7 +1730,7 @@ def _lower_while_via_fori(
return ub, ub, *for_out
@register_lowering_rule(lax.while_p)
@register_lowering_rule(lax.while_p, mgpu.ThreadSemantics.Lane)
def _while_lowering_rule(
ctx: LoweringRuleContext,
*args,
@ -1696,7 +1808,7 @@ def _while_lowering_rule(
return carry_treedef.unflatten(list(while_op.results))
@register_lowering_rule(lax.cond_p)
@register_lowering_rule(lax.cond_p, mgpu.ThreadSemantics.Lane)
def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches):
index_aval, *_arg_avals = ctx.avals_in
@ -1753,7 +1865,7 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches):
return treedef.unflatten(list(switch_op.results))
@register_lowering_rule(lax.bitcast_convert_type_p)
@register_lowering_rule(lax.bitcast_convert_type_p, mgpu.ThreadSemantics.Lane)
def _bitcast_convert_type_lowering_rule(
ctx: LoweringRuleContext, operand, *, new_dtype
):
@ -1778,7 +1890,7 @@ def _bitcast_convert_type_lowering_rule(
)
@register_lowering_rule(lax.optimization_barrier_p)
@register_lowering_rule(lax.optimization_barrier_p, mgpu.ThreadSemantics.Lane)
def _optimization_barrier_lowering(ctx: LoweringRuleContext, *args):
args = (_ensure_fa(arg, aval.dtype) for arg, aval in zip(args, ctx.avals_in))
return mgpu.optimization_barrier(*args)
@ -1817,6 +1929,17 @@ def _ensure_fa(x: object, dtype: jnp.dtype) -> mgpu.FragmentedArray:
)
def _ensure_vector(x: object, dtype: jnp.dtype) -> ir.Value:
if isinstance(x, ir.Value) and ir.VectorType.isinstance(x.type):
assert ir.VectorType(x.type).element_type == mgpu_utils.dtype_to_ir_type(dtype)
return x
if isinstance(x, (np.number, np.ndarray, int, float)):
return _ir_constant(x, mgpu_utils.dtype_to_ir_type(dtype))
raise NotImplementedError(f"Unsupported conversion to vector for: {x!r}")
def _ensure_ir_value(x: object, dtype: jnp.dtype) -> ir.Value:
if isinstance(x, ir.Value):
assert x.type == mgpu_utils.dtype_to_ir_type(dtype)

View File

@ -55,6 +55,12 @@ def pallas_call_lowering(
print(f"The grid mapping for pallas_call {name_and_src_info}:")
print(grid_mapping)
thread_semantics = compiler_params.get("mosaic_gpu", {}).get(
"thread_semantics", mosaic_core.ThreadSemantics.Lane
)
if thread_semantics == mosaic_core.ThreadSemantics.Warpgroup:
mosaic_core.dialect.register_dialect(ctx.module_context.context) # pytype: disable=attribute-error
lowering_result = lowering.lower_jaxpr_to_module(
grid_mapping,
jaxpr,

View File

@ -73,7 +73,7 @@ def _copy_smem_to_gmem_abstract_eval(src, dst, *args, **params):
return (), {state.ReadEffect(0), state.WriteEffect(1)}
@lowering.register_lowering_rule(copy_smem_to_gmem_p)
@lowering.register_lowering_rule(copy_smem_to_gmem_p, mgpu.ThreadSemantics.Lane)
def _copy_smem_to_gmem_lowering(
ctx: lowering.LoweringRuleContext,
src,
@ -187,7 +187,7 @@ def _copy_gmem_to_smem_abstract_eval(src, dst, barrier, *args, **params):
return (), {state.ReadEffect(0), state.WriteEffect(1)}
@lowering.register_lowering_rule(copy_gmem_to_smem_p)
@lowering.register_lowering_rule(copy_gmem_to_smem_p, mgpu.ThreadSemantics.Lane)
def _copy_gmem_to_smem_lowering(
ctx: lowering.LoweringRuleContext,
src,
@ -307,7 +307,7 @@ def _barrier_arrive_abstract_eval(barrier, *args, **params):
return (), {gpu_core._memory_effect}
@lowering.register_lowering_rule(barrier_arrive_p)
@lowering.register_lowering_rule(barrier_arrive_p, mgpu.ThreadSemantics.Lane)
def _barrier_arrive_lowering(
ctx: lowering.LoweringRuleContext,
barrier,
@ -345,7 +345,7 @@ def _barrier_wait_abstract_eval(barrier, *args, **params):
return (), {gpu_core._memory_effect}
@lowering.register_lowering_rule(barrier_wait_p)
@lowering.register_lowering_rule(barrier_wait_p, mgpu.ThreadSemantics.Lane)
def _barrier_wait_lowering(
ctx: lowering.LoweringRuleContext,
barrier,
@ -382,7 +382,7 @@ def _wait_smem_to_gmem_abstract_eval(n, *, wait_read_only):
return (), {gpu_core._memory_effect}
@lowering.register_lowering_rule(wait_smem_to_gmem_p)
@lowering.register_lowering_rule(wait_smem_to_gmem_p, mgpu.ThreadSemantics.Lane)
def _wait_smem_to_gmem_lowering(
ctx: lowering.LoweringRuleContext, n, *, wait_read_only
):
@ -483,7 +483,7 @@ def _wgmma_ref_discharge(in_avals, out_avals, *args, **kwargs):
wgmma_p = jax_core.Primitive("wgmma")
@lowering.register_lowering_rule(wgmma_p)
@lowering.register_lowering_rule(wgmma_p, mgpu.ThreadSemantics.Lane)
def _wgmma_lowering(
ctx: lowering.LoweringRuleContext,
acc,
@ -588,7 +588,7 @@ def wgmma_wait_effectful_abstract_eval(_):
return [], {gpu_core._wgmma_pipeline_effect}
@lowering.register_lowering_rule(wgmma_wait_p)
@lowering.register_lowering_rule(wgmma_wait_p, mgpu.ThreadSemantics.Lane)
def _wgmma_wait_lowering(ctx: lowering.LoweringRuleContext, allow_groups):
del ctx
nvvm_dialect.wgmma_wait_group_sync_aligned(allow_groups)
@ -619,7 +619,7 @@ def _wgmma_accumulator_deref_discharge(in_avals, out_avals, acc):
return (None,), wgmma_accumulator_deref_p.bind(acc)
@lowering.register_lowering_rule(wgmma_accumulator_deref_p)
@lowering.register_lowering_rule(wgmma_accumulator_deref_p, mgpu.ThreadSemantics.Lane)
def _wgmma_accumulator_deref_lowering(ctx: lowering.LoweringRuleContext, acc):
del ctx
nvvm_dialect.wgmma_wait_group_sync_aligned(0)
@ -665,7 +665,7 @@ def _layout_cast_abstract_eval(x, new_layout):
return x
@lowering.register_lowering_rule(layout_cast_p)
@lowering.register_lowering_rule(layout_cast_p, mgpu.ThreadSemantics.Lane)
def _layout_cast_lowering(ctx: lowering.LoweringRuleContext, x, *, new_layout):
del ctx # Unused.
return x.to_layout(_get_mgpu_layout(new_layout))
@ -686,7 +686,7 @@ def _set_max_registers_abstract_eval(n, *, action):
return (), {gpu_core._memory_effect}
@lowering.register_lowering_rule(set_max_registers_p)
@lowering.register_lowering_rule(set_max_registers_p, mgpu.ThreadSemantics.Lane)
def _set_max_registers_lowering(
ctx: lowering.LoweringRuleContext, n, *, action
):
@ -714,7 +714,7 @@ def _commit_smem_abstract_eval():
return (), {gpu_core._memory_effect}
@lowering.register_lowering_rule(commit_smem_p)
@lowering.register_lowering_rule(commit_smem_p, mgpu.ThreadSemantics.Lane)
def _commit_smem_lowering(ctx: lowering.LoweringRuleContext):
mgpu.commit_shared()
return ()
@ -733,7 +733,7 @@ def _broadcasted_iota_abstract_eval(dtype, shape, dimension, layout):
return jax_core.ShapedArray(shape, dtype)
@lowering.register_lowering_rule(broadcasted_iota_p)
@lowering.register_lowering_rule(broadcasted_iota_p, mgpu.ThreadSemantics.Lane)
def _broadcasted_iota_lowering(
ctx: lowering.LoweringRuleContext, dtype, shape, dimension, layout
):

View File

@ -1212,9 +1212,11 @@ _PALLAS_USE_MOSAIC_GPU = config.bool_flag(
default=config.bool_env("JAX_PALLAS_USE_MOSAIC_GPU", False),
help=(
"If True, lower Pallas kernels to the experimental Mosaic GPU"
" dialect, instead of Trition IR."
" dialect, instead of Triton IR."
),
)
_PALLAS_VERBOSE_ERRORS = config.bool_flag(
"jax_pallas_verbose_errors",
default=config.bool_env("JAX_PALLAS_VERBOSE_ERRORS", True),

View File

@ -27,6 +27,7 @@ from jax import lax
from jax._src import test_util as jtu
from jax._src.pallas.mosaic_gpu import pipeline as mgpu_pipeline
from jax.experimental import pallas as pl
from jax.experimental.mosaic.gpu import core as mosaic_gpu_core
from jax.experimental.pallas import mosaic_gpu as plgpu
import jax.numpy as jnp
import numpy as np
@ -2073,5 +2074,31 @@ class ExamplesTest(PallasTest):
# TODO(apaszke): Clusters and multicast
class PallasCallWarpgroupSemanticsTest(PallasTest):
def setUp(self):
self.compiler_params = plgpu.GPUCompilerParams(
thread_semantics=mosaic_gpu_core.ThreadSemantics.Warpgroup
)
super().setUp()
@parameterized.named_parameters(
("add_float", lambda x, y: x + y, np.float32),
)
def test_binary_op_wg_semantics(self, bop, dtype):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], dtype=dtype),
compiler_params=self.compiler_params
)
def kernel(x_ref, y_ref, o_ref):
o_ref[...] = bop(x_ref[...], y_ref[...])
x = jnp.arange(256).astype(dtype)
y = x + 1
np.testing.assert_array_equal(kernel(x, y), bop(x, y))
if __name__ == "__main__":
absltest.main()