mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Optimize _create_copy_plan
by iterating over only the shards that are needed for materialization
For arrays that are fully or partially replicated, it is more efficient to (pre-)construct a list of addressable array shards that participate in materialization rather than going over all array shards. This is particularly useful for single-controller JAX. The implementation assumes that addressable arrays appear in the same order as the corresponding addressable devices in `sharding.addressable_devices_indices_map()`. PiperOrigin-RevId: 624969222
This commit is contained in:
parent
3a09404426
commit
ac1a53d8e4
@ -125,23 +125,13 @@ def _reconstruct_array(fun, args, arr_state, aval_state):
|
||||
def _cached_index_calc(s, shape):
|
||||
map_ = s.addressable_devices_indices_map(shape)
|
||||
seen_h_indices = set()
|
||||
m = {}
|
||||
for d, index in map_.items():
|
||||
l = []
|
||||
for array_index, index in enumerate(map_.values()):
|
||||
h_index = hashed_index(index)
|
||||
if h_index not in seen_h_indices:
|
||||
seen_h_indices.add(h_index)
|
||||
m[d] = index
|
||||
return m
|
||||
|
||||
|
||||
def _create_copy_plan(arrays, s: Sharding, shape: Shape):
|
||||
di_map = _cached_index_calc(s, shape)
|
||||
copy_plan = []
|
||||
for a in arrays:
|
||||
ind = di_map.get(a.sharding._internal_device_list[0], None) # type:ignore
|
||||
if ind is not None:
|
||||
copy_plan.append((ind, a))
|
||||
return copy_plan
|
||||
l.append((array_index, index))
|
||||
return l
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
@ -607,9 +597,8 @@ class ArrayImpl(basearray.Array):
|
||||
if self.is_fully_replicated:
|
||||
self._copy_single_device_array_to_host_async()
|
||||
return
|
||||
copy_plan = _create_copy_plan(self._arrays, self.sharding, self.shape)
|
||||
for _, arr in copy_plan:
|
||||
arr._copy_single_device_array_to_host_async()
|
||||
for i, _ in _cached_index_calc(self.sharding, self.shape):
|
||||
self._arrays[i]._copy_single_device_array_to_host_async()
|
||||
|
||||
@property
|
||||
@functools.partial(profiler.annotate_function, name="np.asarray(jax.Array)")
|
||||
@ -631,13 +620,12 @@ class ArrayImpl(basearray.Array):
|
||||
"`jax.experimental.multihost_utils.process_allgather` "
|
||||
"for this use case.")
|
||||
|
||||
copy_plan = _create_copy_plan(self._arrays, self.sharding, self.shape)
|
||||
for _, arr in copy_plan:
|
||||
arr._copy_single_device_array_to_host_async()
|
||||
for i, _ in _cached_index_calc(self.sharding, self.shape):
|
||||
self._arrays[i]._copy_single_device_array_to_host_async()
|
||||
|
||||
npy_value = np.empty(self.shape, self.dtype)
|
||||
for ind, arr in copy_plan:
|
||||
npy_value[ind] = arr._single_device_array_to_np_array()
|
||||
for i, ind in _cached_index_calc(self.sharding, self.shape):
|
||||
npy_value[ind] = self._arrays[i]._single_device_array_to_np_array()
|
||||
self._npy_value = npy_value # type: ignore
|
||||
self._npy_value.flags.writeable = False
|
||||
# https://docs.python.org/3/library/typing.html#typing.cast
|
||||
|
Loading…
x
Reference in New Issue
Block a user