Document how to run tests using Bazel.

* Add a new --configure_only option to build.py to allow build.py to generate a .bazelrc without necessarily building jaxlib.
* Add a bazel flag that make the dependency of //jax on //jaxlib optional. If //jaxlib isn't built by bazel, then tests will implicitly use a preinstalled jaxlib.
This commit is contained in:
Peter Hawkins 2022-07-03 15:04:37 -04:00
parent 118db407f2
commit 1c75eee1ff
5 changed files with 131 additions and 53 deletions

View File

@ -213,11 +213,13 @@ def get_bazel_version(bazel_path):
return tuple(int(x) for x in match.group(1).split("."))
def write_bazelrc(python_bin_path=None, remote_build=None,
cuda_toolkit_path=None, cudnn_install_path=None,
cuda_version=None, cudnn_version=None, rocm_toolkit_path=None,
cpu=None, cuda_compute_capabilities=None,
rocm_amdgpu_targets=None):
def write_bazelrc(*, python_bin_path, remote_build,
cuda_toolkit_path, cudnn_install_path,
cuda_version, cudnn_version, rocm_toolkit_path,
cpu, cuda_compute_capabilities,
rocm_amdgpu_targets, bazel_options, target_cpu_features,
wheel_cpu, enable_mkl_dnn, enable_cuda, enable_nccl,
enable_tpu, enable_remote_tpu, enable_rocm):
tf_cuda_paths = []
with open("../.jax_configure.bazelrc", "w") as f:
@ -263,6 +265,32 @@ def write_bazelrc(python_bin_path=None, remote_build=None,
else:
f.write("build --distinct_host_configuration=false\n")
for o in bazel_options:
f.write(f"common {o}\n")
if target_cpu_features == "release":
if wheel_cpu == "x86_64":
f.write("build --config=avx_windows\n" if is_windows()
else "build --config=avx_posix\n")
elif target_cpu_features == "native":
if is_windows():
print("--target_cpu_features=native is not supported on Windows; ignoring.")
else:
f.write("build --config=native_arch_posix\n")
if enable_mkl_dnn:
f.write("build --config=mkl_open_source_only\n")
if enable_cuda:
f.write("build --config=cuda\n")
if not enable_nccl:
f.write("build --config=nonccl\n")
if enable_tpu:
f.write("build --config=tpu\n")
if enable_remote_tpu:
f.write("build --//build:enable_remote_tpu=true\n")
if enable_rocm:
f.write("build --config=rocm\n")
if not enable_nccl:
f.write("build --config=nonccl\n")
BANNER = r"""
_ _ __ __
@ -362,7 +390,7 @@ def main():
parser,
"remote_build",
default=False,
help_str="Should we build with RBE.")
help_str="Should we build with RBE (Remote Build Environment)?")
parser.add_argument(
"--cuda_path",
default=None,
@ -410,6 +438,11 @@ def main():
default=None,
help="CPU platform to target. Default is the same as the host machine. "
"Currently supported values are 'darwin_arm64' and 'darwin_x86_64'.")
add_boolean_argument(
parser,
"configure_only",
default=False,
help_str="If true, writes a .bazelrc file but does not build jaxlib.")
args = parser.parse_args()
if is_windows() and args.enable_cuda:
@ -491,38 +524,25 @@ def main():
cpu=args.target_cpu,
cuda_compute_capabilities=args.cuda_compute_capabilities,
rocm_amdgpu_targets=args.rocm_amdgpu_targets,
bazel_options=args.bazel_options,
target_cpu_features=args.target_cpu_features,
wheel_cpu=wheel_cpu,
enable_mkl_dnn=args.enable_mkl_dnn,
enable_cuda=args.enable_cuda,
enable_nccl=args.enable_nccl,
enable_tpu=args.enable_tpu,
enable_remote_tpu=args.enable_remote_tpu,
enable_rocm=args.enable_rocm,
)
if args.configure_only:
return
print("\nBuilding XLA and installing it in the jaxlib source tree...")
config_args = args.bazel_options
if args.target_cpu_features == "release":
if wheel_cpu == "x86_64":
config_args += ["--config=avx_windows" if is_windows()
else "--config=avx_posix"]
elif args.target_cpu_features == "native":
if is_windows():
print("--target_cpu_features=native is not supported on Windows; ignoring.")
else:
config_args += ["--config=native_arch_posix"]
if args.enable_mkl_dnn:
config_args += ["--config=mkl_open_source_only"]
if args.enable_cuda:
config_args += ["--config=cuda"]
if not args.enable_nccl:
config_args += ["--config=nonccl"]
if args.enable_tpu:
config_args += ["--config=tpu"]
if args.enable_remote_tpu:
config_args += ["--//build:enable_remote_tpu=true"]
if args.enable_rocm:
config_args += ["--config=rocm"]
if not args.enable_nccl:
config_args += ["--config=nonccl"]
command = ([bazel_path] + args.bazel_startup_options +
["run", "--verbose_failures=true"] + config_args +
["run", "--verbose_failures=true"] +
[":build_wheel", "--",
f"--output_path={output_path}",
f"--cpu={wheel_cpu}"])

View File

@ -147,24 +147,68 @@ sets up symbolic links from site-packages into the repository.
# Running the tests
To run all the JAX tests, we recommend using `pytest-xdist`, which can run tests in
parallel. First, install `pytest-xdist` and `pytest-benchmark` by running
`pip install -r build/test-requirements.txt`.
There are two supported mechanisms for running the JAX tests, either using Bazel
or using pytest.
## Using Bazel
First, configure the JAX build by running:
```
python build/build.py --configure_only
```
You may pass additional options to `build.py` to configure the build; see the
`jaxlib` build documentation for details.
By default the Bazel build runs the JAX tests using `jaxlib` built form source.
To run JAX tests, run:
```
bazel test //tests/...
```
To use a preinstalled `jaxlib` instead of building `jaxlib` from source, run
```
bazel test --//jax:build_jaxlib=false //tests/...
```
A number of test behaviors can be controlled using environment variables (see
below). Environment variables may be passed to JAX tests using the
`--test_env=FLAG=value` flag to Bazel.
## Using pytest
To run all the JAX tests using `pytest`, we recommend using `pytest-xdist`,
which can run tests in parallel. First, install `pytest-xdist` and
`pytest-benchmark` by running `pip install -r build/test-requirements.txt`.
Then, from the repository root directory run:
```
pytest -n auto tests
```
JAX generates test cases combinatorially, and you can control the number of
cases that are generated and checked for each test (default is 10). The automated tests
currently use 25:
## Controlling test behavior
JAX generates test cases combinatorially, and you can control the number of
cases that are generated and checked for each test (default is 10) using the
`JAX_NUM_GENERATED_CASES` environment variable. The automated tests
currently use 25 by default.
For example, one might write
```
# Bazel
bazel test //tests/... --test_env=JAX_NUM_GENERATED_CASES=25`
```
or
```
# pytest
JAX_NUM_GENERATED_CASES=25 pytest -n auto tests
```
The automated tests also run the tests with default 64-bit floats and ints:
The automated tests also run the tests with default 64-bit floats and ints
(`JAX_ENABLE_X64`):
```
JAX_ENABLE_X64=1 JAX_NUM_GENERATED_CASES=25 pytest -n auto tests
@ -179,7 +223,7 @@ file directly to see more detailed information about the cases being run:
python tests/lax_numpy_test.py --num_generated_cases=5
```
You can skip a few tests known as slow, by passing environment variable
You can skip a few tests known to be slow, by passing environment variable
JAX_SKIP_SLOW_TESTS=1.
To specify a particular set of tests to run from a test file, you can pass a string
@ -192,16 +236,6 @@ python tests/lax_numpy_test.py --test_targets="testPad"
The Colab notebooks are tested for errors as part of the documentation build.
Note that to run the full pmap tests on a (multi-core) CPU-only machine, you
can run:
```
pytest tests/pmap_tests.py
```
I.e. don't use the `-n auto` option, since that effectively runs each test on a
single-core worker.
## Doctests
JAX uses pytest in doctest mode to test the code examples within the documentation.
You can run this using

View File

@ -35,4 +35,5 @@ tf_cc_binary(
"@org_tensorflow//tensorflow/core/platform:logging",
"@org_tensorflow//tensorflow/core/platform:platform_port",
],
tags = ["manual"],
)

View File

@ -29,10 +29,25 @@ load(
"sharded_jit_visibility",
)
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
licenses(["notice"])
package(default_visibility = [":internal"])
bool_flag(
name = "build_jaxlib",
build_setting_default = True,
)
config_setting(
name = "enable_jaxlib_build",
flag_values = {
":build_jaxlib": "True",
},
)
exports_files([
"LICENSE",
"version.py",
@ -104,9 +119,11 @@ py_library_providing_imports_info(
],
lib_rule = pytype_library,
visibility = ["//visibility:public"],
deps = [
"//jaxlib",
] + numpy_py_deps + scipy_py_deps + jax_extra_deps,
deps = select({
":enable_jaxlib_build": ["//jaxlib"],
"//conditions:default": [],
}) +
numpy_py_deps + scipy_py_deps + jax_extra_deps,
)
py_library_providing_imports_info(

View File

@ -77,6 +77,7 @@ def windows_cc_shared_mlir_library(name, out, deps = [], srcs = []):
linkshared = 1,
linkstatic = 1,
deps = deps,
target_compatible_with = ["@platforms//os:windows"],
)
# .def file with all symbols, not usable
@ -85,6 +86,7 @@ def windows_cc_shared_mlir_library(name, out, deps = [], srcs = []):
name = full_def_name,
srcs = [dummy_library_name],
output_group = "def_file",
target_compatible_with = ["@platforms//os:windows"],
)
# filtered def_file, only the needed symbols are included
@ -95,6 +97,7 @@ def windows_cc_shared_mlir_library(name, out, deps = [], srcs = []):
srcs = [full_def_name],
outs = [filtered_def_file],
cmd = """echo 'LIBRARY {}\nEXPORTS ' > $@ && grep '^\\W*mlir' $(location :{}) >> $@""".format(out, full_def_name),
target_compatible_with = ["@platforms//os:windows"],
)
# create the desired library
@ -103,6 +106,7 @@ def windows_cc_shared_mlir_library(name, out, deps = [], srcs = []):
linkshared = 1,
deps = deps,
win_def_file = filtered_def_file,
target_compatible_with = ["@platforms//os:windows"],
)
# however, the created cc_library (a shared library) cannot be correctly
@ -112,6 +116,7 @@ def windows_cc_shared_mlir_library(name, out, deps = [], srcs = []):
name = interface_library_file,
srcs = [out],
output_group = "interface_library",
target_compatible_with = ["@platforms//os:windows"],
)
# but this one can be correctly consumed, this is our final product
@ -119,6 +124,7 @@ def windows_cc_shared_mlir_library(name, out, deps = [], srcs = []):
name = name,
interface_library = interface_library_file,
shared_library = out,
target_compatible_with = ["@platforms//os:windows"],
)
def jax_test(