Updates doc for host_local_array_to_global_array to reflect a few use case patterns and adds a few extra tests.

PiperOrigin-RevId: 596739035
This commit is contained in:
Mark Sandler 2024-01-08 16:50:49 -08:00 committed by jax authors
parent da96633f11
commit 773e1499f1

View File

@ -261,13 +261,56 @@ def host_local_array_to_global_array(
local_inputs: Any, global_mesh: jax.sharding.Mesh, pspecs: Any):
r"""Converts a host local value to a globally sharded jax.Array.
This function takes host-local data (which might be different
across hosts), and populates a global array with this data, where each
device on each host, get the appropriate slice of the data according to
sharding defined by the global_mesh/pspects.
For example:
>>> global_mesh = jax.sharding.Mesh(jax.devices(), 'x')
>>> pspecs = jax.sharding.PartitionSpec('x')
>>> host_id = jax.process_index()
>>> arr = host_local_array_to_global_array(np.arange(4) * host_id, mesh, pspecs) # NB: assumes jax.local_device_count() divides 4. # doctest: +SKIP
The resulting array will have the shape (4 * num_processes) and will
have distributed value of: (0, 1, 2, 3, 0, 2, 4, 6, 0, 3, 6, 9, ... ),
where each slice np.arange(4) * host_id will be partitioned across the
corresponding host's devices.
Similarly:
>>> mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(jax.process_count(), jax.local_device_count()), ['host', 'dev'])
>>> pspecs = jax.sharding.PartitionSpec('host')
>>> host_id = jax.process_index()
>>> arr = host_local_array_to_global_array(np.arange(4) * host_id, mesh, pspecs) # doctest: +SKIP
will create the same distributed value (0, 1, 2, 3, 0, 2, 4, 6, ...),
however each slice np.arange(4) * i will be *replicated* across corresponding
host devices.
On the other hand, if pspecs = PartitionSpec(), which means
replication across all axes, then this snippet:
>>> pspecs = jax.sharding.PartitionSpec()
>>> arr = host_local_array_to_global_array(np.arange(4), mesh, pspecs) # doctest: +SKIP
will have the shape (4,) and the value (0, 1, 2, 3) will be replicated
across all hosts and devices.
It is an undefined behavior to have not identical local_inputs with pspec
indicating data replication.
You can use this function to transition to jax.Array. Using jax.Array with
pjit has the same semantics of using GDA with pjit i.e. all jax.Array
inputs to pjit should be globally shaped.
If you are currently passing host local values to pjit, you can use this
function to convert your host local values to global Arrays and then pass that
to pjit. Example usage.
to pjit.
Example usage.
>>> from jax.experimental import multihost_utils # doctest: +SKIP
>>>
@ -278,10 +321,20 @@ def host_local_array_to_global_array(
>>>
>>> host_local_output = multihost_utils.global_array_to_host_local_array(global_out, mesh, out_pspecs) # doctest: +SKIP
Please note ths function requires global mesh to be a continuous mesh, meaning
that devices that belong to each host should form a subcube in this mesh.
To move local data to global array with non-continuous mesh use
jax.make_array_from_callback or jax.make_array_from_single_device_arrays
instead.
Args:
local_inputs: A Pytree of host local values.
global_mesh: A jax.sharding.Mesh object.
global_mesh: A jax.sharding.Mesh object. The mesh must be a contiguous mesh,
that is all hosts' devices must form a subcube in this mesh.
pspecs: A Pytree of jax.sharding.PartitionSpec's.
Returns:
A pytree of global arrays.
"""
flat_inps, in_tree = tree_flatten(local_inputs)
in_pspecs = _flatten_pspecs('input pspecs', in_tree,