diff --git a/CHANGELOG.md b/CHANGELOG.md index c623b212c..1d72975b5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,10 @@ Remember to align the itemized text with the first line of an item within a list ## jaxlib 0.4.21 +* Changes + * On NVIDIA GPU, JAX now prefers a Jacobi SVD solver for matrices up to + 1024x1024. The Jacobi solver appears faster than the non-Jacobi version. + ## jax 0.4.20 (Nov 2, 2023) ## jaxlib 0.4.20 (Nov 2, 2023) diff --git a/benchmarks/linalg_benchmark.py b/benchmarks/linalg_benchmark.py new file mode 100644 index 000000000..11e59c8eb --- /dev/null +++ b/benchmarks/linalg_benchmark.py @@ -0,0 +1,38 @@ +# Copyright 2020 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Benchmarks for JAX linear algebra functions.""" + +import google_benchmark +import jax +import jax.numpy as jnp +import numpy as np + + +@google_benchmark.register +@google_benchmark.option.arg_names(['m', 'n']) +@google_benchmark.option.args_product( + [[1, 2, 5, 10, 100, 500, 800, 1000], [1, 2, 5, 10, 100, 500, 800, 1000]] +) +def svd(state): + np.random.seed(1234) + m, n = state.range(0), state.range(1) + x = np.random.randn(m, n).astype(np.float32) + jax.block_until_ready(jnp.linalg.svd(x)[0]) + while state: + jax.block_until_ready(jnp.linalg.svd(x)[0]) + + +if __name__ == '__main__': + google_benchmark.main() diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index 4f850e6f4..f2dade564 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -369,7 +369,14 @@ def _gesvd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a, vector_layout = (num_bd,) + tuple(range(num_bd - 1, -1, -1)) i32_type = ir.IntegerType.get_signless(32) - if have_jacobi_solver and m < 32 and n < 32: + # NVIDIA's batched Jacobi solver supports a maximum matrix size of 32x32, but + # the unbatched solver has no such limit. The unbatched solver appears to + # outperform gesvd for small-moderate matrices, e.g., see: + # https://developer.download.nvidia.com/video/gputechconf/gtc/2019/presentation/s9226-fast-singular-value-decomposition-on-gpus-v2.pdf + # slide 5. + if have_jacobi_solver and ( + (b == 1 and m <= 1024 and n <= 1024) or (m <= 32 and n <= 32) + ): # The batched kernel doesn't support "econ" mode. econ = not full_matrices and b == 1 lwork, opaque = gpu_solver.build_gesvdj_descriptor( diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 9886b3fa2..921857738 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -590,20 +590,26 @@ class NumpyLinalgTest(jtu.JaxTestCase): jnp.linalg.norm(jnp.array([1.0, 2.0, 3.0]), ord="inf") @jtu.sample_product( - [dict(m=m, n=n, full_matrices=full_matrices, hermitian=hermitian) - for (m, n), full_matrices in ( - list(itertools.product(itertools.product([0, 2, 7, 29, 53], repeat=2), - [False, True])) + - # Test cases that ensure we are economical when computing the SVD and - # its gradient. If we form a 400kx400k matrix explicitly we will OOM. - [((400000, 2), False), - ((2, 400000), False)] - ) - for hermitian in ([False, True] if m == n else [False]) - ], - b=[(), (3,), (2, 3)], - dtype=float_types + complex_types, - compute_uv=[False, True], + [ + dict(m=m, n=n, full_matrices=full_matrices, hermitian=hermitian) + for (m, n), full_matrices in ( + list( + itertools.product( + itertools.product([0, 2, 7, 29, 32, 53], repeat=2), + [False, True], + ) + ) + + + # Test cases that ensure we are economical when computing the SVD + # and its gradient. If we form a 400kx400k matrix explicitly we + # will OOM. + [((400000, 2), False), ((2, 400000), False)] + ) + for hermitian in ([False, True] if m == n else [False]) + ], + b=[(), (3,), (2, 3)], + dtype=float_types + complex_types, + compute_uv=[False, True], ) @jax.default_matmul_precision("float32") def testSVD(self, b, m, n, dtype, full_matrices, compute_uv, hermitian):