mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Merge pull request #22441 from gnecula:test_clean_hypothesis
PiperOrigin-RevId: 652919414
This commit is contained in:
commit
5ddec63a47
@ -547,6 +547,22 @@ python tests/lax_numpy_test.py --test_targets="testPad"
|
||||
|
||||
The Colab notebooks are tested for errors as part of the documentation build.
|
||||
|
||||
### Hypothesis tests
|
||||
|
||||
Some of the tests use [hypothesis](https://hypothesis.readthedocs.io/en/latest).
|
||||
Normally, hypothesis will test using multiple example inputs, and on a test failure
|
||||
it will try to find a smaller example that still results in failure:
|
||||
Look through the test failure for a line like the one below, and add the decorator
|
||||
mentioned in the message:
|
||||
```
|
||||
You can reproduce this example by temporarily adding @reproduce_failure('6.97.4', b'AXicY2DAAAAAEwAB') as a decorator on your test case
|
||||
```
|
||||
|
||||
For interactive development, you can set the environment variable
|
||||
`JAX_HYPOTHESIS_PROFILE=interactive` (or the equivalent flag `--jax_hypothesis_profile=interactive`)
|
||||
in order to set the number of examples to 1, and skip the example
|
||||
minimization phase.
|
||||
|
||||
### Doctests
|
||||
|
||||
JAX uses pytest in doctest mode to test the code examples within the documentation.
|
||||
|
@ -22,6 +22,7 @@ import datetime
|
||||
import functools
|
||||
from functools import partial
|
||||
import inspect
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
@ -107,6 +108,13 @@ TEST_WITH_PERSISTENT_COMPILATION_CACHE = config.bool_flag(
|
||||
help='If enabled, the persistent compilation cache will be enabled for all '
|
||||
'test cases. This can be used to increase compilation cache coverage.')
|
||||
|
||||
HYPOTHESIS_PROFILE = config.string_flag(
|
||||
'hypothesis_profile',
|
||||
os.getenv('JAX_HYPOTHESIS_PROFILE', 'deterministic'),
|
||||
help=('Select the hypothesis profile to use for testing. Available values: '
|
||||
'deterministic, interactive'),
|
||||
)
|
||||
|
||||
# We sanitize test names to ensure they work with "unitttest -k" and
|
||||
# "pytest -k" test filtering. pytest accepts '[' and ']' but unittest -k
|
||||
# does not. We replace sequences of problematic characters with a single '_'.
|
||||
@ -2036,3 +2044,47 @@ class numpy_with_mpmath:
|
||||
return worker(ctx, scale, exact, reference, value)
|
||||
else:
|
||||
assert 0 # unreachable
|
||||
|
||||
# Hypothesis testing support
|
||||
def setup_hypothesis(max_examples=30) -> None:
|
||||
"""Sets up the hypothesis profiles.
|
||||
|
||||
Sets up the hypothesis testing profiles, and selects the one specified by
|
||||
the ``JAX_HYPOTHESIS_PROFILE`` environment variable (or the
|
||||
``--jax_hypothesis_profile`` configuration.
|
||||
|
||||
Args:
|
||||
max_examples: the maximum number of hypothesis examples to try, when using
|
||||
the default "deterministic" profile.
|
||||
"""
|
||||
try:
|
||||
import hypothesis as hp # type: ignore
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
return
|
||||
|
||||
hp.settings.register_profile(
|
||||
"deterministic",
|
||||
database=None,
|
||||
derandomize=True,
|
||||
deadline=None,
|
||||
max_examples=max_examples,
|
||||
print_blob=True,
|
||||
)
|
||||
hp.settings.register_profile(
|
||||
"interactive",
|
||||
parent=hp.settings.load_profile("deterministic"),
|
||||
max_examples=1,
|
||||
report_multiple_bugs=False,
|
||||
verbosity=hp.Verbosity.verbose,
|
||||
# Don't try and shrink
|
||||
phases=(
|
||||
hp.Phase.explicit,
|
||||
hp.Phase.reuse,
|
||||
hp.Phase.generate,
|
||||
hp.Phase.target,
|
||||
hp.Phase.explain,
|
||||
),
|
||||
)
|
||||
profile = HYPOTHESIS_PROFILE.value
|
||||
logging.info("Using hypothesis profile: %s", profile)
|
||||
hp.settings.load_profile(profile)
|
||||
|
@ -35,12 +35,11 @@ except (ModuleNotFoundError, ImportError):
|
||||
|
||||
import hypothesis.extra.numpy as hnp
|
||||
import hypothesis.strategies as hps
|
||||
hp.settings.register_profile(
|
||||
"deterministic", database=None, derandomize=True, deadline=None,
|
||||
max_examples=100, print_blob=True)
|
||||
hp.settings.load_profile("deterministic")
|
||||
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
jtu.setup_hypothesis(max_examples=100)
|
||||
|
||||
|
||||
Slice = indexing.Slice
|
||||
NDIndexer = indexing.NDIndexer
|
||||
|
@ -35,30 +35,15 @@ import numpy as np
|
||||
try:
|
||||
import hypothesis as hp
|
||||
import hypothesis.strategies as hps
|
||||
CAN_USE_HYPOTHESIS = True
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
CAN_USE_HYPOTHESIS = False
|
||||
|
||||
if not CAN_USE_HYPOTHESIS:
|
||||
raise unittest.SkipTest("tests require hypothesis")
|
||||
|
||||
raise unittest.SkipTest("these tests require hypothesis")
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
hp.settings.register_profile(
|
||||
"deterministic",
|
||||
database=None,
|
||||
derandomize=True,
|
||||
deadline=None,
|
||||
max_examples=30,
|
||||
print_blob=True,
|
||||
)
|
||||
hp.settings.load_profile("deterministic")
|
||||
jtu.setup_hypothesis()
|
||||
|
||||
partial = functools.partial
|
||||
Draw = TypeVar("Draw", bound=Callable[[hps.SearchStrategy[Any]], Any])
|
||||
|
||||
|
||||
@hps.composite
|
||||
def segment_ids_strategy(draw, seq_len: int) -> splash.SegmentIds:
|
||||
boundaries = hps.sets(hps.integers(1, seq_len - 1), min_size=1, max_size=4)
|
||||
@ -466,7 +451,7 @@ class SplashAttentionTest(AttentionTest):
|
||||
|
||||
q_seq_len, kv_seq_len, head_dim, dtype = data.draw(attention_strategy())
|
||||
|
||||
# Avoid segment ids for rectangular matrices, as its hard to enforce
|
||||
# Avoid segment ids for rectangular matrices, as it's hard to enforce
|
||||
# valid masks (non-0 rows).
|
||||
hp.assume(q_seq_len == kv_seq_len or not is_segmented)
|
||||
|
||||
|
@ -51,6 +51,8 @@ from jax._src.state.types import (shaped_array_ref, ReadEffect, WriteEffect,
|
||||
AccumEffect, AbstractRef)
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jtu.setup_hypothesis()
|
||||
|
||||
|
||||
class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user