mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix sphinx syntax error.
This commit is contained in:
parent
de57b4fd36
commit
916ad35319
@ -257,35 +257,31 @@ def host_local_array_to_global_array_impl(
|
||||
|
||||
def host_local_array_to_global_array(
|
||||
local_inputs: Any, global_mesh: jax.sharding.Mesh, pspecs: Any):
|
||||
"""Converts a host local value to a globally sharded `jax.Array`.
|
||||
r"""Converts a host local value to a globally sharded jax.Array.
|
||||
|
||||
You can use this function to transition to `jax.Array`. Using `jax.Array` with
|
||||
`pjit` has the same semantics of using GDA with pjit i.e. all `jax.Array`
|
||||
You can use this function to transition to jax.Array. Using jax.Array with
|
||||
pjit has the same semantics of using GDA with pjit i.e. all jax.Array
|
||||
inputs to pjit should be globally shaped.
|
||||
|
||||
If you are currently passing host local values to pjit, you can use this
|
||||
function to convert your host local values to global Arrays and then pass that
|
||||
to pjit.
|
||||
to pjit. Example usage.
|
||||
|
||||
Example usage:
|
||||
>>> from jax.experimental import multihost_utils
|
||||
|
||||
```
|
||||
from jax.experimental import multihost_utils
|
||||
>>> global_inputs = multihost_utils.host_local_array_to_global_array(
|
||||
>>> host_local_inputs, global_mesh, in_pspecs)
|
||||
|
||||
global_inputs = multihost_utils.host_local_array_to_global_array(
|
||||
host_local_inputs, global_mesh, in_pspecs)
|
||||
>>> with mesh:
|
||||
>>> global_out = pjitted_fun(global_inputs)
|
||||
|
||||
with mesh:
|
||||
global_out = pjitted_fun(global_inputs)
|
||||
|
||||
host_local_output = multihost_utils.global_array_to_host_local_array(
|
||||
global_out, mesh, out_pspecs)
|
||||
```
|
||||
>>> host_local_output = multihost_utils.global_array_to_host_local_array(
|
||||
>>> global_out, mesh, out_pspecs)
|
||||
|
||||
Args:
|
||||
local_inputs: A Pytree of host local values.
|
||||
global_mesh: A ``jax.sharding.Mesh`` object.
|
||||
pspecs: A Pytree of ``jax.sharding.PartitionSpec``s.
|
||||
global_mesh: A jax.sharding.Mesh object.
|
||||
pspecs: A Pytree of jax.sharding.PartitionSpec's.
|
||||
"""
|
||||
flat_inps, in_tree = tree_flatten(local_inputs)
|
||||
in_pspecs = _flatten_pspecs('input pspecs', in_tree,
|
||||
@ -357,36 +353,32 @@ def global_array_to_host_local_array_impl(
|
||||
|
||||
def global_array_to_host_local_array(
|
||||
global_inputs: Any, global_mesh: jax.sharding.Mesh, pspecs: Any):
|
||||
"""Converts a global `jax.Array` to a host local `jax.Array`.
|
||||
r"""Converts a global `jax.Array` to a host local `jax.Array`.
|
||||
|
||||
You can use this function to transition to `jax.Array`. Using `jax.Array` with
|
||||
`pjit` has the same semantics of using GDA with pjit i.e. all `jax.Array`
|
||||
inputs to pjit should be globally shaped and the output from `pjit` will also
|
||||
be globally shaped `jax.Array`s
|
||||
pjit has the same semantics of using GDA with pjit i.e. all `jax.Array`
|
||||
inputs to pjit should be globally shaped and the output from pjit will also
|
||||
be globally shaped jax.Array's
|
||||
|
||||
You can use this function to convert the globally shaped `jax.Array` output
|
||||
from pjit to host local values again so that the transition to jax.Array can
|
||||
be a mechanical change.
|
||||
be a mechanical change. Example usage
|
||||
|
||||
Example usage:
|
||||
>>> from jax.experimental import multihost_utils
|
||||
|
||||
```
|
||||
from jax.experimental import multihost_utils
|
||||
>>> global_inputs = multihost_utils.host_local_array_to_global_array(
|
||||
>>> host_local_inputs, global_mesh, in_pspecs)
|
||||
|
||||
global_inputs = multihost_utils.host_local_array_to_global_array(
|
||||
host_local_inputs, global_mesh, in_pspecs)
|
||||
>>> with mesh:
|
||||
>>> global_out = pjitted_fun(global_inputs)
|
||||
|
||||
with mesh:
|
||||
global_out = pjitted_fun(global_inputs)
|
||||
|
||||
host_local_output = multihost_utils.global_array_to_host_local_array(
|
||||
global_out, mesh, out_pspecs)
|
||||
```
|
||||
>>> host_local_output = multihost_utils.global_array_to_host_local_array(
|
||||
>>> global_out, mesh, out_pspecs)
|
||||
|
||||
Args:
|
||||
global_inputs: A Pytree of global `jax.Array`s.
|
||||
global_mesh: A ``jax.sharding.Mesh`` object.
|
||||
pspecs: A Pytree of ``jax.sharding.PartitionSpec``s.
|
||||
global_inputs: A Pytree of global jax.Array's.
|
||||
global_mesh: A jax.sharding.Mesh object.
|
||||
pspecs: A Pytree of jax.sharding.PartitionSpec's.
|
||||
"""
|
||||
flat_inps, out_tree = tree_flatten(global_inputs)
|
||||
out_pspecs = _flatten_pspecs('output pspecs', out_tree,
|
||||
|
Loading…
x
Reference in New Issue
Block a user