lax.linalg.qr: allow jvp when m == n and full_matrices=True

This commit is contained in:
Jake VanderPlas 2022-04-26 10:34:50 -07:00
parent 04b6f15cdb
commit 67e0fdda82
2 changed files with 4 additions and 3 deletions

View File

@ -1173,7 +1173,7 @@ def qr_jvp_rule(primals, tangents, full_matrices):
dx, = tangents
q, r = qr_p.bind(x, full_matrices=False)
*_, m, n = x.shape
if full_matrices or m < n:
if m < n or (full_matrices and m != n):
raise NotImplementedError(
"Unimplemented case of QR decomposition derivative")
dx_rinv = triangular_solve(r, dx) # Right side solve by default

View File

@ -671,8 +671,9 @@ class NumpyLinalgTest(jtu.JaxTestCase):
self.assertTrue(np.all(
norm(np.eye(k) - np.matmul(np.conj(T(lq)), lq)) < 5))
if not full_matrices and m >= n:
jtu.check_jvp(jnp.linalg.qr, partial(jvp, jnp.linalg.qr), (a,), atol=3e-3)
if m == n or (m > n and not full_matrices):
qr = partial(jnp.linalg.qr, mode=mode)
jtu.check_jvp(qr, partial(jvp, qr), (a,), atol=3e-3)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}".format(