internal: avoid unused imports in lax_numpy

This commit is contained in:
Jake VanderPlas 2023-03-08 10:29:04 -08:00
parent 5daa4f3dc6
commit c8c269f5f5
49 changed files with 656 additions and 676 deletions

View File

@ -33,6 +33,7 @@ from typing import NamedTuple
import jax
import jax._src.numpy.lax_numpy as jnp
import jax._src.numpy.linalg as jnp_linalg
from jax._src.numpy import ufuncs
from jax import lax
from jax._src.lax import qdwh
from jax._src.lax import linalg as lax_linalg
@ -163,7 +164,7 @@ def _projector_subspace(P, H, n, rank, maxiter=2):
_, _, j, error = args
still_counting = j < maxiter
unconverged = error > thresh
return jnp.logical_and(still_counting, unconverged)[0]
return ufuncs.logical_and(still_counting, unconverged)[0]
def body_f(args):
V1, _, j, _ = args
@ -204,7 +205,7 @@ def split_spectrum(H, n, split_point, V0=None):
H_shift = H - (split_point * jnp.eye(N, dtype=split_point.dtype)).astype(H.dtype)
U, _, _, _ = qdwh.qdwh(H_shift, is_hermitian=True, dynamic_shape=(n, n))
P = -0.5 * (U - _mask(jnp.eye(N, dtype=H.dtype), (n, n)))
rank = jnp.round(jnp.trace(jnp.real(P))).astype(jnp.int32)
rank = jnp.round(jnp.trace(ufuncs.real(P))).astype(jnp.int32)
V_minus, V_plus = _projector_subspace(P, H, n, rank)
H_minus = (V_minus.conj().T @ H) @ V_minus
@ -359,7 +360,7 @@ def _eigh_work(H, n, termination_size=256):
def default_case(agenda, blocks, eigenvectors):
V = _slice(eigenvectors, (0, offset), (n, b), (N, B))
# TODO: Improve this?
split_point = jnp.nanmedian(_mask(jnp.diag(jnp.real(H)), (b,), jnp.nan))
split_point = jnp.nanmedian(_mask(jnp.diag(ufuncs.real(H)), (b,), jnp.nan))
H_minus, V_minus, H_plus, V_plus, rank = split_spectrum(
H, b, split_point, V0=V)
@ -381,7 +382,7 @@ def _eigh_work(H, n, termination_size=256):
norm = jnp_linalg.norm(H)
tol = jnp.asarray(10 * jnp.finfo(H.dtype).eps / 2, dtype=norm.dtype)
off_diag_norm = jnp_linalg.norm(
H - jnp.diag(jnp.diag(jnp.real(H)).astype(H.dtype)))
H - jnp.diag(jnp.diag(ufuncs.real(H)).astype(H.dtype)))
# We also handle nearly-all-zero matrices matrices here.
nearly_diagonal = (norm < tol) | (off_diag_norm / norm < tol)
return lax.cond(nearly_diagonal, nearly_diagonal_case, default_case,
@ -450,7 +451,7 @@ def eigh(H, *, precision="float32", termination_size=256, n=None,
n = N if n is None else n
with jax.default_matmul_precision(precision):
eig_vals, eig_vecs = _eigh_work(H, n, termination_size=termination_size)
eig_vals = _mask(jnp.real(eig_vals), (n,), jnp.nan)
eig_vals = _mask(ufuncs.real(eig_vals), (n,), jnp.nan)
if sort_eigenvalues:
sort_idxs = jnp.argsort(eig_vals)
eig_vals = eig_vals[sort_idxs]

View File

@ -49,6 +49,8 @@ from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import chlo
from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import reductions
from jax._src.numpy import ufuncs
from jax._src.numpy.vectorize import vectorize
from jax._src.typing import Array, ArrayLike
@ -375,7 +377,7 @@ def _solve(a: Array, b: Array) -> Array:
return api.vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b)
def _T(x: Array) -> Array: return jnp.swapaxes(x, -1, -2)
def _H(x: Array) -> Array: return jnp.conj(_T(x))
def _H(x: Array) -> Array: return ufuncs.conj(_T(x))
def symmetrize(x: Array) -> Array: return (x + _H(x)) / 2
# primitives
@ -551,7 +553,7 @@ def eig_jvp_rule(primals, tangents, *, compute_left_eigenvectors,
a, = primals
da, = tangents
l, v = eig(a, compute_left_eigenvectors=False)
return [l], [jnp.sum(_solve(v, da.astype(v.dtype)) * _T(v), -1)]
return [l], [reductions.sum(_solve(v, da.astype(v.dtype)) * _T(v), -1)]
eig_p = Primitive('eig')
eig_p.multiple_results = True
@ -679,13 +681,13 @@ def _eigh_tpu_impl(x, *, lower, sort_eigenvalues):
if lower:
mask = jnp.tri(n, k=0, dtype=bool)
else:
mask = jnp.logical_not(jnp.tri(n, k=-1, dtype=bool))
mask = ufuncs.logical_not(jnp.tri(n, k=-1, dtype=bool))
if dtypes.issubdtype(x.dtype, jnp.complexfloating):
re = lax.select(mask, lax.real(x), _T(lax.real(x)))
if lower:
im_mask = jnp.tri(n, k=-1, dtype=bool)
else:
im_mask = jnp.logical_not(jnp.tri(n, k=0, dtype=bool))
im_mask = ufuncs.logical_not(jnp.tri(n, k=0, dtype=bool))
im = lax.select(im_mask, lax.imag(x), jnp.zeros_like(lax.imag(x)))
im = lax.select(mask, im, -_T(im))
x = lax.complex(re, im)
@ -717,13 +719,13 @@ def _eigh_jvp_rule(primals, tangents, *, lower, sort_eigenvalues):
w = w_real.astype(a.dtype)
eye_n = jnp.eye(a.shape[-1], dtype=a.dtype)
# carefully build reciprocal delta-eigenvalue matrix, avoiding NaNs.
Fmat = jnp.reciprocal(eye_n + w[..., jnp.newaxis, :] - w[..., jnp.newaxis]) - eye_n
Fmat = ufuncs.reciprocal(eye_n + w[..., jnp.newaxis, :] - w[..., jnp.newaxis]) - eye_n
# eigh impl doesn't support batch dims, but future-proof the grad.
dot = partial(lax.dot if a.ndim == 2 else lax.batch_matmul,
precision=lax.Precision.HIGHEST)
vdag_adot_v = dot(dot(_H(v), a_dot), v)
dv = dot(v, jnp.multiply(Fmat, vdag_adot_v))
dw = jnp.real(jnp.diagonal(vdag_adot_v, axis1=-2, axis2=-1))
dv = dot(v, ufuncs.multiply(Fmat, vdag_adot_v))
dw = ufuncs.real(jnp.diagonal(vdag_adot_v, axis1=-2, axis2=-1))
return (v, w_real), (dv, dw)
def _eigh_batching_rule(batched_args, batch_dims, *, lower, sort_eigenvalues):
@ -789,7 +791,7 @@ def _triangular_solve_jvp_rule_a(
g_a = jnp.tril(g_a, k=-k) if lower else jnp.triu(g_a, k=k)
g_a = lax.neg(g_a)
g_a = jnp.swapaxes(g_a, -1, -2) if transpose_a else g_a
g_a = jnp.conj(g_a) if conjugate_a else g_a
g_a = ufuncs.conj(g_a) if conjugate_a else g_a
dot = partial(lax.dot if g_a.ndim == 2 else lax.batch_matmul,
precision=lax.Precision.HIGHEST)
@ -1029,9 +1031,9 @@ def _lu_unblocked(a):
if jnp.issubdtype(a.dtype, jnp.complexfloating):
t = a[:, k]
magnitude = jnp.abs(jnp.real(t)) + jnp.abs(jnp.imag(t))
magnitude = ufuncs.abs(ufuncs.real(t)) + ufuncs.abs(ufuncs.imag(t))
else:
magnitude = jnp.abs(a[:, k])
magnitude = ufuncs.abs(a[:, k])
i = jnp.argmax(jnp.where(m_idx >= k, magnitude, -jnp.inf))
pivot = pivot.at[k].set(i)
a = a.at[[k, i],].set(a[[i, k],])
@ -1553,7 +1555,7 @@ def _svd_jvp_rule(primals, tangents, *, full_matrices, compute_uv):
Ut, V = _H(U), _H(Vt)
s_dim = s[..., None, :]
dS = Ut @ dA @ V
ds = jnp.real(jnp.diagonal(dS, 0, -2, -1))
ds = ufuncs.real(jnp.diagonal(dS, 0, -2, -1))
if not compute_uv:
return (s,), (ds,)

View File

@ -23,6 +23,7 @@ from jax._src.lib import xla_client
from jax._src.util import safe_zip
from jax._src.numpy.util import _check_arraylike, _wraps
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import ufuncs, reductions
from jax._src.typing import Array, ArrayLike
Shape = Sequence[int]
@ -31,9 +32,9 @@ def _fft_norm(s: Array, func_name: str, norm: str) -> Array:
if norm == "backward":
return jnp.array(1)
elif norm == "ortho":
return jnp.sqrt(jnp.prod(s)) if func_name.startswith('i') else 1/jnp.sqrt(jnp.prod(s))
return ufuncs.sqrt(reductions.prod(s)) if func_name.startswith('i') else 1/ufuncs.sqrt(reductions.prod(s))
elif norm == "forward":
return jnp.prod(s) if func_name.startswith('i') else 1/jnp.prod(s)
return reductions.prod(s) if func_name.startswith('i') else 1/reductions.prod(s)
raise ValueError(f'Invalid norm value {norm}; should be "backward",'
'"ortho" or "forward".')
@ -175,7 +176,7 @@ def irfft(a: ArrayLike, n: Optional[int] = None,
@_wraps(np.fft.hfft)
def hfft(a: ArrayLike, n: Optional[int] = None,
axis: int = -1, norm: Optional[str] = None) -> Array:
conj_a = jnp.conj(a)
conj_a = ufuncs.conj(a)
_axis_check_1d('hfft', axis)
nn = (conj_a.shape[axis] - 1) * 2 if n is None else n
return _fft_core_1d('hfft', xla_client.FftType.IRFFT, conj_a, n=n, axis=axis,
@ -189,7 +190,7 @@ def ihfft(a: ArrayLike, n: Optional[int] = None,
nn = arr.shape[axis] if n is None else n
output = _fft_core_1d('ihfft', xla_client.FftType.RFFT, arr, n=n, axis=axis,
norm=norm)
return jnp.conj(output) * (1 / nn)
return ufuncs.conj(output) * (1 / nn)
def _fft_core_2d(func_name: str, fft_type: xla_client.FftType, a: ArrayLike,

View File

@ -16,9 +16,11 @@ import abc
from typing import Any, Iterable, List, Tuple, Union
import jax
import jax._src.numpy.lax_numpy as jnp
from jax._src import core
from jax._src.numpy.util import _promote_dtypes
from jax._src.numpy.lax_numpy import (
arange, array, concatenate, expand_dims, linspace, meshgrid, stack, transpose
)
from jax._src.typing import Array, ArrayLike
import numpy as np
@ -35,9 +37,9 @@ def _make_1d_grid_from_slice(s: slice, op_name: str) -> Array:
step = core.concrete_or_error(None, s.step,
f"slice step of jnp.{op_name}") or 1
if np.iscomplex(step):
newobj = jnp.linspace(start, stop, int(abs(step)))
newobj = linspace(start, stop, int(abs(step)))
else:
newobj = jnp.arange(start, stop, step)
newobj = arange(start, stop, step)
return newobj
@ -53,12 +55,12 @@ class _IndexGrid(abc.ABC):
output: Iterable[Array] = (_make_1d_grid_from_slice(k, op_name=self.op_name) for k in key)
with jax.numpy_dtype_promotion('standard'):
output = _promote_dtypes(*output)
output_arr = jnp.meshgrid(*output, indexing='ij', sparse=self.sparse)
output_arr = meshgrid(*output, indexing='ij', sparse=self.sparse)
if self.sparse:
return output_arr
if len(output_arr) == 0:
return jnp.arange(0)
return jnp.stack(output_arr, 0)
return arange(0)
return stack(output_arr, 0)
class _Mgrid(_IndexGrid):
@ -178,10 +180,10 @@ class _AxisConcat(abc.ABC):
elif isinstance(item, str):
raise ValueError("string directive must be placed at the beginning")
else:
newobj = jnp.array(item, copy=False)
newobj = array(item, copy=False)
item_ndim = newobj.ndim
newobj = jnp.array(newobj, copy=False, ndmin=ndmin)
newobj = array(newobj, copy=False, ndmin=ndmin)
if trans1d != -1 and ndmin - item_ndim > 0:
shape_obj = tuple(range(ndmin))
@ -189,15 +191,15 @@ class _AxisConcat(abc.ABC):
num_lshifts = ndmin - abs(ndmin + trans1d + 1) % ndmin
shape_obj = tuple(shape_obj[num_lshifts:] + shape_obj[:num_lshifts])
newobj = jnp.transpose(newobj, shape_obj)
newobj = transpose(newobj, shape_obj)
output.append(newobj)
res = jnp.concatenate(tuple(output), axis=axis)
res = concatenate(tuple(output), axis=axis)
if matrix != -1 and res.ndim == 1:
# insert 2nd dim at axis 0 or 1
res = jnp.expand_dims(res, matrix)
res = expand_dims(res, matrix)
return res

File diff suppressed because it is too large Load Diff

View File

@ -27,6 +27,7 @@ from jax import lax
from jax._src.lax import lax as lax_internal
from jax._src.lax import linalg as lax_linalg
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import reductions, ufuncs
from jax._src.numpy.util import _wraps, _promote_dtypes_inexact, _check_arraylike
from jax._src.util import canonicalize_axis
from jax._src.typing import ArrayLike, Array
@ -37,7 +38,7 @@ def _T(x: ArrayLike) -> Array:
def _H(x: ArrayLike) -> Array:
return jnp.conjugate(jnp.swapaxes(x, -1, -2))
return ufuncs.conjugate(jnp.swapaxes(x, -1, -2))
def _symmetrize(x: Array) -> Array: return (x + _H(x)) / 2
@ -136,12 +137,12 @@ def matrix_rank(M: ArrayLike, tol: Optional[ArrayLike] = None) -> Array:
_check_arraylike("jnp.linalg.matrix_rank", M)
M, = _promote_dtypes_inexact(jnp.asarray(M))
if M.ndim < 2:
return jnp.any(M != 0).astype(jnp.int32)
return (M != 0).any().astype(jnp.int32)
S = svd(M, full_matrices=False, compute_uv=False)
if tol is None:
tol = S.max(-1) * np.max(M.shape[-2:]).astype(S.dtype) * jnp.finfo(S.dtype).eps
tol = jnp.expand_dims(tol, np.ndim(tol))
return jnp.sum(S > tol, axis=-1)
return reductions.sum(S > tol, axis=-1)
@custom_jvp
@ -149,22 +150,22 @@ def _slogdet_lu(a: Array) -> Tuple[Array, Array]:
dtype = lax.dtype(a)
lu, pivot, _ = lax_linalg.lu(a)
diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1)
is_zero = reductions.any(diag == jnp.array(0, dtype=dtype), axis=-1)
iota = lax.expand_dims(jnp.arange(a.shape[-1], dtype=pivot.dtype),
range(pivot.ndim - 1))
parity = jnp.count_nonzero(pivot != iota, axis=-1)
parity = reductions.count_nonzero(pivot != iota, axis=-1)
if jnp.iscomplexobj(a):
sign = jnp.prod(diag / jnp.abs(diag).astype(diag.dtype), axis=-1)
sign = reductions.prod(diag / ufuncs.abs(diag).astype(diag.dtype), axis=-1)
else:
sign = jnp.array(1, dtype=dtype)
parity = parity + jnp.count_nonzero(diag < 0, axis=-1)
parity = parity + reductions.count_nonzero(diag < 0, axis=-1)
sign = jnp.where(is_zero,
jnp.array(0, dtype=dtype),
sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype))
logdet = jnp.where(
is_zero, jnp.array(-jnp.inf, dtype=dtype),
jnp.sum(jnp.log(jnp.abs(diag)).astype(dtype), axis=-1))
return sign, jnp.real(logdet)
reductions.sum(ufuncs.log(ufuncs.abs(diag)).astype(dtype), axis=-1))
return sign, ufuncs.real(logdet)
@custom_jvp
def _slogdet_qr(a: Array) -> Tuple[Array, Array]:
@ -179,11 +180,11 @@ def _slogdet_qr(a: Array) -> Tuple[Array, Array]:
# The determinant of a triangular matrix is the product of its diagonal
# elements. We are working in log space, so we compute the magnitude as the
# the trace of the log-absolute values, and we compute the sign separately.
log_abs_det = jnp.trace(jnp.log(jnp.abs(a)), axis1=-2, axis2=-1)
sign_diag = jnp.prod(jnp.sign(jnp.diagonal(a, axis1=-2, axis2=-1)), axis=-1)
log_abs_det = jnp.trace(ufuncs.log(ufuncs.abs(a)), axis1=-2, axis2=-1)
sign_diag = reductions.prod(ufuncs.sign(jnp.diagonal(a, axis1=-2, axis2=-1)), axis=-1)
# The determinant of a Householder reflector is -1. So whenever we actually
# made a reflection (tau != 0), multiply the result by -1.
sign_taus = jnp.prod(jnp.where(taus[..., :(n-1)] != 0, -1, 1), axis=-1).astype(sign_diag.dtype)
sign_taus = reductions.prod(jnp.where(taus[..., :(n-1)] != 0, -1, 1), axis=-1).astype(sign_diag.dtype)
return sign_diag * sign_taus, log_abs_det
@_wraps(
@ -217,8 +218,8 @@ def _slogdet_jvp(primals, tangents):
sign, ans = slogdet(x)
ans_dot = jnp.trace(solve(x, g), axis1=-1, axis2=-2)
if jnp.issubdtype(jnp._dtype(x), jnp.complexfloating):
sign_dot = (ans_dot - jnp.real(ans_dot).astype(ans_dot.dtype)) * sign
ans_dot = jnp.real(ans_dot)
sign_dot = (ans_dot - ufuncs.real(ans_dot).astype(ans_dot.dtype)) * sign
ans_dot = ufuncs.real(ans_dot)
else:
sign_dot = jnp.zeros_like(sign)
return (sign, ans), (sign_dot, ans_dot)
@ -292,18 +293,18 @@ def _cofactor_solve(a: ArrayLike, b: ArrayLike) -> Tuple[Array, Array]:
diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
iota = lax.expand_dims(jnp.arange(a_shape[-1], dtype=pivots.dtype),
range(pivots.ndim - 1))
parity = jnp.count_nonzero(pivots != iota, axis=-1)
parity = reductions.count_nonzero(pivots != iota, axis=-1)
sign = jnp.asarray(-2 * (parity % 2) + 1, dtype=dtype)
# partial_det[:, -1] contains the full determinant and
# partial_det[:, -2] contains det(u) / u_{nn}.
partial_det = jnp.cumprod(diag, axis=-1) * sign[..., None]
partial_det = reductions.cumprod(diag, axis=-1) * sign[..., None]
lu = lu.at[..., -1, -1].set(1.0 / partial_det[..., -2])
permutation = jnp.broadcast_to(permutation, (*batch_dims, a_shape[-1]))
iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in (*batch_dims, 1)))
# filter out any matrices that are not full rank
d = jnp.ones(x.shape[:-1], x.dtype)
d = lax_linalg.triangular_solve(lu, d, left_side=True, lower=False)
d = jnp.any(jnp.logical_or(jnp.isnan(d), jnp.isinf(d)), axis=-1)
d = reductions.any(ufuncs.logical_or(ufuncs.isnan(d), ufuncs.isinf(d)), axis=-1)
d = jnp.tile(d[..., None, None], d.ndim*(1,) + x.shape[-2:])
x = jnp.where(d, jnp.zeros_like(x), x) # first filter
x = x[iotas[:-1] + (permutation, slice(None))]
@ -344,7 +345,7 @@ def det(a: ArrayLike) -> Array:
return _det_3x3(a)
elif len(a_shape) >= 2 and a_shape[-1] == a_shape[-2]:
sign, logdet = slogdet(a)
return sign * jnp.exp(logdet).astype(sign.dtype)
return sign * ufuncs.exp(logdet).astype(sign.dtype)
else:
msg = "Argument to _det() must have shape [..., n, n], got {}"
raise ValueError(msg.format(a_shape))
@ -424,7 +425,7 @@ def pinv(a: ArrayLike, rcond: Optional[ArrayLike] = None,
m, n = arr.shape[-2:]
if m == 0 or n == 0:
return jnp.empty(arr.shape[:-2] + (n, m), arr.dtype)
arr = jnp.conj(arr)
arr = ufuncs.conj(arr)
if rcond is None:
max_rows_cols = max(arr.shape[-2:])
rcond = 10. * max_rows_cols * jnp.array(jnp.finfo(arr.dtype).eps)
@ -435,7 +436,7 @@ def pinv(a: ArrayLike, rcond: Optional[ArrayLike] = None,
rcond = lax.expand_dims(rcond[..., jnp.newaxis], range(s.ndim - rcond.ndim - 1))
cutoff = rcond * s[..., 0:1]
s = jnp.where(s > cutoff, s, jnp.inf).astype(u.dtype)
res = jnp.matmul(_T(vh), jnp.divide(_T(u), s[..., jnp.newaxis]),
res = jnp.matmul(_T(vh), ufuncs.divide(_T(u), s[..., jnp.newaxis]),
precision=lax.Precision.HIGHEST)
return lax.convert_element_type(res, arr.dtype)
@ -495,7 +496,7 @@ def norm(x: ArrayLike, ord: Union[int, str, None] = None,
# NumPy has an undocumented behavior that admits arbitrary rank inputs if
# `ord` is None: https://github.com/numpy/numpy/issues/14215
if ord is None:
return jnp.sqrt(jnp.sum(jnp.real(x * jnp.conj(x)), keepdims=keepdims))
return ufuncs.sqrt(reductions.sum(ufuncs.real(x * ufuncs.conj(x)), keepdims=keepdims))
axis = tuple(range(ndim))
elif isinstance(axis, tuple):
axis = tuple(canonicalize_axis(x, ndim) for x in axis)
@ -505,21 +506,21 @@ def norm(x: ArrayLike, ord: Union[int, str, None] = None,
num_axes = len(axis)
if num_axes == 1:
if ord is None or ord == 2:
return jnp.sqrt(jnp.sum(jnp.real(x * jnp.conj(x)), axis=axis,
keepdims=keepdims))
return ufuncs.sqrt(reductions.sum(ufuncs.real(x * ufuncs.conj(x)), axis=axis,
keepdims=keepdims))
elif ord == jnp.inf:
return jnp.amax(jnp.abs(x), axis=axis, keepdims=keepdims)
return reductions.amax(ufuncs.abs(x), axis=axis, keepdims=keepdims)
elif ord == -jnp.inf:
return jnp.amin(jnp.abs(x), axis=axis, keepdims=keepdims)
return reductions.amin(ufuncs.abs(x), axis=axis, keepdims=keepdims)
elif ord == 0:
return jnp.sum(x != 0, dtype=jnp.finfo(lax.dtype(x)).dtype,
axis=axis, keepdims=keepdims)
return reductions.sum(x != 0, dtype=jnp.finfo(lax.dtype(x)).dtype,
axis=axis, keepdims=keepdims)
elif ord == 1:
# Numpy has a special case for ord == 1 as an optimization. We don't
# really need the optimization (XLA could do it for us), but the Numpy
# code has slightly different type promotion semantics, so we need a
# special case too.
return jnp.sum(jnp.abs(x), axis=axis, keepdims=keepdims)
return reductions.sum(ufuncs.abs(x), axis=axis, keepdims=keepdims)
elif isinstance(ord, str):
msg = f"Invalid order '{ord}' for vector norm."
if ord == "inf":
@ -528,46 +529,46 @@ def norm(x: ArrayLike, ord: Union[int, str, None] = None,
msg += "Use '-jax.numpy.inf' instead."
raise ValueError(msg)
else:
abs_x = jnp.abs(x)
abs_x = ufuncs.abs(x)
ord_arr = lax_internal._const(abs_x, ord)
ord_inv = lax_internal._const(abs_x, 1. / ord_arr)
out = jnp.sum(abs_x ** ord_arr, axis=axis, keepdims=keepdims)
return jnp.power(out, ord_inv)
out = reductions.sum(abs_x ** ord_arr, axis=axis, keepdims=keepdims)
return ufuncs.power(out, ord_inv)
elif num_axes == 2:
row_axis, col_axis = cast(Tuple[int, ...], axis)
if ord is None or ord in ('f', 'fro'):
return jnp.sqrt(jnp.sum(jnp.real(x * jnp.conj(x)), axis=axis,
keepdims=keepdims))
return ufuncs.sqrt(reductions.sum(ufuncs.real(x * ufuncs.conj(x)), axis=axis,
keepdims=keepdims))
elif ord == 1:
if not keepdims and col_axis > row_axis:
col_axis -= 1
return jnp.amax(jnp.sum(jnp.abs(x), axis=row_axis, keepdims=keepdims),
axis=col_axis, keepdims=keepdims)
return reductions.amax(reductions.sum(ufuncs.abs(x), axis=row_axis, keepdims=keepdims),
axis=col_axis, keepdims=keepdims)
elif ord == -1:
if not keepdims and col_axis > row_axis:
col_axis -= 1
return jnp.amin(jnp.sum(jnp.abs(x), axis=row_axis, keepdims=keepdims),
axis=col_axis, keepdims=keepdims)
return reductions.amin(reductions.sum(ufuncs.abs(x), axis=row_axis, keepdims=keepdims),
axis=col_axis, keepdims=keepdims)
elif ord == jnp.inf:
if not keepdims and row_axis > col_axis:
row_axis -= 1
return jnp.amax(jnp.sum(jnp.abs(x), axis=col_axis, keepdims=keepdims),
return reductions.amax(reductions.sum(ufuncs.abs(x), axis=col_axis, keepdims=keepdims),
axis=row_axis, keepdims=keepdims)
elif ord == -jnp.inf:
if not keepdims and row_axis > col_axis:
row_axis -= 1
return jnp.amin(jnp.sum(jnp.abs(x), axis=col_axis, keepdims=keepdims),
return reductions.amin(reductions.sum(ufuncs.abs(x), axis=col_axis, keepdims=keepdims),
axis=row_axis, keepdims=keepdims)
elif ord in ('nuc', 2, -2):
x = jnp.moveaxis(x, axis, (-2, -1))
if ord == 2:
reducer = jnp.amax
reducer = reductions.amax
elif ord == -2:
reducer = jnp.amin
reducer = reductions.amin
else:
# `sum` takes an extra dtype= argument, unlike `amax` and `amin`.
reducer = jnp.sum # type: ignore[assignment]
reducer = reductions.sum # type: ignore[assignment]
y = reducer(svd(x, compute_uv=False), axis=-1)
if keepdims:
y = jnp.expand_dims(y, axis)

View File

@ -24,9 +24,11 @@ from jax import lax
from jax._src import dtypes
from jax._src import core
from jax._src.numpy.lax_numpy import (
all, arange, argmin, array, asarray, atleast_1d, concatenate, convolve,
diag, dot, finfo, full, maximum, ones, outer, roll, sqrt, trim_zeros,
trim_zeros_tol, true_divide, vander, zeros)
arange, argmin, array, asarray, atleast_1d, concatenate, convolve,
diag, dot, finfo, full, ones, outer, roll, trim_zeros,
trim_zeros_tol, vander, zeros)
from jax._src.numpy.ufuncs import maximum, true_divide, sqrt
from jax._src.numpy.reductions import all
from jax._src.numpy import linalg
from jax._src.numpy.util import (
_check_arraylike, _promote_dtypes, _promote_dtypes_inexact, _where, _wraps)

View File

@ -27,9 +27,11 @@ from jax._src import core
from jax._src import dtypes
from jax._src.lax import lax as lax_internal
from jax._src.numpy.lax_numpy import (
any, append, arange, array, asarray, concatenate, cumsum, diff,
empty, full_like, isnan, lexsort, moveaxis, nonzero, ones, ravel,
append, arange, array, asarray, concatenate, diff,
empty, full_like, lexsort, moveaxis, nonzero, ones, ravel,
sort, where, zeros)
from jax._src.numpy.reductions import any, cumsum
from jax._src.numpy.ufuncs import isnan
from jax._src.numpy.util import _check_arraylike, _wraps
from jax._src.typing import Array, ArrayLike

View File

@ -27,6 +27,8 @@ from jax._src import dtypes
from jax._src import util
from jax._src.lax import lax as lax_internal
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import reductions
from jax._src.numpy.util import _check_arraylike, _promote_dtypes
Array = Any
@ -98,7 +100,7 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
if core.is_empty_shape(indexer.slice_shape):
return x
x, y = jnp._promote_dtypes(x, y)
x, y = _promote_dtypes(x, y)
# Broadcast `y` to the slice output shape.
y = jnp.broadcast_to(y, tuple(indexer.slice_shape))
@ -155,7 +157,7 @@ def _segment_update(name: str,
bucket_size: Optional[int] = None,
reducer: Optional[Callable] = None,
mode: Optional[lax.GatherScatterMode] = None) -> Array:
jnp._check_arraylike(name, data, segment_ids)
_check_arraylike(name, data, segment_ids)
mode = lax.GatherScatterMode.FILL_OR_DROP if mode is None else mode
data = jnp.asarray(data)
segment_ids = jnp.asarray(segment_ids)
@ -239,7 +241,7 @@ def segment_sum(data: Array,
"""
return _segment_update(
"segment_sum", data, segment_ids, lax.scatter_add, num_segments,
indices_are_sorted, unique_indices, bucket_size, jnp.sum, mode=mode)
indices_are_sorted, unique_indices, bucket_size, reductions.sum, mode=mode)
def segment_prod(data: Array,
@ -295,7 +297,7 @@ def segment_prod(data: Array,
"""
return _segment_update(
"segment_prod", data, segment_ids, lax.scatter_mul, num_segments,
indices_are_sorted, unique_indices, bucket_size, jnp.prod, mode=mode)
indices_are_sorted, unique_indices, bucket_size, reductions.prod, mode=mode)
def segment_max(data: Array,
@ -350,7 +352,7 @@ def segment_max(data: Array,
"""
return _segment_update(
"segment_max", data, segment_ids, lax.scatter_max, num_segments,
indices_are_sorted, unique_indices, bucket_size, jnp.max, mode=mode)
indices_are_sorted, unique_indices, bucket_size, reductions.max, mode=mode)
def segment_min(data: Array,
@ -405,4 +407,4 @@ def segment_min(data: Array,
"""
return _segment_update(
"segment_min", data, segment_ids, lax.scatter_min, num_segments,
indices_are_sorted, unique_indices, bucket_size, jnp.min, mode=mode)
indices_are_sorted, unique_indices, bucket_size, reductions.min, mode=mode)

View File

@ -17,7 +17,8 @@ from typing import overload, Literal, Optional, Tuple, Union
import jax
from jax import lax
from jax import numpy as jnp
from jax._src.numpy.lax_numpy import _reduction_dims, _promote_args_inexact
from jax._src.numpy.reductions import _reduction_dims
from jax._src.numpy.util import _promote_args_inexact
from jax._src.typing import Array, ArrayLike
import numpy as np

View File

@ -42,7 +42,8 @@ from jax._src.lax import lax as lax_internal
from jax._src.lax import utils as lax_utils
from jax._src.lib import gpu_prng
from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy import lax_numpy
from jax._src.numpy.lax_numpy import _set_device_array_base_attributes
from jax._src.numpy.util import _register_stackable
from jax._src.sharding import (
NamedSharding, PmapSharding, GSPMDSharding)
from jax._src.util import canonicalize_axis, safe_map, safe_zip
@ -256,10 +257,10 @@ class PRNGKeyArray(metaclass=PRNGKeyArrayMeta):
def flatten(self, *_, **__) -> 'PRNGKeyArray': assert False
lax_numpy._set_device_array_base_attributes(PRNGKeyArray, include=[
_set_device_array_base_attributes(PRNGKeyArray, include=[
'__getitem__', 'ravel', 'squeeze', 'swapaxes', 'take', 'reshape',
'transpose', 'flatten', 'T'])
lax_numpy._register_stackable(PRNGKeyArray)
_register_stackable(PRNGKeyArray)
basearray.Array.register(PRNGKeyArray)

View File

@ -37,9 +37,8 @@ from jax._src.core import NamedShape
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.lax import lax as lax_internal
from jax._src.numpy.lax_numpy import (
_arraylike, _check_arraylike, _convert_and_clip_integer,
_promote_dtypes_inexact)
from jax._src.numpy.lax_numpy import _convert_and_clip_integer
from jax._src.numpy.util import _arraylike, _check_arraylike, _promote_dtypes_inexact
from jax._src.typing import Array, ArrayLike, DTypeLike
from jax._src.util import canonicalize_axis

View File

@ -1,4 +1,4 @@
# Copyright 2022 The JAX Authors.
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -18,9 +18,8 @@ import scipy.cluster.vq
import textwrap
from jax import vmap
import jax.numpy as jnp
from jax._src.numpy.util import _wraps, _check_arraylike, _promote_dtypes_inexact
from jax._src.numpy.lax_numpy import argmin
from jax._src.numpy.linalg import norm
_no_chkfinite_doc = textwrap.dedent("""
@ -41,7 +40,7 @@ def vq(obs, code_book, check_finite=True):
raise ValueError("ndim different than 1 or 2 are not supported")
# explicitly rank promotion
dist = vmap(lambda ob: norm(ob[None] - code_book, axis=-1))(obs)
code = argmin(dist, axis=-1)
dist = vmap(lambda ob: jnp.linalg.norm(ob[None] - code_book, axis=-1))(obs)
code = jnp.argmin(dist, axis=-1)
dist_min = vmap(operator.getitem)(dist, code)
return code, dist_min

View File

@ -22,15 +22,15 @@ import warnings
from typing import cast, overload, Any, Literal, Optional, Tuple, Union
import jax
import jax.numpy as jnp
from jax import jit, vmap, jvp
from jax import lax
from jax._src import dtypes
from jax._src.lax import linalg as lax_linalg
from jax._src.lax import qdwh
from jax._src.numpy.lax_numpy import _check_arraylike
from jax._src.numpy.util import _wraps, _promote_dtypes_inexact, _promote_dtypes_complex
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import linalg as np_linalg
from jax._src.numpy.util import (
_check_arraylike, _wraps, _promote_dtypes, _promote_dtypes_inexact,
_promote_dtypes_complex)
from jax._src.typing import Array, ArrayLike
@ -125,7 +125,7 @@ def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite'))
def det(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> Array:
del overwrite_a, check_finite # unused
return np_linalg.det(a)
return jnp.linalg.det(a)
@overload
@ -210,7 +210,7 @@ def schur(a: ArrayLike, output: str = 'real') -> Tuple[Array, Array]:
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite'))
def inv(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> Array:
del overwrite_a, check_finite # unused
return np_linalg.inv(a)
return jnp.linalg.inv(a)
@_wraps(scipy.linalg.lu_factor,
@ -332,7 +332,7 @@ def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = "
@partial(jit, static_argnames=('assume_a', 'lower'))
def _solve(a: ArrayLike, b: ArrayLike, assume_a: str, lower: bool) -> Array:
if assume_a != 'pos':
return np_linalg.solve(a, b)
return jnp.linalg.solve(a, b)
a, b = _promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b))
lax_linalg._check_solve_shapes(a, b)
@ -453,7 +453,7 @@ def _calc_P_Q(A: ArrayLike) -> Tuple[Array, Array, Array]:
A = jnp.asarray(A)
if A.ndim != 2 or A.shape[0] != A.shape[1]:
raise ValueError('expected A to be a square matrix')
A_L1 = np_linalg.norm(A,1)
A_L1 = jnp.linalg.norm(A,1)
n_squarings: Array
U: Array
V: Array
@ -484,7 +484,7 @@ def _solve_P_Q(P: ArrayLike, Q: ArrayLike, upper_triangular: bool = False) -> Ar
if upper_triangular:
return solve_triangular(Q, P)
else:
return np_linalg.solve(Q, P)
return jnp.linalg.solve(Q, P)
def _precise_dot(A: ArrayLike, B: ArrayLike) -> Array:
return jnp.dot(A, B, precision=lax.Precision.HIGHEST)
@ -609,7 +609,7 @@ def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: Optional[str] = None,
def block_diag(*arrs: ArrayLike) -> Array:
if len(arrs) == 0:
arrs = cast(Tuple[ArrayLike], (jnp.zeros((1, 0)),))
arrs = cast(Tuple[ArrayLike], jnp._promote_dtypes(*arrs))
arrs = cast(Tuple[ArrayLike], _promote_dtypes(*arrs))
bad_shapes = [i for i, a in enumerate(arrs) if jnp.ndim(a) > 2]
if bad_shapes:
raise ValueError("Arguments to jax.scipy.linalg.block_diag must have at "
@ -948,8 +948,8 @@ def rsf2csf(T: ArrayLike, Z: ArrayLike, check_finite: bool = True) -> Tuple[Arra
return T, Z
def _update_T_Z(m, T, Z):
mu = np_linalg.eigvals(lax.dynamic_slice(T, (m-1, m-1), (2, 2))) - T[m, m]
r = np_linalg.norm(jnp.array([mu[0], T[m, m-1]])).astype(T.dtype)
mu = jnp.linalg.eigvals(lax.dynamic_slice(T, (m-1, m-1), (2, 2))) - T[m, m]
r = jnp.linalg.norm(jnp.array([mu[0], T[m, m-1]])).astype(T.dtype)
c = mu[0] / r
s = T[m, m-1] / r
G = jnp.array([[c.conj(), s], [-s, c]], dtype=T.dtype)

View File

@ -24,7 +24,7 @@ import scipy.ndimage
from jax._src import api
from jax._src import util
from jax import lax
from jax._src.numpy import lax_numpy as jnp
import jax.numpy as jnp
from jax._src.numpy.util import _wraps
from jax._src.typing import ArrayLike, Array
from jax._src.util import safe_zip as zip

View File

@ -22,13 +22,13 @@ import scipy.signal as osp_signal
import jax
import jax.numpy.fft
import jax.numpy as jnp
from jax import lax
from jax._src import dtypes
from jax._src.lax.lax import PrecisionLike
from jax._src.numpy.lax_numpy import _check_arraylike
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import linalg
from jax._src.numpy.util import _wraps, _promote_dtypes_inexact, _promote_dtypes_complex
from jax._src.numpy.util import (
_check_arraylike, _wraps, _promote_dtypes_inexact, _promote_dtypes_complex)
from jax._src.third_party.scipy import signal_helper
from jax._src.typing import Array, ArrayLike
from jax._src.util import canonicalize_axis, tuple_delete, tuple_insert

View File

@ -29,8 +29,7 @@ from jax._src import api
from jax._src import core
from jax._src import dtypes
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.lax_numpy import (
moveaxis, _promote_args_inexact, _promote_dtypes_inexact)
from jax._src.numpy.util import _promote_args_inexact, _promote_dtypes_inexact
from jax._src.numpy.util import _wraps
from jax._src.ops import special as ops_special
from jax._src.third_party.scipy.betaln import betaln as _betaln_impl
@ -725,7 +724,7 @@ def bessel_jn(z: ArrayLike, *, v: int, n_iter: int=50) -> Array:
bessel_jn_fun = partial(_bessel_jn, v=v, n_iter=n_iter)
for _ in range(z.ndim):
bessel_jn_fun = vmap(bessel_jn_fun)
return moveaxis(bessel_jn_fun(z), -1, 0)
return jnp.moveaxis(bessel_jn_fun(z), -1, 0)
def _gen_recurrence_mask(

View File

@ -22,8 +22,7 @@ import jax.numpy as jnp
from jax import jit
from jax._src import dtypes
from jax._src.api import vmap
from jax._src.numpy.lax_numpy import _check_arraylike
from jax._src.numpy.util import _wraps
from jax._src.numpy.util import _check_arraylike, _wraps
from jax._src.typing import ArrayLike, Array
from jax._src.util import canonicalize_axis

View File

@ -16,16 +16,16 @@
import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy.util import _wraps
from jax._src.numpy.util import _wraps, _promote_args_inexact
from jax._src.typing import Array, ArrayLike
from jax.scipy.special import xlogy, xlog1py
@_wraps(osp_stats.bernoulli.logpmf, update_doc=False)
def logpmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
k, p, loc = jnp._promote_args_inexact("bernoulli.logpmf", k, p, loc)
k, p, loc = _promote_args_inexact("bernoulli.logpmf", k, p, loc)
zero = _lax_const(k, 0)
one = _lax_const(k, 1)
x = lax.sub(k, loc)
@ -39,7 +39,7 @@ def pmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
@_wraps(osp_stats.bernoulli.cdf, update_doc=False)
def cdf(k: ArrayLike, p: ArrayLike) -> Array:
k, p = jnp._promote_args_inexact('bernoulli.cdf', k, p)
k, p = _promote_args_inexact('bernoulli.cdf', k, p)
zero, one = _lax_const(k, 0), _lax_const(k, 1)
conds = [
jnp.isnan(k) | jnp.isnan(p) | (p < zero) | (p > one),
@ -52,7 +52,7 @@ def cdf(k: ArrayLike, p: ArrayLike) -> Array:
@_wraps(osp_stats.bernoulli.ppf, update_doc=False)
def ppf(q: ArrayLike, p: ArrayLike) -> Array:
q, p = jnp._promote_args_inexact('bernoulli.ppf', q, p)
q, p = _promote_args_inexact('bernoulli.ppf', q, p)
zero, one = _lax_const(q, 0), _lax_const(q, 1)
return jnp.where(
jnp.isnan(q) | jnp.isnan(p) | (p < zero) | (p > one) | (q < zero) | (q > one),

View File

@ -15,9 +15,9 @@
import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps
from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf, logical_or
from jax._src.numpy.util import _wraps, _promote_args_inexact
from jax._src.typing import Array, ArrayLike
from jax.scipy.special import betaln, xlogy, xlog1py
@ -32,8 +32,8 @@ def logpdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
log_linear_term = lax.add(xlogy(lax.sub(a, one), y),
xlog1py(lax.sub(b, one), lax.neg(y)))
log_probs = lax.sub(lax.add(shape_term, log_linear_term), lax.log(scale))
return where(logical_or(lax.gt(x, lax.add(loc, scale)),
lax.lt(x, loc)), -inf, log_probs)
return jnp.where(jnp.logical_or(lax.gt(x, lax.add(loc, scale)),
lax.lt(x, loc)), -jnp.inf, log_probs)
@_wraps(osp_stats.beta.pdf, update_doc=False)

View File

@ -16,9 +16,9 @@
import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps
from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf, logical_or, nan
from jax._src.numpy.util import _wraps, _promote_args_inexact
from jax._src.scipy.special import betaln
from jax._src.typing import Array, ArrayLike
@ -34,10 +34,10 @@ def logpmf(k: ArrayLike, n: ArrayLike, a: ArrayLike, b: ArrayLike,
combiln = lax.neg(lax.add(lax.log1p(n), betaln(lax.add(lax.sub(n,y), one), lax.add(y,one))))
beta_lns = lax.sub(betaln(lax.add(y,a), lax.add(lax.sub(n,y),b)), betaln(a,b))
log_probs = lax.add(combiln, beta_lns)
y_cond = logical_or(lax.lt(y, lax.neg(loc)), lax.gt(y, lax.sub(n, loc)))
log_probs = where(y_cond, -inf, log_probs)
n_a_b_cond = logical_or(logical_or(lax.lt(n, one), lax.lt(a, zero)), lax.lt(b, zero))
return where(n_a_b_cond, nan, log_probs)
y_cond = jnp.logical_or(lax.lt(y, lax.neg(loc)), lax.gt(y, lax.sub(n, loc)))
log_probs = jnp.where(y_cond, -jnp.inf, log_probs)
n_a_b_cond = jnp.logical_or(jnp.logical_or(lax.lt(n, one), lax.lt(a, zero)), lax.lt(b, zero))
return jnp.where(n_a_b_cond, jnp.nan, log_probs)
@_wraps(osp_stats.betabinom.pmf, update_doc=False)

View File

@ -19,7 +19,7 @@ import scipy.stats as osp_stats
from jax import lax
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps
from jax._src.numpy.lax_numpy import _promote_args_inexact
from jax._src.numpy.util import _promote_args_inexact
from jax._src.typing import Array, ArrayLike

View File

@ -16,9 +16,9 @@
import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps
from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf
from jax._src.numpy.util import _wraps, _promote_args_inexact
from jax._src.typing import Array, ArrayLike
@ -35,7 +35,7 @@ def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
nrml_cnst = lax.neg(lax.add(lax.lgamma(df_on_two),lax.div(lax.mul(lax.log(two), df),two)))
log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel)
return where(lax.lt(x, loc), -inf, log_probs)
return jnp.where(lax.lt(x, loc), -jnp.inf, log_probs)
@_wraps(osp_stats.chi2.pdf, update_doc=False)
def pdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:

View File

@ -16,11 +16,10 @@
import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy.util import _wraps
from jax._src.numpy.util import _promote_dtypes_inexact, _wraps
from jax.scipy.special import gammaln, xlogy
from jax._src.numpy.lax_numpy import _promote_dtypes_inexact
from jax._src.typing import Array, ArrayLike

View File

@ -15,8 +15,8 @@
import scipy.stats as osp_stats
from jax import lax
from jax._src.numpy.util import _wraps
from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf
import jax.numpy as jnp
from jax._src.numpy.util import _wraps, _promote_args_inexact
from jax._src.typing import Array, ArrayLike
@ -26,7 +26,7 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
log_scale = lax.log(scale)
linear_term = lax.div(lax.sub(x, loc), scale)
log_probs = lax.neg(lax.add(linear_term, log_scale))
return where(lax.lt(x, loc), -inf, log_probs)
return jnp.where(lax.lt(x, loc), -jnp.inf, log_probs)
@_wraps(osp_stats.expon.pdf, update_doc=False)
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:

View File

@ -15,9 +15,9 @@
import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps
from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf
from jax._src.numpy.util import _wraps, _promote_args_inexact
from jax._src.typing import Array, ArrayLike
from jax.scipy.special import gammaln, xlogy
@ -30,7 +30,7 @@ def logpdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1)
log_linear_term = lax.sub(xlogy(lax.sub(a, one), y), y)
shape_terms = lax.add(gammaln(a), lax.log(scale))
log_probs = lax.sub(log_linear_term, shape_terms)
return where(lax.lt(x, loc), -inf, log_probs)
return jnp.where(lax.lt(x, loc), -jnp.inf, log_probs)
@_wraps(osp_stats.gamma.pdf, update_doc=False)
def pdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:

View File

@ -14,8 +14,7 @@
import scipy.stats as osp_stats
from jax import lax
from jax._src.numpy.util import _wraps
from jax._src.numpy.lax_numpy import _promote_args_inexact
from jax._src.numpy.util import _wraps, _promote_args_inexact
from jax._src.typing import Array, ArrayLike

View File

@ -15,16 +15,16 @@
import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy.util import _wraps
from jax._src.numpy.util import _wraps, _promote_args_inexact
from jax.scipy.special import xlog1py
from jax._src.typing import Array, ArrayLike
@_wraps(osp_stats.geom.logpmf, update_doc=False)
def logpmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
k, p, loc = jnp._promote_args_inexact("geom.logpmf", k, p, loc)
k, p, loc = _promote_args_inexact("geom.logpmf", k, p, loc)
zero = _lax_const(k, 0)
one = _lax_const(k, 1)
x = lax.sub(k, loc)

View File

@ -21,8 +21,7 @@ import scipy.stats as osp_stats
import jax.numpy as jnp
from jax import jit, lax, random, vmap
from jax._src.numpy.lax_numpy import _check_arraylike, _promote_dtypes_inexact
from jax._src.numpy.util import _wraps
from jax._src.numpy.util import _check_arraylike, _promote_dtypes_inexact, _wraps
from jax._src.tree_util import register_pytree_node_class
from jax.scipy import linalg, special

View File

@ -16,8 +16,7 @@ import scipy.stats as osp_stats
from jax import lax
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps
from jax._src.numpy.lax_numpy import _promote_args_inexact
from jax._src.numpy.util import _wraps, _promote_args_inexact
from jax._src.typing import Array, ArrayLike

View File

@ -16,10 +16,9 @@ import scipy.stats as osp_stats
from jax.scipy.special import expit, logit
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps
from jax._src.numpy.lax_numpy import _promote_args_inexact
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy.util import _wraps, _promote_args_inexact
from jax._src.typing import Array, ArrayLike

View File

@ -15,7 +15,7 @@
import scipy.stats as osp_stats
from jax import lax
from jax._src.numpy import lax_numpy as jnp
import jax.numpy as jnp
from jax._src.numpy.util import _wraps, _promote_args_inexact, _promote_args_numeric
from jax._src.scipy.special import gammaln, xlogy
from jax._src.typing import Array, ArrayLike

View File

@ -19,8 +19,7 @@ import scipy.stats as osp_stats
from jax import lax
from jax import numpy as jnp
from jax._src.numpy.util import _wraps
from jax._src.numpy.lax_numpy import _promote_dtypes_inexact
from jax._src.numpy.util import _wraps, _promote_dtypes_inexact
from jax._src.typing import Array, ArrayLike

View File

@ -16,9 +16,9 @@
import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf
from jax._src.numpy.util import _wraps
from jax._src.numpy.util import _wraps, _promote_args_inexact
from jax._src.scipy.special import gammaln, xlogy
from jax._src.typing import Array, ArrayLike
@ -34,7 +34,7 @@ def logpmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Arra
)
log_linear_term = lax.add(xlogy(n, p), xlogy(y, lax.sub(one, p)))
log_probs = lax.add(comb_term, log_linear_term)
return where(lax.lt(k, loc), -inf, log_probs)
return jnp.where(lax.lt(k, loc), -jnp.inf, log_probs)
@_wraps(osp_stats.nbinom.pmf, update_doc=False)

View File

@ -18,10 +18,9 @@ import numpy as np
import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy.util import _wraps
from jax._src.numpy.lax_numpy import _promote_args_inexact
from jax._src.numpy.util import _wraps, _promote_args_inexact
from jax._src.typing import Array, ArrayLike
from jax.scipy import special

View File

@ -16,9 +16,9 @@
import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps
from jax._src.numpy.lax_numpy import _promote_args_inexact, inf, where
from jax._src.numpy.util import _wraps, _promote_args_inexact
from jax._src.typing import Array, ArrayLike
@ -29,7 +29,7 @@ def logpdf(x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1)
scaled_x = lax.div(lax.sub(x, loc), scale)
normalize_term = lax.log(lax.div(scale, b))
log_probs = lax.neg(lax.add(normalize_term, lax.mul(lax.add(b, one), lax.log(scaled_x))))
return where(lax.lt(x, lax.add(loc, scale)), -inf, log_probs)
return jnp.where(lax.lt(x, lax.add(loc, scale)), -jnp.inf, log_probs)
@_wraps(osp_stats.pareto.pdf, update_doc=False)
def pdf(x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:

View File

@ -16,16 +16,16 @@
import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy.util import _wraps, _promote_args_inexact
from jax._src.typing import Array, ArrayLike
from jax.scipy.special import xlogy, gammaln, gammaincc
@_wraps(osp_stats.poisson.logpmf, update_doc=False)
def logpmf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array:
k, mu, loc = jnp._promote_args_inexact("poisson.logpmf", k, mu, loc)
k, mu, loc = _promote_args_inexact("poisson.logpmf", k, mu, loc)
zero = _lax_const(k, 0)
x = lax.sub(k, loc)
log_probs = xlogy(x, mu) - gammaln(x + 1) - mu
@ -37,7 +37,7 @@ def pmf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array:
@_wraps(osp_stats.poisson.cdf, update_doc=False)
def cdf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array:
k, mu, loc = jnp._promote_args_inexact("poisson.logpmf", k, mu, loc)
k, mu, loc = _promote_args_inexact("poisson.logpmf", k, mu, loc)
zero = _lax_const(k, 0)
x = lax.sub(k, loc)
p = gammaincc(jnp.floor(1 + x), mu)

View File

@ -18,8 +18,7 @@ import scipy.stats as osp_stats
from jax import lax
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps
from jax._src.numpy.lax_numpy import _promote_args_inexact
from jax._src.numpy.util import _wraps, _promote_args_inexact
from jax._src.typing import Array, ArrayLike

View File

@ -16,9 +16,8 @@
import scipy.stats as osp_stats
from jax import lax
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy.util import _wraps
from jax._src.numpy.lax_numpy import _promote_args_inexact
import jax.numpy as jnp
from jax._src.numpy.util import _wraps, _promote_args_inexact
from jax._src.scipy.stats import norm
from jax._src.scipy.special import logsumexp, log_ndtr, ndtr

View File

@ -16,9 +16,9 @@
import scipy.stats as osp_stats
from jax import lax
from jax._src.numpy.util import _wraps
from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf, logical_or
from jax.numpy import where, inf, logical_or
from jax._src.typing import Array, ArrayLike
from jax._src.numpy.util import _wraps, _promote_args_inexact
@_wraps(osp_stats.uniform.logpdf, update_doc=False)

View File

@ -15,8 +15,8 @@
import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy.util import _wraps, _promote_args_inexact
from jax._src.typing import Array, ArrayLike

View File

@ -46,7 +46,7 @@ from jax._src import dtypes as _dtypes
from jax._src.config import (flags, bool_env, config,
raise_persistent_cache_errors,
persistent_cache_min_compile_time_secs)
from jax._src.numpy.lax_numpy import _promote_dtypes, _promote_dtypes_inexact
from jax._src.numpy.util import _promote_dtypes, _promote_dtypes_inexact
from jax._src.util import unzip2
from jax._src.public_test_util import ( # noqa: F401
_assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads,

View File

@ -1,7 +1,7 @@
import numpy as np
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import linalg as la
import jax.numpy as jnp
import jax.numpy.linalg as la
from jax._src.numpy.util import _check_arraylike, _wraps

View File

@ -1,7 +1,7 @@
from jax import lax
import jax.numpy as jnp
from jax._src.typing import Array, ArrayLike
from jax._src.numpy.lax_numpy import _promote_args_inexact
from jax._src.numpy.util import _promote_args_inexact
# Note: for mysterious reasons, annotating this leads to very slow mypy runs.
# def algdiv(a: ArrayLike, b: ArrayLike) -> Array:

View File

@ -1,11 +1,10 @@
from itertools import product
import scipy.interpolate as osp_interpolate
from jax.numpy import (asarray, broadcast_arrays, can_cast,
empty, nan, searchsorted, where, zeros)
from jax._src.tree_util import register_pytree_node
from jax._src.numpy.lax_numpy import (_check_arraylike, _promote_dtypes_inexact,
asarray, broadcast_arrays, can_cast,
empty, nan, searchsorted, where, zeros)
from jax._src.numpy.util import _wraps
from jax._src.numpy.util import _check_arraylike, _promote_dtypes_inexact, _wraps
def _ndim_coords_from_arrays(points, ndim=None):

View File

@ -3,7 +3,7 @@ from typing import Callable, Tuple
import scipy.linalg
from jax import jit, lax
from jax._src.numpy import lax_numpy as jnp
import jax.numpy as jnp
from jax._src.numpy.linalg import norm
from jax._src.numpy.util import _wraps
from jax._src.scipy.linalg import rsf2csf, schur

View File

@ -4,7 +4,7 @@ import scipy.signal as osp_signal
from typing import Any, Optional, Tuple, Union
import warnings
from jax._src.numpy import lax_numpy as jnp
import jax.numpy as jnp
from jax._src.typing import Array, ArrayLike, DTypeLike

View File

@ -34,7 +34,7 @@ from jax._src.interpreters import ad
from jax._src.lax.lax import _const
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib import gpu_sparse
from jax._src.numpy.lax_numpy import _promote_dtypes
from jax._src.numpy.util import _promote_dtypes
from jax._src.typing import Array, ArrayLike, DTypeLike
import jax.numpy as jnp

View File

@ -34,7 +34,7 @@ from jax._src import dispatch
from jax._src.interpreters import ad
from jax._src.lax.lax import _const
from jax._src.lib import gpu_sparse
from jax._src.numpy.lax_numpy import _promote_dtypes
from jax._src.numpy.util import _promote_dtypes
from jax._src.typing import Array, ArrayLike, DTypeLike
import jax.numpy as jnp