Move jax.numpy internals into jax._src.numpy.

This commit is contained in:
Peter Hawkins 2020-10-16 18:08:20 -04:00
parent 9ea1311c7d
commit aa107cf1f4
36 changed files with 954 additions and 864 deletions

View File

@ -0,0 +1,13 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

253
jax/_src/numpy/fft.py Normal file
View File

@ -0,0 +1,253 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from jax import lax
from jax.lib import xla_client
from .util import _wraps
from . import lax_numpy as jnp
from jax import ops as jaxops
def _fft_core(func_name, fft_type, a, s, axes, norm):
# TODO(skye): implement padding/cropping based on 's'.
full_name = "jax.numpy.fft." + func_name
if s is not None:
raise NotImplementedError("%s only supports s=None, got %s" % (full_name, s))
if norm is not None:
raise NotImplementedError("%s only supports norm=None, got %s" % (full_name, norm))
if s is not None and axes is not None and len(s) != len(axes):
# Same error as numpy.
raise ValueError("Shape and axes have different lengths.")
orig_axes = axes
if axes is None:
if s is None:
axes = range(a.ndim)
else:
axes = range(a.ndim - len(s), a.ndim)
if len(axes) != len(set(axes)):
raise ValueError(
"%s does not support repeated axes. Got axes %s." % (full_name, axes))
if len(axes) > 3:
# XLA does not support FFTs over more than 3 dimensions
raise ValueError(
"%s only supports 1D, 2D, and 3D FFTs. "
"Got axes %s with input rank %s." % (full_name, orig_axes, a.ndim))
# XLA only supports FFTs over the innermost axes, so rearrange if necessary.
if orig_axes is not None:
axes = tuple(range(a.ndim - len(axes), a.ndim))
a = jnp.moveaxis(a, orig_axes, axes)
if s is None:
if fft_type == xla_client.FftType.IRFFT:
s = [a.shape[axis] for axis in axes[:-1]]
if axes:
s += [max(0, 2 * (a.shape[axes[-1]] - 1))]
else:
s = [a.shape[axis] for axis in axes]
transformed = lax.fft(a, fft_type, s)
if orig_axes is not None:
transformed = jnp.moveaxis(transformed, axes, orig_axes)
return transformed
@_wraps(np.fft.fftn)
def fftn(a, s=None, axes=None, norm=None):
return _fft_core('fftn', xla_client.FftType.FFT, a, s, axes, norm)
@_wraps(np.fft.ifftn)
def ifftn(a, s=None, axes=None, norm=None):
return _fft_core('ifftn', xla_client.FftType.IFFT, a, s, axes, norm)
@_wraps(np.fft.rfftn)
def rfftn(a, s=None, axes=None, norm=None):
return _fft_core('rfftn', xla_client.FftType.RFFT, a, s, axes, norm)
@_wraps(np.fft.irfftn)
def irfftn(a, s=None, axes=None, norm=None):
return _fft_core('irfftn', xla_client.FftType.IRFFT, a, s, axes, norm)
def _axis_check_1d(func_name, axis):
full_name = "jax.numpy.fft." + func_name
if isinstance(axis, (list, tuple)):
raise ValueError(
"%s does not support multiple axes. Please use %sn. "
"Got axis = %r." % (full_name, full_name, axis)
)
def _fft_core_1d(func_name, fft_type, a, s, axis, norm):
_axis_check_1d(func_name, axis)
axes = None if axis is None else [axis]
return _fft_core(func_name, fft_type, a, s, axes, norm)
@_wraps(np.fft.fft)
def fft(a, n=None, axis=-1, norm=None):
return _fft_core_1d('fft', xla_client.FftType.FFT, a, s=n, axis=axis,
norm=norm)
@_wraps(np.fft.ifft)
def ifft(a, n=None, axis=-1, norm=None):
return _fft_core_1d('ifft', xla_client.FftType.IFFT, a, s=n, axis=axis,
norm=norm)
@_wraps(np.fft.rfft)
def rfft(a, n=None, axis=-1, norm=None):
return _fft_core_1d('rfft', xla_client.FftType.RFFT, a, s=n, axis=axis,
norm=norm)
@_wraps(np.fft.irfft)
def irfft(a, n=None, axis=-1, norm=None):
return _fft_core_1d('irfft', xla_client.FftType.IRFFT, a, s=n, axis=axis,
norm=norm)
@_wraps(np.fft.hfft)
def hfft(a, n=None, axis=-1, norm=None):
conj_a = jnp.conj(a)
_axis_check_1d('hfft', axis)
nn = (a.shape[axis] - 1) * 2 if n is None else n
return _fft_core_1d('hfft', xla_client.FftType.IRFFT, conj_a, s=n, axis=axis,
norm=norm) * nn
@_wraps(np.fft.ihfft)
def ihfft(a, n=None, axis=-1, norm=None):
_axis_check_1d('ihfft', axis)
nn = a.shape[axis] if n is None else n
output = _fft_core_1d('ihfft', xla_client.FftType.RFFT, a, s=n, axis=axis,
norm=norm)
return jnp.conj(output) * (1 / nn)
def _fft_core_2d(func_name, fft_type, a, s, axes, norm):
full_name = "jax.numpy.fft." + func_name
if len(axes) != 2:
raise ValueError(
"%s only supports 2 axes. Got axes = %r."
% (full_name, axes)
)
return _fft_core(func_name, fft_type, a, s, axes, norm)
@_wraps(np.fft.fft2)
def fft2(a, s=None, axes=(-2,-1), norm=None):
return _fft_core_2d('fft2', xla_client.FftType.FFT, a, s=s, axes=axes,
norm=norm)
@_wraps(np.fft.ifft2)
def ifft2(a, s=None, axes=(-2,-1), norm=None):
return _fft_core_2d('ifft2', xla_client.FftType.IFFT, a, s=s, axes=axes,
norm=norm)
@_wraps(np.fft.rfft2)
def rfft2(a, s=None, axes=(-2,-1), norm=None):
return _fft_core_2d('rfft2', xla_client.FftType.RFFT, a, s=s, axes=axes,
norm=norm)
@_wraps(np.fft.irfft2)
def irfft2(a, s=None, axes=(-2,-1), norm=None):
return _fft_core_2d('irfft2', xla_client.FftType.IRFFT, a, s=s, axes=axes,
norm=norm)
@_wraps(np.fft.fftfreq)
def fftfreq(n, d=1.0):
if isinstance(n, (list, tuple)):
raise ValueError(
"The n argument of jax.numpy.fft.fftfreq only takes an int. "
"Got n = %s." % list(n))
elif isinstance(d, (list, tuple)):
raise ValueError(
"The d argument of jax.numpy.fft.fftfreq only takes a single value. "
"Got d = %s." % list(d))
k = jnp.zeros(n)
if n % 2 == 0:
# k[0: n // 2 - 1] = jnp.arange(0, n // 2 - 1)
k = jaxops.index_update(k, jaxops.index[0: n // 2], jnp.arange(0, n // 2))
# k[n // 2:] = jnp.arange(-n // 2, -1)
k = jaxops.index_update(k, jaxops.index[n // 2:], jnp.arange(-n // 2, 0))
else:
# k[0: (n - 1) // 2] = jnp.arange(0, (n - 1) // 2)
k = jaxops.index_update(k, jaxops.index[0: (n - 1) // 2 + 1],
jnp.arange(0, (n - 1) // 2 + 1))
# k[(n - 1) // 2 + 1:] = jnp.arange(-(n - 1) // 2, -1)
k = jaxops.index_update(k, jaxops.index[(n - 1) // 2 + 1:],
jnp.arange(-(n - 1) // 2, 0))
return k / (d * n)
@_wraps(np.fft.rfftfreq)
def rfftfreq(n, d=1.0):
if isinstance(n, (list, tuple)):
raise ValueError(
"The n argument of jax.numpy.fft.rfftfreq only takes an int. "
"Got n = %s." % list(n))
elif isinstance(d, (list, tuple)):
raise ValueError(
"The d argument of jax.numpy.fft.rfftfreq only takes a single value. "
"Got d = %s." % list(d))
if n % 2 == 0:
k = jnp.arange(0, n // 2 + 1)
else:
k = jnp.arange(0, (n - 1) // 2 + 1)
return k / (d * n)
@_wraps(np.fft.fftshift)
def fftshift(x, axes=None):
x = jnp.asarray(x)
if axes is None:
axes = tuple(range(x.ndim))
shift = [dim // 2 for dim in x.shape]
elif isinstance(axes, int):
shift = x.shape[axes] // 2
else:
shift = [x.shape[ax] // 2 for ax in axes]
return jnp.roll(x, shift, axes)
@_wraps(np.fft.ifftshift)
def ifftshift(x, axes=None):
x = jnp.asarray(x)
if axes is None:
axes = tuple(range(x.ndim))
shift = [-(dim // 2) for dim in x.shape]
elif isinstance(axes, int):
shift = -(x.shape[axes] // 2)
else:
shift = [-(x.shape[ax] // 2) for ax in axes]
return jnp.roll(x, shift, axes)

View File

@ -39,19 +39,19 @@ import opt_einsum
import jax
from jax import jit, custom_jvp
from .vectorize import vectorize
from ._util import _wraps
from .. import core
from .. import dtypes
from ..abstract_arrays import UnshapedArray, ShapedArray, ConcreteArray, canonicalize_shape
from ..config import flags, config
from ..interpreters.xla import DeviceArray
from ..interpreters.masking import Poly
from .. import lax
from ..lax.lax import _device_put_raw
from .. import ops
from ..util import (partial, unzip2, prod as _prod,
subvals, safe_zip, canonicalize_axis as _canonicalize_axis)
from ..tree_util import tree_leaves, tree_flatten
from .util import _wraps
from jax import core
from jax import dtypes
from jax.abstract_arrays import UnshapedArray, ShapedArray, ConcreteArray, canonicalize_shape
from jax.config import flags, config
from jax.interpreters.xla import DeviceArray
from jax.interpreters.masking import Poly
from jax import lax
from jax.lax.lax import _device_put_raw
from jax import ops
from jax.util import (partial, unzip2, prod as _prod,
subvals, safe_zip, canonicalize_axis as _canonicalize_axis)
from jax.tree_util import tree_leaves, tree_flatten
FLAGS = flags.FLAGS
flags.DEFINE_enum(

530
jax/_src/numpy/linalg.py Normal file
View File

@ -0,0 +1,530 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import numpy as np
import textwrap
import operator
from typing import Tuple, Union, cast
from jax import jit, vmap, custom_jvp
from jax import lax
from jax import ops
from jax import lax_linalg
from jax import dtypes
from .util import _wraps
from .vectorize import vectorize
from . import lax_numpy as jnp
from jax.util import canonicalize_axis
from jax.third_party.numpy.linalg import cond, multi_dot, tensorinv, tensorsolve # noqa: F401
_T = lambda x: jnp.swapaxes(x, -1, -2)
_H = lambda x: jnp.conjugate(jnp.swapaxes(x, -1, -2))
def _promote_arg_dtypes(*args):
"""Promotes `args` to a common inexact type."""
def _to_inexact_type(type):
return type if jnp.issubdtype(type, jnp.inexact) else jnp.float_
inexact_types = [_to_inexact_type(jnp._dtype(arg)) for arg in args]
dtype = dtypes.canonicalize_dtype(jnp.result_type(*inexact_types))
args = [lax.convert_element_type(arg, dtype) for arg in args]
if len(args) == 1:
return args[0]
else:
return args
@_wraps(np.linalg.cholesky)
def cholesky(a):
a = _promote_arg_dtypes(jnp.asarray(a))
return lax_linalg.cholesky(a)
@_wraps(np.linalg.svd)
def svd(a, full_matrices=True, compute_uv=True):
a = _promote_arg_dtypes(jnp.asarray(a))
return lax_linalg.svd(a, full_matrices, compute_uv)
@_wraps(np.linalg.matrix_power)
def matrix_power(a, n):
a = _promote_arg_dtypes(jnp.asarray(a))
if a.ndim < 2:
raise TypeError("{}-dimensional array given. Array must be at least "
"two-dimensional".format(a.ndim))
if a.shape[-2] != a.shape[-1]:
raise TypeError("Last 2 dimensions of the array must be square")
try:
n = operator.index(n)
except TypeError as err:
raise TypeError("exponent must be an integer, got {}".format(n)) from err
if n == 0:
return jnp.broadcast_to(jnp.eye(a.shape[-2], dtype=a.dtype), a.shape)
elif n < 0:
a = inv(a)
n = np.abs(n)
if n == 1:
return a
elif n == 2:
return a @ a
elif n == 3:
return (a @ a) @ a
z = result = None
while n > 0:
z = a if z is None else (z @ z)
n, bit = divmod(n, 2)
if bit:
result = z if result is None else (result @ z)
return result
@_wraps(np.linalg.matrix_rank)
def matrix_rank(M, tol=None):
M = _promote_arg_dtypes(jnp.asarray(M))
if M.ndim > 2:
raise TypeError("array should have 2 or fewer dimensions")
if M.ndim < 2:
return jnp.any(M != 0).astype(jnp.int32)
S = svd(M, full_matrices=False, compute_uv=False)
if tol is None:
tol = S.max() * np.max(M.shape) * jnp.finfo(S.dtype).eps
return jnp.sum(S > tol)
@custom_jvp
@_wraps(np.linalg.slogdet)
@jit
def slogdet(a):
a = _promote_arg_dtypes(jnp.asarray(a))
dtype = lax.dtype(a)
a_shape = jnp.shape(a)
if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]:
msg = "Argument to slogdet() must have shape [..., n, n], got {}"
raise ValueError(msg.format(a_shape))
lu, pivot, _ = lax_linalg.lu(a)
diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1)
parity = jnp.count_nonzero(pivot != jnp.arange(a_shape[-1]), axis=-1)
if jnp.iscomplexobj(a):
sign = jnp.prod(diag / jnp.abs(diag), axis=-1)
else:
sign = jnp.array(1, dtype=dtype)
parity = parity + jnp.count_nonzero(diag < 0, axis=-1)
sign = jnp.where(is_zero,
jnp.array(0, dtype=dtype),
sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype))
logdet = jnp.where(
is_zero, jnp.array(-jnp.inf, dtype=dtype),
jnp.sum(jnp.log(jnp.abs(diag)), axis=-1))
return sign, jnp.real(logdet)
@slogdet.defjvp
def _slogdet_jvp(primals, tangents):
x, = primals
g, = tangents
if jnp.issubdtype(jnp._dtype(x), jnp.complexfloating):
raise NotImplementedError # TODO(pfau): make this work for complex types
sign, ans = slogdet(x)
sign_dot, ans_dot = jnp.zeros_like(sign), jnp.trace(solve(x, g), axis1=-1, axis2=-2)
return (sign, ans), (sign_dot, ans_dot)
def _cofactor_solve(a, b):
"""Equivalent to det(a)*solve(a, b) for nonsingular mat.
Intermediate function used for jvp and vjp of det.
This function borrows heavily from jax.numpy.linalg.solve and
jax.numpy.linalg.slogdet to compute the gradient of the determinant
in a way that is well defined even for low rank matrices.
This function handles two different cases:
* rank(a) == n or n-1
* rank(a) < n-1
For rank n-1 matrices, the gradient of the determinant is a rank 1 matrix.
Rather than computing det(a)*solve(a, b), which would return NaN, we work
directly with the LU decomposition. If a = p @ l @ u, then
det(a)*solve(a, b) =
prod(diag(u)) * u^-1 @ l^-1 @ p^-1 b =
prod(diag(u)) * triangular_solve(u, solve(p @ l, b))
If a is rank n-1, then the lower right corner of u will be zero and the
triangular_solve will fail.
Let x = solve(p @ l, b) and y = det(a)*solve(a, b).
Then y_{n}
x_{n} / u_{nn} * prod_{i=1...n}(u_{ii}) =
x_{n} * prod_{i=1...n-1}(u_{ii})
So by replacing the lower-right corner of u with prod_{i=1...n-1}(u_{ii})^-1
we can avoid the triangular_solve failing.
To correctly compute the rest of y_{i} for i != n, we simply multiply
x_{i} by det(a) for all i != n, which will be zero if rank(a) = n-1.
For the second case, a check is done on the matrix to see if `solve`
returns NaN or Inf, and gives a matrix of zeros as a result, as the
gradient of the determinant of a matrix with rank less than n-1 is 0.
This will still return the correct value for rank n-1 matrices, as the check
is applied *after* the lower right corner of u has been updated.
Args:
a: A square matrix or batch of matrices, possibly singular.
b: A matrix, or batch of matrices of the same dimension as a.
Returns:
det(a) and cofactor(a)^T*b, aka adjugate(a)*b
"""
a = _promote_arg_dtypes(jnp.asarray(a))
b = _promote_arg_dtypes(jnp.asarray(b))
a_shape = jnp.shape(a)
b_shape = jnp.shape(b)
a_ndims = len(a_shape)
if not (a_ndims >= 2 and a_shape[-1] == a_shape[-2]
and b_shape[-2:] == a_shape[-2:]):
msg = ("The arguments to _cofactor_solve must have shapes "
"a=[..., m, m] and b=[..., m, m]; got a={} and b={}")
raise ValueError(msg.format(a_shape, b_shape))
if a_shape[-1] == 1:
return a[0, 0], b
# lu contains u in the upper triangular matrix and l in the strict lower
# triangular matrix.
# The diagonal of l is set to ones without loss of generality.
lu, pivots, permutation = lax_linalg.lu(a)
dtype = lax.dtype(a)
batch_dims = lax.broadcast_shapes(lu.shape[:-2], b.shape[:-2])
x = jnp.broadcast_to(b, batch_dims + b.shape[-2:])
lu = jnp.broadcast_to(lu, batch_dims + lu.shape[-2:])
# Compute (partial) determinant, ignoring last diagonal of LU
diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
parity = jnp.count_nonzero(pivots != jnp.arange(a_shape[-1]), axis=-1)
sign = jnp.array(-2 * (parity % 2) + 1, dtype=dtype)
# partial_det[:, -1] contains the full determinant and
# partial_det[:, -2] contains det(u) / u_{nn}.
partial_det = jnp.cumprod(diag, axis=-1) * sign[..., None]
lu = ops.index_update(lu, ops.index[..., -1, -1], 1.0 / partial_det[..., -2])
permutation = jnp.broadcast_to(permutation, batch_dims + (a_shape[-1],))
iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1,)))
# filter out any matrices that are not full rank
d = jnp.ones(x.shape[:-1], x.dtype)
d = lax_linalg.triangular_solve(lu, d, left_side=True, lower=False)
d = jnp.any(jnp.logical_or(jnp.isnan(d), jnp.isinf(d)), axis=-1)
d = jnp.tile(d[..., None, None], d.ndim*(1,) + x.shape[-2:])
x = jnp.where(d, jnp.zeros_like(x), x) # first filter
x = x[iotas[:-1] + (permutation, slice(None))]
x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=True,
unit_diagonal=True)
x = jnp.concatenate((x[..., :-1, :] * partial_det[..., -1, None, None],
x[..., -1:, :]), axis=-2)
x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=False)
x = jnp.where(d, jnp.zeros_like(x), x) # second filter
return partial_det[..., -1], x
@custom_jvp
@_wraps(np.linalg.det)
def det(a):
sign, logdet = slogdet(a)
return sign * jnp.exp(logdet)
@det.defjvp
def _det_jvp(primals, tangents):
x, = primals
g, = tangents
y, z = _cofactor_solve(x, g)
return y, jnp.trace(z, axis1=-1, axis2=-2)
@_wraps(np.linalg.eig)
def eig(a):
a = _promote_arg_dtypes(jnp.asarray(a))
return lax_linalg.eig(a, compute_left_eigenvectors=False)
@_wraps(np.linalg.eigvals)
def eigvals(a):
return lax_linalg.eig(a, compute_left_eigenvectors=False,
compute_right_eigenvectors=False)[0]
@_wraps(np.linalg.eigh)
def eigh(a, UPLO=None, symmetrize_input=True):
if UPLO is None or UPLO == "L":
lower = True
elif UPLO == "U":
lower = False
else:
msg = "UPLO must be one of None, 'L', or 'U', got {}".format(UPLO)
raise ValueError(msg)
a = _promote_arg_dtypes(jnp.asarray(a))
v, w = lax_linalg.eigh(a, lower=lower, symmetrize_input=symmetrize_input)
return w, v
@_wraps(np.linalg.eigvalsh)
def eigvalsh(a, UPLO='L'):
w, _ = eigh(a, UPLO)
return w
@partial(custom_jvp, nondiff_argnums=(1,))
@_wraps(np.linalg.pinv, lax_description=textwrap.dedent("""\
It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the
default `rcond` is `1e-15`. Here the default is
`10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps`.
"""))
def pinv(a, rcond=None):
# Uses same algorithm as
# https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979
a = jnp.conj(a)
if rcond is None:
max_rows_cols = max(a.shape[-2:])
rcond = 10. * max_rows_cols * jnp.finfo(a.dtype).eps
rcond = jnp.asarray(rcond)
u, s, v = svd(a, full_matrices=False)
# Singular values less than or equal to ``rcond * largest_singular_value``
# are set to zero.
cutoff = rcond[..., jnp.newaxis] * jnp.amax(s, axis=-1, keepdims=True)
s = jnp.where(s > cutoff, s, jnp.inf)
res = jnp.matmul(_T(v), jnp.divide(_T(u), s[..., jnp.newaxis]))
return lax.convert_element_type(res, a.dtype)
@pinv.defjvp
def _pinv_jvp(rcond, primals, tangents):
# The Differentiation of Pseudo-Inverses and Nonlinear Least Squares Problems
# Whose Variables Separate. Author(s): G. H. Golub and V. Pereyra. SIAM
# Journal on Numerical Analysis, Vol. 10, No. 2 (Apr., 1973), pp. 413-432.
# (via https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Derivative)
a, = primals
a_dot, = tangents
p = pinv(a, rcond=rcond)
m, n = a.shape[-2:]
# TODO(phawkins): on TPU, we would need to opt into high precision here.
# TODO(phawkins): consider if this can be simplified in the Hermitian case.
p_dot = -p @ a_dot @ p
p_dot = p_dot + p @ _H(p) @ _H(a_dot) @ (jnp.eye(m, dtype=a.dtype) - a @ p)
p_dot = p_dot + (jnp.eye(n, dtype=a.dtype) - p @ a) @ _H(a_dot) @ _H(p) @ p
return p, p_dot
@_wraps(np.linalg.inv)
def inv(a):
if jnp.ndim(a) < 2 or a.shape[-1] != a.shape[-2]:
raise ValueError("Argument to inv must have shape [..., n, n], got {}."
.format(jnp.shape(a)))
return solve(
a, lax.broadcast(jnp.eye(a.shape[-1], dtype=lax.dtype(a)), a.shape[:-2]))
@partial(jit, static_argnums=(1, 2, 3))
def _norm(x, ord, axis: Union[None, Tuple[int, ...], int], keepdims):
x = _promote_arg_dtypes(jnp.asarray(x))
x_shape = jnp.shape(x)
ndim = len(x_shape)
if axis is None:
# NumPy has an undocumented behavior that admits arbitrary rank inputs if
# `ord` is None: https://github.com/numpy/numpy/issues/14215
if ord is None:
return jnp.sqrt(jnp.sum(jnp.real(x * jnp.conj(x)), keepdims=keepdims))
axis = tuple(range(ndim))
elif isinstance(axis, tuple):
axis = tuple(canonicalize_axis(x, ndim) for x in axis)
else:
axis = (canonicalize_axis(axis, ndim),)
num_axes = len(axis)
if num_axes == 1:
if ord is None or ord == 2:
return jnp.sqrt(jnp.sum(jnp.real(x * jnp.conj(x)), axis=axis,
keepdims=keepdims))
elif ord == jnp.inf:
return jnp.amax(jnp.abs(x), axis=axis, keepdims=keepdims)
elif ord == -jnp.inf:
return jnp.amin(jnp.abs(x), axis=axis, keepdims=keepdims)
elif ord == 0:
return jnp.sum(x != 0, dtype=jnp.finfo(lax.dtype(x)).dtype,
axis=axis, keepdims=keepdims)
elif ord == 1:
# Numpy has a special case for ord == 1 as an optimization. We don't
# really need the optimization (XLA could do it for us), but the Numpy
# code has slightly different type promotion semantics, so we need a
# special case too.
return jnp.sum(jnp.abs(x), axis=axis, keepdims=keepdims)
else:
abs_x = jnp.abs(x)
ord = lax._const(abs_x, ord)
out = jnp.sum(abs_x ** ord, axis=axis, keepdims=keepdims)
return jnp.power(out, 1. / ord)
elif num_axes == 2:
row_axis, col_axis = cast(Tuple[int, ...], axis)
if ord is None or ord in ('f', 'fro'):
return jnp.sqrt(jnp.sum(jnp.real(x * jnp.conj(x)), axis=axis,
keepdims=keepdims))
elif ord == 1:
if not keepdims and col_axis > row_axis:
col_axis -= 1
return jnp.amax(jnp.sum(jnp.abs(x), axis=row_axis, keepdims=keepdims),
axis=col_axis, keepdims=keepdims)
elif ord == -1:
if not keepdims and col_axis > row_axis:
col_axis -= 1
return jnp.amin(jnp.sum(jnp.abs(x), axis=row_axis, keepdims=keepdims),
axis=col_axis, keepdims=keepdims)
elif ord == jnp.inf:
if not keepdims and row_axis > col_axis:
row_axis -= 1
return jnp.amax(jnp.sum(jnp.abs(x), axis=col_axis, keepdims=keepdims),
axis=row_axis, keepdims=keepdims)
elif ord == -jnp.inf:
if not keepdims and row_axis > col_axis:
row_axis -= 1
return jnp.amin(jnp.sum(jnp.abs(x), axis=col_axis, keepdims=keepdims),
axis=row_axis, keepdims=keepdims)
elif ord in ('nuc', 2, -2):
x = jnp.moveaxis(x, axis, (-2, -1))
if ord == 2:
reducer = jnp.amax
elif ord == -2:
reducer = jnp.amin
else:
reducer = jnp.sum
y = reducer(svd(x, compute_uv=False), axis=-1)
if keepdims:
result_shape = list(x_shape)
result_shape[axis[0]] = 1
result_shape[axis[1]] = 1
y = jnp.reshape(y, result_shape)
return y
else:
raise ValueError("Invalid order '{}' for matrix norm.".format(ord))
else:
raise ValueError(
"Invalid axis values ({}) for jnp.linalg.norm.".format(axis))
@_wraps(np.linalg.norm)
def norm(x, ord=None, axis=None, keepdims=False):
return _norm(x, ord, axis, keepdims)
@_wraps(np.linalg.qr)
def qr(a, mode="reduced"):
if mode in ("reduced", "r", "full"):
full_matrices = False
elif mode == "complete":
full_matrices = True
else:
raise ValueError("Unsupported QR decomposition mode '{}'".format(mode))
a = _promote_arg_dtypes(jnp.asarray(a))
q, r = lax_linalg.qr(a, full_matrices)
if mode == "r":
return r
return q, r
def _check_solve_shapes(a, b):
if not (a.ndim >= 2 and a.shape[-1] == a.shape[-2] and b.ndim >= 1):
msg = ("The arguments to solve must have shapes a=[..., m, m] and "
"b=[..., m, k] or b=[..., m]; got a={} and b={}")
raise ValueError(msg.format(a.shape, b.shape))
@partial(vectorize, signature='(n,m),(m)->(n)')
def _matvec_multiply(a, b):
return jnp.dot(a, b, precision=lax.Precision.HIGHEST)
@_wraps(np.linalg.solve)
@jit
def solve(a, b):
a, b = _promote_arg_dtypes(jnp.asarray(a), jnp.asarray(b))
_check_solve_shapes(a, b)
# With custom_linear_solve, we can reuse the same factorization when
# computing sensitivities. This is considerably faster.
lu, _, permutation = lax_linalg.lu(lax.stop_gradient(a))
custom_solve = partial(
lax.custom_linear_solve,
lambda x: _matvec_multiply(a, x),
solve=lambda _, x: lax_linalg.lu_solve(lu, permutation, x, trans=0),
transpose_solve=lambda _, x: lax_linalg.lu_solve(lu, permutation, x,
trans=1))
if a.ndim == b.ndim + 1:
# b.shape == [..., m]
return custom_solve(b)
else:
# b.shape == [..., m, k]
return vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b)
@_wraps(np.linalg.lstsq, lax_description=textwrap.dedent("""\
It has two important differences:
1. In `numpy.linalg.lstsq`, the default `rcond` is `-1`, and warns that in the future
the default will be `None`. Here, the default rcond is `None`.
2. In `np.linalg.lstsq` the returned residuals are empty for low-rank or over-determined
solutions. Here, the residuals are returned in all cases, to make the function
compatible with jit. The non-jit compatible numpy behavior can be recovered by
passing numpy_resid=True.
The lstsq function does not currently have a custom JVP rule, so the gradient is
poorly behaved for some inputs, particularly for low-rank `a`.
"""))
def lstsq(a, b, rcond=None, *, numpy_resid=False):
# TODO: add lstsq to lax_linalg and implement this function via those wrappers.
# TODO: add custom jvp rule for more robust lstsq differentiation
a, b = _promote_arg_dtypes(a, b)
if a.shape[0] != b.shape[0]:
raise ValueError("Leading dimensions of input arrays must match")
b_orig_ndim = b.ndim
if b_orig_ndim == 1:
b = b[:, None]
if a.ndim != 2:
raise TypeError(
f"{a.ndim}-dimensional array given. Array must be two-dimensional")
if b.ndim != 2:
raise TypeError(
f"{b.ndim}-dimensional array given. Array must be one or two-dimensional")
m, n = a.shape
dtype = a.dtype
if rcond is None:
rcond = jnp.finfo(dtype).eps * max(n, m)
elif rcond < 0:
rcond = jnp.finfo(dtype).eps
u, s, vt = svd(a, full_matrices=False)
mask = s >= rcond * s[0]
rank = mask.sum()
safe_s = jnp.where(mask, s, 1)
s_inv = jnp.where(mask, 1 / safe_s, 0)[:, jnp.newaxis]
uTb = jnp.matmul(u.conj().T, b, precision=lax.Precision.HIGHEST)
x = jnp.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST)
# Numpy returns empty residuals in some cases. To allow compilation, we
# default to returning full residuals in all cases.
if numpy_resid and (rank < n or m <= n):
resid = jnp.asarray([])
else:
b_estimate = jnp.matmul(a, x, precision=lax.Precision.HIGHEST)
resid = norm(b - b_estimate, axis=0) ** 2
if b_orig_ndim == 1:
x = x.ravel()
return x, resid, rank, s

View File

@ -14,15 +14,13 @@
import numpy as np
from .. import lax
from jax import lax
from . import lax_numpy as jnp
from jax import jit
from ._util import _wraps
from .lax_numpy import _not_implemented
from .util import _wraps
from .linalg import eigvals as _eigvals
from .. import ops as jaxops
from ..util import get_module_functions
from jax import ops as jaxops
def _to_inexact_type(type):
@ -104,10 +102,3 @@ def roots(p, *, strip_zeros=True):
# combine roots and zero roots
roots = jnp.hstack((roots, jnp.zeros(trailing_zeros, p.dtype)))
return roots
_NOT_IMPLEMENTED = []
for name, func in get_module_functions(np.polynomial).items():
if name not in globals():
_NOT_IMPLEMENTED.append(name)
globals()[name] = _not_implemented(func)

View File

@ -11,14 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import re
from typing import Any, Callable, Dict, List, Tuple
from .. import api
from .. import lax
from jax import api
from jax import lax
from . import lax_numpy as jnp
from ..util import safe_map as map, safe_zip as zip
from jax.util import safe_map as map, safe_zip as zip
# See http://docs.scipy.org/doc/numpy/reference/c-api.generalized-ufuncs.html

View File

@ -22,9 +22,9 @@ from jax import jit, vmap
from jax import api
from jax import lax
from jax import lax_linalg
from jax.numpy._util import _wraps
from jax.numpy import lax_numpy as jnp
from jax.numpy import linalg as np_linalg
from jax._src.numpy.util import _wraps
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import linalg as np_linalg
_T = lambda x: jnp.swapaxes(x, -1, -2)

View File

@ -22,8 +22,8 @@ import scipy.ndimage
from jax import api
from jax import lax
from jax.numpy import lax_numpy as jnp
from jax.numpy._util import _wraps
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy.util import _wraps
from jax.util import safe_zip as zip

View File

@ -18,10 +18,10 @@ import warnings
import numpy as np
from jax import lax
from jax.numpy import lax_numpy as jnp
from jax.numpy import linalg
from jax.numpy.lax_numpy import _promote_dtypes_inexact
from jax.numpy._util import _wraps
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import linalg
from jax._src.numpy.lax_numpy import _promote_dtypes_inexact
from jax._src.numpy.util import _wraps
# Note: we do not re-use the code from jax.numpy.convolve here, because the handling

View File

@ -20,10 +20,10 @@ import scipy.special as osp_special
from jax import lax
from jax import api
from jax.interpreters import ad
from jax.numpy import lax_numpy as jnp
from jax.numpy.lax_numpy import (asarray, _reduction_dims, _constant_like,
_promote_args_inexact)
from jax.numpy._util import _wraps
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy.lax_numpy import (asarray, _reduction_dims, _constant_like,
_promote_args_inexact)
from jax._src.numpy.util import _wraps
@_wraps(osp_special.gammaln)

View File

@ -16,8 +16,8 @@
import scipy.stats as osp_stats
from jax import lax
from jax.numpy import lax_numpy as jnp
from jax.numpy._util import _wraps
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy.util import _wraps
from jax.scipy.special import xlogy, xlog1py

View File

@ -15,9 +15,9 @@
import scipy.stats as osp_stats
from jax import lax
from jax.numpy._util import _wraps
from jax.numpy.lax_numpy import (_promote_args_inexact, _constant_like,
where, inf, logical_or)
from jax._src.numpy.util import _wraps
from jax._src.numpy.lax_numpy import (_promote_args_inexact, _constant_like,
where, inf, logical_or)
from jax.scipy.special import betaln

View File

@ -17,8 +17,8 @@ import numpy as np
import scipy.stats as osp_stats
from jax import lax
from jax.numpy._util import _wraps
from jax.numpy.lax_numpy import _promote_args_inexact, _constant_like
from jax._src.numpy.util import _wraps
from jax._src.numpy.lax_numpy import _promote_args_inexact, _constant_like
@_wraps(osp_stats.cauchy.logpdf, update_doc=False)

View File

@ -17,8 +17,8 @@ import numpy as np
import scipy.stats as osp_stats
from jax import lax
from jax.numpy import lax_numpy as jnp
from jax.numpy._util import _wraps
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy.util import _wraps
from jax.scipy.special import gammaln, xlogy

View File

@ -15,8 +15,8 @@
import scipy.stats as osp_stats
from jax import lax
from jax.numpy._util import _wraps
from jax.numpy.lax_numpy import _promote_args_inexact, where, inf
from jax._src.numpy.util import _wraps
from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf
@_wraps(osp_stats.expon.logpdf, update_doc=False)

View File

@ -15,9 +15,9 @@
import scipy.stats as osp_stats
from jax import lax
from jax.numpy._util import _wraps
from jax.numpy.lax_numpy import (_promote_args_inexact, _constant_like,
where, inf)
from jax._src.numpy.util import _wraps
from jax._src.numpy.lax_numpy import (_promote_args_inexact, _constant_like,
where, inf)
from jax.scipy.special import gammaln

View File

@ -15,8 +15,8 @@
import scipy.stats as osp_stats
from jax import lax
from jax.numpy import lax_numpy as jnp
from jax.numpy._util import _wraps
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy.util import _wraps
from jax.scipy.special import xlog1py
@_wraps(osp_stats.geom.logpmf, update_doc=False)

View File

@ -15,8 +15,8 @@
import scipy.stats as osp_stats
from jax import lax
from jax.numpy._util import _wraps
from jax.numpy.lax_numpy import _promote_args_inexact, _constant_like
from jax._src.numpy.util import _wraps
from jax._src.numpy.lax_numpy import _promote_args_inexact, _constant_like
@_wraps(osp_stats.laplace.logpdf, update_doc=False)

View File

@ -16,7 +16,7 @@ import scipy.stats as osp_stats
from jax.scipy.special import expit, logit
from jax import lax
from jax.numpy._util import _wraps
from jax._src.numpy.util import _wraps
@_wraps(osp_stats.logistic.logpdf, update_doc=False)

View File

@ -19,8 +19,8 @@ import scipy.stats as osp_stats
from jax import lax
from jax.lax_linalg import cholesky, triangular_solve
from jax import numpy as jnp
from jax.numpy._util import _wraps
from jax.numpy.lax_numpy import _promote_dtypes_inexact
from jax._src.numpy.util import _wraps
from jax._src.numpy.lax_numpy import _promote_dtypes_inexact
@_wraps(osp_stats.multivariate_normal.logpdf, update_doc=False)

View File

@ -17,9 +17,9 @@ import numpy as np
import scipy.stats as osp_stats
from jax import lax
from jax.numpy import lax_numpy as jnp
from jax.numpy._util import _wraps
from jax.numpy.lax_numpy import _promote_args_inexact, _constant_like
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy.util import _wraps
from jax._src.numpy.lax_numpy import _promote_args_inexact, _constant_like
from jax.scipy import special
@_wraps(osp_stats.norm.logpdf, update_doc=False)

View File

@ -16,8 +16,8 @@
import scipy.stats as osp_stats
from jax import lax
from jax.numpy._util import _wraps
from jax.numpy.lax_numpy import _promote_args_inexact, _constant_like, inf, where
from jax._src.numpy.util import _wraps
from jax._src.numpy.lax_numpy import _promote_args_inexact, _constant_like, inf, where
@_wraps(osp_stats.pareto.logpdf, update_doc=False)

View File

@ -16,8 +16,8 @@
import scipy.stats as osp_stats
from jax import lax
from jax.numpy._util import _wraps
from jax.numpy import lax_numpy as jnp
from jax._src.numpy.util import _wraps
from jax._src.numpy import lax_numpy as jnp
from jax.scipy.special import xlogy, gammaln

View File

@ -17,8 +17,8 @@ import numpy as np
import scipy.stats as osp_stats
from jax import lax
from jax.numpy._util import _wraps
from jax.numpy.lax_numpy import _promote_args_inexact, _constant_like
from jax._src.numpy.util import _wraps
from jax._src.numpy.lax_numpy import _promote_args_inexact, _constant_like
@_wraps(osp_stats.t.logpdf, update_doc=False)

View File

@ -16,8 +16,8 @@
import scipy.stats as osp_stats
from jax import lax
from jax.numpy._util import _wraps
from jax.numpy.lax_numpy import _promote_args_inexact, where, inf, logical_or
from jax._src.numpy.util import _wraps
from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf, logical_or
@_wraps(osp_stats.uniform.logpdf, update_doc=False)

View File

@ -35,7 +35,7 @@ from jax.util import partial, unzip2, prod
from jax.lib import xla_client as xc
from jax.lib import xla_bridge as xb
from jax.config import config
from jax.numpy import lax_numpy
from jax._src.numpy import lax_numpy
xops = xc.ops

View File

@ -15,8 +15,8 @@
import numpy as np
from jax.numpy import lax_numpy as jnp
from jax.numpy.vectorize import vectorize
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy.vectorize import vectorize
from jax import ad_util
from jax import api
from jax import lax

View File

@ -16,9 +16,9 @@
from . import fft
from . import linalg
from ..interpreters.xla import DeviceArray
from jax.interpreters.xla import DeviceArray
from .lax_numpy import (
from jax._src.numpy.lax_numpy import (
ComplexWarning, NINF, NZERO, PZERO, abs, absolute, add, all, allclose,
alltrue, amax, amin, angle, any, append,
apply_along_axis, apply_over_axes, arange, arccos, arccosh, arcsin,
@ -63,16 +63,18 @@ from .lax_numpy import (
unpackbits, unravel_index, unsignedinteger, unwrap, vander, var, vdot, vsplit,
vstack, where, zeros, zeros_like, _NOT_IMPLEMENTED)
from .polynomial import roots
from .vectorize import vectorize
from jax._src.numpy.polynomial import roots
from jax._src.numpy.vectorize import vectorize
# TODO(phawkins): remove this import after fixing users.
from jax._src.numpy import lax_numpy
# Module initialization is encapsulated in a function to avoid accidental
# namespace pollution.
def _init():
import numpy as np
from . import lax_numpy
from .. import util
from jax._src.numpy import lax_numpy
from jax import util
# Builds a set of all unimplemented NumPy functions.
for name, func in util.get_module_functions(np).items():
if name not in globals():

View File

@ -1,4 +1,4 @@
# Copyright 2018 Google LLC
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,251 +12,41 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa: F401
import numpy as np
from .. import lax
from ..lib import xla_client
from ..util import get_module_functions
from .lax_numpy import _not_implemented
from ._util import _wraps
from . import lax_numpy as jnp
from .. import ops as jaxops
def _fft_core(func_name, fft_type, a, s, axes, norm):
# TODO(skye): implement padding/cropping based on 's'.
full_name = "jax.numpy.fft." + func_name
if s is not None:
raise NotImplementedError("%s only supports s=None, got %s" % (full_name, s))
if norm is not None:
raise NotImplementedError("%s only supports norm=None, got %s" % (full_name, norm))
if s is not None and axes is not None and len(s) != len(axes):
# Same error as numpy.
raise ValueError("Shape and axes have different lengths.")
orig_axes = axes
if axes is None:
if s is None:
axes = range(a.ndim)
else:
axes = range(a.ndim - len(s), a.ndim)
if len(axes) != len(set(axes)):
raise ValueError(
"%s does not support repeated axes. Got axes %s." % (full_name, axes))
if len(axes) > 3:
# XLA does not support FFTs over more than 3 dimensions
raise ValueError(
"%s only supports 1D, 2D, and 3D FFTs. "
"Got axes %s with input rank %s." % (full_name, orig_axes, a.ndim))
# XLA only supports FFTs over the innermost axes, so rearrange if necessary.
if orig_axes is not None:
axes = tuple(range(a.ndim - len(axes), a.ndim))
a = jnp.moveaxis(a, orig_axes, axes)
if s is None:
if fft_type == xla_client.FftType.IRFFT:
s = [a.shape[axis] for axis in axes[:-1]]
if axes:
s += [max(0, 2 * (a.shape[axes[-1]] - 1))]
else:
s = [a.shape[axis] for axis in axes]
transformed = lax.fft(a, fft_type, s)
if orig_axes is not None:
transformed = jnp.moveaxis(transformed, axes, orig_axes)
return transformed
@_wraps(np.fft.fftn)
def fftn(a, s=None, axes=None, norm=None):
return _fft_core('fftn', xla_client.FftType.FFT, a, s, axes, norm)
@_wraps(np.fft.ifftn)
def ifftn(a, s=None, axes=None, norm=None):
return _fft_core('ifftn', xla_client.FftType.IFFT, a, s, axes, norm)
@_wraps(np.fft.rfftn)
def rfftn(a, s=None, axes=None, norm=None):
return _fft_core('rfftn', xla_client.FftType.RFFT, a, s, axes, norm)
@_wraps(np.fft.irfftn)
def irfftn(a, s=None, axes=None, norm=None):
return _fft_core('irfftn', xla_client.FftType.IRFFT, a, s, axes, norm)
def _axis_check_1d(func_name, axis):
full_name = "jax.numpy.fft." + func_name
if isinstance(axis, (list, tuple)):
raise ValueError(
"%s does not support multiple axes. Please use %sn. "
"Got axis = %r." % (full_name, full_name, axis)
)
def _fft_core_1d(func_name, fft_type, a, s, axis, norm):
_axis_check_1d(func_name, axis)
axes = None if axis is None else [axis]
return _fft_core(func_name, fft_type, a, s, axes, norm)
@_wraps(np.fft.fft)
def fft(a, n=None, axis=-1, norm=None):
return _fft_core_1d('fft', xla_client.FftType.FFT, a, s=n, axis=axis,
norm=norm)
@_wraps(np.fft.ifft)
def ifft(a, n=None, axis=-1, norm=None):
return _fft_core_1d('ifft', xla_client.FftType.IFFT, a, s=n, axis=axis,
norm=norm)
@_wraps(np.fft.rfft)
def rfft(a, n=None, axis=-1, norm=None):
return _fft_core_1d('rfft', xla_client.FftType.RFFT, a, s=n, axis=axis,
norm=norm)
@_wraps(np.fft.irfft)
def irfft(a, n=None, axis=-1, norm=None):
return _fft_core_1d('irfft', xla_client.FftType.IRFFT, a, s=n, axis=axis,
norm=norm)
@_wraps(np.fft.hfft)
def hfft(a, n=None, axis=-1, norm=None):
conj_a = jnp.conj(a)
_axis_check_1d('hfft', axis)
nn = (a.shape[axis] - 1) * 2 if n is None else n
return _fft_core_1d('hfft', xla_client.FftType.IRFFT, conj_a, s=n, axis=axis,
norm=norm) * nn
@_wraps(np.fft.ihfft)
def ihfft(a, n=None, axis=-1, norm=None):
_axis_check_1d('ihfft', axis)
nn = a.shape[axis] if n is None else n
output = _fft_core_1d('ihfft', xla_client.FftType.RFFT, a, s=n, axis=axis,
norm=norm)
return jnp.conj(output) * (1 / nn)
def _fft_core_2d(func_name, fft_type, a, s, axes, norm):
full_name = "jax.numpy.fft." + func_name
if len(axes) != 2:
raise ValueError(
"%s only supports 2 axes. Got axes = %r."
% (full_name, axes)
)
return _fft_core(func_name, fft_type, a, s, axes, norm)
@_wraps(np.fft.fft2)
def fft2(a, s=None, axes=(-2,-1), norm=None):
return _fft_core_2d('fft2', xla_client.FftType.FFT, a, s=s, axes=axes,
norm=norm)
@_wraps(np.fft.ifft2)
def ifft2(a, s=None, axes=(-2,-1), norm=None):
return _fft_core_2d('ifft2', xla_client.FftType.IFFT, a, s=s, axes=axes,
norm=norm)
@_wraps(np.fft.rfft2)
def rfft2(a, s=None, axes=(-2,-1), norm=None):
return _fft_core_2d('rfft2', xla_client.FftType.RFFT, a, s=s, axes=axes,
norm=norm)
@_wraps(np.fft.irfft2)
def irfft2(a, s=None, axes=(-2,-1), norm=None):
return _fft_core_2d('irfft2', xla_client.FftType.IRFFT, a, s=s, axes=axes,
norm=norm)
@_wraps(np.fft.fftfreq)
def fftfreq(n, d=1.0):
if isinstance(n, (list, tuple)):
raise ValueError(
"The n argument of jax.numpy.fft.fftfreq only takes an int. "
"Got n = %s." % list(n))
elif isinstance(d, (list, tuple)):
raise ValueError(
"The d argument of jax.numpy.fft.fftfreq only takes a single value. "
"Got d = %s." % list(d))
k = jnp.zeros(n)
if n % 2 == 0:
# k[0: n // 2 - 1] = jnp.arange(0, n // 2 - 1)
k = jaxops.index_update(k, jaxops.index[0: n // 2], jnp.arange(0, n // 2))
# k[n // 2:] = jnp.arange(-n // 2, -1)
k = jaxops.index_update(k, jaxops.index[n // 2:], jnp.arange(-n // 2, 0))
else:
# k[0: (n - 1) // 2] = jnp.arange(0, (n - 1) // 2)
k = jaxops.index_update(k, jaxops.index[0: (n - 1) // 2 + 1],
jnp.arange(0, (n - 1) // 2 + 1))
# k[(n - 1) // 2 + 1:] = jnp.arange(-(n - 1) // 2, -1)
k = jaxops.index_update(k, jaxops.index[(n - 1) // 2 + 1:],
jnp.arange(-(n - 1) // 2, 0))
return k / (d * n)
@_wraps(np.fft.rfftfreq)
def rfftfreq(n, d=1.0):
if isinstance(n, (list, tuple)):
raise ValueError(
"The n argument of jax.numpy.fft.rfftfreq only takes an int. "
"Got n = %s." % list(n))
elif isinstance(d, (list, tuple)):
raise ValueError(
"The d argument of jax.numpy.fft.rfftfreq only takes a single value. "
"Got d = %s." % list(d))
if n % 2 == 0:
k = jnp.arange(0, n // 2 + 1)
else:
k = jnp.arange(0, (n - 1) // 2 + 1)
return k / (d * n)
@_wraps(np.fft.fftshift)
def fftshift(x, axes=None):
x = jnp.asarray(x)
if axes is None:
axes = tuple(range(x.ndim))
shift = [dim // 2 for dim in x.shape]
elif isinstance(axes, int):
shift = x.shape[axes] // 2
else:
shift = [x.shape[ax] // 2 for ax in axes]
return jnp.roll(x, shift, axes)
@_wraps(np.fft.ifftshift)
def ifftshift(x, axes=None):
x = jnp.asarray(x)
if axes is None:
axes = tuple(range(x.ndim))
shift = [-(dim // 2) for dim in x.shape]
elif isinstance(axes, int):
shift = -(x.shape[axes] // 2)
else:
shift = [-(x.shape[ax] // 2) for ax in axes]
return jnp.roll(x, shift, axes)
from jax._src.numpy.fft import (
ifft,
ifft2,
ifftn,
ifftshift,
ihfft,
irfft,
irfft2,
irfftn,
fft,
fft2,
fftfreq,
fftn,
fftshift,
hfft,
rfft,
rfft2,
rfftfreq,
rfftn,
)
# Module initialization is encapsulated in a function to avoid accidental
# namespace pollution.
_NOT_IMPLEMENTED = []
for name, func in get_module_functions(np.fft).items():
if name not in globals():
_NOT_IMPLEMENTED.append(name)
globals()[name] = _not_implemented(func)
def _init():
import numpy as np
from jax._src.numpy import lax_numpy
from jax import util
# Builds a set of all unimplemented NumPy functions.
for name, func in util.get_module_functions(np.fft).items():
if name not in globals():
_NOT_IMPLEMENTED.append(name)
globals()[name] = lax_numpy._not_implemented(func)
_init()
del _init

View File

@ -1,4 +1,4 @@
# Copyright 2018 Google LLC
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,527 +12,43 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import numpy as np
import textwrap
import operator
from typing import Tuple, Union, cast
from jax import jit, vmap, custom_jvp
from .. import lax
from .. import ops
from .. import lax_linalg
from .. import dtypes
from .lax_numpy import _not_implemented
from ._util import _wraps
from .vectorize import vectorize
from . import lax_numpy as jnp
from ..util import get_module_functions, canonicalize_axis
from ..third_party.numpy.linalg import cond, multi_dot, tensorinv, tensorsolve # noqa: F401
_T = lambda x: jnp.swapaxes(x, -1, -2)
_H = lambda x: jnp.conj(jnp.swapaxes(x, -1, -2))
def _promote_arg_dtypes(*args):
"""Promotes `args` to a common inexact type."""
def _to_inexact_type(type):
return type if jnp.issubdtype(type, jnp.inexact) else jnp.float_
inexact_types = [_to_inexact_type(jnp._dtype(arg)) for arg in args]
dtype = dtypes.canonicalize_dtype(jnp.result_type(*inexact_types))
args = [lax.convert_element_type(arg, dtype) for arg in args]
if len(args) == 1:
return args[0]
else:
return args
@_wraps(np.linalg.cholesky)
def cholesky(a):
a = _promote_arg_dtypes(jnp.asarray(a))
return lax_linalg.cholesky(a)
@_wraps(np.linalg.svd)
def svd(a, full_matrices=True, compute_uv=True):
a = _promote_arg_dtypes(jnp.asarray(a))
return lax_linalg.svd(a, full_matrices, compute_uv)
@_wraps(np.linalg.matrix_power)
def matrix_power(a, n):
a = _promote_arg_dtypes(jnp.asarray(a))
if a.ndim < 2:
raise TypeError("{}-dimensional array given. Array must be at least "
"two-dimensional".format(a.ndim))
if a.shape[-2] != a.shape[-1]:
raise TypeError("Last 2 dimensions of the array must be square")
try:
n = operator.index(n)
except TypeError as err:
raise TypeError("exponent must be an integer, got {}".format(n)) from err
if n == 0:
return jnp.broadcast_to(jnp.eye(a.shape[-2], dtype=a.dtype), a.shape)
elif n < 0:
a = inv(a)
n = np.abs(n)
if n == 1:
return a
elif n == 2:
return a @ a
elif n == 3:
return (a @ a) @ a
z = result = None
while n > 0:
z = a if z is None else (z @ z)
n, bit = divmod(n, 2)
if bit:
result = z if result is None else (result @ z)
return result
@_wraps(np.linalg.matrix_rank)
def matrix_rank(M, tol=None):
M = _promote_arg_dtypes(jnp.asarray(M))
if M.ndim > 2:
raise TypeError("array should have 2 or fewer dimensions")
if M.ndim < 2:
return jnp.any(M != 0).astype(jnp.int32)
S = svd(M, full_matrices=False, compute_uv=False)
if tol is None:
tol = S.max() * np.max(M.shape) * jnp.finfo(S.dtype).eps
return jnp.sum(S > tol)
@custom_jvp
@_wraps(np.linalg.slogdet)
@jit
def slogdet(a):
a = _promote_arg_dtypes(jnp.asarray(a))
dtype = lax.dtype(a)
a_shape = jnp.shape(a)
if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]:
msg = "Argument to slogdet() must have shape [..., n, n], got {}"
raise ValueError(msg.format(a_shape))
lu, pivot, _ = lax_linalg.lu(a)
diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1)
parity = jnp.count_nonzero(pivot != jnp.arange(a_shape[-1]), axis=-1)
if jnp.iscomplexobj(a):
sign = jnp.prod(diag / jnp.abs(diag), axis=-1)
else:
sign = jnp.array(1, dtype=dtype)
parity = parity + jnp.count_nonzero(diag < 0, axis=-1)
sign = jnp.where(is_zero,
jnp.array(0, dtype=dtype),
sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype))
logdet = jnp.where(
is_zero, jnp.array(-jnp.inf, dtype=dtype),
jnp.sum(jnp.log(jnp.abs(diag)), axis=-1))
return sign, jnp.real(logdet)
@slogdet.defjvp
def _slogdet_jvp(primals, tangents):
x, = primals
g, = tangents
if jnp.issubdtype(jnp._dtype(x), jnp.complexfloating):
raise NotImplementedError # TODO(pfau): make this work for complex types
sign, ans = slogdet(x)
sign_dot, ans_dot = jnp.zeros_like(sign), jnp.trace(solve(x, g), axis1=-1, axis2=-2)
return (sign, ans), (sign_dot, ans_dot)
def _cofactor_solve(a, b):
"""Equivalent to det(a)*solve(a, b) for nonsingular mat.
Intermediate function used for jvp and vjp of det.
This function borrows heavily from jax.numpy.linalg.solve and
jax.numpy.linalg.slogdet to compute the gradient of the determinant
in a way that is well defined even for low rank matrices.
This function handles two different cases:
* rank(a) == n or n-1
* rank(a) < n-1
For rank n-1 matrices, the gradient of the determinant is a rank 1 matrix.
Rather than computing det(a)*solve(a, b), which would return NaN, we work
directly with the LU decomposition. If a = p @ l @ u, then
det(a)*solve(a, b) =
prod(diag(u)) * u^-1 @ l^-1 @ p^-1 b =
prod(diag(u)) * triangular_solve(u, solve(p @ l, b))
If a is rank n-1, then the lower right corner of u will be zero and the
triangular_solve will fail.
Let x = solve(p @ l, b) and y = det(a)*solve(a, b).
Then y_{n}
x_{n} / u_{nn} * prod_{i=1...n}(u_{ii}) =
x_{n} * prod_{i=1...n-1}(u_{ii})
So by replacing the lower-right corner of u with prod_{i=1...n-1}(u_{ii})^-1
we can avoid the triangular_solve failing.
To correctly compute the rest of y_{i} for i != n, we simply multiply
x_{i} by det(a) for all i != n, which will be zero if rank(a) = n-1.
For the second case, a check is done on the matrix to see if `solve`
returns NaN or Inf, and gives a matrix of zeros as a result, as the
gradient of the determinant of a matrix with rank less than n-1 is 0.
This will still return the correct value for rank n-1 matrices, as the check
is applied *after* the lower right corner of u has been updated.
Args:
a: A square matrix or batch of matrices, possibly singular.
b: A matrix, or batch of matrices of the same dimension as a.
Returns:
det(a) and cofactor(a)^T*b, aka adjugate(a)*b
"""
a = _promote_arg_dtypes(jnp.asarray(a))
b = _promote_arg_dtypes(jnp.asarray(b))
a_shape = jnp.shape(a)
b_shape = jnp.shape(b)
a_ndims = len(a_shape)
if not (a_ndims >= 2 and a_shape[-1] == a_shape[-2]
and b_shape[-2:] == a_shape[-2:]):
msg = ("The arguments to _cofactor_solve must have shapes "
"a=[..., m, m] and b=[..., m, m]; got a={} and b={}")
raise ValueError(msg.format(a_shape, b_shape))
if a_shape[-1] == 1:
return a[0, 0], b
# lu contains u in the upper triangular matrix and l in the strict lower
# triangular matrix.
# The diagonal of l is set to ones without loss of generality.
lu, pivots, permutation = lax_linalg.lu(a)
dtype = lax.dtype(a)
batch_dims = lax.broadcast_shapes(lu.shape[:-2], b.shape[:-2])
x = jnp.broadcast_to(b, batch_dims + b.shape[-2:])
lu = jnp.broadcast_to(lu, batch_dims + lu.shape[-2:])
# Compute (partial) determinant, ignoring last diagonal of LU
diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
parity = jnp.count_nonzero(pivots != jnp.arange(a_shape[-1]), axis=-1)
sign = jnp.array(-2 * (parity % 2) + 1, dtype=dtype)
# partial_det[:, -1] contains the full determinant and
# partial_det[:, -2] contains det(u) / u_{nn}.
partial_det = jnp.cumprod(diag, axis=-1) * sign[..., None]
lu = ops.index_update(lu, ops.index[..., -1, -1], 1.0 / partial_det[..., -2])
permutation = jnp.broadcast_to(permutation, batch_dims + (a_shape[-1],))
iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1,)))
# filter out any matrices that are not full rank
d = jnp.ones(x.shape[:-1], x.dtype)
d = lax_linalg.triangular_solve(lu, d, left_side=True, lower=False)
d = jnp.any(jnp.logical_or(jnp.isnan(d), jnp.isinf(d)), axis=-1)
d = jnp.tile(d[..., None, None], d.ndim*(1,) + x.shape[-2:])
x = jnp.where(d, jnp.zeros_like(x), x) # first filter
x = x[iotas[:-1] + (permutation, slice(None))]
x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=True,
unit_diagonal=True)
x = jnp.concatenate((x[..., :-1, :] * partial_det[..., -1, None, None],
x[..., -1:, :]), axis=-2)
x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=False)
x = jnp.where(d, jnp.zeros_like(x), x) # second filter
return partial_det[..., -1], x
@custom_jvp
@_wraps(np.linalg.det)
def det(a):
sign, logdet = slogdet(a)
return sign * jnp.exp(logdet)
@det.defjvp
def _det_jvp(primals, tangents):
x, = primals
g, = tangents
y, z = _cofactor_solve(x, g)
return y, jnp.trace(z, axis1=-1, axis2=-2)
@_wraps(np.linalg.eig)
def eig(a):
a = _promote_arg_dtypes(jnp.asarray(a))
return lax_linalg.eig(a, compute_left_eigenvectors=False)
@_wraps(np.linalg.eigvals)
def eigvals(a):
return lax_linalg.eig(a, compute_left_eigenvectors=False,
compute_right_eigenvectors=False)[0]
@_wraps(np.linalg.eigh)
def eigh(a, UPLO=None, symmetrize_input=True):
if UPLO is None or UPLO == "L":
lower = True
elif UPLO == "U":
lower = False
else:
msg = "UPLO must be one of None, 'L', or 'U', got {}".format(UPLO)
raise ValueError(msg)
a = _promote_arg_dtypes(jnp.asarray(a))
v, w = lax_linalg.eigh(a, lower=lower, symmetrize_input=symmetrize_input)
return w, v
@_wraps(np.linalg.eigvalsh)
def eigvalsh(a, UPLO='L'):
w, _ = eigh(a, UPLO)
return w
@partial(custom_jvp, nondiff_argnums=(1,))
@_wraps(np.linalg.pinv, lax_description=textwrap.dedent("""\
It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the
default `rcond` is `1e-15`. Here the default is
`10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps`.
"""))
def pinv(a, rcond=None):
# Uses same algorithm as
# https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979
a = jnp.conj(a)
if rcond is None:
max_rows_cols = max(a.shape[-2:])
rcond = 10. * max_rows_cols * jnp.finfo(a.dtype).eps
rcond = jnp.asarray(rcond)
u, s, v = svd(a, full_matrices=False)
# Singular values less than or equal to ``rcond * largest_singular_value``
# are set to zero.
cutoff = rcond[..., jnp.newaxis] * jnp.amax(s, axis=-1, keepdims=True)
s = jnp.where(s > cutoff, s, jnp.inf)
res = jnp.matmul(_T(v), jnp.divide(_T(u), s[..., jnp.newaxis]))
return lax.convert_element_type(res, a.dtype)
@pinv.defjvp
def _pinv_jvp(rcond, primals, tangents):
# The Differentiation of Pseudo-Inverses and Nonlinear Least Squares Problems
# Whose Variables Separate. Author(s): G. H. Golub and V. Pereyra. SIAM
# Journal on Numerical Analysis, Vol. 10, No. 2 (Apr., 1973), pp. 413-432.
# (via https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Derivative)
a, = primals
a_dot, = tangents
p = pinv(a, rcond=rcond)
m, n = a.shape[-2:]
# TODO(phawkins): on TPU, we would need to opt into high precision here.
# TODO(phawkins): consider if this can be simplified in the Hermitian case.
p_dot = -p @ a_dot @ p
p_dot = p_dot + p @ _H(p) @ _H(a_dot) @ (jnp.eye(m, dtype=a.dtype) - a @ p)
p_dot = p_dot + (jnp.eye(n, dtype=a.dtype) - p @ a) @ _H(a_dot) @ _H(p) @ p
return p, p_dot
@_wraps(np.linalg.inv)
def inv(a):
if jnp.ndim(a) < 2 or a.shape[-1] != a.shape[-2]:
raise ValueError("Argument to inv must have shape [..., n, n], got {}."
.format(jnp.shape(a)))
return solve(
a, lax.broadcast(jnp.eye(a.shape[-1], dtype=lax.dtype(a)), a.shape[:-2]))
@partial(jit, static_argnums=(1, 2, 3))
def _norm(x, ord, axis: Union[None, Tuple[int, ...], int], keepdims):
x = _promote_arg_dtypes(jnp.asarray(x))
x_shape = jnp.shape(x)
ndim = len(x_shape)
if axis is None:
# NumPy has an undocumented behavior that admits arbitrary rank inputs if
# `ord` is None: https://github.com/numpy/numpy/issues/14215
if ord is None:
return jnp.sqrt(jnp.sum(jnp.real(x * jnp.conj(x)), keepdims=keepdims))
axis = tuple(range(ndim))
elif isinstance(axis, tuple):
axis = tuple(canonicalize_axis(x, ndim) for x in axis)
else:
axis = (canonicalize_axis(axis, ndim),)
num_axes = len(axis)
if num_axes == 1:
if ord is None or ord == 2:
return jnp.sqrt(jnp.sum(jnp.real(x * jnp.conj(x)), axis=axis,
keepdims=keepdims))
elif ord == jnp.inf:
return jnp.amax(jnp.abs(x), axis=axis, keepdims=keepdims)
elif ord == -jnp.inf:
return jnp.amin(jnp.abs(x), axis=axis, keepdims=keepdims)
elif ord == 0:
return jnp.sum(x != 0, dtype=jnp.finfo(lax.dtype(x)).dtype,
axis=axis, keepdims=keepdims)
elif ord == 1:
# Numpy has a special case for ord == 1 as an optimization. We don't
# really need the optimization (XLA could do it for us), but the Numpy
# code has slightly different type promotion semantics, so we need a
# special case too.
return jnp.sum(jnp.abs(x), axis=axis, keepdims=keepdims)
else:
abs_x = jnp.abs(x)
ord = lax._const(abs_x, ord)
out = jnp.sum(abs_x ** ord, axis=axis, keepdims=keepdims)
return jnp.power(out, 1. / ord)
elif num_axes == 2:
row_axis, col_axis = cast(Tuple[int, ...], axis)
if ord is None or ord in ('f', 'fro'):
return jnp.sqrt(jnp.sum(jnp.real(x * jnp.conj(x)), axis=axis,
keepdims=keepdims))
elif ord == 1:
if not keepdims and col_axis > row_axis:
col_axis -= 1
return jnp.amax(jnp.sum(jnp.abs(x), axis=row_axis, keepdims=keepdims),
axis=col_axis, keepdims=keepdims)
elif ord == -1:
if not keepdims and col_axis > row_axis:
col_axis -= 1
return jnp.amin(jnp.sum(jnp.abs(x), axis=row_axis, keepdims=keepdims),
axis=col_axis, keepdims=keepdims)
elif ord == jnp.inf:
if not keepdims and row_axis > col_axis:
row_axis -= 1
return jnp.amax(jnp.sum(jnp.abs(x), axis=col_axis, keepdims=keepdims),
axis=row_axis, keepdims=keepdims)
elif ord == -jnp.inf:
if not keepdims and row_axis > col_axis:
row_axis -= 1
return jnp.amin(jnp.sum(jnp.abs(x), axis=col_axis, keepdims=keepdims),
axis=row_axis, keepdims=keepdims)
elif ord in ('nuc', 2, -2):
x = jnp.moveaxis(x, axis, (-2, -1))
if ord == 2:
reducer = jnp.amax
elif ord == -2:
reducer = jnp.amin
else:
reducer = jnp.sum
y = reducer(svd(x, compute_uv=False), axis=-1)
if keepdims:
result_shape = list(x_shape)
result_shape[axis[0]] = 1
result_shape[axis[1]] = 1
y = jnp.reshape(y, result_shape)
return y
else:
raise ValueError("Invalid order '{}' for matrix norm.".format(ord))
else:
raise ValueError(
"Invalid axis values ({}) for jnp.linalg.norm.".format(axis))
@_wraps(np.linalg.norm)
def norm(x, ord=None, axis=None, keepdims=False):
return _norm(x, ord, axis, keepdims)
@_wraps(np.linalg.qr)
def qr(a, mode="reduced"):
if mode in ("reduced", "r", "full"):
full_matrices = False
elif mode == "complete":
full_matrices = True
else:
raise ValueError("Unsupported QR decomposition mode '{}'".format(mode))
a = _promote_arg_dtypes(jnp.asarray(a))
q, r = lax_linalg.qr(a, full_matrices)
if mode == "r":
return r
return q, r
def _check_solve_shapes(a, b):
if not (a.ndim >= 2 and a.shape[-1] == a.shape[-2] and b.ndim >= 1):
msg = ("The arguments to solve must have shapes a=[..., m, m] and "
"b=[..., m, k] or b=[..., m]; got a={} and b={}")
raise ValueError(msg.format(a.shape, b.shape))
@partial(vectorize, signature='(n,m),(m)->(n)')
def _matvec_multiply(a, b):
return jnp.dot(a, b, precision=lax.Precision.HIGHEST)
@_wraps(np.linalg.solve)
@jit
def solve(a, b):
a, b = _promote_arg_dtypes(jnp.asarray(a), jnp.asarray(b))
_check_solve_shapes(a, b)
# With custom_linear_solve, we can reuse the same factorization when
# computing sensitivities. This is considerably faster.
lu, _, permutation = lax_linalg.lu(lax.stop_gradient(a))
custom_solve = partial(
lax.custom_linear_solve,
lambda x: _matvec_multiply(a, x),
solve=lambda _, x: lax_linalg.lu_solve(lu, permutation, x, trans=0),
transpose_solve=lambda _, x: lax_linalg.lu_solve(lu, permutation, x,
trans=1))
if a.ndim == b.ndim + 1:
# b.shape == [..., m]
return custom_solve(b)
else:
# b.shape == [..., m, k]
return vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b)
@_wraps(np.linalg.lstsq, lax_description=textwrap.dedent("""\
It has two important differences:
1. In `numpy.linalg.lstsq`, the default `rcond` is `-1`, and warns that in the future
the default will be `None`. Here, the default rcond is `None`.
2. In `np.linalg.lstsq` the returned residuals are empty for low-rank or over-determined
solutions. Here, the residuals are returned in all cases, to make the function
compatible with jit. The non-jit compatible numpy behavior can be recovered by
passing numpy_resid=True.
The lstsq function does not currently have a custom JVP rule, so the gradient is
poorly behaved for some inputs, particularly for low-rank `a`.
"""))
def lstsq(a, b, rcond=None, *, numpy_resid=False):
# TODO: add lstsq to lax_linalg and implement this function via those wrappers.
# TODO: add custom jvp rule for more robust lstsq differentiation
a, b = _promote_arg_dtypes(a, b)
if a.shape[0] != b.shape[0]:
raise ValueError("Leading dimensions of input arrays must match")
b_orig_ndim = b.ndim
if b_orig_ndim == 1:
b = b[:, None]
if a.ndim != 2:
raise TypeError(
f"{a.ndim}-dimensional array given. Array must be two-dimensional")
if b.ndim != 2:
raise TypeError(
f"{b.ndim}-dimensional array given. Array must be one or two-dimensional")
m, n = a.shape
dtype = a.dtype
if rcond is None:
rcond = jnp.finfo(dtype).eps * max(n, m)
elif rcond < 0:
rcond = jnp.finfo(dtype).eps
u, s, vt = svd(a, full_matrices=False)
mask = s >= rcond * s[0]
rank = mask.sum()
safe_s = jnp.where(mask, s, 1)
s_inv = jnp.where(mask, 1 / safe_s, 0)[:, jnp.newaxis]
uTb = jnp.matmul(u.conj().T, b, precision=lax.Precision.HIGHEST)
x = jnp.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST)
# Numpy returns empty residuals in some cases. To allow compilation, we
# default to returning full residuals in all cases.
if numpy_resid and (rank < n or m <= n):
resid = jnp.asarray([])
else:
b_estimate = jnp.matmul(a, x, precision=lax.Precision.HIGHEST)
resid = norm(b - b_estimate, axis=0) ** 2
if b_orig_ndim == 1:
x = x.ravel()
return x, resid, rank, s
# flake8: noqa: F401
from jax._src.numpy.linalg import (
cholesky,
cond,
det,
eig,
eigh,
eigvals,
eigvalsh,
inv,
lstsq,
matrix_power,
matrix_rank,
multi_dot,
norm,
pinv,
qr,
slogdet,
solve,
svd,
tensorinv,
tensorsolve,
)
# Module initialization is encapsulated in a function to avoid accidental
# namespace pollution.
_NOT_IMPLEMENTED = []
for name, func in get_module_functions(np.linalg).items():
if name not in globals():
_NOT_IMPLEMENTED.append(name)
globals()[name] = _not_implemented(func)
def _init():
import numpy as np
from jax._src.numpy import lax_numpy
from jax import util
# Builds a set of all unimplemented NumPy functions.
for name, func in util.get_module_functions(np.linalg).items():
if name not in globals():
_NOT_IMPLEMENTED.append(name)
globals()[name] = lax_numpy._not_implemented(func)
_init()
del _init

View File

@ -16,7 +16,7 @@
from .. import lax
from ..numpy import lax_numpy as jnp
from jax._src.numpy import lax_numpy as jnp
from .. import util

View File

@ -49,7 +49,7 @@ from . import lax
from . import numpy as jnp
from . import dtypes
from .api import jit, vmap
from .numpy.lax_numpy import _constant_like, asarray
from jax._src.numpy.lax_numpy import _constant_like, asarray
from jax.lib import xla_bridge
from jax.lib import xla_client
from jax.lib import cuda_prng

View File

@ -1,8 +1,8 @@
import numpy as np
from jax.numpy import lax_numpy as jnp
from jax.numpy import linalg as la
from jax.numpy._util import _wraps
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import linalg as la
from jax._src.numpy.util import _wraps
def _isEmpty2d(arr):

View File

@ -23,7 +23,7 @@ from jax import lax
from jax import core
from jax import test_util as jtu
from jax.config import config
from jax.numpy.lax_numpy import _polymorphic_slice_indices
from jax._src.numpy.lax_numpy import _polymorphic_slice_indices
from jax.util import safe_map, safe_zip
from jax.tree_util import tree_flatten

View File

@ -34,12 +34,6 @@ all_dtypes = jtu.dtypes.floating + jtu.dtypes.integer + jtu.dtypes.complex
class TestPolynomial(jtu.JaxTestCase):
def testNotImplemented(self):
for name in jnp.polynomial._NOT_IMPLEMENTED:
func = getattr(jnp.polynomial, name)
with self.assertRaises(NotImplementedError):
func()
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}_leading={}_trailing={}".format(
jtu.format_shape_dtype_string((length+leading+trailing,), dtype),