Roy Frostig
c8b9280fb3
partitionable threefry PRNG random bits implementation
...
the cost is 2x overgeneration of bits
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-10-28 10:07:14 -07:00
Matthew Johnson
95eb4249bb
tweaks to DevicesSharding
...
1. rename DevicesSharding -> ReshapeableDevicesSharding
2. fix repr to print device order faithfully
3. respect shape of np.ndarray argument to __init__
2022-10-25 14:28:48 -07:00
Yash Katariya
adcb0f58e8
Add __repr__
and __str__
to PmapSharding
.
...
Fixes https://github.com/google/jax/issues/12971
PiperOrigin-RevId: 483707874
2022-10-25 10:13:02 -07:00
Yash Katariya
1a0affddd8
Move is_deleted()
to C++ so that we can check if an Array is deleted without materializing _arrays
.
...
Also raise a better error message when doing operations of a deleted Array rather than the current thing which says: `NoneType has no len()`. Now it says: `Array has been deleted`.
PiperOrigin-RevId: 482497114
2022-10-20 08:29:25 -07:00
Matthew Johnson
43098f906a
initial commit of DevicesSharding (fka SimpleSharding)
...
need to add tests!
Co-authored-by: Yash Katariya <yashkatariya@google.com>
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-10-18 21:10:24 -07:00
Yash Katariya
335f45ebb2
Use _rewriting_take
and _chunk_iter
path during __getitem__
and __iter__
respectively when the Array is fully replicated
...
For example:
```
k1, k2 = jax.random.split(key, 2) # where key is fully replicated on 8 devices
```
Then `k1` and `k2` should also maintain the sharding of `key` since `key` is fully replicated.
PiperOrigin-RevId: 480434272
2022-10-11 13:09:33 -07:00
Yash Katariya
9b3e864731
Add weak_type attribute to Array
since it exists on DA (but doesn't exist on SDA).
...
PiperOrigin-RevId: 480223116
2022-10-10 18:11:11 -07:00
Yash Katariya
75b2d05989
Make is_fully_replicated
and is_fully_addressble
a property rather than a method.
...
Why?
1. Because it's easy to cache a property than a method with only the `self` argument. (See below for article)
2. There's no harm in making them a property because both of them return a bool without any side-effects and are cached (so its fast). Why cache `is_fully_addressable`? Because its very expensive to calculate when you have 1000s of devices.
PiperOrigin-RevId: 479850850
2022-10-08 19:24:12 -07:00
jax authors
e8ba61d82b
Merge pull request #12677 from mattjj:jit-pjit-lower-sharding
...
PiperOrigin-RevId: 479669125
2022-10-07 14:28:51 -07:00
Matthew Johnson
bcca6fb57a
add test, small fixes
...
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2022-10-06 16:45:34 -07:00
Yash Katariya
37f9db77f7
Create Array
s from __getitem__
and __iter__
. This is done by device_put
ting from the host to default device which is suboptimal. But there is a TODO to fix this!
...
PiperOrigin-RevId: 478691051
2022-10-03 22:29:03 -07:00
Yash Katariya
aafc77d3c0
Improve the checks done in Array
and apply them to all Sharding
s rather than just XLACompatibleSharding
.
...
Also check the symmetric difference of sharding and `_arrays` devices.
PiperOrigin-RevId: 478017409
2022-09-30 09:56:16 -07:00
Yash Katariya
500f8b7f9c
Add HLOSharding's repr to OpShardingSharding since its more compact.
...
PiperOrigin-RevId: 477587916
2022-09-28 17:00:16 -07:00
Yash Katariya
c89cb5d8a4
Use Array
in __repr__
instead of the class name which is ArrayImpl
.
...
PiperOrigin-RevId: 477465432
2022-09-28 08:57:53 -07:00
Yash Katariya
9e4114f0f1
Move array.py
and sharding.py
from experimental/
to _src/
.
...
PiperOrigin-RevId: 477201711
2022-09-27 10:06:52 -07:00
Yash Katariya
cbf34cb609
Rename the concrete class Array
to ArrayImpl
...
PiperOrigin-RevId: 477017236
2022-09-26 16:18:30 -07:00
Yash Katariya
1fa0dda760
Return single device Arrays from .device_buffer
and .device_buffers
.
...
PiperOrigin-RevId: 476449591
2022-09-23 13:30:26 -07:00
Peter Hawkins
ba557d5e1b
Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
...
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.
PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Yash Katariya
52476d1ab5
Add addressable_data to Array (similar to GDA) to aid in transition and also in auto spmd partitioner mode, always convert to MeshPspecSharding.
...
PiperOrigin-RevId: 475972534
2022-09-21 18:19:35 -07:00
Jake VanderPlas
bdb264e60f
jax.Array: add issubclass test analogous to existing DeviceArray test.
2022-09-14 16:29:38 -07:00
Jake VanderPlas
8eb44fd195
jax_array_test: set config once & fix X64 failure
2022-09-13 14:06:38 -07:00
Yash Katariya
120b2801fd
Bounce to host for any sharding that's not PmapSharding or a sharding with a single device for __iter__
and __getitem__
.
...
PiperOrigin-RevId: 473402857
2022-09-09 20:41:40 -07:00
Yash Katariya
d7726e7b26
Make __getitem__
work for PmapSharding just like SDA works. DA is already covered with the current implementation.
...
Added TODOs to take fast path for indices wherever it is possible to do that. If a correct index is passed during getitem and if that index exists on `Array`, then the fast path is taken (see the test in this CL).
PiperOrigin-RevId: 473342504
2022-09-09 14:25:22 -07:00
Yash Katariya
4746a39e59
Show the correct sharding in is_compatible_aval
error in MeshPspecSharding
when created via _from_parsed_pspec
. Preserve the original PartitionSpec from ParsedPartitionSpec if it exists, else calculate it.
...
PiperOrigin-RevId: 473267905
2022-09-09 09:13:42 -07:00
Yash Katariya
751be205d8
Add a test for jnp.array(..., copy=True) for Array.
...
PiperOrigin-RevId: 473041718
2022-09-08 11:29:51 -07:00
Yash Katariya
7fbf8ec669
Fix Forward. The fix is on the user's end. Original PR: https://github.com/google/jax/pull/12217
...
Co-authored-by: Matthew Johnson <mattjj@google.com>
Co-authored-by: Yash Katariya <yashkatariya@google.com>
PiperOrigin-RevId: 472999907
2022-09-08 08:49:40 -07:00
jax authors
14f1a345a1
roll back breakage
...
PiperOrigin-RevId: 472949225
2022-09-08 03:59:54 -07:00
Yash Katariya
b7e4e44cbf
DCE jaxpr and trivial_jaxpr support for lower_sharding_computation
...
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 471274989
2022-09-06 14:09:10 -07:00
Yash Katariya
3e54ac0af0
Make __iter__
of Array
behave like DA when there is a SingleDeviceSharding and like SDA when there is a non-trivial sharding.
...
This is important because when `Array` contains more than 1 shard, each shard can be on a different device and those things need to be preserved when iterating over `Array`.
PiperOrigin-RevId: 471695841
2022-09-01 19:54:34 -07:00
Yash Katariya
2f7951b3dc
Add __hash__
and __eq__
to PmapSharding
...
PiperOrigin-RevId: 471356052
2022-08-31 14:27:16 -07:00
Yash Katariya
da24b99d30
Some minor changes to make_array_from_callback to use the device_indices_map method and calculate the indices just once. Also set the _committed
attribute of shards to what the parent Array has.
...
PiperOrigin-RevId: 471167295
2022-08-30 21:57:21 -07:00
Yash Katariya
fc7a71dc89
Remove device_replica_id_map
from the Sharding interface because the standalone function should be more than enough to use. The major use-case of this is for checkpointing and accessing addressable_shards which accesses the standalone function makes it work.
...
PiperOrigin-RevId: 470820443
2022-08-29 14:49:48 -07:00
Yash Katariya
70a7ee22cc
Check if the buffer shape matches the excepted shard shape by Array.
...
PiperOrigin-RevId: 470732792
2022-08-29 09:00:41 -07:00
Peter Hawkins
78cb9f8492
Avoid more direct references to jax._src without imports.
...
Change in preparation for not exporting jax._src by default.
PiperOrigin-RevId: 469725340
2022-08-24 07:51:28 -07:00
Yash Katariya
e8ec454ae8
Enable fast path in the Array constructor. This means that the rearranging of _arrays
according to the device_assignment won't happen when fastpath is enabled because we assume that jax transformations will return the right arrangement.
...
PiperOrigin-RevId: 469492283
2022-08-23 10:20:26 -07:00
Peter Hawkins
335b2cfb26
[JAX] Prepare not to export jax._src by default.
...
Currently
```
import jax
```
populates `jax._src` in the names exported from JAX. This change prepares for not exporting `jax._src` by default.
In particular, explicitly import modules from jax._src and refer to those imports rather than assuming jax._src contents will be around later. This is a common pattern in tests.
This change does not yet remove any exported names.
Issue https://github.com/google/jax/issues/11951
PiperOrigin-RevId: 469480816
2022-08-23 09:36:47 -07:00
Yash Katariya
f42151c3dc
Take the pjit XLA compilation path for Arrays
. In the test, astype
happens in a sharded fashion without the round trip to host.
...
PiperOrigin-RevId: 468510366
2022-08-18 11:44:39 -07:00
Yash Katariya
acdae7c237
Add weak type support to Array. Also make all api_test.py tests pass with Array. I have disabled the float0
test for now until I investigate.
...
PiperOrigin-RevId: 468264910
2022-08-17 12:25:49 -07:00
Yash Katariya
9040d5c1f7
Add support for returning arrays from full_like
that match the sharding of the input Array.
...
PiperOrigin-RevId: 468053219
2022-08-16 16:25:21 -07:00
Yash Katariya
647186a4aa
Make _device_assignment and _addressable_device_assignment a cached property because having a lru_cache as a decorator on the method with only self
leaks memory. See the document below. In this case, its okay to make them a property rather than a method. Its also consistent with device_set.
...
PiperOrigin-RevId: 465339349
2022-08-04 10:00:01 -07:00
Yash Katariya
2178f339bd
Add OpShardingSharding
and add a function that can calculate indices from an opsharding proto
...
PiperOrigin-RevId: 464123009
2022-07-29 11:37:39 -07:00
Yash Katariya
6d8c6f8fac
Make astype
work for Array
that are sharded. The current behavior is the same as SDA i.e. it round trips via host.
...
PiperOrigin-RevId: 457797458
2022-06-28 12:49:12 -07:00
Yash Katariya
e32373c3ea
Make jnp.array
return jax.Array
. Add input and result handlers for jax.Array
. Also added tests for add
under jit.
...
TODO:
* Don't allow `x + y` if `jax.Array` is not fully addressable.
* Figure out how to use the already written tests with Array. Might be able to follow the path taken by SDA.
PiperOrigin-RevId: 457034779
2022-06-24 10:05:06 -07:00
Yash Katariya
1b21d2c3f5
Return Array
from jax.device_put
if config.jax_array
is enabled.
...
PiperOrigin-RevId: 456531510
2022-06-22 09:20:56 -07:00
Yash Katariya
b2e1d814ac
Add __repr__
to Array
. It works exactly as it does for DA and SDA when it is fully addressable. Otherwise it works like GDA.
...
TODO is adding weak_type support in general and to `__repr__`.
PiperOrigin-RevId: 455680796
2022-06-17 13:16:49 -07:00
Yash Katariya
a7160653ce
Add __array__
(for device_get), _npy_value
, block_until_ready
, delete
and _check_if_deleted
to Array.
...
PiperOrigin-RevId: 454741685
2022-06-13 18:08:31 -07:00
Yash Katariya
123413751c
Adding jax.Array
to jax.experimental. Its pretty much the same as GDA (without the performance optimization for now).
...
Currently, jax.Array takes DeviceArrays in `assemble_array` because device_put returns a DA. In the future (with IFRT), it will return an `Array`.
`addressable_shards` wraps DA into jax.Array with a `SingleDeviceSharding`.
PiperOrigin-RevId: 453319811
2022-06-06 17:32:00 -07:00