mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax2tf] Support bfloat16 for reduce_window when enable_xla=False
PiperOrigin-RevId: 599841265
This commit is contained in:
parent
ab3c1b5146
commit
017a0d83a9
@ -2521,7 +2521,9 @@ def requires_xla_for_reduce(name, dtype):
|
||||
return True
|
||||
if name == "min" and dtype in [np.uint8, np.uint16]:
|
||||
return True
|
||||
if name == "add" and dtype not in [np.float16, np.float32, np.float64]:
|
||||
if name == "add" and dtype not in [
|
||||
dtypes.bfloat16, np.float16, np.float32, np.float64
|
||||
]:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -547,6 +547,7 @@ def _validate_reduce_window_inputs(operand_shape, computation_name, dtype,
|
||||
# tf.math.reduce_min.
|
||||
raise _reduce_error(f"Min pool does not support operands of type {dtype}")
|
||||
if computation_name == "add" and dtype not in [
|
||||
tf.bfloat16,
|
||||
tf.float16,
|
||||
tf.float32,
|
||||
tf.float64,
|
||||
|
@ -1076,7 +1076,7 @@ class Jax2TfLimitation(test_harnesses.Limitation):
|
||||
tol=3e-5),
|
||||
Jax2TfLimitation(
|
||||
"Large deviations on TPU for enable_xla=False",
|
||||
dtypes=[np.float16, np.float32],
|
||||
dtypes=[dtypes.bfloat16, np.float16, np.float32],
|
||||
devices="tpu",
|
||||
modes=("eager", "graph", "compiled"),
|
||||
expect_tf_error=False,
|
||||
@ -1086,6 +1086,8 @@ class Jax2TfLimitation(test_harnesses.Limitation):
|
||||
modes=("eager", "graph", "compiled",), tol=1e-5),
|
||||
custom_numeric(devices=("cpu", "gpu"), dtypes=[np.float16],
|
||||
modes=("eager", "graph", "compiled",), tol=5e-3),
|
||||
custom_numeric(devices=("cpu", "gpu"), dtypes=[dtypes.bfloat16],
|
||||
modes=("eager", "graph", "compiled",), tol=5e-1),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
|
Loading…
x
Reference in New Issue
Block a user