mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add jax.numpy.fill_diagonal.
This commit is contained in:
parent
dfcbfc3915
commit
3cb504c583
@ -11,6 +11,7 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
* New Features
|
* New Features
|
||||||
* Added {obj}`jax.typing.DTypeLike`, which can be used to annotate objects that
|
* Added {obj}`jax.typing.DTypeLike`, which can be used to annotate objects that
|
||||||
are convertible to JAX dtypes.
|
are convertible to JAX dtypes.
|
||||||
|
* Added `jax.numpy.fill_diagonal`.
|
||||||
|
|
||||||
* Changes
|
* Changes
|
||||||
* JAX now requires SciPy 1.9 or newer.
|
* JAX now requires SciPy 1.9 or newer.
|
||||||
|
@ -156,6 +156,7 @@ namespace; they are listed below.
|
|||||||
extract
|
extract
|
||||||
eye
|
eye
|
||||||
fabs
|
fabs
|
||||||
|
fill_diagonal
|
||||||
finfo
|
finfo
|
||||||
fix
|
fix
|
||||||
flatnonzero
|
flatnonzero
|
||||||
|
@ -2813,6 +2813,34 @@ def tril_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]:
|
|||||||
return tril_indices(arr_shape[-2], k=k, m=arr_shape[-1])
|
return tril_indices(arr_shape[-2], k=k, m=arr_shape[-1])
|
||||||
|
|
||||||
|
|
||||||
|
@util._wraps(np.fill_diagonal, lax_description="""
|
||||||
|
The semantics of :func:`numpy.fill_diagonal` is to modify arrays in-place, which
|
||||||
|
JAX cannot do because JAX arrays are immutable. Thus :func:`jax.numpy.fill_diagonal`
|
||||||
|
adds the ``inplace`` parameter, which must be set to ``False`` by the user as a
|
||||||
|
reminder of this API difference.
|
||||||
|
""", extra_params="""
|
||||||
|
inplace : bool, default=True
|
||||||
|
If left to its default value of True, JAX will raise an error. This is because
|
||||||
|
the semantics of :func:`numpy.fill_diagonal` are to modify the array in-place,
|
||||||
|
which is not possible in JAX due to the immutability of JAX arrays.
|
||||||
|
""")
|
||||||
|
def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: bool = False, *, inplace: bool = True) -> Array:
|
||||||
|
if inplace:
|
||||||
|
raise NotImplementedError("JAX arrays are immutable, must use inplace=False")
|
||||||
|
if wrap:
|
||||||
|
raise NotImplementedError("wrap=True is not implemented, must use wrap=False")
|
||||||
|
util.check_arraylike("fill_diagonal", a, val)
|
||||||
|
a = asarray(a)
|
||||||
|
val = asarray(val)
|
||||||
|
if a.ndim < 2:
|
||||||
|
raise ValueError("array must be at least 2-d")
|
||||||
|
if a.ndim > 2 and not all(n == a.shape[0] for n in a.shape[1:]):
|
||||||
|
raise ValueError("All dimensions of input must be of equal length")
|
||||||
|
n = min(a.shape)
|
||||||
|
idx = diag_indices(n, a.ndim)
|
||||||
|
return a.at[idx].set(val if val.ndim == 0 else _tile_to_size(val.ravel(), n))
|
||||||
|
|
||||||
|
|
||||||
@util._wraps(np.diag_indices)
|
@util._wraps(np.diag_indices)
|
||||||
def diag_indices(n: int, ndim: int = 2) -> tuple[Array, ...]:
|
def diag_indices(n: int, ndim: int = 2) -> tuple[Array, ...]:
|
||||||
n = core.concrete_or_error(operator.index, n, "'n' argument of jnp.diag_indices()")
|
n = core.concrete_or_error(operator.index, n, "'n' argument of jnp.diag_indices()")
|
||||||
|
@ -96,6 +96,7 @@ from jax._src.numpy.lax_numpy import (
|
|||||||
expand_dims as expand_dims,
|
expand_dims as expand_dims,
|
||||||
extract as extract,
|
extract as extract,
|
||||||
eye as eye,
|
eye as eye,
|
||||||
|
fill_diagonal as fill_diagonal,
|
||||||
finfo as finfo,
|
finfo as finfo,
|
||||||
fix as fix,
|
fix as fix,
|
||||||
flatnonzero as flatnonzero,
|
flatnonzero as flatnonzero,
|
||||||
|
@ -770,6 +770,7 @@ def tril_indices(
|
|||||||
n: int, k: int = ..., m: Optional[int] = ...
|
n: int, k: int = ..., m: Optional[int] = ...
|
||||||
) -> tuple[Array, Array]: ...
|
) -> tuple[Array, Array]: ...
|
||||||
def tril_indices_from(arr: ArrayLike, k: int = ...) -> tuple[Array, Array]: ...
|
def tril_indices_from(arr: ArrayLike, k: int = ...) -> tuple[Array, Array]: ...
|
||||||
|
def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: bool = ..., *, inplace: bool = ...) -> Array: ...
|
||||||
def trim_zeros(filt: ArrayLike, trim: str = ...) -> Array: ...
|
def trim_zeros(filt: ArrayLike, trim: str = ...) -> Array: ...
|
||||||
def triu(m: ArrayLike, k: int = ...) -> Array: ...
|
def triu(m: ArrayLike, k: int = ...) -> Array: ...
|
||||||
def triu_indices(
|
def triu_indices(
|
||||||
|
@ -2090,6 +2090,24 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
|||||||
args_maker = lambda: [rng(shape, dtype), k]
|
args_maker = lambda: [rng(shape, dtype), k]
|
||||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||||
|
|
||||||
|
@jtu.sample_product(
|
||||||
|
dtype=default_dtypes,
|
||||||
|
a_shape=[(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1), (2, 2), (1, 2), (0, 2), (2, 3), (2, 2, 2), (2, 2, 2, 2)],
|
||||||
|
val_shape=[(), (1,), (2,), (1, 2), (3, 2)],
|
||||||
|
)
|
||||||
|
def testFillDiagonal(self, dtype, a_shape, val_shape):
|
||||||
|
rng = jtu.rand_default(self.rng())
|
||||||
|
|
||||||
|
def np_fun(a, val):
|
||||||
|
a_copy = a.copy()
|
||||||
|
np.fill_diagonal(a_copy, val)
|
||||||
|
return a_copy
|
||||||
|
|
||||||
|
jnp_fun = partial(jnp.fill_diagonal, inplace=False)
|
||||||
|
args_maker = lambda : [rng(a_shape, dtype), rng(val_shape, dtype)]
|
||||||
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||||
|
self._CompileAndCheck(jnp_fun, args_maker)
|
||||||
|
|
||||||
@jtu.sample_product(
|
@jtu.sample_product(
|
||||||
ndim=[0, 1, 4],
|
ndim=[0, 1, 4],
|
||||||
n=[0, 1, 7],
|
n=[0, 1, 7],
|
||||||
@ -5350,6 +5368,7 @@ class NumpySignaturesTest(jtu.JaxTestCase):
|
|||||||
'einsum': ['subscripts', 'precision'],
|
'einsum': ['subscripts', 'precision'],
|
||||||
'einsum_path': ['subscripts'],
|
'einsum_path': ['subscripts'],
|
||||||
'take_along_axis': ['mode'],
|
'take_along_axis': ['mode'],
|
||||||
|
'fill_diagonal': ['inplace'],
|
||||||
}
|
}
|
||||||
|
|
||||||
mismatches = {}
|
mismatches = {}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user