mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
404 lines
14 KiB
Python
404 lines
14 KiB
Python
# Copyright 2018 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from functools import partial
|
|
|
|
import scipy.linalg
|
|
import textwrap
|
|
|
|
from jax import jit
|
|
from .. import lax
|
|
from .. import lax_linalg
|
|
from ..numpy.lax_numpy import _wraps
|
|
from ..numpy import lax_numpy as np
|
|
from ..numpy import linalg as np_linalg
|
|
|
|
_T = lambda x: np.swapaxes(x, -1, -2)
|
|
|
|
@partial(jit, static_argnums=(1,))
|
|
def _cholesky(a, lower):
|
|
a = np_linalg._promote_arg_dtypes(np.asarray(a))
|
|
l = lax_linalg.cholesky(a if lower else np.conj(_T(a)), symmetrize_input=False)
|
|
return l if lower else np.conj(_T(l))
|
|
|
|
@_wraps(scipy.linalg.cholesky)
|
|
def cholesky(a, lower=False, overwrite_a=False, check_finite=True):
|
|
del overwrite_a, check_finite
|
|
return _cholesky(a, lower)
|
|
|
|
@_wraps(scipy.linalg.cho_factor)
|
|
def cho_factor(a, lower=False, overwrite_a=False, check_finite=True):
|
|
return (cholesky(a, lower=lower), lower)
|
|
|
|
@partial(jit, static_argnums=(2,))
|
|
def _cho_solve(c, b, lower):
|
|
c, b = np_linalg._promote_arg_dtypes(np.asarray(c), np.asarray(b))
|
|
c_shape = np.shape(c)
|
|
b_shape = np.shape(b)
|
|
c_ndims = len(c_shape)
|
|
b_ndims = len(b_shape)
|
|
if not (c_ndims >= 2 and c_shape[-1] == c_shape[-2] and
|
|
(c_ndims == b_ndims or c_ndims == b_ndims + 1)):
|
|
msg = ("The arguments to solve must have shapes a=[..., m, m] and "
|
|
"b=[..., m, k] or b=[..., m]; got a={} and b={}")
|
|
raise ValueError(msg.format(c_shape, b_shape))
|
|
|
|
# TODO(phawkins): triangular_solve only supports matrices on the RHS, so we
|
|
# add a dummy dimension. Extend it to support vectors and simplify this.
|
|
b = b if c_ndims == b_ndims else b[..., None]
|
|
b = lax_linalg.triangular_solve(c, b, left_side=True, lower=lower,
|
|
transpose_a=not lower, conjugate_a=not lower)
|
|
b = lax_linalg.triangular_solve(c, b, left_side=True, lower=lower,
|
|
transpose_a=lower, conjugate_a=lower)
|
|
return b[..., 0] if c_ndims != b_ndims else b
|
|
|
|
@_wraps(scipy.linalg.cho_solve, update_doc=False)
|
|
def cho_solve(c_and_lower, b, overwrite_b=False, check_finite=True):
|
|
del overwrite_b, check_finite
|
|
c, lower = c_and_lower
|
|
return _cho_solve(c, b, lower)
|
|
|
|
@_wraps(scipy.linalg.svd)
|
|
def svd(a, full_matrices=True, compute_uv=True, overwrite_a=False,
|
|
check_finite=True, lapack_driver='gesdd'):
|
|
del overwrite_a, check_finite, lapack_driver
|
|
a = np_linalg._promote_arg_dtypes(np.asarray(a))
|
|
return lax_linalg.svd(a, full_matrices, compute_uv)
|
|
|
|
|
|
@_wraps(scipy.linalg.det)
|
|
def det(a, overwrite_a=False, check_finite=True):
|
|
del overwrite_a, check_finite
|
|
return np_linalg.det(a)
|
|
|
|
|
|
@_wraps(scipy.linalg.eigh)
|
|
def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
|
|
overwrite_b=False, turbo=True, eigvals=None, type=1,
|
|
check_finite=True):
|
|
del overwrite_a, overwrite_b, turbo, check_finite
|
|
if b is not None:
|
|
raise NotImplementedError("Only the b=None case of eigh is implemented")
|
|
if type != 1:
|
|
raise NotImplementedError("Only the type=1 case of eigh is implemented.")
|
|
if eigvals is not None:
|
|
raise NotImplementedError(
|
|
"Only the eigvals=None case of eigh is implemented.")
|
|
|
|
a = np_linalg._promote_arg_dtypes(np.asarray(a))
|
|
v, w = lax_linalg.eigh(a, lower=lower)
|
|
|
|
if eigvals_only:
|
|
return w
|
|
else:
|
|
return w, v
|
|
|
|
|
|
|
|
@_wraps(scipy.linalg.inv)
|
|
def inv(a, overwrite_a=False, check_finite=True):
|
|
del overwrite_a, check_finite
|
|
return np_linalg.inv(a)
|
|
|
|
|
|
@_wraps(scipy.linalg.lu_factor)
|
|
def lu_factor(a, overwrite_a=False, check_finite=True):
|
|
del overwrite_a, check_finite
|
|
a = np_linalg._promote_arg_dtypes(np.asarray(a))
|
|
return lax_linalg.lu(a)
|
|
|
|
@partial(jit, static_argnums=(3,))
|
|
def _lu_solve(lu, pivots, b, trans):
|
|
lu_shape = np.shape(lu)
|
|
b_shape = np.shape(b)
|
|
if len(lu_shape) != 2 or lu_shape[0] != lu_shape[1]:
|
|
raise ValueError("LU decomposition must be a square matrix, got shape {}"
|
|
.format(lu_shape))
|
|
if len(b_shape) < 1:
|
|
raise ValueError("b matrix must have rank >= 1, got shape {}"
|
|
.format(b_shape))
|
|
|
|
if b_shape[0] != lu_shape[0]:
|
|
raise ValueError("Dimension of LU decomposition matrix (shape {}) must "
|
|
"match leading axis of b array (shape {})"
|
|
.format(lu_shape, b_shape))
|
|
m = lu_shape[0]
|
|
permutation = lax_linalg.lu_pivots_to_permutation(np.array(pivots), m)
|
|
x = np.reshape(b, (m, -1))
|
|
if trans == 0:
|
|
x = x[permutation, :]
|
|
x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=True,
|
|
unit_diagonal=True)
|
|
x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=False)
|
|
elif trans == 1 or trans == 2:
|
|
conj = trans == 2
|
|
x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=False,
|
|
transpose_a=True, conjugate_a=conj)
|
|
x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=True,
|
|
unit_diagonal=True, transpose_a=True,
|
|
conjugate_a=conj)
|
|
x = x[np.argsort(permutation), :]
|
|
else:
|
|
raise ValueError("'trans' value must be 0, 1, or 2, got {}".format(trans))
|
|
return lax.reshape(x, b_shape)
|
|
|
|
@_wraps(scipy.linalg.lu_solve)
|
|
def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
|
|
del overwrite_b, check_finite
|
|
lu, pivots = lu_and_piv
|
|
return _lu_solve(lu, pivots, b, trans)
|
|
|
|
|
|
@partial(jit, static_argnums=(1,))
|
|
def _lu(a, permute_l):
|
|
a = np_linalg._promote_arg_dtypes(np.asarray(a))
|
|
lu, pivots = lax_linalg.lu(a)
|
|
dtype = lax.dtype(a)
|
|
m, n = np.shape(a)
|
|
permutation = lax_linalg.lu_pivots_to_permutation(pivots, m)
|
|
p = np.real(np.array(permutation == np.arange(m)[:, None], dtype=dtype))
|
|
k = min(m, n)
|
|
l = np.tril(lu, -1)[:, :k] + np.eye(m, k, dtype=dtype)
|
|
u = np.triu(lu)[:k, :]
|
|
if permute_l:
|
|
return np.matmul(p, l), u
|
|
else:
|
|
return p, l, u
|
|
|
|
@_wraps(scipy.linalg.lu, update_doc=False)
|
|
def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
|
|
del overwrite_a, check_finite
|
|
return _lu(a, permute_l)
|
|
|
|
@partial(jit, static_argnums=(1, 2))
|
|
def _qr(a, mode, pivoting):
|
|
if pivoting:
|
|
raise NotImplementedError(
|
|
"The pivoting=True case of qr is not implemented.")
|
|
if mode in ("full", "r"):
|
|
full_matrices = True
|
|
elif mode == "economic":
|
|
full_matrices = False
|
|
else:
|
|
raise ValueError("Unsupported QR decomposition mode '{}'".format(mode))
|
|
a = np_linalg._promote_arg_dtypes(np.asarray(a))
|
|
q, r = lax_linalg.qr(a, full_matrices)
|
|
if mode == "r":
|
|
return r
|
|
return q, r
|
|
|
|
@_wraps(scipy.linalg.qr)
|
|
def qr(a, overwrite_a=False, lwork=None, mode="full", pivoting=False,
|
|
check_finite=True):
|
|
del overwrite_a, lwork, check_finite
|
|
return _qr(a, mode, pivoting)
|
|
|
|
@partial(jit, static_argnums=(2, 3))
|
|
def _solve(a, b, sym_pos, lower):
|
|
if not sym_pos:
|
|
return np_linalg.solve(a, b)
|
|
|
|
a, b = np_linalg._promote_arg_dtypes(np.asarray(a), np.asarray(b))
|
|
return cho_solve(cho_factor(a, lower=lower), b)
|
|
|
|
@_wraps(scipy.linalg.solve)
|
|
def solve(a, b, sym_pos=False, lower=False, overwrite_a=False, overwrite_b=False,
|
|
debug=False, check_finite=True):
|
|
del overwrite_a, overwrite_b, debug, check_finite
|
|
return _solve(a, b, sym_pos, lower)
|
|
|
|
@partial(jit, static_argnums=(2, 3, 4))
|
|
def _solve_triangular(a, b, trans, lower, unit_diagonal):
|
|
if trans == 0 or trans == "N":
|
|
transpose_a, conjugate_a = False, False
|
|
elif trans == 1 or trans == "T":
|
|
transpose_a, conjugate_a = True, False
|
|
elif trans == 2 or trans == "C":
|
|
transpose_a, conjugate_a = True, True
|
|
else:
|
|
raise ValueError("Invalid 'trans' value {}".format(trans))
|
|
|
|
a, b = np_linalg._promote_arg_dtypes(np.asarray(a), np.asarray(b))
|
|
|
|
# lax_linalg.triangular_solve only supports matrix 'b's at the moment.
|
|
b_is_vector = np.ndim(a) == np.ndim(b) + 1
|
|
if b_is_vector:
|
|
b = b[..., None]
|
|
out = lax_linalg.triangular_solve(a, b, left_side=True, lower=lower,
|
|
transpose_a=transpose_a,
|
|
conjugate_a=conjugate_a,
|
|
unit_diagonal=unit_diagonal)
|
|
if b_is_vector:
|
|
return out[..., 0]
|
|
else:
|
|
return out
|
|
|
|
@_wraps(scipy.linalg.solve_triangular)
|
|
def solve_triangular(a, b, trans=0, lower=False, unit_diagonal=False,
|
|
overwrite_b=False, debug=None, check_finite=True):
|
|
del overwrite_b, debug, check_finite
|
|
return _solve_triangular(a, b, trans, lower, unit_diagonal)
|
|
|
|
|
|
|
|
@_wraps(scipy.linalg.tril)
|
|
def tril(m, k=0):
|
|
return np.tril(m, k)
|
|
|
|
|
|
@_wraps(scipy.linalg.triu)
|
|
def triu(m, k=0):
|
|
return np.triu(m, k)
|
|
|
|
@_wraps(scipy.linalg.expm, lax_description=textwrap.dedent("""\
|
|
In addition to the original NumPy argument(s) listed below,
|
|
also supports the optional boolean argument ``upper_triangular``
|
|
to specify whether the ``A`` matrix is upper triangular.
|
|
"""))
|
|
def expm(A, *, upper_triangular=False):
|
|
return _expm(A, upper_triangular)
|
|
|
|
def _expm(A, upper_triangular=False):
|
|
P,Q,n_squarings = _calc_P_Q(A)
|
|
R = _solve_P_Q(P, Q, upper_triangular)
|
|
R = _squaring(R, n_squarings)
|
|
return R
|
|
|
|
@jit
|
|
def _calc_P_Q(A):
|
|
A = np.asarray(A)
|
|
if A.ndim != 2 or A.shape[0] != A.shape[1]:
|
|
raise ValueError('expected A to be a square matrix')
|
|
A_L1 = np_linalg.norm(A,1)
|
|
n_squarings = 0
|
|
if A.dtype == 'float64' or A.dtype == 'complex128':
|
|
U3,V3 = _pade3(A)
|
|
U5,V5 = _pade5(A)
|
|
U7,V7 = _pade7(A)
|
|
U9,V9 = _pade9(A)
|
|
maxnorm = 5.371920351148152
|
|
n_squarings = np.maximum(0, np.floor_divide(np.log2(A_L1 / maxnorm),1))
|
|
A = A / 2**n_squarings
|
|
U13,V13 = _pade13(A)
|
|
conds=np.array([1.495585217958292e-002, 2.539398330063230e-001, 9.504178996162932e-001, 2.097847961257068e+000])
|
|
U = np.select((maxnorm<conds),(U3,U5,U7,U9),U13)
|
|
V = np.select((maxnorm<conds),(V3,V5,V7,V9),V13)
|
|
elif A.dtype == 'float32' or A.dtype == 'complex64':
|
|
U3,V3 = _pade3(A)
|
|
U5,V5 = _pade5(A)
|
|
maxnorm = 3.925724783138660
|
|
n_squarings = np.maximum(0, np.floor_divide(np.log2(A_L1 / maxnorm),1))
|
|
A = A / 2**n_squarings
|
|
U7,V7 = _pade7(A)
|
|
conds=np.array([4.258730016922831e-001, 1.880152677804762e+000])
|
|
U = np.select((maxnorm<conds),(U3,U5),U7)
|
|
V = np.select((maxnorm<conds),(V3,V5),V7)
|
|
else:
|
|
raise TypeError("A.dtype={} is not supported.".format(A.dtype))
|
|
P = U + V # p_m(A) : numerator
|
|
Q = -U + V # q_m(A) : denominator
|
|
return P,Q,n_squarings
|
|
|
|
def _solve_P_Q(P, Q, upper_triangular=False):
|
|
if upper_triangular:
|
|
return solve_triangular(Q, P)
|
|
else:
|
|
return np_linalg.solve(Q,P)
|
|
|
|
@jit
|
|
def _squaring(R, n_squarings):
|
|
# squaring step to undo scaling
|
|
def my_body_fun(i,R):
|
|
return np.dot(R,R)
|
|
lower = np.zeros(1, dtype=n_squarings.dtype)
|
|
R = lax.fori_loop(lower[0],n_squarings,my_body_fun,R)
|
|
return R
|
|
|
|
def _pade3(A):
|
|
b = (120., 60., 12., 1.)
|
|
ident = np.eye(*A.shape, dtype=A.dtype)
|
|
A2 = np.dot(A,A)
|
|
U = np.dot(A , (b[3]*A2 + b[1]*ident))
|
|
V = b[2]*A2 + b[0]*ident
|
|
return U,V
|
|
|
|
def _pade5(A):
|
|
b = (30240., 15120., 3360., 420., 30., 1.)
|
|
ident = np.eye(*A.shape, dtype=A.dtype)
|
|
A2 = np.dot(A,A)
|
|
A4 = np.dot(A2,A2)
|
|
U = np.dot(A, b[5]*A4 + b[3]*A2 + b[1]*ident)
|
|
V = b[4]*A4 + b[2]*A2 + b[0]*ident
|
|
return U,V
|
|
|
|
def _pade7(A):
|
|
b = (17297280., 8648640., 1995840., 277200., 25200., 1512., 56., 1.)
|
|
ident = np.eye(*A.shape, dtype=A.dtype)
|
|
A2 = np.dot(A,A)
|
|
A4 = np.dot(A2,A2)
|
|
A6 = np.dot(A4,A2)
|
|
U = np.dot(A, b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
|
|
V = b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
|
|
return U,V
|
|
|
|
def _pade9(A):
|
|
b = (17643225600., 8821612800., 2075673600., 302702400., 30270240.,
|
|
2162160., 110880., 3960., 90., 1.)
|
|
ident = np.eye(*A.shape, dtype=A.dtype)
|
|
A2 = np.dot(A,A)
|
|
A4 = np.dot(A2,A2)
|
|
A6 = np.dot(A4,A2)
|
|
A8 = np.dot(A6,A2)
|
|
U = np.dot(A, b[9]*A8 + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
|
|
V = b[8]*A8 + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
|
|
return U,V
|
|
|
|
def _pade13(A):
|
|
b = (64764752532480000., 32382376266240000., 7771770303897600.,
|
|
1187353796428800., 129060195264000., 10559470521600., 670442572800.,
|
|
33522128640., 1323241920., 40840800., 960960., 16380., 182., 1.)
|
|
ident = np.eye(*A.shape, dtype=A.dtype)
|
|
A2 = np.dot(A,A)
|
|
A4 = np.dot(A2,A2)
|
|
A6 = np.dot(A4,A2)
|
|
U = np.dot(A,np.dot(A6, b[13]*A6 + b[11]*A4 + b[9]*A2) + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
|
|
V = np.dot(A6, b[12]*A6 + b[10]*A4 + b[8]*A2) + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
|
|
return U,V
|
|
|
|
|
|
@_wraps(scipy.linalg.block_diag)
|
|
@jit
|
|
def block_diag(*arrs):
|
|
if len(arrs) == 0:
|
|
arrs = [np.zeros((1, 0))]
|
|
arrs = np._promote_dtypes(*arrs)
|
|
bad_shapes = [i for i, a in enumerate(arrs) if np.ndim(a) > 2]
|
|
if bad_shapes:
|
|
raise ValueError("Arguments to jax.scipy.linalg.block_diag must have at "
|
|
"most 2 dimensions, got {} at argument {}."
|
|
.format(arrs[bad_shapes[0]], bad_shapes[0]))
|
|
arrs = [np.atleast_2d(a) for a in arrs]
|
|
acc = arrs[0]
|
|
dtype = lax.dtype(acc)
|
|
for a in arrs[1:]:
|
|
_, c = a.shape
|
|
a = lax.pad(a, dtype.type(0), ((0, 0, 0), (acc.shape[-1], 0, 0)))
|
|
acc = lax.pad(acc, dtype.type(0), ((0, 0, 0), (0, c, 0)))
|
|
acc = lax.concatenate([acc, a], dimension=0)
|
|
return acc
|