mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Merge pull request #299 from ROCm/ci-upstream-sync-152_1
CI: 03/19/25 upstream sync
This commit is contained in:
commit
b505df9973
@ -1,8 +1,8 @@
|
||||
diff --git a/build/requirements_lock_3_13_ft.txt b/build/requirements_lock_3_13_ft.txt
|
||||
index dfefaf042..2700e140e 100644
|
||||
index e7a2968e9..d37e11ee3 100644
|
||||
--- a/build/requirements_lock_3_13_ft.txt
|
||||
+++ b/build/requirements_lock_3_13_ft.txt
|
||||
@@ -4,6 +4,12 @@
|
||||
@@ -4,6 +4,11 @@
|
||||
#
|
||||
# pip-compile --allow-unsafe --generate-hashes --output-file=build/requirements_lock_3_13_ft.txt build/requirements.in
|
||||
#
|
||||
@ -10,12 +10,11 @@ index dfefaf042..2700e140e 100644
|
||||
+--pre
|
||||
+--extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple
|
||||
+numpy
|
||||
+
|
||||
+
|
||||
absl-py==2.1.0 \
|
||||
--hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \
|
||||
--hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff
|
||||
@@ -328,68 +334,6 @@ mpmath==1.3.0 \
|
||||
@@ -328,68 +333,6 @@ mpmath==1.3.0 \
|
||||
--hash=sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f \
|
||||
--hash=sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c
|
||||
# via -r build/test-requirements.txt
|
||||
@ -81,6 +80,6 @@ index dfefaf042..2700e140e 100644
|
||||
- # matplotlib
|
||||
- # ml-dtypes
|
||||
- # scipy
|
||||
opt-einsum==3.4.0 \
|
||||
--hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \
|
||||
--hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac
|
||||
nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \
|
||||
--hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \
|
||||
--hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \
|
||||
|
19
.github/workflows/tsan.yaml
vendored
19
.github/workflows/tsan.yaml
vendored
@ -173,12 +173,18 @@ jobs:
|
||||
--bazel_options=--copt=-g \
|
||||
--clang_path=/usr/bin/clang-18
|
||||
|
||||
# Update the patch to use TSAN instrumented numpy
|
||||
# Patch build/requirements_lock_3_13_ft.txt to use TSAN instrumented NumPy
|
||||
sed -i "s|+--extra-index-url.*|+--extra-index-url file://${GITHUB_WORKSPACE}/wheelhouse/|" .github/workflows/requirements_lock_3_13_ft.patch
|
||||
cat .github/workflows/requirements_lock_3_13_ft.patch
|
||||
git apply .github/workflows/requirements_lock_3_13_ft.patch || exit 1
|
||||
|
||||
# Apply a patch to numpy in requirements lock 3.13 ft to use the nightly version
|
||||
git apply .github/workflows/requirements_lock_3_13_ft.patch
|
||||
# Display the content for debugging in logs
|
||||
cat build/requirements_lock_3_13_ft.txt | head -15
|
||||
# Check the patch
|
||||
cat build/requirements_lock_3_13_ft.txt | head -15 | grep -E "(--pre|.*${GITHUB_WORKSPACE}/wheelhouse/|numpy)"
|
||||
if [ "$?" == "1" ]; then echo "Could not find the patch in the requirements_lock_3_13_ft.txt"; exit 1; fi
|
||||
cat build/requirements_lock_3_13_ft.txt | grep -E "(numpy==)"
|
||||
if [ "$?" == "0" ]; then "Found original numpy dependency in the requirements_lock_3_13_ft.txt"; exit 1; fi
|
||||
|
||||
echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES"
|
||||
echo "JAX_ENABLE_X64=$JAX_ENABLE_X64"
|
||||
@ -188,6 +194,13 @@ jobs:
|
||||
bazel_exec=($(ls bazel-*))
|
||||
ln -s ${bazel_exec} bazel
|
||||
|
||||
# Check python version
|
||||
./bazel run --@rules_python//python/config_settings:py_freethreaded="yes" @python//:python3 -- -VV
|
||||
|
||||
# Check numpy version
|
||||
./bazel cquery @pypi_numpy//:* | grep whl
|
||||
|
||||
# Build JAX and run tests
|
||||
./bazel test \
|
||||
--test_env=JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES \
|
||||
--test_env=JAX_ENABLE_X64=$JAX_ENABLE_X64 \
|
||||
|
@ -33,7 +33,6 @@ from jax._src import errors
|
||||
from jax._src import profiler
|
||||
from jax._src import util
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.mesh import use_concrete_mesh
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.interpreters import xla
|
||||
@ -43,7 +42,8 @@ from jax._src.lib import xla_extension as xe
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import (
|
||||
PmapSharding, SingleDeviceSharding,
|
||||
device_replica_id_map, hashed_index, num_addressable_indices, local_to_global_shape) # pyformat: disable
|
||||
device_replica_id_map, hashed_index, num_addressable_indices,
|
||||
local_to_global_shape, use_concrete_mesh) # pyformat: disable
|
||||
from jax._src.typing import ArrayLike, DLDeviceType, DTypeLike
|
||||
from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method, cache
|
||||
import numpy as np
|
||||
|
@ -29,8 +29,8 @@ class SampleFn(Protocol):
|
||||
|
||||
|
||||
def _compute_tile_index(block_index: Sequence[int],
|
||||
total_size_in_blocks: Shape,
|
||||
block_size_in_tiles: Shape,
|
||||
total_size_in_tiles: Shape,
|
||||
tile_index_in_block: Sequence[int]) -> int:
|
||||
ndims = len(block_index)
|
||||
dim_size = 1
|
||||
@ -38,7 +38,7 @@ def _compute_tile_index(block_index: Sequence[int],
|
||||
for i in range(ndims-1, -1, -1):
|
||||
dim_idx = tile_index_in_block[i] + block_index[i] * block_size_in_tiles[i]
|
||||
total_idx += dim_idx * dim_size
|
||||
dim_size *= total_size_in_blocks[i] * block_size_in_tiles[i]
|
||||
dim_size *= total_size_in_tiles[i]
|
||||
return total_idx
|
||||
|
||||
|
||||
@ -103,15 +103,17 @@ def blocked_fold_in(
|
||||
_shape // _element for _shape, _element in zip(block_size, tile_size)
|
||||
)
|
||||
|
||||
total_size_in_blocks = tuple(
|
||||
_shape // _element for _shape, _element in zip(total_size, block_size)
|
||||
# Round up to make sure every tile is numbered.
|
||||
total_size_in_tiles = tuple(
|
||||
(_shape + _element - 1) // _element
|
||||
for _shape, _element in zip(total_size, tile_size)
|
||||
)
|
||||
|
||||
def _keygen_loop(axis, prefix):
|
||||
if axis == len(block_size_in_tiles):
|
||||
subtile_key = jax.random.fold_in(
|
||||
global_key, _compute_tile_index(
|
||||
block_index, total_size_in_blocks, block_size_in_tiles, prefix))
|
||||
block_index, block_size_in_tiles, total_size_in_tiles, prefix))
|
||||
return subtile_key
|
||||
else:
|
||||
keys = []
|
||||
|
@ -179,7 +179,7 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape,
|
||||
for sharding, s in zip(result_shardings, result_shapes)
|
||||
]
|
||||
closed_jaxpr = jax.make_jaxpr(lower_fn, axis_env=list(mesh.shape.items()))(
|
||||
*tiled_args
|
||||
*info.in_tree.unflatten(tiled_args)
|
||||
)
|
||||
if ([(o.shape, o.dtype) for o in closed_jaxpr.out_avals] !=
|
||||
[(t.shape, t.dtype) for t in tiled_results]):
|
||||
|
@ -41,7 +41,7 @@ from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval,
|
||||
JaxprEqn, Primitive, ShapedArray, DShapedArray,
|
||||
mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx,
|
||||
InputType, OutputType, get_referent, JaxprEqnContext)
|
||||
from jax._src.state.types import AbstractRef
|
||||
from jax._src.state.types import AbstractRef, ReadEffect
|
||||
from jax._src.tree_util import (PyTreeDef, treedef_tuple,
|
||||
tree_flatten, tree_structure)
|
||||
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
|
||||
@ -1423,7 +1423,8 @@ def dce_jaxpr_consts(jaxpr: Jaxpr, used_outputs: Sequence[bool],
|
||||
|
||||
|
||||
def has_effects(eqn: JaxprEqn) -> bool:
|
||||
effs = {e for e in eqn.effects if not isinstance(e, core.NamedAxisEffect)}
|
||||
effs = {e for e in eqn.effects if not isinstance(e, core.NamedAxisEffect)
|
||||
and not isinstance(e, ReadEffect)}
|
||||
return bool(effs)
|
||||
|
||||
|
||||
|
@ -615,8 +615,23 @@ def tanh(x: ArrayLike) -> Array:
|
||||
"""
|
||||
return tanh_p.bind(x)
|
||||
|
||||
@export
|
||||
def logistic(x: ArrayLike) -> Array:
|
||||
r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`."""
|
||||
r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`.
|
||||
|
||||
There is no HLO logistic/sigmoid primitive, so this lowers to a sequence
|
||||
of HLO arithmetic operations.
|
||||
|
||||
Args:
|
||||
x: input array. Must have floating point or complex dtype.
|
||||
|
||||
Returns:
|
||||
Array of the same shape and dtype as ``x`` containing the element-wise
|
||||
logistic/sigmoid function.
|
||||
|
||||
See also:
|
||||
- :func:`jax.nn.sigmoid`: an alternative API for this functionality.
|
||||
"""
|
||||
return logistic_p.bind(x)
|
||||
|
||||
@export
|
||||
@ -1018,12 +1033,45 @@ def bitwise_xor(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
"""
|
||||
return xor_p.bind(x, y)
|
||||
|
||||
@export
|
||||
def population_count(x: ArrayLike) -> Array:
|
||||
r"""Elementwise popcount, count the number of set bits in each element."""
|
||||
r"""Elementwise popcount, count the number of set bits in each element.
|
||||
|
||||
This function lowers directly to the `stablehlo.popcnt`_ operation.
|
||||
|
||||
Args:
|
||||
x: Input array. Must have integer dtype.
|
||||
|
||||
Returns:
|
||||
An array of the same shape and dtype as ``x``, containing the number of
|
||||
set bits in the input.
|
||||
|
||||
See also:
|
||||
- :func:`jax.lax.clz`: Elementwise count leading zeros.
|
||||
- :func:`jax.numpy.bitwise_count`: More flexible NumPy-style API for bit counts.
|
||||
|
||||
.. _stablehlo.popcnt: https://openxla.org/stablehlo/spec#popcnt
|
||||
"""
|
||||
return population_count_p.bind(x)
|
||||
|
||||
@export
|
||||
def clz(x: ArrayLike) -> Array:
|
||||
r"""Elementwise count-leading-zeros."""
|
||||
r"""Elementwise count-leading-zeros.
|
||||
|
||||
This function lowers directly to the `stablehlo.count_leading_zeros`_ operation.
|
||||
|
||||
Args:
|
||||
x: Input array. Must have integer dtype.
|
||||
|
||||
Returns:
|
||||
An array of the same shape and dtype as ``x``, containing the number of
|
||||
set bits in the input.
|
||||
|
||||
See also:
|
||||
- :func:`jax.lax.population_count`: Count the number of set bits in each element.
|
||||
|
||||
.. _stablehlo.count_leading_zeros: https://openxla.org/stablehlo/spec#count_leading_zeros
|
||||
"""
|
||||
return clz_p.bind(x)
|
||||
|
||||
@export
|
||||
@ -1124,31 +1172,81 @@ def div(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
"""
|
||||
return div_p.bind(x, y)
|
||||
|
||||
@export
|
||||
def rem(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
r"""Elementwise remainder: :math:`x \bmod y`.
|
||||
|
||||
The sign of the result is taken from the dividend,
|
||||
and the absolute value of the result is always
|
||||
less than the divisor's absolute value.
|
||||
This function lowers directly to the `stablehlo.remainder`_ operation.
|
||||
The sign of the result is taken from the dividend, and the absolute value
|
||||
of the result is always less than the divisor's absolute value.
|
||||
|
||||
Integer division overflow
|
||||
(remainder by zero or remainder of INT_SMIN with -1)
|
||||
Integer division overflow (remainder by zero or remainder of INT_SMIN with -1)
|
||||
produces an implementation defined value.
|
||||
|
||||
Args:
|
||||
x, y: Input arrays. Must have matching int or float dtypes. If neither
|
||||
is a scalar, ``x`` and ``y`` must have the same number of dimensions
|
||||
and be broadcast compatible.
|
||||
|
||||
Returns:
|
||||
An array of the same dtype as ``x`` and ``y`` containing the remainder.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.remainder`: NumPy-style remainder with different
|
||||
sign semantics.
|
||||
|
||||
.. _stablehlo.remainder: https://openxla.org/stablehlo/spec#remainder
|
||||
"""
|
||||
return rem_p.bind(x, y)
|
||||
|
||||
@export
|
||||
def max(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
r"""Elementwise maximum: :math:`\mathrm{max}(x, y)`
|
||||
r"""Elementwise maximum: :math:`\mathrm{max}(x, y)`.
|
||||
|
||||
For complex numbers, uses a lexicographic comparison on the
|
||||
`(real, imaginary)` pairs."""
|
||||
This function lowers directly to the `stablehlo.maximum`_ operation for
|
||||
non-complex inputs. For complex numbers, this uses a lexicographic
|
||||
comparison on the `(real, imaginary)` pairs.
|
||||
|
||||
Args:
|
||||
x, y: Input arrays. Must have matching dtypes. If neither is a scalar,
|
||||
``x`` and ``y`` must have the same rank and be broadcast compatible.
|
||||
|
||||
Returns:
|
||||
An array of the same dtype as ``x`` and ``y`` containing the elementwise
|
||||
maximum.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.maximum`: more flexibly NumPy-style maximum.
|
||||
- :func:`jax.lax.reduce_max`: maximum along an axis of an array.
|
||||
- :func:`jax.lax.min`: elementwise minimum.
|
||||
|
||||
.. _stablehlo.maximum: https://openxla.org/stablehlo/spec#maximum
|
||||
"""
|
||||
return max_p.bind(x, y)
|
||||
|
||||
@export
|
||||
def min(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
r"""Elementwise minimum: :math:`\mathrm{min}(x, y)`
|
||||
r"""Elementwise minimum: :math:`\mathrm{min}(x, y)`
|
||||
|
||||
For complex numbers, uses a lexicographic comparison on the
|
||||
`(real, imaginary)` pairs."""
|
||||
This function lowers directly to the `stablehlo.minimum`_ operation for
|
||||
non-complex inputs. For complex numbers, this uses a lexicographic
|
||||
comparison on the `(real, imaginary)` pairs.
|
||||
|
||||
Args:
|
||||
x, y: Input arrays. Must have matching dtypes. If neither is a scalar,
|
||||
``x`` and ``y`` must have the same rank and be broadcast compatible.
|
||||
|
||||
Returns:
|
||||
An array of the same dtype as ``x`` and ``y`` containing the elementwise
|
||||
minimum.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.minimum`: more flexibly NumPy-style minimum.
|
||||
- :func:`jax.lax.reduce_min`: minimum along an axis of an array.
|
||||
- :func:`jax.lax.max`: elementwise maximum.
|
||||
|
||||
.. _stablehlo.minimum: https://openxla.org/stablehlo/spec#minimum
|
||||
"""
|
||||
return min_p.bind(x, y)
|
||||
|
||||
@export
|
||||
@ -1408,21 +1506,38 @@ def lt(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
"""
|
||||
return lt_p.bind(x, y)
|
||||
|
||||
@export
|
||||
def convert_element_type(operand: ArrayLike,
|
||||
new_dtype: DTypeLike | dtypes.ExtendedDType) -> Array:
|
||||
"""Elementwise cast.
|
||||
|
||||
Wraps XLA's `ConvertElementType
|
||||
<https://www.tensorflow.org/xla/operation_semantics#convertelementtype>`_
|
||||
operator, which performs an elementwise conversion from one type to another.
|
||||
Similar to a C++ `static_cast`.
|
||||
This function lowers directly to the `stablehlo.convert`_ operation, which
|
||||
performs an elementwise conversion from one type to another, similar to a
|
||||
C++ ``static_cast``.
|
||||
|
||||
Args:
|
||||
operand: an array or scalar value to be cast.
|
||||
new_dtype: a NumPy dtype representing the target type.
|
||||
new_dtype: a dtype-like object (e.g. a :class:`numpy.dtype`, a scalar type,
|
||||
or a valid dtype name) representing the target dtype.
|
||||
|
||||
Returns:
|
||||
An array with the same shape as `operand`, cast elementwise to `new_dtype`.
|
||||
An array with the same shape as ``operand``, cast elementwise to ``new_dtype``.
|
||||
|
||||
.. note::
|
||||
|
||||
If ``new_dtype`` is a 64-bit type and `x64 mode`_ is not enabled,
|
||||
the appropriate 32-bit type will be used in its place.
|
||||
|
||||
If the input is a JAX array and the input dtype and output dtype match, then
|
||||
the input array will be returned unmodified.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.astype`: NumPy-style dtype casting API.
|
||||
- :meth:`jax.Array.astype`: dtype casting as an array method.
|
||||
- :func:`jax.lax.bitcast_convert_type`: cast bits directly to a new dtype.
|
||||
|
||||
.. _stablehlo.convert: https://openxla.org/stablehlo/spec#convert
|
||||
.. _x64 mode: https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision
|
||||
"""
|
||||
return _convert_element_type(operand, new_dtype, weak_type=False) # type: ignore[unused-ignore,bad-return-type]
|
||||
|
||||
@ -1500,12 +1615,11 @@ def _convert_element_type(
|
||||
operand, new_dtype=new_dtype, weak_type=bool(weak_type),
|
||||
sharding=sharding)
|
||||
|
||||
@export
|
||||
def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array:
|
||||
"""Elementwise bitcast.
|
||||
|
||||
Wraps XLA's `BitcastConvertType
|
||||
<https://www.tensorflow.org/xla/operation_semantics#bitcastconverttype>`_
|
||||
operator, which performs a bit cast from one type to another.
|
||||
This function lowers directly to the `stablehlo.bitcast_convert`_ operation.
|
||||
|
||||
The output shape depends on the size of the input and output dtypes with
|
||||
the following logic::
|
||||
@ -1525,6 +1639,12 @@ def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array:
|
||||
Returns:
|
||||
An array of shape `output_shape` (see above) and type `new_dtype`,
|
||||
constructed from the same bits as operand.
|
||||
|
||||
See also:
|
||||
- :func:`jax.lax.convert_element_type`: value-preserving dtype conversion.
|
||||
- :func:`jax.Array.view`: NumPy-style API for bitcast type conversion.
|
||||
|
||||
.. _stablehlo.bitcast_convert: https://openxla.org/stablehlo/spec#bitcast_convert
|
||||
"""
|
||||
new_dtype = dtypes.canonicalize_dtype(new_dtype)
|
||||
return bitcast_convert_type_p.bind(operand, new_dtype=new_dtype)
|
||||
|
@ -1301,6 +1301,7 @@ def _ragged_all_to_all_transpose(
|
||||
mask = jax.numpy.cumsum(
|
||||
jax.numpy.zeros(t.shape[0], dtype='int32').at[output_offsets_].set(1)\
|
||||
.at[output_offsets_ + recv_sizes].add(-1))
|
||||
mask = jax.numpy.expand_dims(mask, (*range(1, t.ndim),))
|
||||
output_t = jax.numpy.where(mask, 0, t)
|
||||
return [operand_t, output_t] + [None] * 4
|
||||
|
||||
|
@ -543,6 +543,10 @@ class UseAbstractMeshContextManager:
|
||||
__slots__ = ['mesh', 'prev']
|
||||
|
||||
def __init__(self, mesh: AbstractMesh):
|
||||
if not isinstance(mesh, AbstractMesh):
|
||||
raise ValueError(
|
||||
"Expected mesh of type `jax.sharding.AbstractMesh`. Got type:"
|
||||
f" {type(mesh)}")
|
||||
self.mesh = mesh
|
||||
|
||||
def __enter__(self):
|
||||
@ -557,13 +561,5 @@ def get_abstract_mesh():
|
||||
val = jax_config.abstract_mesh_context_manager.value
|
||||
return empty_abstract_mesh if val is None else val
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use_concrete_mesh(mesh: Mesh | None):
|
||||
prev_val = jax_config.device_context.swap_local(mesh)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
jax_config.device_context.set_local(prev_val)
|
||||
|
||||
def get_concrete_mesh() -> Mesh | None:
|
||||
return jax_config.device_context.value
|
||||
|
@ -71,7 +71,7 @@ import numpy as np
|
||||
|
||||
export = set_module('jax.numpy')
|
||||
|
||||
for pkg_name in ['jax_cuda12_plugin', 'jax.jaxlib']:
|
||||
for pkg_name in ['jax_cuda12_plugin', 'jax.jaxlib.cuda']:
|
||||
try:
|
||||
cuda_plugin_extension = importlib.import_module(
|
||||
f'{pkg_name}.cuda_plugin_extension'
|
||||
|
@ -35,6 +35,7 @@ from jax._src import linear_util as lu
|
||||
from jax._src import state
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src.export._export import export
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.state import discharge as state_discharge
|
||||
@ -1165,14 +1166,16 @@ jax_core.custom_typechecks[core_map_p] = _core_map_typecheck_rule
|
||||
|
||||
|
||||
def lower_as_mlir(
|
||||
f, *args, dynamic_shapes=False, device=None, **kwargs
|
||||
f, *args, dynamic_shapes=False, device=None, static_argnames=(), **kwargs
|
||||
) -> mlir.ir.Module:
|
||||
with pallas_export_experimental(dynamic_shapes):
|
||||
lowered = jax.jit(f, device=device).lower(*args, **kwargs)
|
||||
stablehlo = lowered.compiler_ir(dialect="stablehlo")
|
||||
f = jax.jit(f, device=device, static_argnames=static_argnames)
|
||||
exported = export(f, platforms=["tpu"])(*args, **kwargs)
|
||||
stablehlo = exported.mlir_module()
|
||||
|
||||
return stablehlo # type: ignore[return-value]
|
||||
|
||||
|
||||
_out_shape_to_aval_mapping: dict[
|
||||
type[Any], Callable[[Any], jax_core.AbstractValue]
|
||||
] = {}
|
||||
|
@ -244,6 +244,13 @@ def pull_block_spec(
|
||||
_unwrap_block_spec_scalar_prefetch, out_block_specs
|
||||
)
|
||||
flat_block_specs, out_tree = jax.tree.flatten(block_specs_)
|
||||
jaxpr, used_consts, used_invars = pe.dce_jaxpr_consts(
|
||||
jaxpr,
|
||||
used_outputs=[True] * len(jaxpr.outvars),
|
||||
instantiate=True,
|
||||
)
|
||||
assert all(used_invars)
|
||||
assert all(used_consts)
|
||||
in_block_specs, env, read_usage_env = _pull_block_spec(
|
||||
jaxpr,
|
||||
tuple(flat_block_specs),
|
||||
|
@ -83,10 +83,15 @@ class TPUInterpretParams:
|
||||
replaced with arrays all of `jnp.inf`. Additionaly any floating point
|
||||
operands to any operation will be replaced with (arrays of) `jnp.inf`.
|
||||
Default: False.
|
||||
uninitialized_memory: If "nan", allocated buffers are initialized to
|
||||
to contain all NaNs (or to their maximum possible value for integers).
|
||||
If "zero", allocated buffers are initialized to all zeros.
|
||||
Default: "nan".
|
||||
"""
|
||||
dma_execution_mode: Literal["eager", "on_wait"] = "on_wait"
|
||||
detect_races: bool = False
|
||||
skip_floating_point_ops: bool = False
|
||||
uninitialized_memory: Literal["nan", "zero"] = "nan"
|
||||
|
||||
|
||||
VectorClock = np.ndarray
|
||||
@ -1114,7 +1119,8 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params, interpret_params):
|
||||
jax.ShapeDtypeStruct((), jnp.int16),
|
||||
device_id,
|
||||
TPU_MEMORY_SPACE_IDXS[v.aval.memory_space],
|
||||
primitives.uninitialized_value(v.aval.shape, v.aval.dtype),
|
||||
_uninitialized_value(
|
||||
v.aval.shape, v.aval.dtype, interpret_params),
|
||||
ordered=True))
|
||||
|
||||
out = _interpret(eqn.params['jaxpr'], *deferred_invals(), *allocs)
|
||||
@ -1279,16 +1285,19 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params, interpret_params):
|
||||
|
||||
def _initialize_output_vals(
|
||||
block_mappings_output: Iterable[BlockMapping],
|
||||
input_args, input_output_aliases) -> Sequence[jax.Array]:
|
||||
input_args, input_output_aliases,
|
||||
interpret_params: TPUInterpretParams,
|
||||
) -> Sequence[jax.Array]:
|
||||
oi_map = {v: k for k, v in input_output_aliases}
|
||||
output_vals = []
|
||||
for i, bm in enumerate(block_mappings_output):
|
||||
if i in oi_map:
|
||||
output_vals.append(input_args[oi_map[i]])
|
||||
else:
|
||||
output_vals.append(primitives.uninitialized_value(
|
||||
output_vals.append(_uninitialized_value(
|
||||
bm.array_shape_dtype.shape,
|
||||
bm.array_shape_dtype.dtype))
|
||||
bm.array_shape_dtype.dtype,
|
||||
interpret_params))
|
||||
return output_vals
|
||||
|
||||
def _compute_start_indices(block_mapping, loop_idx, *args):
|
||||
@ -1319,7 +1328,20 @@ def _maybe_dynamic_slice(start_idx, block_shape, value, is_indexing):
|
||||
dtype=np.bool_)])
|
||||
return lax.squeeze(output, squeeze_dims)
|
||||
|
||||
def _pad_to_block_dimension(value, block_shape):
|
||||
def _uninitialized_value(shape, dtype, interpret_params):
|
||||
if interpret_params.uninitialized_memory == 'nan':
|
||||
if jnp.issubdtype(dtype, jnp.floating):
|
||||
return jnp.full(shape, jnp.nan, dtype)
|
||||
elif jnp.issubdtype(dtype, jnp.integer):
|
||||
return jnp.full(shape, jnp.iinfo(dtype).max, dtype)
|
||||
elif jnp.issubdtype(dtype, jnp.bool):
|
||||
return jnp.full(shape, False, dtype)
|
||||
if interpret_params.uninitialized_memory == 'zero':
|
||||
return jnp.full(shape, 0, dtype)
|
||||
raise NotImplementedError(
|
||||
interpret_params.uninitialized_memory + ' + ' + str(dtype))
|
||||
|
||||
def _pad_to_block_dimension(value, block_shape, interpret_params):
|
||||
"""Pads values so the shape evenly divides into block dimensions.
|
||||
|
||||
For example, if values has a shape of (33, 2, 5) with a block_shape of
|
||||
@ -1338,7 +1360,7 @@ def _pad_to_block_dimension(value, block_shape):
|
||||
)
|
||||
if padded_shape != value.shape:
|
||||
pad_width = tuple((0, a-b) for a, b in zip(padded_shape, value.shape))
|
||||
pad_value = primitives.uninitialized_value(shape=(), dtype=value.dtype)
|
||||
pad_value = _uninitialized_value((), value.dtype, interpret_params)
|
||||
value = jnp.pad(value, pad_width, constant_values=pad_value)
|
||||
return value
|
||||
|
||||
@ -1397,7 +1419,7 @@ def interpret_pallas_call(
|
||||
]
|
||||
num_inputs = grid_mapping.num_inputs
|
||||
input_args = [
|
||||
_pad_to_block_dimension(a, bs)
|
||||
_pad_to_block_dimension(a, bs, interpret_params)
|
||||
for a, bs in zip(input_args, block_shapes[:num_inputs])
|
||||
]
|
||||
|
||||
@ -1407,11 +1429,12 @@ def interpret_pallas_call(
|
||||
output_vals = _initialize_output_vals(
|
||||
grid_mapping.block_mappings_output,
|
||||
scalars + input_args,
|
||||
input_output_aliases)
|
||||
input_output_aliases,
|
||||
interpret_params)
|
||||
num_outputs = grid_mapping.num_outputs
|
||||
output_block_shapes = block_shapes[num_inputs : num_inputs + num_outputs]
|
||||
for out_val, bs in zip(output_vals, output_block_shapes):
|
||||
padded_val = _pad_to_block_dimension(out_val, bs)
|
||||
padded_val = _pad_to_block_dimension(out_val, bs, interpret_params)
|
||||
output_buffer_shapes.append(padded_val.shape)
|
||||
output_buffer_ids.append(callback.io_callback(
|
||||
_allocate_buffer,
|
||||
@ -1466,7 +1489,8 @@ def interpret_pallas_call(
|
||||
jax.ShapeDtypeStruct((), jnp.int16),
|
||||
device_id,
|
||||
TPU_MEMORY_SPACE_IDXS[var.aval.memory_space],
|
||||
primitives.uninitialized_value(var.aval.shape, var.aval.dtype),
|
||||
_uninitialized_value(
|
||||
var.aval.shape, var.aval.dtype, interpret_params),
|
||||
ordered=True))
|
||||
|
||||
_, input_ids, kernel_output_ids, _ = split_list(
|
||||
|
@ -40,6 +40,7 @@ from jax._src import source_info_util
|
||||
from jax._src import state
|
||||
from jax._src import traceback_util
|
||||
from jax._src.cloud_tpu_init import is_cloud_tpu_older_than
|
||||
from jax._src.export._export import export
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.lax import lax as lax_internal
|
||||
@ -89,6 +90,11 @@ BOOL_MEMREF_TYPE = np.dtype('int32')
|
||||
# The value interpreted as a dynamic dimension by MLIR.
|
||||
MLIR_DYNAMIC = -9223372036854775808
|
||||
|
||||
# TODO(mvoz): Find a way to make this a contract we can share with the
|
||||
# export specialization step in XLA export.
|
||||
DIM_UPPER_BOUND = np.iinfo(np.int32).max
|
||||
DIM_LOWER_BOUND = -128
|
||||
|
||||
partial = functools.partial
|
||||
map, unsafe_map = safe_map, map # pylint: disable=redefined-builtin
|
||||
zip, unsafe_zip = safe_zip, zip # pylint: disable=redefined-builtin
|
||||
@ -102,17 +108,49 @@ class MeshContext:
|
||||
|
||||
# Note - On Export Placeholders
|
||||
#
|
||||
# Mosaic uses vector IR, which does not have a concept of dynamic
|
||||
# dimensions. We need to come up with a way to represent dynamic dimensions in
|
||||
# vector IR, and so we use placeholders, which are later replaced during
|
||||
# specialization.
|
||||
# Since the vector dialect used by Mosaic does not support dynamic shapes,
|
||||
# we replace all top-level symbolic dimensions with placeholder
|
||||
# constants (between max(int32) - 128 and max(int32)) and we keep a
|
||||
# mapping from the placeholder constants to SHLO functions that encode
|
||||
# the symbolic dimension expression, as a function of the dimension
|
||||
# variables.
|
||||
#
|
||||
# The calling convention of the produced MLIR module is the same as
|
||||
# regular mosaic module, except we add on two new attributes to the custom call
|
||||
# *per* intermediary placeholder dimension.
|
||||
#
|
||||
# The attributes are:
|
||||
#
|
||||
# tpu.dynamic_dimension_mapping_arg_name_<placeholder>
|
||||
# tpu.dynamic_dimension_mapping_module_<placeholder>
|
||||
#
|
||||
# The first attribute is a comma-separated list of the dimension variables
|
||||
# that are used to compute the symbolic dimension expression for the
|
||||
# placeholder. The second attribute is the MLIR module that contains the
|
||||
# SHLO functions that compute the symbolic dimension expression for the
|
||||
# placeholder.
|
||||
class LoweringDynamicShapeEnv:
|
||||
dim_expr_to_placeholder: dict[Any, ir.Value] = {}
|
||||
dim_expr_to_placeholder: dict[shape_poly._DimExpr, int] = {}
|
||||
placeholder_to_dim_expr: dict[int, shape_poly._DimExpr] = {}
|
||||
|
||||
def to_placeholder(self, dim_expr: Any) -> ir.Value:
|
||||
if jax_core.is_constant_dim(dim_expr):
|
||||
# avoid ints, these are not dynamic
|
||||
return dim_expr
|
||||
if dim_expr not in self.dim_expr_to_placeholder:
|
||||
next_val = np.iinfo(np.int32).max - len(self.dim_expr_to_placeholder)
|
||||
next_val = DIM_UPPER_BOUND - len(self.dim_expr_to_placeholder)
|
||||
if next_val < DIM_LOWER_BOUND:
|
||||
# In practice, even with the largest of programs, we see rarely see
|
||||
# anything even close to this limit. It is arbitrary, and can be safely
|
||||
# increased if needed.
|
||||
raise ValueError(
|
||||
"Too many dynamic shapes in the input. Mosaic currently only"
|
||||
" supports up to 128 dynamic dimension values."
|
||||
)
|
||||
self.dim_expr_to_placeholder[dim_expr] = next_val
|
||||
# Reverse mapping - this is consumed to generate a table that is either
|
||||
# input<>placeholder or intermediary computation<>placeholder.
|
||||
self.placeholder_to_dim_expr[next_val] = dim_expr
|
||||
return self.dim_expr_to_placeholder[dim_expr]
|
||||
|
||||
|
||||
@ -622,6 +660,7 @@ def lower_jaxpr_to_module(
|
||||
"Pallas TPU requires a libTPU version that's at most a month old"
|
||||
)
|
||||
debug_info = jaxpr.debug_info
|
||||
_mosaic_lowering_dynamic_shape_env = None
|
||||
if dynamic_shape_replacement_enabled:
|
||||
_mosaic_lowering_dynamic_shape_env = LoweringDynamicShapeEnv()
|
||||
|
||||
@ -663,10 +702,12 @@ def lower_jaxpr_to_module(
|
||||
for_verification=for_verification,
|
||||
forward_compatible=lowering_context.is_forward_compat(),
|
||||
dynamic_shape_replacement_fn=dynamic_shape_replacement_fn,
|
||||
dynamic_shape_replacement_enabled=dynamic_shape_replacement_enabled,
|
||||
)
|
||||
m.body.append(func_op)
|
||||
sym_tab.insert(func_op)
|
||||
window_params = []
|
||||
static_grid = None
|
||||
grid = mosaic_grid_mapping.grid
|
||||
if grid:
|
||||
for i, bm in enumerate(grid_mapping.block_mappings):
|
||||
@ -738,7 +779,6 @@ def lower_jaxpr_to_module(
|
||||
]
|
||||
static_grid = dynamic_shape_replacement_fn(static_grid)
|
||||
func_op.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(static_grid)
|
||||
|
||||
func_op.attributes["scalar_prefetch"] = ir.IntegerAttr.get(
|
||||
ir.IntegerType.get_signless(64), len(mosaic_grid_mapping.scalar_prefetch_types))
|
||||
func_op.attributes["scratch_operands"] = ir.IntegerAttr.get(
|
||||
@ -746,6 +786,60 @@ def lower_jaxpr_to_module(
|
||||
func_op.attributes["dimension_semantics"] = (
|
||||
mosaic_grid_mapping.get_dimension_semantics()
|
||||
)
|
||||
if dynamic_shape_replacement_enabled:
|
||||
if _mosaic_lowering_dynamic_shape_env is None:
|
||||
raise ValueError(
|
||||
"Dynamic shape env is None, invariant violated. Unreachable?"
|
||||
)
|
||||
|
||||
# Now we can use jax to compute the dynamic shape graph
|
||||
|
||||
if static_grid is not None:
|
||||
grid_vars = [
|
||||
_mosaic_lowering_dynamic_shape_env.placeholder_to_dim_expr.get(g, g)
|
||||
for g in static_grid
|
||||
]
|
||||
else:
|
||||
grid_vars = []
|
||||
|
||||
invars = [invar.aval for invar in jaxpr.invars]
|
||||
# Faux shape for grid, just to get the avals
|
||||
invars.append(jax.ShapeDtypeStruct(grid_vars, jax.numpy.int32))
|
||||
args_dimvars = shape_poly.all_dim_vars(invars)
|
||||
|
||||
# This is dimexpr var -> placeholder value for when we jit the dim expr
|
||||
env: dict[str, int] = {}
|
||||
for aval in args_dimvars:
|
||||
env[aval] = _mosaic_lowering_dynamic_shape_env.to_placeholder(aval)
|
||||
|
||||
for (
|
||||
placeholder,
|
||||
dim_expr,
|
||||
) in _mosaic_lowering_dynamic_shape_env.placeholder_to_dim_expr.items():
|
||||
top_level_names = list(env.keys())
|
||||
if dim_expr not in top_level_names:
|
||||
jitted_eval = jax.jit(
|
||||
jax_core.evaluate_shape,
|
||||
static_argnames=(
|
||||
"shape",
|
||||
"dim_vars",
|
||||
),
|
||||
keep_unused=True,
|
||||
)
|
||||
stablehlo = export(
|
||||
jitted_eval, platforms=[str(jax.devices()[0].platform)]
|
||||
)(
|
||||
(dim_expr,), tuple(args_dimvars), *(env[v] for v in args_dimvars)
|
||||
).mlir_module()
|
||||
arg_name = args_dimvars
|
||||
# See Note - On Export Placeholders for more details.
|
||||
m.operation.attributes[
|
||||
"tpu.dynamic_dimension_mapping_module_" + str(placeholder)
|
||||
] = ir.StringAttr.get(str(stablehlo))
|
||||
arg_name_str = ",".join(arg_name)
|
||||
m.operation.attributes[
|
||||
"tpu.dynamic_dimension_mapping_arg_name_" + str(placeholder)
|
||||
] = ir.StringAttr.get(arg_name_str)
|
||||
return m, mosaic_grid_mapping.get_extra_args()
|
||||
|
||||
|
||||
@ -828,6 +922,7 @@ def lower_jaxpr_to_func(
|
||||
dynamic_shape_replacement_fn: (
|
||||
Callable[[tuple[jax.DimSize, ...]], tuple[int, ...]] | None
|
||||
) = None,
|
||||
dynamic_shape_replacement_enabled: bool = False,
|
||||
) -> func.FuncOp:
|
||||
num_grid = len(mosaic_grid_mapping.grid_types)
|
||||
num_scalar_prefetch = len(mosaic_grid_mapping.scalar_prefetch_types)
|
||||
@ -874,6 +969,12 @@ def lower_jaxpr_to_func(
|
||||
)
|
||||
body_func.__name__ = name
|
||||
body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func)
|
||||
if dynamic_shape_replacement_enabled:
|
||||
# Skip verification for dynamic shape replacement - you can potentially
|
||||
# produce ir like ex: add(x[placeholder_0, placeholder_1], y[128, 128])
|
||||
# which is not valid, but we don't care since we'll run the verifier again
|
||||
# after the dynamic shape replacement pass.
|
||||
return body.func_op
|
||||
try:
|
||||
body.func_op.verify()
|
||||
except ir.MLIRError as e:
|
||||
@ -3851,3 +3952,15 @@ def _platform_index_lowering(
|
||||
|
||||
|
||||
lowering_rules[jax._src.lax.control_flow.platform_index_p] = _platform_index_lowering
|
||||
|
||||
|
||||
def _dim_as_value_lowering(ctx: mlir.LoweringRuleContext, *, dim):
|
||||
placeholder = ctx.lowering_context.dynamic_shape_replacement_fn((dim,))[0]
|
||||
return ir_constant(
|
||||
placeholder, mlir_type=_dtype_to_ir_type(jnp.dtype("int32"))
|
||||
)
|
||||
|
||||
|
||||
import jax._src.export.shape_poly as shape_poly
|
||||
|
||||
lowering_rules[shape_poly.dim_as_value_p] = _dim_as_value_lowering
|
||||
|
@ -1196,9 +1196,8 @@ def emit_pipeline(
|
||||
schedule = map_brefs(
|
||||
lambda _, x: get_pipeline_schedule(x), allocations, schedule)
|
||||
|
||||
def loop_body(step, indices):
|
||||
nonlocal allocations
|
||||
scheduler = Scheduler(
|
||||
def make_scheduler(step, indices):
|
||||
return Scheduler(
|
||||
step,
|
||||
indices,
|
||||
grid,
|
||||
@ -1208,13 +1207,15 @@ def emit_pipeline(
|
||||
init_accumulators=init_accumulators,
|
||||
trace_scopes=trace_scopes,
|
||||
)
|
||||
|
||||
def loop_body(step, indices):
|
||||
scheduler = make_scheduler(step, indices)
|
||||
with scheduler.grid_env():
|
||||
|
||||
# prepare any local VMEM aliases
|
||||
brefs = map_brefs(scheduler.alias_local_refs, allocations, refs)
|
||||
|
||||
# loop input handling phase
|
||||
map_brefs(scheduler.initialize, brefs, refs, schedule)
|
||||
map_brefs(scheduler.copy_in, brefs, refs, schedule)
|
||||
map_brefs(scheduler.wait_in, brefs, refs, schedule)
|
||||
|
||||
@ -1243,12 +1244,24 @@ def emit_pipeline(
|
||||
lambda: None)
|
||||
|
||||
map_brefs(scheduler.swap_slots, brefs, refs, schedule)
|
||||
map_brefs(scheduler.finalize, brefs, refs, schedule)
|
||||
|
||||
return _next_index(indices, grid)
|
||||
|
||||
# run pipeline
|
||||
lax.fori_loop(0, num_steps, loop_body, (0,) * len(grid))
|
||||
@pl.when(num_steps > 0)
|
||||
def _():
|
||||
# pipeline prologue
|
||||
initial_indices = (0,) * len(grid)
|
||||
scheduler = make_scheduler(0, initial_indices)
|
||||
brefs = map_brefs(scheduler.alias_local_refs, allocations, refs)
|
||||
map_brefs(scheduler.initialize, brefs, refs, schedule)
|
||||
|
||||
# pipeline loop
|
||||
next_indices = lax.fori_loop(0, num_steps, loop_body, initial_indices)
|
||||
|
||||
# pipeline epilogue
|
||||
final_indices = _prev_index(next_indices, grid)
|
||||
scheduler = make_scheduler(num_steps - 1, final_indices)
|
||||
brefs = map_brefs(scheduler.alias_local_refs, allocations, refs)
|
||||
map_brefs(scheduler.finalize, brefs, refs, schedule)
|
||||
|
||||
return pipeline
|
||||
|
||||
|
@ -1387,16 +1387,38 @@ def use_mesh(mesh: mesh_lib.Mesh):
|
||||
# if not core.trace_state_clean():
|
||||
# raise ValueError('`use_mesh` can only be used outside of `jax.jit`')
|
||||
|
||||
with (mesh_lib.use_abstract_mesh(mesh.abstract_mesh),
|
||||
mesh_lib.use_concrete_mesh(mesh)):
|
||||
with mesh_lib.use_abstract_mesh(mesh.abstract_mesh), use_concrete_mesh(mesh):
|
||||
yield
|
||||
|
||||
def set_mesh(mesh: mesh_lib.Mesh) -> None:
|
||||
if not isinstance(mesh, mesh_lib.Mesh):
|
||||
def set_mesh(mesh: mesh_lib.Mesh | None) -> mesh_lib.Mesh | None:
|
||||
"""Sets the given concrete mesh globally and returns the previous concrete
|
||||
mesh."""
|
||||
if mesh is not None and not isinstance(mesh, mesh_lib.Mesh):
|
||||
raise ValueError(
|
||||
f"Expected mesh of type `jax.sharding.Mesh`. Got {type(mesh)}")
|
||||
if not core.trace_state_clean():
|
||||
raise ValueError('`set_mesh` can only be used outside of `jax.jit`.')
|
||||
|
||||
config.abstract_mesh_context_manager.set_local(mesh.abstract_mesh)
|
||||
config.device_context.set_local(mesh)
|
||||
if mesh is None:
|
||||
config.abstract_mesh_context_manager.set_global(mesh_lib.empty_abstract_mesh) # type: ignore
|
||||
else:
|
||||
config.abstract_mesh_context_manager.set_global(mesh.abstract_mesh) # type: ignore
|
||||
|
||||
prev_mesh = config.device_context.get_global()
|
||||
config.device_context.set_global(mesh)
|
||||
return prev_mesh
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use_concrete_mesh(mesh: mesh_lib.Mesh | None):
|
||||
if mesh is not None and not isinstance(mesh, mesh_lib.Mesh):
|
||||
raise ValueError(
|
||||
f"Expected mesh of type `jax.sharding.Mesh`. Got {type(mesh)}")
|
||||
# TODO(yashkatariya): Enable this.
|
||||
# if not core.trace_state_clean():
|
||||
# raise ValueError('`use_concrete_mesh` can only be used outside of `jax.jit`.')
|
||||
|
||||
prev_val = config.device_context.swap_local(mesh)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
config.device_context.set_local(prev_val)
|
||||
|
@ -41,20 +41,16 @@ from jaxlib.mlir.dialects import memref
|
||||
from jaxlib.mlir.dialects import nvvm
|
||||
import numpy as np
|
||||
|
||||
if dialect is not None:
|
||||
from . import dialect_lowering
|
||||
from . import layout_inference
|
||||
else:
|
||||
dialect_lowering = None
|
||||
layout_inference = None
|
||||
|
||||
from . import profiler
|
||||
from . import utils
|
||||
from . import launch_context
|
||||
from . import tcgen05
|
||||
|
||||
# mypy: ignore-errors
|
||||
|
||||
from . import dialect_lowering
|
||||
from . import launch_context
|
||||
from . import layout_inference
|
||||
from . import profiler
|
||||
from . import tcgen05
|
||||
from . import transform_inference
|
||||
from . import utils
|
||||
|
||||
# MLIR can't find libdevice unless we point it to the CUDA path
|
||||
# TODO(apaszke): Unify with jax._src.lib.cuda_path
|
||||
CUDA_ROOT = "/usr/local/cuda"
|
||||
@ -584,6 +580,7 @@ def as_gpu_kernel(
|
||||
# Run Python lowering passes. The remaining passes will be run in C++ in
|
||||
# jax/jaxlib/mosaic/gpu/custom_call.cc
|
||||
layout_inference.infer_layout(module) # pytype: disable=attribute-error
|
||||
transform_inference.infer_transforms(module) # pytype: disable=attribute-error
|
||||
dialect_lowering.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error
|
||||
|
||||
_initialize_scratch(launch_ctx, scratch_arr)
|
||||
@ -666,6 +663,7 @@ def as_torch_gpu_kernel(
|
||||
# Run Python lowering passes. The remaining passes will be run in C++ in
|
||||
# jax/jaxlib/mosaic/gpu/custom_call.cc
|
||||
layout_inference.infer_layout(module) # pytype: disable=attribute-error
|
||||
transform_inference.infer_transforms(module) # pytype: disable=attribute-error
|
||||
dialect_lowering.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error
|
||||
|
||||
_initialize_scratch(launch_ctx, scratch_arr)
|
||||
|
@ -17,6 +17,7 @@
|
||||
from collections.abc import Callable
|
||||
import dataclasses
|
||||
import functools
|
||||
import itertools
|
||||
import operator
|
||||
from typing import Any, Sequence, Type, cast
|
||||
|
||||
@ -58,7 +59,7 @@ class LoweringContext:
|
||||
if not _should_lower(op):
|
||||
return
|
||||
|
||||
if (name := op.OPERATION_NAME) not in _lowerings:
|
||||
if (name := op.OPERATION_NAME) not in _lowerings: # pytype: disable=attribute-error
|
||||
raise NotImplementedError(f"Missing lowering rule for {op}")
|
||||
|
||||
lowering_rule = _lowerings[name]
|
||||
@ -227,6 +228,60 @@ def _arith_constant_op_lowering_rule(
|
||||
]
|
||||
|
||||
|
||||
def _check_transforms_and_swizzle_are_supported(
|
||||
ref_ty: ir.MemRefType,
|
||||
transforms: Sequence[launch_context.MemRefTransform],
|
||||
swizzle: mgpu.SwizzlingMode,
|
||||
minimum_swizzle: mgpu.SwizzlingMode = mgpu.SwizzlingMode.kNoSwizzle,
|
||||
):
|
||||
"""Checks that the list of provided transforms and swizzle are supported.
|
||||
|
||||
Currently, we allow the following:
|
||||
- any swizzle that is larger than or equal to `minimum_swizzle`;
|
||||
- optionally, a single tile transform (with rank equal to the rank of the
|
||||
memref being annotated);
|
||||
- optionally, a single transpose transform.
|
||||
"""
|
||||
if swizzle < minimum_swizzle:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported swizzle {swizzle} smaller than {minimum_swizzle}."
|
||||
)
|
||||
|
||||
partitioned_transforms = {
|
||||
k: list(v)
|
||||
for k, v in itertools.groupby(
|
||||
transforms, lambda t: isinstance(t, launch_context.TileTransform)
|
||||
)
|
||||
}
|
||||
|
||||
tile_transforms = partitioned_transforms.get(True, [])
|
||||
other_transforms = partitioned_transforms.get(False, [])
|
||||
|
||||
if len(tile_transforms) > 1:
|
||||
raise NotImplementedError(
|
||||
f"{tile_transforms} contains more than one tile transform."
|
||||
)
|
||||
|
||||
if len(tile_transforms) == 1:
|
||||
if len(tile_transforms[0].tiling) != len(ref_ty.shape):
|
||||
raise NotImplementedError(
|
||||
f"Only tile transforms with rank equal to the rank of the memref "
|
||||
f"being annotated are supported but got {tile_transforms[0]} for "
|
||||
f"{ref_ty}."
|
||||
)
|
||||
|
||||
if len(other_transforms) > 1:
|
||||
raise NotImplementedError(
|
||||
f"{other_transforms} contains more than one transform."
|
||||
)
|
||||
|
||||
if len(other_transforms) == 1:
|
||||
if not isinstance(other_transforms[0], launch_context.TransposeTransform):
|
||||
raise NotImplementedError(
|
||||
f"{other_transforms[0]} is not a transpose transform."
|
||||
)
|
||||
|
||||
|
||||
@_register_lowering(vector.LoadOp)
|
||||
def _vector_load_op_lowering_rule(
|
||||
_: LoweringContext, vector_load_op: vector.LoadOp
|
||||
@ -260,8 +315,11 @@ def _vector_load_op_lowering_rule(
|
||||
vec_size=strided_layout.vec_size,
|
||||
)
|
||||
elif layouts.from_layout_attr(out_layout_attr) == fa.WGMMA_LAYOUT:
|
||||
layout = ir.MemRefType(vector_load_op.base.type).layout
|
||||
swizzle, transforms = memref_layout_to_swizzle_and_transforms(layout)
|
||||
swizzle, transforms = swizzle_and_transforms_from_transforms_attr(
|
||||
inference_utils.in_transforms(vector_load_op)[0]
|
||||
)
|
||||
ref_ty = ir.MemRefType(vector_load_op.base.type)
|
||||
_check_transforms_and_swizzle_are_supported(ref_ty, transforms, swizzle)
|
||||
transformed_ref = transform_memref(vector_load_op.base, transforms)
|
||||
fragmented_array = fa.FragmentedArray.load_tiled(
|
||||
transformed_ref,
|
||||
@ -297,8 +355,22 @@ def _vector_store_op_lowering_rule(
|
||||
vector_store_op.valueToStore, to_store_layout
|
||||
)
|
||||
|
||||
# TODO(dasenov): This is not efficient for WGMMA layouts
|
||||
fragmented_array.store_untiled(vector_store_op.base)
|
||||
if fragmented_array.layout == fa.WGMMA_LAYOUT:
|
||||
swizzle, transforms = swizzle_and_transforms_from_transforms_attr(
|
||||
inference_utils.in_transforms(vector_store_op)[0]
|
||||
)
|
||||
ref_ty = ir.MemRefType(vector_store_op.base.type)
|
||||
_check_transforms_and_swizzle_are_supported(ref_ty, transforms, swizzle)
|
||||
fragmented_array.store_tiled(
|
||||
transform_memref(vector_store_op.base, transforms), swizzle
|
||||
)
|
||||
elif (isinstance(fragmented_array.layout, fa.WGStridedFragLayout) or
|
||||
isinstance(fragmented_array.layout, fa.WGSplatFragLayout)):
|
||||
fragmented_array.store_untiled(vector_store_op.base)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{vector_store_op} has an unsupported layout: {to_store_layout}"
|
||||
)
|
||||
|
||||
return []
|
||||
|
||||
@ -362,39 +434,43 @@ def _vector_reduction_op_lowering_rule(
|
||||
return [_fragmented_array_to_ir(result, op.result.type)]
|
||||
|
||||
|
||||
def memref_layout_to_swizzle_and_transforms(
|
||||
layout: ir.Attribute,
|
||||
def swizzle_and_transforms_from_transforms_attr(
|
||||
transforms: ir.ArrayAttr,
|
||||
) -> tuple[mgpu.SwizzlingMode, tuple[launch_context.MemRefTransform, ...]]:
|
||||
"""Returns the swizzle and transforms that are encoded in the given layout.
|
||||
"""Returns the swizzle and MemrefTransforms for the given transforms.
|
||||
|
||||
If the layout is not a LayoutAttr, the swizzle is kNoSwizzle and the
|
||||
transforms are empty. Otherwise, the layout may have at most one swizzle
|
||||
transform and any combination of tiling and transpose transforms.
|
||||
Args:
|
||||
transforms: a list of transform attributes.
|
||||
|
||||
Returns:
|
||||
A tuple containing the swizzle mode and MemRefTransforms corresponding to
|
||||
the parameter transforms. If `transforms` is empty, or does not contain
|
||||
any swizzling transform, the swizzle mode is assumed to be kNoSwizzle.
|
||||
Raises:
|
||||
ValueError: if a swizzling transform is followed by any transform.
|
||||
"""
|
||||
swizzle = None
|
||||
gmem_transforms: list[launch_context.MemRefTransform] = []
|
||||
|
||||
if mgpu.LayoutAttr.isinstance(layout):
|
||||
transforms_attr = mgpu.LayoutAttr(layout).transforms
|
||||
for transform in transforms_attr:
|
||||
if swizzle is not None:
|
||||
raise ValueError(f"{layout} contains more transforms after the initial swizzle.")
|
||||
if mgpu.SwizzleTransformAttr.isinstance(transform):
|
||||
# TODO(dasenov): Swizzling can change if the ref is sliced in certain
|
||||
# ways. We might want to enforce some restrictions here.
|
||||
swizzle = mgpu.SwizzleTransformAttr(transform).swizzle
|
||||
elif mgpu.TileTransformAttr.isinstance(transform):
|
||||
tiling = mgpu.TileTransformAttr(transform).tiling
|
||||
tiling_transform = launch_context.TileTransform(tuple(tiling))
|
||||
gmem_transforms.append(tiling_transform)
|
||||
elif mgpu.TransposeTransformAttr.isinstance(transform):
|
||||
permutation = mgpu.TransposeTransformAttr(transform).permutation
|
||||
transpose_transform = launch_context.TransposeTransform(
|
||||
tuple(permutation)
|
||||
)
|
||||
gmem_transforms.append(transpose_transform)
|
||||
else:
|
||||
raise ValueError(f"{layout} has an unsupported transform: {transform}")
|
||||
for transform in transforms:
|
||||
if swizzle is not None:
|
||||
raise ValueError(f"{transforms} contain more transforms after swizzle.")
|
||||
if mgpu.SwizzleTransformAttr.isinstance(transform):
|
||||
# TODO(dasenov): Swizzling can change if the ref is sliced in certain
|
||||
# ways. We might want to enforce some restrictions here.
|
||||
swizzle = mgpu.SwizzleTransformAttr(transform).swizzle
|
||||
elif mgpu.TileTransformAttr.isinstance(transform):
|
||||
tiling = mgpu.TileTransformAttr(transform).tiling
|
||||
tiling_transform = launch_context.TileTransform(tuple(tiling))
|
||||
gmem_transforms.append(tiling_transform)
|
||||
elif mgpu.TransposeTransformAttr.isinstance(transform):
|
||||
permutation = mgpu.TransposeTransformAttr(transform).permutation
|
||||
transpose_transform = launch_context.TransposeTransform(
|
||||
tuple(permutation)
|
||||
)
|
||||
gmem_transforms.append(transpose_transform)
|
||||
else:
|
||||
raise ValueError("Unknown transform: {transform}")
|
||||
|
||||
return swizzle or mgpu.SwizzlingMode.kNoSwizzle, tuple(gmem_transforms)
|
||||
|
||||
@ -434,8 +510,14 @@ def _mgpu_async_load_op_lowering_rule(
|
||||
assert ctx.launch_context is not None
|
||||
barrier = utils.BarrierRef.from_dialect_barrier_memref(load_op.barrier)
|
||||
|
||||
dst_layout = ir.MemRefType(load_op.destination.type).layout
|
||||
swizzle, transforms = memref_layout_to_swizzle_and_transforms(dst_layout)
|
||||
if inference_utils.has_in_transforms_set(load_op):
|
||||
[transforms] = inference_utils.in_transforms(load_op)
|
||||
swizzle, transforms = swizzle_and_transforms_from_transforms_attr(
|
||||
transforms
|
||||
)
|
||||
else:
|
||||
swizzle = mgpu.SwizzlingMode.kNoSwizzle
|
||||
transforms = ()
|
||||
|
||||
gmem_slice = []
|
||||
for idx_i32, size in zip(load_op.indices, load_op.slice_lengths):
|
||||
@ -464,8 +546,14 @@ def _mgpu_async_store_op_lowering_rule(
|
||||
) -> Sequence[ir.Value]:
|
||||
assert ctx.launch_context is not None
|
||||
|
||||
src_layout = ir.MemRefType(store_op.source.type).layout
|
||||
swizzle, transforms = memref_layout_to_swizzle_and_transforms(src_layout)
|
||||
if inference_utils.has_in_transforms_set(store_op):
|
||||
[transforms] = inference_utils.in_transforms(store_op)
|
||||
swizzle, transforms = swizzle_and_transforms_from_transforms_attr(
|
||||
transforms
|
||||
)
|
||||
else:
|
||||
swizzle = mgpu.SwizzlingMode.kNoSwizzle
|
||||
transforms = ()
|
||||
|
||||
gmem_slice = []
|
||||
for idx_i32, size in zip(store_op.indices, store_op.slice_lengths):
|
||||
@ -673,6 +761,9 @@ def _bitcast_op_lowering_rule(
|
||||
def _mgpu_wgmma_op_lowering_rule(
|
||||
_: LoweringContext, wgmma_op: mgpu.WGMMAOp
|
||||
) -> Sequence[ir.Value]:
|
||||
if wgmma_op.transpose_a or wgmma_op.transpose_b:
|
||||
raise ValueError("Transpose arguments are to be deleted.")
|
||||
|
||||
fa_layouts = (
|
||||
*inference_utils.in_layouts(wgmma_op),
|
||||
*inference_utils.out_layouts(wgmma_op),
|
||||
@ -691,25 +782,38 @@ def _mgpu_wgmma_op_lowering_rule(
|
||||
regs = acc_in.to_layout(fa.WGMMA_LAYOUT)
|
||||
acc = wgmma.WGMMAAccumulator.from_registers(regs)
|
||||
|
||||
b_layout = ir.MemRefType(wgmma_op.b.type).layout
|
||||
b_swizzle, b_transforms = memref_layout_to_swizzle_and_transforms(b_layout)
|
||||
if ir.VectorType.isinstance(wgmma_op.a.type):
|
||||
a_transforms = None
|
||||
b_transforms = inference_utils.in_transforms(wgmma_op)[0]
|
||||
else:
|
||||
a_transforms, b_transforms = inference_utils.in_transforms(wgmma_op)
|
||||
|
||||
b_swizzle, b_transforms = swizzle_and_transforms_from_transforms_attr(
|
||||
b_transforms
|
||||
)
|
||||
minimum_swizzle = mgpu.SwizzlingMode.k32ByteSwizzle
|
||||
ref_ty = ir.MemRefType(wgmma_op.b.type)
|
||||
_check_transforms_and_swizzle_are_supported(
|
||||
ref_ty, b_transforms, b_swizzle, minimum_swizzle
|
||||
)
|
||||
b_operand = transform_memref(wgmma_op.b, b_transforms)
|
||||
if wgmma_op.transpose_b:
|
||||
b_operand = utils.memref_transpose(b_operand, (0, 1, 3, 2))
|
||||
|
||||
if ir.VectorType.isinstance(wgmma_op.a.type):
|
||||
a_operand = _fragmented_array_from_ir(wgmma_op.a, wgmma_layout)
|
||||
else:
|
||||
a_layout = ir.MemRefType(wgmma_op.a.type).layout
|
||||
a_swizzle, a_transforms = memref_layout_to_swizzle_and_transforms(a_layout)
|
||||
a_swizzle, a_transforms = swizzle_and_transforms_from_transforms_attr(
|
||||
a_transforms
|
||||
)
|
||||
ref_ty = ir.MemRefType(wgmma_op.a.type)
|
||||
_check_transforms_and_swizzle_are_supported(
|
||||
ref_ty, a_transforms, a_swizzle, minimum_swizzle
|
||||
)
|
||||
if a_swizzle != b_swizzle:
|
||||
raise ValueError(
|
||||
f"Non-matching swizzles of operands a and b in WGMMA: {a_swizzle} !="
|
||||
f" {b_swizzle}"
|
||||
)
|
||||
a_operand = transform_memref(wgmma_op.a, a_transforms)
|
||||
if wgmma_op.transpose_a:
|
||||
a_operand = utils.memref_transpose(a_operand, (0, 1, 3, 2))
|
||||
|
||||
new_acc = wgmma.wgmma(acc, a_operand, b_operand, swizzle=b_swizzle)
|
||||
|
||||
@ -902,7 +1006,7 @@ def single_thread_predicates(module: ir.Module) -> tuple[ir.Value, ir.Value]:
|
||||
def _should_lower(op: ir.OpView) -> bool:
|
||||
"""Returns 'true' if the operation should be lowered."""
|
||||
return (
|
||||
op.OPERATION_NAME.startswith("mosaic_gpu.")
|
||||
op.OPERATION_NAME.startswith("mosaic_gpu.") # pytype: disable=attribute-error
|
||||
or inference_utils.should_have_layout(op)
|
||||
or any(bool(b) for r in op.regions for b in r) # Does it have subblocks?
|
||||
)
|
||||
|
@ -387,7 +387,21 @@ class WGMMARowFragLayout:
|
||||
"""[m] matrix, where m % 64 == 0."""
|
||||
|
||||
def thread_idxs(self, shape):
|
||||
raise NotImplementedError
|
||||
index = ir.IndexType.get()
|
||||
assert len(shape) == 1
|
||||
assert shape[0] % 64 == 0
|
||||
tid = arith.index_cast(ir.IndexType.get(), mgpu.thread_idx())
|
||||
tid_wg = arith.remui(tid, c(WARPGROUP_SIZE, index))
|
||||
warp_idx = arith.divui(tid_wg, c(32, index))
|
||||
lane_id = arith.remui(tid_wg, c(32, index))
|
||||
row_base = arith.addi(
|
||||
arith.divui(lane_id, c(4, index)), arith.muli(warp_idx, c(16, index))
|
||||
)
|
||||
|
||||
for row_group in range(0, shape[0], 64):
|
||||
for row_subgroup in (0, 8):
|
||||
row = arith.addi(row_base, c(row_group + row_subgroup, index))
|
||||
yield (row,)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -660,6 +674,31 @@ class FragmentedArray:
|
||||
vecs = [vector.load(vec_ty, ref, vec_idx) for vec_idx in layout.thread_idxs(shape)]
|
||||
return cls(_registers=np.array(vecs), _layout=layout, _is_signed=is_signed)
|
||||
|
||||
@classmethod
|
||||
def load_wgmma_row(
|
||||
cls,
|
||||
ref: ir.Value,
|
||||
*,
|
||||
is_signed: bool | None = None,
|
||||
):
|
||||
if not ir.MemRefType.isinstance(ref.type):
|
||||
raise TypeError(ref.type)
|
||||
|
||||
ref_ty = ir.MemRefType(ref.type)
|
||||
shape = tuple(ref_ty.shape)
|
||||
if len(shape) != 1:
|
||||
raise ValueError("WGMMARowFragLayout requires a 1D shape")
|
||||
if shape[0] % 64:
|
||||
raise ValueError(
|
||||
"WGMMARowFragLayout requires shape[0] to be a multiple of 64"
|
||||
)
|
||||
|
||||
layout = WGMMARowFragLayout()
|
||||
registers = [memref.load(ref, [idx]) for (idx,) in layout.thread_idxs(shape)]
|
||||
registers = np.array(registers).reshape(-1, 2)
|
||||
return cls(_registers=registers, _layout=layout, _is_signed=is_signed)
|
||||
|
||||
|
||||
@classmethod
|
||||
def splat(cls, value, shape, layout=None, *, is_signed: bool | None = None):
|
||||
layout = layout or WGSplatFragLayout(shape)
|
||||
@ -1743,6 +1782,8 @@ class FragmentedArray:
|
||||
)
|
||||
|
||||
match self.layout:
|
||||
case WGMMARowFragLayout():
|
||||
self._store_untiled_wgmma_row(ref)
|
||||
case WGSplatFragLayout():
|
||||
vs_unsupported()
|
||||
self._store_untiled_splat(ref)
|
||||
@ -1789,6 +1830,23 @@ class FragmentedArray:
|
||||
for idx, reg in zip(idxs, self.registers.flat):
|
||||
vector.store(reg, ref_, idx)
|
||||
|
||||
def _store_untiled_wgmma_row(self, ref: ir.Value):
|
||||
"""Stores an array with a WGMMA row layout."""
|
||||
assert self.layout == WGMMA_ROW_LAYOUT
|
||||
index = ir.IndexType.get()
|
||||
tid = arith.index_cast(ir.IndexType.get(), mgpu.thread_idx())
|
||||
|
||||
is_first = arith.cmpi(
|
||||
arith.CmpIPredicate.eq, arith.remui(tid, c(4, index)), c(0, index)
|
||||
)
|
||||
# Consecutive groups of 4 threads hold the same value in this layout,
|
||||
# therefore we only need to transfer data from one of them.
|
||||
with utils.when(is_first):
|
||||
for (idx,), value in zip(
|
||||
self.layout.thread_idxs(self.shape), self.registers.flatten()
|
||||
):
|
||||
memref.store(value, ref, [idx])
|
||||
|
||||
def _store_untiled_tiled(self, ref: ir.Value, *, vector_store: bool = True):
|
||||
"""Stores an array with a tiled layout. Not optimized at the moment."""
|
||||
if utils.bitwidth(self.mlir_dtype) < 8:
|
||||
|
@ -383,89 +383,6 @@ def _infer_wgmma_op_layout(wgmma_op: mgpu.WGMMAOp) -> OptionalLayouts:
|
||||
|
||||
return [layout], [layout]
|
||||
|
||||
@dataclasses.dataclass()
|
||||
class WGMMATransforms:
|
||||
swizzle: mgpu.SwizzlingMode
|
||||
a_tile: tuple[int, ...]
|
||||
a_transpose: bool
|
||||
b_tile: tuple[int, ...]
|
||||
b_transpose: bool
|
||||
|
||||
|
||||
def infer_wgmma_transforms(wgmma_op: mgpu.WGMMAOp) -> WGMMATransforms:
|
||||
a_shape = cast(ir.ShapedType, wgmma_op.a.type).shape
|
||||
k = a_shape[0] if wgmma_op.transpose_a else a_shape[1]
|
||||
bitwidth = cast(ir.ShapedType, wgmma_op.a.type).element_type.width
|
||||
|
||||
# Try tiling with all swizzling modes starting from the largest one.
|
||||
for swizzle in [
|
||||
mgpu.SwizzlingMode.k128ByteSwizzle,
|
||||
mgpu.SwizzlingMode.k64ByteSwizzle,
|
||||
mgpu.SwizzlingMode.k32ByteSwizzle,
|
||||
]:
|
||||
s = swizzle * 8 // bitwidth
|
||||
if k % s == 0:
|
||||
return WGMMATransforms(
|
||||
swizzle=swizzle,
|
||||
a_tile=(s, 64) if wgmma_op.transpose_a else (64, s),
|
||||
a_transpose=wgmma_op.transpose_a,
|
||||
b_tile=(s, s),
|
||||
b_transpose=wgmma_op.transpose_b,
|
||||
)
|
||||
raise ValueError(
|
||||
"Could not infer layouts for memref feeding into WGMMA. The "
|
||||
"non-contracting dimension ({k}) must be a multiple of "
|
||||
"s = swizzle * (8 / bitwidth) where swizzle is a valid swizzle "
|
||||
f"(32, 64, or 128) and bitwidth ({bitwidth}) is the element size of "
|
||||
"`a` and `b`."
|
||||
)
|
||||
|
||||
def _layout_for_memref_view(view_op: memref.ViewOp) -> ir.Attribute | None:
|
||||
wgmma_use = None
|
||||
uses = cast(ir.OpResult, view_op.result).uses
|
||||
for use in uses:
|
||||
user = use.owner
|
||||
if isinstance(user, memref.CastOp):
|
||||
# This memref is already cast, so we don't need to do anything.
|
||||
return None
|
||||
if isinstance(user, mgpu.WGMMAOp):
|
||||
if wgmma_use is not None:
|
||||
raise NotImplementedError(f"Multiple WGMMA consumers of {view_op}.")
|
||||
wgmma_use = use
|
||||
break
|
||||
if (
|
||||
not isinstance(user, mgpu.AsyncLoadOp)
|
||||
and not isinstance(user, mgpu.AsyncStoreOp)
|
||||
and not isinstance(user, vector.LoadOp)
|
||||
and not isinstance(user, vector.StoreOp)
|
||||
):
|
||||
raise NotImplementedError(f"Unsupported user {user} of {view_op}.")
|
||||
|
||||
if wgmma_use is None:
|
||||
# This memref is not used by a WGMMA operation, so we don't need to do
|
||||
# anything.
|
||||
return None
|
||||
|
||||
transforms = infer_wgmma_transforms(wgmma_use.owner)
|
||||
if wgmma_use.operand_number == 1:
|
||||
tile = transforms.a_tile
|
||||
transpose = transforms.a_transpose
|
||||
else:
|
||||
tile = transforms.b_tile
|
||||
transpose = transforms.b_transpose
|
||||
transpose_attr = (
|
||||
[mgpu.TransposeTransformAttr.get([1, 0, 2, 3])] if transpose else []
|
||||
)
|
||||
|
||||
layout = mgpu.LayoutAttr.get(
|
||||
2,
|
||||
[mgpu.TileTransformAttr.get(tile)]
|
||||
+ transpose_attr
|
||||
+ [mgpu.SwizzleTransformAttr.get(transforms.swizzle)],
|
||||
)
|
||||
|
||||
return layout
|
||||
|
||||
|
||||
def _earliest_use(regions: list[ir.Region], uses: Sequence[ir.OpOperand]) -> ir.OpView:
|
||||
owners = [use.owner for use in uses]
|
||||
@ -607,11 +524,3 @@ def infer_layout(module: ir.Module):
|
||||
|
||||
for op in module.body:
|
||||
traverse_op(op, set_default_layout)
|
||||
|
||||
def infer_memref_layouts_and_insert_casts(op: ir.OpView):
|
||||
if op.name == "memref.view":
|
||||
if layout := _layout_for_memref_view(op):
|
||||
_insert_memref_layout_cast(layout, op)
|
||||
|
||||
for op in module.body:
|
||||
traverse_op(op, infer_memref_layouts_and_insert_casts)
|
||||
|
@ -26,6 +26,9 @@ from typing import cast
|
||||
from jax._src.lib import mosaic_gpu_dialect as mgpu
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import arith
|
||||
from jax._src.lib.mlir.dialects import builtin
|
||||
from jax._src.lib.mlir.dialects import gpu
|
||||
from jax._src.lib.mlir.dialects import memref
|
||||
from jax._src.lib.mlir.dialects import vector
|
||||
|
||||
from . import fragmented_array as fa
|
||||
@ -169,7 +172,6 @@ def _infer_vector_load_store_transforms(
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# TODO(bchetioui): remove this once jaxlib minimum version >= 0.5.2.
|
||||
SliceSMEMOp = getattr(mgpu, "SliceSMEMOp", None)
|
||||
|
||||
@ -196,6 +198,60 @@ def _infer_slice_smem_transforms(op: SliceSMEMOp) -> OptionalTransforms:
|
||||
return None if transforms is None else ([], [transforms])
|
||||
|
||||
|
||||
# TODO(bchetioui,apaszke): this empty rule is necessary while Mosaic doesn't use
|
||||
# the dialect in all cases.
|
||||
# The rule is necessary in order to handle the lowering of `utils.memref_ptr`
|
||||
# which is used in `_construct_smem_reftree`.
|
||||
@partial(_add_transform_inference_rule, builtin.UnrealizedConversionCastOp)
|
||||
def _infer_unrealized_conversion_cast_transforms(
|
||||
_: builtin.UnrealizedConversionCastOp,
|
||||
) -> OptionalTransforms:
|
||||
return None
|
||||
|
||||
|
||||
@partial(_add_transform_inference_rule, memref.ViewOp)
|
||||
def _infer_memref_view_transforms(op: memref.ViewOp) -> OptionalTransforms:
|
||||
if not isinstance(op.source.owner.opview, gpu.DynamicSharedMemoryOp):
|
||||
raise NotImplementedError(
|
||||
"Memref view transforms are only inferred when the op is a direct user "
|
||||
f"of a DynamicSharedMemoryOp but got {op}."
|
||||
)
|
||||
transforms = inference_utils.value_transforms(op.source)
|
||||
if transforms is not None:
|
||||
raise NotImplementedError(
|
||||
"memref view with in_transforms aren't yet supported"
|
||||
)
|
||||
uses = cast(ir.OpResult, op.result).uses
|
||||
|
||||
for op_operand_use in uses:
|
||||
consumer = op_operand_use.owner
|
||||
op_user = consumer.operands[op_operand_use.operand_number]
|
||||
out_transforms = inference_utils.in_transforms_for_operand(
|
||||
consumer, op_user
|
||||
)
|
||||
if transforms is not None and out_transforms is not None:
|
||||
if transforms != out_transforms:
|
||||
raise ValueError(
|
||||
f"Conflicting transforms for {op_user} in {op}: "
|
||||
f"{transforms} != {out_transforms}."
|
||||
)
|
||||
elif out_transforms is not None:
|
||||
transforms = out_transforms
|
||||
|
||||
# TODO(bchetioui): do we actually need to assign a transform to the input of
|
||||
# the view op? Presumably, it'll only be used to access scratch memory.
|
||||
return None if transforms is None else ([], [transforms])
|
||||
|
||||
|
||||
# TODO(bchetioui,apaszke): this empty rule is necessary while Mosaic doesn't use
|
||||
# the dialect in all cases.
|
||||
@partial(_add_transform_inference_rule, gpu.DynamicSharedMemoryOp)
|
||||
def _infer_dynamic_smem_transforms(
|
||||
_: gpu.DynamicSharedMemoryOp,
|
||||
) -> OptionalTransforms:
|
||||
return None
|
||||
|
||||
|
||||
def _should_have_transforms(op: ir.OpView) -> bool:
|
||||
"""Returns 'True' if the operation should be assigned in/out transforms."""
|
||||
return any(
|
||||
@ -218,7 +274,6 @@ def infer_transforms(module: ir.Module):
|
||||
specified. We error out if two distinct sets of transforms are competing to
|
||||
annotate the same memref.
|
||||
"""
|
||||
|
||||
def inference_step(op: ir.Operation):
|
||||
if not _should_have_transforms(op):
|
||||
return
|
||||
|
@ -24,7 +24,7 @@ import jax._src.xla_bridge as xb
|
||||
|
||||
# cuda_plugin_extension locates inside jaxlib. `jaxlib` is for testing without
|
||||
# preinstalled jax cuda plugin packages.
|
||||
for pkg_name in ['jax_cuda12_plugin', 'jaxlib']:
|
||||
for pkg_name in ['jax_cuda12_plugin', 'jaxlib.cuda']:
|
||||
try:
|
||||
cuda_plugin_extension = importlib.import_module(
|
||||
f'{pkg_name}.cuda_plugin_extension'
|
||||
|
@ -23,7 +23,7 @@ import jax._src.xla_bridge as xb
|
||||
|
||||
# rocm_plugin_extension locates inside jaxlib. `jaxlib` is for testing without
|
||||
# preinstalled jax rocm plugin packages.
|
||||
for pkg_name in ['jax_rocm60_plugin', 'jaxlib']:
|
||||
for pkg_name in ['jax_rocm60_plugin', 'jaxlib.cuda']:
|
||||
try:
|
||||
rocm_plugin_extension = importlib.import_module(
|
||||
f'{pkg_name}.rocm_plugin_extension'
|
||||
|
81
jaxlib/BUILD
81
jaxlib/BUILD
@ -222,84 +222,3 @@ nanobind_extension(
|
||||
"@xla//third_party/python_runtime:headers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gpu_plugin_extension",
|
||||
srcs = ["gpu_plugin_extension.cc"],
|
||||
hdrs = ["gpu_plugin_extension.h"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
":kernel_nanobind_helpers",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/strings:string_view",
|
||||
"@nanobind",
|
||||
"@xla//xla:util",
|
||||
"@xla//xla/ffi/api:c_api",
|
||||
"@xla//xla/pjrt:status_casters",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_ffi_extension_hdrs",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_helpers",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_triton_extension_hdrs",
|
||||
"@xla//xla/python:py_client_gpu",
|
||||
"@xla//xla/tsl/python/lib/core:numpy",
|
||||
],
|
||||
)
|
||||
|
||||
nanobind_extension(
|
||||
name = "cuda_plugin_extension",
|
||||
srcs = ["cuda_plugin_extension.cc"],
|
||||
module_name = "cuda_plugin_extension",
|
||||
deps = [
|
||||
":gpu_plugin_extension",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@nanobind",
|
||||
"@xla//xla/pjrt:status_casters",
|
||||
"@xla//xla/tsl/cuda:cublas",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
],
|
||||
)
|
||||
|
||||
nanobind_extension(
|
||||
name = "rocm_plugin_extension",
|
||||
srcs = ["rocm_plugin_extension.cc"],
|
||||
module_name = "rocm_plugin_extension",
|
||||
deps = [
|
||||
":gpu_plugin_extension",
|
||||
"@com_google_absl//absl/log",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@local_config_rocm//rocm:hip",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
"@nanobind",
|
||||
],
|
||||
)
|
||||
|
||||
# CPU kernels
|
||||
|
||||
# TODO(phawkins): Remove this forwarding target.
|
||||
cc_library(
|
||||
name = "cpu_kernels",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//jaxlib/cpu:cpu_kernels",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# TODO(phawkins): Remove this forwarding target.
|
||||
cc_library(
|
||||
name = "gpu_kernels",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//jaxlib/cuda:cuda_gpu_kernels",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -657,6 +657,22 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
nanobind_extension(
|
||||
name = "cuda_plugin_extension",
|
||||
srcs = ["cuda_plugin_extension.cc"],
|
||||
module_name = "cuda_plugin_extension",
|
||||
deps = [
|
||||
"//jaxlib/gpu:gpu_plugin_extension",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@nanobind",
|
||||
"@xla//xla/pjrt:status_casters",
|
||||
"@xla//xla/tsl/cuda:cublas",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
],
|
||||
)
|
||||
|
||||
# We cannot nest select and if_cuda_is_configured so we introduce
|
||||
# a standalone py_library target.
|
||||
py_library(
|
||||
@ -664,6 +680,6 @@ py_library(
|
||||
# `if_cuda_is_configured` will default to `[]`.
|
||||
deps = if_cuda_is_configured([
|
||||
":cuda_gpu_support",
|
||||
"//jaxlib:cuda_plugin_extension",
|
||||
":cuda_plugin_extension",
|
||||
]),
|
||||
)
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#include "jaxlib/gpu_plugin_extension.h"
|
||||
#include "jaxlib/gpu/gpu_plugin_extension.h"
|
||||
#include "xla/pjrt/status_casters.h"
|
||||
|
||||
namespace nb = nanobind;
|
@ -90,3 +90,32 @@ xla_py_proto_library(
|
||||
visibility = jax_visibility("triton_proto_py_users"),
|
||||
deps = [":triton_proto"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gpu_plugin_extension",
|
||||
srcs = ["gpu_plugin_extension.cc"],
|
||||
hdrs = ["gpu_plugin_extension.h"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
"//jaxlib:kernel_nanobind_helpers",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/strings:string_view",
|
||||
"@nanobind",
|
||||
"@xla//xla:util",
|
||||
"@xla//xla/ffi/api:c_api",
|
||||
"@xla//xla/pjrt:status_casters",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_ffi_extension_hdrs",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_helpers",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_triton_extension_hdrs",
|
||||
"@xla//xla/python:py_client_gpu",
|
||||
"@xla//xla/tsl/python/lib/core:numpy",
|
||||
],
|
||||
)
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/gpu_plugin_extension.h"
|
||||
#include "jaxlib/gpu/gpu_plugin_extension.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef JAXLIB_GPU_PLUGIN_EXTENSION_H_
|
||||
#define JAXLIB_GPU_PLUGIN_EXTENSION_H_
|
||||
#ifndef JAXLIB_GPU_GPU_PLUGIN_EXTENSION_H_
|
||||
#define JAXLIB_GPU_GPU_PLUGIN_EXTENSION_H_
|
||||
|
||||
#include "nanobind/nanobind.h"
|
||||
|
||||
@ -24,4 +24,4 @@ void BuildGpuPluginExtension(nanobind::module_& m);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // JAXLIB_GPU_PLUGIN_EXTENSION_H_
|
||||
#endif // JAXLIB_GPU_GPU_PLUGIN_EXTENSION_H_
|
@ -143,7 +143,7 @@ mlir::FailureOr<mlir::OpPassManager> GetPassPipeline(
|
||||
mlir::memref::registerMemRefPasses();
|
||||
mlir::registerConvertToLLVMPass();
|
||||
mlir::registerGPUPasses();
|
||||
mlir::registerGpuLaunchSinkIndexComputations();
|
||||
mlir::registerGpuLaunchSinkIndexComputationsPass();
|
||||
mosaic::gpu::registerGpuLaunchLoweringPass();
|
||||
mosaic::gpu::registerConvertGpuToLLVMPass();
|
||||
mosaic::gpu::registerByvalInsertionPass();
|
||||
|
@ -555,11 +555,25 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
nanobind_extension(
|
||||
name = "rocm_plugin_extension",
|
||||
srcs = ["rocm_plugin_extension.cc"],
|
||||
module_name = "rocm_plugin_extension",
|
||||
deps = [
|
||||
"//jaxlib/gpu:gpu_plugin_extension",
|
||||
"@com_google_absl//absl/log",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@local_config_rocm//rocm:hip",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
"@nanobind",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "gpu_only_test_deps",
|
||||
# `if_rocm_is_configured` will default to `[]`.
|
||||
deps = if_rocm_is_configured([
|
||||
":rocm_gpu_support",
|
||||
"//jaxlib:rocm_plugin_extension",
|
||||
":rocm_plugin_extension",
|
||||
]),
|
||||
)
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
#include "absl/log/log.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "rocm/include/hip/hip_runtime.h"
|
||||
#include "jaxlib/gpu_plugin_extension.h"
|
||||
#include "jaxlib/gpu/gpu_plugin_extension.h"
|
||||
|
||||
namespace nb = nanobind;
|
||||
|
@ -143,16 +143,16 @@ py_binary(
|
||||
data = [
|
||||
"LICENSE.txt",
|
||||
] + if_cuda([
|
||||
"//jaxlib/mosaic/gpu:mosaic_gpu",
|
||||
"//jaxlib:cuda_plugin_extension",
|
||||
"//jaxlib:version",
|
||||
"//jaxlib/mosaic/gpu:mosaic_gpu",
|
||||
"//jaxlib/cuda:cuda_plugin_extension",
|
||||
"//jaxlib/cuda:cuda_gpu_support",
|
||||
"//jax_plugins/cuda:plugin_pyproject.toml",
|
||||
"//jax_plugins/cuda:plugin_setup.py",
|
||||
"@local_config_cuda//cuda:cuda-nvvm",
|
||||
]) + if_rocm([
|
||||
"//jaxlib:rocm_plugin_extension",
|
||||
"//jaxlib:version",
|
||||
"//jaxlib/rocm:rocm_plugin_extension",
|
||||
"//jaxlib/rocm:rocm_gpu_support",
|
||||
"//jax_plugins/rocm:plugin_pyproject.toml",
|
||||
"//jax_plugins/rocm:plugin_setup.py",
|
||||
|
@ -110,7 +110,7 @@ def prepare_wheel_cuda(
|
||||
f"__main__/jaxlib/cuda/_triton.{pyext}",
|
||||
f"__main__/jaxlib/cuda/_hybrid.{pyext}",
|
||||
f"__main__/jaxlib/cuda/_versions.{pyext}",
|
||||
f"__main__/jaxlib/cuda_plugin_extension.{pyext}",
|
||||
f"__main__/jaxlib/cuda/cuda_plugin_extension.{pyext}",
|
||||
f"__main__/jaxlib/mosaic/gpu/_mosaic_gpu_ext.{pyext}",
|
||||
"__main__/jaxlib/mosaic/gpu/libmosaic_gpu_runtime.so",
|
||||
"__main__/jaxlib/version.py",
|
||||
@ -148,7 +148,7 @@ def prepare_wheel_rocm(
|
||||
f"__main__/jaxlib/rocm/_hybrid.{pyext}",
|
||||
f"__main__/jaxlib/rocm/_rnn.{pyext}",
|
||||
f"__main__/jaxlib/rocm/_triton.{pyext}",
|
||||
f"__main__/jaxlib/rocm_plugin_extension.{pyext}",
|
||||
f"__main__/jaxlib/rocm/rocm_plugin_extension.{pyext}",
|
||||
"__main__/jaxlib/version.py",
|
||||
],
|
||||
)
|
||||
|
@ -29,16 +29,23 @@ def call_kernel(
|
||||
kernel,
|
||||
grid: tuple[int, int],
|
||||
transpose_grid: bool,
|
||||
*args
|
||||
key: jax.Array,
|
||||
total_size: tuple[int, int],
|
||||
block_size: tuple[int, int],
|
||||
tile_size: tuple[int, int],
|
||||
):
|
||||
"""Calls a kernel over a grid and concatenates results to a single array."""
|
||||
if transpose_grid:
|
||||
grid = (grid[1], grid[0])
|
||||
m, n = grid
|
||||
return jnp.concatenate([
|
||||
jnp.concatenate([
|
||||
kernel((i, j), *args) for j in range(n)], axis=1)
|
||||
for i in range(m)], axis=0)
|
||||
samples = jnp.concatenate([
|
||||
jnp.concatenate([
|
||||
kernel((i, j), key, total_size, block_size, tile_size)
|
||||
for j in range(n)], axis=1)
|
||||
for i in range(m)], axis=0)
|
||||
# Slice out the padding.
|
||||
samples = samples[:total_size[0], :total_size[1]]
|
||||
return samples
|
||||
|
||||
|
||||
def call_kernel_3d(
|
||||
@ -73,10 +80,10 @@ def uniform_kernel(block_index, key, total_size, block_size, tile_size):
|
||||
block_size=block_size,
|
||||
tile_size=tile_size)
|
||||
return blocked_sampler.sample_block(jax.random.uniform,
|
||||
keys,
|
||||
block_size=block_size,
|
||||
tile_size=tile_size,
|
||||
minval=0.0, maxval=1.0)
|
||||
keys,
|
||||
block_size=block_size,
|
||||
tile_size=tile_size,
|
||||
minval=0.0, maxval=1.0)
|
||||
|
||||
|
||||
class BlockedSamplerTest(jtu.JaxTestCase):
|
||||
@ -94,16 +101,25 @@ class BlockedSamplerTest(jtu.JaxTestCase):
|
||||
dict(testcase_name='16x256_vs_32x128', total_size=(32, 256),
|
||||
block_size_a=(16, 256), block_size_b=(32, 128),
|
||||
tile_size=(8, 128), transpose_grid=False),
|
||||
dict(testcase_name='128x128_vs_128x256_padding',
|
||||
total_size=(256, 128), block_size_a=(128, 128),
|
||||
block_size_b=(128, 256), tile_size=(128, 128), transpose_grid=False),
|
||||
dict(testcase_name='128x128_vs_128x256_padding2',
|
||||
total_size=(257, 129), block_size_a=(128, 128),
|
||||
block_size_b=(128, 256), tile_size=(128, 128), transpose_grid=False),
|
||||
)
|
||||
def test_block_shape_invariance(self, total_size, block_size_a,
|
||||
block_size_b, tile_size, transpose_grid):
|
||||
global_key = jax.random.key(0)
|
||||
grid_a = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_a))
|
||||
ceil_div = lambda x, y: (x + y - 1) // y
|
||||
grid_a = tuple(ceil_div(_tot, _blk)
|
||||
for _tot, _blk in zip(total_size, block_size_a))
|
||||
result_a = call_kernel(
|
||||
uniform_kernel, grid_a, transpose_grid, global_key,
|
||||
total_size, block_size_a, tile_size)
|
||||
|
||||
grid_b = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_b))
|
||||
grid_b = tuple(ceil_div(_tot, _blk)
|
||||
for _tot, _blk in zip(total_size, block_size_b))
|
||||
result_b = call_kernel(
|
||||
uniform_kernel, grid_b, transpose_grid, global_key,
|
||||
total_size, block_size_b, tile_size)
|
||||
|
@ -15,9 +15,9 @@
|
||||
import contextlib
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import shlex
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
@ -78,6 +78,31 @@ def capture_jax_logs():
|
||||
logger.removeHandler(handler)
|
||||
|
||||
|
||||
# Saves and runs script from the file in order to fix the problem with
|
||||
# `tsl::Env::Default()->GetExecutablePath()` not working properly with
|
||||
# command flag.
|
||||
def _run(program, env_var = {}):
|
||||
# strip the leading whitespace from the program script
|
||||
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
|
||||
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w+", encoding="utf-8", suffix=".py", dir=os.getcwd()
|
||||
) as f:
|
||||
f.write(textwrap.dedent(program))
|
||||
f.flush()
|
||||
python = sys.executable
|
||||
assert "python" in python
|
||||
if env_var:
|
||||
env_var.update(os.environ)
|
||||
else:
|
||||
env_var = os.environ
|
||||
|
||||
# Make sure C++ logging is at default level for the test process.
|
||||
p = subprocess.run([python, f.name], env=env_var, capture_output=True, text=True)
|
||||
|
||||
return type("", (object,), { "stdout": p.stdout, "stderr": p.stderr })
|
||||
|
||||
|
||||
class LoggingTest(jtu.JaxTestCase):
|
||||
|
||||
@unittest.skipIf(platform.system() == "Windows",
|
||||
@ -90,36 +115,25 @@ class LoggingTest(jtu.JaxTestCase):
|
||||
if sys.executable is None:
|
||||
raise self.skipTest("test requires access to python binary")
|
||||
|
||||
# Save script in file to fix the problem with
|
||||
# `tsl::Env::Default()->GetExecutablePath()` not working properly with
|
||||
# command flag.
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w+", encoding="utf-8", suffix=".py"
|
||||
) as f:
|
||||
f.write(textwrap.dedent("""
|
||||
o = _run("""
|
||||
import jax
|
||||
jax.device_count()
|
||||
f = jax.jit(lambda x: x + 1)
|
||||
f(1)
|
||||
f(2)
|
||||
jax.numpy.add(1, 1)
|
||||
"""))
|
||||
python = sys.executable
|
||||
assert "python" in python
|
||||
# Make sure C++ logging is at default level for the test process.
|
||||
proc = subprocess.run([python, f.name], capture_output=True)
|
||||
""")
|
||||
|
||||
lines = proc.stdout.split(b"\n")
|
||||
lines.extend(proc.stderr.split(b"\n"))
|
||||
allowlist = [
|
||||
b"",
|
||||
(
|
||||
b"An NVIDIA GPU may be present on this machine, but a"
|
||||
b" CUDA-enabled jaxlib is not installed. Falling back to cpu."
|
||||
),
|
||||
]
|
||||
lines = [l for l in lines if l not in allowlist]
|
||||
self.assertEmpty(lines)
|
||||
lines = o.stdout.split("\n")
|
||||
lines.extend(o.stderr.split("\n"))
|
||||
allowlist = [
|
||||
(
|
||||
"An NVIDIA GPU may be present on this machine, but a"
|
||||
" CUDA-enabled jaxlib is not installed. Falling back to cpu."
|
||||
),
|
||||
]
|
||||
lines = [l for l in lines if l in allowlist]
|
||||
self.assertEmpty(lines)
|
||||
|
||||
def test_debug_logging(self):
|
||||
# Warmup so we don't get "No GPU/TPU" warning later.
|
||||
@ -164,19 +178,12 @@ class LoggingTest(jtu.JaxTestCase):
|
||||
if sys.executable is None:
|
||||
raise self.skipTest("test requires access to python binary")
|
||||
|
||||
program = """
|
||||
import jax # this prints INFO logging from backend imports
|
||||
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
|
||||
"""
|
||||
o = _run("""
|
||||
import jax # this prints INFO logging from backend imports
|
||||
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
|
||||
""", { "JAX_LOGGING_LEVEL": "INFO" })
|
||||
|
||||
# strip the leading whitespace from the program script
|
||||
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
|
||||
|
||||
# test INFO
|
||||
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=INFO {sys.executable} -c"
|
||||
f" '{program}'")
|
||||
p = subprocess.run(cmd, capture_output=True, text=True)
|
||||
log_output = p.stderr
|
||||
log_output = o.stderr
|
||||
info_lines = log_output.split("\n")
|
||||
self.assertGreater(len(info_lines), 0)
|
||||
self.assertIn("INFO", log_output)
|
||||
@ -194,22 +201,14 @@ class LoggingTest(jtu.JaxTestCase):
|
||||
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
|
||||
"""
|
||||
|
||||
# strip the leading whitespace from the program script
|
||||
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
|
||||
o = _run(program, { "JAX_LOGGING_LEVEL": "DEBUG" })
|
||||
|
||||
# test DEBUG
|
||||
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c"
|
||||
f" '{program}'")
|
||||
p = subprocess.run(cmd, capture_output=True, text=True)
|
||||
log_output = p.stderr
|
||||
log_output = o.stderr
|
||||
self.assertIn("INFO", log_output)
|
||||
self.assertIn("DEBUG", log_output)
|
||||
|
||||
# test JAX_DEBUG_MODULES
|
||||
cmd = shlex.split(f"env JAX_DEBUG_LOG_MODULES=jax {sys.executable} -c"
|
||||
f" '{program}'")
|
||||
p = subprocess.run(cmd, capture_output=True, text=True)
|
||||
log_output = p.stderr
|
||||
o = _run(program, { "JAX_DEBUG_LOG_MODULES": "jax" })
|
||||
log_output = o.stderr
|
||||
self.assertIn("DEBUG", log_output)
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
@ -220,22 +219,15 @@ class LoggingTest(jtu.JaxTestCase):
|
||||
raise self.skipTest("test requires access to python binary")
|
||||
|
||||
_separator = "---------------------------"
|
||||
program = f"""
|
||||
o = _run(f"""
|
||||
import sys
|
||||
import jax # this prints INFO logging from backend imports
|
||||
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
|
||||
jax.config.update("jax_logging_level", None)
|
||||
sys.stderr.write("{_separator}")
|
||||
jax.jit(lambda x: x)(1) # should not log anything now
|
||||
"""
|
||||
|
||||
# strip the leading whitespace from the program script
|
||||
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
|
||||
|
||||
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c"
|
||||
f" '{program}'")
|
||||
p = subprocess.run(cmd, capture_output=True, text=True)
|
||||
log_output = p.stderr
|
||||
""", {"JAX_LOGGING_LEVEL": "DEBUG"})
|
||||
log_output = o.stderr
|
||||
m = re.search(_separator, log_output)
|
||||
self.assertTrue(m is not None)
|
||||
log_output_verbose = log_output[:m.start()]
|
||||
@ -252,19 +244,13 @@ class LoggingTest(jtu.JaxTestCase):
|
||||
if sys.executable is None:
|
||||
raise self.skipTest("test requires access to python binary")
|
||||
|
||||
program = """
|
||||
o = _run("""
|
||||
import jax # this prints INFO logging from backend imports
|
||||
jax.config.update("jax_debug_log_modules", "jax._src.compiler,jax._src.dispatch")
|
||||
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
|
||||
"""
|
||||
""", { "JAX_LOGGING_LEVEL": "DEBUG" })
|
||||
|
||||
# strip the leading whitespace from the program script
|
||||
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
|
||||
|
||||
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c"
|
||||
f" '{program}'")
|
||||
p = subprocess.run(cmd, capture_output=True, text=True)
|
||||
log_output = p.stderr
|
||||
log_output = o.stderr
|
||||
self.assertNotEmpty(log_output)
|
||||
log_lines = log_output.strip().split("\n")
|
||||
# only one tracing line should be printed, if there's more than one
|
||||
@ -285,31 +271,19 @@ class LoggingTest(jtu.JaxTestCase):
|
||||
jax.distributed.initialize("127.0.0.1:12345", num_processes=1, process_id=0)
|
||||
"""
|
||||
|
||||
# strip the leading whitespace from the program script
|
||||
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
|
||||
o = _run(program, { "JAX_LOGGING_LEVEL": "DEBUG" })
|
||||
self.assertIn("Initializing CoordinationService", o.stderr)
|
||||
|
||||
# verbose logging: DEBUG, VERBOSE
|
||||
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c"
|
||||
f" '{program}'")
|
||||
p = subprocess.run(cmd, capture_output=True, text=True)
|
||||
self.assertIn("Initializing CoordinationService", p.stderr)
|
||||
|
||||
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=INFO {sys.executable} -c"
|
||||
f" '{program}'")
|
||||
p = subprocess.run(cmd, capture_output=True, text=True)
|
||||
self.assertIn("Initializing CoordinationService", p.stderr)
|
||||
o = _run(program, { "JAX_LOGGING_LEVEL": "INFO" })
|
||||
self.assertIn("Initializing CoordinationService", o.stderr)
|
||||
|
||||
# verbose logging: WARNING, None
|
||||
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=WARNING {sys.executable} -c"
|
||||
f" '{program}'")
|
||||
p = subprocess.run(cmd, capture_output=True, text=True)
|
||||
self.assertNotIn("Initializing CoordinationService", p.stderr)
|
||||
o = _run(program, { "JAX_LOGGING_LEVEL": "WARNING" })
|
||||
self.assertNotIn("Initializing CoordinationService", o.stderr)
|
||||
|
||||
cmd = shlex.split(f"{sys.executable} -c"
|
||||
f" '{program}'")
|
||||
p = subprocess.run(cmd, capture_output=True, text=True)
|
||||
o = _run(program)
|
||||
if int(_default_TF_CPP_MIN_LOG_LEVEL) >= 1:
|
||||
self.assertNotIn("Initializing CoordinationService", p.stderr)
|
||||
self.assertNotIn("Initializing CoordinationService", o.stderr)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -1664,31 +1664,30 @@ class StreamAnnotationTest(jtu.JaxTestCase):
|
||||
def test_stream_annotation_inside_shmap(self):
|
||||
if not jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("Stream annotation is only supported on GPU.")
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
np_inp = np.ones((8, 8))
|
||||
mesh = jtu.create_mesh((2,), ('x',))
|
||||
s = NamedSharding(mesh, P('x'))
|
||||
np_inp = np.ones((8,))
|
||||
arr1 = jax.device_put(np_inp, s)
|
||||
arr2 = jax.device_put(np_inp, s)
|
||||
|
||||
@compute_on('gpu_stream:1')
|
||||
@jax.jit
|
||||
def g(x, y):
|
||||
return x @ y
|
||||
return x * y
|
||||
|
||||
@compute_on('gpu_stream:2')
|
||||
@jax.jit
|
||||
def h(x, y):
|
||||
return x @ y
|
||||
return x * y
|
||||
|
||||
def f(x, y):
|
||||
z = g(x, y)
|
||||
w = h(3 * x, 2 * y)
|
||||
return z + w
|
||||
|
||||
out = jax.jit(shard_map(f, mesh=mesh,
|
||||
in_specs=(P('x', 'y'), P('x', 'y')),
|
||||
out_specs=P('x', 'y')))(arr1, arr2)
|
||||
self.assertArraysEqual(out, arr1 * 28)
|
||||
out = jax.jit(shard_map(f, mesh=mesh, in_specs=(P('x'), P('x')),
|
||||
out_specs=P('x')))(arr1, arr2)
|
||||
self.assertArraysEqual(out, arr1 * 7)
|
||||
|
||||
|
||||
class ActivationOffloadingTest(jtu.JaxTestCase):
|
||||
|
@ -55,6 +55,7 @@ else:
|
||||
from jax.experimental.mosaic.gpu import launch_context
|
||||
from jax.experimental.mosaic.gpu import utils as utils
|
||||
from jax.experimental.mosaic.gpu import profiler
|
||||
from jax.experimental.mosaic.gpu import inference_utils
|
||||
from jax.experimental.mosaic.gpu.utils import * # noqa: F403
|
||||
from jax._src.lib.mlir.dialects import gpu
|
||||
from jax._src.lib.mlir.dialects import llvm
|
||||
@ -1945,6 +1946,21 @@ class FragmentedArrayTest(TestCase):
|
||||
)(inp)
|
||||
np.testing.assert_array_equal(inp, result)
|
||||
|
||||
@parameterized.product(in_shape=((128,), (64,)))
|
||||
def test_wgmma_row_load_store_with_layout(self, in_shape):
|
||||
def kernel(ctx, *args):
|
||||
gmem_input, gmem_output, (smem_input, smem_output) = args
|
||||
copy(gmem_input, smem_input)
|
||||
t = mgpu.FragmentedArray.load_wgmma_row(smem_input)
|
||||
t.store_untiled(smem_output)
|
||||
copy(smem_output, gmem_output)
|
||||
|
||||
inp = out = self.prng.uniform(-1, 1, in_shape).astype(jnp.float32)
|
||||
result = mgpu.as_gpu_kernel(
|
||||
kernel, (1, 1, 1), (128, 1, 1), (inp,), out, [inp, out],
|
||||
)(inp)
|
||||
np.testing.assert_array_equal(inp, result)
|
||||
|
||||
def test_warp_tree_reduce(self):
|
||||
def kernel(ctx, out, *_):
|
||||
del ctx
|
||||
@ -2405,25 +2421,21 @@ class Swizzle:
|
||||
return mgpu_dialect.SwizzleTransformAttr.get(self.swizzle)
|
||||
|
||||
|
||||
def memref_with_transforms(
|
||||
mem_ref: ir.Value,
|
||||
transforms: Sequence[Tile | Transpose | Swizzle],
|
||||
) -> ir.Value:
|
||||
"""Casts the memref to one that has a layout with the given transforms."""
|
||||
mem_ref_type = ir.MemRefType(mem_ref.type)
|
||||
def set_in_transforms(
|
||||
op: ir.OpView, transforms: Sequence[Sequence[Tile | Transpose | Swizzle]],
|
||||
) -> None:
|
||||
"""Annotates an op with in_transforms."""
|
||||
if not transforms:
|
||||
return
|
||||
|
||||
transform_attr = [t.attr() for t in transforms]
|
||||
if not transform_attr:
|
||||
return mem_ref
|
||||
in_transforms = []
|
||||
smem_refs = filter(inference_utils.is_transformable_smem_memref, op.operands) # pylint: disable=undefined-variable
|
||||
for _, result_transforms in jax.util.safe_zip(smem_refs, transforms):
|
||||
in_transforms.append(
|
||||
ir.ArrayAttr.get([t.attr() for t in result_transforms])
|
||||
)
|
||||
|
||||
layout = mgpu_dialect.LayoutAttr.get(mem_ref_type.rank, transform_attr)
|
||||
memref_new_type = ir.MemRefType.get(
|
||||
mem_ref_type.shape,
|
||||
mem_ref_type.element_type,
|
||||
layout,
|
||||
mem_ref_type.memory_space,
|
||||
)
|
||||
return memref.cast(memref_new_type, mem_ref)
|
||||
op.attributes["in_transforms"] = ir.ArrayAttr.get(in_transforms)
|
||||
|
||||
|
||||
class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
|
||||
@ -2556,7 +2568,6 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
|
||||
):
|
||||
del ctx
|
||||
smem_ref, tma_barrier = smem
|
||||
smem_ref = memref_with_transforms(smem_ref, test_case.transforms)
|
||||
dialect_barrier = tma_barrier.as_dialect_barrier_memref()
|
||||
|
||||
elt_type = ir.MemRefType(in_gmem_ref.type).element_type
|
||||
@ -2571,7 +2582,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
|
||||
slice_indices = [arith.constant(i32, i) for i in test_case.slice_indices]
|
||||
|
||||
# GMEM -> SMEM
|
||||
mgpu_dialect.async_load(
|
||||
load_op = mgpu_dialect.AsyncLoadOp(
|
||||
source=in_gmem_ref,
|
||||
destination=smem_ref,
|
||||
barrier=dialect_barrier,
|
||||
@ -2579,6 +2590,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
|
||||
slice_lengths=test_case.slice_lengths,
|
||||
collective=ir.ArrayAttr.get([]),
|
||||
)
|
||||
set_in_transforms(load_op, [test_case.transforms])
|
||||
|
||||
parities = memref.load(tma_barrier.phases, [])
|
||||
parity, _ = tma_barrier.update_parities(parities)
|
||||
@ -2623,58 +2635,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
|
||||
(x[input_slice]).reshape(test_case.shape_sliced),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def pointwise_kernel_with_tma_cases(dtype: jnp.dtype):
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TestCaseInput:
|
||||
shape: tuple[int, ...]
|
||||
transforms: tuple[Tile | Transpose | Swizzle, ...] = ()
|
||||
|
||||
result = []
|
||||
for swizzle in mgpu_dialect.SwizzlingMode:
|
||||
n = swizzle * 8 // jnp.finfo(dtype).bits
|
||||
if swizzle == mgpu_dialect.SwizzlingMode.kNoSwizzle:
|
||||
# We need at least one case with no transforms, as this is handled
|
||||
# differently.
|
||||
result.append(TestCaseInput(shape=[128, n]))
|
||||
result.extend([
|
||||
TestCaseInput(
|
||||
shape=[128, n],
|
||||
transforms=[Swizzle(swizzle)],
|
||||
),
|
||||
TestCaseInput(
|
||||
shape=[2, 3, 64, n],
|
||||
transforms=[Transpose([0, 1, 2, 3]), Swizzle(swizzle)],
|
||||
),
|
||||
TestCaseInput(
|
||||
shape=[2, 3, 64, n],
|
||||
transforms=[
|
||||
Transpose([1, 0, 2, 3]),
|
||||
Transpose([1, 0, 2, 3]),
|
||||
Swizzle(swizzle),
|
||||
],
|
||||
),
|
||||
TestCaseInput(
|
||||
shape=[2, 3, 64, n],
|
||||
transforms=[Transpose([1, 0, 2, 3]), Swizzle(swizzle)],
|
||||
),
|
||||
TestCaseInput(
|
||||
shape=[128, n],
|
||||
transforms=[Tile([64, n]), Swizzle(swizzle)],
|
||||
),
|
||||
TestCaseInput(
|
||||
shape=[2 * 64, 3 * n],
|
||||
transforms=[
|
||||
Tile([64, n]),
|
||||
Transpose([1, 0, 2, 3]),
|
||||
Swizzle(swizzle),
|
||||
],
|
||||
),
|
||||
])
|
||||
return result
|
||||
|
||||
@parameterized.parameters(pointwise_kernel_with_tma_cases(jnp.bfloat16))
|
||||
def test_pointwise_kernel_with_tma(self, test_case):
|
||||
def test_pointwise_kernel_with_tma(self):
|
||||
def add(
|
||||
ctx: launch_context.LaunchContext,
|
||||
a_gmem_ref: ir.Value,
|
||||
@ -2701,9 +2662,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
|
||||
# GMEM -> SMEM
|
||||
mgpu_dialect.async_load(
|
||||
source=a_gmem_ref,
|
||||
destination=memref_with_transforms(
|
||||
a_smem_ref, test_case.transforms
|
||||
),
|
||||
destination=a_smem_ref,
|
||||
barrier=dialect_barrier,
|
||||
indices=zero_slice_indices,
|
||||
slice_lengths=shape,
|
||||
@ -2711,9 +2670,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
|
||||
)
|
||||
mgpu_dialect.async_load(
|
||||
source=b_gmem_ref,
|
||||
destination=memref_with_transforms(
|
||||
b_smem_ref, test_case.transforms
|
||||
),
|
||||
destination=b_smem_ref,
|
||||
barrier=dialect_barrier,
|
||||
indices=zero_slice_indices,
|
||||
slice_lengths=shape,
|
||||
@ -2740,9 +2697,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
|
||||
|
||||
# SMEM -> GMEM
|
||||
mgpu_dialect.async_store(
|
||||
source=memref_with_transforms(
|
||||
result_smem_ref, test_case.transforms
|
||||
),
|
||||
source=result_smem_ref,
|
||||
destination=result_gmem_ref,
|
||||
indices=zero_slice_indices,
|
||||
slice_lengths=shape,
|
||||
@ -2752,114 +2707,76 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
|
||||
|
||||
dtype = jnp.bfloat16
|
||||
|
||||
jax_shape = jax.ShapeDtypeStruct(test_case.shape, dtype)
|
||||
spec = jax.ShapeDtypeStruct((2, 3, 4, 64), dtype)
|
||||
kernel = mgpu.as_gpu_kernel(
|
||||
add,
|
||||
grid=(1, 1, 1),
|
||||
block=(128, 1, 1),
|
||||
in_shape=(jax_shape, jax_shape),
|
||||
out_shape=jax_shape,
|
||||
in_shape=(spec, spec),
|
||||
out_shape=spec,
|
||||
smem_scratch_shape=[
|
||||
jax_shape,
|
||||
jax_shape,
|
||||
jax_shape,
|
||||
spec,
|
||||
spec,
|
||||
spec,
|
||||
core.TMABarrier(1),
|
||||
],
|
||||
thread_semantics=mgpu.ThreadSemantics.Warpgroup,
|
||||
)
|
||||
|
||||
x = self.prng.uniform(-1, 1, test_case.shape).astype(dtype)
|
||||
y = self.prng.uniform(-1, 1, test_case.shape).astype(dtype)
|
||||
x = self.prng.uniform(-1, 1, spec.shape).astype(dtype)
|
||||
y = self.prng.uniform(-1, 1, spec.shape).astype(dtype)
|
||||
|
||||
self.assertArraysEqual(jax.jit(kernel)(x, y), x + y + y)
|
||||
|
||||
|
||||
class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase):
|
||||
|
||||
@staticmethod
|
||||
def wgmma_kernel_with_tma_cases(abtype: jnp.dtype):
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TestCaseInput:
|
||||
shape_a: tuple[int, ...] = ()
|
||||
shape_b: tuple[int, ...] = ()
|
||||
shape_res: tuple[int, ...] = ()
|
||||
transforms_a: tuple[Tile | Transpose | Swizzle, ...] = ()
|
||||
transforms_b: tuple[Tile | Transpose | Swizzle, ...] = ()
|
||||
transpose_a: bool = False
|
||||
transpose_b: bool = False
|
||||
load_a_in_registers: bool = False
|
||||
@parameterized.named_parameters(
|
||||
(
|
||||
f"swizzle={int(swizzle)}_{transpose_lhs=}_{transpose_rhs=}_{lhs_in_registers=}",
|
||||
swizzle,
|
||||
transpose_lhs,
|
||||
transpose_rhs,
|
||||
lhs_in_registers,
|
||||
)
|
||||
for swizzle in mgpu_dialect.SwizzlingMode
|
||||
for transpose_lhs in [False, True]
|
||||
for transpose_rhs in [False, True]
|
||||
for lhs_in_registers in [False, True]
|
||||
)
|
||||
def test_wgmma_kernel_with_tma(
|
||||
self, swizzle, transpose_lhs, transpose_rhs, load_a_in_registers
|
||||
):
|
||||
if swizzle == mgpu_dialect.SwizzlingMode.kNoSwizzle:
|
||||
self.skipTest("No swizzle is not supported by wgmma")
|
||||
|
||||
result = []
|
||||
for swizzle in [
|
||||
# TODO(dasenov): Add a test for kNoSwizzle, i.e. all swizzling modes.
|
||||
mgpu_dialect.SwizzlingMode.k32ByteSwizzle,
|
||||
mgpu_dialect.SwizzlingMode.k64ByteSwizzle,
|
||||
mgpu_dialect.SwizzlingMode.k128ByteSwizzle,
|
||||
]:
|
||||
k = swizzle // np.dtype(abtype).itemsize
|
||||
groups_m = 4
|
||||
groups_n = 1
|
||||
groups_k = 1
|
||||
result.extend([
|
||||
TestCaseInput(
|
||||
shape_a=[groups_m * 64, groups_k * k],
|
||||
shape_b=[groups_k * k, groups_n * k],
|
||||
shape_res=[groups_m * 64, groups_n * k],
|
||||
),
|
||||
TestCaseInput(
|
||||
shape_a=[groups_m * 64, groups_k * k],
|
||||
shape_b=[groups_n * k, groups_k * k],
|
||||
shape_res=[groups_m * 64, groups_n * k],
|
||||
transpose_b=True,
|
||||
),
|
||||
TestCaseInput(
|
||||
shape_a=[groups_m * 64, groups_k * k],
|
||||
shape_b=[groups_k * k, groups_n * k],
|
||||
shape_res=[groups_m * 64, groups_n * k],
|
||||
transforms_a=[Tile([64, k]), Swizzle(swizzle)],
|
||||
transforms_b=[Tile([k, k]), Swizzle(swizzle)],
|
||||
),
|
||||
TestCaseInput(
|
||||
shape_a=[groups_m * 64, groups_k * k],
|
||||
shape_b=[groups_k * k, groups_n * k],
|
||||
shape_res=[groups_m * 64, groups_n * k],
|
||||
transforms_a=[Tile([64, k]), Swizzle(swizzle)],
|
||||
load_a_in_registers=True,
|
||||
),
|
||||
])
|
||||
# The below only works for 128-byte swizzling. Regardless of transposing,
|
||||
# TMA needs the size of the last dimension to be compatible with the
|
||||
# swizzle.
|
||||
if swizzle == mgpu_dialect.SwizzlingMode.k128ByteSwizzle:
|
||||
result.append(
|
||||
TestCaseInput(
|
||||
shape_a=[groups_k * k, groups_m * 64],
|
||||
shape_b=[groups_k * k, groups_n * k],
|
||||
shape_res=[groups_m * 64, groups_n * k],
|
||||
transpose_a=True,
|
||||
)
|
||||
)
|
||||
return result
|
||||
if transpose_lhs or transpose_rhs:
|
||||
self.skipTest("Transposes are not supported by transform inference yet.")
|
||||
|
||||
@parameterized.parameters(wgmma_kernel_with_tma_cases(jnp.bfloat16))
|
||||
def test_wgmma_kernel_with_tma(self, test_case):
|
||||
swizzle_elems = swizzle // np.dtype(jnp.bfloat16).itemsize
|
||||
tiling_m, tiling_n, tiling_k = 64, swizzle_elems, swizzle_elems
|
||||
|
||||
groups_m, groups_n, groups_k = 4, 1, 1
|
||||
m, n, k = groups_m * tiling_m, groups_n * tiling_n, groups_k * tiling_k
|
||||
|
||||
lhs_shape = (k, m) if transpose_lhs else (m, k)
|
||||
rhs_shape = (n, k) if transpose_rhs else (k, n)
|
||||
out_shape = (m, n)
|
||||
|
||||
def matmul(
|
||||
ctx: launch_context.LaunchContext,
|
||||
a_gmem_ref: ir.Value,
|
||||
b_gmem_ref: ir.Value,
|
||||
lhs_gmem_ref: ir.Value,
|
||||
rhs_gmem_ref: ir.Value,
|
||||
result_gmem_ref: ir.Value,
|
||||
smem: list[ir.Value],
|
||||
):
|
||||
del ctx
|
||||
a_smem_ref, b_smem_ref, result_smem_ref, tma_barrier = smem
|
||||
a_smem_ref = memref_with_transforms(a_smem_ref, test_case.transforms_a)
|
||||
b_smem_ref = memref_with_transforms(b_smem_ref, test_case.transforms_b)
|
||||
lhs_smem_ref, rhs_smem_ref, result_smem_ref, tma_barrier = smem
|
||||
dialect_barrier = tma_barrier.as_dialect_barrier_memref()
|
||||
|
||||
ab_elt_type = ir.MemRefType(a_gmem_ref.type).element_type
|
||||
bytes_a = utils.bytewidth(ab_elt_type) * math.prod(test_case.shape_a)
|
||||
bytes_b = utils.bytewidth(ab_elt_type) * math.prod(test_case.shape_b)
|
||||
operand_elt_type = ir.MemRefType(lhs_gmem_ref.type).element_type
|
||||
bytes_a = utils.bytewidth(operand_elt_type) * math.prod(lhs_shape)
|
||||
bytes_b = utils.bytewidth(operand_elt_type) * math.prod(rhs_shape)
|
||||
|
||||
mgpu_dialect.arrive_expect_tx(
|
||||
barrier=dialect_barrier,
|
||||
@ -2869,19 +2786,19 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase):
|
||||
zero_i32 = arith.constant(ir.IntegerType.get_signless(32), 0)
|
||||
# GMEM -> SMEM
|
||||
mgpu_dialect.async_load(
|
||||
source=a_gmem_ref,
|
||||
destination=a_smem_ref,
|
||||
source=lhs_gmem_ref,
|
||||
destination=lhs_smem_ref,
|
||||
barrier=dialect_barrier,
|
||||
indices=[zero_i32] * len(test_case.shape_a),
|
||||
slice_lengths=test_case.shape_a,
|
||||
indices=[zero_i32] * len(lhs_shape),
|
||||
slice_lengths=lhs_shape,
|
||||
collective=ir.ArrayAttr.get([]),
|
||||
)
|
||||
mgpu_dialect.async_load(
|
||||
source=b_gmem_ref,
|
||||
destination=b_smem_ref,
|
||||
source=rhs_gmem_ref,
|
||||
destination=rhs_smem_ref,
|
||||
barrier=dialect_barrier,
|
||||
indices=[zero_i32] * len(test_case.shape_b),
|
||||
slice_lengths=test_case.shape_b,
|
||||
indices=[zero_i32] * len(rhs_shape),
|
||||
slice_lengths=rhs_shape,
|
||||
collective=ir.ArrayAttr.get([]),
|
||||
)
|
||||
|
||||
@ -2889,29 +2806,34 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase):
|
||||
parity, _ = tma_barrier.update_parities(parities)
|
||||
mgpu_dialect.wait(dialect_barrier, parity)
|
||||
|
||||
# SMEM -> Registers
|
||||
a_operand = a_smem_ref
|
||||
zero_index = arith.constant(ir.IndexType.get(), 0)
|
||||
if test_case.load_a_in_registers:
|
||||
a_vector_type = ir.VectorType.get(test_case.shape_a, ab_elt_type)
|
||||
zero_vector_indices = [zero_index] * len(test_case.shape_a)
|
||||
a_operand = vector.load(a_vector_type, a_smem_ref, zero_vector_indices)
|
||||
|
||||
# Computation
|
||||
shape_result = ir.MemRefType(result_gmem_ref.type).shape
|
||||
result_elt_type = ir.MemRefType(result_gmem_ref.type).element_type
|
||||
acc_elt_type = ir.F32Type.get()
|
||||
acc_type = ir.VectorType.get(shape_result, acc_elt_type)
|
||||
zero_acc = arith.constant(
|
||||
result_elt_type, ir.FloatAttr.get(result_elt_type, 0.0)
|
||||
)
|
||||
accumulator = vector.splat(
|
||||
ir.VectorType.get(shape_result, result_elt_type), zero_acc
|
||||
result_elt_type, ir.FloatAttr.get(acc_elt_type, 0.0)
|
||||
)
|
||||
accumulator = vector.splat(acc_type, zero_acc)
|
||||
|
||||
if transpose_lhs:
|
||||
lhs_smem_ref = utils.memref_transpose(lhs_smem_ref, (1, 0))
|
||||
if transpose_rhs:
|
||||
rhs_smem_ref = utils.memref_transpose(rhs_smem_ref, (1, 0))
|
||||
|
||||
zero_index = arith.constant(ir.IndexType.get(), 0)
|
||||
if load_a_in_registers:
|
||||
# SMEM -> Registers
|
||||
lhs_ty = ir.VectorType.get(lhs_shape, operand_elt_type)
|
||||
zero_vector_indices = [zero_index] * len(lhs_shape)
|
||||
lhs_operand = vector.load(lhs_ty, lhs_smem_ref, zero_vector_indices)
|
||||
else:
|
||||
lhs_operand = lhs_smem_ref
|
||||
|
||||
result = mgpu_dialect.wgmma(
|
||||
accumulator,
|
||||
a_operand,
|
||||
b_smem_ref,
|
||||
transpose_a=test_case.transpose_a,
|
||||
transpose_b=test_case.transpose_b,
|
||||
lhs_operand,
|
||||
rhs_smem_ref,
|
||||
)
|
||||
|
||||
nvvm.wgmma_commit_group_sync_aligned()
|
||||
@ -2929,38 +2851,41 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase):
|
||||
)
|
||||
nvvm.cp_async_bulk_wait_group(0)
|
||||
|
||||
abtype = jnp.bfloat16
|
||||
operand_type = jnp.bfloat16
|
||||
acctype = jnp.float32
|
||||
a_jax_shape = jax.ShapeDtypeStruct(test_case.shape_a, abtype)
|
||||
b_jax_shape = jax.ShapeDtypeStruct(test_case.shape_b, abtype)
|
||||
result_jax_shape = jax.ShapeDtypeStruct(test_case.shape_res, acctype)
|
||||
lhs_jax_shape = jax.ShapeDtypeStruct(lhs_shape, operand_type)
|
||||
rhs_jax_shape = jax.ShapeDtypeStruct(rhs_shape, operand_type)
|
||||
result_jax_shape = jax.ShapeDtypeStruct(out_shape, acctype)
|
||||
kernel = mgpu.as_gpu_kernel(
|
||||
matmul,
|
||||
grid=(1, 1, 1),
|
||||
block=(128, 1, 1),
|
||||
in_shape=(a_jax_shape, b_jax_shape),
|
||||
in_shape=(lhs_jax_shape, rhs_jax_shape),
|
||||
out_shape=result_jax_shape,
|
||||
smem_scratch_shape=[
|
||||
a_jax_shape,
|
||||
b_jax_shape,
|
||||
lhs_jax_shape,
|
||||
rhs_jax_shape,
|
||||
result_jax_shape,
|
||||
core.TMABarrier(1),
|
||||
],
|
||||
thread_semantics=mgpu.ThreadSemantics.Warpgroup,
|
||||
)
|
||||
|
||||
x = self.prng.uniform(-1, 1, test_case.shape_a).astype(abtype)
|
||||
y = self.prng.uniform(-1, 1, test_case.shape_b).astype(abtype)
|
||||
prng_key = jax.random.key(1234)
|
||||
k0, k1 = jax.random.split(prng_key, 2)
|
||||
|
||||
x = jax.random.randint(k0, lhs_shape, 0, 2).astype(operand_type)
|
||||
y = jax.random.randint(k1, rhs_shape, 0, 2).astype(operand_type)
|
||||
|
||||
transpose = lambda x, t: x.T if t else x
|
||||
self.assertArraysAllClose(
|
||||
jax.jit(kernel)(x, y),
|
||||
np.matmul(
|
||||
transpose(x.reshape(test_case.shape_a), test_case.transpose_a),
|
||||
transpose(y.reshape(test_case.shape_b), test_case.transpose_b),
|
||||
transpose(x, transpose_lhs),
|
||||
transpose(y, transpose_rhs)
|
||||
),
|
||||
atol=1e-5,
|
||||
rtol=1e-5,
|
||||
atol=0,
|
||||
rtol=0,
|
||||
)
|
||||
|
||||
|
||||
|
@ -2501,7 +2501,8 @@ class SymbolicPallasTest(PallasBaseTest):
|
||||
)
|
||||
assert exported_module is not None
|
||||
self.assertIn(
|
||||
"tensor<?x?xf32>, %arg6: tensor<?x?xf32>, %arg7: tensor<?x?xf32>",
|
||||
"%arg0: tensor<?x?xf32> loc(unknown), %arg1: tensor<?x?xf32>"
|
||||
" loc(unknown), %arg2: tensor<?x?xf32>",
|
||||
str(exported_module),
|
||||
)
|
||||
x = jax.ShapeDtypeStruct((128, 1024), jax.numpy.float32)
|
||||
@ -2512,7 +2513,7 @@ class SymbolicPallasTest(PallasBaseTest):
|
||||
)
|
||||
assert exported_module is not None
|
||||
self.assertIn(
|
||||
"@sym_matmul(%arg0: tensor<128x1024xf32>, %arg1: tensor<1024x512xf32>",
|
||||
"call @sym_matmul(%arg0, %arg1)",
|
||||
str(exported_module),
|
||||
)
|
||||
|
||||
|
@ -156,5 +156,36 @@ class InterpretTest(jtu.JaxTestCase):
|
||||
lowered = jax.jit(matmul).lower(x, y).as_text(dialect="stablehlo")
|
||||
self.assertNotIn("dot_general", lowered)
|
||||
|
||||
@parameterized.parameters('nan', 'zero')
|
||||
def test_uninitialized_memory(self, uninitialized_memory):
|
||||
def kernel(o1_ref, o2_ref, o3_ref, t1_ref, t2_ref):
|
||||
o1_ref[...] = t1_ref[...]
|
||||
o2_ref[...] = t2_ref[...]
|
||||
|
||||
x, y, z = pl.pallas_call(
|
||||
kernel,
|
||||
out_shape=[
|
||||
jax.ShapeDtypeStruct((8, 128), jnp.bfloat16),
|
||||
jax.ShapeDtypeStruct((8, 128), jnp.int16),
|
||||
jax.ShapeDtypeStruct((8, 128), jnp.float32),
|
||||
],
|
||||
in_specs=[],
|
||||
scratch_shapes=[
|
||||
pltpu.VMEM((8, 128), jnp.bfloat16),
|
||||
pltpu.VMEM((8, 128), jnp.int16),
|
||||
],
|
||||
interpret=mosaic_interpret.TPUInterpretParams(
|
||||
uninitialized_memory=uninitialized_memory),
|
||||
)()
|
||||
if uninitialized_memory == 'nan':
|
||||
self.assertTrue(jnp.isnan(x).all())
|
||||
np.testing.assert_equal(np.array(y), 32767)
|
||||
self.assertTrue(jnp.isnan(z).all())
|
||||
if uninitialized_memory == 'zero':
|
||||
np.testing.assert_equal(np.array(x), 0)
|
||||
np.testing.assert_equal(np.array(y), 0)
|
||||
np.testing.assert_equal(np.array(z), 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -1680,6 +1680,47 @@ class CustomPartitionerTest(jtu.JaxTestCase):
|
||||
jit_f = jax.jit(f, in_shardings=s, out_shardings=s)
|
||||
self.assertArraysEqual(x, jit_f(x))
|
||||
|
||||
@jtu.with_mesh([('x', 4), ('y', 2)])
|
||||
def test_custom_partitioner_pytree_inputs(self):
|
||||
self.skip_if_custom_partitioning_not_supported()
|
||||
|
||||
def partition(mesh, arg_shapes, result_shape):
|
||||
def lower_fn(xs):
|
||||
x, y, z = xs
|
||||
return x + y + z
|
||||
|
||||
return (
|
||||
mesh,
|
||||
lower_fn,
|
||||
arg_shapes[0][0].sharding,
|
||||
jax.tree.map(lambda x: x.sharding, arg_shapes),
|
||||
)
|
||||
|
||||
def infer_sharding_from_operands(mesh, arg_shapes, result_shape):
|
||||
return arg_shapes[0][0].sharding
|
||||
|
||||
def propagate_user_sharding(mesh, user_shape):
|
||||
return user_shape.sharding
|
||||
|
||||
@custom_partitioning
|
||||
def f(xs):
|
||||
x, y, z = xs
|
||||
return x + y + z
|
||||
|
||||
f.def_partition(
|
||||
infer_sharding_from_operands=infer_sharding_from_operands,
|
||||
partition=partition,
|
||||
propagate_user_sharding=propagate_user_sharding,
|
||||
sharding_rule='i j, i j, i j -> i j',
|
||||
)
|
||||
|
||||
def f2(a):
|
||||
return a + f((a, a, a))
|
||||
|
||||
pjit_f = pjit(f2, in_shardings=(P(None, 'x')), out_shardings=P('x'))
|
||||
x = np.asarray(np.random.randint(0, 20, (32, 16)), dtype=np.float32)
|
||||
self.assertArraysEqual(x * 4, pjit_f(x))
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
@ -7096,16 +7137,12 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
|
||||
def test_set_mesh(self):
|
||||
mesh = jtu.create_mesh((2,), ('x',), axis_types=(AxisType.Explicit,))
|
||||
prev_mesh = config.device_context.value
|
||||
prev_abstract_mesh = config.abstract_mesh_context_manager.value
|
||||
try:
|
||||
jax.sharding.set_mesh(mesh)
|
||||
|
||||
prev_mesh = jax.sharding.set_mesh(mesh)
|
||||
out = reshard(np.arange(8), P('x'))
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('x')))
|
||||
finally:
|
||||
config.device_context.set_local(prev_mesh)
|
||||
config.abstract_mesh_context_manager.set_local(prev_abstract_mesh)
|
||||
jax.sharding.set_mesh(prev_mesh)
|
||||
|
||||
@jtu.with_user_mesh((2,), ('x',))
|
||||
def test_auto_axes_late_bind(self, mesh):
|
||||
|
4
third_party/xla/workspace.bzl
vendored
4
third_party/xla/workspace.bzl
vendored
@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
|
||||
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
|
||||
# and update XLA_SHA256 with the result.
|
||||
|
||||
XLA_COMMIT = "3bb765472122548cc227b8bd2990f00bd533f438"
|
||||
XLA_SHA256 = "72126aac7602153aee985ca20f73d11c39e3ba9cfb8027492951e787559d0497"
|
||||
XLA_COMMIT = "df971129bd82e381954da0185b534220e21798a4"
|
||||
XLA_SHA256 = "11e9a568320cf7e7d61819620fd369927527ecefb68d5d1154b1521456bbdb72"
|
||||
|
||||
def repo():
|
||||
tf_http_archive(
|
||||
|
Loading…
x
Reference in New Issue
Block a user