From 9145d617b8ade86daec508ecd17b08f20f2144b4 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 12 Mar 2025 17:50:00 +0100 Subject: [PATCH 01/20] Added exit 1 if git patch is failed + other checks --- .../workflows/requirements_lock_3_13_ft.patch | 13 ++++++------- .github/workflows/tsan.yaml | 19 ++++++++++++++++--- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/.github/workflows/requirements_lock_3_13_ft.patch b/.github/workflows/requirements_lock_3_13_ft.patch index bf4531182..0b63cb5b8 100644 --- a/.github/workflows/requirements_lock_3_13_ft.patch +++ b/.github/workflows/requirements_lock_3_13_ft.patch @@ -1,8 +1,8 @@ diff --git a/build/requirements_lock_3_13_ft.txt b/build/requirements_lock_3_13_ft.txt -index dfefaf042..2700e140e 100644 +index e7a2968e9..d37e11ee3 100644 --- a/build/requirements_lock_3_13_ft.txt +++ b/build/requirements_lock_3_13_ft.txt -@@ -4,6 +4,12 @@ +@@ -4,6 +4,11 @@ # # pip-compile --allow-unsafe --generate-hashes --output-file=build/requirements_lock_3_13_ft.txt build/requirements.in # @@ -10,12 +10,11 @@ index dfefaf042..2700e140e 100644 +--pre +--extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple +numpy -+ + absl-py==2.1.0 \ --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff -@@ -328,68 +334,6 @@ mpmath==1.3.0 \ +@@ -328,68 +333,6 @@ mpmath==1.3.0 \ --hash=sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f \ --hash=sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c # via -r build/test-requirements.txt @@ -81,6 +80,6 @@ index dfefaf042..2700e140e 100644 - # matplotlib - # ml-dtypes - # scipy - opt-einsum==3.4.0 \ - --hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \ - --hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac + nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \ + --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ + --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ diff --git a/.github/workflows/tsan.yaml b/.github/workflows/tsan.yaml index 2940d3dd2..7d93707e4 100644 --- a/.github/workflows/tsan.yaml +++ b/.github/workflows/tsan.yaml @@ -173,12 +173,18 @@ jobs: --bazel_options=--copt=-g \ --clang_path=/usr/bin/clang-18 - # Update the patch to use TSAN instrumented numpy + # Patch build/requirements_lock_3_13_ft.txt to use TSAN instrumented NumPy sed -i "s|+--extra-index-url.*|+--extra-index-url file://${GITHUB_WORKSPACE}/wheelhouse/|" .github/workflows/requirements_lock_3_13_ft.patch cat .github/workflows/requirements_lock_3_13_ft.patch + git apply .github/workflows/requirements_lock_3_13_ft.patch || exit 1 - # Apply a patch to numpy in requirements lock 3.13 ft to use the nightly version - git apply .github/workflows/requirements_lock_3_13_ft.patch + # Display the content for debugging in logs + cat build/requirements_lock_3_13_ft.txt | head -15 + # Check the patch + cat build/requirements_lock_3_13_ft.txt | head -15 | grep -E "(--pre|.*${GITHUB_WORKSPACE}/wheelhouse/|numpy)" + if [ "$?" == "1" ]; then echo "Could not find the patch in the requirements_lock_3_13_ft.txt"; exit 1; fi + cat build/requirements_lock_3_13_ft.txt | grep -E "(numpy==)" + if [ "$?" == "0" ]; then "Found original numpy dependency in the requirements_lock_3_13_ft.txt"; exit 1; fi echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES" echo "JAX_ENABLE_X64=$JAX_ENABLE_X64" @@ -188,6 +194,13 @@ jobs: bazel_exec=($(ls bazel-*)) ln -s ${bazel_exec} bazel + # Check python version + ./bazel run --@rules_python//python/config_settings:py_freethreaded="yes" @python//:python3 -- -VV + + # Check numpy version + ./bazel cquery @pypi_numpy//:* | grep whl + + # Build JAX and run tests ./bazel test \ --test_env=JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES \ --test_env=JAX_ENABLE_X64=$JAX_ENABLE_X64 \ From 8b46e53a4f8af705fc7218ec135acf95df9152b0 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 18 Mar 2025 08:55:38 -0700 Subject: [PATCH 02/20] jax.lax: improve docs for several APIs --- jax/_src/lax/lax.py | 166 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 143 insertions(+), 23 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 186a915e0..86a75ada6 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -615,8 +615,23 @@ def tanh(x: ArrayLike) -> Array: """ return tanh_p.bind(x) +@export def logistic(x: ArrayLike) -> Array: - r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`.""" + r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`. + + There is no HLO logistic/sigmoid primitive, so this lowers to a sequence + of HLO arithmetic operations. + + Args: + x: input array. Must have floating point or complex dtype. + + Returns: + Array of the same shape and dtype as ``x`` containing the element-wise + logistic/sigmoid function. + + See also: + - :func:`jax.nn.sigmoid`: an alternative API for this functionality. + """ return logistic_p.bind(x) @export @@ -1018,12 +1033,45 @@ def bitwise_xor(x: ArrayLike, y: ArrayLike) -> Array: """ return xor_p.bind(x, y) +@export def population_count(x: ArrayLike) -> Array: - r"""Elementwise popcount, count the number of set bits in each element.""" + r"""Elementwise popcount, count the number of set bits in each element. + + This function lowers directly to the `stablehlo.popcnt`_ operation. + + Args: + x: Input array. Must have integer dtype. + + Returns: + An array of the same shape and dtype as ``x``, containing the number of + set bits in the input. + + See also: + - :func:`jax.lax.clz`: Elementwise count leading zeros. + - :func:`jax.numpy.bitwise_count`: More flexible NumPy-style API for bit counts. + + .. _stablehlo.popcnt: https://openxla.org/stablehlo/spec#popcnt + """ return population_count_p.bind(x) +@export def clz(x: ArrayLike) -> Array: - r"""Elementwise count-leading-zeros.""" + r"""Elementwise count-leading-zeros. + + This function lowers directly to the `stablehlo.count_leading_zeros`_ operation. + + Args: + x: Input array. Must have integer dtype. + + Returns: + An array of the same shape and dtype as ``x``, containing the number of + set bits in the input. + + See also: + - :func:`jax.lax.population_count`: Count the number of set bits in each element. + + .. _stablehlo.count_leading_zeros: https://openxla.org/stablehlo/spec#count_leading_zeros + """ return clz_p.bind(x) @export @@ -1124,31 +1172,81 @@ def div(x: ArrayLike, y: ArrayLike) -> Array: """ return div_p.bind(x, y) +@export def rem(x: ArrayLike, y: ArrayLike) -> Array: r"""Elementwise remainder: :math:`x \bmod y`. - The sign of the result is taken from the dividend, - and the absolute value of the result is always - less than the divisor's absolute value. + This function lowers directly to the `stablehlo.remainder`_ operation. + The sign of the result is taken from the dividend, and the absolute value + of the result is always less than the divisor's absolute value. - Integer division overflow - (remainder by zero or remainder of INT_SMIN with -1) + Integer division overflow (remainder by zero or remainder of INT_SMIN with -1) produces an implementation defined value. + + Args: + x, y: Input arrays. Must have matching int or float dtypes. If neither + is a scalar, ``x`` and ``y`` must have the same number of dimensions + and be broadcast compatible. + + Returns: + An array of the same dtype as ``x`` and ``y`` containing the remainder. + + See also: + - :func:`jax.numpy.remainder`: NumPy-style remainder with different + sign semantics. + + .. _stablehlo.remainder: https://openxla.org/stablehlo/spec#remainder """ return rem_p.bind(x, y) +@export def max(x: ArrayLike, y: ArrayLike) -> Array: - r"""Elementwise maximum: :math:`\mathrm{max}(x, y)` + r"""Elementwise maximum: :math:`\mathrm{max}(x, y)`. - For complex numbers, uses a lexicographic comparison on the - `(real, imaginary)` pairs.""" + This function lowers directly to the `stablehlo.maximum`_ operation for + non-complex inputs. For complex numbers, this uses a lexicographic + comparison on the `(real, imaginary)` pairs. + + Args: + x, y: Input arrays. Must have matching dtypes. If neither is a scalar, + ``x`` and ``y`` must have the same rank and be broadcast compatible. + + Returns: + An array of the same dtype as ``x`` and ``y`` containing the elementwise + maximum. + + See also: + - :func:`jax.numpy.maximum`: more flexibly NumPy-style maximum. + - :func:`jax.lax.reduce_max`: maximum along an axis of an array. + - :func:`jax.lax.min`: elementwise minimum. + + .. _stablehlo.maximum: https://openxla.org/stablehlo/spec#maximum + """ return max_p.bind(x, y) +@export def min(x: ArrayLike, y: ArrayLike) -> Array: - r"""Elementwise minimum: :math:`\mathrm{min}(x, y)` + r"""Elementwise minimum: :math:`\mathrm{min}(x, y)` - For complex numbers, uses a lexicographic comparison on the - `(real, imaginary)` pairs.""" + This function lowers directly to the `stablehlo.minimum`_ operation for + non-complex inputs. For complex numbers, this uses a lexicographic + comparison on the `(real, imaginary)` pairs. + + Args: + x, y: Input arrays. Must have matching dtypes. If neither is a scalar, + ``x`` and ``y`` must have the same rank and be broadcast compatible. + + Returns: + An array of the same dtype as ``x`` and ``y`` containing the elementwise + minimum. + + See also: + - :func:`jax.numpy.minimum`: more flexibly NumPy-style minimum. + - :func:`jax.lax.reduce_min`: minimum along an axis of an array. + - :func:`jax.lax.max`: elementwise maximum. + + .. _stablehlo.minimum: https://openxla.org/stablehlo/spec#minimum + """ return min_p.bind(x, y) @export @@ -1408,21 +1506,38 @@ def lt(x: ArrayLike, y: ArrayLike) -> Array: """ return lt_p.bind(x, y) +@export def convert_element_type(operand: ArrayLike, new_dtype: DTypeLike | dtypes.ExtendedDType) -> Array: """Elementwise cast. - Wraps XLA's `ConvertElementType - `_ - operator, which performs an elementwise conversion from one type to another. - Similar to a C++ `static_cast`. + This function lowers directly to the `stablehlo.convert`_ operation, which + performs an elementwise conversion from one type to another, similar to a + C++ ``static_cast``. Args: operand: an array or scalar value to be cast. - new_dtype: a NumPy dtype representing the target type. + new_dtype: a dtype-like object (e.g. a :class:`numpy.dtype`, a scalar type, + or a valid dtype name) representing the target dtype. Returns: - An array with the same shape as `operand`, cast elementwise to `new_dtype`. + An array with the same shape as ``operand``, cast elementwise to ``new_dtype``. + + .. note:: + + If ``new_dtype`` is a 64-bit type and `x64 mode`_ is not enabled, + the appropriate 32-bit type will be used in its place. + + If the input is a JAX array and the input dtype and output dtype match, then + the input array will be returned unmodified. + + See also: + - :func:`jax.numpy.astype`: NumPy-style dtype casting API. + - :meth:`jax.Array.astype`: dtype casting as an array method. + - :func:`jax.lax.bitcast_convert_type`: cast bits directly to a new dtype. + + .. _stablehlo.convert: https://openxla.org/stablehlo/spec#convert + .. _x64 mode: https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision """ return _convert_element_type(operand, new_dtype, weak_type=False) # type: ignore[unused-ignore,bad-return-type] @@ -1500,12 +1615,11 @@ def _convert_element_type( operand, new_dtype=new_dtype, weak_type=bool(weak_type), sharding=sharding) +@export def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array: """Elementwise bitcast. - Wraps XLA's `BitcastConvertType - `_ - operator, which performs a bit cast from one type to another. + This function lowers directly to the `stablehlo.bitcast_convert`_ operation. The output shape depends on the size of the input and output dtypes with the following logic:: @@ -1525,6 +1639,12 @@ def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array: Returns: An array of shape `output_shape` (see above) and type `new_dtype`, constructed from the same bits as operand. + + See also: + - :func:`jax.lax.convert_element_type`: value-preserving dtype conversion. + - :func:`jax.Array.view`: NumPy-style API for bitcast type conversion. + + .. _stablehlo.bitcast_convert: https://openxla.org/stablehlo/spec#bitcast_convert """ new_dtype = dtypes.canonicalize_dtype(new_dtype) return bitcast_convert_type_p.bind(operand, new_dtype=new_dtype) From 13541e9f12d1589890a9384f35e26b51e2111cc8 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 18 Mar 2025 09:11:28 -0700 Subject: [PATCH 03/20] Make blocked_fold_in consistent when the block sizes induce padding Add coverage for padded shapes to unit tests. PiperOrigin-RevId: 738029476 --- jax/_src/blocked_sampler.py | 12 ++++++----- tests/blocked_sampler_test.py | 38 +++++++++++++++++++++++++---------- 2 files changed, 34 insertions(+), 16 deletions(-) diff --git a/jax/_src/blocked_sampler.py b/jax/_src/blocked_sampler.py index e4d2e2855..3021b6a16 100644 --- a/jax/_src/blocked_sampler.py +++ b/jax/_src/blocked_sampler.py @@ -29,8 +29,8 @@ class SampleFn(Protocol): def _compute_tile_index(block_index: Sequence[int], - total_size_in_blocks: Shape, block_size_in_tiles: Shape, + total_size_in_tiles: Shape, tile_index_in_block: Sequence[int]) -> int: ndims = len(block_index) dim_size = 1 @@ -38,7 +38,7 @@ def _compute_tile_index(block_index: Sequence[int], for i in range(ndims-1, -1, -1): dim_idx = tile_index_in_block[i] + block_index[i] * block_size_in_tiles[i] total_idx += dim_idx * dim_size - dim_size *= total_size_in_blocks[i] * block_size_in_tiles[i] + dim_size *= total_size_in_tiles[i] return total_idx @@ -103,15 +103,17 @@ def blocked_fold_in( _shape // _element for _shape, _element in zip(block_size, tile_size) ) - total_size_in_blocks = tuple( - _shape // _element for _shape, _element in zip(total_size, block_size) + # Round up to make sure every tile is numbered. + total_size_in_tiles = tuple( + (_shape + _element - 1) // _element + for _shape, _element in zip(total_size, tile_size) ) def _keygen_loop(axis, prefix): if axis == len(block_size_in_tiles): subtile_key = jax.random.fold_in( global_key, _compute_tile_index( - block_index, total_size_in_blocks, block_size_in_tiles, prefix)) + block_index, block_size_in_tiles, total_size_in_tiles, prefix)) return subtile_key else: keys = [] diff --git a/tests/blocked_sampler_test.py b/tests/blocked_sampler_test.py index 4c27e850c..b5f87fe05 100644 --- a/tests/blocked_sampler_test.py +++ b/tests/blocked_sampler_test.py @@ -29,16 +29,23 @@ def call_kernel( kernel, grid: tuple[int, int], transpose_grid: bool, - *args + key: jax.Array, + total_size: tuple[int, int], + block_size: tuple[int, int], + tile_size: tuple[int, int], ): """Calls a kernel over a grid and concatenates results to a single array.""" if transpose_grid: grid = (grid[1], grid[0]) m, n = grid - return jnp.concatenate([ - jnp.concatenate([ - kernel((i, j), *args) for j in range(n)], axis=1) - for i in range(m)], axis=0) + samples = jnp.concatenate([ + jnp.concatenate([ + kernel((i, j), key, total_size, block_size, tile_size) + for j in range(n)], axis=1) + for i in range(m)], axis=0) + # Slice out the padding. + samples = samples[:total_size[0], :total_size[1]] + return samples def call_kernel_3d( @@ -73,10 +80,10 @@ def uniform_kernel(block_index, key, total_size, block_size, tile_size): block_size=block_size, tile_size=tile_size) return blocked_sampler.sample_block(jax.random.uniform, - keys, - block_size=block_size, - tile_size=tile_size, - minval=0.0, maxval=1.0) + keys, + block_size=block_size, + tile_size=tile_size, + minval=0.0, maxval=1.0) class BlockedSamplerTest(jtu.JaxTestCase): @@ -94,16 +101,25 @@ class BlockedSamplerTest(jtu.JaxTestCase): dict(testcase_name='16x256_vs_32x128', total_size=(32, 256), block_size_a=(16, 256), block_size_b=(32, 128), tile_size=(8, 128), transpose_grid=False), + dict(testcase_name='128x128_vs_128x256_padding', + total_size=(256, 128), block_size_a=(128, 128), + block_size_b=(128, 256), tile_size=(128, 128), transpose_grid=False), + dict(testcase_name='128x128_vs_128x256_padding2', + total_size=(257, 129), block_size_a=(128, 128), + block_size_b=(128, 256), tile_size=(128, 128), transpose_grid=False), ) def test_block_shape_invariance(self, total_size, block_size_a, block_size_b, tile_size, transpose_grid): global_key = jax.random.key(0) - grid_a = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_a)) + ceil_div = lambda x, y: (x + y - 1) // y + grid_a = tuple(ceil_div(_tot, _blk) + for _tot, _blk in zip(total_size, block_size_a)) result_a = call_kernel( uniform_kernel, grid_a, transpose_grid, global_key, total_size, block_size_a, tile_size) - grid_b = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_b)) + grid_b = tuple(ceil_div(_tot, _blk) + for _tot, _blk in zip(total_size, block_size_b)) result_b = call_kernel( uniform_kernel, grid_b, transpose_grid, global_key, total_size, block_size_b, tile_size) From 7c5871f464df0db16c4dcc44b20a382885db1075 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 18 Mar 2025 09:38:12 -0700 Subject: [PATCH 04/20] [Pallas TPU] Hoist prologue and epilogue outside of pipeline loop PiperOrigin-RevId: 738038138 --- jax/_src/pallas/mosaic/pipeline.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index 2044d3d18..184b1497a 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -1196,9 +1196,8 @@ def emit_pipeline( schedule = map_brefs( lambda _, x: get_pipeline_schedule(x), allocations, schedule) - def loop_body(step, indices): - nonlocal allocations - scheduler = Scheduler( + def make_scheduler(step, indices): + return Scheduler( step, indices, grid, @@ -1208,13 +1207,15 @@ def emit_pipeline( init_accumulators=init_accumulators, trace_scopes=trace_scopes, ) + + def loop_body(step, indices): + scheduler = make_scheduler(step, indices) with scheduler.grid_env(): # prepare any local VMEM aliases brefs = map_brefs(scheduler.alias_local_refs, allocations, refs) # loop input handling phase - map_brefs(scheduler.initialize, brefs, refs, schedule) map_brefs(scheduler.copy_in, brefs, refs, schedule) map_brefs(scheduler.wait_in, brefs, refs, schedule) @@ -1243,12 +1244,24 @@ def emit_pipeline( lambda: None) map_brefs(scheduler.swap_slots, brefs, refs, schedule) - map_brefs(scheduler.finalize, brefs, refs, schedule) - return _next_index(indices, grid) - # run pipeline - lax.fori_loop(0, num_steps, loop_body, (0,) * len(grid)) + @pl.when(num_steps > 0) + def _(): + # pipeline prologue + initial_indices = (0,) * len(grid) + scheduler = make_scheduler(0, initial_indices) + brefs = map_brefs(scheduler.alias_local_refs, allocations, refs) + map_brefs(scheduler.initialize, brefs, refs, schedule) + + # pipeline loop + next_indices = lax.fori_loop(0, num_steps, loop_body, initial_indices) + + # pipeline epilogue + final_indices = _prev_index(next_indices, grid) + scheduler = make_scheduler(num_steps - 1, final_indices) + brefs = map_brefs(scheduler.alias_local_refs, allocations, refs) + map_brefs(scheduler.finalize, brefs, refs, schedule) return pipeline From a5c0f200e72d28ae41730be5b244adf82ab4fce1 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 18 Mar 2025 09:42:42 -0700 Subject: [PATCH 05/20] `set_mesh` should return the prev_mesh instead of nothing. Users can choose to use the return value or ignore it. PiperOrigin-RevId: 738039559 --- jax/_src/sharding_impls.py | 16 ++++++++++++---- tests/pjit_test.py | 8 ++------ 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 60e8c54a4..51c4ad639 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -1391,12 +1391,20 @@ def use_mesh(mesh: mesh_lib.Mesh): mesh_lib.use_concrete_mesh(mesh)): yield -def set_mesh(mesh: mesh_lib.Mesh) -> None: - if not isinstance(mesh, mesh_lib.Mesh): +def set_mesh(mesh: mesh_lib.Mesh | None) -> mesh_lib.Mesh | None: + """Sets the given concrete mesh globally and returns the previous concrete + mesh.""" + if mesh is not None and not isinstance(mesh, mesh_lib.Mesh): raise ValueError( f"Expected mesh of type `jax.sharding.Mesh`. Got {type(mesh)}") if not core.trace_state_clean(): raise ValueError('`set_mesh` can only be used outside of `jax.jit`.') - config.abstract_mesh_context_manager.set_local(mesh.abstract_mesh) - config.device_context.set_local(mesh) + if mesh is None: + config.abstract_mesh_context_manager.set_global(mesh_lib.empty_abstract_mesh) # type: ignore + else: + config.abstract_mesh_context_manager.set_global(mesh.abstract_mesh) # type: ignore + + prev_mesh = config.device_context.get_global() + config.device_context.set_global(mesh) + return prev_mesh diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 7687bf110..76336920b 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7096,16 +7096,12 @@ class ShardingInTypesTest(jtu.JaxTestCase): def test_set_mesh(self): mesh = jtu.create_mesh((2,), ('x',), axis_types=(AxisType.Explicit,)) - prev_mesh = config.device_context.value - prev_abstract_mesh = config.abstract_mesh_context_manager.value try: - jax.sharding.set_mesh(mesh) - + prev_mesh = jax.sharding.set_mesh(mesh) out = reshard(np.arange(8), P('x')) self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) finally: - config.device_context.set_local(prev_mesh) - config.abstract_mesh_context_manager.set_local(prev_abstract_mesh) + jax.sharding.set_mesh(prev_mesh) @jtu.with_user_mesh((2,), ('x',)) def test_auto_axes_late_bind(self, mesh): From 547d602760595429b2fccb1287a5615d2d004fb2 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 18 Mar 2025 11:38:21 -0700 Subject: [PATCH 06/20] Remove //jaxlib:cpu_kernels and //jaxlib:gpu_kernels forwarding Bazel targets. These were temporary forwarding targets that are no longer needed; use //jaxlib/cpu:cpu_kernels and //jaxlib/cuda:cuda_gpu_kernels instead. PiperOrigin-RevId: 738085234 --- jaxlib/BUILD | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 93c2b483c..a61bf7c88 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -281,25 +281,3 @@ nanobind_extension( "@nanobind", ], ) - -# CPU kernels - -# TODO(phawkins): Remove this forwarding target. -cc_library( - name = "cpu_kernels", - visibility = ["//visibility:public"], - deps = [ - "//jaxlib/cpu:cpu_kernels", - ], - alwayslink = 1, -) - -# TODO(phawkins): Remove this forwarding target. -cc_library( - name = "gpu_kernels", - visibility = ["//visibility:public"], - deps = [ - "//jaxlib/cuda:cuda_gpu_kernels", - ], - alwayslink = 1, -) From 875099b25dd6dea5353cf6a38b8ccd0ebaf0a0ed Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Tue, 18 Mar 2025 11:50:58 -0700 Subject: [PATCH 07/20] [Mosaic GPU] Enable the new transform inference pass in the warpgroup lowering. A couple of dummy transform inference rules needed to be added in order to contend with parts of the lowering that do not use the dialect yet, along with a transform inference rule for `memref.view`. PiperOrigin-RevId: 738089782 --- jax/experimental/mosaic/gpu/core.py | 22 +- .../mosaic/gpu/dialect_lowering.py | 194 ++++++++--- .../mosaic/gpu/layout_inference.py | 91 ------ .../mosaic/gpu/transform_inference.py | 59 +++- tests/mosaic/gpu_test.py | 306 +++++++----------- 5 files changed, 324 insertions(+), 348 deletions(-) diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 66e19bb5f..b255893e2 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -41,20 +41,16 @@ from jaxlib.mlir.dialects import memref from jaxlib.mlir.dialects import nvvm import numpy as np -if dialect is not None: - from . import dialect_lowering - from . import layout_inference -else: - dialect_lowering = None - layout_inference = None - -from . import profiler -from . import utils -from . import launch_context -from . import tcgen05 - # mypy: ignore-errors +from . import dialect_lowering +from . import launch_context +from . import layout_inference +from . import profiler +from . import tcgen05 +from . import transform_inference +from . import utils + # MLIR can't find libdevice unless we point it to the CUDA path # TODO(apaszke): Unify with jax._src.lib.cuda_path CUDA_ROOT = "/usr/local/cuda" @@ -584,6 +580,7 @@ def as_gpu_kernel( # Run Python lowering passes. The remaining passes will be run in C++ in # jax/jaxlib/mosaic/gpu/custom_call.cc layout_inference.infer_layout(module) # pytype: disable=attribute-error + transform_inference.infer_transforms(module) # pytype: disable=attribute-error dialect_lowering.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error _initialize_scratch(launch_ctx, scratch_arr) @@ -666,6 +663,7 @@ def as_torch_gpu_kernel( # Run Python lowering passes. The remaining passes will be run in C++ in # jax/jaxlib/mosaic/gpu/custom_call.cc layout_inference.infer_layout(module) # pytype: disable=attribute-error + transform_inference.infer_transforms(module) # pytype: disable=attribute-error dialect_lowering.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error _initialize_scratch(launch_ctx, scratch_arr) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 8098d14f0..fedde5a00 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -17,6 +17,7 @@ from collections.abc import Callable import dataclasses import functools +import itertools import operator from typing import Any, Sequence, Type, cast @@ -58,7 +59,7 @@ class LoweringContext: if not _should_lower(op): return - if (name := op.OPERATION_NAME) not in _lowerings: + if (name := op.OPERATION_NAME) not in _lowerings: # pytype: disable=attribute-error raise NotImplementedError(f"Missing lowering rule for {op}") lowering_rule = _lowerings[name] @@ -227,6 +228,60 @@ def _arith_constant_op_lowering_rule( ] +def _check_transforms_and_swizzle_are_supported( + ref_ty: ir.MemRefType, + transforms: Sequence[launch_context.MemRefTransform], + swizzle: mgpu.SwizzlingMode, + minimum_swizzle: mgpu.SwizzlingMode = mgpu.SwizzlingMode.kNoSwizzle, +): + """Checks that the list of provided transforms and swizzle are supported. + + Currently, we allow the following: + - any swizzle that is larger than or equal to `minimum_swizzle`; + - optionally, a single tile transform (with rank equal to the rank of the + memref being annotated); + - optionally, a single transpose transform. + """ + if swizzle < minimum_swizzle: + raise NotImplementedError( + f"Unsupported swizzle {swizzle} smaller than {minimum_swizzle}." + ) + + partitioned_transforms = { + k: list(v) + for k, v in itertools.groupby( + transforms, lambda t: isinstance(t, launch_context.TileTransform) + ) + } + + tile_transforms = partitioned_transforms.get(True, []) + other_transforms = partitioned_transforms.get(False, []) + + if len(tile_transforms) > 1: + raise NotImplementedError( + f"{tile_transforms} contains more than one tile transform." + ) + + if len(tile_transforms) == 1: + if len(tile_transforms[0].tiling) != len(ref_ty.shape): + raise NotImplementedError( + f"Only tile transforms with rank equal to the rank of the memref " + f"being annotated are supported but got {tile_transforms[0]} for " + f"{ref_ty}." + ) + + if len(other_transforms) > 1: + raise NotImplementedError( + f"{other_transforms} contains more than one transform." + ) + + if len(other_transforms) == 1: + if not isinstance(other_transforms[0], launch_context.TransposeTransform): + raise NotImplementedError( + f"{other_transforms[0]} is not a transpose transform." + ) + + @_register_lowering(vector.LoadOp) def _vector_load_op_lowering_rule( _: LoweringContext, vector_load_op: vector.LoadOp @@ -260,8 +315,11 @@ def _vector_load_op_lowering_rule( vec_size=strided_layout.vec_size, ) elif layouts.from_layout_attr(out_layout_attr) == fa.WGMMA_LAYOUT: - layout = ir.MemRefType(vector_load_op.base.type).layout - swizzle, transforms = memref_layout_to_swizzle_and_transforms(layout) + swizzle, transforms = swizzle_and_transforms_from_transforms_attr( + inference_utils.in_transforms(vector_load_op)[0] + ) + ref_ty = ir.MemRefType(vector_load_op.base.type) + _check_transforms_and_swizzle_are_supported(ref_ty, transforms, swizzle) transformed_ref = transform_memref(vector_load_op.base, transforms) fragmented_array = fa.FragmentedArray.load_tiled( transformed_ref, @@ -297,8 +355,22 @@ def _vector_store_op_lowering_rule( vector_store_op.valueToStore, to_store_layout ) - # TODO(dasenov): This is not efficient for WGMMA layouts - fragmented_array.store_untiled(vector_store_op.base) + if fragmented_array.layout == fa.WGMMA_LAYOUT: + swizzle, transforms = swizzle_and_transforms_from_transforms_attr( + inference_utils.in_transforms(vector_store_op)[0] + ) + ref_ty = ir.MemRefType(vector_store_op.base.type) + _check_transforms_and_swizzle_are_supported(ref_ty, transforms, swizzle) + fragmented_array.store_tiled( + transform_memref(vector_store_op.base, transforms), swizzle + ) + elif (isinstance(fragmented_array.layout, fa.WGStridedFragLayout) or + isinstance(fragmented_array.layout, fa.WGSplatFragLayout)): + fragmented_array.store_untiled(vector_store_op.base) + else: + raise ValueError( + f"{vector_store_op} has an unsupported layout: {to_store_layout}" + ) return [] @@ -362,39 +434,43 @@ def _vector_reduction_op_lowering_rule( return [_fragmented_array_to_ir(result, op.result.type)] -def memref_layout_to_swizzle_and_transforms( - layout: ir.Attribute, +def swizzle_and_transforms_from_transforms_attr( + transforms: ir.ArrayAttr, ) -> tuple[mgpu.SwizzlingMode, tuple[launch_context.MemRefTransform, ...]]: - """Returns the swizzle and transforms that are encoded in the given layout. + """Returns the swizzle and MemrefTransforms for the given transforms. - If the layout is not a LayoutAttr, the swizzle is kNoSwizzle and the - transforms are empty. Otherwise, the layout may have at most one swizzle - transform and any combination of tiling and transpose transforms. + Args: + transforms: a list of transform attributes. + + Returns: + A tuple containing the swizzle mode and MemRefTransforms corresponding to + the parameter transforms. If `transforms` is empty, or does not contain + any swizzling transform, the swizzle mode is assumed to be kNoSwizzle. + Raises: + ValueError: if a swizzling transform is followed by any transform. """ swizzle = None gmem_transforms: list[launch_context.MemRefTransform] = [] - if mgpu.LayoutAttr.isinstance(layout): - transforms_attr = mgpu.LayoutAttr(layout).transforms - for transform in transforms_attr: - if swizzle is not None: - raise ValueError(f"{layout} contains more transforms after the initial swizzle.") - if mgpu.SwizzleTransformAttr.isinstance(transform): - # TODO(dasenov): Swizzling can change if the ref is sliced in certain - # ways. We might want to enforce some restrictions here. - swizzle = mgpu.SwizzleTransformAttr(transform).swizzle - elif mgpu.TileTransformAttr.isinstance(transform): - tiling = mgpu.TileTransformAttr(transform).tiling - tiling_transform = launch_context.TileTransform(tuple(tiling)) - gmem_transforms.append(tiling_transform) - elif mgpu.TransposeTransformAttr.isinstance(transform): - permutation = mgpu.TransposeTransformAttr(transform).permutation - transpose_transform = launch_context.TransposeTransform( - tuple(permutation) - ) - gmem_transforms.append(transpose_transform) - else: - raise ValueError(f"{layout} has an unsupported transform: {transform}") + for transform in transforms: + if swizzle is not None: + raise ValueError(f"{transforms} contain more transforms after swizzle.") + if mgpu.SwizzleTransformAttr.isinstance(transform): + # TODO(dasenov): Swizzling can change if the ref is sliced in certain + # ways. We might want to enforce some restrictions here. + swizzle = mgpu.SwizzleTransformAttr(transform).swizzle + elif mgpu.TileTransformAttr.isinstance(transform): + tiling = mgpu.TileTransformAttr(transform).tiling + tiling_transform = launch_context.TileTransform(tuple(tiling)) + gmem_transforms.append(tiling_transform) + elif mgpu.TransposeTransformAttr.isinstance(transform): + permutation = mgpu.TransposeTransformAttr(transform).permutation + transpose_transform = launch_context.TransposeTransform( + tuple(permutation) + ) + gmem_transforms.append(transpose_transform) + else: + raise ValueError("Unknown transform: {transform}") return swizzle or mgpu.SwizzlingMode.kNoSwizzle, tuple(gmem_transforms) @@ -434,8 +510,14 @@ def _mgpu_async_load_op_lowering_rule( assert ctx.launch_context is not None barrier = utils.BarrierRef.from_dialect_barrier_memref(load_op.barrier) - dst_layout = ir.MemRefType(load_op.destination.type).layout - swizzle, transforms = memref_layout_to_swizzle_and_transforms(dst_layout) + if inference_utils.has_in_transforms_set(load_op): + [transforms] = inference_utils.in_transforms(load_op) + swizzle, transforms = swizzle_and_transforms_from_transforms_attr( + transforms + ) + else: + swizzle = mgpu.SwizzlingMode.kNoSwizzle + transforms = () gmem_slice = [] for idx_i32, size in zip(load_op.indices, load_op.slice_lengths): @@ -464,8 +546,14 @@ def _mgpu_async_store_op_lowering_rule( ) -> Sequence[ir.Value]: assert ctx.launch_context is not None - src_layout = ir.MemRefType(store_op.source.type).layout - swizzle, transforms = memref_layout_to_swizzle_and_transforms(src_layout) + if inference_utils.has_in_transforms_set(store_op): + [transforms] = inference_utils.in_transforms(store_op) + swizzle, transforms = swizzle_and_transforms_from_transforms_attr( + transforms + ) + else: + swizzle = mgpu.SwizzlingMode.kNoSwizzle + transforms = () gmem_slice = [] for idx_i32, size in zip(store_op.indices, store_op.slice_lengths): @@ -673,6 +761,9 @@ def _bitcast_op_lowering_rule( def _mgpu_wgmma_op_lowering_rule( _: LoweringContext, wgmma_op: mgpu.WGMMAOp ) -> Sequence[ir.Value]: + if wgmma_op.transpose_a or wgmma_op.transpose_b: + raise ValueError("Transpose arguments are to be deleted.") + fa_layouts = ( *inference_utils.in_layouts(wgmma_op), *inference_utils.out_layouts(wgmma_op), @@ -691,25 +782,38 @@ def _mgpu_wgmma_op_lowering_rule( regs = acc_in.to_layout(fa.WGMMA_LAYOUT) acc = wgmma.WGMMAAccumulator.from_registers(regs) - b_layout = ir.MemRefType(wgmma_op.b.type).layout - b_swizzle, b_transforms = memref_layout_to_swizzle_and_transforms(b_layout) + if ir.VectorType.isinstance(wgmma_op.a.type): + a_transforms = None + b_transforms = inference_utils.in_transforms(wgmma_op)[0] + else: + a_transforms, b_transforms = inference_utils.in_transforms(wgmma_op) + + b_swizzle, b_transforms = swizzle_and_transforms_from_transforms_attr( + b_transforms + ) + minimum_swizzle = mgpu.SwizzlingMode.k32ByteSwizzle + ref_ty = ir.MemRefType(wgmma_op.b.type) + _check_transforms_and_swizzle_are_supported( + ref_ty, b_transforms, b_swizzle, minimum_swizzle + ) b_operand = transform_memref(wgmma_op.b, b_transforms) - if wgmma_op.transpose_b: - b_operand = utils.memref_transpose(b_operand, (0, 1, 3, 2)) if ir.VectorType.isinstance(wgmma_op.a.type): a_operand = _fragmented_array_from_ir(wgmma_op.a, wgmma_layout) else: - a_layout = ir.MemRefType(wgmma_op.a.type).layout - a_swizzle, a_transforms = memref_layout_to_swizzle_and_transforms(a_layout) + a_swizzle, a_transforms = swizzle_and_transforms_from_transforms_attr( + a_transforms + ) + ref_ty = ir.MemRefType(wgmma_op.a.type) + _check_transforms_and_swizzle_are_supported( + ref_ty, a_transforms, a_swizzle, minimum_swizzle + ) if a_swizzle != b_swizzle: raise ValueError( f"Non-matching swizzles of operands a and b in WGMMA: {a_swizzle} !=" f" {b_swizzle}" ) a_operand = transform_memref(wgmma_op.a, a_transforms) - if wgmma_op.transpose_a: - a_operand = utils.memref_transpose(a_operand, (0, 1, 3, 2)) new_acc = wgmma.wgmma(acc, a_operand, b_operand, swizzle=b_swizzle) @@ -902,7 +1006,7 @@ def single_thread_predicates(module: ir.Module) -> tuple[ir.Value, ir.Value]: def _should_lower(op: ir.OpView) -> bool: """Returns 'true' if the operation should be lowered.""" return ( - op.OPERATION_NAME.startswith("mosaic_gpu.") + op.OPERATION_NAME.startswith("mosaic_gpu.") # pytype: disable=attribute-error or inference_utils.should_have_layout(op) or any(bool(b) for r in op.regions for b in r) # Does it have subblocks? ) diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index c9479e0f1..0d2811bb5 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -383,89 +383,6 @@ def _infer_wgmma_op_layout(wgmma_op: mgpu.WGMMAOp) -> OptionalLayouts: return [layout], [layout] -@dataclasses.dataclass() -class WGMMATransforms: - swizzle: mgpu.SwizzlingMode - a_tile: tuple[int, ...] - a_transpose: bool - b_tile: tuple[int, ...] - b_transpose: bool - - -def infer_wgmma_transforms(wgmma_op: mgpu.WGMMAOp) -> WGMMATransforms: - a_shape = cast(ir.ShapedType, wgmma_op.a.type).shape - k = a_shape[0] if wgmma_op.transpose_a else a_shape[1] - bitwidth = cast(ir.ShapedType, wgmma_op.a.type).element_type.width - - # Try tiling with all swizzling modes starting from the largest one. - for swizzle in [ - mgpu.SwizzlingMode.k128ByteSwizzle, - mgpu.SwizzlingMode.k64ByteSwizzle, - mgpu.SwizzlingMode.k32ByteSwizzle, - ]: - s = swizzle * 8 // bitwidth - if k % s == 0: - return WGMMATransforms( - swizzle=swizzle, - a_tile=(s, 64) if wgmma_op.transpose_a else (64, s), - a_transpose=wgmma_op.transpose_a, - b_tile=(s, s), - b_transpose=wgmma_op.transpose_b, - ) - raise ValueError( - "Could not infer layouts for memref feeding into WGMMA. The " - "non-contracting dimension ({k}) must be a multiple of " - "s = swizzle * (8 / bitwidth) where swizzle is a valid swizzle " - f"(32, 64, or 128) and bitwidth ({bitwidth}) is the element size of " - "`a` and `b`." - ) - -def _layout_for_memref_view(view_op: memref.ViewOp) -> ir.Attribute | None: - wgmma_use = None - uses = cast(ir.OpResult, view_op.result).uses - for use in uses: - user = use.owner - if isinstance(user, memref.CastOp): - # This memref is already cast, so we don't need to do anything. - return None - if isinstance(user, mgpu.WGMMAOp): - if wgmma_use is not None: - raise NotImplementedError(f"Multiple WGMMA consumers of {view_op}.") - wgmma_use = use - break - if ( - not isinstance(user, mgpu.AsyncLoadOp) - and not isinstance(user, mgpu.AsyncStoreOp) - and not isinstance(user, vector.LoadOp) - and not isinstance(user, vector.StoreOp) - ): - raise NotImplementedError(f"Unsupported user {user} of {view_op}.") - - if wgmma_use is None: - # This memref is not used by a WGMMA operation, so we don't need to do - # anything. - return None - - transforms = infer_wgmma_transforms(wgmma_use.owner) - if wgmma_use.operand_number == 1: - tile = transforms.a_tile - transpose = transforms.a_transpose - else: - tile = transforms.b_tile - transpose = transforms.b_transpose - transpose_attr = ( - [mgpu.TransposeTransformAttr.get([1, 0, 2, 3])] if transpose else [] - ) - - layout = mgpu.LayoutAttr.get( - 2, - [mgpu.TileTransformAttr.get(tile)] - + transpose_attr - + [mgpu.SwizzleTransformAttr.get(transforms.swizzle)], - ) - - return layout - def _earliest_use(regions: list[ir.Region], uses: Sequence[ir.OpOperand]) -> ir.OpView: owners = [use.owner for use in uses] @@ -607,11 +524,3 @@ def infer_layout(module: ir.Module): for op in module.body: traverse_op(op, set_default_layout) - - def infer_memref_layouts_and_insert_casts(op: ir.OpView): - if op.name == "memref.view": - if layout := _layout_for_memref_view(op): - _insert_memref_layout_cast(layout, op) - - for op in module.body: - traverse_op(op, infer_memref_layouts_and_insert_casts) diff --git a/jax/experimental/mosaic/gpu/transform_inference.py b/jax/experimental/mosaic/gpu/transform_inference.py index be3f2c381..ef2d36616 100644 --- a/jax/experimental/mosaic/gpu/transform_inference.py +++ b/jax/experimental/mosaic/gpu/transform_inference.py @@ -26,6 +26,9 @@ from typing import cast from jax._src.lib import mosaic_gpu_dialect as mgpu from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith +from jax._src.lib.mlir.dialects import builtin +from jax._src.lib.mlir.dialects import gpu +from jax._src.lib.mlir.dialects import memref from jax._src.lib.mlir.dialects import vector from . import fragmented_array as fa @@ -169,7 +172,6 @@ def _infer_vector_load_store_transforms( return None - # TODO(bchetioui): remove this once jaxlib minimum version >= 0.5.2. SliceSMEMOp = getattr(mgpu, "SliceSMEMOp", None) @@ -196,6 +198,60 @@ def _infer_slice_smem_transforms(op: SliceSMEMOp) -> OptionalTransforms: return None if transforms is None else ([], [transforms]) +# TODO(bchetioui,apaszke): this empty rule is necessary while Mosaic doesn't use +# the dialect in all cases. +# The rule is necessary in order to handle the lowering of `utils.memref_ptr` +# which is used in `_construct_smem_reftree`. +@partial(_add_transform_inference_rule, builtin.UnrealizedConversionCastOp) +def _infer_unrealized_conversion_cast_transforms( + _: builtin.UnrealizedConversionCastOp, +) -> OptionalTransforms: + return None + + +@partial(_add_transform_inference_rule, memref.ViewOp) +def _infer_memref_view_transforms(op: memref.ViewOp) -> OptionalTransforms: + if not isinstance(op.source.owner.opview, gpu.DynamicSharedMemoryOp): + raise NotImplementedError( + "Memref view transforms are only inferred when the op is a direct user " + f"of a DynamicSharedMemoryOp but got {op}." + ) + transforms = inference_utils.value_transforms(op.source) + if transforms is not None: + raise NotImplementedError( + "memref view with in_transforms aren't yet supported" + ) + uses = cast(ir.OpResult, op.result).uses + + for op_operand_use in uses: + consumer = op_operand_use.owner + op_user = consumer.operands[op_operand_use.operand_number] + out_transforms = inference_utils.in_transforms_for_operand( + consumer, op_user + ) + if transforms is not None and out_transforms is not None: + if transforms != out_transforms: + raise ValueError( + f"Conflicting transforms for {op_user} in {op}: " + f"{transforms} != {out_transforms}." + ) + elif out_transforms is not None: + transforms = out_transforms + + # TODO(bchetioui): do we actually need to assign a transform to the input of + # the view op? Presumably, it'll only be used to access scratch memory. + return None if transforms is None else ([], [transforms]) + + +# TODO(bchetioui,apaszke): this empty rule is necessary while Mosaic doesn't use +# the dialect in all cases. +@partial(_add_transform_inference_rule, gpu.DynamicSharedMemoryOp) +def _infer_dynamic_smem_transforms( + _: gpu.DynamicSharedMemoryOp, +) -> OptionalTransforms: + return None + + def _should_have_transforms(op: ir.OpView) -> bool: """Returns 'True' if the operation should be assigned in/out transforms.""" return any( @@ -218,7 +274,6 @@ def infer_transforms(module: ir.Module): specified. We error out if two distinct sets of transforms are competing to annotate the same memref. """ - def inference_step(op: ir.Operation): if not _should_have_transforms(op): return diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index bc56f21d0..1fcd68641 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -55,6 +55,7 @@ else: from jax.experimental.mosaic.gpu import launch_context from jax.experimental.mosaic.gpu import utils as utils from jax.experimental.mosaic.gpu import profiler + from jax.experimental.mosaic.gpu import inference_utils from jax.experimental.mosaic.gpu.utils import * # noqa: F403 from jax._src.lib.mlir.dialects import gpu from jax._src.lib.mlir.dialects import llvm @@ -2405,25 +2406,21 @@ class Swizzle: return mgpu_dialect.SwizzleTransformAttr.get(self.swizzle) -def memref_with_transforms( - mem_ref: ir.Value, - transforms: Sequence[Tile | Transpose | Swizzle], -) -> ir.Value: - """Casts the memref to one that has a layout with the given transforms.""" - mem_ref_type = ir.MemRefType(mem_ref.type) +def set_in_transforms( + op: ir.OpView, transforms: Sequence[Sequence[Tile | Transpose | Swizzle]], +) -> None: + """Annotates an op with in_transforms.""" + if not transforms: + return - transform_attr = [t.attr() for t in transforms] - if not transform_attr: - return mem_ref + in_transforms = [] + smem_refs = filter(inference_utils.is_transformable_smem_memref, op.operands) # pylint: disable=undefined-variable + for _, result_transforms in jax.util.safe_zip(smem_refs, transforms): + in_transforms.append( + ir.ArrayAttr.get([t.attr() for t in result_transforms]) + ) - layout = mgpu_dialect.LayoutAttr.get(mem_ref_type.rank, transform_attr) - memref_new_type = ir.MemRefType.get( - mem_ref_type.shape, - mem_ref_type.element_type, - layout, - mem_ref_type.memory_space, - ) - return memref.cast(memref_new_type, mem_ref) + op.attributes["in_transforms"] = ir.ArrayAttr.get(in_transforms) class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase): @@ -2556,7 +2553,6 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase): ): del ctx smem_ref, tma_barrier = smem - smem_ref = memref_with_transforms(smem_ref, test_case.transforms) dialect_barrier = tma_barrier.as_dialect_barrier_memref() elt_type = ir.MemRefType(in_gmem_ref.type).element_type @@ -2571,7 +2567,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase): slice_indices = [arith.constant(i32, i) for i in test_case.slice_indices] # GMEM -> SMEM - mgpu_dialect.async_load( + load_op = mgpu_dialect.AsyncLoadOp( source=in_gmem_ref, destination=smem_ref, barrier=dialect_barrier, @@ -2579,6 +2575,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase): slice_lengths=test_case.slice_lengths, collective=ir.ArrayAttr.get([]), ) + set_in_transforms(load_op, [test_case.transforms]) parities = memref.load(tma_barrier.phases, []) parity, _ = tma_barrier.update_parities(parities) @@ -2623,58 +2620,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase): (x[input_slice]).reshape(test_case.shape_sliced), ) - @staticmethod - def pointwise_kernel_with_tma_cases(dtype: jnp.dtype): - @dataclasses.dataclass(frozen=True) - class TestCaseInput: - shape: tuple[int, ...] - transforms: tuple[Tile | Transpose | Swizzle, ...] = () - - result = [] - for swizzle in mgpu_dialect.SwizzlingMode: - n = swizzle * 8 // jnp.finfo(dtype).bits - if swizzle == mgpu_dialect.SwizzlingMode.kNoSwizzle: - # We need at least one case with no transforms, as this is handled - # differently. - result.append(TestCaseInput(shape=[128, n])) - result.extend([ - TestCaseInput( - shape=[128, n], - transforms=[Swizzle(swizzle)], - ), - TestCaseInput( - shape=[2, 3, 64, n], - transforms=[Transpose([0, 1, 2, 3]), Swizzle(swizzle)], - ), - TestCaseInput( - shape=[2, 3, 64, n], - transforms=[ - Transpose([1, 0, 2, 3]), - Transpose([1, 0, 2, 3]), - Swizzle(swizzle), - ], - ), - TestCaseInput( - shape=[2, 3, 64, n], - transforms=[Transpose([1, 0, 2, 3]), Swizzle(swizzle)], - ), - TestCaseInput( - shape=[128, n], - transforms=[Tile([64, n]), Swizzle(swizzle)], - ), - TestCaseInput( - shape=[2 * 64, 3 * n], - transforms=[ - Tile([64, n]), - Transpose([1, 0, 2, 3]), - Swizzle(swizzle), - ], - ), - ]) - return result - - @parameterized.parameters(pointwise_kernel_with_tma_cases(jnp.bfloat16)) - def test_pointwise_kernel_with_tma(self, test_case): + def test_pointwise_kernel_with_tma(self): def add( ctx: launch_context.LaunchContext, a_gmem_ref: ir.Value, @@ -2701,9 +2647,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase): # GMEM -> SMEM mgpu_dialect.async_load( source=a_gmem_ref, - destination=memref_with_transforms( - a_smem_ref, test_case.transforms - ), + destination=a_smem_ref, barrier=dialect_barrier, indices=zero_slice_indices, slice_lengths=shape, @@ -2711,9 +2655,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase): ) mgpu_dialect.async_load( source=b_gmem_ref, - destination=memref_with_transforms( - b_smem_ref, test_case.transforms - ), + destination=b_smem_ref, barrier=dialect_barrier, indices=zero_slice_indices, slice_lengths=shape, @@ -2740,9 +2682,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase): # SMEM -> GMEM mgpu_dialect.async_store( - source=memref_with_transforms( - result_smem_ref, test_case.transforms - ), + source=result_smem_ref, destination=result_gmem_ref, indices=zero_slice_indices, slice_lengths=shape, @@ -2752,114 +2692,76 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase): dtype = jnp.bfloat16 - jax_shape = jax.ShapeDtypeStruct(test_case.shape, dtype) + spec = jax.ShapeDtypeStruct((2, 3, 4, 64), dtype) kernel = mgpu.as_gpu_kernel( add, grid=(1, 1, 1), block=(128, 1, 1), - in_shape=(jax_shape, jax_shape), - out_shape=jax_shape, + in_shape=(spec, spec), + out_shape=spec, smem_scratch_shape=[ - jax_shape, - jax_shape, - jax_shape, + spec, + spec, + spec, core.TMABarrier(1), ], thread_semantics=mgpu.ThreadSemantics.Warpgroup, ) - x = self.prng.uniform(-1, 1, test_case.shape).astype(dtype) - y = self.prng.uniform(-1, 1, test_case.shape).astype(dtype) + x = self.prng.uniform(-1, 1, spec.shape).astype(dtype) + y = self.prng.uniform(-1, 1, spec.shape).astype(dtype) self.assertArraysEqual(jax.jit(kernel)(x, y), x + y + y) class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase): - @staticmethod - def wgmma_kernel_with_tma_cases(abtype: jnp.dtype): - @dataclasses.dataclass(frozen=True) - class TestCaseInput: - shape_a: tuple[int, ...] = () - shape_b: tuple[int, ...] = () - shape_res: tuple[int, ...] = () - transforms_a: tuple[Tile | Transpose | Swizzle, ...] = () - transforms_b: tuple[Tile | Transpose | Swizzle, ...] = () - transpose_a: bool = False - transpose_b: bool = False - load_a_in_registers: bool = False + @parameterized.named_parameters( + ( + f"swizzle={int(swizzle)}_{transpose_lhs=}_{transpose_rhs=}_{lhs_in_registers=}", + swizzle, + transpose_lhs, + transpose_rhs, + lhs_in_registers, + ) + for swizzle in mgpu_dialect.SwizzlingMode + for transpose_lhs in [False, True] + for transpose_rhs in [False, True] + for lhs_in_registers in [False, True] + ) + def test_wgmma_kernel_with_tma( + self, swizzle, transpose_lhs, transpose_rhs, load_a_in_registers + ): + if swizzle == mgpu_dialect.SwizzlingMode.kNoSwizzle: + self.skipTest("No swizzle is not supported by wgmma") - result = [] - for swizzle in [ - # TODO(dasenov): Add a test for kNoSwizzle, i.e. all swizzling modes. - mgpu_dialect.SwizzlingMode.k32ByteSwizzle, - mgpu_dialect.SwizzlingMode.k64ByteSwizzle, - mgpu_dialect.SwizzlingMode.k128ByteSwizzle, - ]: - k = swizzle // np.dtype(abtype).itemsize - groups_m = 4 - groups_n = 1 - groups_k = 1 - result.extend([ - TestCaseInput( - shape_a=[groups_m * 64, groups_k * k], - shape_b=[groups_k * k, groups_n * k], - shape_res=[groups_m * 64, groups_n * k], - ), - TestCaseInput( - shape_a=[groups_m * 64, groups_k * k], - shape_b=[groups_n * k, groups_k * k], - shape_res=[groups_m * 64, groups_n * k], - transpose_b=True, - ), - TestCaseInput( - shape_a=[groups_m * 64, groups_k * k], - shape_b=[groups_k * k, groups_n * k], - shape_res=[groups_m * 64, groups_n * k], - transforms_a=[Tile([64, k]), Swizzle(swizzle)], - transforms_b=[Tile([k, k]), Swizzle(swizzle)], - ), - TestCaseInput( - shape_a=[groups_m * 64, groups_k * k], - shape_b=[groups_k * k, groups_n * k], - shape_res=[groups_m * 64, groups_n * k], - transforms_a=[Tile([64, k]), Swizzle(swizzle)], - load_a_in_registers=True, - ), - ]) - # The below only works for 128-byte swizzling. Regardless of transposing, - # TMA needs the size of the last dimension to be compatible with the - # swizzle. - if swizzle == mgpu_dialect.SwizzlingMode.k128ByteSwizzle: - result.append( - TestCaseInput( - shape_a=[groups_k * k, groups_m * 64], - shape_b=[groups_k * k, groups_n * k], - shape_res=[groups_m * 64, groups_n * k], - transpose_a=True, - ) - ) - return result + if transpose_lhs or transpose_rhs: + self.skipTest("Transposes are not supported by transform inference yet.") - @parameterized.parameters(wgmma_kernel_with_tma_cases(jnp.bfloat16)) - def test_wgmma_kernel_with_tma(self, test_case): + swizzle_elems = swizzle // np.dtype(jnp.bfloat16).itemsize + tiling_m, tiling_n, tiling_k = 64, swizzle_elems, swizzle_elems + + groups_m, groups_n, groups_k = 4, 1, 1 + m, n, k = groups_m * tiling_m, groups_n * tiling_n, groups_k * tiling_k + + lhs_shape = (k, m) if transpose_lhs else (m, k) + rhs_shape = (n, k) if transpose_rhs else (k, n) + out_shape = (m, n) def matmul( ctx: launch_context.LaunchContext, - a_gmem_ref: ir.Value, - b_gmem_ref: ir.Value, + lhs_gmem_ref: ir.Value, + rhs_gmem_ref: ir.Value, result_gmem_ref: ir.Value, smem: list[ir.Value], ): del ctx - a_smem_ref, b_smem_ref, result_smem_ref, tma_barrier = smem - a_smem_ref = memref_with_transforms(a_smem_ref, test_case.transforms_a) - b_smem_ref = memref_with_transforms(b_smem_ref, test_case.transforms_b) + lhs_smem_ref, rhs_smem_ref, result_smem_ref, tma_barrier = smem dialect_barrier = tma_barrier.as_dialect_barrier_memref() - ab_elt_type = ir.MemRefType(a_gmem_ref.type).element_type - bytes_a = utils.bytewidth(ab_elt_type) * math.prod(test_case.shape_a) - bytes_b = utils.bytewidth(ab_elt_type) * math.prod(test_case.shape_b) + operand_elt_type = ir.MemRefType(lhs_gmem_ref.type).element_type + bytes_a = utils.bytewidth(operand_elt_type) * math.prod(lhs_shape) + bytes_b = utils.bytewidth(operand_elt_type) * math.prod(rhs_shape) mgpu_dialect.arrive_expect_tx( barrier=dialect_barrier, @@ -2869,19 +2771,19 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase): zero_i32 = arith.constant(ir.IntegerType.get_signless(32), 0) # GMEM -> SMEM mgpu_dialect.async_load( - source=a_gmem_ref, - destination=a_smem_ref, + source=lhs_gmem_ref, + destination=lhs_smem_ref, barrier=dialect_barrier, - indices=[zero_i32] * len(test_case.shape_a), - slice_lengths=test_case.shape_a, + indices=[zero_i32] * len(lhs_shape), + slice_lengths=lhs_shape, collective=ir.ArrayAttr.get([]), ) mgpu_dialect.async_load( - source=b_gmem_ref, - destination=b_smem_ref, + source=rhs_gmem_ref, + destination=rhs_smem_ref, barrier=dialect_barrier, - indices=[zero_i32] * len(test_case.shape_b), - slice_lengths=test_case.shape_b, + indices=[zero_i32] * len(rhs_shape), + slice_lengths=rhs_shape, collective=ir.ArrayAttr.get([]), ) @@ -2889,29 +2791,34 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase): parity, _ = tma_barrier.update_parities(parities) mgpu_dialect.wait(dialect_barrier, parity) - # SMEM -> Registers - a_operand = a_smem_ref - zero_index = arith.constant(ir.IndexType.get(), 0) - if test_case.load_a_in_registers: - a_vector_type = ir.VectorType.get(test_case.shape_a, ab_elt_type) - zero_vector_indices = [zero_index] * len(test_case.shape_a) - a_operand = vector.load(a_vector_type, a_smem_ref, zero_vector_indices) - # Computation shape_result = ir.MemRefType(result_gmem_ref.type).shape result_elt_type = ir.MemRefType(result_gmem_ref.type).element_type + acc_elt_type = ir.F32Type.get() + acc_type = ir.VectorType.get(shape_result, acc_elt_type) zero_acc = arith.constant( - result_elt_type, ir.FloatAttr.get(result_elt_type, 0.0) - ) - accumulator = vector.splat( - ir.VectorType.get(shape_result, result_elt_type), zero_acc + result_elt_type, ir.FloatAttr.get(acc_elt_type, 0.0) ) + accumulator = vector.splat(acc_type, zero_acc) + + if transpose_lhs: + lhs_smem_ref = utils.memref_transpose(lhs_smem_ref, (1, 0)) + if transpose_rhs: + rhs_smem_ref = utils.memref_transpose(rhs_smem_ref, (1, 0)) + + zero_index = arith.constant(ir.IndexType.get(), 0) + if load_a_in_registers: + # SMEM -> Registers + lhs_ty = ir.VectorType.get(lhs_shape, operand_elt_type) + zero_vector_indices = [zero_index] * len(lhs_shape) + lhs_operand = vector.load(lhs_ty, lhs_smem_ref, zero_vector_indices) + else: + lhs_operand = lhs_smem_ref + result = mgpu_dialect.wgmma( accumulator, - a_operand, - b_smem_ref, - transpose_a=test_case.transpose_a, - transpose_b=test_case.transpose_b, + lhs_operand, + rhs_smem_ref, ) nvvm.wgmma_commit_group_sync_aligned() @@ -2929,38 +2836,41 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase): ) nvvm.cp_async_bulk_wait_group(0) - abtype = jnp.bfloat16 + operand_type = jnp.bfloat16 acctype = jnp.float32 - a_jax_shape = jax.ShapeDtypeStruct(test_case.shape_a, abtype) - b_jax_shape = jax.ShapeDtypeStruct(test_case.shape_b, abtype) - result_jax_shape = jax.ShapeDtypeStruct(test_case.shape_res, acctype) + lhs_jax_shape = jax.ShapeDtypeStruct(lhs_shape, operand_type) + rhs_jax_shape = jax.ShapeDtypeStruct(rhs_shape, operand_type) + result_jax_shape = jax.ShapeDtypeStruct(out_shape, acctype) kernel = mgpu.as_gpu_kernel( matmul, grid=(1, 1, 1), block=(128, 1, 1), - in_shape=(a_jax_shape, b_jax_shape), + in_shape=(lhs_jax_shape, rhs_jax_shape), out_shape=result_jax_shape, smem_scratch_shape=[ - a_jax_shape, - b_jax_shape, + lhs_jax_shape, + rhs_jax_shape, result_jax_shape, core.TMABarrier(1), ], thread_semantics=mgpu.ThreadSemantics.Warpgroup, ) - x = self.prng.uniform(-1, 1, test_case.shape_a).astype(abtype) - y = self.prng.uniform(-1, 1, test_case.shape_b).astype(abtype) + prng_key = jax.random.key(1234) + k0, k1 = jax.random.split(prng_key, 2) + + x = jax.random.randint(k0, lhs_shape, 0, 2).astype(operand_type) + y = jax.random.randint(k1, rhs_shape, 0, 2).astype(operand_type) transpose = lambda x, t: x.T if t else x self.assertArraysAllClose( jax.jit(kernel)(x, y), np.matmul( - transpose(x.reshape(test_case.shape_a), test_case.transpose_a), - transpose(y.reshape(test_case.shape_b), test_case.transpose_b), + transpose(x, transpose_lhs), + transpose(y, transpose_rhs) ), - atol=1e-5, - rtol=1e-5, + atol=0, + rtol=0, ) From 47e8effdcea5c17dd9f974f1020cfd6bf4630f76 Mon Sep 17 00:00:00 2001 From: Jacob Burnim Date: Tue, 18 Mar 2025 10:59:17 -0700 Subject: [PATCH 08/20] Adds option to initialize buffers to NaNs or zeros in TPU interpret mode. --- jax/_src/pallas/mosaic/interpret.py | 44 +++++++++++++++++------ tests/pallas/tpu_pallas_interpret_test.py | 31 ++++++++++++++++ 2 files changed, 65 insertions(+), 10 deletions(-) diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index e92de91f4..1ad7be815 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -83,10 +83,15 @@ class TPUInterpretParams: replaced with arrays all of `jnp.inf`. Additionaly any floating point operands to any operation will be replaced with (arrays of) `jnp.inf`. Default: False. + uninitialized_memory: If "nan", allocated buffers are initialized to + to contain all NaNs (or to their maximum possible value for integers). + If "zero", allocated buffers are initialized to all zeros. + Default: "nan". """ dma_execution_mode: Literal["eager", "on_wait"] = "on_wait" detect_races: bool = False skip_floating_point_ops: bool = False + uninitialized_memory: Literal["nan", "zero"] = "nan" VectorClock = np.ndarray @@ -1114,7 +1119,8 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params, interpret_params): jax.ShapeDtypeStruct((), jnp.int16), device_id, TPU_MEMORY_SPACE_IDXS[v.aval.memory_space], - primitives.uninitialized_value(v.aval.shape, v.aval.dtype), + _uninitialized_value( + v.aval.shape, v.aval.dtype, interpret_params), ordered=True)) out = _interpret(eqn.params['jaxpr'], *deferred_invals(), *allocs) @@ -1279,16 +1285,19 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params, interpret_params): def _initialize_output_vals( block_mappings_output: Iterable[BlockMapping], - input_args, input_output_aliases) -> Sequence[jax.Array]: + input_args, input_output_aliases, + interpret_params: TPUInterpretParams, +) -> Sequence[jax.Array]: oi_map = {v: k for k, v in input_output_aliases} output_vals = [] for i, bm in enumerate(block_mappings_output): if i in oi_map: output_vals.append(input_args[oi_map[i]]) else: - output_vals.append(primitives.uninitialized_value( + output_vals.append(_uninitialized_value( bm.array_shape_dtype.shape, - bm.array_shape_dtype.dtype)) + bm.array_shape_dtype.dtype, + interpret_params)) return output_vals def _compute_start_indices(block_mapping, loop_idx, *args): @@ -1319,7 +1328,20 @@ def _maybe_dynamic_slice(start_idx, block_shape, value, is_indexing): dtype=np.bool_)]) return lax.squeeze(output, squeeze_dims) -def _pad_to_block_dimension(value, block_shape): +def _uninitialized_value(shape, dtype, interpret_params): + if interpret_params.uninitialized_memory == 'nan': + if jnp.issubdtype(dtype, jnp.floating): + return jnp.full(shape, jnp.nan, dtype) + elif jnp.issubdtype(dtype, jnp.integer): + return jnp.full(shape, jnp.iinfo(dtype).max, dtype) + elif jnp.issubdtype(dtype, jnp.bool): + return jnp.full(shape, False, dtype) + if interpret_params.uninitialized_memory == 'zero': + return jnp.full(shape, 0, dtype) + raise NotImplementedError( + interpret_params.uninitialized_memory + ' + ' + str(dtype)) + +def _pad_to_block_dimension(value, block_shape, interpret_params): """Pads values so the shape evenly divides into block dimensions. For example, if values has a shape of (33, 2, 5) with a block_shape of @@ -1338,7 +1360,7 @@ def _pad_to_block_dimension(value, block_shape): ) if padded_shape != value.shape: pad_width = tuple((0, a-b) for a, b in zip(padded_shape, value.shape)) - pad_value = primitives.uninitialized_value(shape=(), dtype=value.dtype) + pad_value = _uninitialized_value((), value.dtype, interpret_params) value = jnp.pad(value, pad_width, constant_values=pad_value) return value @@ -1397,7 +1419,7 @@ def interpret_pallas_call( ] num_inputs = grid_mapping.num_inputs input_args = [ - _pad_to_block_dimension(a, bs) + _pad_to_block_dimension(a, bs, interpret_params) for a, bs in zip(input_args, block_shapes[:num_inputs]) ] @@ -1407,11 +1429,12 @@ def interpret_pallas_call( output_vals = _initialize_output_vals( grid_mapping.block_mappings_output, scalars + input_args, - input_output_aliases) + input_output_aliases, + interpret_params) num_outputs = grid_mapping.num_outputs output_block_shapes = block_shapes[num_inputs : num_inputs + num_outputs] for out_val, bs in zip(output_vals, output_block_shapes): - padded_val = _pad_to_block_dimension(out_val, bs) + padded_val = _pad_to_block_dimension(out_val, bs, interpret_params) output_buffer_shapes.append(padded_val.shape) output_buffer_ids.append(callback.io_callback( _allocate_buffer, @@ -1466,7 +1489,8 @@ def interpret_pallas_call( jax.ShapeDtypeStruct((), jnp.int16), device_id, TPU_MEMORY_SPACE_IDXS[var.aval.memory_space], - primitives.uninitialized_value(var.aval.shape, var.aval.dtype), + _uninitialized_value( + var.aval.shape, var.aval.dtype, interpret_params), ordered=True)) _, input_ids, kernel_output_ids, _ = split_list( diff --git a/tests/pallas/tpu_pallas_interpret_test.py b/tests/pallas/tpu_pallas_interpret_test.py index 71e91a697..bc589855b 100644 --- a/tests/pallas/tpu_pallas_interpret_test.py +++ b/tests/pallas/tpu_pallas_interpret_test.py @@ -156,5 +156,36 @@ class InterpretTest(jtu.JaxTestCase): lowered = jax.jit(matmul).lower(x, y).as_text(dialect="stablehlo") self.assertNotIn("dot_general", lowered) + @parameterized.parameters('nan', 'zero') + def test_uninitialized_memory(self, uninitialized_memory): + def kernel(o1_ref, o2_ref, o3_ref, t1_ref, t2_ref): + o1_ref[...] = t1_ref[...] + o2_ref[...] = t2_ref[...] + + x, y, z = pl.pallas_call( + kernel, + out_shape=[ + jax.ShapeDtypeStruct((8, 128), jnp.bfloat16), + jax.ShapeDtypeStruct((8, 128), jnp.int16), + jax.ShapeDtypeStruct((8, 128), jnp.float32), + ], + in_specs=[], + scratch_shapes=[ + pltpu.VMEM((8, 128), jnp.bfloat16), + pltpu.VMEM((8, 128), jnp.int16), + ], + interpret=mosaic_interpret.TPUInterpretParams( + uninitialized_memory=uninitialized_memory), + )() + if uninitialized_memory == 'nan': + self.assertTrue(jnp.isnan(x).all()) + np.testing.assert_equal(np.array(y), 32767) + self.assertTrue(jnp.isnan(z).all()) + if uninitialized_memory == 'zero': + np.testing.assert_equal(np.array(x), 0) + np.testing.assert_equal(np.array(y), 0) + np.testing.assert_equal(np.array(z), 0) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 942ff38e3676b8936a860700e3b0e5184ba2c5f3 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Tue, 18 Mar 2025 12:50:36 -0700 Subject: [PATCH 09/20] fix to ragged_all_to_all transpose PiperOrigin-RevId: 738110447 --- jax/_src/lax/parallel.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 764e4dcbe..221fe2a9e 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -1301,6 +1301,7 @@ def _ragged_all_to_all_transpose( mask = jax.numpy.cumsum( jax.numpy.zeros(t.shape[0], dtype='int32').at[output_offsets_].set(1)\ .at[output_offsets_ + recv_sizes].add(-1)) + mask = jax.numpy.expand_dims(mask, (*range(1, t.ndim),)) output_t = jax.numpy.where(mask, 0, t) return [operand_t, output_t] + [None] * 4 From 76d9890bb7b2fca21a1061af08a915b9d1b275ef Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 18 Mar 2025 13:00:57 -0700 Subject: [PATCH 10/20] Run the stream annotation tests on 2 devices so that it can be tested in TAP PiperOrigin-RevId: 738113725 --- tests/memories_test.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/memories_test.py b/tests/memories_test.py index a08c5f36c..0ca973c4d 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -1664,31 +1664,30 @@ class StreamAnnotationTest(jtu.JaxTestCase): def test_stream_annotation_inside_shmap(self): if not jtu.test_device_matches(["gpu"]): self.skipTest("Stream annotation is only supported on GPU.") - mesh = jtu.create_mesh((2, 2), ('x', 'y')) - s = NamedSharding(mesh, P('x', 'y')) - np_inp = np.ones((8, 8)) + mesh = jtu.create_mesh((2,), ('x',)) + s = NamedSharding(mesh, P('x')) + np_inp = np.ones((8,)) arr1 = jax.device_put(np_inp, s) arr2 = jax.device_put(np_inp, s) @compute_on('gpu_stream:1') @jax.jit def g(x, y): - return x @ y + return x * y @compute_on('gpu_stream:2') @jax.jit def h(x, y): - return x @ y + return x * y def f(x, y): z = g(x, y) w = h(3 * x, 2 * y) return z + w - out = jax.jit(shard_map(f, mesh=mesh, - in_specs=(P('x', 'y'), P('x', 'y')), - out_specs=P('x', 'y')))(arr1, arr2) - self.assertArraysEqual(out, arr1 * 28) + out = jax.jit(shard_map(f, mesh=mesh, in_specs=(P('x'), P('x')), + out_specs=P('x')))(arr1, arr2) + self.assertArraysEqual(out, arr1 * 7) class ActivationOffloadingTest(jtu.JaxTestCase): From 54691b125ab4b6f88c751dae460e4d51f5cf834a Mon Sep 17 00:00:00 2001 From: Gleb Pobudzey Date: Tue, 18 Mar 2025 13:22:10 -0700 Subject: [PATCH 11/20] [Mosaic GPU] Support reads/writes from SMEM to WGMMARowFragLayout arrays. PiperOrigin-RevId: 738121106 --- .../mosaic/gpu/fragmented_array.py | 60 ++++++++++++++++++- tests/mosaic/gpu_test.py | 15 +++++ 2 files changed, 74 insertions(+), 1 deletion(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 8b8fdaceb..5daed8416 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -387,7 +387,21 @@ class WGMMARowFragLayout: """[m] matrix, where m % 64 == 0.""" def thread_idxs(self, shape): - raise NotImplementedError + index = ir.IndexType.get() + assert len(shape) == 1 + assert shape[0] % 64 == 0 + tid = arith.index_cast(ir.IndexType.get(), mgpu.thread_idx()) + tid_wg = arith.remui(tid, c(WARPGROUP_SIZE, index)) + warp_idx = arith.divui(tid_wg, c(32, index)) + lane_id = arith.remui(tid_wg, c(32, index)) + row_base = arith.addi( + arith.divui(lane_id, c(4, index)), arith.muli(warp_idx, c(16, index)) + ) + + for row_group in range(0, shape[0], 64): + for row_subgroup in (0, 8): + row = arith.addi(row_base, c(row_group + row_subgroup, index)) + yield (row,) @dataclasses.dataclass(frozen=True) @@ -660,6 +674,31 @@ class FragmentedArray: vecs = [vector.load(vec_ty, ref, vec_idx) for vec_idx in layout.thread_idxs(shape)] return cls(_registers=np.array(vecs), _layout=layout, _is_signed=is_signed) + @classmethod + def load_wgmma_row( + cls, + ref: ir.Value, + *, + is_signed: bool | None = None, + ): + if not ir.MemRefType.isinstance(ref.type): + raise TypeError(ref.type) + + ref_ty = ir.MemRefType(ref.type) + shape = tuple(ref_ty.shape) + if len(shape) != 1: + raise ValueError("WGMMARowFragLayout requires a 1D shape") + if shape[0] % 64: + raise ValueError( + "WGMMARowFragLayout requires shape[0] to be a multiple of 64" + ) + + layout = WGMMARowFragLayout() + registers = [memref.load(ref, [idx]) for (idx,) in layout.thread_idxs(shape)] + registers = np.array(registers).reshape(-1, 2) + return cls(_registers=registers, _layout=layout, _is_signed=is_signed) + + @classmethod def splat(cls, value, shape, layout=None, *, is_signed: bool | None = None): layout = layout or WGSplatFragLayout(shape) @@ -1743,6 +1782,8 @@ class FragmentedArray: ) match self.layout: + case WGMMARowFragLayout(): + self._store_untiled_wgmma_row(ref) case WGSplatFragLayout(): vs_unsupported() self._store_untiled_splat(ref) @@ -1789,6 +1830,23 @@ class FragmentedArray: for idx, reg in zip(idxs, self.registers.flat): vector.store(reg, ref_, idx) + def _store_untiled_wgmma_row(self, ref: ir.Value): + """Stores an array with a WGMMA row layout.""" + assert self.layout == WGMMA_ROW_LAYOUT + index = ir.IndexType.get() + tid = arith.index_cast(ir.IndexType.get(), mgpu.thread_idx()) + + is_first = arith.cmpi( + arith.CmpIPredicate.eq, arith.remui(tid, c(4, index)), c(0, index) + ) + # Consecutive groups of 4 threads hold the same value in this layout, + # therefore we only need to transfer data from one of them. + with utils.when(is_first): + for (idx,), value in zip( + self.layout.thread_idxs(self.shape), self.registers.flatten() + ): + memref.store(value, ref, [idx]) + def _store_untiled_tiled(self, ref: ir.Value, *, vector_store: bool = True): """Stores an array with a tiled layout. Not optimized at the moment.""" if utils.bitwidth(self.mlir_dtype) < 8: diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 1fcd68641..e7bd7fad3 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1946,6 +1946,21 @@ class FragmentedArrayTest(TestCase): )(inp) np.testing.assert_array_equal(inp, result) + @parameterized.product(in_shape=((128,), (64,))) + def test_wgmma_row_load_store_with_layout(self, in_shape): + def kernel(ctx, *args): + gmem_input, gmem_output, (smem_input, smem_output) = args + copy(gmem_input, smem_input) + t = mgpu.FragmentedArray.load_wgmma_row(smem_input) + t.store_untiled(smem_output) + copy(smem_output, gmem_output) + + inp = out = self.prng.uniform(-1, 1, in_shape).astype(jnp.float32) + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (inp,), out, [inp, out], + )(inp) + np.testing.assert_array_equal(inp, result) + def test_warp_tree_reduce(self): def kernel(ctx, out, *_): del ctx From 080804c78dcf9695396c298cd3760ea8bda778ee Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 18 Mar 2025 14:50:49 -0700 Subject: [PATCH 12/20] Fix logging_test fails on Linux with NVIDIA Driver only. Some GPU tests in //tests/logging_test fail on Linux with NVIDIA driver only when we use hermetic CUDA (CUDA isn't installed on Linux). Reason: method tsl::Env::Default()->GetExecutablePath()` doesn't work properly with command flag (-c). As result subprocessor couldn't get path to logging_test.py file and convert it to path of runtime where CUDA hermetic libraries are placed. Solution: Save python program to file in runtime directory then run script from the file. PiperOrigin-RevId: 738152663 --- tests/logging_test.py | 148 +++++++++++++++++------------------------- 1 file changed, 61 insertions(+), 87 deletions(-) diff --git a/tests/logging_test.py b/tests/logging_test.py index a83058095..cfe10c5a9 100644 --- a/tests/logging_test.py +++ b/tests/logging_test.py @@ -15,9 +15,9 @@ import contextlib import io import logging +import os import platform import re -import shlex import subprocess import sys import tempfile @@ -78,6 +78,31 @@ def capture_jax_logs(): logger.removeHandler(handler) +# Saves and runs script from the file in order to fix the problem with +# `tsl::Env::Default()->GetExecutablePath()` not working properly with +# command flag. +def _run(program, env_var = {}): + # strip the leading whitespace from the program script + program = re.sub(r"^\s+", "", program, flags=re.MULTILINE) + + with tempfile.NamedTemporaryFile( + mode="w+", encoding="utf-8", suffix=".py", dir=os.getcwd() + ) as f: + f.write(textwrap.dedent(program)) + f.flush() + python = sys.executable + assert "python" in python + if env_var: + env_var.update(os.environ) + else: + env_var = os.environ + + # Make sure C++ logging is at default level for the test process. + p = subprocess.run([python, f.name], env=env_var, capture_output=True, text=True) + + return type("", (object,), { "stdout": p.stdout, "stderr": p.stderr }) + + class LoggingTest(jtu.JaxTestCase): @unittest.skipIf(platform.system() == "Windows", @@ -90,36 +115,25 @@ class LoggingTest(jtu.JaxTestCase): if sys.executable is None: raise self.skipTest("test requires access to python binary") - # Save script in file to fix the problem with - # `tsl::Env::Default()->GetExecutablePath()` not working properly with - # command flag. - with tempfile.NamedTemporaryFile( - mode="w+", encoding="utf-8", suffix=".py" - ) as f: - f.write(textwrap.dedent(""" + o = _run(""" import jax jax.device_count() f = jax.jit(lambda x: x + 1) f(1) f(2) jax.numpy.add(1, 1) - """)) - python = sys.executable - assert "python" in python - # Make sure C++ logging is at default level for the test process. - proc = subprocess.run([python, f.name], capture_output=True) + """) - lines = proc.stdout.split(b"\n") - lines.extend(proc.stderr.split(b"\n")) - allowlist = [ - b"", - ( - b"An NVIDIA GPU may be present on this machine, but a" - b" CUDA-enabled jaxlib is not installed. Falling back to cpu." - ), - ] - lines = [l for l in lines if l not in allowlist] - self.assertEmpty(lines) + lines = o.stdout.split("\n") + lines.extend(o.stderr.split("\n")) + allowlist = [ + ( + "An NVIDIA GPU may be present on this machine, but a" + " CUDA-enabled jaxlib is not installed. Falling back to cpu." + ), + ] + lines = [l for l in lines if l in allowlist] + self.assertEmpty(lines) def test_debug_logging(self): # Warmup so we don't get "No GPU/TPU" warning later. @@ -164,19 +178,12 @@ class LoggingTest(jtu.JaxTestCase): if sys.executable is None: raise self.skipTest("test requires access to python binary") - program = """ - import jax # this prints INFO logging from backend imports - jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation) - """ + o = _run(""" + import jax # this prints INFO logging from backend imports + jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation) + """, { "JAX_LOGGING_LEVEL": "INFO" }) - # strip the leading whitespace from the program script - program = re.sub(r"^\s+", "", program, flags=re.MULTILINE) - - # test INFO - cmd = shlex.split(f"env JAX_LOGGING_LEVEL=INFO {sys.executable} -c" - f" '{program}'") - p = subprocess.run(cmd, capture_output=True, text=True) - log_output = p.stderr + log_output = o.stderr info_lines = log_output.split("\n") self.assertGreater(len(info_lines), 0) self.assertIn("INFO", log_output) @@ -194,22 +201,14 @@ class LoggingTest(jtu.JaxTestCase): jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation) """ - # strip the leading whitespace from the program script - program = re.sub(r"^\s+", "", program, flags=re.MULTILINE) + o = _run(program, { "JAX_LOGGING_LEVEL": "DEBUG" }) - # test DEBUG - cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c" - f" '{program}'") - p = subprocess.run(cmd, capture_output=True, text=True) - log_output = p.stderr + log_output = o.stderr self.assertIn("INFO", log_output) self.assertIn("DEBUG", log_output) - # test JAX_DEBUG_MODULES - cmd = shlex.split(f"env JAX_DEBUG_LOG_MODULES=jax {sys.executable} -c" - f" '{program}'") - p = subprocess.run(cmd, capture_output=True, text=True) - log_output = p.stderr + o = _run(program, { "JAX_DEBUG_LOG_MODULES": "jax" }) + log_output = o.stderr self.assertIn("DEBUG", log_output) @jtu.skip_on_devices("tpu") @@ -220,22 +219,15 @@ class LoggingTest(jtu.JaxTestCase): raise self.skipTest("test requires access to python binary") _separator = "---------------------------" - program = f""" + o = _run(f""" import sys import jax # this prints INFO logging from backend imports jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation) jax.config.update("jax_logging_level", None) sys.stderr.write("{_separator}") jax.jit(lambda x: x)(1) # should not log anything now - """ - - # strip the leading whitespace from the program script - program = re.sub(r"^\s+", "", program, flags=re.MULTILINE) - - cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c" - f" '{program}'") - p = subprocess.run(cmd, capture_output=True, text=True) - log_output = p.stderr + """, {"JAX_LOGGING_LEVEL": "DEBUG"}) + log_output = o.stderr m = re.search(_separator, log_output) self.assertTrue(m is not None) log_output_verbose = log_output[:m.start()] @@ -252,19 +244,13 @@ class LoggingTest(jtu.JaxTestCase): if sys.executable is None: raise self.skipTest("test requires access to python binary") - program = """ + o = _run(""" import jax # this prints INFO logging from backend imports jax.config.update("jax_debug_log_modules", "jax._src.compiler,jax._src.dispatch") jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation) - """ + """, { "JAX_LOGGING_LEVEL": "DEBUG" }) - # strip the leading whitespace from the program script - program = re.sub(r"^\s+", "", program, flags=re.MULTILINE) - - cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c" - f" '{program}'") - p = subprocess.run(cmd, capture_output=True, text=True) - log_output = p.stderr + log_output = o.stderr self.assertNotEmpty(log_output) log_lines = log_output.strip().split("\n") # only one tracing line should be printed, if there's more than one @@ -285,31 +271,19 @@ class LoggingTest(jtu.JaxTestCase): jax.distributed.initialize("127.0.0.1:12345", num_processes=1, process_id=0) """ - # strip the leading whitespace from the program script - program = re.sub(r"^\s+", "", program, flags=re.MULTILINE) + o = _run(program, { "JAX_LOGGING_LEVEL": "DEBUG" }) + self.assertIn("Initializing CoordinationService", o.stderr) - # verbose logging: DEBUG, VERBOSE - cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c" - f" '{program}'") - p = subprocess.run(cmd, capture_output=True, text=True) - self.assertIn("Initializing CoordinationService", p.stderr) - - cmd = shlex.split(f"env JAX_LOGGING_LEVEL=INFO {sys.executable} -c" - f" '{program}'") - p = subprocess.run(cmd, capture_output=True, text=True) - self.assertIn("Initializing CoordinationService", p.stderr) + o = _run(program, { "JAX_LOGGING_LEVEL": "INFO" }) + self.assertIn("Initializing CoordinationService", o.stderr) # verbose logging: WARNING, None - cmd = shlex.split(f"env JAX_LOGGING_LEVEL=WARNING {sys.executable} -c" - f" '{program}'") - p = subprocess.run(cmd, capture_output=True, text=True) - self.assertNotIn("Initializing CoordinationService", p.stderr) + o = _run(program, { "JAX_LOGGING_LEVEL": "WARNING" }) + self.assertNotIn("Initializing CoordinationService", o.stderr) - cmd = shlex.split(f"{sys.executable} -c" - f" '{program}'") - p = subprocess.run(cmd, capture_output=True, text=True) + o = _run(program) if int(_default_TF_CPP_MIN_LOG_LEVEL) >= 1: - self.assertNotIn("Initializing CoordinationService", p.stderr) + self.assertNotIn("Initializing CoordinationService", o.stderr) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 0fb59747f0a1e6d90bd07b85323ab2a2e8868c91 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Tue, 18 Mar 2025 14:56:23 -0700 Subject: [PATCH 13/20] Support tuples in custom_partitioning. PiperOrigin-RevId: 738154413 --- jax/_src/custom_partitioning.py | 2 +- tests/pjit_test.py | 41 +++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/jax/_src/custom_partitioning.py b/jax/_src/custom_partitioning.py index 658a6f7a2..537407151 100644 --- a/jax/_src/custom_partitioning.py +++ b/jax/_src/custom_partitioning.py @@ -179,7 +179,7 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape, for sharding, s in zip(result_shardings, result_shapes) ] closed_jaxpr = jax.make_jaxpr(lower_fn, axis_env=list(mesh.shape.items()))( - *tiled_args + *info.in_tree.unflatten(tiled_args) ) if ([(o.shape, o.dtype) for o in closed_jaxpr.out_avals] != [(t.shape, t.dtype) for t in tiled_results]): diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 76336920b..293b37a9f 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1680,6 +1680,47 @@ class CustomPartitionerTest(jtu.JaxTestCase): jit_f = jax.jit(f, in_shardings=s, out_shardings=s) self.assertArraysEqual(x, jit_f(x)) + @jtu.with_mesh([('x', 4), ('y', 2)]) + def test_custom_partitioner_pytree_inputs(self): + self.skip_if_custom_partitioning_not_supported() + + def partition(mesh, arg_shapes, result_shape): + def lower_fn(xs): + x, y, z = xs + return x + y + z + + return ( + mesh, + lower_fn, + arg_shapes[0][0].sharding, + jax.tree.map(lambda x: x.sharding, arg_shapes), + ) + + def infer_sharding_from_operands(mesh, arg_shapes, result_shape): + return arg_shapes[0][0].sharding + + def propagate_user_sharding(mesh, user_shape): + return user_shape.sharding + + @custom_partitioning + def f(xs): + x, y, z = xs + return x + y + z + + f.def_partition( + infer_sharding_from_operands=infer_sharding_from_operands, + partition=partition, + propagate_user_sharding=propagate_user_sharding, + sharding_rule='i j, i j, i j -> i j', + ) + + def f2(a): + return a + f((a, a, a)) + + pjit_f = pjit(f2, in_shardings=(P(None, 'x')), out_shardings=P('x')) + x = np.asarray(np.random.randint(0, 20, (32, 16)), dtype=np.float32) + self.assertArraysEqual(x * 4, pjit_f(x)) + @jtu.pytest_mark_if_available('multiaccelerator') class AutoShardingPjitTest(jtu.JaxTestCase): From 01a110c4c9b4d19f897f073d65d84a319387ccb6 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 18 Mar 2025 15:50:27 -0700 Subject: [PATCH 14/20] Better mosaic lowering for dynamic shapes, extend an interpreter into shape_poly dimexpr and lower them alongside the graph if we are in a dynamic export regime. PiperOrigin-RevId: 738171437 --- jax/_src/pallas/core.py | 9 +- jax/_src/pallas/mosaic/lowering.py | 127 +++++++++++++++++++++++++++-- tests/pallas/pallas_test.py | 5 +- 3 files changed, 129 insertions(+), 12 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 5342a6946..206c2a73f 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -35,6 +35,7 @@ from jax._src import linear_util as lu from jax._src import state from jax._src import tree_util from jax._src import util +from jax._src.export._export import export from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.state import discharge as state_discharge @@ -1165,14 +1166,16 @@ jax_core.custom_typechecks[core_map_p] = _core_map_typecheck_rule def lower_as_mlir( - f, *args, dynamic_shapes=False, device=None, **kwargs + f, *args, dynamic_shapes=False, device=None, static_argnames=(), **kwargs ) -> mlir.ir.Module: with pallas_export_experimental(dynamic_shapes): - lowered = jax.jit(f, device=device).lower(*args, **kwargs) - stablehlo = lowered.compiler_ir(dialect="stablehlo") + f = jax.jit(f, device=device, static_argnames=static_argnames) + exported = export(f, platforms=["tpu"])(*args, **kwargs) + stablehlo = exported.mlir_module() return stablehlo # type: ignore[return-value] + _out_shape_to_aval_mapping: dict[ type[Any], Callable[[Any], jax_core.AbstractValue] ] = {} diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 4efb2b276..10b9de748 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -40,6 +40,7 @@ from jax._src import source_info_util from jax._src import state from jax._src import traceback_util from jax._src.cloud_tpu_init import is_cloud_tpu_older_than +from jax._src.export._export import export from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.lax import lax as lax_internal @@ -89,6 +90,11 @@ BOOL_MEMREF_TYPE = np.dtype('int32') # The value interpreted as a dynamic dimension by MLIR. MLIR_DYNAMIC = -9223372036854775808 +# TODO(mvoz): Find a way to make this a contract we can share with the +# export specialization step in XLA export. +DIM_UPPER_BOUND = np.iinfo(np.int32).max +DIM_LOWER_BOUND = -128 + partial = functools.partial map, unsafe_map = safe_map, map # pylint: disable=redefined-builtin zip, unsafe_zip = safe_zip, zip # pylint: disable=redefined-builtin @@ -102,17 +108,49 @@ class MeshContext: # Note - On Export Placeholders # -# Mosaic uses vector IR, which does not have a concept of dynamic -# dimensions. We need to come up with a way to represent dynamic dimensions in -# vector IR, and so we use placeholders, which are later replaced during -# specialization. +# Since the vector dialect used by Mosaic does not support dynamic shapes, +# we replace all top-level symbolic dimensions with placeholder +# constants (between max(int32) - 128 and max(int32)) and we keep a +# mapping from the placeholder constants to SHLO functions that encode +# the symbolic dimension expression, as a function of the dimension +# variables. +# +# The calling convention of the produced MLIR module is the same as +# regular mosaic module, except we add on two new attributes to the custom call +# *per* intermediary placeholder dimension. +# +# The attributes are: +# +# tpu.dynamic_dimension_mapping_arg_name_ +# tpu.dynamic_dimension_mapping_module_ +# +# The first attribute is a comma-separated list of the dimension variables +# that are used to compute the symbolic dimension expression for the +# placeholder. The second attribute is the MLIR module that contains the +# SHLO functions that compute the symbolic dimension expression for the +# placeholder. class LoweringDynamicShapeEnv: - dim_expr_to_placeholder: dict[Any, ir.Value] = {} + dim_expr_to_placeholder: dict[shape_poly._DimExpr, int] = {} + placeholder_to_dim_expr: dict[int, shape_poly._DimExpr] = {} def to_placeholder(self, dim_expr: Any) -> ir.Value: + if jax_core.is_constant_dim(dim_expr): + # avoid ints, these are not dynamic + return dim_expr if dim_expr not in self.dim_expr_to_placeholder: - next_val = np.iinfo(np.int32).max - len(self.dim_expr_to_placeholder) + next_val = DIM_UPPER_BOUND - len(self.dim_expr_to_placeholder) + if next_val < DIM_LOWER_BOUND: + # In practice, even with the largest of programs, we see rarely see + # anything even close to this limit. It is arbitrary, and can be safely + # increased if needed. + raise ValueError( + "Too many dynamic shapes in the input. Mosaic currently only" + " supports up to 128 dynamic dimension values." + ) self.dim_expr_to_placeholder[dim_expr] = next_val + # Reverse mapping - this is consumed to generate a table that is either + # input<>placeholder or intermediary computation<>placeholder. + self.placeholder_to_dim_expr[next_val] = dim_expr return self.dim_expr_to_placeholder[dim_expr] @@ -622,6 +660,7 @@ def lower_jaxpr_to_module( "Pallas TPU requires a libTPU version that's at most a month old" ) debug_info = jaxpr.debug_info + _mosaic_lowering_dynamic_shape_env = None if dynamic_shape_replacement_enabled: _mosaic_lowering_dynamic_shape_env = LoweringDynamicShapeEnv() @@ -663,10 +702,12 @@ def lower_jaxpr_to_module( for_verification=for_verification, forward_compatible=lowering_context.is_forward_compat(), dynamic_shape_replacement_fn=dynamic_shape_replacement_fn, + dynamic_shape_replacement_enabled=dynamic_shape_replacement_enabled, ) m.body.append(func_op) sym_tab.insert(func_op) window_params = [] + static_grid = None grid = mosaic_grid_mapping.grid if grid: for i, bm in enumerate(grid_mapping.block_mappings): @@ -738,7 +779,6 @@ def lower_jaxpr_to_module( ] static_grid = dynamic_shape_replacement_fn(static_grid) func_op.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(static_grid) - func_op.attributes["scalar_prefetch"] = ir.IntegerAttr.get( ir.IntegerType.get_signless(64), len(mosaic_grid_mapping.scalar_prefetch_types)) func_op.attributes["scratch_operands"] = ir.IntegerAttr.get( @@ -746,6 +786,60 @@ def lower_jaxpr_to_module( func_op.attributes["dimension_semantics"] = ( mosaic_grid_mapping.get_dimension_semantics() ) + if dynamic_shape_replacement_enabled: + if _mosaic_lowering_dynamic_shape_env is None: + raise ValueError( + "Dynamic shape env is None, invariant violated. Unreachable?" + ) + + # Now we can use jax to compute the dynamic shape graph + + if static_grid is not None: + grid_vars = [ + _mosaic_lowering_dynamic_shape_env.placeholder_to_dim_expr.get(g, g) + for g in static_grid + ] + else: + grid_vars = [] + + invars = [invar.aval for invar in jaxpr.invars] + # Faux shape for grid, just to get the avals + invars.append(jax.ShapeDtypeStruct(grid_vars, jax.numpy.int32)) + args_dimvars = shape_poly.all_dim_vars(invars) + + # This is dimexpr var -> placeholder value for when we jit the dim expr + env: dict[str, int] = {} + for aval in args_dimvars: + env[aval] = _mosaic_lowering_dynamic_shape_env.to_placeholder(aval) + + for ( + placeholder, + dim_expr, + ) in _mosaic_lowering_dynamic_shape_env.placeholder_to_dim_expr.items(): + top_level_names = list(env.keys()) + if dim_expr not in top_level_names: + jitted_eval = jax.jit( + jax_core.evaluate_shape, + static_argnames=( + "shape", + "dim_vars", + ), + keep_unused=True, + ) + stablehlo = export( + jitted_eval, platforms=[str(jax.devices()[0].platform)] + )( + (dim_expr,), tuple(args_dimvars), *(env[v] for v in args_dimvars) + ).mlir_module() + arg_name = args_dimvars + # See Note - On Export Placeholders for more details. + m.operation.attributes[ + "tpu.dynamic_dimension_mapping_module_" + str(placeholder) + ] = ir.StringAttr.get(str(stablehlo)) + arg_name_str = ",".join(arg_name) + m.operation.attributes[ + "tpu.dynamic_dimension_mapping_arg_name_" + str(placeholder) + ] = ir.StringAttr.get(arg_name_str) return m, mosaic_grid_mapping.get_extra_args() @@ -828,6 +922,7 @@ def lower_jaxpr_to_func( dynamic_shape_replacement_fn: ( Callable[[tuple[jax.DimSize, ...]], tuple[int, ...]] | None ) = None, + dynamic_shape_replacement_enabled: bool = False, ) -> func.FuncOp: num_grid = len(mosaic_grid_mapping.grid_types) num_scalar_prefetch = len(mosaic_grid_mapping.scalar_prefetch_types) @@ -874,6 +969,12 @@ def lower_jaxpr_to_func( ) body_func.__name__ = name body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func) + if dynamic_shape_replacement_enabled: + # Skip verification for dynamic shape replacement - you can potentially + # produce ir like ex: add(x[placeholder_0, placeholder_1], y[128, 128]) + # which is not valid, but we don't care since we'll run the verifier again + # after the dynamic shape replacement pass. + return body.func_op try: body.func_op.verify() except ir.MLIRError as e: @@ -3851,3 +3952,15 @@ def _platform_index_lowering( lowering_rules[jax._src.lax.control_flow.platform_index_p] = _platform_index_lowering + + +def _dim_as_value_lowering(ctx: mlir.LoweringRuleContext, *, dim): + placeholder = ctx.lowering_context.dynamic_shape_replacement_fn((dim,))[0] + return ir_constant( + placeholder, mlir_type=_dtype_to_ir_type(jnp.dtype("int32")) + ) + + +import jax._src.export.shape_poly as shape_poly + +lowering_rules[shape_poly.dim_as_value_p] = _dim_as_value_lowering diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 78c34d404..745c30ba9 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -2501,7 +2501,8 @@ class SymbolicPallasTest(PallasBaseTest): ) assert exported_module is not None self.assertIn( - "tensor, %arg6: tensor, %arg7: tensor", + "%arg0: tensor loc(unknown), %arg1: tensor" + " loc(unknown), %arg2: tensor", str(exported_module), ) x = jax.ShapeDtypeStruct((128, 1024), jax.numpy.float32) @@ -2512,7 +2513,7 @@ class SymbolicPallasTest(PallasBaseTest): ) assert exported_module is not None self.assertIn( - "@sym_matmul(%arg0: tensor<128x1024xf32>, %arg1: tensor<1024x512xf32>", + "call @sym_matmul(%arg0, %arg1)", str(exported_module), ) From 3f91b4b43a422acd6c52526738261ee6c1419f9b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 18 Mar 2025 16:28:00 -0700 Subject: [PATCH 15/20] Move jaxlib/{cuda,rocm}_plugin_extension into jaxlib/{cuda/rocm}/ Move the common jaxlib/gpu_plugin_extension into jaxlib/gpu/ Cleanup only, no functional changes intended. PiperOrigin-RevId: 738183402 --- jax/_src/numpy/lax_numpy.py | 2 +- jax_plugins/cuda/__init__.py | 2 +- jax_plugins/rocm/__init__.py | 2 +- jaxlib/BUILD | 59 ---------------------- jaxlib/cuda/BUILD | 18 ++++++- jaxlib/{ => cuda}/cuda_plugin_extension.cc | 2 +- jaxlib/gpu/BUILD | 29 +++++++++++ jaxlib/{ => gpu}/gpu_plugin_extension.cc | 2 +- jaxlib/{ => gpu}/gpu_plugin_extension.h | 6 +-- jaxlib/rocm/BUILD | 16 +++++- jaxlib/{ => rocm}/rocm_plugin_extension.cc | 2 +- jaxlib/tools/BUILD.bazel | 6 +-- jaxlib/tools/build_gpu_kernels_wheel.py | 4 +- 13 files changed, 75 insertions(+), 75 deletions(-) rename jaxlib/{ => cuda}/cuda_plugin_extension.cc (97%) rename jaxlib/{ => gpu}/gpu_plugin_extension.cc (99%) rename jaxlib/{ => gpu}/gpu_plugin_extension.h (85%) rename jaxlib/{ => rocm}/rocm_plugin_extension.cc (98%) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index b2b828cf3..96efc4806 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -71,7 +71,7 @@ import numpy as np export = set_module('jax.numpy') -for pkg_name in ['jax_cuda12_plugin', 'jax.jaxlib']: +for pkg_name in ['jax_cuda12_plugin', 'jax.jaxlib.cuda']: try: cuda_plugin_extension = importlib.import_module( f'{pkg_name}.cuda_plugin_extension' diff --git a/jax_plugins/cuda/__init__.py b/jax_plugins/cuda/__init__.py index 68281f4f3..f6540e986 100644 --- a/jax_plugins/cuda/__init__.py +++ b/jax_plugins/cuda/__init__.py @@ -24,7 +24,7 @@ import jax._src.xla_bridge as xb # cuda_plugin_extension locates inside jaxlib. `jaxlib` is for testing without # preinstalled jax cuda plugin packages. -for pkg_name in ['jax_cuda12_plugin', 'jaxlib']: +for pkg_name in ['jax_cuda12_plugin', 'jaxlib.cuda']: try: cuda_plugin_extension = importlib.import_module( f'{pkg_name}.cuda_plugin_extension' diff --git a/jax_plugins/rocm/__init__.py b/jax_plugins/rocm/__init__.py index b16806e39..c48a681bf 100644 --- a/jax_plugins/rocm/__init__.py +++ b/jax_plugins/rocm/__init__.py @@ -23,7 +23,7 @@ import jax._src.xla_bridge as xb # rocm_plugin_extension locates inside jaxlib. `jaxlib` is for testing without # preinstalled jax rocm plugin packages. -for pkg_name in ['jax_rocm60_plugin', 'jaxlib']: +for pkg_name in ['jax_rocm60_plugin', 'jaxlib.cuda']: try: rocm_plugin_extension = importlib.import_module( f'{pkg_name}.rocm_plugin_extension' diff --git a/jaxlib/BUILD b/jaxlib/BUILD index a61bf7c88..a35eabc9a 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -222,62 +222,3 @@ nanobind_extension( "@xla//third_party/python_runtime:headers", ], ) - -cc_library( - name = "gpu_plugin_extension", - srcs = ["gpu_plugin_extension.cc"], - hdrs = ["gpu_plugin_extension.h"], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - deps = [ - ":kernel_nanobind_helpers", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@nanobind", - "@xla//xla:util", - "@xla//xla/ffi/api:c_api", - "@xla//xla/pjrt:status_casters", - "@xla//xla/pjrt/c:pjrt_c_api_ffi_extension_hdrs", - "@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs", - "@xla//xla/pjrt/c:pjrt_c_api_hdrs", - "@xla//xla/pjrt/c:pjrt_c_api_helpers", - "@xla//xla/pjrt/c:pjrt_c_api_triton_extension_hdrs", - "@xla//xla/python:py_client_gpu", - "@xla//xla/tsl/python/lib/core:numpy", - ], -) - -nanobind_extension( - name = "cuda_plugin_extension", - srcs = ["cuda_plugin_extension.cc"], - module_name = "cuda_plugin_extension", - deps = [ - ":gpu_plugin_extension", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@local_config_cuda//cuda:cuda_headers", - "@nanobind", - "@xla//xla/pjrt:status_casters", - "@xla//xla/tsl/cuda:cublas", - "@xla//xla/tsl/cuda:cudart", - ], -) - -nanobind_extension( - name = "rocm_plugin_extension", - srcs = ["rocm_plugin_extension.cc"], - module_name = "rocm_plugin_extension", - deps = [ - ":gpu_plugin_extension", - "@com_google_absl//absl/log", - "@com_google_absl//absl/strings", - "@local_config_rocm//rocm:hip", - "@local_config_rocm//rocm:rocm_headers", - "@nanobind", - ], -) diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 9a0315266..a9bd35b77 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -657,6 +657,22 @@ py_library( ], ) +nanobind_extension( + name = "cuda_plugin_extension", + srcs = ["cuda_plugin_extension.cc"], + module_name = "cuda_plugin_extension", + deps = [ + "//jaxlib/gpu:gpu_plugin_extension", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@local_config_cuda//cuda:cuda_headers", + "@nanobind", + "@xla//xla/pjrt:status_casters", + "@xla//xla/tsl/cuda:cublas", + "@xla//xla/tsl/cuda:cudart", + ], +) + # We cannot nest select and if_cuda_is_configured so we introduce # a standalone py_library target. py_library( @@ -664,6 +680,6 @@ py_library( # `if_cuda_is_configured` will default to `[]`. deps = if_cuda_is_configured([ ":cuda_gpu_support", - "//jaxlib:cuda_plugin_extension", + ":cuda_plugin_extension", ]), ) diff --git a/jaxlib/cuda_plugin_extension.cc b/jaxlib/cuda/cuda_plugin_extension.cc similarity index 97% rename from jaxlib/cuda_plugin_extension.cc rename to jaxlib/cuda/cuda_plugin_extension.cc index 34cf462d6..8d8514bd2 100644 --- a/jaxlib/cuda_plugin_extension.cc +++ b/jaxlib/cuda/cuda_plugin_extension.cc @@ -20,7 +20,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "third_party/gpus/cuda/include/cuda.h" -#include "jaxlib/gpu_plugin_extension.h" +#include "jaxlib/gpu/gpu_plugin_extension.h" #include "xla/pjrt/status_casters.h" namespace nb = nanobind; diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD index abaed291a..b5292746d 100644 --- a/jaxlib/gpu/BUILD +++ b/jaxlib/gpu/BUILD @@ -90,3 +90,32 @@ xla_py_proto_library( visibility = jax_visibility("triton_proto_py_users"), deps = [":triton_proto"], ) + +cc_library( + name = "gpu_plugin_extension", + srcs = ["gpu_plugin_extension.cc"], + hdrs = ["gpu_plugin_extension.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "//jaxlib:kernel_nanobind_helpers", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@nanobind", + "@xla//xla:util", + "@xla//xla/ffi/api:c_api", + "@xla//xla/pjrt:status_casters", + "@xla//xla/pjrt/c:pjrt_c_api_ffi_extension_hdrs", + "@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs", + "@xla//xla/pjrt/c:pjrt_c_api_hdrs", + "@xla//xla/pjrt/c:pjrt_c_api_helpers", + "@xla//xla/pjrt/c:pjrt_c_api_triton_extension_hdrs", + "@xla//xla/python:py_client_gpu", + "@xla//xla/tsl/python/lib/core:numpy", + ], +) diff --git a/jaxlib/gpu_plugin_extension.cc b/jaxlib/gpu/gpu_plugin_extension.cc similarity index 99% rename from jaxlib/gpu_plugin_extension.cc rename to jaxlib/gpu/gpu_plugin_extension.cc index d666ef6cc..b56cb8337 100644 --- a/jaxlib/gpu_plugin_extension.cc +++ b/jaxlib/gpu/gpu_plugin_extension.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/gpu_plugin_extension.h" +#include "jaxlib/gpu/gpu_plugin_extension.h" #include #include diff --git a/jaxlib/gpu_plugin_extension.h b/jaxlib/gpu/gpu_plugin_extension.h similarity index 85% rename from jaxlib/gpu_plugin_extension.h rename to jaxlib/gpu/gpu_plugin_extension.h index ae8cd73db..70c74454e 100644 --- a/jaxlib/gpu_plugin_extension.h +++ b/jaxlib/gpu/gpu_plugin_extension.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_GPU_PLUGIN_EXTENSION_H_ -#define JAXLIB_GPU_PLUGIN_EXTENSION_H_ +#ifndef JAXLIB_GPU_GPU_PLUGIN_EXTENSION_H_ +#define JAXLIB_GPU_GPU_PLUGIN_EXTENSION_H_ #include "nanobind/nanobind.h" @@ -24,4 +24,4 @@ void BuildGpuPluginExtension(nanobind::module_& m); } // namespace xla -#endif // JAXLIB_GPU_PLUGIN_EXTENSION_H_ +#endif // JAXLIB_GPU_GPU_PLUGIN_EXTENSION_H_ diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index 9774708ad..9a25a795f 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -555,11 +555,25 @@ py_library( ], ) +nanobind_extension( + name = "rocm_plugin_extension", + srcs = ["rocm_plugin_extension.cc"], + module_name = "rocm_plugin_extension", + deps = [ + "//jaxlib/gpu:gpu_plugin_extension", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + "@local_config_rocm//rocm:hip", + "@local_config_rocm//rocm:rocm_headers", + "@nanobind", + ], +) + py_library( name = "gpu_only_test_deps", # `if_rocm_is_configured` will default to `[]`. deps = if_rocm_is_configured([ ":rocm_gpu_support", - "//jaxlib:rocm_plugin_extension", + ":rocm_plugin_extension", ]), ) diff --git a/jaxlib/rocm_plugin_extension.cc b/jaxlib/rocm/rocm_plugin_extension.cc similarity index 98% rename from jaxlib/rocm_plugin_extension.cc rename to jaxlib/rocm/rocm_plugin_extension.cc index f28b5c9b4..1dd1f1943 100644 --- a/jaxlib/rocm_plugin_extension.cc +++ b/jaxlib/rocm/rocm_plugin_extension.cc @@ -20,7 +20,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/strings/str_cat.h" #include "rocm/include/hip/hip_runtime.h" -#include "jaxlib/gpu_plugin_extension.h" +#include "jaxlib/gpu/gpu_plugin_extension.h" namespace nb = nanobind; diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 5b24d2359..afa5866e2 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -143,16 +143,16 @@ py_binary( data = [ "LICENSE.txt", ] + if_cuda([ - "//jaxlib/mosaic/gpu:mosaic_gpu", - "//jaxlib:cuda_plugin_extension", "//jaxlib:version", + "//jaxlib/mosaic/gpu:mosaic_gpu", + "//jaxlib/cuda:cuda_plugin_extension", "//jaxlib/cuda:cuda_gpu_support", "//jax_plugins/cuda:plugin_pyproject.toml", "//jax_plugins/cuda:plugin_setup.py", "@local_config_cuda//cuda:cuda-nvvm", ]) + if_rocm([ - "//jaxlib:rocm_plugin_extension", "//jaxlib:version", + "//jaxlib/rocm:rocm_plugin_extension", "//jaxlib/rocm:rocm_gpu_support", "//jax_plugins/rocm:plugin_pyproject.toml", "//jax_plugins/rocm:plugin_setup.py", diff --git a/jaxlib/tools/build_gpu_kernels_wheel.py b/jaxlib/tools/build_gpu_kernels_wheel.py index 09a55d3c3..2f81eacbd 100644 --- a/jaxlib/tools/build_gpu_kernels_wheel.py +++ b/jaxlib/tools/build_gpu_kernels_wheel.py @@ -110,7 +110,7 @@ def prepare_wheel_cuda( f"__main__/jaxlib/cuda/_triton.{pyext}", f"__main__/jaxlib/cuda/_hybrid.{pyext}", f"__main__/jaxlib/cuda/_versions.{pyext}", - f"__main__/jaxlib/cuda_plugin_extension.{pyext}", + f"__main__/jaxlib/cuda/cuda_plugin_extension.{pyext}", f"__main__/jaxlib/mosaic/gpu/_mosaic_gpu_ext.{pyext}", "__main__/jaxlib/mosaic/gpu/libmosaic_gpu_runtime.so", "__main__/jaxlib/version.py", @@ -148,7 +148,7 @@ def prepare_wheel_rocm( f"__main__/jaxlib/rocm/_hybrid.{pyext}", f"__main__/jaxlib/rocm/_rnn.{pyext}", f"__main__/jaxlib/rocm/_triton.{pyext}", - f"__main__/jaxlib/rocm_plugin_extension.{pyext}", + f"__main__/jaxlib/rocm/rocm_plugin_extension.{pyext}", "__main__/jaxlib/version.py", ], ) From 663ef7ae0120fe0b91cb32bb0ad8b1ae5b847f12 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 18 Mar 2025 16:56:50 -0700 Subject: [PATCH 16/20] Check the type of mesh in `use_abstract_mesh` and `use_concrete_mesh` PiperOrigin-RevId: 738190879 --- jax/_src/array.py | 4 ++-- jax/_src/mesh.py | 12 ++++-------- jax/_src/sharding_impls.py | 18 ++++++++++++++++-- 3 files changed, 22 insertions(+), 12 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index 2f10d8de8..b0793d2c3 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -33,7 +33,6 @@ from jax._src import errors from jax._src import profiler from jax._src import util from jax._src import xla_bridge -from jax._src.mesh import use_concrete_mesh from jax._src.interpreters import mlir from jax._src.interpreters import pxla from jax._src.interpreters import xla @@ -43,7 +42,8 @@ from jax._src.lib import xla_extension as xe from jax._src.sharding import Sharding from jax._src.sharding_impls import ( PmapSharding, SingleDeviceSharding, - device_replica_id_map, hashed_index, num_addressable_indices, local_to_global_shape) # pyformat: disable + device_replica_id_map, hashed_index, num_addressable_indices, + local_to_global_shape, use_concrete_mesh) # pyformat: disable from jax._src.typing import ArrayLike, DLDeviceType, DTypeLike from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method, cache import numpy as np diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 4cb8ba0af..b490febf7 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -543,6 +543,10 @@ class UseAbstractMeshContextManager: __slots__ = ['mesh', 'prev'] def __init__(self, mesh: AbstractMesh): + if not isinstance(mesh, AbstractMesh): + raise ValueError( + "Expected mesh of type `jax.sharding.AbstractMesh`. Got type:" + f" {type(mesh)}") self.mesh = mesh def __enter__(self): @@ -557,13 +561,5 @@ def get_abstract_mesh(): val = jax_config.abstract_mesh_context_manager.value return empty_abstract_mesh if val is None else val -@contextlib.contextmanager -def use_concrete_mesh(mesh: Mesh | None): - prev_val = jax_config.device_context.swap_local(mesh) - try: - yield - finally: - jax_config.device_context.set_local(prev_val) - def get_concrete_mesh() -> Mesh | None: return jax_config.device_context.value diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 51c4ad639..2bbf91378 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -1387,8 +1387,7 @@ def use_mesh(mesh: mesh_lib.Mesh): # if not core.trace_state_clean(): # raise ValueError('`use_mesh` can only be used outside of `jax.jit`') - with (mesh_lib.use_abstract_mesh(mesh.abstract_mesh), - mesh_lib.use_concrete_mesh(mesh)): + with mesh_lib.use_abstract_mesh(mesh.abstract_mesh), use_concrete_mesh(mesh): yield def set_mesh(mesh: mesh_lib.Mesh | None) -> mesh_lib.Mesh | None: @@ -1408,3 +1407,18 @@ def set_mesh(mesh: mesh_lib.Mesh | None) -> mesh_lib.Mesh | None: prev_mesh = config.device_context.get_global() config.device_context.set_global(mesh) return prev_mesh + +@contextlib.contextmanager +def use_concrete_mesh(mesh: mesh_lib.Mesh | None): + if mesh is not None and not isinstance(mesh, mesh_lib.Mesh): + raise ValueError( + f"Expected mesh of type `jax.sharding.Mesh`. Got {type(mesh)}") + # TODO(yashkatariya): Enable this. + # if not core.trace_state_clean(): + # raise ValueError('`use_concrete_mesh` can only be used outside of `jax.jit`.') + + prev_val = config.device_context.swap_local(mesh) + try: + yield + finally: + config.device_context.set_local(prev_val) From 8c7a55ea82e99e96f27ebdec8a4251bdcd43c110 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 18 Mar 2025 18:32:28 -0700 Subject: [PATCH 17/20] Update XLA dependency to use revision http://github.com/openxla/xla/commit/df971129bd82e381954da0185b534220e21798a4. PiperOrigin-RevId: 738213047 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 9f2f77500..73bf2eb38 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "3bb765472122548cc227b8bd2990f00bd533f438" -XLA_SHA256 = "72126aac7602153aee985ca20f73d11c39e3ba9cfb8027492951e787559d0497" +XLA_COMMIT = "df971129bd82e381954da0185b534220e21798a4" +XLA_SHA256 = "11e9a568320cf7e7d61819620fd369927527ecefb68d5d1154b1521456bbdb72" def repo(): tf_http_archive( From 4d715753c45fbc09e02cbb1a5e254e364ee9b896 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Tue, 18 Mar 2025 18:41:35 -0700 Subject: [PATCH 18/20] Make sure to DCE read effects PiperOrigin-RevId: 738215055 --- jax/_src/interpreters/partial_eval.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 95c09ae94..07c516fd9 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -41,7 +41,7 @@ from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval, JaxprEqn, Primitive, ShapedArray, DShapedArray, mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx, InputType, OutputType, get_referent, JaxprEqnContext) -from jax._src.state.types import AbstractRef +from jax._src.state.types import AbstractRef, ReadEffect from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_flatten, tree_structure) from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list, @@ -1423,7 +1423,8 @@ def dce_jaxpr_consts(jaxpr: Jaxpr, used_outputs: Sequence[bool], def has_effects(eqn: JaxprEqn) -> bool: - effs = {e for e in eqn.effects if not isinstance(e, core.NamedAxisEffect)} + effs = {e for e in eqn.effects if not isinstance(e, core.NamedAxisEffect) + and not isinstance(e, ReadEffect)} return bool(effs) From e949effcda6ccc806b5ca00cd3d7bf27927a3447 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Tue, 18 Mar 2025 18:59:50 -0700 Subject: [PATCH 19/20] [Pallas/Fuser] DCE fusion jaxprs before pulling (to avoid unnecessary computations being staged out in block functions) PiperOrigin-RevId: 738218113 --- jax/_src/pallas/fuser/block_spec.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/jax/_src/pallas/fuser/block_spec.py b/jax/_src/pallas/fuser/block_spec.py index d0767aeeb..de0cdd204 100644 --- a/jax/_src/pallas/fuser/block_spec.py +++ b/jax/_src/pallas/fuser/block_spec.py @@ -244,6 +244,13 @@ def pull_block_spec( _unwrap_block_spec_scalar_prefetch, out_block_specs ) flat_block_specs, out_tree = jax.tree.flatten(block_specs_) + jaxpr, used_consts, used_invars = pe.dce_jaxpr_consts( + jaxpr, + used_outputs=[True] * len(jaxpr.outvars), + instantiate=True, + ) + assert all(used_invars) + assert all(used_consts) in_block_specs, env, read_usage_env = _pull_block_spec( jaxpr, tuple(flat_block_specs), From f3b7c5cb9ef443f59593f3dddc7bfd56985bfd0f Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 18 Mar 2025 19:22:27 -0700 Subject: [PATCH 20/20] Integrate LLVM at llvm/llvm-project@0230d63b4a8b Updates LLVM usage to match [0230d63b4a8b](https://github.com/llvm/llvm-project/commit/0230d63b4a8b) PiperOrigin-RevId: 738222096 --- jaxlib/mosaic/gpu/custom_call.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 361c839b6..402e099c8 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -143,7 +143,7 @@ mlir::FailureOr GetPassPipeline( mlir::memref::registerMemRefPasses(); mlir::registerConvertToLLVMPass(); mlir::registerGPUPasses(); - mlir::registerGpuLaunchSinkIndexComputations(); + mlir::registerGpuLaunchSinkIndexComputationsPass(); mosaic::gpu::registerGpuLaunchLoweringPass(); mosaic::gpu::registerConvertGpuToLLVMPass(); mosaic::gpu::registerByvalInsertionPass();