mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[Pallas/Mosaic GPU] Add initial support for warpgroup semantics in lowering.
This will allow us to lower Pallas kernels using the Mosaic GPU dialect, and in turn to perform layout inference and optimization automatically. The change contains lowering rules for `get` and `swap` (which are necessary to get a basic example to run), as well as for `add`. The new lowering path can be used by specifying the `Warpgroup` thread semantics as part of `pallas_call`'s compiler params. PiperOrigin-RevId: 725958027
This commit is contained in:
parent
72e7b93b4d
commit
5ad89006c3
@ -86,6 +86,7 @@ class GPUCompilerParams(pallas_core.CompilerParams):
|
||||
delay_release: int = 0
|
||||
profile_space: int = 0
|
||||
profile_dir: str = ""
|
||||
thread_semantics: mgpu.core.ThreadSemantics = mgpu.core.ThreadSemantics.Lane
|
||||
|
||||
def __post_init__(self):
|
||||
if bool(self.profile_space) ^ bool(self.profile_dir):
|
||||
|
@ -39,6 +39,7 @@ from jax._src.lib.mlir.dialects import gpu as gpu_dialect
|
||||
from jax._src.lib.mlir.dialects import memref as memref_dialect
|
||||
from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect
|
||||
from jax._src.lib.mlir.dialects import scf as scf_dialect
|
||||
from jax._src.lib.mlir.dialects import vector as vector_dialect
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas import pallas_call
|
||||
from jax._src.pallas import primitives
|
||||
@ -272,6 +273,7 @@ class ModuleContext:
|
||||
class LoweringRuleContext:
|
||||
module_ctx: ModuleContext
|
||||
launch_ctx: mgpu.LaunchContext
|
||||
thread_semantics: mgpu_core.ThreadSemantics
|
||||
predicate: ir.Value
|
||||
avals_in: Sequence[jax_core.ShapedArray]
|
||||
avals_out: Sequence[jax_core.ShapedArray]
|
||||
@ -368,7 +370,7 @@ def lower_jaxpr_to_module(
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
name_and_src_info: pallas_core.NameAndSrcInfo,
|
||||
compiler_params: dict[str, Any],
|
||||
cost_estimate: pallas_core.CostEstimate | None,
|
||||
cost_estimate: pallas_core.CostEstimate | None
|
||||
) -> LoweringResult:
|
||||
del cost_estimate # Unused.
|
||||
|
||||
@ -390,6 +392,9 @@ def lower_jaxpr_to_module(
|
||||
approx_math = params.get("approx_math", False)
|
||||
max_concurrent_steps = params.get("max_concurrent_steps", 1)
|
||||
delay_release = params.get("delay_release", 0)
|
||||
thread_semantics = params.get(
|
||||
"thread_semantics", mgpu_core.ThreadSemantics.Lane
|
||||
)
|
||||
dimension_semantics = params.get("dimension_semantics")
|
||||
if dimension_semantics is None:
|
||||
dimension_semantics = ["parallel"] * len(grid_mapping.grid)
|
||||
@ -724,6 +729,7 @@ def lower_jaxpr_to_module(
|
||||
launch_ctx,
|
||||
lowered_jaxpr,
|
||||
args,
|
||||
thread_semantics=thread_semantics
|
||||
)
|
||||
|
||||
if not all(out_sequential_invariant):
|
||||
@ -826,6 +832,13 @@ def lower_jaxpr_to_module(
|
||||
prof_spec=prof_spec,
|
||||
)
|
||||
)
|
||||
|
||||
if thread_semantics == mgpu.ThreadSemantics.Warpgroup:
|
||||
# Run Python lowering passes. The remaining passes will be run in C++ in
|
||||
# jax/jaxlib/mosaic/gpu/custom_call.cc
|
||||
mgpu.infer_layout(module) # pytype: disable=attribute-error
|
||||
mgpu.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error
|
||||
|
||||
mgpu_core._initialize_scratch(launch_ctx, scratch_arr)
|
||||
|
||||
return LoweringResult(
|
||||
@ -833,12 +846,19 @@ def lower_jaxpr_to_module(
|
||||
)
|
||||
|
||||
|
||||
mosaic_lowering_rules = {}
|
||||
mosaic_lowering_rules = {
|
||||
# Lowering rules when using Mosaic GPU lane semantics.
|
||||
mgpu.ThreadSemantics.Lane: {} ,
|
||||
# Lowering rules when using Mosaic GPU warpgroup semantics.
|
||||
mgpu.ThreadSemantics.Warpgroup: {},
|
||||
}
|
||||
|
||||
|
||||
def register_lowering_rule(primitive: jax_core.Primitive):
|
||||
def register_lowering_rule(
|
||||
primitive: jax_core.Primitive, thread_semantics: mgpu.ThreadSemantics
|
||||
):
|
||||
def deco(fn):
|
||||
mosaic_lowering_rules[primitive] = fn
|
||||
mosaic_lowering_rules[thread_semantics][primitive] = fn
|
||||
return fn
|
||||
|
||||
return deco
|
||||
@ -863,6 +883,7 @@ def lower_jaxpr_to_mosaic_gpu(
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
args: Sequence[ir.Value],
|
||||
consts=(),
|
||||
thread_semantics: mgpu.ThreadSemantics = mgpu.ThreadSemantics.Lane,
|
||||
) -> Sequence[ir.Value]:
|
||||
env = {}
|
||||
|
||||
@ -884,7 +905,7 @@ def lower_jaxpr_to_mosaic_gpu(
|
||||
)
|
||||
loc = mlir._source_info_to_location(module_ctx, eqn.primitive, source_info)
|
||||
with source_info_util.user_context(eqn.source_info.traceback), loc:
|
||||
if eqn.primitive not in mosaic_lowering_rules:
|
||||
if eqn.primitive not in mosaic_lowering_rules[thread_semantics]:
|
||||
raise NotImplementedError(
|
||||
"Unimplemented primitive in Pallas Mosaic GPU lowering: "
|
||||
f"{eqn.primitive.name}. "
|
||||
@ -899,10 +920,11 @@ def lower_jaxpr_to_mosaic_gpu(
|
||||
wrapper_stack = contextlib.ExitStack()
|
||||
wrapper_stack.enter_context(launch_ctx.named_region(name))
|
||||
named_regions.append(wrapper_stack)
|
||||
rule = mosaic_lowering_rules[eqn.primitive]
|
||||
rule = mosaic_lowering_rules[thread_semantics][eqn.primitive]
|
||||
rule_ctx = LoweringRuleContext(
|
||||
module_ctx,
|
||||
launch_ctx,
|
||||
thread_semantics,
|
||||
predicate=mgpu.single_thread_predicate(per_block=False),
|
||||
avals_in=[cast(jax_core.ShapedArray, v.aval) for v in eqn.invars],
|
||||
avals_out=[cast(jax_core.ShapedArray, v.aval) for v in eqn.outvars],
|
||||
@ -928,7 +950,8 @@ def lower_jaxpr_to_mosaic_gpu(
|
||||
return map(read_env, jaxpr.outvars)
|
||||
|
||||
|
||||
@register_lowering_rule(primitives.program_id_p)
|
||||
@register_lowering_rule(primitives.program_id_p, mgpu.ThreadSemantics.Lane)
|
||||
@register_lowering_rule(primitives.program_id_p, mgpu.ThreadSemantics.Warpgroup)
|
||||
def _program_id_lowering_rule(ctx: LoweringRuleContext, axis):
|
||||
if ctx.module_ctx.program_ids is None:
|
||||
raise NotImplementedError("pl.program_id() is not supported in this context")
|
||||
@ -972,7 +995,8 @@ def _program_id(parallel_axis: int, squashed_dims: tuple[int, ...]) -> ir.Value:
|
||||
)
|
||||
|
||||
|
||||
@register_lowering_rule(primitives.num_programs_p)
|
||||
@register_lowering_rule(primitives.num_programs_p, mgpu.ThreadSemantics.Lane)
|
||||
@register_lowering_rule(primitives.num_programs_p, mgpu.ThreadSemantics.Warpgroup)
|
||||
def _num_programs_lowering_rule(ctx: LoweringRuleContext, axis):
|
||||
del ctx # Unused.
|
||||
return arith_dialect.index_cast(
|
||||
@ -1056,7 +1080,7 @@ def _ndindexer_indices(indexer: indexing.NDIndexer) -> tuple[gpu_core.Index, ...
|
||||
return tuple(indices)
|
||||
|
||||
|
||||
@register_lowering_rule(sp.get_p)
|
||||
@register_lowering_rule(sp.get_p, mgpu.ThreadSemantics.Lane)
|
||||
def _get_lowering_rule(ctx: LoweringRuleContext, x_smem, *leaves, tree):
|
||||
if not isinstance(x_smem, ir.Value) and ir.MemRefType.isinstance(x_smem):
|
||||
raise TypeError(f"Can only load from references (got {x_smem}).")
|
||||
@ -1088,7 +1112,33 @@ def _get_lowering_rule(ctx: LoweringRuleContext, x_smem, *leaves, tree):
|
||||
raise NotImplementedError(f"Unsupported transforms: {transforms}")
|
||||
|
||||
|
||||
@register_lowering_rule(sp.swap_p)
|
||||
@register_lowering_rule(sp.get_p, mgpu.ThreadSemantics.Warpgroup)
|
||||
def _get_lowering_rule_wg(ctx: LoweringRuleContext, x_smem, *leaves, tree):
|
||||
if not isinstance(x_smem, ir.Value) and ir.MemRefType.isinstance(x_smem):
|
||||
raise TypeError(f"Can only load from references (got {x_smem}).")
|
||||
|
||||
x_aval = ctx.avals_in[0]
|
||||
|
||||
transforms = jax.tree.unflatten(tree, leaves)
|
||||
x_smem, transforms = _handle_reshaping(x_smem, transforms)
|
||||
x_smem, transforms = _handle_indexing(x_smem, transforms)
|
||||
|
||||
if transforms:
|
||||
raise NotImplementedError(
|
||||
"Transforms are not yet implemented for warpgroup semantics"
|
||||
)
|
||||
|
||||
shape = ctx.avals_out[0].shape
|
||||
ty = ir.VectorType.get(shape, mgpu_utils.dtype_to_ir_type(x_aval.dtype))
|
||||
if shape:
|
||||
zero_index = arith_dialect.constant(ir.IndexType.get(), 0)
|
||||
indices = [zero_index for _ in range(len(shape))]
|
||||
else:
|
||||
indices = []
|
||||
return vector_dialect.load(ty, x_smem, indices)
|
||||
|
||||
|
||||
@register_lowering_rule(sp.swap_p, mgpu.ThreadSemantics.Lane)
|
||||
def _swap_lowering_rule(
|
||||
ctx: LoweringRuleContext, x_smem, value, *leaves, tree
|
||||
):
|
||||
@ -1119,19 +1169,53 @@ def _swap_lowering_rule(
|
||||
raise NotImplementedError(f"Unsupported transforms: {transforms}")
|
||||
|
||||
|
||||
@register_lowering_rule(pjit.pjit_p)
|
||||
def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_):
|
||||
@register_lowering_rule(sp.swap_p, mgpu.ThreadSemantics.Warpgroup)
|
||||
def _swap_lowering_rule_wg(
|
||||
ctx: LoweringRuleContext, x_smem, value, *leaves, tree
|
||||
):
|
||||
if not ir.VectorType.isinstance(value.type):
|
||||
raise TypeError(f"Can only store vectors (got {value}).")
|
||||
if not ir.MemRefType.isinstance(x_smem.type):
|
||||
raise TypeError(f"Can only store to references (got {x_smem}).")
|
||||
|
||||
x_aval = ctx.avals_in[0]
|
||||
|
||||
transforms = jax.tree.unflatten(tree, leaves)
|
||||
x_smem, transforms = _handle_reshaping(x_smem, transforms)
|
||||
x_smem, transforms = _handle_indexing(x_smem, transforms)
|
||||
|
||||
if transforms:
|
||||
raise NotImplementedError(
|
||||
"Transforms are not yet implemented for warpgroup semantics"
|
||||
)
|
||||
|
||||
shape = ctx.avals_out[0].shape
|
||||
ty = ir.VectorType.get(shape, mgpu_utils.dtype_to_ir_type(x_aval.dtype))
|
||||
if shape:
|
||||
zero_index = arith_dialect.constant(ir.IndexType.get(), 0)
|
||||
indices = [zero_index for _ in range(len(shape))]
|
||||
else:
|
||||
indices = []
|
||||
old_value = vector_dialect.load(ty, x_smem, indices)
|
||||
vector_dialect.store(value, x_smem, indices)
|
||||
return old_value
|
||||
|
||||
|
||||
@register_lowering_rule(pjit.pjit_p, mgpu.ThreadSemantics.Lane)
|
||||
@register_lowering_rule(pjit.pjit_p, mgpu.ThreadSemantics.Warpgroup)
|
||||
def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **kwargs):
|
||||
if jaxpr.consts:
|
||||
raise NotImplementedError
|
||||
return lower_jaxpr_to_mosaic_gpu(
|
||||
ctx.module_ctx, ctx.launch_ctx, jaxpr.jaxpr, args
|
||||
ctx.module_ctx, ctx.launch_ctx, jaxpr.jaxpr, args,
|
||||
thread_semantics=ctx.thread_semantics
|
||||
)
|
||||
|
||||
@register_lowering_rule(pjit.mesh_cast_p)
|
||||
@register_lowering_rule(pjit.mesh_cast_p, mgpu.ThreadSemantics.Lane)
|
||||
def _mesh_cast_lowering_rule(ctx, x, dst_sharding):
|
||||
return x
|
||||
|
||||
@register_lowering_rule(lax.slice_p)
|
||||
@register_lowering_rule(lax.slice_p, mgpu.ThreadSemantics.Lane)
|
||||
def _slice_lowering_rule(
|
||||
ctx: LoweringRuleContext, x, limit_indices, start_indices, strides
|
||||
):
|
||||
@ -1141,7 +1225,7 @@ def _slice_lowering_rule(
|
||||
return x[tuple(slice(b, e) for b, e in zip(start_indices, limit_indices))]
|
||||
|
||||
|
||||
@register_lowering_rule(lax.select_n_p)
|
||||
@register_lowering_rule(lax.select_n_p, mgpu.ThreadSemantics.Lane)
|
||||
def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases):
|
||||
if len(cases) != 2:
|
||||
raise NotImplementedError(
|
||||
@ -1157,7 +1241,7 @@ def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases):
|
||||
return pred.select(*reversed(cases))
|
||||
|
||||
|
||||
@register_lowering_rule(lax.broadcast_in_dim_p)
|
||||
@register_lowering_rule(lax.broadcast_in_dim_p, mgpu.ThreadSemantics.Lane)
|
||||
def _broadcast_in_dim_lowering_rule(
|
||||
ctx: LoweringRuleContext,
|
||||
x: mgpu.FragmentedArray,
|
||||
@ -1181,7 +1265,7 @@ def _broadcast_in_dim_lowering_rule(
|
||||
return x.broadcast(shape)
|
||||
|
||||
|
||||
@register_lowering_rule(lax.convert_element_type_p)
|
||||
@register_lowering_rule(lax.convert_element_type_p, mgpu.ThreadSemantics.Lane)
|
||||
def _convert_element_type_lowering_rule(
|
||||
ctx: LoweringRuleContext, x, *, new_dtype, weak_type, sharding
|
||||
):
|
||||
@ -1192,7 +1276,7 @@ def _convert_element_type_lowering_rule(
|
||||
)
|
||||
|
||||
|
||||
mosaic_lowering_rules.update({
|
||||
mosaic_lowering_rules[mgpu.ThreadSemantics.Lane].update({
|
||||
lax.neg_p: lambda ctx, x: -x,
|
||||
lax.not_p: lambda ctx, x: ~x,
|
||||
})
|
||||
@ -1203,7 +1287,7 @@ def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl):
|
||||
return impl(x, y)
|
||||
|
||||
|
||||
mosaic_lowering_rules.update({
|
||||
mosaic_lowering_rules[mgpu.ThreadSemantics.Lane].update({
|
||||
lax.add_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x + y),
|
||||
lax.sub_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x - y),
|
||||
lax.mul_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x * y),
|
||||
@ -1222,7 +1306,35 @@ mosaic_lowering_rules.update({
|
||||
})
|
||||
|
||||
|
||||
@register_lowering_rule(lax.div_p)
|
||||
# TODO(bchetioui): explore how this can be generalized for more binary ops.
|
||||
def _add_lowering_rule_wg(ctx: LoweringRuleContext, x, y):
|
||||
x_aval, y_aval = ctx.avals_in
|
||||
[out_aval] = ctx.avals_out
|
||||
# TODO(bchetioui): support implicit broadcast.
|
||||
if x_aval.shape != out_aval.shape or y_aval.shape != out_aval.shape:
|
||||
raise NotImplementedError(
|
||||
"Implicit broadcast not implemented with warpgroup semantics")
|
||||
|
||||
if np.issubdtype(ctx.avals_in[0].dtype, np.floating):
|
||||
add_op = arith_dialect.addf
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Lowering of non-float addition is not implemented"
|
||||
)
|
||||
|
||||
x = _ensure_vector(x, x_aval.dtype)
|
||||
y = _ensure_vector(y, y_aval.dtype)
|
||||
|
||||
return add_op(x, y)
|
||||
|
||||
|
||||
mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup].update({
|
||||
lax.add_p: _add_lowering_rule_wg,
|
||||
# TODO(bchetioui): add support for the remaining binary ops.
|
||||
})
|
||||
|
||||
|
||||
@register_lowering_rule(lax.div_p, mgpu.ThreadSemantics.Lane)
|
||||
def _div_lowering_rule(ctx: LoweringRuleContext, x, y):
|
||||
x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out)
|
||||
if ir.FloatType.isinstance(x.mlir_dtype):
|
||||
@ -1230,7 +1342,7 @@ def _div_lowering_rule(ctx: LoweringRuleContext, x, y):
|
||||
return x // y
|
||||
|
||||
|
||||
@register_lowering_rule(lax.integer_pow_p)
|
||||
@register_lowering_rule(lax.integer_pow_p, mgpu.ThreadSemantics.Lane)
|
||||
def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y):
|
||||
[x_aval] = ctx.avals_in
|
||||
x = _ensure_fa(x, x_aval.dtype)
|
||||
@ -1238,44 +1350,44 @@ def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y):
|
||||
return x * x
|
||||
return NotImplementedError
|
||||
|
||||
@register_lowering_rule(lax.square_p)
|
||||
@register_lowering_rule(lax.square_p, mgpu.ThreadSemantics.Lane)
|
||||
def _square_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
[x_aval] = ctx.avals_in
|
||||
x = _ensure_fa(x, x_aval.dtype)
|
||||
return x * x
|
||||
|
||||
@register_lowering_rule(lax.rsqrt_p)
|
||||
@register_lowering_rule(lax.rsqrt_p, mgpu.ThreadSemantics.Lane)
|
||||
def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
[x_aval] = ctx.avals_in
|
||||
return _ensure_fa(x, x_aval.dtype).rsqrt(approx=ctx.module_ctx.approx_math)
|
||||
|
||||
@register_lowering_rule(lax.tanh_p)
|
||||
@register_lowering_rule(lax.tanh_p, mgpu.ThreadSemantics.Lane)
|
||||
def _tanh_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
[x_aval] = ctx.avals_in
|
||||
return _ensure_fa(x, x_aval.dtype).tanh(approx=ctx.module_ctx.approx_math)
|
||||
|
||||
|
||||
@register_lowering_rule(lax.logistic_p)
|
||||
@register_lowering_rule(lax.logistic_p, mgpu.ThreadSemantics.Lane)
|
||||
def _logistic_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
[x_aval] = ctx.avals_in
|
||||
a = _ensure_fa(x, x_aval.dtype)
|
||||
return 1. / (1. + (-a).exp(approx=ctx.module_ctx.approx_math))
|
||||
|
||||
@register_lowering_rule(lax.exp_p)
|
||||
@register_lowering_rule(lax.exp_p, mgpu.ThreadSemantics.Lane)
|
||||
def _exp_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
[x_aval] = ctx.avals_in
|
||||
a = _ensure_fa(x, x_aval.dtype)
|
||||
return a.exp(approx=ctx.module_ctx.approx_math)
|
||||
|
||||
|
||||
@register_lowering_rule(lax.exp2_p)
|
||||
@register_lowering_rule(lax.exp2_p, mgpu.ThreadSemantics.Lane)
|
||||
def _exp2_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
[x_aval] = ctx.avals_in
|
||||
a = _ensure_fa(x, x_aval.dtype)
|
||||
return a.exp2(approx=ctx.module_ctx.approx_math)
|
||||
|
||||
|
||||
@register_lowering_rule(lax.reduce_sum_p)
|
||||
@register_lowering_rule(lax.reduce_sum_p, mgpu.ThreadSemantics.Lane)
|
||||
def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
|
||||
[x_aval] = ctx.avals_in
|
||||
match x.layout:
|
||||
@ -1295,7 +1407,7 @@ def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
|
||||
raise NotImplementedError(f"Unsupported layout {x.layout}")
|
||||
|
||||
|
||||
@register_lowering_rule(lax.reduce_max_p)
|
||||
@register_lowering_rule(lax.reduce_max_p, mgpu.ThreadSemantics.Lane)
|
||||
def _reduce_max_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
|
||||
[x_aval] = ctx.avals_in
|
||||
match x.layout:
|
||||
@ -1309,7 +1421,7 @@ def _reduce_max_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
|
||||
raise NotImplementedError(f"Unsupported layout {x.layout}")
|
||||
|
||||
|
||||
@register_lowering_rule(lax.axis_index_p)
|
||||
@register_lowering_rule(lax.axis_index_p, mgpu.ThreadSemantics.Lane)
|
||||
def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable):
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
grid_names = ctx.module_ctx.grid_names
|
||||
@ -1354,7 +1466,7 @@ def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable):
|
||||
)
|
||||
|
||||
|
||||
@register_lowering_rule(primitives.debug_print_p)
|
||||
@register_lowering_rule(primitives.debug_print_p, mgpu.ThreadSemantics.Lane)
|
||||
def _debug_print_lowering_rule(
|
||||
ctx: LoweringRuleContext,
|
||||
*args,
|
||||
@ -1387,7 +1499,7 @@ def _debug_print_lowering_rule(
|
||||
return ()
|
||||
|
||||
|
||||
@register_lowering_rule(primitives.run_scoped_p)
|
||||
@register_lowering_rule(primitives.run_scoped_p, mgpu.ThreadSemantics.Lane)
|
||||
def _run_scoped_lowering_rule(
|
||||
ctx: LoweringRuleContext, *consts, jaxpr: jax_core.Jaxpr
|
||||
):
|
||||
@ -1443,7 +1555,7 @@ def _run_scoped_lowering_rule(
|
||||
return outs
|
||||
|
||||
|
||||
@register_lowering_rule(discharge.run_state_p)
|
||||
@register_lowering_rule(discharge.run_state_p, mgpu.ThreadSemantics.Lane)
|
||||
def _run_state_lowering_rule(
|
||||
ctx: LoweringRuleContext,
|
||||
*args,
|
||||
@ -1530,7 +1642,7 @@ def _lower_jaxpr_to_for_loop(
|
||||
return loop.results
|
||||
|
||||
|
||||
@register_lowering_rule(lax.scan_p)
|
||||
@register_lowering_rule(lax.scan_p, mgpu.ThreadSemantics.Lane)
|
||||
def _scan_lowering_rule(
|
||||
ctx: LoweringRuleContext,
|
||||
*args,
|
||||
@ -1618,7 +1730,7 @@ def _lower_while_via_fori(
|
||||
return ub, ub, *for_out
|
||||
|
||||
|
||||
@register_lowering_rule(lax.while_p)
|
||||
@register_lowering_rule(lax.while_p, mgpu.ThreadSemantics.Lane)
|
||||
def _while_lowering_rule(
|
||||
ctx: LoweringRuleContext,
|
||||
*args,
|
||||
@ -1696,7 +1808,7 @@ def _while_lowering_rule(
|
||||
return carry_treedef.unflatten(list(while_op.results))
|
||||
|
||||
|
||||
@register_lowering_rule(lax.cond_p)
|
||||
@register_lowering_rule(lax.cond_p, mgpu.ThreadSemantics.Lane)
|
||||
def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches):
|
||||
index_aval, *_arg_avals = ctx.avals_in
|
||||
|
||||
@ -1753,7 +1865,7 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches):
|
||||
return treedef.unflatten(list(switch_op.results))
|
||||
|
||||
|
||||
@register_lowering_rule(lax.bitcast_convert_type_p)
|
||||
@register_lowering_rule(lax.bitcast_convert_type_p, mgpu.ThreadSemantics.Lane)
|
||||
def _bitcast_convert_type_lowering_rule(
|
||||
ctx: LoweringRuleContext, operand, *, new_dtype
|
||||
):
|
||||
@ -1778,7 +1890,7 @@ def _bitcast_convert_type_lowering_rule(
|
||||
)
|
||||
|
||||
|
||||
@register_lowering_rule(lax.optimization_barrier_p)
|
||||
@register_lowering_rule(lax.optimization_barrier_p, mgpu.ThreadSemantics.Lane)
|
||||
def _optimization_barrier_lowering(ctx: LoweringRuleContext, *args):
|
||||
args = (_ensure_fa(arg, aval.dtype) for arg, aval in zip(args, ctx.avals_in))
|
||||
return mgpu.optimization_barrier(*args)
|
||||
@ -1817,6 +1929,17 @@ def _ensure_fa(x: object, dtype: jnp.dtype) -> mgpu.FragmentedArray:
|
||||
)
|
||||
|
||||
|
||||
def _ensure_vector(x: object, dtype: jnp.dtype) -> ir.Value:
|
||||
if isinstance(x, ir.Value) and ir.VectorType.isinstance(x.type):
|
||||
assert ir.VectorType(x.type).element_type == mgpu_utils.dtype_to_ir_type(dtype)
|
||||
return x
|
||||
|
||||
if isinstance(x, (np.number, np.ndarray, int, float)):
|
||||
return _ir_constant(x, mgpu_utils.dtype_to_ir_type(dtype))
|
||||
|
||||
raise NotImplementedError(f"Unsupported conversion to vector for: {x!r}")
|
||||
|
||||
|
||||
def _ensure_ir_value(x: object, dtype: jnp.dtype) -> ir.Value:
|
||||
if isinstance(x, ir.Value):
|
||||
assert x.type == mgpu_utils.dtype_to_ir_type(dtype)
|
||||
|
@ -55,6 +55,12 @@ def pallas_call_lowering(
|
||||
print(f"The grid mapping for pallas_call {name_and_src_info}:")
|
||||
print(grid_mapping)
|
||||
|
||||
thread_semantics = compiler_params.get("mosaic_gpu", {}).get(
|
||||
"thread_semantics", mosaic_core.ThreadSemantics.Lane
|
||||
)
|
||||
if thread_semantics == mosaic_core.ThreadSemantics.Warpgroup:
|
||||
mosaic_core.dialect.register_dialect(ctx.module_context.context) # pytype: disable=attribute-error
|
||||
|
||||
lowering_result = lowering.lower_jaxpr_to_module(
|
||||
grid_mapping,
|
||||
jaxpr,
|
||||
|
@ -73,7 +73,7 @@ def _copy_smem_to_gmem_abstract_eval(src, dst, *args, **params):
|
||||
return (), {state.ReadEffect(0), state.WriteEffect(1)}
|
||||
|
||||
|
||||
@lowering.register_lowering_rule(copy_smem_to_gmem_p)
|
||||
@lowering.register_lowering_rule(copy_smem_to_gmem_p, mgpu.ThreadSemantics.Lane)
|
||||
def _copy_smem_to_gmem_lowering(
|
||||
ctx: lowering.LoweringRuleContext,
|
||||
src,
|
||||
@ -187,7 +187,7 @@ def _copy_gmem_to_smem_abstract_eval(src, dst, barrier, *args, **params):
|
||||
return (), {state.ReadEffect(0), state.WriteEffect(1)}
|
||||
|
||||
|
||||
@lowering.register_lowering_rule(copy_gmem_to_smem_p)
|
||||
@lowering.register_lowering_rule(copy_gmem_to_smem_p, mgpu.ThreadSemantics.Lane)
|
||||
def _copy_gmem_to_smem_lowering(
|
||||
ctx: lowering.LoweringRuleContext,
|
||||
src,
|
||||
@ -307,7 +307,7 @@ def _barrier_arrive_abstract_eval(barrier, *args, **params):
|
||||
return (), {gpu_core._memory_effect}
|
||||
|
||||
|
||||
@lowering.register_lowering_rule(barrier_arrive_p)
|
||||
@lowering.register_lowering_rule(barrier_arrive_p, mgpu.ThreadSemantics.Lane)
|
||||
def _barrier_arrive_lowering(
|
||||
ctx: lowering.LoweringRuleContext,
|
||||
barrier,
|
||||
@ -345,7 +345,7 @@ def _barrier_wait_abstract_eval(barrier, *args, **params):
|
||||
return (), {gpu_core._memory_effect}
|
||||
|
||||
|
||||
@lowering.register_lowering_rule(barrier_wait_p)
|
||||
@lowering.register_lowering_rule(barrier_wait_p, mgpu.ThreadSemantics.Lane)
|
||||
def _barrier_wait_lowering(
|
||||
ctx: lowering.LoweringRuleContext,
|
||||
barrier,
|
||||
@ -382,7 +382,7 @@ def _wait_smem_to_gmem_abstract_eval(n, *, wait_read_only):
|
||||
return (), {gpu_core._memory_effect}
|
||||
|
||||
|
||||
@lowering.register_lowering_rule(wait_smem_to_gmem_p)
|
||||
@lowering.register_lowering_rule(wait_smem_to_gmem_p, mgpu.ThreadSemantics.Lane)
|
||||
def _wait_smem_to_gmem_lowering(
|
||||
ctx: lowering.LoweringRuleContext, n, *, wait_read_only
|
||||
):
|
||||
@ -483,7 +483,7 @@ def _wgmma_ref_discharge(in_avals, out_avals, *args, **kwargs):
|
||||
wgmma_p = jax_core.Primitive("wgmma")
|
||||
|
||||
|
||||
@lowering.register_lowering_rule(wgmma_p)
|
||||
@lowering.register_lowering_rule(wgmma_p, mgpu.ThreadSemantics.Lane)
|
||||
def _wgmma_lowering(
|
||||
ctx: lowering.LoweringRuleContext,
|
||||
acc,
|
||||
@ -588,7 +588,7 @@ def wgmma_wait_effectful_abstract_eval(_):
|
||||
return [], {gpu_core._wgmma_pipeline_effect}
|
||||
|
||||
|
||||
@lowering.register_lowering_rule(wgmma_wait_p)
|
||||
@lowering.register_lowering_rule(wgmma_wait_p, mgpu.ThreadSemantics.Lane)
|
||||
def _wgmma_wait_lowering(ctx: lowering.LoweringRuleContext, allow_groups):
|
||||
del ctx
|
||||
nvvm_dialect.wgmma_wait_group_sync_aligned(allow_groups)
|
||||
@ -619,7 +619,7 @@ def _wgmma_accumulator_deref_discharge(in_avals, out_avals, acc):
|
||||
return (None,), wgmma_accumulator_deref_p.bind(acc)
|
||||
|
||||
|
||||
@lowering.register_lowering_rule(wgmma_accumulator_deref_p)
|
||||
@lowering.register_lowering_rule(wgmma_accumulator_deref_p, mgpu.ThreadSemantics.Lane)
|
||||
def _wgmma_accumulator_deref_lowering(ctx: lowering.LoweringRuleContext, acc):
|
||||
del ctx
|
||||
nvvm_dialect.wgmma_wait_group_sync_aligned(0)
|
||||
@ -665,7 +665,7 @@ def _layout_cast_abstract_eval(x, new_layout):
|
||||
return x
|
||||
|
||||
|
||||
@lowering.register_lowering_rule(layout_cast_p)
|
||||
@lowering.register_lowering_rule(layout_cast_p, mgpu.ThreadSemantics.Lane)
|
||||
def _layout_cast_lowering(ctx: lowering.LoweringRuleContext, x, *, new_layout):
|
||||
del ctx # Unused.
|
||||
return x.to_layout(_get_mgpu_layout(new_layout))
|
||||
@ -686,7 +686,7 @@ def _set_max_registers_abstract_eval(n, *, action):
|
||||
return (), {gpu_core._memory_effect}
|
||||
|
||||
|
||||
@lowering.register_lowering_rule(set_max_registers_p)
|
||||
@lowering.register_lowering_rule(set_max_registers_p, mgpu.ThreadSemantics.Lane)
|
||||
def _set_max_registers_lowering(
|
||||
ctx: lowering.LoweringRuleContext, n, *, action
|
||||
):
|
||||
@ -714,7 +714,7 @@ def _commit_smem_abstract_eval():
|
||||
return (), {gpu_core._memory_effect}
|
||||
|
||||
|
||||
@lowering.register_lowering_rule(commit_smem_p)
|
||||
@lowering.register_lowering_rule(commit_smem_p, mgpu.ThreadSemantics.Lane)
|
||||
def _commit_smem_lowering(ctx: lowering.LoweringRuleContext):
|
||||
mgpu.commit_shared()
|
||||
return ()
|
||||
@ -733,7 +733,7 @@ def _broadcasted_iota_abstract_eval(dtype, shape, dimension, layout):
|
||||
return jax_core.ShapedArray(shape, dtype)
|
||||
|
||||
|
||||
@lowering.register_lowering_rule(broadcasted_iota_p)
|
||||
@lowering.register_lowering_rule(broadcasted_iota_p, mgpu.ThreadSemantics.Lane)
|
||||
def _broadcasted_iota_lowering(
|
||||
ctx: lowering.LoweringRuleContext, dtype, shape, dimension, layout
|
||||
):
|
||||
|
@ -1212,9 +1212,11 @@ _PALLAS_USE_MOSAIC_GPU = config.bool_flag(
|
||||
default=config.bool_env("JAX_PALLAS_USE_MOSAIC_GPU", False),
|
||||
help=(
|
||||
"If True, lower Pallas kernels to the experimental Mosaic GPU"
|
||||
" dialect, instead of Trition IR."
|
||||
" dialect, instead of Triton IR."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
_PALLAS_VERBOSE_ERRORS = config.bool_flag(
|
||||
"jax_pallas_verbose_errors",
|
||||
default=config.bool_env("JAX_PALLAS_VERBOSE_ERRORS", True),
|
||||
|
@ -27,6 +27,7 @@ from jax import lax
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.pallas.mosaic_gpu import pipeline as mgpu_pipeline
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.mosaic.gpu import core as mosaic_gpu_core
|
||||
from jax.experimental.pallas import mosaic_gpu as plgpu
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
@ -2073,5 +2074,31 @@ class ExamplesTest(PallasTest):
|
||||
# TODO(apaszke): Clusters and multicast
|
||||
|
||||
|
||||
class PallasCallWarpgroupSemanticsTest(PallasTest):
|
||||
|
||||
def setUp(self):
|
||||
self.compiler_params = plgpu.GPUCompilerParams(
|
||||
thread_semantics=mosaic_gpu_core.ThreadSemantics.Warpgroup
|
||||
)
|
||||
|
||||
super().setUp()
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("add_float", lambda x, y: x + y, np.float32),
|
||||
)
|
||||
def test_binary_op_wg_semantics(self, bop, dtype):
|
||||
@functools.partial(
|
||||
pl.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct([256], dtype=dtype),
|
||||
compiler_params=self.compiler_params
|
||||
)
|
||||
def kernel(x_ref, y_ref, o_ref):
|
||||
o_ref[...] = bop(x_ref[...], y_ref[...])
|
||||
|
||||
x = jnp.arange(256).astype(dtype)
|
||||
y = x + 1
|
||||
np.testing.assert_array_equal(kernel(x, y), bop(x, y))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user