mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Make LOBPCG test plots compatible with bazel.
bazel test invocations would previously not work, because the lobpcg_test did not include the appropriate flag parsing and absl test invocations when run as a script. This change fixes that, and in addition shards tests and removes needless and redundant slow tests with larger matrix sizes to make the tests finish in a smaller amount of time. Now, generated pngs with debug information are properly reported via the undeclared outputs directory when the environment variable to emit them, LOBPCG_EMIT_DEBUG_PLOTS, is set to a non-falsy value. PiperOrigin-RevId: 465465731
This commit is contained in:
parent
07da502323
commit
269067e3e8
@ -1,5 +1,6 @@
|
||||
cloudpickle
|
||||
colorama>=0.4.4
|
||||
matplotlib
|
||||
pillow>=8.3.1
|
||||
pytest-benchmark
|
||||
pytest-xdist
|
||||
|
@ -78,10 +78,9 @@ def lobpcg_standard(
|
||||
to the float epsilon of `A.dtype`.
|
||||
|
||||
Returns:
|
||||
`theta, U, i [, diagnostics]`, where `theta` is a `(k,)` array
|
||||
`theta, U, i`, where `theta` is a `(k,)` array
|
||||
of eigenvalues, `U` is a `(n, k)` array of eigenvectors, `i` is the
|
||||
number of iterations performed, and `diagnostics` is a dictionary with debug
|
||||
information, which is only returned if `debug` is set to true.
|
||||
number of iterations performed.
|
||||
|
||||
Raises:
|
||||
ValueError : if `A,X` dtypes or `n` dimensions do not match, or `k` is too
|
||||
|
@ -147,8 +147,15 @@ jax_test(
|
||||
jax_test(
|
||||
name = "lobpcg_test",
|
||||
srcs = ["lobpcg_test.py"],
|
||||
env = {"LOBPCG_EMIT_DEBUG_PLOTS": "1"},
|
||||
shard_count = {
|
||||
"cpu": 48,
|
||||
"gpu": 48,
|
||||
"tpu": 48,
|
||||
},
|
||||
deps = [
|
||||
"//jax:experimental_sparse",
|
||||
"//third_party/py/matplotlib",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -22,14 +22,16 @@ import functools
|
||||
import re
|
||||
import os
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
from matplotlib import pyplot as plt
|
||||
import scipy.linalg as sla
|
||||
import scipy.sparse as sps
|
||||
|
||||
import jax
|
||||
from jax.config import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.config import flags
|
||||
from jax.experimental.sparse import linalg, bcoo
|
||||
import jax.numpy as jnp
|
||||
|
||||
@ -41,20 +43,25 @@ def _make_concrete_cases(f64):
|
||||
example_names = list(_concrete_generators(dtype))
|
||||
cases = []
|
||||
for name in example_names:
|
||||
nkm = [(100, 10, 20)]
|
||||
if not flags.FLAGS.jax_skip_slow_tests:
|
||||
nkm.append((1000, 100, 200))
|
||||
for n, k, m in nkm:
|
||||
if name == 'ring laplacian':
|
||||
m *= 3
|
||||
if name.startswith('linear'):
|
||||
m *= 2
|
||||
if f64:
|
||||
m *= 2
|
||||
case = [('matrix_name', name), ('n', n), ('k', k), ('m', m)]
|
||||
clean_matrix_name = _clean_matrix_name(name)
|
||||
case.append(('testcase_name', f'{clean_matrix_name}_n{n}'))
|
||||
cases.append(dict(case))
|
||||
n, k, m, tol = 100, 10, 20, None
|
||||
if name == 'ring laplacian':
|
||||
m *= 3
|
||||
if name.startswith('linear'):
|
||||
m *= 2
|
||||
if f64:
|
||||
m *= 2
|
||||
if name.startswith('cluster') and not f64:
|
||||
tol = 2e-6
|
||||
clean_matrix_name = _clean_matrix_name(name)
|
||||
case = {
|
||||
'matrix_name': name,
|
||||
'n': n,
|
||||
'k': k,
|
||||
'm': m,
|
||||
'tol': tol,
|
||||
'testcase_name': f'{clean_matrix_name}_n{n}'
|
||||
}
|
||||
cases.append(case)
|
||||
|
||||
assert len({c['testcase_name'] for c in cases}) == len(cases)
|
||||
return cases
|
||||
@ -160,21 +167,22 @@ def _make_sparse_fn(n, fill):
|
||||
def _callable_generators(dtype):
|
||||
n = 100
|
||||
topk = 10
|
||||
d = {'id': _make_id_fn(n),
|
||||
'linear cond=1k': _make_diag_fn(np.linspace(1, 1000, n), 40),
|
||||
'linear cond=100k':_make_diag_fn(np.linspace(1, 100 * 1000, n), 40),
|
||||
'geom cond=1k': _make_diag_fn(np.logspace(0, 3, n), 20),
|
||||
'geom cond=100k': _make_diag_fn(np.logspace(0, 5, n), 20),
|
||||
'ring laplacian': _make_ring_fn(n, 40),
|
||||
'randn': _make_randn_fn(n, topk, 40),
|
||||
'sparse 1%': _make_sparse_fn(n, 0.01),
|
||||
'sparse 10%': _make_sparse_fn(n, 0.10),
|
||||
}
|
||||
d = {
|
||||
'id': _make_id_fn(n),
|
||||
'linear cond=1k': _make_diag_fn(np.linspace(1, 1000, n), 40),
|
||||
'linear cond=100k': _make_diag_fn(np.linspace(1, 100 * 1000, n), 40),
|
||||
'geom cond=1k': _make_diag_fn(np.logspace(0, 3, n), 20),
|
||||
'geom cond=100k': _make_diag_fn(np.logspace(0, 5, n), 20),
|
||||
'ring laplacian': _make_ring_fn(n, 40),
|
||||
'randn': _make_randn_fn(n, topk, 40),
|
||||
'sparse 1%': _make_sparse_fn(n, 0.01),
|
||||
'sparse 10%': _make_sparse_fn(n, 0.10),
|
||||
}
|
||||
|
||||
ret = {}
|
||||
for k, (vec_mul_fn, eigs, m) in d.items():
|
||||
if jtu.num_float_bits(dtype) > 32:
|
||||
m *= 3
|
||||
m *= 3
|
||||
eigs.sort()
|
||||
|
||||
# Note we must lift the vector multiply into matmul
|
||||
@ -183,6 +191,7 @@ def _callable_generators(dtype):
|
||||
ret[k] = (fn, eigs[::-1][:topk].astype(dtype), n, m)
|
||||
return ret
|
||||
|
||||
|
||||
@jtu.with_config(
|
||||
jax_enable_checks=True,
|
||||
jax_debug_nans=True,
|
||||
@ -190,12 +199,12 @@ def _callable_generators(dtype):
|
||||
jax_traceback_filtering='off')
|
||||
class LobpcgTest(jtu.JaxTestCase):
|
||||
|
||||
def checkLobpcgConsistency(self, matrix_name, n, k, m, dtype):
|
||||
def checkLobpcgConsistency(self, matrix_name, n, k, m, tol, dtype):
|
||||
A, eigs = _concrete_generators(dtype)[matrix_name](n, k)
|
||||
X = self.rng().standard_normal(size=(n, k)).astype(dtype)
|
||||
|
||||
A, X = (jnp.array(i, dtype=dtype) for i in (A, X))
|
||||
theta, U, i = linalg.lobpcg_standard(A, X, m)
|
||||
theta, U, i = linalg.lobpcg_standard(A, X, m, tol)
|
||||
|
||||
self.assertDtypesMatch(theta, A)
|
||||
self.assertDtypesMatch(U, A)
|
||||
@ -219,7 +228,7 @@ class LobpcgTest(jtu.JaxTestCase):
|
||||
vector_norm = np.linalg.norm(Au)
|
||||
adjusted_error = resid_norm / n / (t + vector_norm) / 10
|
||||
|
||||
eps = float(jnp.finfo(dtype).eps)
|
||||
eps = float(jnp.finfo(dtype).eps) if tol is None else tol
|
||||
self.assertLessEqual(
|
||||
adjusted_error,
|
||||
eps,
|
||||
@ -236,7 +245,8 @@ class LobpcgTest(jtu.JaxTestCase):
|
||||
f' for eigenvalue {i} (actual {float(theta[i])}, '
|
||||
f'expected {float(eigs[i])})')
|
||||
|
||||
def checkLobpcgMonotonicity(self, matrix_name, n, k, m, dtype):
|
||||
def checkLobpcgMonotonicity(self, matrix_name, n, k, m, tol, dtype):
|
||||
del tol
|
||||
A, eigs = _concrete_generators(dtype)[matrix_name](n, k)
|
||||
X = self.rng().standard_normal(size=(n, k)).astype(dtype)
|
||||
|
||||
@ -261,15 +271,17 @@ class LobpcgTest(jtu.JaxTestCase):
|
||||
self._possibly_plot(A, eigs, X, m, matrix_name)
|
||||
|
||||
def _possibly_plot(self, A, eigs, X, m, matrix_name):
|
||||
plot_dir = os.getenv('LOBPCG_DEBUG_PLOT_DIR')
|
||||
if plot_dir is not None:
|
||||
if isinstance(A, (np.ndarray, jnp.ndarray)):
|
||||
lobpcg = linalg._lobpcg_standard_matrix
|
||||
else:
|
||||
lobpcg = linalg._lobpcg_standard_callable
|
||||
_theta, _U, _i, info = lobpcg(A, X, m, tol=0, debug=True)
|
||||
self._debug_plots(
|
||||
X, eigs, info, matrix_name, plot_dir)
|
||||
if not os.getenv('LOBPCG_EMIT_DEBUG_PLOTS'):
|
||||
return
|
||||
|
||||
if isinstance(A, (np.ndarray, jnp.ndarray)):
|
||||
lobpcg = linalg._lobpcg_standard_matrix
|
||||
else:
|
||||
lobpcg = linalg._lobpcg_standard_callable
|
||||
_theta, _U, _i, info = lobpcg(A, X, m, tol=0, debug=True)
|
||||
plot_dir = os.getenv('TEST_UNDECLARED_OUTPUTS_DIR')
|
||||
assert plot_dir, 'expected TEST_UNDECLARED_OUTPUTS_DIR for lobpcg plots'
|
||||
self._debug_plots(X, eigs, info, matrix_name, plot_dir)
|
||||
|
||||
def _debug_plots(self, X, eigs, info, matrix_name, lobpcg_debug_plot_dir):
|
||||
os.makedirs(lobpcg_debug_plot_dir, exist_ok=True)
|
||||
@ -280,7 +292,6 @@ class LobpcgTest(jtu.JaxTestCase):
|
||||
lobpcg_debug_plot_dir,
|
||||
f'{clean_matrix_name}_n{n}_k{k}_{dt}.png')
|
||||
|
||||
from matplotlib import pyplot as plt
|
||||
plt.switch_backend('Agg')
|
||||
|
||||
fig, (ax0, ax1, ax2, ax3) = plt.subplots(1, 4, figsize=(24, 4))
|
||||
@ -313,6 +324,7 @@ class LobpcgTest(jtu.JaxTestCase):
|
||||
ax3.set_title(prefix + rf' $\lambda_{{\max}}=\ ${eigs[0]:.1e}')
|
||||
|
||||
fig.savefig(figpath, bbox_inches='tight')
|
||||
plt.close(fig)
|
||||
|
||||
def checkApproxEigs(self, example_name, dtype):
|
||||
fn, eigs, n, m = _callable_generators(dtype)[example_name]
|
||||
@ -346,6 +358,7 @@ class LobpcgTest(jtu.JaxTestCase):
|
||||
|
||||
self._possibly_plot(fn, eigs, X, m, 'callable_' + example_name)
|
||||
|
||||
|
||||
class F32LobpcgTest(LobpcgTest):
|
||||
|
||||
def testLobpcgValidatesArguments(self):
|
||||
@ -367,13 +380,13 @@ class F32LobpcgTest(LobpcgTest):
|
||||
|
||||
@parameterized.named_parameters(_make_concrete_cases(f64=False))
|
||||
@jtu.skip_on_devices("gpu")
|
||||
def testLobpcgConsistencyF32(self, matrix_name, n, k, m):
|
||||
self.checkLobpcgConsistency(matrix_name, n, k, m, jnp.float32)
|
||||
def testLobpcgConsistencyF32(self, matrix_name, n, k, m, tol):
|
||||
self.checkLobpcgConsistency(matrix_name, n, k, m, tol, jnp.float32)
|
||||
|
||||
@parameterized.named_parameters(_make_concrete_cases(f64=False))
|
||||
@jtu.skip_on_devices("rocm") # see SWDEV-321073
|
||||
def testLobpcgMonotonicityF32(self, matrix_name, n, k, m):
|
||||
self.checkLobpcgMonotonicity(matrix_name, n, k, m, jnp.float32)
|
||||
def testLobpcgMonotonicityF32(self, matrix_name, n, k, m, tol):
|
||||
self.checkLobpcgMonotonicity(matrix_name, n, k, m, tol, jnp.float32)
|
||||
|
||||
@parameterized.named_parameters(_make_callable_cases(f64=False))
|
||||
def testCallableMatricesF32(self, matrix_name):
|
||||
@ -385,15 +398,20 @@ class F64LobpcgTest(LobpcgTest):
|
||||
|
||||
@parameterized.named_parameters(_make_concrete_cases(f64=True))
|
||||
@jtu.skip_on_devices("tpu", "iree", "gpu")
|
||||
def testLobpcgConsistencyF64(self, matrix_name, n, k, m):
|
||||
self.checkLobpcgConsistency(matrix_name, n, k, m, jnp.float64)
|
||||
def testLobpcgConsistencyF64(self, matrix_name, n, k, m, tol):
|
||||
self.checkLobpcgConsistency(matrix_name, n, k, m, tol, jnp.float64)
|
||||
|
||||
@parameterized.named_parameters(_make_concrete_cases(f64=True))
|
||||
@jtu.skip_on_devices("tpu", "iree", "gpu")
|
||||
def testLobpcgMonotonicityF64(self, matrix_name, n, k, m):
|
||||
self.checkLobpcgMonotonicity(matrix_name, n, k, m, jnp.float64)
|
||||
def testLobpcgMonotonicityF64(self, matrix_name, n, k, m, tol):
|
||||
self.checkLobpcgMonotonicity(matrix_name, n, k, m, tol, jnp.float64)
|
||||
|
||||
@parameterized.named_parameters(_make_callable_cases(f64=True))
|
||||
@jtu.skip_on_devices("tpu", "iree", "gpu")
|
||||
def testCallableMatricesF64(self, matrix_name):
|
||||
self.checkApproxEigs(matrix_name, jnp.float64)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config.parse_flags_with_absl()
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user