Merge pull request #22653 from rajasekharporeddy:testbranch1

PiperOrigin-RevId: 655980588
This commit is contained in:
jax authors 2024-07-25 09:29:07 -07:00
commit 9ea79c61f4

View File

@ -5072,13 +5072,17 @@ def tril_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array
@util.implements(np.triu_indices_from)
def triu_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]:
arr_shape = shape(arr)
return triu_indices(arr_shape[-2], k=k, m=arr_shape[-1])
if len(arr_shape) != 2:
raise ValueError("Only 2-D inputs are accepted")
return triu_indices(arr_shape[0], k=k, m=arr_shape[1])
@util.implements(np.tril_indices_from)
def tril_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]:
arr_shape = shape(arr)
return tril_indices(arr_shape[-2], k=k, m=arr_shape[-1])
if len(arr_shape) != 2:
raise ValueError("Only 2-D inputs are accepted")
return tril_indices(arr_shape[0], k=k, m=arr_shape[1])
@util.implements(np.fill_diagonal, lax_description="""