From 4a82fe94ded3ae96acbbe354536de3e51815b11f Mon Sep 17 00:00:00 2001 From: Martin Muller Date: Fri, 14 Mar 2025 15:13:13 +0100 Subject: [PATCH 01/34] Use `lax.top_k` instead of `jnp.argsort` in Gumbel top-k trick for weighted sampling without replacement in `jax.random.choice` --- jax/_src/random.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/_src/random.py b/jax/_src/random.py index c91d2f786..4c1436e3f 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -670,8 +670,8 @@ def choice(key: ArrayLike, ind = jnp.searchsorted(p_cuml, r).astype(int) else: # Gumbel top-k trick: https://timvieira.github.io/blog/post/2019/09/16/algorithms-for-sampling-without-replacement/ - g = -gumbel(key, (n_inputs,), dtype=p_arr.dtype) - jnp.log(p_arr) - ind = jnp.argsort(g)[:n_draws] + g = gumbel(key, (n_inputs,), dtype=p_arr.dtype) + jnp.log(p_arr) + ind = lax.top_k(g, k=n_draws)[1].astype(int) result = ind if arr.ndim == 0 else jnp.take(arr, ind, axis) return result.reshape(shape if arr.ndim == 0 else From dadc68b6c1ac490aa62670a3bdb64a1027f48058 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 14 Mar 2025 22:40:41 +0000 Subject: [PATCH 02/34] add experimental lax.optimization_barrier autodiff rules --- jax/_src/lax/lax.py | 10 ++++++++++ tests/lax_test.py | 9 +++++++++ 2 files changed, 19 insertions(+) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 22d703945..76b3fb9ec 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -8422,3 +8422,13 @@ mlir.register_lowering(optimization_barrier_p, def _optimization_barrier_batcher(batched_args, batch_dims, **params): return optimization_barrier_p.bind(*batched_args, **params), batch_dims batching.primitive_batchers[optimization_barrier_p] = _optimization_barrier_batcher + +def _opt_barrier_jvp(primals, tangents): + tangents = [ad.instantiate_zeros(t) for t in tangents] + return optimization_barrier(primals), optimization_barrier(tangents) +ad.primitive_jvps[optimization_barrier_p] = _opt_barrier_jvp + +def _opt_barrier_transpose(cts, *primals): + cts = [ad.instantiate_zeros(ct) for ct in cts] + return optimization_barrier(cts) +ad.primitive_transposes[optimization_barrier_p] = _opt_barrier_transpose diff --git a/tests/lax_test.py b/tests/lax_test.py index 4b67819be..8764caeb2 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -3618,6 +3618,15 @@ class LaxTest(jtu.JaxTestCase): x = lax.optimization_barrier((2, 3)) self.assertEqual((2, 3), x) + def test_optimization_barrier_autodiff(self): + def f(x): + y = 1. * x + x, y = lax.optimization_barrier((x, y)) + z = 2. * x + return y + z + g = jax.grad(f)(5.) # doesn't crash + self.assertAllClose(g, 3., check_dtypes=False) + class LazyConstantTest(jtu.JaxTestCase): def _Check(self, make_const, expected): From 14cb7453f07ec5285b28d6d49cb8c052117c2aca Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 14 Mar 2025 16:03:45 -0700 Subject: [PATCH 03/34] Add a C++ implementation of a toplogical sort. This is an exact port of the current Python implementation to C++ for speed. I am being careful not to change the topological order we return in any way in this change, although we may do so in a future change. PiperOrigin-RevId: 737014989 --- jax/_src/interpreters/partial_eval.py | 2 +- jax/_src/util.py | 90 +++++++++++++++------------ jaxlib/BUILD | 2 + jaxlib/utils.cc | 74 ++++++++++++++++++++++ tests/util_test.py | 44 +++++++++++++ 5 files changed, 171 insertions(+), 41 deletions(-) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 16e090dd4..95c09ae94 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -797,7 +797,7 @@ def tracers_to_jaxpr( processed_eqn_ids = set() eqns: list[core.JaxprEqn] = [] - for t in toposort([*in_tracers, *out_tracers]): + for t in toposort((*in_tracers, *out_tracers)): r = t.recipe if isinstance(r, JaxprEqnRecipe): # TODO broadcast_in_dim can create a new tracer, not present in parents diff --git a/jax/_src/util.py b/jax/_src/util.py index 408106c12..3e52a2584 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -244,52 +244,62 @@ def curry(f): """ return wraps(f)(partial(partial, f)) -def toposort(end_nodes): - if not end_nodes: return [] - end_nodes = _remove_duplicates(end_nodes) +# TODO(phawkins): make this unconditional after jaxlib 0.5.3 is the minimum. +toposort: Callable[[Iterable[Any]], list[Any]] +if hasattr(jaxlib_utils, "topological_sort"): + toposort = partial(jaxlib_utils.topological_sort, "parents") +else: - child_counts = {} - stack = list(end_nodes) - while stack: - node = stack.pop() - if id(node) in child_counts: - child_counts[id(node)] += 1 - else: - child_counts[id(node)] = 1 - stack.extend(node.parents) - for node in end_nodes: - child_counts[id(node)] -= 1 + def toposort(end_nodes): + if not end_nodes: + return [] + end_nodes = _remove_duplicates(end_nodes) - sorted_nodes = [] - childless_nodes = [node for node in end_nodes if child_counts[id(node)] == 0] - assert childless_nodes - while childless_nodes: - node = childless_nodes.pop() - sorted_nodes.append(node) - for parent in node.parents: - if child_counts[id(parent)] == 1: - childless_nodes.append(parent) + child_counts = {} + stack = list(end_nodes) + while stack: + node = stack.pop() + if id(node) in child_counts: + child_counts[id(node)] += 1 else: - child_counts[id(parent)] -= 1 - sorted_nodes = sorted_nodes[::-1] + child_counts[id(node)] = 1 + stack.extend(node.parents) + for node in end_nodes: + child_counts[id(node)] -= 1 - check_toposort(sorted_nodes) - return sorted_nodes + sorted_nodes = [] + childless_nodes = [ + node for node in end_nodes if child_counts[id(node)] == 0 + ] + assert childless_nodes + while childless_nodes: + node = childless_nodes.pop() + sorted_nodes.append(node) + for parent in node.parents: + if child_counts[id(parent)] == 1: + childless_nodes.append(parent) + else: + child_counts[id(parent)] -= 1 + sorted_nodes = sorted_nodes[::-1] -def check_toposort(nodes): - visited = set() - for node in nodes: - assert all(id(parent) in visited for parent in node.parents) - visited.add(id(node)) + check_toposort(sorted_nodes) + return sorted_nodes + + def check_toposort(nodes): + visited = set() + for node in nodes: + assert all(id(parent) in visited for parent in node.parents) + visited.add(id(node)) + + def _remove_duplicates(node_list): + seen = set() + out = [] + for n in node_list: + if id(n) not in seen: + seen.add(id(n)) + out.append(n) + return out -def _remove_duplicates(node_list): - seen = set() - out = [] - for n in node_list: - if id(n) not in seen: - seen.add(id(n)) - out.append(n) - return out def split_merge(predicate, xs): sides = list(map(predicate, xs)) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 33ea0e07d..93c2b483c 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -214,6 +214,8 @@ nanobind_extension( module_name = "utils", deps = [ "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/synchronization", "@nanobind", diff --git a/jaxlib/utils.cc b/jaxlib/utils.cc index 8b673ef68..bf50b3a52 100644 --- a/jaxlib/utils.cc +++ b/jaxlib/utils.cc @@ -16,9 +16,13 @@ limitations under the License. #include #include +#include +#include #include "nanobind/nanobind.h" #include "absl/cleanup/cleanup.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/synchronization/mutex.h" @@ -293,6 +297,69 @@ PyMethodDef safe_zip_def = { METH_FASTCALL, }; +nb::list TopologicalSort(nb::str parents_attr, + nb::iterable end_nodes_iterable) { + // This is a direct conversion of the original Python implementation. + // More efficient implementations of a topological sort are possible (and + // indeed, easier to write), but changing the choice of topological order + // would break existing tests. + std::vector end_nodes; + absl::flat_hash_set seen; + for (nb::handle n : end_nodes_iterable) { + nb::object node = nb::borrow(n); + if (seen.insert(node.ptr()).second) { + end_nodes.push_back(node); + } + } + + nb::list sorted_nodes; + if (end_nodes.empty()) { + return sorted_nodes; + } + + std::vector stack = end_nodes; + absl::flat_hash_map child_counts; + while (!stack.empty()) { + nb::object node = std::move(stack.back()); + stack.pop_back(); + auto& count = child_counts[node.ptr()]; + if (count == 0) { + for (nb::handle parent : node.attr(parents_attr)) { + stack.push_back(nb::borrow(parent)); + } + } + ++count; + } + + for (nb::handle n : end_nodes) { + child_counts[n.ptr()] -= 1; + } + + std::vector childless_nodes; + childless_nodes.reserve(end_nodes.size()); + for (nb::handle n : end_nodes) { + if (child_counts[n.ptr()] == 0) { + childless_nodes.push_back(nb::borrow(n)); + } + } + + while (!childless_nodes.empty()) { + nb::object node = std::move(childless_nodes.back()); + childless_nodes.pop_back(); + sorted_nodes.append(node); + for (nb::handle parent : node.attr(parents_attr)) { + auto& count = child_counts[parent.ptr()]; + if (count == 1) { + childless_nodes.push_back(nb::borrow(parent)); + } else { + --count; + } + } + } + sorted_nodes.reverse(); + return sorted_nodes; +} + } // namespace NB_MODULE(utils, m) { @@ -304,6 +371,13 @@ NB_MODULE(utils, m) { m.attr("safe_zip") = nb::steal( PyCFunction_NewEx(&safe_zip_def, /*self=*/nullptr, module_name.ptr())); + m.def("topological_sort", &TopologicalSort, nb::arg("parents_attr"), + nb::arg("end_nodes"), + "Computes a topological sort of a graph of objects. parents_attr is " + "the name of the attribute on each object that contains the list of " + "parent objects. end_nodes is an iterable of objects from which we " + "should start a backwards search."); + // Python has no reader-writer lock in its standard library, so we expose // bindings around absl::Mutex. nb::class_(m, "Mutex") diff --git a/tests/util_test.py b/tests/util_test.py index cb803d66b..53414dae9 100644 --- a/tests/util_test.py +++ b/tests/util_test.py @@ -201,5 +201,49 @@ class SafeZipTest(jtu.JaxTestCase): util.safe_zip((), range(3)) +class Node: + def __init__(self, parents): + self.parents = parents + + +class TopologicalSortTest(jtu.JaxTestCase): + + def _check_topological_sort(self, nodes, order): + self.assertEqual(sorted(nodes, key=id), sorted(order, key=id)) + visited = set() + for node in nodes: + self.assertTrue(all(id(parent) in visited for parent in node.parents)) + visited.add(id(node)) + + def test_basic(self): + a = Node([]) + b = Node([a]) + c = Node([a]) + d = Node([a, c]) + e = Node([b, c]) + out = util.toposort([a, d, e]) + self._check_topological_sort([a, b, c, d, e], out) + + def test_stick(self): + a = Node([]) + b = Node([a]) + c = Node([b]) + d = Node([c]) + e = Node([d]) + out = util.toposort([e]) + self._check_topological_sort([a, b, c, d, e], out) + + def test_diamonds(self): + a = Node([]) + b = Node([a]) + c = Node([a]) + d = Node([b, c]) + e = Node([d]) + f = Node([d]) + g = Node([e, f]) + out = util.toposort([g]) + self._check_topological_sort([a, b, c, d, e, f, g], out) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 3c0027af3bc0a1e8034c56b12162946a73aa092a Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 14 Mar 2025 18:04:05 -0700 Subject: [PATCH 04/34] mixing modes --- docs/notebooks/explicit-sharding.ipynb | 138 +++++++++++++++++++------ docs/notebooks/explicit-sharding.md | 64 +++++++++--- 2 files changed, 153 insertions(+), 49 deletions(-) diff --git a/docs/notebooks/explicit-sharding.ipynb b/docs/notebooks/explicit-sharding.ipynb index 850de2541..d656e12d4 100644 --- a/docs/notebooks/explicit-sharding.ipynb +++ b/docs/notebooks/explicit-sharding.ipynb @@ -49,13 +49,9 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 7, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "hVi6mApuVw3r", - "outputId": "a64bcbcb-27f8-4c57-8931-8091c9bb8ebf" + "id": "hVi6mApuVw3r" }, "outputs": [], "source": [ @@ -84,13 +80,13 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "mzDIDvj7Vw0k", - "outputId": "417b8453-9c86-4e76-a886-4fa9fdb16434" + "outputId": "09ef049b-461f-47db-bf58-dc10b42fe40a" }, "outputs": [ { @@ -119,13 +115,13 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 9, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "IyPx_-IBVwxr", - "outputId": "7d6e4fcb-f6a8-4ed8-ae41-61cf478fa499" + "outputId": "0cd3122f-e579-45d7-868d-e42bb0eacddb" }, "outputs": [ { @@ -141,7 +137,7 @@ "Array([ 0, 2, 4, 6, 8, 10, 12, 14], dtype=int32)" ] }, - "execution_count": 3, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -172,13 +168,13 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "NO2ulM_QW7a8", - "outputId": "ea313610-146c-41f4-95b4-c5a5b2b407cb" + "outputId": "d888371b-080e-4bff-be5d-ea56beda3aac" }, "outputs": [ { @@ -208,13 +204,13 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 11, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "1-TzmA0AXCAf", - "outputId": "15b33b6d-3915-4725-da6d-4f31fb78fe71" + "outputId": "1c7cc3ac-4b0e-42b7-facc-c706af10d7d2" }, "outputs": [ { @@ -256,13 +252,13 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 12, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Gy7ABds3XND3", - "outputId": "4ced73ed-5872-45f3-a4a6-2138f942e01b" + "outputId": "0d72dad2-381a-4e96-f771-40d705da1376" }, "outputs": [ { @@ -297,13 +293,13 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "grCcotr-XQjY", - "outputId": "9a9f381d-5111-4824-9bc0-cb2472cb8e6a" + "outputId": "c2db656c-809f-49a6-c948-629d6420360c" }, "outputs": [ { @@ -324,7 +320,7 @@ " [ 3, 4, 5, 6, 7, 8, 9, 10]], dtype=int32)" ] }, - "execution_count": 7, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -460,13 +456,13 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "fpFEaMBcXsJG", - "outputId": "d28a69eb-260f-4fc5-8f19-2cc64cc70660" + "outputId": "5b84b1d1-d7b2-4e9a-ba98-3dd34a5465ef" }, "outputs": [ { @@ -479,13 +475,6 @@ "We're in auto-sharding mode here. This is the current mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Auto))\n", "Result type: ShapedArray(int32[4@X,4])\n" ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Result type: ShapedArray(int32[4@X,4])\n" - ] } ], "source": [ @@ -550,13 +539,13 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 15, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "geptWrdYX0OM", - "outputId": "c0e62eb1-9f79-4d1c-e708-526165ca680f" + "outputId": "b8c3813f-60bb-4ccf-9da7-73462c57963f" }, "outputs": [ { @@ -588,7 +577,88 @@ { "cell_type": "markdown", "metadata": { - "id": "AQQjzUeGX4P6" + "id": "LZWjgiMZ7uSS" + }, + "source": [ + "You can use the `auto_axes` API to be `Auto` over some mesh axes while being `Explicit` over other. For example:" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "IVzPSkp77uCF", + "outputId": "db80a604-98ac-4343-8677-23729adf7ffc" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mesh inside f: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit))\n", + "x.sharding: ShapedArray(float32[4@X,4@Y])\n", + "\n", + "mesh inside g: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Explicit))\n", + "y.sharding inside g: jax.typeof(y) = ShapedArray(float32[4,4@Y])\n", + "\n", + "z.sharding: ShapedArray(float32[4@X,4@Y])\n", + "\n" + ] + }, + { + "data": { + "text/plain": [ + "Array([[ 1. , 2.682942 , 2.818595 , 1.28224 ],\n", + " [-0.513605 , -0.9178486 , 0.44116902, 2.3139732 ],\n", + " [ 2.9787164 , 1.824237 , -0.08804226, -0.99998045],\n", + " [-0.07314587, 1.840334 , 2.9812148 , 2.3005757 ]], dtype=float32)" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import functools\n", + "\n", + "@functools.partial(auto_axes, axes='X')\n", + "def g(y):\n", + " print(f'mesh inside g: {get_abstract_mesh()}')\n", + " print(f'y.sharding inside g: {jax.typeof(y) = }', end='\\n\\n')\n", + " return y * 2\n", + "\n", + "@jax.jit\n", + "def f(arr1):\n", + " print(f'mesh inside f: {get_abstract_mesh()}')\n", + " x = jnp.sin(arr1)\n", + " print(f'x.sharding: {jax.typeof(x)}', end='\\n\\n')\n", + "\n", + " z = g(x, out_shardings=P(\"X\", \"Y\"))\n", + "\n", + " print(f'z.sharding: {jax.typeof(z)}', end=\"\\n\\n\")\n", + " return z + 1\n", + "\n", + "some_x = reshard(np.arange(16).reshape(4, 4), P(\"X\", \"Y\"))\n", + "f(some_x)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_3sfJjRq8w9f" + }, + "source": [ + "As you can see, inside `g`, the type of `arr1` is `ShapedArray(float32[4,4@Y])` which indicates it's Explicit over `Y` mesh axis while auto over `X`." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "sJcWbfAh7UcO" }, "source": [ "## Concrete array shardings can mention `Auto` mesh axis\n", @@ -606,7 +676,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -708,5 +778,5 @@ } }, "nbformat": 4, - "nbformat_minor": 4 + "nbformat_minor": 0 } diff --git a/docs/notebooks/explicit-sharding.md b/docs/notebooks/explicit-sharding.md index b7368b5eb..7c59a675d 100644 --- a/docs/notebooks/explicit-sharding.md +++ b/docs/notebooks/explicit-sharding.md @@ -50,12 +50,8 @@ expect there to be bugs and unimplemented cases. Please let us know when you find something that doesn't work! ```{code-cell} ipython3 ---- -colab: - base_uri: https://localhost:8080/ -id: hVi6mApuVw3r -outputId: a64bcbcb-27f8-4c57-8931-8091c9bb8ebf ---- +:id: hVi6mApuVw3r + import jax import numpy as np import jax.numpy as jnp @@ -79,7 +75,7 @@ scalar) using `jax.typeof`: colab: base_uri: https://localhost:8080/ id: mzDIDvj7Vw0k -outputId: 417b8453-9c86-4e76-a886-4fa9fdb16434 +outputId: 09ef049b-461f-47db-bf58-dc10b42fe40a --- some_array = np.arange(8) print(f"JAX-level type of some_array: {jax.typeof(some_array)}") @@ -96,7 +92,7 @@ under a jit). colab: base_uri: https://localhost:8080/ id: IyPx_-IBVwxr -outputId: 7d6e4fcb-f6a8-4ed8-ae41-61cf478fa499 +outputId: 0cd3122f-e579-45d7-868d-e42bb0eacddb --- @jax.jit def foo(x): @@ -121,7 +117,7 @@ mesh afterwards then you can use the context manager `jax.sharding.use_mesh` ins colab: base_uri: https://localhost:8080/ id: NO2ulM_QW7a8 -outputId: ea313610-146c-41f4-95b4-c5a5b2b407cb +outputId: d888371b-080e-4bff-be5d-ea56beda3aac --- mesh = jax.make_mesh((2, 4), ("X", "Y"), axis_types=(AxisType.Explicit, AxisType.Explicit)) @@ -139,7 +135,7 @@ Now we can create some sharded arrays using `reshard`: colab: base_uri: https://localhost:8080/ id: 1-TzmA0AXCAf -outputId: 15b33b6d-3915-4725-da6d-4f31fb78fe71 +outputId: 1c7cc3ac-4b0e-42b7-facc-c706af10d7d2 --- replicated_array = np.arange(8).reshape(4, 2) sharded_array = reshard(replicated_array, P("X", None)) @@ -163,7 +159,7 @@ These shardings associated with JAX-level types propagate through operations. Fo colab: base_uri: https://localhost:8080/ id: Gy7ABds3XND3 -outputId: 4ced73ed-5872-45f3-a4a6-2138f942e01b +outputId: 0d72dad2-381a-4e96-f771-40d705da1376 --- arg0 = reshard(np.arange(4).reshape(4, 1), P("X", None)) arg1 = reshard(np.arange(8).reshape(1, 8), P(None, "Y")) @@ -184,7 +180,7 @@ We can do the same type querying under a jit: colab: base_uri: https://localhost:8080/ id: grCcotr-XQjY -outputId: 9a9f381d-5111-4824-9bc0-cb2472cb8e6a +outputId: c2db656c-809f-49a6-c948-629d6420360c --- @jax.jit def add_arrays(x, y): @@ -294,7 +290,7 @@ the first axis only, like `f32[4@X, 4]`. You can do this as follows: colab: base_uri: https://localhost:8080/ id: fpFEaMBcXsJG -outputId: d28a69eb-260f-4fc5-8f19-2cc64cc70660 +outputId: 5b84b1d1-d7b2-4e9a-ba98-3dd34a5465ef --- some_x = reshard(np.arange(16).reshape(4, 4), P("X", None)) some_y = reshard(np.arange(16).reshape(4, 4), P(None, "X")) @@ -355,7 +351,7 @@ The current mesh tells us which sharding mode we're in. We can query it with colab: base_uri: https://localhost:8080/ id: geptWrdYX0OM -outputId: c0e62eb1-9f79-4d1c-e708-526165ca680f +outputId: b8c3813f-60bb-4ccf-9da7-73462c57963f --- print(f"Current mesh is: {get_abstract_mesh()}") ``` @@ -369,7 +365,45 @@ sharding mode for each mesh axis. Shardings (on JAX-level types) can only mention _explicit_ mesh axes and collective operations like `psum` can only mention _manual_ mesh axes. -+++ {"id": "AQQjzUeGX4P6"} ++++ {"id": "LZWjgiMZ7uSS"} + +You can use the `auto_axes` API to be `Auto` over some mesh axes while being `Explicit` over other. For example: + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: IVzPSkp77uCF +outputId: db80a604-98ac-4343-8677-23729adf7ffc +--- +import functools + +@functools.partial(auto_axes, axes='X') +def g(y): + print(f'mesh inside g: {get_abstract_mesh()}') + print(f'y.sharding inside g: {jax.typeof(y) = }', end='\n\n') + return y * 2 + +@jax.jit +def f(arr1): + print(f'mesh inside f: {get_abstract_mesh()}') + x = jnp.sin(arr1) + print(f'x.sharding: {jax.typeof(x)}', end='\n\n') + + z = g(x, out_shardings=P("X", "Y")) + + print(f'z.sharding: {jax.typeof(z)}', end="\n\n") + return z + 1 + +some_x = reshard(np.arange(16).reshape(4, 4), P("X", "Y")) +f(some_x) +``` + ++++ {"id": "_3sfJjRq8w9f"} + +As you can see, inside `g`, the type of `arr1` is `ShapedArray(float32[4,4@Y])` which indicates it's Explicit over `Y` mesh axis while auto over `X`. + ++++ {"id": "sJcWbfAh7UcO"} ## Concrete array shardings can mention `Auto` mesh axis From 9b0ace4a1112a8ce8b85b4aeae504919b8b65905 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Fri, 14 Mar 2025 18:57:39 -0700 Subject: [PATCH 05/34] Support error checking in explicit mode PiperOrigin-RevId: 737051146 --- jax/_src/error_check.py | 89 ++++++++++++++++++++++++++++++++++++--- jax/_src/mesh.py | 2 +- tests/error_check_test.py | 19 +++++++++ 3 files changed, 103 insertions(+), 7 deletions(-) diff --git a/jax/_src/error_check.py b/jax/_src/error_check.py index 11e65a7dd..60dc2f76a 100644 --- a/jax/_src/error_check.py +++ b/jax/_src/error_check.py @@ -14,13 +14,17 @@ from __future__ import annotations +from functools import partial import threading import jax from jax._src import core from jax._src import source_info_util from jax._src import traceback_util +import jax._src.mesh as mesh_lib +from jax.experimental.shard_map import shard_map import jax.numpy as jnp +from jax.sharding import NamedSharding, PartitionSpec as P Traceback = source_info_util.Traceback @@ -54,17 +58,61 @@ _error_storage = _ErrorStorage() def _initialize_error_code_ref() -> None: - """Initialize error_code_ref in the current thread.""" + """Initialize error_code_ref in the current thread. + + The size of the error code array is determined by the mesh in the context. In + single-device environment, the array is a scalar. In multi-device + environment, the array has the same shape as the mesh. + """ with core.eval_context(): - error_code = jnp.uint32(_NO_ERROR) + # Get mesh from the context. + mesh = mesh_lib.get_concrete_mesh() + + if mesh is None: # single-device case. + error_code = jnp.uint32(_NO_ERROR) + + else: # multi-device case. + sharding = NamedSharding(mesh, P(*mesh.axis_names)) + error_code = jnp.full( + mesh.axis_sizes, + jnp.uint32(_NO_ERROR), + device=sharding, + ) + _error_storage.ref = core.mutable_array(error_code) -def set_error_if(pred: jax.Array, msg: str) -> None: +class error_checking_context: + """Redefine the error checking state based on the mesh in the context. + + This context manager should be used when starting a multi-device + computation, and whenever the mesh is changed. + + When exiting the context, the error checking state will be reset to the + original state. + """ + + __slots__ = ("old_ref",) + + def __init__(self): + self.old_ref = None + + def __enter__(self): + self.old_ref = _error_storage.ref + _initialize_error_code_ref() + return self + + def __exit__(self, exc_type, exc_value, traceback): + _error_storage.ref = self.old_ref + + +def set_error_if(pred: jax.Array, /, msg: str) -> None: """Set error if any element of pred is true. If the error is already set, the new error will be ignored. It will not override the existing error. + + In auto mode, this function does not work under jit. """ if _error_storage.ref is None: _initialize_error_code_ref() @@ -76,7 +124,32 @@ def set_error_if(pred: jax.Array, msg: str) -> None: new_error_code = jnp.uint32(len(_error_list)) _error_list.append((msg, traceback)) - pred = pred.any() + out_sharding = core.typeof(_error_storage.ref).sharding + in_sharding: NamedSharding = core.typeof(pred).sharding + + if out_sharding.mesh.shape_tuple == (): # single-device case. + pred = pred.any() + else: # multi-device case. + has_auto_axes = mesh_lib.AxisType.Auto in in_sharding.mesh.axis_types + if has_auto_axes: + raise NotImplementedError( + "Error checking in auto mode is not supported yet. Please use" + " explicit mode." + ) + if out_sharding.mesh != in_sharding.mesh: + raise ValueError( + "The error code state and the predicate must be on the same mesh, " + f"but got {out_sharding.mesh} and {in_sharding.mesh} respectively. " + "Please use `with error_checking_context()` to redefine the error " + "code state based on the mesh." + ) + pred = shard_map( + partial(jnp.any, keepdims=True), + mesh=out_sharding.mesh, + in_specs=in_sharding.spec, + out_specs=out_sharding.spec, + )(pred) # perform per-device reduction + error_code = _error_storage.ref[...] should_update = jnp.logical_and(pred, error_code == jnp.uint32(_NO_ERROR)) error_code = jnp.where(should_update, new_error_code, error_code) @@ -93,7 +166,7 @@ def raise_if_error() -> None: if _error_storage.ref is None: # if not initialized, do nothing return - error_code = _error_storage.ref[...] + error_code = _error_storage.ref[...].min() # reduce to a single error code if isinstance(error_code, core.Tracer): raise ValueError( "raise_if_error() should not be called within a traced context, such as" @@ -101,7 +174,11 @@ def raise_if_error() -> None: ) if error_code == jnp.uint32(_NO_ERROR): return - _error_storage.ref[...] = jnp.uint32(_NO_ERROR) + _error_storage.ref[...] = jnp.full( + _error_storage.ref.shape, + jnp.uint32(_NO_ERROR), + device=_error_storage.ref.sharding, + ) # clear the error code msg, traceback = _error_list[error_code] exc = JaxValueError(msg) diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 94e27a2ba..4cb8ba0af 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -565,5 +565,5 @@ def use_concrete_mesh(mesh: Mesh | None): finally: jax_config.device_context.set_local(prev_val) -def get_concrete_mesh(): +def get_concrete_mesh() -> Mesh | None: return jax_config.device_context.value diff --git a/tests/error_check_test.py b/tests/error_check_test.py index 5cdde30b1..b96c62814 100644 --- a/tests/error_check_test.py +++ b/tests/error_check_test.py @@ -20,12 +20,14 @@ from jax._src import config from jax._src import error_check from jax._src import test_util as jtu import jax.numpy as jnp +from jax.sharding import NamedSharding, PartitionSpec as P JaxValueError = error_check.JaxValueError config.parse_flags_with_absl() +jtu.request_cpu_devices(4) @jtu.with_config(jax_check_tracer_leaks=True) @@ -190,6 +192,23 @@ class ErrorCheckTests(jtu.JaxTestCase): ): jax.jit(error_check.raise_if_error)() + @parameterized.product(jit=[True, False]) + @jtu.with_user_mesh((2, 2), ("x", "y")) + def test_error_check_explicit_mode(self, mesh, jit): + def f(x): + error_check.set_error_if(x <= 0, "x must be greater than 0") + return x + 1 + + if jit: + f = jax.jit(f) + + sharding = NamedSharding(mesh, P("x", "y")) + x = jnp.full((4, 4), -1, dtype=jnp.int32, device=sharding) + with error_check.error_checking_context(): + f(x) + with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"): + error_check.raise_if_error() + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From f360e1919492c4329ba5badd03164bebeeabdc4d Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 15 Mar 2025 05:09:59 -0700 Subject: [PATCH 06/34] Update XLA dependency to use revision http://github.com/openxla/xla/commit/f52d5e03ce1cf26142a234087dc1d6c3fd919b6f. PiperOrigin-RevId: 737143950 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index abae0dfa1..8a1d5c978 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "4c4aa96f9ffec4bb963b50c50192aeab4da9dc4a" -XLA_SHA256 = "c373e52b2f8b4175c69e99e636ad64b3bcf33fb44d1b7ad6ef8f4162c9052af8" +XLA_COMMIT = "f52d5e03ce1cf26142a234087dc1d6c3fd919b6f" +XLA_SHA256 = "55239fb9087e71c0a5504446538bb61a661b007ebdfb81d9a0c8574b2d6b9c1a" def repo(): tf_http_archive( From de8b0564ce88776768f0a1a1fb8be01312a9f345 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Sat, 15 Mar 2025 11:49:51 -0700 Subject: [PATCH 07/34] Better docs for jax.lax add/sub/mul/div --- jax/_src/lax/lax.py | 89 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 83 insertions(+), 6 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 22d703945..6e3405ec9 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1026,24 +1026,101 @@ def clz(x: ArrayLike) -> Array: r"""Elementwise count-leading-zeros.""" return clz_p.bind(x) +@export def add(x: ArrayLike, y: ArrayLike) -> Array: - r"""Elementwise addition: :math:`x + y`.""" + r"""Elementwise addition: :math:`x + y`. + + This function lowers directly to the `stablehlo.add`_ operation. + + Args: + x, y: Input arrays. Must have matching numerical dtypes. If neither + is a scalar, ``x`` and ``y`` must have the same number of dimensions + and be broadcast compatible. + + Returns: + An array of the same dtype as ``x`` and ``y`` containing the sum + of each pair of broadcasted entries. + + See also: + - :func:`jax.numpy.add`: NumPy-style addition supporting inputs + with mixed dtypes and ranks. + + .. _stablehlo.add: https://openxla.org/stablehlo/spec#add + """ return add_p.bind(x, y) +@export def sub(x: ArrayLike, y: ArrayLike) -> Array: - r"""Elementwise subtraction: :math:`x - y`.""" + r"""Elementwise subtraction: :math:`x - y`. + + This function lowers directly to the `stablehlo.subtract`_ operation. + + Args: + x, y: Input arrays. Must have matching numerical dtypes. If neither + is a scalar, ``x`` and ``y`` must have the same number of dimensions + and be broadcast compatible. + + Returns: + An array of the same dtype as ``x`` and ``y`` containing the difference + of each pair of broadcasted entries. + + See also: + - :func:`jax.numpy.subtract`: NumPy-style subtraction supporting + inputs with mixed dtypes and ranks. + + .. _stablehlo.subtract: https://openxla.org/stablehlo/spec#subtract + """ return sub_p.bind(x, y) +@export def mul(x: ArrayLike, y: ArrayLike) -> Array: - r"""Elementwise multiplication: :math:`x \times y`.""" + r"""Elementwise multiplication: :math:`x \times y`. + + This function lowers directly to the `stablehlo.multiply`_ operation. + + Args: + x, y: Input arrays. Must have matching numerical dtypes. If neither + is a scalar, ``x`` and ``y`` must have the same number of dimensions + and be broadcast compatible. + + Returns: + An array of the same dtype as ``x`` and ``y`` containing the product + of each pair of broadcasted entries. + + See also: + - :func:`jax.numpy.multiply`: NumPy-style multiplication supporting + inputs with mixed dtypes and ranks. + + .. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply + """ return mul_p.bind(x, y) +@export def div(x: ArrayLike, y: ArrayLike) -> Array: r"""Elementwise division: :math:`x \over y`. - Integer division overflow - (division by zero or signed division of INT_SMIN with -1) - produces an implementation defined value. + This function lowers directly to the `stablehlo.divide`_ operation. + + Integer division overflow (division by zero or signed division of + INT_SMIN with -1) produces an implementation defined value. + + Args: + x, y: Input arrays. Must have matching numerical dtypes. If neither + is a scalar, ``x`` and ``y`` must have the same number of dimensions + and be broadcast compatible. + + Returns: + An array of the same dtype as ``x`` and ``y`` containing the quotient + of each pair of broadcasted entries. For integer inputs, any fractional + part is discarded. + + See also: + - :func:`jax.numpy.divide`: NumPy-style true division supporting + inputs with mixed dtypes and ranks. + - :func:`jax.numpy.floor_divide`: NumPy-style floor division supporting + inputs with mixed dtypes and ranks. + + .. _stablehlo.divide: https://openxla.org/stablehlo/spec#divide """ return div_p.bind(x, y) From 466ef6a132f06d8b6a19b2c21b7d05cc7b39172f Mon Sep 17 00:00:00 2001 From: Joan Puigcerver Date: Sat, 15 Mar 2025 22:57:56 -0700 Subject: [PATCH 08/34] Change the way that batching.spec_types is updated. There's no reason why not two custom vmappable types cannot share the same spec_type. However, spec_types was a set, which can cause bugs / exceptions. Suppose that I register two vmappable data_types sharing the same spec_type, and then unregister one of the two. Then, the spec_type is no longer in the set to support the second data_type. Also, an exception will be raised if I try to unregister the two vmappable types (the second call to spec_types.remove). When unregistering a data type, instead of removing its spec_type from the set, we regenerate the set from the remaining vmappable types. PiperOrigin-RevId: 737280270 --- jax/_src/interpreters/batching.py | 7 +++++-- tests/batching_test.py | 26 ++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 40dbe0018..03c9a9510 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -322,12 +322,15 @@ vmappables: dict[type, tuple[type, type]] = {} spec_types: set[type] = {JumbleAxis} def unregister_vmappable(data_type: type) -> None: - spec_type, axis_size_type = vmappables.pop(data_type) - spec_types.remove(spec_type) + _, axis_size_type = vmappables.pop(data_type) del to_elt_handlers[data_type] del from_elt_handlers[data_type] if axis_size_type in make_iota_handlers: del make_iota_handlers[axis_size_type] + global spec_types + spec_types = ( + {JumbleAxis} | {spec_type for spec_type, _ in vmappables.values()} + ) def is_vmappable(x: Any) -> bool: return type(x) is Jumble or type(x) in vmappables diff --git a/tests/batching_test.py b/tests/batching_test.py index bab18ce53..f2a4e8c34 100644 --- a/tests/batching_test.py +++ b/tests/batching_test.py @@ -1356,6 +1356,32 @@ class VmappableTest(jtu.JaxTestCase): self.assertEqual(ans.names, expected.names) self.assertAllClose(ans.data, expected.data) + def test_types_with_same_spec(self): + # We register NamedArray. + batching.register_vmappable(NamedArray, NamedMapSpec, int, + named_to_elt, named_from_elt, None) + + # We then register another type that uses NamedMapSpec as the spec_type too, + # and immediately unregister it. + class Foo: + pass + batching.register_vmappable(Foo, NamedMapSpec, int, + named_to_elt, named_from_elt, None) + batching.unregister_vmappable(Foo) + + # We should still be able to use vmap on NamedArray. + def f(x): + return named_mul(x, x) + + x = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4)) + ans = jax.jit(f)(x) + expected = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4) ** 2) + + self.assertEqual(ans.names, expected.names) + self.assertAllClose(ans.data, expected.data) + + # And unregister NamedArray without exceptions. + batching.unregister_vmappable(NamedArray) if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From e8b683aee04b940527d215942e3bdc3a17707caa Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 16 Mar 2025 05:36:58 -0700 Subject: [PATCH 09/34] Update XLA dependency to use revision http://github.com/openxla/xla/commit/936a727db7cefa30027b727b7056b1b5c6064145. PiperOrigin-RevId: 737338103 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 8a1d5c978..e5029d052 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "f52d5e03ce1cf26142a234087dc1d6c3fd919b6f" -XLA_SHA256 = "55239fb9087e71c0a5504446538bb61a661b007ebdfb81d9a0c8574b2d6b9c1a" +XLA_COMMIT = "936a727db7cefa30027b727b7056b1b5c6064145" +XLA_SHA256 = "b27db91ad7c4a1badb57bbcf92f8894ceebb709b2197cfd2e830d121e3d20dc7" def repo(): tf_http_archive( From 2bdd9c879783fd1a9a33ab9d2622e0d2d1c0ab50 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 17 Mar 2025 02:51:30 -0700 Subject: [PATCH 10/34] [Mosaic GPU] Add support for fast WGMMA layout changes after 8- to 16-bit upcast PiperOrigin-RevId: 737542885 --- .../mosaic/gpu/fragmented_array.py | 79 ++++++++++++++----- jax/experimental/mosaic/gpu/utils.py | 22 +++++- tests/mosaic/gpu_test.py | 39 ++++++++- 3 files changed, 117 insertions(+), 23 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 2cdaaad6a..29650e1a4 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -382,21 +382,6 @@ def _tiled_wgmma_layout(shape: tuple[int, ...]): return WGMMA_LAYOUT -def _tiled_wgmma_layout_for_upcast(shape: tuple[int, ...]): - """Returns a tiled layout that is easy to relayout to WGMMA layout after doubling the bitwidth.""" - if len(shape) != 2: - raise ValueError(f"Shape {shape} is not 2D") - if shape[0] % 64 != 0 or shape[1] % 8 != 0: - raise ValueError(f"Shape {shape} is not a multiple of 64x8") - t = Tiling(((64, 16), (16, 16), (8, 16), (4,), (2, 1))) - return TiledLayout( - t, - warp_dim=-9, - lane_dims=(-5, -2, -4), - vector_dim=-3, - ) - - @dataclasses.dataclass(frozen=True) class WGMMARowFragLayout: """[m] matrix, where m % 64 == 0.""" @@ -505,13 +490,43 @@ WGMMA_ROW_LAYOUT = WGMMARowFragLayout() # The tiled layout is equivalent to one described here in PTX documentation: # https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-d +# In this layout, we partition the 64x8 tiles over 4 warpgroups into 16x8 tiles. +# Then, we further split the 16x8 tiles into 8x8 submatrices which are the unit +# of data that is split across a warp. Since 8*8 = 64, but a warp has only 32 +# threads, we vectorize pairs of elements along columns. +# The assignment of elements to warp lanes is as follows: +# +# 0 0 1 1 2 2 3 3 +# 4 4 5 5 6 6 7 7 +# 8 8 9 9 10 10 11 11 +# 12 12 13 13 14 14 15 15 +# ... WGMMA_LAYOUT = TiledLayout( Tiling(((64, 8), (16, 8), (8, 8), (1, 2))), warp_dim=-8, lane_dims=(-4, -3), vector_dim=-1, ) -# This tiled layout is similar to the one above. Above, each warp stores a 8x8 +# This tiled layout is similar to the WGMMA layout, only the unit at which we +# assign submatrices to warps grows from 8x8 to 8x16. The elements within each +# submatrix are assigned to threads in the following way: +# +# 0 0 0 0 2 2 2 2 1 1 1 1 3 3 3 3 +# 4 4 4 4 6 6 6 6 5 5 5 5 7 7 7 7 +# ... +# +# Our vector length is twice the size of that of WGMMA_LAYOUT, which lets us use +# 32-bit SMEM loads/stores when dealing with 8-bit values. The conversion +# to the WGMMA layout only requires communication between with index differing +# in their 2 bit (i.e. 0 and 1, 2 and 4), so the conversion to WGMMA_LAYOUT +# only requires a single warp shuffle (plus permutes local to each thread). +WGMMA_LAYOUT_UPCAST_2X = TiledLayout( + Tiling(((64, 16), (16, 16), (8, 16), (8,), (4,))), + warp_dim=-8, + lane_dims=(-4, -2, -3), + vector_dim=-1, +) +# This tiled layout is similar to WGMMA_LAYOUT. There, each warp stores a 8x8 # submatrix in the following way (we only show the first 4 rows for brevity): # # 0 0 1 1 2 2 3 3 @@ -697,6 +712,7 @@ class FragmentedArray: At the moment, only conversions from ``WGSplatFragLayout`` are supported. """ i32 = ir.IntegerType.get_signless(32) + c = lambda x: arith.constant(i32, x) if self.layout == new_layout: return self shape = self.shape @@ -707,10 +723,10 @@ class FragmentedArray: ): is_even_row = arith.cmpi( arith.CmpIPredicate.eq, - arith.remui(arith.divui(utils.thread_idx(), c(4, i32)), c(2, i32)), - c(0, i32), + arith.remui(arith.divui(utils.thread_idx(), c(4)), c(2)), + c(0), ) - perm = arith.select(is_even_row, c(0x5410, i32), c(0x3276, i32)) + perm = arith.select(is_even_row, c(0x5410), c(0x3276)) new_regs = [] for reg in self.registers.flat: reg_ty = reg.type @@ -725,6 +741,31 @@ class FragmentedArray: _layout=new_layout, _is_signed=self.is_signed, ) + if len(shape) == 2 and shape[0] % 64 == 0 and shape[1] % 16 == 0: + if ( + self.layout == WGMMA_LAYOUT_UPCAST_2X + and new_layout == WGMMA_LAYOUT + and utils.bytewidth(self.mlir_dtype) == 2 + ): + assert shape[1] % 16 == 0 # Should be implied by the layout + new_registers = np.empty(new_layout.registers_shape(shape), dtype=object) + is_even = arith.cmpi(arith.CmpIPredicate.eq, arith.remui(utils.thread_idx(), c(2)), c(0)) + for idx, reg in np.ndenumerate(self.registers): + assert ir.VectorType(reg.type).shape == [4] + # Each slice is exactly 32-bits (we checked bitwidth == 2 above) + low = utils.vector_slice(reg, slice(0, 2)) + high = utils.vector_slice(reg, slice(2, 4)) + to_exchange = arith.select(is_even, high, low) + # Exchange values between even and odd threads. + exchanged = utils.shfl_bfly(to_exchange, 1) + low = arith.select(is_even, low, exchanged) + high = arith.select(is_even, exchanged, high) + new_registers[(idx[0], idx[1] * 2, *idx[2:-1])] = low + new_registers[(idx[0], idx[1] * 2 + 1, *idx[2:-1])] = high + assert all(r is not None for r in new_registers) + return FragmentedArray( + _registers=new_registers, _layout=new_layout, _is_signed=self.is_signed, + ) if not isinstance(self.layout, WGSplatFragLayout): raise NotImplementedError( f"Cannot convert from {self.layout} to {new_layout}" diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 080397bbb..053ba4b4b 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -1180,10 +1180,14 @@ def shfl_bfly(x: ir.Value, distance: int | ir.Value): i32 = ir.IntegerType.get_signless(32) if isinstance(distance, int): distance = c(distance, i32) - assert x.type == i32 - return nvvm.shfl_sync( + if (result_type := x.type) != i32: + x = bitcast(x, i32) + y = nvvm.shfl_sync( i32, c(0xFFFFFFFF, i32), x, distance, c(0x1F, i32), nvvm.ShflKind.bfly, ) + if result_type != i32: + y = bitcast(y, result_type) + return y def bitcast(x: ir.Value, new_type: ir.Type): @@ -1205,3 +1209,17 @@ def bitcast(x: ir.Value, new_type: ir.Type): def ceil_div(x: int, y: int): return (x + y - 1) // y + + +def vector_slice(v: ir.Value, s: slice): + i32 = ir.IntegerType.get_signless(32) + v_ty = ir.VectorType(v.type) + if len(v_ty.shape) != 1: + raise NotImplementedError + [v_len] = v_ty.shape + it = range(v_len)[s] + result = llvm.mlir_undef(ir.VectorType.get((len(it),), v_ty.element_type)) + for tgt, src in enumerate(it): + elem = llvm.extractelement(v, c(src, i32)) + result = llvm.insertelement(result, elem, c(tgt, i32)) + return result diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 56365a134..cfd5b28a7 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -2214,13 +2214,13 @@ class LayoutTest(TestCase): col_tiling = swizzle // bytewidth(utils.dtype_to_ir_type(dtype)) m, n = 128, col_tiling * 2 tiling = (64, col_tiling) - tiled_layout = fa._tiled_wgmma_layout_for_upcast((m, n)) + layout = fa.WGMMA_LAYOUT_UPCAST_2X def kernel(ctx, in_, out, smems): smem_in, smem_out, barrier = smems ctx.async_copy(src_ref=in_, dst_ref=smem_in, swizzle=swizzle, barrier=barrier) barrier.wait() t = mgpu.FragmentedArray.load_tiled( - smem_in, swizzle=swizzle, is_signed=True, layout=tiled_layout + smem_in, swizzle=swizzle, is_signed=True, layout=layout ) t.store_tiled(smem_out, swizzle=swizzle) mgpu.commit_shared() @@ -2275,6 +2275,41 @@ class LayoutTest(TestCase): )(x) np.testing.assert_array_equal(y, y_ref) + def test_upcast_to_wgmma(self): + in_dtype = jnp.dtype(jnp.int8) + out_dtype = jnp.dtype(jnp.int16) + swizzle = 128 + in_col_tiling = 8 * swizzle // jnp.iinfo(in_dtype).bits + in_tiling = (8, in_col_tiling) + out_col_tiling = swizzle // out_dtype.itemsize + out_tiling = (8, out_col_tiling) + m, n = 128, in_col_tiling * 2 + def kernel(ctx, in_, out, smems): + smem_in, smem_out, barrier = smems + ctx.async_copy(src_ref=in_, dst_ref=smem_in, swizzle=swizzle, barrier=barrier) + barrier.wait() + t = mgpu.FragmentedArray.load_tiled( + smem_in, swizzle=swizzle, is_signed=True, layout=fa.WGMMA_LAYOUT_UPCAST_2X + ) + t = t.astype(ir.IntegerType.get_signless(16), is_signed=True) + t = t.to_layout(fa.WGMMA_LAYOUT) + t.store_tiled(smem_out, swizzle=swizzle) + mgpu.commit_shared() + ctx.async_copy(src_ref=smem_out, dst_ref=out, swizzle=swizzle) + ctx.await_async_copy(0) + def tile(x, tiling): + return x.reshape( + x.shape[0] // tiling[0], tiling[0], x.shape[1] // tiling[1], tiling[1] + ).transpose(0, 2, 1, 3) + x = jax.random.randint(jax.random.key(42), (m, n), -128, 127, dtype=in_dtype) + xt = tile(x, in_tiling) + y = x.astype(out_dtype) + yt = tile(y, out_tiling) + f = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), xt, yt, [xt, yt, mgpu.TMABarrier()], + ) + np.testing.assert_array_equal(f(xt), yt) + @dataclasses.dataclass(frozen=True) class Tile: From 89b21de62ad5b9ab5b8cc12a2d8a133fe6b15d86 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 17 Mar 2025 03:26:06 -0700 Subject: [PATCH 11/34] [Mosaic GPU] Add support for changing the layout before the upcast This lets us save on 2 ALU instructions (3x select becomes 1x prmt). PiperOrigin-RevId: 737550598 --- .../mosaic/gpu/fragmented_array.py | 51 ++++++++++++++----- jax/experimental/mosaic/gpu/utils.py | 22 ++++++-- tests/mosaic/gpu_test.py | 10 +++- 3 files changed, 65 insertions(+), 18 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 29650e1a4..ed66269b5 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -732,9 +732,7 @@ class FragmentedArray: reg_ty = reg.type reg = utils.bitcast(reg, i32) reg_shfl = utils.shfl_bfly(reg, 4) - new_reg = llvm.inline_asm( - i32, [reg, reg_shfl, perm], "prmt.b32 $0, $1, $2, $3;", "=r,r,r,r" - ) + new_reg = utils.prmt(reg, reg_shfl, perm) new_regs.append(utils.bitcast(new_reg, reg_ty)) return FragmentedArray( _registers=np.asarray(new_regs, dtype=object).reshape(new_layout.registers_shape(shape)), @@ -745,21 +743,48 @@ class FragmentedArray: if ( self.layout == WGMMA_LAYOUT_UPCAST_2X and new_layout == WGMMA_LAYOUT - and utils.bytewidth(self.mlir_dtype) == 2 + and (dtype_bitwidth := utils.bitwidth(self.mlir_dtype)) in {8, 16} ): assert shape[1] % 16 == 0 # Should be implied by the layout new_registers = np.empty(new_layout.registers_shape(shape), dtype=object) - is_even = arith.cmpi(arith.CmpIPredicate.eq, arith.remui(utils.thread_idx(), c(2)), c(0)) + is_even = arith.cmpi( + arith.CmpIPredicate.eq, arith.remui(utils.thread_idx(), c(2)), c(0) + ) for idx, reg in np.ndenumerate(self.registers): assert ir.VectorType(reg.type).shape == [4] - # Each slice is exactly 32-bits (we checked bitwidth == 2 above) - low = utils.vector_slice(reg, slice(0, 2)) - high = utils.vector_slice(reg, slice(2, 4)) - to_exchange = arith.select(is_even, high, low) - # Exchange values between even and odd threads. - exchanged = utils.shfl_bfly(to_exchange, 1) - low = arith.select(is_even, low, exchanged) - high = arith.select(is_even, exchanged, high) + if dtype_bitwidth == 16: + # A single vector is 64-bits, but shuffles are only 32-bit wide. + # We only shuffle the half that needs to go to other thread. + low = utils.vector_slice(reg, slice(0, 2)) + high = utils.vector_slice(reg, slice(2, 4)) + to_exchange = arith.select(is_even, high, low) + # Exchange values between even and odd threads. + exchanged = utils.shfl_bfly(to_exchange, 1) + low = arith.select(is_even, low, exchanged) + high = arith.select(is_even, exchanged, high) + elif dtype_bitwidth == 8: + # The vector is 32-bits, so we just shuffle the whole thing and + # use prmt to blend it with the local register. + exchanged = utils.shfl_bfly(reg, 1) + # Consider lanes 0 and 1, because the situation is symmetric for + # each pair. If we feed reg[lane] and exchanged[lane] (which is + # really the same as reg of the other lane) to prmt, we can index + # the elements of the result using the following indices: + # reg[0]: 0 1 2 3 reg[1]: 8 9 10 11 + # prmt[0]: 0 1 2 3 4 5 6 7 + # prmt[1]: 4 5 6 7 0 1 2 3 + # The expected outputs and their respective permutations are: + # out[0]: 0 1 8 9 out[1]: 2 3 10 11 + # prmt[0]: 0 1 4 5 prmt[1]: 6 7 2 3 + # Note that the patterns still need to be flipped, since we listed + # bytes with LSB on the left, which is the opposite of how the + # numeric constants are spelled in Python (LSB on the right). + perm = arith.select(is_even, c(0x5410), c(0x3276)) + blend = utils.prmt(reg, exchanged, perm) + low = utils.vector_slice(blend, slice(0, 2)) + high = utils.vector_slice(blend, slice(2, 4)) + else: + raise NotImplementedError(dtype_bitwidth) new_registers[(idx[0], idx[1] * 2, *idx[2:-1])] = low new_registers[(idx[0], idx[1] * 2 + 1, *idx[2:-1])] = high assert all(r is not None for r in new_registers) diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 053ba4b4b..1807449f9 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -1185,12 +1185,28 @@ def shfl_bfly(x: ir.Value, distance: int | ir.Value): y = nvvm.shfl_sync( i32, c(0xFFFFFFFF, i32), x, distance, c(0x1F, i32), nvvm.ShflKind.bfly, ) - if result_type != i32: - y = bitcast(y, result_type) - return y + return bitcast(y, result_type) + + +def prmt(high: ir.Value, low: ir.Value, permutation: ir.Value): + i32 = ir.IntegerType.get_signless(32) + if (result_type := high.type) != low.type: + raise ValueError(f"Types must match, got {high.type} and {low.type}") + if high.type != i32: + high = bitcast(high, i32) + if low.type != i32: + low = bitcast(low, i32) + if permutation.type != i32: + permutation = bitcast(permutation, i32) + result = llvm.inline_asm( + i32, [high, low, permutation], "prmt.b32 $0, $1, $2, $3;", "=r,r,r,r" + ) + return bitcast(result, result_type) def bitcast(x: ir.Value, new_type: ir.Type): + if x.type == new_type: + return x if ir.VectorType.isinstance(x.type) and ir.IntegerType.isinstance(new_type): new_type = ir.IntegerType(new_type) x_ty = ir.VectorType(x.type) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index cfd5b28a7..6426f8006 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -2275,7 +2275,10 @@ class LayoutTest(TestCase): )(x) np.testing.assert_array_equal(y, y_ref) - def test_upcast_to_wgmma(self): + @parameterized.product( + upcast_before_layout_change=[True, False], + ) + def test_upcast_to_wgmma(self, upcast_before_layout_change): in_dtype = jnp.dtype(jnp.int8) out_dtype = jnp.dtype(jnp.int16) swizzle = 128 @@ -2291,8 +2294,11 @@ class LayoutTest(TestCase): t = mgpu.FragmentedArray.load_tiled( smem_in, swizzle=swizzle, is_signed=True, layout=fa.WGMMA_LAYOUT_UPCAST_2X ) - t = t.astype(ir.IntegerType.get_signless(16), is_signed=True) + if upcast_before_layout_change: + t = t.astype(ir.IntegerType.get_signless(16), is_signed=True) t = t.to_layout(fa.WGMMA_LAYOUT) + if not upcast_before_layout_change: + t = t.astype(ir.IntegerType.get_signless(16), is_signed=True) t.store_tiled(smem_out, swizzle=swizzle) mgpu.commit_shared() ctx.async_copy(src_ref=smem_out, dst_ref=out, swizzle=swizzle) From a7e5eaee56a5f60eaa6c80b69151efa19b7e9c69 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 17 Mar 2025 05:31:15 -0700 Subject: [PATCH 12/34] [pallas:mosaic_gpu] `jnp.reduce_sum` now works for >1D arrays PiperOrigin-RevId: 737578598 --- jax/_src/pallas/mosaic_gpu/lowering.py | 7 +++-- .../mosaic/gpu/dialect_lowering.py | 14 +++++++++ .../mosaic/gpu/fragmented_array.py | 16 +++++----- .../mosaic/gpu/layout_inference.py | 31 +++++++++++++++++++ tests/mosaic/gpu_layout_inference_test.py | 31 +++++++++++++++++++ tests/pallas/mosaic_gpu_test.py | 2 +- 6 files changed, 89 insertions(+), 12 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index cc7ef88df..5c863baf6 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1559,9 +1559,10 @@ def _reduce_lowering_rule_wg( if not out_aval.shape: # Special-case: reducing to a scalar. if x_aval.ndim != 1: - # TODO(slebedev): Flatten to 1D, since vector.reduction only supports - # 1D inputs. - raise NotImplementedError("Only 1D inputs are supported") + # Flatten to 1D, since vector.reduction only supports 1D inputs. + x = vector_dialect.shape_cast( + ir.VectorType.get([x_aval.size], out_type), x + ) return vector_dialect.ReductionOp(out_type, kind, x) acc = vector_dialect.splat( ir.VectorType.get(out_aval.shape, out_type), diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index c9ddce106..8098d14f0 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -320,6 +320,20 @@ def _vector_splat_op_lowering_rule( return [_fragmented_array_to_ir(fragmented_array, out_vec_ty)] +@_register_lowering(vector.ShapeCastOp) +def _vector_shape_cast_op_lowering_rule( + _: LoweringContext, op: vector.ShapeCastOp +) -> Sequence[ir.Value]: + [layout] = inference_utils.in_layouts(op) + out_vec_ty = ir.VectorType(op.result.type) + assert out_vec_ty.has_static_shape + is_signed = ( + False if ir.IntegerType.isinstance(out_vec_ty.element_type) else None + ) + a = _fragmented_array_from_ir(op.source, layout, is_signed) + return [_fragmented_array_to_ir(a.reshape(out_vec_ty.shape), out_vec_ty)] + + @_register_lowering(vector.ReductionOp) def _vector_reduction_op_lowering_rule( ctx: LoweringContext, op: vector.ReductionOp diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index ed66269b5..d325b22c1 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -1538,17 +1538,17 @@ class FragmentedArray: def reshape(self, shape): if self.shape == shape: return self - - if not isinstance(self.layout, WGSplatFragLayout): - raise NotImplementedError(self.layout) - - if np.prod(shape) != np.prod(self.shape): + if math.prod(shape) != math.prod(self.shape): raise ValueError(f"Can't reshape {self.shape} to {shape}") + match self.layout: + case WGSplatFragLayout() | WGStridedFragLayout(): + new_layout = dataclasses.replace(self.layout, shape=shape) + case _: + raise NotImplementedError(self.layout) + return FragmentedArray( - _registers=self.registers, - _layout=WGSplatFragLayout(shape), - _is_signed=self.is_signed, + _registers=self.registers, _layout=new_layout, _is_signed=self.is_signed ) def broadcast_minor(self, n): diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index 42128aff8..c9479e0f1 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -336,6 +336,37 @@ def _infer_splat_op_layout(splat_op: vector.SplatOp) -> OptionalLayouts: return [], [layout] + +def _update_layout_shape( + layout: ir.Attribute, shape: Sequence[int], origin: str +) -> ir.Attribute: + if layouts_lib.is_splat_fragmented_layout( + layout + ) or layouts_lib.is_strided_fragmented_layout(layout): + return layouts_lib.to_layout_attr( + dataclasses.replace(layouts_lib.from_layout_attr(layout), shape=shape) + ) + raise NotImplementedError(f"Unsupported {origin} layout: {layout}.") + + +@partial(_add_layout_inference_rule, vector.ShapeCastOp) +def _infer_shape_cast_op_layout(op: vector.ShapeCastOp) -> OptionalLayouts: + in_layout = inference_utils.value_layout(op.source) + if in_layout is None: + out_layout = inference_utils.value_layout(op.result) + if out_layout is None: + return None + in_layout = _update_layout_shape( + out_layout, ir.VectorType(op.source.type).shape, "source" + ) + return [in_layout], [out_layout] + + out_layout = _update_layout_shape( + in_layout, ir.VectorType(op.result.type).shape, "result" + ) + return [in_layout], [out_layout] + + @partial(_add_layout_inference_rule, vector.ReductionOp) def _infer_reduction_op_layout(op: vector.ReductionOp) -> OptionalLayouts: if layout := inference_utils.value_layout(op.vector): diff --git a/tests/mosaic/gpu_layout_inference_test.py b/tests/mosaic/gpu_layout_inference_test.py index e0b8ab27c..36c8ff9cf 100644 --- a/tests/mosaic/gpu_layout_inference_test.py +++ b/tests/mosaic/gpu_layout_inference_test.py @@ -74,6 +74,37 @@ class LayoutInferenceTest(parameterized.TestCase): self.assertSequenceEqual(add.attributes["in_layouts"], [layout, layout]) self.assertSequenceEqual(add.attributes["out_layouts"], [layout]) + def test_infer_strided_layout_from_shape_cast(self): + shape = (16, 8) + elt_type = ir.BF16Type.get() + src_type = ir.VectorType.get(shape, elt_type) + dst_type = ir.VectorType.get([*reversed(shape)], elt_type) + op = None + + def body(x): + nonlocal op + op = vector.ShapeCastOp(dst_type, x) + + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func(src_type)(body) + + mgpu.infer_layout(self.module) + + in_layout = layouts.to_layout_attr( + mgpu.WGStridedFragLayout.from_shaped_type(src_type) + ) + out_layout = layouts.to_layout_attr( + mgpu.WGStridedFragLayout.from_shaped_type(dst_type) + ) + + self.assertSequenceEqual(op.attributes["in_layouts"], [in_layout]) + self.assertSequenceEqual(op.attributes["out_layouts"], [out_layout]) + + # Ensure that we can recover the original layout. + del op.attributes["in_layouts"] + mgpu.infer_layout(self.module) + self.assertSequenceEqual(op.attributes["in_layouts"], [in_layout]) + def test_infer_splat_layout_for_splat_constants(self): shape = (16, 8) elt_type = ir.BF16Type.get() diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index c165878e8..b3c3ddb84 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -185,7 +185,7 @@ class PallasCallTest(PallasTest): np.testing.assert_array_equal(kernel(x, y), x + y[0]) @parameterized.product( - shape=[(128,)], thread_semantics=[*plgpu.ThreadSemantics] + shape=[(128,), (128, 128)], thread_semantics=[*plgpu.ThreadSemantics] ) def test_reduce_sum(self, shape, thread_semantics): @functools.partial( From 55812c5d02d621c9c1c185298efb51ea562da9d6 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 17 Mar 2025 05:43:43 -0700 Subject: [PATCH 13/34] Update XLA dependency to use revision http://github.com/openxla/xla/commit/fcf97e619e26fcb19cffa060df2d0246f6a7ece7. PiperOrigin-RevId: 737581187 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index e5029d052..b374e3a29 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "936a727db7cefa30027b727b7056b1b5c6064145" -XLA_SHA256 = "b27db91ad7c4a1badb57bbcf92f8894ceebb709b2197cfd2e830d121e3d20dc7" +XLA_COMMIT = "fcf97e619e26fcb19cffa060df2d0246f6a7ece7" +XLA_SHA256 = "ed7fb9863ea1e20a16bdfb135e48ea39c4b232ef2fd49e173de4e2e43fa76e09" def repo(): tf_http_archive( From 0ff234049b6a630ac91ca2010e27e6b97d897c27 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 17 Mar 2025 07:49:01 -0700 Subject: [PATCH 14/34] Removed trivial docstrings from JAX tests These docstrings do not make the tests any more clear and typically just duplicate the test module name. PiperOrigin-RevId: 737611977 --- jax/experimental/array_serialization/serialization_test.py | 1 - jax/experimental/jax2tf/examples/saved_model_main_test.py | 1 - jax/experimental/jax2tf/tests/call_tf_test.py | 1 - jax/experimental/jax2tf/tests/control_flow_ops_test.py | 1 - jax/experimental/jax2tf/tests/shape_poly_test.py | 1 - tests/aot_test.py | 1 - tests/api_util_test.py | 1 - tests/array_test.py | 1 - tests/clear_backends_test.py | 1 - tests/debug_nans_test.py | 2 -- tests/garbage_collection_guard_test.py | 1 - tests/lax_numpy_ufuncs_test.py | 2 -- tests/linalg_test.py | 2 -- tests/mesh_utils_test.py | 1 - tests/mosaic/gpu_test.py | 1 - tests/mosaic/profiler_cupti_test.py | 1 - tests/nn_test.py | 2 -- tests/optimizers_test.py | 2 -- tests/pallas/fuser_block_spec_test.py | 2 -- tests/pallas/indexing_test.py | 2 -- tests/pallas/ops_test.py | 2 -- tests/pallas/pallas_error_handling_test.py | 1 - tests/pallas/tpu_ops_test.py | 1 - tests/pallas/tpu_pallas_distributed_test.py | 2 -- tests/pallas/tpu_pallas_random_test.py | 1 - tests/pallas/tpu_pallas_state_test.py | 1 - tests/pallas/tpu_splash_attention_kernel_test.py | 1 - tests/pallas/tpu_splash_attention_mask_test.py | 1 - tests/pickle_test.py | 1 - tests/qdwh_test.py | 1 - tests/shape_poly_test.py | 1 - tests/stack_test.py | 3 --- tests/stax_test.py | 2 -- tests/svd_test.py | 1 - tests/transfer_guard_test.py | 1 - 35 files changed, 47 deletions(-) diff --git a/jax/experimental/array_serialization/serialization_test.py b/jax/experimental/array_serialization/serialization_test.py index 6ec621d68..9f4539fc6 100644 --- a/jax/experimental/array_serialization/serialization_test.py +++ b/jax/experimental/array_serialization/serialization_test.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for serialization and deserialization of GDA.""" import asyncio import math diff --git a/jax/experimental/jax2tf/examples/saved_model_main_test.py b/jax/experimental/jax2tf/examples/saved_model_main_test.py index aa6be0cff..5d6982179 100644 --- a/jax/experimental/jax2tf/examples/saved_model_main_test.py +++ b/jax/experimental/jax2tf/examples/saved_model_main_test.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for mnist_lib, saved_model_lib, saved_model_main.""" import os from absl import flags diff --git a/jax/experimental/jax2tf/tests/call_tf_test.py b/jax/experimental/jax2tf/tests/call_tf_test.py index 0cde96aeb..4647a16d7 100644 --- a/jax/experimental/jax2tf/tests/call_tf_test.py +++ b/jax/experimental/jax2tf/tests/call_tf_test.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for call_tf.""" from collections.abc import Callable import contextlib diff --git a/jax/experimental/jax2tf/tests/control_flow_ops_test.py b/jax/experimental/jax2tf/tests/control_flow_ops_test.py index c66a6d696..3b39c8752 100644 --- a/jax/experimental/jax2tf/tests/control_flow_ops_test.py +++ b/jax/experimental/jax2tf/tests/control_flow_ops_test.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for the jax2tf conversion for control-flow primitives.""" from absl.testing import absltest diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index 3f8320fd8..09da97e84 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for the shape-polymorphic jax2tf conversion.""" from __future__ import annotations diff --git a/tests/aot_test.py b/tests/aot_test.py index 1245967f8..daaeb8417 100644 --- a/tests/aot_test.py +++ b/tests/aot_test.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for AOT compilation.""" import contextlib import unittest diff --git a/tests/api_util_test.py b/tests/api_util_test.py index e34611c6e..26cca4e74 100644 --- a/tests/api_util_test.py +++ b/tests/api_util_test.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for jax.api_util.""" import itertools as it from absl.testing import absltest diff --git a/tests/array_test.py b/tests/array_test.py index 4184c835a..cc8990828 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for Array.""" import contextlib import math diff --git a/tests/clear_backends_test.py b/tests/clear_backends_test.py index 9ea9cac3a..6e98e7293 100644 --- a/tests/clear_backends_test.py +++ b/tests/clear_backends_test.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for release_backend_clients.""" from absl.testing import absltest import jax diff --git a/tests/debug_nans_test.py b/tests/debug_nans_test.py index d3dcfb2e7..c80d23c41 100644 --- a/tests/debug_nans_test.py +++ b/tests/debug_nans_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for --debug_nans.""" - from absl.testing import absltest import jax diff --git a/tests/garbage_collection_guard_test.py b/tests/garbage_collection_guard_test.py index 5c34c6de2..f833c9e00 100644 --- a/tests/garbage_collection_guard_test.py +++ b/tests/garbage_collection_guard_test.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for garbage allocation guard.""" import gc import weakref diff --git a/tests/lax_numpy_ufuncs_test.py b/tests/lax_numpy_ufuncs_test.py index 20a1a58a9..fd5050a58 100644 --- a/tests/lax_numpy_ufuncs_test.py +++ b/tests/lax_numpy_ufuncs_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for jax.numpy.ufunc and its methods.""" - import itertools from functools import partial diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 2355449fc..feab105cc 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for the LAPAX linear algebra module.""" - from functools import partial import itertools from typing import Iterator diff --git a/tests/mesh_utils_test.py b/tests/mesh_utils_test.py index 4f1b1fb03..136b50794 100644 --- a/tests/mesh_utils_test.py +++ b/tests/mesh_utils_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for mesh utils.""" import collections from collections.abc import Sequence diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 6426f8006..1f43b46dc 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for Mosaic GPU DSL functions and utilities.""" from collections.abc import Sequence import dataclasses diff --git a/tests/mosaic/profiler_cupti_test.py b/tests/mosaic/profiler_cupti_test.py index c6f3d23b4..f3dcf71ca 100644 --- a/tests/mosaic/profiler_cupti_test.py +++ b/tests/mosaic/profiler_cupti_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for Mosaic GPU CUPTI-based profiler.""" from absl.testing import absltest, parameterized import jax diff --git a/tests/nn_test.py b/tests/nn_test.py index 9df315967..ed016ec34 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for nn module.""" - import collections from functools import partial import itertools diff --git a/tests/optimizers_test.py b/tests/optimizers_test.py index c4eca0707..2e027e615 100644 --- a/tests/optimizers_test.py +++ b/tests/optimizers_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for the optimizers module.""" - import functools from absl.testing import absltest diff --git a/tests/pallas/fuser_block_spec_test.py b/tests/pallas/fuser_block_spec_test.py index d87bfd30e..1b3a21587 100644 --- a/tests/pallas/fuser_block_spec_test.py +++ b/tests/pallas/fuser_block_spec_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for pull block spec.""" - from absl.testing import absltest from absl.testing import parameterized import jax diff --git a/tests/pallas/indexing_test.py b/tests/pallas/indexing_test.py index 13f9880e6..c3f3fa6e8 100644 --- a/tests/pallas/indexing_test.py +++ b/tests/pallas/indexing_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for Pallas indexing logic and abstractions.""" - from __future__ import annotations import sys import unittest diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 27ec2dc48..0fc375bf6 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for common JAX operations within pallas_call.""" - from collections.abc import Sequence import functools import itertools diff --git a/tests/pallas/pallas_error_handling_test.py b/tests/pallas/pallas_error_handling_test.py index 34b0ff149..cd5ceecfc 100644 --- a/tests/pallas/pallas_error_handling_test.py +++ b/tests/pallas/pallas_error_handling_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for Pallas error handling.""" import functools import traceback diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index c4d600d23..c8def2627 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for TPU specific operations within pallas_call.""" import functools import math diff --git a/tests/pallas/tpu_pallas_distributed_test.py b/tests/pallas/tpu_pallas_distributed_test.py index a329478a1..f7d7daf18 100644 --- a/tests/pallas/tpu_pallas_distributed_test.py +++ b/tests/pallas/tpu_pallas_distributed_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for distributed pallas TPU operations.""" - import functools import os import tempfile diff --git a/tests/pallas/tpu_pallas_random_test.py b/tests/pallas/tpu_pallas_random_test.py index c6027cd04..ca8edf7a2 100644 --- a/tests/pallas/tpu_pallas_random_test.py +++ b/tests/pallas/tpu_pallas_random_test.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for random ops in Pallas + Mosaic.""" from absl.testing import absltest from absl.testing import parameterized diff --git a/tests/pallas/tpu_pallas_state_test.py b/tests/pallas/tpu_pallas_state_test.py index ab3a82dab..46f98c087 100644 --- a/tests/pallas/tpu_pallas_state_test.py +++ b/tests/pallas/tpu_pallas_state_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for Pallas mesh API.""" import functools from absl.testing import absltest import jax diff --git a/tests/pallas/tpu_splash_attention_kernel_test.py b/tests/pallas/tpu_splash_attention_kernel_test.py index c5943fa27..dfe0bcc0d 100644 --- a/tests/pallas/tpu_splash_attention_kernel_test.py +++ b/tests/pallas/tpu_splash_attention_kernel_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for splash_attention.""" from __future__ import annotations from collections.abc import Callable diff --git a/tests/pallas/tpu_splash_attention_mask_test.py b/tests/pallas/tpu_splash_attention_mask_test.py index f23baf2a0..f39b4d839 100644 --- a/tests/pallas/tpu_splash_attention_mask_test.py +++ b/tests/pallas/tpu_splash_attention_mask_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for splash_attention_masks.""" from __future__ import annotations from absl.testing import absltest diff --git a/tests/pickle_test.py b/tests/pickle_test.py index d2bc89f44..185eebd90 100644 --- a/tests/pickle_test.py +++ b/tests/pickle_test.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for interoperability between JAX and pickling libraries.""" import pickle import unittest diff --git a/tests/qdwh_test.py b/tests/qdwh_test.py index 705c96f00..91cc3a51f 100644 --- a/tests/qdwh_test.py +++ b/tests/qdwh_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License -"""Tests for the library of QDWH-based polar decomposition.""" import functools from absl.testing import absltest diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index 67f8e49b3..6d1ffe744 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for the shape-polymorphic export.""" from __future__ import annotations diff --git a/tests/stack_test.py b/tests/stack_test.py index 655a42571..aa1a02793 100644 --- a/tests/stack_test.py +++ b/tests/stack_test.py @@ -12,9 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - -"""Tests for stack.""" - from absl.testing import absltest import jax diff --git a/tests/stax_test.py b/tests/stax_test.py index e21300ddd..8c38820d2 100644 --- a/tests/stax_test.py +++ b/tests/stax_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for Stax library.""" - from absl.testing import absltest import numpy as np diff --git a/tests/svd_test.py b/tests/svd_test.py index b349c3ca6..97f8176f8 100644 --- a/tests/svd_test.py +++ b/tests/svd_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License -"""Tests for the library of QDWH-based singular value decomposition.""" import functools import jax diff --git a/tests/transfer_guard_test.py b/tests/transfer_guard_test.py index 6a255b0a1..740ef55ff 100644 --- a/tests/transfer_guard_test.py +++ b/tests/transfer_guard_test.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for transfer guards.""" import contextlib import pickle From 3649da56fc1d79d274105b9c2882a857c31c143f Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 17 Mar 2025 08:36:36 -0700 Subject: [PATCH 15/34] [Mosaic GPU] Make the s4 -> bf16 upcast more flexible when it comes to vector length We can now perform the conversion in groups of 2, 4 or even 8 elements at a time. PiperOrigin-RevId: 737626600 --- jax/BUILD | 3 +- .../mosaic/gpu/fragmented_array.py | 59 ++++++++++++------- jax/experimental/mosaic/gpu/utils.py | 33 +++++++++++ tests/mosaic/gpu_test.py | 17 +++--- 4 files changed, 81 insertions(+), 31 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index d6f100581..12eae4afd 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -799,7 +799,7 @@ pytype_strict_library( ) # This target only supports sm_90 GPUs. -py_library( +py_library_providing_imports_info( name = "mosaic_gpu", srcs = glob(["experimental/mosaic/gpu/*.py"]), visibility = [ @@ -824,6 +824,7 @@ py_library( "//jaxlib/mlir:pass_manager", "//jaxlib/mlir:scf_dialect", "//jaxlib/mlir:vector_dialect", + "//jaxlib/mosaic/python:gpu_dialect", ] + py_deps("absl/flags") + py_deps("numpy"), ) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index d325b22c1..c6d7c02fb 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -1244,11 +1244,10 @@ class FragmentedArray: is_vector_reg = ir.VectorType.isinstance(reg_type) reg_shape = tuple(ir.VectorType(reg_type).shape) if is_vector_reg else (1,) [vector_len] = reg_shape # This is meant to be a 1D assertion. - if cur_dtype == i4 and self.is_signed and new_dtype == bf16 and vector_len == 2: + if cur_dtype == i4 and self.is_signed and new_dtype == bf16: new_registers = np.empty_like(self.registers) - empty_vec_32 = llvm.mlir_undef(ir.VectorType.get((1,), i32)) + out_vec_ty = ir.VectorType.get((vector_len,), new_dtype) for idx, reg in np.ndenumerate(self.registers): - reg_8 = vector.bitcast(ir.VectorType.get((1,), i8), reg) # The algorithm here is largely the same as CUTLASS's # NumericArrayConverter specialization for int4 -> bf16 casts. # We modify it slightly, because we only extract 2 values. @@ -1262,25 +1261,41 @@ class FragmentedArray: # positive int4s will end up larger than negative int4s, with a bias of # 8. Use use the sub to subtract the base (our initial exponent) and the # bias coming from flipping the sign bit which is 136 (0x4308 as bits). - new_reg_32 = llvm.inline_asm( - i32, - [reg_8], - """ - { - .reg .b32 s<4>; - shr.s32 s0, $1, 4; - prmt.b32 s1, $1, s0, 0xF4F0; - lop3.b32 s2, s1, 0x000F000F, 0x43084308, (0xf0 & 0xcc) ^ 0xaa; - mov.b32 s3, 0x43084308; - sub.bf16x2 $0, s2, s3; - } - """, - "=r,r", - ) - new_vec_32 = llvm.insertelement(empty_vec_32, new_reg_32, c(0, i32)) - new_registers[idx] = vector.bitcast( - ir.VectorType.get((vector_len,), new_dtype), new_vec_32 - ) + def upcast_to_bf16(reg: ir.Value, reg_shr: ir.Value, part: int): + assert 0 <= part < 4 + return llvm.inline_asm( + i32, + [reg, reg_shr], + f""" + {{ + .reg .b32 s<4>; + prmt.b32 s1, $1, $2, 0xF{part + 4}F{part}; + lop3.b32 s2, s1, 0x000F000F, 0x43084308, (0xf0 & 0xcc) ^ 0xaa; + mov.b32 s3, 0x43084308; + sub.bf16x2 $0, s2, s3; + }} + """, + "=r,r,r", + ) + offset = 0 + out_int_regs = [] + for group_size in (8, 4, 2): + int_ty = ir.IntegerType.get_signless(group_size * 4) + while vector_len - offset >= group_size: + reg_slice = utils.vector_slice(reg, slice(offset, offset + group_size)) + reg_slice_int = arith.extsi(i32, utils.bitcast(reg_slice, int_ty)) + reg_slice_int_shr = arith.shrui(reg_slice_int, c(4, i32)) + out_int_regs.extend( + upcast_to_bf16(reg_slice_int, reg_slice_int_shr, part=part) + for part in range(group_size // 2) + ) + offset += group_size + assert offset == vector_len + out_vec_int = utils.vector_concat([ + vector.splat(ir.VectorType.get((1,), i32), reg) + for reg in out_int_regs + ]) + new_registers[idx] = utils.bitcast(out_vec_int, out_vec_ty) return FragmentedArray( _registers=new_registers, _layout=self.layout, _is_signed=None ) diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 1807449f9..91cb19746 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -348,6 +348,9 @@ def bitwidth_impl(ty: ir.Type): return ir.FloatType(ty).width if dialect is not None and ir.Type.parse("!mosaic_gpu.barrier"): return MBARRIER_BYTES * 8 + if ir.VectorType.isinstance(ty): + vty = ir.VectorType(ty) + return math.prod(vty.shape) * bitwidth(vty.element_type) raise NotImplementedError(ty) @@ -1220,6 +1223,12 @@ def bitcast(x: ir.Value, new_type: ir.Type): x_ty = ir.IntegerType(x.type) assert x_ty.width == bitwidth(new_type.element_type) * math.prod(new_type.shape) return vector.bitcast(new_type, vector.splat(ir.VectorType.get((1,), x_ty), x)) + if ir.VectorType.isinstance(x.type) and ir.VectorType.isinstance(new_type): + x_ty = ir.VectorType(x.type) + new_ty = ir.VectorType(new_type) + if bitwidth(x_ty) != bitwidth(new_ty): + raise ValueError(f"Can't bitcast {x.type} to {new_type}") + return vector.bitcast(new_type, x) raise ValueError(f"Can't bitcast {x.type} to {new_type}") @@ -1239,3 +1248,27 @@ def vector_slice(v: ir.Value, s: slice): elem = llvm.extractelement(v, c(src, i32)) result = llvm.insertelement(result, elem, c(tgt, i32)) return result + + +def vector_concat(vectors: Sequence[ir.Value]) -> ir.Value: + index = ir.IndexType.get() + if not vectors: + raise ValueError("Cannot concatenate an empty list of vectors") + vty = vectors[0].type + if not ir.VectorType.isinstance(vty): + raise ValueError("Cannot concatenate non-vector values") + if vty.rank != 1: + raise NotImplementedError("Only 1D vectors are supported") + for v in vectors: + if v.type != vty: + raise ValueError("Cannot concatenate vectors of different types") + result = llvm.mlir_undef( + ir.VectorType.get((vty.shape[0] * len(vectors),), vty.element_type) + ) + offset = 0 + for v in vectors: + for i in range(vty.shape[0]): + elem = vector.extractelement(v, position=c(i, index)) + result = vector.insertelement(elem, result, position=c(offset + i, index)) + offset += vty.shape[0] + return result diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 1f43b46dc..574299ab1 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -518,14 +518,15 @@ class WGMMALayoutTest(TestCase): )() np.testing.assert_array_equal(iota, expected) - @parameterized.named_parameters( - ("bf16_i8", jnp.bfloat16, jnp.int8), - ("i8_bf16", jnp.int8, jnp.bfloat16), - ("i8_i8", jnp.int8, jnp.int8), - ("i4_i4", jnp.int4, jnp.int4), - ("i4_bf16", jnp.int4, jnp.bfloat16), + @parameterized.product( + jax_dtype_from_to=( + (jnp.int8, jnp.bfloat16), + (jnp.int4, jnp.bfloat16), + ), + layout=(fa.WGMMA_LAYOUT, fa.WGMMA_LAYOUT_UPCAST_2X), ) - def test_convert_tiled(self, jax_dtype_from, jax_dtype_to): + def test_optimized_conversion(self, jax_dtype_from_to, layout): + jax_dtype_from, jax_dtype_to = jax_dtype_from_to mlir_dtype_from = utils.dtype_to_ir_type(jax_dtype_from) mlir_dtype_to = utils.dtype_to_ir_type(jax_dtype_to) m = 128 @@ -538,7 +539,7 @@ class WGMMALayoutTest(TestCase): smem_from, swizzle=128, is_signed=utils.is_signed(jax_dtype_from), - layout=fa._tiled_wgmma_layout((m, n)) + layout=layout, ) t = t.astype(mlir_dtype_to, is_signed=utils.is_signed(jax_dtype_to)) t.store_tiled(smem_to, swizzle=128) From 031614c22b8aa8c60652d379e54baea187c427a0 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 17 Mar 2025 08:58:18 -0700 Subject: [PATCH 16/34] Pin numpy~=2.1.0 in workflow file instead of test-requirements.txt PiperOrigin-RevId: 737632771 --- .github/workflows/pytest_cpu.yml | 5 +++++ build/test-requirements.txt | 5 +---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index d59f5606a..137f49c6d 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -118,6 +118,11 @@ jobs: run: | $JAXCI_PYTHON -m pip install uv~=0.5.30 $JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt + + # CPU Pytests crash with NumPy 2.2+ on Linux Aarch64; b/399168632 + if [[ $OS == "linux" && $ARCH == "aarch64" ]]; then + $JAXCI_PYTHON -m uv pip install numpy~=2.1.0 + fi # Halt for testing - name: Wait For Connection uses: google-ml-infra/actions/ci_connection@main diff --git a/build/test-requirements.txt b/build/test-requirements.txt index 84cd01d82..f0b315771 100644 --- a/build/test-requirements.txt +++ b/build/test-requirements.txt @@ -18,7 +18,4 @@ setuptools matplotlib~=3.8.4; python_version=="3.10" matplotlib; python_version>="3.11" opt-einsum -auditwheel - -# CPU Pytests crash with NumPy 2.2+ on Linux Aarch64; b/399168632 -numpy~=2.1.0; platform_system == "Linux" and platform_machine == "aarch64" +auditwheel \ No newline at end of file From 3f59fa688899f95fac391a2ef1332afdada6663d Mon Sep 17 00:00:00 2001 From: carlosgmartin Date: Thu, 13 Mar 2025 19:00:19 -0400 Subject: [PATCH 17/34] Add replace option to random.categorical to enable sampling without replacement. --- CHANGELOG.md | 2 ++ jax/_src/random.py | 61 ++++++++++++++++++++++++++++++---------- tests/random_lax_test.py | 32 +++++++++++++++++++++ 3 files changed, 80 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 140f66c30..c30877eca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. {func}`jax.lax.dynamic_update_slice` and related functions. The default is true, matching the current behavior. If set to false, JAX does not need to emit code clamping negative indices, which improves code size. + * Added a `replace` option to {func}`jax.random.categorical` to enable sampling + without replacement. ## jax 0.5.2 (Mar 4, 2025) diff --git a/jax/_src/random.py b/jax/_src/random.py index 4c1436e3f..094268c65 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -1548,12 +1548,18 @@ def _gumbel(key, shape, dtype, mode) -> Array: _uniform(key, shape, dtype, minval=jnp.finfo(dtype).tiny, maxval=1.))) -def categorical(key: ArrayLike, - logits: RealArray, - axis: int = -1, - shape: Shape | None = None) -> Array: +def categorical( + key: ArrayLike, + logits: RealArray, + axis: int = -1, + shape: Shape | None = None, + replace: bool = True, +) -> Array: """Sample random values from categorical distributions. + Sampling with replacement uses the Gumbel max trick. Sampling without replacement uses + the Gumbel top-k trick. See [1] for reference. + Args: key: a PRNG key used as the random key. logits: Unnormalized log probabilities of the categorical distribution(s) to sample from, @@ -1562,32 +1568,57 @@ def categorical(key: ArrayLike, shape: Optional, a tuple of nonnegative integers representing the result shape. Must be broadcast-compatible with ``np.delete(logits.shape, axis)``. The default (None) produces a result shape equal to ``np.delete(logits.shape, axis)``. + replace: If True, perform sampling without replacement. Default (False) is to + perform sampling with replacement. Returns: A random array with int dtype and shape given by ``shape`` if ``shape`` is not None, or else ``np.delete(logits.shape, axis)``. + + References: + .. [1] Wouter Kool, Herke van Hoof, Max Welling. "Stochastic Beams and Where to Find + Them: The Gumbel-Top-k Trick for Sampling Sequences Without Replacement". + Proceedings of the 36th International Conference on Machine Learning, PMLR + 97:3499-3508, 2019. https://proceedings.mlr.press/v97/kool19a.html. """ key, _ = _check_prng_key("categorical", key) check_arraylike("categorical", logits) logits_arr = jnp.asarray(logits) - - if axis >= 0: - axis -= len(logits_arr.shape) - batch_shape = tuple(np.delete(logits_arr.shape, axis)) if shape is None: shape = batch_shape else: shape = core.canonicalize_shape(shape) _check_shape("categorical", shape, batch_shape) - shape_prefix = shape[:len(shape)-len(batch_shape)] - logits_shape = list(shape[len(shape) - len(batch_shape):]) - logits_shape.insert(axis % len(logits_arr.shape), logits_arr.shape[axis]) - return jnp.argmax( - gumbel(key, (*shape_prefix, *logits_shape), logits_arr.dtype) + - lax.expand_dims(logits_arr, tuple(range(len(shape_prefix)))), - axis=axis) + + if replace: + if axis >= 0: + axis -= len(logits_arr.shape) + + logits_shape = list(shape[len(shape) - len(batch_shape):]) + logits_shape.insert(axis % len(logits_arr.shape), logits_arr.shape[axis]) + return jnp.argmax( + gumbel(key, (*shape_prefix, *logits_shape), logits_arr.dtype) + + lax.expand_dims(logits_arr, tuple(range(len(shape_prefix)))), + axis=axis) + else: + logits_arr += gumbel(key, logits_arr.shape, logits_arr.dtype) + k = math.prod(shape_prefix) + if k > logits_arr.shape[axis]: + raise ValueError( + f"Number of samples without replacement ({k}) cannot exceed number of " + f"categories ({logits_arr.shape[axis]})." + ) + + _, indices = lax.top_k(jnp.moveaxis(logits_arr, axis, -1), k) + assert indices.shape == batch_shape + (k,) + assert shape == shape_prefix + batch_shape + + dimensions = (indices.ndim - 1, *range(indices.ndim - 1)) + indices = lax.reshape(indices, shape, dimensions) + assert indices.shape == shape + return indices def laplace(key: ArrayLike, diff --git a/tests/random_lax_test.py b/tests/random_lax_test.py index 366f5ab3c..b6f8b4f13 100644 --- a/tests/random_lax_test.py +++ b/tests/random_lax_test.py @@ -365,6 +365,38 @@ class LaxRandomTest(jtu.JaxTestCase): pmf = lambda x: np.where(x < len(p), p[np.minimum(len(p) - 1, x)], 0.0) self._CheckChiSquared(samples, pmf=pmf) + @jtu.sample_product( + logits_shape=[(7,), (8, 9), (10, 11, 12)], + prefix_shape=[(2,), (3, 4), (5, 6)], + ) + def testCategoricalWithoutReplacement(self, logits_shape, prefix_shape): + key = random.key(0) + + key, subkey = random.split(key) + logits = random.normal(subkey, logits_shape) + + key, subkey = random.split(key) + axis = random.randint(subkey, (), -len(logits_shape), len(logits_shape)) + + dists_shape = tuple(np.delete(logits_shape, axis)) + n_categories = logits_shape[axis] + shape = prefix_shape + dists_shape + prefix_size = math.prod(prefix_shape) + + if n_categories < prefix_size: + with self.assertRaisesRegex(ValueError, "Number of samples without replacement"): + random.categorical(key, logits, axis=axis, shape=shape, replace=False) + + else: + output = random.categorical(key, logits, axis=axis, shape=shape, replace=False) + self.assertEqual(output.shape, shape) + assert (0 <= output).all() + assert (output < n_categories).all() + flat = output.reshape((prefix_size, math.prod(dists_shape))) + counts = jax.vmap(partial(jnp.bincount, length=n_categories), 1)(flat) + assert (counts <= 1).all() + + def testBernoulliShape(self): key = self.make_key(0) with jax.numpy_rank_promotion('allow'): From 9a686e0bf3ef2b6c359aa4b15acee51cbf2753c3 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Mon, 17 Mar 2025 12:07:08 -0700 Subject: [PATCH 18/34] [Mosaic GPU] Add initial transform inference rules for `vector.{load,store}`. PiperOrigin-RevId: 737703568 --- .../mosaic/gpu/transform_inference.py | 59 ++++++ tests/mosaic/gpu_transform_inference_test.py | 184 ++++++++++++++++++ 2 files changed, 243 insertions(+) diff --git a/jax/experimental/mosaic/gpu/transform_inference.py b/jax/experimental/mosaic/gpu/transform_inference.py index 65936f943..a3919ea1d 100644 --- a/jax/experimental/mosaic/gpu/transform_inference.py +++ b/jax/experimental/mosaic/gpu/transform_inference.py @@ -25,8 +25,12 @@ from typing import cast from jax._src.lib import mosaic_gpu_dialect as mgpu from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import arith +from jax._src.lib.mlir.dialects import vector +from . import fragmented_array as fa from . import inference_utils +from . import layouts as layouts_lib from . import utils # mypy: ignore-errors @@ -40,6 +44,7 @@ def _add_transform_inference_rule( op: type[ir.OpView], rule: TransformInferenceRule ): _transform_inference_rules[op.OPERATION_NAME] = rule # pytype: disable=attribute-error + return rule def _set_transform_attributes( @@ -110,6 +115,60 @@ def _infer_async_load_transforms(op: mgpu.AsyncLoadOp) -> OptionalTransforms: return None if in_transforms is None else ([in_transforms], []) +@partial(_add_transform_inference_rule, vector.LoadOp) +@partial(_add_transform_inference_rule, vector.StoreOp) +def _infer_vector_load_store_transforms( + op: vector.LoadOp | vector.StoreOp, +) -> OptionalTransforms: + for i in op.indices: + index_defining_op = i.owner.opview + if ( + not isinstance(index_defining_op, arith.ConstantOp) + or index_defining_op.literal_value != 0 + ): + # TODO(bchetioui): handle slicing. + raise NotImplementedError( + f"Only constants with value 0 are supported as indices for {op}" + ) + + if isinstance(op, vector.LoadOp): + [layout_attr] = inference_utils.out_layouts(op) + else: + assert isinstance(op, vector.StoreOp) + [layout_attr] = inference_utils.in_layouts(op) + + layout = layouts_lib.from_layout_attr(layout_attr) + transforms = inference_utils.value_transforms(op.base) + + if layout == fa.WGMMA_LAYOUT: + layout_transforms = infer_transforms_for_wgmma_ref( + ir.MemRefType(op.base.type) + ) + elif (isinstance(layout, fa.WGStridedFragLayout) or + isinstance(layout, fa.WGSplatFragLayout)): + layout_transforms = None + else: + raise NotImplementedError( + f"Got layout {layout} which is not yet supported" + ) + + if transforms is not None and layout_transforms is not None: + if transforms != layout_transforms: + raise NotImplementedError( + f"Conflicting transforms for {op.base} in {op}: " + f"{transforms} != {layout_transforms}." + ) + return [transforms], [] + + if transforms is not None: + return [transforms], [] + + if layout_transforms is not None: + return [layout_transforms], [] + + return None + + def _should_have_transforms(op: ir.OpView) -> bool: """Returns 'True' if the operation should be assigned in/out transforms.""" return any( diff --git a/tests/mosaic/gpu_transform_inference_test.py b/tests/mosaic/gpu_transform_inference_test.py index 65c851dda..2618c22ac 100644 --- a/tests/mosaic/gpu_transform_inference_test.py +++ b/tests/mosaic/gpu_transform_inference_test.py @@ -25,8 +25,11 @@ from jax._src.interpreters import mlir as mlir_interpreter from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith from jax._src.lib.mlir.dialects import func +from jax._src.lib.mlir.dialects import vector import jax.experimental.mosaic.gpu as mgpu +from jax.experimental.mosaic.gpu import fragmented_array as fa from jax.experimental.mosaic.gpu import inference_utils +from jax.experimental.mosaic.gpu import layouts as layouts_lib import numpy as np @@ -162,6 +165,187 @@ class TransformInferenceTest(parameterized.TestCase): ) self.assertEmpty(inference_utils.out_transforms(async_store_op)) + def test_infer_transforms_for_vector_load_op_derives_from_destination(self): + vector_load_op = None + shape = (64, 64) + elt_ty = ir.BF16Type.get() + + def body(smem_ref): + nonlocal vector_load_op + zero = arith.constant(ir.IntegerType.get_signless(32), 0) + vector_load_op = vector.LoadOp( + ir.VectorType.get(shape, elt_ty), smem_ref, [zero] * len(shape) + ) + + with ir.InsertionPoint(self.module.body): + smem = ir.Attribute.parse("#gpu.address_space") + smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem) + func.FuncOp.from_py_func(smem_ty)(body) + + vector_load_op.attributes["out_layouts"] = ir.ArrayAttr.get( + [layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)] + ) + + mgpu.infer_transforms(self.module) + + expected_transforms = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((8, 64)), + mgpu.dialect.SwizzleTransformAttr.get(128), + ]) + + self.assertSequenceEqual( + inference_utils.in_transforms(vector_load_op), [expected_transforms] + ) + self.assertEmpty(inference_utils.out_transforms(vector_load_op)) + + def test_infer_transforms_for_vector_load_op_derives_from_source(self): + vector_load_op = None + shape = (64, 64) + elt_ty = ir.BF16Type.get() + + def body(smem_ref): + nonlocal vector_load_op + zero = arith.constant(ir.IntegerType.get_signless(32), 0) + vector_load_op = vector.LoadOp( + ir.VectorType.get(shape, elt_ty), smem_ref, [zero] * len(shape) + ) + + with ir.InsertionPoint(self.module.body): + smem = ir.Attribute.parse("#gpu.address_space") + smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem) + f = func.FuncOp.from_py_func(smem_ty)(body).func_op + + vector_load_op.attributes["out_layouts"] = ir.ArrayAttr.get( + [layouts_lib.to_layout_attr(fa.WGStridedFragLayout(shape, vec_size=4))] + ) + transforms = ir.ArrayAttr.get([mgpu.dialect.TileTransformAttr.get((8, 64))]) + f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms]) + + mgpu.infer_transforms(self.module) + + self.assertSequenceEqual( + inference_utils.in_transforms(vector_load_op), [transforms] + ) + self.assertEmpty(inference_utils.out_transforms(vector_load_op)) + + def test_infer_transforms_for_vector_load_op_raises_on_mismatches(self): + vector_load_op = None + shape = (64, 64) + elt_ty = ir.BF16Type.get() + + def body(smem_ref): + nonlocal vector_load_op + zero = arith.constant(ir.IntegerType.get_signless(32), 0) + vector_load_op = vector.LoadOp( + ir.VectorType.get(shape, elt_ty), smem_ref, [zero] * len(shape) + ) + + with ir.InsertionPoint(self.module.body): + smem = ir.Attribute.parse("#gpu.address_space") + smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem) + f = func.FuncOp.from_py_func(smem_ty)(body).func_op + + vector_load_op.attributes["out_layouts"] = ir.ArrayAttr.get( + [layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)] + ) + transforms = ir.ArrayAttr.get([mgpu.dialect.TileTransformAttr.get((8, 64))]) + f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms]) + + with self.assertRaisesRegex(NotImplementedError, "Conflicting transforms"): + mgpu.infer_transforms(self.module) + + def test_infer_transforms_for_vector_store_op_derives_from_destination(self): + vector_store_op = None + shape = (64, 64) + elt_ty = ir.BF16Type.get() + + def body(smem_ref, value_to_store): + nonlocal vector_store_op + zero = arith.constant(ir.IntegerType.get_signless(32), 0) + vector_store_op = vector.StoreOp( + value_to_store, smem_ref, [zero] * len(shape) + ) + + with ir.InsertionPoint(self.module.body): + smem = ir.Attribute.parse("#gpu.address_space") + smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem) + value_ty = ir.VectorType.get(shape, elt_ty) + func.FuncOp.from_py_func(smem_ty, value_ty)(body) + + vector_store_op.attributes["in_layouts"] = ir.ArrayAttr.get( + [layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)] + ) + + mgpu.infer_transforms(self.module) + + expected_transforms = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((8, 64)), + mgpu.dialect.SwizzleTransformAttr.get(128), + ]) + + self.assertSequenceEqual( + inference_utils.in_transforms(vector_store_op), [expected_transforms] + ) + self.assertEmpty(inference_utils.out_transforms(vector_store_op)) + + def test_infer_transforms_for_vector_store_op_derives_from_source(self): + vector_store_op = None + shape = (64, 64) + elt_ty = ir.BF16Type.get() + + def body(smem_ref, value_to_store): + nonlocal vector_store_op + zero = arith.constant(ir.IntegerType.get_signless(32), 0) + vector_store_op = vector.StoreOp( + value_to_store, smem_ref, [zero] * len(shape) + ) + + with ir.InsertionPoint(self.module.body): + smem = ir.Attribute.parse("#gpu.address_space") + smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem) + value_ty = ir.VectorType.get(shape, elt_ty) + f = func.FuncOp.from_py_func(smem_ty, value_ty)(body).func_op + + vector_store_op.attributes["in_layouts"] = ir.ArrayAttr.get( + [layouts_lib.to_layout_attr(fa.WGStridedFragLayout(shape, vec_size=4))] + ) + transforms = ir.ArrayAttr.get([mgpu.dialect.TileTransformAttr.get((8, 64))]) + f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms]) + + mgpu.infer_transforms(self.module) + + self.assertSequenceEqual( + inference_utils.in_transforms(vector_store_op), [transforms] + ) + self.assertEmpty(inference_utils.out_transforms(vector_store_op)) + + def test_infer_transforms_for_vector_store_op_raises_on_mismatches(self): + vector_store_op = None + shape = (64, 64) + elt_ty = ir.BF16Type.get() + + def body(smem_ref, value_to_store): + nonlocal vector_store_op + zero = arith.constant(ir.IntegerType.get_signless(32), 0) + vector_store_op = vector.StoreOp( + value_to_store, smem_ref, [zero] * len(shape) + ) + + with ir.InsertionPoint(self.module.body): + smem = ir.Attribute.parse("#gpu.address_space") + smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem) + value_ty = ir.VectorType.get(shape, elt_ty) + f = func.FuncOp.from_py_func(smem_ty, value_ty)(body).func_op + + vector_store_op.attributes["in_layouts"] = ir.ArrayAttr.get( + [layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)] + ) + transforms = ir.ArrayAttr.get([mgpu.dialect.TileTransformAttr.get((8, 64))]) + f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms]) + + with self.assertRaisesRegex(NotImplementedError, "Conflicting transforms"): + mgpu.infer_transforms(self.module) + if __name__ == "__main__": parameterized.absltest.main(testLoader=jtu.JaxTestLoader()) From be5d13af77e28a18ca13ed193135a0157f1ac0cb Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 17 Mar 2025 12:42:17 -0700 Subject: [PATCH 19/34] Remove code that preserved _original_py_fns on C++ classes. This no longer appears to be used. PiperOrigin-RevId: 737715578 --- jax/_src/util.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/jax/_src/util.py b/jax/_src/util.py index 3e52a2584..0e28aea04 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -668,17 +668,12 @@ def use_cpp_class(cpp_cls: type[Any]) -> Callable[[type[T]], type[T]]: exclude_methods = {'__module__', '__dict__', '__doc__'} - originals = {} for attr_name, attr in cls.__dict__.items(): if attr_name not in exclude_methods: - if hasattr(_original_func(attr), "_use_cpp"): - originals[attr_name] = attr - else: + if not hasattr(_original_func(attr), "_use_cpp"): setattr(cpp_cls, attr_name, attr) cpp_cls.__doc__ = cls.__doc__ - # TODO(pschuh): Remove once fastpath is gone. - cpp_cls._original_py_fns = originals return cpp_cls return wrapper From 20658fabb3a2c01ddfec648a6df91bfaa7c27050 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 17 Mar 2025 13:16:52 -0700 Subject: [PATCH 20/34] Replace cached function get_replicated_hlo_sharding() with a constant. Small cleanup, no functional changes intended. PiperOrigin-RevId: 737727727 --- jax/_src/debugging.py | 2 +- jax/_src/sharding_impls.py | 12 +++++------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index 54ac2d5fd..b61b28e12 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -446,7 +446,7 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *, if len(devices) == 1: # If we only have one device in our computation, we can construct a # replicated HloSharding and call it right now. - _hlo_sharding_callback(sharding_impls.get_replicated_hlo_sharding()) + _hlo_sharding_callback(sharding_impls.replicated_hlo_sharding) return [] key = xc.encode_inspect_sharding_callback(_hlo_sharding_callback) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 019411c77..60e8c54a4 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -114,9 +114,7 @@ def modify_sdy_sharding_wrt_axis_types(sdy_sharding: SdyArraySharding, mesh): return sdy_sharding -@util.cache(max_size=128, trace_context_in_key=False) -def get_replicated_hlo_sharding(): - return xc.HloSharding.replicate() +replicated_hlo_sharding = xc.HloSharding.replicate() @use_cpp_class(xc.SingleDeviceSharding) @@ -183,7 +181,7 @@ class SingleDeviceSharding(jsharding.Sharding): return (self._device,) def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: - return get_replicated_hlo_sharding() + return replicated_hlo_sharding def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding: sdy_dim_sharding = [SdyDimSharding(axes=[], is_closed=True) @@ -401,7 +399,7 @@ def _op_sharding_to_pos_sharding( def _positional_sharding_to_xla_hlo_sharding( self, num_dimensions: int) -> xc.HloSharding: if self.shape == (1,) * self.ndim: - return get_replicated_hlo_sharding() + return replicated_hlo_sharding pbuf = xc.OpSharding() shape = self.shape[self.ndim - num_dimensions:] # 'rank promotion' of val @@ -603,7 +601,7 @@ class GSPMDSharding(jsharding.Sharding): @functools.cached_property def _hlo_sharding_hash(self): if self.is_fully_replicated: - return hash(get_replicated_hlo_sharding()) + return hash(replicated_hlo_sharding) return hash(self._hlo_sharding) def __eq__(self, other): @@ -669,7 +667,7 @@ class GSPMDSharding(jsharding.Sharding): @classmethod def get_replicated(cls, device_assignment, *, memory_kind: str | None = None): - return cls(tuple(device_assignment), get_replicated_hlo_sharding(), + return cls(tuple(device_assignment), replicated_hlo_sharding, memory_kind=memory_kind) From 4f704713100017c739b90785e99d2a1921d6529d Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 17 Mar 2025 13:17:34 -0700 Subject: [PATCH 21/34] Fix error in pallas tutorial PiperOrigin-RevId: 737727935 --- docs/pallas/tpu/sparse.ipynb | 10 +++++----- docs/pallas/tpu/sparse.md | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/docs/pallas/tpu/sparse.ipynb b/docs/pallas/tpu/sparse.ipynb index 5b37e7b05..ac3a0dad2 100644 --- a/docs/pallas/tpu/sparse.ipynb +++ b/docs/pallas/tpu/sparse.ipynb @@ -299,7 +299,7 @@ " ):\n", " \"\"\"A DSD (Dense = Sparse @ Dense) matmul kernel.\"\"\"\n", " del idxs_k_ref\n", - " blk_idx = pl.program_id(0)\n", + " blk_idx = pl.program_id(1)\n", " is_start = blk_idx == 0\n", " changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)])\n", " @pl.when(is_start | changed_blocks)\n", @@ -314,13 +314,13 @@ " o_ref[...] = accum_scratch[...].astype(o_ref.dtype)\n", "\n", "\n", - "def x_map(blk_idx, j, blk_idxs_i, blk_idxs_k):\n", + "def x_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n", " del j, blk_idxs_i, blk_idxs_k\n", " return (blk_idx, 0, 0)\n", - "def y_map(blk_idx, j, blk_idxs_i, blk_idxs_k):\n", + "def y_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n", " del blk_idxs_i\n", " return (blk_idxs_k[blk_idx], j)\n", - "def o_map(blk_idx, j, blk_idxs_i, blk_idxs_k):\n", + "def o_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n", " del blk_idxs_k\n", " return (blk_idxs_i[blk_idx], j)\n", "\n", @@ -335,7 +335,7 @@ " num_scalar_prefetch=2,\n", " # Note that while num_blocks is static here, Pallas does support\n", " # dynamic grid sizes.\n", - " grid=(num_blocks, N // blk_N),\n", + " grid=(N // blk_N, num_blocks),\n", " in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map),\n", " pl.BlockSpec((blk_K, blk_N), y_map),\n", " # Placeholder for a zeros-array used by input_output_aliases.\n", diff --git a/docs/pallas/tpu/sparse.md b/docs/pallas/tpu/sparse.md index 36a6e07e9..113f31d8b 100644 --- a/docs/pallas/tpu/sparse.md +++ b/docs/pallas/tpu/sparse.md @@ -239,7 +239,7 @@ def dsd_kernel(idxs_i_ref, idxs_k_ref, # Scalar prefetch inputs. ): """A DSD (Dense = Sparse @ Dense) matmul kernel.""" del idxs_k_ref - blk_idx = pl.program_id(0) + blk_idx = pl.program_id(1) is_start = blk_idx == 0 changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)]) @pl.when(is_start | changed_blocks) @@ -254,13 +254,13 @@ def dsd_kernel(idxs_i_ref, idxs_k_ref, # Scalar prefetch inputs. o_ref[...] = accum_scratch[...].astype(o_ref.dtype) -def x_map(blk_idx, j, blk_idxs_i, blk_idxs_k): +def x_map(j, blk_idx, blk_idxs_i, blk_idxs_k): del j, blk_idxs_i, blk_idxs_k return (blk_idx, 0, 0) -def y_map(blk_idx, j, blk_idxs_i, blk_idxs_k): +def y_map(j, blk_idx, blk_idxs_i, blk_idxs_k): del blk_idxs_i return (blk_idxs_k[blk_idx], j) -def o_map(blk_idx, j, blk_idxs_i, blk_idxs_k): +def o_map(j, blk_idx, blk_idxs_i, blk_idxs_k): del blk_idxs_k return (blk_idxs_i[blk_idx], j) @@ -275,7 +275,7 @@ grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=2, # Note that while num_blocks is static here, Pallas does support # dynamic grid sizes. - grid=(num_blocks, N // blk_N), + grid=(N // blk_N, num_blocks), in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map), pl.BlockSpec((blk_K, blk_N), y_map), # Placeholder for a zeros-array used by input_output_aliases. From ecf7fde7147ee3d6575c74584fe497b457818249 Mon Sep 17 00:00:00 2001 From: Michael Hudgins Date: Mon, 17 Mar 2025 20:19:20 +0000 Subject: [PATCH 22/34] Add B200 testing to continuous workflow --- .github/workflows/pytest_cuda.yml | 3 ++- .github/workflows/wheel_tests_continuous.yml | 24 +++++++++++++++----- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/.github/workflows/pytest_cuda.yml b/.github/workflows/pytest_cuda.yml index c3420d6b5..ae74da53e 100644 --- a/.github/workflows/pytest_cuda.yml +++ b/.github/workflows/pytest_cuda.yml @@ -54,7 +54,8 @@ jobs: runs-on: ${{ inputs.runner }} # TODO: Update to the generic ML ecosystem test containers when they are ready. container: ${{ (contains(inputs.cuda, '12.3') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.3-cudnn9.1:latest') || - (contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.1-cudnn9.1:latest') }} + (contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.1-cudnn9.1:latest') || + (contains(inputs.cuda, '12.8') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest') }} name: "Pytest CUDA (${{ inputs.runner }}, CUDA ${{ inputs.cuda }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})" env: diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml index fc4321724..ecdf43b13 100644 --- a/.github/workflows/wheel_tests_continuous.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -110,18 +110,30 @@ jobs: fail-fast: false # don't cancel all jobs on failure matrix: # Python values need to match the matrix stategy in the artifact build jobs above - runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu"] + # See exlusions for what is fully tested + runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu","linux-x86-a4-224-b200-1gpu"] python: ["3.10",] - cuda: ["12.3", "12.1"] + cuda: ["12.1","12.3","12.8"] enable-x64: [1, 0] exclude: - # Run only a single configuration on H100 to save resources + # L4 does not run on cuda 12.8 but tests other configs + - runner: "linux-x86-g2-48-l4-4gpu" + cuda: "12.8" + # H100 runs only a single config, CUDA 12.3 Enable x64 1 + - runner: "linux-x86-a3-8g-h100-8gpu" + cuda: "12.8" - runner: "linux-x86-a3-8g-h100-8gpu" - python: "3.10" cuda: "12.1" - runner: "linux-x86-a3-8g-h100-8gpu" - python: "3.10" - enable-x64: 0 + enable-x64: "0" + # B200 runs only a single config, CUDA 12.8 Enable x64 1 + - runner: "linux-x86-a4-224-b200-1gpu" + enable-x64: "0" + - runner: "linux-x86-a4-224-b200-1gpu" + cuda: "12.1" + - runner: "linux-x86-a4-224-b200-1gpu" + cuda: "12.3" + name: "Pytest CUDA (JAX artifacts version = ${{ format('{0}', 'head') }})" with: runner: ${{ matrix.runner }} From b4966130a355d64759bcfae66a81153f32b68c89 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 17 Mar 2025 15:45:39 -0700 Subject: [PATCH 23/34] Compute tile index using tile-based coordinates This reduces the chances of overflowing a 32-bit integer when computing tile indices. Add unit test to reproduce the overflow with the previous implementation of `blocked_fold_in`. PiperOrigin-RevId: 737778853 --- jax/_src/blocked_sampler.py | 31 +++++++++------- tests/blocked_sampler_test.py | 68 ++++++++++++++++++++++++++++++----- 2 files changed, 78 insertions(+), 21 deletions(-) diff --git a/jax/_src/blocked_sampler.py b/jax/_src/blocked_sampler.py index 3bc592d88..e4d2e2855 100644 --- a/jax/_src/blocked_sampler.py +++ b/jax/_src/blocked_sampler.py @@ -28,17 +28,17 @@ class SampleFn(Protocol): ... -def _compute_scalar_index(iteration_index: Sequence[int], - total_size: Shape, - block_size: Shape, - block_index: Sequence[int]) -> int: - ndims = len(iteration_index) +def _compute_tile_index(block_index: Sequence[int], + total_size_in_blocks: Shape, + block_size_in_tiles: Shape, + tile_index_in_block: Sequence[int]) -> int: + ndims = len(block_index) dim_size = 1 total_idx = 0 for i in range(ndims-1, -1, -1): - dim_idx = block_index[i] + iteration_index[i] * block_size[i] + dim_idx = tile_index_in_block[i] + block_index[i] * block_size_in_tiles[i] total_idx += dim_idx * dim_size - dim_size *= total_size[i] + dim_size *= total_size_in_blocks[i] * block_size_in_tiles[i] return total_idx @@ -99,18 +99,23 @@ def blocked_fold_in( An N-dimensional nested list of keys required to sample the tiles corresponding to the block specified by `block_index`. """ - size_in_blocks = tuple( - _shape // _element for _shape, _element in zip(block_size, tile_size)) + block_size_in_tiles = tuple( + _shape // _element for _shape, _element in zip(block_size, tile_size) + ) + + total_size_in_blocks = tuple( + _shape // _element for _shape, _element in zip(total_size, block_size) + ) def _keygen_loop(axis, prefix): - if axis == len(size_in_blocks): + if axis == len(block_size_in_tiles): subtile_key = jax.random.fold_in( - global_key, _compute_scalar_index( - block_index, total_size, size_in_blocks, prefix)) + global_key, _compute_tile_index( + block_index, total_size_in_blocks, block_size_in_tiles, prefix)) return subtile_key else: keys = [] - for i in range(size_in_blocks[axis]): + for i in range(block_size_in_tiles[axis]): keys.append(_keygen_loop(axis+1, prefix+(i,))) return keys return _keygen_loop(0, tuple()) diff --git a/tests/blocked_sampler_test.py b/tests/blocked_sampler_test.py index 1f8f2b645..4c27e850c 100644 --- a/tests/blocked_sampler_test.py +++ b/tests/blocked_sampler_test.py @@ -37,18 +37,41 @@ def call_kernel( m, n = grid return jnp.concatenate([ jnp.concatenate([ - kernel(i, j, *args) for j in range(n)], axis=1) + kernel((i, j), *args) for j in range(n)], axis=1) for i in range(m)], axis=0) -def uniform_kernel(i: int, j: int, total_size, block_size, tile_size): - """Uniform random sampling kernel function.""" - global_key = jax.random.key(0) - keys = blocked_sampler.blocked_fold_in(global_key, +def call_kernel_3d( + kernel, + grid: tuple[int, int], + *args + ): + """Calls a kernel over a 3D grid and concatenates results to a single array.""" + depth, rows, cols = grid + return jnp.concatenate([ + jnp.concatenate([ + jnp.concatenate([ + jnp.array(kernel((i, j, k), *args)) + for k in range(cols)], axis=2) + for j in range(rows)], axis=1) + for i in range(depth)], axis=0) + + +def blocked_fold_in(block_index, key, total_size, block_size, tile_size): + """Folds in block_index into global_key.""" + return blocked_sampler.blocked_fold_in(key, total_size=total_size, block_size=block_size, tile_size=tile_size, - block_index=(i, j)) + block_index=block_index) + + +def uniform_kernel(block_index, key, total_size, block_size, tile_size): + """Uniform random sampling kernel function.""" + keys = blocked_fold_in(block_index, key, + total_size=total_size, + block_size=block_size, + tile_size=tile_size) return blocked_sampler.sample_block(jax.random.uniform, keys, block_size=block_size, @@ -74,17 +97,46 @@ class BlockedSamplerTest(jtu.JaxTestCase): ) def test_block_shape_invariance(self, total_size, block_size_a, block_size_b, tile_size, transpose_grid): + global_key = jax.random.key(0) grid_a = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_a)) result_a = call_kernel( - uniform_kernel, grid_a, transpose_grid, + uniform_kernel, grid_a, transpose_grid, global_key, total_size, block_size_a, tile_size) grid_b = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_b)) result_b = call_kernel( - uniform_kernel, grid_b, transpose_grid, + uniform_kernel, grid_b, transpose_grid, global_key, total_size, block_size_b, tile_size) np.testing.assert_array_equal(result_a, result_b) +class BlockedFoldInTest(jtu.JaxTestCase): + @parameterized.named_parameters( + # Check that sampling a tensor of total size > jnp.iinfo(jnp.uint32).max works + # as expected. Specifically, blocked key folding does not depend on the total + # size of the tensor, but only the total number of tiles. + # Using a 3D grid (with very large inner dimensions) triggers an overflow in a + # previous implementation of blocked_fold_in. + dict(testcase_name='4096x512_vs_1024x2048', + total_size=(2, 64 * 1024, 64 * 1024), block_size_a=(1, 4096, 512), + block_size_b=(1, 1024, 2048), tile_size=(1, 1024, 512)), + ) + def test_blocked_fold_in_shape_invariance(self, total_size, block_size_a, + block_size_b, tile_size): + global_key = jax.random.key(0) + grid_a = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_a)) + result_a = call_kernel_3d( + blocked_fold_in, grid_a, global_key, total_size, + block_size_a, tile_size) + + grid_b = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_b)) + result_b = call_kernel_3d( + blocked_fold_in, grid_b, global_key, total_size, + block_size_b, tile_size) + np.testing.assert_array_equal(jax.random.key_data(result_a), + jax.random.key_data(result_b)) + + + if __name__ == "__main__": absltest.main() From 051687dc4c899df3d95c30b812ade401d8b31166 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 17 Mar 2025 16:29:02 -0700 Subject: [PATCH 24/34] [pallas] `pallas_call_p` is now parameterized by a mesh The mesh is necessary to add support for clusters to the Mosaic GPU backend. PiperOrigin-RevId: 737792129 --- jax/_src/pallas/core.py | 28 +++++-- jax/_src/pallas/hlo_interpreter.py | 3 +- jax/_src/pallas/mosaic/core.py | 8 +- jax/_src/pallas/mosaic/interpret.py | 3 +- .../pallas/mosaic/pallas_call_registration.py | 16 ++-- jax/_src/pallas/mosaic_gpu/core.py | 16 ++-- jax/_src/pallas/mosaic_gpu/lowering.py | 10 ++- .../mosaic_gpu/pallas_call_registration.py | 2 + jax/_src/pallas/pallas_call.py | 82 ++++++++++++++++--- .../pallas/triton/pallas_call_registration.py | 3 + 10 files changed, 133 insertions(+), 38 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 466f6037a..5342a6946 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -15,6 +15,7 @@ """Module for pallas-core functionality.""" from __future__ import annotations +import collections from collections.abc import Callable, Iterable, Iterator, Sequence import contextlib import copy @@ -1068,6 +1069,17 @@ def _core_map_abstract_eval(*args, jaxpr, mesh, **_): return [], effs +class Mesh(Protocol): + + @property + def backend(self) -> str: + ... + + @property + def shape(self) -> collections.OrderedDict[object, int]: + ... + + _core_map_mesh_rules: dict[type[Any], Callable[..., Any]] = {} @@ -1075,9 +1087,8 @@ def default_mesh_discharge_rule( in_avals, out_avals, *args, - grid, + mesh, compiler_params, - backend, jaxpr, debug, interpret, @@ -1100,19 +1111,22 @@ def default_mesh_discharge_rule( if isinstance(eff, state_types.WriteEffect) ) any_spec = BlockSpec(memory_space=MemorySpace.ANY) + grid_spec = GridSpec( + grid=tuple(mesh.shape.items()), + in_specs=[any_spec] * len(in_avals), + out_specs=[any_spec] * len(modified_idxs), + ) from jax._src.pallas import pallas_call # Avoid circular dependency. - outs = pallas_call.pallas_call( + outs = pallas_call._pallas_call( body, name=name, out_shape=[in_avals[idx] for idx in modified_idxs], - in_specs=[any_spec] * len(in_avals), - out_specs=[any_spec] * len(modified_idxs), input_output_aliases={ in_idx: out_idx for out_idx, in_idx in enumerate(modified_idxs) }, - grid=grid, + grid_spec=grid_spec, + mesh=mesh, compiler_params=compiler_params, - backend=backend, interpret=interpret, debug=debug, cost_estimate=cost_estimate, diff --git a/jax/_src/pallas/hlo_interpreter.py b/jax/_src/pallas/hlo_interpreter.py index 8d7543b31..6fbe5e914 100644 --- a/jax/_src/pallas/hlo_interpreter.py +++ b/jax/_src/pallas/hlo_interpreter.py @@ -340,11 +340,12 @@ def pallas_call_hlo_interpret( debug: bool, input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: GridMapping, + mesh: pallas_core.Mesh | None, compiler_params: Any, cost_estimate: CostEstimate, out_avals: tuple[jax_core.AbstractValue, ...], ): - del compiler_params, cost_estimate, out_avals + del mesh, compiler_params, cost_estimate, out_avals debug_info = jaxpr.debug_info # If we're in interpret mode, we *scan* over the grid and eval the # discharged jaxpr. diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 3e60e471d..f582248ee 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -211,6 +211,10 @@ class TensorCoreMesh: devices: np.ndarray axis_names: Sequence[str] + @property + def backend(self) -> str: + return "mosaic_tpu" + @property def shape(self): return collections.OrderedDict(zip(self.axis_names, self.devices.shape)) @@ -259,7 +263,6 @@ def _tensorcore_mesh_discharge_rule( compiler_params = TPUCompilerParams() if len(mesh.shape) > 1: raise NotImplementedError("Mesh must be 1D") - core_axis_name, num_cores = list(mesh.shape.items())[0] if compiler_params.dimension_semantics is not None: raise ValueError( "dimension_semantics must be None for TensorCoreMesh" @@ -269,13 +272,12 @@ def _tensorcore_mesh_discharge_rule( out_avals, *args, jaxpr=jaxpr, - grid=((core_axis_name, num_cores),), + mesh=mesh, compiler_params=compiler_params.replace( dimension_semantics=(PARALLEL,) ), debug=debug, interpret=interpret, - backend="mosaic_tpu", cost_estimate=cost_estimate, name=name, ) diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index a731bfdfd..e92de91f4 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -1351,12 +1351,13 @@ def interpret_pallas_call( debug: bool, input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: GridMapping, + mesh: pallas_core.Mesh | None, compiler_params: Any, cost_estimate: CostEstimate, out_avals: tuple[jax_core.AbstractValue, ...], interpret_params: TPUInterpretParams, ): - del debug, cost_estimate, out_avals + del debug, mesh, cost_estimate, out_avals # args contains: *dynamic_grid_sizes, *index, *inputs. (No consts?) dynamic_grid_args, scalars, input_args = split_list( diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index 887c9629a..896af0c46 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -108,6 +108,7 @@ def pallas_call_tpu_lowering_rule( *in_nodes, jaxpr: jax_core.Jaxpr, grid_mapping: core.GridMapping, + mesh: pallas_core.Mesh | None, input_output_aliases: tuple[tuple[int, int], ...], debug: bool, interpret: bool, @@ -116,7 +117,8 @@ def pallas_call_tpu_lowering_rule( out_avals: tuple[jax_core.AbstractValue, ...], ): """Lowers a pallas_call to a Mosaic TPU custom call.""" - del interpret + del mesh, interpret # Unused. + debug_info = jaxpr._debug_info if debug: print(f"\nThe kernel jaxpr for pallas_call {debug_info.func_src_info}:") @@ -126,11 +128,11 @@ def pallas_call_tpu_lowering_rule( else: mosaic_params = {} - mesh = None + jax_mesh = None axis_context = ctx.module_context.axis_context if axis_context is not None: if isinstance(axis_context, sharding_impls.SPMDAxisContext): - mesh = axis_context.mesh + jax_mesh = axis_context.mesh mlir_ctx = mlir.JaxIrContext() mlir_ctx.append_dialect_registry(mlir.upstream_dialects) mlir_ctx.load_all_available_dialects() @@ -147,7 +149,7 @@ def pallas_call_tpu_lowering_rule( grid_mapping, jaxpr, dimension_semantics=dimension_semantics, - mesh=mesh, + mesh=jax_mesh, for_verification=for_verification, dynamic_shape_replacement_enabled=pallas_core.dynamic_shapes_export_enabled(), ) @@ -164,11 +166,11 @@ def pallas_call_tpu_lowering_rule( ) if promela_dump_path := _DUMP_PROMELA_TO.value: - num_devices = 1 if mesh is None else mesh.devices.size + num_devices = 1 if jax_mesh is None else jax_mesh.devices.size num_cores = ( jax.devices()[0].num_cores - if mesh is None - else mesh.devices[0].num_cores + if jax_mesh is None + else jax_mesh.devices[0].num_cores ) verification_module, _ = lower_module(for_verification=True) model = verification.export_promela_model( diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 2e491074c..630c1b8f4 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -18,7 +18,7 @@ from __future__ import annotations import abc import collections -from collections.abc import Sequence +from collections.abc import Iterable, Sequence import dataclasses import enum import itertools as it @@ -519,9 +519,16 @@ class GPUMesh: ) @property - def shape(self): + def backend(self) -> str: + return "mosaic_gpu" + + @property + def shape(self) -> collections.OrderedDict[object, int]: + pairs: Iterable[tuple[object, int]] if self.num_threads is not None: - pairs = zip(self.axis_names, (*self.grid, *self.cluster, self.num_threads)) + pairs = zip( + self.axis_names, (*self.grid, *self.cluster, self.num_threads) + ) else: pairs = tuple( zip( @@ -563,8 +570,7 @@ def _gpu_mesh_discharge_rule( out_avals, *args, jaxpr=jaxpr, - grid=tuple(mesh.shape.items()), - backend="mosaic_gpu", + mesh=mesh, compiler_params=compiler_params, debug=debug, interpret=interpret, diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 5c863baf6..6b06e6b7d 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -450,6 +450,7 @@ def _block_spec_from_block_mapping( def lower_pipelined_jaxpr_to_module( grid_mapping: pallas_core.GridMapping, + mesh: pallas_core.Mesh | None, jaxpr: jax_core.Jaxpr, compiler_params: dict[str, Any], cost_estimate: pallas_core.CostEstimate | None, @@ -473,7 +474,10 @@ def lower_pipelined_jaxpr_to_module( block_mappings, [grid_mapping.num_inputs] ) - if grid_mapping.grid_names: # Last dim corresponds to the warpgroup count + if mesh is not None: + assert isinstance(mesh, gpu_core.GPUMesh) + if mesh and mesh.num_threads is not None: + # Last dim corresponds to the warpgroup count. block = (128 * grid_mapping.grid[-1], 1, 1) grid = grid_mapping.grid[:-1] else: @@ -566,6 +570,7 @@ def lower_pipelined_jaxpr_to_module( parallel_grid, grid_mapping.grid_names, block, + mesh.cluster if mesh is not None else (), [bm.array_shape_dtype for bm in in_block_mappings], [bm.array_shape_dtype for bm in out_block_mappings], new_jaxpr, @@ -578,6 +583,7 @@ def lower_jaxpr_to_module( grid: Sequence[int], grid_names: Sequence[str], block: Sequence[int], + cluster: Sequence[int], in_shapes: Sequence[jax.ShapeDtypeStruct], out_shapes: Sequence[jax.ShapeDtypeStruct], jaxpr: jax_core.Jaxpr, @@ -640,7 +646,7 @@ def lower_jaxpr_to_module( mgpu_core._lower_as_gpu_kernel( body, grid=parallel_grid, - cluster=(), + cluster=cluster, block=block, in_shapes=in_shapes, out_shape=out_shapes, diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index a9e5ead8d..d506349fe 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -38,6 +38,7 @@ def pallas_call_lowering( debug: bool, input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: pallas_core.GridMapping, + mesh: pallas_core.Mesh | None, compiler_params: dict[str, Any], cost_estimate: pallas_core.CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], @@ -63,6 +64,7 @@ def pallas_call_lowering( lowering_result = lowering.lower_pipelined_jaxpr_to_module( grid_mapping, + mesh, jaxpr, compiler_params, cost_estimate, diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index b1e1da34f..d0b74b2e5 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -20,7 +20,7 @@ import dataclasses import enum from functools import partial, reduce import types -from typing import Any, Literal +from typing import Any, Literal, cast import jax from jax import lax @@ -119,6 +119,7 @@ def _pallas_call_jvp_rule( jaxpr: jax_core.Jaxpr, input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: GridMapping, + mesh: pallas_core.Mesh | None, debug: bool, interpret: bool, compiler_params: Any, @@ -133,6 +134,8 @@ def _pallas_call_jvp_rule( raise NotImplementedError if input_output_aliases: raise NotImplementedError("JVP with aliasing not supported.") + if mesh is not None: + raise NotImplementedError("pallas_call with a mesh does not support JVP") nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents] tangents = [t for t in tangents if type(t) is not ad_util.Zero] nonzero_tangents_with_outputs = nonzero_tangents + [True] * grid_mapping.num_outputs @@ -181,6 +184,7 @@ def _pallas_call_jvp_rule( *tangents, jaxpr=jvp_jaxpr, grid_mapping=jvp_grid_mapping, + mesh=mesh, interpret=interpret, debug=debug, input_output_aliases=(), @@ -317,6 +321,7 @@ def _batch_with_explicit_loop( *, jaxpr: jax_core.Jaxpr, grid_mapping: GridMapping, + mesh: pallas_core.Mesh | None, input_output_aliases: tuple[tuple[int, int], ...], debug: bool, interpret: bool, @@ -384,6 +389,7 @@ def _batch_with_explicit_loop( *batch_args, jaxpr=jaxpr, grid_mapping=grid_mapping, + mesh=mesh, input_output_aliases=input_output_aliases, debug=debug, interpret=interpret, @@ -413,6 +419,7 @@ def _pallas_call_batching_rule( *, jaxpr: jax_core.Jaxpr, grid_mapping: GridMapping, + mesh: pallas_core.Mesh | None, input_output_aliases: tuple[tuple[int, int], ...], debug: bool, interpret: bool, @@ -421,6 +428,11 @@ def _pallas_call_batching_rule( out_avals: tuple[jax_core.AbstractValue, ...], backend: _Backend | None, ): + if mesh is not None: + raise NotImplementedError( + "pallas_call with a mesh does not support batching" + ) + def _maybe_squeeze_out_bdim( x: jax.Array, bdim: int | batching.NotMapped ) -> jax.Array: @@ -445,6 +457,7 @@ def _pallas_call_batching_rule( *args, jaxpr=jaxpr, grid_mapping=grid_mapping, + mesh=mesh, input_output_aliases=input_output_aliases, debug=debug, interpret=interpret, @@ -478,6 +491,7 @@ def _pallas_call_batching_rule( dims=dynamic_grid_dims + dims, jaxpr=jaxpr, grid_mapping=grid_mapping, + mesh=mesh, input_output_aliases=input_output_aliases, debug=debug, interpret=interpret, @@ -512,6 +526,7 @@ def _pallas_call_batching_rule( dims=scalar_bdims + bdims, jaxpr=jaxpr, grid_mapping=grid_mapping, + mesh=mesh, input_output_aliases=input_output_aliases, debug=debug, interpret=interpret, @@ -890,6 +905,7 @@ def _pallas_call_batching_rule( *args, jaxpr=jaxpr, grid_mapping=batched_grid_mapping, + mesh=mesh, input_output_aliases=input_output_aliases, debug=debug, interpret=interpret, @@ -1339,12 +1355,13 @@ def _pallas_call_state_discharge_rule( jaxpr: jax_core.Jaxpr, input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: GridMapping, + mesh: pallas_core.Mesh | None, debug: bool, interpret: bool, compiler_params: Any, cost_estimate: CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], - backend: _Backend | None = None + backend: _Backend | None = None, ): del avals_out assert all(isinstance(v.aval, state.AbstractRef) for v in jaxpr.constvars) @@ -1440,6 +1457,7 @@ def _pallas_call_state_discharge_rule( jaxpr=new_jaxpr, input_output_aliases=new_input_output_aliases, grid_mapping=new_grid_mapping, + mesh=mesh, debug=debug, interpret=interpret, compiler_params=compiler_params, @@ -1526,16 +1544,6 @@ def pallas_call( invoke the Pallas kernel. """ - if compiler_params is None: - compiler_params = {} - if isinstance(compiler_params, pallas_core.CompilerParams): - if compiler_params.PLATFORM not in ["mosaic", "mosaic_gpu", "triton"]: - raise ValueError( - f"Unknown platform in compiler params: {compiler_params.PLATFORM}") - compiler_params = { - compiler_params.PLATFORM: dataclasses.asdict(compiler_params) - } - if grid_spec is None: grid_spec = GridSpec(grid, in_specs, out_specs, scratch_shapes) else: @@ -1556,6 +1564,55 @@ def pallas_call( "If `grid_spec` is specified, then `scratch_shapes` must " f"be `()`. It is {scratch_shapes}") del grid, in_specs, out_specs + return _pallas_call( + kernel, + out_shape, + grid_spec=grid_spec, + input_output_aliases=input_output_aliases, + debug=debug, + interpret=interpret, + name=name, + compiler_params=compiler_params, + cost_estimate=cost_estimate, + backend=backend, + ) + + +def _pallas_call( + kernel: Callable[..., None], + out_shape: Any, + *, + grid_spec: GridSpec, + mesh: pallas_core.Mesh | None = None, + input_output_aliases: dict[int, int] = {}, + debug: bool = False, + interpret: bool = False, + name: str | None = None, + compiler_params: dict[str, Any] | pallas_core.CompilerParams | None = None, + cost_estimate: CostEstimate | None = None, + backend: _Backend | None = None, +): + if compiler_params is None: + compiler_params = {} + if isinstance(compiler_params, pallas_core.CompilerParams): + if compiler_params.PLATFORM not in ["mosaic", "mosaic_gpu", "triton"]: + raise ValueError( + f"Unknown platform in compiler params: {compiler_params.PLATFORM}" + ) + compiler_params = { + compiler_params.PLATFORM: dataclasses.asdict(compiler_params) + } + + if mesh is not None: + if tuple(mesh.shape.values()) != grid_spec.grid: + raise ValueError( + f"Mesh shape {tuple(mesh.shape.values())} does not match grid " + f"shape {grid_spec.grid}." + ) + if backend is not None: + raise ValueError("If `mesh` is specified, then `backend` must be `None`.") + backend = cast(_Backend, mesh.backend) + grid_spec, dynamic_grid_bounds = pallas_core.unzip_dynamic_grid_bounds(grid_spec) # TODO(necula): this canonicalization may be convenient for some usage # but it is lossy, because it prevents expressing functions that return @@ -1643,6 +1700,7 @@ def pallas_call( debug=debug, interpret=interpret, grid_mapping=grid_mapping, + mesh=mesh, input_output_aliases=tuple(input_output_aliases.items()), compiler_params=compiler_params, cost_estimate=cost_estimate, diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index 4e3bd0697..4e8775e51 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -50,6 +50,7 @@ def pallas_call_lowering( debug: bool, input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: pallas_core.GridMapping, + mesh: pallas_core.Mesh | None, compiler_params: dict[str, Any], cost_estimate: pallas_core.CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], @@ -64,6 +65,8 @@ def pallas_call_lowering( raise NotImplementedError( "scalar prefetch not implemented in the Triton backend" ) + if mesh is not None: + raise NotImplementedError("mesh is not supported in the Triton backend") triton_params = compiler_params.get("triton", compiler_params) num_warps = triton_params.get("num_warps", 4) num_warps = 4 if num_warps is None else num_warps From 8c351917256ffbf48e34d983104b58d2fa2f3e92 Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Mon, 17 Mar 2025 16:48:57 -0700 Subject: [PATCH 25/34] Enable `jax.device_put` to a sharding with no local devices. PiperOrigin-RevId: 737797815 --- jax/_src/dispatch.py | 13 ++++++++----- jax/_src/interpreters/pxla.py | 2 +- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 050d6c394..2330f7628 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -466,11 +466,14 @@ def _device_put_sharding_impl(x, aval, device, copy): if not s.is_fully_addressable: if ((isinstance(x, array.ArrayImpl) and not x._committed) or type(x) in array_types): - multihost_utils.assert_equal( - x, fail_message=( - f"{type(x)} passed to device_put is not the same on each" - " process. Make sure you are passing the same value of" - f" {type(x)} on each process.")) + # TODO(emilyaf): Remove this condition when jit works when a sharding + # has no local devices. + if not config.enable_empty_arrays.value: + multihost_utils.assert_equal( + x, fail_message=( + f"{type(x)} passed to device_put is not the same on each" + " process. Make sure you are passing the same value of" + f" {type(x)} on each process.")) return _DeferredShardArg(x, s, aval, True, copy) # TODO(yashkatariya,mattjj): Link to a doc about McJAX and jax.Array. raise ValueError( diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index e28896802..c06eda521 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -237,7 +237,7 @@ def batched_device_put(aval: core.ShapedArray, if (isinstance(x, array.ArrayImpl) and dispatch.is_single_device_sharding(x.sharding) and x.devices() == {d})] - if len(bufs) == len(xs): + if len(bufs) == len(xs) > 0: return array.ArrayImpl( aval, sharding, bufs, committed=committed, _skip_checks=True) return xc.batched_device_put(aval, sharding, xs, list(devices), committed) From f174b00f23cb0a402a6bfe682188e6063b72928b Mon Sep 17 00:00:00 2001 From: Changhui Lin Date: Mon, 17 Mar 2025 17:17:28 -0700 Subject: [PATCH 26/34] Replace the uses of `PjRtClient::Compile()` with `PjRtClient::CompileAndLoad()`. This is to prepare for updating `PjRtClient::Compile()` to return an unloaded executable [1/N] PiperOrigin-RevId: 737805623 --- examples/jax_cpp/main.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/jax_cpp/main.cc b/examples/jax_cpp/main.cc index defbf1938..0a1d3a63a 100644 --- a/examples/jax_cpp/main.cc +++ b/examples/jax_cpp/main.cc @@ -81,7 +81,7 @@ int main(int argc, char** argv) { xla::XlaComputation xla_computation(test_module_proto); xla::CompileOptions compile_options; std::unique_ptr executable = - client->Compile(xla_computation, compile_options).value(); + client->CompileAndLoad(xla_computation, compile_options).value(); // Prepare inputs. xla::Literal literal_x = From 549973dec67c40c065274b05e889bcf3073a8870 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 17 Mar 2025 17:47:21 -0700 Subject: [PATCH 27/34] Allow pspec to be passed to device_put if there is a mesh in the surrounding context PiperOrigin-RevId: 737812111 --- jax/_src/api.py | 20 +++++++++++++++++--- tests/pjit_test.py | 13 +++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 4b14d8096..cdcc3e534 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -67,7 +67,9 @@ from jax._src.lib import jax_jit from jax._src.lib import xla_client as xc from jax._src.lib import pmap_lib from jax._src.sharding import Sharding -from jax._src.sharding_impls import PmapSharding, TransferToMemoryKind +from jax._src.mesh import get_concrete_mesh +from jax._src.sharding_impls import ( + PmapSharding, TransferToMemoryKind, PartitionSpec as P, NamedSharding) from jax._src.layout import Layout, AutoLayout from jax._src.traceback_util import api_boundary from jax._src import tree_util @@ -2280,11 +2282,20 @@ def _check_sharding(aval, s): (s,), (aval,), ("",), "device_put args", allow_uneven_sharding=False) s.shard_shape(aval.shape) # should raise an Error if incompatible +def pspec_to_sharding(val): + if isinstance(val, P): + mesh = get_concrete_mesh() + if mesh is None: + raise ValueError( + "Please set a mesh via `jax.sharding.use_mesh` if a PartitionSpec is" + " passed to device_put") + return NamedSharding(mesh, val) + return val def device_put( x, - device: None | xc.Device | Sharding | Layout | Any | TransferToMemoryKind = None, - *, src: None | xc.Device | Sharding | Layout | Any | TransferToMemoryKind = None, + device: None | xc.Device | Sharding | P | Layout | Any | TransferToMemoryKind = None, + *, src: None | xc.Device | Sharding | P | Layout | Any | TransferToMemoryKind = None, donate: bool | Any = False, may_alias: bool | None | Any = None): """Transfers ``x`` to ``device``. @@ -2333,6 +2344,9 @@ def device_put( src_flat = flatten_axes("device_put source", treedef, src) src_flat = list(map(_infer_src_sharding, src_flat, x_flat)) + device_flat = map(pspec_to_sharding, device_flat) + src_flat = map(pspec_to_sharding, src_flat) + if isinstance(donate, bool): donate_flat = [donate] * len(x_flat) else: diff --git a/tests/pjit_test.py b/tests/pjit_test.py index e3ac00133..7687bf110 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -6138,6 +6138,19 @@ class ShardingInTypesTest(jtu.JaxTestCase): self.assertDictEqual(out.sharding.mesh._axis_types_dict, {AxisType.Auto: ('x',)}) + @jtu.with_user_mesh((2,), 'x') + def test_device_put_use_mesh(self, mesh): + out = jax.device_put(np.arange(8), P('x')) + self.assertArraysEqual(out, np.arange(8)) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + + def test_device_put_no_use_mesh_error(self): + with self.assertRaisesRegex( + ValueError, + 'Please set a mesh via `jax.sharding.use_mesh` if a PartitionSpec is' + ' passed to device_put'): + jax.device_put(np.arange(8), P('x')) + @jtu.with_user_mesh((2,), 'x') def test_inputs_different_context(self, mesh): np_inp = np.arange(16).reshape(8, 2) From 34cd5b0d747e3aa82e4ca9c60c20197ddb38dfb7 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 18 Mar 2025 03:12:23 -0700 Subject: [PATCH 28/34] [Mosaic GPU] Remove sub-byte conversion restriction XLA:GPU recently changed its endianness to little endian to better match LLVM and the rest of the CUDA ecosystem, so we can lift the earlier restrictions. PiperOrigin-RevId: 737934373 --- .../mosaic/gpu/fragmented_array.py | 10 +++++----- tests/mosaic/gpu_test.py | 19 +++++++++++++++++++ 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index c6d7c02fb..dc5ad48c4 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -1244,6 +1244,11 @@ class FragmentedArray: is_vector_reg = ir.VectorType.isinstance(reg_type) reg_shape = tuple(ir.VectorType(reg_type).shape) if is_vector_reg else (1,) [vector_len] = reg_shape # This is meant to be a 1D assertion. + if (new_reg_bitwidth := utils.bitwidth(new_dtype) * vector_len) % 8: + raise ValueError( + "Register bitwidth in target type must be divisible by 8, got" + f" {new_reg_bitwidth}" + ) if cur_dtype == i4 and self.is_signed and new_dtype == bf16: new_registers = np.empty_like(self.registers) out_vec_ty = ir.VectorType.get((vector_len,), new_dtype) @@ -1344,11 +1349,6 @@ class FragmentedArray: _registers=new_registers, _layout=self.layout, _is_signed=is_signed ) # Generic path. - # XLA packs elements into bytes in big-endian order, while LLVM assumes the - # same endianness as the target machine (which is little for NVIDIA GPUs). - # We'll need to add specialized casting routines that flip the endianness. - if 1 < utils.bitwidth(cur_dtype) < 8 or 1 < utils.bitwidth(new_dtype) < 8: - raise NotImplementedError("Conversion involving sub-byte types unsupported") from_float = ir.FloatType.isinstance(cur_dtype) to_float = ir.FloatType.isinstance(new_dtype) from_integer = ir.IntegerType.isinstance(cur_dtype) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 574299ab1..91644be5c 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -518,6 +518,25 @@ class WGMMALayoutTest(TestCase): )() np.testing.assert_array_equal(iota, expected) + @parameterized.parameters(jnp.int8, jnp.int16, jnp.int32) + def test_sub_byte_conversion(self, jax_dtype_to): + jax_dtype_from = jnp.int4 + def kernel(ctx, inp, out, smem): + del ctx # Unused. + smem_inp, smem_out = smem + copy(inp, smem_inp, swizzle=16) + t = mgpu.FragmentedArray.load_tiled(smem_inp, is_signed=True, swizzle=16) + t = t.astype(utils.dtype_to_ir_type(jax_dtype_to), is_signed=True) + t.store_tiled(smem_out, swizzle=32 * jnp.dtype(jax_dtype_to).itemsize) + copy(smem_out, out, swizzle=32 * jnp.dtype(jax_dtype_to).itemsize) + + x = self.prng.integers( + low=-8, high=7, size=(1, 1, 64, 64), dtype=np.int32 + ).astype(jax_dtype_from) + y = x.astype(jax_dtype_to) + f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, y, (x, y)) + np.testing.assert_array_equal(f(x), y) + @parameterized.product( jax_dtype_from_to=( (jnp.int8, jnp.bfloat16), From 38d52a19efb84ad54a2d30b5cdddb8e70b23cac2 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Tue, 18 Mar 2025 03:34:17 -0700 Subject: [PATCH 29/34] [mosaic_gpu] Force flush all cupti activity, then unsubscribe. With default flushing, it is possible for events to be missed. We should only unsubscribe after we are finished with cupti. PiperOrigin-RevId: 737939327 --- jaxlib/mosaic/gpu/mosaic_gpu_ext.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc index 2f415912f..4f804c9e2 100644 --- a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc +++ b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc @@ -238,11 +238,12 @@ NB_MODULE(_mosaic_gpu_ext, m) { "failed to enable tracking of kernel activity by CUPTI"); }); m.def("_cupti_get_timings", []() { + THROW_IF_CUPTI_ERROR( + cuptiActivityFlushAll(CUPTI_ACTIVITY_FLAG_FLUSH_FORCED), + "failed to flush CUPTI activity buffers"); + THROW_IF_CUPTI_ERROR(cuptiFinalize(), "failed to detach CUPTI"); THROW_IF_CUPTI_ERROR(cuptiUnsubscribe(profiler_state.subscriber), "failed to unsubscribe from CUPTI"); - THROW_IF_CUPTI_ERROR(cuptiActivityFlushAll(CUPTI_ACTIVITY_FLAG_NONE), - "failed to flush CUPTI activity buffers"); - THROW_IF_CUPTI_ERROR(cuptiFinalize(), "failed to detach CUPTI"); return profiler_state.timings; }); } From d4bd2570ae32fe9c7329520c8d768b042910bc77 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 18 Mar 2025 04:47:04 -0700 Subject: [PATCH 30/34] [Mosaic GPU] Add a specialized layout for loading 4-bit inputs in WGMMA friendly layouts PiperOrigin-RevId: 737956598 --- .../mosaic/gpu/fragmented_array.py | 188 +++++++++++++----- tests/mosaic/gpu_test.py | 79 +++++--- 2 files changed, 190 insertions(+), 77 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index dc5ad48c4..ded17d5d4 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -526,6 +526,18 @@ WGMMA_LAYOUT_UPCAST_2X = TiledLayout( lane_dims=(-4, -2, -3), vector_dim=-1, ) +# This layout should be used when upcasting 4-bit elements to 16-bit, for the +# purpose of passing them into WGMMA later. The core matrices stored by a warp +# are 8x32, because each of the 4 threads in a row holds 8 elements in a single +# vector. Note that unlike WGMMA_LAYOUT_UPCAST_2X, we assign columns to each +# group of 4 threads in order (as opposed to the swapping between 1 and 2, +# 5 and 6, etc. that WGMMA_LAYOUT_UPCAST_2X does). +WGMMA_LAYOUT_UPCAST_4X = TiledLayout( + Tiling(((64, 32), (16, 32), (8, 32), (8,))), + warp_dim=-7, + lane_dims=(-3, -2), + vector_dim=-1, +) # This tiled layout is similar to WGMMA_LAYOUT. There, each warp stores a 8x8 # submatrix in the following way (we only show the first 4 rows for brevity): # @@ -739,58 +751,132 @@ class FragmentedArray: _layout=new_layout, _is_signed=self.is_signed, ) - if len(shape) == 2 and shape[0] % 64 == 0 and shape[1] % 16 == 0: - if ( - self.layout == WGMMA_LAYOUT_UPCAST_2X - and new_layout == WGMMA_LAYOUT - and (dtype_bitwidth := utils.bitwidth(self.mlir_dtype)) in {8, 16} - ): - assert shape[1] % 16 == 0 # Should be implied by the layout - new_registers = np.empty(new_layout.registers_shape(shape), dtype=object) - is_even = arith.cmpi( - arith.CmpIPredicate.eq, arith.remui(utils.thread_idx(), c(2)), c(0) + if ( + self.layout == WGMMA_LAYOUT_UPCAST_2X + and new_layout == WGMMA_LAYOUT + and (dtype_bitwidth := utils.bitwidth(self.mlir_dtype)) <= 16 + ): + assert shape[1] % 16 == 0 # Should be implied by the layout + new_registers = np.empty(new_layout.registers_shape(shape), dtype=object) + is_even = arith.cmpi( + arith.CmpIPredicate.eq, arith.remui(utils.thread_idx(), c(2)), c(0) + ) + registers = self.registers + if dtype_bitwidth == 4: + if registers.shape[1] % 2: + raise NotImplementedError( + "This relayout implementation requires an even number of column" + " tiles (to pack pairs of them for efficiency)" + ) + # We pair up the consecutive column tiles, so each register is 32-bit. + # If this layout originated from a WGMMA_LAYOUT_UPCAST_4X layout, + # LLVM will realize that the paired up vectors actually came from the + # same 32-bit register and it will become a no-op. + col_minor_registers = np.moveaxis(registers, 1, -1) + flat_registers = [ + utils.vector_concat((l, h)) + for l, h in zip( + col_minor_registers.flat[::2], col_minor_registers.flat[1::2] + ) + ] + registers = np.asarray(flat_registers, dtype=object).reshape( + *col_minor_registers.shape[:-1], col_minor_registers.shape[-1] // 2 ) - for idx, reg in np.ndenumerate(self.registers): - assert ir.VectorType(reg.type).shape == [4] - if dtype_bitwidth == 16: - # A single vector is 64-bits, but shuffles are only 32-bit wide. - # We only shuffle the half that needs to go to other thread. - low = utils.vector_slice(reg, slice(0, 2)) - high = utils.vector_slice(reg, slice(2, 4)) - to_exchange = arith.select(is_even, high, low) - # Exchange values between even and odd threads. - exchanged = utils.shfl_bfly(to_exchange, 1) - low = arith.select(is_even, low, exchanged) - high = arith.select(is_even, exchanged, high) - elif dtype_bitwidth == 8: - # The vector is 32-bits, so we just shuffle the whole thing and - # use prmt to blend it with the local register. - exchanged = utils.shfl_bfly(reg, 1) - # Consider lanes 0 and 1, because the situation is symmetric for - # each pair. If we feed reg[lane] and exchanged[lane] (which is - # really the same as reg of the other lane) to prmt, we can index - # the elements of the result using the following indices: - # reg[0]: 0 1 2 3 reg[1]: 8 9 10 11 - # prmt[0]: 0 1 2 3 4 5 6 7 - # prmt[1]: 4 5 6 7 0 1 2 3 - # The expected outputs and their respective permutations are: - # out[0]: 0 1 8 9 out[1]: 2 3 10 11 - # prmt[0]: 0 1 4 5 prmt[1]: 6 7 2 3 - # Note that the patterns still need to be flipped, since we listed - # bytes with LSB on the left, which is the opposite of how the - # numeric constants are spelled in Python (LSB on the right). - perm = arith.select(is_even, c(0x5410), c(0x3276)) - blend = utils.prmt(reg, exchanged, perm) - low = utils.vector_slice(blend, slice(0, 2)) - high = utils.vector_slice(blend, slice(2, 4)) - else: - raise NotImplementedError(dtype_bitwidth) + registers = np.moveaxis(registers, -1, 1) + for idx, reg in np.ndenumerate(registers): + if dtype_bitwidth == 16: + assert reg.type.shape == [4] + # A single vector is 64-bits, but shuffles are only 32-bit wide. + # We only shuffle the half that needs to go to other thread. + low = utils.vector_slice(reg, slice(0, 2)) + high = utils.vector_slice(reg, slice(2, 4)) + to_exchange = arith.select(is_even, high, low) + # Exchange values between even and odd threads. + exchanged = utils.shfl_bfly(to_exchange, 1) + low = arith.select(is_even, low, exchanged) + high = arith.select(is_even, exchanged, high) new_registers[(idx[0], idx[1] * 2, *idx[2:-1])] = low new_registers[(idx[0], idx[1] * 2 + 1, *idx[2:-1])] = high - assert all(r is not None for r in new_registers) - return FragmentedArray( - _registers=new_registers, _layout=new_layout, _is_signed=self.is_signed, - ) + elif dtype_bitwidth == 8: + assert reg.type.shape == [4] + # The vector is 32-bits, so we just shuffle the whole thing and + # use prmt to blend it with the local register. + exchanged = utils.shfl_bfly(reg, 1) + # Consider lanes 0 and 1, because the situation is symmetric for + # each pair. If we feed reg[lane] and exchanged[lane] (which is + # really the same as reg of the other lane) to prmt, we can index + # the elements of the result using the following indices: + # reg[0]: 0 1 2 3 reg[1]: 8 9 10 11 + # prmt[0]: 0 1 2 3 4 5 6 7 + # prmt[1]: 4 5 6 7 0 1 2 3 + # The expected outputs and their respective permutations are: + # out[0]: 0 1 8 9 out[1]: 2 3 10 11 + # prmt[0]: 0 1 4 5 prmt[1]: 6 7 2 3 + # Note that the patterns still need to be flipped, since we listed + # bytes with LSB on the left, which is the opposite of how the + # numeric constants are spelled in Python (LSB on the right). + perm = arith.select(is_even, c(0x5410), c(0x3276)) + blend = utils.prmt(reg, exchanged, perm) + for i in range(2): + reg = utils.vector_slice(blend, slice(i * 2, i * 2 + 2)) + new_registers[(idx[0], idx[1] * 2 + i, *idx[2:-1])] = reg + else: + assert dtype_bitwidth == 4 + assert reg.type.shape == [8] # We paired up the registers above. + exchanged = utils.shfl_bfly(reg, 1) + # See comment above for a more complete explanation. + # reg[0]: 0 1 2 3 16 17 18 19 reg[1]: 8 9 10 11 24 25 26 27 + # prmt[0]: -0- -1- --2-- --3-- -4- --5-- --6-- --7-- + # prmt[1]: -4- -5- --6-- --7-- -0- --1-- --2-- --3-- + # The expected outputs and their respective permutations are: + # out[0]: 0 1 8 9 16 17 24 25 out[1]: 2 3 10 11 18 19 26 27 + # prmt[0]: -0- -4- --2-- --6-- prmt[1]: -5- --1-- --7-- --3-- + perm = arith.select(is_even, c(0x6240), c(0x3715)) + blend = utils.prmt(reg, exchanged, perm) + for i in range(4): + reg = utils.vector_slice(blend, slice(i * 2, i * 2 + 2)) + new_registers[(idx[0], idx[1] * 4 + i, *idx[2:-1])] = reg + assert all(r is not None for r in new_registers) + return FragmentedArray( + _registers=new_registers, _layout=new_layout, _is_signed=self.is_signed, + ) + if ( + self.layout == WGMMA_LAYOUT_UPCAST_4X + and new_layout == WGMMA_LAYOUT_UPCAST_2X + and utils.bitwidth(self.mlir_dtype) == 4 + ): + assert shape[0] % 64 == 0 # Should be implied by the layout + assert shape[1] % 32 == 0 # Should be implied by the layout + new_registers = np.empty(new_layout.registers_shape(shape), dtype=object) + i32 = ir.IntegerType.get_signless(32) + c = lambda x: arith.constant(i32, x) + is_01 = arith.cmpi( + arith.CmpIPredicate.ult, arith.remui(utils.thread_idx(), c(4)), c(2) + ) + for idx, reg in np.ndenumerate(self.registers): + assert ir.VectorType(reg.type).shape == [8] + # The vector is 32-bits, so we just shuffle the whole thing and + # use prmt to blend it with the local register. + exchanged = utils.shfl_bfly(reg, 2) + # See comments above for conventions. Here we exchange data between + # threads with lane index related by flipping 2nd bit (e.g. 0 and 2). + # reg[0]: 0 1 2 3 4 5 6 7 reg[2]: 16 17 18 19 20 21 22 23 + # prmt[0]: -0- -1- -2- -3- --4-- --5-- --6-- --7-- + # prmt[1]: -4- -5- -6- -7- --0-- --1-- --2-- --3-- + # The expected outputs and their respective permutations are: + # out[0]: 0 1 2 3 16 17 18 19 out[2]: 4 5 6 7 20 21 22 23 + # prmt[0]: -0- -1- --4-- --5-- prmt[2]: -6- -7- --2-- --3-- + perm = arith.select(is_01, c(0x5410), c(0x3276)) + blend = utils.prmt(reg, exchanged, perm) + for i in range(2): + reg = utils.vector_slice(blend, slice(i * 4, i * 4 + 4)) + new_registers[(idx[0], idx[1] * 2 + i, *idx[2:-1])] = reg + assert all(r is not None for r in new_registers) + return FragmentedArray( + _registers=new_registers, _layout=new_layout, _is_signed=self.is_signed, + ) + if self.layout == WGMMA_LAYOUT_UPCAST_4X and new_layout == WGMMA_LAYOUT: + return self.to_layout(WGMMA_LAYOUT_UPCAST_2X).to_layout(new_layout) if not isinstance(self.layout, WGSplatFragLayout): raise NotImplementedError( f"Cannot convert from {self.layout} to {new_layout}" @@ -1288,7 +1374,9 @@ class FragmentedArray: int_ty = ir.IntegerType.get_signless(group_size * 4) while vector_len - offset >= group_size: reg_slice = utils.vector_slice(reg, slice(offset, offset + group_size)) - reg_slice_int = arith.extsi(i32, utils.bitcast(reg_slice, int_ty)) + reg_slice_int = utils.bitcast(reg_slice, int_ty) + if int_ty != i32: + reg_slice_int = arith.extsi(i32, reg_slice_int) reg_slice_int_shr = arith.shrui(reg_slice_int, c(4, i32)) out_int_regs.extend( upcast_to_bf16(reg_slice_int, reg_slice_int_shr, part=part) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 91644be5c..bc56f21d0 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -14,6 +14,7 @@ # ============================================================================== from collections.abc import Sequence +import contextlib import dataclasses import enum import itertools @@ -83,6 +84,20 @@ def mlir_sum(elems): return total +@contextlib.contextmanager +def get_sass(): + prev_dump = os.environ.get("MOSAIC_GPU_DUMP_SASS", None) + os.environ["MOSAIC_GPU_DUMP_SASS"] = "1" + try: + with jtu.capture_stdout() as output: + yield output + finally: + if prev_dump is not None: + os.environ["MOSAIC_GPU_DUMP_SASS"] = prev_dump + else: + del os.environ["MOSAIC_GPU_DUMP_SASS"] + + def copy(src: ir.Value, dst: ir.Value, swizzle: int | None = None): index = ir.IndexType.get() thread_id = gpu.thread_id(gpu.Dimension.x) @@ -542,7 +557,11 @@ class WGMMALayoutTest(TestCase): (jnp.int8, jnp.bfloat16), (jnp.int4, jnp.bfloat16), ), - layout=(fa.WGMMA_LAYOUT, fa.WGMMA_LAYOUT_UPCAST_2X), + layout=( + fa.WGMMA_LAYOUT, + fa.WGMMA_LAYOUT_UPCAST_2X, + fa.WGMMA_LAYOUT_UPCAST_4X, + ), ) def test_optimized_conversion(self, jax_dtype_from_to, layout): jax_dtype_from, jax_dtype_to = jax_dtype_from_to @@ -2194,19 +2213,11 @@ class LayoutTest(TestCase): .transpose(0, 2, 1, 3) ) - prev_dump = os.environ.get("MOSAIC_GPU_DUMP_SASS", None) - os.environ["MOSAIC_GPU_DUMP_SASS"] = "1" - try: - with jtu.capture_stdout() as get_sass: - iota = mgpu.as_gpu_kernel( - kernel, (1, 1, 1), (128, 1, 1), expected, expected, - [expected, expected, mgpu.TMABarrier()], - )(expected) - finally: - if prev_dump is not None: - os.environ["MOSAIC_GPU_DUMP_SASS"] = prev_dump - else: - del os.environ["MOSAIC_GPU_DUMP_SASS"] + with get_sass() as sass: + iota = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), expected, expected, + [expected, expected, mgpu.TMABarrier()], + )(expected) np.testing.assert_array_equal(iota, expected) # Verify that we don't use too many registers for the transfers. @@ -2219,7 +2230,7 @@ class LayoutTest(TestCase): expected_regs //= 2 for instr in ("STS", "LDS"): with self.subTest(instr + " count"): - addrs = re.findall(instr + r".* \[(.*)\]", get_sass()) + addrs = re.findall(instr + r".* \[(.*)\]", sass()) def get_reg(addr): if (pos := addr.find("+")) != -1: return addr[:pos] @@ -2294,30 +2305,38 @@ class LayoutTest(TestCase): )(x) np.testing.assert_array_equal(y, y_ref) - @parameterized.product( - upcast_before_layout_change=[True, False], + @parameterized.parameters( + (fa.WGMMA_LAYOUT_UPCAST_2X, fa.WGMMA_LAYOUT, jnp.int8, jnp.int8, 1), + (fa.WGMMA_LAYOUT_UPCAST_2X, fa.WGMMA_LAYOUT, jnp.int8, jnp.int16, 1), + (fa.WGMMA_LAYOUT_UPCAST_4X, fa.WGMMA_LAYOUT_UPCAST_2X, jnp.int4, jnp.int4, 1), + (fa.WGMMA_LAYOUT_UPCAST_2X, fa.WGMMA_LAYOUT, jnp.int4, jnp.int4, 0.5), + (fa.WGMMA_LAYOUT_UPCAST_4X, fa.WGMMA_LAYOUT, jnp.int4, jnp.int4, 2), ) - def test_upcast_to_wgmma(self, upcast_before_layout_change): - in_dtype = jnp.dtype(jnp.int8) + def test_upcast_to_wgmma( + self, start_layout, end_layout, in_dtype, cast_dtype, shfl_per_reg + ): + in_dtype = jnp.dtype(in_dtype) out_dtype = jnp.dtype(jnp.int16) + out_dtype_mlir = utils.dtype_to_ir_type(out_dtype) swizzle = 128 in_col_tiling = 8 * swizzle // jnp.iinfo(in_dtype).bits in_tiling = (8, in_col_tiling) out_col_tiling = swizzle // out_dtype.itemsize out_tiling = (8, out_col_tiling) m, n = 128, in_col_tiling * 2 + regs_per_thread = None def kernel(ctx, in_, out, smems): + nonlocal regs_per_thread smem_in, smem_out, barrier = smems ctx.async_copy(src_ref=in_, dst_ref=smem_in, swizzle=swizzle, barrier=barrier) barrier.wait() t = mgpu.FragmentedArray.load_tiled( - smem_in, swizzle=swizzle, is_signed=True, layout=fa.WGMMA_LAYOUT_UPCAST_2X + smem_in, swizzle=swizzle, is_signed=True, layout=start_layout ) - if upcast_before_layout_change: - t = t.astype(ir.IntegerType.get_signless(16), is_signed=True) - t = t.to_layout(fa.WGMMA_LAYOUT) - if not upcast_before_layout_change: - t = t.astype(ir.IntegerType.get_signless(16), is_signed=True) + regs_per_thread = t.registers.size + t = t.astype(utils.dtype_to_ir_type(cast_dtype), is_signed=True) + t = t.to_layout(end_layout) + t = t.astype(out_dtype_mlir, is_signed=True) t.store_tiled(smem_out, swizzle=swizzle) mgpu.commit_shared() ctx.async_copy(src_ref=smem_out, dst_ref=out, swizzle=swizzle) @@ -2326,14 +2345,20 @@ class LayoutTest(TestCase): return x.reshape( x.shape[0] // tiling[0], tiling[0], x.shape[1] // tiling[1], tiling[1] ).transpose(0, 2, 1, 3) - x = jax.random.randint(jax.random.key(42), (m, n), -128, 127, dtype=in_dtype) + in_iinfo = jnp.iinfo(in_dtype) + x = jax.random.randint( + jax.random.key(42), (m, n), in_iinfo.min, in_iinfo.max, dtype=jnp.int32 + ).astype(in_dtype) xt = tile(x, in_tiling) y = x.astype(out_dtype) yt = tile(y, out_tiling) f = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), xt, yt, [xt, yt, mgpu.TMABarrier()], ) - np.testing.assert_array_equal(f(xt), yt) + with get_sass() as sass: + yt_kernel = f(xt) + np.testing.assert_array_equal(yt_kernel, yt) + self.assertEqual(sass().count("SHFL.BFLY"), regs_per_thread * shfl_per_reg) @dataclasses.dataclass(frozen=True) From ba2f7c9ad96c77a88c8cc7eb2d0fd859f517a43a Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Tue, 18 Mar 2025 04:53:00 -0700 Subject: [PATCH 31/34] [Mosaic GPU] Add transform inference rule for `mgpu.slice_smem`. PiperOrigin-RevId: 737957778 --- .../mosaic/gpu/transform_inference.py | 29 +++++++- tests/mosaic/gpu_transform_inference_test.py | 72 +++++++++++++++++++ 2 files changed, 100 insertions(+), 1 deletion(-) diff --git a/jax/experimental/mosaic/gpu/transform_inference.py b/jax/experimental/mosaic/gpu/transform_inference.py index a3919ea1d..be3f2c381 100644 --- a/jax/experimental/mosaic/gpu/transform_inference.py +++ b/jax/experimental/mosaic/gpu/transform_inference.py @@ -43,7 +43,8 @@ _transform_inference_rules: dict[str, TransformInferenceRule] = {} def _add_transform_inference_rule( op: type[ir.OpView], rule: TransformInferenceRule ): - _transform_inference_rules[op.OPERATION_NAME] = rule # pytype: disable=attribute-error + if op is not None: + _transform_inference_rules[op.OPERATION_NAME] = rule # pytype: disable=attribute-error return rule @@ -169,6 +170,32 @@ def _infer_vector_load_store_transforms( return None +# TODO(bchetioui): remove this once jaxlib minimum version >= 0.5.2. +SliceSMEMOp = getattr(mgpu, "SliceSMEMOp", None) + +@partial(_add_transform_inference_rule, SliceSMEMOp) +def _infer_slice_smem_transforms(op: SliceSMEMOp) -> OptionalTransforms: + transforms = None + uses = cast(ir.OpResult, op.result).uses + + for op_operand_use in uses: + consumer = op_operand_use.owner + op_user = consumer.operands[op_operand_use.operand_number] + out_transforms = inference_utils.in_transforms_for_operand( + consumer, op_user + ) + if transforms is not None and out_transforms is not None: + if transforms != out_transforms: + raise NotImplementedError( + f"Conflicting transforms for {op_user} in {op}: " + f"{transforms} != {out_transforms}." + ) + elif out_transforms is not None: + transforms = out_transforms + + return None if transforms is None else ([], [transforms]) + + def _should_have_transforms(op: ir.OpView) -> bool: """Returns 'True' if the operation should be assigned in/out transforms.""" return any( diff --git a/tests/mosaic/gpu_transform_inference_test.py b/tests/mosaic/gpu_transform_inference_test.py index 2618c22ac..b7cd146df 100644 --- a/tests/mosaic/gpu_transform_inference_test.py +++ b/tests/mosaic/gpu_transform_inference_test.py @@ -346,6 +346,78 @@ class TransformInferenceTest(parameterized.TestCase): with self.assertRaisesRegex(NotImplementedError, "Conflicting transforms"): mgpu.infer_transforms(self.module) + def test_infer_transforms_for_slice_smem_op_derives_from_user(self): + slice_smem_op = vector_load_op = None + shape = (64, 64) + elt_ty = ir.BF16Type.get() + smem = ir.Attribute.parse("#gpu.address_space") + + def body(offset): + nonlocal slice_smem_op, vector_load_op + slice_smem_op = mgpu.dialect.SliceSMEMOp( + ir.MemRefType.get(shape, elt_ty, memory_space=smem), offset + ) + zero = arith.constant(ir.IntegerType.get_signless(32), 0) + load_offsets = [zero] * len(shape) + vector_load_op = vector.LoadOp( + ir.VectorType.get(shape, elt_ty), slice_smem_op.result, load_offsets + ) + + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func(ir.IntegerType.get_signless(32))(body) + + vector_load_op.attributes["out_layouts"] = ir.ArrayAttr.get( + [layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)] + ) + + mgpu.infer_transforms(self.module) + + expected_transforms = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((8, 64)), + mgpu.dialect.SwizzleTransformAttr.get(128), + ]) + + self.assertEmpty(inference_utils.in_transforms(slice_smem_op)) + self.assertSequenceEqual( + inference_utils.out_transforms(slice_smem_op), [expected_transforms] + ) + + def test_infer_transforms_for_slice_smem_op_raises_on_mismatches(self): + slice_smem_op = vector_load_op1 = vector_load_op2 = None + shape = (64, 64) + elt_ty = ir.BF16Type.get() + smem = ir.Attribute.parse("#gpu.address_space") + + def body(offset): + nonlocal slice_smem_op, vector_load_op1, vector_load_op2 + slice_smem_op = mgpu.dialect.SliceSMEMOp( + ir.MemRefType.get(shape, elt_ty, memory_space=smem), offset + ) + zero = arith.constant(ir.IntegerType.get_signless(32), 0) + load_offsets = [zero] * len(shape) + vector_load_op1 = vector.LoadOp( + ir.VectorType.get(shape, elt_ty), slice_smem_op.result, load_offsets + ) + vector_load_op2 = vector.LoadOp( + ir.VectorType.get(shape, elt_ty), slice_smem_op.result, load_offsets + ) + + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func(ir.IntegerType.get_signless(32))(body) + + vector_load_op1.attributes["out_layouts"] = ir.ArrayAttr.get( + [layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)] + ) + vector_load_op2.attributes["out_layouts"] = ir.ArrayAttr.get( + [layouts_lib.to_layout_attr(fa.WGStridedFragLayout(shape, vec_size=4))] + ) + vector_load_op2.attributes["in_transforms"] = ir.ArrayAttr.get( + [ir.ArrayAttr.get([mgpu.dialect.TransposeTransformAttr.get((1, 0))])] + ) + + with self.assertRaisesRegex(NotImplementedError, "Conflicting transforms"): + mgpu.infer_transforms(self.module) + if __name__ == "__main__": parameterized.absltest.main(testLoader=jtu.JaxTestLoader()) From 7a459f0ed1016a702f5cf0918079d6ee29abbeff Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 18 Mar 2025 05:00:46 -0700 Subject: [PATCH 32/34] Update XLA dependency to use revision http://github.com/openxla/xla/commit/3bb765472122548cc227b8bd2990f00bd533f438. PiperOrigin-RevId: 737959582 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index b374e3a29..9f2f77500 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "fcf97e619e26fcb19cffa060df2d0246f6a7ece7" -XLA_SHA256 = "ed7fb9863ea1e20a16bdfb135e48ea39c4b232ef2fd49e173de4e2e43fa76e09" +XLA_COMMIT = "3bb765472122548cc227b8bd2990f00bd533f438" +XLA_SHA256 = "72126aac7602153aee985ca20f73d11c39e3ba9cfb8027492951e787559d0497" def repo(): tf_http_archive( From 8da93249d25f08eb95de8f19682c6aded8176d41 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 18 Mar 2025 05:37:52 -0700 Subject: [PATCH 33/34] [Mosaic GPU] Fuse slicing into s4 -> bf16 upcasts This allows us to significantly simplify the generated PTX/SASS, which is currently cluttered with LLVM trying to align slices to start at bit 0 and failing to CSE the right shifts. PiperOrigin-RevId: 737967890 --- .../mosaic/gpu/fragmented_array.py | 33 +++++++++---- jax/experimental/mosaic/gpu/utils.py | 16 +++---- jaxlib/mosaic/gpu/BUILD | 1 + jaxlib/mosaic/gpu/passes.cc | 46 +++++++++++++++++++ 4 files changed, 78 insertions(+), 18 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index ded17d5d4..8b8fdaceb 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -1373,15 +1373,30 @@ class FragmentedArray: for group_size in (8, 4, 2): int_ty = ir.IntegerType.get_signless(group_size * 4) while vector_len - offset >= group_size: - reg_slice = utils.vector_slice(reg, slice(offset, offset + group_size)) - reg_slice_int = utils.bitcast(reg_slice, int_ty) - if int_ty != i32: - reg_slice_int = arith.extsi(i32, reg_slice_int) - reg_slice_int_shr = arith.shrui(reg_slice_int, c(4, i32)) - out_int_regs.extend( - upcast_to_bf16(reg_slice_int, reg_slice_int_shr, part=part) - for part in range(group_size // 2) - ) + # If the vector originates from a slice (common after relayouts), we + # can fuse the slicing into the conversion and prevent LLVM from + # generating a bunch of shifts to align the vector data to the LSB. + # This also lets us share the right shift among more vectors. + if (isinstance(slice_op := reg.owner.opview, vector.ExtractStridedSliceOp) + and utils.bitwidth(slice_op.vector.type) == 32 + and slice_op.strides[0].value == 1): + slice_offset = slice_op.offsets[0].value + offset + reg_int = utils.bitcast(slice_op.vector, i32) + reg_int_shr = arith.shrui(reg_int, c(4, i32)) + out_int_regs.extend( + upcast_to_bf16(reg_int, reg_int_shr, part=(slice_offset // 2 + part)) + for part in range(group_size // 2) + ) + else: + reg_slice = utils.vector_slice(reg, slice(offset, offset + group_size)) + reg_slice_int = utils.bitcast(reg_slice, int_ty) + if int_ty != i32: + reg_slice_int = arith.extsi(i32, reg_slice_int) + reg_slice_int_shr = arith.shrui(reg_slice_int, c(4, i32)) + out_int_regs.extend( + upcast_to_bf16(reg_slice_int, reg_slice_int_shr, part=part) + for part in range(group_size // 2) + ) offset += group_size assert offset == vector_len out_vec_int = utils.vector_concat([ diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 91cb19746..28534cf40 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -346,7 +346,7 @@ def bitwidth_impl(ty: ir.Type): return ir.IntegerType(ty).width if ir.FloatType.isinstance(ty): return ir.FloatType(ty).width - if dialect is not None and ir.Type.parse("!mosaic_gpu.barrier"): + if dialect is not None and ty == ir.Type.parse("!mosaic_gpu.barrier"): return MBARRIER_BYTES * 8 if ir.VectorType.isinstance(ty): vty = ir.VectorType(ty) @@ -1237,17 +1237,15 @@ def ceil_div(x: int, y: int): def vector_slice(v: ir.Value, s: slice): - i32 = ir.IntegerType.get_signless(32) v_ty = ir.VectorType(v.type) if len(v_ty.shape) != 1: - raise NotImplementedError + raise NotImplementedError(v_ty) [v_len] = v_ty.shape - it = range(v_len)[s] - result = llvm.mlir_undef(ir.VectorType.get((len(it),), v_ty.element_type)) - for tgt, src in enumerate(it): - elem = llvm.extractelement(v, c(src, i32)) - result = llvm.insertelement(result, elem, c(tgt, i32)) - return result + slice_length = len(range(v_len)[s]) + return vector.extract_strided_slice( + ir.VectorType.get((slice_length,), v_ty.element_type), + v, [s.start or 0], [slice_length], [1], + ) def vector_concat(vectors: Sequence[ir.Value]) -> ir.Value: diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index 0686db098..9249ae256 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -65,6 +65,7 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:VectorDialect", ], ) diff --git a/jaxlib/mosaic/gpu/passes.cc b/jaxlib/mosaic/gpu/passes.cc index cee34ddae..b8c3fbb74 100644 --- a/jaxlib/mosaic/gpu/passes.cc +++ b/jaxlib/mosaic/gpu/passes.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "jaxlib/mosaic/gpu/passes.h" +#include #include #include #include @@ -23,6 +24,7 @@ limitations under the License. #include "mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/include/mlir/IR/BuiltinAttributes.h" #include "mlir/include/mlir/IR/BuiltinOps.h" #include "mlir/include/mlir/IR/SymbolTable.h" @@ -36,6 +38,49 @@ namespace gpu { namespace { +// Upstream MLIR does not implement an LLVM lowering pattern for this op. +struct ConvertExtractStridedSlicePattern final + : public mlir::OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + mlir::LogicalResult matchAndRewrite( + mlir::vector::ExtractStridedSliceOp op, OpAdaptor subst, + mlir::ConversionPatternRewriter &rewriter) const override { + auto vty = op.getSourceVectorType(); + if (vty.getRank() != 1) { + return rewriter.notifyMatchFailure(op, "only 1-D vectors are supported"); + } + int64_t size = + (*op.getSizes().getAsRange().begin()).getSInt(); + if (size < 0) { + return rewriter.notifyMatchFailure(op, "size is negative"); + } + int64_t start = + (*op.getOffsets().getAsRange().begin()).getSInt(); + int64_t stride = + (*op.getStrides().getAsRange().begin()).getSInt(); + if (stride != 1) { + return rewriter.notifyMatchFailure(op, "only stride 1 is supported"); + } + if (start < 0 || start + size > vty.getShape()[0]) { + return rewriter.notifyMatchFailure(op, "slice is out of bounds"); + } + mlir::Value result = rewriter.create( + op.getLoc(), op.getResult().getType()); + for (int64_t i = 0; i < size; ++i) { + result = rewriter.create( + op.getLoc(), result, + rewriter.create( + op.getLoc(), subst.getVector(), + rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(i + start))), + rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(i))); + } + rewriter.replaceOp(op, result); + return mlir::success(); + } +}; + class ConvertGpuToLLVMPass : public jaxlib::mlir::Pass { public: @@ -58,6 +103,7 @@ class ConvertGpuToLLVMPass }); auto symtab = mlir::SymbolTable(getOperation()); mlir::populateGpuToLLVMConversionPatterns(converter, patterns, false); + patterns.insert(&getContext()); if (mlir::applyPartialConversion(getOperation(), target, std::move(patterns)) .failed()) { From 1e36cbe59708764a87a6f72babf4bb5e8dd00c74 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Tue, 18 Mar 2025 06:28:05 -0700 Subject: [PATCH 34/34] [Mosaic GPU] Raise a `NotImplementedError` if `swizzle=16`. Unswizzled MMAs don't lower correctly, and are not currently intended to be supported. PiperOrigin-RevId: 737981373 --- jax/experimental/mosaic/gpu/tcgen05.py | 2 ++ jax/experimental/mosaic/gpu/wgmma.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index e5a2d3aa5..3330500cd 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -83,6 +83,8 @@ def mma( accumulate: ir.Value | bool = True, collective: bool = False, ): + if a_swizzle == 16 or b_swizzle == 16: + raise NotImplementedError("No swizzle is not supported") i32 = ir.IntegerType.get_signless(32) i64 = ir.IntegerType.get_signless(64) if isinstance(accumulate, bool): diff --git a/jax/experimental/mosaic/gpu/wgmma.py b/jax/experimental/mosaic/gpu/wgmma.py index ce0c5946a..8baa16d8a 100644 --- a/jax/experimental/mosaic/gpu/wgmma.py +++ b/jax/experimental/mosaic/gpu/wgmma.py @@ -259,6 +259,8 @@ def wgmma( The refs must be contiguous or be contiguous except for having their two minor dimensions swapped. """ + if swizzle == 16: + raise NotImplementedError("No swizzle is not supported") # Step 1. Establish the shape and element type of the operation. if not ir.MemRefType.isinstance(b.type): raise ValueError(f"B must be a memref, got: {b.type}")