Rollback of array fix again for perf regression.

PiperOrigin-RevId: 511879030
This commit is contained in:
Parker Schuh 2023-02-23 13:59:11 -08:00 committed by jax authors
parent 35a27359d0
commit b5026207bc
3 changed files with 15 additions and 17 deletions

View File

@ -351,7 +351,7 @@ class ArrayImpl(basearray.Array):
'named_shape': self.aval.named_shape}
return (_reconstruct_array, (fun, args, arr_state, aval_state))
@use_cpp_method(xla_extension_version >= 128)
@use_cpp_method(xla_extension_version >= 128 and xla_extension_version <= 129)
def unsafe_buffer_pointer(self):
if len(self._arrays) != 1:
raise ValueError("unsafe_buffer_pointer() is supported only for unsharded"
@ -443,7 +443,7 @@ class ArrayImpl(basearray.Array):
out.append(Shard(global_d, self.sharding, self.shape, array))
return out
@use_cpp_method(xla_extension_version >= 128)
@use_cpp_method(xla_extension_version >= 128 and xla_extension_version <= 129)
def delete(self):
if self._arrays is None:
return
@ -489,7 +489,7 @@ class ArrayImpl(basearray.Array):
for s in self.addressable_shards:
if not replica_id_exists or s.replica_id == 0:
if xla_extension_version >= 128:
if xla_extension_version >= 128 and xla_extension_version <= 129:
s.data.copy_to_host_async() # pytype: disable=attribute-error
else:
s.data._arrays[0].copy_to_host_async() # pytype: disable=attribute-error

View File

@ -29,7 +29,6 @@ from jax._src import profiler
from jax._src import util
from jax._src.config import config
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.typing import Array
### device-persistent data
@ -60,14 +59,6 @@ def make_device_array(
This is to be used only within JAX. It will return either a PythonDeviceArray
or a C++ equivalent implementation.
"""
from jax._src import array
if jax.config.jax_array and xla_extension_version >= 128:
if isinstance(device_buffer, xc.Buffer):
return array._single_device_array_from_buf(
device_buffer, False if device_buffer._device is None else True)
elif isinstance(device_buffer, array.ArrayImpl):
return device_buffer
if isinstance(device_buffer, xc.Buffer):
@ -135,6 +126,7 @@ class _DeviceArray(DeviceArray): # type: ignore
aval, npy_value.shape, npy_value.dtype)
assert (device is None) or device is device_buffer.device()
def _check_if_deleted(self):
if self.device_buffer is deleted_buffer:
raise RuntimeError("DeviceArray has been deleted.")

View File

@ -1305,7 +1305,7 @@ for t in device_array.device_array_types:
def _device_put_jax_array(x, device: Optional[Device]):
if is_single_device_sharding(x.sharding):
if xla_extension_version >= 128:
if xla_extension_version >= 128 and xla_extension_version <= 129:
return (_copy_array_to_device(x, device),)
else:
x = _copy_device_array_to_device(_set_aval(x._arrays[0]), device)
@ -1331,10 +1331,16 @@ def _copy_device_array_to_device(
# source and target platforms are the same
if x.device_buffer.device() == device:
# no copying to be done because source equals target
if x.device == device:
return x
if xla_extension_version <= 129:
if x.device == device:
return x
else:
moved_buf = x.device_buffer # We need to change stickyness
else:
moved_buf = x.device_buffer # We need to change stickyness
if x._device == device:
return x
else:
moved_buf = x.device_buffer # We need to change stickyness
else:
# move the buffer with a device-to-device copy
moved_buf = x.device_buffer.copy_to_device(device)
@ -1352,7 +1358,7 @@ def _copy_array_to_device(x: jax.Array, device: Optional[xc.Device]) -> jax.Arra
return x
arr = x._arrays[0]
if xla_extension_version >= 128 and isinstance(arr, array.ArrayImpl):
if xla_extension_version >= 128 and isinstance(arr, array.ArrayImpl) and xla_extension_version <= 129:
# buffers from different XLA backends are passed through the host.
if xb.get_device_backend(device).platform != arr.platform():
backend = xb.get_device_backend(device)