2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2018 The JAX Authors.
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
"""Tests for the LAPAX linear algebra module."""
|
|
|
|
|
2018-12-17 16:36:55 +00:00
|
|
|
from functools import partial
|
2022-11-11 12:55:20 -08:00
|
|
|
import itertools
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2020-05-05 16:40:41 -04:00
|
|
|
import numpy as np
|
2021-05-03 11:27:07 -04:00
|
|
|
import scipy
|
2022-10-06 01:43:54 +00:00
|
|
|
import scipy.linalg
|
2018-12-20 15:37:34 -05:00
|
|
|
import scipy as osp
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
from absl.testing import absltest
|
|
|
|
|
2019-08-09 10:50:31 -04:00
|
|
|
import jax
|
2019-05-04 09:42:01 -04:00
|
|
|
from jax import jit, grad, jvp, vmap
|
2019-12-10 00:38:18 -08:00
|
|
|
from jax import lax
|
2020-05-05 16:40:41 -04:00
|
|
|
from jax import numpy as jnp
|
2018-12-20 15:37:34 -05:00
|
|
|
from jax import scipy as jsp
|
2023-10-12 13:15:22 +01:00
|
|
|
from jax._src import config
|
2021-09-24 07:02:08 -07:00
|
|
|
from jax._src import test_util as jtu
|
2023-07-17 10:14:44 -07:00
|
|
|
from jax._src import xla_bridge
|
2023-10-12 13:15:22 +01:00
|
|
|
from jax._src.numpy.util import promote_dtypes_inexact
|
2018-12-20 15:37:34 -05:00
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
config.parse_flags_with_absl()
|
|
|
|
|
2023-03-27 16:39:48 -07:00
|
|
|
scipy_version = tuple(map(int, scipy.version.version.split('.')[:3]))
|
|
|
|
|
2020-05-05 16:40:41 -04:00
|
|
|
T = lambda x: np.swapaxes(x, -1, -2)
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
|
2021-08-20 09:55:14 -07:00
|
|
|
float_types = jtu.dtypes.floating
|
|
|
|
complex_types = jtu.dtypes.complex
|
2022-11-15 18:40:52 +09:00
|
|
|
int_types = jtu.dtypes.all_integer
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2023-07-17 10:14:44 -07:00
|
|
|
def _is_required_cuda_version_satisfied(cuda_version):
|
|
|
|
version = xla_bridge.get_backend().platform_version
|
|
|
|
if version == "<unknown>" or version.split()[0] == "rocm":
|
|
|
|
return False
|
|
|
|
else:
|
|
|
|
return int(version.split()[-1]) >= cuda_version
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
class NumpyLinalgTest(jtu.JaxTestCase):
|
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(1, 1), (4, 4), (2, 5, 5), (200, 200), (1000, 0, 0)],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
)
|
2020-12-04 09:44:50 -08:00
|
|
|
def testCholesky(self, shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
def args_maker():
|
2019-02-11 16:18:13 -08:00
|
|
|
factor_shape = shape[:-1] + (2 * shape[-1],)
|
|
|
|
a = rng(factor_shape, dtype)
|
2020-05-05 16:40:41 -04:00
|
|
|
return [np.matmul(a, jnp.conj(T(a)))]
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2020-05-05 16:40:41 -04:00
|
|
|
self._CheckAgainstNumpy(np.linalg.cholesky, jnp.linalg.cholesky, args_maker,
|
2020-06-01 17:19:23 -04:00
|
|
|
tol=1e-3)
|
|
|
|
self._CompileAndCheck(jnp.linalg.cholesky, args_maker)
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2020-05-05 16:40:41 -04:00
|
|
|
if jnp.finfo(dtype).bits == 64:
|
|
|
|
jtu.check_grads(jnp.linalg.cholesky, args_maker(), order=2)
|
2019-02-11 17:39:07 -05:00
|
|
|
|
2019-12-10 00:38:18 -08:00
|
|
|
def testCholeskyGradPrecision(self):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-05-05 16:40:41 -04:00
|
|
|
a = rng((3, 3), np.float32)
|
|
|
|
a = np.dot(a, a.T)
|
2019-12-10 00:38:18 -08:00
|
|
|
jtu.assert_dot_precision(
|
2020-05-05 16:40:41 -04:00
|
|
|
lax.Precision.HIGHEST, partial(jvp, jnp.linalg.cholesky), (a,), (a,))
|
2019-12-10 00:38:18 -08:00
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
n=[0, 2, 3, 4, 5, 25], # TODO(mattjj): complex64 unstable on large sizes?
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
)
|
2020-12-04 09:44:50 -08:00
|
|
|
def testDet(self, n, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-12-20 15:37:34 -05:00
|
|
|
args_maker = lambda: [rng((n, n), dtype)]
|
|
|
|
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np.linalg.det, jnp.linalg.det, args_maker, tol=1e-3)
|
|
|
|
self._CompileAndCheck(jnp.linalg.det, args_maker,
|
2020-05-05 16:40:41 -04:00
|
|
|
rtol={np.float64: 1e-13, np.complex128: 1e-13})
|
2018-12-20 15:37:34 -05:00
|
|
|
|
2019-09-16 08:45:10 -07:00
|
|
|
def testDetOfSingularMatrix(self):
|
2020-05-05 16:40:41 -04:00
|
|
|
x = jnp.array([[-1., 3./2], [2./3, -1.]], dtype=np.float32)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(np.float32(0), jsp.linalg.det(x))
|
2020-06-02 17:37:20 -07:00
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(1, 1), (3, 3), (2, 4, 4)],
|
|
|
|
dtype=float_types,
|
|
|
|
)
|
2020-04-25 16:26:25 +01:00
|
|
|
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
2022-10-18 18:09:02 -07:00
|
|
|
@jtu.skip_on_devices("tpu")
|
2020-12-04 09:44:50 -08:00
|
|
|
def testDetGrad(self, shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-04-25 16:26:25 +01:00
|
|
|
a = rng(shape, dtype)
|
2020-05-05 16:40:41 -04:00
|
|
|
jtu.check_grads(jnp.linalg.det, (a,), 2, atol=1e-1, rtol=1e-1)
|
2020-04-25 16:26:25 +01:00
|
|
|
# make sure there are no NaNs when a matrix is zero
|
|
|
|
if len(shape) == 2:
|
|
|
|
pass
|
|
|
|
jtu.check_grads(
|
2020-05-05 16:40:41 -04:00
|
|
|
jnp.linalg.det, (jnp.zeros_like(a),), 1, atol=1e-1, rtol=1e-1)
|
2020-04-25 16:26:25 +01:00
|
|
|
else:
|
|
|
|
a[0] = 0
|
2020-05-05 16:40:41 -04:00
|
|
|
jtu.check_grads(jnp.linalg.det, (a,), 1, atol=1e-1, rtol=1e-1)
|
2020-04-25 16:26:25 +01:00
|
|
|
|
2021-03-19 19:13:45 -07:00
|
|
|
def testDetGradIssue6121(self):
|
|
|
|
f = lambda x: jnp.linalg.det(x).sum()
|
|
|
|
x = jnp.ones((16, 1, 1))
|
|
|
|
jax.grad(f)(x)
|
|
|
|
jtu.check_grads(f, (x,), 2, atol=1e-1, rtol=1e-1)
|
|
|
|
|
2020-04-25 09:20:26 -07:00
|
|
|
def testDetGradOfSingularMatrixCorank1(self):
|
2020-04-25 16:26:25 +01:00
|
|
|
# Rank 2 matrix with nonzero gradient
|
2020-07-07 16:19:43 -07:00
|
|
|
a = jnp.array([[ 50, -30, 45],
|
|
|
|
[-30, 90, -81],
|
|
|
|
[ 45, -81, 81]], dtype=jnp.float32)
|
|
|
|
jtu.check_grads(jnp.linalg.det, (a,), 1, atol=1e-1, rtol=1e-1)
|
2020-04-25 09:20:26 -07:00
|
|
|
|
2022-11-09 18:57:28 -08:00
|
|
|
# TODO(phawkins): Test sometimes produces NaNs on TPU.
|
|
|
|
@jtu.skip_on_devices("tpu")
|
2020-04-25 09:20:26 -07:00
|
|
|
def testDetGradOfSingularMatrixCorank2(self):
|
2020-04-25 16:26:25 +01:00
|
|
|
# Rank 1 matrix with zero gradient
|
2020-07-07 16:19:43 -07:00
|
|
|
b = jnp.array([[ 36, -42, 18],
|
|
|
|
[-42, 49, -21],
|
|
|
|
[ 18, -21, 9]], dtype=jnp.float32)
|
2021-04-22 13:08:43 -07:00
|
|
|
jtu.check_grads(jnp.linalg.det, (b,), 1, atol=1e-1, rtol=1e-1, eps=1e-1)
|
2019-09-16 08:45:10 -07:00
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
m=[1, 5, 7, 23],
|
|
|
|
nq=zip([2, 4, 6, 36], [(1, 2), (2, 2), (1, 2, 3), (3, 3, 1, 4)]),
|
|
|
|
dtype=float_types,
|
|
|
|
)
|
2020-12-04 09:44:50 -08:00
|
|
|
def testTensorsolve(self, m, nq, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-04-22 08:59:22 -07:00
|
|
|
|
2020-02-09 23:35:09 +01:00
|
|
|
# According to numpy docs the shapes are as follows:
|
2020-04-22 08:59:22 -07:00
|
|
|
# Coefficient tensor (a), of shape b.shape + Q.
|
|
|
|
# And prod(Q) == prod(b.shape)
|
|
|
|
# Therefore, n = prod(q)
|
2020-02-09 23:35:09 +01:00
|
|
|
n, q = nq
|
|
|
|
b_shape = (n, m)
|
|
|
|
# To accomplish prod(Q) == prod(b.shape) we append the m extra dim
|
|
|
|
# to Q shape
|
|
|
|
Q = q + (m,)
|
|
|
|
args_maker = lambda: [
|
|
|
|
rng(b_shape + Q, dtype), # = a
|
|
|
|
rng(b_shape, dtype)] # = b
|
|
|
|
a, b = args_maker()
|
2020-05-05 16:40:41 -04:00
|
|
|
result = jnp.linalg.tensorsolve(*args_maker())
|
2020-02-09 23:35:09 +01:00
|
|
|
self.assertEqual(result.shape, Q)
|
|
|
|
|
2020-06-02 17:37:20 -07:00
|
|
|
self._CheckAgainstNumpy(np.linalg.tensorsolve,
|
2020-05-05 16:40:41 -04:00
|
|
|
jnp.linalg.tensorsolve, args_maker,
|
|
|
|
tol={np.float32: 1e-2, np.float64: 1e-3})
|
2020-06-02 17:37:20 -07:00
|
|
|
self._CompileAndCheck(jnp.linalg.tensorsolve,
|
2020-06-01 17:19:23 -04:00
|
|
|
args_maker,
|
2020-05-05 16:40:41 -04:00
|
|
|
rtol={np.float64: 1e-13})
|
2020-02-09 23:35:09 +01:00
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(dtype=dtype, method=method)
|
|
|
|
for dtype in float_types + complex_types
|
|
|
|
for method in (["lu"] if jnp.issubdtype(dtype, jnp.complexfloating)
|
2022-05-17 11:23:10 -07:00
|
|
|
else ["lu", "qr"])
|
2022-10-11 15:59:44 +00:00
|
|
|
],
|
|
|
|
shape=[(0, 0), (1, 1), (3, 3), (4, 4), (10, 10), (200, 200), (2, 2, 2),
|
|
|
|
(2, 3, 3), (3, 2, 2)],
|
|
|
|
)
|
2022-05-17 11:23:10 -07:00
|
|
|
def testSlogdet(self, shape, dtype, method):
|
2020-12-04 09:44:50 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2019-09-11 08:19:26 -04:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2022-05-17 11:23:10 -07:00
|
|
|
slogdet = partial(jnp.linalg.slogdet, method=method)
|
|
|
|
self._CheckAgainstNumpy(np.linalg.slogdet, slogdet, args_maker,
|
2020-06-01 17:19:23 -04:00
|
|
|
tol=1e-3)
|
2022-05-17 11:23:10 -07:00
|
|
|
self._CompileAndCheck(slogdet, args_maker)
|
2019-09-20 20:45:01 -07:00
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(1, 1), (4, 4), (5, 5), (2, 7, 7)],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
)
|
2020-02-05 17:35:46 +01:00
|
|
|
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
2020-12-04 09:44:50 -08:00
|
|
|
def testSlogdetGrad(self, shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2019-09-17 18:55:11 +01:00
|
|
|
a = rng(shape, dtype)
|
2021-12-07 16:14:06 -05:00
|
|
|
jtu.check_grads(jnp.linalg.slogdet, (a,), 2, atol=1e-1, rtol=2e-1)
|
2018-12-20 15:37:34 -05:00
|
|
|
|
2019-09-11 08:19:26 -04:00
|
|
|
def testIssue1213(self):
|
|
|
|
for n in range(5):
|
2020-05-05 16:40:41 -04:00
|
|
|
mat = jnp.array([np.diag(np.ones([5], dtype=np.float32))*(-.01)] * 2)
|
2019-09-11 08:19:26 -04:00
|
|
|
args_maker = lambda: [mat]
|
2020-05-05 16:40:41 -04:00
|
|
|
self._CheckAgainstNumpy(np.linalg.slogdet, jnp.linalg.slogdet, args_maker,
|
2020-06-01 17:19:23 -04:00
|
|
|
tol=1e-3)
|
2019-09-11 08:19:26 -04:00
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(0, 0), (4, 4), (5, 5), (50, 50), (2, 6, 6)],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
compute_left_eigenvectors=[False, True],
|
|
|
|
compute_right_eigenvectors=[False, True],
|
|
|
|
)
|
2019-05-13 15:59:58 -04:00
|
|
|
# TODO(phawkins): enable when there is an eigendecomposition implementation
|
|
|
|
# for GPU/TPU.
|
2023-09-27 14:55:21 -07:00
|
|
|
@jtu.run_on_devices("cpu")
|
2020-09-18 09:30:19 +02:00
|
|
|
def testEig(self, shape, dtype, compute_left_eigenvectors,
|
2020-12-04 09:44:50 -08:00
|
|
|
compute_right_eigenvectors):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2019-05-13 19:53:50 -04:00
|
|
|
n = shape[-1]
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2019-05-13 15:59:58 -04:00
|
|
|
|
|
|
|
# Norm, adjusted for dimension and type.
|
|
|
|
def norm(x):
|
2020-05-05 16:40:41 -04:00
|
|
|
norm = np.linalg.norm(x, axis=(-2, -1))
|
|
|
|
return norm / ((n + 1) * jnp.finfo(dtype).eps)
|
2019-05-13 15:59:58 -04:00
|
|
|
|
2020-09-18 09:30:19 +02:00
|
|
|
def check_right_eigenvectors(a, w, vr):
|
|
|
|
self.assertTrue(
|
|
|
|
np.all(norm(np.matmul(a, vr) - w[..., None, :] * vr) < 100))
|
|
|
|
|
|
|
|
def check_left_eigenvectors(a, w, vl):
|
|
|
|
rank = len(a.shape)
|
|
|
|
aH = jnp.conj(a.transpose(list(range(rank - 2)) + [rank - 1, rank - 2]))
|
|
|
|
wC = jnp.conj(w)
|
|
|
|
check_right_eigenvectors(aH, wC, vl)
|
|
|
|
|
2019-05-13 15:59:58 -04:00
|
|
|
a, = args_maker()
|
2022-05-11 13:06:19 -07:00
|
|
|
results = lax.linalg.eig(
|
|
|
|
a, compute_left_eigenvectors=compute_left_eigenvectors,
|
|
|
|
compute_right_eigenvectors=compute_right_eigenvectors)
|
2020-09-18 09:30:19 +02:00
|
|
|
w = results[0]
|
|
|
|
|
|
|
|
if compute_left_eigenvectors:
|
|
|
|
check_left_eigenvectors(a, w, results[1])
|
|
|
|
if compute_right_eigenvectors:
|
|
|
|
check_right_eigenvectors(a, w, results[1 + compute_left_eigenvectors])
|
2019-05-13 15:59:58 -04:00
|
|
|
|
2020-05-05 16:40:41 -04:00
|
|
|
self._CompileAndCheck(partial(jnp.linalg.eig), args_maker,
|
2020-06-01 17:19:23 -04:00
|
|
|
rtol=1e-3)
|
2019-05-13 15:59:58 -04:00
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(4, 4), (5, 5), (8, 8), (7, 6, 6)],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
)
|
2020-11-18 11:18:54 +00:00
|
|
|
# TODO(phawkins): enable when there is an eigendecomposition implementation
|
|
|
|
# for GPU/TPU.
|
2023-09-27 14:55:21 -07:00
|
|
|
@jtu.run_on_devices("cpu")
|
2020-12-04 09:44:50 -08:00
|
|
|
def testEigvalsGrad(self, shape, dtype):
|
2020-11-18 11:18:54 +00:00
|
|
|
# This test sometimes fails for large matrices. I (@j-towns) suspect, but
|
|
|
|
# haven't checked, that might be because of perturbations causing the
|
|
|
|
# ordering of eigenvalues to change, which will trip up check_grads. So we
|
|
|
|
# just test on small-ish matrices.
|
2020-12-04 09:44:50 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-11-18 11:18:54 +00:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
a, = args_maker()
|
|
|
|
tol = 1e-4 if dtype in (np.float64, np.complex128) else 1e-1
|
|
|
|
jtu.check_grads(lambda x: jnp.linalg.eigvals(x), (a,), order=1,
|
|
|
|
modes=['fwd', 'rev'], rtol=tol, atol=tol)
|
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(4, 4), (5, 5), (50, 50)],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
)
|
2019-10-30 19:29:56 -07:00
|
|
|
# TODO: enable when there is an eigendecomposition implementation
|
|
|
|
# for GPU/TPU.
|
2023-09-27 14:55:21 -07:00
|
|
|
@jtu.run_on_devices("cpu")
|
2020-12-04 09:44:50 -08:00
|
|
|
def testEigvals(self, shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2019-10-30 19:29:56 -07:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
a, = args_maker()
|
2020-05-05 16:40:41 -04:00
|
|
|
w1, _ = jnp.linalg.eig(a)
|
|
|
|
w2 = jnp.linalg.eigvals(a)
|
2022-08-16 14:44:16 -04:00
|
|
|
self.assertAllClose(w1, w2, rtol={np.complex64: 1e-5, np.complex128: 1e-14})
|
2019-10-30 19:29:56 -07:00
|
|
|
|
2023-09-27 14:55:21 -07:00
|
|
|
@jtu.run_on_devices("cpu")
|
2020-04-10 15:40:57 -04:00
|
|
|
def testEigvalsInf(self):
|
|
|
|
# https://github.com/google/jax/issues/2661
|
2020-12-08 13:03:30 -08:00
|
|
|
x = jnp.array([[jnp.inf]])
|
2020-05-05 16:40:41 -04:00
|
|
|
self.assertTrue(jnp.all(jnp.isnan(jnp.linalg.eigvals(x))))
|
2020-04-10 15:40:57 -04:00
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(1, 1), (4, 4), (5, 5)],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
)
|
2023-09-27 14:55:21 -07:00
|
|
|
@jtu.run_on_devices("cpu")
|
2020-12-04 09:44:50 -08:00
|
|
|
def testEigBatching(self, shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2019-05-13 15:59:58 -04:00
|
|
|
shape = (10,) + shape
|
|
|
|
args = rng(shape, dtype)
|
2020-05-05 16:40:41 -04:00
|
|
|
ws, vs = vmap(jnp.linalg.eig)(args)
|
|
|
|
self.assertTrue(np.all(np.linalg.norm(
|
|
|
|
np.matmul(args, vs) - ws[..., None, :] * vs) < 1e-3))
|
2019-05-13 15:59:58 -04:00
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
n=[0, 4, 5, 50, 512],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
lower=[True, False],
|
|
|
|
)
|
2020-12-04 09:44:50 -08:00
|
|
|
def testEigh(self, n, dtype, lower):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2023-05-04 07:45:38 -07:00
|
|
|
tol = 0.5 * np.maximum(n, 80) * np.finfo(dtype).eps
|
2019-01-07 18:10:08 -05:00
|
|
|
args_maker = lambda: [rng((n, n), dtype)]
|
|
|
|
|
|
|
|
uplo = "L" if lower else "U"
|
|
|
|
|
|
|
|
a, = args_maker()
|
2020-05-05 16:40:41 -04:00
|
|
|
a = (a + np.conj(a.T)) / 2
|
|
|
|
w, v = jnp.linalg.eigh(np.tril(a) if lower else np.triu(a),
|
2021-02-26 17:36:56 -05:00
|
|
|
UPLO=uplo, symmetrize_input=False)
|
2022-06-16 10:32:20 -07:00
|
|
|
w = w.astype(v.dtype)
|
2021-02-26 17:36:56 -05:00
|
|
|
self.assertLessEqual(
|
2023-05-04 07:45:38 -07:00
|
|
|
np.linalg.norm(np.eye(n) - np.matmul(np.conj(T(v)), v)), 4 * tol
|
2023-04-10 11:43:25 -07:00
|
|
|
)
|
2022-01-26 09:29:06 -08:00
|
|
|
with jax.numpy_rank_promotion('allow'):
|
|
|
|
self.assertLessEqual(np.linalg.norm(np.matmul(a, v) - w * v),
|
|
|
|
tol * np.linalg.norm(a))
|
2019-01-07 18:10:08 -05:00
|
|
|
|
2023-04-10 11:43:25 -07:00
|
|
|
self._CompileAndCheck(
|
|
|
|
partial(jnp.linalg.eigh, UPLO=uplo), args_maker, rtol=tol
|
|
|
|
)
|
2019-01-07 18:10:08 -05:00
|
|
|
|
2023-10-27 09:28:52 -07:00
|
|
|
@jtu.sample_product(
|
|
|
|
start=[0, 1, 63, 64, 65, 255],
|
|
|
|
end=[1, 63, 64, 65, 256],
|
|
|
|
)
|
|
|
|
@jtu.run_on_devices("tpu") # TODO(rmlarsen: enable on other devices)
|
|
|
|
def testEighSubsetByIndex(self, start, end):
|
|
|
|
if start >= end:
|
|
|
|
return
|
|
|
|
dtype = np.float32
|
|
|
|
n = 256
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
tol = np.maximum(n, 80) * np.finfo(dtype).eps
|
|
|
|
args_maker = lambda: [rng((n, n), dtype)]
|
|
|
|
subset_by_index = (start, end)
|
|
|
|
k = end - start
|
|
|
|
(a,) = args_maker()
|
|
|
|
a = (a + np.conj(a.T)) / 2
|
|
|
|
|
|
|
|
v, w = lax.linalg.eigh(
|
|
|
|
a, symmetrize_input=False, subset_by_index=subset_by_index
|
|
|
|
)
|
|
|
|
w = w.astype(v.dtype)
|
|
|
|
|
|
|
|
self.assertEqual(v.shape, (n, k))
|
|
|
|
self.assertEqual(w.shape, (k,))
|
|
|
|
self.assertLessEqual(
|
|
|
|
np.linalg.norm(np.eye(k) - np.matmul(np.conj(T(v)), v)), 3 * tol
|
|
|
|
)
|
|
|
|
with jax.numpy_rank_promotion("allow"):
|
|
|
|
self.assertLessEqual(
|
|
|
|
np.linalg.norm(np.matmul(a, v) - w * v), tol * np.linalg.norm(a)
|
|
|
|
)
|
|
|
|
|
|
|
|
self._CompileAndCheck(partial(jnp.linalg.eigh), args_maker, rtol=tol)
|
|
|
|
|
|
|
|
# Compare eigenvalues against Numpy. We do not compare eigenvectors because
|
|
|
|
# they are not uniquely defined, but the two checks above guarantee that
|
|
|
|
# that they satisfy the conditions for being eigenvectors.
|
|
|
|
w_np = np.linalg.eigvalsh(a)[subset_by_index[0] : subset_by_index[1]]
|
|
|
|
self.assertAllClose(w_np, w, atol=tol, rtol=tol)
|
|
|
|
|
2021-06-09 14:56:07 -07:00
|
|
|
def testEighZeroDiagonal(self):
|
|
|
|
a = np.array([[0., -1., -1., 1.],
|
|
|
|
[-1., 0., 1., -1.],
|
|
|
|
[-1., 1., 0., -1.],
|
|
|
|
[1., -1., -1., 0.]], dtype=np.float32)
|
|
|
|
w, v = jnp.linalg.eigh(a)
|
2022-06-16 10:32:20 -07:00
|
|
|
w = w.astype(v.dtype)
|
2023-05-04 07:45:38 -07:00
|
|
|
eps = jnp.finfo(a.dtype).eps
|
2022-01-26 09:29:06 -08:00
|
|
|
with jax.numpy_rank_promotion('allow'):
|
2023-04-10 11:43:25 -07:00
|
|
|
self.assertLessEqual(
|
2023-05-04 07:45:38 -07:00
|
|
|
np.linalg.norm(np.matmul(a, v) - w * v), 2 * eps * np.linalg.norm(a)
|
2023-04-10 11:43:25 -07:00
|
|
|
)
|
2021-06-09 14:56:07 -07:00
|
|
|
|
2023-04-24 17:31:14 -07:00
|
|
|
def testEighTinyNorm(self):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
a = rng((300, 300), dtype=np.float32)
|
|
|
|
eps = jnp.finfo(a.dtype).eps
|
|
|
|
a = eps * (a + np.conj(a.T))
|
|
|
|
w, v = jnp.linalg.eigh(a)
|
|
|
|
w = w.astype(v.dtype)
|
|
|
|
with jax.numpy_rank_promotion("allow"):
|
|
|
|
self.assertLessEqual(
|
2023-05-04 07:45:38 -07:00
|
|
|
np.linalg.norm(np.matmul(a, v) - w * v), 20 * eps * np.linalg.norm(a)
|
2023-04-24 17:31:14 -07:00
|
|
|
)
|
|
|
|
|
2023-05-04 07:45:38 -07:00
|
|
|
@jtu.sample_product(
|
|
|
|
rank=[1, 3, 299],
|
|
|
|
)
|
|
|
|
def testEighRankDeficient(self, rank):
|
2023-05-02 14:34:27 -07:00
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
eps = jnp.finfo(np.float32).eps
|
2023-05-04 07:45:38 -07:00
|
|
|
a = rng((300, rank), dtype=np.float32)
|
|
|
|
a = a @ np.conj(a.T)
|
|
|
|
w, v = jnp.linalg.eigh(a)
|
|
|
|
w = w.astype(v.dtype)
|
|
|
|
with jax.numpy_rank_promotion("allow"):
|
|
|
|
self.assertLessEqual(
|
|
|
|
np.linalg.norm(np.matmul(a, v) - w * v),
|
2023-09-29 18:43:30 +00:00
|
|
|
81 * eps * np.linalg.norm(a),
|
2023-05-04 07:45:38 -07:00
|
|
|
)
|
2023-05-02 14:34:27 -07:00
|
|
|
|
2022-11-23 08:17:06 -08:00
|
|
|
@jtu.sample_product(
|
|
|
|
n=[0, 4, 5, 50, 512],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
lower=[True, False],
|
|
|
|
)
|
|
|
|
def testEighIdentity(self, n, dtype, lower):
|
2023-04-10 11:43:25 -07:00
|
|
|
tol = np.finfo(dtype).eps
|
2022-11-23 08:17:06 -08:00
|
|
|
uplo = "L" if lower else "U"
|
|
|
|
|
|
|
|
a = jnp.eye(n, dtype=dtype)
|
|
|
|
w, v = jnp.linalg.eigh(a, UPLO=uplo, symmetrize_input=False)
|
|
|
|
w = w.astype(v.dtype)
|
|
|
|
self.assertLessEqual(
|
2023-04-10 11:43:25 -07:00
|
|
|
np.linalg.norm(np.eye(n) - np.matmul(np.conj(T(v)), v)), tol
|
|
|
|
)
|
2022-11-23 08:17:06 -08:00
|
|
|
with jax.numpy_rank_promotion('allow'):
|
|
|
|
self.assertLessEqual(np.linalg.norm(np.matmul(a, v) - w * v),
|
|
|
|
tol * np.linalg.norm(a))
|
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(4, 4), (5, 5), (50, 50)],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
)
|
2020-12-04 09:44:50 -08:00
|
|
|
def testEigvalsh(self, shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2019-10-30 19:29:56 -07:00
|
|
|
n = shape[-1]
|
|
|
|
def args_maker():
|
|
|
|
a = rng((n, n), dtype)
|
2020-05-05 16:40:41 -04:00
|
|
|
a = (a + np.conj(a.T)) / 2
|
2019-10-30 19:29:56 -07:00
|
|
|
return [a]
|
2023-05-04 07:45:38 -07:00
|
|
|
self._CheckAgainstNumpy(
|
2023-06-23 09:21:32 -07:00
|
|
|
np.linalg.eigvalsh, jnp.linalg.eigvalsh, args_maker, tol=2e-5
|
2023-05-04 07:45:38 -07:00
|
|
|
)
|
2019-10-30 19:29:56 -07:00
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(1, 1), (4, 4), (5, 5), (50, 50), (2, 10, 10)],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
lower=[True, False],
|
|
|
|
)
|
2020-12-04 09:44:50 -08:00
|
|
|
def testEighGrad(self, shape, dtype, lower):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2019-04-12 09:29:46 -04:00
|
|
|
self.skipTest("Test fails with numeric errors.")
|
2019-02-11 23:26:26 -08:00
|
|
|
uplo = "L" if lower else "U"
|
|
|
|
a = rng(shape, dtype)
|
2020-05-05 16:40:41 -04:00
|
|
|
a = (a + np.conj(T(a))) / 2
|
|
|
|
ones = np.ones((a.shape[-1], a.shape[-1]), dtype=dtype)
|
|
|
|
a *= np.tril(ones) if lower else np.triu(ones)
|
2019-02-11 23:26:26 -08:00
|
|
|
# Gradient checks will fail without symmetrization as the eigh jvp rule
|
|
|
|
# is only correct for tangents in the symmetric subspace, whereas the
|
|
|
|
# checker checks against unconstrained (co)tangents.
|
2019-06-27 16:36:54 -04:00
|
|
|
if dtype not in complex_types:
|
2020-05-05 16:40:41 -04:00
|
|
|
f = partial(jnp.linalg.eigh, UPLO=uplo, symmetrize_input=True)
|
2019-02-13 23:23:39 -08:00
|
|
|
else: # only check eigenvalue grads for complex matrices
|
2020-05-05 16:40:41 -04:00
|
|
|
f = lambda a: partial(jnp.linalg.eigh, UPLO=uplo, symmetrize_input=True)(a)[0]
|
2023-04-10 11:43:25 -07:00
|
|
|
jtu.check_grads(f, (a,), 2, rtol=1e-5)
|
2019-02-11 23:26:26 -08:00
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
2023-04-10 11:43:25 -07:00
|
|
|
shape=[(1, 1), (4, 4), (5, 5), (50, 50)],
|
|
|
|
dtype=complex_types,
|
|
|
|
lower=[True, False],
|
|
|
|
eps=[1e-5],
|
2022-10-11 15:59:44 +00:00
|
|
|
)
|
2020-12-04 09:44:50 -08:00
|
|
|
def testEighGradVectorComplex(self, shape, dtype, lower, eps):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2019-02-13 23:23:39 -08:00
|
|
|
# Special case to test for complex eigenvector grad correctness.
|
|
|
|
# Exact eigenvector coordinate gradients are hard to test numerically for complex
|
|
|
|
# eigensystem solvers given the extra degrees of per-eigenvector phase freedom.
|
|
|
|
# Instead, we numerically verify the eigensystem properties on the perturbed
|
|
|
|
# eigenvectors. You only ever want to optimize eigenvector directions, not coordinates!
|
|
|
|
uplo = "L" if lower else "U"
|
|
|
|
a = rng(shape, dtype)
|
2020-05-05 16:40:41 -04:00
|
|
|
a = (a + np.conj(a.T)) / 2
|
|
|
|
a = np.tril(a) if lower else np.triu(a)
|
2019-02-13 23:23:39 -08:00
|
|
|
a_dot = eps * rng(shape, dtype)
|
2020-05-05 16:40:41 -04:00
|
|
|
a_dot = (a_dot + np.conj(a_dot.T)) / 2
|
|
|
|
a_dot = np.tril(a_dot) if lower else np.triu(a_dot)
|
2019-02-13 23:23:39 -08:00
|
|
|
# evaluate eigenvector gradient and groundtruth eigensystem for perturbed input matrix
|
2020-05-05 16:40:41 -04:00
|
|
|
f = partial(jnp.linalg.eigh, UPLO=uplo)
|
2019-02-14 02:28:00 -08:00
|
|
|
(w, v), (dw, dv) = jvp(f, primals=(a,), tangents=(a_dot,))
|
2020-06-25 08:14:54 -04:00
|
|
|
self.assertTrue(jnp.issubdtype(w.dtype, jnp.floating))
|
|
|
|
self.assertTrue(jnp.issubdtype(dw.dtype, jnp.floating))
|
2019-02-13 23:23:39 -08:00
|
|
|
new_a = a + a_dot
|
|
|
|
new_w, new_v = f(new_a)
|
2020-05-05 16:40:41 -04:00
|
|
|
new_a = (new_a + np.conj(new_a.T)) / 2
|
2022-06-16 10:32:20 -07:00
|
|
|
new_w = new_w.astype(new_a.dtype)
|
2019-02-13 23:23:39 -08:00
|
|
|
# Assert rtol eigenvalue delta between perturbed eigenvectors vs new true eigenvalues.
|
2021-07-29 09:51:41 -04:00
|
|
|
RTOL = 1e-2
|
2022-01-26 09:29:06 -08:00
|
|
|
with jax.numpy_rank_promotion('allow'):
|
|
|
|
assert np.max(
|
|
|
|
np.abs((np.diag(np.dot(np.conj((v+dv).T), np.dot(new_a,(v+dv)))) - new_w) / new_w)) < RTOL
|
|
|
|
# Redundant to above, but also assert rtol for eigenvector property with new true eigenvalues.
|
|
|
|
assert np.max(
|
|
|
|
np.linalg.norm(np.abs(new_w*(v+dv) - np.dot(new_a, (v+dv))), axis=0) /
|
|
|
|
np.linalg.norm(np.abs(new_w*(v+dv)), axis=0)
|
|
|
|
) < RTOL
|
2019-02-13 23:23:39 -08:00
|
|
|
|
2019-12-10 00:38:18 -08:00
|
|
|
def testEighGradPrecision(self):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-05-05 16:40:41 -04:00
|
|
|
a = rng((3, 3), np.float32)
|
2019-12-10 00:38:18 -08:00
|
|
|
jtu.assert_dot_precision(
|
2020-05-05 16:40:41 -04:00
|
|
|
lax.Precision.HIGHEST, partial(jvp, jnp.linalg.eigh), (a,), (a,))
|
2019-12-10 00:38:18 -08:00
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(1, 1), (4, 4), (5, 5), (300, 300)],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
)
|
2020-12-04 09:44:50 -08:00
|
|
|
def testEighBatching(self, shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2019-05-10 15:15:38 -04:00
|
|
|
shape = (10,) + shape
|
|
|
|
args = rng(shape, dtype)
|
2020-05-05 16:40:41 -04:00
|
|
|
args = (args + np.conj(T(args))) / 2
|
2019-05-10 15:15:38 -04:00
|
|
|
ws, vs = vmap(jsp.linalg.eigh)(args)
|
2022-06-16 10:32:20 -07:00
|
|
|
ws = ws.astype(vs.dtype)
|
2022-05-27 13:50:07 -07:00
|
|
|
norm = np.max(np.linalg.norm(np.matmul(args, vs) - ws[..., None, :] * vs))
|
2023-04-24 17:31:14 -07:00
|
|
|
self.assertLess(norm, 1e-2)
|
2019-05-10 15:15:38 -04:00
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(1,), (4,), (5,)],
|
|
|
|
dtype=(np.int32,),
|
|
|
|
)
|
2021-03-31 18:35:15 -07:00
|
|
|
def testLuPivotsToPermutation(self, shape, dtype):
|
|
|
|
pivots_size = shape[-1]
|
|
|
|
permutation_size = 2 * pivots_size
|
|
|
|
|
|
|
|
pivots = jnp.arange(permutation_size - 1, pivots_size - 1, -1, dtype=dtype)
|
|
|
|
pivots = jnp.broadcast_to(pivots, shape)
|
|
|
|
actual = lax.linalg.lu_pivots_to_permutation(pivots, permutation_size)
|
|
|
|
expected = jnp.arange(permutation_size - 1, -1, -1, dtype=dtype)
|
|
|
|
expected = jnp.broadcast_to(expected, actual.shape)
|
|
|
|
self.assertArraysEqual(actual, expected)
|
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(1,), (4,), (5,)],
|
|
|
|
dtype=(np.int32,),
|
|
|
|
)
|
2021-03-31 18:35:15 -07:00
|
|
|
def testLuPivotsToPermutationBatching(self, shape, dtype):
|
|
|
|
shape = (10,) + shape
|
|
|
|
pivots_size = shape[-1]
|
|
|
|
permutation_size = 2 * pivots_size
|
|
|
|
|
|
|
|
pivots = jnp.arange(permutation_size - 1, pivots_size - 1, -1, dtype=dtype)
|
|
|
|
pivots = jnp.broadcast_to(pivots, shape)
|
|
|
|
batched_fn = vmap(
|
|
|
|
lambda x: lax.linalg.lu_pivots_to_permutation(x, permutation_size))
|
|
|
|
actual = batched_fn(pivots)
|
|
|
|
expected = jnp.arange(permutation_size - 1, -1, -1, dtype=dtype)
|
|
|
|
expected = jnp.broadcast_to(expected, actual.shape)
|
|
|
|
self.assertArraysEqual(actual, expected)
|
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(axis=axis, shape=shape, ord=ord)
|
|
|
|
for axis, shape in [
|
|
|
|
(None, (1,)), (None, (7,)), (None, (5, 8)),
|
|
|
|
(0, (9,)), (0, (4, 5)), ((1,), (10, 7, 3)), ((-2,), (4, 8)),
|
|
|
|
(-1, (6, 3)), ((0, 2), (3, 4, 5)), ((2, 0), (7, 8, 9)),
|
|
|
|
(None, (7, 8, 11))]
|
|
|
|
for ord in (
|
|
|
|
[None] if axis is None and len(shape) > 2
|
|
|
|
else [None, 0, 1, 2, 3, -1, -2, -3, jnp.inf, -jnp.inf]
|
|
|
|
if (axis is None and len(shape) == 1) or
|
|
|
|
isinstance(axis, int) or
|
|
|
|
(isinstance(axis, tuple) and len(axis) == 1)
|
|
|
|
else [None, 'fro', 1, 2, -1, -2, jnp.inf, -jnp.inf, 'nuc'])
|
|
|
|
],
|
|
|
|
keepdims=[False, True],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
)
|
2020-12-04 09:44:50 -08:00
|
|
|
def testNorm(self, shape, dtype, ord, axis, keepdims):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2019-02-07 10:51:55 -05:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
np_fn = partial(np.linalg.norm, ord=ord, axis=axis, keepdims=keepdims)
|
2021-01-25 12:56:11 +10:00
|
|
|
jnp_fn = partial(jnp.linalg.norm, ord=ord, axis=axis, keepdims=keepdims)
|
|
|
|
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False,
|
2020-06-01 17:19:23 -04:00
|
|
|
tol=1e-3)
|
2021-01-25 12:56:11 +10:00
|
|
|
self._CompileAndCheck(jnp_fn, args_maker)
|
2019-02-07 10:51:55 -05:00
|
|
|
|
2022-08-05 10:15:26 -07:00
|
|
|
def testStringInfNorm(self):
|
|
|
|
err, msg = ValueError, r"Invalid order 'inf' for vector norm."
|
|
|
|
with self.assertRaisesRegex(err, msg):
|
|
|
|
jnp.linalg.norm(jnp.array([1.0, 2.0, 3.0]), ord="inf")
|
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
2023-11-13 12:03:36 -08:00
|
|
|
[
|
|
|
|
dict(m=m, n=n, full_matrices=full_matrices, hermitian=hermitian)
|
|
|
|
for (m, n), full_matrices in (
|
|
|
|
list(
|
|
|
|
itertools.product(
|
|
|
|
itertools.product([0, 2, 7, 29, 32, 53], repeat=2),
|
|
|
|
[False, True],
|
|
|
|
)
|
|
|
|
)
|
|
|
|
+
|
|
|
|
# Test cases that ensure we are economical when computing the SVD
|
|
|
|
# and its gradient. If we form a 400kx400k matrix explicitly we
|
|
|
|
# will OOM.
|
|
|
|
[((400000, 2), False), ((2, 400000), False)]
|
|
|
|
)
|
|
|
|
for hermitian in ([False, True] if m == n else [False])
|
|
|
|
],
|
|
|
|
b=[(), (3,), (2, 3)],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
compute_uv=[False, True],
|
2022-10-11 15:59:44 +00:00
|
|
|
)
|
2022-11-09 18:57:28 -08:00
|
|
|
@jax.default_matmul_precision("float32")
|
2021-11-01 09:54:46 -04:00
|
|
|
def testSVD(self, b, m, n, dtype, full_matrices, compute_uv, hermitian):
|
2020-12-04 09:44:50 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2019-09-05 18:12:00 -04:00
|
|
|
args_maker = lambda: [rng(b + (m, n), dtype)]
|
2019-01-08 15:51:30 +05:30
|
|
|
|
2022-04-28 14:41:29 -07:00
|
|
|
def compute_max_backward_error(operand, reconstructed_operand):
|
|
|
|
error_norm = np.linalg.norm(operand - reconstructed_operand,
|
|
|
|
axis=(-2, -1))
|
|
|
|
backward_error = (error_norm /
|
|
|
|
np.linalg.norm(operand, axis=(-2, -1)))
|
|
|
|
max_backward_error = np.amax(backward_error)
|
|
|
|
return max_backward_error
|
|
|
|
|
2023-05-04 07:45:38 -07:00
|
|
|
tol = 80 * jnp.finfo(dtype).eps
|
2023-04-24 17:31:14 -07:00
|
|
|
reconstruction_tol = 2 * tol
|
|
|
|
unitariness_tol = tol
|
2019-01-08 15:51:30 +05:30
|
|
|
|
|
|
|
a, = args_maker()
|
2021-11-01 09:54:46 -04:00
|
|
|
if hermitian:
|
|
|
|
a = a + np.conj(T(a))
|
|
|
|
out = jnp.linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv,
|
|
|
|
hermitian=hermitian)
|
2019-01-08 15:51:30 +05:30
|
|
|
if compute_uv:
|
|
|
|
# Check the reconstructed matrices
|
2022-06-16 10:32:20 -07:00
|
|
|
out = list(out)
|
|
|
|
out[1] = out[1].astype(out[0].dtype) # for strict dtype promotion.
|
2022-04-28 14:41:29 -07:00
|
|
|
if m and n:
|
|
|
|
if full_matrices:
|
|
|
|
k = min(m, n)
|
|
|
|
if m < n:
|
|
|
|
max_backward_error = compute_max_backward_error(
|
|
|
|
a, np.matmul(out[1][..., None, :] * out[0], out[2][..., :k, :]))
|
|
|
|
self.assertLess(max_backward_error, reconstruction_tol)
|
|
|
|
else:
|
|
|
|
max_backward_error = compute_max_backward_error(
|
|
|
|
a, np.matmul(out[1][..., None, :] * out[0][..., :, :k], out[2]))
|
|
|
|
self.assertLess(max_backward_error, reconstruction_tol)
|
2019-01-08 15:51:30 +05:30
|
|
|
else:
|
2022-04-28 14:41:29 -07:00
|
|
|
max_backward_error = compute_max_backward_error(
|
|
|
|
a, np.matmul(out[1][..., None, :] * out[0], out[2]))
|
|
|
|
self.assertLess(max_backward_error, reconstruction_tol)
|
2019-01-08 15:51:30 +05:30
|
|
|
|
|
|
|
# Check the unitary properties of the singular vector matrices.
|
2022-04-28 14:41:29 -07:00
|
|
|
unitary_mat = np.real(np.matmul(np.conj(T(out[0])), out[0]))
|
|
|
|
eye_slice = np.eye(out[0].shape[-1], dtype=unitary_mat.dtype)
|
|
|
|
self.assertAllClose(np.broadcast_to(eye_slice, b + eye_slice.shape),
|
|
|
|
unitary_mat, rtol=unitariness_tol,
|
|
|
|
atol=unitariness_tol)
|
2019-01-08 15:51:30 +05:30
|
|
|
if m >= n:
|
2022-04-28 14:41:29 -07:00
|
|
|
unitary_mat = np.real(np.matmul(np.conj(T(out[2])), out[2]))
|
|
|
|
eye_slice = np.eye(out[2].shape[-1], dtype=unitary_mat.dtype)
|
|
|
|
self.assertAllClose(np.broadcast_to(eye_slice, b + eye_slice.shape),
|
|
|
|
unitary_mat, rtol=unitariness_tol,
|
|
|
|
atol=unitariness_tol)
|
2019-01-08 15:51:30 +05:30
|
|
|
else:
|
2022-04-28 14:41:29 -07:00
|
|
|
unitary_mat = np.real(np.matmul(out[2], np.conj(T(out[2]))))
|
|
|
|
eye_slice = np.eye(out[2].shape[-2], dtype=unitary_mat.dtype)
|
|
|
|
self.assertAllClose(np.broadcast_to(eye_slice, b + eye_slice.shape),
|
|
|
|
unitary_mat, rtol=unitariness_tol,
|
|
|
|
atol=unitariness_tol)
|
2019-01-08 21:47:19 +05:30
|
|
|
else:
|
2022-04-28 14:41:29 -07:00
|
|
|
self.assertTrue(np.allclose(np.linalg.svd(a, compute_uv=False),
|
|
|
|
np.asarray(out), atol=1e-4, rtol=1e-4))
|
2019-01-08 21:47:19 +05:30
|
|
|
|
2022-04-28 14:41:29 -07:00
|
|
|
self._CompileAndCheck(partial(jnp.linalg.svd, full_matrices=full_matrices,
|
|
|
|
compute_uv=compute_uv),
|
2020-06-01 17:19:23 -04:00
|
|
|
args_maker)
|
2022-11-11 12:55:20 -08:00
|
|
|
|
|
|
|
if not compute_uv and a.size < 100000:
|
2020-05-05 16:40:41 -04:00
|
|
|
svd = partial(jnp.linalg.svd, full_matrices=full_matrices,
|
2020-05-04 23:00:20 -04:00
|
|
|
compute_uv=compute_uv)
|
|
|
|
# TODO(phawkins): these tolerances seem very loose.
|
2020-12-11 13:29:35 +01:00
|
|
|
if dtype == np.complex128:
|
2022-04-28 14:41:29 -07:00
|
|
|
jtu.check_jvp(svd, partial(jvp, svd), (a,), rtol=1e-4, atol=1e-4,
|
|
|
|
eps=1e-8)
|
2020-12-11 13:29:35 +01:00
|
|
|
else:
|
|
|
|
jtu.check_jvp(svd, partial(jvp, svd), (a,), rtol=5e-2, atol=2e-1)
|
|
|
|
|
|
|
|
if compute_uv and (not full_matrices):
|
|
|
|
b, = args_maker()
|
|
|
|
def f(x):
|
|
|
|
u, s, v = jnp.linalg.svd(
|
|
|
|
a + x * b,
|
|
|
|
full_matrices=full_matrices,
|
|
|
|
compute_uv=compute_uv)
|
|
|
|
vdiag = jnp.vectorize(jnp.diag, signature='(k)->(k,k)')
|
2022-06-16 10:32:20 -07:00
|
|
|
return jnp.matmul(jnp.matmul(u, vdiag(s).astype(u.dtype)), v).real
|
2020-12-11 13:29:35 +01:00
|
|
|
_, t_out = jvp(f, (1.,), (1.,))
|
|
|
|
if dtype == np.complex128:
|
2022-11-11 12:55:20 -08:00
|
|
|
atol = 2e-13
|
2020-12-11 13:29:35 +01:00
|
|
|
else:
|
|
|
|
atol = 5e-4
|
|
|
|
self.assertArraysAllClose(t_out, b.real, atol=atol)
|
2019-01-07 18:10:08 -05:00
|
|
|
|
2022-04-21 02:17:09 +08:00
|
|
|
def testJspSVDBasic(self):
|
|
|
|
# since jax.scipy.linalg.svd is almost the same as jax.numpy.linalg.svd
|
|
|
|
# do not check it functionality here
|
|
|
|
jsp.linalg.svd(np.ones((2, 2), dtype=np.float32))
|
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(0, 2), (2, 0), (3, 4), (3, 3), (4, 3)],
|
|
|
|
dtype=[np.float32],
|
|
|
|
mode=["reduced", "r", "full", "complete", "raw"],
|
|
|
|
)
|
2022-04-26 11:26:56 -07:00
|
|
|
def testNumpyQrModes(self, shape, dtype, mode):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
jnp_func = partial(jax.numpy.linalg.qr, mode=mode)
|
|
|
|
np_func = partial(np.linalg.qr, mode=mode)
|
|
|
|
if mode == "full":
|
|
|
|
np_func = jtu.ignore_warning(category=DeprecationWarning, message="The 'full' option.*")(np_func)
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2022-05-16 12:59:27 -07:00
|
|
|
self._CheckAgainstNumpy(np_func, jnp_func, args_maker, rtol=1e-5, atol=1e-5,
|
|
|
|
check_dtypes=(mode != "raw"))
|
2022-04-26 11:26:56 -07:00
|
|
|
self._CompileAndCheck(jnp_func, args_maker)
|
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(0, 0), (2, 0), (0, 2), (3, 3), (3, 4), (2, 10, 5),
|
|
|
|
(2, 200, 100), (64, 16, 5), (33, 7, 3), (137, 9, 5), (20000, 2, 2)],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
full_matrices=[False, True],
|
|
|
|
)
|
|
|
|
@jax.default_matmul_precision("float32")
|
2020-12-04 09:44:50 -08:00
|
|
|
def testQr(self, shape, dtype, full_matrices):
|
2023-09-27 14:55:21 -07:00
|
|
|
if (jtu.test_device_matches(["cuda"]) and
|
2023-07-17 10:14:44 -07:00
|
|
|
_is_required_cuda_version_satisfied(12000)):
|
|
|
|
self.skipTest("Triggers a bug in cuda-12 b/287345077")
|
2020-12-04 09:44:50 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
m, n = shape[-2:]
|
|
|
|
|
|
|
|
if full_matrices:
|
|
|
|
mode, k = "complete", m
|
|
|
|
else:
|
|
|
|
mode, k = "reduced", min(m, n)
|
|
|
|
|
|
|
|
a = rng(shape, dtype)
|
2020-05-05 16:40:41 -04:00
|
|
|
lq, lr = jnp.linalg.qr(a, mode=mode)
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2020-05-05 16:40:41 -04:00
|
|
|
# np.linalg.qr doesn't support batch dimensions. But it seems like an
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
# inevitable extension so we support it in our version.
|
2020-05-05 16:40:41 -04:00
|
|
|
nq = np.zeros(shape[:-2] + (m, k), dtype)
|
|
|
|
nr = np.zeros(shape[:-2] + (k, n), dtype)
|
|
|
|
for index in np.ndindex(*shape[:-2]):
|
|
|
|
nq[index], nr[index] = np.linalg.qr(a[index], mode=mode)
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
max_rank = max(m, n)
|
|
|
|
|
|
|
|
# Norm, adjusted for dimension and type.
|
|
|
|
def norm(x):
|
2020-05-05 16:40:41 -04:00
|
|
|
n = np.linalg.norm(x, axis=(-2, -1))
|
2022-04-29 10:04:47 -07:00
|
|
|
return n / (max(1, max_rank) * jnp.finfo(dtype).eps)
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
def compare_orthogonal(q1, q2):
|
|
|
|
# Q is unique up to sign, so normalize the sign first.
|
2022-04-29 10:04:47 -07:00
|
|
|
ratio = np.divide(np.where(q2 == 0, 0, q1), np.where(q2 == 0, 1, q2))
|
|
|
|
sum_of_ratios = ratio.sum(axis=-2, keepdims=True)
|
2020-05-05 16:40:41 -04:00
|
|
|
phases = np.divide(sum_of_ratios, np.abs(sum_of_ratios))
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
q1 *= phases
|
2022-05-18 15:00:14 -07:00
|
|
|
nm = norm(q1 - q2)
|
2022-11-07 12:40:16 -05:00
|
|
|
self.assertTrue(np.all(nm < 160), msg=f"norm={np.amax(nm)}")
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
# Check a ~= qr
|
2022-11-09 18:57:28 -08:00
|
|
|
norm_error = norm(a - np.matmul(lq, lr))
|
2023-10-24 10:44:52 -07:00
|
|
|
self.assertTrue(np.all(norm_error < 60), msg=np.amax(norm_error))
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
# Compare the first 'k' vectors of Q; the remainder form an arbitrary
|
|
|
|
# orthonormal basis for the null space.
|
|
|
|
compare_orthogonal(nq[..., :k], lq[..., :k])
|
|
|
|
|
|
|
|
# Check that q is close to unitary.
|
2020-05-05 16:40:41 -04:00
|
|
|
self.assertTrue(np.all(
|
2022-05-18 15:00:14 -07:00
|
|
|
norm(np.eye(k) - np.matmul(np.conj(T(lq)), lq)) < 10))
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
# This expresses identity function, which makes us robust to, e.g., the
|
|
|
|
# tangents flipping the direction of vectors in Q.
|
|
|
|
def qr_and_mul(a):
|
|
|
|
q, r = jnp.linalg.qr(a, mode=mode)
|
|
|
|
return q @ r
|
|
|
|
|
2022-04-26 10:34:50 -07:00
|
|
|
if m == n or (m > n and not full_matrices):
|
2022-10-11 15:59:44 +00:00
|
|
|
jtu.check_jvp(qr_and_mul, partial(jvp, qr_and_mul), (a,), atol=3e-3)
|
2018-12-17 16:36:55 +00:00
|
|
|
|
2022-05-03 13:47:11 -07:00
|
|
|
@jtu.skip_on_devices("tpu")
|
|
|
|
def testQrInvalidDtypeCPU(self, shape=(5, 6), dtype=np.float16):
|
|
|
|
# Regression test for https://github.com/google/jax/issues/10530
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
arr = rng(shape, dtype)
|
2023-09-27 12:10:06 -07:00
|
|
|
if jtu.test_device_matches(['cpu']):
|
2022-05-03 13:47:11 -07:00
|
|
|
err, msg = NotImplementedError, "Unsupported dtype float16"
|
|
|
|
else:
|
|
|
|
err, msg = ValueError, r"Unsupported dtype dtype\('float16'\)"
|
|
|
|
with self.assertRaisesRegex(err, msg):
|
|
|
|
jnp.linalg.qr(arr)
|
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(10, 4, 5), (5, 3, 3), (7, 6, 4)],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
)
|
2020-12-04 09:44:50 -08:00
|
|
|
def testQrBatching(self, shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-05-05 16:40:41 -04:00
|
|
|
args = rng(shape, jnp.float32)
|
2019-04-30 18:48:09 -04:00
|
|
|
qs, rs = vmap(jsp.linalg.qr)(args)
|
2020-05-05 16:40:41 -04:00
|
|
|
self.assertTrue(np.all(np.linalg.norm(args - np.matmul(qs, rs)) < 1e-3))
|
2019-04-30 18:48:09 -04:00
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(1, 1), (4, 4), (2, 3, 5), (5, 5, 5), (20, 20), (5, 10)],
|
|
|
|
pnorm=[jnp.inf, -jnp.inf, 1, -1, 2, -2, 'fro'],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
)
|
2020-02-10 12:49:58 +01:00
|
|
|
@jtu.skip_on_devices("gpu") # TODO(#2203): numerical errors
|
2020-02-08 07:20:04 +11:00
|
|
|
def testCond(self, shape, pnorm, dtype):
|
|
|
|
def gen_mat():
|
2020-05-04 23:00:20 -04:00
|
|
|
# arr_gen = jtu.rand_some_nan(self.rng())
|
|
|
|
arr_gen = jtu.rand_default(self.rng())
|
2020-02-08 07:20:04 +11:00
|
|
|
res = arr_gen(shape, dtype)
|
|
|
|
return res
|
|
|
|
|
|
|
|
def args_gen(p):
|
|
|
|
def _args_gen():
|
|
|
|
return [gen_mat(), p]
|
|
|
|
return _args_gen
|
|
|
|
|
|
|
|
args_maker = args_gen(pnorm)
|
|
|
|
if pnorm not in [2, -2] and len(set(shape[-2:])) != 1:
|
2020-05-05 16:40:41 -04:00
|
|
|
with self.assertRaises(np.linalg.LinAlgError):
|
|
|
|
jnp.linalg.cond(*args_maker())
|
2020-02-08 07:20:04 +11:00
|
|
|
else:
|
2020-05-05 16:40:41 -04:00
|
|
|
self._CheckAgainstNumpy(np.linalg.cond, jnp.linalg.cond, args_maker,
|
2020-02-08 07:20:04 +11:00
|
|
|
check_dtypes=False, tol=1e-3)
|
2020-05-05 16:40:41 -04:00
|
|
|
partial_norm = partial(jnp.linalg.cond, p=pnorm)
|
2020-02-08 07:20:04 +11:00
|
|
|
self._CompileAndCheck(partial_norm, lambda: [gen_mat()],
|
|
|
|
check_dtypes=False, rtol=1e-03, atol=1e-03)
|
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(1, 1), (4, 4), (200, 200), (7, 7, 7, 7)],
|
|
|
|
dtype=float_types,
|
|
|
|
)
|
2020-12-04 09:44:50 -08:00
|
|
|
def testTensorinv(self, shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-02-08 07:20:04 +11:00
|
|
|
|
|
|
|
def tensor_maker():
|
|
|
|
invertible = False
|
|
|
|
while not invertible:
|
|
|
|
a = rng(shape, dtype)
|
|
|
|
try:
|
2020-05-05 16:40:41 -04:00
|
|
|
np.linalg.inv(a)
|
2020-02-08 07:20:04 +11:00
|
|
|
invertible = True
|
2020-05-05 16:40:41 -04:00
|
|
|
except np.linalg.LinAlgError:
|
2020-02-08 07:20:04 +11:00
|
|
|
pass
|
|
|
|
return a
|
|
|
|
|
2020-05-05 16:40:41 -04:00
|
|
|
args_maker = lambda: [tensor_maker(), int(np.floor(len(shape) / 2))]
|
|
|
|
self._CheckAgainstNumpy(np.linalg.tensorinv, jnp.linalg.tensorinv, args_maker,
|
2020-02-08 07:20:04 +11:00
|
|
|
check_dtypes=False, tol=1e-3)
|
2020-05-05 16:40:41 -04:00
|
|
|
partial_inv = partial(jnp.linalg.tensorinv, ind=int(np.floor(len(shape) / 2)))
|
2020-02-08 07:20:04 +11:00
|
|
|
self._CompileAndCheck(partial_inv, lambda: [tensor_maker()], check_dtypes=False, rtol=1e-03, atol=1e-03)
|
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape)
|
|
|
|
for lhs_shape, rhs_shape in [
|
|
|
|
((1, 1), (1, 1)),
|
|
|
|
((4, 4), (4,)),
|
|
|
|
((8, 8), (8, 4)),
|
|
|
|
((1, 2, 2), (3, 2)),
|
|
|
|
((2, 1, 3, 3), (1, 4, 3, 4)),
|
|
|
|
((1, 0, 0), (1, 0, 2)),
|
|
|
|
]
|
|
|
|
],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
)
|
2020-12-04 09:44:50 -08:00
|
|
|
def testSolve(self, lhs_shape, rhs_shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-12-21 16:29:45 -05:00
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
|
|
|
|
2020-05-05 16:40:41 -04:00
|
|
|
self._CheckAgainstNumpy(np.linalg.solve, jnp.linalg.solve, args_maker,
|
2020-06-01 17:19:23 -04:00
|
|
|
tol=1e-3)
|
|
|
|
self._CompileAndCheck(jnp.linalg.solve, args_maker)
|
2018-12-21 16:29:45 -05:00
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(1, 1), (4, 4), (2, 5, 5), (100, 100), (5, 5, 5), (0, 0)],
|
|
|
|
dtype=float_types,
|
|
|
|
)
|
2020-12-04 09:44:50 -08:00
|
|
|
def testInv(self, shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2019-06-27 16:36:54 -04:00
|
|
|
|
2018-12-13 19:28:05 -05:00
|
|
|
def args_maker():
|
2018-12-13 19:33:02 -05:00
|
|
|
invertible = False
|
|
|
|
while not invertible:
|
|
|
|
a = rng(shape, dtype)
|
|
|
|
try:
|
2020-05-05 16:40:41 -04:00
|
|
|
np.linalg.inv(a)
|
2018-12-13 19:33:02 -05:00
|
|
|
invertible = True
|
2020-05-05 16:40:41 -04:00
|
|
|
except np.linalg.LinAlgError:
|
2018-12-13 19:33:02 -05:00
|
|
|
pass
|
2018-12-13 19:28:05 -05:00
|
|
|
return [a]
|
|
|
|
|
2020-05-05 16:40:41 -04:00
|
|
|
self._CheckAgainstNumpy(np.linalg.inv, jnp.linalg.inv, args_maker,
|
2020-06-01 17:19:23 -04:00
|
|
|
tol=1e-3)
|
|
|
|
self._CompileAndCheck(jnp.linalg.inv, args_maker)
|
2018-12-13 19:28:05 -05:00
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
2022-11-07 11:39:19 -05:00
|
|
|
[dict(shape=shape, hermitian=hermitian)
|
|
|
|
for shape in [(1, 1), (4, 4), (3, 10, 10), (2, 70, 7), (2000, 7),
|
2022-11-11 11:44:22 -08:00
|
|
|
(7, 1000), (70, 7, 2), (2, 0, 0), (3, 0, 2), (1, 0),
|
|
|
|
(400000, 2), (2, 400000)]
|
2022-11-07 11:39:19 -05:00
|
|
|
for hermitian in ([False, True] if shape[-1] == shape[-2] else [False])],
|
2022-10-11 15:59:44 +00:00
|
|
|
dtype=float_types + complex_types,
|
|
|
|
)
|
2022-11-07 11:39:19 -05:00
|
|
|
def testPinv(self, shape, hermitian, dtype):
|
2020-12-04 09:44:50 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2019-12-03 11:15:39 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
|
2022-11-07 11:39:19 -05:00
|
|
|
jnp_fn = partial(jnp.linalg.pinv, hermitian=hermitian)
|
|
|
|
def np_fn(a):
|
|
|
|
# Symmetrize the input matrix to match the jnp behavior.
|
|
|
|
if hermitian:
|
|
|
|
a = (a + T(a.conj())) / 2
|
|
|
|
return np.linalg.pinv(a, hermitian=hermitian)
|
|
|
|
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-4)
|
|
|
|
self._CompileAndCheck(jnp_fn, args_maker)
|
2022-10-18 18:09:02 -07:00
|
|
|
|
2022-11-09 18:57:28 -08:00
|
|
|
# TODO(phawkins): 6e-2 seems like a very loose tolerance.
|
|
|
|
jtu.check_grads(jnp_fn, args_maker(), 1, rtol=6e-2, atol=1e-3)
|
2020-04-22 20:15:04 -04:00
|
|
|
|
|
|
|
def testPinvGradIssue2792(self):
|
|
|
|
def f(p):
|
2020-05-05 16:40:41 -04:00
|
|
|
a = jnp.array([[0., 0.],[-p, 1.]], jnp.float32) * 1 / (1 + p**2)
|
|
|
|
return jnp.linalg.pinv(a)
|
|
|
|
j = jax.jacobian(f)(jnp.float32(2.))
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(jnp.array([[0., -1.], [ 0., 0.]], jnp.float32), j)
|
2020-04-22 20:15:04 -04:00
|
|
|
|
2020-05-05 16:40:41 -04:00
|
|
|
expected = jnp.array([[[[-1., 0.], [ 0., 0.]], [[0., -1.], [0., 0.]]],
|
2020-04-22 20:15:04 -04:00
|
|
|
[[[0., 0.], [-1., 0.]], [[0., 0.], [0., -1.]]]],
|
2020-05-05 16:40:41 -04:00
|
|
|
dtype=jnp.float32)
|
2020-04-22 20:15:04 -04:00
|
|
|
self.assertAllClose(
|
2020-06-01 17:19:23 -04:00
|
|
|
expected, jax.jacobian(jnp.linalg.pinv)(jnp.eye(2, dtype=jnp.float32)))
|
2019-12-03 11:15:39 -08:00
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(1, 1), (2, 2), (4, 4), (5, 5), (1, 2, 2), (2, 3, 3), (2, 5, 5)],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
n=[-5, -2, -1, 0, 1, 2, 3, 4, 5, 10],
|
|
|
|
)
|
|
|
|
@jax.default_matmul_precision("float32")
|
2020-12-04 09:44:50 -08:00
|
|
|
def testMatrixPower(self, shape, dtype, n):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-01-24 16:52:40 -05:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2020-05-05 16:40:41 -04:00
|
|
|
self._CheckAgainstNumpy(partial(np.linalg.matrix_power, n=n),
|
|
|
|
partial(jnp.linalg.matrix_power, n=n),
|
2022-10-11 15:59:44 +00:00
|
|
|
args_maker, tol=1e-3)
|
2020-05-05 16:40:41 -04:00
|
|
|
self._CompileAndCheck(partial(jnp.linalg.matrix_power, n=n), args_maker,
|
2020-06-01 17:19:23 -04:00
|
|
|
rtol=1e-3)
|
2020-01-24 16:52:40 -05:00
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(3, ), (1, 2), (8, 5), (4, 4), (5, 5), (50, 50), (3, 4, 5),
|
|
|
|
(2, 3, 4, 5)],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
)
|
2020-12-04 09:44:50 -08:00
|
|
|
def testMatrixRank(self, shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-01-26 14:29:33 -05:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
a, = args_maker()
|
2020-05-05 16:40:41 -04:00
|
|
|
self._CheckAgainstNumpy(np.linalg.matrix_rank, jnp.linalg.matrix_rank,
|
2020-01-26 14:29:33 -05:00
|
|
|
args_maker, check_dtypes=False, tol=1e-3)
|
2020-05-05 16:40:41 -04:00
|
|
|
self._CompileAndCheck(jnp.linalg.matrix_rank, args_maker,
|
2020-01-26 14:29:33 -05:00
|
|
|
check_dtypes=False, rtol=1e-3)
|
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shapes=[
|
|
|
|
[(3, ), (3, 1)], # quick-out codepath
|
|
|
|
[(1, 3), (3, 5), (5, 2)], # multi_dot_three codepath
|
|
|
|
[(1, 3), (3, 5), (5, 2), (2, 7), (7, )] # dynamic programming codepath
|
|
|
|
],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
)
|
2020-12-04 09:44:50 -08:00
|
|
|
def testMultiDot(self, shapes, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-04-15 17:35:54 -07:00
|
|
|
args_maker = lambda: [[rng(shape, dtype) for shape in shapes]]
|
|
|
|
|
2020-05-05 16:40:41 -04:00
|
|
|
np_fun = np.linalg.multi_dot
|
|
|
|
jnp_fun = partial(jnp.linalg.multi_dot, precision=lax.Precision.HIGHEST)
|
|
|
|
tol = {np.float32: 1e-4, np.float64: 1e-10,
|
|
|
|
np.complex64: 1e-4, np.complex128: 1e-10}
|
2020-04-15 17:35:54 -07:00
|
|
|
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker,
|
2020-04-15 17:35:54 -07:00
|
|
|
atol=tol, rtol=tol)
|
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape)
|
2020-05-11 14:53:17 -07:00
|
|
|
for lhs_shape, rhs_shape in [
|
|
|
|
((1, 1), (1, 1)),
|
|
|
|
((4, 6), (4,)),
|
|
|
|
((6, 6), (6, 1)),
|
|
|
|
((8, 6), (8, 4)),
|
2022-10-24 14:10:31 -07:00
|
|
|
((0, 3), (0,)),
|
|
|
|
((3, 0), (3,)),
|
|
|
|
((3, 1), (3, 0)),
|
2020-05-11 14:53:17 -07:00
|
|
|
]
|
2022-10-11 15:59:44 +00:00
|
|
|
],
|
|
|
|
rcond=[-1, None, 0.5],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
)
|
2021-08-20 14:45:31 -07:00
|
|
|
def testLstsq(self, lhs_shape, rhs_shape, dtype, rcond):
|
2020-12-04 09:44:50 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-14 13:03:24 -07:00
|
|
|
np_fun = partial(np.linalg.lstsq, rcond=rcond)
|
2020-05-11 14:53:17 -07:00
|
|
|
jnp_fun = partial(jnp.linalg.lstsq, rcond=rcond)
|
|
|
|
jnp_fun_numpy_resid = partial(jnp.linalg.lstsq, rcond=rcond, numpy_resid=True)
|
2022-11-11 16:23:31 -08:00
|
|
|
tol = {np.float32: 1e-4, np.float64: 1e-12,
|
2021-08-20 14:45:31 -07:00
|
|
|
np.complex64: 1e-5, np.complex128: 1e-12}
|
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
2020-05-11 14:53:17 -07:00
|
|
|
|
2020-07-14 13:03:24 -07:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun_numpy_resid, args_maker, check_dtypes=False, tol=tol)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol)
|
2020-05-11 14:53:17 -07:00
|
|
|
|
|
|
|
# Disabled because grad is flaky for low-rank inputs.
|
|
|
|
# TODO:
|
|
|
|
# jtu.check_grads(lambda *args: jnp_fun(*args)[0], args_maker(), order=2, atol=1e-2, rtol=1e-2)
|
|
|
|
|
2019-05-04 09:42:01 -04:00
|
|
|
# Regression test for incorrect type for eigenvalues of a complex matrix.
|
|
|
|
def testIssue669(self):
|
|
|
|
def test(x):
|
2020-05-05 16:40:41 -04:00
|
|
|
val, vec = jnp.linalg.eigh(x)
|
|
|
|
return jnp.real(jnp.sum(val))
|
2019-05-04 09:42:01 -04:00
|
|
|
|
|
|
|
grad_test_jc = jit(grad(jit(test)))
|
2021-02-01 16:30:30 -05:00
|
|
|
xc = np.eye(3, dtype=np.complex64)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(xc, grad_test_jc(xc))
|
2019-05-04 09:42:01 -04:00
|
|
|
|
2020-02-05 17:35:46 +01:00
|
|
|
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
2019-08-09 10:50:31 -04:00
|
|
|
def testIssue1151(self):
|
2020-05-11 12:09:54 -04:00
|
|
|
rng = self.rng()
|
|
|
|
A = jnp.array(rng.randn(100, 3, 3), dtype=jnp.float32)
|
|
|
|
b = jnp.array(rng.randn(100, 3), dtype=jnp.float32)
|
2020-05-05 16:40:41 -04:00
|
|
|
x = jnp.linalg.solve(A, b)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(vmap(jnp.dot)(A, x), b, atol=2e-3, rtol=1e-2)
|
2020-06-02 19:25:47 -07:00
|
|
|
|
|
|
|
_ = jax.jacobian(jnp.linalg.solve, argnums=0)(A, b)
|
|
|
|
_ = jax.jacobian(jnp.linalg.solve, argnums=1)(A, b)
|
|
|
|
|
|
|
|
_ = jax.jacobian(jnp.linalg.solve, argnums=0)(A[0], b[0])
|
|
|
|
_ = jax.jacobian(jnp.linalg.solve, argnums=1)(A[0], b[0])
|
2019-08-09 10:50:31 -04:00
|
|
|
|
2020-02-05 17:35:46 +01:00
|
|
|
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
2023-08-25 14:11:19 -07:00
|
|
|
@jax.legacy_prng_key("allow")
|
2019-09-23 12:46:15 -04:00
|
|
|
def testIssue1383(self):
|
|
|
|
seed = jax.random.PRNGKey(0)
|
|
|
|
tmp = jax.random.uniform(seed, (2,2))
|
2020-05-05 16:40:41 -04:00
|
|
|
a = jnp.dot(tmp, tmp.T)
|
2019-09-23 12:46:15 -04:00
|
|
|
|
|
|
|
def f(inp):
|
2020-05-05 16:40:41 -04:00
|
|
|
val, vec = jnp.linalg.eigh(inp)
|
|
|
|
return jnp.dot(jnp.dot(vec, inp), vec.T)
|
2019-09-23 12:46:15 -04:00
|
|
|
|
|
|
|
grad_func = jax.jacfwd(f)
|
|
|
|
hess_func = jax.jacfwd(grad_func)
|
|
|
|
cube_func = jax.jacfwd(hess_func)
|
2020-05-05 16:40:41 -04:00
|
|
|
self.assertFalse(np.any(np.isnan(cube_func(a))))
|
2019-09-23 12:46:15 -04:00
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
class ScipyLinalgTest(jtu.JaxTestCase):
|
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
args=[
|
|
|
|
(),
|
|
|
|
(1,),
|
|
|
|
(7, -2),
|
|
|
|
(3, 4, 5),
|
2022-12-02 13:20:30 -08:00
|
|
|
(np.ones((3, 4), dtype=float), 5,
|
|
|
|
np.random.randn(5, 2).astype(float)),
|
2022-10-11 15:59:44 +00:00
|
|
|
]
|
|
|
|
)
|
2020-01-29 11:24:40 -05:00
|
|
|
def testBlockDiag(self, args):
|
|
|
|
args_maker = lambda: args
|
|
|
|
self._CheckAgainstNumpy(osp.linalg.block_diag, jsp.linalg.block_diag,
|
2022-12-02 13:20:30 -08:00
|
|
|
args_maker, check_dtypes=False)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(jsp.linalg.block_diag, args_maker)
|
2020-01-29 11:24:40 -05:00
|
|
|
|
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(1, 1), (4, 5), (10, 5), (50, 50)],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
)
|
2020-12-04 09:44:50 -08:00
|
|
|
def testLu(self, shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-12-20 15:37:34 -05:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2019-08-02 11:16:15 -04:00
|
|
|
x, = args_maker()
|
|
|
|
p, l, u = jsp.linalg.lu(x)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(x, np.matmul(p, np.matmul(l, u)),
|
2023-06-23 09:21:32 -07:00
|
|
|
rtol={np.float32: 1e-3, np.float64: 5e-12,
|
2022-11-10 13:15:44 -08:00
|
|
|
np.complex64: 1e-3, np.complex128: 1e-12},
|
|
|
|
atol={np.float32: 1e-5})
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(jsp.linalg.lu, args_maker)
|
2018-12-20 15:37:34 -05:00
|
|
|
|
2019-09-16 08:45:10 -07:00
|
|
|
def testLuOfSingularMatrix(self):
|
2020-05-05 16:40:41 -04:00
|
|
|
x = jnp.array([[-1., 3./2], [2./3, -1.]], dtype=np.float32)
|
2019-09-16 08:45:10 -07:00
|
|
|
p, l, u = jsp.linalg.lu(x)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(x, np.matmul(p, np.matmul(l, u)))
|
2019-06-28 15:31:06 -04:00
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(1, 1), (4, 5), (10, 5), (10, 10), (6, 7, 7)],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
)
|
2020-12-04 09:44:50 -08:00
|
|
|
def testLuGrad(self, shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-12-22 14:53:42 -05:00
|
|
|
a = rng(shape, dtype)
|
2019-08-09 10:50:31 -04:00
|
|
|
lu = vmap(jsp.linalg.lu) if len(shape) > 2 else jsp.linalg.lu
|
2020-05-04 23:00:20 -04:00
|
|
|
jtu.check_grads(lu, (a,), 2, atol=5e-2, rtol=3e-1)
|
2018-12-22 14:53:42 -05:00
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(4, 5), (6, 5)],
|
|
|
|
dtype=[jnp.float32],
|
|
|
|
)
|
2020-12-04 09:44:50 -08:00
|
|
|
def testLuBatching(self, shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-05-05 16:40:41 -04:00
|
|
|
args = [rng(shape, jnp.float32) for _ in range(10)]
|
2019-04-30 13:19:34 -04:00
|
|
|
expected = list(osp.linalg.lu(x) for x in args)
|
2020-05-05 16:40:41 -04:00
|
|
|
ps = np.stack([out[0] for out in expected])
|
|
|
|
ls = np.stack([out[1] for out in expected])
|
|
|
|
us = np.stack([out[2] for out in expected])
|
2019-04-30 13:19:34 -04:00
|
|
|
|
2020-05-05 16:40:41 -04:00
|
|
|
actual_ps, actual_ls, actual_us = vmap(jsp.linalg.lu)(jnp.stack(args))
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(ps, actual_ps)
|
2020-12-06 15:44:44 -05:00
|
|
|
self.assertAllClose(ls, actual_ls, rtol=5e-6)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(us, actual_us)
|
2018-12-22 14:53:42 -05:00
|
|
|
|
2021-05-14 15:46:27 +01:00
|
|
|
@jtu.skip_on_devices("cpu", "tpu")
|
|
|
|
def testLuCPUBackendOnGPU(self):
|
|
|
|
# tests running `lu` on cpu when a gpu is present.
|
|
|
|
jit(jsp.linalg.lu, backend="cpu")(np.ones((2, 2))) # does not crash
|
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
n=[1, 4, 5, 200],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
)
|
2020-12-04 09:44:50 -08:00
|
|
|
def testLuFactor(self, n, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-12-20 15:37:34 -05:00
|
|
|
args_maker = lambda: [rng((n, n), dtype)]
|
|
|
|
|
2019-08-02 11:16:15 -04:00
|
|
|
x, = args_maker()
|
|
|
|
lu, piv = jsp.linalg.lu_factor(x)
|
2020-05-05 16:40:41 -04:00
|
|
|
l = np.tril(lu, -1) + np.eye(n, dtype=dtype)
|
|
|
|
u = np.triu(lu)
|
2019-08-02 11:16:15 -04:00
|
|
|
for i in range(n):
|
|
|
|
x[[i, piv[i]],] = x[[piv[i], i],]
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(x, np.matmul(l, u), rtol=1e-3,
|
2019-11-16 13:51:42 -05:00
|
|
|
atol=1e-3)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(jsp.linalg.lu_factor, args_maker)
|
2018-12-20 15:37:34 -05:00
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape)
|
|
|
|
for lhs_shape, rhs_shape in [
|
|
|
|
((1, 1), (1, 1)),
|
|
|
|
((4, 4), (4,)),
|
|
|
|
((8, 8), (8, 4)),
|
|
|
|
]
|
|
|
|
],
|
|
|
|
trans=[0, 1, 2],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
)
|
2020-04-17 14:43:16 -07:00
|
|
|
@jtu.skip_on_devices("cpu") # TODO(frostig): Test fails on CPU sometimes
|
2020-12-04 09:44:50 -08:00
|
|
|
def testLuSolve(self, lhs_shape, rhs_shape, dtype, trans):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2019-09-05 09:59:47 -04:00
|
|
|
osp_fun = lambda lu, piv, rhs: osp.linalg.lu_solve((lu, piv), rhs, trans=trans)
|
|
|
|
jsp_fun = lambda lu, piv, rhs: jsp.linalg.lu_solve((lu, piv), rhs, trans=trans)
|
|
|
|
|
|
|
|
def args_maker():
|
|
|
|
a = rng(lhs_shape, dtype)
|
|
|
|
lu, piv = osp.linalg.lu_factor(a)
|
|
|
|
return [lu, piv, rng(rhs_shape, dtype)]
|
|
|
|
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, tol=1e-3)
|
|
|
|
self._CompileAndCheck(jsp_fun, args_maker)
|
2019-09-05 09:59:47 -04:00
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape)
|
2018-12-21 16:29:45 -05:00
|
|
|
for lhs_shape, rhs_shape in [
|
2022-10-11 15:59:44 +00:00
|
|
|
((1, 1), (1, 1)),
|
|
|
|
((4, 4), (4,)),
|
|
|
|
((8, 8), (8, 4)),
|
2018-12-21 16:29:45 -05:00
|
|
|
]
|
2022-10-11 15:59:44 +00:00
|
|
|
],
|
|
|
|
[dict(assume_a=assume_a, lower=lower)
|
2022-07-19 13:57:49 -07:00
|
|
|
for assume_a, lower in [
|
|
|
|
('gen', False),
|
|
|
|
('pos', False),
|
|
|
|
('pos', True),
|
2018-12-21 16:29:45 -05:00
|
|
|
]
|
2022-10-11 15:59:44 +00:00
|
|
|
],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
)
|
2022-07-19 13:57:49 -07:00
|
|
|
def testSolve(self, lhs_shape, rhs_shape, dtype, assume_a, lower):
|
2020-12-04 09:44:50 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2022-07-19 13:57:49 -07:00
|
|
|
osp_fun = lambda lhs, rhs: osp.linalg.solve(lhs, rhs, assume_a=assume_a, lower=lower)
|
|
|
|
jsp_fun = lambda lhs, rhs: jsp.linalg.solve(lhs, rhs, assume_a=assume_a, lower=lower)
|
2018-12-21 16:29:45 -05:00
|
|
|
|
|
|
|
def args_maker():
|
|
|
|
a = rng(lhs_shape, dtype)
|
2022-07-19 13:57:49 -07:00
|
|
|
if assume_a == 'pos':
|
2020-05-05 16:40:41 -04:00
|
|
|
a = np.matmul(a, np.conj(T(a)))
|
|
|
|
a = np.tril(a) if lower else np.triu(a)
|
2018-12-21 16:29:45 -05:00
|
|
|
return [a, rng(rhs_shape, dtype)]
|
|
|
|
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, tol=1e-3)
|
|
|
|
self._CompileAndCheck(jsp_fun, args_maker)
|
2018-12-21 16:29:45 -05:00
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape)
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
for lhs_shape, rhs_shape in [
|
2022-10-11 15:59:44 +00:00
|
|
|
((4, 4), (4,)),
|
|
|
|
((4, 4), (4, 3)),
|
|
|
|
((2, 8, 8), (2, 8, 10)),
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
]
|
2022-10-11 15:59:44 +00:00
|
|
|
],
|
|
|
|
lower=[False, True],
|
|
|
|
transpose_a=[False, True],
|
|
|
|
unit_diagonal=[False, True],
|
|
|
|
dtype=float_types,
|
|
|
|
)
|
2019-06-25 15:24:22 -04:00
|
|
|
def testSolveTriangular(self, lower, transpose_a, unit_diagonal, lhs_shape,
|
2020-12-04 09:44:50 -08:00
|
|
|
rhs_shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
k = rng(lhs_shape, dtype)
|
2020-05-05 16:40:41 -04:00
|
|
|
l = np.linalg.cholesky(np.matmul(k, T(k))
|
|
|
|
+ lhs_shape[-1] * np.eye(lhs_shape[-1]))
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
l = l.astype(k.dtype)
|
|
|
|
b = rng(rhs_shape, dtype)
|
|
|
|
|
2019-06-25 15:24:22 -04:00
|
|
|
if unit_diagonal:
|
2020-05-05 16:40:41 -04:00
|
|
|
a = np.tril(l, -1) + np.eye(lhs_shape[-1], dtype=dtype)
|
2019-06-25 15:24:22 -04:00
|
|
|
else:
|
|
|
|
a = l
|
|
|
|
a = a if lower else T(a)
|
|
|
|
|
2020-05-05 16:40:41 -04:00
|
|
|
inv = np.linalg.inv(T(a) if transpose_a else a).astype(a.dtype)
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
if len(lhs_shape) == len(rhs_shape):
|
2020-05-05 16:40:41 -04:00
|
|
|
np_ans = np.matmul(inv, b)
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
else:
|
2020-05-05 16:40:41 -04:00
|
|
|
np_ans = np.einsum("...ij,...j->...i", inv, b)
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
# The standard scipy.linalg.solve_triangular doesn't support broadcasting.
|
|
|
|
# But it seems like an inevitable extension so we support it.
|
2018-12-20 15:37:34 -05:00
|
|
|
ans = jsp.linalg.solve_triangular(
|
2019-06-25 15:24:22 -04:00
|
|
|
l if lower else T(l), b, trans=1 if transpose_a else 0, lower=lower,
|
|
|
|
unit_diagonal=unit_diagonal)
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(np_ans, ans,
|
2020-05-05 16:40:41 -04:00
|
|
|
rtol={np.float32: 1e-4, np.float64: 1e-11})
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(left_side=left_side, a_shape=a_shape, b_shape=b_shape)
|
2019-10-10 08:27:21 -07:00
|
|
|
for left_side, a_shape, b_shape in [
|
2020-03-21 10:46:07 -07:00
|
|
|
(False, (4, 4), (4,)),
|
2019-10-10 08:27:21 -07:00
|
|
|
(False, (4, 4), (1, 4,)),
|
|
|
|
(False, (3, 3), (4, 3)),
|
2020-03-21 10:46:07 -07:00
|
|
|
(True, (4, 4), (4,)),
|
2019-10-10 08:27:21 -07:00
|
|
|
(True, (4, 4), (4, 1)),
|
|
|
|
(True, (4, 4), (4, 3)),
|
|
|
|
(True, (2, 8, 8), (2, 8, 10)),
|
2022-10-11 15:59:44 +00:00
|
|
|
]
|
|
|
|
],
|
|
|
|
[dict(dtype=dtype, conjugate_a=conjugate_a)
|
|
|
|
for dtype in float_types + complex_types
|
|
|
|
for conjugate_a in (
|
|
|
|
[False] if jnp.issubdtype(dtype, jnp.floating) else [False, True])
|
|
|
|
],
|
|
|
|
lower=[False, True],
|
|
|
|
unit_diagonal=[False, True],
|
|
|
|
transpose_a=[False, True],
|
|
|
|
)
|
2019-10-10 08:27:21 -07:00
|
|
|
def testTriangularSolveGrad(
|
|
|
|
self, lower, transpose_a, conjugate_a, unit_diagonal, left_side, a_shape,
|
2020-12-04 09:44:50 -08:00
|
|
|
b_shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-11-04 13:33:30 -08:00
|
|
|
# Test lax.linalg.triangular_solve instead of scipy.linalg.solve_triangular
|
2019-10-10 08:27:21 -07:00
|
|
|
# because it exposes more options.
|
2020-05-05 16:40:41 -04:00
|
|
|
A = jnp.tril(rng(a_shape, dtype) + 5 * np.eye(a_shape[-1], dtype=dtype))
|
2018-12-17 17:20:52 -08:00
|
|
|
A = A if lower else T(A)
|
2019-10-10 08:27:21 -07:00
|
|
|
B = rng(b_shape, dtype)
|
2020-11-04 13:33:30 -08:00
|
|
|
f = partial(lax.linalg.triangular_solve, lower=lower, transpose_a=transpose_a,
|
2020-11-04 08:59:09 -08:00
|
|
|
conjugate_a=conjugate_a, unit_diagonal=unit_diagonal,
|
|
|
|
left_side=left_side)
|
2022-10-11 15:59:44 +00:00
|
|
|
jtu.check_grads(f, (A, B), order=1, rtol=4e-2, eps=1e-3)
|
|
|
|
|
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(left_side=left_side, a_shape=a_shape, b_shape=b_shape, bdims=bdims)
|
Better batching rule for triangular_solve (#2138)
* Better batching rule for triangular_solve
Now, if only the right hand side argument `b` is batched, we leverage
triangular solve's builtin batching for handling multiple right-hand-side
vectors.
This makes the performance of `vmap` over only the second argument of linear
solves equivalent to relying on builtin batching::
rs = onp.random.RandomState(0)
a = rs.randn(500, 500) + 0.1 * np.eye(500)
b_mat = jax.device_put(rs.randn(500, 10))
solve1 = jax.jit(np.linalg.solve)
solve2 = jax.jit(jax.vmap(np.linalg.solve, in_axes=(None, 1), out_axes=1))
Before::
In [6]: %timeit jax.device_get(solve1(a, b_mat))
3.88 ms ± 293 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# 8x slower :(
In [9]: %timeit jax.device_get(solve2(a, b_mat))
23.5 ms ± 1.33 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Now::
In [2]: %timeit jax.device_get(solve1(a, b_mat))
3.76 ms ± 304 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
# same speed :)
In [3]: %timeit jax.device_get(solve2(a, b_mat))
3.72 ms ± 296 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
* Test failures
* Check b.ndim == 2 in triangular solve shape rule
2020-02-03 09:27:03 -08:00
|
|
|
for left_side, a_shape, b_shape, bdims in [
|
|
|
|
(False, (4, 4), (2, 3, 4,), (None, 0)),
|
|
|
|
(False, (2, 4, 4), (2, 2, 3, 4,), (None, 0)),
|
|
|
|
(False, (2, 4, 4), (3, 4,), (0, None)),
|
|
|
|
(False, (2, 4, 4), (2, 3, 4,), (0, 0)),
|
|
|
|
(True, (2, 4, 4), (2, 4, 3), (0, 0)),
|
|
|
|
(True, (2, 4, 4), (2, 2, 4, 3), (None, 0)),
|
2022-10-11 15:59:44 +00:00
|
|
|
]
|
|
|
|
],
|
|
|
|
)
|
Better batching rule for triangular_solve (#2138)
* Better batching rule for triangular_solve
Now, if only the right hand side argument `b` is batched, we leverage
triangular solve's builtin batching for handling multiple right-hand-side
vectors.
This makes the performance of `vmap` over only the second argument of linear
solves equivalent to relying on builtin batching::
rs = onp.random.RandomState(0)
a = rs.randn(500, 500) + 0.1 * np.eye(500)
b_mat = jax.device_put(rs.randn(500, 10))
solve1 = jax.jit(np.linalg.solve)
solve2 = jax.jit(jax.vmap(np.linalg.solve, in_axes=(None, 1), out_axes=1))
Before::
In [6]: %timeit jax.device_get(solve1(a, b_mat))
3.88 ms ± 293 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# 8x slower :(
In [9]: %timeit jax.device_get(solve2(a, b_mat))
23.5 ms ± 1.33 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Now::
In [2]: %timeit jax.device_get(solve1(a, b_mat))
3.76 ms ± 304 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
# same speed :)
In [3]: %timeit jax.device_get(solve2(a, b_mat))
3.72 ms ± 296 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
* Test failures
* Check b.ndim == 2 in triangular solve shape rule
2020-02-03 09:27:03 -08:00
|
|
|
def testTriangularSolveBatching(self, left_side, a_shape, b_shape, bdims):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-05-05 16:40:41 -04:00
|
|
|
A = jnp.tril(rng(a_shape, np.float32)
|
|
|
|
+ 5 * np.eye(a_shape[-1], dtype=np.float32))
|
|
|
|
B = rng(b_shape, np.float32)
|
2020-11-04 13:33:30 -08:00
|
|
|
solve = partial(lax.linalg.triangular_solve, lower=True, transpose_a=False,
|
2020-11-04 08:59:09 -08:00
|
|
|
conjugate_a=False, unit_diagonal=False, left_side=left_side)
|
Better batching rule for triangular_solve (#2138)
* Better batching rule for triangular_solve
Now, if only the right hand side argument `b` is batched, we leverage
triangular solve's builtin batching for handling multiple right-hand-side
vectors.
This makes the performance of `vmap` over only the second argument of linear
solves equivalent to relying on builtin batching::
rs = onp.random.RandomState(0)
a = rs.randn(500, 500) + 0.1 * np.eye(500)
b_mat = jax.device_put(rs.randn(500, 10))
solve1 = jax.jit(np.linalg.solve)
solve2 = jax.jit(jax.vmap(np.linalg.solve, in_axes=(None, 1), out_axes=1))
Before::
In [6]: %timeit jax.device_get(solve1(a, b_mat))
3.88 ms ± 293 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# 8x slower :(
In [9]: %timeit jax.device_get(solve2(a, b_mat))
23.5 ms ± 1.33 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Now::
In [2]: %timeit jax.device_get(solve1(a, b_mat))
3.76 ms ± 304 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
# same speed :)
In [3]: %timeit jax.device_get(solve2(a, b_mat))
3.72 ms ± 296 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
* Test failures
* Check b.ndim == 2 in triangular solve shape rule
2020-02-03 09:27:03 -08:00
|
|
|
X = vmap(solve, bdims)(A, B)
|
2020-05-05 16:40:41 -04:00
|
|
|
matmul = partial(jnp.matmul, precision=lax.Precision.HIGHEST)
|
2020-02-03 18:02:45 -08:00
|
|
|
Y = matmul(A, X) if left_side else matmul(X, A)
|
2022-01-26 09:29:06 -08:00
|
|
|
self.assertArraysAllClose(Y, jnp.broadcast_to(B, Y.shape), atol=1e-4)
|
Better batching rule for triangular_solve (#2138)
* Better batching rule for triangular_solve
Now, if only the right hand side argument `b` is batched, we leverage
triangular solve's builtin batching for handling multiple right-hand-side
vectors.
This makes the performance of `vmap` over only the second argument of linear
solves equivalent to relying on builtin batching::
rs = onp.random.RandomState(0)
a = rs.randn(500, 500) + 0.1 * np.eye(500)
b_mat = jax.device_put(rs.randn(500, 10))
solve1 = jax.jit(np.linalg.solve)
solve2 = jax.jit(jax.vmap(np.linalg.solve, in_axes=(None, 1), out_axes=1))
Before::
In [6]: %timeit jax.device_get(solve1(a, b_mat))
3.88 ms ± 293 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# 8x slower :(
In [9]: %timeit jax.device_get(solve2(a, b_mat))
23.5 ms ± 1.33 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Now::
In [2]: %timeit jax.device_get(solve1(a, b_mat))
3.76 ms ± 304 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
# same speed :)
In [3]: %timeit jax.device_get(solve2(a, b_mat))
3.72 ms ± 296 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
* Test failures
* Check b.ndim == 2 in triangular solve shape rule
2020-02-03 09:27:03 -08:00
|
|
|
|
2019-12-10 00:38:18 -08:00
|
|
|
def testTriangularSolveGradPrecision(self):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-05-05 16:40:41 -04:00
|
|
|
a = jnp.tril(rng((3, 3), np.float32))
|
|
|
|
b = rng((1, 3), np.float32)
|
2019-12-10 00:38:18 -08:00
|
|
|
jtu.assert_dot_precision(
|
|
|
|
lax.Precision.HIGHEST,
|
2020-11-04 13:33:30 -08:00
|
|
|
partial(jvp, lax.linalg.triangular_solve),
|
2019-12-10 00:38:18 -08:00
|
|
|
(a, b),
|
|
|
|
(a, b))
|
2019-02-11 16:18:13 -08:00
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
n=[1, 4, 5, 20, 50, 100],
|
2023-03-27 16:39:48 -07:00
|
|
|
batch_size=[(), (2,), (3, 4)] if scipy_version >= (1, 9, 0) else [()],
|
|
|
|
dtype=int_types + float_types + complex_types
|
2022-10-11 15:59:44 +00:00
|
|
|
)
|
2023-03-27 16:39:48 -07:00
|
|
|
def testExpm(self, n, batch_size, dtype):
|
2023-09-27 14:55:21 -07:00
|
|
|
if (jtu.test_device_matches(["cuda"]) and
|
2023-07-17 10:14:44 -07:00
|
|
|
_is_required_cuda_version_satisfied(12000)):
|
|
|
|
self.skipTest("Triggers a bug in cuda-12 b/287345077")
|
|
|
|
|
2020-12-04 09:44:50 -08:00
|
|
|
rng = jtu.rand_small(self.rng())
|
2023-03-27 16:39:48 -07:00
|
|
|
args_maker = lambda: [rng((*batch_size, n, n), dtype)]
|
2020-01-22 00:11:51 -05:00
|
|
|
|
2023-03-27 16:39:48 -07:00
|
|
|
# Compare to numpy with JAX type promotion semantics.
|
|
|
|
def osp_fun(A):
|
|
|
|
return osp.linalg.expm(np.array(*promote_dtypes_inexact(A)))
|
|
|
|
jsp_fun = jsp.linalg.expm
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jsp_fun, args_maker)
|
2020-01-22 00:11:51 -05:00
|
|
|
|
2023-03-27 16:39:48 -07:00
|
|
|
args_maker_triu = lambda: [np.triu(rng((*batch_size, n, n), dtype))]
|
2020-06-01 17:19:23 -04:00
|
|
|
jsp_fun_triu = lambda a: jsp.linalg.expm(a, upper_triangular=True)
|
|
|
|
self._CheckAgainstNumpy(osp_fun, jsp_fun_triu, args_maker_triu)
|
|
|
|
self._CompileAndCheck(jsp_fun_triu, args_maker_triu)
|
2020-01-22 00:11:51 -05:00
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
# Skip empty shapes because scipy fails: https://github.com/scipy/scipy/issues/1532
|
|
|
|
shape=[(3, 4), (3, 3), (4, 3)],
|
|
|
|
dtype=[np.float32],
|
|
|
|
mode=["full", "r", "economic"],
|
|
|
|
)
|
2022-04-29 10:04:47 -07:00
|
|
|
def testScipyQrModes(self, shape, dtype, mode):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
jsp_func = partial(jax.scipy.linalg.qr, mode=mode)
|
|
|
|
sp_func = partial(scipy.linalg.qr, mode=mode)
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
self._CheckAgainstNumpy(sp_func, jsp_func, args_maker, rtol=1E-5, atol=1E-5)
|
|
|
|
self._CompileAndCheck(jsp_func, args_maker)
|
|
|
|
|
2022-11-09 06:23:22 -08:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, k=k)
|
|
|
|
for shape in [(1, 1), (3, 4, 4), (10, 5)]
|
|
|
|
# TODO(phawkins): there are some test failures on GPU for k=0
|
|
|
|
for k in range(1, shape[-1] + 1)],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
)
|
|
|
|
def testHouseholderProduct(self, shape, k, dtype):
|
|
|
|
|
|
|
|
@partial(np.vectorize, signature='(m,n),(k)->(m,n)')
|
|
|
|
def reference_fn(a, taus):
|
|
|
|
if dtype == np.float32:
|
|
|
|
q, _, info = scipy.linalg.lapack.sorgqr(a, taus)
|
|
|
|
elif dtype == np.float64:
|
|
|
|
q, _, info = scipy.linalg.lapack.dorgqr(a, taus)
|
|
|
|
elif dtype == np.complex64:
|
|
|
|
q, _, info = scipy.linalg.lapack.cungqr(a, taus)
|
|
|
|
elif dtype == np.complex128:
|
|
|
|
q, _, info = scipy.linalg.lapack.zungqr(a, taus)
|
|
|
|
else:
|
|
|
|
assert False, dtype
|
|
|
|
assert info == 0, info
|
|
|
|
return q
|
|
|
|
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape, dtype), rng(shape[:-2] + (k,), dtype)]
|
|
|
|
tol = {np.float32: 1e-5, np.complex64: 1e-5, np.float64: 1e-12,
|
|
|
|
np.complex128: 1e-12}
|
|
|
|
self._CheckAgainstNumpy(reference_fn, lax.linalg.householder_product,
|
|
|
|
args_maker, rtol=tol, atol=tol)
|
|
|
|
self._CompileAndCheck(lax.linalg.householder_product, args_maker)
|
|
|
|
|
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(1, 1), (2, 4, 4), (0, 100, 100), (10, 10)],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
calc_q=[False, True],
|
|
|
|
)
|
2023-09-27 14:55:21 -07:00
|
|
|
@jtu.run_on_devices("cpu")
|
2022-11-09 06:23:22 -08:00
|
|
|
def testHessenberg(self, shape, dtype, calc_q):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
jsp_func = partial(jax.scipy.linalg.hessenberg, calc_q=calc_q)
|
|
|
|
if calc_q:
|
|
|
|
sp_func = np.vectorize(partial(scipy.linalg.hessenberg, calc_q=True),
|
|
|
|
otypes=(dtype, dtype),
|
|
|
|
signature='(n,n)->(n,n),(n,n)')
|
|
|
|
else:
|
|
|
|
sp_func = np.vectorize(scipy.linalg.hessenberg, signature='(n,n)->(n,n)',
|
|
|
|
otypes=(dtype,))
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
# scipy.linalg.hessenberg sometimes returns a float Q matrix for complex
|
|
|
|
# inputs
|
|
|
|
self._CheckAgainstNumpy(sp_func, jsp_func, args_maker, rtol=1e-5, atol=1e-5,
|
|
|
|
check_dtypes=not calc_q)
|
|
|
|
self._CompileAndCheck(jsp_func, args_maker)
|
|
|
|
|
|
|
|
@jtu.sample_product(
|
2022-11-10 13:15:44 -08:00
|
|
|
shape=[(1, 1), (2, 2, 2), (4, 4), (10, 10), (2, 5, 5)],
|
2022-11-09 06:23:22 -08:00
|
|
|
dtype=float_types + complex_types,
|
|
|
|
lower=[False, True],
|
|
|
|
)
|
2022-12-19 21:05:17 +00:00
|
|
|
@jtu.skip_on_devices("tpu","rocm")
|
2022-11-09 06:23:22 -08:00
|
|
|
def testTridiagonal(self, shape, dtype, lower):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
def jax_func(a):
|
|
|
|
return lax.linalg.tridiagonal(a, lower=lower)
|
|
|
|
|
2022-11-10 13:15:44 -08:00
|
|
|
real_dtype = jnp.finfo(dtype).dtype
|
|
|
|
@partial(np.vectorize, otypes=(dtype, real_dtype, real_dtype, dtype),
|
|
|
|
signature='(n,n)->(n,n),(n),(k),(k)')
|
2022-11-09 06:23:22 -08:00
|
|
|
def sp_func(a):
|
|
|
|
if dtype == np.float32:
|
|
|
|
c, d, e, tau, info = scipy.linalg.lapack.ssytrd(a, lower=lower)
|
|
|
|
elif dtype == np.float64:
|
|
|
|
c, d, e, tau, info = scipy.linalg.lapack.dsytrd(a, lower=lower)
|
|
|
|
elif dtype == np.complex64:
|
|
|
|
c, d, e, tau, info = scipy.linalg.lapack.chetrd(a, lower=lower)
|
|
|
|
elif dtype == np.complex128:
|
|
|
|
c, d, e, tau, info = scipy.linalg.lapack.zhetrd(a, lower=lower)
|
|
|
|
else:
|
|
|
|
assert False, dtype
|
|
|
|
assert info == 0
|
2022-11-10 13:15:44 -08:00
|
|
|
return c, d, e, tau
|
2022-11-09 06:23:22 -08:00
|
|
|
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2022-11-10 13:15:44 -08:00
|
|
|
self._CheckAgainstNumpy(sp_func, jax_func, args_maker, rtol=1e-4, atol=1e-4,
|
2022-11-09 06:23:22 -08:00
|
|
|
check_dtypes=False)
|
|
|
|
|
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
n=[1, 4, 5, 20, 50, 100],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
)
|
2020-02-09 17:30:54 -05:00
|
|
|
def testIssue2131(self, n, dtype):
|
2020-05-05 16:40:41 -04:00
|
|
|
args_maker_zeros = lambda: [np.zeros((n, n), dtype)]
|
2020-02-09 17:30:54 -05:00
|
|
|
osp_fun = lambda a: osp.linalg.expm(a)
|
|
|
|
jsp_fun = lambda a: jsp.linalg.expm(a)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker_zeros)
|
|
|
|
self._CompileAndCheck(jsp_fun, args_maker_zeros)
|
2020-02-09 17:30:54 -05:00
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape)
|
2020-03-21 10:46:07 -07:00
|
|
|
for lhs_shape, rhs_shape in [
|
|
|
|
[(1, 1), (1,)],
|
|
|
|
[(4, 4), (4,)],
|
|
|
|
[(4, 4), (4, 4)],
|
|
|
|
]
|
2022-10-11 15:59:44 +00:00
|
|
|
],
|
|
|
|
dtype=float_types,
|
|
|
|
lower=[True, False],
|
|
|
|
)
|
2020-12-04 09:44:50 -08:00
|
|
|
def testChoSolve(self, lhs_shape, rhs_shape, dtype, lower):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-03-21 10:46:07 -07:00
|
|
|
def args_maker():
|
|
|
|
b = rng(rhs_shape, dtype)
|
|
|
|
if lower:
|
2020-05-05 16:40:41 -04:00
|
|
|
L = np.tril(rng(lhs_shape, dtype))
|
2020-03-21 10:46:07 -07:00
|
|
|
return [(L, lower), b]
|
|
|
|
else:
|
2020-05-05 16:40:41 -04:00
|
|
|
U = np.triu(rng(lhs_shape, dtype))
|
2020-03-21 10:46:07 -07:00
|
|
|
return [(U, lower), b]
|
|
|
|
self._CheckAgainstNumpy(osp.linalg.cho_solve, jsp.linalg.cho_solve,
|
2020-06-01 17:19:23 -04:00
|
|
|
args_maker, tol=1e-3)
|
2020-03-21 10:46:07 -07:00
|
|
|
|
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
n=[1, 4, 5, 20, 50, 100],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
)
|
2020-12-04 09:44:50 -08:00
|
|
|
def testExpmFrechet(self, n, dtype):
|
|
|
|
rng = jtu.rand_small(self.rng())
|
2020-09-18 15:32:50 -07:00
|
|
|
if dtype == np.float64 or dtype == np.complex128:
|
|
|
|
target_norms = [1.0e-2, 2.0e-1, 9.0e-01, 2.0, 3.0]
|
|
|
|
# TODO(zhangqiaorjc): Reduce tol to default 1e-15.
|
|
|
|
tol = {
|
|
|
|
np.dtype(np.float64): 1e-14,
|
|
|
|
np.dtype(np.complex128): 1e-14,
|
|
|
|
}
|
|
|
|
elif dtype == np.float32 or dtype == np.complex64:
|
|
|
|
target_norms = [4.0e-1, 1.0, 3.0]
|
|
|
|
tol = None
|
|
|
|
else:
|
2022-12-01 09:12:01 -08:00
|
|
|
raise TypeError(f"{dtype=} is not supported.")
|
2020-09-18 15:32:50 -07:00
|
|
|
for norm in target_norms:
|
|
|
|
def args_maker():
|
|
|
|
a = rng((n, n), dtype)
|
|
|
|
a = a / np.linalg.norm(a, 1) * norm
|
|
|
|
e = rng((n, n), dtype)
|
|
|
|
return [a, e, ]
|
|
|
|
|
|
|
|
#compute_expm is True
|
|
|
|
osp_fun = lambda a,e: osp.linalg.expm_frechet(a,e,compute_expm=True)
|
|
|
|
jsp_fun = lambda a,e: jsp.linalg.expm_frechet(a,e,compute_expm=True)
|
|
|
|
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker,
|
|
|
|
check_dtypes=False, tol=tol)
|
|
|
|
self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=False)
|
|
|
|
#compute_expm is False
|
|
|
|
osp_fun = lambda a,e: osp.linalg.expm_frechet(a,e,compute_expm=False)
|
|
|
|
jsp_fun = lambda a,e: jsp.linalg.expm_frechet(a,e,compute_expm=False)
|
|
|
|
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker,
|
|
|
|
check_dtypes=False, tol=tol)
|
|
|
|
self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=False)
|
2020-06-28 12:11:12 -04:00
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
n=[1, 4, 5, 20, 50],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
)
|
2020-12-04 09:44:50 -08:00
|
|
|
def testExpmGrad(self, n, dtype):
|
|
|
|
rng = jtu.rand_small(self.rng())
|
2020-09-18 15:32:50 -07:00
|
|
|
a = rng((n, n), dtype)
|
|
|
|
if dtype == np.float64 or dtype == np.complex128:
|
|
|
|
target_norms = [1.0e-2, 2.0e-1, 9.0e-01, 2.0, 3.0]
|
|
|
|
elif dtype == np.float32 or dtype == np.complex64:
|
|
|
|
target_norms = [4.0e-1, 1.0, 3.0]
|
|
|
|
else:
|
2022-12-01 09:12:01 -08:00
|
|
|
raise TypeError(f"{dtype=} is not supported.")
|
2020-09-21 14:17:04 -07:00
|
|
|
# TODO(zhangqiaorjc): Reduce tol to default 1e-5.
|
|
|
|
# Lower tolerance is due to 2nd order derivative.
|
|
|
|
tol = {
|
|
|
|
# Note that due to inner_product, float and complex tol are coupled.
|
|
|
|
np.dtype(np.float32): 0.02,
|
|
|
|
np.dtype(np.complex64): 0.02,
|
|
|
|
np.dtype(np.float64): 1e-4,
|
|
|
|
np.dtype(np.complex128): 1e-4,
|
|
|
|
}
|
|
|
|
for norm in target_norms:
|
2020-09-18 15:32:50 -07:00
|
|
|
a = a / np.linalg.norm(a, 1) * norm
|
|
|
|
def expm(x):
|
|
|
|
return jsp.linalg.expm(x, upper_triangular=False, max_squarings=16)
|
2020-10-08 12:09:21 -07:00
|
|
|
jtu.check_grads(expm, (a,), modes=["fwd", "rev"], order=1, atol=tol,
|
2020-09-18 15:32:50 -07:00
|
|
|
rtol=tol)
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(4, 4), (15, 15), (50, 50), (100, 100)],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
)
|
2023-09-27 14:55:21 -07:00
|
|
|
@jtu.run_on_devices("cpu")
|
2022-02-12 02:55:53 +01:00
|
|
|
def testSchur(self, shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
|
|
|
|
self._CheckAgainstNumpy(osp.linalg.schur, jsp.linalg.schur, args_maker)
|
|
|
|
self._CompileAndCheck(jsp.linalg.schur, args_maker)
|
2020-06-28 12:11:12 -04:00
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(1, 1), (4, 4), (15, 15), (50, 50), (100, 100)],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
)
|
2023-09-27 14:55:21 -07:00
|
|
|
@jtu.run_on_devices("cpu")
|
2022-03-27 12:31:12 +01:00
|
|
|
def testRsf2csf(self, shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape, dtype), rng(shape, dtype)]
|
2023-09-15 12:31:02 -07:00
|
|
|
tol = 3e-5
|
2022-03-27 12:31:12 +01:00
|
|
|
self._CheckAgainstNumpy(osp.linalg.rsf2csf, jsp.linalg.rsf2csf,
|
|
|
|
args_maker, tol=tol)
|
|
|
|
self._CompileAndCheck(jsp.linalg.rsf2csf, args_maker)
|
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(1, 1), (5, 5), (20, 20), (50, 50)],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
disp=[True, False],
|
|
|
|
)
|
2022-04-26 18:03:14 +01:00
|
|
|
# funm uses jax.scipy.linalg.schur which is implemented for a CPU
|
|
|
|
# backend only, so tests on GPU and TPU backends are skipped here
|
2023-09-27 14:55:21 -07:00
|
|
|
@jtu.run_on_devices("cpu")
|
2022-04-26 18:03:14 +01:00
|
|
|
def testFunm(self, shape, dtype, disp):
|
|
|
|
def func(x):
|
|
|
|
return x**-2.718
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
jnp_fun = lambda arr: jsp.linalg.funm(arr, func, disp=disp)
|
|
|
|
scp_fun = lambda arr: osp.linalg.funm(arr, func, disp=disp)
|
|
|
|
self._CheckAgainstNumpy(jnp_fun, scp_fun, args_maker, check_dtypes=False,
|
|
|
|
tol={np.complex64: 1e-5, np.complex128: 1e-6})
|
2023-08-14 13:00:37 -07:00
|
|
|
self._CompileAndCheck(jnp_fun, args_maker, atol=2e-5)
|
2022-04-26 18:03:14 +01:00
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(4, 4), (15, 15), (50, 50), (100, 100)],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
)
|
2023-09-27 14:55:21 -07:00
|
|
|
@jtu.run_on_devices("cpu")
|
2022-02-12 03:11:57 +01:00
|
|
|
def testSqrtmPSDMatrix(self, shape, dtype):
|
|
|
|
# Checks against scipy.linalg.sqrtm when the principal square root
|
|
|
|
# is guaranteed to be unique (i.e no negative real eigenvalue)
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
arg = rng(shape, dtype)
|
|
|
|
mat = arg @ arg.T
|
|
|
|
args_maker = lambda : [mat]
|
|
|
|
if dtype == np.float32 or dtype == np.complex64:
|
|
|
|
tol = 1e-4
|
|
|
|
else:
|
|
|
|
tol = 1e-8
|
|
|
|
self._CheckAgainstNumpy(osp.linalg.sqrtm,
|
|
|
|
jsp.linalg.sqrtm,
|
|
|
|
args_maker,
|
|
|
|
tol=tol,
|
|
|
|
check_dtypes=False)
|
|
|
|
self._CompileAndCheck(jsp.linalg.sqrtm, args_maker)
|
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(4, 4), (15, 15), (50, 50), (100, 100)],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
)
|
2023-09-27 14:55:21 -07:00
|
|
|
@jtu.run_on_devices("cpu")
|
2022-02-12 03:11:57 +01:00
|
|
|
def testSqrtmGenMatrix(self, shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
arg = rng(shape, dtype)
|
|
|
|
if dtype == np.float32 or dtype == np.complex64:
|
2023-09-29 11:15:05 -07:00
|
|
|
tol = 2e-3
|
2022-02-12 03:11:57 +01:00
|
|
|
else:
|
|
|
|
tol = 1e-8
|
|
|
|
R = jsp.linalg.sqrtm(arg)
|
|
|
|
self.assertAllClose(R @ R, arg, atol=tol, check_dtypes=False)
|
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(diag=diag, expected=expected)
|
|
|
|
for diag, expected in [([1, 0, 0], [1, 0, 0]), ([0, 4, 0], [0, 2, 0]),
|
|
|
|
([0, 0, 0, 9],[0, 0, 0, 3]),
|
|
|
|
([0, 0, 9, 0, 0, 4], [0, 0, 3, 0, 0, 2])]
|
|
|
|
],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
)
|
2023-09-27 14:55:21 -07:00
|
|
|
@jtu.run_on_devices("cpu")
|
2022-02-12 03:11:57 +01:00
|
|
|
def testSqrtmEdgeCase(self, diag, expected, dtype):
|
|
|
|
"""
|
|
|
|
Tests the zero numerator condition
|
|
|
|
"""
|
|
|
|
mat = jnp.diag(jnp.array(diag)).astype(dtype)
|
|
|
|
expected = jnp.diag(jnp.array(expected))
|
|
|
|
root = jsp.linalg.sqrtm(mat)
|
|
|
|
|
|
|
|
self.assertAllClose(root, expected, check_dtypes=False)
|
2020-06-28 12:11:12 -04:00
|
|
|
|
2022-11-15 18:40:52 +09:00
|
|
|
@jtu.sample_product(
|
|
|
|
cshape=[(), (4,), (8,), (3, 7), (0, 5, 1)],
|
|
|
|
cdtype=float_types + complex_types,
|
|
|
|
rshape=[(), (3,), (7,), (2, 1, 4), (19, 0)],
|
|
|
|
rdtype=float_types + complex_types + int_types)
|
|
|
|
def testToeplitzConstrcution(self, rshape, rdtype, cshape, cdtype):
|
|
|
|
if ((rdtype in [np.float64, np.complex128]
|
|
|
|
or cdtype in [np.float64, np.complex128])
|
2023-10-12 13:15:22 +01:00
|
|
|
and not config.enable_x64.value):
|
2022-11-15 18:40:52 +09:00
|
|
|
self.skipTest("Only run float64 testcase when float64 is enabled.")
|
|
|
|
|
|
|
|
int_types_excl_i8 = set(int_types) - {np.int8}
|
|
|
|
if ((rdtype in int_types_excl_i8 or cdtype in int_types_excl_i8)
|
2023-09-27 12:10:06 -07:00
|
|
|
and jtu.test_device_matches(["gpu"])):
|
2022-11-15 18:40:52 +09:00
|
|
|
self.skipTest("Integer (except int8) toeplitz is not supported on GPU yet.")
|
|
|
|
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(cshape, cdtype), rng(rshape, rdtype)]
|
|
|
|
with jtu.strict_promotion_if_dtypes_match([rdtype, cdtype]):
|
|
|
|
self._CheckAgainstNumpy(jtu.promote_like_jnp(osp.linalg.toeplitz),
|
|
|
|
jsp.linalg.toeplitz, args_maker)
|
|
|
|
self._CompileAndCheck(jsp.linalg.toeplitz, args_maker)
|
|
|
|
|
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(), (3,), (1, 4), (1, 5, 9), (11, 0, 13)],
|
|
|
|
dtype=float_types + complex_types + int_types)
|
2022-12-19 21:05:17 +00:00
|
|
|
@jtu.skip_on_devices("rocm")
|
2022-11-15 18:40:52 +09:00
|
|
|
def testToeplitzSymmetricConstruction(self, shape, dtype):
|
|
|
|
if (dtype in [np.float64, np.complex128]
|
2023-10-12 13:15:22 +01:00
|
|
|
and not config.enable_x64.value):
|
2022-11-15 18:40:52 +09:00
|
|
|
self.skipTest("Only run float64 testcase when float64 is enabled.")
|
|
|
|
|
|
|
|
int_types_excl_i8 = set(int_types) - {np.int8}
|
|
|
|
if (dtype in int_types_excl_i8
|
2023-09-27 12:10:06 -07:00
|
|
|
and jtu.test_device_matches(["gpu"])):
|
2022-11-15 18:40:52 +09:00
|
|
|
self.skipTest("Integer (except int8) toeplitz is not supported on GPU yet.")
|
|
|
|
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
self._CheckAgainstNumpy(jtu.promote_like_jnp(osp.linalg.toeplitz),
|
|
|
|
jsp.linalg.toeplitz, args_maker)
|
|
|
|
self._CompileAndCheck(jsp.linalg.toeplitz, args_maker)
|
|
|
|
|
|
|
|
def testToeplitzConstructionWithKnownCases(self):
|
|
|
|
# Test with examples taken from SciPy doc for the corresponding function.
|
|
|
|
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.toeplitz.html
|
|
|
|
ret = jsp.linalg.toeplitz(np.array([1.0, 2+3j, 4-1j]))
|
|
|
|
self.assertAllClose(ret, np.array([
|
|
|
|
[ 1.+0.j, 2.-3.j, 4.+1.j],
|
|
|
|
[ 2.+3.j, 1.+0.j, 2.-3.j],
|
|
|
|
[ 4.-1.j, 2.+3.j, 1.+0.j]]))
|
|
|
|
ret = jsp.linalg.toeplitz(np.array([1, 2, 3], dtype=np.float32),
|
|
|
|
np.array([1, 4, 5, 6], dtype=np.float32))
|
|
|
|
self.assertAllClose(ret, np.array([
|
|
|
|
[1, 4, 5, 6],
|
|
|
|
[2, 1, 4, 5],
|
|
|
|
[3, 2, 1, 4]], dtype=np.float32))
|
|
|
|
|
2022-02-14 09:22:05 -08:00
|
|
|
|
2021-05-26 19:14:37 +00:00
|
|
|
class LaxLinalgTest(jtu.JaxTestCase):
|
2022-05-11 11:45:28 -07:00
|
|
|
"""Tests for lax.linalg primitives."""
|
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
n=[0, 4, 5, 50],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
lower=[True, False],
|
|
|
|
sort_eigenvalues=[True, False],
|
|
|
|
)
|
2022-05-11 11:45:28 -07:00
|
|
|
def testEigh(self, n, dtype, lower, sort_eigenvalues):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
tol = 1e-3
|
|
|
|
args_maker = lambda: [rng((n, n), dtype)]
|
|
|
|
|
|
|
|
a, = args_maker()
|
|
|
|
a = (a + np.conj(a.T)) / 2
|
|
|
|
v, w = lax.linalg.eigh(np.tril(a) if lower else np.triu(a),
|
|
|
|
lower=lower, symmetrize_input=False,
|
|
|
|
sort_eigenvalues=sort_eigenvalues)
|
|
|
|
w = np.asarray(w)
|
|
|
|
v = np.asarray(v)
|
|
|
|
self.assertLessEqual(
|
|
|
|
np.linalg.norm(np.eye(n) - np.matmul(np.conj(T(v)), v)), 1e-3)
|
|
|
|
self.assertLessEqual(np.linalg.norm(np.matmul(a, v) - w * v),
|
|
|
|
tol * np.linalg.norm(a))
|
|
|
|
|
|
|
|
w_expected, v_expected = np.linalg.eigh(np.asarray(a))
|
|
|
|
self.assertAllClose(w_expected, w if sort_eigenvalues else np.sort(w),
|
2023-06-23 09:21:32 -07:00
|
|
|
rtol=1e-4, atol=1e-4)
|
2021-05-03 11:27:07 -04:00
|
|
|
|
2022-05-11 11:45:28 -07:00
|
|
|
def run_eigh_tridiagonal_test(self, alpha, beta):
|
2021-05-05 21:36:57 -04:00
|
|
|
n = alpha.shape[-1]
|
|
|
|
# scipy.linalg.eigh_tridiagonal doesn't support complex inputs, so for
|
|
|
|
# this we call the slower numpy.linalg.eigh.
|
|
|
|
if np.issubdtype(alpha.dtype, np.complexfloating):
|
|
|
|
tridiagonal = np.diag(alpha) + np.diag(beta, 1) + np.diag(
|
|
|
|
np.conj(beta), -1)
|
|
|
|
eigvals_expected, _ = np.linalg.eigh(tridiagonal)
|
|
|
|
else:
|
|
|
|
eigvals_expected = scipy.linalg.eigh_tridiagonal(
|
|
|
|
alpha, beta, eigvals_only=True)
|
|
|
|
eigvals = jax.scipy.linalg.eigh_tridiagonal(
|
|
|
|
alpha, beta, eigvals_only=True)
|
|
|
|
finfo = np.finfo(alpha.dtype)
|
|
|
|
atol = 4 * np.sqrt(n) * finfo.eps * np.amax(np.abs(eigvals_expected))
|
|
|
|
self.assertAllClose(eigvals_expected, eigvals, atol=atol, rtol=1e-4)
|
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
n=[1, 2, 3, 7, 8, 100],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
)
|
2021-05-03 11:27:07 -04:00
|
|
|
def testToeplitz(self, n, dtype):
|
|
|
|
for a, b in [[2, -1], [1, 0], [0, 1], [-1e10, 1e10], [-1e-10, 1e-10]]:
|
|
|
|
alpha = a * np.ones([n], dtype=dtype)
|
|
|
|
beta = b * np.ones([n - 1], dtype=dtype)
|
2022-05-11 11:45:28 -07:00
|
|
|
self.run_eigh_tridiagonal_test(alpha, beta)
|
2021-05-03 11:27:07 -04:00
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
n=[1, 2, 3, 7, 8, 100],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
)
|
2021-05-05 21:36:57 -04:00
|
|
|
def testRandomUniform(self, n, dtype):
|
|
|
|
alpha = jtu.rand_uniform(self.rng())((n,), dtype)
|
|
|
|
beta = jtu.rand_uniform(self.rng())((n - 1,), dtype)
|
2022-05-11 11:45:28 -07:00
|
|
|
self.run_eigh_tridiagonal_test(alpha, beta)
|
2021-05-05 21:36:57 -04:00
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(dtype=float_types + complex_types)
|
2021-05-05 21:36:57 -04:00
|
|
|
def testSelect(self, dtype):
|
|
|
|
n = 5
|
|
|
|
alpha = jtu.rand_uniform(self.rng())((n,), dtype)
|
|
|
|
beta = jtu.rand_uniform(self.rng())((n - 1,), dtype)
|
|
|
|
eigvals_all = jax.scipy.linalg.eigh_tridiagonal(alpha, beta, select="a",
|
|
|
|
eigvals_only=True)
|
|
|
|
eps = np.finfo(alpha.dtype).eps
|
|
|
|
atol = 2 * n * eps
|
|
|
|
for first in range(n - 1):
|
|
|
|
for last in range(first + 1, n - 1):
|
|
|
|
# Check that we get the expected eigenvalues by selecting by
|
|
|
|
# index range.
|
|
|
|
eigvals_index = jax.scipy.linalg.eigh_tridiagonal(
|
|
|
|
alpha, beta, select="i", select_range=(first, last),
|
|
|
|
eigvals_only=True)
|
|
|
|
self.assertAllClose(
|
|
|
|
eigvals_all[first:(last + 1)], eigvals_index, atol=atol)
|
2021-05-03 11:27:07 -04:00
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(dtype=[np.float32, np.float64])
|
2022-02-15 17:54:02 +00:00
|
|
|
@jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1
|
2021-05-26 19:14:37 +00:00
|
|
|
def test_tridiagonal_solve(self, dtype):
|
2021-07-23 12:07:14 +00:00
|
|
|
dl = np.array([0.0, 2.0, 3.0], dtype=dtype)
|
2021-05-26 19:14:37 +00:00
|
|
|
d = np.ones(3, dtype=dtype)
|
|
|
|
du = np.array([1.0, 2.0, 0.0], dtype=dtype)
|
|
|
|
m = 3
|
|
|
|
B = np.ones([m, 1], dtype=dtype)
|
|
|
|
X = lax.linalg.tridiagonal_solve(dl, d, du, B)
|
|
|
|
A = np.eye(3, dtype=dtype)
|
|
|
|
A[[1, 2], [0, 1]] = dl[1:]
|
|
|
|
A[[0, 1], [1, 2]] = du[:-1]
|
2021-07-23 12:07:14 +00:00
|
|
|
np.testing.assert_allclose(A @ X, B, rtol=1e-6, atol=1e-6)
|
2021-05-03 11:27:07 -04:00
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(4, 4), (15, 15), (50, 50), (100, 100)],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
)
|
2023-09-27 14:55:21 -07:00
|
|
|
@jtu.run_on_devices("cpu")
|
2021-09-28 20:34:35 +02:00
|
|
|
def testSchur(self, shape, dtype):
|
2023-04-24 17:31:14 -07:00
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2021-09-28 20:34:35 +02:00
|
|
|
|
2023-04-24 17:31:14 -07:00
|
|
|
self._CheckAgainstNumpy(osp.linalg.schur, lax.linalg.schur, args_maker)
|
|
|
|
self._CompileAndCheck(lax.linalg.schur, args_maker)
|
2021-09-28 20:34:35 +02:00
|
|
|
|
2022-10-11 15:59:44 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(2, 2), (4, 4), (15, 15), (50, 50), (100, 100)],
|
|
|
|
dtype=float_types + complex_types,
|
|
|
|
)
|
2023-09-27 14:55:21 -07:00
|
|
|
@jtu.run_on_devices("cpu")
|
2021-09-28 20:34:35 +02:00
|
|
|
def testSchurBatching(self, shape, dtype):
|
2023-04-24 17:31:14 -07:00
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
batch_size = 10
|
|
|
|
shape = (batch_size,) + shape
|
|
|
|
args = rng(shape, dtype)
|
|
|
|
reconstruct = vmap(lambda S, T: S @ T @ jnp.conj(S.T))
|
2021-09-28 20:34:35 +02:00
|
|
|
|
2023-04-24 17:31:14 -07:00
|
|
|
Ts, Ss = vmap(lax.linalg.schur)(args)
|
|
|
|
self.assertAllClose(reconstruct(Ss, Ts), args, atol=1e-4)
|
2021-09-28 20:34:35 +02:00
|
|
|
|
2022-11-15 18:40:52 +09:00
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
if __name__ == "__main__":
|
2020-06-24 16:24:33 -07:00
|
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|