Merge pull request #26128 from jakevdp:norm-doc

PiperOrigin-RevId: 720243405
This commit is contained in:
jax authors 2025-01-27 11:24:57 -08:00
commit 763ffb3f73

View File

@ -1085,7 +1085,8 @@ def norm(x: ArrayLike, ord: int | str | None = None,
ord: specify the kind of norm to take. Default is Frobenius norm for matrices,
and the 2-norm for vectors. For other options, see Notes below.
axis: integer or sequence of integers specifying the axes over which the norm
will be computed. Defaults to all axes of ``x``.
will be computed. For a single axis, compute a vector norm. For two axes,
compute a matrix norm. Defaults to all axes of ``x``.
keepdims: if True, the output array will have the same number of dimensions as
the input, with the size of reduced axes replaced by ``1`` (default: False).
@ -1113,6 +1114,9 @@ def norm(x: ArrayLike, ord: int | str | None = None,
- ``ord=2`` computes the 2-norm, i.e. the largest singular value
- ``ord=-2`` computes the smallest singular value
In the special case of ``ord=None`` and ``axis=None``, this function accepts an
array of any dimension and computes the vector 2-norm of the flattened array.
Examples:
Vector norms:
@ -1201,8 +1205,8 @@ def norm(x: ArrayLike, ord: int | str | None = None,
else:
raise ValueError(f"Invalid order '{ord}' for matrix norm.")
else:
raise ValueError(
f"Invalid axis values ({axis}) for jnp.linalg.norm.")
raise ValueError(f"Improper number of axes for norm: {axis=}. Pass one axis to"
" compute a vector-norm, or two axes to compute a matrix-norm.")
@overload
def qr(a: ArrayLike, mode: Literal["r"]) -> Array: ...