diff --git a/CHANGELOG.md b/CHANGELOG.md index 38e5d4206..580040889 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,10 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. * [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.7...main). * Changes + * {func}`jax.numpy.linalg.svd` on TPUs uses a qdwh-svd solver. + * {func}`jax.numpy.linalg.cond` on TPUs now accepts complex input. + * {func}`jax.numpy.linalg.pinv` on TPUs now accepts complex input. + * {func}`jax.numpy.linalg.matrix_rank` on TPUs now accepts complex input. * {func}`jax.scipy.cluster.vq.vq` has been added. * `jax.experimental.maps.mesh` has been deleted. Please use `jax.experimental.maps.Mesh`. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 7777dc94c..22e7258ba 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -34,6 +34,7 @@ from jax._src.lax.lax import ( standard_primitive, standard_unop, naryop_dtype_rule, _float, _complex, _input_dtype) from jax._src.lax import lax as lax_internal +from jax._src.lax import svd as lax_svd from jax._src.lib import lapack from jax._src.lib import cuda_linalg @@ -202,6 +203,7 @@ def qr(x, full_matrices: bool = True): q, r = qr_p.bind(x, full_matrices=full_matrices) return q, r +# TODO: Add `max_qdwh_iterations` to the function signature for TPU SVD. def svd(x, full_matrices=True, compute_uv=True): """Singular value decomposition. @@ -1295,32 +1297,6 @@ def _eye_like_xla(c, aval): xops.Iota(c, iota_shape, len(aval.shape) - 2)) return xops.ConvertElementType(x, xla.dtype_to_primitive_type(aval.dtype)) -def _svd_translation_rule(ctx, avals_in, avals_out, operand, *, full_matrices, - compute_uv): - operand_aval, = avals_in - shape = operand_aval.shape - m, n = shape[-2:] - if m == 0 or n == 0: - out = [_zeros_like_xla(ctx.builder, avals_out[0])] - if compute_uv: - out.append(_eye_like_xla(ctx.builder, avals_out[1])) - out.append(_eye_like_xla(ctx.builder, avals_out[2])) - return out - - u, s, v = xops.SVD(operand) - permutation = list(range(len(shape))) - permutation[-1], permutation[-2] = permutation[-2], permutation[-1] - vt = xops.Transpose(v, permutation) - if not full_matrices and m != n: - u = xops.SliceInDim(u, 0, min(m, n), stride=1, dimno=len(shape) - 1) - vt = xops.SliceInDim(vt, 0, min(m, n), stride=1, dimno=len(shape) - 2) - - if not compute_uv: - return [s] - else: - return [s, u, vt] - - def svd_abstract_eval(operand, full_matrices, compute_uv): if isinstance(operand, ShapedArray): if operand.ndim < 2: @@ -1440,6 +1416,31 @@ def _svd_cpu_gpu_lowering(gesvd_impl, ctx, operand, *, full_matrices, return result +def _svd_tpu(a, *, full_matrices, compute_uv): + batch_dims = a.shape[:-2] + + fn = partial(lax_svd.svd, full_matrices=full_matrices, compute_uv=compute_uv) + for _ in range(len(batch_dims)): + fn = api.vmap(fn) + + if compute_uv: + u, s, vh = fn(a) + return [s, u, vh] + else: + s = fn(a) + return [s] + +def _svd_tpu_lowering_rule(ctx, operand, *, full_matrices, compute_uv): + operand_aval, = ctx.avals_in + m, n = operand_aval.shape[-2:] + + if m == 0 or n == 0: + return mlir.lower_fun(_empty_svd, multiple_results=True)( + ctx, operand, full_matrices=full_matrices, compute_uv=compute_uv) + + return mlir.lower_fun(_svd_tpu, multiple_results=True)( + ctx, operand, full_matrices=full_matrices, compute_uv=compute_uv) + def svd_batching_rule(batched_args, batch_dims, full_matrices, compute_uv): x, = batched_args bd, = batch_dims @@ -1457,7 +1458,6 @@ svd_p.def_impl(svd_impl) svd_p.def_abstract_eval(svd_abstract_eval) ad.primitive_jvps[svd_p] = svd_jvp_rule batching.primitive_batchers[svd_p] = svd_batching_rule -xla.register_translation(svd_p, _svd_translation_rule) mlir.register_lowering( svd_p, partial(_svd_cpu_gpu_lowering, lapack.gesdd_mhlo), @@ -1468,6 +1468,7 @@ if solver_apis is not None: svd_p, partial(_svd_cpu_gpu_lowering, solver_apis.gesvd_mhlo), platform='gpu') +mlir.register_lowering(svd_p, _svd_tpu_lowering_rule) def _tridiagonal_solve_gpu_lowering(ctx, dl, d, du, b, *, m, n, ldb, t): return [sparse_apis.gtsv2_mhlo(dl, d, du, b, m=m, n=n, ldb=ldb, t=t)] diff --git a/jax/_src/lax/svd.py b/jax/_src/lax/svd.py index d0173d658..dc9d49dfc 100644 --- a/jax/_src/lax/svd.py +++ b/jax/_src/lax/svd.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License -"""A JIT-compatible library for QDWH-based SVD decomposition. +"""A JIT-compatible library for QDWH-based singular value decomposition. QDWH is short for QR-based dynamically weighted Halley iteration. The Halley iteration implemented through QR decmopositions is numerically stable and does @@ -35,8 +35,7 @@ https://epubs.siam.org/doi/abs/10.1137/090774999 import functools -from typing import Sequence, Union - +from typing import Any, Sequence, Union import jax from jax import core from jax import lax @@ -44,10 +43,10 @@ import jax.numpy as jnp @functools.partial(jax.jit, static_argnums=(1, 2, 3)) -def _svd(a: jnp.ndarray, +def _svd(a: Any, hermitian: bool, compute_uv: bool, - max_iterations: int) -> Union[jnp.ndarray, Sequence[jnp.ndarray]]: + max_iterations: int) -> Union[Any, Sequence[Any]]: """Singular value decomposition for m x n matrix and m >= n. Args: @@ -99,11 +98,11 @@ def _svd(a: jnp.ndarray, @functools.partial(jax.jit, static_argnums=(1, 2, 3, 4)) -def svd(a: jnp.ndarray, +def svd(a: Any, full_matrices: bool, compute_uv: bool = True, hermitian: bool = False, - max_iterations: int = 10) -> Union[jnp.ndarray, Sequence[jnp.ndarray]]: + max_iterations: int = 10) -> Union[Any, Sequence[Any]]: """Singular value decomposition. Args: diff --git a/jax/experimental/jax2tf/tests/jax2tf_limitations.py b/jax/experimental/jax2tf/tests/jax2tf_limitations.py index ab077f1bd..f220d24ef 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_limitations.py +++ b/jax/experimental/jax2tf/tests/jax2tf_limitations.py @@ -1071,9 +1071,9 @@ class Jax2TfLimitation(primitive_harness.Limitation): # than O(\sqrt(eps)), which is likely a property of the SVD algorithms # in question; revisit with better understanding of the SVD algorithms. if x.dtype in [np.float32, np.complex64]: - slack_factor = 1E4 + slack_factor = 2E4 elif x.dtype in [np.float64, np.complex128]: - slack_factor = 1E9 + slack_factor = 2E9 np.testing.assert_array_less(angular_diff, slack_factor * error_bound) @@ -1131,17 +1131,18 @@ class Jax2TfLimitation(primitive_harness.Limitation): dtypes=[np.complex64, np.complex128], devices=("cpu", "gpu"), modes=("compiled",)), + Jax2TfLimitation( + "Large numerical discrepancy", + dtypes=[np.float16], + devices=("tpu"), + modes=("eager", "graph", "compiled"), + skip_comparison=True), missing_tf_kernel(dtypes=[dtypes.bfloat16], devices="tpu"), custom_numeric( tol=1e-4, dtypes=[np.float32, np.complex64], devices=("cpu", "gpu"), modes=("eager", "graph", "compiled")), - custom_numeric( - tol=1e-2, - dtypes=[np.float16], - devices=("tpu"), - modes=("eager", "graph", "compiled")), # TODO: this is very low tolerance for f64 custom_numeric( tol=1e-4, @@ -1149,11 +1150,20 @@ class Jax2TfLimitation(primitive_harness.Limitation): devices=("cpu", "gpu"), modes=("eager", "graph", "compiled")), custom_numeric( + tol=1e-4, description="custom numeric comparison when compute_uv on CPU/GPU", custom_assert=custom_assert, devices=("cpu", "gpu"), modes=("eager", "graph", "compiled"), enabled=(compute_uv == True)), + custom_numeric( + tol=1e-2, + description="custom numeric comparison when compute_uv on TPU", + dtypes=[np.float32, np.float64, np.complex64, np.complex128], + custom_assert=custom_assert, + devices=("tpu"), + modes=("eager", "graph", "compiled"), + enabled=(compute_uv == True)), ] @classmethod diff --git a/tests/linalg_test.py b/tests/linalg_test.py index e378e238f..575954774 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -544,10 +544,6 @@ class NumpyLinalgTest(jtu.JaxTestCase): for hermitian in ([False, True] if m == n else [False]))) @jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1 def testSVD(self, b, m, n, dtype, full_matrices, compute_uv, hermitian): - # TODO: enable after linking lax.svd to lax.linalg.svd - if (jnp.issubdtype(dtype, np.complexfloating) and - jtu.device_under_test() == "tpu"): - raise unittest.SkipTest("No complex SVD implementation") rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(b + (m, n), dtype)] @@ -744,11 +740,6 @@ class NumpyLinalgTest(jtu.JaxTestCase): for dtype in float_types + complex_types)) @jtu.skip_on_devices("gpu") # TODO(#2203): numerical errors def testCond(self, shape, pnorm, dtype): - # TODO: enable after linking lax.svd to lax.linalg.svd - if (jnp.issubdtype(dtype, np.complexfloating) and - jtu.device_under_test() == "tpu"): - raise unittest.SkipTest("No complex SVD implementation") - def gen_mat(): # arr_gen = jtu.rand_some_nan(self.rng()) arr_gen = jtu.rand_default(self.rng()) @@ -855,10 +846,6 @@ class NumpyLinalgTest(jtu.JaxTestCase): for dtype in float_types + complex_types)) @jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1 def testPinv(self, shape, dtype): - # TODO: enable after linking lax.svd to lax.linalg.svd - if (jnp.issubdtype(dtype, np.complexfloating) and - jtu.device_under_test() == "tpu"): - raise unittest.SkipTest("No complex SVD implementation") rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] @@ -910,10 +897,6 @@ class NumpyLinalgTest(jtu.JaxTestCase): for dtype in float_types + complex_types)) @jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1 def testMatrixRank(self, shape, dtype): - # TODO: enable after linking lax.svd to lax.linalg.svd - if (jnp.issubdtype(dtype, np.complexfloating) and - jtu.device_under_test() == "tpu"): - raise unittest.SkipTest("No complex SVD implementation") rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] a, = args_maker() diff --git a/tests/svd_test.py b/tests/svd_test.py index b38587623..d170f17dc 100644 --- a/tests/svd_test.py +++ b/tests/svd_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License -"""Tests for the library of QDWH-based SVD decomposition.""" +"""Tests for the library of QDWH-based singular value decomposition.""" import functools import jax