Peter Hawkins
7495a9e370
[JAX] Enable/disable tests that timed out in CI.
...
Reenable pmap_test since it was recently sped up.
PiperOrigin-RevId: 491650701
2022-11-29 09:02:16 -08:00
jax authors
21bab5efab
Merge pull request #13431 from jakevdp:fix-sparse-matmul-warning
...
PiperOrigin-RevId: 491649932
2022-11-29 08:55:23 -08:00
Yash Katariya
3e5a5053f4
Run GPU presubmits via bazel test on the RBE cluster. This speeds up the build + testing significantly (upto 10x).
...
But run the continuous builds by building on RBE and testing locally so as to run the multiaccelerator tests too. Locally we have 4 GPUs available.
Also make GPU presubmits blocking for JAX (re-enabled it).
PiperOrigin-RevId: 491647775
2022-11-29 08:45:58 -08:00
Jake VanderPlas
885441b2c8
[sparse] ignore GPU warning in tests
2022-11-29 08:40:12 -08:00
Qiao Zhang
c54bc90bf4
Fix cudnn_header OSS BUILD dep.
...
PiperOrigin-RevId: 491465703
2022-11-28 15:58:55 -08:00
Qiao Zhang
4d1c4bc761
Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
...
PiperOrigin-RevId: 491445515
2022-11-28 14:31:48 -08:00
jax authors
702a1084b8
Merge pull request #13422 from jakevdp:ci-timeout
...
PiperOrigin-RevId: 491435108
2022-11-28 13:53:15 -08:00
Jake VanderPlas
1647c5960e
CI: bump timeout for pre-commit
2022-11-28 13:26:44 -08:00
jax authors
441d53b766
Merge pull request #13253 from NeilGirdhar:keypaths
...
PiperOrigin-RevId: 491424268
2022-11-28 13:10:20 -08:00
jax authors
87408c769a
Merge pull request #13421 from sharadmv:fix-rtd
...
PiperOrigin-RevId: 491404001
2022-11-28 11:53:17 -08:00
jax authors
2745cc2cfa
Merge pull request #13420 from jakevdp:sparse-eye-jit
...
PiperOrigin-RevId: 491399990
2022-11-28 11:37:25 -08:00
jax authors
f211c9d0bb
Merge pull request #13413 from Tixxx:tixxx/fix_mem_fraction_doc
...
PiperOrigin-RevId: 491396608
2022-11-28 11:24:38 -08:00
Sharad Vikram
c0c8eed6fa
Pin IPython version in docs build to avoid RTD warning
2022-11-28 11:22:41 -08:00
Jake VanderPlas
dce6a9f8ce
[sparse] fix bug in sparse.eye under JIT
2022-11-28 10:54:09 -08:00
TJ
5fb0215d4d
updated jaxlib CHANGELOG
2022-11-28 10:37:42 -08:00
TJ
7456b66d35
changed both "currently available" to "total" for mem allocation doc
2022-11-28 09:30:25 -08:00
TJ
4011d17965
Change documentation to state the correct usage of XLA_PYTHON_CLIENT_MEM_FRACTION
2022-11-27 20:05:38 -08:00
Johannes Reifferscheid
cc1d2aaaed
Disable more {cost,memory}_analysis tests when MLIR lowering is enabled.
...
PiperOrigin-RevId: 490898616
2022-11-25 06:25:56 -08:00
Johannes Reifferscheid
575c2f3783
Skip unsupported tests on XLA:CPU MLIR.
...
PiperOrigin-RevId: 490754048
2022-11-24 09:56:59 -08:00
Adam Paszke
a711166569
Make eager pmap take advantage of pmap cache
...
The current strategy of creating a `partial(primitive.bind, **params)` has the downside
of completely confusing the pmap cache and resulting in a new compilation for every single
primitive. Replacing it with a `HashableFunction` should fix it.
Also, pmap_test is now 2x faster!
PiperOrigin-RevId: 490749153
2022-11-24 09:13:51 -08:00
jax authors
8788a9438f
Merge pull request #13382 from levskaya:fasterci
...
PiperOrigin-RevId: 490600867
2022-11-23 15:54:37 -08:00
Anselm Levskaya
074e4ec813
Enable faster test-runners for PR/push CI runs.
2022-11-23 14:07:08 -08:00
jax authors
ac77286dbf
Merge pull request #13381 from skye:logger
...
PiperOrigin-RevId: 490562425
2022-11-23 12:34:11 -08:00
Johannes Reifferscheid
da4108d5e0
mhlo.all_to_all: support tuple form in importer/exporter.
...
PiperOrigin-RevId: 490560403
2022-11-23 12:24:12 -08:00
Yash Katariya
03aa266e45
Build GPU wheel only for Tesla (t4) when running on RBE. THis should in theory speed up the builds
...
PiperOrigin-RevId: 490553589
2022-11-23 11:57:24 -08:00
Yash Katariya
a4e8df76ab
Use the remote_gpu
tag which is inserted by TF's workspace2 when REMOTE_GPU_TESTING=1
...
PiperOrigin-RevId: 490553133
2022-11-23 11:50:50 -08:00
Skye Wanderman-Milne
6d79e2f485
Fix some logging to use logger
instead of logging
2022-11-23 19:44:01 +00:00
Parker Schuh
c00821ea57
Support AOT serialization of pmap.
...
PiperOrigin-RevId: 490547612
2022-11-23 11:24:59 -08:00
Yash Katariya
8e270575f8
Set tf_exec_properties
on OSS tests to use TF's gpu pool in the RBE cluster.
...
PiperOrigin-RevId: 490542399
2022-11-23 11:00:53 -08:00
Yash Katariya
51e4b017bb
Use the cuda 11.8 image for CPU builds to reduce the docker image churn from the JAX side in TF's RBE cluster
...
PiperOrigin-RevId: 490522647
2022-11-23 09:33:17 -08:00
Adam Paszke
7dd958947d
Make testManyArgs actually test pmap with many args
...
For some reason the test has always been passing a single array since it was added,
which seems contradictory with its purpose.
PiperOrigin-RevId: 490519068
2022-11-23 09:15:29 -08:00
George Necula
93674a9d75
Transition default use of XlaCallModule to StableHLO (version 2).
...
PiperOrigin-RevId: 490505972
2022-11-23 08:25:14 -08:00
Peter Hawkins
c1e1d64e66
[TPU] Add cutoff for nearly diagonal matrices in QDWH-eigh algorithm.
...
Fixes a bug where eigh returned NaNs for diagonal matrices, e.g., the identity matrix.
Nakatsukasa and Higham mention this stopping criterion in section 5.2 of Stable and Efficient Spectral Divide and Conquer Algorithms for the Symmetric Eigenvalue Decomposition and the SVD.
PiperOrigin-RevId: 490505832
2022-11-23 08:17:50 -08:00
jax authors
d1fbdbc1cf
Rollback of "Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module."
...
PiperOrigin-RevId: 490499003
2022-11-23 07:48:05 -08:00
Adam Paszke
fe56a19904
Shard fft tests to avoid timeouts
...
PiperOrigin-RevId: 490486632
2022-11-23 06:33:13 -08:00
jax authors
d045fe2f95
Merge pull request #13366 from froystig:custom-vmap-consts
...
PiperOrigin-RevId: 490412416
2022-11-22 22:09:57 -08:00
Qiao Zhang
78963b6020
Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
...
PiperOrigin-RevId: 490387796
2022-11-22 18:53:29 -08:00
jax authors
f33d5514c9
Merge pull request #13367 from froystig:custom-derivatives-docfix
...
PiperOrigin-RevId: 490383906
2022-11-22 18:30:42 -08:00
Roy Frostig
fcce6b102c
remove cotangent negation in custom VJP example
...
This was originally intended to show that we can change the VJP by
customizing it, but the algebraic incorrectness is confusing.
2022-11-22 17:55:22 -08:00
Roy Frostig
ef9b2fe4a1
custom vmap: support closure and staged constants
...
The `custom_vmap` primitive stages out its wrapped function at call
time. It might extract closed-over or otherwise constant values
("consts") in doing so. To handle these, we can reduce back to the
empty closure setting: convert the consts to formal arguments, both in
the target function and in the custom vmap rule, and ignore them in
the latter.
We only need to play this trick once, on initial entry. After that, we
can resume in assuming an empty closure.
2022-11-22 17:44:08 -08:00
jax authors
f341b273fe
Merge pull request #13361 from froystig:threefry-partitionable-jit-cache-key
...
PiperOrigin-RevId: 490373768
2022-11-22 17:22:26 -08:00
Igor Saprykin
be527b62d7
Make it clear that fun
is a function rather than a noun.
...
PiperOrigin-RevId: 490370522
2022-11-22 17:02:21 -08:00
jax authors
dd902fde21
Merge pull request #13317 from google:xdist_tpu
...
PiperOrigin-RevId: 490366370
2022-11-22 16:40:00 -08:00
Roy Frostig
6a52339dcc
include jax_threefry_partitionable
setting in staging cache key
2022-11-22 15:20:01 -08:00
jax authors
7128bb4ac9
Merge pull request #13358 from froystig:keyarray-jaxtype
...
PiperOrigin-RevId: 490343151
2022-11-22 14:49:59 -08:00
Roy Frostig
671e91d02d
reduce relative tolerance in small-alpha Dirichlet test
2022-11-22 14:10:14 -08:00
Peter Hawkins
61aa415356
Disable sparse_test_cpu under msan due to CI timeouts.
...
PiperOrigin-RevId: 490312188
2022-11-22 12:48:34 -08:00
Roy Frostig
f8ecab8f9a
fix Threefry split
/fold_in
symmetry test under key arrays mode
2022-11-22 09:59:13 -08:00
jax authors
518fe6656c
Pickling of Sharding classes: use module level functions when deserializing.
...
This avoids having to pickle the sharding class (which references the module and the Python source file) in the serialized bytes, which happens when deserializing using `classmethod`s.
PiperOrigin-RevId: 490249959
2022-11-22 08:31:16 -08:00
jax authors
d9383fc80d
Merge pull request #13343 from froystig:ci-tests-rng-partitionable
...
PiperOrigin-RevId: 490102616
2022-11-21 16:45:48 -08:00