Merge pull request #13197 from google:yashk2810-patch-17

PiperOrigin-RevId: 487687006
This commit is contained in:
jax authors 2022-11-10 17:56:29 -08:00
commit f9d7a6ae20

View File

@ -1,12 +1,10 @@
## jax.Array migration
# jax.Array migration
**yashkatariya@**
[TOC]
## TL;DR
#### TL;DR {#tl-dr}
##### Whats going on? {#whats-going-on}
### Whats going on?
`jax.Array` is a unified array type that subsumes `DeviceArray`, `ShardedDeviceArray`,
and `GlobalDeviceArray` types in JAX. The `jax.Array` type helps make parallelism a
@ -20,7 +18,7 @@ After the migration is complete `jax.Array` will be the only type of array in
JAX.
##### How to enable jax.Array? {#how-to-enable-jax-array}
### How to enable jax.Array?
You can enable `jax.Array` by:
@ -35,12 +33,12 @@ You can enable `jax.Array` by:
jax.config.update('jax_array', True)
```
##### How do I know if jax.Array broke my code? {#how-do-i-know-if-jax-array-broke-my-code}
### How do I know if jax.Array broke my code?
The easiest way to tell if `jax.Array` is responsible for any problems is to
disable `jax.Array` and see if the issues go away.
##### How can I disable jax.Array for now? {#how-can-i-disable-jax-array-for-now}
### How can I disable jax.Array for now?
You can disable `jax.Array` by: (After a certain date (TBD), the option to
disable jax.Array won't exist)
@ -56,7 +54,7 @@ disable jax.Array won't exist)
jax.config.update('jax_array', False)
```
#### Why create jax.Array? {#why-create-jax-array}
## Why create jax.Array?
Currently JAX has three types; `DeviceArray`, `ShardedDeviceArray` and
`GlobalDeviceArray`. `jax.Array` merges these three types and cleans up JAXs
@ -105,9 +103,9 @@ copy_sharded_x = jnp.copy(sharded_x)
double_out = jax.jit(lambda x: x * 2)(sharded_x)
```
#### What issues can arise when jax.Array is switched on? {#what-issues-can-arise-when-jax-array-is-switched-on}
## What issues can arise when jax.Array is switched on?
##### New public type named jax.Array {#new-public-type-named-jax-array}
### New public type named jax.Array
All `isinstance(..., jnp.DeviceArray)` or `isinstance(.., jax.xla.DeviceArray)`
and other variants of `DeviceArray` should be switched to using `isinstance(...,
@ -130,14 +128,14 @@ x.is_fully_addressable and len(x.sharding.device_set) > 1`.
In general it is not possible to differentiate a `ShardedDeviceArray` on 1
device from any other kind of single-device Array.
##### GDAs API name changes {#gdas-api-name-changes}
### GDAs API name changes
GDAs `local_shards` and `local_data` have been deprecated.
Please use `addressable_shards` and `addressable_data` which are compatible with
`jax.Array` and `GDA`.
##### Creating jax.Array {#creating-jax-array}
### Creating jax.Array
All JAX functions will output `jax.Array` when the `jax_array` flag is True. If
you were using `GlobalDeviceArray.from_callback` or `make_sharded_device_array`
@ -170,7 +168,7 @@ If it was created to give as an input to `pmap`, then sharding can be:
If it was created to give as an input
to `pjit`, then sharding can be `jax.sharding.NamedSharding(mesh, pspec)`.
##### Breaking change for pjit after switching to jax.Array for host local inputs {#breaking-change-for-pjit-after-switching-to-jax-array-for-host-local-inputs}
### Breaking change for pjit after switching to jax.Array for host local inputs
**If you are exclusively using GDA arguments to pjit, you can skip this section!
🎉**
@ -252,7 +250,7 @@ global_out = pjit(f, in_axis_resources=(P(None), P('data')),
out_axis_resources=...)(key, global_inp)
```
##### FROM_GDA and jax.Array {#from_gda-and-jax-array}
### FROM_GDA and jax.Array
If you were using `FROM_GDA` in `in_axis_resources` argument to `pjit`, then
with `jax.Array` there is no need to pass anything to `in_axis_resources` as
@ -291,14 +289,14 @@ array2, array3 = multihost_utils.host_local_array_to_global_array(
pjitted_f(array1, array2, array3, array4)
```
##### live\_buffers replaced with live\_arrays {#live_buffers-replaced-with-live_arrays}
### live_buffers replaced with live_arrays
`live_buffers` attribute on jax `Device`
has been deprecated. Please use `jax.live_arrays()` instead which is compatible
with `jax.Array`.
##### Handling of host local inputs to pjit like batch, etc {#handling-of-host-local-inputs-to-pjit-like-batch-etc}
### Handling of host local inputs to pjit like batch, etc
If you are passing host local inputs to `pjit` in a **multi-process
environment**, then please use
@ -320,7 +318,7 @@ batch = multihost_utils.host_local_array_to_global_array(
See the pjit section above for more details about this change and more examples.
##### RecursionError: Recursively calling jit {#recursionerror-recursively-calling-jit}
### RecursionError: Recursively calling jit
This happens when some part of your code has `jax.Array` disabled and then you
enable it only for some other part. For example, if you use some third\_party