Parker Schuh
568a93bcd1
Convert _arrays to return PyArray instead of PyBuffer.
...
PiperOrigin-RevId: 508769390
2023-02-10 15:32:57 -08:00
Peter Hawkins
8268cd562d
Add infrastructure for managing deprecations.
...
Use it to deprecate jax.experimental.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.interpreters.pxla.Mesh.
PiperOrigin-RevId: 508349776
2023-02-09 05:48:40 -08:00
Peter Hawkins
6860cb8d2a
Move jax.interpreters.xla to jax._src.interpreters.xla.
...
Replace jax.interpreters.xla with a shim that re-exports names that are likely to be used externally.
PiperOrigin-RevId: 507895040
2023-02-07 15:01:32 -08:00
Peter Hawkins
98b75cf27b
Prune accidental exports from jax.interpreters.pxla.
...
These imports do not appear to have users outside JAX itself.
PiperOrigin-RevId: 507835295
2023-02-07 11:16:42 -08:00
Peter Hawkins
428189f8fb
Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
...
This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.
PiperOrigin-RevId: 506994892
2023-02-03 14:28:45 -08:00
lenamartens
0fe159b67e
Make Shard.device and Shard.data read-only properties.
2023-01-05 14:27:17 +00:00
Yash Katariya
1fc9197c79
Simplify Array's shard_arg_handler by merging pmap and pjit/xmap paths
...
PiperOrigin-RevId: 497991966
2022-12-27 10:16:44 -08:00
Yash Katariya
dbc39449b7
Remove more checks now that the minimum jaxlib version corresponds to xla_extension_version == 109. Also remove usage of xc._version
and replace it with xla_extension_version
.
...
PiperOrigin-RevId: 496474494
2022-12-19 13:15:07 -08:00
Roy Frostig
d927a5dbf3
migrate internal dependencies from jax.core
to jax._src.core
...
... in preparation for paring down `jax.core`'s exported symbols.
Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.
PiperOrigin-RevId: 496024782
2022-12-16 21:00:14 -08:00
Peter Hawkins
5e102c17d6
Implement .on_device_size_in_bytes() on jax.Array.
...
This is an array present in DeviceArray that is missing from Array.
PiperOrigin-RevId: 492571171
2022-12-02 15:11:27 -08:00
Yash Katariya
4443b861a5
Remove local imports of array.py. The remaining local imports are in pxla.py but I will chip away at them when we delete SDA and move some more APIs out of experimental.
...
PiperOrigin-RevId: 492033543
2022-11-30 15:26:03 -08:00
Yash Katariya
c4d91d203c
Remove local_imports of sharding.py. Adding pxla
local imports but then cleaning those up will be super easy since those will be the only ones left and restricted to sharding.py
file only.
...
Also remove `maybe_cached_property` from this CL since we are dropping 3.7 support
PiperOrigin-RevId: 491769101
2022-11-29 16:42:03 -08:00
Yash Katariya
6897d37562
Add docstrings for jax.Array APIs make_array_from_callback
and make_array_from_single_device_arrays
.
...
PiperOrigin-RevId: 487929688
2022-11-11 15:21:10 -08:00
Yash Katariya
32a0ea80ef
Add global_shards
to jax.Array
as it exists on GDA and is being used in various places.
...
PiperOrigin-RevId: 485065876
2022-10-31 09:08:03 -07:00
Yash Katariya
1a0affddd8
Move is_deleted()
to C++ so that we can check if an Array is deleted without materializing _arrays
.
...
Also raise a better error message when doing operations of a deleted Array rather than the current thing which says: `NoneType has no len()`. Now it says: `Array has been deleted`.
PiperOrigin-RevId: 482497114
2022-10-20 08:29:25 -07:00
Yash Katariya
335f45ebb2
Use _rewriting_take
and _chunk_iter
path during __getitem__
and __iter__
respectively when the Array is fully replicated
...
For example:
```
k1, k2 = jax.random.split(key, 2) # where key is fully replicated on 8 devices
```
Then `k1` and `k2` should also maintain the sharding of `key` since `key` is fully replicated.
PiperOrigin-RevId: 480434272
2022-10-11 13:09:33 -07:00
Yash Katariya
9b3e864731
Add weak_type attribute to Array
since it exists on DA (but doesn't exist on SDA).
...
PiperOrigin-RevId: 480223116
2022-10-10 18:11:11 -07:00
Yash Katariya
76d8c08317
Fix the type annotation of return type of device_buffer
and device_buffers
which return ArrayImpl
instead of DeviceArray.
...
PiperOrigin-RevId: 480181798
2022-10-10 14:45:12 -07:00
Yash Katariya
75b2d05989
Make is_fully_replicated
and is_fully_addressble
a property rather than a method.
...
Why?
1. Because it's easy to cache a property than a method with only the `self` argument. (See below for article)
2. There's no harm in making them a property because both of them return a bool without any side-effects and are cached (so its fast). Why cache `is_fully_addressable`? Because its very expensive to calculate when you have 1000s of devices.
PiperOrigin-RevId: 479850850
2022-10-08 19:24:12 -07:00
jax authors
674038ca47
Merge pull request #12705 from mattjj:fix-prng-key-array-device-put
...
PiperOrigin-RevId: 479813689
2022-10-08 11:39:05 -07:00
Matthew Johnson
0a0f492a3d
make device_put(prngkeyarray, sharding) for Array
...
Co-authored-by: Yash Katariya <yashkatariya@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2022-10-07 16:50:16 -07:00
Yash Katariya
37f9db77f7
Create Array
s from __getitem__
and __iter__
. This is done by device_put
ting from the host to default device which is suboptimal. But there is a TODO to fix this!
...
PiperOrigin-RevId: 478691051
2022-10-03 22:29:03 -07:00
Yash Katariya
aafc77d3c0
Improve the checks done in Array
and apply them to all Sharding
s rather than just XLACompatibleSharding
.
...
Also check the symmetric difference of sharding and `_arrays` devices.
PiperOrigin-RevId: 478017409
2022-09-30 09:56:16 -07:00
Yash Katariya
c89cb5d8a4
Use Array
in __repr__
instead of the class name which is ArrayImpl
.
...
PiperOrigin-RevId: 477465432
2022-09-28 08:57:53 -07:00
Yash Katariya
9e4114f0f1
Move array.py
and sharding.py
from experimental/
to _src/
.
...
PiperOrigin-RevId: 477201711
2022-09-27 10:06:52 -07:00