From fe30d3fd4bbf92b890c97a75c4c47f4275ab77d1 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 20 Jul 2023 09:43:40 -0700 Subject: [PATCH] Move _array_shard_arg helpers from pxla into array. Refactoring only which fixes a TODO. Add a canonicalize argument to pxla.shard_arg so we can call that API from array yet avoid double-canonicalization. PiperOrigin-RevId: 549658117 --- jax/_src/array.py | 83 +++++++++++++++++++++++++++++++-- jax/_src/interpreters/pxla.py | 87 ++--------------------------------- 2 files changed, 84 insertions(+), 86 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index 9505e4d6f..4aae6754c 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -14,6 +14,7 @@ from __future__ import annotations +from collections import defaultdict import math import operator as op import numpy as np @@ -29,6 +30,7 @@ from jax._src import core from jax._src import dispatch from jax._src import dtypes from jax._src import profiler +from jax._src import tree_util from jax._src import xla_bridge from jax._src.config import config from jax._src.lib import xla_client as xc @@ -40,7 +42,7 @@ from jax._src.sharding_impls import ( SingleDeviceSharding, XLACompatibleSharding, PmapSharding, device_replica_id_map, hashed_index) from jax._src.typing import ArrayLike -from jax._src.util import use_cpp_class, use_cpp_method +from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method Shape = tuple[int, ...] Device = xc.Device @@ -681,6 +683,80 @@ def _array_mlir_constant_handler(val, canonicalize_types=True): mlir.register_constant_handler(ArrayImpl, _array_mlir_constant_handler) +# NOTE(skye): we could refactor to generate _multi_slice parameters directly +# from the input ShardingSpec, rather than the indices. However, this would +# require duplicating the ordering logic of spec_to_indices, which is more +# subtle and more likely to change than the index logic we have to support here. +def as_slice_indices(arr: Any, idx: Index) -> tuple[ + tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + """Returns start_indices, limit_indices, removed_dims""" + start_indices = [0] * arr.ndim + limit_indices = list(arr.shape) + removed_dims = [] + + tuple_idx = idx if isinstance(idx, tuple) else (idx,) + for dim, sub_idx in enumerate(tuple_idx): + if isinstance(sub_idx, int): + start_indices[dim] = sub_idx + limit_indices[dim] = sub_idx + 1 + removed_dims.append(dim) + elif sub_idx == slice(None): + continue + else: + assert isinstance(sub_idx, slice), sub_idx + assert isinstance(sub_idx.start, int), sub_idx + assert isinstance(sub_idx.stop, int), sub_idx + start_indices[dim] = sub_idx.start + limit_indices[dim] = sub_idx.stop + + return tuple(start_indices), tuple(limit_indices), tuple(removed_dims) # type: ignore + + +def shard_device_array(x, devices, indices, sharding): + start_indices, limit_indices, removed_dims = unzip3( + as_slice_indices(x, idx) for idx in indices) + shards = x._multi_slice(start_indices, limit_indices, removed_dims) + aval = api_util.shaped_abstractify(x) + out = pxla.batched_device_put(aval, sharding, shards, devices) + return out + +def _hashable_index(idx): + return tree_util.tree_map( + lambda x: (x.start, x.stop) if type(x) == slice else x, idx) + +# The fast path is handled directly in shard_args(). +def shard_sharded_device_array_slow_path(x, devices, indices, sharding): + candidates = defaultdict(list) + if isinstance(x, ArrayImpl): + bufs = [buf.data for buf in x.addressable_shards] + arr_indices = tuple(x.sharding.devices_indices_map(x.shape).values()) + else: + bufs = x.device_buffers + arr_indices = x.indices + for buf, idx in safe_zip(bufs, arr_indices): + candidates[_hashable_index(idx)].append(buf) + + bufs = [] + for idx, device in safe_zip(indices, devices): + # Look up all buffers that contain the correct slice of the logical array. + candidates_list = candidates[_hashable_index(idx)] + if not candidates_list: + # This array isn't sharded correctly. Reshard it via host roundtrip. + # TODO(skye): more efficient reshard? + return pxla.shard_arg(x._value, devices, indices, sharding, + canonicalize=False) + # Try to find a candidate buffer already on the correct device, + # otherwise copy one of them. + for buf in candidates_list: + if buf.device() == device: + bufs.append(buf) + break + else: + bufs.append(buf) + + return pxla.batched_device_put(x.aval, sharding, bufs, devices) + + def _array_shard_arg(x, devices, indices, sharding): x._check_if_deleted() @@ -697,10 +773,9 @@ def _array_shard_arg(x, devices, indices, sharding): x, list(devices), sharding) # Resharding starts here: if dispatch.is_single_device_sharding(x.sharding): - return pxla.shard_device_array(x, devices, indices, sharding) + return shard_device_array(x, devices, indices, sharding) else: - return pxla.shard_sharded_device_array_slow_path( - x, devices, indices, sharding) + return shard_sharded_device_array_slow_path(x, devices, indices, sharding) pxla.shard_arg_handlers[ArrayImpl] = _array_shard_arg diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 46a8fa496..f987640fb 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -17,7 +17,7 @@ from __future__ import annotations import enum from contextlib import contextmanager -from collections import defaultdict, namedtuple +from collections import namedtuple import dataclasses from functools import partial, lru_cache, cached_property import itertools as it @@ -30,7 +30,6 @@ import numpy as np import jax from jax.errors import JAXTypeError -from jax.tree_util import tree_map from jax._src import api_util from jax._src import core @@ -67,7 +66,7 @@ from jax._src.sharding_impls import ( AUTO, UnspecifiedValue, UNSPECIFIED, get_array_mapping as _get_array_mapping, is_auto, is_unspecified ) -from jax._src.util import (unzip3, safe_map, safe_zip, partition_list, +from jax._src.util import (safe_map, safe_zip, partition_list, wrap_name, tuple_delete, distributed_debug_log, unzip2, HashableFunction, weakref_lru_cache) @@ -103,7 +102,7 @@ ShardingSpec = sharding_specs.ShardingSpec def identity(x): return x -def shard_arg(arg, devices, arg_indices, sharding): +def shard_arg(arg, devices, arg_indices, sharding, canonicalize=True): """Returns a list of size len(devices) containing per-device buffers. For the C++ pmap path, we fallback to Python (this function) to shard @@ -114,7 +113,8 @@ def shard_arg(arg, devices, arg_indices, sharding): devices: The list of devices to shard over. arg_indices: A list of `len(devices)` indices to use to shard the argument. """ - arg = xla.canonicalize_dtype(arg) + if canonicalize: + arg = xla.canonicalize_dtype(arg) return shard_arg_handlers[type(arg)](arg, devices, arg_indices, sharding) @@ -169,14 +169,6 @@ def _shard_darray(x, devices, indices, sharding): return shard_arg(x._data, devices, indices, sharding) shard_arg_handlers[core.DArray] = _shard_darray -def shard_device_array(x, devices, indices, sharding): - start_indices, limit_indices, removed_dims = unzip3( - as_slice_indices(x, idx) for idx in indices) - shards = x._multi_slice(start_indices, limit_indices, removed_dims) - aval = api_util.shaped_abstractify(x) - out = batched_device_put(aval, sharding, shards, devices) - return out - def batched_device_put(aval: core.ShapedArray, sharding: jax.sharding.Sharding, xs: Sequence[Any], devices: Sequence[jax.Device], committed: bool = True): @@ -191,36 +183,6 @@ def batched_device_put(aval: core.ShapedArray, aval, sharding, bufs, committed=committed, _skip_checks=True) return xc.batched_device_put(aval, sharding, xs, devices, committed) # type: ignore - -# NOTE(skye): we could refactor to generate _multi_slice parameters directly -# from the input ShardingSpec, rather than the indices. However, this would -# require duplicating the ordering logic of spec_to_indices, which is more -# subtle and more likely to change than the index logic we have to support here. -def as_slice_indices(arr: Any, idx: Index) -> tuple[ - tuple[int, ...], tuple[int, ...], tuple[int, ...]]: - """Returns start_indices, limit_indices, removed_dims""" - start_indices = [0] * arr.ndim - limit_indices = list(arr.shape) - removed_dims = [] - - tuple_idx = idx if isinstance(idx, tuple) else (idx,) - for dim, sub_idx in enumerate(tuple_idx): - if isinstance(sub_idx, int): - start_indices[dim] = sub_idx - limit_indices[dim] = sub_idx + 1 - removed_dims.append(dim) - elif sub_idx == slice(None): - continue - else: - assert isinstance(sub_idx, slice), sub_idx - assert isinstance(sub_idx.start, int), sub_idx - assert isinstance(sub_idx.stop, int), sub_idx - start_indices[dim] = sub_idx.start - limit_indices[dim] = sub_idx.stop - - return tuple(start_indices), tuple(limit_indices), tuple(removed_dims) # type: ignore - - def shard_aval(size, axis: int, aval): try: return shard_aval_handlers[type(aval)](size, axis, aval) @@ -298,45 +260,6 @@ global_result_handlers: dict[type[core.AbstractValue], PxlaResultHandler] = {} ### lazy device-memory persistence and result handling -def _hashable_index(idx): - return tree_map(lambda x: (x.start, x.stop) if type(x) == slice else x, idx) - -# The fast path is handled directly in shard_args(). -# TODO(yashkatariya): Move this to array.py when SDA is deleted. The local -# import of Array should go away at that time. -def shard_sharded_device_array_slow_path(x, devices, indices, sharding): - from jax._src.array import ArrayImpl - - candidates = defaultdict(list) - if isinstance(x, ArrayImpl): - bufs = [buf.data for buf in x.addressable_shards] - arr_indices = tuple(x.sharding.devices_indices_map(x.shape).values()) - else: - bufs = x.device_buffers - arr_indices = x.indices - for buf, idx in safe_zip(bufs, arr_indices): - candidates[_hashable_index(idx)].append(buf) - - bufs = [] - for idx, device in safe_zip(indices, devices): - # Look up all buffers that contain the correct slice of the logical array. - candidates_list = candidates[_hashable_index(idx)] - if not candidates_list: - # This array isn't sharded correctly. Reshard it via host roundtrip. - # TODO(skye): more efficient reshard? - return shard_arg_handlers[type(x._value)]( - x._value, devices, indices, sharding) - # Try to find a candidate buffer already on the correct device, - # otherwise copy one of them. - for buf in candidates_list: - if buf.device() == device: - bufs.append(buf) - break - else: - bufs.append(buf) - - return batched_device_put(x.aval, sharding, bufs, devices) - ### the xla_pmap primitive and its rules are comparable to xla_call in xla.py