[linalg] Add tpu svd lowering rule.

PiperOrigin-RevId: 445533767
This commit is contained in:
Tianjian Lu 2022-04-29 16:42:08 -07:00 committed by jax authors
parent b90df4bf4d
commit 020849076c
6 changed files with 56 additions and 59 deletions

View File

@ -13,6 +13,10 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.3.7...main).
* Changes
* {func}`jax.numpy.linalg.svd` on TPUs uses a qdwh-svd solver.
* {func}`jax.numpy.linalg.cond` on TPUs now accepts complex input.
* {func}`jax.numpy.linalg.pinv` on TPUs now accepts complex input.
* {func}`jax.numpy.linalg.matrix_rank` on TPUs now accepts complex input.
* {func}`jax.scipy.cluster.vq.vq` has been added.
* `jax.experimental.maps.mesh` has been deleted.
Please use `jax.experimental.maps.Mesh`. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh

View File

@ -34,6 +34,7 @@ from jax._src.lax.lax import (
standard_primitive, standard_unop, naryop_dtype_rule, _float, _complex,
_input_dtype)
from jax._src.lax import lax as lax_internal
from jax._src.lax import svd as lax_svd
from jax._src.lib import lapack
from jax._src.lib import cuda_linalg
@ -202,6 +203,7 @@ def qr(x, full_matrices: bool = True):
q, r = qr_p.bind(x, full_matrices=full_matrices)
return q, r
# TODO: Add `max_qdwh_iterations` to the function signature for TPU SVD.
def svd(x, full_matrices=True, compute_uv=True):
"""Singular value decomposition.
@ -1295,32 +1297,6 @@ def _eye_like_xla(c, aval):
xops.Iota(c, iota_shape, len(aval.shape) - 2))
return xops.ConvertElementType(x, xla.dtype_to_primitive_type(aval.dtype))
def _svd_translation_rule(ctx, avals_in, avals_out, operand, *, full_matrices,
compute_uv):
operand_aval, = avals_in
shape = operand_aval.shape
m, n = shape[-2:]
if m == 0 or n == 0:
out = [_zeros_like_xla(ctx.builder, avals_out[0])]
if compute_uv:
out.append(_eye_like_xla(ctx.builder, avals_out[1]))
out.append(_eye_like_xla(ctx.builder, avals_out[2]))
return out
u, s, v = xops.SVD(operand)
permutation = list(range(len(shape)))
permutation[-1], permutation[-2] = permutation[-2], permutation[-1]
vt = xops.Transpose(v, permutation)
if not full_matrices and m != n:
u = xops.SliceInDim(u, 0, min(m, n), stride=1, dimno=len(shape) - 1)
vt = xops.SliceInDim(vt, 0, min(m, n), stride=1, dimno=len(shape) - 2)
if not compute_uv:
return [s]
else:
return [s, u, vt]
def svd_abstract_eval(operand, full_matrices, compute_uv):
if isinstance(operand, ShapedArray):
if operand.ndim < 2:
@ -1440,6 +1416,31 @@ def _svd_cpu_gpu_lowering(gesvd_impl, ctx, operand, *, full_matrices,
return result
def _svd_tpu(a, *, full_matrices, compute_uv):
batch_dims = a.shape[:-2]
fn = partial(lax_svd.svd, full_matrices=full_matrices, compute_uv=compute_uv)
for _ in range(len(batch_dims)):
fn = api.vmap(fn)
if compute_uv:
u, s, vh = fn(a)
return [s, u, vh]
else:
s = fn(a)
return [s]
def _svd_tpu_lowering_rule(ctx, operand, *, full_matrices, compute_uv):
operand_aval, = ctx.avals_in
m, n = operand_aval.shape[-2:]
if m == 0 or n == 0:
return mlir.lower_fun(_empty_svd, multiple_results=True)(
ctx, operand, full_matrices=full_matrices, compute_uv=compute_uv)
return mlir.lower_fun(_svd_tpu, multiple_results=True)(
ctx, operand, full_matrices=full_matrices, compute_uv=compute_uv)
def svd_batching_rule(batched_args, batch_dims, full_matrices, compute_uv):
x, = batched_args
bd, = batch_dims
@ -1457,7 +1458,6 @@ svd_p.def_impl(svd_impl)
svd_p.def_abstract_eval(svd_abstract_eval)
ad.primitive_jvps[svd_p] = svd_jvp_rule
batching.primitive_batchers[svd_p] = svd_batching_rule
xla.register_translation(svd_p, _svd_translation_rule)
mlir.register_lowering(
svd_p, partial(_svd_cpu_gpu_lowering, lapack.gesdd_mhlo),
@ -1468,6 +1468,7 @@ if solver_apis is not None:
svd_p, partial(_svd_cpu_gpu_lowering, solver_apis.gesvd_mhlo),
platform='gpu')
mlir.register_lowering(svd_p, _svd_tpu_lowering_rule)
def _tridiagonal_solve_gpu_lowering(ctx, dl, d, du, b, *, m, n, ldb, t):
return [sparse_apis.gtsv2_mhlo(dl, d, du, b, m=m, n=n, ldb=ldb, t=t)]

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License
"""A JIT-compatible library for QDWH-based SVD decomposition.
"""A JIT-compatible library for QDWH-based singular value decomposition.
QDWH is short for QR-based dynamically weighted Halley iteration. The Halley
iteration implemented through QR decmopositions is numerically stable and does
@ -35,8 +35,7 @@ https://epubs.siam.org/doi/abs/10.1137/090774999
import functools
from typing import Sequence, Union
from typing import Any, Sequence, Union
import jax
from jax import core
from jax import lax
@ -44,10 +43,10 @@ import jax.numpy as jnp
@functools.partial(jax.jit, static_argnums=(1, 2, 3))
def _svd(a: jnp.ndarray,
def _svd(a: Any,
hermitian: bool,
compute_uv: bool,
max_iterations: int) -> Union[jnp.ndarray, Sequence[jnp.ndarray]]:
max_iterations: int) -> Union[Any, Sequence[Any]]:
"""Singular value decomposition for m x n matrix and m >= n.
Args:
@ -99,11 +98,11 @@ def _svd(a: jnp.ndarray,
@functools.partial(jax.jit, static_argnums=(1, 2, 3, 4))
def svd(a: jnp.ndarray,
def svd(a: Any,
full_matrices: bool,
compute_uv: bool = True,
hermitian: bool = False,
max_iterations: int = 10) -> Union[jnp.ndarray, Sequence[jnp.ndarray]]:
max_iterations: int = 10) -> Union[Any, Sequence[Any]]:
"""Singular value decomposition.
Args:

View File

@ -1071,9 +1071,9 @@ class Jax2TfLimitation(primitive_harness.Limitation):
# than O(\sqrt(eps)), which is likely a property of the SVD algorithms
# in question; revisit with better understanding of the SVD algorithms.
if x.dtype in [np.float32, np.complex64]:
slack_factor = 1E4
slack_factor = 2E4
elif x.dtype in [np.float64, np.complex128]:
slack_factor = 1E9
slack_factor = 2E9
np.testing.assert_array_less(angular_diff,
slack_factor * error_bound)
@ -1131,17 +1131,18 @@ class Jax2TfLimitation(primitive_harness.Limitation):
dtypes=[np.complex64, np.complex128],
devices=("cpu", "gpu"),
modes=("compiled",)),
Jax2TfLimitation(
"Large numerical discrepancy",
dtypes=[np.float16],
devices=("tpu"),
modes=("eager", "graph", "compiled"),
skip_comparison=True),
missing_tf_kernel(dtypes=[dtypes.bfloat16], devices="tpu"),
custom_numeric(
tol=1e-4,
dtypes=[np.float32, np.complex64],
devices=("cpu", "gpu"),
modes=("eager", "graph", "compiled")),
custom_numeric(
tol=1e-2,
dtypes=[np.float16],
devices=("tpu"),
modes=("eager", "graph", "compiled")),
# TODO: this is very low tolerance for f64
custom_numeric(
tol=1e-4,
@ -1149,11 +1150,20 @@ class Jax2TfLimitation(primitive_harness.Limitation):
devices=("cpu", "gpu"),
modes=("eager", "graph", "compiled")),
custom_numeric(
tol=1e-4,
description="custom numeric comparison when compute_uv on CPU/GPU",
custom_assert=custom_assert,
devices=("cpu", "gpu"),
modes=("eager", "graph", "compiled"),
enabled=(compute_uv == True)),
custom_numeric(
tol=1e-2,
description="custom numeric comparison when compute_uv on TPU",
dtypes=[np.float32, np.float64, np.complex64, np.complex128],
custom_assert=custom_assert,
devices=("tpu"),
modes=("eager", "graph", "compiled"),
enabled=(compute_uv == True)),
]
@classmethod

View File

@ -544,10 +544,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
for hermitian in ([False, True] if m == n else [False])))
@jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1
def testSVD(self, b, m, n, dtype, full_matrices, compute_uv, hermitian):
# TODO: enable after linking lax.svd to lax.linalg.svd
if (jnp.issubdtype(dtype, np.complexfloating) and
jtu.device_under_test() == "tpu"):
raise unittest.SkipTest("No complex SVD implementation")
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(b + (m, n), dtype)]
@ -744,11 +740,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
for dtype in float_types + complex_types))
@jtu.skip_on_devices("gpu") # TODO(#2203): numerical errors
def testCond(self, shape, pnorm, dtype):
# TODO: enable after linking lax.svd to lax.linalg.svd
if (jnp.issubdtype(dtype, np.complexfloating) and
jtu.device_under_test() == "tpu"):
raise unittest.SkipTest("No complex SVD implementation")
def gen_mat():
# arr_gen = jtu.rand_some_nan(self.rng())
arr_gen = jtu.rand_default(self.rng())
@ -855,10 +846,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
for dtype in float_types + complex_types))
@jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1
def testPinv(self, shape, dtype):
# TODO: enable after linking lax.svd to lax.linalg.svd
if (jnp.issubdtype(dtype, np.complexfloating) and
jtu.device_under_test() == "tpu"):
raise unittest.SkipTest("No complex SVD implementation")
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
@ -910,10 +897,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
for dtype in float_types + complex_types))
@jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1
def testMatrixRank(self, shape, dtype):
# TODO: enable after linking lax.svd to lax.linalg.svd
if (jnp.issubdtype(dtype, np.complexfloating) and
jtu.device_under_test() == "tpu"):
raise unittest.SkipTest("No complex SVD implementation")
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
a, = args_maker()

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License
"""Tests for the library of QDWH-based SVD decomposition."""
"""Tests for the library of QDWH-based singular value decomposition."""
import functools
import jax