mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #13147 from hawkinsp:eyes
PiperOrigin-RevId: 486826532
This commit is contained in:
commit
1e7e8e8d5c
@ -21,6 +21,7 @@ import warnings
|
||||
import numpy as np
|
||||
from typing_extensions import Literal
|
||||
|
||||
import jax
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src.numpy.vectorize import vectorize
|
||||
from jax._src import ad_util
|
||||
@ -1551,6 +1552,7 @@ def _svd_abstract_eval(operand, *, full_matrices, compute_uv):
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@jax.default_matmul_precision("float32")
|
||||
def _svd_jvp_rule(primals, tangents, *, full_matrices, compute_uv):
|
||||
A, = primals
|
||||
dA, = tangents
|
||||
@ -1563,7 +1565,7 @@ def _svd_jvp_rule(primals, tangents, *, full_matrices, compute_uv):
|
||||
|
||||
Ut, V = _H(U), _H(Vt)
|
||||
s_dim = s[..., None, :]
|
||||
dS = jnp.matmul(jnp.matmul(Ut, dA), V)
|
||||
dS = Ut @ dA @ V
|
||||
ds = jnp.real(jnp.diagonal(dS, 0, -2, -1))
|
||||
|
||||
if not compute_uv:
|
||||
@ -1580,16 +1582,16 @@ def _svd_jvp_rule(primals, tangents, *, full_matrices, compute_uv):
|
||||
s_inv = 1 / (s + s_zeros) - s_zeros
|
||||
s_inv_mat = jnp.vectorize(jnp.diag, signature='(k)->(k,k)')(s_inv)
|
||||
dUdV_diag = .5 * (dS - _H(dS)) * s_inv_mat.astype(A.dtype)
|
||||
dU = jnp.matmul(U, F.astype(A.dtype) * (dSS + _H(dSS)) + dUdV_diag)
|
||||
dV = jnp.matmul(V, F.astype(A.dtype) * (SdS + _H(SdS)))
|
||||
dU = U @ (F.astype(A.dtype) * (dSS + _H(dSS)) + dUdV_diag)
|
||||
dV = V @ (F.astype(A.dtype) * (SdS + _H(SdS)))
|
||||
|
||||
m, n = A.shape[-2:]
|
||||
if m > n:
|
||||
I = lax.expand_dims(jnp.eye(m, dtype=A.dtype), range(U.ndim - 2))
|
||||
dU = dU + jnp.matmul(I - jnp.matmul(U, Ut), jnp.matmul(dA, V)) / s_dim.astype(A.dtype)
|
||||
dAV = dA @ V
|
||||
dU = dU + (dAV - U @ Ut @ dAV) / s_dim.astype(A.dtype)
|
||||
if n > m:
|
||||
I = lax.expand_dims(jnp.eye(n, dtype=A.dtype), range(V.ndim - 2))
|
||||
dV = dV + jnp.matmul(I - jnp.matmul(V, Vt), jnp.matmul(_H(dA), U)) / s_dim.astype(A.dtype)
|
||||
dAHU = _H(dA) @ U
|
||||
dV = dV + (dAHU - V @ Vt @ dAHU) / s_dim.astype(A.dtype)
|
||||
|
||||
return (s, U, Vt), (ds, dU, _H(dV))
|
||||
|
||||
|
@ -15,7 +15,6 @@
|
||||
"""Tests for the LAPAX linear algebra module."""
|
||||
|
||||
from functools import partial
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import scipy
|
||||
@ -478,11 +477,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
)
|
||||
def testNorm(self, shape, dtype, ord, axis, keepdims):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
if (ord in ('nuc', 2, -2) and (
|
||||
jtu.device_under_test() != "cpu" or
|
||||
(isinstance(axis, tuple) and len(axis) == 2))):
|
||||
raise unittest.SkipTest("No adequate SVD implementation available")
|
||||
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
np_fn = partial(np.linalg.norm, ord=ord, axis=axis, keepdims=keepdims)
|
||||
jnp_fn = partial(jnp.linalg.norm, ord=ord, axis=axis, keepdims=keepdims)
|
||||
@ -663,7 +657,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
phases = np.divide(sum_of_ratios, np.abs(sum_of_ratios))
|
||||
q1 *= phases
|
||||
nm = norm(q1 - q2)
|
||||
self.assertTrue(np.all(nm < 140), msg=f"norm={np.amax(nm)}")
|
||||
self.assertTrue(np.all(nm < 160), msg=f"norm={np.amax(nm)}")
|
||||
|
||||
# Check a ~= qr
|
||||
self.assertTrue(np.all(norm(a - np.matmul(lq, lr)) < 40))
|
||||
|
Loading…
x
Reference in New Issue
Block a user