From 2416d154355f19e77b5c1ddf1de1f8552e4a98ad Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 27 Oct 2022 09:01:28 -0700 Subject: [PATCH] Call _check_arraylike for jnp.linalg & jnp.fft functions --- CHANGELOG.md | 3 + jax/_src/numpy/fft.py | 7 +- jax/_src/numpy/linalg.py | 20 ++- jax/_src/third_party/numpy/linalg.py | 178 ++++++++++++++------------- 4 files changed, 115 insertions(+), 93 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4fa11c71c..2754dc637 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,9 @@ Remember to align the itemized text with the first line of an item within a list * Breaking Changes * {func}`jax.numpy.gradient` now behaves like most other functions in {mod}`jax.numpy`, and forbids passing lists or tuples in place of arrays ({jax-issue}`#12958`) + * Functions in {mod}`jax.numpy.linalg` and {mod}`jax.numpy.fft` now uniformly + require inputs to be array-like: i.e. lists and tuples cannot be used in place + of arrays. Part of {jax-issue}`#7737`. ## jaxlib 0.3.24 * Changes diff --git a/jax/_src/numpy/fft.py b/jax/_src/numpy/fft.py index 49baa4474..04fea7986 100644 --- a/jax/_src/numpy/fft.py +++ b/jax/_src/numpy/fft.py @@ -21,7 +21,7 @@ from jax import dtypes from jax import lax from jax._src.lib import xla_client from jax._src.util import safe_zip -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import _check_arraylike, _wraps from jax._src.numpy import lax_numpy as jnp from jax._src.typing import Array, ArrayLike @@ -42,8 +42,7 @@ def _fft_core(func_name: str, fft_type: xla_client.FftType, a: ArrayLike, s: Optional[Shape], axes: Optional[Sequence[int]], norm: Optional[str]) -> Array: full_name = "jax.numpy.fft." + func_name - - # TODO(jakevdp): call check_arraylike + _check_arraylike(full_name, a) arr = jnp.asarray(a) if s is not None: @@ -285,6 +284,7 @@ def rfftfreq(n: int, d: ArrayLike = 1.0) -> Array: @_wraps(np.fft.fftshift) def fftshift(x: ArrayLike, axes: Union[None, int, Sequence[int]] = None) -> Array: + _check_arraylike("fftshift", x) x = jnp.asarray(x) shift: Union[int, Sequence[int]] if axes is None: @@ -300,6 +300,7 @@ def fftshift(x: ArrayLike, axes: Union[None, int, Sequence[int]] = None) -> Arra @_wraps(np.fft.ifftshift) def ifftshift(x: ArrayLike, axes: Union[None, int, Sequence[int]] = None) -> Array: + _check_arraylike("ifftshift", x) x = jnp.asarray(x) shift: Union[int, Sequence[int]] if axes is None: diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index 0c5f07df4..c604efb2d 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -28,7 +28,7 @@ from jax import lax from jax._src.lax import lax as lax_internal from jax._src.lax import linalg as lax_linalg from jax._src.numpy import lax_numpy as jnp -from jax._src.numpy.util import _wraps, _promote_dtypes_inexact +from jax._src.numpy.util import _wraps, _promote_dtypes_inexact, _check_arraylike from jax._src.util import canonicalize_axis from jax._src.typing import ArrayLike, Array @@ -44,6 +44,7 @@ def _H(x: ArrayLike) -> Array: @_wraps(np.linalg.cholesky) @jit def cholesky(a: ArrayLike) -> Array: + _check_arraylike("jnp.linalg.cholesky", a) a, = _promote_dtypes_inexact(jnp.asarray(a)) return lax_linalg.cholesky(a) @@ -67,6 +68,7 @@ def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True, @partial(jit, static_argnames=('full_matrices', 'compute_uv', 'hermitian')) def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True, hermitian: bool = False) -> Union[Array, Tuple[Array, Array, Array]]: + _check_arraylike("jnp.linalg.svd", a) a, = _promote_dtypes_inexact(jnp.asarray(a)) if hermitian: w, v = lax_linalg.eigh(a) @@ -90,7 +92,7 @@ def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True, @_wraps(np.linalg.matrix_power) @partial(jit, static_argnames=('n',)) def matrix_power(a: ArrayLike, n: int) -> Array: - # TODO(jakevdp): call _check_arraylike + _check_arraylike("jnp.linalg.matrix_power", a) arr, = _promote_dtypes_inexact(jnp.asarray(a)) if arr.ndim < 2: @@ -129,6 +131,7 @@ def matrix_power(a: ArrayLike, n: int) -> Array: @_wraps(np.linalg.matrix_rank) @jit def matrix_rank(M: ArrayLike, tol: Optional[ArrayLike] = None) -> Array: + _check_arraylike("jnp.linalg.matrix_rank", M) M, = _promote_dtypes_inexact(jnp.asarray(M)) if M.ndim < 2: return jnp.any(M != 0).astype(jnp.int32) @@ -191,6 +194,7 @@ def _slogdet_qr(a: Array) -> Tuple[Array, Array]: """)) @partial(jit, static_argnames=('method',)) def slogdet(a: ArrayLike, *, method: Optional[str] = None) -> Tuple[Array, Array]: + _check_arraylike("jnp.linalg.slogdet", a) a, = _promote_dtypes_inexact(jnp.asarray(a)) a_shape = jnp.shape(a) if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]: @@ -329,6 +333,7 @@ def _det_3x3(a: Array) -> Array: @_wraps(np.linalg.det) @jit def det(a: ArrayLike) -> Array: + _check_arraylike("jnp.linalg.det", a) a, = _promote_dtypes_inexact(jnp.asarray(a)) a_shape = jnp.shape(a) if len(a_shape) >= 2 and a_shape[-1] == 2 and a_shape[-2] == 2: @@ -361,6 +366,7 @@ backend. However eigendecomposition for symmetric/Hermitian matrices is implemented more widely (see :func:`jax.numpy.linalg.eigh`). """) def eig(a: ArrayLike) -> Tuple[Array, Array]: + _check_arraylike("jnp.linalg.eig", a) a, = _promote_dtypes_inexact(jnp.asarray(a)) w, v = lax_linalg.eig(a, compute_left_eigenvectors=False) return w, v @@ -369,6 +375,7 @@ def eig(a: ArrayLike) -> Tuple[Array, Array]: @_wraps(np.linalg.eigvals) @jit def eigvals(a: ArrayLike) -> Array: + _check_arraylike("jnp.linalg.eigvals", a) return lax_linalg.eig(a, compute_left_eigenvectors=False, compute_right_eigenvectors=False)[0] @@ -377,6 +384,7 @@ def eigvals(a: ArrayLike) -> Array: @partial(jit, static_argnames=('UPLO', 'symmetrize_input')) def eigh(a: ArrayLike, UPLO: Optional[str] = None, symmetrize_input: bool = True) -> Tuple[Array, Array]: + _check_arraylike("jnp.linalg.eigh", a) if UPLO is None or UPLO == "L": lower = True elif UPLO == "U": @@ -393,6 +401,7 @@ def eigh(a: ArrayLike, UPLO: Optional[str] = None, @_wraps(np.linalg.eigvalsh) @partial(jit, static_argnames=('UPLO',)) def eigvalsh(a: ArrayLike, UPLO: Optional[str] = 'L') -> Array: + _check_arraylike("jnp.linalg.eigvalsh", a) w, _ = eigh(a, UPLO) return w @@ -407,6 +416,7 @@ def eigvalsh(a: ArrayLike, UPLO: Optional[str] = 'L') -> Array: def pinv(a: ArrayLike, rcond: Optional[ArrayLike] = None) -> Array: # Uses same algorithm as # https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979 + _check_arraylike("jnp.linalg.pinv", a) arr = jnp.conj(a) if rcond is None: max_rows_cols = max(arr.shape[-2:]) @@ -447,7 +457,7 @@ def _pinv_jvp(rcond, primals, tangents): @_wraps(np.linalg.inv) @jit def inv(a: ArrayLike) -> Array: - # TODO(jakevdp): call _check_arraylike + _check_arraylike("jnp.linalg.inv", a) arr = jnp.asarray(a) if arr.ndim < 2 or arr.shape[-1] != arr.shape[-2]: raise ValueError( @@ -461,6 +471,7 @@ def inv(a: ArrayLike) -> Array: def norm(x: ArrayLike, ord: Union[int, str, None] = None, axis: Union[None, Tuple[int, ...], int] = None, keepdims: bool = False) -> Array: + _check_arraylike("jnp.linalg.norm", x) x, = _promote_dtypes_inexact(jnp.asarray(x)) x_shape = jnp.shape(x) ndim = len(x_shape) @@ -560,6 +571,7 @@ def qr(a: ArrayLike, mode: str = "reduced") -> Union[Array, Tuple[Array, Array]] @_wraps(np.linalg.qr) @partial(jit, static_argnames=('mode',)) def qr(a: ArrayLike, mode: str = "reduced") -> Union[Array, Tuple[Array, Array]]: + _check_arraylike("jnp.linalg.qr", a) a, = _promote_dtypes_inexact(jnp.asarray(a)) if mode == "raw": a, taus = lax_linalg.geqrf(a) @@ -579,6 +591,7 @@ def qr(a: ArrayLike, mode: str = "reduced") -> Union[Array, Tuple[Array, Array]] @_wraps(np.linalg.solve) @jit def solve(a: ArrayLike, b: ArrayLike) -> Array: + _check_arraylike("jnp.linalg.solve", a, b) a, b = _promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b)) return lax_linalg._solve(a, b) @@ -645,6 +658,7 @@ _jit_lstsq = jit(partial(_lstsq, numpy_resid=False)) """)) def lstsq(a: ArrayLike, b: ArrayLike, rcond: Optional[float] = None, *, numpy_resid: bool = False) -> Tuple[Array, Array, Array, Array]: + _check_arraylike("jnp.linalg.lstsq", a, b) if numpy_resid: return _lstsq(a, b, rcond, numpy_resid=True) return _jit_lstsq(a, b, rcond) diff --git a/jax/_src/third_party/numpy/linalg.py b/jax/_src/third_party/numpy/linalg.py index a1dd33f4a..c59f85eb9 100644 --- a/jax/_src/third_party/numpy/linalg.py +++ b/jax/_src/third_party/numpy/linalg.py @@ -2,7 +2,7 @@ import numpy as np from jax._src.numpy import lax_numpy as jnp from jax._src.numpy import linalg as la -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import _check_arraylike, _wraps def _isEmpty2d(arr): @@ -33,14 +33,15 @@ def _assertNdSquareness(*arrays): def _assert2d(*arrays): - for a in arrays: - if a.ndim != 2: - raise ValueError(f'{a.ndim}-dimensional array given. ' - 'Array must be two-dimensional') + for a in arrays: + if a.ndim != 2: + raise ValueError(f'{a.ndim}-dimensional array given. ' + 'Array must be two-dimensional') @_wraps(np.linalg.cond) def cond(x, p=None): + _check_arraylike('jnp.linalg.cond', x) _assertNoEmpty2d(x) if p in (None, 2): s = la.svd(x, compute_uv=False) @@ -63,6 +64,7 @@ def cond(x, p=None): @_wraps(np.linalg.tensorinv) def tensorinv(a, ind=2): + _check_arraylike('jnp.linalg.tensorinv', a) a = jnp.asarray(a) oldshape = a.shape prod = 1 @@ -79,6 +81,7 @@ def tensorinv(a, ind=2): @_wraps(np.linalg.tensorsolve) def tensorsolve(a, b, axes=None): + _check_arraylike('jnp.linalg.tensorsolve', a, b) a = jnp.asarray(a) b = jnp.asarray(b) an = a.ndim @@ -107,101 +110,102 @@ def tensorsolve(a, b, axes=None): @_wraps(np.linalg.multi_dot) def multi_dot(arrays, *, precision=None): - n = len(arrays) - # optimization only makes sense for len(arrays) > 2 - if n < 2: - raise ValueError("Expecting at least two arrays.") - elif n == 2: - return jnp.dot(arrays[0], arrays[1], precision=precision) + _check_arraylike('jnp.linalg.multi_dot', *arrays) + n = len(arrays) + # optimization only makes sense for len(arrays) > 2 + if n < 2: + raise ValueError("Expecting at least two arrays.") + elif n == 2: + return jnp.dot(arrays[0], arrays[1], precision=precision) - arrays = [jnp.asarray(a) for a in arrays] + arrays = [jnp.asarray(a) for a in arrays] - # save original ndim to reshape the result array into the proper form later - ndim_first, ndim_last = arrays[0].ndim, arrays[-1].ndim - # Explicitly convert vectors to 2D arrays to keep the logic of the internal - # _multi_dot_* functions as simple as possible. - if arrays[0].ndim == 1: - arrays[0] = jnp.atleast_2d(arrays[0]) - if arrays[-1].ndim == 1: - arrays[-1] = jnp.atleast_2d(arrays[-1]).T - _assert2d(*arrays) + # save original ndim to reshape the result array into the proper form later + ndim_first, ndim_last = arrays[0].ndim, arrays[-1].ndim + # Explicitly convert vectors to 2D arrays to keep the logic of the internal + # _multi_dot_* functions as simple as possible. + if arrays[0].ndim == 1: + arrays[0] = jnp.atleast_2d(arrays[0]) + if arrays[-1].ndim == 1: + arrays[-1] = jnp.atleast_2d(arrays[-1]).T + _assert2d(*arrays) - # _multi_dot_three is much faster than _multi_dot_matrix_chain_order - if n == 3: - result = _multi_dot_three(*arrays, precision) - else: - order = _multi_dot_matrix_chain_order(arrays) - result = _multi_dot(arrays, order, 0, n - 1, precision) + # _multi_dot_three is much faster than _multi_dot_matrix_chain_order + if n == 3: + result = _multi_dot_three(*arrays, precision) + else: + order = _multi_dot_matrix_chain_order(arrays) + result = _multi_dot(arrays, order, 0, n - 1, precision) - # return proper shape - if ndim_first == 1 and ndim_last == 1: - return result[0, 0] # scalar - elif ndim_first == 1 or ndim_last == 1: - return result.ravel() # 1-D - else: - return result + # return proper shape + if ndim_first == 1 and ndim_last == 1: + return result[0, 0] # scalar + elif ndim_first == 1 or ndim_last == 1: + return result.ravel() # 1-D + else: + return result def _multi_dot_three(A, B, C, precision): - """ - Find the best order for three arrays and do the multiplication. - For three arguments `_multi_dot_three` is approximately 15 times faster - than `_multi_dot_matrix_chain_order` - """ - a0, a1b0 = A.shape - b1c0, c1 = C.shape - # cost1 = cost((AB)C) = a0*a1b0*b1c0 + a0*b1c0*c1 - cost1 = a0 * b1c0 * (a1b0 + c1) - # cost2 = cost(A(BC)) = a1b0*b1c0*c1 + a0*a1b0*c1 - cost2 = a1b0 * c1 * (a0 + b1c0) + """ + Find the best order for three arrays and do the multiplication. + For three arguments `_multi_dot_three` is approximately 15 times faster + than `_multi_dot_matrix_chain_order` + """ + a0, a1b0 = A.shape + b1c0, c1 = C.shape + # cost1 = cost((AB)C) = a0*a1b0*b1c0 + a0*b1c0*c1 + cost1 = a0 * b1c0 * (a1b0 + c1) + # cost2 = cost(A(BC)) = a1b0*b1c0*c1 + a0*a1b0*c1 + cost2 = a1b0 * c1 * (a0 + b1c0) - if cost1 < cost2: - return jnp.dot(jnp.dot(A, B, precision=precision), C, precision=precision) - else: - return jnp.dot(A, jnp.dot(B, C, precision=precision), precision=precision) + if cost1 < cost2: + return jnp.dot(jnp.dot(A, B, precision=precision), C, precision=precision) + else: + return jnp.dot(A, jnp.dot(B, C, precision=precision), precision=precision) def _multi_dot_matrix_chain_order(arrays, return_costs=False): - """ - Return a jnp.array that encodes the optimal order of mutiplications. - The optimal order array is then used by `_multi_dot()` to do the - multiplication. - Also return the cost matrix if `return_costs` is `True` - The implementation CLOSELY follows Cormen, "Introduction to Algorithms", - Chapter 15.2, p. 370-378. Note that Cormen uses 1-based indices. - cost[i, j] = min([ - cost[prefix] + cost[suffix] + cost_mult(prefix, suffix) - for k in range(i, j)]) - """ - n = len(arrays) - # p stores the dimensions of the matrices - # Example for p: A_{10x100}, B_{100x5}, C_{5x50} --> p = [10, 100, 5, 50] - p = [a.shape[0] for a in arrays] + [arrays[-1].shape[1]] - # m is a matrix of costs of the subproblems - # m[i,j]: min number of scalar multiplications needed to compute A_{i..j} - m = np.zeros((n, n), dtype=np.double) - # s is the actual ordering - # s[i, j] is the value of k at which we split the product A_i..A_j - s = np.empty((n, n), dtype=np.intp) + """ + Return a jnp.array that encodes the optimal order of mutiplications. + The optimal order array is then used by `_multi_dot()` to do the + multiplication. + Also return the cost matrix if `return_costs` is `True` + The implementation CLOSELY follows Cormen, "Introduction to Algorithms", + Chapter 15.2, p. 370-378. Note that Cormen uses 1-based indices. + cost[i, j] = min([ + cost[prefix] + cost[suffix] + cost_mult(prefix, suffix) + for k in range(i, j)]) + """ + n = len(arrays) + # p stores the dimensions of the matrices + # Example for p: A_{10x100}, B_{100x5}, C_{5x50} --> p = [10, 100, 5, 50] + p = [a.shape[0] for a in arrays] + [arrays[-1].shape[1]] + # m is a matrix of costs of the subproblems + # m[i,j]: min number of scalar multiplications needed to compute A_{i..j} + m = np.zeros((n, n), dtype=np.double) + # s is the actual ordering + # s[i, j] is the value of k at which we split the product A_i..A_j + s = np.empty((n, n), dtype=np.intp) - for l in range(1, n): - for i in range(n - l): - j = i + l - m[i, j] = jnp.inf - for k in range(i, j): - q = m[i, k] + m[k+1, j] + p[i]*p[k+1]*p[j+1] - if q < m[i, j]: - m[i, j] = q - s[i, j] = k # Note that Cormen uses 1-based index + for l in range(1, n): + for i in range(n - l): + j = i + l + m[i, j] = jnp.inf + for k in range(i, j): + q = m[i, k] + m[k+1, j] + p[i]*p[k+1]*p[j+1] + if q < m[i, j]: + m[i, j] = q + s[i, j] = k # Note that Cormen uses 1-based index - return (s, m) if return_costs else s + return (s, m) if return_costs else s def _multi_dot(arrays, order, i, j, precision): - """Actually do the multiplication with the given order.""" - if i == j: - return arrays[i] - else: - return jnp.dot(_multi_dot(arrays, order, i, order[i, j], precision), - _multi_dot(arrays, order, order[i, j] + 1, j, precision), - precision=precision) + """Actually do the multiplication with the given order.""" + if i == j: + return arrays[i] + else: + return jnp.dot(_multi_dot(arrays, order, i, order[i, j], precision), + _multi_dot(arrays, order, order[i, j] + 1, j, precision), + precision=precision)