25514 Commits

Author SHA1 Message Date
George Necula
a678396f44 Increase test shard_count for shape_poly_test on GPU
PiperOrigin-RevId: 723940915
2025-02-06 08:13:55 -08:00
Adam Paszke
026b6c9704 [Mosaic GPU] Take TMEM as a TMEMRef in tcgen05.mma, not as a raw address
PiperOrigin-RevId: 723936021
2025-02-06 07:59:58 -08:00
Sergei Lebedev
efbb0afd7a [pallas:triton] Temporarily reverted to the lowering using Triton IR
The new lowering caused a performance regression internally.

PiperOrigin-RevId: 723934141
2025-02-06 07:53:04 -08:00
jax authors
5d647ccfa1 Merge pull request #26348 from gnecula:debug_info_jaxpr_3
PiperOrigin-RevId: 723920031
2025-02-06 06:59:18 -08:00
Michael Hudgins
2e808f2836 Merge pull request #26279 from MichaelHudgins:tsan-resultstore
PiperOrigin-RevId: 723918760
2025-02-06 14:55:57 +00:00
George Necula
904b74860c [better_errors] Continue adding debug info to Jaxprs (step 3)
This follows after #26078, and #26313, adding `debug_info` to
more calls to `lu.wrap_init`.

As part of this I have changed the primitives `custom_vjp_call_jaxpr`
and `custom_lin` to take the `bwd` parameter as a `lu.WrappedFun`,
which carries debug info. Previously, this was a `Callable`, but in
almost all cases if was really ` lu.WrappedFun.call_wrapped`.
2025-02-06 16:26:49 +02:00
Michael Hudgins
be0f3d8f1b Update tsan job to upload to resultstore 2025-02-05 20:14:05 +00:00
Yash Katariya
0fb278a0b9 If sharding is not None (that's passed to convert_element_type), only compare it with operand's sharding if the sharding is concrete. Otherwise doing getattr(operand, 'sharding') on a Tracer leads to weird timeouts.
PiperOrigin-RevId: 723595960
2025-02-05 11:50:45 -08:00
Hyeontaek Lim
f43d2b68d9 [JAX] Add a test verifying the behavior of module-level state accessed by colocated Python
A new test verifies that
* Python module-level variables can be created/set and read from a colocated Python function
* Python module-level variables are not pickled on the controller (JAX) or sent to executors via pickling

An API for defining user-defined state and accessing it from multiple colocated
Python functions (i.e., object support) will be added later. That will be a
recommended way to express user-defined state. The capability of accessing
Python module variables is still crucial because a lot of Python code
(including JAX) requires this behavior to implement caching.

PiperOrigin-RevId: 723595727
2025-02-05 11:49:07 -08:00
jax authors
10363663e8 Merge pull request #26339 from jakevdp:mean-doc
PiperOrigin-RevId: 723593813
2025-02-05 11:43:55 -08:00
jax authors
1eda5e2e6e Merge pull request #26259 from Qazalbash:scipy-expon
PiperOrigin-RevId: 723576962
2025-02-05 11:02:13 -08:00
jax authors
c46b0215b0 Merge pull request #26313 from gnecula:debug_info_vjp
PiperOrigin-RevId: 723575296
2025-02-05 10:58:10 -08:00
Jake VanderPlas
9b402ecdb7 doc: add note about f16 casting in jnp.mean 2025-02-05 10:46:07 -08:00
Qazalbash
7fc605f783
Merge branch 'main' into scipy-expon 2025-02-05 23:33:51 +05:00
Parker Schuh
da0827b7f1 Compute buffer aliasing on a per buffer basis.
PiperOrigin-RevId: 723561674
2025-02-05 10:25:04 -08:00
jax authors
d424f5b5b3 Refactor JAX wheel build rules to control the wheel filename and maintain reproducible wheel content and filename results.
This change is a part of the initiative to test the JAX wheels in the presubmit properly.

The list of the changes:
1. JAX wheel build rule verifies that `--@local_config_cuda//cuda:include_cuda_libs=false` during the wheel build. There is a way to pass the restriction by providing `--@local_config_cuda//cuda:override_include_cuda_libs=true`.

2. The JAX version number (which is also used in the wheel filenames) is stored in `_version` variable in the file [version.py](https://github.com/jax-ml/jax/blob/main/jax/version.py). The custom repository rule `jax_python_wheel_version_repository` saves this value in `wheel_version.bzl`, so it becomes available in Bazel build phase.

3. The version suffix of the wheel in the build rule output depends on the environment variables.

   The version suffix chunks that are not reproducible shouldn’t be calculated as a part of the wheel binary: for example, the current date changes every day, thus the wheels built today and tomorrow on the same code version will be technically different. To maintain reproducible wheel content, we need to pass suffix chunks in a form of environment variables.

4. Environment variables combinations for creating wheels with different versions:
  * `0.5.1.dev0+selfbuilt` (local build, default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot`
  * `0.5.1` (release): `--repo_env=ML_WHEEL_TYPE=release`
  * `0.5.1rc1` (release candidate): `--repo_env=ML_WHEEL_TYPE=release --repo_env=ML_WHEEL_VERSION_SUFFIX=rc1`
  * `0.5.1.dev20250128+3e75e20c7` (nightly build): `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=20250128 --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)`

PiperOrigin-RevId: 723552265
2025-02-05 10:01:23 -08:00
Peter Buchlovsky
9f53dfae0b [pallas_mgpu] Fix emit_pipeline_with_wgmma test.
PiperOrigin-RevId: 723547617
2025-02-05 09:47:50 -08:00
Adam Paszke
6d6c8c2e6c [Mosaic GPU] Write a higher-level tcgen05.mma helper reusing WGMMA implementation
Hopper and Blackwell MMA instructions can share a lot of the same logic, which is
why I ended up splitting out a large fraction of WGMMA implementation into a common
utility. This should be an NFC for WGMMA, but it allows us to concisely implement
unrolling of MMAs of different sizes into a number of tcgen05.mma instructions.

PiperOrigin-RevId: 723544349
2025-02-05 09:38:29 -08:00
Emily Fertig
4ae7fcf376 Return arrays from ArrayImpl._check_and_rearrange.
This is in preparation for a larger change, so that input buffers can be checked before Array creation in XLA and the user gets more helpful JAX error messages instead of XLA errors.

Reverts 3b2410f77cdb0acc6951e1770c1229e6689b7409

PiperOrigin-RevId: 723539592
2025-02-05 09:24:22 -08:00
George Necula
abcaec7081 [better_errors] Add debug info to the Jaxprs formed for AD
Following #26078 , we add debug info to more calls of lu.wrap_init.
2025-02-05 19:21:02 +02:00
Jake VanderPlas
e4dac395a5 Roll back multinomial change from https://github.com/jax-ml/jax/pull/25688
This has test breakages on TPU: https://github.com/jax-ml/jax/actions/runs/13159081976/job/36723019653

Reverts 95535df13b422284043623ca3a6d2a5962116fb1

PiperOrigin-RevId: 723536107
2025-02-05 09:13:56 -08:00
Nitin Srinivasan
bc055569df Replace Python 3.12 with Python 3.13 and add Python 3.10 to the matrix
Expands test coverage to cover the oldest and newest Python versions that we support.

PiperOrigin-RevId: 723520699
2025-02-05 08:27:22 -08:00
Nitin Srinivasan
b5b913acb6 Run Bazel CPU/CUDA RBE jobs on pushes to main/release branches
This helps identify breaking commits easily.

PiperOrigin-RevId: 723519953
2025-02-05 08:24:37 -08:00
Bixia Zheng
ae74d3e527 #sdy Fix the format for the external link to jax-shardy-guide colab.
PiperOrigin-RevId: 723511467
2025-02-05 07:59:53 -08:00
Adam Paszke
f4dab0cf72 [Mosaic GPU] Add helpers for dealing with TMEM references + implement optimized loads
The previous example implementation loaded TMEM in a layout that was very hard to
efficiently store into SMEM or GMEM. With the new TMEMRef abstraction, we can implement
loads that yield a FragmentedArray with a new tiled layout that allows for efficient
swizzled stores to SMEM.

The new layout is very similar to the one we've been using for WGMMA on Hopper, only the
initial row tiling is increased to 128 (making each warp hold 32 rows, not 16 as previously).

PiperOrigin-RevId: 723506876
2025-02-05 07:41:40 -08:00
Christos Perivolaropoulos
eeace3ceba [pallas:mgpu] Cast all indices to i32 during lowering.
PiperOrigin-RevId: 723505268
2025-02-05 07:37:04 -08:00
jax authors
3a1c63c50f Update XLA dependency to use revision
46f8cf0390.

PiperOrigin-RevId: 723499604
2025-02-05 07:17:26 -08:00
Adam Paszke
b79ab01ee7 [Mosaic GPU] Refactor the Blackwell matmul example and make it runnable
The previous impelmentation depends on LLVM intrinsics that have not been submitted
yet. This replaces them with inline PTX (as far as I can tell there's no downside to
that) that's encapsulated into convenience functions.

PiperOrigin-RevId: 723498248
2025-02-05 07:11:03 -08:00
Adam Paszke
e7a4f89343 [Mosaic TPU] Add optimized casts for bf16->s4 in TPUv6
PiperOrigin-RevId: 723455843
2025-02-05 04:21:55 -08:00
Adam Paszke
1fbc4a15dd [Mosaic GPU] Infer whether A/B are row- or column-major from strides
There's no need to require extra arguments. This makes our calling convention
saner since the logical dimension order stays the same (e.g. for B it's always
k before n in the shape), only the in-memory representation changes.

Other than the API change, this is a NFC.

PiperOrigin-RevId: 723449720
2025-02-05 04:01:04 -08:00
jax authors
d6be2351d4 Merge pull request #26159 from andportnoy:aportnoy/mosaic-gpu-blackwell-simple-matmul
PiperOrigin-RevId: 723428689
2025-02-05 02:33:57 -08:00
Yash Katariya
c07b6b529a Skip broken tests at HEAD
PiperOrigin-RevId: 723321880
2025-02-04 19:42:45 -08:00
jax authors
9aa9813030 Merge pull request #26324 from mattjj:skip-tests-with-extra-requirements
PiperOrigin-RevId: 723304967
2025-02-04 18:31:23 -08:00
jax authors
781172c24b Merge pull request #26325 from hawkinsp:tpu2
PiperOrigin-RevId: 723298028
2025-02-04 18:07:51 -08:00
Peter Hawkins
b1a2c27aa0 Remove libtpu-nightly dependency from jax[tpu].
For several releases, libtpu-nightly has been a transitional empty package that does nothing. We remove the dependency in preparation for depending on libtpu from pypi instead of a GCS bucket in jax[tpu].
2025-02-04 20:59:30 -05:00
Matthew Johnson
1ae02bc069 skip tests with extra requirements 2025-02-05 01:48:28 +00:00
Yash Katariya
307006e194 Set the mesh as manual during partial_eval_custom in shard_map so that _add_reshapes happens under the correct mesh.
PiperOrigin-RevId: 723268798
2025-02-04 16:36:08 -08:00
Sharad Vikram
02f4531310 [Pallas TPU] Add helpers for writing collectives
PiperOrigin-RevId: 723250661
2025-02-04 15:39:10 -08:00
jax authors
fdcc04c3c4 Merge pull request #26281 from jakevdp:lax-docs
PiperOrigin-RevId: 723219794
2025-02-04 14:05:31 -08:00
jax authors
023f06f99c Merge pull request #26278 from jax-ml:dependabot/github_actions/actions/setup-python-5.4.0
PiperOrigin-RevId: 723208433
2025-02-04 13:33:01 -08:00
dependabot[bot]
333c4e7a0e
Bump actions/setup-python from 5.3.0 to 5.4.0
Bumps [actions/setup-python](https://github.com/actions/setup-python) from 5.3.0 to 5.4.0.
- [Release notes](https://github.com/actions/setup-python/releases)
- [Commits](0b93645e9f...42375524e2)

---
updated-dependencies:
- dependency-name: actions/setup-python
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-02-04 21:07:18 +00:00
jax authors
694572bcaf Merge pull request #26310 from dfm:use-uv-in-ci
PiperOrigin-RevId: 723199573
2025-02-04 13:06:09 -08:00
Sharad Vikram
782215b099 Add get/set methods to TransformedRef
PiperOrigin-RevId: 723188696
2025-02-04 12:34:34 -08:00
jax authors
414449e142 Merge pull request #26078 from gnecula:debug_info_jaxpr
PiperOrigin-RevId: 723151082
2025-02-04 10:54:26 -08:00
Qazalbash
8561f90f8c
fix: simplify logcdf implementation by removing unnecessary argument promotion 2025-02-04 23:43:15 +05:00
jax authors
b1b88a3613 Merge pull request #26307 from dfm:kill-ffi-cache
PiperOrigin-RevId: 723131658
2025-02-04 10:07:04 -08:00
Jake VanderPlas
f180353d78 jax.lax: improve docs for exp & log functions 2025-02-04 09:33:52 -08:00
jax authors
09ee37a41d Merge pull request #26302 from vfdev-5:readd-missed-sed-part-in-tsan
PiperOrigin-RevId: 723115097
2025-02-04 09:23:32 -08:00
jax authors
2b94444226 Merge pull request #26300 from vfdev-5:patch-2
PiperOrigin-RevId: 723113101
2025-02-04 09:17:14 -08:00
Dan Foreman-Mackey
5db5e0d5ca Use uv for dependency resolution on CI. 2025-02-04 11:48:05 -05:00