Try a more proper fix

This commit is contained in:
Frederic Bastien 2023-05-08 10:13:01 -07:00
parent 2af115782d
commit 3e9c9b8691

View File

@ -267,14 +267,16 @@ def host_local_array_to_global_array(
function to convert your host local values to global Arrays and then pass that
to pjit. Example usage.
>> from jax.experimental import multihost_utils
>>
>> global_inputs = multihost_utils.host_local_array_to_global_array(host_local_inputs, global_mesh, in_pspecs)
>>
>> with mesh:
>> global_out = pjitted_fun(global_inputs)
>>
>> host_local_output = multihost_utils.global_array_to_host_local_array(global_out, mesh, out_pspecs)
```
>>> from jax.experimental import multihost_utils
>>>
>>> global_inputs = multihost_utils.host_local_array_to_global_array(host_local_inputs, global_mesh, in_pspecs)
>>>
>>> with mesh:
>>> global_out = pjitted_fun(global_inputs)
>>>
>>> host_local_output = multihost_utils.global_array_to_host_local_array(global_out, mesh, out_pspecs)
```
Args:
local_inputs: A Pytree of host local values.