mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #26426 from NKlug:nklug-linalg-solve-docu
PiperOrigin-RevId: 725296564
This commit is contained in:
commit
6a638ac832
@ -1300,21 +1300,25 @@ def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult:
|
||||
@export
|
||||
@jit
|
||||
def solve(a: ArrayLike, b: ArrayLike) -> Array:
|
||||
"""Solve a linear system of equations
|
||||
"""Solve a linear system of equations.
|
||||
|
||||
JAX implementation of :func:`numpy.linalg.solve`.
|
||||
|
||||
This solves a (batched) linear system of equations ``a @ x = b``
|
||||
for ``x`` given ``a`` and ``b``.
|
||||
|
||||
If ``a`` is singular, this will return ``nan`` or ``inf`` values.
|
||||
|
||||
Args:
|
||||
a: array of shape ``(..., N, N)``.
|
||||
b: array of shape ``(N,)`` (for 1-dimensional right-hand-side) or
|
||||
``(..., N, M)`` (for batched 2-dimensional right-hand-side).
|
||||
|
||||
Returns:
|
||||
An array containing the result of the linear solve. The result has shape ``(..., N)``
|
||||
if ``b`` is of shape ``(N,)``, and has shape ``(..., N, M)`` otherwise.
|
||||
An array containing the result of the linear solve if ``a`` is non-singular.
|
||||
The result has shape ``(..., N)`` if ``b`` is of shape ``(N,)``, and has
|
||||
shape ``(..., N, M)`` otherwise.
|
||||
If ``a`` is singular, the result contains ``nan`` or ``inf`` values.
|
||||
|
||||
See also:
|
||||
- :func:`jax.scipy.linalg.solve`: SciPy-style API for solving linear systems.
|
||||
|
@ -1016,13 +1016,15 @@ def _solve(a: ArrayLike, b: ArrayLike, assume_a: str, lower: bool) -> Array:
|
||||
def solve(a: ArrayLike, b: ArrayLike, lower: bool = False,
|
||||
overwrite_a: bool = False, overwrite_b: bool = False, debug: bool = False,
|
||||
check_finite: bool = True, assume_a: str = 'gen') -> Array:
|
||||
"""Solve a linear system of equations
|
||||
"""Solve a linear system of equations.
|
||||
|
||||
JAX implementation of :func:`scipy.linalg.solve`.
|
||||
|
||||
This solves a (batched) linear system of equations ``a @ x = b`` for ``x``
|
||||
given ``a`` and ``b``.
|
||||
|
||||
If ``a`` is singular, this will return ``nan`` or ``inf`` values.
|
||||
|
||||
Args:
|
||||
a: array of shape ``(..., N, N)``.
|
||||
b: array of shape ``(..., N)`` or ``(..., N, M)``
|
||||
@ -1041,7 +1043,9 @@ def solve(a: ArrayLike, b: ArrayLike, lower: bool = False,
|
||||
check_finite: unused by JAX
|
||||
|
||||
Returns:
|
||||
An array of the same shape as ``b`` containing the solution to the linear system.
|
||||
An array of the same shape as ``b`` containing the solution to the linear
|
||||
system if ``a`` is non-singular.
|
||||
If ``a`` is singular, the result contains ``nan`` or ``inf`` values.
|
||||
|
||||
See also:
|
||||
- :func:`jax.scipy.linalg.lu_solve`: Solve via LU factorization.
|
||||
|
Loading…
x
Reference in New Issue
Block a user