mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Pmap should output SDA like Array
s to maintain the current behavior exactly. Split the shard_arg_handler for Array
based on whether the mode is pmap or pjit. Why do this? The doc below explains more about the context.
PiperOrigin-RevId: 466849614
This commit is contained in:
parent
0a783ca156
commit
33c4fc4fe2
@ -1897,39 +1897,41 @@ class PmapCallInfo(NamedTuple):
|
||||
|
||||
|
||||
def _check_in_pmap_sharding_with_arrays(args, in_axes_flat, in_devices):
|
||||
from jax.experimental import sharding
|
||||
from jax.experimental.sharding import PmapSharding
|
||||
from jax.experimental.array import Array
|
||||
|
||||
if not args:
|
||||
return
|
||||
|
||||
if in_devices is not None:
|
||||
in_devices = np.array(in_devices)
|
||||
|
||||
first_arr_devices = args[0].sharding.devices
|
||||
first_device_assignment = None
|
||||
for a, i in safe_zip(args, in_axes_flat):
|
||||
assert isinstance(a.sharding, sharding.PmapSharding)
|
||||
if not isinstance(a, Array):
|
||||
continue
|
||||
if not isinstance(a.sharding, PmapSharding):
|
||||
raise NotImplementedError('pmap only works with PmapSharding.')
|
||||
if first_device_assignment is None:
|
||||
first_device_assignment = a.sharding._device_assignment
|
||||
arr_sharding = a.sharding.sharded_dim
|
||||
arr_devices = a.sharding.devices
|
||||
arr_device_assignment = a.sharding._device_assignment
|
||||
if arr_sharding != i:
|
||||
raise ValueError('Array and pmap sharding does not match. Got pmap '
|
||||
f'sharding: {i}, Array sharding: {arr_sharding} for '
|
||||
f'arg: {a}')
|
||||
if (in_devices is not None and
|
||||
arr_devices is not None and
|
||||
not np.array_equal(arr_devices, in_devices)):
|
||||
arr_device_assignment is not None and
|
||||
arr_device_assignment != in_devices):
|
||||
raise ValueError('Devices passed to pmap and Array should be equal. '
|
||||
f'Got pmap devices: {devices}, Array devices: '
|
||||
f'{arr_devices} for arg: {a}')
|
||||
f'Got pmap devices: {in_devices}, Array devices: '
|
||||
f'{arr_device_assignment} for arg: {a}')
|
||||
if (in_devices is None and
|
||||
not np.array_equal(arr_devices, first_arr_devices)):
|
||||
arr_device_assignment != first_device_assignment):
|
||||
raise ValueError('Devices of all `Array` inputs should be the same. '
|
||||
f'Got array device: {arr_devices}, '
|
||||
f'another array device: {first_arr_devices}')
|
||||
return first_arr_devices
|
||||
f'Got array device: {arr_device_assignment}, '
|
||||
f'another array device: {first_device_assignment}')
|
||||
|
||||
|
||||
def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
|
||||
donate_tuple, global_arg_shapes, devices, args, kwargs):
|
||||
donate_tuple, global_arg_shapes, in_devices, args, kwargs):
|
||||
f = lu.wrap_init(fun)
|
||||
if static_broadcasted_tuple:
|
||||
if max(static_broadcasted_tuple) >= len(args):
|
||||
@ -1971,13 +1973,7 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
|
||||
flat_fun, out_tree = flatten_fun(f, in_tree)
|
||||
|
||||
if config.jax_array:
|
||||
from jax.experimental.array import Array
|
||||
if any(not isinstance(a, Array) for a in args):
|
||||
raise ValueError('All arguments to pmap when `config.jax_array` is '
|
||||
'enabled should be `Array`s.')
|
||||
arr_devices = _check_in_pmap_sharding_with_arrays(args, in_axes_flat, devices)
|
||||
if devices is None and arr_devices is not None:
|
||||
devices = arr_devices
|
||||
_check_in_pmap_sharding_with_arrays(args, in_axes_flat, in_devices)
|
||||
|
||||
if any(out_axis is None for out_axis in tree_flatten(out_axes)):
|
||||
raise NotImplementedError("None out_axes in pmap are not supported yet")
|
||||
@ -2011,7 +2007,7 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
|
||||
local_axis_size=local_axis_size,
|
||||
global_arg_shapes_flat=global_arg_shapes_flat,
|
||||
out_axes_thunk=out_axes_thunk,
|
||||
devices=None if devices is None else tuple(devices))
|
||||
devices=None if in_devices is None else tuple(in_devices))
|
||||
|
||||
|
||||
def _get_f_mapped(
|
||||
@ -2199,8 +2195,9 @@ def _cpp_pmap(
|
||||
|
||||
return out, fastpath_data
|
||||
|
||||
cpp_mapped_f = pmap_lib.pmap(fun, cache_miss,
|
||||
static_broadcasted_tuple, pxla._shard_arg)
|
||||
cpp_mapped_f = pmap_lib.pmap(
|
||||
fun, cache_miss, static_broadcasted_tuple,
|
||||
partial(pxla._shard_arg, mode=pxla.InputsHandlerMode.pmap))
|
||||
|
||||
pmap_f = wraps(fun)(cpp_mapped_f)
|
||||
|
||||
|
@ -21,7 +21,7 @@ from jax import core
|
||||
from jax._src import api_util
|
||||
from jax._src import dispatch
|
||||
from jax._src.config import config
|
||||
from jax._src.util import prod
|
||||
from jax._src.util import prod, safe_zip
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.api import device_put
|
||||
from jax.interpreters import pxla, xla
|
||||
@ -261,12 +261,26 @@ def _device_put_array(x, device: Optional[Device]):
|
||||
dispatch.device_put_handlers[Array] = _device_put_array
|
||||
|
||||
|
||||
def _array_shard_arg(x, devices, indices):
|
||||
return x._arrays
|
||||
def _array_shard_arg(x, devices, indices, mode):
|
||||
# TODO(yashkatariya): Remove the `mode` handling and try to consolidate the
|
||||
# code paths.
|
||||
if mode == pxla.InputsHandlerMode.pmap:
|
||||
# sharding mismatch between `Array` and pmap sharding is checked in api.py's
|
||||
# `_check_in_pmap_sharding_with_arrays` function.
|
||||
return [buf if buf.device() == d else buf.copy_to_device(d)
|
||||
for buf, d in safe_zip(x._arrays, devices)]
|
||||
else:
|
||||
return x._arrays
|
||||
pxla.shard_arg_handlers[Array] = _array_shard_arg
|
||||
|
||||
|
||||
def _array_result_handler(global_aval, out_sharding):
|
||||
def _array_global_result_handler(global_aval, out_sharding):
|
||||
return lambda bufs: Array(global_aval.shape, out_sharding, bufs, committed=True)
|
||||
pxla.global_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_result_handler
|
||||
pxla.global_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_result_handler
|
||||
pxla.global_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_global_result_handler
|
||||
pxla.global_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_global_result_handler
|
||||
|
||||
|
||||
def _array_local_result_handler(aval, sharding, indices):
|
||||
return lambda bufs: Array(aval.shape, sharding, bufs, committed=True)
|
||||
pxla.local_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_local_result_handler
|
||||
pxla.local_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_local_result_handler
|
||||
|
@ -561,7 +561,9 @@ xla.canonicalize_dtype_handlers[GlobalDeviceArray] = pxla.identity
|
||||
api_util._shaped_abstractify_handlers[GlobalDeviceArray] = \
|
||||
lambda x: core.ShapedArray(x.shape, x.dtype)
|
||||
|
||||
def _gda_shard_arg(x, devices, indices):
|
||||
def _gda_shard_arg(x, devices, indices, mode):
|
||||
if mode == pxla.InputsHandlerMode.pmap:
|
||||
raise RuntimeError('GDA is not supported with pmap.')
|
||||
return x._device_buffers
|
||||
pxla.shard_arg_handlers[GlobalDeviceArray] = _gda_shard_arg
|
||||
|
||||
|
@ -368,7 +368,7 @@ def spec_to_indices(shape: Tuple[int, ...],
|
||||
|
||||
def identity(x): return x
|
||||
|
||||
def _shard_arg(arg, devices, arg_indices):
|
||||
def _shard_arg(arg, devices, arg_indices, mode):
|
||||
"""Returns a list of size len(devices) containing per-device buffers.
|
||||
|
||||
For the C++ pmap path, we fallback to Python (this function) to shard
|
||||
@ -378,6 +378,7 @@ def _shard_arg(arg, devices, arg_indices):
|
||||
arg: The Python argument.
|
||||
devices: The list of devices to shard over.
|
||||
arg_indices: A list of `len(devices)` indices to use to shard the argument.
|
||||
mode: An enum telling whether shard_arg is executed via pmap or pjit/xmap.
|
||||
"""
|
||||
if isinstance(arg, ShardedDeviceArray) and arg_indices == arg.indices:
|
||||
# The shard_arg_handlers allow an extensible set of types to be sharded, but
|
||||
@ -390,12 +391,13 @@ def _shard_arg(arg, devices, arg_indices):
|
||||
]
|
||||
else:
|
||||
arg = xla.canonicalize_dtype(arg)
|
||||
return shard_arg_handlers[type(arg)](arg, devices, arg_indices)
|
||||
return shard_arg_handlers[type(arg)](arg, devices, arg_indices, mode)
|
||||
|
||||
|
||||
@profiler.annotate_function
|
||||
def shard_args(devices: Sequence[xb.xla_client.Device],
|
||||
indices: Sequence[Sequence[Index]],
|
||||
mode: InputsHandlerMode,
|
||||
args) -> Sequence[Sequence[xb.xla_client.Buffer]]:
|
||||
"""Shard each argument data array along its leading axis.
|
||||
|
||||
@ -411,16 +413,16 @@ def shard_args(devices: Sequence[xb.xla_client.Device],
|
||||
A list of length matching args, containing lists of per-device buffers
|
||||
for each argument.
|
||||
"""
|
||||
return [_shard_arg(arg, devices, indices[i]) for i, arg in enumerate(args)]
|
||||
return [_shard_arg(arg, devices, indices[i], mode) for i, arg in enumerate(args)]
|
||||
|
||||
|
||||
shard_arg_handlers: Dict[Any, Callable[[Any, Any, Any], Sequence[Any]]] = {}
|
||||
def _shard_array(x, devices, indices):
|
||||
shard_arg_handlers: Dict[Any, Callable[[Any, Any, Any, InputsHandlerMode], Sequence[Any]]] = {}
|
||||
def _shard_array(x, devices, indices, mode):
|
||||
return device_put([x[i] for i in indices], devices)
|
||||
for _t in array_types:
|
||||
shard_arg_handlers[_t] = _shard_array
|
||||
|
||||
def _shard_device_array(x, devices, indices):
|
||||
def _shard_device_array(x, devices, indices, mode):
|
||||
start_indices, limit_indices, removed_dims = unzip3(
|
||||
_as_slice_indices(x, idx) for idx in indices)
|
||||
shards = x._multi_slice(start_indices, limit_indices, removed_dims)
|
||||
@ -524,9 +526,15 @@ def array_mapping_to_axis_resources(array_mapping: ArrayMapping):
|
||||
return PartitionSpec(*partitions)
|
||||
|
||||
|
||||
class OutputType(enum.Enum):
|
||||
Array = 0
|
||||
GlobalDeviceArray = 1
|
||||
ShardedDeviceArray = 2
|
||||
|
||||
|
||||
def local_aval_to_result_handler(
|
||||
aval: core.AbstractValue,
|
||||
sharding_spec: Optional[ShardingSpec],
|
||||
sharding: XLACompatibleSharding,
|
||||
indices: Optional[Tuple[Index]],
|
||||
) -> Callable[[List[xb.xla_client.Buffer]], Any]:
|
||||
"""Returns a function for handling the raw buffers of a single output aval.
|
||||
@ -543,24 +551,25 @@ def local_aval_to_result_handler(
|
||||
for this output. The function will return an object suitable for returning
|
||||
to the user, e.g. a ShardedDeviceArray.
|
||||
"""
|
||||
if config.jax_array:
|
||||
output_type = OutputType.Array
|
||||
else:
|
||||
output_type = OutputType.ShardedDeviceArray
|
||||
try:
|
||||
return local_result_handlers[type(aval)](aval, sharding_spec, indices)
|
||||
return local_result_handlers[(type(aval), output_type)](aval, sharding, indices)
|
||||
except KeyError as err:
|
||||
raise TypeError(
|
||||
f"No pxla_result_handler for type: {type(aval)}") from err
|
||||
|
||||
PxlaResultHandler = Callable[..., Callable[[List[xb.xla_client.Buffer]], Any]]
|
||||
local_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {}
|
||||
def sda_array_result_handler(aval: ShapedArray, sharding_spec, indices):
|
||||
local_result_handlers: Dict[Tuple[Type[core.AbstractValue], OutputType], PxlaResultHandler] = {}
|
||||
|
||||
def sda_array_result_handler(aval: ShapedArray, sharding, indices):
|
||||
sharding_spec = _get_sharding_specs([sharding], [aval])[0]
|
||||
return lambda bufs: make_sharded_device_array(aval, sharding_spec, bufs,
|
||||
indices)
|
||||
local_result_handlers[ShapedArray] = sda_array_result_handler
|
||||
local_result_handlers[ConcreteArray] = sda_array_result_handler
|
||||
|
||||
|
||||
class OutputType(enum.Enum):
|
||||
Array = 0
|
||||
GlobalDeviceArray = 1
|
||||
local_result_handlers[(ShapedArray, OutputType.ShardedDeviceArray)] = sda_array_result_handler
|
||||
local_result_handlers[(ConcreteArray, OutputType.ShardedDeviceArray)] = sda_array_result_handler
|
||||
|
||||
|
||||
def global_aval_to_result_handler(
|
||||
@ -839,7 +848,7 @@ def _hashable_index(idx):
|
||||
|
||||
# The fast path is handled directly in shard_args().
|
||||
# TODO(skye): is there a simpler way to rewrite this using sharding_spec?
|
||||
def _shard_sharded_device_array_slow_path(x, devices, indices):
|
||||
def _shard_sharded_device_array_slow_path(x, devices, indices, mode):
|
||||
candidates = defaultdict(list)
|
||||
for buf, idx in safe_zip(x.device_buffers, x.indices):
|
||||
candidates[_hashable_index(idx)].append(buf)
|
||||
@ -851,7 +860,7 @@ def _shard_sharded_device_array_slow_path(x, devices, indices):
|
||||
if not candidates_list:
|
||||
# This array isn't sharded correctly. Reshard it via host roundtrip.
|
||||
# TODO(skye): more efficient reshard?
|
||||
return shard_arg_handlers[type(x._value)](x._value, devices, indices)
|
||||
return shard_arg_handlers[type(x._value)](x._value, devices, indices, mode)
|
||||
# Try to find a candidate buffer already on the correct device,
|
||||
# otherwise copy one of them.
|
||||
for buf in candidates_list:
|
||||
@ -1293,8 +1302,6 @@ class PmapExecutable(stages.XlaExecutable):
|
||||
])
|
||||
|
||||
local_arg_parts_ = parts.local_arg_parts or [None] * len(pci.avals)
|
||||
# TODO(yashkatariya): Fix the input handling of `Array`s that span over
|
||||
# multiple processes. Add multi-process tests for pmap.
|
||||
input_sharding_specs = [
|
||||
_pmap_sharding_spec(replicas.num_local_replicas, pci.axis_size,
|
||||
parts.local_num_partitions, arg_parts, aval, in_axis)
|
||||
@ -1312,40 +1319,26 @@ class PmapExecutable(stages.XlaExecutable):
|
||||
if parts.local_out_parts is None:
|
||||
local_out_parts = (None,) * nouts
|
||||
|
||||
if config.jax_array:
|
||||
global_unmapped_avals = [
|
||||
local_out_avals = [
|
||||
get_local_aval(aval, parts, lparts)
|
||||
for aval, parts, lparts
|
||||
in safe_zip(shards.out_sharded_avals, out_parts, local_out_parts)]
|
||||
local_unmapped_avals = [
|
||||
core.unmapped_aval(pci.axis_size, pci.axis_name, out_axis, aval)
|
||||
if out_axis is not None else aval
|
||||
for aval, out_axis in safe_zip(shards.out_sharded_avals, pci.out_axes)]
|
||||
global_out_specs = [
|
||||
_pmap_sharding_spec(replicas.num_global_replicas, pci.axis_size,
|
||||
parts.num_partitions, op, aval, out_axis)
|
||||
for op, aval, out_axis in safe_zip(
|
||||
out_parts, shards.out_sharded_avals, pci.out_axes)]
|
||||
pmap_shardings = _get_pmap_sharding(device_assignment, global_out_specs)
|
||||
handle_outs = global_avals_to_results_handler(
|
||||
global_unmapped_avals, pmap_shardings)
|
||||
else:
|
||||
local_out_avals = [
|
||||
get_local_aval(aval, parts, lparts)
|
||||
for aval, parts, lparts
|
||||
in safe_zip(shards.out_sharded_avals, out_parts, local_out_parts)]
|
||||
local_unmapped_avals = [
|
||||
core.unmapped_aval(pci.axis_size, pci.axis_name, out_axis, aval)
|
||||
if out_axis is not None else aval
|
||||
for aval, out_axis in safe_zip(local_out_avals, pci.out_axes)]
|
||||
out_specs = [
|
||||
_pmap_sharding_spec(replicas.num_local_replicas, pci.axis_size,
|
||||
parts.local_num_partitions, out_parts, aval, out_axis)
|
||||
for out_parts, aval, out_axis in safe_zip(
|
||||
local_out_parts, local_out_avals, pci.out_axes)]
|
||||
pmap_shardings = _get_pmap_sharding(local_device_assignment, out_specs)
|
||||
handle_outs = local_avals_to_results_handler(local_unmapped_avals, pmap_shardings)
|
||||
for aval, out_axis in safe_zip(local_out_avals, pci.out_axes)]
|
||||
out_specs = [
|
||||
_pmap_sharding_spec(replicas.num_local_replicas, pci.axis_size,
|
||||
parts.local_num_partitions, out_parts, aval, out_axis)
|
||||
for out_parts, aval, out_axis in safe_zip(
|
||||
local_out_parts, local_out_avals, pci.out_axes)]
|
||||
pmap_shardings = _get_pmap_sharding(local_device_assignment, out_specs)
|
||||
handle_outs = local_avals_to_results_handler(local_unmapped_avals, pmap_shardings)
|
||||
|
||||
if hasattr(pci.backend, "compile_replicated"):
|
||||
execute_fun = pci.backend.compile_replicated(
|
||||
xla_computation, compile_options, pci.avals, input_indices,
|
||||
in_shardings, handle_outs)
|
||||
in_shardings, InputsHandlerMode.pmap, handle_outs)
|
||||
# TODO(frostig): need `compile_replicated` to give us the XLA executable
|
||||
return PmapExecutable(None, execute_fun, None, pci.avals)
|
||||
|
||||
@ -1354,7 +1347,7 @@ class PmapExecutable(stages.XlaExecutable):
|
||||
compiled = dispatch.compile_or_get_cached(
|
||||
pci.backend, xla_computation, compile_options, host_callbacks)
|
||||
handle_args = InputsHandler(
|
||||
compiled.local_devices(), in_shardings, input_indices)
|
||||
compiled.local_devices(), in_shardings, input_indices, InputsHandlerMode.pmap)
|
||||
execute_fun = ExecuteReplicated(compiled, pci.backend, handle_args,
|
||||
handle_outs, unordered_effects, keepalive)
|
||||
fingerprint = getattr(compiled, "fingerprint", None)
|
||||
@ -1520,14 +1513,21 @@ def _safe_div(x, y):
|
||||
return result
|
||||
|
||||
|
||||
class InputsHandler:
|
||||
__slots__ = ("handler", "local_devices", "in_shardings", "input_indices")
|
||||
class InputsHandlerMode(enum.Enum):
|
||||
pmap = 0
|
||||
pjit_or_xmap = 1
|
||||
|
||||
def __init__(self, local_devices, in_shardings, input_indices):
|
||||
self.handler = partial(shard_args, local_devices, input_indices)
|
||||
|
||||
class InputsHandler:
|
||||
__slots__ = ("handler", "local_devices", "in_shardings", "input_indices",
|
||||
"mode")
|
||||
|
||||
def __init__(self, local_devices, in_shardings, input_indices, mode):
|
||||
self.handler = partial(shard_args, local_devices, input_indices, mode)
|
||||
self.local_devices = local_devices
|
||||
self.in_shardings = in_shardings
|
||||
self.input_indices = input_indices
|
||||
self.mode = mode
|
||||
|
||||
def __call__(self, input_buffers):
|
||||
return self.handler(input_buffers)
|
||||
@ -1536,7 +1536,8 @@ class InputsHandler:
|
||||
return ("InputsHandler(\n"
|
||||
f"local_devices={self.local_devices},\n"
|
||||
f"in_shardings={self.in_shardings},\n"
|
||||
f"input_indices={self.input_indices})")
|
||||
f"input_indices={self.input_indices})\n"
|
||||
f"mode={self.mode}")
|
||||
|
||||
|
||||
class ResultsHandler:
|
||||
@ -1572,13 +1573,11 @@ def _get_sharding_specs(
|
||||
def local_avals_to_results_handler(
|
||||
unmapped_local_out_avals: Sequence[Optional[ShapedArray]],
|
||||
local_shardings: Sequence[XLACompatibleSharding]) -> ResultsHandler:
|
||||
local_out_specs = _get_sharding_specs(
|
||||
local_shardings, cast(Sequence[ShapedArray], unmapped_local_out_avals))
|
||||
out_indices = [tuple(s.devices_indices_map(aval.shape).values())
|
||||
for s, aval in safe_zip(local_shardings, unmapped_local_out_avals)]
|
||||
handlers = [
|
||||
local_aval_to_result_handler(aval, spec, idcs)
|
||||
for aval, spec, idcs in safe_zip(unmapped_local_out_avals, local_out_specs, out_indices)
|
||||
local_aval_to_result_handler(aval, s, idcs)
|
||||
for aval, s, idcs in safe_zip(unmapped_local_out_avals, local_shardings, out_indices)
|
||||
]
|
||||
return ResultsHandler(handlers, local_shardings, unmapped_local_out_avals)
|
||||
|
||||
@ -2761,7 +2760,7 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
global_out_avals, out_shardings) # type: ignore # arg-type
|
||||
unsafe_call = backend.compile_replicated(
|
||||
computation, compile_options, input_avals, input_indices,
|
||||
in_shardings, handle_outs)
|
||||
in_shardings, InputsHandlerMode.pjit_or_xmap, handle_outs)
|
||||
xla_executable = None
|
||||
else:
|
||||
with dispatch.log_elapsed_time(f"Finished XLA compilation of {name} "
|
||||
@ -2784,7 +2783,7 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
handle_outs = global_avals_to_results_handler(
|
||||
global_out_avals, out_shardings) # type: ignore # arg-type
|
||||
handle_args = InputsHandler(xla_executable.local_devices(), in_shardings,
|
||||
input_indices)
|
||||
input_indices, InputsHandlerMode.pjit_or_xmap)
|
||||
unsafe_call = ExecuteReplicated(xla_executable, backend, handle_args,
|
||||
handle_outs, unordered_effects, keepalive)
|
||||
|
||||
|
@ -189,6 +189,11 @@ jax_test(
|
||||
jax_test(
|
||||
name = "pjit_test",
|
||||
srcs = ["pjit_test.py"],
|
||||
shard_count = {
|
||||
"cpu": 5,
|
||||
"gpu": 5,
|
||||
"tpu": 5,
|
||||
},
|
||||
tags = ["multiaccelerator"],
|
||||
deps = [
|
||||
"//jax:experimental",
|
||||
|
@ -2901,7 +2901,7 @@ class ShardArgsTest(jtu.JaxTestCase):
|
||||
x = np.arange(prod(shape)).reshape(shape)
|
||||
arg = make_arg(x)
|
||||
bufs = pxla.shard_args(jax.devices()[:nshards],
|
||||
[indices], [arg])
|
||||
[indices], pxla.InputsHandlerMode.pmap, [arg])
|
||||
self.assertEqual(len(bufs), 1)
|
||||
self.assertEqual(len(bufs[0]), nshards)
|
||||
for buf, idx in zip(bufs[0], indices):
|
||||
|
Loading…
x
Reference in New Issue
Block a user