26079 Commits

Author SHA1 Message Date
jax authors
8f57b8167b Add build targets for jax-rocm-plugin and jax-rocm-pjrt wheels.
PiperOrigin-RevId: 732149495
2025-02-28 08:36:46 -08:00
Dan Foreman-Mackey
bb9aed5eec Reimplement custom_vjp.optimize_remat using custom_dce. 2025-02-28 10:00:28 -05:00
Adam Paszke
bb96226dd8 [Mosaic GPU] Add support for small RHS tile sizes in WGMMA
This is useful for more fine-grained autotuning and can help avoid
wave quantization effects.

PiperOrigin-RevId: 732105219
2025-02-28 05:41:30 -08:00
Benjamin Chetioui
1bc36e623b [Mosaic GPU][NFC] Delete workaround for dialect bindings before jaxlib 0.5.1.
PiperOrigin-RevId: 732102282
2025-02-28 05:25:53 -08:00
Benjamin Chetioui
7c46480eab [Mosaic GPU] Fix as_dialect_barrier_memref to take into account BarrierRef's offset.
PiperOrigin-RevId: 732098299
2025-02-28 05:06:57 -08:00
Benjamin Chetioui
a9ab614123 [Pallas/Mosaic GPU] Add an abstraction to obtain a slice of dynamic shared memory when using waprgroup semantics.
Explicitly make the assumption that `runtime_smem` starts at `0` in the Pallas
module context---which should be enforced by Mosaic GPU.

This is in preparation of changes implementing transform inference.

PiperOrigin-RevId: 732091266
2025-02-28 04:38:25 -08:00
Kanglan Tang
55263ce485 Add linux python 3.13t nightly tests
* Python wheels follow a naming convention: standard wheels use the pattern `*-cp<python_version>-cp<python_version>-*`, while free-threaded wheels use `*-cp<python_version>-cp<python_version>t-*`. Update the pytest workflows to look for free-threaded wheels and ensure that standard wheel tests exclude free-threaded wheels.

* Skip zstandard for python3.13-nogil due to compilation failure https://github.com/indygreg/python-zstandard/issues/231.

PiperOrigin-RevId: 732070585
2025-02-28 03:13:39 -08:00
Benjamin Chetioui
abfe2d080e [Mosaic GPU][NFC] Move some functions to a new file called inference_utils.py.
The intent is to move utils that are useful for both layout inference and
transform inference to a shared location.

PiperOrigin-RevId: 732067659
2025-02-28 03:02:59 -08:00
jax authors
5a770701ae Update XLA dependency to use revision
9fa90d72eb.

PiperOrigin-RevId: 732060951
2025-02-28 02:36:43 -08:00
Adam Paszke
832f5a3aff [Mosaic GPU] Remove TMA inputs
This test configuration dates back to the time when we were very unsure
about how to use TMA. At this point we have plenty of experience and it
makes more sense to focus the test in question on verifying WGMMA. This
also simplifies adding support for smaller RHS tiling.

PiperOrigin-RevId: 732040900
2025-02-28 01:19:28 -08:00
Adam Paszke
092ea35301 [Mosaic GPU][NFC] Start refactoring the MMA parameter inference
The CUDA 12.8 release significantly improved the MMA docs, letting us
improve upon the previously used "magic number" scheme. Sadly, the docs
are still incorrect, but at least I can begin to make some sense of those
parameters.

PiperOrigin-RevId: 732033585
2025-02-28 00:50:20 -08:00
jax authors
d8953e5311 Remove spurious zero_to_zero conversion used exclusively for backcompat types as a way of supporting (best effort) unsupported types on hardware to make it easier to debug.
In general, this is a good feature, but it assumed that the packing type utilized here was exclusively for backcompat, and so always applied the adjustment.

PiperOrigin-RevId: 731954456
2025-02-27 19:21:37 -08:00
Skye Wanderman-Milne
1b87ee07c4 Update setup.py to automatically pick up libtpu patch releases 2025-02-27 17:22:29 -08:00
Dan Foreman-Mackey
810445b10d Temporarily skip linalg sharding tests on GPU.
PiperOrigin-RevId: 731877477
2025-02-27 14:56:16 -08:00
Kanglan Tang
d839e441b7 Reduce pytest workers for asan to resolve memory usage causing OOM
This fixes the current OOM error: https://github.com/jax-ml/jax/actions/runs/13565999206.

PiperOrigin-RevId: 731876781
2025-02-27 14:54:29 -08:00
Yash Katariya
dda62f576f Make sure default layout is None for input and output layout in all codepaths
PiperOrigin-RevId: 731865511
2025-02-27 14:26:25 -08:00
jax authors
c7ca35fe32 Merge pull request #26345 from wenscarl:scaled_matmul
PiperOrigin-RevId: 731865430
2025-02-27 14:24:48 -08:00
jax authors
6a7736754f Reverts 0f0d5e90ef1c3d60f35020141710ea350d17816b
PiperOrigin-RevId: 731844119
2025-02-27 13:27:32 -08:00
Sharad Vikram
6f57410e12 [Pallas TPU] Use grid_env for pipeline body so we can query num_programs/program_id inside the block spec
PiperOrigin-RevId: 731831543
2025-02-27 12:53:02 -08:00
Yash Katariya
07f192cd48 Merge _check_mesh_resource_axis and _check_axis_type_consistency into 1 function.
PiperOrigin-RevId: 731830347
2025-02-27 12:51:25 -08:00
Kanglan Tang
8bb3100019 Add -Wl,-undefined,dynamic_lookup as linkopt on macOS
This is needed to fix mac arm64 nightly failures after Bazel 7.4.1 upgrade: https://github.com/bazelbuild/bazel/pull/16414.

PiperOrigin-RevId: 731830060
2025-02-27 12:49:39 -08:00
Yash Katariya
c265568530 Remove parsed_pspec from NamedSharding constructor
PiperOrigin-RevId: 731820173
2025-02-27 12:24:17 -08:00
Peter Hawkins
1e5d9a9158 Add an allow_negative_indices option to lax.dynamic_slice and lax.dynamic_update_slice.
The goal of this change is to avoid generating code to wrap negative indices back into range in cases where we know it doesn't matter. Change scan to pass allow_negative_indices=False to avoid emitting index wrapping code for each scan argument.

PiperOrigin-RevId: 731812827
2025-02-27 12:04:28 -08:00
jax authors
3450e2cee0 Disable certain tests on V4 and below.
PiperOrigin-RevId: 731812726
2025-02-27 12:02:52 -08:00
Dan Foreman-Mackey
c7ed1bd3a8 Add version check to jaxlib plugin imports.
For the CUDA and ROCM plugins, we only support exact matches between the plugin and jaxlib version, and bad things can happen if we try and load mismatched versions. This change issues a warning and skips importing a plugin when there is a version mismatch.

There are a handful of other places where plugins are imported throughout the JAX codebase (e.g. in lax_numpy, mosaic_gpu, and in the plugins themselves). In a follow up it would be good to add version checking there too, but let's start with just these ones.

PiperOrigin-RevId: 731808733
2025-02-27 11:52:17 -08:00
Yash Katariya
c94ec0eb0d Use batched_device_put for token shard_arg handler
PiperOrigin-RevId: 731800613
2025-02-27 11:30:22 -08:00
jax authors
da39b6f3d4 Comment change
PiperOrigin-RevId: 731792151
2025-02-27 11:07:59 -08:00
Yash Katariya
d69da3b012 More cleanups around ParsedPartitionSpec. In a follow up CL, I can remove it from NamedSharding constructor. Deleting ParsedPartitionSpec is remaining but that's after 0.5.2 release.
PiperOrigin-RevId: 731785005
2025-02-27 10:51:04 -08:00
Yash Katariya
034a827a4d Remove _parsed_pspec from everywhere in JAX except for NamedSharding constructor. I'll do that in the next CL since that has a dependency on C++ so needs guards.
PiperOrigin-RevId: 731772222
2025-02-27 10:17:06 -08:00
Yash Katariya
177e1f6ed9 Canonicalize PartitionSpec so that we can delete ParsedPartitionSpec. We need to do this after sharding-in-types to speed up NamedSharding construction and remove a lot of tech debt and unnecessary complexity.
* `_partitions` is now canonicalized and only contains `tuples`, `singular strings`, `None` or `UNCONSTRAINED`. No more empty tuples (`P((), 'x')`) and singleton tuples.

* Cache the creating of sharding on ShapedArray since it's expensive to do it a lot of times

* Change the `__hash__` and `__eq__` of `NamedSharding` to depend on `self.spec` instead of `self._parsed_pspec`.

PiperOrigin-RevId: 731745062
2025-02-27 08:59:25 -08:00
jax authors
401d315091 Add targets for jaxlib, jax-cuda-plugin and jax-cuda-pjrt editable wheels.
PiperOrigin-RevId: 731737119
2025-02-27 08:33:40 -08:00
Dan Foreman-Mackey
f93c2a1aa5 Add and test support for partitioning of batch dimensions in lax.linalg.
On CPU and GPU, almost all of the primitives in lax.linalg are backed by custom calls that support simple semantics when batch dimensions are sharded. Before this change, all linalg operations on CPU and GPU will insert an `all-gather` before being executed when called on sharded inputs, even when that shouldn't be necessary. This change adds support for this type of partitioning, to cover a wide range of use cases.

There are a few remaining GPU ops that don't support partitioning either because they are backed by HLO ops that don't partition properly (Cholesky factorization and triangular solves), or because they're still using descriptors with problem dimensions in kernel. I'm going to fix these in follow up changes.

PiperOrigin-RevId: 731732301
2025-02-27 08:16:16 -08:00
Adrian Kuegel
de4d047852 Change int4 packing from big-endian to little-endian
LLVM uses little-endian format for int4 packing. To avoid converting between
these formats, we should also use little-endian in XLA.

PiperOrigin-RevId: 731731530
2025-02-27 08:13:43 -08:00
Nitin Srinivasan
5ae0e58a4a Update the calculation for num_processes and num_test_jobs that are used in CUDA test jobs
We need to set them as `min(num_cpu_cores, num_gpus * max_tests_per_gpu, total ram in GB/6)` where max_tests_per_gpu = (GPU memory / 2GB)

PiperOrigin-RevId: 731730857
2025-02-27 08:11:51 -08:00
Bart Chrzaszcz
4997e45743 #sdy close any partially sharded dimensions if using auto axes in a shard_map.
PiperOrigin-RevId: 731724837
2025-02-27 07:53:18 -08:00
jax authors
4eb782e402 Update jax_wheel target to produce both wheel and source distribution files.
This change replicates the old method of building `jax` wheel via `python -m build`, which produced `.tar.gz` and `.whl` files.

PiperOrigin-RevId: 731721522
2025-02-27 07:41:13 -08:00
jax authors
a8738a069e Merge pull request #26804 from hawkinsp:tsan
PiperOrigin-RevId: 731721390
2025-02-27 07:39:35 -08:00
jax authors
07f5d7a475 Reverts f3fade3b70443b6cf87f01f360e6a1cb85d4b1fb
PiperOrigin-RevId: 731658204
2025-02-27 03:26:37 -08:00
Peter Hawkins
6e73637888 Fix a test failure under multi-threading.
Remove a tsan suppression for a CPython race that is fixed.
2025-02-27 06:07:05 -05:00
jax authors
0fbc453d94 Update XLA dependency to use revision
fb6241ad51.

PiperOrigin-RevId: 731643649
2025-02-27 02:27:17 -08:00
Henning Becker
b3f7c93cb2 Fix cudnn version skipping in fused_attention_stablehlo_test
The CUDNN_VERSION is defined as (CUDNN_MAJOR * 10000 + CUDNN_MINOR * 100 + CUDNN_PATCHLEVEL).

Therefore cuDNN 9.1.0 is represented as 90100 - not as 91000.

PiperOrigin-RevId: 731641814
2025-02-27 02:19:43 -08:00
Chris Jones
d6752e9267 [pallas:triton] Generate more efficient code for loading contiguous slices of int4 values.
The existing `int4` loading code is very generic. When reading contiguous data, it will read with offsets like `0, 0, 1, 1, ...`. Triton doesn't consider these to be contiguous in memory and emits much less efficient code than when reading contiguous blocks.

PiperOrigin-RevId: 731635736
2025-02-27 01:57:47 -08:00
Tom Hennigan
1becb57ac9 Add jax.copy_to_host_async(tree).
A relatively common pattern I've observed is the following:

```python
_, metrics = some_jax_function()

with profiler.Trace('compute_metrics'):
  jax.block_until_ready(metrics)

with profiler.Trace('copy_to_host'):
  metrics = jax.device_get(metrics)
```

We are missing an opportunity here to more eagerly begin the h2d copy of
the metrics (e.g. overlap it with closing the "compute_metrics" context
manager etc. The intention of `jax.copy_to_host_async(x)` is to make it
simple to begin h2d transfers as early as possible. Adapting the above code:

```python
_, metrics = some_jax_function()

# Begin D2H copies as early as we can.
jax.copy_to_host_async(metrics)

with profiler.Trace('compute_metrics'):
  jax.block_until_ready(metrics)

with profiler.Trace('copy_to_host'):
  metrics = jax.device_get(metrics)
```

PiperOrigin-RevId: 731626446
2025-02-27 01:22:15 -08:00
Sharad Vikram
2646b8d4ad [Pallas TPU] Add support for GridDimensionSemantics to pallas_call
PiperOrigin-RevId: 731543938
2025-02-26 19:34:36 -08:00
Sharad Vikram
b5fcffadd4 Add swap as method to TransformedRef
PiperOrigin-RevId: 731541165
2025-02-26 19:19:10 -08:00
Sharad Vikram
1ecbac9702 [Pallas] Add name parameter to core_map
PiperOrigin-RevId: 731536152
2025-02-26 18:59:01 -08:00
Sharad Vikram
0f0d5e90ef Add support for TPU v5 2x2 tray configuration
PiperOrigin-RevId: 731529917
2025-02-26 18:33:49 -08:00
Emily Fertig
82124da5cd Redefine is_fully_addressable in shardings to support zero local devices for McJAX.
PiperOrigin-RevId: 731526750
2025-02-26 18:17:35 -08:00
Emily Fertig
7f9e7473cf Rolling back a commit that caused a 50-90% performance regression in most MaxText workloads.
Reverts 9d421c9149a1db006444adeea87464bd6b8c0743

PiperOrigin-RevId: 731506280
2025-02-26 16:57:18 -08:00
jax authors
615219b1f6 Remove tensorstore dependency from //jax/experimental/array_serialization:serialization in OSS (see https://github.com/google/tensorstore/issues/218)
Disable serialization_test in OSS.

PiperOrigin-RevId: 731463136
2025-02-26 14:47:16 -08:00