[jax2tf] Support bfloat16 for reduce_window when enable_xla=False

PiperOrigin-RevId: 599841265
This commit is contained in:
Kevin Chen 2024-01-19 08:24:30 -08:00 committed by jax authors
parent ab3c1b5146
commit 017a0d83a9
3 changed files with 7 additions and 2 deletions

View File

@ -2521,7 +2521,9 @@ def requires_xla_for_reduce(name, dtype):
return True
if name == "min" and dtype in [np.uint8, np.uint16]:
return True
if name == "add" and dtype not in [np.float16, np.float32, np.float64]:
if name == "add" and dtype not in [
dtypes.bfloat16, np.float16, np.float32, np.float64
]:
return True
return False

View File

@ -547,6 +547,7 @@ def _validate_reduce_window_inputs(operand_shape, computation_name, dtype,
# tf.math.reduce_min.
raise _reduce_error(f"Min pool does not support operands of type {dtype}")
if computation_name == "add" and dtype not in [
tf.bfloat16,
tf.float16,
tf.float32,
tf.float64,

View File

@ -1076,7 +1076,7 @@ class Jax2TfLimitation(test_harnesses.Limitation):
tol=3e-5),
Jax2TfLimitation(
"Large deviations on TPU for enable_xla=False",
dtypes=[np.float16, np.float32],
dtypes=[dtypes.bfloat16, np.float16, np.float32],
devices="tpu",
modes=("eager", "graph", "compiled"),
expect_tf_error=False,
@ -1086,6 +1086,8 @@ class Jax2TfLimitation(test_harnesses.Limitation):
modes=("eager", "graph", "compiled",), tol=1e-5),
custom_numeric(devices=("cpu", "gpu"), dtypes=[np.float16],
modes=("eager", "graph", "compiled",), tol=5e-3),
custom_numeric(devices=("cpu", "gpu"), dtypes=[dtypes.bfloat16],
modes=("eager", "graph", "compiled",), tol=5e-1),
]
@classmethod