13886 Commits

Author SHA1 Message Date
Qiao Zhang
78963b6020 Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
PiperOrigin-RevId: 490387796
2022-11-22 18:53:29 -08:00
jax authors
f33d5514c9 Merge pull request #13367 from froystig:custom-derivatives-docfix
PiperOrigin-RevId: 490383906
2022-11-22 18:30:42 -08:00
Roy Frostig
fcce6b102c remove cotangent negation in custom VJP example
This was originally intended to show that we can change the VJP by
customizing it, but the algebraic incorrectness is confusing.
2022-11-22 17:55:22 -08:00
jax authors
f341b273fe Merge pull request #13361 from froystig:threefry-partitionable-jit-cache-key
PiperOrigin-RevId: 490373768
2022-11-22 17:22:26 -08:00
Igor Saprykin
be527b62d7 Make it clear that fun is a function rather than a noun.
PiperOrigin-RevId: 490370522
2022-11-22 17:02:21 -08:00
jax authors
dd902fde21 Merge pull request #13317 from google:xdist_tpu
PiperOrigin-RevId: 490366370
2022-11-22 16:40:00 -08:00
Roy Frostig
6a52339dcc include jax_threefry_partitionable setting in staging cache key 2022-11-22 15:20:01 -08:00
jax authors
7128bb4ac9 Merge pull request #13358 from froystig:keyarray-jaxtype
PiperOrigin-RevId: 490343151
2022-11-22 14:49:59 -08:00
Roy Frostig
671e91d02d reduce relative tolerance in small-alpha Dirichlet test 2022-11-22 14:10:14 -08:00
Peter Hawkins
61aa415356 Disable sparse_test_cpu under msan due to CI timeouts.
PiperOrigin-RevId: 490312188
2022-11-22 12:48:34 -08:00
Roy Frostig
f8ecab8f9a fix Threefry split/fold_in symmetry test under key arrays mode 2022-11-22 09:59:13 -08:00
jax authors
518fe6656c Pickling of Sharding classes: use module level functions when deserializing.
This avoids having to pickle the sharding class (which references the module and the Python source file) in the serialized bytes, which happens when deserializing using `classmethod`s.

PiperOrigin-RevId: 490249959
2022-11-22 08:31:16 -08:00
jax authors
d9383fc80d Merge pull request #13343 from froystig:ci-tests-rng-partitionable
PiperOrigin-RevId: 490102616
2022-11-21 16:45:48 -08:00
jax authors
92ee87dcbf Merge pull request #13341 from froystig:rng-split-fold-symmetry
PiperOrigin-RevId: 490097891
2022-11-21 16:22:48 -08:00
Roy Frostig
35634fcc2a exercise config.jax_threefry_partitionable in one of the CI runs 2022-11-21 15:30:58 -08:00
Roy Frostig
a412d27519 test threefry split consistency with vmapped fold_in of lax.axis_index 2022-11-21 15:24:48 -08:00
Roy Frostig
dab2909a31 make threefry split and fold_in symmetric
Namely, make it so that `split(key, n)[i]` equals `fold_in(key, i)`
for any key and for `0 <= i < n`.

This change affects the observed random bits for a fixed key (indirectly
through splits and folds), so here we guard it behind
`jax.config.jax_threefry_partitionable`. It's not described very well
by the flag name, but it makes for a simple way to bundle together
several random-bit-altering changes as part of the same upgrade cycle.
2022-11-21 15:24:48 -08:00
Peter Hawkins
42e367af9c Fix typo in "nomsan" tag on pmap_test.
PiperOrigin-RevId: 489978468
2022-11-21 07:46:13 -08:00
George Necula
ab47c648c7 [jax2tf] Updated the Keras example to allow training.
The bug that prevented training has been fixed.

Fixes: #13329
PiperOrigin-RevId: 489952924
2022-11-21 05:08:49 -08:00
George Necula
de7237d226 [jax2tf] Upgrade examples to use optax instead of flax.optim
PiperOrigin-RevId: 489939357
2022-11-21 03:42:02 -08:00
jax authors
7890ec8164 Generalize TPU mesh computations.
PiperOrigin-RevId: 489936718
2022-11-21 03:22:24 -08:00
Yash Katariya
928dee415f Optimize host_local_array_to_global_array by caching the local to global conversion and flattening of axis resources. Also take a fast path for device_put which does not do abstractify and only canonicalize_dtype on the entire array once (instead of doing it for every shard).
This results in a 5x speedup!

Before:

```
---------------------------------------------------------------------------
Benchmark                                 Time             CPU   Iterations
---------------------------------------------------------------------------
host_local_array_to_global_array       3.03 ms         3.02 ms          220
```

After:

```
---------------------------------------------------------------------------
Benchmark                                 Time             CPU   Iterations
---------------------------------------------------------------------------
host_local_array_to_global_array      0.673 ms        0.671 ms          985
```

PiperOrigin-RevId: 489880547
2022-11-20 20:53:02 -08:00
Yash Katariya
6d11567142 Add cuda 11.8 configs to .bazelrc
PiperOrigin-RevId: 489591457
2022-11-18 16:55:28 -08:00
Skye Wanderman-Milne
120125f3dd Make pytest-xdist work on TPU and update Cloud TPU CI.
This change also marks multiaccelerator test files in a way pytest can
understand (if pytest is installed).

By running single-device tests on a single TPU chip, running the test
suite goes from 1hr 45m to 35m (both timings are running slow tests).

I tried using bazel at first, which already supported parallel
execution across TPU cores, but somehow it still takes 2h 20m! I'm not
sure why it's so slow. It appears that bazel creates many new test
processes over time, vs. pytest reuses the number of processes
initially specified, and starting and stopping the TPU runtime takes a
few seconds so that may be adding up. It also appears that
single-process bazel is slower than single-process pytest, which I
haven't looked into yet.
2022-11-18 22:05:13 +00:00
Yash Katariya
d918fe85f0 Use positional only arguments and extract self._params from args since it will always be the first argument.
PiperOrigin-RevId: 489527497
2022-11-18 11:51:27 -08:00
Yash Katariya
b6fa77cb60 Fix forward (Add deprecation warnings to DA, SDA and GDA): By raising the warnings in the hook of the jax_array config.
PiperOrigin-RevId: 489503583
2022-11-18 10:12:40 -08:00
jax authors
e8201f961e Merge pull request #13313 from google:yashk2810-patch-20
PiperOrigin-RevId: 489479102
2022-11-18 08:16:52 -08:00
Yash Katariya
1824be772e
Update jax_array_migration.md 2022-11-18 08:11:22 -08:00
Yash Katariya
29d75324a3
Add a date till which jax.Array can be disabled 2022-11-18 08:09:31 -08:00
Peter Hawkins
9f2a6acb61 Revert: Add deprecation warnings to DA, SDA and GDA.
This change is currently overly noisy for users.

PiperOrigin-RevId: 489455729
2022-11-18 06:06:13 -08:00
jax authors
7a3dbcf94e Change params to _params to avoid clashes with downstream users.
PiperOrigin-RevId: 489441483
2022-11-18 04:18:41 -08:00
Parker Schuh
91634e0da4 Refactor create_cpp_call to be a method on MeshExecutable
rather then being passed all the way down from pjit.py.

PiperOrigin-RevId: 489353681
2022-11-17 18:05:10 -08:00
Tianjian Lu
bf21480534 [sparse] Disable gpu bcoo_matmul test when type promotion is required.
PiperOrigin-RevId: 489351869
2022-11-17 17:53:22 -08:00
jax authors
4cc163d3c4 Merge pull request #13300 from jakevdp:sparse-neg
PiperOrigin-RevId: 489340385
2022-11-17 16:47:10 -08:00
Jake VanderPlas
e673f1fd44 [sparse] avoid re-indexing for linear unary ops 2022-11-17 16:31:46 -08:00
Parker Schuh
0324cac888 Remove unused potrf kernels.
PiperOrigin-RevId: 489322021
2022-11-17 15:22:13 -08:00
Yash Katariya
52a2428073 Add deprecation warnings to DA, SDA and GDA.
PiperOrigin-RevId: 489314189
2022-11-17 14:51:29 -08:00
Parker Schuh
4d418fb45e Remove opt-barrier fallbacks.
PiperOrigin-RevId: 489285590
2022-11-17 12:57:57 -08:00
Yash Katariya
13ec2338bb Add shape to the error message along with the value.
PiperOrigin-RevId: 489254448
2022-11-17 10:58:38 -08:00
Peter Hawkins
88379603e0 [PJRT] Delete the old :cpu_device target that uses StreamExecutor.
The TFRT CPU client is better in every way and the SE CPU client is unmaintained and has not been used by JAX in many months.

PiperOrigin-RevId: 489246256
2022-11-17 10:29:03 -08:00
Peter Hawkins
ebee4f4bfd Disable test variants that time out in CI.
PiperOrigin-RevId: 489214464
2022-11-17 08:14:07 -08:00
jax authors
b0d5052928 Merge pull request #13286 from jakevdp:fix-jaxlib-build
PiperOrigin-RevId: 489210938
2022-11-17 07:56:45 -08:00
Jake VanderPlas
e66fe1dff4 Update XLA commit.
Fixes build error: no such target '@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:python/MlirHloModule.cc'
2022-11-16 15:51:28 -08:00
Jake VanderPlas
e7f4fe043e jaxlib: fix mlir_hlo build rule 2022-11-16 15:42:05 -08:00
jax authors
2f3a384243 Merge pull request #13278 from google:chat_notif
PiperOrigin-RevId: 489033149
2022-11-16 14:04:53 -08:00
Skye Wanderman-Milne
0a886c34fa Include which jaxlib/libtpu version failed (latest or nightly) in TPU CI chat notification 2022-11-16 21:38:36 +00:00
jax authors
3a837c8069 Merge pull request #13281 from google:tpu_notify_main_only
PiperOrigin-RevId: 489025620
2022-11-16 13:36:49 -08:00
Skye Wanderman-Milne
b4564a2a57 TPU CI: don't notify when testing the workflow from a branch 2022-11-16 21:27:24 +00:00
jax authors
58af581e3b Merge pull request #13280 from jakevdp:sparse-test-util
PiperOrigin-RevId: 489019436
2022-11-16 13:13:04 -08:00
Parker Schuh
da765a2e54 Allow compiling and then serializing jax.stages.Lowered.
This adds experimental APIs to `serialize_executable.py`:

`compile_and_serialize(lowered)`
and
`load_compiled(serialized, in_tree, out_tree)`

for serializing and deserializing executables.

PiperOrigin-RevId: 489014705
2022-11-16 12:54:10 -08:00