[linalg] Add tpu svd lowering rule.

PiperOrigin-RevId: 445533767
This commit is contained in:
Tianjian Lu 2022-04-29 16:42:08 -07:00 committed by jax authors
parent b90df4bf4d
commit 020849076c
6 changed files with 56 additions and 59 deletions

View File

@ -13,6 +13,10 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* [GitHub * [GitHub
commits](https://github.com/google/jax/compare/jax-v0.3.7...main). commits](https://github.com/google/jax/compare/jax-v0.3.7...main).
* Changes * Changes
* {func}`jax.numpy.linalg.svd` on TPUs uses a qdwh-svd solver.
* {func}`jax.numpy.linalg.cond` on TPUs now accepts complex input.
* {func}`jax.numpy.linalg.pinv` on TPUs now accepts complex input.
* {func}`jax.numpy.linalg.matrix_rank` on TPUs now accepts complex input.
* {func}`jax.scipy.cluster.vq.vq` has been added. * {func}`jax.scipy.cluster.vq.vq` has been added.
* `jax.experimental.maps.mesh` has been deleted. * `jax.experimental.maps.mesh` has been deleted.
Please use `jax.experimental.maps.Mesh`. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh Please use `jax.experimental.maps.Mesh`. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh

View File

@ -34,6 +34,7 @@ from jax._src.lax.lax import (
standard_primitive, standard_unop, naryop_dtype_rule, _float, _complex, standard_primitive, standard_unop, naryop_dtype_rule, _float, _complex,
_input_dtype) _input_dtype)
from jax._src.lax import lax as lax_internal from jax._src.lax import lax as lax_internal
from jax._src.lax import svd as lax_svd
from jax._src.lib import lapack from jax._src.lib import lapack
from jax._src.lib import cuda_linalg from jax._src.lib import cuda_linalg
@ -202,6 +203,7 @@ def qr(x, full_matrices: bool = True):
q, r = qr_p.bind(x, full_matrices=full_matrices) q, r = qr_p.bind(x, full_matrices=full_matrices)
return q, r return q, r
# TODO: Add `max_qdwh_iterations` to the function signature for TPU SVD.
def svd(x, full_matrices=True, compute_uv=True): def svd(x, full_matrices=True, compute_uv=True):
"""Singular value decomposition. """Singular value decomposition.
@ -1295,32 +1297,6 @@ def _eye_like_xla(c, aval):
xops.Iota(c, iota_shape, len(aval.shape) - 2)) xops.Iota(c, iota_shape, len(aval.shape) - 2))
return xops.ConvertElementType(x, xla.dtype_to_primitive_type(aval.dtype)) return xops.ConvertElementType(x, xla.dtype_to_primitive_type(aval.dtype))
def _svd_translation_rule(ctx, avals_in, avals_out, operand, *, full_matrices,
compute_uv):
operand_aval, = avals_in
shape = operand_aval.shape
m, n = shape[-2:]
if m == 0 or n == 0:
out = [_zeros_like_xla(ctx.builder, avals_out[0])]
if compute_uv:
out.append(_eye_like_xla(ctx.builder, avals_out[1]))
out.append(_eye_like_xla(ctx.builder, avals_out[2]))
return out
u, s, v = xops.SVD(operand)
permutation = list(range(len(shape)))
permutation[-1], permutation[-2] = permutation[-2], permutation[-1]
vt = xops.Transpose(v, permutation)
if not full_matrices and m != n:
u = xops.SliceInDim(u, 0, min(m, n), stride=1, dimno=len(shape) - 1)
vt = xops.SliceInDim(vt, 0, min(m, n), stride=1, dimno=len(shape) - 2)
if not compute_uv:
return [s]
else:
return [s, u, vt]
def svd_abstract_eval(operand, full_matrices, compute_uv): def svd_abstract_eval(operand, full_matrices, compute_uv):
if isinstance(operand, ShapedArray): if isinstance(operand, ShapedArray):
if operand.ndim < 2: if operand.ndim < 2:
@ -1440,6 +1416,31 @@ def _svd_cpu_gpu_lowering(gesvd_impl, ctx, operand, *, full_matrices,
return result return result
def _svd_tpu(a, *, full_matrices, compute_uv):
batch_dims = a.shape[:-2]
fn = partial(lax_svd.svd, full_matrices=full_matrices, compute_uv=compute_uv)
for _ in range(len(batch_dims)):
fn = api.vmap(fn)
if compute_uv:
u, s, vh = fn(a)
return [s, u, vh]
else:
s = fn(a)
return [s]
def _svd_tpu_lowering_rule(ctx, operand, *, full_matrices, compute_uv):
operand_aval, = ctx.avals_in
m, n = operand_aval.shape[-2:]
if m == 0 or n == 0:
return mlir.lower_fun(_empty_svd, multiple_results=True)(
ctx, operand, full_matrices=full_matrices, compute_uv=compute_uv)
return mlir.lower_fun(_svd_tpu, multiple_results=True)(
ctx, operand, full_matrices=full_matrices, compute_uv=compute_uv)
def svd_batching_rule(batched_args, batch_dims, full_matrices, compute_uv): def svd_batching_rule(batched_args, batch_dims, full_matrices, compute_uv):
x, = batched_args x, = batched_args
bd, = batch_dims bd, = batch_dims
@ -1457,7 +1458,6 @@ svd_p.def_impl(svd_impl)
svd_p.def_abstract_eval(svd_abstract_eval) svd_p.def_abstract_eval(svd_abstract_eval)
ad.primitive_jvps[svd_p] = svd_jvp_rule ad.primitive_jvps[svd_p] = svd_jvp_rule
batching.primitive_batchers[svd_p] = svd_batching_rule batching.primitive_batchers[svd_p] = svd_batching_rule
xla.register_translation(svd_p, _svd_translation_rule)
mlir.register_lowering( mlir.register_lowering(
svd_p, partial(_svd_cpu_gpu_lowering, lapack.gesdd_mhlo), svd_p, partial(_svd_cpu_gpu_lowering, lapack.gesdd_mhlo),
@ -1468,6 +1468,7 @@ if solver_apis is not None:
svd_p, partial(_svd_cpu_gpu_lowering, solver_apis.gesvd_mhlo), svd_p, partial(_svd_cpu_gpu_lowering, solver_apis.gesvd_mhlo),
platform='gpu') platform='gpu')
mlir.register_lowering(svd_p, _svd_tpu_lowering_rule)
def _tridiagonal_solve_gpu_lowering(ctx, dl, d, du, b, *, m, n, ldb, t): def _tridiagonal_solve_gpu_lowering(ctx, dl, d, du, b, *, m, n, ldb, t):
return [sparse_apis.gtsv2_mhlo(dl, d, du, b, m=m, n=n, ldb=ldb, t=t)] return [sparse_apis.gtsv2_mhlo(dl, d, du, b, m=m, n=n, ldb=ldb, t=t)]

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
"""A JIT-compatible library for QDWH-based SVD decomposition. """A JIT-compatible library for QDWH-based singular value decomposition.
QDWH is short for QR-based dynamically weighted Halley iteration. The Halley QDWH is short for QR-based dynamically weighted Halley iteration. The Halley
iteration implemented through QR decmopositions is numerically stable and does iteration implemented through QR decmopositions is numerically stable and does
@ -35,8 +35,7 @@ https://epubs.siam.org/doi/abs/10.1137/090774999
import functools import functools
from typing import Sequence, Union from typing import Any, Sequence, Union
import jax import jax
from jax import core from jax import core
from jax import lax from jax import lax
@ -44,10 +43,10 @@ import jax.numpy as jnp
@functools.partial(jax.jit, static_argnums=(1, 2, 3)) @functools.partial(jax.jit, static_argnums=(1, 2, 3))
def _svd(a: jnp.ndarray, def _svd(a: Any,
hermitian: bool, hermitian: bool,
compute_uv: bool, compute_uv: bool,
max_iterations: int) -> Union[jnp.ndarray, Sequence[jnp.ndarray]]: max_iterations: int) -> Union[Any, Sequence[Any]]:
"""Singular value decomposition for m x n matrix and m >= n. """Singular value decomposition for m x n matrix and m >= n.
Args: Args:
@ -99,11 +98,11 @@ def _svd(a: jnp.ndarray,
@functools.partial(jax.jit, static_argnums=(1, 2, 3, 4)) @functools.partial(jax.jit, static_argnums=(1, 2, 3, 4))
def svd(a: jnp.ndarray, def svd(a: Any,
full_matrices: bool, full_matrices: bool,
compute_uv: bool = True, compute_uv: bool = True,
hermitian: bool = False, hermitian: bool = False,
max_iterations: int = 10) -> Union[jnp.ndarray, Sequence[jnp.ndarray]]: max_iterations: int = 10) -> Union[Any, Sequence[Any]]:
"""Singular value decomposition. """Singular value decomposition.
Args: Args:

View File

@ -1071,9 +1071,9 @@ class Jax2TfLimitation(primitive_harness.Limitation):
# than O(\sqrt(eps)), which is likely a property of the SVD algorithms # than O(\sqrt(eps)), which is likely a property of the SVD algorithms
# in question; revisit with better understanding of the SVD algorithms. # in question; revisit with better understanding of the SVD algorithms.
if x.dtype in [np.float32, np.complex64]: if x.dtype in [np.float32, np.complex64]:
slack_factor = 1E4 slack_factor = 2E4
elif x.dtype in [np.float64, np.complex128]: elif x.dtype in [np.float64, np.complex128]:
slack_factor = 1E9 slack_factor = 2E9
np.testing.assert_array_less(angular_diff, np.testing.assert_array_less(angular_diff,
slack_factor * error_bound) slack_factor * error_bound)
@ -1131,17 +1131,18 @@ class Jax2TfLimitation(primitive_harness.Limitation):
dtypes=[np.complex64, np.complex128], dtypes=[np.complex64, np.complex128],
devices=("cpu", "gpu"), devices=("cpu", "gpu"),
modes=("compiled",)), modes=("compiled",)),
Jax2TfLimitation(
"Large numerical discrepancy",
dtypes=[np.float16],
devices=("tpu"),
modes=("eager", "graph", "compiled"),
skip_comparison=True),
missing_tf_kernel(dtypes=[dtypes.bfloat16], devices="tpu"), missing_tf_kernel(dtypes=[dtypes.bfloat16], devices="tpu"),
custom_numeric( custom_numeric(
tol=1e-4, tol=1e-4,
dtypes=[np.float32, np.complex64], dtypes=[np.float32, np.complex64],
devices=("cpu", "gpu"), devices=("cpu", "gpu"),
modes=("eager", "graph", "compiled")), modes=("eager", "graph", "compiled")),
custom_numeric(
tol=1e-2,
dtypes=[np.float16],
devices=("tpu"),
modes=("eager", "graph", "compiled")),
# TODO: this is very low tolerance for f64 # TODO: this is very low tolerance for f64
custom_numeric( custom_numeric(
tol=1e-4, tol=1e-4,
@ -1149,11 +1150,20 @@ class Jax2TfLimitation(primitive_harness.Limitation):
devices=("cpu", "gpu"), devices=("cpu", "gpu"),
modes=("eager", "graph", "compiled")), modes=("eager", "graph", "compiled")),
custom_numeric( custom_numeric(
tol=1e-4,
description="custom numeric comparison when compute_uv on CPU/GPU", description="custom numeric comparison when compute_uv on CPU/GPU",
custom_assert=custom_assert, custom_assert=custom_assert,
devices=("cpu", "gpu"), devices=("cpu", "gpu"),
modes=("eager", "graph", "compiled"), modes=("eager", "graph", "compiled"),
enabled=(compute_uv == True)), enabled=(compute_uv == True)),
custom_numeric(
tol=1e-2,
description="custom numeric comparison when compute_uv on TPU",
dtypes=[np.float32, np.float64, np.complex64, np.complex128],
custom_assert=custom_assert,
devices=("tpu"),
modes=("eager", "graph", "compiled"),
enabled=(compute_uv == True)),
] ]
@classmethod @classmethod

View File

@ -544,10 +544,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
for hermitian in ([False, True] if m == n else [False]))) for hermitian in ([False, True] if m == n else [False])))
@jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1 @jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1
def testSVD(self, b, m, n, dtype, full_matrices, compute_uv, hermitian): def testSVD(self, b, m, n, dtype, full_matrices, compute_uv, hermitian):
# TODO: enable after linking lax.svd to lax.linalg.svd
if (jnp.issubdtype(dtype, np.complexfloating) and
jtu.device_under_test() == "tpu"):
raise unittest.SkipTest("No complex SVD implementation")
rng = jtu.rand_default(self.rng()) rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(b + (m, n), dtype)] args_maker = lambda: [rng(b + (m, n), dtype)]
@ -744,11 +740,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
for dtype in float_types + complex_types)) for dtype in float_types + complex_types))
@jtu.skip_on_devices("gpu") # TODO(#2203): numerical errors @jtu.skip_on_devices("gpu") # TODO(#2203): numerical errors
def testCond(self, shape, pnorm, dtype): def testCond(self, shape, pnorm, dtype):
# TODO: enable after linking lax.svd to lax.linalg.svd
if (jnp.issubdtype(dtype, np.complexfloating) and
jtu.device_under_test() == "tpu"):
raise unittest.SkipTest("No complex SVD implementation")
def gen_mat(): def gen_mat():
# arr_gen = jtu.rand_some_nan(self.rng()) # arr_gen = jtu.rand_some_nan(self.rng())
arr_gen = jtu.rand_default(self.rng()) arr_gen = jtu.rand_default(self.rng())
@ -855,10 +846,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
for dtype in float_types + complex_types)) for dtype in float_types + complex_types))
@jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1 @jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1
def testPinv(self, shape, dtype): def testPinv(self, shape, dtype):
# TODO: enable after linking lax.svd to lax.linalg.svd
if (jnp.issubdtype(dtype, np.complexfloating) and
jtu.device_under_test() == "tpu"):
raise unittest.SkipTest("No complex SVD implementation")
rng = jtu.rand_default(self.rng()) rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)] args_maker = lambda: [rng(shape, dtype)]
@ -910,10 +897,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
for dtype in float_types + complex_types)) for dtype in float_types + complex_types))
@jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1 @jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1
def testMatrixRank(self, shape, dtype): def testMatrixRank(self, shape, dtype):
# TODO: enable after linking lax.svd to lax.linalg.svd
if (jnp.issubdtype(dtype, np.complexfloating) and
jtu.device_under_test() == "tpu"):
raise unittest.SkipTest("No complex SVD implementation")
rng = jtu.rand_default(self.rng()) rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)] args_maker = lambda: [rng(shape, dtype)]
a, = args_maker() a, = args_maker()

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
"""Tests for the library of QDWH-based SVD decomposition.""" """Tests for the library of QDWH-based singular value decomposition."""
import functools import functools
import jax import jax