21889 Commits

Author SHA1 Message Date
Ruturaj4
ec7d625c36 [ROCM] add pytest reportlog rocm-jax-stable-2024_07_16 2024-07-16 10:17:36 -05:00
Ruturaj4
7b9d47c2fc [JAX] Fix run_single_gpu.py to include correct paths 2024-07-16 10:15:54 -05:00
Rahul Batra
aa316a4b16 [ROCm]: Add support to continue on fail 2024-07-16 10:15:47 -05:00
Kaixi Hou
09531d2ff8 PR #22330: [NVIDIA] Remove logic of combining bias and mask
Imported from GitHub PR https://github.com/google/jax/pull/22330

The cudnn API has already supported the combination of bias and mask from [this PR](https://github.com/google/jax/pull/22078). We are removing the logic from the public sdpa API and pass the mask directly.

cc. @Cjkkkk
Copybara import of the project:

--
0f75f58a9d81c0ae0a83701a71998c940318732a by kaixih <kaixih@nvidia.com>:

Remove logic of combining bias and mask

Merging this change closes #22330

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/22330 from kaixih:remove_combine_bias_mask 0f75f58a9d81c0ae0a83701a71998c940318732a
PiperOrigin-RevId: 652830016
2024-07-16 07:19:01 -07:00
Sergei Lebedev
7a62b8dd18 Re-enabled PallasCallPrintTest on Cloud TPUs
PiperOrigin-RevId: 652823653
2024-07-16 06:55:20 -07:00
George Necula
e7be205a39 [jax2tf] Fix jax2tf tolerance for Cholesky, needed for newer TPUs
PiperOrigin-RevId: 652751771
2024-07-16 02:13:11 -07:00
George Necula
d34a6e9ce2 [jax2tf] Deprecate jax2tf with native_serialization=False or enable_xla=False.
Also disable many of the non-native-serialization jax2tf tests.
In particular, I am disabling the thousands of primitives tests in
graph serialization mode.
I kept jax2tf_test running in both native and graph serialization mode.

PiperOrigin-RevId: 652749891
2024-07-16 02:05:43 -07:00
Christos Perivolaropoulos
28ffa25496 [MosaicGPU] Move parity computations to a separate function to allow the user to use wait_parity without duplicate code.
PiperOrigin-RevId: 652665738
2024-07-15 19:08:10 -07:00
Justin Fu
0690988626 [Pallas] Add limited boolean memref support for scalars.
PiperOrigin-RevId: 652653003
2024-07-15 17:59:05 -07:00
jax authors
21e9dabdad Merge pull request #22465 from ROCm:ci_multiprocess_gpu_test
PiperOrigin-RevId: 652602616
2024-07-15 14:38:58 -07:00
Ruturaj4
fb5c516405 [ROCM] test_computation_follows_data fix for rocm and cuda 2024-07-15 16:28:22 -05:00
jax authors
3e09040cf7 Merge pull request #22463 from hawkinsp:tpuci2
PiperOrigin-RevId: 652597728
2024-07-15 14:25:15 -07:00
jax authors
fc3654792b Merge pull request #22450 from kaixih:update_sdpa_doc
PiperOrigin-RevId: 652597691
2024-07-15 14:21:27 -07:00
Yash Katariya
bb7a6995f9 Remove the spmd_mode check. It's disabled in OSS since a long time.
PiperOrigin-RevId: 652591122
2024-07-15 13:58:23 -07:00
Peter Hawkins
f488c4cc31 Disable some tests that fail on Cloud TPU. 2024-07-15 16:00:58 -04:00
Peter Hawkins
2f45cd725a Bump some SVD test tolerances.
These just barely fail on recent TPUs.

PiperOrigin-RevId: 652571985
2024-07-15 12:54:00 -07:00
jax authors
5d3392927a Update XLA dependency to use revision
83e38528e0.

PiperOrigin-RevId: 652563753
2024-07-15 12:27:15 -07:00
Kevin Gleason
5e897c61f5 Integrate StableHLO at openxla/stablehlo@8817ff1d
PiperOrigin-RevId: 652528759
2024-07-15 10:38:09 -07:00
kaixih
0d387e0839 Update jax doc sdpa 2024-07-15 17:30:54 +00:00
jax authors
26ec43f9e5 Merge pull request #22445 from dfm:numpy-nightly-unique
PiperOrigin-RevId: 652520243
2024-07-15 10:16:03 -07:00
jax authors
cab1f85f09 Merge pull request #22448 from hawkinsp:tpuci
PiperOrigin-RevId: 652520107
2024-07-15 10:12:39 -07:00
jax authors
7255ab146b Merge pull request #22440 from gnecula:pallas_test_clean
PiperOrigin-RevId: 652513116
2024-07-15 09:55:38 -07:00
jax authors
2b29a94255 Merge pull request #22375 from jakevdp:mypy-docs
PiperOrigin-RevId: 652511749
2024-07-15 09:52:07 -07:00
jax authors
5216719996 Merge pull request #22405 from gnecula:poly_pad
PiperOrigin-RevId: 652511693
2024-07-15 09:48:21 -07:00
Yash Katariya
b6264e99b1 Skip layout tests because they require xla_extension_version >= 274
PiperOrigin-RevId: 652510288
2024-07-15 09:43:03 -07:00
Peter Hawkins
3019467992 Update references to jaxlib_nightly_releases (which is a legacy index and no longer updated) to jax_nightly_releases.
Fixes CI jobs using old jaxlib nightly wheels.
2024-07-15 12:42:50 -04:00
Peter Hawkins
a1f69713f5 Disable Pallas vmap test that is very slow under tsan.
PiperOrigin-RevId: 652505878
2024-07-15 09:28:35 -07:00
jax authors
9c72e67711 Merge pull request #22169 from gspschmid:gschmid/async_serialize-overlap-shard-copies
PiperOrigin-RevId: 652502425
2024-07-15 09:15:21 -07:00
jax authors
86f4bb4346 Added more Mosaic bug reproducers.
PiperOrigin-RevId: 652498944
2024-07-15 09:02:48 -07:00
Dan Foreman-Mackey
7857bd3319 Fix compatibility of jnp.unique with numpy nightly
In https://github.com/numpy/numpy/pull/26914, the behavior of the
`return_inverse` argument to `np.unique` was partially reverted to the
pre-v2.0 behavior. The PR brings JAX's implementation compatible with
the `numpy>2.0.0` behavior.
2024-07-15 11:41:00 -04:00
jax authors
2e7e700090 Merge pull request #22402 from superbobry:maint
PiperOrigin-RevId: 652487969
2024-07-15 08:20:56 -07:00
George Necula
791743c296 [pallas] Some minor test housekeeping.
Add missing interpreter tests in tpu_pallas_test.
2024-07-15 17:05:15 +02:00
George Necula
7817b6785b [shape_poly] Expand the support for shape polymorphism for jnp.pad
Handle several new padding modes: wrap, reflect, symmetric, linear_ramp, maximum.
Not all situations are handled; try to give a clear error for the unsupported
cases.

While implementing this, I needed to add shape polymorphism support
also for jnp.linspace.

And I discovered a bug in the implementation of `divmod(0, b)`.
2024-07-15 17:04:54 +02:00
jax authors
b966c74fae Merge pull request #22443 from gnecula:docs_checkify
PiperOrigin-RevId: 652482390
2024-07-15 08:01:31 -07:00
George Necula
be8e83adc1 [docs] Fix docs building error
The checkify APIs were mentioned in the jax.experimental.rst and also
in jax.experimental.checkify.rst.
2024-07-15 15:42:33 +01:00
Georg Stefan Schmid
0428871c82 Adapt test case 2024-07-15 09:37:59 +00:00
Georg Stefan Schmid
cdd17cabba Also overlap staged transfers (ba to hostseline case when jax memories or pinned_host is unavailable) 2024-07-15 09:37:53 +00:00
Georg Stefan Schmid
6b0136b829 async_serialize: Overlap shard copies 2024-07-15 09:37:47 +00:00
Georg Stefan Schmid
b8b9d2878c [memories] Transfer to pinned_host fast path in async_serialize 2024-07-15 09:35:43 +00:00
jax authors
a8b425cac5 Update XLA dependency to use revision
c50e9f58f0.

PiperOrigin-RevId: 652100151
2024-07-13 13:14:57 -07:00
jax authors
764ec92118 Add support for elementwise op canonicalization in fp32 for older hardware.
PiperOrigin-RevId: 651959463
2024-07-12 19:58:55 -07:00
Yash Katariya
0dfb206088 Reference make_array_from_process_local_data in make_array_from_single_device_arrays docstring.
PiperOrigin-RevId: 651937263
2024-07-12 18:10:15 -07:00
Jevin Jiang
aa16485457 [XLA:Mosaic] Support memref shapecast.
This cl supports memref shapecast:
1. if tile is (1, 128), we support shapecast on any dim.
2. if shapecast on sublane dim, we only support tile aligned shape.
3. if shapecast on non-tiling dim, we support any shapecast.
4. all other cases would be considered as invalid memref shapecast.

PiperOrigin-RevId: 651924552
2024-07-12 17:05:03 -07:00
Sharad Vikram
f3c1cbc709 Add custom rules for str_eqn_compact
PiperOrigin-RevId: 651911281
2024-07-12 16:03:27 -07:00
jax authors
ea8de20d45 Update XLA dependency to use revision
a8425caae5.

PiperOrigin-RevId: 651881048
2024-07-12 14:05:52 -07:00
Sharad Vikram
7016ca4829 [Mosaic] Strengthen check on return types from RegionOp
PiperOrigin-RevId: 651879359
2024-07-12 13:59:50 -07:00
Yash Katariya
60fccc2aac Disable test_source_file_prefix_removal test because there is cross-contamination of metadata information from different call sites because of cached jitted functions
PiperOrigin-RevId: 651847347
2024-07-12 12:07:24 -07:00
jax authors
17591f5c19 Merge pull request #22359 from selamw1:isandallclose_doc
PiperOrigin-RevId: 651826321
2024-07-12 11:01:35 -07:00
jax authors
1699d04f66 Merge pull request #22419 from ROCm:ci_rocm_version_fix
PiperOrigin-RevId: 651818904
2024-07-12 10:38:51 -07:00
jax authors
e59fdc5089 More Pallas bug reproducers.
PiperOrigin-RevId: 651799590
2024-07-12 09:33:58 -07:00