Added a very rough sketch of Mosaic GPU lowering for Pallas

Almost nothing is supported, including

* PyTree inputs/outputs
* indexers
* non-trivial grids
* block specs
* any primitives beyond the ones added here
* etc etc

PiperOrigin-RevId: 633713366
This commit is contained in:
Sergei Lebedev 2024-05-14 14:47:24 -07:00 committed by jax authors
parent 0ad5167da8
commit e2918ca138
10 changed files with 601 additions and 15 deletions

View File

@ -19,6 +19,7 @@ load("@rules_python//python:defs.bzl", "py_library")
load(
"//jaxlib:jax.bzl",
"if_building_jaxlib",
"if_building_mosaic_gpu",
"jax_extend_internal_users",
"jax_extra_deps",
"jax_internal_export_back_compat_test_util_visibility",
@ -669,7 +670,7 @@ pytype_strict_library(
deps = [
":pallas",
"//jax/_src/pallas/triton",
],
] + if_building_mosaic_gpu(["//third_party/py/jax/_src/pallas/mosaic_gpu"]),
)
# This target only supports sm_90 GPUs.

View File

@ -0,0 +1,60 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Package for Mosaic-specific Pallas extensions
load("//jaxlib:jax.bzl", "pytype_strict_library")
load("@rules_python//python:defs.bzl", "py_library")
package(
default_applicable_licenses = [],
default_visibility = [
"//third_party/py/jax:internal",
],
)
py_library(
name = "mosaic_gpu",
srcs = ["__init__.py"],
deps = [
":pallas_call_registration",
],
)
pytype_strict_library(
name = "pallas_call_registration",
srcs = ["pallas_call_registration.py"],
deps = [
":lowering",
"//third_party/py/jax",
"//third_party/py/jax:mlir",
"//third_party/py/jax:mosaic_gpu",
"//third_party/py/jax/_src/pallas",
],
)
pytype_strict_library(
name = "lowering",
srcs = ["lowering.py"],
deps = [
"//third_party/py/jax",
"//third_party/py/jax:core",
"//third_party/py/jax:mlir",
"//third_party/py/jax:mosaic_gpu",
"//third_party/py/jax:util",
"//third_party/py/jax/_src/lib",
"//third_party/py/jax/_src/pallas",
"//third_party/py/numpy",
],
)

View File

@ -0,0 +1,13 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,297 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for lowering JAX primitives to Mosaic GPU."""
from __future__ import annotations
from collections.abc import Sequence
import dataclasses
import math
from typing import Any, cast
import jax
from jax._src import core as jax_core
from jax._src import util
from jax._src.interpreters import mlir
from jax._src.lax import lax
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith as arith_dialect
from jax._src.lib.mlir.dialects import memref as memref_dialect
from jax._src.lib.mlir.dialects import nvgpu as nvgpu_dialect
from jax._src.pallas import core as pl_core
from jax._src.state import primitives as sp
from jax.experimental.mosaic import gpu as mosaic_gpu
from jax.experimental.mosaic.gpu import dsl as mgpu
import numpy as np
# TODO(slebedev): Enable type checking.
# mypy: ignore-errors
# pytype: skip-file
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
@dataclasses.dataclass
class ModuleContext:
name: str
grid_mapping: pl_core.GridMapping
@dataclasses.dataclass
class LoweringRuleContext:
context: ModuleContext
avals_in: Sequence[jax_core.ShapedArray]
avals_out: Sequence[jax_core.ShapedArray]
block_shapes: list[tuple[int | pl_core.Mapped, ...]] | None
replace = dataclasses.replace
@dataclasses.dataclass
class LoweringResult:
module: ir.Module
grid: tuple[int, ...]
gmem_scratch_bytes: int
out_structs: tuple[jax.ShapeDtypeStruct, ...]
@dataclasses.dataclass
class BlockInfo:
full_shape_dtype: jax.ShapeDtypeStruct
start_indices: Sequence[Any]
block_shape: tuple[int, ...]
class LoweringError(Exception):
pass
def lower_jaxpr_to_module(
grid_mapping: pl_core.GridMapping,
in_structs: tuple[jax.ShapeDtypeStruct, ...],
out_structs: tuple[jax.ShapeDtypeStruct, ...],
jaxpr: jax_core.Jaxpr,
name: str,
) -> LoweringResult:
assert len(jaxpr.outvars) == 0
assert not grid_mapping.mapped_dims
grid = grid_mapping.grid
if len(grid) < 3:
grid += (1,) * (3 - len(grid))
block = (128,) + (1,) * (len(grid) - 1)
def body(ctx: mosaic_gpu.LaunchContext, *buffers):
*buffers_gmem, buffers_smem = buffers
assert len(buffers_gmem) == len(buffers_smem)
in_buffers_gmem = buffers_gmem[: len(in_structs)]
in_buffers_smem = buffers_smem[: len(in_structs)]
out_buffers_gmem = buffers_gmem[len(in_structs) :]
out_buffers_smem = buffers_smem[len(in_structs) :]
# arrival_count= determines the expected number of arrivals for each
# barrier in the array. It is not accidental that we do just a single
# mbarrier_arrive_expect_tx below.
# TODO(slebedev): Consider enforcing this in the mgpu.BarrierArray.
[barrier] = mgpu.BarrierArray(1, arrival_count=1)
with mgpu.once():
nvgpu_dialect.mbarrier_arrive_expect_tx(
barrier.barrier_array.value,
_index(
sum(math.prod(s.shape) * s.dtype.itemsize for s in in_structs)
),
barrier.offset,
)
for b_gmem, b_smem in zip(in_buffers_gmem, in_buffers_smem):
# TODO(slebedev): Support 128-byteswizzling, once we can lower matmuls.
ctx.async_copy(
src_ref=b_gmem,
dst_ref=b_smem,
barrier=barrier,
swizzle=None,
arrive=False,
uniform=False,
)
barrier.wait()
module_ctx = ModuleContext(name, grid_mapping)
_ = lower_jaxpr_to_mosaic_gpu(module_ctx, jaxpr, None, *buffers_smem)
for b_gmem, b_smem in zip(out_buffers_gmem, out_buffers_smem):
# TODO(slebedev): Support 128-byteswizzling, once we can lower matmuls.
ctx.async_copy(src_ref=b_smem, dst_ref=b_gmem, swizzle=None)
ctx.await_async_copy(0)
module, out_structs, gmem_scratch_bytes, _ = mosaic_gpu._lower_as_gpu_kernel(
body,
grid,
block,
in_shape=in_structs,
out_shape=out_structs,
smem_scratch_shape=in_structs + out_structs,
)
return LoweringResult(module, grid, gmem_scratch_bytes, out_structs)
mosaic_lowering_rules = {}
def register_lowering_rule(primitive: jax_core.Primitive):
def deco(fn):
mosaic_lowering_rules[primitive] = fn
return fn
return deco
def lower_jaxpr_to_mosaic_gpu(
ctx: ModuleContext,
jaxpr: jax_core.Jaxpr,
block_infos: Sequence[BlockInfo | None] | None,
*args,
) -> Sequence[ir.Value]:
env = {}
block_info_env = {}
def read_env(atom: jax_core.Atom):
return atom.val if isinstance(atom, jax_core.Literal) else env[atom]
def read_block_info_env(atom: jax_core.Atom):
if isinstance(atom, jax_core.Literal):
return None
return block_info_env.get(atom, None)
def write_env(var: jax_core.Var, val):
env[var] = val
if block_infos is None:
block_infos = [None] * len(jaxpr.invars)
for invar, block_info in zip(jaxpr.invars, block_infos):
block_info_env[invar] = block_info
map(write_env, jaxpr.invars, args)
for eqn in jaxpr.eqns:
invals = map(read_env, eqn.invars)
if eqn.primitive not in mosaic_lowering_rules:
raise NotImplementedError(
"Unimplemented primitive in Pallas Mosaic GPU lowering: "
f"{eqn.primitive.name}. "
"Please file an issue on https://github.com/google/jax/issues."
)
rule = mosaic_lowering_rules[eqn.primitive]
rule_ctx = LoweringRuleContext(
ctx,
avals_in=[cast(jax_core.ShapedArray, v.aval) for v in eqn.invars],
avals_out=[cast(jax_core.ShapedArray, v.aval) for v in eqn.outvars],
block_shapes=map(read_block_info_env, eqn.invars),
)
try:
outvals = rule(rule_ctx, *invals, **eqn.params)
except LoweringError:
raise # We only add the extra info to the innermost exception.
except Exception as e:
inval_types = map(lambda t: getattr(t, "type", None), invals)
raise LoweringError(
f"Exception while lowering eqn:\n {eqn}\nWith context:\n "
f" {rule_ctx}\nWith inval types={inval_types}\nIn jaxpr:\n{jaxpr}"
) from e
if eqn.primitive.multiple_results:
map(write_env, eqn.outvars, outvals)
else:
write_env(eqn.outvars[0], outvals)
return map(read_env, jaxpr.outvars)
@register_lowering_rule(sp.get_p)
def _get_lowering_rule(ctx: LoweringRuleContext, x_smem, *indexers, tree):
del tree # Unused.
if indexers:
raise NotImplementedError("No support for indexers yet")
return mgpu.FragmentedArray.load_strided(x_smem)
@register_lowering_rule(sp.swap_p)
def _swap_lowering_rule(
ctx: LoweringRuleContext, x_smem, value, *indexers, tree
):
del tree # Unused.
if indexers:
raise NotImplementedError("No support for indexers yet")
old_value = mgpu.FragmentedArray.load_strided(x_smem)
value.store_untiled(x_smem)
return old_value
@register_lowering_rule(lax.add_p)
def _add_lowering_rule(ctx: LoweringRuleContext, x, y):
x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out)
return x + y
def _bcast_to(a: Any, shape: tuple[int, ...]) -> ir.Value:
if not isinstance(a, mgpu.FragmentedArray):
if not shape:
return a
layout = mgpu.WGStridedFragLayout.from_memref_type(
memref_dialect.MemRefType.get(shape, a.type)
)
return mgpu.FragmentedArray.splat(a, shape, layout)
else:
if a.shape == shape:
return a
raise NotImplementedError
def _bcast(
x: ir.Value,
y: ir.Value,
x_aval: jax_core.ShapedArray,
y_aval: jax_core.ShapedArray,
out_aval: jax_core.ShapedArray,
) -> ir.Value:
if isinstance(x, (np.ndarray, np.number, int, float)):
x_dtype = x_aval.dtype
if x_aval.weak_type:
x_dtype = y_aval.dtype
x = _ir_constant(x, mlir.dtype_to_ir_type(x_dtype))
if isinstance(y, (np.ndarray, np.number, int, float)):
y_dtype = y_aval.dtype
if y_aval.weak_type:
y_dtype = x_aval.dtype
y = _ir_constant(y, mlir.dtype_to_ir_type(y_dtype))
if x_aval.shape != out_aval.shape:
x = _bcast_to(x, out_aval.shape)
if y_aval.shape != out_aval.shape:
y = _bcast_to(y, out_aval.shape)
return x, y
def _ir_constant(v: object, t: ir.Type) -> ir.Value:
if isinstance(v, (np.number, np.ndarray, int, float)):
if isinstance(t, ir.IntegerType):
v = int(v)
else:
assert isinstance(t, ir.FloatType)
v = float(v)
return arith_dialect.constant(t, v)
raise NotImplementedError(f"Unsupported constant: {v!r}")
def _index(i: int) -> ir.Value:
return arith_dialect.constant(ir.IndexType.get(), i)

View File

@ -0,0 +1,90 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module registering a lowering rule for pallas_call on GPU."""
from __future__ import annotations
from typing import Any
import jax
from jax import core as jax_core
from jax._src.interpreters import mlir
from jax._src.pallas import core as pallas_core
from jax._src.pallas.mosaic_gpu import lowering
from jax._src.pallas.pallas_call import pallas_call_p
from jax.experimental.mosaic import gpu as mosaic_gpu
def pallas_call_lowering(
ctx: mlir.LoweringRuleContext,
*args,
jaxpr: jax_core.Jaxpr,
name: str,
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
which_linear: tuple[bool, ...],
interpret: bool,
debug: bool,
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping: pallas_core.GridMapping,
compiler_params: dict[str, Any],
):
if interpret:
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
ctx,
*args,
jaxpr=jaxpr,
name=name,
out_shapes=out_shapes,
in_shapes=in_shapes,
which_linear=which_linear,
interpret=interpret,
debug=debug,
input_output_aliases=input_output_aliases,
grid_mapping=grid_mapping,
compiler_params=compiler_params,
)
if grid_mapping.num_dynamic_grid_bounds:
raise NotImplementedError(
"dynamic grid bounds not supported in the Mosaic GPU backend"
)
if input_output_aliases:
raise NotImplementedError(
"input_output_aliases not supported in the Mosaic GPU backend"
)
if debug:
print(jaxpr)
print(grid_mapping)
lowering_result = lowering.lower_jaxpr_to_module(
grid_mapping,
in_shapes,
out_shapes,
jaxpr,
name,
)
if debug:
print(lowering_result.module.operation)
return mosaic_gpu._mosaic_gpu_lowering_rule(
ctx,
*args,
module=lowering_result.module,
gmem_scratch_bytes=lowering_result.gmem_scratch_bytes,
out_types=lowering_result.out_structs,
)

View File

@ -15,17 +15,16 @@
"""Module for calling pallas functions from JAX."""
from __future__ import annotations
import itertools
from functools import partial
from functools import reduce
from typing import Any, Callable
from collections.abc import Sequence
import itertools
from functools import partial, reduce
from typing import Any, Callable
import jax
from jax import api_util
from jax import tree_util
from jax import lax
from jax._src import config
from jax._src import state
from jax._src.interpreters import ad
from jax._src.interpreters import batching
@ -532,6 +531,16 @@ def _extract_function_name(f: Callable, name: str | None) -> str:
return name
_PALLAS_USE_MOSAIC_GPU = config.DEFINE_bool(
"jax_pallas_use_mosaic_gpu",
default=config.bool_env("JAX_PALLAS_USE_MOSAIC_GPU", False),
help=(
"If True, lower Pallas kernels to the experimental Mosaic GPU"
" dialect, instead of Trition IR."
),
)
def _unsupported_lowering_error(platform: str) -> Exception:
return ValueError(
f"Cannot lower pallas_call on platform: {platform}. To use Pallas on GPU,"
@ -560,6 +569,9 @@ def _pallas_call_lowering(
raise ValueError("Only interpret mode is supported on CPU backend.")
elif platform == "cuda" or platform == "rocm":
try:
if _PALLAS_USE_MOSAIC_GPU.value:
from jax._src.pallas.mosaic_gpu import pallas_call_registration # type: ignore
else:
from jax._src.pallas.triton import pallas_call_registration # type: ignore
except ImportError:
pass

View File

@ -530,7 +530,7 @@ def _launch(
gpu.terminator()
def as_gpu_kernel(
def _lower_as_gpu_kernel(
body,
grid: tuple[int, ...],
block: tuple[int, ...],
@ -609,6 +609,31 @@ def as_gpu_kernel(
main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
module.operation.verify()
dump_low_level(module)
pass_manager = _get_mosaic_gpu_pipeline("fatbin")
if mosaic_gpu_print_after_all.value:
pass_manager.enable_ir_printing()
pass_manager.run(module.operation)
return module, out_shape, gmem_scratch_bytes, unwrap_output_tuple
def as_gpu_kernel(
body,
grid: tuple[int, ...],
block: tuple[int, ...],
in_shape,
out_shape,
smem_scratch_shape,
prof_spec: profiler.ProfilerSpec | None = None,
):
module, out_shape, gmem_scratch_bytes, unwrap_output_tuple = (
_lower_as_gpu_kernel(
body, grid, block, in_shape, out_shape, smem_scratch_shape, prof_spec
)
)
expected_arg_treedef = jax.tree.structure(in_shape)
def _check_args(args):
arg_treedef = jax.tree.structure(args)
@ -618,13 +643,6 @@ def as_gpu_kernel(
f" {arg_treedef}"
)
dump_low_level(module)
pass_manager = _get_mosaic_gpu_pipeline("fatbin")
if mosaic_gpu_print_after_all.value:
pass_manager.enable_ir_printing()
pass_manager.run(module.operation)
def bind(*args):
return mosaic_gpu_p.bind(
*args,

View File

@ -14,6 +14,7 @@
# ==============================================================================
"""Utilities for code generator."""
from collections.abc import Iterator
import contextlib
import dataclasses
from typing import Any, Literal, Sequence
@ -496,6 +497,7 @@ class BarrierArray:
f" num_barriers={num_barriers}>"
)
self.num_barriers = num_barriers
self.value = nvgpu.mbarrier_create(barrier_group_ty)
self.num_barriers = num_barriers
index = ir.IndexType.get()
@ -508,6 +510,10 @@ class BarrierArray:
for i in range(num_barriers):
nvgpu.mbarrier_init(self.value, c(arrival_count, index), c(i, index))
def __iter__(self) -> Iterator["Barrier"]:
for offset in range(self.num_barriers):
yield self[offset]
def __getitem__(self, offset: ir.Value | int):
if isinstance(offset, int):
offset = c(offset, ir.IndexType.get())

View File

@ -254,6 +254,42 @@ jax_test(
] + py_deps("absl/testing") + py_deps("absl/flags") + py_deps("numpy") + py_deps("hypothesis"),
)
jax_test(
name = "mosaic_gpu_test",
srcs = [
"mosaic_gpu_test.py",
],
config_tags_overrides = {
# TODO(slebedev): Switch to False once Mosaic GPU is unconditionally enabled.
"gpu_h100_x32": {
"ondemand": True, # Include in presubmit.
},
},
disable_backends = [
"cpu",
"tpu",
],
disable_configs = [
"gpu",
"gpu_x32",
"gpu_a100",
"gpu_a100_x32",
"gpu_p100",
"gpu_p100_x32",
"gpu_h100",
],
enable_configs = [
"gpu_h100_x32",
],
env = {
"JAX_PALLAS_USE_MOSAIC_GPU": "1",
},
deps = [
"//jax:pallas",
"//jax:pallas_gpu", # build_cleaner: keep
] + py_deps("absl/testing") + py_deps("numpy"),
)
jax_test(
name = "export_back_compat_pallas_test",
srcs = ["export_back_compat_pallas_test.py"],

View File

@ -0,0 +1,53 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from absl.testing import absltest
import jax
from jax._src import test_util as jtu
from jax.experimental import pallas as pl
import jax.numpy as jnp
import numpy as np
jax.config.parse_flags_with_absl()
class PallasTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if not jtu.is_cuda_compute_capability_at_least("9.0"):
self.skipTest("Only works on a GPU with capability >= sm90")
class PallasCallTest(PallasTest):
def test_add_one(self):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
)
def add_one(x_ref, o_ref):
print(">>>", x_ref, o_ref)
o_ref[...] = x_ref[...] + 1.0
x = jnp.arange(256).astype(jnp.float32)
np.testing.assert_array_equal(add_one(x), x + 1.0)
if __name__ == "__main__":
absltest.main()