mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Move host_local_array_to_global_array
and global_array_to_host_local_array
to multihost_utils.py
.
PiperOrigin-RevId: 480662569
This commit is contained in:
parent
6ec35c1629
commit
84e740e744
@ -18,11 +18,13 @@ from typing import Optional
|
||||
import zlib
|
||||
|
||||
import jax
|
||||
from jax.tree_util import tree_flatten, tree_map, tree_unflatten
|
||||
from jax._src import array
|
||||
from jax._src import sharding
|
||||
from jax.tree_util import PyTreeDef
|
||||
from jax.interpreters import pxla
|
||||
from jax.experimental import maps
|
||||
from jax.experimental import pjit as pjit_lib
|
||||
from jax.experimental.pjit import pjit, FROM_GDA
|
||||
from jax.interpreters.pxla import PartitionSpec as P
|
||||
from jax.experimental.global_device_array import GlobalDeviceArray
|
||||
@ -224,3 +226,103 @@ def reached_preemption_sync_point(step_id: int) -> bool:
|
||||
if sync_manager is None:
|
||||
raise RuntimeError("Preemption sync manager has not been initialized.")
|
||||
return sync_manager.reached_sync_point(step_id)
|
||||
|
||||
|
||||
def host_local_array_to_global_array(local_inputs, global_mesh, pspecs):
|
||||
"""Converts a host local value to a globally sharded `jax.Array`.
|
||||
|
||||
You can use this function to transition to `jax.Array`. Using `jax.Array` with
|
||||
`pjit` has the same semantics of using GDA with pjit i.e. all `jax.Array`
|
||||
inputs to pjit should be globally shaped.
|
||||
|
||||
If you are currently passing host local values to pjit, you can use this
|
||||
function to convert your host local values to global Arrays and then pass that
|
||||
to pjit.
|
||||
|
||||
Example usage:
|
||||
|
||||
```
|
||||
from jax.experimental import multihost_utils
|
||||
|
||||
global_inputs = multihost_utils.host_local_array_to_global_array(
|
||||
host_local_inputs, global_mesh, in_pspecs)
|
||||
|
||||
with mesh:
|
||||
global_out = pjitted_fun(global_inputs)
|
||||
|
||||
host_local_output = multihost_utils.global_array_to_host_local_array(
|
||||
global_out, mesh, out_pspecs)
|
||||
```
|
||||
|
||||
Args:
|
||||
local_inputs: A Pytree of host local values.
|
||||
global_mesh: The global mesh.
|
||||
pspecs: A Pytree of PartitionSpecs.
|
||||
"""
|
||||
def _convert(arr, pspec):
|
||||
if isinstance(arr, array.ArrayImpl) and isinstance(
|
||||
arr.sharding, jax.sharding.PmapSharding):
|
||||
arr = np.array(arr)
|
||||
local_sharding = jax.sharding.MeshPspecSharding(global_mesh.local_mesh, pspec)
|
||||
arrays = [
|
||||
jax.device_put(arr[index], d)
|
||||
for d, index in local_sharding.devices_indices_map(arr.shape).items()
|
||||
]
|
||||
global_aval = global_mesh._local_to_global(
|
||||
pxla._get_array_mapping(pspec),
|
||||
jax.ShapedArray(arr.shape, arrays[0].dtype))
|
||||
return array.ArrayImpl(
|
||||
global_aval, jax.sharding.MeshPspecSharding(global_mesh, pspec),
|
||||
arrays, committed=True)
|
||||
|
||||
flattened_inps, in_tree = tree_flatten(local_inputs)
|
||||
in_pspecs = pjit_lib.flatten_axis_resources(
|
||||
'input pspecs', in_tree, pspecs, tupled_args=True)
|
||||
out = tree_map(_convert, tuple(flattened_inps), in_pspecs)
|
||||
return tree_unflatten(in_tree, out)
|
||||
|
||||
|
||||
def global_array_to_host_local_array(global_inputs, global_mesh, pspecs):
|
||||
"""Converts a global `jax.Array` to a host local `jax.Array`.
|
||||
|
||||
You can use this function to transition to `jax.Array`. Using `jax.Array` with
|
||||
`pjit` has the same semantics of using GDA with pjit i.e. all `jax.Array`
|
||||
inputs to pjit should be globally shaped and the output from `pjit` will also
|
||||
be globally shaped `jax.Array`s
|
||||
|
||||
You can use this function to convert the globally shaped `jax.Array` output
|
||||
from pjit to host local values again so that the transition to jax.Array can
|
||||
be a mechanical change.
|
||||
|
||||
Example usage:
|
||||
|
||||
```
|
||||
from jax.experimental import multihost_utils
|
||||
|
||||
global_inputs = multihost_utils.host_local_array_to_global_array(
|
||||
host_local_inputs, global_mesh, in_pspecs)
|
||||
|
||||
with mesh:
|
||||
global_out = pjitted_fun(global_inputs)
|
||||
|
||||
host_local_output = multihost_utils.global_array_to_host_local_array(
|
||||
global_out, mesh, out_pspecs)
|
||||
```
|
||||
|
||||
Args:
|
||||
global_inputs: A Pytree of global `jax.Array`s.
|
||||
global_mesh: The global mesh.
|
||||
pspecs: A Pytree of PartitionSpecs.
|
||||
"""
|
||||
def _convert(arr, pspec):
|
||||
local_aval = global_mesh._global_to_local(
|
||||
pxla._get_array_mapping(pspec), arr.aval)
|
||||
return array.ArrayImpl(
|
||||
local_aval, jax.sharding.MeshPspecSharding(global_mesh.local_mesh, pspec),
|
||||
arr._arrays, committed=True)
|
||||
|
||||
flattened_inps, out_tree = tree_flatten(global_inputs)
|
||||
out_pspecs = pjit_lib.flatten_axis_resources(
|
||||
'output pspecs', out_tree, pspecs, tupled_args=True)
|
||||
out = tree_map(_convert, tuple(flattened_inps), out_pspecs)
|
||||
return tree_unflatten(out_tree, out)
|
||||
|
@ -1781,97 +1781,3 @@ def _get_pspec_from_executable(
|
||||
out_partition_spec = _get_partition_spec(out_ppspec)
|
||||
in_partition_spec = _get_partition_spec(in_ppspec)
|
||||
return tuple(in_partition_spec), tuple(out_partition_spec)
|
||||
|
||||
|
||||
def host_local_array_to_global_array(local_inputs, global_mesh, pspecs):
|
||||
"""Converts a host local value to a globally sharded `jax.Array`.
|
||||
|
||||
You can use this function to transition to `jax.Array`. Using `jax.Array` with
|
||||
`pjit` has the same semantics of using GDA with pjit i.e. all `jax.Array`
|
||||
inputs to pjit should be globally shaped.
|
||||
|
||||
If you are currently passing host local values to pjit, you can use this
|
||||
function to convert your host local values to global Arrays and then pass that
|
||||
to pjit.
|
||||
|
||||
Example usage:
|
||||
|
||||
```
|
||||
global_inputs = jax.experimental.pjit.host_local_array_to_global_array(
|
||||
host_local_inputs, global_mesh, in_pspecs)
|
||||
|
||||
with mesh:
|
||||
global_out = pjitted_fun(global_inputs)
|
||||
|
||||
host_local_output = jax.experimental.pjit.global_array_to_host_local_array(
|
||||
global_out, mesh, out_pspecs)
|
||||
```
|
||||
|
||||
Args:
|
||||
local_inputs: A Pytree of host local values.
|
||||
global_mesh: The global mesh.
|
||||
pspecs: A Pytree of PartitionSpecs.
|
||||
"""
|
||||
def _convert(arr, pspec):
|
||||
if isinstance(arr, array.ArrayImpl) and isinstance(arr.sharding, PmapSharding):
|
||||
arr = np.array(arr)
|
||||
local_sharding = MeshPspecSharding(global_mesh.local_mesh, pspec)
|
||||
arrays = [
|
||||
device_put(arr[index], d)
|
||||
for d, index in local_sharding.devices_indices_map(arr.shape).items()
|
||||
]
|
||||
global_aval = global_mesh._local_to_global(
|
||||
pxla._get_array_mapping(pspec),
|
||||
core.ShapedArray(arr.shape, arrays[0].dtype))
|
||||
return array.ArrayImpl(global_aval, MeshPspecSharding(global_mesh, pspec),
|
||||
arrays, committed=True)
|
||||
|
||||
flattened_inps, in_tree = tree_flatten(local_inputs)
|
||||
in_pspecs = flatten_axis_resources(
|
||||
'input pspecs', in_tree, pspecs, tupled_args=True)
|
||||
out = tree_map(_convert, tuple(flattened_inps), in_pspecs)
|
||||
return tree_unflatten(in_tree, out)
|
||||
|
||||
|
||||
def global_array_to_host_local_array(global_inputs, global_mesh, pspecs):
|
||||
"""Converts a global `jax.Array` to a host local `jax.Array`.
|
||||
|
||||
You can use this function to transition to `jax.Array`. Using `jax.Array` with
|
||||
`pjit` has the same semantics of using GDA with pjit i.e. all `jax.Array`
|
||||
inputs to pjit should be globally shaped and the output from `pjit` will also
|
||||
be globally shaped `jax.Array`s
|
||||
|
||||
You can use this function to convert the globally shaped `jax.Array` output
|
||||
from pjit to host local values again so that the transition to jax.Array can
|
||||
be a mechanical change.
|
||||
|
||||
Example usage:
|
||||
|
||||
```
|
||||
global_inputs = jax.experimental.pjit.host_local_array_to_global_array(
|
||||
host_local_inputs, global_mesh, in_pspecs)
|
||||
|
||||
with mesh:
|
||||
global_out = pjitted_fun(global_inputs)
|
||||
|
||||
host_local_output = jax.experimental.pjit.global_array_to_host_local_array(
|
||||
global_out, mesh, out_pspecs)
|
||||
```
|
||||
|
||||
Args:
|
||||
global_inputs: A Pytree of global `jax.Array`s.
|
||||
global_mesh: The global mesh.
|
||||
pspecs: A Pytree of PartitionSpecs.
|
||||
"""
|
||||
def _convert(arr, pspec):
|
||||
local_aval = global_mesh._global_to_local(
|
||||
pxla._get_array_mapping(pspec), arr.aval)
|
||||
return array.ArrayImpl(
|
||||
local_aval, MeshPspecSharding(global_mesh.local_mesh, pspec),
|
||||
arr._arrays, committed=True)
|
||||
|
||||
flattened_inps, out_tree = tree_flatten(global_inputs)
|
||||
out_pspecs = flatten_axis_resources(
|
||||
'output pspecs', out_tree, pspecs, tupled_args=True)
|
||||
out = tree_map(_convert, tuple(flattened_inps), out_pspecs)
|
||||
return tree_unflatten(out_tree, out)
|
||||
|
Loading…
x
Reference in New Issue
Block a user