MAINT Fixed new mypy errors

mypy seems to handle lambdas and named functions differently. So, I had to
promote a few helpers to named functions to get them to type check.
This commit is contained in:
Sergei Lebedev 2022-05-23 20:14:11 +01:00
parent be140981ac
commit c5d3ece6f5
4 changed files with 15 additions and 6 deletions

View File

@ -29,8 +29,13 @@ from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy.util import _wraps, _promote_dtypes_inexact
from jax._src.util import canonicalize_axis
_T = lambda x: jnp.swapaxes(x, -1, -2)
_H = lambda x: jnp.conjugate(jnp.swapaxes(x, -1, -2))
def _T(x):
return jnp.swapaxes(x, -1, -2)
def _H(x):
return jnp.conjugate(jnp.swapaxes(x, -1, -2))
@_wraps(np.linalg.cholesky)
@ -500,7 +505,8 @@ def norm(x, ord=None, axis : Union[None, Tuple[int, ...], int] = None,
elif ord == -2:
reducer = jnp.amin
else:
reducer = jnp.sum
# `sum` takes an extra dtype= argument, unlike `amax` and `amin`.
reducer = jnp.sum # type: ignore[assignment]
y = reducer(svd(x, compute_uv=False), axis=-1)
if keepdims:
y = jnp.expand_dims(y, axis)

View File

@ -471,7 +471,8 @@ def nansum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
initial=initial, where=where)
# Work around a sphinx documentation warning in NumPy 1.22.
nansum.__doc__ = nansum.__doc__.replace("\n\n\n", "\n\n")
if nansum.__doc__ is not None:
nansum.__doc__ = nansum.__doc__.replace("\n\n\n", "\n\n")
@_wraps(np.nanprod, skip_params=['out'])
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))

View File

@ -53,7 +53,9 @@ The JAX version only accepts real-valued inputs.""")
def digamma(x):
x, = _promote_args_inexact("digamma", x)
return lax.digamma(x)
ad.defjvp(lax.digamma_p, lambda g, x: lax.mul(g, polygamma(1, x)))
ad.defjvp(
lax.digamma_p,
lambda g, x: lax.mul(g, polygamma(1, x))) # type: ignore[has-type]
@_wraps(osp_special.gammainc, update_doc=False)

View File

@ -1551,7 +1551,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
# poly_axes=[None, 0]),
[
_make_harness("reduce", reduce_op.__name__,
lambda x: reduce_op(x, axis=-1, keepdims=True),
lambda x: reduce_op(x, axis=-1, keepdims=True), # type: ignore
[RandArg((3, 5), _f32)],
poly_axes=[0])
for reduce_op in [jnp.all, jnp.any, jnp.max, jnp.min, jnp.prod, jnp.sum]