1371 Commits

Author SHA1 Message Date
jax authors
f3b7c5cb9e Integrate LLVM at llvm/llvm-project@0230d63b4a
Updates LLVM usage to match
[0230d63b4a8b](https://github.com/llvm/llvm-project/commit/0230d63b4a8b)

PiperOrigin-RevId: 738222096
2025-03-18 19:23:20 -07:00
Peter Hawkins
3f91b4b43a Move jaxlib/{cuda,rocm}_plugin_extension into jaxlib/{cuda/rocm}/
Move the common jaxlib/gpu_plugin_extension into jaxlib/gpu/

Cleanup only, no functional changes intended.

PiperOrigin-RevId: 738183402
2025-03-18 16:29:37 -07:00
Peter Hawkins
547d602760 Remove //jaxlib:cpu_kernels and //jaxlib:gpu_kernels forwarding Bazel targets.
These were temporary forwarding targets that are no longer needed; use //jaxlib/cpu:cpu_kernels and //jaxlib/cuda:cuda_gpu_kernels instead.

PiperOrigin-RevId: 738085234
2025-03-18 11:39:00 -07:00
Adam Paszke
8da93249d2 [Mosaic GPU] Fuse slicing into s4 -> bf16 upcasts
This allows us to significantly simplify the generated PTX/SASS,
which is currently cluttered with LLVM trying to align slices to
start at bit 0 and failing to CSE the right shifts.

PiperOrigin-RevId: 737967890
2025-03-18 05:38:49 -07:00
Chris Jones
38d52a19ef [mosaic_gpu] Force flush all cupti activity, then unsubscribe.
With default flushing, it is possible for events to be missed. We should only unsubscribe after we are finished with cupti.

PiperOrigin-RevId: 737939327
2025-03-18 03:35:03 -07:00
Peter Hawkins
14cb7453f0 Add a C++ implementation of a toplogical sort.
This is an exact port of the current Python implementation to C++ for speed.

I am being careful not to change the topological order we return in any way in this change, although we may do so in a future change.

PiperOrigin-RevId: 737014989
2025-03-14 16:04:25 -07:00
Daniel Suo
39e8ee93b0 Add experimental/serialize_executable.py to BUILD.
PiperOrigin-RevId: 736975882
2025-03-14 13:54:39 -07:00
Peter Hawkins
8ab33669e2 Add a variant of safe_map() that has no return value, named foreach().
This avoids a bunch of list bookkeeping in cases where we are iterating only for a side effect and do not care about the results.

I would have named this iter() to match OCaml's list.iter(), but unfortunately iter() is a Python builtin.

PiperOrigin-RevId: 736859418
2025-03-14 07:42:48 -07:00
Tzu-Wei Sung
e235fb9760 [Mosaic] Allow part of x2 int casts.
This should at least allow int2 -> int4 for native tiling vregs. Skip many tests due to XLA compatibility.

PiperOrigin-RevId: 736710186
2025-03-13 18:57:36 -07:00
jax authors
538a2be7fe Reverts 74b4d868e3751c1b4efa315ff8cf771faeb0b663
PiperOrigin-RevId: 736650031
2025-03-13 14:59:09 -07:00
Tzu-Wei Sung
a0f1be123d [Mosaic] Improve error messages.
PiperOrigin-RevId: 736580673
2025-03-13 11:35:33 -07:00
Jevin Jiang
12c0987e2f [Mosaic TPU][NFC] Throw NYI error instead of crash when squeeze ref to 1d.
PiperOrigin-RevId: 736263705
2025-03-12 14:18:33 -07:00
Dan Foreman-Mackey
8b7cfcb33c Fix integer overflow in workspace size computations for experimental.rnn.*.
PiperOrigin-RevId: 736139471
2025-03-12 08:22:04 -07:00
Chris Jones
74b4d868e3 Add support for scratch buffers in jax_triton.
This is required to use device-side TMA descriptors.

PiperOrigin-RevId: 735985603
2025-03-11 20:49:33 -07:00
Dimitar (Mitko) Asenov
99c9106032 [Mosaic GPU] Replace WGMMAFragLayout with TiledLayout in the mlir dialect and use it in layout inference.
`WGMMAFragLayout` will be completely removed soon.

PiperOrigin-RevId: 735877661
2025-03-11 13:50:42 -07:00
jax authors
1aca76fc13 Update :build_jaxlib flag to control whether we should add py_import dependencies to the test targets.
This change enables testing the wheels produced by the build rules in the presubmit using one `bazel test` command only.

There are three options for running the tests:

1) `build_jaxlib=true`: the tests depend on JAX targets.
2) `build_jaxlib=false`: the tests depend on the wheel files located in the `dist` folder.
3) `build_jaxlib=wheel`: the tests depend on the py_import targets.

PiperOrigin-RevId: 735765819
2025-03-11 08:31:43 -07:00
Dan Foreman-Mackey
21884d4a14 Move (most) jaxlib linalg custom call registration into JAX.
My motivation here is to fix the plugin support for batch partitionable custom calls. Since plugin support for custom call partitioners is provided via register_plugin_callback in xla_bridge, instead of xla_client itself, it's much more straightforward to register the custom calls in JAX.

It would be possible to refactor things differently, but it actually seems like a reasonable choice to use the supported APIs from `jax.ffi` instead of `xla_client` so that we can take advantage of any new features we might add there in the future.

This is all still a little bit brittle and I'd eventually like to migrate to a version where the XLA FFI library provides a mechanism for exporting handlers, but this change is still compatible with any future changes like that.

PiperOrigin-RevId: 735381736
2025-03-10 08:17:44 -07:00
Jevin Jiang
0f0636afab [Mosaic TPU][Pallas] Add pl.reciprocal
PiperOrigin-RevId: 734749577
2025-03-07 18:29:30 -08:00
Sergei Lebedev
928caf83ee [pallas:mosaic_gpu] copy_smem_to_gmem now allows skipping cp.async.commit_group
This feature is necessary to fix the SMEM->GMEM waiting behavior in
`emit_pipeline`, which used a pessimistic condition prior to this change,
since every copy was its own commit group.

PiperOrigin-RevId: 734553668
2025-03-07 07:43:54 -08:00
Jevin Jiang
ff4310f640 [Mosaic TPU] Support fp8 upcast to f32
PiperOrigin-RevId: 734345644
2025-03-06 17:19:15 -08:00
jax authors
c16f37d89d Set USERPROFILE for Windows builds to fix CI issue.
This change fixes https://github.com/jax-ml/jax/actions/runs/13686468791/job/38270929632.

From the [documentation](https://docs.python.org/3/library/os.path.html#os.path.expanduser):
`On Windows, USERPROFILE will be used if set, otherwise a combination of HOMEPATH and HOMEDRIVE will be used.`

PiperOrigin-RevId: 733935305
2025-03-05 18:09:14 -08:00
jax authors
0913cd7583 Fix build rule for free-threaded python builds.
PiperOrigin-RevId: 733857126
2025-03-05 13:54:24 -08:00
jax authors
3edc068f8c Fix ambiguous cpu definition for JAX wheels.
Should fix the error in https://github.com/jax-ml/jax/actions/runs/13682579939/job/38258344926.

PiperOrigin-RevId: 733838895
2025-03-05 12:59:21 -08:00
jax authors
a13b3cedad Merge pull request #26691 from h-vetinari:packed
PiperOrigin-RevId: 733696873
2025-03-05 05:46:01 -08:00
David Dunleavy
1a19d5594a Update all uses of @tsl//third_party to @xla//third_party
PiperOrigin-RevId: 733495240
2025-03-04 15:55:23 -08:00
jax authors
ce3412e540 Remove redundant BUILD_TAG from JAX wheels build rule.
PiperOrigin-RevId: 733334423
2025-03-04 08:13:13 -08:00
Sharad Vikram
d32e282ff9 Add fuser to jax.experimental.pallas
Note that fuser is considered experimental within Pallas and APIs are subject to change

PiperOrigin-RevId: 733117882
2025-03-03 17:26:44 -08:00
Tzu-Wei Sung
5179642eb5 [Mosaic] Rename dep name.
PiperOrigin-RevId: 732985217
2025-03-03 11:01:25 -08:00
Dimitar (Mitko) Asenov
3b305c6617 [Mosaic GPU] Infer layouts (transforms) on memrefs that directly feed into the dialect wgmma op.
This change detects a situation where a gmem_memref is read via `async_load` and directly used in a wgmma. In such cases, we insert a cast before the load to add tile, transpose, and swizzle transformations.

PiperOrigin-RevId: 732618760
2025-03-02 03:17:13 -08:00
jax authors
8f57b8167b Add build targets for jax-rocm-plugin and jax-rocm-pjrt wheels.
PiperOrigin-RevId: 732149495
2025-02-28 08:36:46 -08:00
Benjamin Chetioui
a9ab614123 [Pallas/Mosaic GPU] Add an abstraction to obtain a slice of dynamic shared memory when using waprgroup semantics.
Explicitly make the assumption that `runtime_smem` starts at `0` in the Pallas
module context---which should be enforced by Mosaic GPU.

This is in preparation of changes implementing transform inference.

PiperOrigin-RevId: 732091266
2025-02-28 04:38:25 -08:00
Kanglan Tang
55263ce485 Add linux python 3.13t nightly tests
* Python wheels follow a naming convention: standard wheels use the pattern `*-cp<python_version>-cp<python_version>-*`, while free-threaded wheels use `*-cp<python_version>-cp<python_version>t-*`. Update the pytest workflows to look for free-threaded wheels and ensure that standard wheel tests exclude free-threaded wheels.

* Skip zstandard for python3.13-nogil due to compilation failure https://github.com/indygreg/python-zstandard/issues/231.

PiperOrigin-RevId: 732070585
2025-02-28 03:13:39 -08:00
Dan Foreman-Mackey
c7ed1bd3a8 Add version check to jaxlib plugin imports.
For the CUDA and ROCM plugins, we only support exact matches between the plugin and jaxlib version, and bad things can happen if we try and load mismatched versions. This change issues a warning and skips importing a plugin when there is a version mismatch.

There are a handful of other places where plugins are imported throughout the JAX codebase (e.g. in lax_numpy, mosaic_gpu, and in the plugins themselves). In a follow up it would be good to add version checking there too, but let's start with just these ones.

PiperOrigin-RevId: 731808733
2025-02-27 11:52:17 -08:00
jax authors
401d315091 Add targets for jaxlib, jax-cuda-plugin and jax-cuda-pjrt editable wheels.
PiperOrigin-RevId: 731737119
2025-02-27 08:33:40 -08:00
Dan Foreman-Mackey
f93c2a1aa5 Add and test support for partitioning of batch dimensions in lax.linalg.
On CPU and GPU, almost all of the primitives in lax.linalg are backed by custom calls that support simple semantics when batch dimensions are sharded. Before this change, all linalg operations on CPU and GPU will insert an `all-gather` before being executed when called on sharded inputs, even when that shouldn't be necessary. This change adds support for this type of partitioning, to cover a wide range of use cases.

There are a few remaining GPU ops that don't support partitioning either because they are backed by HLO ops that don't partition properly (Cholesky factorization and triangular solves), or because they're still using descriptors with problem dimensions in kernel. I'm going to fix these in follow up changes.

PiperOrigin-RevId: 731732301
2025-02-27 08:16:16 -08:00
jax authors
4eb782e402 Update jax_wheel target to produce both wheel and source distribution files.
This change replicates the old method of building `jax` wheel via `python -m build`, which produced `.tar.gz` and `.whl` files.

PiperOrigin-RevId: 731721522
2025-02-27 07:41:13 -08:00
jax authors
615219b1f6 Remove tensorstore dependency from //jax/experimental/array_serialization:serialization in OSS (see https://github.com/google/tensorstore/issues/218)
Disable serialization_test in OSS.

PiperOrigin-RevId: 731463136
2025-02-26 14:47:16 -08:00
William S. Moses
8262987a1c Fix build dependencies
PiperOrigin-RevId: 731330542
2025-02-26 08:38:31 -08:00
Adam Paszke
99a12ef9ea [Mosaic GPU] Add support for warpgroup lowering of loops with vector carries
PiperOrigin-RevId: 731260912
2025-02-26 04:29:36 -08:00
Adam Paszke
cb7402f6de Remove MemoryEffects annotations from async_{load/store} ops
The annotation on async_load didn't indicate its write to SMEM, allowing it
to be DCEd by MLIR canonicalization. We don't get much mileage out of those
annotations, so let's just delete them for simplicity.

PiperOrigin-RevId: 731003033
2025-02-25 13:15:00 -08:00
Dan Foreman-Mackey
2ce88c950a Deprecate alpha argument to trsm LAPACK kernel.
(Part of general cleanups of the lax.linalg submodule.)

This is always set to 1 and I don't see any benefit to keeping this argument around. This can be done in a forward and backward compatible way following these docs: https://docs.jax.dev/en/latest/export/export.html#ensuring-forward-and-backward-compatibility

We start by updating the FFI handler to remove the explicit alpha argument, but allow it to accept (but ignore) extra input arguments. Then we only pass alpha when lowering in forward compatibility mode, or when the jaxlib version is old (I'm using >0.5.1 as the cutoff assuming that this change doesn't make it into the upcoming release).

Then, the forward compatibility lowering can be removed after at least 21 days, and the kernel can be updated at least 180 days after 0.5.2 is released.

PiperOrigin-RevId: 730928808
2025-02-25 10:04:29 -08:00
jax authors
eb912ad0d9 Create jax wheel build target.
This change introduces a uniform way of building the artifacts and controlling the filename version suffixes (see the changes for `jaxlib`, `jax-cuda-plugin` and `jax-cuda-pjrt` in https://github.com/jax-ml/jax/pull/25126)

Previously `jax` wheel was built via `python3 -m build` command. The resulting wheel contained the python packages files in `jax` folder (e.g. the files in the subdirs that have `__init__.py` file).

You can still build the `jax` wheel with `python3 -m build` command.

Bazel `jax` wheel target: `//:jax_wheel`

Environment variables combinations for creating wheels with different versions:
  * self-built wheel (default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot`
  * release: `--repo_env=ML_WHEEL_TYPE=release`
  * release candidate: `--repo_env=ML_WHEEL_TYPE=release --repo_env=ML_WHEEL_VERSION_SUFFIX=-rc1`
  * nightly build: `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=<YYYYmmdd> --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)`

PiperOrigin-RevId: 730916743
2025-02-25 09:30:08 -08:00
jax authors
083ffd3717 [Easy][Mosaic] Tiny refactor for clarity in getTypeBitwidth
PiperOrigin-RevId: 730906329
2025-02-25 08:58:19 -08:00
H. Vetinari
91cae595e4 fix member access to packed CUDA struct 2025-02-24 08:03:07 +11:00
Jan Naumann
e03fe3a06d Implement SVD algorithm based on QR for CPU targets
In a recent jax release the SvdAlgorithm parameter has been added
to the jax.lax.linalg.svd function. Currently, for CPU targets
still only the divide and conquer algorithm from LAPACK is
supported (gesdd).

This commits adds the functionality to select the QR based
algorithm on CPU as well. Mainly it addes the wrapper code
to call the gesvd function of LAPACK using the FFI interface.

Signed-off-by: Jan Naumann <j.naumann@fu-berlin.de>
2025-02-22 15:24:57 +01:00
jax authors
b510127a13 Internal compatibility change
PiperOrigin-RevId: 729428478
2025-02-21 01:21:56 -08:00
jax authors
b7968474c2 [Pallas][Mosaic] Support float8_e4m3b11fnuz
PiperOrigin-RevId: 729169181
2025-02-20 10:44:33 -08:00
jax authors
37af0135b0 [Mosaic] Consider divisibility when doing large tiling
PiperOrigin-RevId: 728980108
2025-02-19 23:56:07 -08:00
Jevin Jiang
bb68124c33 [Mosaic TPU] Support mask concat
PiperOrigin-RevId: 728349788
2025-02-18 14:03:46 -08:00
jax authors
725087e13f Integrate LLVM at llvm/llvm-project@9d24f94379
Updates LLVM usage to match
[9d24f9437944](https://github.com/llvm/llvm-project/commit/9d24f9437944)

PiperOrigin-RevId: 728265165
2025-02-18 10:30:48 -08:00