Add jax.Array migration doc to OSS

PiperOrigin-RevId: 487673643
This commit is contained in:
Yash Katariya 2022-11-10 16:45:51 -08:00 committed by jax authors
parent 352b042fe9
commit b49a1bda15
2 changed files with 333 additions and 0 deletions

View File

@ -59,6 +59,7 @@ parallelize, Just-In-Time compile to GPU/TPU, and more.
:caption: Notes
api_compatibility
jax_array_migration
deprecation
concurrency
gpu_memory_allocation

332
docs/jax_array_migration.md Normal file
View File

@ -0,0 +1,332 @@
## jax.Array migration
**yashkatariya@**
[TOC]
#### TL;DR {#tl-dr}
##### Whats going on? {#whats-going-on}
`jax.Array` is a unified array type that subsumes `DeviceArray`, `ShardedDeviceArray`,
and `GlobalDeviceArray` types in JAX. The `jax.Array` type helps make parallelism a
core feature of JAX, simplifies and unifies JAX internals, and allows us to
unify jit and pjit. If your code doesn't mention `DeviceArray` vs
`ShardedDeviceArray` vs `GlobalDeviceArray`, no changes are needed. But code that
depends on details of these separate classes may need to be tweaked to work with
the unified jax.Array
After the migration is complete `jax.Array` will be the only type of array in
JAX.
##### How to enable jax.Array? {#how-to-enable-jax-array}
You can enable `jax.Array` by:
* setting the shell environment variable `JAX_ARRAY` to something true-like
(e.g., `1`);
* setting the boolean flag `jax_array` to something true-like if your code
parses flags with absl;
* using this statement at the top of your main file:
```
import jax
jax.config.update('jax_array', True)
```
##### How do I know if jax.Array broke my code? {#how-do-i-know-if-jax-array-broke-my-code}
The easiest way to tell if `jax.Array` is responsible for any problems is to
disable `jax.Array` and see if the issues go away.
##### How can I disable jax.Array for now? {#how-can-i-disable-jax-array-for-now}
You can disable `jax.Array` by: (After a certain date (TBD), the option to
disable jax.Array won't exist)
* setting the shell environment variable `JAX_ARRAY` to something falsey
(e.g., `0`);
* setting the boolean flag `jax_array` to something falsey if your code parses
flags with absl;
* using this statement at the top of your main file:
```
import jax
jax.config.update('jax_array', False)
```
#### Why create jax.Array? {#why-create-jax-array}
Currently JAX has three types; `DeviceArray`, `ShardedDeviceArray` and
`GlobalDeviceArray`. `jax.Array` merges these three types and cleans up JAXs
internals while adding new parallelism features.
We also introduce a new `Sharding` abstraction that describes how a logical
Array is physically sharded out across one or more devices, such as TPUs or
GPUs. The change also upgrades, simplifies and merges the parallelism features
of `pjit` into `jit`. Functions decorated with `jit` will be able to operate
over sharded arrays without copying data onto a single device.
Features you get with `jax.Array`:
* C++ `pjit` dispatch path
* Op-by-op parallelism (even if the array distributed across multiple devices
across multiple hosts)
* Simpler batch data parallelism with `pjit`/`jit`.
* Ways to create `Sharding`s that are not necessarily consisting of a mesh and
partition spec. Can fully utilize the flexibility of OpSharding if you want
or any other Sharding that you want.
* and many more
Example:
```
import jax
import jax.numpy as jnp
x = jnp.arange(8)
# Let's say there are 8 devices in jax.devices()
mesh = maps.Mesh(jax.devices().reshape(4, 2), ('x', 'y'))
sharding = jax.sharding.MeshPspecSharding(mesh, P('x'))
sharded_x = jax.device_put(x, sharding)
# `mul_sharded_x` and `sin_sharded_x` are sharded. `jit` is able to operate over a
# sharded array without copying data to a single device.
mul_sharded_x = sharded_x @ sharded_x.T
sin_sharded_x = jnp.sin(sharded_x)
# Even jnp.copy preserves the sharding on the output.
copy_sharded_x = jnp.copy(sharded_x)
# double_out is also sharded
double_out = jax.jit(lambda x: x * 2)(sharded_x)
```
#### What issues can arise when jax.Array is switched on? {#what-issues-can-arise-when-jax-array-is-switched-on}
##### New public type named jax.Array {#new-public-type-named-jax-array}
All `isinstance(..., jnp.DeviceArray)` or `isinstance(.., jax.xla.DeviceArray)`
and other variants of `DeviceArray` should be switched to using `isinstance(...,
jax.Array)`.
Since `jax.Array` can represent DA, SDA and GDA, you can differentiate those 3
types in `jax.Array` via:
* `x.is_fully_addressable and len(x.sharding.device_set) == 1` -- this means
that `jax.Array` is like a DA
* `x.is_fully_addressable and (len(x.sharding.device_set) > 1` -- this means
that `jax.Array` is like a SDA
* `not x.is_fully_addressable` -- this means that `jax.Array` is like a GDA
and spans across multiple processes
For `ShardedDeviceArray`, you can move `isinstance(...,
pxla.ShardedDeviceArray)` to `isinstance(..., jax.Array) and
x.is_fully_addressable and len(x.sharding.device_set) > 1`.
In general it is not possible to differentiate a `ShardedDeviceArray` on 1
device from any other kind of single-device Array.
##### GDAs API name changes {#gdas-api-name-changes}
GDAs `local_shards` and `local_data` have been deprecated.
Please use `addressable_shards` and `addressable_data` which are compatible with
`jax.Array` and `GDA`.
##### Creating jax.Array {#creating-jax-array}
All JAX functions will output `jax.Array` when the `jax_array` flag is True. If
you were using `GlobalDeviceArray.from_callback` or `make_sharded_device_array`
or `make_device_array` functions to explicitly create the respective JAX data
types, you will need to switch them to use `jax.make_array_from_callback` or
`jax.make_array_from_single_device_arrays`.
**For GDA:**
`GlobalDeviceArray.from_callback(shape, mesh, pspec, callback)` can become
`jax.make_array_from_callback(shape, jax.sharding.NamedSharding(mesh, pspec), callback)`
in a 1:1 switch.
If you were using the raw GDA constructor to create GDAs, then do this:
`GlobalDeviceArray(shape, mesh, pspec, buffers)` can become
`jax.make_array_from_single_device_arrays(shape, jax.sharding.NamedSharding(mesh, pspec), buffers)`
**For SDA:**
`make_sharded_device_array(aval, sharding_spec, device_buffers, indices)` can
become `jax.make_array_from_single_device_arrays(shape, sharding, device_buffers)`.
To decide what the sharding should be, it depends on why you were creating the
SDAs:
If it was created to give as an input to `pmap`, then sharding can be:
`jax.sharding.PmapSharding(devices, sharding_spec)`.
If it was created to give as an input
to `pjit`, then sharding can be `jax.sharding.NamedSharding(mesh, pspec)`.
##### Breaking change for pjit after switching to jax.Array for host local inputs {#breaking-change-for-pjit-after-switching-to-jax-array-for-host-local-inputs}
**If you are exclusively using GDA arguments to pjit, you can skip this section!
🎉**
With `jax.Array` enabled, all inputs to `pjit` must be globally shaped. This is
a breaking change from the previous behavior where `pjit` would concatenate
process-local arguments into a global value; this concatenation no longer
occurs.
Why are we making this breaking change? Each array now says explicitly how its
local shards fit into a global whole, rather than leaving it implicit. The more
explicit representation also unlocks additional flexibility, for example the use
of non-contiguous meshes with `pjit` which can improve efficiency on some TPU
models.
Running **multi-process pjit computation** and passing host-local inputs when
`jax.Array` is enabled can lead to an error similar to this:
Example:
Mesh = `{'x': 2, 'y': 2, 'z': 3}` and host local input shape == `(4,)` and
pspec = `P(('x', 'y', 'z'))`
Since `pjit` doesnt lift host local shapes to global shapes with `jax.Array`,
you get the following error:
Note: You will only see this error if your host local shape is smaller than the
shape of the mesh.
```
ValueError: One of pjit arguments was given the sharding of
MeshPspecSharding(mesh={'x': 2, 'y': 2, 'chips': 2}, partition_spec=PartitionSpec(('x', 'y', 'chips'),)),
which implies that the global size of its dimension 0 should be divisible by 8,
but it is equal to 4
```
The error makes sense because you can't shard dimension 0, 8 ways when the value
on dimension `0` is `4`.
How can you migrate if you still pass host local inputs to `pjit`? We are
providing transitional APIs to help you migrate:
Note: You don't need these utilities if you run your pjitted computation on a
single process.
```
from jax.experimental import multihost_utils
global_inps = multihost_utils.host_local_array_to_global_array(
local_inputs, mesh, in_pspecs)
global_outputs = pjit(f, in_axis_resources=in_pspecs,
out_axis_resources=out_pspecs)(global_inps)
local_outs = multihost_utils.global_array_to_host_local_array(
global_outputs, mesh, out_pspecs)
```
`host_local_array_to_global_array` is a type cast that looks at a value with
only local shards and changes its local shape to the shape that `pjit` would
have previously assumed if that value was passed before the change.
Passing in fully replicated inputs i.e. same shape on each process with
`P(None)` as `in_axis_resources` is still supported. In this case you do not
have to use `host_local_array_to_global_array` because the shape is already
global.
```
key = jax.random.PRNGKey(1)
# As you can see, using host_local_array_to_global_array is not required since in_axis_resources says
# that the input is fully replicated via P(None)
pjit(f, in_axis_resources=None, out_axis_resources=None)(key)
# Mixing inputs
global_inp = multihost_utils.host_local_array_to_global_array(
local_inp, mesh, P('data'))
global_out = pjit(f, in_axis_resources=(P(None), P('data')),
out_axis_resources=...)(key, global_inp)
```
##### FROM_GDA and jax.Array {#from_gda-and-jax-array}
If you were using `FROM_GDA` in `in_axis_resources` argument to `pjit`, then
with `jax.Array` there is no need to pass anything to `in_axis_resources` as
`jax.Array` will follow **computation follows sharding** semantics.
For example:
```
pjit(f, in_axis_resources=FROM_GDA, out_axis_resources=...) can be replaced by pjit(f, out_axis_resources=...)
```
If you have PartitionSpecs mixed in with `FROM_GDA` for inputs like numpy
arrays, etc, then use `host_local_array_to_global_array` to convert them to
`jax.Array`.
For example:
If you had this:
```
pjitted_f = pjit(
f, in_axis_resources=(FROM_GDA, P('x'), FROM_GDA, P(None)),
out_axis_resources=...)
pjitted_f(gda1, np_array1, gda2, np_array2)
```
then you can replace it with:
```
pjitted_f = pjit(f, out_axis_resources=...)
array2, array3 = multihost_utils.host_local_array_to_global_array(
(np_array1, np_array2), mesh, (P('x'), P(None)))
pjitted_f(array1, array2, array3, array4)
```
##### live\_buffers replaced with live\_arrays {#live_buffers-replaced-with-live_arrays}
`live_buffers` attribute on jax `Device`
has been deprecated. Please use `jax.live_arrays()` instead which is compatible
with `jax.Array`.
##### Handling of host local inputs to pjit like batch, etc {#handling-of-host-local-inputs-to-pjit-like-batch-etc}
If you are passing host local inputs to `pjit` in a **multi-process
environment**, then please use
`multihost_utils.host_local_array_to_global_array` to convert the batch to a
global `jax.Array` and then pass that to `pjit`.
The most common example of such a host local input is a **batch of input data**.
See this cl/486388872 as an example of how to do that.
This will work for any host local input (not just a batch of input data).
```
from jax.experimental import multihost_utils
batch = multihost_utils.host_local_array_to_global_array(
batch, mesh, batch_partition_spec)
```
See the pjit section above for more details about this change and more examples.
##### RecursionError: Recursively calling jit {#recursionerror-recursively-calling-jit}
This happens when some part of your code has `jax.Array` disabled and then you
enable it only for some other part. For example, if you use some third\_party
code which has `jax.Array` disabled and you get a `DeviceArray` from that
library and then you enable `jax.Array` in your library and pass that
`DeviceArray` to JAX functions, it will lead to a RecursionError.
This error should go away when `jax.Array` is enabled by default so that all
libraries return `jax.Array` unless they explicitly disable it.