Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
"""Tests for the LAPAX linear algebra module."""
|
|
|
|
|
2018-12-17 16:36:55 +00:00
|
|
|
from functools import partial
|
2019-07-02 20:06:29 -04:00
|
|
|
import unittest
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2020-05-05 16:40:41 -04:00
|
|
|
import numpy as np
|
2018-12-20 15:37:34 -05:00
|
|
|
import scipy as osp
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
from absl.testing import absltest
|
|
|
|
from absl.testing import parameterized
|
|
|
|
|
2019-08-09 10:50:31 -04:00
|
|
|
import jax
|
2019-09-04 16:24:32 -04:00
|
|
|
import jax.lib
|
2019-05-04 09:42:01 -04:00
|
|
|
from jax import jit, grad, jvp, vmap
|
2019-12-10 00:38:18 -08:00
|
|
|
from jax import lax
|
2020-05-05 16:40:41 -04:00
|
|
|
from jax import numpy as jnp
|
2018-12-20 15:37:34 -05:00
|
|
|
from jax import scipy as jsp
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
from jax import test_util as jtu
|
2018-12-20 15:37:34 -05:00
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
from jax.config import config
|
|
|
|
config.parse_flags_with_absl()
|
2019-02-07 10:51:55 -05:00
|
|
|
FLAGS = config.FLAGS
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2020-05-05 16:40:41 -04:00
|
|
|
T = lambda x: np.swapaxes(x, -1, -2)
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
|
2020-05-05 16:40:41 -04:00
|
|
|
float_types = [np.float32, np.float64]
|
|
|
|
complex_types = [np.complex64, np.complex128]
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
|
|
|
|
class NumpyLinalgTest(jtu.JaxTestCase):
|
|
|
|
|
2020-07-09 16:31:08 -07:00
|
|
|
def testNotImplemented(self):
|
|
|
|
for name in jnp.linalg._NOT_IMPLEMENTED:
|
|
|
|
func = getattr(jnp.linalg, name)
|
|
|
|
with self.assertRaises(NotImplementedError):
|
|
|
|
func()
|
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
2020-12-04 09:44:50 -08:00
|
|
|
"shape": shape, "dtype": dtype}
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
for shape in [(1, 1), (4, 4), (2, 5, 5), (200, 200), (1000, 0, 0)]
|
2020-12-04 09:44:50 -08:00
|
|
|
for dtype in float_types + complex_types))
|
|
|
|
def testCholesky(self, shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-30 11:07:56 -07:00
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
def args_maker():
|
2019-02-11 16:18:13 -08:00
|
|
|
factor_shape = shape[:-1] + (2 * shape[-1],)
|
|
|
|
a = rng(factor_shape, dtype)
|
2020-05-05 16:40:41 -04:00
|
|
|
return [np.matmul(a, jnp.conj(T(a)))]
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2020-05-05 16:40:41 -04:00
|
|
|
self._CheckAgainstNumpy(np.linalg.cholesky, jnp.linalg.cholesky, args_maker,
|
2020-06-01 17:19:23 -04:00
|
|
|
tol=1e-3)
|
|
|
|
self._CompileAndCheck(jnp.linalg.cholesky, args_maker)
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2020-05-05 16:40:41 -04:00
|
|
|
if jnp.finfo(dtype).bits == 64:
|
|
|
|
jtu.check_grads(jnp.linalg.cholesky, args_maker(), order=2)
|
2019-02-11 17:39:07 -05:00
|
|
|
|
2019-12-10 00:38:18 -08:00
|
|
|
def testCholeskyGradPrecision(self):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-05-05 16:40:41 -04:00
|
|
|
a = rng((3, 3), np.float32)
|
|
|
|
a = np.dot(a, a.T)
|
2019-12-10 00:38:18 -08:00
|
|
|
jtu.assert_dot_precision(
|
2020-05-05 16:40:41 -04:00
|
|
|
lax.Precision.HIGHEST, partial(jvp, jnp.linalg.cholesky), (a,), (a,))
|
2019-12-10 00:38:18 -08:00
|
|
|
|
2018-12-20 15:37:34 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_n={}".format(jtu.format_shape_dtype_string((n,n), dtype)),
|
2020-12-04 09:44:50 -08:00
|
|
|
"n": n, "dtype": dtype}
|
2019-02-03 14:37:55 -08:00
|
|
|
for n in [0, 4, 5, 25] # TODO(mattjj): complex64 unstable on large sizes?
|
2020-12-04 09:44:50 -08:00
|
|
|
for dtype in float_types + complex_types))
|
|
|
|
def testDet(self, n, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-30 11:07:56 -07:00
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
2018-12-20 15:37:34 -05:00
|
|
|
args_maker = lambda: [rng((n, n), dtype)]
|
|
|
|
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np.linalg.det, jnp.linalg.det, args_maker, tol=1e-3)
|
|
|
|
self._CompileAndCheck(jnp.linalg.det, args_maker,
|
2020-05-05 16:40:41 -04:00
|
|
|
rtol={np.float64: 1e-13, np.complex128: 1e-13})
|
2018-12-20 15:37:34 -05:00
|
|
|
|
2019-09-16 08:45:10 -07:00
|
|
|
def testDetOfSingularMatrix(self):
|
2020-05-05 16:40:41 -04:00
|
|
|
x = jnp.array([[-1., 3./2], [2./3, -1.]], dtype=np.float32)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(np.float32(0), jsp.linalg.det(x))
|
2020-06-02 17:37:20 -07:00
|
|
|
|
2020-04-25 16:26:25 +01:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
2020-12-04 09:44:50 -08:00
|
|
|
"shape": shape, "dtype": dtype}
|
2020-04-25 16:26:25 +01:00
|
|
|
for shape in [(1, 1), (3, 3), (2, 4, 4)]
|
2020-12-04 09:44:50 -08:00
|
|
|
for dtype in float_types))
|
2020-04-25 16:26:25 +01:00
|
|
|
@jtu.skip_on_devices("tpu")
|
|
|
|
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
2020-12-04 09:44:50 -08:00
|
|
|
def testDetGrad(self, shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-30 11:07:56 -07:00
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
2020-04-25 16:26:25 +01:00
|
|
|
a = rng(shape, dtype)
|
2020-05-05 16:40:41 -04:00
|
|
|
jtu.check_grads(jnp.linalg.det, (a,), 2, atol=1e-1, rtol=1e-1)
|
2020-04-25 16:26:25 +01:00
|
|
|
# make sure there are no NaNs when a matrix is zero
|
|
|
|
if len(shape) == 2:
|
|
|
|
pass
|
|
|
|
jtu.check_grads(
|
2020-05-05 16:40:41 -04:00
|
|
|
jnp.linalg.det, (jnp.zeros_like(a),), 1, atol=1e-1, rtol=1e-1)
|
2020-04-25 16:26:25 +01:00
|
|
|
else:
|
|
|
|
a[0] = 0
|
2020-05-05 16:40:41 -04:00
|
|
|
jtu.check_grads(jnp.linalg.det, (a,), 1, atol=1e-1, rtol=1e-1)
|
2020-04-25 16:26:25 +01:00
|
|
|
|
2021-03-19 19:13:45 -07:00
|
|
|
def testDetGradIssue6121(self):
|
|
|
|
f = lambda x: jnp.linalg.det(x).sum()
|
|
|
|
x = jnp.ones((16, 1, 1))
|
|
|
|
jax.grad(f)(x)
|
|
|
|
jtu.check_grads(f, (x,), 2, atol=1e-1, rtol=1e-1)
|
|
|
|
|
2020-04-25 09:20:26 -07:00
|
|
|
def testDetGradOfSingularMatrixCorank1(self):
|
2020-04-25 16:26:25 +01:00
|
|
|
# Rank 2 matrix with nonzero gradient
|
2020-07-07 16:19:43 -07:00
|
|
|
a = jnp.array([[ 50, -30, 45],
|
|
|
|
[-30, 90, -81],
|
|
|
|
[ 45, -81, 81]], dtype=jnp.float32)
|
|
|
|
jtu.check_grads(jnp.linalg.det, (a,), 1, atol=1e-1, rtol=1e-1)
|
2020-04-25 09:20:26 -07:00
|
|
|
|
|
|
|
def testDetGradOfSingularMatrixCorank2(self):
|
2020-04-25 16:26:25 +01:00
|
|
|
# Rank 1 matrix with zero gradient
|
2020-07-07 16:19:43 -07:00
|
|
|
b = jnp.array([[ 36, -42, 18],
|
|
|
|
[-42, 49, -21],
|
|
|
|
[ 18, -21, 9]], dtype=jnp.float32)
|
2020-05-05 16:40:41 -04:00
|
|
|
jtu.check_grads(jnp.linalg.det, (b,), 1, atol=1e-1, rtol=1e-1)
|
2019-09-16 08:45:10 -07:00
|
|
|
|
2020-02-09 23:35:09 +01:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_m={}_n={}_q={}".format(
|
|
|
|
jtu.format_shape_dtype_string((m,), dtype),
|
|
|
|
jtu.format_shape_dtype_string((nq[0],), dtype),
|
|
|
|
jtu.format_shape_dtype_string(nq[1], dtype)),
|
2020-12-04 09:44:50 -08:00
|
|
|
"m": m, "nq": nq, "dtype": dtype}
|
2020-02-09 23:35:09 +01:00
|
|
|
for m in [1, 5, 7, 23]
|
|
|
|
for nq in zip([2, 4, 6, 36], [(1, 2), (2, 2), (1, 2, 3), (3, 3, 1, 4)])
|
2020-12-04 09:44:50 -08:00
|
|
|
for dtype in float_types))
|
|
|
|
def testTensorsolve(self, m, nq, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-30 11:07:56 -07:00
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
2020-04-22 08:59:22 -07:00
|
|
|
|
2020-02-09 23:35:09 +01:00
|
|
|
# According to numpy docs the shapes are as follows:
|
2020-04-22 08:59:22 -07:00
|
|
|
# Coefficient tensor (a), of shape b.shape + Q.
|
|
|
|
# And prod(Q) == prod(b.shape)
|
|
|
|
# Therefore, n = prod(q)
|
2020-02-09 23:35:09 +01:00
|
|
|
n, q = nq
|
|
|
|
b_shape = (n, m)
|
|
|
|
# To accomplish prod(Q) == prod(b.shape) we append the m extra dim
|
|
|
|
# to Q shape
|
|
|
|
Q = q + (m,)
|
|
|
|
args_maker = lambda: [
|
|
|
|
rng(b_shape + Q, dtype), # = a
|
|
|
|
rng(b_shape, dtype)] # = b
|
|
|
|
a, b = args_maker()
|
2020-05-05 16:40:41 -04:00
|
|
|
result = jnp.linalg.tensorsolve(*args_maker())
|
2020-02-09 23:35:09 +01:00
|
|
|
self.assertEqual(result.shape, Q)
|
|
|
|
|
2020-06-02 17:37:20 -07:00
|
|
|
self._CheckAgainstNumpy(np.linalg.tensorsolve,
|
2020-05-05 16:40:41 -04:00
|
|
|
jnp.linalg.tensorsolve, args_maker,
|
|
|
|
tol={np.float32: 1e-2, np.float64: 1e-3})
|
2020-06-02 17:37:20 -07:00
|
|
|
self._CompileAndCheck(jnp.linalg.tensorsolve,
|
2020-06-01 17:19:23 -04:00
|
|
|
args_maker,
|
2020-05-05 16:40:41 -04:00
|
|
|
rtol={np.float64: 1e-13})
|
2020-02-09 23:35:09 +01:00
|
|
|
|
2018-12-20 22:18:20 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
2019-09-11 08:19:26 -04:00
|
|
|
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
2020-12-04 09:44:50 -08:00
|
|
|
"shape": shape, "dtype": dtype}
|
2019-09-11 08:19:26 -04:00
|
|
|
for shape in [(0, 0), (1, 1), (3, 3), (4, 4), (10, 10), (200, 200),
|
|
|
|
(2, 2, 2), (2, 3, 3), (3, 2, 2)]
|
2020-12-04 09:44:50 -08:00
|
|
|
for dtype in float_types + complex_types))
|
|
|
|
def testSlogdet(self, shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-30 11:07:56 -07:00
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
2019-09-11 08:19:26 -04:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2018-12-20 22:18:20 -05:00
|
|
|
|
2020-05-05 16:40:41 -04:00
|
|
|
self._CheckAgainstNumpy(np.linalg.slogdet, jnp.linalg.slogdet, args_maker,
|
2020-06-01 17:19:23 -04:00
|
|
|
tol=1e-3)
|
|
|
|
self._CompileAndCheck(jnp.linalg.slogdet, args_maker)
|
2019-09-20 20:45:01 -07:00
|
|
|
|
2019-09-17 18:55:11 +01:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
2020-12-04 09:44:50 -08:00
|
|
|
"shape": shape, "dtype": dtype}
|
2019-09-20 20:45:01 -07:00
|
|
|
for shape in [(1, 1), (4, 4), (5, 5), (2, 7, 7)]
|
2020-12-04 09:44:50 -08:00
|
|
|
for dtype in float_types + complex_types))
|
2019-09-20 20:45:01 -07:00
|
|
|
@jtu.skip_on_devices("tpu")
|
2020-02-05 17:35:46 +01:00
|
|
|
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
2020-12-04 09:44:50 -08:00
|
|
|
def testSlogdetGrad(self, shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-30 11:07:56 -07:00
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
2019-09-17 18:55:11 +01:00
|
|
|
a = rng(shape, dtype)
|
2020-05-05 16:40:41 -04:00
|
|
|
jtu.check_grads(jnp.linalg.slogdet, (a,), 2, atol=1e-1, rtol=1e-1)
|
2018-12-20 15:37:34 -05:00
|
|
|
|
2019-09-11 08:19:26 -04:00
|
|
|
def testIssue1213(self):
|
|
|
|
for n in range(5):
|
2020-05-05 16:40:41 -04:00
|
|
|
mat = jnp.array([np.diag(np.ones([5], dtype=np.float32))*(-.01)] * 2)
|
2019-09-11 08:19:26 -04:00
|
|
|
args_maker = lambda: [mat]
|
2020-05-05 16:40:41 -04:00
|
|
|
self._CheckAgainstNumpy(np.linalg.slogdet, jnp.linalg.slogdet, args_maker,
|
2020-06-01 17:19:23 -04:00
|
|
|
tol=1e-3)
|
2019-09-11 08:19:26 -04:00
|
|
|
|
2019-05-13 15:59:58 -04:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2020-09-18 09:30:19 +02:00
|
|
|
{"testcase_name": "_shape={}_leftvectors={}_rightvectors={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype),
|
|
|
|
compute_left_eigenvectors, compute_right_eigenvectors),
|
2020-12-04 09:44:50 -08:00
|
|
|
"shape": shape, "dtype": dtype,
|
2020-09-18 09:30:19 +02:00
|
|
|
"compute_left_eigenvectors": compute_left_eigenvectors,
|
|
|
|
"compute_right_eigenvectors": compute_right_eigenvectors}
|
2019-05-13 19:53:50 -04:00
|
|
|
for shape in [(0, 0), (4, 4), (5, 5), (50, 50), (2, 6, 6)]
|
2019-06-27 16:36:54 -04:00
|
|
|
for dtype in float_types + complex_types
|
2020-09-18 09:30:19 +02:00
|
|
|
for compute_left_eigenvectors, compute_right_eigenvectors in [
|
|
|
|
(False, False),
|
|
|
|
(True, False),
|
|
|
|
(False, True),
|
|
|
|
(True, True)
|
2020-12-04 09:44:50 -08:00
|
|
|
]))
|
2019-05-13 15:59:58 -04:00
|
|
|
# TODO(phawkins): enable when there is an eigendecomposition implementation
|
|
|
|
# for GPU/TPU.
|
|
|
|
@jtu.skip_on_devices("gpu", "tpu")
|
2020-09-18 09:30:19 +02:00
|
|
|
def testEig(self, shape, dtype, compute_left_eigenvectors,
|
2020-12-04 09:44:50 -08:00
|
|
|
compute_right_eigenvectors):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-30 11:07:56 -07:00
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
2019-05-13 19:53:50 -04:00
|
|
|
n = shape[-1]
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2019-05-13 15:59:58 -04:00
|
|
|
|
|
|
|
# Norm, adjusted for dimension and type.
|
|
|
|
def norm(x):
|
2020-05-05 16:40:41 -04:00
|
|
|
norm = np.linalg.norm(x, axis=(-2, -1))
|
|
|
|
return norm / ((n + 1) * jnp.finfo(dtype).eps)
|
2019-05-13 15:59:58 -04:00
|
|
|
|
2020-09-18 09:30:19 +02:00
|
|
|
def check_right_eigenvectors(a, w, vr):
|
|
|
|
self.assertTrue(
|
|
|
|
np.all(norm(np.matmul(a, vr) - w[..., None, :] * vr) < 100))
|
|
|
|
|
|
|
|
def check_left_eigenvectors(a, w, vl):
|
|
|
|
rank = len(a.shape)
|
|
|
|
aH = jnp.conj(a.transpose(list(range(rank - 2)) + [rank - 1, rank - 2]))
|
|
|
|
wC = jnp.conj(w)
|
|
|
|
check_right_eigenvectors(aH, wC, vl)
|
|
|
|
|
2019-05-13 15:59:58 -04:00
|
|
|
a, = args_maker()
|
2020-11-04 13:33:30 -08:00
|
|
|
results = lax.linalg.eig(a, compute_left_eigenvectors,
|
|
|
|
compute_right_eigenvectors)
|
2020-09-18 09:30:19 +02:00
|
|
|
w = results[0]
|
|
|
|
|
|
|
|
if compute_left_eigenvectors:
|
|
|
|
check_left_eigenvectors(a, w, results[1])
|
|
|
|
if compute_right_eigenvectors:
|
|
|
|
check_right_eigenvectors(a, w, results[1 + compute_left_eigenvectors])
|
2019-05-13 15:59:58 -04:00
|
|
|
|
2020-05-05 16:40:41 -04:00
|
|
|
self._CompileAndCheck(partial(jnp.linalg.eig), args_maker,
|
2020-06-01 17:19:23 -04:00
|
|
|
rtol=1e-3)
|
2019-05-13 15:59:58 -04:00
|
|
|
|
2020-11-18 11:18:54 +00:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_shape={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype)),
|
2020-12-04 09:44:50 -08:00
|
|
|
"shape": shape, "dtype": dtype}
|
2020-11-18 11:18:54 +00:00
|
|
|
for shape in [(4, 4), (5, 5), (8, 8), (7, 6, 6)]
|
2020-12-04 09:44:50 -08:00
|
|
|
for dtype in float_types + complex_types))
|
2020-11-18 11:18:54 +00:00
|
|
|
# TODO(phawkins): enable when there is an eigendecomposition implementation
|
|
|
|
# for GPU/TPU.
|
|
|
|
@jtu.skip_on_devices("gpu", "tpu")
|
2020-12-04 09:44:50 -08:00
|
|
|
def testEigvalsGrad(self, shape, dtype):
|
2020-11-18 11:18:54 +00:00
|
|
|
# This test sometimes fails for large matrices. I (@j-towns) suspect, but
|
|
|
|
# haven't checked, that might be because of perturbations causing the
|
|
|
|
# ordering of eigenvalues to change, which will trip up check_grads. So we
|
|
|
|
# just test on small-ish matrices.
|
2020-12-04 09:44:50 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-11-18 11:18:54 +00:00
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
a, = args_maker()
|
|
|
|
tol = 1e-4 if dtype in (np.float64, np.complex128) else 1e-1
|
|
|
|
jtu.check_grads(lambda x: jnp.linalg.eigvals(x), (a,), order=1,
|
|
|
|
modes=['fwd', 'rev'], rtol=tol, atol=tol)
|
|
|
|
|
2019-10-30 19:29:56 -07:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_shape={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype)),
|
2020-12-04 09:44:50 -08:00
|
|
|
"shape": shape, "dtype": dtype}
|
2019-10-30 19:29:56 -07:00
|
|
|
for shape in [(4, 4), (5, 5), (50, 50)]
|
2020-12-04 09:44:50 -08:00
|
|
|
for dtype in float_types + complex_types))
|
2019-10-30 19:29:56 -07:00
|
|
|
# TODO: enable when there is an eigendecomposition implementation
|
|
|
|
# for GPU/TPU.
|
|
|
|
@jtu.skip_on_devices("gpu", "tpu")
|
2020-12-04 09:44:50 -08:00
|
|
|
def testEigvals(self, shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-30 11:07:56 -07:00
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
2019-10-30 19:29:56 -07:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
a, = args_maker()
|
2020-05-05 16:40:41 -04:00
|
|
|
w1, _ = jnp.linalg.eig(a)
|
|
|
|
w2 = jnp.linalg.eigvals(a)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(w1, w2)
|
2019-10-30 19:29:56 -07:00
|
|
|
|
2020-04-10 15:40:57 -04:00
|
|
|
@jtu.skip_on_devices("gpu", "tpu")
|
|
|
|
def testEigvalsInf(self):
|
|
|
|
# https://github.com/google/jax/issues/2661
|
2020-12-08 13:03:30 -08:00
|
|
|
x = jnp.array([[jnp.inf]])
|
2020-05-05 16:40:41 -04:00
|
|
|
self.assertTrue(jnp.all(jnp.isnan(jnp.linalg.eigvals(x))))
|
2020-04-10 15:40:57 -04:00
|
|
|
|
2019-05-13 15:59:58 -04:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
2020-12-04 09:44:50 -08:00
|
|
|
"shape": shape, "dtype": dtype}
|
2019-05-13 15:59:58 -04:00
|
|
|
for shape in [(1, 1), (4, 4), (5, 5)]
|
2020-12-04 09:44:50 -08:00
|
|
|
for dtype in float_types + complex_types))
|
2019-05-13 15:59:58 -04:00
|
|
|
@jtu.skip_on_devices("gpu", "tpu")
|
2020-12-04 09:44:50 -08:00
|
|
|
def testEigBatching(self, shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-30 11:07:56 -07:00
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
2019-05-13 15:59:58 -04:00
|
|
|
shape = (10,) + shape
|
|
|
|
args = rng(shape, dtype)
|
2020-05-05 16:40:41 -04:00
|
|
|
ws, vs = vmap(jnp.linalg.eig)(args)
|
|
|
|
self.assertTrue(np.all(np.linalg.norm(
|
|
|
|
np.matmul(args, vs) - ws[..., None, :] * vs) < 1e-3))
|
2019-05-13 15:59:58 -04:00
|
|
|
|
2019-01-07 18:10:08 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_n={}_lower={}".format(
|
|
|
|
jtu.format_shape_dtype_string((n,n), dtype), lower),
|
2020-12-04 09:44:50 -08:00
|
|
|
"n": n, "dtype": dtype, "lower": lower}
|
2019-01-07 18:10:08 -05:00
|
|
|
for n in [0, 4, 5, 50]
|
2019-06-27 16:36:54 -04:00
|
|
|
for dtype in float_types + complex_types
|
2020-12-04 09:44:50 -08:00
|
|
|
for lower in [False, True]))
|
|
|
|
def testEigh(self, n, dtype, lower):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-30 11:07:56 -07:00
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
2021-02-26 17:36:56 -05:00
|
|
|
tol = 1e-4
|
2019-10-08 16:09:50 -04:00
|
|
|
if jtu.device_under_test() == "tpu":
|
2020-05-05 16:40:41 -04:00
|
|
|
if jnp.issubdtype(dtype, np.complexfloating):
|
2019-10-08 16:09:50 -04:00
|
|
|
raise unittest.SkipTest("No complex eigh on TPU")
|
2019-01-07 18:10:08 -05:00
|
|
|
args_maker = lambda: [rng((n, n), dtype)]
|
|
|
|
|
|
|
|
uplo = "L" if lower else "U"
|
|
|
|
|
|
|
|
a, = args_maker()
|
2020-05-05 16:40:41 -04:00
|
|
|
a = (a + np.conj(a.T)) / 2
|
|
|
|
w, v = jnp.linalg.eigh(np.tril(a) if lower else np.triu(a),
|
2021-02-26 17:36:56 -05:00
|
|
|
UPLO=uplo, symmetrize_input=False)
|
|
|
|
self.assertLessEqual(
|
|
|
|
np.linalg.norm(np.eye(n) - np.matmul(np.conj(T(v)), v)), 1e-3)
|
|
|
|
self.assertLessEqual(np.linalg.norm(np.matmul(a, v) - w * v),
|
|
|
|
tol * np.linalg.norm(a))
|
2019-01-07 18:10:08 -05:00
|
|
|
|
2020-05-05 16:40:41 -04:00
|
|
|
self._CompileAndCheck(partial(jnp.linalg.eigh, UPLO=uplo), args_maker,
|
2020-06-01 17:19:23 -04:00
|
|
|
rtol=1e-3)
|
2019-01-07 18:10:08 -05:00
|
|
|
|
2019-10-30 19:29:56 -07:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_shape={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype)),
|
2020-12-04 09:44:50 -08:00
|
|
|
"shape": shape, "dtype": dtype}
|
2019-10-30 19:29:56 -07:00
|
|
|
for shape in [(4, 4), (5, 5), (50, 50)]
|
2020-12-04 09:44:50 -08:00
|
|
|
for dtype in float_types + complex_types))
|
|
|
|
def testEigvalsh(self, shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-30 11:07:56 -07:00
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
2019-11-06 08:56:29 -08:00
|
|
|
if jtu.device_under_test() == "tpu":
|
2020-05-05 16:40:41 -04:00
|
|
|
if jnp.issubdtype(dtype, jnp.complexfloating):
|
2019-11-06 08:56:29 -08:00
|
|
|
raise unittest.SkipTest("No complex eigh on TPU")
|
2019-10-30 19:29:56 -07:00
|
|
|
n = shape[-1]
|
|
|
|
def args_maker():
|
|
|
|
a = rng((n, n), dtype)
|
2020-05-05 16:40:41 -04:00
|
|
|
a = (a + np.conj(a.T)) / 2
|
2019-10-30 19:29:56 -07:00
|
|
|
return [a]
|
2020-05-05 16:40:41 -04:00
|
|
|
self._CheckAgainstNumpy(np.linalg.eigvalsh, jnp.linalg.eigvalsh, args_maker,
|
2020-06-01 17:19:23 -04:00
|
|
|
tol=1e-3)
|
2019-10-30 19:29:56 -07:00
|
|
|
|
2019-02-11 23:26:26 -08:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_shape={}_lower={}".format(jtu.format_shape_dtype_string(shape, dtype),
|
|
|
|
lower),
|
2020-12-04 09:44:50 -08:00
|
|
|
"shape": shape, "dtype": dtype, "lower":lower}
|
2019-10-09 14:34:46 -04:00
|
|
|
for shape in [(1, 1), (4, 4), (5, 5), (50, 50), (2, 10, 10)]
|
2019-06-27 16:36:54 -04:00
|
|
|
for dtype in float_types + complex_types
|
2019-02-11 23:26:26 -08:00
|
|
|
for lower in [True, False]))
|
2020-12-04 09:44:50 -08:00
|
|
|
def testEighGrad(self, shape, dtype, lower):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2019-04-12 09:29:46 -04:00
|
|
|
self.skipTest("Test fails with numeric errors.")
|
2019-02-11 23:26:26 -08:00
|
|
|
uplo = "L" if lower else "U"
|
|
|
|
a = rng(shape, dtype)
|
2020-05-05 16:40:41 -04:00
|
|
|
a = (a + np.conj(T(a))) / 2
|
|
|
|
ones = np.ones((a.shape[-1], a.shape[-1]), dtype=dtype)
|
|
|
|
a *= np.tril(ones) if lower else np.triu(ones)
|
2019-02-11 23:26:26 -08:00
|
|
|
# Gradient checks will fail without symmetrization as the eigh jvp rule
|
|
|
|
# is only correct for tangents in the symmetric subspace, whereas the
|
|
|
|
# checker checks against unconstrained (co)tangents.
|
2019-06-27 16:36:54 -04:00
|
|
|
if dtype not in complex_types:
|
2020-05-05 16:40:41 -04:00
|
|
|
f = partial(jnp.linalg.eigh, UPLO=uplo, symmetrize_input=True)
|
2019-02-13 23:23:39 -08:00
|
|
|
else: # only check eigenvalue grads for complex matrices
|
2020-05-05 16:40:41 -04:00
|
|
|
f = lambda a: partial(jnp.linalg.eigh, UPLO=uplo, symmetrize_input=True)(a)[0]
|
2019-02-11 23:26:26 -08:00
|
|
|
jtu.check_grads(f, (a,), 2, rtol=1e-1)
|
|
|
|
|
2019-02-13 23:23:39 -08:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_shape={}_lower={}".format(jtu.format_shape_dtype_string(shape, dtype),
|
|
|
|
lower),
|
2020-12-04 09:44:50 -08:00
|
|
|
"shape": shape, "dtype": dtype, "lower":lower, "eps":eps}
|
2019-02-13 23:23:39 -08:00
|
|
|
for shape in [(1, 1), (4, 4), (5, 5), (50, 50)]
|
2019-06-27 16:36:54 -04:00
|
|
|
for dtype in complex_types
|
2019-02-13 23:44:41 -08:00
|
|
|
for lower in [True, False]
|
|
|
|
for eps in [1e-4]))
|
2019-10-08 16:09:50 -04:00
|
|
|
# TODO(phawkins): enable when there is a complex eigendecomposition
|
|
|
|
# implementation for TPU.
|
2019-08-02 11:16:15 -04:00
|
|
|
@jtu.skip_on_devices("tpu")
|
2020-12-04 09:44:50 -08:00
|
|
|
def testEighGradVectorComplex(self, shape, dtype, lower, eps):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-30 11:07:56 -07:00
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
2019-02-13 23:23:39 -08:00
|
|
|
# Special case to test for complex eigenvector grad correctness.
|
|
|
|
# Exact eigenvector coordinate gradients are hard to test numerically for complex
|
|
|
|
# eigensystem solvers given the extra degrees of per-eigenvector phase freedom.
|
|
|
|
# Instead, we numerically verify the eigensystem properties on the perturbed
|
|
|
|
# eigenvectors. You only ever want to optimize eigenvector directions, not coordinates!
|
|
|
|
uplo = "L" if lower else "U"
|
|
|
|
a = rng(shape, dtype)
|
2020-05-05 16:40:41 -04:00
|
|
|
a = (a + np.conj(a.T)) / 2
|
|
|
|
a = np.tril(a) if lower else np.triu(a)
|
2019-02-13 23:23:39 -08:00
|
|
|
a_dot = eps * rng(shape, dtype)
|
2020-05-05 16:40:41 -04:00
|
|
|
a_dot = (a_dot + np.conj(a_dot.T)) / 2
|
|
|
|
a_dot = np.tril(a_dot) if lower else np.triu(a_dot)
|
2019-02-13 23:23:39 -08:00
|
|
|
# evaluate eigenvector gradient and groundtruth eigensystem for perturbed input matrix
|
2020-05-05 16:40:41 -04:00
|
|
|
f = partial(jnp.linalg.eigh, UPLO=uplo)
|
2019-02-14 02:28:00 -08:00
|
|
|
(w, v), (dw, dv) = jvp(f, primals=(a,), tangents=(a_dot,))
|
2020-06-25 08:14:54 -04:00
|
|
|
self.assertTrue(jnp.issubdtype(w.dtype, jnp.floating))
|
|
|
|
self.assertTrue(jnp.issubdtype(dw.dtype, jnp.floating))
|
2019-02-13 23:23:39 -08:00
|
|
|
new_a = a + a_dot
|
|
|
|
new_w, new_v = f(new_a)
|
2020-05-05 16:40:41 -04:00
|
|
|
new_a = (new_a + np.conj(new_a.T)) / 2
|
2019-02-13 23:23:39 -08:00
|
|
|
# Assert rtol eigenvalue delta between perturbed eigenvectors vs new true eigenvalues.
|
2019-02-14 02:28:00 -08:00
|
|
|
RTOL=1e-2
|
2020-05-05 16:40:41 -04:00
|
|
|
assert np.max(
|
|
|
|
np.abs((np.diag(np.dot(np.conj((v+dv).T), np.dot(new_a,(v+dv)))) - new_w) / new_w)) < RTOL
|
2019-02-13 23:23:39 -08:00
|
|
|
# Redundant to above, but also assert rtol for eigenvector property with new true eigenvalues.
|
2020-05-05 16:40:41 -04:00
|
|
|
assert np.max(
|
|
|
|
np.linalg.norm(np.abs(new_w*(v+dv) - np.dot(new_a, (v+dv))), axis=0) /
|
|
|
|
np.linalg.norm(np.abs(new_w*(v+dv)), axis=0)
|
2019-02-14 02:28:00 -08:00
|
|
|
) < RTOL
|
2019-02-13 23:23:39 -08:00
|
|
|
|
2019-12-10 00:38:18 -08:00
|
|
|
def testEighGradPrecision(self):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-05-05 16:40:41 -04:00
|
|
|
a = rng((3, 3), np.float32)
|
2019-12-10 00:38:18 -08:00
|
|
|
jtu.assert_dot_precision(
|
2020-05-05 16:40:41 -04:00
|
|
|
lax.Precision.HIGHEST, partial(jvp, jnp.linalg.eigh), (a,), (a,))
|
2019-12-10 00:38:18 -08:00
|
|
|
|
2019-05-10 15:15:38 -04:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
2020-12-04 09:44:50 -08:00
|
|
|
"shape": shape, "dtype": dtype}
|
2019-05-10 15:15:38 -04:00
|
|
|
for shape in [(1, 1), (4, 4), (5, 5)]
|
2020-12-04 09:44:50 -08:00
|
|
|
for dtype in float_types + complex_types))
|
|
|
|
def testEighBatching(self, shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-30 11:07:56 -07:00
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
2019-10-08 16:09:50 -04:00
|
|
|
if (jtu.device_under_test() == "tpu" and
|
2020-05-05 16:40:41 -04:00
|
|
|
jnp.issubdtype(dtype, np.complexfloating)):
|
2019-10-08 16:09:50 -04:00
|
|
|
raise unittest.SkipTest("No complex eigh on TPU")
|
2019-05-10 15:15:38 -04:00
|
|
|
shape = (10,) + shape
|
|
|
|
args = rng(shape, dtype)
|
2020-05-05 16:40:41 -04:00
|
|
|
args = (args + np.conj(T(args))) / 2
|
2019-05-10 15:15:38 -04:00
|
|
|
ws, vs = vmap(jsp.linalg.eigh)(args)
|
2020-05-05 16:40:41 -04:00
|
|
|
self.assertTrue(np.all(np.linalg.norm(
|
|
|
|
np.matmul(args, vs) - ws[..., None, :] * vs) < 1e-3))
|
2019-05-10 15:15:38 -04:00
|
|
|
|
2021-03-31 18:35:15 -07:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
|
|
|
"shape": shape, "dtype": dtype}
|
|
|
|
for shape in [(1,), (4,), (5,)]
|
|
|
|
for dtype in (np.int32,)))
|
|
|
|
def testLuPivotsToPermutation(self, shape, dtype):
|
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
|
|
|
pivots_size = shape[-1]
|
|
|
|
permutation_size = 2 * pivots_size
|
|
|
|
|
|
|
|
pivots = jnp.arange(permutation_size - 1, pivots_size - 1, -1, dtype=dtype)
|
|
|
|
pivots = jnp.broadcast_to(pivots, shape)
|
|
|
|
actual = lax.linalg.lu_pivots_to_permutation(pivots, permutation_size)
|
|
|
|
expected = jnp.arange(permutation_size - 1, -1, -1, dtype=dtype)
|
|
|
|
expected = jnp.broadcast_to(expected, actual.shape)
|
|
|
|
self.assertArraysEqual(actual, expected)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
|
|
|
"shape": shape, "dtype": dtype}
|
|
|
|
for shape in [(1,), (4,), (5,)]
|
|
|
|
for dtype in (np.int32,)))
|
|
|
|
def testLuPivotsToPermutationBatching(self, shape, dtype):
|
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
|
|
|
shape = (10,) + shape
|
|
|
|
pivots_size = shape[-1]
|
|
|
|
permutation_size = 2 * pivots_size
|
|
|
|
|
|
|
|
pivots = jnp.arange(permutation_size - 1, pivots_size - 1, -1, dtype=dtype)
|
|
|
|
pivots = jnp.broadcast_to(pivots, shape)
|
|
|
|
batched_fn = vmap(
|
|
|
|
lambda x: lax.linalg.lu_pivots_to_permutation(x, permutation_size))
|
|
|
|
actual = batched_fn(pivots)
|
|
|
|
expected = jnp.arange(permutation_size - 1, -1, -1, dtype=dtype)
|
|
|
|
expected = jnp.broadcast_to(expected, actual.shape)
|
|
|
|
self.assertArraysEqual(actual, expected)
|
|
|
|
|
2019-02-07 10:51:55 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_shape={}_ord={}_axis={}_keepdims={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), ord, axis, keepdims),
|
|
|
|
"shape": shape, "dtype": dtype, "axis": axis, "keepdims": keepdims,
|
2020-12-04 09:44:50 -08:00
|
|
|
"ord": ord}
|
2019-02-07 10:51:55 -05:00
|
|
|
for axis, shape in [
|
|
|
|
(None, (1,)), (None, (7,)), (None, (5, 8)),
|
|
|
|
(0, (9,)), (0, (4, 5)), ((1,), (10, 7, 3)), ((-2,), (4, 8)),
|
2019-08-07 09:21:07 -04:00
|
|
|
(-1, (6, 3)), ((0, 2), (3, 4, 5)), ((2, 0), (7, 8, 9)),
|
|
|
|
(None, (7, 8, 11))]
|
2019-02-07 10:51:55 -05:00
|
|
|
for keepdims in [False, True]
|
|
|
|
for ord in (
|
2019-08-07 09:21:07 -04:00
|
|
|
[None] if axis is None and len(shape) > 2
|
2020-05-05 16:40:41 -04:00
|
|
|
else [None, 0, 1, 2, 3, -1, -2, -3, jnp.inf, -jnp.inf]
|
2019-02-07 10:51:55 -05:00
|
|
|
if (axis is None and len(shape) == 1) or
|
|
|
|
isinstance(axis, int) or
|
|
|
|
(isinstance(axis, tuple) and len(axis) == 1)
|
2020-05-05 16:40:41 -04:00
|
|
|
else [None, 'fro', 1, 2, -1, -2, jnp.inf, -jnp.inf, 'nuc'])
|
2020-12-04 09:44:50 -08:00
|
|
|
for dtype in float_types + complex_types)) # type: ignore
|
|
|
|
def testNorm(self, shape, dtype, ord, axis, keepdims):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-30 11:07:56 -07:00
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
2019-06-27 16:36:54 -04:00
|
|
|
if (ord in ('nuc', 2, -2) and (
|
2019-08-04 17:17:49 -04:00
|
|
|
jtu.device_under_test() != "cpu" or
|
2019-06-27 16:36:54 -04:00
|
|
|
(isinstance(axis, tuple) and len(axis) == 2))):
|
2019-07-02 20:06:29 -04:00
|
|
|
raise unittest.SkipTest("No adequate SVD implementation available")
|
2019-02-07 10:51:55 -05:00
|
|
|
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
np_fn = partial(np.linalg.norm, ord=ord, axis=axis, keepdims=keepdims)
|
2021-01-25 12:56:11 +10:00
|
|
|
jnp_fn = partial(jnp.linalg.norm, ord=ord, axis=axis, keepdims=keepdims)
|
|
|
|
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False,
|
2020-06-01 17:19:23 -04:00
|
|
|
tol=1e-3)
|
2021-01-25 12:56:11 +10:00
|
|
|
self._CompileAndCheck(jnp_fn, args_maker)
|
2019-02-07 10:51:55 -05:00
|
|
|
|
2019-01-08 15:51:30 +05:30
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_n={}_full_matrices={}_compute_uv={}".format(
|
2019-09-05 18:12:00 -04:00
|
|
|
jtu.format_shape_dtype_string(b + (m, n), dtype), full_matrices,
|
|
|
|
compute_uv),
|
|
|
|
"b": b, "m": m, "n": n, "dtype": dtype, "full_matrices": full_matrices,
|
2020-12-04 09:44:50 -08:00
|
|
|
"compute_uv": compute_uv}
|
2019-09-05 18:12:00 -04:00
|
|
|
for b in [(), (3,), (2, 3)]
|
2021-01-13 10:39:00 -08:00
|
|
|
for m in [0, 2, 7, 29, 53]
|
|
|
|
for n in [0, 2, 7, 29, 53]
|
2019-06-27 16:36:54 -04:00
|
|
|
for dtype in float_types + complex_types
|
2019-01-08 15:51:30 +05:30
|
|
|
for full_matrices in [False, True]
|
2020-12-04 09:44:50 -08:00
|
|
|
for compute_uv in [False, True]))
|
|
|
|
def testSVD(self, b, m, n, dtype, full_matrices, compute_uv):
|
2020-06-05 12:21:30 -04:00
|
|
|
if (jnp.issubdtype(dtype, np.complexfloating) and
|
|
|
|
jtu.device_under_test() == "tpu"):
|
|
|
|
raise unittest.SkipTest("No complex SVD implementation")
|
2020-12-04 09:44:50 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-30 11:07:56 -07:00
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
2019-09-05 18:12:00 -04:00
|
|
|
args_maker = lambda: [rng(b + (m, n), dtype)]
|
2019-01-08 15:51:30 +05:30
|
|
|
|
|
|
|
# Norm, adjusted for dimension and type.
|
|
|
|
def norm(x):
|
2020-05-05 16:40:41 -04:00
|
|
|
norm = np.linalg.norm(x, axis=(-2, -1))
|
2021-01-13 10:39:00 -08:00
|
|
|
return norm / (max(1, m, n) * jnp.finfo(dtype).eps)
|
2019-01-08 15:51:30 +05:30
|
|
|
|
|
|
|
a, = args_maker()
|
2020-05-05 16:40:41 -04:00
|
|
|
out = jnp.linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv)
|
2019-01-08 15:51:30 +05:30
|
|
|
if compute_uv:
|
|
|
|
# Check the reconstructed matrices
|
|
|
|
if full_matrices:
|
|
|
|
k = min(m, n)
|
|
|
|
if m < n:
|
2020-05-05 16:40:41 -04:00
|
|
|
self.assertTrue(np.all(
|
|
|
|
norm(a - np.matmul(out[1][..., None, :] * out[0], out[2][..., :k, :])) < 50))
|
2019-01-08 15:51:30 +05:30
|
|
|
else:
|
2020-05-05 16:40:41 -04:00
|
|
|
self.assertTrue(np.all(
|
|
|
|
norm(a - np.matmul(out[1][..., None, :] * out[0][..., :, :k], out[2])) < 350))
|
2019-01-08 15:51:30 +05:30
|
|
|
else:
|
2020-05-05 16:40:41 -04:00
|
|
|
self.assertTrue(np.all(
|
|
|
|
norm(a - np.matmul(out[1][..., None, :] * out[0], out[2])) < 350))
|
2019-01-08 15:51:30 +05:30
|
|
|
|
|
|
|
# Check the unitary properties of the singular vector matrices.
|
2020-05-05 16:40:41 -04:00
|
|
|
self.assertTrue(np.all(norm(np.eye(out[0].shape[-1]) - np.matmul(np.conj(T(out[0])), out[0])) < 15))
|
2019-01-08 15:51:30 +05:30
|
|
|
if m >= n:
|
2020-05-05 16:40:41 -04:00
|
|
|
self.assertTrue(np.all(norm(np.eye(out[2].shape[-1]) - np.matmul(np.conj(T(out[2])), out[2])) < 10))
|
2019-01-08 15:51:30 +05:30
|
|
|
else:
|
2020-05-05 16:40:41 -04:00
|
|
|
self.assertTrue(np.all(norm(np.eye(out[2].shape[-2]) - np.matmul(out[2], np.conj(T(out[2])))) < 20))
|
2019-01-08 15:51:30 +05:30
|
|
|
|
2019-01-08 21:47:19 +05:30
|
|
|
else:
|
2020-05-05 16:40:41 -04:00
|
|
|
self.assertTrue(np.allclose(np.linalg.svd(a, compute_uv=False), np.asarray(out), atol=1e-4, rtol=1e-4))
|
2019-01-08 21:47:19 +05:30
|
|
|
|
2020-05-05 16:40:41 -04:00
|
|
|
self._CompileAndCheck(partial(jnp.linalg.svd, full_matrices=full_matrices, compute_uv=compute_uv),
|
2020-06-01 17:19:23 -04:00
|
|
|
args_maker)
|
2020-12-11 13:29:35 +01:00
|
|
|
if not compute_uv:
|
2020-05-05 16:40:41 -04:00
|
|
|
svd = partial(jnp.linalg.svd, full_matrices=full_matrices,
|
2020-05-04 23:00:20 -04:00
|
|
|
compute_uv=compute_uv)
|
|
|
|
# TODO(phawkins): these tolerances seem very loose.
|
2020-12-11 13:29:35 +01:00
|
|
|
if dtype == np.complex128:
|
|
|
|
jtu.check_jvp(svd, partial(jvp, svd), (a,), rtol=1e-4, atol=1e-4, eps=1e-8)
|
|
|
|
else:
|
|
|
|
jtu.check_jvp(svd, partial(jvp, svd), (a,), rtol=5e-2, atol=2e-1)
|
|
|
|
|
|
|
|
if jtu.device_under_test() == "tpu":
|
|
|
|
raise unittest.SkipTest("TPU matmul does not have enough precision")
|
|
|
|
# TODO(frederikwilde): Find the appropriate precision to use for this test on TPUs.
|
|
|
|
|
|
|
|
if compute_uv and (not full_matrices):
|
|
|
|
b, = args_maker()
|
|
|
|
def f(x):
|
|
|
|
u, s, v = jnp.linalg.svd(
|
|
|
|
a + x * b,
|
|
|
|
full_matrices=full_matrices,
|
|
|
|
compute_uv=compute_uv)
|
|
|
|
vdiag = jnp.vectorize(jnp.diag, signature='(k)->(k,k)')
|
|
|
|
return jnp.matmul(jnp.matmul(u, vdiag(s)), v).real
|
|
|
|
_, t_out = jvp(f, (1.,), (1.,))
|
|
|
|
if dtype == np.complex128:
|
|
|
|
atol = 1e-13
|
|
|
|
else:
|
|
|
|
atol = 5e-4
|
|
|
|
self.assertArraysAllClose(t_out, b.real, atol=atol)
|
2019-01-07 18:10:08 -05:00
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_shape={}_fullmatrices={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), full_matrices),
|
2020-12-04 09:44:50 -08:00
|
|
|
"shape": shape, "dtype": dtype, "full_matrices": full_matrices}
|
2019-09-04 16:24:32 -04:00
|
|
|
for shape in [(1, 1), (3, 3), (3, 4), (2, 10, 5), (2, 200, 100)]
|
|
|
|
for dtype in float_types + complex_types
|
2020-12-04 09:44:50 -08:00
|
|
|
for full_matrices in [False, True]))
|
|
|
|
def testQr(self, shape, dtype, full_matrices):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-30 11:07:56 -07:00
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
m, n = shape[-2:]
|
|
|
|
|
|
|
|
if full_matrices:
|
|
|
|
mode, k = "complete", m
|
|
|
|
else:
|
|
|
|
mode, k = "reduced", min(m, n)
|
|
|
|
|
|
|
|
a = rng(shape, dtype)
|
2020-05-05 16:40:41 -04:00
|
|
|
lq, lr = jnp.linalg.qr(a, mode=mode)
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2020-05-05 16:40:41 -04:00
|
|
|
# np.linalg.qr doesn't support batch dimensions. But it seems like an
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
# inevitable extension so we support it in our version.
|
2020-05-05 16:40:41 -04:00
|
|
|
nq = np.zeros(shape[:-2] + (m, k), dtype)
|
|
|
|
nr = np.zeros(shape[:-2] + (k, n), dtype)
|
|
|
|
for index in np.ndindex(*shape[:-2]):
|
|
|
|
nq[index], nr[index] = np.linalg.qr(a[index], mode=mode)
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
max_rank = max(m, n)
|
|
|
|
|
|
|
|
# Norm, adjusted for dimension and type.
|
|
|
|
def norm(x):
|
2020-05-05 16:40:41 -04:00
|
|
|
n = np.linalg.norm(x, axis=(-2, -1))
|
|
|
|
return n / (max_rank * jnp.finfo(dtype).eps)
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
def compare_orthogonal(q1, q2):
|
|
|
|
# Q is unique up to sign, so normalize the sign first.
|
2020-05-05 16:40:41 -04:00
|
|
|
sum_of_ratios = np.sum(np.divide(q1, q2), axis=-2, keepdims=True)
|
|
|
|
phases = np.divide(sum_of_ratios, np.abs(sum_of_ratios))
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
q1 *= phases
|
2020-05-05 16:40:41 -04:00
|
|
|
self.assertTrue(np.all(norm(q1 - q2) < 30))
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
# Check a ~= qr
|
2020-05-05 16:40:41 -04:00
|
|
|
self.assertTrue(np.all(norm(a - np.matmul(lq, lr)) < 30))
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
# Compare the first 'k' vectors of Q; the remainder form an arbitrary
|
|
|
|
# orthonormal basis for the null space.
|
|
|
|
compare_orthogonal(nq[..., :k], lq[..., :k])
|
|
|
|
|
|
|
|
# Check that q is close to unitary.
|
2020-05-05 16:40:41 -04:00
|
|
|
self.assertTrue(np.all(
|
|
|
|
norm(np.eye(k) -np.matmul(np.conj(T(lq)), lq)) < 5))
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2020-04-28 17:58:49 +01:00
|
|
|
if not full_matrices and m >= n:
|
2020-05-05 16:40:41 -04:00
|
|
|
jtu.check_jvp(jnp.linalg.qr, partial(jvp, jnp.linalg.qr), (a,), atol=3e-3)
|
2018-12-17 16:36:55 +00:00
|
|
|
|
2019-09-04 16:24:32 -04:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_shape={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype)),
|
2020-12-04 09:44:50 -08:00
|
|
|
"shape": shape, "dtype": dtype}
|
2019-09-04 16:24:32 -04:00
|
|
|
for shape in [(10, 4, 5), (5, 3, 3), (7, 6, 4)]
|
2020-12-04 09:44:50 -08:00
|
|
|
for dtype in float_types + complex_types))
|
|
|
|
def testQrBatching(self, shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-05-05 16:40:41 -04:00
|
|
|
args = rng(shape, jnp.float32)
|
2019-04-30 18:48:09 -04:00
|
|
|
qs, rs = vmap(jsp.linalg.qr)(args)
|
2020-05-05 16:40:41 -04:00
|
|
|
self.assertTrue(np.all(np.linalg.norm(args - np.matmul(qs, rs)) < 1e-3))
|
2019-04-30 18:48:09 -04:00
|
|
|
|
2020-02-08 07:20:04 +11:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_shape={}_pnorm={}".format(jtu.format_shape_dtype_string(shape, dtype), pnorm),
|
|
|
|
"shape": shape, "pnorm": pnorm, "dtype": dtype}
|
|
|
|
for shape in [(1, 1), (4, 4), (2, 3, 5), (5, 5, 5), (20, 20), (5, 10)]
|
2020-05-05 16:40:41 -04:00
|
|
|
for pnorm in [jnp.inf, -jnp.inf, 1, -1, 2, -2, 'fro']
|
2020-02-08 07:20:04 +11:00
|
|
|
for dtype in float_types + complex_types))
|
2020-02-10 12:49:58 +01:00
|
|
|
@jtu.skip_on_devices("gpu") # TODO(#2203): numerical errors
|
2020-02-08 07:20:04 +11:00
|
|
|
def testCond(self, shape, pnorm, dtype):
|
2020-07-30 11:07:56 -07:00
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
2020-06-05 12:21:30 -04:00
|
|
|
if (jnp.issubdtype(dtype, np.complexfloating) and
|
|
|
|
jtu.device_under_test() == "tpu"):
|
|
|
|
raise unittest.SkipTest("No complex SVD implementation")
|
2020-02-08 07:20:04 +11:00
|
|
|
|
|
|
|
def gen_mat():
|
2020-05-04 23:00:20 -04:00
|
|
|
# arr_gen = jtu.rand_some_nan(self.rng())
|
|
|
|
arr_gen = jtu.rand_default(self.rng())
|
2020-02-08 07:20:04 +11:00
|
|
|
res = arr_gen(shape, dtype)
|
|
|
|
return res
|
|
|
|
|
|
|
|
def args_gen(p):
|
|
|
|
def _args_gen():
|
|
|
|
return [gen_mat(), p]
|
|
|
|
return _args_gen
|
|
|
|
|
|
|
|
args_maker = args_gen(pnorm)
|
|
|
|
if pnorm not in [2, -2] and len(set(shape[-2:])) != 1:
|
2020-05-05 16:40:41 -04:00
|
|
|
with self.assertRaises(np.linalg.LinAlgError):
|
|
|
|
jnp.linalg.cond(*args_maker())
|
2020-02-08 07:20:04 +11:00
|
|
|
else:
|
2020-05-05 16:40:41 -04:00
|
|
|
self._CheckAgainstNumpy(np.linalg.cond, jnp.linalg.cond, args_maker,
|
2020-02-08 07:20:04 +11:00
|
|
|
check_dtypes=False, tol=1e-3)
|
2020-05-05 16:40:41 -04:00
|
|
|
partial_norm = partial(jnp.linalg.cond, p=pnorm)
|
2020-02-08 07:20:04 +11:00
|
|
|
self._CompileAndCheck(partial_norm, lambda: [gen_mat()],
|
|
|
|
check_dtypes=False, rtol=1e-03, atol=1e-03)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
2020-12-04 09:44:50 -08:00
|
|
|
"shape": shape, "dtype": dtype}
|
2020-02-08 07:20:04 +11:00
|
|
|
for shape in [(1, 1), (4, 4), (200, 200), (7, 7, 7, 7)]
|
2020-12-04 09:44:50 -08:00
|
|
|
for dtype in float_types))
|
|
|
|
def testTensorinv(self, shape, dtype):
|
2020-07-30 11:07:56 -07:00
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
2020-12-04 09:44:50 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-02-08 07:20:04 +11:00
|
|
|
|
|
|
|
def tensor_maker():
|
|
|
|
invertible = False
|
|
|
|
while not invertible:
|
|
|
|
a = rng(shape, dtype)
|
|
|
|
try:
|
2020-05-05 16:40:41 -04:00
|
|
|
np.linalg.inv(a)
|
2020-02-08 07:20:04 +11:00
|
|
|
invertible = True
|
2020-05-05 16:40:41 -04:00
|
|
|
except np.linalg.LinAlgError:
|
2020-02-08 07:20:04 +11:00
|
|
|
pass
|
|
|
|
return a
|
|
|
|
|
2020-05-05 16:40:41 -04:00
|
|
|
args_maker = lambda: [tensor_maker(), int(np.floor(len(shape) / 2))]
|
|
|
|
self._CheckAgainstNumpy(np.linalg.tensorinv, jnp.linalg.tensorinv, args_maker,
|
2020-02-08 07:20:04 +11:00
|
|
|
check_dtypes=False, tol=1e-3)
|
2020-05-05 16:40:41 -04:00
|
|
|
partial_inv = partial(jnp.linalg.tensorinv, ind=int(np.floor(len(shape) / 2)))
|
2020-02-08 07:20:04 +11:00
|
|
|
self._CompileAndCheck(partial_inv, lambda: [tensor_maker()], check_dtypes=False, rtol=1e-03, atol=1e-03)
|
|
|
|
|
2018-12-21 16:29:45 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_lhs={}_rhs={}".format(
|
|
|
|
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype)),
|
2020-12-04 09:44:50 -08:00
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype}
|
2018-12-21 16:29:45 -05:00
|
|
|
for lhs_shape, rhs_shape in [
|
|
|
|
((1, 1), (1, 1)),
|
|
|
|
((4, 4), (4,)),
|
|
|
|
((8, 8), (8, 4)),
|
2019-06-17 20:32:19 -04:00
|
|
|
((1, 2, 2), (3, 2)),
|
2021-01-05 17:11:52 -08:00
|
|
|
((2, 1, 3, 3), (1, 4, 3, 4)),
|
2021-01-05 11:51:32 -08:00
|
|
|
((1, 0, 0), (1, 0, 2)),
|
2018-12-21 16:29:45 -05:00
|
|
|
]
|
2020-12-04 09:44:50 -08:00
|
|
|
for dtype in float_types + complex_types))
|
|
|
|
def testSolve(self, lhs_shape, rhs_shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-30 11:07:56 -07:00
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
2018-12-21 16:29:45 -05:00
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
|
|
|
|
2020-05-05 16:40:41 -04:00
|
|
|
self._CheckAgainstNumpy(np.linalg.solve, jnp.linalg.solve, args_maker,
|
2020-06-01 17:19:23 -04:00
|
|
|
tol=1e-3)
|
|
|
|
self._CompileAndCheck(jnp.linalg.solve, args_maker)
|
2018-12-21 16:29:45 -05:00
|
|
|
|
2018-12-13 19:28:05 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
2020-12-04 09:44:50 -08:00
|
|
|
"shape": shape, "dtype": dtype}
|
2021-01-05 11:51:32 -08:00
|
|
|
for shape in [(1, 1), (4, 4), (2, 5, 5), (200, 200), (5, 5, 5), (0, 0)]
|
2020-12-04 09:44:50 -08:00
|
|
|
for dtype in float_types))
|
|
|
|
def testInv(self, shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-30 11:07:56 -07:00
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
2019-08-04 17:17:49 -04:00
|
|
|
if jtu.device_under_test() == "gpu" and shape == (200, 200):
|
2019-07-02 20:06:29 -04:00
|
|
|
raise unittest.SkipTest("Test is flaky on GPU")
|
2019-06-27 16:36:54 -04:00
|
|
|
|
2018-12-13 19:28:05 -05:00
|
|
|
def args_maker():
|
2018-12-13 19:33:02 -05:00
|
|
|
invertible = False
|
|
|
|
while not invertible:
|
|
|
|
a = rng(shape, dtype)
|
|
|
|
try:
|
2020-05-05 16:40:41 -04:00
|
|
|
np.linalg.inv(a)
|
2018-12-13 19:33:02 -05:00
|
|
|
invertible = True
|
2020-05-05 16:40:41 -04:00
|
|
|
except np.linalg.LinAlgError:
|
2018-12-13 19:33:02 -05:00
|
|
|
pass
|
2018-12-13 19:28:05 -05:00
|
|
|
return [a]
|
|
|
|
|
2020-05-05 16:40:41 -04:00
|
|
|
self._CheckAgainstNumpy(np.linalg.inv, jnp.linalg.inv, args_maker,
|
2020-06-01 17:19:23 -04:00
|
|
|
tol=1e-3)
|
|
|
|
self._CompileAndCheck(jnp.linalg.inv, args_maker)
|
2018-12-13 19:28:05 -05:00
|
|
|
|
2019-12-03 11:15:39 -08:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
2020-12-04 09:44:50 -08:00
|
|
|
"shape": shape, "dtype": dtype}
|
2021-01-13 10:39:00 -08:00
|
|
|
for shape in [(1, 1), (4, 4), (2, 70, 7), (2000, 7), (7, 1000), (70, 7, 2),
|
|
|
|
(2, 0, 0), (3, 0, 2), (1, 0)]
|
2020-12-04 09:44:50 -08:00
|
|
|
for dtype in float_types + complex_types))
|
|
|
|
def testPinv(self, shape, dtype):
|
2020-06-05 12:21:30 -04:00
|
|
|
if (jnp.issubdtype(dtype, np.complexfloating) and
|
|
|
|
jtu.device_under_test() == "tpu"):
|
|
|
|
raise unittest.SkipTest("No complex SVD implementation")
|
2020-12-04 09:44:50 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-30 11:07:56 -07:00
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
2019-12-03 11:15:39 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
|
2020-05-05 16:40:41 -04:00
|
|
|
self._CheckAgainstNumpy(np.linalg.pinv, jnp.linalg.pinv, args_maker,
|
2020-06-01 17:19:23 -04:00
|
|
|
tol=1e-2)
|
|
|
|
self._CompileAndCheck(jnp.linalg.pinv, args_maker)
|
2020-06-05 12:21:30 -04:00
|
|
|
if jtu.device_under_test() != "tpu":
|
|
|
|
# TODO(phawkins): 1e-1 seems like a very loose tolerance.
|
|
|
|
jtu.check_grads(jnp.linalg.pinv, args_maker(), 2, rtol=1e-1, atol=2e-1)
|
2020-04-22 20:15:04 -04:00
|
|
|
|
|
|
|
def testPinvGradIssue2792(self):
|
|
|
|
def f(p):
|
2020-05-05 16:40:41 -04:00
|
|
|
a = jnp.array([[0., 0.],[-p, 1.]], jnp.float32) * 1 / (1 + p**2)
|
|
|
|
return jnp.linalg.pinv(a)
|
|
|
|
j = jax.jacobian(f)(jnp.float32(2.))
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(jnp.array([[0., -1.], [ 0., 0.]], jnp.float32), j)
|
2020-04-22 20:15:04 -04:00
|
|
|
|
2020-05-05 16:40:41 -04:00
|
|
|
expected = jnp.array([[[[-1., 0.], [ 0., 0.]], [[0., -1.], [0., 0.]]],
|
2020-04-22 20:15:04 -04:00
|
|
|
[[[0., 0.], [-1., 0.]], [[0., 0.], [0., -1.]]]],
|
2020-05-05 16:40:41 -04:00
|
|
|
dtype=jnp.float32)
|
2020-04-22 20:15:04 -04:00
|
|
|
self.assertAllClose(
|
2020-06-01 17:19:23 -04:00
|
|
|
expected, jax.jacobian(jnp.linalg.pinv)(jnp.eye(2, dtype=jnp.float32)))
|
2019-12-03 11:15:39 -08:00
|
|
|
|
2020-01-24 16:52:40 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_shape={}_n={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), n),
|
2020-12-04 09:44:50 -08:00
|
|
|
"shape": shape, "dtype": dtype, "n": n}
|
2020-01-24 16:52:40 -05:00
|
|
|
for shape in [(1, 1), (2, 2), (4, 4), (5, 5),
|
|
|
|
(1, 2, 2), (2, 3, 3), (2, 5, 5)]
|
|
|
|
for dtype in float_types + complex_types
|
2020-12-04 09:44:50 -08:00
|
|
|
for n in [-5, -2, -1, 0, 1, 2, 3, 4, 5, 10]))
|
2020-02-20 09:41:08 +01:00
|
|
|
@jtu.skip_on_devices("tpu") # TODO(b/149870255): Bug in XLA:TPU?.
|
2020-12-04 09:44:50 -08:00
|
|
|
def testMatrixPower(self, shape, dtype, n):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-30 11:07:56 -07:00
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
2020-01-24 16:52:40 -05:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2020-01-27 12:48:10 -05:00
|
|
|
tol = 1e-1 if jtu.device_under_test() == "tpu" else 1e-3
|
2020-05-05 16:40:41 -04:00
|
|
|
self._CheckAgainstNumpy(partial(np.linalg.matrix_power, n=n),
|
|
|
|
partial(jnp.linalg.matrix_power, n=n),
|
2020-06-01 17:19:23 -04:00
|
|
|
args_maker, tol=tol)
|
2020-05-05 16:40:41 -04:00
|
|
|
self._CompileAndCheck(partial(jnp.linalg.matrix_power, n=n), args_maker,
|
2020-06-01 17:19:23 -04:00
|
|
|
rtol=1e-3)
|
2020-01-24 16:52:40 -05:00
|
|
|
|
2020-01-26 14:29:33 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_shape={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype)),
|
2020-12-04 09:44:50 -08:00
|
|
|
"shape": shape, "dtype": dtype}
|
2020-01-26 14:29:33 -05:00
|
|
|
for shape in [(3, ), (1, 2), (8, 5), (4, 4), (5, 5), (50, 50)]
|
2020-12-04 09:44:50 -08:00
|
|
|
for dtype in float_types + complex_types))
|
|
|
|
def testMatrixRank(self, shape, dtype):
|
2020-06-05 12:21:30 -04:00
|
|
|
if (jnp.issubdtype(dtype, np.complexfloating) and
|
|
|
|
jtu.device_under_test() == "tpu"):
|
|
|
|
raise unittest.SkipTest("No complex SVD implementation")
|
2020-12-04 09:44:50 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-30 11:07:56 -07:00
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
2020-01-26 14:29:33 -05:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
a, = args_maker()
|
2020-05-05 16:40:41 -04:00
|
|
|
self._CheckAgainstNumpy(np.linalg.matrix_rank, jnp.linalg.matrix_rank,
|
2020-01-26 14:29:33 -05:00
|
|
|
args_maker, check_dtypes=False, tol=1e-3)
|
2020-05-05 16:40:41 -04:00
|
|
|
self._CompileAndCheck(jnp.linalg.matrix_rank, args_maker,
|
2020-01-26 14:29:33 -05:00
|
|
|
check_dtypes=False, rtol=1e-3)
|
|
|
|
|
2020-04-15 17:35:54 -07:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_shapes={}".format(
|
|
|
|
','.join(jtu.format_shape_dtype_string(s, dtype) for s in shapes)),
|
2020-12-04 09:44:50 -08:00
|
|
|
"shapes": shapes, "dtype": dtype}
|
2020-04-15 17:35:54 -07:00
|
|
|
for shapes in [
|
|
|
|
[(3, ), (3, 1)], # quick-out codepath
|
|
|
|
[(1, 3), (3, 5), (5, 2)], # multi_dot_three codepath
|
|
|
|
[(1, 3), (3, 5), (5, 2), (2, 7), (7, )] # dynamic programming codepath
|
|
|
|
]
|
2020-12-04 09:44:50 -08:00
|
|
|
for dtype in float_types + complex_types))
|
|
|
|
def testMultiDot(self, shapes, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-30 11:07:56 -07:00
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
2020-04-15 17:35:54 -07:00
|
|
|
args_maker = lambda: [[rng(shape, dtype) for shape in shapes]]
|
|
|
|
|
2020-05-05 16:40:41 -04:00
|
|
|
np_fun = np.linalg.multi_dot
|
|
|
|
jnp_fun = partial(jnp.linalg.multi_dot, precision=lax.Precision.HIGHEST)
|
|
|
|
tol = {np.float32: 1e-4, np.float64: 1e-10,
|
|
|
|
np.complex64: 1e-4, np.complex128: 1e-10}
|
2020-04-15 17:35:54 -07:00
|
|
|
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker,
|
2020-04-15 17:35:54 -07:00
|
|
|
atol=tol, rtol=tol)
|
|
|
|
|
2020-05-11 14:53:17 -07:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_lhs={}_rhs={}_lowrank={}_rcond={}".format(
|
|
|
|
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
|
|
|
lowrank, rcond),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
2020-12-04 09:44:50 -08:00
|
|
|
"lowrank": lowrank, "rcond": rcond}
|
2020-05-11 14:53:17 -07:00
|
|
|
for lhs_shape, rhs_shape in [
|
|
|
|
((1, 1), (1, 1)),
|
|
|
|
((4, 6), (4,)),
|
|
|
|
((6, 6), (6, 1)),
|
|
|
|
((8, 6), (8, 4)),
|
|
|
|
]
|
|
|
|
for lowrank in [True, False]
|
|
|
|
for rcond in [-1, None, 0.5]
|
2020-12-04 09:44:50 -08:00
|
|
|
for dtype in float_types + complex_types))
|
2020-05-11 14:53:17 -07:00
|
|
|
@jtu.skip_on_devices("tpu") # SVD not implemented on TPU.
|
2020-05-12 10:06:32 +03:00
|
|
|
@jtu.skip_on_devices("cpu", "gpu") # TODO(jakevdp) Test fails numerically
|
2020-12-04 09:44:50 -08:00
|
|
|
def testLstsq(self, lhs_shape, rhs_shape, dtype, lowrank, rcond):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-30 11:07:56 -07:00
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
2020-07-14 13:03:24 -07:00
|
|
|
np_fun = partial(np.linalg.lstsq, rcond=rcond)
|
2020-05-11 14:53:17 -07:00
|
|
|
jnp_fun = partial(jnp.linalg.lstsq, rcond=rcond)
|
|
|
|
jnp_fun_numpy_resid = partial(jnp.linalg.lstsq, rcond=rcond, numpy_resid=True)
|
|
|
|
tol = {np.float32: 1e-6, np.float64: 1e-12,
|
|
|
|
np.complex64: 1e-6, np.complex128: 1e-12}
|
|
|
|
def args_maker():
|
|
|
|
lhs = rng(lhs_shape, dtype)
|
|
|
|
if lowrank and lhs_shape[1] > 1:
|
|
|
|
lhs[:, -1] = lhs[:, :-1].mean(1)
|
|
|
|
return [lhs, rng(rhs_shape, dtype)]
|
|
|
|
|
2020-07-14 13:03:24 -07:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun_numpy_resid, args_maker, check_dtypes=False, tol=tol)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol)
|
2020-05-11 14:53:17 -07:00
|
|
|
|
|
|
|
# Disabled because grad is flaky for low-rank inputs.
|
|
|
|
# TODO:
|
|
|
|
# jtu.check_grads(lambda *args: jnp_fun(*args)[0], args_maker(), order=2, atol=1e-2, rtol=1e-2)
|
|
|
|
|
2019-05-04 09:42:01 -04:00
|
|
|
# Regression test for incorrect type for eigenvalues of a complex matrix.
|
2019-10-08 16:09:50 -04:00
|
|
|
@jtu.skip_on_devices("tpu") # TODO(phawkins): No complex eigh implementation on TPU.
|
2019-05-04 09:42:01 -04:00
|
|
|
def testIssue669(self):
|
|
|
|
def test(x):
|
2020-05-05 16:40:41 -04:00
|
|
|
val, vec = jnp.linalg.eigh(x)
|
|
|
|
return jnp.real(jnp.sum(val))
|
2019-05-04 09:42:01 -04:00
|
|
|
|
|
|
|
grad_test_jc = jit(grad(jit(test)))
|
2021-02-01 16:30:30 -05:00
|
|
|
xc = np.eye(3, dtype=np.complex64)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(xc, grad_test_jc(xc))
|
2019-05-04 09:42:01 -04:00
|
|
|
|
2020-02-05 17:35:46 +01:00
|
|
|
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
2019-08-09 10:50:31 -04:00
|
|
|
def testIssue1151(self):
|
2020-05-11 12:09:54 -04:00
|
|
|
rng = self.rng()
|
|
|
|
A = jnp.array(rng.randn(100, 3, 3), dtype=jnp.float32)
|
|
|
|
b = jnp.array(rng.randn(100, 3), dtype=jnp.float32)
|
2020-05-05 16:40:41 -04:00
|
|
|
x = jnp.linalg.solve(A, b)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(vmap(jnp.dot)(A, x), b, atol=2e-3, rtol=1e-2)
|
2020-06-02 19:25:47 -07:00
|
|
|
|
|
|
|
_ = jax.jacobian(jnp.linalg.solve, argnums=0)(A, b)
|
|
|
|
_ = jax.jacobian(jnp.linalg.solve, argnums=1)(A, b)
|
|
|
|
|
|
|
|
_ = jax.jacobian(jnp.linalg.solve, argnums=0)(A[0], b[0])
|
|
|
|
_ = jax.jacobian(jnp.linalg.solve, argnums=1)(A[0], b[0])
|
2019-08-09 10:50:31 -04:00
|
|
|
|
2020-02-05 17:35:46 +01:00
|
|
|
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
2019-09-23 12:46:15 -04:00
|
|
|
def testIssue1383(self):
|
|
|
|
seed = jax.random.PRNGKey(0)
|
|
|
|
tmp = jax.random.uniform(seed, (2,2))
|
2020-05-05 16:40:41 -04:00
|
|
|
a = jnp.dot(tmp, tmp.T)
|
2019-09-23 12:46:15 -04:00
|
|
|
|
|
|
|
def f(inp):
|
2020-05-05 16:40:41 -04:00
|
|
|
val, vec = jnp.linalg.eigh(inp)
|
|
|
|
return jnp.dot(jnp.dot(vec, inp), vec.T)
|
2019-09-23 12:46:15 -04:00
|
|
|
|
|
|
|
grad_func = jax.jacfwd(f)
|
|
|
|
hess_func = jax.jacfwd(grad_func)
|
|
|
|
cube_func = jax.jacfwd(hess_func)
|
2020-05-05 16:40:41 -04:00
|
|
|
self.assertFalse(np.any(np.isnan(cube_func(a))))
|
2019-09-23 12:46:15 -04:00
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
class ScipyLinalgTest(jtu.JaxTestCase):
|
|
|
|
|
2020-01-29 11:24:40 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_i={}".format(i), "args": args}
|
|
|
|
for i, args in enumerate([
|
|
|
|
(),
|
|
|
|
(1,),
|
|
|
|
(7, -2),
|
|
|
|
(3, 4, 5),
|
2020-05-05 16:40:41 -04:00
|
|
|
(np.ones((3, 4), dtype=jnp.float_), 5,
|
|
|
|
np.random.randn(5, 2).astype(jnp.float_)),
|
2020-01-29 11:24:40 -05:00
|
|
|
])))
|
|
|
|
def testBlockDiag(self, args):
|
|
|
|
args_maker = lambda: args
|
|
|
|
self._CheckAgainstNumpy(osp.linalg.block_diag, jsp.linalg.block_diag,
|
2020-06-01 17:19:23 -04:00
|
|
|
args_maker)
|
|
|
|
self._CompileAndCheck(jsp.linalg.block_diag, args_maker)
|
2020-01-29 11:24:40 -05:00
|
|
|
|
|
|
|
|
2018-12-20 15:37:34 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
2020-12-04 09:44:50 -08:00
|
|
|
"shape": shape, "dtype": dtype}
|
2018-12-20 15:37:34 -05:00
|
|
|
for shape in [(1, 1), (4, 5), (10, 5), (50, 50)]
|
2020-12-04 09:44:50 -08:00
|
|
|
for dtype in float_types + complex_types))
|
|
|
|
def testLu(self, shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-30 11:07:56 -07:00
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
2018-12-20 15:37:34 -05:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2019-08-02 11:16:15 -04:00
|
|
|
x, = args_maker()
|
|
|
|
p, l, u = jsp.linalg.lu(x)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(x, np.matmul(p, np.matmul(l, u)),
|
2020-05-05 16:40:41 -04:00
|
|
|
rtol={np.float32: 1e-3, np.float64: 1e-12,
|
|
|
|
np.complex64: 1e-3, np.complex128: 1e-12})
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(jsp.linalg.lu, args_maker)
|
2018-12-20 15:37:34 -05:00
|
|
|
|
2019-09-16 08:45:10 -07:00
|
|
|
def testLuOfSingularMatrix(self):
|
2020-05-05 16:40:41 -04:00
|
|
|
x = jnp.array([[-1., 3./2], [2./3, -1.]], dtype=np.float32)
|
2019-09-16 08:45:10 -07:00
|
|
|
p, l, u = jsp.linalg.lu(x)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(x, np.matmul(p, np.matmul(l, u)))
|
2019-06-28 15:31:06 -04:00
|
|
|
|
2018-12-22 14:53:42 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
2020-12-04 09:44:50 -08:00
|
|
|
"shape": shape, "dtype": dtype}
|
2019-08-09 10:50:31 -04:00
|
|
|
for shape in [(1, 1), (4, 5), (10, 5), (10, 10), (6, 7, 7)]
|
2020-12-04 09:44:50 -08:00
|
|
|
for dtype in float_types + complex_types))
|
2019-06-28 10:20:56 -04:00
|
|
|
@jtu.skip_on_devices("tpu") # TODO(phawkins): precision problems on TPU.
|
2020-02-05 17:35:46 +01:00
|
|
|
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
2020-12-04 09:44:50 -08:00
|
|
|
def testLuGrad(self, shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-30 11:07:56 -07:00
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
2018-12-22 14:53:42 -05:00
|
|
|
a = rng(shape, dtype)
|
2019-08-09 10:50:31 -04:00
|
|
|
lu = vmap(jsp.linalg.lu) if len(shape) > 2 else jsp.linalg.lu
|
2020-05-04 23:00:20 -04:00
|
|
|
jtu.check_grads(lu, (a,), 2, atol=5e-2, rtol=3e-1)
|
2018-12-22 14:53:42 -05:00
|
|
|
|
2019-06-27 15:21:56 -04:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
2020-12-04 09:44:50 -08:00
|
|
|
"shape": shape, "dtype": dtype}
|
2019-06-27 15:21:56 -04:00
|
|
|
for shape in [(4, 5), (6, 5)]
|
2020-12-04 09:44:50 -08:00
|
|
|
for dtype in [jnp.float32]))
|
|
|
|
def testLuBatching(self, shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-30 11:07:56 -07:00
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
2020-05-05 16:40:41 -04:00
|
|
|
args = [rng(shape, jnp.float32) for _ in range(10)]
|
2019-04-30 13:19:34 -04:00
|
|
|
expected = list(osp.linalg.lu(x) for x in args)
|
2020-05-05 16:40:41 -04:00
|
|
|
ps = np.stack([out[0] for out in expected])
|
|
|
|
ls = np.stack([out[1] for out in expected])
|
|
|
|
us = np.stack([out[2] for out in expected])
|
2019-04-30 13:19:34 -04:00
|
|
|
|
2020-05-05 16:40:41 -04:00
|
|
|
actual_ps, actual_ls, actual_us = vmap(jsp.linalg.lu)(jnp.stack(args))
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(ps, actual_ps)
|
2020-12-06 15:44:44 -05:00
|
|
|
self.assertAllClose(ls, actual_ls, rtol=5e-6)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(us, actual_us)
|
2018-12-22 14:53:42 -05:00
|
|
|
|
2018-12-20 15:37:34 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_n={}".format(jtu.format_shape_dtype_string((n,n), dtype)),
|
2020-12-04 09:44:50 -08:00
|
|
|
"n": n, "dtype": dtype}
|
2018-12-20 15:37:34 -05:00
|
|
|
for n in [1, 4, 5, 200]
|
2020-12-04 09:44:50 -08:00
|
|
|
for dtype in float_types + complex_types))
|
|
|
|
def testLuFactor(self, n, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-30 11:07:56 -07:00
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
2018-12-20 15:37:34 -05:00
|
|
|
args_maker = lambda: [rng((n, n), dtype)]
|
|
|
|
|
2019-08-02 11:16:15 -04:00
|
|
|
x, = args_maker()
|
|
|
|
lu, piv = jsp.linalg.lu_factor(x)
|
2020-05-05 16:40:41 -04:00
|
|
|
l = np.tril(lu, -1) + np.eye(n, dtype=dtype)
|
|
|
|
u = np.triu(lu)
|
2019-08-02 11:16:15 -04:00
|
|
|
for i in range(n):
|
|
|
|
x[[i, piv[i]],] = x[[piv[i], i],]
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(x, np.matmul(l, u), rtol=1e-3,
|
2019-11-16 13:51:42 -05:00
|
|
|
atol=1e-3)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(jsp.linalg.lu_factor, args_maker)
|
2018-12-20 15:37:34 -05:00
|
|
|
|
2019-09-05 09:59:47 -04:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_lhs={}_rhs={}_trans={}".format(
|
|
|
|
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
|
|
|
trans),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
2020-12-04 09:44:50 -08:00
|
|
|
"trans": trans}
|
2019-09-05 09:59:47 -04:00
|
|
|
for lhs_shape, rhs_shape in [
|
|
|
|
((1, 1), (1, 1)),
|
|
|
|
((4, 4), (4,)),
|
2020-02-12 17:05:18 -08:00
|
|
|
((8, 8), (8, 4)),
|
2019-09-05 09:59:47 -04:00
|
|
|
]
|
|
|
|
for trans in [0, 1, 2]
|
2020-12-04 09:44:50 -08:00
|
|
|
for dtype in float_types + complex_types))
|
2020-04-17 14:43:16 -07:00
|
|
|
@jtu.skip_on_devices("cpu") # TODO(frostig): Test fails on CPU sometimes
|
2020-12-04 09:44:50 -08:00
|
|
|
def testLuSolve(self, lhs_shape, rhs_shape, dtype, trans):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-30 11:07:56 -07:00
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
2019-09-05 09:59:47 -04:00
|
|
|
osp_fun = lambda lu, piv, rhs: osp.linalg.lu_solve((lu, piv), rhs, trans=trans)
|
|
|
|
jsp_fun = lambda lu, piv, rhs: jsp.linalg.lu_solve((lu, piv), rhs, trans=trans)
|
|
|
|
|
|
|
|
def args_maker():
|
|
|
|
a = rng(lhs_shape, dtype)
|
|
|
|
lu, piv = osp.linalg.lu_factor(a)
|
|
|
|
return [lu, piv, rng(rhs_shape, dtype)]
|
|
|
|
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, tol=1e-3)
|
|
|
|
self._CompileAndCheck(jsp_fun, args_maker)
|
2019-09-05 09:59:47 -04:00
|
|
|
|
2018-12-21 16:29:45 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_lhs={}_rhs={}_sym_pos={}_lower={}".format(
|
|
|
|
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
|
|
|
sym_pos, lower),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
2020-12-04 09:44:50 -08:00
|
|
|
"sym_pos": sym_pos, "lower": lower}
|
2018-12-21 16:29:45 -05:00
|
|
|
for lhs_shape, rhs_shape in [
|
|
|
|
((1, 1), (1, 1)),
|
|
|
|
((4, 4), (4,)),
|
|
|
|
((8, 8), (8, 4)),
|
|
|
|
]
|
|
|
|
for sym_pos, lower in [
|
|
|
|
(False, False),
|
|
|
|
(True, False),
|
|
|
|
(True, True),
|
|
|
|
]
|
2020-12-04 09:44:50 -08:00
|
|
|
for dtype in float_types + complex_types))
|
|
|
|
def testSolve(self, lhs_shape, rhs_shape, dtype, sym_pos, lower):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-30 11:07:56 -07:00
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
2018-12-21 16:29:45 -05:00
|
|
|
osp_fun = lambda lhs, rhs: osp.linalg.solve(lhs, rhs, sym_pos=sym_pos, lower=lower)
|
|
|
|
jsp_fun = lambda lhs, rhs: jsp.linalg.solve(lhs, rhs, sym_pos=sym_pos, lower=lower)
|
|
|
|
|
|
|
|
def args_maker():
|
|
|
|
a = rng(lhs_shape, dtype)
|
|
|
|
if sym_pos:
|
2020-05-05 16:40:41 -04:00
|
|
|
a = np.matmul(a, np.conj(T(a)))
|
|
|
|
a = np.tril(a) if lower else np.triu(a)
|
2018-12-21 16:29:45 -05:00
|
|
|
return [a, rng(rhs_shape, dtype)]
|
|
|
|
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, tol=1e-3)
|
|
|
|
self._CompileAndCheck(jsp_fun, args_maker)
|
2018-12-21 16:29:45 -05:00
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
2019-06-25 15:24:22 -04:00
|
|
|
"_lhs={}_rhs={}_lower={}_transposea={}_unit_diagonal={}".format(
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
2019-06-25 15:24:22 -04:00
|
|
|
lower, transpose_a, unit_diagonal),
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
"lower": lower, "transpose_a": transpose_a,
|
2019-06-25 15:24:22 -04:00
|
|
|
"unit_diagonal": unit_diagonal, "lhs_shape": lhs_shape,
|
2020-12-04 09:44:50 -08:00
|
|
|
"rhs_shape": rhs_shape, "dtype": dtype}
|
2019-06-25 15:24:22 -04:00
|
|
|
for lower in [False, True]
|
|
|
|
for transpose_a in [False, True]
|
|
|
|
for unit_diagonal in [False, True]
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
for lhs_shape, rhs_shape in [
|
|
|
|
((4, 4), (4,)),
|
|
|
|
((4, 4), (4, 3)),
|
|
|
|
((2, 8, 8), (2, 8, 10)),
|
|
|
|
]
|
2020-12-04 09:44:50 -08:00
|
|
|
for dtype in float_types))
|
2019-06-25 15:24:22 -04:00
|
|
|
def testSolveTriangular(self, lower, transpose_a, unit_diagonal, lhs_shape,
|
2020-12-04 09:44:50 -08:00
|
|
|
rhs_shape, dtype):
|
2020-07-30 11:07:56 -07:00
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
2020-12-04 09:44:50 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
k = rng(lhs_shape, dtype)
|
2020-05-05 16:40:41 -04:00
|
|
|
l = np.linalg.cholesky(np.matmul(k, T(k))
|
|
|
|
+ lhs_shape[-1] * np.eye(lhs_shape[-1]))
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
l = l.astype(k.dtype)
|
|
|
|
b = rng(rhs_shape, dtype)
|
|
|
|
|
2019-06-25 15:24:22 -04:00
|
|
|
if unit_diagonal:
|
2020-05-05 16:40:41 -04:00
|
|
|
a = np.tril(l, -1) + np.eye(lhs_shape[-1], dtype=dtype)
|
2019-06-25 15:24:22 -04:00
|
|
|
else:
|
|
|
|
a = l
|
|
|
|
a = a if lower else T(a)
|
|
|
|
|
2020-05-05 16:40:41 -04:00
|
|
|
inv = np.linalg.inv(T(a) if transpose_a else a).astype(a.dtype)
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
if len(lhs_shape) == len(rhs_shape):
|
2020-05-05 16:40:41 -04:00
|
|
|
np_ans = np.matmul(inv, b)
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
else:
|
2020-05-05 16:40:41 -04:00
|
|
|
np_ans = np.einsum("...ij,...j->...i", inv, b)
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
# The standard scipy.linalg.solve_triangular doesn't support broadcasting.
|
|
|
|
# But it seems like an inevitable extension so we support it.
|
2018-12-20 15:37:34 -05:00
|
|
|
ans = jsp.linalg.solve_triangular(
|
2019-06-25 15:24:22 -04:00
|
|
|
l if lower else T(l), b, trans=1 if transpose_a else 0, lower=lower,
|
|
|
|
unit_diagonal=unit_diagonal)
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(np_ans, ans,
|
2020-05-05 16:40:41 -04:00
|
|
|
rtol={np.float32: 1e-4, np.float64: 1e-11})
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2018-12-17 17:20:52 -08:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
2019-10-10 08:27:21 -07:00
|
|
|
"_A={}_B={}_lower={}_transposea={}_conja={}_unitdiag={}_leftside={}".format(
|
|
|
|
jtu.format_shape_dtype_string(a_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(b_shape, dtype),
|
|
|
|
lower, transpose_a, conjugate_a, unit_diagonal, left_side),
|
|
|
|
"lower": lower, "transpose_a": transpose_a, "conjugate_a": conjugate_a,
|
|
|
|
"unit_diagonal": unit_diagonal, "left_side": left_side,
|
2020-12-04 09:44:50 -08:00
|
|
|
"a_shape": a_shape, "b_shape": b_shape, "dtype": dtype}
|
2019-06-18 23:38:03 -04:00
|
|
|
for lower in [False, True]
|
2019-06-25 15:24:22 -04:00
|
|
|
for unit_diagonal in [False, True]
|
2019-06-27 16:36:54 -04:00
|
|
|
for dtype in float_types + complex_types
|
2019-10-10 08:27:21 -07:00
|
|
|
for transpose_a in [False, True]
|
|
|
|
for conjugate_a in (
|
2020-05-05 16:40:41 -04:00
|
|
|
[False] if jnp.issubdtype(dtype, jnp.floating) else [False, True])
|
2019-10-10 08:27:21 -07:00
|
|
|
for left_side, a_shape, b_shape in [
|
2020-03-21 10:46:07 -07:00
|
|
|
(False, (4, 4), (4,)),
|
2019-10-10 08:27:21 -07:00
|
|
|
(False, (4, 4), (1, 4,)),
|
|
|
|
(False, (3, 3), (4, 3)),
|
2020-03-21 10:46:07 -07:00
|
|
|
(True, (4, 4), (4,)),
|
2019-10-10 08:27:21 -07:00
|
|
|
(True, (4, 4), (4, 1)),
|
|
|
|
(True, (4, 4), (4, 3)),
|
|
|
|
(True, (2, 8, 8), (2, 8, 10)),
|
2020-12-04 09:44:50 -08:00
|
|
|
]))
|
2019-10-10 08:27:21 -07:00
|
|
|
def testTriangularSolveGrad(
|
|
|
|
self, lower, transpose_a, conjugate_a, unit_diagonal, left_side, a_shape,
|
2020-12-04 09:44:50 -08:00
|
|
|
b_shape, dtype):
|
2020-07-30 11:07:56 -07:00
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
2020-12-04 09:44:50 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-11-04 13:33:30 -08:00
|
|
|
# Test lax.linalg.triangular_solve instead of scipy.linalg.solve_triangular
|
2019-10-10 08:27:21 -07:00
|
|
|
# because it exposes more options.
|
2020-05-05 16:40:41 -04:00
|
|
|
A = jnp.tril(rng(a_shape, dtype) + 5 * np.eye(a_shape[-1], dtype=dtype))
|
2018-12-17 17:20:52 -08:00
|
|
|
A = A if lower else T(A)
|
2019-10-10 08:27:21 -07:00
|
|
|
B = rng(b_shape, dtype)
|
2020-11-04 13:33:30 -08:00
|
|
|
f = partial(lax.linalg.triangular_solve, lower=lower, transpose_a=transpose_a,
|
2020-11-04 08:59:09 -08:00
|
|
|
conjugate_a=conjugate_a, unit_diagonal=unit_diagonal,
|
|
|
|
left_side=left_side)
|
2019-12-03 10:05:51 -05:00
|
|
|
jtu.check_grads(f, (A, B), 2, rtol=4e-2, eps=1e-3)
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
Better batching rule for triangular_solve (#2138)
* Better batching rule for triangular_solve
Now, if only the right hand side argument `b` is batched, we leverage
triangular solve's builtin batching for handling multiple right-hand-side
vectors.
This makes the performance of `vmap` over only the second argument of linear
solves equivalent to relying on builtin batching::
rs = onp.random.RandomState(0)
a = rs.randn(500, 500) + 0.1 * np.eye(500)
b_mat = jax.device_put(rs.randn(500, 10))
solve1 = jax.jit(np.linalg.solve)
solve2 = jax.jit(jax.vmap(np.linalg.solve, in_axes=(None, 1), out_axes=1))
Before::
In [6]: %timeit jax.device_get(solve1(a, b_mat))
3.88 ms ± 293 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# 8x slower :(
In [9]: %timeit jax.device_get(solve2(a, b_mat))
23.5 ms ± 1.33 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Now::
In [2]: %timeit jax.device_get(solve1(a, b_mat))
3.76 ms ± 304 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
# same speed :)
In [3]: %timeit jax.device_get(solve2(a, b_mat))
3.72 ms ± 296 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
* Test failures
* Check b.ndim == 2 in triangular solve shape rule
2020-02-03 09:27:03 -08:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_A={}_B={}_bdim={}_leftside={}".format(
|
|
|
|
a_shape, b_shape, bdims, left_side),
|
|
|
|
"left_side": left_side, "a_shape": a_shape, "b_shape": b_shape,
|
|
|
|
"bdims": bdims}
|
|
|
|
for left_side, a_shape, b_shape, bdims in [
|
|
|
|
(False, (4, 4), (2, 3, 4,), (None, 0)),
|
|
|
|
(False, (2, 4, 4), (2, 2, 3, 4,), (None, 0)),
|
|
|
|
(False, (2, 4, 4), (3, 4,), (0, None)),
|
|
|
|
(False, (2, 4, 4), (2, 3, 4,), (0, 0)),
|
|
|
|
(True, (2, 4, 4), (2, 4, 3), (0, 0)),
|
|
|
|
(True, (2, 4, 4), (2, 2, 4, 3), (None, 0)),
|
|
|
|
]))
|
|
|
|
def testTriangularSolveBatching(self, left_side, a_shape, b_shape, bdims):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-05-05 16:40:41 -04:00
|
|
|
A = jnp.tril(rng(a_shape, np.float32)
|
|
|
|
+ 5 * np.eye(a_shape[-1], dtype=np.float32))
|
|
|
|
B = rng(b_shape, np.float32)
|
2020-11-04 13:33:30 -08:00
|
|
|
solve = partial(lax.linalg.triangular_solve, lower=True, transpose_a=False,
|
2020-11-04 08:59:09 -08:00
|
|
|
conjugate_a=False, unit_diagonal=False, left_side=left_side)
|
Better batching rule for triangular_solve (#2138)
* Better batching rule for triangular_solve
Now, if only the right hand side argument `b` is batched, we leverage
triangular solve's builtin batching for handling multiple right-hand-side
vectors.
This makes the performance of `vmap` over only the second argument of linear
solves equivalent to relying on builtin batching::
rs = onp.random.RandomState(0)
a = rs.randn(500, 500) + 0.1 * np.eye(500)
b_mat = jax.device_put(rs.randn(500, 10))
solve1 = jax.jit(np.linalg.solve)
solve2 = jax.jit(jax.vmap(np.linalg.solve, in_axes=(None, 1), out_axes=1))
Before::
In [6]: %timeit jax.device_get(solve1(a, b_mat))
3.88 ms ± 293 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# 8x slower :(
In [9]: %timeit jax.device_get(solve2(a, b_mat))
23.5 ms ± 1.33 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Now::
In [2]: %timeit jax.device_get(solve1(a, b_mat))
3.76 ms ± 304 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
# same speed :)
In [3]: %timeit jax.device_get(solve2(a, b_mat))
3.72 ms ± 296 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
* Test failures
* Check b.ndim == 2 in triangular solve shape rule
2020-02-03 09:27:03 -08:00
|
|
|
X = vmap(solve, bdims)(A, B)
|
2020-05-05 16:40:41 -04:00
|
|
|
matmul = partial(jnp.matmul, precision=lax.Precision.HIGHEST)
|
2020-02-03 18:02:45 -08:00
|
|
|
Y = matmul(A, X) if left_side else matmul(X, A)
|
2020-05-05 16:40:41 -04:00
|
|
|
np.testing.assert_allclose(Y - B, 0, atol=1e-4)
|
Better batching rule for triangular_solve (#2138)
* Better batching rule for triangular_solve
Now, if only the right hand side argument `b` is batched, we leverage
triangular solve's builtin batching for handling multiple right-hand-side
vectors.
This makes the performance of `vmap` over only the second argument of linear
solves equivalent to relying on builtin batching::
rs = onp.random.RandomState(0)
a = rs.randn(500, 500) + 0.1 * np.eye(500)
b_mat = jax.device_put(rs.randn(500, 10))
solve1 = jax.jit(np.linalg.solve)
solve2 = jax.jit(jax.vmap(np.linalg.solve, in_axes=(None, 1), out_axes=1))
Before::
In [6]: %timeit jax.device_get(solve1(a, b_mat))
3.88 ms ± 293 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# 8x slower :(
In [9]: %timeit jax.device_get(solve2(a, b_mat))
23.5 ms ± 1.33 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Now::
In [2]: %timeit jax.device_get(solve1(a, b_mat))
3.76 ms ± 304 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
# same speed :)
In [3]: %timeit jax.device_get(solve2(a, b_mat))
3.72 ms ± 296 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
* Test failures
* Check b.ndim == 2 in triangular solve shape rule
2020-02-03 09:27:03 -08:00
|
|
|
|
2019-12-10 00:38:18 -08:00
|
|
|
def testTriangularSolveGradPrecision(self):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-05-05 16:40:41 -04:00
|
|
|
a = jnp.tril(rng((3, 3), np.float32))
|
|
|
|
b = rng((1, 3), np.float32)
|
2019-12-10 00:38:18 -08:00
|
|
|
jtu.assert_dot_precision(
|
|
|
|
lax.Precision.HIGHEST,
|
2020-11-04 13:33:30 -08:00
|
|
|
partial(jvp, lax.linalg.triangular_solve),
|
2019-12-10 00:38:18 -08:00
|
|
|
(a, b),
|
|
|
|
(a, b))
|
2019-02-11 16:18:13 -08:00
|
|
|
|
2020-01-22 00:11:51 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_n={}".format(jtu.format_shape_dtype_string((n,n), dtype)),
|
2020-12-04 09:44:50 -08:00
|
|
|
"n": n, "dtype": dtype}
|
2020-01-22 00:11:51 -05:00
|
|
|
for n in [1, 4, 5, 20, 50, 100]
|
2020-12-04 09:44:50 -08:00
|
|
|
for dtype in float_types + complex_types))
|
|
|
|
def testExpm(self, n, dtype):
|
|
|
|
rng = jtu.rand_small(self.rng())
|
2020-07-30 11:07:56 -07:00
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
2020-01-22 00:11:51 -05:00
|
|
|
args_maker = lambda: [rng((n, n), dtype)]
|
|
|
|
|
|
|
|
osp_fun = lambda a: osp.linalg.expm(a)
|
|
|
|
jsp_fun = lambda a: jsp.linalg.expm(a)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jsp_fun, args_maker)
|
2020-01-22 00:11:51 -05:00
|
|
|
|
2020-05-05 16:40:41 -04:00
|
|
|
args_maker_triu = lambda: [np.triu(rng((n, n), dtype))]
|
2020-06-01 17:19:23 -04:00
|
|
|
jsp_fun_triu = lambda a: jsp.linalg.expm(a, upper_triangular=True)
|
|
|
|
self._CheckAgainstNumpy(osp_fun, jsp_fun_triu, args_maker_triu)
|
|
|
|
self._CompileAndCheck(jsp_fun_triu, args_maker_triu)
|
2020-01-22 00:11:51 -05:00
|
|
|
|
2020-02-09 17:30:54 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_n={}".format(jtu.format_shape_dtype_string((n,n), dtype)),
|
|
|
|
"n": n, "dtype": dtype}
|
|
|
|
for n in [1, 4, 5, 20, 50, 100]
|
|
|
|
for dtype in float_types + complex_types
|
|
|
|
))
|
|
|
|
def testIssue2131(self, n, dtype):
|
2020-05-05 16:40:41 -04:00
|
|
|
args_maker_zeros = lambda: [np.zeros((n, n), dtype)]
|
2020-02-09 17:30:54 -05:00
|
|
|
osp_fun = lambda a: osp.linalg.expm(a)
|
|
|
|
jsp_fun = lambda a: jsp.linalg.expm(a)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker_zeros)
|
|
|
|
self._CompileAndCheck(jsp_fun, args_maker_zeros)
|
2020-02-09 17:30:54 -05:00
|
|
|
|
2020-03-21 10:46:07 -07:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_lhs={}_rhs={}_lower={}".format(
|
|
|
|
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
|
|
|
lower),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
2020-12-04 09:44:50 -08:00
|
|
|
"lower": lower}
|
2020-03-21 10:46:07 -07:00
|
|
|
for lhs_shape, rhs_shape in [
|
|
|
|
[(1, 1), (1,)],
|
|
|
|
[(4, 4), (4,)],
|
|
|
|
[(4, 4), (4, 4)],
|
|
|
|
]
|
|
|
|
for dtype in float_types
|
2020-12-04 09:44:50 -08:00
|
|
|
for lower in [True, False]))
|
|
|
|
def testChoSolve(self, lhs_shape, rhs_shape, dtype, lower):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-30 11:07:56 -07:00
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
2020-03-21 10:46:07 -07:00
|
|
|
def args_maker():
|
|
|
|
b = rng(rhs_shape, dtype)
|
|
|
|
if lower:
|
2020-05-05 16:40:41 -04:00
|
|
|
L = np.tril(rng(lhs_shape, dtype))
|
2020-03-21 10:46:07 -07:00
|
|
|
return [(L, lower), b]
|
|
|
|
else:
|
2020-05-05 16:40:41 -04:00
|
|
|
U = np.triu(rng(lhs_shape, dtype))
|
2020-03-21 10:46:07 -07:00
|
|
|
return [(U, lower), b]
|
|
|
|
self._CheckAgainstNumpy(osp.linalg.cho_solve, jsp.linalg.cho_solve,
|
2020-06-01 17:19:23 -04:00
|
|
|
args_maker, tol=1e-3)
|
2020-03-21 10:46:07 -07:00
|
|
|
|
|
|
|
|
2020-06-28 12:11:12 -04:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_n={}".format(jtu.format_shape_dtype_string((n,n), dtype)),
|
2020-12-04 09:44:50 -08:00
|
|
|
"n": n, "dtype": dtype}
|
2020-06-28 12:11:12 -04:00
|
|
|
for n in [1, 4, 5, 20, 50, 100]
|
2020-12-04 09:44:50 -08:00
|
|
|
for dtype in float_types + complex_types))
|
|
|
|
def testExpmFrechet(self, n, dtype):
|
|
|
|
rng = jtu.rand_small(self.rng())
|
2020-07-30 11:07:56 -07:00
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
2020-09-18 15:32:50 -07:00
|
|
|
if dtype == np.float64 or dtype == np.complex128:
|
|
|
|
target_norms = [1.0e-2, 2.0e-1, 9.0e-01, 2.0, 3.0]
|
|
|
|
# TODO(zhangqiaorjc): Reduce tol to default 1e-15.
|
|
|
|
tol = {
|
|
|
|
np.dtype(np.float64): 1e-14,
|
|
|
|
np.dtype(np.complex128): 1e-14,
|
|
|
|
}
|
|
|
|
elif dtype == np.float32 or dtype == np.complex64:
|
|
|
|
target_norms = [4.0e-1, 1.0, 3.0]
|
|
|
|
tol = None
|
|
|
|
else:
|
|
|
|
raise TypeError("dtype={} is not supported.".format(dtype))
|
|
|
|
for norm in target_norms:
|
|
|
|
def args_maker():
|
|
|
|
a = rng((n, n), dtype)
|
|
|
|
a = a / np.linalg.norm(a, 1) * norm
|
|
|
|
e = rng((n, n), dtype)
|
|
|
|
return [a, e, ]
|
|
|
|
|
|
|
|
#compute_expm is True
|
|
|
|
osp_fun = lambda a,e: osp.linalg.expm_frechet(a,e,compute_expm=True)
|
|
|
|
jsp_fun = lambda a,e: jsp.linalg.expm_frechet(a,e,compute_expm=True)
|
|
|
|
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker,
|
|
|
|
check_dtypes=False, tol=tol)
|
|
|
|
self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=False)
|
|
|
|
#compute_expm is False
|
|
|
|
osp_fun = lambda a,e: osp.linalg.expm_frechet(a,e,compute_expm=False)
|
|
|
|
jsp_fun = lambda a,e: jsp.linalg.expm_frechet(a,e,compute_expm=False)
|
|
|
|
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker,
|
|
|
|
check_dtypes=False, tol=tol)
|
|
|
|
self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=False)
|
2020-06-28 12:11:12 -04:00
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_n={}".format(jtu.format_shape_dtype_string((n,n), dtype)),
|
2020-12-04 09:44:50 -08:00
|
|
|
"n": n, "dtype": dtype}
|
2020-06-28 12:11:12 -04:00
|
|
|
for n in [1, 4, 5, 20, 50]
|
2020-12-04 09:44:50 -08:00
|
|
|
for dtype in float_types + complex_types))
|
|
|
|
def testExpmGrad(self, n, dtype):
|
|
|
|
rng = jtu.rand_small(self.rng())
|
2020-09-18 15:32:50 -07:00
|
|
|
jtu.skip_if_unsupported_type(dtype)
|
|
|
|
a = rng((n, n), dtype)
|
|
|
|
if dtype == np.float64 or dtype == np.complex128:
|
|
|
|
target_norms = [1.0e-2, 2.0e-1, 9.0e-01, 2.0, 3.0]
|
|
|
|
elif dtype == np.float32 or dtype == np.complex64:
|
|
|
|
target_norms = [4.0e-1, 1.0, 3.0]
|
|
|
|
else:
|
|
|
|
raise TypeError("dtype={} is not supported.".format(dtype))
|
2020-09-21 14:17:04 -07:00
|
|
|
# TODO(zhangqiaorjc): Reduce tol to default 1e-5.
|
|
|
|
# Lower tolerance is due to 2nd order derivative.
|
|
|
|
tol = {
|
|
|
|
# Note that due to inner_product, float and complex tol are coupled.
|
|
|
|
np.dtype(np.float32): 0.02,
|
|
|
|
np.dtype(np.complex64): 0.02,
|
|
|
|
np.dtype(np.float64): 1e-4,
|
|
|
|
np.dtype(np.complex128): 1e-4,
|
|
|
|
}
|
|
|
|
for norm in target_norms:
|
2020-09-18 15:32:50 -07:00
|
|
|
a = a / np.linalg.norm(a, 1) * norm
|
|
|
|
def expm(x):
|
|
|
|
return jsp.linalg.expm(x, upper_triangular=False, max_squarings=16)
|
2020-10-08 12:09:21 -07:00
|
|
|
jtu.check_grads(expm, (a,), modes=["fwd", "rev"], order=1, atol=tol,
|
2020-09-18 15:32:50 -07:00
|
|
|
rtol=tol)
|
2020-06-28 12:11:12 -04:00
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
if __name__ == "__main__":
|
2020-06-24 16:24:33 -07:00
|
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|