Add typehints and point to the correct endpoint of Mesh and PartitionSpec in the args section.

PiperOrigin-RevId: 493035898
This commit is contained in:
Yash Katariya 2022-12-05 09:48:46 -08:00 committed by jax authors
parent 401fbb61a9
commit b8b6e272d3

View File

@ -18,6 +18,7 @@ import itertools as it
from typing import Optional
import zlib
from typing import Any
import jax
from jax.tree_util import tree_flatten, tree_map, tree_unflatten
from jax._src import dispatch
@ -251,7 +252,9 @@ def _device_put(x, device):
raise TypeError(f"No device_put handler for type: {type(x)}") from err
def host_local_array_to_global_array(local_inputs, global_mesh, pspecs):
def host_local_array_to_global_array(local_inputs: Any,
global_mesh: jax.sharding.Mesh,
pspecs: Any):
"""Converts a host local value to a globally sharded `jax.Array`.
You can use this function to transition to `jax.Array`. Using `jax.Array` with
@ -279,8 +282,8 @@ def host_local_array_to_global_array(local_inputs, global_mesh, pspecs):
Args:
local_inputs: A Pytree of host local values.
global_mesh: The global mesh.
pspecs: A Pytree of PartitionSpecs.
global_mesh: A ``jax.sharding.Mesh`` object.
pspecs: A Pytree of ``jax.sharding.PartitionSpec``s.
Raises:
RuntimeError: If `jax.config.jax_array` not previously enabled.
@ -332,7 +335,9 @@ def host_local_array_to_global_array(local_inputs, global_mesh, pspecs):
return tree_unflatten(in_tree, out)
def global_array_to_host_local_array(global_inputs, global_mesh, pspecs):
def global_array_to_host_local_array(global_inputs: Any,
global_mesh: jax.sharding.Mesh,
pspecs: Any):
"""Converts a global `jax.Array` to a host local `jax.Array`.
You can use this function to transition to `jax.Array`. Using `jax.Array` with
@ -361,8 +366,8 @@ def global_array_to_host_local_array(global_inputs, global_mesh, pspecs):
Args:
global_inputs: A Pytree of global `jax.Array`s.
global_mesh: The global mesh.
pspecs: A Pytree of PartitionSpecs.
global_mesh: A ``jax.sharding.Mesh`` object.
pspecs: A Pytree of ``jax.sharding.PartitionSpec``s.
Raises:
RuntimeError: If `jax.config.jax_array` not previously enabled.