mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 21:36:05 +00:00
add pascal matrix
This commit is contained in:
parent
51560bf3f5
commit
81abbac536
@ -69,6 +69,7 @@ jax.scipy.linalg
|
||||
lu
|
||||
lu_factor
|
||||
lu_solve
|
||||
pascal
|
||||
polar
|
||||
qr
|
||||
rsf2csf
|
||||
|
@ -2182,3 +2182,64 @@ def hilbert(n: int) -> Array:
|
||||
"""
|
||||
a = lax.broadcasted_iota(jnp.float64, (n, 1), 0)
|
||||
return 1/(a + a.T + 1)
|
||||
|
||||
@partial(jit, static_argnames=("n", "kind",))
|
||||
def pascal(n: int, kind: str | None = None) -> Array:
|
||||
r"""Create a Pascal matrix approximation of order n.
|
||||
|
||||
JAX implementation of :func:`scipy.linalg.pascal`.
|
||||
|
||||
The elements of the Pascal matrix approximate the binomial coefficents. This
|
||||
implementation is not exact as JAX does not support exact factorials.
|
||||
|
||||
Args:
|
||||
n: the size of the matrix to create.
|
||||
kind: (optional) must be one of ``lower``, ``upper``, or ``symmetric`` (default).
|
||||
|
||||
Returns:
|
||||
A Pascal matrix of shape ``(n, n)``
|
||||
|
||||
Examples:
|
||||
>>> with jnp.printoptions(precision=3):
|
||||
... print(jax.scipy.linalg.pascal(3, kind="lower"))
|
||||
... print(jax.scipy.linalg.pascal(4, kind="upper"))
|
||||
... print(jax.scipy.linalg.pascal(5))
|
||||
[[1. 0. 0.]
|
||||
[1. 1. 0.]
|
||||
[1. 2. 1.]]
|
||||
[[1. 1. 1. 1.]
|
||||
[0. 1. 2. 3.]
|
||||
[0. 0. 1. 3.]
|
||||
[0. 0. 0. 1.]]
|
||||
[[ 1. 1. 1. 1. 1.]
|
||||
[ 1. 2. 3. 4. 5.]
|
||||
[ 1. 3. 6. 10. 15.]
|
||||
[ 1. 4. 10. 20. 35.]
|
||||
[ 1. 5. 15. 35. 70.]]
|
||||
"""
|
||||
if kind is None:
|
||||
kind = "symmetric"
|
||||
|
||||
valid_kind = ["symmetric", "lower", "upper"]
|
||||
|
||||
if kind not in valid_kind:
|
||||
raise ValueError(f"Expected kind to be on of: {valid_kind}; got {kind}")
|
||||
|
||||
a = jnp.arange(n, dtype=jnp.float32)
|
||||
|
||||
L_n = _binom(a[:, None], a[None, :])
|
||||
|
||||
if kind == "lower":
|
||||
return L_n
|
||||
|
||||
if kind == "upper":
|
||||
return L_n.T
|
||||
|
||||
return jnp.dot(L_n, L_n.T)
|
||||
|
||||
@jit
|
||||
def _binom(n, k):
|
||||
a = lax.lgamma(n + 1.0)
|
||||
b = lax.lgamma(n - k + 1.0)
|
||||
c = lax.lgamma(k + 1.0)
|
||||
return lax.exp(a - b - c)
|
||||
|
@ -31,6 +31,7 @@ from jax._src.scipy.linalg import (
|
||||
lu as lu,
|
||||
lu_factor as lu_factor,
|
||||
lu_solve as lu_solve,
|
||||
pascal as pascal,
|
||||
polar as polar,
|
||||
qr as qr,
|
||||
rsf2csf as rsf2csf,
|
||||
|
@ -2329,6 +2329,22 @@ class LaxLinalgTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(
|
||||
new_product_with_batching, old_product, atol=atol)
|
||||
|
||||
@jtu.sample_product(
|
||||
n=[0, 1, 5, 10, 20],
|
||||
kind=["symmetric", "lower", "upper"],
|
||||
)
|
||||
@jax.default_matmul_precision("float32")
|
||||
def testPascal(self, n, kind):
|
||||
args_maker = lambda: []
|
||||
osp_fun = partial(osp.linalg.pascal, n=n, kind=kind, exact=False)
|
||||
jsp_fun = partial(jsp.linalg.pascal, n=n, kind=kind)
|
||||
self._CheckAgainstNumpy(osp_fun,
|
||||
jsp_fun, args_maker,
|
||||
atol=1e-3,
|
||||
rtol=1e-2 if jtu.test_device_matches(['tpu']) else 1e-3,
|
||||
check_dtypes=False)
|
||||
self._CompileAndCheck(jsp_fun, args_maker)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user