Add Hilbert matrix to jax.scipy.linalg

This commit is contained in:
rajasekharporeddy 2024-03-20 22:55:03 +05:30
parent c5869feb92
commit 4d6a53fb63
4 changed files with 18 additions and 0 deletions

View File

@ -43,6 +43,7 @@ jax.scipy.linalg
expm_frechet
funm
hessenberg
hilbert
inv
lu
lu_factor

View File

@ -1035,3 +1035,9 @@ def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array:
(nrows,), (1,), 'VALID', dimension_numbers=('NTC', 'IOT', 'NTC'),
precision=lax.Precision.HIGHEST)[0]
return jnp.flip(patches, axis=0)
@implements(scipy.linalg.hilbert)
@partial(jit, static_argnames=("n",))
def hilbert(n: int) -> Array:
a = lax.broadcasted_iota(jnp.float64, (n, 1), 0)
return 1/(a + a.T + 1)

View File

@ -26,6 +26,7 @@ from jax._src.scipy.linalg import (
expm as expm,
expm_frechet as expm_frechet,
hessenberg as hessenberg,
hilbert as hilbert,
inv as inv,
lu as lu,
lu_factor as lu_factor,

View File

@ -2085,6 +2085,16 @@ class LaxLinalgTest(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
n=[0, 1, 5, 10, 20],
)
def testHilbert(self, n):
args_maker = lambda: []
osp_fun = partial(osp.linalg.hilbert, n=n)
jsp_fun = partial(jsp.linalg.hilbert, n=n)
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker)
self._CompileAndCheck(jsp_fun, args_maker)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())