Qiao Zhang
78963b6020
Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
...
PiperOrigin-RevId: 490387796
2022-11-22 18:53:29 -08:00
jax authors
f33d5514c9
Merge pull request #13367 from froystig:custom-derivatives-docfix
...
PiperOrigin-RevId: 490383906
2022-11-22 18:30:42 -08:00
Roy Frostig
fcce6b102c
remove cotangent negation in custom VJP example
...
This was originally intended to show that we can change the VJP by
customizing it, but the algebraic incorrectness is confusing.
2022-11-22 17:55:22 -08:00
jax authors
f341b273fe
Merge pull request #13361 from froystig:threefry-partitionable-jit-cache-key
...
PiperOrigin-RevId: 490373768
2022-11-22 17:22:26 -08:00
Igor Saprykin
be527b62d7
Make it clear that fun
is a function rather than a noun.
...
PiperOrigin-RevId: 490370522
2022-11-22 17:02:21 -08:00
jax authors
dd902fde21
Merge pull request #13317 from google:xdist_tpu
...
PiperOrigin-RevId: 490366370
2022-11-22 16:40:00 -08:00
Roy Frostig
6a52339dcc
include jax_threefry_partitionable
setting in staging cache key
2022-11-22 15:20:01 -08:00
jax authors
7128bb4ac9
Merge pull request #13358 from froystig:keyarray-jaxtype
...
PiperOrigin-RevId: 490343151
2022-11-22 14:49:59 -08:00
Roy Frostig
671e91d02d
reduce relative tolerance in small-alpha Dirichlet test
2022-11-22 14:10:14 -08:00
Peter Hawkins
61aa415356
Disable sparse_test_cpu under msan due to CI timeouts.
...
PiperOrigin-RevId: 490312188
2022-11-22 12:48:34 -08:00
Roy Frostig
f8ecab8f9a
fix Threefry split
/fold_in
symmetry test under key arrays mode
2022-11-22 09:59:13 -08:00
jax authors
518fe6656c
Pickling of Sharding classes: use module level functions when deserializing.
...
This avoids having to pickle the sharding class (which references the module and the Python source file) in the serialized bytes, which happens when deserializing using `classmethod`s.
PiperOrigin-RevId: 490249959
2022-11-22 08:31:16 -08:00
jax authors
d9383fc80d
Merge pull request #13343 from froystig:ci-tests-rng-partitionable
...
PiperOrigin-RevId: 490102616
2022-11-21 16:45:48 -08:00
jax authors
92ee87dcbf
Merge pull request #13341 from froystig:rng-split-fold-symmetry
...
PiperOrigin-RevId: 490097891
2022-11-21 16:22:48 -08:00
Roy Frostig
35634fcc2a
exercise config.jax_threefry_partitionable
in one of the CI runs
2022-11-21 15:30:58 -08:00
Roy Frostig
a412d27519
test threefry split
consistency with vmapped fold_in
of lax.axis_index
2022-11-21 15:24:48 -08:00
Roy Frostig
dab2909a31
make threefry split
and fold_in
symmetric
...
Namely, make it so that `split(key, n)[i]` equals `fold_in(key, i)`
for any key and for `0 <= i < n`.
This change affects the observed random bits for a fixed key (indirectly
through splits and folds), so here we guard it behind
`jax.config.jax_threefry_partitionable`. It's not described very well
by the flag name, but it makes for a simple way to bundle together
several random-bit-altering changes as part of the same upgrade cycle.
2022-11-21 15:24:48 -08:00
Peter Hawkins
42e367af9c
Fix typo in "nomsan" tag on pmap_test.
...
PiperOrigin-RevId: 489978468
2022-11-21 07:46:13 -08:00
George Necula
ab47c648c7
[jax2tf] Updated the Keras example to allow training.
...
The bug that prevented training has been fixed.
Fixes : #13329
PiperOrigin-RevId: 489952924
2022-11-21 05:08:49 -08:00
George Necula
de7237d226
[jax2tf] Upgrade examples to use optax instead of flax.optim
...
PiperOrigin-RevId: 489939357
2022-11-21 03:42:02 -08:00
jax authors
7890ec8164
Generalize TPU mesh computations.
...
PiperOrigin-RevId: 489936718
2022-11-21 03:22:24 -08:00
Yash Katariya
928dee415f
Optimize host_local_array_to_global_array
by caching the local to global conversion and flattening of axis resources. Also take a fast path for device_put which does not do abstractify
and only canonicalize_dtype on the entire array once (instead of doing it for every shard).
...
This results in a 5x speedup!
Before:
```
---------------------------------------------------------------------------
Benchmark Time CPU Iterations
---------------------------------------------------------------------------
host_local_array_to_global_array 3.03 ms 3.02 ms 220
```
After:
```
---------------------------------------------------------------------------
Benchmark Time CPU Iterations
---------------------------------------------------------------------------
host_local_array_to_global_array 0.673 ms 0.671 ms 985
```
PiperOrigin-RevId: 489880547
2022-11-20 20:53:02 -08:00
Yash Katariya
6d11567142
Add cuda 11.8 configs to .bazelrc
...
PiperOrigin-RevId: 489591457
2022-11-18 16:55:28 -08:00
Skye Wanderman-Milne
120125f3dd
Make pytest-xdist work on TPU and update Cloud TPU CI.
...
This change also marks multiaccelerator test files in a way pytest can
understand (if pytest is installed).
By running single-device tests on a single TPU chip, running the test
suite goes from 1hr 45m to 35m (both timings are running slow tests).
I tried using bazel at first, which already supported parallel
execution across TPU cores, but somehow it still takes 2h 20m! I'm not
sure why it's so slow. It appears that bazel creates many new test
processes over time, vs. pytest reuses the number of processes
initially specified, and starting and stopping the TPU runtime takes a
few seconds so that may be adding up. It also appears that
single-process bazel is slower than single-process pytest, which I
haven't looked into yet.
2022-11-18 22:05:13 +00:00
Yash Katariya
d918fe85f0
Use positional only arguments and extract self._params
from args
since it will always be the first argument.
...
PiperOrigin-RevId: 489527497
2022-11-18 11:51:27 -08:00
Yash Katariya
b6fa77cb60
Fix forward (Add deprecation warnings to DA, SDA and GDA): By raising the warnings in the hook of the jax_array config.
...
PiperOrigin-RevId: 489503583
2022-11-18 10:12:40 -08:00
jax authors
e8201f961e
Merge pull request #13313 from google:yashk2810-patch-20
...
PiperOrigin-RevId: 489479102
2022-11-18 08:16:52 -08:00
Yash Katariya
1824be772e
Update jax_array_migration.md
2022-11-18 08:11:22 -08:00
Yash Katariya
29d75324a3
Add a date till which jax.Array can be disabled
2022-11-18 08:09:31 -08:00
Peter Hawkins
9f2a6acb61
Revert: Add deprecation warnings to DA, SDA and GDA.
...
This change is currently overly noisy for users.
PiperOrigin-RevId: 489455729
2022-11-18 06:06:13 -08:00
jax authors
7a3dbcf94e
Change params to _params to avoid clashes with downstream users.
...
PiperOrigin-RevId: 489441483
2022-11-18 04:18:41 -08:00
Parker Schuh
91634e0da4
Refactor create_cpp_call to be a method on MeshExecutable
...
rather then being passed all the way down from pjit.py.
PiperOrigin-RevId: 489353681
2022-11-17 18:05:10 -08:00
Tianjian Lu
bf21480534
[sparse] Disable gpu bcoo_matmul test when type promotion is required.
...
PiperOrigin-RevId: 489351869
2022-11-17 17:53:22 -08:00
jax authors
4cc163d3c4
Merge pull request #13300 from jakevdp:sparse-neg
...
PiperOrigin-RevId: 489340385
2022-11-17 16:47:10 -08:00
Jake VanderPlas
e673f1fd44
[sparse] avoid re-indexing for linear unary ops
2022-11-17 16:31:46 -08:00
Parker Schuh
0324cac888
Remove unused potrf kernels.
...
PiperOrigin-RevId: 489322021
2022-11-17 15:22:13 -08:00
Yash Katariya
52a2428073
Add deprecation warnings to DA, SDA and GDA.
...
PiperOrigin-RevId: 489314189
2022-11-17 14:51:29 -08:00
Parker Schuh
4d418fb45e
Remove opt-barrier fallbacks.
...
PiperOrigin-RevId: 489285590
2022-11-17 12:57:57 -08:00
Yash Katariya
13ec2338bb
Add shape to the error message along with the value.
...
PiperOrigin-RevId: 489254448
2022-11-17 10:58:38 -08:00
Peter Hawkins
88379603e0
[PJRT] Delete the old :cpu_device target that uses StreamExecutor.
...
The TFRT CPU client is better in every way and the SE CPU client is unmaintained and has not been used by JAX in many months.
PiperOrigin-RevId: 489246256
2022-11-17 10:29:03 -08:00
Peter Hawkins
ebee4f4bfd
Disable test variants that time out in CI.
...
PiperOrigin-RevId: 489214464
2022-11-17 08:14:07 -08:00
jax authors
b0d5052928
Merge pull request #13286 from jakevdp:fix-jaxlib-build
...
PiperOrigin-RevId: 489210938
2022-11-17 07:56:45 -08:00
Jake VanderPlas
e66fe1dff4
Update XLA commit.
...
Fixes build error: no such target '@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:python/MlirHloModule.cc'
2022-11-16 15:51:28 -08:00
Jake VanderPlas
e7f4fe043e
jaxlib: fix mlir_hlo build rule
2022-11-16 15:42:05 -08:00
jax authors
2f3a384243
Merge pull request #13278 from google:chat_notif
...
PiperOrigin-RevId: 489033149
2022-11-16 14:04:53 -08:00
Skye Wanderman-Milne
0a886c34fa
Include which jaxlib/libtpu version failed (latest or nightly) in TPU CI chat notification
2022-11-16 21:38:36 +00:00
jax authors
3a837c8069
Merge pull request #13281 from google:tpu_notify_main_only
...
PiperOrigin-RevId: 489025620
2022-11-16 13:36:49 -08:00
Skye Wanderman-Milne
b4564a2a57
TPU CI: don't notify when testing the workflow from a branch
2022-11-16 21:27:24 +00:00
jax authors
58af581e3b
Merge pull request #13280 from jakevdp:sparse-test-util
...
PiperOrigin-RevId: 489019436
2022-11-16 13:13:04 -08:00
Parker Schuh
da765a2e54
Allow compiling and then serializing jax.stages.Lowered.
...
This adds experimental APIs to `serialize_executable.py`:
`compile_and_serialize(lowered)`
and
`load_compiled(serialized, in_tree, out_tree)`
for serializing and deserializing executables.
PiperOrigin-RevId: 489014705
2022-11-16 12:54:10 -08:00