mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add jax.Array migration doc to OSS
PiperOrigin-RevId: 487673643
This commit is contained in:
parent
352b042fe9
commit
b49a1bda15
@ -59,6 +59,7 @@ parallelize, Just-In-Time compile to GPU/TPU, and more.
|
||||
:caption: Notes
|
||||
|
||||
api_compatibility
|
||||
jax_array_migration
|
||||
deprecation
|
||||
concurrency
|
||||
gpu_memory_allocation
|
||||
|
332
docs/jax_array_migration.md
Normal file
332
docs/jax_array_migration.md
Normal file
@ -0,0 +1,332 @@
|
||||
## jax.Array migration
|
||||
|
||||
**yashkatariya@**
|
||||
|
||||
[TOC]
|
||||
|
||||
#### TL;DR {#tl-dr}
|
||||
|
||||
##### What’s going on? {#what’s-going-on}
|
||||
|
||||
`jax.Array` is a unified array type that subsumes `DeviceArray`, `ShardedDeviceArray`,
|
||||
and `GlobalDeviceArray` types in JAX. The `jax.Array` type helps make parallelism a
|
||||
core feature of JAX, simplifies and unifies JAX internals, and allows us to
|
||||
unify jit and pjit. If your code doesn't mention `DeviceArray` vs
|
||||
`ShardedDeviceArray` vs `GlobalDeviceArray`, no changes are needed. But code that
|
||||
depends on details of these separate classes may need to be tweaked to work with
|
||||
the unified jax.Array
|
||||
|
||||
After the migration is complete `jax.Array` will be the only type of array in
|
||||
JAX.
|
||||
|
||||
|
||||
##### How to enable jax.Array? {#how-to-enable-jax-array}
|
||||
|
||||
You can enable `jax.Array` by:
|
||||
|
||||
* setting the shell environment variable `JAX_ARRAY` to something true-like
|
||||
(e.g., `1`);
|
||||
* setting the boolean flag `jax_array` to something true-like if your code
|
||||
parses flags with absl;
|
||||
* using this statement at the top of your main file:
|
||||
|
||||
```
|
||||
import jax
|
||||
jax.config.update('jax_array', True)
|
||||
```
|
||||
|
||||
##### How do I know if jax.Array broke my code? {#how-do-i-know-if-jax-array-broke-my-code}
|
||||
|
||||
The easiest way to tell if `jax.Array` is responsible for any problems is to
|
||||
disable `jax.Array` and see if the issues go away.
|
||||
|
||||
##### How can I disable jax.Array for now? {#how-can-i-disable-jax-array-for-now}
|
||||
|
||||
You can disable `jax.Array` by: (After a certain date (TBD), the option to
|
||||
disable jax.Array won't exist)
|
||||
|
||||
* setting the shell environment variable `JAX_ARRAY` to something falsey
|
||||
(e.g., `0`);
|
||||
* setting the boolean flag `jax_array` to something falsey if your code parses
|
||||
flags with absl;
|
||||
* using this statement at the top of your main file:
|
||||
|
||||
```
|
||||
import jax
|
||||
jax.config.update('jax_array', False)
|
||||
```
|
||||
|
||||
#### Why create jax.Array? {#why-create-jax-array}
|
||||
|
||||
Currently JAX has three types; `DeviceArray`, `ShardedDeviceArray` and
|
||||
`GlobalDeviceArray`. `jax.Array` merges these three types and cleans up JAX’s
|
||||
internals while adding new parallelism features.
|
||||
|
||||
We also introduce a new `Sharding` abstraction that describes how a logical
|
||||
Array is physically sharded out across one or more devices, such as TPUs or
|
||||
GPUs. The change also upgrades, simplifies and merges the parallelism features
|
||||
of `pjit` into `jit`. Functions decorated with `jit` will be able to operate
|
||||
over sharded arrays without copying data onto a single device.
|
||||
|
||||
Features you get with `jax.Array`:
|
||||
|
||||
* C++ `pjit` dispatch path
|
||||
* Op-by-op parallelism (even if the array distributed across multiple devices
|
||||
across multiple hosts)
|
||||
* Simpler batch data parallelism with `pjit`/`jit`.
|
||||
* Ways to create `Sharding`s that are not necessarily consisting of a mesh and
|
||||
partition spec. Can fully utilize the flexibility of OpSharding if you want
|
||||
or any other Sharding that you want.
|
||||
* and many more
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
x = jnp.arange(8)
|
||||
|
||||
# Let's say there are 8 devices in jax.devices()
|
||||
mesh = maps.Mesh(jax.devices().reshape(4, 2), ('x', 'y'))
|
||||
sharding = jax.sharding.MeshPspecSharding(mesh, P('x'))
|
||||
|
||||
sharded_x = jax.device_put(x, sharding)
|
||||
|
||||
# `mul_sharded_x` and `sin_sharded_x` are sharded. `jit` is able to operate over a
|
||||
# sharded array without copying data to a single device.
|
||||
mul_sharded_x = sharded_x @ sharded_x.T
|
||||
sin_sharded_x = jnp.sin(sharded_x)
|
||||
|
||||
# Even jnp.copy preserves the sharding on the output.
|
||||
copy_sharded_x = jnp.copy(sharded_x)
|
||||
|
||||
# double_out is also sharded
|
||||
double_out = jax.jit(lambda x: x * 2)(sharded_x)
|
||||
```
|
||||
|
||||
#### What issues can arise when jax.Array is switched on? {#what-issues-can-arise-when-jax-array-is-switched-on}
|
||||
|
||||
##### New public type named jax.Array {#new-public-type-named-jax-array}
|
||||
|
||||
All `isinstance(..., jnp.DeviceArray)` or `isinstance(.., jax.xla.DeviceArray)`
|
||||
and other variants of `DeviceArray` should be switched to using `isinstance(...,
|
||||
jax.Array)`.
|
||||
|
||||
Since `jax.Array` can represent DA, SDA and GDA, you can differentiate those 3
|
||||
types in `jax.Array` via:
|
||||
|
||||
* `x.is_fully_addressable and len(x.sharding.device_set) == 1` -- this means
|
||||
that `jax.Array` is like a DA
|
||||
* `x.is_fully_addressable and (len(x.sharding.device_set) > 1` -- this means
|
||||
that `jax.Array` is like a SDA
|
||||
* `not x.is_fully_addressable` -- this means that `jax.Array` is like a GDA
|
||||
and spans across multiple processes
|
||||
|
||||
For `ShardedDeviceArray`, you can move `isinstance(...,
|
||||
pxla.ShardedDeviceArray)` to `isinstance(..., jax.Array) and
|
||||
x.is_fully_addressable and len(x.sharding.device_set) > 1`.
|
||||
|
||||
In general it is not possible to differentiate a `ShardedDeviceArray` on 1
|
||||
device from any other kind of single-device Array.
|
||||
|
||||
##### GDA’s API name changes {#gda’s-api-name-changes}
|
||||
|
||||
GDA’s `local_shards` and `local_data` have been deprecated.
|
||||
|
||||
Please use `addressable_shards` and `addressable_data` which are compatible with
|
||||
`jax.Array` and `GDA`.
|
||||
|
||||
##### Creating jax.Array {#creating-jax-array}
|
||||
|
||||
All JAX functions will output `jax.Array` when the `jax_array` flag is True. If
|
||||
you were using `GlobalDeviceArray.from_callback` or `make_sharded_device_array`
|
||||
or `make_device_array` functions to explicitly create the respective JAX data
|
||||
types, you will need to switch them to use `jax.make_array_from_callback` or
|
||||
`jax.make_array_from_single_device_arrays`.
|
||||
|
||||
**For GDA:**
|
||||
|
||||
`GlobalDeviceArray.from_callback(shape, mesh, pspec, callback)` can become
|
||||
`jax.make_array_from_callback(shape, jax.sharding.NamedSharding(mesh, pspec), callback)`
|
||||
in a 1:1 switch.
|
||||
|
||||
If you were using the raw GDA constructor to create GDAs, then do this:
|
||||
|
||||
`GlobalDeviceArray(shape, mesh, pspec, buffers)` can become
|
||||
`jax.make_array_from_single_device_arrays(shape, jax.sharding.NamedSharding(mesh, pspec), buffers)`
|
||||
|
||||
**For SDA:**
|
||||
|
||||
`make_sharded_device_array(aval, sharding_spec, device_buffers, indices)` can
|
||||
become `jax.make_array_from_single_device_arrays(shape, sharding, device_buffers)`.
|
||||
|
||||
To decide what the sharding should be, it depends on why you were creating the
|
||||
SDAs:
|
||||
|
||||
If it was created to give as an input to `pmap`, then sharding can be:
|
||||
`jax.sharding.PmapSharding(devices, sharding_spec)`.
|
||||
|
||||
If it was created to give as an input
|
||||
to `pjit`, then sharding can be `jax.sharding.NamedSharding(mesh, pspec)`.
|
||||
|
||||
##### Breaking change for pjit after switching to jax.Array for host local inputs {#breaking-change-for-pjit-after-switching-to-jax-array-for-host-local-inputs}
|
||||
|
||||
**If you are exclusively using GDA arguments to pjit, you can skip this section!
|
||||
🎉**
|
||||
|
||||
With `jax.Array` enabled, all inputs to `pjit` must be globally shaped. This is
|
||||
a breaking change from the previous behavior where `pjit` would concatenate
|
||||
process-local arguments into a global value; this concatenation no longer
|
||||
occurs.
|
||||
|
||||
Why are we making this breaking change? Each array now says explicitly how its
|
||||
local shards fit into a global whole, rather than leaving it implicit. The more
|
||||
explicit representation also unlocks additional flexibility, for example the use
|
||||
of non-contiguous meshes with `pjit` which can improve efficiency on some TPU
|
||||
models.
|
||||
|
||||
Running **multi-process pjit computation** and passing host-local inputs when
|
||||
`jax.Array` is enabled can lead to an error similar to this:
|
||||
|
||||
Example:
|
||||
|
||||
Mesh = `{'x': 2, 'y': 2, 'z': 3}` and host local input shape == `(4,)` and
|
||||
pspec = `P(('x', 'y', 'z'))`
|
||||
|
||||
Since `pjit` doesn’t lift host local shapes to global shapes with `jax.Array`,
|
||||
you get the following error:
|
||||
|
||||
Note: You will only see this error if your host local shape is smaller than the
|
||||
shape of the mesh.
|
||||
|
||||
```
|
||||
ValueError: One of pjit arguments was given the sharding of
|
||||
MeshPspecSharding(mesh={'x': 2, 'y': 2, 'chips': 2}, partition_spec=PartitionSpec(('x', 'y', 'chips'),)),
|
||||
which implies that the global size of its dimension 0 should be divisible by 8,
|
||||
but it is equal to 4
|
||||
```
|
||||
|
||||
The error makes sense because you can't shard dimension 0, 8 ways when the value
|
||||
on dimension `0` is `4`.
|
||||
|
||||
How can you migrate if you still pass host local inputs to `pjit`? We are
|
||||
providing transitional APIs to help you migrate:
|
||||
|
||||
Note: You don't need these utilities if you run your pjitted computation on a
|
||||
single process.
|
||||
|
||||
```
|
||||
from jax.experimental import multihost_utils
|
||||
|
||||
global_inps = multihost_utils.host_local_array_to_global_array(
|
||||
local_inputs, mesh, in_pspecs)
|
||||
|
||||
global_outputs = pjit(f, in_axis_resources=in_pspecs,
|
||||
out_axis_resources=out_pspecs)(global_inps)
|
||||
|
||||
local_outs = multihost_utils.global_array_to_host_local_array(
|
||||
global_outputs, mesh, out_pspecs)
|
||||
```
|
||||
|
||||
`host_local_array_to_global_array` is a type cast that looks at a value with
|
||||
only local shards and changes its local shape to the shape that `pjit` would
|
||||
have previously assumed if that value was passed before the change.
|
||||
|
||||
Passing in fully replicated inputs i.e. same shape on each process with
|
||||
`P(None)` as `in_axis_resources` is still supported. In this case you do not
|
||||
have to use `host_local_array_to_global_array` because the shape is already
|
||||
global.
|
||||
|
||||
```
|
||||
key = jax.random.PRNGKey(1)
|
||||
|
||||
# As you can see, using host_local_array_to_global_array is not required since in_axis_resources says
|
||||
# that the input is fully replicated via P(None)
|
||||
pjit(f, in_axis_resources=None, out_axis_resources=None)(key)
|
||||
|
||||
# Mixing inputs
|
||||
global_inp = multihost_utils.host_local_array_to_global_array(
|
||||
local_inp, mesh, P('data'))
|
||||
global_out = pjit(f, in_axis_resources=(P(None), P('data')),
|
||||
out_axis_resources=...)(key, global_inp)
|
||||
```
|
||||
|
||||
##### FROM_GDA and jax.Array {#from_gda-and-jax-array}
|
||||
|
||||
If you were using `FROM_GDA` in `in_axis_resources` argument to `pjit`, then
|
||||
with `jax.Array` there is no need to pass anything to `in_axis_resources` as
|
||||
`jax.Array` will follow **computation follows sharding** semantics.
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
pjit(f, in_axis_resources=FROM_GDA, out_axis_resources=...) can be replaced by pjit(f, out_axis_resources=...)
|
||||
```
|
||||
|
||||
If you have PartitionSpecs mixed in with `FROM_GDA` for inputs like numpy
|
||||
arrays, etc, then use `host_local_array_to_global_array` to convert them to
|
||||
`jax.Array`.
|
||||
|
||||
For example:
|
||||
|
||||
If you had this:
|
||||
|
||||
```
|
||||
pjitted_f = pjit(
|
||||
f, in_axis_resources=(FROM_GDA, P('x'), FROM_GDA, P(None)),
|
||||
out_axis_resources=...)
|
||||
pjitted_f(gda1, np_array1, gda2, np_array2)
|
||||
```
|
||||
|
||||
then you can replace it with:
|
||||
|
||||
```
|
||||
|
||||
pjitted_f = pjit(f, out_axis_resources=...)
|
||||
|
||||
array2, array3 = multihost_utils.host_local_array_to_global_array(
|
||||
(np_array1, np_array2), mesh, (P('x'), P(None)))
|
||||
|
||||
pjitted_f(array1, array2, array3, array4)
|
||||
```
|
||||
|
||||
##### live\_buffers replaced with live\_arrays {#live_buffers-replaced-with-live_arrays}
|
||||
|
||||
`live_buffers` attribute on jax `Device`
|
||||
has been deprecated. Please use `jax.live_arrays()` instead which is compatible
|
||||
with `jax.Array`.
|
||||
|
||||
|
||||
##### Handling of host local inputs to pjit like batch, etc {#handling-of-host-local-inputs-to-pjit-like-batch-etc}
|
||||
|
||||
If you are passing host local inputs to `pjit` in a **multi-process
|
||||
environment**, then please use
|
||||
`multihost_utils.host_local_array_to_global_array` to convert the batch to a
|
||||
global `jax.Array` and then pass that to `pjit`.
|
||||
|
||||
The most common example of such a host local input is a **batch of input data**.
|
||||
|
||||
See this cl/486388872 as an example of how to do that.
|
||||
|
||||
This will work for any host local input (not just a batch of input data).
|
||||
|
||||
```
|
||||
from jax.experimental import multihost_utils
|
||||
|
||||
batch = multihost_utils.host_local_array_to_global_array(
|
||||
batch, mesh, batch_partition_spec)
|
||||
```
|
||||
|
||||
See the pjit section above for more details about this change and more examples.
|
||||
|
||||
##### RecursionError: Recursively calling jit {#recursionerror-recursively-calling-jit}
|
||||
|
||||
This happens when some part of your code has `jax.Array` disabled and then you
|
||||
enable it only for some other part. For example, if you use some third\_party
|
||||
code which has `jax.Array` disabled and you get a `DeviceArray` from that
|
||||
library and then you enable `jax.Array` in your library and pass that
|
||||
`DeviceArray` to JAX functions, it will lead to a RecursionError.
|
||||
|
||||
This error should go away when `jax.Array` is enabled by default so that all
|
||||
libraries return `jax.Array` unless they explicitly disable it.
|
Loading…
x
Reference in New Issue
Block a user