2019-10-10 08:27:21 -07:00
|
|
|
# coding=utf-8
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
2020-05-05 20:41:57 -04:00
|
|
|
import numpy as np
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2020-05-05 20:41:57 -04:00
|
|
|
from jax.numpy import lax_numpy as jnp
|
2020-02-12 17:05:18 -08:00
|
|
|
from jax.numpy.vectorize import vectorize
|
2019-06-17 20:32:19 -04:00
|
|
|
from jax import ad_util
|
2019-06-25 16:52:28 -04:00
|
|
|
from jax import api
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
from jax import lax
|
2019-06-17 20:32:19 -04:00
|
|
|
from jax import ops
|
2019-11-15 10:02:51 -05:00
|
|
|
from jax import dtypes
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
from jax.interpreters import xla
|
|
|
|
from jax.interpreters import ad
|
2019-02-05 08:39:03 -08:00
|
|
|
from jax.interpreters import batching
|
2019-08-08 11:50:31 -04:00
|
|
|
from jax.util import partial, prod
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
from jax.abstract_arrays import ShapedArray
|
|
|
|
from jax.core import Primitive
|
2020-01-15 13:13:11 -08:00
|
|
|
from jax.lax import (standard_primitive, standard_unop, naryop_dtype_rule,
|
2019-05-29 16:50:27 -04:00
|
|
|
_float, _complex, _input_dtype, _broadcasting_select)
|
2019-07-29 15:06:05 -04:00
|
|
|
from jax.lib import lapack
|
2019-08-02 11:16:15 -04:00
|
|
|
from jax.lib import cusolver
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2020-04-23 18:30:47 -04:00
|
|
|
from jax.lib import xla_client
|
|
|
|
from jax.lib import xla_bridge as xb
|
|
|
|
|
|
|
|
xops = xla_client.ops
|
|
|
|
|
2020-02-12 17:05:18 -08:00
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
# traceables
|
|
|
|
|
2019-02-11 16:18:13 -08:00
|
|
|
def cholesky(x, symmetrize_input=True):
|
|
|
|
if symmetrize_input:
|
|
|
|
x = symmetrize(x)
|
2020-05-05 20:41:57 -04:00
|
|
|
return jnp.tril(cholesky_p.bind(x))
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2019-05-13 15:59:58 -04:00
|
|
|
def eig(x):
|
|
|
|
w, vl, vr = eig_p.bind(x)
|
|
|
|
return w, vl, vr
|
|
|
|
|
2019-02-13 23:23:39 -08:00
|
|
|
def eigh(x, lower=True, symmetrize_input=True):
|
|
|
|
if symmetrize_input:
|
|
|
|
x = symmetrize(x)
|
2019-04-30 20:57:44 -04:00
|
|
|
v, w = eigh_p.bind(x, lower=lower)
|
|
|
|
return v, w
|
2018-12-22 14:54:26 -05:00
|
|
|
|
2019-04-30 20:57:44 -04:00
|
|
|
def lu(x):
|
|
|
|
lu, pivots = lu_p.bind(x)
|
|
|
|
return lu, pivots
|
2018-12-20 15:37:34 -05:00
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
def qr(x, full_matrices=True):
|
2019-04-30 20:57:44 -04:00
|
|
|
q, r = qr_p.bind(x, full_matrices=full_matrices)
|
2020-04-28 12:01:54 -04:00
|
|
|
return q, r
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2019-01-05 11:13:08 +05:30
|
|
|
def svd(x, full_matrices=True, compute_uv=True):
|
2019-01-08 09:24:48 +05:30
|
|
|
s, u, v = svd_p.bind(x, full_matrices=full_matrices, compute_uv=compute_uv)
|
2019-01-05 11:13:08 +05:30
|
|
|
if compute_uv:
|
2019-01-08 21:47:19 +05:30
|
|
|
return u, s, v
|
2019-01-05 11:13:08 +05:30
|
|
|
else:
|
|
|
|
return s
|
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
def triangular_solve(a, b, left_side=False, lower=False, transpose_a=False,
|
2019-06-25 15:24:22 -04:00
|
|
|
conjugate_a=False, unit_diagonal=False):
|
2020-05-05 20:41:57 -04:00
|
|
|
conjugate_a = conjugate_a and jnp.issubdtype(lax.dtype(a), jnp.complexfloating)
|
|
|
|
singleton = jnp.ndim(b) == jnp.ndim(a) - 1
|
2020-03-21 10:46:07 -07:00
|
|
|
if singleton:
|
2020-05-05 20:41:57 -04:00
|
|
|
b = jnp.expand_dims(b, -1 if left_side else -2)
|
2020-03-21 10:46:07 -07:00
|
|
|
out = triangular_solve_p.bind(
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
a, b, left_side=left_side, lower=lower, transpose_a=transpose_a,
|
2019-06-25 15:24:22 -04:00
|
|
|
conjugate_a=conjugate_a, unit_diagonal=unit_diagonal)
|
2020-03-21 10:46:07 -07:00
|
|
|
if singleton:
|
|
|
|
out = out[..., 0] if left_side else out[..., 0, :]
|
|
|
|
return out
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
|
|
|
|
# utilities
|
|
|
|
|
2020-05-05 20:41:57 -04:00
|
|
|
def _T(x): return jnp.swapaxes(x, -1, -2)
|
|
|
|
def _H(x): return jnp.conj(_T(x))
|
2019-02-11 16:18:13 -08:00
|
|
|
def symmetrize(x): return (x + _H(x)) / 2
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2019-08-02 11:16:15 -04:00
|
|
|
def _unpack_tuple(f, n):
|
|
|
|
def g(c, *args, **kwargs):
|
|
|
|
t = f(c, *args, **kwargs)
|
2020-04-23 18:30:47 -04:00
|
|
|
return (xops.GetTupleElement(t, i) for i in range(n))
|
2019-08-02 11:16:15 -04:00
|
|
|
return g
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
# primitives
|
|
|
|
|
2020-05-05 20:41:57 -04:00
|
|
|
_cpu_lapack_types = {np.dtype(np.float32), np.dtype(np.float64),
|
|
|
|
np.dtype(np.complex64), np.dtype(np.complex128)}
|
2018-12-21 16:29:45 -05:00
|
|
|
|
2018-12-22 14:54:26 -05:00
|
|
|
# Cholesky decomposition
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
def cholesky_jvp_rule(primals, tangents):
|
|
|
|
x, = primals
|
|
|
|
sigma_dot, = tangents
|
2020-05-05 20:41:57 -04:00
|
|
|
L = jnp.tril(cholesky_p.bind(x))
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
# Forward-mode rule from https://arxiv.org/pdf/1602.07527.pdf
|
2019-06-19 09:29:33 -04:00
|
|
|
def phi(X):
|
2020-05-05 20:41:57 -04:00
|
|
|
l = jnp.tril(X)
|
|
|
|
return l / (jnp._constant_like(X, 1) + jnp.eye(X.shape[-1], dtype=X.dtype))
|
2019-06-19 09:29:33 -04:00
|
|
|
|
|
|
|
tmp = triangular_solve(L, sigma_dot, left_side=False, transpose_a=True,
|
|
|
|
conjugate_a=True, lower=True)
|
2019-02-11 16:18:13 -08:00
|
|
|
L_dot = lax.batch_matmul(L, phi(triangular_solve(
|
2019-12-10 00:38:18 -08:00
|
|
|
L, tmp, left_side=True, transpose_a=False, lower=True)),
|
|
|
|
precision=lax.Precision.HIGHEST)
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
return L, L_dot
|
|
|
|
|
2019-02-05 08:39:03 -08:00
|
|
|
def cholesky_batching_rule(batched_args, batch_dims):
|
|
|
|
x, = batched_args
|
|
|
|
bd, = batch_dims
|
2019-07-27 15:46:14 -07:00
|
|
|
x = batching.moveaxis(x, bd, 0)
|
2019-02-05 08:39:03 -08:00
|
|
|
return cholesky(x), 0
|
|
|
|
|
2018-12-21 16:29:45 -05:00
|
|
|
cholesky_p = standard_unop(_float | _complex, 'cholesky')
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
ad.primitive_jvps[cholesky_p] = cholesky_jvp_rule
|
2019-02-05 08:39:03 -08:00
|
|
|
batching.primitive_batchers[cholesky_p] = cholesky_batching_rule
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2019-05-29 16:50:27 -04:00
|
|
|
def _nan_like(c, operand):
|
2020-05-11 17:43:55 -04:00
|
|
|
shape = c.get_shape(operand)
|
2019-05-29 16:50:27 -04:00
|
|
|
dtype = shape.element_type()
|
2020-05-05 20:41:57 -04:00
|
|
|
if jnp.issubdtype(dtype, np.complexfloating):
|
|
|
|
nan = xb.constant(c, np.array(np.nan * (1. + 1j), dtype=dtype))
|
2019-05-29 16:50:27 -04:00
|
|
|
else:
|
2020-05-05 20:41:57 -04:00
|
|
|
nan = xb.constant(c, np.array(np.nan, dtype=dtype))
|
2020-04-23 18:30:47 -04:00
|
|
|
return xops.Broadcast(nan, shape.dimensions())
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2020-01-29 14:16:58 -05:00
|
|
|
def _cholesky_cpu_gpu_translation_rule(potrf_impl, c, operand):
|
2020-05-11 17:43:55 -04:00
|
|
|
shape = c.get_shape(operand)
|
2020-01-07 10:56:15 -05:00
|
|
|
batch_dims = shape.dimensions()[:-2]
|
2020-05-11 17:43:55 -04:00
|
|
|
result, info = potrf_impl(c, operand, lower=True)
|
2020-05-05 20:41:57 -04:00
|
|
|
ok = xops.Eq(info, xops.ConstantLiteral(c, np.array(0, np.int32)))
|
2020-01-29 14:16:58 -05:00
|
|
|
return _broadcasting_select(c,
|
2020-04-23 18:30:47 -04:00
|
|
|
xops.Reshape(ok, batch_dims + (1, 1)), result,
|
2020-01-29 14:16:58 -05:00
|
|
|
_nan_like(c, result))
|
2018-12-17 14:36:21 -05:00
|
|
|
|
2020-01-07 10:56:15 -05:00
|
|
|
xla.backend_specific_translations['cpu'][cholesky_p] = partial(
|
2020-01-29 14:16:58 -05:00
|
|
|
_cholesky_cpu_gpu_translation_rule, lapack.potrf)
|
2020-01-07 10:56:15 -05:00
|
|
|
|
2020-01-29 14:16:58 -05:00
|
|
|
xla.backend_specific_translations['gpu'][cholesky_p] = partial(
|
|
|
|
_cholesky_cpu_gpu_translation_rule, cusolver.potrf)
|
2018-12-17 14:36:21 -05:00
|
|
|
|
2019-05-13 15:59:58 -04:00
|
|
|
# Asymmetric eigendecomposition
|
|
|
|
|
|
|
|
def eig_impl(operand):
|
2019-05-13 19:53:50 -04:00
|
|
|
return xla.apply_primitive(eig_p, operand)
|
2019-05-13 15:59:58 -04:00
|
|
|
|
|
|
|
def eig_translation_rule(c, operand):
|
|
|
|
raise NotImplementedError(
|
|
|
|
"Nonsymmetric eigendecomposition is only implemented on the CPU backend")
|
|
|
|
|
|
|
|
def eig_abstract_eval(operand):
|
|
|
|
if isinstance(operand, ShapedArray):
|
|
|
|
if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
|
|
|
|
raise ValueError("Argument to nonsymmetric eigendecomposition must have "
|
2019-05-13 19:53:50 -04:00
|
|
|
"shape [..., n, n], got shape {}".format(operand.shape))
|
2019-05-13 15:59:58 -04:00
|
|
|
|
|
|
|
batch_dims = operand.shape[:-2]
|
|
|
|
n = operand.shape[-1]
|
2020-05-05 20:41:57 -04:00
|
|
|
dtype = np.complex64 if dtypes.finfo(operand.dtype).bits == 32 else np.complex128
|
2019-11-15 10:02:51 -05:00
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2019-07-27 15:46:14 -07:00
|
|
|
vl = vr = ShapedArray(batch_dims + (n, n), dtype)
|
|
|
|
w = ShapedArray(batch_dims + (n,), dtype)
|
2019-05-13 15:59:58 -04:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
raise NotImplementedError
|
|
|
|
return w, vl, vr
|
2019-05-13 15:59:58 -04:00
|
|
|
|
2019-09-02 07:25:06 -07:00
|
|
|
_cpu_geev = lapack.geev
|
2019-05-29 16:50:27 -04:00
|
|
|
|
2019-05-13 15:59:58 -04:00
|
|
|
def eig_cpu_translation_rule(c, operand):
|
2020-05-11 17:43:55 -04:00
|
|
|
shape = c.get_shape(operand)
|
2019-05-29 16:50:27 -04:00
|
|
|
batch_dims = shape.dimensions()[:-2]
|
2020-05-11 17:43:55 -04:00
|
|
|
w, vl, vr, info = _cpu_geev(c, operand)
|
2020-05-05 20:41:57 -04:00
|
|
|
ok = xops.Eq(info, xops.ConstantLiteral(c, np.array(0, np.int32)))
|
2020-04-23 18:30:47 -04:00
|
|
|
w = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1,)), w,
|
2019-05-29 16:50:27 -04:00
|
|
|
_nan_like(c, w))
|
2020-04-23 18:30:47 -04:00
|
|
|
vl = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), vl,
|
2019-05-29 16:50:27 -04:00
|
|
|
_nan_like(c, vl))
|
2020-04-23 18:30:47 -04:00
|
|
|
vr = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), vr,
|
2019-05-29 16:50:27 -04:00
|
|
|
_nan_like(c, vr))
|
2020-04-23 18:30:47 -04:00
|
|
|
return xops.Tuple(c, [w, vl, vr])
|
2019-05-13 15:59:58 -04:00
|
|
|
|
|
|
|
def eig_batching_rule(batched_args, batch_dims):
|
|
|
|
x, = batched_args
|
|
|
|
bd, = batch_dims
|
2019-07-27 15:46:14 -07:00
|
|
|
x = batching.moveaxis(x, bd, 0)
|
|
|
|
return eig_p.bind(x), (0, 0, 0)
|
2019-05-13 15:59:58 -04:00
|
|
|
|
|
|
|
eig_p = Primitive('eig')
|
2019-07-27 15:46:14 -07:00
|
|
|
eig_p.multiple_results = True
|
2019-05-13 15:59:58 -04:00
|
|
|
eig_p.def_impl(eig_impl)
|
|
|
|
eig_p.def_abstract_eval(eig_abstract_eval)
|
|
|
|
xla.translations[eig_p] = eig_translation_rule
|
|
|
|
xla.backend_specific_translations['cpu'][eig_p] = eig_cpu_translation_rule
|
|
|
|
batching.primitive_batchers[eig_p] = eig_batching_rule
|
|
|
|
|
2018-12-17 14:36:21 -05:00
|
|
|
|
2018-12-22 14:54:26 -05:00
|
|
|
# Symmetric/Hermitian eigendecomposition
|
|
|
|
|
2019-01-07 18:10:08 -05:00
|
|
|
def eigh_impl(operand, lower):
|
2019-01-07 18:28:48 -05:00
|
|
|
v, w = xla.apply_primitive(eigh_p, operand, lower=lower)
|
2019-07-27 15:46:14 -07:00
|
|
|
return v, w
|
2018-12-22 14:54:26 -05:00
|
|
|
|
2019-01-07 18:10:08 -05:00
|
|
|
def eigh_translation_rule(c, operand, lower):
|
2020-05-11 17:43:55 -04:00
|
|
|
shape = c.get_shape(operand)
|
2019-10-08 16:09:50 -04:00
|
|
|
dims = shape.dimensions()
|
|
|
|
if dims[-1] == 0:
|
2020-04-23 18:30:47 -04:00
|
|
|
return xops.Tuple(c, [operand, xops.Reshape(operand, dims[:-1])])
|
2019-10-08 16:09:50 -04:00
|
|
|
if not lower:
|
|
|
|
n = len(dims)
|
2020-04-23 18:30:47 -04:00
|
|
|
operand = xops.Transpose(operand, list(range(n - 2)) + [n - 1, n - 2])
|
|
|
|
return xops.Tuple(c, xops.Eigh(operand))
|
2018-12-22 14:54:26 -05:00
|
|
|
|
2019-01-07 18:10:08 -05:00
|
|
|
def eigh_abstract_eval(operand, lower):
|
2018-12-22 14:54:26 -05:00
|
|
|
if isinstance(operand, ShapedArray):
|
|
|
|
if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
|
|
|
|
raise ValueError(
|
2019-05-13 19:53:50 -04:00
|
|
|
"Argument to symmetric eigendecomposition must have shape [..., n, n],"
|
|
|
|
"got shape {}".format(operand.shape))
|
2018-12-22 14:54:26 -05:00
|
|
|
|
|
|
|
batch_dims = operand.shape[:-2]
|
|
|
|
n = operand.shape[-1]
|
|
|
|
v = ShapedArray(batch_dims + (n, n), operand.dtype)
|
2019-05-04 09:42:01 -04:00
|
|
|
w = ShapedArray(batch_dims + (n,), lax.lax._complex_basetype(operand.dtype))
|
2018-12-22 14:54:26 -05:00
|
|
|
else:
|
2019-01-07 18:28:48 -05:00
|
|
|
v, w = operand, operand
|
2019-07-27 15:46:14 -07:00
|
|
|
return v, w
|
2018-12-22 14:54:26 -05:00
|
|
|
|
2019-08-02 11:16:15 -04:00
|
|
|
def _eigh_cpu_gpu_translation_rule(syevd_impl, c, operand, lower):
|
2020-05-11 17:43:55 -04:00
|
|
|
shape = c.get_shape(operand)
|
2019-05-29 16:50:27 -04:00
|
|
|
batch_dims = shape.dimensions()[:-2]
|
2020-05-11 17:43:55 -04:00
|
|
|
v, w, info = syevd_impl(c, operand, lower=lower)
|
2020-05-05 20:41:57 -04:00
|
|
|
ok = xops.Eq(info, xops.ConstantLiteral(c, np.array(0, np.int32)))
|
2020-04-23 18:30:47 -04:00
|
|
|
v = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), v,
|
2019-05-29 16:50:27 -04:00
|
|
|
_nan_like(c, v))
|
2020-04-23 18:30:47 -04:00
|
|
|
w = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1,)), w,
|
2019-05-29 16:50:27 -04:00
|
|
|
_nan_like(c, w))
|
2020-04-23 18:30:47 -04:00
|
|
|
return xops.Tuple(c, [v, w])
|
2018-12-22 14:54:26 -05:00
|
|
|
|
2019-02-11 23:26:26 -08:00
|
|
|
def eigh_jvp_rule(primals, tangents, lower):
|
|
|
|
# Derivative for eigh in the simplest case of distinct eigenvalues.
|
2019-02-13 23:23:39 -08:00
|
|
|
# This is classic nondegenerate perurbation theory, but also see
|
|
|
|
# https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
|
2019-02-11 23:26:26 -08:00
|
|
|
# The general solution treating the case of degenerate eigenvalues is
|
|
|
|
# considerably more complicated. Ambitious readers may refer to the general
|
2019-02-13 23:23:39 -08:00
|
|
|
# methods below or refer to degenerate perturbation theory in physics.
|
2019-02-11 23:26:26 -08:00
|
|
|
# https://www.win.tue.nl/analysis/reports/rana06-33.pdf and
|
|
|
|
# https://people.orie.cornell.edu/aslewis/publications/99-clarke.pdf
|
|
|
|
a, = primals
|
|
|
|
a_dot, = tangents
|
2019-08-09 10:50:31 -04:00
|
|
|
|
2019-02-13 23:44:41 -08:00
|
|
|
v, w = eigh_p.bind(symmetrize(a), lower=lower)
|
2019-08-09 10:50:31 -04:00
|
|
|
|
2019-02-11 23:26:26 -08:00
|
|
|
# for complex numbers we need eigenvalues to be full dtype of v, a:
|
|
|
|
w = w.astype(a.dtype)
|
2020-05-05 20:41:57 -04:00
|
|
|
eye_n = jnp.eye(a.shape[-1], dtype=a.dtype)
|
2019-02-11 23:26:26 -08:00
|
|
|
# carefully build reciprocal delta-eigenvalue matrix, avoiding NaNs.
|
2020-05-05 20:41:57 -04:00
|
|
|
Fmat = jnp.reciprocal(eye_n + w[..., jnp.newaxis, :] - w[..., jnp.newaxis]) - eye_n
|
2019-02-11 23:26:26 -08:00
|
|
|
# eigh impl doesn't support batch dims, but future-proof the grad.
|
2019-12-10 00:38:18 -08:00
|
|
|
dot = partial(lax.dot if a.ndim == 2 else lax.batch_matmul,
|
|
|
|
precision=lax.Precision.HIGHEST)
|
2019-02-11 23:26:26 -08:00
|
|
|
vdag_adot_v = dot(dot(_H(v), a_dot), v)
|
2020-05-05 20:41:57 -04:00
|
|
|
dv = dot(v, jnp.multiply(Fmat, vdag_adot_v))
|
|
|
|
dw = jnp.diagonal(vdag_adot_v, axis1=-2, axis2=-1)
|
2019-07-27 15:46:14 -07:00
|
|
|
return (v, w), (dv, dw)
|
2019-02-11 23:26:26 -08:00
|
|
|
|
2019-05-10 15:15:38 -04:00
|
|
|
def eigh_batching_rule(batched_args, batch_dims, lower):
|
|
|
|
x, = batched_args
|
|
|
|
bd, = batch_dims
|
2019-07-27 15:46:14 -07:00
|
|
|
x = batching.moveaxis(x, bd, 0)
|
|
|
|
return eigh_p.bind(x, lower=lower), (0, 0)
|
2019-05-10 15:15:38 -04:00
|
|
|
|
2018-12-22 14:54:26 -05:00
|
|
|
eigh_p = Primitive('eigh')
|
2019-07-27 15:46:14 -07:00
|
|
|
eigh_p.multiple_results = True
|
2018-12-22 14:54:26 -05:00
|
|
|
eigh_p.def_impl(eigh_impl)
|
|
|
|
eigh_p.def_abstract_eval(eigh_abstract_eval)
|
|
|
|
xla.translations[eigh_p] = eigh_translation_rule
|
2019-02-11 23:26:26 -08:00
|
|
|
ad.primitive_jvps[eigh_p] = eigh_jvp_rule
|
2019-10-08 16:09:50 -04:00
|
|
|
batching.primitive_batchers[eigh_p] = eigh_batching_rule
|
2019-08-02 11:16:15 -04:00
|
|
|
|
2019-09-02 07:25:06 -07:00
|
|
|
_cpu_syevd = lapack.syevd
|
2019-08-02 11:16:15 -04:00
|
|
|
|
|
|
|
xla.backend_specific_translations['cpu'][eigh_p] = partial(
|
|
|
|
_eigh_cpu_gpu_translation_rule, _cpu_syevd)
|
|
|
|
|
2019-09-02 07:25:06 -07:00
|
|
|
xla.backend_specific_translations['gpu'][eigh_p] = partial(
|
|
|
|
_eigh_cpu_gpu_translation_rule, cusolver.syevd)
|
2019-10-08 16:09:50 -04:00
|
|
|
|
2018-12-22 14:54:26 -05:00
|
|
|
|
|
|
|
|
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
triangular_solve_dtype_rule = partial(
|
2020-01-15 13:13:11 -08:00
|
|
|
naryop_dtype_rule, _input_dtype, (_float | _complex, _float | _complex),
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
'triangular_solve')
|
|
|
|
|
|
|
|
def triangular_solve_shape_rule(a, b, left_side=False, **unused_kwargs):
|
|
|
|
if a.ndim < 2:
|
|
|
|
msg = "triangular_solve requires a.ndim to be at least 2, got {}."
|
|
|
|
raise TypeError(msg.format(a.ndim))
|
Better batching rule for triangular_solve (#2138)
* Better batching rule for triangular_solve
Now, if only the right hand side argument `b` is batched, we leverage
triangular solve's builtin batching for handling multiple right-hand-side
vectors.
This makes the performance of `vmap` over only the second argument of linear
solves equivalent to relying on builtin batching::
rs = onp.random.RandomState(0)
a = rs.randn(500, 500) + 0.1 * np.eye(500)
b_mat = jax.device_put(rs.randn(500, 10))
solve1 = jax.jit(np.linalg.solve)
solve2 = jax.jit(jax.vmap(np.linalg.solve, in_axes=(None, 1), out_axes=1))
Before::
In [6]: %timeit jax.device_get(solve1(a, b_mat))
3.88 ms ± 293 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# 8x slower :(
In [9]: %timeit jax.device_get(solve2(a, b_mat))
23.5 ms ± 1.33 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Now::
In [2]: %timeit jax.device_get(solve1(a, b_mat))
3.76 ms ± 304 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
# same speed :)
In [3]: %timeit jax.device_get(solve2(a, b_mat))
3.72 ms ± 296 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
* Test failures
* Check b.ndim == 2 in triangular solve shape rule
2020-02-03 09:27:03 -08:00
|
|
|
if b.ndim < 2:
|
|
|
|
msg = "triangular_solve requires b.ndim to be at least 2, got {}."
|
|
|
|
raise TypeError(msg.format(b.ndim))
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
if a.shape[-1] != a.shape[-2]:
|
|
|
|
msg = ("triangular_solve requires the last two dimensions of a to be equal "
|
|
|
|
"in size, got a.shape of {}.")
|
|
|
|
raise TypeError(msg.format(a.shape))
|
|
|
|
if a.shape[:-2] != b.shape[:-2]:
|
|
|
|
msg = ("triangular_solve requires both arguments to have the same number "
|
|
|
|
"of dimensions and equal batch dimensions, got {} and {}.")
|
|
|
|
raise TypeError(msg.format(a.shape, b.shape))
|
|
|
|
common_dim = -2 if left_side else -1
|
|
|
|
if a.shape[-1] != b.shape[common_dim]:
|
|
|
|
msg = "Incompatible shapes for arguments to triangular_solve: {} and {}."
|
|
|
|
raise TypeError(msg.format(a.shape, b.shape))
|
|
|
|
return b.shape
|
|
|
|
|
2018-12-17 17:20:52 -08:00
|
|
|
def triangular_solve_jvp_rule_a(
|
2019-06-25 15:24:22 -04:00
|
|
|
g_a, ans, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal):
|
2019-10-10 08:27:21 -07:00
|
|
|
m, n = b.shape[-2:]
|
2019-06-25 15:24:22 -04:00
|
|
|
k = 1 if unit_diagonal else 0
|
2020-05-05 20:41:57 -04:00
|
|
|
g_a = jnp.tril(g_a, k=-k) if lower else jnp.triu(g_a, k=k)
|
2018-12-17 17:20:52 -08:00
|
|
|
g_a = lax.neg(g_a)
|
2020-05-05 20:41:57 -04:00
|
|
|
g_a = jnp.swapaxes(g_a, -1, -2) if transpose_a else g_a
|
|
|
|
g_a = jnp.conj(g_a) if conjugate_a else g_a
|
2019-12-10 00:38:18 -08:00
|
|
|
dot = partial(lax.dot if g_a.ndim == 2 else lax.batch_matmul,
|
|
|
|
precision=lax.Precision.HIGHEST)
|
2019-10-10 08:27:21 -07:00
|
|
|
|
|
|
|
def a_inverse(rhs):
|
|
|
|
return triangular_solve(a, rhs, left_side, lower, transpose_a, conjugate_a,
|
|
|
|
unit_diagonal)
|
|
|
|
|
|
|
|
# triangular_solve is about the same cost as matrix multplication (~n^2 FLOPs
|
|
|
|
# for matrix/vector inputs). Order these operations in whichever order is
|
|
|
|
# cheaper.
|
2018-12-17 17:20:52 -08:00
|
|
|
if left_side:
|
2019-10-10 08:27:21 -07:00
|
|
|
assert g_a.shape[-2:] == a.shape[-2:] == (m, m) and ans.shape[-2:] == (m, n)
|
|
|
|
if m > n:
|
|
|
|
return a_inverse(dot(g_a, ans)) # A^{-1} (∂A X)
|
|
|
|
else:
|
|
|
|
return dot(a_inverse(g_a), ans) # (A^{-1} ∂A) X
|
2018-12-17 17:20:52 -08:00
|
|
|
else:
|
2019-10-10 08:27:21 -07:00
|
|
|
assert g_a.shape[-2:] == a.shape[-2:] == (n, n) and ans.shape[-2:] == (m, n)
|
|
|
|
if m < n:
|
|
|
|
return a_inverse(dot(ans, g_a)) # (X ∂A) A^{-1}
|
|
|
|
else:
|
|
|
|
return dot(ans, a_inverse(g_a)) # X (∂A A^{-1})
|
2018-12-17 17:20:52 -08:00
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
def triangular_solve_transpose_rule(
|
2019-06-25 15:24:22 -04:00
|
|
|
cotangent, a, b, left_side, lower, transpose_a, conjugate_a,
|
|
|
|
unit_diagonal):
|
2019-07-26 17:21:11 -07:00
|
|
|
# Triangular solve is nonlinear in its first argument and linear in its second
|
|
|
|
# argument, analogous to `div` but swapped.
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
assert not ad.is_undefined_primal(a) and ad.is_undefined_primal(b)
|
2020-05-27 13:57:47 +00:00
|
|
|
if type(cotangent) is ad_util.Zero:
|
|
|
|
cotangent_b = ad_util.Zero(b.aval)
|
2019-10-21 18:03:36 -07:00
|
|
|
else:
|
|
|
|
cotangent_b = triangular_solve(a, cotangent, left_side, lower,
|
|
|
|
not transpose_a, conjugate_a, unit_diagonal)
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
return [None, cotangent_b]
|
|
|
|
|
2019-03-10 17:31:51 -04:00
|
|
|
|
|
|
|
def triangular_solve_batching_rule(batched_args, batch_dims, left_side,
|
2019-06-25 15:24:22 -04:00
|
|
|
lower, transpose_a, conjugate_a,
|
|
|
|
unit_diagonal):
|
2019-03-10 17:31:51 -04:00
|
|
|
x, y = batched_args
|
|
|
|
bx, by = batch_dims
|
Better batching rule for triangular_solve (#2138)
* Better batching rule for triangular_solve
Now, if only the right hand side argument `b` is batched, we leverage
triangular solve's builtin batching for handling multiple right-hand-side
vectors.
This makes the performance of `vmap` over only the second argument of linear
solves equivalent to relying on builtin batching::
rs = onp.random.RandomState(0)
a = rs.randn(500, 500) + 0.1 * np.eye(500)
b_mat = jax.device_put(rs.randn(500, 10))
solve1 = jax.jit(np.linalg.solve)
solve2 = jax.jit(jax.vmap(np.linalg.solve, in_axes=(None, 1), out_axes=1))
Before::
In [6]: %timeit jax.device_get(solve1(a, b_mat))
3.88 ms ± 293 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# 8x slower :(
In [9]: %timeit jax.device_get(solve2(a, b_mat))
23.5 ms ± 1.33 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Now::
In [2]: %timeit jax.device_get(solve1(a, b_mat))
3.76 ms ± 304 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
# same speed :)
In [3]: %timeit jax.device_get(solve2(a, b_mat))
3.72 ms ± 296 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
* Test failures
* Check b.ndim == 2 in triangular solve shape rule
2020-02-03 09:27:03 -08:00
|
|
|
if bx is batching.not_mapped:
|
|
|
|
if left_side:
|
|
|
|
y = batching.moveaxis(y, by, -1)
|
|
|
|
y_flat = y.reshape(y.shape[:-2] + (y.shape[-2] * y.shape[-1],))
|
|
|
|
bdim_out = y.ndim - 1
|
|
|
|
else:
|
|
|
|
y = batching.moveaxis(y, by, -2)
|
|
|
|
y_flat = y.reshape(y.shape[:-3] + (y.shape[-3] * y.shape[-2], y.shape[-1]))
|
|
|
|
bdim_out = y.ndim - 2
|
|
|
|
out_flat = triangular_solve(
|
|
|
|
x, y_flat, left_side=left_side, lower=lower,
|
|
|
|
transpose_a=transpose_a, conjugate_a=conjugate_a,
|
|
|
|
unit_diagonal=unit_diagonal)
|
|
|
|
return out_flat.reshape(y.shape), bdim_out
|
|
|
|
else:
|
|
|
|
size = next(t.shape[i] for t, i in zip(batched_args, batch_dims)
|
|
|
|
if i is not None)
|
|
|
|
x = batching.bdim_at_front(x, bx, size)
|
|
|
|
y = batching.bdim_at_front(y, by, size)
|
|
|
|
return triangular_solve(x, y, left_side=left_side, lower=lower,
|
|
|
|
transpose_a=transpose_a, conjugate_a=conjugate_a,
|
|
|
|
unit_diagonal=unit_diagonal), 0
|
2019-03-10 17:31:51 -04:00
|
|
|
|
2020-04-23 18:30:47 -04:00
|
|
|
def _triangular_solve_translation_rule(
|
|
|
|
c, a, b, *, left_side, lower, transpose_a, conjugate_a, unit_diagonal):
|
|
|
|
if conjugate_a and not transpose_a:
|
|
|
|
a = xops.Conj(a)
|
|
|
|
conjugate_a = False
|
|
|
|
if not transpose_a:
|
|
|
|
transpose = xops.TriangularSolveOptions_Transpose.NO_TRANSPOSE
|
|
|
|
else:
|
|
|
|
transpose = (xops.TriangularSolveOptions_Transpose.ADJOINT if conjugate_a
|
|
|
|
else xops.TriangularSolveOptions_Transpose.TRANSPOSE)
|
|
|
|
return xops.TriangularSolve(a, b, left_side, lower, unit_diagonal, transpose)
|
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
triangular_solve_p = standard_primitive(
|
|
|
|
triangular_solve_shape_rule, triangular_solve_dtype_rule,
|
2020-04-23 18:30:47 -04:00
|
|
|
'triangular_solve', translation_rule=_triangular_solve_translation_rule)
|
2018-12-17 17:20:52 -08:00
|
|
|
ad.defjvp2(triangular_solve_p,
|
|
|
|
triangular_solve_jvp_rule_a,
|
|
|
|
lambda g_b, _, a, b, **kws: triangular_solve(a, g_b, **kws))
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
ad.primitive_transposes[triangular_solve_p] = triangular_solve_transpose_rule
|
2019-03-10 17:31:51 -04:00
|
|
|
batching.primitive_batchers[triangular_solve_p] = triangular_solve_batching_rule
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
|
2019-08-08 11:50:31 -04:00
|
|
|
def _triangular_solve_cpu_translation_rule(
|
2019-06-25 15:24:22 -04:00
|
|
|
c, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal):
|
2020-05-11 17:43:55 -04:00
|
|
|
shape = c.get_shape(a)
|
2018-12-21 16:29:45 -05:00
|
|
|
dtype = shape.element_type().type
|
2020-01-14 16:18:47 +00:00
|
|
|
|
2020-04-23 18:30:47 -04:00
|
|
|
if conjugate_a and not transpose_a:
|
|
|
|
a = xops.Conj(a)
|
|
|
|
conjugate_a = False
|
2020-05-05 20:41:57 -04:00
|
|
|
if len(shape.dimensions()) == 2 and np.dtype(dtype) in _cpu_lapack_types:
|
2018-12-20 15:37:34 -05:00
|
|
|
return lapack.jax_trsm(
|
2020-05-11 17:43:55 -04:00
|
|
|
c, xb.constant(c, np.array(1, dtype=dtype)),
|
2020-04-23 18:30:47 -04:00
|
|
|
a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal)
|
2018-12-17 16:39:19 -05:00
|
|
|
else:
|
2020-01-14 16:18:47 +00:00
|
|
|
# Fall back to the HLO implementation for unsupported types or batching.
|
|
|
|
# TODO: Consider swapping XLA for LAPACK in batched case
|
2020-04-23 18:30:47 -04:00
|
|
|
if not transpose_a:
|
|
|
|
transpose = xops.TriangularSolveOptions_Transpose.NO_TRANSPOSE
|
|
|
|
else:
|
|
|
|
transpose = (xops.TriangularSolveOptions_Transpose.ADJOINT if conjugate_a
|
|
|
|
else xops.TriangularSolveOptions_Transpose.TRANSPOSE)
|
|
|
|
return xops.TriangularSolve(a, b, left_side, lower, unit_diagonal, transpose)
|
2018-12-17 16:39:19 -05:00
|
|
|
|
2019-08-08 11:50:31 -04:00
|
|
|
xla.backend_specific_translations['cpu'][triangular_solve_p] = \
|
|
|
|
_triangular_solve_cpu_translation_rule
|
2018-12-20 15:37:34 -05:00
|
|
|
|
2019-08-08 11:50:31 -04:00
|
|
|
def _triangular_solve_gpu_translation_rule(
|
|
|
|
c, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal):
|
2020-05-11 17:43:55 -04:00
|
|
|
shape = c.get_shape(a)
|
2019-08-08 11:50:31 -04:00
|
|
|
dims = shape.dimensions()
|
|
|
|
m, n = dims[-2:]
|
|
|
|
batch = prod(dims[:-2])
|
2020-04-23 18:30:47 -04:00
|
|
|
if conjugate_a and not transpose_a:
|
|
|
|
a = xops.Conj(a)
|
|
|
|
conjugate_a = False
|
2019-08-08 11:50:31 -04:00
|
|
|
if batch > 1 and m <= 32 and n <= 32:
|
|
|
|
return cusolver.trsm(
|
2020-05-11 17:43:55 -04:00
|
|
|
c, a, b, left_side, lower, transpose_a,
|
2020-04-23 18:30:47 -04:00
|
|
|
conjugate_a, unit_diagonal)
|
2019-08-08 11:50:31 -04:00
|
|
|
else:
|
|
|
|
# Use the XLA implementation for unbatched triangular_solve.
|
2020-04-23 18:30:47 -04:00
|
|
|
if not transpose_a:
|
|
|
|
transpose = xops.TriangularSolveOptions_Transpose.NO_TRANSPOSE
|
|
|
|
else:
|
|
|
|
transpose = (xops.TriangularSolveOptions_Transpose.ADJOINT if conjugate_a
|
|
|
|
else xops.TriangularSolveOptions_Transpose.TRANSPOSE)
|
|
|
|
return xops.TriangularSolve(a, b, left_side, lower, unit_diagonal,
|
|
|
|
transpose)
|
2019-08-08 11:50:31 -04:00
|
|
|
|
2019-09-04 16:24:32 -04:00
|
|
|
xla.backend_specific_translations['gpu'][triangular_solve_p] = \
|
|
|
|
_triangular_solve_gpu_translation_rule
|
2018-12-20 15:37:34 -05:00
|
|
|
|
|
|
|
# LU decomposition
|
|
|
|
|
|
|
|
# Computes a pivoted LU decomposition such that
|
|
|
|
# PA = LU
|
|
|
|
# In the style of LAPACK, LU are stored in the same matrix.
|
|
|
|
|
2019-06-25 16:52:28 -04:00
|
|
|
def _lu_unblocked(a):
|
|
|
|
"""Unblocked LU decomposition, as a rolled loop."""
|
|
|
|
m, n = a.shape
|
|
|
|
def body(k, state):
|
2019-09-16 08:45:10 -07:00
|
|
|
pivot, perm, a = state
|
2020-05-05 20:41:57 -04:00
|
|
|
m_idx = jnp.arange(m)
|
|
|
|
n_idx = jnp.arange(n)
|
2019-06-25 16:52:28 -04:00
|
|
|
|
2020-05-05 20:41:57 -04:00
|
|
|
if jnp.issubdtype(a.dtype, jnp.complexfloating):
|
2019-06-25 16:52:28 -04:00
|
|
|
t = a[:, k]
|
2020-05-05 20:41:57 -04:00
|
|
|
magnitude = jnp.abs(jnp.real(t)) + jnp.abs(jnp.imag(t))
|
2019-06-25 16:52:28 -04:00
|
|
|
else:
|
2020-05-05 20:41:57 -04:00
|
|
|
magnitude = jnp.abs(a[:, k])
|
|
|
|
i = jnp.argmax(jnp.where(m_idx >= k, magnitude, -jnp.inf))
|
2019-06-25 16:52:28 -04:00
|
|
|
pivot = ops.index_update(pivot, ops.index[k], i)
|
|
|
|
|
|
|
|
a = ops.index_update(a, ops.index[[k, i],], a[[i, k],])
|
|
|
|
|
|
|
|
perm = ops.index_update(perm, ops.index[[i, k],], perm[[k, i],])
|
|
|
|
|
|
|
|
# a[k+1:, k] /= a[k, k], adapted for loop-invariant shapes
|
2019-06-28 15:31:06 -04:00
|
|
|
x = a[k, k]
|
2019-06-25 16:52:28 -04:00
|
|
|
a = ops.index_update(a, ops.index[:, k],
|
2020-05-05 20:41:57 -04:00
|
|
|
jnp.where(m_idx > k, a[:, k] / x, a[:, k]))
|
2019-06-25 16:52:28 -04:00
|
|
|
|
2020-05-05 20:41:57 -04:00
|
|
|
# a[k+1:, k+1:] -= jnp.outer(a[k+1:, k], a[k, k+1:])
|
|
|
|
a = a - jnp.where((m_idx[:, None] > k) & (n_idx > k),
|
|
|
|
jnp.outer(a[:, k], a[k, :]), jnp.array(0, dtype=a.dtype))
|
2019-09-16 08:45:10 -07:00
|
|
|
return pivot, perm, a
|
2019-06-25 16:52:28 -04:00
|
|
|
|
2020-05-05 20:41:57 -04:00
|
|
|
pivot = jnp.zeros((min(m, n),), dtype=jnp.int32)
|
|
|
|
perm = jnp.arange(m, dtype=jnp.int32)
|
2019-06-25 16:52:28 -04:00
|
|
|
if m == 0 and n == 0:
|
|
|
|
# If the array is empty, the loop body never executes but tracing it to a
|
|
|
|
# jaxpr fails because the indexing cannot succeed.
|
2019-09-16 08:45:10 -07:00
|
|
|
return (pivot, perm, a)
|
|
|
|
return lax.fori_loop(0, min(m, n), body, (pivot, perm, a))
|
2019-06-25 16:52:28 -04:00
|
|
|
|
|
|
|
|
2020-03-27 21:24:26 -04:00
|
|
|
def _lu_blocked(a, block_size=128):
|
2019-06-25 16:52:28 -04:00
|
|
|
"""Blocked LU decomposition, as an unrolled loop."""
|
|
|
|
m, n = a.shape
|
|
|
|
r = min(m, n)
|
2020-05-05 20:41:57 -04:00
|
|
|
pivot = jnp.zeros((r,), dtype=jnp.int32)
|
2019-06-25 16:52:28 -04:00
|
|
|
for k in range(0, r, block_size):
|
|
|
|
b = min(r - k, block_size)
|
2019-09-16 08:45:10 -07:00
|
|
|
block_pivot, perm, lu_block = _lu_unblocked(a[k:, k:k+b])
|
2019-06-25 16:52:28 -04:00
|
|
|
|
2020-03-27 21:24:26 -04:00
|
|
|
a = ops.index_update(a, ops.index[k:, :], a[perm + k, :])
|
|
|
|
a = ops.index_update(a, ops.index[k:, k:k+b], lu_block)
|
2019-06-25 16:52:28 -04:00
|
|
|
pivot = ops.index_update(pivot, ops.index[k:k+b], block_pivot + k)
|
|
|
|
|
|
|
|
if k + b < n:
|
|
|
|
a = ops.index_update(
|
|
|
|
a, ops.index[k:k+b, k+b:],
|
|
|
|
triangular_solve(a[k:k+b, k:k+b], a[k:k+b, k+b:],
|
|
|
|
left_side=True, lower=True, unit_diagonal=True))
|
2019-06-28 09:00:32 -04:00
|
|
|
a = ops.index_add(
|
|
|
|
a, ops.index[k+b:, k+b:],
|
|
|
|
-lax.dot(a[k+b:, k:k+b], a[k:k+b, k+b:],
|
|
|
|
precision=lax.Precision.HIGHEST))
|
2019-06-25 16:52:28 -04:00
|
|
|
return pivot, a
|
|
|
|
|
|
|
|
def _lu_python(x):
|
|
|
|
"""Default LU decomposition in Python, where no better version exists."""
|
|
|
|
m, n = x.shape[-2:]
|
|
|
|
batch_dims = x.shape[:-2]
|
|
|
|
if len(batch_dims) > 0:
|
2020-05-05 20:41:57 -04:00
|
|
|
batch_size = np.prod(batch_dims, dtype=np.int64)
|
2019-06-25 16:52:28 -04:00
|
|
|
pivot, lu = api.vmap(_lu_blocked)(lax.reshape(x, (batch_size, m, n)))
|
2019-06-27 15:21:56 -04:00
|
|
|
pivot = lax.reshape(pivot, batch_dims + (min(m, n),))
|
2019-06-25 16:52:28 -04:00
|
|
|
lu = lax.reshape(lu, batch_dims + (m, n))
|
|
|
|
else:
|
|
|
|
pivot, lu = _lu_blocked(x)
|
2019-07-27 15:46:14 -07:00
|
|
|
return lu, pivot
|
2018-12-20 15:37:34 -05:00
|
|
|
|
2019-06-25 16:52:28 -04:00
|
|
|
def _lu_impl(operand):
|
|
|
|
lu, pivot = xla.apply_primitive(lu_p, operand)
|
2019-07-27 15:46:14 -07:00
|
|
|
return lu, pivot
|
2018-12-20 15:37:34 -05:00
|
|
|
|
2019-06-25 16:52:28 -04:00
|
|
|
def _lu_abstract_eval(operand):
|
2018-12-20 15:37:34 -05:00
|
|
|
if isinstance(operand, ShapedArray):
|
|
|
|
if operand.ndim < 2:
|
|
|
|
raise ValueError("Argument to LU decomposition must have ndims >= 2")
|
|
|
|
|
|
|
|
batch_dims = operand.shape[:-2]
|
|
|
|
m = operand.shape[-2]
|
|
|
|
n = operand.shape[-1]
|
2020-05-05 20:41:57 -04:00
|
|
|
pivot = ShapedArray(batch_dims + (min(m, n),), jnp.int32)
|
2018-12-20 15:37:34 -05:00
|
|
|
else:
|
|
|
|
pivot = operand
|
2019-07-27 15:46:14 -07:00
|
|
|
return operand, pivot
|
2018-12-20 15:37:34 -05:00
|
|
|
|
2019-06-27 15:21:56 -04:00
|
|
|
def _lu_jvp_rule(primals, tangents):
|
2018-12-22 14:53:42 -05:00
|
|
|
a, = primals
|
|
|
|
a_dot, = tangents
|
|
|
|
lu, pivots = lu_p.bind(a)
|
|
|
|
|
2020-05-05 20:41:57 -04:00
|
|
|
a_shape = jnp.shape(a)
|
2019-01-09 15:52:34 -05:00
|
|
|
m, n = a_shape[-2:]
|
2019-04-12 16:28:40 -07:00
|
|
|
dtype = lax.dtype(a)
|
2018-12-22 14:53:42 -05:00
|
|
|
k = min(m, n)
|
|
|
|
|
|
|
|
permutation = lu_pivots_to_permutation(pivots, m)
|
2019-08-09 10:50:31 -04:00
|
|
|
batch_dims = a_shape[:-2]
|
2020-05-05 20:41:57 -04:00
|
|
|
iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1,)))
|
2019-08-09 10:50:31 -04:00
|
|
|
x = a_dot[iotas[:-1] + (permutation, slice(None))]
|
2018-12-22 14:53:42 -05:00
|
|
|
|
|
|
|
# Differentiation of Matrix Functionals Using Triangular Factorization
|
|
|
|
# F. R. De Hoog, R. S. Anderssen, and M. A. Lukas
|
2019-01-09 13:22:25 -05:00
|
|
|
#
|
|
|
|
# LU = A
|
|
|
|
# ==> L'U + LU' = A'
|
|
|
|
# ==> inv(L) . L' + U' . inv(U) = inv(L) A' inv(U)
|
|
|
|
# ==> L' = L . tril(inv(L) . A' . inv(U), -1)
|
|
|
|
# U' = triu(inv(L) . A' . inv(U)) . U
|
2018-12-22 14:53:42 -05:00
|
|
|
|
2019-01-09 15:52:34 -05:00
|
|
|
ndims = len(a_shape)
|
|
|
|
l_padding = [(0, 0, 0)] * ndims
|
|
|
|
l_padding[-1] = (0, m - k, 0)
|
2020-05-05 20:41:57 -04:00
|
|
|
zero = jnp._constant_like(lu, 0)
|
|
|
|
l = lax.pad(jnp.tril(lu[..., :, :k], -1), zero, l_padding)
|
|
|
|
l = l + jnp.eye(m, m, dtype=dtype)
|
2019-01-09 15:52:34 -05:00
|
|
|
|
2020-05-05 20:41:57 -04:00
|
|
|
u_eye = lax.pad(jnp.eye(n - k, n - k, dtype=dtype), zero,
|
2019-01-09 15:52:34 -05:00
|
|
|
((k, 0, 0), (k, 0, 0)))
|
|
|
|
u_padding = [(0, 0, 0)] * ndims
|
|
|
|
u_padding[-2] = (0, n - k, 0)
|
2020-05-05 20:41:57 -04:00
|
|
|
u = lax.pad(jnp.triu(lu[..., :k, :]), zero, u_padding) + u_eye
|
2019-01-09 15:52:34 -05:00
|
|
|
|
2019-06-25 15:24:22 -04:00
|
|
|
la = triangular_solve(l, x, left_side=True, transpose_a=False, lower=True,
|
|
|
|
unit_diagonal=True)
|
2018-12-22 14:53:42 -05:00
|
|
|
lau = triangular_solve(u, la, left_side=False, transpose_a=False,
|
|
|
|
lower=False)
|
|
|
|
|
2020-05-05 20:41:57 -04:00
|
|
|
l_dot = jnp.matmul(l, jnp.tril(lau, -1))
|
|
|
|
u_dot = jnp.matmul(jnp.triu(lau), u)
|
2019-01-09 13:22:25 -05:00
|
|
|
lu_dot = l_dot + u_dot
|
2020-05-27 13:57:47 +00:00
|
|
|
return (lu, pivots), (lu_dot, ad_util.Zero.from_value(pivots))
|
2018-12-22 14:53:42 -05:00
|
|
|
|
|
|
|
|
2019-06-25 16:52:28 -04:00
|
|
|
def _lu_batching_rule(batched_args, batch_dims):
|
2019-04-30 13:19:34 -04:00
|
|
|
x, = batched_args
|
|
|
|
bd, = batch_dims
|
2019-07-27 15:46:14 -07:00
|
|
|
x = batching.moveaxis(x, bd, 0)
|
|
|
|
return lu_p.bind(x), (0, 0)
|
2019-04-30 13:19:34 -04:00
|
|
|
|
2019-08-02 11:16:15 -04:00
|
|
|
def _lu_cpu_gpu_translation_rule(getrf_impl, c, operand):
|
2020-05-11 17:43:55 -04:00
|
|
|
shape = c.get_shape(operand)
|
2019-06-28 15:31:06 -04:00
|
|
|
batch_dims = shape.dimensions()[:-2]
|
2020-05-11 17:43:55 -04:00
|
|
|
lu, pivot, info = getrf_impl(c, operand)
|
2019-06-28 15:31:06 -04:00
|
|
|
# Subtract 1 from the pivot to get 0-based indices.
|
2020-05-05 20:41:57 -04:00
|
|
|
pivot = xops.Sub(pivot, xops.ConstantLiteral(c, np.array(1, np.int32)))
|
|
|
|
ok = xops.Ge(info, xops.ConstantLiteral(c, np.array(0, np.int32)))
|
2020-04-23 18:30:47 -04:00
|
|
|
lu = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), lu,
|
2019-06-28 15:31:06 -04:00
|
|
|
_nan_like(c, lu))
|
2020-04-23 18:30:47 -04:00
|
|
|
return xops.Tuple(c, [lu, pivot])
|
2019-06-28 15:31:06 -04:00
|
|
|
|
2019-04-30 13:19:34 -04:00
|
|
|
|
2018-12-20 15:37:34 -05:00
|
|
|
lu_p = Primitive('lu')
|
2019-07-27 15:46:14 -07:00
|
|
|
lu_p.multiple_results = True
|
2019-06-25 16:52:28 -04:00
|
|
|
lu_p.def_impl(_lu_impl)
|
|
|
|
lu_p.def_abstract_eval(_lu_abstract_eval)
|
change the xla representation of JAX's unit (#2416)
* change the xla representation of JAX's unit
Previously the representation of JAX's unit value (a sentinel /
placeholder) was an empty tuple, but by changing the representation to
something else we can further reduce our dependence on runtime tuples.
This commit makes the representation fairly easy to change. There are
three functions in xla.py that define the representation. Here are
versions that would keep the old XLA representation as an empty tuple:
```
def _make_unit(c): return c.Tuple()
def _make_abstract_unit(_): return xc.Shape.tuple_shape(())
def _device_put_unit(_, device):
return xc.Buffer.make_tuple((), device, backend=xb.get_device_backend(device))
```
The new representation is as a trivial array. An alternative
representation would be nothing at all: we don't need to generate XLA
computations that have representations of JAX units. While that
alterntaive is probably the best choice, it seemed like it would require
a bit more refactoring/bookkeeping (e.g. to allow XLA computations to
have a smaller number of outputs than the corresponding JAX function),
and would also mean the XLA representation would be a step further
removed from the jaxpr representation. So I stuck with a trivial array
for now.
The mapping from JAX types to XLA types need not be invertible. However,
XLA translation rules currently don't take as arguments the
corresponding JAX types (abstract values), and there were a few cases
where we relied on checking whether an argument's XLA type was that of
an empty tuple so as to determine if we were effectively operating on a
JAX unit.
In particular, the AD-related primitive add_jaxvals_p could in principle
add two units, and get lowered to an XLA addition on the unit
representation. Previously, the translation rule for add_jaxvals_p
checked the XLA type so that adding two empty tuples didn't produce any
XLA operation; now it adds its inputs, and so if unit is represented as
a trivial array we could be inserting trivial scalar adds where we had
none before. However, if that case is ever possible, it doesn't come up
in our tests (which I checked by keeping the representation as an empty
tuple and then asserting an XLA tuple type is never seen by that
translation rule).
* add comment about JAX<->XLA array types assumption
2020-03-14 12:33:14 -07:00
|
|
|
xla.translations[lu_p] = xla.lower_fun(_lu_python)
|
2019-06-25 16:52:28 -04:00
|
|
|
ad.primitive_jvps[lu_p] = _lu_jvp_rule
|
|
|
|
batching.primitive_batchers[lu_p] = _lu_batching_rule
|
2019-08-02 11:16:15 -04:00
|
|
|
|
|
|
|
xla.backend_specific_translations['cpu'][lu_p] = partial(
|
2019-09-04 16:24:32 -04:00
|
|
|
_lu_cpu_gpu_translation_rule, lapack.getrf)
|
2019-08-02 11:16:15 -04:00
|
|
|
|
2019-09-04 16:24:32 -04:00
|
|
|
xla.backend_specific_translations['gpu'][lu_p] = partial(
|
|
|
|
_lu_cpu_gpu_translation_rule, cusolver.getrf)
|
2018-12-17 16:39:19 -05:00
|
|
|
|
|
|
|
|
2019-09-26 11:25:51 +02:00
|
|
|
# Define this outside lu_pivots_to_permutation to ensure fori_loop cache hits
|
|
|
|
def _lu_pivots_body_fn(i, permutation_and_swaps):
|
|
|
|
permutation, swaps = permutation_and_swaps
|
|
|
|
batch_dims = swaps.shape[:-1]
|
|
|
|
j = swaps[..., i]
|
2020-05-05 20:41:57 -04:00
|
|
|
iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims))
|
2019-09-26 11:25:51 +02:00
|
|
|
x = permutation[..., i]
|
|
|
|
y = permutation[iotas + (j,)]
|
|
|
|
permutation = ops.index_update(permutation, ops.index[..., i], y)
|
|
|
|
return ops.index_update(permutation, ops.index[iotas + (j,)], x), swaps
|
|
|
|
|
2020-02-20 16:10:09 -08:00
|
|
|
|
|
|
|
@partial(api.jit, static_argnums=(1,))
|
2019-06-17 20:32:19 -04:00
|
|
|
def lu_pivots_to_permutation(swaps, m):
|
|
|
|
"""Converts the pivots (row swaps) returned by LU to a permutation.
|
|
|
|
|
|
|
|
We build a permutation rather than applying `swaps` directly to the rows
|
|
|
|
of a matrix because lax loops aren't differentiable.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
swaps: an array of shape (..., k) of row swaps to perform
|
|
|
|
m: the size of the output permutation. m should be >= k.
|
|
|
|
Returns:
|
|
|
|
An int32 array of shape (..., m).
|
|
|
|
"""
|
|
|
|
assert len(swaps.shape) >= 1
|
|
|
|
batch_dims = swaps.shape[:-1]
|
|
|
|
k = swaps.shape[-1]
|
|
|
|
|
2020-05-05 20:41:57 -04:00
|
|
|
permutation = lax.broadcasted_iota(jnp.int32, batch_dims + (m,),
|
2019-06-17 20:32:19 -04:00
|
|
|
len(batch_dims))
|
2020-05-05 20:41:57 -04:00
|
|
|
result, _ = lax.fori_loop(np.array(0, np.int32), np.array(k, np.int32),
|
2019-09-26 13:39:35 +02:00
|
|
|
_lu_pivots_body_fn, (permutation, swaps))
|
2019-09-26 11:25:51 +02:00
|
|
|
return result
|
2018-12-20 15:37:34 -05:00
|
|
|
|
|
|
|
|
2020-02-12 17:05:18 -08:00
|
|
|
@partial(vectorize, excluded={3}, signature='(n,n),(n),(n,k)->(n,k)')
|
|
|
|
def _lu_solve_core(lu, pivots, b, trans):
|
|
|
|
m = lu.shape[0]
|
|
|
|
permutation = lu_pivots_to_permutation(pivots, m)
|
2020-05-05 20:41:57 -04:00
|
|
|
x = jnp.reshape(b, (m, -1))
|
2020-02-12 17:05:18 -08:00
|
|
|
if trans == 0:
|
|
|
|
x = x[permutation, :]
|
|
|
|
x = triangular_solve(lu, x, left_side=True, lower=True, unit_diagonal=True)
|
|
|
|
x = triangular_solve(lu, x, left_side=True, lower=False)
|
|
|
|
elif trans == 1 or trans == 2:
|
|
|
|
conj = trans == 2
|
|
|
|
x = triangular_solve(lu, x, left_side=True, lower=False, transpose_a=True,
|
|
|
|
conjugate_a=conj)
|
|
|
|
x = triangular_solve(lu, x, left_side=True, lower=True, unit_diagonal=True,
|
|
|
|
transpose_a=True, conjugate_a=conj)
|
2020-05-05 20:41:57 -04:00
|
|
|
x = x[jnp.argsort(permutation), :]
|
2020-02-12 17:05:18 -08:00
|
|
|
else:
|
|
|
|
raise ValueError("'trans' value must be 0, 1, or 2, got {}".format(trans))
|
|
|
|
return lax.reshape(x, b.shape)
|
|
|
|
|
|
|
|
|
|
|
|
@partial(api.jit, static_argnums=(3,))
|
|
|
|
def _lu_solve(lu, pivots, b, trans):
|
|
|
|
if len(lu.shape) < 2 or lu.shape[-1] != lu.shape[-2]:
|
|
|
|
raise ValueError("last two dimensions of LU decomposition must be equal, "
|
|
|
|
"got shape {}".format(lu.shape))
|
|
|
|
if len(b.shape) < 1:
|
|
|
|
raise ValueError("b matrix must have rank >= 1, got shape {}"
|
|
|
|
.format(b.shape))
|
|
|
|
# Broadcasting follows NumPy's convention for linalg.solve: the RHS is
|
|
|
|
# treated as a (batched) vector if the number of dimensions differ by 1.
|
|
|
|
# Otherwise, broadcasting rules apply.
|
|
|
|
rhs_vector = lu.ndim == b.ndim + 1
|
|
|
|
if rhs_vector:
|
|
|
|
if b.shape[-1] != lu.shape[-1]:
|
|
|
|
raise ValueError("When LU decomposition matrix and b have the same "
|
|
|
|
"number of dimensions, last axis of LU decomposition "
|
|
|
|
"matrix (shape {}) and b array (shape {}) must match"
|
|
|
|
.format(lu.shape, b.shape))
|
2020-05-05 20:41:57 -04:00
|
|
|
b = b[..., jnp.newaxis]
|
2020-02-12 17:05:18 -08:00
|
|
|
else:
|
|
|
|
if b.shape[-2] != lu.shape[-1]:
|
|
|
|
raise ValueError("When LU decomposition matrix and b different "
|
|
|
|
"numbers of dimensions, last axis of LU decomposition "
|
|
|
|
"matrix (shape {}) and second to last axis of b array "
|
|
|
|
"(shape {}) must match"
|
|
|
|
.format(lu.shape, b.shape))
|
|
|
|
x = _lu_solve_core(lu, pivots, b, trans)
|
|
|
|
return x[..., 0] if rhs_vector else x
|
|
|
|
|
|
|
|
|
|
|
|
def lu_solve(lu, pivots, b, trans=0):
|
|
|
|
"""LU solve with broadcasting."""
|
|
|
|
return _lu_solve(lu, pivots, b, trans)
|
|
|
|
|
|
|
|
|
2018-12-20 15:37:34 -05:00
|
|
|
# QR decomposition
|
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
def qr_impl(operand, full_matrices):
|
|
|
|
q, r = xla.apply_primitive(qr_p, operand, full_matrices=full_matrices)
|
2019-07-27 15:46:14 -07:00
|
|
|
return q, r
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
def qr_translation_rule(c, operand, full_matrices):
|
2020-04-23 18:30:47 -04:00
|
|
|
return xops.Tuple(c, xops.QR(operand, full_matrices))
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
def qr_abstract_eval(operand, full_matrices):
|
|
|
|
if isinstance(operand, ShapedArray):
|
|
|
|
if operand.ndim < 2:
|
|
|
|
raise ValueError("Argument to QR decomposition must have ndims >= 2")
|
|
|
|
batch_dims = operand.shape[:-2]
|
|
|
|
m = operand.shape[-2]
|
|
|
|
n = operand.shape[-1]
|
|
|
|
k = m if full_matrices else min(m, n)
|
|
|
|
q = ShapedArray(batch_dims + (m, k), operand.dtype)
|
|
|
|
r = ShapedArray(batch_dims + (k, n), operand.dtype)
|
|
|
|
else:
|
|
|
|
q = operand
|
|
|
|
r = operand
|
2019-07-27 15:46:14 -07:00
|
|
|
return q, r
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2018-12-17 16:02:29 +00:00
|
|
|
def qr_jvp_rule(primals, tangents, full_matrices):
|
2018-12-17 16:04:51 +00:00
|
|
|
# See j-towns.github.io/papers/qr-derivative.pdf for a terse derivation.
|
2018-12-17 16:02:29 +00:00
|
|
|
x, = primals
|
|
|
|
dx, = tangents
|
|
|
|
q, r = qr_p.bind(x, full_matrices=False)
|
2020-04-28 17:58:49 +01:00
|
|
|
*_, m, n = x.shape
|
|
|
|
if full_matrices or m < n:
|
2020-04-28 12:01:54 -04:00
|
|
|
raise NotImplementedError(
|
|
|
|
"Unimplemented case of QR decomposition derivative")
|
2018-12-17 16:02:29 +00:00
|
|
|
dx_rinv = triangular_solve(r, dx) # Right side solve by default
|
2020-05-05 20:41:57 -04:00
|
|
|
qt_dx_rinv = jnp.matmul(_H(q), dx_rinv)
|
|
|
|
qt_dx_rinv_lower = jnp.tril(qt_dx_rinv, -1)
|
2020-04-28 17:58:49 +01:00
|
|
|
do = qt_dx_rinv_lower - _H(qt_dx_rinv_lower) # This is skew-symmetric
|
|
|
|
# The following correction is necessary for complex inputs
|
2020-05-05 20:41:57 -04:00
|
|
|
do = do + jnp.eye(n, dtype=do.dtype) * (qt_dx_rinv - jnp.real(qt_dx_rinv))
|
|
|
|
dq = jnp.matmul(q, do - qt_dx_rinv) + dx_rinv
|
|
|
|
dr = jnp.matmul(qt_dx_rinv - do, r)
|
2019-07-27 15:46:14 -07:00
|
|
|
return (q, r), (dq, dr)
|
2018-12-17 16:02:29 +00:00
|
|
|
|
2019-04-30 13:19:34 -04:00
|
|
|
def qr_batching_rule(batched_args, batch_dims, full_matrices):
|
|
|
|
x, = batched_args
|
|
|
|
bd, = batch_dims
|
2019-07-27 15:46:14 -07:00
|
|
|
x = batching.moveaxis(x, bd, 0)
|
|
|
|
return qr_p.bind(x, full_matrices=full_matrices), (0, 0)
|
2019-04-30 13:19:34 -04:00
|
|
|
|
2019-09-04 16:24:32 -04:00
|
|
|
def _qr_cpu_gpu_translation_rule(geqrf_impl, orgqr_impl, c, operand,
|
|
|
|
full_matrices):
|
2020-05-11 17:43:55 -04:00
|
|
|
shape = c.get_shape(operand)
|
2019-09-04 16:24:32 -04:00
|
|
|
dims = shape.dimensions()
|
|
|
|
m, n = dims[-2:]
|
|
|
|
batch_dims = dims[:-2]
|
2020-05-11 17:43:55 -04:00
|
|
|
r, tau, info_geqrf = geqrf_impl(c, operand)
|
2019-09-04 16:24:32 -04:00
|
|
|
if m < n:
|
2020-04-23 18:30:47 -04:00
|
|
|
q = xops.Slice(r, [0] * len(dims), list(batch_dims) + [m, m],
|
|
|
|
[1] * len(dims))
|
2020-05-11 17:43:55 -04:00
|
|
|
q, info_orgqr = orgqr_impl(c, q, tau)
|
2019-09-04 16:24:32 -04:00
|
|
|
elif not full_matrices:
|
2020-05-11 17:43:55 -04:00
|
|
|
q, info_orgqr = orgqr_impl(c, r, tau)
|
2020-04-23 18:30:47 -04:00
|
|
|
r = xops.Slice(r, [0] * len(dims), list(batch_dims) + [n, n],
|
|
|
|
[1] * len(dims))
|
2019-09-04 16:24:32 -04:00
|
|
|
else:
|
|
|
|
padding_config = [(0, 0, 0)] * len(dims)
|
|
|
|
padding_config[-1] = (0, m - n, 0)
|
2020-05-05 20:41:57 -04:00
|
|
|
q = xops.Pad(r, xops.Constant(c, np.array(0, dtype=shape.element_type())),
|
2020-04-23 18:30:47 -04:00
|
|
|
xla_client.make_padding_config(padding_config))
|
2020-05-11 17:43:55 -04:00
|
|
|
q, info_orgqr = orgqr_impl(c, q, tau)
|
2020-04-23 18:30:47 -04:00
|
|
|
|
|
|
|
ok = xops.And(
|
2020-05-05 20:41:57 -04:00
|
|
|
xops.Eq(info_geqrf, xops.ConstantLiteral(c, np.array(0, np.int32))),
|
|
|
|
xops.Eq(info_orgqr, xops.ConstantLiteral(c, np.array(0, np.int32))))
|
2020-04-23 18:30:47 -04:00
|
|
|
q = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), q,
|
2019-09-04 16:24:32 -04:00
|
|
|
_nan_like(c, q))
|
2020-04-23 18:30:47 -04:00
|
|
|
r = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), r,
|
2019-09-04 16:24:32 -04:00
|
|
|
_nan_like(c, r))
|
2020-05-05 20:41:57 -04:00
|
|
|
r = xla.lower_fun(jnp.triu, multiple_results=False)(c, r)
|
2020-04-23 18:30:47 -04:00
|
|
|
return xops.Tuple(c, [q, r])
|
2019-09-04 16:24:32 -04:00
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
qr_p = Primitive('qr')
|
2019-07-27 15:46:14 -07:00
|
|
|
qr_p.multiple_results = True
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
qr_p.def_impl(qr_impl)
|
|
|
|
qr_p.def_abstract_eval(qr_abstract_eval)
|
|
|
|
xla.translations[qr_p] = qr_translation_rule
|
2018-12-17 16:02:29 +00:00
|
|
|
ad.primitive_jvps[qr_p] = qr_jvp_rule
|
2019-04-30 13:19:34 -04:00
|
|
|
batching.primitive_batchers[qr_p] = qr_batching_rule
|
2019-01-05 11:13:08 +05:30
|
|
|
|
2019-12-18 11:07:39 -05:00
|
|
|
xla.backend_specific_translations['cpu'][qr_p] = partial(
|
|
|
|
_qr_cpu_gpu_translation_rule, lapack.geqrf, lapack.orgqr)
|
|
|
|
|
|
|
|
xla.backend_specific_translations['gpu'][qr_p] = partial(
|
|
|
|
_qr_cpu_gpu_translation_rule, cusolver.geqrf, cusolver.orgqr)
|
2019-09-04 16:24:32 -04:00
|
|
|
|
2019-01-05 11:13:08 +05:30
|
|
|
|
2019-01-08 21:47:19 +05:30
|
|
|
# Singular value decomposition
|
2019-01-05 11:13:08 +05:30
|
|
|
|
|
|
|
def svd_impl(operand, full_matrices, compute_uv):
|
2019-09-05 18:12:00 -04:00
|
|
|
s, u, vt = xla.apply_primitive(svd_p, operand, full_matrices=full_matrices,
|
|
|
|
compute_uv=compute_uv)
|
2019-07-27 15:46:14 -07:00
|
|
|
return s, u, vt
|
2019-01-05 11:13:08 +05:30
|
|
|
|
2019-01-08 09:24:48 +05:30
|
|
|
def svd_translation_rule(c, operand, full_matrices, compute_uv):
|
2020-06-05 12:21:30 -04:00
|
|
|
shape = c.get_shape(operand).dimensions()
|
|
|
|
m, n = shape[-2:]
|
|
|
|
u, s, v = xops.SVD(operand)
|
|
|
|
permutation = list(range(len(shape)))
|
|
|
|
permutation[-1], permutation[-2] = permutation[-2], permutation[-1]
|
|
|
|
vt = xops.Transpose(v, permutation)
|
|
|
|
if not full_matrices and m != n:
|
|
|
|
u = xops.SliceInDim(u, 0, min(m, n), stride=1, dimno=len(shape) - 1)
|
|
|
|
vt = xops.SliceInDim(vt, 0, min(m, n), stride=1, dimno=len(shape) - 2)
|
|
|
|
return xops.Tuple(c, [s, u, vt])
|
2019-01-05 11:13:08 +05:30
|
|
|
|
|
|
|
def svd_abstract_eval(operand, full_matrices, compute_uv):
|
|
|
|
if isinstance(operand, ShapedArray):
|
|
|
|
if operand.ndim < 2:
|
2019-01-08 09:24:48 +05:30
|
|
|
raise ValueError("Argument to singular value decomposition must have ndims >= 2")
|
2019-01-05 11:13:08 +05:30
|
|
|
|
|
|
|
batch_dims = operand.shape[:-2]
|
|
|
|
m = operand.shape[-2]
|
|
|
|
n = operand.shape[-1]
|
2019-07-27 15:46:14 -07:00
|
|
|
s = ShapedArray(batch_dims + (min(m, n),), lax.lax._complex_basetype(operand.dtype))
|
2019-01-05 11:13:08 +05:30
|
|
|
u = ShapedArray(batch_dims + (m, m if full_matrices else min(m, n)), operand.dtype)
|
|
|
|
vt = ShapedArray(batch_dims + (n if full_matrices else min(m, n), n), operand.dtype)
|
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
raise NotImplementedError
|
|
|
|
return s, u, vt
|
2019-01-05 11:13:08 +05:30
|
|
|
|
2019-05-02 17:47:34 -07:00
|
|
|
def svd_jvp_rule(primals, tangents, full_matrices, compute_uv):
|
|
|
|
A, = primals
|
|
|
|
dA, = tangents
|
|
|
|
s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True)
|
|
|
|
|
2020-01-27 19:57:43 +00:00
|
|
|
if compute_uv and full_matrices:
|
2019-08-09 10:50:31 -04:00
|
|
|
# TODO: implement full matrices case, documented here: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
|
|
|
|
raise NotImplementedError(
|
|
|
|
"Singular value decomposition JVP not implemented for full matrices")
|
|
|
|
|
2019-05-02 17:47:34 -07:00
|
|
|
k = s.shape[-1]
|
2019-09-05 18:12:00 -04:00
|
|
|
Ut, V = _H(U), _H(Vt)
|
2019-05-02 17:47:34 -07:00
|
|
|
s_dim = s[..., None, :]
|
2020-05-05 20:41:57 -04:00
|
|
|
dS = jnp.matmul(jnp.matmul(Ut, dA), V)
|
|
|
|
ds = jnp.real(jnp.diagonal(dS, 0, -2, -1))
|
|
|
|
F = 1 / (jnp.square(s_dim) - jnp.square(_T(s_dim)) + jnp.eye(k, dtype=A.dtype))
|
|
|
|
F = F - jnp.eye(k, dtype=A.dtype)
|
2019-05-02 17:47:34 -07:00
|
|
|
dSS = s_dim * dS
|
2019-09-05 18:12:00 -04:00
|
|
|
SdS = _T(s_dim) * dS
|
2020-05-05 20:41:57 -04:00
|
|
|
dU = jnp.matmul(U, F * (dSS + _T(dSS)))
|
|
|
|
dV = jnp.matmul(V, F * (SdS + _T(SdS)))
|
2019-05-02 17:47:34 -07:00
|
|
|
|
2020-04-22 20:15:04 -04:00
|
|
|
m, n = A.shape[-2:]
|
2019-05-02 17:47:34 -07:00
|
|
|
if m > n:
|
2020-05-05 20:41:57 -04:00
|
|
|
dU = dU + jnp.matmul(jnp.eye(m, dtype=A.dtype) - jnp.matmul(U, Ut), jnp.matmul(dA, V)) / s_dim
|
2019-05-02 17:47:34 -07:00
|
|
|
if n > m:
|
2020-05-05 20:41:57 -04:00
|
|
|
dV = dV + jnp.matmul(jnp.eye(n, dtype=A.dtype) - jnp.matmul(V, Vt), jnp.matmul(_H(dA), U)) / s_dim
|
2019-09-05 18:12:00 -04:00
|
|
|
return (s, U, Vt), (ds, dU, _T(dV))
|
2019-05-02 17:47:34 -07:00
|
|
|
|
2019-08-02 11:16:15 -04:00
|
|
|
def _svd_cpu_gpu_translation_rule(gesvd_impl, c, operand, full_matrices, compute_uv):
|
2019-09-05 18:12:00 -04:00
|
|
|
|
2020-05-11 17:43:55 -04:00
|
|
|
shape = c.get_shape(operand)
|
2019-09-05 18:12:00 -04:00
|
|
|
batch_dims = shape.dimensions()[:-2]
|
2020-05-11 17:43:55 -04:00
|
|
|
s, u, vt, info = gesvd_impl(c, operand,
|
2020-04-23 18:30:47 -04:00
|
|
|
full_matrices=full_matrices,
|
2019-09-05 18:12:00 -04:00
|
|
|
compute_uv=compute_uv)
|
2020-05-05 20:41:57 -04:00
|
|
|
ok = xops.Eq(info, xops.ConstantLiteral(c, np.array(0, np.int32)))
|
2020-04-23 18:30:47 -04:00
|
|
|
s = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1,)), s,
|
2019-09-05 18:12:00 -04:00
|
|
|
_nan_like(c, s))
|
2020-04-23 18:30:47 -04:00
|
|
|
u = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), u,
|
2019-09-05 18:12:00 -04:00
|
|
|
_nan_like(c, u))
|
2020-04-23 18:30:47 -04:00
|
|
|
vt = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), vt,
|
2019-09-05 18:12:00 -04:00
|
|
|
_nan_like(c, vt))
|
2020-04-23 18:30:47 -04:00
|
|
|
return xops.Tuple(c, [s, u, vt])
|
2019-01-08 09:24:48 +05:30
|
|
|
|
2019-04-30 13:19:34 -04:00
|
|
|
def svd_batching_rule(batched_args, batch_dims, full_matrices, compute_uv):
|
|
|
|
x, = batched_args
|
|
|
|
bd, = batch_dims
|
2019-07-27 15:46:14 -07:00
|
|
|
x = batching.moveaxis(x, bd, 0)
|
|
|
|
outs = svd_p.bind(x, full_matrices=full_matrices, compute_uv=compute_uv)
|
|
|
|
return outs, (0, 0, 0)
|
2019-04-30 13:19:34 -04:00
|
|
|
|
2019-01-05 11:13:08 +05:30
|
|
|
svd_p = Primitive('svd')
|
2019-07-27 15:46:14 -07:00
|
|
|
svd_p.multiple_results = True
|
2019-01-05 11:13:08 +05:30
|
|
|
svd_p.def_impl(svd_impl)
|
|
|
|
svd_p.def_abstract_eval(svd_abstract_eval)
|
2019-05-02 17:47:34 -07:00
|
|
|
ad.primitive_jvps[svd_p] = svd_jvp_rule
|
2019-04-30 13:19:34 -04:00
|
|
|
batching.primitive_batchers[svd_p] = svd_batching_rule
|
2019-08-02 11:16:15 -04:00
|
|
|
xla.translations[svd_p] = svd_translation_rule
|
|
|
|
|
|
|
|
xla.backend_specific_translations['cpu'][svd_p] = partial(
|
2019-09-04 16:24:32 -04:00
|
|
|
_svd_cpu_gpu_translation_rule, lapack.gesdd)
|
2019-08-02 11:16:15 -04:00
|
|
|
|
2019-09-02 07:25:06 -07:00
|
|
|
xla.backend_specific_translations['gpu'][svd_p] = partial(
|
|
|
|
_svd_cpu_gpu_translation_rule, cusolver.gesvd)
|