mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add Hilbert matrix to jax.scipy.linalg
This commit is contained in:
parent
c5869feb92
commit
4d6a53fb63
@ -43,6 +43,7 @@ jax.scipy.linalg
|
||||
expm_frechet
|
||||
funm
|
||||
hessenberg
|
||||
hilbert
|
||||
inv
|
||||
lu
|
||||
lu_factor
|
||||
|
@ -1035,3 +1035,9 @@ def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array:
|
||||
(nrows,), (1,), 'VALID', dimension_numbers=('NTC', 'IOT', 'NTC'),
|
||||
precision=lax.Precision.HIGHEST)[0]
|
||||
return jnp.flip(patches, axis=0)
|
||||
|
||||
@implements(scipy.linalg.hilbert)
|
||||
@partial(jit, static_argnames=("n",))
|
||||
def hilbert(n: int) -> Array:
|
||||
a = lax.broadcasted_iota(jnp.float64, (n, 1), 0)
|
||||
return 1/(a + a.T + 1)
|
||||
|
@ -26,6 +26,7 @@ from jax._src.scipy.linalg import (
|
||||
expm as expm,
|
||||
expm_frechet as expm_frechet,
|
||||
hessenberg as hessenberg,
|
||||
hilbert as hilbert,
|
||||
inv as inv,
|
||||
lu as lu,
|
||||
lu_factor as lu_factor,
|
||||
|
@ -2085,6 +2085,16 @@ class LaxLinalgTest(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
n=[0, 1, 5, 10, 20],
|
||||
)
|
||||
def testHilbert(self, n):
|
||||
args_maker = lambda: []
|
||||
osp_fun = partial(osp.linalg.hilbert, n=n)
|
||||
jsp_fun = partial(jsp.linalg.hilbert, n=n)
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker)
|
||||
self._CompileAndCheck(jsp_fun, args_maker)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user