Merge pull request #17202 from jakevdp:logsumexp-types

PiperOrigin-RevId: 558873117
This commit is contained in:
jax authors 2023-08-21 12:47:45 -07:00
commit 2f28848c7c

View File

@ -17,7 +17,7 @@ from typing import overload, Literal, Optional, Union
import jax
from jax import lax
from jax import numpy as jnp
from jax._src.numpy.reductions import _reduction_dims
from jax._src.numpy.reductions import _reduction_dims, Axis
from jax._src.numpy.util import promote_args_inexact
from jax._src.typing import Array, ArrayLike
import numpy as np
@ -27,18 +27,18 @@ import numpy as np
# unnecessary scipy dependencies.
@overload
def logsumexp(a: ArrayLike, axis: Optional[int] = None, b: Optional[ArrayLike] = None,
def logsumexp(a: ArrayLike, axis: Axis = None, b: Optional[ArrayLike] = None,
keepdims: bool = False, return_sign: Literal[False] = False) -> Array: ...
@overload
def logsumexp(a: ArrayLike, axis: Optional[int] = None, b: Optional[ArrayLike] = None,
def logsumexp(a: ArrayLike, axis: Axis = None, b: Optional[ArrayLike] = None,
keepdims: bool = False, *, return_sign: Literal[True]) -> tuple[Array, Array]: ...
@overload
def logsumexp(a: ArrayLike, axis: Optional[int] = None, b: Optional[ArrayLike] = None,
def logsumexp(a: ArrayLike, axis: Axis = None, b: Optional[ArrayLike] = None,
keepdims: bool = False, return_sign: bool = False) -> Union[Array, tuple[Array, Array]]: ...
def logsumexp(a: ArrayLike, axis: Optional[int] = None, b: Optional[ArrayLike] = None,
def logsumexp(a: ArrayLike, axis: Axis = None, b: Optional[ArrayLike] = None,
keepdims: bool = False, return_sign: bool = False) -> Union[Array, tuple[Array, Array]]:
r"""Log-sum-exp reduction.