From cb182b8b225a15ab11911acb32392fb12d8255b6 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 13 Nov 2023 12:03:36 -0800 Subject: [PATCH] Use a Jacobi SVD solver for unbatched SVDs up to 1024x1024 on NVIDIA GPUs. The unbatched Jacobi solver is faster for small-moderate matrices, and the unbatched kernel doesn't have size restrictions. Timings on T4 GPU: Before: ------------------------------------------------------------ Benchmark Time CPU Iterations ------------------------------------------------------------ svd/m:1/n:1 263587 ns 242274 ns 2780 svd/m:2/n:1 335561 ns 298238 ns 2303 svd/m:5/n:1 337784 ns 299841 ns 2304 svd/m:10/n:1 339184 ns 300703 ns 2311 svd/m:100/n:1 359826 ns 320088 ns 2159 svd/m:500/n:1 376124 ns 338660 ns 2076 svd/m:800/n:1 375779 ns 335590 ns 2060 svd/m:1000/n:1 419171 ns 341487 ns 2072 svd/m:1/n:2 307564 ns 270663 ns 2544 svd/m:2/n:2 320928 ns 283601 ns 2487 svd/m:5/n:2 377373 ns 344228 ns 2035 svd/m:10/n:2 380557 ns 349412 ns 1953 svd/m:100/n:2 435465 ns 403496 ns 1722 svd/m:500/n:2 444610 ns 410913 ns 1680 svd/m:800/n:2 454493 ns 416495 ns 1665 svd/m:1000/n:2 492110 ns 420539 ns 1665 svd/m:1/n:5 307316 ns 275833 ns 2531 svd/m:2/n:5 374318 ns 341432 ns 2086 svd/m:5/n:5 512928 ns 470293 ns 1361 svd/m:10/n:5 589330 ns 537070 ns 1353 svd/m:100/n:5 620164 ns 580166 ns 1193 svd/m:500/n:5 636424 ns 593692 ns 1180 svd/m:800/n:5 635545 ns 595016 ns 1181 svd/m:1000/n:5 672443 ns 597387 ns 1115 svd/m:1/n:10 310013 ns 273998 ns 2520 svd/m:2/n:10 370451 ns 334489 ns 2105 svd/m:5/n:10 560037 ns 522223 ns 1274 svd/m:10/n:10 572868 ns 535388 ns 1304 svd/m:100/n:10 959802 ns 918258 ns 765 svd/m:500/n:10 955958 ns 909778 ns 758 svd/m:800/n:10 924104 ns 879512 ns 777 svd/m:1000/n:10 950140 ns 883493 ns 775 svd/m:1/n:100 351237 ns 315554 ns 2198 svd/m:2/n:100 426883 ns 390089 ns 1792 svd/m:5/n:100 601557 ns 564493 ns 1255 svd/m:10/n:100 920819 ns 880011 ns 787 svd/m:100/n:100 7902281 ns 7229220 ns 95 svd/m:500/n:100 9720727 ns 9040679 ns 79 svd/m:800/n:100 9856378 ns 8998050 ns 79 svd/m:1000/n:100 9721017 ns 9086414 ns 79 svd/m:1/n:500 371171 ns 334217 ns 2117 svd/m:2/n:500 449165 ns 411499 ns 1700 svd/m:5/n:500 620354 ns 581866 ns 1185 svd/m:10/n:500 892375 ns 847239 ns 833 svd/m:100/n:500 9564810 ns 8867540 ns 79 svd/m:500/n:500 111924035 ns 104078023 ns 7 svd/m:800/n:500 147777319 ns 142730412 ns 5 svd/m:1000/n:500 154205084 ns 149740209 ns 5 svd/m:1/n:800 372122 ns 334212 ns 2119 svd/m:2/n:800 456672 ns 419260 ns 1680 svd/m:5/n:800 691208 ns 626003 ns 1190 svd/m:10/n:800 1017694 ns 941480 ns 730 svd/m:100/n:800 9892683 ns 9091043 ns 76 svd/m:500/n:800 144134235 ns 139129722 ns 5 svd/m:800/n:800 342790246 ns 333299774 ns 2 svd/m:1000/n:800 432820082 ns 427978978 ns 2 svd/m:1/n:1000 372785 ns 335745 ns 1805 svd/m:2/n:1000 451946 ns 413341 ns 1668 svd/m:5/n:1000 618475 ns 577213 ns 1169 svd/m:10/n:1000 907729 ns 863335 ns 808 svd/m:100/n:1000 9868543 ns 9116870 ns 76 svd/m:500/n:1000 156777811 ns 152042065 ns 5 svd/m:800/n:1000 429704070 ns 424677592 ns 2 svd/m:1000/n:1000 654864311 ns 642693162 ns 1 After: ------------------------------------------------------------ Benchmark Time CPU Iterations ------------------------------------------------------------ svd/m:1/n:1 265980 ns 245433 ns 2791 svd/m:2/n:1 340203 ns 302783 ns 2288 svd/m:5/n:1 337807 ns 301916 ns 2286 svd/m:10/n:1 338064 ns 302441 ns 2297 svd/m:100/n:1 335444 ns 298440 ns 2327 svd/m:500/n:1 338025 ns 302096 ns 2272 svd/m:800/n:1 328382 ns 291740 ns 2252 svd/m:1000/n:1 397494 ns 310905 ns 2239 svd/m:1/n:2 310464 ns 274507 ns 2535 svd/m:2/n:2 319999 ns 284247 ns 2515 svd/m:5/n:2 373435 ns 335919 ns 2069 svd/m:10/n:2 376327 ns 339327 ns 2056 svd/m:100/n:2 385061 ns 349258 ns 2003 svd/m:500/n:2 392352 ns 355735 ns 1932 svd/m:800/n:2 410736 ns 370677 ns 1881 svd/m:1000/n:2 494326 ns 405603 ns 1721 svd/m:1/n:5 316735 ns 277292 ns 2538 svd/m:2/n:5 383748 ns 342218 ns 2077 svd/m:5/n:5 494204 ns 454309 ns 1476 svd/m:10/n:5 547017 ns 508184 ns 1371 svd/m:100/n:5 514537 ns 476761 ns 1460 svd/m:500/n:5 544656 ns 504877 ns 1381 svd/m:800/n:5 642590 ns 599314 ns 1159 svd/m:1000/n:5 706166 ns 621209 ns 1106 svd/m:1/n:10 310825 ns 274374 ns 2511 svd/m:2/n:10 381316 ns 344202 ns 2094 svd/m:5/n:10 565469 ns 526759 ns 1266 svd/m:10/n:10 576111 ns 537286 ns 1299 svd/m:100/n:10 653250 ns 613392 ns 1137 svd/m:500/n:10 690532 ns 645828 ns 1080 svd/m:800/n:10 763924 ns 723677 ns 959 svd/m:1000/n:10 940342 ns 855517 ns 818 svd/m:1/n:100 306134 ns 271533 ns 2526 svd/m:2/n:100 374680 ns 339298 ns 2071 svd/m:5/n:100 576926 ns 539062 ns 1228 svd/m:10/n:100 656806 ns 615171 ns 1123 svd/m:100/n:100 3295164 ns 3138621 ns 223 svd/m:500/n:100 4269347 ns 4166000 ns 168 svd/m:800/n:100 4656541 ns 4522247 ns 154 svd/m:1000/n:100 6479223 ns 6354578 ns 112 svd/m:1/n:500 329966 ns 289083 ns 2440 svd/m:2/n:500 407535 ns 366794 ns 1947 svd/m:5/n:500 567367 ns 522809 ns 1336 svd/m:10/n:500 712307 ns 657608 ns 1065 svd/m:100/n:500 4262986 ns 4169907 ns 167 svd/m:500/n:500 28824720 ns 28650258 ns 25 svd/m:800/n:500 29330139 ns 28677269 ns 25 svd/m:1000/n:500 30848037 ns 30089216 ns 23 svd/m:1/n:800 328620 ns 289181 ns 2329 svd/m:2/n:800 419052 ns 379483 ns 1876 svd/m:5/n:800 587366 ns 546979 ns 1269 svd/m:10/n:800 830762 ns 787923 ns 893 svd/m:100/n:800 4763633 ns 4595738 ns 152 svd/m:500/n:800 30447861 ns 29949714 ns 24 svd/m:800/n:800 94188958 ns 93488372 ns 8 svd/m:1000/n:800 94701529 ns 93394677 ns 7 svd/m:1/n:1000 351102 ns 313099 ns 2218 svd/m:2/n:1000 446543 ns 407807 ns 1708 svd/m:5/n:1000 661152 ns 616174 ns 1129 svd/m:10/n:1000 915743 ns 873397 ns 802 svd/m:100/n:1000 6434730 ns 6282779 ns 113 svd/m:500/n:1000 30244321 ns 29684290 ns 24 svd/m:800/n:1000 92727423 ns 91477078 ns 8 svd/m:1000/n:1000 169500709 ns 168358420 ns 4 PiperOrigin-RevId: 582041508 --- CHANGELOG.md | 4 ++++ benchmarks/linalg_benchmark.py | 38 ++++++++++++++++++++++++++++++++++ jaxlib/gpu_solver.py | 9 +++++++- tests/linalg_test.py | 34 +++++++++++++++++------------- 4 files changed, 70 insertions(+), 15 deletions(-) create mode 100644 benchmarks/linalg_benchmark.py diff --git a/CHANGELOG.md b/CHANGELOG.md index c623b212c..1d72975b5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,10 @@ Remember to align the itemized text with the first line of an item within a list ## jaxlib 0.4.21 +* Changes + * On NVIDIA GPU, JAX now prefers a Jacobi SVD solver for matrices up to + 1024x1024. The Jacobi solver appears faster than the non-Jacobi version. + ## jax 0.4.20 (Nov 2, 2023) ## jaxlib 0.4.20 (Nov 2, 2023) diff --git a/benchmarks/linalg_benchmark.py b/benchmarks/linalg_benchmark.py new file mode 100644 index 000000000..11e59c8eb --- /dev/null +++ b/benchmarks/linalg_benchmark.py @@ -0,0 +1,38 @@ +# Copyright 2020 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Benchmarks for JAX linear algebra functions.""" + +import google_benchmark +import jax +import jax.numpy as jnp +import numpy as np + + +@google_benchmark.register +@google_benchmark.option.arg_names(['m', 'n']) +@google_benchmark.option.args_product( + [[1, 2, 5, 10, 100, 500, 800, 1000], [1, 2, 5, 10, 100, 500, 800, 1000]] +) +def svd(state): + np.random.seed(1234) + m, n = state.range(0), state.range(1) + x = np.random.randn(m, n).astype(np.float32) + jax.block_until_ready(jnp.linalg.svd(x)[0]) + while state: + jax.block_until_ready(jnp.linalg.svd(x)[0]) + + +if __name__ == '__main__': + google_benchmark.main() diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index 4f850e6f4..f2dade564 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -369,7 +369,14 @@ def _gesvd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a, vector_layout = (num_bd,) + tuple(range(num_bd - 1, -1, -1)) i32_type = ir.IntegerType.get_signless(32) - if have_jacobi_solver and m < 32 and n < 32: + # NVIDIA's batched Jacobi solver supports a maximum matrix size of 32x32, but + # the unbatched solver has no such limit. The unbatched solver appears to + # outperform gesvd for small-moderate matrices, e.g., see: + # https://developer.download.nvidia.com/video/gputechconf/gtc/2019/presentation/s9226-fast-singular-value-decomposition-on-gpus-v2.pdf + # slide 5. + if have_jacobi_solver and ( + (b == 1 and m <= 1024 and n <= 1024) or (m <= 32 and n <= 32) + ): # The batched kernel doesn't support "econ" mode. econ = not full_matrices and b == 1 lwork, opaque = gpu_solver.build_gesvdj_descriptor( diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 9886b3fa2..921857738 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -590,20 +590,26 @@ class NumpyLinalgTest(jtu.JaxTestCase): jnp.linalg.norm(jnp.array([1.0, 2.0, 3.0]), ord="inf") @jtu.sample_product( - [dict(m=m, n=n, full_matrices=full_matrices, hermitian=hermitian) - for (m, n), full_matrices in ( - list(itertools.product(itertools.product([0, 2, 7, 29, 53], repeat=2), - [False, True])) + - # Test cases that ensure we are economical when computing the SVD and - # its gradient. If we form a 400kx400k matrix explicitly we will OOM. - [((400000, 2), False), - ((2, 400000), False)] - ) - for hermitian in ([False, True] if m == n else [False]) - ], - b=[(), (3,), (2, 3)], - dtype=float_types + complex_types, - compute_uv=[False, True], + [ + dict(m=m, n=n, full_matrices=full_matrices, hermitian=hermitian) + for (m, n), full_matrices in ( + list( + itertools.product( + itertools.product([0, 2, 7, 29, 32, 53], repeat=2), + [False, True], + ) + ) + + + # Test cases that ensure we are economical when computing the SVD + # and its gradient. If we form a 400kx400k matrix explicitly we + # will OOM. + [((400000, 2), False), ((2, 400000), False)] + ) + for hermitian in ([False, True] if m == n else [False]) + ], + b=[(), (3,), (2, 3)], + dtype=float_types + complex_types, + compute_uv=[False, True], ) @jax.default_matmul_precision("float32") def testSVD(self, b, m, n, dtype, full_matrices, compute_uv, hermitian):