mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
test digitize corner case and fix it
This commit is contained in:
parent
9a00721a54
commit
fb91b51320
@ -5373,7 +5373,7 @@ def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False) -> Array:
|
||||
if bins_arr.ndim != 1:
|
||||
raise ValueError(f"digitize: bins must be a 1-dimensional array; got {bins=}")
|
||||
if bins_arr.shape[0] == 0:
|
||||
return zeros(x, dtype=dtypes.canonicalize_dtype(int_))
|
||||
return zeros_like(x, dtype=int32)
|
||||
side = 'right' if not right else 'left'
|
||||
return where(
|
||||
bins_arr[-1] >= bins_arr[0],
|
||||
|
@ -2597,7 +2597,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(
|
||||
xshape=[(20,), (5, 4)],
|
||||
binshape=[(1,), (5,)],
|
||||
binshape=[(0,), (1,), (5,)],
|
||||
right=[True, False],
|
||||
reverse=[True, False],
|
||||
dtype=default_dtypes,
|
||||
|
Loading…
x
Reference in New Issue
Block a user