mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00

The unbatched Jacobi solver is faster for small-moderate matrices, and the unbatched kernel doesn't have size restrictions. Timings on T4 GPU: Before: ------------------------------------------------------------ Benchmark Time CPU Iterations ------------------------------------------------------------ svd/m:1/n:1 263587 ns 242274 ns 2780 svd/m:2/n:1 335561 ns 298238 ns 2303 svd/m:5/n:1 337784 ns 299841 ns 2304 svd/m:10/n:1 339184 ns 300703 ns 2311 svd/m:100/n:1 359826 ns 320088 ns 2159 svd/m:500/n:1 376124 ns 338660 ns 2076 svd/m:800/n:1 375779 ns 335590 ns 2060 svd/m:1000/n:1 419171 ns 341487 ns 2072 svd/m:1/n:2 307564 ns 270663 ns 2544 svd/m:2/n:2 320928 ns 283601 ns 2487 svd/m:5/n:2 377373 ns 344228 ns 2035 svd/m:10/n:2 380557 ns 349412 ns 1953 svd/m:100/n:2 435465 ns 403496 ns 1722 svd/m:500/n:2 444610 ns 410913 ns 1680 svd/m:800/n:2 454493 ns 416495 ns 1665 svd/m:1000/n:2 492110 ns 420539 ns 1665 svd/m:1/n:5 307316 ns 275833 ns 2531 svd/m:2/n:5 374318 ns 341432 ns 2086 svd/m:5/n:5 512928 ns 470293 ns 1361 svd/m:10/n:5 589330 ns 537070 ns 1353 svd/m:100/n:5 620164 ns 580166 ns 1193 svd/m:500/n:5 636424 ns 593692 ns 1180 svd/m:800/n:5 635545 ns 595016 ns 1181 svd/m:1000/n:5 672443 ns 597387 ns 1115 svd/m:1/n:10 310013 ns 273998 ns 2520 svd/m:2/n:10 370451 ns 334489 ns 2105 svd/m:5/n:10 560037 ns 522223 ns 1274 svd/m:10/n:10 572868 ns 535388 ns 1304 svd/m:100/n:10 959802 ns 918258 ns 765 svd/m:500/n:10 955958 ns 909778 ns 758 svd/m:800/n:10 924104 ns 879512 ns 777 svd/m:1000/n:10 950140 ns 883493 ns 775 svd/m:1/n:100 351237 ns 315554 ns 2198 svd/m:2/n:100 426883 ns 390089 ns 1792 svd/m:5/n:100 601557 ns 564493 ns 1255 svd/m:10/n:100 920819 ns 880011 ns 787 svd/m:100/n:100 7902281 ns 7229220 ns 95 svd/m:500/n:100 9720727 ns 9040679 ns 79 svd/m:800/n:100 9856378 ns 8998050 ns 79 svd/m:1000/n:100 9721017 ns 9086414 ns 79 svd/m:1/n:500 371171 ns 334217 ns 2117 svd/m:2/n:500 449165 ns 411499 ns 1700 svd/m:5/n:500 620354 ns 581866 ns 1185 svd/m:10/n:500 892375 ns 847239 ns 833 svd/m:100/n:500 9564810 ns 8867540 ns 79 svd/m:500/n:500 111924035 ns 104078023 ns 7 svd/m:800/n:500 147777319 ns 142730412 ns 5 svd/m:1000/n:500 154205084 ns 149740209 ns 5 svd/m:1/n:800 372122 ns 334212 ns 2119 svd/m:2/n:800 456672 ns 419260 ns 1680 svd/m:5/n:800 691208 ns 626003 ns 1190 svd/m:10/n:800 1017694 ns 941480 ns 730 svd/m:100/n:800 9892683 ns 9091043 ns 76 svd/m:500/n:800 144134235 ns 139129722 ns 5 svd/m:800/n:800 342790246 ns 333299774 ns 2 svd/m:1000/n:800 432820082 ns 427978978 ns 2 svd/m:1/n:1000 372785 ns 335745 ns 1805 svd/m:2/n:1000 451946 ns 413341 ns 1668 svd/m:5/n:1000 618475 ns 577213 ns 1169 svd/m:10/n:1000 907729 ns 863335 ns 808 svd/m:100/n:1000 9868543 ns 9116870 ns 76 svd/m:500/n:1000 156777811 ns 152042065 ns 5 svd/m:800/n:1000 429704070 ns 424677592 ns 2 svd/m:1000/n:1000 654864311 ns 642693162 ns 1 After: ------------------------------------------------------------ Benchmark Time CPU Iterations ------------------------------------------------------------ svd/m:1/n:1 265980 ns 245433 ns 2791 svd/m:2/n:1 340203 ns 302783 ns 2288 svd/m:5/n:1 337807 ns 301916 ns 2286 svd/m:10/n:1 338064 ns 302441 ns 2297 svd/m:100/n:1 335444 ns 298440 ns 2327 svd/m:500/n:1 338025 ns 302096 ns 2272 svd/m:800/n:1 328382 ns 291740 ns 2252 svd/m:1000/n:1 397494 ns 310905 ns 2239 svd/m:1/n:2 310464 ns 274507 ns 2535 svd/m:2/n:2 319999 ns 284247 ns 2515 svd/m:5/n:2 373435 ns 335919 ns 2069 svd/m:10/n:2 376327 ns 339327 ns 2056 svd/m:100/n:2 385061 ns 349258 ns 2003 svd/m:500/n:2 392352 ns 355735 ns 1932 svd/m:800/n:2 410736 ns 370677 ns 1881 svd/m:1000/n:2 494326 ns 405603 ns 1721 svd/m:1/n:5 316735 ns 277292 ns 2538 svd/m:2/n:5 383748 ns 342218 ns 2077 svd/m:5/n:5 494204 ns 454309 ns 1476 svd/m:10/n:5 547017 ns 508184 ns 1371 svd/m:100/n:5 514537 ns 476761 ns 1460 svd/m:500/n:5 544656 ns 504877 ns 1381 svd/m:800/n:5 642590 ns 599314 ns 1159 svd/m:1000/n:5 706166 ns 621209 ns 1106 svd/m:1/n:10 310825 ns 274374 ns 2511 svd/m:2/n:10 381316 ns 344202 ns 2094 svd/m:5/n:10 565469 ns 526759 ns 1266 svd/m:10/n:10 576111 ns 537286 ns 1299 svd/m:100/n:10 653250 ns 613392 ns 1137 svd/m:500/n:10 690532 ns 645828 ns 1080 svd/m:800/n:10 763924 ns 723677 ns 959 svd/m:1000/n:10 940342 ns 855517 ns 818 svd/m:1/n:100 306134 ns 271533 ns 2526 svd/m:2/n:100 374680 ns 339298 ns 2071 svd/m:5/n:100 576926 ns 539062 ns 1228 svd/m:10/n:100 656806 ns 615171 ns 1123 svd/m:100/n:100 3295164 ns 3138621 ns 223 svd/m:500/n:100 4269347 ns 4166000 ns 168 svd/m:800/n:100 4656541 ns 4522247 ns 154 svd/m:1000/n:100 6479223 ns 6354578 ns 112 svd/m:1/n:500 329966 ns 289083 ns 2440 svd/m:2/n:500 407535 ns 366794 ns 1947 svd/m:5/n:500 567367 ns 522809 ns 1336 svd/m:10/n:500 712307 ns 657608 ns 1065 svd/m:100/n:500 4262986 ns 4169907 ns 167 svd/m:500/n:500 28824720 ns 28650258 ns 25 svd/m:800/n:500 29330139 ns 28677269 ns 25 svd/m:1000/n:500 30848037 ns 30089216 ns 23 svd/m:1/n:800 328620 ns 289181 ns 2329 svd/m:2/n:800 419052 ns 379483 ns 1876 svd/m:5/n:800 587366 ns 546979 ns 1269 svd/m:10/n:800 830762 ns 787923 ns 893 svd/m:100/n:800 4763633 ns 4595738 ns 152 svd/m:500/n:800 30447861 ns 29949714 ns 24 svd/m:800/n:800 94188958 ns 93488372 ns 8 svd/m:1000/n:800 94701529 ns 93394677 ns 7 svd/m:1/n:1000 351102 ns 313099 ns 2218 svd/m:2/n:1000 446543 ns 407807 ns 1708 svd/m:5/n:1000 661152 ns 616174 ns 1129 svd/m:10/n:1000 915743 ns 873397 ns 802 svd/m:100/n:1000 6434730 ns 6282779 ns 113 svd/m:500/n:1000 30244321 ns 29684290 ns 24 svd/m:800/n:1000 92727423 ns 91477078 ns 8 svd/m:1000/n:1000 169500709 ns 168358420 ns 4 PiperOrigin-RevId: 582041508
1848 lines
67 KiB
Python
1848 lines
67 KiB
Python
# Copyright 2018 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Tests for the LAPAX linear algebra module."""
|
|
|
|
from functools import partial
|
|
import itertools
|
|
|
|
import numpy as np
|
|
import scipy
|
|
import scipy.linalg
|
|
import scipy as osp
|
|
|
|
from absl.testing import absltest
|
|
|
|
import jax
|
|
from jax import jit, grad, jvp, vmap
|
|
from jax import lax
|
|
from jax import numpy as jnp
|
|
from jax import scipy as jsp
|
|
from jax._src import config
|
|
from jax._src import test_util as jtu
|
|
from jax._src import xla_bridge
|
|
from jax._src.numpy.util import promote_dtypes_inexact
|
|
|
|
config.parse_flags_with_absl()
|
|
|
|
scipy_version = tuple(map(int, scipy.version.version.split('.')[:3]))
|
|
|
|
T = lambda x: np.swapaxes(x, -1, -2)
|
|
|
|
|
|
float_types = jtu.dtypes.floating
|
|
complex_types = jtu.dtypes.complex
|
|
int_types = jtu.dtypes.all_integer
|
|
|
|
def _is_required_cuda_version_satisfied(cuda_version):
|
|
version = xla_bridge.get_backend().platform_version
|
|
if version == "<unknown>" or version.split()[0] == "rocm":
|
|
return False
|
|
else:
|
|
return int(version.split()[-1]) >= cuda_version
|
|
|
|
class NumpyLinalgTest(jtu.JaxTestCase):
|
|
|
|
@jtu.sample_product(
|
|
shape=[(1, 1), (4, 4), (2, 5, 5), (200, 200), (1000, 0, 0)],
|
|
dtype=float_types + complex_types,
|
|
)
|
|
def testCholesky(self, shape, dtype):
|
|
rng = jtu.rand_default(self.rng())
|
|
def args_maker():
|
|
factor_shape = shape[:-1] + (2 * shape[-1],)
|
|
a = rng(factor_shape, dtype)
|
|
return [np.matmul(a, jnp.conj(T(a)))]
|
|
|
|
self._CheckAgainstNumpy(np.linalg.cholesky, jnp.linalg.cholesky, args_maker,
|
|
tol=1e-3)
|
|
self._CompileAndCheck(jnp.linalg.cholesky, args_maker)
|
|
|
|
if jnp.finfo(dtype).bits == 64:
|
|
jtu.check_grads(jnp.linalg.cholesky, args_maker(), order=2)
|
|
|
|
def testCholeskyGradPrecision(self):
|
|
rng = jtu.rand_default(self.rng())
|
|
a = rng((3, 3), np.float32)
|
|
a = np.dot(a, a.T)
|
|
jtu.assert_dot_precision(
|
|
lax.Precision.HIGHEST, partial(jvp, jnp.linalg.cholesky), (a,), (a,))
|
|
|
|
@jtu.sample_product(
|
|
n=[0, 2, 3, 4, 5, 25], # TODO(mattjj): complex64 unstable on large sizes?
|
|
dtype=float_types + complex_types,
|
|
)
|
|
def testDet(self, n, dtype):
|
|
rng = jtu.rand_default(self.rng())
|
|
args_maker = lambda: [rng((n, n), dtype)]
|
|
|
|
self._CheckAgainstNumpy(np.linalg.det, jnp.linalg.det, args_maker, tol=1e-3)
|
|
self._CompileAndCheck(jnp.linalg.det, args_maker,
|
|
rtol={np.float64: 1e-13, np.complex128: 1e-13})
|
|
|
|
def testDetOfSingularMatrix(self):
|
|
x = jnp.array([[-1., 3./2], [2./3, -1.]], dtype=np.float32)
|
|
self.assertAllClose(np.float32(0), jsp.linalg.det(x))
|
|
|
|
@jtu.sample_product(
|
|
shape=[(1, 1), (3, 3), (2, 4, 4)],
|
|
dtype=float_types,
|
|
)
|
|
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
|
@jtu.skip_on_devices("tpu")
|
|
def testDetGrad(self, shape, dtype):
|
|
rng = jtu.rand_default(self.rng())
|
|
a = rng(shape, dtype)
|
|
jtu.check_grads(jnp.linalg.det, (a,), 2, atol=1e-1, rtol=1e-1)
|
|
# make sure there are no NaNs when a matrix is zero
|
|
if len(shape) == 2:
|
|
pass
|
|
jtu.check_grads(
|
|
jnp.linalg.det, (jnp.zeros_like(a),), 1, atol=1e-1, rtol=1e-1)
|
|
else:
|
|
a[0] = 0
|
|
jtu.check_grads(jnp.linalg.det, (a,), 1, atol=1e-1, rtol=1e-1)
|
|
|
|
def testDetGradIssue6121(self):
|
|
f = lambda x: jnp.linalg.det(x).sum()
|
|
x = jnp.ones((16, 1, 1))
|
|
jax.grad(f)(x)
|
|
jtu.check_grads(f, (x,), 2, atol=1e-1, rtol=1e-1)
|
|
|
|
def testDetGradOfSingularMatrixCorank1(self):
|
|
# Rank 2 matrix with nonzero gradient
|
|
a = jnp.array([[ 50, -30, 45],
|
|
[-30, 90, -81],
|
|
[ 45, -81, 81]], dtype=jnp.float32)
|
|
jtu.check_grads(jnp.linalg.det, (a,), 1, atol=1e-1, rtol=1e-1)
|
|
|
|
# TODO(phawkins): Test sometimes produces NaNs on TPU.
|
|
@jtu.skip_on_devices("tpu")
|
|
def testDetGradOfSingularMatrixCorank2(self):
|
|
# Rank 1 matrix with zero gradient
|
|
b = jnp.array([[ 36, -42, 18],
|
|
[-42, 49, -21],
|
|
[ 18, -21, 9]], dtype=jnp.float32)
|
|
jtu.check_grads(jnp.linalg.det, (b,), 1, atol=1e-1, rtol=1e-1, eps=1e-1)
|
|
|
|
@jtu.sample_product(
|
|
m=[1, 5, 7, 23],
|
|
nq=zip([2, 4, 6, 36], [(1, 2), (2, 2), (1, 2, 3), (3, 3, 1, 4)]),
|
|
dtype=float_types,
|
|
)
|
|
def testTensorsolve(self, m, nq, dtype):
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
# According to numpy docs the shapes are as follows:
|
|
# Coefficient tensor (a), of shape b.shape + Q.
|
|
# And prod(Q) == prod(b.shape)
|
|
# Therefore, n = prod(q)
|
|
n, q = nq
|
|
b_shape = (n, m)
|
|
# To accomplish prod(Q) == prod(b.shape) we append the m extra dim
|
|
# to Q shape
|
|
Q = q + (m,)
|
|
args_maker = lambda: [
|
|
rng(b_shape + Q, dtype), # = a
|
|
rng(b_shape, dtype)] # = b
|
|
a, b = args_maker()
|
|
result = jnp.linalg.tensorsolve(*args_maker())
|
|
self.assertEqual(result.shape, Q)
|
|
|
|
self._CheckAgainstNumpy(np.linalg.tensorsolve,
|
|
jnp.linalg.tensorsolve, args_maker,
|
|
tol={np.float32: 1e-2, np.float64: 1e-3})
|
|
self._CompileAndCheck(jnp.linalg.tensorsolve,
|
|
args_maker,
|
|
rtol={np.float64: 1e-13})
|
|
|
|
@jtu.sample_product(
|
|
[dict(dtype=dtype, method=method)
|
|
for dtype in float_types + complex_types
|
|
for method in (["lu"] if jnp.issubdtype(dtype, jnp.complexfloating)
|
|
else ["lu", "qr"])
|
|
],
|
|
shape=[(0, 0), (1, 1), (3, 3), (4, 4), (10, 10), (200, 200), (2, 2, 2),
|
|
(2, 3, 3), (3, 2, 2)],
|
|
)
|
|
def testSlogdet(self, shape, dtype, method):
|
|
rng = jtu.rand_default(self.rng())
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
slogdet = partial(jnp.linalg.slogdet, method=method)
|
|
self._CheckAgainstNumpy(np.linalg.slogdet, slogdet, args_maker,
|
|
tol=1e-3)
|
|
self._CompileAndCheck(slogdet, args_maker)
|
|
|
|
@jtu.sample_product(
|
|
shape=[(1, 1), (4, 4), (5, 5), (2, 7, 7)],
|
|
dtype=float_types + complex_types,
|
|
)
|
|
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
|
def testSlogdetGrad(self, shape, dtype):
|
|
rng = jtu.rand_default(self.rng())
|
|
a = rng(shape, dtype)
|
|
jtu.check_grads(jnp.linalg.slogdet, (a,), 2, atol=1e-1, rtol=2e-1)
|
|
|
|
def testIssue1213(self):
|
|
for n in range(5):
|
|
mat = jnp.array([np.diag(np.ones([5], dtype=np.float32))*(-.01)] * 2)
|
|
args_maker = lambda: [mat]
|
|
self._CheckAgainstNumpy(np.linalg.slogdet, jnp.linalg.slogdet, args_maker,
|
|
tol=1e-3)
|
|
|
|
@jtu.sample_product(
|
|
shape=[(0, 0), (4, 4), (5, 5), (50, 50), (2, 6, 6)],
|
|
dtype=float_types + complex_types,
|
|
compute_left_eigenvectors=[False, True],
|
|
compute_right_eigenvectors=[False, True],
|
|
)
|
|
# TODO(phawkins): enable when there is an eigendecomposition implementation
|
|
# for GPU/TPU.
|
|
@jtu.run_on_devices("cpu")
|
|
def testEig(self, shape, dtype, compute_left_eigenvectors,
|
|
compute_right_eigenvectors):
|
|
rng = jtu.rand_default(self.rng())
|
|
n = shape[-1]
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
# Norm, adjusted for dimension and type.
|
|
def norm(x):
|
|
norm = np.linalg.norm(x, axis=(-2, -1))
|
|
return norm / ((n + 1) * jnp.finfo(dtype).eps)
|
|
|
|
def check_right_eigenvectors(a, w, vr):
|
|
self.assertTrue(
|
|
np.all(norm(np.matmul(a, vr) - w[..., None, :] * vr) < 100))
|
|
|
|
def check_left_eigenvectors(a, w, vl):
|
|
rank = len(a.shape)
|
|
aH = jnp.conj(a.transpose(list(range(rank - 2)) + [rank - 1, rank - 2]))
|
|
wC = jnp.conj(w)
|
|
check_right_eigenvectors(aH, wC, vl)
|
|
|
|
a, = args_maker()
|
|
results = lax.linalg.eig(
|
|
a, compute_left_eigenvectors=compute_left_eigenvectors,
|
|
compute_right_eigenvectors=compute_right_eigenvectors)
|
|
w = results[0]
|
|
|
|
if compute_left_eigenvectors:
|
|
check_left_eigenvectors(a, w, results[1])
|
|
if compute_right_eigenvectors:
|
|
check_right_eigenvectors(a, w, results[1 + compute_left_eigenvectors])
|
|
|
|
self._CompileAndCheck(partial(jnp.linalg.eig), args_maker,
|
|
rtol=1e-3)
|
|
|
|
@jtu.sample_product(
|
|
shape=[(4, 4), (5, 5), (8, 8), (7, 6, 6)],
|
|
dtype=float_types + complex_types,
|
|
)
|
|
# TODO(phawkins): enable when there is an eigendecomposition implementation
|
|
# for GPU/TPU.
|
|
@jtu.run_on_devices("cpu")
|
|
def testEigvalsGrad(self, shape, dtype):
|
|
# This test sometimes fails for large matrices. I (@j-towns) suspect, but
|
|
# haven't checked, that might be because of perturbations causing the
|
|
# ordering of eigenvalues to change, which will trip up check_grads. So we
|
|
# just test on small-ish matrices.
|
|
rng = jtu.rand_default(self.rng())
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
a, = args_maker()
|
|
tol = 1e-4 if dtype in (np.float64, np.complex128) else 1e-1
|
|
jtu.check_grads(lambda x: jnp.linalg.eigvals(x), (a,), order=1,
|
|
modes=['fwd', 'rev'], rtol=tol, atol=tol)
|
|
|
|
@jtu.sample_product(
|
|
shape=[(4, 4), (5, 5), (50, 50)],
|
|
dtype=float_types + complex_types,
|
|
)
|
|
# TODO: enable when there is an eigendecomposition implementation
|
|
# for GPU/TPU.
|
|
@jtu.run_on_devices("cpu")
|
|
def testEigvals(self, shape, dtype):
|
|
rng = jtu.rand_default(self.rng())
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
a, = args_maker()
|
|
w1, _ = jnp.linalg.eig(a)
|
|
w2 = jnp.linalg.eigvals(a)
|
|
self.assertAllClose(w1, w2, rtol={np.complex64: 1e-5, np.complex128: 1e-14})
|
|
|
|
@jtu.run_on_devices("cpu")
|
|
def testEigvalsInf(self):
|
|
# https://github.com/google/jax/issues/2661
|
|
x = jnp.array([[jnp.inf]])
|
|
self.assertTrue(jnp.all(jnp.isnan(jnp.linalg.eigvals(x))))
|
|
|
|
@jtu.sample_product(
|
|
shape=[(1, 1), (4, 4), (5, 5)],
|
|
dtype=float_types + complex_types,
|
|
)
|
|
@jtu.run_on_devices("cpu")
|
|
def testEigBatching(self, shape, dtype):
|
|
rng = jtu.rand_default(self.rng())
|
|
shape = (10,) + shape
|
|
args = rng(shape, dtype)
|
|
ws, vs = vmap(jnp.linalg.eig)(args)
|
|
self.assertTrue(np.all(np.linalg.norm(
|
|
np.matmul(args, vs) - ws[..., None, :] * vs) < 1e-3))
|
|
|
|
@jtu.sample_product(
|
|
n=[0, 4, 5, 50, 512],
|
|
dtype=float_types + complex_types,
|
|
lower=[True, False],
|
|
)
|
|
def testEigh(self, n, dtype, lower):
|
|
rng = jtu.rand_default(self.rng())
|
|
tol = 0.5 * np.maximum(n, 80) * np.finfo(dtype).eps
|
|
args_maker = lambda: [rng((n, n), dtype)]
|
|
|
|
uplo = "L" if lower else "U"
|
|
|
|
a, = args_maker()
|
|
a = (a + np.conj(a.T)) / 2
|
|
w, v = jnp.linalg.eigh(np.tril(a) if lower else np.triu(a),
|
|
UPLO=uplo, symmetrize_input=False)
|
|
w = w.astype(v.dtype)
|
|
self.assertLessEqual(
|
|
np.linalg.norm(np.eye(n) - np.matmul(np.conj(T(v)), v)), 4 * tol
|
|
)
|
|
with jax.numpy_rank_promotion('allow'):
|
|
self.assertLessEqual(np.linalg.norm(np.matmul(a, v) - w * v),
|
|
tol * np.linalg.norm(a))
|
|
|
|
self._CompileAndCheck(
|
|
partial(jnp.linalg.eigh, UPLO=uplo), args_maker, rtol=tol
|
|
)
|
|
|
|
@jtu.sample_product(
|
|
start=[0, 1, 63, 64, 65, 255],
|
|
end=[1, 63, 64, 65, 256],
|
|
)
|
|
@jtu.run_on_devices("tpu") # TODO(rmlarsen: enable on other devices)
|
|
def testEighSubsetByIndex(self, start, end):
|
|
if start >= end:
|
|
return
|
|
dtype = np.float32
|
|
n = 256
|
|
rng = jtu.rand_default(self.rng())
|
|
tol = np.maximum(n, 80) * np.finfo(dtype).eps
|
|
args_maker = lambda: [rng((n, n), dtype)]
|
|
subset_by_index = (start, end)
|
|
k = end - start
|
|
(a,) = args_maker()
|
|
a = (a + np.conj(a.T)) / 2
|
|
|
|
v, w = lax.linalg.eigh(
|
|
a, symmetrize_input=False, subset_by_index=subset_by_index
|
|
)
|
|
w = w.astype(v.dtype)
|
|
|
|
self.assertEqual(v.shape, (n, k))
|
|
self.assertEqual(w.shape, (k,))
|
|
self.assertLessEqual(
|
|
np.linalg.norm(np.eye(k) - np.matmul(np.conj(T(v)), v)), 3 * tol
|
|
)
|
|
with jax.numpy_rank_promotion("allow"):
|
|
self.assertLessEqual(
|
|
np.linalg.norm(np.matmul(a, v) - w * v), tol * np.linalg.norm(a)
|
|
)
|
|
|
|
self._CompileAndCheck(partial(jnp.linalg.eigh), args_maker, rtol=tol)
|
|
|
|
# Compare eigenvalues against Numpy. We do not compare eigenvectors because
|
|
# they are not uniquely defined, but the two checks above guarantee that
|
|
# that they satisfy the conditions for being eigenvectors.
|
|
w_np = np.linalg.eigvalsh(a)[subset_by_index[0] : subset_by_index[1]]
|
|
self.assertAllClose(w_np, w, atol=tol, rtol=tol)
|
|
|
|
def testEighZeroDiagonal(self):
|
|
a = np.array([[0., -1., -1., 1.],
|
|
[-1., 0., 1., -1.],
|
|
[-1., 1., 0., -1.],
|
|
[1., -1., -1., 0.]], dtype=np.float32)
|
|
w, v = jnp.linalg.eigh(a)
|
|
w = w.astype(v.dtype)
|
|
eps = jnp.finfo(a.dtype).eps
|
|
with jax.numpy_rank_promotion('allow'):
|
|
self.assertLessEqual(
|
|
np.linalg.norm(np.matmul(a, v) - w * v), 2 * eps * np.linalg.norm(a)
|
|
)
|
|
|
|
def testEighTinyNorm(self):
|
|
rng = jtu.rand_default(self.rng())
|
|
a = rng((300, 300), dtype=np.float32)
|
|
eps = jnp.finfo(a.dtype).eps
|
|
a = eps * (a + np.conj(a.T))
|
|
w, v = jnp.linalg.eigh(a)
|
|
w = w.astype(v.dtype)
|
|
with jax.numpy_rank_promotion("allow"):
|
|
self.assertLessEqual(
|
|
np.linalg.norm(np.matmul(a, v) - w * v), 20 * eps * np.linalg.norm(a)
|
|
)
|
|
|
|
@jtu.sample_product(
|
|
rank=[1, 3, 299],
|
|
)
|
|
def testEighRankDeficient(self, rank):
|
|
rng = jtu.rand_default(self.rng())
|
|
eps = jnp.finfo(np.float32).eps
|
|
a = rng((300, rank), dtype=np.float32)
|
|
a = a @ np.conj(a.T)
|
|
w, v = jnp.linalg.eigh(a)
|
|
w = w.astype(v.dtype)
|
|
with jax.numpy_rank_promotion("allow"):
|
|
self.assertLessEqual(
|
|
np.linalg.norm(np.matmul(a, v) - w * v),
|
|
81 * eps * np.linalg.norm(a),
|
|
)
|
|
|
|
@jtu.sample_product(
|
|
n=[0, 4, 5, 50, 512],
|
|
dtype=float_types + complex_types,
|
|
lower=[True, False],
|
|
)
|
|
def testEighIdentity(self, n, dtype, lower):
|
|
tol = np.finfo(dtype).eps
|
|
uplo = "L" if lower else "U"
|
|
|
|
a = jnp.eye(n, dtype=dtype)
|
|
w, v = jnp.linalg.eigh(a, UPLO=uplo, symmetrize_input=False)
|
|
w = w.astype(v.dtype)
|
|
self.assertLessEqual(
|
|
np.linalg.norm(np.eye(n) - np.matmul(np.conj(T(v)), v)), tol
|
|
)
|
|
with jax.numpy_rank_promotion('allow'):
|
|
self.assertLessEqual(np.linalg.norm(np.matmul(a, v) - w * v),
|
|
tol * np.linalg.norm(a))
|
|
|
|
@jtu.sample_product(
|
|
shape=[(4, 4), (5, 5), (50, 50)],
|
|
dtype=float_types + complex_types,
|
|
)
|
|
def testEigvalsh(self, shape, dtype):
|
|
rng = jtu.rand_default(self.rng())
|
|
n = shape[-1]
|
|
def args_maker():
|
|
a = rng((n, n), dtype)
|
|
a = (a + np.conj(a.T)) / 2
|
|
return [a]
|
|
self._CheckAgainstNumpy(
|
|
np.linalg.eigvalsh, jnp.linalg.eigvalsh, args_maker, tol=2e-5
|
|
)
|
|
|
|
@jtu.sample_product(
|
|
shape=[(1, 1), (4, 4), (5, 5), (50, 50), (2, 10, 10)],
|
|
dtype=float_types + complex_types,
|
|
lower=[True, False],
|
|
)
|
|
def testEighGrad(self, shape, dtype, lower):
|
|
rng = jtu.rand_default(self.rng())
|
|
self.skipTest("Test fails with numeric errors.")
|
|
uplo = "L" if lower else "U"
|
|
a = rng(shape, dtype)
|
|
a = (a + np.conj(T(a))) / 2
|
|
ones = np.ones((a.shape[-1], a.shape[-1]), dtype=dtype)
|
|
a *= np.tril(ones) if lower else np.triu(ones)
|
|
# Gradient checks will fail without symmetrization as the eigh jvp rule
|
|
# is only correct for tangents in the symmetric subspace, whereas the
|
|
# checker checks against unconstrained (co)tangents.
|
|
if dtype not in complex_types:
|
|
f = partial(jnp.linalg.eigh, UPLO=uplo, symmetrize_input=True)
|
|
else: # only check eigenvalue grads for complex matrices
|
|
f = lambda a: partial(jnp.linalg.eigh, UPLO=uplo, symmetrize_input=True)(a)[0]
|
|
jtu.check_grads(f, (a,), 2, rtol=1e-5)
|
|
|
|
@jtu.sample_product(
|
|
shape=[(1, 1), (4, 4), (5, 5), (50, 50)],
|
|
dtype=complex_types,
|
|
lower=[True, False],
|
|
eps=[1e-5],
|
|
)
|
|
def testEighGradVectorComplex(self, shape, dtype, lower, eps):
|
|
rng = jtu.rand_default(self.rng())
|
|
# Special case to test for complex eigenvector grad correctness.
|
|
# Exact eigenvector coordinate gradients are hard to test numerically for complex
|
|
# eigensystem solvers given the extra degrees of per-eigenvector phase freedom.
|
|
# Instead, we numerically verify the eigensystem properties on the perturbed
|
|
# eigenvectors. You only ever want to optimize eigenvector directions, not coordinates!
|
|
uplo = "L" if lower else "U"
|
|
a = rng(shape, dtype)
|
|
a = (a + np.conj(a.T)) / 2
|
|
a = np.tril(a) if lower else np.triu(a)
|
|
a_dot = eps * rng(shape, dtype)
|
|
a_dot = (a_dot + np.conj(a_dot.T)) / 2
|
|
a_dot = np.tril(a_dot) if lower else np.triu(a_dot)
|
|
# evaluate eigenvector gradient and groundtruth eigensystem for perturbed input matrix
|
|
f = partial(jnp.linalg.eigh, UPLO=uplo)
|
|
(w, v), (dw, dv) = jvp(f, primals=(a,), tangents=(a_dot,))
|
|
self.assertTrue(jnp.issubdtype(w.dtype, jnp.floating))
|
|
self.assertTrue(jnp.issubdtype(dw.dtype, jnp.floating))
|
|
new_a = a + a_dot
|
|
new_w, new_v = f(new_a)
|
|
new_a = (new_a + np.conj(new_a.T)) / 2
|
|
new_w = new_w.astype(new_a.dtype)
|
|
# Assert rtol eigenvalue delta between perturbed eigenvectors vs new true eigenvalues.
|
|
RTOL = 1e-2
|
|
with jax.numpy_rank_promotion('allow'):
|
|
assert np.max(
|
|
np.abs((np.diag(np.dot(np.conj((v+dv).T), np.dot(new_a,(v+dv)))) - new_w) / new_w)) < RTOL
|
|
# Redundant to above, but also assert rtol for eigenvector property with new true eigenvalues.
|
|
assert np.max(
|
|
np.linalg.norm(np.abs(new_w*(v+dv) - np.dot(new_a, (v+dv))), axis=0) /
|
|
np.linalg.norm(np.abs(new_w*(v+dv)), axis=0)
|
|
) < RTOL
|
|
|
|
def testEighGradPrecision(self):
|
|
rng = jtu.rand_default(self.rng())
|
|
a = rng((3, 3), np.float32)
|
|
jtu.assert_dot_precision(
|
|
lax.Precision.HIGHEST, partial(jvp, jnp.linalg.eigh), (a,), (a,))
|
|
|
|
@jtu.sample_product(
|
|
shape=[(1, 1), (4, 4), (5, 5), (300, 300)],
|
|
dtype=float_types + complex_types,
|
|
)
|
|
def testEighBatching(self, shape, dtype):
|
|
rng = jtu.rand_default(self.rng())
|
|
shape = (10,) + shape
|
|
args = rng(shape, dtype)
|
|
args = (args + np.conj(T(args))) / 2
|
|
ws, vs = vmap(jsp.linalg.eigh)(args)
|
|
ws = ws.astype(vs.dtype)
|
|
norm = np.max(np.linalg.norm(np.matmul(args, vs) - ws[..., None, :] * vs))
|
|
self.assertLess(norm, 1e-2)
|
|
|
|
@jtu.sample_product(
|
|
shape=[(1,), (4,), (5,)],
|
|
dtype=(np.int32,),
|
|
)
|
|
def testLuPivotsToPermutation(self, shape, dtype):
|
|
pivots_size = shape[-1]
|
|
permutation_size = 2 * pivots_size
|
|
|
|
pivots = jnp.arange(permutation_size - 1, pivots_size - 1, -1, dtype=dtype)
|
|
pivots = jnp.broadcast_to(pivots, shape)
|
|
actual = lax.linalg.lu_pivots_to_permutation(pivots, permutation_size)
|
|
expected = jnp.arange(permutation_size - 1, -1, -1, dtype=dtype)
|
|
expected = jnp.broadcast_to(expected, actual.shape)
|
|
self.assertArraysEqual(actual, expected)
|
|
|
|
@jtu.sample_product(
|
|
shape=[(1,), (4,), (5,)],
|
|
dtype=(np.int32,),
|
|
)
|
|
def testLuPivotsToPermutationBatching(self, shape, dtype):
|
|
shape = (10,) + shape
|
|
pivots_size = shape[-1]
|
|
permutation_size = 2 * pivots_size
|
|
|
|
pivots = jnp.arange(permutation_size - 1, pivots_size - 1, -1, dtype=dtype)
|
|
pivots = jnp.broadcast_to(pivots, shape)
|
|
batched_fn = vmap(
|
|
lambda x: lax.linalg.lu_pivots_to_permutation(x, permutation_size))
|
|
actual = batched_fn(pivots)
|
|
expected = jnp.arange(permutation_size - 1, -1, -1, dtype=dtype)
|
|
expected = jnp.broadcast_to(expected, actual.shape)
|
|
self.assertArraysEqual(actual, expected)
|
|
|
|
@jtu.sample_product(
|
|
[dict(axis=axis, shape=shape, ord=ord)
|
|
for axis, shape in [
|
|
(None, (1,)), (None, (7,)), (None, (5, 8)),
|
|
(0, (9,)), (0, (4, 5)), ((1,), (10, 7, 3)), ((-2,), (4, 8)),
|
|
(-1, (6, 3)), ((0, 2), (3, 4, 5)), ((2, 0), (7, 8, 9)),
|
|
(None, (7, 8, 11))]
|
|
for ord in (
|
|
[None] if axis is None and len(shape) > 2
|
|
else [None, 0, 1, 2, 3, -1, -2, -3, jnp.inf, -jnp.inf]
|
|
if (axis is None and len(shape) == 1) or
|
|
isinstance(axis, int) or
|
|
(isinstance(axis, tuple) and len(axis) == 1)
|
|
else [None, 'fro', 1, 2, -1, -2, jnp.inf, -jnp.inf, 'nuc'])
|
|
],
|
|
keepdims=[False, True],
|
|
dtype=float_types + complex_types,
|
|
)
|
|
def testNorm(self, shape, dtype, ord, axis, keepdims):
|
|
rng = jtu.rand_default(self.rng())
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
np_fn = partial(np.linalg.norm, ord=ord, axis=axis, keepdims=keepdims)
|
|
jnp_fn = partial(jnp.linalg.norm, ord=ord, axis=axis, keepdims=keepdims)
|
|
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False,
|
|
tol=1e-3)
|
|
self._CompileAndCheck(jnp_fn, args_maker)
|
|
|
|
def testStringInfNorm(self):
|
|
err, msg = ValueError, r"Invalid order 'inf' for vector norm."
|
|
with self.assertRaisesRegex(err, msg):
|
|
jnp.linalg.norm(jnp.array([1.0, 2.0, 3.0]), ord="inf")
|
|
|
|
@jtu.sample_product(
|
|
[
|
|
dict(m=m, n=n, full_matrices=full_matrices, hermitian=hermitian)
|
|
for (m, n), full_matrices in (
|
|
list(
|
|
itertools.product(
|
|
itertools.product([0, 2, 7, 29, 32, 53], repeat=2),
|
|
[False, True],
|
|
)
|
|
)
|
|
+
|
|
# Test cases that ensure we are economical when computing the SVD
|
|
# and its gradient. If we form a 400kx400k matrix explicitly we
|
|
# will OOM.
|
|
[((400000, 2), False), ((2, 400000), False)]
|
|
)
|
|
for hermitian in ([False, True] if m == n else [False])
|
|
],
|
|
b=[(), (3,), (2, 3)],
|
|
dtype=float_types + complex_types,
|
|
compute_uv=[False, True],
|
|
)
|
|
@jax.default_matmul_precision("float32")
|
|
def testSVD(self, b, m, n, dtype, full_matrices, compute_uv, hermitian):
|
|
rng = jtu.rand_default(self.rng())
|
|
args_maker = lambda: [rng(b + (m, n), dtype)]
|
|
|
|
def compute_max_backward_error(operand, reconstructed_operand):
|
|
error_norm = np.linalg.norm(operand - reconstructed_operand,
|
|
axis=(-2, -1))
|
|
backward_error = (error_norm /
|
|
np.linalg.norm(operand, axis=(-2, -1)))
|
|
max_backward_error = np.amax(backward_error)
|
|
return max_backward_error
|
|
|
|
tol = 80 * jnp.finfo(dtype).eps
|
|
reconstruction_tol = 2 * tol
|
|
unitariness_tol = tol
|
|
|
|
a, = args_maker()
|
|
if hermitian:
|
|
a = a + np.conj(T(a))
|
|
out = jnp.linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv,
|
|
hermitian=hermitian)
|
|
if compute_uv:
|
|
# Check the reconstructed matrices
|
|
out = list(out)
|
|
out[1] = out[1].astype(out[0].dtype) # for strict dtype promotion.
|
|
if m and n:
|
|
if full_matrices:
|
|
k = min(m, n)
|
|
if m < n:
|
|
max_backward_error = compute_max_backward_error(
|
|
a, np.matmul(out[1][..., None, :] * out[0], out[2][..., :k, :]))
|
|
self.assertLess(max_backward_error, reconstruction_tol)
|
|
else:
|
|
max_backward_error = compute_max_backward_error(
|
|
a, np.matmul(out[1][..., None, :] * out[0][..., :, :k], out[2]))
|
|
self.assertLess(max_backward_error, reconstruction_tol)
|
|
else:
|
|
max_backward_error = compute_max_backward_error(
|
|
a, np.matmul(out[1][..., None, :] * out[0], out[2]))
|
|
self.assertLess(max_backward_error, reconstruction_tol)
|
|
|
|
# Check the unitary properties of the singular vector matrices.
|
|
unitary_mat = np.real(np.matmul(np.conj(T(out[0])), out[0]))
|
|
eye_slice = np.eye(out[0].shape[-1], dtype=unitary_mat.dtype)
|
|
self.assertAllClose(np.broadcast_to(eye_slice, b + eye_slice.shape),
|
|
unitary_mat, rtol=unitariness_tol,
|
|
atol=unitariness_tol)
|
|
if m >= n:
|
|
unitary_mat = np.real(np.matmul(np.conj(T(out[2])), out[2]))
|
|
eye_slice = np.eye(out[2].shape[-1], dtype=unitary_mat.dtype)
|
|
self.assertAllClose(np.broadcast_to(eye_slice, b + eye_slice.shape),
|
|
unitary_mat, rtol=unitariness_tol,
|
|
atol=unitariness_tol)
|
|
else:
|
|
unitary_mat = np.real(np.matmul(out[2], np.conj(T(out[2]))))
|
|
eye_slice = np.eye(out[2].shape[-2], dtype=unitary_mat.dtype)
|
|
self.assertAllClose(np.broadcast_to(eye_slice, b + eye_slice.shape),
|
|
unitary_mat, rtol=unitariness_tol,
|
|
atol=unitariness_tol)
|
|
else:
|
|
self.assertTrue(np.allclose(np.linalg.svd(a, compute_uv=False),
|
|
np.asarray(out), atol=1e-4, rtol=1e-4))
|
|
|
|
self._CompileAndCheck(partial(jnp.linalg.svd, full_matrices=full_matrices,
|
|
compute_uv=compute_uv),
|
|
args_maker)
|
|
|
|
if not compute_uv and a.size < 100000:
|
|
svd = partial(jnp.linalg.svd, full_matrices=full_matrices,
|
|
compute_uv=compute_uv)
|
|
# TODO(phawkins): these tolerances seem very loose.
|
|
if dtype == np.complex128:
|
|
jtu.check_jvp(svd, partial(jvp, svd), (a,), rtol=1e-4, atol=1e-4,
|
|
eps=1e-8)
|
|
else:
|
|
jtu.check_jvp(svd, partial(jvp, svd), (a,), rtol=5e-2, atol=2e-1)
|
|
|
|
if compute_uv and (not full_matrices):
|
|
b, = args_maker()
|
|
def f(x):
|
|
u, s, v = jnp.linalg.svd(
|
|
a + x * b,
|
|
full_matrices=full_matrices,
|
|
compute_uv=compute_uv)
|
|
vdiag = jnp.vectorize(jnp.diag, signature='(k)->(k,k)')
|
|
return jnp.matmul(jnp.matmul(u, vdiag(s).astype(u.dtype)), v).real
|
|
_, t_out = jvp(f, (1.,), (1.,))
|
|
if dtype == np.complex128:
|
|
atol = 2e-13
|
|
else:
|
|
atol = 5e-4
|
|
self.assertArraysAllClose(t_out, b.real, atol=atol)
|
|
|
|
def testJspSVDBasic(self):
|
|
# since jax.scipy.linalg.svd is almost the same as jax.numpy.linalg.svd
|
|
# do not check it functionality here
|
|
jsp.linalg.svd(np.ones((2, 2), dtype=np.float32))
|
|
|
|
@jtu.sample_product(
|
|
shape=[(0, 2), (2, 0), (3, 4), (3, 3), (4, 3)],
|
|
dtype=[np.float32],
|
|
mode=["reduced", "r", "full", "complete", "raw"],
|
|
)
|
|
def testNumpyQrModes(self, shape, dtype, mode):
|
|
rng = jtu.rand_default(self.rng())
|
|
jnp_func = partial(jax.numpy.linalg.qr, mode=mode)
|
|
np_func = partial(np.linalg.qr, mode=mode)
|
|
if mode == "full":
|
|
np_func = jtu.ignore_warning(category=DeprecationWarning, message="The 'full' option.*")(np_func)
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
self._CheckAgainstNumpy(np_func, jnp_func, args_maker, rtol=1e-5, atol=1e-5,
|
|
check_dtypes=(mode != "raw"))
|
|
self._CompileAndCheck(jnp_func, args_maker)
|
|
|
|
@jtu.sample_product(
|
|
shape=[(0, 0), (2, 0), (0, 2), (3, 3), (3, 4), (2, 10, 5),
|
|
(2, 200, 100), (64, 16, 5), (33, 7, 3), (137, 9, 5), (20000, 2, 2)],
|
|
dtype=float_types + complex_types,
|
|
full_matrices=[False, True],
|
|
)
|
|
@jax.default_matmul_precision("float32")
|
|
def testQr(self, shape, dtype, full_matrices):
|
|
if (jtu.test_device_matches(["cuda"]) and
|
|
_is_required_cuda_version_satisfied(12000)):
|
|
self.skipTest("Triggers a bug in cuda-12 b/287345077")
|
|
rng = jtu.rand_default(self.rng())
|
|
m, n = shape[-2:]
|
|
|
|
if full_matrices:
|
|
mode, k = "complete", m
|
|
else:
|
|
mode, k = "reduced", min(m, n)
|
|
|
|
a = rng(shape, dtype)
|
|
lq, lr = jnp.linalg.qr(a, mode=mode)
|
|
|
|
# np.linalg.qr doesn't support batch dimensions. But it seems like an
|
|
# inevitable extension so we support it in our version.
|
|
nq = np.zeros(shape[:-2] + (m, k), dtype)
|
|
nr = np.zeros(shape[:-2] + (k, n), dtype)
|
|
for index in np.ndindex(*shape[:-2]):
|
|
nq[index], nr[index] = np.linalg.qr(a[index], mode=mode)
|
|
|
|
max_rank = max(m, n)
|
|
|
|
# Norm, adjusted for dimension and type.
|
|
def norm(x):
|
|
n = np.linalg.norm(x, axis=(-2, -1))
|
|
return n / (max(1, max_rank) * jnp.finfo(dtype).eps)
|
|
|
|
def compare_orthogonal(q1, q2):
|
|
# Q is unique up to sign, so normalize the sign first.
|
|
ratio = np.divide(np.where(q2 == 0, 0, q1), np.where(q2 == 0, 1, q2))
|
|
sum_of_ratios = ratio.sum(axis=-2, keepdims=True)
|
|
phases = np.divide(sum_of_ratios, np.abs(sum_of_ratios))
|
|
q1 *= phases
|
|
nm = norm(q1 - q2)
|
|
self.assertTrue(np.all(nm < 160), msg=f"norm={np.amax(nm)}")
|
|
|
|
# Check a ~= qr
|
|
norm_error = norm(a - np.matmul(lq, lr))
|
|
self.assertTrue(np.all(norm_error < 60), msg=np.amax(norm_error))
|
|
|
|
# Compare the first 'k' vectors of Q; the remainder form an arbitrary
|
|
# orthonormal basis for the null space.
|
|
compare_orthogonal(nq[..., :k], lq[..., :k])
|
|
|
|
# Check that q is close to unitary.
|
|
self.assertTrue(np.all(
|
|
norm(np.eye(k) - np.matmul(np.conj(T(lq)), lq)) < 10))
|
|
|
|
# This expresses identity function, which makes us robust to, e.g., the
|
|
# tangents flipping the direction of vectors in Q.
|
|
def qr_and_mul(a):
|
|
q, r = jnp.linalg.qr(a, mode=mode)
|
|
return q @ r
|
|
|
|
if m == n or (m > n and not full_matrices):
|
|
jtu.check_jvp(qr_and_mul, partial(jvp, qr_and_mul), (a,), atol=3e-3)
|
|
|
|
@jtu.skip_on_devices("tpu")
|
|
def testQrInvalidDtypeCPU(self, shape=(5, 6), dtype=np.float16):
|
|
# Regression test for https://github.com/google/jax/issues/10530
|
|
rng = jtu.rand_default(self.rng())
|
|
arr = rng(shape, dtype)
|
|
if jtu.test_device_matches(['cpu']):
|
|
err, msg = NotImplementedError, "Unsupported dtype float16"
|
|
else:
|
|
err, msg = ValueError, r"Unsupported dtype dtype\('float16'\)"
|
|
with self.assertRaisesRegex(err, msg):
|
|
jnp.linalg.qr(arr)
|
|
|
|
@jtu.sample_product(
|
|
shape=[(10, 4, 5), (5, 3, 3), (7, 6, 4)],
|
|
dtype=float_types + complex_types,
|
|
)
|
|
def testQrBatching(self, shape, dtype):
|
|
rng = jtu.rand_default(self.rng())
|
|
args = rng(shape, jnp.float32)
|
|
qs, rs = vmap(jsp.linalg.qr)(args)
|
|
self.assertTrue(np.all(np.linalg.norm(args - np.matmul(qs, rs)) < 1e-3))
|
|
|
|
@jtu.sample_product(
|
|
shape=[(1, 1), (4, 4), (2, 3, 5), (5, 5, 5), (20, 20), (5, 10)],
|
|
pnorm=[jnp.inf, -jnp.inf, 1, -1, 2, -2, 'fro'],
|
|
dtype=float_types + complex_types,
|
|
)
|
|
@jtu.skip_on_devices("gpu") # TODO(#2203): numerical errors
|
|
def testCond(self, shape, pnorm, dtype):
|
|
def gen_mat():
|
|
# arr_gen = jtu.rand_some_nan(self.rng())
|
|
arr_gen = jtu.rand_default(self.rng())
|
|
res = arr_gen(shape, dtype)
|
|
return res
|
|
|
|
def args_gen(p):
|
|
def _args_gen():
|
|
return [gen_mat(), p]
|
|
return _args_gen
|
|
|
|
args_maker = args_gen(pnorm)
|
|
if pnorm not in [2, -2] and len(set(shape[-2:])) != 1:
|
|
with self.assertRaises(np.linalg.LinAlgError):
|
|
jnp.linalg.cond(*args_maker())
|
|
else:
|
|
self._CheckAgainstNumpy(np.linalg.cond, jnp.linalg.cond, args_maker,
|
|
check_dtypes=False, tol=1e-3)
|
|
partial_norm = partial(jnp.linalg.cond, p=pnorm)
|
|
self._CompileAndCheck(partial_norm, lambda: [gen_mat()],
|
|
check_dtypes=False, rtol=1e-03, atol=1e-03)
|
|
|
|
@jtu.sample_product(
|
|
shape=[(1, 1), (4, 4), (200, 200), (7, 7, 7, 7)],
|
|
dtype=float_types,
|
|
)
|
|
def testTensorinv(self, shape, dtype):
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
def tensor_maker():
|
|
invertible = False
|
|
while not invertible:
|
|
a = rng(shape, dtype)
|
|
try:
|
|
np.linalg.inv(a)
|
|
invertible = True
|
|
except np.linalg.LinAlgError:
|
|
pass
|
|
return a
|
|
|
|
args_maker = lambda: [tensor_maker(), int(np.floor(len(shape) / 2))]
|
|
self._CheckAgainstNumpy(np.linalg.tensorinv, jnp.linalg.tensorinv, args_maker,
|
|
check_dtypes=False, tol=1e-3)
|
|
partial_inv = partial(jnp.linalg.tensorinv, ind=int(np.floor(len(shape) / 2)))
|
|
self._CompileAndCheck(partial_inv, lambda: [tensor_maker()], check_dtypes=False, rtol=1e-03, atol=1e-03)
|
|
|
|
@jtu.sample_product(
|
|
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape)
|
|
for lhs_shape, rhs_shape in [
|
|
((1, 1), (1, 1)),
|
|
((4, 4), (4,)),
|
|
((8, 8), (8, 4)),
|
|
((1, 2, 2), (3, 2)),
|
|
((2, 1, 3, 3), (1, 4, 3, 4)),
|
|
((1, 0, 0), (1, 0, 2)),
|
|
]
|
|
],
|
|
dtype=float_types + complex_types,
|
|
)
|
|
def testSolve(self, lhs_shape, rhs_shape, dtype):
|
|
rng = jtu.rand_default(self.rng())
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
|
|
|
self._CheckAgainstNumpy(np.linalg.solve, jnp.linalg.solve, args_maker,
|
|
tol=1e-3)
|
|
self._CompileAndCheck(jnp.linalg.solve, args_maker)
|
|
|
|
@jtu.sample_product(
|
|
shape=[(1, 1), (4, 4), (2, 5, 5), (100, 100), (5, 5, 5), (0, 0)],
|
|
dtype=float_types,
|
|
)
|
|
def testInv(self, shape, dtype):
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
def args_maker():
|
|
invertible = False
|
|
while not invertible:
|
|
a = rng(shape, dtype)
|
|
try:
|
|
np.linalg.inv(a)
|
|
invertible = True
|
|
except np.linalg.LinAlgError:
|
|
pass
|
|
return [a]
|
|
|
|
self._CheckAgainstNumpy(np.linalg.inv, jnp.linalg.inv, args_maker,
|
|
tol=1e-3)
|
|
self._CompileAndCheck(jnp.linalg.inv, args_maker)
|
|
|
|
@jtu.sample_product(
|
|
[dict(shape=shape, hermitian=hermitian)
|
|
for shape in [(1, 1), (4, 4), (3, 10, 10), (2, 70, 7), (2000, 7),
|
|
(7, 1000), (70, 7, 2), (2, 0, 0), (3, 0, 2), (1, 0),
|
|
(400000, 2), (2, 400000)]
|
|
for hermitian in ([False, True] if shape[-1] == shape[-2] else [False])],
|
|
dtype=float_types + complex_types,
|
|
)
|
|
def testPinv(self, shape, hermitian, dtype):
|
|
rng = jtu.rand_default(self.rng())
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
jnp_fn = partial(jnp.linalg.pinv, hermitian=hermitian)
|
|
def np_fn(a):
|
|
# Symmetrize the input matrix to match the jnp behavior.
|
|
if hermitian:
|
|
a = (a + T(a.conj())) / 2
|
|
return np.linalg.pinv(a, hermitian=hermitian)
|
|
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-4)
|
|
self._CompileAndCheck(jnp_fn, args_maker)
|
|
|
|
# TODO(phawkins): 6e-2 seems like a very loose tolerance.
|
|
jtu.check_grads(jnp_fn, args_maker(), 1, rtol=6e-2, atol=1e-3)
|
|
|
|
def testPinvGradIssue2792(self):
|
|
def f(p):
|
|
a = jnp.array([[0., 0.],[-p, 1.]], jnp.float32) * 1 / (1 + p**2)
|
|
return jnp.linalg.pinv(a)
|
|
j = jax.jacobian(f)(jnp.float32(2.))
|
|
self.assertAllClose(jnp.array([[0., -1.], [ 0., 0.]], jnp.float32), j)
|
|
|
|
expected = jnp.array([[[[-1., 0.], [ 0., 0.]], [[0., -1.], [0., 0.]]],
|
|
[[[0., 0.], [-1., 0.]], [[0., 0.], [0., -1.]]]],
|
|
dtype=jnp.float32)
|
|
self.assertAllClose(
|
|
expected, jax.jacobian(jnp.linalg.pinv)(jnp.eye(2, dtype=jnp.float32)))
|
|
|
|
@jtu.sample_product(
|
|
shape=[(1, 1), (2, 2), (4, 4), (5, 5), (1, 2, 2), (2, 3, 3), (2, 5, 5)],
|
|
dtype=float_types + complex_types,
|
|
n=[-5, -2, -1, 0, 1, 2, 3, 4, 5, 10],
|
|
)
|
|
@jax.default_matmul_precision("float32")
|
|
def testMatrixPower(self, shape, dtype, n):
|
|
rng = jtu.rand_default(self.rng())
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
self._CheckAgainstNumpy(partial(np.linalg.matrix_power, n=n),
|
|
partial(jnp.linalg.matrix_power, n=n),
|
|
args_maker, tol=1e-3)
|
|
self._CompileAndCheck(partial(jnp.linalg.matrix_power, n=n), args_maker,
|
|
rtol=1e-3)
|
|
|
|
@jtu.sample_product(
|
|
shape=[(3, ), (1, 2), (8, 5), (4, 4), (5, 5), (50, 50), (3, 4, 5),
|
|
(2, 3, 4, 5)],
|
|
dtype=float_types + complex_types,
|
|
)
|
|
def testMatrixRank(self, shape, dtype):
|
|
rng = jtu.rand_default(self.rng())
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
a, = args_maker()
|
|
self._CheckAgainstNumpy(np.linalg.matrix_rank, jnp.linalg.matrix_rank,
|
|
args_maker, check_dtypes=False, tol=1e-3)
|
|
self._CompileAndCheck(jnp.linalg.matrix_rank, args_maker,
|
|
check_dtypes=False, rtol=1e-3)
|
|
|
|
@jtu.sample_product(
|
|
shapes=[
|
|
[(3, ), (3, 1)], # quick-out codepath
|
|
[(1, 3), (3, 5), (5, 2)], # multi_dot_three codepath
|
|
[(1, 3), (3, 5), (5, 2), (2, 7), (7, )] # dynamic programming codepath
|
|
],
|
|
dtype=float_types + complex_types,
|
|
)
|
|
def testMultiDot(self, shapes, dtype):
|
|
rng = jtu.rand_default(self.rng())
|
|
args_maker = lambda: [[rng(shape, dtype) for shape in shapes]]
|
|
|
|
np_fun = np.linalg.multi_dot
|
|
jnp_fun = partial(jnp.linalg.multi_dot, precision=lax.Precision.HIGHEST)
|
|
tol = {np.float32: 1e-4, np.float64: 1e-10,
|
|
np.complex64: 1e-4, np.complex128: 1e-10}
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol)
|
|
self._CompileAndCheck(jnp_fun, args_maker,
|
|
atol=tol, rtol=tol)
|
|
|
|
@jtu.sample_product(
|
|
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape)
|
|
for lhs_shape, rhs_shape in [
|
|
((1, 1), (1, 1)),
|
|
((4, 6), (4,)),
|
|
((6, 6), (6, 1)),
|
|
((8, 6), (8, 4)),
|
|
((0, 3), (0,)),
|
|
((3, 0), (3,)),
|
|
((3, 1), (3, 0)),
|
|
]
|
|
],
|
|
rcond=[-1, None, 0.5],
|
|
dtype=float_types + complex_types,
|
|
)
|
|
def testLstsq(self, lhs_shape, rhs_shape, dtype, rcond):
|
|
rng = jtu.rand_default(self.rng())
|
|
np_fun = partial(np.linalg.lstsq, rcond=rcond)
|
|
jnp_fun = partial(jnp.linalg.lstsq, rcond=rcond)
|
|
jnp_fun_numpy_resid = partial(jnp.linalg.lstsq, rcond=rcond, numpy_resid=True)
|
|
tol = {np.float32: 1e-4, np.float64: 1e-12,
|
|
np.complex64: 1e-5, np.complex128: 1e-12}
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun_numpy_resid, args_maker, check_dtypes=False, tol=tol)
|
|
self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol)
|
|
|
|
# Disabled because grad is flaky for low-rank inputs.
|
|
# TODO:
|
|
# jtu.check_grads(lambda *args: jnp_fun(*args)[0], args_maker(), order=2, atol=1e-2, rtol=1e-2)
|
|
|
|
# Regression test for incorrect type for eigenvalues of a complex matrix.
|
|
def testIssue669(self):
|
|
def test(x):
|
|
val, vec = jnp.linalg.eigh(x)
|
|
return jnp.real(jnp.sum(val))
|
|
|
|
grad_test_jc = jit(grad(jit(test)))
|
|
xc = np.eye(3, dtype=np.complex64)
|
|
self.assertAllClose(xc, grad_test_jc(xc))
|
|
|
|
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
|
def testIssue1151(self):
|
|
rng = self.rng()
|
|
A = jnp.array(rng.randn(100, 3, 3), dtype=jnp.float32)
|
|
b = jnp.array(rng.randn(100, 3), dtype=jnp.float32)
|
|
x = jnp.linalg.solve(A, b)
|
|
self.assertAllClose(vmap(jnp.dot)(A, x), b, atol=2e-3, rtol=1e-2)
|
|
|
|
_ = jax.jacobian(jnp.linalg.solve, argnums=0)(A, b)
|
|
_ = jax.jacobian(jnp.linalg.solve, argnums=1)(A, b)
|
|
|
|
_ = jax.jacobian(jnp.linalg.solve, argnums=0)(A[0], b[0])
|
|
_ = jax.jacobian(jnp.linalg.solve, argnums=1)(A[0], b[0])
|
|
|
|
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
|
@jax.legacy_prng_key("allow")
|
|
def testIssue1383(self):
|
|
seed = jax.random.PRNGKey(0)
|
|
tmp = jax.random.uniform(seed, (2,2))
|
|
a = jnp.dot(tmp, tmp.T)
|
|
|
|
def f(inp):
|
|
val, vec = jnp.linalg.eigh(inp)
|
|
return jnp.dot(jnp.dot(vec, inp), vec.T)
|
|
|
|
grad_func = jax.jacfwd(f)
|
|
hess_func = jax.jacfwd(grad_func)
|
|
cube_func = jax.jacfwd(hess_func)
|
|
self.assertFalse(np.any(np.isnan(cube_func(a))))
|
|
|
|
|
|
class ScipyLinalgTest(jtu.JaxTestCase):
|
|
|
|
@jtu.sample_product(
|
|
args=[
|
|
(),
|
|
(1,),
|
|
(7, -2),
|
|
(3, 4, 5),
|
|
(np.ones((3, 4), dtype=float), 5,
|
|
np.random.randn(5, 2).astype(float)),
|
|
]
|
|
)
|
|
def testBlockDiag(self, args):
|
|
args_maker = lambda: args
|
|
self._CheckAgainstNumpy(osp.linalg.block_diag, jsp.linalg.block_diag,
|
|
args_maker, check_dtypes=False)
|
|
self._CompileAndCheck(jsp.linalg.block_diag, args_maker)
|
|
|
|
|
|
@jtu.sample_product(
|
|
shape=[(1, 1), (4, 5), (10, 5), (50, 50)],
|
|
dtype=float_types + complex_types,
|
|
)
|
|
def testLu(self, shape, dtype):
|
|
rng = jtu.rand_default(self.rng())
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
x, = args_maker()
|
|
p, l, u = jsp.linalg.lu(x)
|
|
self.assertAllClose(x, np.matmul(p, np.matmul(l, u)),
|
|
rtol={np.float32: 1e-3, np.float64: 5e-12,
|
|
np.complex64: 1e-3, np.complex128: 1e-12},
|
|
atol={np.float32: 1e-5})
|
|
self._CompileAndCheck(jsp.linalg.lu, args_maker)
|
|
|
|
def testLuOfSingularMatrix(self):
|
|
x = jnp.array([[-1., 3./2], [2./3, -1.]], dtype=np.float32)
|
|
p, l, u = jsp.linalg.lu(x)
|
|
self.assertAllClose(x, np.matmul(p, np.matmul(l, u)))
|
|
|
|
@jtu.sample_product(
|
|
shape=[(1, 1), (4, 5), (10, 5), (10, 10), (6, 7, 7)],
|
|
dtype=float_types + complex_types,
|
|
)
|
|
def testLuGrad(self, shape, dtype):
|
|
rng = jtu.rand_default(self.rng())
|
|
a = rng(shape, dtype)
|
|
lu = vmap(jsp.linalg.lu) if len(shape) > 2 else jsp.linalg.lu
|
|
jtu.check_grads(lu, (a,), 2, atol=5e-2, rtol=3e-1)
|
|
|
|
@jtu.sample_product(
|
|
shape=[(4, 5), (6, 5)],
|
|
dtype=[jnp.float32],
|
|
)
|
|
def testLuBatching(self, shape, dtype):
|
|
rng = jtu.rand_default(self.rng())
|
|
args = [rng(shape, jnp.float32) for _ in range(10)]
|
|
expected = list(osp.linalg.lu(x) for x in args)
|
|
ps = np.stack([out[0] for out in expected])
|
|
ls = np.stack([out[1] for out in expected])
|
|
us = np.stack([out[2] for out in expected])
|
|
|
|
actual_ps, actual_ls, actual_us = vmap(jsp.linalg.lu)(jnp.stack(args))
|
|
self.assertAllClose(ps, actual_ps)
|
|
self.assertAllClose(ls, actual_ls, rtol=5e-6)
|
|
self.assertAllClose(us, actual_us)
|
|
|
|
@jtu.skip_on_devices("cpu", "tpu")
|
|
def testLuCPUBackendOnGPU(self):
|
|
# tests running `lu` on cpu when a gpu is present.
|
|
jit(jsp.linalg.lu, backend="cpu")(np.ones((2, 2))) # does not crash
|
|
|
|
@jtu.sample_product(
|
|
n=[1, 4, 5, 200],
|
|
dtype=float_types + complex_types,
|
|
)
|
|
def testLuFactor(self, n, dtype):
|
|
rng = jtu.rand_default(self.rng())
|
|
args_maker = lambda: [rng((n, n), dtype)]
|
|
|
|
x, = args_maker()
|
|
lu, piv = jsp.linalg.lu_factor(x)
|
|
l = np.tril(lu, -1) + np.eye(n, dtype=dtype)
|
|
u = np.triu(lu)
|
|
for i in range(n):
|
|
x[[i, piv[i]],] = x[[piv[i], i],]
|
|
self.assertAllClose(x, np.matmul(l, u), rtol=1e-3,
|
|
atol=1e-3)
|
|
self._CompileAndCheck(jsp.linalg.lu_factor, args_maker)
|
|
|
|
@jtu.sample_product(
|
|
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape)
|
|
for lhs_shape, rhs_shape in [
|
|
((1, 1), (1, 1)),
|
|
((4, 4), (4,)),
|
|
((8, 8), (8, 4)),
|
|
]
|
|
],
|
|
trans=[0, 1, 2],
|
|
dtype=float_types + complex_types,
|
|
)
|
|
@jtu.skip_on_devices("cpu") # TODO(frostig): Test fails on CPU sometimes
|
|
def testLuSolve(self, lhs_shape, rhs_shape, dtype, trans):
|
|
rng = jtu.rand_default(self.rng())
|
|
osp_fun = lambda lu, piv, rhs: osp.linalg.lu_solve((lu, piv), rhs, trans=trans)
|
|
jsp_fun = lambda lu, piv, rhs: jsp.linalg.lu_solve((lu, piv), rhs, trans=trans)
|
|
|
|
def args_maker():
|
|
a = rng(lhs_shape, dtype)
|
|
lu, piv = osp.linalg.lu_factor(a)
|
|
return [lu, piv, rng(rhs_shape, dtype)]
|
|
|
|
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, tol=1e-3)
|
|
self._CompileAndCheck(jsp_fun, args_maker)
|
|
|
|
@jtu.sample_product(
|
|
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape)
|
|
for lhs_shape, rhs_shape in [
|
|
((1, 1), (1, 1)),
|
|
((4, 4), (4,)),
|
|
((8, 8), (8, 4)),
|
|
]
|
|
],
|
|
[dict(assume_a=assume_a, lower=lower)
|
|
for assume_a, lower in [
|
|
('gen', False),
|
|
('pos', False),
|
|
('pos', True),
|
|
]
|
|
],
|
|
dtype=float_types + complex_types,
|
|
)
|
|
def testSolve(self, lhs_shape, rhs_shape, dtype, assume_a, lower):
|
|
rng = jtu.rand_default(self.rng())
|
|
osp_fun = lambda lhs, rhs: osp.linalg.solve(lhs, rhs, assume_a=assume_a, lower=lower)
|
|
jsp_fun = lambda lhs, rhs: jsp.linalg.solve(lhs, rhs, assume_a=assume_a, lower=lower)
|
|
|
|
def args_maker():
|
|
a = rng(lhs_shape, dtype)
|
|
if assume_a == 'pos':
|
|
a = np.matmul(a, np.conj(T(a)))
|
|
a = np.tril(a) if lower else np.triu(a)
|
|
return [a, rng(rhs_shape, dtype)]
|
|
|
|
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, tol=1e-3)
|
|
self._CompileAndCheck(jsp_fun, args_maker)
|
|
|
|
@jtu.sample_product(
|
|
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape)
|
|
for lhs_shape, rhs_shape in [
|
|
((4, 4), (4,)),
|
|
((4, 4), (4, 3)),
|
|
((2, 8, 8), (2, 8, 10)),
|
|
]
|
|
],
|
|
lower=[False, True],
|
|
transpose_a=[False, True],
|
|
unit_diagonal=[False, True],
|
|
dtype=float_types,
|
|
)
|
|
def testSolveTriangular(self, lower, transpose_a, unit_diagonal, lhs_shape,
|
|
rhs_shape, dtype):
|
|
rng = jtu.rand_default(self.rng())
|
|
k = rng(lhs_shape, dtype)
|
|
l = np.linalg.cholesky(np.matmul(k, T(k))
|
|
+ lhs_shape[-1] * np.eye(lhs_shape[-1]))
|
|
l = l.astype(k.dtype)
|
|
b = rng(rhs_shape, dtype)
|
|
|
|
if unit_diagonal:
|
|
a = np.tril(l, -1) + np.eye(lhs_shape[-1], dtype=dtype)
|
|
else:
|
|
a = l
|
|
a = a if lower else T(a)
|
|
|
|
inv = np.linalg.inv(T(a) if transpose_a else a).astype(a.dtype)
|
|
if len(lhs_shape) == len(rhs_shape):
|
|
np_ans = np.matmul(inv, b)
|
|
else:
|
|
np_ans = np.einsum("...ij,...j->...i", inv, b)
|
|
|
|
# The standard scipy.linalg.solve_triangular doesn't support broadcasting.
|
|
# But it seems like an inevitable extension so we support it.
|
|
ans = jsp.linalg.solve_triangular(
|
|
l if lower else T(l), b, trans=1 if transpose_a else 0, lower=lower,
|
|
unit_diagonal=unit_diagonal)
|
|
|
|
self.assertAllClose(np_ans, ans,
|
|
rtol={np.float32: 1e-4, np.float64: 1e-11})
|
|
|
|
@jtu.sample_product(
|
|
[dict(left_side=left_side, a_shape=a_shape, b_shape=b_shape)
|
|
for left_side, a_shape, b_shape in [
|
|
(False, (4, 4), (4,)),
|
|
(False, (4, 4), (1, 4,)),
|
|
(False, (3, 3), (4, 3)),
|
|
(True, (4, 4), (4,)),
|
|
(True, (4, 4), (4, 1)),
|
|
(True, (4, 4), (4, 3)),
|
|
(True, (2, 8, 8), (2, 8, 10)),
|
|
]
|
|
],
|
|
[dict(dtype=dtype, conjugate_a=conjugate_a)
|
|
for dtype in float_types + complex_types
|
|
for conjugate_a in (
|
|
[False] if jnp.issubdtype(dtype, jnp.floating) else [False, True])
|
|
],
|
|
lower=[False, True],
|
|
unit_diagonal=[False, True],
|
|
transpose_a=[False, True],
|
|
)
|
|
def testTriangularSolveGrad(
|
|
self, lower, transpose_a, conjugate_a, unit_diagonal, left_side, a_shape,
|
|
b_shape, dtype):
|
|
rng = jtu.rand_default(self.rng())
|
|
# Test lax.linalg.triangular_solve instead of scipy.linalg.solve_triangular
|
|
# because it exposes more options.
|
|
A = jnp.tril(rng(a_shape, dtype) + 5 * np.eye(a_shape[-1], dtype=dtype))
|
|
A = A if lower else T(A)
|
|
B = rng(b_shape, dtype)
|
|
f = partial(lax.linalg.triangular_solve, lower=lower, transpose_a=transpose_a,
|
|
conjugate_a=conjugate_a, unit_diagonal=unit_diagonal,
|
|
left_side=left_side)
|
|
jtu.check_grads(f, (A, B), order=1, rtol=4e-2, eps=1e-3)
|
|
|
|
@jtu.sample_product(
|
|
[dict(left_side=left_side, a_shape=a_shape, b_shape=b_shape, bdims=bdims)
|
|
for left_side, a_shape, b_shape, bdims in [
|
|
(False, (4, 4), (2, 3, 4,), (None, 0)),
|
|
(False, (2, 4, 4), (2, 2, 3, 4,), (None, 0)),
|
|
(False, (2, 4, 4), (3, 4,), (0, None)),
|
|
(False, (2, 4, 4), (2, 3, 4,), (0, 0)),
|
|
(True, (2, 4, 4), (2, 4, 3), (0, 0)),
|
|
(True, (2, 4, 4), (2, 2, 4, 3), (None, 0)),
|
|
]
|
|
],
|
|
)
|
|
def testTriangularSolveBatching(self, left_side, a_shape, b_shape, bdims):
|
|
rng = jtu.rand_default(self.rng())
|
|
A = jnp.tril(rng(a_shape, np.float32)
|
|
+ 5 * np.eye(a_shape[-1], dtype=np.float32))
|
|
B = rng(b_shape, np.float32)
|
|
solve = partial(lax.linalg.triangular_solve, lower=True, transpose_a=False,
|
|
conjugate_a=False, unit_diagonal=False, left_side=left_side)
|
|
X = vmap(solve, bdims)(A, B)
|
|
matmul = partial(jnp.matmul, precision=lax.Precision.HIGHEST)
|
|
Y = matmul(A, X) if left_side else matmul(X, A)
|
|
self.assertArraysAllClose(Y, jnp.broadcast_to(B, Y.shape), atol=1e-4)
|
|
|
|
def testTriangularSolveGradPrecision(self):
|
|
rng = jtu.rand_default(self.rng())
|
|
a = jnp.tril(rng((3, 3), np.float32))
|
|
b = rng((1, 3), np.float32)
|
|
jtu.assert_dot_precision(
|
|
lax.Precision.HIGHEST,
|
|
partial(jvp, lax.linalg.triangular_solve),
|
|
(a, b),
|
|
(a, b))
|
|
|
|
@jtu.sample_product(
|
|
n=[1, 4, 5, 20, 50, 100],
|
|
batch_size=[(), (2,), (3, 4)] if scipy_version >= (1, 9, 0) else [()],
|
|
dtype=int_types + float_types + complex_types
|
|
)
|
|
def testExpm(self, n, batch_size, dtype):
|
|
if (jtu.test_device_matches(["cuda"]) and
|
|
_is_required_cuda_version_satisfied(12000)):
|
|
self.skipTest("Triggers a bug in cuda-12 b/287345077")
|
|
|
|
rng = jtu.rand_small(self.rng())
|
|
args_maker = lambda: [rng((*batch_size, n, n), dtype)]
|
|
|
|
# Compare to numpy with JAX type promotion semantics.
|
|
def osp_fun(A):
|
|
return osp.linalg.expm(np.array(*promote_dtypes_inexact(A)))
|
|
jsp_fun = jsp.linalg.expm
|
|
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker)
|
|
self._CompileAndCheck(jsp_fun, args_maker)
|
|
|
|
args_maker_triu = lambda: [np.triu(rng((*batch_size, n, n), dtype))]
|
|
jsp_fun_triu = lambda a: jsp.linalg.expm(a, upper_triangular=True)
|
|
self._CheckAgainstNumpy(osp_fun, jsp_fun_triu, args_maker_triu)
|
|
self._CompileAndCheck(jsp_fun_triu, args_maker_triu)
|
|
|
|
@jtu.sample_product(
|
|
# Skip empty shapes because scipy fails: https://github.com/scipy/scipy/issues/1532
|
|
shape=[(3, 4), (3, 3), (4, 3)],
|
|
dtype=[np.float32],
|
|
mode=["full", "r", "economic"],
|
|
)
|
|
def testScipyQrModes(self, shape, dtype, mode):
|
|
rng = jtu.rand_default(self.rng())
|
|
jsp_func = partial(jax.scipy.linalg.qr, mode=mode)
|
|
sp_func = partial(scipy.linalg.qr, mode=mode)
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
self._CheckAgainstNumpy(sp_func, jsp_func, args_maker, rtol=1E-5, atol=1E-5)
|
|
self._CompileAndCheck(jsp_func, args_maker)
|
|
|
|
@jtu.sample_product(
|
|
[dict(shape=shape, k=k)
|
|
for shape in [(1, 1), (3, 4, 4), (10, 5)]
|
|
# TODO(phawkins): there are some test failures on GPU for k=0
|
|
for k in range(1, shape[-1] + 1)],
|
|
dtype=float_types + complex_types,
|
|
)
|
|
def testHouseholderProduct(self, shape, k, dtype):
|
|
|
|
@partial(np.vectorize, signature='(m,n),(k)->(m,n)')
|
|
def reference_fn(a, taus):
|
|
if dtype == np.float32:
|
|
q, _, info = scipy.linalg.lapack.sorgqr(a, taus)
|
|
elif dtype == np.float64:
|
|
q, _, info = scipy.linalg.lapack.dorgqr(a, taus)
|
|
elif dtype == np.complex64:
|
|
q, _, info = scipy.linalg.lapack.cungqr(a, taus)
|
|
elif dtype == np.complex128:
|
|
q, _, info = scipy.linalg.lapack.zungqr(a, taus)
|
|
else:
|
|
assert False, dtype
|
|
assert info == 0, info
|
|
return q
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
args_maker = lambda: [rng(shape, dtype), rng(shape[:-2] + (k,), dtype)]
|
|
tol = {np.float32: 1e-5, np.complex64: 1e-5, np.float64: 1e-12,
|
|
np.complex128: 1e-12}
|
|
self._CheckAgainstNumpy(reference_fn, lax.linalg.householder_product,
|
|
args_maker, rtol=tol, atol=tol)
|
|
self._CompileAndCheck(lax.linalg.householder_product, args_maker)
|
|
|
|
@jtu.sample_product(
|
|
shape=[(1, 1), (2, 4, 4), (0, 100, 100), (10, 10)],
|
|
dtype=float_types + complex_types,
|
|
calc_q=[False, True],
|
|
)
|
|
@jtu.run_on_devices("cpu")
|
|
def testHessenberg(self, shape, dtype, calc_q):
|
|
rng = jtu.rand_default(self.rng())
|
|
jsp_func = partial(jax.scipy.linalg.hessenberg, calc_q=calc_q)
|
|
if calc_q:
|
|
sp_func = np.vectorize(partial(scipy.linalg.hessenberg, calc_q=True),
|
|
otypes=(dtype, dtype),
|
|
signature='(n,n)->(n,n),(n,n)')
|
|
else:
|
|
sp_func = np.vectorize(scipy.linalg.hessenberg, signature='(n,n)->(n,n)',
|
|
otypes=(dtype,))
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
# scipy.linalg.hessenberg sometimes returns a float Q matrix for complex
|
|
# inputs
|
|
self._CheckAgainstNumpy(sp_func, jsp_func, args_maker, rtol=1e-5, atol=1e-5,
|
|
check_dtypes=not calc_q)
|
|
self._CompileAndCheck(jsp_func, args_maker)
|
|
|
|
@jtu.sample_product(
|
|
shape=[(1, 1), (2, 2, 2), (4, 4), (10, 10), (2, 5, 5)],
|
|
dtype=float_types + complex_types,
|
|
lower=[False, True],
|
|
)
|
|
@jtu.skip_on_devices("tpu","rocm")
|
|
def testTridiagonal(self, shape, dtype, lower):
|
|
rng = jtu.rand_default(self.rng())
|
|
def jax_func(a):
|
|
return lax.linalg.tridiagonal(a, lower=lower)
|
|
|
|
real_dtype = jnp.finfo(dtype).dtype
|
|
@partial(np.vectorize, otypes=(dtype, real_dtype, real_dtype, dtype),
|
|
signature='(n,n)->(n,n),(n),(k),(k)')
|
|
def sp_func(a):
|
|
if dtype == np.float32:
|
|
c, d, e, tau, info = scipy.linalg.lapack.ssytrd(a, lower=lower)
|
|
elif dtype == np.float64:
|
|
c, d, e, tau, info = scipy.linalg.lapack.dsytrd(a, lower=lower)
|
|
elif dtype == np.complex64:
|
|
c, d, e, tau, info = scipy.linalg.lapack.chetrd(a, lower=lower)
|
|
elif dtype == np.complex128:
|
|
c, d, e, tau, info = scipy.linalg.lapack.zhetrd(a, lower=lower)
|
|
else:
|
|
assert False, dtype
|
|
assert info == 0
|
|
return c, d, e, tau
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
self._CheckAgainstNumpy(sp_func, jax_func, args_maker, rtol=1e-4, atol=1e-4,
|
|
check_dtypes=False)
|
|
|
|
|
|
@jtu.sample_product(
|
|
n=[1, 4, 5, 20, 50, 100],
|
|
dtype=float_types + complex_types,
|
|
)
|
|
def testIssue2131(self, n, dtype):
|
|
args_maker_zeros = lambda: [np.zeros((n, n), dtype)]
|
|
osp_fun = lambda a: osp.linalg.expm(a)
|
|
jsp_fun = lambda a: jsp.linalg.expm(a)
|
|
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker_zeros)
|
|
self._CompileAndCheck(jsp_fun, args_maker_zeros)
|
|
|
|
@jtu.sample_product(
|
|
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape)
|
|
for lhs_shape, rhs_shape in [
|
|
[(1, 1), (1,)],
|
|
[(4, 4), (4,)],
|
|
[(4, 4), (4, 4)],
|
|
]
|
|
],
|
|
dtype=float_types,
|
|
lower=[True, False],
|
|
)
|
|
def testChoSolve(self, lhs_shape, rhs_shape, dtype, lower):
|
|
rng = jtu.rand_default(self.rng())
|
|
def args_maker():
|
|
b = rng(rhs_shape, dtype)
|
|
if lower:
|
|
L = np.tril(rng(lhs_shape, dtype))
|
|
return [(L, lower), b]
|
|
else:
|
|
U = np.triu(rng(lhs_shape, dtype))
|
|
return [(U, lower), b]
|
|
self._CheckAgainstNumpy(osp.linalg.cho_solve, jsp.linalg.cho_solve,
|
|
args_maker, tol=1e-3)
|
|
|
|
|
|
@jtu.sample_product(
|
|
n=[1, 4, 5, 20, 50, 100],
|
|
dtype=float_types + complex_types,
|
|
)
|
|
def testExpmFrechet(self, n, dtype):
|
|
rng = jtu.rand_small(self.rng())
|
|
if dtype == np.float64 or dtype == np.complex128:
|
|
target_norms = [1.0e-2, 2.0e-1, 9.0e-01, 2.0, 3.0]
|
|
# TODO(zhangqiaorjc): Reduce tol to default 1e-15.
|
|
tol = {
|
|
np.dtype(np.float64): 1e-14,
|
|
np.dtype(np.complex128): 1e-14,
|
|
}
|
|
elif dtype == np.float32 or dtype == np.complex64:
|
|
target_norms = [4.0e-1, 1.0, 3.0]
|
|
tol = None
|
|
else:
|
|
raise TypeError(f"{dtype=} is not supported.")
|
|
for norm in target_norms:
|
|
def args_maker():
|
|
a = rng((n, n), dtype)
|
|
a = a / np.linalg.norm(a, 1) * norm
|
|
e = rng((n, n), dtype)
|
|
return [a, e, ]
|
|
|
|
#compute_expm is True
|
|
osp_fun = lambda a,e: osp.linalg.expm_frechet(a,e,compute_expm=True)
|
|
jsp_fun = lambda a,e: jsp.linalg.expm_frechet(a,e,compute_expm=True)
|
|
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker,
|
|
check_dtypes=False, tol=tol)
|
|
self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=False)
|
|
#compute_expm is False
|
|
osp_fun = lambda a,e: osp.linalg.expm_frechet(a,e,compute_expm=False)
|
|
jsp_fun = lambda a,e: jsp.linalg.expm_frechet(a,e,compute_expm=False)
|
|
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker,
|
|
check_dtypes=False, tol=tol)
|
|
self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=False)
|
|
|
|
@jtu.sample_product(
|
|
n=[1, 4, 5, 20, 50],
|
|
dtype=float_types + complex_types,
|
|
)
|
|
def testExpmGrad(self, n, dtype):
|
|
rng = jtu.rand_small(self.rng())
|
|
a = rng((n, n), dtype)
|
|
if dtype == np.float64 or dtype == np.complex128:
|
|
target_norms = [1.0e-2, 2.0e-1, 9.0e-01, 2.0, 3.0]
|
|
elif dtype == np.float32 or dtype == np.complex64:
|
|
target_norms = [4.0e-1, 1.0, 3.0]
|
|
else:
|
|
raise TypeError(f"{dtype=} is not supported.")
|
|
# TODO(zhangqiaorjc): Reduce tol to default 1e-5.
|
|
# Lower tolerance is due to 2nd order derivative.
|
|
tol = {
|
|
# Note that due to inner_product, float and complex tol are coupled.
|
|
np.dtype(np.float32): 0.02,
|
|
np.dtype(np.complex64): 0.02,
|
|
np.dtype(np.float64): 1e-4,
|
|
np.dtype(np.complex128): 1e-4,
|
|
}
|
|
for norm in target_norms:
|
|
a = a / np.linalg.norm(a, 1) * norm
|
|
def expm(x):
|
|
return jsp.linalg.expm(x, upper_triangular=False, max_squarings=16)
|
|
jtu.check_grads(expm, (a,), modes=["fwd", "rev"], order=1, atol=tol,
|
|
rtol=tol)
|
|
@jtu.sample_product(
|
|
shape=[(4, 4), (15, 15), (50, 50), (100, 100)],
|
|
dtype=float_types + complex_types,
|
|
)
|
|
@jtu.run_on_devices("cpu")
|
|
def testSchur(self, shape, dtype):
|
|
rng = jtu.rand_default(self.rng())
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
self._CheckAgainstNumpy(osp.linalg.schur, jsp.linalg.schur, args_maker)
|
|
self._CompileAndCheck(jsp.linalg.schur, args_maker)
|
|
|
|
@jtu.sample_product(
|
|
shape=[(1, 1), (4, 4), (15, 15), (50, 50), (100, 100)],
|
|
dtype=float_types + complex_types,
|
|
)
|
|
@jtu.run_on_devices("cpu")
|
|
def testRsf2csf(self, shape, dtype):
|
|
rng = jtu.rand_default(self.rng())
|
|
args_maker = lambda: [rng(shape, dtype), rng(shape, dtype)]
|
|
tol = 3e-5
|
|
self._CheckAgainstNumpy(osp.linalg.rsf2csf, jsp.linalg.rsf2csf,
|
|
args_maker, tol=tol)
|
|
self._CompileAndCheck(jsp.linalg.rsf2csf, args_maker)
|
|
|
|
@jtu.sample_product(
|
|
shape=[(1, 1), (5, 5), (20, 20), (50, 50)],
|
|
dtype=float_types + complex_types,
|
|
disp=[True, False],
|
|
)
|
|
# funm uses jax.scipy.linalg.schur which is implemented for a CPU
|
|
# backend only, so tests on GPU and TPU backends are skipped here
|
|
@jtu.run_on_devices("cpu")
|
|
def testFunm(self, shape, dtype, disp):
|
|
def func(x):
|
|
return x**-2.718
|
|
rng = jtu.rand_default(self.rng())
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
jnp_fun = lambda arr: jsp.linalg.funm(arr, func, disp=disp)
|
|
scp_fun = lambda arr: osp.linalg.funm(arr, func, disp=disp)
|
|
self._CheckAgainstNumpy(jnp_fun, scp_fun, args_maker, check_dtypes=False,
|
|
tol={np.complex64: 1e-5, np.complex128: 1e-6})
|
|
self._CompileAndCheck(jnp_fun, args_maker, atol=2e-5)
|
|
|
|
@jtu.sample_product(
|
|
shape=[(4, 4), (15, 15), (50, 50), (100, 100)],
|
|
dtype=float_types + complex_types,
|
|
)
|
|
@jtu.run_on_devices("cpu")
|
|
def testSqrtmPSDMatrix(self, shape, dtype):
|
|
# Checks against scipy.linalg.sqrtm when the principal square root
|
|
# is guaranteed to be unique (i.e no negative real eigenvalue)
|
|
rng = jtu.rand_default(self.rng())
|
|
arg = rng(shape, dtype)
|
|
mat = arg @ arg.T
|
|
args_maker = lambda : [mat]
|
|
if dtype == np.float32 or dtype == np.complex64:
|
|
tol = 1e-4
|
|
else:
|
|
tol = 1e-8
|
|
self._CheckAgainstNumpy(osp.linalg.sqrtm,
|
|
jsp.linalg.sqrtm,
|
|
args_maker,
|
|
tol=tol,
|
|
check_dtypes=False)
|
|
self._CompileAndCheck(jsp.linalg.sqrtm, args_maker)
|
|
|
|
@jtu.sample_product(
|
|
shape=[(4, 4), (15, 15), (50, 50), (100, 100)],
|
|
dtype=float_types + complex_types,
|
|
)
|
|
@jtu.run_on_devices("cpu")
|
|
def testSqrtmGenMatrix(self, shape, dtype):
|
|
rng = jtu.rand_default(self.rng())
|
|
arg = rng(shape, dtype)
|
|
if dtype == np.float32 or dtype == np.complex64:
|
|
tol = 2e-3
|
|
else:
|
|
tol = 1e-8
|
|
R = jsp.linalg.sqrtm(arg)
|
|
self.assertAllClose(R @ R, arg, atol=tol, check_dtypes=False)
|
|
|
|
@jtu.sample_product(
|
|
[dict(diag=diag, expected=expected)
|
|
for diag, expected in [([1, 0, 0], [1, 0, 0]), ([0, 4, 0], [0, 2, 0]),
|
|
([0, 0, 0, 9],[0, 0, 0, 3]),
|
|
([0, 0, 9, 0, 0, 4], [0, 0, 3, 0, 0, 2])]
|
|
],
|
|
dtype=float_types + complex_types,
|
|
)
|
|
@jtu.run_on_devices("cpu")
|
|
def testSqrtmEdgeCase(self, diag, expected, dtype):
|
|
"""
|
|
Tests the zero numerator condition
|
|
"""
|
|
mat = jnp.diag(jnp.array(diag)).astype(dtype)
|
|
expected = jnp.diag(jnp.array(expected))
|
|
root = jsp.linalg.sqrtm(mat)
|
|
|
|
self.assertAllClose(root, expected, check_dtypes=False)
|
|
|
|
@jtu.sample_product(
|
|
cshape=[(), (4,), (8,), (3, 7), (0, 5, 1)],
|
|
cdtype=float_types + complex_types,
|
|
rshape=[(), (3,), (7,), (2, 1, 4), (19, 0)],
|
|
rdtype=float_types + complex_types + int_types)
|
|
def testToeplitzConstrcution(self, rshape, rdtype, cshape, cdtype):
|
|
if ((rdtype in [np.float64, np.complex128]
|
|
or cdtype in [np.float64, np.complex128])
|
|
and not config.enable_x64.value):
|
|
self.skipTest("Only run float64 testcase when float64 is enabled.")
|
|
|
|
int_types_excl_i8 = set(int_types) - {np.int8}
|
|
if ((rdtype in int_types_excl_i8 or cdtype in int_types_excl_i8)
|
|
and jtu.test_device_matches(["gpu"])):
|
|
self.skipTest("Integer (except int8) toeplitz is not supported on GPU yet.")
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
args_maker = lambda: [rng(cshape, cdtype), rng(rshape, rdtype)]
|
|
with jtu.strict_promotion_if_dtypes_match([rdtype, cdtype]):
|
|
self._CheckAgainstNumpy(jtu.promote_like_jnp(osp.linalg.toeplitz),
|
|
jsp.linalg.toeplitz, args_maker)
|
|
self._CompileAndCheck(jsp.linalg.toeplitz, args_maker)
|
|
|
|
@jtu.sample_product(
|
|
shape=[(), (3,), (1, 4), (1, 5, 9), (11, 0, 13)],
|
|
dtype=float_types + complex_types + int_types)
|
|
@jtu.skip_on_devices("rocm")
|
|
def testToeplitzSymmetricConstruction(self, shape, dtype):
|
|
if (dtype in [np.float64, np.complex128]
|
|
and not config.enable_x64.value):
|
|
self.skipTest("Only run float64 testcase when float64 is enabled.")
|
|
|
|
int_types_excl_i8 = set(int_types) - {np.int8}
|
|
if (dtype in int_types_excl_i8
|
|
and jtu.test_device_matches(["gpu"])):
|
|
self.skipTest("Integer (except int8) toeplitz is not supported on GPU yet.")
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
self._CheckAgainstNumpy(jtu.promote_like_jnp(osp.linalg.toeplitz),
|
|
jsp.linalg.toeplitz, args_maker)
|
|
self._CompileAndCheck(jsp.linalg.toeplitz, args_maker)
|
|
|
|
def testToeplitzConstructionWithKnownCases(self):
|
|
# Test with examples taken from SciPy doc for the corresponding function.
|
|
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.toeplitz.html
|
|
ret = jsp.linalg.toeplitz(np.array([1.0, 2+3j, 4-1j]))
|
|
self.assertAllClose(ret, np.array([
|
|
[ 1.+0.j, 2.-3.j, 4.+1.j],
|
|
[ 2.+3.j, 1.+0.j, 2.-3.j],
|
|
[ 4.-1.j, 2.+3.j, 1.+0.j]]))
|
|
ret = jsp.linalg.toeplitz(np.array([1, 2, 3], dtype=np.float32),
|
|
np.array([1, 4, 5, 6], dtype=np.float32))
|
|
self.assertAllClose(ret, np.array([
|
|
[1, 4, 5, 6],
|
|
[2, 1, 4, 5],
|
|
[3, 2, 1, 4]], dtype=np.float32))
|
|
|
|
|
|
class LaxLinalgTest(jtu.JaxTestCase):
|
|
"""Tests for lax.linalg primitives."""
|
|
|
|
@jtu.sample_product(
|
|
n=[0, 4, 5, 50],
|
|
dtype=float_types + complex_types,
|
|
lower=[True, False],
|
|
sort_eigenvalues=[True, False],
|
|
)
|
|
def testEigh(self, n, dtype, lower, sort_eigenvalues):
|
|
rng = jtu.rand_default(self.rng())
|
|
tol = 1e-3
|
|
args_maker = lambda: [rng((n, n), dtype)]
|
|
|
|
a, = args_maker()
|
|
a = (a + np.conj(a.T)) / 2
|
|
v, w = lax.linalg.eigh(np.tril(a) if lower else np.triu(a),
|
|
lower=lower, symmetrize_input=False,
|
|
sort_eigenvalues=sort_eigenvalues)
|
|
w = np.asarray(w)
|
|
v = np.asarray(v)
|
|
self.assertLessEqual(
|
|
np.linalg.norm(np.eye(n) - np.matmul(np.conj(T(v)), v)), 1e-3)
|
|
self.assertLessEqual(np.linalg.norm(np.matmul(a, v) - w * v),
|
|
tol * np.linalg.norm(a))
|
|
|
|
w_expected, v_expected = np.linalg.eigh(np.asarray(a))
|
|
self.assertAllClose(w_expected, w if sort_eigenvalues else np.sort(w),
|
|
rtol=1e-4, atol=1e-4)
|
|
|
|
def run_eigh_tridiagonal_test(self, alpha, beta):
|
|
n = alpha.shape[-1]
|
|
# scipy.linalg.eigh_tridiagonal doesn't support complex inputs, so for
|
|
# this we call the slower numpy.linalg.eigh.
|
|
if np.issubdtype(alpha.dtype, np.complexfloating):
|
|
tridiagonal = np.diag(alpha) + np.diag(beta, 1) + np.diag(
|
|
np.conj(beta), -1)
|
|
eigvals_expected, _ = np.linalg.eigh(tridiagonal)
|
|
else:
|
|
eigvals_expected = scipy.linalg.eigh_tridiagonal(
|
|
alpha, beta, eigvals_only=True)
|
|
eigvals = jax.scipy.linalg.eigh_tridiagonal(
|
|
alpha, beta, eigvals_only=True)
|
|
finfo = np.finfo(alpha.dtype)
|
|
atol = 4 * np.sqrt(n) * finfo.eps * np.amax(np.abs(eigvals_expected))
|
|
self.assertAllClose(eigvals_expected, eigvals, atol=atol, rtol=1e-4)
|
|
|
|
@jtu.sample_product(
|
|
n=[1, 2, 3, 7, 8, 100],
|
|
dtype=float_types + complex_types,
|
|
)
|
|
def testToeplitz(self, n, dtype):
|
|
for a, b in [[2, -1], [1, 0], [0, 1], [-1e10, 1e10], [-1e-10, 1e-10]]:
|
|
alpha = a * np.ones([n], dtype=dtype)
|
|
beta = b * np.ones([n - 1], dtype=dtype)
|
|
self.run_eigh_tridiagonal_test(alpha, beta)
|
|
|
|
@jtu.sample_product(
|
|
n=[1, 2, 3, 7, 8, 100],
|
|
dtype=float_types + complex_types,
|
|
)
|
|
def testRandomUniform(self, n, dtype):
|
|
alpha = jtu.rand_uniform(self.rng())((n,), dtype)
|
|
beta = jtu.rand_uniform(self.rng())((n - 1,), dtype)
|
|
self.run_eigh_tridiagonal_test(alpha, beta)
|
|
|
|
@jtu.sample_product(dtype=float_types + complex_types)
|
|
def testSelect(self, dtype):
|
|
n = 5
|
|
alpha = jtu.rand_uniform(self.rng())((n,), dtype)
|
|
beta = jtu.rand_uniform(self.rng())((n - 1,), dtype)
|
|
eigvals_all = jax.scipy.linalg.eigh_tridiagonal(alpha, beta, select="a",
|
|
eigvals_only=True)
|
|
eps = np.finfo(alpha.dtype).eps
|
|
atol = 2 * n * eps
|
|
for first in range(n - 1):
|
|
for last in range(first + 1, n - 1):
|
|
# Check that we get the expected eigenvalues by selecting by
|
|
# index range.
|
|
eigvals_index = jax.scipy.linalg.eigh_tridiagonal(
|
|
alpha, beta, select="i", select_range=(first, last),
|
|
eigvals_only=True)
|
|
self.assertAllClose(
|
|
eigvals_all[first:(last + 1)], eigvals_index, atol=atol)
|
|
|
|
@jtu.sample_product(dtype=[np.float32, np.float64])
|
|
@jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1
|
|
def test_tridiagonal_solve(self, dtype):
|
|
dl = np.array([0.0, 2.0, 3.0], dtype=dtype)
|
|
d = np.ones(3, dtype=dtype)
|
|
du = np.array([1.0, 2.0, 0.0], dtype=dtype)
|
|
m = 3
|
|
B = np.ones([m, 1], dtype=dtype)
|
|
X = lax.linalg.tridiagonal_solve(dl, d, du, B)
|
|
A = np.eye(3, dtype=dtype)
|
|
A[[1, 2], [0, 1]] = dl[1:]
|
|
A[[0, 1], [1, 2]] = du[:-1]
|
|
np.testing.assert_allclose(A @ X, B, rtol=1e-6, atol=1e-6)
|
|
|
|
@jtu.sample_product(
|
|
shape=[(4, 4), (15, 15), (50, 50), (100, 100)],
|
|
dtype=float_types + complex_types,
|
|
)
|
|
@jtu.run_on_devices("cpu")
|
|
def testSchur(self, shape, dtype):
|
|
rng = jtu.rand_default(self.rng())
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
self._CheckAgainstNumpy(osp.linalg.schur, lax.linalg.schur, args_maker)
|
|
self._CompileAndCheck(lax.linalg.schur, args_maker)
|
|
|
|
@jtu.sample_product(
|
|
shape=[(2, 2), (4, 4), (15, 15), (50, 50), (100, 100)],
|
|
dtype=float_types + complex_types,
|
|
)
|
|
@jtu.run_on_devices("cpu")
|
|
def testSchurBatching(self, shape, dtype):
|
|
rng = jtu.rand_default(self.rng())
|
|
batch_size = 10
|
|
shape = (batch_size,) + shape
|
|
args = rng(shape, dtype)
|
|
reconstruct = vmap(lambda S, T: S @ T @ jnp.conj(S.T))
|
|
|
|
Ts, Ss = vmap(lax.linalg.schur)(args)
|
|
self.assertAllClose(reconstruct(Ss, Ts), args, atol=1e-4)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|