Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
"""Tests for the LAPAX linear algebra module."""
|
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
|
|
|
|
2018-12-17 16:36:55 +00:00
|
|
|
from functools import partial
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
import itertools
|
2019-03-19 16:38:42 -07:00
|
|
|
from unittest import SkipTest
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
import numpy as onp
|
2018-12-20 15:37:34 -05:00
|
|
|
import scipy as osp
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
from absl.testing import absltest
|
|
|
|
from absl.testing import parameterized
|
|
|
|
|
2018-12-17 16:36:55 +00:00
|
|
|
from jax import jvp
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
from jax import numpy as np
|
2018-12-20 15:37:34 -05:00
|
|
|
from jax import scipy as jsp
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
from jax import test_util as jtu
|
|
|
|
from jax.lib import xla_bridge
|
|
|
|
|
2018-12-20 15:37:34 -05:00
|
|
|
from jaxlib import lapack
|
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
from jax.config import config
|
|
|
|
config.parse_flags_with_absl()
|
2019-02-07 10:51:55 -05:00
|
|
|
FLAGS = config.FLAGS
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
T = lambda x: onp.swapaxes(x, -1, -2)
|
|
|
|
|
|
|
|
|
|
|
|
def float_types():
|
2019-01-28 17:27:46 -05:00
|
|
|
return {onp.dtype(xla_bridge.canonicalize_dtype(dtype))
|
|
|
|
for dtype in [onp.float32, onp.float64]}
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2018-12-20 22:18:20 -05:00
|
|
|
def complex_types():
|
2019-01-28 17:27:46 -05:00
|
|
|
return {onp.dtype(xla_bridge.canonicalize_dtype(dtype))
|
|
|
|
for dtype in [onp.complex64, onp.complex128]}
|
2018-12-20 22:18:20 -05:00
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
class NumpyLinalgTest(jtu.JaxTestCase):
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
|
|
|
"shape": shape, "dtype": dtype, "rng": rng}
|
|
|
|
for shape in [(1, 1), (4, 4), (2, 5, 5), (200, 200), (1000, 0, 0)]
|
|
|
|
for dtype in float_types()
|
|
|
|
for rng in [jtu.rand_default()]))
|
|
|
|
def testCholesky(self, shape, dtype, rng):
|
|
|
|
def args_maker():
|
2019-02-11 16:18:13 -08:00
|
|
|
factor_shape = shape[:-1] + (2 * shape[-1],)
|
|
|
|
a = rng(factor_shape, dtype)
|
2018-12-21 16:29:45 -05:00
|
|
|
return [onp.matmul(a, np.conj(T(a)))]
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
self._CheckAgainstNumpy(onp.linalg.cholesky, np.linalg.cholesky, args_maker,
|
|
|
|
check_dtypes=True, tol=1e-3)
|
|
|
|
self._CompileAndCheck(np.linalg.cholesky, args_maker, check_dtypes=True)
|
|
|
|
|
2019-02-11 16:18:13 -08:00
|
|
|
if onp.finfo(dtype).bits == 64:
|
|
|
|
jtu.check_grads(np.linalg.cholesky, args_maker(), order=2)
|
2019-02-11 17:39:07 -05:00
|
|
|
|
2018-12-20 15:37:34 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_n={}".format(jtu.format_shape_dtype_string((n,n), dtype)),
|
|
|
|
"n": n, "dtype": dtype, "rng": rng}
|
2019-02-03 14:37:55 -08:00
|
|
|
for n in [0, 4, 5, 25] # TODO(mattjj): complex64 unstable on large sizes?
|
2018-12-20 22:18:20 -05:00
|
|
|
for dtype in float_types() | complex_types()
|
2018-12-20 15:37:34 -05:00
|
|
|
for rng in [jtu.rand_default()]))
|
2018-12-21 16:29:45 -05:00
|
|
|
# TODO(phawkins): enable when there is an LU implementation for GPU/TPU.
|
2018-12-20 21:04:02 -05:00
|
|
|
@jtu.skip_on_devices("gpu", "tpu")
|
2018-12-20 15:37:34 -05:00
|
|
|
def testDet(self, n, dtype, rng):
|
2018-12-20 21:04:02 -05:00
|
|
|
if not hasattr(lapack, "jax_getrf"):
|
|
|
|
self.skipTest("No LU implementation available")
|
2018-12-20 15:37:34 -05:00
|
|
|
args_maker = lambda: [rng((n, n), dtype)]
|
|
|
|
|
|
|
|
self._CheckAgainstNumpy(onp.linalg.det, np.linalg.det, args_maker,
|
|
|
|
check_dtypes=True, tol=1e-3)
|
|
|
|
self._CompileAndCheck(np.linalg.det, args_maker, check_dtypes=True)
|
|
|
|
|
2018-12-20 22:18:20 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_n={}".format(jtu.format_shape_dtype_string((n,n), dtype)),
|
|
|
|
"n": n, "dtype": dtype, "rng": rng}
|
|
|
|
for n in [0, 4, 10, 200]
|
|
|
|
for dtype in float_types() | complex_types()
|
|
|
|
for rng in [jtu.rand_default()]))
|
|
|
|
@jtu.skip_on_devices("gpu", "tpu")
|
|
|
|
def testSlogdet(self, n, dtype, rng):
|
|
|
|
if not hasattr(lapack, "jax_getrf"):
|
|
|
|
self.skipTest("No LU implementation available")
|
|
|
|
args_maker = lambda: [rng((n, n), dtype)]
|
|
|
|
|
|
|
|
self._CheckAgainstNumpy(onp.linalg.slogdet, np.linalg.slogdet, args_maker,
|
|
|
|
check_dtypes=True, tol=1e-3)
|
|
|
|
self._CompileAndCheck(np.linalg.slogdet, args_maker, check_dtypes=True)
|
2018-12-20 15:37:34 -05:00
|
|
|
|
2019-01-07 18:10:08 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_n={}_lower={}".format(
|
|
|
|
jtu.format_shape_dtype_string((n,n), dtype), lower),
|
|
|
|
"n": n, "dtype": dtype, "lower": lower, "rng": rng}
|
|
|
|
for n in [0, 4, 5, 50]
|
|
|
|
for dtype in float_types() | complex_types()
|
|
|
|
for lower in [False, True]
|
|
|
|
for rng in [jtu.rand_default()]))
|
|
|
|
# TODO(phawkins): enable when there is an eigendecomposition implementation
|
|
|
|
# for GPU/TPU.
|
|
|
|
@jtu.skip_on_devices("gpu", "tpu")
|
|
|
|
def testEigh(self, n, dtype, lower, rng):
|
|
|
|
if not hasattr(lapack, "jax_syevd"):
|
|
|
|
self.skipTest("No symmetric eigendecomposition implementation available")
|
|
|
|
args_maker = lambda: [rng((n, n), dtype)]
|
|
|
|
|
|
|
|
uplo = "L" if lower else "U"
|
|
|
|
|
|
|
|
# Norm, adjusted for dimension and type.
|
|
|
|
def norm(x):
|
|
|
|
norm = onp.linalg.norm(x, axis=(-2, -1))
|
|
|
|
return norm / ((n + 1) * onp.finfo(dtype).eps)
|
|
|
|
|
|
|
|
a, = args_maker()
|
|
|
|
a = (a + onp.conj(a.T)) / 2
|
2019-02-11 23:26:26 -08:00
|
|
|
w, v = np.linalg.eigh(onp.tril(a) if lower else onp.triu(a),
|
2019-02-13 23:23:39 -08:00
|
|
|
UPLO=uplo, symmetrize_input=False)
|
2019-01-07 18:10:08 -05:00
|
|
|
self.assertTrue(norm(onp.eye(n) - onp.matmul(onp.conj(T(v)), v)) < 5)
|
2019-01-07 19:56:46 -05:00
|
|
|
self.assertTrue(norm(onp.matmul(a, v) - w * v) < 30)
|
2019-01-07 18:10:08 -05:00
|
|
|
|
|
|
|
self._CompileAndCheck(partial(np.linalg.eigh, UPLO=uplo), args_maker,
|
|
|
|
check_dtypes=True)
|
|
|
|
|
2019-02-11 23:26:26 -08:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_shape={}_lower={}".format(jtu.format_shape_dtype_string(shape, dtype),
|
|
|
|
lower),
|
|
|
|
"shape": shape, "dtype": dtype, "rng": rng, "lower":lower}
|
|
|
|
for shape in [(1, 1), (4, 4), (5, 5), (50, 50)]
|
|
|
|
for dtype in float_types() | complex_types()
|
|
|
|
for rng in [jtu.rand_default()]
|
|
|
|
for lower in [True, False]))
|
|
|
|
# TODO(phawkins): enable when there is an eigendecomposition implementation
|
|
|
|
# for GPU/TPU.
|
|
|
|
@jtu.skip_on_devices("gpu", "tpu")
|
|
|
|
def testEighGrad(self, shape, dtype, rng, lower):
|
|
|
|
if not hasattr(lapack, "jax_syevd"):
|
|
|
|
self.skipTest("No symmetric eigendecomposition implementation available")
|
|
|
|
uplo = "L" if lower else "U"
|
|
|
|
a = rng(shape, dtype)
|
|
|
|
a = (a + onp.conj(a.T)) / 2
|
|
|
|
a = onp.tril(a) if lower else onp.triu(a)
|
|
|
|
# Gradient checks will fail without symmetrization as the eigh jvp rule
|
|
|
|
# is only correct for tangents in the symmetric subspace, whereas the
|
|
|
|
# checker checks against unconstrained (co)tangents.
|
2019-02-13 23:23:39 -08:00
|
|
|
if dtype not in complex_types():
|
|
|
|
f = partial(np.linalg.eigh, UPLO=uplo, symmetrize_input=True)
|
|
|
|
else: # only check eigenvalue grads for complex matrices
|
|
|
|
f = lambda a: partial(np.linalg.eigh, UPLO=uplo, symmetrize_input=True)(a)[0]
|
2019-02-11 23:26:26 -08:00
|
|
|
jtu.check_grads(f, (a,), 2, rtol=1e-1)
|
|
|
|
|
2019-02-13 23:23:39 -08:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_shape={}_lower={}".format(jtu.format_shape_dtype_string(shape, dtype),
|
|
|
|
lower),
|
2019-02-13 23:44:41 -08:00
|
|
|
"shape": shape, "dtype": dtype, "rng": rng, "lower":lower, "eps":eps}
|
2019-02-13 23:23:39 -08:00
|
|
|
for shape in [(1, 1), (4, 4), (5, 5), (50, 50)]
|
|
|
|
for dtype in complex_types()
|
|
|
|
for rng in [jtu.rand_default()]
|
2019-02-13 23:44:41 -08:00
|
|
|
for lower in [True, False]
|
|
|
|
for eps in [1e-4]))
|
2019-02-13 23:23:39 -08:00
|
|
|
# TODO(phawkins): enable when there is an eigendecomposition implementation
|
|
|
|
# for GPU/TPU.
|
|
|
|
@jtu.skip_on_devices("gpu", "tpu")
|
|
|
|
def testEighGradVectorComplex(self, shape, dtype, rng, lower, eps):
|
|
|
|
# Special case to test for complex eigenvector grad correctness.
|
|
|
|
# Exact eigenvector coordinate gradients are hard to test numerically for complex
|
|
|
|
# eigensystem solvers given the extra degrees of per-eigenvector phase freedom.
|
|
|
|
# Instead, we numerically verify the eigensystem properties on the perturbed
|
|
|
|
# eigenvectors. You only ever want to optimize eigenvector directions, not coordinates!
|
|
|
|
if not hasattr(lapack, "jax_syevd"):
|
|
|
|
self.skipTest("No symmetric eigendecomposition implementation available")
|
|
|
|
uplo = "L" if lower else "U"
|
|
|
|
a = rng(shape, dtype)
|
|
|
|
a = (a + onp.conj(a.T)) / 2
|
|
|
|
a = onp.tril(a) if lower else onp.triu(a)
|
|
|
|
a_dot = eps * rng(shape, dtype)
|
|
|
|
a_dot = (a_dot + onp.conj(a_dot.T)) / 2
|
|
|
|
a_dot = onp.tril(a_dot) if lower else onp.triu(a_dot)
|
|
|
|
# evaluate eigenvector gradient and groundtruth eigensystem for perturbed input matrix
|
|
|
|
f = partial(np.linalg.eigh, UPLO=uplo)
|
2019-02-14 02:28:00 -08:00
|
|
|
(w, v), (dw, dv) = jvp(f, primals=(a,), tangents=(a_dot,))
|
2019-02-13 23:23:39 -08:00
|
|
|
new_a = a + a_dot
|
|
|
|
new_w, new_v = f(new_a)
|
2019-02-14 02:28:00 -08:00
|
|
|
new_a = (new_a + onp.conj(new_a.T)) / 2
|
2019-02-13 23:23:39 -08:00
|
|
|
# Assert rtol eigenvalue delta between perturbed eigenvectors vs new true eigenvalues.
|
2019-02-14 02:28:00 -08:00
|
|
|
RTOL=1e-2
|
2019-02-13 23:23:39 -08:00
|
|
|
assert onp.max(
|
2019-02-14 02:28:00 -08:00
|
|
|
onp.abs((onp.diag(onp.dot(onp.conj((v+dv).T), onp.dot(new_a,(v+dv)))) - new_w) / new_w)) < RTOL
|
2019-02-13 23:23:39 -08:00
|
|
|
# Redundant to above, but also assert rtol for eigenvector property with new true eigenvalues.
|
|
|
|
assert onp.max(
|
2019-02-14 02:28:00 -08:00
|
|
|
onp.linalg.norm(onp.abs(new_w*(v+dv) - onp.dot(new_a, (v+dv))), axis=0) /
|
|
|
|
onp.linalg.norm(onp.abs(new_w*(v+dv)), axis=0)
|
|
|
|
) < RTOL
|
2019-02-13 23:23:39 -08:00
|
|
|
|
2019-02-07 10:51:55 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_shape={}_ord={}_axis={}_keepdims={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), ord, axis, keepdims),
|
|
|
|
"shape": shape, "dtype": dtype, "axis": axis, "keepdims": keepdims,
|
|
|
|
"ord": ord, "rng": rng}
|
|
|
|
for axis, shape in [
|
|
|
|
(None, (1,)), (None, (7,)), (None, (5, 8)),
|
|
|
|
(0, (9,)), (0, (4, 5)), ((1,), (10, 7, 3)), ((-2,), (4, 8)),
|
|
|
|
(-1, (6, 3)), ((0, 2), (3, 4, 5)), ((2, 0), (7, 8, 9))]
|
|
|
|
for keepdims in [False, True]
|
|
|
|
for ord in (
|
|
|
|
[None, 0, 1, 2, 3, -1, -2, -3, np.inf, -np.inf]
|
|
|
|
if (axis is None and len(shape) == 1) or
|
|
|
|
isinstance(axis, int) or
|
|
|
|
(isinstance(axis, tuple) and len(axis) == 1)
|
|
|
|
else [None, 'fro', 1, 2, -1, -2, np.inf, -np.inf, 'nuc'])
|
|
|
|
for dtype in float_types() | complex_types()
|
|
|
|
for rng in [jtu.rand_default()]))
|
|
|
|
def testNorm(self, shape, dtype, ord, axis, keepdims, rng):
|
2019-02-11 13:32:22 -08:00
|
|
|
# TODO(mattjj,phawkins): re-enable after checking internal tests
|
|
|
|
self.skipTest("internal test failures")
|
|
|
|
|
2019-02-07 10:51:55 -05:00
|
|
|
if (ord in ('nuc', 2, -2) and isinstance(axis, tuple) and len(axis) == 2 and
|
|
|
|
(not FLAGS.jax_test_dut or not FLAGS.jax_test_dut.startswith("cpu") or
|
|
|
|
len(shape) != 2)):
|
2019-03-19 16:38:42 -07:00
|
|
|
raise SkipTest("No adequate SVD implementation available")
|
2019-02-07 10:51:55 -05:00
|
|
|
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
onp_fn = partial(onp.linalg.norm, ord=ord, axis=axis, keepdims=keepdims)
|
|
|
|
np_fn = partial(np.linalg.norm, ord=ord, axis=axis, keepdims=keepdims)
|
|
|
|
self._CheckAgainstNumpy(onp_fn, np_fn, args_maker,
|
|
|
|
check_dtypes=True, tol=1e-3)
|
|
|
|
self._CompileAndCheck(np_fn, args_maker, check_dtypes=True)
|
|
|
|
|
2019-01-08 15:51:30 +05:30
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_n={}_full_matrices={}_compute_uv={}".format(
|
|
|
|
jtu.format_shape_dtype_string((m, n), dtype), full_matrices, compute_uv),
|
|
|
|
"m": m, "n": n, "dtype": dtype, "full_matrices": full_matrices,
|
|
|
|
"compute_uv": compute_uv, "rng": rng}
|
|
|
|
for m in [2, 7, 29, 53]
|
|
|
|
for n in [2, 7, 29, 53]
|
2019-01-09 22:04:14 +05:30
|
|
|
for dtype in float_types() | complex_types()
|
2019-01-08 15:51:30 +05:30
|
|
|
for full_matrices in [False, True]
|
|
|
|
for compute_uv in [False, True]
|
|
|
|
for rng in [jtu.rand_default()]))
|
|
|
|
@jtu.skip_on_devices("gpu", "tpu")
|
|
|
|
def testSVD(self, m, n, dtype, full_matrices, compute_uv, rng):
|
|
|
|
if not hasattr(lapack, "jax_gesdd"):
|
|
|
|
self.skipTest("No singular value decomposition implementation available")
|
2019-01-13 11:37:34 -08:00
|
|
|
|
2019-01-08 15:51:30 +05:30
|
|
|
args_maker = lambda: [rng((m, n), dtype)]
|
|
|
|
|
|
|
|
# Norm, adjusted for dimension and type.
|
|
|
|
def norm(x):
|
2019-01-14 17:24:55 -05:00
|
|
|
norm = onp.linalg.norm(x, axis=(-2, -1))
|
|
|
|
return norm / (max(m, n) * onp.finfo(dtype).eps)
|
2019-01-08 15:51:30 +05:30
|
|
|
|
|
|
|
a, = args_maker()
|
|
|
|
out = np.linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv)
|
|
|
|
|
|
|
|
if compute_uv:
|
|
|
|
# Check the reconstructed matrices
|
|
|
|
if full_matrices:
|
|
|
|
k = min(m, n)
|
|
|
|
if m < n:
|
2019-01-14 17:24:55 -05:00
|
|
|
self.assertTrue(onp.all(norm(a - onp.matmul(out[1] * out[0], out[2][:k, :])) < 50))
|
2019-01-08 15:51:30 +05:30
|
|
|
else:
|
2019-01-14 17:24:55 -05:00
|
|
|
self.assertTrue(onp.all(norm(a - onp.matmul(out[1] * out[0][:, :k], out[2])) < 50))
|
2019-01-08 15:51:30 +05:30
|
|
|
else:
|
2019-01-14 17:24:55 -05:00
|
|
|
self.assertTrue(onp.all(norm(a - onp.matmul(out[1] * out[0], out[2])) < 50))
|
2019-01-08 15:51:30 +05:30
|
|
|
|
|
|
|
# Check the unitary properties of the singular vector matrices.
|
2019-01-13 11:37:34 -08:00
|
|
|
self.assertTrue(onp.all(norm(onp.eye(out[0].shape[1]) - onp.matmul(onp.conj(T(out[0])), out[0])) < 10))
|
2019-01-08 15:51:30 +05:30
|
|
|
if m >= n:
|
2019-01-13 11:37:34 -08:00
|
|
|
self.assertTrue(onp.all(norm(onp.eye(out[2].shape[1]) - onp.matmul(onp.conj(T(out[2])), out[2])) < 10))
|
2019-01-08 15:51:30 +05:30
|
|
|
else:
|
2019-01-13 11:37:34 -08:00
|
|
|
self.assertTrue(onp.all(norm(onp.eye(out[2].shape[0]) - onp.matmul(out[2], onp.conj(T(out[2])))) < 20))
|
2019-01-08 15:51:30 +05:30
|
|
|
|
2019-01-08 21:47:19 +05:30
|
|
|
else:
|
2019-01-14 17:24:55 -05:00
|
|
|
self.assertTrue(onp.allclose(onp.linalg.svd(a, compute_uv=False), onp.asarray(out), atol=1e-4, rtol=1e-4))
|
2019-01-08 21:47:19 +05:30
|
|
|
|
2019-01-08 15:51:30 +05:30
|
|
|
self._CompileAndCheck(partial(np.linalg.svd, full_matrices=full_matrices, compute_uv=compute_uv),
|
|
|
|
args_maker, check_dtypes=True)
|
2019-01-07 18:10:08 -05:00
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_shape={}_fullmatrices={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), full_matrices),
|
|
|
|
"shape": shape, "dtype": dtype, "full_matrices": full_matrices,
|
|
|
|
"rng": rng}
|
2018-12-17 16:36:55 +00:00
|
|
|
for shape in [(1, 1), (3, 4), (2, 10, 5), (2, 200, 100)]
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
for dtype in float_types()
|
|
|
|
for full_matrices in [False, True]
|
|
|
|
for rng in [jtu.rand_default()]))
|
2018-12-26 15:25:14 -08:00
|
|
|
@jtu.skip_on_devices("cpu")
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
def testQr(self, shape, dtype, full_matrices, rng):
|
|
|
|
m, n = shape[-2:]
|
|
|
|
|
|
|
|
if full_matrices:
|
|
|
|
mode, k = "complete", m
|
|
|
|
else:
|
|
|
|
mode, k = "reduced", min(m, n)
|
|
|
|
|
|
|
|
a = rng(shape, dtype)
|
|
|
|
lq, lr = np.linalg.qr(a, mode=mode)
|
|
|
|
|
|
|
|
# onp.linalg.qr doesn't support broadcasting. But it seems like an
|
|
|
|
# inevitable extension so we support it in our version.
|
|
|
|
nq = onp.zeros(shape[:-2] + (m, k), dtype)
|
|
|
|
nr = onp.zeros(shape[:-2] + (k, n), dtype)
|
|
|
|
for index in onp.ndindex(*shape[:-2]):
|
|
|
|
nq[index], nr[index] = onp.linalg.qr(a[index], mode=mode)
|
|
|
|
|
|
|
|
max_rank = max(m, n)
|
|
|
|
|
|
|
|
# Norm, adjusted for dimension and type.
|
|
|
|
def norm(x):
|
|
|
|
n = onp.linalg.norm(x, axis=(-2, -1))
|
|
|
|
return n / (max_rank * onp.finfo(dtype).eps)
|
|
|
|
|
|
|
|
def compare_orthogonal(q1, q2):
|
|
|
|
# Q is unique up to sign, so normalize the sign first.
|
|
|
|
sum_of_ratios = onp.sum(onp.divide(q1, q2), axis=-2, keepdims=True)
|
|
|
|
phases = onp.divide(sum_of_ratios, onp.abs(sum_of_ratios))
|
|
|
|
q1 *= phases
|
|
|
|
self.assertTrue(onp.all(norm(q1 - q2) < 30))
|
|
|
|
|
|
|
|
# Check a ~= qr
|
|
|
|
self.assertTrue(onp.all(norm(a - onp.matmul(lq, lr)) < 30))
|
|
|
|
|
|
|
|
# Compare the first 'k' vectors of Q; the remainder form an arbitrary
|
|
|
|
# orthonormal basis for the null space.
|
|
|
|
compare_orthogonal(nq[..., :k], lq[..., :k])
|
|
|
|
|
|
|
|
# Check that q is close to unitary.
|
|
|
|
self.assertTrue(onp.all(norm(onp.eye(k) - onp.matmul(T(lq), lq)) < 5))
|
|
|
|
|
2018-12-17 16:36:55 +00:00
|
|
|
if not full_matrices and m >= n:
|
|
|
|
jtu.check_jvp(np.linalg.qr, partial(jvp, np.linalg.qr), (a,))
|
|
|
|
|
2018-12-21 16:29:45 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_lhs={}_rhs={}".format(
|
|
|
|
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype)),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
|
|
|
"rng": rng}
|
|
|
|
for lhs_shape, rhs_shape in [
|
|
|
|
((1, 1), (1, 1)),
|
|
|
|
((4, 4), (4,)),
|
|
|
|
((8, 8), (8, 4)),
|
|
|
|
]
|
|
|
|
for dtype in float_types() | complex_types()
|
|
|
|
for rng in [jtu.rand_default()]))
|
|
|
|
# TODO(phawkins): enable when there is an LU implementation for GPU/TPU.
|
|
|
|
@jtu.skip_on_devices("gpu", "tpu")
|
|
|
|
def testSolve(self, lhs_shape, rhs_shape, dtype, rng):
|
|
|
|
if not hasattr(lapack, "jax_getrf"):
|
|
|
|
self.skipTest("No LU implementation available")
|
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
|
|
|
|
|
|
|
self._CheckAgainstNumpy(onp.linalg.solve, np.linalg.solve, args_maker,
|
|
|
|
check_dtypes=True, tol=1e-3)
|
|
|
|
self._CompileAndCheck(np.linalg.solve, args_maker, check_dtypes=True)
|
|
|
|
|
2018-12-13 19:28:05 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
|
|
|
"shape": shape, "dtype": dtype, "rng": rng}
|
|
|
|
for shape in [(1, 1), (4, 4), (2, 5, 5), (200, 200), (5, 5, 5)]
|
|
|
|
for dtype in float_types()
|
|
|
|
for rng in [jtu.rand_default()]))
|
|
|
|
def testInv(self, shape, dtype, rng):
|
|
|
|
def args_maker():
|
2018-12-13 19:33:02 -05:00
|
|
|
invertible = False
|
|
|
|
while not invertible:
|
|
|
|
a = rng(shape, dtype)
|
|
|
|
try:
|
|
|
|
onp.linalg.inv(a)
|
|
|
|
invertible = True
|
|
|
|
except onp.linalg.LinAlgError:
|
|
|
|
pass
|
2018-12-13 19:28:05 -05:00
|
|
|
return [a]
|
|
|
|
|
|
|
|
self._CheckAgainstNumpy(onp.linalg.inv, np.linalg.inv, args_maker,
|
|
|
|
check_dtypes=True, tol=1e-3)
|
|
|
|
self._CompileAndCheck(np.linalg.inv, args_maker, check_dtypes=True)
|
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
|
|
|
class ScipyLinalgTest(jtu.JaxTestCase):
|
|
|
|
|
2018-12-20 21:04:02 -05:00
|
|
|
# TODO(phawkins): enable when there is an LU implementation for GPU/TPU.
|
2018-12-20 15:37:34 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
|
|
|
"shape": shape, "dtype": dtype, "rng": rng}
|
|
|
|
for shape in [(1, 1), (4, 5), (10, 5), (50, 50)]
|
2018-12-20 22:18:20 -05:00
|
|
|
for dtype in float_types() | complex_types()
|
2018-12-20 15:37:34 -05:00
|
|
|
for rng in [jtu.rand_default()]))
|
2018-12-20 21:04:02 -05:00
|
|
|
@jtu.skip_on_devices("gpu", "tpu")
|
2018-12-20 15:37:34 -05:00
|
|
|
def testLu(self, shape, dtype, rng):
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
|
2018-12-21 15:18:34 -05:00
|
|
|
self._CheckAgainstNumpy(jsp.linalg.lu, osp.linalg.lu, args_maker,
|
2018-12-20 15:37:34 -05:00
|
|
|
check_dtypes=True, tol=1e-3)
|
|
|
|
self._CompileAndCheck(jsp.linalg.lu, args_maker, check_dtypes=True)
|
|
|
|
|
2018-12-22 14:53:42 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
|
|
|
"shape": shape, "dtype": dtype, "rng": rng}
|
|
|
|
for shape in [(1, 1), (4, 5), (10, 5), (10, 10)]
|
|
|
|
for dtype in float_types() | complex_types()
|
|
|
|
for rng in [jtu.rand_default()]))
|
|
|
|
# TODO(phawkins): enable when there is an LU implementation for GPU/TPU.
|
|
|
|
@jtu.skip_on_devices("gpu", "tpu")
|
|
|
|
def testLuGrad(self, shape, dtype, rng):
|
|
|
|
a = rng(shape, dtype)
|
|
|
|
|
2019-01-13 11:37:34 -08:00
|
|
|
jtu.check_grads(jsp.linalg.lu, (a,), 2, rtol=1e-1)
|
2018-12-22 14:53:42 -05:00
|
|
|
|
|
|
|
|
2018-12-20 21:04:02 -05:00
|
|
|
# TODO(phawkins): enable when there is an LU implementation for GPU/TPU.
|
2018-12-20 15:37:34 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_n={}".format(jtu.format_shape_dtype_string((n,n), dtype)),
|
|
|
|
"n": n, "dtype": dtype, "rng": rng}
|
|
|
|
for n in [1, 4, 5, 200]
|
2018-12-20 22:18:20 -05:00
|
|
|
for dtype in float_types() | complex_types()
|
2018-12-20 15:37:34 -05:00
|
|
|
for rng in [jtu.rand_default()]))
|
2018-12-20 21:04:02 -05:00
|
|
|
@jtu.skip_on_devices("gpu", "tpu")
|
2018-12-20 15:37:34 -05:00
|
|
|
def testLuFactor(self, n, dtype, rng):
|
|
|
|
args_maker = lambda: [rng((n, n), dtype)]
|
|
|
|
|
2018-12-21 15:18:34 -05:00
|
|
|
self._CheckAgainstNumpy(jsp.linalg.lu_factor, osp.linalg.lu_factor,
|
2018-12-20 15:37:34 -05:00
|
|
|
args_maker, check_dtypes=True, tol=1e-3)
|
|
|
|
self._CompileAndCheck(jsp.linalg.lu_factor, args_maker, check_dtypes=True)
|
|
|
|
|
2018-12-21 16:29:45 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_lhs={}_rhs={}_sym_pos={}_lower={}".format(
|
|
|
|
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
|
|
|
sym_pos, lower),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
|
|
|
"sym_pos": sym_pos, "lower": lower, "rng": rng}
|
|
|
|
for lhs_shape, rhs_shape in [
|
|
|
|
((1, 1), (1, 1)),
|
|
|
|
((4, 4), (4,)),
|
|
|
|
((8, 8), (8, 4)),
|
|
|
|
]
|
|
|
|
for sym_pos, lower in [
|
|
|
|
(False, False),
|
|
|
|
(True, False),
|
|
|
|
(True, True),
|
|
|
|
]
|
|
|
|
for dtype in float_types() | complex_types()
|
|
|
|
for rng in [jtu.rand_default()]))
|
|
|
|
# TODO(phawkins): enable when there is an LU implementation for GPU/TPU.
|
|
|
|
@jtu.skip_on_devices("gpu", "tpu")
|
|
|
|
def testSolve(self, lhs_shape, rhs_shape, dtype, sym_pos, lower, rng):
|
|
|
|
if not hasattr(lapack, "jax_getrf"):
|
|
|
|
self.skipTest("No LU implementation available")
|
|
|
|
osp_fun = lambda lhs, rhs: osp.linalg.solve(lhs, rhs, sym_pos=sym_pos, lower=lower)
|
|
|
|
jsp_fun = lambda lhs, rhs: jsp.linalg.solve(lhs, rhs, sym_pos=sym_pos, lower=lower)
|
|
|
|
|
|
|
|
def args_maker():
|
|
|
|
a = rng(lhs_shape, dtype)
|
|
|
|
if sym_pos:
|
|
|
|
a = onp.matmul(a, onp.conj(T(a)))
|
|
|
|
a = onp.tril(a) if lower else onp.triu(a)
|
|
|
|
return [a, rng(rhs_shape, dtype)]
|
|
|
|
|
|
|
|
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker,
|
|
|
|
check_dtypes=True, tol=1e-3)
|
|
|
|
self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True)
|
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_lhs={}_rhs={}_lower={}_transposea={}".format(
|
|
|
|
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
|
|
|
lower, transpose_a),
|
|
|
|
"lower": lower, "transpose_a": transpose_a,
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
|
|
|
"rng": rng}
|
|
|
|
for lower, transpose_a in itertools.product([False, True], repeat=2)
|
|
|
|
for lhs_shape, rhs_shape in [
|
|
|
|
((4, 4), (4,)),
|
|
|
|
((4, 4), (4, 3)),
|
|
|
|
((2, 8, 8), (2, 8, 10)),
|
|
|
|
]
|
|
|
|
for dtype in float_types()
|
|
|
|
for rng in [jtu.rand_default()]))
|
2019-03-10 17:31:51 -04:00
|
|
|
def testSolveTriangular(self, lower, transpose_a, lhs_shape, rhs_shape, dtype,
|
|
|
|
rng):
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
k = rng(lhs_shape, dtype)
|
|
|
|
l = onp.linalg.cholesky(onp.matmul(k, T(k))
|
|
|
|
+ lhs_shape[-1] * onp.eye(lhs_shape[-1]))
|
|
|
|
l = l.astype(k.dtype)
|
|
|
|
b = rng(rhs_shape, dtype)
|
|
|
|
|
|
|
|
a = l if lower else T(l)
|
|
|
|
inv = onp.linalg.inv(T(a) if transpose_a else a).astype(a.dtype)
|
|
|
|
if len(lhs_shape) == len(rhs_shape):
|
|
|
|
onp_ans = onp.matmul(inv, b)
|
|
|
|
else:
|
|
|
|
onp_ans = onp.einsum("...ij,...j->...i", inv, b)
|
|
|
|
|
|
|
|
# The standard scipy.linalg.solve_triangular doesn't support broadcasting.
|
|
|
|
# But it seems like an inevitable extension so we support it.
|
2018-12-20 15:37:34 -05:00
|
|
|
ans = jsp.linalg.solve_triangular(
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
l if lower else T(l), b, trans=1 if transpose_a else 0, lower=lower)
|
|
|
|
|
|
|
|
self.assertAllClose(onp_ans, ans, check_dtypes=True)
|
|
|
|
|
2018-12-17 17:20:52 -08:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_lhs={}_rhs={}_lower={}_transposea={}".format(
|
|
|
|
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
|
|
|
lower, transpose_a),
|
|
|
|
"lower": lower, "transpose_a": transpose_a,
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
|
|
|
"rng": rng}
|
|
|
|
for lower, transpose_a in itertools.product([False, True], repeat=2)
|
|
|
|
for lhs_shape, rhs_shape in [
|
|
|
|
((4, 4), (4,)),
|
|
|
|
((4, 4), (4, 3)),
|
|
|
|
((2, 8, 8), (2, 8, 10)),
|
|
|
|
]
|
|
|
|
for dtype in float_types()
|
|
|
|
for rng in [jtu.rand_default()]))
|
2018-12-20 15:37:34 -05:00
|
|
|
def testSolveTriangularGrad(self, lower, transpose_a, lhs_shape,
|
2018-12-17 17:20:52 -08:00
|
|
|
rhs_shape, dtype, rng):
|
2018-12-19 21:37:46 -08:00
|
|
|
# TODO(frostig): change ensemble to support a bigger rtol
|
2018-12-22 11:27:57 -08:00
|
|
|
self.skipTest("rtol does not cover all devices and precision modes")
|
2018-12-19 21:37:46 -08:00
|
|
|
A = np.tril(rng(lhs_shape, dtype) + 5 * onp.eye(lhs_shape[-1], dtype=dtype))
|
2018-12-17 17:20:52 -08:00
|
|
|
A = A if lower else T(A)
|
|
|
|
B = rng(rhs_shape, dtype)
|
2018-12-20 15:37:34 -05:00
|
|
|
f = partial(jsp.linalg.solve_triangular, lower=lower,
|
2018-12-17 17:20:52 -08:00
|
|
|
trans=1 if transpose_a else 0)
|
2018-12-19 21:37:46 -08:00
|
|
|
jtu.check_grads(f, (A, B), 2, rtol=1e-3)
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
|
2019-02-11 16:18:13 -08:00
|
|
|
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
if __name__ == "__main__":
|
|
|
|
absltest.main()
|