mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[JAX] Add an option subset_by_index that allows computing a contiguous subset of singular components from svd.
PiperOrigin-RevId: 607493941
This commit is contained in:
parent
0203d15485
commit
7e7094c82d
@ -1800,25 +1800,35 @@ for dtype in jtu.dtypes.all_floating + jtu.dtypes.complex:
|
||||
for shape in [(2, 2), (2, 7), (29, 29), (2, 3, 53), (2, 3, 29, 7)]:
|
||||
for full_matrices in [False, True]:
|
||||
for compute_uv in [False, True]:
|
||||
subset_by_index = None
|
||||
define(
|
||||
lax.linalg.svd_p,
|
||||
f"shape={jtu.format_shape_dtype_string(shape, dtype)}_fullmatrices={full_matrices}_computeuv={compute_uv}",
|
||||
lambda *args: lax.linalg.svd_p.bind(
|
||||
args[0], full_matrices=args[1], compute_uv=args[2]), [
|
||||
RandArg(shape, dtype),
|
||||
StaticArg(full_matrices),
|
||||
StaticArg(compute_uv)
|
||||
],
|
||||
args[0],
|
||||
full_matrices=args[1],
|
||||
compute_uv=args[2],
|
||||
subset_by_index=args[3],
|
||||
),
|
||||
[
|
||||
RandArg(shape, dtype),
|
||||
StaticArg(full_matrices),
|
||||
StaticArg(compute_uv),
|
||||
StaticArg(subset_by_index),
|
||||
],
|
||||
jax_unimplemented=[
|
||||
Limitation(
|
||||
"unimplemented",
|
||||
devices=("cpu", "gpu"),
|
||||
dtypes=[np.float16, dtypes.bfloat16]),
|
||||
dtypes=[np.float16, dtypes.bfloat16],
|
||||
),
|
||||
],
|
||||
shape=shape,
|
||||
dtype=dtype,
|
||||
full_matrices=full_matrices,
|
||||
compute_uv=compute_uv)
|
||||
compute_uv=compute_uv,
|
||||
subset_by_index=subset_by_index,
|
||||
)
|
||||
|
||||
for dtype in jtu.dtypes.all_inexact:
|
||||
for shape in [(0, 0), (5, 5), (2, 6, 6)]:
|
||||
@ -2666,7 +2676,6 @@ for dtype in (np.float32, np.float64):
|
||||
dtype=dtype)
|
||||
|
||||
|
||||
|
||||
def wrap_and_split():
|
||||
key = jax.random.key(42)
|
||||
result = jax.random.split(key, 2)
|
||||
|
@ -298,25 +298,61 @@ def qr(x: ArrayLike, *, full_matrices: bool = True) -> tuple[Array, Array]:
|
||||
q, r = qr_p.bind(x, full_matrices=full_matrices)
|
||||
return q, r
|
||||
|
||||
@overload
|
||||
def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: Literal[True]) -> tuple[Array, Array, Array]: ...
|
||||
|
||||
@overload
|
||||
def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: Literal[False]) -> Array: ...
|
||||
def svd(
|
||||
x: ArrayLike,
|
||||
*,
|
||||
full_matrices: bool = True,
|
||||
compute_uv: Literal[True],
|
||||
subset_by_index: tuple[int, int] | None = None,
|
||||
) -> tuple[Array, Array, Array]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: bool = True) -> Array | tuple[Array, Array, Array]: ...
|
||||
def svd(
|
||||
x: ArrayLike,
|
||||
*,
|
||||
full_matrices: bool = True,
|
||||
compute_uv: Literal[False],
|
||||
subset_by_index: tuple[int, int] | None = None,
|
||||
) -> Array:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def svd(
|
||||
x: ArrayLike,
|
||||
*,
|
||||
full_matrices: bool = True,
|
||||
compute_uv: bool = True,
|
||||
subset_by_index: tuple[int, int] | None = None,
|
||||
) -> Array | tuple[Array, Array, Array]:
|
||||
...
|
||||
|
||||
|
||||
# TODO: Add `max_qdwh_iterations` to the function signature for TPU SVD.
|
||||
@_warn_on_positional_kwargs
|
||||
def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: bool = True) -> Array | tuple[Array, Array, Array]:
|
||||
def svd(
|
||||
x: ArrayLike,
|
||||
*,
|
||||
full_matrices: bool = True,
|
||||
compute_uv: bool = True,
|
||||
subset_by_index: tuple[int, int] | None = None,
|
||||
) -> Array | tuple[Array, Array, Array]:
|
||||
"""Singular value decomposition.
|
||||
|
||||
Returns the singular values if compute_uv is False, otherwise returns a triple
|
||||
containing the left singular vectors, the singular values and the adjoint of
|
||||
the right singular vectors.
|
||||
"""
|
||||
result = svd_p.bind(x, full_matrices=full_matrices, compute_uv=compute_uv)
|
||||
result = svd_p.bind(
|
||||
x,
|
||||
full_matrices=full_matrices,
|
||||
compute_uv=compute_uv,
|
||||
subset_by_index=subset_by_index,
|
||||
)
|
||||
if compute_uv:
|
||||
s, u, v = result
|
||||
return u, s, v
|
||||
@ -324,6 +360,7 @@ def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: bool = True) ->
|
||||
s, = result
|
||||
return s
|
||||
|
||||
|
||||
@_warn_on_positional_kwargs
|
||||
def triangular_solve(a: ArrayLike, b: ArrayLike, *,
|
||||
left_side: bool = False, lower: bool = False,
|
||||
@ -1043,7 +1080,6 @@ mlir.register_lowering(triangular_solve_p, _triangular_solve_cpu_lower,
|
||||
# Support operation for LU decomposition: Transformation of the pivots returned
|
||||
# by LU decomposition into permutations.
|
||||
|
||||
|
||||
# Define this outside lu_pivots_to_permutation to ensure fori_loop cache hits
|
||||
def _lu_pivots_body_fn(i, permutation_and_swaps):
|
||||
permutation, swaps = permutation_and_swaps
|
||||
@ -1138,7 +1174,6 @@ mlir.register_lowering(
|
||||
gpu_linalg.hip_lu_pivots_to_permutation),
|
||||
platform='rocm')
|
||||
|
||||
|
||||
# LU decomposition
|
||||
|
||||
# Computes a pivoted LU decomposition such that
|
||||
@ -1745,35 +1780,50 @@ mlir.register_lowering(qr_p, mlir.lower_fun(_qr_lowering))
|
||||
|
||||
|
||||
# Singular value decomposition
|
||||
def _svd_impl(operand, *, full_matrices, compute_uv, subset_by_index=None):
|
||||
return dispatch.apply_primitive(
|
||||
svd_p,
|
||||
operand,
|
||||
full_matrices=full_matrices,
|
||||
compute_uv=compute_uv,
|
||||
subset_by_index=subset_by_index,
|
||||
)
|
||||
|
||||
def _svd_impl(operand, *, full_matrices, compute_uv):
|
||||
return dispatch.apply_primitive(svd_p, operand, full_matrices=full_matrices,
|
||||
compute_uv=compute_uv)
|
||||
|
||||
def _svd_abstract_eval(operand, *, full_matrices, compute_uv):
|
||||
def _svd_abstract_eval(operand, *, full_matrices, compute_uv, subset_by_index):
|
||||
if isinstance(operand, ShapedArray):
|
||||
if operand.ndim < 2:
|
||||
raise ValueError("Argument to singular value decomposition must have ndims >= 2")
|
||||
|
||||
batch_dims = operand.shape[:-2]
|
||||
m = operand.shape[-2]
|
||||
n = operand.shape[-1]
|
||||
s = operand.update(shape=batch_dims + (min(m, n),),
|
||||
dtype=lax_internal._complex_basetype(operand.dtype))
|
||||
rank = min(m, n)
|
||||
if subset_by_index is not None:
|
||||
if full_matrices and subset_by_index != (0, rank):
|
||||
raise ValueError("full_matrices and subset_by_index cannot both be set")
|
||||
rank = min(rank, subset_by_index[1] - subset_by_index[0])
|
||||
|
||||
s = operand.update(
|
||||
shape=batch_dims + (rank,),
|
||||
dtype=lax_internal._complex_basetype(operand.dtype),
|
||||
)
|
||||
if compute_uv:
|
||||
u = operand.update(shape=batch_dims + (m, m if full_matrices else min(m, n)))
|
||||
vt = operand.update(shape=batch_dims + (n if full_matrices else min(m, n), n))
|
||||
u = operand.update(shape=batch_dims + (m, m if full_matrices else rank))
|
||||
vt = operand.update(shape=batch_dims + (n if full_matrices else rank, n))
|
||||
return s, u, vt
|
||||
else:
|
||||
return s,
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@jax.default_matmul_precision("float32")
|
||||
def _svd_jvp_rule(primals, tangents, *, full_matrices, compute_uv):
|
||||
def _svd_jvp_rule(
|
||||
primals, tangents, *, full_matrices, compute_uv, subset_by_index
|
||||
):
|
||||
A, = primals
|
||||
dA, = tangents
|
||||
s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True)
|
||||
s, U, Vt = svd_p.bind(
|
||||
A, full_matrices=False, compute_uv=True, subset_by_index=subset_by_index
|
||||
)
|
||||
|
||||
if compute_uv and full_matrices:
|
||||
# TODO: implement full matrices case, documented here: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
|
||||
@ -1812,6 +1862,7 @@ def _svd_jvp_rule(primals, tangents, *, full_matrices, compute_uv):
|
||||
|
||||
return (s, U, Vt), (ds, dU, _H(dV))
|
||||
|
||||
|
||||
def _empty_svd(a, *, full_matrices, compute_uv):
|
||||
batch_shape = a.shape[:-2]
|
||||
m, n = a.shape[-2:]
|
||||
@ -1828,8 +1879,17 @@ def _empty_svd(a, *, full_matrices, compute_uv):
|
||||
u, v = v, u
|
||||
return s, u, v
|
||||
|
||||
def _svd_cpu_gpu_lowering(gesvd_impl, ctx, operand, *, full_matrices,
|
||||
compute_uv, platform: str):
|
||||
|
||||
def _svd_cpu_gpu_lowering(
|
||||
gesvd_impl,
|
||||
ctx,
|
||||
operand,
|
||||
*,
|
||||
full_matrices,
|
||||
compute_uv,
|
||||
subset_by_index,
|
||||
platform: str,
|
||||
):
|
||||
operand_aval, = ctx.avals_in
|
||||
s_aval = ctx.avals_out[0]
|
||||
m, n = operand_aval.shape[-2:]
|
||||
@ -1841,9 +1901,16 @@ def _svd_cpu_gpu_lowering(gesvd_impl, ctx, operand, *, full_matrices,
|
||||
f"implemented only for the batch dimensions: {operand_aval.shape}")
|
||||
batch_dims = operand_aval.shape[:-2]
|
||||
|
||||
if not (subset_by_index is None or subset_by_index == (0, min(m, n))):
|
||||
raise NotImplementedError("subset_by_index not implemented for CPU and GPU")
|
||||
|
||||
if m == 0 or n == 0:
|
||||
return mlir.lower_fun(_empty_svd, multiple_results=True)(
|
||||
ctx, operand, full_matrices=full_matrices, compute_uv=compute_uv)
|
||||
ctx,
|
||||
operand,
|
||||
full_matrices=full_matrices,
|
||||
compute_uv=compute_uv,
|
||||
)
|
||||
|
||||
if platform in ["cuda", "rocm"]:
|
||||
if not is_constant_shape(operand_aval.shape):
|
||||
@ -1891,10 +1958,16 @@ def _svd_cpu_gpu_lowering(gesvd_impl, ctx, operand, *, full_matrices,
|
||||
|
||||
return result
|
||||
|
||||
def _svd_tpu(a, *, full_matrices, compute_uv):
|
||||
|
||||
def _svd_tpu(a, *, full_matrices, compute_uv, subset_by_index):
|
||||
batch_dims = a.shape[:-2]
|
||||
|
||||
fn = partial(lax_svd.svd, full_matrices=full_matrices, compute_uv=compute_uv)
|
||||
fn = partial(
|
||||
lax_svd.svd,
|
||||
full_matrices=full_matrices,
|
||||
compute_uv=compute_uv,
|
||||
subset_by_index=subset_by_index,
|
||||
)
|
||||
for _ in range(len(batch_dims)):
|
||||
fn = api.vmap(fn)
|
||||
|
||||
@ -1905,28 +1978,49 @@ def _svd_tpu(a, *, full_matrices, compute_uv):
|
||||
s = fn(a)
|
||||
return [s]
|
||||
|
||||
def _svd_tpu_lowering_rule(ctx, operand, *, full_matrices, compute_uv):
|
||||
|
||||
def _svd_tpu_lowering_rule(
|
||||
ctx, operand, *, full_matrices, compute_uv, subset_by_index
|
||||
):
|
||||
operand_aval, = ctx.avals_in
|
||||
m, n = operand_aval.shape[-2:]
|
||||
|
||||
if m == 0 or n == 0:
|
||||
return mlir.lower_fun(_empty_svd, multiple_results=True)(
|
||||
ctx, operand, full_matrices=full_matrices, compute_uv=compute_uv)
|
||||
ctx,
|
||||
operand,
|
||||
full_matrices=full_matrices,
|
||||
compute_uv=compute_uv,
|
||||
)
|
||||
|
||||
return mlir.lower_fun(_svd_tpu, multiple_results=True)(
|
||||
ctx, operand, full_matrices=full_matrices, compute_uv=compute_uv)
|
||||
ctx,
|
||||
operand,
|
||||
full_matrices=full_matrices,
|
||||
compute_uv=compute_uv,
|
||||
subset_by_index=subset_by_index,
|
||||
)
|
||||
|
||||
def _svd_batching_rule(batched_args, batch_dims, *, full_matrices, compute_uv):
|
||||
|
||||
def _svd_batching_rule(
|
||||
batched_args, batch_dims, *, full_matrices, compute_uv, subset_by_index
|
||||
):
|
||||
x, = batched_args
|
||||
bd, = batch_dims
|
||||
x = batching.moveaxis(x, bd, 0)
|
||||
outs = svd_p.bind(x, full_matrices=full_matrices, compute_uv=compute_uv)
|
||||
outs = svd_p.bind(
|
||||
x,
|
||||
full_matrices=full_matrices,
|
||||
compute_uv=compute_uv,
|
||||
subset_by_index=subset_by_index,
|
||||
)
|
||||
|
||||
if compute_uv:
|
||||
return outs, (0, 0, 0)
|
||||
else:
|
||||
return outs, (0,)
|
||||
|
||||
|
||||
svd_p = Primitive('svd')
|
||||
svd_p.multiple_results = True
|
||||
svd_p.def_impl(_svd_impl)
|
||||
|
@ -37,21 +37,29 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
import functools
|
||||
import operator
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from jax._src import core
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
@functools.partial(jax.jit, static_argnums=(2, 3))
|
||||
@functools.partial(jax.jit, static_argnums=(2, 3, 4))
|
||||
def _constant_svd(
|
||||
a: Any, return_nan: bool, full_matrices: bool, compute_uv: bool = True
|
||||
a: Any,
|
||||
return_nan: bool,
|
||||
full_matrices: bool,
|
||||
compute_uv: bool = True,
|
||||
subset_by_index: tuple[int, int] | None = None,
|
||||
) -> Any | Sequence[Any]:
|
||||
"""SVD on matrix of all zeros."""
|
||||
m, n = a.shape
|
||||
k = min(m, n)
|
||||
if subset_by_index is not None:
|
||||
k = min(k, subset_by_index[1] - subset_by_index[0])
|
||||
|
||||
s = jnp.where(
|
||||
return_nan,
|
||||
jnp.full(shape=(k,), fill_value=jnp.nan, dtype=a.real.dtype),
|
||||
@ -90,9 +98,13 @@ def _constant_svd(
|
||||
return s
|
||||
|
||||
|
||||
@functools.partial(jax.jit, static_argnums=(1, 2, 3))
|
||||
@functools.partial(jax.jit, static_argnums=(1, 2, 3, 4))
|
||||
def _svd_tall_and_square_input(
|
||||
a: Any, hermitian: bool, compute_uv: bool, max_iterations: int
|
||||
a: Any,
|
||||
hermitian: bool,
|
||||
compute_uv: bool,
|
||||
max_iterations: int,
|
||||
subset_by_index: tuple[int, int] | None = None,
|
||||
) -> Any | Sequence[Any]:
|
||||
"""Singular value decomposition for m x n matrix and m >= n.
|
||||
|
||||
@ -113,7 +125,7 @@ def _svd_tall_and_square_input(
|
||||
max_iterations=max_iterations)
|
||||
|
||||
# TODO: Uses `eigvals_only=True` if `compute_uv=False`.
|
||||
v, s = lax.linalg.eigh(h)
|
||||
v, s = lax.linalg.eigh(h, subset_by_index=subset_by_index)
|
||||
# Singular values are non-negative by definition. But eigh could return small
|
||||
# negative values, so we clamp them to zero.
|
||||
s = jnp.maximum(s, 0.0)
|
||||
@ -148,12 +160,15 @@ def _svd_tall_and_square_input(
|
||||
return (u_out, s_out, v_out)
|
||||
|
||||
|
||||
@functools.partial(jax.jit, static_argnums=(1, 2, 3, 4))
|
||||
def _qdwh_svd(a: Any,
|
||||
full_matrices: bool,
|
||||
compute_uv: bool = True,
|
||||
hermitian: bool = False,
|
||||
max_iterations: int = 10) -> Any | Sequence[Any]:
|
||||
@functools.partial(jax.jit, static_argnums=(1, 2, 3, 4, 5))
|
||||
def _qdwh_svd(
|
||||
a: Any,
|
||||
full_matrices: bool,
|
||||
compute_uv: bool = True,
|
||||
hermitian: bool = False,
|
||||
max_iterations: int = 10,
|
||||
subset_by_index: tuple[int, int] | None = None,
|
||||
) -> Any | Sequence[Any]:
|
||||
"""Singular value decomposition.
|
||||
|
||||
Args:
|
||||
@ -196,12 +211,14 @@ def _qdwh_svd(a: Any,
|
||||
|
||||
if not compute_uv:
|
||||
with jax.default_matmul_precision('float32'):
|
||||
return _svd_tall_and_square_input(a, hermitian, compute_uv,
|
||||
max_iterations)
|
||||
return _svd_tall_and_square_input(
|
||||
a, hermitian, compute_uv, max_iterations, subset_by_index
|
||||
)
|
||||
|
||||
with jax.default_matmul_precision('float32'):
|
||||
u_out, s_out, v_out = _svd_tall_and_square_input(
|
||||
a, hermitian, compute_uv, max_iterations)
|
||||
a, hermitian, compute_uv, max_iterations, subset_by_index
|
||||
)
|
||||
if reduce_to_square:
|
||||
u_out = q @ u_out
|
||||
|
||||
@ -214,12 +231,15 @@ def _qdwh_svd(a: Any,
|
||||
return (u_out, s_out, v_out.T.conj())
|
||||
|
||||
|
||||
@functools.partial(jax.jit, static_argnums=(1, 2, 3, 4))
|
||||
def svd(a: Any,
|
||||
full_matrices: bool,
|
||||
compute_uv: bool = True,
|
||||
hermitian: bool = False,
|
||||
max_iterations: int = 10) -> Any | Sequence[Any]:
|
||||
@functools.partial(jax.jit, static_argnums=(1, 2, 3, 4, 5))
|
||||
def svd(
|
||||
a: Any,
|
||||
full_matrices: bool,
|
||||
compute_uv: bool = True,
|
||||
hermitian: bool = False,
|
||||
max_iterations: int = 10,
|
||||
subset_by_index: tuple[int, int] | None = None,
|
||||
) -> Any | Sequence[Any]:
|
||||
"""Singular value decomposition.
|
||||
|
||||
Args:
|
||||
@ -230,6 +250,10 @@ def svd(a: Any,
|
||||
compute_uv: Whether to compute also `u` and `v` in addition to `s`.
|
||||
hermitian: True if `a` is Hermitian.
|
||||
max_iterations: The predefined maximum number of iterations of QDWH.
|
||||
subset_by_index: Optional 2-tuple [start, end] indicating the range of
|
||||
indices of singular componenets to compute. For example, if
|
||||
``subset_by_index`` = [0,2], then ``svd`` computes the two largest
|
||||
singular values (and their singular vectors if `compute_uv` is true.
|
||||
|
||||
Returns:
|
||||
A 3-tuple (`u`, `s`, `vh`), where `u` and `vh` are unitary matrices,
|
||||
@ -247,12 +271,43 @@ def svd(a: Any,
|
||||
'specified to use `svd` within JAX transformations.')
|
||||
|
||||
hermitian = core.concrete_or_error(
|
||||
bool, hermitian, 'The `hermitian` argument must be statically '
|
||||
'specified to use `qdwh` within JAX transformations.')
|
||||
bool,
|
||||
hermitian,
|
||||
'The `hermitian` argument must be statically '
|
||||
'specified to use `svd` within JAX transformations.',
|
||||
)
|
||||
|
||||
max_iterations = core.concrete_or_error(
|
||||
int, max_iterations, 'The `max_iterations` argument must be statically '
|
||||
'specified to use `qdwh` within JAX transformations.')
|
||||
int,
|
||||
max_iterations,
|
||||
'The `max_iterations` argument must be statically '
|
||||
'specified to use `svd` within JAX transformations.',
|
||||
)
|
||||
|
||||
if subset_by_index is not None:
|
||||
if len(subset_by_index) != 2:
|
||||
raise ValueError('subset_by_index must be a tuple of size 2.')
|
||||
# Make sure subset_by_index is a concrete tuple.
|
||||
subset_by_index = (
|
||||
operator.index(subset_by_index[0]),
|
||||
operator.index(subset_by_index[1]),
|
||||
)
|
||||
if subset_by_index[0] >= subset_by_index[1]:
|
||||
raise ValueError('Got empty index range in subset_by_index.')
|
||||
if subset_by_index[0] < 0:
|
||||
raise ValueError('Indices in subset_by_index must be non-negative.')
|
||||
m, n = a.shape
|
||||
rank = n if n < m else m
|
||||
if subset_by_index[1] > rank:
|
||||
raise ValueError('Index in subset_by_index[1] exceeds matrix size.')
|
||||
if full_matrices and subset_by_index != (0, rank):
|
||||
raise ValueError(
|
||||
'full_matrices and subset_by_index cannot be both be set.'
|
||||
)
|
||||
# By convention, eigenvalues are numbered in non-decreasing order, while
|
||||
# singular values are numbered non-increasing order, so change
|
||||
# subset_by_index accordingly.
|
||||
subset_by_index = (rank - subset_by_index[1], rank - subset_by_index[0])
|
||||
|
||||
# QDWH algorithm fails at zero-matrix `A` and produces all NaNs, which can
|
||||
# be seen from a dynamically weighted Halley (DWH) iteration:
|
||||
@ -268,6 +323,7 @@ def svd(a: Any,
|
||||
return_nan=non_finite,
|
||||
full_matrices=full_matrices,
|
||||
compute_uv=compute_uv,
|
||||
subset_by_index=subset_by_index,
|
||||
),
|
||||
functools.partial(
|
||||
_qdwh_svd,
|
||||
@ -275,6 +331,7 @@ def svd(a: Any,
|
||||
compute_uv=compute_uv,
|
||||
hermitian=hermitian,
|
||||
max_iterations=max_iterations,
|
||||
subset_by_index=subset_by_index,
|
||||
),
|
||||
operand=(a),
|
||||
)
|
||||
|
@ -72,30 +72,85 @@ def cholesky(a: ArrayLike, *, upper: bool = False) -> Array:
|
||||
L = lax_linalg.cholesky(a)
|
||||
return L.mT.conj() if upper else L
|
||||
|
||||
|
||||
@overload
|
||||
def svd(a: ArrayLike, full_matrices: bool = True, *, compute_uv: Literal[True],
|
||||
hermitian: bool = False) -> SVDResult: ...
|
||||
def svd(
|
||||
a: ArrayLike,
|
||||
full_matrices: bool = True,
|
||||
*,
|
||||
compute_uv: Literal[True],
|
||||
hermitian: bool = False,
|
||||
subset_by_index: tuple[int, int] | None = None,
|
||||
) -> SVDResult:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def svd(a: ArrayLike, full_matrices: bool, compute_uv: Literal[True],
|
||||
hermitian: bool = False) -> SVDResult: ...
|
||||
def svd(
|
||||
a: ArrayLike,
|
||||
full_matrices: bool,
|
||||
compute_uv: Literal[True],
|
||||
hermitian: bool = False,
|
||||
subset_by_index: tuple[int, int] | None = None,
|
||||
) -> SVDResult:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def svd(a: ArrayLike, full_matrices: bool = True, *, compute_uv: Literal[False],
|
||||
hermitian: bool = False) -> Array: ...
|
||||
def svd(
|
||||
a: ArrayLike,
|
||||
full_matrices: bool = True,
|
||||
*,
|
||||
compute_uv: Literal[False],
|
||||
hermitian: bool = False,
|
||||
subset_by_index: tuple[int, int] | None = None,
|
||||
) -> Array:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def svd(a: ArrayLike, full_matrices: bool, compute_uv: Literal[False],
|
||||
hermitian: bool = False) -> Array: ...
|
||||
def svd(
|
||||
a: ArrayLike,
|
||||
full_matrices: bool,
|
||||
compute_uv: Literal[False],
|
||||
hermitian: bool = False,
|
||||
subset_by_index: tuple[int, int] | None = None,
|
||||
) -> Array:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
|
||||
hermitian: bool = False) -> Array | SVDResult: ...
|
||||
def svd(
|
||||
a: ArrayLike,
|
||||
full_matrices: bool = True,
|
||||
compute_uv: bool = True,
|
||||
hermitian: bool = False,
|
||||
subset_by_index: tuple[int, int] | None = None,
|
||||
) -> Array | SVDResult:
|
||||
...
|
||||
|
||||
|
||||
@implements(np.linalg.svd)
|
||||
@partial(jit, static_argnames=('full_matrices', 'compute_uv', 'hermitian'))
|
||||
def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
|
||||
hermitian: bool = False) -> Array | SVDResult:
|
||||
@partial(
|
||||
jit,
|
||||
static_argnames=(
|
||||
"full_matrices",
|
||||
"compute_uv",
|
||||
"hermitian",
|
||||
"subset_by_index",
|
||||
),
|
||||
)
|
||||
def svd(
|
||||
a: ArrayLike,
|
||||
full_matrices: bool = True,
|
||||
compute_uv: bool = True,
|
||||
hermitian: bool = False,
|
||||
subset_by_index: tuple[int, int] | None = None,
|
||||
) -> Array | SVDResult:
|
||||
check_arraylike("jnp.linalg.svd", a)
|
||||
a, = promote_dtypes_inexact(jnp.asarray(a))
|
||||
if hermitian:
|
||||
w, v = lax_linalg.eigh(a)
|
||||
w, v = lax_linalg.eigh(a, subset_by_index=subset_by_index)
|
||||
s = lax.abs(v)
|
||||
if compute_uv:
|
||||
sign = lax.sign(v)
|
||||
@ -111,10 +166,20 @@ def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
|
||||
return lax.rev(lax.sort(s, dimension=-1), dimensions=[s.ndim-1])
|
||||
|
||||
if compute_uv:
|
||||
u, s, vh = lax_linalg.svd(a, full_matrices=full_matrices, compute_uv=True)
|
||||
u, s, vh = lax_linalg.svd(
|
||||
a,
|
||||
full_matrices=full_matrices,
|
||||
compute_uv=True,
|
||||
subset_by_index=subset_by_index,
|
||||
)
|
||||
return SVDResult(u, s, vh)
|
||||
else:
|
||||
return lax_linalg.svd(a, full_matrices=full_matrices, compute_uv=False)
|
||||
return lax_linalg.svd(
|
||||
a,
|
||||
full_matrices=full_matrices,
|
||||
compute_uv=False,
|
||||
subset_by_index=subset_by_index,
|
||||
)
|
||||
|
||||
|
||||
@implements(np.linalg.matrix_power)
|
||||
|
@ -3249,7 +3249,18 @@ def _qr(operand, full_matrices):
|
||||
tf_impl[lax.linalg.qr_p] = _qr
|
||||
|
||||
|
||||
def _svd(operand, full_matrices, compute_uv):
|
||||
def _svd(
|
||||
operand: TfVal,
|
||||
full_matrices: bool,
|
||||
compute_uv: bool,
|
||||
subset_by_index: tuple[int, int] | None = None,
|
||||
):
|
||||
if not (
|
||||
subset_by_index is None
|
||||
or subset_by_index == (0, min(operand.shape[-1], operand.shape[-2]))
|
||||
):
|
||||
raise NotImplementedError("subset_by_index is not implemented")
|
||||
|
||||
result = tf.linalg.svd(operand, full_matrices, compute_uv)
|
||||
if not compute_uv:
|
||||
return result,
|
||||
|
@ -252,6 +252,42 @@ class SvdTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(expected_s, jnp.real(actual_s), rtol=_SVD_RTOL,
|
||||
atol=1E-6)
|
||||
|
||||
@jtu.sample_product(
|
||||
start=[0, 1, 64, 126, 127],
|
||||
end=[1, 2, 65, 127, 128],
|
||||
)
|
||||
@jtu.run_on_devices('tpu') # TODO(rmlarsen: enable on other devices)
|
||||
def testSvdSubsetByIndex(self, start, end):
|
||||
if start >= end:
|
||||
return
|
||||
dtype = np.float32
|
||||
m = 256
|
||||
n = 128
|
||||
rng = jtu.rand_default(self.rng())
|
||||
tol = np.maximum(n, 80) * np.finfo(dtype).eps
|
||||
args_maker = lambda: [rng((m, n), dtype)]
|
||||
subset_by_index = (start, end)
|
||||
k = end - start
|
||||
(a,) = args_maker()
|
||||
|
||||
u, s, vt = jnp.linalg.svd(
|
||||
a, full_matrices=False, subset_by_index=subset_by_index
|
||||
)
|
||||
self.assertEqual(u.shape, (m, k))
|
||||
self.assertEqual(s.shape, (k,))
|
||||
self.assertEqual(vt.shape, (k, n))
|
||||
|
||||
with jax.numpy_rank_promotion('allow'):
|
||||
self.assertLessEqual(
|
||||
np.linalg.norm(np.matmul(a, vt.T) - u * s), tol * np.linalg.norm(a)
|
||||
)
|
||||
|
||||
# Test that we get the approximately the same singular values when
|
||||
# slicing the full SVD.
|
||||
_, full_s, _ = jnp.linalg.svd(a, full_matrices=False)
|
||||
s_slice = full_s[start:end]
|
||||
self.assertAllClose(s_slice, s, atol=tol, rtol=tol)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user