Initial commit for Mosaic GPU

Moving this to JAX to make it easier to explore Pallas integration.

PiperOrigin-RevId: 625982382
This commit is contained in:
Adam Paszke 2024-04-18 04:03:03 -07:00 committed by jax authors
parent c4dea624cc
commit 8e3f5b1018
26 changed files with 3755 additions and 11 deletions

View File

@ -111,6 +111,10 @@ build:cuda_clang --copt=-Wno-gnu-offsetof-extensions
# Disable clang extention that rejects unknown arguments.
build:cuda_clang --copt=-Qunused-arguments
build:mosaic_gpu --@llvm-project//mlir:enable_cuda=true
build:mosaic_gpu --copt=-DLLVM_HAS_NVPTX_TARGET=1
build:mosaic_gpu --//jax:build_mosaic_gpu=true
build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
build:rocm --@xla//xla/python:enable_gpu=true

View File

@ -140,7 +140,7 @@ jobs:
PY_COLORS: 1
run: |
pytest -n auto --tb=short docs
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas
documentation_render:

View File

@ -263,7 +263,7 @@ def write_bazelrc(*, python_bin_path, remote_build,
rocm_amdgpu_targets, bazel_options, target_cpu_features,
wheel_cpu, enable_mkl_dnn, use_clang, clang_path,
clang_major_version, enable_cuda, enable_nccl, enable_rocm,
build_gpu_plugin):
build_gpu_plugin, enable_mosaic_gpu):
tf_cuda_paths = []
with open("../.jax_configure.bazelrc", "w") as f:
@ -337,6 +337,8 @@ def write_bazelrc(*, python_bin_path, remote_build,
if use_clang:
f.write("build --config=nvcc_clang\n")
f.write(f"build --action_env=CLANG_CUDA_COMPILER_PATH={clang_path}\n")
if enable_mosaic_gpu:
f.write("build --config=mosaic_gpu")
if enable_rocm:
f.write("build --config=rocm\n")
if not enable_nccl:
@ -541,6 +543,10 @@ def main():
"--editable",
action="store_true",
help="Create an 'editable' jaxlib build instead of a wheel.")
add_boolean_argument(
parser,
"enable_mosaic_gpu",
help_str="Should we build with Mosaic GPU? VERY EXPERIMENTAL.")
add_boolean_argument(
parser,
"configure_only",
@ -652,6 +658,7 @@ def main():
enable_nccl=args.enable_nccl,
enable_rocm=args.enable_rocm,
build_gpu_plugin=args.build_gpu_plugin,
enable_mosaic_gpu=args.enable_mosaic_gpu,
)
if args.configure_only:

View File

@ -70,6 +70,19 @@ config_setting(
},
)
# If this flag is true, jaxlib will be built with Mosaic GPU. VERY EXPERIMENTAL.
bool_flag(
name = "build_mosaic_gpu",
build_setting_default = False,
)
config_setting(
name = "enable_mosaic_gpu",
flag_values = {
":build_mosaic_gpu": "True",
},
)
exports_files([
"LICENSE",
"version.py",
@ -116,6 +129,14 @@ package_group(
] + pallas_tpu_internal_users,
)
package_group(
name = "mosaic_gpu_users",
packages = [
"//...",
"//learning/brain/research/jax",
],
)
# JAX-private test utilities.
py_library(
# This build target is required in order to use private test utilities in jax._src.test_util,
@ -647,6 +668,37 @@ pytype_strict_library(
],
)
# This target only supports sm_90 GPUs.
py_library(
name = "mosaic_gpu",
srcs = glob(["experimental/mosaic/gpu/*.py"]),
visibility = [
":mosaic_gpu_users",
],
deps = [
":config",
":jax",
":mlir",
"//jax/_src/lib",
"//third_party/py/absl/flags",
"//jaxlib/mlir:arithmetic_dialect",
"//jaxlib/mlir:builtin_dialect",
"//jaxlib/mlir:execution_engine",
"//jaxlib/mlir:func_dialect",
"//jaxlib/mlir:gpu_dialect",
"//jaxlib/mlir:ir",
"//jaxlib/mlir:llvm_dialect",
"//jaxlib/mlir:math_dialect",
"//jaxlib/mlir:memref_dialect",
"//jaxlib/mlir:nvgpu_dialect",
"//jaxlib/mlir:nvvm_dialect",
"//jaxlib/mlir:pass_manager",
"//jaxlib/mlir:scf_dialect",
"//jaxlib/mlir:vector_dialect",
"//third_party/py/numpy",
],
)
pytype_strict_library(
name = "partial_eval",
srcs = ["_src/interpreters/partial_eval.py"],

View File

@ -15,6 +15,7 @@
load(
"//jaxlib:jax.bzl",
"if_building_jaxlib",
"if_building_mosaic_gpu",
"jax_visibility",
"py_library_providing_imports_info",
"pytype_strict_library",
@ -31,6 +32,7 @@ py_library_providing_imports_info(
"__init__.py",
"mlir/__init__.py",
"mlir/dialects/__init__.py",
"mosaic_gpu.py",
"triton.py",
],
lib_rule = pytype_strict_library,
@ -58,5 +60,5 @@ py_library_providing_imports_info(
"//jaxlib/mlir:stablehlo_dialect",
"//jaxlib/mlir:vector_dialect",
# xla_client
]),
]) + if_building_mosaic_gpu(["//jaxlib/mosaic/gpu:mosaic_gpu"]),
)

View File

@ -23,6 +23,13 @@ import jaxlib.mlir.dialects.func as func
import jaxlib.mlir.dialects.scf as scf
import jaxlib.mlir.dialects.sparse_tensor as sparse_tensor
import jaxlib.mlir.dialects.vector as vector
try:
import jaxlib.mlir.dialects.gpu as gpu # type: ignore
import jaxlib.mlir.dialects.nvgpu as nvgpu # type: ignore
import jaxlib.mlir.dialects.nvvm as nvvm # type: ignore
import jaxlib.mlir.dialects.llvm as llvm # type: ignore
except ImportError:
pass
from jax._src import lib

View File

@ -0,0 +1,23 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
try:
from jaxlib.mosaic.gpu import _mosaic_gpu_ext # pytype: disable=import-error
except ImportError as e:
raise ModuleNotFoundError(
"Cannot import the Mosaic GPU bindings. You may need to build jaxlib from"
" source."
) from e

View File

@ -0,0 +1,595 @@
# Copyright 2024 The JAX Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import contextlib
import ctypes
import dataclasses
import os
import pathlib
import subprocess
import tempfile
import time
from typing import Any, Sequence
import jax
from jax._src import config
from jax._src.interpreters import mlir
from jax._src.lib import xla_client
from jax._src.lib import mosaic_gpu as mosaic_gpu_lib
from jaxlib.mlir import ir
from jaxlib.mlir.dialects import arith
from jaxlib.mlir.dialects import builtin
from jaxlib.mlir.dialects import func
from jaxlib.mlir.dialects import gpu
from jaxlib.mlir.dialects import llvm
from jaxlib.mlir.dialects import memref
from jaxlib.mlir.dialects import nvgpu
from jaxlib.mlir.dialects import nvvm
from jaxlib.mlir.execution_engine import ExecutionEngine
from jaxlib.mlir.passmanager import PassManager
import numpy as np
from . import dsl as mgpu
from . import profiler
from . import utils
# mypy: ignore-errors
# MLIR can't find libdevice unless we point it to the CUDA path
# TODO(apaszke): Unify with jax._src.lib.cuda_path
CUDA_ROOT = "/usr/local/cuda"
if os.environ.get("CUDA_ROOT") is None:
os.environ["CUDA_ROOT"] = CUDA_ROOT
else:
CUDA_ROOT = os.environ["CUDA_ROOT"]
PTXAS_PATH = os.path.join(CUDA_ROOT, "bin/ptxas")
NVDISASM_PATH = os.path.join(CUDA_ROOT, "bin/nvdisasm")
c = mgpu.c # This is too common to fully qualify.
xla_client.register_custom_call_target(
"mosaic_gpu",
mosaic_gpu_lib._mosaic_gpu_ext._custom_call_capsule(),
platform="CUDA",
)
mosaic_gpu_dump_ptx = config.define_bool_state(
name="mosaic_gpu_dump_ptx",
default=config.bool_env("MOSAIC_GPU_DUMP_PTX", False),
help="If set, prints the kernel PTX",
)
mosaic_gpu_dump_ptxas = config.define_bool_state(
name="mosaic_gpu_dump_ptxas",
default=config.bool_env("MOSAIC_GPU_DUMP_PTXAS", False),
help="If set, prints the ptxas verbose output",
)
mosaic_gpu_dump_sass = config.define_bool_state(
name="mosaic_gpu_dump_sass",
default=config.bool_env("MOSAIC_GPU_DUMP_SASS", False),
help="If set, prints the kernel SASS",
)
mosaic_gpu_print_after_all = config.define_bool_state(
name='mosaic_gpu_print_after_all',
default=config.bool_env('MOSAIC_GPU_PRINT_AFTER_ALL', False),
help="If set, prints the kernel module after every pass",
)
mosaic_gpu_p = jax.core.Primitive("mosaic_gpu_p")
mosaic_gpu_p.multiple_results = True
@mosaic_gpu_p.def_abstract_eval
def _mosaic_gpu_abstract_eval(*_, module, out_types):
return [jax._src.core.ShapedArray(t.shape, t.dtype) for t in out_types]
def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types):
runtime_path = (
pathlib.Path(mosaic_gpu_lib._mosaic_gpu_ext.__file__).parent
/ "libmlir_cuda_runtime.so"
)
shared_libs = [str(runtime_path)] if runtime_path.exists() else []
engine = ExecutionEngine(
module, opt_level=3, shared_libs=shared_libs, enable_object_dump=False
)
ctx.module_context.add_keepalive(engine)
func_ptr = engine.lookup("main")
ptr_bytes = ctypes.cast(func_ptr, ctypes.c_void_p).value.to_bytes(
8, byteorder="little"
) # pytype: disable=attribute-error
op = mlir.custom_call(
"mosaic_gpu",
result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
operands=args,
backend_config=ptr_bytes,
)
return op.results
mlir.register_lowering(mosaic_gpu_p, _mosaic_gpu_lowering_rule, "cuda")
@dataclasses.dataclass(frozen=True)
class MemRefTransform:
def apply(self, ref: ir.Value) -> ir.Value:
raise NotImplementedError("Subclasses should override this method")
def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]:
raise NotImplementedError("Subclasses should override this method")
def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]:
raise NotImplementedError("Subclasses should override this method")
@dataclasses.dataclass(frozen=True)
class TileTransform(MemRefTransform):
"""Tiles a suffix of memref dimensions.
For example, given a memref of shape (5, 128, 128) and a tiling of (64, 32),
the shape of the result will be (5, 2, 4, 64, 32). The shape always ends with
the tile shape, and the size of tiled dimensions is divided by the tile size.
This is especially useful for swizzled WGMMA, which expect tiled layouts in
shared memory.
"""
tiling: tuple[int, ...]
def apply(self, ref: ir.Value) -> ir.Value:
untiled_rank = ir.MemRefType(ref.type).rank
tiling_rank = len(self.tiling)
tiled_rank = untiled_rank + tiling_rank
for t, d in zip(self.tiling[::-1], range(untiled_rank)[::-1]):
ref = mgpu.memref_unfold(ref, d, (None, t))
permutation = (
*range(untiled_rank - tiling_rank),
*range(untiled_rank - tiling_rank, tiled_rank, 2),
*range(untiled_rank - tiling_rank + 1, tiled_rank, 2),
)
return mgpu.memref_transpose(ref, permutation)
def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]:
index = ir.IndexType.get()
tiling_rank = len(self.tiling)
return (
*idx[:-tiling_rank],
*(
arith.divui(i, c(t, index))
for i, t in zip(idx[-tiling_rank:], self.tiling)
),
*([c(0, index)] * tiling_rank),
)
def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]:
# Note that this also checks that tiled dims are not squeezed. Their slice
# size would be 1 if so.
tiling_rank = len(self.tiling)
for size, tile_size in zip(shape[-tiling_rank:], self.tiling):
if size % tile_size:
raise ValueError(
f"Expected GMEM slice shape {shape} suffix to be a multiple"
f" of tiling {self.tiling}"
)
return (
*shape[:-tiling_rank],
*(s // t for s, t in zip(shape[-tiling_rank:], self.tiling)),
*self.tiling,
)
@dataclasses.dataclass(frozen=True)
class TransposeTransform(MemRefTransform):
"""Transposes memref dimensions."""
permutation: tuple[int, ...]
def __post_init__(self):
if len(self.permutation) != len(set(self.permutation)):
raise ValueError("Permutation must be a permutation")
def apply(self, ref: ir.Value) -> ir.Value:
return mgpu.memref_transpose(ref, self.permutation)
def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]:
return tuple(idx[p] for p in self.permutation)
def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]:
return tuple(shape[p] for p in self.permutation)
OnDeviceProfiler = profiler.OnDeviceProfiler
@dataclasses.dataclass()
class LaunchContext:
launch_op: gpu.LaunchOp
profiler: OnDeviceProfiler | None = None
tma_descriptors: dict[
tuple[ir.Value, tuple[int, ...], int | None, tuple[MemRefTransform, ...]],
ir.Value,
] = dataclasses.field(default_factory=dict, init=False)
@contextlib.contextmanager
def named_region(self, *args, **kwargs):
if self.profiler is not None:
with self.profiler.record(*args, **kwargs):
yield
else:
yield
def _get_tma_desc(
self,
ref,
gmem_transform: tuple[MemRefTransform, ...],
transformed_slice_shape: tuple[int, ...],
swizzle: int | None,
):
index = ir.IndexType.get()
ref_ty = ir.MemRefType(ref.type)
tma_desc_key = (ref, transformed_slice_shape, swizzle, gmem_transform)
if (tma_desc := self.tma_descriptors.get(tma_desc_key, None)) is None:
swizzle_str = f"swizzle_{swizzle}b" if swizzle is not None else "none"
default_tensor_map_attrs = dict(
swizzle=swizzle_str, l2promo="none", oob="zero", interleave="none"
)
tensor_map_ty = utils.get_tensormap_descriptor(
tensor=(
f"memref<{'x'.join(map(str, transformed_slice_shape))}x{ref_ty.element_type}, 3>"
),
**default_tensor_map_attrs,
)
with ir.InsertionPoint(self.launch_op):
for t in gmem_transform:
ref = t.apply(ref)
ref_unranked = memref.cast(
ir.UnrankedMemRefType.get(ref_ty.element_type, None), ref
)
tma_desc = nvgpu.tma_create_descriptor(
tensor_map_ty,
ref_unranked,
[c(s, index) for s in transformed_slice_shape],
)
self.tma_descriptors[tma_desc_key] = tma_desc
return tma_desc
def async_copy(
self,
*,
src_ref,
dst_ref,
gmem_slice: Any = (),
gmem_transform: MemRefTransform | tuple[MemRefTransform, ...] = (),
barrier: mgpu.Barrier | None = None,
swizzle: int | None = None,
arrive: bool | None = None,
uniform: bool = True,
):
index = ir.IndexType.get()
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
src_ref_ty = ir.MemRefType(src_ref.type)
dst_ref_ty = ir.MemRefType(dst_ref.type)
element_type = src_ref_ty.element_type
if element_type != dst_ref_ty.element_type:
raise ValueError(
f"Expected same element type, got {element_type} and"
f" {dst_ref_ty.element_type}"
)
if not isinstance(gmem_transform, tuple):
gmem_transform = (gmem_transform,)
if src_ref_ty.memory_space is None and dst_ref_ty.memory_space == smem:
gmem_ref, smem_ref = src_ref, dst_ref
if barrier is None:
raise ValueError("Barriers are required for GMEM -> SMEM copies")
if arrive is None:
arrive = True # Arrive by default
elif src_ref_ty.memory_space == smem and dst_ref_ty.memory_space is None:
gmem_ref, smem_ref = dst_ref, src_ref
if barrier is not None:
raise ValueError("Barriers are unsupported for SMEM -> GMEM copies")
if arrive is not None:
raise ValueError("arrive is unsupported for SMEM -> GMEM copies")
else:
raise ValueError("Only SMEM <-> GMEM copies supported")
# TODO(apaszke): This is a very approximate check. Improve it!
expected_name = "builtin.unrealized_conversion_cast"
if (
gmem_ref.owner is None
or gmem_ref.owner.opview.OPERATION_NAME != expected_name
):
raise ValueError("GMEM reference in async_copy must be a kernel argument")
base_indices, slice_shape, is_squeezed = utils.parse_indices(
gmem_slice, ir.MemRefType(gmem_ref.type).shape
)
dyn_base_indices = tuple(
c(i, index) if not isinstance(i, ir.Value) else i for i in base_indices
)
slice_shape = tuple(slice_shape)
for t in gmem_transform:
dyn_base_indices = t.transform_index(dyn_base_indices)
slice_shape = t.transform_shape(slice_shape)
for dim, squeezed in enumerate(is_squeezed):
if squeezed:
smem_ref = mgpu.memref_unsqueeze(smem_ref, dim)
smem_ref_ty = ir.MemRefType(smem_ref.type)
if slice_shape != tuple(smem_ref_ty.shape):
raise ValueError(
"Expected the SMEM reference to have the same shape as the tiled"
f" slice: {tuple(smem_ref_ty.shape)} != {slice_shape}"
)
tma_desc = self._get_tma_desc(
gmem_ref, gmem_transform, slice_shape, swizzle,
)
# nvgpu TMA instructions expect reversed indices...
rev_dyn_based_indices = reversed(dyn_base_indices)
uniform_ctx = mgpu.once if uniform else contextlib.nullcontext
if gmem_ref is src_ref:
with uniform_ctx():
assert barrier is not None # for pytype
barrier_group = barrier.barrier_array.value
barrier_idx = barrier.offset
if arrive:
slice_bytes = c(
np.prod(slice_shape) * mgpu.bytewidth(element_type), index
)
nvgpu.mbarrier_arrive_expect_tx(
barrier_group, slice_bytes, barrier_idx
)
nvgpu.tma_async_load(
smem_ref, barrier_group, tma_desc, rev_dyn_based_indices, barrier_idx
)
else:
with uniform_ctx():
nvgpu.tma_async_store(smem_ref, tma_desc, rev_dyn_based_indices)
nvvm.cp_async_bulk_commit_group()
def await_async_copy(
self, allow_groups: int, await_read_only: bool = False
):
nvvm.cp_async_bulk_wait_group(allow_groups, read=await_read_only)
gpu.barrier() # Groups are supposedly tracked per-thread
def _prefetch_tma_descs(self):
with ir.InsertionPoint(self.launch_op.body.blocks[0]):
with mgpu.once():
for desc in self.tma_descriptors.values():
nvgpu.tma_prefetch_descriptor(desc)
@contextlib.contextmanager
def _launch(
token,
grid,
block,
smem_buffers,
profiler_spec: profiler.ProfilerSpec | None = None,
maybe_prof_buffer: ir.Value | None = None,
):
if (profiler_spec is None) != (maybe_prof_buffer is None):
raise ValueError
index = ir.IndexType.get()
i32 = ir.IntegerType.get_signless(32)
i8 = ir.IntegerType.get_signless(8)
grid_vals = [c(i, index) for i in grid]
block_vals = [c(i, index) for i in block]
flat_refs, smem_buffer_tree = jax.tree.flatten(smem_buffers)
smem_ref_bytes = []
for ref_ty in flat_refs:
smem_ref_bytes.append(
np.prod(ref_ty.shape) * np.dtype(ref_ty.dtype).itemsize
)
smem_bytes = sum(smem_ref_bytes)
if profiler_spec is not None:
smem_bytes += profiler_spec.smem_bytes(grid)
launch_op = gpu.LaunchOp(
token.type, [token], *grid_vals, *block_vals,
dynamicSharedMemorySize=c(smem_bytes, i32))
launch_op.body.blocks.append(*([index] * 12)) # Append an empty block
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
with ir.InsertionPoint(launch_op.body.blocks[0]):
dynamic_smem = gpu.dynamic_shared_memory(
ir.MemRefType.get(
(ir.ShapedType.get_dynamic_size(),), i8, memory_space=smem
)
)
smem_refs = []
dynamic_smem_offset = 0
for ref_ty, ref_bytes in zip(flat_refs, smem_ref_bytes):
mlir_dtype = mlir.dtype_to_ir_type(ref_ty.dtype)
tile_smem = memref.view(
ir.MemRefType.get(ref_ty.shape, mlir_dtype, memory_space=smem),
dynamic_smem, c(dynamic_smem_offset, index), [],
)
dynamic_smem_offset += ref_bytes
smem_refs.append(tile_smem)
if profiler_spec:
prof_smem = memref.view(
ir.MemRefType.get(
(profiler_spec.smem_i32_elements(grid=grid),),
i32, memory_space=smem,
),
dynamic_smem, c(dynamic_smem_offset, index), [],
)
prof = profiler.OnDeviceProfiler(
profiler_spec, prof_smem, maybe_prof_buffer
)
else:
prof = None
smem_ref_tree = jax.tree.unflatten(smem_buffer_tree, smem_refs)
yield LaunchContext(launch_op, prof), smem_ref_tree
if prof is not None:
prof.finalize(grid=grid)
gpu.terminator()
def as_gpu_kernel(
body,
grid: tuple[int, ...],
block: tuple[int, ...],
in_shape,
out_shape,
smem_scratch_shape,
prof_spec: profiler.ProfilerSpec | None = None,
):
ptr_ty = ir.Type.parse("!llvm.ptr")
token_ty = ir.Type.parse("!gpu.async.token")
def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType:
return ir.MemRefType.get(shape.shape, mlir.dtype_to_ir_type(shape.dtype))
if isinstance(in_shape, list):
in_shape = tuple(in_shape)
elif not isinstance(in_shape, tuple):
in_shape = (in_shape,)
in_ref_tys = [_shape_to_ref_ty(t) for t in in_shape]
unwrap_output_tuple = False
if isinstance(out_shape, list):
out_shape = tuple(out_shape)
elif not isinstance(out_shape, tuple):
out_shape = (out_shape,)
unwrap_output_tuple = True
out_ref_tys = [_shape_to_ref_ty(t) for t in out_shape]
if prof_spec is not None:
out_shape = (*out_shape, prof_spec.jax_buffer_type)
out_ref_tys.append(prof_spec.mlir_buffer_type)
module = ir.Module.create()
with ir.InsertionPoint(module.body):
@func.FuncOp.from_py_func(ptr_ty, ptr_ty)
def main(token_ptr, buffers):
token = builtin.unrealized_conversion_cast([token_ty], [token_ptr])
arg_refs = []
for i, ref_ty in enumerate([*in_ref_tys, *out_ref_tys]):
ptr = llvm.LoadOp(ptr_ty, llvm.GEPOp(ptr_ty, buffers, [], [i], ptr_ty))
arg_refs.append(utils.ptr_as_memref(ptr, ir.MemRefType(ref_ty)))
in_refs = arg_refs[:len(in_ref_tys)]
out_refs = arg_refs[len(in_ref_tys):]
prof_buffer = out_refs.pop() if prof_spec is not None else None
with _launch(
token, grid, block, smem_scratch_shape, prof_spec, prof_buffer
) as (launch_ctx, smem_refs):
body(launch_ctx, *in_refs, *out_refs, smem_refs)
main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
module.operation.verify()
expected_arg_treedef = jax.tree.structure(in_shape)
def _check_args(args):
arg_treedef = jax.tree.structure(args)
if arg_treedef != expected_arg_treedef:
raise ValueError(
f"Invalid argument structure: expected {expected_arg_treedef}, got"
f" {arg_treedef}"
)
dump_low_level(module)
pass_manager = PassManager.parse(
"builtin.module(gpu-lower-to-nvvm-pipeline{cubin-format=fatbin"
" cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3})"
)
if mosaic_gpu_print_after_all.value:
pass_manager.enable_ir_printing()
pass_manager.run(module.operation)
def bind(*args):
return mosaic_gpu_p.bind(*args, out_types=out_shape, module=module)
if prof_spec is not None:
@jax.jit
def prof_kernel(*args):
_check_args(args)
*results, prof_buffer = bind(*args)
def dump_profile(prof_buffer):
out_file = os.path.join(
os.getenv("TEST_UNDECLARED_OUTPUTS_DIR"),
f"{time.time_ns()}-trace.json",
)
try:
with open(out_file, "x") as f:
prof_spec.dump(prof_buffer, f)
except FileExistsError:
pass # TODO: Retry
jax.debug.callback(dump_profile, prof_buffer)
return results[0] if unwrap_output_tuple else results
return prof_kernel
else:
@jax.jit
def kernel(*args):
_check_args(args)
results = bind(*args)
return results[0] if unwrap_output_tuple else results
return kernel
def dump_low_level(module):
dump_ptx = mosaic_gpu_dump_ptx.value
dump_ptxas = mosaic_gpu_dump_ptxas.value
dump_sass = mosaic_gpu_dump_sass.value
if not any([dump_ptx, dump_ptxas, dump_sass]):
return
module = ir.Module.parse(
module.operation.get_asm(binary=True, enable_debug_info=True)
)
pm = PassManager.parse(
"builtin.module(gpu-lower-to-nvvm-pipeline{cubin-format=isa"
" cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3})"
)
pm.run(module.operation)
for op in module.body:
if op.OPERATION_NAME == "gpu.binary":
objects = ir.ArrayAttr(op.objects)
if len(objects) != 1:
raise NotImplementedError("Expected a single object")
obj = str(objects[0])
start = obj.find('assembly = "') + len('assembly = "')
end = obj.find('"', start)
ptx = obj[start:end]
ptx = ptx.replace("\\09", "\t").replace("\\0A", "\n")[:-3]
if dump_ptx:
print(ptx)
if dump_ptxas or dump_sass:
with tempfile.TemporaryDirectory() as tmp:
ptx_path = os.path.join(tmp, "kernel.ptx")
with open(ptx_path, "w") as f:
f.write(ptx)
elf_path = os.path.join(tmp, 'kernel.o')
v_flag = "-v" if dump_ptxas else ""
ptxas_flags = f"{v_flag} --opt-level 3 --gpu-name sm_90a"
ptxas_out = subprocess.check_output(
f"{PTXAS_PATH} {ptxas_flags} --output-file {elf_path} {ptx_path}",
stderr=subprocess.STDOUT,
shell=True,
)
if dump_ptxas:
print(ptxas_out.decode())
if dump_sass:
sass = subprocess.check_output(
f"{NVDISASM_PATH} -ndf -c {elf_path}",
stderr=subprocess.STDOUT,
shell=True,
)
print(sass.decode())

View File

@ -0,0 +1,47 @@
# Copyright 2024 The JAX Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from .fragmented_array import (
FragmentedArray,
FragmentedLayout,
WGMMA_LAYOUT,
WGMMA_ROW_LAYOUT,
WGStridedFragLayout,
)
from .utils import (
Barrier,
BarrierArray,
DynamicSlice,
Partition,
Partition1D,
bytewidth,
c,
commit_shared,
debug_print,
ds,
fori,
memref_fold,
memref_slice,
memref_transpose,
memref_unfold,
memref_unsqueeze,
once,
tile_shape,
)
from .wgmma import (
WGMMAAccumulator,
WGMMALayout,
wgmma,
)

View File

@ -0,0 +1,476 @@
# Copyright 2024 The JAX Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for code generator."""
import dataclasses
import jax
from jaxlib.mlir import ir
from jaxlib.mlir.dialects import arith
from jaxlib.mlir.dialects import gpu
from jaxlib.mlir.dialects import llvm
from jaxlib.mlir.dialects import math as mlir_math
from jaxlib.mlir.dialects import memref
from jaxlib.mlir.dialects import nvvm
from jaxlib.mlir.dialects import vector
import numpy as np
from . import utils
from . import dsl as mgpu
# mypy: ignore-errors
WARPGROUP_SIZE = utils.WARPGROUP_SIZE
c = utils.c
@dataclasses.dataclass(frozen=True)
class WGMMAFragLayout:
"""[m, n] matrix, where m % 64 == 0 == n % 8."""
@dataclasses.dataclass(frozen=True)
class WGMMARowFragLayout:
"""[m] matrix, where m % 64 == 0."""
@dataclasses.dataclass(frozen=True)
class WGStridedFragLayout:
"""Convert the array to 1D and then shard across threads."""
shape: tuple[int, ...]
vec_size: int
def __post_init__(self):
if np.prod(self.shape) % (self.vec_size * WARPGROUP_SIZE) != 0:
raise ValueError((self, WARPGROUP_SIZE))
@classmethod
def from_memref_type(cls, memref_ty: ir.Type):
if not ir.MemRefType.isinstance(memref_ty):
raise TypeError(memref_ty)
memref_type = ir.MemRefType(memref_ty)
bw = mgpu.bytewidth(memref_type.element_type)
assert 8 % bw == 0 and 8 // bw != 0, bw
return cls(shape=memref_type.shape, vec_size=8 // bw)
def thread_vec_idxs(self):
"""The indexes to be used for vector load/store WGStridedFragLayout.
Yields:
The indices of the vector that correspond to the current thread.
"""
cardinality = np.prod(self.shape)
assert cardinality % (WARPGROUP_SIZE * self.vec_size) == 0
reg_num = cardinality // (WARPGROUP_SIZE * self.vec_size)
tidx = gpu.thread_id(gpu.Dimension.x)
off = arith.muli(tidx, c(self.vec_size, tidx.type))
for i in range(reg_num):
yield [arith.addi(off, c(i * WARPGROUP_SIZE * self.vec_size, tidx.type))]
FragmentedLayout = WGStridedFragLayout | WGMMAFragLayout | WGMMARowFragLayout
WGMMA_LAYOUT = WGMMAFragLayout()
WGMMA_ROW_LAYOUT = WGMMARowFragLayout()
@jax.tree_util.register_pytree_node_class
class FragmentedArray:
registers: np.ndarray # of ir.Value, see checks in init for shapes.
layout: FragmentedLayout
def __init__(self, *, _registers: np.ndarray, _layout: FragmentedLayout):
self.registers = _registers
self.layout = _layout
match self.layout:
# Registers are [m_tiles, n_tiles, 2 rows, 1 cols] in WGMMA layout
# Each element is a vector<2xdtype>
case WGMMAFragLayout():
if self.registers.ndim != 4 or self.registers.shape[2:] != (2, 1):
raise ValueError("Invalid register array shape")
# Registers are [m_tiles, 2 rows] in WGMMA_ROW layout
# Each element is a dtype scalar
case WGMMARowFragLayout():
if self.registers.ndim != 2 or self.registers.shape[-1] != 2:
raise ValueError("Invalid register array shape")
# Registers are flat
case WGStridedFragLayout(shape):
(reg_size,) = ir.VectorType(_registers.flat[0].type).shape
if np.prod(shape) != np.prod(_registers.shape) * WARPGROUP_SIZE * reg_size:
raise ValueError((reg_size, shape, _registers.shape, WARPGROUP_SIZE), _registers.flat[0].type)
case _:
raise NotImplementedError
@classmethod
def load_strided(cls, ref: ir.Value):
if not ir.MemRefType.isinstance(ref.type):
raise TypeError(ref.type)
ref_ty = ir.MemRefType(ref.type)
ref_1d = mgpu.memref_fold(ref, 0, len(ref_ty.shape))
layout = WGStridedFragLayout.from_memref_type(ref_ty)
vec_ty = ir.VectorType.get((layout.vec_size,), ref_ty.element_type)
vecs = [vector.load(vec_ty, ref_1d, vec_idx) for vec_idx in layout.thread_vec_idxs()]
return cls(_registers=np.array(vecs), _layout=layout)
@classmethod
def splat(cls, value, shape, layout):
match layout:
case WGMMARowFragLayout():
if len(shape) != 1:
raise ValueError
if shape[0] % 64:
raise ValueError
reg_shape = (shape[0] // 64, 2)
case WGMMAFragLayout():
if len(shape) != 2:
raise ValueError
if shape[0] % 64 or shape[1] % 8:
raise ValueError
reg_shape = (shape[0] // 64, shape[1] // 8, 2, 1)
value = vector.splat(ir.VectorType.get((2,), value.type), value)
case WGStridedFragLayout(shape=shape, vec_size=vec_size):
elems = np.prod(shape)
reg_shape = (elems // (WARPGROUP_SIZE * vec_size),)
value = vector.splat(ir.VectorType.get((vec_size,), value.type), value)
case _:
raise NotImplementedError(layout)
return cls(
_registers=np.full(reg_shape, value, dtype=object),
_layout=layout,
)
@property
def shape(self):
row_tiles = self.registers.shape[0]
match self.layout:
case WGMMAFragLayout():
col_tiles = self.registers.shape[1]
return (row_tiles * 64, col_tiles * 8)
case WGMMARowFragLayout():
return (row_tiles * 64,)
case WGStridedFragLayout(shape):
return shape
@property
def mlir_dtype(self):
reg_ty = self.registers.flat[0].type
match self.layout:
case WGMMAFragLayout() | WGStridedFragLayout():
return ir.VectorType(reg_ty).element_type
case WGMMARowFragLayout():
return reg_ty
def _pointwise(self, op, *other):
for o in other:
if not isinstance(o, FragmentedArray):
return NotImplemented
if self.layout != o.layout:
raise ValueError("Incompatible FragmentedArray layouts")
if self.registers.shape != o.registers.shape:
raise ValueError("Incompatible FragmentedArray shapes")
new_regs = np.empty_like(self.registers)
for idx, reg in np.ndenumerate(self.registers):
new_regs[idx] = op(reg, *(o.registers[idx] for o in other))
return FragmentedArray(_registers=new_regs, _layout=self.layout)
def __add__(self, other):
return self._pointwise(arith.addf, other)
def __mul__(self, other):
return self._pointwise(arith.mulf, other)
def __sub__(self, other):
return self._pointwise(arith.subf, other)
def __truediv__(self, other):
return self._pointwise(arith.divf, other)
def max(self, other):
return self._pointwise(arith.maximumf, other)
def exp(self, approx: bool = False):
def fast_exp(x):
f32 = ir.F32Type.get()
log2e = arith.constant(f32, ir.FloatAttr.get(f32, 1.4426950408889634))
if x.type == f32:
scaled = arith.mulf(x, log2e)
return llvm.inline_asm(
f32, [scaled], "ex2.approx.f32 $0,$1;", "=f,f", asm_dialect=0
)
elif ir.VectorType.isinstance(x.type):
index = ir.IndexType.get()
result = llvm.mlir_undef(x.type)
for i in range(2):
v = vector.extractelement(x, position=c(i, index))
vr = fast_exp(v)
result = vector.insertelement(vr, result, position=c(i, index))
return result
else:
raise NotImplementedError(x.type)
return self._pointwise(fast_exp if approx else mlir_math.exp)
def __and__(self, other):
if not ir.IntegerType.isinstance(self.mlir_dtype):
raise ValueError(
"Bitwise operations only defined for integer types, not"
f" {self.mlir_dtype}"
)
return self._pointwise(arith.andi, other)
def bitcast(self, elt: ir.Type):
reg_type = self.registers.flat[0].type
if ir.VectorType.isinstance(reg_type):
reg_shape = ir.VectorType(reg_type).shape
ty = ir.VectorType.get(reg_shape, elt)
else:
ty = elt
return self._pointwise(lambda x: arith.bitcast(ty, x))
def __getitem__(self, idx):
if self.layout != WGMMA_LAYOUT:
raise NotImplementedError("Only WGMMA layouts support slicing")
base_idx, slice_shape, is_squeezed = utils.parse_indices(idx, self.shape)
if any(is_squeezed):
raise NotImplementedError("Only slicing implemented")
if (
base_idx[0] % 64
or slice_shape[0] % 64
or base_idx[1] % 8
or slice_shape[1] % 8
):
raise NotImplementedError("Only tile aligned slicing supported")
base_idx[0] //= 64
slice_shape[0] //= 64
base_idx[1] //= 8
slice_shape[1] //= 8
new_regs = self.registers[
base_idx[0] : base_idx[0] + slice_shape[0],
base_idx[1] : base_idx[1] + slice_shape[1],
]
return FragmentedArray(_registers=new_regs, _layout=self.layout)
# TODO(apaszke): Support JAX dtypes here as well?
def astype(self, new_dtype: ir.Type):
cur_dtype = self.mlir_dtype
if cur_dtype == new_dtype:
return self
from_float = ir.FloatType.isinstance(cur_dtype)
to_float = ir.FloatType.isinstance(new_dtype)
from_integer = ir.IntegerType.isinstance(cur_dtype)
to_integer = ir.IntegerType.isinstance(new_dtype)
if from_float and to_float:
if ir.FloatType(cur_dtype).width > ir.FloatType(new_dtype).width:
convert = arith.truncf
else:
convert = arith.extf
elif from_integer and to_integer:
if ir.IntegerType(cur_dtype).width > ir.IntegerType(new_dtype).width:
convert = arith.trunci
else:
convert = arith.extsi
elif from_integer and to_float:
convert = arith.sitofp
elif from_float and to_integer:
convert = arith.fptosi
new_registers = np.empty_like(self.registers)
match self.layout:
case WGMMAFragLayout():
new_reg_ty = ir.VectorType.get((2,), new_dtype)
case WGStridedFragLayout(vec_size=vec_size):
new_reg_ty = ir.VectorType.get((vec_size,), new_dtype)
case WGMMARowFragLayout():
new_reg_ty = new_dtype
case _:
raise NotImplementedError(f"Unsupported layout {self.layout}")
for idx, reg in np.ndenumerate(self.registers):
new_registers[idx] = convert(new_reg_ty, reg)
return FragmentedArray(_registers=new_registers, _layout=self.layout)
def reduce(self, op, axis):
if self.layout != WGMMA_LAYOUT:
raise NotImplementedError(self.layout)
if axis != 1:
raise NotImplementedError
index = ir.IndexType.get()
i32 = ir.IntegerType.get_signless(32)
new_regs = np.empty(self.registers.shape[::2], dtype=object)
assert self.registers.shape[-1] == 1
for row_tile, row_subtile in np.ndindex(new_regs.shape):
# Reduce the registers owned by the current thread over n tiles
thread_result_vec = self.registers[row_tile, 0, row_subtile, 0]
for n_tile in range(1, self.registers.shape[1]):
thread_result_vec = op(
thread_result_vec, self.registers[row_tile, n_tile, row_subtile, 0]
)
thread_result = op(
vector.extractelement(thread_result_vec, position=c(0, index)),
vector.extractelement(thread_result_vec, position=c(1, index)),
)
# Do a shuffle to reduce in groups of 4 consecutive threads.
result = thread_result
for i in (1, 2):
other_result = nvvm.shfl_sync(
result.type,
c(0xFFFFFFFF, i32),
result,
c(i, i32),
c(0x1F, i32),
nvvm.ShflKind.bfly,
)
result = op(result, other_result)
new_regs[row_tile, row_subtile] = result
return FragmentedArray(_registers=new_regs, _layout=WGMMA_ROW_LAYOUT)
def broadcast_minor(self, n):
if self.layout != WGMMA_ROW_LAYOUT:
raise NotImplementedError
num_row_tiles = self.registers.shape[0]
num_col_tiles, rem = divmod(n, 8)
if rem:
raise ValueError("Number of columns must be divisible by 8")
new_regs = np.empty((num_row_tiles, num_col_tiles, 2, 1), dtype=object)
dtype = self.mlir_dtype
for (row_tile, row_subtile), reg in np.ndenumerate(self.registers):
new_regs[row_tile, :, row_subtile, :] = vector.splat(
ir.VectorType.get((2,), dtype), reg
)
return FragmentedArray(_registers=new_regs, _layout=WGMMA_LAYOUT)
def store_untiled(self, ref: ir.Value):
if not ir.MemRefType.isinstance(ref.type):
raise ValueError(ref)
match self.layout:
case WGMMAFragLayout():
self._store_untiled_wgmma(ref)
case WGStridedFragLayout():
self._store_untiled_wg_strided(ref)
case _:
raise NotImplementedError(self.layout)
def _store_untiled_wg_strided(self, ref: ir.Value):
ref_ty = ir.MemRefType(ref.type)
if ref_ty.shape != self.shape:
raise ValueError((ref_ty.shape, self.shape))
smem_1d = mgpu.memref_fold(ref, 0, len(ref_ty.shape))
assert isinstance(self.layout, WGStridedFragLayout)
for idx, reg in zip(self.layout.thread_vec_idxs(), self.registers.flat):
vector.store(reg, smem_1d, idx)
def _store_untiled_wgmma(self, ref: ir.Value):
"""Stores accumulator to a 2D memref. Not optimized at the moment."""
assert self.layout == WGMMA_LAYOUT
index = ir.IndexType.get()
m, n = self.shape # pytype: disable=bad-unpacking
ref_ty = ir.MemRefType(ref.type)
if ref_ty.shape != [m, n]:
raise ValueError(ref.type, (m, n))
def c(x):
return arith.ConstantOp(index, ir.IntegerAttr.get(index, x))
tidx = gpu.thread_id(gpu.Dimension.x)
lane_id = arith.remui(tidx, c(32)) # {0, 1, ..., 31}
warp_id = arith.divui(tidx, c(32)) # {0, 1, 2, 3}
row_base = arith.addi(
arith.divui(lane_id, c(4)), arith.muli(warp_id, c(16))
)
col_base = arith.muli(arith.remui(lane_id, c(4)), c(2)) # {0, 2, 4, 6}
it = np.ndenumerate(self.registers)
for (row_tile, col_tile, row_idx, col_zero), elem in it:
del col_zero
row = arith.addi(row_base, c(row_tile * 64 + row_idx * 8))
for col_idx in range(2):
value = vector.extractelement(elem, position=c(col_idx))
col = arith.addi(col_base, c(col_tile * 8 + col_idx))
memref.store(value, ref, [row, col])
def store_tiled(self, ref, swizzle: int | None):
if self.layout != WGMMA_LAYOUT:
raise NotImplementedError
bw = mgpu.bytewidth(self.mlir_dtype)
m, n = self.shape # pytype: disable=bad-unpacking
assert m % 64 == 0 # This is implied by the layout.
if n % 32 != 0:
raise NotImplementedError
cols_per_tile = 128 // bw
expected_shape = [m // 64, n // cols_per_tile, 64, cols_per_tile]
if ir.MemRefType(ref.type).shape != expected_shape:
raise ValueError(ref.type, (m, n))
if swizzle != 128:
raise NotImplementedError("Only 128B swizzle supported")
index = ir.IndexType.get()
def c(x):
return arith.ConstantOp(index, ir.IntegerAttr.get(index, x))
tidx = gpu.thread_id(gpu.Dimension.x)
lane_id = arith.remui(tidx, c(32)) # {0, 1, ..., 31}
warp_id = arith.divui(tidx, c(32)) # {0, 1, 2, 3}
sub_row_base = arith.divui(lane_id, c(4)) # {0, 1, ..., 7}
if bw > 2: # Stagger is only necessary for values larger than 16bit.
is_even_row = arith.cmpi(
arith.CmpIPredicate.eq, arith.remui(sub_row_base, c(2)), c(0)
)
else:
# We rely on canonicalization to clean up the selects.
i1 = ir.IntegerType.get_signless(1)
is_even_row = arith.constant(i1, ir.IntegerAttr.get(i1, 1))
row_base = arith.addi(sub_row_base, arith.muli(warp_id, c(16)))
col_base = arith.muli(arith.remui(lane_id, c(4)), c(2)) # {0, 2, 4, 6}
# The swizzle pattern is constant for a given thread.
col_swizzle_bits = arith.muli(sub_row_base, c(16 // bw))
for row_group in range(m // 64):
for col_group in range(n // cols_per_tile):
for row_subidx in range(2):
row = arith.addi(row_base, c(row_subidx * 8))
for col_subidx in range(cols_per_tile // 8):
# We stagger the even and odd rows a little to avoid bank conflicts.
# It seems that the STS.64 is 2x faster (and the hardware reports no
# conflicts) when the conflicts are split between half-warps, as
# opposed to having them within the half-warp. This requires a
# little more work for the selects, but is ultimately worth it.
col_subidx_even = col_subidx
col_subidx_odd = col_subidx ^ 2
col_off = arith.select(
is_even_row, c(col_subidx_even * 8), c(col_subidx_odd * 8)
)
col = arith.addi(col_base, col_off)
col = arith.xori(col, col_swizzle_bits)
reg_idx_even = col_subidx_even + col_group * (cols_per_tile // 8)
reg_idx_odd = col_subidx_odd + col_group * (cols_per_tile // 8)
value_even = self.registers[row_group, reg_idx_even, row_subidx, 0]
value_odd = self.registers[row_group, reg_idx_odd, row_subidx, 0]
value = arith.select(is_even_row, value_even, value_odd)
vector.store(value, ref, [c(row_group), c(col_group), row, col])
def tree_flatten(self):
return list(self.registers.flat), (self.layout, self.registers.shape)
@classmethod
def tree_unflatten(cls, aux, flat_registers):
layout, reg_shape = aux
registers = np.asarray(flat_registers, dtype=object).reshape(reg_shape)
return cls(_registers=registers, _layout=layout)

View File

@ -0,0 +1,206 @@
# Copyright 2024 The JAX Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import contextlib
import json
import jax
import jax.numpy as jnp
from jaxlib.mlir import ir
from jaxlib.mlir.dialects import arith
from jaxlib.mlir.dialects import gpu
from jaxlib.mlir.dialects import memref
from jaxlib.mlir.dialects import scf
import numpy as np
from .utils import * # noqa: F403
# ruff: noqa: F405
# mypy: ignore-errors
class ProfilerSpec:
ENTER = 0
EXIT = 1 << 31
def __init__(self, num_entries: int):
self.num_entries = num_entries
self.interned_names = {}
@property
def mlir_buffer_type(self) -> ir.Type:
return ir.MemRefType.get(
(1 + self.num_entries,), ir.IntegerType.get_signless(32)
)
@property
def jax_buffer_type(self) -> ir.Type:
return jax.ShapeDtypeStruct((1 + self.num_entries,), jnp.uint32)
def smem_i32_elements(self, grid: tuple[int, ...]):
return int(self.num_entries // np.prod(grid))
def smem_bytes(self, grid: tuple[int, ...]):
bytes_per_entry = 4
return self.smem_i32_elements(grid) * bytes_per_entry
def intern_name(self, name: str) -> int:
if name_id := self.interned_names.get(name, None):
return name_id
name_id = self.interned_names[name] = len(self.interned_names)
if name_id & self.EXIT:
raise RuntimeError("Allocated too many names")
return name_id
def dump(self, buffer, f):
buffer = np.asarray(buffer)
num_blocks = buffer[0]
per_block = self.num_entries // num_blocks
block_entries = buffer[1 : 1 + num_blocks * per_block].reshape(
num_blocks, per_block
)
start_times = block_entries[:, :2].astype(np.int64)
start_times = (start_times[:, 0] << 32) + start_times[:, 1]
start_times -= start_times.min() # Normalize
entries_used = block_entries[:, 2]
if np.any(entries_used > per_block - 2):
raise RuntimeError("Insufficient space to capture a full trace")
block_traces = block_entries[:, 3:]
unintern = {v: k for k, v in self.interned_names.items()}
events = []
for block_idx in range(num_blocks):
valid_entries = entries_used[block_idx] - 3
local_clock_offset = None
assert valid_entries % 2 == 0
start_time = start_times[block_idx]
block_events = []
for i in range(0, valid_entries, 2):
tag = block_traces[block_idx, i]
time = block_traces[block_idx, i + 1]
if local_clock_offset is None:
local_clock_offset = time
time -= local_clock_offset
time -= i * 6 # Account for the overhead of profiling.
if time < 0:
break # Detect a timer wraparound
name_id = tag
begin = True
if name_id & ProfilerSpec.EXIT:
name_id = name_id ^ ProfilerSpec.EXIT
begin = False
name = unintern[name_id]
block_events.append({
"name": name,
"ph": "B" if begin else "E",
"ts": float(start_time + time) / 1e3,
"pid": 0,
"tid": block_idx,
})
else: # If we didn't break
events.extend(block_events)
return json.dump({"displayTimeUnit": "ns", "traceEvents": events}, f)
class OnDeviceProfiler:
def __init__(self, spec: ProfilerSpec, smem_buffer: ir.Value, gmem_buffer: ir.Value):
self.spec = spec
# self.should_store = gpu.thread_id(gpu.Dimension.x)
i32 = ir.IntegerType.get_signless(32)
index = ir.IndexType.get()
num_blocks = c(1, index)
for dim in gpu.Dimension:
num_blocks = arith.muli(num_blocks, gpu.grid_dim(dim))
memref.store(arith.index_cast(i32, num_blocks), gmem_buffer, [c(0, index)])
self.entries_per_block = arith.divui(c(spec.num_entries, index), num_blocks)
self.smem_buffer = smem_buffer
self.gmem_buffer = gmem_buffer
# Hopefully mem2reg will remove the allocation.
self.offset = memref.alloca(ir.MemRefType.get((), i32), [], [])
memref.store(c(0, i32), self.offset, [])
@contextlib.contextmanager
def record(self, name: str):
i32 = ir.IntegerType.get_signless(32)
index = ir.IndexType.get()
name_id = self.spec.intern_name(name)
def store(modifier):
cur = arith.index_cast(index, memref.load(self.offset, []))
# TODO(apaszke): Clamp indices
# bound = arith.subi(self.entries_per_block, c(2, index))
# cur = arith.select(
# arith.cmpi(arith.CmpIPredicate.ult, cur, bound), cur, bound
# )
memref.store(c(modifier | name_id, i32), self.smem_buffer, [cur])
memref.store(
clock(), self.smem_buffer, [arith.addi(cur, c(1, cur.type))]
)
memref.store(
arith.index_cast(i32, arith.addi(cur, c(2, cur.type))),
self.offset,
[],
)
store(ProfilerSpec.ENTER)
yield
store(ProfilerSpec.EXIT)
def finalize(self, grid):
index = ir.IndexType.get()
i32 = ir.IntegerType.get_signless(32)
block_idx = c(0, index)
for dim in reversed(gpu.Dimension): # pytype: disable=wrong-arg-types
block_idx = arith.addi(
arith.muli(block_idx, gpu.grid_dim(dim)), gpu.block_id(dim)
)
start_offset = arith.addi(
arith.muli(block_idx, self.entries_per_block), c(1, index)
)
block_gmem_buffer = memref.subview(
self.gmem_buffer, [start_offset], [self.spec.num_entries], [1],
result_type=ir.Type.parse(
f"memref<{self.spec.num_entries}xi32, strided<[1], offset: ?>>"
),
)
# TODO(apaszke): Either use globaltimer or delete
# memref.store(globaltimer("high"), block_gmem_buffer, [c(0, index)])
# memref.store(globaltimer("low"), block_gmem_buffer, [c(1, index)])
memref.store(c(0, i32), block_gmem_buffer, [c(0, index)])
memref.store(c(0, i32), block_gmem_buffer, [c(1, index)])
memref.store(
arith.addi(memref.load(self.offset, []), c(3, i32)),
block_gmem_buffer,
[c(2, index)],
)
if_first = scf.IfOp(
arith.cmpi(
arith.CmpIPredicate.eq, gpu.thread_id(gpu.Dimension.x), c(0, index)
)
)
with ir.InsertionPoint(if_first.then_block):
for_op = scf.ForOp(
c(0, index),
c(self.spec.smem_i32_elements(grid) - 3, index),
c(1, index),
)
with ir.InsertionPoint(for_op.body):
x = memref.load(self.smem_buffer, [for_op.induction_variable])
memref.store(
x,
block_gmem_buffer,
[arith.addi(for_op.induction_variable, c(3, index))],
)
scf.yield_([])
scf.yield_([])

View File

@ -0,0 +1,634 @@
# Copyright 2024 The JAX Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for code generator."""
import contextlib
import dataclasses
from typing import Any, Literal, Sequence
from absl import flags
import jax
from jaxlib.mlir import ir
from jaxlib.mlir.dialects import arith
from jaxlib.mlir.dialects import builtin
from jaxlib.mlir.dialects import gpu
from jaxlib.mlir.dialects import llvm
from jaxlib.mlir.dialects import memref
from jaxlib.mlir.dialects import nvgpu
from jaxlib.mlir.dialects import nvvm
from jaxlib.mlir.dialects import scf
from jaxlib.mlir.dialects import vector
import numpy as np
# mypy: ignore-errors
WARPGROUP_SIZE: int = 128
DYNAMIC = -9223372036854775808
FLAGS = flags.FLAGS
flags.DEFINE_bool("mosaic_gpu_debug", False, "Perform debug printing")
# pylint: disable=line-too-long, wildcard-import, missing-function-docstring, bad-continuation, g-bad-todo, protected-access, g-explicit-length-test, missing-class-docstring, g-doc-return-or-yield, g-inconsistent-quotes
def ptr_as_memref(ptr, memref_ty: ir.MemRefType):
if len(memref_ty.shape) == 0:
raise NotImplementedError
i64 = ir.IntegerType.get_signless(64)
rank = len(memref_ty.shape)
desc_ty = ir.Type.parse(
f"!llvm.struct<(ptr, ptr, i64, array<{rank} x i64>, array<{rank} x i64>)>"
)
desc = llvm.UndefOp(desc_ty)
desc = llvm.InsertValueOp(desc, ptr, [0]) # Allocation
desc = llvm.InsertValueOp(desc, ptr, [1]) # Aligned Base
desc = llvm.InsertValueOp(
desc, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, 0)), [2]
)
for i, s in enumerate(memref_ty.shape):
desc = llvm.InsertValueOp(
desc, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, s)), [3, i]
)
for i, s in enumerate(get_contiguous_strides(memref_ty.shape)):
desc = llvm.InsertValueOp(
desc, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, s)), [4, i]
)
return builtin.unrealized_conversion_cast([memref_ty], [desc])
def get_contiguous_strides(xs):
strides_ret = []
stride = 1
for x in xs[::-1]:
strides_ret.append(stride)
stride *= x
return strides_ret[::-1]
def c(val: int | float, ty):
if ir.IntegerType.isinstance(ty) or ir.IndexType.isinstance(ty):
if not isinstance(val, (int, np.integer)):
raise TypeError(type(val))
attr = ir.IntegerAttr.get(ty, val)
elif ir.FloatType.isinstance(ty):
attr = ir.FloatAttr.get(ty, val)
elif ir.VectorType.isinstance(ty):
return vector.splat(ty, c(val, ir.VectorType(ty).element_type))
else:
raise NotImplementedError(ty)
return arith.constant(ty, attr)
def get_tensormap_descriptor(**attrs):
return ir.Type.parse(
f"!nvgpu.tensormap.descriptor<{', '.join(k + '=' + v for k, v in attrs.items())}>"
)
def debug_print(fmt, *args, uniform=True):
if not FLAGS.mosaic_gpu_debug:
return
type_formats = []
new_args = []
for arg in args:
ty_format = None
if ir.IndexType.isinstance(arg.type):
ty_format = "%llu"
if ir.IntegerType.isinstance(arg.type):
width = ir.IntegerType(arg.type).width
if width == 64:
ty_format = "%llu"
elif width == 1:
ty_format = "%llu"
arg = arith.extui(ir.IntegerType.get_signless(64), arg)
if ir.F32Type.isinstance(arg.type):
ty_format = "%f"
if ir.F16Type.isinstance(arg.type):
ty_format = "%f"
arg = arith.extf(ir.F32Type.get(), arg)
if ty_format is None:
raise NotImplementedError(arg.type)
type_formats.append(ty_format)
new_args.append(arg)
ctx = once if uniform else contextlib.nullcontext
with ctx():
gpu.printf(fmt.format(*type_formats) + "\n", new_args)
@dataclasses.dataclass(frozen=True)
class ForResult:
op: scf.ForOp
results: tuple[Any, ...]
@property
def result(self):
if len(self.results) != 1:
raise ValueError
return self.results[0]
def fori(bound, carrys):
unwrap = False
if not isinstance(carrys, (list, tuple)):
carrys = [carrys]
unwrap = True
flat_carrys, carry_treedef = jax.tree.flatten(carrys)
def wrapper(f):
index = ir.IndexType.get()
c0 = arith.ConstantOp(index, ir.IntegerAttr.get(index, 0))
c1 = arith.ConstantOp(index, ir.IntegerAttr.get(index, 1))
for_op = scf.ForOp(c0, bound, c1, flat_carrys)
with ir.InsertionPoint(for_op.body):
i = for_op.induction_variable
inner_carrys = jax.tree.unflatten(carry_treedef, for_op.inner_iter_args)
if unwrap:
[inner_carrys] = inner_carrys
new_carrys = f(i, inner_carrys)
if unwrap:
new_carrys = [new_carrys]
new_flat_carrys, new_carry_treedef = jax.tree.flatten(new_carrys)
if new_carry_treedef != carry_treedef:
raise ValueError(new_carry_treedef, carry_treedef)
scf.YieldOp(new_flat_carrys)
final_flat_carrys = for_op.results
return ForResult(
for_op, jax.tree.unflatten(carry_treedef, final_flat_carrys)
)
return wrapper
def get_warp_idx():
i32 = ir.IntegerType.get_signless(32)
tidx = arith.index_cast(i32, gpu.thread_id(gpu.Dimension.x))
warp_idx = arith.shrui(tidx, c(5, tidx.type))
mask = c(0xFFFFFFFF, i32)
return nvvm.shfl_sync(
warp_idx.type, mask, warp_idx, c(0, i32), c(0x1F, i32), nvvm.ShflKind.idx
)
# True withon `once()` contexts.
_ONCE_REGION_ACTIVE = False
@contextlib.contextmanager
def once():
"""Runs the context only from a single thread from the first warp.
The block is assumed to have a size of 1 in both y and z dimensions.
"""
global _ONCE_REGION_ACTIVE
if _ONCE_REGION_ACTIVE:
yield
return
warp = get_warp_idx()
first_warp = arith.cmpi(arith.CmpIPredicate.eq, warp, c(0, warp.type))
elected = nvvm.elect_sync(ir.IntegerType.get_signless(1))
should_run = arith.andi(first_warp, elected)
if_op = scf.IfOp(should_run)
_ONCE_REGION_ACTIVE = True
try:
with ir.InsertionPoint(if_op.then_block):
yield
scf.YieldOp([])
finally:
_ONCE_REGION_ACTIVE = False
def clock():
i32 = ir.IntegerType.get_signless(32)
return llvm.inline_asm(
i32, [], "mov.u32 $0,%clock;", "=r", asm_dialect=0, has_side_effects=True
)
def globaltimer(kind: Literal["low", "high"] | None = None):
if kind is None:
i64 = ir.IntegerType.get_signless(64)
return llvm.inline_asm(
i64, [], "mov.u32 $0,%globaltimer;",
"=l", asm_dialect=0, has_side_effects=True,
)
i32 = ir.IntegerType.get_signless(32)
return llvm.inline_asm(
i32, [], f"mov.u32 $0,%globaltimer_{kind[:2]};",
"=r", asm_dialect=0, has_side_effects=True,
)
def bytewidth(ty: ir.Type):
if ir.IntegerType.isinstance(ty):
return ir.IntegerType(ty).width // 8
if ir.FloatType.isinstance(ty):
return ir.FloatType(ty).width // 8
raise NotImplementedError(ty)
@dataclasses.dataclass(frozen=True)
class DynamicSlice:
base: ir.Value | int
length: int
ds = DynamicSlice
def memref_slice(ref: ir.Value, index) -> ir.Value:
ref_ty = ir.MemRefType(ref.type)
base_indices, slice_shape, is_squeezed = parse_indices(index, ref_ty.shape)
memref_strides, offset = ref_ty.get_strides_and_offset()
new_offset = offset
for idx, stride in zip(base_indices, memref_strides):
if isinstance(idx, int):
new_offset += idx * stride
else:
new_offset = ir.ShapedType.get_dynamic_stride_or_offset()
break
new_strides = [
s for s, squeeze in zip(memref_strides, is_squeezed) if not squeeze
]
new_shape = [s for s, squeeze in zip(slice_shape, is_squeezed) if not squeeze]
new_layout = ir.StridedLayoutAttr.get(new_offset, new_strides)
ref_slice = memref.subview(
ref, base_indices, slice_shape, [1] * len(ref_ty.shape),
result_type=ir.MemRefType.get(
new_shape, ref_ty.element_type, new_layout, ref_ty.memory_space
),
)
return ref_slice
def _is_contiguous_shape_slice(
ref_ty: ir.MemRefType, dim_slice: slice | None = slice(None)
):
# If it's not a strided layout then we are definitely contiguous.
if not ir.StridedLayoutAttr.isinstance(ref_ty.layout):
return True
strides = ir.StridedLayoutAttr(ref_ty.layout).strides[dim_slice]
shape = ref_ty.shape[dim_slice]
# Check that each dimension fits exactly it the immediately larger stride.
ss = sorted(zip(strides, shape), key=lambda x: x[0], reverse=True)
for (prev_stride, _), (stride, shape) in zip(ss, ss[1:]):
if stride * shape != prev_stride:
return False
return True
def memref_fold(ref: ir.Value, dim, fold_rank) -> ir.Value:
ref_ty = ir.MemRefType(ref.type)
new_shape = list(ref_ty.shape)
new_shape[dim : dim + fold_rank] = [np.prod(new_shape[dim : dim + fold_rank])]
identity = ir.AffineMapAttr.get(ir.AffineMap.get_identity(ref_ty.rank))
if ref_ty.layout == identity:
new_layout = ir.AffineMapAttr.get(
ir.AffineMap.get_identity(ref_ty.rank - fold_rank + 1)
)
elif _is_contiguous_shape_slice(ref_ty, slice(dim, dim + fold_rank)):
new_strides, offset = ref_ty.get_strides_and_offset()
new_strides[dim : dim + fold_rank] = [new_strides[dim + fold_rank - 1]]
new_layout = ir.StridedLayoutAttr.get(offset, new_strides)
else:
raise NotImplementedError(
f"strides={ref_ty.get_strides_and_offset()[0]}, {ref_ty.shape=},"
f" {dim=}, {fold_rank=}"
)
new_ty = ir.MemRefType.get(
new_shape, ref_ty.element_type, new_layout, ref_ty.memory_space
)
assoc = [[d] for d in range(dim)]
assoc.append([dim + i for i in range(fold_rank)])
assoc.extend([d] for d in range(dim + fold_rank, ref_ty.rank))
assert len(assoc) == new_ty.rank
return memref.collapse_shape(new_ty, ref, assoc)
def memref_unfold(ref: ir.Value, dim, factors) -> ir.Value:
"""Unfolds dim into two dimensions, the size of leading one given be major_factor."""
ref_ty = ir.MemRefType(ref.type)
new_shape = list(ref_ty.shape)
if sum(f is None for f in factors) > 1:
raise ValueError("Can only infer one dimension")
known_factor_prod = np.prod([f for f in factors if f is not None])
if new_shape[dim] % known_factor_prod:
raise ValueError("Non-divisible unfold:", new_shape[dim], factors)
factors = tuple(
new_shape[dim] // known_factor_prod if f is None else f for f in factors
)
new_shape[dim : dim + 1] = factors
identity = ir.AffineMapAttr.get(ir.AffineMap.get_identity(ref_ty.rank))
if ref_ty.layout == identity:
new_layout = ir.AffineMapAttr.get(
ir.AffineMap.get_identity(ref_ty.rank + len(factors) - 1)
)
else:
new_strides, offset = ref_ty.get_strides_and_offset()
prev_stride = new_strides[dim]
inserted_strides = []
for f in reversed(factors):
inserted_strides.append(prev_stride)
prev_stride *= f
new_strides[dim : dim + 1] = reversed(inserted_strides)
new_layout = ir.StridedLayoutAttr.get(offset, new_strides)
new_ty = ir.MemRefType.get(
new_shape, ref_ty.element_type, new_layout, ref_ty.memory_space
)
if dim == ref_ty.rank:
assoc = [[d] for d in range(ref_ty.rank)]
assoc[-1].extend(range(ref_ty.rank, ref_ty.rank + len(factors) - 1))
else:
assoc = [[d] for d in range(dim)]
assoc.append(list(range(dim, dim + len(factors))))
assoc.extend([d + len(factors) - 1] for d in range(dim + 1, ref_ty.rank))
assert len(assoc) == ref_ty.rank
return memref.expand_shape(new_ty, ref, assoc)
def memref_unsqueeze(ref: ir.Value, dim) -> ir.Value:
"""Inserts a singleton dimension."""
ref_ty = ir.MemRefType(ref.type)
if dim == ref_ty.rank:
new_shape = list(ref_ty.shape)
new_shape.append(1)
identity = ir.AffineMapAttr.get(ir.AffineMap.get_identity(ref_ty.rank))
if ref_ty.layout == identity:
new_layout = ir.AffineMapAttr.get(
ir.AffineMap.get_identity(ref_ty.rank + 1)
)
else:
new_strides, offset = ref_ty.get_strides_and_offset()
new_strides.append(1)
new_layout = ir.StridedLayoutAttr.get(offset, new_strides)
new_ty = ir.MemRefType.get(
new_shape, ref_ty.element_type, new_layout, ref_ty.memory_space
)
assoc = [[d] for d in range(ref_ty.rank)]
assoc[-1].append(ref_ty.rank)
return memref.expand_shape(new_ty, ref, assoc)
else:
return memref_unfold(ref, dim, (1, None))
def memref_transpose(ref: ir.Value, permutation: Sequence[int]) -> ir.Value:
ref_ty = ir.MemRefType(ref.type)
strides, offset = ref_ty.get_strides_and_offset()
new_strides = [strides[p] for p in permutation]
new_shape = [ref_ty.shape[p] for p in permutation]
new_layout = ir.StridedLayoutAttr.get(offset, new_strides)
new_ty = ir.MemRefType.get(
new_shape, ref_ty.element_type, new_layout, ref_ty.memory_space
)
return memref.transpose(
new_ty, ref, ir.AffineMap.get_permutation(permutation)
)
def parse_indices(
index, shape: tuple[int, ...]
) -> tuple[list[ir.Value | int], list[int], list[bool]]:
if not isinstance(index, tuple):
index = (index,)
if trailing_dims := len(shape) - len(index):
index += (slice(None),) * trailing_dims
base_indices = []
slice_shape = []
is_squeezed = []
for idx, bound in zip(index, shape):
if isinstance(idx, (ir.Operation, ir.OpView)):
idx = idx.result
if isinstance(idx, int):
base_indices.append(idx)
slice_shape.append(1)
is_squeezed.append(True)
elif isinstance(idx, slice):
if idx.step is not None:
raise NotImplementedError("Strided slices not implemented")
base_indices.append(idx.start or 0)
slice_shape.append((idx.stop or bound) - (idx.start or 0))
is_squeezed.append(False)
elif isinstance(idx, DynamicSlice):
base_indices.append(idx.base)
slice_shape.append(idx.length)
is_squeezed.append(False)
elif isinstance(idx, ir.Value):
if not ir.IndexType.isinstance(idx.type):
raise ValueError("Expected an index-typed index")
base_indices.append(idx)
slice_shape.append(1)
is_squeezed.append(True)
else:
raise NotImplementedError(type(idx))
assert len(base_indices) == len(slice_shape) == len(is_squeezed) == len(shape)
return base_indices, slice_shape, is_squeezed
def commit_shared():
gpu.barrier()
nvvm.fence_proxy(
nvvm.ProxyKind.async_shared, space=nvvm.SharedSpace.shared_cta
)
class BarrierArray:
def __init__(self, num_barriers):
barrier_group_ty = ir.Type.parse(
"!nvgpu.mbarrier.group<memorySpace=#gpu.address_space<workgroup>,"
f" num_barriers={num_barriers}>"
)
self.value = nvgpu.mbarrier_create(barrier_group_ty)
index = ir.IndexType.get()
if num_barriers > 32:
raise NotImplementedError("Only up to 32 barriers per group supported")
i32 = ir.IntegerType.get_signless(32)
self.phases = memref.alloca(ir.MemRefType.get((), i32), [], [])
memref.store(c(0, i32), self.phases, [])
with once():
for i in range(num_barriers):
nvgpu.mbarrier_init(self.value, c(1, index), c(i, index))
def __getitem__(self, offset: ir.Value | int):
if isinstance(offset, int):
offset = c(offset, ir.IndexType.get())
return Barrier(self, offset)
@dataclasses.dataclass(frozen=True)
class Barrier:
barrier_array: BarrierArray
offset: ir.Value
def wait_parity(self, parity):
index = ir.IndexType.get()
nvgpu.mbarrier_try_wait_parity(
self.barrier_array.value, parity, c(10000000, index), self.offset,
)
def wait(self):
i32 = ir.IntegerType.get_signless(32)
parities = memref.load(self.barrier_array.phases, [])
offset_i32 = arith.index_castui(i32, self.offset)
bitmask = arith.shli(c(1, i32), offset_i32)
parity = arith.cmpi(
arith.CmpIPredicate.ne, arith.andi(parities, bitmask), c(0, i32)
)
new_parities = arith.xori(parities, bitmask)
memref.store(new_parities, self.barrier_array.phases, [])
self.wait_parity(parity)
class Partition:
source_bounds: tuple[int, ...]
target_bounds: tuple[int, ...]
partition: tuple[int | None, ...]
base_offset: tuple[ir.Value, ...] | None
def __init__(
self,
elements: tuple[int, ...],
*,
partition: tuple[int | None, ...],
base_offset: tuple[ir.Value, ...] | None = None,
num_chunks: tuple[int, ...] | None = None,
chunk_size: tuple[int, ...] | None = None,
):
self.target_bounds = elements
self.partition = partition
self.base_offset = base_offset
if len(self.target_bounds) != len(self.partition):
raise ValueError
if num_chunks is None == chunk_size is None:
raise ValueError(
"Exactly one of num_chunks and chunk_size must be specified"
)
if num_chunks is not None:
self.source_bounds = num_chunks
else:
if len(chunk_size) != len(self.target_bounds):
raise ValueError
source_bounds = []
for els, chunk in zip(elements, chunk_size):
if els % chunk:
raise ValueError("Non-divisible partition", elements, chunk_size)
source_bounds.append(els // chunk)
self.source_bounds = tuple(source_bounds)
seen_dims = set()
for p in self.partition:
if p is None:
continue
if not (0 <= p < len(self.source_bounds)):
raise ValueError
if p in seen_dims:
raise ValueError
seen_dims.add(p)
for tb, p in zip(self.target_bounds, self.partition):
if p is not None and tb % self.source_bounds[p]:
raise ValueError("Non-divisible partitioning")
@property
def num_chunks(self) -> tuple[int, ...]:
return self.source_bounds
@property
def target_block_shape(self):
return tuple(tb if p is None else tb // self.source_bounds[p]
for tb, p in zip(self.target_bounds, self.partition))
def get_base(self, *source_coords: ir.Value | int) -> list[ir.Value]:
coords = []
index = ir.IndexType.get()
for i, (tbs, p) in enumerate(zip(self.target_block_shape, self.partition)):
if p is None:
dim_base = c(0, index)
else:
dim_base = arith.muli(c(tbs, index), source_coords[p])
if self.base_offset is not None:
dim_base = arith.addi(self.base_offset[i], dim_base)
coords.append(dim_base)
return coords
class Partition1D:
partition: Partition
def __init__(
self,
elements: int,
*,
base_offset: ir.Value | None = None,
num_chunks: int | None = None,
chunk_size: int | None = None,
):
self.base_offset = base_offset
if num_chunks is None == chunk_size is None:
raise ValueError(
"Exactly one of num_chunks and chunk_size must be specified"
)
common_kwargs = dict(elements=(elements,), partition=(0,))
if base_offset is not None:
common_kwargs["base_offset"] = (base_offset,)
if num_chunks is not None:
self.partition = Partition(num_chunks=(num_chunks,), **common_kwargs)
else:
self.partition = Partition(chunk_size=(chunk_size,), **common_kwargs)
@property
def num_chunks(self) -> int:
return self.partition.source_bounds[0]
def get_base(self, source_coords: ir.Value) -> ir.Value:
return self.partition.get_base(source_coords)[0]
def refine(
self,
*,
chunk: ir.Value | None = None,
num_chunks: int | None = None,
chunk_size: int | None = None,
):
return Partition1D(
self.partition.target_block_shape[0],
num_chunks=num_chunks,
chunk_size=chunk_size,
base_offset=self.get_base(chunk) if chunk is not None else None,
)
def tile_shape(shape, tiling):
if len(tiling) > len(shape):
raise ValueError
if not tiling:
return shape
tiling_rank = len(tiling)
for s, t in zip(shape[-tiling_rank:], tiling):
if s % t:
raise ValueError("Non-divisible tiling:", shape, tiling)
return (
*shape[:-tiling_rank],
*(s // t for s, t in zip(shape[-tiling_rank:], tiling)),
*tiling,
)

View File

@ -0,0 +1,404 @@
# Copyright 2024 The JAX Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import dataclasses
import enum
import itertools
import jax
from jaxlib.mlir import ir
from jaxlib.mlir.dialects import arith
from jaxlib.mlir.dialects import builtin
from jaxlib.mlir.dialects import llvm
from jaxlib.mlir.dialects import nvvm
from jaxlib.mlir.dialects import vector
import numpy as np
from . import dsl as mgpu
# mypy: ignore-errors
c = mgpu.c
bytewidth = mgpu.bytewidth
@jax.tree_util.register_pytree_node_class
@dataclasses.dataclass
class WGMMAAccumulator:
"""A FragmentedArray that has is synchronized with the async proxy.
This implies that it requires no additional synchronization when passed in
as a WGMMA accumulator. In particular, when created from a
FragmentedArray, the necessary synchronization is inserted at construction.
"""
value: mgpu.FragmentedArray
def __init__(self, *, _value: mgpu.FragmentedArray, _sync: bool = True):
if _value.layout != mgpu.WGMMA_LAYOUT:
raise ValueError("Only WGMMA layouts supported in WGMMAAccumulator")
self.value = _value
if _sync:
nvvm.wgmma_fence_aligned()
@classmethod
def zero(cls, m, n):
if m % 64 or n % 8:
raise ValueError
f32 = ir.F32Type.get()
zero = arith.constant(f32, ir.FloatAttr.get(f32, 0.0))
return cls(
_value=mgpu.FragmentedArray.splat(zero, (m, n), mgpu.WGMMA_LAYOUT)
)
@classmethod
def from_registers(cls, registers):
return cls(_value=registers)
def tree_flatten(self):
return (self.value,), ()
@classmethod
def tree_unflatten(cls, aux, value):
del aux
return cls(_value=value[0], _sync=False)
def wgmma_encode(x: int):
result = (x & 0x3FFFF) >> 4
if result << 4 != x:
raise ValueError("Cannot encode value in a WGMMA descriptor")
return result
def get_memref_base(memref_arg, memory_space=None):
i64 = ir.IntegerType.get_signless(64)
memref_ty = ir.MemRefType(memref_arg.type)
if len(memref_ty.shape) == 0:
raise NotImplementedError
elem_bytewidth = bytewidth(memref_ty.element_type)
rank = len(memref_ty.shape)
# TODO: Read out memory space from memref
space = "" if memory_space is None else "<" + str(memory_space) + ">"
ptr_ty = ir.Type.parse("!llvm.ptr" + space)
desc_ty = ir.Type.parse(
f"!llvm.struct<({ptr_ty}, {ptr_ty}, i64, array<{rank} x i64>,"
f" array<{rank} x i64>)>"
)
desc = builtin.UnrealizedConversionCastOp([desc_ty], [memref_arg])
aligned_ptr = llvm.extractvalue(ptr_ty, desc, [1])
offset_elems = llvm.extractvalue(i64, desc, [2])
offset_bytes = llvm.mul(offset_elems, c(elem_bytewidth, i64))
return llvm.inttoptr(
ptr_ty, llvm.add(llvm.ptrtoint(i64, aligned_ptr), offset_bytes)
)
def create_descriptor(
memref_arg,
leading_byte_offset: int,
stride_byte_offset: int,
swizzle: int | None,
memory_space: int | None = None,
nvgpu_type=None,
):
i64 = ir.IntegerType.get_signless(64)
ptr_val = llvm.ptrtoint(i64, get_memref_base(memref_arg, memory_space))
if swizzle is None:
swizzle_encoding = 0
elif swizzle == 128:
swizzle_encoding = 1
else:
raise NotImplementedError(swizzle)
encoded_base_addr = llvm.LShrOp(
llvm.AndOp(ptr_val, c(0x3FFFF, i64)), c(4, i64)
)
desc_const = (
(wgmma_encode(leading_byte_offset) << 16)
| (wgmma_encode(stride_byte_offset) << 32)
|
# We ignore the offset
(swizzle_encoding << 62)
)
desc = llvm.OrOp(encoded_base_addr, c(desc_const, i64))
if nvgpu_type is not None:
desc = builtin.UnrealizedConversionCastOp([nvgpu_type], [desc])
return desc.result
def wgmma_m64k128B(
acc: np.ndarray, # of register Values
a,
b_descriptor: ir.Value,
a_transpose: bool | None,
b_transpose: bool,
a_k_stride: int | None,
b_k_stride: int,
n: int,
element_type: ir.Type,
):
f32 = ir.F32Type.get()
i32 = ir.IntegerType.get_signless(32)
i64 = ir.IntegerType.get_signless(64)
index = ir.IndexType.get()
if b_k_stride % 16:
raise ValueError
if n % (128 // bytewidth(element_type)):
raise ValueError
# Only 16-bit types support transposes
supports_transpose = bytewidth(element_type) == 2
if not supports_transpose and (a_transpose or b_transpose):
raise ValueError("Only f16 WGMMA supports transposes")
if a_in_regs := isinstance(a, mgpu.FragmentedArray):
if a.mlir_dtype != ir.F16Type.get() and a.mlir_dtype != ir.BF16Type.get():
raise ValueError(f"Unsupported A register array dtype: {a.mlir_dtype}")
if a.layout != mgpu.WGMMA_LAYOUT or a.shape != (64, 64):
raise ValueError("Unsupported A register array layout")
if a_k_stride is not None or a_transpose is not None:
raise ValueError("Unsupported WGMMA features with A in registers")
else:
if a_k_stride is None or a_k_stride % 16:
raise ValueError
if a_transpose is None:
raise ValueError
num_acc_regs = n // 2
num_imm_regs = 4 if supports_transpose else 2
if a_in_regs:
a_reg_constraints = ["r"] * 4 # 4x f16x2 registers
num_imm_regs -= 1 # transpose not supported for a in registers
else:
a_reg_constraints = ["l"] # descriptor
# Reference for i/o aliasing: https://gcc.gnu.org/onlinedocs/gcc/Extended-Asm.html
# Seems like it's not actually documented in LLVM IR docs.
reg_constraints_list = (
["=f"] * num_acc_regs # accumulator registers
+ [str(i) for i in range(num_acc_regs)] # we alias outputs as inputs, too.
+ a_reg_constraints # a descriptor / registers
+ ["l"] * 1 # b descriptor
+ ["n"] * (1 + num_imm_regs) # literal constants
)
reg_constraints = ",".join(reg_constraints_list)
reg_count = itertools.count()
def take_regs(n):
return (f"${i}" for i in itertools.islice(reg_count, n))
acc_reg_vector = "{" + ",".join(take_regs(num_acc_regs)) + "}"
for _ in take_regs(num_acc_regs): # Ignore next entries: aliasing.
pass
if a_in_regs:
a_regs = "{" + ",".join(take_regs(len(a_reg_constraints))) + "}"
else:
a_regs, = take_regs(1)
b_desc_reg, use_out_reg = take_regs(2)
imm_regs = ", ".join(take_regs(num_imm_regs)) # Immediate regs (scale, ...).
assert next(reg_count) == len(reg_constraints_list)
el_ty = element_type
k_instr = 32 // bytewidth(element_type)
wgmma_instr = (
f"wgmma.mma_async.sync.aligned.m64n{n}k{k_instr}.f32.{el_ty}.{el_ty} "
f"{acc_reg_vector}, {a_regs}, {b_desc_reg}, p, {imm_regs};"
)
ptx = f"{{ .reg .pred p; setp.ne.b32 p, {use_out_reg}, 0; {wgmma_instr} }}\n"
def lc(x):
return llvm.mlir_constant(i32, ir.IntegerAttr.get(i32, x))
def as_i32_reg(v):
return llvm.extractelement(
vector.bitcast(ir.VectorType.get((1,), i32), v), lc(0)
)
use_out = scale_a = scale_b = lc(1)
imms = [use_out, scale_a, scale_b]
if supports_transpose and a_transpose is not None:
imms += [lc(int(a_transpose)), lc(int(b_transpose))]
elif supports_transpose:
imms += [lc(int(b_transpose))]
if acc.ndim != 4 or acc.shape[0] != 1 or acc.shape[2:] != (2, 1):
raise ValueError(acc.shape)
acc_regs = [ # pylint: disable=g-complex-comprehension
vector.extractelement(reg, position=c(pos, index))
for reg in acc.flat
for pos in range(2)
]
acc_struct_type = ir.Type.parse(
f"!llvm.struct<({','.join('f32' for _ in acc_regs)})>"
)
for i in range(4):
# Slice out the relevant part of A or advance the A descriptor.
if a_in_regs:
a_slice = a[:, (i * 16) : ((i + 1) * 16)]
a_args = [as_i32_reg(v) for v in a_slice.registers.flat]
else:
if i > 0:
a = llvm.add(
a,
llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, a_k_stride >> 4)),
)
a_args = [a]
# Advance the B descriptor.
if i > 0:
b_descriptor = llvm.add(
b_descriptor,
llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, b_k_stride >> 4)),
)
assert len(a_args) == len(a_reg_constraints)
acc_struct = llvm.inline_asm(
acc_struct_type,
[*acc_regs, *a_args, b_descriptor, *imms],
ptx,
reg_constraints,
asm_dialect=0,
has_side_effects=True,
)
acc_regs = [
llvm.extractvalue(f32, acc_struct, [i]) for i in range(len(acc_regs))
]
acc_vec_regs = []
for first, second in zip(acc_regs[::2], acc_regs[1::2]):
vec = llvm.mlir_undef(ir.VectorType.get((2,), f32))
vec = llvm.insertelement(vec, first, position=lc(0))
vec = llvm.insertelement(vec, second, position=lc(1))
acc_vec_regs.append(vec)
return np.asarray(acc_vec_regs, dtype=object).reshape(acc.shape)
class WGMMALayout(enum.Enum):
ROW_MAJOR = enum.auto()
COL_MAJOR = enum.auto()
# TODO(apaszke): Remove WGMMALayout. Make input shapes logical and infer
# transpositions from memref strides.
def wgmma(
acc: WGMMAAccumulator,
a,
b,
*,
# Order only applies within each tile!
a_order: WGMMALayout | None = None,
b_order: WGMMALayout = WGMMALayout.ROW_MAJOR,
):
if a_in_regs := isinstance(a, mgpu.FragmentedArray):
a_element_type = a.mlir_dtype
a_shape = a.shape
else:
a_ty = ir.MemRefType(a.type)
a_element_type = a_ty.element_type
a_shape = a_ty.shape
b_ty = ir.MemRefType(b.type)
supported_types = {ir.F16Type.get(), ir.BF16Type.get(), ir.F32Type.get()}
if a_element_type not in supported_types:
raise ValueError(a_element_type)
if b_ty.element_type not in supported_types:
raise ValueError(b_ty.element_type)
if (element_type := a_element_type) != b_ty.element_type:
raise ValueError
element_bytewidth = bytewidth(element_type)
kn_tile = 128 // element_bytewidth
groups_k, groups_n = b_ty.shape[:2]
if b_ty.shape[2:] != [kn_tile, kn_tile]:
raise ValueError(b_ty.shape)
if a_in_regs:
if a_element_type != ir.F16Type.get() and a_element_type != ir.BF16Type.get():
raise ValueError(a_element_type)
if a_shape[0] % 64 or a_shape[1] % kn_tile:
raise ValueError(a_shape)
if a_shape[1] // kn_tile != groups_k:
raise ValueError(a_shape[1] // kn_tile, groups_k)
groups_m = a_shape[0] // 64
if a_order is not None:
raise ValueError(
"a_order can only be specified when A is in shared memory"
)
else:
groups_m = a_shape[0]
if a_shape[1] != groups_k:
raise ValueError(a_shape[1], groups_k)
if a_shape[2:] != [64, kn_tile]:
raise ValueError(a_shape)
if a_order is None:
a_order = WGMMALayout.ROW_MAJOR
row_major = WGMMALayout.ROW_MAJOR
col_major = WGMMALayout.COL_MAJOR
a_desc_fields = dict(
leading_byte_offset=((1 if a_order == row_major else 512) << 4),
stride_byte_offset=(64 << 4),
swizzle=128,
memory_space=3,
)
b_desc_fields = dict(
leading_byte_offset=((512 if b_order == row_major else 1) << 4),
stride_byte_offset=(64 << 4),
swizzle=128,
memory_space=3,
)
wgmma_params = dict(
a_transpose=a_order == col_major,
b_transpose=b_order == row_major,
a_k_stride=(2 if a_order == row_major else 128) * 16,
b_k_stride=(128 if b_order == row_major else 2) * 16,
n=(groups_n * kn_tile),
element_type=ir.FloatTF32Type.get()
if ir.F32Type.isinstance(element_type)
else element_type,
)
if a_in_regs:
wgmma_params["a_k_stride"] = wgmma_params["a_transpose"] = None
if a_in_regs:
nvvm.wgmma_fence_aligned() # Make sure the registers are ready.
a_m_byte_stride = a_k_byte_stride = a_desc_base = None # Silence pytype.
else:
a_desc_base = create_descriptor(a, **a_desc_fields)
a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset()
a_byte_strides = [s * element_bytewidth for s in a_strides]
a_m_byte_stride, a_k_byte_stride = a_byte_strides[:2]
if a_byte_strides[2:] != [128, element_bytewidth]:
raise ValueError(a_byte_strides)
b_desc_base = create_descriptor(b, **b_desc_fields)
b_strides, _ = b_ty.get_strides_and_offset()
b_byte_strides = [s * element_bytewidth for s in b_strides]
b_k_byte_stride = b_byte_strides[0]
if b_byte_strides[1:] != [128 * kn_tile, 128, element_bytewidth]:
raise ValueError(b_byte_strides)
i64 = ir.IntegerType.get_signless(64)
new_acc_regs = acc.value.registers.copy()
for mi in range(groups_m):
for ki in range(groups_k):
if a_in_regs:
a_mk = a[mi * 64 : (mi + 1) * 64, ki * kn_tile : (ki + 1) * kn_tile]
else:
a_mk = llvm.add(
a_desc_base,
c(wgmma_encode(mi * a_m_byte_stride + ki * a_k_byte_stride), i64),
)
b_k = llvm.add(b_desc_base, c(wgmma_encode(ki * b_k_byte_stride), i64))
new_acc_regs[mi : mi + 1] = wgmma_m64k128B(
new_acc_regs[mi : mi + 1], a_mk, b_k, **wgmma_params
)
return WGMMAAccumulator(
_value=mgpu.FragmentedArray(
_registers=new_acc_regs, _layout=mgpu.WGMMA_LAYOUT
),
_sync=False,
)

View File

@ -17,6 +17,7 @@
load("//jaxlib:symlink_files.bzl", "symlink_files")
load(
"//jaxlib:jax.bzl",
"if_building_mosaic_gpu",
"if_windows",
"py_library_providing_imports_info",
"pybind_extension",
@ -70,7 +71,7 @@ py_library_providing_imports_info(
"//jaxlib/mlir:vector_dialect",
"//jaxlib/mosaic",
"//jaxlib/triton",
],
] + if_building_mosaic_gpu(["//jaxlib/mosaic/gpu:mosaic_gpu"]),
)
symlink_files(

View File

@ -161,6 +161,12 @@ def if_building_jaxlib(if_building, if_not_building = []):
"//conditions:default": if_not_building,
})
def if_building_mosaic_gpu(if_building, if_not_building = []):
return select({
"//jax:enable_mosaic_gpu": if_building,
"//conditions:default": if_not_building,
})
# buildifier: disable=function-docstring
def jax_test(
name,

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
load("//jaxlib:symlink_files.bzl", "symlink_inputs")
load("//jaxlib:symlink_files.bzl", "symlink_files", "symlink_inputs")
package(
default_visibility = [
@ -217,3 +217,94 @@ symlink_inputs(
"//jaxlib/mlir/_mlir_libs:_stablehlo",
],
)
symlink_inputs(
name = "execution_engine",
rule = py_library,
symlinked_inputs = {"srcs": {
".": [
"@llvm-project//mlir/python:ExecutionEnginePyFiles",
],
}},
deps = [
":mlir",
"//jaxlib/mlir/_mlir_libs:_mlirExecutionEngine",
],
)
symlink_inputs(
name = "nvgpu_dialect",
rule = py_library,
symlinked_inputs = {"srcs": {"dialects": [
"@llvm-project//mlir/python:NVGPUOpsPyFiles",
]}},
deps = [
":core",
":ir",
":mlir",
],
)
symlink_inputs(
name = "nvvm_dialect",
rule = py_library,
symlinked_inputs = {"srcs": {"dialects": [
"@llvm-project//mlir/python:NVVMOpsPyFiles",
]}},
deps = [
":core",
":ir",
":mlir",
],
)
symlink_files(
name = "gpu_files",
srcs = ["@llvm-project//mlir/python:GPUOpsPyFiles"],
dst = "dialects",
flatten = True,
)
symlink_files(
name = "gpu_package_files",
srcs = ["@llvm-project//mlir/python:GPUOpsPackagePyFiles"],
dst = "dialects/gpu",
flatten = True,
)
symlink_files(
name = "gpu_package_passes_files",
srcs = ["@llvm-project//mlir/python:GPUOpsPackagePassesPyFiles"],
dst = "dialects/gpu/passes",
flatten = True,
)
py_library(
name = "gpu_dialect",
srcs = [
":gpu_files",
":gpu_package_files",
":gpu_package_passes_files",
],
deps = [
":core",
":ir",
":mlir",
"//jaxlib/mlir/_mlir_libs:_mlirGPUPasses",
],
)
symlink_inputs(
name = "llvm_dialect",
rule = py_library,
symlinked_inputs = {"srcs": {"dialects": [
"@llvm-project//mlir/python:LLVMOpsPyFiles",
]}},
deps = [
":core",
":ir",
":mlir",
"//jaxlib/mlir/_mlir_libs:_mlirDialectsLLVM",
],
)

View File

@ -14,6 +14,7 @@
load(
"//jaxlib:jax.bzl",
"if_building_mosaic_gpu",
"if_windows",
"py_extension",
"pybind_extension",
@ -58,6 +59,51 @@ py_extension(
],
)
py_extension(
name = "_mlirExecutionEngine",
srcs = [
"@llvm-project//mlir:lib/Bindings/Python/ExecutionEngineModule.cpp",
],
copts = COPTS,
linkopts = LINKOPTS,
deps = [
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:CAPIExecutionEngineHeaders",
"@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps",
"@pybind11",
],
)
py_extension(
name = "_mlirGPUPasses",
srcs = [
"@llvm-project//mlir:lib/Bindings/Python/GPUPasses.cpp",
],
copts = COPTS,
linkopts = LINKOPTS,
deps = [
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:CAPIGPUHeaders",
"@pybind11",
],
)
py_extension(
name = "_mlirDialectsLLVM",
srcs = [
"@llvm-project//mlir:lib/Bindings/Python/DialectLLVM.cpp",
],
copts = COPTS,
linkopts = LINKOPTS,
deps = [
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:CAPIIRHeaders",
"@llvm-project//mlir:CAPILLVMHeaders",
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
"@pybind11",
],
)
py_extension(
name = "_mlirDialectsSparseTensor",
srcs = [
@ -148,11 +194,36 @@ symlink_inputs(
],
)
cc_library(
name = "jaxlib_mlir_capi_shims",
srcs = ["jaxlib_mlir_capi_shims.cc"],
hdrs = ["jaxlib_mlir_capi_shims.h"],
deps = [
"@llvm-project//mlir:BuiltinToLLVMIRTranslation",
"@llvm-project//mlir:CAPIIRHeaders",
"@llvm-project//mlir:GPUPipelines",
"@llvm-project//mlir:GPUToLLVMIRTranslation",
"@llvm-project//mlir:LLVMToLLVMIRTranslation",
"@llvm-project//mlir:MemRefTransforms",
"@llvm-project//mlir:NVVMTarget",
"@llvm-project//mlir:NVVMToLLVMIRTranslation",
],
alwayslink = 1,
)
cc_library(
name = "jaxlib_mlir_capi_shims_hdrs",
hdrs = ["jaxlib_mlir_capi_shims.h"],
deps = [
"@llvm-project//mlir:CAPIIRHeaders",
],
)
# JAX-specific registrations.
py_extension(
name = "register_jax_dialects",
srcs = ["register_jax_dialects.cc"],
copts = COPTS,
copts = COPTS + if_building_mosaic_gpu(["-DJAXLIB_MOSAIC_GPU"]),
linkopts = LINKOPTS,
deps = [
":jaxlib_mlir_capi_shared_library",
@ -166,7 +237,14 @@ py_extension(
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
"@local_config_python//:headers",
"@pybind11",
],
] + if_building_mosaic_gpu([
":jaxlib_mlir_capi_shims_hdrs",
"@llvm-project//mlir:CAPIGPUHeaders",
"@llvm-project//mlir:CAPINVGPUHeaders",
"@llvm-project//mlir:CAPINVVMHeaders",
"@llvm-project//mlir:CAPILLVMHeaders",
"@llvm-project//mlir:CAPIConversionHeaders",
]),
)
##---------------------------------------------------------------------------##
@ -270,7 +348,15 @@ cc_library(
[
"//jaxlib/triton:triton_dialect_capi_objects",
],
),
) + if_building_mosaic_gpu([
":jaxlib_mlir_capi_shims",
"@llvm-project//mlir:CAPIConversionObjects",
"@llvm-project//mlir:CAPIExecutionEngineObjects",
"@llvm-project//mlir:CAPIGPUObjects",
"@llvm-project//mlir:CAPILLVMObjects",
"@llvm-project//mlir:CAPINVGPUObjects",
"@llvm-project//mlir:CAPINVVMObjects",
]),
)
cc_binary(

View File

@ -0,0 +1,47 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi_shims.h"
#include "mlir-c/IR.h"
#include "mlir/CAPI/IR.h"
#include "mlir/Dialect/GPU/Pipelines/Passes.h"
#include "mlir/Target/LLVM/NVVM/Target.h"
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
extern "C" {
void jaxMlirRegisterMemRefPasses() {
mlir::memref::registerMemRefPasses();
}
void jaxMlirRegisterInterfaceExternalModels(MlirDialectRegistry registry) {
mlir::NVVM::registerNVVMTargetInterfaceExternalModels(*unwrap(registry));
mlir::gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(
*unwrap(registry));
mlir::registerBuiltinDialectTranslation(*unwrap(registry));
mlir::registerGPUDialectTranslation(*unwrap(registry));
mlir::registerLLVMDialectTranslation(*unwrap(registry));
mlir::registerNVVMDialectTranslation(*unwrap(registry));
}
void jaxMlirRegisterGPUToNVVMPipeline() {
mlir::gpu::registerGPUToNVVMPipeline();
}
}

View File

@ -0,0 +1,34 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef JAXLIB_MLIR_CAPI_SHIMS
#define JAXLIB_MLIR_CAPI_SHIMS
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#ifdef __cplusplus
extern "C" {
#endif
MLIR_CAPI_EXPORTED void jaxMlirRegisterMemRefPasses();
MLIR_CAPI_EXPORTED void jaxMlirRegisterInterfaceExternalModels(MlirDialectRegistry registry);
MLIR_CAPI_EXPORTED void jaxMlirRegisterGPUToNVVMPipeline();
#ifdef __cplusplus
}
#endif
#endif // JAXLIB_MLIR_CAPI_SHIMS

View File

@ -9,6 +9,14 @@
#include "mlir-c/Dialect/Vector.h"
#include "mlir-c/Transforms.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#ifdef JAXLIB_MOSAIC_GPU
#include "mlir-c/Dialect/GPU.h"
#include "mlir-c/Dialect/NVGPU.h"
#include "mlir-c/Dialect/NVVM.h"
#include "mlir-c/Dialect/LLVM.h"
#include "mlir-c/Conversion.h"
#include "jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi_shims.h"
#endif
#define REGISTER_DIALECT(name) \
MlirDialectHandle name##_dialect = mlirGetDialectHandle__##name##__(); \
@ -27,5 +35,17 @@ PYBIND11_MODULE(register_jax_dialects, m) {
mlirRegisterTransformsPasses();
// Transforms used by JAX.
mlirRegisterTransformsStripDebugInfo();
#ifdef JAXLIB_MOSAIC_GPU
REGISTER_DIALECT(gpu);
REGISTER_DIALECT(nvgpu);
REGISTER_DIALECT(nvvm);
REGISTER_DIALECT(llvm);
mlirRegisterGPUPasses();
mlirRegisterConversionPasses();
// TODO(apaszke): Upstream and remove those.
jaxMlirRegisterMemRefPasses();
jaxMlirRegisterInterfaceExternalModels(registry);
jaxMlirRegisterGPUToNVVMPipeline();
#endif
});
}

76
jaxlib/mosaic/gpu/BUILD Normal file
View File

@ -0,0 +1,76 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("@rules_python//python:defs.bzl", "py_library")
load("//jaxlib:jax.bzl", "pybind_extension")
package(
default_applicable_licenses = [],
default_visibility = ["//:__subpackages__"],
)
py_library(
name = "mosaic_gpu",
data = [":libmlir_cuda_runtime.so"],
deps = [
":_mosaic_gpu_ext",
"//jaxlib/mlir:execution_engine",
"//jaxlib/mlir:gpu_dialect",
"//jaxlib/mlir:llvm_dialect",
"//jaxlib/mlir:nvgpu_dialect",
"//jaxlib/mlir:nvvm_dialect",
],
)
pybind_extension(
name = "_mosaic_gpu_ext",
srcs = ["_mosaic_gpu_ext.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
linkopts = select({
"@xla//xla/python:use_jax_cuda_pip_rpaths": [
"-Wl,-rpath,$$ORIGIN/../../../nvidia/cuda_runtime/lib",
],
"//conditions:default": [],
}),
deps = [
"//jaxlib:kernel_nanobind_helpers",
"@xla//xla/service:custom_call_status",
"@nanobind",
],
)
cc_binary(
name = "libmlir_cuda_runtime.so",
srcs = ["@llvm-project//mlir:lib/ExecutionEngine/CudaRuntimeWrappers.cpp"],
copts = ["-fvisibility=default"],
linkopts = select({
"@xla//xla/python:use_jax_cuda_pip_rpaths": [
"-Wl,-rpath,$$ORIGIN/../../../nvidia/cuda_runtime/lib",
],
"//conditions:default": [],
}),
linkshared = 1,
tags = [
"manual",
"notap",
],
deps = [
"@llvm-project//mlir:mlir_c_runner_utils_hdrs",
"@xla//xla/tsl/cuda:cudart",
"@local_config_cuda//cuda:cuda_headers",
],
)

View File

@ -0,0 +1,24 @@
#include "nanobind/nanobind.h"
#include "jaxlib/kernel_nanobind_helpers.h"
#include "xla/service/custom_call_status.h"
namespace jax::cuda {
namespace {
namespace nb = nanobind;
using MosaicHostFunc = void(void**);
void MosaicKernelCall(void* stream, void** buffers, char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
void* args[2] = {&stream, &buffers};
MosaicHostFunc* func = *reinterpret_cast<MosaicHostFunc**>(opaque);
func(args);
}
NB_MODULE(_mosaic_gpu_ext, m) {
m.def("_custom_call_capsule",
[]() { return EncapsulateFunction(MosaicKernelCall); });
}
} // namespace
} // namespace jax::cuda

View File

@ -98,10 +98,13 @@ setup(
'cuda/*',
'cuda/nvvm/libdevice/libdevice*',
'mosaic/*.py',
'mosaic/gpu/*.so',
'mosaic/python/*.py',
'mosaic/python/*.so',
'mlir/*.py',
'mlir/dialects/*.py',
'mlir/dialects/gpu/*.py',
'mlir/dialects/gpu/passes/*.py',
'mlir/extras/*.py',
'mlir/_mlir_libs/*.dll',
'mlir/_mlir_libs/*.dylib',

View File

@ -266,12 +266,27 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi
"__main__/jaxlib/mosaic/python/_tpu_gen.py", dst_dir=mosaic_python_dir
)
has_mosaic_gpu = exists(f"__main__/jaxlib/mosaic/gpu/_mosaic_gpu_ext.{pyext}")
def if_has_mosaic_gpu(extras):
return extras if has_mosaic_gpu else []
if has_mosaic_gpu:
copy_runfiles(
dst_dir=jaxlib_dir / "mosaic" / "gpu",
src_files=[
"__main__/jaxlib/mosaic/gpu/libmlir_cuda_runtime.so",
f"__main__/jaxlib/mosaic/gpu/_mosaic_gpu_ext.{pyext}",
],
)
copy_runfiles(
dst_dir=jaxlib_dir / "mlir",
src_files=[
"__main__/jaxlib/mlir/ir.py",
"__main__/jaxlib/mlir/passmanager.py",
],
] + if_has_mosaic_gpu([
"__main__/jaxlib/mlir/execution_engine.py",
]),
)
copy_runfiles(
dst_dir=jaxlib_dir / "mlir" / "dialects",
@ -302,7 +317,19 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi
"__main__/jaxlib/mlir/dialects/sparse_tensor.py",
"__main__/jaxlib/mlir/dialects/stablehlo.py",
"__main__/jaxlib/mlir/dialects/vector.py",
],
] + if_has_mosaic_gpu([
"__main__/jaxlib/mlir/dialects/_gpu_enum_gen.py",
"__main__/jaxlib/mlir/dialects/_gpu_ops_gen.py",
"__main__/jaxlib/mlir/dialects/_nvgpu_enum_gen.py",
"__main__/jaxlib/mlir/dialects/_nvgpu_ops_gen.py",
"__main__/jaxlib/mlir/dialects/_nvvm_enum_gen.py",
"__main__/jaxlib/mlir/dialects/_nvvm_ops_gen.py",
"__main__/jaxlib/mlir/dialects/_llvm_enum_gen.py",
"__main__/jaxlib/mlir/dialects/_llvm_ops_gen.py",
"__main__/jaxlib/mlir/dialects/nvgpu.py",
"__main__/jaxlib/mlir/dialects/nvvm.py",
"__main__/jaxlib/mlir/dialects/llvm.py",
]),
)
copy_runfiles(
dst_dir=jaxlib_dir / "mlir" / "extras",
@ -310,6 +337,20 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi
"__main__/jaxlib/mlir/extras/meta.py",
],
)
if has_mosaic_gpu:
copy_runfiles(
dst_dir=jaxlib_dir / "mlir" / "dialects" / "gpu",
src_files=[
"__main__/jaxlib/mlir/dialects/gpu/__init__.py",
],
)
copy_runfiles(
dst_dir=jaxlib_dir / "mlir" / "dialects" / "gpu" / "passes",
src_files=[
"__main__/jaxlib/mlir/dialects/gpu/passes/__init__.py",
],
)
if build_utils.is_windows():
capi_so = "__main__/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi.dll"
@ -339,7 +380,10 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi
f"__main__/jaxlib/mlir/_mlir_libs/_triton_ext.{pyext}",
"__main__/jaxlib/mlir/_mlir_libs/_triton_ext.pyi",
]
),
) + if_has_mosaic_gpu([
f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsLLVM.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/_mlirExecutionEngine.{pyext}",
]),
)
triton_dir = jaxlib_dir / "triton"

52
tests/mosaic/BUILD Normal file
View File

@ -0,0 +1,52 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load(
"//jaxlib:jax.bzl",
"jax_generate_backend_suites",
"jax_test",
"py_deps",
)
licenses(["notice"])
package(
default_applicable_licenses = [],
default_visibility = ["//visibility:private"],
)
jax_generate_backend_suites()
jax_test(
name = "gpu_test",
srcs = [
"gpu_test.py",
],
disable_backends = [
"cpu",
"tpu",
],
disable_configs = [
"gpu",
"gpu_a100",
"gpu_p100",
"gpu_p100_x32",
"gpu_x32",
"gpu_pjrt_c_api",
],
shard_count = 4,
deps = [
"//jax:mosaic_gpu",
] + py_deps("absl/testing") + py_deps("numpy"),
)

803
tests/mosaic/gpu_test.py Normal file
View File

@ -0,0 +1,803 @@
# Copyright 2024 The JAX Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Mosaic GPU DSL functions and utilities."""
import operator
from typing import Optional
from absl.testing import absltest, parameterized
import numpy as np
import jax
import jax.numpy as jnp
from jax._src import config
from jax._src import test_util as jtu
from jax._src.interpreters import mlir
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith
from jax._src.lib.mlir.dialects import scf
from jax._src.lib.mlir.dialects import vector
try:
import jax._src.lib.mosaic_gpu # noqa: F401
HAS_MOSAIC_GPU = True
except ImportError:
HAS_MOSAIC_GPU = False
else:
from jax.experimental.mosaic import gpu as mosaic_gpu
from jax.experimental.mosaic.gpu import dsl as mgpu
from jax.experimental.mosaic.gpu.utils import * # noqa: F403
from jax._src.lib.mlir.dialects import gpu
from jax._src.lib.mlir.dialects import llvm
# ruff: noqa: F405
config.update("jax_traceback_filtering", "off")
config.parse_flags_with_absl()
def nd_loop(bounds, body, *, _idxs = ()):
if not bounds:
body(*_idxs)
return
bound, *other_bounds = bounds
@fori(bound, ())
def _loop_body(i, _):
nd_loop(other_bounds, body, _idxs=(*_idxs, i))
return ()
def mlir_sum(elems):
assert elems
total = elems[0]
for elem in elems[1:]:
total = arith.addi(total, elem)
return total
def copy(src: ir.Value, dst: ir.Value, swizzle: Optional[int] = None):
index = ir.IndexType.get()
thread_id = gpu.thread_id(gpu.Dimension.x)
stride = gpu.block_dim(gpu.Dimension.x)
for dim in (gpu.Dimension.y, gpu.Dimension.z):
thread_id = arith.addi(thread_id, arith.muli(gpu.thread_id(dim), stride))
stride = arith.muli(stride, gpu.block_dim(dim))
is_first_thread = arith.cmpi(arith.CmpIPredicate.eq, thread_id, c(0, index))
src_ty = ir.MemRefType(src.type)
dst_ty = ir.MemRefType(dst.type)
if src_ty.shape != dst_ty.shape:
raise ValueError(
f"src and dst shapes don't match: {src_ty.shape} != {dst_ty.shape}"
)
shape = src_ty.shape
dyn_strides = [c(s, index) for s in get_contiguous_strides(shape)]
with ir.InsertionPoint(scf.IfOp(is_first_thread).then_block):
def body(*idx):
dst_idx = idx
if swizzle is not None:
if swizzle != 128:
raise NotImplementedError("Only swizzle 128B implemented")
# TODO(apaszke): This can probably be cleaned up.
# But it works and it's test-only, so it doesn't matter much.
# After all, swizzle should just be an xor of row and linear idx,
# adjusted for the bytewidth.
bytes_per_element = bytewidth(src_ty.element_type)
elems_per_tile = 1024 // bytes_per_element
elems_per_row = elems_per_tile // 8
elems_per_group = 16 // bytes_per_element
linear_idx = c(0, index)
for stride, i in zip(dyn_strides, idx):
linear_idx = arith.addi(linear_idx, arith.muli(i, stride))
tile_offset = arith.remui(linear_idx, c(elems_per_tile, index))
linear_tile_start = arith.subi(linear_idx, tile_offset)
row = arith.divui(tile_offset, c(elems_per_row, index))
row_offset = arith.remui(tile_offset, c(elems_per_row, index))
src_group = arith.divui(row_offset, c(elems_per_group, index))
group_offset = arith.remui(row_offset, c(elems_per_group, index))
dst_group = arith.xori(src_group, row)
dst_linear_idx = mlir_sum([
linear_tile_start,
arith.muli(row, c(elems_per_row, index)),
arith.muli(dst_group, c(elems_per_group, index)),
group_offset,
])
dst_idx = [
arith.remui(arith.divui(dst_linear_idx, stride), c(bound, index))
for stride, bound in zip(dyn_strides, shape)
]
memref.store(memref.load(src, idx), dst, dst_idx)
nd_loop([c(d, index) for d in shape], body)
scf.yield_([])
gpu.barrier()
nvvm.fence_proxy(nvvm.ProxyKind.async_)
def iota_tensor(m, n, mlir_dtype):
assert m % 64 == 0
assert n % 8 == 0
def c(i):
return arith.constant(index, ir.IntegerAttr.get(index, i))
index = ir.IndexType.get()
i32 = ir.IntegerType.get_signless(32)
warp_id = arith.divui(gpu.thread_id(gpu.Dimension.x), c(32))
within_warp_id = arith.remui(gpu.thread_id(gpu.Dimension.x), c(32))
warp_row_start = arith.muli(warp_id, c(16))
within_warp_row = arith.divui(within_warp_id, c(4))
start_row = arith.addi(warp_row_start, within_warp_row)
start_col = arith.muli(arith.remui(within_warp_id, c(4)), c(2))
registers = np.empty((m // 64, n // 8, 2, 1), dtype=object)
for row_tile, col_tile, row_subtile, _ in np.ndindex(registers.shape):
row = arith.addi(start_row, c(row_tile * 64 + row_subtile * 8))
col = arith.addi(start_col, c(col_tile * 8))
row_value_base = arith.muli(row, c(n))
vec = llvm.mlir_undef(ir.VectorType.get((2,), i32))
for col_offset in range(2):
value = arith.addi(row_value_base, arith.addi(c(col_offset), col))
value = arith.index_cast(i32, value)
vec = vector.insertelement(value, vec, position=c(col_offset))
registers[row_tile, col_tile, row_subtile, 0] = vec
t = mgpu.FragmentedArray(_registers=registers, _layout=mgpu.WGMMA_LAYOUT)
return t.astype(mlir_dtype)
class TestCase(parameterized.TestCase):
def setUp(self):
if not HAS_MOSAIC_GPU:
self.skipTest("jaxlib built without Mosaic GPU")
super().setUp()
self.prng = np.random.default_rng(1234)
self.ctx = mlir.make_ir_context()
self.ctx.__enter__()
self.loc = ir.Location.unknown()
self.loc.__enter__()
def tearDown(self):
self.loc.__exit__(None, None, None)
self.ctx.__exit__(None, None, None)
del self.loc, self.ctx
super().tearDown()
class TestUtilTest(TestCase):
def test_copy(self):
def kernel(ctx, src, dst, _):
copy(src, dst)
x = jnp.arange(2 * 3 * 5).reshape(2, 5, 3)
y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, ())(x)
np.testing.assert_array_equal(y, x)
def test_copy_swizzle(self):
def kernel(ctx, src, dst, _):
copy(src, dst, swizzle=128)
x = jnp.arange(8 * 32, dtype=jnp.float32).reshape(8, 32)
y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, ())(x)
expected = np.zeros_like(y)
for i in range(8):
for j in range(8):
js = j ^ i
expected[i, (j * 4):(j * 4) + 4] = x[i, (js * 4):(js * 4) + 4]
np.testing.assert_array_equal(y, expected)
def test_copy_swizzle_noop(self):
# Two swizzles cancel out
def kernel(ctx, src, dst, smem):
copy(src, smem, swizzle=128)
copy(smem, dst, swizzle=128)
x = jnp.arange(8 * 32, dtype=jnp.float32).reshape(8, 32)
y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x)
np.testing.assert_array_equal(y, x)
def test_iota_tensor(self):
m = n = 64
def kernel(ctx, dst, _):
f32 = ir.F32Type.get()
index = ir.IndexType.get()
registers = iota_tensor(m, n, f32).registers
assert registers.size == 16, registers.size
for i, vec_reg in enumerate(registers.flat):
for j in range(2):
reg = vector.extractelement(vec_reg, position=c(j, index))
memref.store(
reg, dst, [gpu.thread_id(gpu.Dimension.x), c(2 * i + j, index)]
)
out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32)
regs = mosaic_gpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
)()
thread_ids = np.arange(128)
warp_ids = thread_ids // 32
lane_ids = thread_ids % 32
thread_rows = warp_ids * 16 + lane_ids // 4
thread_start_cols = (lane_ids % 4) * 2
thread_cols = thread_start_cols[:, None] + (np.arange(n // 8)[None] * 8)
regs = regs.reshape(128, 8, 2, 2)
for row_half in range(2):
for col_half in range(2):
np.testing.assert_array_equal(
regs[..., row_half, col_half],
(thread_rows[:, None] + row_half * 8) * n + thread_cols + col_half
)
class MemRefTest(TestCase):
@parameterized.product(
dim=tuple(range(3)),
strided=(False, True)
)
def test_unsqueeze(self, dim, strided):
def kernel(ctx, inp, out, _):
if strided:
for i in range(8):
s = ds(i, 1)
out_slice = s if dim != 0 else (slice(None), s)
copy(
memref_unsqueeze(memref_slice(inp, s), dim),
memref_slice(out, out_slice),
)
else:
copy(memref_unsqueeze(inp, dim), out)
x = np.arange(8 * 16, dtype=jnp.float32).reshape(8, 16)
out_shape = list(x.shape)
out_shape.insert(dim, 1)
out_ty = jax.ShapeDtypeStruct(out_shape, jnp.float32)
y = mosaic_gpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), x, out_ty, ()
)(x)
np.testing.assert_array_equal(y, x.reshape(out_shape))
@parameterized.product(
dim=tuple(range(2)),
strided=(False, True)
)
def test_unfold(self, dim, strided):
in_shape = (8, 16)
def kernel(ctx, inp, out, _):
if strided:
# We slice the dim we don't unfold
for i in range(in_shape[1 - dim] // 4):
s = ds(i * 4, 4)
in_slice = s if dim == 1 else (slice(None), s)
out_slice = s if dim == 1 else (slice(None),) * 3 + (s,)
copy(
memref_unfold(memref_slice(inp, in_slice), dim, (2, 2, None)),
memref_slice(out, out_slice),
)
else:
copy(memref_unfold(inp, dim, (2, 2, None)), out)
x = np.arange(np.prod(in_shape), dtype=jnp.float32).reshape(in_shape)
out_shape = list(in_shape)
out_shape[dim:dim + 1] = [2, 2, out_shape[dim] // 4]
out_ty = jax.ShapeDtypeStruct(out_shape, jnp.float32)
y = mosaic_gpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), x, out_ty, ()
)(x)
np.testing.assert_array_equal(y, x.reshape(out_ty.shape))
@parameterized.product(
dim=tuple(range(2)),
)
def test_fold_not_strided(self, dim):
def kernel(ctx, inp, out, _):
copy(memref_fold(inp, dim, 2), out)
x = np.arange(8 * 2 * 8, dtype=jnp.float32).reshape(8, 2, 8)
out_ty = jax.ShapeDtypeStruct((16, 8) if dim == 0 else (8, 16), jnp.float32)
y = mosaic_gpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), x, out_ty, ()
)(x)
np.testing.assert_array_equal(y, x.reshape(out_ty.shape))
@parameterized.named_parameters([
("packed", (4, 4, 4), (16, 4, 1), 1, 2, False),
("strided_end", (4, 4, 4, 4), (256, 64, 16, 4), 1, 2, False),
("strided_bot", (4, 4, 4, 4), (256, 16, 4, 1), 1, 2, False),
("strided_top", (4, 4, 4, 4), (256, 64, 4, 1), 1, 2, True),
("strided_mid", (4, 4, 4, 4), (265, 64, 16, 1), 1, 3, True),
("overap", (2, 4, 4), (16, 1, 1), 0, 3, True),
])
def test_fold_strided(
self, shape, strides, dim, fold_rank, throws_not_impl
):
expanded_shape = get_packed_shape(strides, shape)
total_size = np.prod(expanded_shape)
np_inp = np.arange(total_size, dtype=jnp.float32).reshape(expanded_shape)
index = tuple([slice(0, s) for s in shape])
# Reference implementation
def np_fold(inp, dim, fold_rank):
out_shape = list(inp.shape)
out_shape[dim : dim + fold_rank] = [
int(np.prod(inp.shape[dim : dim + fold_rank]))
]
if throws_not_impl:
return jax.ShapeDtypeStruct(shape=out_shape, dtype=inp.dtype)
else:
return inp.reshape(*out_shape)
total_size = np.prod(shape) * np.prod(strides)
def do_test():
def kernel(ctx, inp, out, _):
copy(memref_fold(memref_slice(inp, index), dim, fold_rank), out)
out = np_fold(np_inp[index], dim, fold_rank)
y = mosaic_gpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), np_inp, out, ()
)(np_inp)
assert (
not throws_not_impl
), "If it should have thrown it would during the call."
np.testing.assert_array_equal(y, out)
if throws_not_impl:
with self.assertRaises(NotImplementedError):
do_test()
else:
do_test()
def get_packed_shape(strides, shape):
perm = sorted(range(len(strides)), key=lambda i: strides[i], reverse=True)
ordered_strides = [strides[i] for i in perm]
ordered_shape = [shape[i] for i in perm]
packed_shape = [ordered_shape[-1]]
packed_shape += [
stride0 // stride
for stride0, stride in zip(ordered_strides, ordered_strides[1:])
]
# Invert permutation
inv_perm = [None] * len(perm)
for i, p in enumerate(perm):
inv_perm[p] = i
return [packed_shape[i] for i in inv_perm]
class WGMMATest(TestCase):
@parameterized.named_parameters(
("f32", ir.F32Type, jnp.float32), ("f16", ir.F16Type, jnp.float16)
)
def test_store_untiled(self, mlir_dtype_cls, jax_dtype):
mlir_dtype = mlir_dtype_cls.get()
def kernel(ctx, out, _):
del ctx
iota_tensor(64, 64, mlir_dtype).store_untiled(out)
expected = np.arange(64 * 64, dtype=jax_dtype).reshape(64, 64)
iota = mosaic_gpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (), expected, ()
)()
np.testing.assert_array_equal(iota, expected)
@parameterized.named_parameters(
("f32", ir.F32Type, jnp.float32),
("f16", ir.F16Type, jnp.float16),
)
def test_store_tiled(self, mlir_dtype_cls, jax_dtype):
mlir_dtype = mlir_dtype_cls.get()
m = 128
n = 256
tiling = (64, 128 // bytewidth(mlir_dtype))
def kernel(ctx, out, smem):
del ctx
iota_tensor(m, n, mlir_dtype).store_tiled(smem, swizzle=128)
copy(smem, out, swizzle=128)
expected = (
np.arange(m * n, dtype=jax_dtype)
.reshape(m // tiling[0], tiling[0], n // tiling[1], tiling[1])
.transpose(0, 2, 1, 3)
)
iota = mosaic_gpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (), expected, expected
)()
np.testing.assert_array_equal(iota, expected)
@parameterized.product(
lhs_transpose=(False, True),
rhs_transpose=(False, True),
mlir_dtype_cls=(ir.F16Type, ir.BF16Type, ir.F32Type),
m=(64, 128, 192),
n=(32, 64, 128, 192),
k_steps=(1, 2),
tma_inputs=(False, True),
)
def test_wgmma(
self,
m,
n,
k_steps,
mlir_dtype_cls,
lhs_transpose,
rhs_transpose,
tma_inputs,
):
mlir_dtype = mlir_dtype_cls.get()
if ir.F32Type.isinstance(mlir_dtype): # We actually use tf32 instead
jax_dtype = jnp.float32
if lhs_transpose or not rhs_transpose:
self.skipTest("Transpose only supported in 16-bit WGMMA")
exponent_bits, mantissa_bits = 8, 10 # Use tf32
elif bytewidth(mlir_dtype) == 2:
if n % 64 != 0:
self.skipTest("16-bit WGMMA only supports n % 64 == 0")
if ir.F16Type.isinstance(mlir_dtype):
jax_dtype = jnp.float16
exponent_bits, mantissa_bits = 5, 10
elif ir.BF16Type.isinstance(mlir_dtype):
jax_dtype = jnp.bfloat16
exponent_bits, mantissa_bits = 8, 7
else:
raise NotImplementedError(mlir_dtype)
else:
raise NotImplementedError(mlir_dtype)
nk_tile = 128 // bytewidth(mlir_dtype)
k = nk_tile * k_steps
assert m % 64 == 0 and n % nk_tile == 0
index = ir.IndexType.get()
row_major = mgpu.WGMMALayout.ROW_MAJOR
col_major = mgpu.WGMMALayout.COL_MAJOR
lhs_order = col_major if lhs_transpose else row_major
rhs_order = col_major if rhs_transpose else row_major
def kernel(ctx, lhs, rhs, out, scratch):
lhs_smem, rhs_smem = scratch
if tma_inputs:
lhs_transform = (mosaic_gpu.TileTransform((64, nk_tile)),)
if lhs_transpose:
assert nk_tile == 64 # Make sure we didn't have to transpose tiling.
lhs_transform += (mosaic_gpu.TransposeTransform((1, 0, 2, 3)),)
rhs_transform = (mosaic_gpu.TileTransform((nk_tile, nk_tile)),)
if rhs_transpose:
rhs_transform += (mosaic_gpu.TransposeTransform((1, 0, 2, 3)),)
barriers = BarrierArray(2)
ctx.async_copy(
src_ref=lhs,
dst_ref=lhs_smem,
swizzle=128,
gmem_transform=lhs_transform,
barrier=barriers[0],
)
ctx.async_copy(
src_ref=rhs,
dst_ref=rhs_smem,
swizzle=128,
gmem_transform=rhs_transform,
barrier=barriers[1],
)
for i in range(2):
barriers[i].wait()
else:
for mi in range(m // 64):
for ki in range(k // nk_tile):
lhs_slice = (
ds(c(mi * 64, index), 64),
ds(c(ki * nk_tile, index), nk_tile),
)
if lhs_transpose:
lhs_slice = lhs_slice[::-1]
copy(
src=memref_slice(lhs, lhs_slice),
dst=memref_slice(lhs_smem, (mi, ki)),
swizzle=128,
)
for ki in range(k // nk_tile):
k_slice = ds(c(ki * nk_tile, index), nk_tile)
for ni in range(n // nk_tile):
rhs_slice = (k_slice, ds(c(ni * nk_tile, index), nk_tile))
if rhs_transpose:
rhs_slice = rhs_slice[::-1]
copy(
src=memref_slice(rhs, rhs_slice),
dst=memref_slice(rhs_smem, (ki, ni)),
swizzle=128,
)
init_acc = mgpu.WGMMAAccumulator.zero(m=m, n=n)
acc = mgpu.wgmma(
init_acc, lhs_smem, rhs_smem,
a_order=lhs_order, b_order=rhs_order,
)
nvvm.wgmma_commit_group_sync_aligned()
nvvm.wgmma_wait_group_sync_aligned(0)
acc.value.store_untiled(out)
def quantize(x):
# Quantize the input to avoid rounding when feeding the WGMMA
return jax.lax.reduce_precision(x, exponent_bits, mantissa_bits)
x_shape = (k, m) if lhs_transpose else (m, k)
x = quantize(self.prng.uniform(-1, 1, x_shape)).astype(jax_dtype)
y_shape = (n, k) if rhs_transpose else (k, n)
y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(jax_dtype)
out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32)
scratch_shape = [
jax.ShapeDtypeStruct((m // 64, k // nk_tile, 64, nk_tile), jax_dtype),
jax.ShapeDtypeStruct(
(k // nk_tile, n // nk_tile, nk_tile, nk_tile), jax_dtype
),
]
z = mosaic_gpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (x, y), out_shape, scratch_shape
)(x, y)
x32, y32 = x.astype(np.float32), y.astype(np.float32)
ref = (x32.T if lhs_transpose else x32) @ (y32.T if rhs_transpose else y32)
np.testing.assert_allclose(z, ref, atol=5e-6)
# TODO(apaszke): Add support for f32
@parameterized.product(
m=(64, 128, 192),
n=(64, 128, 192),
k_steps=(1, 2),
rhs_transpose=(False, True),
mlir_dtype_cls=(ir.F16Type, ir.BF16Type),
)
def test_wgmma_reg_lhs(self, m, n, k_steps, rhs_transpose, mlir_dtype_cls):
k = 64 * k_steps
index = ir.IndexType.get()
row_major = mgpu.WGMMALayout.ROW_MAJOR
col_major = mgpu.WGMMALayout.COL_MAJOR
rhs_order = col_major if rhs_transpose else row_major
def kernel(ctx, rhs, out, rhs_smem):
del ctx
for ki in range(k_steps):
for ni in range(n // 64):
rhs_slice = (ds(c(ki * 64, index), 64), ds(c(ni * 64, index), 64))
if rhs_transpose:
rhs_slice = rhs_slice[::-1]
copy(
src=memref_slice(rhs, rhs_slice),
dst=memref_slice(rhs_smem, (ki, ni)),
swizzle=128,
)
init_acc = mgpu.WGMMAAccumulator.zero(m=m, n=n)
lhs_regs = iota_tensor(m, k, mlir_dtype_cls.get())
acc = mgpu.wgmma(init_acc, lhs_regs, rhs_smem, b_order=rhs_order)
nvvm.wgmma_commit_group_sync_aligned()
nvvm.wgmma_wait_group_sync_aligned(0)
acc.value.store_untiled(out)
jax_dtype = jnp.float16 if mlir_dtype_cls == ir.F16Type else jnp.bfloat16
y_shape = (n, k) if rhs_transpose else (k, n)
y = self.prng.uniform(-1, 1, y_shape).astype(jax_dtype)
out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32)
scratch_shape = jax.ShapeDtypeStruct(
(k_steps, n // 64, 64, 64), jax_dtype
)
z = mosaic_gpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), y, out_shape, scratch_shape
)(y)
x = np.arange(m * k, dtype=jax_dtype).reshape(m, k)
ref = jax.lax.dot(
x, (y.T if rhs_transpose else y), preferred_element_type=jnp.float32
)
rtol = 0 if k_steps == 1 else 2.2e-4
np.testing.assert_allclose(z, ref, rtol=rtol, atol=0)
class TMATest(TestCase):
@parameterized.product(
swizzle=(None, 128),
shape=((64, 64), (5, 64), (2, 3, 5, 64)),
dtype=(jnp.float16, jnp.float32),
)
def test_tma_load(self, swizzle, shape, dtype):
if dtype == jnp.float32:
shape = (*shape[:-1], shape[-1] // 2)
i1 = ir.IntegerType.get_signless(1)
def kernel(ctx, src, dst, tmp):
barrier = BarrierArray(1)[0]
ctx.async_copy(src_ref=src, dst_ref=tmp, swizzle=swizzle, barrier=barrier)
barrier.wait_parity(c(0, i1))
copy(tmp, dst, swizzle=swizzle)
x = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x)
np.testing.assert_array_equal(y, x)
@parameterized.product(
swizzle=(None, 128),
shape=((128, 128), (5, 32, 128)),
dtype=(jnp.float16, jnp.float32),
)
def test_tma_load_tiled(self, swizzle, shape, dtype):
i1 = ir.IntegerType.get_signless(1)
index = ir.IndexType.get()
tiling = (32, 128 // jnp.dtype(dtype).itemsize)
tiled_shape = tile_shape(shape, tiling)[:len(shape)]
def kernel(ctx, src, dst, tmp):
barrier = BarrierArray(1)[0]
ctx.async_copy(
src_ref=src,
dst_ref=tmp,
swizzle=swizzle,
barrier=barrier,
gmem_transform=mosaic_gpu.TileTransform(tiling),
)
barrier.wait_parity(c(0, i1))
for idxs in np.ndindex(tiled_shape):
untiled_idxs, tiled_idxs = idxs[:-len(tiling)], idxs[-len(tiling):]
s = (
*untiled_idxs,
*(ds(c(ix * t, index), t) for ix, t in zip(tiled_idxs, tiling)),
)
copy(memref_slice(tmp, idxs), memref_slice(dst, s), swizzle=swizzle)
x = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
smem = jax.ShapeDtypeStruct(tile_shape(shape, tiling), dtype)
f = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, smem)
y = f(x)
np.testing.assert_array_equal(y, x)
@parameterized.product(
swizzle=(None, 128),
dtype=(jnp.float16, jnp.float32),
)
def test_tma_squeeze_indexing(self, swizzle, dtype):
shape = (4, 5, 64)
if dtype == jnp.float32:
shape = (*shape[:-1], shape[-1] // 2)
def kernel(ctx, src, dst, tmp):
barrier = BarrierArray(1)[0]
for i in range(4):
ctx.async_copy(
src_ref=src,
dst_ref=memref_slice(tmp, i),
gmem_slice=i,
swizzle=swizzle,
barrier=barrier,
)
barrier.wait()
copy(tmp, dst, swizzle=swizzle)
x = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x)
np.testing.assert_array_equal(y, x)
def test_parity_tracking(self):
shape = (16, 64)
index = ir.IndexType.get()
def kernel(ctx, src, dst, tmp):
barrier = BarrierArray(1)[0]
for i in range(shape[0]):
s = ds(c(i, index), 1)
ctx.async_copy(
src_ref=src, dst_ref=tmp, gmem_slice=s, barrier=barrier,
)
barrier.wait()
copy(tmp, memref_slice(dst, s))
x = np.arange(np.prod(shape), dtype=jnp.float16).reshape(shape)
y = mosaic_gpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), x, x, x[0:1]
)(x)
np.testing.assert_array_equal(y, x)
@parameterized.product(
swizzle=(None, 128),
shape=((64, 64), (5, 64), (2, 3, 5, 64)),
dtype=(jnp.float16, jnp.float32),
)
def test_tma_store(self, swizzle, shape, dtype):
if dtype == jnp.float32:
shape = (*shape[:-1], shape[-1] // 2)
def kernel(ctx, src, dst, tmp):
copy(src, tmp, swizzle=swizzle)
ctx.async_copy(src_ref=tmp, dst_ref=dst, swizzle=swizzle)
ctx.await_async_copy(0)
x = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x)
np.testing.assert_array_equal(y, x)
class FragmentedArrayTest(TestCase):
@parameterized.product(
op=(
operator.add,
operator.mul,
operator.sub,
operator.truediv,
(lambda x, y: mgpu.FragmentedArray.max(x, y), np.maximum),
),
m=(64, 128),
n=(8, 16, 32, 64, 80, 128, 256),
)
def test_binary(self, op, m=64, n=32):
if isinstance(op, tuple):
op, np_op = op
else:
np_op = op
def kernel(ctx, dst, _):
f32 = ir.F32Type.get()
iota = iota_tensor(m=m, n=n, mlir_dtype=f32)
op(iota, iota).store_untiled(dst)
out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32)
result = mosaic_gpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
)()
x = np.arange(m * n, dtype=jnp.float32).reshape(m, n)
if op == operator.truediv:
np.testing.assert_allclose(result, np_op(x, x), atol=2e-7)
else:
np.testing.assert_array_equal(result, np_op(x, x))
@parameterized.product(
ops=((lambda x: mgpu.FragmentedArray.exp(x), np.exp),),
m=(64, 128),
n=(8, 16, 32, 64, 80, 128, 256),
)
def test_unary(self, ops, m=64, n=32):
op, np_op = ops
def kernel(ctx, dst, _):
f32 = ir.F32Type.get()
iota = iota_tensor(m=m, n=n, mlir_dtype=f32)
op(iota).store_untiled(dst)
out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32)
result = mosaic_gpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
)()
x = np.arange(m * n, dtype=jnp.float32).reshape(m, n)
np.testing.assert_allclose(result, np_op(x), atol=2e-7, rtol=2e-7)
@parameterized.product(
op=(arith.addf, arith.maximumf),
m=(64, 128),
n=(8, 16, 32, 64, 80, 128, 256),
)
def test_reduce(self, op, m=64, n=32):
def kernel(ctx, dst, _):
f32 = ir.F32Type.get()
iota = iota_tensor(m=m, n=n, mlir_dtype=f32)
iota.reduce(op, axis=1).broadcast_minor(n).store_untiled(dst)
out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32)
result = mosaic_gpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
)()
x = np.arange(m * n, dtype=jnp.float32).reshape(m, n)
if op == arith.addf:
expected = np.broadcast_to(x.sum(axis=1, keepdims=True), x.shape)
elif op == arith.maximumf:
expected = np.broadcast_to(x.max(axis=1, keepdims=True), x.shape)
else:
raise NotImplementedError(f"Unsupported op: {op}")
np.testing.assert_array_equal(result, expected)
def test_splat(self):
def kernel(ctx, dst, _):
f32 = ir.F32Type.get()
v = arith.constant(f32, ir.FloatAttr.get(f32, 3.14))
t = mgpu.FragmentedArray.splat(v, (128,), mgpu.WGMMA_ROW_LAYOUT)
t.broadcast_minor(32).store_untiled(dst)
out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32)
result = mosaic_gpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
)()
np.testing.assert_array_equal(result, np.full((128, 32), 3.14, np.float32))
@parameterized.product(in_shape=((128, 128), (128, 64), (64, 128)))
def test_strided_load_store(self, in_shape):
def kernel(ctx, *args):
gmem_input, gmem_output, (smem_input, smem_output) = args
copy(gmem_input, smem_input)
t = mgpu.FragmentedArray.load_strided(smem_input)
t.store_untiled(smem_output)
copy(smem_output, gmem_output)
inp = out = self.prng.uniform(-1, 1, in_shape).astype(jnp.float32)
result = mosaic_gpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (inp,), out, [inp, out],
)(inp)
np.testing.assert_array_equal(inp, result)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())