mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[linalg] Add tpu svd lowering rule.
PiperOrigin-RevId: 445533767
This commit is contained in:
parent
b90df4bf4d
commit
020849076c
@ -13,6 +13,10 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
|||||||
* [GitHub
|
* [GitHub
|
||||||
commits](https://github.com/google/jax/compare/jax-v0.3.7...main).
|
commits](https://github.com/google/jax/compare/jax-v0.3.7...main).
|
||||||
* Changes
|
* Changes
|
||||||
|
* {func}`jax.numpy.linalg.svd` on TPUs uses a qdwh-svd solver.
|
||||||
|
* {func}`jax.numpy.linalg.cond` on TPUs now accepts complex input.
|
||||||
|
* {func}`jax.numpy.linalg.pinv` on TPUs now accepts complex input.
|
||||||
|
* {func}`jax.numpy.linalg.matrix_rank` on TPUs now accepts complex input.
|
||||||
* {func}`jax.scipy.cluster.vq.vq` has been added.
|
* {func}`jax.scipy.cluster.vq.vq` has been added.
|
||||||
* `jax.experimental.maps.mesh` has been deleted.
|
* `jax.experimental.maps.mesh` has been deleted.
|
||||||
Please use `jax.experimental.maps.Mesh`. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh
|
Please use `jax.experimental.maps.Mesh`. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh
|
||||||
|
@ -34,6 +34,7 @@ from jax._src.lax.lax import (
|
|||||||
standard_primitive, standard_unop, naryop_dtype_rule, _float, _complex,
|
standard_primitive, standard_unop, naryop_dtype_rule, _float, _complex,
|
||||||
_input_dtype)
|
_input_dtype)
|
||||||
from jax._src.lax import lax as lax_internal
|
from jax._src.lax import lax as lax_internal
|
||||||
|
from jax._src.lax import svd as lax_svd
|
||||||
from jax._src.lib import lapack
|
from jax._src.lib import lapack
|
||||||
|
|
||||||
from jax._src.lib import cuda_linalg
|
from jax._src.lib import cuda_linalg
|
||||||
@ -202,6 +203,7 @@ def qr(x, full_matrices: bool = True):
|
|||||||
q, r = qr_p.bind(x, full_matrices=full_matrices)
|
q, r = qr_p.bind(x, full_matrices=full_matrices)
|
||||||
return q, r
|
return q, r
|
||||||
|
|
||||||
|
# TODO: Add `max_qdwh_iterations` to the function signature for TPU SVD.
|
||||||
def svd(x, full_matrices=True, compute_uv=True):
|
def svd(x, full_matrices=True, compute_uv=True):
|
||||||
"""Singular value decomposition.
|
"""Singular value decomposition.
|
||||||
|
|
||||||
@ -1295,32 +1297,6 @@ def _eye_like_xla(c, aval):
|
|||||||
xops.Iota(c, iota_shape, len(aval.shape) - 2))
|
xops.Iota(c, iota_shape, len(aval.shape) - 2))
|
||||||
return xops.ConvertElementType(x, xla.dtype_to_primitive_type(aval.dtype))
|
return xops.ConvertElementType(x, xla.dtype_to_primitive_type(aval.dtype))
|
||||||
|
|
||||||
def _svd_translation_rule(ctx, avals_in, avals_out, operand, *, full_matrices,
|
|
||||||
compute_uv):
|
|
||||||
operand_aval, = avals_in
|
|
||||||
shape = operand_aval.shape
|
|
||||||
m, n = shape[-2:]
|
|
||||||
if m == 0 or n == 0:
|
|
||||||
out = [_zeros_like_xla(ctx.builder, avals_out[0])]
|
|
||||||
if compute_uv:
|
|
||||||
out.append(_eye_like_xla(ctx.builder, avals_out[1]))
|
|
||||||
out.append(_eye_like_xla(ctx.builder, avals_out[2]))
|
|
||||||
return out
|
|
||||||
|
|
||||||
u, s, v = xops.SVD(operand)
|
|
||||||
permutation = list(range(len(shape)))
|
|
||||||
permutation[-1], permutation[-2] = permutation[-2], permutation[-1]
|
|
||||||
vt = xops.Transpose(v, permutation)
|
|
||||||
if not full_matrices and m != n:
|
|
||||||
u = xops.SliceInDim(u, 0, min(m, n), stride=1, dimno=len(shape) - 1)
|
|
||||||
vt = xops.SliceInDim(vt, 0, min(m, n), stride=1, dimno=len(shape) - 2)
|
|
||||||
|
|
||||||
if not compute_uv:
|
|
||||||
return [s]
|
|
||||||
else:
|
|
||||||
return [s, u, vt]
|
|
||||||
|
|
||||||
|
|
||||||
def svd_abstract_eval(operand, full_matrices, compute_uv):
|
def svd_abstract_eval(operand, full_matrices, compute_uv):
|
||||||
if isinstance(operand, ShapedArray):
|
if isinstance(operand, ShapedArray):
|
||||||
if operand.ndim < 2:
|
if operand.ndim < 2:
|
||||||
@ -1440,6 +1416,31 @@ def _svd_cpu_gpu_lowering(gesvd_impl, ctx, operand, *, full_matrices,
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def _svd_tpu(a, *, full_matrices, compute_uv):
|
||||||
|
batch_dims = a.shape[:-2]
|
||||||
|
|
||||||
|
fn = partial(lax_svd.svd, full_matrices=full_matrices, compute_uv=compute_uv)
|
||||||
|
for _ in range(len(batch_dims)):
|
||||||
|
fn = api.vmap(fn)
|
||||||
|
|
||||||
|
if compute_uv:
|
||||||
|
u, s, vh = fn(a)
|
||||||
|
return [s, u, vh]
|
||||||
|
else:
|
||||||
|
s = fn(a)
|
||||||
|
return [s]
|
||||||
|
|
||||||
|
def _svd_tpu_lowering_rule(ctx, operand, *, full_matrices, compute_uv):
|
||||||
|
operand_aval, = ctx.avals_in
|
||||||
|
m, n = operand_aval.shape[-2:]
|
||||||
|
|
||||||
|
if m == 0 or n == 0:
|
||||||
|
return mlir.lower_fun(_empty_svd, multiple_results=True)(
|
||||||
|
ctx, operand, full_matrices=full_matrices, compute_uv=compute_uv)
|
||||||
|
|
||||||
|
return mlir.lower_fun(_svd_tpu, multiple_results=True)(
|
||||||
|
ctx, operand, full_matrices=full_matrices, compute_uv=compute_uv)
|
||||||
|
|
||||||
def svd_batching_rule(batched_args, batch_dims, full_matrices, compute_uv):
|
def svd_batching_rule(batched_args, batch_dims, full_matrices, compute_uv):
|
||||||
x, = batched_args
|
x, = batched_args
|
||||||
bd, = batch_dims
|
bd, = batch_dims
|
||||||
@ -1457,7 +1458,6 @@ svd_p.def_impl(svd_impl)
|
|||||||
svd_p.def_abstract_eval(svd_abstract_eval)
|
svd_p.def_abstract_eval(svd_abstract_eval)
|
||||||
ad.primitive_jvps[svd_p] = svd_jvp_rule
|
ad.primitive_jvps[svd_p] = svd_jvp_rule
|
||||||
batching.primitive_batchers[svd_p] = svd_batching_rule
|
batching.primitive_batchers[svd_p] = svd_batching_rule
|
||||||
xla.register_translation(svd_p, _svd_translation_rule)
|
|
||||||
|
|
||||||
mlir.register_lowering(
|
mlir.register_lowering(
|
||||||
svd_p, partial(_svd_cpu_gpu_lowering, lapack.gesdd_mhlo),
|
svd_p, partial(_svd_cpu_gpu_lowering, lapack.gesdd_mhlo),
|
||||||
@ -1468,6 +1468,7 @@ if solver_apis is not None:
|
|||||||
svd_p, partial(_svd_cpu_gpu_lowering, solver_apis.gesvd_mhlo),
|
svd_p, partial(_svd_cpu_gpu_lowering, solver_apis.gesvd_mhlo),
|
||||||
platform='gpu')
|
platform='gpu')
|
||||||
|
|
||||||
|
mlir.register_lowering(svd_p, _svd_tpu_lowering_rule)
|
||||||
|
|
||||||
def _tridiagonal_solve_gpu_lowering(ctx, dl, d, du, b, *, m, n, ldb, t):
|
def _tridiagonal_solve_gpu_lowering(ctx, dl, d, du, b, *, m, n, ldb, t):
|
||||||
return [sparse_apis.gtsv2_mhlo(dl, d, du, b, m=m, n=n, ldb=ldb, t=t)]
|
return [sparse_apis.gtsv2_mhlo(dl, d, du, b, m=m, n=n, ldb=ldb, t=t)]
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License
|
# limitations under the License
|
||||||
|
|
||||||
"""A JIT-compatible library for QDWH-based SVD decomposition.
|
"""A JIT-compatible library for QDWH-based singular value decomposition.
|
||||||
|
|
||||||
QDWH is short for QR-based dynamically weighted Halley iteration. The Halley
|
QDWH is short for QR-based dynamically weighted Halley iteration. The Halley
|
||||||
iteration implemented through QR decmopositions is numerically stable and does
|
iteration implemented through QR decmopositions is numerically stable and does
|
||||||
@ -35,8 +35,7 @@ https://epubs.siam.org/doi/abs/10.1137/090774999
|
|||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
from typing import Sequence, Union
|
from typing import Any, Sequence, Union
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
from jax import core
|
from jax import core
|
||||||
from jax import lax
|
from jax import lax
|
||||||
@ -44,10 +43,10 @@ import jax.numpy as jnp
|
|||||||
|
|
||||||
|
|
||||||
@functools.partial(jax.jit, static_argnums=(1, 2, 3))
|
@functools.partial(jax.jit, static_argnums=(1, 2, 3))
|
||||||
def _svd(a: jnp.ndarray,
|
def _svd(a: Any,
|
||||||
hermitian: bool,
|
hermitian: bool,
|
||||||
compute_uv: bool,
|
compute_uv: bool,
|
||||||
max_iterations: int) -> Union[jnp.ndarray, Sequence[jnp.ndarray]]:
|
max_iterations: int) -> Union[Any, Sequence[Any]]:
|
||||||
"""Singular value decomposition for m x n matrix and m >= n.
|
"""Singular value decomposition for m x n matrix and m >= n.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -99,11 +98,11 @@ def _svd(a: jnp.ndarray,
|
|||||||
|
|
||||||
|
|
||||||
@functools.partial(jax.jit, static_argnums=(1, 2, 3, 4))
|
@functools.partial(jax.jit, static_argnums=(1, 2, 3, 4))
|
||||||
def svd(a: jnp.ndarray,
|
def svd(a: Any,
|
||||||
full_matrices: bool,
|
full_matrices: bool,
|
||||||
compute_uv: bool = True,
|
compute_uv: bool = True,
|
||||||
hermitian: bool = False,
|
hermitian: bool = False,
|
||||||
max_iterations: int = 10) -> Union[jnp.ndarray, Sequence[jnp.ndarray]]:
|
max_iterations: int = 10) -> Union[Any, Sequence[Any]]:
|
||||||
"""Singular value decomposition.
|
"""Singular value decomposition.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -1071,9 +1071,9 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
|||||||
# than O(\sqrt(eps)), which is likely a property of the SVD algorithms
|
# than O(\sqrt(eps)), which is likely a property of the SVD algorithms
|
||||||
# in question; revisit with better understanding of the SVD algorithms.
|
# in question; revisit with better understanding of the SVD algorithms.
|
||||||
if x.dtype in [np.float32, np.complex64]:
|
if x.dtype in [np.float32, np.complex64]:
|
||||||
slack_factor = 1E4
|
slack_factor = 2E4
|
||||||
elif x.dtype in [np.float64, np.complex128]:
|
elif x.dtype in [np.float64, np.complex128]:
|
||||||
slack_factor = 1E9
|
slack_factor = 2E9
|
||||||
|
|
||||||
np.testing.assert_array_less(angular_diff,
|
np.testing.assert_array_less(angular_diff,
|
||||||
slack_factor * error_bound)
|
slack_factor * error_bound)
|
||||||
@ -1131,17 +1131,18 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
|||||||
dtypes=[np.complex64, np.complex128],
|
dtypes=[np.complex64, np.complex128],
|
||||||
devices=("cpu", "gpu"),
|
devices=("cpu", "gpu"),
|
||||||
modes=("compiled",)),
|
modes=("compiled",)),
|
||||||
|
Jax2TfLimitation(
|
||||||
|
"Large numerical discrepancy",
|
||||||
|
dtypes=[np.float16],
|
||||||
|
devices=("tpu"),
|
||||||
|
modes=("eager", "graph", "compiled"),
|
||||||
|
skip_comparison=True),
|
||||||
missing_tf_kernel(dtypes=[dtypes.bfloat16], devices="tpu"),
|
missing_tf_kernel(dtypes=[dtypes.bfloat16], devices="tpu"),
|
||||||
custom_numeric(
|
custom_numeric(
|
||||||
tol=1e-4,
|
tol=1e-4,
|
||||||
dtypes=[np.float32, np.complex64],
|
dtypes=[np.float32, np.complex64],
|
||||||
devices=("cpu", "gpu"),
|
devices=("cpu", "gpu"),
|
||||||
modes=("eager", "graph", "compiled")),
|
modes=("eager", "graph", "compiled")),
|
||||||
custom_numeric(
|
|
||||||
tol=1e-2,
|
|
||||||
dtypes=[np.float16],
|
|
||||||
devices=("tpu"),
|
|
||||||
modes=("eager", "graph", "compiled")),
|
|
||||||
# TODO: this is very low tolerance for f64
|
# TODO: this is very low tolerance for f64
|
||||||
custom_numeric(
|
custom_numeric(
|
||||||
tol=1e-4,
|
tol=1e-4,
|
||||||
@ -1149,11 +1150,20 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
|||||||
devices=("cpu", "gpu"),
|
devices=("cpu", "gpu"),
|
||||||
modes=("eager", "graph", "compiled")),
|
modes=("eager", "graph", "compiled")),
|
||||||
custom_numeric(
|
custom_numeric(
|
||||||
|
tol=1e-4,
|
||||||
description="custom numeric comparison when compute_uv on CPU/GPU",
|
description="custom numeric comparison when compute_uv on CPU/GPU",
|
||||||
custom_assert=custom_assert,
|
custom_assert=custom_assert,
|
||||||
devices=("cpu", "gpu"),
|
devices=("cpu", "gpu"),
|
||||||
modes=("eager", "graph", "compiled"),
|
modes=("eager", "graph", "compiled"),
|
||||||
enabled=(compute_uv == True)),
|
enabled=(compute_uv == True)),
|
||||||
|
custom_numeric(
|
||||||
|
tol=1e-2,
|
||||||
|
description="custom numeric comparison when compute_uv on TPU",
|
||||||
|
dtypes=[np.float32, np.float64, np.complex64, np.complex128],
|
||||||
|
custom_assert=custom_assert,
|
||||||
|
devices=("tpu"),
|
||||||
|
modes=("eager", "graph", "compiled"),
|
||||||
|
enabled=(compute_uv == True)),
|
||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -544,10 +544,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
|||||||
for hermitian in ([False, True] if m == n else [False])))
|
for hermitian in ([False, True] if m == n else [False])))
|
||||||
@jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1
|
@jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1
|
||||||
def testSVD(self, b, m, n, dtype, full_matrices, compute_uv, hermitian):
|
def testSVD(self, b, m, n, dtype, full_matrices, compute_uv, hermitian):
|
||||||
# TODO: enable after linking lax.svd to lax.linalg.svd
|
|
||||||
if (jnp.issubdtype(dtype, np.complexfloating) and
|
|
||||||
jtu.device_under_test() == "tpu"):
|
|
||||||
raise unittest.SkipTest("No complex SVD implementation")
|
|
||||||
rng = jtu.rand_default(self.rng())
|
rng = jtu.rand_default(self.rng())
|
||||||
args_maker = lambda: [rng(b + (m, n), dtype)]
|
args_maker = lambda: [rng(b + (m, n), dtype)]
|
||||||
|
|
||||||
@ -744,11 +740,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
|||||||
for dtype in float_types + complex_types))
|
for dtype in float_types + complex_types))
|
||||||
@jtu.skip_on_devices("gpu") # TODO(#2203): numerical errors
|
@jtu.skip_on_devices("gpu") # TODO(#2203): numerical errors
|
||||||
def testCond(self, shape, pnorm, dtype):
|
def testCond(self, shape, pnorm, dtype):
|
||||||
# TODO: enable after linking lax.svd to lax.linalg.svd
|
|
||||||
if (jnp.issubdtype(dtype, np.complexfloating) and
|
|
||||||
jtu.device_under_test() == "tpu"):
|
|
||||||
raise unittest.SkipTest("No complex SVD implementation")
|
|
||||||
|
|
||||||
def gen_mat():
|
def gen_mat():
|
||||||
# arr_gen = jtu.rand_some_nan(self.rng())
|
# arr_gen = jtu.rand_some_nan(self.rng())
|
||||||
arr_gen = jtu.rand_default(self.rng())
|
arr_gen = jtu.rand_default(self.rng())
|
||||||
@ -855,10 +846,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
|||||||
for dtype in float_types + complex_types))
|
for dtype in float_types + complex_types))
|
||||||
@jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1
|
@jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1
|
||||||
def testPinv(self, shape, dtype):
|
def testPinv(self, shape, dtype):
|
||||||
# TODO: enable after linking lax.svd to lax.linalg.svd
|
|
||||||
if (jnp.issubdtype(dtype, np.complexfloating) and
|
|
||||||
jtu.device_under_test() == "tpu"):
|
|
||||||
raise unittest.SkipTest("No complex SVD implementation")
|
|
||||||
rng = jtu.rand_default(self.rng())
|
rng = jtu.rand_default(self.rng())
|
||||||
args_maker = lambda: [rng(shape, dtype)]
|
args_maker = lambda: [rng(shape, dtype)]
|
||||||
|
|
||||||
@ -910,10 +897,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
|||||||
for dtype in float_types + complex_types))
|
for dtype in float_types + complex_types))
|
||||||
@jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1
|
@jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1
|
||||||
def testMatrixRank(self, shape, dtype):
|
def testMatrixRank(self, shape, dtype):
|
||||||
# TODO: enable after linking lax.svd to lax.linalg.svd
|
|
||||||
if (jnp.issubdtype(dtype, np.complexfloating) and
|
|
||||||
jtu.device_under_test() == "tpu"):
|
|
||||||
raise unittest.SkipTest("No complex SVD implementation")
|
|
||||||
rng = jtu.rand_default(self.rng())
|
rng = jtu.rand_default(self.rng())
|
||||||
args_maker = lambda: [rng(shape, dtype)]
|
args_maker = lambda: [rng(shape, dtype)]
|
||||||
a, = args_maker()
|
a, = args_maker()
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License
|
# limitations under the License
|
||||||
|
|
||||||
"""Tests for the library of QDWH-based SVD decomposition."""
|
"""Tests for the library of QDWH-based singular value decomposition."""
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
|
Loading…
x
Reference in New Issue
Block a user