175 Commits

Author SHA1 Message Date
Parker Schuh
568a93bcd1 Convert _arrays to return PyArray instead of PyBuffer.
PiperOrigin-RevId: 508769390
2023-02-10 15:32:57 -08:00
Peter Hawkins
8268cd562d Add infrastructure for managing deprecations.
Use it to deprecate jax.experimental.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.interpreters.pxla.Mesh.

PiperOrigin-RevId: 508349776
2023-02-09 05:48:40 -08:00
Peter Hawkins
6860cb8d2a Move jax.interpreters.xla to jax._src.interpreters.xla.
Replace jax.interpreters.xla with a shim that re-exports names that are likely to be used externally.

PiperOrigin-RevId: 507895040
2023-02-07 15:01:32 -08:00
Peter Hawkins
98b75cf27b Prune accidental exports from jax.interpreters.pxla.
These imports do not appear to have users outside JAX itself.

PiperOrigin-RevId: 507835295
2023-02-07 11:16:42 -08:00
Peter Hawkins
428189f8fb Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.

PiperOrigin-RevId: 506994892
2023-02-03 14:28:45 -08:00
lenamartens
0fe159b67e Make Shard.device and Shard.data read-only properties. 2023-01-05 14:27:17 +00:00
Yash Katariya
1fc9197c79 Simplify Array's shard_arg_handler by merging pmap and pjit/xmap paths
PiperOrigin-RevId: 497991966
2022-12-27 10:16:44 -08:00
Yash Katariya
dbc39449b7 Remove more checks now that the minimum jaxlib version corresponds to xla_extension_version == 109. Also remove usage of xc._version and replace it with xla_extension_version.
PiperOrigin-RevId: 496474494
2022-12-19 13:15:07 -08:00
Roy Frostig
d927a5dbf3 migrate internal dependencies from jax.core to jax._src.core
... in preparation for paring down `jax.core`'s exported symbols.

Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.

PiperOrigin-RevId: 496024782
2022-12-16 21:00:14 -08:00
Peter Hawkins
5e102c17d6 Implement .on_device_size_in_bytes() on jax.Array.
This is an array present in DeviceArray that is missing from Array.

PiperOrigin-RevId: 492571171
2022-12-02 15:11:27 -08:00
Yash Katariya
4443b861a5 Remove local imports of array.py. The remaining local imports are in pxla.py but I will chip away at them when we delete SDA and move some more APIs out of experimental.
PiperOrigin-RevId: 492033543
2022-11-30 15:26:03 -08:00
Yash Katariya
c4d91d203c Remove local_imports of sharding.py. Adding pxla local imports but then cleaning those up will be super easy since those will be the only ones left and restricted to sharding.py file only.
Also remove `maybe_cached_property` from this CL since we are dropping 3.7 support

PiperOrigin-RevId: 491769101
2022-11-29 16:42:03 -08:00
Yash Katariya
6897d37562 Add docstrings for jax.Array APIs make_array_from_callback and make_array_from_single_device_arrays.
PiperOrigin-RevId: 487929688
2022-11-11 15:21:10 -08:00
Yash Katariya
32a0ea80ef Add global_shards to jax.Array as it exists on GDA and is being used in various places.
PiperOrigin-RevId: 485065876
2022-10-31 09:08:03 -07:00
Yash Katariya
1a0affddd8 Move is_deleted() to C++ so that we can check if an Array is deleted without materializing _arrays.
Also raise a better error message when doing operations of a deleted Array rather than the current thing which says: `NoneType has no len()`. Now it says: `Array has been deleted`.

PiperOrigin-RevId: 482497114
2022-10-20 08:29:25 -07:00
Yash Katariya
335f45ebb2 Use _rewriting_take and _chunk_iter path during __getitem__ and __iter__ respectively when the Array is fully replicated
For example:

```
k1, k2 = jax.random.split(key, 2) # where key is fully replicated on 8 devices
```

Then `k1` and `k2` should also maintain the sharding of `key` since `key` is fully replicated.

PiperOrigin-RevId: 480434272
2022-10-11 13:09:33 -07:00
Yash Katariya
9b3e864731 Add weak_type attribute to Array since it exists on DA (but doesn't exist on SDA).
PiperOrigin-RevId: 480223116
2022-10-10 18:11:11 -07:00
Yash Katariya
76d8c08317 Fix the type annotation of return type of device_buffer and device_buffers which return ArrayImpl instead of DeviceArray.
PiperOrigin-RevId: 480181798
2022-10-10 14:45:12 -07:00
Yash Katariya
75b2d05989 Make is_fully_replicated and is_fully_addressble a property rather than a method.
Why?

1. Because it's easy to cache a property than a method with only the `self` argument. (See below for article)

2. There's no harm in making them a property because both of them return a bool without any side-effects and are cached (so its fast). Why cache `is_fully_addressable`? Because its very expensive to calculate when you have 1000s of devices.

PiperOrigin-RevId: 479850850
2022-10-08 19:24:12 -07:00
jax authors
674038ca47 Merge pull request #12705 from mattjj:fix-prng-key-array-device-put
PiperOrigin-RevId: 479813689
2022-10-08 11:39:05 -07:00
Matthew Johnson
0a0f492a3d make device_put(prngkeyarray, sharding) for Array
Co-authored-by: Yash Katariya <yashkatariya@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2022-10-07 16:50:16 -07:00
Yash Katariya
37f9db77f7 Create Arrays from __getitem__ and __iter__. This is done by device_putting from the host to default device which is suboptimal. But there is a TODO to fix this!
PiperOrigin-RevId: 478691051
2022-10-03 22:29:03 -07:00
Yash Katariya
aafc77d3c0 Improve the checks done in Array and apply them to all Shardings rather than just XLACompatibleSharding.
Also check the symmetric difference of sharding and `_arrays` devices.

PiperOrigin-RevId: 478017409
2022-09-30 09:56:16 -07:00
Yash Katariya
c89cb5d8a4 Use Array in __repr__ instead of the class name which is ArrayImpl.
PiperOrigin-RevId: 477465432
2022-09-28 08:57:53 -07:00
Yash Katariya
9e4114f0f1 Move array.py and sharding.py from experimental/ to _src/.
PiperOrigin-RevId: 477201711
2022-09-27 10:06:52 -07:00