[JAX] Add an option subset_by_index that allows computing a contiguous subset of singular components from svd.

PiperOrigin-RevId: 607493941
This commit is contained in:
jax authors 2024-02-15 16:32:25 -08:00
parent 0203d15485
commit 7e7094c82d
6 changed files with 353 additions and 81 deletions

View File

@ -1800,25 +1800,35 @@ for dtype in jtu.dtypes.all_floating + jtu.dtypes.complex:
for shape in [(2, 2), (2, 7), (29, 29), (2, 3, 53), (2, 3, 29, 7)]:
for full_matrices in [False, True]:
for compute_uv in [False, True]:
subset_by_index = None
define(
lax.linalg.svd_p,
f"shape={jtu.format_shape_dtype_string(shape, dtype)}_fullmatrices={full_matrices}_computeuv={compute_uv}",
lambda *args: lax.linalg.svd_p.bind(
args[0], full_matrices=args[1], compute_uv=args[2]), [
RandArg(shape, dtype),
StaticArg(full_matrices),
StaticArg(compute_uv)
],
args[0],
full_matrices=args[1],
compute_uv=args[2],
subset_by_index=args[3],
),
[
RandArg(shape, dtype),
StaticArg(full_matrices),
StaticArg(compute_uv),
StaticArg(subset_by_index),
],
jax_unimplemented=[
Limitation(
"unimplemented",
devices=("cpu", "gpu"),
dtypes=[np.float16, dtypes.bfloat16]),
dtypes=[np.float16, dtypes.bfloat16],
),
],
shape=shape,
dtype=dtype,
full_matrices=full_matrices,
compute_uv=compute_uv)
compute_uv=compute_uv,
subset_by_index=subset_by_index,
)
for dtype in jtu.dtypes.all_inexact:
for shape in [(0, 0), (5, 5), (2, 6, 6)]:
@ -2666,7 +2676,6 @@ for dtype in (np.float32, np.float64):
dtype=dtype)
def wrap_and_split():
key = jax.random.key(42)
result = jax.random.split(key, 2)

View File

@ -298,25 +298,61 @@ def qr(x: ArrayLike, *, full_matrices: bool = True) -> tuple[Array, Array]:
q, r = qr_p.bind(x, full_matrices=full_matrices)
return q, r
@overload
def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: Literal[True]) -> tuple[Array, Array, Array]: ...
@overload
def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: Literal[False]) -> Array: ...
def svd(
x: ArrayLike,
*,
full_matrices: bool = True,
compute_uv: Literal[True],
subset_by_index: tuple[int, int] | None = None,
) -> tuple[Array, Array, Array]:
...
@overload
def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: bool = True) -> Array | tuple[Array, Array, Array]: ...
def svd(
x: ArrayLike,
*,
full_matrices: bool = True,
compute_uv: Literal[False],
subset_by_index: tuple[int, int] | None = None,
) -> Array:
...
@overload
def svd(
x: ArrayLike,
*,
full_matrices: bool = True,
compute_uv: bool = True,
subset_by_index: tuple[int, int] | None = None,
) -> Array | tuple[Array, Array, Array]:
...
# TODO: Add `max_qdwh_iterations` to the function signature for TPU SVD.
@_warn_on_positional_kwargs
def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: bool = True) -> Array | tuple[Array, Array, Array]:
def svd(
x: ArrayLike,
*,
full_matrices: bool = True,
compute_uv: bool = True,
subset_by_index: tuple[int, int] | None = None,
) -> Array | tuple[Array, Array, Array]:
"""Singular value decomposition.
Returns the singular values if compute_uv is False, otherwise returns a triple
containing the left singular vectors, the singular values and the adjoint of
the right singular vectors.
"""
result = svd_p.bind(x, full_matrices=full_matrices, compute_uv=compute_uv)
result = svd_p.bind(
x,
full_matrices=full_matrices,
compute_uv=compute_uv,
subset_by_index=subset_by_index,
)
if compute_uv:
s, u, v = result
return u, s, v
@ -324,6 +360,7 @@ def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: bool = True) ->
s, = result
return s
@_warn_on_positional_kwargs
def triangular_solve(a: ArrayLike, b: ArrayLike, *,
left_side: bool = False, lower: bool = False,
@ -1043,7 +1080,6 @@ mlir.register_lowering(triangular_solve_p, _triangular_solve_cpu_lower,
# Support operation for LU decomposition: Transformation of the pivots returned
# by LU decomposition into permutations.
# Define this outside lu_pivots_to_permutation to ensure fori_loop cache hits
def _lu_pivots_body_fn(i, permutation_and_swaps):
permutation, swaps = permutation_and_swaps
@ -1138,7 +1174,6 @@ mlir.register_lowering(
gpu_linalg.hip_lu_pivots_to_permutation),
platform='rocm')
# LU decomposition
# Computes a pivoted LU decomposition such that
@ -1745,35 +1780,50 @@ mlir.register_lowering(qr_p, mlir.lower_fun(_qr_lowering))
# Singular value decomposition
def _svd_impl(operand, *, full_matrices, compute_uv, subset_by_index=None):
return dispatch.apply_primitive(
svd_p,
operand,
full_matrices=full_matrices,
compute_uv=compute_uv,
subset_by_index=subset_by_index,
)
def _svd_impl(operand, *, full_matrices, compute_uv):
return dispatch.apply_primitive(svd_p, operand, full_matrices=full_matrices,
compute_uv=compute_uv)
def _svd_abstract_eval(operand, *, full_matrices, compute_uv):
def _svd_abstract_eval(operand, *, full_matrices, compute_uv, subset_by_index):
if isinstance(operand, ShapedArray):
if operand.ndim < 2:
raise ValueError("Argument to singular value decomposition must have ndims >= 2")
batch_dims = operand.shape[:-2]
m = operand.shape[-2]
n = operand.shape[-1]
s = operand.update(shape=batch_dims + (min(m, n),),
dtype=lax_internal._complex_basetype(operand.dtype))
rank = min(m, n)
if subset_by_index is not None:
if full_matrices and subset_by_index != (0, rank):
raise ValueError("full_matrices and subset_by_index cannot both be set")
rank = min(rank, subset_by_index[1] - subset_by_index[0])
s = operand.update(
shape=batch_dims + (rank,),
dtype=lax_internal._complex_basetype(operand.dtype),
)
if compute_uv:
u = operand.update(shape=batch_dims + (m, m if full_matrices else min(m, n)))
vt = operand.update(shape=batch_dims + (n if full_matrices else min(m, n), n))
u = operand.update(shape=batch_dims + (m, m if full_matrices else rank))
vt = operand.update(shape=batch_dims + (n if full_matrices else rank, n))
return s, u, vt
else:
return s,
else:
raise NotImplementedError
@jax.default_matmul_precision("float32")
def _svd_jvp_rule(primals, tangents, *, full_matrices, compute_uv):
def _svd_jvp_rule(
primals, tangents, *, full_matrices, compute_uv, subset_by_index
):
A, = primals
dA, = tangents
s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True)
s, U, Vt = svd_p.bind(
A, full_matrices=False, compute_uv=True, subset_by_index=subset_by_index
)
if compute_uv and full_matrices:
# TODO: implement full matrices case, documented here: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
@ -1812,6 +1862,7 @@ def _svd_jvp_rule(primals, tangents, *, full_matrices, compute_uv):
return (s, U, Vt), (ds, dU, _H(dV))
def _empty_svd(a, *, full_matrices, compute_uv):
batch_shape = a.shape[:-2]
m, n = a.shape[-2:]
@ -1828,8 +1879,17 @@ def _empty_svd(a, *, full_matrices, compute_uv):
u, v = v, u
return s, u, v
def _svd_cpu_gpu_lowering(gesvd_impl, ctx, operand, *, full_matrices,
compute_uv, platform: str):
def _svd_cpu_gpu_lowering(
gesvd_impl,
ctx,
operand,
*,
full_matrices,
compute_uv,
subset_by_index,
platform: str,
):
operand_aval, = ctx.avals_in
s_aval = ctx.avals_out[0]
m, n = operand_aval.shape[-2:]
@ -1841,9 +1901,16 @@ def _svd_cpu_gpu_lowering(gesvd_impl, ctx, operand, *, full_matrices,
f"implemented only for the batch dimensions: {operand_aval.shape}")
batch_dims = operand_aval.shape[:-2]
if not (subset_by_index is None or subset_by_index == (0, min(m, n))):
raise NotImplementedError("subset_by_index not implemented for CPU and GPU")
if m == 0 or n == 0:
return mlir.lower_fun(_empty_svd, multiple_results=True)(
ctx, operand, full_matrices=full_matrices, compute_uv=compute_uv)
ctx,
operand,
full_matrices=full_matrices,
compute_uv=compute_uv,
)
if platform in ["cuda", "rocm"]:
if not is_constant_shape(operand_aval.shape):
@ -1891,10 +1958,16 @@ def _svd_cpu_gpu_lowering(gesvd_impl, ctx, operand, *, full_matrices,
return result
def _svd_tpu(a, *, full_matrices, compute_uv):
def _svd_tpu(a, *, full_matrices, compute_uv, subset_by_index):
batch_dims = a.shape[:-2]
fn = partial(lax_svd.svd, full_matrices=full_matrices, compute_uv=compute_uv)
fn = partial(
lax_svd.svd,
full_matrices=full_matrices,
compute_uv=compute_uv,
subset_by_index=subset_by_index,
)
for _ in range(len(batch_dims)):
fn = api.vmap(fn)
@ -1905,28 +1978,49 @@ def _svd_tpu(a, *, full_matrices, compute_uv):
s = fn(a)
return [s]
def _svd_tpu_lowering_rule(ctx, operand, *, full_matrices, compute_uv):
def _svd_tpu_lowering_rule(
ctx, operand, *, full_matrices, compute_uv, subset_by_index
):
operand_aval, = ctx.avals_in
m, n = operand_aval.shape[-2:]
if m == 0 or n == 0:
return mlir.lower_fun(_empty_svd, multiple_results=True)(
ctx, operand, full_matrices=full_matrices, compute_uv=compute_uv)
ctx,
operand,
full_matrices=full_matrices,
compute_uv=compute_uv,
)
return mlir.lower_fun(_svd_tpu, multiple_results=True)(
ctx, operand, full_matrices=full_matrices, compute_uv=compute_uv)
ctx,
operand,
full_matrices=full_matrices,
compute_uv=compute_uv,
subset_by_index=subset_by_index,
)
def _svd_batching_rule(batched_args, batch_dims, *, full_matrices, compute_uv):
def _svd_batching_rule(
batched_args, batch_dims, *, full_matrices, compute_uv, subset_by_index
):
x, = batched_args
bd, = batch_dims
x = batching.moveaxis(x, bd, 0)
outs = svd_p.bind(x, full_matrices=full_matrices, compute_uv=compute_uv)
outs = svd_p.bind(
x,
full_matrices=full_matrices,
compute_uv=compute_uv,
subset_by_index=subset_by_index,
)
if compute_uv:
return outs, (0, 0, 0)
else:
return outs, (0,)
svd_p = Primitive('svd')
svd_p.multiple_results = True
svd_p.def_impl(_svd_impl)

View File

@ -37,21 +37,29 @@ from __future__ import annotations
from collections.abc import Sequence
import functools
import operator
from typing import Any
import jax
import jax.numpy as jnp
from jax import lax
from jax._src import core
import jax.numpy as jnp
@functools.partial(jax.jit, static_argnums=(2, 3))
@functools.partial(jax.jit, static_argnums=(2, 3, 4))
def _constant_svd(
a: Any, return_nan: bool, full_matrices: bool, compute_uv: bool = True
a: Any,
return_nan: bool,
full_matrices: bool,
compute_uv: bool = True,
subset_by_index: tuple[int, int] | None = None,
) -> Any | Sequence[Any]:
"""SVD on matrix of all zeros."""
m, n = a.shape
k = min(m, n)
if subset_by_index is not None:
k = min(k, subset_by_index[1] - subset_by_index[0])
s = jnp.where(
return_nan,
jnp.full(shape=(k,), fill_value=jnp.nan, dtype=a.real.dtype),
@ -90,9 +98,13 @@ def _constant_svd(
return s
@functools.partial(jax.jit, static_argnums=(1, 2, 3))
@functools.partial(jax.jit, static_argnums=(1, 2, 3, 4))
def _svd_tall_and_square_input(
a: Any, hermitian: bool, compute_uv: bool, max_iterations: int
a: Any,
hermitian: bool,
compute_uv: bool,
max_iterations: int,
subset_by_index: tuple[int, int] | None = None,
) -> Any | Sequence[Any]:
"""Singular value decomposition for m x n matrix and m >= n.
@ -113,7 +125,7 @@ def _svd_tall_and_square_input(
max_iterations=max_iterations)
# TODO: Uses `eigvals_only=True` if `compute_uv=False`.
v, s = lax.linalg.eigh(h)
v, s = lax.linalg.eigh(h, subset_by_index=subset_by_index)
# Singular values are non-negative by definition. But eigh could return small
# negative values, so we clamp them to zero.
s = jnp.maximum(s, 0.0)
@ -148,12 +160,15 @@ def _svd_tall_and_square_input(
return (u_out, s_out, v_out)
@functools.partial(jax.jit, static_argnums=(1, 2, 3, 4))
def _qdwh_svd(a: Any,
full_matrices: bool,
compute_uv: bool = True,
hermitian: bool = False,
max_iterations: int = 10) -> Any | Sequence[Any]:
@functools.partial(jax.jit, static_argnums=(1, 2, 3, 4, 5))
def _qdwh_svd(
a: Any,
full_matrices: bool,
compute_uv: bool = True,
hermitian: bool = False,
max_iterations: int = 10,
subset_by_index: tuple[int, int] | None = None,
) -> Any | Sequence[Any]:
"""Singular value decomposition.
Args:
@ -196,12 +211,14 @@ def _qdwh_svd(a: Any,
if not compute_uv:
with jax.default_matmul_precision('float32'):
return _svd_tall_and_square_input(a, hermitian, compute_uv,
max_iterations)
return _svd_tall_and_square_input(
a, hermitian, compute_uv, max_iterations, subset_by_index
)
with jax.default_matmul_precision('float32'):
u_out, s_out, v_out = _svd_tall_and_square_input(
a, hermitian, compute_uv, max_iterations)
a, hermitian, compute_uv, max_iterations, subset_by_index
)
if reduce_to_square:
u_out = q @ u_out
@ -214,12 +231,15 @@ def _qdwh_svd(a: Any,
return (u_out, s_out, v_out.T.conj())
@functools.partial(jax.jit, static_argnums=(1, 2, 3, 4))
def svd(a: Any,
full_matrices: bool,
compute_uv: bool = True,
hermitian: bool = False,
max_iterations: int = 10) -> Any | Sequence[Any]:
@functools.partial(jax.jit, static_argnums=(1, 2, 3, 4, 5))
def svd(
a: Any,
full_matrices: bool,
compute_uv: bool = True,
hermitian: bool = False,
max_iterations: int = 10,
subset_by_index: tuple[int, int] | None = None,
) -> Any | Sequence[Any]:
"""Singular value decomposition.
Args:
@ -230,6 +250,10 @@ def svd(a: Any,
compute_uv: Whether to compute also `u` and `v` in addition to `s`.
hermitian: True if `a` is Hermitian.
max_iterations: The predefined maximum number of iterations of QDWH.
subset_by_index: Optional 2-tuple [start, end] indicating the range of
indices of singular componenets to compute. For example, if
``subset_by_index`` = [0,2], then ``svd`` computes the two largest
singular values (and their singular vectors if `compute_uv` is true.
Returns:
A 3-tuple (`u`, `s`, `vh`), where `u` and `vh` are unitary matrices,
@ -247,12 +271,43 @@ def svd(a: Any,
'specified to use `svd` within JAX transformations.')
hermitian = core.concrete_or_error(
bool, hermitian, 'The `hermitian` argument must be statically '
'specified to use `qdwh` within JAX transformations.')
bool,
hermitian,
'The `hermitian` argument must be statically '
'specified to use `svd` within JAX transformations.',
)
max_iterations = core.concrete_or_error(
int, max_iterations, 'The `max_iterations` argument must be statically '
'specified to use `qdwh` within JAX transformations.')
int,
max_iterations,
'The `max_iterations` argument must be statically '
'specified to use `svd` within JAX transformations.',
)
if subset_by_index is not None:
if len(subset_by_index) != 2:
raise ValueError('subset_by_index must be a tuple of size 2.')
# Make sure subset_by_index is a concrete tuple.
subset_by_index = (
operator.index(subset_by_index[0]),
operator.index(subset_by_index[1]),
)
if subset_by_index[0] >= subset_by_index[1]:
raise ValueError('Got empty index range in subset_by_index.')
if subset_by_index[0] < 0:
raise ValueError('Indices in subset_by_index must be non-negative.')
m, n = a.shape
rank = n if n < m else m
if subset_by_index[1] > rank:
raise ValueError('Index in subset_by_index[1] exceeds matrix size.')
if full_matrices and subset_by_index != (0, rank):
raise ValueError(
'full_matrices and subset_by_index cannot be both be set.'
)
# By convention, eigenvalues are numbered in non-decreasing order, while
# singular values are numbered non-increasing order, so change
# subset_by_index accordingly.
subset_by_index = (rank - subset_by_index[1], rank - subset_by_index[0])
# QDWH algorithm fails at zero-matrix `A` and produces all NaNs, which can
# be seen from a dynamically weighted Halley (DWH) iteration:
@ -268,6 +323,7 @@ def svd(a: Any,
return_nan=non_finite,
full_matrices=full_matrices,
compute_uv=compute_uv,
subset_by_index=subset_by_index,
),
functools.partial(
_qdwh_svd,
@ -275,6 +331,7 @@ def svd(a: Any,
compute_uv=compute_uv,
hermitian=hermitian,
max_iterations=max_iterations,
subset_by_index=subset_by_index,
),
operand=(a),
)

View File

@ -72,30 +72,85 @@ def cholesky(a: ArrayLike, *, upper: bool = False) -> Array:
L = lax_linalg.cholesky(a)
return L.mT.conj() if upper else L
@overload
def svd(a: ArrayLike, full_matrices: bool = True, *, compute_uv: Literal[True],
hermitian: bool = False) -> SVDResult: ...
def svd(
a: ArrayLike,
full_matrices: bool = True,
*,
compute_uv: Literal[True],
hermitian: bool = False,
subset_by_index: tuple[int, int] | None = None,
) -> SVDResult:
...
@overload
def svd(a: ArrayLike, full_matrices: bool, compute_uv: Literal[True],
hermitian: bool = False) -> SVDResult: ...
def svd(
a: ArrayLike,
full_matrices: bool,
compute_uv: Literal[True],
hermitian: bool = False,
subset_by_index: tuple[int, int] | None = None,
) -> SVDResult:
...
@overload
def svd(a: ArrayLike, full_matrices: bool = True, *, compute_uv: Literal[False],
hermitian: bool = False) -> Array: ...
def svd(
a: ArrayLike,
full_matrices: bool = True,
*,
compute_uv: Literal[False],
hermitian: bool = False,
subset_by_index: tuple[int, int] | None = None,
) -> Array:
...
@overload
def svd(a: ArrayLike, full_matrices: bool, compute_uv: Literal[False],
hermitian: bool = False) -> Array: ...
def svd(
a: ArrayLike,
full_matrices: bool,
compute_uv: Literal[False],
hermitian: bool = False,
subset_by_index: tuple[int, int] | None = None,
) -> Array:
...
@overload
def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
hermitian: bool = False) -> Array | SVDResult: ...
def svd(
a: ArrayLike,
full_matrices: bool = True,
compute_uv: bool = True,
hermitian: bool = False,
subset_by_index: tuple[int, int] | None = None,
) -> Array | SVDResult:
...
@implements(np.linalg.svd)
@partial(jit, static_argnames=('full_matrices', 'compute_uv', 'hermitian'))
def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
hermitian: bool = False) -> Array | SVDResult:
@partial(
jit,
static_argnames=(
"full_matrices",
"compute_uv",
"hermitian",
"subset_by_index",
),
)
def svd(
a: ArrayLike,
full_matrices: bool = True,
compute_uv: bool = True,
hermitian: bool = False,
subset_by_index: tuple[int, int] | None = None,
) -> Array | SVDResult:
check_arraylike("jnp.linalg.svd", a)
a, = promote_dtypes_inexact(jnp.asarray(a))
if hermitian:
w, v = lax_linalg.eigh(a)
w, v = lax_linalg.eigh(a, subset_by_index=subset_by_index)
s = lax.abs(v)
if compute_uv:
sign = lax.sign(v)
@ -111,10 +166,20 @@ def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
return lax.rev(lax.sort(s, dimension=-1), dimensions=[s.ndim-1])
if compute_uv:
u, s, vh = lax_linalg.svd(a, full_matrices=full_matrices, compute_uv=True)
u, s, vh = lax_linalg.svd(
a,
full_matrices=full_matrices,
compute_uv=True,
subset_by_index=subset_by_index,
)
return SVDResult(u, s, vh)
else:
return lax_linalg.svd(a, full_matrices=full_matrices, compute_uv=False)
return lax_linalg.svd(
a,
full_matrices=full_matrices,
compute_uv=False,
subset_by_index=subset_by_index,
)
@implements(np.linalg.matrix_power)

View File

@ -3249,7 +3249,18 @@ def _qr(operand, full_matrices):
tf_impl[lax.linalg.qr_p] = _qr
def _svd(operand, full_matrices, compute_uv):
def _svd(
operand: TfVal,
full_matrices: bool,
compute_uv: bool,
subset_by_index: tuple[int, int] | None = None,
):
if not (
subset_by_index is None
or subset_by_index == (0, min(operand.shape[-1], operand.shape[-2]))
):
raise NotImplementedError("subset_by_index is not implemented")
result = tf.linalg.svd(operand, full_matrices, compute_uv)
if not compute_uv:
return result,

View File

@ -252,6 +252,42 @@ class SvdTest(jtu.JaxTestCase):
self.assertAllClose(expected_s, jnp.real(actual_s), rtol=_SVD_RTOL,
atol=1E-6)
@jtu.sample_product(
start=[0, 1, 64, 126, 127],
end=[1, 2, 65, 127, 128],
)
@jtu.run_on_devices('tpu') # TODO(rmlarsen: enable on other devices)
def testSvdSubsetByIndex(self, start, end):
if start >= end:
return
dtype = np.float32
m = 256
n = 128
rng = jtu.rand_default(self.rng())
tol = np.maximum(n, 80) * np.finfo(dtype).eps
args_maker = lambda: [rng((m, n), dtype)]
subset_by_index = (start, end)
k = end - start
(a,) = args_maker()
u, s, vt = jnp.linalg.svd(
a, full_matrices=False, subset_by_index=subset_by_index
)
self.assertEqual(u.shape, (m, k))
self.assertEqual(s.shape, (k,))
self.assertEqual(vt.shape, (k, n))
with jax.numpy_rank_promotion('allow'):
self.assertLessEqual(
np.linalg.norm(np.matmul(a, vt.T) - u * s), tol * np.linalg.norm(a)
)
# Test that we get the approximately the same singular values when
# slicing the full SVD.
_, full_s, _ = jnp.linalg.svd(a, full_matrices=False)
s_slice = full_s[start:end]
self.assertAllClose(s_slice, s, atol=tol, rtol=tol)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())