mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Document how to run tests using Bazel.
* Add a new --configure_only option to build.py to allow build.py to generate a .bazelrc without necessarily building jaxlib. * Add a bazel flag that make the dependency of //jax on //jaxlib optional. If //jaxlib isn't built by bazel, then tests will implicitly use a preinstalled jaxlib.
This commit is contained in:
parent
118db407f2
commit
1c75eee1ff
@ -213,11 +213,13 @@ def get_bazel_version(bazel_path):
|
||||
return tuple(int(x) for x in match.group(1).split("."))
|
||||
|
||||
|
||||
def write_bazelrc(python_bin_path=None, remote_build=None,
|
||||
cuda_toolkit_path=None, cudnn_install_path=None,
|
||||
cuda_version=None, cudnn_version=None, rocm_toolkit_path=None,
|
||||
cpu=None, cuda_compute_capabilities=None,
|
||||
rocm_amdgpu_targets=None):
|
||||
def write_bazelrc(*, python_bin_path, remote_build,
|
||||
cuda_toolkit_path, cudnn_install_path,
|
||||
cuda_version, cudnn_version, rocm_toolkit_path,
|
||||
cpu, cuda_compute_capabilities,
|
||||
rocm_amdgpu_targets, bazel_options, target_cpu_features,
|
||||
wheel_cpu, enable_mkl_dnn, enable_cuda, enable_nccl,
|
||||
enable_tpu, enable_remote_tpu, enable_rocm):
|
||||
tf_cuda_paths = []
|
||||
|
||||
with open("../.jax_configure.bazelrc", "w") as f:
|
||||
@ -263,6 +265,32 @@ def write_bazelrc(python_bin_path=None, remote_build=None,
|
||||
else:
|
||||
f.write("build --distinct_host_configuration=false\n")
|
||||
|
||||
for o in bazel_options:
|
||||
f.write(f"common {o}\n")
|
||||
if target_cpu_features == "release":
|
||||
if wheel_cpu == "x86_64":
|
||||
f.write("build --config=avx_windows\n" if is_windows()
|
||||
else "build --config=avx_posix\n")
|
||||
elif target_cpu_features == "native":
|
||||
if is_windows():
|
||||
print("--target_cpu_features=native is not supported on Windows; ignoring.")
|
||||
else:
|
||||
f.write("build --config=native_arch_posix\n")
|
||||
|
||||
if enable_mkl_dnn:
|
||||
f.write("build --config=mkl_open_source_only\n")
|
||||
if enable_cuda:
|
||||
f.write("build --config=cuda\n")
|
||||
if not enable_nccl:
|
||||
f.write("build --config=nonccl\n")
|
||||
if enable_tpu:
|
||||
f.write("build --config=tpu\n")
|
||||
if enable_remote_tpu:
|
||||
f.write("build --//build:enable_remote_tpu=true\n")
|
||||
if enable_rocm:
|
||||
f.write("build --config=rocm\n")
|
||||
if not enable_nccl:
|
||||
f.write("build --config=nonccl\n")
|
||||
|
||||
BANNER = r"""
|
||||
_ _ __ __
|
||||
@ -362,7 +390,7 @@ def main():
|
||||
parser,
|
||||
"remote_build",
|
||||
default=False,
|
||||
help_str="Should we build with RBE.")
|
||||
help_str="Should we build with RBE (Remote Build Environment)?")
|
||||
parser.add_argument(
|
||||
"--cuda_path",
|
||||
default=None,
|
||||
@ -410,6 +438,11 @@ def main():
|
||||
default=None,
|
||||
help="CPU platform to target. Default is the same as the host machine. "
|
||||
"Currently supported values are 'darwin_arm64' and 'darwin_x86_64'.")
|
||||
add_boolean_argument(
|
||||
parser,
|
||||
"configure_only",
|
||||
default=False,
|
||||
help_str="If true, writes a .bazelrc file but does not build jaxlib.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if is_windows() and args.enable_cuda:
|
||||
@ -491,38 +524,25 @@ def main():
|
||||
cpu=args.target_cpu,
|
||||
cuda_compute_capabilities=args.cuda_compute_capabilities,
|
||||
rocm_amdgpu_targets=args.rocm_amdgpu_targets,
|
||||
bazel_options=args.bazel_options,
|
||||
target_cpu_features=args.target_cpu_features,
|
||||
wheel_cpu=wheel_cpu,
|
||||
enable_mkl_dnn=args.enable_mkl_dnn,
|
||||
enable_cuda=args.enable_cuda,
|
||||
enable_nccl=args.enable_nccl,
|
||||
enable_tpu=args.enable_tpu,
|
||||
enable_remote_tpu=args.enable_remote_tpu,
|
||||
enable_rocm=args.enable_rocm,
|
||||
)
|
||||
|
||||
if args.configure_only:
|
||||
return
|
||||
|
||||
print("\nBuilding XLA and installing it in the jaxlib source tree...")
|
||||
|
||||
config_args = args.bazel_options
|
||||
if args.target_cpu_features == "release":
|
||||
if wheel_cpu == "x86_64":
|
||||
config_args += ["--config=avx_windows" if is_windows()
|
||||
else "--config=avx_posix"]
|
||||
elif args.target_cpu_features == "native":
|
||||
if is_windows():
|
||||
print("--target_cpu_features=native is not supported on Windows; ignoring.")
|
||||
else:
|
||||
config_args += ["--config=native_arch_posix"]
|
||||
|
||||
if args.enable_mkl_dnn:
|
||||
config_args += ["--config=mkl_open_source_only"]
|
||||
if args.enable_cuda:
|
||||
config_args += ["--config=cuda"]
|
||||
if not args.enable_nccl:
|
||||
config_args += ["--config=nonccl"]
|
||||
if args.enable_tpu:
|
||||
config_args += ["--config=tpu"]
|
||||
if args.enable_remote_tpu:
|
||||
config_args += ["--//build:enable_remote_tpu=true"]
|
||||
if args.enable_rocm:
|
||||
config_args += ["--config=rocm"]
|
||||
if not args.enable_nccl:
|
||||
config_args += ["--config=nonccl"]
|
||||
|
||||
command = ([bazel_path] + args.bazel_startup_options +
|
||||
["run", "--verbose_failures=true"] + config_args +
|
||||
["run", "--verbose_failures=true"] +
|
||||
[":build_wheel", "--",
|
||||
f"--output_path={output_path}",
|
||||
f"--cpu={wheel_cpu}"])
|
||||
|
@ -147,24 +147,68 @@ sets up symbolic links from site-packages into the repository.
|
||||
|
||||
# Running the tests
|
||||
|
||||
To run all the JAX tests, we recommend using `pytest-xdist`, which can run tests in
|
||||
parallel. First, install `pytest-xdist` and `pytest-benchmark` by running
|
||||
`pip install -r build/test-requirements.txt`.
|
||||
There are two supported mechanisms for running the JAX tests, either using Bazel
|
||||
or using pytest.
|
||||
|
||||
## Using Bazel
|
||||
|
||||
First, configure the JAX build by running:
|
||||
```
|
||||
python build/build.py --configure_only
|
||||
```
|
||||
|
||||
You may pass additional options to `build.py` to configure the build; see the
|
||||
`jaxlib` build documentation for details.
|
||||
|
||||
By default the Bazel build runs the JAX tests using `jaxlib` built form source.
|
||||
To run JAX tests, run:
|
||||
|
||||
```
|
||||
bazel test //tests/...
|
||||
```
|
||||
|
||||
To use a preinstalled `jaxlib` instead of building `jaxlib` from source, run
|
||||
|
||||
```
|
||||
bazel test --//jax:build_jaxlib=false //tests/...
|
||||
```
|
||||
|
||||
|
||||
A number of test behaviors can be controlled using environment variables (see
|
||||
below). Environment variables may be passed to JAX tests using the
|
||||
`--test_env=FLAG=value` flag to Bazel.
|
||||
|
||||
## Using pytest
|
||||
|
||||
To run all the JAX tests using `pytest`, we recommend using `pytest-xdist`,
|
||||
which can run tests in parallel. First, install `pytest-xdist` and
|
||||
`pytest-benchmark` by running `pip install -r build/test-requirements.txt`.
|
||||
Then, from the repository root directory run:
|
||||
|
||||
```
|
||||
pytest -n auto tests
|
||||
```
|
||||
|
||||
JAX generates test cases combinatorially, and you can control the number of
|
||||
cases that are generated and checked for each test (default is 10). The automated tests
|
||||
currently use 25:
|
||||
## Controlling test behavior
|
||||
|
||||
JAX generates test cases combinatorially, and you can control the number of
|
||||
cases that are generated and checked for each test (default is 10) using the
|
||||
`JAX_NUM_GENERATED_CASES` environment variable. The automated tests
|
||||
currently use 25 by default.
|
||||
|
||||
For example, one might write
|
||||
```
|
||||
# Bazel
|
||||
bazel test //tests/... --test_env=JAX_NUM_GENERATED_CASES=25`
|
||||
```
|
||||
or
|
||||
```
|
||||
# pytest
|
||||
JAX_NUM_GENERATED_CASES=25 pytest -n auto tests
|
||||
```
|
||||
|
||||
The automated tests also run the tests with default 64-bit floats and ints:
|
||||
The automated tests also run the tests with default 64-bit floats and ints
|
||||
(`JAX_ENABLE_X64`):
|
||||
|
||||
```
|
||||
JAX_ENABLE_X64=1 JAX_NUM_GENERATED_CASES=25 pytest -n auto tests
|
||||
@ -179,7 +223,7 @@ file directly to see more detailed information about the cases being run:
|
||||
python tests/lax_numpy_test.py --num_generated_cases=5
|
||||
```
|
||||
|
||||
You can skip a few tests known as slow, by passing environment variable
|
||||
You can skip a few tests known to be slow, by passing environment variable
|
||||
JAX_SKIP_SLOW_TESTS=1.
|
||||
|
||||
To specify a particular set of tests to run from a test file, you can pass a string
|
||||
@ -192,16 +236,6 @@ python tests/lax_numpy_test.py --test_targets="testPad"
|
||||
|
||||
The Colab notebooks are tested for errors as part of the documentation build.
|
||||
|
||||
Note that to run the full pmap tests on a (multi-core) CPU-only machine, you
|
||||
can run:
|
||||
|
||||
```
|
||||
pytest tests/pmap_tests.py
|
||||
```
|
||||
|
||||
I.e. don't use the `-n auto` option, since that effectively runs each test on a
|
||||
single-core worker.
|
||||
|
||||
## Doctests
|
||||
JAX uses pytest in doctest mode to test the code examples within the documentation.
|
||||
You can run this using
|
||||
|
@ -35,4 +35,5 @@ tf_cc_binary(
|
||||
"@org_tensorflow//tensorflow/core/platform:logging",
|
||||
"@org_tensorflow//tensorflow/core/platform:platform_port",
|
||||
],
|
||||
tags = ["manual"],
|
||||
)
|
||||
|
23
jax/BUILD
23
jax/BUILD
@ -29,10 +29,25 @@ load(
|
||||
"sharded_jit_visibility",
|
||||
)
|
||||
|
||||
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(default_visibility = [":internal"])
|
||||
|
||||
bool_flag(
|
||||
name = "build_jaxlib",
|
||||
build_setting_default = True,
|
||||
)
|
||||
|
||||
|
||||
config_setting(
|
||||
name = "enable_jaxlib_build",
|
||||
flag_values = {
|
||||
":build_jaxlib": "True",
|
||||
},
|
||||
)
|
||||
|
||||
exports_files([
|
||||
"LICENSE",
|
||||
"version.py",
|
||||
@ -104,9 +119,11 @@ py_library_providing_imports_info(
|
||||
],
|
||||
lib_rule = pytype_library,
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//jaxlib",
|
||||
] + numpy_py_deps + scipy_py_deps + jax_extra_deps,
|
||||
deps = select({
|
||||
":enable_jaxlib_build": ["//jaxlib"],
|
||||
"//conditions:default": [],
|
||||
}) +
|
||||
numpy_py_deps + scipy_py_deps + jax_extra_deps,
|
||||
)
|
||||
|
||||
py_library_providing_imports_info(
|
||||
|
@ -77,6 +77,7 @@ def windows_cc_shared_mlir_library(name, out, deps = [], srcs = []):
|
||||
linkshared = 1,
|
||||
linkstatic = 1,
|
||||
deps = deps,
|
||||
target_compatible_with = ["@platforms//os:windows"],
|
||||
)
|
||||
|
||||
# .def file with all symbols, not usable
|
||||
@ -85,6 +86,7 @@ def windows_cc_shared_mlir_library(name, out, deps = [], srcs = []):
|
||||
name = full_def_name,
|
||||
srcs = [dummy_library_name],
|
||||
output_group = "def_file",
|
||||
target_compatible_with = ["@platforms//os:windows"],
|
||||
)
|
||||
|
||||
# filtered def_file, only the needed symbols are included
|
||||
@ -95,6 +97,7 @@ def windows_cc_shared_mlir_library(name, out, deps = [], srcs = []):
|
||||
srcs = [full_def_name],
|
||||
outs = [filtered_def_file],
|
||||
cmd = """echo 'LIBRARY {}\nEXPORTS ' > $@ && grep '^\\W*mlir' $(location :{}) >> $@""".format(out, full_def_name),
|
||||
target_compatible_with = ["@platforms//os:windows"],
|
||||
)
|
||||
|
||||
# create the desired library
|
||||
@ -103,6 +106,7 @@ def windows_cc_shared_mlir_library(name, out, deps = [], srcs = []):
|
||||
linkshared = 1,
|
||||
deps = deps,
|
||||
win_def_file = filtered_def_file,
|
||||
target_compatible_with = ["@platforms//os:windows"],
|
||||
)
|
||||
|
||||
# however, the created cc_library (a shared library) cannot be correctly
|
||||
@ -112,6 +116,7 @@ def windows_cc_shared_mlir_library(name, out, deps = [], srcs = []):
|
||||
name = interface_library_file,
|
||||
srcs = [out],
|
||||
output_group = "interface_library",
|
||||
target_compatible_with = ["@platforms//os:windows"],
|
||||
)
|
||||
|
||||
# but this one can be correctly consumed, this is our final product
|
||||
@ -119,6 +124,7 @@ def windows_cc_shared_mlir_library(name, out, deps = [], srcs = []):
|
||||
name = name,
|
||||
interface_library = interface_library_file,
|
||||
shared_library = out,
|
||||
target_compatible_with = ["@platforms//os:windows"],
|
||||
)
|
||||
|
||||
def jax_test(
|
||||
|
Loading…
x
Reference in New Issue
Block a user