Merge pull request #7228 from jakevdp:packbits-rank-promotion

PiperOrigin-RevId: 384257098
This commit is contained in:
jax authors 2021-07-12 09:58:28 -07:00
commit 24df92c61e
2 changed files with 2 additions and 0 deletions

View File

@ -4718,6 +4718,7 @@ def packbits(a, axis: Optional[int] = None, bitorder='big'):
(a.ndim - 1) * [(0, 0, 0)] + [(0, 8 - remainder, 0)])
a = a.reshape(a.shape[:-1] + (a.shape[-1] // 8, 8))
bits = expand_dims(bits, tuple(range(a.ndim - 1)))
packed = (a << bits).sum(-1).astype('uint8')
return swapaxes(packed, axis, -1)

View File

@ -3971,6 +3971,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
for bitorder in ['big', 'little']
for shape in [(1, 2, 3, 4)]
for axis in [None, 0, 1, -2, -1]))
@jax.numpy_rank_promotion('raise')
def testPackbits(self, shape, dtype, axis, bitorder):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, dtype)]