diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index e491f8d7a..dee61c145 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -601,7 +601,7 @@ def ldexp(x: ArrayLike, y: ArrayLike, /) -> Array: ... def left_shift(x: ArrayLike, y: ArrayLike, /) -> Array: ... def less(x: ArrayLike, y: ArrayLike, /) -> Array: ... def less_equal(x: ArrayLike, y: ArrayLike, /) -> Array: ... -def lexsort(keys: Sequence[ArrayLike], axis: int = ...) -> Array: ... +def lexsort(keys: Array | _np.ndarray | Sequence[ArrayLike], axis: int = ...) -> Array: ... @overload def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,