mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add docstrings for jax.Array APIs make_array_from_callback
and make_array_from_single_device_arrays
.
PiperOrigin-RevId: 487929688
This commit is contained in:
parent
19d76a7818
commit
6897d37562
@ -77,6 +77,14 @@ Automatic differentiation
|
||||
closure_convert
|
||||
checkpoint
|
||||
|
||||
jax.Array (:code:`jax.Array`)
|
||||
-----------------------------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
make_array_from_callback
|
||||
make_array_from_single_device_arrays
|
||||
|
||||
Vectorization (:code:`vmap`)
|
||||
----------------------------
|
||||
|
@ -140,12 +140,12 @@ Please use `addressable_shards` and `addressable_data` which are compatible with
|
||||
All JAX functions will output `jax.Array` when the `jax_array` flag is True. If
|
||||
you were using `GlobalDeviceArray.from_callback` or `make_sharded_device_array`
|
||||
or `make_device_array` functions to explicitly create the respective JAX data
|
||||
types, you will need to switch them to use `jax.make_array_from_callback` or
|
||||
`jax.make_array_from_single_device_arrays`.
|
||||
types, you will need to switch them to use {func}`jax.make_array_from_callback`
|
||||
or {func}`jax.make_array_from_single_device_arrays`.
|
||||
|
||||
**For GDA:**
|
||||
|
||||
`GlobalDeviceArray.from_callback(shape, mesh, pspec, callback)` can become
|
||||
`GlobalDeviceArray.from_callback(shape, mesh, pspec, callback)` can become
|
||||
`jax.make_array_from_callback(shape, jax.sharding.NamedSharding(mesh, pspec), callback)`
|
||||
in a 1:1 switch.
|
||||
|
||||
|
@ -510,6 +510,40 @@ setattr(ArrayImpl, "__array_priority__", 100)
|
||||
def make_array_from_callback(
|
||||
shape: Shape, sharding: Sharding,
|
||||
data_callback: Callable[[Optional[Index]], ArrayLike]) -> ArrayImpl:
|
||||
"""Returns a ``jax.Array`` via data fetched from ``data_callback``.
|
||||
|
||||
``data_callback`` is used to fetch the data for each addressable shard of the
|
||||
returned ``jax.Array``.
|
||||
|
||||
Args:
|
||||
shape : Shape of the ``jax.Array``.
|
||||
sharding: A ``Sharding`` instance which describes how the ``jax.Array`` is
|
||||
laid out across devices.
|
||||
data_callback : Callback that takes indices into the global array value as
|
||||
input and returns the corresponding data of the global array value.
|
||||
The data can be returned as any array-like object, e.g. a ``numpy.ndarray``.
|
||||
|
||||
Returns:
|
||||
A ``jax.Array`` via data fetched from ``data_callback``.
|
||||
|
||||
Example:
|
||||
|
||||
>>> from jax.experimental.maps import Mesh
|
||||
>>> from jax.experimental import PartitionSpec as P
|
||||
>>> import numpy as np
|
||||
...
|
||||
>>> input_shape = (8, 8)
|
||||
>>> global_input_data = np.arange(prod(input_shape)).reshape(input_shape)
|
||||
>>> global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
|
||||
>>> inp_sharding = jax.sharding.NamedSharding(global_mesh, P('x', 'y'))
|
||||
...
|
||||
>>> def cb(index):
|
||||
... return global_input_data[index]
|
||||
...
|
||||
>>> arr = jax.make_array_from_callback(input_shape, inp_sharding, cb)
|
||||
>>> arr.addressable_data(0).shape
|
||||
(4, 2)
|
||||
"""
|
||||
device_to_index_map = sharding.devices_indices_map(shape)
|
||||
# Use addressable_devices here instead of `_addressable_device_assignment`
|
||||
# because `_addressable_device_assignment` is only available on
|
||||
@ -525,6 +559,41 @@ def make_array_from_callback(
|
||||
|
||||
def make_array_from_single_device_arrays(
|
||||
shape: Shape, sharding: Sharding, arrays: Sequence[ArrayImpl]) -> ArrayImpl:
|
||||
r"""Returns a ``jax.Array`` from a sequence of ``jax.Array``\s on a single device.
|
||||
|
||||
``jax.Array`` on a single device is analogous to a ``DeviceArray``. You can use
|
||||
this function if you have already ``jax.device_put`` the value on a single
|
||||
device and want to create a global Array. The smaller ``jax.Array``\s should be
|
||||
addressable and belong to the current process.
|
||||
|
||||
Args:
|
||||
shape : Shape of the ``jax.Array``.
|
||||
sharding: A ``Sharding`` instance which describes how the ``jax.Array`` is
|
||||
laid out across devices.
|
||||
arrays: Sequence of ``jax.Array``\s that are on a single device.
|
||||
|
||||
Returns:
|
||||
A ``jax.Array`` from a sequence of ``jax.Array``\s on a single device.
|
||||
|
||||
Example:
|
||||
|
||||
>>> from jax.experimental.maps import Mesh
|
||||
>>> from jax.experimental import PartitionSpec as P
|
||||
>>> import numpy as np
|
||||
...
|
||||
>>> shape = (8, 8)
|
||||
>>> global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
|
||||
>>> sharding = jax.sharding.NamedSharding(global_mesh, P('x', 'y'))
|
||||
>>> inp_data = np.arange(prod(shape)).reshape(shape)
|
||||
...
|
||||
>>> arrays = [
|
||||
... jax.device_put(inp_data[index], d)
|
||||
... for d, index in sharding.addressable_devices_indices_map(shape).items()]
|
||||
...
|
||||
>>> arr = jax.make_array_from_single_device_arrays(shape, sharding, arrays)
|
||||
>>> arr.addressable_data(0).shape
|
||||
(4, 2)
|
||||
"""
|
||||
# All input arrays should be committed. Checking it is expensive on
|
||||
# single-controller systems.
|
||||
aval = core.ShapedArray(shape, arrays[0].dtype, weak_type=False)
|
||||
|
Loading…
x
Reference in New Issue
Block a user