jax authors
96abd9ac75
Merge pull request #12540 from sharadmv:cond-lowering-fix
...
PiperOrigin-RevId: 477358889
2022-09-27 22:33:12 -07:00
Yash Katariya
96a85bd59a
Make addressable_shards a property like local_shards
...
PiperOrigin-RevId: 477358276
2022-09-27 22:27:19 -07:00
jax authors
948906885d
Merge pull request #12546 from mattjj:issue12542
...
PiperOrigin-RevId: 477356925
2022-09-27 22:16:02 -07:00
Sharad Vikram
ddeaa8dbbc
Fix lowering bug in effectful batched cond and add tests
2022-09-27 22:12:13 -07:00
Yash Katariya
b4e1d0af8a
Propagate name
through ExecuteReplicated for dispatch.check_special
...
PiperOrigin-RevId: 477351323
2022-09-27 21:32:32 -07:00
Matthew Johnson
b175e11731
[c++ jit] only set use_fastpath in cache_miss if all args are DeviceArrays
...
fixes #12542
Co-authored-by: Peter Hawkins <phawkins@google.com>
Co-authored-by: Kuangyuan Chen <chky@google.com>
2022-09-27 20:51:07 -07:00
Yash Katariya
933b6a2fa4
Fix the bug where XLA doesn't provide shardings for all the outputs if all the elements in the output tuple have the same sharding. XLA decides to run the FusionTupleDeduplicator
to put the sharding on ROOT instead of the tuple.
...
PiperOrigin-RevId: 477343328
2022-09-27 20:27:39 -07:00
Yash Katariya
c8bff11d1b
Add addressable_
counterparts of local_
to GDA to make it easier for users to move to Array as both will have the same API.
...
PiperOrigin-RevId: 477332697
2022-09-27 19:19:29 -07:00
Yash Katariya
e4f2bff0a3
Disintegrate Array
into DeviceBuffers inside GDA. This is required for backwards compatibility changes as users can create GDAs and pass that to pjit even when Array is switched on.
...
PiperOrigin-RevId: 477297406
2022-09-27 16:02:23 -07:00
Skye Wanderman-Milne
d028d93983
Update version and changelog for jax 0.3.19 release
2022-09-27 11:00:27 -07:00
Yash Katariya
9e4114f0f1
Move array.py
and sharding.py
from experimental/
to _src/
.
...
PiperOrigin-RevId: 477201711
2022-09-27 10:06:52 -07:00
jax authors
0e116888ea
Merge pull request #12382 from jakevdp:reduction-dtype
...
PiperOrigin-RevId: 477179725
2022-09-27 08:38:46 -07:00
jax authors
1bcf8d646d
Merge pull request #12497 from mattjj:djax-dag-fix1
...
PiperOrigin-RevId: 477038279
2022-09-26 18:14:56 -07:00
jax authors
e42247bffb
Merge pull request #12524 from sharadmv:lax-import-fix
...
PiperOrigin-RevId: 477038211
2022-09-26 18:08:45 -07:00
Yash Katariya
389a2e570d
Add a backwards compat path for op_sharding.clone()
because it doesn't exist with the latest jaxlib on pypi
...
PiperOrigin-RevId: 477034758
2022-09-26 17:50:19 -07:00
Matthew Johnson
1e7ca8f77a
fix bug in djax type signature inference logic
...
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-09-26 17:48:25 -07:00
Sharad Vikram
1d895b2c85
Fix lax imports
2022-09-26 17:32:44 -07:00
Yash Katariya
cbf34cb609
Rename the concrete class Array
to ArrayImpl
...
PiperOrigin-RevId: 477017236
2022-09-26 16:18:30 -07:00
Tianjian Lu
71bcabe499
[sparse] Add BCSR format template.
...
PiperOrigin-RevId: 477013899
2022-09-26 16:02:16 -07:00
Jake VanderPlas
265b39d23f
Add pytype_srcs to main jax BUILD rule
...
PiperOrigin-RevId: 476989241
2022-09-26 14:18:13 -07:00
Jake VanderPlas
1860f6d839
[x64] add promote_integers argument to jnp.prod & jnp.sum
2022-09-26 13:31:43 -07:00
jax authors
69d1a2c063
Merge pull request #12517 from skye:update-pypi
...
PiperOrigin-RevId: 476969287
2022-09-26 13:00:31 -07:00
Yash Katariya
b2b60d943e
Add make_array_from_single_device_arrays
to prepare to rename of the concrete Array
to ArrayImpl
.
...
PiperOrigin-RevId: 476965287
2022-09-26 12:43:59 -07:00
Skye Wanderman-Milne
3c0d280bc0
Update version and changelog for jax 0.3.18 release
2022-09-26 12:43:39 -07:00
jax authors
e034432872
Merge pull request #12513 from inoryy:patch-4
...
PiperOrigin-RevId: 476923412
2022-09-26 10:04:14 -07:00
lenamartens
27e3981d52
lowerable errors behind a config flag.
2022-09-26 17:34:27 +01:00
Roman Ring
8bcf358fde
Remove unused _remat_static_argnums import.
2022-09-26 17:14:09 +01:00
lenamartens
78ecc1442c
Lowerable checks!!
2022-09-26 16:54:18 +01:00
jax authors
9c66569514
Merge pull request #12468 from LenaMartens:checkify-but-better
...
PiperOrigin-RevId: 476901601
2022-09-26 08:23:02 -07:00
Jake VanderPlas
0cb233eec9
Add initial jax.Array base class for instance checks & annotation
2022-09-26 07:48:43 -07:00
jax authors
ec15e83018
- Wraps calls to lax.xeinsum and _einsum in a named call with their 'spec', the string specifying the computation. Makes xprof traces more interpretable.
...
PiperOrigin-RevId: 476796185
2022-09-25 20:54:17 -07:00
Yash Katariya
7c85ca38f4
Only look at hlo_modules for output sharding if there is more than 1 device because if there is only 1 device, the spmd partitioner won't run.
...
PiperOrigin-RevId: 476497929
2022-09-23 17:31:33 -07:00
Yash Katariya
1fa0dda760
Return single device Arrays from .device_buffer
and .device_buffers
.
...
PiperOrigin-RevId: 476449591
2022-09-23 13:30:26 -07:00
jax authors
43bbce0cc6
Merge pull request #12486 from hawkinsp:debugging
...
PiperOrigin-RevId: 476445041
2022-09-23 13:09:26 -07:00
jax authors
737327a42d
Merge pull request #12490 from mattjj:improve-leak-checker
...
PiperOrigin-RevId: 476442352
2022-09-23 12:58:03 -07:00
Matthew Johnson
b6ef90ffdd
fix leak checker internal error
...
The issue was that partial_eval.py's _memoize, used in custom_jvp, was made
into an identity function by enabling config.jax_check_tracer_leaks (from
references to the main trace (needed for the jvp_jaxpr thunk) and hence trigger
the leak checker (which would see if any references to the main trace persisted
after finishing tracing of the user function).
But after #7345 , the leak checker should only trigger when actual Tracers are
leaked. So disabling the memoization when jax_check_tracer_leaks is no longer
active shouldn't be necessary. (These PR numbers seem out of order! We're not
sure why.)
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-09-23 12:33:45 -07:00
Yash Katariya
ecb27a9b24
Update the _check_special
code to not use xla_shape since its deprecated and does not work with Array.
...
PiperOrigin-RevId: 476422732
2022-09-23 11:40:32 -07:00
jax authors
d078f3f5fc
Merge pull request #12478 from sharadmv:sharding-docs
...
PiperOrigin-RevId: 476420315
2022-09-23 11:31:37 -07:00
Ke Wu
c823151771
Allow transpose axes to be negative to match (undocumented) NumPy behavior
2022-09-23 10:18:23 -07:00
Peter Hawkins
38fb8ed22f
Fix copyright attribution for some newly added files.
...
PiperOrigin-RevId: 476390902
2022-09-23 09:32:47 -07:00
Yash Katariya
c8f55414fc
Convert the devices in the Mesh
constructor to a numpy array if its a list, tuple, etc.
...
PiperOrigin-RevId: 476380496
2022-09-23 08:48:31 -07:00
Peter Hawkins
a88c5ad789
Fix xla extension version test in debugging.py
...
The custom call partitioner callback was not present in version 94 but is present in version 95.
2022-09-23 10:53:50 -04:00
jax authors
254dc24a8b
Merge pull request #11961 from jakeh-gc:plugin_device
...
PiperOrigin-RevId: 476363760
2022-09-23 07:29:17 -07:00
lenamartens
7078f81dd0
Checkify: misc improvements.
...
- err.throw == check_error(err) -> meaning they have the same behavior
under checkify now
- "divided by zero" -> "division by zero"
- add validation that check_error only takes args of type Error
2022-09-23 14:33:06 +01:00
Peter Hawkins
eed327914e
Improve documentation for unique_indices.
2022-09-23 09:11:15 -04:00
Tianjian Lu
67b7ae259f
[sparse] Move _bcoo_nse
to sparse util.
...
PiperOrigin-RevId: 476263483
2022-09-22 20:22:06 -07:00
Sharad Vikram
99d4d8b89a
Update debugging docs to have sharding visualization
2022-09-22 19:42:36 -07:00
Sharad Vikram
805073f36a
Add inspect_array_sharding, enabling looking at shardings in pjit-ted functions
...
PiperOrigin-RevId: 476237731
2022-09-22 17:36:56 -07:00
Jake VanderPlas
3d23592cf6
[array] full_like: only match sharding if shape==None
2022-09-22 15:28:59 -07:00
jax authors
dfdf00c2eb
Merge pull request #12472 from google:sharadmv-patch-2
...
PiperOrigin-RevId: 476190259
2022-09-22 14:00:07 -07:00