Merge pull request #299 from ROCm/ci-upstream-sync-152_1

CI: 03/19/25 upstream sync
This commit is contained in:
rocm-repo-management-api-2[bot] 2025-03-19 07:20:19 -05:00 committed by GitHub
commit b505df9973
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
42 changed files with 1048 additions and 649 deletions

View File

@ -1,8 +1,8 @@
diff --git a/build/requirements_lock_3_13_ft.txt b/build/requirements_lock_3_13_ft.txt
index dfefaf042..2700e140e 100644
index e7a2968e9..d37e11ee3 100644
--- a/build/requirements_lock_3_13_ft.txt
+++ b/build/requirements_lock_3_13_ft.txt
@@ -4,6 +4,12 @@
@@ -4,6 +4,11 @@
#
# pip-compile --allow-unsafe --generate-hashes --output-file=build/requirements_lock_3_13_ft.txt build/requirements.in
#
@ -10,12 +10,11 @@ index dfefaf042..2700e140e 100644
+--pre
+--extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple
+numpy
+
+
absl-py==2.1.0 \
--hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \
--hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff
@@ -328,68 +334,6 @@ mpmath==1.3.0 \
@@ -328,68 +333,6 @@ mpmath==1.3.0 \
--hash=sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f \
--hash=sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c
# via -r build/test-requirements.txt
@ -81,6 +80,6 @@ index dfefaf042..2700e140e 100644
- # matplotlib
- # ml-dtypes
- # scipy
opt-einsum==3.4.0 \
--hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \
--hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac
nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \
--hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \
--hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \

View File

@ -173,12 +173,18 @@ jobs:
--bazel_options=--copt=-g \
--clang_path=/usr/bin/clang-18
# Update the patch to use TSAN instrumented numpy
# Patch build/requirements_lock_3_13_ft.txt to use TSAN instrumented NumPy
sed -i "s|+--extra-index-url.*|+--extra-index-url file://${GITHUB_WORKSPACE}/wheelhouse/|" .github/workflows/requirements_lock_3_13_ft.patch
cat .github/workflows/requirements_lock_3_13_ft.patch
git apply .github/workflows/requirements_lock_3_13_ft.patch || exit 1
# Apply a patch to numpy in requirements lock 3.13 ft to use the nightly version
git apply .github/workflows/requirements_lock_3_13_ft.patch
# Display the content for debugging in logs
cat build/requirements_lock_3_13_ft.txt | head -15
# Check the patch
cat build/requirements_lock_3_13_ft.txt | head -15 | grep -E "(--pre|.*${GITHUB_WORKSPACE}/wheelhouse/|numpy)"
if [ "$?" == "1" ]; then echo "Could not find the patch in the requirements_lock_3_13_ft.txt"; exit 1; fi
cat build/requirements_lock_3_13_ft.txt | grep -E "(numpy==)"
if [ "$?" == "0" ]; then "Found original numpy dependency in the requirements_lock_3_13_ft.txt"; exit 1; fi
echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES"
echo "JAX_ENABLE_X64=$JAX_ENABLE_X64"
@ -188,6 +194,13 @@ jobs:
bazel_exec=($(ls bazel-*))
ln -s ${bazel_exec} bazel
# Check python version
./bazel run --@rules_python//python/config_settings:py_freethreaded="yes" @python//:python3 -- -VV
# Check numpy version
./bazel cquery @pypi_numpy//:* | grep whl
# Build JAX and run tests
./bazel test \
--test_env=JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES \
--test_env=JAX_ENABLE_X64=$JAX_ENABLE_X64 \

View File

@ -33,7 +33,6 @@ from jax._src import errors
from jax._src import profiler
from jax._src import util
from jax._src import xla_bridge
from jax._src.mesh import use_concrete_mesh
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
from jax._src.interpreters import xla
@ -43,7 +42,8 @@ from jax._src.lib import xla_extension as xe
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (
PmapSharding, SingleDeviceSharding,
device_replica_id_map, hashed_index, num_addressable_indices, local_to_global_shape) # pyformat: disable
device_replica_id_map, hashed_index, num_addressable_indices,
local_to_global_shape, use_concrete_mesh) # pyformat: disable
from jax._src.typing import ArrayLike, DLDeviceType, DTypeLike
from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method, cache
import numpy as np

View File

@ -29,8 +29,8 @@ class SampleFn(Protocol):
def _compute_tile_index(block_index: Sequence[int],
total_size_in_blocks: Shape,
block_size_in_tiles: Shape,
total_size_in_tiles: Shape,
tile_index_in_block: Sequence[int]) -> int:
ndims = len(block_index)
dim_size = 1
@ -38,7 +38,7 @@ def _compute_tile_index(block_index: Sequence[int],
for i in range(ndims-1, -1, -1):
dim_idx = tile_index_in_block[i] + block_index[i] * block_size_in_tiles[i]
total_idx += dim_idx * dim_size
dim_size *= total_size_in_blocks[i] * block_size_in_tiles[i]
dim_size *= total_size_in_tiles[i]
return total_idx
@ -103,15 +103,17 @@ def blocked_fold_in(
_shape // _element for _shape, _element in zip(block_size, tile_size)
)
total_size_in_blocks = tuple(
_shape // _element for _shape, _element in zip(total_size, block_size)
# Round up to make sure every tile is numbered.
total_size_in_tiles = tuple(
(_shape + _element - 1) // _element
for _shape, _element in zip(total_size, tile_size)
)
def _keygen_loop(axis, prefix):
if axis == len(block_size_in_tiles):
subtile_key = jax.random.fold_in(
global_key, _compute_tile_index(
block_index, total_size_in_blocks, block_size_in_tiles, prefix))
block_index, block_size_in_tiles, total_size_in_tiles, prefix))
return subtile_key
else:
keys = []

View File

@ -179,7 +179,7 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape,
for sharding, s in zip(result_shardings, result_shapes)
]
closed_jaxpr = jax.make_jaxpr(lower_fn, axis_env=list(mesh.shape.items()))(
*tiled_args
*info.in_tree.unflatten(tiled_args)
)
if ([(o.shape, o.dtype) for o in closed_jaxpr.out_avals] !=
[(t.shape, t.dtype) for t in tiled_results]):

View File

@ -41,7 +41,7 @@ from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval,
JaxprEqn, Primitive, ShapedArray, DShapedArray,
mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx,
InputType, OutputType, get_referent, JaxprEqnContext)
from jax._src.state.types import AbstractRef
from jax._src.state.types import AbstractRef, ReadEffect
from jax._src.tree_util import (PyTreeDef, treedef_tuple,
tree_flatten, tree_structure)
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
@ -1423,7 +1423,8 @@ def dce_jaxpr_consts(jaxpr: Jaxpr, used_outputs: Sequence[bool],
def has_effects(eqn: JaxprEqn) -> bool:
effs = {e for e in eqn.effects if not isinstance(e, core.NamedAxisEffect)}
effs = {e for e in eqn.effects if not isinstance(e, core.NamedAxisEffect)
and not isinstance(e, ReadEffect)}
return bool(effs)

View File

@ -615,8 +615,23 @@ def tanh(x: ArrayLike) -> Array:
"""
return tanh_p.bind(x)
@export
def logistic(x: ArrayLike) -> Array:
r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`."""
r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`.
There is no HLO logistic/sigmoid primitive, so this lowers to a sequence
of HLO arithmetic operations.
Args:
x: input array. Must have floating point or complex dtype.
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
logistic/sigmoid function.
See also:
- :func:`jax.nn.sigmoid`: an alternative API for this functionality.
"""
return logistic_p.bind(x)
@export
@ -1018,12 +1033,45 @@ def bitwise_xor(x: ArrayLike, y: ArrayLike) -> Array:
"""
return xor_p.bind(x, y)
@export
def population_count(x: ArrayLike) -> Array:
r"""Elementwise popcount, count the number of set bits in each element."""
r"""Elementwise popcount, count the number of set bits in each element.
This function lowers directly to the `stablehlo.popcnt`_ operation.
Args:
x: Input array. Must have integer dtype.
Returns:
An array of the same shape and dtype as ``x``, containing the number of
set bits in the input.
See also:
- :func:`jax.lax.clz`: Elementwise count leading zeros.
- :func:`jax.numpy.bitwise_count`: More flexible NumPy-style API for bit counts.
.. _stablehlo.popcnt: https://openxla.org/stablehlo/spec#popcnt
"""
return population_count_p.bind(x)
@export
def clz(x: ArrayLike) -> Array:
r"""Elementwise count-leading-zeros."""
r"""Elementwise count-leading-zeros.
This function lowers directly to the `stablehlo.count_leading_zeros`_ operation.
Args:
x: Input array. Must have integer dtype.
Returns:
An array of the same shape and dtype as ``x``, containing the number of
set bits in the input.
See also:
- :func:`jax.lax.population_count`: Count the number of set bits in each element.
.. _stablehlo.count_leading_zeros: https://openxla.org/stablehlo/spec#count_leading_zeros
"""
return clz_p.bind(x)
@export
@ -1124,31 +1172,81 @@ def div(x: ArrayLike, y: ArrayLike) -> Array:
"""
return div_p.bind(x, y)
@export
def rem(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise remainder: :math:`x \bmod y`.
The sign of the result is taken from the dividend,
and the absolute value of the result is always
less than the divisor's absolute value.
This function lowers directly to the `stablehlo.remainder`_ operation.
The sign of the result is taken from the dividend, and the absolute value
of the result is always less than the divisor's absolute value.
Integer division overflow
(remainder by zero or remainder of INT_SMIN with -1)
Integer division overflow (remainder by zero or remainder of INT_SMIN with -1)
produces an implementation defined value.
Args:
x, y: Input arrays. Must have matching int or float dtypes. If neither
is a scalar, ``x`` and ``y`` must have the same number of dimensions
and be broadcast compatible.
Returns:
An array of the same dtype as ``x`` and ``y`` containing the remainder.
See also:
- :func:`jax.numpy.remainder`: NumPy-style remainder with different
sign semantics.
.. _stablehlo.remainder: https://openxla.org/stablehlo/spec#remainder
"""
return rem_p.bind(x, y)
@export
def max(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise maximum: :math:`\mathrm{max}(x, y)`
r"""Elementwise maximum: :math:`\mathrm{max}(x, y)`.
For complex numbers, uses a lexicographic comparison on the
`(real, imaginary)` pairs."""
This function lowers directly to the `stablehlo.maximum`_ operation for
non-complex inputs. For complex numbers, this uses a lexicographic
comparison on the `(real, imaginary)` pairs.
Args:
x, y: Input arrays. Must have matching dtypes. If neither is a scalar,
``x`` and ``y`` must have the same rank and be broadcast compatible.
Returns:
An array of the same dtype as ``x`` and ``y`` containing the elementwise
maximum.
See also:
- :func:`jax.numpy.maximum`: more flexibly NumPy-style maximum.
- :func:`jax.lax.reduce_max`: maximum along an axis of an array.
- :func:`jax.lax.min`: elementwise minimum.
.. _stablehlo.maximum: https://openxla.org/stablehlo/spec#maximum
"""
return max_p.bind(x, y)
@export
def min(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise minimum: :math:`\mathrm{min}(x, y)`
r"""Elementwise minimum: :math:`\mathrm{min}(x, y)`
For complex numbers, uses a lexicographic comparison on the
`(real, imaginary)` pairs."""
This function lowers directly to the `stablehlo.minimum`_ operation for
non-complex inputs. For complex numbers, this uses a lexicographic
comparison on the `(real, imaginary)` pairs.
Args:
x, y: Input arrays. Must have matching dtypes. If neither is a scalar,
``x`` and ``y`` must have the same rank and be broadcast compatible.
Returns:
An array of the same dtype as ``x`` and ``y`` containing the elementwise
minimum.
See also:
- :func:`jax.numpy.minimum`: more flexibly NumPy-style minimum.
- :func:`jax.lax.reduce_min`: minimum along an axis of an array.
- :func:`jax.lax.max`: elementwise maximum.
.. _stablehlo.minimum: https://openxla.org/stablehlo/spec#minimum
"""
return min_p.bind(x, y)
@export
@ -1408,21 +1506,38 @@ def lt(x: ArrayLike, y: ArrayLike) -> Array:
"""
return lt_p.bind(x, y)
@export
def convert_element_type(operand: ArrayLike,
new_dtype: DTypeLike | dtypes.ExtendedDType) -> Array:
"""Elementwise cast.
Wraps XLA's `ConvertElementType
<https://www.tensorflow.org/xla/operation_semantics#convertelementtype>`_
operator, which performs an elementwise conversion from one type to another.
Similar to a C++ `static_cast`.
This function lowers directly to the `stablehlo.convert`_ operation, which
performs an elementwise conversion from one type to another, similar to a
C++ ``static_cast``.
Args:
operand: an array or scalar value to be cast.
new_dtype: a NumPy dtype representing the target type.
new_dtype: a dtype-like object (e.g. a :class:`numpy.dtype`, a scalar type,
or a valid dtype name) representing the target dtype.
Returns:
An array with the same shape as `operand`, cast elementwise to `new_dtype`.
An array with the same shape as ``operand``, cast elementwise to ``new_dtype``.
.. note::
If ``new_dtype`` is a 64-bit type and `x64 mode`_ is not enabled,
the appropriate 32-bit type will be used in its place.
If the input is a JAX array and the input dtype and output dtype match, then
the input array will be returned unmodified.
See also:
- :func:`jax.numpy.astype`: NumPy-style dtype casting API.
- :meth:`jax.Array.astype`: dtype casting as an array method.
- :func:`jax.lax.bitcast_convert_type`: cast bits directly to a new dtype.
.. _stablehlo.convert: https://openxla.org/stablehlo/spec#convert
.. _x64 mode: https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision
"""
return _convert_element_type(operand, new_dtype, weak_type=False) # type: ignore[unused-ignore,bad-return-type]
@ -1500,12 +1615,11 @@ def _convert_element_type(
operand, new_dtype=new_dtype, weak_type=bool(weak_type),
sharding=sharding)
@export
def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array:
"""Elementwise bitcast.
Wraps XLA's `BitcastConvertType
<https://www.tensorflow.org/xla/operation_semantics#bitcastconverttype>`_
operator, which performs a bit cast from one type to another.
This function lowers directly to the `stablehlo.bitcast_convert`_ operation.
The output shape depends on the size of the input and output dtypes with
the following logic::
@ -1525,6 +1639,12 @@ def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array:
Returns:
An array of shape `output_shape` (see above) and type `new_dtype`,
constructed from the same bits as operand.
See also:
- :func:`jax.lax.convert_element_type`: value-preserving dtype conversion.
- :func:`jax.Array.view`: NumPy-style API for bitcast type conversion.
.. _stablehlo.bitcast_convert: https://openxla.org/stablehlo/spec#bitcast_convert
"""
new_dtype = dtypes.canonicalize_dtype(new_dtype)
return bitcast_convert_type_p.bind(operand, new_dtype=new_dtype)

View File

@ -1301,6 +1301,7 @@ def _ragged_all_to_all_transpose(
mask = jax.numpy.cumsum(
jax.numpy.zeros(t.shape[0], dtype='int32').at[output_offsets_].set(1)\
.at[output_offsets_ + recv_sizes].add(-1))
mask = jax.numpy.expand_dims(mask, (*range(1, t.ndim),))
output_t = jax.numpy.where(mask, 0, t)
return [operand_t, output_t] + [None] * 4

View File

@ -543,6 +543,10 @@ class UseAbstractMeshContextManager:
__slots__ = ['mesh', 'prev']
def __init__(self, mesh: AbstractMesh):
if not isinstance(mesh, AbstractMesh):
raise ValueError(
"Expected mesh of type `jax.sharding.AbstractMesh`. Got type:"
f" {type(mesh)}")
self.mesh = mesh
def __enter__(self):
@ -557,13 +561,5 @@ def get_abstract_mesh():
val = jax_config.abstract_mesh_context_manager.value
return empty_abstract_mesh if val is None else val
@contextlib.contextmanager
def use_concrete_mesh(mesh: Mesh | None):
prev_val = jax_config.device_context.swap_local(mesh)
try:
yield
finally:
jax_config.device_context.set_local(prev_val)
def get_concrete_mesh() -> Mesh | None:
return jax_config.device_context.value

View File

@ -71,7 +71,7 @@ import numpy as np
export = set_module('jax.numpy')
for pkg_name in ['jax_cuda12_plugin', 'jax.jaxlib']:
for pkg_name in ['jax_cuda12_plugin', 'jax.jaxlib.cuda']:
try:
cuda_plugin_extension = importlib.import_module(
f'{pkg_name}.cuda_plugin_extension'

View File

@ -35,6 +35,7 @@ from jax._src import linear_util as lu
from jax._src import state
from jax._src import tree_util
from jax._src import util
from jax._src.export._export import export
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.state import discharge as state_discharge
@ -1165,14 +1166,16 @@ jax_core.custom_typechecks[core_map_p] = _core_map_typecheck_rule
def lower_as_mlir(
f, *args, dynamic_shapes=False, device=None, **kwargs
f, *args, dynamic_shapes=False, device=None, static_argnames=(), **kwargs
) -> mlir.ir.Module:
with pallas_export_experimental(dynamic_shapes):
lowered = jax.jit(f, device=device).lower(*args, **kwargs)
stablehlo = lowered.compiler_ir(dialect="stablehlo")
f = jax.jit(f, device=device, static_argnames=static_argnames)
exported = export(f, platforms=["tpu"])(*args, **kwargs)
stablehlo = exported.mlir_module()
return stablehlo # type: ignore[return-value]
_out_shape_to_aval_mapping: dict[
type[Any], Callable[[Any], jax_core.AbstractValue]
] = {}

View File

@ -244,6 +244,13 @@ def pull_block_spec(
_unwrap_block_spec_scalar_prefetch, out_block_specs
)
flat_block_specs, out_tree = jax.tree.flatten(block_specs_)
jaxpr, used_consts, used_invars = pe.dce_jaxpr_consts(
jaxpr,
used_outputs=[True] * len(jaxpr.outvars),
instantiate=True,
)
assert all(used_invars)
assert all(used_consts)
in_block_specs, env, read_usage_env = _pull_block_spec(
jaxpr,
tuple(flat_block_specs),

View File

@ -83,10 +83,15 @@ class TPUInterpretParams:
replaced with arrays all of `jnp.inf`. Additionaly any floating point
operands to any operation will be replaced with (arrays of) `jnp.inf`.
Default: False.
uninitialized_memory: If "nan", allocated buffers are initialized to
to contain all NaNs (or to their maximum possible value for integers).
If "zero", allocated buffers are initialized to all zeros.
Default: "nan".
"""
dma_execution_mode: Literal["eager", "on_wait"] = "on_wait"
detect_races: bool = False
skip_floating_point_ops: bool = False
uninitialized_memory: Literal["nan", "zero"] = "nan"
VectorClock = np.ndarray
@ -1114,7 +1119,8 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params, interpret_params):
jax.ShapeDtypeStruct((), jnp.int16),
device_id,
TPU_MEMORY_SPACE_IDXS[v.aval.memory_space],
primitives.uninitialized_value(v.aval.shape, v.aval.dtype),
_uninitialized_value(
v.aval.shape, v.aval.dtype, interpret_params),
ordered=True))
out = _interpret(eqn.params['jaxpr'], *deferred_invals(), *allocs)
@ -1279,16 +1285,19 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params, interpret_params):
def _initialize_output_vals(
block_mappings_output: Iterable[BlockMapping],
input_args, input_output_aliases) -> Sequence[jax.Array]:
input_args, input_output_aliases,
interpret_params: TPUInterpretParams,
) -> Sequence[jax.Array]:
oi_map = {v: k for k, v in input_output_aliases}
output_vals = []
for i, bm in enumerate(block_mappings_output):
if i in oi_map:
output_vals.append(input_args[oi_map[i]])
else:
output_vals.append(primitives.uninitialized_value(
output_vals.append(_uninitialized_value(
bm.array_shape_dtype.shape,
bm.array_shape_dtype.dtype))
bm.array_shape_dtype.dtype,
interpret_params))
return output_vals
def _compute_start_indices(block_mapping, loop_idx, *args):
@ -1319,7 +1328,20 @@ def _maybe_dynamic_slice(start_idx, block_shape, value, is_indexing):
dtype=np.bool_)])
return lax.squeeze(output, squeeze_dims)
def _pad_to_block_dimension(value, block_shape):
def _uninitialized_value(shape, dtype, interpret_params):
if interpret_params.uninitialized_memory == 'nan':
if jnp.issubdtype(dtype, jnp.floating):
return jnp.full(shape, jnp.nan, dtype)
elif jnp.issubdtype(dtype, jnp.integer):
return jnp.full(shape, jnp.iinfo(dtype).max, dtype)
elif jnp.issubdtype(dtype, jnp.bool):
return jnp.full(shape, False, dtype)
if interpret_params.uninitialized_memory == 'zero':
return jnp.full(shape, 0, dtype)
raise NotImplementedError(
interpret_params.uninitialized_memory + ' + ' + str(dtype))
def _pad_to_block_dimension(value, block_shape, interpret_params):
"""Pads values so the shape evenly divides into block dimensions.
For example, if values has a shape of (33, 2, 5) with a block_shape of
@ -1338,7 +1360,7 @@ def _pad_to_block_dimension(value, block_shape):
)
if padded_shape != value.shape:
pad_width = tuple((0, a-b) for a, b in zip(padded_shape, value.shape))
pad_value = primitives.uninitialized_value(shape=(), dtype=value.dtype)
pad_value = _uninitialized_value((), value.dtype, interpret_params)
value = jnp.pad(value, pad_width, constant_values=pad_value)
return value
@ -1397,7 +1419,7 @@ def interpret_pallas_call(
]
num_inputs = grid_mapping.num_inputs
input_args = [
_pad_to_block_dimension(a, bs)
_pad_to_block_dimension(a, bs, interpret_params)
for a, bs in zip(input_args, block_shapes[:num_inputs])
]
@ -1407,11 +1429,12 @@ def interpret_pallas_call(
output_vals = _initialize_output_vals(
grid_mapping.block_mappings_output,
scalars + input_args,
input_output_aliases)
input_output_aliases,
interpret_params)
num_outputs = grid_mapping.num_outputs
output_block_shapes = block_shapes[num_inputs : num_inputs + num_outputs]
for out_val, bs in zip(output_vals, output_block_shapes):
padded_val = _pad_to_block_dimension(out_val, bs)
padded_val = _pad_to_block_dimension(out_val, bs, interpret_params)
output_buffer_shapes.append(padded_val.shape)
output_buffer_ids.append(callback.io_callback(
_allocate_buffer,
@ -1466,7 +1489,8 @@ def interpret_pallas_call(
jax.ShapeDtypeStruct((), jnp.int16),
device_id,
TPU_MEMORY_SPACE_IDXS[var.aval.memory_space],
primitives.uninitialized_value(var.aval.shape, var.aval.dtype),
_uninitialized_value(
var.aval.shape, var.aval.dtype, interpret_params),
ordered=True))
_, input_ids, kernel_output_ids, _ = split_list(

View File

@ -40,6 +40,7 @@ from jax._src import source_info_util
from jax._src import state
from jax._src import traceback_util
from jax._src.cloud_tpu_init import is_cloud_tpu_older_than
from jax._src.export._export import export
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.lax import lax as lax_internal
@ -89,6 +90,11 @@ BOOL_MEMREF_TYPE = np.dtype('int32')
# The value interpreted as a dynamic dimension by MLIR.
MLIR_DYNAMIC = -9223372036854775808
# TODO(mvoz): Find a way to make this a contract we can share with the
# export specialization step in XLA export.
DIM_UPPER_BOUND = np.iinfo(np.int32).max
DIM_LOWER_BOUND = -128
partial = functools.partial
map, unsafe_map = safe_map, map # pylint: disable=redefined-builtin
zip, unsafe_zip = safe_zip, zip # pylint: disable=redefined-builtin
@ -102,17 +108,49 @@ class MeshContext:
# Note - On Export Placeholders
#
# Mosaic uses vector IR, which does not have a concept of dynamic
# dimensions. We need to come up with a way to represent dynamic dimensions in
# vector IR, and so we use placeholders, which are later replaced during
# specialization.
# Since the vector dialect used by Mosaic does not support dynamic shapes,
# we replace all top-level symbolic dimensions with placeholder
# constants (between max(int32) - 128 and max(int32)) and we keep a
# mapping from the placeholder constants to SHLO functions that encode
# the symbolic dimension expression, as a function of the dimension
# variables.
#
# The calling convention of the produced MLIR module is the same as
# regular mosaic module, except we add on two new attributes to the custom call
# *per* intermediary placeholder dimension.
#
# The attributes are:
#
# tpu.dynamic_dimension_mapping_arg_name_<placeholder>
# tpu.dynamic_dimension_mapping_module_<placeholder>
#
# The first attribute is a comma-separated list of the dimension variables
# that are used to compute the symbolic dimension expression for the
# placeholder. The second attribute is the MLIR module that contains the
# SHLO functions that compute the symbolic dimension expression for the
# placeholder.
class LoweringDynamicShapeEnv:
dim_expr_to_placeholder: dict[Any, ir.Value] = {}
dim_expr_to_placeholder: dict[shape_poly._DimExpr, int] = {}
placeholder_to_dim_expr: dict[int, shape_poly._DimExpr] = {}
def to_placeholder(self, dim_expr: Any) -> ir.Value:
if jax_core.is_constant_dim(dim_expr):
# avoid ints, these are not dynamic
return dim_expr
if dim_expr not in self.dim_expr_to_placeholder:
next_val = np.iinfo(np.int32).max - len(self.dim_expr_to_placeholder)
next_val = DIM_UPPER_BOUND - len(self.dim_expr_to_placeholder)
if next_val < DIM_LOWER_BOUND:
# In practice, even with the largest of programs, we see rarely see
# anything even close to this limit. It is arbitrary, and can be safely
# increased if needed.
raise ValueError(
"Too many dynamic shapes in the input. Mosaic currently only"
" supports up to 128 dynamic dimension values."
)
self.dim_expr_to_placeholder[dim_expr] = next_val
# Reverse mapping - this is consumed to generate a table that is either
# input<>placeholder or intermediary computation<>placeholder.
self.placeholder_to_dim_expr[next_val] = dim_expr
return self.dim_expr_to_placeholder[dim_expr]
@ -622,6 +660,7 @@ def lower_jaxpr_to_module(
"Pallas TPU requires a libTPU version that's at most a month old"
)
debug_info = jaxpr.debug_info
_mosaic_lowering_dynamic_shape_env = None
if dynamic_shape_replacement_enabled:
_mosaic_lowering_dynamic_shape_env = LoweringDynamicShapeEnv()
@ -663,10 +702,12 @@ def lower_jaxpr_to_module(
for_verification=for_verification,
forward_compatible=lowering_context.is_forward_compat(),
dynamic_shape_replacement_fn=dynamic_shape_replacement_fn,
dynamic_shape_replacement_enabled=dynamic_shape_replacement_enabled,
)
m.body.append(func_op)
sym_tab.insert(func_op)
window_params = []
static_grid = None
grid = mosaic_grid_mapping.grid
if grid:
for i, bm in enumerate(grid_mapping.block_mappings):
@ -738,7 +779,6 @@ def lower_jaxpr_to_module(
]
static_grid = dynamic_shape_replacement_fn(static_grid)
func_op.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(static_grid)
func_op.attributes["scalar_prefetch"] = ir.IntegerAttr.get(
ir.IntegerType.get_signless(64), len(mosaic_grid_mapping.scalar_prefetch_types))
func_op.attributes["scratch_operands"] = ir.IntegerAttr.get(
@ -746,6 +786,60 @@ def lower_jaxpr_to_module(
func_op.attributes["dimension_semantics"] = (
mosaic_grid_mapping.get_dimension_semantics()
)
if dynamic_shape_replacement_enabled:
if _mosaic_lowering_dynamic_shape_env is None:
raise ValueError(
"Dynamic shape env is None, invariant violated. Unreachable?"
)
# Now we can use jax to compute the dynamic shape graph
if static_grid is not None:
grid_vars = [
_mosaic_lowering_dynamic_shape_env.placeholder_to_dim_expr.get(g, g)
for g in static_grid
]
else:
grid_vars = []
invars = [invar.aval for invar in jaxpr.invars]
# Faux shape for grid, just to get the avals
invars.append(jax.ShapeDtypeStruct(grid_vars, jax.numpy.int32))
args_dimvars = shape_poly.all_dim_vars(invars)
# This is dimexpr var -> placeholder value for when we jit the dim expr
env: dict[str, int] = {}
for aval in args_dimvars:
env[aval] = _mosaic_lowering_dynamic_shape_env.to_placeholder(aval)
for (
placeholder,
dim_expr,
) in _mosaic_lowering_dynamic_shape_env.placeholder_to_dim_expr.items():
top_level_names = list(env.keys())
if dim_expr not in top_level_names:
jitted_eval = jax.jit(
jax_core.evaluate_shape,
static_argnames=(
"shape",
"dim_vars",
),
keep_unused=True,
)
stablehlo = export(
jitted_eval, platforms=[str(jax.devices()[0].platform)]
)(
(dim_expr,), tuple(args_dimvars), *(env[v] for v in args_dimvars)
).mlir_module()
arg_name = args_dimvars
# See Note - On Export Placeholders for more details.
m.operation.attributes[
"tpu.dynamic_dimension_mapping_module_" + str(placeholder)
] = ir.StringAttr.get(str(stablehlo))
arg_name_str = ",".join(arg_name)
m.operation.attributes[
"tpu.dynamic_dimension_mapping_arg_name_" + str(placeholder)
] = ir.StringAttr.get(arg_name_str)
return m, mosaic_grid_mapping.get_extra_args()
@ -828,6 +922,7 @@ def lower_jaxpr_to_func(
dynamic_shape_replacement_fn: (
Callable[[tuple[jax.DimSize, ...]], tuple[int, ...]] | None
) = None,
dynamic_shape_replacement_enabled: bool = False,
) -> func.FuncOp:
num_grid = len(mosaic_grid_mapping.grid_types)
num_scalar_prefetch = len(mosaic_grid_mapping.scalar_prefetch_types)
@ -874,6 +969,12 @@ def lower_jaxpr_to_func(
)
body_func.__name__ = name
body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func)
if dynamic_shape_replacement_enabled:
# Skip verification for dynamic shape replacement - you can potentially
# produce ir like ex: add(x[placeholder_0, placeholder_1], y[128, 128])
# which is not valid, but we don't care since we'll run the verifier again
# after the dynamic shape replacement pass.
return body.func_op
try:
body.func_op.verify()
except ir.MLIRError as e:
@ -3851,3 +3952,15 @@ def _platform_index_lowering(
lowering_rules[jax._src.lax.control_flow.platform_index_p] = _platform_index_lowering
def _dim_as_value_lowering(ctx: mlir.LoweringRuleContext, *, dim):
placeholder = ctx.lowering_context.dynamic_shape_replacement_fn((dim,))[0]
return ir_constant(
placeholder, mlir_type=_dtype_to_ir_type(jnp.dtype("int32"))
)
import jax._src.export.shape_poly as shape_poly
lowering_rules[shape_poly.dim_as_value_p] = _dim_as_value_lowering

View File

@ -1196,9 +1196,8 @@ def emit_pipeline(
schedule = map_brefs(
lambda _, x: get_pipeline_schedule(x), allocations, schedule)
def loop_body(step, indices):
nonlocal allocations
scheduler = Scheduler(
def make_scheduler(step, indices):
return Scheduler(
step,
indices,
grid,
@ -1208,13 +1207,15 @@ def emit_pipeline(
init_accumulators=init_accumulators,
trace_scopes=trace_scopes,
)
def loop_body(step, indices):
scheduler = make_scheduler(step, indices)
with scheduler.grid_env():
# prepare any local VMEM aliases
brefs = map_brefs(scheduler.alias_local_refs, allocations, refs)
# loop input handling phase
map_brefs(scheduler.initialize, brefs, refs, schedule)
map_brefs(scheduler.copy_in, brefs, refs, schedule)
map_brefs(scheduler.wait_in, brefs, refs, schedule)
@ -1243,12 +1244,24 @@ def emit_pipeline(
lambda: None)
map_brefs(scheduler.swap_slots, brefs, refs, schedule)
map_brefs(scheduler.finalize, brefs, refs, schedule)
return _next_index(indices, grid)
# run pipeline
lax.fori_loop(0, num_steps, loop_body, (0,) * len(grid))
@pl.when(num_steps > 0)
def _():
# pipeline prologue
initial_indices = (0,) * len(grid)
scheduler = make_scheduler(0, initial_indices)
brefs = map_brefs(scheduler.alias_local_refs, allocations, refs)
map_brefs(scheduler.initialize, brefs, refs, schedule)
# pipeline loop
next_indices = lax.fori_loop(0, num_steps, loop_body, initial_indices)
# pipeline epilogue
final_indices = _prev_index(next_indices, grid)
scheduler = make_scheduler(num_steps - 1, final_indices)
brefs = map_brefs(scheduler.alias_local_refs, allocations, refs)
map_brefs(scheduler.finalize, brefs, refs, schedule)
return pipeline

View File

@ -1387,16 +1387,38 @@ def use_mesh(mesh: mesh_lib.Mesh):
# if not core.trace_state_clean():
# raise ValueError('`use_mesh` can only be used outside of `jax.jit`')
with (mesh_lib.use_abstract_mesh(mesh.abstract_mesh),
mesh_lib.use_concrete_mesh(mesh)):
with mesh_lib.use_abstract_mesh(mesh.abstract_mesh), use_concrete_mesh(mesh):
yield
def set_mesh(mesh: mesh_lib.Mesh) -> None:
if not isinstance(mesh, mesh_lib.Mesh):
def set_mesh(mesh: mesh_lib.Mesh | None) -> mesh_lib.Mesh | None:
"""Sets the given concrete mesh globally and returns the previous concrete
mesh."""
if mesh is not None and not isinstance(mesh, mesh_lib.Mesh):
raise ValueError(
f"Expected mesh of type `jax.sharding.Mesh`. Got {type(mesh)}")
if not core.trace_state_clean():
raise ValueError('`set_mesh` can only be used outside of `jax.jit`.')
config.abstract_mesh_context_manager.set_local(mesh.abstract_mesh)
config.device_context.set_local(mesh)
if mesh is None:
config.abstract_mesh_context_manager.set_global(mesh_lib.empty_abstract_mesh) # type: ignore
else:
config.abstract_mesh_context_manager.set_global(mesh.abstract_mesh) # type: ignore
prev_mesh = config.device_context.get_global()
config.device_context.set_global(mesh)
return prev_mesh
@contextlib.contextmanager
def use_concrete_mesh(mesh: mesh_lib.Mesh | None):
if mesh is not None and not isinstance(mesh, mesh_lib.Mesh):
raise ValueError(
f"Expected mesh of type `jax.sharding.Mesh`. Got {type(mesh)}")
# TODO(yashkatariya): Enable this.
# if not core.trace_state_clean():
# raise ValueError('`use_concrete_mesh` can only be used outside of `jax.jit`.')
prev_val = config.device_context.swap_local(mesh)
try:
yield
finally:
config.device_context.set_local(prev_val)

View File

@ -41,20 +41,16 @@ from jaxlib.mlir.dialects import memref
from jaxlib.mlir.dialects import nvvm
import numpy as np
if dialect is not None:
from . import dialect_lowering
from . import layout_inference
else:
dialect_lowering = None
layout_inference = None
from . import profiler
from . import utils
from . import launch_context
from . import tcgen05
# mypy: ignore-errors
from . import dialect_lowering
from . import launch_context
from . import layout_inference
from . import profiler
from . import tcgen05
from . import transform_inference
from . import utils
# MLIR can't find libdevice unless we point it to the CUDA path
# TODO(apaszke): Unify with jax._src.lib.cuda_path
CUDA_ROOT = "/usr/local/cuda"
@ -584,6 +580,7 @@ def as_gpu_kernel(
# Run Python lowering passes. The remaining passes will be run in C++ in
# jax/jaxlib/mosaic/gpu/custom_call.cc
layout_inference.infer_layout(module) # pytype: disable=attribute-error
transform_inference.infer_transforms(module) # pytype: disable=attribute-error
dialect_lowering.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error
_initialize_scratch(launch_ctx, scratch_arr)
@ -666,6 +663,7 @@ def as_torch_gpu_kernel(
# Run Python lowering passes. The remaining passes will be run in C++ in
# jax/jaxlib/mosaic/gpu/custom_call.cc
layout_inference.infer_layout(module) # pytype: disable=attribute-error
transform_inference.infer_transforms(module) # pytype: disable=attribute-error
dialect_lowering.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error
_initialize_scratch(launch_ctx, scratch_arr)

View File

@ -17,6 +17,7 @@
from collections.abc import Callable
import dataclasses
import functools
import itertools
import operator
from typing import Any, Sequence, Type, cast
@ -58,7 +59,7 @@ class LoweringContext:
if not _should_lower(op):
return
if (name := op.OPERATION_NAME) not in _lowerings:
if (name := op.OPERATION_NAME) not in _lowerings: # pytype: disable=attribute-error
raise NotImplementedError(f"Missing lowering rule for {op}")
lowering_rule = _lowerings[name]
@ -227,6 +228,60 @@ def _arith_constant_op_lowering_rule(
]
def _check_transforms_and_swizzle_are_supported(
ref_ty: ir.MemRefType,
transforms: Sequence[launch_context.MemRefTransform],
swizzle: mgpu.SwizzlingMode,
minimum_swizzle: mgpu.SwizzlingMode = mgpu.SwizzlingMode.kNoSwizzle,
):
"""Checks that the list of provided transforms and swizzle are supported.
Currently, we allow the following:
- any swizzle that is larger than or equal to `minimum_swizzle`;
- optionally, a single tile transform (with rank equal to the rank of the
memref being annotated);
- optionally, a single transpose transform.
"""
if swizzle < minimum_swizzle:
raise NotImplementedError(
f"Unsupported swizzle {swizzle} smaller than {minimum_swizzle}."
)
partitioned_transforms = {
k: list(v)
for k, v in itertools.groupby(
transforms, lambda t: isinstance(t, launch_context.TileTransform)
)
}
tile_transforms = partitioned_transforms.get(True, [])
other_transforms = partitioned_transforms.get(False, [])
if len(tile_transforms) > 1:
raise NotImplementedError(
f"{tile_transforms} contains more than one tile transform."
)
if len(tile_transforms) == 1:
if len(tile_transforms[0].tiling) != len(ref_ty.shape):
raise NotImplementedError(
f"Only tile transforms with rank equal to the rank of the memref "
f"being annotated are supported but got {tile_transforms[0]} for "
f"{ref_ty}."
)
if len(other_transforms) > 1:
raise NotImplementedError(
f"{other_transforms} contains more than one transform."
)
if len(other_transforms) == 1:
if not isinstance(other_transforms[0], launch_context.TransposeTransform):
raise NotImplementedError(
f"{other_transforms[0]} is not a transpose transform."
)
@_register_lowering(vector.LoadOp)
def _vector_load_op_lowering_rule(
_: LoweringContext, vector_load_op: vector.LoadOp
@ -260,8 +315,11 @@ def _vector_load_op_lowering_rule(
vec_size=strided_layout.vec_size,
)
elif layouts.from_layout_attr(out_layout_attr) == fa.WGMMA_LAYOUT:
layout = ir.MemRefType(vector_load_op.base.type).layout
swizzle, transforms = memref_layout_to_swizzle_and_transforms(layout)
swizzle, transforms = swizzle_and_transforms_from_transforms_attr(
inference_utils.in_transforms(vector_load_op)[0]
)
ref_ty = ir.MemRefType(vector_load_op.base.type)
_check_transforms_and_swizzle_are_supported(ref_ty, transforms, swizzle)
transformed_ref = transform_memref(vector_load_op.base, transforms)
fragmented_array = fa.FragmentedArray.load_tiled(
transformed_ref,
@ -297,8 +355,22 @@ def _vector_store_op_lowering_rule(
vector_store_op.valueToStore, to_store_layout
)
# TODO(dasenov): This is not efficient for WGMMA layouts
fragmented_array.store_untiled(vector_store_op.base)
if fragmented_array.layout == fa.WGMMA_LAYOUT:
swizzle, transforms = swizzle_and_transforms_from_transforms_attr(
inference_utils.in_transforms(vector_store_op)[0]
)
ref_ty = ir.MemRefType(vector_store_op.base.type)
_check_transforms_and_swizzle_are_supported(ref_ty, transforms, swizzle)
fragmented_array.store_tiled(
transform_memref(vector_store_op.base, transforms), swizzle
)
elif (isinstance(fragmented_array.layout, fa.WGStridedFragLayout) or
isinstance(fragmented_array.layout, fa.WGSplatFragLayout)):
fragmented_array.store_untiled(vector_store_op.base)
else:
raise ValueError(
f"{vector_store_op} has an unsupported layout: {to_store_layout}"
)
return []
@ -362,39 +434,43 @@ def _vector_reduction_op_lowering_rule(
return [_fragmented_array_to_ir(result, op.result.type)]
def memref_layout_to_swizzle_and_transforms(
layout: ir.Attribute,
def swizzle_and_transforms_from_transforms_attr(
transforms: ir.ArrayAttr,
) -> tuple[mgpu.SwizzlingMode, tuple[launch_context.MemRefTransform, ...]]:
"""Returns the swizzle and transforms that are encoded in the given layout.
"""Returns the swizzle and MemrefTransforms for the given transforms.
If the layout is not a LayoutAttr, the swizzle is kNoSwizzle and the
transforms are empty. Otherwise, the layout may have at most one swizzle
transform and any combination of tiling and transpose transforms.
Args:
transforms: a list of transform attributes.
Returns:
A tuple containing the swizzle mode and MemRefTransforms corresponding to
the parameter transforms. If `transforms` is empty, or does not contain
any swizzling transform, the swizzle mode is assumed to be kNoSwizzle.
Raises:
ValueError: if a swizzling transform is followed by any transform.
"""
swizzle = None
gmem_transforms: list[launch_context.MemRefTransform] = []
if mgpu.LayoutAttr.isinstance(layout):
transforms_attr = mgpu.LayoutAttr(layout).transforms
for transform in transforms_attr:
if swizzle is not None:
raise ValueError(f"{layout} contains more transforms after the initial swizzle.")
if mgpu.SwizzleTransformAttr.isinstance(transform):
# TODO(dasenov): Swizzling can change if the ref is sliced in certain
# ways. We might want to enforce some restrictions here.
swizzle = mgpu.SwizzleTransformAttr(transform).swizzle
elif mgpu.TileTransformAttr.isinstance(transform):
tiling = mgpu.TileTransformAttr(transform).tiling
tiling_transform = launch_context.TileTransform(tuple(tiling))
gmem_transforms.append(tiling_transform)
elif mgpu.TransposeTransformAttr.isinstance(transform):
permutation = mgpu.TransposeTransformAttr(transform).permutation
transpose_transform = launch_context.TransposeTransform(
tuple(permutation)
)
gmem_transforms.append(transpose_transform)
else:
raise ValueError(f"{layout} has an unsupported transform: {transform}")
for transform in transforms:
if swizzle is not None:
raise ValueError(f"{transforms} contain more transforms after swizzle.")
if mgpu.SwizzleTransformAttr.isinstance(transform):
# TODO(dasenov): Swizzling can change if the ref is sliced in certain
# ways. We might want to enforce some restrictions here.
swizzle = mgpu.SwizzleTransformAttr(transform).swizzle
elif mgpu.TileTransformAttr.isinstance(transform):
tiling = mgpu.TileTransformAttr(transform).tiling
tiling_transform = launch_context.TileTransform(tuple(tiling))
gmem_transforms.append(tiling_transform)
elif mgpu.TransposeTransformAttr.isinstance(transform):
permutation = mgpu.TransposeTransformAttr(transform).permutation
transpose_transform = launch_context.TransposeTransform(
tuple(permutation)
)
gmem_transforms.append(transpose_transform)
else:
raise ValueError("Unknown transform: {transform}")
return swizzle or mgpu.SwizzlingMode.kNoSwizzle, tuple(gmem_transforms)
@ -434,8 +510,14 @@ def _mgpu_async_load_op_lowering_rule(
assert ctx.launch_context is not None
barrier = utils.BarrierRef.from_dialect_barrier_memref(load_op.barrier)
dst_layout = ir.MemRefType(load_op.destination.type).layout
swizzle, transforms = memref_layout_to_swizzle_and_transforms(dst_layout)
if inference_utils.has_in_transforms_set(load_op):
[transforms] = inference_utils.in_transforms(load_op)
swizzle, transforms = swizzle_and_transforms_from_transforms_attr(
transforms
)
else:
swizzle = mgpu.SwizzlingMode.kNoSwizzle
transforms = ()
gmem_slice = []
for idx_i32, size in zip(load_op.indices, load_op.slice_lengths):
@ -464,8 +546,14 @@ def _mgpu_async_store_op_lowering_rule(
) -> Sequence[ir.Value]:
assert ctx.launch_context is not None
src_layout = ir.MemRefType(store_op.source.type).layout
swizzle, transforms = memref_layout_to_swizzle_and_transforms(src_layout)
if inference_utils.has_in_transforms_set(store_op):
[transforms] = inference_utils.in_transforms(store_op)
swizzle, transforms = swizzle_and_transforms_from_transforms_attr(
transforms
)
else:
swizzle = mgpu.SwizzlingMode.kNoSwizzle
transforms = ()
gmem_slice = []
for idx_i32, size in zip(store_op.indices, store_op.slice_lengths):
@ -673,6 +761,9 @@ def _bitcast_op_lowering_rule(
def _mgpu_wgmma_op_lowering_rule(
_: LoweringContext, wgmma_op: mgpu.WGMMAOp
) -> Sequence[ir.Value]:
if wgmma_op.transpose_a or wgmma_op.transpose_b:
raise ValueError("Transpose arguments are to be deleted.")
fa_layouts = (
*inference_utils.in_layouts(wgmma_op),
*inference_utils.out_layouts(wgmma_op),
@ -691,25 +782,38 @@ def _mgpu_wgmma_op_lowering_rule(
regs = acc_in.to_layout(fa.WGMMA_LAYOUT)
acc = wgmma.WGMMAAccumulator.from_registers(regs)
b_layout = ir.MemRefType(wgmma_op.b.type).layout
b_swizzle, b_transforms = memref_layout_to_swizzle_and_transforms(b_layout)
if ir.VectorType.isinstance(wgmma_op.a.type):
a_transforms = None
b_transforms = inference_utils.in_transforms(wgmma_op)[0]
else:
a_transforms, b_transforms = inference_utils.in_transforms(wgmma_op)
b_swizzle, b_transforms = swizzle_and_transforms_from_transforms_attr(
b_transforms
)
minimum_swizzle = mgpu.SwizzlingMode.k32ByteSwizzle
ref_ty = ir.MemRefType(wgmma_op.b.type)
_check_transforms_and_swizzle_are_supported(
ref_ty, b_transforms, b_swizzle, minimum_swizzle
)
b_operand = transform_memref(wgmma_op.b, b_transforms)
if wgmma_op.transpose_b:
b_operand = utils.memref_transpose(b_operand, (0, 1, 3, 2))
if ir.VectorType.isinstance(wgmma_op.a.type):
a_operand = _fragmented_array_from_ir(wgmma_op.a, wgmma_layout)
else:
a_layout = ir.MemRefType(wgmma_op.a.type).layout
a_swizzle, a_transforms = memref_layout_to_swizzle_and_transforms(a_layout)
a_swizzle, a_transforms = swizzle_and_transforms_from_transforms_attr(
a_transforms
)
ref_ty = ir.MemRefType(wgmma_op.a.type)
_check_transforms_and_swizzle_are_supported(
ref_ty, a_transforms, a_swizzle, minimum_swizzle
)
if a_swizzle != b_swizzle:
raise ValueError(
f"Non-matching swizzles of operands a and b in WGMMA: {a_swizzle} !="
f" {b_swizzle}"
)
a_operand = transform_memref(wgmma_op.a, a_transforms)
if wgmma_op.transpose_a:
a_operand = utils.memref_transpose(a_operand, (0, 1, 3, 2))
new_acc = wgmma.wgmma(acc, a_operand, b_operand, swizzle=b_swizzle)
@ -902,7 +1006,7 @@ def single_thread_predicates(module: ir.Module) -> tuple[ir.Value, ir.Value]:
def _should_lower(op: ir.OpView) -> bool:
"""Returns 'true' if the operation should be lowered."""
return (
op.OPERATION_NAME.startswith("mosaic_gpu.")
op.OPERATION_NAME.startswith("mosaic_gpu.") # pytype: disable=attribute-error
or inference_utils.should_have_layout(op)
or any(bool(b) for r in op.regions for b in r) # Does it have subblocks?
)

View File

@ -387,7 +387,21 @@ class WGMMARowFragLayout:
"""[m] matrix, where m % 64 == 0."""
def thread_idxs(self, shape):
raise NotImplementedError
index = ir.IndexType.get()
assert len(shape) == 1
assert shape[0] % 64 == 0
tid = arith.index_cast(ir.IndexType.get(), mgpu.thread_idx())
tid_wg = arith.remui(tid, c(WARPGROUP_SIZE, index))
warp_idx = arith.divui(tid_wg, c(32, index))
lane_id = arith.remui(tid_wg, c(32, index))
row_base = arith.addi(
arith.divui(lane_id, c(4, index)), arith.muli(warp_idx, c(16, index))
)
for row_group in range(0, shape[0], 64):
for row_subgroup in (0, 8):
row = arith.addi(row_base, c(row_group + row_subgroup, index))
yield (row,)
@dataclasses.dataclass(frozen=True)
@ -660,6 +674,31 @@ class FragmentedArray:
vecs = [vector.load(vec_ty, ref, vec_idx) for vec_idx in layout.thread_idxs(shape)]
return cls(_registers=np.array(vecs), _layout=layout, _is_signed=is_signed)
@classmethod
def load_wgmma_row(
cls,
ref: ir.Value,
*,
is_signed: bool | None = None,
):
if not ir.MemRefType.isinstance(ref.type):
raise TypeError(ref.type)
ref_ty = ir.MemRefType(ref.type)
shape = tuple(ref_ty.shape)
if len(shape) != 1:
raise ValueError("WGMMARowFragLayout requires a 1D shape")
if shape[0] % 64:
raise ValueError(
"WGMMARowFragLayout requires shape[0] to be a multiple of 64"
)
layout = WGMMARowFragLayout()
registers = [memref.load(ref, [idx]) for (idx,) in layout.thread_idxs(shape)]
registers = np.array(registers).reshape(-1, 2)
return cls(_registers=registers, _layout=layout, _is_signed=is_signed)
@classmethod
def splat(cls, value, shape, layout=None, *, is_signed: bool | None = None):
layout = layout or WGSplatFragLayout(shape)
@ -1743,6 +1782,8 @@ class FragmentedArray:
)
match self.layout:
case WGMMARowFragLayout():
self._store_untiled_wgmma_row(ref)
case WGSplatFragLayout():
vs_unsupported()
self._store_untiled_splat(ref)
@ -1789,6 +1830,23 @@ class FragmentedArray:
for idx, reg in zip(idxs, self.registers.flat):
vector.store(reg, ref_, idx)
def _store_untiled_wgmma_row(self, ref: ir.Value):
"""Stores an array with a WGMMA row layout."""
assert self.layout == WGMMA_ROW_LAYOUT
index = ir.IndexType.get()
tid = arith.index_cast(ir.IndexType.get(), mgpu.thread_idx())
is_first = arith.cmpi(
arith.CmpIPredicate.eq, arith.remui(tid, c(4, index)), c(0, index)
)
# Consecutive groups of 4 threads hold the same value in this layout,
# therefore we only need to transfer data from one of them.
with utils.when(is_first):
for (idx,), value in zip(
self.layout.thread_idxs(self.shape), self.registers.flatten()
):
memref.store(value, ref, [idx])
def _store_untiled_tiled(self, ref: ir.Value, *, vector_store: bool = True):
"""Stores an array with a tiled layout. Not optimized at the moment."""
if utils.bitwidth(self.mlir_dtype) < 8:

View File

@ -383,89 +383,6 @@ def _infer_wgmma_op_layout(wgmma_op: mgpu.WGMMAOp) -> OptionalLayouts:
return [layout], [layout]
@dataclasses.dataclass()
class WGMMATransforms:
swizzle: mgpu.SwizzlingMode
a_tile: tuple[int, ...]
a_transpose: bool
b_tile: tuple[int, ...]
b_transpose: bool
def infer_wgmma_transforms(wgmma_op: mgpu.WGMMAOp) -> WGMMATransforms:
a_shape = cast(ir.ShapedType, wgmma_op.a.type).shape
k = a_shape[0] if wgmma_op.transpose_a else a_shape[1]
bitwidth = cast(ir.ShapedType, wgmma_op.a.type).element_type.width
# Try tiling with all swizzling modes starting from the largest one.
for swizzle in [
mgpu.SwizzlingMode.k128ByteSwizzle,
mgpu.SwizzlingMode.k64ByteSwizzle,
mgpu.SwizzlingMode.k32ByteSwizzle,
]:
s = swizzle * 8 // bitwidth
if k % s == 0:
return WGMMATransforms(
swizzle=swizzle,
a_tile=(s, 64) if wgmma_op.transpose_a else (64, s),
a_transpose=wgmma_op.transpose_a,
b_tile=(s, s),
b_transpose=wgmma_op.transpose_b,
)
raise ValueError(
"Could not infer layouts for memref feeding into WGMMA. The "
"non-contracting dimension ({k}) must be a multiple of "
"s = swizzle * (8 / bitwidth) where swizzle is a valid swizzle "
f"(32, 64, or 128) and bitwidth ({bitwidth}) is the element size of "
"`a` and `b`."
)
def _layout_for_memref_view(view_op: memref.ViewOp) -> ir.Attribute | None:
wgmma_use = None
uses = cast(ir.OpResult, view_op.result).uses
for use in uses:
user = use.owner
if isinstance(user, memref.CastOp):
# This memref is already cast, so we don't need to do anything.
return None
if isinstance(user, mgpu.WGMMAOp):
if wgmma_use is not None:
raise NotImplementedError(f"Multiple WGMMA consumers of {view_op}.")
wgmma_use = use
break
if (
not isinstance(user, mgpu.AsyncLoadOp)
and not isinstance(user, mgpu.AsyncStoreOp)
and not isinstance(user, vector.LoadOp)
and not isinstance(user, vector.StoreOp)
):
raise NotImplementedError(f"Unsupported user {user} of {view_op}.")
if wgmma_use is None:
# This memref is not used by a WGMMA operation, so we don't need to do
# anything.
return None
transforms = infer_wgmma_transforms(wgmma_use.owner)
if wgmma_use.operand_number == 1:
tile = transforms.a_tile
transpose = transforms.a_transpose
else:
tile = transforms.b_tile
transpose = transforms.b_transpose
transpose_attr = (
[mgpu.TransposeTransformAttr.get([1, 0, 2, 3])] if transpose else []
)
layout = mgpu.LayoutAttr.get(
2,
[mgpu.TileTransformAttr.get(tile)]
+ transpose_attr
+ [mgpu.SwizzleTransformAttr.get(transforms.swizzle)],
)
return layout
def _earliest_use(regions: list[ir.Region], uses: Sequence[ir.OpOperand]) -> ir.OpView:
owners = [use.owner for use in uses]
@ -607,11 +524,3 @@ def infer_layout(module: ir.Module):
for op in module.body:
traverse_op(op, set_default_layout)
def infer_memref_layouts_and_insert_casts(op: ir.OpView):
if op.name == "memref.view":
if layout := _layout_for_memref_view(op):
_insert_memref_layout_cast(layout, op)
for op in module.body:
traverse_op(op, infer_memref_layouts_and_insert_casts)

View File

@ -26,6 +26,9 @@ from typing import cast
from jax._src.lib import mosaic_gpu_dialect as mgpu
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith
from jax._src.lib.mlir.dialects import builtin
from jax._src.lib.mlir.dialects import gpu
from jax._src.lib.mlir.dialects import memref
from jax._src.lib.mlir.dialects import vector
from . import fragmented_array as fa
@ -169,7 +172,6 @@ def _infer_vector_load_store_transforms(
return None
# TODO(bchetioui): remove this once jaxlib minimum version >= 0.5.2.
SliceSMEMOp = getattr(mgpu, "SliceSMEMOp", None)
@ -196,6 +198,60 @@ def _infer_slice_smem_transforms(op: SliceSMEMOp) -> OptionalTransforms:
return None if transforms is None else ([], [transforms])
# TODO(bchetioui,apaszke): this empty rule is necessary while Mosaic doesn't use
# the dialect in all cases.
# The rule is necessary in order to handle the lowering of `utils.memref_ptr`
# which is used in `_construct_smem_reftree`.
@partial(_add_transform_inference_rule, builtin.UnrealizedConversionCastOp)
def _infer_unrealized_conversion_cast_transforms(
_: builtin.UnrealizedConversionCastOp,
) -> OptionalTransforms:
return None
@partial(_add_transform_inference_rule, memref.ViewOp)
def _infer_memref_view_transforms(op: memref.ViewOp) -> OptionalTransforms:
if not isinstance(op.source.owner.opview, gpu.DynamicSharedMemoryOp):
raise NotImplementedError(
"Memref view transforms are only inferred when the op is a direct user "
f"of a DynamicSharedMemoryOp but got {op}."
)
transforms = inference_utils.value_transforms(op.source)
if transforms is not None:
raise NotImplementedError(
"memref view with in_transforms aren't yet supported"
)
uses = cast(ir.OpResult, op.result).uses
for op_operand_use in uses:
consumer = op_operand_use.owner
op_user = consumer.operands[op_operand_use.operand_number]
out_transforms = inference_utils.in_transforms_for_operand(
consumer, op_user
)
if transforms is not None and out_transforms is not None:
if transforms != out_transforms:
raise ValueError(
f"Conflicting transforms for {op_user} in {op}: "
f"{transforms} != {out_transforms}."
)
elif out_transforms is not None:
transforms = out_transforms
# TODO(bchetioui): do we actually need to assign a transform to the input of
# the view op? Presumably, it'll only be used to access scratch memory.
return None if transforms is None else ([], [transforms])
# TODO(bchetioui,apaszke): this empty rule is necessary while Mosaic doesn't use
# the dialect in all cases.
@partial(_add_transform_inference_rule, gpu.DynamicSharedMemoryOp)
def _infer_dynamic_smem_transforms(
_: gpu.DynamicSharedMemoryOp,
) -> OptionalTransforms:
return None
def _should_have_transforms(op: ir.OpView) -> bool:
"""Returns 'True' if the operation should be assigned in/out transforms."""
return any(
@ -218,7 +274,6 @@ def infer_transforms(module: ir.Module):
specified. We error out if two distinct sets of transforms are competing to
annotate the same memref.
"""
def inference_step(op: ir.Operation):
if not _should_have_transforms(op):
return

View File

@ -24,7 +24,7 @@ import jax._src.xla_bridge as xb
# cuda_plugin_extension locates inside jaxlib. `jaxlib` is for testing without
# preinstalled jax cuda plugin packages.
for pkg_name in ['jax_cuda12_plugin', 'jaxlib']:
for pkg_name in ['jax_cuda12_plugin', 'jaxlib.cuda']:
try:
cuda_plugin_extension = importlib.import_module(
f'{pkg_name}.cuda_plugin_extension'

View File

@ -23,7 +23,7 @@ import jax._src.xla_bridge as xb
# rocm_plugin_extension locates inside jaxlib. `jaxlib` is for testing without
# preinstalled jax rocm plugin packages.
for pkg_name in ['jax_rocm60_plugin', 'jaxlib']:
for pkg_name in ['jax_rocm60_plugin', 'jaxlib.cuda']:
try:
rocm_plugin_extension = importlib.import_module(
f'{pkg_name}.rocm_plugin_extension'

View File

@ -222,84 +222,3 @@ nanobind_extension(
"@xla//third_party/python_runtime:headers",
],
)
cc_library(
name = "gpu_plugin_extension",
srcs = ["gpu_plugin_extension.cc"],
hdrs = ["gpu_plugin_extension.h"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
deps = [
":kernel_nanobind_helpers",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/strings:string_view",
"@nanobind",
"@xla//xla:util",
"@xla//xla/ffi/api:c_api",
"@xla//xla/pjrt:status_casters",
"@xla//xla/pjrt/c:pjrt_c_api_ffi_extension_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_helpers",
"@xla//xla/pjrt/c:pjrt_c_api_triton_extension_hdrs",
"@xla//xla/python:py_client_gpu",
"@xla//xla/tsl/python/lib/core:numpy",
],
)
nanobind_extension(
name = "cuda_plugin_extension",
srcs = ["cuda_plugin_extension.cc"],
module_name = "cuda_plugin_extension",
deps = [
":gpu_plugin_extension",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@local_config_cuda//cuda:cuda_headers",
"@nanobind",
"@xla//xla/pjrt:status_casters",
"@xla//xla/tsl/cuda:cublas",
"@xla//xla/tsl/cuda:cudart",
],
)
nanobind_extension(
name = "rocm_plugin_extension",
srcs = ["rocm_plugin_extension.cc"],
module_name = "rocm_plugin_extension",
deps = [
":gpu_plugin_extension",
"@com_google_absl//absl/log",
"@com_google_absl//absl/strings",
"@local_config_rocm//rocm:hip",
"@local_config_rocm//rocm:rocm_headers",
"@nanobind",
],
)
# CPU kernels
# TODO(phawkins): Remove this forwarding target.
cc_library(
name = "cpu_kernels",
visibility = ["//visibility:public"],
deps = [
"//jaxlib/cpu:cpu_kernels",
],
alwayslink = 1,
)
# TODO(phawkins): Remove this forwarding target.
cc_library(
name = "gpu_kernels",
visibility = ["//visibility:public"],
deps = [
"//jaxlib/cuda:cuda_gpu_kernels",
],
alwayslink = 1,
)

View File

@ -657,6 +657,22 @@ py_library(
],
)
nanobind_extension(
name = "cuda_plugin_extension",
srcs = ["cuda_plugin_extension.cc"],
module_name = "cuda_plugin_extension",
deps = [
"//jaxlib/gpu:gpu_plugin_extension",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@local_config_cuda//cuda:cuda_headers",
"@nanobind",
"@xla//xla/pjrt:status_casters",
"@xla//xla/tsl/cuda:cublas",
"@xla//xla/tsl/cuda:cudart",
],
)
# We cannot nest select and if_cuda_is_configured so we introduce
# a standalone py_library target.
py_library(
@ -664,6 +680,6 @@ py_library(
# `if_cuda_is_configured` will default to `[]`.
deps = if_cuda_is_configured([
":cuda_gpu_support",
"//jaxlib:cuda_plugin_extension",
":cuda_plugin_extension",
]),
)

View File

@ -20,7 +20,7 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "jaxlib/gpu_plugin_extension.h"
#include "jaxlib/gpu/gpu_plugin_extension.h"
#include "xla/pjrt/status_casters.h"
namespace nb = nanobind;

View File

@ -90,3 +90,32 @@ xla_py_proto_library(
visibility = jax_visibility("triton_proto_py_users"),
deps = [":triton_proto"],
)
cc_library(
name = "gpu_plugin_extension",
srcs = ["gpu_plugin_extension.cc"],
hdrs = ["gpu_plugin_extension.h"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
deps = [
"//jaxlib:kernel_nanobind_helpers",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/strings:string_view",
"@nanobind",
"@xla//xla:util",
"@xla//xla/ffi/api:c_api",
"@xla//xla/pjrt:status_casters",
"@xla//xla/pjrt/c:pjrt_c_api_ffi_extension_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_helpers",
"@xla//xla/pjrt/c:pjrt_c_api_triton_extension_hdrs",
"@xla//xla/python:py_client_gpu",
"@xla//xla/tsl/python/lib/core:numpy",
],
)

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/gpu_plugin_extension.h"
#include "jaxlib/gpu/gpu_plugin_extension.h"
#include <cstddef>
#include <cstdint>

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef JAXLIB_GPU_PLUGIN_EXTENSION_H_
#define JAXLIB_GPU_PLUGIN_EXTENSION_H_
#ifndef JAXLIB_GPU_GPU_PLUGIN_EXTENSION_H_
#define JAXLIB_GPU_GPU_PLUGIN_EXTENSION_H_
#include "nanobind/nanobind.h"
@ -24,4 +24,4 @@ void BuildGpuPluginExtension(nanobind::module_& m);
} // namespace xla
#endif // JAXLIB_GPU_PLUGIN_EXTENSION_H_
#endif // JAXLIB_GPU_GPU_PLUGIN_EXTENSION_H_

View File

@ -143,7 +143,7 @@ mlir::FailureOr<mlir::OpPassManager> GetPassPipeline(
mlir::memref::registerMemRefPasses();
mlir::registerConvertToLLVMPass();
mlir::registerGPUPasses();
mlir::registerGpuLaunchSinkIndexComputations();
mlir::registerGpuLaunchSinkIndexComputationsPass();
mosaic::gpu::registerGpuLaunchLoweringPass();
mosaic::gpu::registerConvertGpuToLLVMPass();
mosaic::gpu::registerByvalInsertionPass();

View File

@ -555,11 +555,25 @@ py_library(
],
)
nanobind_extension(
name = "rocm_plugin_extension",
srcs = ["rocm_plugin_extension.cc"],
module_name = "rocm_plugin_extension",
deps = [
"//jaxlib/gpu:gpu_plugin_extension",
"@com_google_absl//absl/log",
"@com_google_absl//absl/strings",
"@local_config_rocm//rocm:hip",
"@local_config_rocm//rocm:rocm_headers",
"@nanobind",
],
)
py_library(
name = "gpu_only_test_deps",
# `if_rocm_is_configured` will default to `[]`.
deps = if_rocm_is_configured([
":rocm_gpu_support",
"//jaxlib:rocm_plugin_extension",
":rocm_plugin_extension",
]),
)

View File

@ -20,7 +20,7 @@ limitations under the License.
#include "absl/log/log.h"
#include "absl/strings/str_cat.h"
#include "rocm/include/hip/hip_runtime.h"
#include "jaxlib/gpu_plugin_extension.h"
#include "jaxlib/gpu/gpu_plugin_extension.h"
namespace nb = nanobind;

View File

@ -143,16 +143,16 @@ py_binary(
data = [
"LICENSE.txt",
] + if_cuda([
"//jaxlib/mosaic/gpu:mosaic_gpu",
"//jaxlib:cuda_plugin_extension",
"//jaxlib:version",
"//jaxlib/mosaic/gpu:mosaic_gpu",
"//jaxlib/cuda:cuda_plugin_extension",
"//jaxlib/cuda:cuda_gpu_support",
"//jax_plugins/cuda:plugin_pyproject.toml",
"//jax_plugins/cuda:plugin_setup.py",
"@local_config_cuda//cuda:cuda-nvvm",
]) + if_rocm([
"//jaxlib:rocm_plugin_extension",
"//jaxlib:version",
"//jaxlib/rocm:rocm_plugin_extension",
"//jaxlib/rocm:rocm_gpu_support",
"//jax_plugins/rocm:plugin_pyproject.toml",
"//jax_plugins/rocm:plugin_setup.py",

View File

@ -110,7 +110,7 @@ def prepare_wheel_cuda(
f"__main__/jaxlib/cuda/_triton.{pyext}",
f"__main__/jaxlib/cuda/_hybrid.{pyext}",
f"__main__/jaxlib/cuda/_versions.{pyext}",
f"__main__/jaxlib/cuda_plugin_extension.{pyext}",
f"__main__/jaxlib/cuda/cuda_plugin_extension.{pyext}",
f"__main__/jaxlib/mosaic/gpu/_mosaic_gpu_ext.{pyext}",
"__main__/jaxlib/mosaic/gpu/libmosaic_gpu_runtime.so",
"__main__/jaxlib/version.py",
@ -148,7 +148,7 @@ def prepare_wheel_rocm(
f"__main__/jaxlib/rocm/_hybrid.{pyext}",
f"__main__/jaxlib/rocm/_rnn.{pyext}",
f"__main__/jaxlib/rocm/_triton.{pyext}",
f"__main__/jaxlib/rocm_plugin_extension.{pyext}",
f"__main__/jaxlib/rocm/rocm_plugin_extension.{pyext}",
"__main__/jaxlib/version.py",
],
)

View File

@ -29,16 +29,23 @@ def call_kernel(
kernel,
grid: tuple[int, int],
transpose_grid: bool,
*args
key: jax.Array,
total_size: tuple[int, int],
block_size: tuple[int, int],
tile_size: tuple[int, int],
):
"""Calls a kernel over a grid and concatenates results to a single array."""
if transpose_grid:
grid = (grid[1], grid[0])
m, n = grid
return jnp.concatenate([
jnp.concatenate([
kernel((i, j), *args) for j in range(n)], axis=1)
for i in range(m)], axis=0)
samples = jnp.concatenate([
jnp.concatenate([
kernel((i, j), key, total_size, block_size, tile_size)
for j in range(n)], axis=1)
for i in range(m)], axis=0)
# Slice out the padding.
samples = samples[:total_size[0], :total_size[1]]
return samples
def call_kernel_3d(
@ -73,10 +80,10 @@ def uniform_kernel(block_index, key, total_size, block_size, tile_size):
block_size=block_size,
tile_size=tile_size)
return blocked_sampler.sample_block(jax.random.uniform,
keys,
block_size=block_size,
tile_size=tile_size,
minval=0.0, maxval=1.0)
keys,
block_size=block_size,
tile_size=tile_size,
minval=0.0, maxval=1.0)
class BlockedSamplerTest(jtu.JaxTestCase):
@ -94,16 +101,25 @@ class BlockedSamplerTest(jtu.JaxTestCase):
dict(testcase_name='16x256_vs_32x128', total_size=(32, 256),
block_size_a=(16, 256), block_size_b=(32, 128),
tile_size=(8, 128), transpose_grid=False),
dict(testcase_name='128x128_vs_128x256_padding',
total_size=(256, 128), block_size_a=(128, 128),
block_size_b=(128, 256), tile_size=(128, 128), transpose_grid=False),
dict(testcase_name='128x128_vs_128x256_padding2',
total_size=(257, 129), block_size_a=(128, 128),
block_size_b=(128, 256), tile_size=(128, 128), transpose_grid=False),
)
def test_block_shape_invariance(self, total_size, block_size_a,
block_size_b, tile_size, transpose_grid):
global_key = jax.random.key(0)
grid_a = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_a))
ceil_div = lambda x, y: (x + y - 1) // y
grid_a = tuple(ceil_div(_tot, _blk)
for _tot, _blk in zip(total_size, block_size_a))
result_a = call_kernel(
uniform_kernel, grid_a, transpose_grid, global_key,
total_size, block_size_a, tile_size)
grid_b = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_b))
grid_b = tuple(ceil_div(_tot, _blk)
for _tot, _blk in zip(total_size, block_size_b))
result_b = call_kernel(
uniform_kernel, grid_b, transpose_grid, global_key,
total_size, block_size_b, tile_size)

View File

@ -15,9 +15,9 @@
import contextlib
import io
import logging
import os
import platform
import re
import shlex
import subprocess
import sys
import tempfile
@ -78,6 +78,31 @@ def capture_jax_logs():
logger.removeHandler(handler)
# Saves and runs script from the file in order to fix the problem with
# `tsl::Env::Default()->GetExecutablePath()` not working properly with
# command flag.
def _run(program, env_var = {}):
# strip the leading whitespace from the program script
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
with tempfile.NamedTemporaryFile(
mode="w+", encoding="utf-8", suffix=".py", dir=os.getcwd()
) as f:
f.write(textwrap.dedent(program))
f.flush()
python = sys.executable
assert "python" in python
if env_var:
env_var.update(os.environ)
else:
env_var = os.environ
# Make sure C++ logging is at default level for the test process.
p = subprocess.run([python, f.name], env=env_var, capture_output=True, text=True)
return type("", (object,), { "stdout": p.stdout, "stderr": p.stderr })
class LoggingTest(jtu.JaxTestCase):
@unittest.skipIf(platform.system() == "Windows",
@ -90,36 +115,25 @@ class LoggingTest(jtu.JaxTestCase):
if sys.executable is None:
raise self.skipTest("test requires access to python binary")
# Save script in file to fix the problem with
# `tsl::Env::Default()->GetExecutablePath()` not working properly with
# command flag.
with tempfile.NamedTemporaryFile(
mode="w+", encoding="utf-8", suffix=".py"
) as f:
f.write(textwrap.dedent("""
o = _run("""
import jax
jax.device_count()
f = jax.jit(lambda x: x + 1)
f(1)
f(2)
jax.numpy.add(1, 1)
"""))
python = sys.executable
assert "python" in python
# Make sure C++ logging is at default level for the test process.
proc = subprocess.run([python, f.name], capture_output=True)
""")
lines = proc.stdout.split(b"\n")
lines.extend(proc.stderr.split(b"\n"))
allowlist = [
b"",
(
b"An NVIDIA GPU may be present on this machine, but a"
b" CUDA-enabled jaxlib is not installed. Falling back to cpu."
),
]
lines = [l for l in lines if l not in allowlist]
self.assertEmpty(lines)
lines = o.stdout.split("\n")
lines.extend(o.stderr.split("\n"))
allowlist = [
(
"An NVIDIA GPU may be present on this machine, but a"
" CUDA-enabled jaxlib is not installed. Falling back to cpu."
),
]
lines = [l for l in lines if l in allowlist]
self.assertEmpty(lines)
def test_debug_logging(self):
# Warmup so we don't get "No GPU/TPU" warning later.
@ -164,19 +178,12 @@ class LoggingTest(jtu.JaxTestCase):
if sys.executable is None:
raise self.skipTest("test requires access to python binary")
program = """
import jax # this prints INFO logging from backend imports
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
"""
o = _run("""
import jax # this prints INFO logging from backend imports
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
""", { "JAX_LOGGING_LEVEL": "INFO" })
# strip the leading whitespace from the program script
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
# test INFO
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=INFO {sys.executable} -c"
f" '{program}'")
p = subprocess.run(cmd, capture_output=True, text=True)
log_output = p.stderr
log_output = o.stderr
info_lines = log_output.split("\n")
self.assertGreater(len(info_lines), 0)
self.assertIn("INFO", log_output)
@ -194,22 +201,14 @@ class LoggingTest(jtu.JaxTestCase):
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
"""
# strip the leading whitespace from the program script
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
o = _run(program, { "JAX_LOGGING_LEVEL": "DEBUG" })
# test DEBUG
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c"
f" '{program}'")
p = subprocess.run(cmd, capture_output=True, text=True)
log_output = p.stderr
log_output = o.stderr
self.assertIn("INFO", log_output)
self.assertIn("DEBUG", log_output)
# test JAX_DEBUG_MODULES
cmd = shlex.split(f"env JAX_DEBUG_LOG_MODULES=jax {sys.executable} -c"
f" '{program}'")
p = subprocess.run(cmd, capture_output=True, text=True)
log_output = p.stderr
o = _run(program, { "JAX_DEBUG_LOG_MODULES": "jax" })
log_output = o.stderr
self.assertIn("DEBUG", log_output)
@jtu.skip_on_devices("tpu")
@ -220,22 +219,15 @@ class LoggingTest(jtu.JaxTestCase):
raise self.skipTest("test requires access to python binary")
_separator = "---------------------------"
program = f"""
o = _run(f"""
import sys
import jax # this prints INFO logging from backend imports
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
jax.config.update("jax_logging_level", None)
sys.stderr.write("{_separator}")
jax.jit(lambda x: x)(1) # should not log anything now
"""
# strip the leading whitespace from the program script
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c"
f" '{program}'")
p = subprocess.run(cmd, capture_output=True, text=True)
log_output = p.stderr
""", {"JAX_LOGGING_LEVEL": "DEBUG"})
log_output = o.stderr
m = re.search(_separator, log_output)
self.assertTrue(m is not None)
log_output_verbose = log_output[:m.start()]
@ -252,19 +244,13 @@ class LoggingTest(jtu.JaxTestCase):
if sys.executable is None:
raise self.skipTest("test requires access to python binary")
program = """
o = _run("""
import jax # this prints INFO logging from backend imports
jax.config.update("jax_debug_log_modules", "jax._src.compiler,jax._src.dispatch")
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
"""
""", { "JAX_LOGGING_LEVEL": "DEBUG" })
# strip the leading whitespace from the program script
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c"
f" '{program}'")
p = subprocess.run(cmd, capture_output=True, text=True)
log_output = p.stderr
log_output = o.stderr
self.assertNotEmpty(log_output)
log_lines = log_output.strip().split("\n")
# only one tracing line should be printed, if there's more than one
@ -285,31 +271,19 @@ class LoggingTest(jtu.JaxTestCase):
jax.distributed.initialize("127.0.0.1:12345", num_processes=1, process_id=0)
"""
# strip the leading whitespace from the program script
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
o = _run(program, { "JAX_LOGGING_LEVEL": "DEBUG" })
self.assertIn("Initializing CoordinationService", o.stderr)
# verbose logging: DEBUG, VERBOSE
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c"
f" '{program}'")
p = subprocess.run(cmd, capture_output=True, text=True)
self.assertIn("Initializing CoordinationService", p.stderr)
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=INFO {sys.executable} -c"
f" '{program}'")
p = subprocess.run(cmd, capture_output=True, text=True)
self.assertIn("Initializing CoordinationService", p.stderr)
o = _run(program, { "JAX_LOGGING_LEVEL": "INFO" })
self.assertIn("Initializing CoordinationService", o.stderr)
# verbose logging: WARNING, None
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=WARNING {sys.executable} -c"
f" '{program}'")
p = subprocess.run(cmd, capture_output=True, text=True)
self.assertNotIn("Initializing CoordinationService", p.stderr)
o = _run(program, { "JAX_LOGGING_LEVEL": "WARNING" })
self.assertNotIn("Initializing CoordinationService", o.stderr)
cmd = shlex.split(f"{sys.executable} -c"
f" '{program}'")
p = subprocess.run(cmd, capture_output=True, text=True)
o = _run(program)
if int(_default_TF_CPP_MIN_LOG_LEVEL) >= 1:
self.assertNotIn("Initializing CoordinationService", p.stderr)
self.assertNotIn("Initializing CoordinationService", o.stderr)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -1664,31 +1664,30 @@ class StreamAnnotationTest(jtu.JaxTestCase):
def test_stream_annotation_inside_shmap(self):
if not jtu.test_device_matches(["gpu"]):
self.skipTest("Stream annotation is only supported on GPU.")
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
s = NamedSharding(mesh, P('x', 'y'))
np_inp = np.ones((8, 8))
mesh = jtu.create_mesh((2,), ('x',))
s = NamedSharding(mesh, P('x'))
np_inp = np.ones((8,))
arr1 = jax.device_put(np_inp, s)
arr2 = jax.device_put(np_inp, s)
@compute_on('gpu_stream:1')
@jax.jit
def g(x, y):
return x @ y
return x * y
@compute_on('gpu_stream:2')
@jax.jit
def h(x, y):
return x @ y
return x * y
def f(x, y):
z = g(x, y)
w = h(3 * x, 2 * y)
return z + w
out = jax.jit(shard_map(f, mesh=mesh,
in_specs=(P('x', 'y'), P('x', 'y')),
out_specs=P('x', 'y')))(arr1, arr2)
self.assertArraysEqual(out, arr1 * 28)
out = jax.jit(shard_map(f, mesh=mesh, in_specs=(P('x'), P('x')),
out_specs=P('x')))(arr1, arr2)
self.assertArraysEqual(out, arr1 * 7)
class ActivationOffloadingTest(jtu.JaxTestCase):

View File

@ -55,6 +55,7 @@ else:
from jax.experimental.mosaic.gpu import launch_context
from jax.experimental.mosaic.gpu import utils as utils
from jax.experimental.mosaic.gpu import profiler
from jax.experimental.mosaic.gpu import inference_utils
from jax.experimental.mosaic.gpu.utils import * # noqa: F403
from jax._src.lib.mlir.dialects import gpu
from jax._src.lib.mlir.dialects import llvm
@ -1945,6 +1946,21 @@ class FragmentedArrayTest(TestCase):
)(inp)
np.testing.assert_array_equal(inp, result)
@parameterized.product(in_shape=((128,), (64,)))
def test_wgmma_row_load_store_with_layout(self, in_shape):
def kernel(ctx, *args):
gmem_input, gmem_output, (smem_input, smem_output) = args
copy(gmem_input, smem_input)
t = mgpu.FragmentedArray.load_wgmma_row(smem_input)
t.store_untiled(smem_output)
copy(smem_output, gmem_output)
inp = out = self.prng.uniform(-1, 1, in_shape).astype(jnp.float32)
result = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (inp,), out, [inp, out],
)(inp)
np.testing.assert_array_equal(inp, result)
def test_warp_tree_reduce(self):
def kernel(ctx, out, *_):
del ctx
@ -2405,25 +2421,21 @@ class Swizzle:
return mgpu_dialect.SwizzleTransformAttr.get(self.swizzle)
def memref_with_transforms(
mem_ref: ir.Value,
transforms: Sequence[Tile | Transpose | Swizzle],
) -> ir.Value:
"""Casts the memref to one that has a layout with the given transforms."""
mem_ref_type = ir.MemRefType(mem_ref.type)
def set_in_transforms(
op: ir.OpView, transforms: Sequence[Sequence[Tile | Transpose | Swizzle]],
) -> None:
"""Annotates an op with in_transforms."""
if not transforms:
return
transform_attr = [t.attr() for t in transforms]
if not transform_attr:
return mem_ref
in_transforms = []
smem_refs = filter(inference_utils.is_transformable_smem_memref, op.operands) # pylint: disable=undefined-variable
for _, result_transforms in jax.util.safe_zip(smem_refs, transforms):
in_transforms.append(
ir.ArrayAttr.get([t.attr() for t in result_transforms])
)
layout = mgpu_dialect.LayoutAttr.get(mem_ref_type.rank, transform_attr)
memref_new_type = ir.MemRefType.get(
mem_ref_type.shape,
mem_ref_type.element_type,
layout,
mem_ref_type.memory_space,
)
return memref.cast(memref_new_type, mem_ref)
op.attributes["in_transforms"] = ir.ArrayAttr.get(in_transforms)
class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
@ -2556,7 +2568,6 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
):
del ctx
smem_ref, tma_barrier = smem
smem_ref = memref_with_transforms(smem_ref, test_case.transforms)
dialect_barrier = tma_barrier.as_dialect_barrier_memref()
elt_type = ir.MemRefType(in_gmem_ref.type).element_type
@ -2571,7 +2582,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
slice_indices = [arith.constant(i32, i) for i in test_case.slice_indices]
# GMEM -> SMEM
mgpu_dialect.async_load(
load_op = mgpu_dialect.AsyncLoadOp(
source=in_gmem_ref,
destination=smem_ref,
barrier=dialect_barrier,
@ -2579,6 +2590,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
slice_lengths=test_case.slice_lengths,
collective=ir.ArrayAttr.get([]),
)
set_in_transforms(load_op, [test_case.transforms])
parities = memref.load(tma_barrier.phases, [])
parity, _ = tma_barrier.update_parities(parities)
@ -2623,58 +2635,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
(x[input_slice]).reshape(test_case.shape_sliced),
)
@staticmethod
def pointwise_kernel_with_tma_cases(dtype: jnp.dtype):
@dataclasses.dataclass(frozen=True)
class TestCaseInput:
shape: tuple[int, ...]
transforms: tuple[Tile | Transpose | Swizzle, ...] = ()
result = []
for swizzle in mgpu_dialect.SwizzlingMode:
n = swizzle * 8 // jnp.finfo(dtype).bits
if swizzle == mgpu_dialect.SwizzlingMode.kNoSwizzle:
# We need at least one case with no transforms, as this is handled
# differently.
result.append(TestCaseInput(shape=[128, n]))
result.extend([
TestCaseInput(
shape=[128, n],
transforms=[Swizzle(swizzle)],
),
TestCaseInput(
shape=[2, 3, 64, n],
transforms=[Transpose([0, 1, 2, 3]), Swizzle(swizzle)],
),
TestCaseInput(
shape=[2, 3, 64, n],
transforms=[
Transpose([1, 0, 2, 3]),
Transpose([1, 0, 2, 3]),
Swizzle(swizzle),
],
),
TestCaseInput(
shape=[2, 3, 64, n],
transforms=[Transpose([1, 0, 2, 3]), Swizzle(swizzle)],
),
TestCaseInput(
shape=[128, n],
transforms=[Tile([64, n]), Swizzle(swizzle)],
),
TestCaseInput(
shape=[2 * 64, 3 * n],
transforms=[
Tile([64, n]),
Transpose([1, 0, 2, 3]),
Swizzle(swizzle),
],
),
])
return result
@parameterized.parameters(pointwise_kernel_with_tma_cases(jnp.bfloat16))
def test_pointwise_kernel_with_tma(self, test_case):
def test_pointwise_kernel_with_tma(self):
def add(
ctx: launch_context.LaunchContext,
a_gmem_ref: ir.Value,
@ -2701,9 +2662,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
# GMEM -> SMEM
mgpu_dialect.async_load(
source=a_gmem_ref,
destination=memref_with_transforms(
a_smem_ref, test_case.transforms
),
destination=a_smem_ref,
barrier=dialect_barrier,
indices=zero_slice_indices,
slice_lengths=shape,
@ -2711,9 +2670,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
)
mgpu_dialect.async_load(
source=b_gmem_ref,
destination=memref_with_transforms(
b_smem_ref, test_case.transforms
),
destination=b_smem_ref,
barrier=dialect_barrier,
indices=zero_slice_indices,
slice_lengths=shape,
@ -2740,9 +2697,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
# SMEM -> GMEM
mgpu_dialect.async_store(
source=memref_with_transforms(
result_smem_ref, test_case.transforms
),
source=result_smem_ref,
destination=result_gmem_ref,
indices=zero_slice_indices,
slice_lengths=shape,
@ -2752,114 +2707,76 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
dtype = jnp.bfloat16
jax_shape = jax.ShapeDtypeStruct(test_case.shape, dtype)
spec = jax.ShapeDtypeStruct((2, 3, 4, 64), dtype)
kernel = mgpu.as_gpu_kernel(
add,
grid=(1, 1, 1),
block=(128, 1, 1),
in_shape=(jax_shape, jax_shape),
out_shape=jax_shape,
in_shape=(spec, spec),
out_shape=spec,
smem_scratch_shape=[
jax_shape,
jax_shape,
jax_shape,
spec,
spec,
spec,
core.TMABarrier(1),
],
thread_semantics=mgpu.ThreadSemantics.Warpgroup,
)
x = self.prng.uniform(-1, 1, test_case.shape).astype(dtype)
y = self.prng.uniform(-1, 1, test_case.shape).astype(dtype)
x = self.prng.uniform(-1, 1, spec.shape).astype(dtype)
y = self.prng.uniform(-1, 1, spec.shape).astype(dtype)
self.assertArraysEqual(jax.jit(kernel)(x, y), x + y + y)
class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase):
@staticmethod
def wgmma_kernel_with_tma_cases(abtype: jnp.dtype):
@dataclasses.dataclass(frozen=True)
class TestCaseInput:
shape_a: tuple[int, ...] = ()
shape_b: tuple[int, ...] = ()
shape_res: tuple[int, ...] = ()
transforms_a: tuple[Tile | Transpose | Swizzle, ...] = ()
transforms_b: tuple[Tile | Transpose | Swizzle, ...] = ()
transpose_a: bool = False
transpose_b: bool = False
load_a_in_registers: bool = False
@parameterized.named_parameters(
(
f"swizzle={int(swizzle)}_{transpose_lhs=}_{transpose_rhs=}_{lhs_in_registers=}",
swizzle,
transpose_lhs,
transpose_rhs,
lhs_in_registers,
)
for swizzle in mgpu_dialect.SwizzlingMode
for transpose_lhs in [False, True]
for transpose_rhs in [False, True]
for lhs_in_registers in [False, True]
)
def test_wgmma_kernel_with_tma(
self, swizzle, transpose_lhs, transpose_rhs, load_a_in_registers
):
if swizzle == mgpu_dialect.SwizzlingMode.kNoSwizzle:
self.skipTest("No swizzle is not supported by wgmma")
result = []
for swizzle in [
# TODO(dasenov): Add a test for kNoSwizzle, i.e. all swizzling modes.
mgpu_dialect.SwizzlingMode.k32ByteSwizzle,
mgpu_dialect.SwizzlingMode.k64ByteSwizzle,
mgpu_dialect.SwizzlingMode.k128ByteSwizzle,
]:
k = swizzle // np.dtype(abtype).itemsize
groups_m = 4
groups_n = 1
groups_k = 1
result.extend([
TestCaseInput(
shape_a=[groups_m * 64, groups_k * k],
shape_b=[groups_k * k, groups_n * k],
shape_res=[groups_m * 64, groups_n * k],
),
TestCaseInput(
shape_a=[groups_m * 64, groups_k * k],
shape_b=[groups_n * k, groups_k * k],
shape_res=[groups_m * 64, groups_n * k],
transpose_b=True,
),
TestCaseInput(
shape_a=[groups_m * 64, groups_k * k],
shape_b=[groups_k * k, groups_n * k],
shape_res=[groups_m * 64, groups_n * k],
transforms_a=[Tile([64, k]), Swizzle(swizzle)],
transforms_b=[Tile([k, k]), Swizzle(swizzle)],
),
TestCaseInput(
shape_a=[groups_m * 64, groups_k * k],
shape_b=[groups_k * k, groups_n * k],
shape_res=[groups_m * 64, groups_n * k],
transforms_a=[Tile([64, k]), Swizzle(swizzle)],
load_a_in_registers=True,
),
])
# The below only works for 128-byte swizzling. Regardless of transposing,
# TMA needs the size of the last dimension to be compatible with the
# swizzle.
if swizzle == mgpu_dialect.SwizzlingMode.k128ByteSwizzle:
result.append(
TestCaseInput(
shape_a=[groups_k * k, groups_m * 64],
shape_b=[groups_k * k, groups_n * k],
shape_res=[groups_m * 64, groups_n * k],
transpose_a=True,
)
)
return result
if transpose_lhs or transpose_rhs:
self.skipTest("Transposes are not supported by transform inference yet.")
@parameterized.parameters(wgmma_kernel_with_tma_cases(jnp.bfloat16))
def test_wgmma_kernel_with_tma(self, test_case):
swizzle_elems = swizzle // np.dtype(jnp.bfloat16).itemsize
tiling_m, tiling_n, tiling_k = 64, swizzle_elems, swizzle_elems
groups_m, groups_n, groups_k = 4, 1, 1
m, n, k = groups_m * tiling_m, groups_n * tiling_n, groups_k * tiling_k
lhs_shape = (k, m) if transpose_lhs else (m, k)
rhs_shape = (n, k) if transpose_rhs else (k, n)
out_shape = (m, n)
def matmul(
ctx: launch_context.LaunchContext,
a_gmem_ref: ir.Value,
b_gmem_ref: ir.Value,
lhs_gmem_ref: ir.Value,
rhs_gmem_ref: ir.Value,
result_gmem_ref: ir.Value,
smem: list[ir.Value],
):
del ctx
a_smem_ref, b_smem_ref, result_smem_ref, tma_barrier = smem
a_smem_ref = memref_with_transforms(a_smem_ref, test_case.transforms_a)
b_smem_ref = memref_with_transforms(b_smem_ref, test_case.transforms_b)
lhs_smem_ref, rhs_smem_ref, result_smem_ref, tma_barrier = smem
dialect_barrier = tma_barrier.as_dialect_barrier_memref()
ab_elt_type = ir.MemRefType(a_gmem_ref.type).element_type
bytes_a = utils.bytewidth(ab_elt_type) * math.prod(test_case.shape_a)
bytes_b = utils.bytewidth(ab_elt_type) * math.prod(test_case.shape_b)
operand_elt_type = ir.MemRefType(lhs_gmem_ref.type).element_type
bytes_a = utils.bytewidth(operand_elt_type) * math.prod(lhs_shape)
bytes_b = utils.bytewidth(operand_elt_type) * math.prod(rhs_shape)
mgpu_dialect.arrive_expect_tx(
barrier=dialect_barrier,
@ -2869,19 +2786,19 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase):
zero_i32 = arith.constant(ir.IntegerType.get_signless(32), 0)
# GMEM -> SMEM
mgpu_dialect.async_load(
source=a_gmem_ref,
destination=a_smem_ref,
source=lhs_gmem_ref,
destination=lhs_smem_ref,
barrier=dialect_barrier,
indices=[zero_i32] * len(test_case.shape_a),
slice_lengths=test_case.shape_a,
indices=[zero_i32] * len(lhs_shape),
slice_lengths=lhs_shape,
collective=ir.ArrayAttr.get([]),
)
mgpu_dialect.async_load(
source=b_gmem_ref,
destination=b_smem_ref,
source=rhs_gmem_ref,
destination=rhs_smem_ref,
barrier=dialect_barrier,
indices=[zero_i32] * len(test_case.shape_b),
slice_lengths=test_case.shape_b,
indices=[zero_i32] * len(rhs_shape),
slice_lengths=rhs_shape,
collective=ir.ArrayAttr.get([]),
)
@ -2889,29 +2806,34 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase):
parity, _ = tma_barrier.update_parities(parities)
mgpu_dialect.wait(dialect_barrier, parity)
# SMEM -> Registers
a_operand = a_smem_ref
zero_index = arith.constant(ir.IndexType.get(), 0)
if test_case.load_a_in_registers:
a_vector_type = ir.VectorType.get(test_case.shape_a, ab_elt_type)
zero_vector_indices = [zero_index] * len(test_case.shape_a)
a_operand = vector.load(a_vector_type, a_smem_ref, zero_vector_indices)
# Computation
shape_result = ir.MemRefType(result_gmem_ref.type).shape
result_elt_type = ir.MemRefType(result_gmem_ref.type).element_type
acc_elt_type = ir.F32Type.get()
acc_type = ir.VectorType.get(shape_result, acc_elt_type)
zero_acc = arith.constant(
result_elt_type, ir.FloatAttr.get(result_elt_type, 0.0)
)
accumulator = vector.splat(
ir.VectorType.get(shape_result, result_elt_type), zero_acc
result_elt_type, ir.FloatAttr.get(acc_elt_type, 0.0)
)
accumulator = vector.splat(acc_type, zero_acc)
if transpose_lhs:
lhs_smem_ref = utils.memref_transpose(lhs_smem_ref, (1, 0))
if transpose_rhs:
rhs_smem_ref = utils.memref_transpose(rhs_smem_ref, (1, 0))
zero_index = arith.constant(ir.IndexType.get(), 0)
if load_a_in_registers:
# SMEM -> Registers
lhs_ty = ir.VectorType.get(lhs_shape, operand_elt_type)
zero_vector_indices = [zero_index] * len(lhs_shape)
lhs_operand = vector.load(lhs_ty, lhs_smem_ref, zero_vector_indices)
else:
lhs_operand = lhs_smem_ref
result = mgpu_dialect.wgmma(
accumulator,
a_operand,
b_smem_ref,
transpose_a=test_case.transpose_a,
transpose_b=test_case.transpose_b,
lhs_operand,
rhs_smem_ref,
)
nvvm.wgmma_commit_group_sync_aligned()
@ -2929,38 +2851,41 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase):
)
nvvm.cp_async_bulk_wait_group(0)
abtype = jnp.bfloat16
operand_type = jnp.bfloat16
acctype = jnp.float32
a_jax_shape = jax.ShapeDtypeStruct(test_case.shape_a, abtype)
b_jax_shape = jax.ShapeDtypeStruct(test_case.shape_b, abtype)
result_jax_shape = jax.ShapeDtypeStruct(test_case.shape_res, acctype)
lhs_jax_shape = jax.ShapeDtypeStruct(lhs_shape, operand_type)
rhs_jax_shape = jax.ShapeDtypeStruct(rhs_shape, operand_type)
result_jax_shape = jax.ShapeDtypeStruct(out_shape, acctype)
kernel = mgpu.as_gpu_kernel(
matmul,
grid=(1, 1, 1),
block=(128, 1, 1),
in_shape=(a_jax_shape, b_jax_shape),
in_shape=(lhs_jax_shape, rhs_jax_shape),
out_shape=result_jax_shape,
smem_scratch_shape=[
a_jax_shape,
b_jax_shape,
lhs_jax_shape,
rhs_jax_shape,
result_jax_shape,
core.TMABarrier(1),
],
thread_semantics=mgpu.ThreadSemantics.Warpgroup,
)
x = self.prng.uniform(-1, 1, test_case.shape_a).astype(abtype)
y = self.prng.uniform(-1, 1, test_case.shape_b).astype(abtype)
prng_key = jax.random.key(1234)
k0, k1 = jax.random.split(prng_key, 2)
x = jax.random.randint(k0, lhs_shape, 0, 2).astype(operand_type)
y = jax.random.randint(k1, rhs_shape, 0, 2).astype(operand_type)
transpose = lambda x, t: x.T if t else x
self.assertArraysAllClose(
jax.jit(kernel)(x, y),
np.matmul(
transpose(x.reshape(test_case.shape_a), test_case.transpose_a),
transpose(y.reshape(test_case.shape_b), test_case.transpose_b),
transpose(x, transpose_lhs),
transpose(y, transpose_rhs)
),
atol=1e-5,
rtol=1e-5,
atol=0,
rtol=0,
)

View File

@ -2501,7 +2501,8 @@ class SymbolicPallasTest(PallasBaseTest):
)
assert exported_module is not None
self.assertIn(
"tensor<?x?xf32>, %arg6: tensor<?x?xf32>, %arg7: tensor<?x?xf32>",
"%arg0: tensor<?x?xf32> loc(unknown), %arg1: tensor<?x?xf32>"
" loc(unknown), %arg2: tensor<?x?xf32>",
str(exported_module),
)
x = jax.ShapeDtypeStruct((128, 1024), jax.numpy.float32)
@ -2512,7 +2513,7 @@ class SymbolicPallasTest(PallasBaseTest):
)
assert exported_module is not None
self.assertIn(
"@sym_matmul(%arg0: tensor<128x1024xf32>, %arg1: tensor<1024x512xf32>",
"call @sym_matmul(%arg0, %arg1)",
str(exported_module),
)

View File

@ -156,5 +156,36 @@ class InterpretTest(jtu.JaxTestCase):
lowered = jax.jit(matmul).lower(x, y).as_text(dialect="stablehlo")
self.assertNotIn("dot_general", lowered)
@parameterized.parameters('nan', 'zero')
def test_uninitialized_memory(self, uninitialized_memory):
def kernel(o1_ref, o2_ref, o3_ref, t1_ref, t2_ref):
o1_ref[...] = t1_ref[...]
o2_ref[...] = t2_ref[...]
x, y, z = pl.pallas_call(
kernel,
out_shape=[
jax.ShapeDtypeStruct((8, 128), jnp.bfloat16),
jax.ShapeDtypeStruct((8, 128), jnp.int16),
jax.ShapeDtypeStruct((8, 128), jnp.float32),
],
in_specs=[],
scratch_shapes=[
pltpu.VMEM((8, 128), jnp.bfloat16),
pltpu.VMEM((8, 128), jnp.int16),
],
interpret=mosaic_interpret.TPUInterpretParams(
uninitialized_memory=uninitialized_memory),
)()
if uninitialized_memory == 'nan':
self.assertTrue(jnp.isnan(x).all())
np.testing.assert_equal(np.array(y), 32767)
self.assertTrue(jnp.isnan(z).all())
if uninitialized_memory == 'zero':
np.testing.assert_equal(np.array(x), 0)
np.testing.assert_equal(np.array(y), 0)
np.testing.assert_equal(np.array(z), 0)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -1680,6 +1680,47 @@ class CustomPartitionerTest(jtu.JaxTestCase):
jit_f = jax.jit(f, in_shardings=s, out_shardings=s)
self.assertArraysEqual(x, jit_f(x))
@jtu.with_mesh([('x', 4), ('y', 2)])
def test_custom_partitioner_pytree_inputs(self):
self.skip_if_custom_partitioning_not_supported()
def partition(mesh, arg_shapes, result_shape):
def lower_fn(xs):
x, y, z = xs
return x + y + z
return (
mesh,
lower_fn,
arg_shapes[0][0].sharding,
jax.tree.map(lambda x: x.sharding, arg_shapes),
)
def infer_sharding_from_operands(mesh, arg_shapes, result_shape):
return arg_shapes[0][0].sharding
def propagate_user_sharding(mesh, user_shape):
return user_shape.sharding
@custom_partitioning
def f(xs):
x, y, z = xs
return x + y + z
f.def_partition(
infer_sharding_from_operands=infer_sharding_from_operands,
partition=partition,
propagate_user_sharding=propagate_user_sharding,
sharding_rule='i j, i j, i j -> i j',
)
def f2(a):
return a + f((a, a, a))
pjit_f = pjit(f2, in_shardings=(P(None, 'x')), out_shardings=P('x'))
x = np.asarray(np.random.randint(0, 20, (32, 16)), dtype=np.float32)
self.assertArraysEqual(x * 4, pjit_f(x))
@jtu.pytest_mark_if_available('multiaccelerator')
class AutoShardingPjitTest(jtu.JaxTestCase):
@ -7096,16 +7137,12 @@ class ShardingInTypesTest(jtu.JaxTestCase):
def test_set_mesh(self):
mesh = jtu.create_mesh((2,), ('x',), axis_types=(AxisType.Explicit,))
prev_mesh = config.device_context.value
prev_abstract_mesh = config.abstract_mesh_context_manager.value
try:
jax.sharding.set_mesh(mesh)
prev_mesh = jax.sharding.set_mesh(mesh)
out = reshard(np.arange(8), P('x'))
self.assertEqual(out.sharding, NamedSharding(mesh, P('x')))
finally:
config.device_context.set_local(prev_mesh)
config.abstract_mesh_context_manager.set_local(prev_abstract_mesh)
jax.sharding.set_mesh(prev_mesh)
@jtu.with_user_mesh((2,), ('x',))
def test_auto_axes_late_bind(self, mesh):

View File

@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
# and update XLA_SHA256 with the result.
XLA_COMMIT = "3bb765472122548cc227b8bd2990f00bd533f438"
XLA_SHA256 = "72126aac7602153aee985ca20f73d11c39e3ba9cfb8027492951e787559d0497"
XLA_COMMIT = "df971129bd82e381954da0185b534220e21798a4"
XLA_SHA256 = "11e9a568320cf7e7d61819620fd369927527ecefb68d5d1154b1521456bbdb72"
def repo():
tf_http_archive(