Merge pull request #26426 from NKlug:nklug-linalg-solve-docu

PiperOrigin-RevId: 725296564
This commit is contained in:
jax authors 2025-02-10 11:57:48 -08:00
commit 6a638ac832
2 changed files with 13 additions and 5 deletions

View File

@ -1300,21 +1300,25 @@ def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult:
@export
@jit
def solve(a: ArrayLike, b: ArrayLike) -> Array:
"""Solve a linear system of equations
"""Solve a linear system of equations.
JAX implementation of :func:`numpy.linalg.solve`.
This solves a (batched) linear system of equations ``a @ x = b``
for ``x`` given ``a`` and ``b``.
If ``a`` is singular, this will return ``nan`` or ``inf`` values.
Args:
a: array of shape ``(..., N, N)``.
b: array of shape ``(N,)`` (for 1-dimensional right-hand-side) or
``(..., N, M)`` (for batched 2-dimensional right-hand-side).
Returns:
An array containing the result of the linear solve. The result has shape ``(..., N)``
if ``b`` is of shape ``(N,)``, and has shape ``(..., N, M)`` otherwise.
An array containing the result of the linear solve if ``a`` is non-singular.
The result has shape ``(..., N)`` if ``b`` is of shape ``(N,)``, and has
shape ``(..., N, M)`` otherwise.
If ``a`` is singular, the result contains ``nan`` or ``inf`` values.
See also:
- :func:`jax.scipy.linalg.solve`: SciPy-style API for solving linear systems.

View File

@ -1016,13 +1016,15 @@ def _solve(a: ArrayLike, b: ArrayLike, assume_a: str, lower: bool) -> Array:
def solve(a: ArrayLike, b: ArrayLike, lower: bool = False,
overwrite_a: bool = False, overwrite_b: bool = False, debug: bool = False,
check_finite: bool = True, assume_a: str = 'gen') -> Array:
"""Solve a linear system of equations
"""Solve a linear system of equations.
JAX implementation of :func:`scipy.linalg.solve`.
This solves a (batched) linear system of equations ``a @ x = b`` for ``x``
given ``a`` and ``b``.
If ``a`` is singular, this will return ``nan`` or ``inf`` values.
Args:
a: array of shape ``(..., N, N)``.
b: array of shape ``(..., N)`` or ``(..., N, M)``
@ -1041,7 +1043,9 @@ def solve(a: ArrayLike, b: ArrayLike, lower: bool = False,
check_finite: unused by JAX
Returns:
An array of the same shape as ``b`` containing the solution to the linear system.
An array of the same shape as ``b`` containing the solution to the linear
system if ``a`` is non-singular.
If ``a`` is singular, the result contains ``nan`` or ``inf`` values.
See also:
- :func:`jax.scipy.linalg.lu_solve`: Solve via LU factorization.