mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Initial commit for Mosaic GPU
Moving this to JAX to make it easier to explore Pallas integration. PiperOrigin-RevId: 625982382
This commit is contained in:
parent
c4dea624cc
commit
8e3f5b1018
4
.bazelrc
4
.bazelrc
@ -111,6 +111,10 @@ build:cuda_clang --copt=-Wno-gnu-offsetof-extensions
|
||||
# Disable clang extention that rejects unknown arguments.
|
||||
build:cuda_clang --copt=-Qunused-arguments
|
||||
|
||||
build:mosaic_gpu --@llvm-project//mlir:enable_cuda=true
|
||||
build:mosaic_gpu --copt=-DLLVM_HAS_NVPTX_TARGET=1
|
||||
build:mosaic_gpu --//jax:build_mosaic_gpu=true
|
||||
|
||||
build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
|
||||
build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
|
||||
build:rocm --@xla//xla/python:enable_gpu=true
|
||||
|
2
.github/workflows/ci-build.yaml
vendored
2
.github/workflows/ci-build.yaml
vendored
@ -140,7 +140,7 @@ jobs:
|
||||
PY_COLORS: 1
|
||||
run: |
|
||||
pytest -n auto --tb=short docs
|
||||
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas
|
||||
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas
|
||||
|
||||
|
||||
documentation_render:
|
||||
|
@ -263,7 +263,7 @@ def write_bazelrc(*, python_bin_path, remote_build,
|
||||
rocm_amdgpu_targets, bazel_options, target_cpu_features,
|
||||
wheel_cpu, enable_mkl_dnn, use_clang, clang_path,
|
||||
clang_major_version, enable_cuda, enable_nccl, enable_rocm,
|
||||
build_gpu_plugin):
|
||||
build_gpu_plugin, enable_mosaic_gpu):
|
||||
tf_cuda_paths = []
|
||||
|
||||
with open("../.jax_configure.bazelrc", "w") as f:
|
||||
@ -337,6 +337,8 @@ def write_bazelrc(*, python_bin_path, remote_build,
|
||||
if use_clang:
|
||||
f.write("build --config=nvcc_clang\n")
|
||||
f.write(f"build --action_env=CLANG_CUDA_COMPILER_PATH={clang_path}\n")
|
||||
if enable_mosaic_gpu:
|
||||
f.write("build --config=mosaic_gpu")
|
||||
if enable_rocm:
|
||||
f.write("build --config=rocm\n")
|
||||
if not enable_nccl:
|
||||
@ -541,6 +543,10 @@ def main():
|
||||
"--editable",
|
||||
action="store_true",
|
||||
help="Create an 'editable' jaxlib build instead of a wheel.")
|
||||
add_boolean_argument(
|
||||
parser,
|
||||
"enable_mosaic_gpu",
|
||||
help_str="Should we build with Mosaic GPU? VERY EXPERIMENTAL.")
|
||||
add_boolean_argument(
|
||||
parser,
|
||||
"configure_only",
|
||||
@ -652,6 +658,7 @@ def main():
|
||||
enable_nccl=args.enable_nccl,
|
||||
enable_rocm=args.enable_rocm,
|
||||
build_gpu_plugin=args.build_gpu_plugin,
|
||||
enable_mosaic_gpu=args.enable_mosaic_gpu,
|
||||
)
|
||||
|
||||
if args.configure_only:
|
||||
|
52
jax/BUILD
52
jax/BUILD
@ -70,6 +70,19 @@ config_setting(
|
||||
},
|
||||
)
|
||||
|
||||
# If this flag is true, jaxlib will be built with Mosaic GPU. VERY EXPERIMENTAL.
|
||||
bool_flag(
|
||||
name = "build_mosaic_gpu",
|
||||
build_setting_default = False,
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "enable_mosaic_gpu",
|
||||
flag_values = {
|
||||
":build_mosaic_gpu": "True",
|
||||
},
|
||||
)
|
||||
|
||||
exports_files([
|
||||
"LICENSE",
|
||||
"version.py",
|
||||
@ -116,6 +129,14 @@ package_group(
|
||||
] + pallas_tpu_internal_users,
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "mosaic_gpu_users",
|
||||
packages = [
|
||||
"//...",
|
||||
"//learning/brain/research/jax",
|
||||
],
|
||||
)
|
||||
|
||||
# JAX-private test utilities.
|
||||
py_library(
|
||||
# This build target is required in order to use private test utilities in jax._src.test_util,
|
||||
@ -647,6 +668,37 @@ pytype_strict_library(
|
||||
],
|
||||
)
|
||||
|
||||
# This target only supports sm_90 GPUs.
|
||||
py_library(
|
||||
name = "mosaic_gpu",
|
||||
srcs = glob(["experimental/mosaic/gpu/*.py"]),
|
||||
visibility = [
|
||||
":mosaic_gpu_users",
|
||||
],
|
||||
deps = [
|
||||
":config",
|
||||
":jax",
|
||||
":mlir",
|
||||
"//jax/_src/lib",
|
||||
"//third_party/py/absl/flags",
|
||||
"//jaxlib/mlir:arithmetic_dialect",
|
||||
"//jaxlib/mlir:builtin_dialect",
|
||||
"//jaxlib/mlir:execution_engine",
|
||||
"//jaxlib/mlir:func_dialect",
|
||||
"//jaxlib/mlir:gpu_dialect",
|
||||
"//jaxlib/mlir:ir",
|
||||
"//jaxlib/mlir:llvm_dialect",
|
||||
"//jaxlib/mlir:math_dialect",
|
||||
"//jaxlib/mlir:memref_dialect",
|
||||
"//jaxlib/mlir:nvgpu_dialect",
|
||||
"//jaxlib/mlir:nvvm_dialect",
|
||||
"//jaxlib/mlir:pass_manager",
|
||||
"//jaxlib/mlir:scf_dialect",
|
||||
"//jaxlib/mlir:vector_dialect",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "partial_eval",
|
||||
srcs = ["_src/interpreters/partial_eval.py"],
|
||||
|
@ -15,6 +15,7 @@
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"if_building_jaxlib",
|
||||
"if_building_mosaic_gpu",
|
||||
"jax_visibility",
|
||||
"py_library_providing_imports_info",
|
||||
"pytype_strict_library",
|
||||
@ -31,6 +32,7 @@ py_library_providing_imports_info(
|
||||
"__init__.py",
|
||||
"mlir/__init__.py",
|
||||
"mlir/dialects/__init__.py",
|
||||
"mosaic_gpu.py",
|
||||
"triton.py",
|
||||
],
|
||||
lib_rule = pytype_strict_library,
|
||||
@ -58,5 +60,5 @@ py_library_providing_imports_info(
|
||||
"//jaxlib/mlir:stablehlo_dialect",
|
||||
"//jaxlib/mlir:vector_dialect",
|
||||
# xla_client
|
||||
]),
|
||||
]) + if_building_mosaic_gpu(["//jaxlib/mosaic/gpu:mosaic_gpu"]),
|
||||
)
|
||||
|
@ -23,6 +23,13 @@ import jaxlib.mlir.dialects.func as func
|
||||
import jaxlib.mlir.dialects.scf as scf
|
||||
import jaxlib.mlir.dialects.sparse_tensor as sparse_tensor
|
||||
import jaxlib.mlir.dialects.vector as vector
|
||||
try:
|
||||
import jaxlib.mlir.dialects.gpu as gpu # type: ignore
|
||||
import jaxlib.mlir.dialects.nvgpu as nvgpu # type: ignore
|
||||
import jaxlib.mlir.dialects.nvvm as nvvm # type: ignore
|
||||
import jaxlib.mlir.dialects.llvm as llvm # type: ignore
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
from jax._src import lib
|
||||
|
||||
|
23
jax/_src/lib/mosaic_gpu.py
Normal file
23
jax/_src/lib/mosaic_gpu.py
Normal file
@ -0,0 +1,23 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa
|
||||
|
||||
try:
|
||||
from jaxlib.mosaic.gpu import _mosaic_gpu_ext # pytype: disable=import-error
|
||||
except ImportError as e:
|
||||
raise ModuleNotFoundError(
|
||||
"Cannot import the Mosaic GPU bindings. You may need to build jaxlib from"
|
||||
" source."
|
||||
) from e
|
595
jax/experimental/mosaic/gpu/__init__.py
Normal file
595
jax/experimental/mosaic/gpu/__init__.py
Normal file
@ -0,0 +1,595 @@
|
||||
# Copyright 2024 The JAX Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import contextlib
|
||||
import ctypes
|
||||
import dataclasses
|
||||
import os
|
||||
import pathlib
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Any, Sequence
|
||||
|
||||
import jax
|
||||
from jax._src import config
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import mosaic_gpu as mosaic_gpu_lib
|
||||
from jaxlib.mlir import ir
|
||||
from jaxlib.mlir.dialects import arith
|
||||
from jaxlib.mlir.dialects import builtin
|
||||
from jaxlib.mlir.dialects import func
|
||||
from jaxlib.mlir.dialects import gpu
|
||||
from jaxlib.mlir.dialects import llvm
|
||||
from jaxlib.mlir.dialects import memref
|
||||
from jaxlib.mlir.dialects import nvgpu
|
||||
from jaxlib.mlir.dialects import nvvm
|
||||
from jaxlib.mlir.execution_engine import ExecutionEngine
|
||||
from jaxlib.mlir.passmanager import PassManager
|
||||
import numpy as np
|
||||
|
||||
from . import dsl as mgpu
|
||||
from . import profiler
|
||||
from . import utils
|
||||
|
||||
# mypy: ignore-errors
|
||||
|
||||
# MLIR can't find libdevice unless we point it to the CUDA path
|
||||
# TODO(apaszke): Unify with jax._src.lib.cuda_path
|
||||
CUDA_ROOT = "/usr/local/cuda"
|
||||
if os.environ.get("CUDA_ROOT") is None:
|
||||
os.environ["CUDA_ROOT"] = CUDA_ROOT
|
||||
else:
|
||||
CUDA_ROOT = os.environ["CUDA_ROOT"]
|
||||
|
||||
PTXAS_PATH = os.path.join(CUDA_ROOT, "bin/ptxas")
|
||||
NVDISASM_PATH = os.path.join(CUDA_ROOT, "bin/nvdisasm")
|
||||
|
||||
|
||||
c = mgpu.c # This is too common to fully qualify.
|
||||
|
||||
|
||||
xla_client.register_custom_call_target(
|
||||
"mosaic_gpu",
|
||||
mosaic_gpu_lib._mosaic_gpu_ext._custom_call_capsule(),
|
||||
platform="CUDA",
|
||||
)
|
||||
|
||||
|
||||
mosaic_gpu_dump_ptx = config.define_bool_state(
|
||||
name="mosaic_gpu_dump_ptx",
|
||||
default=config.bool_env("MOSAIC_GPU_DUMP_PTX", False),
|
||||
help="If set, prints the kernel PTX",
|
||||
)
|
||||
mosaic_gpu_dump_ptxas = config.define_bool_state(
|
||||
name="mosaic_gpu_dump_ptxas",
|
||||
default=config.bool_env("MOSAIC_GPU_DUMP_PTXAS", False),
|
||||
help="If set, prints the ptxas verbose output",
|
||||
)
|
||||
mosaic_gpu_dump_sass = config.define_bool_state(
|
||||
name="mosaic_gpu_dump_sass",
|
||||
default=config.bool_env("MOSAIC_GPU_DUMP_SASS", False),
|
||||
help="If set, prints the kernel SASS",
|
||||
)
|
||||
mosaic_gpu_print_after_all = config.define_bool_state(
|
||||
name='mosaic_gpu_print_after_all',
|
||||
default=config.bool_env('MOSAIC_GPU_PRINT_AFTER_ALL', False),
|
||||
help="If set, prints the kernel module after every pass",
|
||||
)
|
||||
|
||||
|
||||
mosaic_gpu_p = jax.core.Primitive("mosaic_gpu_p")
|
||||
mosaic_gpu_p.multiple_results = True
|
||||
|
||||
|
||||
@mosaic_gpu_p.def_abstract_eval
|
||||
def _mosaic_gpu_abstract_eval(*_, module, out_types):
|
||||
return [jax._src.core.ShapedArray(t.shape, t.dtype) for t in out_types]
|
||||
|
||||
|
||||
def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types):
|
||||
runtime_path = (
|
||||
pathlib.Path(mosaic_gpu_lib._mosaic_gpu_ext.__file__).parent
|
||||
/ "libmlir_cuda_runtime.so"
|
||||
)
|
||||
shared_libs = [str(runtime_path)] if runtime_path.exists() else []
|
||||
engine = ExecutionEngine(
|
||||
module, opt_level=3, shared_libs=shared_libs, enable_object_dump=False
|
||||
)
|
||||
ctx.module_context.add_keepalive(engine)
|
||||
func_ptr = engine.lookup("main")
|
||||
ptr_bytes = ctypes.cast(func_ptr, ctypes.c_void_p).value.to_bytes(
|
||||
8, byteorder="little"
|
||||
) # pytype: disable=attribute-error
|
||||
op = mlir.custom_call(
|
||||
"mosaic_gpu",
|
||||
result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
|
||||
operands=args,
|
||||
backend_config=ptr_bytes,
|
||||
)
|
||||
return op.results
|
||||
|
||||
mlir.register_lowering(mosaic_gpu_p, _mosaic_gpu_lowering_rule, "cuda")
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class MemRefTransform:
|
||||
def apply(self, ref: ir.Value) -> ir.Value:
|
||||
raise NotImplementedError("Subclasses should override this method")
|
||||
|
||||
def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]:
|
||||
raise NotImplementedError("Subclasses should override this method")
|
||||
|
||||
def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]:
|
||||
raise NotImplementedError("Subclasses should override this method")
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TileTransform(MemRefTransform):
|
||||
"""Tiles a suffix of memref dimensions.
|
||||
|
||||
For example, given a memref of shape (5, 128, 128) and a tiling of (64, 32),
|
||||
the shape of the result will be (5, 2, 4, 64, 32). The shape always ends with
|
||||
the tile shape, and the size of tiled dimensions is divided by the tile size.
|
||||
This is especially useful for swizzled WGMMA, which expect tiled layouts in
|
||||
shared memory.
|
||||
"""
|
||||
tiling: tuple[int, ...]
|
||||
|
||||
def apply(self, ref: ir.Value) -> ir.Value:
|
||||
untiled_rank = ir.MemRefType(ref.type).rank
|
||||
tiling_rank = len(self.tiling)
|
||||
tiled_rank = untiled_rank + tiling_rank
|
||||
for t, d in zip(self.tiling[::-1], range(untiled_rank)[::-1]):
|
||||
ref = mgpu.memref_unfold(ref, d, (None, t))
|
||||
permutation = (
|
||||
*range(untiled_rank - tiling_rank),
|
||||
*range(untiled_rank - tiling_rank, tiled_rank, 2),
|
||||
*range(untiled_rank - tiling_rank + 1, tiled_rank, 2),
|
||||
)
|
||||
return mgpu.memref_transpose(ref, permutation)
|
||||
|
||||
def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]:
|
||||
index = ir.IndexType.get()
|
||||
tiling_rank = len(self.tiling)
|
||||
return (
|
||||
*idx[:-tiling_rank],
|
||||
*(
|
||||
arith.divui(i, c(t, index))
|
||||
for i, t in zip(idx[-tiling_rank:], self.tiling)
|
||||
),
|
||||
*([c(0, index)] * tiling_rank),
|
||||
)
|
||||
|
||||
def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]:
|
||||
# Note that this also checks that tiled dims are not squeezed. Their slice
|
||||
# size would be 1 if so.
|
||||
tiling_rank = len(self.tiling)
|
||||
for size, tile_size in zip(shape[-tiling_rank:], self.tiling):
|
||||
if size % tile_size:
|
||||
raise ValueError(
|
||||
f"Expected GMEM slice shape {shape} suffix to be a multiple"
|
||||
f" of tiling {self.tiling}"
|
||||
)
|
||||
return (
|
||||
*shape[:-tiling_rank],
|
||||
*(s // t for s, t in zip(shape[-tiling_rank:], self.tiling)),
|
||||
*self.tiling,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TransposeTransform(MemRefTransform):
|
||||
"""Transposes memref dimensions."""
|
||||
permutation: tuple[int, ...]
|
||||
|
||||
def __post_init__(self):
|
||||
if len(self.permutation) != len(set(self.permutation)):
|
||||
raise ValueError("Permutation must be a permutation")
|
||||
|
||||
def apply(self, ref: ir.Value) -> ir.Value:
|
||||
return mgpu.memref_transpose(ref, self.permutation)
|
||||
|
||||
def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]:
|
||||
return tuple(idx[p] for p in self.permutation)
|
||||
|
||||
def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]:
|
||||
return tuple(shape[p] for p in self.permutation)
|
||||
|
||||
|
||||
OnDeviceProfiler = profiler.OnDeviceProfiler
|
||||
|
||||
|
||||
@dataclasses.dataclass()
|
||||
class LaunchContext:
|
||||
launch_op: gpu.LaunchOp
|
||||
profiler: OnDeviceProfiler | None = None
|
||||
tma_descriptors: dict[
|
||||
tuple[ir.Value, tuple[int, ...], int | None, tuple[MemRefTransform, ...]],
|
||||
ir.Value,
|
||||
] = dataclasses.field(default_factory=dict, init=False)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def named_region(self, *args, **kwargs):
|
||||
if self.profiler is not None:
|
||||
with self.profiler.record(*args, **kwargs):
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
def _get_tma_desc(
|
||||
self,
|
||||
ref,
|
||||
gmem_transform: tuple[MemRefTransform, ...],
|
||||
transformed_slice_shape: tuple[int, ...],
|
||||
swizzle: int | None,
|
||||
):
|
||||
index = ir.IndexType.get()
|
||||
ref_ty = ir.MemRefType(ref.type)
|
||||
tma_desc_key = (ref, transformed_slice_shape, swizzle, gmem_transform)
|
||||
if (tma_desc := self.tma_descriptors.get(tma_desc_key, None)) is None:
|
||||
swizzle_str = f"swizzle_{swizzle}b" if swizzle is not None else "none"
|
||||
default_tensor_map_attrs = dict(
|
||||
swizzle=swizzle_str, l2promo="none", oob="zero", interleave="none"
|
||||
)
|
||||
tensor_map_ty = utils.get_tensormap_descriptor(
|
||||
tensor=(
|
||||
f"memref<{'x'.join(map(str, transformed_slice_shape))}x{ref_ty.element_type}, 3>"
|
||||
),
|
||||
**default_tensor_map_attrs,
|
||||
)
|
||||
with ir.InsertionPoint(self.launch_op):
|
||||
for t in gmem_transform:
|
||||
ref = t.apply(ref)
|
||||
ref_unranked = memref.cast(
|
||||
ir.UnrankedMemRefType.get(ref_ty.element_type, None), ref
|
||||
)
|
||||
tma_desc = nvgpu.tma_create_descriptor(
|
||||
tensor_map_ty,
|
||||
ref_unranked,
|
||||
[c(s, index) for s in transformed_slice_shape],
|
||||
)
|
||||
self.tma_descriptors[tma_desc_key] = tma_desc
|
||||
return tma_desc
|
||||
|
||||
def async_copy(
|
||||
self,
|
||||
*,
|
||||
src_ref,
|
||||
dst_ref,
|
||||
gmem_slice: Any = (),
|
||||
gmem_transform: MemRefTransform | tuple[MemRefTransform, ...] = (),
|
||||
barrier: mgpu.Barrier | None = None,
|
||||
swizzle: int | None = None,
|
||||
arrive: bool | None = None,
|
||||
uniform: bool = True,
|
||||
):
|
||||
index = ir.IndexType.get()
|
||||
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
||||
src_ref_ty = ir.MemRefType(src_ref.type)
|
||||
dst_ref_ty = ir.MemRefType(dst_ref.type)
|
||||
element_type = src_ref_ty.element_type
|
||||
if element_type != dst_ref_ty.element_type:
|
||||
raise ValueError(
|
||||
f"Expected same element type, got {element_type} and"
|
||||
f" {dst_ref_ty.element_type}"
|
||||
)
|
||||
if not isinstance(gmem_transform, tuple):
|
||||
gmem_transform = (gmem_transform,)
|
||||
|
||||
if src_ref_ty.memory_space is None and dst_ref_ty.memory_space == smem:
|
||||
gmem_ref, smem_ref = src_ref, dst_ref
|
||||
if barrier is None:
|
||||
raise ValueError("Barriers are required for GMEM -> SMEM copies")
|
||||
if arrive is None:
|
||||
arrive = True # Arrive by default
|
||||
elif src_ref_ty.memory_space == smem and dst_ref_ty.memory_space is None:
|
||||
gmem_ref, smem_ref = dst_ref, src_ref
|
||||
if barrier is not None:
|
||||
raise ValueError("Barriers are unsupported for SMEM -> GMEM copies")
|
||||
if arrive is not None:
|
||||
raise ValueError("arrive is unsupported for SMEM -> GMEM copies")
|
||||
else:
|
||||
raise ValueError("Only SMEM <-> GMEM copies supported")
|
||||
# TODO(apaszke): This is a very approximate check. Improve it!
|
||||
expected_name = "builtin.unrealized_conversion_cast"
|
||||
if (
|
||||
gmem_ref.owner is None
|
||||
or gmem_ref.owner.opview.OPERATION_NAME != expected_name
|
||||
):
|
||||
raise ValueError("GMEM reference in async_copy must be a kernel argument")
|
||||
|
||||
base_indices, slice_shape, is_squeezed = utils.parse_indices(
|
||||
gmem_slice, ir.MemRefType(gmem_ref.type).shape
|
||||
)
|
||||
dyn_base_indices = tuple(
|
||||
c(i, index) if not isinstance(i, ir.Value) else i for i in base_indices
|
||||
)
|
||||
slice_shape = tuple(slice_shape)
|
||||
for t in gmem_transform:
|
||||
dyn_base_indices = t.transform_index(dyn_base_indices)
|
||||
slice_shape = t.transform_shape(slice_shape)
|
||||
for dim, squeezed in enumerate(is_squeezed):
|
||||
if squeezed:
|
||||
smem_ref = mgpu.memref_unsqueeze(smem_ref, dim)
|
||||
smem_ref_ty = ir.MemRefType(smem_ref.type)
|
||||
|
||||
if slice_shape != tuple(smem_ref_ty.shape):
|
||||
raise ValueError(
|
||||
"Expected the SMEM reference to have the same shape as the tiled"
|
||||
f" slice: {tuple(smem_ref_ty.shape)} != {slice_shape}"
|
||||
)
|
||||
tma_desc = self._get_tma_desc(
|
||||
gmem_ref, gmem_transform, slice_shape, swizzle,
|
||||
)
|
||||
|
||||
# nvgpu TMA instructions expect reversed indices...
|
||||
rev_dyn_based_indices = reversed(dyn_base_indices)
|
||||
|
||||
uniform_ctx = mgpu.once if uniform else contextlib.nullcontext
|
||||
|
||||
if gmem_ref is src_ref:
|
||||
with uniform_ctx():
|
||||
assert barrier is not None # for pytype
|
||||
barrier_group = barrier.barrier_array.value
|
||||
barrier_idx = barrier.offset
|
||||
if arrive:
|
||||
slice_bytes = c(
|
||||
np.prod(slice_shape) * mgpu.bytewidth(element_type), index
|
||||
)
|
||||
nvgpu.mbarrier_arrive_expect_tx(
|
||||
barrier_group, slice_bytes, barrier_idx
|
||||
)
|
||||
nvgpu.tma_async_load(
|
||||
smem_ref, barrier_group, tma_desc, rev_dyn_based_indices, barrier_idx
|
||||
)
|
||||
else:
|
||||
with uniform_ctx():
|
||||
nvgpu.tma_async_store(smem_ref, tma_desc, rev_dyn_based_indices)
|
||||
nvvm.cp_async_bulk_commit_group()
|
||||
|
||||
def await_async_copy(
|
||||
self, allow_groups: int, await_read_only: bool = False
|
||||
):
|
||||
nvvm.cp_async_bulk_wait_group(allow_groups, read=await_read_only)
|
||||
gpu.barrier() # Groups are supposedly tracked per-thread
|
||||
|
||||
def _prefetch_tma_descs(self):
|
||||
with ir.InsertionPoint(self.launch_op.body.blocks[0]):
|
||||
with mgpu.once():
|
||||
for desc in self.tma_descriptors.values():
|
||||
nvgpu.tma_prefetch_descriptor(desc)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _launch(
|
||||
token,
|
||||
grid,
|
||||
block,
|
||||
smem_buffers,
|
||||
profiler_spec: profiler.ProfilerSpec | None = None,
|
||||
maybe_prof_buffer: ir.Value | None = None,
|
||||
):
|
||||
if (profiler_spec is None) != (maybe_prof_buffer is None):
|
||||
raise ValueError
|
||||
index = ir.IndexType.get()
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
i8 = ir.IntegerType.get_signless(8)
|
||||
grid_vals = [c(i, index) for i in grid]
|
||||
block_vals = [c(i, index) for i in block]
|
||||
flat_refs, smem_buffer_tree = jax.tree.flatten(smem_buffers)
|
||||
|
||||
smem_ref_bytes = []
|
||||
for ref_ty in flat_refs:
|
||||
smem_ref_bytes.append(
|
||||
np.prod(ref_ty.shape) * np.dtype(ref_ty.dtype).itemsize
|
||||
)
|
||||
|
||||
smem_bytes = sum(smem_ref_bytes)
|
||||
if profiler_spec is not None:
|
||||
smem_bytes += profiler_spec.smem_bytes(grid)
|
||||
|
||||
launch_op = gpu.LaunchOp(
|
||||
token.type, [token], *grid_vals, *block_vals,
|
||||
dynamicSharedMemorySize=c(smem_bytes, i32))
|
||||
launch_op.body.blocks.append(*([index] * 12)) # Append an empty block
|
||||
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
||||
with ir.InsertionPoint(launch_op.body.blocks[0]):
|
||||
dynamic_smem = gpu.dynamic_shared_memory(
|
||||
ir.MemRefType.get(
|
||||
(ir.ShapedType.get_dynamic_size(),), i8, memory_space=smem
|
||||
)
|
||||
)
|
||||
smem_refs = []
|
||||
dynamic_smem_offset = 0
|
||||
for ref_ty, ref_bytes in zip(flat_refs, smem_ref_bytes):
|
||||
mlir_dtype = mlir.dtype_to_ir_type(ref_ty.dtype)
|
||||
tile_smem = memref.view(
|
||||
ir.MemRefType.get(ref_ty.shape, mlir_dtype, memory_space=smem),
|
||||
dynamic_smem, c(dynamic_smem_offset, index), [],
|
||||
)
|
||||
dynamic_smem_offset += ref_bytes
|
||||
smem_refs.append(tile_smem)
|
||||
|
||||
if profiler_spec:
|
||||
prof_smem = memref.view(
|
||||
ir.MemRefType.get(
|
||||
(profiler_spec.smem_i32_elements(grid=grid),),
|
||||
i32, memory_space=smem,
|
||||
),
|
||||
dynamic_smem, c(dynamic_smem_offset, index), [],
|
||||
)
|
||||
prof = profiler.OnDeviceProfiler(
|
||||
profiler_spec, prof_smem, maybe_prof_buffer
|
||||
)
|
||||
else:
|
||||
prof = None
|
||||
smem_ref_tree = jax.tree.unflatten(smem_buffer_tree, smem_refs)
|
||||
yield LaunchContext(launch_op, prof), smem_ref_tree
|
||||
if prof is not None:
|
||||
prof.finalize(grid=grid)
|
||||
gpu.terminator()
|
||||
|
||||
|
||||
def as_gpu_kernel(
|
||||
body,
|
||||
grid: tuple[int, ...],
|
||||
block: tuple[int, ...],
|
||||
in_shape,
|
||||
out_shape,
|
||||
smem_scratch_shape,
|
||||
prof_spec: profiler.ProfilerSpec | None = None,
|
||||
):
|
||||
ptr_ty = ir.Type.parse("!llvm.ptr")
|
||||
token_ty = ir.Type.parse("!gpu.async.token")
|
||||
|
||||
def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType:
|
||||
return ir.MemRefType.get(shape.shape, mlir.dtype_to_ir_type(shape.dtype))
|
||||
|
||||
if isinstance(in_shape, list):
|
||||
in_shape = tuple(in_shape)
|
||||
elif not isinstance(in_shape, tuple):
|
||||
in_shape = (in_shape,)
|
||||
in_ref_tys = [_shape_to_ref_ty(t) for t in in_shape]
|
||||
|
||||
unwrap_output_tuple = False
|
||||
if isinstance(out_shape, list):
|
||||
out_shape = tuple(out_shape)
|
||||
elif not isinstance(out_shape, tuple):
|
||||
out_shape = (out_shape,)
|
||||
unwrap_output_tuple = True
|
||||
out_ref_tys = [_shape_to_ref_ty(t) for t in out_shape]
|
||||
if prof_spec is not None:
|
||||
out_shape = (*out_shape, prof_spec.jax_buffer_type)
|
||||
out_ref_tys.append(prof_spec.mlir_buffer_type)
|
||||
|
||||
module = ir.Module.create()
|
||||
with ir.InsertionPoint(module.body):
|
||||
@func.FuncOp.from_py_func(ptr_ty, ptr_ty)
|
||||
def main(token_ptr, buffers):
|
||||
token = builtin.unrealized_conversion_cast([token_ty], [token_ptr])
|
||||
arg_refs = []
|
||||
for i, ref_ty in enumerate([*in_ref_tys, *out_ref_tys]):
|
||||
ptr = llvm.LoadOp(ptr_ty, llvm.GEPOp(ptr_ty, buffers, [], [i], ptr_ty))
|
||||
arg_refs.append(utils.ptr_as_memref(ptr, ir.MemRefType(ref_ty)))
|
||||
in_refs = arg_refs[:len(in_ref_tys)]
|
||||
out_refs = arg_refs[len(in_ref_tys):]
|
||||
prof_buffer = out_refs.pop() if prof_spec is not None else None
|
||||
with _launch(
|
||||
token, grid, block, smem_scratch_shape, prof_spec, prof_buffer
|
||||
) as (launch_ctx, smem_refs):
|
||||
body(launch_ctx, *in_refs, *out_refs, smem_refs)
|
||||
main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
|
||||
module.operation.verify()
|
||||
|
||||
expected_arg_treedef = jax.tree.structure(in_shape)
|
||||
def _check_args(args):
|
||||
arg_treedef = jax.tree.structure(args)
|
||||
if arg_treedef != expected_arg_treedef:
|
||||
raise ValueError(
|
||||
f"Invalid argument structure: expected {expected_arg_treedef}, got"
|
||||
f" {arg_treedef}"
|
||||
)
|
||||
|
||||
dump_low_level(module)
|
||||
|
||||
pass_manager = PassManager.parse(
|
||||
"builtin.module(gpu-lower-to-nvvm-pipeline{cubin-format=fatbin"
|
||||
" cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3})"
|
||||
)
|
||||
if mosaic_gpu_print_after_all.value:
|
||||
pass_manager.enable_ir_printing()
|
||||
pass_manager.run(module.operation)
|
||||
|
||||
def bind(*args):
|
||||
return mosaic_gpu_p.bind(*args, out_types=out_shape, module=module)
|
||||
|
||||
if prof_spec is not None:
|
||||
@jax.jit
|
||||
def prof_kernel(*args):
|
||||
_check_args(args)
|
||||
*results, prof_buffer = bind(*args)
|
||||
def dump_profile(prof_buffer):
|
||||
out_file = os.path.join(
|
||||
os.getenv("TEST_UNDECLARED_OUTPUTS_DIR"),
|
||||
f"{time.time_ns()}-trace.json",
|
||||
)
|
||||
try:
|
||||
with open(out_file, "x") as f:
|
||||
prof_spec.dump(prof_buffer, f)
|
||||
except FileExistsError:
|
||||
pass # TODO: Retry
|
||||
jax.debug.callback(dump_profile, prof_buffer)
|
||||
return results[0] if unwrap_output_tuple else results
|
||||
return prof_kernel
|
||||
else:
|
||||
@jax.jit
|
||||
def kernel(*args):
|
||||
_check_args(args)
|
||||
results = bind(*args)
|
||||
return results[0] if unwrap_output_tuple else results
|
||||
return kernel
|
||||
|
||||
|
||||
def dump_low_level(module):
|
||||
dump_ptx = mosaic_gpu_dump_ptx.value
|
||||
dump_ptxas = mosaic_gpu_dump_ptxas.value
|
||||
dump_sass = mosaic_gpu_dump_sass.value
|
||||
if not any([dump_ptx, dump_ptxas, dump_sass]):
|
||||
return
|
||||
module = ir.Module.parse(
|
||||
module.operation.get_asm(binary=True, enable_debug_info=True)
|
||||
)
|
||||
pm = PassManager.parse(
|
||||
"builtin.module(gpu-lower-to-nvvm-pipeline{cubin-format=isa"
|
||||
" cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3})"
|
||||
)
|
||||
pm.run(module.operation)
|
||||
|
||||
for op in module.body:
|
||||
if op.OPERATION_NAME == "gpu.binary":
|
||||
objects = ir.ArrayAttr(op.objects)
|
||||
if len(objects) != 1:
|
||||
raise NotImplementedError("Expected a single object")
|
||||
obj = str(objects[0])
|
||||
start = obj.find('assembly = "') + len('assembly = "')
|
||||
end = obj.find('"', start)
|
||||
ptx = obj[start:end]
|
||||
ptx = ptx.replace("\\09", "\t").replace("\\0A", "\n")[:-3]
|
||||
if dump_ptx:
|
||||
print(ptx)
|
||||
if dump_ptxas or dump_sass:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
ptx_path = os.path.join(tmp, "kernel.ptx")
|
||||
with open(ptx_path, "w") as f:
|
||||
f.write(ptx)
|
||||
elf_path = os.path.join(tmp, 'kernel.o')
|
||||
v_flag = "-v" if dump_ptxas else ""
|
||||
ptxas_flags = f"{v_flag} --opt-level 3 --gpu-name sm_90a"
|
||||
ptxas_out = subprocess.check_output(
|
||||
f"{PTXAS_PATH} {ptxas_flags} --output-file {elf_path} {ptx_path}",
|
||||
stderr=subprocess.STDOUT,
|
||||
shell=True,
|
||||
)
|
||||
if dump_ptxas:
|
||||
print(ptxas_out.decode())
|
||||
if dump_sass:
|
||||
sass = subprocess.check_output(
|
||||
f"{NVDISASM_PATH} -ndf -c {elf_path}",
|
||||
stderr=subprocess.STDOUT,
|
||||
shell=True,
|
||||
)
|
||||
print(sass.decode())
|
47
jax/experimental/mosaic/gpu/dsl.py
Normal file
47
jax/experimental/mosaic/gpu/dsl.py
Normal file
@ -0,0 +1,47 @@
|
||||
# Copyright 2024 The JAX Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from .fragmented_array import (
|
||||
FragmentedArray,
|
||||
FragmentedLayout,
|
||||
WGMMA_LAYOUT,
|
||||
WGMMA_ROW_LAYOUT,
|
||||
WGStridedFragLayout,
|
||||
)
|
||||
from .utils import (
|
||||
Barrier,
|
||||
BarrierArray,
|
||||
DynamicSlice,
|
||||
Partition,
|
||||
Partition1D,
|
||||
bytewidth,
|
||||
c,
|
||||
commit_shared,
|
||||
debug_print,
|
||||
ds,
|
||||
fori,
|
||||
memref_fold,
|
||||
memref_slice,
|
||||
memref_transpose,
|
||||
memref_unfold,
|
||||
memref_unsqueeze,
|
||||
once,
|
||||
tile_shape,
|
||||
)
|
||||
from .wgmma import (
|
||||
WGMMAAccumulator,
|
||||
WGMMALayout,
|
||||
wgmma,
|
||||
)
|
476
jax/experimental/mosaic/gpu/fragmented_array.py
Normal file
476
jax/experimental/mosaic/gpu/fragmented_array.py
Normal file
@ -0,0 +1,476 @@
|
||||
# Copyright 2024 The JAX Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utilities for code generator."""
|
||||
|
||||
import dataclasses
|
||||
|
||||
import jax
|
||||
from jaxlib.mlir import ir
|
||||
from jaxlib.mlir.dialects import arith
|
||||
from jaxlib.mlir.dialects import gpu
|
||||
from jaxlib.mlir.dialects import llvm
|
||||
from jaxlib.mlir.dialects import math as mlir_math
|
||||
from jaxlib.mlir.dialects import memref
|
||||
from jaxlib.mlir.dialects import nvvm
|
||||
from jaxlib.mlir.dialects import vector
|
||||
import numpy as np
|
||||
|
||||
from . import utils
|
||||
from . import dsl as mgpu
|
||||
|
||||
# mypy: ignore-errors
|
||||
|
||||
WARPGROUP_SIZE = utils.WARPGROUP_SIZE
|
||||
c = utils.c
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class WGMMAFragLayout:
|
||||
"""[m, n] matrix, where m % 64 == 0 == n % 8."""
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class WGMMARowFragLayout:
|
||||
"""[m] matrix, where m % 64 == 0."""
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class WGStridedFragLayout:
|
||||
"""Convert the array to 1D and then shard across threads."""
|
||||
|
||||
shape: tuple[int, ...]
|
||||
vec_size: int
|
||||
|
||||
def __post_init__(self):
|
||||
if np.prod(self.shape) % (self.vec_size * WARPGROUP_SIZE) != 0:
|
||||
raise ValueError((self, WARPGROUP_SIZE))
|
||||
|
||||
@classmethod
|
||||
def from_memref_type(cls, memref_ty: ir.Type):
|
||||
if not ir.MemRefType.isinstance(memref_ty):
|
||||
raise TypeError(memref_ty)
|
||||
|
||||
memref_type = ir.MemRefType(memref_ty)
|
||||
bw = mgpu.bytewidth(memref_type.element_type)
|
||||
assert 8 % bw == 0 and 8 // bw != 0, bw
|
||||
return cls(shape=memref_type.shape, vec_size=8 // bw)
|
||||
|
||||
def thread_vec_idxs(self):
|
||||
"""The indexes to be used for vector load/store WGStridedFragLayout.
|
||||
|
||||
Yields:
|
||||
The indices of the vector that correspond to the current thread.
|
||||
"""
|
||||
cardinality = np.prod(self.shape)
|
||||
assert cardinality % (WARPGROUP_SIZE * self.vec_size) == 0
|
||||
reg_num = cardinality // (WARPGROUP_SIZE * self.vec_size)
|
||||
tidx = gpu.thread_id(gpu.Dimension.x)
|
||||
off = arith.muli(tidx, c(self.vec_size, tidx.type))
|
||||
for i in range(reg_num):
|
||||
yield [arith.addi(off, c(i * WARPGROUP_SIZE * self.vec_size, tidx.type))]
|
||||
|
||||
|
||||
FragmentedLayout = WGStridedFragLayout | WGMMAFragLayout | WGMMARowFragLayout
|
||||
|
||||
|
||||
WGMMA_LAYOUT = WGMMAFragLayout()
|
||||
WGMMA_ROW_LAYOUT = WGMMARowFragLayout()
|
||||
|
||||
|
||||
@jax.tree_util.register_pytree_node_class
|
||||
class FragmentedArray:
|
||||
registers: np.ndarray # of ir.Value, see checks in init for shapes.
|
||||
layout: FragmentedLayout
|
||||
|
||||
def __init__(self, *, _registers: np.ndarray, _layout: FragmentedLayout):
|
||||
self.registers = _registers
|
||||
self.layout = _layout
|
||||
|
||||
match self.layout:
|
||||
# Registers are [m_tiles, n_tiles, 2 rows, 1 cols] in WGMMA layout
|
||||
# Each element is a vector<2xdtype>
|
||||
case WGMMAFragLayout():
|
||||
if self.registers.ndim != 4 or self.registers.shape[2:] != (2, 1):
|
||||
raise ValueError("Invalid register array shape")
|
||||
|
||||
# Registers are [m_tiles, 2 rows] in WGMMA_ROW layout
|
||||
# Each element is a dtype scalar
|
||||
case WGMMARowFragLayout():
|
||||
if self.registers.ndim != 2 or self.registers.shape[-1] != 2:
|
||||
raise ValueError("Invalid register array shape")
|
||||
|
||||
# Registers are flat
|
||||
case WGStridedFragLayout(shape):
|
||||
(reg_size,) = ir.VectorType(_registers.flat[0].type).shape
|
||||
if np.prod(shape) != np.prod(_registers.shape) * WARPGROUP_SIZE * reg_size:
|
||||
raise ValueError((reg_size, shape, _registers.shape, WARPGROUP_SIZE), _registers.flat[0].type)
|
||||
case _:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def load_strided(cls, ref: ir.Value):
|
||||
if not ir.MemRefType.isinstance(ref.type):
|
||||
raise TypeError(ref.type)
|
||||
|
||||
ref_ty = ir.MemRefType(ref.type)
|
||||
ref_1d = mgpu.memref_fold(ref, 0, len(ref_ty.shape))
|
||||
layout = WGStridedFragLayout.from_memref_type(ref_ty)
|
||||
vec_ty = ir.VectorType.get((layout.vec_size,), ref_ty.element_type)
|
||||
vecs = [vector.load(vec_ty, ref_1d, vec_idx) for vec_idx in layout.thread_vec_idxs()]
|
||||
return cls(_registers=np.array(vecs), _layout=layout)
|
||||
|
||||
@classmethod
|
||||
def splat(cls, value, shape, layout):
|
||||
match layout:
|
||||
case WGMMARowFragLayout():
|
||||
if len(shape) != 1:
|
||||
raise ValueError
|
||||
if shape[0] % 64:
|
||||
raise ValueError
|
||||
reg_shape = (shape[0] // 64, 2)
|
||||
case WGMMAFragLayout():
|
||||
if len(shape) != 2:
|
||||
raise ValueError
|
||||
if shape[0] % 64 or shape[1] % 8:
|
||||
raise ValueError
|
||||
reg_shape = (shape[0] // 64, shape[1] // 8, 2, 1)
|
||||
value = vector.splat(ir.VectorType.get((2,), value.type), value)
|
||||
case WGStridedFragLayout(shape=shape, vec_size=vec_size):
|
||||
elems = np.prod(shape)
|
||||
reg_shape = (elems // (WARPGROUP_SIZE * vec_size),)
|
||||
value = vector.splat(ir.VectorType.get((vec_size,), value.type), value)
|
||||
case _:
|
||||
raise NotImplementedError(layout)
|
||||
|
||||
return cls(
|
||||
_registers=np.full(reg_shape, value, dtype=object),
|
||||
_layout=layout,
|
||||
)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
row_tiles = self.registers.shape[0]
|
||||
match self.layout:
|
||||
case WGMMAFragLayout():
|
||||
col_tiles = self.registers.shape[1]
|
||||
return (row_tiles * 64, col_tiles * 8)
|
||||
case WGMMARowFragLayout():
|
||||
return (row_tiles * 64,)
|
||||
case WGStridedFragLayout(shape):
|
||||
return shape
|
||||
|
||||
@property
|
||||
def mlir_dtype(self):
|
||||
reg_ty = self.registers.flat[0].type
|
||||
match self.layout:
|
||||
case WGMMAFragLayout() | WGStridedFragLayout():
|
||||
return ir.VectorType(reg_ty).element_type
|
||||
case WGMMARowFragLayout():
|
||||
return reg_ty
|
||||
|
||||
def _pointwise(self, op, *other):
|
||||
for o in other:
|
||||
if not isinstance(o, FragmentedArray):
|
||||
return NotImplemented
|
||||
if self.layout != o.layout:
|
||||
raise ValueError("Incompatible FragmentedArray layouts")
|
||||
if self.registers.shape != o.registers.shape:
|
||||
raise ValueError("Incompatible FragmentedArray shapes")
|
||||
new_regs = np.empty_like(self.registers)
|
||||
for idx, reg in np.ndenumerate(self.registers):
|
||||
new_regs[idx] = op(reg, *(o.registers[idx] for o in other))
|
||||
return FragmentedArray(_registers=new_regs, _layout=self.layout)
|
||||
|
||||
def __add__(self, other):
|
||||
return self._pointwise(arith.addf, other)
|
||||
|
||||
def __mul__(self, other):
|
||||
return self._pointwise(arith.mulf, other)
|
||||
|
||||
def __sub__(self, other):
|
||||
return self._pointwise(arith.subf, other)
|
||||
|
||||
def __truediv__(self, other):
|
||||
return self._pointwise(arith.divf, other)
|
||||
|
||||
def max(self, other):
|
||||
return self._pointwise(arith.maximumf, other)
|
||||
|
||||
def exp(self, approx: bool = False):
|
||||
def fast_exp(x):
|
||||
f32 = ir.F32Type.get()
|
||||
log2e = arith.constant(f32, ir.FloatAttr.get(f32, 1.4426950408889634))
|
||||
if x.type == f32:
|
||||
scaled = arith.mulf(x, log2e)
|
||||
return llvm.inline_asm(
|
||||
f32, [scaled], "ex2.approx.f32 $0,$1;", "=f,f", asm_dialect=0
|
||||
)
|
||||
elif ir.VectorType.isinstance(x.type):
|
||||
index = ir.IndexType.get()
|
||||
result = llvm.mlir_undef(x.type)
|
||||
for i in range(2):
|
||||
v = vector.extractelement(x, position=c(i, index))
|
||||
vr = fast_exp(v)
|
||||
result = vector.insertelement(vr, result, position=c(i, index))
|
||||
return result
|
||||
else:
|
||||
raise NotImplementedError(x.type)
|
||||
return self._pointwise(fast_exp if approx else mlir_math.exp)
|
||||
|
||||
def __and__(self, other):
|
||||
if not ir.IntegerType.isinstance(self.mlir_dtype):
|
||||
raise ValueError(
|
||||
"Bitwise operations only defined for integer types, not"
|
||||
f" {self.mlir_dtype}"
|
||||
)
|
||||
|
||||
return self._pointwise(arith.andi, other)
|
||||
|
||||
def bitcast(self, elt: ir.Type):
|
||||
reg_type = self.registers.flat[0].type
|
||||
if ir.VectorType.isinstance(reg_type):
|
||||
reg_shape = ir.VectorType(reg_type).shape
|
||||
ty = ir.VectorType.get(reg_shape, elt)
|
||||
else:
|
||||
ty = elt
|
||||
|
||||
return self._pointwise(lambda x: arith.bitcast(ty, x))
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self.layout != WGMMA_LAYOUT:
|
||||
raise NotImplementedError("Only WGMMA layouts support slicing")
|
||||
base_idx, slice_shape, is_squeezed = utils.parse_indices(idx, self.shape)
|
||||
if any(is_squeezed):
|
||||
raise NotImplementedError("Only slicing implemented")
|
||||
if (
|
||||
base_idx[0] % 64
|
||||
or slice_shape[0] % 64
|
||||
or base_idx[1] % 8
|
||||
or slice_shape[1] % 8
|
||||
):
|
||||
raise NotImplementedError("Only tile aligned slicing supported")
|
||||
base_idx[0] //= 64
|
||||
slice_shape[0] //= 64
|
||||
base_idx[1] //= 8
|
||||
slice_shape[1] //= 8
|
||||
new_regs = self.registers[
|
||||
base_idx[0] : base_idx[0] + slice_shape[0],
|
||||
base_idx[1] : base_idx[1] + slice_shape[1],
|
||||
]
|
||||
return FragmentedArray(_registers=new_regs, _layout=self.layout)
|
||||
|
||||
# TODO(apaszke): Support JAX dtypes here as well?
|
||||
def astype(self, new_dtype: ir.Type):
|
||||
cur_dtype = self.mlir_dtype
|
||||
if cur_dtype == new_dtype:
|
||||
return self
|
||||
from_float = ir.FloatType.isinstance(cur_dtype)
|
||||
to_float = ir.FloatType.isinstance(new_dtype)
|
||||
from_integer = ir.IntegerType.isinstance(cur_dtype)
|
||||
to_integer = ir.IntegerType.isinstance(new_dtype)
|
||||
if from_float and to_float:
|
||||
if ir.FloatType(cur_dtype).width > ir.FloatType(new_dtype).width:
|
||||
convert = arith.truncf
|
||||
else:
|
||||
convert = arith.extf
|
||||
elif from_integer and to_integer:
|
||||
if ir.IntegerType(cur_dtype).width > ir.IntegerType(new_dtype).width:
|
||||
convert = arith.trunci
|
||||
else:
|
||||
convert = arith.extsi
|
||||
elif from_integer and to_float:
|
||||
convert = arith.sitofp
|
||||
elif from_float and to_integer:
|
||||
convert = arith.fptosi
|
||||
new_registers = np.empty_like(self.registers)
|
||||
match self.layout:
|
||||
case WGMMAFragLayout():
|
||||
new_reg_ty = ir.VectorType.get((2,), new_dtype)
|
||||
case WGStridedFragLayout(vec_size=vec_size):
|
||||
new_reg_ty = ir.VectorType.get((vec_size,), new_dtype)
|
||||
case WGMMARowFragLayout():
|
||||
new_reg_ty = new_dtype
|
||||
case _:
|
||||
raise NotImplementedError(f"Unsupported layout {self.layout}")
|
||||
for idx, reg in np.ndenumerate(self.registers):
|
||||
new_registers[idx] = convert(new_reg_ty, reg)
|
||||
return FragmentedArray(_registers=new_registers, _layout=self.layout)
|
||||
|
||||
def reduce(self, op, axis):
|
||||
if self.layout != WGMMA_LAYOUT:
|
||||
raise NotImplementedError(self.layout)
|
||||
if axis != 1:
|
||||
raise NotImplementedError
|
||||
index = ir.IndexType.get()
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
new_regs = np.empty(self.registers.shape[::2], dtype=object)
|
||||
assert self.registers.shape[-1] == 1
|
||||
for row_tile, row_subtile in np.ndindex(new_regs.shape):
|
||||
# Reduce the registers owned by the current thread over n tiles
|
||||
thread_result_vec = self.registers[row_tile, 0, row_subtile, 0]
|
||||
for n_tile in range(1, self.registers.shape[1]):
|
||||
thread_result_vec = op(
|
||||
thread_result_vec, self.registers[row_tile, n_tile, row_subtile, 0]
|
||||
)
|
||||
thread_result = op(
|
||||
vector.extractelement(thread_result_vec, position=c(0, index)),
|
||||
vector.extractelement(thread_result_vec, position=c(1, index)),
|
||||
)
|
||||
# Do a shuffle to reduce in groups of 4 consecutive threads.
|
||||
result = thread_result
|
||||
for i in (1, 2):
|
||||
other_result = nvvm.shfl_sync(
|
||||
result.type,
|
||||
c(0xFFFFFFFF, i32),
|
||||
result,
|
||||
c(i, i32),
|
||||
c(0x1F, i32),
|
||||
nvvm.ShflKind.bfly,
|
||||
)
|
||||
result = op(result, other_result)
|
||||
new_regs[row_tile, row_subtile] = result
|
||||
return FragmentedArray(_registers=new_regs, _layout=WGMMA_ROW_LAYOUT)
|
||||
|
||||
def broadcast_minor(self, n):
|
||||
if self.layout != WGMMA_ROW_LAYOUT:
|
||||
raise NotImplementedError
|
||||
num_row_tiles = self.registers.shape[0]
|
||||
num_col_tiles, rem = divmod(n, 8)
|
||||
if rem:
|
||||
raise ValueError("Number of columns must be divisible by 8")
|
||||
new_regs = np.empty((num_row_tiles, num_col_tiles, 2, 1), dtype=object)
|
||||
dtype = self.mlir_dtype
|
||||
for (row_tile, row_subtile), reg in np.ndenumerate(self.registers):
|
||||
new_regs[row_tile, :, row_subtile, :] = vector.splat(
|
||||
ir.VectorType.get((2,), dtype), reg
|
||||
)
|
||||
return FragmentedArray(_registers=new_regs, _layout=WGMMA_LAYOUT)
|
||||
|
||||
def store_untiled(self, ref: ir.Value):
|
||||
if not ir.MemRefType.isinstance(ref.type):
|
||||
raise ValueError(ref)
|
||||
|
||||
match self.layout:
|
||||
case WGMMAFragLayout():
|
||||
self._store_untiled_wgmma(ref)
|
||||
case WGStridedFragLayout():
|
||||
self._store_untiled_wg_strided(ref)
|
||||
case _:
|
||||
raise NotImplementedError(self.layout)
|
||||
|
||||
def _store_untiled_wg_strided(self, ref: ir.Value):
|
||||
ref_ty = ir.MemRefType(ref.type)
|
||||
if ref_ty.shape != self.shape:
|
||||
raise ValueError((ref_ty.shape, self.shape))
|
||||
smem_1d = mgpu.memref_fold(ref, 0, len(ref_ty.shape))
|
||||
assert isinstance(self.layout, WGStridedFragLayout)
|
||||
for idx, reg in zip(self.layout.thread_vec_idxs(), self.registers.flat):
|
||||
vector.store(reg, smem_1d, idx)
|
||||
|
||||
def _store_untiled_wgmma(self, ref: ir.Value):
|
||||
"""Stores accumulator to a 2D memref. Not optimized at the moment."""
|
||||
assert self.layout == WGMMA_LAYOUT
|
||||
index = ir.IndexType.get()
|
||||
m, n = self.shape # pytype: disable=bad-unpacking
|
||||
ref_ty = ir.MemRefType(ref.type)
|
||||
if ref_ty.shape != [m, n]:
|
||||
raise ValueError(ref.type, (m, n))
|
||||
|
||||
def c(x):
|
||||
return arith.ConstantOp(index, ir.IntegerAttr.get(index, x))
|
||||
|
||||
tidx = gpu.thread_id(gpu.Dimension.x)
|
||||
lane_id = arith.remui(tidx, c(32)) # {0, 1, ..., 31}
|
||||
warp_id = arith.divui(tidx, c(32)) # {0, 1, 2, 3}
|
||||
row_base = arith.addi(
|
||||
arith.divui(lane_id, c(4)), arith.muli(warp_id, c(16))
|
||||
)
|
||||
col_base = arith.muli(arith.remui(lane_id, c(4)), c(2)) # {0, 2, 4, 6}
|
||||
it = np.ndenumerate(self.registers)
|
||||
for (row_tile, col_tile, row_idx, col_zero), elem in it:
|
||||
del col_zero
|
||||
row = arith.addi(row_base, c(row_tile * 64 + row_idx * 8))
|
||||
for col_idx in range(2):
|
||||
value = vector.extractelement(elem, position=c(col_idx))
|
||||
col = arith.addi(col_base, c(col_tile * 8 + col_idx))
|
||||
memref.store(value, ref, [row, col])
|
||||
|
||||
def store_tiled(self, ref, swizzle: int | None):
|
||||
if self.layout != WGMMA_LAYOUT:
|
||||
raise NotImplementedError
|
||||
bw = mgpu.bytewidth(self.mlir_dtype)
|
||||
m, n = self.shape # pytype: disable=bad-unpacking
|
||||
assert m % 64 == 0 # This is implied by the layout.
|
||||
if n % 32 != 0:
|
||||
raise NotImplementedError
|
||||
cols_per_tile = 128 // bw
|
||||
expected_shape = [m // 64, n // cols_per_tile, 64, cols_per_tile]
|
||||
if ir.MemRefType(ref.type).shape != expected_shape:
|
||||
raise ValueError(ref.type, (m, n))
|
||||
if swizzle != 128:
|
||||
raise NotImplementedError("Only 128B swizzle supported")
|
||||
index = ir.IndexType.get()
|
||||
|
||||
def c(x):
|
||||
return arith.ConstantOp(index, ir.IntegerAttr.get(index, x))
|
||||
|
||||
tidx = gpu.thread_id(gpu.Dimension.x)
|
||||
lane_id = arith.remui(tidx, c(32)) # {0, 1, ..., 31}
|
||||
warp_id = arith.divui(tidx, c(32)) # {0, 1, 2, 3}
|
||||
sub_row_base = arith.divui(lane_id, c(4)) # {0, 1, ..., 7}
|
||||
if bw > 2: # Stagger is only necessary for values larger than 16bit.
|
||||
is_even_row = arith.cmpi(
|
||||
arith.CmpIPredicate.eq, arith.remui(sub_row_base, c(2)), c(0)
|
||||
)
|
||||
else:
|
||||
# We rely on canonicalization to clean up the selects.
|
||||
i1 = ir.IntegerType.get_signless(1)
|
||||
is_even_row = arith.constant(i1, ir.IntegerAttr.get(i1, 1))
|
||||
row_base = arith.addi(sub_row_base, arith.muli(warp_id, c(16)))
|
||||
col_base = arith.muli(arith.remui(lane_id, c(4)), c(2)) # {0, 2, 4, 6}
|
||||
# The swizzle pattern is constant for a given thread.
|
||||
col_swizzle_bits = arith.muli(sub_row_base, c(16 // bw))
|
||||
for row_group in range(m // 64):
|
||||
for col_group in range(n // cols_per_tile):
|
||||
for row_subidx in range(2):
|
||||
row = arith.addi(row_base, c(row_subidx * 8))
|
||||
for col_subidx in range(cols_per_tile // 8):
|
||||
# We stagger the even and odd rows a little to avoid bank conflicts.
|
||||
# It seems that the STS.64 is 2x faster (and the hardware reports no
|
||||
# conflicts) when the conflicts are split between half-warps, as
|
||||
# opposed to having them within the half-warp. This requires a
|
||||
# little more work for the selects, but is ultimately worth it.
|
||||
col_subidx_even = col_subidx
|
||||
col_subidx_odd = col_subidx ^ 2
|
||||
col_off = arith.select(
|
||||
is_even_row, c(col_subidx_even * 8), c(col_subidx_odd * 8)
|
||||
)
|
||||
col = arith.addi(col_base, col_off)
|
||||
col = arith.xori(col, col_swizzle_bits)
|
||||
reg_idx_even = col_subidx_even + col_group * (cols_per_tile // 8)
|
||||
reg_idx_odd = col_subidx_odd + col_group * (cols_per_tile // 8)
|
||||
value_even = self.registers[row_group, reg_idx_even, row_subidx, 0]
|
||||
value_odd = self.registers[row_group, reg_idx_odd, row_subidx, 0]
|
||||
value = arith.select(is_even_row, value_even, value_odd)
|
||||
vector.store(value, ref, [c(row_group), c(col_group), row, col])
|
||||
|
||||
def tree_flatten(self):
|
||||
return list(self.registers.flat), (self.layout, self.registers.shape)
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, aux, flat_registers):
|
||||
layout, reg_shape = aux
|
||||
registers = np.asarray(flat_registers, dtype=object).reshape(reg_shape)
|
||||
return cls(_registers=registers, _layout=layout)
|
206
jax/experimental/mosaic/gpu/profiler.py
Normal file
206
jax/experimental/mosaic/gpu/profiler.py
Normal file
@ -0,0 +1,206 @@
|
||||
# Copyright 2024 The JAX Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jaxlib.mlir import ir
|
||||
from jaxlib.mlir.dialects import arith
|
||||
from jaxlib.mlir.dialects import gpu
|
||||
from jaxlib.mlir.dialects import memref
|
||||
from jaxlib.mlir.dialects import scf
|
||||
import numpy as np
|
||||
|
||||
from .utils import * # noqa: F403
|
||||
|
||||
# ruff: noqa: F405
|
||||
# mypy: ignore-errors
|
||||
|
||||
class ProfilerSpec:
|
||||
ENTER = 0
|
||||
EXIT = 1 << 31
|
||||
|
||||
def __init__(self, num_entries: int):
|
||||
self.num_entries = num_entries
|
||||
self.interned_names = {}
|
||||
|
||||
@property
|
||||
def mlir_buffer_type(self) -> ir.Type:
|
||||
return ir.MemRefType.get(
|
||||
(1 + self.num_entries,), ir.IntegerType.get_signless(32)
|
||||
)
|
||||
|
||||
@property
|
||||
def jax_buffer_type(self) -> ir.Type:
|
||||
return jax.ShapeDtypeStruct((1 + self.num_entries,), jnp.uint32)
|
||||
|
||||
def smem_i32_elements(self, grid: tuple[int, ...]):
|
||||
return int(self.num_entries // np.prod(grid))
|
||||
|
||||
def smem_bytes(self, grid: tuple[int, ...]):
|
||||
bytes_per_entry = 4
|
||||
return self.smem_i32_elements(grid) * bytes_per_entry
|
||||
|
||||
def intern_name(self, name: str) -> int:
|
||||
if name_id := self.interned_names.get(name, None):
|
||||
return name_id
|
||||
name_id = self.interned_names[name] = len(self.interned_names)
|
||||
if name_id & self.EXIT:
|
||||
raise RuntimeError("Allocated too many names")
|
||||
return name_id
|
||||
|
||||
def dump(self, buffer, f):
|
||||
buffer = np.asarray(buffer)
|
||||
num_blocks = buffer[0]
|
||||
per_block = self.num_entries // num_blocks
|
||||
block_entries = buffer[1 : 1 + num_blocks * per_block].reshape(
|
||||
num_blocks, per_block
|
||||
)
|
||||
start_times = block_entries[:, :2].astype(np.int64)
|
||||
start_times = (start_times[:, 0] << 32) + start_times[:, 1]
|
||||
start_times -= start_times.min() # Normalize
|
||||
entries_used = block_entries[:, 2]
|
||||
if np.any(entries_used > per_block - 2):
|
||||
raise RuntimeError("Insufficient space to capture a full trace")
|
||||
block_traces = block_entries[:, 3:]
|
||||
unintern = {v: k for k, v in self.interned_names.items()}
|
||||
events = []
|
||||
for block_idx in range(num_blocks):
|
||||
valid_entries = entries_used[block_idx] - 3
|
||||
local_clock_offset = None
|
||||
assert valid_entries % 2 == 0
|
||||
start_time = start_times[block_idx]
|
||||
block_events = []
|
||||
for i in range(0, valid_entries, 2):
|
||||
tag = block_traces[block_idx, i]
|
||||
time = block_traces[block_idx, i + 1]
|
||||
if local_clock_offset is None:
|
||||
local_clock_offset = time
|
||||
time -= local_clock_offset
|
||||
time -= i * 6 # Account for the overhead of profiling.
|
||||
if time < 0:
|
||||
break # Detect a timer wraparound
|
||||
name_id = tag
|
||||
begin = True
|
||||
if name_id & ProfilerSpec.EXIT:
|
||||
name_id = name_id ^ ProfilerSpec.EXIT
|
||||
begin = False
|
||||
name = unintern[name_id]
|
||||
block_events.append({
|
||||
"name": name,
|
||||
"ph": "B" if begin else "E",
|
||||
"ts": float(start_time + time) / 1e3,
|
||||
"pid": 0,
|
||||
"tid": block_idx,
|
||||
})
|
||||
else: # If we didn't break
|
||||
events.extend(block_events)
|
||||
return json.dump({"displayTimeUnit": "ns", "traceEvents": events}, f)
|
||||
|
||||
|
||||
class OnDeviceProfiler:
|
||||
|
||||
def __init__(self, spec: ProfilerSpec, smem_buffer: ir.Value, gmem_buffer: ir.Value):
|
||||
self.spec = spec
|
||||
# self.should_store = gpu.thread_id(gpu.Dimension.x)
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
index = ir.IndexType.get()
|
||||
num_blocks = c(1, index)
|
||||
for dim in gpu.Dimension:
|
||||
num_blocks = arith.muli(num_blocks, gpu.grid_dim(dim))
|
||||
memref.store(arith.index_cast(i32, num_blocks), gmem_buffer, [c(0, index)])
|
||||
self.entries_per_block = arith.divui(c(spec.num_entries, index), num_blocks)
|
||||
self.smem_buffer = smem_buffer
|
||||
self.gmem_buffer = gmem_buffer
|
||||
# Hopefully mem2reg will remove the allocation.
|
||||
self.offset = memref.alloca(ir.MemRefType.get((), i32), [], [])
|
||||
memref.store(c(0, i32), self.offset, [])
|
||||
|
||||
@contextlib.contextmanager
|
||||
def record(self, name: str):
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
index = ir.IndexType.get()
|
||||
name_id = self.spec.intern_name(name)
|
||||
def store(modifier):
|
||||
cur = arith.index_cast(index, memref.load(self.offset, []))
|
||||
# TODO(apaszke): Clamp indices
|
||||
# bound = arith.subi(self.entries_per_block, c(2, index))
|
||||
# cur = arith.select(
|
||||
# arith.cmpi(arith.CmpIPredicate.ult, cur, bound), cur, bound
|
||||
# )
|
||||
memref.store(c(modifier | name_id, i32), self.smem_buffer, [cur])
|
||||
memref.store(
|
||||
clock(), self.smem_buffer, [arith.addi(cur, c(1, cur.type))]
|
||||
)
|
||||
memref.store(
|
||||
arith.index_cast(i32, arith.addi(cur, c(2, cur.type))),
|
||||
self.offset,
|
||||
[],
|
||||
)
|
||||
store(ProfilerSpec.ENTER)
|
||||
yield
|
||||
store(ProfilerSpec.EXIT)
|
||||
|
||||
def finalize(self, grid):
|
||||
index = ir.IndexType.get()
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
|
||||
block_idx = c(0, index)
|
||||
for dim in reversed(gpu.Dimension): # pytype: disable=wrong-arg-types
|
||||
block_idx = arith.addi(
|
||||
arith.muli(block_idx, gpu.grid_dim(dim)), gpu.block_id(dim)
|
||||
)
|
||||
start_offset = arith.addi(
|
||||
arith.muli(block_idx, self.entries_per_block), c(1, index)
|
||||
)
|
||||
block_gmem_buffer = memref.subview(
|
||||
self.gmem_buffer, [start_offset], [self.spec.num_entries], [1],
|
||||
result_type=ir.Type.parse(
|
||||
f"memref<{self.spec.num_entries}xi32, strided<[1], offset: ?>>"
|
||||
),
|
||||
)
|
||||
# TODO(apaszke): Either use globaltimer or delete
|
||||
# memref.store(globaltimer("high"), block_gmem_buffer, [c(0, index)])
|
||||
# memref.store(globaltimer("low"), block_gmem_buffer, [c(1, index)])
|
||||
memref.store(c(0, i32), block_gmem_buffer, [c(0, index)])
|
||||
memref.store(c(0, i32), block_gmem_buffer, [c(1, index)])
|
||||
memref.store(
|
||||
arith.addi(memref.load(self.offset, []), c(3, i32)),
|
||||
block_gmem_buffer,
|
||||
[c(2, index)],
|
||||
)
|
||||
|
||||
if_first = scf.IfOp(
|
||||
arith.cmpi(
|
||||
arith.CmpIPredicate.eq, gpu.thread_id(gpu.Dimension.x), c(0, index)
|
||||
)
|
||||
)
|
||||
with ir.InsertionPoint(if_first.then_block):
|
||||
for_op = scf.ForOp(
|
||||
c(0, index),
|
||||
c(self.spec.smem_i32_elements(grid) - 3, index),
|
||||
c(1, index),
|
||||
)
|
||||
with ir.InsertionPoint(for_op.body):
|
||||
x = memref.load(self.smem_buffer, [for_op.induction_variable])
|
||||
memref.store(
|
||||
x,
|
||||
block_gmem_buffer,
|
||||
[arith.addi(for_op.induction_variable, c(3, index))],
|
||||
)
|
||||
scf.yield_([])
|
||||
scf.yield_([])
|
634
jax/experimental/mosaic/gpu/utils.py
Normal file
634
jax/experimental/mosaic/gpu/utils.py
Normal file
@ -0,0 +1,634 @@
|
||||
# Copyright 2024 The JAX Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utilities for code generator."""
|
||||
|
||||
import contextlib
|
||||
import dataclasses
|
||||
from typing import Any, Literal, Sequence
|
||||
|
||||
from absl import flags
|
||||
import jax
|
||||
from jaxlib.mlir import ir
|
||||
from jaxlib.mlir.dialects import arith
|
||||
from jaxlib.mlir.dialects import builtin
|
||||
from jaxlib.mlir.dialects import gpu
|
||||
from jaxlib.mlir.dialects import llvm
|
||||
from jaxlib.mlir.dialects import memref
|
||||
from jaxlib.mlir.dialects import nvgpu
|
||||
from jaxlib.mlir.dialects import nvvm
|
||||
from jaxlib.mlir.dialects import scf
|
||||
from jaxlib.mlir.dialects import vector
|
||||
import numpy as np
|
||||
|
||||
# mypy: ignore-errors
|
||||
|
||||
WARPGROUP_SIZE: int = 128
|
||||
DYNAMIC = -9223372036854775808
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_bool("mosaic_gpu_debug", False, "Perform debug printing")
|
||||
|
||||
# pylint: disable=line-too-long, wildcard-import, missing-function-docstring, bad-continuation, g-bad-todo, protected-access, g-explicit-length-test, missing-class-docstring, g-doc-return-or-yield, g-inconsistent-quotes
|
||||
|
||||
|
||||
def ptr_as_memref(ptr, memref_ty: ir.MemRefType):
|
||||
if len(memref_ty.shape) == 0:
|
||||
raise NotImplementedError
|
||||
i64 = ir.IntegerType.get_signless(64)
|
||||
rank = len(memref_ty.shape)
|
||||
desc_ty = ir.Type.parse(
|
||||
f"!llvm.struct<(ptr, ptr, i64, array<{rank} x i64>, array<{rank} x i64>)>"
|
||||
)
|
||||
desc = llvm.UndefOp(desc_ty)
|
||||
desc = llvm.InsertValueOp(desc, ptr, [0]) # Allocation
|
||||
desc = llvm.InsertValueOp(desc, ptr, [1]) # Aligned Base
|
||||
desc = llvm.InsertValueOp(
|
||||
desc, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, 0)), [2]
|
||||
)
|
||||
for i, s in enumerate(memref_ty.shape):
|
||||
desc = llvm.InsertValueOp(
|
||||
desc, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, s)), [3, i]
|
||||
)
|
||||
for i, s in enumerate(get_contiguous_strides(memref_ty.shape)):
|
||||
desc = llvm.InsertValueOp(
|
||||
desc, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, s)), [4, i]
|
||||
)
|
||||
return builtin.unrealized_conversion_cast([memref_ty], [desc])
|
||||
|
||||
|
||||
def get_contiguous_strides(xs):
|
||||
strides_ret = []
|
||||
stride = 1
|
||||
for x in xs[::-1]:
|
||||
strides_ret.append(stride)
|
||||
stride *= x
|
||||
return strides_ret[::-1]
|
||||
|
||||
|
||||
def c(val: int | float, ty):
|
||||
if ir.IntegerType.isinstance(ty) or ir.IndexType.isinstance(ty):
|
||||
if not isinstance(val, (int, np.integer)):
|
||||
raise TypeError(type(val))
|
||||
attr = ir.IntegerAttr.get(ty, val)
|
||||
elif ir.FloatType.isinstance(ty):
|
||||
attr = ir.FloatAttr.get(ty, val)
|
||||
elif ir.VectorType.isinstance(ty):
|
||||
return vector.splat(ty, c(val, ir.VectorType(ty).element_type))
|
||||
else:
|
||||
raise NotImplementedError(ty)
|
||||
return arith.constant(ty, attr)
|
||||
|
||||
|
||||
def get_tensormap_descriptor(**attrs):
|
||||
return ir.Type.parse(
|
||||
f"!nvgpu.tensormap.descriptor<{', '.join(k + '=' + v for k, v in attrs.items())}>"
|
||||
)
|
||||
|
||||
|
||||
def debug_print(fmt, *args, uniform=True):
|
||||
if not FLAGS.mosaic_gpu_debug:
|
||||
return
|
||||
type_formats = []
|
||||
new_args = []
|
||||
for arg in args:
|
||||
ty_format = None
|
||||
if ir.IndexType.isinstance(arg.type):
|
||||
ty_format = "%llu"
|
||||
if ir.IntegerType.isinstance(arg.type):
|
||||
width = ir.IntegerType(arg.type).width
|
||||
if width == 64:
|
||||
ty_format = "%llu"
|
||||
elif width == 1:
|
||||
ty_format = "%llu"
|
||||
arg = arith.extui(ir.IntegerType.get_signless(64), arg)
|
||||
if ir.F32Type.isinstance(arg.type):
|
||||
ty_format = "%f"
|
||||
if ir.F16Type.isinstance(arg.type):
|
||||
ty_format = "%f"
|
||||
arg = arith.extf(ir.F32Type.get(), arg)
|
||||
if ty_format is None:
|
||||
raise NotImplementedError(arg.type)
|
||||
type_formats.append(ty_format)
|
||||
new_args.append(arg)
|
||||
ctx = once if uniform else contextlib.nullcontext
|
||||
with ctx():
|
||||
gpu.printf(fmt.format(*type_formats) + "\n", new_args)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ForResult:
|
||||
op: scf.ForOp
|
||||
results: tuple[Any, ...]
|
||||
|
||||
@property
|
||||
def result(self):
|
||||
if len(self.results) != 1:
|
||||
raise ValueError
|
||||
return self.results[0]
|
||||
|
||||
|
||||
def fori(bound, carrys):
|
||||
unwrap = False
|
||||
if not isinstance(carrys, (list, tuple)):
|
||||
carrys = [carrys]
|
||||
unwrap = True
|
||||
flat_carrys, carry_treedef = jax.tree.flatten(carrys)
|
||||
|
||||
def wrapper(f):
|
||||
index = ir.IndexType.get()
|
||||
c0 = arith.ConstantOp(index, ir.IntegerAttr.get(index, 0))
|
||||
c1 = arith.ConstantOp(index, ir.IntegerAttr.get(index, 1))
|
||||
for_op = scf.ForOp(c0, bound, c1, flat_carrys)
|
||||
with ir.InsertionPoint(for_op.body):
|
||||
i = for_op.induction_variable
|
||||
inner_carrys = jax.tree.unflatten(carry_treedef, for_op.inner_iter_args)
|
||||
if unwrap:
|
||||
[inner_carrys] = inner_carrys
|
||||
new_carrys = f(i, inner_carrys)
|
||||
if unwrap:
|
||||
new_carrys = [new_carrys]
|
||||
new_flat_carrys, new_carry_treedef = jax.tree.flatten(new_carrys)
|
||||
if new_carry_treedef != carry_treedef:
|
||||
raise ValueError(new_carry_treedef, carry_treedef)
|
||||
scf.YieldOp(new_flat_carrys)
|
||||
final_flat_carrys = for_op.results
|
||||
return ForResult(
|
||||
for_op, jax.tree.unflatten(carry_treedef, final_flat_carrys)
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_warp_idx():
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
tidx = arith.index_cast(i32, gpu.thread_id(gpu.Dimension.x))
|
||||
warp_idx = arith.shrui(tidx, c(5, tidx.type))
|
||||
mask = c(0xFFFFFFFF, i32)
|
||||
return nvvm.shfl_sync(
|
||||
warp_idx.type, mask, warp_idx, c(0, i32), c(0x1F, i32), nvvm.ShflKind.idx
|
||||
)
|
||||
|
||||
|
||||
# True withon `once()` contexts.
|
||||
_ONCE_REGION_ACTIVE = False
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def once():
|
||||
"""Runs the context only from a single thread from the first warp.
|
||||
|
||||
The block is assumed to have a size of 1 in both y and z dimensions.
|
||||
"""
|
||||
global _ONCE_REGION_ACTIVE
|
||||
|
||||
if _ONCE_REGION_ACTIVE:
|
||||
yield
|
||||
return
|
||||
|
||||
warp = get_warp_idx()
|
||||
first_warp = arith.cmpi(arith.CmpIPredicate.eq, warp, c(0, warp.type))
|
||||
elected = nvvm.elect_sync(ir.IntegerType.get_signless(1))
|
||||
should_run = arith.andi(first_warp, elected)
|
||||
if_op = scf.IfOp(should_run)
|
||||
_ONCE_REGION_ACTIVE = True
|
||||
try:
|
||||
with ir.InsertionPoint(if_op.then_block):
|
||||
yield
|
||||
scf.YieldOp([])
|
||||
finally:
|
||||
_ONCE_REGION_ACTIVE = False
|
||||
|
||||
|
||||
def clock():
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
return llvm.inline_asm(
|
||||
i32, [], "mov.u32 $0,%clock;", "=r", asm_dialect=0, has_side_effects=True
|
||||
)
|
||||
|
||||
|
||||
def globaltimer(kind: Literal["low", "high"] | None = None):
|
||||
if kind is None:
|
||||
i64 = ir.IntegerType.get_signless(64)
|
||||
return llvm.inline_asm(
|
||||
i64, [], "mov.u32 $0,%globaltimer;",
|
||||
"=l", asm_dialect=0, has_side_effects=True,
|
||||
)
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
return llvm.inline_asm(
|
||||
i32, [], f"mov.u32 $0,%globaltimer_{kind[:2]};",
|
||||
"=r", asm_dialect=0, has_side_effects=True,
|
||||
)
|
||||
|
||||
|
||||
def bytewidth(ty: ir.Type):
|
||||
if ir.IntegerType.isinstance(ty):
|
||||
return ir.IntegerType(ty).width // 8
|
||||
if ir.FloatType.isinstance(ty):
|
||||
return ir.FloatType(ty).width // 8
|
||||
raise NotImplementedError(ty)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class DynamicSlice:
|
||||
base: ir.Value | int
|
||||
length: int
|
||||
|
||||
|
||||
ds = DynamicSlice
|
||||
|
||||
|
||||
def memref_slice(ref: ir.Value, index) -> ir.Value:
|
||||
ref_ty = ir.MemRefType(ref.type)
|
||||
base_indices, slice_shape, is_squeezed = parse_indices(index, ref_ty.shape)
|
||||
|
||||
memref_strides, offset = ref_ty.get_strides_and_offset()
|
||||
new_offset = offset
|
||||
for idx, stride in zip(base_indices, memref_strides):
|
||||
if isinstance(idx, int):
|
||||
new_offset += idx * stride
|
||||
else:
|
||||
new_offset = ir.ShapedType.get_dynamic_stride_or_offset()
|
||||
break
|
||||
new_strides = [
|
||||
s for s, squeeze in zip(memref_strides, is_squeezed) if not squeeze
|
||||
]
|
||||
new_shape = [s for s, squeeze in zip(slice_shape, is_squeezed) if not squeeze]
|
||||
new_layout = ir.StridedLayoutAttr.get(new_offset, new_strides)
|
||||
|
||||
ref_slice = memref.subview(
|
||||
ref, base_indices, slice_shape, [1] * len(ref_ty.shape),
|
||||
result_type=ir.MemRefType.get(
|
||||
new_shape, ref_ty.element_type, new_layout, ref_ty.memory_space
|
||||
),
|
||||
)
|
||||
return ref_slice
|
||||
|
||||
|
||||
def _is_contiguous_shape_slice(
|
||||
ref_ty: ir.MemRefType, dim_slice: slice | None = slice(None)
|
||||
):
|
||||
# If it's not a strided layout then we are definitely contiguous.
|
||||
if not ir.StridedLayoutAttr.isinstance(ref_ty.layout):
|
||||
return True
|
||||
|
||||
strides = ir.StridedLayoutAttr(ref_ty.layout).strides[dim_slice]
|
||||
shape = ref_ty.shape[dim_slice]
|
||||
|
||||
# Check that each dimension fits exactly it the immediately larger stride.
|
||||
ss = sorted(zip(strides, shape), key=lambda x: x[0], reverse=True)
|
||||
for (prev_stride, _), (stride, shape) in zip(ss, ss[1:]):
|
||||
if stride * shape != prev_stride:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def memref_fold(ref: ir.Value, dim, fold_rank) -> ir.Value:
|
||||
ref_ty = ir.MemRefType(ref.type)
|
||||
new_shape = list(ref_ty.shape)
|
||||
new_shape[dim : dim + fold_rank] = [np.prod(new_shape[dim : dim + fold_rank])]
|
||||
identity = ir.AffineMapAttr.get(ir.AffineMap.get_identity(ref_ty.rank))
|
||||
if ref_ty.layout == identity:
|
||||
new_layout = ir.AffineMapAttr.get(
|
||||
ir.AffineMap.get_identity(ref_ty.rank - fold_rank + 1)
|
||||
)
|
||||
elif _is_contiguous_shape_slice(ref_ty, slice(dim, dim + fold_rank)):
|
||||
new_strides, offset = ref_ty.get_strides_and_offset()
|
||||
new_strides[dim : dim + fold_rank] = [new_strides[dim + fold_rank - 1]]
|
||||
new_layout = ir.StridedLayoutAttr.get(offset, new_strides)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"strides={ref_ty.get_strides_and_offset()[0]}, {ref_ty.shape=},"
|
||||
f" {dim=}, {fold_rank=}"
|
||||
)
|
||||
|
||||
new_ty = ir.MemRefType.get(
|
||||
new_shape, ref_ty.element_type, new_layout, ref_ty.memory_space
|
||||
)
|
||||
assoc = [[d] for d in range(dim)]
|
||||
assoc.append([dim + i for i in range(fold_rank)])
|
||||
assoc.extend([d] for d in range(dim + fold_rank, ref_ty.rank))
|
||||
assert len(assoc) == new_ty.rank
|
||||
return memref.collapse_shape(new_ty, ref, assoc)
|
||||
|
||||
|
||||
def memref_unfold(ref: ir.Value, dim, factors) -> ir.Value:
|
||||
"""Unfolds dim into two dimensions, the size of leading one given be major_factor."""
|
||||
ref_ty = ir.MemRefType(ref.type)
|
||||
new_shape = list(ref_ty.shape)
|
||||
if sum(f is None for f in factors) > 1:
|
||||
raise ValueError("Can only infer one dimension")
|
||||
known_factor_prod = np.prod([f for f in factors if f is not None])
|
||||
if new_shape[dim] % known_factor_prod:
|
||||
raise ValueError("Non-divisible unfold:", new_shape[dim], factors)
|
||||
factors = tuple(
|
||||
new_shape[dim] // known_factor_prod if f is None else f for f in factors
|
||||
)
|
||||
new_shape[dim : dim + 1] = factors
|
||||
identity = ir.AffineMapAttr.get(ir.AffineMap.get_identity(ref_ty.rank))
|
||||
if ref_ty.layout == identity:
|
||||
new_layout = ir.AffineMapAttr.get(
|
||||
ir.AffineMap.get_identity(ref_ty.rank + len(factors) - 1)
|
||||
)
|
||||
else:
|
||||
new_strides, offset = ref_ty.get_strides_and_offset()
|
||||
prev_stride = new_strides[dim]
|
||||
inserted_strides = []
|
||||
for f in reversed(factors):
|
||||
inserted_strides.append(prev_stride)
|
||||
prev_stride *= f
|
||||
new_strides[dim : dim + 1] = reversed(inserted_strides)
|
||||
new_layout = ir.StridedLayoutAttr.get(offset, new_strides)
|
||||
new_ty = ir.MemRefType.get(
|
||||
new_shape, ref_ty.element_type, new_layout, ref_ty.memory_space
|
||||
)
|
||||
if dim == ref_ty.rank:
|
||||
assoc = [[d] for d in range(ref_ty.rank)]
|
||||
assoc[-1].extend(range(ref_ty.rank, ref_ty.rank + len(factors) - 1))
|
||||
else:
|
||||
assoc = [[d] for d in range(dim)]
|
||||
assoc.append(list(range(dim, dim + len(factors))))
|
||||
assoc.extend([d + len(factors) - 1] for d in range(dim + 1, ref_ty.rank))
|
||||
assert len(assoc) == ref_ty.rank
|
||||
return memref.expand_shape(new_ty, ref, assoc)
|
||||
|
||||
|
||||
def memref_unsqueeze(ref: ir.Value, dim) -> ir.Value:
|
||||
"""Inserts a singleton dimension."""
|
||||
ref_ty = ir.MemRefType(ref.type)
|
||||
if dim == ref_ty.rank:
|
||||
new_shape = list(ref_ty.shape)
|
||||
new_shape.append(1)
|
||||
identity = ir.AffineMapAttr.get(ir.AffineMap.get_identity(ref_ty.rank))
|
||||
if ref_ty.layout == identity:
|
||||
new_layout = ir.AffineMapAttr.get(
|
||||
ir.AffineMap.get_identity(ref_ty.rank + 1)
|
||||
)
|
||||
else:
|
||||
new_strides, offset = ref_ty.get_strides_and_offset()
|
||||
new_strides.append(1)
|
||||
new_layout = ir.StridedLayoutAttr.get(offset, new_strides)
|
||||
new_ty = ir.MemRefType.get(
|
||||
new_shape, ref_ty.element_type, new_layout, ref_ty.memory_space
|
||||
)
|
||||
assoc = [[d] for d in range(ref_ty.rank)]
|
||||
assoc[-1].append(ref_ty.rank)
|
||||
return memref.expand_shape(new_ty, ref, assoc)
|
||||
else:
|
||||
return memref_unfold(ref, dim, (1, None))
|
||||
|
||||
|
||||
def memref_transpose(ref: ir.Value, permutation: Sequence[int]) -> ir.Value:
|
||||
ref_ty = ir.MemRefType(ref.type)
|
||||
strides, offset = ref_ty.get_strides_and_offset()
|
||||
new_strides = [strides[p] for p in permutation]
|
||||
new_shape = [ref_ty.shape[p] for p in permutation]
|
||||
new_layout = ir.StridedLayoutAttr.get(offset, new_strides)
|
||||
new_ty = ir.MemRefType.get(
|
||||
new_shape, ref_ty.element_type, new_layout, ref_ty.memory_space
|
||||
)
|
||||
return memref.transpose(
|
||||
new_ty, ref, ir.AffineMap.get_permutation(permutation)
|
||||
)
|
||||
|
||||
|
||||
def parse_indices(
|
||||
index, shape: tuple[int, ...]
|
||||
) -> tuple[list[ir.Value | int], list[int], list[bool]]:
|
||||
if not isinstance(index, tuple):
|
||||
index = (index,)
|
||||
if trailing_dims := len(shape) - len(index):
|
||||
index += (slice(None),) * trailing_dims
|
||||
base_indices = []
|
||||
slice_shape = []
|
||||
is_squeezed = []
|
||||
for idx, bound in zip(index, shape):
|
||||
if isinstance(idx, (ir.Operation, ir.OpView)):
|
||||
idx = idx.result
|
||||
if isinstance(idx, int):
|
||||
base_indices.append(idx)
|
||||
slice_shape.append(1)
|
||||
is_squeezed.append(True)
|
||||
elif isinstance(idx, slice):
|
||||
if idx.step is not None:
|
||||
raise NotImplementedError("Strided slices not implemented")
|
||||
base_indices.append(idx.start or 0)
|
||||
slice_shape.append((idx.stop or bound) - (idx.start or 0))
|
||||
is_squeezed.append(False)
|
||||
elif isinstance(idx, DynamicSlice):
|
||||
base_indices.append(idx.base)
|
||||
slice_shape.append(idx.length)
|
||||
is_squeezed.append(False)
|
||||
elif isinstance(idx, ir.Value):
|
||||
if not ir.IndexType.isinstance(idx.type):
|
||||
raise ValueError("Expected an index-typed index")
|
||||
base_indices.append(idx)
|
||||
slice_shape.append(1)
|
||||
is_squeezed.append(True)
|
||||
else:
|
||||
raise NotImplementedError(type(idx))
|
||||
assert len(base_indices) == len(slice_shape) == len(is_squeezed) == len(shape)
|
||||
return base_indices, slice_shape, is_squeezed
|
||||
|
||||
|
||||
def commit_shared():
|
||||
gpu.barrier()
|
||||
nvvm.fence_proxy(
|
||||
nvvm.ProxyKind.async_shared, space=nvvm.SharedSpace.shared_cta
|
||||
)
|
||||
|
||||
|
||||
class BarrierArray:
|
||||
|
||||
def __init__(self, num_barriers):
|
||||
barrier_group_ty = ir.Type.parse(
|
||||
"!nvgpu.mbarrier.group<memorySpace=#gpu.address_space<workgroup>,"
|
||||
f" num_barriers={num_barriers}>"
|
||||
)
|
||||
|
||||
self.value = nvgpu.mbarrier_create(barrier_group_ty)
|
||||
index = ir.IndexType.get()
|
||||
if num_barriers > 32:
|
||||
raise NotImplementedError("Only up to 32 barriers per group supported")
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
self.phases = memref.alloca(ir.MemRefType.get((), i32), [], [])
|
||||
memref.store(c(0, i32), self.phases, [])
|
||||
with once():
|
||||
for i in range(num_barriers):
|
||||
nvgpu.mbarrier_init(self.value, c(1, index), c(i, index))
|
||||
|
||||
def __getitem__(self, offset: ir.Value | int):
|
||||
if isinstance(offset, int):
|
||||
offset = c(offset, ir.IndexType.get())
|
||||
return Barrier(self, offset)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Barrier:
|
||||
barrier_array: BarrierArray
|
||||
offset: ir.Value
|
||||
|
||||
def wait_parity(self, parity):
|
||||
index = ir.IndexType.get()
|
||||
nvgpu.mbarrier_try_wait_parity(
|
||||
self.barrier_array.value, parity, c(10000000, index), self.offset,
|
||||
)
|
||||
|
||||
def wait(self):
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
parities = memref.load(self.barrier_array.phases, [])
|
||||
offset_i32 = arith.index_castui(i32, self.offset)
|
||||
bitmask = arith.shli(c(1, i32), offset_i32)
|
||||
parity = arith.cmpi(
|
||||
arith.CmpIPredicate.ne, arith.andi(parities, bitmask), c(0, i32)
|
||||
)
|
||||
new_parities = arith.xori(parities, bitmask)
|
||||
memref.store(new_parities, self.barrier_array.phases, [])
|
||||
self.wait_parity(parity)
|
||||
|
||||
|
||||
class Partition:
|
||||
source_bounds: tuple[int, ...]
|
||||
target_bounds: tuple[int, ...]
|
||||
partition: tuple[int | None, ...]
|
||||
base_offset: tuple[ir.Value, ...] | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
elements: tuple[int, ...],
|
||||
*,
|
||||
partition: tuple[int | None, ...],
|
||||
base_offset: tuple[ir.Value, ...] | None = None,
|
||||
num_chunks: tuple[int, ...] | None = None,
|
||||
chunk_size: tuple[int, ...] | None = None,
|
||||
):
|
||||
self.target_bounds = elements
|
||||
self.partition = partition
|
||||
self.base_offset = base_offset
|
||||
if len(self.target_bounds) != len(self.partition):
|
||||
raise ValueError
|
||||
if num_chunks is None == chunk_size is None:
|
||||
raise ValueError(
|
||||
"Exactly one of num_chunks and chunk_size must be specified"
|
||||
)
|
||||
if num_chunks is not None:
|
||||
self.source_bounds = num_chunks
|
||||
else:
|
||||
if len(chunk_size) != len(self.target_bounds):
|
||||
raise ValueError
|
||||
source_bounds = []
|
||||
for els, chunk in zip(elements, chunk_size):
|
||||
if els % chunk:
|
||||
raise ValueError("Non-divisible partition", elements, chunk_size)
|
||||
source_bounds.append(els // chunk)
|
||||
self.source_bounds = tuple(source_bounds)
|
||||
|
||||
seen_dims = set()
|
||||
for p in self.partition:
|
||||
if p is None:
|
||||
continue
|
||||
if not (0 <= p < len(self.source_bounds)):
|
||||
raise ValueError
|
||||
if p in seen_dims:
|
||||
raise ValueError
|
||||
seen_dims.add(p)
|
||||
for tb, p in zip(self.target_bounds, self.partition):
|
||||
if p is not None and tb % self.source_bounds[p]:
|
||||
raise ValueError("Non-divisible partitioning")
|
||||
|
||||
@property
|
||||
def num_chunks(self) -> tuple[int, ...]:
|
||||
return self.source_bounds
|
||||
|
||||
@property
|
||||
def target_block_shape(self):
|
||||
return tuple(tb if p is None else tb // self.source_bounds[p]
|
||||
for tb, p in zip(self.target_bounds, self.partition))
|
||||
|
||||
def get_base(self, *source_coords: ir.Value | int) -> list[ir.Value]:
|
||||
coords = []
|
||||
index = ir.IndexType.get()
|
||||
for i, (tbs, p) in enumerate(zip(self.target_block_shape, self.partition)):
|
||||
if p is None:
|
||||
dim_base = c(0, index)
|
||||
else:
|
||||
dim_base = arith.muli(c(tbs, index), source_coords[p])
|
||||
if self.base_offset is not None:
|
||||
dim_base = arith.addi(self.base_offset[i], dim_base)
|
||||
coords.append(dim_base)
|
||||
return coords
|
||||
|
||||
|
||||
class Partition1D:
|
||||
partition: Partition
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
elements: int,
|
||||
*,
|
||||
base_offset: ir.Value | None = None,
|
||||
num_chunks: int | None = None,
|
||||
chunk_size: int | None = None,
|
||||
):
|
||||
self.base_offset = base_offset
|
||||
if num_chunks is None == chunk_size is None:
|
||||
raise ValueError(
|
||||
"Exactly one of num_chunks and chunk_size must be specified"
|
||||
)
|
||||
common_kwargs = dict(elements=(elements,), partition=(0,))
|
||||
if base_offset is not None:
|
||||
common_kwargs["base_offset"] = (base_offset,)
|
||||
if num_chunks is not None:
|
||||
self.partition = Partition(num_chunks=(num_chunks,), **common_kwargs)
|
||||
else:
|
||||
self.partition = Partition(chunk_size=(chunk_size,), **common_kwargs)
|
||||
|
||||
@property
|
||||
def num_chunks(self) -> int:
|
||||
return self.partition.source_bounds[0]
|
||||
|
||||
def get_base(self, source_coords: ir.Value) -> ir.Value:
|
||||
return self.partition.get_base(source_coords)[0]
|
||||
|
||||
def refine(
|
||||
self,
|
||||
*,
|
||||
chunk: ir.Value | None = None,
|
||||
num_chunks: int | None = None,
|
||||
chunk_size: int | None = None,
|
||||
):
|
||||
return Partition1D(
|
||||
self.partition.target_block_shape[0],
|
||||
num_chunks=num_chunks,
|
||||
chunk_size=chunk_size,
|
||||
base_offset=self.get_base(chunk) if chunk is not None else None,
|
||||
)
|
||||
|
||||
|
||||
def tile_shape(shape, tiling):
|
||||
if len(tiling) > len(shape):
|
||||
raise ValueError
|
||||
if not tiling:
|
||||
return shape
|
||||
tiling_rank = len(tiling)
|
||||
for s, t in zip(shape[-tiling_rank:], tiling):
|
||||
if s % t:
|
||||
raise ValueError("Non-divisible tiling:", shape, tiling)
|
||||
return (
|
||||
*shape[:-tiling_rank],
|
||||
*(s // t for s, t in zip(shape[-tiling_rank:], tiling)),
|
||||
*tiling,
|
||||
)
|
404
jax/experimental/mosaic/gpu/wgmma.py
Normal file
404
jax/experimental/mosaic/gpu/wgmma.py
Normal file
@ -0,0 +1,404 @@
|
||||
# Copyright 2024 The JAX Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import dataclasses
|
||||
import enum
|
||||
import itertools
|
||||
|
||||
import jax
|
||||
from jaxlib.mlir import ir
|
||||
from jaxlib.mlir.dialects import arith
|
||||
from jaxlib.mlir.dialects import builtin
|
||||
from jaxlib.mlir.dialects import llvm
|
||||
from jaxlib.mlir.dialects import nvvm
|
||||
from jaxlib.mlir.dialects import vector
|
||||
import numpy as np
|
||||
|
||||
from . import dsl as mgpu
|
||||
|
||||
# mypy: ignore-errors
|
||||
|
||||
c = mgpu.c
|
||||
bytewidth = mgpu.bytewidth
|
||||
|
||||
|
||||
@jax.tree_util.register_pytree_node_class
|
||||
@dataclasses.dataclass
|
||||
class WGMMAAccumulator:
|
||||
"""A FragmentedArray that has is synchronized with the async proxy.
|
||||
|
||||
This implies that it requires no additional synchronization when passed in
|
||||
as a WGMMA accumulator. In particular, when created from a
|
||||
FragmentedArray, the necessary synchronization is inserted at construction.
|
||||
"""
|
||||
value: mgpu.FragmentedArray
|
||||
|
||||
def __init__(self, *, _value: mgpu.FragmentedArray, _sync: bool = True):
|
||||
if _value.layout != mgpu.WGMMA_LAYOUT:
|
||||
raise ValueError("Only WGMMA layouts supported in WGMMAAccumulator")
|
||||
self.value = _value
|
||||
if _sync:
|
||||
nvvm.wgmma_fence_aligned()
|
||||
|
||||
@classmethod
|
||||
def zero(cls, m, n):
|
||||
if m % 64 or n % 8:
|
||||
raise ValueError
|
||||
f32 = ir.F32Type.get()
|
||||
zero = arith.constant(f32, ir.FloatAttr.get(f32, 0.0))
|
||||
return cls(
|
||||
_value=mgpu.FragmentedArray.splat(zero, (m, n), mgpu.WGMMA_LAYOUT)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_registers(cls, registers):
|
||||
return cls(_value=registers)
|
||||
|
||||
def tree_flatten(self):
|
||||
return (self.value,), ()
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, aux, value):
|
||||
del aux
|
||||
return cls(_value=value[0], _sync=False)
|
||||
|
||||
|
||||
def wgmma_encode(x: int):
|
||||
result = (x & 0x3FFFF) >> 4
|
||||
if result << 4 != x:
|
||||
raise ValueError("Cannot encode value in a WGMMA descriptor")
|
||||
return result
|
||||
|
||||
|
||||
def get_memref_base(memref_arg, memory_space=None):
|
||||
i64 = ir.IntegerType.get_signless(64)
|
||||
memref_ty = ir.MemRefType(memref_arg.type)
|
||||
if len(memref_ty.shape) == 0:
|
||||
raise NotImplementedError
|
||||
elem_bytewidth = bytewidth(memref_ty.element_type)
|
||||
rank = len(memref_ty.shape)
|
||||
# TODO: Read out memory space from memref
|
||||
space = "" if memory_space is None else "<" + str(memory_space) + ">"
|
||||
ptr_ty = ir.Type.parse("!llvm.ptr" + space)
|
||||
desc_ty = ir.Type.parse(
|
||||
f"!llvm.struct<({ptr_ty}, {ptr_ty}, i64, array<{rank} x i64>,"
|
||||
f" array<{rank} x i64>)>"
|
||||
)
|
||||
desc = builtin.UnrealizedConversionCastOp([desc_ty], [memref_arg])
|
||||
aligned_ptr = llvm.extractvalue(ptr_ty, desc, [1])
|
||||
offset_elems = llvm.extractvalue(i64, desc, [2])
|
||||
offset_bytes = llvm.mul(offset_elems, c(elem_bytewidth, i64))
|
||||
return llvm.inttoptr(
|
||||
ptr_ty, llvm.add(llvm.ptrtoint(i64, aligned_ptr), offset_bytes)
|
||||
)
|
||||
|
||||
|
||||
def create_descriptor(
|
||||
memref_arg,
|
||||
leading_byte_offset: int,
|
||||
stride_byte_offset: int,
|
||||
swizzle: int | None,
|
||||
memory_space: int | None = None,
|
||||
nvgpu_type=None,
|
||||
):
|
||||
i64 = ir.IntegerType.get_signless(64)
|
||||
ptr_val = llvm.ptrtoint(i64, get_memref_base(memref_arg, memory_space))
|
||||
if swizzle is None:
|
||||
swizzle_encoding = 0
|
||||
elif swizzle == 128:
|
||||
swizzle_encoding = 1
|
||||
else:
|
||||
raise NotImplementedError(swizzle)
|
||||
encoded_base_addr = llvm.LShrOp(
|
||||
llvm.AndOp(ptr_val, c(0x3FFFF, i64)), c(4, i64)
|
||||
)
|
||||
desc_const = (
|
||||
(wgmma_encode(leading_byte_offset) << 16)
|
||||
| (wgmma_encode(stride_byte_offset) << 32)
|
||||
|
|
||||
# We ignore the offset
|
||||
(swizzle_encoding << 62)
|
||||
)
|
||||
desc = llvm.OrOp(encoded_base_addr, c(desc_const, i64))
|
||||
if nvgpu_type is not None:
|
||||
desc = builtin.UnrealizedConversionCastOp([nvgpu_type], [desc])
|
||||
return desc.result
|
||||
|
||||
|
||||
def wgmma_m64k128B(
|
||||
acc: np.ndarray, # of register Values
|
||||
a,
|
||||
b_descriptor: ir.Value,
|
||||
a_transpose: bool | None,
|
||||
b_transpose: bool,
|
||||
a_k_stride: int | None,
|
||||
b_k_stride: int,
|
||||
n: int,
|
||||
element_type: ir.Type,
|
||||
):
|
||||
f32 = ir.F32Type.get()
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
i64 = ir.IntegerType.get_signless(64)
|
||||
index = ir.IndexType.get()
|
||||
if b_k_stride % 16:
|
||||
raise ValueError
|
||||
if n % (128 // bytewidth(element_type)):
|
||||
raise ValueError
|
||||
# Only 16-bit types support transposes
|
||||
supports_transpose = bytewidth(element_type) == 2
|
||||
if not supports_transpose and (a_transpose or b_transpose):
|
||||
raise ValueError("Only f16 WGMMA supports transposes")
|
||||
if a_in_regs := isinstance(a, mgpu.FragmentedArray):
|
||||
if a.mlir_dtype != ir.F16Type.get() and a.mlir_dtype != ir.BF16Type.get():
|
||||
raise ValueError(f"Unsupported A register array dtype: {a.mlir_dtype}")
|
||||
if a.layout != mgpu.WGMMA_LAYOUT or a.shape != (64, 64):
|
||||
raise ValueError("Unsupported A register array layout")
|
||||
if a_k_stride is not None or a_transpose is not None:
|
||||
raise ValueError("Unsupported WGMMA features with A in registers")
|
||||
else:
|
||||
if a_k_stride is None or a_k_stride % 16:
|
||||
raise ValueError
|
||||
if a_transpose is None:
|
||||
raise ValueError
|
||||
|
||||
num_acc_regs = n // 2
|
||||
num_imm_regs = 4 if supports_transpose else 2
|
||||
|
||||
if a_in_regs:
|
||||
a_reg_constraints = ["r"] * 4 # 4x f16x2 registers
|
||||
num_imm_regs -= 1 # transpose not supported for a in registers
|
||||
else:
|
||||
a_reg_constraints = ["l"] # descriptor
|
||||
# Reference for i/o aliasing: https://gcc.gnu.org/onlinedocs/gcc/Extended-Asm.html
|
||||
# Seems like it's not actually documented in LLVM IR docs.
|
||||
reg_constraints_list = (
|
||||
["=f"] * num_acc_regs # accumulator registers
|
||||
+ [str(i) for i in range(num_acc_regs)] # we alias outputs as inputs, too.
|
||||
+ a_reg_constraints # a descriptor / registers
|
||||
+ ["l"] * 1 # b descriptor
|
||||
+ ["n"] * (1 + num_imm_regs) # literal constants
|
||||
)
|
||||
reg_constraints = ",".join(reg_constraints_list)
|
||||
|
||||
reg_count = itertools.count()
|
||||
|
||||
def take_regs(n):
|
||||
return (f"${i}" for i in itertools.islice(reg_count, n))
|
||||
|
||||
acc_reg_vector = "{" + ",".join(take_regs(num_acc_regs)) + "}"
|
||||
for _ in take_regs(num_acc_regs): # Ignore next entries: aliasing.
|
||||
pass
|
||||
if a_in_regs:
|
||||
a_regs = "{" + ",".join(take_regs(len(a_reg_constraints))) + "}"
|
||||
else:
|
||||
a_regs, = take_regs(1)
|
||||
b_desc_reg, use_out_reg = take_regs(2)
|
||||
imm_regs = ", ".join(take_regs(num_imm_regs)) # Immediate regs (scale, ...).
|
||||
assert next(reg_count) == len(reg_constraints_list)
|
||||
el_ty = element_type
|
||||
k_instr = 32 // bytewidth(element_type)
|
||||
wgmma_instr = (
|
||||
f"wgmma.mma_async.sync.aligned.m64n{n}k{k_instr}.f32.{el_ty}.{el_ty} "
|
||||
f"{acc_reg_vector}, {a_regs}, {b_desc_reg}, p, {imm_regs};"
|
||||
)
|
||||
ptx = f"{{ .reg .pred p; setp.ne.b32 p, {use_out_reg}, 0; {wgmma_instr} }}\n"
|
||||
|
||||
def lc(x):
|
||||
return llvm.mlir_constant(i32, ir.IntegerAttr.get(i32, x))
|
||||
|
||||
def as_i32_reg(v):
|
||||
return llvm.extractelement(
|
||||
vector.bitcast(ir.VectorType.get((1,), i32), v), lc(0)
|
||||
)
|
||||
|
||||
use_out = scale_a = scale_b = lc(1)
|
||||
imms = [use_out, scale_a, scale_b]
|
||||
if supports_transpose and a_transpose is not None:
|
||||
imms += [lc(int(a_transpose)), lc(int(b_transpose))]
|
||||
elif supports_transpose:
|
||||
imms += [lc(int(b_transpose))]
|
||||
if acc.ndim != 4 or acc.shape[0] != 1 or acc.shape[2:] != (2, 1):
|
||||
raise ValueError(acc.shape)
|
||||
acc_regs = [ # pylint: disable=g-complex-comprehension
|
||||
vector.extractelement(reg, position=c(pos, index))
|
||||
for reg in acc.flat
|
||||
for pos in range(2)
|
||||
]
|
||||
acc_struct_type = ir.Type.parse(
|
||||
f"!llvm.struct<({','.join('f32' for _ in acc_regs)})>"
|
||||
)
|
||||
for i in range(4):
|
||||
# Slice out the relevant part of A or advance the A descriptor.
|
||||
if a_in_regs:
|
||||
a_slice = a[:, (i * 16) : ((i + 1) * 16)]
|
||||
a_args = [as_i32_reg(v) for v in a_slice.registers.flat]
|
||||
else:
|
||||
if i > 0:
|
||||
a = llvm.add(
|
||||
a,
|
||||
llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, a_k_stride >> 4)),
|
||||
)
|
||||
a_args = [a]
|
||||
# Advance the B descriptor.
|
||||
if i > 0:
|
||||
b_descriptor = llvm.add(
|
||||
b_descriptor,
|
||||
llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, b_k_stride >> 4)),
|
||||
)
|
||||
assert len(a_args) == len(a_reg_constraints)
|
||||
acc_struct = llvm.inline_asm(
|
||||
acc_struct_type,
|
||||
[*acc_regs, *a_args, b_descriptor, *imms],
|
||||
ptx,
|
||||
reg_constraints,
|
||||
asm_dialect=0,
|
||||
has_side_effects=True,
|
||||
)
|
||||
acc_regs = [
|
||||
llvm.extractvalue(f32, acc_struct, [i]) for i in range(len(acc_regs))
|
||||
]
|
||||
acc_vec_regs = []
|
||||
for first, second in zip(acc_regs[::2], acc_regs[1::2]):
|
||||
vec = llvm.mlir_undef(ir.VectorType.get((2,), f32))
|
||||
vec = llvm.insertelement(vec, first, position=lc(0))
|
||||
vec = llvm.insertelement(vec, second, position=lc(1))
|
||||
acc_vec_regs.append(vec)
|
||||
return np.asarray(acc_vec_regs, dtype=object).reshape(acc.shape)
|
||||
|
||||
|
||||
class WGMMALayout(enum.Enum):
|
||||
ROW_MAJOR = enum.auto()
|
||||
COL_MAJOR = enum.auto()
|
||||
|
||||
|
||||
# TODO(apaszke): Remove WGMMALayout. Make input shapes logical and infer
|
||||
# transpositions from memref strides.
|
||||
def wgmma(
|
||||
acc: WGMMAAccumulator,
|
||||
a,
|
||||
b,
|
||||
*,
|
||||
# Order only applies within each tile!
|
||||
a_order: WGMMALayout | None = None,
|
||||
b_order: WGMMALayout = WGMMALayout.ROW_MAJOR,
|
||||
):
|
||||
if a_in_regs := isinstance(a, mgpu.FragmentedArray):
|
||||
a_element_type = a.mlir_dtype
|
||||
a_shape = a.shape
|
||||
else:
|
||||
a_ty = ir.MemRefType(a.type)
|
||||
a_element_type = a_ty.element_type
|
||||
a_shape = a_ty.shape
|
||||
b_ty = ir.MemRefType(b.type)
|
||||
supported_types = {ir.F16Type.get(), ir.BF16Type.get(), ir.F32Type.get()}
|
||||
if a_element_type not in supported_types:
|
||||
raise ValueError(a_element_type)
|
||||
if b_ty.element_type not in supported_types:
|
||||
raise ValueError(b_ty.element_type)
|
||||
if (element_type := a_element_type) != b_ty.element_type:
|
||||
raise ValueError
|
||||
element_bytewidth = bytewidth(element_type)
|
||||
kn_tile = 128 // element_bytewidth
|
||||
|
||||
groups_k, groups_n = b_ty.shape[:2]
|
||||
if b_ty.shape[2:] != [kn_tile, kn_tile]:
|
||||
raise ValueError(b_ty.shape)
|
||||
|
||||
if a_in_regs:
|
||||
if a_element_type != ir.F16Type.get() and a_element_type != ir.BF16Type.get():
|
||||
raise ValueError(a_element_type)
|
||||
if a_shape[0] % 64 or a_shape[1] % kn_tile:
|
||||
raise ValueError(a_shape)
|
||||
if a_shape[1] // kn_tile != groups_k:
|
||||
raise ValueError(a_shape[1] // kn_tile, groups_k)
|
||||
groups_m = a_shape[0] // 64
|
||||
if a_order is not None:
|
||||
raise ValueError(
|
||||
"a_order can only be specified when A is in shared memory"
|
||||
)
|
||||
else:
|
||||
groups_m = a_shape[0]
|
||||
if a_shape[1] != groups_k:
|
||||
raise ValueError(a_shape[1], groups_k)
|
||||
if a_shape[2:] != [64, kn_tile]:
|
||||
raise ValueError(a_shape)
|
||||
if a_order is None:
|
||||
a_order = WGMMALayout.ROW_MAJOR
|
||||
|
||||
row_major = WGMMALayout.ROW_MAJOR
|
||||
col_major = WGMMALayout.COL_MAJOR
|
||||
a_desc_fields = dict(
|
||||
leading_byte_offset=((1 if a_order == row_major else 512) << 4),
|
||||
stride_byte_offset=(64 << 4),
|
||||
swizzle=128,
|
||||
memory_space=3,
|
||||
)
|
||||
b_desc_fields = dict(
|
||||
leading_byte_offset=((512 if b_order == row_major else 1) << 4),
|
||||
stride_byte_offset=(64 << 4),
|
||||
swizzle=128,
|
||||
memory_space=3,
|
||||
)
|
||||
wgmma_params = dict(
|
||||
a_transpose=a_order == col_major,
|
||||
b_transpose=b_order == row_major,
|
||||
a_k_stride=(2 if a_order == row_major else 128) * 16,
|
||||
b_k_stride=(128 if b_order == row_major else 2) * 16,
|
||||
n=(groups_n * kn_tile),
|
||||
element_type=ir.FloatTF32Type.get()
|
||||
if ir.F32Type.isinstance(element_type)
|
||||
else element_type,
|
||||
)
|
||||
if a_in_regs:
|
||||
wgmma_params["a_k_stride"] = wgmma_params["a_transpose"] = None
|
||||
|
||||
if a_in_regs:
|
||||
nvvm.wgmma_fence_aligned() # Make sure the registers are ready.
|
||||
a_m_byte_stride = a_k_byte_stride = a_desc_base = None # Silence pytype.
|
||||
else:
|
||||
a_desc_base = create_descriptor(a, **a_desc_fields)
|
||||
a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset()
|
||||
a_byte_strides = [s * element_bytewidth for s in a_strides]
|
||||
a_m_byte_stride, a_k_byte_stride = a_byte_strides[:2]
|
||||
if a_byte_strides[2:] != [128, element_bytewidth]:
|
||||
raise ValueError(a_byte_strides)
|
||||
b_desc_base = create_descriptor(b, **b_desc_fields)
|
||||
b_strides, _ = b_ty.get_strides_and_offset()
|
||||
b_byte_strides = [s * element_bytewidth for s in b_strides]
|
||||
b_k_byte_stride = b_byte_strides[0]
|
||||
if b_byte_strides[1:] != [128 * kn_tile, 128, element_bytewidth]:
|
||||
raise ValueError(b_byte_strides)
|
||||
|
||||
i64 = ir.IntegerType.get_signless(64)
|
||||
new_acc_regs = acc.value.registers.copy()
|
||||
for mi in range(groups_m):
|
||||
for ki in range(groups_k):
|
||||
if a_in_regs:
|
||||
a_mk = a[mi * 64 : (mi + 1) * 64, ki * kn_tile : (ki + 1) * kn_tile]
|
||||
else:
|
||||
a_mk = llvm.add(
|
||||
a_desc_base,
|
||||
c(wgmma_encode(mi * a_m_byte_stride + ki * a_k_byte_stride), i64),
|
||||
)
|
||||
b_k = llvm.add(b_desc_base, c(wgmma_encode(ki * b_k_byte_stride), i64))
|
||||
new_acc_regs[mi : mi + 1] = wgmma_m64k128B(
|
||||
new_acc_regs[mi : mi + 1], a_mk, b_k, **wgmma_params
|
||||
)
|
||||
return WGMMAAccumulator(
|
||||
_value=mgpu.FragmentedArray(
|
||||
_registers=new_acc_regs, _layout=mgpu.WGMMA_LAYOUT
|
||||
),
|
||||
_sync=False,
|
||||
)
|
@ -17,6 +17,7 @@
|
||||
load("//jaxlib:symlink_files.bzl", "symlink_files")
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"if_building_mosaic_gpu",
|
||||
"if_windows",
|
||||
"py_library_providing_imports_info",
|
||||
"pybind_extension",
|
||||
@ -70,7 +71,7 @@ py_library_providing_imports_info(
|
||||
"//jaxlib/mlir:vector_dialect",
|
||||
"//jaxlib/mosaic",
|
||||
"//jaxlib/triton",
|
||||
],
|
||||
] + if_building_mosaic_gpu(["//jaxlib/mosaic/gpu:mosaic_gpu"]),
|
||||
)
|
||||
|
||||
symlink_files(
|
||||
|
@ -161,6 +161,12 @@ def if_building_jaxlib(if_building, if_not_building = []):
|
||||
"//conditions:default": if_not_building,
|
||||
})
|
||||
|
||||
def if_building_mosaic_gpu(if_building, if_not_building = []):
|
||||
return select({
|
||||
"//jax:enable_mosaic_gpu": if_building,
|
||||
"//conditions:default": if_not_building,
|
||||
})
|
||||
|
||||
# buildifier: disable=function-docstring
|
||||
def jax_test(
|
||||
name,
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("//jaxlib:symlink_files.bzl", "symlink_inputs")
|
||||
load("//jaxlib:symlink_files.bzl", "symlink_files", "symlink_inputs")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
@ -217,3 +217,94 @@ symlink_inputs(
|
||||
"//jaxlib/mlir/_mlir_libs:_stablehlo",
|
||||
],
|
||||
)
|
||||
|
||||
symlink_inputs(
|
||||
name = "execution_engine",
|
||||
rule = py_library,
|
||||
symlinked_inputs = {"srcs": {
|
||||
".": [
|
||||
"@llvm-project//mlir/python:ExecutionEnginePyFiles",
|
||||
],
|
||||
}},
|
||||
deps = [
|
||||
":mlir",
|
||||
"//jaxlib/mlir/_mlir_libs:_mlirExecutionEngine",
|
||||
],
|
||||
)
|
||||
|
||||
symlink_inputs(
|
||||
name = "nvgpu_dialect",
|
||||
rule = py_library,
|
||||
symlinked_inputs = {"srcs": {"dialects": [
|
||||
"@llvm-project//mlir/python:NVGPUOpsPyFiles",
|
||||
]}},
|
||||
deps = [
|
||||
":core",
|
||||
":ir",
|
||||
":mlir",
|
||||
],
|
||||
)
|
||||
|
||||
symlink_inputs(
|
||||
name = "nvvm_dialect",
|
||||
rule = py_library,
|
||||
symlinked_inputs = {"srcs": {"dialects": [
|
||||
"@llvm-project//mlir/python:NVVMOpsPyFiles",
|
||||
]}},
|
||||
deps = [
|
||||
":core",
|
||||
":ir",
|
||||
":mlir",
|
||||
],
|
||||
)
|
||||
|
||||
symlink_files(
|
||||
name = "gpu_files",
|
||||
srcs = ["@llvm-project//mlir/python:GPUOpsPyFiles"],
|
||||
dst = "dialects",
|
||||
flatten = True,
|
||||
)
|
||||
|
||||
symlink_files(
|
||||
name = "gpu_package_files",
|
||||
srcs = ["@llvm-project//mlir/python:GPUOpsPackagePyFiles"],
|
||||
dst = "dialects/gpu",
|
||||
flatten = True,
|
||||
)
|
||||
|
||||
symlink_files(
|
||||
name = "gpu_package_passes_files",
|
||||
srcs = ["@llvm-project//mlir/python:GPUOpsPackagePassesPyFiles"],
|
||||
dst = "dialects/gpu/passes",
|
||||
flatten = True,
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "gpu_dialect",
|
||||
srcs = [
|
||||
":gpu_files",
|
||||
":gpu_package_files",
|
||||
":gpu_package_passes_files",
|
||||
],
|
||||
deps = [
|
||||
":core",
|
||||
":ir",
|
||||
":mlir",
|
||||
"//jaxlib/mlir/_mlir_libs:_mlirGPUPasses",
|
||||
],
|
||||
)
|
||||
|
||||
symlink_inputs(
|
||||
name = "llvm_dialect",
|
||||
rule = py_library,
|
||||
symlinked_inputs = {"srcs": {"dialects": [
|
||||
"@llvm-project//mlir/python:LLVMOpsPyFiles",
|
||||
]}},
|
||||
deps = [
|
||||
":core",
|
||||
":ir",
|
||||
":mlir",
|
||||
"//jaxlib/mlir/_mlir_libs:_mlirDialectsLLVM",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"if_building_mosaic_gpu",
|
||||
"if_windows",
|
||||
"py_extension",
|
||||
"pybind_extension",
|
||||
@ -58,6 +59,51 @@ py_extension(
|
||||
],
|
||||
)
|
||||
|
||||
py_extension(
|
||||
name = "_mlirExecutionEngine",
|
||||
srcs = [
|
||||
"@llvm-project//mlir:lib/Bindings/Python/ExecutionEngineModule.cpp",
|
||||
],
|
||||
copts = COPTS,
|
||||
linkopts = LINKOPTS,
|
||||
deps = [
|
||||
":jaxlib_mlir_capi_shared_library",
|
||||
"@llvm-project//mlir:CAPIExecutionEngineHeaders",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
py_extension(
|
||||
name = "_mlirGPUPasses",
|
||||
srcs = [
|
||||
"@llvm-project//mlir:lib/Bindings/Python/GPUPasses.cpp",
|
||||
],
|
||||
copts = COPTS,
|
||||
linkopts = LINKOPTS,
|
||||
deps = [
|
||||
":jaxlib_mlir_capi_shared_library",
|
||||
"@llvm-project//mlir:CAPIGPUHeaders",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
py_extension(
|
||||
name = "_mlirDialectsLLVM",
|
||||
srcs = [
|
||||
"@llvm-project//mlir:lib/Bindings/Python/DialectLLVM.cpp",
|
||||
],
|
||||
copts = COPTS,
|
||||
linkopts = LINKOPTS,
|
||||
deps = [
|
||||
":jaxlib_mlir_capi_shared_library",
|
||||
"@llvm-project//mlir:CAPIIRHeaders",
|
||||
"@llvm-project//mlir:CAPILLVMHeaders",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
py_extension(
|
||||
name = "_mlirDialectsSparseTensor",
|
||||
srcs = [
|
||||
@ -148,11 +194,36 @@ symlink_inputs(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "jaxlib_mlir_capi_shims",
|
||||
srcs = ["jaxlib_mlir_capi_shims.cc"],
|
||||
hdrs = ["jaxlib_mlir_capi_shims.h"],
|
||||
deps = [
|
||||
"@llvm-project//mlir:BuiltinToLLVMIRTranslation",
|
||||
"@llvm-project//mlir:CAPIIRHeaders",
|
||||
"@llvm-project//mlir:GPUPipelines",
|
||||
"@llvm-project//mlir:GPUToLLVMIRTranslation",
|
||||
"@llvm-project//mlir:LLVMToLLVMIRTranslation",
|
||||
"@llvm-project//mlir:MemRefTransforms",
|
||||
"@llvm-project//mlir:NVVMTarget",
|
||||
"@llvm-project//mlir:NVVMToLLVMIRTranslation",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "jaxlib_mlir_capi_shims_hdrs",
|
||||
hdrs = ["jaxlib_mlir_capi_shims.h"],
|
||||
deps = [
|
||||
"@llvm-project//mlir:CAPIIRHeaders",
|
||||
],
|
||||
)
|
||||
|
||||
# JAX-specific registrations.
|
||||
py_extension(
|
||||
name = "register_jax_dialects",
|
||||
srcs = ["register_jax_dialects.cc"],
|
||||
copts = COPTS,
|
||||
copts = COPTS + if_building_mosaic_gpu(["-DJAXLIB_MOSAIC_GPU"]),
|
||||
linkopts = LINKOPTS,
|
||||
deps = [
|
||||
":jaxlib_mlir_capi_shared_library",
|
||||
@ -166,7 +237,14 @@ py_extension(
|
||||
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
|
||||
"@local_config_python//:headers",
|
||||
"@pybind11",
|
||||
],
|
||||
] + if_building_mosaic_gpu([
|
||||
":jaxlib_mlir_capi_shims_hdrs",
|
||||
"@llvm-project//mlir:CAPIGPUHeaders",
|
||||
"@llvm-project//mlir:CAPINVGPUHeaders",
|
||||
"@llvm-project//mlir:CAPINVVMHeaders",
|
||||
"@llvm-project//mlir:CAPILLVMHeaders",
|
||||
"@llvm-project//mlir:CAPIConversionHeaders",
|
||||
]),
|
||||
)
|
||||
|
||||
##---------------------------------------------------------------------------##
|
||||
@ -270,7 +348,15 @@ cc_library(
|
||||
[
|
||||
"//jaxlib/triton:triton_dialect_capi_objects",
|
||||
],
|
||||
),
|
||||
) + if_building_mosaic_gpu([
|
||||
":jaxlib_mlir_capi_shims",
|
||||
"@llvm-project//mlir:CAPIConversionObjects",
|
||||
"@llvm-project//mlir:CAPIExecutionEngineObjects",
|
||||
"@llvm-project//mlir:CAPIGPUObjects",
|
||||
"@llvm-project//mlir:CAPILLVMObjects",
|
||||
"@llvm-project//mlir:CAPINVGPUObjects",
|
||||
"@llvm-project//mlir:CAPINVVMObjects",
|
||||
]),
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
|
47
jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi_shims.cc
Normal file
47
jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi_shims.cc
Normal file
@ -0,0 +1,47 @@
|
||||
/* Copyright 2024 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi_shims.h"
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir/CAPI/IR.h"
|
||||
#include "mlir/Dialect/GPU/Pipelines/Passes.h"
|
||||
#include "mlir/Target/LLVM/NVVM/Target.h"
|
||||
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
|
||||
#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
|
||||
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
|
||||
#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
|
||||
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
|
||||
|
||||
extern "C" {
|
||||
|
||||
void jaxMlirRegisterMemRefPasses() {
|
||||
mlir::memref::registerMemRefPasses();
|
||||
}
|
||||
|
||||
void jaxMlirRegisterInterfaceExternalModels(MlirDialectRegistry registry) {
|
||||
mlir::NVVM::registerNVVMTargetInterfaceExternalModels(*unwrap(registry));
|
||||
mlir::gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(
|
||||
*unwrap(registry));
|
||||
mlir::registerBuiltinDialectTranslation(*unwrap(registry));
|
||||
mlir::registerGPUDialectTranslation(*unwrap(registry));
|
||||
mlir::registerLLVMDialectTranslation(*unwrap(registry));
|
||||
mlir::registerNVVMDialectTranslation(*unwrap(registry));
|
||||
}
|
||||
void jaxMlirRegisterGPUToNVVMPipeline() {
|
||||
mlir::gpu::registerGPUToNVVMPipeline();
|
||||
}
|
||||
|
||||
}
|
34
jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi_shims.h
Normal file
34
jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi_shims.h
Normal file
@ -0,0 +1,34 @@
|
||||
/* Copyright 2024 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef JAXLIB_MLIR_CAPI_SHIMS
|
||||
#define JAXLIB_MLIR_CAPI_SHIMS
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Support.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
MLIR_CAPI_EXPORTED void jaxMlirRegisterMemRefPasses();
|
||||
MLIR_CAPI_EXPORTED void jaxMlirRegisterInterfaceExternalModels(MlirDialectRegistry registry);
|
||||
MLIR_CAPI_EXPORTED void jaxMlirRegisterGPUToNVVMPipeline();
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // JAXLIB_MLIR_CAPI_SHIMS
|
@ -9,6 +9,14 @@
|
||||
#include "mlir-c/Dialect/Vector.h"
|
||||
#include "mlir-c/Transforms.h"
|
||||
#include "mlir/Bindings/Python/PybindAdaptors.h"
|
||||
#ifdef JAXLIB_MOSAIC_GPU
|
||||
#include "mlir-c/Dialect/GPU.h"
|
||||
#include "mlir-c/Dialect/NVGPU.h"
|
||||
#include "mlir-c/Dialect/NVVM.h"
|
||||
#include "mlir-c/Dialect/LLVM.h"
|
||||
#include "mlir-c/Conversion.h"
|
||||
#include "jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi_shims.h"
|
||||
#endif
|
||||
|
||||
#define REGISTER_DIALECT(name) \
|
||||
MlirDialectHandle name##_dialect = mlirGetDialectHandle__##name##__(); \
|
||||
@ -27,5 +35,17 @@ PYBIND11_MODULE(register_jax_dialects, m) {
|
||||
mlirRegisterTransformsPasses();
|
||||
// Transforms used by JAX.
|
||||
mlirRegisterTransformsStripDebugInfo();
|
||||
#ifdef JAXLIB_MOSAIC_GPU
|
||||
REGISTER_DIALECT(gpu);
|
||||
REGISTER_DIALECT(nvgpu);
|
||||
REGISTER_DIALECT(nvvm);
|
||||
REGISTER_DIALECT(llvm);
|
||||
mlirRegisterGPUPasses();
|
||||
mlirRegisterConversionPasses();
|
||||
// TODO(apaszke): Upstream and remove those.
|
||||
jaxMlirRegisterMemRefPasses();
|
||||
jaxMlirRegisterInterfaceExternalModels(registry);
|
||||
jaxMlirRegisterGPUToNVVMPipeline();
|
||||
#endif
|
||||
});
|
||||
}
|
||||
|
76
jaxlib/mosaic/gpu/BUILD
Normal file
76
jaxlib/mosaic/gpu/BUILD
Normal file
@ -0,0 +1,76 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("@rules_python//python:defs.bzl", "py_library")
|
||||
load("//jaxlib:jax.bzl", "pybind_extension")
|
||||
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
default_visibility = ["//:__subpackages__"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "mosaic_gpu",
|
||||
data = [":libmlir_cuda_runtime.so"],
|
||||
deps = [
|
||||
":_mosaic_gpu_ext",
|
||||
"//jaxlib/mlir:execution_engine",
|
||||
"//jaxlib/mlir:gpu_dialect",
|
||||
"//jaxlib/mlir:llvm_dialect",
|
||||
"//jaxlib/mlir:nvgpu_dialect",
|
||||
"//jaxlib/mlir:nvvm_dialect",
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "_mosaic_gpu_ext",
|
||||
srcs = ["_mosaic_gpu_ext.cc"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
linkopts = select({
|
||||
"@xla//xla/python:use_jax_cuda_pip_rpaths": [
|
||||
"-Wl,-rpath,$$ORIGIN/../../../nvidia/cuda_runtime/lib",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
deps = [
|
||||
"//jaxlib:kernel_nanobind_helpers",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
"@nanobind",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "libmlir_cuda_runtime.so",
|
||||
srcs = ["@llvm-project//mlir:lib/ExecutionEngine/CudaRuntimeWrappers.cpp"],
|
||||
copts = ["-fvisibility=default"],
|
||||
linkopts = select({
|
||||
"@xla//xla/python:use_jax_cuda_pip_rpaths": [
|
||||
"-Wl,-rpath,$$ORIGIN/../../../nvidia/cuda_runtime/lib",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
linkshared = 1,
|
||||
tags = [
|
||||
"manual",
|
||||
"notap",
|
||||
],
|
||||
deps = [
|
||||
"@llvm-project//mlir:mlir_c_runner_utils_hdrs",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
],
|
||||
)
|
24
jaxlib/mosaic/gpu/_mosaic_gpu_ext.cc
Normal file
24
jaxlib/mosaic/gpu/_mosaic_gpu_ext.cc
Normal file
@ -0,0 +1,24 @@
|
||||
#include "nanobind/nanobind.h"
|
||||
#include "jaxlib/kernel_nanobind_helpers.h"
|
||||
#include "xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax::cuda {
|
||||
namespace {
|
||||
|
||||
namespace nb = nanobind;
|
||||
using MosaicHostFunc = void(void**);
|
||||
|
||||
void MosaicKernelCall(void* stream, void** buffers, char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
void* args[2] = {&stream, &buffers};
|
||||
MosaicHostFunc* func = *reinterpret_cast<MosaicHostFunc**>(opaque);
|
||||
func(args);
|
||||
}
|
||||
|
||||
NB_MODULE(_mosaic_gpu_ext, m) {
|
||||
m.def("_custom_call_capsule",
|
||||
[]() { return EncapsulateFunction(MosaicKernelCall); });
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace jax::cuda
|
@ -98,10 +98,13 @@ setup(
|
||||
'cuda/*',
|
||||
'cuda/nvvm/libdevice/libdevice*',
|
||||
'mosaic/*.py',
|
||||
'mosaic/gpu/*.so',
|
||||
'mosaic/python/*.py',
|
||||
'mosaic/python/*.so',
|
||||
'mlir/*.py',
|
||||
'mlir/dialects/*.py',
|
||||
'mlir/dialects/gpu/*.py',
|
||||
'mlir/dialects/gpu/passes/*.py',
|
||||
'mlir/extras/*.py',
|
||||
'mlir/_mlir_libs/*.dll',
|
||||
'mlir/_mlir_libs/*.dylib',
|
||||
|
@ -266,12 +266,27 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi
|
||||
"__main__/jaxlib/mosaic/python/_tpu_gen.py", dst_dir=mosaic_python_dir
|
||||
)
|
||||
|
||||
has_mosaic_gpu = exists(f"__main__/jaxlib/mosaic/gpu/_mosaic_gpu_ext.{pyext}")
|
||||
def if_has_mosaic_gpu(extras):
|
||||
return extras if has_mosaic_gpu else []
|
||||
|
||||
if has_mosaic_gpu:
|
||||
copy_runfiles(
|
||||
dst_dir=jaxlib_dir / "mosaic" / "gpu",
|
||||
src_files=[
|
||||
"__main__/jaxlib/mosaic/gpu/libmlir_cuda_runtime.so",
|
||||
f"__main__/jaxlib/mosaic/gpu/_mosaic_gpu_ext.{pyext}",
|
||||
],
|
||||
)
|
||||
|
||||
copy_runfiles(
|
||||
dst_dir=jaxlib_dir / "mlir",
|
||||
src_files=[
|
||||
"__main__/jaxlib/mlir/ir.py",
|
||||
"__main__/jaxlib/mlir/passmanager.py",
|
||||
],
|
||||
] + if_has_mosaic_gpu([
|
||||
"__main__/jaxlib/mlir/execution_engine.py",
|
||||
]),
|
||||
)
|
||||
copy_runfiles(
|
||||
dst_dir=jaxlib_dir / "mlir" / "dialects",
|
||||
@ -302,7 +317,19 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi
|
||||
"__main__/jaxlib/mlir/dialects/sparse_tensor.py",
|
||||
"__main__/jaxlib/mlir/dialects/stablehlo.py",
|
||||
"__main__/jaxlib/mlir/dialects/vector.py",
|
||||
],
|
||||
] + if_has_mosaic_gpu([
|
||||
"__main__/jaxlib/mlir/dialects/_gpu_enum_gen.py",
|
||||
"__main__/jaxlib/mlir/dialects/_gpu_ops_gen.py",
|
||||
"__main__/jaxlib/mlir/dialects/_nvgpu_enum_gen.py",
|
||||
"__main__/jaxlib/mlir/dialects/_nvgpu_ops_gen.py",
|
||||
"__main__/jaxlib/mlir/dialects/_nvvm_enum_gen.py",
|
||||
"__main__/jaxlib/mlir/dialects/_nvvm_ops_gen.py",
|
||||
"__main__/jaxlib/mlir/dialects/_llvm_enum_gen.py",
|
||||
"__main__/jaxlib/mlir/dialects/_llvm_ops_gen.py",
|
||||
"__main__/jaxlib/mlir/dialects/nvgpu.py",
|
||||
"__main__/jaxlib/mlir/dialects/nvvm.py",
|
||||
"__main__/jaxlib/mlir/dialects/llvm.py",
|
||||
]),
|
||||
)
|
||||
copy_runfiles(
|
||||
dst_dir=jaxlib_dir / "mlir" / "extras",
|
||||
@ -310,6 +337,20 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi
|
||||
"__main__/jaxlib/mlir/extras/meta.py",
|
||||
],
|
||||
)
|
||||
if has_mosaic_gpu:
|
||||
copy_runfiles(
|
||||
dst_dir=jaxlib_dir / "mlir" / "dialects" / "gpu",
|
||||
src_files=[
|
||||
"__main__/jaxlib/mlir/dialects/gpu/__init__.py",
|
||||
],
|
||||
)
|
||||
copy_runfiles(
|
||||
dst_dir=jaxlib_dir / "mlir" / "dialects" / "gpu" / "passes",
|
||||
src_files=[
|
||||
"__main__/jaxlib/mlir/dialects/gpu/passes/__init__.py",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
if build_utils.is_windows():
|
||||
capi_so = "__main__/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi.dll"
|
||||
@ -339,7 +380,10 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_triton_ext.{pyext}",
|
||||
"__main__/jaxlib/mlir/_mlir_libs/_triton_ext.pyi",
|
||||
]
|
||||
),
|
||||
) + if_has_mosaic_gpu([
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsLLVM.{pyext}",
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_mlirExecutionEngine.{pyext}",
|
||||
]),
|
||||
)
|
||||
|
||||
triton_dir = jaxlib_dir / "triton"
|
||||
|
52
tests/mosaic/BUILD
Normal file
52
tests/mosaic/BUILD
Normal file
@ -0,0 +1,52 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"jax_generate_backend_suites",
|
||||
"jax_test",
|
||||
"py_deps",
|
||||
)
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
default_visibility = ["//visibility:private"],
|
||||
)
|
||||
|
||||
jax_generate_backend_suites()
|
||||
|
||||
jax_test(
|
||||
name = "gpu_test",
|
||||
srcs = [
|
||||
"gpu_test.py",
|
||||
],
|
||||
disable_backends = [
|
||||
"cpu",
|
||||
"tpu",
|
||||
],
|
||||
disable_configs = [
|
||||
"gpu",
|
||||
"gpu_a100",
|
||||
"gpu_p100",
|
||||
"gpu_p100_x32",
|
||||
"gpu_x32",
|
||||
"gpu_pjrt_c_api",
|
||||
],
|
||||
shard_count = 4,
|
||||
deps = [
|
||||
"//jax:mosaic_gpu",
|
||||
] + py_deps("absl/testing") + py_deps("numpy"),
|
||||
)
|
803
tests/mosaic/gpu_test.py
Normal file
803
tests/mosaic/gpu_test.py
Normal file
@ -0,0 +1,803 @@
|
||||
# Copyright 2024 The JAX Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for Mosaic GPU DSL functions and utilities."""
|
||||
|
||||
import operator
|
||||
from typing import Optional
|
||||
|
||||
from absl.testing import absltest, parameterized
|
||||
import numpy as np
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import arith
|
||||
from jax._src.lib.mlir.dialects import scf
|
||||
from jax._src.lib.mlir.dialects import vector
|
||||
try:
|
||||
import jax._src.lib.mosaic_gpu # noqa: F401
|
||||
HAS_MOSAIC_GPU = True
|
||||
except ImportError:
|
||||
HAS_MOSAIC_GPU = False
|
||||
else:
|
||||
from jax.experimental.mosaic import gpu as mosaic_gpu
|
||||
from jax.experimental.mosaic.gpu import dsl as mgpu
|
||||
from jax.experimental.mosaic.gpu.utils import * # noqa: F403
|
||||
from jax._src.lib.mlir.dialects import gpu
|
||||
from jax._src.lib.mlir.dialects import llvm
|
||||
|
||||
|
||||
# ruff: noqa: F405
|
||||
config.update("jax_traceback_filtering", "off")
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
def nd_loop(bounds, body, *, _idxs = ()):
|
||||
if not bounds:
|
||||
body(*_idxs)
|
||||
return
|
||||
bound, *other_bounds = bounds
|
||||
@fori(bound, ())
|
||||
def _loop_body(i, _):
|
||||
nd_loop(other_bounds, body, _idxs=(*_idxs, i))
|
||||
return ()
|
||||
|
||||
|
||||
def mlir_sum(elems):
|
||||
assert elems
|
||||
total = elems[0]
|
||||
for elem in elems[1:]:
|
||||
total = arith.addi(total, elem)
|
||||
return total
|
||||
|
||||
|
||||
def copy(src: ir.Value, dst: ir.Value, swizzle: Optional[int] = None):
|
||||
index = ir.IndexType.get()
|
||||
thread_id = gpu.thread_id(gpu.Dimension.x)
|
||||
stride = gpu.block_dim(gpu.Dimension.x)
|
||||
for dim in (gpu.Dimension.y, gpu.Dimension.z):
|
||||
thread_id = arith.addi(thread_id, arith.muli(gpu.thread_id(dim), stride))
|
||||
stride = arith.muli(stride, gpu.block_dim(dim))
|
||||
is_first_thread = arith.cmpi(arith.CmpIPredicate.eq, thread_id, c(0, index))
|
||||
src_ty = ir.MemRefType(src.type)
|
||||
dst_ty = ir.MemRefType(dst.type)
|
||||
if src_ty.shape != dst_ty.shape:
|
||||
raise ValueError(
|
||||
f"src and dst shapes don't match: {src_ty.shape} != {dst_ty.shape}"
|
||||
)
|
||||
shape = src_ty.shape
|
||||
dyn_strides = [c(s, index) for s in get_contiguous_strides(shape)]
|
||||
with ir.InsertionPoint(scf.IfOp(is_first_thread).then_block):
|
||||
def body(*idx):
|
||||
dst_idx = idx
|
||||
if swizzle is not None:
|
||||
if swizzle != 128:
|
||||
raise NotImplementedError("Only swizzle 128B implemented")
|
||||
# TODO(apaszke): This can probably be cleaned up.
|
||||
# But it works and it's test-only, so it doesn't matter much.
|
||||
# After all, swizzle should just be an xor of row and linear idx,
|
||||
# adjusted for the bytewidth.
|
||||
bytes_per_element = bytewidth(src_ty.element_type)
|
||||
elems_per_tile = 1024 // bytes_per_element
|
||||
elems_per_row = elems_per_tile // 8
|
||||
elems_per_group = 16 // bytes_per_element
|
||||
linear_idx = c(0, index)
|
||||
for stride, i in zip(dyn_strides, idx):
|
||||
linear_idx = arith.addi(linear_idx, arith.muli(i, stride))
|
||||
tile_offset = arith.remui(linear_idx, c(elems_per_tile, index))
|
||||
linear_tile_start = arith.subi(linear_idx, tile_offset)
|
||||
row = arith.divui(tile_offset, c(elems_per_row, index))
|
||||
row_offset = arith.remui(tile_offset, c(elems_per_row, index))
|
||||
src_group = arith.divui(row_offset, c(elems_per_group, index))
|
||||
group_offset = arith.remui(row_offset, c(elems_per_group, index))
|
||||
dst_group = arith.xori(src_group, row)
|
||||
dst_linear_idx = mlir_sum([
|
||||
linear_tile_start,
|
||||
arith.muli(row, c(elems_per_row, index)),
|
||||
arith.muli(dst_group, c(elems_per_group, index)),
|
||||
group_offset,
|
||||
])
|
||||
dst_idx = [
|
||||
arith.remui(arith.divui(dst_linear_idx, stride), c(bound, index))
|
||||
for stride, bound in zip(dyn_strides, shape)
|
||||
]
|
||||
memref.store(memref.load(src, idx), dst, dst_idx)
|
||||
nd_loop([c(d, index) for d in shape], body)
|
||||
scf.yield_([])
|
||||
gpu.barrier()
|
||||
nvvm.fence_proxy(nvvm.ProxyKind.async_)
|
||||
|
||||
|
||||
def iota_tensor(m, n, mlir_dtype):
|
||||
assert m % 64 == 0
|
||||
assert n % 8 == 0
|
||||
def c(i):
|
||||
return arith.constant(index, ir.IntegerAttr.get(index, i))
|
||||
index = ir.IndexType.get()
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
warp_id = arith.divui(gpu.thread_id(gpu.Dimension.x), c(32))
|
||||
within_warp_id = arith.remui(gpu.thread_id(gpu.Dimension.x), c(32))
|
||||
warp_row_start = arith.muli(warp_id, c(16))
|
||||
within_warp_row = arith.divui(within_warp_id, c(4))
|
||||
start_row = arith.addi(warp_row_start, within_warp_row)
|
||||
start_col = arith.muli(arith.remui(within_warp_id, c(4)), c(2))
|
||||
registers = np.empty((m // 64, n // 8, 2, 1), dtype=object)
|
||||
for row_tile, col_tile, row_subtile, _ in np.ndindex(registers.shape):
|
||||
row = arith.addi(start_row, c(row_tile * 64 + row_subtile * 8))
|
||||
col = arith.addi(start_col, c(col_tile * 8))
|
||||
row_value_base = arith.muli(row, c(n))
|
||||
vec = llvm.mlir_undef(ir.VectorType.get((2,), i32))
|
||||
for col_offset in range(2):
|
||||
value = arith.addi(row_value_base, arith.addi(c(col_offset), col))
|
||||
value = arith.index_cast(i32, value)
|
||||
vec = vector.insertelement(value, vec, position=c(col_offset))
|
||||
registers[row_tile, col_tile, row_subtile, 0] = vec
|
||||
t = mgpu.FragmentedArray(_registers=registers, _layout=mgpu.WGMMA_LAYOUT)
|
||||
return t.astype(mlir_dtype)
|
||||
|
||||
|
||||
class TestCase(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
if not HAS_MOSAIC_GPU:
|
||||
self.skipTest("jaxlib built without Mosaic GPU")
|
||||
super().setUp()
|
||||
self.prng = np.random.default_rng(1234)
|
||||
self.ctx = mlir.make_ir_context()
|
||||
self.ctx.__enter__()
|
||||
self.loc = ir.Location.unknown()
|
||||
self.loc.__enter__()
|
||||
|
||||
def tearDown(self):
|
||||
self.loc.__exit__(None, None, None)
|
||||
self.ctx.__exit__(None, None, None)
|
||||
del self.loc, self.ctx
|
||||
super().tearDown()
|
||||
|
||||
|
||||
class TestUtilTest(TestCase):
|
||||
|
||||
def test_copy(self):
|
||||
def kernel(ctx, src, dst, _):
|
||||
copy(src, dst)
|
||||
x = jnp.arange(2 * 3 * 5).reshape(2, 5, 3)
|
||||
y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, ())(x)
|
||||
np.testing.assert_array_equal(y, x)
|
||||
|
||||
def test_copy_swizzle(self):
|
||||
def kernel(ctx, src, dst, _):
|
||||
copy(src, dst, swizzle=128)
|
||||
x = jnp.arange(8 * 32, dtype=jnp.float32).reshape(8, 32)
|
||||
y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, ())(x)
|
||||
expected = np.zeros_like(y)
|
||||
for i in range(8):
|
||||
for j in range(8):
|
||||
js = j ^ i
|
||||
expected[i, (j * 4):(j * 4) + 4] = x[i, (js * 4):(js * 4) + 4]
|
||||
np.testing.assert_array_equal(y, expected)
|
||||
|
||||
def test_copy_swizzle_noop(self):
|
||||
# Two swizzles cancel out
|
||||
def kernel(ctx, src, dst, smem):
|
||||
copy(src, smem, swizzle=128)
|
||||
copy(smem, dst, swizzle=128)
|
||||
x = jnp.arange(8 * 32, dtype=jnp.float32).reshape(8, 32)
|
||||
y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x)
|
||||
np.testing.assert_array_equal(y, x)
|
||||
|
||||
def test_iota_tensor(self):
|
||||
m = n = 64
|
||||
def kernel(ctx, dst, _):
|
||||
f32 = ir.F32Type.get()
|
||||
index = ir.IndexType.get()
|
||||
registers = iota_tensor(m, n, f32).registers
|
||||
assert registers.size == 16, registers.size
|
||||
for i, vec_reg in enumerate(registers.flat):
|
||||
for j in range(2):
|
||||
reg = vector.extractelement(vec_reg, position=c(j, index))
|
||||
memref.store(
|
||||
reg, dst, [gpu.thread_id(gpu.Dimension.x), c(2 * i + j, index)]
|
||||
)
|
||||
out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32)
|
||||
regs = mosaic_gpu.as_gpu_kernel(
|
||||
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
|
||||
)()
|
||||
thread_ids = np.arange(128)
|
||||
warp_ids = thread_ids // 32
|
||||
lane_ids = thread_ids % 32
|
||||
thread_rows = warp_ids * 16 + lane_ids // 4
|
||||
thread_start_cols = (lane_ids % 4) * 2
|
||||
thread_cols = thread_start_cols[:, None] + (np.arange(n // 8)[None] * 8)
|
||||
regs = regs.reshape(128, 8, 2, 2)
|
||||
for row_half in range(2):
|
||||
for col_half in range(2):
|
||||
np.testing.assert_array_equal(
|
||||
regs[..., row_half, col_half],
|
||||
(thread_rows[:, None] + row_half * 8) * n + thread_cols + col_half
|
||||
)
|
||||
|
||||
|
||||
class MemRefTest(TestCase):
|
||||
@parameterized.product(
|
||||
dim=tuple(range(3)),
|
||||
strided=(False, True)
|
||||
)
|
||||
def test_unsqueeze(self, dim, strided):
|
||||
def kernel(ctx, inp, out, _):
|
||||
if strided:
|
||||
for i in range(8):
|
||||
s = ds(i, 1)
|
||||
out_slice = s if dim != 0 else (slice(None), s)
|
||||
copy(
|
||||
memref_unsqueeze(memref_slice(inp, s), dim),
|
||||
memref_slice(out, out_slice),
|
||||
)
|
||||
else:
|
||||
copy(memref_unsqueeze(inp, dim), out)
|
||||
x = np.arange(8 * 16, dtype=jnp.float32).reshape(8, 16)
|
||||
out_shape = list(x.shape)
|
||||
out_shape.insert(dim, 1)
|
||||
out_ty = jax.ShapeDtypeStruct(out_shape, jnp.float32)
|
||||
y = mosaic_gpu.as_gpu_kernel(
|
||||
kernel, (1, 1, 1), (128, 1, 1), x, out_ty, ()
|
||||
)(x)
|
||||
np.testing.assert_array_equal(y, x.reshape(out_shape))
|
||||
|
||||
@parameterized.product(
|
||||
dim=tuple(range(2)),
|
||||
strided=(False, True)
|
||||
)
|
||||
def test_unfold(self, dim, strided):
|
||||
in_shape = (8, 16)
|
||||
def kernel(ctx, inp, out, _):
|
||||
if strided:
|
||||
# We slice the dim we don't unfold
|
||||
for i in range(in_shape[1 - dim] // 4):
|
||||
s = ds(i * 4, 4)
|
||||
in_slice = s if dim == 1 else (slice(None), s)
|
||||
out_slice = s if dim == 1 else (slice(None),) * 3 + (s,)
|
||||
copy(
|
||||
memref_unfold(memref_slice(inp, in_slice), dim, (2, 2, None)),
|
||||
memref_slice(out, out_slice),
|
||||
)
|
||||
else:
|
||||
copy(memref_unfold(inp, dim, (2, 2, None)), out)
|
||||
x = np.arange(np.prod(in_shape), dtype=jnp.float32).reshape(in_shape)
|
||||
out_shape = list(in_shape)
|
||||
out_shape[dim:dim + 1] = [2, 2, out_shape[dim] // 4]
|
||||
out_ty = jax.ShapeDtypeStruct(out_shape, jnp.float32)
|
||||
y = mosaic_gpu.as_gpu_kernel(
|
||||
kernel, (1, 1, 1), (128, 1, 1), x, out_ty, ()
|
||||
)(x)
|
||||
np.testing.assert_array_equal(y, x.reshape(out_ty.shape))
|
||||
|
||||
@parameterized.product(
|
||||
dim=tuple(range(2)),
|
||||
)
|
||||
def test_fold_not_strided(self, dim):
|
||||
def kernel(ctx, inp, out, _):
|
||||
copy(memref_fold(inp, dim, 2), out)
|
||||
|
||||
x = np.arange(8 * 2 * 8, dtype=jnp.float32).reshape(8, 2, 8)
|
||||
out_ty = jax.ShapeDtypeStruct((16, 8) if dim == 0 else (8, 16), jnp.float32)
|
||||
y = mosaic_gpu.as_gpu_kernel(
|
||||
kernel, (1, 1, 1), (128, 1, 1), x, out_ty, ()
|
||||
)(x)
|
||||
np.testing.assert_array_equal(y, x.reshape(out_ty.shape))
|
||||
|
||||
@parameterized.named_parameters([
|
||||
("packed", (4, 4, 4), (16, 4, 1), 1, 2, False),
|
||||
("strided_end", (4, 4, 4, 4), (256, 64, 16, 4), 1, 2, False),
|
||||
("strided_bot", (4, 4, 4, 4), (256, 16, 4, 1), 1, 2, False),
|
||||
("strided_top", (4, 4, 4, 4), (256, 64, 4, 1), 1, 2, True),
|
||||
("strided_mid", (4, 4, 4, 4), (265, 64, 16, 1), 1, 3, True),
|
||||
("overap", (2, 4, 4), (16, 1, 1), 0, 3, True),
|
||||
])
|
||||
def test_fold_strided(
|
||||
self, shape, strides, dim, fold_rank, throws_not_impl
|
||||
):
|
||||
expanded_shape = get_packed_shape(strides, shape)
|
||||
total_size = np.prod(expanded_shape)
|
||||
np_inp = np.arange(total_size, dtype=jnp.float32).reshape(expanded_shape)
|
||||
index = tuple([slice(0, s) for s in shape])
|
||||
|
||||
# Reference implementation
|
||||
def np_fold(inp, dim, fold_rank):
|
||||
out_shape = list(inp.shape)
|
||||
out_shape[dim : dim + fold_rank] = [
|
||||
int(np.prod(inp.shape[dim : dim + fold_rank]))
|
||||
]
|
||||
if throws_not_impl:
|
||||
return jax.ShapeDtypeStruct(shape=out_shape, dtype=inp.dtype)
|
||||
else:
|
||||
return inp.reshape(*out_shape)
|
||||
|
||||
total_size = np.prod(shape) * np.prod(strides)
|
||||
|
||||
def do_test():
|
||||
def kernel(ctx, inp, out, _):
|
||||
copy(memref_fold(memref_slice(inp, index), dim, fold_rank), out)
|
||||
|
||||
out = np_fold(np_inp[index], dim, fold_rank)
|
||||
y = mosaic_gpu.as_gpu_kernel(
|
||||
kernel, (1, 1, 1), (128, 1, 1), np_inp, out, ()
|
||||
)(np_inp)
|
||||
assert (
|
||||
not throws_not_impl
|
||||
), "If it should have thrown it would during the call."
|
||||
np.testing.assert_array_equal(y, out)
|
||||
|
||||
if throws_not_impl:
|
||||
with self.assertRaises(NotImplementedError):
|
||||
do_test()
|
||||
else:
|
||||
do_test()
|
||||
|
||||
|
||||
def get_packed_shape(strides, shape):
|
||||
perm = sorted(range(len(strides)), key=lambda i: strides[i], reverse=True)
|
||||
ordered_strides = [strides[i] for i in perm]
|
||||
ordered_shape = [shape[i] for i in perm]
|
||||
packed_shape = [ordered_shape[-1]]
|
||||
packed_shape += [
|
||||
stride0 // stride
|
||||
for stride0, stride in zip(ordered_strides, ordered_strides[1:])
|
||||
]
|
||||
# Invert permutation
|
||||
inv_perm = [None] * len(perm)
|
||||
for i, p in enumerate(perm):
|
||||
inv_perm[p] = i
|
||||
return [packed_shape[i] for i in inv_perm]
|
||||
|
||||
|
||||
class WGMMATest(TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("f32", ir.F32Type, jnp.float32), ("f16", ir.F16Type, jnp.float16)
|
||||
)
|
||||
def test_store_untiled(self, mlir_dtype_cls, jax_dtype):
|
||||
mlir_dtype = mlir_dtype_cls.get()
|
||||
def kernel(ctx, out, _):
|
||||
del ctx
|
||||
iota_tensor(64, 64, mlir_dtype).store_untiled(out)
|
||||
expected = np.arange(64 * 64, dtype=jax_dtype).reshape(64, 64)
|
||||
iota = mosaic_gpu.as_gpu_kernel(
|
||||
kernel, (1, 1, 1), (128, 1, 1), (), expected, ()
|
||||
)()
|
||||
np.testing.assert_array_equal(iota, expected)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("f32", ir.F32Type, jnp.float32),
|
||||
("f16", ir.F16Type, jnp.float16),
|
||||
)
|
||||
def test_store_tiled(self, mlir_dtype_cls, jax_dtype):
|
||||
mlir_dtype = mlir_dtype_cls.get()
|
||||
m = 128
|
||||
n = 256
|
||||
tiling = (64, 128 // bytewidth(mlir_dtype))
|
||||
def kernel(ctx, out, smem):
|
||||
del ctx
|
||||
iota_tensor(m, n, mlir_dtype).store_tiled(smem, swizzle=128)
|
||||
copy(smem, out, swizzle=128)
|
||||
expected = (
|
||||
np.arange(m * n, dtype=jax_dtype)
|
||||
.reshape(m // tiling[0], tiling[0], n // tiling[1], tiling[1])
|
||||
.transpose(0, 2, 1, 3)
|
||||
)
|
||||
iota = mosaic_gpu.as_gpu_kernel(
|
||||
kernel, (1, 1, 1), (128, 1, 1), (), expected, expected
|
||||
)()
|
||||
np.testing.assert_array_equal(iota, expected)
|
||||
|
||||
@parameterized.product(
|
||||
lhs_transpose=(False, True),
|
||||
rhs_transpose=(False, True),
|
||||
mlir_dtype_cls=(ir.F16Type, ir.BF16Type, ir.F32Type),
|
||||
m=(64, 128, 192),
|
||||
n=(32, 64, 128, 192),
|
||||
k_steps=(1, 2),
|
||||
tma_inputs=(False, True),
|
||||
)
|
||||
def test_wgmma(
|
||||
self,
|
||||
m,
|
||||
n,
|
||||
k_steps,
|
||||
mlir_dtype_cls,
|
||||
lhs_transpose,
|
||||
rhs_transpose,
|
||||
tma_inputs,
|
||||
):
|
||||
mlir_dtype = mlir_dtype_cls.get()
|
||||
if ir.F32Type.isinstance(mlir_dtype): # We actually use tf32 instead
|
||||
jax_dtype = jnp.float32
|
||||
if lhs_transpose or not rhs_transpose:
|
||||
self.skipTest("Transpose only supported in 16-bit WGMMA")
|
||||
exponent_bits, mantissa_bits = 8, 10 # Use tf32
|
||||
elif bytewidth(mlir_dtype) == 2:
|
||||
if n % 64 != 0:
|
||||
self.skipTest("16-bit WGMMA only supports n % 64 == 0")
|
||||
if ir.F16Type.isinstance(mlir_dtype):
|
||||
jax_dtype = jnp.float16
|
||||
exponent_bits, mantissa_bits = 5, 10
|
||||
elif ir.BF16Type.isinstance(mlir_dtype):
|
||||
jax_dtype = jnp.bfloat16
|
||||
exponent_bits, mantissa_bits = 8, 7
|
||||
else:
|
||||
raise NotImplementedError(mlir_dtype)
|
||||
else:
|
||||
raise NotImplementedError(mlir_dtype)
|
||||
nk_tile = 128 // bytewidth(mlir_dtype)
|
||||
k = nk_tile * k_steps
|
||||
assert m % 64 == 0 and n % nk_tile == 0
|
||||
index = ir.IndexType.get()
|
||||
|
||||
row_major = mgpu.WGMMALayout.ROW_MAJOR
|
||||
col_major = mgpu.WGMMALayout.COL_MAJOR
|
||||
lhs_order = col_major if lhs_transpose else row_major
|
||||
rhs_order = col_major if rhs_transpose else row_major
|
||||
|
||||
def kernel(ctx, lhs, rhs, out, scratch):
|
||||
lhs_smem, rhs_smem = scratch
|
||||
if tma_inputs:
|
||||
lhs_transform = (mosaic_gpu.TileTransform((64, nk_tile)),)
|
||||
if lhs_transpose:
|
||||
assert nk_tile == 64 # Make sure we didn't have to transpose tiling.
|
||||
lhs_transform += (mosaic_gpu.TransposeTransform((1, 0, 2, 3)),)
|
||||
rhs_transform = (mosaic_gpu.TileTransform((nk_tile, nk_tile)),)
|
||||
if rhs_transpose:
|
||||
rhs_transform += (mosaic_gpu.TransposeTransform((1, 0, 2, 3)),)
|
||||
barriers = BarrierArray(2)
|
||||
ctx.async_copy(
|
||||
src_ref=lhs,
|
||||
dst_ref=lhs_smem,
|
||||
swizzle=128,
|
||||
gmem_transform=lhs_transform,
|
||||
barrier=barriers[0],
|
||||
)
|
||||
ctx.async_copy(
|
||||
src_ref=rhs,
|
||||
dst_ref=rhs_smem,
|
||||
swizzle=128,
|
||||
gmem_transform=rhs_transform,
|
||||
barrier=barriers[1],
|
||||
)
|
||||
for i in range(2):
|
||||
barriers[i].wait()
|
||||
else:
|
||||
for mi in range(m // 64):
|
||||
for ki in range(k // nk_tile):
|
||||
lhs_slice = (
|
||||
ds(c(mi * 64, index), 64),
|
||||
ds(c(ki * nk_tile, index), nk_tile),
|
||||
)
|
||||
if lhs_transpose:
|
||||
lhs_slice = lhs_slice[::-1]
|
||||
copy(
|
||||
src=memref_slice(lhs, lhs_slice),
|
||||
dst=memref_slice(lhs_smem, (mi, ki)),
|
||||
swizzle=128,
|
||||
)
|
||||
for ki in range(k // nk_tile):
|
||||
k_slice = ds(c(ki * nk_tile, index), nk_tile)
|
||||
for ni in range(n // nk_tile):
|
||||
rhs_slice = (k_slice, ds(c(ni * nk_tile, index), nk_tile))
|
||||
if rhs_transpose:
|
||||
rhs_slice = rhs_slice[::-1]
|
||||
copy(
|
||||
src=memref_slice(rhs, rhs_slice),
|
||||
dst=memref_slice(rhs_smem, (ki, ni)),
|
||||
swizzle=128,
|
||||
)
|
||||
init_acc = mgpu.WGMMAAccumulator.zero(m=m, n=n)
|
||||
acc = mgpu.wgmma(
|
||||
init_acc, lhs_smem, rhs_smem,
|
||||
a_order=lhs_order, b_order=rhs_order,
|
||||
)
|
||||
nvvm.wgmma_commit_group_sync_aligned()
|
||||
nvvm.wgmma_wait_group_sync_aligned(0)
|
||||
acc.value.store_untiled(out)
|
||||
|
||||
def quantize(x):
|
||||
# Quantize the input to avoid rounding when feeding the WGMMA
|
||||
return jax.lax.reduce_precision(x, exponent_bits, mantissa_bits)
|
||||
|
||||
x_shape = (k, m) if lhs_transpose else (m, k)
|
||||
x = quantize(self.prng.uniform(-1, 1, x_shape)).astype(jax_dtype)
|
||||
y_shape = (n, k) if rhs_transpose else (k, n)
|
||||
y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(jax_dtype)
|
||||
out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32)
|
||||
scratch_shape = [
|
||||
jax.ShapeDtypeStruct((m // 64, k // nk_tile, 64, nk_tile), jax_dtype),
|
||||
jax.ShapeDtypeStruct(
|
||||
(k // nk_tile, n // nk_tile, nk_tile, nk_tile), jax_dtype
|
||||
),
|
||||
]
|
||||
z = mosaic_gpu.as_gpu_kernel(
|
||||
kernel, (1, 1, 1), (128, 1, 1), (x, y), out_shape, scratch_shape
|
||||
)(x, y)
|
||||
x32, y32 = x.astype(np.float32), y.astype(np.float32)
|
||||
ref = (x32.T if lhs_transpose else x32) @ (y32.T if rhs_transpose else y32)
|
||||
np.testing.assert_allclose(z, ref, atol=5e-6)
|
||||
|
||||
# TODO(apaszke): Add support for f32
|
||||
@parameterized.product(
|
||||
m=(64, 128, 192),
|
||||
n=(64, 128, 192),
|
||||
k_steps=(1, 2),
|
||||
rhs_transpose=(False, True),
|
||||
mlir_dtype_cls=(ir.F16Type, ir.BF16Type),
|
||||
)
|
||||
def test_wgmma_reg_lhs(self, m, n, k_steps, rhs_transpose, mlir_dtype_cls):
|
||||
k = 64 * k_steps
|
||||
index = ir.IndexType.get()
|
||||
|
||||
row_major = mgpu.WGMMALayout.ROW_MAJOR
|
||||
col_major = mgpu.WGMMALayout.COL_MAJOR
|
||||
rhs_order = col_major if rhs_transpose else row_major
|
||||
|
||||
def kernel(ctx, rhs, out, rhs_smem):
|
||||
del ctx
|
||||
for ki in range(k_steps):
|
||||
for ni in range(n // 64):
|
||||
rhs_slice = (ds(c(ki * 64, index), 64), ds(c(ni * 64, index), 64))
|
||||
if rhs_transpose:
|
||||
rhs_slice = rhs_slice[::-1]
|
||||
copy(
|
||||
src=memref_slice(rhs, rhs_slice),
|
||||
dst=memref_slice(rhs_smem, (ki, ni)),
|
||||
swizzle=128,
|
||||
)
|
||||
init_acc = mgpu.WGMMAAccumulator.zero(m=m, n=n)
|
||||
lhs_regs = iota_tensor(m, k, mlir_dtype_cls.get())
|
||||
acc = mgpu.wgmma(init_acc, lhs_regs, rhs_smem, b_order=rhs_order)
|
||||
nvvm.wgmma_commit_group_sync_aligned()
|
||||
nvvm.wgmma_wait_group_sync_aligned(0)
|
||||
acc.value.store_untiled(out)
|
||||
|
||||
jax_dtype = jnp.float16 if mlir_dtype_cls == ir.F16Type else jnp.bfloat16
|
||||
y_shape = (n, k) if rhs_transpose else (k, n)
|
||||
y = self.prng.uniform(-1, 1, y_shape).astype(jax_dtype)
|
||||
out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32)
|
||||
scratch_shape = jax.ShapeDtypeStruct(
|
||||
(k_steps, n // 64, 64, 64), jax_dtype
|
||||
)
|
||||
z = mosaic_gpu.as_gpu_kernel(
|
||||
kernel, (1, 1, 1), (128, 1, 1), y, out_shape, scratch_shape
|
||||
)(y)
|
||||
x = np.arange(m * k, dtype=jax_dtype).reshape(m, k)
|
||||
ref = jax.lax.dot(
|
||||
x, (y.T if rhs_transpose else y), preferred_element_type=jnp.float32
|
||||
)
|
||||
rtol = 0 if k_steps == 1 else 2.2e-4
|
||||
np.testing.assert_allclose(z, ref, rtol=rtol, atol=0)
|
||||
|
||||
|
||||
class TMATest(TestCase):
|
||||
|
||||
@parameterized.product(
|
||||
swizzle=(None, 128),
|
||||
shape=((64, 64), (5, 64), (2, 3, 5, 64)),
|
||||
dtype=(jnp.float16, jnp.float32),
|
||||
)
|
||||
def test_tma_load(self, swizzle, shape, dtype):
|
||||
if dtype == jnp.float32:
|
||||
shape = (*shape[:-1], shape[-1] // 2)
|
||||
i1 = ir.IntegerType.get_signless(1)
|
||||
def kernel(ctx, src, dst, tmp):
|
||||
barrier = BarrierArray(1)[0]
|
||||
ctx.async_copy(src_ref=src, dst_ref=tmp, swizzle=swizzle, barrier=barrier)
|
||||
barrier.wait_parity(c(0, i1))
|
||||
copy(tmp, dst, swizzle=swizzle)
|
||||
x = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
|
||||
y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x)
|
||||
np.testing.assert_array_equal(y, x)
|
||||
|
||||
@parameterized.product(
|
||||
swizzle=(None, 128),
|
||||
shape=((128, 128), (5, 32, 128)),
|
||||
dtype=(jnp.float16, jnp.float32),
|
||||
)
|
||||
def test_tma_load_tiled(self, swizzle, shape, dtype):
|
||||
i1 = ir.IntegerType.get_signless(1)
|
||||
index = ir.IndexType.get()
|
||||
tiling = (32, 128 // jnp.dtype(dtype).itemsize)
|
||||
tiled_shape = tile_shape(shape, tiling)[:len(shape)]
|
||||
def kernel(ctx, src, dst, tmp):
|
||||
barrier = BarrierArray(1)[0]
|
||||
ctx.async_copy(
|
||||
src_ref=src,
|
||||
dst_ref=tmp,
|
||||
swizzle=swizzle,
|
||||
barrier=barrier,
|
||||
gmem_transform=mosaic_gpu.TileTransform(tiling),
|
||||
)
|
||||
barrier.wait_parity(c(0, i1))
|
||||
for idxs in np.ndindex(tiled_shape):
|
||||
untiled_idxs, tiled_idxs = idxs[:-len(tiling)], idxs[-len(tiling):]
|
||||
s = (
|
||||
*untiled_idxs,
|
||||
*(ds(c(ix * t, index), t) for ix, t in zip(tiled_idxs, tiling)),
|
||||
)
|
||||
copy(memref_slice(tmp, idxs), memref_slice(dst, s), swizzle=swizzle)
|
||||
x = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
|
||||
smem = jax.ShapeDtypeStruct(tile_shape(shape, tiling), dtype)
|
||||
f = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, smem)
|
||||
y = f(x)
|
||||
np.testing.assert_array_equal(y, x)
|
||||
|
||||
@parameterized.product(
|
||||
swizzle=(None, 128),
|
||||
dtype=(jnp.float16, jnp.float32),
|
||||
)
|
||||
def test_tma_squeeze_indexing(self, swizzle, dtype):
|
||||
shape = (4, 5, 64)
|
||||
if dtype == jnp.float32:
|
||||
shape = (*shape[:-1], shape[-1] // 2)
|
||||
def kernel(ctx, src, dst, tmp):
|
||||
barrier = BarrierArray(1)[0]
|
||||
for i in range(4):
|
||||
ctx.async_copy(
|
||||
src_ref=src,
|
||||
dst_ref=memref_slice(tmp, i),
|
||||
gmem_slice=i,
|
||||
swizzle=swizzle,
|
||||
barrier=barrier,
|
||||
)
|
||||
barrier.wait()
|
||||
copy(tmp, dst, swizzle=swizzle)
|
||||
x = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
|
||||
y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x)
|
||||
np.testing.assert_array_equal(y, x)
|
||||
|
||||
def test_parity_tracking(self):
|
||||
shape = (16, 64)
|
||||
index = ir.IndexType.get()
|
||||
def kernel(ctx, src, dst, tmp):
|
||||
barrier = BarrierArray(1)[0]
|
||||
for i in range(shape[0]):
|
||||
s = ds(c(i, index), 1)
|
||||
ctx.async_copy(
|
||||
src_ref=src, dst_ref=tmp, gmem_slice=s, barrier=barrier,
|
||||
)
|
||||
barrier.wait()
|
||||
copy(tmp, memref_slice(dst, s))
|
||||
x = np.arange(np.prod(shape), dtype=jnp.float16).reshape(shape)
|
||||
y = mosaic_gpu.as_gpu_kernel(
|
||||
kernel, (1, 1, 1), (128, 1, 1), x, x, x[0:1]
|
||||
)(x)
|
||||
np.testing.assert_array_equal(y, x)
|
||||
|
||||
@parameterized.product(
|
||||
swizzle=(None, 128),
|
||||
shape=((64, 64), (5, 64), (2, 3, 5, 64)),
|
||||
dtype=(jnp.float16, jnp.float32),
|
||||
)
|
||||
def test_tma_store(self, swizzle, shape, dtype):
|
||||
if dtype == jnp.float32:
|
||||
shape = (*shape[:-1], shape[-1] // 2)
|
||||
def kernel(ctx, src, dst, tmp):
|
||||
copy(src, tmp, swizzle=swizzle)
|
||||
ctx.async_copy(src_ref=tmp, dst_ref=dst, swizzle=swizzle)
|
||||
ctx.await_async_copy(0)
|
||||
x = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
|
||||
y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x)
|
||||
np.testing.assert_array_equal(y, x)
|
||||
|
||||
|
||||
class FragmentedArrayTest(TestCase):
|
||||
|
||||
@parameterized.product(
|
||||
op=(
|
||||
operator.add,
|
||||
operator.mul,
|
||||
operator.sub,
|
||||
operator.truediv,
|
||||
(lambda x, y: mgpu.FragmentedArray.max(x, y), np.maximum),
|
||||
),
|
||||
m=(64, 128),
|
||||
n=(8, 16, 32, 64, 80, 128, 256),
|
||||
)
|
||||
def test_binary(self, op, m=64, n=32):
|
||||
if isinstance(op, tuple):
|
||||
op, np_op = op
|
||||
else:
|
||||
np_op = op
|
||||
def kernel(ctx, dst, _):
|
||||
f32 = ir.F32Type.get()
|
||||
iota = iota_tensor(m=m, n=n, mlir_dtype=f32)
|
||||
op(iota, iota).store_untiled(dst)
|
||||
out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32)
|
||||
result = mosaic_gpu.as_gpu_kernel(
|
||||
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
|
||||
)()
|
||||
x = np.arange(m * n, dtype=jnp.float32).reshape(m, n)
|
||||
if op == operator.truediv:
|
||||
np.testing.assert_allclose(result, np_op(x, x), atol=2e-7)
|
||||
else:
|
||||
np.testing.assert_array_equal(result, np_op(x, x))
|
||||
|
||||
@parameterized.product(
|
||||
ops=((lambda x: mgpu.FragmentedArray.exp(x), np.exp),),
|
||||
m=(64, 128),
|
||||
n=(8, 16, 32, 64, 80, 128, 256),
|
||||
)
|
||||
def test_unary(self, ops, m=64, n=32):
|
||||
op, np_op = ops
|
||||
def kernel(ctx, dst, _):
|
||||
f32 = ir.F32Type.get()
|
||||
iota = iota_tensor(m=m, n=n, mlir_dtype=f32)
|
||||
op(iota).store_untiled(dst)
|
||||
out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32)
|
||||
result = mosaic_gpu.as_gpu_kernel(
|
||||
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
|
||||
)()
|
||||
x = np.arange(m * n, dtype=jnp.float32).reshape(m, n)
|
||||
np.testing.assert_allclose(result, np_op(x), atol=2e-7, rtol=2e-7)
|
||||
|
||||
@parameterized.product(
|
||||
op=(arith.addf, arith.maximumf),
|
||||
m=(64, 128),
|
||||
n=(8, 16, 32, 64, 80, 128, 256),
|
||||
)
|
||||
def test_reduce(self, op, m=64, n=32):
|
||||
def kernel(ctx, dst, _):
|
||||
f32 = ir.F32Type.get()
|
||||
iota = iota_tensor(m=m, n=n, mlir_dtype=f32)
|
||||
iota.reduce(op, axis=1).broadcast_minor(n).store_untiled(dst)
|
||||
out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32)
|
||||
result = mosaic_gpu.as_gpu_kernel(
|
||||
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
|
||||
)()
|
||||
x = np.arange(m * n, dtype=jnp.float32).reshape(m, n)
|
||||
if op == arith.addf:
|
||||
expected = np.broadcast_to(x.sum(axis=1, keepdims=True), x.shape)
|
||||
elif op == arith.maximumf:
|
||||
expected = np.broadcast_to(x.max(axis=1, keepdims=True), x.shape)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported op: {op}")
|
||||
np.testing.assert_array_equal(result, expected)
|
||||
|
||||
def test_splat(self):
|
||||
def kernel(ctx, dst, _):
|
||||
f32 = ir.F32Type.get()
|
||||
v = arith.constant(f32, ir.FloatAttr.get(f32, 3.14))
|
||||
t = mgpu.FragmentedArray.splat(v, (128,), mgpu.WGMMA_ROW_LAYOUT)
|
||||
t.broadcast_minor(32).store_untiled(dst)
|
||||
out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32)
|
||||
result = mosaic_gpu.as_gpu_kernel(
|
||||
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
|
||||
)()
|
||||
np.testing.assert_array_equal(result, np.full((128, 32), 3.14, np.float32))
|
||||
|
||||
@parameterized.product(in_shape=((128, 128), (128, 64), (64, 128)))
|
||||
def test_strided_load_store(self, in_shape):
|
||||
def kernel(ctx, *args):
|
||||
gmem_input, gmem_output, (smem_input, smem_output) = args
|
||||
copy(gmem_input, smem_input)
|
||||
t = mgpu.FragmentedArray.load_strided(smem_input)
|
||||
t.store_untiled(smem_output)
|
||||
copy(smem_output, gmem_output)
|
||||
|
||||
inp = out = self.prng.uniform(-1, 1, in_shape).astype(jnp.float32)
|
||||
result = mosaic_gpu.as_gpu_kernel(
|
||||
kernel, (1, 1, 1), (128, 1, 1), (inp,), out, [inp, out],
|
||||
)(inp)
|
||||
np.testing.assert_array_equal(inp, result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user