test digitize corner case and fix it

This commit is contained in:
Giacomo Petrillo 2024-03-14 16:55:06 -05:00
parent 9a00721a54
commit fb91b51320
2 changed files with 2 additions and 2 deletions

View File

@ -5373,7 +5373,7 @@ def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False) -> Array:
if bins_arr.ndim != 1:
raise ValueError(f"digitize: bins must be a 1-dimensional array; got {bins=}")
if bins_arr.shape[0] == 0:
return zeros(x, dtype=dtypes.canonicalize_dtype(int_))
return zeros_like(x, dtype=int32)
side = 'right' if not right else 'left'
return where(
bins_arr[-1] >= bins_arr[0],

View File

@ -2597,7 +2597,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
@jtu.sample_product(
xshape=[(20,), (5, 4)],
binshape=[(1,), (5,)],
binshape=[(0,), (1,), (5,)],
right=[True, False],
reverse=[True, False],
dtype=default_dtypes,