Add docstrings for jax.Array APIs make_array_from_callback and make_array_from_single_device_arrays.

PiperOrigin-RevId: 487929688
This commit is contained in:
Yash Katariya 2022-11-11 15:20:27 -08:00 committed by jax authors
parent 19d76a7818
commit 6897d37562
3 changed files with 80 additions and 3 deletions

View File

@ -77,6 +77,14 @@ Automatic differentiation
closure_convert
checkpoint
jax.Array (:code:`jax.Array`)
-----------------------------
.. autosummary::
:toctree: _autosummary
make_array_from_callback
make_array_from_single_device_arrays
Vectorization (:code:`vmap`)
----------------------------

View File

@ -140,12 +140,12 @@ Please use `addressable_shards` and `addressable_data` which are compatible with
All JAX functions will output `jax.Array` when the `jax_array` flag is True. If
you were using `GlobalDeviceArray.from_callback` or `make_sharded_device_array`
or `make_device_array` functions to explicitly create the respective JAX data
types, you will need to switch them to use `jax.make_array_from_callback` or
`jax.make_array_from_single_device_arrays`.
types, you will need to switch them to use {func}`jax.make_array_from_callback`
or {func}`jax.make_array_from_single_device_arrays`.
**For GDA:**
`GlobalDeviceArray.from_callback(shape, mesh, pspec, callback)` can become
`GlobalDeviceArray.from_callback(shape, mesh, pspec, callback)` can become
`jax.make_array_from_callback(shape, jax.sharding.NamedSharding(mesh, pspec), callback)`
in a 1:1 switch.

View File

@ -510,6 +510,40 @@ setattr(ArrayImpl, "__array_priority__", 100)
def make_array_from_callback(
shape: Shape, sharding: Sharding,
data_callback: Callable[[Optional[Index]], ArrayLike]) -> ArrayImpl:
"""Returns a ``jax.Array`` via data fetched from ``data_callback``.
``data_callback`` is used to fetch the data for each addressable shard of the
returned ``jax.Array``.
Args:
shape : Shape of the ``jax.Array``.
sharding: A ``Sharding`` instance which describes how the ``jax.Array`` is
laid out across devices.
data_callback : Callback that takes indices into the global array value as
input and returns the corresponding data of the global array value.
The data can be returned as any array-like object, e.g. a ``numpy.ndarray``.
Returns:
A ``jax.Array`` via data fetched from ``data_callback``.
Example:
>>> from jax.experimental.maps import Mesh
>>> from jax.experimental import PartitionSpec as P
>>> import numpy as np
...
>>> input_shape = (8, 8)
>>> global_input_data = np.arange(prod(input_shape)).reshape(input_shape)
>>> global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
>>> inp_sharding = jax.sharding.NamedSharding(global_mesh, P('x', 'y'))
...
>>> def cb(index):
... return global_input_data[index]
...
>>> arr = jax.make_array_from_callback(input_shape, inp_sharding, cb)
>>> arr.addressable_data(0).shape
(4, 2)
"""
device_to_index_map = sharding.devices_indices_map(shape)
# Use addressable_devices here instead of `_addressable_device_assignment`
# because `_addressable_device_assignment` is only available on
@ -525,6 +559,41 @@ def make_array_from_callback(
def make_array_from_single_device_arrays(
shape: Shape, sharding: Sharding, arrays: Sequence[ArrayImpl]) -> ArrayImpl:
r"""Returns a ``jax.Array`` from a sequence of ``jax.Array``\s on a single device.
``jax.Array`` on a single device is analogous to a ``DeviceArray``. You can use
this function if you have already ``jax.device_put`` the value on a single
device and want to create a global Array. The smaller ``jax.Array``\s should be
addressable and belong to the current process.
Args:
shape : Shape of the ``jax.Array``.
sharding: A ``Sharding`` instance which describes how the ``jax.Array`` is
laid out across devices.
arrays: Sequence of ``jax.Array``\s that are on a single device.
Returns:
A ``jax.Array`` from a sequence of ``jax.Array``\s on a single device.
Example:
>>> from jax.experimental.maps import Mesh
>>> from jax.experimental import PartitionSpec as P
>>> import numpy as np
...
>>> shape = (8, 8)
>>> global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
>>> sharding = jax.sharding.NamedSharding(global_mesh, P('x', 'y'))
>>> inp_data = np.arange(prod(shape)).reshape(shape)
...
>>> arrays = [
... jax.device_put(inp_data[index], d)
... for d, index in sharding.addressable_devices_indices_map(shape).items()]
...
>>> arr = jax.make_array_from_single_device_arrays(shape, sharding, arrays)
>>> arr.addressable_data(0).shape
(4, 2)
"""
# All input arrays should be committed. Checking it is expensive on
# single-controller systems.
aval = core.ShapedArray(shape, arrays[0].dtype, weak_type=False)