mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Try a more proper fix
This commit is contained in:
parent
2af115782d
commit
3e9c9b8691
@ -267,14 +267,16 @@ def host_local_array_to_global_array(
|
||||
function to convert your host local values to global Arrays and then pass that
|
||||
to pjit. Example usage.
|
||||
|
||||
>> from jax.experimental import multihost_utils
|
||||
>>
|
||||
>> global_inputs = multihost_utils.host_local_array_to_global_array(host_local_inputs, global_mesh, in_pspecs)
|
||||
>>
|
||||
>> with mesh:
|
||||
>> global_out = pjitted_fun(global_inputs)
|
||||
>>
|
||||
>> host_local_output = multihost_utils.global_array_to_host_local_array(global_out, mesh, out_pspecs)
|
||||
```
|
||||
>>> from jax.experimental import multihost_utils
|
||||
>>>
|
||||
>>> global_inputs = multihost_utils.host_local_array_to_global_array(host_local_inputs, global_mesh, in_pspecs)
|
||||
>>>
|
||||
>>> with mesh:
|
||||
>>> global_out = pjitted_fun(global_inputs)
|
||||
>>>
|
||||
>>> host_local_output = multihost_utils.global_array_to_host_local_array(global_out, mesh, out_pspecs)
|
||||
```
|
||||
|
||||
Args:
|
||||
local_inputs: A Pytree of host local values.
|
||||
|
Loading…
x
Reference in New Issue
Block a user