mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Roll-back https://github.com/google/jax/pull/14526 because it breaks view()
on scalar inputs
PiperOrigin-RevId: 510281592
This commit is contained in:
parent
c467d84eea
commit
e1333f3de0
@ -4978,54 +4978,47 @@ def _view(arr: Array, dtype: DTypeLike = None, type: None = None) -> Array:
|
||||
|
||||
This is fuller-featured wrapper around :func:`jax.lax.bitcast_convert_type`.
|
||||
"""
|
||||
if type is not None:
|
||||
raise NotImplementedError("`type` argument of array.view() is not supported.")
|
||||
|
||||
_check_arraylike("view", arr)
|
||||
arr = asarray(arr)
|
||||
|
||||
lax_internal._check_user_dtype_supported(dtype, "view")
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
|
||||
if (arr.shape[-1] * arr.dtype.itemsize) % dtype.itemsize != 0:
|
||||
if type is not None:
|
||||
raise NotImplementedError("`type` argument of array.view()")
|
||||
if dtype is None:
|
||||
return arr
|
||||
arr_dtype = _dtype(arr)
|
||||
if arr_dtype == dtype:
|
||||
return arr
|
||||
# bool is implemented as lax:PRED, which is not compatible with lax.bitcast_convert_type.
|
||||
# We work around this by casting bool to uint8.
|
||||
if arr_dtype == bool_:
|
||||
arr = arr.astype(uint8)
|
||||
nbits_in = 8 * arr_dtype.itemsize
|
||||
nbits_out = 8 * np.dtype(dtype).itemsize
|
||||
if nbits_in == nbits_out:
|
||||
if dtype == bool_:
|
||||
return lax.bitcast_convert_type(arr, uint8).astype(dtype)
|
||||
return lax.bitcast_convert_type(arr, dtype)
|
||||
if nbits_out > nbits_in and (shape(arr)[-1] * nbits_in) % nbits_out != 0:
|
||||
raise ValueError("When changing to a larger dtype, its size must be a divisor "
|
||||
"of the total size in bytes of the last axis of the array.")
|
||||
|
||||
if arr.dtype == dtype:
|
||||
return arr
|
||||
|
||||
# lax.bitcast_convert_type does not support bool or complex; in these cases we
|
||||
# cast to a compatible type and recursively call _view for simplicity.
|
||||
if arr.dtype == bool:
|
||||
return _view(arr.astype('uint8'), dtype)
|
||||
|
||||
if issubdtype(arr.dtype, complexfloating):
|
||||
new_shape = (*arr.shape[:-1], arr.shape[-1] * 2)
|
||||
new_dtype = finfo(arr.dtype).dtype
|
||||
arr = (zeros(new_shape, new_dtype)
|
||||
.at[..., 0::2].set(arr.real)
|
||||
.at[..., 1::2].set(arr.imag))
|
||||
return _view(arr, dtype)
|
||||
|
||||
if dtype == bool:
|
||||
return _view(arr, uint8).astype(bool)
|
||||
|
||||
if issubdtype(dtype, complexfloating):
|
||||
out = _view(arr, finfo(dtype).dtype).astype(dtype)
|
||||
return out[..., 0::2] + 1j * out[..., 1::2]
|
||||
|
||||
# lax.bitcast_convert_type adds or subtracts dimensions depending on the
|
||||
# relative bitwidths of the dtypes; we account for that with reshapes.
|
||||
if arr.dtype.itemsize < dtype.itemsize:
|
||||
factor = dtype.itemsize // arr.dtype.itemsize
|
||||
arr = arr.reshape(*arr.shape[:-1], arr.shape[-1] // factor, factor)
|
||||
return lax.bitcast_convert_type(arr, dtype)
|
||||
|
||||
if arr.dtype.itemsize > dtype.itemsize:
|
||||
out = lax.bitcast_convert_type(arr, dtype)
|
||||
return out.reshape(*out.shape[:-2], out.shape[-2] * out.shape[-1])
|
||||
|
||||
return lax.bitcast_convert_type(arr, dtype)
|
||||
byte_dtypes: Dict[int, DType] = {8: np.dtype('uint8'), 16: np.dtype('uint16'),
|
||||
32: np.dtype('uint32'), 64: np.dtype('uint64')}
|
||||
if nbits_in not in byte_dtypes:
|
||||
raise NotImplementedError(f"arr.view() for arr.dtype={arr_dtype}")
|
||||
if nbits_out not in byte_dtypes:
|
||||
raise NotImplementedError(f"arr.view(dtype) for {dtype=}")
|
||||
dt_in = byte_dtypes[nbits_in]
|
||||
dt_out = byte_dtypes[nbits_out]
|
||||
arr_bytes = lax.bitcast_convert_type(arr, dt_in)
|
||||
if nbits_in < nbits_out:
|
||||
arr_bytes = arr_bytes.reshape(arr.shape[:-1] + (-1, nbits_out // nbits_in)).astype(dt_out)
|
||||
shifts = expand_dims(arange(0, nbits_out, nbits_in, dtype=dt_out), tuple(range(arr_bytes.ndim - 1)))
|
||||
arr_bytes = (arr_bytes << shifts).sum(-1).astype(dt_out)
|
||||
else:
|
||||
shifts = lax.expand_dims(arange(0, nbits_in, nbits_out, dtype=dt_in), tuple(range(arr_bytes.ndim)))
|
||||
arr_bytes = ((arr_bytes[..., newaxis] >> shifts) & iinfo(dt_out).max).astype(dt_out)
|
||||
arr_bytes = arr_bytes.reshape(arr_bytes.shape[:-2] + (-1,))
|
||||
if dtype == bool_:
|
||||
return lax.bitcast_convert_type(arr_bytes, uint8).astype(dtype)
|
||||
return lax.bitcast_convert_type(arr_bytes, dtype)
|
||||
|
||||
def _notimplemented_flat(self):
|
||||
raise NotImplementedError("JAX DeviceArrays do not implement the arr.flat property: "
|
||||
|
@ -3430,15 +3430,17 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(jnp_op, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
# Final dimension must be a multiple of 16 to ensure compatibilty of all dtype pairs.
|
||||
shape=[(0,), (32,), (2, 16)],
|
||||
a_dtype=all_dtypes,
|
||||
dtype=(*all_dtypes, None) if config.x64_enabled else all_dtypes,
|
||||
shape=[(8,), (3, 8)], # last dim = 8 to ensure shape compatibility
|
||||
a_dtype=default_dtypes + unsigned_dtypes + bool_dtypes,
|
||||
dtype=default_dtypes + unsigned_dtypes + bool_dtypes,
|
||||
)
|
||||
def testView(self, shape, a_dtype, dtype):
|
||||
if jtu.device_under_test() == 'tpu':
|
||||
if jnp.dtype(a_dtype).itemsize in [1, 2] or jnp.dtype(dtype).itemsize in [1, 2]:
|
||||
self.skipTest("arr.view() not supported on TPU for 8- or 16-bit types.")
|
||||
if not config.x64_enabled:
|
||||
if jnp.dtype(a_dtype).itemsize == 8 or jnp.dtype(dtype).itemsize == 8:
|
||||
self.skipTest("x64 types are disabled by jax_enable_x64")
|
||||
rng = jtu.rand_fullrange(self.rng())
|
||||
args_maker = lambda: [rng(shape, a_dtype)]
|
||||
np_op = lambda x: np.asarray(x).view(dtype)
|
||||
|
Loading…
x
Reference in New Issue
Block a user