[jax2tf] Fix jax2tf tolerance for Cholesky, needed for newer TPUs

PiperOrigin-RevId: 652751771
This commit is contained in:
George Necula 2024-07-16 02:12:27 -07:00 committed by jax authors
parent d34a6e9ce2
commit e7be205a39

View File

@ -372,6 +372,13 @@ class Jax2TfLimitation(test_harnesses.Limitation):
tol=5e-2,
devices=("cpu", "gpu"),
modes=("eager", "graph", "compiled")),
custom_numeric(
dtypes=[dtypes.bfloat16],
tol=5e-5,
# Error for GL
devices=("tpu",),
modes=("eager", "graph", "compiled"),
native_serialization=Jax2TfLimitation.FOR_NATIVE),
custom_numeric(
custom_assert=custom_assert,
description=(