rocm_jax/jax/lax_linalg.py
Peter Hawkins ed3e2308c1 Add support for linear algebra ops on GPU using Cusolver:
* LU decomposition
* Symmetric (Hermitian) eigendecomposition
* Singular value decomposition.

Make LU decomposition tests less sensitive to the exact decomposition; check that we have a decomposition, not precisely the same one scipy returns.
2019-08-02 11:16:15 -04:00

770 lines
27 KiB
Python

# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from functools import partial
import numpy as onp
from jax.numpy import lax_numpy as np
from jax import ad_util
from jax import api
from jax import api_util
from jax import core
from jax import lax
from jax import ops
from jax.interpreters import xla
from jax.interpreters import ad
from jax.interpreters import batching
from jax.util import partial
from jax.abstract_arrays import ShapedArray
from jax.core import Primitive
from jax.lax import (standard_primitive, standard_unop, binop_dtype_rule,
_float, _complex, _input_dtype, _broadcasting_select)
from jax.lib import lapack
from jax.lib import cusolver
# traceables
def cholesky(x, symmetrize_input=True):
if symmetrize_input:
x = symmetrize(x)
return np.tril(cholesky_p.bind(x))
def eig(x):
w, vl, vr = eig_p.bind(x)
return w, vl, vr
def eigh(x, lower=True, symmetrize_input=True):
if symmetrize_input:
x = symmetrize(x)
v, w = eigh_p.bind(x, lower=lower)
return v, w
def lu(x):
lu, pivots = lu_p.bind(x)
return lu, pivots
def qr(x, full_matrices=True):
q, r = qr_p.bind(x, full_matrices=full_matrices)
return q, r
def svd(x, full_matrices=True, compute_uv=True):
s, u, v = svd_p.bind(x, full_matrices=full_matrices, compute_uv=compute_uv)
if compute_uv:
return u, s, v
else:
return s
def triangular_solve(a, b, left_side=False, lower=False, transpose_a=False,
conjugate_a=False, unit_diagonal=False):
conjugate_a = conjugate_a and np.issubdtype(lax.dtype(a), np.complexfloating)
return triangular_solve_p.bind(
a, b, left_side=left_side, lower=lower, transpose_a=transpose_a,
conjugate_a=conjugate_a, unit_diagonal=unit_diagonal)
# utilities
def _T(x): return np.swapaxes(x, -1, -2)
def _H(x): return np.conj(_T(x))
def symmetrize(x): return (x + _H(x)) / 2
def _unpack_tuple(f, n):
def g(c, *args, **kwargs):
t = f(c, *args, **kwargs)
return (c.GetTupleElement(t, i) for i in range(n))
return g
# primitives
_cpu_lapack_types = {np.float32, np.float64, np.complex64, np.complex128}
# Cholesky decomposition
def cholesky_jvp_rule(primals, tangents):
x, = primals
sigma_dot, = tangents
L = np.tril(cholesky_p.bind(x))
# Forward-mode rule from https://arxiv.org/pdf/1602.07527.pdf
def phi(X):
l = np.tril(X)
return l / (np._constant_like(X, 1) + np.eye(X.shape[-1], dtype=X.dtype))
tmp = triangular_solve(L, sigma_dot, left_side=False, transpose_a=True,
conjugate_a=True, lower=True)
L_dot = lax.batch_matmul(L, phi(triangular_solve(
L, tmp, left_side=True, transpose_a=False, lower=True)))
return L, L_dot
def cholesky_batching_rule(batched_args, batch_dims):
x, = batched_args
bd, = batch_dims
x = batching.bdim_at_front(x, bd)
return cholesky(x), 0
cholesky_p = standard_unop(_float | _complex, 'cholesky')
ad.primitive_jvps[cholesky_p] = cholesky_jvp_rule
batching.primitive_batchers[cholesky_p] = cholesky_batching_rule
def _nan_like(c, operand):
shape = c.GetShape(operand)
dtype = shape.element_type()
if onp.issubdtype(dtype, onp.complexfloating):
nan = c.Constant(onp.array(onp.nan * (1. + 1j), dtype=dtype))
else:
nan = c.Constant(onp.array(onp.nan, dtype=dtype))
return c.Broadcast(nan, shape.dimensions())
# TODO(phawkins): remove if-condition after increasing minimum Jaxlib version to
# 0.1.23.
if hasattr(lapack, "potrf"):
_cpu_potrf = lapack.potrf
else:
_cpu_potrf = _unpack_tuple(lapack.jax_potrf, 2)
def cholesky_cpu_translation_rule(c, operand):
shape = c.GetShape(operand)
dtype = shape.element_type().type
if len(shape.dimensions()) == 2 and dtype in _cpu_lapack_types:
result, info = _cpu_potrf(c, operand, lower=True)
return c.Select(c.Eq(info, c.ConstantS32Scalar(0)), result,
_nan_like(c, result))
else:
# Fall back to the HLO implementation for batched Cholesky decomposition or
# unsupported types.
# TODO(phawkins): support LAPACK primitives in batched mode.
return c.Cholesky(operand)
xla.backend_specific_translations['cpu'][cholesky_p] = cholesky_cpu_translation_rule
# Asymmetric eigendecomposition
def eig_impl(operand):
return xla.apply_primitive(eig_p, operand)
def eig_translation_rule(c, operand):
raise NotImplementedError(
"Nonsymmetric eigendecomposition is only implemented on the CPU backend")
def eig_abstract_eval(operand):
if isinstance(operand, ShapedArray):
if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
raise ValueError("Argument to nonsymmetric eigendecomposition must have "
"shape [..., n, n], got shape {}".format(operand.shape))
batch_dims = operand.shape[:-2]
n = operand.shape[-1]
vl = vr = ShapedArray(batch_dims + (n, n), operand.dtype)
w = ShapedArray(batch_dims + (n,), lax.lax._complex_basetype(operand.dtype))
else:
w = vl = vr = operand
return core.AbstractTuple((w, vl, vr))
# TODO(phawkins): remove if-condition after increasing minimum Jaxlib version to
# 0.1.23.
if hasattr(lapack, "geev"):
_cpu_geev = lapack.geev
else:
_cpu_geev = _unpack_tuple(lapack.jax_geev, 4)
def eig_cpu_translation_rule(c, operand):
shape = c.GetShape(operand)
batch_dims = shape.dimensions()[:-2]
w, vl, vr, info = _cpu_geev(c, operand)
ok = c.Eq(info, c.ConstantS32Scalar(0))
w = _broadcasting_select(c, c.Reshape(ok, None, batch_dims + (1,)), w,
_nan_like(c, w))
vl = _broadcasting_select(c, c.Reshape(ok, None, batch_dims + (1, 1)), vl,
_nan_like(c, vl))
vr = _broadcasting_select(c, c.Reshape(ok, None, batch_dims + (1, 1)), vr,
_nan_like(c, vr))
return c.Tuple(w, vl, vr)
def eig_batching_rule(batched_args, batch_dims):
x, = batched_args
bd, = batch_dims
x = batching.bdim_at_front(x, bd)
return eig_p.bind(x), 0
eig_p = Primitive('eig')
eig_p.def_impl(eig_impl)
eig_p.def_abstract_eval(eig_abstract_eval)
xla.translations[eig_p] = eig_translation_rule
xla.backend_specific_translations['cpu'][eig_p] = eig_cpu_translation_rule
batching.primitive_batchers[eig_p] = eig_batching_rule
# Symmetric/Hermitian eigendecomposition
def eigh_impl(operand, lower):
v, w = xla.apply_primitive(eigh_p, operand, lower=lower)
return core.pack((v, w))
def eigh_translation_rule(c, operand, lower):
raise NotImplementedError(
"Symmetric eigendecomposition is only implemented on the CPU backend")
def eigh_abstract_eval(operand, lower):
if isinstance(operand, ShapedArray):
if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
raise ValueError(
"Argument to symmetric eigendecomposition must have shape [..., n, n],"
"got shape {}".format(operand.shape))
batch_dims = operand.shape[:-2]
n = operand.shape[-1]
v = ShapedArray(batch_dims + (n, n), operand.dtype)
w = ShapedArray(batch_dims + (n,), lax.lax._complex_basetype(operand.dtype))
else:
v, w = operand, operand
return core.AbstractTuple((v, w))
def _eigh_cpu_gpu_translation_rule(syevd_impl, c, operand, lower):
shape = c.GetShape(operand)
batch_dims = shape.dimensions()[:-2]
v, w, info = syevd_impl(c, operand, lower=lower)
ok = c.Eq(info, c.ConstantS32Scalar(0))
v = _broadcasting_select(c, c.Reshape(ok, None, batch_dims + (1, 1)), v,
_nan_like(c, v))
w = _broadcasting_select(c, c.Reshape(ok, None, batch_dims + (1,)), w,
_nan_like(c, w))
return c.Tuple(v, w)
def eigh_jvp_rule(primals, tangents, lower):
# Derivative for eigh in the simplest case of distinct eigenvalues.
# This is classic nondegenerate perurbation theory, but also see
# https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
# The general solution treating the case of degenerate eigenvalues is
# considerably more complicated. Ambitious readers may refer to the general
# methods below or refer to degenerate perturbation theory in physics.
# https://www.win.tue.nl/analysis/reports/rana06-33.pdf and
# https://people.orie.cornell.edu/aslewis/publications/99-clarke.pdf
a, = primals
a_dot, = tangents
v, w = eigh_p.bind(symmetrize(a), lower=lower)
# for complex numbers we need eigenvalues to be full dtype of v, a:
w = w.astype(a.dtype)
eye_n = np.eye(a.shape[-1], dtype=a.dtype)
# carefully build reciprocal delta-eigenvalue matrix, avoiding NaNs.
Fmat = np.reciprocal(eye_n + w - w[..., np.newaxis]) - eye_n
# eigh impl doesn't support batch dims, but future-proof the grad.
dot = lax.dot if a.ndim == 2 else lax.batch_matmul
vdag_adot_v = dot(dot(_H(v), a_dot), v)
dv = dot(v, np.multiply(Fmat, vdag_adot_v))
dw = np.diagonal(vdag_adot_v)
return core.pack((v, w)), core.pack((dv, dw))
def eigh_batching_rule(batched_args, batch_dims, lower):
x, = batched_args
bd, = batch_dims
x = batching.bdim_at_front(x, bd)
return eigh_p.bind(x, lower=lower), 0
eigh_p = Primitive('eigh')
eigh_p.def_impl(eigh_impl)
eigh_p.def_abstract_eval(eigh_abstract_eval)
xla.translations[eigh_p] = eigh_translation_rule
ad.primitive_jvps[eigh_p] = eigh_jvp_rule
# TODO(phawkins): remove if-condition after increasing minimum Jaxlib version to
# 0.1.23.
if hasattr(lapack, "syevd"):
_cpu_syevd = lapack.syevd
else:
_cpu_syevd = _unpack_tuple(lapack.jax_syevd, 3)
xla.backend_specific_translations['cpu'][eigh_p] = partial(
_eigh_cpu_gpu_translation_rule, _cpu_syevd)
# TODO(phawkins): remove if-condition after increasing minimum Jaxlib version to
# 0.1.23.
if cusolver:
xla.backend_specific_translations['gpu'][eigh_p] = partial(
_eigh_cpu_gpu_translation_rule, cusolver.syevd)
batching.primitive_batchers[eigh_p] = eigh_batching_rule
triangular_solve_dtype_rule = partial(
binop_dtype_rule, _input_dtype, (_float | _complex, _float | _complex),
'triangular_solve')
def triangular_solve_shape_rule(a, b, left_side=False, **unused_kwargs):
if a.ndim < 2:
msg = "triangular_solve requires a.ndim to be at least 2, got {}."
raise TypeError(msg.format(a.ndim))
if a.shape[-1] != a.shape[-2]:
msg = ("triangular_solve requires the last two dimensions of a to be equal "
"in size, got a.shape of {}.")
raise TypeError(msg.format(a.shape))
if a.shape[:-2] != b.shape[:-2]:
msg = ("triangular_solve requires both arguments to have the same number "
"of dimensions and equal batch dimensions, got {} and {}.")
raise TypeError(msg.format(a.shape, b.shape))
common_dim = -2 if left_side else -1
if a.shape[-1] != b.shape[common_dim]:
msg = "Incompatible shapes for arguments to triangular_solve: {} and {}."
raise TypeError(msg.format(a.shape, b.shape))
return b.shape
def triangular_solve_jvp_rule_a(
g_a, ans, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal):
k = 1 if unit_diagonal else 0
g_a = np.tril(g_a, k=-k) if lower else np.triu(g_a, k=k)
g_a = lax.neg(g_a)
g_a = np.swapaxes(g_a, -1, -2) if transpose_a else g_a
g_a = np.conj(g_a) if conjugate_a else g_a
tmp = triangular_solve(a, g_a, left_side, lower, transpose_a, conjugate_a,
unit_diagonal)
dot = lax.dot if g_a.ndim == 2 else lax.batch_matmul
if left_side:
return dot(tmp, ans)
else:
return dot(ans, tmp)
def triangular_solve_transpose_rule(
cotangent, a, b, left_side, lower, transpose_a, conjugate_a,
unit_diagonal):
# Triangular solve is nonlinear in its first argument and linear in its second
# argument, analogous to `div` but swapped.
assert a is not None and b is None
cotangent_b = triangular_solve(a, cotangent, left_side, lower,
not transpose_a, conjugate_a, unit_diagonal)
return [None, cotangent_b]
def triangular_solve_batching_rule(batched_args, batch_dims, left_side,
lower, transpose_a, conjugate_a,
unit_diagonal):
x, y = batched_args
bx, by = batch_dims
size = next(t.shape[i] for t, i in zip(batched_args, batch_dims)
if i is not None)
x = batching.bdim_at_front(x, bx, size, force_broadcast=True)
y = batching.bdim_at_front(y, by, size, force_broadcast=True)
return triangular_solve(x, y, left_side=left_side, lower=lower,
transpose_a=transpose_a, conjugate_a=conjugate_a,
unit_diagonal=unit_diagonal), 0
triangular_solve_p = standard_primitive(
triangular_solve_shape_rule, triangular_solve_dtype_rule,
'triangular_solve')
ad.defjvp2(triangular_solve_p,
triangular_solve_jvp_rule_a,
lambda g_b, _, a, b, **kws: triangular_solve(a, g_b, **kws))
ad.primitive_transposes[triangular_solve_p] = triangular_solve_transpose_rule
batching.primitive_batchers[triangular_solve_p] = triangular_solve_batching_rule
def triangular_solve_cpu_translation_rule(
c, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal):
shape = c.GetShape(a)
dtype = shape.element_type().type
if len(shape.dimensions()) == 2 and dtype in _cpu_lapack_types:
if conjugate_a and not transpose_a:
a = c.Conj(a)
conjugate_a = False
return lapack.jax_trsm(
c, c.Constant(onp.array(1, dtype=dtype)), a, b, left_side, lower,
transpose_a, conjugate_a, unit_diagonal)
else:
# Fall back to the HLO implementation for batched triangular_solve or
# unsupported types.
# TODO(phawkins): support BLAS primitives in batched mode.
return c.TriangularSolve(a, b, left_side, lower, transpose_a, conjugate_a,
unit_diagonal)
xla.backend_specific_translations['cpu'][triangular_solve_p] = triangular_solve_cpu_translation_rule
# LU decomposition
# Computes a pivoted LU decomposition such that
# PA = LU
# In the style of LAPACK, LU are stored in the same matrix.
def _lu_unblocked(a):
"""Unblocked LU decomposition, as a rolled loop."""
m, n = a.shape
def body(k, state):
pivot, perm, a, error = state
m_idx = np.arange(m)
n_idx = np.arange(n)
if np.issubdtype(a.dtype, np.complexfloating):
t = a[:, k]
magnitude = np.abs(np.real(t)) + np.abs(np.imag(t))
else:
magnitude = np.abs(a[:, k])
i = np.argmax(np.where(m_idx >= k, magnitude, -np.inf))
pivot = ops.index_update(pivot, ops.index[k], i)
a = ops.index_update(a, ops.index[[k, i],], a[[i, k],])
perm = ops.index_update(perm, ops.index[[i, k],], perm[[k, i],])
# a[k+1:, k] /= a[k, k], adapted for loop-invariant shapes
x = a[k, k]
error = error | lax.eq(x, np._constant_like(a, 0))
a = ops.index_update(a, ops.index[:, k],
np.where(m_idx > k, a[:, k] / x, a[:, k]))
# a[k+1:, k+1:] -= np.outer(a[k+1:, k], a[k, k+1:])
a = a - np.where((m_idx[:, None] > k) & (n_idx > k),
np.outer(a[:, k], a[k, :]), np.array(0, dtype=a.dtype))
return pivot, perm, a, error
pivot = np.zeros((min(m, n),), dtype=np.int32)
perm = np.arange(m, dtype=np.int32)
error = np.array(False, np.bool_)
if m == 0 and n == 0:
# If the array is empty, the loop body never executes but tracing it to a
# jaxpr fails because the indexing cannot succeed.
return (pivot, perm, a, error)
return lax.fori_loop(0, min(m, n), body, (pivot, perm, a, error))
def _lu_blocked(a, block_size=32):
"""Blocked LU decomposition, as an unrolled loop."""
m, n = a.shape
r = min(m, n)
pivot = np.zeros((r,), dtype=np.int32)
error = np.array(False, np.bool_)
for k in range(0, r, block_size):
b = min(r - k, block_size)
block_pivot, perm, lu_block, block_error = _lu_unblocked(a[k:, k:k+b])
error = error | block_error
a = ops.index_update(a, ops.index[k:, k:k+b], lu_block)
a = ops.index_update(a, ops.index[k:, :k], a[perm + k, :k])
pivot = ops.index_update(pivot, ops.index[k:k+b], block_pivot + k)
if k + b < n:
a = ops.index_update(a, ops.index[k:, k+b:], a[perm + k, k+b:])
a = ops.index_update(
a, ops.index[k:k+b, k+b:],
triangular_solve(a[k:k+b, k:k+b], a[k:k+b, k+b:],
left_side=True, lower=True, unit_diagonal=True))
a = ops.index_add(
a, ops.index[k+b:, k+b:],
-lax.dot(a[k+b:, k:k+b], a[k:k+b, k+b:],
precision=lax.Precision.HIGHEST))
a = np.where(error, lax.full_like(a, np.nan), a)
return pivot, a
def _lu_python(x):
"""Default LU decomposition in Python, where no better version exists."""
m, n = x.shape[-2:]
batch_dims = x.shape[:-2]
if len(batch_dims) > 0:
batch_size = onp.prod(batch_dims, dtype=onp.int64)
pivot, lu = api.vmap(_lu_blocked)(lax.reshape(x, (batch_size, m, n)))
pivot = lax.reshape(pivot, batch_dims + (min(m, n),))
lu = lax.reshape(lu, batch_dims + (m, n))
else:
pivot, lu = _lu_blocked(x)
return core.pack((lu, pivot))
def _lu_impl(operand):
lu, pivot = xla.apply_primitive(lu_p, operand)
return core.pack((lu, pivot))
def _lu_abstract_eval(operand):
if isinstance(operand, ShapedArray):
if operand.ndim < 2:
raise ValueError("Argument to LU decomposition must have ndims >= 2")
batch_dims = operand.shape[:-2]
m = operand.shape[-2]
n = operand.shape[-1]
pivot = ShapedArray(batch_dims + (min(m, n),), np.int32)
else:
pivot = operand
return core.AbstractTuple((operand, pivot))
def _lu_jvp_rule(primals, tangents):
a, = primals
a_dot, = tangents
lu, pivots = lu_p.bind(a)
a_shape = np.shape(a)
m, n = a_shape[-2:]
dtype = lax.dtype(a)
k = min(m, n)
permutation = lu_pivots_to_permutation(pivots, m)
x = a_dot[..., permutation, :]
# Differentiation of Matrix Functionals Using Triangular Factorization
# F. R. De Hoog, R. S. Anderssen, and M. A. Lukas
#
# LU = A
# ==> L'U + LU' = A'
# ==> inv(L) . L' + U' . inv(U) = inv(L) A' inv(U)
# ==> L' = L . tril(inv(L) . A' . inv(U), -1)
# U' = triu(inv(L) . A' . inv(U)) . U
ndims = len(a_shape)
l_padding = [(0, 0, 0)] * ndims
l_padding[-1] = (0, m - k, 0)
zero = np._constant_like(lu, 0)
l = lax.pad(np.tril(lu[..., :, :k], -1), zero, l_padding)
l = l + np.eye(m, m, dtype=dtype)
u_eye = lax.pad(np.eye(n - k, n - k, dtype=dtype), zero,
((k, 0, 0), (k, 0, 0)))
u_padding = [(0, 0, 0)] * ndims
u_padding[-2] = (0, n - k, 0)
u = lax.pad(np.triu(lu[..., :k, :]), zero, u_padding) + u_eye
la = triangular_solve(l, x, left_side=True, transpose_a=False, lower=True,
unit_diagonal=True)
lau = triangular_solve(u, la, left_side=False, transpose_a=False,
lower=False)
l_dot = np.matmul(l, np.tril(lau, -1))
u_dot = np.matmul(np.triu(lau), u)
lu_dot = l_dot + u_dot
return core.pack((lu, pivots)), ad.TangentTuple((lu_dot, ad_util.zero))
def _lu_batching_rule(batched_args, batch_dims):
x, = batched_args
bd, = batch_dims
x = batching.bdim_at_front(x, bd)
return lu_p.bind(x), 0
def _lu_cpu_gpu_translation_rule(getrf_impl, c, operand):
shape = c.GetShape(operand)
batch_dims = shape.dimensions()[:-2]
lu, pivot, info = getrf_impl(c, operand)
# Subtract 1 from the pivot to get 0-based indices.
pivot = c.Sub(pivot, c.ConstantS32Scalar(1))
ok = c.Eq(info, c.ConstantS32Scalar(0))
lu = _broadcasting_select(c, c.Reshape(ok, None, batch_dims + (1, 1)), lu,
_nan_like(c, lu))
return c.Tuple(lu, pivot)
lu_p = Primitive('lu')
lu_p.def_impl(_lu_impl)
lu_p.def_abstract_eval(_lu_abstract_eval)
xla.translations[lu_p] = xla.lower_fun(_lu_python, instantiate=True)
ad.primitive_jvps[lu_p] = _lu_jvp_rule
batching.primitive_batchers[lu_p] = _lu_batching_rule
# TODO(phawkins): remove if-condition after increasing minimum Jaxlib version to
# 0.1.23.
if hasattr(lapack, "getrf"):
_cpu_getrf = lapack.getrf
else:
_cpu_getrf = _unpack_tuple(lapack.jax_getrf, 3)
xla.backend_specific_translations['cpu'][lu_p] = partial(
_lu_cpu_gpu_translation_rule, _cpu_getrf)
if cusolver:
xla.backend_specific_translations['gpu'][lu_p] = partial(
_lu_cpu_gpu_translation_rule, cusolver.getrf)
def lu_pivots_to_permutation(swaps, m):
"""Converts the pivots (row swaps) returned by LU to a permutation.
We build a permutation rather than applying `swaps` directly to the rows
of a matrix because lax loops aren't differentiable.
Args:
swaps: an array of shape (..., k) of row swaps to perform
m: the size of the output permutation. m should be >= k.
Returns:
An int32 array of shape (..., m).
"""
assert len(swaps.shape) >= 1
batch_dims = swaps.shape[:-1]
k = swaps.shape[-1]
def body_fn(i, permutation):
j = swaps[..., i]
iotas = np.ix_(*(lax.iota(np.int32, b) for b in batch_dims))
x = permutation[..., i]
y = permutation[iotas + (j,)]
permutation = ops.index_update(permutation, ops.index[..., i], y)
return ops.index_update(permutation, ops.index[iotas + (j,)], x)
permutation = lax.broadcasted_iota(np.int32, batch_dims + (m,),
len(batch_dims))
return lax.fori_loop(
onp.array(0, onp.int32), onp.array(k, onp.int32), body_fn, permutation)
# QR decomposition
def qr_impl(operand, full_matrices):
q, r = xla.apply_primitive(qr_p, operand, full_matrices=full_matrices)
return core.pack((q, r))
def qr_translation_rule(c, operand, full_matrices):
return c.QR(operand, full_matrices=full_matrices)
def qr_abstract_eval(operand, full_matrices):
if isinstance(operand, ShapedArray):
if operand.ndim < 2:
raise ValueError("Argument to QR decomposition must have ndims >= 2")
batch_dims = operand.shape[:-2]
m = operand.shape[-2]
n = operand.shape[-1]
k = m if full_matrices else min(m, n)
q = ShapedArray(batch_dims + (m, k), operand.dtype)
r = ShapedArray(batch_dims + (k, n), operand.dtype)
else:
q = operand
r = operand
return core.AbstractTuple((q, r))
def qr_jvp_rule(primals, tangents, full_matrices):
# See j-towns.github.io/papers/qr-derivative.pdf for a terse derivation.
x, = primals
if full_matrices or np.shape(x)[-2] < np.shape(x)[-1]:
raise NotImplementedError
dx, = tangents
q, r = qr_p.bind(x, full_matrices=False)
dx_rinv = triangular_solve(r, dx) # Right side solve by default
qt_dx_rinv = np.matmul(_T(q), dx_rinv)
qt_dx_rinv_lower = np.tril(qt_dx_rinv, -1)
domega = qt_dx_rinv_lower - _T(qt_dx_rinv_lower) # This is skew-symmetric
dq = np.matmul(q, domega - qt_dx_rinv) + dx_rinv
dr = np.matmul(qt_dx_rinv - domega, r)
return core.pack((q, r)), core.pack((dq, dr))
def qr_batching_rule(batched_args, batch_dims, full_matrices):
x, = batched_args
bd, = batch_dims
x = batching.bdim_at_front(x, bd)
return qr_p.bind(x, full_matrices=full_matrices), 0
qr_p = Primitive('qr')
qr_p.def_impl(qr_impl)
qr_p.def_abstract_eval(qr_abstract_eval)
xla.translations[qr_p] = qr_translation_rule
ad.primitive_jvps[qr_p] = qr_jvp_rule
batching.primitive_batchers[qr_p] = qr_batching_rule
# Singular value decomposition
def svd_impl(operand, full_matrices, compute_uv):
s, u, vt = xla.apply_primitive(svd_p, operand, full_matrices=full_matrices, compute_uv=compute_uv)
return core.pack((s, u, vt))
def svd_translation_rule(c, operand, full_matrices, compute_uv):
raise NotImplementedError(
"Singular value decomposition is only implemented on the CPU backend")
def svd_abstract_eval(operand, full_matrices, compute_uv):
if isinstance(operand, ShapedArray):
if operand.ndim < 2:
raise ValueError("Argument to singular value decomposition must have ndims >= 2")
batch_dims = operand.shape[:-2]
m = operand.shape[-2]
n = operand.shape[-1]
s = ShapedArray(batch_dims + (min(m, n),), operand.dtype)
u = ShapedArray(batch_dims + (m, m if full_matrices else min(m, n)), operand.dtype)
vt = ShapedArray(batch_dims + (n if full_matrices else min(m, n), n), operand.dtype)
else:
s = operand
u = operand
vt = operand
return core.AbstractTuple((s, u, vt))
def svd_jvp_rule(primals, tangents, full_matrices, compute_uv):
if full_matrices:
#TODO: implement full matrices case, documented here: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
raise NotImplementedError("Singular value decomposition JVP not implemented for full matrices")
A, = primals
dA, = tangents
s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True)
k = s.shape[-1]
Ut, V = np.conj(U).T, np.conj(Vt).T
s_dim = s[..., None, :]
dS = np.dot(np.dot(Ut, dA), V)
ds = np.real(np.diag(dS))
F = 1 / (np.square(s_dim) - np.square(s_dim.T) + np.eye(k)) - np.eye(k)
dSS = s_dim * dS
SdS = s_dim.T * dS
dU = np.dot(U, F * (dSS + dSS.T))
dV = np.dot(V, F * (SdS + SdS.T))
m, n = A.shape[-2], A.shape[-1]
if m > n:
dU = dU + np.dot(np.eye(m) - np.dot(U, Ut), np.dot(dA, V)) / s_dim
if n > m:
dV = dV + np.dot(np.eye(n) - np.dot(V, Vt), np.dot(np.conj(dA).T, U)) / s_dim
return core.pack((s, U, Vt)), core.pack((ds, dU, dV.T))
def _svd_cpu_gpu_translation_rule(gesvd_impl, c, operand, full_matrices, compute_uv):
shape = c.GetShape(operand)
dtype = shape.element_type().type
if len(shape.dimensions()) == 2 and dtype in _cpu_lapack_types:
s, u, vt, info = gesvd_impl(c, operand, full_matrices=full_matrices,
compute_uv=compute_uv)
ok = c.Eq(info, c.ConstantS32Scalar(0))
s = _broadcasting_select(c, c.Reshape(ok, None, (1,)), s,
_nan_like(c, s))
u = _broadcasting_select(c, c.Reshape(ok, None, (1, 1)), u,
_nan_like(c, u))
vt = _broadcasting_select(c, c.Reshape(ok, None, (1, 1)), vt,
_nan_like(c, vt))
return c.Tuple(s, u, vt)
else:
raise NotImplementedError(
"Only unbatched singular value decomposition is implemented on CPU")
def svd_batching_rule(batched_args, batch_dims, full_matrices, compute_uv):
x, = batched_args
bd, = batch_dims
x = batching.bdim_at_front(x, bd)
return svd_p.bind(x, full_matrices=full_matrices, compute_uv=compute_uv), 0
svd_p = Primitive('svd')
svd_p.def_impl(svd_impl)
svd_p.def_abstract_eval(svd_abstract_eval)
ad.primitive_jvps[svd_p] = svd_jvp_rule
batching.primitive_batchers[svd_p] = svd_batching_rule
xla.translations[svd_p] = svd_translation_rule
# TODO(phawkins): remove if-condition after increasing minimum Jaxlib version to
# 0.1.23.
if hasattr(lapack, "gesdd"):
_cpu_gesdd = lapack.gesdd
else:
_cpu_gesdd = _unpack_tuple(lapack.jax_gesdd, 4)
xla.backend_specific_translations['cpu'][svd_p] = partial(
_svd_cpu_gpu_translation_rule, _cpu_gesdd)
# TODO(phawkins): remove if-condition after increasing minimum Jaxlib version to
# 0.1.23.
if cusolver:
xla.backend_specific_translations['gpu'][svd_p] = partial(
_svd_cpu_gpu_translation_rule, cusolver.gesvd)