mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add typehints and point to the correct endpoint of Mesh and PartitionSpec in the args section.
PiperOrigin-RevId: 493035898
This commit is contained in:
parent
401fbb61a9
commit
b8b6e272d3
@ -18,6 +18,7 @@ import itertools as it
|
||||
from typing import Optional
|
||||
import zlib
|
||||
|
||||
from typing import Any
|
||||
import jax
|
||||
from jax.tree_util import tree_flatten, tree_map, tree_unflatten
|
||||
from jax._src import dispatch
|
||||
@ -251,7 +252,9 @@ def _device_put(x, device):
|
||||
raise TypeError(f"No device_put handler for type: {type(x)}") from err
|
||||
|
||||
|
||||
def host_local_array_to_global_array(local_inputs, global_mesh, pspecs):
|
||||
def host_local_array_to_global_array(local_inputs: Any,
|
||||
global_mesh: jax.sharding.Mesh,
|
||||
pspecs: Any):
|
||||
"""Converts a host local value to a globally sharded `jax.Array`.
|
||||
|
||||
You can use this function to transition to `jax.Array`. Using `jax.Array` with
|
||||
@ -279,8 +282,8 @@ def host_local_array_to_global_array(local_inputs, global_mesh, pspecs):
|
||||
|
||||
Args:
|
||||
local_inputs: A Pytree of host local values.
|
||||
global_mesh: The global mesh.
|
||||
pspecs: A Pytree of PartitionSpecs.
|
||||
global_mesh: A ``jax.sharding.Mesh`` object.
|
||||
pspecs: A Pytree of ``jax.sharding.PartitionSpec``s.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If `jax.config.jax_array` not previously enabled.
|
||||
@ -332,7 +335,9 @@ def host_local_array_to_global_array(local_inputs, global_mesh, pspecs):
|
||||
return tree_unflatten(in_tree, out)
|
||||
|
||||
|
||||
def global_array_to_host_local_array(global_inputs, global_mesh, pspecs):
|
||||
def global_array_to_host_local_array(global_inputs: Any,
|
||||
global_mesh: jax.sharding.Mesh,
|
||||
pspecs: Any):
|
||||
"""Converts a global `jax.Array` to a host local `jax.Array`.
|
||||
|
||||
You can use this function to transition to `jax.Array`. Using `jax.Array` with
|
||||
@ -361,8 +366,8 @@ def global_array_to_host_local_array(global_inputs, global_mesh, pspecs):
|
||||
|
||||
Args:
|
||||
global_inputs: A Pytree of global `jax.Array`s.
|
||||
global_mesh: The global mesh.
|
||||
pspecs: A Pytree of PartitionSpecs.
|
||||
global_mesh: A ``jax.sharding.Mesh`` object.
|
||||
pspecs: A Pytree of ``jax.sharding.PartitionSpec``s.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If `jax.config.jax_array` not previously enabled.
|
||||
|
Loading…
x
Reference in New Issue
Block a user