9761 Commits

Author SHA1 Message Date
jax authors
1a8d537728 Merge pull request #26384 from gnecula:debug_info_jaxpr_4
PiperOrigin-RevId: 725210049
2025-02-10 07:42:57 -08:00
Adam Paszke
26d8e112e3 Add a missing jaxlib version check in ragged_collective_test
PiperOrigin-RevId: 725186144
2025-02-10 06:12:36 -08:00
Bart Chrzaszcz
6ed4c29c8a #sdy enable test_mem_kind_donation_pinned_host for Shardy.
PiperOrigin-RevId: 725142884
2025-02-10 03:15:25 -08:00
jax authors
6740165e4f [Pallas] Add pipeline mode to pltpu
PiperOrigin-RevId: 725133131
2025-02-10 02:36:44 -08:00
George Necula
817b3e5757 [better_errors] Continue adding debug info to Jaxprs (step 7)
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.

Fixes in jet, stateful code, key_reuse, ode, pallas, tests.
2025-02-09 18:14:33 +02:00
George Necula
1e813e1693 [better_errors] Continue adding debug info to Jaxprs (step 4)
This follows after #26078, #26313, #26348, adding `debug_info` to more calls to `lu.wrap_init`.

As part of this I have changed the primitive `custom_transpose` to take the `transpose` parameter as a `lu.WrappedFun`, which carries debug info. Previously, this was a `Callable`.

These changes ensure that all the `lu.wrap_init` and `Jaxpr` are called with debug_info in the `api_test.py:CustomTransposeTest`.
2025-02-08 09:13:55 +02:00
jax authors
289035747e Merge pull request #26407 from jakevdp:printoptions-doc
PiperOrigin-RevId: 724487999
2025-02-07 15:22:52 -08:00
Gleb Pobudzey
cd0753751c Increase the absolute error tolerance to fix flaky tests.
PiperOrigin-RevId: 724424293
2025-02-07 11:59:13 -08:00
jax authors
c0ba36260e Merge pull request #26377 from mattjj:maintain-mutable-array-sharding
PiperOrigin-RevId: 724405629
2025-02-07 11:04:56 -08:00
Sergei Lebedev
e5058079c9 [pallas:mosaic_gpu] Fixed a bug in how delay_release is handled in emit_pipeline
PiperOrigin-RevId: 724395676
2025-02-07 10:37:21 -08:00
Matthew Johnson
719031c1fd [mutable-arrays] persist shardings through xla computations 2025-02-07 18:33:24 +00:00
jax authors
3b470b9530 Merge pull request #26383 from jakevdp:jnp-sorting
PiperOrigin-RevId: 724381260
2025-02-07 10:00:29 -08:00
Jake VanderPlas
08563842b9 DOC: make clear that printoptions are NumPy aliases 2025-02-07 09:56:52 -08:00
Dan Foreman-Mackey
c521bc6205 [xla:python] Add a mechanism for "batch partitioning" of FFI calls.
This is the first in a series of changes to add a simple API for supporting a set of common sharding and partitioning patterns for FFI calls. The high level motivation is that custom calls (including FFI calls) are opaque to the SPMD partitioner, and the only ways to customize the partitioning behavior is to (a) explicitly register an `xla::CustomCallPartitoner` with XLA, or (b) use the `jax.experimental.custom_partitioning` APIs. Option (a) isn't generally practical for most use cases where the FFI handler lives in an external binary. Option (b) is flexible, and supports all common use cases, but it requires embedding Python callbacks in to the HLO, which can lead to issues including cache misses. Furthermore, `custom_partitioning` is overpowered for many use cases, where only (what I will call) "batch partitioning" is supported.

In this case, "batch partitioning" refers to the behavior of many FFI calls where they can be trivially partitioned on some number of (leading) dimensions, with the same call being executed independently on each shard of data. If the data are sharded on non-batch dimensions, partitioning will still re-shard the data to be replicated on the non-batch dimensions. This kind of partitioning logic applies to all the LAPACK/cuSOLVER/etc.-backed linear algebra functions in jaxlib, as well as some external users of `custom_partitioning`.

The approach I'm taking here is to add a new registration function to the XLA client, which let's a user label their FFI call as batch partitionable. Then, when lowering the custom call, the user passes the number of batch dimensions as a frontend attribute, which is then interpreted by the SPMD partitioner.

In parallel with this change, shardy has added support for sharding propagation across custom calls using a string representation that is similar in spirit to this approach, but somewhat more general. However, the shardy implementation still requires a Python callback for the partitioning step, so it doesn't (yet!) solve all of the relevant problems with the `custom_partitioning` approach. Ultimately, it should be possible to have the partitioner parse the shardy sharding rule representation, but I wanted to start with the minimal implementation.

PiperOrigin-RevId: 724367877
2025-02-07 09:14:06 -08:00
Jake VanderPlas
d3b3cd369f refactor: move sorting ops out of lax_numpy 2025-02-07 08:18:04 -08:00
jax authors
e56c7dc502 Merge pull request #26344 from Cjkkkk:disable_head_256_on_bw
PiperOrigin-RevId: 724333455
2025-02-07 07:08:18 -08:00
Sergei Lebedev
35351f95e4 [pallas:triton] Really revert to the lowering using Triton IR
PiperOrigin-RevId: 724329911
2025-02-07 06:55:14 -08:00
Adam Paszke
3b5e91b8a8 [Mosaic GPU] Add tests for various tcgen05.mma configurations
It would be good to add smaller tests that verify reads and writes to TMEM,
since we depend on it here, but that will come later.

PiperOrigin-RevId: 724328602
2025-02-07 06:50:11 -08:00
Gunhyun Park
2828bce2e6 Add check to lax.composite to prevent DynamicJaxprTracer type errors.
There's some confusion on whether jax arrays can be used inside the attributes, so I made the error more explicit.

PiperOrigin-RevId: 724316766
2025-02-07 06:02:24 -08:00
George Necula
000b92f539 [better_errors] Continue adding debug info to Jaxprs (step 5)
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.

These changes ensure that all the lu.wrap_init and Jaxpr are called with debug_info in the api_test.py:CustomTransposeTest,
api_test.py:CustomVmapTest and api_test.py:RematTest.
2025-02-07 08:23:10 +02:00
George Necula
c1734f2253 Fix file path regexp for Windows
PiperOrigin-RevId: 724194012
2025-02-06 21:24:30 -08:00
Dan Foreman-Mackey
c6e83903de Update RNN kernels to use FFI.
PiperOrigin-RevId: 724151647
2025-02-06 18:27:58 -08:00
jax authors
4b86ff22e9 Merge pull request #25097 from jburnim:jburnim_pallas_interpret_mode
PiperOrigin-RevId: 724073443
2025-02-06 14:22:49 -08:00
Jacob Burnim
1c82484c9b Start a new TPU interpret mode for Pallas.
The goal of this interpret mode is to run a Pallas TPU kernel on CPU,
while simulating a TPU's shared memory, multiple devices/cores, remote
DMAs, and synchronization.

The basic approach is to execute the kernel's Jaxpr on CPU, but to
replace all load/store, DMA, and synchronization primitives with
io_callbacks to a Python functions that simulate these primitives.
When this interpret mode is run inside of shard_map and jit, the
shards will run in parallel, simulating the parallel execution of the
kernel on multiple TPU devices.

The initial version in this PR can successfully interpret the examples
in https://jax.readthedocs.io/en/latest/pallas/tpu/distributed.html ,
but is still missing a lot of functionality, including:

 - Executing DMAs asynchronously.

 - Padding in pallas_call.

 - Propagating source info.
2025-02-06 13:04:14 -08:00
jax authors
1b7b04f7db Merge pull request #26369 from justinjfu:sourcemap_win
PiperOrigin-RevId: 724032624
2025-02-06 12:32:10 -08:00
Justin Fu
49fa1274b5 Temporarily disable source mapper on win32 2025-02-06 12:22:14 -08:00
Ayaka
afad924de7 [Pallas TPU] Remove obsolete skip condition
PiperOrigin-RevId: 723963888
2025-02-06 09:23:21 -08:00
George Necula
a678396f44 Increase test shard_count for shape_poly_test on GPU
PiperOrigin-RevId: 723940915
2025-02-06 08:13:55 -08:00
jax authors
5d647ccfa1 Merge pull request #26348 from gnecula:debug_info_jaxpr_3
PiperOrigin-RevId: 723920031
2025-02-06 06:59:18 -08:00
Michael Hudgins
2e808f2836 Merge pull request #26279 from MichaelHudgins:tsan-resultstore
PiperOrigin-RevId: 723918760
2025-02-06 14:55:57 +00:00
George Necula
904b74860c [better_errors] Continue adding debug info to Jaxprs (step 3)
This follows after #26078, and #26313, adding `debug_info` to
more calls to `lu.wrap_init`.

As part of this I have changed the primitives `custom_vjp_call_jaxpr`
and `custom_lin` to take the `bwd` parameter as a `lu.WrappedFun`,
which carries debug info. Previously, this was a `Callable`, but in
almost all cases if was really ` lu.WrappedFun.call_wrapped`.
2025-02-06 16:26:49 +02:00
Olli Lupton
d8f811e790 Remove test for CUPTI multi-subscriber error message that needed cupti-python and a subprocess. 2025-02-06 08:30:34 +00:00
Olli Lupton
1bba1ea2e2 Add JAX_COMPILATION_CACHE_EXPECT_PGLE option
This allows using external profiling tools, such as Nsight Systems,
with the automatic PGLE workflow supported by JAX with a simple two-step
workflow:

export JAX_COMPILATION_CACHE_DIR=...
JAX_ENABLE_PGLE=yes python model.py
JAX_COMPILATION_CACHE_EXPECT_PGLE=yes nsys profile python model.py
2025-02-06 08:19:45 +00:00
shuw
061d4acbfb Scaled matmul for mxfp8 2025-02-05 23:25:51 +00:00
cjkkkk
553199e4dc disable head dim 256 on bw now 2025-02-05 22:07:48 +00:00
Hyeontaek Lim
f43d2b68d9 [JAX] Add a test verifying the behavior of module-level state accessed by colocated Python
A new test verifies that
* Python module-level variables can be created/set and read from a colocated Python function
* Python module-level variables are not pickled on the controller (JAX) or sent to executors via pickling

An API for defining user-defined state and accessing it from multiple colocated
Python functions (i.e., object support) will be added later. That will be a
recommended way to express user-defined state. The capability of accessing
Python module variables is still crucial because a lot of Python code
(including JAX) requires this behavior to implement caching.

PiperOrigin-RevId: 723595727
2025-02-05 11:49:07 -08:00
jax authors
1eda5e2e6e Merge pull request #26259 from Qazalbash:scipy-expon
PiperOrigin-RevId: 723576962
2025-02-05 11:02:13 -08:00
jax authors
c46b0215b0 Merge pull request #26313 from gnecula:debug_info_vjp
PiperOrigin-RevId: 723575296
2025-02-05 10:58:10 -08:00
Qazalbash
7fc605f783
Merge branch 'main' into scipy-expon 2025-02-05 23:33:51 +05:00
Parker Schuh
da0827b7f1 Compute buffer aliasing on a per buffer basis.
PiperOrigin-RevId: 723561674
2025-02-05 10:25:04 -08:00
jax authors
d424f5b5b3 Refactor JAX wheel build rules to control the wheel filename and maintain reproducible wheel content and filename results.
This change is a part of the initiative to test the JAX wheels in the presubmit properly.

The list of the changes:
1. JAX wheel build rule verifies that `--@local_config_cuda//cuda:include_cuda_libs=false` during the wheel build. There is a way to pass the restriction by providing `--@local_config_cuda//cuda:override_include_cuda_libs=true`.

2. The JAX version number (which is also used in the wheel filenames) is stored in `_version` variable in the file [version.py](https://github.com/jax-ml/jax/blob/main/jax/version.py). The custom repository rule `jax_python_wheel_version_repository` saves this value in `wheel_version.bzl`, so it becomes available in Bazel build phase.

3. The version suffix of the wheel in the build rule output depends on the environment variables.

   The version suffix chunks that are not reproducible shouldn’t be calculated as a part of the wheel binary: for example, the current date changes every day, thus the wheels built today and tomorrow on the same code version will be technically different. To maintain reproducible wheel content, we need to pass suffix chunks in a form of environment variables.

4. Environment variables combinations for creating wheels with different versions:
  * `0.5.1.dev0+selfbuilt` (local build, default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot`
  * `0.5.1` (release): `--repo_env=ML_WHEEL_TYPE=release`
  * `0.5.1rc1` (release candidate): `--repo_env=ML_WHEEL_TYPE=release --repo_env=ML_WHEEL_VERSION_SUFFIX=rc1`
  * `0.5.1.dev20250128+3e75e20c7` (nightly build): `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=20250128 --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)`

PiperOrigin-RevId: 723552265
2025-02-05 10:01:23 -08:00
Peter Buchlovsky
9f53dfae0b [pallas_mgpu] Fix emit_pipeline_with_wgmma test.
PiperOrigin-RevId: 723547617
2025-02-05 09:47:50 -08:00
Emily Fertig
4ae7fcf376 Return arrays from ArrayImpl._check_and_rearrange.
This is in preparation for a larger change, so that input buffers can be checked before Array creation in XLA and the user gets more helpful JAX error messages instead of XLA errors.

Reverts 3b2410f77cdb0acc6951e1770c1229e6689b7409

PiperOrigin-RevId: 723539592
2025-02-05 09:24:22 -08:00
George Necula
abcaec7081 [better_errors] Add debug info to the Jaxprs formed for AD
Following #26078 , we add debug info to more calls of lu.wrap_init.
2025-02-05 19:21:02 +02:00
Jake VanderPlas
e4dac395a5 Roll back multinomial change from https://github.com/jax-ml/jax/pull/25688
This has test breakages on TPU: https://github.com/jax-ml/jax/actions/runs/13159081976/job/36723019653

Reverts 95535df13b422284043623ca3a6d2a5962116fb1

PiperOrigin-RevId: 723536107
2025-02-05 09:13:56 -08:00
Adam Paszke
1fbc4a15dd [Mosaic GPU] Infer whether A/B are row- or column-major from strides
There's no need to require extra arguments. This makes our calling convention
saner since the logical dimension order stays the same (e.g. for B it's always
k before n in the shape), only the in-memory representation changes.

Other than the API change, this is a NFC.

PiperOrigin-RevId: 723449720
2025-02-05 04:01:04 -08:00
Yash Katariya
c07b6b529a Skip broken tests at HEAD
PiperOrigin-RevId: 723321880
2025-02-04 19:42:45 -08:00
Matthew Johnson
1ae02bc069 skip tests with extra requirements 2025-02-05 01:48:28 +00:00
jax authors
414449e142 Merge pull request #26078 from gnecula:debug_info_jaxpr
PiperOrigin-RevId: 723151082
2025-02-04 10:54:26 -08:00
Roy Frostig
58a202ad18 fix binomial sample test on TPU
The recent partitionable Threefry upgrade affects binomial sampling under the RBG PRNG scheme because the implementation of `jax.random.binomial` derives internal subkeys with a call to `split`. This led a randomized test to fail by pushing its numeric closeness check just beyond its current relative tolerance. This is very likely a false failure, so we update the rtol.

PiperOrigin-RevId: 723100174
2025-02-04 08:40:38 -08:00