Copybara import of the project:

--
ed4916f7c8e29fb65af3c1d7ce41b8c19a26f2f4 by Peter Hawkins <phawkins@google.com>:

Fix jax2tf eigh tests in preparation for enabling complex eigh lowering in XLA.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/5892 from hawkinsp:eigh2 ed4916f7c8e29fb65af3c1d7ce41b8c19a26f2f4
PiperOrigin-RevId: 360327930
This commit is contained in:
Peter Hawkins 2021-03-01 18:57:38 -08:00 committed by jax authors
parent ad260af96b
commit f4507d366d

View File

@ -546,13 +546,13 @@ class Jax2TfLimitation(primitive_harness.Limitation):
# TODO(bchetioui): tolerance needs to be very high in compiled mode,
# specifically for eigenvectors.
if dtype == np.float64:
tol = 1e-6
tol = 2e-5
elif dtype == np.float32:
tol = 1e-2
elif dtype in [dtypes.bfloat16, np.complex64]:
tol = 1e-3
elif dtype == np.complex128:
tol = 1e-13
tol = 2e-5
tst.assertAllClose(
np.matmul(a, vr) - w[..., None, :] * vr,
np.zeros(a.shape, dtype=vr.dtype),
@ -563,7 +563,7 @@ class Jax2TfLimitation(primitive_harness.Limitation):
if dtype in [dtypes.bfloat16, np.float32, np.complex64]:
tol = 1e-3
elif dtype in [np.float64, np.complex128]:
tol = 1e-11
tol = 1e-5
closest_diff = min(abs(eigenvalues_array - eigenvalue))
tst.assertAllClose(
closest_diff, np.array(0., closest_diff.dtype), atol=tol)
@ -581,18 +581,18 @@ class Jax2TfLimitation(primitive_harness.Limitation):
return [
# See https://github.com/google/jax/pull/3775#issuecomment-659407824;
# TODO(b/181414529): enable after XLA/GPU bug is fixed.
Jax2TfLimitation(
"function not compilable",
"XLA lowering bug",
dtypes=[np.complex64, np.complex128],
devices=("gpu",),
modes="compiled",
skip_tf_run=True),
Jax2TfLimitation(
"function not yet compilable",
dtypes=[np.complex64, np.complex128],
modes="compiled",
enabled=(shape[0] > 0)),
Jax2TfLimitation(
"TODO: numeric discrepancies",
dtypes=[np.float64],
modes="compiled",
devices=("cpu", "gpu"),
expect_tf_error=False,
skip_comparison=True),
skip_tf_run=True),
Jax2TfLimitation(
"TODO: numeric discrepancies",
dtypes=[np.float16],