Fix example code snippet and docstring of global_array_to_host_local_array

This commit is contained in:
rajasekharporeddy 2024-06-14 01:19:53 +05:30
parent 41a1f2cfdc
commit c45bcee04d

View File

@ -424,21 +424,27 @@ def global_array_to_host_local_array(
You can use this function to convert the globally shaped `jax.Array` output
from pjit to host local values again so that the transition to jax.Array can
be a mechanical change. Example usage
be a mechanical change.
>> from jax.experimental import multihost_utils # doctest: +SKIP
>>
>> global_inputs = multihost_utils.host_local_array_to_global_array(host_local_inputs, global_mesh, in_pspecs) # doctest: +SKIP
>>
>> with mesh: # doctest: +SKIP
>> global_out = pjitted_fun(global_inputs) # doctest: +SKIP
>>
>> host_local_output = multihost_utils.global_array_to_host_local_array(global_out, mesh, out_pspecs) # doctest: +SKIP
Example usage:
>>> from jax.experimental import multihost_utils # doctest: +SKIP
>>>
>>> global_inputs = multihost_utils.host_local_array_to_global_array(host_local_inputs, global_mesh, in_pspecs) # doctest: +SKIP
>>>
>>> with mesh: # doctest: +SKIP
... global_out = pjitted_fun(global_inputs) # doctest: +SKIP
>>>
>>> host_local_output = multihost_utils.global_array_to_host_local_array(global_out, mesh, out_pspecs) # doctest: +SKIP
Args:
global_inputs: A Pytree of global jax.Array's.
global_mesh: A jax.sharding.Mesh object.
pspecs: A Pytree of jax.sharding.PartitionSpec's.
global_mesh: A :class:`jax.sharding.Mesh` object. The mesh must be contiguous
meaning all local devices of the host must form a subcube.
pspecs: A Pytree of :class:`jax.sharding.PartitionSpec` objects.
Returns:
A Pytree of host local arrays.
"""
flat_inps, out_tree = tree_flatten(global_inputs)
out_pspecs = _flatten_pspecs('output pspecs', out_tree,