2018-12-12 08:44:07 -05:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2018-12-11 12:44:02 -05:00
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
|
|
|
|
2019-08-07 09:21:07 -04:00
|
|
|
from functools import partial
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2018-12-11 12:44:02 -05:00
|
|
|
import numpy as onp
|
2018-12-17 17:39:46 -05:00
|
|
|
import warnings
|
2019-12-09 09:56:26 -08:00
|
|
|
import textwrap
|
2020-01-18 08:26:23 -05:00
|
|
|
from typing import Tuple, Union, cast
|
2018-12-11 12:44:02 -05:00
|
|
|
|
2019-08-07 09:21:07 -04:00
|
|
|
from jax import jit
|
2018-12-20 15:37:34 -05:00
|
|
|
from .. import lax
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
from .. import lax_linalg
|
2019-11-15 10:02:51 -05:00
|
|
|
from .. import dtypes
|
2018-12-11 12:44:02 -05:00
|
|
|
from .lax_numpy import _not_implemented
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
from .lax_numpy import _wraps
|
|
|
|
from . import lax_numpy as np
|
2019-09-05 15:22:36 +01:00
|
|
|
from ..api import custom_transforms, defjvp
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
from ..util import get_module_functions
|
|
|
|
|
|
|
|
|
2018-12-13 19:28:05 -05:00
|
|
|
_T = lambda x: np.swapaxes(x, -1, -2)
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2019-01-07 18:10:08 -05:00
|
|
|
|
|
|
|
def _promote_arg_dtypes(*args):
|
|
|
|
"""Promotes `args` to a common inexact type."""
|
|
|
|
def _to_inexact_type(type):
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
return type if np.issubdtype(type, np.inexact) else np.float_
|
2019-01-07 18:10:08 -05:00
|
|
|
inexact_types = [_to_inexact_type(np._dtype(arg)) for arg in args]
|
2019-11-15 10:02:51 -05:00
|
|
|
dtype = dtypes.canonicalize_dtype(np.result_type(*inexact_types))
|
2019-01-07 18:10:08 -05:00
|
|
|
args = [lax.convert_element_type(arg, dtype) for arg in args]
|
|
|
|
if len(args) == 1:
|
|
|
|
return args[0]
|
|
|
|
else:
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
@_wraps(onp.linalg.cholesky)
|
|
|
|
def cholesky(a):
|
2019-01-07 18:10:08 -05:00
|
|
|
a = _promote_arg_dtypes(np.asarray(a))
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
return lax_linalg.cholesky(a)
|
|
|
|
|
2018-12-20 15:37:34 -05:00
|
|
|
|
2019-01-05 11:13:08 +05:30
|
|
|
@_wraps(onp.linalg.svd)
|
|
|
|
def svd(a, full_matrices=True, compute_uv=True):
|
2019-01-08 09:24:48 +05:30
|
|
|
a = _promote_arg_dtypes(np.asarray(a))
|
2019-01-05 11:13:08 +05:30
|
|
|
return lax_linalg.svd(a, full_matrices, compute_uv)
|
|
|
|
|
|
|
|
|
2019-09-17 18:58:34 +01:00
|
|
|
# TODO(pfau): make this work for complex types
|
2019-09-14 14:35:27 +01:00
|
|
|
def _jvp_slogdet(g, ans, x):
|
|
|
|
jvp_sign = np.zeros(x.shape[:-2])
|
2019-09-16 21:27:55 +01:00
|
|
|
jvp_logdet = np.trace(solve(x, g), axis1=-1, axis2=-2)
|
2019-09-14 14:35:27 +01:00
|
|
|
return jvp_sign, jvp_logdet
|
|
|
|
|
|
|
|
|
2018-12-20 22:18:20 -05:00
|
|
|
@_wraps(onp.linalg.slogdet)
|
2019-09-14 14:35:27 +01:00
|
|
|
@custom_transforms
|
2019-09-16 20:52:36 +01:00
|
|
|
@jit
|
2018-12-20 22:18:20 -05:00
|
|
|
def slogdet(a):
|
2019-01-07 18:10:08 -05:00
|
|
|
a = _promote_arg_dtypes(np.asarray(a))
|
2019-04-12 16:28:40 -07:00
|
|
|
dtype = lax.dtype(a)
|
2018-12-20 15:37:34 -05:00
|
|
|
a_shape = np.shape(a)
|
2018-12-20 22:18:20 -05:00
|
|
|
if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]:
|
|
|
|
msg = "Argument to slogdet() must have shape [..., n, n], got {}"
|
2018-12-20 15:37:34 -05:00
|
|
|
raise ValueError(msg.format(a_shape))
|
|
|
|
lu, pivot = lax_linalg.lu(a)
|
2018-12-20 22:18:20 -05:00
|
|
|
diag = np.diagonal(lu, axis1=-2, axis2=-1)
|
2018-12-21 15:18:34 -05:00
|
|
|
is_zero = np.any(diag == np.array(0, dtype=dtype), axis=-1)
|
2018-12-20 22:18:20 -05:00
|
|
|
parity = np.count_nonzero(pivot != np.arange(a_shape[-1]), axis=-1)
|
|
|
|
if np.iscomplexobj(a):
|
2019-09-11 08:19:26 -04:00
|
|
|
sign = np.prod(diag / np.abs(diag), axis=-1)
|
2018-12-20 22:18:20 -05:00
|
|
|
else:
|
2018-12-21 15:18:34 -05:00
|
|
|
sign = np.array(1, dtype=dtype)
|
2019-09-11 08:19:26 -04:00
|
|
|
parity = parity + np.count_nonzero(diag < 0, axis=-1)
|
2018-12-20 22:18:20 -05:00
|
|
|
sign = np.where(is_zero,
|
|
|
|
np.array(0, dtype=dtype),
|
|
|
|
sign * np.array(-2 * (parity % 2) + 1, dtype=dtype))
|
|
|
|
logdet = np.where(
|
|
|
|
is_zero, np.array(-np.inf, dtype=dtype),
|
|
|
|
np.sum(np.log(np.abs(diag)), axis=-1))
|
2018-12-21 15:18:34 -05:00
|
|
|
return sign, np.real(logdet)
|
2019-09-14 14:35:27 +01:00
|
|
|
defjvp(slogdet, _jvp_slogdet)
|
2019-09-14 14:30:45 +01:00
|
|
|
|
|
|
|
|
2018-12-20 22:18:20 -05:00
|
|
|
@_wraps(onp.linalg.det)
|
|
|
|
def det(a):
|
|
|
|
sign, logdet = slogdet(a)
|
|
|
|
return sign * np.exp(logdet)
|
2018-12-20 15:37:34 -05:00
|
|
|
|
|
|
|
|
2019-05-13 15:59:58 -04:00
|
|
|
@_wraps(onp.linalg.eig)
|
|
|
|
def eig(a):
|
|
|
|
a = _promote_arg_dtypes(np.asarray(a))
|
|
|
|
w, vl, vr = lax_linalg.eig(a)
|
|
|
|
return w, vr
|
|
|
|
|
|
|
|
|
2019-10-30 19:29:56 -07:00
|
|
|
@_wraps(onp.linalg.eigvals)
|
|
|
|
def eigvals(a):
|
|
|
|
w, _ = eig(a)
|
|
|
|
return w
|
|
|
|
|
|
|
|
|
2019-01-07 18:10:08 -05:00
|
|
|
@_wraps(onp.linalg.eigh)
|
2019-02-13 23:44:41 -08:00
|
|
|
def eigh(a, UPLO=None, symmetrize_input=True):
|
2019-01-07 18:10:08 -05:00
|
|
|
if UPLO is None or UPLO == "L":
|
|
|
|
lower = True
|
|
|
|
elif UPLO == "U":
|
|
|
|
lower = False
|
|
|
|
else:
|
|
|
|
msg = "UPLO must be one of None, 'L', or 'U', got {}".format(UPLO)
|
|
|
|
raise ValueError(msg)
|
|
|
|
|
|
|
|
a = _promote_arg_dtypes(np.asarray(a))
|
2019-02-13 23:44:41 -08:00
|
|
|
v, w = lax_linalg.eigh(a, lower=lower, symmetrize_input=symmetrize_input)
|
2019-01-07 18:10:08 -05:00
|
|
|
return w, v
|
|
|
|
|
|
|
|
|
2019-10-30 19:29:56 -07:00
|
|
|
@_wraps(onp.linalg.eigvalsh)
|
|
|
|
def eigvalsh(a, UPLO='L'):
|
|
|
|
w, _ = eigh(a, UPLO)
|
|
|
|
return w
|
|
|
|
|
|
|
|
|
2019-12-09 09:56:26 -08:00
|
|
|
@_wraps(onp.linalg.pinv, lax_description=textwrap.dedent("""\
|
|
|
|
It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the
|
|
|
|
default `rcond` is `1e-15`. Here the default is
|
|
|
|
`10. * max(num_rows, num_cols) * np.finfo(dtype).eps`.
|
|
|
|
"""))
|
2019-12-03 11:15:39 -08:00
|
|
|
def pinv(a, rcond=None):
|
|
|
|
# ported from https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979
|
|
|
|
a = np.conj(a)
|
2019-12-09 09:56:26 -08:00
|
|
|
# copied from https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/math/linalg.py#L442
|
2019-12-03 11:15:39 -08:00
|
|
|
if rcond is None:
|
|
|
|
max_rows_cols = max(a.shape[-2:])
|
|
|
|
rcond = 10. * max_rows_cols * np.finfo(a.dtype).eps
|
|
|
|
rcond = np.asarray(rcond)
|
|
|
|
u, s, v = svd(a, full_matrices=False)
|
|
|
|
# Singular values less than or equal to ``rcond * largest_singular_value``
|
|
|
|
# are set to zero.
|
|
|
|
cutoff = rcond[..., np.newaxis] * np.amax(s, axis=-1, keepdims=True)
|
|
|
|
large = s > cutoff
|
|
|
|
s = np.divide(1, s)
|
|
|
|
s = np.where(large, s, 0)
|
|
|
|
vT = np.swapaxes(v, -1, -2)
|
|
|
|
uT = np.swapaxes(u, -1, -2)
|
|
|
|
res = np.matmul(vT, np.multiply(s[..., np.newaxis], uT))
|
|
|
|
return lax.convert_element_type(res, a.dtype)
|
|
|
|
|
|
|
|
|
2018-12-13 19:28:05 -05:00
|
|
|
@_wraps(onp.linalg.inv)
|
|
|
|
def inv(a):
|
|
|
|
if np.ndim(a) < 2 or a.shape[-1] != a.shape[-2]:
|
2018-12-13 21:02:24 -05:00
|
|
|
raise ValueError("Argument to inv must have shape [..., n, n], got {}."
|
|
|
|
.format(np.shape(a)))
|
2019-06-28 15:49:38 -04:00
|
|
|
return solve(
|
|
|
|
a, lax.broadcast(np.eye(a.shape[-1], dtype=lax.dtype(a)), a.shape[:-2]))
|
2018-12-13 19:28:05 -05:00
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2019-08-07 09:21:07 -04:00
|
|
|
@partial(jit, static_argnums=(1, 2, 3))
|
2020-01-18 08:26:23 -05:00
|
|
|
def _norm(x, ord, axis: Union[None, Tuple[int, ...], int], keepdims):
|
2019-02-07 10:51:55 -05:00
|
|
|
x = _promote_arg_dtypes(np.asarray(x))
|
|
|
|
x_shape = np.shape(x)
|
|
|
|
ndim = len(x_shape)
|
|
|
|
|
|
|
|
if axis is None:
|
2019-08-07 09:21:07 -04:00
|
|
|
# NumPy has an undocumented behavior that admits arbitrary rank inputs if
|
|
|
|
# `ord` is None: https://github.com/numpy/numpy/issues/14215
|
|
|
|
if ord is None:
|
|
|
|
return np.sqrt(np.sum(np.real(x * np.conj(x)), keepdims=keepdims))
|
2019-02-07 10:51:55 -05:00
|
|
|
axis = tuple(range(ndim))
|
|
|
|
elif isinstance(axis, tuple):
|
|
|
|
axis = tuple(np._canonicalize_axis(x, ndim) for x in axis)
|
|
|
|
else:
|
|
|
|
axis = (np._canonicalize_axis(axis, ndim),)
|
|
|
|
|
|
|
|
num_axes = len(axis)
|
|
|
|
if num_axes == 1:
|
|
|
|
if ord is None or ord == 2:
|
|
|
|
return np.sqrt(np.sum(np.real(x * np.conj(x)), axis=axis,
|
|
|
|
keepdims=keepdims))
|
|
|
|
elif ord == np.inf:
|
|
|
|
return np.amax(np.abs(x), axis=axis, keepdims=keepdims)
|
|
|
|
elif ord == -np.inf:
|
|
|
|
return np.amin(np.abs(x), axis=axis, keepdims=keepdims)
|
|
|
|
elif ord == 0:
|
2019-04-12 16:28:40 -07:00
|
|
|
return np.sum(x != 0, dtype=np.finfo(lax.dtype(x)).dtype,
|
2019-02-07 10:51:55 -05:00
|
|
|
axis=axis, keepdims=keepdims)
|
|
|
|
elif ord == 1:
|
|
|
|
# Numpy has a special case for ord == 1 as an optimization. We don't
|
|
|
|
# really need the optimization (XLA could do it for us), but the Numpy
|
|
|
|
# code has slightly different type promotion semantics, so we need a
|
|
|
|
# special case too.
|
|
|
|
return np.sum(np.abs(x), axis=axis, keepdims=keepdims)
|
|
|
|
else:
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
abs_x = np.abs(x)
|
|
|
|
ord = lax._const(abs_x, ord)
|
|
|
|
out = np.sum(abs_x ** ord, axis=axis, keepdims=keepdims)
|
|
|
|
return np.power(out, 1. / ord)
|
2019-02-07 10:51:55 -05:00
|
|
|
|
|
|
|
elif num_axes == 2:
|
2020-01-18 08:26:23 -05:00
|
|
|
row_axis, col_axis = cast(Tuple[int, ...], axis)
|
2019-02-07 10:51:55 -05:00
|
|
|
if ord is None or ord in ('f', 'fro'):
|
|
|
|
return np.sqrt(np.sum(np.real(x * np.conj(x)), axis=axis,
|
|
|
|
keepdims=keepdims))
|
|
|
|
elif ord == 1:
|
|
|
|
if not keepdims and col_axis > row_axis:
|
|
|
|
col_axis -= 1
|
|
|
|
return np.amax(np.sum(np.abs(x), axis=row_axis, keepdims=keepdims),
|
|
|
|
axis=col_axis, keepdims=keepdims)
|
|
|
|
elif ord == -1:
|
|
|
|
if not keepdims and col_axis > row_axis:
|
|
|
|
col_axis -= 1
|
|
|
|
return np.amin(np.sum(np.abs(x), axis=row_axis, keepdims=keepdims),
|
|
|
|
axis=col_axis, keepdims=keepdims)
|
|
|
|
elif ord == np.inf:
|
|
|
|
if not keepdims and row_axis > col_axis:
|
|
|
|
row_axis -= 1
|
|
|
|
return np.amax(np.sum(np.abs(x), axis=col_axis, keepdims=keepdims),
|
|
|
|
axis=row_axis, keepdims=keepdims)
|
|
|
|
elif ord == -np.inf:
|
|
|
|
if not keepdims and row_axis > col_axis:
|
|
|
|
row_axis -= 1
|
|
|
|
return np.amin(np.sum(np.abs(x), axis=col_axis, keepdims=keepdims),
|
|
|
|
axis=row_axis, keepdims=keepdims)
|
|
|
|
elif ord in ('nuc', 2, -2):
|
|
|
|
x = np.moveaxis(x, axis, (-2, -1))
|
|
|
|
if ord == 2:
|
|
|
|
reducer = np.amax
|
|
|
|
elif ord == -2:
|
|
|
|
reducer = np.amin
|
|
|
|
else:
|
|
|
|
reducer = np.sum
|
|
|
|
y = reducer(svd(x, compute_uv=False), axis=-1)
|
|
|
|
if keepdims:
|
|
|
|
result_shape = list(x_shape)
|
|
|
|
result_shape[axis[0]] = 1
|
|
|
|
result_shape[axis[1]] = 1
|
|
|
|
y = np.reshape(y, result_shape)
|
|
|
|
return y
|
|
|
|
else:
|
|
|
|
raise ValueError("Invalid order '{}' for matrix norm.".format(ord))
|
|
|
|
else:
|
|
|
|
raise ValueError(
|
|
|
|
"Invalid axis values ({}) for np.linalg.norm.".format(axis))
|
|
|
|
|
2019-08-07 09:21:07 -04:00
|
|
|
@_wraps(onp.linalg.norm)
|
|
|
|
def norm(x, ord=None, axis=None, keepdims=False):
|
|
|
|
return _norm(x, ord, axis, keepdims)
|
|
|
|
|
2019-02-07 10:51:55 -05:00
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
@_wraps(onp.linalg.qr)
|
|
|
|
def qr(a, mode="reduced"):
|
|
|
|
if mode in ("reduced", "r", "full"):
|
|
|
|
full_matrices = False
|
|
|
|
elif mode == "complete":
|
|
|
|
full_matrices = True
|
|
|
|
else:
|
|
|
|
raise ValueError("Unsupported QR decomposition mode '{}'".format(mode))
|
2019-01-07 18:10:08 -05:00
|
|
|
a = _promote_arg_dtypes(np.asarray(a))
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
q, r = lax_linalg.qr(a, full_matrices)
|
|
|
|
if mode == "r":
|
|
|
|
return r
|
|
|
|
return q, r
|
2018-12-11 12:44:02 -05:00
|
|
|
|
2018-12-21 16:29:45 -05:00
|
|
|
|
|
|
|
@_wraps(onp.linalg.solve)
|
2019-08-08 11:50:31 -04:00
|
|
|
@jit
|
2018-12-21 16:29:45 -05:00
|
|
|
def solve(a, b):
|
2019-01-07 18:10:08 -05:00
|
|
|
a, b = _promote_arg_dtypes(np.asarray(a), np.asarray(b))
|
2018-12-21 16:29:45 -05:00
|
|
|
a_shape = np.shape(a)
|
|
|
|
b_shape = np.shape(b)
|
|
|
|
a_ndims = len(a_shape)
|
|
|
|
b_ndims = len(b_shape)
|
2019-06-17 20:32:19 -04:00
|
|
|
if not (a_ndims >= 2 and a_shape[-1] == a_shape[-2] and b_ndims >= 1):
|
2018-12-21 16:29:45 -05:00
|
|
|
msg = ("The arguments to solve must have shapes a=[..., m, m] and "
|
|
|
|
"b=[..., m, k] or b=[..., m]; got a={} and b={}")
|
|
|
|
raise ValueError(msg.format(a_shape, b_shape))
|
|
|
|
lu, pivots = lax_linalg.lu(a)
|
2019-04-12 16:28:40 -07:00
|
|
|
dtype = lax.dtype(a)
|
2018-12-21 16:29:45 -05:00
|
|
|
|
|
|
|
m = a_shape[-1]
|
|
|
|
|
2019-06-17 20:32:19 -04:00
|
|
|
# Numpy treats the RHS as a (batched) vector if the number of dimensions
|
|
|
|
# differ by 1. Otherwise, broadcasting rules apply.
|
|
|
|
x = b[..., None] if a_ndims == b_ndims + 1 else b
|
|
|
|
|
|
|
|
batch_dims = lax.broadcast_shapes(lu.shape[:-2], x.shape[:-2])
|
|
|
|
x = np.broadcast_to(x, batch_dims + x.shape[-2:])
|
|
|
|
lu = np.broadcast_to(lu, batch_dims + lu.shape[-2:])
|
2018-12-21 16:29:45 -05:00
|
|
|
|
|
|
|
permutation = lax_linalg.lu_pivots_to_permutation(pivots, m)
|
2019-06-17 20:32:19 -04:00
|
|
|
permutation = np.broadcast_to(permutation, batch_dims + (m,))
|
|
|
|
iotas = np.ix_(*(lax.iota(np.int32, b) for b in batch_dims + (1,)))
|
|
|
|
x = x[iotas[:-1] + (permutation, slice(None))]
|
2018-12-21 16:29:45 -05:00
|
|
|
|
2019-06-25 15:24:22 -04:00
|
|
|
x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=True,
|
|
|
|
unit_diagonal=True)
|
2018-12-21 16:29:45 -05:00
|
|
|
x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=False)
|
|
|
|
|
2019-06-17 20:32:19 -04:00
|
|
|
return x[..., 0] if a_ndims == b_ndims + 1 else x
|
2018-12-21 16:29:45 -05:00
|
|
|
|
|
|
|
|
2018-12-13 11:52:41 -08:00
|
|
|
for func in get_module_functions(onp.linalg):
|
2018-12-11 12:44:02 -05:00
|
|
|
if func.__name__ not in globals():
|
|
|
|
globals()[func.__name__] = _not_implemented(func)
|