Temporarily disable flaky nextafter tests

These are currently failing at HEAD due to 72f10f7eb5

We can re-enable once b9483d30a7 is integrated.

PiperOrigin-RevId: 601788984
This commit is contained in:
Jake VanderPlas 2024-01-26 09:35:30 -08:00 committed by jax authors
parent f34bcc326b
commit 1ae054b003
2 changed files with 11 additions and 9 deletions

View File

@ -160,13 +160,14 @@ def lax_ops():
op_record("floor", 1, float_dtypes, test_util.rand_small),
op_record("ceil", 1, float_dtypes, test_util.rand_small),
op_record("round", 1, float_dtypes, test_util.rand_default),
op_record(
"nextafter",
2,
[f for f in float_dtypes if f != dtypes.bfloat16],
test_util.rand_default,
tol=0,
),
# TODO(b/322390905) re-enable this test
# op_record(
# "nextafter",
# 2,
# [f for f in float_dtypes if f != dtypes.bfloat16],
# test_util.rand_default,
# tol=0,
# ),
op_record("is_finite", 1, float_dtypes, test_util.rand_small),
op_record("exp", 1, float_dtypes + complex_dtypes, test_util.rand_small),
op_record("exp2", 1, float_dtypes + complex_dtypes, test_util.rand_small),

View File

@ -124,8 +124,9 @@ JAX_ONE_TO_ONE_OP_RECORDS = [
op_record("minimum", 2, all_dtypes, all_shapes, jtu.rand_some_inf, []),
op_record("multiply", 2, all_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("negative", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("nextafter", 2, [f for f in float_dtypes if f != jnp.bfloat16],
all_shapes, jtu.rand_default, ["rev"], inexact=True, tolerance=0),
# TODO(b/322390905) re-enable this test
# op_record("nextafter", 2, [f for f in float_dtypes if f != jnp.bfloat16],
# all_shapes, jtu.rand_default, ["rev"], inexact=True, tolerance=0),
op_record("not_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
op_record("array_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
op_record("array_equiv", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),