Merge pull request #15585 from SauravMaheshkar:main

PiperOrigin-RevId: 524896086
This commit is contained in:
jax authors 2023-04-17 11:08:46 -07:00
commit 9c92ec5d90
6 changed files with 129 additions and 140 deletions

View File

@ -33,6 +33,7 @@ repos:
files: (jax/|tests/typing_test\.py)
exclude: jax/_src/basearray.py # Use pyi instead
additional_dependencies: [types-requests==2.28.11, jaxlib==0.4.6, ml_dtypes==0.0.3, numpy==1.21.6, scipy==1.7.3]
args: [--config=pyproject.toml]
- repo: https://github.com/mwouts/jupytext
rev: v1.14.4

View File

@ -30,7 +30,7 @@ guidance on pip installation (e.g., for GPU and TPU support).
To build `jaxlib` from source, you must also install some prerequisites:
* a C++ compiler (g++, clang, or MSVC)
- a C++ compiler (g++, clang, or MSVC)
On Ubuntu or Debian you can install the necessary prerequisites with:
@ -42,7 +42,8 @@ To build `jaxlib` from source, you must also install some prerequisites:
are installed.
See below for Windows build instructions.
* Python packages: `numpy`, `wheel`.
- Python packages: `numpy`, `wheel`.
You can install the necessary Python dependencies using `pip`:
@ -74,13 +75,14 @@ By default JAX uses a pinned copy of the XLA repository, but we often
want to use a locally-modified copy of XLA when working on JAX. There are two
ways to do this:
* use Bazel's `override_repository` feature, which you can pass as a command
- use Bazel's `override_repository` feature, which you can pass as a command
line flag to `build.py` as follows:
```
python build/build.py --bazel_options=--override_repository=xla=/path/to/xla
```
* modify the `WORKSPACE` file in the root of the JAX source tree to point to
- modify the `WORKSPACE` file in the root of the JAX source tree to point to
a different XLA tree.
To contribute changes back to XLA, send PRs to the XLA repository.
@ -112,6 +114,7 @@ for more details. Install the following packages:
```
pacman -S patch coreutils
```
Once coreutils is installed, the realpath command should be present in your shell's path.
Once everything is installed. Open PowerShell, and make sure MSYS2 is in the
@ -144,12 +147,14 @@ sudo apt install miopen-hip hipfft-dev rocrand-dev hipsparse-dev hipsolver-dev \
AMD's fork of the XLA repository may include fixes
not present in the upstream repository. To use AMD's fork, you should clone
their repository:
```
git clone https://github.com/ROCmSoftwarePlatform/tensorflow-upstream.git
```
To build jaxlib with ROCM support, you can run the following build command,
suitably adjusted for your paths and ROCM version.
```
python build/build.py --enable_rocm --rocm_path=/opt/rocm-5.3.0 \
--bazel_options=--override_repository=xla=/path/to/xla-upstream
@ -178,6 +183,7 @@ or using pytest.
### Using Bazel
First, configure the JAX build by running:
```
python build/build.py --configure_only
```
@ -200,7 +206,6 @@ To use a preinstalled `jaxlib` instead of building `jaxlib` from source, run
bazel test --//jax:build_jaxlib=false //tests:cpu_tests //tests:backend_independent_tests
```
A number of test behaviors can be controlled using environment variables (see
below). Environment variables may be passed to JAX tests using the
`--test_env=FLAG=value` flag to Bazel.
@ -240,11 +245,14 @@ cases that are generated and checked for each test (default is 10) using the
currently use 25 by default.
For example, one might write
```
# Bazel
bazel test //tests/... --test_env=JAX_NUM_GENERATED_CASES=25`
```
or
```
# pytest
JAX_NUM_GENERATED_CASES=25 pytest -n auto tests
@ -283,14 +291,18 @@ The Colab notebooks are tested for errors as part of the documentation build.
JAX uses pytest in doctest mode to test the code examples within the documentation.
You can run this using
```
pytest docs
```
Additionally, JAX runs pytest in `doctest-modules` mode to ensure code examples in
function docstrings will run correctly. You can run this locally using, for example:
```
pytest --doctest-modules jax/_src/numpy/lax_numpy.py
```
Keep in mind that there are several files that are marked to be skipped when the
doctest command is run on the full package; you can see the details in
[`ci-build.yaml`](https://github.com/google/jax/blob/main/.github/workflows/ci-build.yaml)
@ -302,7 +314,7 @@ as the CI checks them:
```
pip install mypy
mypy --config=mypy.ini --show-error-codes jax
mypy --config=pyproject.toml --show-error-codes jax
```
Alternatively, you can use the [pre-commit](https://pre-commit.com/) framework to run this
@ -334,18 +346,24 @@ pre-commit run flake8
## Update documentation
To rebuild the documentation, install several packages:
```
pip install -r docs/requirements.txt
```
And then run:
```
sphinx-build -b html docs docs/build/html -j auto
```
This can take a long time because it executes many of the notebooks in the documentation source;
if you'd prefer to build the docs without executing the notebooks, you can run:
```
sphinx-build -b html -D nb_execution_mode=off docs docs/build/html -j auto
```
You can then see the generated documentation in `docs/build/html/index.html`.
The `-j auto` option controls the parallelism of the build. You can use a number

View File

@ -1,55 +0,0 @@
[mypy]
show_error_codes = True
disable_error_code = attr-defined
no_implicit_optional = True
[mypy-absl.*]
ignore_missing_imports = True
[mypy-colorama.*]
ignore_missing_imports = True
[mypy-numpy.*]
ignore_missing_imports = True
[mypy-opt_einsum.*]
ignore_missing_imports = True
[mypy-scipy.*]
ignore_missing_imports = True
[mypy-jax.interpreters.autospmd]
ignore_errors = True
[mypy-jax.lax.lax_parallel]
ignore_errors = True
[mypy-jax.experimental.jax2tf.tests.primitive_harness]
ignore_errors = True
[mypy-libtpu.*]
ignore_missing_imports = True
[mypy-jaxlib.mlir.*]
ignore_missing_imports = True
[mypy-iree.*]
ignore_missing_imports = True
[mypy-rich.*]
ignore_missing_imports = True
[mypy-optax.*]
ignore_missing_imports = True
[mypy-flax.*]
ignore_missing_imports = True
[mypy-tensorflow.*]
ignore_missing_imports = True
[mypy-tensorflowjs.*]
ignore_missing_imports = True
[mypy-tensorflow.io.*]
ignore_missing_imports = True
[mypy-tensorstore.*]
ignore_missing_imports = True
[mypy-web_pdb.*]
ignore_missing_imports = True
[mypy-etils.*]
ignore_missing_imports = True
[mypy-google.colab.*]
ignore_missing_imports = True
[mypy-pygments.*]
ignore_missing_imports = True
[mypy-jraph.*]
ignore_missing_imports = True
[mypy-matplotlib.*]
ignore_missing_imports = True
[mypy-tensorboard_plugin_profile.convert.*]
ignore_missing_imports = True

View File

@ -1,46 +0,0 @@
[MASTER]
# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code
extension-pkg-whitelist=numpy
[MESSAGES CONTROL]
# Disable the message, report, category or checker with the given id(s). You
# can either give multiple identifiers separated by comma (,) or put this
# option multiple times (only on the command line, not in the configuration
# file where it should appear only once).You can also use "--disable=all" to
# disable everything first and then reenable specific checks. For example, if
# you want to run only the similarities checker, you can use "--disable=all
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use"--disable=all --enable=classes
# --disable=W"
disable=missing-docstring,
too-many-locals,
invalid-name,
redefined-outer-name,
redefined-builtin,
protected-name,
no-else-return,
fixme,
protected-access,
too-many-arguments,
blacklisted-name,
too-few-public-methods,
unnecessary-lambda,
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
# it should appear only once). See also the "--disable" option for examples.
enable=c-extension-no-member
[FORMAT]
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
# tab).
indent-string=" "

View File

@ -1,3 +1,107 @@
[build-system]
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"
[tool.mypy]
show_error_codes = true
disable_error_code = "attr-defined"
no_implicit_optional = true
[[tool.mypy.overrides]]
module = [
"absl.*",
"colorama.*",
"IPython.*",
"numpy.*",
"opt_einsum.*",
"scipy.*",
"libtpu.*",
"jaxlib.mlir.*",
"iree.*",
"rich.*",
"optax.*",
"flax.*",
"tensorflow.*",
"tensorflowjs.*",
"tensorflow.io.*",
"tensorstore.*",
"web_pdb.*",
"etils.*",
"google.colab.*",
"pygments.*",
"jraph.*",
"matplotlib.*",
"tensorboard_plugin_profile.convert.*",
"jaxlib.*",
"pytest.*",
"ml_dtypes",
"jax.experimental.jax2tf.tests.flax_models",
"jax.experimental.jax2tf.tests.back_compat_testdata"
]
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = [
"jax.interpreters.autospmd",
"jax.lax.lax_parallel",
"jax.experimental.jax2tf.tests.primitive_harness"
]
ignore_errors = true
[tool.pytest.ini_options]
markers = [
"multiaccelerator: indicates that a test can make use of and possibly requires multiple accelerators",
"SlurmMultiNodeGpuTest: mark a test for Slurm multinode GPU nightly CI"
]
filterwarnings = [
"error",
"ignore:No GPU/TPU found, falling back to CPU.:UserWarning",
"ignore:outfeed_receiver is unnecessary and deprecated:DeprecationWarning",
"ignore:xmap is an experimental feature and probably has bugs!",
"ignore:the imp module is deprecated in favour of importlib.*:DeprecationWarning",
"ignore:can't resolve package from __spec__ or __package__:ImportWarning",
"ignore:Using or importing the ABCs.*:DeprecationWarning",
"ignore:numpy.ufunc size changed",
"ignore:.*experimental feature",
"ignore:index.*is deprecated.*:DeprecationWarning",
"ignore:jax.experimental.* is deprecated, import jax.example_libraries.* instead:FutureWarning",
"ignore:The distutils.* is deprecated.*:DeprecationWarning",
"ignore:`sharded_jit` is deprecated. Please use `pjit` instead.*:DeprecationWarning",
"default:Error reading persistent compilation cache entry for 'jit__lambda_'",
"default:Error writing persistent compilation cache entry for 'jit__lambda_'",
"ignore:DeviceArray, ShardedDeviceArray, and GlobalDeviceArray have been deprecated.*:DeprecationWarning",
"ignore:backend and device argument on jit is deprecated.*:DeprecationWarning",
"ignore:GlobalDeviceArray has been deprecated.*:DeprecationWarning",
"ignore:jax.interpreters.pxla.make_sharded_device_array is deprecated.*:DeprecationWarning",
# TODO(skyewm, yashkatariya): Remove after jaxlib 0.4.7 is released.
"ignore:jax.interpreters.pxla.ShardedDeviceArray is deprecated.*:DeprecationWarning"
]
doctest_optionflags = [
"NUMBER",
"NORMALIZE_WHITESPACE"
]
addopts = "--doctest-glob='*.rst'"
[tool.pylint.master]
extension-pkg-whitelist = "numpy"
[tool.pylint."messages control"]
disable = [
"missing-docstring",
"too-many-locals",
"invalid-name",
"redefined-outer-name",
"redefined-builtin",
"protected-name",
"no-else-return",
"fixme",
"protected-access",
"too-many-arguments",
"blacklisted-name",
"too-few-public-methods",
"unnecessary-lambda"
]
enable = "c-extension-no-member"
[tool.pylint.format]
indent-string=" "

View File

@ -1,33 +0,0 @@
[pytest]
markers =
multiaccelerator: indicates that a test can make use of and possibly requires multiple accelerators
SlurmMultiNodeGpuTest: mark a test for Slurm multinode GPU nightly CI
filterwarnings =
error
ignore:No GPU/TPU found, falling back to CPU.:UserWarning
ignore:outfeed_receiver is unnecessary and deprecated:DeprecationWarning
# xmap
ignore:xmap is an experimental feature and probably has bugs!
# The rest are for experimental/jax_to_tf
ignore:the imp module is deprecated in favour of importlib.*:DeprecationWarning
ignore:can't resolve package from __spec__ or __package__:ImportWarning
ignore:Using or importing the ABCs.*:DeprecationWarning
# jax2tf tests due to mix of JAX and TF
ignore:numpy.ufunc size changed
ignore:.*experimental feature
ignore:index.*is deprecated.*:DeprecationWarning
ignore:jax.experimental.* is deprecated, import jax.example_libraries.* instead:FutureWarning
# numpy uses distutils which is deprecated
ignore:The distutils.* is deprecated.*:DeprecationWarning
ignore:`sharded_jit` is deprecated. Please use `pjit` instead.*:DeprecationWarning
# Print message for compilation_cache_test.py::CompilationCacheTest::test_cache_read/write_warning
default:Error reading persistent compilation cache entry for 'jit__lambda_'
default:Error writing persistent compilation cache entry for 'jit__lambda_'
ignore:DeviceArray, ShardedDeviceArray, and GlobalDeviceArray have been deprecated.*:DeprecationWarning
ignore:backend and device argument on jit is deprecated.*:DeprecationWarning
ignore:GlobalDeviceArray has been deprecated.*:DeprecationWarning
ignore:jax.interpreters.pxla.make_sharded_device_array is deprecated.*:DeprecationWarning
# TODO(skyewm, yashkatariya): Remove after jaxlib 0.4.7 is released.
ignore:jax.interpreters.pxla.ShardedDeviceArray is deprecated.*:DeprecationWarning
doctest_optionflags = NUMBER NORMALIZE_WHITESPACE
addopts = --doctest-glob="*.rst"