Merge pull request #13145 from hawkinsp:pinv

PiperOrigin-RevId: 486935918
This commit is contained in:
jax authors 2022-11-08 06:39:54 -08:00
commit 3994ac30d5
3 changed files with 42 additions and 23 deletions

View File

@ -7,6 +7,8 @@ Remember to align the itemized text with the first line of an item within a list
-->
## jax 0.3.25
* Changes
* {func}`jax.numpy.linalg.pinv` now supports the `hermitian` option.
## jaxlib 0.3.25

View File

@ -41,6 +41,9 @@ def _H(x: ArrayLike) -> Array:
return jnp.conjugate(jnp.swapaxes(x, -1, -2))
def _symmetrize(x: Array) -> Array: return (x + _H(x)) / 2
@_wraps(np.linalg.cholesky)
@jit
def cholesky(a: ArrayLike) -> Array:
@ -406,27 +409,32 @@ def eigvalsh(a: ArrayLike, UPLO: Optional[str] = 'L') -> Array:
return w
@partial(custom_jvp, nondiff_argnums=(1,))
@partial(custom_jvp, nondiff_argnums=(1, 2))
@_wraps(np.linalg.pinv, lax_description=textwrap.dedent("""\
It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the
default `rcond` is `1e-15`. Here the default is
`10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps`.
"""))
@jit
def pinv(a: ArrayLike, rcond: Optional[ArrayLike] = None) -> Array:
@partial(jit, static_argnames=('hermitian',))
def pinv(a: ArrayLike, rcond: Optional[ArrayLike] = None,
hermitian: bool = False) -> Array:
# Uses same algorithm as
# https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979
_check_arraylike("jnp.linalg.pinv", a)
arr = jnp.conj(a)
arr = jnp.asarray(a)
m, n = arr.shape[-2:]
if m == 0 or n == 0:
return jnp.empty(arr.shape[:-2] + (n, m), arr.dtype)
arr = jnp.conj(arr)
if rcond is None:
max_rows_cols = max(arr.shape[-2:])
rcond = 10. * max_rows_cols * jnp.array(jnp.finfo(arr.dtype).eps)
rcond = jnp.asarray(rcond)
u, s, vh = svd(arr, full_matrices=False)
u, s, vh = svd(arr, full_matrices=False, hermitian=hermitian)
# Singular values less than or equal to ``rcond * largest_singular_value``
# are set to zero.
rcond = lax.expand_dims(rcond[..., jnp.newaxis], range(s.ndim - rcond.ndim - 1))
cutoff = rcond * jnp.amax(s, axis=-1, keepdims=True, initial=-jnp.inf)
cutoff = rcond * s[..., 0:1]
s = jnp.where(s > cutoff, s, jnp.inf).astype(u.dtype)
res = jnp.matmul(_T(vh), jnp.divide(_T(u), s[..., jnp.newaxis]),
precision=lax.Precision.HIGHEST)
@ -435,22 +443,24 @@ def pinv(a: ArrayLike, rcond: Optional[ArrayLike] = None) -> Array:
@pinv.defjvp
@jax.default_matmul_precision("float32")
def _pinv_jvp(rcond, primals, tangents):
def _pinv_jvp(rcond, hermitian, primals, tangents):
# The Differentiation of Pseudo-Inverses and Nonlinear Least Squares Problems
# Whose Variables Separate. Author(s): G. H. Golub and V. Pereyra. SIAM
# Journal on Numerical Analysis, Vol. 10, No. 2 (Apr., 1973), pp. 413-432.
# (via https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Derivative)
a, = primals
a_dot, = tangents
p = pinv(a, rcond=rcond)
m, n = a.shape[-2:]
# TODO(phawkins): on TPU, we would need to opt into high precision here.
# TODO(phawkins): consider if this can be simplified in the Hermitian case.
p_dot = -p @ a_dot @ p
I_n = lax.expand_dims(jnp.eye(m, dtype=a.dtype), range(a.ndim - 2))
p_dot = p_dot + p @ _H(p) @ _H(a_dot) @ (I_n - a @ p)
I_m = lax.expand_dims(jnp.eye(n, dtype=a.dtype), range(a.ndim - 2))
p_dot = p_dot + (I_m - p @ a) @ _H(a_dot) @ _H(p) @ p
p = pinv(a, rcond=rcond, hermitian=hermitian)
if hermitian:
# svd(..., hermitian=True) symmetrizes its input, and the JVP must match.
a = _symmetrize(a)
a_dot = _symmetrize(a_dot)
# TODO(phawkins): this could be simplified in the Hermitian case if we
# supported triangular matrix multiplication.
s = p @ _H(p) @ _H(a_dot)
t = _H(a_dot) @ _H(p) @ p
p_dot = -p @ a_dot @ p + s - s @ a @ p + t - p @ a @ t
return p, p_dot

View File

@ -798,21 +798,28 @@ class NumpyLinalgTest(jtu.JaxTestCase):
self._CompileAndCheck(jnp.linalg.inv, args_maker)
@jtu.sample_product(
shape=[(1, 1), (4, 4), (2, 70, 7), (2000, 7), (7, 1000), (70, 7, 2),
(2, 0, 0), (3, 0, 2), (1, 0)],
[dict(shape=shape, hermitian=hermitian)
for shape in [(1, 1), (4, 4), (3, 10, 10), (2, 70, 7), (2000, 7),
(7, 1000), (70, 7, 2), (2, 0, 0), (3, 0, 2), (1, 0)]
for hermitian in ([False, True] if shape[-1] == shape[-2] else [False])],
dtype=float_types + complex_types,
)
@jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1
def testPinv(self, shape, dtype):
def testPinv(self, shape, hermitian, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np.linalg.pinv, jnp.linalg.pinv, args_maker,
tol=1e-2)
self._CompileAndCheck(jnp.linalg.pinv, args_maker)
jnp_fn = partial(jnp.linalg.pinv, hermitian=hermitian)
def np_fn(a):
# Symmetrize the input matrix to match the jnp behavior.
if hermitian:
a = (a + T(a.conj())) / 2
return np.linalg.pinv(a, hermitian=hermitian)
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker)
# TODO(phawkins): 1e-1 seems like a very loose tolerance.
jtu.check_grads(jnp.linalg.pinv, args_maker(), 2, rtol=1e-1, atol=2e-1)
jtu.check_grads(jnp_fn, args_maker(), 1, rtol=3e-2, atol=1e-3)
def testPinvGradIssue2792(self):
def f(p):