Add implementation of numpy packbits() and unpackbits() (#2695)

* Add implementation of numpy packbits() and unpackbits()

* Fix packbits() under x64

* Add packbits & unpackbits to docs
This commit is contained in:
Jake Vanderplas 2020-04-13 11:57:18 -07:00 committed by GitHub
parent 8c2901cf4a
commit ea54c0664f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 83 additions and 0 deletions

View File

@ -188,6 +188,7 @@ Not every function in NumPy is implemented; contributions are welcome!
ones
ones_like
outer
packbits
pad
percentile
polyval
@ -244,6 +245,7 @@ Not every function in NumPy is implemented; contributions are welcome!
triu
triu_indices
true_divide
unpackbits
vander
var
vdot

View File

@ -2852,6 +2852,50 @@ def rollaxis(a, axis, start=0):
return moveaxis(a, axis, start)
@_wraps(onp.packbits)
def packbits(a, axis=None, bitorder='big'):
a = asarray(a)
if not (issubdtype(dtype(a), integer) or issubdtype(dtype(a), bool_)):
raise TypeError('Expected an input array of integer or boolean data type')
if bitorder not in ['little', 'big']:
raise ValueError("'order' must be either 'little' or 'big'")
a = (a > 0).astype('uint8')
bits = arange(8, dtype='uint8')
if bitorder == 'big':
bits = bits[::-1]
if axis is None:
a = ravel(a)
axis = 0
a = swapaxes(a, axis, -1)
remainder = a.shape[-1] % 8
if remainder:
a = pad(a, (a.ndim - 1) * [(0, 0)] + [(0, 8 - remainder)])
a = a.reshape(a.shape[:-1] + (a.shape[-1] // 8, 8))
packed = (a << bits).sum(-1).astype('uint8')
return swapaxes(packed, axis, -1)
@_wraps(onp.unpackbits)
def unpackbits(a, axis=None, count=None, bitorder='big'):
a = asarray(a)
if dtype(a) != uint8:
raise TypeError("Expected an input array of unsigned byte data type")
if bitorder not in ['little', 'big']:
raise ValueError("'order' must be either 'little' or 'big'")
bits = asarray(1) << arange(8, dtype='uint8')
if bitorder == 'big':
bits = bits[::-1]
if axis is None:
a = a.ravel()
axis = 0
a = swapaxes(a, axis, -1)
unpacked = ((a[..., None] & bits) > 0).astype('uint8')
unpacked = unpacked.reshape(unpacked.shape[:-2] + (-1,))[..., :count]
return swapaxes(unpacked, axis, -1)
@_wraps(onp.take)
def take(a, indices, axis=None, out=None, mode=None):
if out:

View File

@ -2029,6 +2029,43 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(jnp_op, onp_op, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}_bitorder={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis, bitorder),
"rng_factory": rng_factory, "shape": shape, "dtype": dtype, "axis": axis,
"bitorder": bitorder}
for dtype in [onp.uint8, onp.bool_]
for bitorder in ['big', 'little']
for shape in [(1, 2, 3, 4)]
for axis in [None, 0, 1, -2, -1]
for rng_factory in [jtu.rand_some_zero]))
def testPackbits(self, shape, dtype, axis, bitorder, rng_factory):
rng = rng_factory()
args_maker = lambda: [rng(shape, dtype)]
jnp_op = partial(jnp.packbits, axis=axis, bitorder=bitorder)
onp_op = partial(onp.packbits, axis=axis, bitorder=bitorder)
self._CheckAgainstNumpy(jnp_op, onp_op, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}_bitorder={}_count={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis, bitorder, count),
"rng_factory": rng_factory, "shape": shape, "dtype": dtype, "axis": axis,
"bitorder": bitorder, "count": count}
for dtype in [onp.uint8]
for bitorder in ['big', 'little']
for shape in [(1, 2, 3, 4)]
for axis in [None, 0, 1, -2, -1]
for count in [None, 20]
for rng_factory in [jtu.rand_int]))
def testUnpackbits(self, shape, dtype, axis, bitorder, count, rng_factory):
rng = rng_factory(0, 256)
args_maker = lambda: [rng(shape, dtype)]
jnp_op = partial(jnp.unpackbits, axis=axis, bitorder=bitorder)
onp_op = partial(onp.unpackbits, axis=axis, bitorder=bitorder)
self._CheckAgainstNumpy(jnp_op, onp_op, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_index={}_axis={}_mode={}".format(
jtu.format_shape_dtype_string(shape, dtype),