mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Make device_put
work with inputs which are host local and the sharding is global sharding i.e. sharding spanning across multiple hosts.
Use `multihost_utils.assert_equal` to check if the input is the same across all hosts. Do some formatting fixes too ;) PiperOrigin-RevId: 647711853
This commit is contained in:
parent
ba88601b9c
commit
061f4df82a
@ -40,6 +40,7 @@ from jax._src import util
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.abstract_arrays import array_types
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.interpreters import pxla
|
||||
@ -364,7 +365,8 @@ def _mcjax_reshard(x, target_sharding):
|
||||
|
||||
new_x = array.make_array_from_single_device_arrays(
|
||||
x.shape,
|
||||
GSPMDSharding(target_sharding._device_assignment, new_hlo_sharding),
|
||||
GSPMDSharding(target_sharding._device_assignment, new_hlo_sharding,
|
||||
memory_kind=target_sharding.memory_kind),
|
||||
x._arrays,
|
||||
)
|
||||
|
||||
@ -398,25 +400,33 @@ class _DeferredShardArg:
|
||||
|
||||
def _device_put_sharding_impl(x, aval, device):
|
||||
from jax._src import array
|
||||
from jax.experimental import multihost_utils
|
||||
|
||||
if isinstance(device, Sharding):
|
||||
s = device
|
||||
if getattr(x, 'sharding', None) == s and getattr(x, '_committed', False):
|
||||
return x
|
||||
|
||||
if (not s.is_fully_addressable and
|
||||
isinstance(x, array.ArrayImpl) and not x.is_fully_addressable):
|
||||
# This has to be XLACompatible because _mcjax_reshard will run a
|
||||
# XLA computation.
|
||||
assert isinstance(s, Sharding)
|
||||
return _mcjax_reshard(x, s)
|
||||
|
||||
if not s.is_fully_addressable:
|
||||
if ((isinstance(x, array.ArrayImpl) and not x._committed) or
|
||||
type(x) in array_types):
|
||||
# TODO(yashkatariya): Move this check to `jit`.
|
||||
multihost_utils.assert_equal(
|
||||
x, fail_message=(
|
||||
f"{type(x)} passed to device_put is not the same on each"
|
||||
" process. Make sure you are passing the same value of"
|
||||
f" {type(x)} on each process."))
|
||||
return api.jit(_identity_fn, out_shardings=s)(x)
|
||||
# TODO(yashkatariya,mattjj): Link to a doc about McJAX and jax.Array.
|
||||
raise ValueError(
|
||||
"device_put's second argument must be a Device or a Sharding which"
|
||||
f" represents addressable devices, but got {s}. You are probably"
|
||||
" trying to use device_put in multi-controller JAX which is not"
|
||||
" supported. Please use jax.make_array_from_single_device_arrays API"
|
||||
" or pass device or Sharding which represents addressable devices.")
|
||||
f" represents addressable devices, but got {s}. Please pass device or"
|
||||
" Sharding which represents addressable devices.")
|
||||
return _DeferredShardArg(x, s, aval, True)
|
||||
|
||||
# Only `Device` exists below. `Sharding` instance is handled above.
|
||||
|
@ -1296,10 +1296,7 @@ class NonUniformShardingError(ValueError):
|
||||
|
||||
|
||||
def get_process_index_and_count(
|
||||
tensor_sharding: sharding.Sharding,
|
||||
dim: int,
|
||||
ndims: int,
|
||||
) -> tuple[int, int]:
|
||||
tensor_sharding: sharding.Sharding, dim: int, ndims: int) -> tuple[int, int]:
|
||||
"""Get current process index and number of unique processes for given dimension.
|
||||
|
||||
This function facilitates mapping of process-level data to individual
|
||||
@ -1365,10 +1362,8 @@ def get_process_index_and_count(
|
||||
"""
|
||||
# TODO(sandler, yashkatariya): Consider making this function public.
|
||||
|
||||
if (
|
||||
tensor_sharding.is_fully_addressable
|
||||
or tensor_sharding.is_fully_replicated
|
||||
):
|
||||
if (tensor_sharding.is_fully_addressable or
|
||||
tensor_sharding.is_fully_replicated):
|
||||
return (0, 1)
|
||||
num_devices = len(tensor_sharding.device_set)
|
||||
# Get device to indices map, we don't care about the concrete
|
||||
@ -1416,9 +1411,7 @@ def get_process_index_and_count(
|
||||
|
||||
|
||||
def local_to_global_shape(
|
||||
sharding: sharding.Sharding,
|
||||
local_shape: Shape,
|
||||
) -> tuple[int | None, ...]:
|
||||
sharding: sharding.Sharding, local_shape: Shape) -> tuple[int | None, ...]:
|
||||
"""Computes the global shape given the per process if possible.
|
||||
|
||||
The returned shape will have the size of the global tensor in that dimension
|
||||
@ -1467,8 +1460,7 @@ def local_to_global_shape(
|
||||
for i, local_dim in enumerate(local_shape):
|
||||
try:
|
||||
_, shard_count = get_process_index_and_count(
|
||||
sharding, i, ndims=len(local_shape)
|
||||
)
|
||||
sharding, i, ndims=len(local_shape))
|
||||
global_shape[i] = local_dim * shard_count
|
||||
except NonUniformShardingError:
|
||||
global_shape[i] = None
|
||||
@ -1478,10 +1470,7 @@ def local_to_global_shape(
|
||||
|
||||
|
||||
def num_addressable_indices(
|
||||
tensor_sharding: sharding.Sharding,
|
||||
dim: int,
|
||||
global_shape: Shape,
|
||||
) -> int:
|
||||
tensor_sharding: sharding.Sharding, dim: int, global_shape: Shape) -> int:
|
||||
"""Returns the number of indices for given dimension this host has access to.
|
||||
|
||||
Each host can have multiple number of devices that are spanning
|
||||
|
Loading…
x
Reference in New Issue
Block a user