Always suppress the differing_executors Hypothesis health check

It's only relevant to notify about potential key collisions in the example
database, but we explicitly disable it, so it doesn't matter.

PiperOrigin-RevId: 707914664
This commit is contained in:
Adam Paszke 2024-12-19 08:00:12 -08:00 committed by jax authors
parent 3915f4a147
commit 23000a3842
2 changed files with 17 additions and 5 deletions

View File

@ -2147,6 +2147,22 @@ def setup_hypothesis(max_examples=30) -> None:
except (ModuleNotFoundError, ImportError):
return
# In our tests we often use subclasses with slightly different class variables
# to generate whole suites of parameterized tests, but this approach does not
# work well with Hypothesis databases, which use some function of the method
# identity to generate keys. But, if the method is defined in a superclass,
# all subclasses share the same key. This key collision can lead to confusing
# false positives in other health checks.
#
# Still, as far as I understand, for as long as we don't use the example
# database, it should be perfectly safe to suppress this health check. This
# seems simpler than rewriting our tests that trigger this behavior. See
# the end of https://github.com/HypothesisWorks/hypothesis/issues/3446 for
# more context.
suppressed_checks = []
if hasattr(hp.HealthCheck, "differing_executors"):
suppressed_checks.append(hp.HealthCheck.differing_executors)
hp.settings.register_profile(
"deterministic",
database=None,
@ -2154,6 +2170,7 @@ def setup_hypothesis(max_examples=30) -> None:
deadline=None,
max_examples=max_examples,
print_blob=True,
suppress_health_check=suppressed_checks,
)
hp.settings.register_profile(
"interactive",

View File

@ -51,10 +51,6 @@ Slice = indexing.Slice
NDIndexer = indexing.NDIndexer
ds = indexing.ds
HP_DIFFERING_EXECUTORS = []
if hasattr(hp.HealthCheck, "differing_executors"):
HP_DIFFERING_EXECUTORS = [hp.HealthCheck.differing_executors]
_INDEXING_TEST_CASES = [
((4, 8, 128), (...,), (4, 8, 128)),
@ -375,7 +371,6 @@ class IndexerOpsTest(PallasBaseTest):
np.testing.assert_array_equal(left_out_np, left_out)
np.testing.assert_array_equal(right_out_np, right_out)
@hp.settings(suppress_health_check=HP_DIFFERING_EXECUTORS)
@hp.given(hps.data())
def test_vmap_nd_indexing(self, data):
self.skipTest("TODO(necula): enable this test; was in jax_triton.")