diff --git a/jax/_src/ops/special.py b/jax/_src/ops/special.py index d856f8899..545844ad3 100644 --- a/jax/_src/ops/special.py +++ b/jax/_src/ops/special.py @@ -17,7 +17,7 @@ from typing import overload, Literal, Optional, Union import jax from jax import lax from jax import numpy as jnp -from jax._src.numpy.reductions import _reduction_dims +from jax._src.numpy.reductions import _reduction_dims, Axis from jax._src.numpy.util import promote_args_inexact from jax._src.typing import Array, ArrayLike import numpy as np @@ -27,18 +27,18 @@ import numpy as np # unnecessary scipy dependencies. @overload -def logsumexp(a: ArrayLike, axis: Optional[int] = None, b: Optional[ArrayLike] = None, +def logsumexp(a: ArrayLike, axis: Axis = None, b: Optional[ArrayLike] = None, keepdims: bool = False, return_sign: Literal[False] = False) -> Array: ... @overload -def logsumexp(a: ArrayLike, axis: Optional[int] = None, b: Optional[ArrayLike] = None, +def logsumexp(a: ArrayLike, axis: Axis = None, b: Optional[ArrayLike] = None, keepdims: bool = False, *, return_sign: Literal[True]) -> tuple[Array, Array]: ... @overload -def logsumexp(a: ArrayLike, axis: Optional[int] = None, b: Optional[ArrayLike] = None, +def logsumexp(a: ArrayLike, axis: Axis = None, b: Optional[ArrayLike] = None, keepdims: bool = False, return_sign: bool = False) -> Union[Array, tuple[Array, Array]]: ... -def logsumexp(a: ArrayLike, axis: Optional[int] = None, b: Optional[ArrayLike] = None, +def logsumexp(a: ArrayLike, axis: Axis = None, b: Optional[ArrayLike] = None, keepdims: bool = False, return_sign: bool = False) -> Union[Array, tuple[Array, Array]]: r"""Log-sum-exp reduction.