Merge pull request #299 from ROCm/ci-upstream-sync-152_1

CI: 03/19/25 upstream sync
This commit is contained in:
rocm-repo-management-api-2[bot] 2025-03-19 07:20:19 -05:00 committed by GitHub
commit b505df9973
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
42 changed files with 1048 additions and 649 deletions

View File

@ -1,8 +1,8 @@
diff --git a/build/requirements_lock_3_13_ft.txt b/build/requirements_lock_3_13_ft.txt diff --git a/build/requirements_lock_3_13_ft.txt b/build/requirements_lock_3_13_ft.txt
index dfefaf042..2700e140e 100644 index e7a2968e9..d37e11ee3 100644
--- a/build/requirements_lock_3_13_ft.txt --- a/build/requirements_lock_3_13_ft.txt
+++ b/build/requirements_lock_3_13_ft.txt +++ b/build/requirements_lock_3_13_ft.txt
@@ -4,6 +4,12 @@ @@ -4,6 +4,11 @@
# #
# pip-compile --allow-unsafe --generate-hashes --output-file=build/requirements_lock_3_13_ft.txt build/requirements.in # pip-compile --allow-unsafe --generate-hashes --output-file=build/requirements_lock_3_13_ft.txt build/requirements.in
# #
@ -10,12 +10,11 @@ index dfefaf042..2700e140e 100644
+--pre +--pre
+--extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple +--extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple
+numpy +numpy
+
+ +
absl-py==2.1.0 \ absl-py==2.1.0 \
--hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \
--hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff
@@ -328,68 +334,6 @@ mpmath==1.3.0 \ @@ -328,68 +333,6 @@ mpmath==1.3.0 \
--hash=sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f \ --hash=sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f \
--hash=sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c --hash=sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c
# via -r build/test-requirements.txt # via -r build/test-requirements.txt
@ -81,6 +80,6 @@ index dfefaf042..2700e140e 100644
- # matplotlib - # matplotlib
- # ml-dtypes - # ml-dtypes
- # scipy - # scipy
opt-einsum==3.4.0 \ nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \
--hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \ --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \
--hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \

View File

@ -173,12 +173,18 @@ jobs:
--bazel_options=--copt=-g \ --bazel_options=--copt=-g \
--clang_path=/usr/bin/clang-18 --clang_path=/usr/bin/clang-18
# Update the patch to use TSAN instrumented numpy # Patch build/requirements_lock_3_13_ft.txt to use TSAN instrumented NumPy
sed -i "s|+--extra-index-url.*|+--extra-index-url file://${GITHUB_WORKSPACE}/wheelhouse/|" .github/workflows/requirements_lock_3_13_ft.patch sed -i "s|+--extra-index-url.*|+--extra-index-url file://${GITHUB_WORKSPACE}/wheelhouse/|" .github/workflows/requirements_lock_3_13_ft.patch
cat .github/workflows/requirements_lock_3_13_ft.patch cat .github/workflows/requirements_lock_3_13_ft.patch
git apply .github/workflows/requirements_lock_3_13_ft.patch || exit 1
# Apply a patch to numpy in requirements lock 3.13 ft to use the nightly version # Display the content for debugging in logs
git apply .github/workflows/requirements_lock_3_13_ft.patch cat build/requirements_lock_3_13_ft.txt | head -15
# Check the patch
cat build/requirements_lock_3_13_ft.txt | head -15 | grep -E "(--pre|.*${GITHUB_WORKSPACE}/wheelhouse/|numpy)"
if [ "$?" == "1" ]; then echo "Could not find the patch in the requirements_lock_3_13_ft.txt"; exit 1; fi
cat build/requirements_lock_3_13_ft.txt | grep -E "(numpy==)"
if [ "$?" == "0" ]; then "Found original numpy dependency in the requirements_lock_3_13_ft.txt"; exit 1; fi
echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES" echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES"
echo "JAX_ENABLE_X64=$JAX_ENABLE_X64" echo "JAX_ENABLE_X64=$JAX_ENABLE_X64"
@ -188,6 +194,13 @@ jobs:
bazel_exec=($(ls bazel-*)) bazel_exec=($(ls bazel-*))
ln -s ${bazel_exec} bazel ln -s ${bazel_exec} bazel
# Check python version
./bazel run --@rules_python//python/config_settings:py_freethreaded="yes" @python//:python3 -- -VV
# Check numpy version
./bazel cquery @pypi_numpy//:* | grep whl
# Build JAX and run tests
./bazel test \ ./bazel test \
--test_env=JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES \ --test_env=JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES \
--test_env=JAX_ENABLE_X64=$JAX_ENABLE_X64 \ --test_env=JAX_ENABLE_X64=$JAX_ENABLE_X64 \

View File

@ -33,7 +33,6 @@ from jax._src import errors
from jax._src import profiler from jax._src import profiler
from jax._src import util from jax._src import util
from jax._src import xla_bridge from jax._src import xla_bridge
from jax._src.mesh import use_concrete_mesh
from jax._src.interpreters import mlir from jax._src.interpreters import mlir
from jax._src.interpreters import pxla from jax._src.interpreters import pxla
from jax._src.interpreters import xla from jax._src.interpreters import xla
@ -43,7 +42,8 @@ from jax._src.lib import xla_extension as xe
from jax._src.sharding import Sharding from jax._src.sharding import Sharding
from jax._src.sharding_impls import ( from jax._src.sharding_impls import (
PmapSharding, SingleDeviceSharding, PmapSharding, SingleDeviceSharding,
device_replica_id_map, hashed_index, num_addressable_indices, local_to_global_shape) # pyformat: disable device_replica_id_map, hashed_index, num_addressable_indices,
local_to_global_shape, use_concrete_mesh) # pyformat: disable
from jax._src.typing import ArrayLike, DLDeviceType, DTypeLike from jax._src.typing import ArrayLike, DLDeviceType, DTypeLike
from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method, cache from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method, cache
import numpy as np import numpy as np

View File

@ -29,8 +29,8 @@ class SampleFn(Protocol):
def _compute_tile_index(block_index: Sequence[int], def _compute_tile_index(block_index: Sequence[int],
total_size_in_blocks: Shape,
block_size_in_tiles: Shape, block_size_in_tiles: Shape,
total_size_in_tiles: Shape,
tile_index_in_block: Sequence[int]) -> int: tile_index_in_block: Sequence[int]) -> int:
ndims = len(block_index) ndims = len(block_index)
dim_size = 1 dim_size = 1
@ -38,7 +38,7 @@ def _compute_tile_index(block_index: Sequence[int],
for i in range(ndims-1, -1, -1): for i in range(ndims-1, -1, -1):
dim_idx = tile_index_in_block[i] + block_index[i] * block_size_in_tiles[i] dim_idx = tile_index_in_block[i] + block_index[i] * block_size_in_tiles[i]
total_idx += dim_idx * dim_size total_idx += dim_idx * dim_size
dim_size *= total_size_in_blocks[i] * block_size_in_tiles[i] dim_size *= total_size_in_tiles[i]
return total_idx return total_idx
@ -103,15 +103,17 @@ def blocked_fold_in(
_shape // _element for _shape, _element in zip(block_size, tile_size) _shape // _element for _shape, _element in zip(block_size, tile_size)
) )
total_size_in_blocks = tuple( # Round up to make sure every tile is numbered.
_shape // _element for _shape, _element in zip(total_size, block_size) total_size_in_tiles = tuple(
(_shape + _element - 1) // _element
for _shape, _element in zip(total_size, tile_size)
) )
def _keygen_loop(axis, prefix): def _keygen_loop(axis, prefix):
if axis == len(block_size_in_tiles): if axis == len(block_size_in_tiles):
subtile_key = jax.random.fold_in( subtile_key = jax.random.fold_in(
global_key, _compute_tile_index( global_key, _compute_tile_index(
block_index, total_size_in_blocks, block_size_in_tiles, prefix)) block_index, block_size_in_tiles, total_size_in_tiles, prefix))
return subtile_key return subtile_key
else: else:
keys = [] keys = []

View File

@ -179,7 +179,7 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape,
for sharding, s in zip(result_shardings, result_shapes) for sharding, s in zip(result_shardings, result_shapes)
] ]
closed_jaxpr = jax.make_jaxpr(lower_fn, axis_env=list(mesh.shape.items()))( closed_jaxpr = jax.make_jaxpr(lower_fn, axis_env=list(mesh.shape.items()))(
*tiled_args *info.in_tree.unflatten(tiled_args)
) )
if ([(o.shape, o.dtype) for o in closed_jaxpr.out_avals] != if ([(o.shape, o.dtype) for o in closed_jaxpr.out_avals] !=
[(t.shape, t.dtype) for t in tiled_results]): [(t.shape, t.dtype) for t in tiled_results]):

View File

@ -41,7 +41,7 @@ from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval,
JaxprEqn, Primitive, ShapedArray, DShapedArray, JaxprEqn, Primitive, ShapedArray, DShapedArray,
mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx, mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx,
InputType, OutputType, get_referent, JaxprEqnContext) InputType, OutputType, get_referent, JaxprEqnContext)
from jax._src.state.types import AbstractRef from jax._src.state.types import AbstractRef, ReadEffect
from jax._src.tree_util import (PyTreeDef, treedef_tuple, from jax._src.tree_util import (PyTreeDef, treedef_tuple,
tree_flatten, tree_structure) tree_flatten, tree_structure)
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list, from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
@ -1423,7 +1423,8 @@ def dce_jaxpr_consts(jaxpr: Jaxpr, used_outputs: Sequence[bool],
def has_effects(eqn: JaxprEqn) -> bool: def has_effects(eqn: JaxprEqn) -> bool:
effs = {e for e in eqn.effects if not isinstance(e, core.NamedAxisEffect)} effs = {e for e in eqn.effects if not isinstance(e, core.NamedAxisEffect)
and not isinstance(e, ReadEffect)}
return bool(effs) return bool(effs)

View File

@ -615,8 +615,23 @@ def tanh(x: ArrayLike) -> Array:
""" """
return tanh_p.bind(x) return tanh_p.bind(x)
@export
def logistic(x: ArrayLike) -> Array: def logistic(x: ArrayLike) -> Array:
r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`.""" r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`.
There is no HLO logistic/sigmoid primitive, so this lowers to a sequence
of HLO arithmetic operations.
Args:
x: input array. Must have floating point or complex dtype.
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
logistic/sigmoid function.
See also:
- :func:`jax.nn.sigmoid`: an alternative API for this functionality.
"""
return logistic_p.bind(x) return logistic_p.bind(x)
@export @export
@ -1018,12 +1033,45 @@ def bitwise_xor(x: ArrayLike, y: ArrayLike) -> Array:
""" """
return xor_p.bind(x, y) return xor_p.bind(x, y)
@export
def population_count(x: ArrayLike) -> Array: def population_count(x: ArrayLike) -> Array:
r"""Elementwise popcount, count the number of set bits in each element.""" r"""Elementwise popcount, count the number of set bits in each element.
This function lowers directly to the `stablehlo.popcnt`_ operation.
Args:
x: Input array. Must have integer dtype.
Returns:
An array of the same shape and dtype as ``x``, containing the number of
set bits in the input.
See also:
- :func:`jax.lax.clz`: Elementwise count leading zeros.
- :func:`jax.numpy.bitwise_count`: More flexible NumPy-style API for bit counts.
.. _stablehlo.popcnt: https://openxla.org/stablehlo/spec#popcnt
"""
return population_count_p.bind(x) return population_count_p.bind(x)
@export
def clz(x: ArrayLike) -> Array: def clz(x: ArrayLike) -> Array:
r"""Elementwise count-leading-zeros.""" r"""Elementwise count-leading-zeros.
This function lowers directly to the `stablehlo.count_leading_zeros`_ operation.
Args:
x: Input array. Must have integer dtype.
Returns:
An array of the same shape and dtype as ``x``, containing the number of
set bits in the input.
See also:
- :func:`jax.lax.population_count`: Count the number of set bits in each element.
.. _stablehlo.count_leading_zeros: https://openxla.org/stablehlo/spec#count_leading_zeros
"""
return clz_p.bind(x) return clz_p.bind(x)
@export @export
@ -1124,31 +1172,81 @@ def div(x: ArrayLike, y: ArrayLike) -> Array:
""" """
return div_p.bind(x, y) return div_p.bind(x, y)
@export
def rem(x: ArrayLike, y: ArrayLike) -> Array: def rem(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise remainder: :math:`x \bmod y`. r"""Elementwise remainder: :math:`x \bmod y`.
The sign of the result is taken from the dividend, This function lowers directly to the `stablehlo.remainder`_ operation.
and the absolute value of the result is always The sign of the result is taken from the dividend, and the absolute value
less than the divisor's absolute value. of the result is always less than the divisor's absolute value.
Integer division overflow Integer division overflow (remainder by zero or remainder of INT_SMIN with -1)
(remainder by zero or remainder of INT_SMIN with -1)
produces an implementation defined value. produces an implementation defined value.
Args:
x, y: Input arrays. Must have matching int or float dtypes. If neither
is a scalar, ``x`` and ``y`` must have the same number of dimensions
and be broadcast compatible.
Returns:
An array of the same dtype as ``x`` and ``y`` containing the remainder.
See also:
- :func:`jax.numpy.remainder`: NumPy-style remainder with different
sign semantics.
.. _stablehlo.remainder: https://openxla.org/stablehlo/spec#remainder
""" """
return rem_p.bind(x, y) return rem_p.bind(x, y)
@export
def max(x: ArrayLike, y: ArrayLike) -> Array: def max(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise maximum: :math:`\mathrm{max}(x, y)` r"""Elementwise maximum: :math:`\mathrm{max}(x, y)`.
For complex numbers, uses a lexicographic comparison on the This function lowers directly to the `stablehlo.maximum`_ operation for
`(real, imaginary)` pairs.""" non-complex inputs. For complex numbers, this uses a lexicographic
comparison on the `(real, imaginary)` pairs.
Args:
x, y: Input arrays. Must have matching dtypes. If neither is a scalar,
``x`` and ``y`` must have the same rank and be broadcast compatible.
Returns:
An array of the same dtype as ``x`` and ``y`` containing the elementwise
maximum.
See also:
- :func:`jax.numpy.maximum`: more flexibly NumPy-style maximum.
- :func:`jax.lax.reduce_max`: maximum along an axis of an array.
- :func:`jax.lax.min`: elementwise minimum.
.. _stablehlo.maximum: https://openxla.org/stablehlo/spec#maximum
"""
return max_p.bind(x, y) return max_p.bind(x, y)
@export
def min(x: ArrayLike, y: ArrayLike) -> Array: def min(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise minimum: :math:`\mathrm{min}(x, y)` r"""Elementwise minimum: :math:`\mathrm{min}(x, y)`
For complex numbers, uses a lexicographic comparison on the This function lowers directly to the `stablehlo.minimum`_ operation for
`(real, imaginary)` pairs.""" non-complex inputs. For complex numbers, this uses a lexicographic
comparison on the `(real, imaginary)` pairs.
Args:
x, y: Input arrays. Must have matching dtypes. If neither is a scalar,
``x`` and ``y`` must have the same rank and be broadcast compatible.
Returns:
An array of the same dtype as ``x`` and ``y`` containing the elementwise
minimum.
See also:
- :func:`jax.numpy.minimum`: more flexibly NumPy-style minimum.
- :func:`jax.lax.reduce_min`: minimum along an axis of an array.
- :func:`jax.lax.max`: elementwise maximum.
.. _stablehlo.minimum: https://openxla.org/stablehlo/spec#minimum
"""
return min_p.bind(x, y) return min_p.bind(x, y)
@export @export
@ -1408,21 +1506,38 @@ def lt(x: ArrayLike, y: ArrayLike) -> Array:
""" """
return lt_p.bind(x, y) return lt_p.bind(x, y)
@export
def convert_element_type(operand: ArrayLike, def convert_element_type(operand: ArrayLike,
new_dtype: DTypeLike | dtypes.ExtendedDType) -> Array: new_dtype: DTypeLike | dtypes.ExtendedDType) -> Array:
"""Elementwise cast. """Elementwise cast.
Wraps XLA's `ConvertElementType This function lowers directly to the `stablehlo.convert`_ operation, which
<https://www.tensorflow.org/xla/operation_semantics#convertelementtype>`_ performs an elementwise conversion from one type to another, similar to a
operator, which performs an elementwise conversion from one type to another. C++ ``static_cast``.
Similar to a C++ `static_cast`.
Args: Args:
operand: an array or scalar value to be cast. operand: an array or scalar value to be cast.
new_dtype: a NumPy dtype representing the target type. new_dtype: a dtype-like object (e.g. a :class:`numpy.dtype`, a scalar type,
or a valid dtype name) representing the target dtype.
Returns: Returns:
An array with the same shape as `operand`, cast elementwise to `new_dtype`. An array with the same shape as ``operand``, cast elementwise to ``new_dtype``.
.. note::
If ``new_dtype`` is a 64-bit type and `x64 mode`_ is not enabled,
the appropriate 32-bit type will be used in its place.
If the input is a JAX array and the input dtype and output dtype match, then
the input array will be returned unmodified.
See also:
- :func:`jax.numpy.astype`: NumPy-style dtype casting API.
- :meth:`jax.Array.astype`: dtype casting as an array method.
- :func:`jax.lax.bitcast_convert_type`: cast bits directly to a new dtype.
.. _stablehlo.convert: https://openxla.org/stablehlo/spec#convert
.. _x64 mode: https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision
""" """
return _convert_element_type(operand, new_dtype, weak_type=False) # type: ignore[unused-ignore,bad-return-type] return _convert_element_type(operand, new_dtype, weak_type=False) # type: ignore[unused-ignore,bad-return-type]
@ -1500,12 +1615,11 @@ def _convert_element_type(
operand, new_dtype=new_dtype, weak_type=bool(weak_type), operand, new_dtype=new_dtype, weak_type=bool(weak_type),
sharding=sharding) sharding=sharding)
@export
def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array: def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array:
"""Elementwise bitcast. """Elementwise bitcast.
Wraps XLA's `BitcastConvertType This function lowers directly to the `stablehlo.bitcast_convert`_ operation.
<https://www.tensorflow.org/xla/operation_semantics#bitcastconverttype>`_
operator, which performs a bit cast from one type to another.
The output shape depends on the size of the input and output dtypes with The output shape depends on the size of the input and output dtypes with
the following logic:: the following logic::
@ -1525,6 +1639,12 @@ def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array:
Returns: Returns:
An array of shape `output_shape` (see above) and type `new_dtype`, An array of shape `output_shape` (see above) and type `new_dtype`,
constructed from the same bits as operand. constructed from the same bits as operand.
See also:
- :func:`jax.lax.convert_element_type`: value-preserving dtype conversion.
- :func:`jax.Array.view`: NumPy-style API for bitcast type conversion.
.. _stablehlo.bitcast_convert: https://openxla.org/stablehlo/spec#bitcast_convert
""" """
new_dtype = dtypes.canonicalize_dtype(new_dtype) new_dtype = dtypes.canonicalize_dtype(new_dtype)
return bitcast_convert_type_p.bind(operand, new_dtype=new_dtype) return bitcast_convert_type_p.bind(operand, new_dtype=new_dtype)

View File

@ -1301,6 +1301,7 @@ def _ragged_all_to_all_transpose(
mask = jax.numpy.cumsum( mask = jax.numpy.cumsum(
jax.numpy.zeros(t.shape[0], dtype='int32').at[output_offsets_].set(1)\ jax.numpy.zeros(t.shape[0], dtype='int32').at[output_offsets_].set(1)\
.at[output_offsets_ + recv_sizes].add(-1)) .at[output_offsets_ + recv_sizes].add(-1))
mask = jax.numpy.expand_dims(mask, (*range(1, t.ndim),))
output_t = jax.numpy.where(mask, 0, t) output_t = jax.numpy.where(mask, 0, t)
return [operand_t, output_t] + [None] * 4 return [operand_t, output_t] + [None] * 4

View File

@ -543,6 +543,10 @@ class UseAbstractMeshContextManager:
__slots__ = ['mesh', 'prev'] __slots__ = ['mesh', 'prev']
def __init__(self, mesh: AbstractMesh): def __init__(self, mesh: AbstractMesh):
if not isinstance(mesh, AbstractMesh):
raise ValueError(
"Expected mesh of type `jax.sharding.AbstractMesh`. Got type:"
f" {type(mesh)}")
self.mesh = mesh self.mesh = mesh
def __enter__(self): def __enter__(self):
@ -557,13 +561,5 @@ def get_abstract_mesh():
val = jax_config.abstract_mesh_context_manager.value val = jax_config.abstract_mesh_context_manager.value
return empty_abstract_mesh if val is None else val return empty_abstract_mesh if val is None else val
@contextlib.contextmanager
def use_concrete_mesh(mesh: Mesh | None):
prev_val = jax_config.device_context.swap_local(mesh)
try:
yield
finally:
jax_config.device_context.set_local(prev_val)
def get_concrete_mesh() -> Mesh | None: def get_concrete_mesh() -> Mesh | None:
return jax_config.device_context.value return jax_config.device_context.value

View File

@ -71,7 +71,7 @@ import numpy as np
export = set_module('jax.numpy') export = set_module('jax.numpy')
for pkg_name in ['jax_cuda12_plugin', 'jax.jaxlib']: for pkg_name in ['jax_cuda12_plugin', 'jax.jaxlib.cuda']:
try: try:
cuda_plugin_extension = importlib.import_module( cuda_plugin_extension = importlib.import_module(
f'{pkg_name}.cuda_plugin_extension' f'{pkg_name}.cuda_plugin_extension'

View File

@ -35,6 +35,7 @@ from jax._src import linear_util as lu
from jax._src import state from jax._src import state
from jax._src import tree_util from jax._src import tree_util
from jax._src import util from jax._src import util
from jax._src.export._export import export
from jax._src.interpreters import mlir from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import partial_eval as pe
from jax._src.state import discharge as state_discharge from jax._src.state import discharge as state_discharge
@ -1165,14 +1166,16 @@ jax_core.custom_typechecks[core_map_p] = _core_map_typecheck_rule
def lower_as_mlir( def lower_as_mlir(
f, *args, dynamic_shapes=False, device=None, **kwargs f, *args, dynamic_shapes=False, device=None, static_argnames=(), **kwargs
) -> mlir.ir.Module: ) -> mlir.ir.Module:
with pallas_export_experimental(dynamic_shapes): with pallas_export_experimental(dynamic_shapes):
lowered = jax.jit(f, device=device).lower(*args, **kwargs) f = jax.jit(f, device=device, static_argnames=static_argnames)
stablehlo = lowered.compiler_ir(dialect="stablehlo") exported = export(f, platforms=["tpu"])(*args, **kwargs)
stablehlo = exported.mlir_module()
return stablehlo # type: ignore[return-value] return stablehlo # type: ignore[return-value]
_out_shape_to_aval_mapping: dict[ _out_shape_to_aval_mapping: dict[
type[Any], Callable[[Any], jax_core.AbstractValue] type[Any], Callable[[Any], jax_core.AbstractValue]
] = {} ] = {}

View File

@ -244,6 +244,13 @@ def pull_block_spec(
_unwrap_block_spec_scalar_prefetch, out_block_specs _unwrap_block_spec_scalar_prefetch, out_block_specs
) )
flat_block_specs, out_tree = jax.tree.flatten(block_specs_) flat_block_specs, out_tree = jax.tree.flatten(block_specs_)
jaxpr, used_consts, used_invars = pe.dce_jaxpr_consts(
jaxpr,
used_outputs=[True] * len(jaxpr.outvars),
instantiate=True,
)
assert all(used_invars)
assert all(used_consts)
in_block_specs, env, read_usage_env = _pull_block_spec( in_block_specs, env, read_usage_env = _pull_block_spec(
jaxpr, jaxpr,
tuple(flat_block_specs), tuple(flat_block_specs),

View File

@ -83,10 +83,15 @@ class TPUInterpretParams:
replaced with arrays all of `jnp.inf`. Additionaly any floating point replaced with arrays all of `jnp.inf`. Additionaly any floating point
operands to any operation will be replaced with (arrays of) `jnp.inf`. operands to any operation will be replaced with (arrays of) `jnp.inf`.
Default: False. Default: False.
uninitialized_memory: If "nan", allocated buffers are initialized to
to contain all NaNs (or to their maximum possible value for integers).
If "zero", allocated buffers are initialized to all zeros.
Default: "nan".
""" """
dma_execution_mode: Literal["eager", "on_wait"] = "on_wait" dma_execution_mode: Literal["eager", "on_wait"] = "on_wait"
detect_races: bool = False detect_races: bool = False
skip_floating_point_ops: bool = False skip_floating_point_ops: bool = False
uninitialized_memory: Literal["nan", "zero"] = "nan"
VectorClock = np.ndarray VectorClock = np.ndarray
@ -1114,7 +1119,8 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params, interpret_params):
jax.ShapeDtypeStruct((), jnp.int16), jax.ShapeDtypeStruct((), jnp.int16),
device_id, device_id,
TPU_MEMORY_SPACE_IDXS[v.aval.memory_space], TPU_MEMORY_SPACE_IDXS[v.aval.memory_space],
primitives.uninitialized_value(v.aval.shape, v.aval.dtype), _uninitialized_value(
v.aval.shape, v.aval.dtype, interpret_params),
ordered=True)) ordered=True))
out = _interpret(eqn.params['jaxpr'], *deferred_invals(), *allocs) out = _interpret(eqn.params['jaxpr'], *deferred_invals(), *allocs)
@ -1279,16 +1285,19 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params, interpret_params):
def _initialize_output_vals( def _initialize_output_vals(
block_mappings_output: Iterable[BlockMapping], block_mappings_output: Iterable[BlockMapping],
input_args, input_output_aliases) -> Sequence[jax.Array]: input_args, input_output_aliases,
interpret_params: TPUInterpretParams,
) -> Sequence[jax.Array]:
oi_map = {v: k for k, v in input_output_aliases} oi_map = {v: k for k, v in input_output_aliases}
output_vals = [] output_vals = []
for i, bm in enumerate(block_mappings_output): for i, bm in enumerate(block_mappings_output):
if i in oi_map: if i in oi_map:
output_vals.append(input_args[oi_map[i]]) output_vals.append(input_args[oi_map[i]])
else: else:
output_vals.append(primitives.uninitialized_value( output_vals.append(_uninitialized_value(
bm.array_shape_dtype.shape, bm.array_shape_dtype.shape,
bm.array_shape_dtype.dtype)) bm.array_shape_dtype.dtype,
interpret_params))
return output_vals return output_vals
def _compute_start_indices(block_mapping, loop_idx, *args): def _compute_start_indices(block_mapping, loop_idx, *args):
@ -1319,7 +1328,20 @@ def _maybe_dynamic_slice(start_idx, block_shape, value, is_indexing):
dtype=np.bool_)]) dtype=np.bool_)])
return lax.squeeze(output, squeeze_dims) return lax.squeeze(output, squeeze_dims)
def _pad_to_block_dimension(value, block_shape): def _uninitialized_value(shape, dtype, interpret_params):
if interpret_params.uninitialized_memory == 'nan':
if jnp.issubdtype(dtype, jnp.floating):
return jnp.full(shape, jnp.nan, dtype)
elif jnp.issubdtype(dtype, jnp.integer):
return jnp.full(shape, jnp.iinfo(dtype).max, dtype)
elif jnp.issubdtype(dtype, jnp.bool):
return jnp.full(shape, False, dtype)
if interpret_params.uninitialized_memory == 'zero':
return jnp.full(shape, 0, dtype)
raise NotImplementedError(
interpret_params.uninitialized_memory + ' + ' + str(dtype))
def _pad_to_block_dimension(value, block_shape, interpret_params):
"""Pads values so the shape evenly divides into block dimensions. """Pads values so the shape evenly divides into block dimensions.
For example, if values has a shape of (33, 2, 5) with a block_shape of For example, if values has a shape of (33, 2, 5) with a block_shape of
@ -1338,7 +1360,7 @@ def _pad_to_block_dimension(value, block_shape):
) )
if padded_shape != value.shape: if padded_shape != value.shape:
pad_width = tuple((0, a-b) for a, b in zip(padded_shape, value.shape)) pad_width = tuple((0, a-b) for a, b in zip(padded_shape, value.shape))
pad_value = primitives.uninitialized_value(shape=(), dtype=value.dtype) pad_value = _uninitialized_value((), value.dtype, interpret_params)
value = jnp.pad(value, pad_width, constant_values=pad_value) value = jnp.pad(value, pad_width, constant_values=pad_value)
return value return value
@ -1397,7 +1419,7 @@ def interpret_pallas_call(
] ]
num_inputs = grid_mapping.num_inputs num_inputs = grid_mapping.num_inputs
input_args = [ input_args = [
_pad_to_block_dimension(a, bs) _pad_to_block_dimension(a, bs, interpret_params)
for a, bs in zip(input_args, block_shapes[:num_inputs]) for a, bs in zip(input_args, block_shapes[:num_inputs])
] ]
@ -1407,11 +1429,12 @@ def interpret_pallas_call(
output_vals = _initialize_output_vals( output_vals = _initialize_output_vals(
grid_mapping.block_mappings_output, grid_mapping.block_mappings_output,
scalars + input_args, scalars + input_args,
input_output_aliases) input_output_aliases,
interpret_params)
num_outputs = grid_mapping.num_outputs num_outputs = grid_mapping.num_outputs
output_block_shapes = block_shapes[num_inputs : num_inputs + num_outputs] output_block_shapes = block_shapes[num_inputs : num_inputs + num_outputs]
for out_val, bs in zip(output_vals, output_block_shapes): for out_val, bs in zip(output_vals, output_block_shapes):
padded_val = _pad_to_block_dimension(out_val, bs) padded_val = _pad_to_block_dimension(out_val, bs, interpret_params)
output_buffer_shapes.append(padded_val.shape) output_buffer_shapes.append(padded_val.shape)
output_buffer_ids.append(callback.io_callback( output_buffer_ids.append(callback.io_callback(
_allocate_buffer, _allocate_buffer,
@ -1466,7 +1489,8 @@ def interpret_pallas_call(
jax.ShapeDtypeStruct((), jnp.int16), jax.ShapeDtypeStruct((), jnp.int16),
device_id, device_id,
TPU_MEMORY_SPACE_IDXS[var.aval.memory_space], TPU_MEMORY_SPACE_IDXS[var.aval.memory_space],
primitives.uninitialized_value(var.aval.shape, var.aval.dtype), _uninitialized_value(
var.aval.shape, var.aval.dtype, interpret_params),
ordered=True)) ordered=True))
_, input_ids, kernel_output_ids, _ = split_list( _, input_ids, kernel_output_ids, _ = split_list(

View File

@ -40,6 +40,7 @@ from jax._src import source_info_util
from jax._src import state from jax._src import state
from jax._src import traceback_util from jax._src import traceback_util
from jax._src.cloud_tpu_init import is_cloud_tpu_older_than from jax._src.cloud_tpu_init import is_cloud_tpu_older_than
from jax._src.export._export import export
from jax._src.interpreters import mlir from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import partial_eval as pe
from jax._src.lax import lax as lax_internal from jax._src.lax import lax as lax_internal
@ -89,6 +90,11 @@ BOOL_MEMREF_TYPE = np.dtype('int32')
# The value interpreted as a dynamic dimension by MLIR. # The value interpreted as a dynamic dimension by MLIR.
MLIR_DYNAMIC = -9223372036854775808 MLIR_DYNAMIC = -9223372036854775808
# TODO(mvoz): Find a way to make this a contract we can share with the
# export specialization step in XLA export.
DIM_UPPER_BOUND = np.iinfo(np.int32).max
DIM_LOWER_BOUND = -128
partial = functools.partial partial = functools.partial
map, unsafe_map = safe_map, map # pylint: disable=redefined-builtin map, unsafe_map = safe_map, map # pylint: disable=redefined-builtin
zip, unsafe_zip = safe_zip, zip # pylint: disable=redefined-builtin zip, unsafe_zip = safe_zip, zip # pylint: disable=redefined-builtin
@ -102,17 +108,49 @@ class MeshContext:
# Note - On Export Placeholders # Note - On Export Placeholders
# #
# Mosaic uses vector IR, which does not have a concept of dynamic # Since the vector dialect used by Mosaic does not support dynamic shapes,
# dimensions. We need to come up with a way to represent dynamic dimensions in # we replace all top-level symbolic dimensions with placeholder
# vector IR, and so we use placeholders, which are later replaced during # constants (between max(int32) - 128 and max(int32)) and we keep a
# specialization. # mapping from the placeholder constants to SHLO functions that encode
# the symbolic dimension expression, as a function of the dimension
# variables.
#
# The calling convention of the produced MLIR module is the same as
# regular mosaic module, except we add on two new attributes to the custom call
# *per* intermediary placeholder dimension.
#
# The attributes are:
#
# tpu.dynamic_dimension_mapping_arg_name_<placeholder>
# tpu.dynamic_dimension_mapping_module_<placeholder>
#
# The first attribute is a comma-separated list of the dimension variables
# that are used to compute the symbolic dimension expression for the
# placeholder. The second attribute is the MLIR module that contains the
# SHLO functions that compute the symbolic dimension expression for the
# placeholder.
class LoweringDynamicShapeEnv: class LoweringDynamicShapeEnv:
dim_expr_to_placeholder: dict[Any, ir.Value] = {} dim_expr_to_placeholder: dict[shape_poly._DimExpr, int] = {}
placeholder_to_dim_expr: dict[int, shape_poly._DimExpr] = {}
def to_placeholder(self, dim_expr: Any) -> ir.Value: def to_placeholder(self, dim_expr: Any) -> ir.Value:
if jax_core.is_constant_dim(dim_expr):
# avoid ints, these are not dynamic
return dim_expr
if dim_expr not in self.dim_expr_to_placeholder: if dim_expr not in self.dim_expr_to_placeholder:
next_val = np.iinfo(np.int32).max - len(self.dim_expr_to_placeholder) next_val = DIM_UPPER_BOUND - len(self.dim_expr_to_placeholder)
if next_val < DIM_LOWER_BOUND:
# In practice, even with the largest of programs, we see rarely see
# anything even close to this limit. It is arbitrary, and can be safely
# increased if needed.
raise ValueError(
"Too many dynamic shapes in the input. Mosaic currently only"
" supports up to 128 dynamic dimension values."
)
self.dim_expr_to_placeholder[dim_expr] = next_val self.dim_expr_to_placeholder[dim_expr] = next_val
# Reverse mapping - this is consumed to generate a table that is either
# input<>placeholder or intermediary computation<>placeholder.
self.placeholder_to_dim_expr[next_val] = dim_expr
return self.dim_expr_to_placeholder[dim_expr] return self.dim_expr_to_placeholder[dim_expr]
@ -622,6 +660,7 @@ def lower_jaxpr_to_module(
"Pallas TPU requires a libTPU version that's at most a month old" "Pallas TPU requires a libTPU version that's at most a month old"
) )
debug_info = jaxpr.debug_info debug_info = jaxpr.debug_info
_mosaic_lowering_dynamic_shape_env = None
if dynamic_shape_replacement_enabled: if dynamic_shape_replacement_enabled:
_mosaic_lowering_dynamic_shape_env = LoweringDynamicShapeEnv() _mosaic_lowering_dynamic_shape_env = LoweringDynamicShapeEnv()
@ -663,10 +702,12 @@ def lower_jaxpr_to_module(
for_verification=for_verification, for_verification=for_verification,
forward_compatible=lowering_context.is_forward_compat(), forward_compatible=lowering_context.is_forward_compat(),
dynamic_shape_replacement_fn=dynamic_shape_replacement_fn, dynamic_shape_replacement_fn=dynamic_shape_replacement_fn,
dynamic_shape_replacement_enabled=dynamic_shape_replacement_enabled,
) )
m.body.append(func_op) m.body.append(func_op)
sym_tab.insert(func_op) sym_tab.insert(func_op)
window_params = [] window_params = []
static_grid = None
grid = mosaic_grid_mapping.grid grid = mosaic_grid_mapping.grid
if grid: if grid:
for i, bm in enumerate(grid_mapping.block_mappings): for i, bm in enumerate(grid_mapping.block_mappings):
@ -738,7 +779,6 @@ def lower_jaxpr_to_module(
] ]
static_grid = dynamic_shape_replacement_fn(static_grid) static_grid = dynamic_shape_replacement_fn(static_grid)
func_op.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(static_grid) func_op.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(static_grid)
func_op.attributes["scalar_prefetch"] = ir.IntegerAttr.get( func_op.attributes["scalar_prefetch"] = ir.IntegerAttr.get(
ir.IntegerType.get_signless(64), len(mosaic_grid_mapping.scalar_prefetch_types)) ir.IntegerType.get_signless(64), len(mosaic_grid_mapping.scalar_prefetch_types))
func_op.attributes["scratch_operands"] = ir.IntegerAttr.get( func_op.attributes["scratch_operands"] = ir.IntegerAttr.get(
@ -746,6 +786,60 @@ def lower_jaxpr_to_module(
func_op.attributes["dimension_semantics"] = ( func_op.attributes["dimension_semantics"] = (
mosaic_grid_mapping.get_dimension_semantics() mosaic_grid_mapping.get_dimension_semantics()
) )
if dynamic_shape_replacement_enabled:
if _mosaic_lowering_dynamic_shape_env is None:
raise ValueError(
"Dynamic shape env is None, invariant violated. Unreachable?"
)
# Now we can use jax to compute the dynamic shape graph
if static_grid is not None:
grid_vars = [
_mosaic_lowering_dynamic_shape_env.placeholder_to_dim_expr.get(g, g)
for g in static_grid
]
else:
grid_vars = []
invars = [invar.aval for invar in jaxpr.invars]
# Faux shape for grid, just to get the avals
invars.append(jax.ShapeDtypeStruct(grid_vars, jax.numpy.int32))
args_dimvars = shape_poly.all_dim_vars(invars)
# This is dimexpr var -> placeholder value for when we jit the dim expr
env: dict[str, int] = {}
for aval in args_dimvars:
env[aval] = _mosaic_lowering_dynamic_shape_env.to_placeholder(aval)
for (
placeholder,
dim_expr,
) in _mosaic_lowering_dynamic_shape_env.placeholder_to_dim_expr.items():
top_level_names = list(env.keys())
if dim_expr not in top_level_names:
jitted_eval = jax.jit(
jax_core.evaluate_shape,
static_argnames=(
"shape",
"dim_vars",
),
keep_unused=True,
)
stablehlo = export(
jitted_eval, platforms=[str(jax.devices()[0].platform)]
)(
(dim_expr,), tuple(args_dimvars), *(env[v] for v in args_dimvars)
).mlir_module()
arg_name = args_dimvars
# See Note - On Export Placeholders for more details.
m.operation.attributes[
"tpu.dynamic_dimension_mapping_module_" + str(placeholder)
] = ir.StringAttr.get(str(stablehlo))
arg_name_str = ",".join(arg_name)
m.operation.attributes[
"tpu.dynamic_dimension_mapping_arg_name_" + str(placeholder)
] = ir.StringAttr.get(arg_name_str)
return m, mosaic_grid_mapping.get_extra_args() return m, mosaic_grid_mapping.get_extra_args()
@ -828,6 +922,7 @@ def lower_jaxpr_to_func(
dynamic_shape_replacement_fn: ( dynamic_shape_replacement_fn: (
Callable[[tuple[jax.DimSize, ...]], tuple[int, ...]] | None Callable[[tuple[jax.DimSize, ...]], tuple[int, ...]] | None
) = None, ) = None,
dynamic_shape_replacement_enabled: bool = False,
) -> func.FuncOp: ) -> func.FuncOp:
num_grid = len(mosaic_grid_mapping.grid_types) num_grid = len(mosaic_grid_mapping.grid_types)
num_scalar_prefetch = len(mosaic_grid_mapping.scalar_prefetch_types) num_scalar_prefetch = len(mosaic_grid_mapping.scalar_prefetch_types)
@ -874,6 +969,12 @@ def lower_jaxpr_to_func(
) )
body_func.__name__ = name body_func.__name__ = name
body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func) body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func)
if dynamic_shape_replacement_enabled:
# Skip verification for dynamic shape replacement - you can potentially
# produce ir like ex: add(x[placeholder_0, placeholder_1], y[128, 128])
# which is not valid, but we don't care since we'll run the verifier again
# after the dynamic shape replacement pass.
return body.func_op
try: try:
body.func_op.verify() body.func_op.verify()
except ir.MLIRError as e: except ir.MLIRError as e:
@ -3851,3 +3952,15 @@ def _platform_index_lowering(
lowering_rules[jax._src.lax.control_flow.platform_index_p] = _platform_index_lowering lowering_rules[jax._src.lax.control_flow.platform_index_p] = _platform_index_lowering
def _dim_as_value_lowering(ctx: mlir.LoweringRuleContext, *, dim):
placeholder = ctx.lowering_context.dynamic_shape_replacement_fn((dim,))[0]
return ir_constant(
placeholder, mlir_type=_dtype_to_ir_type(jnp.dtype("int32"))
)
import jax._src.export.shape_poly as shape_poly
lowering_rules[shape_poly.dim_as_value_p] = _dim_as_value_lowering

View File

@ -1196,9 +1196,8 @@ def emit_pipeline(
schedule = map_brefs( schedule = map_brefs(
lambda _, x: get_pipeline_schedule(x), allocations, schedule) lambda _, x: get_pipeline_schedule(x), allocations, schedule)
def loop_body(step, indices): def make_scheduler(step, indices):
nonlocal allocations return Scheduler(
scheduler = Scheduler(
step, step,
indices, indices,
grid, grid,
@ -1208,13 +1207,15 @@ def emit_pipeline(
init_accumulators=init_accumulators, init_accumulators=init_accumulators,
trace_scopes=trace_scopes, trace_scopes=trace_scopes,
) )
def loop_body(step, indices):
scheduler = make_scheduler(step, indices)
with scheduler.grid_env(): with scheduler.grid_env():
# prepare any local VMEM aliases # prepare any local VMEM aliases
brefs = map_brefs(scheduler.alias_local_refs, allocations, refs) brefs = map_brefs(scheduler.alias_local_refs, allocations, refs)
# loop input handling phase # loop input handling phase
map_brefs(scheduler.initialize, brefs, refs, schedule)
map_brefs(scheduler.copy_in, brefs, refs, schedule) map_brefs(scheduler.copy_in, brefs, refs, schedule)
map_brefs(scheduler.wait_in, brefs, refs, schedule) map_brefs(scheduler.wait_in, brefs, refs, schedule)
@ -1243,12 +1244,24 @@ def emit_pipeline(
lambda: None) lambda: None)
map_brefs(scheduler.swap_slots, brefs, refs, schedule) map_brefs(scheduler.swap_slots, brefs, refs, schedule)
map_brefs(scheduler.finalize, brefs, refs, schedule)
return _next_index(indices, grid) return _next_index(indices, grid)
# run pipeline @pl.when(num_steps > 0)
lax.fori_loop(0, num_steps, loop_body, (0,) * len(grid)) def _():
# pipeline prologue
initial_indices = (0,) * len(grid)
scheduler = make_scheduler(0, initial_indices)
brefs = map_brefs(scheduler.alias_local_refs, allocations, refs)
map_brefs(scheduler.initialize, brefs, refs, schedule)
# pipeline loop
next_indices = lax.fori_loop(0, num_steps, loop_body, initial_indices)
# pipeline epilogue
final_indices = _prev_index(next_indices, grid)
scheduler = make_scheduler(num_steps - 1, final_indices)
brefs = map_brefs(scheduler.alias_local_refs, allocations, refs)
map_brefs(scheduler.finalize, brefs, refs, schedule)
return pipeline return pipeline

View File

@ -1387,16 +1387,38 @@ def use_mesh(mesh: mesh_lib.Mesh):
# if not core.trace_state_clean(): # if not core.trace_state_clean():
# raise ValueError('`use_mesh` can only be used outside of `jax.jit`') # raise ValueError('`use_mesh` can only be used outside of `jax.jit`')
with (mesh_lib.use_abstract_mesh(mesh.abstract_mesh), with mesh_lib.use_abstract_mesh(mesh.abstract_mesh), use_concrete_mesh(mesh):
mesh_lib.use_concrete_mesh(mesh)):
yield yield
def set_mesh(mesh: mesh_lib.Mesh) -> None: def set_mesh(mesh: mesh_lib.Mesh | None) -> mesh_lib.Mesh | None:
if not isinstance(mesh, mesh_lib.Mesh): """Sets the given concrete mesh globally and returns the previous concrete
mesh."""
if mesh is not None and not isinstance(mesh, mesh_lib.Mesh):
raise ValueError( raise ValueError(
f"Expected mesh of type `jax.sharding.Mesh`. Got {type(mesh)}") f"Expected mesh of type `jax.sharding.Mesh`. Got {type(mesh)}")
if not core.trace_state_clean(): if not core.trace_state_clean():
raise ValueError('`set_mesh` can only be used outside of `jax.jit`.') raise ValueError('`set_mesh` can only be used outside of `jax.jit`.')
config.abstract_mesh_context_manager.set_local(mesh.abstract_mesh) if mesh is None:
config.device_context.set_local(mesh) config.abstract_mesh_context_manager.set_global(mesh_lib.empty_abstract_mesh) # type: ignore
else:
config.abstract_mesh_context_manager.set_global(mesh.abstract_mesh) # type: ignore
prev_mesh = config.device_context.get_global()
config.device_context.set_global(mesh)
return prev_mesh
@contextlib.contextmanager
def use_concrete_mesh(mesh: mesh_lib.Mesh | None):
if mesh is not None and not isinstance(mesh, mesh_lib.Mesh):
raise ValueError(
f"Expected mesh of type `jax.sharding.Mesh`. Got {type(mesh)}")
# TODO(yashkatariya): Enable this.
# if not core.trace_state_clean():
# raise ValueError('`use_concrete_mesh` can only be used outside of `jax.jit`.')
prev_val = config.device_context.swap_local(mesh)
try:
yield
finally:
config.device_context.set_local(prev_val)

View File

@ -41,20 +41,16 @@ from jaxlib.mlir.dialects import memref
from jaxlib.mlir.dialects import nvvm from jaxlib.mlir.dialects import nvvm
import numpy as np import numpy as np
if dialect is not None:
from . import dialect_lowering
from . import layout_inference
else:
dialect_lowering = None
layout_inference = None
from . import profiler
from . import utils
from . import launch_context
from . import tcgen05
# mypy: ignore-errors # mypy: ignore-errors
from . import dialect_lowering
from . import launch_context
from . import layout_inference
from . import profiler
from . import tcgen05
from . import transform_inference
from . import utils
# MLIR can't find libdevice unless we point it to the CUDA path # MLIR can't find libdevice unless we point it to the CUDA path
# TODO(apaszke): Unify with jax._src.lib.cuda_path # TODO(apaszke): Unify with jax._src.lib.cuda_path
CUDA_ROOT = "/usr/local/cuda" CUDA_ROOT = "/usr/local/cuda"
@ -584,6 +580,7 @@ def as_gpu_kernel(
# Run Python lowering passes. The remaining passes will be run in C++ in # Run Python lowering passes. The remaining passes will be run in C++ in
# jax/jaxlib/mosaic/gpu/custom_call.cc # jax/jaxlib/mosaic/gpu/custom_call.cc
layout_inference.infer_layout(module) # pytype: disable=attribute-error layout_inference.infer_layout(module) # pytype: disable=attribute-error
transform_inference.infer_transforms(module) # pytype: disable=attribute-error
dialect_lowering.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error dialect_lowering.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error
_initialize_scratch(launch_ctx, scratch_arr) _initialize_scratch(launch_ctx, scratch_arr)
@ -666,6 +663,7 @@ def as_torch_gpu_kernel(
# Run Python lowering passes. The remaining passes will be run in C++ in # Run Python lowering passes. The remaining passes will be run in C++ in
# jax/jaxlib/mosaic/gpu/custom_call.cc # jax/jaxlib/mosaic/gpu/custom_call.cc
layout_inference.infer_layout(module) # pytype: disable=attribute-error layout_inference.infer_layout(module) # pytype: disable=attribute-error
transform_inference.infer_transforms(module) # pytype: disable=attribute-error
dialect_lowering.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error dialect_lowering.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error
_initialize_scratch(launch_ctx, scratch_arr) _initialize_scratch(launch_ctx, scratch_arr)

View File

@ -17,6 +17,7 @@
from collections.abc import Callable from collections.abc import Callable
import dataclasses import dataclasses
import functools import functools
import itertools
import operator import operator
from typing import Any, Sequence, Type, cast from typing import Any, Sequence, Type, cast
@ -58,7 +59,7 @@ class LoweringContext:
if not _should_lower(op): if not _should_lower(op):
return return
if (name := op.OPERATION_NAME) not in _lowerings: if (name := op.OPERATION_NAME) not in _lowerings: # pytype: disable=attribute-error
raise NotImplementedError(f"Missing lowering rule for {op}") raise NotImplementedError(f"Missing lowering rule for {op}")
lowering_rule = _lowerings[name] lowering_rule = _lowerings[name]
@ -227,6 +228,60 @@ def _arith_constant_op_lowering_rule(
] ]
def _check_transforms_and_swizzle_are_supported(
ref_ty: ir.MemRefType,
transforms: Sequence[launch_context.MemRefTransform],
swizzle: mgpu.SwizzlingMode,
minimum_swizzle: mgpu.SwizzlingMode = mgpu.SwizzlingMode.kNoSwizzle,
):
"""Checks that the list of provided transforms and swizzle are supported.
Currently, we allow the following:
- any swizzle that is larger than or equal to `minimum_swizzle`;
- optionally, a single tile transform (with rank equal to the rank of the
memref being annotated);
- optionally, a single transpose transform.
"""
if swizzle < minimum_swizzle:
raise NotImplementedError(
f"Unsupported swizzle {swizzle} smaller than {minimum_swizzle}."
)
partitioned_transforms = {
k: list(v)
for k, v in itertools.groupby(
transforms, lambda t: isinstance(t, launch_context.TileTransform)
)
}
tile_transforms = partitioned_transforms.get(True, [])
other_transforms = partitioned_transforms.get(False, [])
if len(tile_transforms) > 1:
raise NotImplementedError(
f"{tile_transforms} contains more than one tile transform."
)
if len(tile_transforms) == 1:
if len(tile_transforms[0].tiling) != len(ref_ty.shape):
raise NotImplementedError(
f"Only tile transforms with rank equal to the rank of the memref "
f"being annotated are supported but got {tile_transforms[0]} for "
f"{ref_ty}."
)
if len(other_transforms) > 1:
raise NotImplementedError(
f"{other_transforms} contains more than one transform."
)
if len(other_transforms) == 1:
if not isinstance(other_transforms[0], launch_context.TransposeTransform):
raise NotImplementedError(
f"{other_transforms[0]} is not a transpose transform."
)
@_register_lowering(vector.LoadOp) @_register_lowering(vector.LoadOp)
def _vector_load_op_lowering_rule( def _vector_load_op_lowering_rule(
_: LoweringContext, vector_load_op: vector.LoadOp _: LoweringContext, vector_load_op: vector.LoadOp
@ -260,8 +315,11 @@ def _vector_load_op_lowering_rule(
vec_size=strided_layout.vec_size, vec_size=strided_layout.vec_size,
) )
elif layouts.from_layout_attr(out_layout_attr) == fa.WGMMA_LAYOUT: elif layouts.from_layout_attr(out_layout_attr) == fa.WGMMA_LAYOUT:
layout = ir.MemRefType(vector_load_op.base.type).layout swizzle, transforms = swizzle_and_transforms_from_transforms_attr(
swizzle, transforms = memref_layout_to_swizzle_and_transforms(layout) inference_utils.in_transforms(vector_load_op)[0]
)
ref_ty = ir.MemRefType(vector_load_op.base.type)
_check_transforms_and_swizzle_are_supported(ref_ty, transforms, swizzle)
transformed_ref = transform_memref(vector_load_op.base, transforms) transformed_ref = transform_memref(vector_load_op.base, transforms)
fragmented_array = fa.FragmentedArray.load_tiled( fragmented_array = fa.FragmentedArray.load_tiled(
transformed_ref, transformed_ref,
@ -297,8 +355,22 @@ def _vector_store_op_lowering_rule(
vector_store_op.valueToStore, to_store_layout vector_store_op.valueToStore, to_store_layout
) )
# TODO(dasenov): This is not efficient for WGMMA layouts if fragmented_array.layout == fa.WGMMA_LAYOUT:
fragmented_array.store_untiled(vector_store_op.base) swizzle, transforms = swizzle_and_transforms_from_transforms_attr(
inference_utils.in_transforms(vector_store_op)[0]
)
ref_ty = ir.MemRefType(vector_store_op.base.type)
_check_transforms_and_swizzle_are_supported(ref_ty, transforms, swizzle)
fragmented_array.store_tiled(
transform_memref(vector_store_op.base, transforms), swizzle
)
elif (isinstance(fragmented_array.layout, fa.WGStridedFragLayout) or
isinstance(fragmented_array.layout, fa.WGSplatFragLayout)):
fragmented_array.store_untiled(vector_store_op.base)
else:
raise ValueError(
f"{vector_store_op} has an unsupported layout: {to_store_layout}"
)
return [] return []
@ -362,39 +434,43 @@ def _vector_reduction_op_lowering_rule(
return [_fragmented_array_to_ir(result, op.result.type)] return [_fragmented_array_to_ir(result, op.result.type)]
def memref_layout_to_swizzle_and_transforms( def swizzle_and_transforms_from_transforms_attr(
layout: ir.Attribute, transforms: ir.ArrayAttr,
) -> tuple[mgpu.SwizzlingMode, tuple[launch_context.MemRefTransform, ...]]: ) -> tuple[mgpu.SwizzlingMode, tuple[launch_context.MemRefTransform, ...]]:
"""Returns the swizzle and transforms that are encoded in the given layout. """Returns the swizzle and MemrefTransforms for the given transforms.
If the layout is not a LayoutAttr, the swizzle is kNoSwizzle and the Args:
transforms are empty. Otherwise, the layout may have at most one swizzle transforms: a list of transform attributes.
transform and any combination of tiling and transpose transforms.
Returns:
A tuple containing the swizzle mode and MemRefTransforms corresponding to
the parameter transforms. If `transforms` is empty, or does not contain
any swizzling transform, the swizzle mode is assumed to be kNoSwizzle.
Raises:
ValueError: if a swizzling transform is followed by any transform.
""" """
swizzle = None swizzle = None
gmem_transforms: list[launch_context.MemRefTransform] = [] gmem_transforms: list[launch_context.MemRefTransform] = []
if mgpu.LayoutAttr.isinstance(layout): for transform in transforms:
transforms_attr = mgpu.LayoutAttr(layout).transforms if swizzle is not None:
for transform in transforms_attr: raise ValueError(f"{transforms} contain more transforms after swizzle.")
if swizzle is not None: if mgpu.SwizzleTransformAttr.isinstance(transform):
raise ValueError(f"{layout} contains more transforms after the initial swizzle.") # TODO(dasenov): Swizzling can change if the ref is sliced in certain
if mgpu.SwizzleTransformAttr.isinstance(transform): # ways. We might want to enforce some restrictions here.
# TODO(dasenov): Swizzling can change if the ref is sliced in certain swizzle = mgpu.SwizzleTransformAttr(transform).swizzle
# ways. We might want to enforce some restrictions here. elif mgpu.TileTransformAttr.isinstance(transform):
swizzle = mgpu.SwizzleTransformAttr(transform).swizzle tiling = mgpu.TileTransformAttr(transform).tiling
elif mgpu.TileTransformAttr.isinstance(transform): tiling_transform = launch_context.TileTransform(tuple(tiling))
tiling = mgpu.TileTransformAttr(transform).tiling gmem_transforms.append(tiling_transform)
tiling_transform = launch_context.TileTransform(tuple(tiling)) elif mgpu.TransposeTransformAttr.isinstance(transform):
gmem_transforms.append(tiling_transform) permutation = mgpu.TransposeTransformAttr(transform).permutation
elif mgpu.TransposeTransformAttr.isinstance(transform): transpose_transform = launch_context.TransposeTransform(
permutation = mgpu.TransposeTransformAttr(transform).permutation tuple(permutation)
transpose_transform = launch_context.TransposeTransform( )
tuple(permutation) gmem_transforms.append(transpose_transform)
) else:
gmem_transforms.append(transpose_transform) raise ValueError("Unknown transform: {transform}")
else:
raise ValueError(f"{layout} has an unsupported transform: {transform}")
return swizzle or mgpu.SwizzlingMode.kNoSwizzle, tuple(gmem_transforms) return swizzle or mgpu.SwizzlingMode.kNoSwizzle, tuple(gmem_transforms)
@ -434,8 +510,14 @@ def _mgpu_async_load_op_lowering_rule(
assert ctx.launch_context is not None assert ctx.launch_context is not None
barrier = utils.BarrierRef.from_dialect_barrier_memref(load_op.barrier) barrier = utils.BarrierRef.from_dialect_barrier_memref(load_op.barrier)
dst_layout = ir.MemRefType(load_op.destination.type).layout if inference_utils.has_in_transforms_set(load_op):
swizzle, transforms = memref_layout_to_swizzle_and_transforms(dst_layout) [transforms] = inference_utils.in_transforms(load_op)
swizzle, transforms = swizzle_and_transforms_from_transforms_attr(
transforms
)
else:
swizzle = mgpu.SwizzlingMode.kNoSwizzle
transforms = ()
gmem_slice = [] gmem_slice = []
for idx_i32, size in zip(load_op.indices, load_op.slice_lengths): for idx_i32, size in zip(load_op.indices, load_op.slice_lengths):
@ -464,8 +546,14 @@ def _mgpu_async_store_op_lowering_rule(
) -> Sequence[ir.Value]: ) -> Sequence[ir.Value]:
assert ctx.launch_context is not None assert ctx.launch_context is not None
src_layout = ir.MemRefType(store_op.source.type).layout if inference_utils.has_in_transforms_set(store_op):
swizzle, transforms = memref_layout_to_swizzle_and_transforms(src_layout) [transforms] = inference_utils.in_transforms(store_op)
swizzle, transforms = swizzle_and_transforms_from_transforms_attr(
transforms
)
else:
swizzle = mgpu.SwizzlingMode.kNoSwizzle
transforms = ()
gmem_slice = [] gmem_slice = []
for idx_i32, size in zip(store_op.indices, store_op.slice_lengths): for idx_i32, size in zip(store_op.indices, store_op.slice_lengths):
@ -673,6 +761,9 @@ def _bitcast_op_lowering_rule(
def _mgpu_wgmma_op_lowering_rule( def _mgpu_wgmma_op_lowering_rule(
_: LoweringContext, wgmma_op: mgpu.WGMMAOp _: LoweringContext, wgmma_op: mgpu.WGMMAOp
) -> Sequence[ir.Value]: ) -> Sequence[ir.Value]:
if wgmma_op.transpose_a or wgmma_op.transpose_b:
raise ValueError("Transpose arguments are to be deleted.")
fa_layouts = ( fa_layouts = (
*inference_utils.in_layouts(wgmma_op), *inference_utils.in_layouts(wgmma_op),
*inference_utils.out_layouts(wgmma_op), *inference_utils.out_layouts(wgmma_op),
@ -691,25 +782,38 @@ def _mgpu_wgmma_op_lowering_rule(
regs = acc_in.to_layout(fa.WGMMA_LAYOUT) regs = acc_in.to_layout(fa.WGMMA_LAYOUT)
acc = wgmma.WGMMAAccumulator.from_registers(regs) acc = wgmma.WGMMAAccumulator.from_registers(regs)
b_layout = ir.MemRefType(wgmma_op.b.type).layout if ir.VectorType.isinstance(wgmma_op.a.type):
b_swizzle, b_transforms = memref_layout_to_swizzle_and_transforms(b_layout) a_transforms = None
b_transforms = inference_utils.in_transforms(wgmma_op)[0]
else:
a_transforms, b_transforms = inference_utils.in_transforms(wgmma_op)
b_swizzle, b_transforms = swizzle_and_transforms_from_transforms_attr(
b_transforms
)
minimum_swizzle = mgpu.SwizzlingMode.k32ByteSwizzle
ref_ty = ir.MemRefType(wgmma_op.b.type)
_check_transforms_and_swizzle_are_supported(
ref_ty, b_transforms, b_swizzle, minimum_swizzle
)
b_operand = transform_memref(wgmma_op.b, b_transforms) b_operand = transform_memref(wgmma_op.b, b_transforms)
if wgmma_op.transpose_b:
b_operand = utils.memref_transpose(b_operand, (0, 1, 3, 2))
if ir.VectorType.isinstance(wgmma_op.a.type): if ir.VectorType.isinstance(wgmma_op.a.type):
a_operand = _fragmented_array_from_ir(wgmma_op.a, wgmma_layout) a_operand = _fragmented_array_from_ir(wgmma_op.a, wgmma_layout)
else: else:
a_layout = ir.MemRefType(wgmma_op.a.type).layout a_swizzle, a_transforms = swizzle_and_transforms_from_transforms_attr(
a_swizzle, a_transforms = memref_layout_to_swizzle_and_transforms(a_layout) a_transforms
)
ref_ty = ir.MemRefType(wgmma_op.a.type)
_check_transforms_and_swizzle_are_supported(
ref_ty, a_transforms, a_swizzle, minimum_swizzle
)
if a_swizzle != b_swizzle: if a_swizzle != b_swizzle:
raise ValueError( raise ValueError(
f"Non-matching swizzles of operands a and b in WGMMA: {a_swizzle} !=" f"Non-matching swizzles of operands a and b in WGMMA: {a_swizzle} !="
f" {b_swizzle}" f" {b_swizzle}"
) )
a_operand = transform_memref(wgmma_op.a, a_transforms) a_operand = transform_memref(wgmma_op.a, a_transforms)
if wgmma_op.transpose_a:
a_operand = utils.memref_transpose(a_operand, (0, 1, 3, 2))
new_acc = wgmma.wgmma(acc, a_operand, b_operand, swizzle=b_swizzle) new_acc = wgmma.wgmma(acc, a_operand, b_operand, swizzle=b_swizzle)
@ -902,7 +1006,7 @@ def single_thread_predicates(module: ir.Module) -> tuple[ir.Value, ir.Value]:
def _should_lower(op: ir.OpView) -> bool: def _should_lower(op: ir.OpView) -> bool:
"""Returns 'true' if the operation should be lowered.""" """Returns 'true' if the operation should be lowered."""
return ( return (
op.OPERATION_NAME.startswith("mosaic_gpu.") op.OPERATION_NAME.startswith("mosaic_gpu.") # pytype: disable=attribute-error
or inference_utils.should_have_layout(op) or inference_utils.should_have_layout(op)
or any(bool(b) for r in op.regions for b in r) # Does it have subblocks? or any(bool(b) for r in op.regions for b in r) # Does it have subblocks?
) )

View File

@ -387,7 +387,21 @@ class WGMMARowFragLayout:
"""[m] matrix, where m % 64 == 0.""" """[m] matrix, where m % 64 == 0."""
def thread_idxs(self, shape): def thread_idxs(self, shape):
raise NotImplementedError index = ir.IndexType.get()
assert len(shape) == 1
assert shape[0] % 64 == 0
tid = arith.index_cast(ir.IndexType.get(), mgpu.thread_idx())
tid_wg = arith.remui(tid, c(WARPGROUP_SIZE, index))
warp_idx = arith.divui(tid_wg, c(32, index))
lane_id = arith.remui(tid_wg, c(32, index))
row_base = arith.addi(
arith.divui(lane_id, c(4, index)), arith.muli(warp_idx, c(16, index))
)
for row_group in range(0, shape[0], 64):
for row_subgroup in (0, 8):
row = arith.addi(row_base, c(row_group + row_subgroup, index))
yield (row,)
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
@ -660,6 +674,31 @@ class FragmentedArray:
vecs = [vector.load(vec_ty, ref, vec_idx) for vec_idx in layout.thread_idxs(shape)] vecs = [vector.load(vec_ty, ref, vec_idx) for vec_idx in layout.thread_idxs(shape)]
return cls(_registers=np.array(vecs), _layout=layout, _is_signed=is_signed) return cls(_registers=np.array(vecs), _layout=layout, _is_signed=is_signed)
@classmethod
def load_wgmma_row(
cls,
ref: ir.Value,
*,
is_signed: bool | None = None,
):
if not ir.MemRefType.isinstance(ref.type):
raise TypeError(ref.type)
ref_ty = ir.MemRefType(ref.type)
shape = tuple(ref_ty.shape)
if len(shape) != 1:
raise ValueError("WGMMARowFragLayout requires a 1D shape")
if shape[0] % 64:
raise ValueError(
"WGMMARowFragLayout requires shape[0] to be a multiple of 64"
)
layout = WGMMARowFragLayout()
registers = [memref.load(ref, [idx]) for (idx,) in layout.thread_idxs(shape)]
registers = np.array(registers).reshape(-1, 2)
return cls(_registers=registers, _layout=layout, _is_signed=is_signed)
@classmethod @classmethod
def splat(cls, value, shape, layout=None, *, is_signed: bool | None = None): def splat(cls, value, shape, layout=None, *, is_signed: bool | None = None):
layout = layout or WGSplatFragLayout(shape) layout = layout or WGSplatFragLayout(shape)
@ -1743,6 +1782,8 @@ class FragmentedArray:
) )
match self.layout: match self.layout:
case WGMMARowFragLayout():
self._store_untiled_wgmma_row(ref)
case WGSplatFragLayout(): case WGSplatFragLayout():
vs_unsupported() vs_unsupported()
self._store_untiled_splat(ref) self._store_untiled_splat(ref)
@ -1789,6 +1830,23 @@ class FragmentedArray:
for idx, reg in zip(idxs, self.registers.flat): for idx, reg in zip(idxs, self.registers.flat):
vector.store(reg, ref_, idx) vector.store(reg, ref_, idx)
def _store_untiled_wgmma_row(self, ref: ir.Value):
"""Stores an array with a WGMMA row layout."""
assert self.layout == WGMMA_ROW_LAYOUT
index = ir.IndexType.get()
tid = arith.index_cast(ir.IndexType.get(), mgpu.thread_idx())
is_first = arith.cmpi(
arith.CmpIPredicate.eq, arith.remui(tid, c(4, index)), c(0, index)
)
# Consecutive groups of 4 threads hold the same value in this layout,
# therefore we only need to transfer data from one of them.
with utils.when(is_first):
for (idx,), value in zip(
self.layout.thread_idxs(self.shape), self.registers.flatten()
):
memref.store(value, ref, [idx])
def _store_untiled_tiled(self, ref: ir.Value, *, vector_store: bool = True): def _store_untiled_tiled(self, ref: ir.Value, *, vector_store: bool = True):
"""Stores an array with a tiled layout. Not optimized at the moment.""" """Stores an array with a tiled layout. Not optimized at the moment."""
if utils.bitwidth(self.mlir_dtype) < 8: if utils.bitwidth(self.mlir_dtype) < 8:

View File

@ -383,89 +383,6 @@ def _infer_wgmma_op_layout(wgmma_op: mgpu.WGMMAOp) -> OptionalLayouts:
return [layout], [layout] return [layout], [layout]
@dataclasses.dataclass()
class WGMMATransforms:
swizzle: mgpu.SwizzlingMode
a_tile: tuple[int, ...]
a_transpose: bool
b_tile: tuple[int, ...]
b_transpose: bool
def infer_wgmma_transforms(wgmma_op: mgpu.WGMMAOp) -> WGMMATransforms:
a_shape = cast(ir.ShapedType, wgmma_op.a.type).shape
k = a_shape[0] if wgmma_op.transpose_a else a_shape[1]
bitwidth = cast(ir.ShapedType, wgmma_op.a.type).element_type.width
# Try tiling with all swizzling modes starting from the largest one.
for swizzle in [
mgpu.SwizzlingMode.k128ByteSwizzle,
mgpu.SwizzlingMode.k64ByteSwizzle,
mgpu.SwizzlingMode.k32ByteSwizzle,
]:
s = swizzle * 8 // bitwidth
if k % s == 0:
return WGMMATransforms(
swizzle=swizzle,
a_tile=(s, 64) if wgmma_op.transpose_a else (64, s),
a_transpose=wgmma_op.transpose_a,
b_tile=(s, s),
b_transpose=wgmma_op.transpose_b,
)
raise ValueError(
"Could not infer layouts for memref feeding into WGMMA. The "
"non-contracting dimension ({k}) must be a multiple of "
"s = swizzle * (8 / bitwidth) where swizzle is a valid swizzle "
f"(32, 64, or 128) and bitwidth ({bitwidth}) is the element size of "
"`a` and `b`."
)
def _layout_for_memref_view(view_op: memref.ViewOp) -> ir.Attribute | None:
wgmma_use = None
uses = cast(ir.OpResult, view_op.result).uses
for use in uses:
user = use.owner
if isinstance(user, memref.CastOp):
# This memref is already cast, so we don't need to do anything.
return None
if isinstance(user, mgpu.WGMMAOp):
if wgmma_use is not None:
raise NotImplementedError(f"Multiple WGMMA consumers of {view_op}.")
wgmma_use = use
break
if (
not isinstance(user, mgpu.AsyncLoadOp)
and not isinstance(user, mgpu.AsyncStoreOp)
and not isinstance(user, vector.LoadOp)
and not isinstance(user, vector.StoreOp)
):
raise NotImplementedError(f"Unsupported user {user} of {view_op}.")
if wgmma_use is None:
# This memref is not used by a WGMMA operation, so we don't need to do
# anything.
return None
transforms = infer_wgmma_transforms(wgmma_use.owner)
if wgmma_use.operand_number == 1:
tile = transforms.a_tile
transpose = transforms.a_transpose
else:
tile = transforms.b_tile
transpose = transforms.b_transpose
transpose_attr = (
[mgpu.TransposeTransformAttr.get([1, 0, 2, 3])] if transpose else []
)
layout = mgpu.LayoutAttr.get(
2,
[mgpu.TileTransformAttr.get(tile)]
+ transpose_attr
+ [mgpu.SwizzleTransformAttr.get(transforms.swizzle)],
)
return layout
def _earliest_use(regions: list[ir.Region], uses: Sequence[ir.OpOperand]) -> ir.OpView: def _earliest_use(regions: list[ir.Region], uses: Sequence[ir.OpOperand]) -> ir.OpView:
owners = [use.owner for use in uses] owners = [use.owner for use in uses]
@ -607,11 +524,3 @@ def infer_layout(module: ir.Module):
for op in module.body: for op in module.body:
traverse_op(op, set_default_layout) traverse_op(op, set_default_layout)
def infer_memref_layouts_and_insert_casts(op: ir.OpView):
if op.name == "memref.view":
if layout := _layout_for_memref_view(op):
_insert_memref_layout_cast(layout, op)
for op in module.body:
traverse_op(op, infer_memref_layouts_and_insert_casts)

View File

@ -26,6 +26,9 @@ from typing import cast
from jax._src.lib import mosaic_gpu_dialect as mgpu from jax._src.lib import mosaic_gpu_dialect as mgpu
from jax._src.lib.mlir import ir from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith from jax._src.lib.mlir.dialects import arith
from jax._src.lib.mlir.dialects import builtin
from jax._src.lib.mlir.dialects import gpu
from jax._src.lib.mlir.dialects import memref
from jax._src.lib.mlir.dialects import vector from jax._src.lib.mlir.dialects import vector
from . import fragmented_array as fa from . import fragmented_array as fa
@ -169,7 +172,6 @@ def _infer_vector_load_store_transforms(
return None return None
# TODO(bchetioui): remove this once jaxlib minimum version >= 0.5.2. # TODO(bchetioui): remove this once jaxlib minimum version >= 0.5.2.
SliceSMEMOp = getattr(mgpu, "SliceSMEMOp", None) SliceSMEMOp = getattr(mgpu, "SliceSMEMOp", None)
@ -196,6 +198,60 @@ def _infer_slice_smem_transforms(op: SliceSMEMOp) -> OptionalTransforms:
return None if transforms is None else ([], [transforms]) return None if transforms is None else ([], [transforms])
# TODO(bchetioui,apaszke): this empty rule is necessary while Mosaic doesn't use
# the dialect in all cases.
# The rule is necessary in order to handle the lowering of `utils.memref_ptr`
# which is used in `_construct_smem_reftree`.
@partial(_add_transform_inference_rule, builtin.UnrealizedConversionCastOp)
def _infer_unrealized_conversion_cast_transforms(
_: builtin.UnrealizedConversionCastOp,
) -> OptionalTransforms:
return None
@partial(_add_transform_inference_rule, memref.ViewOp)
def _infer_memref_view_transforms(op: memref.ViewOp) -> OptionalTransforms:
if not isinstance(op.source.owner.opview, gpu.DynamicSharedMemoryOp):
raise NotImplementedError(
"Memref view transforms are only inferred when the op is a direct user "
f"of a DynamicSharedMemoryOp but got {op}."
)
transforms = inference_utils.value_transforms(op.source)
if transforms is not None:
raise NotImplementedError(
"memref view with in_transforms aren't yet supported"
)
uses = cast(ir.OpResult, op.result).uses
for op_operand_use in uses:
consumer = op_operand_use.owner
op_user = consumer.operands[op_operand_use.operand_number]
out_transforms = inference_utils.in_transforms_for_operand(
consumer, op_user
)
if transforms is not None and out_transforms is not None:
if transforms != out_transforms:
raise ValueError(
f"Conflicting transforms for {op_user} in {op}: "
f"{transforms} != {out_transforms}."
)
elif out_transforms is not None:
transforms = out_transforms
# TODO(bchetioui): do we actually need to assign a transform to the input of
# the view op? Presumably, it'll only be used to access scratch memory.
return None if transforms is None else ([], [transforms])
# TODO(bchetioui,apaszke): this empty rule is necessary while Mosaic doesn't use
# the dialect in all cases.
@partial(_add_transform_inference_rule, gpu.DynamicSharedMemoryOp)
def _infer_dynamic_smem_transforms(
_: gpu.DynamicSharedMemoryOp,
) -> OptionalTransforms:
return None
def _should_have_transforms(op: ir.OpView) -> bool: def _should_have_transforms(op: ir.OpView) -> bool:
"""Returns 'True' if the operation should be assigned in/out transforms.""" """Returns 'True' if the operation should be assigned in/out transforms."""
return any( return any(
@ -218,7 +274,6 @@ def infer_transforms(module: ir.Module):
specified. We error out if two distinct sets of transforms are competing to specified. We error out if two distinct sets of transforms are competing to
annotate the same memref. annotate the same memref.
""" """
def inference_step(op: ir.Operation): def inference_step(op: ir.Operation):
if not _should_have_transforms(op): if not _should_have_transforms(op):
return return

View File

@ -24,7 +24,7 @@ import jax._src.xla_bridge as xb
# cuda_plugin_extension locates inside jaxlib. `jaxlib` is for testing without # cuda_plugin_extension locates inside jaxlib. `jaxlib` is for testing without
# preinstalled jax cuda plugin packages. # preinstalled jax cuda plugin packages.
for pkg_name in ['jax_cuda12_plugin', 'jaxlib']: for pkg_name in ['jax_cuda12_plugin', 'jaxlib.cuda']:
try: try:
cuda_plugin_extension = importlib.import_module( cuda_plugin_extension = importlib.import_module(
f'{pkg_name}.cuda_plugin_extension' f'{pkg_name}.cuda_plugin_extension'

View File

@ -23,7 +23,7 @@ import jax._src.xla_bridge as xb
# rocm_plugin_extension locates inside jaxlib. `jaxlib` is for testing without # rocm_plugin_extension locates inside jaxlib. `jaxlib` is for testing without
# preinstalled jax rocm plugin packages. # preinstalled jax rocm plugin packages.
for pkg_name in ['jax_rocm60_plugin', 'jaxlib']: for pkg_name in ['jax_rocm60_plugin', 'jaxlib.cuda']:
try: try:
rocm_plugin_extension = importlib.import_module( rocm_plugin_extension = importlib.import_module(
f'{pkg_name}.rocm_plugin_extension' f'{pkg_name}.rocm_plugin_extension'

View File

@ -222,84 +222,3 @@ nanobind_extension(
"@xla//third_party/python_runtime:headers", "@xla//third_party/python_runtime:headers",
], ],
) )
cc_library(
name = "gpu_plugin_extension",
srcs = ["gpu_plugin_extension.cc"],
hdrs = ["gpu_plugin_extension.h"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
deps = [
":kernel_nanobind_helpers",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/strings:string_view",
"@nanobind",
"@xla//xla:util",
"@xla//xla/ffi/api:c_api",
"@xla//xla/pjrt:status_casters",
"@xla//xla/pjrt/c:pjrt_c_api_ffi_extension_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_helpers",
"@xla//xla/pjrt/c:pjrt_c_api_triton_extension_hdrs",
"@xla//xla/python:py_client_gpu",
"@xla//xla/tsl/python/lib/core:numpy",
],
)
nanobind_extension(
name = "cuda_plugin_extension",
srcs = ["cuda_plugin_extension.cc"],
module_name = "cuda_plugin_extension",
deps = [
":gpu_plugin_extension",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@local_config_cuda//cuda:cuda_headers",
"@nanobind",
"@xla//xla/pjrt:status_casters",
"@xla//xla/tsl/cuda:cublas",
"@xla//xla/tsl/cuda:cudart",
],
)
nanobind_extension(
name = "rocm_plugin_extension",
srcs = ["rocm_plugin_extension.cc"],
module_name = "rocm_plugin_extension",
deps = [
":gpu_plugin_extension",
"@com_google_absl//absl/log",
"@com_google_absl//absl/strings",
"@local_config_rocm//rocm:hip",
"@local_config_rocm//rocm:rocm_headers",
"@nanobind",
],
)
# CPU kernels
# TODO(phawkins): Remove this forwarding target.
cc_library(
name = "cpu_kernels",
visibility = ["//visibility:public"],
deps = [
"//jaxlib/cpu:cpu_kernels",
],
alwayslink = 1,
)
# TODO(phawkins): Remove this forwarding target.
cc_library(
name = "gpu_kernels",
visibility = ["//visibility:public"],
deps = [
"//jaxlib/cuda:cuda_gpu_kernels",
],
alwayslink = 1,
)

View File

@ -657,6 +657,22 @@ py_library(
], ],
) )
nanobind_extension(
name = "cuda_plugin_extension",
srcs = ["cuda_plugin_extension.cc"],
module_name = "cuda_plugin_extension",
deps = [
"//jaxlib/gpu:gpu_plugin_extension",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@local_config_cuda//cuda:cuda_headers",
"@nanobind",
"@xla//xla/pjrt:status_casters",
"@xla//xla/tsl/cuda:cublas",
"@xla//xla/tsl/cuda:cudart",
],
)
# We cannot nest select and if_cuda_is_configured so we introduce # We cannot nest select and if_cuda_is_configured so we introduce
# a standalone py_library target. # a standalone py_library target.
py_library( py_library(
@ -664,6 +680,6 @@ py_library(
# `if_cuda_is_configured` will default to `[]`. # `if_cuda_is_configured` will default to `[]`.
deps = if_cuda_is_configured([ deps = if_cuda_is_configured([
":cuda_gpu_support", ":cuda_gpu_support",
"//jaxlib:cuda_plugin_extension", ":cuda_plugin_extension",
]), ]),
) )

View File

@ -20,7 +20,7 @@ limitations under the License.
#include "absl/status/status.h" #include "absl/status/status.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cuda.h"
#include "jaxlib/gpu_plugin_extension.h" #include "jaxlib/gpu/gpu_plugin_extension.h"
#include "xla/pjrt/status_casters.h" #include "xla/pjrt/status_casters.h"
namespace nb = nanobind; namespace nb = nanobind;

View File

@ -90,3 +90,32 @@ xla_py_proto_library(
visibility = jax_visibility("triton_proto_py_users"), visibility = jax_visibility("triton_proto_py_users"),
deps = [":triton_proto"], deps = [":triton_proto"],
) )
cc_library(
name = "gpu_plugin_extension",
srcs = ["gpu_plugin_extension.cc"],
hdrs = ["gpu_plugin_extension.h"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
deps = [
"//jaxlib:kernel_nanobind_helpers",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/strings:string_view",
"@nanobind",
"@xla//xla:util",
"@xla//xla/ffi/api:c_api",
"@xla//xla/pjrt:status_casters",
"@xla//xla/pjrt/c:pjrt_c_api_ffi_extension_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_helpers",
"@xla//xla/pjrt/c:pjrt_c_api_triton_extension_hdrs",
"@xla//xla/python:py_client_gpu",
"@xla//xla/tsl/python/lib/core:numpy",
],
)

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "jaxlib/gpu_plugin_extension.h" #include "jaxlib/gpu/gpu_plugin_extension.h"
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef JAXLIB_GPU_PLUGIN_EXTENSION_H_ #ifndef JAXLIB_GPU_GPU_PLUGIN_EXTENSION_H_
#define JAXLIB_GPU_PLUGIN_EXTENSION_H_ #define JAXLIB_GPU_GPU_PLUGIN_EXTENSION_H_
#include "nanobind/nanobind.h" #include "nanobind/nanobind.h"
@ -24,4 +24,4 @@ void BuildGpuPluginExtension(nanobind::module_& m);
} // namespace xla } // namespace xla
#endif // JAXLIB_GPU_PLUGIN_EXTENSION_H_ #endif // JAXLIB_GPU_GPU_PLUGIN_EXTENSION_H_

View File

@ -143,7 +143,7 @@ mlir::FailureOr<mlir::OpPassManager> GetPassPipeline(
mlir::memref::registerMemRefPasses(); mlir::memref::registerMemRefPasses();
mlir::registerConvertToLLVMPass(); mlir::registerConvertToLLVMPass();
mlir::registerGPUPasses(); mlir::registerGPUPasses();
mlir::registerGpuLaunchSinkIndexComputations(); mlir::registerGpuLaunchSinkIndexComputationsPass();
mosaic::gpu::registerGpuLaunchLoweringPass(); mosaic::gpu::registerGpuLaunchLoweringPass();
mosaic::gpu::registerConvertGpuToLLVMPass(); mosaic::gpu::registerConvertGpuToLLVMPass();
mosaic::gpu::registerByvalInsertionPass(); mosaic::gpu::registerByvalInsertionPass();

View File

@ -555,11 +555,25 @@ py_library(
], ],
) )
nanobind_extension(
name = "rocm_plugin_extension",
srcs = ["rocm_plugin_extension.cc"],
module_name = "rocm_plugin_extension",
deps = [
"//jaxlib/gpu:gpu_plugin_extension",
"@com_google_absl//absl/log",
"@com_google_absl//absl/strings",
"@local_config_rocm//rocm:hip",
"@local_config_rocm//rocm:rocm_headers",
"@nanobind",
],
)
py_library( py_library(
name = "gpu_only_test_deps", name = "gpu_only_test_deps",
# `if_rocm_is_configured` will default to `[]`. # `if_rocm_is_configured` will default to `[]`.
deps = if_rocm_is_configured([ deps = if_rocm_is_configured([
":rocm_gpu_support", ":rocm_gpu_support",
"//jaxlib:rocm_plugin_extension", ":rocm_plugin_extension",
]), ]),
) )

View File

@ -20,7 +20,7 @@ limitations under the License.
#include "absl/log/log.h" #include "absl/log/log.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "rocm/include/hip/hip_runtime.h" #include "rocm/include/hip/hip_runtime.h"
#include "jaxlib/gpu_plugin_extension.h" #include "jaxlib/gpu/gpu_plugin_extension.h"
namespace nb = nanobind; namespace nb = nanobind;

View File

@ -143,16 +143,16 @@ py_binary(
data = [ data = [
"LICENSE.txt", "LICENSE.txt",
] + if_cuda([ ] + if_cuda([
"//jaxlib/mosaic/gpu:mosaic_gpu",
"//jaxlib:cuda_plugin_extension",
"//jaxlib:version", "//jaxlib:version",
"//jaxlib/mosaic/gpu:mosaic_gpu",
"//jaxlib/cuda:cuda_plugin_extension",
"//jaxlib/cuda:cuda_gpu_support", "//jaxlib/cuda:cuda_gpu_support",
"//jax_plugins/cuda:plugin_pyproject.toml", "//jax_plugins/cuda:plugin_pyproject.toml",
"//jax_plugins/cuda:plugin_setup.py", "//jax_plugins/cuda:plugin_setup.py",
"@local_config_cuda//cuda:cuda-nvvm", "@local_config_cuda//cuda:cuda-nvvm",
]) + if_rocm([ ]) + if_rocm([
"//jaxlib:rocm_plugin_extension",
"//jaxlib:version", "//jaxlib:version",
"//jaxlib/rocm:rocm_plugin_extension",
"//jaxlib/rocm:rocm_gpu_support", "//jaxlib/rocm:rocm_gpu_support",
"//jax_plugins/rocm:plugin_pyproject.toml", "//jax_plugins/rocm:plugin_pyproject.toml",
"//jax_plugins/rocm:plugin_setup.py", "//jax_plugins/rocm:plugin_setup.py",

View File

@ -110,7 +110,7 @@ def prepare_wheel_cuda(
f"__main__/jaxlib/cuda/_triton.{pyext}", f"__main__/jaxlib/cuda/_triton.{pyext}",
f"__main__/jaxlib/cuda/_hybrid.{pyext}", f"__main__/jaxlib/cuda/_hybrid.{pyext}",
f"__main__/jaxlib/cuda/_versions.{pyext}", f"__main__/jaxlib/cuda/_versions.{pyext}",
f"__main__/jaxlib/cuda_plugin_extension.{pyext}", f"__main__/jaxlib/cuda/cuda_plugin_extension.{pyext}",
f"__main__/jaxlib/mosaic/gpu/_mosaic_gpu_ext.{pyext}", f"__main__/jaxlib/mosaic/gpu/_mosaic_gpu_ext.{pyext}",
"__main__/jaxlib/mosaic/gpu/libmosaic_gpu_runtime.so", "__main__/jaxlib/mosaic/gpu/libmosaic_gpu_runtime.so",
"__main__/jaxlib/version.py", "__main__/jaxlib/version.py",
@ -148,7 +148,7 @@ def prepare_wheel_rocm(
f"__main__/jaxlib/rocm/_hybrid.{pyext}", f"__main__/jaxlib/rocm/_hybrid.{pyext}",
f"__main__/jaxlib/rocm/_rnn.{pyext}", f"__main__/jaxlib/rocm/_rnn.{pyext}",
f"__main__/jaxlib/rocm/_triton.{pyext}", f"__main__/jaxlib/rocm/_triton.{pyext}",
f"__main__/jaxlib/rocm_plugin_extension.{pyext}", f"__main__/jaxlib/rocm/rocm_plugin_extension.{pyext}",
"__main__/jaxlib/version.py", "__main__/jaxlib/version.py",
], ],
) )

View File

@ -29,16 +29,23 @@ def call_kernel(
kernel, kernel,
grid: tuple[int, int], grid: tuple[int, int],
transpose_grid: bool, transpose_grid: bool,
*args key: jax.Array,
total_size: tuple[int, int],
block_size: tuple[int, int],
tile_size: tuple[int, int],
): ):
"""Calls a kernel over a grid and concatenates results to a single array.""" """Calls a kernel over a grid and concatenates results to a single array."""
if transpose_grid: if transpose_grid:
grid = (grid[1], grid[0]) grid = (grid[1], grid[0])
m, n = grid m, n = grid
return jnp.concatenate([ samples = jnp.concatenate([
jnp.concatenate([ jnp.concatenate([
kernel((i, j), *args) for j in range(n)], axis=1) kernel((i, j), key, total_size, block_size, tile_size)
for i in range(m)], axis=0) for j in range(n)], axis=1)
for i in range(m)], axis=0)
# Slice out the padding.
samples = samples[:total_size[0], :total_size[1]]
return samples
def call_kernel_3d( def call_kernel_3d(
@ -73,10 +80,10 @@ def uniform_kernel(block_index, key, total_size, block_size, tile_size):
block_size=block_size, block_size=block_size,
tile_size=tile_size) tile_size=tile_size)
return blocked_sampler.sample_block(jax.random.uniform, return blocked_sampler.sample_block(jax.random.uniform,
keys, keys,
block_size=block_size, block_size=block_size,
tile_size=tile_size, tile_size=tile_size,
minval=0.0, maxval=1.0) minval=0.0, maxval=1.0)
class BlockedSamplerTest(jtu.JaxTestCase): class BlockedSamplerTest(jtu.JaxTestCase):
@ -94,16 +101,25 @@ class BlockedSamplerTest(jtu.JaxTestCase):
dict(testcase_name='16x256_vs_32x128', total_size=(32, 256), dict(testcase_name='16x256_vs_32x128', total_size=(32, 256),
block_size_a=(16, 256), block_size_b=(32, 128), block_size_a=(16, 256), block_size_b=(32, 128),
tile_size=(8, 128), transpose_grid=False), tile_size=(8, 128), transpose_grid=False),
dict(testcase_name='128x128_vs_128x256_padding',
total_size=(256, 128), block_size_a=(128, 128),
block_size_b=(128, 256), tile_size=(128, 128), transpose_grid=False),
dict(testcase_name='128x128_vs_128x256_padding2',
total_size=(257, 129), block_size_a=(128, 128),
block_size_b=(128, 256), tile_size=(128, 128), transpose_grid=False),
) )
def test_block_shape_invariance(self, total_size, block_size_a, def test_block_shape_invariance(self, total_size, block_size_a,
block_size_b, tile_size, transpose_grid): block_size_b, tile_size, transpose_grid):
global_key = jax.random.key(0) global_key = jax.random.key(0)
grid_a = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_a)) ceil_div = lambda x, y: (x + y - 1) // y
grid_a = tuple(ceil_div(_tot, _blk)
for _tot, _blk in zip(total_size, block_size_a))
result_a = call_kernel( result_a = call_kernel(
uniform_kernel, grid_a, transpose_grid, global_key, uniform_kernel, grid_a, transpose_grid, global_key,
total_size, block_size_a, tile_size) total_size, block_size_a, tile_size)
grid_b = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_b)) grid_b = tuple(ceil_div(_tot, _blk)
for _tot, _blk in zip(total_size, block_size_b))
result_b = call_kernel( result_b = call_kernel(
uniform_kernel, grid_b, transpose_grid, global_key, uniform_kernel, grid_b, transpose_grid, global_key,
total_size, block_size_b, tile_size) total_size, block_size_b, tile_size)

View File

@ -15,9 +15,9 @@
import contextlib import contextlib
import io import io
import logging import logging
import os
import platform import platform
import re import re
import shlex
import subprocess import subprocess
import sys import sys
import tempfile import tempfile
@ -78,6 +78,31 @@ def capture_jax_logs():
logger.removeHandler(handler) logger.removeHandler(handler)
# Saves and runs script from the file in order to fix the problem with
# `tsl::Env::Default()->GetExecutablePath()` not working properly with
# command flag.
def _run(program, env_var = {}):
# strip the leading whitespace from the program script
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
with tempfile.NamedTemporaryFile(
mode="w+", encoding="utf-8", suffix=".py", dir=os.getcwd()
) as f:
f.write(textwrap.dedent(program))
f.flush()
python = sys.executable
assert "python" in python
if env_var:
env_var.update(os.environ)
else:
env_var = os.environ
# Make sure C++ logging is at default level for the test process.
p = subprocess.run([python, f.name], env=env_var, capture_output=True, text=True)
return type("", (object,), { "stdout": p.stdout, "stderr": p.stderr })
class LoggingTest(jtu.JaxTestCase): class LoggingTest(jtu.JaxTestCase):
@unittest.skipIf(platform.system() == "Windows", @unittest.skipIf(platform.system() == "Windows",
@ -90,36 +115,25 @@ class LoggingTest(jtu.JaxTestCase):
if sys.executable is None: if sys.executable is None:
raise self.skipTest("test requires access to python binary") raise self.skipTest("test requires access to python binary")
# Save script in file to fix the problem with o = _run("""
# `tsl::Env::Default()->GetExecutablePath()` not working properly with
# command flag.
with tempfile.NamedTemporaryFile(
mode="w+", encoding="utf-8", suffix=".py"
) as f:
f.write(textwrap.dedent("""
import jax import jax
jax.device_count() jax.device_count()
f = jax.jit(lambda x: x + 1) f = jax.jit(lambda x: x + 1)
f(1) f(1)
f(2) f(2)
jax.numpy.add(1, 1) jax.numpy.add(1, 1)
""")) """)
python = sys.executable
assert "python" in python
# Make sure C++ logging is at default level for the test process.
proc = subprocess.run([python, f.name], capture_output=True)
lines = proc.stdout.split(b"\n") lines = o.stdout.split("\n")
lines.extend(proc.stderr.split(b"\n")) lines.extend(o.stderr.split("\n"))
allowlist = [ allowlist = [
b"", (
( "An NVIDIA GPU may be present on this machine, but a"
b"An NVIDIA GPU may be present on this machine, but a" " CUDA-enabled jaxlib is not installed. Falling back to cpu."
b" CUDA-enabled jaxlib is not installed. Falling back to cpu." ),
), ]
] lines = [l for l in lines if l in allowlist]
lines = [l for l in lines if l not in allowlist] self.assertEmpty(lines)
self.assertEmpty(lines)
def test_debug_logging(self): def test_debug_logging(self):
# Warmup so we don't get "No GPU/TPU" warning later. # Warmup so we don't get "No GPU/TPU" warning later.
@ -164,19 +178,12 @@ class LoggingTest(jtu.JaxTestCase):
if sys.executable is None: if sys.executable is None:
raise self.skipTest("test requires access to python binary") raise self.skipTest("test requires access to python binary")
program = """ o = _run("""
import jax # this prints INFO logging from backend imports import jax # this prints INFO logging from backend imports
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation) jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
""" """, { "JAX_LOGGING_LEVEL": "INFO" })
# strip the leading whitespace from the program script log_output = o.stderr
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
# test INFO
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=INFO {sys.executable} -c"
f" '{program}'")
p = subprocess.run(cmd, capture_output=True, text=True)
log_output = p.stderr
info_lines = log_output.split("\n") info_lines = log_output.split("\n")
self.assertGreater(len(info_lines), 0) self.assertGreater(len(info_lines), 0)
self.assertIn("INFO", log_output) self.assertIn("INFO", log_output)
@ -194,22 +201,14 @@ class LoggingTest(jtu.JaxTestCase):
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation) jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
""" """
# strip the leading whitespace from the program script o = _run(program, { "JAX_LOGGING_LEVEL": "DEBUG" })
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
# test DEBUG log_output = o.stderr
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c"
f" '{program}'")
p = subprocess.run(cmd, capture_output=True, text=True)
log_output = p.stderr
self.assertIn("INFO", log_output) self.assertIn("INFO", log_output)
self.assertIn("DEBUG", log_output) self.assertIn("DEBUG", log_output)
# test JAX_DEBUG_MODULES o = _run(program, { "JAX_DEBUG_LOG_MODULES": "jax" })
cmd = shlex.split(f"env JAX_DEBUG_LOG_MODULES=jax {sys.executable} -c" log_output = o.stderr
f" '{program}'")
p = subprocess.run(cmd, capture_output=True, text=True)
log_output = p.stderr
self.assertIn("DEBUG", log_output) self.assertIn("DEBUG", log_output)
@jtu.skip_on_devices("tpu") @jtu.skip_on_devices("tpu")
@ -220,22 +219,15 @@ class LoggingTest(jtu.JaxTestCase):
raise self.skipTest("test requires access to python binary") raise self.skipTest("test requires access to python binary")
_separator = "---------------------------" _separator = "---------------------------"
program = f""" o = _run(f"""
import sys import sys
import jax # this prints INFO logging from backend imports import jax # this prints INFO logging from backend imports
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation) jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
jax.config.update("jax_logging_level", None) jax.config.update("jax_logging_level", None)
sys.stderr.write("{_separator}") sys.stderr.write("{_separator}")
jax.jit(lambda x: x)(1) # should not log anything now jax.jit(lambda x: x)(1) # should not log anything now
""" """, {"JAX_LOGGING_LEVEL": "DEBUG"})
log_output = o.stderr
# strip the leading whitespace from the program script
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c"
f" '{program}'")
p = subprocess.run(cmd, capture_output=True, text=True)
log_output = p.stderr
m = re.search(_separator, log_output) m = re.search(_separator, log_output)
self.assertTrue(m is not None) self.assertTrue(m is not None)
log_output_verbose = log_output[:m.start()] log_output_verbose = log_output[:m.start()]
@ -252,19 +244,13 @@ class LoggingTest(jtu.JaxTestCase):
if sys.executable is None: if sys.executable is None:
raise self.skipTest("test requires access to python binary") raise self.skipTest("test requires access to python binary")
program = """ o = _run("""
import jax # this prints INFO logging from backend imports import jax # this prints INFO logging from backend imports
jax.config.update("jax_debug_log_modules", "jax._src.compiler,jax._src.dispatch") jax.config.update("jax_debug_log_modules", "jax._src.compiler,jax._src.dispatch")
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation) jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
""" """, { "JAX_LOGGING_LEVEL": "DEBUG" })
# strip the leading whitespace from the program script log_output = o.stderr
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c"
f" '{program}'")
p = subprocess.run(cmd, capture_output=True, text=True)
log_output = p.stderr
self.assertNotEmpty(log_output) self.assertNotEmpty(log_output)
log_lines = log_output.strip().split("\n") log_lines = log_output.strip().split("\n")
# only one tracing line should be printed, if there's more than one # only one tracing line should be printed, if there's more than one
@ -285,31 +271,19 @@ class LoggingTest(jtu.JaxTestCase):
jax.distributed.initialize("127.0.0.1:12345", num_processes=1, process_id=0) jax.distributed.initialize("127.0.0.1:12345", num_processes=1, process_id=0)
""" """
# strip the leading whitespace from the program script o = _run(program, { "JAX_LOGGING_LEVEL": "DEBUG" })
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE) self.assertIn("Initializing CoordinationService", o.stderr)
# verbose logging: DEBUG, VERBOSE o = _run(program, { "JAX_LOGGING_LEVEL": "INFO" })
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c" self.assertIn("Initializing CoordinationService", o.stderr)
f" '{program}'")
p = subprocess.run(cmd, capture_output=True, text=True)
self.assertIn("Initializing CoordinationService", p.stderr)
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=INFO {sys.executable} -c"
f" '{program}'")
p = subprocess.run(cmd, capture_output=True, text=True)
self.assertIn("Initializing CoordinationService", p.stderr)
# verbose logging: WARNING, None # verbose logging: WARNING, None
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=WARNING {sys.executable} -c" o = _run(program, { "JAX_LOGGING_LEVEL": "WARNING" })
f" '{program}'") self.assertNotIn("Initializing CoordinationService", o.stderr)
p = subprocess.run(cmd, capture_output=True, text=True)
self.assertNotIn("Initializing CoordinationService", p.stderr)
cmd = shlex.split(f"{sys.executable} -c" o = _run(program)
f" '{program}'")
p = subprocess.run(cmd, capture_output=True, text=True)
if int(_default_TF_CPP_MIN_LOG_LEVEL) >= 1: if int(_default_TF_CPP_MIN_LOG_LEVEL) >= 1:
self.assertNotIn("Initializing CoordinationService", p.stderr) self.assertNotIn("Initializing CoordinationService", o.stderr)
if __name__ == "__main__": if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader()) absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -1664,31 +1664,30 @@ class StreamAnnotationTest(jtu.JaxTestCase):
def test_stream_annotation_inside_shmap(self): def test_stream_annotation_inside_shmap(self):
if not jtu.test_device_matches(["gpu"]): if not jtu.test_device_matches(["gpu"]):
self.skipTest("Stream annotation is only supported on GPU.") self.skipTest("Stream annotation is only supported on GPU.")
mesh = jtu.create_mesh((2, 2), ('x', 'y')) mesh = jtu.create_mesh((2,), ('x',))
s = NamedSharding(mesh, P('x', 'y')) s = NamedSharding(mesh, P('x'))
np_inp = np.ones((8, 8)) np_inp = np.ones((8,))
arr1 = jax.device_put(np_inp, s) arr1 = jax.device_put(np_inp, s)
arr2 = jax.device_put(np_inp, s) arr2 = jax.device_put(np_inp, s)
@compute_on('gpu_stream:1') @compute_on('gpu_stream:1')
@jax.jit @jax.jit
def g(x, y): def g(x, y):
return x @ y return x * y
@compute_on('gpu_stream:2') @compute_on('gpu_stream:2')
@jax.jit @jax.jit
def h(x, y): def h(x, y):
return x @ y return x * y
def f(x, y): def f(x, y):
z = g(x, y) z = g(x, y)
w = h(3 * x, 2 * y) w = h(3 * x, 2 * y)
return z + w return z + w
out = jax.jit(shard_map(f, mesh=mesh, out = jax.jit(shard_map(f, mesh=mesh, in_specs=(P('x'), P('x')),
in_specs=(P('x', 'y'), P('x', 'y')), out_specs=P('x')))(arr1, arr2)
out_specs=P('x', 'y')))(arr1, arr2) self.assertArraysEqual(out, arr1 * 7)
self.assertArraysEqual(out, arr1 * 28)
class ActivationOffloadingTest(jtu.JaxTestCase): class ActivationOffloadingTest(jtu.JaxTestCase):

View File

@ -55,6 +55,7 @@ else:
from jax.experimental.mosaic.gpu import launch_context from jax.experimental.mosaic.gpu import launch_context
from jax.experimental.mosaic.gpu import utils as utils from jax.experimental.mosaic.gpu import utils as utils
from jax.experimental.mosaic.gpu import profiler from jax.experimental.mosaic.gpu import profiler
from jax.experimental.mosaic.gpu import inference_utils
from jax.experimental.mosaic.gpu.utils import * # noqa: F403 from jax.experimental.mosaic.gpu.utils import * # noqa: F403
from jax._src.lib.mlir.dialects import gpu from jax._src.lib.mlir.dialects import gpu
from jax._src.lib.mlir.dialects import llvm from jax._src.lib.mlir.dialects import llvm
@ -1945,6 +1946,21 @@ class FragmentedArrayTest(TestCase):
)(inp) )(inp)
np.testing.assert_array_equal(inp, result) np.testing.assert_array_equal(inp, result)
@parameterized.product(in_shape=((128,), (64,)))
def test_wgmma_row_load_store_with_layout(self, in_shape):
def kernel(ctx, *args):
gmem_input, gmem_output, (smem_input, smem_output) = args
copy(gmem_input, smem_input)
t = mgpu.FragmentedArray.load_wgmma_row(smem_input)
t.store_untiled(smem_output)
copy(smem_output, gmem_output)
inp = out = self.prng.uniform(-1, 1, in_shape).astype(jnp.float32)
result = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (inp,), out, [inp, out],
)(inp)
np.testing.assert_array_equal(inp, result)
def test_warp_tree_reduce(self): def test_warp_tree_reduce(self):
def kernel(ctx, out, *_): def kernel(ctx, out, *_):
del ctx del ctx
@ -2405,25 +2421,21 @@ class Swizzle:
return mgpu_dialect.SwizzleTransformAttr.get(self.swizzle) return mgpu_dialect.SwizzleTransformAttr.get(self.swizzle)
def memref_with_transforms( def set_in_transforms(
mem_ref: ir.Value, op: ir.OpView, transforms: Sequence[Sequence[Tile | Transpose | Swizzle]],
transforms: Sequence[Tile | Transpose | Swizzle], ) -> None:
) -> ir.Value: """Annotates an op with in_transforms."""
"""Casts the memref to one that has a layout with the given transforms.""" if not transforms:
mem_ref_type = ir.MemRefType(mem_ref.type) return
transform_attr = [t.attr() for t in transforms] in_transforms = []
if not transform_attr: smem_refs = filter(inference_utils.is_transformable_smem_memref, op.operands) # pylint: disable=undefined-variable
return mem_ref for _, result_transforms in jax.util.safe_zip(smem_refs, transforms):
in_transforms.append(
ir.ArrayAttr.get([t.attr() for t in result_transforms])
)
layout = mgpu_dialect.LayoutAttr.get(mem_ref_type.rank, transform_attr) op.attributes["in_transforms"] = ir.ArrayAttr.get(in_transforms)
memref_new_type = ir.MemRefType.get(
mem_ref_type.shape,
mem_ref_type.element_type,
layout,
mem_ref_type.memory_space,
)
return memref.cast(memref_new_type, mem_ref)
class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase): class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
@ -2556,7 +2568,6 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
): ):
del ctx del ctx
smem_ref, tma_barrier = smem smem_ref, tma_barrier = smem
smem_ref = memref_with_transforms(smem_ref, test_case.transforms)
dialect_barrier = tma_barrier.as_dialect_barrier_memref() dialect_barrier = tma_barrier.as_dialect_barrier_memref()
elt_type = ir.MemRefType(in_gmem_ref.type).element_type elt_type = ir.MemRefType(in_gmem_ref.type).element_type
@ -2571,7 +2582,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
slice_indices = [arith.constant(i32, i) for i in test_case.slice_indices] slice_indices = [arith.constant(i32, i) for i in test_case.slice_indices]
# GMEM -> SMEM # GMEM -> SMEM
mgpu_dialect.async_load( load_op = mgpu_dialect.AsyncLoadOp(
source=in_gmem_ref, source=in_gmem_ref,
destination=smem_ref, destination=smem_ref,
barrier=dialect_barrier, barrier=dialect_barrier,
@ -2579,6 +2590,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
slice_lengths=test_case.slice_lengths, slice_lengths=test_case.slice_lengths,
collective=ir.ArrayAttr.get([]), collective=ir.ArrayAttr.get([]),
) )
set_in_transforms(load_op, [test_case.transforms])
parities = memref.load(tma_barrier.phases, []) parities = memref.load(tma_barrier.phases, [])
parity, _ = tma_barrier.update_parities(parities) parity, _ = tma_barrier.update_parities(parities)
@ -2623,58 +2635,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
(x[input_slice]).reshape(test_case.shape_sliced), (x[input_slice]).reshape(test_case.shape_sliced),
) )
@staticmethod def test_pointwise_kernel_with_tma(self):
def pointwise_kernel_with_tma_cases(dtype: jnp.dtype):
@dataclasses.dataclass(frozen=True)
class TestCaseInput:
shape: tuple[int, ...]
transforms: tuple[Tile | Transpose | Swizzle, ...] = ()
result = []
for swizzle in mgpu_dialect.SwizzlingMode:
n = swizzle * 8 // jnp.finfo(dtype).bits
if swizzle == mgpu_dialect.SwizzlingMode.kNoSwizzle:
# We need at least one case with no transforms, as this is handled
# differently.
result.append(TestCaseInput(shape=[128, n]))
result.extend([
TestCaseInput(
shape=[128, n],
transforms=[Swizzle(swizzle)],
),
TestCaseInput(
shape=[2, 3, 64, n],
transforms=[Transpose([0, 1, 2, 3]), Swizzle(swizzle)],
),
TestCaseInput(
shape=[2, 3, 64, n],
transforms=[
Transpose([1, 0, 2, 3]),
Transpose([1, 0, 2, 3]),
Swizzle(swizzle),
],
),
TestCaseInput(
shape=[2, 3, 64, n],
transforms=[Transpose([1, 0, 2, 3]), Swizzle(swizzle)],
),
TestCaseInput(
shape=[128, n],
transforms=[Tile([64, n]), Swizzle(swizzle)],
),
TestCaseInput(
shape=[2 * 64, 3 * n],
transforms=[
Tile([64, n]),
Transpose([1, 0, 2, 3]),
Swizzle(swizzle),
],
),
])
return result
@parameterized.parameters(pointwise_kernel_with_tma_cases(jnp.bfloat16))
def test_pointwise_kernel_with_tma(self, test_case):
def add( def add(
ctx: launch_context.LaunchContext, ctx: launch_context.LaunchContext,
a_gmem_ref: ir.Value, a_gmem_ref: ir.Value,
@ -2701,9 +2662,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
# GMEM -> SMEM # GMEM -> SMEM
mgpu_dialect.async_load( mgpu_dialect.async_load(
source=a_gmem_ref, source=a_gmem_ref,
destination=memref_with_transforms( destination=a_smem_ref,
a_smem_ref, test_case.transforms
),
barrier=dialect_barrier, barrier=dialect_barrier,
indices=zero_slice_indices, indices=zero_slice_indices,
slice_lengths=shape, slice_lengths=shape,
@ -2711,9 +2670,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
) )
mgpu_dialect.async_load( mgpu_dialect.async_load(
source=b_gmem_ref, source=b_gmem_ref,
destination=memref_with_transforms( destination=b_smem_ref,
b_smem_ref, test_case.transforms
),
barrier=dialect_barrier, barrier=dialect_barrier,
indices=zero_slice_indices, indices=zero_slice_indices,
slice_lengths=shape, slice_lengths=shape,
@ -2740,9 +2697,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
# SMEM -> GMEM # SMEM -> GMEM
mgpu_dialect.async_store( mgpu_dialect.async_store(
source=memref_with_transforms( source=result_smem_ref,
result_smem_ref, test_case.transforms
),
destination=result_gmem_ref, destination=result_gmem_ref,
indices=zero_slice_indices, indices=zero_slice_indices,
slice_lengths=shape, slice_lengths=shape,
@ -2752,114 +2707,76 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
dtype = jnp.bfloat16 dtype = jnp.bfloat16
jax_shape = jax.ShapeDtypeStruct(test_case.shape, dtype) spec = jax.ShapeDtypeStruct((2, 3, 4, 64), dtype)
kernel = mgpu.as_gpu_kernel( kernel = mgpu.as_gpu_kernel(
add, add,
grid=(1, 1, 1), grid=(1, 1, 1),
block=(128, 1, 1), block=(128, 1, 1),
in_shape=(jax_shape, jax_shape), in_shape=(spec, spec),
out_shape=jax_shape, out_shape=spec,
smem_scratch_shape=[ smem_scratch_shape=[
jax_shape, spec,
jax_shape, spec,
jax_shape, spec,
core.TMABarrier(1), core.TMABarrier(1),
], ],
thread_semantics=mgpu.ThreadSemantics.Warpgroup, thread_semantics=mgpu.ThreadSemantics.Warpgroup,
) )
x = self.prng.uniform(-1, 1, test_case.shape).astype(dtype) x = self.prng.uniform(-1, 1, spec.shape).astype(dtype)
y = self.prng.uniform(-1, 1, test_case.shape).astype(dtype) y = self.prng.uniform(-1, 1, spec.shape).astype(dtype)
self.assertArraysEqual(jax.jit(kernel)(x, y), x + y + y) self.assertArraysEqual(jax.jit(kernel)(x, y), x + y + y)
class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase): class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase):
@staticmethod @parameterized.named_parameters(
def wgmma_kernel_with_tma_cases(abtype: jnp.dtype): (
@dataclasses.dataclass(frozen=True) f"swizzle={int(swizzle)}_{transpose_lhs=}_{transpose_rhs=}_{lhs_in_registers=}",
class TestCaseInput: swizzle,
shape_a: tuple[int, ...] = () transpose_lhs,
shape_b: tuple[int, ...] = () transpose_rhs,
shape_res: tuple[int, ...] = () lhs_in_registers,
transforms_a: tuple[Tile | Transpose | Swizzle, ...] = () )
transforms_b: tuple[Tile | Transpose | Swizzle, ...] = () for swizzle in mgpu_dialect.SwizzlingMode
transpose_a: bool = False for transpose_lhs in [False, True]
transpose_b: bool = False for transpose_rhs in [False, True]
load_a_in_registers: bool = False for lhs_in_registers in [False, True]
)
def test_wgmma_kernel_with_tma(
self, swizzle, transpose_lhs, transpose_rhs, load_a_in_registers
):
if swizzle == mgpu_dialect.SwizzlingMode.kNoSwizzle:
self.skipTest("No swizzle is not supported by wgmma")
result = [] if transpose_lhs or transpose_rhs:
for swizzle in [ self.skipTest("Transposes are not supported by transform inference yet.")
# TODO(dasenov): Add a test for kNoSwizzle, i.e. all swizzling modes.
mgpu_dialect.SwizzlingMode.k32ByteSwizzle,
mgpu_dialect.SwizzlingMode.k64ByteSwizzle,
mgpu_dialect.SwizzlingMode.k128ByteSwizzle,
]:
k = swizzle // np.dtype(abtype).itemsize
groups_m = 4
groups_n = 1
groups_k = 1
result.extend([
TestCaseInput(
shape_a=[groups_m * 64, groups_k * k],
shape_b=[groups_k * k, groups_n * k],
shape_res=[groups_m * 64, groups_n * k],
),
TestCaseInput(
shape_a=[groups_m * 64, groups_k * k],
shape_b=[groups_n * k, groups_k * k],
shape_res=[groups_m * 64, groups_n * k],
transpose_b=True,
),
TestCaseInput(
shape_a=[groups_m * 64, groups_k * k],
shape_b=[groups_k * k, groups_n * k],
shape_res=[groups_m * 64, groups_n * k],
transforms_a=[Tile([64, k]), Swizzle(swizzle)],
transforms_b=[Tile([k, k]), Swizzle(swizzle)],
),
TestCaseInput(
shape_a=[groups_m * 64, groups_k * k],
shape_b=[groups_k * k, groups_n * k],
shape_res=[groups_m * 64, groups_n * k],
transforms_a=[Tile([64, k]), Swizzle(swizzle)],
load_a_in_registers=True,
),
])
# The below only works for 128-byte swizzling. Regardless of transposing,
# TMA needs the size of the last dimension to be compatible with the
# swizzle.
if swizzle == mgpu_dialect.SwizzlingMode.k128ByteSwizzle:
result.append(
TestCaseInput(
shape_a=[groups_k * k, groups_m * 64],
shape_b=[groups_k * k, groups_n * k],
shape_res=[groups_m * 64, groups_n * k],
transpose_a=True,
)
)
return result
@parameterized.parameters(wgmma_kernel_with_tma_cases(jnp.bfloat16)) swizzle_elems = swizzle // np.dtype(jnp.bfloat16).itemsize
def test_wgmma_kernel_with_tma(self, test_case): tiling_m, tiling_n, tiling_k = 64, swizzle_elems, swizzle_elems
groups_m, groups_n, groups_k = 4, 1, 1
m, n, k = groups_m * tiling_m, groups_n * tiling_n, groups_k * tiling_k
lhs_shape = (k, m) if transpose_lhs else (m, k)
rhs_shape = (n, k) if transpose_rhs else (k, n)
out_shape = (m, n)
def matmul( def matmul(
ctx: launch_context.LaunchContext, ctx: launch_context.LaunchContext,
a_gmem_ref: ir.Value, lhs_gmem_ref: ir.Value,
b_gmem_ref: ir.Value, rhs_gmem_ref: ir.Value,
result_gmem_ref: ir.Value, result_gmem_ref: ir.Value,
smem: list[ir.Value], smem: list[ir.Value],
): ):
del ctx del ctx
a_smem_ref, b_smem_ref, result_smem_ref, tma_barrier = smem lhs_smem_ref, rhs_smem_ref, result_smem_ref, tma_barrier = smem
a_smem_ref = memref_with_transforms(a_smem_ref, test_case.transforms_a)
b_smem_ref = memref_with_transforms(b_smem_ref, test_case.transforms_b)
dialect_barrier = tma_barrier.as_dialect_barrier_memref() dialect_barrier = tma_barrier.as_dialect_barrier_memref()
ab_elt_type = ir.MemRefType(a_gmem_ref.type).element_type operand_elt_type = ir.MemRefType(lhs_gmem_ref.type).element_type
bytes_a = utils.bytewidth(ab_elt_type) * math.prod(test_case.shape_a) bytes_a = utils.bytewidth(operand_elt_type) * math.prod(lhs_shape)
bytes_b = utils.bytewidth(ab_elt_type) * math.prod(test_case.shape_b) bytes_b = utils.bytewidth(operand_elt_type) * math.prod(rhs_shape)
mgpu_dialect.arrive_expect_tx( mgpu_dialect.arrive_expect_tx(
barrier=dialect_barrier, barrier=dialect_barrier,
@ -2869,19 +2786,19 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase):
zero_i32 = arith.constant(ir.IntegerType.get_signless(32), 0) zero_i32 = arith.constant(ir.IntegerType.get_signless(32), 0)
# GMEM -> SMEM # GMEM -> SMEM
mgpu_dialect.async_load( mgpu_dialect.async_load(
source=a_gmem_ref, source=lhs_gmem_ref,
destination=a_smem_ref, destination=lhs_smem_ref,
barrier=dialect_barrier, barrier=dialect_barrier,
indices=[zero_i32] * len(test_case.shape_a), indices=[zero_i32] * len(lhs_shape),
slice_lengths=test_case.shape_a, slice_lengths=lhs_shape,
collective=ir.ArrayAttr.get([]), collective=ir.ArrayAttr.get([]),
) )
mgpu_dialect.async_load( mgpu_dialect.async_load(
source=b_gmem_ref, source=rhs_gmem_ref,
destination=b_smem_ref, destination=rhs_smem_ref,
barrier=dialect_barrier, barrier=dialect_barrier,
indices=[zero_i32] * len(test_case.shape_b), indices=[zero_i32] * len(rhs_shape),
slice_lengths=test_case.shape_b, slice_lengths=rhs_shape,
collective=ir.ArrayAttr.get([]), collective=ir.ArrayAttr.get([]),
) )
@ -2889,29 +2806,34 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase):
parity, _ = tma_barrier.update_parities(parities) parity, _ = tma_barrier.update_parities(parities)
mgpu_dialect.wait(dialect_barrier, parity) mgpu_dialect.wait(dialect_barrier, parity)
# SMEM -> Registers
a_operand = a_smem_ref
zero_index = arith.constant(ir.IndexType.get(), 0)
if test_case.load_a_in_registers:
a_vector_type = ir.VectorType.get(test_case.shape_a, ab_elt_type)
zero_vector_indices = [zero_index] * len(test_case.shape_a)
a_operand = vector.load(a_vector_type, a_smem_ref, zero_vector_indices)
# Computation # Computation
shape_result = ir.MemRefType(result_gmem_ref.type).shape shape_result = ir.MemRefType(result_gmem_ref.type).shape
result_elt_type = ir.MemRefType(result_gmem_ref.type).element_type result_elt_type = ir.MemRefType(result_gmem_ref.type).element_type
acc_elt_type = ir.F32Type.get()
acc_type = ir.VectorType.get(shape_result, acc_elt_type)
zero_acc = arith.constant( zero_acc = arith.constant(
result_elt_type, ir.FloatAttr.get(result_elt_type, 0.0) result_elt_type, ir.FloatAttr.get(acc_elt_type, 0.0)
)
accumulator = vector.splat(
ir.VectorType.get(shape_result, result_elt_type), zero_acc
) )
accumulator = vector.splat(acc_type, zero_acc)
if transpose_lhs:
lhs_smem_ref = utils.memref_transpose(lhs_smem_ref, (1, 0))
if transpose_rhs:
rhs_smem_ref = utils.memref_transpose(rhs_smem_ref, (1, 0))
zero_index = arith.constant(ir.IndexType.get(), 0)
if load_a_in_registers:
# SMEM -> Registers
lhs_ty = ir.VectorType.get(lhs_shape, operand_elt_type)
zero_vector_indices = [zero_index] * len(lhs_shape)
lhs_operand = vector.load(lhs_ty, lhs_smem_ref, zero_vector_indices)
else:
lhs_operand = lhs_smem_ref
result = mgpu_dialect.wgmma( result = mgpu_dialect.wgmma(
accumulator, accumulator,
a_operand, lhs_operand,
b_smem_ref, rhs_smem_ref,
transpose_a=test_case.transpose_a,
transpose_b=test_case.transpose_b,
) )
nvvm.wgmma_commit_group_sync_aligned() nvvm.wgmma_commit_group_sync_aligned()
@ -2929,38 +2851,41 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase):
) )
nvvm.cp_async_bulk_wait_group(0) nvvm.cp_async_bulk_wait_group(0)
abtype = jnp.bfloat16 operand_type = jnp.bfloat16
acctype = jnp.float32 acctype = jnp.float32
a_jax_shape = jax.ShapeDtypeStruct(test_case.shape_a, abtype) lhs_jax_shape = jax.ShapeDtypeStruct(lhs_shape, operand_type)
b_jax_shape = jax.ShapeDtypeStruct(test_case.shape_b, abtype) rhs_jax_shape = jax.ShapeDtypeStruct(rhs_shape, operand_type)
result_jax_shape = jax.ShapeDtypeStruct(test_case.shape_res, acctype) result_jax_shape = jax.ShapeDtypeStruct(out_shape, acctype)
kernel = mgpu.as_gpu_kernel( kernel = mgpu.as_gpu_kernel(
matmul, matmul,
grid=(1, 1, 1), grid=(1, 1, 1),
block=(128, 1, 1), block=(128, 1, 1),
in_shape=(a_jax_shape, b_jax_shape), in_shape=(lhs_jax_shape, rhs_jax_shape),
out_shape=result_jax_shape, out_shape=result_jax_shape,
smem_scratch_shape=[ smem_scratch_shape=[
a_jax_shape, lhs_jax_shape,
b_jax_shape, rhs_jax_shape,
result_jax_shape, result_jax_shape,
core.TMABarrier(1), core.TMABarrier(1),
], ],
thread_semantics=mgpu.ThreadSemantics.Warpgroup, thread_semantics=mgpu.ThreadSemantics.Warpgroup,
) )
x = self.prng.uniform(-1, 1, test_case.shape_a).astype(abtype) prng_key = jax.random.key(1234)
y = self.prng.uniform(-1, 1, test_case.shape_b).astype(abtype) k0, k1 = jax.random.split(prng_key, 2)
x = jax.random.randint(k0, lhs_shape, 0, 2).astype(operand_type)
y = jax.random.randint(k1, rhs_shape, 0, 2).astype(operand_type)
transpose = lambda x, t: x.T if t else x transpose = lambda x, t: x.T if t else x
self.assertArraysAllClose( self.assertArraysAllClose(
jax.jit(kernel)(x, y), jax.jit(kernel)(x, y),
np.matmul( np.matmul(
transpose(x.reshape(test_case.shape_a), test_case.transpose_a), transpose(x, transpose_lhs),
transpose(y.reshape(test_case.shape_b), test_case.transpose_b), transpose(y, transpose_rhs)
), ),
atol=1e-5, atol=0,
rtol=1e-5, rtol=0,
) )

View File

@ -2501,7 +2501,8 @@ class SymbolicPallasTest(PallasBaseTest):
) )
assert exported_module is not None assert exported_module is not None
self.assertIn( self.assertIn(
"tensor<?x?xf32>, %arg6: tensor<?x?xf32>, %arg7: tensor<?x?xf32>", "%arg0: tensor<?x?xf32> loc(unknown), %arg1: tensor<?x?xf32>"
" loc(unknown), %arg2: tensor<?x?xf32>",
str(exported_module), str(exported_module),
) )
x = jax.ShapeDtypeStruct((128, 1024), jax.numpy.float32) x = jax.ShapeDtypeStruct((128, 1024), jax.numpy.float32)
@ -2512,7 +2513,7 @@ class SymbolicPallasTest(PallasBaseTest):
) )
assert exported_module is not None assert exported_module is not None
self.assertIn( self.assertIn(
"@sym_matmul(%arg0: tensor<128x1024xf32>, %arg1: tensor<1024x512xf32>", "call @sym_matmul(%arg0, %arg1)",
str(exported_module), str(exported_module),
) )

View File

@ -156,5 +156,36 @@ class InterpretTest(jtu.JaxTestCase):
lowered = jax.jit(matmul).lower(x, y).as_text(dialect="stablehlo") lowered = jax.jit(matmul).lower(x, y).as_text(dialect="stablehlo")
self.assertNotIn("dot_general", lowered) self.assertNotIn("dot_general", lowered)
@parameterized.parameters('nan', 'zero')
def test_uninitialized_memory(self, uninitialized_memory):
def kernel(o1_ref, o2_ref, o3_ref, t1_ref, t2_ref):
o1_ref[...] = t1_ref[...]
o2_ref[...] = t2_ref[...]
x, y, z = pl.pallas_call(
kernel,
out_shape=[
jax.ShapeDtypeStruct((8, 128), jnp.bfloat16),
jax.ShapeDtypeStruct((8, 128), jnp.int16),
jax.ShapeDtypeStruct((8, 128), jnp.float32),
],
in_specs=[],
scratch_shapes=[
pltpu.VMEM((8, 128), jnp.bfloat16),
pltpu.VMEM((8, 128), jnp.int16),
],
interpret=mosaic_interpret.TPUInterpretParams(
uninitialized_memory=uninitialized_memory),
)()
if uninitialized_memory == 'nan':
self.assertTrue(jnp.isnan(x).all())
np.testing.assert_equal(np.array(y), 32767)
self.assertTrue(jnp.isnan(z).all())
if uninitialized_memory == 'zero':
np.testing.assert_equal(np.array(x), 0)
np.testing.assert_equal(np.array(y), 0)
np.testing.assert_equal(np.array(z), 0)
if __name__ == "__main__": if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader()) absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -1680,6 +1680,47 @@ class CustomPartitionerTest(jtu.JaxTestCase):
jit_f = jax.jit(f, in_shardings=s, out_shardings=s) jit_f = jax.jit(f, in_shardings=s, out_shardings=s)
self.assertArraysEqual(x, jit_f(x)) self.assertArraysEqual(x, jit_f(x))
@jtu.with_mesh([('x', 4), ('y', 2)])
def test_custom_partitioner_pytree_inputs(self):
self.skip_if_custom_partitioning_not_supported()
def partition(mesh, arg_shapes, result_shape):
def lower_fn(xs):
x, y, z = xs
return x + y + z
return (
mesh,
lower_fn,
arg_shapes[0][0].sharding,
jax.tree.map(lambda x: x.sharding, arg_shapes),
)
def infer_sharding_from_operands(mesh, arg_shapes, result_shape):
return arg_shapes[0][0].sharding
def propagate_user_sharding(mesh, user_shape):
return user_shape.sharding
@custom_partitioning
def f(xs):
x, y, z = xs
return x + y + z
f.def_partition(
infer_sharding_from_operands=infer_sharding_from_operands,
partition=partition,
propagate_user_sharding=propagate_user_sharding,
sharding_rule='i j, i j, i j -> i j',
)
def f2(a):
return a + f((a, a, a))
pjit_f = pjit(f2, in_shardings=(P(None, 'x')), out_shardings=P('x'))
x = np.asarray(np.random.randint(0, 20, (32, 16)), dtype=np.float32)
self.assertArraysEqual(x * 4, pjit_f(x))
@jtu.pytest_mark_if_available('multiaccelerator') @jtu.pytest_mark_if_available('multiaccelerator')
class AutoShardingPjitTest(jtu.JaxTestCase): class AutoShardingPjitTest(jtu.JaxTestCase):
@ -7096,16 +7137,12 @@ class ShardingInTypesTest(jtu.JaxTestCase):
def test_set_mesh(self): def test_set_mesh(self):
mesh = jtu.create_mesh((2,), ('x',), axis_types=(AxisType.Explicit,)) mesh = jtu.create_mesh((2,), ('x',), axis_types=(AxisType.Explicit,))
prev_mesh = config.device_context.value
prev_abstract_mesh = config.abstract_mesh_context_manager.value
try: try:
jax.sharding.set_mesh(mesh) prev_mesh = jax.sharding.set_mesh(mesh)
out = reshard(np.arange(8), P('x')) out = reshard(np.arange(8), P('x'))
self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) self.assertEqual(out.sharding, NamedSharding(mesh, P('x')))
finally: finally:
config.device_context.set_local(prev_mesh) jax.sharding.set_mesh(prev_mesh)
config.abstract_mesh_context_manager.set_local(prev_abstract_mesh)
@jtu.with_user_mesh((2,), ('x',)) @jtu.with_user_mesh((2,), ('x',))
def test_auto_axes_late_bind(self, mesh): def test_auto_axes_late_bind(self, mesh):

View File

@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum # curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
# and update XLA_SHA256 with the result. # and update XLA_SHA256 with the result.
XLA_COMMIT = "3bb765472122548cc227b8bd2990f00bd533f438" XLA_COMMIT = "df971129bd82e381954da0185b534220e21798a4"
XLA_SHA256 = "72126aac7602153aee985ca20f73d11c39e3ba9cfb8027492951e787559d0497" XLA_SHA256 = "11e9a568320cf7e7d61819620fd369927527ecefb68d5d1154b1521456bbdb72"
def repo(): def repo():
tf_http_archive( tf_http_archive(