mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Copybara import of the project:
-- ed4916f7c8e29fb65af3c1d7ce41b8c19a26f2f4 by Peter Hawkins <phawkins@google.com>: Fix jax2tf eigh tests in preparation for enabling complex eigh lowering in XLA. COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/5892 from hawkinsp:eigh2 ed4916f7c8e29fb65af3c1d7ce41b8c19a26f2f4 PiperOrigin-RevId: 360327930
This commit is contained in:
parent
ad260af96b
commit
f4507d366d
@ -546,13 +546,13 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
# TODO(bchetioui): tolerance needs to be very high in compiled mode,
|
||||
# specifically for eigenvectors.
|
||||
if dtype == np.float64:
|
||||
tol = 1e-6
|
||||
tol = 2e-5
|
||||
elif dtype == np.float32:
|
||||
tol = 1e-2
|
||||
elif dtype in [dtypes.bfloat16, np.complex64]:
|
||||
tol = 1e-3
|
||||
elif dtype == np.complex128:
|
||||
tol = 1e-13
|
||||
tol = 2e-5
|
||||
tst.assertAllClose(
|
||||
np.matmul(a, vr) - w[..., None, :] * vr,
|
||||
np.zeros(a.shape, dtype=vr.dtype),
|
||||
@ -563,7 +563,7 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
if dtype in [dtypes.bfloat16, np.float32, np.complex64]:
|
||||
tol = 1e-3
|
||||
elif dtype in [np.float64, np.complex128]:
|
||||
tol = 1e-11
|
||||
tol = 1e-5
|
||||
closest_diff = min(abs(eigenvalues_array - eigenvalue))
|
||||
tst.assertAllClose(
|
||||
closest_diff, np.array(0., closest_diff.dtype), atol=tol)
|
||||
@ -581,18 +581,18 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
|
||||
return [
|
||||
# See https://github.com/google/jax/pull/3775#issuecomment-659407824;
|
||||
# TODO(b/181414529): enable after XLA/GPU bug is fixed.
|
||||
Jax2TfLimitation(
|
||||
"function not compilable",
|
||||
"XLA lowering bug",
|
||||
dtypes=[np.complex64, np.complex128],
|
||||
devices=("gpu",),
|
||||
modes="compiled",
|
||||
skip_tf_run=True),
|
||||
Jax2TfLimitation(
|
||||
"function not yet compilable",
|
||||
dtypes=[np.complex64, np.complex128],
|
||||
modes="compiled",
|
||||
enabled=(shape[0] > 0)),
|
||||
Jax2TfLimitation(
|
||||
"TODO: numeric discrepancies",
|
||||
dtypes=[np.float64],
|
||||
modes="compiled",
|
||||
devices=("cpu", "gpu"),
|
||||
expect_tf_error=False,
|
||||
skip_comparison=True),
|
||||
skip_tf_run=True),
|
||||
Jax2TfLimitation(
|
||||
"TODO: numeric discrepancies",
|
||||
dtypes=[np.float16],
|
||||
|
Loading…
x
Reference in New Issue
Block a user