From 5ad89006c360ad62c8ec5f960c10b4ac0a307ed1 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Wed, 12 Feb 2025 01:47:12 -0800 Subject: [PATCH] [Pallas/Mosaic GPU] Add initial support for warpgroup semantics in lowering. This will allow us to lower Pallas kernels using the Mosaic GPU dialect, and in turn to perform layout inference and optimization automatically. The change contains lowering rules for `get` and `swap` (which are necessary to get a basic example to run), as well as for `add`. The new lowering path can be used by specifying the `Warpgroup` thread semantics as part of `pallas_call`'s compiler params. PiperOrigin-RevId: 725958027 --- jax/_src/pallas/mosaic_gpu/core.py | 1 + jax/_src/pallas/mosaic_gpu/lowering.py | 201 ++++++++++++++---- .../mosaic_gpu/pallas_call_registration.py | 6 + jax/_src/pallas/mosaic_gpu/primitives.py | 24 +-- jax/_src/pallas/pallas_call.py | 4 +- tests/pallas/mosaic_gpu_test.py | 27 +++ 6 files changed, 211 insertions(+), 52 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 42cf59ec6..c9a0ddd44 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -86,6 +86,7 @@ class GPUCompilerParams(pallas_core.CompilerParams): delay_release: int = 0 profile_space: int = 0 profile_dir: str = "" + thread_semantics: mgpu.core.ThreadSemantics = mgpu.core.ThreadSemantics.Lane def __post_init__(self): if bool(self.profile_space) ^ bool(self.profile_dir): diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index c2d4c8450..c39fa4c3b 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -39,6 +39,7 @@ from jax._src.lib.mlir.dialects import gpu as gpu_dialect from jax._src.lib.mlir.dialects import memref as memref_dialect from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect from jax._src.lib.mlir.dialects import scf as scf_dialect +from jax._src.lib.mlir.dialects import vector as vector_dialect from jax._src.pallas import core as pallas_core from jax._src.pallas import pallas_call from jax._src.pallas import primitives @@ -272,6 +273,7 @@ class ModuleContext: class LoweringRuleContext: module_ctx: ModuleContext launch_ctx: mgpu.LaunchContext + thread_semantics: mgpu_core.ThreadSemantics predicate: ir.Value avals_in: Sequence[jax_core.ShapedArray] avals_out: Sequence[jax_core.ShapedArray] @@ -368,7 +370,7 @@ def lower_jaxpr_to_module( jaxpr: jax_core.Jaxpr, name_and_src_info: pallas_core.NameAndSrcInfo, compiler_params: dict[str, Any], - cost_estimate: pallas_core.CostEstimate | None, + cost_estimate: pallas_core.CostEstimate | None ) -> LoweringResult: del cost_estimate # Unused. @@ -390,6 +392,9 @@ def lower_jaxpr_to_module( approx_math = params.get("approx_math", False) max_concurrent_steps = params.get("max_concurrent_steps", 1) delay_release = params.get("delay_release", 0) + thread_semantics = params.get( + "thread_semantics", mgpu_core.ThreadSemantics.Lane + ) dimension_semantics = params.get("dimension_semantics") if dimension_semantics is None: dimension_semantics = ["parallel"] * len(grid_mapping.grid) @@ -724,6 +729,7 @@ def lower_jaxpr_to_module( launch_ctx, lowered_jaxpr, args, + thread_semantics=thread_semantics ) if not all(out_sequential_invariant): @@ -826,6 +832,13 @@ def lower_jaxpr_to_module( prof_spec=prof_spec, ) ) + + if thread_semantics == mgpu.ThreadSemantics.Warpgroup: + # Run Python lowering passes. The remaining passes will be run in C++ in + # jax/jaxlib/mosaic/gpu/custom_call.cc + mgpu.infer_layout(module) # pytype: disable=attribute-error + mgpu.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error + mgpu_core._initialize_scratch(launch_ctx, scratch_arr) return LoweringResult( @@ -833,12 +846,19 @@ def lower_jaxpr_to_module( ) -mosaic_lowering_rules = {} +mosaic_lowering_rules = { + # Lowering rules when using Mosaic GPU lane semantics. + mgpu.ThreadSemantics.Lane: {} , + # Lowering rules when using Mosaic GPU warpgroup semantics. + mgpu.ThreadSemantics.Warpgroup: {}, +} -def register_lowering_rule(primitive: jax_core.Primitive): +def register_lowering_rule( + primitive: jax_core.Primitive, thread_semantics: mgpu.ThreadSemantics +): def deco(fn): - mosaic_lowering_rules[primitive] = fn + mosaic_lowering_rules[thread_semantics][primitive] = fn return fn return deco @@ -863,6 +883,7 @@ def lower_jaxpr_to_mosaic_gpu( jaxpr: jax_core.Jaxpr, args: Sequence[ir.Value], consts=(), + thread_semantics: mgpu.ThreadSemantics = mgpu.ThreadSemantics.Lane, ) -> Sequence[ir.Value]: env = {} @@ -884,7 +905,7 @@ def lower_jaxpr_to_mosaic_gpu( ) loc = mlir._source_info_to_location(module_ctx, eqn.primitive, source_info) with source_info_util.user_context(eqn.source_info.traceback), loc: - if eqn.primitive not in mosaic_lowering_rules: + if eqn.primitive not in mosaic_lowering_rules[thread_semantics]: raise NotImplementedError( "Unimplemented primitive in Pallas Mosaic GPU lowering: " f"{eqn.primitive.name}. " @@ -899,10 +920,11 @@ def lower_jaxpr_to_mosaic_gpu( wrapper_stack = contextlib.ExitStack() wrapper_stack.enter_context(launch_ctx.named_region(name)) named_regions.append(wrapper_stack) - rule = mosaic_lowering_rules[eqn.primitive] + rule = mosaic_lowering_rules[thread_semantics][eqn.primitive] rule_ctx = LoweringRuleContext( module_ctx, launch_ctx, + thread_semantics, predicate=mgpu.single_thread_predicate(per_block=False), avals_in=[cast(jax_core.ShapedArray, v.aval) for v in eqn.invars], avals_out=[cast(jax_core.ShapedArray, v.aval) for v in eqn.outvars], @@ -928,7 +950,8 @@ def lower_jaxpr_to_mosaic_gpu( return map(read_env, jaxpr.outvars) -@register_lowering_rule(primitives.program_id_p) +@register_lowering_rule(primitives.program_id_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(primitives.program_id_p, mgpu.ThreadSemantics.Warpgroup) def _program_id_lowering_rule(ctx: LoweringRuleContext, axis): if ctx.module_ctx.program_ids is None: raise NotImplementedError("pl.program_id() is not supported in this context") @@ -972,7 +995,8 @@ def _program_id(parallel_axis: int, squashed_dims: tuple[int, ...]) -> ir.Value: ) -@register_lowering_rule(primitives.num_programs_p) +@register_lowering_rule(primitives.num_programs_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(primitives.num_programs_p, mgpu.ThreadSemantics.Warpgroup) def _num_programs_lowering_rule(ctx: LoweringRuleContext, axis): del ctx # Unused. return arith_dialect.index_cast( @@ -1056,7 +1080,7 @@ def _ndindexer_indices(indexer: indexing.NDIndexer) -> tuple[gpu_core.Index, ... return tuple(indices) -@register_lowering_rule(sp.get_p) +@register_lowering_rule(sp.get_p, mgpu.ThreadSemantics.Lane) def _get_lowering_rule(ctx: LoweringRuleContext, x_smem, *leaves, tree): if not isinstance(x_smem, ir.Value) and ir.MemRefType.isinstance(x_smem): raise TypeError(f"Can only load from references (got {x_smem}).") @@ -1088,7 +1112,33 @@ def _get_lowering_rule(ctx: LoweringRuleContext, x_smem, *leaves, tree): raise NotImplementedError(f"Unsupported transforms: {transforms}") -@register_lowering_rule(sp.swap_p) +@register_lowering_rule(sp.get_p, mgpu.ThreadSemantics.Warpgroup) +def _get_lowering_rule_wg(ctx: LoweringRuleContext, x_smem, *leaves, tree): + if not isinstance(x_smem, ir.Value) and ir.MemRefType.isinstance(x_smem): + raise TypeError(f"Can only load from references (got {x_smem}).") + + x_aval = ctx.avals_in[0] + + transforms = jax.tree.unflatten(tree, leaves) + x_smem, transforms = _handle_reshaping(x_smem, transforms) + x_smem, transforms = _handle_indexing(x_smem, transforms) + + if transforms: + raise NotImplementedError( + "Transforms are not yet implemented for warpgroup semantics" + ) + + shape = ctx.avals_out[0].shape + ty = ir.VectorType.get(shape, mgpu_utils.dtype_to_ir_type(x_aval.dtype)) + if shape: + zero_index = arith_dialect.constant(ir.IndexType.get(), 0) + indices = [zero_index for _ in range(len(shape))] + else: + indices = [] + return vector_dialect.load(ty, x_smem, indices) + + +@register_lowering_rule(sp.swap_p, mgpu.ThreadSemantics.Lane) def _swap_lowering_rule( ctx: LoweringRuleContext, x_smem, value, *leaves, tree ): @@ -1119,19 +1169,53 @@ def _swap_lowering_rule( raise NotImplementedError(f"Unsupported transforms: {transforms}") -@register_lowering_rule(pjit.pjit_p) -def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_): +@register_lowering_rule(sp.swap_p, mgpu.ThreadSemantics.Warpgroup) +def _swap_lowering_rule_wg( + ctx: LoweringRuleContext, x_smem, value, *leaves, tree +): + if not ir.VectorType.isinstance(value.type): + raise TypeError(f"Can only store vectors (got {value}).") + if not ir.MemRefType.isinstance(x_smem.type): + raise TypeError(f"Can only store to references (got {x_smem}).") + + x_aval = ctx.avals_in[0] + + transforms = jax.tree.unflatten(tree, leaves) + x_smem, transforms = _handle_reshaping(x_smem, transforms) + x_smem, transforms = _handle_indexing(x_smem, transforms) + + if transforms: + raise NotImplementedError( + "Transforms are not yet implemented for warpgroup semantics" + ) + + shape = ctx.avals_out[0].shape + ty = ir.VectorType.get(shape, mgpu_utils.dtype_to_ir_type(x_aval.dtype)) + if shape: + zero_index = arith_dialect.constant(ir.IndexType.get(), 0) + indices = [zero_index for _ in range(len(shape))] + else: + indices = [] + old_value = vector_dialect.load(ty, x_smem, indices) + vector_dialect.store(value, x_smem, indices) + return old_value + + +@register_lowering_rule(pjit.pjit_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(pjit.pjit_p, mgpu.ThreadSemantics.Warpgroup) +def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **kwargs): if jaxpr.consts: raise NotImplementedError return lower_jaxpr_to_mosaic_gpu( - ctx.module_ctx, ctx.launch_ctx, jaxpr.jaxpr, args + ctx.module_ctx, ctx.launch_ctx, jaxpr.jaxpr, args, + thread_semantics=ctx.thread_semantics ) -@register_lowering_rule(pjit.mesh_cast_p) +@register_lowering_rule(pjit.mesh_cast_p, mgpu.ThreadSemantics.Lane) def _mesh_cast_lowering_rule(ctx, x, dst_sharding): return x -@register_lowering_rule(lax.slice_p) +@register_lowering_rule(lax.slice_p, mgpu.ThreadSemantics.Lane) def _slice_lowering_rule( ctx: LoweringRuleContext, x, limit_indices, start_indices, strides ): @@ -1141,7 +1225,7 @@ def _slice_lowering_rule( return x[tuple(slice(b, e) for b, e in zip(start_indices, limit_indices))] -@register_lowering_rule(lax.select_n_p) +@register_lowering_rule(lax.select_n_p, mgpu.ThreadSemantics.Lane) def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases): if len(cases) != 2: raise NotImplementedError( @@ -1157,7 +1241,7 @@ def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases): return pred.select(*reversed(cases)) -@register_lowering_rule(lax.broadcast_in_dim_p) +@register_lowering_rule(lax.broadcast_in_dim_p, mgpu.ThreadSemantics.Lane) def _broadcast_in_dim_lowering_rule( ctx: LoweringRuleContext, x: mgpu.FragmentedArray, @@ -1181,7 +1265,7 @@ def _broadcast_in_dim_lowering_rule( return x.broadcast(shape) -@register_lowering_rule(lax.convert_element_type_p) +@register_lowering_rule(lax.convert_element_type_p, mgpu.ThreadSemantics.Lane) def _convert_element_type_lowering_rule( ctx: LoweringRuleContext, x, *, new_dtype, weak_type, sharding ): @@ -1192,7 +1276,7 @@ def _convert_element_type_lowering_rule( ) -mosaic_lowering_rules.update({ +mosaic_lowering_rules[mgpu.ThreadSemantics.Lane].update({ lax.neg_p: lambda ctx, x: -x, lax.not_p: lambda ctx, x: ~x, }) @@ -1203,7 +1287,7 @@ def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl): return impl(x, y) -mosaic_lowering_rules.update({ +mosaic_lowering_rules[mgpu.ThreadSemantics.Lane].update({ lax.add_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x + y), lax.sub_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x - y), lax.mul_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x * y), @@ -1222,7 +1306,35 @@ mosaic_lowering_rules.update({ }) -@register_lowering_rule(lax.div_p) +# TODO(bchetioui): explore how this can be generalized for more binary ops. +def _add_lowering_rule_wg(ctx: LoweringRuleContext, x, y): + x_aval, y_aval = ctx.avals_in + [out_aval] = ctx.avals_out + # TODO(bchetioui): support implicit broadcast. + if x_aval.shape != out_aval.shape or y_aval.shape != out_aval.shape: + raise NotImplementedError( + "Implicit broadcast not implemented with warpgroup semantics") + + if np.issubdtype(ctx.avals_in[0].dtype, np.floating): + add_op = arith_dialect.addf + else: + raise NotImplementedError( + "Lowering of non-float addition is not implemented" + ) + + x = _ensure_vector(x, x_aval.dtype) + y = _ensure_vector(y, y_aval.dtype) + + return add_op(x, y) + + +mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup].update({ + lax.add_p: _add_lowering_rule_wg, + # TODO(bchetioui): add support for the remaining binary ops. +}) + + +@register_lowering_rule(lax.div_p, mgpu.ThreadSemantics.Lane) def _div_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) if ir.FloatType.isinstance(x.mlir_dtype): @@ -1230,7 +1342,7 @@ def _div_lowering_rule(ctx: LoweringRuleContext, x, y): return x // y -@register_lowering_rule(lax.integer_pow_p) +@register_lowering_rule(lax.integer_pow_p, mgpu.ThreadSemantics.Lane) def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y): [x_aval] = ctx.avals_in x = _ensure_fa(x, x_aval.dtype) @@ -1238,44 +1350,44 @@ def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y): return x * x return NotImplementedError -@register_lowering_rule(lax.square_p) +@register_lowering_rule(lax.square_p, mgpu.ThreadSemantics.Lane) def _square_lowering_rule(ctx: LoweringRuleContext, x): [x_aval] = ctx.avals_in x = _ensure_fa(x, x_aval.dtype) return x * x -@register_lowering_rule(lax.rsqrt_p) +@register_lowering_rule(lax.rsqrt_p, mgpu.ThreadSemantics.Lane) def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x): [x_aval] = ctx.avals_in return _ensure_fa(x, x_aval.dtype).rsqrt(approx=ctx.module_ctx.approx_math) -@register_lowering_rule(lax.tanh_p) +@register_lowering_rule(lax.tanh_p, mgpu.ThreadSemantics.Lane) def _tanh_lowering_rule(ctx: LoweringRuleContext, x): [x_aval] = ctx.avals_in return _ensure_fa(x, x_aval.dtype).tanh(approx=ctx.module_ctx.approx_math) -@register_lowering_rule(lax.logistic_p) +@register_lowering_rule(lax.logistic_p, mgpu.ThreadSemantics.Lane) def _logistic_lowering_rule(ctx: LoweringRuleContext, x): [x_aval] = ctx.avals_in a = _ensure_fa(x, x_aval.dtype) return 1. / (1. + (-a).exp(approx=ctx.module_ctx.approx_math)) -@register_lowering_rule(lax.exp_p) +@register_lowering_rule(lax.exp_p, mgpu.ThreadSemantics.Lane) def _exp_lowering_rule(ctx: LoweringRuleContext, x): [x_aval] = ctx.avals_in a = _ensure_fa(x, x_aval.dtype) return a.exp(approx=ctx.module_ctx.approx_math) -@register_lowering_rule(lax.exp2_p) +@register_lowering_rule(lax.exp2_p, mgpu.ThreadSemantics.Lane) def _exp2_lowering_rule(ctx: LoweringRuleContext, x): [x_aval] = ctx.avals_in a = _ensure_fa(x, x_aval.dtype) return a.exp2(approx=ctx.module_ctx.approx_math) -@register_lowering_rule(lax.reduce_sum_p) +@register_lowering_rule(lax.reduce_sum_p, mgpu.ThreadSemantics.Lane) def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes): [x_aval] = ctx.avals_in match x.layout: @@ -1295,7 +1407,7 @@ def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes): raise NotImplementedError(f"Unsupported layout {x.layout}") -@register_lowering_rule(lax.reduce_max_p) +@register_lowering_rule(lax.reduce_max_p, mgpu.ThreadSemantics.Lane) def _reduce_max_lowering_rule(ctx: LoweringRuleContext, x, *, axes): [x_aval] = ctx.avals_in match x.layout: @@ -1309,7 +1421,7 @@ def _reduce_max_lowering_rule(ctx: LoweringRuleContext, x, *, axes): raise NotImplementedError(f"Unsupported layout {x.layout}") -@register_lowering_rule(lax.axis_index_p) +@register_lowering_rule(lax.axis_index_p, mgpu.ThreadSemantics.Lane) def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): i32 = ir.IntegerType.get_signless(32) grid_names = ctx.module_ctx.grid_names @@ -1354,7 +1466,7 @@ def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): ) -@register_lowering_rule(primitives.debug_print_p) +@register_lowering_rule(primitives.debug_print_p, mgpu.ThreadSemantics.Lane) def _debug_print_lowering_rule( ctx: LoweringRuleContext, *args, @@ -1387,7 +1499,7 @@ def _debug_print_lowering_rule( return () -@register_lowering_rule(primitives.run_scoped_p) +@register_lowering_rule(primitives.run_scoped_p, mgpu.ThreadSemantics.Lane) def _run_scoped_lowering_rule( ctx: LoweringRuleContext, *consts, jaxpr: jax_core.Jaxpr ): @@ -1443,7 +1555,7 @@ def _run_scoped_lowering_rule( return outs -@register_lowering_rule(discharge.run_state_p) +@register_lowering_rule(discharge.run_state_p, mgpu.ThreadSemantics.Lane) def _run_state_lowering_rule( ctx: LoweringRuleContext, *args, @@ -1530,7 +1642,7 @@ def _lower_jaxpr_to_for_loop( return loop.results -@register_lowering_rule(lax.scan_p) +@register_lowering_rule(lax.scan_p, mgpu.ThreadSemantics.Lane) def _scan_lowering_rule( ctx: LoweringRuleContext, *args, @@ -1618,7 +1730,7 @@ def _lower_while_via_fori( return ub, ub, *for_out -@register_lowering_rule(lax.while_p) +@register_lowering_rule(lax.while_p, mgpu.ThreadSemantics.Lane) def _while_lowering_rule( ctx: LoweringRuleContext, *args, @@ -1696,7 +1808,7 @@ def _while_lowering_rule( return carry_treedef.unflatten(list(while_op.results)) -@register_lowering_rule(lax.cond_p) +@register_lowering_rule(lax.cond_p, mgpu.ThreadSemantics.Lane) def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches): index_aval, *_arg_avals = ctx.avals_in @@ -1753,7 +1865,7 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches): return treedef.unflatten(list(switch_op.results)) -@register_lowering_rule(lax.bitcast_convert_type_p) +@register_lowering_rule(lax.bitcast_convert_type_p, mgpu.ThreadSemantics.Lane) def _bitcast_convert_type_lowering_rule( ctx: LoweringRuleContext, operand, *, new_dtype ): @@ -1778,7 +1890,7 @@ def _bitcast_convert_type_lowering_rule( ) -@register_lowering_rule(lax.optimization_barrier_p) +@register_lowering_rule(lax.optimization_barrier_p, mgpu.ThreadSemantics.Lane) def _optimization_barrier_lowering(ctx: LoweringRuleContext, *args): args = (_ensure_fa(arg, aval.dtype) for arg, aval in zip(args, ctx.avals_in)) return mgpu.optimization_barrier(*args) @@ -1817,6 +1929,17 @@ def _ensure_fa(x: object, dtype: jnp.dtype) -> mgpu.FragmentedArray: ) +def _ensure_vector(x: object, dtype: jnp.dtype) -> ir.Value: + if isinstance(x, ir.Value) and ir.VectorType.isinstance(x.type): + assert ir.VectorType(x.type).element_type == mgpu_utils.dtype_to_ir_type(dtype) + return x + + if isinstance(x, (np.number, np.ndarray, int, float)): + return _ir_constant(x, mgpu_utils.dtype_to_ir_type(dtype)) + + raise NotImplementedError(f"Unsupported conversion to vector for: {x!r}") + + def _ensure_ir_value(x: object, dtype: jnp.dtype) -> ir.Value: if isinstance(x, ir.Value): assert x.type == mgpu_utils.dtype_to_ir_type(dtype) diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index 19c8b0ad5..17ebe2034 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -55,6 +55,12 @@ def pallas_call_lowering( print(f"The grid mapping for pallas_call {name_and_src_info}:") print(grid_mapping) + thread_semantics = compiler_params.get("mosaic_gpu", {}).get( + "thread_semantics", mosaic_core.ThreadSemantics.Lane + ) + if thread_semantics == mosaic_core.ThreadSemantics.Warpgroup: + mosaic_core.dialect.register_dialect(ctx.module_context.context) # pytype: disable=attribute-error + lowering_result = lowering.lower_jaxpr_to_module( grid_mapping, jaxpr, diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index afb350a38..4fd63fb9f 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -73,7 +73,7 @@ def _copy_smem_to_gmem_abstract_eval(src, dst, *args, **params): return (), {state.ReadEffect(0), state.WriteEffect(1)} -@lowering.register_lowering_rule(copy_smem_to_gmem_p) +@lowering.register_lowering_rule(copy_smem_to_gmem_p, mgpu.ThreadSemantics.Lane) def _copy_smem_to_gmem_lowering( ctx: lowering.LoweringRuleContext, src, @@ -187,7 +187,7 @@ def _copy_gmem_to_smem_abstract_eval(src, dst, barrier, *args, **params): return (), {state.ReadEffect(0), state.WriteEffect(1)} -@lowering.register_lowering_rule(copy_gmem_to_smem_p) +@lowering.register_lowering_rule(copy_gmem_to_smem_p, mgpu.ThreadSemantics.Lane) def _copy_gmem_to_smem_lowering( ctx: lowering.LoweringRuleContext, src, @@ -307,7 +307,7 @@ def _barrier_arrive_abstract_eval(barrier, *args, **params): return (), {gpu_core._memory_effect} -@lowering.register_lowering_rule(barrier_arrive_p) +@lowering.register_lowering_rule(barrier_arrive_p, mgpu.ThreadSemantics.Lane) def _barrier_arrive_lowering( ctx: lowering.LoweringRuleContext, barrier, @@ -345,7 +345,7 @@ def _barrier_wait_abstract_eval(barrier, *args, **params): return (), {gpu_core._memory_effect} -@lowering.register_lowering_rule(barrier_wait_p) +@lowering.register_lowering_rule(barrier_wait_p, mgpu.ThreadSemantics.Lane) def _barrier_wait_lowering( ctx: lowering.LoweringRuleContext, barrier, @@ -382,7 +382,7 @@ def _wait_smem_to_gmem_abstract_eval(n, *, wait_read_only): return (), {gpu_core._memory_effect} -@lowering.register_lowering_rule(wait_smem_to_gmem_p) +@lowering.register_lowering_rule(wait_smem_to_gmem_p, mgpu.ThreadSemantics.Lane) def _wait_smem_to_gmem_lowering( ctx: lowering.LoweringRuleContext, n, *, wait_read_only ): @@ -483,7 +483,7 @@ def _wgmma_ref_discharge(in_avals, out_avals, *args, **kwargs): wgmma_p = jax_core.Primitive("wgmma") -@lowering.register_lowering_rule(wgmma_p) +@lowering.register_lowering_rule(wgmma_p, mgpu.ThreadSemantics.Lane) def _wgmma_lowering( ctx: lowering.LoweringRuleContext, acc, @@ -588,7 +588,7 @@ def wgmma_wait_effectful_abstract_eval(_): return [], {gpu_core._wgmma_pipeline_effect} -@lowering.register_lowering_rule(wgmma_wait_p) +@lowering.register_lowering_rule(wgmma_wait_p, mgpu.ThreadSemantics.Lane) def _wgmma_wait_lowering(ctx: lowering.LoweringRuleContext, allow_groups): del ctx nvvm_dialect.wgmma_wait_group_sync_aligned(allow_groups) @@ -619,7 +619,7 @@ def _wgmma_accumulator_deref_discharge(in_avals, out_avals, acc): return (None,), wgmma_accumulator_deref_p.bind(acc) -@lowering.register_lowering_rule(wgmma_accumulator_deref_p) +@lowering.register_lowering_rule(wgmma_accumulator_deref_p, mgpu.ThreadSemantics.Lane) def _wgmma_accumulator_deref_lowering(ctx: lowering.LoweringRuleContext, acc): del ctx nvvm_dialect.wgmma_wait_group_sync_aligned(0) @@ -665,7 +665,7 @@ def _layout_cast_abstract_eval(x, new_layout): return x -@lowering.register_lowering_rule(layout_cast_p) +@lowering.register_lowering_rule(layout_cast_p, mgpu.ThreadSemantics.Lane) def _layout_cast_lowering(ctx: lowering.LoweringRuleContext, x, *, new_layout): del ctx # Unused. return x.to_layout(_get_mgpu_layout(new_layout)) @@ -686,7 +686,7 @@ def _set_max_registers_abstract_eval(n, *, action): return (), {gpu_core._memory_effect} -@lowering.register_lowering_rule(set_max_registers_p) +@lowering.register_lowering_rule(set_max_registers_p, mgpu.ThreadSemantics.Lane) def _set_max_registers_lowering( ctx: lowering.LoweringRuleContext, n, *, action ): @@ -714,7 +714,7 @@ def _commit_smem_abstract_eval(): return (), {gpu_core._memory_effect} -@lowering.register_lowering_rule(commit_smem_p) +@lowering.register_lowering_rule(commit_smem_p, mgpu.ThreadSemantics.Lane) def _commit_smem_lowering(ctx: lowering.LoweringRuleContext): mgpu.commit_shared() return () @@ -733,7 +733,7 @@ def _broadcasted_iota_abstract_eval(dtype, shape, dimension, layout): return jax_core.ShapedArray(shape, dtype) -@lowering.register_lowering_rule(broadcasted_iota_p) +@lowering.register_lowering_rule(broadcasted_iota_p, mgpu.ThreadSemantics.Lane) def _broadcasted_iota_lowering( ctx: lowering.LoweringRuleContext, dtype, shape, dimension, layout ): diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 130895c21..583754cde 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -1212,9 +1212,11 @@ _PALLAS_USE_MOSAIC_GPU = config.bool_flag( default=config.bool_env("JAX_PALLAS_USE_MOSAIC_GPU", False), help=( "If True, lower Pallas kernels to the experimental Mosaic GPU" - " dialect, instead of Trition IR." + " dialect, instead of Triton IR." ), ) + + _PALLAS_VERBOSE_ERRORS = config.bool_flag( "jax_pallas_verbose_errors", default=config.bool_env("JAX_PALLAS_VERBOSE_ERRORS", True), diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 151fb4fe3..bbad669fc 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -27,6 +27,7 @@ from jax import lax from jax._src import test_util as jtu from jax._src.pallas.mosaic_gpu import pipeline as mgpu_pipeline from jax.experimental import pallas as pl +from jax.experimental.mosaic.gpu import core as mosaic_gpu_core from jax.experimental.pallas import mosaic_gpu as plgpu import jax.numpy as jnp import numpy as np @@ -2073,5 +2074,31 @@ class ExamplesTest(PallasTest): # TODO(apaszke): Clusters and multicast +class PallasCallWarpgroupSemanticsTest(PallasTest): + + def setUp(self): + self.compiler_params = plgpu.GPUCompilerParams( + thread_semantics=mosaic_gpu_core.ThreadSemantics.Warpgroup + ) + + super().setUp() + + @parameterized.named_parameters( + ("add_float", lambda x, y: x + y, np.float32), + ) + def test_binary_op_wg_semantics(self, bop, dtype): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], dtype=dtype), + compiler_params=self.compiler_params + ) + def kernel(x_ref, y_ref, o_ref): + o_ref[...] = bop(x_ref[...], y_ref[...]) + + x = jnp.arange(256).astype(dtype) + y = x + 1 + np.testing.assert_array_equal(kernel(x, y), bop(x, y)) + + if __name__ == "__main__": absltest.main()