Merge pull request #27804 from vfdev-5:ft-adapt-state-test-2

PiperOrigin-RevId: 745341315
This commit is contained in:
jax authors 2025-04-08 16:48:54 -07:00
commit 373ac2ef7e
6 changed files with 25 additions and 97 deletions

View File

@ -69,3 +69,7 @@ race:gemm_oncopy
# https://github.com/python/cpython/issues/129547
# Maybe fixed?
# race:type_get_annotations
# https://github.com/python/cpython/issues/132013
# Fixed on 3.14 and not backported to 3.13
race_top:frozenset_hash

View File

@ -14,6 +14,7 @@ on:
paths:
- '**/workflows/tsan.yaml'
- '**/workflows/tsan-suppressions*.txt'
- '**/workflows/requirements_lock_3_13_ft.patch'
jobs:
tsan:

View File

@ -10,98 +10,12 @@ absl-py==2.1.0
attrs==24.3.0
auditwheel==6.2.0
build==1.2.2.post1
cloudpickle==3.1.1 # version 3.1.0 leads to recursion error
colorama==0.4.6
contourpy==1.3.1
cycler==0.12.1
etils[epath,epy]==1.11.0
execnet==2.1.1
filelock==3.16.1
flatbuffers==24.12.23
fonttools==4.56.0
fsspec==2024.12.0
hypothesis==6.123.9
importlib-resources==6.5.2
iniconfig==2.0.0
kiwisolver==1.4.8
markdown-it-py==3.0.0
matplotlib==3.10.1
mdurl==0.1.2
ml-dtypes==0.5.1
mpmath==1.3.0
nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux"
nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux"
nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux"
nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux"
nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux"
nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux"
nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux"
nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux"
nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux"
nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux"
opt-einsum==3.4.0
packaging==24.2
pillow==11.1.0
pluggy==1.5.0
portpicker==1.6.0
psutil==6.1.1
pyelftools==0.31
pygments==2.19.1
pyparsing==3.2.2 # version 3.2.1 fails with SyntaxError(originally SyntaxWarning): 'return' in a 'finally' block in pyparsing/core.py", line 5716
pyproject-hooks==1.2.0
pytest==8.3.4
pytest-xdist==3.6.1
python-dateutil==2.9.0.post0
rich==13.9.4
six==1.17.0
sortedcontainers==2.4.0
typing-extensions==4.12.2
flatbuffers==24.12.23
wheel==0.45.1
ml-dtypes==0.5.1
zipp==3.21.0
# python 3.14t can't compile 0.23.0
# due to https://github.com/indygreg/python-zstandard/issues/231
# zstandard==0.23.0
setuptools==70.3.0
opt-einsum==3.4.0

View File

@ -91,7 +91,7 @@ _MAX_CASES_SAMPLING_RETRIES = config.int_flag(
'sampling process is terminated.'
)
_SKIP_SLOW_TESTS = config.bool_flag(
SKIP_SLOW_TESTS = config.bool_flag(
'jax_skip_slow_tests',
config.bool_env('JAX_SKIP_SLOW_TESTS', False),
help='Skip tests marked as slow (> 5 sec).'

View File

@ -82,22 +82,27 @@ def get_zstandard():
return []
return ["@pypi_zstandard//:pkg"]
def get_optional_dep(package, excluded_py_versions = ["3.14", "3.14-ft"]):
if HERMETIC_PYTHON_VERSION in excluded_py_versions:
return []
return [package]
_py_deps = {
"absl/logging": ["@pypi_absl_py//:pkg"],
"absl/testing": ["@pypi_absl_py//:pkg"],
"absl/flags": ["@pypi_absl_py//:pkg"],
"cloudpickle": ["@pypi_cloudpickle//:pkg"],
"colorama": ["@pypi_colorama//:pkg"],
"epath": ["@pypi_etils//:pkg"], # etils.epath
"filelock": ["@pypi_filelock//:pkg"],
"cloudpickle": get_optional_dep("@pypi_cloudpickle//:pkg"),
"colorama": get_optional_dep("@pypi_colorama//:pkg"),
"epath": get_optional_dep("@pypi_etils//:pkg"), # etils.epath
"filelock": get_optional_dep("@pypi_filelock//:pkg"),
"flatbuffers": ["@pypi_flatbuffers//:pkg"],
"hypothesis": ["@pypi_hypothesis//:pkg"],
"magma": [],
"matplotlib": ["@pypi_matplotlib//:pkg"],
"matplotlib": get_optional_dep("@pypi_matplotlib//:pkg"),
"mpmath": [],
"opt_einsum": ["@pypi_opt_einsum//:pkg"],
"pil": ["@pypi_pillow//:pkg"],
"portpicker": ["@pypi_portpicker//:pkg"],
"pil": get_optional_dep("@pypi_pillow//:pkg"),
"portpicker": get_optional_dep("@pypi_portpicker//:pkg"),
"ml_dtypes": ["@pypi_ml_dtypes//:pkg"],
"numpy": ["@pypi_numpy//:pkg"],
"scipy": ["@pypi_scipy//:pkg"],

View File

@ -1826,6 +1826,10 @@ if CAN_USE_HYPOTHESIS:
y2 = random.normal(jax.random.clone(k1), y.shape)
self.assertAllClose(impl_vjp(t), ref_vjp(t))
if jtu.SKIP_SLOW_TESTS.value:
# Skip second order tests if JAX_SKIP_SLOW_TESTS=true
return
# Second order
key, k1, k2 = random.split(key, 3)
t2 = random.normal(k2, t.shape)