From 3cb504c583317bc10da78d225e940be4e4381001 Mon Sep 17 00:00:00 2001 From: carlosgmartin Date: Fri, 20 Oct 2023 16:47:46 -0400 Subject: [PATCH] Add jax.numpy.fill_diagonal. --- CHANGELOG.md | 1 + docs/jax.numpy.rst | 1 + jax/_src/numpy/lax_numpy.py | 28 ++++++++++++++++++++++++++++ jax/numpy/__init__.py | 1 + jax/numpy/__init__.pyi | 1 + tests/lax_numpy_test.py | 19 +++++++++++++++++++ 6 files changed, 51 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a2bf4202..e3a5788d2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ Remember to align the itemized text with the first line of an item within a list * New Features * Added {obj}`jax.typing.DTypeLike`, which can be used to annotate objects that are convertible to JAX dtypes. + * Added `jax.numpy.fill_diagonal`. * Changes * JAX now requires SciPy 1.9 or newer. diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index 4194baff1..e75204766 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -156,6 +156,7 @@ namespace; they are listed below. extract eye fabs + fill_diagonal finfo fix flatnonzero diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 850b44762..aebc293b6 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2813,6 +2813,34 @@ def tril_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: return tril_indices(arr_shape[-2], k=k, m=arr_shape[-1]) +@util._wraps(np.fill_diagonal, lax_description=""" +The semantics of :func:`numpy.fill_diagonal` is to modify arrays in-place, which +JAX cannot do because JAX arrays are immutable. Thus :func:`jax.numpy.fill_diagonal` +adds the ``inplace`` parameter, which must be set to ``False`` by the user as a +reminder of this API difference. +""", extra_params=""" +inplace : bool, default=True + If left to its default value of True, JAX will raise an error. This is because + the semantics of :func:`numpy.fill_diagonal` are to modify the array in-place, + which is not possible in JAX due to the immutability of JAX arrays. +""") +def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: bool = False, *, inplace: bool = True) -> Array: + if inplace: + raise NotImplementedError("JAX arrays are immutable, must use inplace=False") + if wrap: + raise NotImplementedError("wrap=True is not implemented, must use wrap=False") + util.check_arraylike("fill_diagonal", a, val) + a = asarray(a) + val = asarray(val) + if a.ndim < 2: + raise ValueError("array must be at least 2-d") + if a.ndim > 2 and not all(n == a.shape[0] for n in a.shape[1:]): + raise ValueError("All dimensions of input must be of equal length") + n = min(a.shape) + idx = diag_indices(n, a.ndim) + return a.at[idx].set(val if val.ndim == 0 else _tile_to_size(val.ravel(), n)) + + @util._wraps(np.diag_indices) def diag_indices(n: int, ndim: int = 2) -> tuple[Array, ...]: n = core.concrete_or_error(operator.index, n, "'n' argument of jnp.diag_indices()") diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 63839b26b..82786e4cf 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -96,6 +96,7 @@ from jax._src.numpy.lax_numpy import ( expand_dims as expand_dims, extract as extract, eye as eye, + fill_diagonal as fill_diagonal, finfo as finfo, fix as fix, flatnonzero as flatnonzero, diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 64f34af83..35b406ea0 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -770,6 +770,7 @@ def tril_indices( n: int, k: int = ..., m: Optional[int] = ... ) -> tuple[Array, Array]: ... def tril_indices_from(arr: ArrayLike, k: int = ...) -> tuple[Array, Array]: ... +def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: bool = ..., *, inplace: bool = ...) -> Array: ... def trim_zeros(filt: ArrayLike, trim: str = ...) -> Array: ... def triu(m: ArrayLike, k: int = ...) -> Array: ... def triu_indices( diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 567cce8bf..5f598a203 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -2090,6 +2090,24 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): args_maker = lambda: [rng(shape, dtype), k] self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + @jtu.sample_product( + dtype=default_dtypes, + a_shape=[(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1), (2, 2), (1, 2), (0, 2), (2, 3), (2, 2, 2), (2, 2, 2, 2)], + val_shape=[(), (1,), (2,), (1, 2), (3, 2)], + ) + def testFillDiagonal(self, dtype, a_shape, val_shape): + rng = jtu.rand_default(self.rng()) + + def np_fun(a, val): + a_copy = a.copy() + np.fill_diagonal(a_copy, val) + return a_copy + + jnp_fun = partial(jnp.fill_diagonal, inplace=False) + args_maker = lambda : [rng(a_shape, dtype), rng(val_shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( ndim=[0, 1, 4], n=[0, 1, 7], @@ -5350,6 +5368,7 @@ class NumpySignaturesTest(jtu.JaxTestCase): 'einsum': ['subscripts', 'precision'], 'einsum_path': ['subscripts'], 'take_along_axis': ['mode'], + 'fill_diagonal': ['inplace'], } mismatches = {}