Merge pull request #19606 from jakevdp:cholesky-upper

PiperOrigin-RevId: 603172800
This commit is contained in:
jax authors 2024-01-31 15:13:02 -08:00
commit 44a7d022f8
4 changed files with 22 additions and 7 deletions

View File

@ -34,7 +34,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install .[cpu]
python -m pip install .[ci]
python -m pip install -r array-api-tests/requirements.txt
- name: Run the test suite
env:

View File

@ -64,10 +64,12 @@ def _symmetrize(x: Array) -> Array: return (x + _H(x)) / 2
@implements(np.linalg.cholesky)
@jit
def cholesky(a: ArrayLike) -> Array:
@partial(jit, static_argnames=['upper'])
def cholesky(a: ArrayLike, *, upper: bool = False) -> Array:
check_arraylike("jnp.linalg.cholesky", a)
a, = promote_dtypes_inexact(jnp.asarray(a))
if upper:
a = jax.numpy.matrix_transpose(a)
return lax_linalg.cholesky(a)
@overload

View File

@ -21,7 +21,7 @@ def cholesky(x, /, *, upper=False):
"""
Returns the lower (upper) Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix x.
"""
return jax.numpy.linalg.cholesky(jax.numpy.matrix_transpose(x) if upper else x)
return jax.numpy.linalg.cholesky(x, upper=upper)
def cross(x1, x2, /, *, axis=-1):
"""

View File

@ -59,17 +59,30 @@ class NumpyLinalgTest(jtu.JaxTestCase):
@jtu.sample_product(
shape=[(1, 1), (4, 4), (2, 5, 5), (200, 200), (1000, 0, 0)],
dtype=float_types + complex_types,
upper=[True, False]
)
def testCholesky(self, shape, dtype):
def testCholesky(self, shape, dtype, upper):
rng = jtu.rand_default(self.rng())
def args_maker():
factor_shape = shape[:-1] + (2 * shape[-1],)
a = rng(factor_shape, dtype)
return [np.matmul(a, jnp.conj(T(a)))]
self._CheckAgainstNumpy(np.linalg.cholesky, jnp.linalg.cholesky, args_maker,
jnp_fun = partial(jnp.linalg.cholesky, upper=upper)
def np_fun(x, upper=upper):
# Upper argument added in NumPy 2.0.0
if jtu.numpy_version() >= (2, 0, 0):
return np.linalg.cholesky(x, upper=upper)
if upper:
axes = list(range(x.ndim))
axes[-1], axes[-2] = axes[-2], axes[-1]
x = np.transpose(x, axes)
return np.linalg.cholesky(x)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
tol=1e-3)
self._CompileAndCheck(jnp.linalg.cholesky, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
if jnp.finfo(dtype).bits == 64:
jtu.check_grads(jnp.linalg.cholesky, args_maker(), order=2)