mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Rollback of array fix again for perf regression.
PiperOrigin-RevId: 511879030
This commit is contained in:
parent
35a27359d0
commit
b5026207bc
@ -351,7 +351,7 @@ class ArrayImpl(basearray.Array):
|
||||
'named_shape': self.aval.named_shape}
|
||||
return (_reconstruct_array, (fun, args, arr_state, aval_state))
|
||||
|
||||
@use_cpp_method(xla_extension_version >= 128)
|
||||
@use_cpp_method(xla_extension_version >= 128 and xla_extension_version <= 129)
|
||||
def unsafe_buffer_pointer(self):
|
||||
if len(self._arrays) != 1:
|
||||
raise ValueError("unsafe_buffer_pointer() is supported only for unsharded"
|
||||
@ -443,7 +443,7 @@ class ArrayImpl(basearray.Array):
|
||||
out.append(Shard(global_d, self.sharding, self.shape, array))
|
||||
return out
|
||||
|
||||
@use_cpp_method(xla_extension_version >= 128)
|
||||
@use_cpp_method(xla_extension_version >= 128 and xla_extension_version <= 129)
|
||||
def delete(self):
|
||||
if self._arrays is None:
|
||||
return
|
||||
@ -489,7 +489,7 @@ class ArrayImpl(basearray.Array):
|
||||
|
||||
for s in self.addressable_shards:
|
||||
if not replica_id_exists or s.replica_id == 0:
|
||||
if xla_extension_version >= 128:
|
||||
if xla_extension_version >= 128 and xla_extension_version <= 129:
|
||||
s.data.copy_to_host_async() # pytype: disable=attribute-error
|
||||
else:
|
||||
s.data._arrays[0].copy_to_host_async() # pytype: disable=attribute-error
|
||||
|
@ -29,7 +29,6 @@ from jax._src import profiler
|
||||
from jax._src import util
|
||||
from jax._src.config import config
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.typing import Array
|
||||
|
||||
### device-persistent data
|
||||
@ -60,14 +59,6 @@ def make_device_array(
|
||||
This is to be used only within JAX. It will return either a PythonDeviceArray
|
||||
or a C++ equivalent implementation.
|
||||
"""
|
||||
from jax._src import array
|
||||
|
||||
if jax.config.jax_array and xla_extension_version >= 128:
|
||||
if isinstance(device_buffer, xc.Buffer):
|
||||
return array._single_device_array_from_buf(
|
||||
device_buffer, False if device_buffer._device is None else True)
|
||||
elif isinstance(device_buffer, array.ArrayImpl):
|
||||
return device_buffer
|
||||
|
||||
if isinstance(device_buffer, xc.Buffer):
|
||||
|
||||
@ -135,6 +126,7 @@ class _DeviceArray(DeviceArray): # type: ignore
|
||||
aval, npy_value.shape, npy_value.dtype)
|
||||
assert (device is None) or device is device_buffer.device()
|
||||
|
||||
|
||||
def _check_if_deleted(self):
|
||||
if self.device_buffer is deleted_buffer:
|
||||
raise RuntimeError("DeviceArray has been deleted.")
|
||||
|
@ -1305,7 +1305,7 @@ for t in device_array.device_array_types:
|
||||
|
||||
def _device_put_jax_array(x, device: Optional[Device]):
|
||||
if is_single_device_sharding(x.sharding):
|
||||
if xla_extension_version >= 128:
|
||||
if xla_extension_version >= 128 and xla_extension_version <= 129:
|
||||
return (_copy_array_to_device(x, device),)
|
||||
else:
|
||||
x = _copy_device_array_to_device(_set_aval(x._arrays[0]), device)
|
||||
@ -1331,10 +1331,16 @@ def _copy_device_array_to_device(
|
||||
# source and target platforms are the same
|
||||
if x.device_buffer.device() == device:
|
||||
# no copying to be done because source equals target
|
||||
if x.device == device:
|
||||
return x
|
||||
if xla_extension_version <= 129:
|
||||
if x.device == device:
|
||||
return x
|
||||
else:
|
||||
moved_buf = x.device_buffer # We need to change stickyness
|
||||
else:
|
||||
moved_buf = x.device_buffer # We need to change stickyness
|
||||
if x._device == device:
|
||||
return x
|
||||
else:
|
||||
moved_buf = x.device_buffer # We need to change stickyness
|
||||
else:
|
||||
# move the buffer with a device-to-device copy
|
||||
moved_buf = x.device_buffer.copy_to_device(device)
|
||||
@ -1352,7 +1358,7 @@ def _copy_array_to_device(x: jax.Array, device: Optional[xc.Device]) -> jax.Arra
|
||||
return x
|
||||
|
||||
arr = x._arrays[0]
|
||||
if xla_extension_version >= 128 and isinstance(arr, array.ArrayImpl):
|
||||
if xla_extension_version >= 128 and isinstance(arr, array.ArrayImpl) and xla_extension_version <= 129:
|
||||
# buffers from different XLA backends are passed through the host.
|
||||
if xb.get_device_backend(device).platform != arr.platform():
|
||||
backend = xb.get_device_backend(device)
|
||||
|
Loading…
x
Reference in New Issue
Block a user