mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Added a very rough sketch of Mosaic GPU lowering for Pallas
Almost nothing is supported, including * PyTree inputs/outputs * indexers * non-trivial grids * block specs * any primitives beyond the ones added here * etc etc PiperOrigin-RevId: 633713366
This commit is contained in:
parent
0ad5167da8
commit
e2918ca138
@ -19,6 +19,7 @@ load("@rules_python//python:defs.bzl", "py_library")
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"if_building_jaxlib",
|
||||
"if_building_mosaic_gpu",
|
||||
"jax_extend_internal_users",
|
||||
"jax_extra_deps",
|
||||
"jax_internal_export_back_compat_test_util_visibility",
|
||||
@ -669,7 +670,7 @@ pytype_strict_library(
|
||||
deps = [
|
||||
":pallas",
|
||||
"//jax/_src/pallas/triton",
|
||||
],
|
||||
] + if_building_mosaic_gpu(["//third_party/py/jax/_src/pallas/mosaic_gpu"]),
|
||||
)
|
||||
|
||||
# This target only supports sm_90 GPUs.
|
||||
|
60
jax/_src/pallas/mosaic_gpu/BUILD
Normal file
60
jax/_src/pallas/mosaic_gpu/BUILD
Normal file
@ -0,0 +1,60 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Package for Mosaic-specific Pallas extensions
|
||||
|
||||
load("//jaxlib:jax.bzl", "pytype_strict_library")
|
||||
load("@rules_python//python:defs.bzl", "py_library")
|
||||
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
default_visibility = [
|
||||
"//third_party/py/jax:internal",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "mosaic_gpu",
|
||||
srcs = ["__init__.py"],
|
||||
deps = [
|
||||
":pallas_call_registration",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "pallas_call_registration",
|
||||
srcs = ["pallas_call_registration.py"],
|
||||
deps = [
|
||||
":lowering",
|
||||
"//third_party/py/jax",
|
||||
"//third_party/py/jax:mlir",
|
||||
"//third_party/py/jax:mosaic_gpu",
|
||||
"//third_party/py/jax/_src/pallas",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "lowering",
|
||||
srcs = ["lowering.py"],
|
||||
deps = [
|
||||
"//third_party/py/jax",
|
||||
"//third_party/py/jax:core",
|
||||
"//third_party/py/jax:mlir",
|
||||
"//third_party/py/jax:mosaic_gpu",
|
||||
"//third_party/py/jax:util",
|
||||
"//third_party/py/jax/_src/lib",
|
||||
"//third_party/py/jax/_src/pallas",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
13
jax/_src/pallas/mosaic_gpu/__init__.py
Normal file
13
jax/_src/pallas/mosaic_gpu/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
297
jax/_src/pallas/mosaic_gpu/lowering.py
Normal file
297
jax/_src/pallas/mosaic_gpu/lowering.py
Normal file
@ -0,0 +1,297 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Module for lowering JAX primitives to Mosaic GPU."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
import dataclasses
|
||||
import math
|
||||
from typing import Any, cast
|
||||
|
||||
import jax
|
||||
from jax._src import core as jax_core
|
||||
from jax._src import util
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lax import lax
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import arith as arith_dialect
|
||||
from jax._src.lib.mlir.dialects import memref as memref_dialect
|
||||
from jax._src.lib.mlir.dialects import nvgpu as nvgpu_dialect
|
||||
from jax._src.pallas import core as pl_core
|
||||
from jax._src.state import primitives as sp
|
||||
from jax.experimental.mosaic import gpu as mosaic_gpu
|
||||
from jax.experimental.mosaic.gpu import dsl as mgpu
|
||||
import numpy as np
|
||||
|
||||
# TODO(slebedev): Enable type checking.
|
||||
# mypy: ignore-errors
|
||||
# pytype: skip-file
|
||||
|
||||
map, unsafe_map = util.safe_map, map
|
||||
zip, unsafe_zip = util.safe_zip, zip
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModuleContext:
|
||||
name: str
|
||||
grid_mapping: pl_core.GridMapping
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LoweringRuleContext:
|
||||
context: ModuleContext
|
||||
avals_in: Sequence[jax_core.ShapedArray]
|
||||
avals_out: Sequence[jax_core.ShapedArray]
|
||||
block_shapes: list[tuple[int | pl_core.Mapped, ...]] | None
|
||||
|
||||
replace = dataclasses.replace
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LoweringResult:
|
||||
module: ir.Module
|
||||
grid: tuple[int, ...]
|
||||
gmem_scratch_bytes: int
|
||||
out_structs: tuple[jax.ShapeDtypeStruct, ...]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BlockInfo:
|
||||
full_shape_dtype: jax.ShapeDtypeStruct
|
||||
start_indices: Sequence[Any]
|
||||
block_shape: tuple[int, ...]
|
||||
|
||||
|
||||
class LoweringError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def lower_jaxpr_to_module(
|
||||
grid_mapping: pl_core.GridMapping,
|
||||
in_structs: tuple[jax.ShapeDtypeStruct, ...],
|
||||
out_structs: tuple[jax.ShapeDtypeStruct, ...],
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
name: str,
|
||||
) -> LoweringResult:
|
||||
assert len(jaxpr.outvars) == 0
|
||||
assert not grid_mapping.mapped_dims
|
||||
grid = grid_mapping.grid
|
||||
if len(grid) < 3:
|
||||
grid += (1,) * (3 - len(grid))
|
||||
block = (128,) + (1,) * (len(grid) - 1)
|
||||
|
||||
def body(ctx: mosaic_gpu.LaunchContext, *buffers):
|
||||
*buffers_gmem, buffers_smem = buffers
|
||||
assert len(buffers_gmem) == len(buffers_smem)
|
||||
in_buffers_gmem = buffers_gmem[: len(in_structs)]
|
||||
in_buffers_smem = buffers_smem[: len(in_structs)]
|
||||
out_buffers_gmem = buffers_gmem[len(in_structs) :]
|
||||
out_buffers_smem = buffers_smem[len(in_structs) :]
|
||||
|
||||
# arrival_count= determines the expected number of arrivals for each
|
||||
# barrier in the array. It is not accidental that we do just a single
|
||||
# mbarrier_arrive_expect_tx below.
|
||||
# TODO(slebedev): Consider enforcing this in the mgpu.BarrierArray.
|
||||
[barrier] = mgpu.BarrierArray(1, arrival_count=1)
|
||||
|
||||
with mgpu.once():
|
||||
nvgpu_dialect.mbarrier_arrive_expect_tx(
|
||||
barrier.barrier_array.value,
|
||||
_index(
|
||||
sum(math.prod(s.shape) * s.dtype.itemsize for s in in_structs)
|
||||
),
|
||||
barrier.offset,
|
||||
)
|
||||
|
||||
for b_gmem, b_smem in zip(in_buffers_gmem, in_buffers_smem):
|
||||
# TODO(slebedev): Support 128-byteswizzling, once we can lower matmuls.
|
||||
ctx.async_copy(
|
||||
src_ref=b_gmem,
|
||||
dst_ref=b_smem,
|
||||
barrier=barrier,
|
||||
swizzle=None,
|
||||
arrive=False,
|
||||
uniform=False,
|
||||
)
|
||||
|
||||
barrier.wait()
|
||||
|
||||
module_ctx = ModuleContext(name, grid_mapping)
|
||||
_ = lower_jaxpr_to_mosaic_gpu(module_ctx, jaxpr, None, *buffers_smem)
|
||||
|
||||
for b_gmem, b_smem in zip(out_buffers_gmem, out_buffers_smem):
|
||||
# TODO(slebedev): Support 128-byteswizzling, once we can lower matmuls.
|
||||
ctx.async_copy(src_ref=b_smem, dst_ref=b_gmem, swizzle=None)
|
||||
|
||||
ctx.await_async_copy(0)
|
||||
|
||||
module, out_structs, gmem_scratch_bytes, _ = mosaic_gpu._lower_as_gpu_kernel(
|
||||
body,
|
||||
grid,
|
||||
block,
|
||||
in_shape=in_structs,
|
||||
out_shape=out_structs,
|
||||
smem_scratch_shape=in_structs + out_structs,
|
||||
)
|
||||
|
||||
return LoweringResult(module, grid, gmem_scratch_bytes, out_structs)
|
||||
|
||||
|
||||
mosaic_lowering_rules = {}
|
||||
|
||||
|
||||
def register_lowering_rule(primitive: jax_core.Primitive):
|
||||
def deco(fn):
|
||||
mosaic_lowering_rules[primitive] = fn
|
||||
return fn
|
||||
|
||||
return deco
|
||||
|
||||
|
||||
def lower_jaxpr_to_mosaic_gpu(
|
||||
ctx: ModuleContext,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
block_infos: Sequence[BlockInfo | None] | None,
|
||||
*args,
|
||||
) -> Sequence[ir.Value]:
|
||||
env = {}
|
||||
block_info_env = {}
|
||||
|
||||
def read_env(atom: jax_core.Atom):
|
||||
return atom.val if isinstance(atom, jax_core.Literal) else env[atom]
|
||||
|
||||
def read_block_info_env(atom: jax_core.Atom):
|
||||
if isinstance(atom, jax_core.Literal):
|
||||
return None
|
||||
return block_info_env.get(atom, None)
|
||||
|
||||
def write_env(var: jax_core.Var, val):
|
||||
env[var] = val
|
||||
|
||||
if block_infos is None:
|
||||
block_infos = [None] * len(jaxpr.invars)
|
||||
for invar, block_info in zip(jaxpr.invars, block_infos):
|
||||
block_info_env[invar] = block_info
|
||||
map(write_env, jaxpr.invars, args)
|
||||
for eqn in jaxpr.eqns:
|
||||
invals = map(read_env, eqn.invars)
|
||||
if eqn.primitive not in mosaic_lowering_rules:
|
||||
raise NotImplementedError(
|
||||
"Unimplemented primitive in Pallas Mosaic GPU lowering: "
|
||||
f"{eqn.primitive.name}. "
|
||||
"Please file an issue on https://github.com/google/jax/issues."
|
||||
)
|
||||
rule = mosaic_lowering_rules[eqn.primitive]
|
||||
rule_ctx = LoweringRuleContext(
|
||||
ctx,
|
||||
avals_in=[cast(jax_core.ShapedArray, v.aval) for v in eqn.invars],
|
||||
avals_out=[cast(jax_core.ShapedArray, v.aval) for v in eqn.outvars],
|
||||
block_shapes=map(read_block_info_env, eqn.invars),
|
||||
)
|
||||
try:
|
||||
outvals = rule(rule_ctx, *invals, **eqn.params)
|
||||
except LoweringError:
|
||||
raise # We only add the extra info to the innermost exception.
|
||||
except Exception as e:
|
||||
inval_types = map(lambda t: getattr(t, "type", None), invals)
|
||||
raise LoweringError(
|
||||
f"Exception while lowering eqn:\n {eqn}\nWith context:\n "
|
||||
f" {rule_ctx}\nWith inval types={inval_types}\nIn jaxpr:\n{jaxpr}"
|
||||
) from e
|
||||
if eqn.primitive.multiple_results:
|
||||
map(write_env, eqn.outvars, outvals)
|
||||
else:
|
||||
write_env(eqn.outvars[0], outvals)
|
||||
return map(read_env, jaxpr.outvars)
|
||||
|
||||
|
||||
@register_lowering_rule(sp.get_p)
|
||||
def _get_lowering_rule(ctx: LoweringRuleContext, x_smem, *indexers, tree):
|
||||
del tree # Unused.
|
||||
if indexers:
|
||||
raise NotImplementedError("No support for indexers yet")
|
||||
return mgpu.FragmentedArray.load_strided(x_smem)
|
||||
|
||||
|
||||
@register_lowering_rule(sp.swap_p)
|
||||
def _swap_lowering_rule(
|
||||
ctx: LoweringRuleContext, x_smem, value, *indexers, tree
|
||||
):
|
||||
del tree # Unused.
|
||||
if indexers:
|
||||
raise NotImplementedError("No support for indexers yet")
|
||||
old_value = mgpu.FragmentedArray.load_strided(x_smem)
|
||||
value.store_untiled(x_smem)
|
||||
return old_value
|
||||
|
||||
|
||||
@register_lowering_rule(lax.add_p)
|
||||
def _add_lowering_rule(ctx: LoweringRuleContext, x, y):
|
||||
x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out)
|
||||
return x + y
|
||||
|
||||
|
||||
def _bcast_to(a: Any, shape: tuple[int, ...]) -> ir.Value:
|
||||
if not isinstance(a, mgpu.FragmentedArray):
|
||||
if not shape:
|
||||
return a
|
||||
layout = mgpu.WGStridedFragLayout.from_memref_type(
|
||||
memref_dialect.MemRefType.get(shape, a.type)
|
||||
)
|
||||
return mgpu.FragmentedArray.splat(a, shape, layout)
|
||||
else:
|
||||
if a.shape == shape:
|
||||
return a
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _bcast(
|
||||
x: ir.Value,
|
||||
y: ir.Value,
|
||||
x_aval: jax_core.ShapedArray,
|
||||
y_aval: jax_core.ShapedArray,
|
||||
out_aval: jax_core.ShapedArray,
|
||||
) -> ir.Value:
|
||||
if isinstance(x, (np.ndarray, np.number, int, float)):
|
||||
x_dtype = x_aval.dtype
|
||||
if x_aval.weak_type:
|
||||
x_dtype = y_aval.dtype
|
||||
x = _ir_constant(x, mlir.dtype_to_ir_type(x_dtype))
|
||||
if isinstance(y, (np.ndarray, np.number, int, float)):
|
||||
y_dtype = y_aval.dtype
|
||||
if y_aval.weak_type:
|
||||
y_dtype = x_aval.dtype
|
||||
y = _ir_constant(y, mlir.dtype_to_ir_type(y_dtype))
|
||||
if x_aval.shape != out_aval.shape:
|
||||
x = _bcast_to(x, out_aval.shape)
|
||||
if y_aval.shape != out_aval.shape:
|
||||
y = _bcast_to(y, out_aval.shape)
|
||||
return x, y
|
||||
|
||||
|
||||
def _ir_constant(v: object, t: ir.Type) -> ir.Value:
|
||||
if isinstance(v, (np.number, np.ndarray, int, float)):
|
||||
if isinstance(t, ir.IntegerType):
|
||||
v = int(v)
|
||||
else:
|
||||
assert isinstance(t, ir.FloatType)
|
||||
v = float(v)
|
||||
return arith_dialect.constant(t, v)
|
||||
raise NotImplementedError(f"Unsupported constant: {v!r}")
|
||||
|
||||
|
||||
def _index(i: int) -> ir.Value:
|
||||
return arith_dialect.constant(ir.IndexType.get(), i)
|
90
jax/_src/pallas/mosaic_gpu/pallas_call_registration.py
Normal file
90
jax/_src/pallas/mosaic_gpu/pallas_call_registration.py
Normal file
@ -0,0 +1,90 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Module registering a lowering rule for pallas_call on GPU."""
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax import core as jax_core
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas.mosaic_gpu import lowering
|
||||
from jax._src.pallas.pallas_call import pallas_call_p
|
||||
from jax.experimental.mosaic import gpu as mosaic_gpu
|
||||
|
||||
|
||||
def pallas_call_lowering(
|
||||
ctx: mlir.LoweringRuleContext,
|
||||
*args,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
name: str,
|
||||
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
||||
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
||||
which_linear: tuple[bool, ...],
|
||||
interpret: bool,
|
||||
debug: bool,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
grid_mapping: pallas_core.GridMapping,
|
||||
compiler_params: dict[str, Any],
|
||||
):
|
||||
if interpret:
|
||||
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
|
||||
ctx,
|
||||
*args,
|
||||
jaxpr=jaxpr,
|
||||
name=name,
|
||||
out_shapes=out_shapes,
|
||||
in_shapes=in_shapes,
|
||||
which_linear=which_linear,
|
||||
interpret=interpret,
|
||||
debug=debug,
|
||||
input_output_aliases=input_output_aliases,
|
||||
grid_mapping=grid_mapping,
|
||||
compiler_params=compiler_params,
|
||||
)
|
||||
|
||||
if grid_mapping.num_dynamic_grid_bounds:
|
||||
raise NotImplementedError(
|
||||
"dynamic grid bounds not supported in the Mosaic GPU backend"
|
||||
)
|
||||
if input_output_aliases:
|
||||
raise NotImplementedError(
|
||||
"input_output_aliases not supported in the Mosaic GPU backend"
|
||||
)
|
||||
|
||||
if debug:
|
||||
print(jaxpr)
|
||||
print(grid_mapping)
|
||||
|
||||
lowering_result = lowering.lower_jaxpr_to_module(
|
||||
grid_mapping,
|
||||
in_shapes,
|
||||
out_shapes,
|
||||
jaxpr,
|
||||
name,
|
||||
)
|
||||
if debug:
|
||||
print(lowering_result.module.operation)
|
||||
|
||||
return mosaic_gpu._mosaic_gpu_lowering_rule(
|
||||
ctx,
|
||||
*args,
|
||||
module=lowering_result.module,
|
||||
gmem_scratch_bytes=lowering_result.gmem_scratch_bytes,
|
||||
out_types=lowering_result.out_structs,
|
||||
)
|
@ -15,17 +15,16 @@
|
||||
"""Module for calling pallas functions from JAX."""
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
from functools import partial
|
||||
from functools import reduce
|
||||
|
||||
from typing import Any, Callable
|
||||
from collections.abc import Sequence
|
||||
import itertools
|
||||
from functools import partial, reduce
|
||||
from typing import Any, Callable
|
||||
|
||||
import jax
|
||||
from jax import api_util
|
||||
from jax import tree_util
|
||||
from jax import lax
|
||||
from jax._src import config
|
||||
from jax._src import state
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
@ -532,6 +531,16 @@ def _extract_function_name(f: Callable, name: str | None) -> str:
|
||||
return name
|
||||
|
||||
|
||||
_PALLAS_USE_MOSAIC_GPU = config.DEFINE_bool(
|
||||
"jax_pallas_use_mosaic_gpu",
|
||||
default=config.bool_env("JAX_PALLAS_USE_MOSAIC_GPU", False),
|
||||
help=(
|
||||
"If True, lower Pallas kernels to the experimental Mosaic GPU"
|
||||
" dialect, instead of Trition IR."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _unsupported_lowering_error(platform: str) -> Exception:
|
||||
return ValueError(
|
||||
f"Cannot lower pallas_call on platform: {platform}. To use Pallas on GPU,"
|
||||
@ -560,7 +569,10 @@ def _pallas_call_lowering(
|
||||
raise ValueError("Only interpret mode is supported on CPU backend.")
|
||||
elif platform == "cuda" or platform == "rocm":
|
||||
try:
|
||||
from jax._src.pallas.triton import pallas_call_registration # type: ignore
|
||||
if _PALLAS_USE_MOSAIC_GPU.value:
|
||||
from jax._src.pallas.mosaic_gpu import pallas_call_registration # type: ignore
|
||||
else:
|
||||
from jax._src.pallas.triton import pallas_call_registration # type: ignore
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
|
@ -530,7 +530,7 @@ def _launch(
|
||||
gpu.terminator()
|
||||
|
||||
|
||||
def as_gpu_kernel(
|
||||
def _lower_as_gpu_kernel(
|
||||
body,
|
||||
grid: tuple[int, ...],
|
||||
block: tuple[int, ...],
|
||||
@ -609,6 +609,31 @@ def as_gpu_kernel(
|
||||
main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
|
||||
module.operation.verify()
|
||||
|
||||
dump_low_level(module)
|
||||
|
||||
pass_manager = _get_mosaic_gpu_pipeline("fatbin")
|
||||
if mosaic_gpu_print_after_all.value:
|
||||
pass_manager.enable_ir_printing()
|
||||
pass_manager.run(module.operation)
|
||||
|
||||
return module, out_shape, gmem_scratch_bytes, unwrap_output_tuple
|
||||
|
||||
|
||||
def as_gpu_kernel(
|
||||
body,
|
||||
grid: tuple[int, ...],
|
||||
block: tuple[int, ...],
|
||||
in_shape,
|
||||
out_shape,
|
||||
smem_scratch_shape,
|
||||
prof_spec: profiler.ProfilerSpec | None = None,
|
||||
):
|
||||
module, out_shape, gmem_scratch_bytes, unwrap_output_tuple = (
|
||||
_lower_as_gpu_kernel(
|
||||
body, grid, block, in_shape, out_shape, smem_scratch_shape, prof_spec
|
||||
)
|
||||
)
|
||||
|
||||
expected_arg_treedef = jax.tree.structure(in_shape)
|
||||
def _check_args(args):
|
||||
arg_treedef = jax.tree.structure(args)
|
||||
@ -618,13 +643,6 @@ def as_gpu_kernel(
|
||||
f" {arg_treedef}"
|
||||
)
|
||||
|
||||
dump_low_level(module)
|
||||
|
||||
pass_manager = _get_mosaic_gpu_pipeline("fatbin")
|
||||
if mosaic_gpu_print_after_all.value:
|
||||
pass_manager.enable_ir_printing()
|
||||
pass_manager.run(module.operation)
|
||||
|
||||
def bind(*args):
|
||||
return mosaic_gpu_p.bind(
|
||||
*args,
|
||||
|
@ -14,6 +14,7 @@
|
||||
# ==============================================================================
|
||||
"""Utilities for code generator."""
|
||||
|
||||
from collections.abc import Iterator
|
||||
import contextlib
|
||||
import dataclasses
|
||||
from typing import Any, Literal, Sequence
|
||||
@ -496,6 +497,7 @@ class BarrierArray:
|
||||
f" num_barriers={num_barriers}>"
|
||||
)
|
||||
|
||||
self.num_barriers = num_barriers
|
||||
self.value = nvgpu.mbarrier_create(barrier_group_ty)
|
||||
self.num_barriers = num_barriers
|
||||
index = ir.IndexType.get()
|
||||
@ -508,6 +510,10 @@ class BarrierArray:
|
||||
for i in range(num_barriers):
|
||||
nvgpu.mbarrier_init(self.value, c(arrival_count, index), c(i, index))
|
||||
|
||||
def __iter__(self) -> Iterator["Barrier"]:
|
||||
for offset in range(self.num_barriers):
|
||||
yield self[offset]
|
||||
|
||||
def __getitem__(self, offset: ir.Value | int):
|
||||
if isinstance(offset, int):
|
||||
offset = c(offset, ir.IndexType.get())
|
||||
|
@ -254,6 +254,42 @@ jax_test(
|
||||
] + py_deps("absl/testing") + py_deps("absl/flags") + py_deps("numpy") + py_deps("hypothesis"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "mosaic_gpu_test",
|
||||
srcs = [
|
||||
"mosaic_gpu_test.py",
|
||||
],
|
||||
config_tags_overrides = {
|
||||
# TODO(slebedev): Switch to False once Mosaic GPU is unconditionally enabled.
|
||||
"gpu_h100_x32": {
|
||||
"ondemand": True, # Include in presubmit.
|
||||
},
|
||||
},
|
||||
disable_backends = [
|
||||
"cpu",
|
||||
"tpu",
|
||||
],
|
||||
disable_configs = [
|
||||
"gpu",
|
||||
"gpu_x32",
|
||||
"gpu_a100",
|
||||
"gpu_a100_x32",
|
||||
"gpu_p100",
|
||||
"gpu_p100_x32",
|
||||
"gpu_h100",
|
||||
],
|
||||
enable_configs = [
|
||||
"gpu_h100_x32",
|
||||
],
|
||||
env = {
|
||||
"JAX_PALLAS_USE_MOSAIC_GPU": "1",
|
||||
},
|
||||
deps = [
|
||||
"//jax:pallas",
|
||||
"//jax:pallas_gpu", # build_cleaner: keep
|
||||
] + py_deps("absl/testing") + py_deps("numpy"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "export_back_compat_pallas_test",
|
||||
srcs = ["export_back_compat_pallas_test.py"],
|
||||
|
53
tests/pallas/mosaic_gpu_test.py
Normal file
53
tests/pallas/mosaic_gpu_test.py
Normal file
@ -0,0 +1,53 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax._src import test_util as jtu
|
||||
from jax.experimental import pallas as pl
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class PallasTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
if not jtu.is_cuda_compute_capability_at_least("9.0"):
|
||||
self.skipTest("Only works on a GPU with capability >= sm90")
|
||||
|
||||
|
||||
class PallasCallTest(PallasTest):
|
||||
|
||||
def test_add_one(self):
|
||||
@functools.partial(
|
||||
pl.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
|
||||
)
|
||||
def add_one(x_ref, o_ref):
|
||||
print(">>>", x_ref, o_ref)
|
||||
o_ref[...] = x_ref[...] + 1.0
|
||||
|
||||
x = jnp.arange(256).astype(jnp.float32)
|
||||
np.testing.assert_array_equal(add_one(x), x + 1.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
Loading…
x
Reference in New Issue
Block a user