Call _check_arraylike for jnp.linalg & jnp.fft functions

This commit is contained in:
Jake VanderPlas 2022-10-27 09:01:28 -07:00
parent 32a0ea80ef
commit 2416d15435
4 changed files with 115 additions and 93 deletions

View File

@ -21,6 +21,9 @@ Remember to align the itemized text with the first line of an item within a list
* Breaking Changes
* {func}`jax.numpy.gradient` now behaves like most other functions in {mod}`jax.numpy`,
and forbids passing lists or tuples in place of arrays ({jax-issue}`#12958`)
* Functions in {mod}`jax.numpy.linalg` and {mod}`jax.numpy.fft` now uniformly
require inputs to be array-like: i.e. lists and tuples cannot be used in place
of arrays. Part of {jax-issue}`#7737`.
## jaxlib 0.3.24
* Changes

View File

@ -21,7 +21,7 @@ from jax import dtypes
from jax import lax
from jax._src.lib import xla_client
from jax._src.util import safe_zip
from jax._src.numpy.util import _wraps
from jax._src.numpy.util import _check_arraylike, _wraps
from jax._src.numpy import lax_numpy as jnp
from jax._src.typing import Array, ArrayLike
@ -42,8 +42,7 @@ def _fft_core(func_name: str, fft_type: xla_client.FftType, a: ArrayLike,
s: Optional[Shape], axes: Optional[Sequence[int]],
norm: Optional[str]) -> Array:
full_name = "jax.numpy.fft." + func_name
# TODO(jakevdp): call check_arraylike
_check_arraylike(full_name, a)
arr = jnp.asarray(a)
if s is not None:
@ -285,6 +284,7 @@ def rfftfreq(n: int, d: ArrayLike = 1.0) -> Array:
@_wraps(np.fft.fftshift)
def fftshift(x: ArrayLike, axes: Union[None, int, Sequence[int]] = None) -> Array:
_check_arraylike("fftshift", x)
x = jnp.asarray(x)
shift: Union[int, Sequence[int]]
if axes is None:
@ -300,6 +300,7 @@ def fftshift(x: ArrayLike, axes: Union[None, int, Sequence[int]] = None) -> Arra
@_wraps(np.fft.ifftshift)
def ifftshift(x: ArrayLike, axes: Union[None, int, Sequence[int]] = None) -> Array:
_check_arraylike("ifftshift", x)
x = jnp.asarray(x)
shift: Union[int, Sequence[int]]
if axes is None:

View File

@ -28,7 +28,7 @@ from jax import lax
from jax._src.lax import lax as lax_internal
from jax._src.lax import linalg as lax_linalg
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy.util import _wraps, _promote_dtypes_inexact
from jax._src.numpy.util import _wraps, _promote_dtypes_inexact, _check_arraylike
from jax._src.util import canonicalize_axis
from jax._src.typing import ArrayLike, Array
@ -44,6 +44,7 @@ def _H(x: ArrayLike) -> Array:
@_wraps(np.linalg.cholesky)
@jit
def cholesky(a: ArrayLike) -> Array:
_check_arraylike("jnp.linalg.cholesky", a)
a, = _promote_dtypes_inexact(jnp.asarray(a))
return lax_linalg.cholesky(a)
@ -67,6 +68,7 @@ def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
@partial(jit, static_argnames=('full_matrices', 'compute_uv', 'hermitian'))
def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
hermitian: bool = False) -> Union[Array, Tuple[Array, Array, Array]]:
_check_arraylike("jnp.linalg.svd", a)
a, = _promote_dtypes_inexact(jnp.asarray(a))
if hermitian:
w, v = lax_linalg.eigh(a)
@ -90,7 +92,7 @@ def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
@_wraps(np.linalg.matrix_power)
@partial(jit, static_argnames=('n',))
def matrix_power(a: ArrayLike, n: int) -> Array:
# TODO(jakevdp): call _check_arraylike
_check_arraylike("jnp.linalg.matrix_power", a)
arr, = _promote_dtypes_inexact(jnp.asarray(a))
if arr.ndim < 2:
@ -129,6 +131,7 @@ def matrix_power(a: ArrayLike, n: int) -> Array:
@_wraps(np.linalg.matrix_rank)
@jit
def matrix_rank(M: ArrayLike, tol: Optional[ArrayLike] = None) -> Array:
_check_arraylike("jnp.linalg.matrix_rank", M)
M, = _promote_dtypes_inexact(jnp.asarray(M))
if M.ndim < 2:
return jnp.any(M != 0).astype(jnp.int32)
@ -191,6 +194,7 @@ def _slogdet_qr(a: Array) -> Tuple[Array, Array]:
"""))
@partial(jit, static_argnames=('method',))
def slogdet(a: ArrayLike, *, method: Optional[str] = None) -> Tuple[Array, Array]:
_check_arraylike("jnp.linalg.slogdet", a)
a, = _promote_dtypes_inexact(jnp.asarray(a))
a_shape = jnp.shape(a)
if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]:
@ -329,6 +333,7 @@ def _det_3x3(a: Array) -> Array:
@_wraps(np.linalg.det)
@jit
def det(a: ArrayLike) -> Array:
_check_arraylike("jnp.linalg.det", a)
a, = _promote_dtypes_inexact(jnp.asarray(a))
a_shape = jnp.shape(a)
if len(a_shape) >= 2 and a_shape[-1] == 2 and a_shape[-2] == 2:
@ -361,6 +366,7 @@ backend. However eigendecomposition for symmetric/Hermitian matrices is
implemented more widely (see :func:`jax.numpy.linalg.eigh`).
""")
def eig(a: ArrayLike) -> Tuple[Array, Array]:
_check_arraylike("jnp.linalg.eig", a)
a, = _promote_dtypes_inexact(jnp.asarray(a))
w, v = lax_linalg.eig(a, compute_left_eigenvectors=False)
return w, v
@ -369,6 +375,7 @@ def eig(a: ArrayLike) -> Tuple[Array, Array]:
@_wraps(np.linalg.eigvals)
@jit
def eigvals(a: ArrayLike) -> Array:
_check_arraylike("jnp.linalg.eigvals", a)
return lax_linalg.eig(a, compute_left_eigenvectors=False,
compute_right_eigenvectors=False)[0]
@ -377,6 +384,7 @@ def eigvals(a: ArrayLike) -> Array:
@partial(jit, static_argnames=('UPLO', 'symmetrize_input'))
def eigh(a: ArrayLike, UPLO: Optional[str] = None,
symmetrize_input: bool = True) -> Tuple[Array, Array]:
_check_arraylike("jnp.linalg.eigh", a)
if UPLO is None or UPLO == "L":
lower = True
elif UPLO == "U":
@ -393,6 +401,7 @@ def eigh(a: ArrayLike, UPLO: Optional[str] = None,
@_wraps(np.linalg.eigvalsh)
@partial(jit, static_argnames=('UPLO',))
def eigvalsh(a: ArrayLike, UPLO: Optional[str] = 'L') -> Array:
_check_arraylike("jnp.linalg.eigvalsh", a)
w, _ = eigh(a, UPLO)
return w
@ -407,6 +416,7 @@ def eigvalsh(a: ArrayLike, UPLO: Optional[str] = 'L') -> Array:
def pinv(a: ArrayLike, rcond: Optional[ArrayLike] = None) -> Array:
# Uses same algorithm as
# https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979
_check_arraylike("jnp.linalg.pinv", a)
arr = jnp.conj(a)
if rcond is None:
max_rows_cols = max(arr.shape[-2:])
@ -447,7 +457,7 @@ def _pinv_jvp(rcond, primals, tangents):
@_wraps(np.linalg.inv)
@jit
def inv(a: ArrayLike) -> Array:
# TODO(jakevdp): call _check_arraylike
_check_arraylike("jnp.linalg.inv", a)
arr = jnp.asarray(a)
if arr.ndim < 2 or arr.shape[-1] != arr.shape[-2]:
raise ValueError(
@ -461,6 +471,7 @@ def inv(a: ArrayLike) -> Array:
def norm(x: ArrayLike, ord: Union[int, str, None] = None,
axis: Union[None, Tuple[int, ...], int] = None,
keepdims: bool = False) -> Array:
_check_arraylike("jnp.linalg.norm", x)
x, = _promote_dtypes_inexact(jnp.asarray(x))
x_shape = jnp.shape(x)
ndim = len(x_shape)
@ -560,6 +571,7 @@ def qr(a: ArrayLike, mode: str = "reduced") -> Union[Array, Tuple[Array, Array]]
@_wraps(np.linalg.qr)
@partial(jit, static_argnames=('mode',))
def qr(a: ArrayLike, mode: str = "reduced") -> Union[Array, Tuple[Array, Array]]:
_check_arraylike("jnp.linalg.qr", a)
a, = _promote_dtypes_inexact(jnp.asarray(a))
if mode == "raw":
a, taus = lax_linalg.geqrf(a)
@ -579,6 +591,7 @@ def qr(a: ArrayLike, mode: str = "reduced") -> Union[Array, Tuple[Array, Array]]
@_wraps(np.linalg.solve)
@jit
def solve(a: ArrayLike, b: ArrayLike) -> Array:
_check_arraylike("jnp.linalg.solve", a, b)
a, b = _promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b))
return lax_linalg._solve(a, b)
@ -645,6 +658,7 @@ _jit_lstsq = jit(partial(_lstsq, numpy_resid=False))
"""))
def lstsq(a: ArrayLike, b: ArrayLike, rcond: Optional[float] = None, *,
numpy_resid: bool = False) -> Tuple[Array, Array, Array, Array]:
_check_arraylike("jnp.linalg.lstsq", a, b)
if numpy_resid:
return _lstsq(a, b, rcond, numpy_resid=True)
return _jit_lstsq(a, b, rcond)

View File

@ -2,7 +2,7 @@ import numpy as np
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import linalg as la
from jax._src.numpy.util import _wraps
from jax._src.numpy.util import _check_arraylike, _wraps
def _isEmpty2d(arr):
@ -41,6 +41,7 @@ def _assert2d(*arrays):
@_wraps(np.linalg.cond)
def cond(x, p=None):
_check_arraylike('jnp.linalg.cond', x)
_assertNoEmpty2d(x)
if p in (None, 2):
s = la.svd(x, compute_uv=False)
@ -63,6 +64,7 @@ def cond(x, p=None):
@_wraps(np.linalg.tensorinv)
def tensorinv(a, ind=2):
_check_arraylike('jnp.linalg.tensorinv', a)
a = jnp.asarray(a)
oldshape = a.shape
prod = 1
@ -79,6 +81,7 @@ def tensorinv(a, ind=2):
@_wraps(np.linalg.tensorsolve)
def tensorsolve(a, b, axes=None):
_check_arraylike('jnp.linalg.tensorsolve', a, b)
a = jnp.asarray(a)
b = jnp.asarray(b)
an = a.ndim
@ -107,6 +110,7 @@ def tensorsolve(a, b, axes=None):
@_wraps(np.linalg.multi_dot)
def multi_dot(arrays, *, precision=None):
_check_arraylike('jnp.linalg.multi_dot', *arrays)
n = len(arrays)
# optimization only makes sense for len(arrays) > 2
if n < 2: