From 1c75eee1ffa07da19acda14284035232cbfb642e Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Sun, 3 Jul 2022 15:04:37 -0400 Subject: [PATCH] Document how to run tests using Bazel. * Add a new --configure_only option to build.py to allow build.py to generate a .bazelrc without necessarily building jaxlib. * Add a bazel flag that make the dependency of //jax on //jaxlib optional. If //jaxlib isn't built by bazel, then tests will implicitly use a preinstalled jaxlib. --- build/build.py | 84 ++++++++++++++++++++++++++---------------- docs/developer.md | 70 ++++++++++++++++++++++++++--------- examples/jax_cpp/BUILD | 1 + jax/BUILD | 23 ++++++++++-- jaxlib/jax.bzl | 6 +++ 5 files changed, 131 insertions(+), 53 deletions(-) diff --git a/build/build.py b/build/build.py index d3821b4e5..d8e902027 100755 --- a/build/build.py +++ b/build/build.py @@ -213,11 +213,13 @@ def get_bazel_version(bazel_path): return tuple(int(x) for x in match.group(1).split(".")) -def write_bazelrc(python_bin_path=None, remote_build=None, - cuda_toolkit_path=None, cudnn_install_path=None, - cuda_version=None, cudnn_version=None, rocm_toolkit_path=None, - cpu=None, cuda_compute_capabilities=None, - rocm_amdgpu_targets=None): +def write_bazelrc(*, python_bin_path, remote_build, + cuda_toolkit_path, cudnn_install_path, + cuda_version, cudnn_version, rocm_toolkit_path, + cpu, cuda_compute_capabilities, + rocm_amdgpu_targets, bazel_options, target_cpu_features, + wheel_cpu, enable_mkl_dnn, enable_cuda, enable_nccl, + enable_tpu, enable_remote_tpu, enable_rocm): tf_cuda_paths = [] with open("../.jax_configure.bazelrc", "w") as f: @@ -263,6 +265,32 @@ def write_bazelrc(python_bin_path=None, remote_build=None, else: f.write("build --distinct_host_configuration=false\n") + for o in bazel_options: + f.write(f"common {o}\n") + if target_cpu_features == "release": + if wheel_cpu == "x86_64": + f.write("build --config=avx_windows\n" if is_windows() + else "build --config=avx_posix\n") + elif target_cpu_features == "native": + if is_windows(): + print("--target_cpu_features=native is not supported on Windows; ignoring.") + else: + f.write("build --config=native_arch_posix\n") + + if enable_mkl_dnn: + f.write("build --config=mkl_open_source_only\n") + if enable_cuda: + f.write("build --config=cuda\n") + if not enable_nccl: + f.write("build --config=nonccl\n") + if enable_tpu: + f.write("build --config=tpu\n") + if enable_remote_tpu: + f.write("build --//build:enable_remote_tpu=true\n") + if enable_rocm: + f.write("build --config=rocm\n") + if not enable_nccl: + f.write("build --config=nonccl\n") BANNER = r""" _ _ __ __ @@ -362,7 +390,7 @@ def main(): parser, "remote_build", default=False, - help_str="Should we build with RBE.") + help_str="Should we build with RBE (Remote Build Environment)?") parser.add_argument( "--cuda_path", default=None, @@ -410,6 +438,11 @@ def main(): default=None, help="CPU platform to target. Default is the same as the host machine. " "Currently supported values are 'darwin_arm64' and 'darwin_x86_64'.") + add_boolean_argument( + parser, + "configure_only", + default=False, + help_str="If true, writes a .bazelrc file but does not build jaxlib.") args = parser.parse_args() if is_windows() and args.enable_cuda: @@ -491,38 +524,25 @@ def main(): cpu=args.target_cpu, cuda_compute_capabilities=args.cuda_compute_capabilities, rocm_amdgpu_targets=args.rocm_amdgpu_targets, + bazel_options=args.bazel_options, + target_cpu_features=args.target_cpu_features, + wheel_cpu=wheel_cpu, + enable_mkl_dnn=args.enable_mkl_dnn, + enable_cuda=args.enable_cuda, + enable_nccl=args.enable_nccl, + enable_tpu=args.enable_tpu, + enable_remote_tpu=args.enable_remote_tpu, + enable_rocm=args.enable_rocm, ) + if args.configure_only: + return + print("\nBuilding XLA and installing it in the jaxlib source tree...") - config_args = args.bazel_options - if args.target_cpu_features == "release": - if wheel_cpu == "x86_64": - config_args += ["--config=avx_windows" if is_windows() - else "--config=avx_posix"] - elif args.target_cpu_features == "native": - if is_windows(): - print("--target_cpu_features=native is not supported on Windows; ignoring.") - else: - config_args += ["--config=native_arch_posix"] - - if args.enable_mkl_dnn: - config_args += ["--config=mkl_open_source_only"] - if args.enable_cuda: - config_args += ["--config=cuda"] - if not args.enable_nccl: - config_args += ["--config=nonccl"] - if args.enable_tpu: - config_args += ["--config=tpu"] - if args.enable_remote_tpu: - config_args += ["--//build:enable_remote_tpu=true"] - if args.enable_rocm: - config_args += ["--config=rocm"] - if not args.enable_nccl: - config_args += ["--config=nonccl"] command = ([bazel_path] + args.bazel_startup_options + - ["run", "--verbose_failures=true"] + config_args + + ["run", "--verbose_failures=true"] + [":build_wheel", "--", f"--output_path={output_path}", f"--cpu={wheel_cpu}"]) diff --git a/docs/developer.md b/docs/developer.md index b9c45f63a..c983f88d4 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -147,24 +147,68 @@ sets up symbolic links from site-packages into the repository. # Running the tests -To run all the JAX tests, we recommend using `pytest-xdist`, which can run tests in -parallel. First, install `pytest-xdist` and `pytest-benchmark` by running -`pip install -r build/test-requirements.txt`. +There are two supported mechanisms for running the JAX tests, either using Bazel +or using pytest. + +## Using Bazel + +First, configure the JAX build by running: +``` +python build/build.py --configure_only +``` + +You may pass additional options to `build.py` to configure the build; see the +`jaxlib` build documentation for details. + +By default the Bazel build runs the JAX tests using `jaxlib` built form source. +To run JAX tests, run: + +``` +bazel test //tests/... +``` + +To use a preinstalled `jaxlib` instead of building `jaxlib` from source, run + +``` +bazel test --//jax:build_jaxlib=false //tests/... +``` + + +A number of test behaviors can be controlled using environment variables (see +below). Environment variables may be passed to JAX tests using the +`--test_env=FLAG=value` flag to Bazel. + +## Using pytest + +To run all the JAX tests using `pytest`, we recommend using `pytest-xdist`, +which can run tests in parallel. First, install `pytest-xdist` and +`pytest-benchmark` by running `pip install -r build/test-requirements.txt`. Then, from the repository root directory run: ``` pytest -n auto tests ``` -JAX generates test cases combinatorially, and you can control the number of -cases that are generated and checked for each test (default is 10). The automated tests -currently use 25: +## Controlling test behavior +JAX generates test cases combinatorially, and you can control the number of +cases that are generated and checked for each test (default is 10) using the +`JAX_NUM_GENERATED_CASES` environment variable. The automated tests +currently use 25 by default. + +For example, one might write ``` +# Bazel +bazel test //tests/... --test_env=JAX_NUM_GENERATED_CASES=25` +``` +or +``` +# pytest JAX_NUM_GENERATED_CASES=25 pytest -n auto tests ``` -The automated tests also run the tests with default 64-bit floats and ints: +The automated tests also run the tests with default 64-bit floats and ints +(`JAX_ENABLE_X64`): ``` JAX_ENABLE_X64=1 JAX_NUM_GENERATED_CASES=25 pytest -n auto tests @@ -179,7 +223,7 @@ file directly to see more detailed information about the cases being run: python tests/lax_numpy_test.py --num_generated_cases=5 ``` -You can skip a few tests known as slow, by passing environment variable +You can skip a few tests known to be slow, by passing environment variable JAX_SKIP_SLOW_TESTS=1. To specify a particular set of tests to run from a test file, you can pass a string @@ -192,16 +236,6 @@ python tests/lax_numpy_test.py --test_targets="testPad" The Colab notebooks are tested for errors as part of the documentation build. -Note that to run the full pmap tests on a (multi-core) CPU-only machine, you -can run: - -``` -pytest tests/pmap_tests.py -``` - -I.e. don't use the `-n auto` option, since that effectively runs each test on a -single-core worker. - ## Doctests JAX uses pytest in doctest mode to test the code examples within the documentation. You can run this using diff --git a/examples/jax_cpp/BUILD b/examples/jax_cpp/BUILD index c6fcb7904..33c773830 100644 --- a/examples/jax_cpp/BUILD +++ b/examples/jax_cpp/BUILD @@ -35,4 +35,5 @@ tf_cc_binary( "@org_tensorflow//tensorflow/core/platform:logging", "@org_tensorflow//tensorflow/core/platform:platform_port", ], + tags = ["manual"], ) diff --git a/jax/BUILD b/jax/BUILD index 6cc20d8e8..3f1c86724 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -29,10 +29,25 @@ load( "sharded_jit_visibility", ) +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") + licenses(["notice"]) package(default_visibility = [":internal"]) +bool_flag( + name = "build_jaxlib", + build_setting_default = True, +) + + +config_setting( + name = "enable_jaxlib_build", + flag_values = { + ":build_jaxlib": "True", + }, +) + exports_files([ "LICENSE", "version.py", @@ -104,9 +119,11 @@ py_library_providing_imports_info( ], lib_rule = pytype_library, visibility = ["//visibility:public"], - deps = [ - "//jaxlib", - ] + numpy_py_deps + scipy_py_deps + jax_extra_deps, + deps = select({ + ":enable_jaxlib_build": ["//jaxlib"], + "//conditions:default": [], + }) + + numpy_py_deps + scipy_py_deps + jax_extra_deps, ) py_library_providing_imports_info( diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 36b568929..76e996a4d 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -77,6 +77,7 @@ def windows_cc_shared_mlir_library(name, out, deps = [], srcs = []): linkshared = 1, linkstatic = 1, deps = deps, + target_compatible_with = ["@platforms//os:windows"], ) # .def file with all symbols, not usable @@ -85,6 +86,7 @@ def windows_cc_shared_mlir_library(name, out, deps = [], srcs = []): name = full_def_name, srcs = [dummy_library_name], output_group = "def_file", + target_compatible_with = ["@platforms//os:windows"], ) # filtered def_file, only the needed symbols are included @@ -95,6 +97,7 @@ def windows_cc_shared_mlir_library(name, out, deps = [], srcs = []): srcs = [full_def_name], outs = [filtered_def_file], cmd = """echo 'LIBRARY {}\nEXPORTS ' > $@ && grep '^\\W*mlir' $(location :{}) >> $@""".format(out, full_def_name), + target_compatible_with = ["@platforms//os:windows"], ) # create the desired library @@ -103,6 +106,7 @@ def windows_cc_shared_mlir_library(name, out, deps = [], srcs = []): linkshared = 1, deps = deps, win_def_file = filtered_def_file, + target_compatible_with = ["@platforms//os:windows"], ) # however, the created cc_library (a shared library) cannot be correctly @@ -112,6 +116,7 @@ def windows_cc_shared_mlir_library(name, out, deps = [], srcs = []): name = interface_library_file, srcs = [out], output_group = "interface_library", + target_compatible_with = ["@platforms//os:windows"], ) # but this one can be correctly consumed, this is our final product @@ -119,6 +124,7 @@ def windows_cc_shared_mlir_library(name, out, deps = [], srcs = []): name = name, interface_library = interface_library_file, shared_library = out, + target_compatible_with = ["@platforms//os:windows"], ) def jax_test(