mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #17202 from jakevdp:logsumexp-types
PiperOrigin-RevId: 558873117
This commit is contained in:
commit
2f28848c7c
@ -17,7 +17,7 @@ from typing import overload, Literal, Optional, Union
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax._src.numpy.reductions import _reduction_dims
|
||||
from jax._src.numpy.reductions import _reduction_dims, Axis
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
import numpy as np
|
||||
@ -27,18 +27,18 @@ import numpy as np
|
||||
# unnecessary scipy dependencies.
|
||||
|
||||
@overload
|
||||
def logsumexp(a: ArrayLike, axis: Optional[int] = None, b: Optional[ArrayLike] = None,
|
||||
def logsumexp(a: ArrayLike, axis: Axis = None, b: Optional[ArrayLike] = None,
|
||||
keepdims: bool = False, return_sign: Literal[False] = False) -> Array: ...
|
||||
|
||||
@overload
|
||||
def logsumexp(a: ArrayLike, axis: Optional[int] = None, b: Optional[ArrayLike] = None,
|
||||
def logsumexp(a: ArrayLike, axis: Axis = None, b: Optional[ArrayLike] = None,
|
||||
keepdims: bool = False, *, return_sign: Literal[True]) -> tuple[Array, Array]: ...
|
||||
|
||||
@overload
|
||||
def logsumexp(a: ArrayLike, axis: Optional[int] = None, b: Optional[ArrayLike] = None,
|
||||
def logsumexp(a: ArrayLike, axis: Axis = None, b: Optional[ArrayLike] = None,
|
||||
keepdims: bool = False, return_sign: bool = False) -> Union[Array, tuple[Array, Array]]: ...
|
||||
|
||||
def logsumexp(a: ArrayLike, axis: Optional[int] = None, b: Optional[ArrayLike] = None,
|
||||
def logsumexp(a: ArrayLike, axis: Axis = None, b: Optional[ArrayLike] = None,
|
||||
keepdims: bool = False, return_sign: bool = False) -> Union[Array, tuple[Array, Array]]:
|
||||
r"""Log-sum-exp reduction.
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user