Merge pull request #6933 from lgeiger:multi-slice

PiperOrigin-RevId: 383608401
This commit is contained in:
jax authors 2021-07-08 04:54:03 -07:00
commit 78a689bb09
2 changed files with 3 additions and 3 deletions

View File

@ -6026,7 +6026,7 @@ def _multi_slice(arr,
for starts, limits, removed in safe_zip(start_indices, limit_indices, removed_dims):
sliced = lax.slice(arr, starts, limits)
if removed:
sliced = sliced.reshape(np.delete(sliced.shape, removed_dims))
sliced = lax.squeeze(sliced, removed)
results.append(sliced)
return results
setattr(_DeviceArray, "_multi_slice", _multi_slice)

View File

@ -348,8 +348,8 @@ for _t in array_types:
shard_arg_handlers[_t] = _shard_array
def _shard_device_array(x, devices, indices):
start_indices, limit_indices, removed_dims = map(tuple, unzip3(
_as_slice_indices(x, idx) for idx in indices))
start_indices, limit_indices, removed_dims = unzip3(
_as_slice_indices(x, idx) for idx in indices)
shards = x._multi_slice(start_indices, limit_indices, removed_dims)
return device_put(shards, devices)
shard_arg_handlers[xla._DeviceArray] = _shard_device_array