Merge pull request #13147 from hawkinsp:eyes

PiperOrigin-RevId: 486826532
This commit is contained in:
jax authors 2022-11-07 19:25:15 -08:00
commit 1e7e8e8d5c
2 changed files with 10 additions and 14 deletions

View File

@ -21,6 +21,7 @@ import warnings
import numpy as np
from typing_extensions import Literal
import jax
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy.vectorize import vectorize
from jax._src import ad_util
@ -1551,6 +1552,7 @@ def _svd_abstract_eval(operand, *, full_matrices, compute_uv):
else:
raise NotImplementedError
@jax.default_matmul_precision("float32")
def _svd_jvp_rule(primals, tangents, *, full_matrices, compute_uv):
A, = primals
dA, = tangents
@ -1563,7 +1565,7 @@ def _svd_jvp_rule(primals, tangents, *, full_matrices, compute_uv):
Ut, V = _H(U), _H(Vt)
s_dim = s[..., None, :]
dS = jnp.matmul(jnp.matmul(Ut, dA), V)
dS = Ut @ dA @ V
ds = jnp.real(jnp.diagonal(dS, 0, -2, -1))
if not compute_uv:
@ -1580,16 +1582,16 @@ def _svd_jvp_rule(primals, tangents, *, full_matrices, compute_uv):
s_inv = 1 / (s + s_zeros) - s_zeros
s_inv_mat = jnp.vectorize(jnp.diag, signature='(k)->(k,k)')(s_inv)
dUdV_diag = .5 * (dS - _H(dS)) * s_inv_mat.astype(A.dtype)
dU = jnp.matmul(U, F.astype(A.dtype) * (dSS + _H(dSS)) + dUdV_diag)
dV = jnp.matmul(V, F.astype(A.dtype) * (SdS + _H(SdS)))
dU = U @ (F.astype(A.dtype) * (dSS + _H(dSS)) + dUdV_diag)
dV = V @ (F.astype(A.dtype) * (SdS + _H(SdS)))
m, n = A.shape[-2:]
if m > n:
I = lax.expand_dims(jnp.eye(m, dtype=A.dtype), range(U.ndim - 2))
dU = dU + jnp.matmul(I - jnp.matmul(U, Ut), jnp.matmul(dA, V)) / s_dim.astype(A.dtype)
dAV = dA @ V
dU = dU + (dAV - U @ Ut @ dAV) / s_dim.astype(A.dtype)
if n > m:
I = lax.expand_dims(jnp.eye(n, dtype=A.dtype), range(V.ndim - 2))
dV = dV + jnp.matmul(I - jnp.matmul(V, Vt), jnp.matmul(_H(dA), U)) / s_dim.astype(A.dtype)
dAHU = _H(dA) @ U
dV = dV + (dAHU - V @ Vt @ dAHU) / s_dim.astype(A.dtype)
return (s, U, Vt), (ds, dU, _H(dV))

View File

@ -15,7 +15,6 @@
"""Tests for the LAPAX linear algebra module."""
from functools import partial
import unittest
import numpy as np
import scipy
@ -478,11 +477,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
)
def testNorm(self, shape, dtype, ord, axis, keepdims):
rng = jtu.rand_default(self.rng())
if (ord in ('nuc', 2, -2) and (
jtu.device_under_test() != "cpu" or
(isinstance(axis, tuple) and len(axis) == 2))):
raise unittest.SkipTest("No adequate SVD implementation available")
args_maker = lambda: [rng(shape, dtype)]
np_fn = partial(np.linalg.norm, ord=ord, axis=axis, keepdims=keepdims)
jnp_fn = partial(jnp.linalg.norm, ord=ord, axis=axis, keepdims=keepdims)
@ -663,7 +657,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
phases = np.divide(sum_of_ratios, np.abs(sum_of_ratios))
q1 *= phases
nm = norm(q1 - q2)
self.assertTrue(np.all(nm < 140), msg=f"norm={np.amax(nm)}")
self.assertTrue(np.all(nm < 160), msg=f"norm={np.amax(nm)}")
# Check a ~= qr
self.assertTrue(np.all(norm(a - np.matmul(lq, lr)) < 40))