mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #22653 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 655980588
This commit is contained in:
commit
9ea79c61f4
@ -5072,13 +5072,17 @@ def tril_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array
|
||||
@util.implements(np.triu_indices_from)
|
||||
def triu_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]:
|
||||
arr_shape = shape(arr)
|
||||
return triu_indices(arr_shape[-2], k=k, m=arr_shape[-1])
|
||||
if len(arr_shape) != 2:
|
||||
raise ValueError("Only 2-D inputs are accepted")
|
||||
return triu_indices(arr_shape[0], k=k, m=arr_shape[1])
|
||||
|
||||
|
||||
@util.implements(np.tril_indices_from)
|
||||
def tril_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]:
|
||||
arr_shape = shape(arr)
|
||||
return tril_indices(arr_shape[-2], k=k, m=arr_shape[-1])
|
||||
if len(arr_shape) != 2:
|
||||
raise ValueError("Only 2-D inputs are accepted")
|
||||
return tril_indices(arr_shape[0], k=k, m=arr_shape[1])
|
||||
|
||||
|
||||
@util.implements(np.fill_diagonal, lax_description="""
|
||||
|
Loading…
x
Reference in New Issue
Block a user