mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add implementation of numpy packbits() and unpackbits() (#2695)
* Add implementation of numpy packbits() and unpackbits() * Fix packbits() under x64 * Add packbits & unpackbits to docs
This commit is contained in:
parent
8c2901cf4a
commit
ea54c0664f
@ -188,6 +188,7 @@ Not every function in NumPy is implemented; contributions are welcome!
|
||||
ones
|
||||
ones_like
|
||||
outer
|
||||
packbits
|
||||
pad
|
||||
percentile
|
||||
polyval
|
||||
@ -244,6 +245,7 @@ Not every function in NumPy is implemented; contributions are welcome!
|
||||
triu
|
||||
triu_indices
|
||||
true_divide
|
||||
unpackbits
|
||||
vander
|
||||
var
|
||||
vdot
|
||||
|
@ -2852,6 +2852,50 @@ def rollaxis(a, axis, start=0):
|
||||
return moveaxis(a, axis, start)
|
||||
|
||||
|
||||
@_wraps(onp.packbits)
|
||||
def packbits(a, axis=None, bitorder='big'):
|
||||
a = asarray(a)
|
||||
if not (issubdtype(dtype(a), integer) or issubdtype(dtype(a), bool_)):
|
||||
raise TypeError('Expected an input array of integer or boolean data type')
|
||||
if bitorder not in ['little', 'big']:
|
||||
raise ValueError("'order' must be either 'little' or 'big'")
|
||||
a = (a > 0).astype('uint8')
|
||||
bits = arange(8, dtype='uint8')
|
||||
if bitorder == 'big':
|
||||
bits = bits[::-1]
|
||||
if axis is None:
|
||||
a = ravel(a)
|
||||
axis = 0
|
||||
a = swapaxes(a, axis, -1)
|
||||
|
||||
remainder = a.shape[-1] % 8
|
||||
if remainder:
|
||||
a = pad(a, (a.ndim - 1) * [(0, 0)] + [(0, 8 - remainder)])
|
||||
|
||||
a = a.reshape(a.shape[:-1] + (a.shape[-1] // 8, 8))
|
||||
packed = (a << bits).sum(-1).astype('uint8')
|
||||
return swapaxes(packed, axis, -1)
|
||||
|
||||
|
||||
@_wraps(onp.unpackbits)
|
||||
def unpackbits(a, axis=None, count=None, bitorder='big'):
|
||||
a = asarray(a)
|
||||
if dtype(a) != uint8:
|
||||
raise TypeError("Expected an input array of unsigned byte data type")
|
||||
if bitorder not in ['little', 'big']:
|
||||
raise ValueError("'order' must be either 'little' or 'big'")
|
||||
bits = asarray(1) << arange(8, dtype='uint8')
|
||||
if bitorder == 'big':
|
||||
bits = bits[::-1]
|
||||
if axis is None:
|
||||
a = a.ravel()
|
||||
axis = 0
|
||||
a = swapaxes(a, axis, -1)
|
||||
unpacked = ((a[..., None] & bits) > 0).astype('uint8')
|
||||
unpacked = unpacked.reshape(unpacked.shape[:-2] + (-1,))[..., :count]
|
||||
return swapaxes(unpacked, axis, -1)
|
||||
|
||||
|
||||
@_wraps(onp.take)
|
||||
def take(a, indices, axis=None, out=None, mode=None):
|
||||
if out:
|
||||
|
@ -2029,6 +2029,43 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(jnp_op, onp_op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_axis={}_bitorder={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), axis, bitorder),
|
||||
"rng_factory": rng_factory, "shape": shape, "dtype": dtype, "axis": axis,
|
||||
"bitorder": bitorder}
|
||||
for dtype in [onp.uint8, onp.bool_]
|
||||
for bitorder in ['big', 'little']
|
||||
for shape in [(1, 2, 3, 4)]
|
||||
for axis in [None, 0, 1, -2, -1]
|
||||
for rng_factory in [jtu.rand_some_zero]))
|
||||
def testPackbits(self, shape, dtype, axis, bitorder, rng_factory):
|
||||
rng = rng_factory()
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
jnp_op = partial(jnp.packbits, axis=axis, bitorder=bitorder)
|
||||
onp_op = partial(onp.packbits, axis=axis, bitorder=bitorder)
|
||||
self._CheckAgainstNumpy(jnp_op, onp_op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_axis={}_bitorder={}_count={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), axis, bitorder, count),
|
||||
"rng_factory": rng_factory, "shape": shape, "dtype": dtype, "axis": axis,
|
||||
"bitorder": bitorder, "count": count}
|
||||
for dtype in [onp.uint8]
|
||||
for bitorder in ['big', 'little']
|
||||
for shape in [(1, 2, 3, 4)]
|
||||
for axis in [None, 0, 1, -2, -1]
|
||||
for count in [None, 20]
|
||||
for rng_factory in [jtu.rand_int]))
|
||||
def testUnpackbits(self, shape, dtype, axis, bitorder, count, rng_factory):
|
||||
rng = rng_factory(0, 256)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
jnp_op = partial(jnp.unpackbits, axis=axis, bitorder=bitorder)
|
||||
onp_op = partial(onp.unpackbits, axis=axis, bitorder=bitorder)
|
||||
self._CheckAgainstNumpy(jnp_op, onp_op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_index={}_axis={}_mode={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype),
|
||||
|
Loading…
x
Reference in New Issue
Block a user