Decouple ShardedDeviceArray from _DeviceArray

This commit is contained in:
Parker Schuh 2021-02-25 12:52:33 -08:00
parent 5b288b71a3
commit 10289390f3
3 changed files with 24 additions and 6 deletions

View File

@ -5350,7 +5350,7 @@ setattr(ShapedArray, "nbytes", core.aval_property(_nbytes))
# Forward operators, methods, and properties on DeviceArray to lax_numpy
# functions (with no Tracers involved; this forwarding is direct)
for device_array in [_DeviceArray, _CppDeviceArray]:
for device_array in [DeviceArray]:
for operator_name, function in _operators.items():
setattr(device_array, "__{}__".format(operator_name), function)
for method_name in _nondiff_methods + _diff_methods:

View File

@ -431,7 +431,7 @@ pxla_result_handlers[ConcreteArray] = array_result_handler
### lazy device-memory persistence and result handling
class ShardedDeviceArray(xla._DeviceArray):
class ShardedDeviceArray(xla.DeviceArray): # type: ignore
"""A ShardedDeviceArray is an ndarray sharded across devices.
The purpose of a ShardedDeviceArray is to reduce the number of transfers when
@ -456,8 +456,10 @@ class ShardedDeviceArray(xla._DeviceArray):
stored in the corresponding device buffer, i.e. `array[indices[i]] ==
device_buffers[i].to_py()`.
"""
__slots__ = ["device_buffers", "sharding_spec", "indices",
"_one_replica_buffer_indices"]
__slots__ = [
"aval", "device_buffers", "sharding_spec", "indices",
"_one_replica_buffer_indices", "_npy_value"
]
# TODO(skye): expose PyLocalBuffers in xla_client
def __init__(self,
@ -502,6 +504,22 @@ class ShardedDeviceArray(xla._DeviceArray):
self._one_replica_buffer_indices = one_replica_indices
return self._one_replica_buffer_indices
@property
def shape(self):
return self.aval.shape
@property
def dtype(self):
return self.aval.dtype
@property
def size(self):
return prod(self.aval.shape)
@property
def ndim(self):
return len(self.aval.shape)
def copy_to_host_async(self):
for buffer_index in self.one_replica_buffer_indices:
self.device_buffers[buffer_index].copy_to_host_async()
@ -546,7 +564,7 @@ class ShardedDeviceArray(xla._DeviceArray):
buf = self.device_buffers[buf_idx]
aval = ShapedArray(buf.xla_shape().dimensions(), self.aval.dtype)
return xla.make_device_array(aval, None, lazy.array(aval.shape), buf)
return super(ShardedDeviceArray, self).__getitem__(idx)
return xla.DeviceArray.__getitem__(self, idx)
def _hashable_index(idx):

View File

@ -1172,7 +1172,7 @@ class _DeviceArray(DeviceArray): # type: ignore
# Adding methods dynamically to both _DeviceArray and _CppDeviceArray
# pylint: disable=protected-access
for device_array in [_DeviceArray, _CppDeviceArray]:
for device_array in [DeviceArray]:
def copy(self):