mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #19606 from jakevdp:cholesky-upper
PiperOrigin-RevId: 603172800
This commit is contained in:
commit
44a7d022f8
2
.github/workflows/jax-array-api.yml
vendored
2
.github/workflows/jax-array-api.yml
vendored
@ -34,7 +34,7 @@ jobs:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install .[cpu]
|
||||
python -m pip install .[ci]
|
||||
python -m pip install -r array-api-tests/requirements.txt
|
||||
- name: Run the test suite
|
||||
env:
|
||||
|
@ -64,10 +64,12 @@ def _symmetrize(x: Array) -> Array: return (x + _H(x)) / 2
|
||||
|
||||
|
||||
@implements(np.linalg.cholesky)
|
||||
@jit
|
||||
def cholesky(a: ArrayLike) -> Array:
|
||||
@partial(jit, static_argnames=['upper'])
|
||||
def cholesky(a: ArrayLike, *, upper: bool = False) -> Array:
|
||||
check_arraylike("jnp.linalg.cholesky", a)
|
||||
a, = promote_dtypes_inexact(jnp.asarray(a))
|
||||
if upper:
|
||||
a = jax.numpy.matrix_transpose(a)
|
||||
return lax_linalg.cholesky(a)
|
||||
|
||||
@overload
|
||||
|
@ -21,7 +21,7 @@ def cholesky(x, /, *, upper=False):
|
||||
"""
|
||||
Returns the lower (upper) Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix x.
|
||||
"""
|
||||
return jax.numpy.linalg.cholesky(jax.numpy.matrix_transpose(x) if upper else x)
|
||||
return jax.numpy.linalg.cholesky(x, upper=upper)
|
||||
|
||||
def cross(x1, x2, /, *, axis=-1):
|
||||
"""
|
||||
|
@ -59,17 +59,30 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
@jtu.sample_product(
|
||||
shape=[(1, 1), (4, 4), (2, 5, 5), (200, 200), (1000, 0, 0)],
|
||||
dtype=float_types + complex_types,
|
||||
upper=[True, False]
|
||||
)
|
||||
def testCholesky(self, shape, dtype):
|
||||
def testCholesky(self, shape, dtype, upper):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
def args_maker():
|
||||
factor_shape = shape[:-1] + (2 * shape[-1],)
|
||||
a = rng(factor_shape, dtype)
|
||||
return [np.matmul(a, jnp.conj(T(a)))]
|
||||
|
||||
self._CheckAgainstNumpy(np.linalg.cholesky, jnp.linalg.cholesky, args_maker,
|
||||
jnp_fun = partial(jnp.linalg.cholesky, upper=upper)
|
||||
|
||||
def np_fun(x, upper=upper):
|
||||
# Upper argument added in NumPy 2.0.0
|
||||
if jtu.numpy_version() >= (2, 0, 0):
|
||||
return np.linalg.cholesky(x, upper=upper)
|
||||
if upper:
|
||||
axes = list(range(x.ndim))
|
||||
axes[-1], axes[-2] = axes[-2], axes[-1]
|
||||
x = np.transpose(x, axes)
|
||||
return np.linalg.cholesky(x)
|
||||
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(jnp.linalg.cholesky, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
if jnp.finfo(dtype).bits == 64:
|
||||
jtu.check_grads(jnp.linalg.cholesky, args_maker(), order=2)
|
||||
|
Loading…
x
Reference in New Issue
Block a user