13235 Commits

Author SHA1 Message Date
jax authors
849f837b6a Merge pull request #12609 from skye:jax2tf_import
PiperOrigin-RevId: 478104104
jax-v0.3.21 jax-v0.3.21-rc 0.3.21
2022-09-30 16:25:17 -07:00
Skye Wanderman-Milne
0a69c9a27b Fix jax2tf import so it works with both the latest tensorflow release (2.10.0) and tf-nightly 2022-09-30 15:55:22 -07:00
Yash Katariya
fb8558cfdd Add jax_array coverage to debug_nans_test
PiperOrigin-RevId: 478079509
2022-09-30 14:21:32 -07:00
jax authors
ec41de2c9b Merge pull request #12603 from jbushago:patch-1
PiperOrigin-RevId: 478068347
2022-09-30 13:33:03 -07:00
jax authors
ea77c453cb Merge pull request #12602 from skye:colab
PiperOrigin-RevId: 478063554
2022-09-30 13:17:47 -07:00
jax authors
8a1e0ed13f Merge pull request #12594 from skye:cache_warnings
PiperOrigin-RevId: 478063392
2022-09-30 13:11:34 -07:00
Skye Wanderman-Milne
15e5f38a16 Make persistent compilation cache warn instead of raise an error on cache read/write failures
Fixes #12582. Setting the env var `JAX_RAISE_PERSISTENT_CACHE_ERRORS=true` will revert to the original behavior of raising exception instead of warning.

Also makes JAX_DUMP_IR_TO work when the persistent cache is enabled.
2022-09-30 18:38:22 +00:00
jbushago
2038988783
Fix typo in faq.rst.
Fixed a small typo in the FAQ: "inthe" -> "in the".
2022-09-30 14:14:05 -04:00
Skye Wanderman-Milne
0cc4066bb7 Pin default jax.tools.colab_tpu.setup_tpu driver version.
Prior to this change, we were defaulting to the TPU nightly driver
version. We should instead pin to the version associated with the
default jaxlib version that Colab uses.
2022-09-30 17:45:49 +00:00
Yash Katariya
aafc77d3c0 Improve the checks done in Array and apply them to all Shardings rather than just XLACompatibleSharding.
Also check the symmetric difference of sharding and `_arrays` devices.

PiperOrigin-RevId: 478017409
2022-09-30 09:56:16 -07:00
Jake VanderPlas
4e51d2d4fb Roll back https://github.com/google/jax/pull/12588 because of test failures
PiperOrigin-RevId: 477871341
2022-09-29 18:30:58 -07:00
jax authors
d498bd1eaf Merge pull request #12588 from jakevdp:random-annotations
PiperOrigin-RevId: 477855302
2022-09-29 16:52:42 -07:00
Yash Katariya
9ff570e6c3 Make debug_nans_test.py pass with jax_array=1. Both with enabled and disabled jax_array flag and --pdb_post_mortem, we fall to the same place.
PiperOrigin-RevId: 477850567
2022-09-29 16:29:58 -07:00
Yash Katariya
3c7d927a2c Disable dynamic_api_test and custom_object_test.py with jax.Array. Enable it back when support for it is added. Also don't use xla_shape since its deprecated.
PiperOrigin-RevId: 477833061
2022-09-29 15:09:55 -07:00
Jake VanderPlas
aed46f3312 [typing] use jax.Array annotations in random.py 2022-09-29 14:29:02 -07:00
Yash Katariya
eb0fa4028f Fix process_allgather to work with jax.Array.
PiperOrigin-RevId: 477793014
2022-09-29 12:32:21 -07:00
jax authors
4f90af91d3 Remove unused jax_unique_mhlo_module_names flag.
PiperOrigin-RevId: 477778135
2022-09-29 11:32:22 -07:00
jax authors
a770db04fa Merge pull request #12579 from jakevdp:gather-unique
PiperOrigin-RevId: 477767679
2022-09-29 10:56:14 -07:00
Yash Katariya
7b49a3f51d Run tests in multiprocess_gpu_test only if the backend is GPU.
PiperOrigin-RevId: 477750739
2022-09-29 09:54:32 -07:00
Jake VanderPlas
1bc161a67d random.permutation: use unique_indices=True for efficiency 2022-09-29 09:34:03 -07:00
Jake VanderPlas
d49c5c37ea jnp.take: add optional arguments forwarded to lax.gather 2022-09-29 09:33:38 -07:00
Mehdi Amini
137384d856 Update xla_sharding import path to new location
We are moving the TensorFlow APIs outside of XLA and will remove the old
path soon.

PiperOrigin-RevId: 477701988
2022-09-29 05:40:56 -07:00
jax authors
de5dd1a4a5 Merge pull request #12444 from LenaMartens:checkify-switch
PiperOrigin-RevId: 477688623
2022-09-29 04:18:18 -07:00
lenamartens
0639aced5b Raise cond index into tracing context in case of effects.
So even if the cond is not data dependent at all, it's included in the
dynamic trace, and effects can be discharged.
2022-09-29 11:36:04 +01:00
jax authors
48b89560e5 Merge pull request #12566 from mattjj:djax-slice-sick
PiperOrigin-RevId: 477626935
2022-09-28 21:23:38 -07:00
Yash Katariya
163b7e22d2 Convert shardings in jit path to OpShardingSharding to avoid recompilation when semantically similar shardings are used in jit.
PiperOrigin-RevId: 477626548
2022-09-28 21:17:29 -07:00
Yash Katariya
500f8b7f9c Add HLOSharding's repr to OpShardingSharding since its more compact.
PiperOrigin-RevId: 477587916
2022-09-28 17:00:16 -07:00
Yash Katariya
84768d2d49 Replace jax.xla.DeviceArray private type with the new public type jax.Array.
PiperOrigin-RevId: 477582562
2022-09-28 16:34:10 -07:00
Matthew Johnson
a8826e672b [dynamic-shapes] Add basic slicing support
If e.g. `x : f32[10, n]` then we want to handle Python expressions like `x[0]`.
To do that, we can use a generalized version of `dynamic_slice` which allows
dynamic slice sizes (where the result shape depends on those slice sizes).

Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-09-28 15:55:51 -07:00
jax authors
33dbf0ea1c Merge pull request #12565 from hawkinsp:release
PiperOrigin-RevId: 477549228
2022-09-28 14:14:30 -07:00
Peter Hawkins
b49e31a012 Update version numbers after release. 2022-09-28 18:49:22 +00:00
Yash Katariya
c89cb5d8a4 Use Array in __repr__ instead of the class name which is ArrayImpl.
PiperOrigin-RevId: 477465432
2022-09-28 08:57:53 -07:00
jax authors
0282b4bfad Merge pull request #12538 from jakevdp:bundle-pyi
PiperOrigin-RevId: 477453094
jax-v0.3.20 jaxlib-v0.3.20 jax-v0.3.20-rc 0.3.20
2022-09-28 08:00:20 -07:00
jax authors
aafc70d293 Merge pull request #12556 from hawkinsp:rocm
PiperOrigin-RevId: 477440001
2022-09-28 06:50:19 -07:00
jax authors
5fe7a5440f Merge pull request #12555 from hawkinsp:release
PiperOrigin-RevId: 477439236
2022-09-28 06:50:06 -07:00
jax authors
39eabe878d Merge pull request #12552 from hawkinsp:nccl
PiperOrigin-RevId: 477439228
2022-09-28 06:43:45 -07:00
Peter Hawkins
f7bafb3d4c Disable multiprocess_gpu_test that fails on ROCm. 2022-09-28 13:40:57 +00:00
Peter Hawkins
8d8643664c jax/jaxlib 0.3.20 release candidate. 2022-09-28 13:33:52 +00:00
Peter Hawkins
eabb91e53f Fix test failure in GPU CI if NCCL_DEBUG is enabled.
If NCCL_DEBUG is enabled, NCCL prints extra status information. Make
test accept this.
2022-09-28 13:06:04 +00:00
jax authors
96abd9ac75 Merge pull request #12540 from sharadmv:cond-lowering-fix
PiperOrigin-RevId: 477358889
2022-09-27 22:33:12 -07:00
Yash Katariya
96a85bd59a Make addressable_shards a property like local_shards
PiperOrigin-RevId: 477358276
2022-09-27 22:27:19 -07:00
jax authors
948906885d Merge pull request #12546 from mattjj:issue12542
PiperOrigin-RevId: 477356925
2022-09-27 22:16:02 -07:00
Sharad Vikram
ddeaa8dbbc Fix lowering bug in effectful batched cond and add tests 2022-09-27 22:12:13 -07:00
Yash Katariya
b4e1d0af8a Propagate name through ExecuteReplicated for dispatch.check_special
PiperOrigin-RevId: 477351323
2022-09-27 21:32:32 -07:00
Matthew Johnson
b175e11731 [c++ jit] only set use_fastpath in cache_miss if all args are DeviceArrays
fixes #12542

Co-authored-by: Peter Hawkins <phawkins@google.com>
Co-authored-by: Kuangyuan Chen <chky@google.com>
2022-09-27 20:51:07 -07:00
Yash Katariya
933b6a2fa4 Fix the bug where XLA doesn't provide shardings for all the outputs if all the elements in the output tuple have the same sharding. XLA decides to run the FusionTupleDeduplicator to put the sharding on ROOT instead of the tuple.
PiperOrigin-RevId: 477343328
2022-09-27 20:27:39 -07:00
Yash Katariya
c8bff11d1b Add addressable_ counterparts of local_ to GDA to make it easier for users to move to Array as both will have the same API.
PiperOrigin-RevId: 477332697
2022-09-27 19:19:29 -07:00
Yash Katariya
e4f2bff0a3 Disintegrate Array into DeviceBuffers inside GDA. This is required for backwards compatibility changes as users can create GDAs and pass that to pjit even when Array is switched on.
PiperOrigin-RevId: 477297406
2022-09-27 16:02:23 -07:00
jax authors
0919a6776a Merge pull request #12534 from google:update-pypi
PiperOrigin-RevId: 477260550
2022-09-27 13:31:05 -07:00
Jake VanderPlas
6e6fb10ca3 setup: bundle *.pyi files with distribution 2022-09-27 12:55:42 -07:00