1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 21:36:05 +00:00

add pascal matrix

This commit is contained in:
Matt Bahr 2025-03-25 06:36:28 +00:00
parent 51560bf3f5
commit 81abbac536
4 changed files with 79 additions and 0 deletions

@ -69,6 +69,7 @@ jax.scipy.linalg
lu
lu_factor
lu_solve
pascal
polar
qr
rsf2csf

@ -2182,3 +2182,64 @@ def hilbert(n: int) -> Array:
"""
a = lax.broadcasted_iota(jnp.float64, (n, 1), 0)
return 1/(a + a.T + 1)
@partial(jit, static_argnames=("n", "kind",))
def pascal(n: int, kind: str | None = None) -> Array:
r"""Create a Pascal matrix approximation of order n.
JAX implementation of :func:`scipy.linalg.pascal`.
The elements of the Pascal matrix approximate the binomial coefficents. This
implementation is not exact as JAX does not support exact factorials.
Args:
n: the size of the matrix to create.
kind: (optional) must be one of ``lower``, ``upper``, or ``symmetric`` (default).
Returns:
A Pascal matrix of shape ``(n, n)``
Examples:
>>> with jnp.printoptions(precision=3):
... print(jax.scipy.linalg.pascal(3, kind="lower"))
... print(jax.scipy.linalg.pascal(4, kind="upper"))
... print(jax.scipy.linalg.pascal(5))
[[1. 0. 0.]
[1. 1. 0.]
[1. 2. 1.]]
[[1. 1. 1. 1.]
[0. 1. 2. 3.]
[0. 0. 1. 3.]
[0. 0. 0. 1.]]
[[ 1. 1. 1. 1. 1.]
[ 1. 2. 3. 4. 5.]
[ 1. 3. 6. 10. 15.]
[ 1. 4. 10. 20. 35.]
[ 1. 5. 15. 35. 70.]]
"""
if kind is None:
kind = "symmetric"
valid_kind = ["symmetric", "lower", "upper"]
if kind not in valid_kind:
raise ValueError(f"Expected kind to be on of: {valid_kind}; got {kind}")
a = jnp.arange(n, dtype=jnp.float32)
L_n = _binom(a[:, None], a[None, :])
if kind == "lower":
return L_n
if kind == "upper":
return L_n.T
return jnp.dot(L_n, L_n.T)
@jit
def _binom(n, k):
a = lax.lgamma(n + 1.0)
b = lax.lgamma(n - k + 1.0)
c = lax.lgamma(k + 1.0)
return lax.exp(a - b - c)

@ -31,6 +31,7 @@ from jax._src.scipy.linalg import (
lu as lu,
lu_factor as lu_factor,
lu_solve as lu_solve,
pascal as pascal,
polar as polar,
qr as qr,
rsf2csf as rsf2csf,

@ -2329,6 +2329,22 @@ class LaxLinalgTest(jtu.JaxTestCase):
self.assertAllClose(
new_product_with_batching, old_product, atol=atol)
@jtu.sample_product(
n=[0, 1, 5, 10, 20],
kind=["symmetric", "lower", "upper"],
)
@jax.default_matmul_precision("float32")
def testPascal(self, n, kind):
args_maker = lambda: []
osp_fun = partial(osp.linalg.pascal, n=n, kind=kind, exact=False)
jsp_fun = partial(jsp.linalg.pascal, n=n, kind=kind)
self._CheckAgainstNumpy(osp_fun,
jsp_fun, args_maker,
atol=1e-3,
rtol=1e-2 if jtu.test_device_matches(['tpu']) else 1e-3,
check_dtypes=False)
self._CompileAndCheck(jsp_fun, args_maker)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())