mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[linalg] Add tpu svd lowering rule.
PiperOrigin-RevId: 445533767
This commit is contained in:
parent
b90df4bf4d
commit
020849076c
@ -13,6 +13,10 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* [GitHub
|
||||
commits](https://github.com/google/jax/compare/jax-v0.3.7...main).
|
||||
* Changes
|
||||
* {func}`jax.numpy.linalg.svd` on TPUs uses a qdwh-svd solver.
|
||||
* {func}`jax.numpy.linalg.cond` on TPUs now accepts complex input.
|
||||
* {func}`jax.numpy.linalg.pinv` on TPUs now accepts complex input.
|
||||
* {func}`jax.numpy.linalg.matrix_rank` on TPUs now accepts complex input.
|
||||
* {func}`jax.scipy.cluster.vq.vq` has been added.
|
||||
* `jax.experimental.maps.mesh` has been deleted.
|
||||
Please use `jax.experimental.maps.Mesh`. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh
|
||||
|
@ -34,6 +34,7 @@ from jax._src.lax.lax import (
|
||||
standard_primitive, standard_unop, naryop_dtype_rule, _float, _complex,
|
||||
_input_dtype)
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import svd as lax_svd
|
||||
from jax._src.lib import lapack
|
||||
|
||||
from jax._src.lib import cuda_linalg
|
||||
@ -202,6 +203,7 @@ def qr(x, full_matrices: bool = True):
|
||||
q, r = qr_p.bind(x, full_matrices=full_matrices)
|
||||
return q, r
|
||||
|
||||
# TODO: Add `max_qdwh_iterations` to the function signature for TPU SVD.
|
||||
def svd(x, full_matrices=True, compute_uv=True):
|
||||
"""Singular value decomposition.
|
||||
|
||||
@ -1295,32 +1297,6 @@ def _eye_like_xla(c, aval):
|
||||
xops.Iota(c, iota_shape, len(aval.shape) - 2))
|
||||
return xops.ConvertElementType(x, xla.dtype_to_primitive_type(aval.dtype))
|
||||
|
||||
def _svd_translation_rule(ctx, avals_in, avals_out, operand, *, full_matrices,
|
||||
compute_uv):
|
||||
operand_aval, = avals_in
|
||||
shape = operand_aval.shape
|
||||
m, n = shape[-2:]
|
||||
if m == 0 or n == 0:
|
||||
out = [_zeros_like_xla(ctx.builder, avals_out[0])]
|
||||
if compute_uv:
|
||||
out.append(_eye_like_xla(ctx.builder, avals_out[1]))
|
||||
out.append(_eye_like_xla(ctx.builder, avals_out[2]))
|
||||
return out
|
||||
|
||||
u, s, v = xops.SVD(operand)
|
||||
permutation = list(range(len(shape)))
|
||||
permutation[-1], permutation[-2] = permutation[-2], permutation[-1]
|
||||
vt = xops.Transpose(v, permutation)
|
||||
if not full_matrices and m != n:
|
||||
u = xops.SliceInDim(u, 0, min(m, n), stride=1, dimno=len(shape) - 1)
|
||||
vt = xops.SliceInDim(vt, 0, min(m, n), stride=1, dimno=len(shape) - 2)
|
||||
|
||||
if not compute_uv:
|
||||
return [s]
|
||||
else:
|
||||
return [s, u, vt]
|
||||
|
||||
|
||||
def svd_abstract_eval(operand, full_matrices, compute_uv):
|
||||
if isinstance(operand, ShapedArray):
|
||||
if operand.ndim < 2:
|
||||
@ -1440,6 +1416,31 @@ def _svd_cpu_gpu_lowering(gesvd_impl, ctx, operand, *, full_matrices,
|
||||
|
||||
return result
|
||||
|
||||
def _svd_tpu(a, *, full_matrices, compute_uv):
|
||||
batch_dims = a.shape[:-2]
|
||||
|
||||
fn = partial(lax_svd.svd, full_matrices=full_matrices, compute_uv=compute_uv)
|
||||
for _ in range(len(batch_dims)):
|
||||
fn = api.vmap(fn)
|
||||
|
||||
if compute_uv:
|
||||
u, s, vh = fn(a)
|
||||
return [s, u, vh]
|
||||
else:
|
||||
s = fn(a)
|
||||
return [s]
|
||||
|
||||
def _svd_tpu_lowering_rule(ctx, operand, *, full_matrices, compute_uv):
|
||||
operand_aval, = ctx.avals_in
|
||||
m, n = operand_aval.shape[-2:]
|
||||
|
||||
if m == 0 or n == 0:
|
||||
return mlir.lower_fun(_empty_svd, multiple_results=True)(
|
||||
ctx, operand, full_matrices=full_matrices, compute_uv=compute_uv)
|
||||
|
||||
return mlir.lower_fun(_svd_tpu, multiple_results=True)(
|
||||
ctx, operand, full_matrices=full_matrices, compute_uv=compute_uv)
|
||||
|
||||
def svd_batching_rule(batched_args, batch_dims, full_matrices, compute_uv):
|
||||
x, = batched_args
|
||||
bd, = batch_dims
|
||||
@ -1457,7 +1458,6 @@ svd_p.def_impl(svd_impl)
|
||||
svd_p.def_abstract_eval(svd_abstract_eval)
|
||||
ad.primitive_jvps[svd_p] = svd_jvp_rule
|
||||
batching.primitive_batchers[svd_p] = svd_batching_rule
|
||||
xla.register_translation(svd_p, _svd_translation_rule)
|
||||
|
||||
mlir.register_lowering(
|
||||
svd_p, partial(_svd_cpu_gpu_lowering, lapack.gesdd_mhlo),
|
||||
@ -1468,6 +1468,7 @@ if solver_apis is not None:
|
||||
svd_p, partial(_svd_cpu_gpu_lowering, solver_apis.gesvd_mhlo),
|
||||
platform='gpu')
|
||||
|
||||
mlir.register_lowering(svd_p, _svd_tpu_lowering_rule)
|
||||
|
||||
def _tridiagonal_solve_gpu_lowering(ctx, dl, d, du, b, *, m, n, ldb, t):
|
||||
return [sparse_apis.gtsv2_mhlo(dl, d, du, b, m=m, n=n, ldb=ldb, t=t)]
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
|
||||
"""A JIT-compatible library for QDWH-based SVD decomposition.
|
||||
"""A JIT-compatible library for QDWH-based singular value decomposition.
|
||||
|
||||
QDWH is short for QR-based dynamically weighted Halley iteration. The Halley
|
||||
iteration implemented through QR decmopositions is numerically stable and does
|
||||
@ -35,8 +35,7 @@ https://epubs.siam.org/doi/abs/10.1137/090774999
|
||||
|
||||
import functools
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from typing import Any, Sequence, Union
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import lax
|
||||
@ -44,10 +43,10 @@ import jax.numpy as jnp
|
||||
|
||||
|
||||
@functools.partial(jax.jit, static_argnums=(1, 2, 3))
|
||||
def _svd(a: jnp.ndarray,
|
||||
def _svd(a: Any,
|
||||
hermitian: bool,
|
||||
compute_uv: bool,
|
||||
max_iterations: int) -> Union[jnp.ndarray, Sequence[jnp.ndarray]]:
|
||||
max_iterations: int) -> Union[Any, Sequence[Any]]:
|
||||
"""Singular value decomposition for m x n matrix and m >= n.
|
||||
|
||||
Args:
|
||||
@ -99,11 +98,11 @@ def _svd(a: jnp.ndarray,
|
||||
|
||||
|
||||
@functools.partial(jax.jit, static_argnums=(1, 2, 3, 4))
|
||||
def svd(a: jnp.ndarray,
|
||||
def svd(a: Any,
|
||||
full_matrices: bool,
|
||||
compute_uv: bool = True,
|
||||
hermitian: bool = False,
|
||||
max_iterations: int = 10) -> Union[jnp.ndarray, Sequence[jnp.ndarray]]:
|
||||
max_iterations: int = 10) -> Union[Any, Sequence[Any]]:
|
||||
"""Singular value decomposition.
|
||||
|
||||
Args:
|
||||
|
@ -1071,9 +1071,9 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
# than O(\sqrt(eps)), which is likely a property of the SVD algorithms
|
||||
# in question; revisit with better understanding of the SVD algorithms.
|
||||
if x.dtype in [np.float32, np.complex64]:
|
||||
slack_factor = 1E4
|
||||
slack_factor = 2E4
|
||||
elif x.dtype in [np.float64, np.complex128]:
|
||||
slack_factor = 1E9
|
||||
slack_factor = 2E9
|
||||
|
||||
np.testing.assert_array_less(angular_diff,
|
||||
slack_factor * error_bound)
|
||||
@ -1131,17 +1131,18 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
dtypes=[np.complex64, np.complex128],
|
||||
devices=("cpu", "gpu"),
|
||||
modes=("compiled",)),
|
||||
Jax2TfLimitation(
|
||||
"Large numerical discrepancy",
|
||||
dtypes=[np.float16],
|
||||
devices=("tpu"),
|
||||
modes=("eager", "graph", "compiled"),
|
||||
skip_comparison=True),
|
||||
missing_tf_kernel(dtypes=[dtypes.bfloat16], devices="tpu"),
|
||||
custom_numeric(
|
||||
tol=1e-4,
|
||||
dtypes=[np.float32, np.complex64],
|
||||
devices=("cpu", "gpu"),
|
||||
modes=("eager", "graph", "compiled")),
|
||||
custom_numeric(
|
||||
tol=1e-2,
|
||||
dtypes=[np.float16],
|
||||
devices=("tpu"),
|
||||
modes=("eager", "graph", "compiled")),
|
||||
# TODO: this is very low tolerance for f64
|
||||
custom_numeric(
|
||||
tol=1e-4,
|
||||
@ -1149,11 +1150,20 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
devices=("cpu", "gpu"),
|
||||
modes=("eager", "graph", "compiled")),
|
||||
custom_numeric(
|
||||
tol=1e-4,
|
||||
description="custom numeric comparison when compute_uv on CPU/GPU",
|
||||
custom_assert=custom_assert,
|
||||
devices=("cpu", "gpu"),
|
||||
modes=("eager", "graph", "compiled"),
|
||||
enabled=(compute_uv == True)),
|
||||
custom_numeric(
|
||||
tol=1e-2,
|
||||
description="custom numeric comparison when compute_uv on TPU",
|
||||
dtypes=[np.float32, np.float64, np.complex64, np.complex128],
|
||||
custom_assert=custom_assert,
|
||||
devices=("tpu"),
|
||||
modes=("eager", "graph", "compiled"),
|
||||
enabled=(compute_uv == True)),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
|
@ -544,10 +544,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
for hermitian in ([False, True] if m == n else [False])))
|
||||
@jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1
|
||||
def testSVD(self, b, m, n, dtype, full_matrices, compute_uv, hermitian):
|
||||
# TODO: enable after linking lax.svd to lax.linalg.svd
|
||||
if (jnp.issubdtype(dtype, np.complexfloating) and
|
||||
jtu.device_under_test() == "tpu"):
|
||||
raise unittest.SkipTest("No complex SVD implementation")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(b + (m, n), dtype)]
|
||||
|
||||
@ -744,11 +740,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
for dtype in float_types + complex_types))
|
||||
@jtu.skip_on_devices("gpu") # TODO(#2203): numerical errors
|
||||
def testCond(self, shape, pnorm, dtype):
|
||||
# TODO: enable after linking lax.svd to lax.linalg.svd
|
||||
if (jnp.issubdtype(dtype, np.complexfloating) and
|
||||
jtu.device_under_test() == "tpu"):
|
||||
raise unittest.SkipTest("No complex SVD implementation")
|
||||
|
||||
def gen_mat():
|
||||
# arr_gen = jtu.rand_some_nan(self.rng())
|
||||
arr_gen = jtu.rand_default(self.rng())
|
||||
@ -855,10 +846,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
for dtype in float_types + complex_types))
|
||||
@jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1
|
||||
def testPinv(self, shape, dtype):
|
||||
# TODO: enable after linking lax.svd to lax.linalg.svd
|
||||
if (jnp.issubdtype(dtype, np.complexfloating) and
|
||||
jtu.device_under_test() == "tpu"):
|
||||
raise unittest.SkipTest("No complex SVD implementation")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
|
||||
@ -910,10 +897,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
for dtype in float_types + complex_types))
|
||||
@jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1
|
||||
def testMatrixRank(self, shape, dtype):
|
||||
# TODO: enable after linking lax.svd to lax.linalg.svd
|
||||
if (jnp.issubdtype(dtype, np.complexfloating) and
|
||||
jtu.device_under_test() == "tpu"):
|
||||
raise unittest.SkipTest("No complex SVD implementation")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
a, = args_maker()
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
|
||||
"""Tests for the library of QDWH-based SVD decomposition."""
|
||||
"""Tests for the library of QDWH-based singular value decomposition."""
|
||||
import functools
|
||||
|
||||
import jax
|
||||
|
Loading…
x
Reference in New Issue
Block a user