mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #7228 from jakevdp:packbits-rank-promotion
PiperOrigin-RevId: 384257098
This commit is contained in:
commit
24df92c61e
@ -4718,6 +4718,7 @@ def packbits(a, axis: Optional[int] = None, bitorder='big'):
|
||||
(a.ndim - 1) * [(0, 0, 0)] + [(0, 8 - remainder, 0)])
|
||||
|
||||
a = a.reshape(a.shape[:-1] + (a.shape[-1] // 8, 8))
|
||||
bits = expand_dims(bits, tuple(range(a.ndim - 1)))
|
||||
packed = (a << bits).sum(-1).astype('uint8')
|
||||
return swapaxes(packed, axis, -1)
|
||||
|
||||
|
@ -3971,6 +3971,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
for bitorder in ['big', 'little']
|
||||
for shape in [(1, 2, 3, 4)]
|
||||
for axis in [None, 0, 1, -2, -1]))
|
||||
@jax.numpy_rank_promotion('raise')
|
||||
def testPackbits(self, shape, dtype, axis, bitorder):
|
||||
rng = jtu.rand_some_zero(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
|
Loading…
x
Reference in New Issue
Block a user