Merge pull request #18127 from rwitten:rwitten_make_array_from_single_device_arrays_docs

PiperOrigin-RevId: 599940102
This commit is contained in:
jax authors 2024-01-19 14:35:50 -08:00
commit aaac4f93a8

View File

@ -716,54 +716,81 @@ def make_array_from_callback(
def make_array_from_single_device_arrays(
shape: Shape, sharding: Sharding, arrays: Sequence[basearray.Array]
) -> ArrayImpl:
r"""Returns a ``jax.Array`` from a sequence of ``jax.Array``\s on a single device.
You can use this function if you have already ``jax.device_put`` the value on
a single device and want to create a global Array. The smaller ``jax.Array``\s should be
addressable and belong to the current process.
r"""Returns a ``jax.Array`` from a sequence of ``jax.Array``\s each on a single device.
Every device in input ``sharding``\'s mesh must have an array in ``arrays``\s.
Args:
shape : Shape of the ``jax.Array``.
sharding: A ``Sharding`` instance which describes how the ``jax.Array`` is
laid out across devices.
arrays: Sequence of ``jax.Array``\s that are on a single device.
shape : Shape of the output ``jax.Array``. This conveys information already included with
``sharding`` and ``arrays`` and serves as a double check.
sharding: Sharding: A global Sharding instance which describes how the output jax.Array is laid out across devices.
arrays: Sequence of ``jax.Array``\s that are each single device addressable. ``len(arrays)``
must equal ``len(sharding.addressable_devices)`` and the shape of each array must be the same. For multiprocess code,
each process will call with a different ``arrays`` argument that corresponds to that processes' data.
These arrays are commonly created via ``jax.device_put``.
Returns:
A ``jax.Array`` from a sequence of ``jax.Array``\s on a single device.
A global ``jax.Array``, sharded as ``sharding``, with shape equal to ``shape``, and with per-device
contents matching ``arrays``.
Example:
Examples:
In this single-process example, we use ``make_array_from_single_device_arrays`` to create an
a global array.
>>> import math
>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> import numpy as np
...
>>> mesh_rows = 2
>>> mesh_cols = jax.device_count() // 2
...
>>> global_shape = (8, 8)
>>> global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
>>> sharding = jax.sharding.NamedSharding(global_mesh, P('x', 'y'))
>>> mesh = Mesh(np.array(jax.devices()).reshape(mesh_rows, mesh_cols), ('x', 'y'))
>>> sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))
>>> inp_data = np.arange(math.prod(global_shape)).reshape(global_shape)
...
>>> arrays = [
... jax.device_put(inp_data[index], d)
... for d, index in sharding.addressable_devices_indices_map(global_shape).items()]
... jax.device_put(inp_data[index], d)
... for d, index in sharding.addressable_devices_indices_map(global_shape).items()]
...
>>> arr = jax.make_array_from_single_device_arrays(global_shape, sharding, arrays)
>>> arr.addressable_data(0).shape
(4, 2)
>>> assert arr.shape == (8,8) # arr.shape is (8,8) regardless of jax.device_count()
In multi-process case, if the input is process local and data parallel
i.e. each process receives a different part of the data, then you can use
`make_array_from_single_device_arrays` to create a global jax.Array
When using multiple processes, a common data pipeling is to have data parallelism across devices,
with each device receiving at least one example. In this case, the following recipe will use
`make_array_from_single_device_arrays` to create a global jax.Array.
First, we create the per host data as Numpy arrays.
>>> sharding = jax.sharding.NamedSharding(mesh, P(('x', 'y'),))
>>> rows_per_device = 2
>>> feature_length = 32
>>> per_device_shape = (rows_per_device, feature_length)
>>> per_host_shape = (rows_per_device * len(mesh.local_devices), feature_length)
>>> per_host_generator = lambda : np.arange(np.prod(per_host_shape)).reshape(per_host_shape)
>>> per_host_data = per_host_generator() # replace with your own per-host data pipeline that outputs numpy arrays
Second, we put the Numpy data onto the local devices as single device Jax Arrays. Then we call
make_array_from_single_device_arrays to make the global Array.
>>> global_shape = (rows_per_device * len(sharding.device_set), ) + per_device_shape[1:]
>>> per_device_data = np.split(per_host_data, len(mesh.local_devices), axis = 0) # per device data, but on host
>>> per_device_data_on_device = jax.device_put(per_device_data, mesh.local_devices) # per device data, now on device
>>> output_global_array = jax.make_array_from_single_device_arrays(global_shape, sharding, per_device_data_on_device)
...
>>> assert output_global_array.addressable_data(0).shape == per_device_shape
>>> assert output_global_array.shape == global_shape
When using tensor parallelism (equivalent to sharding across both rows and columns in the
above example), the above example doesn't generate the data in the sharding that you plan
to consume it with. The most common fix is to simply load the data in this data parallel sharding
and have the reshard happen automatically within the downstream jitted function.
Depending on your use case, you might prefer to directly load sharded data, something that
``make_array_from_single_device_arrays`` can do but will depend on your data loading pipeline
also loading in the matching sharding. Loading in a data parallel format is typically
fully satisfactory for data loading for LLM use cases.
>>> local_shape = (8, 2)
>>> global_shape = (jax.process_count() * local_shape[0], ) + local_shape[1:]
>>> local_array = np.arange(math.prod(local_shape)).reshape(local_shape)
>>> arrays = jax.device_put(
... np.split(local_array, len(global_mesh.local_devices), axis = 0), global_mesh.local_devices)
>>> sharding = jax.sharding.NamedSharding(global_mesh, P(('x', 'y'), ))
>>> arr = jax.make_array_from_single_device_arrays(global_shape, sharding, arrays)
>>> arr.addressable_data(0).shape
(1, 2)
"""
# All input arrays should be committed. Checking it is expensive on
# single-controller systems.