mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 04:16:07 +00:00
special.logsumexp: fix incorrect annotation
This commit is contained in:
parent
a14d64b160
commit
27324ff18c
@ -17,7 +17,7 @@ from typing import overload, Literal, Optional, Union
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax._src.numpy.reductions import _reduction_dims
|
||||
from jax._src.numpy.reductions import _reduction_dims, Axis
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
import numpy as np
|
||||
@ -27,18 +27,18 @@ import numpy as np
|
||||
# unnecessary scipy dependencies.
|
||||
|
||||
@overload
|
||||
def logsumexp(a: ArrayLike, axis: Optional[int] = None, b: Optional[ArrayLike] = None,
|
||||
def logsumexp(a: ArrayLike, axis: Axis = None, b: Optional[ArrayLike] = None,
|
||||
keepdims: bool = False, return_sign: Literal[False] = False) -> Array: ...
|
||||
|
||||
@overload
|
||||
def logsumexp(a: ArrayLike, axis: Optional[int] = None, b: Optional[ArrayLike] = None,
|
||||
def logsumexp(a: ArrayLike, axis: Axis = None, b: Optional[ArrayLike] = None,
|
||||
keepdims: bool = False, *, return_sign: Literal[True]) -> tuple[Array, Array]: ...
|
||||
|
||||
@overload
|
||||
def logsumexp(a: ArrayLike, axis: Optional[int] = None, b: Optional[ArrayLike] = None,
|
||||
def logsumexp(a: ArrayLike, axis: Axis = None, b: Optional[ArrayLike] = None,
|
||||
keepdims: bool = False, return_sign: bool = False) -> Union[Array, tuple[Array, Array]]: ...
|
||||
|
||||
def logsumexp(a: ArrayLike, axis: Optional[int] = None, b: Optional[ArrayLike] = None,
|
||||
def logsumexp(a: ArrayLike, axis: Axis = None, b: Optional[ArrayLike] = None,
|
||||
keepdims: bool = False, return_sign: bool = False) -> Union[Array, tuple[Array, Array]]:
|
||||
r"""Log-sum-exp reduction.
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user