13924 Commits

Author SHA1 Message Date
Peter Hawkins
7495a9e370 [JAX] Enable/disable tests that timed out in CI.
Reenable pmap_test since it was recently sped up.

PiperOrigin-RevId: 491650701
2022-11-29 09:02:16 -08:00
jax authors
21bab5efab Merge pull request #13431 from jakevdp:fix-sparse-matmul-warning
PiperOrigin-RevId: 491649932
2022-11-29 08:55:23 -08:00
Yash Katariya
3e5a5053f4 Run GPU presubmits via bazel test on the RBE cluster. This speeds up the build + testing significantly (upto 10x).
But run the continuous builds by building on RBE and testing locally so as to run the multiaccelerator tests too. Locally we have 4 GPUs available.

Also make GPU presubmits blocking for JAX (re-enabled it).

PiperOrigin-RevId: 491647775
2022-11-29 08:45:58 -08:00
Jake VanderPlas
885441b2c8 [sparse] ignore GPU warning in tests 2022-11-29 08:40:12 -08:00
Qiao Zhang
c54bc90bf4 Fix cudnn_header OSS BUILD dep.
PiperOrigin-RevId: 491465703
2022-11-28 15:58:55 -08:00
Qiao Zhang
4d1c4bc761 Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
PiperOrigin-RevId: 491445515
2022-11-28 14:31:48 -08:00
jax authors
702a1084b8 Merge pull request #13422 from jakevdp:ci-timeout
PiperOrigin-RevId: 491435108
2022-11-28 13:53:15 -08:00
Jake VanderPlas
1647c5960e CI: bump timeout for pre-commit 2022-11-28 13:26:44 -08:00
jax authors
441d53b766 Merge pull request #13253 from NeilGirdhar:keypaths
PiperOrigin-RevId: 491424268
2022-11-28 13:10:20 -08:00
jax authors
87408c769a Merge pull request #13421 from sharadmv:fix-rtd
PiperOrigin-RevId: 491404001
2022-11-28 11:53:17 -08:00
jax authors
2745cc2cfa Merge pull request #13420 from jakevdp:sparse-eye-jit
PiperOrigin-RevId: 491399990
2022-11-28 11:37:25 -08:00
jax authors
f211c9d0bb Merge pull request #13413 from Tixxx:tixxx/fix_mem_fraction_doc
PiperOrigin-RevId: 491396608
2022-11-28 11:24:38 -08:00
Sharad Vikram
c0c8eed6fa Pin IPython version in docs build to avoid RTD warning 2022-11-28 11:22:41 -08:00
Jake VanderPlas
dce6a9f8ce [sparse] fix bug in sparse.eye under JIT 2022-11-28 10:54:09 -08:00
TJ
5fb0215d4d updated jaxlib CHANGELOG 2022-11-28 10:37:42 -08:00
TJ
7456b66d35 changed both "currently available" to "total" for mem allocation doc 2022-11-28 09:30:25 -08:00
TJ
4011d17965 Change documentation to state the correct usage of XLA_PYTHON_CLIENT_MEM_FRACTION 2022-11-27 20:05:38 -08:00
Johannes Reifferscheid
cc1d2aaaed Disable more {cost,memory}_analysis tests when MLIR lowering is enabled.
PiperOrigin-RevId: 490898616
2022-11-25 06:25:56 -08:00
Johannes Reifferscheid
575c2f3783 Skip unsupported tests on XLA:CPU MLIR.
PiperOrigin-RevId: 490754048
2022-11-24 09:56:59 -08:00
Adam Paszke
a711166569 Make eager pmap take advantage of pmap cache
The current strategy of creating a `partial(primitive.bind, **params)` has the downside
of completely confusing the pmap cache and resulting in a new compilation for every single
primitive. Replacing it with a `HashableFunction` should fix it.

Also, pmap_test is now 2x faster!

PiperOrigin-RevId: 490749153
2022-11-24 09:13:51 -08:00
jax authors
8788a9438f Merge pull request #13382 from levskaya:fasterci
PiperOrigin-RevId: 490600867
2022-11-23 15:54:37 -08:00
Anselm Levskaya
074e4ec813 Enable faster test-runners for PR/push CI runs. 2022-11-23 14:07:08 -08:00
jax authors
ac77286dbf Merge pull request #13381 from skye:logger
PiperOrigin-RevId: 490562425
2022-11-23 12:34:11 -08:00
Johannes Reifferscheid
da4108d5e0 mhlo.all_to_all: support tuple form in importer/exporter.
PiperOrigin-RevId: 490560403
2022-11-23 12:24:12 -08:00
Yash Katariya
03aa266e45 Build GPU wheel only for Tesla (t4) when running on RBE. THis should in theory speed up the builds
PiperOrigin-RevId: 490553589
2022-11-23 11:57:24 -08:00
Yash Katariya
a4e8df76ab Use the remote_gpu tag which is inserted by TF's workspace2 when REMOTE_GPU_TESTING=1
PiperOrigin-RevId: 490553133
2022-11-23 11:50:50 -08:00
Skye Wanderman-Milne
6d79e2f485 Fix some logging to use logger instead of logging 2022-11-23 19:44:01 +00:00
Parker Schuh
c00821ea57 Support AOT serialization of pmap.
PiperOrigin-RevId: 490547612
2022-11-23 11:24:59 -08:00
Yash Katariya
8e270575f8 Set tf_exec_properties on OSS tests to use TF's gpu pool in the RBE cluster.
PiperOrigin-RevId: 490542399
2022-11-23 11:00:53 -08:00
Yash Katariya
51e4b017bb Use the cuda 11.8 image for CPU builds to reduce the docker image churn from the JAX side in TF's RBE cluster
PiperOrigin-RevId: 490522647
2022-11-23 09:33:17 -08:00
Adam Paszke
7dd958947d Make testManyArgs actually test pmap with many args
For some reason the test has always been passing a single array since it was added,
which seems contradictory with its purpose.

PiperOrigin-RevId: 490519068
2022-11-23 09:15:29 -08:00
George Necula
93674a9d75 Transition default use of XlaCallModule to StableHLO (version 2).
PiperOrigin-RevId: 490505972
2022-11-23 08:25:14 -08:00
Peter Hawkins
c1e1d64e66 [TPU] Add cutoff for nearly diagonal matrices in QDWH-eigh algorithm.
Fixes a bug where eigh returned NaNs for diagonal matrices, e.g., the identity matrix.

Nakatsukasa and Higham mention this stopping criterion in section 5.2 of Stable and Efficient Spectral Divide and Conquer Algorithms for the Symmetric Eigenvalue Decomposition and the SVD.

PiperOrigin-RevId: 490505832
2022-11-23 08:17:50 -08:00
jax authors
d1fbdbc1cf Rollback of "Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module."
PiperOrigin-RevId: 490499003
2022-11-23 07:48:05 -08:00
Adam Paszke
fe56a19904 Shard fft tests to avoid timeouts
PiperOrigin-RevId: 490486632
2022-11-23 06:33:13 -08:00
jax authors
d045fe2f95 Merge pull request #13366 from froystig:custom-vmap-consts
PiperOrigin-RevId: 490412416
2022-11-22 22:09:57 -08:00
Qiao Zhang
78963b6020 Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
PiperOrigin-RevId: 490387796
2022-11-22 18:53:29 -08:00
jax authors
f33d5514c9 Merge pull request #13367 from froystig:custom-derivatives-docfix
PiperOrigin-RevId: 490383906
2022-11-22 18:30:42 -08:00
Roy Frostig
fcce6b102c remove cotangent negation in custom VJP example
This was originally intended to show that we can change the VJP by
customizing it, but the algebraic incorrectness is confusing.
2022-11-22 17:55:22 -08:00
Roy Frostig
ef9b2fe4a1 custom vmap: support closure and staged constants
The `custom_vmap` primitive stages out its wrapped function at call
time. It might extract closed-over or otherwise constant values
("consts") in doing so. To handle these, we can reduce back to the
empty closure setting: convert the consts to formal arguments, both in
the target function and in the custom vmap rule, and ignore them in
the latter.

We only need to play this trick once, on initial entry. After that, we
can resume in assuming an empty closure.
2022-11-22 17:44:08 -08:00
jax authors
f341b273fe Merge pull request #13361 from froystig:threefry-partitionable-jit-cache-key
PiperOrigin-RevId: 490373768
2022-11-22 17:22:26 -08:00
Igor Saprykin
be527b62d7 Make it clear that fun is a function rather than a noun.
PiperOrigin-RevId: 490370522
2022-11-22 17:02:21 -08:00
jax authors
dd902fde21 Merge pull request #13317 from google:xdist_tpu
PiperOrigin-RevId: 490366370
2022-11-22 16:40:00 -08:00
Roy Frostig
6a52339dcc include jax_threefry_partitionable setting in staging cache key 2022-11-22 15:20:01 -08:00
jax authors
7128bb4ac9 Merge pull request #13358 from froystig:keyarray-jaxtype
PiperOrigin-RevId: 490343151
2022-11-22 14:49:59 -08:00
Roy Frostig
671e91d02d reduce relative tolerance in small-alpha Dirichlet test 2022-11-22 14:10:14 -08:00
Peter Hawkins
61aa415356 Disable sparse_test_cpu under msan due to CI timeouts.
PiperOrigin-RevId: 490312188
2022-11-22 12:48:34 -08:00
Roy Frostig
f8ecab8f9a fix Threefry split/fold_in symmetry test under key arrays mode 2022-11-22 09:59:13 -08:00
jax authors
518fe6656c Pickling of Sharding classes: use module level functions when deserializing.
This avoids having to pickle the sharding class (which references the module and the Python source file) in the serialized bytes, which happens when deserializing using `classmethod`s.

PiperOrigin-RevId: 490249959
2022-11-22 08:31:16 -08:00
jax authors
d9383fc80d Merge pull request #13343 from froystig:ci-tests-rng-partitionable
PiperOrigin-RevId: 490102616
2022-11-21 16:45:48 -08:00