mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Move _array_shard_arg helpers from pxla into array.
Refactoring only which fixes a TODO. Add a canonicalize argument to pxla.shard_arg so we can call that API from array yet avoid double-canonicalization. PiperOrigin-RevId: 549658117
This commit is contained in:
parent
08366b21a1
commit
fe30d3fd4b
@ -14,6 +14,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
import math
|
||||
import operator as op
|
||||
import numpy as np
|
||||
@ -29,6 +30,7 @@ from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import profiler
|
||||
from jax._src import tree_util
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.config import config
|
||||
from jax._src.lib import xla_client as xc
|
||||
@ -40,7 +42,7 @@ from jax._src.sharding_impls import (
|
||||
SingleDeviceSharding, XLACompatibleSharding, PmapSharding,
|
||||
device_replica_id_map, hashed_index)
|
||||
from jax._src.typing import ArrayLike
|
||||
from jax._src.util import use_cpp_class, use_cpp_method
|
||||
from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method
|
||||
|
||||
Shape = tuple[int, ...]
|
||||
Device = xc.Device
|
||||
@ -681,6 +683,80 @@ def _array_mlir_constant_handler(val, canonicalize_types=True):
|
||||
mlir.register_constant_handler(ArrayImpl, _array_mlir_constant_handler)
|
||||
|
||||
|
||||
# NOTE(skye): we could refactor to generate _multi_slice parameters directly
|
||||
# from the input ShardingSpec, rather than the indices. However, this would
|
||||
# require duplicating the ordering logic of spec_to_indices, which is more
|
||||
# subtle and more likely to change than the index logic we have to support here.
|
||||
def as_slice_indices(arr: Any, idx: Index) -> tuple[
|
||||
tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
"""Returns start_indices, limit_indices, removed_dims"""
|
||||
start_indices = [0] * arr.ndim
|
||||
limit_indices = list(arr.shape)
|
||||
removed_dims = []
|
||||
|
||||
tuple_idx = idx if isinstance(idx, tuple) else (idx,)
|
||||
for dim, sub_idx in enumerate(tuple_idx):
|
||||
if isinstance(sub_idx, int):
|
||||
start_indices[dim] = sub_idx
|
||||
limit_indices[dim] = sub_idx + 1
|
||||
removed_dims.append(dim)
|
||||
elif sub_idx == slice(None):
|
||||
continue
|
||||
else:
|
||||
assert isinstance(sub_idx, slice), sub_idx
|
||||
assert isinstance(sub_idx.start, int), sub_idx
|
||||
assert isinstance(sub_idx.stop, int), sub_idx
|
||||
start_indices[dim] = sub_idx.start
|
||||
limit_indices[dim] = sub_idx.stop
|
||||
|
||||
return tuple(start_indices), tuple(limit_indices), tuple(removed_dims) # type: ignore
|
||||
|
||||
|
||||
def shard_device_array(x, devices, indices, sharding):
|
||||
start_indices, limit_indices, removed_dims = unzip3(
|
||||
as_slice_indices(x, idx) for idx in indices)
|
||||
shards = x._multi_slice(start_indices, limit_indices, removed_dims)
|
||||
aval = api_util.shaped_abstractify(x)
|
||||
out = pxla.batched_device_put(aval, sharding, shards, devices)
|
||||
return out
|
||||
|
||||
def _hashable_index(idx):
|
||||
return tree_util.tree_map(
|
||||
lambda x: (x.start, x.stop) if type(x) == slice else x, idx)
|
||||
|
||||
# The fast path is handled directly in shard_args().
|
||||
def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
|
||||
candidates = defaultdict(list)
|
||||
if isinstance(x, ArrayImpl):
|
||||
bufs = [buf.data for buf in x.addressable_shards]
|
||||
arr_indices = tuple(x.sharding.devices_indices_map(x.shape).values())
|
||||
else:
|
||||
bufs = x.device_buffers
|
||||
arr_indices = x.indices
|
||||
for buf, idx in safe_zip(bufs, arr_indices):
|
||||
candidates[_hashable_index(idx)].append(buf)
|
||||
|
||||
bufs = []
|
||||
for idx, device in safe_zip(indices, devices):
|
||||
# Look up all buffers that contain the correct slice of the logical array.
|
||||
candidates_list = candidates[_hashable_index(idx)]
|
||||
if not candidates_list:
|
||||
# This array isn't sharded correctly. Reshard it via host roundtrip.
|
||||
# TODO(skye): more efficient reshard?
|
||||
return pxla.shard_arg(x._value, devices, indices, sharding,
|
||||
canonicalize=False)
|
||||
# Try to find a candidate buffer already on the correct device,
|
||||
# otherwise copy one of them.
|
||||
for buf in candidates_list:
|
||||
if buf.device() == device:
|
||||
bufs.append(buf)
|
||||
break
|
||||
else:
|
||||
bufs.append(buf)
|
||||
|
||||
return pxla.batched_device_put(x.aval, sharding, bufs, devices)
|
||||
|
||||
|
||||
def _array_shard_arg(x, devices, indices, sharding):
|
||||
x._check_if_deleted()
|
||||
|
||||
@ -697,10 +773,9 @@ def _array_shard_arg(x, devices, indices, sharding):
|
||||
x, list(devices), sharding)
|
||||
# Resharding starts here:
|
||||
if dispatch.is_single_device_sharding(x.sharding):
|
||||
return pxla.shard_device_array(x, devices, indices, sharding)
|
||||
return shard_device_array(x, devices, indices, sharding)
|
||||
else:
|
||||
return pxla.shard_sharded_device_array_slow_path(
|
||||
x, devices, indices, sharding)
|
||||
return shard_sharded_device_array_slow_path(x, devices, indices, sharding)
|
||||
|
||||
|
||||
pxla.shard_arg_handlers[ArrayImpl] = _array_shard_arg
|
||||
|
@ -17,7 +17,7 @@ from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from contextlib import contextmanager
|
||||
from collections import defaultdict, namedtuple
|
||||
from collections import namedtuple
|
||||
import dataclasses
|
||||
from functools import partial, lru_cache, cached_property
|
||||
import itertools as it
|
||||
@ -30,7 +30,6 @@ import numpy as np
|
||||
|
||||
import jax
|
||||
from jax.errors import JAXTypeError
|
||||
from jax.tree_util import tree_map
|
||||
|
||||
from jax._src import api_util
|
||||
from jax._src import core
|
||||
@ -67,7 +66,7 @@ from jax._src.sharding_impls import (
|
||||
AUTO, UnspecifiedValue, UNSPECIFIED,
|
||||
get_array_mapping as _get_array_mapping, is_auto, is_unspecified
|
||||
)
|
||||
from jax._src.util import (unzip3, safe_map, safe_zip, partition_list,
|
||||
from jax._src.util import (safe_map, safe_zip, partition_list,
|
||||
wrap_name, tuple_delete, distributed_debug_log,
|
||||
unzip2, HashableFunction, weakref_lru_cache)
|
||||
|
||||
@ -103,7 +102,7 @@ ShardingSpec = sharding_specs.ShardingSpec
|
||||
|
||||
def identity(x): return x
|
||||
|
||||
def shard_arg(arg, devices, arg_indices, sharding):
|
||||
def shard_arg(arg, devices, arg_indices, sharding, canonicalize=True):
|
||||
"""Returns a list of size len(devices) containing per-device buffers.
|
||||
|
||||
For the C++ pmap path, we fallback to Python (this function) to shard
|
||||
@ -114,7 +113,8 @@ def shard_arg(arg, devices, arg_indices, sharding):
|
||||
devices: The list of devices to shard over.
|
||||
arg_indices: A list of `len(devices)` indices to use to shard the argument.
|
||||
"""
|
||||
arg = xla.canonicalize_dtype(arg)
|
||||
if canonicalize:
|
||||
arg = xla.canonicalize_dtype(arg)
|
||||
return shard_arg_handlers[type(arg)](arg, devices, arg_indices, sharding)
|
||||
|
||||
|
||||
@ -169,14 +169,6 @@ def _shard_darray(x, devices, indices, sharding):
|
||||
return shard_arg(x._data, devices, indices, sharding)
|
||||
shard_arg_handlers[core.DArray] = _shard_darray
|
||||
|
||||
def shard_device_array(x, devices, indices, sharding):
|
||||
start_indices, limit_indices, removed_dims = unzip3(
|
||||
as_slice_indices(x, idx) for idx in indices)
|
||||
shards = x._multi_slice(start_indices, limit_indices, removed_dims)
|
||||
aval = api_util.shaped_abstractify(x)
|
||||
out = batched_device_put(aval, sharding, shards, devices)
|
||||
return out
|
||||
|
||||
def batched_device_put(aval: core.ShapedArray,
|
||||
sharding: jax.sharding.Sharding, xs: Sequence[Any],
|
||||
devices: Sequence[jax.Device], committed: bool = True):
|
||||
@ -191,36 +183,6 @@ def batched_device_put(aval: core.ShapedArray,
|
||||
aval, sharding, bufs, committed=committed, _skip_checks=True)
|
||||
return xc.batched_device_put(aval, sharding, xs, devices, committed) # type: ignore
|
||||
|
||||
|
||||
# NOTE(skye): we could refactor to generate _multi_slice parameters directly
|
||||
# from the input ShardingSpec, rather than the indices. However, this would
|
||||
# require duplicating the ordering logic of spec_to_indices, which is more
|
||||
# subtle and more likely to change than the index logic we have to support here.
|
||||
def as_slice_indices(arr: Any, idx: Index) -> tuple[
|
||||
tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
"""Returns start_indices, limit_indices, removed_dims"""
|
||||
start_indices = [0] * arr.ndim
|
||||
limit_indices = list(arr.shape)
|
||||
removed_dims = []
|
||||
|
||||
tuple_idx = idx if isinstance(idx, tuple) else (idx,)
|
||||
for dim, sub_idx in enumerate(tuple_idx):
|
||||
if isinstance(sub_idx, int):
|
||||
start_indices[dim] = sub_idx
|
||||
limit_indices[dim] = sub_idx + 1
|
||||
removed_dims.append(dim)
|
||||
elif sub_idx == slice(None):
|
||||
continue
|
||||
else:
|
||||
assert isinstance(sub_idx, slice), sub_idx
|
||||
assert isinstance(sub_idx.start, int), sub_idx
|
||||
assert isinstance(sub_idx.stop, int), sub_idx
|
||||
start_indices[dim] = sub_idx.start
|
||||
limit_indices[dim] = sub_idx.stop
|
||||
|
||||
return tuple(start_indices), tuple(limit_indices), tuple(removed_dims) # type: ignore
|
||||
|
||||
|
||||
def shard_aval(size, axis: int, aval):
|
||||
try:
|
||||
return shard_aval_handlers[type(aval)](size, axis, aval)
|
||||
@ -298,45 +260,6 @@ global_result_handlers: dict[type[core.AbstractValue], PxlaResultHandler] = {}
|
||||
|
||||
### lazy device-memory persistence and result handling
|
||||
|
||||
def _hashable_index(idx):
|
||||
return tree_map(lambda x: (x.start, x.stop) if type(x) == slice else x, idx)
|
||||
|
||||
# The fast path is handled directly in shard_args().
|
||||
# TODO(yashkatariya): Move this to array.py when SDA is deleted. The local
|
||||
# import of Array should go away at that time.
|
||||
def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
|
||||
from jax._src.array import ArrayImpl
|
||||
|
||||
candidates = defaultdict(list)
|
||||
if isinstance(x, ArrayImpl):
|
||||
bufs = [buf.data for buf in x.addressable_shards]
|
||||
arr_indices = tuple(x.sharding.devices_indices_map(x.shape).values())
|
||||
else:
|
||||
bufs = x.device_buffers
|
||||
arr_indices = x.indices
|
||||
for buf, idx in safe_zip(bufs, arr_indices):
|
||||
candidates[_hashable_index(idx)].append(buf)
|
||||
|
||||
bufs = []
|
||||
for idx, device in safe_zip(indices, devices):
|
||||
# Look up all buffers that contain the correct slice of the logical array.
|
||||
candidates_list = candidates[_hashable_index(idx)]
|
||||
if not candidates_list:
|
||||
# This array isn't sharded correctly. Reshard it via host roundtrip.
|
||||
# TODO(skye): more efficient reshard?
|
||||
return shard_arg_handlers[type(x._value)](
|
||||
x._value, devices, indices, sharding)
|
||||
# Try to find a candidate buffer already on the correct device,
|
||||
# otherwise copy one of them.
|
||||
for buf in candidates_list:
|
||||
if buf.device() == device:
|
||||
bufs.append(buf)
|
||||
break
|
||||
else:
|
||||
bufs.append(buf)
|
||||
|
||||
return batched_device_put(x.aval, sharding, bufs, devices)
|
||||
|
||||
### the xla_pmap primitive and its rules are comparable to xla_call in xla.py
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user