special.logsumexp: fix incorrect annotation

This commit is contained in:
Jake VanderPlas 2023-08-21 09:10:19 -07:00
parent a14d64b160
commit 27324ff18c

View File

@ -17,7 +17,7 @@ from typing import overload, Literal, Optional, Union
import jax
from jax import lax
from jax import numpy as jnp
from jax._src.numpy.reductions import _reduction_dims
from jax._src.numpy.reductions import _reduction_dims, Axis
from jax._src.numpy.util import promote_args_inexact
from jax._src.typing import Array, ArrayLike
import numpy as np
@ -27,18 +27,18 @@ import numpy as np
# unnecessary scipy dependencies.
@overload
def logsumexp(a: ArrayLike, axis: Optional[int] = None, b: Optional[ArrayLike] = None,
def logsumexp(a: ArrayLike, axis: Axis = None, b: Optional[ArrayLike] = None,
keepdims: bool = False, return_sign: Literal[False] = False) -> Array: ...
@overload
def logsumexp(a: ArrayLike, axis: Optional[int] = None, b: Optional[ArrayLike] = None,
def logsumexp(a: ArrayLike, axis: Axis = None, b: Optional[ArrayLike] = None,
keepdims: bool = False, *, return_sign: Literal[True]) -> tuple[Array, Array]: ...
@overload
def logsumexp(a: ArrayLike, axis: Optional[int] = None, b: Optional[ArrayLike] = None,
def logsumexp(a: ArrayLike, axis: Axis = None, b: Optional[ArrayLike] = None,
keepdims: bool = False, return_sign: bool = False) -> Union[Array, tuple[Array, Array]]: ...
def logsumexp(a: ArrayLike, axis: Optional[int] = None, b: Optional[ArrayLike] = None,
def logsumexp(a: ArrayLike, axis: Axis = None, b: Optional[ArrayLike] = None,
keepdims: bool = False, return_sign: bool = False) -> Union[Array, tuple[Array, Array]]:
r"""Log-sum-exp reduction.