mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #13145 from hawkinsp:pinv
PiperOrigin-RevId: 486935918
This commit is contained in:
commit
3994ac30d5
@ -7,6 +7,8 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
-->
|
||||
|
||||
## jax 0.3.25
|
||||
* Changes
|
||||
* {func}`jax.numpy.linalg.pinv` now supports the `hermitian` option.
|
||||
|
||||
## jaxlib 0.3.25
|
||||
|
||||
|
@ -41,6 +41,9 @@ def _H(x: ArrayLike) -> Array:
|
||||
return jnp.conjugate(jnp.swapaxes(x, -1, -2))
|
||||
|
||||
|
||||
def _symmetrize(x: Array) -> Array: return (x + _H(x)) / 2
|
||||
|
||||
|
||||
@_wraps(np.linalg.cholesky)
|
||||
@jit
|
||||
def cholesky(a: ArrayLike) -> Array:
|
||||
@ -406,27 +409,32 @@ def eigvalsh(a: ArrayLike, UPLO: Optional[str] = 'L') -> Array:
|
||||
return w
|
||||
|
||||
|
||||
@partial(custom_jvp, nondiff_argnums=(1,))
|
||||
@partial(custom_jvp, nondiff_argnums=(1, 2))
|
||||
@_wraps(np.linalg.pinv, lax_description=textwrap.dedent("""\
|
||||
It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the
|
||||
default `rcond` is `1e-15`. Here the default is
|
||||
`10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps`.
|
||||
"""))
|
||||
@jit
|
||||
def pinv(a: ArrayLike, rcond: Optional[ArrayLike] = None) -> Array:
|
||||
@partial(jit, static_argnames=('hermitian',))
|
||||
def pinv(a: ArrayLike, rcond: Optional[ArrayLike] = None,
|
||||
hermitian: bool = False) -> Array:
|
||||
# Uses same algorithm as
|
||||
# https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979
|
||||
_check_arraylike("jnp.linalg.pinv", a)
|
||||
arr = jnp.conj(a)
|
||||
arr = jnp.asarray(a)
|
||||
m, n = arr.shape[-2:]
|
||||
if m == 0 or n == 0:
|
||||
return jnp.empty(arr.shape[:-2] + (n, m), arr.dtype)
|
||||
arr = jnp.conj(arr)
|
||||
if rcond is None:
|
||||
max_rows_cols = max(arr.shape[-2:])
|
||||
rcond = 10. * max_rows_cols * jnp.array(jnp.finfo(arr.dtype).eps)
|
||||
rcond = jnp.asarray(rcond)
|
||||
u, s, vh = svd(arr, full_matrices=False)
|
||||
u, s, vh = svd(arr, full_matrices=False, hermitian=hermitian)
|
||||
# Singular values less than or equal to ``rcond * largest_singular_value``
|
||||
# are set to zero.
|
||||
rcond = lax.expand_dims(rcond[..., jnp.newaxis], range(s.ndim - rcond.ndim - 1))
|
||||
cutoff = rcond * jnp.amax(s, axis=-1, keepdims=True, initial=-jnp.inf)
|
||||
cutoff = rcond * s[..., 0:1]
|
||||
s = jnp.where(s > cutoff, s, jnp.inf).astype(u.dtype)
|
||||
res = jnp.matmul(_T(vh), jnp.divide(_T(u), s[..., jnp.newaxis]),
|
||||
precision=lax.Precision.HIGHEST)
|
||||
@ -435,22 +443,24 @@ def pinv(a: ArrayLike, rcond: Optional[ArrayLike] = None) -> Array:
|
||||
|
||||
@pinv.defjvp
|
||||
@jax.default_matmul_precision("float32")
|
||||
def _pinv_jvp(rcond, primals, tangents):
|
||||
def _pinv_jvp(rcond, hermitian, primals, tangents):
|
||||
# The Differentiation of Pseudo-Inverses and Nonlinear Least Squares Problems
|
||||
# Whose Variables Separate. Author(s): G. H. Golub and V. Pereyra. SIAM
|
||||
# Journal on Numerical Analysis, Vol. 10, No. 2 (Apr., 1973), pp. 413-432.
|
||||
# (via https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Derivative)
|
||||
a, = primals
|
||||
a_dot, = tangents
|
||||
p = pinv(a, rcond=rcond)
|
||||
m, n = a.shape[-2:]
|
||||
# TODO(phawkins): on TPU, we would need to opt into high precision here.
|
||||
# TODO(phawkins): consider if this can be simplified in the Hermitian case.
|
||||
p_dot = -p @ a_dot @ p
|
||||
I_n = lax.expand_dims(jnp.eye(m, dtype=a.dtype), range(a.ndim - 2))
|
||||
p_dot = p_dot + p @ _H(p) @ _H(a_dot) @ (I_n - a @ p)
|
||||
I_m = lax.expand_dims(jnp.eye(n, dtype=a.dtype), range(a.ndim - 2))
|
||||
p_dot = p_dot + (I_m - p @ a) @ _H(a_dot) @ _H(p) @ p
|
||||
p = pinv(a, rcond=rcond, hermitian=hermitian)
|
||||
if hermitian:
|
||||
# svd(..., hermitian=True) symmetrizes its input, and the JVP must match.
|
||||
a = _symmetrize(a)
|
||||
a_dot = _symmetrize(a_dot)
|
||||
|
||||
# TODO(phawkins): this could be simplified in the Hermitian case if we
|
||||
# supported triangular matrix multiplication.
|
||||
s = p @ _H(p) @ _H(a_dot)
|
||||
t = _H(a_dot) @ _H(p) @ p
|
||||
p_dot = -p @ a_dot @ p + s - s @ a @ p + t - p @ a @ t
|
||||
return p, p_dot
|
||||
|
||||
|
||||
|
@ -798,21 +798,28 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(jnp.linalg.inv, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=[(1, 1), (4, 4), (2, 70, 7), (2000, 7), (7, 1000), (70, 7, 2),
|
||||
(2, 0, 0), (3, 0, 2), (1, 0)],
|
||||
[dict(shape=shape, hermitian=hermitian)
|
||||
for shape in [(1, 1), (4, 4), (3, 10, 10), (2, 70, 7), (2000, 7),
|
||||
(7, 1000), (70, 7, 2), (2, 0, 0), (3, 0, 2), (1, 0)]
|
||||
for hermitian in ([False, True] if shape[-1] == shape[-2] else [False])],
|
||||
dtype=float_types + complex_types,
|
||||
)
|
||||
@jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1
|
||||
def testPinv(self, shape, dtype):
|
||||
def testPinv(self, shape, hermitian, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
|
||||
self._CheckAgainstNumpy(np.linalg.pinv, jnp.linalg.pinv, args_maker,
|
||||
tol=1e-2)
|
||||
self._CompileAndCheck(jnp.linalg.pinv, args_maker)
|
||||
jnp_fn = partial(jnp.linalg.pinv, hermitian=hermitian)
|
||||
def np_fn(a):
|
||||
# Symmetrize the input matrix to match the jnp behavior.
|
||||
if hermitian:
|
||||
a = (a + T(a.conj())) / 2
|
||||
return np.linalg.pinv(a, hermitian=hermitian)
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker)
|
||||
|
||||
# TODO(phawkins): 1e-1 seems like a very loose tolerance.
|
||||
jtu.check_grads(jnp.linalg.pinv, args_maker(), 2, rtol=1e-1, atol=2e-1)
|
||||
jtu.check_grads(jnp_fn, args_maker(), 1, rtol=3e-2, atol=1e-3)
|
||||
|
||||
def testPinvGradIssue2792(self):
|
||||
def f(p):
|
||||
|
Loading…
x
Reference in New Issue
Block a user