mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Reenable pytype for numpy ufuncs.
Add a few type annotations to ufuncs so the exported types are more precise. PiperOrigin-RevId: 513798060
This commit is contained in:
parent
4c13ade81f
commit
7b6321cc09
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# pytype: skip-file
|
||||
"""
|
||||
Implements ufuncs for jax.numpy.
|
||||
"""
|
||||
@ -24,18 +23,17 @@ from typing import Any, Callable, Tuple, Union, overload
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax import lax
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src.api import jit, custom_jvp
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import lax
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax._src.numpy.util import (
|
||||
_asarray, _check_arraylike, _promote_args, _promote_args_inexact,
|
||||
_promote_args_numeric, _promote_dtypes_inexact, _promote_dtypes_numeric,
|
||||
_promote_shapes, _where, _wraps)
|
||||
|
||||
_lax_const = lax_internal._const
|
||||
_lax_const = lax._const
|
||||
|
||||
_INT_DTYPES = {
|
||||
16: np.int16,
|
||||
@ -52,7 +50,7 @@ def _constant_like(x, const):
|
||||
|
||||
|
||||
def _replace_inf(x: ArrayLike) -> Array:
|
||||
return lax.select(isposinf(real(x)), lax_internal._zeros(x), x)
|
||||
return lax.select(isposinf(real(x)), lax._zeros(x), x)
|
||||
|
||||
|
||||
def _one_to_one_unop(
|
||||
@ -183,11 +181,10 @@ greater = _comparison_op(np.greater, lax.gt)
|
||||
less_equal = _comparison_op(np.less_equal, lax.le)
|
||||
less = _comparison_op(np.less, lax.lt)
|
||||
|
||||
logical_and = _logical_op(np.logical_and, lax.bitwise_and)
|
||||
logical_not = _logical_op(np.logical_not, lax.bitwise_not)
|
||||
logical_or = _logical_op(np.logical_or, lax.bitwise_or)
|
||||
logical_xor = _logical_op(np.logical_xor, lax.bitwise_xor)
|
||||
|
||||
logical_and: BinOp = _logical_op(np.logical_and, lax.bitwise_and)
|
||||
logical_not: UnOp = _logical_op(np.logical_not, lax.bitwise_not)
|
||||
logical_or: BinOp = _logical_op(np.logical_or, lax.bitwise_or)
|
||||
logical_xor: BinOp = _logical_op(np.logical_xor, lax.bitwise_xor)
|
||||
|
||||
@_wraps(np.arccosh, module='numpy')
|
||||
@jit
|
||||
@ -355,7 +352,7 @@ def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
amax = lax.max(x1, x2)
|
||||
if dtypes.issubdtype(x1.dtype, np.floating):
|
||||
delta = lax.sub(x1, x2)
|
||||
return lax.select(lax_internal._isnan(delta),
|
||||
return lax.select(lax._isnan(delta),
|
||||
lax.add(x1, x2), # NaNs or infinities of the same sign.
|
||||
lax.add(amax, lax.log1p(lax.exp(lax.neg(lax.abs(delta))))))
|
||||
else:
|
||||
@ -393,7 +390,7 @@ def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
amax = lax.max(x1, x2)
|
||||
if dtypes.issubdtype(x1.dtype, np.floating):
|
||||
delta = lax.sub(x1, x2)
|
||||
return lax.select(lax_internal._isnan(delta),
|
||||
return lax.select(lax._isnan(delta),
|
||||
lax.add(x1, x2), # NaNs or infinities of the same sign.
|
||||
lax.add(amax, lax.div(lax.log1p(exp2(lax.neg(lax.abs(delta)))),
|
||||
_constant_like(x1, np.log(2)))))
|
||||
@ -541,7 +538,7 @@ def frexp(x: ArrayLike, /) -> Tuple[Array, Array]:
|
||||
x1 = lax.bitcast_convert_type(x1, dtype)
|
||||
|
||||
cond = isinf(x) | isnan(x) | (x == 0)
|
||||
x2 = _where(cond, lax_internal._zeros(x2), x2)
|
||||
x2 = _where(cond, lax._zeros(x2), x2)
|
||||
return _where(cond, x, x1), lax.convert_element_type(x2, np.int32)
|
||||
|
||||
|
||||
@ -551,7 +548,7 @@ def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
x1, x2 = _promote_args_numeric("remainder", x1, x2)
|
||||
zero = _constant_like(x1, 0)
|
||||
if dtypes.issubdtype(x2.dtype, np.integer):
|
||||
x2 = _where(x2 == 0, lax_internal._ones(x2), x2)
|
||||
x2 = _where(x2 == 0, lax._ones(x2), x2)
|
||||
trunc_mod = lax.rem(x1, x2)
|
||||
trunc_mod_not_zero = lax.ne(trunc_mod, zero)
|
||||
do_plus = lax.bitwise_and(
|
||||
@ -565,7 +562,7 @@ mod = _wraps(np.mod, module='numpy')(remainder)
|
||||
def fmod(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
_check_arraylike("fmod", x1, x2)
|
||||
if dtypes.issubdtype(dtypes.result_type(x1, x2), np.integer):
|
||||
x2 = _where(x2 == 0, lax_internal._ones(x2), x2)
|
||||
x2 = _where(x2 == 0, lax._ones(x2), x2)
|
||||
return lax.rem(*_promote_args_numeric("fmod", x1, x2))
|
||||
|
||||
|
||||
@ -623,7 +620,7 @@ def modf(x: ArrayLike, /, out=None) -> Tuple[Array, Array]:
|
||||
x, = _promote_dtypes_inexact(x)
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.modf is not supported.")
|
||||
whole = _where(lax.ge(x, lax_internal._zero(x)), floor(x), ceil(x))
|
||||
whole = _where(lax.ge(x, lax._zero(x)), floor(x), ceil(x))
|
||||
return x - whole, whole
|
||||
|
||||
|
||||
@ -703,7 +700,7 @@ def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
x1 = lax.abs(x1)
|
||||
x2 = lax.abs(x2)
|
||||
x1, x2 = maximum(x1, x2), minimum(x1, x2)
|
||||
return lax.select(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, lax.select(x1 == 0, lax_internal._ones(x1), x1)))))
|
||||
return lax.select(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, lax.select(x1 == 0, lax._ones(x1), x1)))))
|
||||
|
||||
|
||||
@_wraps(np.reciprocal, module='numpy')
|
||||
|
@ -116,7 +116,7 @@ def minimize(
|
||||
|
||||
if method.lower() == 'l-bfgs-experimental-do-not-rely-on-this':
|
||||
results = _minimize_lbfgs(fun_with_args, x0, **options)
|
||||
success = results.converged & (~results.failed)
|
||||
success = results.converged & jnp.logical_not(results.failed)
|
||||
return OptimizeResults(x=results.x_k,
|
||||
success=success,
|
||||
status=results.status,
|
||||
|
Loading…
x
Reference in New Issue
Block a user