mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #6933 from lgeiger:multi-slice
PiperOrigin-RevId: 383608401
This commit is contained in:
commit
78a689bb09
@ -6026,7 +6026,7 @@ def _multi_slice(arr,
|
||||
for starts, limits, removed in safe_zip(start_indices, limit_indices, removed_dims):
|
||||
sliced = lax.slice(arr, starts, limits)
|
||||
if removed:
|
||||
sliced = sliced.reshape(np.delete(sliced.shape, removed_dims))
|
||||
sliced = lax.squeeze(sliced, removed)
|
||||
results.append(sliced)
|
||||
return results
|
||||
setattr(_DeviceArray, "_multi_slice", _multi_slice)
|
||||
|
@ -348,8 +348,8 @@ for _t in array_types:
|
||||
shard_arg_handlers[_t] = _shard_array
|
||||
|
||||
def _shard_device_array(x, devices, indices):
|
||||
start_indices, limit_indices, removed_dims = map(tuple, unzip3(
|
||||
_as_slice_indices(x, idx) for idx in indices))
|
||||
start_indices, limit_indices, removed_dims = unzip3(
|
||||
_as_slice_indices(x, idx) for idx in indices)
|
||||
shards = x._multi_slice(start_indices, limit_indices, removed_dims)
|
||||
return device_put(shards, devices)
|
||||
shard_arg_handlers[xla._DeviceArray] = _shard_device_array
|
||||
|
Loading…
x
Reference in New Issue
Block a user