mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #27804 from vfdev-5:ft-adapt-state-test-2
PiperOrigin-RevId: 745341315
This commit is contained in:
commit
373ac2ef7e
4
.github/workflows/tsan-suppressions_3.13.txt
vendored
4
.github/workflows/tsan-suppressions_3.13.txt
vendored
@ -69,3 +69,7 @@ race:gemm_oncopy
|
||||
# https://github.com/python/cpython/issues/129547
|
||||
# Maybe fixed?
|
||||
# race:type_get_annotations
|
||||
|
||||
# https://github.com/python/cpython/issues/132013
|
||||
# Fixed on 3.14 and not backported to 3.13
|
||||
race_top:frozenset_hash
|
1
.github/workflows/tsan.yaml
vendored
1
.github/workflows/tsan.yaml
vendored
@ -14,6 +14,7 @@ on:
|
||||
paths:
|
||||
- '**/workflows/tsan.yaml'
|
||||
- '**/workflows/tsan-suppressions*.txt'
|
||||
- '**/workflows/requirements_lock_3_13_ft.patch'
|
||||
|
||||
jobs:
|
||||
tsan:
|
||||
|
@ -10,98 +10,12 @@ absl-py==2.1.0
|
||||
|
||||
attrs==24.3.0
|
||||
|
||||
auditwheel==6.2.0
|
||||
|
||||
build==1.2.2.post1
|
||||
|
||||
cloudpickle==3.1.1 # version 3.1.0 leads to recursion error
|
||||
|
||||
colorama==0.4.6
|
||||
|
||||
contourpy==1.3.1
|
||||
|
||||
cycler==0.12.1
|
||||
|
||||
etils[epath,epy]==1.11.0
|
||||
|
||||
execnet==2.1.1
|
||||
|
||||
filelock==3.16.1
|
||||
|
||||
flatbuffers==24.12.23
|
||||
|
||||
fonttools==4.56.0
|
||||
|
||||
fsspec==2024.12.0
|
||||
|
||||
hypothesis==6.123.9
|
||||
|
||||
importlib-resources==6.5.2
|
||||
|
||||
iniconfig==2.0.0
|
||||
|
||||
kiwisolver==1.4.8
|
||||
|
||||
markdown-it-py==3.0.0
|
||||
|
||||
matplotlib==3.10.1
|
||||
|
||||
mdurl==0.1.2
|
||||
|
||||
ml-dtypes==0.5.1
|
||||
|
||||
mpmath==1.3.0
|
||||
|
||||
nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux"
|
||||
|
||||
nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux"
|
||||
nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux"
|
||||
nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux"
|
||||
nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux"
|
||||
nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux"
|
||||
nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux"
|
||||
nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux"
|
||||
nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux"
|
||||
|
||||
nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux"
|
||||
opt-einsum==3.4.0
|
||||
|
||||
packaging==24.2
|
||||
|
||||
pillow==11.1.0
|
||||
pluggy==1.5.0
|
||||
|
||||
portpicker==1.6.0
|
||||
|
||||
psutil==6.1.1
|
||||
pyelftools==0.31
|
||||
|
||||
pygments==2.19.1
|
||||
|
||||
pyparsing==3.2.2 # version 3.2.1 fails with SyntaxError(originally SyntaxWarning): 'return' in a 'finally' block in pyparsing/core.py", line 5716
|
||||
|
||||
pyproject-hooks==1.2.0
|
||||
|
||||
pytest==8.3.4
|
||||
|
||||
pytest-xdist==3.6.1
|
||||
|
||||
python-dateutil==2.9.0.post0
|
||||
|
||||
rich==13.9.4
|
||||
|
||||
six==1.17.0
|
||||
|
||||
sortedcontainers==2.4.0
|
||||
|
||||
typing-extensions==4.12.2
|
||||
flatbuffers==24.12.23
|
||||
|
||||
wheel==0.45.1
|
||||
ml-dtypes==0.5.1
|
||||
|
||||
zipp==3.21.0
|
||||
|
||||
# python 3.14t can't compile 0.23.0
|
||||
# due to https://github.com/indygreg/python-zstandard/issues/231
|
||||
# zstandard==0.23.0
|
||||
|
||||
setuptools==70.3.0
|
||||
opt-einsum==3.4.0
|
||||
|
@ -91,7 +91,7 @@ _MAX_CASES_SAMPLING_RETRIES = config.int_flag(
|
||||
'sampling process is terminated.'
|
||||
)
|
||||
|
||||
_SKIP_SLOW_TESTS = config.bool_flag(
|
||||
SKIP_SLOW_TESTS = config.bool_flag(
|
||||
'jax_skip_slow_tests',
|
||||
config.bool_env('JAX_SKIP_SLOW_TESTS', False),
|
||||
help='Skip tests marked as slow (> 5 sec).'
|
||||
|
@ -82,22 +82,27 @@ def get_zstandard():
|
||||
return []
|
||||
return ["@pypi_zstandard//:pkg"]
|
||||
|
||||
def get_optional_dep(package, excluded_py_versions = ["3.14", "3.14-ft"]):
|
||||
if HERMETIC_PYTHON_VERSION in excluded_py_versions:
|
||||
return []
|
||||
return [package]
|
||||
|
||||
_py_deps = {
|
||||
"absl/logging": ["@pypi_absl_py//:pkg"],
|
||||
"absl/testing": ["@pypi_absl_py//:pkg"],
|
||||
"absl/flags": ["@pypi_absl_py//:pkg"],
|
||||
"cloudpickle": ["@pypi_cloudpickle//:pkg"],
|
||||
"colorama": ["@pypi_colorama//:pkg"],
|
||||
"epath": ["@pypi_etils//:pkg"], # etils.epath
|
||||
"filelock": ["@pypi_filelock//:pkg"],
|
||||
"cloudpickle": get_optional_dep("@pypi_cloudpickle//:pkg"),
|
||||
"colorama": get_optional_dep("@pypi_colorama//:pkg"),
|
||||
"epath": get_optional_dep("@pypi_etils//:pkg"), # etils.epath
|
||||
"filelock": get_optional_dep("@pypi_filelock//:pkg"),
|
||||
"flatbuffers": ["@pypi_flatbuffers//:pkg"],
|
||||
"hypothesis": ["@pypi_hypothesis//:pkg"],
|
||||
"magma": [],
|
||||
"matplotlib": ["@pypi_matplotlib//:pkg"],
|
||||
"matplotlib": get_optional_dep("@pypi_matplotlib//:pkg"),
|
||||
"mpmath": [],
|
||||
"opt_einsum": ["@pypi_opt_einsum//:pkg"],
|
||||
"pil": ["@pypi_pillow//:pkg"],
|
||||
"portpicker": ["@pypi_portpicker//:pkg"],
|
||||
"pil": get_optional_dep("@pypi_pillow//:pkg"),
|
||||
"portpicker": get_optional_dep("@pypi_portpicker//:pkg"),
|
||||
"ml_dtypes": ["@pypi_ml_dtypes//:pkg"],
|
||||
"numpy": ["@pypi_numpy//:pkg"],
|
||||
"scipy": ["@pypi_scipy//:pkg"],
|
||||
|
@ -1826,6 +1826,10 @@ if CAN_USE_HYPOTHESIS:
|
||||
y2 = random.normal(jax.random.clone(k1), y.shape)
|
||||
self.assertAllClose(impl_vjp(t), ref_vjp(t))
|
||||
|
||||
if jtu.SKIP_SLOW_TESTS.value:
|
||||
# Skip second order tests if JAX_SKIP_SLOW_TESTS=true
|
||||
return
|
||||
|
||||
# Second order
|
||||
key, k1, k2 = random.split(key, 3)
|
||||
t2 = random.normal(k2, t.shape)
|
||||
|
Loading…
x
Reference in New Issue
Block a user