Add support for the hermitian option on jnp.linalg.pinv.

Improve the pinv implementation to avoid computing an unnecessary reduction: svd sorts its singular values so we don't need to use amax() to find the largest one.
Avoid explicitly forming the identity matrix in the pinv JVP.
This commit is contained in:
Peter Hawkins 2022-11-07 11:39:19 -05:00
parent 1e7e8e8d5c
commit ab8cde9ed4
3 changed files with 42 additions and 23 deletions

View File

@ -7,6 +7,8 @@ Remember to align the itemized text with the first line of an item within a list
--> -->
## jax 0.3.25 ## jax 0.3.25
* Changes
* {func}`jax.numpy.linalg.pinv` now supports the `hermitian` option.
## jaxlib 0.3.25 ## jaxlib 0.3.25

View File

@ -41,6 +41,9 @@ def _H(x: ArrayLike) -> Array:
return jnp.conjugate(jnp.swapaxes(x, -1, -2)) return jnp.conjugate(jnp.swapaxes(x, -1, -2))
def _symmetrize(x: Array) -> Array: return (x + _H(x)) / 2
@_wraps(np.linalg.cholesky) @_wraps(np.linalg.cholesky)
@jit @jit
def cholesky(a: ArrayLike) -> Array: def cholesky(a: ArrayLike) -> Array:
@ -406,27 +409,32 @@ def eigvalsh(a: ArrayLike, UPLO: Optional[str] = 'L') -> Array:
return w return w
@partial(custom_jvp, nondiff_argnums=(1,)) @partial(custom_jvp, nondiff_argnums=(1, 2))
@_wraps(np.linalg.pinv, lax_description=textwrap.dedent("""\ @_wraps(np.linalg.pinv, lax_description=textwrap.dedent("""\
It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the
default `rcond` is `1e-15`. Here the default is default `rcond` is `1e-15`. Here the default is
`10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps`. `10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps`.
""")) """))
@jit @partial(jit, static_argnames=('hermitian',))
def pinv(a: ArrayLike, rcond: Optional[ArrayLike] = None) -> Array: def pinv(a: ArrayLike, rcond: Optional[ArrayLike] = None,
hermitian: bool = False) -> Array:
# Uses same algorithm as # Uses same algorithm as
# https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979 # https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979
_check_arraylike("jnp.linalg.pinv", a) _check_arraylike("jnp.linalg.pinv", a)
arr = jnp.conj(a) arr = jnp.asarray(a)
m, n = arr.shape[-2:]
if m == 0 or n == 0:
return jnp.empty(arr.shape[:-2] + (n, m), arr.dtype)
arr = jnp.conj(arr)
if rcond is None: if rcond is None:
max_rows_cols = max(arr.shape[-2:]) max_rows_cols = max(arr.shape[-2:])
rcond = 10. * max_rows_cols * jnp.array(jnp.finfo(arr.dtype).eps) rcond = 10. * max_rows_cols * jnp.array(jnp.finfo(arr.dtype).eps)
rcond = jnp.asarray(rcond) rcond = jnp.asarray(rcond)
u, s, vh = svd(arr, full_matrices=False) u, s, vh = svd(arr, full_matrices=False, hermitian=hermitian)
# Singular values less than or equal to ``rcond * largest_singular_value`` # Singular values less than or equal to ``rcond * largest_singular_value``
# are set to zero. # are set to zero.
rcond = lax.expand_dims(rcond[..., jnp.newaxis], range(s.ndim - rcond.ndim - 1)) rcond = lax.expand_dims(rcond[..., jnp.newaxis], range(s.ndim - rcond.ndim - 1))
cutoff = rcond * jnp.amax(s, axis=-1, keepdims=True, initial=-jnp.inf) cutoff = rcond * s[..., 0:1]
s = jnp.where(s > cutoff, s, jnp.inf).astype(u.dtype) s = jnp.where(s > cutoff, s, jnp.inf).astype(u.dtype)
res = jnp.matmul(_T(vh), jnp.divide(_T(u), s[..., jnp.newaxis]), res = jnp.matmul(_T(vh), jnp.divide(_T(u), s[..., jnp.newaxis]),
precision=lax.Precision.HIGHEST) precision=lax.Precision.HIGHEST)
@ -435,22 +443,24 @@ def pinv(a: ArrayLike, rcond: Optional[ArrayLike] = None) -> Array:
@pinv.defjvp @pinv.defjvp
@jax.default_matmul_precision("float32") @jax.default_matmul_precision("float32")
def _pinv_jvp(rcond, primals, tangents): def _pinv_jvp(rcond, hermitian, primals, tangents):
# The Differentiation of Pseudo-Inverses and Nonlinear Least Squares Problems # The Differentiation of Pseudo-Inverses and Nonlinear Least Squares Problems
# Whose Variables Separate. Author(s): G. H. Golub and V. Pereyra. SIAM # Whose Variables Separate. Author(s): G. H. Golub and V. Pereyra. SIAM
# Journal on Numerical Analysis, Vol. 10, No. 2 (Apr., 1973), pp. 413-432. # Journal on Numerical Analysis, Vol. 10, No. 2 (Apr., 1973), pp. 413-432.
# (via https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Derivative) # (via https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Derivative)
a, = primals a, = primals
a_dot, = tangents a_dot, = tangents
p = pinv(a, rcond=rcond) p = pinv(a, rcond=rcond, hermitian=hermitian)
m, n = a.shape[-2:] if hermitian:
# TODO(phawkins): on TPU, we would need to opt into high precision here. # svd(..., hermitian=True) symmetrizes its input, and the JVP must match.
# TODO(phawkins): consider if this can be simplified in the Hermitian case. a = _symmetrize(a)
p_dot = -p @ a_dot @ p a_dot = _symmetrize(a_dot)
I_n = lax.expand_dims(jnp.eye(m, dtype=a.dtype), range(a.ndim - 2))
p_dot = p_dot + p @ _H(p) @ _H(a_dot) @ (I_n - a @ p) # TODO(phawkins): this could be simplified in the Hermitian case if we
I_m = lax.expand_dims(jnp.eye(n, dtype=a.dtype), range(a.ndim - 2)) # supported triangular matrix multiplication.
p_dot = p_dot + (I_m - p @ a) @ _H(a_dot) @ _H(p) @ p s = p @ _H(p) @ _H(a_dot)
t = _H(a_dot) @ _H(p) @ p
p_dot = -p @ a_dot @ p + s - s @ a @ p + t - p @ a @ t
return p, p_dot return p, p_dot

View File

@ -798,21 +798,28 @@ class NumpyLinalgTest(jtu.JaxTestCase):
self._CompileAndCheck(jnp.linalg.inv, args_maker) self._CompileAndCheck(jnp.linalg.inv, args_maker)
@jtu.sample_product( @jtu.sample_product(
shape=[(1, 1), (4, 4), (2, 70, 7), (2000, 7), (7, 1000), (70, 7, 2), [dict(shape=shape, hermitian=hermitian)
(2, 0, 0), (3, 0, 2), (1, 0)], for shape in [(1, 1), (4, 4), (3, 10, 10), (2, 70, 7), (2000, 7),
(7, 1000), (70, 7, 2), (2, 0, 0), (3, 0, 2), (1, 0)]
for hermitian in ([False, True] if shape[-1] == shape[-2] else [False])],
dtype=float_types + complex_types, dtype=float_types + complex_types,
) )
@jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1 @jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1
def testPinv(self, shape, dtype): def testPinv(self, shape, hermitian, dtype):
rng = jtu.rand_default(self.rng()) rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)] args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np.linalg.pinv, jnp.linalg.pinv, args_maker, jnp_fn = partial(jnp.linalg.pinv, hermitian=hermitian)
tol=1e-2) def np_fn(a):
self._CompileAndCheck(jnp.linalg.pinv, args_maker) # Symmetrize the input matrix to match the jnp behavior.
if hermitian:
a = (a + T(a.conj())) / 2
return np.linalg.pinv(a, hermitian=hermitian)
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker)
# TODO(phawkins): 1e-1 seems like a very loose tolerance. # TODO(phawkins): 1e-1 seems like a very loose tolerance.
jtu.check_grads(jnp.linalg.pinv, args_maker(), 2, rtol=1e-1, atol=2e-1) jtu.check_grads(jnp_fn, args_maker(), 1, rtol=3e-2, atol=1e-3)
def testPinvGradIssue2792(self): def testPinvGradIssue2792(self):
def f(p): def f(p):