mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Decouple ShardedDeviceArray from _DeviceArray
This commit is contained in:
parent
5b288b71a3
commit
10289390f3
@ -5350,7 +5350,7 @@ setattr(ShapedArray, "nbytes", core.aval_property(_nbytes))
|
||||
|
||||
# Forward operators, methods, and properties on DeviceArray to lax_numpy
|
||||
# functions (with no Tracers involved; this forwarding is direct)
|
||||
for device_array in [_DeviceArray, _CppDeviceArray]:
|
||||
for device_array in [DeviceArray]:
|
||||
for operator_name, function in _operators.items():
|
||||
setattr(device_array, "__{}__".format(operator_name), function)
|
||||
for method_name in _nondiff_methods + _diff_methods:
|
||||
|
@ -431,7 +431,7 @@ pxla_result_handlers[ConcreteArray] = array_result_handler
|
||||
|
||||
### lazy device-memory persistence and result handling
|
||||
|
||||
class ShardedDeviceArray(xla._DeviceArray):
|
||||
class ShardedDeviceArray(xla.DeviceArray): # type: ignore
|
||||
"""A ShardedDeviceArray is an ndarray sharded across devices.
|
||||
|
||||
The purpose of a ShardedDeviceArray is to reduce the number of transfers when
|
||||
@ -456,8 +456,10 @@ class ShardedDeviceArray(xla._DeviceArray):
|
||||
stored in the corresponding device buffer, i.e. `array[indices[i]] ==
|
||||
device_buffers[i].to_py()`.
|
||||
"""
|
||||
__slots__ = ["device_buffers", "sharding_spec", "indices",
|
||||
"_one_replica_buffer_indices"]
|
||||
__slots__ = [
|
||||
"aval", "device_buffers", "sharding_spec", "indices",
|
||||
"_one_replica_buffer_indices", "_npy_value"
|
||||
]
|
||||
|
||||
# TODO(skye): expose PyLocalBuffers in xla_client
|
||||
def __init__(self,
|
||||
@ -502,6 +504,22 @@ class ShardedDeviceArray(xla._DeviceArray):
|
||||
self._one_replica_buffer_indices = one_replica_indices
|
||||
return self._one_replica_buffer_indices
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.aval.shape
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.aval.dtype
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
return prod(self.aval.shape)
|
||||
|
||||
@property
|
||||
def ndim(self):
|
||||
return len(self.aval.shape)
|
||||
|
||||
def copy_to_host_async(self):
|
||||
for buffer_index in self.one_replica_buffer_indices:
|
||||
self.device_buffers[buffer_index].copy_to_host_async()
|
||||
@ -546,7 +564,7 @@ class ShardedDeviceArray(xla._DeviceArray):
|
||||
buf = self.device_buffers[buf_idx]
|
||||
aval = ShapedArray(buf.xla_shape().dimensions(), self.aval.dtype)
|
||||
return xla.make_device_array(aval, None, lazy.array(aval.shape), buf)
|
||||
return super(ShardedDeviceArray, self).__getitem__(idx)
|
||||
return xla.DeviceArray.__getitem__(self, idx)
|
||||
|
||||
|
||||
def _hashable_index(idx):
|
||||
|
@ -1172,7 +1172,7 @@ class _DeviceArray(DeviceArray): # type: ignore
|
||||
|
||||
# Adding methods dynamically to both _DeviceArray and _CppDeviceArray
|
||||
# pylint: disable=protected-access
|
||||
for device_array in [_DeviceArray, _CppDeviceArray]:
|
||||
for device_array in [DeviceArray]:
|
||||
|
||||
|
||||
def copy(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user