13203 Commits

Author SHA1 Message Date
jax authors
0282b4bfad Merge pull request #12538 from jakevdp:bundle-pyi
PiperOrigin-RevId: 477453094
jax-v0.3.20 jaxlib-v0.3.20 jax-v0.3.20-rc 0.3.20
2022-09-28 08:00:20 -07:00
jax authors
aafc70d293 Merge pull request #12556 from hawkinsp:rocm
PiperOrigin-RevId: 477440001
2022-09-28 06:50:19 -07:00
jax authors
5fe7a5440f Merge pull request #12555 from hawkinsp:release
PiperOrigin-RevId: 477439236
2022-09-28 06:50:06 -07:00
jax authors
39eabe878d Merge pull request #12552 from hawkinsp:nccl
PiperOrigin-RevId: 477439228
2022-09-28 06:43:45 -07:00
Peter Hawkins
f7bafb3d4c Disable multiprocess_gpu_test that fails on ROCm. 2022-09-28 13:40:57 +00:00
Peter Hawkins
8d8643664c jax/jaxlib 0.3.20 release candidate. 2022-09-28 13:33:52 +00:00
Peter Hawkins
eabb91e53f Fix test failure in GPU CI if NCCL_DEBUG is enabled.
If NCCL_DEBUG is enabled, NCCL prints extra status information. Make
test accept this.
2022-09-28 13:06:04 +00:00
jax authors
96abd9ac75 Merge pull request #12540 from sharadmv:cond-lowering-fix
PiperOrigin-RevId: 477358889
2022-09-27 22:33:12 -07:00
Yash Katariya
96a85bd59a Make addressable_shards a property like local_shards
PiperOrigin-RevId: 477358276
2022-09-27 22:27:19 -07:00
jax authors
948906885d Merge pull request #12546 from mattjj:issue12542
PiperOrigin-RevId: 477356925
2022-09-27 22:16:02 -07:00
Sharad Vikram
ddeaa8dbbc Fix lowering bug in effectful batched cond and add tests 2022-09-27 22:12:13 -07:00
Yash Katariya
b4e1d0af8a Propagate name through ExecuteReplicated for dispatch.check_special
PiperOrigin-RevId: 477351323
2022-09-27 21:32:32 -07:00
Matthew Johnson
b175e11731 [c++ jit] only set use_fastpath in cache_miss if all args are DeviceArrays
fixes #12542

Co-authored-by: Peter Hawkins <phawkins@google.com>
Co-authored-by: Kuangyuan Chen <chky@google.com>
2022-09-27 20:51:07 -07:00
Yash Katariya
933b6a2fa4 Fix the bug where XLA doesn't provide shardings for all the outputs if all the elements in the output tuple have the same sharding. XLA decides to run the FusionTupleDeduplicator to put the sharding on ROOT instead of the tuple.
PiperOrigin-RevId: 477343328
2022-09-27 20:27:39 -07:00
Yash Katariya
c8bff11d1b Add addressable_ counterparts of local_ to GDA to make it easier for users to move to Array as both will have the same API.
PiperOrigin-RevId: 477332697
2022-09-27 19:19:29 -07:00
Yash Katariya
e4f2bff0a3 Disintegrate Array into DeviceBuffers inside GDA. This is required for backwards compatibility changes as users can create GDAs and pass that to pjit even when Array is switched on.
PiperOrigin-RevId: 477297406
2022-09-27 16:02:23 -07:00
jax authors
0919a6776a Merge pull request #12534 from google:update-pypi
PiperOrigin-RevId: 477260550
2022-09-27 13:31:05 -07:00
Jake VanderPlas
6e6fb10ca3 setup: bundle *.pyi files with distribution 2022-09-27 12:55:42 -07:00
Skye Wanderman-Milne
d028d93983 Update version and changelog for jax 0.3.19 release 2022-09-27 11:00:27 -07:00
Yash Katariya
9e4114f0f1 Move array.py and sharding.py from experimental/ to _src/.
PiperOrigin-RevId: 477201711
jax-v0.3.19
2022-09-27 10:06:52 -07:00
jax authors
0e116888ea Merge pull request #12382 from jakevdp:reduction-dtype
PiperOrigin-RevId: 477179725
2022-09-27 08:38:46 -07:00
jax authors
1bcf8d646d Merge pull request #12497 from mattjj:djax-dag-fix1
PiperOrigin-RevId: 477038279
2022-09-26 18:14:56 -07:00
jax authors
e42247bffb Merge pull request #12524 from sharadmv:lax-import-fix
PiperOrigin-RevId: 477038211
2022-09-26 18:08:45 -07:00
Yash Katariya
389a2e570d Add a backwards compat path for op_sharding.clone() because it doesn't exist with the latest jaxlib on pypi
PiperOrigin-RevId: 477034758
2022-09-26 17:50:19 -07:00
Matthew Johnson
1e7ca8f77a fix bug in djax type signature inference logic
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-09-26 17:48:25 -07:00
Sharad Vikram
1d895b2c85 Fix lax imports 2022-09-26 17:32:44 -07:00
Yash Katariya
cbf34cb609 Rename the concrete class Array to ArrayImpl
PiperOrigin-RevId: 477017236
2022-09-26 16:18:30 -07:00
Tianjian Lu
71bcabe499 [sparse] Add BCSR format template.
PiperOrigin-RevId: 477013899
2022-09-26 16:02:16 -07:00
jax authors
82636b0bcd Merge pull request #12523 from jakevdp:fix-build
PiperOrigin-RevId: 477005157
2022-09-26 15:23:02 -07:00
Jake VanderPlas
6cae54f82d Fix bazel build alias 2022-09-26 15:13:12 -07:00
Peter Hawkins
d63a9442bb Change jax_jit_test to be a jax_test() under Bazel that works across backends.
Make it pass under TPU if x64 types are enabled.

PiperOrigin-RevId: 476994286
2022-09-26 14:38:35 -07:00
Jake VanderPlas
265b39d23f Add pytype_srcs to main jax BUILD rule
PiperOrigin-RevId: 476989241
2022-09-26 14:18:13 -07:00
jax authors
ddd8581f38 Merge pull request #12480 from google:bug-template-gpu-smi
PiperOrigin-RevId: 476979981
2022-09-26 13:41:31 -07:00
Jake VanderPlas
1860f6d839 [x64] add promote_integers argument to jnp.prod & jnp.sum 2022-09-26 13:31:43 -07:00
jax authors
69d1a2c063 Merge pull request #12517 from skye:update-pypi
PiperOrigin-RevId: 476969287
2022-09-26 13:00:31 -07:00
Yash Katariya
b2b60d943e Add make_array_from_single_device_arrays to prepare to rename of the concrete Array to ArrayImpl.
PiperOrigin-RevId: 476965287
2022-09-26 12:43:59 -07:00
Skye Wanderman-Milne
3c0d280bc0 Update version and changelog for jax 0.3.18 release 2022-09-26 12:43:39 -07:00
Roy Frostig
2a7b3197e0 add nvidia-smi question to bug template 2022-09-26 11:06:29 -07:00
jax authors
e034432872 Merge pull request #12513 from inoryy:patch-4
PiperOrigin-RevId: 476923412
jax-v0.3.18
2022-09-26 10:04:14 -07:00
jax authors
7962b01f5d Merge pull request #12485 from LenaMartens:checkify-lower
PiperOrigin-RevId: 476922387
2022-09-26 09:53:40 -07:00
lenamartens
27e3981d52 lowerable errors behind a config flag. 2022-09-26 17:34:27 +01:00
Roman Ring
8bcf358fde
Remove unused _remat_static_argnums import. 2022-09-26 17:14:09 +01:00
lenamartens
78ecc1442c Lowerable checks!! 2022-09-26 16:54:18 +01:00
jax authors
28672cca0e Merge pull request #12496 from mattjj:improve-leak-checker-2
PiperOrigin-RevId: 476907407
2022-09-26 08:50:13 -07:00
jax authors
9c66569514 Merge pull request #12468 from LenaMartens:checkify-but-better
PiperOrigin-RevId: 476901601
2022-09-26 08:23:02 -07:00
jax authors
2df61b1aa1 Merge pull request #12421 from jakevdp:jax-array
PiperOrigin-RevId: 476898184
2022-09-26 08:07:11 -07:00
Jake VanderPlas
0cb233eec9 Add initial jax.Array base class for instance checks & annotation 2022-09-26 07:48:43 -07:00
jax authors
2180710c2a Merge pull request #12511 from hawkinsp:release
PiperOrigin-RevId: 476889960
jax-v0.3.18-rc
2022-09-26 07:24:44 -07:00
Peter Hawkins
bcd36d8eb2 Jax and jaxlib 0.3.18 release candidate. 2022-09-26 14:10:57 +00:00
jax authors
53de057748 Merge pull request #12510 from hawkinsp:context
PiperOrigin-RevId: 476884674
2022-09-26 06:58:46 -07:00