rocm_jax/benchmarks/linalg_benchmark.py
Peter Hawkins cb182b8b22 Use a Jacobi SVD solver for unbatched SVDs up to 1024x1024 on NVIDIA GPUs.
The unbatched Jacobi solver is faster for small-moderate matrices, and the unbatched kernel doesn't have size restrictions.

Timings on T4 GPU:

Before:

------------------------------------------------------------
Benchmark                  Time             CPU   Iterations
------------------------------------------------------------
svd/m:1/n:1           263587 ns       242274 ns         2780
svd/m:2/n:1           335561 ns       298238 ns         2303
svd/m:5/n:1           337784 ns       299841 ns         2304
svd/m:10/n:1          339184 ns       300703 ns         2311
svd/m:100/n:1         359826 ns       320088 ns         2159
svd/m:500/n:1         376124 ns       338660 ns         2076
svd/m:800/n:1         375779 ns       335590 ns         2060
svd/m:1000/n:1        419171 ns       341487 ns         2072
svd/m:1/n:2           307564 ns       270663 ns         2544
svd/m:2/n:2           320928 ns       283601 ns         2487
svd/m:5/n:2           377373 ns       344228 ns         2035
svd/m:10/n:2          380557 ns       349412 ns         1953
svd/m:100/n:2         435465 ns       403496 ns         1722
svd/m:500/n:2         444610 ns       410913 ns         1680
svd/m:800/n:2         454493 ns       416495 ns         1665
svd/m:1000/n:2        492110 ns       420539 ns         1665
svd/m:1/n:5           307316 ns       275833 ns         2531
svd/m:2/n:5           374318 ns       341432 ns         2086
svd/m:5/n:5           512928 ns       470293 ns         1361
svd/m:10/n:5          589330 ns       537070 ns         1353
svd/m:100/n:5         620164 ns       580166 ns         1193
svd/m:500/n:5         636424 ns       593692 ns         1180
svd/m:800/n:5         635545 ns       595016 ns         1181
svd/m:1000/n:5        672443 ns       597387 ns         1115
svd/m:1/n:10          310013 ns       273998 ns         2520
svd/m:2/n:10          370451 ns       334489 ns         2105
svd/m:5/n:10          560037 ns       522223 ns         1274
svd/m:10/n:10         572868 ns       535388 ns         1304
svd/m:100/n:10        959802 ns       918258 ns          765
svd/m:500/n:10        955958 ns       909778 ns          758
svd/m:800/n:10        924104 ns       879512 ns          777
svd/m:1000/n:10       950140 ns       883493 ns          775
svd/m:1/n:100         351237 ns       315554 ns         2198
svd/m:2/n:100         426883 ns       390089 ns         1792
svd/m:5/n:100         601557 ns       564493 ns         1255
svd/m:10/n:100        920819 ns       880011 ns          787
svd/m:100/n:100      7902281 ns      7229220 ns           95
svd/m:500/n:100      9720727 ns      9040679 ns           79
svd/m:800/n:100      9856378 ns      8998050 ns           79
svd/m:1000/n:100     9721017 ns      9086414 ns           79
svd/m:1/n:500         371171 ns       334217 ns         2117
svd/m:2/n:500         449165 ns       411499 ns         1700
svd/m:5/n:500         620354 ns       581866 ns         1185
svd/m:10/n:500        892375 ns       847239 ns          833
svd/m:100/n:500      9564810 ns      8867540 ns           79
svd/m:500/n:500    111924035 ns    104078023 ns            7
svd/m:800/n:500    147777319 ns    142730412 ns            5
svd/m:1000/n:500   154205084 ns    149740209 ns            5
svd/m:1/n:800         372122 ns       334212 ns         2119
svd/m:2/n:800         456672 ns       419260 ns         1680
svd/m:5/n:800         691208 ns       626003 ns         1190
svd/m:10/n:800       1017694 ns       941480 ns          730
svd/m:100/n:800      9892683 ns      9091043 ns           76
svd/m:500/n:800    144134235 ns    139129722 ns            5
svd/m:800/n:800    342790246 ns    333299774 ns            2
svd/m:1000/n:800   432820082 ns    427978978 ns            2
svd/m:1/n:1000        372785 ns       335745 ns         1805
svd/m:2/n:1000        451946 ns       413341 ns         1668
svd/m:5/n:1000        618475 ns       577213 ns         1169
svd/m:10/n:1000       907729 ns       863335 ns          808
svd/m:100/n:1000     9868543 ns      9116870 ns           76
svd/m:500/n:1000   156777811 ns    152042065 ns            5
svd/m:800/n:1000   429704070 ns    424677592 ns            2
svd/m:1000/n:1000  654864311 ns    642693162 ns            1

After:
------------------------------------------------------------
Benchmark                  Time             CPU   Iterations
------------------------------------------------------------
svd/m:1/n:1           265980 ns       245433 ns         2791
svd/m:2/n:1           340203 ns       302783 ns         2288
svd/m:5/n:1           337807 ns       301916 ns         2286
svd/m:10/n:1          338064 ns       302441 ns         2297
svd/m:100/n:1         335444 ns       298440 ns         2327
svd/m:500/n:1         338025 ns       302096 ns         2272
svd/m:800/n:1         328382 ns       291740 ns         2252
svd/m:1000/n:1        397494 ns       310905 ns         2239
svd/m:1/n:2           310464 ns       274507 ns         2535
svd/m:2/n:2           319999 ns       284247 ns         2515
svd/m:5/n:2           373435 ns       335919 ns         2069
svd/m:10/n:2          376327 ns       339327 ns         2056
svd/m:100/n:2         385061 ns       349258 ns         2003
svd/m:500/n:2         392352 ns       355735 ns         1932
svd/m:800/n:2         410736 ns       370677 ns         1881
svd/m:1000/n:2        494326 ns       405603 ns         1721
svd/m:1/n:5           316735 ns       277292 ns         2538
svd/m:2/n:5           383748 ns       342218 ns         2077
svd/m:5/n:5           494204 ns       454309 ns         1476
svd/m:10/n:5          547017 ns       508184 ns         1371
svd/m:100/n:5         514537 ns       476761 ns         1460
svd/m:500/n:5         544656 ns       504877 ns         1381
svd/m:800/n:5         642590 ns       599314 ns         1159
svd/m:1000/n:5        706166 ns       621209 ns         1106
svd/m:1/n:10          310825 ns       274374 ns         2511
svd/m:2/n:10          381316 ns       344202 ns         2094
svd/m:5/n:10          565469 ns       526759 ns         1266
svd/m:10/n:10         576111 ns       537286 ns         1299
svd/m:100/n:10        653250 ns       613392 ns         1137
svd/m:500/n:10        690532 ns       645828 ns         1080
svd/m:800/n:10        763924 ns       723677 ns          959
svd/m:1000/n:10       940342 ns       855517 ns          818
svd/m:1/n:100         306134 ns       271533 ns         2526
svd/m:2/n:100         374680 ns       339298 ns         2071
svd/m:5/n:100         576926 ns       539062 ns         1228
svd/m:10/n:100        656806 ns       615171 ns         1123
svd/m:100/n:100      3295164 ns      3138621 ns          223
svd/m:500/n:100      4269347 ns      4166000 ns          168
svd/m:800/n:100      4656541 ns      4522247 ns          154
svd/m:1000/n:100     6479223 ns      6354578 ns          112
svd/m:1/n:500         329966 ns       289083 ns         2440
svd/m:2/n:500         407535 ns       366794 ns         1947
svd/m:5/n:500         567367 ns       522809 ns         1336
svd/m:10/n:500        712307 ns       657608 ns         1065
svd/m:100/n:500      4262986 ns      4169907 ns          167
svd/m:500/n:500     28824720 ns     28650258 ns           25
svd/m:800/n:500     29330139 ns     28677269 ns           25
svd/m:1000/n:500    30848037 ns     30089216 ns           23
svd/m:1/n:800         328620 ns       289181 ns         2329
svd/m:2/n:800         419052 ns       379483 ns         1876
svd/m:5/n:800         587366 ns       546979 ns         1269
svd/m:10/n:800        830762 ns       787923 ns          893
svd/m:100/n:800      4763633 ns      4595738 ns          152
svd/m:500/n:800     30447861 ns     29949714 ns           24
svd/m:800/n:800     94188958 ns     93488372 ns            8
svd/m:1000/n:800    94701529 ns     93394677 ns            7
svd/m:1/n:1000        351102 ns       313099 ns         2218
svd/m:2/n:1000        446543 ns       407807 ns         1708
svd/m:5/n:1000        661152 ns       616174 ns         1129
svd/m:10/n:1000       915743 ns       873397 ns          802
svd/m:100/n:1000     6434730 ns      6282779 ns          113
svd/m:500/n:1000    30244321 ns     29684290 ns           24
svd/m:800/n:1000    92727423 ns     91477078 ns            8
svd/m:1000/n:1000  169500709 ns    168358420 ns            4
PiperOrigin-RevId: 582041508
2023-11-13 12:04:13 -08:00

39 lines
1.2 KiB
Python

# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Benchmarks for JAX linear algebra functions."""
import google_benchmark
import jax
import jax.numpy as jnp
import numpy as np
@google_benchmark.register
@google_benchmark.option.arg_names(['m', 'n'])
@google_benchmark.option.args_product(
[[1, 2, 5, 10, 100, 500, 800, 1000], [1, 2, 5, 10, 100, 500, 800, 1000]]
)
def svd(state):
np.random.seed(1234)
m, n = state.range(0), state.range(1)
x = np.random.randn(m, n).astype(np.float32)
jax.block_until_ready(jnp.linalg.svd(x)[0])
while state:
jax.block_until_ready(jnp.linalg.svd(x)[0])
if __name__ == '__main__':
google_benchmark.main()