mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Updates doc for host_local_array_to_global_array to reflect a few use case patterns and adds a few extra tests.
PiperOrigin-RevId: 596739035
This commit is contained in:
parent
da96633f11
commit
773e1499f1
@ -261,13 +261,56 @@ def host_local_array_to_global_array(
|
||||
local_inputs: Any, global_mesh: jax.sharding.Mesh, pspecs: Any):
|
||||
r"""Converts a host local value to a globally sharded jax.Array.
|
||||
|
||||
This function takes host-local data (which might be different
|
||||
across hosts), and populates a global array with this data, where each
|
||||
device on each host, get the appropriate slice of the data according to
|
||||
sharding defined by the global_mesh/pspects.
|
||||
|
||||
For example:
|
||||
|
||||
>>> global_mesh = jax.sharding.Mesh(jax.devices(), 'x')
|
||||
>>> pspecs = jax.sharding.PartitionSpec('x')
|
||||
>>> host_id = jax.process_index()
|
||||
>>> arr = host_local_array_to_global_array(np.arange(4) * host_id, mesh, pspecs) # NB: assumes jax.local_device_count() divides 4. # doctest: +SKIP
|
||||
|
||||
The resulting array will have the shape (4 * num_processes) and will
|
||||
have distributed value of: (0, 1, 2, 3, 0, 2, 4, 6, 0, 3, 6, 9, ... ),
|
||||
where each slice np.arange(4) * host_id will be partitioned across the
|
||||
corresponding host's devices.
|
||||
|
||||
Similarly:
|
||||
|
||||
>>> mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(jax.process_count(), jax.local_device_count()), ['host', 'dev'])
|
||||
>>> pspecs = jax.sharding.PartitionSpec('host')
|
||||
>>> host_id = jax.process_index()
|
||||
>>> arr = host_local_array_to_global_array(np.arange(4) * host_id, mesh, pspecs) # doctest: +SKIP
|
||||
|
||||
will create the same distributed value (0, 1, 2, 3, 0, 2, 4, 6, ...),
|
||||
however each slice np.arange(4) * i will be *replicated* across corresponding
|
||||
host devices.
|
||||
|
||||
On the other hand, if pspecs = PartitionSpec(), which means
|
||||
replication across all axes, then this snippet:
|
||||
|
||||
>>> pspecs = jax.sharding.PartitionSpec()
|
||||
>>> arr = host_local_array_to_global_array(np.arange(4), mesh, pspecs) # doctest: +SKIP
|
||||
|
||||
will have the shape (4,) and the value (0, 1, 2, 3) will be replicated
|
||||
across all hosts and devices.
|
||||
|
||||
It is an undefined behavior to have not identical local_inputs with pspec
|
||||
indicating data replication.
|
||||
|
||||
You can use this function to transition to jax.Array. Using jax.Array with
|
||||
pjit has the same semantics of using GDA with pjit i.e. all jax.Array
|
||||
inputs to pjit should be globally shaped.
|
||||
|
||||
If you are currently passing host local values to pjit, you can use this
|
||||
function to convert your host local values to global Arrays and then pass that
|
||||
to pjit. Example usage.
|
||||
to pjit.
|
||||
|
||||
|
||||
Example usage.
|
||||
|
||||
>>> from jax.experimental import multihost_utils # doctest: +SKIP
|
||||
>>>
|
||||
@ -278,10 +321,20 @@ def host_local_array_to_global_array(
|
||||
>>>
|
||||
>>> host_local_output = multihost_utils.global_array_to_host_local_array(global_out, mesh, out_pspecs) # doctest: +SKIP
|
||||
|
||||
Please note ths function requires global mesh to be a continuous mesh, meaning
|
||||
that devices that belong to each host should form a subcube in this mesh.
|
||||
To move local data to global array with non-continuous mesh use
|
||||
jax.make_array_from_callback or jax.make_array_from_single_device_arrays
|
||||
instead.
|
||||
|
||||
Args:
|
||||
local_inputs: A Pytree of host local values.
|
||||
global_mesh: A jax.sharding.Mesh object.
|
||||
global_mesh: A jax.sharding.Mesh object. The mesh must be a contiguous mesh,
|
||||
that is all hosts' devices must form a subcube in this mesh.
|
||||
pspecs: A Pytree of jax.sharding.PartitionSpec's.
|
||||
|
||||
Returns:
|
||||
A pytree of global arrays.
|
||||
"""
|
||||
flat_inps, in_tree = tree_flatten(local_inputs)
|
||||
in_pspecs = _flatten_pspecs('input pspecs', in_tree,
|
||||
|
Loading…
x
Reference in New Issue
Block a user