rocm_jax/pyproject.toml
Sergei Lebedev 8d33a6c9a6 Bumped jaxlib version mypy uses on the CI
I also enabled unnecessary cast checking, because turns out we have quite
a few of those.
2024-07-26 11:22:39 +01:00

184 lines
4.6 KiB
TOML

[build-system]
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"
[tool.mypy]
show_error_codes = true
disable_error_code = "attr-defined, name-defined, annotation-unchecked"
no_implicit_optional = true
warn_redundant_casts = true
warn_unused_ignores = true
[[tool.mypy.overrides]]
module = [
"absl.*",
"colorama.*",
"filelock.*",
"IPython.*",
"numpy.*",
"opt_einsum.*",
"scipy.*",
"libtpu.*",
"jaxlib.mlir.*",
"rich.*",
"optax.*",
"flatbuffers.*",
"flax.*",
"tensorflow.*",
"tensorflowjs.*",
"tensorflow.io.*",
"tensorstore.*",
"web_pdb.*",
"etils.*",
"google.colab.*",
"pygments.*",
"jraph.*",
"matplotlib.*",
"tensorboard_plugin_profile.convert.*",
"jaxlib.*",
"pytest.*",
"zstandard.*",
"jax.experimental.jax2tf.tests.flax_models",
"jax.experimental.jax2tf.tests.back_compat_testdata",
"setuptools.*",
"jax_cuda12_plugin.*",
]
ignore_missing_imports = true
[tool.pytest.ini_options]
markers = [
"multiaccelerator: indicates that a test can make use of and possibly requires multiple accelerators",
"SlurmMultiNodeGpuTest: mark a test for Slurm multinode GPU nightly CI"
]
filterwarnings = [
"error",
"default:Error (reading|writing) persistent compilation cache entry for 'jit_equal'",
"default:Error (reading|writing) persistent compilation cache entry for 'jit__lambda_'",
"default:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning",
# TODO(jakevdp): remove when array_api_tests stabilize
"default:.*not machine-readable.*:UserWarning",
"default:Special cases found for .* but none were parsed.*:UserWarning",
"default:.*is not JSON-serializable. Using the repr instead.",
# These are transitive warnings coming from TensorFlow dependencies.
# TODO(slebedev): Remove once we bump the minimum TensorFlow version.
"default:The key path API is deprecated .*",
"default:jax.xla_computation is deprecated.*:DeprecationWarning",
]
doctest_optionflags = [
"NUMBER",
"NORMALIZE_WHITESPACE"
]
addopts = "--doctest-glob='*.rst'"
[tool.pylint.master]
extension-pkg-whitelist = "numpy"
[tool.pylint."messages control"]
disable = [
"missing-docstring",
"too-many-locals",
"invalid-name",
"redefined-outer-name",
"redefined-builtin",
"protected-name",
"no-else-return",
"fixme",
"protected-access",
"too-many-arguments",
"blacklisted-name",
"too-few-public-methods",
"unnecessary-lambda"
]
enable = "c-extension-no-member"
[tool.pylint.format]
indent-string=" "
[tool.ruff]
preview = true
exclude = [
".git",
"build",
"__pycache__",
]
line-length = 88
indent-width = 2
target-version = "py310"
[tool.ruff.lint]
ignore = [
# Unnecessary collection call
"C408",
# Unnecessary map usage
"C417",
# Object names too complex
"C901",
# Local variable is assigned to but never used
"F841",
# Raise with from clause inside except block
"B904",
# Zip without explicit strict parameter
"B905",
]
select = [
"B9",
"C",
"F",
"W",
"YTT",
"ASYNC",
"E225",
"E227",
"E228",
]
[tool.ruff.lint.mccabe]
max-complexity = 18
[tool.ruff.lint.per-file-ignores]
# F811: Redefinition of unused name.
"docs/autodidax.py" = ["F811"]
# Note: we don't use jax/*.py because this matches contents of jax/_src
"__init__.py" = ["F401"]
"jax/abstract_arrays.py" = ["F401"]
"jax/ad_checkpoint.py" = ["F401"]
"jax/api_util.py" = ["F401"]
"jax/cloud_tpu_init.py" = ["F401"]
"jax/core.py" = ["F401"]
"jax/custom_batching.py" = ["F401"]
"jax/custom_derivatives.py" = ["F401"]
"jax/custom_transpose.py" = ["F401"]
"jax/debug.py" = ["F401"]
"jax/distributed.py" = ["F401"]
"jax/dlpack.py" = ["F401"]
"jax/dtypes.py" = ["F401"]
"jax/errors.py" = ["F401"]
"jax/experimental/*.py" = ["F401"]
"jax/extend/*.py" = ["F401"]
"jax/flatten_util.py" = ["F401"]
"jax/interpreters/ad.py" = ["F401"]
"jax/interpreters/batching.py" = ["F401"]
"jax/interpreters/mlir.py" = ["F401"]
"jax/interpreters/partial_eval.py" = ["F401"]
"jax/interpreters/pxla.py" = ["F401"]
"jax/interpreters/xla.py" = ["F401"]
"jax/lax/*.py" = ["F401"]
"jax/linear_util.py" = ["F401"]
"jax/monitoring.py" = ["F401"]
"jax/nn/*.py" = ["F401"]
"jax/numpy/*.py" = ["F401"]
"jax/prng.py" = ["F401"]
"jax/profiler.py" = ["F401"]
"jax/random.py" = ["F401"]
"jax/scipy/*.py" = ["F401"]
"jax/sharding.py" = ["F401"]
"jax/stages.py" = ["F401"]
"jax/test_util.py" = ["F401"]
"jax/tree_util.py" = ["F401"]
"jax/typing.py" = ["F401"]
"jax/util.py" = ["F401"]
# F821: Undefined name.
"jax/numpy/__init__.pyi" = ["F821"]