[build-system] requires = ["setuptools", "wheel"] build-backend = "setuptools.build_meta" [tool.mypy] show_error_codes = true disable_error_code = "attr-defined, name-defined, annotation-unchecked" no_implicit_optional = true warn_redundant_casts = true warn_unused_ignores = true [[tool.mypy.overrides]] module = [ "IPython.*", "absl.*", "colorama.*", "etils.*", "filelock.*", "flatbuffers.*", "flax.*", "google.colab.*", "hypothesis.*", "jax.experimental.jax2tf.tests.back_compat_testdata", "jax.experimental.jax2tf.tests.flax_models", "jax_cuda12_plugin.*", "jaxlib.*", "jaxlib.mlir.*", "jraph.*", "libtpu.*", "matplotlib.*", "numpy.*", "opt_einsum.*", "optax.*", "pygments.*", "pytest.*", "rich.*", "scipy.*", "setuptools.*", "tensorboard_plugin_profile.convert.*", "tensorflow.*", "tensorflow.io.*", "tensorflowjs.*", "tensorstore.*", "web_pdb.*", "zstandard.*", ] ignore_missing_imports = true [tool.pytest.ini_options] markers = [ "multiaccelerator: indicates that a test can make use of and possibly requires multiple accelerators", "SlurmMultiNodeGpuTest: mark a test for Slurm multinode GPU nightly CI" ] filterwarnings = [ "error", "default:Error (reading|writing) persistent compilation cache entry for 'jit_equal'", "default:Error (reading|writing) persistent compilation cache entry for 'jit__lambda_'", "default:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning", # TODO(jakevdp): remove when array_api_tests stabilize "default:.*not machine-readable.*:UserWarning", "default:Special cases found for .* but none were parsed.*:UserWarning", "default:.*is not JSON-serializable. Using the repr instead.", # These are transitive warnings coming from TensorFlow dependencies. # TODO(slebedev): Remove once we bump the minimum TensorFlow version. "default:The key path API is deprecated .*", "default:jax.xla_computation is deprecated.*:DeprecationWarning", ] doctest_optionflags = [ "NUMBER", "NORMALIZE_WHITESPACE" ] addopts = "--doctest-glob='*.rst'" [tool.pylint.master] extension-pkg-whitelist = "numpy" [tool.pylint."messages control"] disable = [ "missing-docstring", "too-many-locals", "invalid-name", "redefined-outer-name", "redefined-builtin", "protected-name", "no-else-return", "fixme", "protected-access", "too-many-arguments", "blacklisted-name", "too-few-public-methods", "unnecessary-lambda" ] enable = "c-extension-no-member" [tool.pylint.format] indent-string=" " [tool.ruff] preview = true exclude = [ ".git", "build", "__pycache__", ] line-length = 88 indent-width = 2 target-version = "py310" [tool.ruff.lint] ignore = [ # Unnecessary collection call "C408", # Unnecessary map usage "C417", # Object names too complex "C901", # Local variable is assigned to but never used "F841", # Raise with from clause inside except block "B904", # Zip without explicit strict parameter "B905", ] select = [ "B9", "C", "F", "W", "YTT", "ASYNC", "E225", "E227", "E228", ] [tool.ruff.lint.mccabe] max-complexity = 18 [tool.ruff.lint.per-file-ignores] # F811: Redefinition of unused name. "docs/autodidax.py" = ["F811"] # Note: we don't use jax/*.py because this matches contents of jax/_src "__init__.py" = ["F401"] "jax/abstract_arrays.py" = ["F401"] "jax/ad_checkpoint.py" = ["F401"] "jax/api_util.py" = ["F401"] "jax/cloud_tpu_init.py" = ["F401"] "jax/core.py" = ["F401"] "jax/custom_batching.py" = ["F401"] "jax/custom_derivatives.py" = ["F401"] "jax/custom_transpose.py" = ["F401"] "jax/debug.py" = ["F401"] "jax/distributed.py" = ["F401"] "jax/dlpack.py" = ["F401"] "jax/dtypes.py" = ["F401"] "jax/errors.py" = ["F401"] "jax/experimental/*.py" = ["F401"] "jax/extend/*.py" = ["F401"] "jax/flatten_util.py" = ["F401"] "jax/interpreters/ad.py" = ["F401"] "jax/interpreters/batching.py" = ["F401"] "jax/interpreters/mlir.py" = ["F401"] "jax/interpreters/partial_eval.py" = ["F401"] "jax/interpreters/pxla.py" = ["F401"] "jax/interpreters/xla.py" = ["F401"] "jax/lax/*.py" = ["F401"] "jax/linear_util.py" = ["F401"] "jax/monitoring.py" = ["F401"] "jax/nn/*.py" = ["F401"] "jax/numpy/*.py" = ["F401"] "jax/prng.py" = ["F401"] "jax/profiler.py" = ["F401"] "jax/random.py" = ["F401"] "jax/scipy/*.py" = ["F401"] "jax/sharding.py" = ["F401"] "jax/stages.py" = ["F401"] "jax/test_util.py" = ["F401"] "jax/tree_util.py" = ["F401"] "jax/typing.py" = ["F401"] "jax/util.py" = ["F401"] # F821: Undefined name. "jax/numpy/__init__.pyi" = ["F821"]