mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add jax.numpy.fill_diagonal.
This commit is contained in:
parent
dfcbfc3915
commit
3cb504c583
@ -11,6 +11,7 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* New Features
|
||||
* Added {obj}`jax.typing.DTypeLike`, which can be used to annotate objects that
|
||||
are convertible to JAX dtypes.
|
||||
* Added `jax.numpy.fill_diagonal`.
|
||||
|
||||
* Changes
|
||||
* JAX now requires SciPy 1.9 or newer.
|
||||
|
@ -156,6 +156,7 @@ namespace; they are listed below.
|
||||
extract
|
||||
eye
|
||||
fabs
|
||||
fill_diagonal
|
||||
finfo
|
||||
fix
|
||||
flatnonzero
|
||||
|
@ -2813,6 +2813,34 @@ def tril_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]:
|
||||
return tril_indices(arr_shape[-2], k=k, m=arr_shape[-1])
|
||||
|
||||
|
||||
@util._wraps(np.fill_diagonal, lax_description="""
|
||||
The semantics of :func:`numpy.fill_diagonal` is to modify arrays in-place, which
|
||||
JAX cannot do because JAX arrays are immutable. Thus :func:`jax.numpy.fill_diagonal`
|
||||
adds the ``inplace`` parameter, which must be set to ``False`` by the user as a
|
||||
reminder of this API difference.
|
||||
""", extra_params="""
|
||||
inplace : bool, default=True
|
||||
If left to its default value of True, JAX will raise an error. This is because
|
||||
the semantics of :func:`numpy.fill_diagonal` are to modify the array in-place,
|
||||
which is not possible in JAX due to the immutability of JAX arrays.
|
||||
""")
|
||||
def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: bool = False, *, inplace: bool = True) -> Array:
|
||||
if inplace:
|
||||
raise NotImplementedError("JAX arrays are immutable, must use inplace=False")
|
||||
if wrap:
|
||||
raise NotImplementedError("wrap=True is not implemented, must use wrap=False")
|
||||
util.check_arraylike("fill_diagonal", a, val)
|
||||
a = asarray(a)
|
||||
val = asarray(val)
|
||||
if a.ndim < 2:
|
||||
raise ValueError("array must be at least 2-d")
|
||||
if a.ndim > 2 and not all(n == a.shape[0] for n in a.shape[1:]):
|
||||
raise ValueError("All dimensions of input must be of equal length")
|
||||
n = min(a.shape)
|
||||
idx = diag_indices(n, a.ndim)
|
||||
return a.at[idx].set(val if val.ndim == 0 else _tile_to_size(val.ravel(), n))
|
||||
|
||||
|
||||
@util._wraps(np.diag_indices)
|
||||
def diag_indices(n: int, ndim: int = 2) -> tuple[Array, ...]:
|
||||
n = core.concrete_or_error(operator.index, n, "'n' argument of jnp.diag_indices()")
|
||||
|
@ -96,6 +96,7 @@ from jax._src.numpy.lax_numpy import (
|
||||
expand_dims as expand_dims,
|
||||
extract as extract,
|
||||
eye as eye,
|
||||
fill_diagonal as fill_diagonal,
|
||||
finfo as finfo,
|
||||
fix as fix,
|
||||
flatnonzero as flatnonzero,
|
||||
|
@ -770,6 +770,7 @@ def tril_indices(
|
||||
n: int, k: int = ..., m: Optional[int] = ...
|
||||
) -> tuple[Array, Array]: ...
|
||||
def tril_indices_from(arr: ArrayLike, k: int = ...) -> tuple[Array, Array]: ...
|
||||
def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: bool = ..., *, inplace: bool = ...) -> Array: ...
|
||||
def trim_zeros(filt: ArrayLike, trim: str = ...) -> Array: ...
|
||||
def triu(m: ArrayLike, k: int = ...) -> Array: ...
|
||||
def triu_indices(
|
||||
|
@ -2090,6 +2090,24 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng(shape, dtype), k]
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
dtype=default_dtypes,
|
||||
a_shape=[(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1), (2, 2), (1, 2), (0, 2), (2, 3), (2, 2, 2), (2, 2, 2, 2)],
|
||||
val_shape=[(), (1,), (2,), (1, 2), (3, 2)],
|
||||
)
|
||||
def testFillDiagonal(self, dtype, a_shape, val_shape):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
|
||||
def np_fun(a, val):
|
||||
a_copy = a.copy()
|
||||
np.fill_diagonal(a_copy, val)
|
||||
return a_copy
|
||||
|
||||
jnp_fun = partial(jnp.fill_diagonal, inplace=False)
|
||||
args_maker = lambda : [rng(a_shape, dtype), rng(val_shape, dtype)]
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
ndim=[0, 1, 4],
|
||||
n=[0, 1, 7],
|
||||
@ -5350,6 +5368,7 @@ class NumpySignaturesTest(jtu.JaxTestCase):
|
||||
'einsum': ['subscripts', 'precision'],
|
||||
'einsum_path': ['subscripts'],
|
||||
'take_along_axis': ['mode'],
|
||||
'fill_diagonal': ['inplace'],
|
||||
}
|
||||
|
||||
mismatches = {}
|
||||
|
Loading…
x
Reference in New Issue
Block a user