rocm_jax/pyproject.toml
Peter Hawkins 6ae01247f0 Fix pytest failures from compilation cache test.
The names of the functions in the compilation cache tests changed, causing warnings emitted by that test to become errors.
2024-04-29 11:08:07 -04:00

205 lines
5.9 KiB
TOML

[build-system]
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"
[tool.mypy]
show_error_codes = true
disable_error_code = "attr-defined, name-defined, annotation-unchecked"
no_implicit_optional = true
[[tool.mypy.overrides]]
module = [
"absl.*",
"colorama.*",
"importlib_metadata.*",
"IPython.*",
"numpy.*",
"opt_einsum.*",
"scipy.*",
"libtpu.*",
"jaxlib.mlir.*",
"iree.*",
"rich.*",
"optax.*",
"flatbuffers.*",
"flax.*",
"tensorflow.*",
"tensorflowjs.*",
"tensorflow.io.*",
"tensorstore.*",
"web_pdb.*",
"etils.*",
"google.colab.*",
"pygments.*",
"jraph.*",
"matplotlib.*",
"tensorboard_plugin_profile.convert.*",
"jaxlib.*",
"pytest.*",
"zstandard.*",
"jax.experimental.jax2tf.tests.flax_models",
"jax.experimental.jax2tf.tests.back_compat_testdata",
"setuptools.*",
]
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = [
"jax.interpreters.autospmd",
"jax.lax.lax_parallel",
"jax._src.internal_test_util.test_harnesses",
]
ignore_errors = true
[tool.pytest.ini_options]
markers = [
"multiaccelerator: indicates that a test can make use of and possibly requires multiple accelerators",
"SlurmMultiNodeGpuTest: mark a test for Slurm multinode GPU nightly CI"
]
filterwarnings = [
"error",
"ignore:The hookimpl.*:DeprecationWarning",
"ignore:No GPU/TPU found, falling back to CPU.:UserWarning",
"ignore:xmap is an experimental feature and probably has bugs!",
"ignore:the imp module is deprecated in favour of importlib.*:DeprecationWarning",
"ignore:can't resolve package from __spec__ or __package__:ImportWarning",
"ignore:Using or importing the ABCs.*:DeprecationWarning",
"ignore:numpy.ufunc size changed",
"ignore:.*experimental feature",
"ignore:The distutils.* is deprecated.*:DeprecationWarning",
"default:Error reading persistent compilation cache entry for 'jit_equal'",
"default:Error reading persistent compilation cache entry for 'jit__lambda_'",
"default:Error writing persistent compilation cache entry for 'jit_equal'",
"default:Error writing persistent compilation cache entry for 'jit__lambda_'",
"ignore:backend and device argument on jit is deprecated.*:DeprecationWarning",
# TODO(skyewm): remove when jaxlib >= 0.4.12 is released (needs
# https://github.com/openxla/xla/commit/fb9dc3db0999bf14c78d95cb7c3aa6815221ddc7)
"ignore:ml_dtypes.float8_e4m3b11 is deprecated.",
"ignore:JAX_USE_PJRT_C_API_ON_TPU=false will no longer be supported.*:UserWarning",
"ignore:np.find_common_type is deprecated.*:DeprecationWarning",
"ignore:jax.numpy.in1d is deprecated.*:DeprecationWarning",
# TODO(jakevdp): remove when array_api_tests stabilize
# start array_api_tests-related warnings
"ignore:The numpy.array_api submodule is still experimental.*:UserWarning",
"ignore:case not machine-readable.*:UserWarning",
"ignore:not machine-readable.*:UserWarning",
"ignore:Special cases found for .* but none were parsed.*:UserWarning",
# end array_api_tests-related warnings
"ignore:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning",
"ignore:jax.experimental.maps and .* are deprecated.*:DeprecationWarning",
"ignore:The host_callback APIs are deprecated .*:DeprecationWarning",
]
doctest_optionflags = [
"NUMBER",
"NORMALIZE_WHITESPACE"
]
addopts = "--doctest-glob='*.rst'"
[tool.pylint.master]
extension-pkg-whitelist = "numpy"
[tool.pylint."messages control"]
disable = [
"missing-docstring",
"too-many-locals",
"invalid-name",
"redefined-outer-name",
"redefined-builtin",
"protected-name",
"no-else-return",
"fixme",
"protected-access",
"too-many-arguments",
"blacklisted-name",
"too-few-public-methods",
"unnecessary-lambda"
]
enable = "c-extension-no-member"
[tool.pylint.format]
indent-string=" "
[tool.ruff]
preview = true
exclude = [
".git",
"build",
"__pycache__",
]
line-length = 88
indent-width = 2
target-version = "py39"
[tool.ruff.lint]
ignore = [
# Unnecessary collection call
"C408",
# Unnecessary map usage
"C417",
# Object names too complex
"C901",
# Local variable is assigned to but never used
"F841",
# Raise with from clause inside except block
"B904",
]
select = [
"B9",
"C",
"F",
"W",
"YTT",
"ASYNC",
"E225",
"E227",
"E228",
]
[tool.ruff.lint.mccabe]
max-complexity = 18
[tool.ruff.lint.per-file-ignores]
# F811: Redefinition of unused name.
"docs/autodidax.py" = ["F811"]
# Note: we don't use jax/*.py because this matches contents of jax/_src
"__init__.py" = ["F401"]
"jax/abstract_arrays.py" = ["F401"]
"jax/ad_checkpoint.py" = ["F401"]
"jax/api_util.py" = ["F401"]
"jax/cloud_tpu_init.py" = ["F401"]
"jax/core.py" = ["F401"]
"jax/custom_batching.py" = ["F401"]
"jax/custom_derivatives.py" = ["F401"]
"jax/custom_transpose.py" = ["F401"]
"jax/debug.py" = ["F401"]
"jax/distributed.py" = ["F401"]
"jax/dlpack.py" = ["F401"]
"jax/dtypes.py" = ["F401"]
"jax/errors.py" = ["F401"]
"jax/experimental/*.py" = ["F401"]
"jax/extend/*.py" = ["F401"]
"jax/flatten_util.py" = ["F401"]
"jax/interpreters/ad.py" = ["F401"]
"jax/interpreters/batching.py" = ["F401"]
"jax/interpreters/mlir.py" = ["F401"]
"jax/interpreters/partial_eval.py" = ["F401"]
"jax/interpreters/pxla.py" = ["F401"]
"jax/interpreters/xla.py" = ["F401"]
"jax/lax/*.py" = ["F401"]
"jax/linear_util.py" = ["F401"]
"jax/monitoring.py" = ["F401"]
"jax/nn/*.py" = ["F401"]
"jax/numpy/*.py" = ["F401"]
"jax/prng.py" = ["F401"]
"jax/profiler.py" = ["F401"]
"jax/random.py" = ["F401"]
"jax/scipy/*.py" = ["F401"]
"jax/sharding.py" = ["F401"]
"jax/stages.py" = ["F401"]
"jax/test_util.py" = ["F401"]
"jax/tree_util.py" = ["F401"]
"jax/typing.py" = ["F401"]
"jax/util.py" = ["F401"]
# F821: Undefined name.
"jax/numpy/__init__.pyi" = ["F821"]