16 Commits

Author SHA1 Message Date
Yash Katariya
fc7a71dc89 Remove device_replica_id_map from the Sharding interface because the standalone function should be more than enough to use. The major use-case of this is for checkpointing and accessing addressable_shards which accesses the standalone function makes it work.
PiperOrigin-RevId: 470820443
2022-08-29 14:49:48 -07:00
Yash Katariya
70a7ee22cc Check if the buffer shape matches the excepted shard shape by Array.
PiperOrigin-RevId: 470732792
2022-08-29 09:00:41 -07:00
Peter Hawkins
78cb9f8492 Avoid more direct references to jax._src without imports.
Change in preparation for not exporting jax._src by default.

PiperOrigin-RevId: 469725340
2022-08-24 07:51:28 -07:00
Yash Katariya
e8ec454ae8 Enable fast path in the Array constructor. This means that the rearranging of _arrays according to the device_assignment won't happen when fastpath is enabled because we assume that jax transformations will return the right arrangement.
PiperOrigin-RevId: 469492283
2022-08-23 10:20:26 -07:00
Peter Hawkins
335b2cfb26 [JAX] Prepare not to export jax._src by default.
Currently
```
import jax
```
populates `jax._src` in the names exported from JAX. This change prepares for not exporting `jax._src` by default.

In particular, explicitly import modules from jax._src and refer to those imports rather than assuming jax._src contents will be around later. This is a common pattern in tests.

This change does not yet remove any exported names.

Issue https://github.com/google/jax/issues/11951

PiperOrigin-RevId: 469480816
2022-08-23 09:36:47 -07:00
Yash Katariya
f42151c3dc Take the pjit XLA compilation path for Arrays. In the test, astype happens in a sharded fashion without the round trip to host.
PiperOrigin-RevId: 468510366
2022-08-18 11:44:39 -07:00
Yash Katariya
acdae7c237 Add weak type support to Array. Also make all api_test.py tests pass with Array. I have disabled the float0 test for now until I investigate.
PiperOrigin-RevId: 468264910
2022-08-17 12:25:49 -07:00
Yash Katariya
9040d5c1f7 Add support for returning arrays from full_like that match the sharding of the input Array.
PiperOrigin-RevId: 468053219
2022-08-16 16:25:21 -07:00
Yash Katariya
647186a4aa Make _device_assignment and _addressable_device_assignment a cached property because having a lru_cache as a decorator on the method with only self leaks memory. See the document below. In this case, its okay to make them a property rather than a method. Its also consistent with device_set.
PiperOrigin-RevId: 465339349
2022-08-04 10:00:01 -07:00
Yash Katariya
2178f339bd Add OpShardingSharding and add a function that can calculate indices from an opsharding proto
PiperOrigin-RevId: 464123009
2022-07-29 11:37:39 -07:00
Yash Katariya
6d8c6f8fac Make astype work for Array that are sharded. The current behavior is the same as SDA i.e. it round trips via host.
PiperOrigin-RevId: 457797458
2022-06-28 12:49:12 -07:00
Yash Katariya
e32373c3ea Make jnp.array return jax.Array. Add input and result handlers for jax.Array. Also added tests for add under jit.
TODO:
* Don't allow `x + y` if `jax.Array` is not fully addressable.
* Figure out how to use the already written tests with Array. Might be able to follow the path taken by SDA.
PiperOrigin-RevId: 457034779
2022-06-24 10:05:06 -07:00
Yash Katariya
1b21d2c3f5 Return Array from jax.device_put if config.jax_array is enabled.
PiperOrigin-RevId: 456531510
2022-06-22 09:20:56 -07:00
Yash Katariya
b2e1d814ac Add __repr__ to Array. It works exactly as it does for DA and SDA when it is fully addressable. Otherwise it works like GDA.
TODO is adding weak_type support in general and to `__repr__`.

PiperOrigin-RevId: 455680796
2022-06-17 13:16:49 -07:00
Yash Katariya
a7160653ce Add __array__ (for device_get), _npy_value, block_until_ready, delete and _check_if_deleted to Array.
PiperOrigin-RevId: 454741685
2022-06-13 18:08:31 -07:00
Yash Katariya
123413751c Adding jax.Array to jax.experimental. Its pretty much the same as GDA (without the performance optimization for now).
Currently, jax.Array takes DeviceArrays in `assemble_array` because device_put returns a DA. In the future (with IFRT), it will return an `Array`.

`addressable_shards` wraps DA into jax.Array with a `SingleDeviceSharding`.

PiperOrigin-RevId: 453319811
2022-06-06 17:32:00 -07:00