26442 Commits

Author SHA1 Message Date
Sharad Vikram
c6b164dc09 [Pallas/Fuser] Add custom evaluate to allow/disallow transposes
PiperOrigin-RevId: 735931978
2025-03-11 16:35:49 -07:00
Yash Katariya
f45cbf3342 Fix a bug where full and use_mesh outside jit did not work because the shard passed to make_array_from_callback was sharded on all devices instead of just 1 device.
This is because `convert_element_type` returning an output on all devices of the mesh because of the surrounding `use_mesh` context.

PiperOrigin-RevId: 735909962
2025-03-11 15:25:46 -07:00
Jevin Jiang
29bfd00f9c [Pallas TPU] Fix preferred_element_type propagation in dot_general with const
PiperOrigin-RevId: 735903687
2025-03-11 15:06:07 -07:00
jax authors
13eb8d3ae7 Upgrade ml-dtypes version in py3.10-py3.13 hermetic python lock files.
This change is needed to add testing of int2/uint2 dtypes via bazel in presubmit (see https://github.com/jax-ml/jax/pull/21395).

PiperOrigin-RevId: 735895293
2025-03-11 14:41:34 -07:00
Kanglan Tang
4df691ec00 Remove unsupported mac x86 CI build options
PiperOrigin-RevId: 735885305
2025-03-11 14:12:51 -07:00
jax authors
7ac088c14f Merge pull request #20699 from pearu:pearu/gammainc
PiperOrigin-RevId: 735878582
2025-03-11 13:53:20 -07:00
Dimitar (Mitko) Asenov
99c9106032 [Mosaic GPU] Replace WGMMAFragLayout with TiledLayout in the mlir dialect and use it in layout inference.
`WGMMAFragLayout` will be completely removed soon.

PiperOrigin-RevId: 735877661
2025-03-11 13:50:42 -07:00
Peter Hawkins
67aa997f84 Increase the number of iterations in a test that compares rolled versus unrolled HLO for length.
A change that avoids duplicating subcomputations in XLA causes this test to fail, but we can make it work again by increasing the number of iterations.

PiperOrigin-RevId: 735875835
2025-03-11 13:45:19 -07:00
jax authors
e0545a71eb Remove installation of NVIDIA wheels for CPU tests
PiperOrigin-RevId: 735875073
2025-03-11 13:43:13 -07:00
Jevin Jiang
eff612a3b6 Fix the assumption that pages_per_seq is already a multiple of num_kv_pages_per_blk.
PiperOrigin-RevId: 735851301
2025-03-11 12:36:33 -07:00
jax authors
0db14aa342 Add NVIDIA wheel requirements only for Linux builds.
PiperOrigin-RevId: 735850240
2025-03-11 12:33:54 -07:00
shuw
f9aef8a189 Support nvfp4 2025-03-11 19:33:25 +00:00
Pearu Peterson
82b2591b21 Fix scipy.special.gammainc/gammaincc evaluation at boundary points 2025-03-11 21:18:47 +02:00
github-actions[bot]
6ee76a8a6d
Merge pull request #271 from ROCm/ci-upstream-sync-142_1
CI: 03/11/25 upstream sync
2025-03-11 14:10:03 -05:00
Nitin Srinivasan
7ac6355262 Add TPU test jobs to the new CI continuous and nightly/release test workflows
Also, modify the TPU presubmit workflow to reuse the `build_artifacts.yml` and `pytest_tpu.yml`

PiperOrigin-RevId: 735832964
2025-03-11 11:42:21 -07:00
jax authors
c2c68c018f Merge pull request #27059 from jakevdp:fix-while-loop
PiperOrigin-RevId: 735828960
2025-03-11 11:32:00 -07:00
Gunhyun Park
d191927b24 Fix syntax error and typos for composite primitive docstring.
PiperOrigin-RevId: 735808000
2025-03-11 10:37:07 -07:00
Adam Paszke
6f7ce9d048 Skip ASAN tests for the big Mosaic GPU tests
They are timing out.

PiperOrigin-RevId: 735804647
2025-03-11 10:30:04 -07:00
JD
ce53e374fc
Deprecate obsolete gfx versions (#273) 2025-03-11 12:19:51 -05:00
Jake VanderPlas
4ae3211ea2 jax.disable_jit: ensure while_loop behaves similarly to non-disable_jit version 2025-03-11 09:53:34 -07:00
Adam Paszke
30a9e1b3bf [Mosaic GPU] Add support for .cta_group::2 MMA with n=512 on Blackwell
This one is particularly annoying, because we have to break up the MMA
into two collective N=256 MMAs. However, TensorCore only updates a contiguous
chunk of columns in TMEM and so after executing two of those we end up with
a TMEM layout that looks like this:

```
Contributing CTA |    0    |    1    |    0    |    1    |
N local          |   0:128 |   0:128 | 128:256 | 128:256 |
N                |   0:128 | 256:384 | 128:256 | 384:512 |
```

You can see that the TMEM columns no longer monotonically go over all
columns until N=512, but they include a number of jumps!

We could fix this on the load side, by ensuring that each CTA in the group
does a strided load along the tiled dimension, but that just seems more
trouble than it's worth (and is not that well supported by TMA unless we
increase the number of striding levels).

Instead, we encode this weirdness in the TMEM layout we use and make sure
to rearrange the data properly while loading the tiles into registers.

PiperOrigin-RevId: 735791426
2025-03-11 09:53:20 -07:00
Charles Hofer
fb89a4b427 Merge branch 'rocm-main' into ci-upstream-sync-142_1 2025-03-11 16:33:59 +00:00
jax authors
1aca76fc13 Update :build_jaxlib flag to control whether we should add py_import dependencies to the test targets.
This change enables testing the wheels produced by the build rules in the presubmit using one `bazel test` command only.

There are three options for running the tests:

1) `build_jaxlib=true`: the tests depend on JAX targets.
2) `build_jaxlib=false`: the tests depend on the wheel files located in the `dist` folder.
3) `build_jaxlib=wheel`: the tests depend on the py_import targets.

PiperOrigin-RevId: 735765819
2025-03-11 08:31:43 -07:00
Benjamin Chetioui
7fd32ecc04 [Pallas/Mosaic GPU] Explicitly disable ops_test on Mosaic GPU pre-Hopper.
PiperOrigin-RevId: 735744473
2025-03-11 07:11:09 -07:00
jax authors
b6da46ecda Update XLA dependency to use revision
fae64d49aa.

PiperOrigin-RevId: 735709684
2025-03-11 04:41:55 -07:00
Shraiysh
cb2eb15739 PR #22800: Change the default value of print_operand_shape_ to false and print_large_constants_ to true.
Imported from GitHub PR https://github.com/openxla/xla/pull/22800

Operand shape in long hlo text adds redundant information, which shouldn't be required. Changing the default value to off.

The large constants were also printed earlier by default print options, and it is required for parsability and reproducibility. Turning this on by default. This is still controlled by debug option and the default value of that flag disables the large constants, and that behavior is not changed. Just the default print options change here.

Copybara import of the project:

--
e30dea20489b3fb4d03d373fec0391d69486f4aa by Shraiysh Vaishay <svaishay@nvidia.com>:

Change the default value of print_operand_shape_ to false and print_large_constants_ to true.

Operand shape in long hlo text adds redundant information, which
shouldn't be required. Changing the default value to off.

The large constants were also printed earlier by default print options,
and it is required for parsability and reproducibility. Turning this on by default.
This is still controlled by debug option and the default value of that
flag disables the large constants, and that behavior is not changed. Just the
default print options change here.

--
7008af0dd0ce342ecbe9475f1d0e277319f1705a by Shraiysh Vaishay <svaishay@nvidia.com>:

Handle tests

--
b22d5f95cfb7e15f930a2198279a76c38593cc53 by Shraiysh Vaishay <svaishay@nvidia.com>:

Fix more tests

--
d51579cae7359c6426a87ad4a7ff1b4b0c80f74a by Shraiysh Vaishay <svaishay@nvidia.com>:

Fix more tests

Merging this change closes #22800

PiperOrigin-RevId: 735690598
2025-03-11 03:17:04 -07:00
Yash Katariya
76dec38286 Under pjit the with mesh: context will use use_mesh(mesh): jit instead of tracking separately using resource_env.
This would also make it easier to deprecate the `with mesh: pjit` path in the future from user code since the new path would be completely tested.
This will also allow us to remove `resource_env` from JAX and the internal API access of `resource_env.physical_mesh` spread throughout codebases internally and externally.

PiperOrigin-RevId: 735602187
2025-03-10 20:21:02 -07:00
jax authors
02505fa757 [Pallas TPU] Remove next_slot SMEM tensor from pipeline emitter
PiperOrigin-RevId: 735564365
2025-03-10 17:19:39 -07:00
Ayaka
988a1208a9 Better error message when raise_if_error() is called within a traced context
PiperOrigin-RevId: 735557928
2025-03-10 16:55:06 -07:00
jax authors
aceae84fab [Pallas] Enable skipping of floating-point operations when interpreting Pallas TPU kernels on CPU.
PiperOrigin-RevId: 735527650
2025-03-10 15:14:00 -07:00
Jacob Burnim
802cb33bf8 [Pallas] Increase tolerance in PallasOutOfBoundsInterpretTest.
PiperOrigin-RevId: 735519526
2025-03-10 14:49:34 -07:00
jax authors
d55879723e Merge pull request #26840 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 735513976
2025-03-10 14:33:14 -07:00
jax authors
b8590816bf Merge pull request #26839 from Sai-Suraj-27:fix_jax.debug.print
PiperOrigin-RevId: 735511953
2025-03-10 14:26:45 -07:00
Sharad Vikram
81dde225b0 [Pallas/Fuser] Add select_n push rule
PiperOrigin-RevId: 735510713
2025-03-10 14:23:01 -07:00
jax authors
261e6e5fdc Merge pull request #27038 from jakevdp:vmap-sentinel
PiperOrigin-RevId: 735510065
2025-03-10 14:21:11 -07:00
jax authors
c942b0fef0 Merge pull request #26977 from jakevdp:fix-expn
PiperOrigin-RevId: 735506133
2025-03-10 14:09:32 -07:00
Sharad Vikram
87272fbe93 [Pallas/Fuser] Add debug option to fuser.fuse that prints out jaxpr
PiperOrigin-RevId: 735505460
2025-03-10 14:07:26 -07:00
carlosgmartin
8b6ca56417 Fix the ValueError message for random.binomial (forgot to use string formatting). 2025-03-10 16:38:03 -04:00
jax authors
affe2e734e Rename dot_with_no_batch_dims_saveable to dots_with_no_batch_dims_saveable for internal consistency
PiperOrigin-RevId: 735484326
2025-03-10 13:04:49 -07:00
Praveen Narayanan
b6d4fe5387 Define lax.ragged_dot_general and express lax.ragged_dot in terms of it.
PiperOrigin-RevId: 735471245
2025-03-10 12:25:22 -07:00
jax authors
64beebbfb0 Merge pull request #27035 from vfdev-5:add-file-zip-to-tsan-ci-jobs
PiperOrigin-RevId: 735467348
2025-03-10 12:12:49 -07:00
jax authors
18f2f19c1a Merge pull request #26525 from wenscarl:e2m1fn
PiperOrigin-RevId: 735457804
2025-03-10 11:46:18 -07:00
Jacob Burnim
73d20cd62a [Pallas] Small fix to TPU interpret mode (input_output_aliases + scalar args).
PiperOrigin-RevId: 735455671
2025-03-10 11:40:10 -07:00
Jake VanderPlas
8ecadfdf9d Internal: make it easier to detect the vmap sentinel 2025-03-10 11:37:50 -07:00
jax authors
007fc7a6f1 Remove version limit for setuptools dependency.
PiperOrigin-RevId: 735453796
2025-03-10 11:36:17 -07:00
Nitin Srinivasan
d41e96835b Modify version test to consider "rc" versions as well
I was testing the RC promotion workflow and found that the version test failed as it does not consider pre-releases. Therefore, this commit modifies the `VERSION_PATTERN` to also consider "rc" wheels.

Fixes https://github.com/jax-ml/jax/actions/runs/13705984545/job/38331236497

PiperOrigin-RevId: 735444828
2025-03-10 11:10:18 -07:00
Michael Whittaker
5cb29949d4 Warn the user if transparent huge pages aren't enabled.
PiperOrigin-RevId: 735431881
2025-03-10 10:37:58 -07:00
jax authors
14b215fe76 Merge pull request #27032 from dfm:lax-dtype
PiperOrigin-RevId: 735424674
2025-03-10 10:18:58 -07:00
vfdev
1bab037ca0 Add file and zip to tsan.yaml 2025-03-10 16:51:05 +00:00
jax authors
ab0ce8a448 Merge pull request #26811 from dfm:direct-lin
PiperOrigin-RevId: 735388827
2025-03-10 08:39:49 -07:00