mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Move jax.numpy internals into jax._src.numpy.
This commit is contained in:
parent
9ea1311c7d
commit
aa107cf1f4
13
jax/_src/numpy/__init__.py
Normal file
13
jax/_src/numpy/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
253
jax/_src/numpy/fft.py
Normal file
253
jax/_src/numpy/fft.py
Normal file
@ -0,0 +1,253 @@
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax import lax
|
||||
from jax.lib import xla_client
|
||||
from .util import _wraps
|
||||
from . import lax_numpy as jnp
|
||||
from jax import ops as jaxops
|
||||
|
||||
|
||||
def _fft_core(func_name, fft_type, a, s, axes, norm):
|
||||
# TODO(skye): implement padding/cropping based on 's'.
|
||||
full_name = "jax.numpy.fft." + func_name
|
||||
if s is not None:
|
||||
raise NotImplementedError("%s only supports s=None, got %s" % (full_name, s))
|
||||
if norm is not None:
|
||||
raise NotImplementedError("%s only supports norm=None, got %s" % (full_name, norm))
|
||||
if s is not None and axes is not None and len(s) != len(axes):
|
||||
# Same error as numpy.
|
||||
raise ValueError("Shape and axes have different lengths.")
|
||||
|
||||
orig_axes = axes
|
||||
if axes is None:
|
||||
if s is None:
|
||||
axes = range(a.ndim)
|
||||
else:
|
||||
axes = range(a.ndim - len(s), a.ndim)
|
||||
|
||||
if len(axes) != len(set(axes)):
|
||||
raise ValueError(
|
||||
"%s does not support repeated axes. Got axes %s." % (full_name, axes))
|
||||
|
||||
if len(axes) > 3:
|
||||
# XLA does not support FFTs over more than 3 dimensions
|
||||
raise ValueError(
|
||||
"%s only supports 1D, 2D, and 3D FFTs. "
|
||||
"Got axes %s with input rank %s." % (full_name, orig_axes, a.ndim))
|
||||
|
||||
# XLA only supports FFTs over the innermost axes, so rearrange if necessary.
|
||||
if orig_axes is not None:
|
||||
axes = tuple(range(a.ndim - len(axes), a.ndim))
|
||||
a = jnp.moveaxis(a, orig_axes, axes)
|
||||
|
||||
if s is None:
|
||||
if fft_type == xla_client.FftType.IRFFT:
|
||||
s = [a.shape[axis] for axis in axes[:-1]]
|
||||
if axes:
|
||||
s += [max(0, 2 * (a.shape[axes[-1]] - 1))]
|
||||
else:
|
||||
s = [a.shape[axis] for axis in axes]
|
||||
|
||||
transformed = lax.fft(a, fft_type, s)
|
||||
|
||||
if orig_axes is not None:
|
||||
transformed = jnp.moveaxis(transformed, axes, orig_axes)
|
||||
return transformed
|
||||
|
||||
|
||||
@_wraps(np.fft.fftn)
|
||||
def fftn(a, s=None, axes=None, norm=None):
|
||||
return _fft_core('fftn', xla_client.FftType.FFT, a, s, axes, norm)
|
||||
|
||||
|
||||
@_wraps(np.fft.ifftn)
|
||||
def ifftn(a, s=None, axes=None, norm=None):
|
||||
return _fft_core('ifftn', xla_client.FftType.IFFT, a, s, axes, norm)
|
||||
|
||||
|
||||
@_wraps(np.fft.rfftn)
|
||||
def rfftn(a, s=None, axes=None, norm=None):
|
||||
return _fft_core('rfftn', xla_client.FftType.RFFT, a, s, axes, norm)
|
||||
|
||||
|
||||
@_wraps(np.fft.irfftn)
|
||||
def irfftn(a, s=None, axes=None, norm=None):
|
||||
return _fft_core('irfftn', xla_client.FftType.IRFFT, a, s, axes, norm)
|
||||
|
||||
|
||||
def _axis_check_1d(func_name, axis):
|
||||
full_name = "jax.numpy.fft." + func_name
|
||||
if isinstance(axis, (list, tuple)):
|
||||
raise ValueError(
|
||||
"%s does not support multiple axes. Please use %sn. "
|
||||
"Got axis = %r." % (full_name, full_name, axis)
|
||||
)
|
||||
|
||||
def _fft_core_1d(func_name, fft_type, a, s, axis, norm):
|
||||
_axis_check_1d(func_name, axis)
|
||||
axes = None if axis is None else [axis]
|
||||
return _fft_core(func_name, fft_type, a, s, axes, norm)
|
||||
|
||||
|
||||
@_wraps(np.fft.fft)
|
||||
def fft(a, n=None, axis=-1, norm=None):
|
||||
return _fft_core_1d('fft', xla_client.FftType.FFT, a, s=n, axis=axis,
|
||||
norm=norm)
|
||||
|
||||
@_wraps(np.fft.ifft)
|
||||
def ifft(a, n=None, axis=-1, norm=None):
|
||||
return _fft_core_1d('ifft', xla_client.FftType.IFFT, a, s=n, axis=axis,
|
||||
norm=norm)
|
||||
|
||||
@_wraps(np.fft.rfft)
|
||||
def rfft(a, n=None, axis=-1, norm=None):
|
||||
return _fft_core_1d('rfft', xla_client.FftType.RFFT, a, s=n, axis=axis,
|
||||
norm=norm)
|
||||
|
||||
@_wraps(np.fft.irfft)
|
||||
def irfft(a, n=None, axis=-1, norm=None):
|
||||
return _fft_core_1d('irfft', xla_client.FftType.IRFFT, a, s=n, axis=axis,
|
||||
norm=norm)
|
||||
|
||||
@_wraps(np.fft.hfft)
|
||||
def hfft(a, n=None, axis=-1, norm=None):
|
||||
conj_a = jnp.conj(a)
|
||||
_axis_check_1d('hfft', axis)
|
||||
nn = (a.shape[axis] - 1) * 2 if n is None else n
|
||||
return _fft_core_1d('hfft', xla_client.FftType.IRFFT, conj_a, s=n, axis=axis,
|
||||
norm=norm) * nn
|
||||
|
||||
@_wraps(np.fft.ihfft)
|
||||
def ihfft(a, n=None, axis=-1, norm=None):
|
||||
_axis_check_1d('ihfft', axis)
|
||||
nn = a.shape[axis] if n is None else n
|
||||
output = _fft_core_1d('ihfft', xla_client.FftType.RFFT, a, s=n, axis=axis,
|
||||
norm=norm)
|
||||
return jnp.conj(output) * (1 / nn)
|
||||
|
||||
|
||||
def _fft_core_2d(func_name, fft_type, a, s, axes, norm):
|
||||
full_name = "jax.numpy.fft." + func_name
|
||||
if len(axes) != 2:
|
||||
raise ValueError(
|
||||
"%s only supports 2 axes. Got axes = %r."
|
||||
% (full_name, axes)
|
||||
)
|
||||
return _fft_core(func_name, fft_type, a, s, axes, norm)
|
||||
|
||||
|
||||
@_wraps(np.fft.fft2)
|
||||
def fft2(a, s=None, axes=(-2,-1), norm=None):
|
||||
return _fft_core_2d('fft2', xla_client.FftType.FFT, a, s=s, axes=axes,
|
||||
norm=norm)
|
||||
|
||||
@_wraps(np.fft.ifft2)
|
||||
def ifft2(a, s=None, axes=(-2,-1), norm=None):
|
||||
return _fft_core_2d('ifft2', xla_client.FftType.IFFT, a, s=s, axes=axes,
|
||||
norm=norm)
|
||||
|
||||
@_wraps(np.fft.rfft2)
|
||||
def rfft2(a, s=None, axes=(-2,-1), norm=None):
|
||||
return _fft_core_2d('rfft2', xla_client.FftType.RFFT, a, s=s, axes=axes,
|
||||
norm=norm)
|
||||
|
||||
@_wraps(np.fft.irfft2)
|
||||
def irfft2(a, s=None, axes=(-2,-1), norm=None):
|
||||
return _fft_core_2d('irfft2', xla_client.FftType.IRFFT, a, s=s, axes=axes,
|
||||
norm=norm)
|
||||
|
||||
|
||||
@_wraps(np.fft.fftfreq)
|
||||
def fftfreq(n, d=1.0):
|
||||
if isinstance(n, (list, tuple)):
|
||||
raise ValueError(
|
||||
"The n argument of jax.numpy.fft.fftfreq only takes an int. "
|
||||
"Got n = %s." % list(n))
|
||||
|
||||
elif isinstance(d, (list, tuple)):
|
||||
raise ValueError(
|
||||
"The d argument of jax.numpy.fft.fftfreq only takes a single value. "
|
||||
"Got d = %s." % list(d))
|
||||
|
||||
k = jnp.zeros(n)
|
||||
if n % 2 == 0:
|
||||
# k[0: n // 2 - 1] = jnp.arange(0, n // 2 - 1)
|
||||
k = jaxops.index_update(k, jaxops.index[0: n // 2], jnp.arange(0, n // 2))
|
||||
|
||||
# k[n // 2:] = jnp.arange(-n // 2, -1)
|
||||
k = jaxops.index_update(k, jaxops.index[n // 2:], jnp.arange(-n // 2, 0))
|
||||
|
||||
else:
|
||||
# k[0: (n - 1) // 2] = jnp.arange(0, (n - 1) // 2)
|
||||
k = jaxops.index_update(k, jaxops.index[0: (n - 1) // 2 + 1],
|
||||
jnp.arange(0, (n - 1) // 2 + 1))
|
||||
|
||||
# k[(n - 1) // 2 + 1:] = jnp.arange(-(n - 1) // 2, -1)
|
||||
k = jaxops.index_update(k, jaxops.index[(n - 1) // 2 + 1:],
|
||||
jnp.arange(-(n - 1) // 2, 0))
|
||||
|
||||
return k / (d * n)
|
||||
|
||||
|
||||
@_wraps(np.fft.rfftfreq)
|
||||
def rfftfreq(n, d=1.0):
|
||||
if isinstance(n, (list, tuple)):
|
||||
raise ValueError(
|
||||
"The n argument of jax.numpy.fft.rfftfreq only takes an int. "
|
||||
"Got n = %s." % list(n))
|
||||
|
||||
elif isinstance(d, (list, tuple)):
|
||||
raise ValueError(
|
||||
"The d argument of jax.numpy.fft.rfftfreq only takes a single value. "
|
||||
"Got d = %s." % list(d))
|
||||
|
||||
if n % 2 == 0:
|
||||
k = jnp.arange(0, n // 2 + 1)
|
||||
|
||||
else:
|
||||
k = jnp.arange(0, (n - 1) // 2 + 1)
|
||||
|
||||
return k / (d * n)
|
||||
|
||||
|
||||
@_wraps(np.fft.fftshift)
|
||||
def fftshift(x, axes=None):
|
||||
x = jnp.asarray(x)
|
||||
if axes is None:
|
||||
axes = tuple(range(x.ndim))
|
||||
shift = [dim // 2 for dim in x.shape]
|
||||
elif isinstance(axes, int):
|
||||
shift = x.shape[axes] // 2
|
||||
else:
|
||||
shift = [x.shape[ax] // 2 for ax in axes]
|
||||
|
||||
return jnp.roll(x, shift, axes)
|
||||
|
||||
|
||||
@_wraps(np.fft.ifftshift)
|
||||
def ifftshift(x, axes=None):
|
||||
x = jnp.asarray(x)
|
||||
if axes is None:
|
||||
axes = tuple(range(x.ndim))
|
||||
shift = [-(dim // 2) for dim in x.shape]
|
||||
elif isinstance(axes, int):
|
||||
shift = -(x.shape[axes] // 2)
|
||||
else:
|
||||
shift = [-(x.shape[ax] // 2) for ax in axes]
|
||||
|
||||
return jnp.roll(x, shift, axes)
|
@ -39,19 +39,19 @@ import opt_einsum
|
||||
import jax
|
||||
from jax import jit, custom_jvp
|
||||
from .vectorize import vectorize
|
||||
from ._util import _wraps
|
||||
from .. import core
|
||||
from .. import dtypes
|
||||
from ..abstract_arrays import UnshapedArray, ShapedArray, ConcreteArray, canonicalize_shape
|
||||
from ..config import flags, config
|
||||
from ..interpreters.xla import DeviceArray
|
||||
from ..interpreters.masking import Poly
|
||||
from .. import lax
|
||||
from ..lax.lax import _device_put_raw
|
||||
from .. import ops
|
||||
from ..util import (partial, unzip2, prod as _prod,
|
||||
subvals, safe_zip, canonicalize_axis as _canonicalize_axis)
|
||||
from ..tree_util import tree_leaves, tree_flatten
|
||||
from .util import _wraps
|
||||
from jax import core
|
||||
from jax import dtypes
|
||||
from jax.abstract_arrays import UnshapedArray, ShapedArray, ConcreteArray, canonicalize_shape
|
||||
from jax.config import flags, config
|
||||
from jax.interpreters.xla import DeviceArray
|
||||
from jax.interpreters.masking import Poly
|
||||
from jax import lax
|
||||
from jax.lax.lax import _device_put_raw
|
||||
from jax import ops
|
||||
from jax.util import (partial, unzip2, prod as _prod,
|
||||
subvals, safe_zip, canonicalize_axis as _canonicalize_axis)
|
||||
from jax.tree_util import tree_leaves, tree_flatten
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_enum(
|
530
jax/_src/numpy/linalg.py
Normal file
530
jax/_src/numpy/linalg.py
Normal file
@ -0,0 +1,530 @@
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import textwrap
|
||||
import operator
|
||||
from typing import Tuple, Union, cast
|
||||
|
||||
from jax import jit, vmap, custom_jvp
|
||||
from jax import lax
|
||||
from jax import ops
|
||||
from jax import lax_linalg
|
||||
from jax import dtypes
|
||||
from .util import _wraps
|
||||
from .vectorize import vectorize
|
||||
from . import lax_numpy as jnp
|
||||
from jax.util import canonicalize_axis
|
||||
from jax.third_party.numpy.linalg import cond, multi_dot, tensorinv, tensorsolve # noqa: F401
|
||||
|
||||
_T = lambda x: jnp.swapaxes(x, -1, -2)
|
||||
_H = lambda x: jnp.conjugate(jnp.swapaxes(x, -1, -2))
|
||||
|
||||
|
||||
def _promote_arg_dtypes(*args):
|
||||
"""Promotes `args` to a common inexact type."""
|
||||
def _to_inexact_type(type):
|
||||
return type if jnp.issubdtype(type, jnp.inexact) else jnp.float_
|
||||
inexact_types = [_to_inexact_type(jnp._dtype(arg)) for arg in args]
|
||||
dtype = dtypes.canonicalize_dtype(jnp.result_type(*inexact_types))
|
||||
args = [lax.convert_element_type(arg, dtype) for arg in args]
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
else:
|
||||
return args
|
||||
|
||||
|
||||
@_wraps(np.linalg.cholesky)
|
||||
def cholesky(a):
|
||||
a = _promote_arg_dtypes(jnp.asarray(a))
|
||||
return lax_linalg.cholesky(a)
|
||||
|
||||
|
||||
@_wraps(np.linalg.svd)
|
||||
def svd(a, full_matrices=True, compute_uv=True):
|
||||
a = _promote_arg_dtypes(jnp.asarray(a))
|
||||
return lax_linalg.svd(a, full_matrices, compute_uv)
|
||||
|
||||
|
||||
@_wraps(np.linalg.matrix_power)
|
||||
def matrix_power(a, n):
|
||||
a = _promote_arg_dtypes(jnp.asarray(a))
|
||||
|
||||
if a.ndim < 2:
|
||||
raise TypeError("{}-dimensional array given. Array must be at least "
|
||||
"two-dimensional".format(a.ndim))
|
||||
if a.shape[-2] != a.shape[-1]:
|
||||
raise TypeError("Last 2 dimensions of the array must be square")
|
||||
try:
|
||||
n = operator.index(n)
|
||||
except TypeError as err:
|
||||
raise TypeError("exponent must be an integer, got {}".format(n)) from err
|
||||
|
||||
if n == 0:
|
||||
return jnp.broadcast_to(jnp.eye(a.shape[-2], dtype=a.dtype), a.shape)
|
||||
elif n < 0:
|
||||
a = inv(a)
|
||||
n = np.abs(n)
|
||||
|
||||
if n == 1:
|
||||
return a
|
||||
elif n == 2:
|
||||
return a @ a
|
||||
elif n == 3:
|
||||
return (a @ a) @ a
|
||||
|
||||
z = result = None
|
||||
while n > 0:
|
||||
z = a if z is None else (z @ z)
|
||||
n, bit = divmod(n, 2)
|
||||
if bit:
|
||||
result = z if result is None else (result @ z)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@_wraps(np.linalg.matrix_rank)
|
||||
def matrix_rank(M, tol=None):
|
||||
M = _promote_arg_dtypes(jnp.asarray(M))
|
||||
if M.ndim > 2:
|
||||
raise TypeError("array should have 2 or fewer dimensions")
|
||||
if M.ndim < 2:
|
||||
return jnp.any(M != 0).astype(jnp.int32)
|
||||
S = svd(M, full_matrices=False, compute_uv=False)
|
||||
if tol is None:
|
||||
tol = S.max() * np.max(M.shape) * jnp.finfo(S.dtype).eps
|
||||
return jnp.sum(S > tol)
|
||||
|
||||
|
||||
@custom_jvp
|
||||
@_wraps(np.linalg.slogdet)
|
||||
@jit
|
||||
def slogdet(a):
|
||||
a = _promote_arg_dtypes(jnp.asarray(a))
|
||||
dtype = lax.dtype(a)
|
||||
a_shape = jnp.shape(a)
|
||||
if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]:
|
||||
msg = "Argument to slogdet() must have shape [..., n, n], got {}"
|
||||
raise ValueError(msg.format(a_shape))
|
||||
lu, pivot, _ = lax_linalg.lu(a)
|
||||
diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
|
||||
is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1)
|
||||
parity = jnp.count_nonzero(pivot != jnp.arange(a_shape[-1]), axis=-1)
|
||||
if jnp.iscomplexobj(a):
|
||||
sign = jnp.prod(diag / jnp.abs(diag), axis=-1)
|
||||
else:
|
||||
sign = jnp.array(1, dtype=dtype)
|
||||
parity = parity + jnp.count_nonzero(diag < 0, axis=-1)
|
||||
sign = jnp.where(is_zero,
|
||||
jnp.array(0, dtype=dtype),
|
||||
sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype))
|
||||
logdet = jnp.where(
|
||||
is_zero, jnp.array(-jnp.inf, dtype=dtype),
|
||||
jnp.sum(jnp.log(jnp.abs(diag)), axis=-1))
|
||||
return sign, jnp.real(logdet)
|
||||
|
||||
@slogdet.defjvp
|
||||
def _slogdet_jvp(primals, tangents):
|
||||
x, = primals
|
||||
g, = tangents
|
||||
if jnp.issubdtype(jnp._dtype(x), jnp.complexfloating):
|
||||
raise NotImplementedError # TODO(pfau): make this work for complex types
|
||||
sign, ans = slogdet(x)
|
||||
sign_dot, ans_dot = jnp.zeros_like(sign), jnp.trace(solve(x, g), axis1=-1, axis2=-2)
|
||||
return (sign, ans), (sign_dot, ans_dot)
|
||||
|
||||
|
||||
def _cofactor_solve(a, b):
|
||||
"""Equivalent to det(a)*solve(a, b) for nonsingular mat.
|
||||
|
||||
Intermediate function used for jvp and vjp of det.
|
||||
This function borrows heavily from jax.numpy.linalg.solve and
|
||||
jax.numpy.linalg.slogdet to compute the gradient of the determinant
|
||||
in a way that is well defined even for low rank matrices.
|
||||
|
||||
This function handles two different cases:
|
||||
* rank(a) == n or n-1
|
||||
* rank(a) < n-1
|
||||
|
||||
For rank n-1 matrices, the gradient of the determinant is a rank 1 matrix.
|
||||
Rather than computing det(a)*solve(a, b), which would return NaN, we work
|
||||
directly with the LU decomposition. If a = p @ l @ u, then
|
||||
det(a)*solve(a, b) =
|
||||
prod(diag(u)) * u^-1 @ l^-1 @ p^-1 b =
|
||||
prod(diag(u)) * triangular_solve(u, solve(p @ l, b))
|
||||
If a is rank n-1, then the lower right corner of u will be zero and the
|
||||
triangular_solve will fail.
|
||||
Let x = solve(p @ l, b) and y = det(a)*solve(a, b).
|
||||
Then y_{n}
|
||||
x_{n} / u_{nn} * prod_{i=1...n}(u_{ii}) =
|
||||
x_{n} * prod_{i=1...n-1}(u_{ii})
|
||||
So by replacing the lower-right corner of u with prod_{i=1...n-1}(u_{ii})^-1
|
||||
we can avoid the triangular_solve failing.
|
||||
To correctly compute the rest of y_{i} for i != n, we simply multiply
|
||||
x_{i} by det(a) for all i != n, which will be zero if rank(a) = n-1.
|
||||
|
||||
For the second case, a check is done on the matrix to see if `solve`
|
||||
returns NaN or Inf, and gives a matrix of zeros as a result, as the
|
||||
gradient of the determinant of a matrix with rank less than n-1 is 0.
|
||||
This will still return the correct value for rank n-1 matrices, as the check
|
||||
is applied *after* the lower right corner of u has been updated.
|
||||
|
||||
Args:
|
||||
a: A square matrix or batch of matrices, possibly singular.
|
||||
b: A matrix, or batch of matrices of the same dimension as a.
|
||||
|
||||
Returns:
|
||||
det(a) and cofactor(a)^T*b, aka adjugate(a)*b
|
||||
"""
|
||||
a = _promote_arg_dtypes(jnp.asarray(a))
|
||||
b = _promote_arg_dtypes(jnp.asarray(b))
|
||||
a_shape = jnp.shape(a)
|
||||
b_shape = jnp.shape(b)
|
||||
a_ndims = len(a_shape)
|
||||
if not (a_ndims >= 2 and a_shape[-1] == a_shape[-2]
|
||||
and b_shape[-2:] == a_shape[-2:]):
|
||||
msg = ("The arguments to _cofactor_solve must have shapes "
|
||||
"a=[..., m, m] and b=[..., m, m]; got a={} and b={}")
|
||||
raise ValueError(msg.format(a_shape, b_shape))
|
||||
if a_shape[-1] == 1:
|
||||
return a[0, 0], b
|
||||
# lu contains u in the upper triangular matrix and l in the strict lower
|
||||
# triangular matrix.
|
||||
# The diagonal of l is set to ones without loss of generality.
|
||||
lu, pivots, permutation = lax_linalg.lu(a)
|
||||
dtype = lax.dtype(a)
|
||||
batch_dims = lax.broadcast_shapes(lu.shape[:-2], b.shape[:-2])
|
||||
x = jnp.broadcast_to(b, batch_dims + b.shape[-2:])
|
||||
lu = jnp.broadcast_to(lu, batch_dims + lu.shape[-2:])
|
||||
# Compute (partial) determinant, ignoring last diagonal of LU
|
||||
diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
|
||||
parity = jnp.count_nonzero(pivots != jnp.arange(a_shape[-1]), axis=-1)
|
||||
sign = jnp.array(-2 * (parity % 2) + 1, dtype=dtype)
|
||||
# partial_det[:, -1] contains the full determinant and
|
||||
# partial_det[:, -2] contains det(u) / u_{nn}.
|
||||
partial_det = jnp.cumprod(diag, axis=-1) * sign[..., None]
|
||||
lu = ops.index_update(lu, ops.index[..., -1, -1], 1.0 / partial_det[..., -2])
|
||||
permutation = jnp.broadcast_to(permutation, batch_dims + (a_shape[-1],))
|
||||
iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1,)))
|
||||
# filter out any matrices that are not full rank
|
||||
d = jnp.ones(x.shape[:-1], x.dtype)
|
||||
d = lax_linalg.triangular_solve(lu, d, left_side=True, lower=False)
|
||||
d = jnp.any(jnp.logical_or(jnp.isnan(d), jnp.isinf(d)), axis=-1)
|
||||
d = jnp.tile(d[..., None, None], d.ndim*(1,) + x.shape[-2:])
|
||||
x = jnp.where(d, jnp.zeros_like(x), x) # first filter
|
||||
x = x[iotas[:-1] + (permutation, slice(None))]
|
||||
x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=True,
|
||||
unit_diagonal=True)
|
||||
x = jnp.concatenate((x[..., :-1, :] * partial_det[..., -1, None, None],
|
||||
x[..., -1:, :]), axis=-2)
|
||||
x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=False)
|
||||
x = jnp.where(d, jnp.zeros_like(x), x) # second filter
|
||||
|
||||
return partial_det[..., -1], x
|
||||
|
||||
|
||||
@custom_jvp
|
||||
@_wraps(np.linalg.det)
|
||||
def det(a):
|
||||
sign, logdet = slogdet(a)
|
||||
return sign * jnp.exp(logdet)
|
||||
|
||||
|
||||
@det.defjvp
|
||||
def _det_jvp(primals, tangents):
|
||||
x, = primals
|
||||
g, = tangents
|
||||
y, z = _cofactor_solve(x, g)
|
||||
return y, jnp.trace(z, axis1=-1, axis2=-2)
|
||||
|
||||
|
||||
@_wraps(np.linalg.eig)
|
||||
def eig(a):
|
||||
a = _promote_arg_dtypes(jnp.asarray(a))
|
||||
return lax_linalg.eig(a, compute_left_eigenvectors=False)
|
||||
|
||||
|
||||
@_wraps(np.linalg.eigvals)
|
||||
def eigvals(a):
|
||||
return lax_linalg.eig(a, compute_left_eigenvectors=False,
|
||||
compute_right_eigenvectors=False)[0]
|
||||
|
||||
|
||||
@_wraps(np.linalg.eigh)
|
||||
def eigh(a, UPLO=None, symmetrize_input=True):
|
||||
if UPLO is None or UPLO == "L":
|
||||
lower = True
|
||||
elif UPLO == "U":
|
||||
lower = False
|
||||
else:
|
||||
msg = "UPLO must be one of None, 'L', or 'U', got {}".format(UPLO)
|
||||
raise ValueError(msg)
|
||||
|
||||
a = _promote_arg_dtypes(jnp.asarray(a))
|
||||
v, w = lax_linalg.eigh(a, lower=lower, symmetrize_input=symmetrize_input)
|
||||
return w, v
|
||||
|
||||
|
||||
@_wraps(np.linalg.eigvalsh)
|
||||
def eigvalsh(a, UPLO='L'):
|
||||
w, _ = eigh(a, UPLO)
|
||||
return w
|
||||
|
||||
|
||||
@partial(custom_jvp, nondiff_argnums=(1,))
|
||||
@_wraps(np.linalg.pinv, lax_description=textwrap.dedent("""\
|
||||
It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the
|
||||
default `rcond` is `1e-15`. Here the default is
|
||||
`10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps`.
|
||||
"""))
|
||||
def pinv(a, rcond=None):
|
||||
# Uses same algorithm as
|
||||
# https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979
|
||||
a = jnp.conj(a)
|
||||
if rcond is None:
|
||||
max_rows_cols = max(a.shape[-2:])
|
||||
rcond = 10. * max_rows_cols * jnp.finfo(a.dtype).eps
|
||||
rcond = jnp.asarray(rcond)
|
||||
u, s, v = svd(a, full_matrices=False)
|
||||
# Singular values less than or equal to ``rcond * largest_singular_value``
|
||||
# are set to zero.
|
||||
cutoff = rcond[..., jnp.newaxis] * jnp.amax(s, axis=-1, keepdims=True)
|
||||
s = jnp.where(s > cutoff, s, jnp.inf)
|
||||
res = jnp.matmul(_T(v), jnp.divide(_T(u), s[..., jnp.newaxis]))
|
||||
return lax.convert_element_type(res, a.dtype)
|
||||
|
||||
|
||||
@pinv.defjvp
|
||||
def _pinv_jvp(rcond, primals, tangents):
|
||||
# The Differentiation of Pseudo-Inverses and Nonlinear Least Squares Problems
|
||||
# Whose Variables Separate. Author(s): G. H. Golub and V. Pereyra. SIAM
|
||||
# Journal on Numerical Analysis, Vol. 10, No. 2 (Apr., 1973), pp. 413-432.
|
||||
# (via https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Derivative)
|
||||
a, = primals
|
||||
a_dot, = tangents
|
||||
p = pinv(a, rcond=rcond)
|
||||
m, n = a.shape[-2:]
|
||||
# TODO(phawkins): on TPU, we would need to opt into high precision here.
|
||||
# TODO(phawkins): consider if this can be simplified in the Hermitian case.
|
||||
p_dot = -p @ a_dot @ p
|
||||
p_dot = p_dot + p @ _H(p) @ _H(a_dot) @ (jnp.eye(m, dtype=a.dtype) - a @ p)
|
||||
p_dot = p_dot + (jnp.eye(n, dtype=a.dtype) - p @ a) @ _H(a_dot) @ _H(p) @ p
|
||||
return p, p_dot
|
||||
|
||||
|
||||
@_wraps(np.linalg.inv)
|
||||
def inv(a):
|
||||
if jnp.ndim(a) < 2 or a.shape[-1] != a.shape[-2]:
|
||||
raise ValueError("Argument to inv must have shape [..., n, n], got {}."
|
||||
.format(jnp.shape(a)))
|
||||
return solve(
|
||||
a, lax.broadcast(jnp.eye(a.shape[-1], dtype=lax.dtype(a)), a.shape[:-2]))
|
||||
|
||||
|
||||
@partial(jit, static_argnums=(1, 2, 3))
|
||||
def _norm(x, ord, axis: Union[None, Tuple[int, ...], int], keepdims):
|
||||
x = _promote_arg_dtypes(jnp.asarray(x))
|
||||
x_shape = jnp.shape(x)
|
||||
ndim = len(x_shape)
|
||||
|
||||
if axis is None:
|
||||
# NumPy has an undocumented behavior that admits arbitrary rank inputs if
|
||||
# `ord` is None: https://github.com/numpy/numpy/issues/14215
|
||||
if ord is None:
|
||||
return jnp.sqrt(jnp.sum(jnp.real(x * jnp.conj(x)), keepdims=keepdims))
|
||||
axis = tuple(range(ndim))
|
||||
elif isinstance(axis, tuple):
|
||||
axis = tuple(canonicalize_axis(x, ndim) for x in axis)
|
||||
else:
|
||||
axis = (canonicalize_axis(axis, ndim),)
|
||||
|
||||
num_axes = len(axis)
|
||||
if num_axes == 1:
|
||||
if ord is None or ord == 2:
|
||||
return jnp.sqrt(jnp.sum(jnp.real(x * jnp.conj(x)), axis=axis,
|
||||
keepdims=keepdims))
|
||||
elif ord == jnp.inf:
|
||||
return jnp.amax(jnp.abs(x), axis=axis, keepdims=keepdims)
|
||||
elif ord == -jnp.inf:
|
||||
return jnp.amin(jnp.abs(x), axis=axis, keepdims=keepdims)
|
||||
elif ord == 0:
|
||||
return jnp.sum(x != 0, dtype=jnp.finfo(lax.dtype(x)).dtype,
|
||||
axis=axis, keepdims=keepdims)
|
||||
elif ord == 1:
|
||||
# Numpy has a special case for ord == 1 as an optimization. We don't
|
||||
# really need the optimization (XLA could do it for us), but the Numpy
|
||||
# code has slightly different type promotion semantics, so we need a
|
||||
# special case too.
|
||||
return jnp.sum(jnp.abs(x), axis=axis, keepdims=keepdims)
|
||||
else:
|
||||
abs_x = jnp.abs(x)
|
||||
ord = lax._const(abs_x, ord)
|
||||
out = jnp.sum(abs_x ** ord, axis=axis, keepdims=keepdims)
|
||||
return jnp.power(out, 1. / ord)
|
||||
|
||||
elif num_axes == 2:
|
||||
row_axis, col_axis = cast(Tuple[int, ...], axis)
|
||||
if ord is None or ord in ('f', 'fro'):
|
||||
return jnp.sqrt(jnp.sum(jnp.real(x * jnp.conj(x)), axis=axis,
|
||||
keepdims=keepdims))
|
||||
elif ord == 1:
|
||||
if not keepdims and col_axis > row_axis:
|
||||
col_axis -= 1
|
||||
return jnp.amax(jnp.sum(jnp.abs(x), axis=row_axis, keepdims=keepdims),
|
||||
axis=col_axis, keepdims=keepdims)
|
||||
elif ord == -1:
|
||||
if not keepdims and col_axis > row_axis:
|
||||
col_axis -= 1
|
||||
return jnp.amin(jnp.sum(jnp.abs(x), axis=row_axis, keepdims=keepdims),
|
||||
axis=col_axis, keepdims=keepdims)
|
||||
elif ord == jnp.inf:
|
||||
if not keepdims and row_axis > col_axis:
|
||||
row_axis -= 1
|
||||
return jnp.amax(jnp.sum(jnp.abs(x), axis=col_axis, keepdims=keepdims),
|
||||
axis=row_axis, keepdims=keepdims)
|
||||
elif ord == -jnp.inf:
|
||||
if not keepdims and row_axis > col_axis:
|
||||
row_axis -= 1
|
||||
return jnp.amin(jnp.sum(jnp.abs(x), axis=col_axis, keepdims=keepdims),
|
||||
axis=row_axis, keepdims=keepdims)
|
||||
elif ord in ('nuc', 2, -2):
|
||||
x = jnp.moveaxis(x, axis, (-2, -1))
|
||||
if ord == 2:
|
||||
reducer = jnp.amax
|
||||
elif ord == -2:
|
||||
reducer = jnp.amin
|
||||
else:
|
||||
reducer = jnp.sum
|
||||
y = reducer(svd(x, compute_uv=False), axis=-1)
|
||||
if keepdims:
|
||||
result_shape = list(x_shape)
|
||||
result_shape[axis[0]] = 1
|
||||
result_shape[axis[1]] = 1
|
||||
y = jnp.reshape(y, result_shape)
|
||||
return y
|
||||
else:
|
||||
raise ValueError("Invalid order '{}' for matrix norm.".format(ord))
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid axis values ({}) for jnp.linalg.norm.".format(axis))
|
||||
|
||||
@_wraps(np.linalg.norm)
|
||||
def norm(x, ord=None, axis=None, keepdims=False):
|
||||
return _norm(x, ord, axis, keepdims)
|
||||
|
||||
|
||||
@_wraps(np.linalg.qr)
|
||||
def qr(a, mode="reduced"):
|
||||
if mode in ("reduced", "r", "full"):
|
||||
full_matrices = False
|
||||
elif mode == "complete":
|
||||
full_matrices = True
|
||||
else:
|
||||
raise ValueError("Unsupported QR decomposition mode '{}'".format(mode))
|
||||
a = _promote_arg_dtypes(jnp.asarray(a))
|
||||
q, r = lax_linalg.qr(a, full_matrices)
|
||||
if mode == "r":
|
||||
return r
|
||||
return q, r
|
||||
|
||||
|
||||
def _check_solve_shapes(a, b):
|
||||
if not (a.ndim >= 2 and a.shape[-1] == a.shape[-2] and b.ndim >= 1):
|
||||
msg = ("The arguments to solve must have shapes a=[..., m, m] and "
|
||||
"b=[..., m, k] or b=[..., m]; got a={} and b={}")
|
||||
raise ValueError(msg.format(a.shape, b.shape))
|
||||
|
||||
|
||||
@partial(vectorize, signature='(n,m),(m)->(n)')
|
||||
def _matvec_multiply(a, b):
|
||||
return jnp.dot(a, b, precision=lax.Precision.HIGHEST)
|
||||
|
||||
|
||||
@_wraps(np.linalg.solve)
|
||||
@jit
|
||||
def solve(a, b):
|
||||
a, b = _promote_arg_dtypes(jnp.asarray(a), jnp.asarray(b))
|
||||
_check_solve_shapes(a, b)
|
||||
|
||||
# With custom_linear_solve, we can reuse the same factorization when
|
||||
# computing sensitivities. This is considerably faster.
|
||||
lu, _, permutation = lax_linalg.lu(lax.stop_gradient(a))
|
||||
custom_solve = partial(
|
||||
lax.custom_linear_solve,
|
||||
lambda x: _matvec_multiply(a, x),
|
||||
solve=lambda _, x: lax_linalg.lu_solve(lu, permutation, x, trans=0),
|
||||
transpose_solve=lambda _, x: lax_linalg.lu_solve(lu, permutation, x,
|
||||
trans=1))
|
||||
if a.ndim == b.ndim + 1:
|
||||
# b.shape == [..., m]
|
||||
return custom_solve(b)
|
||||
else:
|
||||
# b.shape == [..., m, k]
|
||||
return vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b)
|
||||
|
||||
|
||||
@_wraps(np.linalg.lstsq, lax_description=textwrap.dedent("""\
|
||||
It has two important differences:
|
||||
|
||||
1. In `numpy.linalg.lstsq`, the default `rcond` is `-1`, and warns that in the future
|
||||
the default will be `None`. Here, the default rcond is `None`.
|
||||
2. In `np.linalg.lstsq` the returned residuals are empty for low-rank or over-determined
|
||||
solutions. Here, the residuals are returned in all cases, to make the function
|
||||
compatible with jit. The non-jit compatible numpy behavior can be recovered by
|
||||
passing numpy_resid=True.
|
||||
|
||||
The lstsq function does not currently have a custom JVP rule, so the gradient is
|
||||
poorly behaved for some inputs, particularly for low-rank `a`.
|
||||
"""))
|
||||
def lstsq(a, b, rcond=None, *, numpy_resid=False):
|
||||
# TODO: add lstsq to lax_linalg and implement this function via those wrappers.
|
||||
# TODO: add custom jvp rule for more robust lstsq differentiation
|
||||
a, b = _promote_arg_dtypes(a, b)
|
||||
if a.shape[0] != b.shape[0]:
|
||||
raise ValueError("Leading dimensions of input arrays must match")
|
||||
b_orig_ndim = b.ndim
|
||||
if b_orig_ndim == 1:
|
||||
b = b[:, None]
|
||||
if a.ndim != 2:
|
||||
raise TypeError(
|
||||
f"{a.ndim}-dimensional array given. Array must be two-dimensional")
|
||||
if b.ndim != 2:
|
||||
raise TypeError(
|
||||
f"{b.ndim}-dimensional array given. Array must be one or two-dimensional")
|
||||
m, n = a.shape
|
||||
dtype = a.dtype
|
||||
if rcond is None:
|
||||
rcond = jnp.finfo(dtype).eps * max(n, m)
|
||||
elif rcond < 0:
|
||||
rcond = jnp.finfo(dtype).eps
|
||||
u, s, vt = svd(a, full_matrices=False)
|
||||
mask = s >= rcond * s[0]
|
||||
rank = mask.sum()
|
||||
safe_s = jnp.where(mask, s, 1)
|
||||
s_inv = jnp.where(mask, 1 / safe_s, 0)[:, jnp.newaxis]
|
||||
uTb = jnp.matmul(u.conj().T, b, precision=lax.Precision.HIGHEST)
|
||||
x = jnp.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST)
|
||||
# Numpy returns empty residuals in some cases. To allow compilation, we
|
||||
# default to returning full residuals in all cases.
|
||||
if numpy_resid and (rank < n or m <= n):
|
||||
resid = jnp.asarray([])
|
||||
else:
|
||||
b_estimate = jnp.matmul(a, x, precision=lax.Precision.HIGHEST)
|
||||
resid = norm(b - b_estimate, axis=0) ** 2
|
||||
if b_orig_ndim == 1:
|
||||
x = x.ravel()
|
||||
return x, resid, rank, s
|
@ -14,15 +14,13 @@
|
||||
|
||||
|
||||
import numpy as np
|
||||
from .. import lax
|
||||
from jax import lax
|
||||
from . import lax_numpy as jnp
|
||||
|
||||
from jax import jit
|
||||
from ._util import _wraps
|
||||
from .lax_numpy import _not_implemented
|
||||
from .util import _wraps
|
||||
from .linalg import eigvals as _eigvals
|
||||
from .. import ops as jaxops
|
||||
from ..util import get_module_functions
|
||||
from jax import ops as jaxops
|
||||
|
||||
|
||||
def _to_inexact_type(type):
|
||||
@ -104,10 +102,3 @@ def roots(p, *, strip_zeros=True):
|
||||
# combine roots and zero roots
|
||||
roots = jnp.hstack((roots, jnp.zeros(trailing_zeros, p.dtype)))
|
||||
return roots
|
||||
|
||||
|
||||
_NOT_IMPLEMENTED = []
|
||||
for name, func in get_module_functions(np.polynomial).items():
|
||||
if name not in globals():
|
||||
_NOT_IMPLEMENTED.append(name)
|
||||
globals()[name] = _not_implemented(func)
|
@ -11,14 +11,15 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import re
|
||||
from typing import Any, Callable, Dict, List, Tuple
|
||||
|
||||
from .. import api
|
||||
from .. import lax
|
||||
from jax import api
|
||||
from jax import lax
|
||||
from . import lax_numpy as jnp
|
||||
from ..util import safe_map as map, safe_zip as zip
|
||||
from jax.util import safe_map as map, safe_zip as zip
|
||||
|
||||
|
||||
# See http://docs.scipy.org/doc/numpy/reference/c-api.generalized-ufuncs.html
|
@ -22,9 +22,9 @@ from jax import jit, vmap
|
||||
from jax import api
|
||||
from jax import lax
|
||||
from jax import lax_linalg
|
||||
from jax.numpy._util import _wraps
|
||||
from jax.numpy import lax_numpy as jnp
|
||||
from jax.numpy import linalg as np_linalg
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src.numpy import linalg as np_linalg
|
||||
|
||||
_T = lambda x: jnp.swapaxes(x, -1, -2)
|
||||
|
||||
|
@ -22,8 +22,8 @@ import scipy.ndimage
|
||||
|
||||
from jax import api
|
||||
from jax import lax
|
||||
from jax.numpy import lax_numpy as jnp
|
||||
from jax.numpy._util import _wraps
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax.util import safe_zip as zip
|
||||
|
||||
|
||||
|
@ -18,10 +18,10 @@ import warnings
|
||||
import numpy as np
|
||||
|
||||
from jax import lax
|
||||
from jax.numpy import lax_numpy as jnp
|
||||
from jax.numpy import linalg
|
||||
from jax.numpy.lax_numpy import _promote_dtypes_inexact
|
||||
from jax.numpy._util import _wraps
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src.numpy import linalg
|
||||
from jax._src.numpy.lax_numpy import _promote_dtypes_inexact
|
||||
from jax._src.numpy.util import _wraps
|
||||
|
||||
|
||||
# Note: we do not re-use the code from jax.numpy.convolve here, because the handling
|
||||
|
@ -20,10 +20,10 @@ import scipy.special as osp_special
|
||||
from jax import lax
|
||||
from jax import api
|
||||
from jax.interpreters import ad
|
||||
from jax.numpy import lax_numpy as jnp
|
||||
from jax.numpy.lax_numpy import (asarray, _reduction_dims, _constant_like,
|
||||
_promote_args_inexact)
|
||||
from jax.numpy._util import _wraps
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src.numpy.lax_numpy import (asarray, _reduction_dims, _constant_like,
|
||||
_promote_args_inexact)
|
||||
from jax._src.numpy.util import _wraps
|
||||
|
||||
|
||||
@_wraps(osp_special.gammaln)
|
||||
|
@ -16,8 +16,8 @@
|
||||
import scipy.stats as osp_stats
|
||||
|
||||
from jax import lax
|
||||
from jax.numpy import lax_numpy as jnp
|
||||
from jax.numpy._util import _wraps
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax.scipy.special import xlogy, xlog1py
|
||||
|
||||
|
||||
|
@ -15,9 +15,9 @@
|
||||
import scipy.stats as osp_stats
|
||||
|
||||
from jax import lax
|
||||
from jax.numpy._util import _wraps
|
||||
from jax.numpy.lax_numpy import (_promote_args_inexact, _constant_like,
|
||||
where, inf, logical_or)
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax._src.numpy.lax_numpy import (_promote_args_inexact, _constant_like,
|
||||
where, inf, logical_or)
|
||||
from jax.scipy.special import betaln
|
||||
|
||||
|
||||
|
@ -17,8 +17,8 @@ import numpy as np
|
||||
import scipy.stats as osp_stats
|
||||
|
||||
from jax import lax
|
||||
from jax.numpy._util import _wraps
|
||||
from jax.numpy.lax_numpy import _promote_args_inexact, _constant_like
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax._src.numpy.lax_numpy import _promote_args_inexact, _constant_like
|
||||
|
||||
|
||||
@_wraps(osp_stats.cauchy.logpdf, update_doc=False)
|
||||
|
@ -17,8 +17,8 @@ import numpy as np
|
||||
import scipy.stats as osp_stats
|
||||
|
||||
from jax import lax
|
||||
from jax.numpy import lax_numpy as jnp
|
||||
from jax.numpy._util import _wraps
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax.scipy.special import gammaln, xlogy
|
||||
|
||||
|
||||
|
@ -15,8 +15,8 @@
|
||||
import scipy.stats as osp_stats
|
||||
|
||||
from jax import lax
|
||||
from jax.numpy._util import _wraps
|
||||
from jax.numpy.lax_numpy import _promote_args_inexact, where, inf
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf
|
||||
|
||||
|
||||
@_wraps(osp_stats.expon.logpdf, update_doc=False)
|
||||
|
@ -15,9 +15,9 @@
|
||||
import scipy.stats as osp_stats
|
||||
|
||||
from jax import lax
|
||||
from jax.numpy._util import _wraps
|
||||
from jax.numpy.lax_numpy import (_promote_args_inexact, _constant_like,
|
||||
where, inf)
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax._src.numpy.lax_numpy import (_promote_args_inexact, _constant_like,
|
||||
where, inf)
|
||||
from jax.scipy.special import gammaln
|
||||
|
||||
|
||||
|
@ -15,8 +15,8 @@
|
||||
import scipy.stats as osp_stats
|
||||
|
||||
from jax import lax
|
||||
from jax.numpy import lax_numpy as jnp
|
||||
from jax.numpy._util import _wraps
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax.scipy.special import xlog1py
|
||||
|
||||
@_wraps(osp_stats.geom.logpmf, update_doc=False)
|
||||
|
@ -15,8 +15,8 @@
|
||||
import scipy.stats as osp_stats
|
||||
|
||||
from jax import lax
|
||||
from jax.numpy._util import _wraps
|
||||
from jax.numpy.lax_numpy import _promote_args_inexact, _constant_like
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax._src.numpy.lax_numpy import _promote_args_inexact, _constant_like
|
||||
|
||||
|
||||
@_wraps(osp_stats.laplace.logpdf, update_doc=False)
|
||||
|
@ -16,7 +16,7 @@ import scipy.stats as osp_stats
|
||||
from jax.scipy.special import expit, logit
|
||||
|
||||
from jax import lax
|
||||
from jax.numpy._util import _wraps
|
||||
from jax._src.numpy.util import _wraps
|
||||
|
||||
|
||||
@_wraps(osp_stats.logistic.logpdf, update_doc=False)
|
||||
|
@ -19,8 +19,8 @@ import scipy.stats as osp_stats
|
||||
from jax import lax
|
||||
from jax.lax_linalg import cholesky, triangular_solve
|
||||
from jax import numpy as jnp
|
||||
from jax.numpy._util import _wraps
|
||||
from jax.numpy.lax_numpy import _promote_dtypes_inexact
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax._src.numpy.lax_numpy import _promote_dtypes_inexact
|
||||
|
||||
|
||||
@_wraps(osp_stats.multivariate_normal.logpdf, update_doc=False)
|
||||
|
@ -17,9 +17,9 @@ import numpy as np
|
||||
import scipy.stats as osp_stats
|
||||
|
||||
from jax import lax
|
||||
from jax.numpy import lax_numpy as jnp
|
||||
from jax.numpy._util import _wraps
|
||||
from jax.numpy.lax_numpy import _promote_args_inexact, _constant_like
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax._src.numpy.lax_numpy import _promote_args_inexact, _constant_like
|
||||
from jax.scipy import special
|
||||
|
||||
@_wraps(osp_stats.norm.logpdf, update_doc=False)
|
||||
|
@ -16,8 +16,8 @@
|
||||
import scipy.stats as osp_stats
|
||||
|
||||
from jax import lax
|
||||
from jax.numpy._util import _wraps
|
||||
from jax.numpy.lax_numpy import _promote_args_inexact, _constant_like, inf, where
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax._src.numpy.lax_numpy import _promote_args_inexact, _constant_like, inf, where
|
||||
|
||||
|
||||
@_wraps(osp_stats.pareto.logpdf, update_doc=False)
|
||||
|
@ -16,8 +16,8 @@
|
||||
import scipy.stats as osp_stats
|
||||
|
||||
from jax import lax
|
||||
from jax.numpy._util import _wraps
|
||||
from jax.numpy import lax_numpy as jnp
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from jax.scipy.special import xlogy, gammaln
|
||||
|
||||
|
||||
|
@ -17,8 +17,8 @@ import numpy as np
|
||||
import scipy.stats as osp_stats
|
||||
|
||||
from jax import lax
|
||||
from jax.numpy._util import _wraps
|
||||
from jax.numpy.lax_numpy import _promote_args_inexact, _constant_like
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax._src.numpy.lax_numpy import _promote_args_inexact, _constant_like
|
||||
|
||||
|
||||
@_wraps(osp_stats.t.logpdf, update_doc=False)
|
||||
|
@ -16,8 +16,8 @@
|
||||
import scipy.stats as osp_stats
|
||||
|
||||
from jax import lax
|
||||
from jax.numpy._util import _wraps
|
||||
from jax.numpy.lax_numpy import _promote_args_inexact, where, inf, logical_or
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf, logical_or
|
||||
|
||||
|
||||
@_wraps(osp_stats.uniform.logpdf, update_doc=False)
|
||||
|
@ -35,7 +35,7 @@ from jax.util import partial, unzip2, prod
|
||||
from jax.lib import xla_client as xc
|
||||
from jax.lib import xla_bridge as xb
|
||||
from jax.config import config
|
||||
from jax.numpy import lax_numpy
|
||||
from jax._src.numpy import lax_numpy
|
||||
|
||||
xops = xc.ops
|
||||
|
||||
|
@ -15,8 +15,8 @@
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax.numpy import lax_numpy as jnp
|
||||
from jax.numpy.vectorize import vectorize
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src.numpy.vectorize import vectorize
|
||||
from jax import ad_util
|
||||
from jax import api
|
||||
from jax import lax
|
||||
|
@ -16,9 +16,9 @@
|
||||
from . import fft
|
||||
from . import linalg
|
||||
|
||||
from ..interpreters.xla import DeviceArray
|
||||
from jax.interpreters.xla import DeviceArray
|
||||
|
||||
from .lax_numpy import (
|
||||
from jax._src.numpy.lax_numpy import (
|
||||
ComplexWarning, NINF, NZERO, PZERO, abs, absolute, add, all, allclose,
|
||||
alltrue, amax, amin, angle, any, append,
|
||||
apply_along_axis, apply_over_axes, arange, arccos, arccosh, arcsin,
|
||||
@ -63,16 +63,18 @@ from .lax_numpy import (
|
||||
unpackbits, unravel_index, unsignedinteger, unwrap, vander, var, vdot, vsplit,
|
||||
vstack, where, zeros, zeros_like, _NOT_IMPLEMENTED)
|
||||
|
||||
from .polynomial import roots
|
||||
from .vectorize import vectorize
|
||||
from jax._src.numpy.polynomial import roots
|
||||
from jax._src.numpy.vectorize import vectorize
|
||||
|
||||
# TODO(phawkins): remove this import after fixing users.
|
||||
from jax._src.numpy import lax_numpy
|
||||
|
||||
# Module initialization is encapsulated in a function to avoid accidental
|
||||
# namespace pollution.
|
||||
def _init():
|
||||
import numpy as np
|
||||
from . import lax_numpy
|
||||
from .. import util
|
||||
from jax._src.numpy import lax_numpy
|
||||
from jax import util
|
||||
# Builds a set of all unimplemented NumPy functions.
|
||||
for name, func in util.get_module_functions(np).items():
|
||||
if name not in globals():
|
||||
|
282
jax/numpy/fft.py
282
jax/numpy/fft.py
@ -1,4 +1,4 @@
|
||||
# Copyright 2018 Google LLC
|
||||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,251 +12,41 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa: F401
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .. import lax
|
||||
from ..lib import xla_client
|
||||
from ..util import get_module_functions
|
||||
from .lax_numpy import _not_implemented
|
||||
from ._util import _wraps
|
||||
from . import lax_numpy as jnp
|
||||
from .. import ops as jaxops
|
||||
|
||||
|
||||
def _fft_core(func_name, fft_type, a, s, axes, norm):
|
||||
# TODO(skye): implement padding/cropping based on 's'.
|
||||
full_name = "jax.numpy.fft." + func_name
|
||||
if s is not None:
|
||||
raise NotImplementedError("%s only supports s=None, got %s" % (full_name, s))
|
||||
if norm is not None:
|
||||
raise NotImplementedError("%s only supports norm=None, got %s" % (full_name, norm))
|
||||
if s is not None and axes is not None and len(s) != len(axes):
|
||||
# Same error as numpy.
|
||||
raise ValueError("Shape and axes have different lengths.")
|
||||
|
||||
orig_axes = axes
|
||||
if axes is None:
|
||||
if s is None:
|
||||
axes = range(a.ndim)
|
||||
else:
|
||||
axes = range(a.ndim - len(s), a.ndim)
|
||||
|
||||
if len(axes) != len(set(axes)):
|
||||
raise ValueError(
|
||||
"%s does not support repeated axes. Got axes %s." % (full_name, axes))
|
||||
|
||||
if len(axes) > 3:
|
||||
# XLA does not support FFTs over more than 3 dimensions
|
||||
raise ValueError(
|
||||
"%s only supports 1D, 2D, and 3D FFTs. "
|
||||
"Got axes %s with input rank %s." % (full_name, orig_axes, a.ndim))
|
||||
|
||||
# XLA only supports FFTs over the innermost axes, so rearrange if necessary.
|
||||
if orig_axes is not None:
|
||||
axes = tuple(range(a.ndim - len(axes), a.ndim))
|
||||
a = jnp.moveaxis(a, orig_axes, axes)
|
||||
|
||||
if s is None:
|
||||
if fft_type == xla_client.FftType.IRFFT:
|
||||
s = [a.shape[axis] for axis in axes[:-1]]
|
||||
if axes:
|
||||
s += [max(0, 2 * (a.shape[axes[-1]] - 1))]
|
||||
else:
|
||||
s = [a.shape[axis] for axis in axes]
|
||||
|
||||
transformed = lax.fft(a, fft_type, s)
|
||||
|
||||
if orig_axes is not None:
|
||||
transformed = jnp.moveaxis(transformed, axes, orig_axes)
|
||||
return transformed
|
||||
|
||||
|
||||
@_wraps(np.fft.fftn)
|
||||
def fftn(a, s=None, axes=None, norm=None):
|
||||
return _fft_core('fftn', xla_client.FftType.FFT, a, s, axes, norm)
|
||||
|
||||
|
||||
@_wraps(np.fft.ifftn)
|
||||
def ifftn(a, s=None, axes=None, norm=None):
|
||||
return _fft_core('ifftn', xla_client.FftType.IFFT, a, s, axes, norm)
|
||||
|
||||
|
||||
@_wraps(np.fft.rfftn)
|
||||
def rfftn(a, s=None, axes=None, norm=None):
|
||||
return _fft_core('rfftn', xla_client.FftType.RFFT, a, s, axes, norm)
|
||||
|
||||
|
||||
@_wraps(np.fft.irfftn)
|
||||
def irfftn(a, s=None, axes=None, norm=None):
|
||||
return _fft_core('irfftn', xla_client.FftType.IRFFT, a, s, axes, norm)
|
||||
|
||||
|
||||
def _axis_check_1d(func_name, axis):
|
||||
full_name = "jax.numpy.fft." + func_name
|
||||
if isinstance(axis, (list, tuple)):
|
||||
raise ValueError(
|
||||
"%s does not support multiple axes. Please use %sn. "
|
||||
"Got axis = %r." % (full_name, full_name, axis)
|
||||
)
|
||||
|
||||
def _fft_core_1d(func_name, fft_type, a, s, axis, norm):
|
||||
_axis_check_1d(func_name, axis)
|
||||
axes = None if axis is None else [axis]
|
||||
return _fft_core(func_name, fft_type, a, s, axes, norm)
|
||||
|
||||
|
||||
@_wraps(np.fft.fft)
|
||||
def fft(a, n=None, axis=-1, norm=None):
|
||||
return _fft_core_1d('fft', xla_client.FftType.FFT, a, s=n, axis=axis,
|
||||
norm=norm)
|
||||
|
||||
@_wraps(np.fft.ifft)
|
||||
def ifft(a, n=None, axis=-1, norm=None):
|
||||
return _fft_core_1d('ifft', xla_client.FftType.IFFT, a, s=n, axis=axis,
|
||||
norm=norm)
|
||||
|
||||
@_wraps(np.fft.rfft)
|
||||
def rfft(a, n=None, axis=-1, norm=None):
|
||||
return _fft_core_1d('rfft', xla_client.FftType.RFFT, a, s=n, axis=axis,
|
||||
norm=norm)
|
||||
|
||||
@_wraps(np.fft.irfft)
|
||||
def irfft(a, n=None, axis=-1, norm=None):
|
||||
return _fft_core_1d('irfft', xla_client.FftType.IRFFT, a, s=n, axis=axis,
|
||||
norm=norm)
|
||||
|
||||
@_wraps(np.fft.hfft)
|
||||
def hfft(a, n=None, axis=-1, norm=None):
|
||||
conj_a = jnp.conj(a)
|
||||
_axis_check_1d('hfft', axis)
|
||||
nn = (a.shape[axis] - 1) * 2 if n is None else n
|
||||
return _fft_core_1d('hfft', xla_client.FftType.IRFFT, conj_a, s=n, axis=axis,
|
||||
norm=norm) * nn
|
||||
|
||||
@_wraps(np.fft.ihfft)
|
||||
def ihfft(a, n=None, axis=-1, norm=None):
|
||||
_axis_check_1d('ihfft', axis)
|
||||
nn = a.shape[axis] if n is None else n
|
||||
output = _fft_core_1d('ihfft', xla_client.FftType.RFFT, a, s=n, axis=axis,
|
||||
norm=norm)
|
||||
return jnp.conj(output) * (1 / nn)
|
||||
|
||||
|
||||
def _fft_core_2d(func_name, fft_type, a, s, axes, norm):
|
||||
full_name = "jax.numpy.fft." + func_name
|
||||
if len(axes) != 2:
|
||||
raise ValueError(
|
||||
"%s only supports 2 axes. Got axes = %r."
|
||||
% (full_name, axes)
|
||||
)
|
||||
return _fft_core(func_name, fft_type, a, s, axes, norm)
|
||||
|
||||
|
||||
@_wraps(np.fft.fft2)
|
||||
def fft2(a, s=None, axes=(-2,-1), norm=None):
|
||||
return _fft_core_2d('fft2', xla_client.FftType.FFT, a, s=s, axes=axes,
|
||||
norm=norm)
|
||||
|
||||
@_wraps(np.fft.ifft2)
|
||||
def ifft2(a, s=None, axes=(-2,-1), norm=None):
|
||||
return _fft_core_2d('ifft2', xla_client.FftType.IFFT, a, s=s, axes=axes,
|
||||
norm=norm)
|
||||
|
||||
@_wraps(np.fft.rfft2)
|
||||
def rfft2(a, s=None, axes=(-2,-1), norm=None):
|
||||
return _fft_core_2d('rfft2', xla_client.FftType.RFFT, a, s=s, axes=axes,
|
||||
norm=norm)
|
||||
|
||||
@_wraps(np.fft.irfft2)
|
||||
def irfft2(a, s=None, axes=(-2,-1), norm=None):
|
||||
return _fft_core_2d('irfft2', xla_client.FftType.IRFFT, a, s=s, axes=axes,
|
||||
norm=norm)
|
||||
|
||||
|
||||
@_wraps(np.fft.fftfreq)
|
||||
def fftfreq(n, d=1.0):
|
||||
if isinstance(n, (list, tuple)):
|
||||
raise ValueError(
|
||||
"The n argument of jax.numpy.fft.fftfreq only takes an int. "
|
||||
"Got n = %s." % list(n))
|
||||
|
||||
elif isinstance(d, (list, tuple)):
|
||||
raise ValueError(
|
||||
"The d argument of jax.numpy.fft.fftfreq only takes a single value. "
|
||||
"Got d = %s." % list(d))
|
||||
|
||||
k = jnp.zeros(n)
|
||||
if n % 2 == 0:
|
||||
# k[0: n // 2 - 1] = jnp.arange(0, n // 2 - 1)
|
||||
k = jaxops.index_update(k, jaxops.index[0: n // 2], jnp.arange(0, n // 2))
|
||||
|
||||
# k[n // 2:] = jnp.arange(-n // 2, -1)
|
||||
k = jaxops.index_update(k, jaxops.index[n // 2:], jnp.arange(-n // 2, 0))
|
||||
|
||||
else:
|
||||
# k[0: (n - 1) // 2] = jnp.arange(0, (n - 1) // 2)
|
||||
k = jaxops.index_update(k, jaxops.index[0: (n - 1) // 2 + 1],
|
||||
jnp.arange(0, (n - 1) // 2 + 1))
|
||||
|
||||
# k[(n - 1) // 2 + 1:] = jnp.arange(-(n - 1) // 2, -1)
|
||||
k = jaxops.index_update(k, jaxops.index[(n - 1) // 2 + 1:],
|
||||
jnp.arange(-(n - 1) // 2, 0))
|
||||
|
||||
return k / (d * n)
|
||||
|
||||
|
||||
@_wraps(np.fft.rfftfreq)
|
||||
def rfftfreq(n, d=1.0):
|
||||
if isinstance(n, (list, tuple)):
|
||||
raise ValueError(
|
||||
"The n argument of jax.numpy.fft.rfftfreq only takes an int. "
|
||||
"Got n = %s." % list(n))
|
||||
|
||||
elif isinstance(d, (list, tuple)):
|
||||
raise ValueError(
|
||||
"The d argument of jax.numpy.fft.rfftfreq only takes a single value. "
|
||||
"Got d = %s." % list(d))
|
||||
|
||||
if n % 2 == 0:
|
||||
k = jnp.arange(0, n // 2 + 1)
|
||||
|
||||
else:
|
||||
k = jnp.arange(0, (n - 1) // 2 + 1)
|
||||
|
||||
return k / (d * n)
|
||||
|
||||
|
||||
@_wraps(np.fft.fftshift)
|
||||
def fftshift(x, axes=None):
|
||||
x = jnp.asarray(x)
|
||||
if axes is None:
|
||||
axes = tuple(range(x.ndim))
|
||||
shift = [dim // 2 for dim in x.shape]
|
||||
elif isinstance(axes, int):
|
||||
shift = x.shape[axes] // 2
|
||||
else:
|
||||
shift = [x.shape[ax] // 2 for ax in axes]
|
||||
|
||||
return jnp.roll(x, shift, axes)
|
||||
|
||||
|
||||
@_wraps(np.fft.ifftshift)
|
||||
def ifftshift(x, axes=None):
|
||||
x = jnp.asarray(x)
|
||||
if axes is None:
|
||||
axes = tuple(range(x.ndim))
|
||||
shift = [-(dim // 2) for dim in x.shape]
|
||||
elif isinstance(axes, int):
|
||||
shift = -(x.shape[axes] // 2)
|
||||
else:
|
||||
shift = [-(x.shape[ax] // 2) for ax in axes]
|
||||
|
||||
return jnp.roll(x, shift, axes)
|
||||
|
||||
from jax._src.numpy.fft import (
|
||||
ifft,
|
||||
ifft2,
|
||||
ifftn,
|
||||
ifftshift,
|
||||
ihfft,
|
||||
irfft,
|
||||
irfft2,
|
||||
irfftn,
|
||||
fft,
|
||||
fft2,
|
||||
fftfreq,
|
||||
fftn,
|
||||
fftshift,
|
||||
hfft,
|
||||
rfft,
|
||||
rfft2,
|
||||
rfftfreq,
|
||||
rfftn,
|
||||
)
|
||||
|
||||
# Module initialization is encapsulated in a function to avoid accidental
|
||||
# namespace pollution.
|
||||
_NOT_IMPLEMENTED = []
|
||||
for name, func in get_module_functions(np.fft).items():
|
||||
if name not in globals():
|
||||
_NOT_IMPLEMENTED.append(name)
|
||||
globals()[name] = _not_implemented(func)
|
||||
def _init():
|
||||
import numpy as np
|
||||
from jax._src.numpy import lax_numpy
|
||||
from jax import util
|
||||
# Builds a set of all unimplemented NumPy functions.
|
||||
for name, func in util.get_module_functions(np.fft).items():
|
||||
if name not in globals():
|
||||
_NOT_IMPLEMENTED.append(name)
|
||||
globals()[name] = lax_numpy._not_implemented(func)
|
||||
|
||||
_init()
|
||||
del _init
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2018 Google LLC
|
||||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,527 +12,43 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import textwrap
|
||||
import operator
|
||||
from typing import Tuple, Union, cast
|
||||
|
||||
from jax import jit, vmap, custom_jvp
|
||||
from .. import lax
|
||||
from .. import ops
|
||||
from .. import lax_linalg
|
||||
from .. import dtypes
|
||||
from .lax_numpy import _not_implemented
|
||||
from ._util import _wraps
|
||||
from .vectorize import vectorize
|
||||
from . import lax_numpy as jnp
|
||||
from ..util import get_module_functions, canonicalize_axis
|
||||
from ..third_party.numpy.linalg import cond, multi_dot, tensorinv, tensorsolve # noqa: F401
|
||||
|
||||
_T = lambda x: jnp.swapaxes(x, -1, -2)
|
||||
_H = lambda x: jnp.conj(jnp.swapaxes(x, -1, -2))
|
||||
|
||||
|
||||
def _promote_arg_dtypes(*args):
|
||||
"""Promotes `args` to a common inexact type."""
|
||||
def _to_inexact_type(type):
|
||||
return type if jnp.issubdtype(type, jnp.inexact) else jnp.float_
|
||||
inexact_types = [_to_inexact_type(jnp._dtype(arg)) for arg in args]
|
||||
dtype = dtypes.canonicalize_dtype(jnp.result_type(*inexact_types))
|
||||
args = [lax.convert_element_type(arg, dtype) for arg in args]
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
else:
|
||||
return args
|
||||
|
||||
|
||||
@_wraps(np.linalg.cholesky)
|
||||
def cholesky(a):
|
||||
a = _promote_arg_dtypes(jnp.asarray(a))
|
||||
return lax_linalg.cholesky(a)
|
||||
|
||||
|
||||
@_wraps(np.linalg.svd)
|
||||
def svd(a, full_matrices=True, compute_uv=True):
|
||||
a = _promote_arg_dtypes(jnp.asarray(a))
|
||||
return lax_linalg.svd(a, full_matrices, compute_uv)
|
||||
|
||||
|
||||
@_wraps(np.linalg.matrix_power)
|
||||
def matrix_power(a, n):
|
||||
a = _promote_arg_dtypes(jnp.asarray(a))
|
||||
|
||||
if a.ndim < 2:
|
||||
raise TypeError("{}-dimensional array given. Array must be at least "
|
||||
"two-dimensional".format(a.ndim))
|
||||
if a.shape[-2] != a.shape[-1]:
|
||||
raise TypeError("Last 2 dimensions of the array must be square")
|
||||
try:
|
||||
n = operator.index(n)
|
||||
except TypeError as err:
|
||||
raise TypeError("exponent must be an integer, got {}".format(n)) from err
|
||||
|
||||
if n == 0:
|
||||
return jnp.broadcast_to(jnp.eye(a.shape[-2], dtype=a.dtype), a.shape)
|
||||
elif n < 0:
|
||||
a = inv(a)
|
||||
n = np.abs(n)
|
||||
|
||||
if n == 1:
|
||||
return a
|
||||
elif n == 2:
|
||||
return a @ a
|
||||
elif n == 3:
|
||||
return (a @ a) @ a
|
||||
|
||||
z = result = None
|
||||
while n > 0:
|
||||
z = a if z is None else (z @ z)
|
||||
n, bit = divmod(n, 2)
|
||||
if bit:
|
||||
result = z if result is None else (result @ z)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@_wraps(np.linalg.matrix_rank)
|
||||
def matrix_rank(M, tol=None):
|
||||
M = _promote_arg_dtypes(jnp.asarray(M))
|
||||
if M.ndim > 2:
|
||||
raise TypeError("array should have 2 or fewer dimensions")
|
||||
if M.ndim < 2:
|
||||
return jnp.any(M != 0).astype(jnp.int32)
|
||||
S = svd(M, full_matrices=False, compute_uv=False)
|
||||
if tol is None:
|
||||
tol = S.max() * np.max(M.shape) * jnp.finfo(S.dtype).eps
|
||||
return jnp.sum(S > tol)
|
||||
|
||||
|
||||
@custom_jvp
|
||||
@_wraps(np.linalg.slogdet)
|
||||
@jit
|
||||
def slogdet(a):
|
||||
a = _promote_arg_dtypes(jnp.asarray(a))
|
||||
dtype = lax.dtype(a)
|
||||
a_shape = jnp.shape(a)
|
||||
if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]:
|
||||
msg = "Argument to slogdet() must have shape [..., n, n], got {}"
|
||||
raise ValueError(msg.format(a_shape))
|
||||
lu, pivot, _ = lax_linalg.lu(a)
|
||||
diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
|
||||
is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1)
|
||||
parity = jnp.count_nonzero(pivot != jnp.arange(a_shape[-1]), axis=-1)
|
||||
if jnp.iscomplexobj(a):
|
||||
sign = jnp.prod(diag / jnp.abs(diag), axis=-1)
|
||||
else:
|
||||
sign = jnp.array(1, dtype=dtype)
|
||||
parity = parity + jnp.count_nonzero(diag < 0, axis=-1)
|
||||
sign = jnp.where(is_zero,
|
||||
jnp.array(0, dtype=dtype),
|
||||
sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype))
|
||||
logdet = jnp.where(
|
||||
is_zero, jnp.array(-jnp.inf, dtype=dtype),
|
||||
jnp.sum(jnp.log(jnp.abs(diag)), axis=-1))
|
||||
return sign, jnp.real(logdet)
|
||||
|
||||
@slogdet.defjvp
|
||||
def _slogdet_jvp(primals, tangents):
|
||||
x, = primals
|
||||
g, = tangents
|
||||
if jnp.issubdtype(jnp._dtype(x), jnp.complexfloating):
|
||||
raise NotImplementedError # TODO(pfau): make this work for complex types
|
||||
sign, ans = slogdet(x)
|
||||
sign_dot, ans_dot = jnp.zeros_like(sign), jnp.trace(solve(x, g), axis1=-1, axis2=-2)
|
||||
return (sign, ans), (sign_dot, ans_dot)
|
||||
|
||||
|
||||
def _cofactor_solve(a, b):
|
||||
"""Equivalent to det(a)*solve(a, b) for nonsingular mat.
|
||||
|
||||
Intermediate function used for jvp and vjp of det.
|
||||
This function borrows heavily from jax.numpy.linalg.solve and
|
||||
jax.numpy.linalg.slogdet to compute the gradient of the determinant
|
||||
in a way that is well defined even for low rank matrices.
|
||||
|
||||
This function handles two different cases:
|
||||
* rank(a) == n or n-1
|
||||
* rank(a) < n-1
|
||||
|
||||
For rank n-1 matrices, the gradient of the determinant is a rank 1 matrix.
|
||||
Rather than computing det(a)*solve(a, b), which would return NaN, we work
|
||||
directly with the LU decomposition. If a = p @ l @ u, then
|
||||
det(a)*solve(a, b) =
|
||||
prod(diag(u)) * u^-1 @ l^-1 @ p^-1 b =
|
||||
prod(diag(u)) * triangular_solve(u, solve(p @ l, b))
|
||||
If a is rank n-1, then the lower right corner of u will be zero and the
|
||||
triangular_solve will fail.
|
||||
Let x = solve(p @ l, b) and y = det(a)*solve(a, b).
|
||||
Then y_{n}
|
||||
x_{n} / u_{nn} * prod_{i=1...n}(u_{ii}) =
|
||||
x_{n} * prod_{i=1...n-1}(u_{ii})
|
||||
So by replacing the lower-right corner of u with prod_{i=1...n-1}(u_{ii})^-1
|
||||
we can avoid the triangular_solve failing.
|
||||
To correctly compute the rest of y_{i} for i != n, we simply multiply
|
||||
x_{i} by det(a) for all i != n, which will be zero if rank(a) = n-1.
|
||||
|
||||
For the second case, a check is done on the matrix to see if `solve`
|
||||
returns NaN or Inf, and gives a matrix of zeros as a result, as the
|
||||
gradient of the determinant of a matrix with rank less than n-1 is 0.
|
||||
This will still return the correct value for rank n-1 matrices, as the check
|
||||
is applied *after* the lower right corner of u has been updated.
|
||||
|
||||
Args:
|
||||
a: A square matrix or batch of matrices, possibly singular.
|
||||
b: A matrix, or batch of matrices of the same dimension as a.
|
||||
|
||||
Returns:
|
||||
det(a) and cofactor(a)^T*b, aka adjugate(a)*b
|
||||
"""
|
||||
a = _promote_arg_dtypes(jnp.asarray(a))
|
||||
b = _promote_arg_dtypes(jnp.asarray(b))
|
||||
a_shape = jnp.shape(a)
|
||||
b_shape = jnp.shape(b)
|
||||
a_ndims = len(a_shape)
|
||||
if not (a_ndims >= 2 and a_shape[-1] == a_shape[-2]
|
||||
and b_shape[-2:] == a_shape[-2:]):
|
||||
msg = ("The arguments to _cofactor_solve must have shapes "
|
||||
"a=[..., m, m] and b=[..., m, m]; got a={} and b={}")
|
||||
raise ValueError(msg.format(a_shape, b_shape))
|
||||
if a_shape[-1] == 1:
|
||||
return a[0, 0], b
|
||||
# lu contains u in the upper triangular matrix and l in the strict lower
|
||||
# triangular matrix.
|
||||
# The diagonal of l is set to ones without loss of generality.
|
||||
lu, pivots, permutation = lax_linalg.lu(a)
|
||||
dtype = lax.dtype(a)
|
||||
batch_dims = lax.broadcast_shapes(lu.shape[:-2], b.shape[:-2])
|
||||
x = jnp.broadcast_to(b, batch_dims + b.shape[-2:])
|
||||
lu = jnp.broadcast_to(lu, batch_dims + lu.shape[-2:])
|
||||
# Compute (partial) determinant, ignoring last diagonal of LU
|
||||
diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
|
||||
parity = jnp.count_nonzero(pivots != jnp.arange(a_shape[-1]), axis=-1)
|
||||
sign = jnp.array(-2 * (parity % 2) + 1, dtype=dtype)
|
||||
# partial_det[:, -1] contains the full determinant and
|
||||
# partial_det[:, -2] contains det(u) / u_{nn}.
|
||||
partial_det = jnp.cumprod(diag, axis=-1) * sign[..., None]
|
||||
lu = ops.index_update(lu, ops.index[..., -1, -1], 1.0 / partial_det[..., -2])
|
||||
permutation = jnp.broadcast_to(permutation, batch_dims + (a_shape[-1],))
|
||||
iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1,)))
|
||||
# filter out any matrices that are not full rank
|
||||
d = jnp.ones(x.shape[:-1], x.dtype)
|
||||
d = lax_linalg.triangular_solve(lu, d, left_side=True, lower=False)
|
||||
d = jnp.any(jnp.logical_or(jnp.isnan(d), jnp.isinf(d)), axis=-1)
|
||||
d = jnp.tile(d[..., None, None], d.ndim*(1,) + x.shape[-2:])
|
||||
x = jnp.where(d, jnp.zeros_like(x), x) # first filter
|
||||
x = x[iotas[:-1] + (permutation, slice(None))]
|
||||
x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=True,
|
||||
unit_diagonal=True)
|
||||
x = jnp.concatenate((x[..., :-1, :] * partial_det[..., -1, None, None],
|
||||
x[..., -1:, :]), axis=-2)
|
||||
x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=False)
|
||||
x = jnp.where(d, jnp.zeros_like(x), x) # second filter
|
||||
|
||||
return partial_det[..., -1], x
|
||||
|
||||
|
||||
@custom_jvp
|
||||
@_wraps(np.linalg.det)
|
||||
def det(a):
|
||||
sign, logdet = slogdet(a)
|
||||
return sign * jnp.exp(logdet)
|
||||
|
||||
|
||||
@det.defjvp
|
||||
def _det_jvp(primals, tangents):
|
||||
x, = primals
|
||||
g, = tangents
|
||||
y, z = _cofactor_solve(x, g)
|
||||
return y, jnp.trace(z, axis1=-1, axis2=-2)
|
||||
|
||||
|
||||
@_wraps(np.linalg.eig)
|
||||
def eig(a):
|
||||
a = _promote_arg_dtypes(jnp.asarray(a))
|
||||
return lax_linalg.eig(a, compute_left_eigenvectors=False)
|
||||
|
||||
|
||||
@_wraps(np.linalg.eigvals)
|
||||
def eigvals(a):
|
||||
return lax_linalg.eig(a, compute_left_eigenvectors=False,
|
||||
compute_right_eigenvectors=False)[0]
|
||||
|
||||
|
||||
@_wraps(np.linalg.eigh)
|
||||
def eigh(a, UPLO=None, symmetrize_input=True):
|
||||
if UPLO is None or UPLO == "L":
|
||||
lower = True
|
||||
elif UPLO == "U":
|
||||
lower = False
|
||||
else:
|
||||
msg = "UPLO must be one of None, 'L', or 'U', got {}".format(UPLO)
|
||||
raise ValueError(msg)
|
||||
|
||||
a = _promote_arg_dtypes(jnp.asarray(a))
|
||||
v, w = lax_linalg.eigh(a, lower=lower, symmetrize_input=symmetrize_input)
|
||||
return w, v
|
||||
|
||||
|
||||
@_wraps(np.linalg.eigvalsh)
|
||||
def eigvalsh(a, UPLO='L'):
|
||||
w, _ = eigh(a, UPLO)
|
||||
return w
|
||||
|
||||
|
||||
@partial(custom_jvp, nondiff_argnums=(1,))
|
||||
@_wraps(np.linalg.pinv, lax_description=textwrap.dedent("""\
|
||||
It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the
|
||||
default `rcond` is `1e-15`. Here the default is
|
||||
`10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps`.
|
||||
"""))
|
||||
def pinv(a, rcond=None):
|
||||
# Uses same algorithm as
|
||||
# https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979
|
||||
a = jnp.conj(a)
|
||||
if rcond is None:
|
||||
max_rows_cols = max(a.shape[-2:])
|
||||
rcond = 10. * max_rows_cols * jnp.finfo(a.dtype).eps
|
||||
rcond = jnp.asarray(rcond)
|
||||
u, s, v = svd(a, full_matrices=False)
|
||||
# Singular values less than or equal to ``rcond * largest_singular_value``
|
||||
# are set to zero.
|
||||
cutoff = rcond[..., jnp.newaxis] * jnp.amax(s, axis=-1, keepdims=True)
|
||||
s = jnp.where(s > cutoff, s, jnp.inf)
|
||||
res = jnp.matmul(_T(v), jnp.divide(_T(u), s[..., jnp.newaxis]))
|
||||
return lax.convert_element_type(res, a.dtype)
|
||||
|
||||
|
||||
@pinv.defjvp
|
||||
def _pinv_jvp(rcond, primals, tangents):
|
||||
# The Differentiation of Pseudo-Inverses and Nonlinear Least Squares Problems
|
||||
# Whose Variables Separate. Author(s): G. H. Golub and V. Pereyra. SIAM
|
||||
# Journal on Numerical Analysis, Vol. 10, No. 2 (Apr., 1973), pp. 413-432.
|
||||
# (via https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Derivative)
|
||||
a, = primals
|
||||
a_dot, = tangents
|
||||
p = pinv(a, rcond=rcond)
|
||||
m, n = a.shape[-2:]
|
||||
# TODO(phawkins): on TPU, we would need to opt into high precision here.
|
||||
# TODO(phawkins): consider if this can be simplified in the Hermitian case.
|
||||
p_dot = -p @ a_dot @ p
|
||||
p_dot = p_dot + p @ _H(p) @ _H(a_dot) @ (jnp.eye(m, dtype=a.dtype) - a @ p)
|
||||
p_dot = p_dot + (jnp.eye(n, dtype=a.dtype) - p @ a) @ _H(a_dot) @ _H(p) @ p
|
||||
return p, p_dot
|
||||
|
||||
|
||||
@_wraps(np.linalg.inv)
|
||||
def inv(a):
|
||||
if jnp.ndim(a) < 2 or a.shape[-1] != a.shape[-2]:
|
||||
raise ValueError("Argument to inv must have shape [..., n, n], got {}."
|
||||
.format(jnp.shape(a)))
|
||||
return solve(
|
||||
a, lax.broadcast(jnp.eye(a.shape[-1], dtype=lax.dtype(a)), a.shape[:-2]))
|
||||
|
||||
|
||||
@partial(jit, static_argnums=(1, 2, 3))
|
||||
def _norm(x, ord, axis: Union[None, Tuple[int, ...], int], keepdims):
|
||||
x = _promote_arg_dtypes(jnp.asarray(x))
|
||||
x_shape = jnp.shape(x)
|
||||
ndim = len(x_shape)
|
||||
|
||||
if axis is None:
|
||||
# NumPy has an undocumented behavior that admits arbitrary rank inputs if
|
||||
# `ord` is None: https://github.com/numpy/numpy/issues/14215
|
||||
if ord is None:
|
||||
return jnp.sqrt(jnp.sum(jnp.real(x * jnp.conj(x)), keepdims=keepdims))
|
||||
axis = tuple(range(ndim))
|
||||
elif isinstance(axis, tuple):
|
||||
axis = tuple(canonicalize_axis(x, ndim) for x in axis)
|
||||
else:
|
||||
axis = (canonicalize_axis(axis, ndim),)
|
||||
|
||||
num_axes = len(axis)
|
||||
if num_axes == 1:
|
||||
if ord is None or ord == 2:
|
||||
return jnp.sqrt(jnp.sum(jnp.real(x * jnp.conj(x)), axis=axis,
|
||||
keepdims=keepdims))
|
||||
elif ord == jnp.inf:
|
||||
return jnp.amax(jnp.abs(x), axis=axis, keepdims=keepdims)
|
||||
elif ord == -jnp.inf:
|
||||
return jnp.amin(jnp.abs(x), axis=axis, keepdims=keepdims)
|
||||
elif ord == 0:
|
||||
return jnp.sum(x != 0, dtype=jnp.finfo(lax.dtype(x)).dtype,
|
||||
axis=axis, keepdims=keepdims)
|
||||
elif ord == 1:
|
||||
# Numpy has a special case for ord == 1 as an optimization. We don't
|
||||
# really need the optimization (XLA could do it for us), but the Numpy
|
||||
# code has slightly different type promotion semantics, so we need a
|
||||
# special case too.
|
||||
return jnp.sum(jnp.abs(x), axis=axis, keepdims=keepdims)
|
||||
else:
|
||||
abs_x = jnp.abs(x)
|
||||
ord = lax._const(abs_x, ord)
|
||||
out = jnp.sum(abs_x ** ord, axis=axis, keepdims=keepdims)
|
||||
return jnp.power(out, 1. / ord)
|
||||
|
||||
elif num_axes == 2:
|
||||
row_axis, col_axis = cast(Tuple[int, ...], axis)
|
||||
if ord is None or ord in ('f', 'fro'):
|
||||
return jnp.sqrt(jnp.sum(jnp.real(x * jnp.conj(x)), axis=axis,
|
||||
keepdims=keepdims))
|
||||
elif ord == 1:
|
||||
if not keepdims and col_axis > row_axis:
|
||||
col_axis -= 1
|
||||
return jnp.amax(jnp.sum(jnp.abs(x), axis=row_axis, keepdims=keepdims),
|
||||
axis=col_axis, keepdims=keepdims)
|
||||
elif ord == -1:
|
||||
if not keepdims and col_axis > row_axis:
|
||||
col_axis -= 1
|
||||
return jnp.amin(jnp.sum(jnp.abs(x), axis=row_axis, keepdims=keepdims),
|
||||
axis=col_axis, keepdims=keepdims)
|
||||
elif ord == jnp.inf:
|
||||
if not keepdims and row_axis > col_axis:
|
||||
row_axis -= 1
|
||||
return jnp.amax(jnp.sum(jnp.abs(x), axis=col_axis, keepdims=keepdims),
|
||||
axis=row_axis, keepdims=keepdims)
|
||||
elif ord == -jnp.inf:
|
||||
if not keepdims and row_axis > col_axis:
|
||||
row_axis -= 1
|
||||
return jnp.amin(jnp.sum(jnp.abs(x), axis=col_axis, keepdims=keepdims),
|
||||
axis=row_axis, keepdims=keepdims)
|
||||
elif ord in ('nuc', 2, -2):
|
||||
x = jnp.moveaxis(x, axis, (-2, -1))
|
||||
if ord == 2:
|
||||
reducer = jnp.amax
|
||||
elif ord == -2:
|
||||
reducer = jnp.amin
|
||||
else:
|
||||
reducer = jnp.sum
|
||||
y = reducer(svd(x, compute_uv=False), axis=-1)
|
||||
if keepdims:
|
||||
result_shape = list(x_shape)
|
||||
result_shape[axis[0]] = 1
|
||||
result_shape[axis[1]] = 1
|
||||
y = jnp.reshape(y, result_shape)
|
||||
return y
|
||||
else:
|
||||
raise ValueError("Invalid order '{}' for matrix norm.".format(ord))
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid axis values ({}) for jnp.linalg.norm.".format(axis))
|
||||
|
||||
@_wraps(np.linalg.norm)
|
||||
def norm(x, ord=None, axis=None, keepdims=False):
|
||||
return _norm(x, ord, axis, keepdims)
|
||||
|
||||
|
||||
@_wraps(np.linalg.qr)
|
||||
def qr(a, mode="reduced"):
|
||||
if mode in ("reduced", "r", "full"):
|
||||
full_matrices = False
|
||||
elif mode == "complete":
|
||||
full_matrices = True
|
||||
else:
|
||||
raise ValueError("Unsupported QR decomposition mode '{}'".format(mode))
|
||||
a = _promote_arg_dtypes(jnp.asarray(a))
|
||||
q, r = lax_linalg.qr(a, full_matrices)
|
||||
if mode == "r":
|
||||
return r
|
||||
return q, r
|
||||
|
||||
|
||||
def _check_solve_shapes(a, b):
|
||||
if not (a.ndim >= 2 and a.shape[-1] == a.shape[-2] and b.ndim >= 1):
|
||||
msg = ("The arguments to solve must have shapes a=[..., m, m] and "
|
||||
"b=[..., m, k] or b=[..., m]; got a={} and b={}")
|
||||
raise ValueError(msg.format(a.shape, b.shape))
|
||||
|
||||
|
||||
@partial(vectorize, signature='(n,m),(m)->(n)')
|
||||
def _matvec_multiply(a, b):
|
||||
return jnp.dot(a, b, precision=lax.Precision.HIGHEST)
|
||||
|
||||
|
||||
@_wraps(np.linalg.solve)
|
||||
@jit
|
||||
def solve(a, b):
|
||||
a, b = _promote_arg_dtypes(jnp.asarray(a), jnp.asarray(b))
|
||||
_check_solve_shapes(a, b)
|
||||
|
||||
# With custom_linear_solve, we can reuse the same factorization when
|
||||
# computing sensitivities. This is considerably faster.
|
||||
lu, _, permutation = lax_linalg.lu(lax.stop_gradient(a))
|
||||
custom_solve = partial(
|
||||
lax.custom_linear_solve,
|
||||
lambda x: _matvec_multiply(a, x),
|
||||
solve=lambda _, x: lax_linalg.lu_solve(lu, permutation, x, trans=0),
|
||||
transpose_solve=lambda _, x: lax_linalg.lu_solve(lu, permutation, x,
|
||||
trans=1))
|
||||
if a.ndim == b.ndim + 1:
|
||||
# b.shape == [..., m]
|
||||
return custom_solve(b)
|
||||
else:
|
||||
# b.shape == [..., m, k]
|
||||
return vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b)
|
||||
|
||||
|
||||
@_wraps(np.linalg.lstsq, lax_description=textwrap.dedent("""\
|
||||
It has two important differences:
|
||||
|
||||
1. In `numpy.linalg.lstsq`, the default `rcond` is `-1`, and warns that in the future
|
||||
the default will be `None`. Here, the default rcond is `None`.
|
||||
2. In `np.linalg.lstsq` the returned residuals are empty for low-rank or over-determined
|
||||
solutions. Here, the residuals are returned in all cases, to make the function
|
||||
compatible with jit. The non-jit compatible numpy behavior can be recovered by
|
||||
passing numpy_resid=True.
|
||||
|
||||
The lstsq function does not currently have a custom JVP rule, so the gradient is
|
||||
poorly behaved for some inputs, particularly for low-rank `a`.
|
||||
"""))
|
||||
def lstsq(a, b, rcond=None, *, numpy_resid=False):
|
||||
# TODO: add lstsq to lax_linalg and implement this function via those wrappers.
|
||||
# TODO: add custom jvp rule for more robust lstsq differentiation
|
||||
a, b = _promote_arg_dtypes(a, b)
|
||||
if a.shape[0] != b.shape[0]:
|
||||
raise ValueError("Leading dimensions of input arrays must match")
|
||||
b_orig_ndim = b.ndim
|
||||
if b_orig_ndim == 1:
|
||||
b = b[:, None]
|
||||
if a.ndim != 2:
|
||||
raise TypeError(
|
||||
f"{a.ndim}-dimensional array given. Array must be two-dimensional")
|
||||
if b.ndim != 2:
|
||||
raise TypeError(
|
||||
f"{b.ndim}-dimensional array given. Array must be one or two-dimensional")
|
||||
m, n = a.shape
|
||||
dtype = a.dtype
|
||||
if rcond is None:
|
||||
rcond = jnp.finfo(dtype).eps * max(n, m)
|
||||
elif rcond < 0:
|
||||
rcond = jnp.finfo(dtype).eps
|
||||
u, s, vt = svd(a, full_matrices=False)
|
||||
mask = s >= rcond * s[0]
|
||||
rank = mask.sum()
|
||||
safe_s = jnp.where(mask, s, 1)
|
||||
s_inv = jnp.where(mask, 1 / safe_s, 0)[:, jnp.newaxis]
|
||||
uTb = jnp.matmul(u.conj().T, b, precision=lax.Precision.HIGHEST)
|
||||
x = jnp.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST)
|
||||
# Numpy returns empty residuals in some cases. To allow compilation, we
|
||||
# default to returning full residuals in all cases.
|
||||
if numpy_resid and (rank < n or m <= n):
|
||||
resid = jnp.asarray([])
|
||||
else:
|
||||
b_estimate = jnp.matmul(a, x, precision=lax.Precision.HIGHEST)
|
||||
resid = norm(b - b_estimate, axis=0) ** 2
|
||||
if b_orig_ndim == 1:
|
||||
x = x.ravel()
|
||||
return x, resid, rank, s
|
||||
|
||||
|
||||
# flake8: noqa: F401
|
||||
|
||||
from jax._src.numpy.linalg import (
|
||||
cholesky,
|
||||
cond,
|
||||
det,
|
||||
eig,
|
||||
eigh,
|
||||
eigvals,
|
||||
eigvalsh,
|
||||
inv,
|
||||
lstsq,
|
||||
matrix_power,
|
||||
matrix_rank,
|
||||
multi_dot,
|
||||
norm,
|
||||
pinv,
|
||||
qr,
|
||||
slogdet,
|
||||
solve,
|
||||
svd,
|
||||
tensorinv,
|
||||
tensorsolve,
|
||||
)
|
||||
|
||||
# Module initialization is encapsulated in a function to avoid accidental
|
||||
# namespace pollution.
|
||||
_NOT_IMPLEMENTED = []
|
||||
for name, func in get_module_functions(np.linalg).items():
|
||||
if name not in globals():
|
||||
_NOT_IMPLEMENTED.append(name)
|
||||
globals()[name] = _not_implemented(func)
|
||||
def _init():
|
||||
import numpy as np
|
||||
from jax._src.numpy import lax_numpy
|
||||
from jax import util
|
||||
# Builds a set of all unimplemented NumPy functions.
|
||||
for name, func in util.get_module_functions(np.linalg).items():
|
||||
if name not in globals():
|
||||
_NOT_IMPLEMENTED.append(name)
|
||||
globals()[name] = lax_numpy._not_implemented(func)
|
||||
|
||||
_init()
|
||||
del _init
|
||||
|
@ -16,7 +16,7 @@
|
||||
|
||||
|
||||
from .. import lax
|
||||
from ..numpy import lax_numpy as jnp
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from .. import util
|
||||
|
||||
|
||||
|
@ -49,7 +49,7 @@ from . import lax
|
||||
from . import numpy as jnp
|
||||
from . import dtypes
|
||||
from .api import jit, vmap
|
||||
from .numpy.lax_numpy import _constant_like, asarray
|
||||
from jax._src.numpy.lax_numpy import _constant_like, asarray
|
||||
from jax.lib import xla_bridge
|
||||
from jax.lib import xla_client
|
||||
from jax.lib import cuda_prng
|
||||
|
6
jax/third_party/numpy/linalg.py
vendored
6
jax/third_party/numpy/linalg.py
vendored
@ -1,8 +1,8 @@
|
||||
import numpy as np
|
||||
|
||||
from jax.numpy import lax_numpy as jnp
|
||||
from jax.numpy import linalg as la
|
||||
from jax.numpy._util import _wraps
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src.numpy import linalg as la
|
||||
from jax._src.numpy.util import _wraps
|
||||
|
||||
|
||||
def _isEmpty2d(arr):
|
||||
|
@ -23,7 +23,7 @@ from jax import lax
|
||||
from jax import core
|
||||
from jax import test_util as jtu
|
||||
from jax.config import config
|
||||
from jax.numpy.lax_numpy import _polymorphic_slice_indices
|
||||
from jax._src.numpy.lax_numpy import _polymorphic_slice_indices
|
||||
from jax.util import safe_map, safe_zip
|
||||
from jax.tree_util import tree_flatten
|
||||
|
||||
|
@ -34,12 +34,6 @@ all_dtypes = jtu.dtypes.floating + jtu.dtypes.integer + jtu.dtypes.complex
|
||||
|
||||
class TestPolynomial(jtu.JaxTestCase):
|
||||
|
||||
def testNotImplemented(self):
|
||||
for name in jnp.polynomial._NOT_IMPLEMENTED:
|
||||
func = getattr(jnp.polynomial, name)
|
||||
with self.assertRaises(NotImplementedError):
|
||||
func()
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_dtype={}_leading={}_trailing={}".format(
|
||||
jtu.format_shape_dtype_string((length+leading+trailing,), dtype),
|
||||
|
Loading…
x
Reference in New Issue
Block a user