diff --git a/jax/BUILD b/jax/BUILD index 91c6d02f0..40add06ba 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -19,6 +19,7 @@ load("@rules_python//python:defs.bzl", "py_library") load( "//jaxlib:jax.bzl", "if_building_jaxlib", + "if_building_mosaic_gpu", "jax_extend_internal_users", "jax_extra_deps", "jax_internal_export_back_compat_test_util_visibility", @@ -669,7 +670,7 @@ pytype_strict_library( deps = [ ":pallas", "//jax/_src/pallas/triton", - ], + ] + if_building_mosaic_gpu(["//third_party/py/jax/_src/pallas/mosaic_gpu"]), ) # This target only supports sm_90 GPUs. diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD new file mode 100644 index 000000000..76ab118f8 --- /dev/null +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -0,0 +1,60 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Package for Mosaic-specific Pallas extensions + +load("//jaxlib:jax.bzl", "pytype_strict_library") +load("@rules_python//python:defs.bzl", "py_library") + +package( + default_applicable_licenses = [], + default_visibility = [ + "//third_party/py/jax:internal", + ], +) + +py_library( + name = "mosaic_gpu", + srcs = ["__init__.py"], + deps = [ + ":pallas_call_registration", + ], +) + +pytype_strict_library( + name = "pallas_call_registration", + srcs = ["pallas_call_registration.py"], + deps = [ + ":lowering", + "//third_party/py/jax", + "//third_party/py/jax:mlir", + "//third_party/py/jax:mosaic_gpu", + "//third_party/py/jax/_src/pallas", + ], +) + +pytype_strict_library( + name = "lowering", + srcs = ["lowering.py"], + deps = [ + "//third_party/py/jax", + "//third_party/py/jax:core", + "//third_party/py/jax:mlir", + "//third_party/py/jax:mosaic_gpu", + "//third_party/py/jax:util", + "//third_party/py/jax/_src/lib", + "//third_party/py/jax/_src/pallas", + "//third_party/py/numpy", + ], +) diff --git a/jax/_src/pallas/mosaic_gpu/__init__.py b/jax/_src/pallas/mosaic_gpu/__init__.py new file mode 100644 index 000000000..862a661e2 --- /dev/null +++ b/jax/_src/pallas/mosaic_gpu/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py new file mode 100644 index 000000000..11bc4add9 --- /dev/null +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -0,0 +1,297 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module for lowering JAX primitives to Mosaic GPU.""" + +from __future__ import annotations + +from collections.abc import Sequence +import dataclasses +import math +from typing import Any, cast + +import jax +from jax._src import core as jax_core +from jax._src import util +from jax._src.interpreters import mlir +from jax._src.lax import lax +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import arith as arith_dialect +from jax._src.lib.mlir.dialects import memref as memref_dialect +from jax._src.lib.mlir.dialects import nvgpu as nvgpu_dialect +from jax._src.pallas import core as pl_core +from jax._src.state import primitives as sp +from jax.experimental.mosaic import gpu as mosaic_gpu +from jax.experimental.mosaic.gpu import dsl as mgpu +import numpy as np + +# TODO(slebedev): Enable type checking. +# mypy: ignore-errors +# pytype: skip-file + +map, unsafe_map = util.safe_map, map +zip, unsafe_zip = util.safe_zip, zip + + +@dataclasses.dataclass +class ModuleContext: + name: str + grid_mapping: pl_core.GridMapping + + +@dataclasses.dataclass +class LoweringRuleContext: + context: ModuleContext + avals_in: Sequence[jax_core.ShapedArray] + avals_out: Sequence[jax_core.ShapedArray] + block_shapes: list[tuple[int | pl_core.Mapped, ...]] | None + + replace = dataclasses.replace + + +@dataclasses.dataclass +class LoweringResult: + module: ir.Module + grid: tuple[int, ...] + gmem_scratch_bytes: int + out_structs: tuple[jax.ShapeDtypeStruct, ...] + + +@dataclasses.dataclass +class BlockInfo: + full_shape_dtype: jax.ShapeDtypeStruct + start_indices: Sequence[Any] + block_shape: tuple[int, ...] + + +class LoweringError(Exception): + pass + + +def lower_jaxpr_to_module( + grid_mapping: pl_core.GridMapping, + in_structs: tuple[jax.ShapeDtypeStruct, ...], + out_structs: tuple[jax.ShapeDtypeStruct, ...], + jaxpr: jax_core.Jaxpr, + name: str, +) -> LoweringResult: + assert len(jaxpr.outvars) == 0 + assert not grid_mapping.mapped_dims + grid = grid_mapping.grid + if len(grid) < 3: + grid += (1,) * (3 - len(grid)) + block = (128,) + (1,) * (len(grid) - 1) + + def body(ctx: mosaic_gpu.LaunchContext, *buffers): + *buffers_gmem, buffers_smem = buffers + assert len(buffers_gmem) == len(buffers_smem) + in_buffers_gmem = buffers_gmem[: len(in_structs)] + in_buffers_smem = buffers_smem[: len(in_structs)] + out_buffers_gmem = buffers_gmem[len(in_structs) :] + out_buffers_smem = buffers_smem[len(in_structs) :] + + # arrival_count= determines the expected number of arrivals for each + # barrier in the array. It is not accidental that we do just a single + # mbarrier_arrive_expect_tx below. + # TODO(slebedev): Consider enforcing this in the mgpu.BarrierArray. + [barrier] = mgpu.BarrierArray(1, arrival_count=1) + + with mgpu.once(): + nvgpu_dialect.mbarrier_arrive_expect_tx( + barrier.barrier_array.value, + _index( + sum(math.prod(s.shape) * s.dtype.itemsize for s in in_structs) + ), + barrier.offset, + ) + + for b_gmem, b_smem in zip(in_buffers_gmem, in_buffers_smem): + # TODO(slebedev): Support 128-byteswizzling, once we can lower matmuls. + ctx.async_copy( + src_ref=b_gmem, + dst_ref=b_smem, + barrier=barrier, + swizzle=None, + arrive=False, + uniform=False, + ) + + barrier.wait() + + module_ctx = ModuleContext(name, grid_mapping) + _ = lower_jaxpr_to_mosaic_gpu(module_ctx, jaxpr, None, *buffers_smem) + + for b_gmem, b_smem in zip(out_buffers_gmem, out_buffers_smem): + # TODO(slebedev): Support 128-byteswizzling, once we can lower matmuls. + ctx.async_copy(src_ref=b_smem, dst_ref=b_gmem, swizzle=None) + + ctx.await_async_copy(0) + + module, out_structs, gmem_scratch_bytes, _ = mosaic_gpu._lower_as_gpu_kernel( + body, + grid, + block, + in_shape=in_structs, + out_shape=out_structs, + smem_scratch_shape=in_structs + out_structs, + ) + + return LoweringResult(module, grid, gmem_scratch_bytes, out_structs) + + +mosaic_lowering_rules = {} + + +def register_lowering_rule(primitive: jax_core.Primitive): + def deco(fn): + mosaic_lowering_rules[primitive] = fn + return fn + + return deco + + +def lower_jaxpr_to_mosaic_gpu( + ctx: ModuleContext, + jaxpr: jax_core.Jaxpr, + block_infos: Sequence[BlockInfo | None] | None, + *args, +) -> Sequence[ir.Value]: + env = {} + block_info_env = {} + + def read_env(atom: jax_core.Atom): + return atom.val if isinstance(atom, jax_core.Literal) else env[atom] + + def read_block_info_env(atom: jax_core.Atom): + if isinstance(atom, jax_core.Literal): + return None + return block_info_env.get(atom, None) + + def write_env(var: jax_core.Var, val): + env[var] = val + + if block_infos is None: + block_infos = [None] * len(jaxpr.invars) + for invar, block_info in zip(jaxpr.invars, block_infos): + block_info_env[invar] = block_info + map(write_env, jaxpr.invars, args) + for eqn in jaxpr.eqns: + invals = map(read_env, eqn.invars) + if eqn.primitive not in mosaic_lowering_rules: + raise NotImplementedError( + "Unimplemented primitive in Pallas Mosaic GPU lowering: " + f"{eqn.primitive.name}. " + "Please file an issue on https://github.com/google/jax/issues." + ) + rule = mosaic_lowering_rules[eqn.primitive] + rule_ctx = LoweringRuleContext( + ctx, + avals_in=[cast(jax_core.ShapedArray, v.aval) for v in eqn.invars], + avals_out=[cast(jax_core.ShapedArray, v.aval) for v in eqn.outvars], + block_shapes=map(read_block_info_env, eqn.invars), + ) + try: + outvals = rule(rule_ctx, *invals, **eqn.params) + except LoweringError: + raise # We only add the extra info to the innermost exception. + except Exception as e: + inval_types = map(lambda t: getattr(t, "type", None), invals) + raise LoweringError( + f"Exception while lowering eqn:\n {eqn}\nWith context:\n " + f" {rule_ctx}\nWith inval types={inval_types}\nIn jaxpr:\n{jaxpr}" + ) from e + if eqn.primitive.multiple_results: + map(write_env, eqn.outvars, outvals) + else: + write_env(eqn.outvars[0], outvals) + return map(read_env, jaxpr.outvars) + + +@register_lowering_rule(sp.get_p) +def _get_lowering_rule(ctx: LoweringRuleContext, x_smem, *indexers, tree): + del tree # Unused. + if indexers: + raise NotImplementedError("No support for indexers yet") + return mgpu.FragmentedArray.load_strided(x_smem) + + +@register_lowering_rule(sp.swap_p) +def _swap_lowering_rule( + ctx: LoweringRuleContext, x_smem, value, *indexers, tree +): + del tree # Unused. + if indexers: + raise NotImplementedError("No support for indexers yet") + old_value = mgpu.FragmentedArray.load_strided(x_smem) + value.store_untiled(x_smem) + return old_value + + +@register_lowering_rule(lax.add_p) +def _add_lowering_rule(ctx: LoweringRuleContext, x, y): + x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) + return x + y + + +def _bcast_to(a: Any, shape: tuple[int, ...]) -> ir.Value: + if not isinstance(a, mgpu.FragmentedArray): + if not shape: + return a + layout = mgpu.WGStridedFragLayout.from_memref_type( + memref_dialect.MemRefType.get(shape, a.type) + ) + return mgpu.FragmentedArray.splat(a, shape, layout) + else: + if a.shape == shape: + return a + raise NotImplementedError + + +def _bcast( + x: ir.Value, + y: ir.Value, + x_aval: jax_core.ShapedArray, + y_aval: jax_core.ShapedArray, + out_aval: jax_core.ShapedArray, +) -> ir.Value: + if isinstance(x, (np.ndarray, np.number, int, float)): + x_dtype = x_aval.dtype + if x_aval.weak_type: + x_dtype = y_aval.dtype + x = _ir_constant(x, mlir.dtype_to_ir_type(x_dtype)) + if isinstance(y, (np.ndarray, np.number, int, float)): + y_dtype = y_aval.dtype + if y_aval.weak_type: + y_dtype = x_aval.dtype + y = _ir_constant(y, mlir.dtype_to_ir_type(y_dtype)) + if x_aval.shape != out_aval.shape: + x = _bcast_to(x, out_aval.shape) + if y_aval.shape != out_aval.shape: + y = _bcast_to(y, out_aval.shape) + return x, y + + +def _ir_constant(v: object, t: ir.Type) -> ir.Value: + if isinstance(v, (np.number, np.ndarray, int, float)): + if isinstance(t, ir.IntegerType): + v = int(v) + else: + assert isinstance(t, ir.FloatType) + v = float(v) + return arith_dialect.constant(t, v) + raise NotImplementedError(f"Unsupported constant: {v!r}") + + +def _index(i: int) -> ir.Value: + return arith_dialect.constant(ir.IndexType.get(), i) diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py new file mode 100644 index 000000000..3403e5677 --- /dev/null +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -0,0 +1,90 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module registering a lowering rule for pallas_call on GPU.""" + + +from __future__ import annotations + +from typing import Any + +import jax +from jax import core as jax_core +from jax._src.interpreters import mlir +from jax._src.pallas import core as pallas_core +from jax._src.pallas.mosaic_gpu import lowering +from jax._src.pallas.pallas_call import pallas_call_p +from jax.experimental.mosaic import gpu as mosaic_gpu + + +def pallas_call_lowering( + ctx: mlir.LoweringRuleContext, + *args, + jaxpr: jax_core.Jaxpr, + name: str, + in_shapes: tuple[jax.ShapeDtypeStruct, ...], + out_shapes: tuple[jax.ShapeDtypeStruct, ...], + which_linear: tuple[bool, ...], + interpret: bool, + debug: bool, + input_output_aliases: tuple[tuple[int, int], ...], + grid_mapping: pallas_core.GridMapping, + compiler_params: dict[str, Any], +): + if interpret: + return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)( + ctx, + *args, + jaxpr=jaxpr, + name=name, + out_shapes=out_shapes, + in_shapes=in_shapes, + which_linear=which_linear, + interpret=interpret, + debug=debug, + input_output_aliases=input_output_aliases, + grid_mapping=grid_mapping, + compiler_params=compiler_params, + ) + + if grid_mapping.num_dynamic_grid_bounds: + raise NotImplementedError( + "dynamic grid bounds not supported in the Mosaic GPU backend" + ) + if input_output_aliases: + raise NotImplementedError( + "input_output_aliases not supported in the Mosaic GPU backend" + ) + + if debug: + print(jaxpr) + print(grid_mapping) + + lowering_result = lowering.lower_jaxpr_to_module( + grid_mapping, + in_shapes, + out_shapes, + jaxpr, + name, + ) + if debug: + print(lowering_result.module.operation) + + return mosaic_gpu._mosaic_gpu_lowering_rule( + ctx, + *args, + module=lowering_result.module, + gmem_scratch_bytes=lowering_result.gmem_scratch_bytes, + out_types=lowering_result.out_structs, + ) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index a52a41a2e..c2cf2fd4b 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -15,17 +15,16 @@ """Module for calling pallas functions from JAX.""" from __future__ import annotations -import itertools -from functools import partial -from functools import reduce - -from typing import Any, Callable from collections.abc import Sequence +import itertools +from functools import partial, reduce +from typing import Any, Callable import jax from jax import api_util from jax import tree_util from jax import lax +from jax._src import config from jax._src import state from jax._src.interpreters import ad from jax._src.interpreters import batching @@ -532,6 +531,16 @@ def _extract_function_name(f: Callable, name: str | None) -> str: return name +_PALLAS_USE_MOSAIC_GPU = config.DEFINE_bool( + "jax_pallas_use_mosaic_gpu", + default=config.bool_env("JAX_PALLAS_USE_MOSAIC_GPU", False), + help=( + "If True, lower Pallas kernels to the experimental Mosaic GPU" + " dialect, instead of Trition IR." + ), +) + + def _unsupported_lowering_error(platform: str) -> Exception: return ValueError( f"Cannot lower pallas_call on platform: {platform}. To use Pallas on GPU," @@ -560,7 +569,10 @@ def _pallas_call_lowering( raise ValueError("Only interpret mode is supported on CPU backend.") elif platform == "cuda" or platform == "rocm": try: - from jax._src.pallas.triton import pallas_call_registration # type: ignore + if _PALLAS_USE_MOSAIC_GPU.value: + from jax._src.pallas.mosaic_gpu import pallas_call_registration # type: ignore + else: + from jax._src.pallas.triton import pallas_call_registration # type: ignore except ImportError: pass else: diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index ed96a08d7..07017de1f 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -530,7 +530,7 @@ def _launch( gpu.terminator() -def as_gpu_kernel( +def _lower_as_gpu_kernel( body, grid: tuple[int, ...], block: tuple[int, ...], @@ -609,6 +609,31 @@ def as_gpu_kernel( main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() module.operation.verify() + dump_low_level(module) + + pass_manager = _get_mosaic_gpu_pipeline("fatbin") + if mosaic_gpu_print_after_all.value: + pass_manager.enable_ir_printing() + pass_manager.run(module.operation) + + return module, out_shape, gmem_scratch_bytes, unwrap_output_tuple + + +def as_gpu_kernel( + body, + grid: tuple[int, ...], + block: tuple[int, ...], + in_shape, + out_shape, + smem_scratch_shape, + prof_spec: profiler.ProfilerSpec | None = None, +): + module, out_shape, gmem_scratch_bytes, unwrap_output_tuple = ( + _lower_as_gpu_kernel( + body, grid, block, in_shape, out_shape, smem_scratch_shape, prof_spec + ) + ) + expected_arg_treedef = jax.tree.structure(in_shape) def _check_args(args): arg_treedef = jax.tree.structure(args) @@ -618,13 +643,6 @@ def as_gpu_kernel( f" {arg_treedef}" ) - dump_low_level(module) - - pass_manager = _get_mosaic_gpu_pipeline("fatbin") - if mosaic_gpu_print_after_all.value: - pass_manager.enable_ir_printing() - pass_manager.run(module.operation) - def bind(*args): return mosaic_gpu_p.bind( *args, diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index d7fddf1aa..4a8206c46 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -14,6 +14,7 @@ # ============================================================================== """Utilities for code generator.""" +from collections.abc import Iterator import contextlib import dataclasses from typing import Any, Literal, Sequence @@ -496,6 +497,7 @@ class BarrierArray: f" num_barriers={num_barriers}>" ) + self.num_barriers = num_barriers self.value = nvgpu.mbarrier_create(barrier_group_ty) self.num_barriers = num_barriers index = ir.IndexType.get() @@ -508,6 +510,10 @@ class BarrierArray: for i in range(num_barriers): nvgpu.mbarrier_init(self.value, c(arrival_count, index), c(i, index)) + def __iter__(self) -> Iterator["Barrier"]: + for offset in range(self.num_barriers): + yield self[offset] + def __getitem__(self, offset: ir.Value | int): if isinstance(offset, int): offset = c(offset, ir.IndexType.get()) diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 49af8c9be..7489c5147 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -254,6 +254,42 @@ jax_test( ] + py_deps("absl/testing") + py_deps("absl/flags") + py_deps("numpy") + py_deps("hypothesis"), ) +jax_test( + name = "mosaic_gpu_test", + srcs = [ + "mosaic_gpu_test.py", + ], + config_tags_overrides = { + # TODO(slebedev): Switch to False once Mosaic GPU is unconditionally enabled. + "gpu_h100_x32": { + "ondemand": True, # Include in presubmit. + }, + }, + disable_backends = [ + "cpu", + "tpu", + ], + disable_configs = [ + "gpu", + "gpu_x32", + "gpu_a100", + "gpu_a100_x32", + "gpu_p100", + "gpu_p100_x32", + "gpu_h100", + ], + enable_configs = [ + "gpu_h100_x32", + ], + env = { + "JAX_PALLAS_USE_MOSAIC_GPU": "1", + }, + deps = [ + "//jax:pallas", + "//jax:pallas_gpu", # build_cleaner: keep + ] + py_deps("absl/testing") + py_deps("numpy"), +) + jax_test( name = "export_back_compat_pallas_test", srcs = ["export_back_compat_pallas_test.py"], diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py new file mode 100644 index 000000000..10650d64c --- /dev/null +++ b/tests/pallas/mosaic_gpu_test.py @@ -0,0 +1,53 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools + +from absl.testing import absltest +import jax +from jax._src import test_util as jtu +from jax.experimental import pallas as pl +import jax.numpy as jnp +import numpy as np + + +jax.config.parse_flags_with_absl() + + +class PallasTest(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + + if not jtu.is_cuda_compute_capability_at_least("9.0"): + self.skipTest("Only works on a GPU with capability >= sm90") + + +class PallasCallTest(PallasTest): + + def test_add_one(self): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.float32), + ) + def add_one(x_ref, o_ref): + print(">>>", x_ref, o_ref) + o_ref[...] = x_ref[...] + 1.0 + + x = jnp.arange(256).astype(jnp.float32) + np.testing.assert_array_equal(add_one(x), x + 1.0) + + +if __name__ == "__main__": + absltest.main()