jax authors
849f837b6a
Merge pull request #12609 from skye:jax2tf_import
...
PiperOrigin-RevId: 478104104
jax-v0.3.21
jax-v0.3.21-rc
0.3.21
2022-09-30 16:25:17 -07:00
Skye Wanderman-Milne
0a69c9a27b
Fix jax2tf import so it works with both the latest tensorflow release (2.10.0) and tf-nightly
2022-09-30 15:55:22 -07:00
Yash Katariya
fb8558cfdd
Add jax_array coverage to debug_nans_test
...
PiperOrigin-RevId: 478079509
2022-09-30 14:21:32 -07:00
jax authors
ec41de2c9b
Merge pull request #12603 from jbushago:patch-1
...
PiperOrigin-RevId: 478068347
2022-09-30 13:33:03 -07:00
jax authors
ea77c453cb
Merge pull request #12602 from skye:colab
...
PiperOrigin-RevId: 478063554
2022-09-30 13:17:47 -07:00
jax authors
8a1e0ed13f
Merge pull request #12594 from skye:cache_warnings
...
PiperOrigin-RevId: 478063392
2022-09-30 13:11:34 -07:00
Skye Wanderman-Milne
15e5f38a16
Make persistent compilation cache warn instead of raise an error on cache read/write failures
...
Fixes #12582 . Setting the env var `JAX_RAISE_PERSISTENT_CACHE_ERRORS=true` will revert to the original behavior of raising exception instead of warning.
Also makes JAX_DUMP_IR_TO work when the persistent cache is enabled.
2022-09-30 18:38:22 +00:00
jbushago
2038988783
Fix typo in faq.rst.
...
Fixed a small typo in the FAQ: "inthe" -> "in the".
2022-09-30 14:14:05 -04:00
Skye Wanderman-Milne
0cc4066bb7
Pin default jax.tools.colab_tpu.setup_tpu driver version.
...
Prior to this change, we were defaulting to the TPU nightly driver
version. We should instead pin to the version associated with the
default jaxlib version that Colab uses.
2022-09-30 17:45:49 +00:00
Yash Katariya
aafc77d3c0
Improve the checks done in Array
and apply them to all Sharding
s rather than just XLACompatibleSharding
.
...
Also check the symmetric difference of sharding and `_arrays` devices.
PiperOrigin-RevId: 478017409
2022-09-30 09:56:16 -07:00
Jake VanderPlas
4e51d2d4fb
Roll back https://github.com/google/jax/pull/12588 because of test failures
...
PiperOrigin-RevId: 477871341
2022-09-29 18:30:58 -07:00
jax authors
d498bd1eaf
Merge pull request #12588 from jakevdp:random-annotations
...
PiperOrigin-RevId: 477855302
2022-09-29 16:52:42 -07:00
Yash Katariya
9ff570e6c3
Make debug_nans_test.py pass with jax_array=1. Both with enabled and disabled jax_array flag and --pdb_post_mortem, we fall to the same place.
...
PiperOrigin-RevId: 477850567
2022-09-29 16:29:58 -07:00
Yash Katariya
3c7d927a2c
Disable dynamic_api_test and custom_object_test.py with jax.Array. Enable it back when support for it is added. Also don't use xla_shape since its deprecated.
...
PiperOrigin-RevId: 477833061
2022-09-29 15:09:55 -07:00
Jake VanderPlas
aed46f3312
[typing] use jax.Array annotations in random.py
2022-09-29 14:29:02 -07:00
Yash Katariya
eb0fa4028f
Fix process_allgather
to work with jax.Array
.
...
PiperOrigin-RevId: 477793014
2022-09-29 12:32:21 -07:00
jax authors
4f90af91d3
Remove unused jax_unique_mhlo_module_names flag.
...
PiperOrigin-RevId: 477778135
2022-09-29 11:32:22 -07:00
jax authors
a770db04fa
Merge pull request #12579 from jakevdp:gather-unique
...
PiperOrigin-RevId: 477767679
2022-09-29 10:56:14 -07:00
Yash Katariya
7b49a3f51d
Run tests in multiprocess_gpu_test only if the backend is GPU.
...
PiperOrigin-RevId: 477750739
2022-09-29 09:54:32 -07:00
Jake VanderPlas
1bc161a67d
random.permutation: use unique_indices=True for efficiency
2022-09-29 09:34:03 -07:00
Jake VanderPlas
d49c5c37ea
jnp.take: add optional arguments forwarded to lax.gather
2022-09-29 09:33:38 -07:00
Mehdi Amini
137384d856
Update xla_sharding import path to new location
...
We are moving the TensorFlow APIs outside of XLA and will remove the old
path soon.
PiperOrigin-RevId: 477701988
2022-09-29 05:40:56 -07:00
jax authors
de5dd1a4a5
Merge pull request #12444 from LenaMartens:checkify-switch
...
PiperOrigin-RevId: 477688623
2022-09-29 04:18:18 -07:00
lenamartens
0639aced5b
Raise cond index into tracing context in case of effects.
...
So even if the cond is not data dependent at all, it's included in the
dynamic trace, and effects can be discharged.
2022-09-29 11:36:04 +01:00
jax authors
48b89560e5
Merge pull request #12566 from mattjj:djax-slice-sick
...
PiperOrigin-RevId: 477626935
2022-09-28 21:23:38 -07:00
Yash Katariya
163b7e22d2
Convert shardings in jit
path to OpShardingSharding to avoid recompilation when semantically similar shardings are used in jit
.
...
PiperOrigin-RevId: 477626548
2022-09-28 21:17:29 -07:00
Yash Katariya
500f8b7f9c
Add HLOSharding's repr to OpShardingSharding since its more compact.
...
PiperOrigin-RevId: 477587916
2022-09-28 17:00:16 -07:00
Yash Katariya
84768d2d49
Replace jax.xla.DeviceArray
private type with the new public type jax.Array
.
...
PiperOrigin-RevId: 477582562
2022-09-28 16:34:10 -07:00
Matthew Johnson
a8826e672b
[dynamic-shapes] Add basic slicing support
...
If e.g. `x : f32[10, n]` then we want to handle Python expressions like `x[0]`.
To do that, we can use a generalized version of `dynamic_slice` which allows
dynamic slice sizes (where the result shape depends on those slice sizes).
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-09-28 15:55:51 -07:00
jax authors
33dbf0ea1c
Merge pull request #12565 from hawkinsp:release
...
PiperOrigin-RevId: 477549228
2022-09-28 14:14:30 -07:00
Peter Hawkins
b49e31a012
Update version numbers after release.
2022-09-28 18:49:22 +00:00
Yash Katariya
c89cb5d8a4
Use Array
in __repr__
instead of the class name which is ArrayImpl
.
...
PiperOrigin-RevId: 477465432
2022-09-28 08:57:53 -07:00
jax authors
0282b4bfad
Merge pull request #12538 from jakevdp:bundle-pyi
...
PiperOrigin-RevId: 477453094
jax-v0.3.20
jaxlib-v0.3.20
jax-v0.3.20-rc
0.3.20
2022-09-28 08:00:20 -07:00
jax authors
aafc70d293
Merge pull request #12556 from hawkinsp:rocm
...
PiperOrigin-RevId: 477440001
2022-09-28 06:50:19 -07:00
jax authors
5fe7a5440f
Merge pull request #12555 from hawkinsp:release
...
PiperOrigin-RevId: 477439236
2022-09-28 06:50:06 -07:00
jax authors
39eabe878d
Merge pull request #12552 from hawkinsp:nccl
...
PiperOrigin-RevId: 477439228
2022-09-28 06:43:45 -07:00
Peter Hawkins
f7bafb3d4c
Disable multiprocess_gpu_test that fails on ROCm.
2022-09-28 13:40:57 +00:00
Peter Hawkins
8d8643664c
jax/jaxlib 0.3.20 release candidate.
2022-09-28 13:33:52 +00:00
Peter Hawkins
eabb91e53f
Fix test failure in GPU CI if NCCL_DEBUG is enabled.
...
If NCCL_DEBUG is enabled, NCCL prints extra status information. Make
test accept this.
2022-09-28 13:06:04 +00:00
jax authors
96abd9ac75
Merge pull request #12540 from sharadmv:cond-lowering-fix
...
PiperOrigin-RevId: 477358889
2022-09-27 22:33:12 -07:00
Yash Katariya
96a85bd59a
Make addressable_shards a property like local_shards
...
PiperOrigin-RevId: 477358276
2022-09-27 22:27:19 -07:00
jax authors
948906885d
Merge pull request #12546 from mattjj:issue12542
...
PiperOrigin-RevId: 477356925
2022-09-27 22:16:02 -07:00
Sharad Vikram
ddeaa8dbbc
Fix lowering bug in effectful batched cond and add tests
2022-09-27 22:12:13 -07:00
Yash Katariya
b4e1d0af8a
Propagate name
through ExecuteReplicated for dispatch.check_special
...
PiperOrigin-RevId: 477351323
2022-09-27 21:32:32 -07:00
Matthew Johnson
b175e11731
[c++ jit] only set use_fastpath in cache_miss if all args are DeviceArrays
...
fixes #12542
Co-authored-by: Peter Hawkins <phawkins@google.com>
Co-authored-by: Kuangyuan Chen <chky@google.com>
2022-09-27 20:51:07 -07:00
Yash Katariya
933b6a2fa4
Fix the bug where XLA doesn't provide shardings for all the outputs if all the elements in the output tuple have the same sharding. XLA decides to run the FusionTupleDeduplicator
to put the sharding on ROOT instead of the tuple.
...
PiperOrigin-RevId: 477343328
2022-09-27 20:27:39 -07:00
Yash Katariya
c8bff11d1b
Add addressable_
counterparts of local_
to GDA to make it easier for users to move to Array as both will have the same API.
...
PiperOrigin-RevId: 477332697
2022-09-27 19:19:29 -07:00
Yash Katariya
e4f2bff0a3
Disintegrate Array
into DeviceBuffers inside GDA. This is required for backwards compatibility changes as users can create GDAs and pass that to pjit even when Array is switched on.
...
PiperOrigin-RevId: 477297406
2022-09-27 16:02:23 -07:00
jax authors
0919a6776a
Merge pull request #12534 from google:update-pypi
...
PiperOrigin-RevId: 477260550
2022-09-27 13:31:05 -07:00
Jake VanderPlas
6e6fb10ca3
setup: bundle *.pyi files with distribution
2022-09-27 12:55:42 -07:00