mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
MAINT Fixed new mypy errors
mypy seems to handle lambdas and named functions differently. So, I had to promote a few helpers to named functions to get them to type check.
This commit is contained in:
parent
be140981ac
commit
c5d3ece6f5
@ -29,8 +29,13 @@ from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src.numpy.util import _wraps, _promote_dtypes_inexact
|
||||
from jax._src.util import canonicalize_axis
|
||||
|
||||
_T = lambda x: jnp.swapaxes(x, -1, -2)
|
||||
_H = lambda x: jnp.conjugate(jnp.swapaxes(x, -1, -2))
|
||||
|
||||
def _T(x):
|
||||
return jnp.swapaxes(x, -1, -2)
|
||||
|
||||
|
||||
def _H(x):
|
||||
return jnp.conjugate(jnp.swapaxes(x, -1, -2))
|
||||
|
||||
|
||||
@_wraps(np.linalg.cholesky)
|
||||
@ -500,7 +505,8 @@ def norm(x, ord=None, axis : Union[None, Tuple[int, ...], int] = None,
|
||||
elif ord == -2:
|
||||
reducer = jnp.amin
|
||||
else:
|
||||
reducer = jnp.sum
|
||||
# `sum` takes an extra dtype= argument, unlike `amax` and `amin`.
|
||||
reducer = jnp.sum # type: ignore[assignment]
|
||||
y = reducer(svd(x, compute_uv=False), axis=-1)
|
||||
if keepdims:
|
||||
y = jnp.expand_dims(y, axis)
|
||||
|
@ -471,7 +471,8 @@ def nansum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
initial=initial, where=where)
|
||||
|
||||
# Work around a sphinx documentation warning in NumPy 1.22.
|
||||
nansum.__doc__ = nansum.__doc__.replace("\n\n\n", "\n\n")
|
||||
if nansum.__doc__ is not None:
|
||||
nansum.__doc__ = nansum.__doc__.replace("\n\n\n", "\n\n")
|
||||
|
||||
@_wraps(np.nanprod, skip_params=['out'])
|
||||
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
|
||||
|
@ -53,7 +53,9 @@ The JAX version only accepts real-valued inputs.""")
|
||||
def digamma(x):
|
||||
x, = _promote_args_inexact("digamma", x)
|
||||
return lax.digamma(x)
|
||||
ad.defjvp(lax.digamma_p, lambda g, x: lax.mul(g, polygamma(1, x)))
|
||||
ad.defjvp(
|
||||
lax.digamma_p,
|
||||
lambda g, x: lax.mul(g, polygamma(1, x))) # type: ignore[has-type]
|
||||
|
||||
|
||||
@_wraps(osp_special.gammainc, update_doc=False)
|
||||
|
@ -1551,7 +1551,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
# poly_axes=[None, 0]),
|
||||
[
|
||||
_make_harness("reduce", reduce_op.__name__,
|
||||
lambda x: reduce_op(x, axis=-1, keepdims=True),
|
||||
lambda x: reduce_op(x, axis=-1, keepdims=True), # type: ignore
|
||||
[RandArg((3, 5), _f32)],
|
||||
poly_axes=[0])
|
||||
for reduce_op in [jnp.all, jnp.any, jnp.max, jnp.min, jnp.prod, jnp.sum]
|
||||
|
Loading…
x
Reference in New Issue
Block a user