Merge pull request #22441 from gnecula:test_clean_hypothesis

PiperOrigin-RevId: 652919414
This commit is contained in:
jax authors 2024-07-16 11:32:46 -07:00
commit 5ddec63a47
5 changed files with 76 additions and 22 deletions

View File

@ -547,6 +547,22 @@ python tests/lax_numpy_test.py --test_targets="testPad"
The Colab notebooks are tested for errors as part of the documentation build.
### Hypothesis tests
Some of the tests use [hypothesis](https://hypothesis.readthedocs.io/en/latest).
Normally, hypothesis will test using multiple example inputs, and on a test failure
it will try to find a smaller example that still results in failure:
Look through the test failure for a line like the one below, and add the decorator
mentioned in the message:
```
You can reproduce this example by temporarily adding @reproduce_failure('6.97.4', b'AXicY2DAAAAAEwAB') as a decorator on your test case
```
For interactive development, you can set the environment variable
`JAX_HYPOTHESIS_PROFILE=interactive` (or the equivalent flag `--jax_hypothesis_profile=interactive`)
in order to set the number of examples to 1, and skip the example
minimization phase.
### Doctests
JAX uses pytest in doctest mode to test the code examples within the documentation.

View File

@ -22,6 +22,7 @@ import datetime
import functools
from functools import partial
import inspect
import logging
import math
import os
import re
@ -107,6 +108,13 @@ TEST_WITH_PERSISTENT_COMPILATION_CACHE = config.bool_flag(
help='If enabled, the persistent compilation cache will be enabled for all '
'test cases. This can be used to increase compilation cache coverage.')
HYPOTHESIS_PROFILE = config.string_flag(
'hypothesis_profile',
os.getenv('JAX_HYPOTHESIS_PROFILE', 'deterministic'),
help=('Select the hypothesis profile to use for testing. Available values: '
'deterministic, interactive'),
)
# We sanitize test names to ensure they work with "unitttest -k" and
# "pytest -k" test filtering. pytest accepts '[' and ']' but unittest -k
# does not. We replace sequences of problematic characters with a single '_'.
@ -2036,3 +2044,47 @@ class numpy_with_mpmath:
return worker(ctx, scale, exact, reference, value)
else:
assert 0 # unreachable
# Hypothesis testing support
def setup_hypothesis(max_examples=30) -> None:
"""Sets up the hypothesis profiles.
Sets up the hypothesis testing profiles, and selects the one specified by
the ``JAX_HYPOTHESIS_PROFILE`` environment variable (or the
``--jax_hypothesis_profile`` configuration.
Args:
max_examples: the maximum number of hypothesis examples to try, when using
the default "deterministic" profile.
"""
try:
import hypothesis as hp # type: ignore
except (ModuleNotFoundError, ImportError):
return
hp.settings.register_profile(
"deterministic",
database=None,
derandomize=True,
deadline=None,
max_examples=max_examples,
print_blob=True,
)
hp.settings.register_profile(
"interactive",
parent=hp.settings.load_profile("deterministic"),
max_examples=1,
report_multiple_bugs=False,
verbosity=hp.Verbosity.verbose,
# Don't try and shrink
phases=(
hp.Phase.explicit,
hp.Phase.reuse,
hp.Phase.generate,
hp.Phase.target,
hp.Phase.explain,
),
)
profile = HYPOTHESIS_PROFILE.value
logging.info("Using hypothesis profile: %s", profile)
hp.settings.load_profile(profile)

View File

@ -35,12 +35,11 @@ except (ModuleNotFoundError, ImportError):
import hypothesis.extra.numpy as hnp
import hypothesis.strategies as hps
hp.settings.register_profile(
"deterministic", database=None, derandomize=True, deadline=None,
max_examples=100, print_blob=True)
hp.settings.load_profile("deterministic")
jax.config.parse_flags_with_absl()
jtu.setup_hypothesis(max_examples=100)
Slice = indexing.Slice
NDIndexer = indexing.NDIndexer

View File

@ -35,30 +35,15 @@ import numpy as np
try:
import hypothesis as hp
import hypothesis.strategies as hps
CAN_USE_HYPOTHESIS = True
except (ModuleNotFoundError, ImportError):
CAN_USE_HYPOTHESIS = False
if not CAN_USE_HYPOTHESIS:
raise unittest.SkipTest("tests require hypothesis")
raise unittest.SkipTest("these tests require hypothesis")
jax.config.parse_flags_with_absl()
hp.settings.register_profile(
"deterministic",
database=None,
derandomize=True,
deadline=None,
max_examples=30,
print_blob=True,
)
hp.settings.load_profile("deterministic")
jtu.setup_hypothesis()
partial = functools.partial
Draw = TypeVar("Draw", bound=Callable[[hps.SearchStrategy[Any]], Any])
@hps.composite
def segment_ids_strategy(draw, seq_len: int) -> splash.SegmentIds:
boundaries = hps.sets(hps.integers(1, seq_len - 1), min_size=1, max_size=4)
@ -466,7 +451,7 @@ class SplashAttentionTest(AttentionTest):
q_seq_len, kv_seq_len, head_dim, dtype = data.draw(attention_strategy())
# Avoid segment ids for rectangular matrices, as its hard to enforce
# Avoid segment ids for rectangular matrices, as it's hard to enforce
# valid masks (non-0 rows).
hp.assume(q_seq_len == kv_seq_len or not is_segmented)

View File

@ -51,6 +51,8 @@ from jax._src.state.types import (shaped_array_ref, ReadEffect, WriteEffect,
AccumEffect, AbstractRef)
config.parse_flags_with_absl()
jtu.setup_hypothesis()
class StatePrimitivesTest(jtu.JaxTestCase):