Move host_local_array_to_global_array and global_array_to_host_local_array to multihost_utils.py.

PiperOrigin-RevId: 480662569
This commit is contained in:
Yash Katariya 2022-10-12 10:41:32 -07:00 committed by jax authors
parent 6ec35c1629
commit 84e740e744
2 changed files with 102 additions and 94 deletions

View File

@ -18,11 +18,13 @@ from typing import Optional
import zlib
import jax
from jax.tree_util import tree_flatten, tree_map, tree_unflatten
from jax._src import array
from jax._src import sharding
from jax.tree_util import PyTreeDef
from jax.interpreters import pxla
from jax.experimental import maps
from jax.experimental import pjit as pjit_lib
from jax.experimental.pjit import pjit, FROM_GDA
from jax.interpreters.pxla import PartitionSpec as P
from jax.experimental.global_device_array import GlobalDeviceArray
@ -224,3 +226,103 @@ def reached_preemption_sync_point(step_id: int) -> bool:
if sync_manager is None:
raise RuntimeError("Preemption sync manager has not been initialized.")
return sync_manager.reached_sync_point(step_id)
def host_local_array_to_global_array(local_inputs, global_mesh, pspecs):
"""Converts a host local value to a globally sharded `jax.Array`.
You can use this function to transition to `jax.Array`. Using `jax.Array` with
`pjit` has the same semantics of using GDA with pjit i.e. all `jax.Array`
inputs to pjit should be globally shaped.
If you are currently passing host local values to pjit, you can use this
function to convert your host local values to global Arrays and then pass that
to pjit.
Example usage:
```
from jax.experimental import multihost_utils
global_inputs = multihost_utils.host_local_array_to_global_array(
host_local_inputs, global_mesh, in_pspecs)
with mesh:
global_out = pjitted_fun(global_inputs)
host_local_output = multihost_utils.global_array_to_host_local_array(
global_out, mesh, out_pspecs)
```
Args:
local_inputs: A Pytree of host local values.
global_mesh: The global mesh.
pspecs: A Pytree of PartitionSpecs.
"""
def _convert(arr, pspec):
if isinstance(arr, array.ArrayImpl) and isinstance(
arr.sharding, jax.sharding.PmapSharding):
arr = np.array(arr)
local_sharding = jax.sharding.MeshPspecSharding(global_mesh.local_mesh, pspec)
arrays = [
jax.device_put(arr[index], d)
for d, index in local_sharding.devices_indices_map(arr.shape).items()
]
global_aval = global_mesh._local_to_global(
pxla._get_array_mapping(pspec),
jax.ShapedArray(arr.shape, arrays[0].dtype))
return array.ArrayImpl(
global_aval, jax.sharding.MeshPspecSharding(global_mesh, pspec),
arrays, committed=True)
flattened_inps, in_tree = tree_flatten(local_inputs)
in_pspecs = pjit_lib.flatten_axis_resources(
'input pspecs', in_tree, pspecs, tupled_args=True)
out = tree_map(_convert, tuple(flattened_inps), in_pspecs)
return tree_unflatten(in_tree, out)
def global_array_to_host_local_array(global_inputs, global_mesh, pspecs):
"""Converts a global `jax.Array` to a host local `jax.Array`.
You can use this function to transition to `jax.Array`. Using `jax.Array` with
`pjit` has the same semantics of using GDA with pjit i.e. all `jax.Array`
inputs to pjit should be globally shaped and the output from `pjit` will also
be globally shaped `jax.Array`s
You can use this function to convert the globally shaped `jax.Array` output
from pjit to host local values again so that the transition to jax.Array can
be a mechanical change.
Example usage:
```
from jax.experimental import multihost_utils
global_inputs = multihost_utils.host_local_array_to_global_array(
host_local_inputs, global_mesh, in_pspecs)
with mesh:
global_out = pjitted_fun(global_inputs)
host_local_output = multihost_utils.global_array_to_host_local_array(
global_out, mesh, out_pspecs)
```
Args:
global_inputs: A Pytree of global `jax.Array`s.
global_mesh: The global mesh.
pspecs: A Pytree of PartitionSpecs.
"""
def _convert(arr, pspec):
local_aval = global_mesh._global_to_local(
pxla._get_array_mapping(pspec), arr.aval)
return array.ArrayImpl(
local_aval, jax.sharding.MeshPspecSharding(global_mesh.local_mesh, pspec),
arr._arrays, committed=True)
flattened_inps, out_tree = tree_flatten(global_inputs)
out_pspecs = pjit_lib.flatten_axis_resources(
'output pspecs', out_tree, pspecs, tupled_args=True)
out = tree_map(_convert, tuple(flattened_inps), out_pspecs)
return tree_unflatten(out_tree, out)

View File

@ -1781,97 +1781,3 @@ def _get_pspec_from_executable(
out_partition_spec = _get_partition_spec(out_ppspec)
in_partition_spec = _get_partition_spec(in_ppspec)
return tuple(in_partition_spec), tuple(out_partition_spec)
def host_local_array_to_global_array(local_inputs, global_mesh, pspecs):
"""Converts a host local value to a globally sharded `jax.Array`.
You can use this function to transition to `jax.Array`. Using `jax.Array` with
`pjit` has the same semantics of using GDA with pjit i.e. all `jax.Array`
inputs to pjit should be globally shaped.
If you are currently passing host local values to pjit, you can use this
function to convert your host local values to global Arrays and then pass that
to pjit.
Example usage:
```
global_inputs = jax.experimental.pjit.host_local_array_to_global_array(
host_local_inputs, global_mesh, in_pspecs)
with mesh:
global_out = pjitted_fun(global_inputs)
host_local_output = jax.experimental.pjit.global_array_to_host_local_array(
global_out, mesh, out_pspecs)
```
Args:
local_inputs: A Pytree of host local values.
global_mesh: The global mesh.
pspecs: A Pytree of PartitionSpecs.
"""
def _convert(arr, pspec):
if isinstance(arr, array.ArrayImpl) and isinstance(arr.sharding, PmapSharding):
arr = np.array(arr)
local_sharding = MeshPspecSharding(global_mesh.local_mesh, pspec)
arrays = [
device_put(arr[index], d)
for d, index in local_sharding.devices_indices_map(arr.shape).items()
]
global_aval = global_mesh._local_to_global(
pxla._get_array_mapping(pspec),
core.ShapedArray(arr.shape, arrays[0].dtype))
return array.ArrayImpl(global_aval, MeshPspecSharding(global_mesh, pspec),
arrays, committed=True)
flattened_inps, in_tree = tree_flatten(local_inputs)
in_pspecs = flatten_axis_resources(
'input pspecs', in_tree, pspecs, tupled_args=True)
out = tree_map(_convert, tuple(flattened_inps), in_pspecs)
return tree_unflatten(in_tree, out)
def global_array_to_host_local_array(global_inputs, global_mesh, pspecs):
"""Converts a global `jax.Array` to a host local `jax.Array`.
You can use this function to transition to `jax.Array`. Using `jax.Array` with
`pjit` has the same semantics of using GDA with pjit i.e. all `jax.Array`
inputs to pjit should be globally shaped and the output from `pjit` will also
be globally shaped `jax.Array`s
You can use this function to convert the globally shaped `jax.Array` output
from pjit to host local values again so that the transition to jax.Array can
be a mechanical change.
Example usage:
```
global_inputs = jax.experimental.pjit.host_local_array_to_global_array(
host_local_inputs, global_mesh, in_pspecs)
with mesh:
global_out = pjitted_fun(global_inputs)
host_local_output = jax.experimental.pjit.global_array_to_host_local_array(
global_out, mesh, out_pspecs)
```
Args:
global_inputs: A Pytree of global `jax.Array`s.
global_mesh: The global mesh.
pspecs: A Pytree of PartitionSpecs.
"""
def _convert(arr, pspec):
local_aval = global_mesh._global_to_local(
pxla._get_array_mapping(pspec), arr.aval)
return array.ArrayImpl(
local_aval, MeshPspecSharding(global_mesh.local_mesh, pspec),
arr._arrays, committed=True)
flattened_inps, out_tree = tree_flatten(global_inputs)
out_pspecs = flatten_axis_resources(
'output pspecs', out_tree, pspecs, tupled_args=True)
out = tree_map(_convert, tuple(flattened_inps), out_pspecs)
return tree_unflatten(out_tree, out)