Try another sphinx fix.

This commit is contained in:
Frederic Bastien 2023-05-18 13:21:50 -07:00
parent 3e9c9b8691
commit 8ca40b2af6

View File

@ -267,16 +267,14 @@ def host_local_array_to_global_array(
function to convert your host local values to global Arrays and then pass that
to pjit. Example usage.
```
>>> from jax.experimental import multihost_utils
>>> from jax.experimental import multihost_utils # doctest: +SKIP
>>>
>>> global_inputs = multihost_utils.host_local_array_to_global_array(host_local_inputs, global_mesh, in_pspecs)
>>> global_inputs = multihost_utils.host_local_array_to_global_array(host_local_inputs, global_mesh, in_pspecs) # doctest: +SKIP
>>>
>>> with mesh:
>>> global_out = pjitted_fun(global_inputs)
>>> with mesh: # doctest: +SKIP
>>> global_out = pjitted_fun(global_inputs) # doctest: +SKIP
>>>
>>> host_local_output = multihost_utils.global_array_to_host_local_array(global_out, mesh, out_pspecs)
```
>>> host_local_output = multihost_utils.global_array_to_host_local_array(global_out, mesh, out_pspecs) # doctest: +SKIP
Args:
local_inputs: A Pytree of host local values.
@ -364,14 +362,14 @@ def global_array_to_host_local_array(
from pjit to host local values again so that the transition to jax.Array can
be a mechanical change. Example usage
>> from jax.experimental import multihost_utils
>> from jax.experimental import multihost_utils # doctest: +SKIP
>>
>> global_inputs = multihost_utils.host_local_array_to_global_array(host_local_inputs, global_mesh, in_pspecs)
>> global_inputs = multihost_utils.host_local_array_to_global_array(host_local_inputs, global_mesh, in_pspecs) # doctest: +SKIP
>>
>> with mesh:
>> global_out = pjitted_fun(global_inputs)
>> with mesh: # doctest: +SKIP
>> global_out = pjitted_fun(global_inputs) # doctest: +SKIP
>>
>> host_local_output = multihost_utils.global_array_to_host_local_array(global_out, mesh, out_pspecs)
>> host_local_output = multihost_utils.global_array_to_host_local_array(global_out, mesh, out_pspecs) # doctest: +SKIP
Args:
global_inputs: A Pytree of global jax.Array's.