rocm_jax/pyproject.toml

204 lines
5.8 KiB
TOML
Raw Normal View History

[build-system]
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"
2023-04-15 02:39:39 +01:00
[tool.mypy]
show_error_codes = true
disable_error_code = "attr-defined, name-defined, annotation-unchecked"
2023-04-15 02:39:39 +01:00
no_implicit_optional = true
warn_unused_ignores = true
2023-04-15 02:39:39 +01:00
[[tool.mypy.overrides]]
module = [
"absl.*",
"colorama.*",
2023-08-16 13:27:11 -07:00
"importlib_metadata.*",
2023-04-15 02:39:39 +01:00
"IPython.*",
"numpy.*",
"opt_einsum.*",
"scipy.*",
"libtpu.*",
"jaxlib.mlir.*",
"rich.*",
"optax.*",
[export] Add support for serialization and deserialization of Exported At the moment we can export a JAX function into an Exported and we can invoke an Exported from another JAX function, but there is no way to serialize an Exported to be able to use it in another process. Exported now has all the features we had in mind, so it is a reasonable time to add a serialization method. The intention is for the serialization to have backwards compatibility guarantees, meaning that we can deserialize an Exported that has been serialized with an older version of JAX. This PR does not add explicit support for versioning, nor backwards compatibility tests. Those will follow. Here we add serialization and deserialization to bytearray, using the flatbuffers package. We use flatbuffers because it is simple, it has backwards and forwards compatibility guarantees, it is a lightweight dependency that does not require additional build steps, and it is fast (deserialization simply indexes into the bytearray rather than creating a Python structure). In the process of implementing this we have done some small cleanup of the Exported structure: * renamed serialization_version to mlir_module_serialization_version * renamed disabled_checks to disabled_safety_checks This code is tested by changing export_test.py to interpose a serialization followed by a deserialization every time we export.export. There is a known bug with the serialization of effects, so I disabled one of the export tests. Will fix in a subsequent PR. PiperOrigin-RevId: 590078785
2023-12-11 23:22:16 -08:00
"flatbuffers.*",
2023-04-15 02:39:39 +01:00
"flax.*",
"tensorflow.*",
"tensorflowjs.*",
"tensorflow.io.*",
"tensorstore.*",
"web_pdb.*",
"etils.*",
"google.colab.*",
"pygments.*",
"jraph.*",
"matplotlib.*",
"tensorboard_plugin_profile.convert.*",
"jaxlib.*",
"pytest.*",
2023-04-17 11:22:08 -07:00
"zstandard.*",
2023-04-15 02:39:39 +01:00
"jax.experimental.jax2tf.tests.flax_models",
"jax.experimental.jax2tf.tests.back_compat_testdata",
"setuptools.*",
2023-04-15 02:39:39 +01:00
]
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = [
"jax.interpreters.autospmd",
"jax.lax.lax_parallel",
"jax._src.internal_test_util.test_harnesses",
2023-04-15 02:39:39 +01:00
]
ignore_errors = true
[tool.pytest.ini_options]
markers = [
"multiaccelerator: indicates that a test can make use of and possibly requires multiple accelerators",
"SlurmMultiNodeGpuTest: mark a test for Slurm multinode GPU nightly CI"
]
filterwarnings = [
"error",
"ignore:The hookimpl.*:DeprecationWarning",
2023-04-15 02:39:39 +01:00
"ignore:No GPU/TPU found, falling back to CPU.:UserWarning",
"ignore:xmap is an experimental feature and probably has bugs!",
"ignore:the imp module is deprecated in favour of importlib.*:DeprecationWarning",
"ignore:can't resolve package from __spec__ or __package__:ImportWarning",
"ignore:Using or importing the ABCs.*:DeprecationWarning",
"ignore:numpy.ufunc size changed",
"ignore:.*experimental feature",
"ignore:The distutils.* is deprecated.*:DeprecationWarning",
"default:Error reading persistent compilation cache entry for 'jit_equal'",
2023-04-15 02:39:39 +01:00
"default:Error reading persistent compilation cache entry for 'jit__lambda_'",
"default:Error writing persistent compilation cache entry for 'jit_equal'",
2023-04-15 02:39:39 +01:00
"default:Error writing persistent compilation cache entry for 'jit__lambda_'",
"ignore:backend and device argument on jit is deprecated.*:DeprecationWarning",
# TODO(skyewm): remove when jaxlib >= 0.4.12 is released (needs
# https://github.com/openxla/xla/commit/fb9dc3db0999bf14c78d95cb7c3aa6815221ddc7)
"ignore:ml_dtypes.float8_e4m3b11 is deprecated.",
"ignore:JAX_USE_PJRT_C_API_ON_TPU=false will no longer be supported.*:UserWarning",
2023-08-14 10:08:57 -07:00
"ignore:np.find_common_type is deprecated.*:DeprecationWarning",
"ignore:jax.numpy.in1d is deprecated.*:DeprecationWarning",
# TODO(jakevdp): remove when array_api_tests stabilize
# start array_api_tests-related warnings
"ignore:The numpy.array_api submodule is still experimental.*:UserWarning",
"ignore:case not machine-readable.*:UserWarning",
"ignore:not machine-readable.*:UserWarning",
"ignore:Special cases found for .* but none were parsed.*:UserWarning",
# end array_api_tests-related warnings
"ignore:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning",
"ignore:.*is not JSON-serializable. Using the repr instead.",
2023-04-15 02:39:39 +01:00
]
doctest_optionflags = [
"NUMBER",
"NORMALIZE_WHITESPACE"
]
addopts = "--doctest-glob='*.rst'"
[tool.pylint.master]
extension-pkg-whitelist = "numpy"
[tool.pylint."messages control"]
disable = [
"missing-docstring",
"too-many-locals",
"invalid-name",
"redefined-outer-name",
"redefined-builtin",
"protected-name",
"no-else-return",
"fixme",
"protected-access",
"too-many-arguments",
"blacklisted-name",
"too-few-public-methods",
"unnecessary-lambda"
]
enable = "c-extension-no-member"
[tool.pylint.format]
indent-string=" "
2023-11-14 23:34:30 -05:00
[tool.ruff]
preview = true
exclude = [
".git",
"build",
"__pycache__",
]
line-length = 88
indent-width = 2
target-version = "py39"
[tool.ruff.lint]
2023-11-14 23:34:30 -05:00
ignore = [
# Unnecessary collection call
"C408",
# Unnecessary map usage
"C417",
# Object names too complex
"C901",
# Local variable is assigned to but never used
"F841",
# Raise with from clause inside except block
"B904",
]
select = [
"B9",
"C",
"F",
"W",
"YTT",
"ASYNC",
"E225",
"E227",
"E228",
]
[tool.ruff.lint.mccabe]
2023-11-14 23:34:30 -05:00
max-complexity = 18
[tool.ruff.lint.per-file-ignores]
2023-11-14 23:34:30 -05:00
# F811: Redefinition of unused name.
"docs/autodidax.py" = ["F811"]
# Note: we don't use jax/*.py because this matches contents of jax/_src
"__init__.py" = ["F401"]
"jax/abstract_arrays.py" = ["F401"]
"jax/ad_checkpoint.py" = ["F401"]
"jax/api_util.py" = ["F401"]
"jax/cloud_tpu_init.py" = ["F401"]
"jax/core.py" = ["F401"]
"jax/custom_batching.py" = ["F401"]
"jax/custom_derivatives.py" = ["F401"]
"jax/custom_transpose.py" = ["F401"]
"jax/debug.py" = ["F401"]
"jax/distributed.py" = ["F401"]
"jax/dlpack.py" = ["F401"]
"jax/dtypes.py" = ["F401"]
"jax/errors.py" = ["F401"]
"jax/experimental/*.py" = ["F401"]
"jax/extend/*.py" = ["F401"]
"jax/flatten_util.py" = ["F401"]
"jax/interpreters/ad.py" = ["F401"]
"jax/interpreters/batching.py" = ["F401"]
"jax/interpreters/mlir.py" = ["F401"]
"jax/interpreters/partial_eval.py" = ["F401"]
"jax/interpreters/pxla.py" = ["F401"]
"jax/interpreters/xla.py" = ["F401"]
"jax/lax/*.py" = ["F401"]
"jax/linear_util.py" = ["F401"]
"jax/monitoring.py" = ["F401"]
"jax/nn/*.py" = ["F401"]
"jax/numpy/*.py" = ["F401"]
"jax/prng.py" = ["F401"]
"jax/profiler.py" = ["F401"]
"jax/random.py" = ["F401"]
"jax/scipy/*.py" = ["F401"]
"jax/sharding.py" = ["F401"]
"jax/stages.py" = ["F401"]
"jax/test_util.py" = ["F401"]
"jax/tree_util.py" = ["F401"]
"jax/typing.py" = ["F401"]
"jax/util.py" = ["F401"]
# F821: Undefined name.
"jax/numpy/__init__.pyi" = ["F821"]