mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
lax.linalg.qr: allow jvp when m == n and full_matrices=True
This commit is contained in:
parent
04b6f15cdb
commit
67e0fdda82
@ -1173,7 +1173,7 @@ def qr_jvp_rule(primals, tangents, full_matrices):
|
||||
dx, = tangents
|
||||
q, r = qr_p.bind(x, full_matrices=False)
|
||||
*_, m, n = x.shape
|
||||
if full_matrices or m < n:
|
||||
if m < n or (full_matrices and m != n):
|
||||
raise NotImplementedError(
|
||||
"Unimplemented case of QR decomposition derivative")
|
||||
dx_rinv = triangular_solve(r, dx) # Right side solve by default
|
||||
|
@ -671,8 +671,9 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
self.assertTrue(np.all(
|
||||
norm(np.eye(k) - np.matmul(np.conj(T(lq)), lq)) < 5))
|
||||
|
||||
if not full_matrices and m >= n:
|
||||
jtu.check_jvp(jnp.linalg.qr, partial(jvp, jnp.linalg.qr), (a,), atol=3e-3)
|
||||
if m == n or (m > n and not full_matrices):
|
||||
qr = partial(jnp.linalg.qr, mode=mode)
|
||||
jtu.check_jvp(qr, partial(jvp, qr), (a,), atol=3e-3)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}".format(
|
||||
|
Loading…
x
Reference in New Issue
Block a user