From 269067e3e892f168ccfe34ea84c9f316d3f0f885 Mon Sep 17 00:00:00 2001 From: Vlad Feinberg Date: Thu, 4 Aug 2022 20:05:18 -0700 Subject: [PATCH] Make LOBPCG test plots compatible with bazel. bazel test invocations would previously not work, because the lobpcg_test did not include the appropriate flag parsing and absl test invocations when run as a script. This change fixes that, and in addition shards tests and removes needless and redundant slow tests with larger matrix sizes to make the tests finish in a smaller amount of time. Now, generated pngs with debug information are properly reported via the undeclared outputs directory when the environment variable to emit them, LOBPCG_EMIT_DEBUG_PLOTS, is set to a non-falsy value. PiperOrigin-RevId: 465465731 --- build/test-requirements.txt | 1 + jax/experimental/sparse/linalg.py | 5 +- tests/BUILD | 7 ++ tests/lobpcg_test.py | 114 +++++++++++++++++------------- 4 files changed, 76 insertions(+), 51 deletions(-) diff --git a/build/test-requirements.txt b/build/test-requirements.txt index f4344a908..7d9cb4500 100644 --- a/build/test-requirements.txt +++ b/build/test-requirements.txt @@ -1,5 +1,6 @@ cloudpickle colorama>=0.4.4 +matplotlib pillow>=8.3.1 pytest-benchmark pytest-xdist diff --git a/jax/experimental/sparse/linalg.py b/jax/experimental/sparse/linalg.py index c8c55c655..a47d98f63 100644 --- a/jax/experimental/sparse/linalg.py +++ b/jax/experimental/sparse/linalg.py @@ -78,10 +78,9 @@ def lobpcg_standard( to the float epsilon of `A.dtype`. Returns: - `theta, U, i [, diagnostics]`, where `theta` is a `(k,)` array + `theta, U, i`, where `theta` is a `(k,)` array of eigenvalues, `U` is a `(n, k)` array of eigenvectors, `i` is the - number of iterations performed, and `diagnostics` is a dictionary with debug - information, which is only returned if `debug` is set to true. + number of iterations performed. Raises: ValueError : if `A,X` dtypes or `n` dimensions do not match, or `k` is too diff --git a/tests/BUILD b/tests/BUILD index fba60496d..683af96fb 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -147,8 +147,15 @@ jax_test( jax_test( name = "lobpcg_test", srcs = ["lobpcg_test.py"], + env = {"LOBPCG_EMIT_DEBUG_PLOTS": "1"}, + shard_count = { + "cpu": 48, + "gpu": 48, + "tpu": 48, + }, deps = [ "//jax:experimental_sparse", + "//third_party/py/matplotlib", ], ) diff --git a/tests/lobpcg_test.py b/tests/lobpcg_test.py index 1c0e99e31..eb91fb391 100644 --- a/tests/lobpcg_test.py +++ b/tests/lobpcg_test.py @@ -22,14 +22,16 @@ import functools import re import os +from absl.testing import absltest from absl.testing import parameterized import numpy as np +from matplotlib import pyplot as plt import scipy.linalg as sla import scipy.sparse as sps import jax +from jax.config import config from jax._src import test_util as jtu -from jax._src.config import flags from jax.experimental.sparse import linalg, bcoo import jax.numpy as jnp @@ -41,20 +43,25 @@ def _make_concrete_cases(f64): example_names = list(_concrete_generators(dtype)) cases = [] for name in example_names: - nkm = [(100, 10, 20)] - if not flags.FLAGS.jax_skip_slow_tests: - nkm.append((1000, 100, 200)) - for n, k, m in nkm: - if name == 'ring laplacian': - m *= 3 - if name.startswith('linear'): - m *= 2 - if f64: - m *= 2 - case = [('matrix_name', name), ('n', n), ('k', k), ('m', m)] - clean_matrix_name = _clean_matrix_name(name) - case.append(('testcase_name', f'{clean_matrix_name}_n{n}')) - cases.append(dict(case)) + n, k, m, tol = 100, 10, 20, None + if name == 'ring laplacian': + m *= 3 + if name.startswith('linear'): + m *= 2 + if f64: + m *= 2 + if name.startswith('cluster') and not f64: + tol = 2e-6 + clean_matrix_name = _clean_matrix_name(name) + case = { + 'matrix_name': name, + 'n': n, + 'k': k, + 'm': m, + 'tol': tol, + 'testcase_name': f'{clean_matrix_name}_n{n}' + } + cases.append(case) assert len({c['testcase_name'] for c in cases}) == len(cases) return cases @@ -160,21 +167,22 @@ def _make_sparse_fn(n, fill): def _callable_generators(dtype): n = 100 topk = 10 - d = {'id': _make_id_fn(n), - 'linear cond=1k': _make_diag_fn(np.linspace(1, 1000, n), 40), - 'linear cond=100k':_make_diag_fn(np.linspace(1, 100 * 1000, n), 40), - 'geom cond=1k': _make_diag_fn(np.logspace(0, 3, n), 20), - 'geom cond=100k': _make_diag_fn(np.logspace(0, 5, n), 20), - 'ring laplacian': _make_ring_fn(n, 40), - 'randn': _make_randn_fn(n, topk, 40), - 'sparse 1%': _make_sparse_fn(n, 0.01), - 'sparse 10%': _make_sparse_fn(n, 0.10), - } + d = { + 'id': _make_id_fn(n), + 'linear cond=1k': _make_diag_fn(np.linspace(1, 1000, n), 40), + 'linear cond=100k': _make_diag_fn(np.linspace(1, 100 * 1000, n), 40), + 'geom cond=1k': _make_diag_fn(np.logspace(0, 3, n), 20), + 'geom cond=100k': _make_diag_fn(np.logspace(0, 5, n), 20), + 'ring laplacian': _make_ring_fn(n, 40), + 'randn': _make_randn_fn(n, topk, 40), + 'sparse 1%': _make_sparse_fn(n, 0.01), + 'sparse 10%': _make_sparse_fn(n, 0.10), + } ret = {} for k, (vec_mul_fn, eigs, m) in d.items(): if jtu.num_float_bits(dtype) > 32: - m *= 3 + m *= 3 eigs.sort() # Note we must lift the vector multiply into matmul @@ -183,6 +191,7 @@ def _callable_generators(dtype): ret[k] = (fn, eigs[::-1][:topk].astype(dtype), n, m) return ret + @jtu.with_config( jax_enable_checks=True, jax_debug_nans=True, @@ -190,12 +199,12 @@ def _callable_generators(dtype): jax_traceback_filtering='off') class LobpcgTest(jtu.JaxTestCase): - def checkLobpcgConsistency(self, matrix_name, n, k, m, dtype): + def checkLobpcgConsistency(self, matrix_name, n, k, m, tol, dtype): A, eigs = _concrete_generators(dtype)[matrix_name](n, k) X = self.rng().standard_normal(size=(n, k)).astype(dtype) A, X = (jnp.array(i, dtype=dtype) for i in (A, X)) - theta, U, i = linalg.lobpcg_standard(A, X, m) + theta, U, i = linalg.lobpcg_standard(A, X, m, tol) self.assertDtypesMatch(theta, A) self.assertDtypesMatch(U, A) @@ -219,7 +228,7 @@ class LobpcgTest(jtu.JaxTestCase): vector_norm = np.linalg.norm(Au) adjusted_error = resid_norm / n / (t + vector_norm) / 10 - eps = float(jnp.finfo(dtype).eps) + eps = float(jnp.finfo(dtype).eps) if tol is None else tol self.assertLessEqual( adjusted_error, eps, @@ -236,7 +245,8 @@ class LobpcgTest(jtu.JaxTestCase): f' for eigenvalue {i} (actual {float(theta[i])}, ' f'expected {float(eigs[i])})') - def checkLobpcgMonotonicity(self, matrix_name, n, k, m, dtype): + def checkLobpcgMonotonicity(self, matrix_name, n, k, m, tol, dtype): + del tol A, eigs = _concrete_generators(dtype)[matrix_name](n, k) X = self.rng().standard_normal(size=(n, k)).astype(dtype) @@ -261,15 +271,17 @@ class LobpcgTest(jtu.JaxTestCase): self._possibly_plot(A, eigs, X, m, matrix_name) def _possibly_plot(self, A, eigs, X, m, matrix_name): - plot_dir = os.getenv('LOBPCG_DEBUG_PLOT_DIR') - if plot_dir is not None: - if isinstance(A, (np.ndarray, jnp.ndarray)): - lobpcg = linalg._lobpcg_standard_matrix - else: - lobpcg = linalg._lobpcg_standard_callable - _theta, _U, _i, info = lobpcg(A, X, m, tol=0, debug=True) - self._debug_plots( - X, eigs, info, matrix_name, plot_dir) + if not os.getenv('LOBPCG_EMIT_DEBUG_PLOTS'): + return + + if isinstance(A, (np.ndarray, jnp.ndarray)): + lobpcg = linalg._lobpcg_standard_matrix + else: + lobpcg = linalg._lobpcg_standard_callable + _theta, _U, _i, info = lobpcg(A, X, m, tol=0, debug=True) + plot_dir = os.getenv('TEST_UNDECLARED_OUTPUTS_DIR') + assert plot_dir, 'expected TEST_UNDECLARED_OUTPUTS_DIR for lobpcg plots' + self._debug_plots(X, eigs, info, matrix_name, plot_dir) def _debug_plots(self, X, eigs, info, matrix_name, lobpcg_debug_plot_dir): os.makedirs(lobpcg_debug_plot_dir, exist_ok=True) @@ -280,7 +292,6 @@ class LobpcgTest(jtu.JaxTestCase): lobpcg_debug_plot_dir, f'{clean_matrix_name}_n{n}_k{k}_{dt}.png') - from matplotlib import pyplot as plt plt.switch_backend('Agg') fig, (ax0, ax1, ax2, ax3) = plt.subplots(1, 4, figsize=(24, 4)) @@ -313,6 +324,7 @@ class LobpcgTest(jtu.JaxTestCase): ax3.set_title(prefix + rf' $\lambda_{{\max}}=\ ${eigs[0]:.1e}') fig.savefig(figpath, bbox_inches='tight') + plt.close(fig) def checkApproxEigs(self, example_name, dtype): fn, eigs, n, m = _callable_generators(dtype)[example_name] @@ -346,6 +358,7 @@ class LobpcgTest(jtu.JaxTestCase): self._possibly_plot(fn, eigs, X, m, 'callable_' + example_name) + class F32LobpcgTest(LobpcgTest): def testLobpcgValidatesArguments(self): @@ -367,13 +380,13 @@ class F32LobpcgTest(LobpcgTest): @parameterized.named_parameters(_make_concrete_cases(f64=False)) @jtu.skip_on_devices("gpu") - def testLobpcgConsistencyF32(self, matrix_name, n, k, m): - self.checkLobpcgConsistency(matrix_name, n, k, m, jnp.float32) + def testLobpcgConsistencyF32(self, matrix_name, n, k, m, tol): + self.checkLobpcgConsistency(matrix_name, n, k, m, tol, jnp.float32) @parameterized.named_parameters(_make_concrete_cases(f64=False)) @jtu.skip_on_devices("rocm") # see SWDEV-321073 - def testLobpcgMonotonicityF32(self, matrix_name, n, k, m): - self.checkLobpcgMonotonicity(matrix_name, n, k, m, jnp.float32) + def testLobpcgMonotonicityF32(self, matrix_name, n, k, m, tol): + self.checkLobpcgMonotonicity(matrix_name, n, k, m, tol, jnp.float32) @parameterized.named_parameters(_make_callable_cases(f64=False)) def testCallableMatricesF32(self, matrix_name): @@ -385,15 +398,20 @@ class F64LobpcgTest(LobpcgTest): @parameterized.named_parameters(_make_concrete_cases(f64=True)) @jtu.skip_on_devices("tpu", "iree", "gpu") - def testLobpcgConsistencyF64(self, matrix_name, n, k, m): - self.checkLobpcgConsistency(matrix_name, n, k, m, jnp.float64) + def testLobpcgConsistencyF64(self, matrix_name, n, k, m, tol): + self.checkLobpcgConsistency(matrix_name, n, k, m, tol, jnp.float64) @parameterized.named_parameters(_make_concrete_cases(f64=True)) @jtu.skip_on_devices("tpu", "iree", "gpu") - def testLobpcgMonotonicityF64(self, matrix_name, n, k, m): - self.checkLobpcgMonotonicity(matrix_name, n, k, m, jnp.float64) + def testLobpcgMonotonicityF64(self, matrix_name, n, k, m, tol): + self.checkLobpcgMonotonicity(matrix_name, n, k, m, tol, jnp.float64) @parameterized.named_parameters(_make_callable_cases(f64=True)) @jtu.skip_on_devices("tpu", "iree", "gpu") def testCallableMatricesF64(self, matrix_name): self.checkApproxEigs(matrix_name, jnp.float64) + + +if __name__ == '__main__': + config.parse_flags_with_absl() + absltest.main(testLoader=jtu.JaxTestLoader())