mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add support for the hermitian option on jnp.linalg.pinv.
Improve the pinv implementation to avoid computing an unnecessary reduction: svd sorts its singular values so we don't need to use amax() to find the largest one. Avoid explicitly forming the identity matrix in the pinv JVP.
This commit is contained in:
parent
1e7e8e8d5c
commit
ab8cde9ed4
@ -7,6 +7,8 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
-->
|
||||
|
||||
## jax 0.3.25
|
||||
* Changes
|
||||
* {func}`jax.numpy.linalg.pinv` now supports the `hermitian` option.
|
||||
|
||||
## jaxlib 0.3.25
|
||||
|
||||
|
@ -41,6 +41,9 @@ def _H(x: ArrayLike) -> Array:
|
||||
return jnp.conjugate(jnp.swapaxes(x, -1, -2))
|
||||
|
||||
|
||||
def _symmetrize(x: Array) -> Array: return (x + _H(x)) / 2
|
||||
|
||||
|
||||
@_wraps(np.linalg.cholesky)
|
||||
@jit
|
||||
def cholesky(a: ArrayLike) -> Array:
|
||||
@ -406,27 +409,32 @@ def eigvalsh(a: ArrayLike, UPLO: Optional[str] = 'L') -> Array:
|
||||
return w
|
||||
|
||||
|
||||
@partial(custom_jvp, nondiff_argnums=(1,))
|
||||
@partial(custom_jvp, nondiff_argnums=(1, 2))
|
||||
@_wraps(np.linalg.pinv, lax_description=textwrap.dedent("""\
|
||||
It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the
|
||||
default `rcond` is `1e-15`. Here the default is
|
||||
`10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps`.
|
||||
"""))
|
||||
@jit
|
||||
def pinv(a: ArrayLike, rcond: Optional[ArrayLike] = None) -> Array:
|
||||
@partial(jit, static_argnames=('hermitian',))
|
||||
def pinv(a: ArrayLike, rcond: Optional[ArrayLike] = None,
|
||||
hermitian: bool = False) -> Array:
|
||||
# Uses same algorithm as
|
||||
# https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979
|
||||
_check_arraylike("jnp.linalg.pinv", a)
|
||||
arr = jnp.conj(a)
|
||||
arr = jnp.asarray(a)
|
||||
m, n = arr.shape[-2:]
|
||||
if m == 0 or n == 0:
|
||||
return jnp.empty(arr.shape[:-2] + (n, m), arr.dtype)
|
||||
arr = jnp.conj(arr)
|
||||
if rcond is None:
|
||||
max_rows_cols = max(arr.shape[-2:])
|
||||
rcond = 10. * max_rows_cols * jnp.array(jnp.finfo(arr.dtype).eps)
|
||||
rcond = jnp.asarray(rcond)
|
||||
u, s, vh = svd(arr, full_matrices=False)
|
||||
u, s, vh = svd(arr, full_matrices=False, hermitian=hermitian)
|
||||
# Singular values less than or equal to ``rcond * largest_singular_value``
|
||||
# are set to zero.
|
||||
rcond = lax.expand_dims(rcond[..., jnp.newaxis], range(s.ndim - rcond.ndim - 1))
|
||||
cutoff = rcond * jnp.amax(s, axis=-1, keepdims=True, initial=-jnp.inf)
|
||||
cutoff = rcond * s[..., 0:1]
|
||||
s = jnp.where(s > cutoff, s, jnp.inf).astype(u.dtype)
|
||||
res = jnp.matmul(_T(vh), jnp.divide(_T(u), s[..., jnp.newaxis]),
|
||||
precision=lax.Precision.HIGHEST)
|
||||
@ -435,22 +443,24 @@ def pinv(a: ArrayLike, rcond: Optional[ArrayLike] = None) -> Array:
|
||||
|
||||
@pinv.defjvp
|
||||
@jax.default_matmul_precision("float32")
|
||||
def _pinv_jvp(rcond, primals, tangents):
|
||||
def _pinv_jvp(rcond, hermitian, primals, tangents):
|
||||
# The Differentiation of Pseudo-Inverses and Nonlinear Least Squares Problems
|
||||
# Whose Variables Separate. Author(s): G. H. Golub and V. Pereyra. SIAM
|
||||
# Journal on Numerical Analysis, Vol. 10, No. 2 (Apr., 1973), pp. 413-432.
|
||||
# (via https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Derivative)
|
||||
a, = primals
|
||||
a_dot, = tangents
|
||||
p = pinv(a, rcond=rcond)
|
||||
m, n = a.shape[-2:]
|
||||
# TODO(phawkins): on TPU, we would need to opt into high precision here.
|
||||
# TODO(phawkins): consider if this can be simplified in the Hermitian case.
|
||||
p_dot = -p @ a_dot @ p
|
||||
I_n = lax.expand_dims(jnp.eye(m, dtype=a.dtype), range(a.ndim - 2))
|
||||
p_dot = p_dot + p @ _H(p) @ _H(a_dot) @ (I_n - a @ p)
|
||||
I_m = lax.expand_dims(jnp.eye(n, dtype=a.dtype), range(a.ndim - 2))
|
||||
p_dot = p_dot + (I_m - p @ a) @ _H(a_dot) @ _H(p) @ p
|
||||
p = pinv(a, rcond=rcond, hermitian=hermitian)
|
||||
if hermitian:
|
||||
# svd(..., hermitian=True) symmetrizes its input, and the JVP must match.
|
||||
a = _symmetrize(a)
|
||||
a_dot = _symmetrize(a_dot)
|
||||
|
||||
# TODO(phawkins): this could be simplified in the Hermitian case if we
|
||||
# supported triangular matrix multiplication.
|
||||
s = p @ _H(p) @ _H(a_dot)
|
||||
t = _H(a_dot) @ _H(p) @ p
|
||||
p_dot = -p @ a_dot @ p + s - s @ a @ p + t - p @ a @ t
|
||||
return p, p_dot
|
||||
|
||||
|
||||
|
@ -798,21 +798,28 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(jnp.linalg.inv, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=[(1, 1), (4, 4), (2, 70, 7), (2000, 7), (7, 1000), (70, 7, 2),
|
||||
(2, 0, 0), (3, 0, 2), (1, 0)],
|
||||
[dict(shape=shape, hermitian=hermitian)
|
||||
for shape in [(1, 1), (4, 4), (3, 10, 10), (2, 70, 7), (2000, 7),
|
||||
(7, 1000), (70, 7, 2), (2, 0, 0), (3, 0, 2), (1, 0)]
|
||||
for hermitian in ([False, True] if shape[-1] == shape[-2] else [False])],
|
||||
dtype=float_types + complex_types,
|
||||
)
|
||||
@jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1
|
||||
def testPinv(self, shape, dtype):
|
||||
def testPinv(self, shape, hermitian, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
|
||||
self._CheckAgainstNumpy(np.linalg.pinv, jnp.linalg.pinv, args_maker,
|
||||
tol=1e-2)
|
||||
self._CompileAndCheck(jnp.linalg.pinv, args_maker)
|
||||
jnp_fn = partial(jnp.linalg.pinv, hermitian=hermitian)
|
||||
def np_fn(a):
|
||||
# Symmetrize the input matrix to match the jnp behavior.
|
||||
if hermitian:
|
||||
a = (a + T(a.conj())) / 2
|
||||
return np.linalg.pinv(a, hermitian=hermitian)
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker)
|
||||
|
||||
# TODO(phawkins): 1e-1 seems like a very loose tolerance.
|
||||
jtu.check_grads(jnp.linalg.pinv, args_maker(), 2, rtol=1e-1, atol=2e-1)
|
||||
jtu.check_grads(jnp_fn, args_maker(), 1, rtol=3e-2, atol=1e-3)
|
||||
|
||||
def testPinvGradIssue2792(self):
|
||||
def f(p):
|
||||
|
Loading…
x
Reference in New Issue
Block a user