2019-10-10 08:27:21 -07:00
|
|
|
# coding=utf-8
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
|
|
|
|
2019-08-02 11:16:15 -04:00
|
|
|
from functools import partial
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
import numpy as onp
|
|
|
|
|
|
|
|
from jax.numpy import lax_numpy as np
|
2019-06-17 20:32:19 -04:00
|
|
|
from jax import ad_util
|
2019-06-25 16:52:28 -04:00
|
|
|
from jax import api
|
|
|
|
from jax import api_util
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
from jax import core
|
|
|
|
from jax import lax
|
2019-06-17 20:32:19 -04:00
|
|
|
from jax import ops
|
2019-11-15 10:02:51 -05:00
|
|
|
from jax import dtypes
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
from jax.interpreters import xla
|
|
|
|
from jax.interpreters import ad
|
2019-02-05 08:39:03 -08:00
|
|
|
from jax.interpreters import batching
|
2019-08-08 11:50:31 -04:00
|
|
|
from jax.util import partial, prod
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
from jax.abstract_arrays import ShapedArray
|
|
|
|
from jax.core import Primitive
|
2020-01-15 13:13:11 -08:00
|
|
|
from jax.lax import (standard_primitive, standard_unop, naryop_dtype_rule,
|
2019-05-29 16:50:27 -04:00
|
|
|
_float, _complex, _input_dtype, _broadcasting_select)
|
2019-10-08 16:09:50 -04:00
|
|
|
from jax.lib import xla_client
|
2019-07-29 15:06:05 -04:00
|
|
|
from jax.lib import lapack
|
2019-08-02 11:16:15 -04:00
|
|
|
from jax.lib import cusolver
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
# traceables
|
|
|
|
|
2019-02-11 16:18:13 -08:00
|
|
|
def cholesky(x, symmetrize_input=True):
|
|
|
|
if symmetrize_input:
|
|
|
|
x = symmetrize(x)
|
2019-02-22 18:31:59 -05:00
|
|
|
return np.tril(cholesky_p.bind(x))
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2019-05-13 15:59:58 -04:00
|
|
|
def eig(x):
|
|
|
|
w, vl, vr = eig_p.bind(x)
|
|
|
|
return w, vl, vr
|
|
|
|
|
2019-02-13 23:23:39 -08:00
|
|
|
def eigh(x, lower=True, symmetrize_input=True):
|
|
|
|
if symmetrize_input:
|
|
|
|
x = symmetrize(x)
|
2019-04-30 20:57:44 -04:00
|
|
|
v, w = eigh_p.bind(x, lower=lower)
|
|
|
|
return v, w
|
2018-12-22 14:54:26 -05:00
|
|
|
|
2019-04-30 20:57:44 -04:00
|
|
|
def lu(x):
|
|
|
|
lu, pivots = lu_p.bind(x)
|
|
|
|
return lu, pivots
|
2018-12-20 15:37:34 -05:00
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
def qr(x, full_matrices=True):
|
2019-04-30 20:57:44 -04:00
|
|
|
q, r = qr_p.bind(x, full_matrices=full_matrices)
|
2019-09-04 16:24:32 -04:00
|
|
|
return q, np.triu(r)
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2019-01-05 11:13:08 +05:30
|
|
|
def svd(x, full_matrices=True, compute_uv=True):
|
2019-01-08 09:24:48 +05:30
|
|
|
s, u, v = svd_p.bind(x, full_matrices=full_matrices, compute_uv=compute_uv)
|
2019-01-05 11:13:08 +05:30
|
|
|
if compute_uv:
|
2019-01-08 21:47:19 +05:30
|
|
|
return u, s, v
|
2019-01-05 11:13:08 +05:30
|
|
|
else:
|
|
|
|
return s
|
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
def triangular_solve(a, b, left_side=False, lower=False, transpose_a=False,
|
2019-06-25 15:24:22 -04:00
|
|
|
conjugate_a=False, unit_diagonal=False):
|
2019-06-18 23:38:03 -04:00
|
|
|
conjugate_a = conjugate_a and np.issubdtype(lax.dtype(a), np.complexfloating)
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
return triangular_solve_p.bind(
|
|
|
|
a, b, left_side=left_side, lower=lower, transpose_a=transpose_a,
|
2019-06-25 15:24:22 -04:00
|
|
|
conjugate_a=conjugate_a, unit_diagonal=unit_diagonal)
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
|
|
|
|
# utilities
|
|
|
|
|
2019-02-11 16:18:13 -08:00
|
|
|
def _T(x): return np.swapaxes(x, -1, -2)
|
|
|
|
def _H(x): return np.conj(_T(x))
|
|
|
|
def symmetrize(x): return (x + _H(x)) / 2
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2019-08-02 11:16:15 -04:00
|
|
|
def _unpack_tuple(f, n):
|
|
|
|
def g(c, *args, **kwargs):
|
|
|
|
t = f(c, *args, **kwargs)
|
|
|
|
return (c.GetTupleElement(t, i) for i in range(n))
|
|
|
|
return g
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
# primitives
|
|
|
|
|
2019-12-18 11:57:22 -05:00
|
|
|
_cpu_lapack_types = {onp.dtype(onp.float32), onp.dtype(onp.float64),
|
|
|
|
onp.dtype(onp.complex64), onp.dtype(onp.complex128)}
|
2018-12-21 16:29:45 -05:00
|
|
|
|
2018-12-22 14:54:26 -05:00
|
|
|
# Cholesky decomposition
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
def cholesky_jvp_rule(primals, tangents):
|
|
|
|
x, = primals
|
|
|
|
sigma_dot, = tangents
|
2019-02-22 18:31:59 -05:00
|
|
|
L = np.tril(cholesky_p.bind(x))
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
# Forward-mode rule from https://arxiv.org/pdf/1602.07527.pdf
|
2019-06-19 09:29:33 -04:00
|
|
|
def phi(X):
|
|
|
|
l = np.tril(X)
|
|
|
|
return l / (np._constant_like(X, 1) + np.eye(X.shape[-1], dtype=X.dtype))
|
|
|
|
|
|
|
|
tmp = triangular_solve(L, sigma_dot, left_side=False, transpose_a=True,
|
|
|
|
conjugate_a=True, lower=True)
|
2019-02-11 16:18:13 -08:00
|
|
|
L_dot = lax.batch_matmul(L, phi(triangular_solve(
|
2019-12-10 00:38:18 -08:00
|
|
|
L, tmp, left_side=True, transpose_a=False, lower=True)),
|
|
|
|
precision=lax.Precision.HIGHEST)
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
return L, L_dot
|
|
|
|
|
2019-02-05 08:39:03 -08:00
|
|
|
def cholesky_batching_rule(batched_args, batch_dims):
|
|
|
|
x, = batched_args
|
|
|
|
bd, = batch_dims
|
2019-07-27 15:46:14 -07:00
|
|
|
x = batching.moveaxis(x, bd, 0)
|
2019-02-05 08:39:03 -08:00
|
|
|
return cholesky(x), 0
|
|
|
|
|
2018-12-21 16:29:45 -05:00
|
|
|
cholesky_p = standard_unop(_float | _complex, 'cholesky')
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
ad.primitive_jvps[cholesky_p] = cholesky_jvp_rule
|
2019-02-05 08:39:03 -08:00
|
|
|
batching.primitive_batchers[cholesky_p] = cholesky_batching_rule
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2019-05-29 16:50:27 -04:00
|
|
|
def _nan_like(c, operand):
|
|
|
|
shape = c.GetShape(operand)
|
|
|
|
dtype = shape.element_type()
|
2019-10-29 20:53:20 -04:00
|
|
|
if np.issubdtype(dtype, onp.complexfloating):
|
2019-05-29 16:50:27 -04:00
|
|
|
nan = c.Constant(onp.array(onp.nan * (1. + 1j), dtype=dtype))
|
|
|
|
else:
|
|
|
|
nan = c.Constant(onp.array(onp.nan, dtype=dtype))
|
|
|
|
return c.Broadcast(nan, shape.dimensions())
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2020-01-07 10:56:15 -05:00
|
|
|
# TODO(phawkins): remove supports_batching argument after the minimum jaxlib
|
|
|
|
# version is 0.1.38.
|
|
|
|
def _cholesky_cpu_gpu_translation_rule(potrf_impl, potrf_supports_batching, c,
|
|
|
|
operand):
|
2018-12-17 14:36:21 -05:00
|
|
|
shape = c.GetShape(operand)
|
2020-01-07 10:56:15 -05:00
|
|
|
batch_dims = shape.dimensions()[:-2]
|
2018-12-21 16:29:45 -05:00
|
|
|
dtype = shape.element_type().type
|
2020-01-07 10:56:15 -05:00
|
|
|
if len(batch_dims) == 0 or potrf_supports_batching:
|
|
|
|
result, info = potrf_impl(c, operand, lower=True)
|
|
|
|
ok = c.Eq(info, c.ConstantS32Scalar(0))
|
|
|
|
return _broadcasting_select(c,
|
|
|
|
c.Reshape(ok, None, batch_dims + (1, 1)), result,
|
|
|
|
_nan_like(c, result))
|
2018-12-17 14:36:21 -05:00
|
|
|
else:
|
2020-01-07 10:56:15 -05:00
|
|
|
# Fall back to the HLO implementation for batched Cholesky decomposition.
|
2018-12-17 14:36:21 -05:00
|
|
|
return c.Cholesky(operand)
|
|
|
|
|
2020-01-07 10:56:15 -05:00
|
|
|
xla.backend_specific_translations['cpu'][cholesky_p] = partial(
|
|
|
|
_cholesky_cpu_gpu_translation_rule, lapack.potrf,
|
|
|
|
not hasattr(lapack, "jax_potrf"))
|
|
|
|
|
|
|
|
# TODO(phawkins): remove after the minimum jaxlib version is 0.1.38.
|
|
|
|
if hasattr(cusolver, "potrf"):
|
|
|
|
xla.backend_specific_translations['gpu'][cholesky_p] = partial(
|
|
|
|
_cholesky_cpu_gpu_translation_rule, cusolver.potrf, True)
|
2018-12-17 14:36:21 -05:00
|
|
|
|
2019-05-13 15:59:58 -04:00
|
|
|
# Asymmetric eigendecomposition
|
|
|
|
|
|
|
|
def eig_impl(operand):
|
2019-05-13 19:53:50 -04:00
|
|
|
return xla.apply_primitive(eig_p, operand)
|
2019-05-13 15:59:58 -04:00
|
|
|
|
|
|
|
def eig_translation_rule(c, operand):
|
|
|
|
raise NotImplementedError(
|
|
|
|
"Nonsymmetric eigendecomposition is only implemented on the CPU backend")
|
|
|
|
|
|
|
|
def eig_abstract_eval(operand):
|
|
|
|
if isinstance(operand, ShapedArray):
|
|
|
|
if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
|
|
|
|
raise ValueError("Argument to nonsymmetric eigendecomposition must have "
|
2019-05-13 19:53:50 -04:00
|
|
|
"shape [..., n, n], got shape {}".format(operand.shape))
|
2019-05-13 15:59:58 -04:00
|
|
|
|
|
|
|
batch_dims = operand.shape[:-2]
|
|
|
|
n = operand.shape[-1]
|
2019-11-15 10:02:51 -05:00
|
|
|
dtype = onp.complex64 if dtypes.finfo(operand.dtype).bits == 32 else onp.complex128
|
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2019-07-27 15:46:14 -07:00
|
|
|
vl = vr = ShapedArray(batch_dims + (n, n), dtype)
|
|
|
|
w = ShapedArray(batch_dims + (n,), dtype)
|
2019-05-13 15:59:58 -04:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
raise NotImplementedError
|
|
|
|
return w, vl, vr
|
2019-05-13 15:59:58 -04:00
|
|
|
|
2019-09-02 07:25:06 -07:00
|
|
|
_cpu_geev = lapack.geev
|
2019-05-29 16:50:27 -04:00
|
|
|
|
2019-05-13 15:59:58 -04:00
|
|
|
def eig_cpu_translation_rule(c, operand):
|
2019-05-29 16:50:27 -04:00
|
|
|
shape = c.GetShape(operand)
|
|
|
|
batch_dims = shape.dimensions()[:-2]
|
2019-08-02 11:16:15 -04:00
|
|
|
w, vl, vr, info = _cpu_geev(c, operand)
|
|
|
|
ok = c.Eq(info, c.ConstantS32Scalar(0))
|
2019-05-29 16:50:27 -04:00
|
|
|
w = _broadcasting_select(c, c.Reshape(ok, None, batch_dims + (1,)), w,
|
|
|
|
_nan_like(c, w))
|
|
|
|
vl = _broadcasting_select(c, c.Reshape(ok, None, batch_dims + (1, 1)), vl,
|
|
|
|
_nan_like(c, vl))
|
|
|
|
vr = _broadcasting_select(c, c.Reshape(ok, None, batch_dims + (1, 1)), vr,
|
|
|
|
_nan_like(c, vr))
|
|
|
|
return c.Tuple(w, vl, vr)
|
2019-05-13 15:59:58 -04:00
|
|
|
|
|
|
|
def eig_batching_rule(batched_args, batch_dims):
|
|
|
|
x, = batched_args
|
|
|
|
bd, = batch_dims
|
2019-07-27 15:46:14 -07:00
|
|
|
x = batching.moveaxis(x, bd, 0)
|
|
|
|
return eig_p.bind(x), (0, 0, 0)
|
2019-05-13 15:59:58 -04:00
|
|
|
|
|
|
|
eig_p = Primitive('eig')
|
2019-07-27 15:46:14 -07:00
|
|
|
eig_p.multiple_results = True
|
2019-05-13 15:59:58 -04:00
|
|
|
eig_p.def_impl(eig_impl)
|
|
|
|
eig_p.def_abstract_eval(eig_abstract_eval)
|
|
|
|
xla.translations[eig_p] = eig_translation_rule
|
|
|
|
xla.backend_specific_translations['cpu'][eig_p] = eig_cpu_translation_rule
|
|
|
|
batching.primitive_batchers[eig_p] = eig_batching_rule
|
|
|
|
|
2018-12-17 14:36:21 -05:00
|
|
|
|
2018-12-22 14:54:26 -05:00
|
|
|
# Symmetric/Hermitian eigendecomposition
|
|
|
|
|
2019-01-07 18:10:08 -05:00
|
|
|
def eigh_impl(operand, lower):
|
2019-01-07 18:28:48 -05:00
|
|
|
v, w = xla.apply_primitive(eigh_p, operand, lower=lower)
|
2019-07-27 15:46:14 -07:00
|
|
|
return v, w
|
2018-12-22 14:54:26 -05:00
|
|
|
|
2019-01-07 18:10:08 -05:00
|
|
|
def eigh_translation_rule(c, operand, lower):
|
2019-10-08 16:09:50 -04:00
|
|
|
shape = c.GetShape(operand)
|
|
|
|
dims = shape.dimensions()
|
|
|
|
if dims[-1] == 0:
|
|
|
|
return c.Tuple(operand, c.Reshape(operand, None, dims[:-1]))
|
|
|
|
if not lower:
|
|
|
|
n = len(dims)
|
|
|
|
operand = c.Transpose(operand, list(range(n - 2)) + [n - 1, n - 2])
|
|
|
|
return c.Eigh(operand)
|
2018-12-22 14:54:26 -05:00
|
|
|
|
2019-01-07 18:10:08 -05:00
|
|
|
def eigh_abstract_eval(operand, lower):
|
2018-12-22 14:54:26 -05:00
|
|
|
if isinstance(operand, ShapedArray):
|
|
|
|
if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
|
|
|
|
raise ValueError(
|
2019-05-13 19:53:50 -04:00
|
|
|
"Argument to symmetric eigendecomposition must have shape [..., n, n],"
|
|
|
|
"got shape {}".format(operand.shape))
|
2018-12-22 14:54:26 -05:00
|
|
|
|
|
|
|
batch_dims = operand.shape[:-2]
|
|
|
|
n = operand.shape[-1]
|
|
|
|
v = ShapedArray(batch_dims + (n, n), operand.dtype)
|
2019-05-04 09:42:01 -04:00
|
|
|
w = ShapedArray(batch_dims + (n,), lax.lax._complex_basetype(operand.dtype))
|
2018-12-22 14:54:26 -05:00
|
|
|
else:
|
2019-01-07 18:28:48 -05:00
|
|
|
v, w = operand, operand
|
2019-07-27 15:46:14 -07:00
|
|
|
return v, w
|
2018-12-22 14:54:26 -05:00
|
|
|
|
2019-08-02 11:16:15 -04:00
|
|
|
def _eigh_cpu_gpu_translation_rule(syevd_impl, c, operand, lower):
|
2019-05-29 16:50:27 -04:00
|
|
|
shape = c.GetShape(operand)
|
|
|
|
batch_dims = shape.dimensions()[:-2]
|
2019-08-02 11:16:15 -04:00
|
|
|
v, w, info = syevd_impl(c, operand, lower=lower)
|
|
|
|
ok = c.Eq(info, c.ConstantS32Scalar(0))
|
2019-05-29 16:50:27 -04:00
|
|
|
v = _broadcasting_select(c, c.Reshape(ok, None, batch_dims + (1, 1)), v,
|
|
|
|
_nan_like(c, v))
|
|
|
|
w = _broadcasting_select(c, c.Reshape(ok, None, batch_dims + (1,)), w,
|
|
|
|
_nan_like(c, w))
|
|
|
|
return c.Tuple(v, w)
|
2018-12-22 14:54:26 -05:00
|
|
|
|
2019-02-11 23:26:26 -08:00
|
|
|
def eigh_jvp_rule(primals, tangents, lower):
|
|
|
|
# Derivative for eigh in the simplest case of distinct eigenvalues.
|
2019-02-13 23:23:39 -08:00
|
|
|
# This is classic nondegenerate perurbation theory, but also see
|
|
|
|
# https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
|
2019-02-11 23:26:26 -08:00
|
|
|
# The general solution treating the case of degenerate eigenvalues is
|
|
|
|
# considerably more complicated. Ambitious readers may refer to the general
|
2019-02-13 23:23:39 -08:00
|
|
|
# methods below or refer to degenerate perturbation theory in physics.
|
2019-02-11 23:26:26 -08:00
|
|
|
# https://www.win.tue.nl/analysis/reports/rana06-33.pdf and
|
|
|
|
# https://people.orie.cornell.edu/aslewis/publications/99-clarke.pdf
|
|
|
|
a, = primals
|
|
|
|
a_dot, = tangents
|
2019-08-09 10:50:31 -04:00
|
|
|
|
2019-02-13 23:44:41 -08:00
|
|
|
v, w = eigh_p.bind(symmetrize(a), lower=lower)
|
2019-08-09 10:50:31 -04:00
|
|
|
|
2019-02-11 23:26:26 -08:00
|
|
|
# for complex numbers we need eigenvalues to be full dtype of v, a:
|
|
|
|
w = w.astype(a.dtype)
|
|
|
|
eye_n = np.eye(a.shape[-1], dtype=a.dtype)
|
|
|
|
# carefully build reciprocal delta-eigenvalue matrix, avoiding NaNs.
|
2019-10-09 14:34:46 -04:00
|
|
|
Fmat = np.reciprocal(eye_n + w[..., np.newaxis, :] - w[..., np.newaxis]) - eye_n
|
2019-02-11 23:26:26 -08:00
|
|
|
# eigh impl doesn't support batch dims, but future-proof the grad.
|
2019-12-10 00:38:18 -08:00
|
|
|
dot = partial(lax.dot if a.ndim == 2 else lax.batch_matmul,
|
|
|
|
precision=lax.Precision.HIGHEST)
|
2019-02-11 23:26:26 -08:00
|
|
|
vdag_adot_v = dot(dot(_H(v), a_dot), v)
|
|
|
|
dv = dot(v, np.multiply(Fmat, vdag_adot_v))
|
2019-10-09 14:34:46 -04:00
|
|
|
dw = np.diagonal(vdag_adot_v, axis1=-2, axis2=-1)
|
2019-07-27 15:46:14 -07:00
|
|
|
return (v, w), (dv, dw)
|
2019-02-11 23:26:26 -08:00
|
|
|
|
2019-05-10 15:15:38 -04:00
|
|
|
def eigh_batching_rule(batched_args, batch_dims, lower):
|
|
|
|
x, = batched_args
|
|
|
|
bd, = batch_dims
|
2019-07-27 15:46:14 -07:00
|
|
|
x = batching.moveaxis(x, bd, 0)
|
|
|
|
return eigh_p.bind(x, lower=lower), (0, 0)
|
2019-05-10 15:15:38 -04:00
|
|
|
|
2018-12-22 14:54:26 -05:00
|
|
|
eigh_p = Primitive('eigh')
|
2019-07-27 15:46:14 -07:00
|
|
|
eigh_p.multiple_results = True
|
2018-12-22 14:54:26 -05:00
|
|
|
eigh_p.def_impl(eigh_impl)
|
|
|
|
eigh_p.def_abstract_eval(eigh_abstract_eval)
|
|
|
|
xla.translations[eigh_p] = eigh_translation_rule
|
2019-02-11 23:26:26 -08:00
|
|
|
ad.primitive_jvps[eigh_p] = eigh_jvp_rule
|
2019-10-08 16:09:50 -04:00
|
|
|
batching.primitive_batchers[eigh_p] = eigh_batching_rule
|
2019-08-02 11:16:15 -04:00
|
|
|
|
2019-09-02 07:25:06 -07:00
|
|
|
_cpu_syevd = lapack.syevd
|
2019-08-02 11:16:15 -04:00
|
|
|
|
|
|
|
xla.backend_specific_translations['cpu'][eigh_p] = partial(
|
|
|
|
_eigh_cpu_gpu_translation_rule, _cpu_syevd)
|
|
|
|
|
2019-09-02 07:25:06 -07:00
|
|
|
xla.backend_specific_translations['gpu'][eigh_p] = partial(
|
|
|
|
_eigh_cpu_gpu_translation_rule, cusolver.syevd)
|
2019-10-08 16:09:50 -04:00
|
|
|
|
2018-12-22 14:54:26 -05:00
|
|
|
|
|
|
|
|
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
triangular_solve_dtype_rule = partial(
|
2020-01-15 13:13:11 -08:00
|
|
|
naryop_dtype_rule, _input_dtype, (_float | _complex, _float | _complex),
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
'triangular_solve')
|
|
|
|
|
|
|
|
def triangular_solve_shape_rule(a, b, left_side=False, **unused_kwargs):
|
|
|
|
if a.ndim < 2:
|
|
|
|
msg = "triangular_solve requires a.ndim to be at least 2, got {}."
|
|
|
|
raise TypeError(msg.format(a.ndim))
|
|
|
|
if a.shape[-1] != a.shape[-2]:
|
|
|
|
msg = ("triangular_solve requires the last two dimensions of a to be equal "
|
|
|
|
"in size, got a.shape of {}.")
|
|
|
|
raise TypeError(msg.format(a.shape))
|
|
|
|
if a.shape[:-2] != b.shape[:-2]:
|
|
|
|
msg = ("triangular_solve requires both arguments to have the same number "
|
|
|
|
"of dimensions and equal batch dimensions, got {} and {}.")
|
|
|
|
raise TypeError(msg.format(a.shape, b.shape))
|
|
|
|
common_dim = -2 if left_side else -1
|
|
|
|
if a.shape[-1] != b.shape[common_dim]:
|
|
|
|
msg = "Incompatible shapes for arguments to triangular_solve: {} and {}."
|
|
|
|
raise TypeError(msg.format(a.shape, b.shape))
|
|
|
|
return b.shape
|
|
|
|
|
2018-12-17 17:20:52 -08:00
|
|
|
def triangular_solve_jvp_rule_a(
|
2019-06-25 15:24:22 -04:00
|
|
|
g_a, ans, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal):
|
2019-10-10 08:27:21 -07:00
|
|
|
m, n = b.shape[-2:]
|
2019-06-25 15:24:22 -04:00
|
|
|
k = 1 if unit_diagonal else 0
|
|
|
|
g_a = np.tril(g_a, k=-k) if lower else np.triu(g_a, k=k)
|
2018-12-17 17:20:52 -08:00
|
|
|
g_a = lax.neg(g_a)
|
2018-12-19 17:47:56 -05:00
|
|
|
g_a = np.swapaxes(g_a, -1, -2) if transpose_a else g_a
|
2019-06-18 23:38:03 -04:00
|
|
|
g_a = np.conj(g_a) if conjugate_a else g_a
|
2019-12-10 00:38:18 -08:00
|
|
|
dot = partial(lax.dot if g_a.ndim == 2 else lax.batch_matmul,
|
|
|
|
precision=lax.Precision.HIGHEST)
|
2019-10-10 08:27:21 -07:00
|
|
|
|
|
|
|
def a_inverse(rhs):
|
|
|
|
return triangular_solve(a, rhs, left_side, lower, transpose_a, conjugate_a,
|
|
|
|
unit_diagonal)
|
|
|
|
|
|
|
|
# triangular_solve is about the same cost as matrix multplication (~n^2 FLOPs
|
|
|
|
# for matrix/vector inputs). Order these operations in whichever order is
|
|
|
|
# cheaper.
|
2018-12-17 17:20:52 -08:00
|
|
|
if left_side:
|
2019-10-10 08:27:21 -07:00
|
|
|
assert g_a.shape[-2:] == a.shape[-2:] == (m, m) and ans.shape[-2:] == (m, n)
|
|
|
|
if m > n:
|
|
|
|
return a_inverse(dot(g_a, ans)) # A^{-1} (∂A X)
|
|
|
|
else:
|
|
|
|
return dot(a_inverse(g_a), ans) # (A^{-1} ∂A) X
|
2018-12-17 17:20:52 -08:00
|
|
|
else:
|
2019-10-10 08:27:21 -07:00
|
|
|
assert g_a.shape[-2:] == a.shape[-2:] == (n, n) and ans.shape[-2:] == (m, n)
|
|
|
|
if m < n:
|
|
|
|
return a_inverse(dot(ans, g_a)) # (X ∂A) A^{-1}
|
|
|
|
else:
|
|
|
|
return dot(ans, a_inverse(g_a)) # X (∂A A^{-1})
|
2018-12-17 17:20:52 -08:00
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
def triangular_solve_transpose_rule(
|
2019-06-25 15:24:22 -04:00
|
|
|
cotangent, a, b, left_side, lower, transpose_a, conjugate_a,
|
|
|
|
unit_diagonal):
|
2019-07-26 17:21:11 -07:00
|
|
|
# Triangular solve is nonlinear in its first argument and linear in its second
|
|
|
|
# argument, analogous to `div` but swapped.
|
2019-07-27 15:46:14 -07:00
|
|
|
assert a is not ad.undefined_primal and b is ad.undefined_primal
|
2019-10-21 18:03:36 -07:00
|
|
|
if cotangent is ad_util.zero:
|
|
|
|
cotangent_b = ad_util.zero
|
|
|
|
else:
|
|
|
|
cotangent_b = triangular_solve(a, cotangent, left_side, lower,
|
|
|
|
not transpose_a, conjugate_a, unit_diagonal)
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
return [None, cotangent_b]
|
|
|
|
|
2019-03-10 17:31:51 -04:00
|
|
|
|
|
|
|
def triangular_solve_batching_rule(batched_args, batch_dims, left_side,
|
2019-06-25 15:24:22 -04:00
|
|
|
lower, transpose_a, conjugate_a,
|
|
|
|
unit_diagonal):
|
2019-03-10 17:31:51 -04:00
|
|
|
x, y = batched_args
|
|
|
|
bx, by = batch_dims
|
|
|
|
size = next(t.shape[i] for t, i in zip(batched_args, batch_dims)
|
|
|
|
if i is not None)
|
2019-07-27 15:46:14 -07:00
|
|
|
x = batching.bdim_at_front(x, bx, size)
|
|
|
|
y = batching.bdim_at_front(y, by, size)
|
2019-03-10 17:31:51 -04:00
|
|
|
return triangular_solve(x, y, left_side=left_side, lower=lower,
|
2019-06-25 15:24:22 -04:00
|
|
|
transpose_a=transpose_a, conjugate_a=conjugate_a,
|
|
|
|
unit_diagonal=unit_diagonal), 0
|
2019-03-10 17:31:51 -04:00
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
triangular_solve_p = standard_primitive(
|
|
|
|
triangular_solve_shape_rule, triangular_solve_dtype_rule,
|
|
|
|
'triangular_solve')
|
2018-12-17 17:20:52 -08:00
|
|
|
ad.defjvp2(triangular_solve_p,
|
|
|
|
triangular_solve_jvp_rule_a,
|
|
|
|
lambda g_b, _, a, b, **kws: triangular_solve(a, g_b, **kws))
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
ad.primitive_transposes[triangular_solve_p] = triangular_solve_transpose_rule
|
2019-03-10 17:31:51 -04:00
|
|
|
batching.primitive_batchers[triangular_solve_p] = triangular_solve_batching_rule
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
|
2019-08-08 11:50:31 -04:00
|
|
|
def _triangular_solve_cpu_translation_rule(
|
2019-06-25 15:24:22 -04:00
|
|
|
c, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal):
|
2018-12-17 16:39:19 -05:00
|
|
|
shape = c.GetShape(a)
|
2018-12-21 16:29:45 -05:00
|
|
|
dtype = shape.element_type().type
|
2020-01-14 16:18:47 +00:00
|
|
|
|
2019-12-18 11:57:22 -05:00
|
|
|
if len(shape.dimensions()) == 2 and onp.dtype(dtype) in _cpu_lapack_types:
|
2019-06-18 23:38:03 -04:00
|
|
|
if conjugate_a and not transpose_a:
|
|
|
|
a = c.Conj(a)
|
|
|
|
conjugate_a = False
|
2018-12-20 15:37:34 -05:00
|
|
|
return lapack.jax_trsm(
|
2018-12-21 16:29:45 -05:00
|
|
|
c, c.Constant(onp.array(1, dtype=dtype)), a, b, left_side, lower,
|
2019-06-25 15:24:22 -04:00
|
|
|
transpose_a, conjugate_a, unit_diagonal)
|
2018-12-17 16:39:19 -05:00
|
|
|
else:
|
2020-01-14 16:18:47 +00:00
|
|
|
# Fall back to the HLO implementation for unsupported types or batching.
|
|
|
|
# TODO: Consider swapping XLA for LAPACK in batched case
|
2019-06-25 15:24:22 -04:00
|
|
|
return c.TriangularSolve(a, b, left_side, lower, transpose_a, conjugate_a,
|
|
|
|
unit_diagonal)
|
2018-12-17 16:39:19 -05:00
|
|
|
|
2019-08-08 11:50:31 -04:00
|
|
|
xla.backend_specific_translations['cpu'][triangular_solve_p] = \
|
|
|
|
_triangular_solve_cpu_translation_rule
|
2018-12-20 15:37:34 -05:00
|
|
|
|
2019-08-08 11:50:31 -04:00
|
|
|
def _triangular_solve_gpu_translation_rule(
|
|
|
|
c, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal):
|
|
|
|
shape = c.GetShape(a)
|
|
|
|
dtype = shape.element_type().type
|
|
|
|
dims = shape.dimensions()
|
|
|
|
m, n = dims[-2:]
|
|
|
|
batch = prod(dims[:-2])
|
|
|
|
if batch > 1 and m <= 32 and n <= 32:
|
|
|
|
if conjugate_a and not transpose_a:
|
|
|
|
a = c.Conj(a)
|
|
|
|
conjugate_a = False
|
|
|
|
return cusolver.trsm(
|
|
|
|
c, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal)
|
|
|
|
else:
|
|
|
|
# Use the XLA implementation for unbatched triangular_solve.
|
|
|
|
return c.TriangularSolve(a, b, left_side, lower, transpose_a, conjugate_a,
|
|
|
|
unit_diagonal)
|
|
|
|
|
2019-09-04 16:24:32 -04:00
|
|
|
xla.backend_specific_translations['gpu'][triangular_solve_p] = \
|
|
|
|
_triangular_solve_gpu_translation_rule
|
2018-12-20 15:37:34 -05:00
|
|
|
|
|
|
|
# LU decomposition
|
|
|
|
|
|
|
|
# Computes a pivoted LU decomposition such that
|
|
|
|
# PA = LU
|
|
|
|
# In the style of LAPACK, LU are stored in the same matrix.
|
|
|
|
|
2019-06-25 16:52:28 -04:00
|
|
|
def _lu_unblocked(a):
|
|
|
|
"""Unblocked LU decomposition, as a rolled loop."""
|
|
|
|
m, n = a.shape
|
|
|
|
def body(k, state):
|
2019-09-16 08:45:10 -07:00
|
|
|
pivot, perm, a = state
|
2019-06-25 16:52:28 -04:00
|
|
|
m_idx = np.arange(m)
|
|
|
|
n_idx = np.arange(n)
|
|
|
|
|
|
|
|
if np.issubdtype(a.dtype, np.complexfloating):
|
|
|
|
t = a[:, k]
|
|
|
|
magnitude = np.abs(np.real(t)) + np.abs(np.imag(t))
|
|
|
|
else:
|
|
|
|
magnitude = np.abs(a[:, k])
|
|
|
|
i = np.argmax(np.where(m_idx >= k, magnitude, -np.inf))
|
|
|
|
pivot = ops.index_update(pivot, ops.index[k], i)
|
|
|
|
|
|
|
|
a = ops.index_update(a, ops.index[[k, i],], a[[i, k],])
|
|
|
|
|
|
|
|
perm = ops.index_update(perm, ops.index[[i, k],], perm[[k, i],])
|
|
|
|
|
|
|
|
# a[k+1:, k] /= a[k, k], adapted for loop-invariant shapes
|
2019-06-28 15:31:06 -04:00
|
|
|
x = a[k, k]
|
2019-06-25 16:52:28 -04:00
|
|
|
a = ops.index_update(a, ops.index[:, k],
|
2019-06-28 15:31:06 -04:00
|
|
|
np.where(m_idx > k, a[:, k] / x, a[:, k]))
|
2019-06-25 16:52:28 -04:00
|
|
|
|
|
|
|
# a[k+1:, k+1:] -= np.outer(a[k+1:, k], a[k, k+1:])
|
|
|
|
a = a - np.where((m_idx[:, None] > k) & (n_idx > k),
|
|
|
|
np.outer(a[:, k], a[k, :]), np.array(0, dtype=a.dtype))
|
2019-09-16 08:45:10 -07:00
|
|
|
return pivot, perm, a
|
2019-06-25 16:52:28 -04:00
|
|
|
|
|
|
|
pivot = np.zeros((min(m, n),), dtype=np.int32)
|
|
|
|
perm = np.arange(m, dtype=np.int32)
|
|
|
|
if m == 0 and n == 0:
|
|
|
|
# If the array is empty, the loop body never executes but tracing it to a
|
|
|
|
# jaxpr fails because the indexing cannot succeed.
|
2019-09-16 08:45:10 -07:00
|
|
|
return (pivot, perm, a)
|
|
|
|
return lax.fori_loop(0, min(m, n), body, (pivot, perm, a))
|
2019-06-25 16:52:28 -04:00
|
|
|
|
|
|
|
|
|
|
|
def _lu_blocked(a, block_size=32):
|
|
|
|
"""Blocked LU decomposition, as an unrolled loop."""
|
|
|
|
m, n = a.shape
|
|
|
|
r = min(m, n)
|
|
|
|
pivot = np.zeros((r,), dtype=np.int32)
|
|
|
|
for k in range(0, r, block_size):
|
|
|
|
b = min(r - k, block_size)
|
2019-09-16 08:45:10 -07:00
|
|
|
block_pivot, perm, lu_block = _lu_unblocked(a[k:, k:k+b])
|
2019-06-25 16:52:28 -04:00
|
|
|
a = ops.index_update(a, ops.index[k:, k:k+b], lu_block)
|
|
|
|
|
|
|
|
a = ops.index_update(a, ops.index[k:, :k], a[perm + k, :k])
|
|
|
|
pivot = ops.index_update(pivot, ops.index[k:k+b], block_pivot + k)
|
|
|
|
|
|
|
|
if k + b < n:
|
|
|
|
a = ops.index_update(a, ops.index[k:, k+b:], a[perm + k, k+b:])
|
|
|
|
a = ops.index_update(
|
|
|
|
a, ops.index[k:k+b, k+b:],
|
|
|
|
triangular_solve(a[k:k+b, k:k+b], a[k:k+b, k+b:],
|
|
|
|
left_side=True, lower=True, unit_diagonal=True))
|
2019-06-28 09:00:32 -04:00
|
|
|
a = ops.index_add(
|
|
|
|
a, ops.index[k+b:, k+b:],
|
|
|
|
-lax.dot(a[k+b:, k:k+b], a[k:k+b, k+b:],
|
|
|
|
precision=lax.Precision.HIGHEST))
|
2019-06-25 16:52:28 -04:00
|
|
|
return pivot, a
|
|
|
|
|
|
|
|
def _lu_python(x):
|
|
|
|
"""Default LU decomposition in Python, where no better version exists."""
|
|
|
|
m, n = x.shape[-2:]
|
|
|
|
batch_dims = x.shape[:-2]
|
|
|
|
if len(batch_dims) > 0:
|
|
|
|
batch_size = onp.prod(batch_dims, dtype=onp.int64)
|
|
|
|
pivot, lu = api.vmap(_lu_blocked)(lax.reshape(x, (batch_size, m, n)))
|
2019-06-27 15:21:56 -04:00
|
|
|
pivot = lax.reshape(pivot, batch_dims + (min(m, n),))
|
2019-06-25 16:52:28 -04:00
|
|
|
lu = lax.reshape(lu, batch_dims + (m, n))
|
|
|
|
else:
|
|
|
|
pivot, lu = _lu_blocked(x)
|
2019-07-27 15:46:14 -07:00
|
|
|
return lu, pivot
|
2018-12-20 15:37:34 -05:00
|
|
|
|
2019-06-25 16:52:28 -04:00
|
|
|
def _lu_impl(operand):
|
|
|
|
lu, pivot = xla.apply_primitive(lu_p, operand)
|
2019-07-27 15:46:14 -07:00
|
|
|
return lu, pivot
|
2018-12-20 15:37:34 -05:00
|
|
|
|
2019-06-25 16:52:28 -04:00
|
|
|
def _lu_abstract_eval(operand):
|
2018-12-20 15:37:34 -05:00
|
|
|
if isinstance(operand, ShapedArray):
|
|
|
|
if operand.ndim < 2:
|
|
|
|
raise ValueError("Argument to LU decomposition must have ndims >= 2")
|
|
|
|
|
|
|
|
batch_dims = operand.shape[:-2]
|
|
|
|
m = operand.shape[-2]
|
|
|
|
n = operand.shape[-1]
|
|
|
|
pivot = ShapedArray(batch_dims + (min(m, n),), np.int32)
|
|
|
|
else:
|
|
|
|
pivot = operand
|
2019-07-27 15:46:14 -07:00
|
|
|
return operand, pivot
|
2018-12-20 15:37:34 -05:00
|
|
|
|
2019-06-27 15:21:56 -04:00
|
|
|
def _lu_jvp_rule(primals, tangents):
|
2018-12-22 14:53:42 -05:00
|
|
|
a, = primals
|
|
|
|
a_dot, = tangents
|
|
|
|
lu, pivots = lu_p.bind(a)
|
|
|
|
|
2019-01-09 15:52:34 -05:00
|
|
|
a_shape = np.shape(a)
|
|
|
|
m, n = a_shape[-2:]
|
2019-04-12 16:28:40 -07:00
|
|
|
dtype = lax.dtype(a)
|
2018-12-22 14:53:42 -05:00
|
|
|
k = min(m, n)
|
|
|
|
|
|
|
|
permutation = lu_pivots_to_permutation(pivots, m)
|
2019-08-09 10:50:31 -04:00
|
|
|
batch_dims = a_shape[:-2]
|
|
|
|
iotas = np.ix_(*(lax.iota(np.int32, b) for b in batch_dims + (1,)))
|
|
|
|
x = a_dot[iotas[:-1] + (permutation, slice(None))]
|
2018-12-22 14:53:42 -05:00
|
|
|
|
|
|
|
# Differentiation of Matrix Functionals Using Triangular Factorization
|
|
|
|
# F. R. De Hoog, R. S. Anderssen, and M. A. Lukas
|
2019-01-09 13:22:25 -05:00
|
|
|
#
|
|
|
|
# LU = A
|
|
|
|
# ==> L'U + LU' = A'
|
|
|
|
# ==> inv(L) . L' + U' . inv(U) = inv(L) A' inv(U)
|
|
|
|
# ==> L' = L . tril(inv(L) . A' . inv(U), -1)
|
|
|
|
# U' = triu(inv(L) . A' . inv(U)) . U
|
2018-12-22 14:53:42 -05:00
|
|
|
|
2019-01-09 15:52:34 -05:00
|
|
|
ndims = len(a_shape)
|
|
|
|
l_padding = [(0, 0, 0)] * ndims
|
|
|
|
l_padding[-1] = (0, m - k, 0)
|
|
|
|
zero = np._constant_like(lu, 0)
|
|
|
|
l = lax.pad(np.tril(lu[..., :, :k], -1), zero, l_padding)
|
|
|
|
l = l + np.eye(m, m, dtype=dtype)
|
|
|
|
|
|
|
|
u_eye = lax.pad(np.eye(n - k, n - k, dtype=dtype), zero,
|
|
|
|
((k, 0, 0), (k, 0, 0)))
|
|
|
|
u_padding = [(0, 0, 0)] * ndims
|
|
|
|
u_padding[-2] = (0, n - k, 0)
|
|
|
|
u = lax.pad(np.triu(lu[..., :k, :]), zero, u_padding) + u_eye
|
|
|
|
|
2019-06-25 15:24:22 -04:00
|
|
|
la = triangular_solve(l, x, left_side=True, transpose_a=False, lower=True,
|
|
|
|
unit_diagonal=True)
|
2018-12-22 14:53:42 -05:00
|
|
|
lau = triangular_solve(u, la, left_side=False, transpose_a=False,
|
|
|
|
lower=False)
|
|
|
|
|
2019-01-09 15:52:34 -05:00
|
|
|
l_dot = np.matmul(l, np.tril(lau, -1))
|
|
|
|
u_dot = np.matmul(np.triu(lau), u)
|
2019-01-09 13:22:25 -05:00
|
|
|
lu_dot = l_dot + u_dot
|
2019-07-27 15:46:14 -07:00
|
|
|
return (lu, pivots), (lu_dot, ad_util.zero)
|
2018-12-22 14:53:42 -05:00
|
|
|
|
|
|
|
|
2019-06-25 16:52:28 -04:00
|
|
|
def _lu_batching_rule(batched_args, batch_dims):
|
2019-04-30 13:19:34 -04:00
|
|
|
x, = batched_args
|
|
|
|
bd, = batch_dims
|
2019-07-27 15:46:14 -07:00
|
|
|
x = batching.moveaxis(x, bd, 0)
|
|
|
|
return lu_p.bind(x), (0, 0)
|
2019-04-30 13:19:34 -04:00
|
|
|
|
2019-08-02 11:16:15 -04:00
|
|
|
def _lu_cpu_gpu_translation_rule(getrf_impl, c, operand):
|
2019-06-28 15:31:06 -04:00
|
|
|
shape = c.GetShape(operand)
|
|
|
|
batch_dims = shape.dimensions()[:-2]
|
2019-08-02 11:16:15 -04:00
|
|
|
lu, pivot, info = getrf_impl(c, operand)
|
2019-06-28 15:31:06 -04:00
|
|
|
# Subtract 1 from the pivot to get 0-based indices.
|
2019-08-02 11:16:15 -04:00
|
|
|
pivot = c.Sub(pivot, c.ConstantS32Scalar(1))
|
2019-09-16 08:45:10 -07:00
|
|
|
ok = c.Ge(info, c.ConstantS32Scalar(0))
|
2019-06-28 15:31:06 -04:00
|
|
|
lu = _broadcasting_select(c, c.Reshape(ok, None, batch_dims + (1, 1)), lu,
|
|
|
|
_nan_like(c, lu))
|
|
|
|
return c.Tuple(lu, pivot)
|
|
|
|
|
2019-04-30 13:19:34 -04:00
|
|
|
|
2018-12-20 15:37:34 -05:00
|
|
|
lu_p = Primitive('lu')
|
2019-07-27 15:46:14 -07:00
|
|
|
lu_p.multiple_results = True
|
2019-06-25 16:52:28 -04:00
|
|
|
lu_p.def_impl(_lu_impl)
|
|
|
|
lu_p.def_abstract_eval(_lu_abstract_eval)
|
2019-06-28 15:31:06 -04:00
|
|
|
xla.translations[lu_p] = xla.lower_fun(_lu_python, instantiate=True)
|
2019-06-25 16:52:28 -04:00
|
|
|
ad.primitive_jvps[lu_p] = _lu_jvp_rule
|
|
|
|
batching.primitive_batchers[lu_p] = _lu_batching_rule
|
2019-08-02 11:16:15 -04:00
|
|
|
|
|
|
|
xla.backend_specific_translations['cpu'][lu_p] = partial(
|
2019-09-04 16:24:32 -04:00
|
|
|
_lu_cpu_gpu_translation_rule, lapack.getrf)
|
2019-08-02 11:16:15 -04:00
|
|
|
|
2019-09-04 16:24:32 -04:00
|
|
|
xla.backend_specific_translations['gpu'][lu_p] = partial(
|
|
|
|
_lu_cpu_gpu_translation_rule, cusolver.getrf)
|
2018-12-17 16:39:19 -05:00
|
|
|
|
|
|
|
|
2019-09-26 11:25:51 +02:00
|
|
|
# Define this outside lu_pivots_to_permutation to ensure fori_loop cache hits
|
|
|
|
def _lu_pivots_body_fn(i, permutation_and_swaps):
|
|
|
|
permutation, swaps = permutation_and_swaps
|
|
|
|
batch_dims = swaps.shape[:-1]
|
|
|
|
j = swaps[..., i]
|
|
|
|
iotas = np.ix_(*(lax.iota(np.int32, b) for b in batch_dims))
|
|
|
|
x = permutation[..., i]
|
|
|
|
y = permutation[iotas + (j,)]
|
|
|
|
permutation = ops.index_update(permutation, ops.index[..., i], y)
|
|
|
|
return ops.index_update(permutation, ops.index[iotas + (j,)], x), swaps
|
|
|
|
|
2019-06-17 20:32:19 -04:00
|
|
|
def lu_pivots_to_permutation(swaps, m):
|
|
|
|
"""Converts the pivots (row swaps) returned by LU to a permutation.
|
|
|
|
|
|
|
|
We build a permutation rather than applying `swaps` directly to the rows
|
|
|
|
of a matrix because lax loops aren't differentiable.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
swaps: an array of shape (..., k) of row swaps to perform
|
|
|
|
m: the size of the output permutation. m should be >= k.
|
|
|
|
Returns:
|
|
|
|
An int32 array of shape (..., m).
|
|
|
|
"""
|
|
|
|
assert len(swaps.shape) >= 1
|
|
|
|
batch_dims = swaps.shape[:-1]
|
|
|
|
k = swaps.shape[-1]
|
|
|
|
|
|
|
|
permutation = lax.broadcasted_iota(np.int32, batch_dims + (m,),
|
|
|
|
len(batch_dims))
|
2019-09-26 13:39:35 +02:00
|
|
|
result, _ = lax.fori_loop(onp.array(0, onp.int32), onp.array(k, onp.int32),
|
|
|
|
_lu_pivots_body_fn, (permutation, swaps))
|
2019-09-26 11:25:51 +02:00
|
|
|
return result
|
2018-12-20 15:37:34 -05:00
|
|
|
|
|
|
|
|
|
|
|
# QR decomposition
|
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
def qr_impl(operand, full_matrices):
|
|
|
|
q, r = xla.apply_primitive(qr_p, operand, full_matrices=full_matrices)
|
2019-07-27 15:46:14 -07:00
|
|
|
return q, r
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
def qr_translation_rule(c, operand, full_matrices):
|
|
|
|
return c.QR(operand, full_matrices=full_matrices)
|
|
|
|
|
|
|
|
def qr_abstract_eval(operand, full_matrices):
|
|
|
|
if isinstance(operand, ShapedArray):
|
|
|
|
if operand.ndim < 2:
|
|
|
|
raise ValueError("Argument to QR decomposition must have ndims >= 2")
|
|
|
|
batch_dims = operand.shape[:-2]
|
|
|
|
m = operand.shape[-2]
|
|
|
|
n = operand.shape[-1]
|
|
|
|
k = m if full_matrices else min(m, n)
|
|
|
|
q = ShapedArray(batch_dims + (m, k), operand.dtype)
|
|
|
|
r = ShapedArray(batch_dims + (k, n), operand.dtype)
|
|
|
|
else:
|
|
|
|
q = operand
|
|
|
|
r = operand
|
2019-07-27 15:46:14 -07:00
|
|
|
return q, r
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2018-12-17 16:02:29 +00:00
|
|
|
def qr_jvp_rule(primals, tangents, full_matrices):
|
2018-12-17 16:04:51 +00:00
|
|
|
# See j-towns.github.io/papers/qr-derivative.pdf for a terse derivation.
|
2018-12-17 16:02:29 +00:00
|
|
|
x, = primals
|
|
|
|
dx, = tangents
|
|
|
|
q, r = qr_p.bind(x, full_matrices=False)
|
2019-08-09 10:50:31 -04:00
|
|
|
if full_matrices or np.shape(x)[-2] < np.shape(x)[-1]:
|
|
|
|
raise NotImplementedError
|
2018-12-17 16:02:29 +00:00
|
|
|
dx_rinv = triangular_solve(r, dx) # Right side solve by default
|
2019-09-04 16:24:32 -04:00
|
|
|
qt_dx_rinv = np.matmul(_H(q), dx_rinv)
|
2018-12-17 16:02:29 +00:00
|
|
|
qt_dx_rinv_lower = np.tril(qt_dx_rinv, -1)
|
2019-09-04 16:24:32 -04:00
|
|
|
domega = qt_dx_rinv_lower - _H(qt_dx_rinv_lower) # This is skew-symmetric
|
2018-12-17 16:02:29 +00:00
|
|
|
dq = np.matmul(q, domega - qt_dx_rinv) + dx_rinv
|
|
|
|
dr = np.matmul(qt_dx_rinv - domega, r)
|
2019-07-27 15:46:14 -07:00
|
|
|
return (q, r), (dq, dr)
|
2018-12-17 16:02:29 +00:00
|
|
|
|
2019-04-30 13:19:34 -04:00
|
|
|
def qr_batching_rule(batched_args, batch_dims, full_matrices):
|
|
|
|
x, = batched_args
|
|
|
|
bd, = batch_dims
|
2019-07-27 15:46:14 -07:00
|
|
|
x = batching.moveaxis(x, bd, 0)
|
|
|
|
return qr_p.bind(x, full_matrices=full_matrices), (0, 0)
|
2019-04-30 13:19:34 -04:00
|
|
|
|
2019-09-04 16:24:32 -04:00
|
|
|
def _qr_cpu_gpu_translation_rule(geqrf_impl, orgqr_impl, c, operand,
|
|
|
|
full_matrices):
|
|
|
|
shape = c.GetShape(operand)
|
|
|
|
dims = shape.dimensions()
|
|
|
|
m, n = dims[-2:]
|
|
|
|
batch_dims = dims[:-2]
|
|
|
|
r, tau, info_geqrf = geqrf_impl(c, operand)
|
|
|
|
if m < n:
|
|
|
|
q = c.Slice(r, [0] * len(dims), list(batch_dims) + [m, m])
|
|
|
|
q, info_orgqr = orgqr_impl(c, q, tau)
|
|
|
|
elif not full_matrices:
|
|
|
|
q, info_orgqr = orgqr_impl(c, r, tau)
|
|
|
|
r = c.Slice(r, [0] * len(dims), list(batch_dims) + [n, n])
|
|
|
|
else:
|
|
|
|
padding_config = [(0, 0, 0)] * len(dims)
|
|
|
|
padding_config[-1] = (0, m - n, 0)
|
|
|
|
q = c.Pad(r, c.Constant(onp.array(0, dtype=shape.element_type())),
|
|
|
|
padding_config)
|
|
|
|
q, info_orgqr = orgqr_impl(c, q, tau)
|
|
|
|
|
|
|
|
ok = c.And(c.Eq(info_geqrf, c.ConstantS32Scalar(0)),
|
|
|
|
c.Eq(info_orgqr, c.ConstantS32Scalar(0)))
|
|
|
|
q = _broadcasting_select(c, c.Reshape(ok, None, batch_dims + (1, 1)), q,
|
|
|
|
_nan_like(c, q))
|
|
|
|
r = _broadcasting_select(c, c.Reshape(ok, None, batch_dims + (1, 1)), r,
|
|
|
|
_nan_like(c, r))
|
|
|
|
return c.Tuple(q, r)
|
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
qr_p = Primitive('qr')
|
2019-07-27 15:46:14 -07:00
|
|
|
qr_p.multiple_results = True
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
qr_p.def_impl(qr_impl)
|
|
|
|
qr_p.def_abstract_eval(qr_abstract_eval)
|
|
|
|
xla.translations[qr_p] = qr_translation_rule
|
2018-12-17 16:02:29 +00:00
|
|
|
ad.primitive_jvps[qr_p] = qr_jvp_rule
|
2019-04-30 13:19:34 -04:00
|
|
|
batching.primitive_batchers[qr_p] = qr_batching_rule
|
2019-01-05 11:13:08 +05:30
|
|
|
|
2019-12-18 11:07:39 -05:00
|
|
|
xla.backend_specific_translations['cpu'][qr_p] = partial(
|
|
|
|
_qr_cpu_gpu_translation_rule, lapack.geqrf, lapack.orgqr)
|
|
|
|
|
|
|
|
xla.backend_specific_translations['gpu'][qr_p] = partial(
|
|
|
|
_qr_cpu_gpu_translation_rule, cusolver.geqrf, cusolver.orgqr)
|
2019-09-04 16:24:32 -04:00
|
|
|
|
2019-01-05 11:13:08 +05:30
|
|
|
|
2019-01-08 21:47:19 +05:30
|
|
|
# Singular value decomposition
|
2019-01-05 11:13:08 +05:30
|
|
|
|
|
|
|
def svd_impl(operand, full_matrices, compute_uv):
|
2019-09-05 18:12:00 -04:00
|
|
|
s, u, vt = xla.apply_primitive(svd_p, operand, full_matrices=full_matrices,
|
|
|
|
compute_uv=compute_uv)
|
2019-07-27 15:46:14 -07:00
|
|
|
return s, u, vt
|
2019-01-05 11:13:08 +05:30
|
|
|
|
2019-01-08 09:24:48 +05:30
|
|
|
def svd_translation_rule(c, operand, full_matrices, compute_uv):
|
2019-01-05 11:13:08 +05:30
|
|
|
raise NotImplementedError(
|
2019-12-04 15:38:17 +01:00
|
|
|
"Singular value decomposition is only implemented on the CPU and GPU backends")
|
2019-01-05 11:13:08 +05:30
|
|
|
|
|
|
|
def svd_abstract_eval(operand, full_matrices, compute_uv):
|
|
|
|
if isinstance(operand, ShapedArray):
|
|
|
|
if operand.ndim < 2:
|
2019-01-08 09:24:48 +05:30
|
|
|
raise ValueError("Argument to singular value decomposition must have ndims >= 2")
|
2019-01-05 11:13:08 +05:30
|
|
|
|
|
|
|
batch_dims = operand.shape[:-2]
|
|
|
|
m = operand.shape[-2]
|
|
|
|
n = operand.shape[-1]
|
2019-07-27 15:46:14 -07:00
|
|
|
s = ShapedArray(batch_dims + (min(m, n),), lax.lax._complex_basetype(operand.dtype))
|
2019-01-05 11:13:08 +05:30
|
|
|
u = ShapedArray(batch_dims + (m, m if full_matrices else min(m, n)), operand.dtype)
|
|
|
|
vt = ShapedArray(batch_dims + (n if full_matrices else min(m, n), n), operand.dtype)
|
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
raise NotImplementedError
|
|
|
|
return s, u, vt
|
2019-01-05 11:13:08 +05:30
|
|
|
|
2019-05-02 17:47:34 -07:00
|
|
|
def svd_jvp_rule(primals, tangents, full_matrices, compute_uv):
|
|
|
|
A, = primals
|
|
|
|
dA, = tangents
|
|
|
|
s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True)
|
|
|
|
|
2019-08-09 10:50:31 -04:00
|
|
|
if full_matrices:
|
|
|
|
# TODO: implement full matrices case, documented here: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
|
|
|
|
raise NotImplementedError(
|
|
|
|
"Singular value decomposition JVP not implemented for full matrices")
|
|
|
|
|
2019-05-02 17:47:34 -07:00
|
|
|
k = s.shape[-1]
|
2019-09-05 18:12:00 -04:00
|
|
|
Ut, V = _H(U), _H(Vt)
|
2019-05-02 17:47:34 -07:00
|
|
|
s_dim = s[..., None, :]
|
2019-09-05 18:12:00 -04:00
|
|
|
dS = np.matmul(np.matmul(Ut, dA), V)
|
|
|
|
ds = np.real(np.diagonal(dS, 0, -2, -1))
|
|
|
|
F = 1 / (np.square(s_dim) - np.square(_T(s_dim)) + np.eye(k)) - np.eye(k)
|
2019-05-02 17:47:34 -07:00
|
|
|
dSS = s_dim * dS
|
2019-09-05 18:12:00 -04:00
|
|
|
SdS = _T(s_dim) * dS
|
|
|
|
dU = np.matmul(U, F * (dSS + _T(dSS)))
|
|
|
|
dV = np.matmul(V, F * (SdS + _T(SdS)))
|
2019-05-02 17:47:34 -07:00
|
|
|
|
|
|
|
m, n = A.shape[-2], A.shape[-1]
|
|
|
|
if m > n:
|
2019-09-05 18:12:00 -04:00
|
|
|
dU = dU + np.matmul(np.eye(m) - np.matmul(U, Ut), np.matmul(dA, V)) / s_dim
|
2019-05-02 17:47:34 -07:00
|
|
|
if n > m:
|
2019-09-05 18:12:00 -04:00
|
|
|
dV = dV + np.matmul(np.eye(n) - np.matmul(V, Vt), np.matmul(_H(dA), U)) / s_dim
|
|
|
|
return (s, U, Vt), (ds, dU, _T(dV))
|
2019-05-02 17:47:34 -07:00
|
|
|
|
2019-08-02 11:16:15 -04:00
|
|
|
def _svd_cpu_gpu_translation_rule(gesvd_impl, c, operand, full_matrices, compute_uv):
|
2019-09-05 18:12:00 -04:00
|
|
|
|
2019-01-08 09:24:48 +05:30
|
|
|
shape = c.GetShape(operand)
|
2019-09-05 18:12:00 -04:00
|
|
|
batch_dims = shape.dimensions()[:-2]
|
|
|
|
s, u, vt, info = gesvd_impl(c, operand, full_matrices=full_matrices,
|
|
|
|
compute_uv=compute_uv)
|
|
|
|
ok = c.Eq(info, c.ConstantS32Scalar(0))
|
|
|
|
s = _broadcasting_select(c, c.Reshape(ok, None, batch_dims + (1,)), s,
|
|
|
|
_nan_like(c, s))
|
|
|
|
u = _broadcasting_select(c, c.Reshape(ok, None, batch_dims + (1, 1)), u,
|
|
|
|
_nan_like(c, u))
|
|
|
|
vt = _broadcasting_select(c, c.Reshape(ok, None, batch_dims + (1, 1)), vt,
|
|
|
|
_nan_like(c, vt))
|
|
|
|
return c.Tuple(s, u, vt)
|
2019-01-08 09:24:48 +05:30
|
|
|
|
2019-04-30 13:19:34 -04:00
|
|
|
def svd_batching_rule(batched_args, batch_dims, full_matrices, compute_uv):
|
|
|
|
x, = batched_args
|
|
|
|
bd, = batch_dims
|
2019-07-27 15:46:14 -07:00
|
|
|
x = batching.moveaxis(x, bd, 0)
|
|
|
|
outs = svd_p.bind(x, full_matrices=full_matrices, compute_uv=compute_uv)
|
|
|
|
return outs, (0, 0, 0)
|
2019-04-30 13:19:34 -04:00
|
|
|
|
2019-01-05 11:13:08 +05:30
|
|
|
svd_p = Primitive('svd')
|
2019-07-27 15:46:14 -07:00
|
|
|
svd_p.multiple_results = True
|
2019-01-05 11:13:08 +05:30
|
|
|
svd_p.def_impl(svd_impl)
|
|
|
|
svd_p.def_abstract_eval(svd_abstract_eval)
|
2019-05-02 17:47:34 -07:00
|
|
|
ad.primitive_jvps[svd_p] = svd_jvp_rule
|
2019-04-30 13:19:34 -04:00
|
|
|
batching.primitive_batchers[svd_p] = svd_batching_rule
|
2019-08-02 11:16:15 -04:00
|
|
|
xla.translations[svd_p] = svd_translation_rule
|
|
|
|
|
|
|
|
xla.backend_specific_translations['cpu'][svd_p] = partial(
|
2019-09-04 16:24:32 -04:00
|
|
|
_svd_cpu_gpu_translation_rule, lapack.gesdd)
|
2019-08-02 11:16:15 -04:00
|
|
|
|
2019-09-02 07:25:06 -07:00
|
|
|
xla.backend_specific_translations['gpu'][svd_p] = partial(
|
|
|
|
_svd_cpu_gpu_translation_rule, cusolver.gesvd)
|