diff --git a/jax/_src/numpy/__init__.py b/jax/_src/numpy/__init__.py new file mode 100644 index 000000000..b0c7da3d7 --- /dev/null +++ b/jax/_src/numpy/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jax/_src/numpy/fft.py b/jax/_src/numpy/fft.py new file mode 100644 index 000000000..7ae86713a --- /dev/null +++ b/jax/_src/numpy/fft.py @@ -0,0 +1,253 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np + +from jax import lax +from jax.lib import xla_client +from .util import _wraps +from . import lax_numpy as jnp +from jax import ops as jaxops + + +def _fft_core(func_name, fft_type, a, s, axes, norm): + # TODO(skye): implement padding/cropping based on 's'. + full_name = "jax.numpy.fft." + func_name + if s is not None: + raise NotImplementedError("%s only supports s=None, got %s" % (full_name, s)) + if norm is not None: + raise NotImplementedError("%s only supports norm=None, got %s" % (full_name, norm)) + if s is not None and axes is not None and len(s) != len(axes): + # Same error as numpy. + raise ValueError("Shape and axes have different lengths.") + + orig_axes = axes + if axes is None: + if s is None: + axes = range(a.ndim) + else: + axes = range(a.ndim - len(s), a.ndim) + + if len(axes) != len(set(axes)): + raise ValueError( + "%s does not support repeated axes. Got axes %s." % (full_name, axes)) + + if len(axes) > 3: + # XLA does not support FFTs over more than 3 dimensions + raise ValueError( + "%s only supports 1D, 2D, and 3D FFTs. " + "Got axes %s with input rank %s." % (full_name, orig_axes, a.ndim)) + + # XLA only supports FFTs over the innermost axes, so rearrange if necessary. + if orig_axes is not None: + axes = tuple(range(a.ndim - len(axes), a.ndim)) + a = jnp.moveaxis(a, orig_axes, axes) + + if s is None: + if fft_type == xla_client.FftType.IRFFT: + s = [a.shape[axis] for axis in axes[:-1]] + if axes: + s += [max(0, 2 * (a.shape[axes[-1]] - 1))] + else: + s = [a.shape[axis] for axis in axes] + + transformed = lax.fft(a, fft_type, s) + + if orig_axes is not None: + transformed = jnp.moveaxis(transformed, axes, orig_axes) + return transformed + + +@_wraps(np.fft.fftn) +def fftn(a, s=None, axes=None, norm=None): + return _fft_core('fftn', xla_client.FftType.FFT, a, s, axes, norm) + + +@_wraps(np.fft.ifftn) +def ifftn(a, s=None, axes=None, norm=None): + return _fft_core('ifftn', xla_client.FftType.IFFT, a, s, axes, norm) + + +@_wraps(np.fft.rfftn) +def rfftn(a, s=None, axes=None, norm=None): + return _fft_core('rfftn', xla_client.FftType.RFFT, a, s, axes, norm) + + +@_wraps(np.fft.irfftn) +def irfftn(a, s=None, axes=None, norm=None): + return _fft_core('irfftn', xla_client.FftType.IRFFT, a, s, axes, norm) + + +def _axis_check_1d(func_name, axis): + full_name = "jax.numpy.fft." + func_name + if isinstance(axis, (list, tuple)): + raise ValueError( + "%s does not support multiple axes. Please use %sn. " + "Got axis = %r." % (full_name, full_name, axis) + ) + +def _fft_core_1d(func_name, fft_type, a, s, axis, norm): + _axis_check_1d(func_name, axis) + axes = None if axis is None else [axis] + return _fft_core(func_name, fft_type, a, s, axes, norm) + + +@_wraps(np.fft.fft) +def fft(a, n=None, axis=-1, norm=None): + return _fft_core_1d('fft', xla_client.FftType.FFT, a, s=n, axis=axis, + norm=norm) + +@_wraps(np.fft.ifft) +def ifft(a, n=None, axis=-1, norm=None): + return _fft_core_1d('ifft', xla_client.FftType.IFFT, a, s=n, axis=axis, + norm=norm) + +@_wraps(np.fft.rfft) +def rfft(a, n=None, axis=-1, norm=None): + return _fft_core_1d('rfft', xla_client.FftType.RFFT, a, s=n, axis=axis, + norm=norm) + +@_wraps(np.fft.irfft) +def irfft(a, n=None, axis=-1, norm=None): + return _fft_core_1d('irfft', xla_client.FftType.IRFFT, a, s=n, axis=axis, + norm=norm) + +@_wraps(np.fft.hfft) +def hfft(a, n=None, axis=-1, norm=None): + conj_a = jnp.conj(a) + _axis_check_1d('hfft', axis) + nn = (a.shape[axis] - 1) * 2 if n is None else n + return _fft_core_1d('hfft', xla_client.FftType.IRFFT, conj_a, s=n, axis=axis, + norm=norm) * nn + +@_wraps(np.fft.ihfft) +def ihfft(a, n=None, axis=-1, norm=None): + _axis_check_1d('ihfft', axis) + nn = a.shape[axis] if n is None else n + output = _fft_core_1d('ihfft', xla_client.FftType.RFFT, a, s=n, axis=axis, + norm=norm) + return jnp.conj(output) * (1 / nn) + + +def _fft_core_2d(func_name, fft_type, a, s, axes, norm): + full_name = "jax.numpy.fft." + func_name + if len(axes) != 2: + raise ValueError( + "%s only supports 2 axes. Got axes = %r." + % (full_name, axes) + ) + return _fft_core(func_name, fft_type, a, s, axes, norm) + + +@_wraps(np.fft.fft2) +def fft2(a, s=None, axes=(-2,-1), norm=None): + return _fft_core_2d('fft2', xla_client.FftType.FFT, a, s=s, axes=axes, + norm=norm) + +@_wraps(np.fft.ifft2) +def ifft2(a, s=None, axes=(-2,-1), norm=None): + return _fft_core_2d('ifft2', xla_client.FftType.IFFT, a, s=s, axes=axes, + norm=norm) + +@_wraps(np.fft.rfft2) +def rfft2(a, s=None, axes=(-2,-1), norm=None): + return _fft_core_2d('rfft2', xla_client.FftType.RFFT, a, s=s, axes=axes, + norm=norm) + +@_wraps(np.fft.irfft2) +def irfft2(a, s=None, axes=(-2,-1), norm=None): + return _fft_core_2d('irfft2', xla_client.FftType.IRFFT, a, s=s, axes=axes, + norm=norm) + + +@_wraps(np.fft.fftfreq) +def fftfreq(n, d=1.0): + if isinstance(n, (list, tuple)): + raise ValueError( + "The n argument of jax.numpy.fft.fftfreq only takes an int. " + "Got n = %s." % list(n)) + + elif isinstance(d, (list, tuple)): + raise ValueError( + "The d argument of jax.numpy.fft.fftfreq only takes a single value. " + "Got d = %s." % list(d)) + + k = jnp.zeros(n) + if n % 2 == 0: + # k[0: n // 2 - 1] = jnp.arange(0, n // 2 - 1) + k = jaxops.index_update(k, jaxops.index[0: n // 2], jnp.arange(0, n // 2)) + + # k[n // 2:] = jnp.arange(-n // 2, -1) + k = jaxops.index_update(k, jaxops.index[n // 2:], jnp.arange(-n // 2, 0)) + + else: + # k[0: (n - 1) // 2] = jnp.arange(0, (n - 1) // 2) + k = jaxops.index_update(k, jaxops.index[0: (n - 1) // 2 + 1], + jnp.arange(0, (n - 1) // 2 + 1)) + + # k[(n - 1) // 2 + 1:] = jnp.arange(-(n - 1) // 2, -1) + k = jaxops.index_update(k, jaxops.index[(n - 1) // 2 + 1:], + jnp.arange(-(n - 1) // 2, 0)) + + return k / (d * n) + + +@_wraps(np.fft.rfftfreq) +def rfftfreq(n, d=1.0): + if isinstance(n, (list, tuple)): + raise ValueError( + "The n argument of jax.numpy.fft.rfftfreq only takes an int. " + "Got n = %s." % list(n)) + + elif isinstance(d, (list, tuple)): + raise ValueError( + "The d argument of jax.numpy.fft.rfftfreq only takes a single value. " + "Got d = %s." % list(d)) + + if n % 2 == 0: + k = jnp.arange(0, n // 2 + 1) + + else: + k = jnp.arange(0, (n - 1) // 2 + 1) + + return k / (d * n) + + +@_wraps(np.fft.fftshift) +def fftshift(x, axes=None): + x = jnp.asarray(x) + if axes is None: + axes = tuple(range(x.ndim)) + shift = [dim // 2 for dim in x.shape] + elif isinstance(axes, int): + shift = x.shape[axes] // 2 + else: + shift = [x.shape[ax] // 2 for ax in axes] + + return jnp.roll(x, shift, axes) + + +@_wraps(np.fft.ifftshift) +def ifftshift(x, axes=None): + x = jnp.asarray(x) + if axes is None: + axes = tuple(range(x.ndim)) + shift = [-(dim // 2) for dim in x.shape] + elif isinstance(axes, int): + shift = -(x.shape[axes] // 2) + else: + shift = [-(x.shape[ax] // 2) for ax in axes] + + return jnp.roll(x, shift, axes) diff --git a/jax/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py similarity index 99% rename from jax/numpy/lax_numpy.py rename to jax/_src/numpy/lax_numpy.py index 9c233d9fc..0ce577772 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -39,19 +39,19 @@ import opt_einsum import jax from jax import jit, custom_jvp from .vectorize import vectorize -from ._util import _wraps -from .. import core -from .. import dtypes -from ..abstract_arrays import UnshapedArray, ShapedArray, ConcreteArray, canonicalize_shape -from ..config import flags, config -from ..interpreters.xla import DeviceArray -from ..interpreters.masking import Poly -from .. import lax -from ..lax.lax import _device_put_raw -from .. import ops -from ..util import (partial, unzip2, prod as _prod, - subvals, safe_zip, canonicalize_axis as _canonicalize_axis) -from ..tree_util import tree_leaves, tree_flatten +from .util import _wraps +from jax import core +from jax import dtypes +from jax.abstract_arrays import UnshapedArray, ShapedArray, ConcreteArray, canonicalize_shape +from jax.config import flags, config +from jax.interpreters.xla import DeviceArray +from jax.interpreters.masking import Poly +from jax import lax +from jax.lax.lax import _device_put_raw +from jax import ops +from jax.util import (partial, unzip2, prod as _prod, + subvals, safe_zip, canonicalize_axis as _canonicalize_axis) +from jax.tree_util import tree_leaves, tree_flatten FLAGS = flags.FLAGS flags.DEFINE_enum( diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py new file mode 100644 index 000000000..e0b2a4d24 --- /dev/null +++ b/jax/_src/numpy/linalg.py @@ -0,0 +1,530 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from functools import partial + +import numpy as np +import textwrap +import operator +from typing import Tuple, Union, cast + +from jax import jit, vmap, custom_jvp +from jax import lax +from jax import ops +from jax import lax_linalg +from jax import dtypes +from .util import _wraps +from .vectorize import vectorize +from . import lax_numpy as jnp +from jax.util import canonicalize_axis +from jax.third_party.numpy.linalg import cond, multi_dot, tensorinv, tensorsolve # noqa: F401 + +_T = lambda x: jnp.swapaxes(x, -1, -2) +_H = lambda x: jnp.conjugate(jnp.swapaxes(x, -1, -2)) + + +def _promote_arg_dtypes(*args): + """Promotes `args` to a common inexact type.""" + def _to_inexact_type(type): + return type if jnp.issubdtype(type, jnp.inexact) else jnp.float_ + inexact_types = [_to_inexact_type(jnp._dtype(arg)) for arg in args] + dtype = dtypes.canonicalize_dtype(jnp.result_type(*inexact_types)) + args = [lax.convert_element_type(arg, dtype) for arg in args] + if len(args) == 1: + return args[0] + else: + return args + + +@_wraps(np.linalg.cholesky) +def cholesky(a): + a = _promote_arg_dtypes(jnp.asarray(a)) + return lax_linalg.cholesky(a) + + +@_wraps(np.linalg.svd) +def svd(a, full_matrices=True, compute_uv=True): + a = _promote_arg_dtypes(jnp.asarray(a)) + return lax_linalg.svd(a, full_matrices, compute_uv) + + +@_wraps(np.linalg.matrix_power) +def matrix_power(a, n): + a = _promote_arg_dtypes(jnp.asarray(a)) + + if a.ndim < 2: + raise TypeError("{}-dimensional array given. Array must be at least " + "two-dimensional".format(a.ndim)) + if a.shape[-2] != a.shape[-1]: + raise TypeError("Last 2 dimensions of the array must be square") + try: + n = operator.index(n) + except TypeError as err: + raise TypeError("exponent must be an integer, got {}".format(n)) from err + + if n == 0: + return jnp.broadcast_to(jnp.eye(a.shape[-2], dtype=a.dtype), a.shape) + elif n < 0: + a = inv(a) + n = np.abs(n) + + if n == 1: + return a + elif n == 2: + return a @ a + elif n == 3: + return (a @ a) @ a + + z = result = None + while n > 0: + z = a if z is None else (z @ z) + n, bit = divmod(n, 2) + if bit: + result = z if result is None else (result @ z) + + return result + + +@_wraps(np.linalg.matrix_rank) +def matrix_rank(M, tol=None): + M = _promote_arg_dtypes(jnp.asarray(M)) + if M.ndim > 2: + raise TypeError("array should have 2 or fewer dimensions") + if M.ndim < 2: + return jnp.any(M != 0).astype(jnp.int32) + S = svd(M, full_matrices=False, compute_uv=False) + if tol is None: + tol = S.max() * np.max(M.shape) * jnp.finfo(S.dtype).eps + return jnp.sum(S > tol) + + +@custom_jvp +@_wraps(np.linalg.slogdet) +@jit +def slogdet(a): + a = _promote_arg_dtypes(jnp.asarray(a)) + dtype = lax.dtype(a) + a_shape = jnp.shape(a) + if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]: + msg = "Argument to slogdet() must have shape [..., n, n], got {}" + raise ValueError(msg.format(a_shape)) + lu, pivot, _ = lax_linalg.lu(a) + diag = jnp.diagonal(lu, axis1=-2, axis2=-1) + is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1) + parity = jnp.count_nonzero(pivot != jnp.arange(a_shape[-1]), axis=-1) + if jnp.iscomplexobj(a): + sign = jnp.prod(diag / jnp.abs(diag), axis=-1) + else: + sign = jnp.array(1, dtype=dtype) + parity = parity + jnp.count_nonzero(diag < 0, axis=-1) + sign = jnp.where(is_zero, + jnp.array(0, dtype=dtype), + sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype)) + logdet = jnp.where( + is_zero, jnp.array(-jnp.inf, dtype=dtype), + jnp.sum(jnp.log(jnp.abs(diag)), axis=-1)) + return sign, jnp.real(logdet) + +@slogdet.defjvp +def _slogdet_jvp(primals, tangents): + x, = primals + g, = tangents + if jnp.issubdtype(jnp._dtype(x), jnp.complexfloating): + raise NotImplementedError # TODO(pfau): make this work for complex types + sign, ans = slogdet(x) + sign_dot, ans_dot = jnp.zeros_like(sign), jnp.trace(solve(x, g), axis1=-1, axis2=-2) + return (sign, ans), (sign_dot, ans_dot) + + +def _cofactor_solve(a, b): + """Equivalent to det(a)*solve(a, b) for nonsingular mat. + + Intermediate function used for jvp and vjp of det. + This function borrows heavily from jax.numpy.linalg.solve and + jax.numpy.linalg.slogdet to compute the gradient of the determinant + in a way that is well defined even for low rank matrices. + + This function handles two different cases: + * rank(a) == n or n-1 + * rank(a) < n-1 + + For rank n-1 matrices, the gradient of the determinant is a rank 1 matrix. + Rather than computing det(a)*solve(a, b), which would return NaN, we work + directly with the LU decomposition. If a = p @ l @ u, then + det(a)*solve(a, b) = + prod(diag(u)) * u^-1 @ l^-1 @ p^-1 b = + prod(diag(u)) * triangular_solve(u, solve(p @ l, b)) + If a is rank n-1, then the lower right corner of u will be zero and the + triangular_solve will fail. + Let x = solve(p @ l, b) and y = det(a)*solve(a, b). + Then y_{n} + x_{n} / u_{nn} * prod_{i=1...n}(u_{ii}) = + x_{n} * prod_{i=1...n-1}(u_{ii}) + So by replacing the lower-right corner of u with prod_{i=1...n-1}(u_{ii})^-1 + we can avoid the triangular_solve failing. + To correctly compute the rest of y_{i} for i != n, we simply multiply + x_{i} by det(a) for all i != n, which will be zero if rank(a) = n-1. + + For the second case, a check is done on the matrix to see if `solve` + returns NaN or Inf, and gives a matrix of zeros as a result, as the + gradient of the determinant of a matrix with rank less than n-1 is 0. + This will still return the correct value for rank n-1 matrices, as the check + is applied *after* the lower right corner of u has been updated. + + Args: + a: A square matrix or batch of matrices, possibly singular. + b: A matrix, or batch of matrices of the same dimension as a. + + Returns: + det(a) and cofactor(a)^T*b, aka adjugate(a)*b + """ + a = _promote_arg_dtypes(jnp.asarray(a)) + b = _promote_arg_dtypes(jnp.asarray(b)) + a_shape = jnp.shape(a) + b_shape = jnp.shape(b) + a_ndims = len(a_shape) + if not (a_ndims >= 2 and a_shape[-1] == a_shape[-2] + and b_shape[-2:] == a_shape[-2:]): + msg = ("The arguments to _cofactor_solve must have shapes " + "a=[..., m, m] and b=[..., m, m]; got a={} and b={}") + raise ValueError(msg.format(a_shape, b_shape)) + if a_shape[-1] == 1: + return a[0, 0], b + # lu contains u in the upper triangular matrix and l in the strict lower + # triangular matrix. + # The diagonal of l is set to ones without loss of generality. + lu, pivots, permutation = lax_linalg.lu(a) + dtype = lax.dtype(a) + batch_dims = lax.broadcast_shapes(lu.shape[:-2], b.shape[:-2]) + x = jnp.broadcast_to(b, batch_dims + b.shape[-2:]) + lu = jnp.broadcast_to(lu, batch_dims + lu.shape[-2:]) + # Compute (partial) determinant, ignoring last diagonal of LU + diag = jnp.diagonal(lu, axis1=-2, axis2=-1) + parity = jnp.count_nonzero(pivots != jnp.arange(a_shape[-1]), axis=-1) + sign = jnp.array(-2 * (parity % 2) + 1, dtype=dtype) + # partial_det[:, -1] contains the full determinant and + # partial_det[:, -2] contains det(u) / u_{nn}. + partial_det = jnp.cumprod(diag, axis=-1) * sign[..., None] + lu = ops.index_update(lu, ops.index[..., -1, -1], 1.0 / partial_det[..., -2]) + permutation = jnp.broadcast_to(permutation, batch_dims + (a_shape[-1],)) + iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1,))) + # filter out any matrices that are not full rank + d = jnp.ones(x.shape[:-1], x.dtype) + d = lax_linalg.triangular_solve(lu, d, left_side=True, lower=False) + d = jnp.any(jnp.logical_or(jnp.isnan(d), jnp.isinf(d)), axis=-1) + d = jnp.tile(d[..., None, None], d.ndim*(1,) + x.shape[-2:]) + x = jnp.where(d, jnp.zeros_like(x), x) # first filter + x = x[iotas[:-1] + (permutation, slice(None))] + x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=True, + unit_diagonal=True) + x = jnp.concatenate((x[..., :-1, :] * partial_det[..., -1, None, None], + x[..., -1:, :]), axis=-2) + x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=False) + x = jnp.where(d, jnp.zeros_like(x), x) # second filter + + return partial_det[..., -1], x + + +@custom_jvp +@_wraps(np.linalg.det) +def det(a): + sign, logdet = slogdet(a) + return sign * jnp.exp(logdet) + + +@det.defjvp +def _det_jvp(primals, tangents): + x, = primals + g, = tangents + y, z = _cofactor_solve(x, g) + return y, jnp.trace(z, axis1=-1, axis2=-2) + + +@_wraps(np.linalg.eig) +def eig(a): + a = _promote_arg_dtypes(jnp.asarray(a)) + return lax_linalg.eig(a, compute_left_eigenvectors=False) + + +@_wraps(np.linalg.eigvals) +def eigvals(a): + return lax_linalg.eig(a, compute_left_eigenvectors=False, + compute_right_eigenvectors=False)[0] + + +@_wraps(np.linalg.eigh) +def eigh(a, UPLO=None, symmetrize_input=True): + if UPLO is None or UPLO == "L": + lower = True + elif UPLO == "U": + lower = False + else: + msg = "UPLO must be one of None, 'L', or 'U', got {}".format(UPLO) + raise ValueError(msg) + + a = _promote_arg_dtypes(jnp.asarray(a)) + v, w = lax_linalg.eigh(a, lower=lower, symmetrize_input=symmetrize_input) + return w, v + + +@_wraps(np.linalg.eigvalsh) +def eigvalsh(a, UPLO='L'): + w, _ = eigh(a, UPLO) + return w + + +@partial(custom_jvp, nondiff_argnums=(1,)) +@_wraps(np.linalg.pinv, lax_description=textwrap.dedent("""\ + It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the + default `rcond` is `1e-15`. Here the default is + `10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps`. + """)) +def pinv(a, rcond=None): + # Uses same algorithm as + # https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979 + a = jnp.conj(a) + if rcond is None: + max_rows_cols = max(a.shape[-2:]) + rcond = 10. * max_rows_cols * jnp.finfo(a.dtype).eps + rcond = jnp.asarray(rcond) + u, s, v = svd(a, full_matrices=False) + # Singular values less than or equal to ``rcond * largest_singular_value`` + # are set to zero. + cutoff = rcond[..., jnp.newaxis] * jnp.amax(s, axis=-1, keepdims=True) + s = jnp.where(s > cutoff, s, jnp.inf) + res = jnp.matmul(_T(v), jnp.divide(_T(u), s[..., jnp.newaxis])) + return lax.convert_element_type(res, a.dtype) + + +@pinv.defjvp +def _pinv_jvp(rcond, primals, tangents): + # The Differentiation of Pseudo-Inverses and Nonlinear Least Squares Problems + # Whose Variables Separate. Author(s): G. H. Golub and V. Pereyra. SIAM + # Journal on Numerical Analysis, Vol. 10, No. 2 (Apr., 1973), pp. 413-432. + # (via https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Derivative) + a, = primals + a_dot, = tangents + p = pinv(a, rcond=rcond) + m, n = a.shape[-2:] + # TODO(phawkins): on TPU, we would need to opt into high precision here. + # TODO(phawkins): consider if this can be simplified in the Hermitian case. + p_dot = -p @ a_dot @ p + p_dot = p_dot + p @ _H(p) @ _H(a_dot) @ (jnp.eye(m, dtype=a.dtype) - a @ p) + p_dot = p_dot + (jnp.eye(n, dtype=a.dtype) - p @ a) @ _H(a_dot) @ _H(p) @ p + return p, p_dot + + +@_wraps(np.linalg.inv) +def inv(a): + if jnp.ndim(a) < 2 or a.shape[-1] != a.shape[-2]: + raise ValueError("Argument to inv must have shape [..., n, n], got {}." + .format(jnp.shape(a))) + return solve( + a, lax.broadcast(jnp.eye(a.shape[-1], dtype=lax.dtype(a)), a.shape[:-2])) + + +@partial(jit, static_argnums=(1, 2, 3)) +def _norm(x, ord, axis: Union[None, Tuple[int, ...], int], keepdims): + x = _promote_arg_dtypes(jnp.asarray(x)) + x_shape = jnp.shape(x) + ndim = len(x_shape) + + if axis is None: + # NumPy has an undocumented behavior that admits arbitrary rank inputs if + # `ord` is None: https://github.com/numpy/numpy/issues/14215 + if ord is None: + return jnp.sqrt(jnp.sum(jnp.real(x * jnp.conj(x)), keepdims=keepdims)) + axis = tuple(range(ndim)) + elif isinstance(axis, tuple): + axis = tuple(canonicalize_axis(x, ndim) for x in axis) + else: + axis = (canonicalize_axis(axis, ndim),) + + num_axes = len(axis) + if num_axes == 1: + if ord is None or ord == 2: + return jnp.sqrt(jnp.sum(jnp.real(x * jnp.conj(x)), axis=axis, + keepdims=keepdims)) + elif ord == jnp.inf: + return jnp.amax(jnp.abs(x), axis=axis, keepdims=keepdims) + elif ord == -jnp.inf: + return jnp.amin(jnp.abs(x), axis=axis, keepdims=keepdims) + elif ord == 0: + return jnp.sum(x != 0, dtype=jnp.finfo(lax.dtype(x)).dtype, + axis=axis, keepdims=keepdims) + elif ord == 1: + # Numpy has a special case for ord == 1 as an optimization. We don't + # really need the optimization (XLA could do it for us), but the Numpy + # code has slightly different type promotion semantics, so we need a + # special case too. + return jnp.sum(jnp.abs(x), axis=axis, keepdims=keepdims) + else: + abs_x = jnp.abs(x) + ord = lax._const(abs_x, ord) + out = jnp.sum(abs_x ** ord, axis=axis, keepdims=keepdims) + return jnp.power(out, 1. / ord) + + elif num_axes == 2: + row_axis, col_axis = cast(Tuple[int, ...], axis) + if ord is None or ord in ('f', 'fro'): + return jnp.sqrt(jnp.sum(jnp.real(x * jnp.conj(x)), axis=axis, + keepdims=keepdims)) + elif ord == 1: + if not keepdims and col_axis > row_axis: + col_axis -= 1 + return jnp.amax(jnp.sum(jnp.abs(x), axis=row_axis, keepdims=keepdims), + axis=col_axis, keepdims=keepdims) + elif ord == -1: + if not keepdims and col_axis > row_axis: + col_axis -= 1 + return jnp.amin(jnp.sum(jnp.abs(x), axis=row_axis, keepdims=keepdims), + axis=col_axis, keepdims=keepdims) + elif ord == jnp.inf: + if not keepdims and row_axis > col_axis: + row_axis -= 1 + return jnp.amax(jnp.sum(jnp.abs(x), axis=col_axis, keepdims=keepdims), + axis=row_axis, keepdims=keepdims) + elif ord == -jnp.inf: + if not keepdims and row_axis > col_axis: + row_axis -= 1 + return jnp.amin(jnp.sum(jnp.abs(x), axis=col_axis, keepdims=keepdims), + axis=row_axis, keepdims=keepdims) + elif ord in ('nuc', 2, -2): + x = jnp.moveaxis(x, axis, (-2, -1)) + if ord == 2: + reducer = jnp.amax + elif ord == -2: + reducer = jnp.amin + else: + reducer = jnp.sum + y = reducer(svd(x, compute_uv=False), axis=-1) + if keepdims: + result_shape = list(x_shape) + result_shape[axis[0]] = 1 + result_shape[axis[1]] = 1 + y = jnp.reshape(y, result_shape) + return y + else: + raise ValueError("Invalid order '{}' for matrix norm.".format(ord)) + else: + raise ValueError( + "Invalid axis values ({}) for jnp.linalg.norm.".format(axis)) + +@_wraps(np.linalg.norm) +def norm(x, ord=None, axis=None, keepdims=False): + return _norm(x, ord, axis, keepdims) + + +@_wraps(np.linalg.qr) +def qr(a, mode="reduced"): + if mode in ("reduced", "r", "full"): + full_matrices = False + elif mode == "complete": + full_matrices = True + else: + raise ValueError("Unsupported QR decomposition mode '{}'".format(mode)) + a = _promote_arg_dtypes(jnp.asarray(a)) + q, r = lax_linalg.qr(a, full_matrices) + if mode == "r": + return r + return q, r + + +def _check_solve_shapes(a, b): + if not (a.ndim >= 2 and a.shape[-1] == a.shape[-2] and b.ndim >= 1): + msg = ("The arguments to solve must have shapes a=[..., m, m] and " + "b=[..., m, k] or b=[..., m]; got a={} and b={}") + raise ValueError(msg.format(a.shape, b.shape)) + + +@partial(vectorize, signature='(n,m),(m)->(n)') +def _matvec_multiply(a, b): + return jnp.dot(a, b, precision=lax.Precision.HIGHEST) + + +@_wraps(np.linalg.solve) +@jit +def solve(a, b): + a, b = _promote_arg_dtypes(jnp.asarray(a), jnp.asarray(b)) + _check_solve_shapes(a, b) + + # With custom_linear_solve, we can reuse the same factorization when + # computing sensitivities. This is considerably faster. + lu, _, permutation = lax_linalg.lu(lax.stop_gradient(a)) + custom_solve = partial( + lax.custom_linear_solve, + lambda x: _matvec_multiply(a, x), + solve=lambda _, x: lax_linalg.lu_solve(lu, permutation, x, trans=0), + transpose_solve=lambda _, x: lax_linalg.lu_solve(lu, permutation, x, + trans=1)) + if a.ndim == b.ndim + 1: + # b.shape == [..., m] + return custom_solve(b) + else: + # b.shape == [..., m, k] + return vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b) + + +@_wraps(np.linalg.lstsq, lax_description=textwrap.dedent("""\ + It has two important differences: + + 1. In `numpy.linalg.lstsq`, the default `rcond` is `-1`, and warns that in the future + the default will be `None`. Here, the default rcond is `None`. + 2. In `np.linalg.lstsq` the returned residuals are empty for low-rank or over-determined + solutions. Here, the residuals are returned in all cases, to make the function + compatible with jit. The non-jit compatible numpy behavior can be recovered by + passing numpy_resid=True. + + The lstsq function does not currently have a custom JVP rule, so the gradient is + poorly behaved for some inputs, particularly for low-rank `a`. + """)) +def lstsq(a, b, rcond=None, *, numpy_resid=False): + # TODO: add lstsq to lax_linalg and implement this function via those wrappers. + # TODO: add custom jvp rule for more robust lstsq differentiation + a, b = _promote_arg_dtypes(a, b) + if a.shape[0] != b.shape[0]: + raise ValueError("Leading dimensions of input arrays must match") + b_orig_ndim = b.ndim + if b_orig_ndim == 1: + b = b[:, None] + if a.ndim != 2: + raise TypeError( + f"{a.ndim}-dimensional array given. Array must be two-dimensional") + if b.ndim != 2: + raise TypeError( + f"{b.ndim}-dimensional array given. Array must be one or two-dimensional") + m, n = a.shape + dtype = a.dtype + if rcond is None: + rcond = jnp.finfo(dtype).eps * max(n, m) + elif rcond < 0: + rcond = jnp.finfo(dtype).eps + u, s, vt = svd(a, full_matrices=False) + mask = s >= rcond * s[0] + rank = mask.sum() + safe_s = jnp.where(mask, s, 1) + s_inv = jnp.where(mask, 1 / safe_s, 0)[:, jnp.newaxis] + uTb = jnp.matmul(u.conj().T, b, precision=lax.Precision.HIGHEST) + x = jnp.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST) + # Numpy returns empty residuals in some cases. To allow compilation, we + # default to returning full residuals in all cases. + if numpy_resid and (rank < n or m <= n): + resid = jnp.asarray([]) + else: + b_estimate = jnp.matmul(a, x, precision=lax.Precision.HIGHEST) + resid = norm(b - b_estimate, axis=0) ** 2 + if b_orig_ndim == 1: + x = x.ravel() + return x, resid, rank, s diff --git a/jax/numpy/polynomial.py b/jax/_src/numpy/polynomial.py similarity index 90% rename from jax/numpy/polynomial.py rename to jax/_src/numpy/polynomial.py index eb75fd356..9ceff4739 100644 --- a/jax/numpy/polynomial.py +++ b/jax/_src/numpy/polynomial.py @@ -14,15 +14,13 @@ import numpy as np -from .. import lax +from jax import lax from . import lax_numpy as jnp from jax import jit -from ._util import _wraps -from .lax_numpy import _not_implemented +from .util import _wraps from .linalg import eigvals as _eigvals -from .. import ops as jaxops -from ..util import get_module_functions +from jax import ops as jaxops def _to_inexact_type(type): @@ -104,10 +102,3 @@ def roots(p, *, strip_zeros=True): # combine roots and zero roots roots = jnp.hstack((roots, jnp.zeros(trailing_zeros, p.dtype))) return roots - - -_NOT_IMPLEMENTED = [] -for name, func in get_module_functions(np.polynomial).items(): - if name not in globals(): - _NOT_IMPLEMENTED.append(name) - globals()[name] = _not_implemented(func) diff --git a/jax/numpy/_util.py b/jax/_src/numpy/util.py similarity index 100% rename from jax/numpy/_util.py rename to jax/_src/numpy/util.py diff --git a/jax/numpy/vectorize.py b/jax/_src/numpy/vectorize.py similarity index 99% rename from jax/numpy/vectorize.py rename to jax/_src/numpy/vectorize.py index 21bcb0766..400bcfb40 100644 --- a/jax/numpy/vectorize.py +++ b/jax/_src/numpy/vectorize.py @@ -11,14 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import functools import re from typing import Any, Callable, Dict, List, Tuple -from .. import api -from .. import lax +from jax import api +from jax import lax from . import lax_numpy as jnp -from ..util import safe_map as map, safe_zip as zip +from jax.util import safe_map as map, safe_zip as zip # See http://docs.scipy.org/doc/numpy/reference/c-api.generalized-ufuncs.html diff --git a/jax/_src/scipy/linalg.py b/jax/_src/scipy/linalg.py index 999d301d1..adde27c34 100644 --- a/jax/_src/scipy/linalg.py +++ b/jax/_src/scipy/linalg.py @@ -22,9 +22,9 @@ from jax import jit, vmap from jax import api from jax import lax from jax import lax_linalg -from jax.numpy._util import _wraps -from jax.numpy import lax_numpy as jnp -from jax.numpy import linalg as np_linalg +from jax._src.numpy.util import _wraps +from jax._src.numpy import lax_numpy as jnp +from jax._src.numpy import linalg as np_linalg _T = lambda x: jnp.swapaxes(x, -1, -2) diff --git a/jax/_src/scipy/ndimage.py b/jax/_src/scipy/ndimage.py index d17a9f7ec..255d35e38 100644 --- a/jax/_src/scipy/ndimage.py +++ b/jax/_src/scipy/ndimage.py @@ -22,8 +22,8 @@ import scipy.ndimage from jax import api from jax import lax -from jax.numpy import lax_numpy as jnp -from jax.numpy._util import _wraps +from jax._src.numpy import lax_numpy as jnp +from jax._src.numpy.util import _wraps from jax.util import safe_zip as zip diff --git a/jax/_src/scipy/signal.py b/jax/_src/scipy/signal.py index 29fca948b..1f537d8dc 100644 --- a/jax/_src/scipy/signal.py +++ b/jax/_src/scipy/signal.py @@ -18,10 +18,10 @@ import warnings import numpy as np from jax import lax -from jax.numpy import lax_numpy as jnp -from jax.numpy import linalg -from jax.numpy.lax_numpy import _promote_dtypes_inexact -from jax.numpy._util import _wraps +from jax._src.numpy import lax_numpy as jnp +from jax._src.numpy import linalg +from jax._src.numpy.lax_numpy import _promote_dtypes_inexact +from jax._src.numpy.util import _wraps # Note: we do not re-use the code from jax.numpy.convolve here, because the handling diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index 9286fb7e5..f7cf43048 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -20,10 +20,10 @@ import scipy.special as osp_special from jax import lax from jax import api from jax.interpreters import ad -from jax.numpy import lax_numpy as jnp -from jax.numpy.lax_numpy import (asarray, _reduction_dims, _constant_like, - _promote_args_inexact) -from jax.numpy._util import _wraps +from jax._src.numpy import lax_numpy as jnp +from jax._src.numpy.lax_numpy import (asarray, _reduction_dims, _constant_like, + _promote_args_inexact) +from jax._src.numpy.util import _wraps @_wraps(osp_special.gammaln) diff --git a/jax/_src/scipy/stats/bernoulli.py b/jax/_src/scipy/stats/bernoulli.py index bc2305c0a..aa99b96fe 100644 --- a/jax/_src/scipy/stats/bernoulli.py +++ b/jax/_src/scipy/stats/bernoulli.py @@ -16,8 +16,8 @@ import scipy.stats as osp_stats from jax import lax -from jax.numpy import lax_numpy as jnp -from jax.numpy._util import _wraps +from jax._src.numpy import lax_numpy as jnp +from jax._src.numpy.util import _wraps from jax.scipy.special import xlogy, xlog1py diff --git a/jax/_src/scipy/stats/beta.py b/jax/_src/scipy/stats/beta.py index 5c658e886..64a9e59cb 100644 --- a/jax/_src/scipy/stats/beta.py +++ b/jax/_src/scipy/stats/beta.py @@ -15,9 +15,9 @@ import scipy.stats as osp_stats from jax import lax -from jax.numpy._util import _wraps -from jax.numpy.lax_numpy import (_promote_args_inexact, _constant_like, - where, inf, logical_or) +from jax._src.numpy.util import _wraps +from jax._src.numpy.lax_numpy import (_promote_args_inexact, _constant_like, + where, inf, logical_or) from jax.scipy.special import betaln diff --git a/jax/_src/scipy/stats/cauchy.py b/jax/_src/scipy/stats/cauchy.py index 158e16490..68624a422 100644 --- a/jax/_src/scipy/stats/cauchy.py +++ b/jax/_src/scipy/stats/cauchy.py @@ -17,8 +17,8 @@ import numpy as np import scipy.stats as osp_stats from jax import lax -from jax.numpy._util import _wraps -from jax.numpy.lax_numpy import _promote_args_inexact, _constant_like +from jax._src.numpy.util import _wraps +from jax._src.numpy.lax_numpy import _promote_args_inexact, _constant_like @_wraps(osp_stats.cauchy.logpdf, update_doc=False) diff --git a/jax/_src/scipy/stats/dirichlet.py b/jax/_src/scipy/stats/dirichlet.py index c1cdacc06..f58aaf26b 100644 --- a/jax/_src/scipy/stats/dirichlet.py +++ b/jax/_src/scipy/stats/dirichlet.py @@ -17,8 +17,8 @@ import numpy as np import scipy.stats as osp_stats from jax import lax -from jax.numpy import lax_numpy as jnp -from jax.numpy._util import _wraps +from jax._src.numpy import lax_numpy as jnp +from jax._src.numpy.util import _wraps from jax.scipy.special import gammaln, xlogy diff --git a/jax/_src/scipy/stats/expon.py b/jax/_src/scipy/stats/expon.py index 8dd5a344c..bf4a7afb4 100644 --- a/jax/_src/scipy/stats/expon.py +++ b/jax/_src/scipy/stats/expon.py @@ -15,8 +15,8 @@ import scipy.stats as osp_stats from jax import lax -from jax.numpy._util import _wraps -from jax.numpy.lax_numpy import _promote_args_inexact, where, inf +from jax._src.numpy.util import _wraps +from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf @_wraps(osp_stats.expon.logpdf, update_doc=False) diff --git a/jax/_src/scipy/stats/gamma.py b/jax/_src/scipy/stats/gamma.py index 172254c7b..5c2a33010 100644 --- a/jax/_src/scipy/stats/gamma.py +++ b/jax/_src/scipy/stats/gamma.py @@ -15,9 +15,9 @@ import scipy.stats as osp_stats from jax import lax -from jax.numpy._util import _wraps -from jax.numpy.lax_numpy import (_promote_args_inexact, _constant_like, - where, inf) +from jax._src.numpy.util import _wraps +from jax._src.numpy.lax_numpy import (_promote_args_inexact, _constant_like, + where, inf) from jax.scipy.special import gammaln diff --git a/jax/_src/scipy/stats/geom.py b/jax/_src/scipy/stats/geom.py index 00f164243..442800c4d 100644 --- a/jax/_src/scipy/stats/geom.py +++ b/jax/_src/scipy/stats/geom.py @@ -15,8 +15,8 @@ import scipy.stats as osp_stats from jax import lax -from jax.numpy import lax_numpy as jnp -from jax.numpy._util import _wraps +from jax._src.numpy import lax_numpy as jnp +from jax._src.numpy.util import _wraps from jax.scipy.special import xlog1py @_wraps(osp_stats.geom.logpmf, update_doc=False) diff --git a/jax/_src/scipy/stats/laplace.py b/jax/_src/scipy/stats/laplace.py index acfa5d4f0..4b9179752 100644 --- a/jax/_src/scipy/stats/laplace.py +++ b/jax/_src/scipy/stats/laplace.py @@ -15,8 +15,8 @@ import scipy.stats as osp_stats from jax import lax -from jax.numpy._util import _wraps -from jax.numpy.lax_numpy import _promote_args_inexact, _constant_like +from jax._src.numpy.util import _wraps +from jax._src.numpy.lax_numpy import _promote_args_inexact, _constant_like @_wraps(osp_stats.laplace.logpdf, update_doc=False) diff --git a/jax/_src/scipy/stats/logistic.py b/jax/_src/scipy/stats/logistic.py index 32ede9c5f..196f49c22 100644 --- a/jax/_src/scipy/stats/logistic.py +++ b/jax/_src/scipy/stats/logistic.py @@ -16,7 +16,7 @@ import scipy.stats as osp_stats from jax.scipy.special import expit, logit from jax import lax -from jax.numpy._util import _wraps +from jax._src.numpy.util import _wraps @_wraps(osp_stats.logistic.logpdf, update_doc=False) diff --git a/jax/_src/scipy/stats/multivariate_normal.py b/jax/_src/scipy/stats/multivariate_normal.py index 407dce28c..aebee2c3b 100644 --- a/jax/_src/scipy/stats/multivariate_normal.py +++ b/jax/_src/scipy/stats/multivariate_normal.py @@ -19,8 +19,8 @@ import scipy.stats as osp_stats from jax import lax from jax.lax_linalg import cholesky, triangular_solve from jax import numpy as jnp -from jax.numpy._util import _wraps -from jax.numpy.lax_numpy import _promote_dtypes_inexact +from jax._src.numpy.util import _wraps +from jax._src.numpy.lax_numpy import _promote_dtypes_inexact @_wraps(osp_stats.multivariate_normal.logpdf, update_doc=False) diff --git a/jax/_src/scipy/stats/norm.py b/jax/_src/scipy/stats/norm.py index d4b689c85..63a073128 100644 --- a/jax/_src/scipy/stats/norm.py +++ b/jax/_src/scipy/stats/norm.py @@ -17,9 +17,9 @@ import numpy as np import scipy.stats as osp_stats from jax import lax -from jax.numpy import lax_numpy as jnp -from jax.numpy._util import _wraps -from jax.numpy.lax_numpy import _promote_args_inexact, _constant_like +from jax._src.numpy import lax_numpy as jnp +from jax._src.numpy.util import _wraps +from jax._src.numpy.lax_numpy import _promote_args_inexact, _constant_like from jax.scipy import special @_wraps(osp_stats.norm.logpdf, update_doc=False) diff --git a/jax/_src/scipy/stats/pareto.py b/jax/_src/scipy/stats/pareto.py index ee20605c7..4bdd7e32e 100644 --- a/jax/_src/scipy/stats/pareto.py +++ b/jax/_src/scipy/stats/pareto.py @@ -16,8 +16,8 @@ import scipy.stats as osp_stats from jax import lax -from jax.numpy._util import _wraps -from jax.numpy.lax_numpy import _promote_args_inexact, _constant_like, inf, where +from jax._src.numpy.util import _wraps +from jax._src.numpy.lax_numpy import _promote_args_inexact, _constant_like, inf, where @_wraps(osp_stats.pareto.logpdf, update_doc=False) diff --git a/jax/_src/scipy/stats/poisson.py b/jax/_src/scipy/stats/poisson.py index efb5e59b1..a25ea88dc 100644 --- a/jax/_src/scipy/stats/poisson.py +++ b/jax/_src/scipy/stats/poisson.py @@ -16,8 +16,8 @@ import scipy.stats as osp_stats from jax import lax -from jax.numpy._util import _wraps -from jax.numpy import lax_numpy as jnp +from jax._src.numpy.util import _wraps +from jax._src.numpy import lax_numpy as jnp from jax.scipy.special import xlogy, gammaln diff --git a/jax/_src/scipy/stats/t.py b/jax/_src/scipy/stats/t.py index fda23f9a2..a56059f38 100644 --- a/jax/_src/scipy/stats/t.py +++ b/jax/_src/scipy/stats/t.py @@ -17,8 +17,8 @@ import numpy as np import scipy.stats as osp_stats from jax import lax -from jax.numpy._util import _wraps -from jax.numpy.lax_numpy import _promote_args_inexact, _constant_like +from jax._src.numpy.util import _wraps +from jax._src.numpy.lax_numpy import _promote_args_inexact, _constant_like @_wraps(osp_stats.t.logpdf, update_doc=False) diff --git a/jax/_src/scipy/stats/uniform.py b/jax/_src/scipy/stats/uniform.py index 23595a347..47cb1e1d9 100644 --- a/jax/_src/scipy/stats/uniform.py +++ b/jax/_src/scipy/stats/uniform.py @@ -16,8 +16,8 @@ import scipy.stats as osp_stats from jax import lax -from jax.numpy._util import _wraps -from jax.numpy.lax_numpy import _promote_args_inexact, where, inf, logical_or +from jax._src.numpy.util import _wraps +from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf, logical_or @_wraps(osp_stats.uniform.logpdf, update_doc=False) diff --git a/jax/lax/lax_parallel.py b/jax/lax/lax_parallel.py index 9fc82bd5b..79de3cd71 100644 --- a/jax/lax/lax_parallel.py +++ b/jax/lax/lax_parallel.py @@ -35,7 +35,7 @@ from jax.util import partial, unzip2, prod from jax.lib import xla_client as xc from jax.lib import xla_bridge as xb from jax.config import config -from jax.numpy import lax_numpy +from jax._src.numpy import lax_numpy xops = xc.ops diff --git a/jax/lax_linalg.py b/jax/lax_linalg.py index 06c732d90..73d8e00e6 100644 --- a/jax/lax_linalg.py +++ b/jax/lax_linalg.py @@ -15,8 +15,8 @@ import numpy as np -from jax.numpy import lax_numpy as jnp -from jax.numpy.vectorize import vectorize +from jax._src.numpy import lax_numpy as jnp +from jax._src.numpy.vectorize import vectorize from jax import ad_util from jax import api from jax import lax diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 68eb6339a..fe57765d1 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -16,9 +16,9 @@ from . import fft from . import linalg -from ..interpreters.xla import DeviceArray +from jax.interpreters.xla import DeviceArray -from .lax_numpy import ( +from jax._src.numpy.lax_numpy import ( ComplexWarning, NINF, NZERO, PZERO, abs, absolute, add, all, allclose, alltrue, amax, amin, angle, any, append, apply_along_axis, apply_over_axes, arange, arccos, arccosh, arcsin, @@ -63,16 +63,18 @@ from .lax_numpy import ( unpackbits, unravel_index, unsignedinteger, unwrap, vander, var, vdot, vsplit, vstack, where, zeros, zeros_like, _NOT_IMPLEMENTED) -from .polynomial import roots -from .vectorize import vectorize +from jax._src.numpy.polynomial import roots +from jax._src.numpy.vectorize import vectorize +# TODO(phawkins): remove this import after fixing users. +from jax._src.numpy import lax_numpy # Module initialization is encapsulated in a function to avoid accidental # namespace pollution. def _init(): import numpy as np - from . import lax_numpy - from .. import util + from jax._src.numpy import lax_numpy + from jax import util # Builds a set of all unimplemented NumPy functions. for name, func in util.get_module_functions(np).items(): if name not in globals(): diff --git a/jax/numpy/fft.py b/jax/numpy/fft.py index 00b656369..c4bad037b 100644 --- a/jax/numpy/fft.py +++ b/jax/numpy/fft.py @@ -1,4 +1,4 @@ -# Copyright 2018 Google LLC +# Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,251 +12,41 @@ # See the License for the specific language governing permissions and # limitations under the License. +# flake8: noqa: F401 -import numpy as np - -from .. import lax -from ..lib import xla_client -from ..util import get_module_functions -from .lax_numpy import _not_implemented -from ._util import _wraps -from . import lax_numpy as jnp -from .. import ops as jaxops - - -def _fft_core(func_name, fft_type, a, s, axes, norm): - # TODO(skye): implement padding/cropping based on 's'. - full_name = "jax.numpy.fft." + func_name - if s is not None: - raise NotImplementedError("%s only supports s=None, got %s" % (full_name, s)) - if norm is not None: - raise NotImplementedError("%s only supports norm=None, got %s" % (full_name, norm)) - if s is not None and axes is not None and len(s) != len(axes): - # Same error as numpy. - raise ValueError("Shape and axes have different lengths.") - - orig_axes = axes - if axes is None: - if s is None: - axes = range(a.ndim) - else: - axes = range(a.ndim - len(s), a.ndim) - - if len(axes) != len(set(axes)): - raise ValueError( - "%s does not support repeated axes. Got axes %s." % (full_name, axes)) - - if len(axes) > 3: - # XLA does not support FFTs over more than 3 dimensions - raise ValueError( - "%s only supports 1D, 2D, and 3D FFTs. " - "Got axes %s with input rank %s." % (full_name, orig_axes, a.ndim)) - - # XLA only supports FFTs over the innermost axes, so rearrange if necessary. - if orig_axes is not None: - axes = tuple(range(a.ndim - len(axes), a.ndim)) - a = jnp.moveaxis(a, orig_axes, axes) - - if s is None: - if fft_type == xla_client.FftType.IRFFT: - s = [a.shape[axis] for axis in axes[:-1]] - if axes: - s += [max(0, 2 * (a.shape[axes[-1]] - 1))] - else: - s = [a.shape[axis] for axis in axes] - - transformed = lax.fft(a, fft_type, s) - - if orig_axes is not None: - transformed = jnp.moveaxis(transformed, axes, orig_axes) - return transformed - - -@_wraps(np.fft.fftn) -def fftn(a, s=None, axes=None, norm=None): - return _fft_core('fftn', xla_client.FftType.FFT, a, s, axes, norm) - - -@_wraps(np.fft.ifftn) -def ifftn(a, s=None, axes=None, norm=None): - return _fft_core('ifftn', xla_client.FftType.IFFT, a, s, axes, norm) - - -@_wraps(np.fft.rfftn) -def rfftn(a, s=None, axes=None, norm=None): - return _fft_core('rfftn', xla_client.FftType.RFFT, a, s, axes, norm) - - -@_wraps(np.fft.irfftn) -def irfftn(a, s=None, axes=None, norm=None): - return _fft_core('irfftn', xla_client.FftType.IRFFT, a, s, axes, norm) - - -def _axis_check_1d(func_name, axis): - full_name = "jax.numpy.fft." + func_name - if isinstance(axis, (list, tuple)): - raise ValueError( - "%s does not support multiple axes. Please use %sn. " - "Got axis = %r." % (full_name, full_name, axis) - ) - -def _fft_core_1d(func_name, fft_type, a, s, axis, norm): - _axis_check_1d(func_name, axis) - axes = None if axis is None else [axis] - return _fft_core(func_name, fft_type, a, s, axes, norm) - - -@_wraps(np.fft.fft) -def fft(a, n=None, axis=-1, norm=None): - return _fft_core_1d('fft', xla_client.FftType.FFT, a, s=n, axis=axis, - norm=norm) - -@_wraps(np.fft.ifft) -def ifft(a, n=None, axis=-1, norm=None): - return _fft_core_1d('ifft', xla_client.FftType.IFFT, a, s=n, axis=axis, - norm=norm) - -@_wraps(np.fft.rfft) -def rfft(a, n=None, axis=-1, norm=None): - return _fft_core_1d('rfft', xla_client.FftType.RFFT, a, s=n, axis=axis, - norm=norm) - -@_wraps(np.fft.irfft) -def irfft(a, n=None, axis=-1, norm=None): - return _fft_core_1d('irfft', xla_client.FftType.IRFFT, a, s=n, axis=axis, - norm=norm) - -@_wraps(np.fft.hfft) -def hfft(a, n=None, axis=-1, norm=None): - conj_a = jnp.conj(a) - _axis_check_1d('hfft', axis) - nn = (a.shape[axis] - 1) * 2 if n is None else n - return _fft_core_1d('hfft', xla_client.FftType.IRFFT, conj_a, s=n, axis=axis, - norm=norm) * nn - -@_wraps(np.fft.ihfft) -def ihfft(a, n=None, axis=-1, norm=None): - _axis_check_1d('ihfft', axis) - nn = a.shape[axis] if n is None else n - output = _fft_core_1d('ihfft', xla_client.FftType.RFFT, a, s=n, axis=axis, - norm=norm) - return jnp.conj(output) * (1 / nn) - - -def _fft_core_2d(func_name, fft_type, a, s, axes, norm): - full_name = "jax.numpy.fft." + func_name - if len(axes) != 2: - raise ValueError( - "%s only supports 2 axes. Got axes = %r." - % (full_name, axes) - ) - return _fft_core(func_name, fft_type, a, s, axes, norm) - - -@_wraps(np.fft.fft2) -def fft2(a, s=None, axes=(-2,-1), norm=None): - return _fft_core_2d('fft2', xla_client.FftType.FFT, a, s=s, axes=axes, - norm=norm) - -@_wraps(np.fft.ifft2) -def ifft2(a, s=None, axes=(-2,-1), norm=None): - return _fft_core_2d('ifft2', xla_client.FftType.IFFT, a, s=s, axes=axes, - norm=norm) - -@_wraps(np.fft.rfft2) -def rfft2(a, s=None, axes=(-2,-1), norm=None): - return _fft_core_2d('rfft2', xla_client.FftType.RFFT, a, s=s, axes=axes, - norm=norm) - -@_wraps(np.fft.irfft2) -def irfft2(a, s=None, axes=(-2,-1), norm=None): - return _fft_core_2d('irfft2', xla_client.FftType.IRFFT, a, s=s, axes=axes, - norm=norm) - - -@_wraps(np.fft.fftfreq) -def fftfreq(n, d=1.0): - if isinstance(n, (list, tuple)): - raise ValueError( - "The n argument of jax.numpy.fft.fftfreq only takes an int. " - "Got n = %s." % list(n)) - - elif isinstance(d, (list, tuple)): - raise ValueError( - "The d argument of jax.numpy.fft.fftfreq only takes a single value. " - "Got d = %s." % list(d)) - - k = jnp.zeros(n) - if n % 2 == 0: - # k[0: n // 2 - 1] = jnp.arange(0, n // 2 - 1) - k = jaxops.index_update(k, jaxops.index[0: n // 2], jnp.arange(0, n // 2)) - - # k[n // 2:] = jnp.arange(-n // 2, -1) - k = jaxops.index_update(k, jaxops.index[n // 2:], jnp.arange(-n // 2, 0)) - - else: - # k[0: (n - 1) // 2] = jnp.arange(0, (n - 1) // 2) - k = jaxops.index_update(k, jaxops.index[0: (n - 1) // 2 + 1], - jnp.arange(0, (n - 1) // 2 + 1)) - - # k[(n - 1) // 2 + 1:] = jnp.arange(-(n - 1) // 2, -1) - k = jaxops.index_update(k, jaxops.index[(n - 1) // 2 + 1:], - jnp.arange(-(n - 1) // 2, 0)) - - return k / (d * n) - - -@_wraps(np.fft.rfftfreq) -def rfftfreq(n, d=1.0): - if isinstance(n, (list, tuple)): - raise ValueError( - "The n argument of jax.numpy.fft.rfftfreq only takes an int. " - "Got n = %s." % list(n)) - - elif isinstance(d, (list, tuple)): - raise ValueError( - "The d argument of jax.numpy.fft.rfftfreq only takes a single value. " - "Got d = %s." % list(d)) - - if n % 2 == 0: - k = jnp.arange(0, n // 2 + 1) - - else: - k = jnp.arange(0, (n - 1) // 2 + 1) - - return k / (d * n) - - -@_wraps(np.fft.fftshift) -def fftshift(x, axes=None): - x = jnp.asarray(x) - if axes is None: - axes = tuple(range(x.ndim)) - shift = [dim // 2 for dim in x.shape] - elif isinstance(axes, int): - shift = x.shape[axes] // 2 - else: - shift = [x.shape[ax] // 2 for ax in axes] - - return jnp.roll(x, shift, axes) - - -@_wraps(np.fft.ifftshift) -def ifftshift(x, axes=None): - x = jnp.asarray(x) - if axes is None: - axes = tuple(range(x.ndim)) - shift = [-(dim // 2) for dim in x.shape] - elif isinstance(axes, int): - shift = -(x.shape[axes] // 2) - else: - shift = [-(x.shape[ax] // 2) for ax in axes] - - return jnp.roll(x, shift, axes) - +from jax._src.numpy.fft import ( + ifft, + ifft2, + ifftn, + ifftshift, + ihfft, + irfft, + irfft2, + irfftn, + fft, + fft2, + fftfreq, + fftn, + fftshift, + hfft, + rfft, + rfft2, + rfftfreq, + rfftn, +) +# Module initialization is encapsulated in a function to avoid accidental +# namespace pollution. _NOT_IMPLEMENTED = [] -for name, func in get_module_functions(np.fft).items(): - if name not in globals(): - _NOT_IMPLEMENTED.append(name) - globals()[name] = _not_implemented(func) +def _init(): + import numpy as np + from jax._src.numpy import lax_numpy + from jax import util + # Builds a set of all unimplemented NumPy functions. + for name, func in util.get_module_functions(np.fft).items(): + if name not in globals(): + _NOT_IMPLEMENTED.append(name) + globals()[name] = lax_numpy._not_implemented(func) + +_init() +del _init diff --git a/jax/numpy/linalg.py b/jax/numpy/linalg.py index 5a55f09de..92575619b 100644 --- a/jax/numpy/linalg.py +++ b/jax/numpy/linalg.py @@ -1,4 +1,4 @@ -# Copyright 2018 Google LLC +# Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,527 +12,43 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from functools import partial - -import numpy as np -import textwrap -import operator -from typing import Tuple, Union, cast - -from jax import jit, vmap, custom_jvp -from .. import lax -from .. import ops -from .. import lax_linalg -from .. import dtypes -from .lax_numpy import _not_implemented -from ._util import _wraps -from .vectorize import vectorize -from . import lax_numpy as jnp -from ..util import get_module_functions, canonicalize_axis -from ..third_party.numpy.linalg import cond, multi_dot, tensorinv, tensorsolve # noqa: F401 - -_T = lambda x: jnp.swapaxes(x, -1, -2) -_H = lambda x: jnp.conj(jnp.swapaxes(x, -1, -2)) - - -def _promote_arg_dtypes(*args): - """Promotes `args` to a common inexact type.""" - def _to_inexact_type(type): - return type if jnp.issubdtype(type, jnp.inexact) else jnp.float_ - inexact_types = [_to_inexact_type(jnp._dtype(arg)) for arg in args] - dtype = dtypes.canonicalize_dtype(jnp.result_type(*inexact_types)) - args = [lax.convert_element_type(arg, dtype) for arg in args] - if len(args) == 1: - return args[0] - else: - return args - - -@_wraps(np.linalg.cholesky) -def cholesky(a): - a = _promote_arg_dtypes(jnp.asarray(a)) - return lax_linalg.cholesky(a) - - -@_wraps(np.linalg.svd) -def svd(a, full_matrices=True, compute_uv=True): - a = _promote_arg_dtypes(jnp.asarray(a)) - return lax_linalg.svd(a, full_matrices, compute_uv) - - -@_wraps(np.linalg.matrix_power) -def matrix_power(a, n): - a = _promote_arg_dtypes(jnp.asarray(a)) - - if a.ndim < 2: - raise TypeError("{}-dimensional array given. Array must be at least " - "two-dimensional".format(a.ndim)) - if a.shape[-2] != a.shape[-1]: - raise TypeError("Last 2 dimensions of the array must be square") - try: - n = operator.index(n) - except TypeError as err: - raise TypeError("exponent must be an integer, got {}".format(n)) from err - - if n == 0: - return jnp.broadcast_to(jnp.eye(a.shape[-2], dtype=a.dtype), a.shape) - elif n < 0: - a = inv(a) - n = np.abs(n) - - if n == 1: - return a - elif n == 2: - return a @ a - elif n == 3: - return (a @ a) @ a - - z = result = None - while n > 0: - z = a if z is None else (z @ z) - n, bit = divmod(n, 2) - if bit: - result = z if result is None else (result @ z) - - return result - - -@_wraps(np.linalg.matrix_rank) -def matrix_rank(M, tol=None): - M = _promote_arg_dtypes(jnp.asarray(M)) - if M.ndim > 2: - raise TypeError("array should have 2 or fewer dimensions") - if M.ndim < 2: - return jnp.any(M != 0).astype(jnp.int32) - S = svd(M, full_matrices=False, compute_uv=False) - if tol is None: - tol = S.max() * np.max(M.shape) * jnp.finfo(S.dtype).eps - return jnp.sum(S > tol) - - -@custom_jvp -@_wraps(np.linalg.slogdet) -@jit -def slogdet(a): - a = _promote_arg_dtypes(jnp.asarray(a)) - dtype = lax.dtype(a) - a_shape = jnp.shape(a) - if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]: - msg = "Argument to slogdet() must have shape [..., n, n], got {}" - raise ValueError(msg.format(a_shape)) - lu, pivot, _ = lax_linalg.lu(a) - diag = jnp.diagonal(lu, axis1=-2, axis2=-1) - is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1) - parity = jnp.count_nonzero(pivot != jnp.arange(a_shape[-1]), axis=-1) - if jnp.iscomplexobj(a): - sign = jnp.prod(diag / jnp.abs(diag), axis=-1) - else: - sign = jnp.array(1, dtype=dtype) - parity = parity + jnp.count_nonzero(diag < 0, axis=-1) - sign = jnp.where(is_zero, - jnp.array(0, dtype=dtype), - sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype)) - logdet = jnp.where( - is_zero, jnp.array(-jnp.inf, dtype=dtype), - jnp.sum(jnp.log(jnp.abs(diag)), axis=-1)) - return sign, jnp.real(logdet) - -@slogdet.defjvp -def _slogdet_jvp(primals, tangents): - x, = primals - g, = tangents - if jnp.issubdtype(jnp._dtype(x), jnp.complexfloating): - raise NotImplementedError # TODO(pfau): make this work for complex types - sign, ans = slogdet(x) - sign_dot, ans_dot = jnp.zeros_like(sign), jnp.trace(solve(x, g), axis1=-1, axis2=-2) - return (sign, ans), (sign_dot, ans_dot) - - -def _cofactor_solve(a, b): - """Equivalent to det(a)*solve(a, b) for nonsingular mat. - - Intermediate function used for jvp and vjp of det. - This function borrows heavily from jax.numpy.linalg.solve and - jax.numpy.linalg.slogdet to compute the gradient of the determinant - in a way that is well defined even for low rank matrices. - - This function handles two different cases: - * rank(a) == n or n-1 - * rank(a) < n-1 - - For rank n-1 matrices, the gradient of the determinant is a rank 1 matrix. - Rather than computing det(a)*solve(a, b), which would return NaN, we work - directly with the LU decomposition. If a = p @ l @ u, then - det(a)*solve(a, b) = - prod(diag(u)) * u^-1 @ l^-1 @ p^-1 b = - prod(diag(u)) * triangular_solve(u, solve(p @ l, b)) - If a is rank n-1, then the lower right corner of u will be zero and the - triangular_solve will fail. - Let x = solve(p @ l, b) and y = det(a)*solve(a, b). - Then y_{n} - x_{n} / u_{nn} * prod_{i=1...n}(u_{ii}) = - x_{n} * prod_{i=1...n-1}(u_{ii}) - So by replacing the lower-right corner of u with prod_{i=1...n-1}(u_{ii})^-1 - we can avoid the triangular_solve failing. - To correctly compute the rest of y_{i} for i != n, we simply multiply - x_{i} by det(a) for all i != n, which will be zero if rank(a) = n-1. - - For the second case, a check is done on the matrix to see if `solve` - returns NaN or Inf, and gives a matrix of zeros as a result, as the - gradient of the determinant of a matrix with rank less than n-1 is 0. - This will still return the correct value for rank n-1 matrices, as the check - is applied *after* the lower right corner of u has been updated. - - Args: - a: A square matrix or batch of matrices, possibly singular. - b: A matrix, or batch of matrices of the same dimension as a. - - Returns: - det(a) and cofactor(a)^T*b, aka adjugate(a)*b - """ - a = _promote_arg_dtypes(jnp.asarray(a)) - b = _promote_arg_dtypes(jnp.asarray(b)) - a_shape = jnp.shape(a) - b_shape = jnp.shape(b) - a_ndims = len(a_shape) - if not (a_ndims >= 2 and a_shape[-1] == a_shape[-2] - and b_shape[-2:] == a_shape[-2:]): - msg = ("The arguments to _cofactor_solve must have shapes " - "a=[..., m, m] and b=[..., m, m]; got a={} and b={}") - raise ValueError(msg.format(a_shape, b_shape)) - if a_shape[-1] == 1: - return a[0, 0], b - # lu contains u in the upper triangular matrix and l in the strict lower - # triangular matrix. - # The diagonal of l is set to ones without loss of generality. - lu, pivots, permutation = lax_linalg.lu(a) - dtype = lax.dtype(a) - batch_dims = lax.broadcast_shapes(lu.shape[:-2], b.shape[:-2]) - x = jnp.broadcast_to(b, batch_dims + b.shape[-2:]) - lu = jnp.broadcast_to(lu, batch_dims + lu.shape[-2:]) - # Compute (partial) determinant, ignoring last diagonal of LU - diag = jnp.diagonal(lu, axis1=-2, axis2=-1) - parity = jnp.count_nonzero(pivots != jnp.arange(a_shape[-1]), axis=-1) - sign = jnp.array(-2 * (parity % 2) + 1, dtype=dtype) - # partial_det[:, -1] contains the full determinant and - # partial_det[:, -2] contains det(u) / u_{nn}. - partial_det = jnp.cumprod(diag, axis=-1) * sign[..., None] - lu = ops.index_update(lu, ops.index[..., -1, -1], 1.0 / partial_det[..., -2]) - permutation = jnp.broadcast_to(permutation, batch_dims + (a_shape[-1],)) - iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1,))) - # filter out any matrices that are not full rank - d = jnp.ones(x.shape[:-1], x.dtype) - d = lax_linalg.triangular_solve(lu, d, left_side=True, lower=False) - d = jnp.any(jnp.logical_or(jnp.isnan(d), jnp.isinf(d)), axis=-1) - d = jnp.tile(d[..., None, None], d.ndim*(1,) + x.shape[-2:]) - x = jnp.where(d, jnp.zeros_like(x), x) # first filter - x = x[iotas[:-1] + (permutation, slice(None))] - x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=True, - unit_diagonal=True) - x = jnp.concatenate((x[..., :-1, :] * partial_det[..., -1, None, None], - x[..., -1:, :]), axis=-2) - x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=False) - x = jnp.where(d, jnp.zeros_like(x), x) # second filter - - return partial_det[..., -1], x - - -@custom_jvp -@_wraps(np.linalg.det) -def det(a): - sign, logdet = slogdet(a) - return sign * jnp.exp(logdet) - - -@det.defjvp -def _det_jvp(primals, tangents): - x, = primals - g, = tangents - y, z = _cofactor_solve(x, g) - return y, jnp.trace(z, axis1=-1, axis2=-2) - - -@_wraps(np.linalg.eig) -def eig(a): - a = _promote_arg_dtypes(jnp.asarray(a)) - return lax_linalg.eig(a, compute_left_eigenvectors=False) - - -@_wraps(np.linalg.eigvals) -def eigvals(a): - return lax_linalg.eig(a, compute_left_eigenvectors=False, - compute_right_eigenvectors=False)[0] - - -@_wraps(np.linalg.eigh) -def eigh(a, UPLO=None, symmetrize_input=True): - if UPLO is None or UPLO == "L": - lower = True - elif UPLO == "U": - lower = False - else: - msg = "UPLO must be one of None, 'L', or 'U', got {}".format(UPLO) - raise ValueError(msg) - - a = _promote_arg_dtypes(jnp.asarray(a)) - v, w = lax_linalg.eigh(a, lower=lower, symmetrize_input=symmetrize_input) - return w, v - - -@_wraps(np.linalg.eigvalsh) -def eigvalsh(a, UPLO='L'): - w, _ = eigh(a, UPLO) - return w - - -@partial(custom_jvp, nondiff_argnums=(1,)) -@_wraps(np.linalg.pinv, lax_description=textwrap.dedent("""\ - It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the - default `rcond` is `1e-15`. Here the default is - `10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps`. - """)) -def pinv(a, rcond=None): - # Uses same algorithm as - # https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979 - a = jnp.conj(a) - if rcond is None: - max_rows_cols = max(a.shape[-2:]) - rcond = 10. * max_rows_cols * jnp.finfo(a.dtype).eps - rcond = jnp.asarray(rcond) - u, s, v = svd(a, full_matrices=False) - # Singular values less than or equal to ``rcond * largest_singular_value`` - # are set to zero. - cutoff = rcond[..., jnp.newaxis] * jnp.amax(s, axis=-1, keepdims=True) - s = jnp.where(s > cutoff, s, jnp.inf) - res = jnp.matmul(_T(v), jnp.divide(_T(u), s[..., jnp.newaxis])) - return lax.convert_element_type(res, a.dtype) - - -@pinv.defjvp -def _pinv_jvp(rcond, primals, tangents): - # The Differentiation of Pseudo-Inverses and Nonlinear Least Squares Problems - # Whose Variables Separate. Author(s): G. H. Golub and V. Pereyra. SIAM - # Journal on Numerical Analysis, Vol. 10, No. 2 (Apr., 1973), pp. 413-432. - # (via https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Derivative) - a, = primals - a_dot, = tangents - p = pinv(a, rcond=rcond) - m, n = a.shape[-2:] - # TODO(phawkins): on TPU, we would need to opt into high precision here. - # TODO(phawkins): consider if this can be simplified in the Hermitian case. - p_dot = -p @ a_dot @ p - p_dot = p_dot + p @ _H(p) @ _H(a_dot) @ (jnp.eye(m, dtype=a.dtype) - a @ p) - p_dot = p_dot + (jnp.eye(n, dtype=a.dtype) - p @ a) @ _H(a_dot) @ _H(p) @ p - return p, p_dot - - -@_wraps(np.linalg.inv) -def inv(a): - if jnp.ndim(a) < 2 or a.shape[-1] != a.shape[-2]: - raise ValueError("Argument to inv must have shape [..., n, n], got {}." - .format(jnp.shape(a))) - return solve( - a, lax.broadcast(jnp.eye(a.shape[-1], dtype=lax.dtype(a)), a.shape[:-2])) - - -@partial(jit, static_argnums=(1, 2, 3)) -def _norm(x, ord, axis: Union[None, Tuple[int, ...], int], keepdims): - x = _promote_arg_dtypes(jnp.asarray(x)) - x_shape = jnp.shape(x) - ndim = len(x_shape) - - if axis is None: - # NumPy has an undocumented behavior that admits arbitrary rank inputs if - # `ord` is None: https://github.com/numpy/numpy/issues/14215 - if ord is None: - return jnp.sqrt(jnp.sum(jnp.real(x * jnp.conj(x)), keepdims=keepdims)) - axis = tuple(range(ndim)) - elif isinstance(axis, tuple): - axis = tuple(canonicalize_axis(x, ndim) for x in axis) - else: - axis = (canonicalize_axis(axis, ndim),) - - num_axes = len(axis) - if num_axes == 1: - if ord is None or ord == 2: - return jnp.sqrt(jnp.sum(jnp.real(x * jnp.conj(x)), axis=axis, - keepdims=keepdims)) - elif ord == jnp.inf: - return jnp.amax(jnp.abs(x), axis=axis, keepdims=keepdims) - elif ord == -jnp.inf: - return jnp.amin(jnp.abs(x), axis=axis, keepdims=keepdims) - elif ord == 0: - return jnp.sum(x != 0, dtype=jnp.finfo(lax.dtype(x)).dtype, - axis=axis, keepdims=keepdims) - elif ord == 1: - # Numpy has a special case for ord == 1 as an optimization. We don't - # really need the optimization (XLA could do it for us), but the Numpy - # code has slightly different type promotion semantics, so we need a - # special case too. - return jnp.sum(jnp.abs(x), axis=axis, keepdims=keepdims) - else: - abs_x = jnp.abs(x) - ord = lax._const(abs_x, ord) - out = jnp.sum(abs_x ** ord, axis=axis, keepdims=keepdims) - return jnp.power(out, 1. / ord) - - elif num_axes == 2: - row_axis, col_axis = cast(Tuple[int, ...], axis) - if ord is None or ord in ('f', 'fro'): - return jnp.sqrt(jnp.sum(jnp.real(x * jnp.conj(x)), axis=axis, - keepdims=keepdims)) - elif ord == 1: - if not keepdims and col_axis > row_axis: - col_axis -= 1 - return jnp.amax(jnp.sum(jnp.abs(x), axis=row_axis, keepdims=keepdims), - axis=col_axis, keepdims=keepdims) - elif ord == -1: - if not keepdims and col_axis > row_axis: - col_axis -= 1 - return jnp.amin(jnp.sum(jnp.abs(x), axis=row_axis, keepdims=keepdims), - axis=col_axis, keepdims=keepdims) - elif ord == jnp.inf: - if not keepdims and row_axis > col_axis: - row_axis -= 1 - return jnp.amax(jnp.sum(jnp.abs(x), axis=col_axis, keepdims=keepdims), - axis=row_axis, keepdims=keepdims) - elif ord == -jnp.inf: - if not keepdims and row_axis > col_axis: - row_axis -= 1 - return jnp.amin(jnp.sum(jnp.abs(x), axis=col_axis, keepdims=keepdims), - axis=row_axis, keepdims=keepdims) - elif ord in ('nuc', 2, -2): - x = jnp.moveaxis(x, axis, (-2, -1)) - if ord == 2: - reducer = jnp.amax - elif ord == -2: - reducer = jnp.amin - else: - reducer = jnp.sum - y = reducer(svd(x, compute_uv=False), axis=-1) - if keepdims: - result_shape = list(x_shape) - result_shape[axis[0]] = 1 - result_shape[axis[1]] = 1 - y = jnp.reshape(y, result_shape) - return y - else: - raise ValueError("Invalid order '{}' for matrix norm.".format(ord)) - else: - raise ValueError( - "Invalid axis values ({}) for jnp.linalg.norm.".format(axis)) - -@_wraps(np.linalg.norm) -def norm(x, ord=None, axis=None, keepdims=False): - return _norm(x, ord, axis, keepdims) - - -@_wraps(np.linalg.qr) -def qr(a, mode="reduced"): - if mode in ("reduced", "r", "full"): - full_matrices = False - elif mode == "complete": - full_matrices = True - else: - raise ValueError("Unsupported QR decomposition mode '{}'".format(mode)) - a = _promote_arg_dtypes(jnp.asarray(a)) - q, r = lax_linalg.qr(a, full_matrices) - if mode == "r": - return r - return q, r - - -def _check_solve_shapes(a, b): - if not (a.ndim >= 2 and a.shape[-1] == a.shape[-2] and b.ndim >= 1): - msg = ("The arguments to solve must have shapes a=[..., m, m] and " - "b=[..., m, k] or b=[..., m]; got a={} and b={}") - raise ValueError(msg.format(a.shape, b.shape)) - - -@partial(vectorize, signature='(n,m),(m)->(n)') -def _matvec_multiply(a, b): - return jnp.dot(a, b, precision=lax.Precision.HIGHEST) - - -@_wraps(np.linalg.solve) -@jit -def solve(a, b): - a, b = _promote_arg_dtypes(jnp.asarray(a), jnp.asarray(b)) - _check_solve_shapes(a, b) - - # With custom_linear_solve, we can reuse the same factorization when - # computing sensitivities. This is considerably faster. - lu, _, permutation = lax_linalg.lu(lax.stop_gradient(a)) - custom_solve = partial( - lax.custom_linear_solve, - lambda x: _matvec_multiply(a, x), - solve=lambda _, x: lax_linalg.lu_solve(lu, permutation, x, trans=0), - transpose_solve=lambda _, x: lax_linalg.lu_solve(lu, permutation, x, - trans=1)) - if a.ndim == b.ndim + 1: - # b.shape == [..., m] - return custom_solve(b) - else: - # b.shape == [..., m, k] - return vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b) - - -@_wraps(np.linalg.lstsq, lax_description=textwrap.dedent("""\ - It has two important differences: - - 1. In `numpy.linalg.lstsq`, the default `rcond` is `-1`, and warns that in the future - the default will be `None`. Here, the default rcond is `None`. - 2. In `np.linalg.lstsq` the returned residuals are empty for low-rank or over-determined - solutions. Here, the residuals are returned in all cases, to make the function - compatible with jit. The non-jit compatible numpy behavior can be recovered by - passing numpy_resid=True. - - The lstsq function does not currently have a custom JVP rule, so the gradient is - poorly behaved for some inputs, particularly for low-rank `a`. - """)) -def lstsq(a, b, rcond=None, *, numpy_resid=False): - # TODO: add lstsq to lax_linalg and implement this function via those wrappers. - # TODO: add custom jvp rule for more robust lstsq differentiation - a, b = _promote_arg_dtypes(a, b) - if a.shape[0] != b.shape[0]: - raise ValueError("Leading dimensions of input arrays must match") - b_orig_ndim = b.ndim - if b_orig_ndim == 1: - b = b[:, None] - if a.ndim != 2: - raise TypeError( - f"{a.ndim}-dimensional array given. Array must be two-dimensional") - if b.ndim != 2: - raise TypeError( - f"{b.ndim}-dimensional array given. Array must be one or two-dimensional") - m, n = a.shape - dtype = a.dtype - if rcond is None: - rcond = jnp.finfo(dtype).eps * max(n, m) - elif rcond < 0: - rcond = jnp.finfo(dtype).eps - u, s, vt = svd(a, full_matrices=False) - mask = s >= rcond * s[0] - rank = mask.sum() - safe_s = jnp.where(mask, s, 1) - s_inv = jnp.where(mask, 1 / safe_s, 0)[:, jnp.newaxis] - uTb = jnp.matmul(u.conj().T, b, precision=lax.Precision.HIGHEST) - x = jnp.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST) - # Numpy returns empty residuals in some cases. To allow compilation, we - # default to returning full residuals in all cases. - if numpy_resid and (rank < n or m <= n): - resid = jnp.asarray([]) - else: - b_estimate = jnp.matmul(a, x, precision=lax.Precision.HIGHEST) - resid = norm(b - b_estimate, axis=0) ** 2 - if b_orig_ndim == 1: - x = x.ravel() - return x, resid, rank, s - - +# flake8: noqa: F401 + +from jax._src.numpy.linalg import ( + cholesky, + cond, + det, + eig, + eigh, + eigvals, + eigvalsh, + inv, + lstsq, + matrix_power, + matrix_rank, + multi_dot, + norm, + pinv, + qr, + slogdet, + solve, + svd, + tensorinv, + tensorsolve, +) + +# Module initialization is encapsulated in a function to avoid accidental +# namespace pollution. _NOT_IMPLEMENTED = [] -for name, func in get_module_functions(np.linalg).items(): - if name not in globals(): - _NOT_IMPLEMENTED.append(name) - globals()[name] = _not_implemented(func) +def _init(): + import numpy as np + from jax._src.numpy import lax_numpy + from jax import util + # Builds a set of all unimplemented NumPy functions. + for name, func in util.get_module_functions(np.linalg).items(): + if name not in globals(): + _NOT_IMPLEMENTED.append(name) + globals()[name] = lax_numpy._not_implemented(func) + +_init() +del _init diff --git a/jax/ops/scatter.py b/jax/ops/scatter.py index d6dde5ed2..8f6b5def0 100644 --- a/jax/ops/scatter.py +++ b/jax/ops/scatter.py @@ -16,7 +16,7 @@ from .. import lax -from ..numpy import lax_numpy as jnp +from jax._src.numpy import lax_numpy as jnp from .. import util diff --git a/jax/random.py b/jax/random.py index 7571bf61d..e84893ae8 100644 --- a/jax/random.py +++ b/jax/random.py @@ -49,7 +49,7 @@ from . import lax from . import numpy as jnp from . import dtypes from .api import jit, vmap -from .numpy.lax_numpy import _constant_like, asarray +from jax._src.numpy.lax_numpy import _constant_like, asarray from jax.lib import xla_bridge from jax.lib import xla_client from jax.lib import cuda_prng diff --git a/jax/third_party/numpy/linalg.py b/jax/third_party/numpy/linalg.py index 9d1bc3bbd..a1dd33f4a 100644 --- a/jax/third_party/numpy/linalg.py +++ b/jax/third_party/numpy/linalg.py @@ -1,8 +1,8 @@ import numpy as np -from jax.numpy import lax_numpy as jnp -from jax.numpy import linalg as la -from jax.numpy._util import _wraps +from jax._src.numpy import lax_numpy as jnp +from jax._src.numpy import linalg as la +from jax._src.numpy.util import _wraps def _isEmpty2d(arr): diff --git a/tests/masking_test.py b/tests/masking_test.py index 31c849ce5..ddb873f30 100644 --- a/tests/masking_test.py +++ b/tests/masking_test.py @@ -23,7 +23,7 @@ from jax import lax from jax import core from jax import test_util as jtu from jax.config import config -from jax.numpy.lax_numpy import _polymorphic_slice_indices +from jax._src.numpy.lax_numpy import _polymorphic_slice_indices from jax.util import safe_map, safe_zip from jax.tree_util import tree_flatten diff --git a/tests/polynomial_test.py b/tests/polynomial_test.py index fde0e51b3..858b5e5b3 100644 --- a/tests/polynomial_test.py +++ b/tests/polynomial_test.py @@ -34,12 +34,6 @@ all_dtypes = jtu.dtypes.floating + jtu.dtypes.integer + jtu.dtypes.complex class TestPolynomial(jtu.JaxTestCase): - def testNotImplemented(self): - for name in jnp.polynomial._NOT_IMPLEMENTED: - func = getattr(jnp.polynomial, name) - with self.assertRaises(NotImplementedError): - func() - @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_dtype={}_leading={}_trailing={}".format( jtu.format_shape_dtype_string((length+leading+trailing,), dtype),