[build-system] requires = ["setuptools", "wheel"] build-backend = "setuptools.build_meta" [tool.mypy] show_error_codes = true disable_error_code = "attr-defined, name-defined, annotation-unchecked" no_implicit_optional = true [[tool.mypy.overrides]] module = [ "absl.*", "colorama.*", "importlib_metadata.*", "IPython.*", "numpy.*", "opt_einsum.*", "scipy.*", "libtpu.*", "jaxlib.mlir.*", "iree.*", "rich.*", "optax.*", "flatbuffers.*", "flax.*", "tensorflow.*", "tensorflowjs.*", "tensorflow.io.*", "tensorstore.*", "web_pdb.*", "etils.*", "google.colab.*", "pygments.*", "jraph.*", "matplotlib.*", "tensorboard_plugin_profile.convert.*", "jaxlib.*", "pytest.*", "zstandard.*", "jax.experimental.jax2tf.tests.flax_models", "jax.experimental.jax2tf.tests.back_compat_testdata", "setuptools.*", ] ignore_missing_imports = true [[tool.mypy.overrides]] module = [ "jax.interpreters.autospmd", "jax.lax.lax_parallel", "jax._src.internal_test_util.test_harnesses", ] ignore_errors = true [tool.pytest.ini_options] markers = [ "multiaccelerator: indicates that a test can make use of and possibly requires multiple accelerators", "SlurmMultiNodeGpuTest: mark a test for Slurm multinode GPU nightly CI" ] filterwarnings = [ "error", "ignore:The hookimpl.*:DeprecationWarning", "ignore:No GPU/TPU found, falling back to CPU.:UserWarning", "ignore:xmap is an experimental feature and probably has bugs!", "ignore:the imp module is deprecated in favour of importlib.*:DeprecationWarning", "ignore:can't resolve package from __spec__ or __package__:ImportWarning", "ignore:Using or importing the ABCs.*:DeprecationWarning", "ignore:numpy.ufunc size changed", "ignore:.*experimental feature", "ignore:The distutils.* is deprecated.*:DeprecationWarning", "default:Error reading persistent compilation cache entry for 'jit_equal'", "default:Error reading persistent compilation cache entry for 'jit__lambda_'", "default:Error writing persistent compilation cache entry for 'jit_equal'", "default:Error writing persistent compilation cache entry for 'jit__lambda_'", "ignore:backend and device argument on jit is deprecated.*:DeprecationWarning", # TODO(skyewm): remove when jaxlib >= 0.4.12 is released (needs # https://github.com/openxla/xla/commit/fb9dc3db0999bf14c78d95cb7c3aa6815221ddc7) "ignore:ml_dtypes.float8_e4m3b11 is deprecated.", "ignore:JAX_USE_PJRT_C_API_ON_TPU=false will no longer be supported.*:UserWarning", "ignore:np.find_common_type is deprecated.*:DeprecationWarning", "ignore:jax.numpy.in1d is deprecated.*:DeprecationWarning", # TODO(jakevdp): remove when array_api_tests stabilize # start array_api_tests-related warnings "ignore:The numpy.array_api submodule is still experimental.*:UserWarning", "ignore:case not machine-readable.*:UserWarning", "ignore:not machine-readable.*:UserWarning", "ignore:Special cases found for .* but none were parsed.*:UserWarning", # end array_api_tests-related warnings "ignore:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning", "ignore:jax.experimental.maps and .* are deprecated.*:DeprecationWarning", "ignore:The host_callback APIs are deprecated .*:DeprecationWarning", ] doctest_optionflags = [ "NUMBER", "NORMALIZE_WHITESPACE" ] addopts = "--doctest-glob='*.rst'" [tool.pylint.master] extension-pkg-whitelist = "numpy" [tool.pylint."messages control"] disable = [ "missing-docstring", "too-many-locals", "invalid-name", "redefined-outer-name", "redefined-builtin", "protected-name", "no-else-return", "fixme", "protected-access", "too-many-arguments", "blacklisted-name", "too-few-public-methods", "unnecessary-lambda" ] enable = "c-extension-no-member" [tool.pylint.format] indent-string=" " [tool.ruff] preview = true exclude = [ ".git", "build", "__pycache__", ] line-length = 88 indent-width = 2 target-version = "py39" [tool.ruff.lint] ignore = [ # Unnecessary collection call "C408", # Unnecessary map usage "C417", # Object names too complex "C901", # Local variable is assigned to but never used "F841", # Raise with from clause inside except block "B904", ] select = [ "B9", "C", "F", "W", "YTT", "ASYNC", "E225", "E227", "E228", ] [tool.ruff.lint.mccabe] max-complexity = 18 [tool.ruff.lint.per-file-ignores] # F811: Redefinition of unused name. "docs/autodidax.py" = ["F811"] # Note: we don't use jax/*.py because this matches contents of jax/_src "__init__.py" = ["F401"] "jax/abstract_arrays.py" = ["F401"] "jax/ad_checkpoint.py" = ["F401"] "jax/api_util.py" = ["F401"] "jax/cloud_tpu_init.py" = ["F401"] "jax/core.py" = ["F401"] "jax/custom_batching.py" = ["F401"] "jax/custom_derivatives.py" = ["F401"] "jax/custom_transpose.py" = ["F401"] "jax/debug.py" = ["F401"] "jax/distributed.py" = ["F401"] "jax/dlpack.py" = ["F401"] "jax/dtypes.py" = ["F401"] "jax/errors.py" = ["F401"] "jax/experimental/*.py" = ["F401"] "jax/extend/*.py" = ["F401"] "jax/flatten_util.py" = ["F401"] "jax/interpreters/ad.py" = ["F401"] "jax/interpreters/batching.py" = ["F401"] "jax/interpreters/mlir.py" = ["F401"] "jax/interpreters/partial_eval.py" = ["F401"] "jax/interpreters/pxla.py" = ["F401"] "jax/interpreters/xla.py" = ["F401"] "jax/lax/*.py" = ["F401"] "jax/linear_util.py" = ["F401"] "jax/monitoring.py" = ["F401"] "jax/nn/*.py" = ["F401"] "jax/numpy/*.py" = ["F401"] "jax/prng.py" = ["F401"] "jax/profiler.py" = ["F401"] "jax/random.py" = ["F401"] "jax/scipy/*.py" = ["F401"] "jax/sharding.py" = ["F401"] "jax/stages.py" = ["F401"] "jax/test_util.py" = ["F401"] "jax/tree_util.py" = ["F401"] "jax/typing.py" = ["F401"] "jax/util.py" = ["F401"] # F821: Undefined name. "jax/numpy/__init__.pyi" = ["F821"]