rocm_jax/tests/lobpcg_test.py
vfdev-5 576843283b Disabled default env var LOBPCG_EMIT_DEBUG_PLOTS=1
Description:
- Disabled default env var LOBPCG_EMIT_DEBUG_PLOTS=1
- When run inside TSAN CI job with 3.14t cpython and under multi-threading
the test code from main leads to `RecursionError: maximum recursion depth exceeded` error:
```
ERROR: testLobpcgMonotonicityF32cluster_k_1__n100 (__main__.F32LobpcgTest)
F32LobpcgTest.testLobpcgMonotonicityF32cluster_k_1__n100
testLobpcgMonotonicityF32cluster_k_1__n100(matrix_name='cluster(k-1)', n=100, k=10, m=20, tol=2e-06)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_absl_py/site-packages/absl/testing/parameterized.py", line 319, in bound_param_test
    return test_method(self, **testcase_params)
  File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/__main__/tests/lobpcg_test.py", line 408, in testLobpcgMonotonicityF32
    self.checkLobpcgMonotonicity(matrix_name, n, k, m, tol, jnp.float32)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/__main__/tests/lobpcg_test.py", line 272, in checkLobpcgMonotonicity
    self._possibly_plot(A, eigs, X, m, matrix_name)
    ~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/__main__/tests/lobpcg_test.py", line 290, in _possibly_plot
    self._debug_plots(X, eigs, info, matrix_name, plot_dir)
    ~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/__main__/tests/lobpcg_test.py", line 318, in _debug_plots
    ax0.legend()
    ~~~~~~~~~~^^
  File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/axes/_axes.py", line 337, in legend
    self.legend_ = mlegend.Legend(self, handles, labels, **kwargs)
                   ~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/legend.py", line 549, in __init__
    self._init_legend_box(handles, labels, markerfirst)
    ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/legend.py", line 896, in _init_legend_box
    handle_list.append(handler.legend_artist(self, orig_handle,
                       ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^
                                             fontsize, handlebox))
                                             ^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/legend_handler.py", line 129, in legend_artist
    artists = self.create_artists(legend, orig_handle,
                                  xdescent, ydescent, width, height,
                                  fontsize, handlebox.get_transform())
  File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/legend_handler.py", line 303, in create_artists
    self.update_prop(legline, orig_handle, legend)
    ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/legend_handler.py", line 88, in update_prop
    self._update_prop(legend_handle, orig_handle)
    ~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/legend_handler.py", line 79, in _update_prop
    self._default_update_prop(legend_handle, orig_handle)
    ~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/legend_handler.py", line 84, in _default_update_prop
    legend_handle.update_from(orig_handle)
    ~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^
  File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/lines.py", line 1358, in update_from
    self._marker = MarkerStyle(marker=other._marker)
                   ~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/markers.py", line 248, in __init__
    self._set_marker(marker)
    ~~~~~~~~~~~~~~~~^^^^^^^^
  File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/markers.py", line 323, in _set_marker
    self.__dict__ = copy.deepcopy(marker.__dict__)
                    ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/external/python_x86_64-unknown-linux-gnu-freethreaded/lib/python3.14t/copy.py", line 131, in deepcopy
    y = copier(x, memo)
  File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/external/python_x86_64-unknown-linux-gnu-freethreaded/lib/python3.14t/copy.py", line 202, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ~~~~~~~~^^^^^^^^^^^^^
  File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/external/python_x86_64-unknown-linux-gnu-freethreaded/lib/python3.14t/copy.py", line 138, in deepcopy
    y = copier(memo)
  File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/path.py", line 285, in __deepcopy__
    p = copy.deepcopy(super(), memo)
  File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/external/python_x86_64-unknown-linux-gnu-freethreaded/lib/python3.14t/copy.py", line 157, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/external/python_x86_64-unknown-linux-gnu-freethreaded/lib/python3.14t/copy.py", line 234, in _reconstruct
    y = func(*args)
  File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/external/python_x86_64-unknown-linux-gnu-freethreaded/lib/python3.14t/copy.py", line 233, in <genexpr>
    args = (deepcopy(arg, memo) for arg in args)
            ~~~~~~~~^^^^^^^^^^^
  File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/external/python_x86_64-unknown-linux-gnu-freethreaded/lib/python3.14t/copy.py", line 138, in deepcopy
    y = copier(memo)
  File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/path.py", line 285, in __deepcopy__
    p = copy.deepcopy(super(), memo)
```
2025-04-02 15:12:12 +00:00

438 lines
15 KiB
Python

# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for lobpcg routine.
If LOBPCG_DEBUG_PLOT_DIR is set, exports debug visuals to that directory.
Requires matplotlib.
"""
import functools
import re
import os
import unittest
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import scipy.linalg as sla
import scipy.sparse as sps
import jax
from jax._src import test_util as jtu
from jax.experimental.sparse import linalg, bcoo
import jax.numpy as jnp
def _clean_matrix_name(name):
return re.sub('[^0-9a-zA-Z]+', '_', name)
def _make_concrete_cases(f64):
dtype = np.float64 if f64 else np.float32
example_names = list(_concrete_generators(dtype))
cases = []
for name in example_names:
n, k, m, tol = 100, 10, 20, None
if name == 'ring laplacian':
m *= 3
if name.startswith('linear'):
m *= 2
if f64:
m *= 2
if name.startswith('cluster') and not f64:
tol = 2e-6
clean_matrix_name = _clean_matrix_name(name)
case = {
'matrix_name': name,
'n': n,
'k': k,
'm': m,
'tol': tol,
'testcase_name': f'{clean_matrix_name}_n{n}'
}
cases.append(case)
assert len({c['testcase_name'] for c in cases}) == len(cases)
return cases
def _make_callable_cases(f64):
dtype = np.float64 if f64 else np.float32
example_names = list(_callable_generators(dtype))
return [{'testcase_name': _clean_matrix_name(n), 'matrix_name': n}
for n in example_names]
def _make_ring(n):
# from lobpcg scipy tests
col = np.zeros(n)
col[1] = 1
A = sla.toeplitz(col)
D = np.diag(A.sum(axis=1))
L = D - A
# Compute the full eigendecomposition using tricks, e.g.
# http://www.cs.yale.edu/homes/spielman/561/2009/lect02-09.pdf
tmp = np.pi * np.arange(n) / n
analytic_w = 2 * (1 - np.cos(tmp))
analytic_w.sort()
analytic_w = analytic_w[::-1]
return L, analytic_w
def _make_diag(diag):
diag.sort()
diag = diag[::-1]
return np.diag(diag), diag
def _make_cluster(to_cluster, n):
return _make_diag(
np.array([1000] * to_cluster + [1] * (n - to_cluster)))
def _concrete_generators(dtype):
d = {
'id': lambda n, _k: _make_diag(np.ones(n)),
'linear cond=1k': lambda n, _k: _make_diag(np.linspace(1, 1000, n)),
'linear cond=100k':
lambda n, _k: _make_diag(np.linspace(1, 100 * 1000, n)),
'geom cond=1k': lambda n, _k: _make_diag(np.logspace(0, 3, n)),
'geom cond=100k': lambda n, _k: _make_diag(np.logspace(0, 5, n)),
'ring laplacian': lambda n, _k: _make_ring(n),
'cluster(k/2)': lambda n, k: _make_cluster(k // 2, n),
'cluster(k-1)': lambda n, k: _make_cluster(k - 1, n),
'cluster(k)': lambda n, k: _make_cluster(k, n)}
def cast_fn(fn):
def casted_fn(n, k):
result = fn(n, k)
cast = functools.partial(np.array, dtype=dtype)
return tuple(map(cast, result))
return casted_fn
return {k: cast_fn(v) for k, v in d.items()}
def _make_id_fn(n):
return lambda x: x, np.ones(n), 5
def _make_diag_fn(diagonal, m):
return lambda x: diagonal.astype(x.dtype) * x, diagonal, m
def _make_ring_fn(n, m):
_, eigs = _make_ring(n)
def ring_action(x):
degree = 2 * x
lnbr = jnp.roll(x, 1)
rnbr = jnp.roll(x, -1)
return degree - lnbr - rnbr
return ring_action, eigs, m
def _make_randn_fn(n, k, m):
rng = np.random.default_rng(1234)
tall_skinny = rng.standard_normal((n, k))
def randn_action(x):
ts = jnp.array(tall_skinny, dtype=x.dtype)
p = jax.lax.Precision.HIGHEST
return ts.dot(ts.T.dot(x, precision=p), precision=p)
_, s, _ = np.linalg.svd(tall_skinny, full_matrices=False)
return randn_action, s ** 2, m
def _make_sparse_fn(n, fill):
rng = np.random.default_rng(1234)
slots = n ** 2
filled = max(int(slots * fill), 1)
pos = rng.choice(slots, size=filled, replace=False)
posx, posy = divmod(pos, n)
data = rng.standard_normal(len(pos))
coo = sps.coo_matrix((data, (posx, posy)), shape=(n, n))
def sparse_action(x):
coo_typed = coo.astype(np.dtype(x.dtype))
sps_mat = bcoo.BCOO.from_scipy_sparse(coo_typed)
dn = (((1,), (0,)), ((), ())) # Good old fashioned matmul.
x = bcoo.bcoo_dot_general(sps_mat, x, dimension_numbers=dn)
sps_mat_T = sps_mat.transpose()
return bcoo.bcoo_dot_general(sps_mat_T, x, dimension_numbers=dn)
dense = coo.todense()
_, s, _ = np.linalg.svd(dense, full_matrices=False)
return sparse_action, s ** 2, 20
def _callable_generators(dtype):
n = 100
topk = 10
d = {
'id': _make_id_fn(n),
'linear cond=1k': _make_diag_fn(np.linspace(1, 1000, n), 40),
'linear cond=100k': _make_diag_fn(np.linspace(1, 100 * 1000, n), 40),
'geom cond=1k': _make_diag_fn(np.logspace(0, 3, n), 20),
'geom cond=100k': _make_diag_fn(np.logspace(0, 5, n), 20),
'ring laplacian': _make_ring_fn(n, 40),
'randn': _make_randn_fn(n, topk, 40),
'sparse 1%': _make_sparse_fn(n, 0.01),
'sparse 10%': _make_sparse_fn(n, 0.10),
}
ret = {}
for k, (vec_mul_fn, eigs, m) in d.items():
if jtu.num_float_bits(dtype) > 32:
m *= 3
eigs.sort()
# Note we must lift the vector multiply into matmul
fn = jax.vmap(vec_mul_fn, in_axes=1, out_axes=1)
ret[k] = (fn, eigs[::-1][:topk].astype(dtype), n, m)
return ret
@jtu.with_config(
jax_enable_checks=True,
jax_debug_nans=True,
jax_numpy_rank_promotion='raise',
jax_traceback_filtering='off')
@jtu.thread_unsafe_test_class() # matplotlib isn't thread-safe
class LobpcgTest(jtu.JaxTestCase):
def checkLobpcgConsistency(self, matrix_name, n, k, m, tol, dtype):
A, eigs = _concrete_generators(dtype)[matrix_name](n, k)
X = self.rng().standard_normal(size=(n, k)).astype(dtype)
A, X = (jnp.array(i, dtype=dtype) for i in (A, X))
theta, U, i = linalg.lobpcg_standard(A, X, m, tol)
self.assertDtypesMatch(theta, A)
self.assertDtypesMatch(U, A)
self.assertLess(
i, m, msg=f'expected early convergence iters {int(i)} < max {m}')
issorted = theta[:-1] >= theta[1:]
all_true = np.ones_like(issorted).astype(bool)
self.assertArraysEqual(issorted, all_true)
k = X.shape[1]
relerr = np.abs(theta - eigs[:k]) / eigs[:k]
for i in range(k):
# The self-consistency property should be ensured.
u = np.asarray(U[:, i], dtype=A.dtype)
t = float(theta[i])
Au = A.dot(u)
resid = Au - t * u
resid_norm = np.linalg.norm(resid)
vector_norm = np.linalg.norm(Au)
adjusted_error = resid_norm / n / (t + vector_norm) / 10
eps = float(jnp.finfo(dtype).eps) if tol is None else tol
self.assertLessEqual(
adjusted_error,
eps,
msg=f'convergence criterion for eigenvalue {i} not satisfied, '
f'floating point error {adjusted_error} not <= {eps}')
# There's no real guarantee we can be within x% of the true eigenvalue.
# However, for these simple unit test examples this should be met.
tol = float(np.sqrt(eps)) * 10
self.assertLessEqual(
relerr[i],
tol,
msg=f'expected relative error within {tol}, was {float(relerr[i])}'
f' for eigenvalue {i} (actual {float(theta[i])}, '
f'expected {float(eigs[i])})')
def checkLobpcgMonotonicity(self, matrix_name, n, k, m, tol, dtype):
del tol
A, eigs = _concrete_generators(dtype)[matrix_name](n, k)
X = self.rng().standard_normal(size=(n, k)).astype(dtype)
_theta, _U, _i, info = linalg._lobpcg_standard_matrix(
A, X, m, tol=0, debug=True)
self.assertArraysEqual(info['X zeros'], jnp.zeros_like(info['X zeros']))
# To check for any divergence, make sure that the last 20% of
# steps have lower worst-case relerr than first 20% of steps,
# at least up to an order of magnitude.
#
# This is non-trivial, as many implementations have catastrophic
# cancellations at convergence for residual terms, and rely on
# brittle locking tolerance to avoid divergence.
eigs = eigs[:k]
relerrs = np.abs(np.array(info['lambda history']) - eigs) / eigs
few_steps = max(m // 5, 1)
self.assertLess(
relerrs[-few_steps:].max(axis=1).mean(),
10 * relerrs[:few_steps].max(axis=1).mean())
self._possibly_plot(A, eigs, X, m, matrix_name)
def _possibly_plot(self, A, eigs, X, m, matrix_name):
if os.getenv('LOBPCG_EMIT_DEBUG_PLOTS', '0') != '1':
return
if isinstance(A, (np.ndarray, jax.Array)):
lobpcg = linalg._lobpcg_standard_matrix
else:
lobpcg = linalg._lobpcg_standard_callable
_theta, _U, _i, info = lobpcg(A, X, m, tol=0, debug=True)
plot_dir = os.getenv('TEST_UNDECLARED_OUTPUTS_DIR')
assert plot_dir, 'expected TEST_UNDECLARED_OUTPUTS_DIR for lobpcg plots'
self._debug_plots(X, eigs, info, matrix_name, plot_dir)
def _debug_plots(self, X, eigs, info, matrix_name, lobpcg_debug_plot_dir):
# We import matplotlib lazily because (a) it's faster this way, and
# (b) concurrent imports of matplotlib appear to trigger some sort of
# collision on the matplotlib cache lock on Windows.
try:
from matplotlib import pyplot as plt
except (ModuleNotFoundError, ImportError):
return # If matplotlib isn't available, don't emit plots.
os.makedirs(lobpcg_debug_plot_dir, exist_ok=True)
clean_matrix_name = _clean_matrix_name(matrix_name)
n, k = X.shape
dt = 'f32' if X.dtype == np.float32 else 'f64'
figpath = os.path.join(
lobpcg_debug_plot_dir,
f'{clean_matrix_name}_n{n}_k{k}_{dt}.png')
plt.switch_backend('Agg')
fig, (ax0, ax1, ax2, ax3) = plt.subplots(1, 4, figsize=(24, 4))
fig.suptitle(fr'{matrix_name} ${n=},{k=}$, {dt}')
line_styles = [':', '--', '-.', '-']
for key, ls in zip(['X orth', 'P orth', 'P.X'], line_styles):
ax0.semilogy(info[key], ls=ls, label=key)
ax0.set_title('basis average orthogonality')
ax0.legend()
relerrs = np.abs(np.array(info['lambda history']) - eigs) / eigs
keys = ['max', 'p50', 'min']
fns = [np.max, np.median, np.min]
for key, fn, ls in zip(keys, fns, line_styles):
ax1.semilogy(fn(relerrs, axis=1), ls=ls, label=key)
ax1.set_title('eigval relerr')
ax1.legend()
for key, ls in zip(['basis rank', 'converged', 'P zeros'], line_styles):
ax2.plot(info[key], ls=ls, label=key)
ax2.set_title('basis dimension counts')
ax2.legend()
prefix = 'adjusted residual'
for key, ls in zip(keys, line_styles):
ax3.semilogy(info[prefix + ' ' + key], ls=ls, label=key)
ax3.axhline(np.finfo(X.dtype).eps, label='eps', c='k')
ax3.legend()
ax3.set_title(prefix + rf' $\lambda_{{\max}}=\ ${eigs[0]:.1e}')
fig.savefig(figpath, bbox_inches='tight')
plt.close(fig)
def checkApproxEigs(self, example_name, dtype):
fn, eigs, n, m = _callable_generators(dtype)[example_name]
k = len(eigs)
X = self.rng().standard_normal(size=(n, k)).astype(dtype)
theta, U, iters = linalg.lobpcg_standard(fn, X, m, tol=0.0)
# Given tolerance is zero all iters should be used.
self.assertEqual(iters, m)
# Evaluate in f64.
as64 = functools.partial(np.array, dtype=np.float64)
theta, eigs, U = (as64(x) for x in (theta, eigs, U))
relerr = np.abs(theta - eigs) / eigs
UTU = U.T.dot(U)
tol = np.sqrt(jnp.finfo(dtype).eps) * 100
if example_name == 'ring laplacian':
tol = 1e-2
for i in range(k):
self.assertLessEqual(
relerr[i], tol,
msg=f'eigenvalue {i} (actual {theta[i]} expected {eigs[i]})')
self.assertAllClose(UTU[i, i], 1.0, rtol=tol)
UTU[i, i] = 0
self.assertArraysAllClose(UTU[i], np.zeros_like(UTU[i]), atol=tol)
self._possibly_plot(fn, eigs, X, m, 'callable_' + example_name)
class F32LobpcgTest(LobpcgTest):
def setUp(self):
# TODO(phawkins): investigate this failure
if jtu.test_device_matches(["gpu"]):
raise unittest.SkipTest("Test is failing on CUDA gpus")
super().setUp()
def testLobpcgValidatesArguments(self):
A, _ = _concrete_generators(np.float32)['id'](100, 10)
X = self.rng().standard_normal(size=(100, 10)).astype(np.float32)
with self.assertRaisesRegex(ValueError, 'search dim > 0'):
linalg.lobpcg_standard(A, X[:,:0])
with self.assertRaisesRegex(ValueError, 'A, X must have same dtypes'):
linalg.lobpcg_standard(
lambda x: jnp.array(A).dot(x).astype(jnp.float16), X)
with self.assertRaisesRegex(ValueError, r'A must be \(100, 100\)'):
linalg.lobpcg_standard(A[:60, :], X)
with self.assertRaisesRegex(ValueError, r'search dim \* 5 < matrix dim'):
linalg.lobpcg_standard(A[:50, :50], X[:50])
@parameterized.named_parameters(_make_concrete_cases(f64=False))
@jtu.skip_on_devices("gpu")
def testLobpcgConsistencyF32(self, matrix_name, n, k, m, tol):
self.checkLobpcgConsistency(matrix_name, n, k, m, tol, jnp.float32)
@parameterized.named_parameters(_make_concrete_cases(f64=False))
def testLobpcgMonotonicityF32(self, matrix_name, n, k, m, tol):
self.checkLobpcgMonotonicity(matrix_name, n, k, m, tol, jnp.float32)
@parameterized.named_parameters(_make_callable_cases(f64=False))
def testCallableMatricesF32(self, matrix_name):
self.checkApproxEigs(matrix_name, jnp.float32)
@jtu.with_config(jax_enable_x64=True)
class F64LobpcgTest(LobpcgTest):
def setUp(self):
# TODO(phawkins): investigate this failure
if jtu.test_device_matches(["gpu"]):
raise unittest.SkipTest("Test is failing on CUDA gpus")
super().setUp()
@parameterized.named_parameters(_make_concrete_cases(f64=True))
@jtu.skip_on_devices("tpu", "gpu")
def testLobpcgConsistencyF64(self, matrix_name, n, k, m, tol):
self.checkLobpcgConsistency(matrix_name, n, k, m, tol, jnp.float64)
@parameterized.named_parameters(_make_concrete_cases(f64=True))
@jtu.skip_on_devices("tpu", "gpu")
def testLobpcgMonotonicityF64(self, matrix_name, n, k, m, tol):
self.checkLobpcgMonotonicity(matrix_name, n, k, m, tol, jnp.float64)
@parameterized.named_parameters(_make_callable_cases(f64=True))
@jtu.skip_on_devices("tpu", "gpu")
def testCallableMatricesF64(self, matrix_name):
self.checkApproxEigs(matrix_name, jnp.float64)
if __name__ == '__main__':
jax.config.parse_flags_with_absl()
absltest.main(testLoader=jtu.JaxTestLoader())