mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax2tf] Fix jax2tf tolerance for Cholesky, needed for newer TPUs
PiperOrigin-RevId: 652751771
This commit is contained in:
parent
d34a6e9ce2
commit
e7be205a39
@ -372,6 +372,13 @@ class Jax2TfLimitation(test_harnesses.Limitation):
|
||||
tol=5e-2,
|
||||
devices=("cpu", "gpu"),
|
||||
modes=("eager", "graph", "compiled")),
|
||||
custom_numeric(
|
||||
dtypes=[dtypes.bfloat16],
|
||||
tol=5e-5,
|
||||
# Error for GL
|
||||
devices=("tpu",),
|
||||
modes=("eager", "graph", "compiled"),
|
||||
native_serialization=Jax2TfLimitation.FOR_NATIVE),
|
||||
custom_numeric(
|
||||
custom_assert=custom_assert,
|
||||
description=(
|
||||
|
Loading…
x
Reference in New Issue
Block a user