Add support for the hermitian option on jnp.linalg.pinv.

Improve the pinv implementation to avoid computing an unnecessary reduction: svd sorts its singular values so we don't need to use amax() to find the largest one.
Avoid explicitly forming the identity matrix in the pinv JVP.
This commit is contained in:
Peter Hawkins 2022-11-07 11:39:19 -05:00
parent 1e7e8e8d5c
commit ab8cde9ed4
3 changed files with 42 additions and 23 deletions

View File

@ -7,6 +7,8 @@ Remember to align the itemized text with the first line of an item within a list
-->
## jax 0.3.25
* Changes
* {func}`jax.numpy.linalg.pinv` now supports the `hermitian` option.
## jaxlib 0.3.25

View File

@ -41,6 +41,9 @@ def _H(x: ArrayLike) -> Array:
return jnp.conjugate(jnp.swapaxes(x, -1, -2))
def _symmetrize(x: Array) -> Array: return (x + _H(x)) / 2
@_wraps(np.linalg.cholesky)
@jit
def cholesky(a: ArrayLike) -> Array:
@ -406,27 +409,32 @@ def eigvalsh(a: ArrayLike, UPLO: Optional[str] = 'L') -> Array:
return w
@partial(custom_jvp, nondiff_argnums=(1,))
@partial(custom_jvp, nondiff_argnums=(1, 2))
@_wraps(np.linalg.pinv, lax_description=textwrap.dedent("""\
It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the
default `rcond` is `1e-15`. Here the default is
`10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps`.
"""))
@jit
def pinv(a: ArrayLike, rcond: Optional[ArrayLike] = None) -> Array:
@partial(jit, static_argnames=('hermitian',))
def pinv(a: ArrayLike, rcond: Optional[ArrayLike] = None,
hermitian: bool = False) -> Array:
# Uses same algorithm as
# https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979
_check_arraylike("jnp.linalg.pinv", a)
arr = jnp.conj(a)
arr = jnp.asarray(a)
m, n = arr.shape[-2:]
if m == 0 or n == 0:
return jnp.empty(arr.shape[:-2] + (n, m), arr.dtype)
arr = jnp.conj(arr)
if rcond is None:
max_rows_cols = max(arr.shape[-2:])
rcond = 10. * max_rows_cols * jnp.array(jnp.finfo(arr.dtype).eps)
rcond = jnp.asarray(rcond)
u, s, vh = svd(arr, full_matrices=False)
u, s, vh = svd(arr, full_matrices=False, hermitian=hermitian)
# Singular values less than or equal to ``rcond * largest_singular_value``
# are set to zero.
rcond = lax.expand_dims(rcond[..., jnp.newaxis], range(s.ndim - rcond.ndim - 1))
cutoff = rcond * jnp.amax(s, axis=-1, keepdims=True, initial=-jnp.inf)
cutoff = rcond * s[..., 0:1]
s = jnp.where(s > cutoff, s, jnp.inf).astype(u.dtype)
res = jnp.matmul(_T(vh), jnp.divide(_T(u), s[..., jnp.newaxis]),
precision=lax.Precision.HIGHEST)
@ -435,22 +443,24 @@ def pinv(a: ArrayLike, rcond: Optional[ArrayLike] = None) -> Array:
@pinv.defjvp
@jax.default_matmul_precision("float32")
def _pinv_jvp(rcond, primals, tangents):
def _pinv_jvp(rcond, hermitian, primals, tangents):
# The Differentiation of Pseudo-Inverses and Nonlinear Least Squares Problems
# Whose Variables Separate. Author(s): G. H. Golub and V. Pereyra. SIAM
# Journal on Numerical Analysis, Vol. 10, No. 2 (Apr., 1973), pp. 413-432.
# (via https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Derivative)
a, = primals
a_dot, = tangents
p = pinv(a, rcond=rcond)
m, n = a.shape[-2:]
# TODO(phawkins): on TPU, we would need to opt into high precision here.
# TODO(phawkins): consider if this can be simplified in the Hermitian case.
p_dot = -p @ a_dot @ p
I_n = lax.expand_dims(jnp.eye(m, dtype=a.dtype), range(a.ndim - 2))
p_dot = p_dot + p @ _H(p) @ _H(a_dot) @ (I_n - a @ p)
I_m = lax.expand_dims(jnp.eye(n, dtype=a.dtype), range(a.ndim - 2))
p_dot = p_dot + (I_m - p @ a) @ _H(a_dot) @ _H(p) @ p
p = pinv(a, rcond=rcond, hermitian=hermitian)
if hermitian:
# svd(..., hermitian=True) symmetrizes its input, and the JVP must match.
a = _symmetrize(a)
a_dot = _symmetrize(a_dot)
# TODO(phawkins): this could be simplified in the Hermitian case if we
# supported triangular matrix multiplication.
s = p @ _H(p) @ _H(a_dot)
t = _H(a_dot) @ _H(p) @ p
p_dot = -p @ a_dot @ p + s - s @ a @ p + t - p @ a @ t
return p, p_dot

View File

@ -798,21 +798,28 @@ class NumpyLinalgTest(jtu.JaxTestCase):
self._CompileAndCheck(jnp.linalg.inv, args_maker)
@jtu.sample_product(
shape=[(1, 1), (4, 4), (2, 70, 7), (2000, 7), (7, 1000), (70, 7, 2),
(2, 0, 0), (3, 0, 2), (1, 0)],
[dict(shape=shape, hermitian=hermitian)
for shape in [(1, 1), (4, 4), (3, 10, 10), (2, 70, 7), (2000, 7),
(7, 1000), (70, 7, 2), (2, 0, 0), (3, 0, 2), (1, 0)]
for hermitian in ([False, True] if shape[-1] == shape[-2] else [False])],
dtype=float_types + complex_types,
)
@jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1
def testPinv(self, shape, dtype):
def testPinv(self, shape, hermitian, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np.linalg.pinv, jnp.linalg.pinv, args_maker,
tol=1e-2)
self._CompileAndCheck(jnp.linalg.pinv, args_maker)
jnp_fn = partial(jnp.linalg.pinv, hermitian=hermitian)
def np_fn(a):
# Symmetrize the input matrix to match the jnp behavior.
if hermitian:
a = (a + T(a.conj())) / 2
return np.linalg.pinv(a, hermitian=hermitian)
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker)
# TODO(phawkins): 1e-1 seems like a very loose tolerance.
jtu.check_grads(jnp.linalg.pinv, args_maker(), 2, rtol=1e-1, atol=2e-1)
jtu.check_grads(jnp_fn, args_maker(), 1, rtol=3e-2, atol=1e-3)
def testPinvGradIssue2792(self):
def f(p):