Make LOBPCG test plots compatible with bazel.

bazel test invocations would previously not work, because the lobpcg_test did not include the appropriate flag parsing and absl test invocations when run as a script. This change fixes that, and in addition shards tests and removes needless and redundant slow tests with larger matrix sizes to make the tests finish in a smaller amount of time. Now, generated pngs with debug information are properly reported via the undeclared outputs directory when the environment variable to emit them, LOBPCG_EMIT_DEBUG_PLOTS, is set to a non-falsy value.

PiperOrigin-RevId: 465465731
This commit is contained in:
Vlad Feinberg 2022-08-04 20:05:18 -07:00 committed by jax authors
parent 07da502323
commit 269067e3e8
4 changed files with 76 additions and 51 deletions

View File

@ -1,5 +1,6 @@
cloudpickle
colorama>=0.4.4
matplotlib
pillow>=8.3.1
pytest-benchmark
pytest-xdist

View File

@ -78,10 +78,9 @@ def lobpcg_standard(
to the float epsilon of `A.dtype`.
Returns:
`theta, U, i [, diagnostics]`, where `theta` is a `(k,)` array
`theta, U, i`, where `theta` is a `(k,)` array
of eigenvalues, `U` is a `(n, k)` array of eigenvectors, `i` is the
number of iterations performed, and `diagnostics` is a dictionary with debug
information, which is only returned if `debug` is set to true.
number of iterations performed.
Raises:
ValueError : if `A,X` dtypes or `n` dimensions do not match, or `k` is too

View File

@ -147,8 +147,15 @@ jax_test(
jax_test(
name = "lobpcg_test",
srcs = ["lobpcg_test.py"],
env = {"LOBPCG_EMIT_DEBUG_PLOTS": "1"},
shard_count = {
"cpu": 48,
"gpu": 48,
"tpu": 48,
},
deps = [
"//jax:experimental_sparse",
"//third_party/py/matplotlib",
],
)

View File

@ -22,14 +22,16 @@ import functools
import re
import os
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from matplotlib import pyplot as plt
import scipy.linalg as sla
import scipy.sparse as sps
import jax
from jax.config import config
from jax._src import test_util as jtu
from jax._src.config import flags
from jax.experimental.sparse import linalg, bcoo
import jax.numpy as jnp
@ -41,20 +43,25 @@ def _make_concrete_cases(f64):
example_names = list(_concrete_generators(dtype))
cases = []
for name in example_names:
nkm = [(100, 10, 20)]
if not flags.FLAGS.jax_skip_slow_tests:
nkm.append((1000, 100, 200))
for n, k, m in nkm:
if name == 'ring laplacian':
m *= 3
if name.startswith('linear'):
m *= 2
if f64:
m *= 2
case = [('matrix_name', name), ('n', n), ('k', k), ('m', m)]
clean_matrix_name = _clean_matrix_name(name)
case.append(('testcase_name', f'{clean_matrix_name}_n{n}'))
cases.append(dict(case))
n, k, m, tol = 100, 10, 20, None
if name == 'ring laplacian':
m *= 3
if name.startswith('linear'):
m *= 2
if f64:
m *= 2
if name.startswith('cluster') and not f64:
tol = 2e-6
clean_matrix_name = _clean_matrix_name(name)
case = {
'matrix_name': name,
'n': n,
'k': k,
'm': m,
'tol': tol,
'testcase_name': f'{clean_matrix_name}_n{n}'
}
cases.append(case)
assert len({c['testcase_name'] for c in cases}) == len(cases)
return cases
@ -160,21 +167,22 @@ def _make_sparse_fn(n, fill):
def _callable_generators(dtype):
n = 100
topk = 10
d = {'id': _make_id_fn(n),
'linear cond=1k': _make_diag_fn(np.linspace(1, 1000, n), 40),
'linear cond=100k':_make_diag_fn(np.linspace(1, 100 * 1000, n), 40),
'geom cond=1k': _make_diag_fn(np.logspace(0, 3, n), 20),
'geom cond=100k': _make_diag_fn(np.logspace(0, 5, n), 20),
'ring laplacian': _make_ring_fn(n, 40),
'randn': _make_randn_fn(n, topk, 40),
'sparse 1%': _make_sparse_fn(n, 0.01),
'sparse 10%': _make_sparse_fn(n, 0.10),
}
d = {
'id': _make_id_fn(n),
'linear cond=1k': _make_diag_fn(np.linspace(1, 1000, n), 40),
'linear cond=100k': _make_diag_fn(np.linspace(1, 100 * 1000, n), 40),
'geom cond=1k': _make_diag_fn(np.logspace(0, 3, n), 20),
'geom cond=100k': _make_diag_fn(np.logspace(0, 5, n), 20),
'ring laplacian': _make_ring_fn(n, 40),
'randn': _make_randn_fn(n, topk, 40),
'sparse 1%': _make_sparse_fn(n, 0.01),
'sparse 10%': _make_sparse_fn(n, 0.10),
}
ret = {}
for k, (vec_mul_fn, eigs, m) in d.items():
if jtu.num_float_bits(dtype) > 32:
m *= 3
m *= 3
eigs.sort()
# Note we must lift the vector multiply into matmul
@ -183,6 +191,7 @@ def _callable_generators(dtype):
ret[k] = (fn, eigs[::-1][:topk].astype(dtype), n, m)
return ret
@jtu.with_config(
jax_enable_checks=True,
jax_debug_nans=True,
@ -190,12 +199,12 @@ def _callable_generators(dtype):
jax_traceback_filtering='off')
class LobpcgTest(jtu.JaxTestCase):
def checkLobpcgConsistency(self, matrix_name, n, k, m, dtype):
def checkLobpcgConsistency(self, matrix_name, n, k, m, tol, dtype):
A, eigs = _concrete_generators(dtype)[matrix_name](n, k)
X = self.rng().standard_normal(size=(n, k)).astype(dtype)
A, X = (jnp.array(i, dtype=dtype) for i in (A, X))
theta, U, i = linalg.lobpcg_standard(A, X, m)
theta, U, i = linalg.lobpcg_standard(A, X, m, tol)
self.assertDtypesMatch(theta, A)
self.assertDtypesMatch(U, A)
@ -219,7 +228,7 @@ class LobpcgTest(jtu.JaxTestCase):
vector_norm = np.linalg.norm(Au)
adjusted_error = resid_norm / n / (t + vector_norm) / 10
eps = float(jnp.finfo(dtype).eps)
eps = float(jnp.finfo(dtype).eps) if tol is None else tol
self.assertLessEqual(
adjusted_error,
eps,
@ -236,7 +245,8 @@ class LobpcgTest(jtu.JaxTestCase):
f' for eigenvalue {i} (actual {float(theta[i])}, '
f'expected {float(eigs[i])})')
def checkLobpcgMonotonicity(self, matrix_name, n, k, m, dtype):
def checkLobpcgMonotonicity(self, matrix_name, n, k, m, tol, dtype):
del tol
A, eigs = _concrete_generators(dtype)[matrix_name](n, k)
X = self.rng().standard_normal(size=(n, k)).astype(dtype)
@ -261,15 +271,17 @@ class LobpcgTest(jtu.JaxTestCase):
self._possibly_plot(A, eigs, X, m, matrix_name)
def _possibly_plot(self, A, eigs, X, m, matrix_name):
plot_dir = os.getenv('LOBPCG_DEBUG_PLOT_DIR')
if plot_dir is not None:
if isinstance(A, (np.ndarray, jnp.ndarray)):
lobpcg = linalg._lobpcg_standard_matrix
else:
lobpcg = linalg._lobpcg_standard_callable
_theta, _U, _i, info = lobpcg(A, X, m, tol=0, debug=True)
self._debug_plots(
X, eigs, info, matrix_name, plot_dir)
if not os.getenv('LOBPCG_EMIT_DEBUG_PLOTS'):
return
if isinstance(A, (np.ndarray, jnp.ndarray)):
lobpcg = linalg._lobpcg_standard_matrix
else:
lobpcg = linalg._lobpcg_standard_callable
_theta, _U, _i, info = lobpcg(A, X, m, tol=0, debug=True)
plot_dir = os.getenv('TEST_UNDECLARED_OUTPUTS_DIR')
assert plot_dir, 'expected TEST_UNDECLARED_OUTPUTS_DIR for lobpcg plots'
self._debug_plots(X, eigs, info, matrix_name, plot_dir)
def _debug_plots(self, X, eigs, info, matrix_name, lobpcg_debug_plot_dir):
os.makedirs(lobpcg_debug_plot_dir, exist_ok=True)
@ -280,7 +292,6 @@ class LobpcgTest(jtu.JaxTestCase):
lobpcg_debug_plot_dir,
f'{clean_matrix_name}_n{n}_k{k}_{dt}.png')
from matplotlib import pyplot as plt
plt.switch_backend('Agg')
fig, (ax0, ax1, ax2, ax3) = plt.subplots(1, 4, figsize=(24, 4))
@ -313,6 +324,7 @@ class LobpcgTest(jtu.JaxTestCase):
ax3.set_title(prefix + rf' $\lambda_{{\max}}=\ ${eigs[0]:.1e}')
fig.savefig(figpath, bbox_inches='tight')
plt.close(fig)
def checkApproxEigs(self, example_name, dtype):
fn, eigs, n, m = _callable_generators(dtype)[example_name]
@ -346,6 +358,7 @@ class LobpcgTest(jtu.JaxTestCase):
self._possibly_plot(fn, eigs, X, m, 'callable_' + example_name)
class F32LobpcgTest(LobpcgTest):
def testLobpcgValidatesArguments(self):
@ -367,13 +380,13 @@ class F32LobpcgTest(LobpcgTest):
@parameterized.named_parameters(_make_concrete_cases(f64=False))
@jtu.skip_on_devices("gpu")
def testLobpcgConsistencyF32(self, matrix_name, n, k, m):
self.checkLobpcgConsistency(matrix_name, n, k, m, jnp.float32)
def testLobpcgConsistencyF32(self, matrix_name, n, k, m, tol):
self.checkLobpcgConsistency(matrix_name, n, k, m, tol, jnp.float32)
@parameterized.named_parameters(_make_concrete_cases(f64=False))
@jtu.skip_on_devices("rocm") # see SWDEV-321073
def testLobpcgMonotonicityF32(self, matrix_name, n, k, m):
self.checkLobpcgMonotonicity(matrix_name, n, k, m, jnp.float32)
def testLobpcgMonotonicityF32(self, matrix_name, n, k, m, tol):
self.checkLobpcgMonotonicity(matrix_name, n, k, m, tol, jnp.float32)
@parameterized.named_parameters(_make_callable_cases(f64=False))
def testCallableMatricesF32(self, matrix_name):
@ -385,15 +398,20 @@ class F64LobpcgTest(LobpcgTest):
@parameterized.named_parameters(_make_concrete_cases(f64=True))
@jtu.skip_on_devices("tpu", "iree", "gpu")
def testLobpcgConsistencyF64(self, matrix_name, n, k, m):
self.checkLobpcgConsistency(matrix_name, n, k, m, jnp.float64)
def testLobpcgConsistencyF64(self, matrix_name, n, k, m, tol):
self.checkLobpcgConsistency(matrix_name, n, k, m, tol, jnp.float64)
@parameterized.named_parameters(_make_concrete_cases(f64=True))
@jtu.skip_on_devices("tpu", "iree", "gpu")
def testLobpcgMonotonicityF64(self, matrix_name, n, k, m):
self.checkLobpcgMonotonicity(matrix_name, n, k, m, jnp.float64)
def testLobpcgMonotonicityF64(self, matrix_name, n, k, m, tol):
self.checkLobpcgMonotonicity(matrix_name, n, k, m, tol, jnp.float64)
@parameterized.named_parameters(_make_callable_cases(f64=True))
@jtu.skip_on_devices("tpu", "iree", "gpu")
def testCallableMatricesF64(self, matrix_name):
self.checkApproxEigs(matrix_name, jnp.float64)
if __name__ == '__main__':
config.parse_flags_with_absl()
absltest.main(testLoader=jtu.JaxTestLoader())