Reenable pytype for numpy ufuncs.

Add a few type annotations to ufuncs so the exported types are more precise.

PiperOrigin-RevId: 513798060
This commit is contained in:
Peter Hawkins 2023-03-03 05:00:29 -08:00 committed by jax authors
parent 4c13ade81f
commit 7b6321cc09
2 changed files with 15 additions and 18 deletions

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# pytype: skip-file
"""
Implements ufuncs for jax.numpy.
"""
@ -24,18 +23,17 @@ from typing import Any, Callable, Tuple, Union, overload
import numpy as np
from jax import lax
from jax._src import core
from jax._src import dtypes
from jax._src.api import jit, custom_jvp
from jax._src.lax import lax as lax_internal
from jax._src.lax import lax
from jax._src.typing import Array, ArrayLike
from jax._src.numpy.util import (
_asarray, _check_arraylike, _promote_args, _promote_args_inexact,
_promote_args_numeric, _promote_dtypes_inexact, _promote_dtypes_numeric,
_promote_shapes, _where, _wraps)
_lax_const = lax_internal._const
_lax_const = lax._const
_INT_DTYPES = {
16: np.int16,
@ -52,7 +50,7 @@ def _constant_like(x, const):
def _replace_inf(x: ArrayLike) -> Array:
return lax.select(isposinf(real(x)), lax_internal._zeros(x), x)
return lax.select(isposinf(real(x)), lax._zeros(x), x)
def _one_to_one_unop(
@ -183,11 +181,10 @@ greater = _comparison_op(np.greater, lax.gt)
less_equal = _comparison_op(np.less_equal, lax.le)
less = _comparison_op(np.less, lax.lt)
logical_and = _logical_op(np.logical_and, lax.bitwise_and)
logical_not = _logical_op(np.logical_not, lax.bitwise_not)
logical_or = _logical_op(np.logical_or, lax.bitwise_or)
logical_xor = _logical_op(np.logical_xor, lax.bitwise_xor)
logical_and: BinOp = _logical_op(np.logical_and, lax.bitwise_and)
logical_not: UnOp = _logical_op(np.logical_not, lax.bitwise_not)
logical_or: BinOp = _logical_op(np.logical_or, lax.bitwise_or)
logical_xor: BinOp = _logical_op(np.logical_xor, lax.bitwise_xor)
@_wraps(np.arccosh, module='numpy')
@jit
@ -355,7 +352,7 @@ def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array:
amax = lax.max(x1, x2)
if dtypes.issubdtype(x1.dtype, np.floating):
delta = lax.sub(x1, x2)
return lax.select(lax_internal._isnan(delta),
return lax.select(lax._isnan(delta),
lax.add(x1, x2), # NaNs or infinities of the same sign.
lax.add(amax, lax.log1p(lax.exp(lax.neg(lax.abs(delta))))))
else:
@ -393,7 +390,7 @@ def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array:
amax = lax.max(x1, x2)
if dtypes.issubdtype(x1.dtype, np.floating):
delta = lax.sub(x1, x2)
return lax.select(lax_internal._isnan(delta),
return lax.select(lax._isnan(delta),
lax.add(x1, x2), # NaNs or infinities of the same sign.
lax.add(amax, lax.div(lax.log1p(exp2(lax.neg(lax.abs(delta)))),
_constant_like(x1, np.log(2)))))
@ -541,7 +538,7 @@ def frexp(x: ArrayLike, /) -> Tuple[Array, Array]:
x1 = lax.bitcast_convert_type(x1, dtype)
cond = isinf(x) | isnan(x) | (x == 0)
x2 = _where(cond, lax_internal._zeros(x2), x2)
x2 = _where(cond, lax._zeros(x2), x2)
return _where(cond, x, x1), lax.convert_element_type(x2, np.int32)
@ -551,7 +548,7 @@ def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array:
x1, x2 = _promote_args_numeric("remainder", x1, x2)
zero = _constant_like(x1, 0)
if dtypes.issubdtype(x2.dtype, np.integer):
x2 = _where(x2 == 0, lax_internal._ones(x2), x2)
x2 = _where(x2 == 0, lax._ones(x2), x2)
trunc_mod = lax.rem(x1, x2)
trunc_mod_not_zero = lax.ne(trunc_mod, zero)
do_plus = lax.bitwise_and(
@ -565,7 +562,7 @@ mod = _wraps(np.mod, module='numpy')(remainder)
def fmod(x1: ArrayLike, x2: ArrayLike, /) -> Array:
_check_arraylike("fmod", x1, x2)
if dtypes.issubdtype(dtypes.result_type(x1, x2), np.integer):
x2 = _where(x2 == 0, lax_internal._ones(x2), x2)
x2 = _where(x2 == 0, lax._ones(x2), x2)
return lax.rem(*_promote_args_numeric("fmod", x1, x2))
@ -623,7 +620,7 @@ def modf(x: ArrayLike, /, out=None) -> Tuple[Array, Array]:
x, = _promote_dtypes_inexact(x)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.modf is not supported.")
whole = _where(lax.ge(x, lax_internal._zero(x)), floor(x), ceil(x))
whole = _where(lax.ge(x, lax._zero(x)), floor(x), ceil(x))
return x - whole, whole
@ -703,7 +700,7 @@ def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array:
x1 = lax.abs(x1)
x2 = lax.abs(x2)
x1, x2 = maximum(x1, x2), minimum(x1, x2)
return lax.select(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, lax.select(x1 == 0, lax_internal._ones(x1), x1)))))
return lax.select(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, lax.select(x1 == 0, lax._ones(x1), x1)))))
@_wraps(np.reciprocal, module='numpy')

View File

@ -116,7 +116,7 @@ def minimize(
if method.lower() == 'l-bfgs-experimental-do-not-rely-on-this':
results = _minimize_lbfgs(fun_with_args, x0, **options)
success = results.converged & (~results.failed)
success = results.converged & jnp.logical_not(results.failed)
return OptimizeResults(x=results.x_k,
success=success,
status=results.status,