47 Commits

Author SHA1 Message Date
Roy Frostig
c8b9280fb3 partitionable threefry PRNG random bits implementation
the cost is 2x overgeneration of bits

Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-10-28 10:07:14 -07:00
Matthew Johnson
95eb4249bb tweaks to DevicesSharding
1. rename DevicesSharding -> ReshapeableDevicesSharding
2. fix repr to print device order faithfully
3. respect shape of np.ndarray argument to __init__
2022-10-25 14:28:48 -07:00
Yash Katariya
adcb0f58e8 Add __repr__ and __str__ to PmapSharding.
Fixes https://github.com/google/jax/issues/12971

PiperOrigin-RevId: 483707874
2022-10-25 10:13:02 -07:00
Yash Katariya
1a0affddd8 Move is_deleted() to C++ so that we can check if an Array is deleted without materializing _arrays.
Also raise a better error message when doing operations of a deleted Array rather than the current thing which says: `NoneType has no len()`. Now it says: `Array has been deleted`.

PiperOrigin-RevId: 482497114
2022-10-20 08:29:25 -07:00
Matthew Johnson
43098f906a initial commit of DevicesSharding (fka SimpleSharding)
need to add tests!

Co-authored-by: Yash Katariya <yashkatariya@google.com>
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-10-18 21:10:24 -07:00
Yash Katariya
335f45ebb2 Use _rewriting_take and _chunk_iter path during __getitem__ and __iter__ respectively when the Array is fully replicated
For example:

```
k1, k2 = jax.random.split(key, 2) # where key is fully replicated on 8 devices
```

Then `k1` and `k2` should also maintain the sharding of `key` since `key` is fully replicated.

PiperOrigin-RevId: 480434272
2022-10-11 13:09:33 -07:00
Yash Katariya
9b3e864731 Add weak_type attribute to Array since it exists on DA (but doesn't exist on SDA).
PiperOrigin-RevId: 480223116
2022-10-10 18:11:11 -07:00
Yash Katariya
75b2d05989 Make is_fully_replicated and is_fully_addressble a property rather than a method.
Why?

1. Because it's easy to cache a property than a method with only the `self` argument. (See below for article)

2. There's no harm in making them a property because both of them return a bool without any side-effects and are cached (so its fast). Why cache `is_fully_addressable`? Because its very expensive to calculate when you have 1000s of devices.

PiperOrigin-RevId: 479850850
2022-10-08 19:24:12 -07:00
jax authors
e8ba61d82b Merge pull request #12677 from mattjj:jit-pjit-lower-sharding
PiperOrigin-RevId: 479669125
2022-10-07 14:28:51 -07:00
Matthew Johnson
bcca6fb57a add test, small fixes
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2022-10-06 16:45:34 -07:00
Yash Katariya
37f9db77f7 Create Arrays from __getitem__ and __iter__. This is done by device_putting from the host to default device which is suboptimal. But there is a TODO to fix this!
PiperOrigin-RevId: 478691051
2022-10-03 22:29:03 -07:00
Yash Katariya
aafc77d3c0 Improve the checks done in Array and apply them to all Shardings rather than just XLACompatibleSharding.
Also check the symmetric difference of sharding and `_arrays` devices.

PiperOrigin-RevId: 478017409
2022-09-30 09:56:16 -07:00
Yash Katariya
500f8b7f9c Add HLOSharding's repr to OpShardingSharding since its more compact.
PiperOrigin-RevId: 477587916
2022-09-28 17:00:16 -07:00
Yash Katariya
c89cb5d8a4 Use Array in __repr__ instead of the class name which is ArrayImpl.
PiperOrigin-RevId: 477465432
2022-09-28 08:57:53 -07:00
Yash Katariya
9e4114f0f1 Move array.py and sharding.py from experimental/ to _src/.
PiperOrigin-RevId: 477201711
2022-09-27 10:06:52 -07:00
Yash Katariya
cbf34cb609 Rename the concrete class Array to ArrayImpl
PiperOrigin-RevId: 477017236
2022-09-26 16:18:30 -07:00
Yash Katariya
1fa0dda760 Return single device Arrays from .device_buffer and .device_buffers.
PiperOrigin-RevId: 476449591
2022-09-23 13:30:26 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Yash Katariya
52476d1ab5 Add addressable_data to Array (similar to GDA) to aid in transition and also in auto spmd partitioner mode, always convert to MeshPspecSharding.
PiperOrigin-RevId: 475972534
2022-09-21 18:19:35 -07:00
Jake VanderPlas
bdb264e60f jax.Array: add issubclass test analogous to existing DeviceArray test. 2022-09-14 16:29:38 -07:00
Jake VanderPlas
8eb44fd195 jax_array_test: set config once & fix X64 failure 2022-09-13 14:06:38 -07:00
Yash Katariya
120b2801fd Bounce to host for any sharding that's not PmapSharding or a sharding with a single device for __iter__ and __getitem__.
PiperOrigin-RevId: 473402857
2022-09-09 20:41:40 -07:00
Yash Katariya
d7726e7b26 Make __getitem__ work for PmapSharding just like SDA works. DA is already covered with the current implementation.
Added TODOs to take fast path for indices wherever it is possible to do that. If a correct index is passed during getitem and if that index exists on `Array`, then the fast path is taken (see the test in this CL).

PiperOrigin-RevId: 473342504
2022-09-09 14:25:22 -07:00
Yash Katariya
4746a39e59 Show the correct sharding in is_compatible_aval error in MeshPspecSharding when created via _from_parsed_pspec. Preserve the original PartitionSpec from ParsedPartitionSpec if it exists, else calculate it.
PiperOrigin-RevId: 473267905
2022-09-09 09:13:42 -07:00
Yash Katariya
751be205d8 Add a test for jnp.array(..., copy=True) for Array.
PiperOrigin-RevId: 473041718
2022-09-08 11:29:51 -07:00
Yash Katariya
7fbf8ec669 Fix Forward. The fix is on the user's end. Original PR: https://github.com/google/jax/pull/12217
Co-authored-by: Matthew Johnson <mattjj@google.com>
Co-authored-by: Yash Katariya <yashkatariya@google.com>
PiperOrigin-RevId: 472999907
2022-09-08 08:49:40 -07:00
jax authors
14f1a345a1 roll back breakage
PiperOrigin-RevId: 472949225
2022-09-08 03:59:54 -07:00
Yash Katariya
b7e4e44cbf DCE jaxpr and trivial_jaxpr support for lower_sharding_computation
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 471274989
2022-09-06 14:09:10 -07:00
Yash Katariya
3e54ac0af0 Make __iter__ of Array behave like DA when there is a SingleDeviceSharding and like SDA when there is a non-trivial sharding.
This is important because when `Array` contains more than 1 shard, each shard can be on a different device and those things need to be preserved when iterating over `Array`.

PiperOrigin-RevId: 471695841
2022-09-01 19:54:34 -07:00
Yash Katariya
2f7951b3dc Add __hash__ and __eq__ to PmapSharding
PiperOrigin-RevId: 471356052
2022-08-31 14:27:16 -07:00
Yash Katariya
da24b99d30 Some minor changes to make_array_from_callback to use the device_indices_map method and calculate the indices just once. Also set the _committed attribute of shards to what the parent Array has.
PiperOrigin-RevId: 471167295
2022-08-30 21:57:21 -07:00
Yash Katariya
fc7a71dc89 Remove device_replica_id_map from the Sharding interface because the standalone function should be more than enough to use. The major use-case of this is for checkpointing and accessing addressable_shards which accesses the standalone function makes it work.
PiperOrigin-RevId: 470820443
2022-08-29 14:49:48 -07:00
Yash Katariya
70a7ee22cc Check if the buffer shape matches the excepted shard shape by Array.
PiperOrigin-RevId: 470732792
2022-08-29 09:00:41 -07:00
Peter Hawkins
78cb9f8492 Avoid more direct references to jax._src without imports.
Change in preparation for not exporting jax._src by default.

PiperOrigin-RevId: 469725340
2022-08-24 07:51:28 -07:00
Yash Katariya
e8ec454ae8 Enable fast path in the Array constructor. This means that the rearranging of _arrays according to the device_assignment won't happen when fastpath is enabled because we assume that jax transformations will return the right arrangement.
PiperOrigin-RevId: 469492283
2022-08-23 10:20:26 -07:00
Peter Hawkins
335b2cfb26 [JAX] Prepare not to export jax._src by default.
Currently
```
import jax
```
populates `jax._src` in the names exported from JAX. This change prepares for not exporting `jax._src` by default.

In particular, explicitly import modules from jax._src and refer to those imports rather than assuming jax._src contents will be around later. This is a common pattern in tests.

This change does not yet remove any exported names.

Issue https://github.com/google/jax/issues/11951

PiperOrigin-RevId: 469480816
2022-08-23 09:36:47 -07:00
Yash Katariya
f42151c3dc Take the pjit XLA compilation path for Arrays. In the test, astype happens in a sharded fashion without the round trip to host.
PiperOrigin-RevId: 468510366
2022-08-18 11:44:39 -07:00
Yash Katariya
acdae7c237 Add weak type support to Array. Also make all api_test.py tests pass with Array. I have disabled the float0 test for now until I investigate.
PiperOrigin-RevId: 468264910
2022-08-17 12:25:49 -07:00
Yash Katariya
9040d5c1f7 Add support for returning arrays from full_like that match the sharding of the input Array.
PiperOrigin-RevId: 468053219
2022-08-16 16:25:21 -07:00
Yash Katariya
647186a4aa Make _device_assignment and _addressable_device_assignment a cached property because having a lru_cache as a decorator on the method with only self leaks memory. See the document below. In this case, its okay to make them a property rather than a method. Its also consistent with device_set.
PiperOrigin-RevId: 465339349
2022-08-04 10:00:01 -07:00
Yash Katariya
2178f339bd Add OpShardingSharding and add a function that can calculate indices from an opsharding proto
PiperOrigin-RevId: 464123009
2022-07-29 11:37:39 -07:00
Yash Katariya
6d8c6f8fac Make astype work for Array that are sharded. The current behavior is the same as SDA i.e. it round trips via host.
PiperOrigin-RevId: 457797458
2022-06-28 12:49:12 -07:00
Yash Katariya
e32373c3ea Make jnp.array return jax.Array. Add input and result handlers for jax.Array. Also added tests for add under jit.
TODO:
* Don't allow `x + y` if `jax.Array` is not fully addressable.
* Figure out how to use the already written tests with Array. Might be able to follow the path taken by SDA.
PiperOrigin-RevId: 457034779
2022-06-24 10:05:06 -07:00
Yash Katariya
1b21d2c3f5 Return Array from jax.device_put if config.jax_array is enabled.
PiperOrigin-RevId: 456531510
2022-06-22 09:20:56 -07:00
Yash Katariya
b2e1d814ac Add __repr__ to Array. It works exactly as it does for DA and SDA when it is fully addressable. Otherwise it works like GDA.
TODO is adding weak_type support in general and to `__repr__`.

PiperOrigin-RevId: 455680796
2022-06-17 13:16:49 -07:00
Yash Katariya
a7160653ce Add __array__ (for device_get), _npy_value, block_until_ready, delete and _check_if_deleted to Array.
PiperOrigin-RevId: 454741685
2022-06-13 18:08:31 -07:00
Yash Katariya
123413751c Adding jax.Array to jax.experimental. Its pretty much the same as GDA (without the performance optimization for now).
Currently, jax.Array takes DeviceArrays in `assemble_array` because device_put returns a DA. In the future (with IFRT), it will return an `Array`.

`addressable_shards` wraps DA into jax.Array with a `SingleDeviceSharding`.

PiperOrigin-RevId: 453319811
2022-06-06 17:32:00 -07:00