diff --git a/jax/_src/internal_test_util/test_harnesses.py b/jax/_src/internal_test_util/test_harnesses.py index 3a2810498..bf099344b 100644 --- a/jax/_src/internal_test_util/test_harnesses.py +++ b/jax/_src/internal_test_util/test_harnesses.py @@ -1800,25 +1800,35 @@ for dtype in jtu.dtypes.all_floating + jtu.dtypes.complex: for shape in [(2, 2), (2, 7), (29, 29), (2, 3, 53), (2, 3, 29, 7)]: for full_matrices in [False, True]: for compute_uv in [False, True]: + subset_by_index = None define( lax.linalg.svd_p, f"shape={jtu.format_shape_dtype_string(shape, dtype)}_fullmatrices={full_matrices}_computeuv={compute_uv}", lambda *args: lax.linalg.svd_p.bind( - args[0], full_matrices=args[1], compute_uv=args[2]), [ - RandArg(shape, dtype), - StaticArg(full_matrices), - StaticArg(compute_uv) - ], + args[0], + full_matrices=args[1], + compute_uv=args[2], + subset_by_index=args[3], + ), + [ + RandArg(shape, dtype), + StaticArg(full_matrices), + StaticArg(compute_uv), + StaticArg(subset_by_index), + ], jax_unimplemented=[ Limitation( "unimplemented", devices=("cpu", "gpu"), - dtypes=[np.float16, dtypes.bfloat16]), + dtypes=[np.float16, dtypes.bfloat16], + ), ], shape=shape, dtype=dtype, full_matrices=full_matrices, - compute_uv=compute_uv) + compute_uv=compute_uv, + subset_by_index=subset_by_index, + ) for dtype in jtu.dtypes.all_inexact: for shape in [(0, 0), (5, 5), (2, 6, 6)]: @@ -2666,7 +2676,6 @@ for dtype in (np.float32, np.float64): dtype=dtype) - def wrap_and_split(): key = jax.random.key(42) result = jax.random.split(key, 2) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 729cb8171..80162a204 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -298,25 +298,61 @@ def qr(x: ArrayLike, *, full_matrices: bool = True) -> tuple[Array, Array]: q, r = qr_p.bind(x, full_matrices=full_matrices) return q, r -@overload -def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: Literal[True]) -> tuple[Array, Array, Array]: ... @overload -def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: Literal[False]) -> Array: ... +def svd( + x: ArrayLike, + *, + full_matrices: bool = True, + compute_uv: Literal[True], + subset_by_index: tuple[int, int] | None = None, +) -> tuple[Array, Array, Array]: + ... + @overload -def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: bool = True) -> Array | tuple[Array, Array, Array]: ... +def svd( + x: ArrayLike, + *, + full_matrices: bool = True, + compute_uv: Literal[False], + subset_by_index: tuple[int, int] | None = None, +) -> Array: + ... + + +@overload +def svd( + x: ArrayLike, + *, + full_matrices: bool = True, + compute_uv: bool = True, + subset_by_index: tuple[int, int] | None = None, +) -> Array | tuple[Array, Array, Array]: + ... + # TODO: Add `max_qdwh_iterations` to the function signature for TPU SVD. @_warn_on_positional_kwargs -def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: bool = True) -> Array | tuple[Array, Array, Array]: +def svd( + x: ArrayLike, + *, + full_matrices: bool = True, + compute_uv: bool = True, + subset_by_index: tuple[int, int] | None = None, +) -> Array | tuple[Array, Array, Array]: """Singular value decomposition. Returns the singular values if compute_uv is False, otherwise returns a triple containing the left singular vectors, the singular values and the adjoint of the right singular vectors. """ - result = svd_p.bind(x, full_matrices=full_matrices, compute_uv=compute_uv) + result = svd_p.bind( + x, + full_matrices=full_matrices, + compute_uv=compute_uv, + subset_by_index=subset_by_index, + ) if compute_uv: s, u, v = result return u, s, v @@ -324,6 +360,7 @@ def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: bool = True) -> s, = result return s + @_warn_on_positional_kwargs def triangular_solve(a: ArrayLike, b: ArrayLike, *, left_side: bool = False, lower: bool = False, @@ -1043,7 +1080,6 @@ mlir.register_lowering(triangular_solve_p, _triangular_solve_cpu_lower, # Support operation for LU decomposition: Transformation of the pivots returned # by LU decomposition into permutations. - # Define this outside lu_pivots_to_permutation to ensure fori_loop cache hits def _lu_pivots_body_fn(i, permutation_and_swaps): permutation, swaps = permutation_and_swaps @@ -1138,7 +1174,6 @@ mlir.register_lowering( gpu_linalg.hip_lu_pivots_to_permutation), platform='rocm') - # LU decomposition # Computes a pivoted LU decomposition such that @@ -1745,35 +1780,50 @@ mlir.register_lowering(qr_p, mlir.lower_fun(_qr_lowering)) # Singular value decomposition +def _svd_impl(operand, *, full_matrices, compute_uv, subset_by_index=None): + return dispatch.apply_primitive( + svd_p, + operand, + full_matrices=full_matrices, + compute_uv=compute_uv, + subset_by_index=subset_by_index, + ) -def _svd_impl(operand, *, full_matrices, compute_uv): - return dispatch.apply_primitive(svd_p, operand, full_matrices=full_matrices, - compute_uv=compute_uv) -def _svd_abstract_eval(operand, *, full_matrices, compute_uv): +def _svd_abstract_eval(operand, *, full_matrices, compute_uv, subset_by_index): if isinstance(operand, ShapedArray): - if operand.ndim < 2: - raise ValueError("Argument to singular value decomposition must have ndims >= 2") - batch_dims = operand.shape[:-2] m = operand.shape[-2] n = operand.shape[-1] - s = operand.update(shape=batch_dims + (min(m, n),), - dtype=lax_internal._complex_basetype(operand.dtype)) + rank = min(m, n) + if subset_by_index is not None: + if full_matrices and subset_by_index != (0, rank): + raise ValueError("full_matrices and subset_by_index cannot both be set") + rank = min(rank, subset_by_index[1] - subset_by_index[0]) + + s = operand.update( + shape=batch_dims + (rank,), + dtype=lax_internal._complex_basetype(operand.dtype), + ) if compute_uv: - u = operand.update(shape=batch_dims + (m, m if full_matrices else min(m, n))) - vt = operand.update(shape=batch_dims + (n if full_matrices else min(m, n), n)) + u = operand.update(shape=batch_dims + (m, m if full_matrices else rank)) + vt = operand.update(shape=batch_dims + (n if full_matrices else rank, n)) return s, u, vt else: return s, else: raise NotImplementedError + @jax.default_matmul_precision("float32") -def _svd_jvp_rule(primals, tangents, *, full_matrices, compute_uv): +def _svd_jvp_rule( + primals, tangents, *, full_matrices, compute_uv, subset_by_index +): A, = primals dA, = tangents - s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True) + s, U, Vt = svd_p.bind( + A, full_matrices=False, compute_uv=True, subset_by_index=subset_by_index + ) if compute_uv and full_matrices: # TODO: implement full matrices case, documented here: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf @@ -1812,6 +1862,7 @@ def _svd_jvp_rule(primals, tangents, *, full_matrices, compute_uv): return (s, U, Vt), (ds, dU, _H(dV)) + def _empty_svd(a, *, full_matrices, compute_uv): batch_shape = a.shape[:-2] m, n = a.shape[-2:] @@ -1828,8 +1879,17 @@ def _empty_svd(a, *, full_matrices, compute_uv): u, v = v, u return s, u, v -def _svd_cpu_gpu_lowering(gesvd_impl, ctx, operand, *, full_matrices, - compute_uv, platform: str): + +def _svd_cpu_gpu_lowering( + gesvd_impl, + ctx, + operand, + *, + full_matrices, + compute_uv, + subset_by_index, + platform: str, +): operand_aval, = ctx.avals_in s_aval = ctx.avals_out[0] m, n = operand_aval.shape[-2:] @@ -1841,9 +1901,16 @@ def _svd_cpu_gpu_lowering(gesvd_impl, ctx, operand, *, full_matrices, f"implemented only for the batch dimensions: {operand_aval.shape}") batch_dims = operand_aval.shape[:-2] + if not (subset_by_index is None or subset_by_index == (0, min(m, n))): + raise NotImplementedError("subset_by_index not implemented for CPU and GPU") + if m == 0 or n == 0: return mlir.lower_fun(_empty_svd, multiple_results=True)( - ctx, operand, full_matrices=full_matrices, compute_uv=compute_uv) + ctx, + operand, + full_matrices=full_matrices, + compute_uv=compute_uv, + ) if platform in ["cuda", "rocm"]: if not is_constant_shape(operand_aval.shape): @@ -1891,10 +1958,16 @@ def _svd_cpu_gpu_lowering(gesvd_impl, ctx, operand, *, full_matrices, return result -def _svd_tpu(a, *, full_matrices, compute_uv): + +def _svd_tpu(a, *, full_matrices, compute_uv, subset_by_index): batch_dims = a.shape[:-2] - fn = partial(lax_svd.svd, full_matrices=full_matrices, compute_uv=compute_uv) + fn = partial( + lax_svd.svd, + full_matrices=full_matrices, + compute_uv=compute_uv, + subset_by_index=subset_by_index, + ) for _ in range(len(batch_dims)): fn = api.vmap(fn) @@ -1905,28 +1978,49 @@ def _svd_tpu(a, *, full_matrices, compute_uv): s = fn(a) return [s] -def _svd_tpu_lowering_rule(ctx, operand, *, full_matrices, compute_uv): + +def _svd_tpu_lowering_rule( + ctx, operand, *, full_matrices, compute_uv, subset_by_index +): operand_aval, = ctx.avals_in m, n = operand_aval.shape[-2:] if m == 0 or n == 0: return mlir.lower_fun(_empty_svd, multiple_results=True)( - ctx, operand, full_matrices=full_matrices, compute_uv=compute_uv) + ctx, + operand, + full_matrices=full_matrices, + compute_uv=compute_uv, + ) return mlir.lower_fun(_svd_tpu, multiple_results=True)( - ctx, operand, full_matrices=full_matrices, compute_uv=compute_uv) + ctx, + operand, + full_matrices=full_matrices, + compute_uv=compute_uv, + subset_by_index=subset_by_index, + ) -def _svd_batching_rule(batched_args, batch_dims, *, full_matrices, compute_uv): + +def _svd_batching_rule( + batched_args, batch_dims, *, full_matrices, compute_uv, subset_by_index +): x, = batched_args bd, = batch_dims x = batching.moveaxis(x, bd, 0) - outs = svd_p.bind(x, full_matrices=full_matrices, compute_uv=compute_uv) + outs = svd_p.bind( + x, + full_matrices=full_matrices, + compute_uv=compute_uv, + subset_by_index=subset_by_index, + ) if compute_uv: return outs, (0, 0, 0) else: return outs, (0,) + svd_p = Primitive('svd') svd_p.multiple_results = True svd_p.def_impl(_svd_impl) diff --git a/jax/_src/lax/svd.py b/jax/_src/lax/svd.py index 1e3472e25..cbe229bc6 100644 --- a/jax/_src/lax/svd.py +++ b/jax/_src/lax/svd.py @@ -37,21 +37,29 @@ from __future__ import annotations from collections.abc import Sequence import functools +import operator from typing import Any import jax -import jax.numpy as jnp from jax import lax from jax._src import core +import jax.numpy as jnp -@functools.partial(jax.jit, static_argnums=(2, 3)) +@functools.partial(jax.jit, static_argnums=(2, 3, 4)) def _constant_svd( - a: Any, return_nan: bool, full_matrices: bool, compute_uv: bool = True + a: Any, + return_nan: bool, + full_matrices: bool, + compute_uv: bool = True, + subset_by_index: tuple[int, int] | None = None, ) -> Any | Sequence[Any]: """SVD on matrix of all zeros.""" m, n = a.shape k = min(m, n) + if subset_by_index is not None: + k = min(k, subset_by_index[1] - subset_by_index[0]) + s = jnp.where( return_nan, jnp.full(shape=(k,), fill_value=jnp.nan, dtype=a.real.dtype), @@ -90,9 +98,13 @@ def _constant_svd( return s -@functools.partial(jax.jit, static_argnums=(1, 2, 3)) +@functools.partial(jax.jit, static_argnums=(1, 2, 3, 4)) def _svd_tall_and_square_input( - a: Any, hermitian: bool, compute_uv: bool, max_iterations: int + a: Any, + hermitian: bool, + compute_uv: bool, + max_iterations: int, + subset_by_index: tuple[int, int] | None = None, ) -> Any | Sequence[Any]: """Singular value decomposition for m x n matrix and m >= n. @@ -113,7 +125,7 @@ def _svd_tall_and_square_input( max_iterations=max_iterations) # TODO: Uses `eigvals_only=True` if `compute_uv=False`. - v, s = lax.linalg.eigh(h) + v, s = lax.linalg.eigh(h, subset_by_index=subset_by_index) # Singular values are non-negative by definition. But eigh could return small # negative values, so we clamp them to zero. s = jnp.maximum(s, 0.0) @@ -148,12 +160,15 @@ def _svd_tall_and_square_input( return (u_out, s_out, v_out) -@functools.partial(jax.jit, static_argnums=(1, 2, 3, 4)) -def _qdwh_svd(a: Any, - full_matrices: bool, - compute_uv: bool = True, - hermitian: bool = False, - max_iterations: int = 10) -> Any | Sequence[Any]: +@functools.partial(jax.jit, static_argnums=(1, 2, 3, 4, 5)) +def _qdwh_svd( + a: Any, + full_matrices: bool, + compute_uv: bool = True, + hermitian: bool = False, + max_iterations: int = 10, + subset_by_index: tuple[int, int] | None = None, +) -> Any | Sequence[Any]: """Singular value decomposition. Args: @@ -196,12 +211,14 @@ def _qdwh_svd(a: Any, if not compute_uv: with jax.default_matmul_precision('float32'): - return _svd_tall_and_square_input(a, hermitian, compute_uv, - max_iterations) + return _svd_tall_and_square_input( + a, hermitian, compute_uv, max_iterations, subset_by_index + ) with jax.default_matmul_precision('float32'): u_out, s_out, v_out = _svd_tall_and_square_input( - a, hermitian, compute_uv, max_iterations) + a, hermitian, compute_uv, max_iterations, subset_by_index + ) if reduce_to_square: u_out = q @ u_out @@ -214,12 +231,15 @@ def _qdwh_svd(a: Any, return (u_out, s_out, v_out.T.conj()) -@functools.partial(jax.jit, static_argnums=(1, 2, 3, 4)) -def svd(a: Any, - full_matrices: bool, - compute_uv: bool = True, - hermitian: bool = False, - max_iterations: int = 10) -> Any | Sequence[Any]: +@functools.partial(jax.jit, static_argnums=(1, 2, 3, 4, 5)) +def svd( + a: Any, + full_matrices: bool, + compute_uv: bool = True, + hermitian: bool = False, + max_iterations: int = 10, + subset_by_index: tuple[int, int] | None = None, +) -> Any | Sequence[Any]: """Singular value decomposition. Args: @@ -230,6 +250,10 @@ def svd(a: Any, compute_uv: Whether to compute also `u` and `v` in addition to `s`. hermitian: True if `a` is Hermitian. max_iterations: The predefined maximum number of iterations of QDWH. + subset_by_index: Optional 2-tuple [start, end] indicating the range of + indices of singular componenets to compute. For example, if + ``subset_by_index`` = [0,2], then ``svd`` computes the two largest + singular values (and their singular vectors if `compute_uv` is true. Returns: A 3-tuple (`u`, `s`, `vh`), where `u` and `vh` are unitary matrices, @@ -247,12 +271,43 @@ def svd(a: Any, 'specified to use `svd` within JAX transformations.') hermitian = core.concrete_or_error( - bool, hermitian, 'The `hermitian` argument must be statically ' - 'specified to use `qdwh` within JAX transformations.') + bool, + hermitian, + 'The `hermitian` argument must be statically ' + 'specified to use `svd` within JAX transformations.', + ) max_iterations = core.concrete_or_error( - int, max_iterations, 'The `max_iterations` argument must be statically ' - 'specified to use `qdwh` within JAX transformations.') + int, + max_iterations, + 'The `max_iterations` argument must be statically ' + 'specified to use `svd` within JAX transformations.', + ) + + if subset_by_index is not None: + if len(subset_by_index) != 2: + raise ValueError('subset_by_index must be a tuple of size 2.') + # Make sure subset_by_index is a concrete tuple. + subset_by_index = ( + operator.index(subset_by_index[0]), + operator.index(subset_by_index[1]), + ) + if subset_by_index[0] >= subset_by_index[1]: + raise ValueError('Got empty index range in subset_by_index.') + if subset_by_index[0] < 0: + raise ValueError('Indices in subset_by_index must be non-negative.') + m, n = a.shape + rank = n if n < m else m + if subset_by_index[1] > rank: + raise ValueError('Index in subset_by_index[1] exceeds matrix size.') + if full_matrices and subset_by_index != (0, rank): + raise ValueError( + 'full_matrices and subset_by_index cannot be both be set.' + ) + # By convention, eigenvalues are numbered in non-decreasing order, while + # singular values are numbered non-increasing order, so change + # subset_by_index accordingly. + subset_by_index = (rank - subset_by_index[1], rank - subset_by_index[0]) # QDWH algorithm fails at zero-matrix `A` and produces all NaNs, which can # be seen from a dynamically weighted Halley (DWH) iteration: @@ -268,6 +323,7 @@ def svd(a: Any, return_nan=non_finite, full_matrices=full_matrices, compute_uv=compute_uv, + subset_by_index=subset_by_index, ), functools.partial( _qdwh_svd, @@ -275,6 +331,7 @@ def svd(a: Any, compute_uv=compute_uv, hermitian=hermitian, max_iterations=max_iterations, + subset_by_index=subset_by_index, ), operand=(a), ) diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index 1543ac0c0..2ce86faa1 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -72,30 +72,85 @@ def cholesky(a: ArrayLike, *, upper: bool = False) -> Array: L = lax_linalg.cholesky(a) return L.mT.conj() if upper else L + @overload -def svd(a: ArrayLike, full_matrices: bool = True, *, compute_uv: Literal[True], - hermitian: bool = False) -> SVDResult: ... +def svd( + a: ArrayLike, + full_matrices: bool = True, + *, + compute_uv: Literal[True], + hermitian: bool = False, + subset_by_index: tuple[int, int] | None = None, +) -> SVDResult: + ... + + @overload -def svd(a: ArrayLike, full_matrices: bool, compute_uv: Literal[True], - hermitian: bool = False) -> SVDResult: ... +def svd( + a: ArrayLike, + full_matrices: bool, + compute_uv: Literal[True], + hermitian: bool = False, + subset_by_index: tuple[int, int] | None = None, +) -> SVDResult: + ... + + @overload -def svd(a: ArrayLike, full_matrices: bool = True, *, compute_uv: Literal[False], - hermitian: bool = False) -> Array: ... +def svd( + a: ArrayLike, + full_matrices: bool = True, + *, + compute_uv: Literal[False], + hermitian: bool = False, + subset_by_index: tuple[int, int] | None = None, +) -> Array: + ... + + @overload -def svd(a: ArrayLike, full_matrices: bool, compute_uv: Literal[False], - hermitian: bool = False) -> Array: ... +def svd( + a: ArrayLike, + full_matrices: bool, + compute_uv: Literal[False], + hermitian: bool = False, + subset_by_index: tuple[int, int] | None = None, +) -> Array: + ... + + @overload -def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True, - hermitian: bool = False) -> Array | SVDResult: ... +def svd( + a: ArrayLike, + full_matrices: bool = True, + compute_uv: bool = True, + hermitian: bool = False, + subset_by_index: tuple[int, int] | None = None, +) -> Array | SVDResult: + ... + @implements(np.linalg.svd) -@partial(jit, static_argnames=('full_matrices', 'compute_uv', 'hermitian')) -def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True, - hermitian: bool = False) -> Array | SVDResult: +@partial( + jit, + static_argnames=( + "full_matrices", + "compute_uv", + "hermitian", + "subset_by_index", + ), +) +def svd( + a: ArrayLike, + full_matrices: bool = True, + compute_uv: bool = True, + hermitian: bool = False, + subset_by_index: tuple[int, int] | None = None, +) -> Array | SVDResult: check_arraylike("jnp.linalg.svd", a) a, = promote_dtypes_inexact(jnp.asarray(a)) if hermitian: - w, v = lax_linalg.eigh(a) + w, v = lax_linalg.eigh(a, subset_by_index=subset_by_index) s = lax.abs(v) if compute_uv: sign = lax.sign(v) @@ -111,10 +166,20 @@ def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True, return lax.rev(lax.sort(s, dimension=-1), dimensions=[s.ndim-1]) if compute_uv: - u, s, vh = lax_linalg.svd(a, full_matrices=full_matrices, compute_uv=True) + u, s, vh = lax_linalg.svd( + a, + full_matrices=full_matrices, + compute_uv=True, + subset_by_index=subset_by_index, + ) return SVDResult(u, s, vh) else: - return lax_linalg.svd(a, full_matrices=full_matrices, compute_uv=False) + return lax_linalg.svd( + a, + full_matrices=full_matrices, + compute_uv=False, + subset_by_index=subset_by_index, + ) @implements(np.linalg.matrix_power) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index e57e60583..4d8973a8a 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -3249,7 +3249,18 @@ def _qr(operand, full_matrices): tf_impl[lax.linalg.qr_p] = _qr -def _svd(operand, full_matrices, compute_uv): +def _svd( + operand: TfVal, + full_matrices: bool, + compute_uv: bool, + subset_by_index: tuple[int, int] | None = None, +): + if not ( + subset_by_index is None + or subset_by_index == (0, min(operand.shape[-1], operand.shape[-2])) + ): + raise NotImplementedError("subset_by_index is not implemented") + result = tf.linalg.svd(operand, full_matrices, compute_uv) if not compute_uv: return result, diff --git a/tests/svd_test.py b/tests/svd_test.py index a963861a4..7284ff60d 100644 --- a/tests/svd_test.py +++ b/tests/svd_test.py @@ -252,6 +252,42 @@ class SvdTest(jtu.JaxTestCase): self.assertAllClose(expected_s, jnp.real(actual_s), rtol=_SVD_RTOL, atol=1E-6) + @jtu.sample_product( + start=[0, 1, 64, 126, 127], + end=[1, 2, 65, 127, 128], + ) + @jtu.run_on_devices('tpu') # TODO(rmlarsen: enable on other devices) + def testSvdSubsetByIndex(self, start, end): + if start >= end: + return + dtype = np.float32 + m = 256 + n = 128 + rng = jtu.rand_default(self.rng()) + tol = np.maximum(n, 80) * np.finfo(dtype).eps + args_maker = lambda: [rng((m, n), dtype)] + subset_by_index = (start, end) + k = end - start + (a,) = args_maker() + + u, s, vt = jnp.linalg.svd( + a, full_matrices=False, subset_by_index=subset_by_index + ) + self.assertEqual(u.shape, (m, k)) + self.assertEqual(s.shape, (k,)) + self.assertEqual(vt.shape, (k, n)) + + with jax.numpy_rank_promotion('allow'): + self.assertLessEqual( + np.linalg.norm(np.matmul(a, vt.T) - u * s), tol * np.linalg.norm(a) + ) + + # Test that we get the approximately the same singular values when + # slicing the full SVD. + _, full_s, _ = jnp.linalg.svd(a, full_matrices=False) + s_slice = full_s[start:end] + self.assertAllClose(s_slice, s, atol=tol, rtol=tol) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())