mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Updated make_array_from_single_device_arrays docs
This commit is contained in:
parent
08837a9919
commit
03a8e5885b
@ -700,54 +700,76 @@ def make_array_from_callback(
|
||||
def make_array_from_single_device_arrays(
|
||||
shape: Shape, sharding: Sharding, arrays: Sequence[basearray.Array]
|
||||
) -> ArrayImpl:
|
||||
r"""Returns a ``jax.Array`` from a sequence of ``jax.Array``\s on a single device.
|
||||
|
||||
You can use this function if you have already ``jax.device_put`` the value on
|
||||
a single device and want to create a global Array. The smaller ``jax.Array``\s should be
|
||||
addressable and belong to the current process.
|
||||
r"""Returns a ``jax.Array`` from a sequence of ``jax.Array``\s each on a single device.
|
||||
Every device in input ``sharding``\'s mesh must have an array in ``arrays``\s.
|
||||
|
||||
Args:
|
||||
shape : Shape of the ``jax.Array``.
|
||||
sharding: A ``Sharding`` instance which describes how the ``jax.Array`` is
|
||||
laid out across devices.
|
||||
arrays: Sequence of ``jax.Array``\s that are on a single device.
|
||||
shape : Shape of the output ``jax.Array``. This conveys information already included with
|
||||
``sharding`` and ``arrays`` and serves as a double check.
|
||||
sharding: Sharding: A global Sharding instance which describes how the output jax.Array is laid out across devices.
|
||||
arrays: Sequence of ``jax.Array``\s that are each single device addressable. ``len(arrays)``
|
||||
must equal ``len(sharding.addressable_devices)`` and the shape of each array must be the same. For multiprocess code,
|
||||
each process will call with a different ``arrays`` argument that corresponds to that processes' data.
|
||||
These arrays are commonly created via ``jax.device_put``.
|
||||
|
||||
Returns:
|
||||
A ``jax.Array`` from a sequence of ``jax.Array``\s on a single device.
|
||||
A global ``jax.Array``, sharded as ``sharding``, with shape equal to ``shape``, and with per-device
|
||||
contents matching ``arrays``.
|
||||
|
||||
Example:
|
||||
Examples:
|
||||
|
||||
In this single-process example, we use ``make_array_from_single_device_arrays`` to create an
|
||||
a global array.
|
||||
|
||||
>>> import math
|
||||
>>> from jax.sharding import Mesh
|
||||
>>> from jax.sharding import PartitionSpec as P
|
||||
>>> import numpy as np
|
||||
...
|
||||
>>> mesh_rows = 2
|
||||
>>> mesh_cols = jax.device_count() // 2
|
||||
...
|
||||
>>> global_shape = (8, 8)
|
||||
>>> global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
|
||||
>>> sharding = jax.sharding.NamedSharding(global_mesh, P('x', 'y'))
|
||||
>>> mesh = Mesh(np.array(jax.devices()).reshape(mesh_rows, mesh_cols), ('x', 'y'))
|
||||
>>> sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))
|
||||
>>> inp_data = np.arange(math.prod(global_shape)).reshape(global_shape)
|
||||
...
|
||||
>>> arrays = [
|
||||
... jax.device_put(inp_data[index], d)
|
||||
... for d, index in sharding.addressable_devices_indices_map(global_shape).items()]
|
||||
... jax.device_put(inp_data[index], d)
|
||||
... for d, index in sharding.addressable_devices_indices_map(global_shape).items()]
|
||||
...
|
||||
>>> arr = jax.make_array_from_single_device_arrays(global_shape, sharding, arrays)
|
||||
>>> arr.addressable_data(0).shape
|
||||
(4, 2)
|
||||
>>> assert arr.shape == (8,8) # arr.shape is (8,8) regardless of jax.device_count()
|
||||
|
||||
In multi-process case, if the input is process local and data parallel
|
||||
i.e. each process receives a different part of the data, then you can use
|
||||
`make_array_from_single_device_arrays` to create a global jax.Array
|
||||
When using multiple processes, a common data pipeling is to have data parallelism across devices,
|
||||
with each device receiving at least one example. In this case, the following recipe will use
|
||||
`make_array_from_single_device_arrays` to create a global jax.Array.
|
||||
|
||||
>>> sharding = jax.sharding.NamedSharding(mesh, P(('x', 'y'),))
|
||||
>>> rows_per_device = 2
|
||||
>>> feature_length = 32
|
||||
>>> per_device_shape = (rows_per_device, feature_length)
|
||||
>>> per_host_shape = (rows_per_device * len(mesh.local_devices), feature_length)
|
||||
>>> global_shape = (rows_per_device * jax.device_count(), ) + per_device_shape[1:]
|
||||
>>> per_host_generator = lambda : np.arange(np.prod(per_host_shape)).reshape(per_host_shape)
|
||||
...
|
||||
>>> per_host_data = per_host_generator() # replace with your own per-host data pipeline that outputs numpy arrays
|
||||
>>> per_device_data = np.split(per_host_data, len(mesh.local_devices), axis = 0) # per device data, but on host
|
||||
>>> per_device_data_on_device = jax.device_put(per_device_data, mesh.local_devices) # per device data, now on device
|
||||
>>> output_global_array = jax.make_array_from_single_device_arrays(global_shape, sharding, per_device_data_on_device)
|
||||
...
|
||||
>>> assert output_global_array.addressable_data(0).shape == (rows_per_device, feature_length)
|
||||
>>> assert output_global_array.shape == (rows_per_device * jax.device_count(), feature_length)
|
||||
|
||||
When using tensor parallelism (equivalent to sharding across both rows and columns in the
|
||||
above example), the above example doesn't generate the data in the sharding that you plan
|
||||
to consume it with. The most common fix is to simply load the data in this data parallel sharding
|
||||
and have the reshard happen automatically within the downstream jitted function.
|
||||
Depending on your use case, you might prefer to directly load sharded data, something that
|
||||
``make_array_from_single_device_arrays`` can do but will depend on your data loading pipeline
|
||||
also loading in the matching sharding. Loading in a data parallel format is typically
|
||||
fully satisfactory for data loading for LLM use cases.
|
||||
|
||||
>>> local_shape = (8, 2)
|
||||
>>> global_shape = (jax.process_count() * local_shape[0], ) + local_shape[1:]
|
||||
>>> local_array = np.arange(math.prod(local_shape)).reshape(local_shape)
|
||||
>>> arrays = jax.device_put(
|
||||
... np.split(local_array, len(global_mesh.local_devices), axis = 0), global_mesh.local_devices)
|
||||
>>> sharding = jax.sharding.NamedSharding(global_mesh, P(('x', 'y'), ))
|
||||
>>> arr = jax.make_array_from_single_device_arrays(global_shape, sharding, arrays)
|
||||
>>> arr.addressable_data(0).shape
|
||||
(1, 2)
|
||||
"""
|
||||
# All input arrays should be committed. Checking it is expensive on
|
||||
# single-controller systems.
|
||||
|
Loading…
x
Reference in New Issue
Block a user