mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix example code snippet and docstring of global_array_to_host_local_array
This commit is contained in:
parent
41a1f2cfdc
commit
c45bcee04d
@ -424,21 +424,27 @@ def global_array_to_host_local_array(
|
||||
|
||||
You can use this function to convert the globally shaped `jax.Array` output
|
||||
from pjit to host local values again so that the transition to jax.Array can
|
||||
be a mechanical change. Example usage
|
||||
be a mechanical change.
|
||||
|
||||
>> from jax.experimental import multihost_utils # doctest: +SKIP
|
||||
>>
|
||||
>> global_inputs = multihost_utils.host_local_array_to_global_array(host_local_inputs, global_mesh, in_pspecs) # doctest: +SKIP
|
||||
>>
|
||||
>> with mesh: # doctest: +SKIP
|
||||
>> global_out = pjitted_fun(global_inputs) # doctest: +SKIP
|
||||
>>
|
||||
>> host_local_output = multihost_utils.global_array_to_host_local_array(global_out, mesh, out_pspecs) # doctest: +SKIP
|
||||
Example usage:
|
||||
|
||||
>>> from jax.experimental import multihost_utils # doctest: +SKIP
|
||||
>>>
|
||||
>>> global_inputs = multihost_utils.host_local_array_to_global_array(host_local_inputs, global_mesh, in_pspecs) # doctest: +SKIP
|
||||
>>>
|
||||
>>> with mesh: # doctest: +SKIP
|
||||
... global_out = pjitted_fun(global_inputs) # doctest: +SKIP
|
||||
>>>
|
||||
>>> host_local_output = multihost_utils.global_array_to_host_local_array(global_out, mesh, out_pspecs) # doctest: +SKIP
|
||||
|
||||
Args:
|
||||
global_inputs: A Pytree of global jax.Array's.
|
||||
global_mesh: A jax.sharding.Mesh object.
|
||||
pspecs: A Pytree of jax.sharding.PartitionSpec's.
|
||||
global_mesh: A :class:`jax.sharding.Mesh` object. The mesh must be contiguous
|
||||
meaning all local devices of the host must form a subcube.
|
||||
pspecs: A Pytree of :class:`jax.sharding.PartitionSpec` objects.
|
||||
|
||||
Returns:
|
||||
A Pytree of host local arrays.
|
||||
"""
|
||||
flat_inps, out_tree = tree_flatten(global_inputs)
|
||||
out_pspecs = _flatten_pspecs('output pspecs', out_tree,
|
||||
|
Loading…
x
Reference in New Issue
Block a user