rocm_jax/docs/jax_array_migration.md
jax authors fc1e1d4a65 Add freshness metablock to JAX OSS docs.
PiperOrigin-RevId: 645508135
2024-06-21 14:50:49 -07:00

338 lines
12 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

(jax-array-migration)=
# jax.Array migration
<!--* freshness: { reviewed: '2023-03-17' } *-->
**yashkatariya@**
## TL;DR
JAX switched its default array implementation to the new `jax.Array` as of version 0.4.1.
This guide explains the reasoning behind this, the impact it might have on your code,
and how to (temporarily) switch back to the old behavior.
### Whats going on?
`jax.Array` is a unified array type that subsumes `DeviceArray`, `ShardedDeviceArray`,
and `GlobalDeviceArray` types in JAX. The `jax.Array` type helps make parallelism a
core feature of JAX, simplifies and unifies JAX internals, and allows us to
unify jit and pjit. If your code doesn't mention `DeviceArray` vs
`ShardedDeviceArray` vs `GlobalDeviceArray`, no changes are needed. But code that
depends on details of these separate classes may need to be tweaked to work with
the unified jax.Array
After the migration is complete `jax.Array` will be the only type of array in
JAX.
This doc explains how to migrate existing codebases to `jax.Array`. For more information on using `jax.Array` and JAX parallelism APIs, see the [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) tutorial.
### How to enable jax.Array?
You can enable `jax.Array` by:
* setting the shell environment variable `JAX_ARRAY` to something true-like
(e.g., `1`);
* setting the boolean flag `jax_array` to something true-like if your code
parses flags with absl;
* using this statement at the top of your main file:
```
import jax
jax.config.update('jax_array', True)
```
### How do I know if jax.Array broke my code?
The easiest way to tell if `jax.Array` is responsible for any problems is to
disable `jax.Array` and see if the issues go away.
### How can I disable jax.Array for now?
Through **March 15, 2023** it will be possible to disable jax.Array by:
* setting the shell environment variable `JAX_ARRAY` to something falsey
(e.g., `0`);
* setting the boolean flag `jax_array` to something falsey if your code parses
flags with absl;
* using this statement at the top of your main file:
```
import jax
jax.config.update('jax_array', False)
```
## Why create jax.Array?
Currently JAX has three types; `DeviceArray`, `ShardedDeviceArray` and
`GlobalDeviceArray`. `jax.Array` merges these three types and cleans up JAXs
internals while adding new parallelism features.
We also introduce a new `Sharding` abstraction that describes how a logical
Array is physically sharded out across one or more devices, such as TPUs or
GPUs. The change also upgrades, simplifies and merges the parallelism features
of `pjit` into `jit`. Functions decorated with `jit` will be able to operate
over sharded arrays without copying data onto a single device.
Features you get with `jax.Array`:
* C++ `pjit` dispatch path
* Op-by-op parallelism (even if the array distributed across multiple devices
across multiple hosts)
* Simpler batch data parallelism with `pjit`/`jit`.
* Ways to create `Sharding`s that are not necessarily consisting of a mesh and
partition spec. Can fully utilize the flexibility of OpSharding if you want
or any other Sharding that you want.
* and many more
Example:
```
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
import numpy as np
x = jnp.arange(8)
# Let's say there are 8 devices in jax.devices()
mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x'))
sharded_x = jax.device_put(x, sharding)
# `matmul_sharded_x` and `sin_sharded_x` are sharded. `jit` is able to operate over a
# sharded array without copying data to a single device.
matmul_sharded_x = sharded_x @ sharded_x.T
sin_sharded_x = jnp.sin(sharded_x)
# Even jnp.copy preserves the sharding on the output.
copy_sharded_x = jnp.copy(sharded_x)
# double_out is also sharded
double_out = jax.jit(lambda x: x * 2)(sharded_x)
```
## What issues can arise when jax.Array is switched on?
### New public type named jax.Array
All `isinstance(..., jnp.DeviceArray)` or `isinstance(.., jax.xla.DeviceArray)`
and other variants of `DeviceArray` should be switched to using `isinstance(...,
jax.Array)`.
Since `jax.Array` can represent DA, SDA and GDA, you can differentiate those 3
types in `jax.Array` via:
* `x.is_fully_addressable and len(x.sharding.device_set) == 1` -- this means
that `jax.Array` is like a DA
* `x.is_fully_addressable and (len(x.sharding.device_set) > 1` -- this means
that `jax.Array` is like a SDA
* `not x.is_fully_addressable` -- this means that `jax.Array` is like a GDA
and spans across multiple processes
For `ShardedDeviceArray`, you can move `isinstance(...,
pxla.ShardedDeviceArray)` to `isinstance(..., jax.Array) and
x.is_fully_addressable and len(x.sharding.device_set) > 1`.
In general it is not possible to differentiate a `ShardedDeviceArray` on 1
device from any other kind of single-device Array.
### GDAs API name changes
GDAs `local_shards` and `local_data` have been deprecated.
Please use `addressable_shards` and `addressable_data` which are compatible with
`jax.Array` and `GDA`.
### Creating jax.Array
All JAX functions will output `jax.Array` when the `jax_array` flag is True. If
you were using `GlobalDeviceArray.from_callback` or `make_sharded_device_array`
or `make_device_array` functions to explicitly create the respective JAX data
types, you will need to switch them to use {func}`jax.make_array_from_callback`
or {func}`jax.make_array_from_single_device_arrays`.
**For GDA:**
`GlobalDeviceArray.from_callback(shape, mesh, pspec, callback)` can become
`jax.make_array_from_callback(shape, jax.sharding.NamedSharding(mesh, pspec), callback)`
in a 1:1 switch.
If you were using the raw GDA constructor to create GDAs, then do this:
`GlobalDeviceArray(shape, mesh, pspec, buffers)` can become
`jax.make_array_from_single_device_arrays(shape, jax.sharding.NamedSharding(mesh, pspec), buffers)`
**For SDA:**
`make_sharded_device_array(aval, sharding_spec, device_buffers, indices)` can
become `jax.make_array_from_single_device_arrays(shape, sharding, device_buffers)`.
To decide what the sharding should be, it depends on why you were creating the
SDAs:
If it was created to give as an input to `pmap`, then sharding can be:
`jax.sharding.PmapSharding(devices, sharding_spec)`.
If it was created to give as an input
to `pjit`, then sharding can be `jax.sharding.NamedSharding(mesh, pspec)`.
### Breaking change for pjit after switching to jax.Array for host local inputs
**If you are exclusively using GDA arguments to pjit, you can skip this section!
🎉**
With `jax.Array` enabled, all inputs to `pjit` must be globally shaped. This is
a breaking change from the previous behavior where `pjit` would concatenate
process-local arguments into a global value; this concatenation no longer
occurs.
Why are we making this breaking change? Each array now says explicitly how its
local shards fit into a global whole, rather than leaving it implicit. The more
explicit representation also unlocks additional flexibility, for example the use
of non-contiguous meshes with `pjit` which can improve efficiency on some TPU
models.
Running **multi-process pjit computation** and passing host-local inputs when
`jax.Array` is enabled can lead to an error similar to this:
Example:
Mesh = `{'x': 2, 'y': 2, 'z': 2}` and host local input shape == `(4,)` and
pspec = `P(('x', 'y', 'z'))`
Since `pjit` doesnt lift host local shapes to global shapes with `jax.Array`,
you get the following error:
Note: You will only see this error if your host local shape is smaller than the
shape of the mesh.
```
ValueError: One of pjit arguments was given the sharding of
NamedSharding(mesh={'x': 2, 'y': 2, 'chips': 2}, partition_spec=PartitionSpec(('x', 'y', 'chips'),)),
which implies that the global size of its dimension 0 should be divisible by 8,
but it is equal to 4
```
The error makes sense because you can't shard dimension 0, 8 ways when the value
on dimension `0` is `4`.
How can you migrate if you still pass host local inputs to `pjit`? We are
providing transitional APIs to help you migrate:
Note: You don't need these utilities if you run your pjitted computation on a
single process.
```
from jax.experimental import multihost_utils
global_inps = multihost_utils.host_local_array_to_global_array(
local_inputs, mesh, in_pspecs)
global_outputs = pjit(f, in_shardings=in_pspecs,
out_shardings=out_pspecs)(global_inps)
local_outs = multihost_utils.global_array_to_host_local_array(
global_outputs, mesh, out_pspecs)
```
`host_local_array_to_global_array` is a type cast that looks at a value with
only local shards and changes its local shape to the shape that `pjit` would
have previously assumed if that value was passed before the change.
Passing in fully replicated inputs i.e. same shape on each process with
`P(None)` as `in_axis_resources` is still supported. In this case you do not
have to use `host_local_array_to_global_array` because the shape is already
global.
```
key = jax.random.PRNGKey(1)
# As you can see, using host_local_array_to_global_array is not required since in_axis_resources says
# that the input is fully replicated via P(None)
pjit(f, in_shardings=None, out_shardings=None)(key)
# Mixing inputs
global_inp = multihost_utils.host_local_array_to_global_array(
local_inp, mesh, P('data'))
global_out = pjit(f, in_shardings=(P(None), P('data')),
out_shardings=...)(key, global_inp)
```
### FROM_GDA and jax.Array
If you were using `FROM_GDA` in `in_axis_resources` argument to `pjit`, then
with `jax.Array` there is no need to pass anything to `in_axis_resources` as
`jax.Array` will follow **computation follows sharding** semantics.
For example:
```
pjit(f, in_shardings=FROM_GDA, out_shardings=...) can be replaced by pjit(f, out_shardings=...)
```
If you have PartitionSpecs mixed in with `FROM_GDA` for inputs like numpy
arrays, etc, then use `host_local_array_to_global_array` to convert them to
`jax.Array`.
For example:
If you had this:
```
pjitted_f = pjit(
f, in_shardings=(FROM_GDA, P('x'), FROM_GDA, P(None)),
out_shardings=...)
pjitted_f(gda1, np_array1, gda2, np_array2)
```
then you can replace it with:
```
pjitted_f = pjit(f, out_shardings=...)
array2, array3 = multihost_utils.host_local_array_to_global_array(
(np_array1, np_array2), mesh, (P('x'), P(None)))
pjitted_f(array1, array2, array3, array4)
```
### live_buffers replaced with live_arrays
`live_buffers` attribute on jax `Device`
has been deprecated. Please use `jax.live_arrays()` instead which is compatible
with `jax.Array`.
### Handling of host local inputs to pjit like batch, etc
If you are passing host local inputs to `pjit` in a **multi-process
environment**, then please use
`multihost_utils.host_local_array_to_global_array` to convert the batch to a
global `jax.Array` and then pass that to `pjit`.
The most common example of such a host local input is a **batch of input data**.
This will work for any host local input (not just a batch of input data).
```
from jax.experimental import multihost_utils
batch = multihost_utils.host_local_array_to_global_array(
batch, mesh, batch_partition_spec)
```
See the pjit section above for more details about this change and more examples.
### RecursionError: Recursively calling jit
This happens when some part of your code has `jax.Array` disabled and then you
enable it only for some other part. For example, if you use some third\_party
code which has `jax.Array` disabled and you get a `DeviceArray` from that
library and then you enable `jax.Array` in your library and pass that
`DeviceArray` to JAX functions, it will lead to a RecursionError.
This error should go away when `jax.Array` is enabled by default so that all
libraries return `jax.Array` unless they explicitly disable it.