Add jax.numpy.fill_diagonal.

This commit is contained in:
carlosgmartin 2023-10-20 16:47:46 -04:00
parent dfcbfc3915
commit 3cb504c583
6 changed files with 51 additions and 0 deletions

View File

@ -11,6 +11,7 @@ Remember to align the itemized text with the first line of an item within a list
* New Features
* Added {obj}`jax.typing.DTypeLike`, which can be used to annotate objects that
are convertible to JAX dtypes.
* Added `jax.numpy.fill_diagonal`.
* Changes
* JAX now requires SciPy 1.9 or newer.

View File

@ -156,6 +156,7 @@ namespace; they are listed below.
extract
eye
fabs
fill_diagonal
finfo
fix
flatnonzero

View File

@ -2813,6 +2813,34 @@ def tril_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]:
return tril_indices(arr_shape[-2], k=k, m=arr_shape[-1])
@util._wraps(np.fill_diagonal, lax_description="""
The semantics of :func:`numpy.fill_diagonal` is to modify arrays in-place, which
JAX cannot do because JAX arrays are immutable. Thus :func:`jax.numpy.fill_diagonal`
adds the ``inplace`` parameter, which must be set to ``False`` by the user as a
reminder of this API difference.
""", extra_params="""
inplace : bool, default=True
If left to its default value of True, JAX will raise an error. This is because
the semantics of :func:`numpy.fill_diagonal` are to modify the array in-place,
which is not possible in JAX due to the immutability of JAX arrays.
""")
def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: bool = False, *, inplace: bool = True) -> Array:
if inplace:
raise NotImplementedError("JAX arrays are immutable, must use inplace=False")
if wrap:
raise NotImplementedError("wrap=True is not implemented, must use wrap=False")
util.check_arraylike("fill_diagonal", a, val)
a = asarray(a)
val = asarray(val)
if a.ndim < 2:
raise ValueError("array must be at least 2-d")
if a.ndim > 2 and not all(n == a.shape[0] for n in a.shape[1:]):
raise ValueError("All dimensions of input must be of equal length")
n = min(a.shape)
idx = diag_indices(n, a.ndim)
return a.at[idx].set(val if val.ndim == 0 else _tile_to_size(val.ravel(), n))
@util._wraps(np.diag_indices)
def diag_indices(n: int, ndim: int = 2) -> tuple[Array, ...]:
n = core.concrete_or_error(operator.index, n, "'n' argument of jnp.diag_indices()")

View File

@ -96,6 +96,7 @@ from jax._src.numpy.lax_numpy import (
expand_dims as expand_dims,
extract as extract,
eye as eye,
fill_diagonal as fill_diagonal,
finfo as finfo,
fix as fix,
flatnonzero as flatnonzero,

View File

@ -770,6 +770,7 @@ def tril_indices(
n: int, k: int = ..., m: Optional[int] = ...
) -> tuple[Array, Array]: ...
def tril_indices_from(arr: ArrayLike, k: int = ...) -> tuple[Array, Array]: ...
def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: bool = ..., *, inplace: bool = ...) -> Array: ...
def trim_zeros(filt: ArrayLike, trim: str = ...) -> Array: ...
def triu(m: ArrayLike, k: int = ...) -> Array: ...
def triu_indices(

View File

@ -2090,6 +2090,24 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
args_maker = lambda: [rng(shape, dtype), k]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
@jtu.sample_product(
dtype=default_dtypes,
a_shape=[(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1), (2, 2), (1, 2), (0, 2), (2, 3), (2, 2, 2), (2, 2, 2, 2)],
val_shape=[(), (1,), (2,), (1, 2), (3, 2)],
)
def testFillDiagonal(self, dtype, a_shape, val_shape):
rng = jtu.rand_default(self.rng())
def np_fun(a, val):
a_copy = a.copy()
np.fill_diagonal(a_copy, val)
return a_copy
jnp_fun = partial(jnp.fill_diagonal, inplace=False)
args_maker = lambda : [rng(a_shape, dtype), rng(val_shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
ndim=[0, 1, 4],
n=[0, 1, 7],
@ -5350,6 +5368,7 @@ class NumpySignaturesTest(jtu.JaxTestCase):
'einsum': ['subscripts', 'precision'],
'einsum_path': ['subscripts'],
'take_along_axis': ['mode'],
'fill_diagonal': ['inplace'],
}
mismatches = {}