Merge pull request #8236 from jakevdp:fix-bincount

PiperOrigin-RevId: 403514221
This commit is contained in:
jax authors 2021-10-15 18:39:20 -07:00
commit 69d7a813e7
2 changed files with 14 additions and 10 deletions

View File

@ -2165,11 +2165,12 @@ def select(condlist, choicelist, default=0):
@_wraps(np.bincount, lax_description="""\
Jax adds the optional `length` parameter which specifies the output length, and
defaults to ``x.max() + 1``. It must be specified for bincount to be compilable.
Values larger than the specified length will be discarded.
defaults to ``x.max() + 1``. It must be specified for bincount to be compiled
with non-static operands. Values larger than the specified length will be discarded.
If `length` is specified, `minlength` will be ignored.
Additionally, while ``np.bincount`` raises an error if the input array contains
negative values, ``jax.numpy.bincount`` treats negative values as zero.
negative values, ``jax.numpy.bincount`` clips negative values to zero.
""")
def bincount(x, weights=None, minlength=0, *, length=None):
_check_arraylike("bincount", x)
@ -2184,11 +2185,10 @@ def bincount(x, weights=None, minlength=0, *, length=None):
x = core.concrete_or_error(asarray, x,
"The error occured because of argument 'x' of jnp.bincount. "
"To avoid this error, pass a static `length` argument.")
length = max(x, initial=-1) + 1
length = _max(minlength, x.size and x.max() + 1)
else:
length = core.concrete_or_error(operator.index, length,
"The error occurred because of argument 'length' of jnp.bincount.")
length = _max(length, minlength)
if weights is None:
weights = 1
elif shape(x) != shape(weights):

View File

@ -4578,19 +4578,23 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
for dtype in int_dtypes
for weights in [True, False]
for minlength in [0, 20]
for length in [None, 10]
for length in [None, 8]
))
def testBincount(self, shape, dtype, weights, minlength, length):
rng = jtu.rand_positive(self.rng())
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype), (rng(shape, 'float32') if weights else None))
np_fun = partial(np.bincount, minlength=minlength)
def np_fun(x, *args):
x = np.clip(x, 0, None) # jnp.bincount clips negative values to zero.
out = np.bincount(x, *args, minlength=minlength)
if length and length > out.size:
return np.pad(out, (0, length - out.size))
return out[:length]
jnp_fun = partial(jnp.bincount, minlength=minlength, length=length)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
if length is not None:
self._CompileAndCheck(jnp_fun, args_maker)
if length is None:
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
def testBincountNegative(self):
# Test that jnp.bincount ignores negative values.