mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #8236 from jakevdp:fix-bincount
PiperOrigin-RevId: 403514221
This commit is contained in:
commit
69d7a813e7
@ -2165,11 +2165,12 @@ def select(condlist, choicelist, default=0):
|
||||
|
||||
@_wraps(np.bincount, lax_description="""\
|
||||
Jax adds the optional `length` parameter which specifies the output length, and
|
||||
defaults to ``x.max() + 1``. It must be specified for bincount to be compilable.
|
||||
Values larger than the specified length will be discarded.
|
||||
defaults to ``x.max() + 1``. It must be specified for bincount to be compiled
|
||||
with non-static operands. Values larger than the specified length will be discarded.
|
||||
If `length` is specified, `minlength` will be ignored.
|
||||
|
||||
Additionally, while ``np.bincount`` raises an error if the input array contains
|
||||
negative values, ``jax.numpy.bincount`` treats negative values as zero.
|
||||
negative values, ``jax.numpy.bincount`` clips negative values to zero.
|
||||
""")
|
||||
def bincount(x, weights=None, minlength=0, *, length=None):
|
||||
_check_arraylike("bincount", x)
|
||||
@ -2184,11 +2185,10 @@ def bincount(x, weights=None, minlength=0, *, length=None):
|
||||
x = core.concrete_or_error(asarray, x,
|
||||
"The error occured because of argument 'x' of jnp.bincount. "
|
||||
"To avoid this error, pass a static `length` argument.")
|
||||
length = max(x, initial=-1) + 1
|
||||
length = _max(minlength, x.size and x.max() + 1)
|
||||
else:
|
||||
length = core.concrete_or_error(operator.index, length,
|
||||
"The error occurred because of argument 'length' of jnp.bincount.")
|
||||
length = _max(length, minlength)
|
||||
if weights is None:
|
||||
weights = 1
|
||||
elif shape(x) != shape(weights):
|
||||
|
@ -4578,19 +4578,23 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
for dtype in int_dtypes
|
||||
for weights in [True, False]
|
||||
for minlength in [0, 20]
|
||||
for length in [None, 10]
|
||||
for length in [None, 8]
|
||||
))
|
||||
def testBincount(self, shape, dtype, weights, minlength, length):
|
||||
rng = jtu.rand_positive(self.rng())
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype), (rng(shape, 'float32') if weights else None))
|
||||
|
||||
np_fun = partial(np.bincount, minlength=minlength)
|
||||
def np_fun(x, *args):
|
||||
x = np.clip(x, 0, None) # jnp.bincount clips negative values to zero.
|
||||
out = np.bincount(x, *args, minlength=minlength)
|
||||
if length and length > out.size:
|
||||
return np.pad(out, (0, length - out.size))
|
||||
return out[:length]
|
||||
jnp_fun = partial(jnp.bincount, minlength=minlength, length=length)
|
||||
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
||||
if length is not None:
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
if length is None:
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
||||
|
||||
def testBincountNegative(self):
|
||||
# Test that jnp.bincount ignores negative values.
|
||||
|
Loading…
x
Reference in New Issue
Block a user