diff --git a/.bazelrc b/.bazelrc index e865ef967..b89f909f5 100644 --- a/.bazelrc +++ b/.bazelrc @@ -111,6 +111,10 @@ build:cuda_clang --copt=-Wno-gnu-offsetof-extensions # Disable clang extention that rejects unknown arguments. build:cuda_clang --copt=-Qunused-arguments +build:mosaic_gpu --@llvm-project//mlir:enable_cuda=true +build:mosaic_gpu --copt=-DLLVM_HAS_NVPTX_TARGET=1 +build:mosaic_gpu --//jax:build_mosaic_gpu=true + build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true build:rocm --@xla//xla/python:enable_gpu=true diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 3e9a585f5..287573f04 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -140,7 +140,7 @@ jobs: PY_COLORS: 1 run: | pytest -n auto --tb=short docs - pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas + pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas documentation_render: diff --git a/build/build.py b/build/build.py index dfdf33ec8..66a689b43 100755 --- a/build/build.py +++ b/build/build.py @@ -263,7 +263,7 @@ def write_bazelrc(*, python_bin_path, remote_build, rocm_amdgpu_targets, bazel_options, target_cpu_features, wheel_cpu, enable_mkl_dnn, use_clang, clang_path, clang_major_version, enable_cuda, enable_nccl, enable_rocm, - build_gpu_plugin): + build_gpu_plugin, enable_mosaic_gpu): tf_cuda_paths = [] with open("../.jax_configure.bazelrc", "w") as f: @@ -337,6 +337,8 @@ def write_bazelrc(*, python_bin_path, remote_build, if use_clang: f.write("build --config=nvcc_clang\n") f.write(f"build --action_env=CLANG_CUDA_COMPILER_PATH={clang_path}\n") + if enable_mosaic_gpu: + f.write("build --config=mosaic_gpu") if enable_rocm: f.write("build --config=rocm\n") if not enable_nccl: @@ -541,6 +543,10 @@ def main(): "--editable", action="store_true", help="Create an 'editable' jaxlib build instead of a wheel.") + add_boolean_argument( + parser, + "enable_mosaic_gpu", + help_str="Should we build with Mosaic GPU? VERY EXPERIMENTAL.") add_boolean_argument( parser, "configure_only", @@ -652,6 +658,7 @@ def main(): enable_nccl=args.enable_nccl, enable_rocm=args.enable_rocm, build_gpu_plugin=args.build_gpu_plugin, + enable_mosaic_gpu=args.enable_mosaic_gpu, ) if args.configure_only: diff --git a/jax/BUILD b/jax/BUILD index bf95b3fbc..fb8ad49c0 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -70,6 +70,19 @@ config_setting( }, ) +# If this flag is true, jaxlib will be built with Mosaic GPU. VERY EXPERIMENTAL. +bool_flag( + name = "build_mosaic_gpu", + build_setting_default = False, +) + +config_setting( + name = "enable_mosaic_gpu", + flag_values = { + ":build_mosaic_gpu": "True", + }, +) + exports_files([ "LICENSE", "version.py", @@ -116,6 +129,14 @@ package_group( ] + pallas_tpu_internal_users, ) +package_group( + name = "mosaic_gpu_users", + packages = [ + "//...", + "//learning/brain/research/jax", + ], +) + # JAX-private test utilities. py_library( # This build target is required in order to use private test utilities in jax._src.test_util, @@ -647,6 +668,37 @@ pytype_strict_library( ], ) +# This target only supports sm_90 GPUs. +py_library( + name = "mosaic_gpu", + srcs = glob(["experimental/mosaic/gpu/*.py"]), + visibility = [ + ":mosaic_gpu_users", + ], + deps = [ + ":config", + ":jax", + ":mlir", + "//jax/_src/lib", + "//third_party/py/absl/flags", + "//jaxlib/mlir:arithmetic_dialect", + "//jaxlib/mlir:builtin_dialect", + "//jaxlib/mlir:execution_engine", + "//jaxlib/mlir:func_dialect", + "//jaxlib/mlir:gpu_dialect", + "//jaxlib/mlir:ir", + "//jaxlib/mlir:llvm_dialect", + "//jaxlib/mlir:math_dialect", + "//jaxlib/mlir:memref_dialect", + "//jaxlib/mlir:nvgpu_dialect", + "//jaxlib/mlir:nvvm_dialect", + "//jaxlib/mlir:pass_manager", + "//jaxlib/mlir:scf_dialect", + "//jaxlib/mlir:vector_dialect", + "//third_party/py/numpy", + ], +) + pytype_strict_library( name = "partial_eval", srcs = ["_src/interpreters/partial_eval.py"], diff --git a/jax/_src/lib/BUILD b/jax/_src/lib/BUILD index e6c6f7e3f..d3ed72f47 100644 --- a/jax/_src/lib/BUILD +++ b/jax/_src/lib/BUILD @@ -15,6 +15,7 @@ load( "//jaxlib:jax.bzl", "if_building_jaxlib", + "if_building_mosaic_gpu", "jax_visibility", "py_library_providing_imports_info", "pytype_strict_library", @@ -31,6 +32,7 @@ py_library_providing_imports_info( "__init__.py", "mlir/__init__.py", "mlir/dialects/__init__.py", + "mosaic_gpu.py", "triton.py", ], lib_rule = pytype_strict_library, @@ -58,5 +60,5 @@ py_library_providing_imports_info( "//jaxlib/mlir:stablehlo_dialect", "//jaxlib/mlir:vector_dialect", # xla_client - ]), + ]) + if_building_mosaic_gpu(["//jaxlib/mosaic/gpu:mosaic_gpu"]), ) diff --git a/jax/_src/lib/mlir/dialects/__init__.py b/jax/_src/lib/mlir/dialects/__init__.py index ae47aacc9..c22cc678a 100644 --- a/jax/_src/lib/mlir/dialects/__init__.py +++ b/jax/_src/lib/mlir/dialects/__init__.py @@ -23,6 +23,13 @@ import jaxlib.mlir.dialects.func as func import jaxlib.mlir.dialects.scf as scf import jaxlib.mlir.dialects.sparse_tensor as sparse_tensor import jaxlib.mlir.dialects.vector as vector +try: + import jaxlib.mlir.dialects.gpu as gpu # type: ignore + import jaxlib.mlir.dialects.nvgpu as nvgpu # type: ignore + import jaxlib.mlir.dialects.nvvm as nvvm # type: ignore + import jaxlib.mlir.dialects.llvm as llvm # type: ignore +except ImportError: + pass from jax._src import lib diff --git a/jax/_src/lib/mosaic_gpu.py b/jax/_src/lib/mosaic_gpu.py new file mode 100644 index 000000000..b00b556fd --- /dev/null +++ b/jax/_src/lib/mosaic_gpu.py @@ -0,0 +1,23 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa + +try: + from jaxlib.mosaic.gpu import _mosaic_gpu_ext # pytype: disable=import-error +except ImportError as e: + raise ModuleNotFoundError( + "Cannot import the Mosaic GPU bindings. You may need to build jaxlib from" + " source." + ) from e diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py new file mode 100644 index 000000000..6fed0e58b --- /dev/null +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -0,0 +1,595 @@ +# Copyright 2024 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import contextlib +import ctypes +import dataclasses +import os +import pathlib +import subprocess +import tempfile +import time +from typing import Any, Sequence + +import jax +from jax._src import config +from jax._src.interpreters import mlir +from jax._src.lib import xla_client +from jax._src.lib import mosaic_gpu as mosaic_gpu_lib +from jaxlib.mlir import ir +from jaxlib.mlir.dialects import arith +from jaxlib.mlir.dialects import builtin +from jaxlib.mlir.dialects import func +from jaxlib.mlir.dialects import gpu +from jaxlib.mlir.dialects import llvm +from jaxlib.mlir.dialects import memref +from jaxlib.mlir.dialects import nvgpu +from jaxlib.mlir.dialects import nvvm +from jaxlib.mlir.execution_engine import ExecutionEngine +from jaxlib.mlir.passmanager import PassManager +import numpy as np + +from . import dsl as mgpu +from . import profiler +from . import utils + +# mypy: ignore-errors + +# MLIR can't find libdevice unless we point it to the CUDA path +# TODO(apaszke): Unify with jax._src.lib.cuda_path +CUDA_ROOT = "/usr/local/cuda" +if os.environ.get("CUDA_ROOT") is None: + os.environ["CUDA_ROOT"] = CUDA_ROOT +else: + CUDA_ROOT = os.environ["CUDA_ROOT"] + +PTXAS_PATH = os.path.join(CUDA_ROOT, "bin/ptxas") +NVDISASM_PATH = os.path.join(CUDA_ROOT, "bin/nvdisasm") + + +c = mgpu.c # This is too common to fully qualify. + + +xla_client.register_custom_call_target( + "mosaic_gpu", + mosaic_gpu_lib._mosaic_gpu_ext._custom_call_capsule(), + platform="CUDA", +) + + +mosaic_gpu_dump_ptx = config.define_bool_state( + name="mosaic_gpu_dump_ptx", + default=config.bool_env("MOSAIC_GPU_DUMP_PTX", False), + help="If set, prints the kernel PTX", +) +mosaic_gpu_dump_ptxas = config.define_bool_state( + name="mosaic_gpu_dump_ptxas", + default=config.bool_env("MOSAIC_GPU_DUMP_PTXAS", False), + help="If set, prints the ptxas verbose output", +) +mosaic_gpu_dump_sass = config.define_bool_state( + name="mosaic_gpu_dump_sass", + default=config.bool_env("MOSAIC_GPU_DUMP_SASS", False), + help="If set, prints the kernel SASS", +) +mosaic_gpu_print_after_all = config.define_bool_state( + name='mosaic_gpu_print_after_all', + default=config.bool_env('MOSAIC_GPU_PRINT_AFTER_ALL', False), + help="If set, prints the kernel module after every pass", +) + + +mosaic_gpu_p = jax.core.Primitive("mosaic_gpu_p") +mosaic_gpu_p.multiple_results = True + + +@mosaic_gpu_p.def_abstract_eval +def _mosaic_gpu_abstract_eval(*_, module, out_types): + return [jax._src.core.ShapedArray(t.shape, t.dtype) for t in out_types] + + +def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types): + runtime_path = ( + pathlib.Path(mosaic_gpu_lib._mosaic_gpu_ext.__file__).parent + / "libmlir_cuda_runtime.so" + ) + shared_libs = [str(runtime_path)] if runtime_path.exists() else [] + engine = ExecutionEngine( + module, opt_level=3, shared_libs=shared_libs, enable_object_dump=False + ) + ctx.module_context.add_keepalive(engine) + func_ptr = engine.lookup("main") + ptr_bytes = ctypes.cast(func_ptr, ctypes.c_void_p).value.to_bytes( + 8, byteorder="little" + ) # pytype: disable=attribute-error + op = mlir.custom_call( + "mosaic_gpu", + result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], + operands=args, + backend_config=ptr_bytes, + ) + return op.results + +mlir.register_lowering(mosaic_gpu_p, _mosaic_gpu_lowering_rule, "cuda") + + +@dataclasses.dataclass(frozen=True) +class MemRefTransform: + def apply(self, ref: ir.Value) -> ir.Value: + raise NotImplementedError("Subclasses should override this method") + + def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]: + raise NotImplementedError("Subclasses should override this method") + + def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: + raise NotImplementedError("Subclasses should override this method") + + +@dataclasses.dataclass(frozen=True) +class TileTransform(MemRefTransform): + """Tiles a suffix of memref dimensions. + + For example, given a memref of shape (5, 128, 128) and a tiling of (64, 32), + the shape of the result will be (5, 2, 4, 64, 32). The shape always ends with + the tile shape, and the size of tiled dimensions is divided by the tile size. + This is especially useful for swizzled WGMMA, which expect tiled layouts in + shared memory. + """ + tiling: tuple[int, ...] + + def apply(self, ref: ir.Value) -> ir.Value: + untiled_rank = ir.MemRefType(ref.type).rank + tiling_rank = len(self.tiling) + tiled_rank = untiled_rank + tiling_rank + for t, d in zip(self.tiling[::-1], range(untiled_rank)[::-1]): + ref = mgpu.memref_unfold(ref, d, (None, t)) + permutation = ( + *range(untiled_rank - tiling_rank), + *range(untiled_rank - tiling_rank, tiled_rank, 2), + *range(untiled_rank - tiling_rank + 1, tiled_rank, 2), + ) + return mgpu.memref_transpose(ref, permutation) + + def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]: + index = ir.IndexType.get() + tiling_rank = len(self.tiling) + return ( + *idx[:-tiling_rank], + *( + arith.divui(i, c(t, index)) + for i, t in zip(idx[-tiling_rank:], self.tiling) + ), + *([c(0, index)] * tiling_rank), + ) + + def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: + # Note that this also checks that tiled dims are not squeezed. Their slice + # size would be 1 if so. + tiling_rank = len(self.tiling) + for size, tile_size in zip(shape[-tiling_rank:], self.tiling): + if size % tile_size: + raise ValueError( + f"Expected GMEM slice shape {shape} suffix to be a multiple" + f" of tiling {self.tiling}" + ) + return ( + *shape[:-tiling_rank], + *(s // t for s, t in zip(shape[-tiling_rank:], self.tiling)), + *self.tiling, + ) + + +@dataclasses.dataclass(frozen=True) +class TransposeTransform(MemRefTransform): + """Transposes memref dimensions.""" + permutation: tuple[int, ...] + + def __post_init__(self): + if len(self.permutation) != len(set(self.permutation)): + raise ValueError("Permutation must be a permutation") + + def apply(self, ref: ir.Value) -> ir.Value: + return mgpu.memref_transpose(ref, self.permutation) + + def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]: + return tuple(idx[p] for p in self.permutation) + + def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: + return tuple(shape[p] for p in self.permutation) + + +OnDeviceProfiler = profiler.OnDeviceProfiler + + +@dataclasses.dataclass() +class LaunchContext: + launch_op: gpu.LaunchOp + profiler: OnDeviceProfiler | None = None + tma_descriptors: dict[ + tuple[ir.Value, tuple[int, ...], int | None, tuple[MemRefTransform, ...]], + ir.Value, + ] = dataclasses.field(default_factory=dict, init=False) + + @contextlib.contextmanager + def named_region(self, *args, **kwargs): + if self.profiler is not None: + with self.profiler.record(*args, **kwargs): + yield + else: + yield + + def _get_tma_desc( + self, + ref, + gmem_transform: tuple[MemRefTransform, ...], + transformed_slice_shape: tuple[int, ...], + swizzle: int | None, + ): + index = ir.IndexType.get() + ref_ty = ir.MemRefType(ref.type) + tma_desc_key = (ref, transformed_slice_shape, swizzle, gmem_transform) + if (tma_desc := self.tma_descriptors.get(tma_desc_key, None)) is None: + swizzle_str = f"swizzle_{swizzle}b" if swizzle is not None else "none" + default_tensor_map_attrs = dict( + swizzle=swizzle_str, l2promo="none", oob="zero", interleave="none" + ) + tensor_map_ty = utils.get_tensormap_descriptor( + tensor=( + f"memref<{'x'.join(map(str, transformed_slice_shape))}x{ref_ty.element_type}, 3>" + ), + **default_tensor_map_attrs, + ) + with ir.InsertionPoint(self.launch_op): + for t in gmem_transform: + ref = t.apply(ref) + ref_unranked = memref.cast( + ir.UnrankedMemRefType.get(ref_ty.element_type, None), ref + ) + tma_desc = nvgpu.tma_create_descriptor( + tensor_map_ty, + ref_unranked, + [c(s, index) for s in transformed_slice_shape], + ) + self.tma_descriptors[tma_desc_key] = tma_desc + return tma_desc + + def async_copy( + self, + *, + src_ref, + dst_ref, + gmem_slice: Any = (), + gmem_transform: MemRefTransform | tuple[MemRefTransform, ...] = (), + barrier: mgpu.Barrier | None = None, + swizzle: int | None = None, + arrive: bool | None = None, + uniform: bool = True, + ): + index = ir.IndexType.get() + smem = ir.Attribute.parse("#gpu.address_space") + src_ref_ty = ir.MemRefType(src_ref.type) + dst_ref_ty = ir.MemRefType(dst_ref.type) + element_type = src_ref_ty.element_type + if element_type != dst_ref_ty.element_type: + raise ValueError( + f"Expected same element type, got {element_type} and" + f" {dst_ref_ty.element_type}" + ) + if not isinstance(gmem_transform, tuple): + gmem_transform = (gmem_transform,) + + if src_ref_ty.memory_space is None and dst_ref_ty.memory_space == smem: + gmem_ref, smem_ref = src_ref, dst_ref + if barrier is None: + raise ValueError("Barriers are required for GMEM -> SMEM copies") + if arrive is None: + arrive = True # Arrive by default + elif src_ref_ty.memory_space == smem and dst_ref_ty.memory_space is None: + gmem_ref, smem_ref = dst_ref, src_ref + if barrier is not None: + raise ValueError("Barriers are unsupported for SMEM -> GMEM copies") + if arrive is not None: + raise ValueError("arrive is unsupported for SMEM -> GMEM copies") + else: + raise ValueError("Only SMEM <-> GMEM copies supported") + # TODO(apaszke): This is a very approximate check. Improve it! + expected_name = "builtin.unrealized_conversion_cast" + if ( + gmem_ref.owner is None + or gmem_ref.owner.opview.OPERATION_NAME != expected_name + ): + raise ValueError("GMEM reference in async_copy must be a kernel argument") + + base_indices, slice_shape, is_squeezed = utils.parse_indices( + gmem_slice, ir.MemRefType(gmem_ref.type).shape + ) + dyn_base_indices = tuple( + c(i, index) if not isinstance(i, ir.Value) else i for i in base_indices + ) + slice_shape = tuple(slice_shape) + for t in gmem_transform: + dyn_base_indices = t.transform_index(dyn_base_indices) + slice_shape = t.transform_shape(slice_shape) + for dim, squeezed in enumerate(is_squeezed): + if squeezed: + smem_ref = mgpu.memref_unsqueeze(smem_ref, dim) + smem_ref_ty = ir.MemRefType(smem_ref.type) + + if slice_shape != tuple(smem_ref_ty.shape): + raise ValueError( + "Expected the SMEM reference to have the same shape as the tiled" + f" slice: {tuple(smem_ref_ty.shape)} != {slice_shape}" + ) + tma_desc = self._get_tma_desc( + gmem_ref, gmem_transform, slice_shape, swizzle, + ) + + # nvgpu TMA instructions expect reversed indices... + rev_dyn_based_indices = reversed(dyn_base_indices) + + uniform_ctx = mgpu.once if uniform else contextlib.nullcontext + + if gmem_ref is src_ref: + with uniform_ctx(): + assert barrier is not None # for pytype + barrier_group = barrier.barrier_array.value + barrier_idx = barrier.offset + if arrive: + slice_bytes = c( + np.prod(slice_shape) * mgpu.bytewidth(element_type), index + ) + nvgpu.mbarrier_arrive_expect_tx( + barrier_group, slice_bytes, barrier_idx + ) + nvgpu.tma_async_load( + smem_ref, barrier_group, tma_desc, rev_dyn_based_indices, barrier_idx + ) + else: + with uniform_ctx(): + nvgpu.tma_async_store(smem_ref, tma_desc, rev_dyn_based_indices) + nvvm.cp_async_bulk_commit_group() + + def await_async_copy( + self, allow_groups: int, await_read_only: bool = False + ): + nvvm.cp_async_bulk_wait_group(allow_groups, read=await_read_only) + gpu.barrier() # Groups are supposedly tracked per-thread + + def _prefetch_tma_descs(self): + with ir.InsertionPoint(self.launch_op.body.blocks[0]): + with mgpu.once(): + for desc in self.tma_descriptors.values(): + nvgpu.tma_prefetch_descriptor(desc) + + +@contextlib.contextmanager +def _launch( + token, + grid, + block, + smem_buffers, + profiler_spec: profiler.ProfilerSpec | None = None, + maybe_prof_buffer: ir.Value | None = None, +): + if (profiler_spec is None) != (maybe_prof_buffer is None): + raise ValueError + index = ir.IndexType.get() + i32 = ir.IntegerType.get_signless(32) + i8 = ir.IntegerType.get_signless(8) + grid_vals = [c(i, index) for i in grid] + block_vals = [c(i, index) for i in block] + flat_refs, smem_buffer_tree = jax.tree.flatten(smem_buffers) + + smem_ref_bytes = [] + for ref_ty in flat_refs: + smem_ref_bytes.append( + np.prod(ref_ty.shape) * np.dtype(ref_ty.dtype).itemsize + ) + + smem_bytes = sum(smem_ref_bytes) + if profiler_spec is not None: + smem_bytes += profiler_spec.smem_bytes(grid) + + launch_op = gpu.LaunchOp( + token.type, [token], *grid_vals, *block_vals, + dynamicSharedMemorySize=c(smem_bytes, i32)) + launch_op.body.blocks.append(*([index] * 12)) # Append an empty block + smem = ir.Attribute.parse("#gpu.address_space") + with ir.InsertionPoint(launch_op.body.blocks[0]): + dynamic_smem = gpu.dynamic_shared_memory( + ir.MemRefType.get( + (ir.ShapedType.get_dynamic_size(),), i8, memory_space=smem + ) + ) + smem_refs = [] + dynamic_smem_offset = 0 + for ref_ty, ref_bytes in zip(flat_refs, smem_ref_bytes): + mlir_dtype = mlir.dtype_to_ir_type(ref_ty.dtype) + tile_smem = memref.view( + ir.MemRefType.get(ref_ty.shape, mlir_dtype, memory_space=smem), + dynamic_smem, c(dynamic_smem_offset, index), [], + ) + dynamic_smem_offset += ref_bytes + smem_refs.append(tile_smem) + + if profiler_spec: + prof_smem = memref.view( + ir.MemRefType.get( + (profiler_spec.smem_i32_elements(grid=grid),), + i32, memory_space=smem, + ), + dynamic_smem, c(dynamic_smem_offset, index), [], + ) + prof = profiler.OnDeviceProfiler( + profiler_spec, prof_smem, maybe_prof_buffer + ) + else: + prof = None + smem_ref_tree = jax.tree.unflatten(smem_buffer_tree, smem_refs) + yield LaunchContext(launch_op, prof), smem_ref_tree + if prof is not None: + prof.finalize(grid=grid) + gpu.terminator() + + +def as_gpu_kernel( + body, + grid: tuple[int, ...], + block: tuple[int, ...], + in_shape, + out_shape, + smem_scratch_shape, + prof_spec: profiler.ProfilerSpec | None = None, +): + ptr_ty = ir.Type.parse("!llvm.ptr") + token_ty = ir.Type.parse("!gpu.async.token") + + def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType: + return ir.MemRefType.get(shape.shape, mlir.dtype_to_ir_type(shape.dtype)) + + if isinstance(in_shape, list): + in_shape = tuple(in_shape) + elif not isinstance(in_shape, tuple): + in_shape = (in_shape,) + in_ref_tys = [_shape_to_ref_ty(t) for t in in_shape] + + unwrap_output_tuple = False + if isinstance(out_shape, list): + out_shape = tuple(out_shape) + elif not isinstance(out_shape, tuple): + out_shape = (out_shape,) + unwrap_output_tuple = True + out_ref_tys = [_shape_to_ref_ty(t) for t in out_shape] + if prof_spec is not None: + out_shape = (*out_shape, prof_spec.jax_buffer_type) + out_ref_tys.append(prof_spec.mlir_buffer_type) + + module = ir.Module.create() + with ir.InsertionPoint(module.body): + @func.FuncOp.from_py_func(ptr_ty, ptr_ty) + def main(token_ptr, buffers): + token = builtin.unrealized_conversion_cast([token_ty], [token_ptr]) + arg_refs = [] + for i, ref_ty in enumerate([*in_ref_tys, *out_ref_tys]): + ptr = llvm.LoadOp(ptr_ty, llvm.GEPOp(ptr_ty, buffers, [], [i], ptr_ty)) + arg_refs.append(utils.ptr_as_memref(ptr, ir.MemRefType(ref_ty))) + in_refs = arg_refs[:len(in_ref_tys)] + out_refs = arg_refs[len(in_ref_tys):] + prof_buffer = out_refs.pop() if prof_spec is not None else None + with _launch( + token, grid, block, smem_scratch_shape, prof_spec, prof_buffer + ) as (launch_ctx, smem_refs): + body(launch_ctx, *in_refs, *out_refs, smem_refs) + main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + module.operation.verify() + + expected_arg_treedef = jax.tree.structure(in_shape) + def _check_args(args): + arg_treedef = jax.tree.structure(args) + if arg_treedef != expected_arg_treedef: + raise ValueError( + f"Invalid argument structure: expected {expected_arg_treedef}, got" + f" {arg_treedef}" + ) + + dump_low_level(module) + + pass_manager = PassManager.parse( + "builtin.module(gpu-lower-to-nvvm-pipeline{cubin-format=fatbin" + " cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3})" + ) + if mosaic_gpu_print_after_all.value: + pass_manager.enable_ir_printing() + pass_manager.run(module.operation) + + def bind(*args): + return mosaic_gpu_p.bind(*args, out_types=out_shape, module=module) + + if prof_spec is not None: + @jax.jit + def prof_kernel(*args): + _check_args(args) + *results, prof_buffer = bind(*args) + def dump_profile(prof_buffer): + out_file = os.path.join( + os.getenv("TEST_UNDECLARED_OUTPUTS_DIR"), + f"{time.time_ns()}-trace.json", + ) + try: + with open(out_file, "x") as f: + prof_spec.dump(prof_buffer, f) + except FileExistsError: + pass # TODO: Retry + jax.debug.callback(dump_profile, prof_buffer) + return results[0] if unwrap_output_tuple else results + return prof_kernel + else: + @jax.jit + def kernel(*args): + _check_args(args) + results = bind(*args) + return results[0] if unwrap_output_tuple else results + return kernel + + +def dump_low_level(module): + dump_ptx = mosaic_gpu_dump_ptx.value + dump_ptxas = mosaic_gpu_dump_ptxas.value + dump_sass = mosaic_gpu_dump_sass.value + if not any([dump_ptx, dump_ptxas, dump_sass]): + return + module = ir.Module.parse( + module.operation.get_asm(binary=True, enable_debug_info=True) + ) + pm = PassManager.parse( + "builtin.module(gpu-lower-to-nvvm-pipeline{cubin-format=isa" + " cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3})" + ) + pm.run(module.operation) + + for op in module.body: + if op.OPERATION_NAME == "gpu.binary": + objects = ir.ArrayAttr(op.objects) + if len(objects) != 1: + raise NotImplementedError("Expected a single object") + obj = str(objects[0]) + start = obj.find('assembly = "') + len('assembly = "') + end = obj.find('"', start) + ptx = obj[start:end] + ptx = ptx.replace("\\09", "\t").replace("\\0A", "\n")[:-3] + if dump_ptx: + print(ptx) + if dump_ptxas or dump_sass: + with tempfile.TemporaryDirectory() as tmp: + ptx_path = os.path.join(tmp, "kernel.ptx") + with open(ptx_path, "w") as f: + f.write(ptx) + elf_path = os.path.join(tmp, 'kernel.o') + v_flag = "-v" if dump_ptxas else "" + ptxas_flags = f"{v_flag} --opt-level 3 --gpu-name sm_90a" + ptxas_out = subprocess.check_output( + f"{PTXAS_PATH} {ptxas_flags} --output-file {elf_path} {ptx_path}", + stderr=subprocess.STDOUT, + shell=True, + ) + if dump_ptxas: + print(ptxas_out.decode()) + if dump_sass: + sass = subprocess.check_output( + f"{NVDISASM_PATH} -ndf -c {elf_path}", + stderr=subprocess.STDOUT, + shell=True, + ) + print(sass.decode()) diff --git a/jax/experimental/mosaic/gpu/dsl.py b/jax/experimental/mosaic/gpu/dsl.py new file mode 100644 index 000000000..9e4868deb --- /dev/null +++ b/jax/experimental/mosaic/gpu/dsl.py @@ -0,0 +1,47 @@ +# Copyright 2024 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from .fragmented_array import ( + FragmentedArray, + FragmentedLayout, + WGMMA_LAYOUT, + WGMMA_ROW_LAYOUT, + WGStridedFragLayout, +) +from .utils import ( + Barrier, + BarrierArray, + DynamicSlice, + Partition, + Partition1D, + bytewidth, + c, + commit_shared, + debug_print, + ds, + fori, + memref_fold, + memref_slice, + memref_transpose, + memref_unfold, + memref_unsqueeze, + once, + tile_shape, +) +from .wgmma import ( + WGMMAAccumulator, + WGMMALayout, + wgmma, +) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py new file mode 100644 index 000000000..2deb80f47 --- /dev/null +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -0,0 +1,476 @@ +# Copyright 2024 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for code generator.""" + +import dataclasses + +import jax +from jaxlib.mlir import ir +from jaxlib.mlir.dialects import arith +from jaxlib.mlir.dialects import gpu +from jaxlib.mlir.dialects import llvm +from jaxlib.mlir.dialects import math as mlir_math +from jaxlib.mlir.dialects import memref +from jaxlib.mlir.dialects import nvvm +from jaxlib.mlir.dialects import vector +import numpy as np + +from . import utils +from . import dsl as mgpu + +# mypy: ignore-errors + +WARPGROUP_SIZE = utils.WARPGROUP_SIZE +c = utils.c + + +@dataclasses.dataclass(frozen=True) +class WGMMAFragLayout: + """[m, n] matrix, where m % 64 == 0 == n % 8.""" + + +@dataclasses.dataclass(frozen=True) +class WGMMARowFragLayout: + """[m] matrix, where m % 64 == 0.""" + + +@dataclasses.dataclass(frozen=True) +class WGStridedFragLayout: + """Convert the array to 1D and then shard across threads.""" + + shape: tuple[int, ...] + vec_size: int + + def __post_init__(self): + if np.prod(self.shape) % (self.vec_size * WARPGROUP_SIZE) != 0: + raise ValueError((self, WARPGROUP_SIZE)) + + @classmethod + def from_memref_type(cls, memref_ty: ir.Type): + if not ir.MemRefType.isinstance(memref_ty): + raise TypeError(memref_ty) + + memref_type = ir.MemRefType(memref_ty) + bw = mgpu.bytewidth(memref_type.element_type) + assert 8 % bw == 0 and 8 // bw != 0, bw + return cls(shape=memref_type.shape, vec_size=8 // bw) + + def thread_vec_idxs(self): + """The indexes to be used for vector load/store WGStridedFragLayout. + + Yields: + The indices of the vector that correspond to the current thread. + """ + cardinality = np.prod(self.shape) + assert cardinality % (WARPGROUP_SIZE * self.vec_size) == 0 + reg_num = cardinality // (WARPGROUP_SIZE * self.vec_size) + tidx = gpu.thread_id(gpu.Dimension.x) + off = arith.muli(tidx, c(self.vec_size, tidx.type)) + for i in range(reg_num): + yield [arith.addi(off, c(i * WARPGROUP_SIZE * self.vec_size, tidx.type))] + + +FragmentedLayout = WGStridedFragLayout | WGMMAFragLayout | WGMMARowFragLayout + + +WGMMA_LAYOUT = WGMMAFragLayout() +WGMMA_ROW_LAYOUT = WGMMARowFragLayout() + + +@jax.tree_util.register_pytree_node_class +class FragmentedArray: + registers: np.ndarray # of ir.Value, see checks in init for shapes. + layout: FragmentedLayout + + def __init__(self, *, _registers: np.ndarray, _layout: FragmentedLayout): + self.registers = _registers + self.layout = _layout + + match self.layout: + # Registers are [m_tiles, n_tiles, 2 rows, 1 cols] in WGMMA layout + # Each element is a vector<2xdtype> + case WGMMAFragLayout(): + if self.registers.ndim != 4 or self.registers.shape[2:] != (2, 1): + raise ValueError("Invalid register array shape") + + # Registers are [m_tiles, 2 rows] in WGMMA_ROW layout + # Each element is a dtype scalar + case WGMMARowFragLayout(): + if self.registers.ndim != 2 or self.registers.shape[-1] != 2: + raise ValueError("Invalid register array shape") + + # Registers are flat + case WGStridedFragLayout(shape): + (reg_size,) = ir.VectorType(_registers.flat[0].type).shape + if np.prod(shape) != np.prod(_registers.shape) * WARPGROUP_SIZE * reg_size: + raise ValueError((reg_size, shape, _registers.shape, WARPGROUP_SIZE), _registers.flat[0].type) + case _: + raise NotImplementedError + + @classmethod + def load_strided(cls, ref: ir.Value): + if not ir.MemRefType.isinstance(ref.type): + raise TypeError(ref.type) + + ref_ty = ir.MemRefType(ref.type) + ref_1d = mgpu.memref_fold(ref, 0, len(ref_ty.shape)) + layout = WGStridedFragLayout.from_memref_type(ref_ty) + vec_ty = ir.VectorType.get((layout.vec_size,), ref_ty.element_type) + vecs = [vector.load(vec_ty, ref_1d, vec_idx) for vec_idx in layout.thread_vec_idxs()] + return cls(_registers=np.array(vecs), _layout=layout) + + @classmethod + def splat(cls, value, shape, layout): + match layout: + case WGMMARowFragLayout(): + if len(shape) != 1: + raise ValueError + if shape[0] % 64: + raise ValueError + reg_shape = (shape[0] // 64, 2) + case WGMMAFragLayout(): + if len(shape) != 2: + raise ValueError + if shape[0] % 64 or shape[1] % 8: + raise ValueError + reg_shape = (shape[0] // 64, shape[1] // 8, 2, 1) + value = vector.splat(ir.VectorType.get((2,), value.type), value) + case WGStridedFragLayout(shape=shape, vec_size=vec_size): + elems = np.prod(shape) + reg_shape = (elems // (WARPGROUP_SIZE * vec_size),) + value = vector.splat(ir.VectorType.get((vec_size,), value.type), value) + case _: + raise NotImplementedError(layout) + + return cls( + _registers=np.full(reg_shape, value, dtype=object), + _layout=layout, + ) + + @property + def shape(self): + row_tiles = self.registers.shape[0] + match self.layout: + case WGMMAFragLayout(): + col_tiles = self.registers.shape[1] + return (row_tiles * 64, col_tiles * 8) + case WGMMARowFragLayout(): + return (row_tiles * 64,) + case WGStridedFragLayout(shape): + return shape + + @property + def mlir_dtype(self): + reg_ty = self.registers.flat[0].type + match self.layout: + case WGMMAFragLayout() | WGStridedFragLayout(): + return ir.VectorType(reg_ty).element_type + case WGMMARowFragLayout(): + return reg_ty + + def _pointwise(self, op, *other): + for o in other: + if not isinstance(o, FragmentedArray): + return NotImplemented + if self.layout != o.layout: + raise ValueError("Incompatible FragmentedArray layouts") + if self.registers.shape != o.registers.shape: + raise ValueError("Incompatible FragmentedArray shapes") + new_regs = np.empty_like(self.registers) + for idx, reg in np.ndenumerate(self.registers): + new_regs[idx] = op(reg, *(o.registers[idx] for o in other)) + return FragmentedArray(_registers=new_regs, _layout=self.layout) + + def __add__(self, other): + return self._pointwise(arith.addf, other) + + def __mul__(self, other): + return self._pointwise(arith.mulf, other) + + def __sub__(self, other): + return self._pointwise(arith.subf, other) + + def __truediv__(self, other): + return self._pointwise(arith.divf, other) + + def max(self, other): + return self._pointwise(arith.maximumf, other) + + def exp(self, approx: bool = False): + def fast_exp(x): + f32 = ir.F32Type.get() + log2e = arith.constant(f32, ir.FloatAttr.get(f32, 1.4426950408889634)) + if x.type == f32: + scaled = arith.mulf(x, log2e) + return llvm.inline_asm( + f32, [scaled], "ex2.approx.f32 $0,$1;", "=f,f", asm_dialect=0 + ) + elif ir.VectorType.isinstance(x.type): + index = ir.IndexType.get() + result = llvm.mlir_undef(x.type) + for i in range(2): + v = vector.extractelement(x, position=c(i, index)) + vr = fast_exp(v) + result = vector.insertelement(vr, result, position=c(i, index)) + return result + else: + raise NotImplementedError(x.type) + return self._pointwise(fast_exp if approx else mlir_math.exp) + + def __and__(self, other): + if not ir.IntegerType.isinstance(self.mlir_dtype): + raise ValueError( + "Bitwise operations only defined for integer types, not" + f" {self.mlir_dtype}" + ) + + return self._pointwise(arith.andi, other) + + def bitcast(self, elt: ir.Type): + reg_type = self.registers.flat[0].type + if ir.VectorType.isinstance(reg_type): + reg_shape = ir.VectorType(reg_type).shape + ty = ir.VectorType.get(reg_shape, elt) + else: + ty = elt + + return self._pointwise(lambda x: arith.bitcast(ty, x)) + + def __getitem__(self, idx): + if self.layout != WGMMA_LAYOUT: + raise NotImplementedError("Only WGMMA layouts support slicing") + base_idx, slice_shape, is_squeezed = utils.parse_indices(idx, self.shape) + if any(is_squeezed): + raise NotImplementedError("Only slicing implemented") + if ( + base_idx[0] % 64 + or slice_shape[0] % 64 + or base_idx[1] % 8 + or slice_shape[1] % 8 + ): + raise NotImplementedError("Only tile aligned slicing supported") + base_idx[0] //= 64 + slice_shape[0] //= 64 + base_idx[1] //= 8 + slice_shape[1] //= 8 + new_regs = self.registers[ + base_idx[0] : base_idx[0] + slice_shape[0], + base_idx[1] : base_idx[1] + slice_shape[1], + ] + return FragmentedArray(_registers=new_regs, _layout=self.layout) + + # TODO(apaszke): Support JAX dtypes here as well? + def astype(self, new_dtype: ir.Type): + cur_dtype = self.mlir_dtype + if cur_dtype == new_dtype: + return self + from_float = ir.FloatType.isinstance(cur_dtype) + to_float = ir.FloatType.isinstance(new_dtype) + from_integer = ir.IntegerType.isinstance(cur_dtype) + to_integer = ir.IntegerType.isinstance(new_dtype) + if from_float and to_float: + if ir.FloatType(cur_dtype).width > ir.FloatType(new_dtype).width: + convert = arith.truncf + else: + convert = arith.extf + elif from_integer and to_integer: + if ir.IntegerType(cur_dtype).width > ir.IntegerType(new_dtype).width: + convert = arith.trunci + else: + convert = arith.extsi + elif from_integer and to_float: + convert = arith.sitofp + elif from_float and to_integer: + convert = arith.fptosi + new_registers = np.empty_like(self.registers) + match self.layout: + case WGMMAFragLayout(): + new_reg_ty = ir.VectorType.get((2,), new_dtype) + case WGStridedFragLayout(vec_size=vec_size): + new_reg_ty = ir.VectorType.get((vec_size,), new_dtype) + case WGMMARowFragLayout(): + new_reg_ty = new_dtype + case _: + raise NotImplementedError(f"Unsupported layout {self.layout}") + for idx, reg in np.ndenumerate(self.registers): + new_registers[idx] = convert(new_reg_ty, reg) + return FragmentedArray(_registers=new_registers, _layout=self.layout) + + def reduce(self, op, axis): + if self.layout != WGMMA_LAYOUT: + raise NotImplementedError(self.layout) + if axis != 1: + raise NotImplementedError + index = ir.IndexType.get() + i32 = ir.IntegerType.get_signless(32) + new_regs = np.empty(self.registers.shape[::2], dtype=object) + assert self.registers.shape[-1] == 1 + for row_tile, row_subtile in np.ndindex(new_regs.shape): + # Reduce the registers owned by the current thread over n tiles + thread_result_vec = self.registers[row_tile, 0, row_subtile, 0] + for n_tile in range(1, self.registers.shape[1]): + thread_result_vec = op( + thread_result_vec, self.registers[row_tile, n_tile, row_subtile, 0] + ) + thread_result = op( + vector.extractelement(thread_result_vec, position=c(0, index)), + vector.extractelement(thread_result_vec, position=c(1, index)), + ) + # Do a shuffle to reduce in groups of 4 consecutive threads. + result = thread_result + for i in (1, 2): + other_result = nvvm.shfl_sync( + result.type, + c(0xFFFFFFFF, i32), + result, + c(i, i32), + c(0x1F, i32), + nvvm.ShflKind.bfly, + ) + result = op(result, other_result) + new_regs[row_tile, row_subtile] = result + return FragmentedArray(_registers=new_regs, _layout=WGMMA_ROW_LAYOUT) + + def broadcast_minor(self, n): + if self.layout != WGMMA_ROW_LAYOUT: + raise NotImplementedError + num_row_tiles = self.registers.shape[0] + num_col_tiles, rem = divmod(n, 8) + if rem: + raise ValueError("Number of columns must be divisible by 8") + new_regs = np.empty((num_row_tiles, num_col_tiles, 2, 1), dtype=object) + dtype = self.mlir_dtype + for (row_tile, row_subtile), reg in np.ndenumerate(self.registers): + new_regs[row_tile, :, row_subtile, :] = vector.splat( + ir.VectorType.get((2,), dtype), reg + ) + return FragmentedArray(_registers=new_regs, _layout=WGMMA_LAYOUT) + + def store_untiled(self, ref: ir.Value): + if not ir.MemRefType.isinstance(ref.type): + raise ValueError(ref) + + match self.layout: + case WGMMAFragLayout(): + self._store_untiled_wgmma(ref) + case WGStridedFragLayout(): + self._store_untiled_wg_strided(ref) + case _: + raise NotImplementedError(self.layout) + + def _store_untiled_wg_strided(self, ref: ir.Value): + ref_ty = ir.MemRefType(ref.type) + if ref_ty.shape != self.shape: + raise ValueError((ref_ty.shape, self.shape)) + smem_1d = mgpu.memref_fold(ref, 0, len(ref_ty.shape)) + assert isinstance(self.layout, WGStridedFragLayout) + for idx, reg in zip(self.layout.thread_vec_idxs(), self.registers.flat): + vector.store(reg, smem_1d, idx) + + def _store_untiled_wgmma(self, ref: ir.Value): + """Stores accumulator to a 2D memref. Not optimized at the moment.""" + assert self.layout == WGMMA_LAYOUT + index = ir.IndexType.get() + m, n = self.shape # pytype: disable=bad-unpacking + ref_ty = ir.MemRefType(ref.type) + if ref_ty.shape != [m, n]: + raise ValueError(ref.type, (m, n)) + + def c(x): + return arith.ConstantOp(index, ir.IntegerAttr.get(index, x)) + + tidx = gpu.thread_id(gpu.Dimension.x) + lane_id = arith.remui(tidx, c(32)) # {0, 1, ..., 31} + warp_id = arith.divui(tidx, c(32)) # {0, 1, 2, 3} + row_base = arith.addi( + arith.divui(lane_id, c(4)), arith.muli(warp_id, c(16)) + ) + col_base = arith.muli(arith.remui(lane_id, c(4)), c(2)) # {0, 2, 4, 6} + it = np.ndenumerate(self.registers) + for (row_tile, col_tile, row_idx, col_zero), elem in it: + del col_zero + row = arith.addi(row_base, c(row_tile * 64 + row_idx * 8)) + for col_idx in range(2): + value = vector.extractelement(elem, position=c(col_idx)) + col = arith.addi(col_base, c(col_tile * 8 + col_idx)) + memref.store(value, ref, [row, col]) + + def store_tiled(self, ref, swizzle: int | None): + if self.layout != WGMMA_LAYOUT: + raise NotImplementedError + bw = mgpu.bytewidth(self.mlir_dtype) + m, n = self.shape # pytype: disable=bad-unpacking + assert m % 64 == 0 # This is implied by the layout. + if n % 32 != 0: + raise NotImplementedError + cols_per_tile = 128 // bw + expected_shape = [m // 64, n // cols_per_tile, 64, cols_per_tile] + if ir.MemRefType(ref.type).shape != expected_shape: + raise ValueError(ref.type, (m, n)) + if swizzle != 128: + raise NotImplementedError("Only 128B swizzle supported") + index = ir.IndexType.get() + + def c(x): + return arith.ConstantOp(index, ir.IntegerAttr.get(index, x)) + + tidx = gpu.thread_id(gpu.Dimension.x) + lane_id = arith.remui(tidx, c(32)) # {0, 1, ..., 31} + warp_id = arith.divui(tidx, c(32)) # {0, 1, 2, 3} + sub_row_base = arith.divui(lane_id, c(4)) # {0, 1, ..., 7} + if bw > 2: # Stagger is only necessary for values larger than 16bit. + is_even_row = arith.cmpi( + arith.CmpIPredicate.eq, arith.remui(sub_row_base, c(2)), c(0) + ) + else: + # We rely on canonicalization to clean up the selects. + i1 = ir.IntegerType.get_signless(1) + is_even_row = arith.constant(i1, ir.IntegerAttr.get(i1, 1)) + row_base = arith.addi(sub_row_base, arith.muli(warp_id, c(16))) + col_base = arith.muli(arith.remui(lane_id, c(4)), c(2)) # {0, 2, 4, 6} + # The swizzle pattern is constant for a given thread. + col_swizzle_bits = arith.muli(sub_row_base, c(16 // bw)) + for row_group in range(m // 64): + for col_group in range(n // cols_per_tile): + for row_subidx in range(2): + row = arith.addi(row_base, c(row_subidx * 8)) + for col_subidx in range(cols_per_tile // 8): + # We stagger the even and odd rows a little to avoid bank conflicts. + # It seems that the STS.64 is 2x faster (and the hardware reports no + # conflicts) when the conflicts are split between half-warps, as + # opposed to having them within the half-warp. This requires a + # little more work for the selects, but is ultimately worth it. + col_subidx_even = col_subidx + col_subidx_odd = col_subidx ^ 2 + col_off = arith.select( + is_even_row, c(col_subidx_even * 8), c(col_subidx_odd * 8) + ) + col = arith.addi(col_base, col_off) + col = arith.xori(col, col_swizzle_bits) + reg_idx_even = col_subidx_even + col_group * (cols_per_tile // 8) + reg_idx_odd = col_subidx_odd + col_group * (cols_per_tile // 8) + value_even = self.registers[row_group, reg_idx_even, row_subidx, 0] + value_odd = self.registers[row_group, reg_idx_odd, row_subidx, 0] + value = arith.select(is_even_row, value_even, value_odd) + vector.store(value, ref, [c(row_group), c(col_group), row, col]) + + def tree_flatten(self): + return list(self.registers.flat), (self.layout, self.registers.shape) + + @classmethod + def tree_unflatten(cls, aux, flat_registers): + layout, reg_shape = aux + registers = np.asarray(flat_registers, dtype=object).reshape(reg_shape) + return cls(_registers=registers, _layout=layout) diff --git a/jax/experimental/mosaic/gpu/profiler.py b/jax/experimental/mosaic/gpu/profiler.py new file mode 100644 index 000000000..2dcc11b4f --- /dev/null +++ b/jax/experimental/mosaic/gpu/profiler.py @@ -0,0 +1,206 @@ +# Copyright 2024 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import contextlib +import json + +import jax +import jax.numpy as jnp +from jaxlib.mlir import ir +from jaxlib.mlir.dialects import arith +from jaxlib.mlir.dialects import gpu +from jaxlib.mlir.dialects import memref +from jaxlib.mlir.dialects import scf +import numpy as np + +from .utils import * # noqa: F403 + +# ruff: noqa: F405 +# mypy: ignore-errors + +class ProfilerSpec: + ENTER = 0 + EXIT = 1 << 31 + + def __init__(self, num_entries: int): + self.num_entries = num_entries + self.interned_names = {} + + @property + def mlir_buffer_type(self) -> ir.Type: + return ir.MemRefType.get( + (1 + self.num_entries,), ir.IntegerType.get_signless(32) + ) + + @property + def jax_buffer_type(self) -> ir.Type: + return jax.ShapeDtypeStruct((1 + self.num_entries,), jnp.uint32) + + def smem_i32_elements(self, grid: tuple[int, ...]): + return int(self.num_entries // np.prod(grid)) + + def smem_bytes(self, grid: tuple[int, ...]): + bytes_per_entry = 4 + return self.smem_i32_elements(grid) * bytes_per_entry + + def intern_name(self, name: str) -> int: + if name_id := self.interned_names.get(name, None): + return name_id + name_id = self.interned_names[name] = len(self.interned_names) + if name_id & self.EXIT: + raise RuntimeError("Allocated too many names") + return name_id + + def dump(self, buffer, f): + buffer = np.asarray(buffer) + num_blocks = buffer[0] + per_block = self.num_entries // num_blocks + block_entries = buffer[1 : 1 + num_blocks * per_block].reshape( + num_blocks, per_block + ) + start_times = block_entries[:, :2].astype(np.int64) + start_times = (start_times[:, 0] << 32) + start_times[:, 1] + start_times -= start_times.min() # Normalize + entries_used = block_entries[:, 2] + if np.any(entries_used > per_block - 2): + raise RuntimeError("Insufficient space to capture a full trace") + block_traces = block_entries[:, 3:] + unintern = {v: k for k, v in self.interned_names.items()} + events = [] + for block_idx in range(num_blocks): + valid_entries = entries_used[block_idx] - 3 + local_clock_offset = None + assert valid_entries % 2 == 0 + start_time = start_times[block_idx] + block_events = [] + for i in range(0, valid_entries, 2): + tag = block_traces[block_idx, i] + time = block_traces[block_idx, i + 1] + if local_clock_offset is None: + local_clock_offset = time + time -= local_clock_offset + time -= i * 6 # Account for the overhead of profiling. + if time < 0: + break # Detect a timer wraparound + name_id = tag + begin = True + if name_id & ProfilerSpec.EXIT: + name_id = name_id ^ ProfilerSpec.EXIT + begin = False + name = unintern[name_id] + block_events.append({ + "name": name, + "ph": "B" if begin else "E", + "ts": float(start_time + time) / 1e3, + "pid": 0, + "tid": block_idx, + }) + else: # If we didn't break + events.extend(block_events) + return json.dump({"displayTimeUnit": "ns", "traceEvents": events}, f) + + +class OnDeviceProfiler: + + def __init__(self, spec: ProfilerSpec, smem_buffer: ir.Value, gmem_buffer: ir.Value): + self.spec = spec + # self.should_store = gpu.thread_id(gpu.Dimension.x) + i32 = ir.IntegerType.get_signless(32) + index = ir.IndexType.get() + num_blocks = c(1, index) + for dim in gpu.Dimension: + num_blocks = arith.muli(num_blocks, gpu.grid_dim(dim)) + memref.store(arith.index_cast(i32, num_blocks), gmem_buffer, [c(0, index)]) + self.entries_per_block = arith.divui(c(spec.num_entries, index), num_blocks) + self.smem_buffer = smem_buffer + self.gmem_buffer = gmem_buffer + # Hopefully mem2reg will remove the allocation. + self.offset = memref.alloca(ir.MemRefType.get((), i32), [], []) + memref.store(c(0, i32), self.offset, []) + + @contextlib.contextmanager + def record(self, name: str): + i32 = ir.IntegerType.get_signless(32) + index = ir.IndexType.get() + name_id = self.spec.intern_name(name) + def store(modifier): + cur = arith.index_cast(index, memref.load(self.offset, [])) + # TODO(apaszke): Clamp indices + # bound = arith.subi(self.entries_per_block, c(2, index)) + # cur = arith.select( + # arith.cmpi(arith.CmpIPredicate.ult, cur, bound), cur, bound + # ) + memref.store(c(modifier | name_id, i32), self.smem_buffer, [cur]) + memref.store( + clock(), self.smem_buffer, [arith.addi(cur, c(1, cur.type))] + ) + memref.store( + arith.index_cast(i32, arith.addi(cur, c(2, cur.type))), + self.offset, + [], + ) + store(ProfilerSpec.ENTER) + yield + store(ProfilerSpec.EXIT) + + def finalize(self, grid): + index = ir.IndexType.get() + i32 = ir.IntegerType.get_signless(32) + + block_idx = c(0, index) + for dim in reversed(gpu.Dimension): # pytype: disable=wrong-arg-types + block_idx = arith.addi( + arith.muli(block_idx, gpu.grid_dim(dim)), gpu.block_id(dim) + ) + start_offset = arith.addi( + arith.muli(block_idx, self.entries_per_block), c(1, index) + ) + block_gmem_buffer = memref.subview( + self.gmem_buffer, [start_offset], [self.spec.num_entries], [1], + result_type=ir.Type.parse( + f"memref<{self.spec.num_entries}xi32, strided<[1], offset: ?>>" + ), + ) + # TODO(apaszke): Either use globaltimer or delete + # memref.store(globaltimer("high"), block_gmem_buffer, [c(0, index)]) + # memref.store(globaltimer("low"), block_gmem_buffer, [c(1, index)]) + memref.store(c(0, i32), block_gmem_buffer, [c(0, index)]) + memref.store(c(0, i32), block_gmem_buffer, [c(1, index)]) + memref.store( + arith.addi(memref.load(self.offset, []), c(3, i32)), + block_gmem_buffer, + [c(2, index)], + ) + + if_first = scf.IfOp( + arith.cmpi( + arith.CmpIPredicate.eq, gpu.thread_id(gpu.Dimension.x), c(0, index) + ) + ) + with ir.InsertionPoint(if_first.then_block): + for_op = scf.ForOp( + c(0, index), + c(self.spec.smem_i32_elements(grid) - 3, index), + c(1, index), + ) + with ir.InsertionPoint(for_op.body): + x = memref.load(self.smem_buffer, [for_op.induction_variable]) + memref.store( + x, + block_gmem_buffer, + [arith.addi(for_op.induction_variable, c(3, index))], + ) + scf.yield_([]) + scf.yield_([]) diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py new file mode 100644 index 000000000..103187e23 --- /dev/null +++ b/jax/experimental/mosaic/gpu/utils.py @@ -0,0 +1,634 @@ +# Copyright 2024 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for code generator.""" + +import contextlib +import dataclasses +from typing import Any, Literal, Sequence + +from absl import flags +import jax +from jaxlib.mlir import ir +from jaxlib.mlir.dialects import arith +from jaxlib.mlir.dialects import builtin +from jaxlib.mlir.dialects import gpu +from jaxlib.mlir.dialects import llvm +from jaxlib.mlir.dialects import memref +from jaxlib.mlir.dialects import nvgpu +from jaxlib.mlir.dialects import nvvm +from jaxlib.mlir.dialects import scf +from jaxlib.mlir.dialects import vector +import numpy as np + +# mypy: ignore-errors + +WARPGROUP_SIZE: int = 128 +DYNAMIC = -9223372036854775808 + +FLAGS = flags.FLAGS + +flags.DEFINE_bool("mosaic_gpu_debug", False, "Perform debug printing") + +# pylint: disable=line-too-long, wildcard-import, missing-function-docstring, bad-continuation, g-bad-todo, protected-access, g-explicit-length-test, missing-class-docstring, g-doc-return-or-yield, g-inconsistent-quotes + + +def ptr_as_memref(ptr, memref_ty: ir.MemRefType): + if len(memref_ty.shape) == 0: + raise NotImplementedError + i64 = ir.IntegerType.get_signless(64) + rank = len(memref_ty.shape) + desc_ty = ir.Type.parse( + f"!llvm.struct<(ptr, ptr, i64, array<{rank} x i64>, array<{rank} x i64>)>" + ) + desc = llvm.UndefOp(desc_ty) + desc = llvm.InsertValueOp(desc, ptr, [0]) # Allocation + desc = llvm.InsertValueOp(desc, ptr, [1]) # Aligned Base + desc = llvm.InsertValueOp( + desc, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, 0)), [2] + ) + for i, s in enumerate(memref_ty.shape): + desc = llvm.InsertValueOp( + desc, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, s)), [3, i] + ) + for i, s in enumerate(get_contiguous_strides(memref_ty.shape)): + desc = llvm.InsertValueOp( + desc, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, s)), [4, i] + ) + return builtin.unrealized_conversion_cast([memref_ty], [desc]) + + +def get_contiguous_strides(xs): + strides_ret = [] + stride = 1 + for x in xs[::-1]: + strides_ret.append(stride) + stride *= x + return strides_ret[::-1] + + +def c(val: int | float, ty): + if ir.IntegerType.isinstance(ty) or ir.IndexType.isinstance(ty): + if not isinstance(val, (int, np.integer)): + raise TypeError(type(val)) + attr = ir.IntegerAttr.get(ty, val) + elif ir.FloatType.isinstance(ty): + attr = ir.FloatAttr.get(ty, val) + elif ir.VectorType.isinstance(ty): + return vector.splat(ty, c(val, ir.VectorType(ty).element_type)) + else: + raise NotImplementedError(ty) + return arith.constant(ty, attr) + + +def get_tensormap_descriptor(**attrs): + return ir.Type.parse( + f"!nvgpu.tensormap.descriptor<{', '.join(k + '=' + v for k, v in attrs.items())}>" + ) + + +def debug_print(fmt, *args, uniform=True): + if not FLAGS.mosaic_gpu_debug: + return + type_formats = [] + new_args = [] + for arg in args: + ty_format = None + if ir.IndexType.isinstance(arg.type): + ty_format = "%llu" + if ir.IntegerType.isinstance(arg.type): + width = ir.IntegerType(arg.type).width + if width == 64: + ty_format = "%llu" + elif width == 1: + ty_format = "%llu" + arg = arith.extui(ir.IntegerType.get_signless(64), arg) + if ir.F32Type.isinstance(arg.type): + ty_format = "%f" + if ir.F16Type.isinstance(arg.type): + ty_format = "%f" + arg = arith.extf(ir.F32Type.get(), arg) + if ty_format is None: + raise NotImplementedError(arg.type) + type_formats.append(ty_format) + new_args.append(arg) + ctx = once if uniform else contextlib.nullcontext + with ctx(): + gpu.printf(fmt.format(*type_formats) + "\n", new_args) + + +@dataclasses.dataclass(frozen=True) +class ForResult: + op: scf.ForOp + results: tuple[Any, ...] + + @property + def result(self): + if len(self.results) != 1: + raise ValueError + return self.results[0] + + +def fori(bound, carrys): + unwrap = False + if not isinstance(carrys, (list, tuple)): + carrys = [carrys] + unwrap = True + flat_carrys, carry_treedef = jax.tree.flatten(carrys) + + def wrapper(f): + index = ir.IndexType.get() + c0 = arith.ConstantOp(index, ir.IntegerAttr.get(index, 0)) + c1 = arith.ConstantOp(index, ir.IntegerAttr.get(index, 1)) + for_op = scf.ForOp(c0, bound, c1, flat_carrys) + with ir.InsertionPoint(for_op.body): + i = for_op.induction_variable + inner_carrys = jax.tree.unflatten(carry_treedef, for_op.inner_iter_args) + if unwrap: + [inner_carrys] = inner_carrys + new_carrys = f(i, inner_carrys) + if unwrap: + new_carrys = [new_carrys] + new_flat_carrys, new_carry_treedef = jax.tree.flatten(new_carrys) + if new_carry_treedef != carry_treedef: + raise ValueError(new_carry_treedef, carry_treedef) + scf.YieldOp(new_flat_carrys) + final_flat_carrys = for_op.results + return ForResult( + for_op, jax.tree.unflatten(carry_treedef, final_flat_carrys) + ) + + return wrapper + + +def get_warp_idx(): + i32 = ir.IntegerType.get_signless(32) + tidx = arith.index_cast(i32, gpu.thread_id(gpu.Dimension.x)) + warp_idx = arith.shrui(tidx, c(5, tidx.type)) + mask = c(0xFFFFFFFF, i32) + return nvvm.shfl_sync( + warp_idx.type, mask, warp_idx, c(0, i32), c(0x1F, i32), nvvm.ShflKind.idx + ) + + +# True withon `once()` contexts. +_ONCE_REGION_ACTIVE = False + + +@contextlib.contextmanager +def once(): + """Runs the context only from a single thread from the first warp. + + The block is assumed to have a size of 1 in both y and z dimensions. + """ + global _ONCE_REGION_ACTIVE + + if _ONCE_REGION_ACTIVE: + yield + return + + warp = get_warp_idx() + first_warp = arith.cmpi(arith.CmpIPredicate.eq, warp, c(0, warp.type)) + elected = nvvm.elect_sync(ir.IntegerType.get_signless(1)) + should_run = arith.andi(first_warp, elected) + if_op = scf.IfOp(should_run) + _ONCE_REGION_ACTIVE = True + try: + with ir.InsertionPoint(if_op.then_block): + yield + scf.YieldOp([]) + finally: + _ONCE_REGION_ACTIVE = False + + +def clock(): + i32 = ir.IntegerType.get_signless(32) + return llvm.inline_asm( + i32, [], "mov.u32 $0,%clock;", "=r", asm_dialect=0, has_side_effects=True + ) + + +def globaltimer(kind: Literal["low", "high"] | None = None): + if kind is None: + i64 = ir.IntegerType.get_signless(64) + return llvm.inline_asm( + i64, [], "mov.u32 $0,%globaltimer;", + "=l", asm_dialect=0, has_side_effects=True, + ) + i32 = ir.IntegerType.get_signless(32) + return llvm.inline_asm( + i32, [], f"mov.u32 $0,%globaltimer_{kind[:2]};", + "=r", asm_dialect=0, has_side_effects=True, + ) + + +def bytewidth(ty: ir.Type): + if ir.IntegerType.isinstance(ty): + return ir.IntegerType(ty).width // 8 + if ir.FloatType.isinstance(ty): + return ir.FloatType(ty).width // 8 + raise NotImplementedError(ty) + + +@dataclasses.dataclass(frozen=True) +class DynamicSlice: + base: ir.Value | int + length: int + + +ds = DynamicSlice + + +def memref_slice(ref: ir.Value, index) -> ir.Value: + ref_ty = ir.MemRefType(ref.type) + base_indices, slice_shape, is_squeezed = parse_indices(index, ref_ty.shape) + + memref_strides, offset = ref_ty.get_strides_and_offset() + new_offset = offset + for idx, stride in zip(base_indices, memref_strides): + if isinstance(idx, int): + new_offset += idx * stride + else: + new_offset = ir.ShapedType.get_dynamic_stride_or_offset() + break + new_strides = [ + s for s, squeeze in zip(memref_strides, is_squeezed) if not squeeze + ] + new_shape = [s for s, squeeze in zip(slice_shape, is_squeezed) if not squeeze] + new_layout = ir.StridedLayoutAttr.get(new_offset, new_strides) + + ref_slice = memref.subview( + ref, base_indices, slice_shape, [1] * len(ref_ty.shape), + result_type=ir.MemRefType.get( + new_shape, ref_ty.element_type, new_layout, ref_ty.memory_space + ), + ) + return ref_slice + + +def _is_contiguous_shape_slice( + ref_ty: ir.MemRefType, dim_slice: slice | None = slice(None) +): + # If it's not a strided layout then we are definitely contiguous. + if not ir.StridedLayoutAttr.isinstance(ref_ty.layout): + return True + + strides = ir.StridedLayoutAttr(ref_ty.layout).strides[dim_slice] + shape = ref_ty.shape[dim_slice] + + # Check that each dimension fits exactly it the immediately larger stride. + ss = sorted(zip(strides, shape), key=lambda x: x[0], reverse=True) + for (prev_stride, _), (stride, shape) in zip(ss, ss[1:]): + if stride * shape != prev_stride: + return False + + return True + + +def memref_fold(ref: ir.Value, dim, fold_rank) -> ir.Value: + ref_ty = ir.MemRefType(ref.type) + new_shape = list(ref_ty.shape) + new_shape[dim : dim + fold_rank] = [np.prod(new_shape[dim : dim + fold_rank])] + identity = ir.AffineMapAttr.get(ir.AffineMap.get_identity(ref_ty.rank)) + if ref_ty.layout == identity: + new_layout = ir.AffineMapAttr.get( + ir.AffineMap.get_identity(ref_ty.rank - fold_rank + 1) + ) + elif _is_contiguous_shape_slice(ref_ty, slice(dim, dim + fold_rank)): + new_strides, offset = ref_ty.get_strides_and_offset() + new_strides[dim : dim + fold_rank] = [new_strides[dim + fold_rank - 1]] + new_layout = ir.StridedLayoutAttr.get(offset, new_strides) + else: + raise NotImplementedError( + f"strides={ref_ty.get_strides_and_offset()[0]}, {ref_ty.shape=}," + f" {dim=}, {fold_rank=}" + ) + + new_ty = ir.MemRefType.get( + new_shape, ref_ty.element_type, new_layout, ref_ty.memory_space + ) + assoc = [[d] for d in range(dim)] + assoc.append([dim + i for i in range(fold_rank)]) + assoc.extend([d] for d in range(dim + fold_rank, ref_ty.rank)) + assert len(assoc) == new_ty.rank + return memref.collapse_shape(new_ty, ref, assoc) + + +def memref_unfold(ref: ir.Value, dim, factors) -> ir.Value: + """Unfolds dim into two dimensions, the size of leading one given be major_factor.""" + ref_ty = ir.MemRefType(ref.type) + new_shape = list(ref_ty.shape) + if sum(f is None for f in factors) > 1: + raise ValueError("Can only infer one dimension") + known_factor_prod = np.prod([f for f in factors if f is not None]) + if new_shape[dim] % known_factor_prod: + raise ValueError("Non-divisible unfold:", new_shape[dim], factors) + factors = tuple( + new_shape[dim] // known_factor_prod if f is None else f for f in factors + ) + new_shape[dim : dim + 1] = factors + identity = ir.AffineMapAttr.get(ir.AffineMap.get_identity(ref_ty.rank)) + if ref_ty.layout == identity: + new_layout = ir.AffineMapAttr.get( + ir.AffineMap.get_identity(ref_ty.rank + len(factors) - 1) + ) + else: + new_strides, offset = ref_ty.get_strides_and_offset() + prev_stride = new_strides[dim] + inserted_strides = [] + for f in reversed(factors): + inserted_strides.append(prev_stride) + prev_stride *= f + new_strides[dim : dim + 1] = reversed(inserted_strides) + new_layout = ir.StridedLayoutAttr.get(offset, new_strides) + new_ty = ir.MemRefType.get( + new_shape, ref_ty.element_type, new_layout, ref_ty.memory_space + ) + if dim == ref_ty.rank: + assoc = [[d] for d in range(ref_ty.rank)] + assoc[-1].extend(range(ref_ty.rank, ref_ty.rank + len(factors) - 1)) + else: + assoc = [[d] for d in range(dim)] + assoc.append(list(range(dim, dim + len(factors)))) + assoc.extend([d + len(factors) - 1] for d in range(dim + 1, ref_ty.rank)) + assert len(assoc) == ref_ty.rank + return memref.expand_shape(new_ty, ref, assoc) + + +def memref_unsqueeze(ref: ir.Value, dim) -> ir.Value: + """Inserts a singleton dimension.""" + ref_ty = ir.MemRefType(ref.type) + if dim == ref_ty.rank: + new_shape = list(ref_ty.shape) + new_shape.append(1) + identity = ir.AffineMapAttr.get(ir.AffineMap.get_identity(ref_ty.rank)) + if ref_ty.layout == identity: + new_layout = ir.AffineMapAttr.get( + ir.AffineMap.get_identity(ref_ty.rank + 1) + ) + else: + new_strides, offset = ref_ty.get_strides_and_offset() + new_strides.append(1) + new_layout = ir.StridedLayoutAttr.get(offset, new_strides) + new_ty = ir.MemRefType.get( + new_shape, ref_ty.element_type, new_layout, ref_ty.memory_space + ) + assoc = [[d] for d in range(ref_ty.rank)] + assoc[-1].append(ref_ty.rank) + return memref.expand_shape(new_ty, ref, assoc) + else: + return memref_unfold(ref, dim, (1, None)) + + +def memref_transpose(ref: ir.Value, permutation: Sequence[int]) -> ir.Value: + ref_ty = ir.MemRefType(ref.type) + strides, offset = ref_ty.get_strides_and_offset() + new_strides = [strides[p] for p in permutation] + new_shape = [ref_ty.shape[p] for p in permutation] + new_layout = ir.StridedLayoutAttr.get(offset, new_strides) + new_ty = ir.MemRefType.get( + new_shape, ref_ty.element_type, new_layout, ref_ty.memory_space + ) + return memref.transpose( + new_ty, ref, ir.AffineMap.get_permutation(permutation) + ) + + +def parse_indices( + index, shape: tuple[int, ...] +) -> tuple[list[ir.Value | int], list[int], list[bool]]: + if not isinstance(index, tuple): + index = (index,) + if trailing_dims := len(shape) - len(index): + index += (slice(None),) * trailing_dims + base_indices = [] + slice_shape = [] + is_squeezed = [] + for idx, bound in zip(index, shape): + if isinstance(idx, (ir.Operation, ir.OpView)): + idx = idx.result + if isinstance(idx, int): + base_indices.append(idx) + slice_shape.append(1) + is_squeezed.append(True) + elif isinstance(idx, slice): + if idx.step is not None: + raise NotImplementedError("Strided slices not implemented") + base_indices.append(idx.start or 0) + slice_shape.append((idx.stop or bound) - (idx.start or 0)) + is_squeezed.append(False) + elif isinstance(idx, DynamicSlice): + base_indices.append(idx.base) + slice_shape.append(idx.length) + is_squeezed.append(False) + elif isinstance(idx, ir.Value): + if not ir.IndexType.isinstance(idx.type): + raise ValueError("Expected an index-typed index") + base_indices.append(idx) + slice_shape.append(1) + is_squeezed.append(True) + else: + raise NotImplementedError(type(idx)) + assert len(base_indices) == len(slice_shape) == len(is_squeezed) == len(shape) + return base_indices, slice_shape, is_squeezed + + +def commit_shared(): + gpu.barrier() + nvvm.fence_proxy( + nvvm.ProxyKind.async_shared, space=nvvm.SharedSpace.shared_cta + ) + + +class BarrierArray: + + def __init__(self, num_barriers): + barrier_group_ty = ir.Type.parse( + "!nvgpu.mbarrier.group," + f" num_barriers={num_barriers}>" + ) + + self.value = nvgpu.mbarrier_create(barrier_group_ty) + index = ir.IndexType.get() + if num_barriers > 32: + raise NotImplementedError("Only up to 32 barriers per group supported") + i32 = ir.IntegerType.get_signless(32) + self.phases = memref.alloca(ir.MemRefType.get((), i32), [], []) + memref.store(c(0, i32), self.phases, []) + with once(): + for i in range(num_barriers): + nvgpu.mbarrier_init(self.value, c(1, index), c(i, index)) + + def __getitem__(self, offset: ir.Value | int): + if isinstance(offset, int): + offset = c(offset, ir.IndexType.get()) + return Barrier(self, offset) + + +@dataclasses.dataclass(frozen=True) +class Barrier: + barrier_array: BarrierArray + offset: ir.Value + + def wait_parity(self, parity): + index = ir.IndexType.get() + nvgpu.mbarrier_try_wait_parity( + self.barrier_array.value, parity, c(10000000, index), self.offset, + ) + + def wait(self): + i32 = ir.IntegerType.get_signless(32) + parities = memref.load(self.barrier_array.phases, []) + offset_i32 = arith.index_castui(i32, self.offset) + bitmask = arith.shli(c(1, i32), offset_i32) + parity = arith.cmpi( + arith.CmpIPredicate.ne, arith.andi(parities, bitmask), c(0, i32) + ) + new_parities = arith.xori(parities, bitmask) + memref.store(new_parities, self.barrier_array.phases, []) + self.wait_parity(parity) + + +class Partition: + source_bounds: tuple[int, ...] + target_bounds: tuple[int, ...] + partition: tuple[int | None, ...] + base_offset: tuple[ir.Value, ...] | None + + def __init__( + self, + elements: tuple[int, ...], + *, + partition: tuple[int | None, ...], + base_offset: tuple[ir.Value, ...] | None = None, + num_chunks: tuple[int, ...] | None = None, + chunk_size: tuple[int, ...] | None = None, + ): + self.target_bounds = elements + self.partition = partition + self.base_offset = base_offset + if len(self.target_bounds) != len(self.partition): + raise ValueError + if num_chunks is None == chunk_size is None: + raise ValueError( + "Exactly one of num_chunks and chunk_size must be specified" + ) + if num_chunks is not None: + self.source_bounds = num_chunks + else: + if len(chunk_size) != len(self.target_bounds): + raise ValueError + source_bounds = [] + for els, chunk in zip(elements, chunk_size): + if els % chunk: + raise ValueError("Non-divisible partition", elements, chunk_size) + source_bounds.append(els // chunk) + self.source_bounds = tuple(source_bounds) + + seen_dims = set() + for p in self.partition: + if p is None: + continue + if not (0 <= p < len(self.source_bounds)): + raise ValueError + if p in seen_dims: + raise ValueError + seen_dims.add(p) + for tb, p in zip(self.target_bounds, self.partition): + if p is not None and tb % self.source_bounds[p]: + raise ValueError("Non-divisible partitioning") + + @property + def num_chunks(self) -> tuple[int, ...]: + return self.source_bounds + + @property + def target_block_shape(self): + return tuple(tb if p is None else tb // self.source_bounds[p] + for tb, p in zip(self.target_bounds, self.partition)) + + def get_base(self, *source_coords: ir.Value | int) -> list[ir.Value]: + coords = [] + index = ir.IndexType.get() + for i, (tbs, p) in enumerate(zip(self.target_block_shape, self.partition)): + if p is None: + dim_base = c(0, index) + else: + dim_base = arith.muli(c(tbs, index), source_coords[p]) + if self.base_offset is not None: + dim_base = arith.addi(self.base_offset[i], dim_base) + coords.append(dim_base) + return coords + + +class Partition1D: + partition: Partition + + def __init__( + self, + elements: int, + *, + base_offset: ir.Value | None = None, + num_chunks: int | None = None, + chunk_size: int | None = None, + ): + self.base_offset = base_offset + if num_chunks is None == chunk_size is None: + raise ValueError( + "Exactly one of num_chunks and chunk_size must be specified" + ) + common_kwargs = dict(elements=(elements,), partition=(0,)) + if base_offset is not None: + common_kwargs["base_offset"] = (base_offset,) + if num_chunks is not None: + self.partition = Partition(num_chunks=(num_chunks,), **common_kwargs) + else: + self.partition = Partition(chunk_size=(chunk_size,), **common_kwargs) + + @property + def num_chunks(self) -> int: + return self.partition.source_bounds[0] + + def get_base(self, source_coords: ir.Value) -> ir.Value: + return self.partition.get_base(source_coords)[0] + + def refine( + self, + *, + chunk: ir.Value | None = None, + num_chunks: int | None = None, + chunk_size: int | None = None, + ): + return Partition1D( + self.partition.target_block_shape[0], + num_chunks=num_chunks, + chunk_size=chunk_size, + base_offset=self.get_base(chunk) if chunk is not None else None, + ) + + +def tile_shape(shape, tiling): + if len(tiling) > len(shape): + raise ValueError + if not tiling: + return shape + tiling_rank = len(tiling) + for s, t in zip(shape[-tiling_rank:], tiling): + if s % t: + raise ValueError("Non-divisible tiling:", shape, tiling) + return ( + *shape[:-tiling_rank], + *(s // t for s, t in zip(shape[-tiling_rank:], tiling)), + *tiling, + ) diff --git a/jax/experimental/mosaic/gpu/wgmma.py b/jax/experimental/mosaic/gpu/wgmma.py new file mode 100644 index 000000000..13161dea8 --- /dev/null +++ b/jax/experimental/mosaic/gpu/wgmma.py @@ -0,0 +1,404 @@ +# Copyright 2024 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import dataclasses +import enum +import itertools + +import jax +from jaxlib.mlir import ir +from jaxlib.mlir.dialects import arith +from jaxlib.mlir.dialects import builtin +from jaxlib.mlir.dialects import llvm +from jaxlib.mlir.dialects import nvvm +from jaxlib.mlir.dialects import vector +import numpy as np + +from . import dsl as mgpu + +# mypy: ignore-errors + +c = mgpu.c +bytewidth = mgpu.bytewidth + + +@jax.tree_util.register_pytree_node_class +@dataclasses.dataclass +class WGMMAAccumulator: + """A FragmentedArray that has is synchronized with the async proxy. + + This implies that it requires no additional synchronization when passed in + as a WGMMA accumulator. In particular, when created from a + FragmentedArray, the necessary synchronization is inserted at construction. + """ + value: mgpu.FragmentedArray + + def __init__(self, *, _value: mgpu.FragmentedArray, _sync: bool = True): + if _value.layout != mgpu.WGMMA_LAYOUT: + raise ValueError("Only WGMMA layouts supported in WGMMAAccumulator") + self.value = _value + if _sync: + nvvm.wgmma_fence_aligned() + + @classmethod + def zero(cls, m, n): + if m % 64 or n % 8: + raise ValueError + f32 = ir.F32Type.get() + zero = arith.constant(f32, ir.FloatAttr.get(f32, 0.0)) + return cls( + _value=mgpu.FragmentedArray.splat(zero, (m, n), mgpu.WGMMA_LAYOUT) + ) + + @classmethod + def from_registers(cls, registers): + return cls(_value=registers) + + def tree_flatten(self): + return (self.value,), () + + @classmethod + def tree_unflatten(cls, aux, value): + del aux + return cls(_value=value[0], _sync=False) + + +def wgmma_encode(x: int): + result = (x & 0x3FFFF) >> 4 + if result << 4 != x: + raise ValueError("Cannot encode value in a WGMMA descriptor") + return result + + +def get_memref_base(memref_arg, memory_space=None): + i64 = ir.IntegerType.get_signless(64) + memref_ty = ir.MemRefType(memref_arg.type) + if len(memref_ty.shape) == 0: + raise NotImplementedError + elem_bytewidth = bytewidth(memref_ty.element_type) + rank = len(memref_ty.shape) + # TODO: Read out memory space from memref + space = "" if memory_space is None else "<" + str(memory_space) + ">" + ptr_ty = ir.Type.parse("!llvm.ptr" + space) + desc_ty = ir.Type.parse( + f"!llvm.struct<({ptr_ty}, {ptr_ty}, i64, array<{rank} x i64>," + f" array<{rank} x i64>)>" + ) + desc = builtin.UnrealizedConversionCastOp([desc_ty], [memref_arg]) + aligned_ptr = llvm.extractvalue(ptr_ty, desc, [1]) + offset_elems = llvm.extractvalue(i64, desc, [2]) + offset_bytes = llvm.mul(offset_elems, c(elem_bytewidth, i64)) + return llvm.inttoptr( + ptr_ty, llvm.add(llvm.ptrtoint(i64, aligned_ptr), offset_bytes) + ) + + +def create_descriptor( + memref_arg, + leading_byte_offset: int, + stride_byte_offset: int, + swizzle: int | None, + memory_space: int | None = None, + nvgpu_type=None, +): + i64 = ir.IntegerType.get_signless(64) + ptr_val = llvm.ptrtoint(i64, get_memref_base(memref_arg, memory_space)) + if swizzle is None: + swizzle_encoding = 0 + elif swizzle == 128: + swizzle_encoding = 1 + else: + raise NotImplementedError(swizzle) + encoded_base_addr = llvm.LShrOp( + llvm.AndOp(ptr_val, c(0x3FFFF, i64)), c(4, i64) + ) + desc_const = ( + (wgmma_encode(leading_byte_offset) << 16) + | (wgmma_encode(stride_byte_offset) << 32) + | + # We ignore the offset + (swizzle_encoding << 62) + ) + desc = llvm.OrOp(encoded_base_addr, c(desc_const, i64)) + if nvgpu_type is not None: + desc = builtin.UnrealizedConversionCastOp([nvgpu_type], [desc]) + return desc.result + + +def wgmma_m64k128B( + acc: np.ndarray, # of register Values + a, + b_descriptor: ir.Value, + a_transpose: bool | None, + b_transpose: bool, + a_k_stride: int | None, + b_k_stride: int, + n: int, + element_type: ir.Type, +): + f32 = ir.F32Type.get() + i32 = ir.IntegerType.get_signless(32) + i64 = ir.IntegerType.get_signless(64) + index = ir.IndexType.get() + if b_k_stride % 16: + raise ValueError + if n % (128 // bytewidth(element_type)): + raise ValueError + # Only 16-bit types support transposes + supports_transpose = bytewidth(element_type) == 2 + if not supports_transpose and (a_transpose or b_transpose): + raise ValueError("Only f16 WGMMA supports transposes") + if a_in_regs := isinstance(a, mgpu.FragmentedArray): + if a.mlir_dtype != ir.F16Type.get() and a.mlir_dtype != ir.BF16Type.get(): + raise ValueError(f"Unsupported A register array dtype: {a.mlir_dtype}") + if a.layout != mgpu.WGMMA_LAYOUT or a.shape != (64, 64): + raise ValueError("Unsupported A register array layout") + if a_k_stride is not None or a_transpose is not None: + raise ValueError("Unsupported WGMMA features with A in registers") + else: + if a_k_stride is None or a_k_stride % 16: + raise ValueError + if a_transpose is None: + raise ValueError + + num_acc_regs = n // 2 + num_imm_regs = 4 if supports_transpose else 2 + + if a_in_regs: + a_reg_constraints = ["r"] * 4 # 4x f16x2 registers + num_imm_regs -= 1 # transpose not supported for a in registers + else: + a_reg_constraints = ["l"] # descriptor + # Reference for i/o aliasing: https://gcc.gnu.org/onlinedocs/gcc/Extended-Asm.html + # Seems like it's not actually documented in LLVM IR docs. + reg_constraints_list = ( + ["=f"] * num_acc_regs # accumulator registers + + [str(i) for i in range(num_acc_regs)] # we alias outputs as inputs, too. + + a_reg_constraints # a descriptor / registers + + ["l"] * 1 # b descriptor + + ["n"] * (1 + num_imm_regs) # literal constants + ) + reg_constraints = ",".join(reg_constraints_list) + + reg_count = itertools.count() + + def take_regs(n): + return (f"${i}" for i in itertools.islice(reg_count, n)) + + acc_reg_vector = "{" + ",".join(take_regs(num_acc_regs)) + "}" + for _ in take_regs(num_acc_regs): # Ignore next entries: aliasing. + pass + if a_in_regs: + a_regs = "{" + ",".join(take_regs(len(a_reg_constraints))) + "}" + else: + a_regs, = take_regs(1) + b_desc_reg, use_out_reg = take_regs(2) + imm_regs = ", ".join(take_regs(num_imm_regs)) # Immediate regs (scale, ...). + assert next(reg_count) == len(reg_constraints_list) + el_ty = element_type + k_instr = 32 // bytewidth(element_type) + wgmma_instr = ( + f"wgmma.mma_async.sync.aligned.m64n{n}k{k_instr}.f32.{el_ty}.{el_ty} " + f"{acc_reg_vector}, {a_regs}, {b_desc_reg}, p, {imm_regs};" + ) + ptx = f"{{ .reg .pred p; setp.ne.b32 p, {use_out_reg}, 0; {wgmma_instr} }}\n" + + def lc(x): + return llvm.mlir_constant(i32, ir.IntegerAttr.get(i32, x)) + + def as_i32_reg(v): + return llvm.extractelement( + vector.bitcast(ir.VectorType.get((1,), i32), v), lc(0) + ) + + use_out = scale_a = scale_b = lc(1) + imms = [use_out, scale_a, scale_b] + if supports_transpose and a_transpose is not None: + imms += [lc(int(a_transpose)), lc(int(b_transpose))] + elif supports_transpose: + imms += [lc(int(b_transpose))] + if acc.ndim != 4 or acc.shape[0] != 1 or acc.shape[2:] != (2, 1): + raise ValueError(acc.shape) + acc_regs = [ # pylint: disable=g-complex-comprehension + vector.extractelement(reg, position=c(pos, index)) + for reg in acc.flat + for pos in range(2) + ] + acc_struct_type = ir.Type.parse( + f"!llvm.struct<({','.join('f32' for _ in acc_regs)})>" + ) + for i in range(4): + # Slice out the relevant part of A or advance the A descriptor. + if a_in_regs: + a_slice = a[:, (i * 16) : ((i + 1) * 16)] + a_args = [as_i32_reg(v) for v in a_slice.registers.flat] + else: + if i > 0: + a = llvm.add( + a, + llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, a_k_stride >> 4)), + ) + a_args = [a] + # Advance the B descriptor. + if i > 0: + b_descriptor = llvm.add( + b_descriptor, + llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, b_k_stride >> 4)), + ) + assert len(a_args) == len(a_reg_constraints) + acc_struct = llvm.inline_asm( + acc_struct_type, + [*acc_regs, *a_args, b_descriptor, *imms], + ptx, + reg_constraints, + asm_dialect=0, + has_side_effects=True, + ) + acc_regs = [ + llvm.extractvalue(f32, acc_struct, [i]) for i in range(len(acc_regs)) + ] + acc_vec_regs = [] + for first, second in zip(acc_regs[::2], acc_regs[1::2]): + vec = llvm.mlir_undef(ir.VectorType.get((2,), f32)) + vec = llvm.insertelement(vec, first, position=lc(0)) + vec = llvm.insertelement(vec, second, position=lc(1)) + acc_vec_regs.append(vec) + return np.asarray(acc_vec_regs, dtype=object).reshape(acc.shape) + + +class WGMMALayout(enum.Enum): + ROW_MAJOR = enum.auto() + COL_MAJOR = enum.auto() + + +# TODO(apaszke): Remove WGMMALayout. Make input shapes logical and infer +# transpositions from memref strides. +def wgmma( + acc: WGMMAAccumulator, + a, + b, + *, + # Order only applies within each tile! + a_order: WGMMALayout | None = None, + b_order: WGMMALayout = WGMMALayout.ROW_MAJOR, +): + if a_in_regs := isinstance(a, mgpu.FragmentedArray): + a_element_type = a.mlir_dtype + a_shape = a.shape + else: + a_ty = ir.MemRefType(a.type) + a_element_type = a_ty.element_type + a_shape = a_ty.shape + b_ty = ir.MemRefType(b.type) + supported_types = {ir.F16Type.get(), ir.BF16Type.get(), ir.F32Type.get()} + if a_element_type not in supported_types: + raise ValueError(a_element_type) + if b_ty.element_type not in supported_types: + raise ValueError(b_ty.element_type) + if (element_type := a_element_type) != b_ty.element_type: + raise ValueError + element_bytewidth = bytewidth(element_type) + kn_tile = 128 // element_bytewidth + + groups_k, groups_n = b_ty.shape[:2] + if b_ty.shape[2:] != [kn_tile, kn_tile]: + raise ValueError(b_ty.shape) + + if a_in_regs: + if a_element_type != ir.F16Type.get() and a_element_type != ir.BF16Type.get(): + raise ValueError(a_element_type) + if a_shape[0] % 64 or a_shape[1] % kn_tile: + raise ValueError(a_shape) + if a_shape[1] // kn_tile != groups_k: + raise ValueError(a_shape[1] // kn_tile, groups_k) + groups_m = a_shape[0] // 64 + if a_order is not None: + raise ValueError( + "a_order can only be specified when A is in shared memory" + ) + else: + groups_m = a_shape[0] + if a_shape[1] != groups_k: + raise ValueError(a_shape[1], groups_k) + if a_shape[2:] != [64, kn_tile]: + raise ValueError(a_shape) + if a_order is None: + a_order = WGMMALayout.ROW_MAJOR + + row_major = WGMMALayout.ROW_MAJOR + col_major = WGMMALayout.COL_MAJOR + a_desc_fields = dict( + leading_byte_offset=((1 if a_order == row_major else 512) << 4), + stride_byte_offset=(64 << 4), + swizzle=128, + memory_space=3, + ) + b_desc_fields = dict( + leading_byte_offset=((512 if b_order == row_major else 1) << 4), + stride_byte_offset=(64 << 4), + swizzle=128, + memory_space=3, + ) + wgmma_params = dict( + a_transpose=a_order == col_major, + b_transpose=b_order == row_major, + a_k_stride=(2 if a_order == row_major else 128) * 16, + b_k_stride=(128 if b_order == row_major else 2) * 16, + n=(groups_n * kn_tile), + element_type=ir.FloatTF32Type.get() + if ir.F32Type.isinstance(element_type) + else element_type, + ) + if a_in_regs: + wgmma_params["a_k_stride"] = wgmma_params["a_transpose"] = None + + if a_in_regs: + nvvm.wgmma_fence_aligned() # Make sure the registers are ready. + a_m_byte_stride = a_k_byte_stride = a_desc_base = None # Silence pytype. + else: + a_desc_base = create_descriptor(a, **a_desc_fields) + a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset() + a_byte_strides = [s * element_bytewidth for s in a_strides] + a_m_byte_stride, a_k_byte_stride = a_byte_strides[:2] + if a_byte_strides[2:] != [128, element_bytewidth]: + raise ValueError(a_byte_strides) + b_desc_base = create_descriptor(b, **b_desc_fields) + b_strides, _ = b_ty.get_strides_and_offset() + b_byte_strides = [s * element_bytewidth for s in b_strides] + b_k_byte_stride = b_byte_strides[0] + if b_byte_strides[1:] != [128 * kn_tile, 128, element_bytewidth]: + raise ValueError(b_byte_strides) + + i64 = ir.IntegerType.get_signless(64) + new_acc_regs = acc.value.registers.copy() + for mi in range(groups_m): + for ki in range(groups_k): + if a_in_regs: + a_mk = a[mi * 64 : (mi + 1) * 64, ki * kn_tile : (ki + 1) * kn_tile] + else: + a_mk = llvm.add( + a_desc_base, + c(wgmma_encode(mi * a_m_byte_stride + ki * a_k_byte_stride), i64), + ) + b_k = llvm.add(b_desc_base, c(wgmma_encode(ki * b_k_byte_stride), i64)) + new_acc_regs[mi : mi + 1] = wgmma_m64k128B( + new_acc_regs[mi : mi + 1], a_mk, b_k, **wgmma_params + ) + return WGMMAAccumulator( + _value=mgpu.FragmentedArray( + _registers=new_acc_regs, _layout=mgpu.WGMMA_LAYOUT + ), + _sync=False, + ) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 6fea4fc3a..d6dbedbe6 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -17,6 +17,7 @@ load("//jaxlib:symlink_files.bzl", "symlink_files") load( "//jaxlib:jax.bzl", + "if_building_mosaic_gpu", "if_windows", "py_library_providing_imports_info", "pybind_extension", @@ -70,7 +71,7 @@ py_library_providing_imports_info( "//jaxlib/mlir:vector_dialect", "//jaxlib/mosaic", "//jaxlib/triton", - ], + ] + if_building_mosaic_gpu(["//jaxlib/mosaic/gpu:mosaic_gpu"]), ) symlink_files( diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 07c6121d6..bf5e02414 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -161,6 +161,12 @@ def if_building_jaxlib(if_building, if_not_building = []): "//conditions:default": if_not_building, }) +def if_building_mosaic_gpu(if_building, if_not_building = []): + return select({ + "//jax:enable_mosaic_gpu": if_building, + "//conditions:default": if_not_building, + }) + # buildifier: disable=function-docstring def jax_test( name, diff --git a/jaxlib/mlir/BUILD.bazel b/jaxlib/mlir/BUILD.bazel index c9268aaf3..48dc7603c 100644 --- a/jaxlib/mlir/BUILD.bazel +++ b/jaxlib/mlir/BUILD.bazel @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//jaxlib:symlink_files.bzl", "symlink_inputs") +load("//jaxlib:symlink_files.bzl", "symlink_files", "symlink_inputs") package( default_visibility = [ @@ -217,3 +217,94 @@ symlink_inputs( "//jaxlib/mlir/_mlir_libs:_stablehlo", ], ) + +symlink_inputs( + name = "execution_engine", + rule = py_library, + symlinked_inputs = {"srcs": { + ".": [ + "@llvm-project//mlir/python:ExecutionEnginePyFiles", + ], + }}, + deps = [ + ":mlir", + "//jaxlib/mlir/_mlir_libs:_mlirExecutionEngine", + ], +) + +symlink_inputs( + name = "nvgpu_dialect", + rule = py_library, + symlinked_inputs = {"srcs": {"dialects": [ + "@llvm-project//mlir/python:NVGPUOpsPyFiles", + ]}}, + deps = [ + ":core", + ":ir", + ":mlir", + ], +) + +symlink_inputs( + name = "nvvm_dialect", + rule = py_library, + symlinked_inputs = {"srcs": {"dialects": [ + "@llvm-project//mlir/python:NVVMOpsPyFiles", + ]}}, + deps = [ + ":core", + ":ir", + ":mlir", + ], +) + +symlink_files( + name = "gpu_files", + srcs = ["@llvm-project//mlir/python:GPUOpsPyFiles"], + dst = "dialects", + flatten = True, +) + +symlink_files( + name = "gpu_package_files", + srcs = ["@llvm-project//mlir/python:GPUOpsPackagePyFiles"], + dst = "dialects/gpu", + flatten = True, +) + +symlink_files( + name = "gpu_package_passes_files", + srcs = ["@llvm-project//mlir/python:GPUOpsPackagePassesPyFiles"], + dst = "dialects/gpu/passes", + flatten = True, +) + +py_library( + name = "gpu_dialect", + srcs = [ + ":gpu_files", + ":gpu_package_files", + ":gpu_package_passes_files", + ], + deps = [ + ":core", + ":ir", + ":mlir", + "//jaxlib/mlir/_mlir_libs:_mlirGPUPasses", + ], +) + +symlink_inputs( + name = "llvm_dialect", + rule = py_library, + symlinked_inputs = {"srcs": {"dialects": [ + "@llvm-project//mlir/python:LLVMOpsPyFiles", + ]}}, + deps = [ + ":core", + ":ir", + ":mlir", + "//jaxlib/mlir/_mlir_libs:_mlirDialectsLLVM", + ], +) + diff --git a/jaxlib/mlir/_mlir_libs/BUILD.bazel b/jaxlib/mlir/_mlir_libs/BUILD.bazel index 54e58926e..6003d148a 100644 --- a/jaxlib/mlir/_mlir_libs/BUILD.bazel +++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel @@ -14,6 +14,7 @@ load( "//jaxlib:jax.bzl", + "if_building_mosaic_gpu", "if_windows", "py_extension", "pybind_extension", @@ -58,6 +59,51 @@ py_extension( ], ) +py_extension( + name = "_mlirExecutionEngine", + srcs = [ + "@llvm-project//mlir:lib/Bindings/Python/ExecutionEngineModule.cpp", + ], + copts = COPTS, + linkopts = LINKOPTS, + deps = [ + ":jaxlib_mlir_capi_shared_library", + "@llvm-project//mlir:CAPIExecutionEngineHeaders", + "@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps", + "@pybind11", + ], +) + +py_extension( + name = "_mlirGPUPasses", + srcs = [ + "@llvm-project//mlir:lib/Bindings/Python/GPUPasses.cpp", + ], + copts = COPTS, + linkopts = LINKOPTS, + deps = [ + ":jaxlib_mlir_capi_shared_library", + "@llvm-project//mlir:CAPIGPUHeaders", + "@pybind11", + ], +) + +py_extension( + name = "_mlirDialectsLLVM", + srcs = [ + "@llvm-project//mlir:lib/Bindings/Python/DialectLLVM.cpp", + ], + copts = COPTS, + linkopts = LINKOPTS, + deps = [ + ":jaxlib_mlir_capi_shared_library", + "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPILLVMHeaders", + "@llvm-project//mlir:MLIRBindingsPythonHeaders", + "@pybind11", + ], +) + py_extension( name = "_mlirDialectsSparseTensor", srcs = [ @@ -148,11 +194,36 @@ symlink_inputs( ], ) +cc_library( + name = "jaxlib_mlir_capi_shims", + srcs = ["jaxlib_mlir_capi_shims.cc"], + hdrs = ["jaxlib_mlir_capi_shims.h"], + deps = [ + "@llvm-project//mlir:BuiltinToLLVMIRTranslation", + "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:GPUPipelines", + "@llvm-project//mlir:GPUToLLVMIRTranslation", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:MemRefTransforms", + "@llvm-project//mlir:NVVMTarget", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", + ], + alwayslink = 1, +) + +cc_library( + name = "jaxlib_mlir_capi_shims_hdrs", + hdrs = ["jaxlib_mlir_capi_shims.h"], + deps = [ + "@llvm-project//mlir:CAPIIRHeaders", + ], +) + # JAX-specific registrations. py_extension( name = "register_jax_dialects", srcs = ["register_jax_dialects.cc"], - copts = COPTS, + copts = COPTS + if_building_mosaic_gpu(["-DJAXLIB_MOSAIC_GPU"]), linkopts = LINKOPTS, deps = [ ":jaxlib_mlir_capi_shared_library", @@ -166,7 +237,14 @@ py_extension( "@llvm-project//mlir:MLIRBindingsPythonHeaders", "@local_config_python//:headers", "@pybind11", - ], + ] + if_building_mosaic_gpu([ + ":jaxlib_mlir_capi_shims_hdrs", + "@llvm-project//mlir:CAPIGPUHeaders", + "@llvm-project//mlir:CAPINVGPUHeaders", + "@llvm-project//mlir:CAPINVVMHeaders", + "@llvm-project//mlir:CAPILLVMHeaders", + "@llvm-project//mlir:CAPIConversionHeaders", + ]), ) ##---------------------------------------------------------------------------## @@ -270,7 +348,15 @@ cc_library( [ "//jaxlib/triton:triton_dialect_capi_objects", ], - ), + ) + if_building_mosaic_gpu([ + ":jaxlib_mlir_capi_shims", + "@llvm-project//mlir:CAPIConversionObjects", + "@llvm-project//mlir:CAPIExecutionEngineObjects", + "@llvm-project//mlir:CAPIGPUObjects", + "@llvm-project//mlir:CAPILLVMObjects", + "@llvm-project//mlir:CAPINVGPUObjects", + "@llvm-project//mlir:CAPINVVMObjects", + ]), ) cc_binary( diff --git a/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi_shims.cc b/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi_shims.cc new file mode 100644 index 000000000..1f13f48d7 --- /dev/null +++ b/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi_shims.cc @@ -0,0 +1,47 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi_shims.h" + +#include "mlir-c/IR.h" +#include "mlir/CAPI/IR.h" +#include "mlir/Dialect/GPU/Pipelines/Passes.h" +#include "mlir/Target/LLVM/NVVM/Target.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" + +extern "C" { + +void jaxMlirRegisterMemRefPasses() { + mlir::memref::registerMemRefPasses(); +} + +void jaxMlirRegisterInterfaceExternalModels(MlirDialectRegistry registry) { + mlir::NVVM::registerNVVMTargetInterfaceExternalModels(*unwrap(registry)); + mlir::gpu::registerOffloadingLLVMTranslationInterfaceExternalModels( + *unwrap(registry)); + mlir::registerBuiltinDialectTranslation(*unwrap(registry)); + mlir::registerGPUDialectTranslation(*unwrap(registry)); + mlir::registerLLVMDialectTranslation(*unwrap(registry)); + mlir::registerNVVMDialectTranslation(*unwrap(registry)); +} +void jaxMlirRegisterGPUToNVVMPipeline() { + mlir::gpu::registerGPUToNVVMPipeline(); +} + +} diff --git a/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi_shims.h b/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi_shims.h new file mode 100644 index 000000000..bebf40a7a --- /dev/null +++ b/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi_shims.h @@ -0,0 +1,34 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_MLIR_CAPI_SHIMS +#define JAXLIB_MLIR_CAPI_SHIMS + +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_CAPI_EXPORTED void jaxMlirRegisterMemRefPasses(); +MLIR_CAPI_EXPORTED void jaxMlirRegisterInterfaceExternalModels(MlirDialectRegistry registry); +MLIR_CAPI_EXPORTED void jaxMlirRegisterGPUToNVVMPipeline(); + +#ifdef __cplusplus +} +#endif + +#endif // JAXLIB_MLIR_CAPI_SHIMS diff --git a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc index 87b1cfbb3..6f5156837 100644 --- a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc +++ b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc @@ -9,6 +9,14 @@ #include "mlir-c/Dialect/Vector.h" #include "mlir-c/Transforms.h" #include "mlir/Bindings/Python/PybindAdaptors.h" +#ifdef JAXLIB_MOSAIC_GPU +#include "mlir-c/Dialect/GPU.h" +#include "mlir-c/Dialect/NVGPU.h" +#include "mlir-c/Dialect/NVVM.h" +#include "mlir-c/Dialect/LLVM.h" +#include "mlir-c/Conversion.h" +#include "jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi_shims.h" +#endif #define REGISTER_DIALECT(name) \ MlirDialectHandle name##_dialect = mlirGetDialectHandle__##name##__(); \ @@ -27,5 +35,17 @@ PYBIND11_MODULE(register_jax_dialects, m) { mlirRegisterTransformsPasses(); // Transforms used by JAX. mlirRegisterTransformsStripDebugInfo(); +#ifdef JAXLIB_MOSAIC_GPU + REGISTER_DIALECT(gpu); + REGISTER_DIALECT(nvgpu); + REGISTER_DIALECT(nvvm); + REGISTER_DIALECT(llvm); + mlirRegisterGPUPasses(); + mlirRegisterConversionPasses(); + // TODO(apaszke): Upstream and remove those. + jaxMlirRegisterMemRefPasses(); + jaxMlirRegisterInterfaceExternalModels(registry); + jaxMlirRegisterGPUToNVVMPipeline(); +#endif }); } diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD new file mode 100644 index 000000000..e00baf451 --- /dev/null +++ b/jaxlib/mosaic/gpu/BUILD @@ -0,0 +1,76 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_python//python:defs.bzl", "py_library") +load("//jaxlib:jax.bzl", "pybind_extension") + +package( + default_applicable_licenses = [], + default_visibility = ["//:__subpackages__"], +) + +py_library( + name = "mosaic_gpu", + data = [":libmlir_cuda_runtime.so"], + deps = [ + ":_mosaic_gpu_ext", + "//jaxlib/mlir:execution_engine", + "//jaxlib/mlir:gpu_dialect", + "//jaxlib/mlir:llvm_dialect", + "//jaxlib/mlir:nvgpu_dialect", + "//jaxlib/mlir:nvvm_dialect", + ], +) + +pybind_extension( + name = "_mosaic_gpu_ext", + srcs = ["_mosaic_gpu_ext.cc"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + linkopts = select({ + "@xla//xla/python:use_jax_cuda_pip_rpaths": [ + "-Wl,-rpath,$$ORIGIN/../../../nvidia/cuda_runtime/lib", + ], + "//conditions:default": [], + }), + deps = [ + "//jaxlib:kernel_nanobind_helpers", + "@xla//xla/service:custom_call_status", + "@nanobind", + ], +) + +cc_binary( + name = "libmlir_cuda_runtime.so", + srcs = ["@llvm-project//mlir:lib/ExecutionEngine/CudaRuntimeWrappers.cpp"], + copts = ["-fvisibility=default"], + linkopts = select({ + "@xla//xla/python:use_jax_cuda_pip_rpaths": [ + "-Wl,-rpath,$$ORIGIN/../../../nvidia/cuda_runtime/lib", + ], + "//conditions:default": [], + }), + linkshared = 1, + tags = [ + "manual", + "notap", + ], + deps = [ + "@llvm-project//mlir:mlir_c_runner_utils_hdrs", + "@xla//xla/tsl/cuda:cudart", + "@local_config_cuda//cuda:cuda_headers", + ], +) diff --git a/jaxlib/mosaic/gpu/_mosaic_gpu_ext.cc b/jaxlib/mosaic/gpu/_mosaic_gpu_ext.cc new file mode 100644 index 000000000..1eb137da8 --- /dev/null +++ b/jaxlib/mosaic/gpu/_mosaic_gpu_ext.cc @@ -0,0 +1,24 @@ +#include "nanobind/nanobind.h" +#include "jaxlib/kernel_nanobind_helpers.h" +#include "xla/service/custom_call_status.h" + +namespace jax::cuda { +namespace { + +namespace nb = nanobind; +using MosaicHostFunc = void(void**); + +void MosaicKernelCall(void* stream, void** buffers, char* opaque, + size_t opaque_len, XlaCustomCallStatus* status) { + void* args[2] = {&stream, &buffers}; + MosaicHostFunc* func = *reinterpret_cast(opaque); + func(args); +} + +NB_MODULE(_mosaic_gpu_ext, m) { + m.def("_custom_call_capsule", + []() { return EncapsulateFunction(MosaicKernelCall); }); +} + +} // namespace +} // namespace jax::cuda diff --git a/jaxlib/setup.py b/jaxlib/setup.py index ce3d38e2d..d73bfd7b1 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -98,10 +98,13 @@ setup( 'cuda/*', 'cuda/nvvm/libdevice/libdevice*', 'mosaic/*.py', + 'mosaic/gpu/*.so', 'mosaic/python/*.py', 'mosaic/python/*.so', 'mlir/*.py', 'mlir/dialects/*.py', + 'mlir/dialects/gpu/*.py', + 'mlir/dialects/gpu/passes/*.py', 'mlir/extras/*.py', 'mlir/_mlir_libs/*.dll', 'mlir/_mlir_libs/*.dylib', diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index ed10d730f..153b566fe 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -266,12 +266,27 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi "__main__/jaxlib/mosaic/python/_tpu_gen.py", dst_dir=mosaic_python_dir ) + has_mosaic_gpu = exists(f"__main__/jaxlib/mosaic/gpu/_mosaic_gpu_ext.{pyext}") + def if_has_mosaic_gpu(extras): + return extras if has_mosaic_gpu else [] + + if has_mosaic_gpu: + copy_runfiles( + dst_dir=jaxlib_dir / "mosaic" / "gpu", + src_files=[ + "__main__/jaxlib/mosaic/gpu/libmlir_cuda_runtime.so", + f"__main__/jaxlib/mosaic/gpu/_mosaic_gpu_ext.{pyext}", + ], + ) + copy_runfiles( dst_dir=jaxlib_dir / "mlir", src_files=[ "__main__/jaxlib/mlir/ir.py", "__main__/jaxlib/mlir/passmanager.py", - ], + ] + if_has_mosaic_gpu([ + "__main__/jaxlib/mlir/execution_engine.py", + ]), ) copy_runfiles( dst_dir=jaxlib_dir / "mlir" / "dialects", @@ -302,7 +317,19 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi "__main__/jaxlib/mlir/dialects/sparse_tensor.py", "__main__/jaxlib/mlir/dialects/stablehlo.py", "__main__/jaxlib/mlir/dialects/vector.py", - ], + ] + if_has_mosaic_gpu([ + "__main__/jaxlib/mlir/dialects/_gpu_enum_gen.py", + "__main__/jaxlib/mlir/dialects/_gpu_ops_gen.py", + "__main__/jaxlib/mlir/dialects/_nvgpu_enum_gen.py", + "__main__/jaxlib/mlir/dialects/_nvgpu_ops_gen.py", + "__main__/jaxlib/mlir/dialects/_nvvm_enum_gen.py", + "__main__/jaxlib/mlir/dialects/_nvvm_ops_gen.py", + "__main__/jaxlib/mlir/dialects/_llvm_enum_gen.py", + "__main__/jaxlib/mlir/dialects/_llvm_ops_gen.py", + "__main__/jaxlib/mlir/dialects/nvgpu.py", + "__main__/jaxlib/mlir/dialects/nvvm.py", + "__main__/jaxlib/mlir/dialects/llvm.py", + ]), ) copy_runfiles( dst_dir=jaxlib_dir / "mlir" / "extras", @@ -310,6 +337,20 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi "__main__/jaxlib/mlir/extras/meta.py", ], ) + if has_mosaic_gpu: + copy_runfiles( + dst_dir=jaxlib_dir / "mlir" / "dialects" / "gpu", + src_files=[ + "__main__/jaxlib/mlir/dialects/gpu/__init__.py", + ], + ) + copy_runfiles( + dst_dir=jaxlib_dir / "mlir" / "dialects" / "gpu" / "passes", + src_files=[ + "__main__/jaxlib/mlir/dialects/gpu/passes/__init__.py", + ], + ) + if build_utils.is_windows(): capi_so = "__main__/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi.dll" @@ -339,7 +380,10 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi f"__main__/jaxlib/mlir/_mlir_libs/_triton_ext.{pyext}", "__main__/jaxlib/mlir/_mlir_libs/_triton_ext.pyi", ] - ), + ) + if_has_mosaic_gpu([ + f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsLLVM.{pyext}", + f"__main__/jaxlib/mlir/_mlir_libs/_mlirExecutionEngine.{pyext}", + ]), ) triton_dir = jaxlib_dir / "triton" diff --git a/tests/mosaic/BUILD b/tests/mosaic/BUILD new file mode 100644 index 000000000..dd22ab80b --- /dev/null +++ b/tests/mosaic/BUILD @@ -0,0 +1,52 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load( + "//jaxlib:jax.bzl", + "jax_generate_backend_suites", + "jax_test", + "py_deps", +) + +licenses(["notice"]) + +package( + default_applicable_licenses = [], + default_visibility = ["//visibility:private"], +) + +jax_generate_backend_suites() + +jax_test( + name = "gpu_test", + srcs = [ + "gpu_test.py", + ], + disable_backends = [ + "cpu", + "tpu", + ], + disable_configs = [ + "gpu", + "gpu_a100", + "gpu_p100", + "gpu_p100_x32", + "gpu_x32", + "gpu_pjrt_c_api", + ], + shard_count = 4, + deps = [ + "//jax:mosaic_gpu", + ] + py_deps("absl/testing") + py_deps("numpy"), +) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py new file mode 100644 index 000000000..e3c1bf993 --- /dev/null +++ b/tests/mosaic/gpu_test.py @@ -0,0 +1,803 @@ +# Copyright 2024 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Mosaic GPU DSL functions and utilities.""" + +import operator +from typing import Optional + +from absl.testing import absltest, parameterized +import numpy as np +import jax +import jax.numpy as jnp +from jax._src import config +from jax._src import test_util as jtu +from jax._src.interpreters import mlir +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import arith +from jax._src.lib.mlir.dialects import scf +from jax._src.lib.mlir.dialects import vector +try: + import jax._src.lib.mosaic_gpu # noqa: F401 + HAS_MOSAIC_GPU = True +except ImportError: + HAS_MOSAIC_GPU = False +else: + from jax.experimental.mosaic import gpu as mosaic_gpu + from jax.experimental.mosaic.gpu import dsl as mgpu + from jax.experimental.mosaic.gpu.utils import * # noqa: F403 + from jax._src.lib.mlir.dialects import gpu + from jax._src.lib.mlir.dialects import llvm + + +# ruff: noqa: F405 +config.update("jax_traceback_filtering", "off") +config.parse_flags_with_absl() + +def nd_loop(bounds, body, *, _idxs = ()): + if not bounds: + body(*_idxs) + return + bound, *other_bounds = bounds + @fori(bound, ()) + def _loop_body(i, _): + nd_loop(other_bounds, body, _idxs=(*_idxs, i)) + return () + + +def mlir_sum(elems): + assert elems + total = elems[0] + for elem in elems[1:]: + total = arith.addi(total, elem) + return total + + +def copy(src: ir.Value, dst: ir.Value, swizzle: Optional[int] = None): + index = ir.IndexType.get() + thread_id = gpu.thread_id(gpu.Dimension.x) + stride = gpu.block_dim(gpu.Dimension.x) + for dim in (gpu.Dimension.y, gpu.Dimension.z): + thread_id = arith.addi(thread_id, arith.muli(gpu.thread_id(dim), stride)) + stride = arith.muli(stride, gpu.block_dim(dim)) + is_first_thread = arith.cmpi(arith.CmpIPredicate.eq, thread_id, c(0, index)) + src_ty = ir.MemRefType(src.type) + dst_ty = ir.MemRefType(dst.type) + if src_ty.shape != dst_ty.shape: + raise ValueError( + f"src and dst shapes don't match: {src_ty.shape} != {dst_ty.shape}" + ) + shape = src_ty.shape + dyn_strides = [c(s, index) for s in get_contiguous_strides(shape)] + with ir.InsertionPoint(scf.IfOp(is_first_thread).then_block): + def body(*idx): + dst_idx = idx + if swizzle is not None: + if swizzle != 128: + raise NotImplementedError("Only swizzle 128B implemented") + # TODO(apaszke): This can probably be cleaned up. + # But it works and it's test-only, so it doesn't matter much. + # After all, swizzle should just be an xor of row and linear idx, + # adjusted for the bytewidth. + bytes_per_element = bytewidth(src_ty.element_type) + elems_per_tile = 1024 // bytes_per_element + elems_per_row = elems_per_tile // 8 + elems_per_group = 16 // bytes_per_element + linear_idx = c(0, index) + for stride, i in zip(dyn_strides, idx): + linear_idx = arith.addi(linear_idx, arith.muli(i, stride)) + tile_offset = arith.remui(linear_idx, c(elems_per_tile, index)) + linear_tile_start = arith.subi(linear_idx, tile_offset) + row = arith.divui(tile_offset, c(elems_per_row, index)) + row_offset = arith.remui(tile_offset, c(elems_per_row, index)) + src_group = arith.divui(row_offset, c(elems_per_group, index)) + group_offset = arith.remui(row_offset, c(elems_per_group, index)) + dst_group = arith.xori(src_group, row) + dst_linear_idx = mlir_sum([ + linear_tile_start, + arith.muli(row, c(elems_per_row, index)), + arith.muli(dst_group, c(elems_per_group, index)), + group_offset, + ]) + dst_idx = [ + arith.remui(arith.divui(dst_linear_idx, stride), c(bound, index)) + for stride, bound in zip(dyn_strides, shape) + ] + memref.store(memref.load(src, idx), dst, dst_idx) + nd_loop([c(d, index) for d in shape], body) + scf.yield_([]) + gpu.barrier() + nvvm.fence_proxy(nvvm.ProxyKind.async_) + + +def iota_tensor(m, n, mlir_dtype): + assert m % 64 == 0 + assert n % 8 == 0 + def c(i): + return arith.constant(index, ir.IntegerAttr.get(index, i)) + index = ir.IndexType.get() + i32 = ir.IntegerType.get_signless(32) + warp_id = arith.divui(gpu.thread_id(gpu.Dimension.x), c(32)) + within_warp_id = arith.remui(gpu.thread_id(gpu.Dimension.x), c(32)) + warp_row_start = arith.muli(warp_id, c(16)) + within_warp_row = arith.divui(within_warp_id, c(4)) + start_row = arith.addi(warp_row_start, within_warp_row) + start_col = arith.muli(arith.remui(within_warp_id, c(4)), c(2)) + registers = np.empty((m // 64, n // 8, 2, 1), dtype=object) + for row_tile, col_tile, row_subtile, _ in np.ndindex(registers.shape): + row = arith.addi(start_row, c(row_tile * 64 + row_subtile * 8)) + col = arith.addi(start_col, c(col_tile * 8)) + row_value_base = arith.muli(row, c(n)) + vec = llvm.mlir_undef(ir.VectorType.get((2,), i32)) + for col_offset in range(2): + value = arith.addi(row_value_base, arith.addi(c(col_offset), col)) + value = arith.index_cast(i32, value) + vec = vector.insertelement(value, vec, position=c(col_offset)) + registers[row_tile, col_tile, row_subtile, 0] = vec + t = mgpu.FragmentedArray(_registers=registers, _layout=mgpu.WGMMA_LAYOUT) + return t.astype(mlir_dtype) + + +class TestCase(parameterized.TestCase): + + def setUp(self): + if not HAS_MOSAIC_GPU: + self.skipTest("jaxlib built without Mosaic GPU") + super().setUp() + self.prng = np.random.default_rng(1234) + self.ctx = mlir.make_ir_context() + self.ctx.__enter__() + self.loc = ir.Location.unknown() + self.loc.__enter__() + + def tearDown(self): + self.loc.__exit__(None, None, None) + self.ctx.__exit__(None, None, None) + del self.loc, self.ctx + super().tearDown() + + +class TestUtilTest(TestCase): + + def test_copy(self): + def kernel(ctx, src, dst, _): + copy(src, dst) + x = jnp.arange(2 * 3 * 5).reshape(2, 5, 3) + y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, ())(x) + np.testing.assert_array_equal(y, x) + + def test_copy_swizzle(self): + def kernel(ctx, src, dst, _): + copy(src, dst, swizzle=128) + x = jnp.arange(8 * 32, dtype=jnp.float32).reshape(8, 32) + y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, ())(x) + expected = np.zeros_like(y) + for i in range(8): + for j in range(8): + js = j ^ i + expected[i, (j * 4):(j * 4) + 4] = x[i, (js * 4):(js * 4) + 4] + np.testing.assert_array_equal(y, expected) + + def test_copy_swizzle_noop(self): + # Two swizzles cancel out + def kernel(ctx, src, dst, smem): + copy(src, smem, swizzle=128) + copy(smem, dst, swizzle=128) + x = jnp.arange(8 * 32, dtype=jnp.float32).reshape(8, 32) + y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x) + np.testing.assert_array_equal(y, x) + + def test_iota_tensor(self): + m = n = 64 + def kernel(ctx, dst, _): + f32 = ir.F32Type.get() + index = ir.IndexType.get() + registers = iota_tensor(m, n, f32).registers + assert registers.size == 16, registers.size + for i, vec_reg in enumerate(registers.flat): + for j in range(2): + reg = vector.extractelement(vec_reg, position=c(j, index)) + memref.store( + reg, dst, [gpu.thread_id(gpu.Dimension.x), c(2 * i + j, index)] + ) + out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32) + regs = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () + )() + thread_ids = np.arange(128) + warp_ids = thread_ids // 32 + lane_ids = thread_ids % 32 + thread_rows = warp_ids * 16 + lane_ids // 4 + thread_start_cols = (lane_ids % 4) * 2 + thread_cols = thread_start_cols[:, None] + (np.arange(n // 8)[None] * 8) + regs = regs.reshape(128, 8, 2, 2) + for row_half in range(2): + for col_half in range(2): + np.testing.assert_array_equal( + regs[..., row_half, col_half], + (thread_rows[:, None] + row_half * 8) * n + thread_cols + col_half + ) + + +class MemRefTest(TestCase): + @parameterized.product( + dim=tuple(range(3)), + strided=(False, True) + ) + def test_unsqueeze(self, dim, strided): + def kernel(ctx, inp, out, _): + if strided: + for i in range(8): + s = ds(i, 1) + out_slice = s if dim != 0 else (slice(None), s) + copy( + memref_unsqueeze(memref_slice(inp, s), dim), + memref_slice(out, out_slice), + ) + else: + copy(memref_unsqueeze(inp, dim), out) + x = np.arange(8 * 16, dtype=jnp.float32).reshape(8, 16) + out_shape = list(x.shape) + out_shape.insert(dim, 1) + out_ty = jax.ShapeDtypeStruct(out_shape, jnp.float32) + y = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, out_ty, () + )(x) + np.testing.assert_array_equal(y, x.reshape(out_shape)) + + @parameterized.product( + dim=tuple(range(2)), + strided=(False, True) + ) + def test_unfold(self, dim, strided): + in_shape = (8, 16) + def kernel(ctx, inp, out, _): + if strided: + # We slice the dim we don't unfold + for i in range(in_shape[1 - dim] // 4): + s = ds(i * 4, 4) + in_slice = s if dim == 1 else (slice(None), s) + out_slice = s if dim == 1 else (slice(None),) * 3 + (s,) + copy( + memref_unfold(memref_slice(inp, in_slice), dim, (2, 2, None)), + memref_slice(out, out_slice), + ) + else: + copy(memref_unfold(inp, dim, (2, 2, None)), out) + x = np.arange(np.prod(in_shape), dtype=jnp.float32).reshape(in_shape) + out_shape = list(in_shape) + out_shape[dim:dim + 1] = [2, 2, out_shape[dim] // 4] + out_ty = jax.ShapeDtypeStruct(out_shape, jnp.float32) + y = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, out_ty, () + )(x) + np.testing.assert_array_equal(y, x.reshape(out_ty.shape)) + + @parameterized.product( + dim=tuple(range(2)), + ) + def test_fold_not_strided(self, dim): + def kernel(ctx, inp, out, _): + copy(memref_fold(inp, dim, 2), out) + + x = np.arange(8 * 2 * 8, dtype=jnp.float32).reshape(8, 2, 8) + out_ty = jax.ShapeDtypeStruct((16, 8) if dim == 0 else (8, 16), jnp.float32) + y = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, out_ty, () + )(x) + np.testing.assert_array_equal(y, x.reshape(out_ty.shape)) + + @parameterized.named_parameters([ + ("packed", (4, 4, 4), (16, 4, 1), 1, 2, False), + ("strided_end", (4, 4, 4, 4), (256, 64, 16, 4), 1, 2, False), + ("strided_bot", (4, 4, 4, 4), (256, 16, 4, 1), 1, 2, False), + ("strided_top", (4, 4, 4, 4), (256, 64, 4, 1), 1, 2, True), + ("strided_mid", (4, 4, 4, 4), (265, 64, 16, 1), 1, 3, True), + ("overap", (2, 4, 4), (16, 1, 1), 0, 3, True), + ]) + def test_fold_strided( + self, shape, strides, dim, fold_rank, throws_not_impl + ): + expanded_shape = get_packed_shape(strides, shape) + total_size = np.prod(expanded_shape) + np_inp = np.arange(total_size, dtype=jnp.float32).reshape(expanded_shape) + index = tuple([slice(0, s) for s in shape]) + + # Reference implementation + def np_fold(inp, dim, fold_rank): + out_shape = list(inp.shape) + out_shape[dim : dim + fold_rank] = [ + int(np.prod(inp.shape[dim : dim + fold_rank])) + ] + if throws_not_impl: + return jax.ShapeDtypeStruct(shape=out_shape, dtype=inp.dtype) + else: + return inp.reshape(*out_shape) + + total_size = np.prod(shape) * np.prod(strides) + + def do_test(): + def kernel(ctx, inp, out, _): + copy(memref_fold(memref_slice(inp, index), dim, fold_rank), out) + + out = np_fold(np_inp[index], dim, fold_rank) + y = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), np_inp, out, () + )(np_inp) + assert ( + not throws_not_impl + ), "If it should have thrown it would during the call." + np.testing.assert_array_equal(y, out) + + if throws_not_impl: + with self.assertRaises(NotImplementedError): + do_test() + else: + do_test() + + +def get_packed_shape(strides, shape): + perm = sorted(range(len(strides)), key=lambda i: strides[i], reverse=True) + ordered_strides = [strides[i] for i in perm] + ordered_shape = [shape[i] for i in perm] + packed_shape = [ordered_shape[-1]] + packed_shape += [ + stride0 // stride + for stride0, stride in zip(ordered_strides, ordered_strides[1:]) + ] + # Invert permutation + inv_perm = [None] * len(perm) + for i, p in enumerate(perm): + inv_perm[p] = i + return [packed_shape[i] for i in inv_perm] + + +class WGMMATest(TestCase): + + @parameterized.named_parameters( + ("f32", ir.F32Type, jnp.float32), ("f16", ir.F16Type, jnp.float16) + ) + def test_store_untiled(self, mlir_dtype_cls, jax_dtype): + mlir_dtype = mlir_dtype_cls.get() + def kernel(ctx, out, _): + del ctx + iota_tensor(64, 64, mlir_dtype).store_untiled(out) + expected = np.arange(64 * 64, dtype=jax_dtype).reshape(64, 64) + iota = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), expected, () + )() + np.testing.assert_array_equal(iota, expected) + + @parameterized.named_parameters( + ("f32", ir.F32Type, jnp.float32), + ("f16", ir.F16Type, jnp.float16), + ) + def test_store_tiled(self, mlir_dtype_cls, jax_dtype): + mlir_dtype = mlir_dtype_cls.get() + m = 128 + n = 256 + tiling = (64, 128 // bytewidth(mlir_dtype)) + def kernel(ctx, out, smem): + del ctx + iota_tensor(m, n, mlir_dtype).store_tiled(smem, swizzle=128) + copy(smem, out, swizzle=128) + expected = ( + np.arange(m * n, dtype=jax_dtype) + .reshape(m // tiling[0], tiling[0], n // tiling[1], tiling[1]) + .transpose(0, 2, 1, 3) + ) + iota = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), expected, expected + )() + np.testing.assert_array_equal(iota, expected) + + @parameterized.product( + lhs_transpose=(False, True), + rhs_transpose=(False, True), + mlir_dtype_cls=(ir.F16Type, ir.BF16Type, ir.F32Type), + m=(64, 128, 192), + n=(32, 64, 128, 192), + k_steps=(1, 2), + tma_inputs=(False, True), + ) + def test_wgmma( + self, + m, + n, + k_steps, + mlir_dtype_cls, + lhs_transpose, + rhs_transpose, + tma_inputs, + ): + mlir_dtype = mlir_dtype_cls.get() + if ir.F32Type.isinstance(mlir_dtype): # We actually use tf32 instead + jax_dtype = jnp.float32 + if lhs_transpose or not rhs_transpose: + self.skipTest("Transpose only supported in 16-bit WGMMA") + exponent_bits, mantissa_bits = 8, 10 # Use tf32 + elif bytewidth(mlir_dtype) == 2: + if n % 64 != 0: + self.skipTest("16-bit WGMMA only supports n % 64 == 0") + if ir.F16Type.isinstance(mlir_dtype): + jax_dtype = jnp.float16 + exponent_bits, mantissa_bits = 5, 10 + elif ir.BF16Type.isinstance(mlir_dtype): + jax_dtype = jnp.bfloat16 + exponent_bits, mantissa_bits = 8, 7 + else: + raise NotImplementedError(mlir_dtype) + else: + raise NotImplementedError(mlir_dtype) + nk_tile = 128 // bytewidth(mlir_dtype) + k = nk_tile * k_steps + assert m % 64 == 0 and n % nk_tile == 0 + index = ir.IndexType.get() + + row_major = mgpu.WGMMALayout.ROW_MAJOR + col_major = mgpu.WGMMALayout.COL_MAJOR + lhs_order = col_major if lhs_transpose else row_major + rhs_order = col_major if rhs_transpose else row_major + + def kernel(ctx, lhs, rhs, out, scratch): + lhs_smem, rhs_smem = scratch + if tma_inputs: + lhs_transform = (mosaic_gpu.TileTransform((64, nk_tile)),) + if lhs_transpose: + assert nk_tile == 64 # Make sure we didn't have to transpose tiling. + lhs_transform += (mosaic_gpu.TransposeTransform((1, 0, 2, 3)),) + rhs_transform = (mosaic_gpu.TileTransform((nk_tile, nk_tile)),) + if rhs_transpose: + rhs_transform += (mosaic_gpu.TransposeTransform((1, 0, 2, 3)),) + barriers = BarrierArray(2) + ctx.async_copy( + src_ref=lhs, + dst_ref=lhs_smem, + swizzle=128, + gmem_transform=lhs_transform, + barrier=barriers[0], + ) + ctx.async_copy( + src_ref=rhs, + dst_ref=rhs_smem, + swizzle=128, + gmem_transform=rhs_transform, + barrier=barriers[1], + ) + for i in range(2): + barriers[i].wait() + else: + for mi in range(m // 64): + for ki in range(k // nk_tile): + lhs_slice = ( + ds(c(mi * 64, index), 64), + ds(c(ki * nk_tile, index), nk_tile), + ) + if lhs_transpose: + lhs_slice = lhs_slice[::-1] + copy( + src=memref_slice(lhs, lhs_slice), + dst=memref_slice(lhs_smem, (mi, ki)), + swizzle=128, + ) + for ki in range(k // nk_tile): + k_slice = ds(c(ki * nk_tile, index), nk_tile) + for ni in range(n // nk_tile): + rhs_slice = (k_slice, ds(c(ni * nk_tile, index), nk_tile)) + if rhs_transpose: + rhs_slice = rhs_slice[::-1] + copy( + src=memref_slice(rhs, rhs_slice), + dst=memref_slice(rhs_smem, (ki, ni)), + swizzle=128, + ) + init_acc = mgpu.WGMMAAccumulator.zero(m=m, n=n) + acc = mgpu.wgmma( + init_acc, lhs_smem, rhs_smem, + a_order=lhs_order, b_order=rhs_order, + ) + nvvm.wgmma_commit_group_sync_aligned() + nvvm.wgmma_wait_group_sync_aligned(0) + acc.value.store_untiled(out) + + def quantize(x): + # Quantize the input to avoid rounding when feeding the WGMMA + return jax.lax.reduce_precision(x, exponent_bits, mantissa_bits) + + x_shape = (k, m) if lhs_transpose else (m, k) + x = quantize(self.prng.uniform(-1, 1, x_shape)).astype(jax_dtype) + y_shape = (n, k) if rhs_transpose else (k, n) + y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(jax_dtype) + out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) + scratch_shape = [ + jax.ShapeDtypeStruct((m // 64, k // nk_tile, 64, nk_tile), jax_dtype), + jax.ShapeDtypeStruct( + (k // nk_tile, n // nk_tile, nk_tile, nk_tile), jax_dtype + ), + ] + z = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (x, y), out_shape, scratch_shape + )(x, y) + x32, y32 = x.astype(np.float32), y.astype(np.float32) + ref = (x32.T if lhs_transpose else x32) @ (y32.T if rhs_transpose else y32) + np.testing.assert_allclose(z, ref, atol=5e-6) + + # TODO(apaszke): Add support for f32 + @parameterized.product( + m=(64, 128, 192), + n=(64, 128, 192), + k_steps=(1, 2), + rhs_transpose=(False, True), + mlir_dtype_cls=(ir.F16Type, ir.BF16Type), + ) + def test_wgmma_reg_lhs(self, m, n, k_steps, rhs_transpose, mlir_dtype_cls): + k = 64 * k_steps + index = ir.IndexType.get() + + row_major = mgpu.WGMMALayout.ROW_MAJOR + col_major = mgpu.WGMMALayout.COL_MAJOR + rhs_order = col_major if rhs_transpose else row_major + + def kernel(ctx, rhs, out, rhs_smem): + del ctx + for ki in range(k_steps): + for ni in range(n // 64): + rhs_slice = (ds(c(ki * 64, index), 64), ds(c(ni * 64, index), 64)) + if rhs_transpose: + rhs_slice = rhs_slice[::-1] + copy( + src=memref_slice(rhs, rhs_slice), + dst=memref_slice(rhs_smem, (ki, ni)), + swizzle=128, + ) + init_acc = mgpu.WGMMAAccumulator.zero(m=m, n=n) + lhs_regs = iota_tensor(m, k, mlir_dtype_cls.get()) + acc = mgpu.wgmma(init_acc, lhs_regs, rhs_smem, b_order=rhs_order) + nvvm.wgmma_commit_group_sync_aligned() + nvvm.wgmma_wait_group_sync_aligned(0) + acc.value.store_untiled(out) + + jax_dtype = jnp.float16 if mlir_dtype_cls == ir.F16Type else jnp.bfloat16 + y_shape = (n, k) if rhs_transpose else (k, n) + y = self.prng.uniform(-1, 1, y_shape).astype(jax_dtype) + out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) + scratch_shape = jax.ShapeDtypeStruct( + (k_steps, n // 64, 64, 64), jax_dtype + ) + z = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), y, out_shape, scratch_shape + )(y) + x = np.arange(m * k, dtype=jax_dtype).reshape(m, k) + ref = jax.lax.dot( + x, (y.T if rhs_transpose else y), preferred_element_type=jnp.float32 + ) + rtol = 0 if k_steps == 1 else 2.2e-4 + np.testing.assert_allclose(z, ref, rtol=rtol, atol=0) + + +class TMATest(TestCase): + + @parameterized.product( + swizzle=(None, 128), + shape=((64, 64), (5, 64), (2, 3, 5, 64)), + dtype=(jnp.float16, jnp.float32), + ) + def test_tma_load(self, swizzle, shape, dtype): + if dtype == jnp.float32: + shape = (*shape[:-1], shape[-1] // 2) + i1 = ir.IntegerType.get_signless(1) + def kernel(ctx, src, dst, tmp): + barrier = BarrierArray(1)[0] + ctx.async_copy(src_ref=src, dst_ref=tmp, swizzle=swizzle, barrier=barrier) + barrier.wait_parity(c(0, i1)) + copy(tmp, dst, swizzle=swizzle) + x = np.arange(np.prod(shape), dtype=dtype).reshape(shape) + y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x) + np.testing.assert_array_equal(y, x) + + @parameterized.product( + swizzle=(None, 128), + shape=((128, 128), (5, 32, 128)), + dtype=(jnp.float16, jnp.float32), + ) + def test_tma_load_tiled(self, swizzle, shape, dtype): + i1 = ir.IntegerType.get_signless(1) + index = ir.IndexType.get() + tiling = (32, 128 // jnp.dtype(dtype).itemsize) + tiled_shape = tile_shape(shape, tiling)[:len(shape)] + def kernel(ctx, src, dst, tmp): + barrier = BarrierArray(1)[0] + ctx.async_copy( + src_ref=src, + dst_ref=tmp, + swizzle=swizzle, + barrier=barrier, + gmem_transform=mosaic_gpu.TileTransform(tiling), + ) + barrier.wait_parity(c(0, i1)) + for idxs in np.ndindex(tiled_shape): + untiled_idxs, tiled_idxs = idxs[:-len(tiling)], idxs[-len(tiling):] + s = ( + *untiled_idxs, + *(ds(c(ix * t, index), t) for ix, t in zip(tiled_idxs, tiling)), + ) + copy(memref_slice(tmp, idxs), memref_slice(dst, s), swizzle=swizzle) + x = np.arange(np.prod(shape), dtype=dtype).reshape(shape) + smem = jax.ShapeDtypeStruct(tile_shape(shape, tiling), dtype) + f = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, smem) + y = f(x) + np.testing.assert_array_equal(y, x) + + @parameterized.product( + swizzle=(None, 128), + dtype=(jnp.float16, jnp.float32), + ) + def test_tma_squeeze_indexing(self, swizzle, dtype): + shape = (4, 5, 64) + if dtype == jnp.float32: + shape = (*shape[:-1], shape[-1] // 2) + def kernel(ctx, src, dst, tmp): + barrier = BarrierArray(1)[0] + for i in range(4): + ctx.async_copy( + src_ref=src, + dst_ref=memref_slice(tmp, i), + gmem_slice=i, + swizzle=swizzle, + barrier=barrier, + ) + barrier.wait() + copy(tmp, dst, swizzle=swizzle) + x = np.arange(np.prod(shape), dtype=dtype).reshape(shape) + y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x) + np.testing.assert_array_equal(y, x) + + def test_parity_tracking(self): + shape = (16, 64) + index = ir.IndexType.get() + def kernel(ctx, src, dst, tmp): + barrier = BarrierArray(1)[0] + for i in range(shape[0]): + s = ds(c(i, index), 1) + ctx.async_copy( + src_ref=src, dst_ref=tmp, gmem_slice=s, barrier=barrier, + ) + barrier.wait() + copy(tmp, memref_slice(dst, s)) + x = np.arange(np.prod(shape), dtype=jnp.float16).reshape(shape) + y = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, x, x[0:1] + )(x) + np.testing.assert_array_equal(y, x) + + @parameterized.product( + swizzle=(None, 128), + shape=((64, 64), (5, 64), (2, 3, 5, 64)), + dtype=(jnp.float16, jnp.float32), + ) + def test_tma_store(self, swizzle, shape, dtype): + if dtype == jnp.float32: + shape = (*shape[:-1], shape[-1] // 2) + def kernel(ctx, src, dst, tmp): + copy(src, tmp, swizzle=swizzle) + ctx.async_copy(src_ref=tmp, dst_ref=dst, swizzle=swizzle) + ctx.await_async_copy(0) + x = np.arange(np.prod(shape), dtype=dtype).reshape(shape) + y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x) + np.testing.assert_array_equal(y, x) + + +class FragmentedArrayTest(TestCase): + + @parameterized.product( + op=( + operator.add, + operator.mul, + operator.sub, + operator.truediv, + (lambda x, y: mgpu.FragmentedArray.max(x, y), np.maximum), + ), + m=(64, 128), + n=(8, 16, 32, 64, 80, 128, 256), + ) + def test_binary(self, op, m=64, n=32): + if isinstance(op, tuple): + op, np_op = op + else: + np_op = op + def kernel(ctx, dst, _): + f32 = ir.F32Type.get() + iota = iota_tensor(m=m, n=n, mlir_dtype=f32) + op(iota, iota).store_untiled(dst) + out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) + result = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () + )() + x = np.arange(m * n, dtype=jnp.float32).reshape(m, n) + if op == operator.truediv: + np.testing.assert_allclose(result, np_op(x, x), atol=2e-7) + else: + np.testing.assert_array_equal(result, np_op(x, x)) + + @parameterized.product( + ops=((lambda x: mgpu.FragmentedArray.exp(x), np.exp),), + m=(64, 128), + n=(8, 16, 32, 64, 80, 128, 256), + ) + def test_unary(self, ops, m=64, n=32): + op, np_op = ops + def kernel(ctx, dst, _): + f32 = ir.F32Type.get() + iota = iota_tensor(m=m, n=n, mlir_dtype=f32) + op(iota).store_untiled(dst) + out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) + result = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () + )() + x = np.arange(m * n, dtype=jnp.float32).reshape(m, n) + np.testing.assert_allclose(result, np_op(x), atol=2e-7, rtol=2e-7) + + @parameterized.product( + op=(arith.addf, arith.maximumf), + m=(64, 128), + n=(8, 16, 32, 64, 80, 128, 256), + ) + def test_reduce(self, op, m=64, n=32): + def kernel(ctx, dst, _): + f32 = ir.F32Type.get() + iota = iota_tensor(m=m, n=n, mlir_dtype=f32) + iota.reduce(op, axis=1).broadcast_minor(n).store_untiled(dst) + out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) + result = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () + )() + x = np.arange(m * n, dtype=jnp.float32).reshape(m, n) + if op == arith.addf: + expected = np.broadcast_to(x.sum(axis=1, keepdims=True), x.shape) + elif op == arith.maximumf: + expected = np.broadcast_to(x.max(axis=1, keepdims=True), x.shape) + else: + raise NotImplementedError(f"Unsupported op: {op}") + np.testing.assert_array_equal(result, expected) + + def test_splat(self): + def kernel(ctx, dst, _): + f32 = ir.F32Type.get() + v = arith.constant(f32, ir.FloatAttr.get(f32, 3.14)) + t = mgpu.FragmentedArray.splat(v, (128,), mgpu.WGMMA_ROW_LAYOUT) + t.broadcast_minor(32).store_untiled(dst) + out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32) + result = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () + )() + np.testing.assert_array_equal(result, np.full((128, 32), 3.14, np.float32)) + + @parameterized.product(in_shape=((128, 128), (128, 64), (64, 128))) + def test_strided_load_store(self, in_shape): + def kernel(ctx, *args): + gmem_input, gmem_output, (smem_input, smem_output) = args + copy(gmem_input, smem_input) + t = mgpu.FragmentedArray.load_strided(smem_input) + t.store_untiled(smem_output) + copy(smem_output, gmem_output) + + inp = out = self.prng.uniform(-1, 1, in_shape).astype(jnp.float32) + result = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (inp,), out, [inp, out], + )(inp) + np.testing.assert_array_equal(inp, result) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader())