2018-12-12 08:44:07 -05:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2018-12-11 12:44:02 -05:00
|
|
|
|
2019-08-07 09:21:07 -04:00
|
|
|
from functools import partial
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2018-12-11 12:44:02 -05:00
|
|
|
import numpy as onp
|
2019-12-09 09:56:26 -08:00
|
|
|
import textwrap
|
2020-01-24 16:52:40 -05:00
|
|
|
import operator
|
2020-01-18 08:26:23 -05:00
|
|
|
from typing import Tuple, Union, cast
|
2018-12-11 12:44:02 -05:00
|
|
|
|
2020-03-29 20:48:08 -07:00
|
|
|
from jax import jit, vmap, custom_jvp
|
2018-12-20 15:37:34 -05:00
|
|
|
from .. import lax
|
2020-01-15 15:00:38 -08:00
|
|
|
from .. import ops
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
from .. import lax_linalg
|
2019-11-15 10:02:51 -05:00
|
|
|
from .. import dtypes
|
2018-12-11 12:44:02 -05:00
|
|
|
from .lax_numpy import _not_implemented
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
from .lax_numpy import _wraps
|
Faster gradient rules for {numpy,scipy}.linalg.solve (#2220)
Fixes GH1747
The implicit function theorem (via `lax.custom_linear_solve`) lets us
_directly_ define gradients for linear solves, in contrast to the current
implementations of gradient for `solve` which rely upon differentiating matrix
factorization.
In **theory**, JVPs of `cholesky` and `lu` involve the equivalent of ~3 dense
matrix-matrix multiplications, which makes them rather expensive: time
`O(n**3)`. In contrast, with `custom_linear_solve` we don't need to
differentiate the factorization. The JVP and VJP rules for linear solve (for a
single right-hand-side vector) now only use matrix-vector products and
triangular solves, which is time `O(n**2)`. We should also have reduced memory
usage, because we don't need to save any intermediate outputs.
In **practice**, these new gradient rules seem to make solves with large
arrays ~3x faster:
from functools import partial
import jax.scipy as jsp
from jax import lax
import jax.numpy as np
import numpy as onp
import jax
def loss(solve):
def f(a, b):
return solve(a, b).sum()
return f
rs = onp.random.RandomState(0)
N = 500
K = 1
a = rs.randn(N, N)
a = jax.device_put(a.T @ a + 0.1 * np.eye(N))
b = jax.device_put(rs.randn(N, K))
# general matrix solve
grad = jax.jit(jax.grad(loss(np.linalg.solve)))
grad(a, b).block_until_ready()
%timeit grad(a, b).block_until_ready()
# N=500, K=1: 11.4 ms -> 3.63 ms
# positive definite solve
grad = jax.jit(jax.grad(loss(partial(jsp.linalg.solve, sym_pos=True))))
grad(a, b).block_until_ready()
%timeit grad(a, b).block_until_ready()
# N=500, K=1: 9.22 ms -> 2.83 ms
2020-02-18 17:41:38 -08:00
|
|
|
from .vectorize import vectorize
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
from . import lax_numpy as np
|
|
|
|
from ..util import get_module_functions
|
2020-04-15 17:35:54 -07:00
|
|
|
from ..third_party.numpy.linalg import cond, multi_dot, tensorinv, tensorsolve
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2018-12-13 19:28:05 -05:00
|
|
|
_T = lambda x: np.swapaxes(x, -1, -2)
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2019-01-07 18:10:08 -05:00
|
|
|
|
|
|
|
def _promote_arg_dtypes(*args):
|
|
|
|
"""Promotes `args` to a common inexact type."""
|
|
|
|
def _to_inexact_type(type):
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
return type if np.issubdtype(type, np.inexact) else np.float_
|
2019-01-07 18:10:08 -05:00
|
|
|
inexact_types = [_to_inexact_type(np._dtype(arg)) for arg in args]
|
2019-11-15 10:02:51 -05:00
|
|
|
dtype = dtypes.canonicalize_dtype(np.result_type(*inexact_types))
|
2019-01-07 18:10:08 -05:00
|
|
|
args = [lax.convert_element_type(arg, dtype) for arg in args]
|
|
|
|
if len(args) == 1:
|
|
|
|
return args[0]
|
|
|
|
else:
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
@_wraps(onp.linalg.cholesky)
|
|
|
|
def cholesky(a):
|
2019-01-07 18:10:08 -05:00
|
|
|
a = _promote_arg_dtypes(np.asarray(a))
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
return lax_linalg.cholesky(a)
|
|
|
|
|
2018-12-20 15:37:34 -05:00
|
|
|
|
2019-01-05 11:13:08 +05:30
|
|
|
@_wraps(onp.linalg.svd)
|
|
|
|
def svd(a, full_matrices=True, compute_uv=True):
|
2019-01-08 09:24:48 +05:30
|
|
|
a = _promote_arg_dtypes(np.asarray(a))
|
2019-01-05 11:13:08 +05:30
|
|
|
return lax_linalg.svd(a, full_matrices, compute_uv)
|
|
|
|
|
|
|
|
|
2020-01-24 16:52:40 -05:00
|
|
|
@_wraps(onp.linalg.matrix_power)
|
|
|
|
def matrix_power(a, n):
|
|
|
|
a = _promote_arg_dtypes(np.asarray(a))
|
|
|
|
|
|
|
|
if a.ndim < 2:
|
|
|
|
raise TypeError("{}-dimensional array given. Array must be at least "
|
|
|
|
"two-dimensional".format(a.ndim))
|
|
|
|
if a.shape[-2] != a.shape[-1]:
|
|
|
|
raise TypeError("Last 2 dimensions of the array must be square")
|
|
|
|
try:
|
|
|
|
n = operator.index(n)
|
|
|
|
except TypeError:
|
|
|
|
raise TypeError("exponent must be an integer, got {}".format(n))
|
|
|
|
|
|
|
|
if n == 0:
|
|
|
|
return np.broadcast_to(np.eye(a.shape[-2], dtype=a.dtype), a.shape)
|
|
|
|
elif n < 0:
|
|
|
|
a = inv(a)
|
|
|
|
n = np.abs(n)
|
|
|
|
|
|
|
|
if n == 1:
|
|
|
|
return a
|
|
|
|
elif n == 2:
|
|
|
|
return a @ a
|
|
|
|
elif n == 3:
|
|
|
|
return (a @ a) @ a
|
|
|
|
|
|
|
|
z = result = None
|
|
|
|
while n > 0:
|
|
|
|
z = a if z is None else (z @ z)
|
|
|
|
n, bit = divmod(n, 2)
|
|
|
|
if bit:
|
|
|
|
result = z if result is None else (result @ z)
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
2020-01-26 14:29:33 -05:00
|
|
|
@_wraps(onp.linalg.matrix_rank)
|
|
|
|
def matrix_rank(M, tol=None):
|
|
|
|
M = _promote_arg_dtypes(np.asarray(M))
|
|
|
|
if M.ndim > 2:
|
|
|
|
raise TypeError("array should have 2 or fewer dimensions")
|
|
|
|
if M.ndim < 2:
|
|
|
|
return np.any(M != 0).astype(np.int32)
|
|
|
|
S = svd(M, full_matrices=False, compute_uv=False)
|
|
|
|
if tol is None:
|
|
|
|
tol = S.max() * np.max(M.shape) * np.finfo(S.dtype).eps
|
|
|
|
return np.sum(S > tol)
|
|
|
|
|
|
|
|
|
2020-03-29 20:48:08 -07:00
|
|
|
@custom_jvp
|
2018-12-20 22:18:20 -05:00
|
|
|
@_wraps(onp.linalg.slogdet)
|
2019-09-16 20:52:36 +01:00
|
|
|
@jit
|
2018-12-20 22:18:20 -05:00
|
|
|
def slogdet(a):
|
2019-01-07 18:10:08 -05:00
|
|
|
a = _promote_arg_dtypes(np.asarray(a))
|
2019-04-12 16:28:40 -07:00
|
|
|
dtype = lax.dtype(a)
|
2018-12-20 15:37:34 -05:00
|
|
|
a_shape = np.shape(a)
|
2018-12-20 22:18:20 -05:00
|
|
|
if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]:
|
|
|
|
msg = "Argument to slogdet() must have shape [..., n, n], got {}"
|
2018-12-20 15:37:34 -05:00
|
|
|
raise ValueError(msg.format(a_shape))
|
|
|
|
lu, pivot = lax_linalg.lu(a)
|
2018-12-20 22:18:20 -05:00
|
|
|
diag = np.diagonal(lu, axis1=-2, axis2=-1)
|
2018-12-21 15:18:34 -05:00
|
|
|
is_zero = np.any(diag == np.array(0, dtype=dtype), axis=-1)
|
2018-12-20 22:18:20 -05:00
|
|
|
parity = np.count_nonzero(pivot != np.arange(a_shape[-1]), axis=-1)
|
|
|
|
if np.iscomplexobj(a):
|
2019-09-11 08:19:26 -04:00
|
|
|
sign = np.prod(diag / np.abs(diag), axis=-1)
|
2018-12-20 22:18:20 -05:00
|
|
|
else:
|
2018-12-21 15:18:34 -05:00
|
|
|
sign = np.array(1, dtype=dtype)
|
2019-09-11 08:19:26 -04:00
|
|
|
parity = parity + np.count_nonzero(diag < 0, axis=-1)
|
2018-12-20 22:18:20 -05:00
|
|
|
sign = np.where(is_zero,
|
|
|
|
np.array(0, dtype=dtype),
|
|
|
|
sign * np.array(-2 * (parity % 2) + 1, dtype=dtype))
|
|
|
|
logdet = np.where(
|
|
|
|
is_zero, np.array(-np.inf, dtype=dtype),
|
|
|
|
np.sum(np.log(np.abs(diag)), axis=-1))
|
2018-12-21 15:18:34 -05:00
|
|
|
return sign, np.real(logdet)
|
2020-03-29 20:48:08 -07:00
|
|
|
|
|
|
|
@slogdet.defjvp
|
|
|
|
def _slogdet_jvp(primals, tangents):
|
|
|
|
x, = primals
|
|
|
|
g, = tangents
|
|
|
|
if np.issubdtype(np._dtype(x), np.complexfloating):
|
|
|
|
raise NotImplementedError # TODO(pfau): make this work for complex types
|
|
|
|
sign, ans = slogdet(x)
|
|
|
|
sign_dot, ans_dot = np.zeros_like(sign), np.trace(solve(x, g), axis1=-1, axis2=-2)
|
|
|
|
return (sign, ans), (sign_dot, ans_dot)
|
2019-09-14 14:30:45 +01:00
|
|
|
|
|
|
|
|
2018-12-20 22:18:20 -05:00
|
|
|
@_wraps(onp.linalg.det)
|
|
|
|
def det(a):
|
|
|
|
sign, logdet = slogdet(a)
|
|
|
|
return sign * np.exp(logdet)
|
2018-12-20 15:37:34 -05:00
|
|
|
|
|
|
|
|
2019-05-13 15:59:58 -04:00
|
|
|
@_wraps(onp.linalg.eig)
|
|
|
|
def eig(a):
|
|
|
|
a = _promote_arg_dtypes(np.asarray(a))
|
|
|
|
w, vl, vr = lax_linalg.eig(a)
|
|
|
|
return w, vr
|
|
|
|
|
|
|
|
|
2019-10-30 19:29:56 -07:00
|
|
|
@_wraps(onp.linalg.eigvals)
|
|
|
|
def eigvals(a):
|
|
|
|
w, _ = eig(a)
|
|
|
|
return w
|
|
|
|
|
|
|
|
|
2019-01-07 18:10:08 -05:00
|
|
|
@_wraps(onp.linalg.eigh)
|
2019-02-13 23:44:41 -08:00
|
|
|
def eigh(a, UPLO=None, symmetrize_input=True):
|
2019-01-07 18:10:08 -05:00
|
|
|
if UPLO is None or UPLO == "L":
|
|
|
|
lower = True
|
|
|
|
elif UPLO == "U":
|
|
|
|
lower = False
|
|
|
|
else:
|
|
|
|
msg = "UPLO must be one of None, 'L', or 'U', got {}".format(UPLO)
|
|
|
|
raise ValueError(msg)
|
|
|
|
|
|
|
|
a = _promote_arg_dtypes(np.asarray(a))
|
2019-02-13 23:44:41 -08:00
|
|
|
v, w = lax_linalg.eigh(a, lower=lower, symmetrize_input=symmetrize_input)
|
2019-01-07 18:10:08 -05:00
|
|
|
return w, v
|
|
|
|
|
|
|
|
|
2019-10-30 19:29:56 -07:00
|
|
|
@_wraps(onp.linalg.eigvalsh)
|
|
|
|
def eigvalsh(a, UPLO='L'):
|
|
|
|
w, _ = eigh(a, UPLO)
|
|
|
|
return w
|
|
|
|
|
|
|
|
|
2019-12-09 09:56:26 -08:00
|
|
|
@_wraps(onp.linalg.pinv, lax_description=textwrap.dedent("""\
|
|
|
|
It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the
|
|
|
|
default `rcond` is `1e-15`. Here the default is
|
|
|
|
`10. * max(num_rows, num_cols) * np.finfo(dtype).eps`.
|
|
|
|
"""))
|
2019-12-03 11:15:39 -08:00
|
|
|
def pinv(a, rcond=None):
|
|
|
|
# ported from https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979
|
|
|
|
a = np.conj(a)
|
2019-12-09 09:56:26 -08:00
|
|
|
# copied from https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/math/linalg.py#L442
|
2019-12-03 11:15:39 -08:00
|
|
|
if rcond is None:
|
|
|
|
max_rows_cols = max(a.shape[-2:])
|
|
|
|
rcond = 10. * max_rows_cols * np.finfo(a.dtype).eps
|
|
|
|
rcond = np.asarray(rcond)
|
|
|
|
u, s, v = svd(a, full_matrices=False)
|
2020-01-24 16:52:40 -05:00
|
|
|
# Singular values less than or equal to ``rcond * largest_singular_value``
|
2019-12-03 11:15:39 -08:00
|
|
|
# are set to zero.
|
|
|
|
cutoff = rcond[..., np.newaxis] * np.amax(s, axis=-1, keepdims=True)
|
|
|
|
large = s > cutoff
|
|
|
|
s = np.divide(1, s)
|
|
|
|
s = np.where(large, s, 0)
|
|
|
|
vT = np.swapaxes(v, -1, -2)
|
|
|
|
uT = np.swapaxes(u, -1, -2)
|
|
|
|
res = np.matmul(vT, np.multiply(s[..., np.newaxis], uT))
|
|
|
|
return lax.convert_element_type(res, a.dtype)
|
|
|
|
|
|
|
|
|
2018-12-13 19:28:05 -05:00
|
|
|
@_wraps(onp.linalg.inv)
|
|
|
|
def inv(a):
|
|
|
|
if np.ndim(a) < 2 or a.shape[-1] != a.shape[-2]:
|
2018-12-13 21:02:24 -05:00
|
|
|
raise ValueError("Argument to inv must have shape [..., n, n], got {}."
|
|
|
|
.format(np.shape(a)))
|
2019-06-28 15:49:38 -04:00
|
|
|
return solve(
|
|
|
|
a, lax.broadcast(np.eye(a.shape[-1], dtype=lax.dtype(a)), a.shape[:-2]))
|
2018-12-13 19:28:05 -05:00
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2019-08-07 09:21:07 -04:00
|
|
|
@partial(jit, static_argnums=(1, 2, 3))
|
2020-01-18 08:26:23 -05:00
|
|
|
def _norm(x, ord, axis: Union[None, Tuple[int, ...], int], keepdims):
|
2019-02-07 10:51:55 -05:00
|
|
|
x = _promote_arg_dtypes(np.asarray(x))
|
|
|
|
x_shape = np.shape(x)
|
|
|
|
ndim = len(x_shape)
|
|
|
|
|
|
|
|
if axis is None:
|
2019-08-07 09:21:07 -04:00
|
|
|
# NumPy has an undocumented behavior that admits arbitrary rank inputs if
|
|
|
|
# `ord` is None: https://github.com/numpy/numpy/issues/14215
|
|
|
|
if ord is None:
|
|
|
|
return np.sqrt(np.sum(np.real(x * np.conj(x)), keepdims=keepdims))
|
2019-02-07 10:51:55 -05:00
|
|
|
axis = tuple(range(ndim))
|
|
|
|
elif isinstance(axis, tuple):
|
|
|
|
axis = tuple(np._canonicalize_axis(x, ndim) for x in axis)
|
|
|
|
else:
|
|
|
|
axis = (np._canonicalize_axis(axis, ndim),)
|
|
|
|
|
|
|
|
num_axes = len(axis)
|
|
|
|
if num_axes == 1:
|
|
|
|
if ord is None or ord == 2:
|
|
|
|
return np.sqrt(np.sum(np.real(x * np.conj(x)), axis=axis,
|
|
|
|
keepdims=keepdims))
|
|
|
|
elif ord == np.inf:
|
|
|
|
return np.amax(np.abs(x), axis=axis, keepdims=keepdims)
|
|
|
|
elif ord == -np.inf:
|
|
|
|
return np.amin(np.abs(x), axis=axis, keepdims=keepdims)
|
|
|
|
elif ord == 0:
|
2019-04-12 16:28:40 -07:00
|
|
|
return np.sum(x != 0, dtype=np.finfo(lax.dtype(x)).dtype,
|
2019-02-07 10:51:55 -05:00
|
|
|
axis=axis, keepdims=keepdims)
|
|
|
|
elif ord == 1:
|
|
|
|
# Numpy has a special case for ord == 1 as an optimization. We don't
|
|
|
|
# really need the optimization (XLA could do it for us), but the Numpy
|
|
|
|
# code has slightly different type promotion semantics, so we need a
|
|
|
|
# special case too.
|
|
|
|
return np.sum(np.abs(x), axis=axis, keepdims=keepdims)
|
|
|
|
else:
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
abs_x = np.abs(x)
|
|
|
|
ord = lax._const(abs_x, ord)
|
|
|
|
out = np.sum(abs_x ** ord, axis=axis, keepdims=keepdims)
|
|
|
|
return np.power(out, 1. / ord)
|
2019-02-07 10:51:55 -05:00
|
|
|
|
|
|
|
elif num_axes == 2:
|
2020-01-18 08:26:23 -05:00
|
|
|
row_axis, col_axis = cast(Tuple[int, ...], axis)
|
2019-02-07 10:51:55 -05:00
|
|
|
if ord is None or ord in ('f', 'fro'):
|
|
|
|
return np.sqrt(np.sum(np.real(x * np.conj(x)), axis=axis,
|
|
|
|
keepdims=keepdims))
|
|
|
|
elif ord == 1:
|
|
|
|
if not keepdims and col_axis > row_axis:
|
|
|
|
col_axis -= 1
|
|
|
|
return np.amax(np.sum(np.abs(x), axis=row_axis, keepdims=keepdims),
|
|
|
|
axis=col_axis, keepdims=keepdims)
|
|
|
|
elif ord == -1:
|
|
|
|
if not keepdims and col_axis > row_axis:
|
|
|
|
col_axis -= 1
|
|
|
|
return np.amin(np.sum(np.abs(x), axis=row_axis, keepdims=keepdims),
|
|
|
|
axis=col_axis, keepdims=keepdims)
|
|
|
|
elif ord == np.inf:
|
|
|
|
if not keepdims and row_axis > col_axis:
|
|
|
|
row_axis -= 1
|
|
|
|
return np.amax(np.sum(np.abs(x), axis=col_axis, keepdims=keepdims),
|
|
|
|
axis=row_axis, keepdims=keepdims)
|
|
|
|
elif ord == -np.inf:
|
|
|
|
if not keepdims and row_axis > col_axis:
|
|
|
|
row_axis -= 1
|
|
|
|
return np.amin(np.sum(np.abs(x), axis=col_axis, keepdims=keepdims),
|
|
|
|
axis=row_axis, keepdims=keepdims)
|
|
|
|
elif ord in ('nuc', 2, -2):
|
|
|
|
x = np.moveaxis(x, axis, (-2, -1))
|
|
|
|
if ord == 2:
|
|
|
|
reducer = np.amax
|
|
|
|
elif ord == -2:
|
|
|
|
reducer = np.amin
|
|
|
|
else:
|
|
|
|
reducer = np.sum
|
|
|
|
y = reducer(svd(x, compute_uv=False), axis=-1)
|
|
|
|
if keepdims:
|
|
|
|
result_shape = list(x_shape)
|
|
|
|
result_shape[axis[0]] = 1
|
|
|
|
result_shape[axis[1]] = 1
|
|
|
|
y = np.reshape(y, result_shape)
|
|
|
|
return y
|
|
|
|
else:
|
|
|
|
raise ValueError("Invalid order '{}' for matrix norm.".format(ord))
|
|
|
|
else:
|
|
|
|
raise ValueError(
|
|
|
|
"Invalid axis values ({}) for np.linalg.norm.".format(axis))
|
|
|
|
|
2019-08-07 09:21:07 -04:00
|
|
|
@_wraps(onp.linalg.norm)
|
|
|
|
def norm(x, ord=None, axis=None, keepdims=False):
|
|
|
|
return _norm(x, ord, axis, keepdims)
|
|
|
|
|
2019-02-07 10:51:55 -05:00
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
@_wraps(onp.linalg.qr)
|
|
|
|
def qr(a, mode="reduced"):
|
|
|
|
if mode in ("reduced", "r", "full"):
|
|
|
|
full_matrices = False
|
|
|
|
elif mode == "complete":
|
|
|
|
full_matrices = True
|
|
|
|
else:
|
|
|
|
raise ValueError("Unsupported QR decomposition mode '{}'".format(mode))
|
2019-01-07 18:10:08 -05:00
|
|
|
a = _promote_arg_dtypes(np.asarray(a))
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
q, r = lax_linalg.qr(a, full_matrices)
|
|
|
|
if mode == "r":
|
|
|
|
return r
|
|
|
|
return q, r
|
2018-12-11 12:44:02 -05:00
|
|
|
|
2018-12-21 16:29:45 -05:00
|
|
|
|
Faster gradient rules for {numpy,scipy}.linalg.solve (#2220)
Fixes GH1747
The implicit function theorem (via `lax.custom_linear_solve`) lets us
_directly_ define gradients for linear solves, in contrast to the current
implementations of gradient for `solve` which rely upon differentiating matrix
factorization.
In **theory**, JVPs of `cholesky` and `lu` involve the equivalent of ~3 dense
matrix-matrix multiplications, which makes them rather expensive: time
`O(n**3)`. In contrast, with `custom_linear_solve` we don't need to
differentiate the factorization. The JVP and VJP rules for linear solve (for a
single right-hand-side vector) now only use matrix-vector products and
triangular solves, which is time `O(n**2)`. We should also have reduced memory
usage, because we don't need to save any intermediate outputs.
In **practice**, these new gradient rules seem to make solves with large
arrays ~3x faster:
from functools import partial
import jax.scipy as jsp
from jax import lax
import jax.numpy as np
import numpy as onp
import jax
def loss(solve):
def f(a, b):
return solve(a, b).sum()
return f
rs = onp.random.RandomState(0)
N = 500
K = 1
a = rs.randn(N, N)
a = jax.device_put(a.T @ a + 0.1 * np.eye(N))
b = jax.device_put(rs.randn(N, K))
# general matrix solve
grad = jax.jit(jax.grad(loss(np.linalg.solve)))
grad(a, b).block_until_ready()
%timeit grad(a, b).block_until_ready()
# N=500, K=1: 11.4 ms -> 3.63 ms
# positive definite solve
grad = jax.jit(jax.grad(loss(partial(jsp.linalg.solve, sym_pos=True))))
grad(a, b).block_until_ready()
%timeit grad(a, b).block_until_ready()
# N=500, K=1: 9.22 ms -> 2.83 ms
2020-02-18 17:41:38 -08:00
|
|
|
def _check_solve_shapes(a, b):
|
2020-02-12 17:05:18 -08:00
|
|
|
if not (a.ndim >= 2 and a.shape[-1] == a.shape[-2] and b.ndim >= 1):
|
2018-12-21 16:29:45 -05:00
|
|
|
msg = ("The arguments to solve must have shapes a=[..., m, m] and "
|
|
|
|
"b=[..., m, k] or b=[..., m]; got a={} and b={}")
|
2020-02-12 17:05:18 -08:00
|
|
|
raise ValueError(msg.format(a.shape, b.shape))
|
Faster gradient rules for {numpy,scipy}.linalg.solve (#2220)
Fixes GH1747
The implicit function theorem (via `lax.custom_linear_solve`) lets us
_directly_ define gradients for linear solves, in contrast to the current
implementations of gradient for `solve` which rely upon differentiating matrix
factorization.
In **theory**, JVPs of `cholesky` and `lu` involve the equivalent of ~3 dense
matrix-matrix multiplications, which makes them rather expensive: time
`O(n**3)`. In contrast, with `custom_linear_solve` we don't need to
differentiate the factorization. The JVP and VJP rules for linear solve (for a
single right-hand-side vector) now only use matrix-vector products and
triangular solves, which is time `O(n**2)`. We should also have reduced memory
usage, because we don't need to save any intermediate outputs.
In **practice**, these new gradient rules seem to make solves with large
arrays ~3x faster:
from functools import partial
import jax.scipy as jsp
from jax import lax
import jax.numpy as np
import numpy as onp
import jax
def loss(solve):
def f(a, b):
return solve(a, b).sum()
return f
rs = onp.random.RandomState(0)
N = 500
K = 1
a = rs.randn(N, N)
a = jax.device_put(a.T @ a + 0.1 * np.eye(N))
b = jax.device_put(rs.randn(N, K))
# general matrix solve
grad = jax.jit(jax.grad(loss(np.linalg.solve)))
grad(a, b).block_until_ready()
%timeit grad(a, b).block_until_ready()
# N=500, K=1: 11.4 ms -> 3.63 ms
# positive definite solve
grad = jax.jit(jax.grad(loss(partial(jsp.linalg.solve, sym_pos=True))))
grad(a, b).block_until_ready()
%timeit grad(a, b).block_until_ready()
# N=500, K=1: 9.22 ms -> 2.83 ms
2020-02-18 17:41:38 -08:00
|
|
|
|
|
|
|
|
|
|
|
@partial(vectorize, signature='(n,m),(m)->(n)')
|
|
|
|
def _matvec_multiply(a, b):
|
|
|
|
return np.dot(a, b, precision=lax.Precision.HIGHEST)
|
|
|
|
|
|
|
|
|
|
|
|
@_wraps(onp.linalg.solve)
|
|
|
|
@jit
|
|
|
|
def solve(a, b):
|
|
|
|
a, b = _promote_arg_dtypes(np.asarray(a), np.asarray(b))
|
|
|
|
_check_solve_shapes(a, b)
|
|
|
|
|
|
|
|
# With custom_linear_solve, we can reuse the same factorization when
|
|
|
|
# computing sensitivities. This is considerably faster.
|
|
|
|
lu, pivots = lax.stop_gradient(lax_linalg.lu)(a)
|
|
|
|
custom_solve = partial(
|
|
|
|
lax.custom_linear_solve,
|
|
|
|
lambda x: _matvec_multiply(a, x),
|
|
|
|
solve=lambda _, x: lax_linalg.lu_solve(lu, pivots, x, trans=0),
|
|
|
|
transpose_solve=lambda _, x: lax_linalg.lu_solve(lu, pivots, x, trans=1))
|
|
|
|
if a.ndim == b.ndim + 1:
|
|
|
|
# b.shape == [..., m]
|
|
|
|
return custom_solve(b)
|
|
|
|
else:
|
|
|
|
# b.shape == [..., m, k]
|
|
|
|
return vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b)
|
2018-12-21 16:29:45 -05:00
|
|
|
|
|
|
|
|
2018-12-13 11:52:41 -08:00
|
|
|
for func in get_module_functions(onp.linalg):
|
2018-12-11 12:44:02 -05:00
|
|
|
if func.__name__ not in globals():
|
|
|
|
globals()[func.__name__] = _not_implemented(func)
|