rocm_jax/pyproject.toml

178 lines
4.4 KiB
TOML
Raw Normal View History

[build-system]
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"
2023-04-15 02:39:39 +01:00
[tool.mypy]
show_error_codes = true
disable_error_code = "attr-defined, name-defined, annotation-unchecked"
2023-04-15 02:39:39 +01:00
no_implicit_optional = true
warn_unused_ignores = true
2023-04-15 02:39:39 +01:00
[[tool.mypy.overrides]]
module = [
"absl.*",
"colorama.*",
"filelock.*",
2023-04-15 02:39:39 +01:00
"IPython.*",
"numpy.*",
"opt_einsum.*",
"scipy.*",
"libtpu.*",
"jaxlib.mlir.*",
"rich.*",
"optax.*",
[export] Add support for serialization and deserialization of Exported At the moment we can export a JAX function into an Exported and we can invoke an Exported from another JAX function, but there is no way to serialize an Exported to be able to use it in another process. Exported now has all the features we had in mind, so it is a reasonable time to add a serialization method. The intention is for the serialization to have backwards compatibility guarantees, meaning that we can deserialize an Exported that has been serialized with an older version of JAX. This PR does not add explicit support for versioning, nor backwards compatibility tests. Those will follow. Here we add serialization and deserialization to bytearray, using the flatbuffers package. We use flatbuffers because it is simple, it has backwards and forwards compatibility guarantees, it is a lightweight dependency that does not require additional build steps, and it is fast (deserialization simply indexes into the bytearray rather than creating a Python structure). In the process of implementing this we have done some small cleanup of the Exported structure: * renamed serialization_version to mlir_module_serialization_version * renamed disabled_checks to disabled_safety_checks This code is tested by changing export_test.py to interpose a serialization followed by a deserialization every time we export.export. There is a known bug with the serialization of effects, so I disabled one of the export tests. Will fix in a subsequent PR. PiperOrigin-RevId: 590078785
2023-12-11 23:22:16 -08:00
"flatbuffers.*",
2023-04-15 02:39:39 +01:00
"flax.*",
"tensorflow.*",
"tensorflowjs.*",
"tensorflow.io.*",
"tensorstore.*",
"web_pdb.*",
"etils.*",
"google.colab.*",
"pygments.*",
"jraph.*",
"matplotlib.*",
"tensorboard_plugin_profile.convert.*",
"jaxlib.*",
"pytest.*",
2023-04-17 11:22:08 -07:00
"zstandard.*",
2023-04-15 02:39:39 +01:00
"jax.experimental.jax2tf.tests.flax_models",
"jax.experimental.jax2tf.tests.back_compat_testdata",
"setuptools.*",
"jax_cuda12_plugin.*",
2023-04-15 02:39:39 +01:00
]
ignore_missing_imports = true
[tool.pytest.ini_options]
markers = [
"multiaccelerator: indicates that a test can make use of and possibly requires multiple accelerators",
"SlurmMultiNodeGpuTest: mark a test for Slurm multinode GPU nightly CI"
]
filterwarnings = [
"error",
2024-05-28 13:13:40 -07:00
"default:Error (reading|writing) persistent compilation cache entry for 'jit_equal'",
"default:Error (reading|writing) persistent compilation cache entry for 'jit__lambda_'",
"default:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning",
"default:jax.xla_computation is deprecated. Please use the AOT APIs.*:DeprecationWarning",
# TODO(jakevdp): remove when array_api_tests stabilize
# start array_api_tests-related warnings
2024-05-28 13:13:40 -07:00
"default:.*not machine-readable.*:UserWarning",
"default:Special cases found for .* but none were parsed.*:UserWarning",
"default:.*is not JSON-serializable. Using the repr instead.",
# end array_api_tests-related warnings
2023-04-15 02:39:39 +01:00
]
doctest_optionflags = [
"NUMBER",
"NORMALIZE_WHITESPACE"
]
addopts = "--doctest-glob='*.rst'"
[tool.pylint.master]
extension-pkg-whitelist = "numpy"
[tool.pylint."messages control"]
disable = [
"missing-docstring",
"too-many-locals",
"invalid-name",
"redefined-outer-name",
"redefined-builtin",
"protected-name",
"no-else-return",
"fixme",
"protected-access",
"too-many-arguments",
"blacklisted-name",
"too-few-public-methods",
"unnecessary-lambda"
]
enable = "c-extension-no-member"
[tool.pylint.format]
indent-string=" "
2023-11-14 23:34:30 -05:00
[tool.ruff]
preview = true
exclude = [
".git",
"build",
"__pycache__",
]
line-length = 88
indent-width = 2
target-version = "py39"
[tool.ruff.lint]
2023-11-14 23:34:30 -05:00
ignore = [
# Unnecessary collection call
"C408",
# Unnecessary map usage
"C417",
# Object names too complex
"C901",
# Local variable is assigned to but never used
"F841",
# Raise with from clause inside except block
"B904",
]
select = [
"B9",
"C",
"F",
"W",
"YTT",
"ASYNC",
"E225",
"E227",
"E228",
]
[tool.ruff.lint.mccabe]
2023-11-14 23:34:30 -05:00
max-complexity = 18
[tool.ruff.lint.per-file-ignores]
2023-11-14 23:34:30 -05:00
# F811: Redefinition of unused name.
"docs/autodidax.py" = ["F811"]
# Note: we don't use jax/*.py because this matches contents of jax/_src
"__init__.py" = ["F401"]
"jax/abstract_arrays.py" = ["F401"]
"jax/ad_checkpoint.py" = ["F401"]
"jax/api_util.py" = ["F401"]
"jax/cloud_tpu_init.py" = ["F401"]
"jax/core.py" = ["F401"]
"jax/custom_batching.py" = ["F401"]
"jax/custom_derivatives.py" = ["F401"]
"jax/custom_transpose.py" = ["F401"]
"jax/debug.py" = ["F401"]
"jax/distributed.py" = ["F401"]
"jax/dlpack.py" = ["F401"]
"jax/dtypes.py" = ["F401"]
"jax/errors.py" = ["F401"]
"jax/experimental/*.py" = ["F401"]
"jax/extend/*.py" = ["F401"]
"jax/flatten_util.py" = ["F401"]
"jax/interpreters/ad.py" = ["F401"]
"jax/interpreters/batching.py" = ["F401"]
"jax/interpreters/mlir.py" = ["F401"]
"jax/interpreters/partial_eval.py" = ["F401"]
"jax/interpreters/pxla.py" = ["F401"]
"jax/interpreters/xla.py" = ["F401"]
"jax/lax/*.py" = ["F401"]
"jax/linear_util.py" = ["F401"]
"jax/monitoring.py" = ["F401"]
"jax/nn/*.py" = ["F401"]
"jax/numpy/*.py" = ["F401"]
"jax/prng.py" = ["F401"]
"jax/profiler.py" = ["F401"]
"jax/random.py" = ["F401"]
"jax/scipy/*.py" = ["F401"]
"jax/sharding.py" = ["F401"]
"jax/stages.py" = ["F401"]
"jax/test_util.py" = ["F401"]
"jax/tree_util.py" = ["F401"]
"jax/typing.py" = ["F401"]
"jax/util.py" = ["F401"]
# F821: Undefined name.
"jax/numpy/__init__.pyi" = ["F821"]