Make device_put work with inputs which are host local and the sharding is global sharding i.e. sharding spanning across multiple hosts.

Use `multihost_utils.assert_equal` to check if the input is the same across all hosts.

Do some formatting fixes too ;)

PiperOrigin-RevId: 647711853
This commit is contained in:
Yash Katariya 2024-06-28 09:43:41 -07:00 committed by jax authors
parent ba88601b9c
commit 061f4df82a
2 changed files with 23 additions and 24 deletions

View File

@ -40,6 +40,7 @@ from jax._src import util
from jax._src import xla_bridge as xb
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.abstract_arrays import array_types
from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.interpreters import pxla
@ -364,7 +365,8 @@ def _mcjax_reshard(x, target_sharding):
new_x = array.make_array_from_single_device_arrays(
x.shape,
GSPMDSharding(target_sharding._device_assignment, new_hlo_sharding),
GSPMDSharding(target_sharding._device_assignment, new_hlo_sharding,
memory_kind=target_sharding.memory_kind),
x._arrays,
)
@ -398,25 +400,33 @@ class _DeferredShardArg:
def _device_put_sharding_impl(x, aval, device):
from jax._src import array
from jax.experimental import multihost_utils
if isinstance(device, Sharding):
s = device
if getattr(x, 'sharding', None) == s and getattr(x, '_committed', False):
return x
if (not s.is_fully_addressable and
isinstance(x, array.ArrayImpl) and not x.is_fully_addressable):
# This has to be XLACompatible because _mcjax_reshard will run a
# XLA computation.
assert isinstance(s, Sharding)
return _mcjax_reshard(x, s)
if not s.is_fully_addressable:
if ((isinstance(x, array.ArrayImpl) and not x._committed) or
type(x) in array_types):
# TODO(yashkatariya): Move this check to `jit`.
multihost_utils.assert_equal(
x, fail_message=(
f"{type(x)} passed to device_put is not the same on each"
" process. Make sure you are passing the same value of"
f" {type(x)} on each process."))
return api.jit(_identity_fn, out_shardings=s)(x)
# TODO(yashkatariya,mattjj): Link to a doc about McJAX and jax.Array.
raise ValueError(
"device_put's second argument must be a Device or a Sharding which"
f" represents addressable devices, but got {s}. You are probably"
" trying to use device_put in multi-controller JAX which is not"
" supported. Please use jax.make_array_from_single_device_arrays API"
" or pass device or Sharding which represents addressable devices.")
f" represents addressable devices, but got {s}. Please pass device or"
" Sharding which represents addressable devices.")
return _DeferredShardArg(x, s, aval, True)
# Only `Device` exists below. `Sharding` instance is handled above.

View File

@ -1296,10 +1296,7 @@ class NonUniformShardingError(ValueError):
def get_process_index_and_count(
tensor_sharding: sharding.Sharding,
dim: int,
ndims: int,
) -> tuple[int, int]:
tensor_sharding: sharding.Sharding, dim: int, ndims: int) -> tuple[int, int]:
"""Get current process index and number of unique processes for given dimension.
This function facilitates mapping of process-level data to individual
@ -1365,10 +1362,8 @@ def get_process_index_and_count(
"""
# TODO(sandler, yashkatariya): Consider making this function public.
if (
tensor_sharding.is_fully_addressable
or tensor_sharding.is_fully_replicated
):
if (tensor_sharding.is_fully_addressable or
tensor_sharding.is_fully_replicated):
return (0, 1)
num_devices = len(tensor_sharding.device_set)
# Get device to indices map, we don't care about the concrete
@ -1416,9 +1411,7 @@ def get_process_index_and_count(
def local_to_global_shape(
sharding: sharding.Sharding,
local_shape: Shape,
) -> tuple[int | None, ...]:
sharding: sharding.Sharding, local_shape: Shape) -> tuple[int | None, ...]:
"""Computes the global shape given the per process if possible.
The returned shape will have the size of the global tensor in that dimension
@ -1467,8 +1460,7 @@ def local_to_global_shape(
for i, local_dim in enumerate(local_shape):
try:
_, shard_count = get_process_index_and_count(
sharding, i, ndims=len(local_shape)
)
sharding, i, ndims=len(local_shape))
global_shape[i] = local_dim * shard_count
except NonUniformShardingError:
global_shape[i] = None
@ -1478,10 +1470,7 @@ def local_to_global_shape(
def num_addressable_indices(
tensor_sharding: sharding.Sharding,
dim: int,
global_shape: Shape,
) -> int:
tensor_sharding: sharding.Sharding, dim: int, global_shape: Shape) -> int:
"""Returns the number of indices for given dimension this host has access to.
Each host can have multiple number of devices that are spanning