mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #13197 from google:yashk2810-patch-17
PiperOrigin-RevId: 487687006
This commit is contained in:
commit
f9d7a6ae20
@ -1,12 +1,10 @@
|
||||
## jax.Array migration
|
||||
# jax.Array migration
|
||||
|
||||
**yashkatariya@**
|
||||
|
||||
[TOC]
|
||||
## TL;DR
|
||||
|
||||
#### TL;DR {#tl-dr}
|
||||
|
||||
##### What’s going on? {#what’s-going-on}
|
||||
### What’s going on?
|
||||
|
||||
`jax.Array` is a unified array type that subsumes `DeviceArray`, `ShardedDeviceArray`,
|
||||
and `GlobalDeviceArray` types in JAX. The `jax.Array` type helps make parallelism a
|
||||
@ -20,7 +18,7 @@ After the migration is complete `jax.Array` will be the only type of array in
|
||||
JAX.
|
||||
|
||||
|
||||
##### How to enable jax.Array? {#how-to-enable-jax-array}
|
||||
### How to enable jax.Array?
|
||||
|
||||
You can enable `jax.Array` by:
|
||||
|
||||
@ -35,12 +33,12 @@ You can enable `jax.Array` by:
|
||||
jax.config.update('jax_array', True)
|
||||
```
|
||||
|
||||
##### How do I know if jax.Array broke my code? {#how-do-i-know-if-jax-array-broke-my-code}
|
||||
### How do I know if jax.Array broke my code?
|
||||
|
||||
The easiest way to tell if `jax.Array` is responsible for any problems is to
|
||||
disable `jax.Array` and see if the issues go away.
|
||||
|
||||
##### How can I disable jax.Array for now? {#how-can-i-disable-jax-array-for-now}
|
||||
### How can I disable jax.Array for now?
|
||||
|
||||
You can disable `jax.Array` by: (After a certain date (TBD), the option to
|
||||
disable jax.Array won't exist)
|
||||
@ -56,7 +54,7 @@ disable jax.Array won't exist)
|
||||
jax.config.update('jax_array', False)
|
||||
```
|
||||
|
||||
#### Why create jax.Array? {#why-create-jax-array}
|
||||
## Why create jax.Array?
|
||||
|
||||
Currently JAX has three types; `DeviceArray`, `ShardedDeviceArray` and
|
||||
`GlobalDeviceArray`. `jax.Array` merges these three types and cleans up JAX’s
|
||||
@ -105,9 +103,9 @@ copy_sharded_x = jnp.copy(sharded_x)
|
||||
double_out = jax.jit(lambda x: x * 2)(sharded_x)
|
||||
```
|
||||
|
||||
#### What issues can arise when jax.Array is switched on? {#what-issues-can-arise-when-jax-array-is-switched-on}
|
||||
## What issues can arise when jax.Array is switched on?
|
||||
|
||||
##### New public type named jax.Array {#new-public-type-named-jax-array}
|
||||
### New public type named jax.Array
|
||||
|
||||
All `isinstance(..., jnp.DeviceArray)` or `isinstance(.., jax.xla.DeviceArray)`
|
||||
and other variants of `DeviceArray` should be switched to using `isinstance(...,
|
||||
@ -130,14 +128,14 @@ x.is_fully_addressable and len(x.sharding.device_set) > 1`.
|
||||
In general it is not possible to differentiate a `ShardedDeviceArray` on 1
|
||||
device from any other kind of single-device Array.
|
||||
|
||||
##### GDA’s API name changes {#gda’s-api-name-changes}
|
||||
### GDA’s API name changes
|
||||
|
||||
GDA’s `local_shards` and `local_data` have been deprecated.
|
||||
|
||||
Please use `addressable_shards` and `addressable_data` which are compatible with
|
||||
`jax.Array` and `GDA`.
|
||||
|
||||
##### Creating jax.Array {#creating-jax-array}
|
||||
### Creating jax.Array
|
||||
|
||||
All JAX functions will output `jax.Array` when the `jax_array` flag is True. If
|
||||
you were using `GlobalDeviceArray.from_callback` or `make_sharded_device_array`
|
||||
@ -170,7 +168,7 @@ If it was created to give as an input to `pmap`, then sharding can be:
|
||||
If it was created to give as an input
|
||||
to `pjit`, then sharding can be `jax.sharding.NamedSharding(mesh, pspec)`.
|
||||
|
||||
##### Breaking change for pjit after switching to jax.Array for host local inputs {#breaking-change-for-pjit-after-switching-to-jax-array-for-host-local-inputs}
|
||||
### Breaking change for pjit after switching to jax.Array for host local inputs
|
||||
|
||||
**If you are exclusively using GDA arguments to pjit, you can skip this section!
|
||||
🎉**
|
||||
@ -252,7 +250,7 @@ global_out = pjit(f, in_axis_resources=(P(None), P('data')),
|
||||
out_axis_resources=...)(key, global_inp)
|
||||
```
|
||||
|
||||
##### FROM_GDA and jax.Array {#from_gda-and-jax-array}
|
||||
### FROM_GDA and jax.Array
|
||||
|
||||
If you were using `FROM_GDA` in `in_axis_resources` argument to `pjit`, then
|
||||
with `jax.Array` there is no need to pass anything to `in_axis_resources` as
|
||||
@ -291,14 +289,14 @@ array2, array3 = multihost_utils.host_local_array_to_global_array(
|
||||
pjitted_f(array1, array2, array3, array4)
|
||||
```
|
||||
|
||||
##### live\_buffers replaced with live\_arrays {#live_buffers-replaced-with-live_arrays}
|
||||
### live_buffers replaced with live_arrays
|
||||
|
||||
`live_buffers` attribute on jax `Device`
|
||||
has been deprecated. Please use `jax.live_arrays()` instead which is compatible
|
||||
with `jax.Array`.
|
||||
|
||||
|
||||
##### Handling of host local inputs to pjit like batch, etc {#handling-of-host-local-inputs-to-pjit-like-batch-etc}
|
||||
### Handling of host local inputs to pjit like batch, etc
|
||||
|
||||
If you are passing host local inputs to `pjit` in a **multi-process
|
||||
environment**, then please use
|
||||
@ -320,7 +318,7 @@ batch = multihost_utils.host_local_array_to_global_array(
|
||||
|
||||
See the pjit section above for more details about this change and more examples.
|
||||
|
||||
##### RecursionError: Recursively calling jit {#recursionerror-recursively-calling-jit}
|
||||
### RecursionError: Recursively calling jit
|
||||
|
||||
This happens when some part of your code has `jax.Array` disabled and then you
|
||||
enable it only for some other part. For example, if you use some third\_party
|
||||
|
Loading…
x
Reference in New Issue
Block a user