mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Merge pull request #25650 from jakevdp:view-int4
PiperOrigin-RevId: 708468858
This commit is contained in:
commit
1c0dee8012
@ -209,14 +209,16 @@ def bit_width(dtype: DTypeLike) -> int:
|
||||
"""Number of bits per element for the dtype."""
|
||||
# Note: we cannot use dtype.itemsize here because this is
|
||||
# incorrect for sub-byte integer types.
|
||||
if dtype == bool:
|
||||
if dtype == np.dtype(bool):
|
||||
return 8 # physical bit layout for boolean dtype
|
||||
elif issubdtype(dtype, np.integer):
|
||||
return iinfo(dtype).bits
|
||||
elif issubdtype(dtype, np.inexact):
|
||||
elif issubdtype(dtype, np.floating):
|
||||
return finfo(dtype).bits
|
||||
elif issubdtype(dtype, np.complexfloating):
|
||||
return 2 * finfo(dtype).bits
|
||||
else:
|
||||
raise ValueError("unexpected input: {dtype=}")
|
||||
raise ValueError(f"unexpected input: {dtype=}")
|
||||
|
||||
# Trivial vectorspace datatype needed for tangent values of int/bool primals
|
||||
float0: np.dtype = np.dtype([('float0', np.void, 0)])
|
||||
|
@ -509,12 +509,15 @@ def _view(self: Array, dtype: DTypeLike | None = None, type: None = None) -> Arr
|
||||
dtypes.check_user_dtype_supported(dtype, "view")
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
|
||||
nbits_in = dtypes.bit_width(self.dtype)
|
||||
nbits_out = dtypes.bit_width(dtype)
|
||||
|
||||
if self.ndim == 0:
|
||||
if self.dtype.itemsize != dtype.itemsize:
|
||||
if nbits_in != nbits_out:
|
||||
raise ValueError("view() of a 0d array is only supported if the itemsize is unchanged.")
|
||||
return _view(lax.expand_dims(self, (0,)), dtype).squeeze()
|
||||
|
||||
if (self.shape[-1] * self.dtype.itemsize) % dtype.itemsize != 0:
|
||||
if (self.shape[-1] * nbits_in) % nbits_out != 0:
|
||||
raise ValueError("When changing to a larger dtype, its size must be a divisor "
|
||||
"of the total size in bytes of the last axis of the array.")
|
||||
|
||||
@ -543,16 +546,15 @@ def _view(self: Array, dtype: DTypeLike | None = None, type: None = None) -> Arr
|
||||
|
||||
# lax.bitcast_convert_type adds or subtracts dimensions depending on the
|
||||
# relative bitwidths of the dtypes; we account for that with reshapes.
|
||||
if self.dtype.itemsize < dtype.itemsize:
|
||||
factor = dtype.itemsize // self.dtype.itemsize
|
||||
if nbits_in < nbits_out:
|
||||
factor = nbits_out // nbits_in
|
||||
out = self.reshape(*self.shape[:-1], self.shape[-1] // factor, factor)
|
||||
return lax.bitcast_convert_type(out, dtype)
|
||||
|
||||
if self.dtype.itemsize > dtype.itemsize:
|
||||
elif nbits_in > nbits_out:
|
||||
out = lax.bitcast_convert_type(self, dtype)
|
||||
return out.reshape(*out.shape[:-2], out.shape[-2] * out.shape[-1])
|
||||
|
||||
return lax.bitcast_convert_type(self, dtype)
|
||||
else:
|
||||
return lax.bitcast_convert_type(self, dtype)
|
||||
|
||||
|
||||
def _notimplemented_flat(self):
|
||||
|
@ -87,6 +87,32 @@ python_scalar_dtypes = [jnp.bool_, jnp.int_, jnp.float_, jnp.complex_]
|
||||
# uint64 is problematic because with any uint type it promotes to float:
|
||||
int_dtypes_no_uint64 = [d for d in int_dtypes + unsigned_dtypes if d != np.uint64]
|
||||
|
||||
def _bitcast_uint4_to_uint8(operand):
|
||||
# Note: assumes little-endian byte order.
|
||||
assert operand.dtype == 'uint4'
|
||||
operand = operand.astype('uint8')
|
||||
return operand[..., ::2] + (operand[..., 1::2] << 4)
|
||||
|
||||
def _bitcast_uint8_to_uint4(operand):
|
||||
# Note: assumes little-endian byte order.
|
||||
assert operand.dtype == 'uint8'
|
||||
result = np.zeros((*operand.shape[:-1], operand.shape[-1] * 2), dtype='uint4')
|
||||
result[..., ::2] = (operand & 0b00001111).astype('uint4')
|
||||
result[..., 1::2] = ((operand & 0b11110000) >> 4).astype('uint4')
|
||||
return result
|
||||
|
||||
def np_view(arr, dtype):
|
||||
# Implementation of np.ndarray.view() that works for int4/uint4
|
||||
dtype = np.dtype(dtype)
|
||||
nbits_in = dtypes.bit_width(arr.dtype)
|
||||
nbits_out = dtypes.bit_width(dtype)
|
||||
if nbits_in == 4:
|
||||
arr = _bitcast_uint4_to_uint8(arr.view('uint4'))
|
||||
if nbits_out == 4:
|
||||
arr = _bitcast_uint8_to_uint4(arr.view('uint8'))
|
||||
return arr.view(dtype)
|
||||
|
||||
|
||||
def np_unique_backport(ar, return_index=False, return_inverse=False, return_counts=False,
|
||||
axis=None, **kwds):
|
||||
# Wrapper for np.unique, handling the change to inverse_indices in numpy 2.0
|
||||
@ -4244,9 +4270,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(
|
||||
# Final dimension must be a multiple of 16 to ensure compatibility of all dtype pairs.
|
||||
shape=[(0,), (32,), (2, 16)],
|
||||
a_dtype=all_dtypes,
|
||||
dtype=(*all_dtypes, None) if config.enable_x64.value else all_dtypes,
|
||||
shape=[(0,), (64,), (2, 32)],
|
||||
a_dtype=(jnp.int4, jnp.uint4, *all_dtypes),
|
||||
dtype=((jnp.int4, jnp.uint4, *all_dtypes, None)
|
||||
if config.enable_x64.value else (jnp.int4, jnp.uint4, *all_dtypes)),
|
||||
)
|
||||
def testView(self, shape, a_dtype, dtype):
|
||||
if jtu.test_device_matches(["tpu"]):
|
||||
@ -4259,7 +4286,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self.rng()
|
||||
)
|
||||
args_maker = lambda: [rng(shape, a_dtype)]
|
||||
np_op = lambda x: np.asarray(x).view(dtype)
|
||||
np_op = lambda x: np_view(x, dtype)
|
||||
jnp_op = lambda x: jnp.asarray(x).view(dtype)
|
||||
# Above may produce signaling nans; ignore warnings from invalid values.
|
||||
with np.errstate(invalid='ignore'):
|
||||
@ -4268,9 +4295,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product([
|
||||
{'a_dtype': a_dtype, 'dtype': dtype}
|
||||
for a_dtype in all_dtypes
|
||||
for dtype in all_dtypes
|
||||
if np.dtype(a_dtype).itemsize == np.dtype(dtype).itemsize
|
||||
for a_dtype in [jnp.int4, jnp.uint4, *all_dtypes]
|
||||
for dtype in [jnp.int4, jnp.uint4, *all_dtypes]
|
||||
if dtypes.bit_width(a_dtype) == dtypes.bit_width(dtype)
|
||||
])
|
||||
def testViewScalar(self, a_dtype, dtype):
|
||||
if jtu.test_device_matches(["tpu"]):
|
||||
|
Loading…
x
Reference in New Issue
Block a user