Move _array_shard_arg helpers from pxla into array.

Refactoring only which fixes a TODO.

Add a canonicalize argument to pxla.shard_arg so we can call that API from array yet  avoid double-canonicalization.

PiperOrigin-RevId: 549658117
This commit is contained in:
Peter Hawkins 2023-07-20 09:43:40 -07:00 committed by jax authors
parent 08366b21a1
commit fe30d3fd4b
2 changed files with 84 additions and 86 deletions

View File

@ -14,6 +14,7 @@
from __future__ import annotations
from collections import defaultdict
import math
import operator as op
import numpy as np
@ -29,6 +30,7 @@ from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import profiler
from jax._src import tree_util
from jax._src import xla_bridge
from jax._src.config import config
from jax._src.lib import xla_client as xc
@ -40,7 +42,7 @@ from jax._src.sharding_impls import (
SingleDeviceSharding, XLACompatibleSharding, PmapSharding,
device_replica_id_map, hashed_index)
from jax._src.typing import ArrayLike
from jax._src.util import use_cpp_class, use_cpp_method
from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method
Shape = tuple[int, ...]
Device = xc.Device
@ -681,6 +683,80 @@ def _array_mlir_constant_handler(val, canonicalize_types=True):
mlir.register_constant_handler(ArrayImpl, _array_mlir_constant_handler)
# NOTE(skye): we could refactor to generate _multi_slice parameters directly
# from the input ShardingSpec, rather than the indices. However, this would
# require duplicating the ordering logic of spec_to_indices, which is more
# subtle and more likely to change than the index logic we have to support here.
def as_slice_indices(arr: Any, idx: Index) -> tuple[
tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
"""Returns start_indices, limit_indices, removed_dims"""
start_indices = [0] * arr.ndim
limit_indices = list(arr.shape)
removed_dims = []
tuple_idx = idx if isinstance(idx, tuple) else (idx,)
for dim, sub_idx in enumerate(tuple_idx):
if isinstance(sub_idx, int):
start_indices[dim] = sub_idx
limit_indices[dim] = sub_idx + 1
removed_dims.append(dim)
elif sub_idx == slice(None):
continue
else:
assert isinstance(sub_idx, slice), sub_idx
assert isinstance(sub_idx.start, int), sub_idx
assert isinstance(sub_idx.stop, int), sub_idx
start_indices[dim] = sub_idx.start
limit_indices[dim] = sub_idx.stop
return tuple(start_indices), tuple(limit_indices), tuple(removed_dims) # type: ignore
def shard_device_array(x, devices, indices, sharding):
start_indices, limit_indices, removed_dims = unzip3(
as_slice_indices(x, idx) for idx in indices)
shards = x._multi_slice(start_indices, limit_indices, removed_dims)
aval = api_util.shaped_abstractify(x)
out = pxla.batched_device_put(aval, sharding, shards, devices)
return out
def _hashable_index(idx):
return tree_util.tree_map(
lambda x: (x.start, x.stop) if type(x) == slice else x, idx)
# The fast path is handled directly in shard_args().
def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
candidates = defaultdict(list)
if isinstance(x, ArrayImpl):
bufs = [buf.data for buf in x.addressable_shards]
arr_indices = tuple(x.sharding.devices_indices_map(x.shape).values())
else:
bufs = x.device_buffers
arr_indices = x.indices
for buf, idx in safe_zip(bufs, arr_indices):
candidates[_hashable_index(idx)].append(buf)
bufs = []
for idx, device in safe_zip(indices, devices):
# Look up all buffers that contain the correct slice of the logical array.
candidates_list = candidates[_hashable_index(idx)]
if not candidates_list:
# This array isn't sharded correctly. Reshard it via host roundtrip.
# TODO(skye): more efficient reshard?
return pxla.shard_arg(x._value, devices, indices, sharding,
canonicalize=False)
# Try to find a candidate buffer already on the correct device,
# otherwise copy one of them.
for buf in candidates_list:
if buf.device() == device:
bufs.append(buf)
break
else:
bufs.append(buf)
return pxla.batched_device_put(x.aval, sharding, bufs, devices)
def _array_shard_arg(x, devices, indices, sharding):
x._check_if_deleted()
@ -697,10 +773,9 @@ def _array_shard_arg(x, devices, indices, sharding):
x, list(devices), sharding)
# Resharding starts here:
if dispatch.is_single_device_sharding(x.sharding):
return pxla.shard_device_array(x, devices, indices, sharding)
return shard_device_array(x, devices, indices, sharding)
else:
return pxla.shard_sharded_device_array_slow_path(
x, devices, indices, sharding)
return shard_sharded_device_array_slow_path(x, devices, indices, sharding)
pxla.shard_arg_handlers[ArrayImpl] = _array_shard_arg

View File

@ -17,7 +17,7 @@ from __future__ import annotations
import enum
from contextlib import contextmanager
from collections import defaultdict, namedtuple
from collections import namedtuple
import dataclasses
from functools import partial, lru_cache, cached_property
import itertools as it
@ -30,7 +30,6 @@ import numpy as np
import jax
from jax.errors import JAXTypeError
from jax.tree_util import tree_map
from jax._src import api_util
from jax._src import core
@ -67,7 +66,7 @@ from jax._src.sharding_impls import (
AUTO, UnspecifiedValue, UNSPECIFIED,
get_array_mapping as _get_array_mapping, is_auto, is_unspecified
)
from jax._src.util import (unzip3, safe_map, safe_zip, partition_list,
from jax._src.util import (safe_map, safe_zip, partition_list,
wrap_name, tuple_delete, distributed_debug_log,
unzip2, HashableFunction, weakref_lru_cache)
@ -103,7 +102,7 @@ ShardingSpec = sharding_specs.ShardingSpec
def identity(x): return x
def shard_arg(arg, devices, arg_indices, sharding):
def shard_arg(arg, devices, arg_indices, sharding, canonicalize=True):
"""Returns a list of size len(devices) containing per-device buffers.
For the C++ pmap path, we fallback to Python (this function) to shard
@ -114,7 +113,8 @@ def shard_arg(arg, devices, arg_indices, sharding):
devices: The list of devices to shard over.
arg_indices: A list of `len(devices)` indices to use to shard the argument.
"""
arg = xla.canonicalize_dtype(arg)
if canonicalize:
arg = xla.canonicalize_dtype(arg)
return shard_arg_handlers[type(arg)](arg, devices, arg_indices, sharding)
@ -169,14 +169,6 @@ def _shard_darray(x, devices, indices, sharding):
return shard_arg(x._data, devices, indices, sharding)
shard_arg_handlers[core.DArray] = _shard_darray
def shard_device_array(x, devices, indices, sharding):
start_indices, limit_indices, removed_dims = unzip3(
as_slice_indices(x, idx) for idx in indices)
shards = x._multi_slice(start_indices, limit_indices, removed_dims)
aval = api_util.shaped_abstractify(x)
out = batched_device_put(aval, sharding, shards, devices)
return out
def batched_device_put(aval: core.ShapedArray,
sharding: jax.sharding.Sharding, xs: Sequence[Any],
devices: Sequence[jax.Device], committed: bool = True):
@ -191,36 +183,6 @@ def batched_device_put(aval: core.ShapedArray,
aval, sharding, bufs, committed=committed, _skip_checks=True)
return xc.batched_device_put(aval, sharding, xs, devices, committed) # type: ignore
# NOTE(skye): we could refactor to generate _multi_slice parameters directly
# from the input ShardingSpec, rather than the indices. However, this would
# require duplicating the ordering logic of spec_to_indices, which is more
# subtle and more likely to change than the index logic we have to support here.
def as_slice_indices(arr: Any, idx: Index) -> tuple[
tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
"""Returns start_indices, limit_indices, removed_dims"""
start_indices = [0] * arr.ndim
limit_indices = list(arr.shape)
removed_dims = []
tuple_idx = idx if isinstance(idx, tuple) else (idx,)
for dim, sub_idx in enumerate(tuple_idx):
if isinstance(sub_idx, int):
start_indices[dim] = sub_idx
limit_indices[dim] = sub_idx + 1
removed_dims.append(dim)
elif sub_idx == slice(None):
continue
else:
assert isinstance(sub_idx, slice), sub_idx
assert isinstance(sub_idx.start, int), sub_idx
assert isinstance(sub_idx.stop, int), sub_idx
start_indices[dim] = sub_idx.start
limit_indices[dim] = sub_idx.stop
return tuple(start_indices), tuple(limit_indices), tuple(removed_dims) # type: ignore
def shard_aval(size, axis: int, aval):
try:
return shard_aval_handlers[type(aval)](size, axis, aval)
@ -298,45 +260,6 @@ global_result_handlers: dict[type[core.AbstractValue], PxlaResultHandler] = {}
### lazy device-memory persistence and result handling
def _hashable_index(idx):
return tree_map(lambda x: (x.start, x.stop) if type(x) == slice else x, idx)
# The fast path is handled directly in shard_args().
# TODO(yashkatariya): Move this to array.py when SDA is deleted. The local
# import of Array should go away at that time.
def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
from jax._src.array import ArrayImpl
candidates = defaultdict(list)
if isinstance(x, ArrayImpl):
bufs = [buf.data for buf in x.addressable_shards]
arr_indices = tuple(x.sharding.devices_indices_map(x.shape).values())
else:
bufs = x.device_buffers
arr_indices = x.indices
for buf, idx in safe_zip(bufs, arr_indices):
candidates[_hashable_index(idx)].append(buf)
bufs = []
for idx, device in safe_zip(indices, devices):
# Look up all buffers that contain the correct slice of the logical array.
candidates_list = candidates[_hashable_index(idx)]
if not candidates_list:
# This array isn't sharded correctly. Reshard it via host roundtrip.
# TODO(skye): more efficient reshard?
return shard_arg_handlers[type(x._value)](
x._value, devices, indices, sharding)
# Try to find a candidate buffer already on the correct device,
# otherwise copy one of them.
for buf in candidates_list:
if buf.device() == device:
bufs.append(buf)
break
else:
bufs.append(buf)
return batched_device_put(x.aval, sharding, bufs, devices)
### the xla_pmap primitive and its rules are comparable to xla_call in xla.py