Yash Katariya
fc7a71dc89
Remove device_replica_id_map
from the Sharding interface because the standalone function should be more than enough to use. The major use-case of this is for checkpointing and accessing addressable_shards which accesses the standalone function makes it work.
...
PiperOrigin-RevId: 470820443
2022-08-29 14:49:48 -07:00
Yash Katariya
70a7ee22cc
Check if the buffer shape matches the excepted shard shape by Array.
...
PiperOrigin-RevId: 470732792
2022-08-29 09:00:41 -07:00
Peter Hawkins
78cb9f8492
Avoid more direct references to jax._src without imports.
...
Change in preparation for not exporting jax._src by default.
PiperOrigin-RevId: 469725340
2022-08-24 07:51:28 -07:00
Yash Katariya
e8ec454ae8
Enable fast path in the Array constructor. This means that the rearranging of _arrays
according to the device_assignment won't happen when fastpath is enabled because we assume that jax transformations will return the right arrangement.
...
PiperOrigin-RevId: 469492283
2022-08-23 10:20:26 -07:00
Peter Hawkins
335b2cfb26
[JAX] Prepare not to export jax._src by default.
...
Currently
```
import jax
```
populates `jax._src` in the names exported from JAX. This change prepares for not exporting `jax._src` by default.
In particular, explicitly import modules from jax._src and refer to those imports rather than assuming jax._src contents will be around later. This is a common pattern in tests.
This change does not yet remove any exported names.
Issue https://github.com/google/jax/issues/11951
PiperOrigin-RevId: 469480816
2022-08-23 09:36:47 -07:00
Yash Katariya
f42151c3dc
Take the pjit XLA compilation path for Arrays
. In the test, astype
happens in a sharded fashion without the round trip to host.
...
PiperOrigin-RevId: 468510366
2022-08-18 11:44:39 -07:00
Yash Katariya
acdae7c237
Add weak type support to Array. Also make all api_test.py tests pass with Array. I have disabled the float0
test for now until I investigate.
...
PiperOrigin-RevId: 468264910
2022-08-17 12:25:49 -07:00
Yash Katariya
9040d5c1f7
Add support for returning arrays from full_like
that match the sharding of the input Array.
...
PiperOrigin-RevId: 468053219
2022-08-16 16:25:21 -07:00
Yash Katariya
647186a4aa
Make _device_assignment and _addressable_device_assignment a cached property because having a lru_cache as a decorator on the method with only self
leaks memory. See the document below. In this case, its okay to make them a property rather than a method. Its also consistent with device_set.
...
PiperOrigin-RevId: 465339349
2022-08-04 10:00:01 -07:00
Yash Katariya
2178f339bd
Add OpShardingSharding
and add a function that can calculate indices from an opsharding proto
...
PiperOrigin-RevId: 464123009
2022-07-29 11:37:39 -07:00
Yash Katariya
6d8c6f8fac
Make astype
work for Array
that are sharded. The current behavior is the same as SDA i.e. it round trips via host.
...
PiperOrigin-RevId: 457797458
2022-06-28 12:49:12 -07:00
Yash Katariya
e32373c3ea
Make jnp.array
return jax.Array
. Add input and result handlers for jax.Array
. Also added tests for add
under jit.
...
TODO:
* Don't allow `x + y` if `jax.Array` is not fully addressable.
* Figure out how to use the already written tests with Array. Might be able to follow the path taken by SDA.
PiperOrigin-RevId: 457034779
2022-06-24 10:05:06 -07:00
Yash Katariya
1b21d2c3f5
Return Array
from jax.device_put
if config.jax_array
is enabled.
...
PiperOrigin-RevId: 456531510
2022-06-22 09:20:56 -07:00
Yash Katariya
b2e1d814ac
Add __repr__
to Array
. It works exactly as it does for DA and SDA when it is fully addressable. Otherwise it works like GDA.
...
TODO is adding weak_type support in general and to `__repr__`.
PiperOrigin-RevId: 455680796
2022-06-17 13:16:49 -07:00
Yash Katariya
a7160653ce
Add __array__
(for device_get), _npy_value
, block_until_ready
, delete
and _check_if_deleted
to Array.
...
PiperOrigin-RevId: 454741685
2022-06-13 18:08:31 -07:00
Yash Katariya
123413751c
Adding jax.Array
to jax.experimental. Its pretty much the same as GDA (without the performance optimization for now).
...
Currently, jax.Array takes DeviceArrays in `assemble_array` because device_put returns a DA. In the future (with IFRT), it will return an `Array`.
`addressable_shards` wraps DA into jax.Array with a `SingleDeviceSharding`.
PiperOrigin-RevId: 453319811
2022-06-06 17:32:00 -07:00