Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
import numpy as onp
|
|
|
|
|
|
|
|
from jax.numpy import lax_numpy as np
|
|
|
|
from jax import core
|
|
|
|
from jax import lax
|
2019-01-09 13:22:25 -05:00
|
|
|
from jax import ad_util
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
from jax.interpreters import xla
|
|
|
|
from jax.interpreters import ad
|
|
|
|
from jax.util import partial
|
|
|
|
from jax.abstract_arrays import ShapedArray
|
|
|
|
from jax.core import Primitive
|
|
|
|
from jax.lax import (standard_primitive, standard_unop, binop_dtype_rule,
|
|
|
|
_float, _complex, _input_dtype)
|
2018-12-20 15:37:34 -05:00
|
|
|
from jaxlib import lapack
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
# traceables
|
|
|
|
|
|
|
|
def cholesky(x): return cholesky_p.bind(x)
|
|
|
|
|
2019-01-07 18:10:08 -05:00
|
|
|
def eigh(x, lower=True): return eigh_p.bind(x, lower=lower)
|
2018-12-22 14:54:26 -05:00
|
|
|
|
2018-12-20 15:37:34 -05:00
|
|
|
def lu(x): return lu_p.bind(x)
|
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
def qr(x, full_matrices=True):
|
|
|
|
q, r = qr_p.bind(x, full_matrices=full_matrices)
|
|
|
|
return q, r
|
|
|
|
|
2019-01-05 11:13:08 +05:30
|
|
|
def svd(x, full_matrices=True, compute_uv=True):
|
2019-01-08 09:24:48 +05:30
|
|
|
s, u, v = svd_p.bind(x, full_matrices=full_matrices, compute_uv=compute_uv)
|
2019-01-05 11:13:08 +05:30
|
|
|
if compute_uv:
|
2019-01-08 21:47:19 +05:30
|
|
|
return u, s, v
|
2019-01-05 11:13:08 +05:30
|
|
|
else:
|
|
|
|
return s
|
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
def triangular_solve(a, b, left_side=False, lower=False, transpose_a=False,
|
|
|
|
conjugate_a=False):
|
|
|
|
return triangular_solve_p.bind(
|
|
|
|
a, b, left_side=left_side, lower=lower, transpose_a=transpose_a,
|
|
|
|
conjugate_a=conjugate_a)
|
|
|
|
|
|
|
|
|
|
|
|
# utilities
|
|
|
|
|
|
|
|
def _T(x):
|
|
|
|
return np.swapaxes(x, -1, -2)
|
|
|
|
|
|
|
|
|
|
|
|
# primitives
|
|
|
|
|
2019-01-18 16:47:26 -05:00
|
|
|
_cpu_lapack_types = {np.float32, np.float64, np.complex64, np.complex128}
|
2018-12-21 16:29:45 -05:00
|
|
|
|
2018-12-22 14:54:26 -05:00
|
|
|
# Cholesky decomposition
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
def cholesky_jvp_rule(primals, tangents):
|
|
|
|
x, = primals
|
|
|
|
sigma_dot, = tangents
|
|
|
|
L = cholesky_p.bind(x)
|
|
|
|
|
|
|
|
# Forward-mode rule from https://arxiv.org/pdf/1602.07527.pdf
|
|
|
|
sigma_dot = (sigma_dot + _T(sigma_dot)) / 2
|
|
|
|
phi = lambda X: np.tril(X) / (1 + np.eye(x.shape[-1]))
|
|
|
|
tmp = triangular_solve(L, sigma_dot,
|
|
|
|
left_side=False, transpose_a=True, lower=True)
|
|
|
|
L_dot = lax.dot(L, phi(triangular_solve(
|
|
|
|
L, tmp, left_side=True, transpose_a=False, lower=True)))
|
|
|
|
return L, L_dot
|
|
|
|
|
2018-12-21 16:29:45 -05:00
|
|
|
cholesky_p = standard_unop(_float | _complex, 'cholesky')
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
ad.primitive_jvps[cholesky_p] = cholesky_jvp_rule
|
|
|
|
|
|
|
|
|
2018-12-17 14:36:21 -05:00
|
|
|
def cholesky_cpu_translation_rule(c, operand):
|
|
|
|
shape = c.GetShape(operand)
|
2018-12-21 16:29:45 -05:00
|
|
|
dtype = shape.element_type().type
|
|
|
|
if len(shape.dimensions()) == 2 and dtype in _cpu_lapack_types:
|
2018-12-20 15:37:34 -05:00
|
|
|
return c.GetTupleElement(lapack.jax_potrf(c, operand, lower=True), 0)
|
2018-12-17 14:36:21 -05:00
|
|
|
else:
|
|
|
|
# Fall back to the HLO implementation for batched Cholesky decomposition or
|
|
|
|
# unsupported types.
|
|
|
|
# TODO(phawkins): support LAPACK primitives in batched mode.
|
|
|
|
return c.Cholesky(operand)
|
|
|
|
|
2018-12-20 15:37:34 -05:00
|
|
|
xla.backend_specific_translations['Host'][cholesky_p] = cholesky_cpu_translation_rule
|
2018-12-17 14:36:21 -05:00
|
|
|
|
|
|
|
|
2018-12-22 14:54:26 -05:00
|
|
|
# Symmetric/Hermitian eigendecomposition
|
|
|
|
|
2019-01-07 18:10:08 -05:00
|
|
|
def eigh_impl(operand, lower):
|
2019-01-07 18:28:48 -05:00
|
|
|
v, w = xla.apply_primitive(eigh_p, operand, lower=lower)
|
|
|
|
return core.pack((v, w))
|
2018-12-22 14:54:26 -05:00
|
|
|
|
2019-01-07 18:10:08 -05:00
|
|
|
def eigh_translation_rule(c, operand, lower):
|
2018-12-22 14:54:26 -05:00
|
|
|
raise NotImplementedError(
|
|
|
|
"Symmetric eigendecomposition is only implemented on the CPU backend")
|
|
|
|
|
2019-01-07 18:10:08 -05:00
|
|
|
def eigh_abstract_eval(operand, lower):
|
2018-12-22 14:54:26 -05:00
|
|
|
if isinstance(operand, ShapedArray):
|
|
|
|
if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
|
|
|
|
raise ValueError(
|
|
|
|
"Argument to symmetric eigendecomposition must have shape [..., n, n]")
|
|
|
|
|
|
|
|
batch_dims = operand.shape[:-2]
|
|
|
|
n = operand.shape[-1]
|
|
|
|
v = ShapedArray(batch_dims + (n, n), operand.dtype)
|
2019-01-07 18:28:48 -05:00
|
|
|
w = ShapedArray(batch_dims + (n,), operand.dtype)
|
2018-12-22 14:54:26 -05:00
|
|
|
else:
|
2019-01-07 18:28:48 -05:00
|
|
|
v, w = operand, operand
|
|
|
|
return core.AbstractTuple((v, w))
|
2018-12-22 14:54:26 -05:00
|
|
|
|
2019-01-07 18:10:08 -05:00
|
|
|
def eigh_cpu_translation_rule(c, operand, lower):
|
2018-12-22 14:54:26 -05:00
|
|
|
shape = c.GetShape(operand)
|
|
|
|
dtype = shape.element_type().type
|
|
|
|
if len(shape.dimensions()) == 2 and dtype in _cpu_lapack_types:
|
2019-01-07 18:10:08 -05:00
|
|
|
out = lapack.jax_syevd(c, operand, lower=lower)
|
2018-12-22 14:54:26 -05:00
|
|
|
return c.Tuple(c.GetTupleElement(out, 0), c.GetTupleElement(out, 1))
|
|
|
|
else:
|
|
|
|
raise NotImplementedError(
|
|
|
|
"Only unbatched eigendecomposition is implemented on CPU")
|
|
|
|
|
|
|
|
eigh_p = Primitive('eigh')
|
|
|
|
eigh_p.def_impl(eigh_impl)
|
|
|
|
eigh_p.def_abstract_eval(eigh_abstract_eval)
|
|
|
|
xla.translations[eigh_p] = eigh_translation_rule
|
|
|
|
xla.backend_specific_translations['Host'][eigh_p] = eigh_cpu_translation_rule
|
|
|
|
|
|
|
|
|
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
triangular_solve_dtype_rule = partial(
|
|
|
|
binop_dtype_rule, _input_dtype, (_float | _complex, _float | _complex),
|
|
|
|
'triangular_solve')
|
|
|
|
|
|
|
|
def triangular_solve_shape_rule(a, b, left_side=False, **unused_kwargs):
|
|
|
|
if a.ndim < 2:
|
|
|
|
msg = "triangular_solve requires a.ndim to be at least 2, got {}."
|
|
|
|
raise TypeError(msg.format(a.ndim))
|
|
|
|
if a.shape[-1] != a.shape[-2]:
|
|
|
|
msg = ("triangular_solve requires the last two dimensions of a to be equal "
|
|
|
|
"in size, got a.shape of {}.")
|
|
|
|
raise TypeError(msg.format(a.shape))
|
|
|
|
if a.shape[:-2] != b.shape[:-2]:
|
|
|
|
msg = ("triangular_solve requires both arguments to have the same number "
|
|
|
|
"of dimensions and equal batch dimensions, got {} and {}.")
|
|
|
|
raise TypeError(msg.format(a.shape, b.shape))
|
|
|
|
common_dim = -2 if left_side else -1
|
|
|
|
if a.shape[-1] != b.shape[common_dim]:
|
|
|
|
msg = "Incompatible shapes for arguments to triangular_solve: {} and {}."
|
|
|
|
raise TypeError(msg.format(a.shape, b.shape))
|
|
|
|
return b.shape
|
|
|
|
|
2018-12-17 17:20:52 -08:00
|
|
|
def triangular_solve_jvp_rule_a(
|
|
|
|
g_a, ans, a, b, left_side, lower, transpose_a, conjugate_a):
|
|
|
|
g_a = lax.neg(g_a)
|
2018-12-19 17:47:56 -05:00
|
|
|
g_a = np.swapaxes(g_a, -1, -2) if transpose_a else g_a
|
2018-12-17 17:20:52 -08:00
|
|
|
tmp = triangular_solve(a, g_a, left_side, lower, transpose_a, conjugate_a)
|
|
|
|
dot = lax.dot if g_a.ndim == 2 else lax.batch_matmul
|
|
|
|
if left_side:
|
|
|
|
return dot(tmp, ans)
|
|
|
|
else:
|
|
|
|
return dot(ans, tmp)
|
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
def triangular_solve_transpose_rule(
|
|
|
|
cotangent, a, b, left_side, lower, transpose_a, conjugate_a):
|
|
|
|
assert a is not None and b is None
|
|
|
|
cotangent_b = triangular_solve(a, cotangent, left_side, lower,
|
|
|
|
not transpose_a, conjugate_a)
|
|
|
|
return [None, cotangent_b]
|
|
|
|
|
|
|
|
triangular_solve_p = standard_primitive(
|
|
|
|
triangular_solve_shape_rule, triangular_solve_dtype_rule,
|
|
|
|
'triangular_solve')
|
2018-12-17 17:20:52 -08:00
|
|
|
ad.defjvp2(triangular_solve_p,
|
|
|
|
triangular_solve_jvp_rule_a,
|
|
|
|
lambda g_b, _, a, b, **kws: triangular_solve(a, g_b, **kws))
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
ad.primitive_transposes[triangular_solve_p] = triangular_solve_transpose_rule
|
|
|
|
|
|
|
|
|
2018-12-17 16:39:19 -05:00
|
|
|
def triangular_solve_cpu_translation_rule(
|
|
|
|
c, a, b, left_side, lower, transpose_a, conjugate_a):
|
|
|
|
shape = c.GetShape(a)
|
2018-12-21 16:29:45 -05:00
|
|
|
dtype = shape.element_type().type
|
|
|
|
if len(shape.dimensions()) == 2 and dtype in _cpu_lapack_types:
|
2018-12-20 15:37:34 -05:00
|
|
|
return lapack.jax_trsm(
|
2018-12-21 16:29:45 -05:00
|
|
|
c, c.Constant(onp.array(1, dtype=dtype)), a, b, left_side, lower,
|
|
|
|
transpose_a, conjugate_a)
|
2018-12-17 16:39:19 -05:00
|
|
|
else:
|
|
|
|
# Fall back to the HLO implementation for batched triangular_solve or
|
|
|
|
# unsupported types.
|
|
|
|
# TODO(phawkins): support BLAS primitives in batched mode.
|
|
|
|
return c.TriangularSolve(a, b, left_side, lower, transpose_a, conjugate_a)
|
|
|
|
|
2018-12-20 15:37:34 -05:00
|
|
|
xla.backend_specific_translations['Host'][triangular_solve_p] = triangular_solve_cpu_translation_rule
|
|
|
|
|
|
|
|
|
|
|
|
# LU decomposition
|
|
|
|
|
|
|
|
# Computes a pivoted LU decomposition such that
|
|
|
|
# PA = LU
|
|
|
|
# In the style of LAPACK, LU are stored in the same matrix.
|
|
|
|
# TODO(phawkins): add a mechanism to report errors for singular matrices.
|
|
|
|
|
|
|
|
def lu_impl(operand):
|
|
|
|
lu, pivot = xla.apply_primitive(lu_p, operand)
|
|
|
|
return core.pack((lu, pivot))
|
|
|
|
|
|
|
|
def lu_translation_rule(c, operand):
|
|
|
|
raise NotImplementedError(
|
|
|
|
"LU decomposition is only implemented on the CPU backend")
|
|
|
|
|
|
|
|
def lu_abstract_eval(operand):
|
|
|
|
if isinstance(operand, ShapedArray):
|
|
|
|
if operand.ndim < 2:
|
|
|
|
raise ValueError("Argument to LU decomposition must have ndims >= 2")
|
|
|
|
|
|
|
|
batch_dims = operand.shape[:-2]
|
|
|
|
m = operand.shape[-2]
|
|
|
|
n = operand.shape[-1]
|
|
|
|
pivot = ShapedArray(batch_dims + (min(m, n),), np.int32)
|
|
|
|
else:
|
|
|
|
pivot = operand
|
|
|
|
return core.AbstractTuple((operand, pivot))
|
|
|
|
|
2018-12-22 14:53:42 -05:00
|
|
|
def lu_jvp_rule(primals, tangents):
|
|
|
|
a, = primals
|
|
|
|
a_dot, = tangents
|
|
|
|
lu, pivots = lu_p.bind(a)
|
|
|
|
|
2019-01-09 15:52:34 -05:00
|
|
|
a_shape = np.shape(a)
|
|
|
|
m, n = a_shape[-2:]
|
2018-12-22 14:53:42 -05:00
|
|
|
dtype = lax._dtype(a)
|
|
|
|
k = min(m, n)
|
|
|
|
|
|
|
|
# TODO(phawkins): use a gather rather than a matrix multiplication here.
|
|
|
|
permutation = lu_pivots_to_permutation(pivots, m)
|
2019-01-09 13:22:25 -05:00
|
|
|
p = np.array(permutation[:, None] == np.arange(m), dtype=dtype)
|
2018-12-22 14:53:42 -05:00
|
|
|
x = np.matmul(p, a_dot)
|
|
|
|
|
|
|
|
# Differentiation of Matrix Functionals Using Triangular Factorization
|
|
|
|
# F. R. De Hoog, R. S. Anderssen, and M. A. Lukas
|
2019-01-09 13:22:25 -05:00
|
|
|
#
|
|
|
|
# LU = A
|
|
|
|
# ==> L'U + LU' = A'
|
|
|
|
# ==> inv(L) . L' + U' . inv(U) = inv(L) A' inv(U)
|
|
|
|
# ==> L' = L . tril(inv(L) . A' . inv(U), -1)
|
|
|
|
# U' = triu(inv(L) . A' . inv(U)) . U
|
2018-12-22 14:53:42 -05:00
|
|
|
|
2019-01-09 15:52:34 -05:00
|
|
|
ndims = len(a_shape)
|
|
|
|
l_padding = [(0, 0, 0)] * ndims
|
|
|
|
l_padding[-1] = (0, m - k, 0)
|
|
|
|
zero = np._constant_like(lu, 0)
|
|
|
|
l = lax.pad(np.tril(lu[..., :, :k], -1), zero, l_padding)
|
|
|
|
l = l + np.eye(m, m, dtype=dtype)
|
|
|
|
|
|
|
|
u_eye = lax.pad(np.eye(n - k, n - k, dtype=dtype), zero,
|
|
|
|
((k, 0, 0), (k, 0, 0)))
|
|
|
|
u_padding = [(0, 0, 0)] * ndims
|
|
|
|
u_padding[-2] = (0, n - k, 0)
|
|
|
|
u = lax.pad(np.triu(lu[..., :k, :]), zero, u_padding) + u_eye
|
|
|
|
|
2018-12-22 14:53:42 -05:00
|
|
|
|
|
|
|
la = triangular_solve(l, x, left_side=True, transpose_a=False, lower=True)
|
|
|
|
lau = triangular_solve(u, la, left_side=False, transpose_a=False,
|
|
|
|
lower=False)
|
|
|
|
|
2019-01-09 15:52:34 -05:00
|
|
|
l_dot = np.matmul(l, np.tril(lau, -1))
|
|
|
|
u_dot = np.matmul(np.triu(lau), u)
|
2019-01-09 13:22:25 -05:00
|
|
|
lu_dot = l_dot + u_dot
|
|
|
|
return core.pack((lu, pivots)), ad.TangentTuple((lu_dot, ad_util.zero))
|
2018-12-22 14:53:42 -05:00
|
|
|
|
|
|
|
|
2018-12-20 15:37:34 -05:00
|
|
|
lu_p = Primitive('lu')
|
|
|
|
lu_p.def_impl(lu_impl)
|
|
|
|
lu_p.def_abstract_eval(lu_abstract_eval)
|
|
|
|
xla.translations[lu_p] = lu_translation_rule
|
2018-12-22 14:53:42 -05:00
|
|
|
ad.primitive_jvps[lu_p] = lu_jvp_rule
|
2018-12-20 15:37:34 -05:00
|
|
|
|
|
|
|
def lu_cpu_translation_rule(c, operand):
|
|
|
|
shape = c.GetShape(operand)
|
2018-12-21 16:29:45 -05:00
|
|
|
dtype = shape.element_type().type
|
|
|
|
if len(shape.dimensions()) == 2 and dtype in _cpu_lapack_types:
|
2018-12-20 15:37:34 -05:00
|
|
|
out = lapack.jax_getrf(c, operand)
|
|
|
|
lu = c.GetTupleElement(out, 0)
|
|
|
|
# Subtract 1 from the pivot to get 0-based indices.
|
|
|
|
pivot = c.Sub(c.GetTupleElement(out, 1), c.ConstantS32Scalar(1))
|
|
|
|
# Throw away the `info` value, because we have no way to report errors.
|
|
|
|
return c.Tuple(lu, pivot)
|
|
|
|
else:
|
|
|
|
raise NotImplementedError("Only unbatched LU decomposition is implemented")
|
|
|
|
|
2018-12-20 21:04:02 -05:00
|
|
|
# TODO(phawkins): The hasattr() test here is to avoid incompatibilities between
|
|
|
|
# jax and an older jaxlib. Remove after a jaxlib release includes jax_getrf.
|
|
|
|
if hasattr(lapack, "jax_getrf"):
|
|
|
|
xla.backend_specific_translations['Host'][lu_p] = lu_cpu_translation_rule
|
2018-12-17 16:39:19 -05:00
|
|
|
|
|
|
|
|
2018-12-20 15:37:34 -05:00
|
|
|
def lu_pivots_to_permutation(swaps, k):
|
|
|
|
"""Converts the pivots (row swaps) returned by LU to a permutation."""
|
|
|
|
|
|
|
|
def body_fn(i, loop_carry):
|
|
|
|
swaps, permutation = loop_carry
|
|
|
|
j = swaps[i]
|
|
|
|
x, y = np.ravel(permutation[i]), np.ravel(permutation[j])
|
|
|
|
permutation = lax.dynamic_update_index_in_dim(permutation, y, i, axis=0)
|
|
|
|
permutation = lax.dynamic_update_index_in_dim(permutation, x, j, axis=0)
|
|
|
|
return swaps, permutation
|
|
|
|
|
|
|
|
n, = np.shape(swaps)
|
|
|
|
permutation = np.arange(k)
|
|
|
|
_, permutation = lax.fori_loop(onp.array(0, onp.int32), onp.array(n, onp.int32),
|
|
|
|
body_fn, (swaps, permutation))
|
|
|
|
return permutation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# QR decomposition
|
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
def qr_impl(operand, full_matrices):
|
|
|
|
q, r = xla.apply_primitive(qr_p, operand, full_matrices=full_matrices)
|
|
|
|
return core.pack((q, r))
|
|
|
|
|
|
|
|
def qr_translation_rule(c, operand, full_matrices):
|
|
|
|
return c.QR(operand, full_matrices=full_matrices)
|
|
|
|
|
|
|
|
def qr_abstract_eval(operand, full_matrices):
|
|
|
|
if isinstance(operand, ShapedArray):
|
|
|
|
if operand.ndim < 2:
|
|
|
|
raise ValueError("Argument to QR decomposition must have ndims >= 2")
|
|
|
|
batch_dims = operand.shape[:-2]
|
|
|
|
m = operand.shape[-2]
|
|
|
|
n = operand.shape[-1]
|
|
|
|
k = m if full_matrices else min(m, n)
|
|
|
|
q = ShapedArray(batch_dims + (m, k), operand.dtype)
|
|
|
|
r = ShapedArray(batch_dims + (k, n), operand.dtype)
|
|
|
|
else:
|
|
|
|
q = operand
|
|
|
|
r = operand
|
|
|
|
return core.AbstractTuple((q, r))
|
|
|
|
|
2018-12-17 16:02:29 +00:00
|
|
|
def qr_jvp_rule(primals, tangents, full_matrices):
|
2018-12-17 16:04:51 +00:00
|
|
|
# See j-towns.github.io/papers/qr-derivative.pdf for a terse derivation.
|
2018-12-17 16:02:29 +00:00
|
|
|
x, = primals
|
2018-12-17 16:36:55 +00:00
|
|
|
if full_matrices or np.shape(x)[-2] < np.shape(x)[-1]:
|
|
|
|
raise NotImplementedError
|
2018-12-17 16:02:29 +00:00
|
|
|
dx, = tangents
|
|
|
|
q, r = qr_p.bind(x, full_matrices=False)
|
|
|
|
dx_rinv = triangular_solve(r, dx) # Right side solve by default
|
|
|
|
qt_dx_rinv = np.matmul(_T(q), dx_rinv)
|
|
|
|
qt_dx_rinv_lower = np.tril(qt_dx_rinv, -1)
|
|
|
|
domega = qt_dx_rinv_lower - _T(qt_dx_rinv_lower) # This is skew-symmetric
|
|
|
|
dq = np.matmul(q, domega - qt_dx_rinv) + dx_rinv
|
|
|
|
dr = np.matmul(qt_dx_rinv - domega, r)
|
|
|
|
return core.pack((q, r)), core.pack((dq, dr))
|
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
qr_p = Primitive('qr')
|
|
|
|
qr_p.def_impl(qr_impl)
|
|
|
|
qr_p.def_abstract_eval(qr_abstract_eval)
|
|
|
|
xla.translations[qr_p] = qr_translation_rule
|
2018-12-17 16:02:29 +00:00
|
|
|
ad.primitive_jvps[qr_p] = qr_jvp_rule
|
2019-01-05 11:13:08 +05:30
|
|
|
|
|
|
|
|
2019-01-08 21:47:19 +05:30
|
|
|
# Singular value decomposition
|
2019-01-05 11:13:08 +05:30
|
|
|
|
|
|
|
def svd_impl(operand, full_matrices, compute_uv):
|
|
|
|
s, u, vt = xla.apply_primitive(svd_p, operand, full_matrices=full_matrices, compute_uv=compute_uv)
|
|
|
|
return core.pack((s, u, vt))
|
|
|
|
|
2019-01-08 09:24:48 +05:30
|
|
|
def svd_translation_rule(c, operand, full_matrices, compute_uv):
|
2019-01-05 11:13:08 +05:30
|
|
|
raise NotImplementedError(
|
2019-01-08 09:24:48 +05:30
|
|
|
"Singular value decomposition is only implemented on the CPU backend")
|
2019-01-05 11:13:08 +05:30
|
|
|
|
|
|
|
def svd_abstract_eval(operand, full_matrices, compute_uv):
|
|
|
|
if isinstance(operand, ShapedArray):
|
|
|
|
if operand.ndim < 2:
|
2019-01-08 09:24:48 +05:30
|
|
|
raise ValueError("Argument to singular value decomposition must have ndims >= 2")
|
2019-01-05 11:13:08 +05:30
|
|
|
|
|
|
|
batch_dims = operand.shape[:-2]
|
|
|
|
m = operand.shape[-2]
|
|
|
|
n = operand.shape[-1]
|
|
|
|
s = ShapedArray(batch_dims + (min(m, n),), operand.dtype)
|
|
|
|
u = ShapedArray(batch_dims + (m, m if full_matrices else min(m, n)), operand.dtype)
|
|
|
|
vt = ShapedArray(batch_dims + (n if full_matrices else min(m, n), n), operand.dtype)
|
|
|
|
else:
|
|
|
|
s = operand
|
|
|
|
u = operand
|
|
|
|
vt = operand
|
|
|
|
return core.AbstractTuple((s, u, vt))
|
|
|
|
|
2019-01-08 09:24:48 +05:30
|
|
|
def svd_cpu_translation_rule(c, operand, full_matrices, compute_uv):
|
|
|
|
shape = c.GetShape(operand)
|
|
|
|
dtype = shape.element_type().type
|
2019-01-09 22:04:14 +05:30
|
|
|
if len(shape.dimensions()) == 2 and dtype in _cpu_lapack_types:
|
2019-01-08 09:24:48 +05:30
|
|
|
out = lapack.jax_gesdd(c, operand, full_matrices=full_matrices, compute_uv=compute_uv)
|
|
|
|
return c.Tuple(c.GetTupleElement(out, 0),
|
|
|
|
c.GetTupleElement(out, 1),
|
|
|
|
c.GetTupleElement(out, 2))
|
|
|
|
else:
|
|
|
|
raise NotImplementedError(
|
2019-01-09 22:04:14 +05:30
|
|
|
"Only unbatched singular value decomposition is implemented on CPU")
|
2019-01-08 09:24:48 +05:30
|
|
|
|
2019-01-05 11:13:08 +05:30
|
|
|
svd_p = Primitive('svd')
|
|
|
|
svd_p.def_impl(svd_impl)
|
|
|
|
svd_p.def_abstract_eval(svd_abstract_eval)
|
|
|
|
xla.translations[svd_p] = svd_translation_rule
|
2019-01-08 09:24:48 +05:30
|
|
|
xla.backend_specific_translations['Host'][svd_p] = svd_cpu_translation_rule
|