mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #15585 from SauravMaheshkar:main
PiperOrigin-RevId: 524896086
This commit is contained in:
commit
9c92ec5d90
@ -33,6 +33,7 @@ repos:
|
||||
files: (jax/|tests/typing_test\.py)
|
||||
exclude: jax/_src/basearray.py # Use pyi instead
|
||||
additional_dependencies: [types-requests==2.28.11, jaxlib==0.4.6, ml_dtypes==0.0.3, numpy==1.21.6, scipy==1.7.3]
|
||||
args: [--config=pyproject.toml]
|
||||
|
||||
- repo: https://github.com/mwouts/jupytext
|
||||
rev: v1.14.4
|
||||
|
@ -30,7 +30,7 @@ guidance on pip installation (e.g., for GPU and TPU support).
|
||||
|
||||
To build `jaxlib` from source, you must also install some prerequisites:
|
||||
|
||||
* a C++ compiler (g++, clang, or MSVC)
|
||||
- a C++ compiler (g++, clang, or MSVC)
|
||||
|
||||
On Ubuntu or Debian you can install the necessary prerequisites with:
|
||||
|
||||
@ -42,7 +42,8 @@ To build `jaxlib` from source, you must also install some prerequisites:
|
||||
are installed.
|
||||
|
||||
See below for Windows build instructions.
|
||||
* Python packages: `numpy`, `wheel`.
|
||||
|
||||
- Python packages: `numpy`, `wheel`.
|
||||
|
||||
You can install the necessary Python dependencies using `pip`:
|
||||
|
||||
@ -74,13 +75,14 @@ By default JAX uses a pinned copy of the XLA repository, but we often
|
||||
want to use a locally-modified copy of XLA when working on JAX. There are two
|
||||
ways to do this:
|
||||
|
||||
* use Bazel's `override_repository` feature, which you can pass as a command
|
||||
- use Bazel's `override_repository` feature, which you can pass as a command
|
||||
line flag to `build.py` as follows:
|
||||
|
||||
```
|
||||
python build/build.py --bazel_options=--override_repository=xla=/path/to/xla
|
||||
```
|
||||
* modify the `WORKSPACE` file in the root of the JAX source tree to point to
|
||||
|
||||
- modify the `WORKSPACE` file in the root of the JAX source tree to point to
|
||||
a different XLA tree.
|
||||
|
||||
To contribute changes back to XLA, send PRs to the XLA repository.
|
||||
@ -112,6 +114,7 @@ for more details. Install the following packages:
|
||||
```
|
||||
pacman -S patch coreutils
|
||||
```
|
||||
|
||||
Once coreutils is installed, the realpath command should be present in your shell's path.
|
||||
|
||||
Once everything is installed. Open PowerShell, and make sure MSYS2 is in the
|
||||
@ -144,12 +147,14 @@ sudo apt install miopen-hip hipfft-dev rocrand-dev hipsparse-dev hipsolver-dev \
|
||||
AMD's fork of the XLA repository may include fixes
|
||||
not present in the upstream repository. To use AMD's fork, you should clone
|
||||
their repository:
|
||||
|
||||
```
|
||||
git clone https://github.com/ROCmSoftwarePlatform/tensorflow-upstream.git
|
||||
```
|
||||
|
||||
To build jaxlib with ROCM support, you can run the following build command,
|
||||
suitably adjusted for your paths and ROCM version.
|
||||
|
||||
```
|
||||
python build/build.py --enable_rocm --rocm_path=/opt/rocm-5.3.0 \
|
||||
--bazel_options=--override_repository=xla=/path/to/xla-upstream
|
||||
@ -178,6 +183,7 @@ or using pytest.
|
||||
### Using Bazel
|
||||
|
||||
First, configure the JAX build by running:
|
||||
|
||||
```
|
||||
python build/build.py --configure_only
|
||||
```
|
||||
@ -200,7 +206,6 @@ To use a preinstalled `jaxlib` instead of building `jaxlib` from source, run
|
||||
bazel test --//jax:build_jaxlib=false //tests:cpu_tests //tests:backend_independent_tests
|
||||
```
|
||||
|
||||
|
||||
A number of test behaviors can be controlled using environment variables (see
|
||||
below). Environment variables may be passed to JAX tests using the
|
||||
`--test_env=FLAG=value` flag to Bazel.
|
||||
@ -240,11 +245,14 @@ cases that are generated and checked for each test (default is 10) using the
|
||||
currently use 25 by default.
|
||||
|
||||
For example, one might write
|
||||
|
||||
```
|
||||
# Bazel
|
||||
bazel test //tests/... --test_env=JAX_NUM_GENERATED_CASES=25`
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```
|
||||
# pytest
|
||||
JAX_NUM_GENERATED_CASES=25 pytest -n auto tests
|
||||
@ -283,14 +291,18 @@ The Colab notebooks are tested for errors as part of the documentation build.
|
||||
|
||||
JAX uses pytest in doctest mode to test the code examples within the documentation.
|
||||
You can run this using
|
||||
|
||||
```
|
||||
pytest docs
|
||||
```
|
||||
|
||||
Additionally, JAX runs pytest in `doctest-modules` mode to ensure code examples in
|
||||
function docstrings will run correctly. You can run this locally using, for example:
|
||||
|
||||
```
|
||||
pytest --doctest-modules jax/_src/numpy/lax_numpy.py
|
||||
```
|
||||
|
||||
Keep in mind that there are several files that are marked to be skipped when the
|
||||
doctest command is run on the full package; you can see the details in
|
||||
[`ci-build.yaml`](https://github.com/google/jax/blob/main/.github/workflows/ci-build.yaml)
|
||||
@ -302,7 +314,7 @@ as the CI checks them:
|
||||
|
||||
```
|
||||
pip install mypy
|
||||
mypy --config=mypy.ini --show-error-codes jax
|
||||
mypy --config=pyproject.toml --show-error-codes jax
|
||||
```
|
||||
|
||||
Alternatively, you can use the [pre-commit](https://pre-commit.com/) framework to run this
|
||||
@ -334,18 +346,24 @@ pre-commit run flake8
|
||||
## Update documentation
|
||||
|
||||
To rebuild the documentation, install several packages:
|
||||
|
||||
```
|
||||
pip install -r docs/requirements.txt
|
||||
```
|
||||
|
||||
And then run:
|
||||
|
||||
```
|
||||
sphinx-build -b html docs docs/build/html -j auto
|
||||
```
|
||||
|
||||
This can take a long time because it executes many of the notebooks in the documentation source;
|
||||
if you'd prefer to build the docs without executing the notebooks, you can run:
|
||||
|
||||
```
|
||||
sphinx-build -b html -D nb_execution_mode=off docs docs/build/html -j auto
|
||||
```
|
||||
|
||||
You can then see the generated documentation in `docs/build/html/index.html`.
|
||||
|
||||
The `-j auto` option controls the parallelism of the build. You can use a number
|
||||
|
55
mypy.ini
55
mypy.ini
@ -1,55 +0,0 @@
|
||||
[mypy]
|
||||
show_error_codes = True
|
||||
disable_error_code = attr-defined
|
||||
no_implicit_optional = True
|
||||
|
||||
[mypy-absl.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-colorama.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-numpy.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-opt_einsum.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-scipy.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-jax.interpreters.autospmd]
|
||||
ignore_errors = True
|
||||
[mypy-jax.lax.lax_parallel]
|
||||
ignore_errors = True
|
||||
[mypy-jax.experimental.jax2tf.tests.primitive_harness]
|
||||
ignore_errors = True
|
||||
[mypy-libtpu.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-jaxlib.mlir.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-iree.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-rich.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-optax.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-flax.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-tensorflow.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-tensorflowjs.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-tensorflow.io.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-tensorstore.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-web_pdb.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-etils.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-google.colab.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-pygments.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-jraph.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-matplotlib.*]
|
||||
ignore_missing_imports = True
|
||||
[mypy-tensorboard_plugin_profile.convert.*]
|
||||
ignore_missing_imports = True
|
46
pylintrc
46
pylintrc
@ -1,46 +0,0 @@
|
||||
[MASTER]
|
||||
|
||||
# A comma-separated list of package or module names from where C extensions may
|
||||
# be loaded. Extensions are loading into the active Python interpreter and may
|
||||
# run arbitrary code
|
||||
extension-pkg-whitelist=numpy
|
||||
|
||||
|
||||
[MESSAGES CONTROL]
|
||||
|
||||
# Disable the message, report, category or checker with the given id(s). You
|
||||
# can either give multiple identifiers separated by comma (,) or put this
|
||||
# option multiple times (only on the command line, not in the configuration
|
||||
# file where it should appear only once).You can also use "--disable=all" to
|
||||
# disable everything first and then reenable specific checks. For example, if
|
||||
# you want to run only the similarities checker, you can use "--disable=all
|
||||
# --enable=similarities". If you want to run only the classes checker, but have
|
||||
# no Warning level messages displayed, use"--disable=all --enable=classes
|
||||
# --disable=W"
|
||||
disable=missing-docstring,
|
||||
too-many-locals,
|
||||
invalid-name,
|
||||
redefined-outer-name,
|
||||
redefined-builtin,
|
||||
protected-name,
|
||||
no-else-return,
|
||||
fixme,
|
||||
protected-access,
|
||||
too-many-arguments,
|
||||
blacklisted-name,
|
||||
too-few-public-methods,
|
||||
unnecessary-lambda,
|
||||
|
||||
|
||||
# Enable the message, report, category or checker with the given id(s). You can
|
||||
# either give multiple identifier separated by comma (,) or put this option
|
||||
# multiple time (only on the command line, not in the configuration file where
|
||||
# it should appear only once). See also the "--disable" option for examples.
|
||||
enable=c-extension-no-member
|
||||
|
||||
|
||||
[FORMAT]
|
||||
|
||||
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
|
||||
# tab).
|
||||
indent-string=" "
|
104
pyproject.toml
104
pyproject.toml
@ -1,3 +1,107 @@
|
||||
[build-system]
|
||||
requires = ["setuptools", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.mypy]
|
||||
show_error_codes = true
|
||||
disable_error_code = "attr-defined"
|
||||
no_implicit_optional = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = [
|
||||
"absl.*",
|
||||
"colorama.*",
|
||||
"IPython.*",
|
||||
"numpy.*",
|
||||
"opt_einsum.*",
|
||||
"scipy.*",
|
||||
"libtpu.*",
|
||||
"jaxlib.mlir.*",
|
||||
"iree.*",
|
||||
"rich.*",
|
||||
"optax.*",
|
||||
"flax.*",
|
||||
"tensorflow.*",
|
||||
"tensorflowjs.*",
|
||||
"tensorflow.io.*",
|
||||
"tensorstore.*",
|
||||
"web_pdb.*",
|
||||
"etils.*",
|
||||
"google.colab.*",
|
||||
"pygments.*",
|
||||
"jraph.*",
|
||||
"matplotlib.*",
|
||||
"tensorboard_plugin_profile.convert.*",
|
||||
"jaxlib.*",
|
||||
"pytest.*",
|
||||
"ml_dtypes",
|
||||
"jax.experimental.jax2tf.tests.flax_models",
|
||||
"jax.experimental.jax2tf.tests.back_compat_testdata"
|
||||
]
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = [
|
||||
"jax.interpreters.autospmd",
|
||||
"jax.lax.lax_parallel",
|
||||
"jax.experimental.jax2tf.tests.primitive_harness"
|
||||
]
|
||||
ignore_errors = true
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
markers = [
|
||||
"multiaccelerator: indicates that a test can make use of and possibly requires multiple accelerators",
|
||||
"SlurmMultiNodeGpuTest: mark a test for Slurm multinode GPU nightly CI"
|
||||
]
|
||||
filterwarnings = [
|
||||
"error",
|
||||
"ignore:No GPU/TPU found, falling back to CPU.:UserWarning",
|
||||
"ignore:outfeed_receiver is unnecessary and deprecated:DeprecationWarning",
|
||||
"ignore:xmap is an experimental feature and probably has bugs!",
|
||||
"ignore:the imp module is deprecated in favour of importlib.*:DeprecationWarning",
|
||||
"ignore:can't resolve package from __spec__ or __package__:ImportWarning",
|
||||
"ignore:Using or importing the ABCs.*:DeprecationWarning",
|
||||
"ignore:numpy.ufunc size changed",
|
||||
"ignore:.*experimental feature",
|
||||
"ignore:index.*is deprecated.*:DeprecationWarning",
|
||||
"ignore:jax.experimental.* is deprecated, import jax.example_libraries.* instead:FutureWarning",
|
||||
"ignore:The distutils.* is deprecated.*:DeprecationWarning",
|
||||
"ignore:`sharded_jit` is deprecated. Please use `pjit` instead.*:DeprecationWarning",
|
||||
"default:Error reading persistent compilation cache entry for 'jit__lambda_'",
|
||||
"default:Error writing persistent compilation cache entry for 'jit__lambda_'",
|
||||
"ignore:DeviceArray, ShardedDeviceArray, and GlobalDeviceArray have been deprecated.*:DeprecationWarning",
|
||||
"ignore:backend and device argument on jit is deprecated.*:DeprecationWarning",
|
||||
"ignore:GlobalDeviceArray has been deprecated.*:DeprecationWarning",
|
||||
"ignore:jax.interpreters.pxla.make_sharded_device_array is deprecated.*:DeprecationWarning",
|
||||
# TODO(skyewm, yashkatariya): Remove after jaxlib 0.4.7 is released.
|
||||
"ignore:jax.interpreters.pxla.ShardedDeviceArray is deprecated.*:DeprecationWarning"
|
||||
]
|
||||
doctest_optionflags = [
|
||||
"NUMBER",
|
||||
"NORMALIZE_WHITESPACE"
|
||||
]
|
||||
addopts = "--doctest-glob='*.rst'"
|
||||
|
||||
[tool.pylint.master]
|
||||
extension-pkg-whitelist = "numpy"
|
||||
|
||||
[tool.pylint."messages control"]
|
||||
disable = [
|
||||
"missing-docstring",
|
||||
"too-many-locals",
|
||||
"invalid-name",
|
||||
"redefined-outer-name",
|
||||
"redefined-builtin",
|
||||
"protected-name",
|
||||
"no-else-return",
|
||||
"fixme",
|
||||
"protected-access",
|
||||
"too-many-arguments",
|
||||
"blacklisted-name",
|
||||
"too-few-public-methods",
|
||||
"unnecessary-lambda"
|
||||
]
|
||||
enable = "c-extension-no-member"
|
||||
|
||||
[tool.pylint.format]
|
||||
indent-string=" "
|
||||
|
33
pytest.ini
33
pytest.ini
@ -1,33 +0,0 @@
|
||||
[pytest]
|
||||
markers =
|
||||
multiaccelerator: indicates that a test can make use of and possibly requires multiple accelerators
|
||||
SlurmMultiNodeGpuTest: mark a test for Slurm multinode GPU nightly CI
|
||||
filterwarnings =
|
||||
error
|
||||
ignore:No GPU/TPU found, falling back to CPU.:UserWarning
|
||||
ignore:outfeed_receiver is unnecessary and deprecated:DeprecationWarning
|
||||
# xmap
|
||||
ignore:xmap is an experimental feature and probably has bugs!
|
||||
# The rest are for experimental/jax_to_tf
|
||||
ignore:the imp module is deprecated in favour of importlib.*:DeprecationWarning
|
||||
ignore:can't resolve package from __spec__ or __package__:ImportWarning
|
||||
ignore:Using or importing the ABCs.*:DeprecationWarning
|
||||
# jax2tf tests due to mix of JAX and TF
|
||||
ignore:numpy.ufunc size changed
|
||||
ignore:.*experimental feature
|
||||
ignore:index.*is deprecated.*:DeprecationWarning
|
||||
ignore:jax.experimental.* is deprecated, import jax.example_libraries.* instead:FutureWarning
|
||||
# numpy uses distutils which is deprecated
|
||||
ignore:The distutils.* is deprecated.*:DeprecationWarning
|
||||
ignore:`sharded_jit` is deprecated. Please use `pjit` instead.*:DeprecationWarning
|
||||
# Print message for compilation_cache_test.py::CompilationCacheTest::test_cache_read/write_warning
|
||||
default:Error reading persistent compilation cache entry for 'jit__lambda_'
|
||||
default:Error writing persistent compilation cache entry for 'jit__lambda_'
|
||||
ignore:DeviceArray, ShardedDeviceArray, and GlobalDeviceArray have been deprecated.*:DeprecationWarning
|
||||
ignore:backend and device argument on jit is deprecated.*:DeprecationWarning
|
||||
ignore:GlobalDeviceArray has been deprecated.*:DeprecationWarning
|
||||
ignore:jax.interpreters.pxla.make_sharded_device_array is deprecated.*:DeprecationWarning
|
||||
# TODO(skyewm, yashkatariya): Remove after jaxlib 0.4.7 is released.
|
||||
ignore:jax.interpreters.pxla.ShardedDeviceArray is deprecated.*:DeprecationWarning
|
||||
doctest_optionflags = NUMBER NORMALIZE_WHITESPACE
|
||||
addopts = --doctest-glob="*.rst"
|
Loading…
x
Reference in New Issue
Block a user