21898 Commits

Author SHA1 Message Date
Sharad Vikram
39ec5dacb4 [Pallas TPU] Add matrix multiplication tutorial 2024-07-16 18:12:19 -07:00
Sharad Vikram
7069a5a2e1 [Pallas/Mosaic] Refactor how we lower Pallas to a custom call
This avoids a round-trip through lower_fun which creates unnecessary HLO ops and name stack entries (which can add noise to trace viewers).

PiperOrigin-RevId: 653018101
2024-07-16 16:21:00 -07:00
jax authors
778477b62b Update XLA dependency to use revision
d8e2b28463.

PiperOrigin-RevId: 652932751
2024-07-16 12:11:43 -07:00
jax authors
4907c38742 Merge pull request #22386 from rdyro:rdyro/explain_persistent_compilation
PiperOrigin-RevId: 652928829
2024-07-16 12:00:37 -07:00
jax authors
84827bf247 Merge pull request #22418 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 652921968
2024-07-16 11:40:06 -07:00
jax authors
5ddec63a47 Merge pull request #22441 from gnecula:test_clean_hypothesis
PiperOrigin-RevId: 652919414
2024-07-16 11:32:46 -07:00
Chris Jones
15c542228a [jax] Make the SupportsDType protocol runtime checkable.
This allows `DTypeLike` to be used as a type annotation for type-checked functions without triggering a warning.

PiperOrigin-RevId: 652905699
2024-07-16 10:57:15 -07:00
jax authors
661ecdd83e Merge pull request #22475 from dfm:fix-lint
PiperOrigin-RevId: 652865171
2024-07-16 09:01:31 -07:00
Dan Foreman-Mackey
556cc23fa5 Fix lint at head.
It looks like https://github.com/google/jax/pull/22330 introduced some
mypy lint. This PR fixes it.
2024-07-16 10:53:49 -04:00
Kaixi Hou
09531d2ff8 PR #22330: [NVIDIA] Remove logic of combining bias and mask
Imported from GitHub PR https://github.com/google/jax/pull/22330

The cudnn API has already supported the combination of bias and mask from [this PR](https://github.com/google/jax/pull/22078). We are removing the logic from the public sdpa API and pass the mask directly.

cc. @Cjkkkk
Copybara import of the project:

--
0f75f58a9d81c0ae0a83701a71998c940318732a by kaixih <kaixih@nvidia.com>:

Remove logic of combining bias and mask

Merging this change closes #22330

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/22330 from kaixih:remove_combine_bias_mask 0f75f58a9d81c0ae0a83701a71998c940318732a
PiperOrigin-RevId: 652830016
2024-07-16 07:19:01 -07:00
Sergei Lebedev
7a62b8dd18 Re-enabled PallasCallPrintTest on Cloud TPUs
PiperOrigin-RevId: 652823653
2024-07-16 06:55:20 -07:00
rajasekharporeddy
a00c94c5a2 Improved docs for jnp.median and nanmedian 2024-07-16 17:27:46 +05:30
George Necula
e7be205a39 [jax2tf] Fix jax2tf tolerance for Cholesky, needed for newer TPUs
PiperOrigin-RevId: 652751771
2024-07-16 02:13:11 -07:00
George Necula
d34a6e9ce2 [jax2tf] Deprecate jax2tf with native_serialization=False or enable_xla=False.
Also disable many of the non-native-serialization jax2tf tests.
In particular, I am disabling the thousands of primitives tests in
graph serialization mode.
I kept jax2tf_test running in both native and graph serialization mode.

PiperOrigin-RevId: 652749891
2024-07-16 02:05:43 -07:00
Christos Perivolaropoulos
28ffa25496 [MosaicGPU] Move parity computations to a separate function to allow the user to use wait_parity without duplicate code.
PiperOrigin-RevId: 652665738
2024-07-15 19:08:10 -07:00
Justin Fu
0690988626 [Pallas] Add limited boolean memref support for scalars.
PiperOrigin-RevId: 652653003
2024-07-15 17:59:05 -07:00
rdyro
c6d6207170 Unifying persistent cache messages
and moving them to WARNING logging when explain_cache_misses is true.
2024-07-16 00:47:53 +00:00
jax authors
21e9dabdad Merge pull request #22465 from ROCm:ci_multiprocess_gpu_test
PiperOrigin-RevId: 652602616
2024-07-15 14:38:58 -07:00
Ruturaj4
fb5c516405 [ROCM] test_computation_follows_data fix for rocm and cuda 2024-07-15 16:28:22 -05:00
jax authors
3e09040cf7 Merge pull request #22463 from hawkinsp:tpuci2
PiperOrigin-RevId: 652597728
2024-07-15 14:25:15 -07:00
jax authors
fc3654792b Merge pull request #22450 from kaixih:update_sdpa_doc
PiperOrigin-RevId: 652597691
2024-07-15 14:21:27 -07:00
Yash Katariya
bb7a6995f9 Remove the spmd_mode check. It's disabled in OSS since a long time.
PiperOrigin-RevId: 652591122
2024-07-15 13:58:23 -07:00
Peter Hawkins
f488c4cc31 Disable some tests that fail on Cloud TPU. 2024-07-15 16:00:58 -04:00
Peter Hawkins
2f45cd725a Bump some SVD test tolerances.
These just barely fail on recent TPUs.

PiperOrigin-RevId: 652571985
2024-07-15 12:54:00 -07:00
jax authors
5d3392927a Update XLA dependency to use revision
83e38528e0.

PiperOrigin-RevId: 652563753
2024-07-15 12:27:15 -07:00
Kevin Gleason
5e897c61f5 Integrate StableHLO at openxla/stablehlo@8817ff1d
PiperOrigin-RevId: 652528759
2024-07-15 10:38:09 -07:00
kaixih
0d387e0839 Update jax doc sdpa 2024-07-15 17:30:54 +00:00
jax authors
26ec43f9e5 Merge pull request #22445 from dfm:numpy-nightly-unique
PiperOrigin-RevId: 652520243
2024-07-15 10:16:03 -07:00
jax authors
cab1f85f09 Merge pull request #22448 from hawkinsp:tpuci
PiperOrigin-RevId: 652520107
2024-07-15 10:12:39 -07:00
jax authors
7255ab146b Merge pull request #22440 from gnecula:pallas_test_clean
PiperOrigin-RevId: 652513116
2024-07-15 09:55:38 -07:00
jax authors
2b29a94255 Merge pull request #22375 from jakevdp:mypy-docs
PiperOrigin-RevId: 652511749
2024-07-15 09:52:07 -07:00
jax authors
5216719996 Merge pull request #22405 from gnecula:poly_pad
PiperOrigin-RevId: 652511693
2024-07-15 09:48:21 -07:00
Yash Katariya
b6264e99b1 Skip layout tests because they require xla_extension_version >= 274
PiperOrigin-RevId: 652510288
2024-07-15 09:43:03 -07:00
Peter Hawkins
3019467992 Update references to jaxlib_nightly_releases (which is a legacy index and no longer updated) to jax_nightly_releases.
Fixes CI jobs using old jaxlib nightly wheels.
2024-07-15 12:42:50 -04:00
Peter Hawkins
a1f69713f5 Disable Pallas vmap test that is very slow under tsan.
PiperOrigin-RevId: 652505878
2024-07-15 09:28:35 -07:00
jax authors
9c72e67711 Merge pull request #22169 from gspschmid:gschmid/async_serialize-overlap-shard-copies
PiperOrigin-RevId: 652502425
2024-07-15 09:15:21 -07:00
jax authors
86f4bb4346 Added more Mosaic bug reproducers.
PiperOrigin-RevId: 652498944
2024-07-15 09:02:48 -07:00
Dan Foreman-Mackey
7857bd3319 Fix compatibility of jnp.unique with numpy nightly
In https://github.com/numpy/numpy/pull/26914, the behavior of the
`return_inverse` argument to `np.unique` was partially reverted to the
pre-v2.0 behavior. The PR brings JAX's implementation compatible with
the `numpy>2.0.0` behavior.
2024-07-15 11:41:00 -04:00
jax authors
2e7e700090 Merge pull request #22402 from superbobry:maint
PiperOrigin-RevId: 652487969
2024-07-15 08:20:56 -07:00
George Necula
d3454f374e Add some hypothesis testing utilities and developer documentation.
Add a helper function for setting up hypothesis testing,
with support for selecting an interactive hypothesis profile
that speeds up interactive development.
2024-07-15 17:05:32 +02:00
George Necula
791743c296 [pallas] Some minor test housekeeping.
Add missing interpreter tests in tpu_pallas_test.
2024-07-15 17:05:15 +02:00
George Necula
7817b6785b [shape_poly] Expand the support for shape polymorphism for jnp.pad
Handle several new padding modes: wrap, reflect, symmetric, linear_ramp, maximum.
Not all situations are handled; try to give a clear error for the unsupported
cases.

While implementing this, I needed to add shape polymorphism support
also for jnp.linspace.

And I discovered a bug in the implementation of `divmod(0, b)`.
2024-07-15 17:04:54 +02:00
jax authors
b966c74fae Merge pull request #22443 from gnecula:docs_checkify
PiperOrigin-RevId: 652482390
2024-07-15 08:01:31 -07:00
George Necula
be8e83adc1 [docs] Fix docs building error
The checkify APIs were mentioned in the jax.experimental.rst and also
in jax.experimental.checkify.rst.
2024-07-15 15:42:33 +01:00
Georg Stefan Schmid
0428871c82 Adapt test case 2024-07-15 09:37:59 +00:00
Georg Stefan Schmid
cdd17cabba Also overlap staged transfers (ba to hostseline case when jax memories or pinned_host is unavailable) 2024-07-15 09:37:53 +00:00
Georg Stefan Schmid
6b0136b829 async_serialize: Overlap shard copies 2024-07-15 09:37:47 +00:00
Georg Stefan Schmid
b8b9d2878c [memories] Transfer to pinned_host fast path in async_serialize 2024-07-15 09:35:43 +00:00
jax authors
a8b425cac5 Update XLA dependency to use revision
c50e9f58f0.

PiperOrigin-RevId: 652100151
2024-07-13 13:14:57 -07:00
jax authors
764ec92118 Add support for elementwise op canonicalization in fp32 for older hardware.
PiperOrigin-RevId: 651959463
2024-07-12 19:58:55 -07:00