158 Commits

Author SHA1 Message Date
Sunita Nadampalli
e370deee0f add mkldnn+acl build config for aarch64 platform 2024-12-09 16:03:14 +00:00
Nitin Srinivasan
83c64b2379 Add a flag to enable detailed timestamped logging of subprocess commands.
This adds a new command-line flag, `--detailed_timestamped_log`, that enables detailed logging of Bazel build commands. When disabled (the default), logging mirrors the output you'd see when running the command directly in your terminal.

When this flag is enabled:
- Bazel's output is captured line by line.
- Each line is timestamped for improved traceability.
- The complete log is stored for potential use as an artifact.

The flag is disabled by default and only enabled in the CI builds. If you're running locally and enable `detailed_timestamped_log`, you might notice that Bazel's output is not colored. To force a color output, include `--bazel_options=--color=yes` in your command.

PiperOrigin-RevId: 703581368
2024-12-06 12:32:24 -08:00
Nitin Srinivasan
d449f12a2e Fix early exiting when building multiple wheels
PiperOrigin-RevId: 700711389
2024-11-27 08:35:51 -08:00
Nitin Srinivasan
c6866d05db Add a check for return codes of executor.run so that we propagate error codes correctly
PiperOrigin-RevId: 700518396
2024-11-26 18:18:06 -08:00
Nitin Srinivasan
6761512658 Re-factor build CLI to a subcommand based approach
This commit reworks the JAX build CLI to a subcommand based approach where CLI use cases are now defined as subcommands. Two subcommands are defined: build and requirements_update. "build" is to be used when wanting to build a JAX wheel package. "requirements_update" is to be used when wanting to update the requirements_lock.txt files. The new structure offers a clear and organized CLI that enables users to execute specific build tasks without having to navigate through a monolithic script.

Each subcommand has specific arguments that apply to its respective build process. In addition, arguments are separated into groups to achieve a cleaner separation and improves the readability when the CLI subcommands are run with `--help`. It also makes it clear as to which parts of the build they affect. E.g: CUDA arguments only apply to CUDA builds, ROCM arguments only apply to ROCM builds, etc. This reduces the complexity and the potential for errors during the build process. Segregating functionalities into distinct subcommands also simplifies the code which should help with the maintenance and future extensions.

There is also a transition from using `subprocess.check_output` to `asyncio.create_subprocess_shell` for executing the build commands which allows for streaming logs and helps in showing the build progress in real time.

Usage:
* Building `jaxlib`:
```
python build/build.py build --wheels=jaxlib --python_version=3.10
```
* Building `jax-cuda-plugin`:
```
python build/build.py build --wheels=jax-cuda-plugin --cuda_version=12.3.2 --cudnn_version=9.1.1 --python_version=3.10
```
* Building multiple packages:
```
python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt --cuda_version=12.3.2 --cudnn_version=9.1.1 --python_version=3.10
```
* Building `jax-rocm-pjrt`:
```
python build/build.py build --wheels=jax-rocm-pjrt --rocm_version=60 --rocm_path=/path/to/rocm
```
* Using a local XLA path:
```
python build/build.py build --wheels=jaxlib --local_xla_path=/path/to/xla
```
* Updating requirements_lock.txt files:
```
python build/build.py requirements_update --python_version=3.10
```

For more details on each argument and to see available options, run:
```
python build/build.py build --help
```
or
```
python build/build.py requirements_update --help
```

PiperOrigin-RevId: 700075411
2024-11-25 13:03:04 -08:00
Nitin Srinivasan
da994d3552 Move utility functions in build.py to utils.py
This commit is the first step towards re-working the build CLI. It moves all the auxiliary functions used by the CLI into a separate script for easier maintenance and readability.

PiperOrigin-RevId: 691458051
2024-10-30 10:00:32 -07:00
jax authors
e4629f6a4c Merge pull request #24232 from ROCm:ci_rv_clang_clean
PiperOrigin-RevId: 684891301
2024-10-11 11:00:55 -07:00
David Dunleavy
ee312afe86 Corresponding build.py updates after 5132a188f7
PiperOrigin-RevId: 684805296
2024-10-11 05:41:28 -07:00
Ruturaj4
33bcd0cb7a [ROCm] Bring up clang support for JAX+XLA
* Add clang path

* bazelrc env fixes

* Fix wheelhouse installation and preserve wheels

* dockerfile changes

* Add target.lst

* Change target architectures

* Install bzip2 and sqlite packages
2024-10-10 16:31:26 -05:00
jax authors
f1b3251bf9 Change CLANG_CUDA_COMPILER_PATH set order. Add --config=cuda_clang to build.py
Set `--action_env=CLANG_CUDA_COMPILER_PATH` after cuda_nvcc configuration
Add `--config=cuda_clang` when `--nouse_cuda_nvcc` flag set

PiperOrigin-RevId: 678873849
2024-09-25 15:39:44 -07:00
jax authors
6e116491c1 Add --use_cuda_nvcc flag to enable or disable compilation of CUDA code using NVCC.
If `--use_cuda_nvcc` flag is set the NVCC compiler driver will be used to build the CUDA code (default behavior). Otherwise, if the flag `--nouse_cuda_nvcc` is set, only the clang compiler will be used to build the CUDA code (effectively disabling NVCC).

Mark `--use_clang` flag as deprecated.

Refactor `.bazelrc` configs to match the new flag and to cleanup all previous confusing names.

PiperOrigin-RevId: 678332548
2024-09-24 11:37:00 -07:00
jax authors
3a6ff86df9 Merge pull request #23088 from vfdev-5:update-build-build-py
PiperOrigin-RevId: 663876121
2024-08-16 14:50:57 -07:00
vfdev-5
12e8bf4525 Pass bazel options to requirements_update and requirements_nightly_update commands 2024-08-16 11:47:21 +02:00
jax authors
a498c1e668 Set Clang as the default compiler in the build script.
PiperOrigin-RevId: 663433112
2024-08-15 13:36:15 -07:00
jax authors
599c13aa09 Introduce hermetic CUDA in Google ML projects.
1) Hermetic CUDA rules allow building wheels with GPU support on a machine without GPUs, as well as running Bazel GPU tests on a machine with only GPUs and NVIDIA driver installed. When `--config=cuda` is provided in Bazel options, Bazel will download CUDA, CUDNN and NCCL redistributions in the cache, and use them during build and test phases.

    [Default location of CUNN redistributions](https://developer.download.nvidia.com/compute/cudnn/redist/)

    [Default location of CUDA redistributions](https://developer.download.nvidia.com/compute/cuda/redist/)

    [Default location of NCCL redistributions](https://pypi.org/project/nvidia-nccl-cu12/#history)

2) To include hermetic CUDA rules in your project, add the following in the WORKSPACE of the downstream project dependent on XLA.

   Note: use `@local_tsl` instead of `@tsl` in Tensorflow project.

   ```
   load(
      "@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl",
      "cuda_json_init_repository",
   )

   cuda_json_init_repository()

   load(
      "@cuda_redist_json//:distributions.bzl",
      "CUDA_REDISTRIBUTIONS",
      "CUDNN_REDISTRIBUTIONS",
   )
   load(
      "@tsl//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl",
      "cuda_redist_init_repositories",
      "cudnn_redist_init_repository",
   )

   cuda_redist_init_repositories(
      cuda_redistributions = CUDA_REDISTRIBUTIONS,
   )

   cudnn_redist_init_repository(
      cudnn_redistributions = CUDNN_REDISTRIBUTIONS,
   )

   load(
      "@tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl",
      "cuda_configure",
   )

   cuda_configure(name = "local_config_cuda")

   load(
      "@tsl//third_party/nccl/hermetic:nccl_redist_init_repository.bzl",
      "nccl_redist_init_repository",
   )

   nccl_redist_init_repository()

   load(
      "@tsl//third_party/nccl/hermetic:nccl_configure.bzl",
      "nccl_configure",
   )

   nccl_configure(name = "local_config_nccl")
   ```

PiperOrigin-RevId: 662981325
2024-08-14 10:58:43 -07:00
Jieying Luo
9c2caedab1 Add subdirectories to the output path when building editable wheels for jaxlib and GPU plugin.
When `build_gpu_plugin` is true, three wheels will be produced (jaxlib, jax-cuda-pjrt and jax-cuda-plugin). If they are editable, they need to be placed in subdirectories to avoid overwrite.

Tested on GPU. After the editable wheels are built, they can be installed with `pip install -e /jax/dist/jax_gpu_pjrt /jax/dist/jaxlib /jax/dist/jax_gpu_plugin`.

PiperOrigin-RevId: 660984311
2024-08-08 14:34:55 -07:00
Jieying Luo
abe7982d65 Remove enable_gpu and xla_python_enable_gpu from jax .bazelrc.
The plugin is released and the flag is no longer needed.

Also set default value of enable_gpu to False. enable_gpu will be removed in the next change.

PiperOrigin-RevId: 660059432
2024-08-06 12:39:45 -07:00
Jieying Luo
bc0229a61f Rollback as it broke some tests.
Reverts ff17b76e3eec3e573788f64fafe23fabcfc09ce2

PiperOrigin-RevId: 658557091
2024-08-01 15:21:42 -07:00
Jieying Luo
ff17b76e3e Cleanup. Remove build:cuda_plugin and set enable_gpu and xla_python_enable_gpu to false in build:cuda.
JAX already migrated from jaxlib[cuda] to cuda plugin.

PiperOrigin-RevId: 658508037
2024-08-01 12:59:12 -07:00
Peter Hawkins
945fde41e4 Update minimum Python version to 3.10. 2024-06-26 13:47:14 -04:00
Dan Foreman-Mackey
1c640b25cb Fix pre-commit formatting errors
It looks like #21916 introduced some whitespace issues that are causing
our pre-commits to fail.
2024-06-26 06:38:51 -04:00
Ruturaj Vaidya
385283c50b
Update build.py 2024-06-25 11:31:18 -05:00
Ruturaj4
a00d030248 [ROCM] nits and fixes 2024-06-18 20:21:23 +00:00
Ruturaj4
99c2b7b4e9 [ROCm] Bring-up pjrt support 2024-06-17 16:49:22 +00:00
Adam Paszke
cfe64cd5ce [Mosaic GPU] Integrate the ExecutionEngine with the jaxlib GPU plugin
This lets us avoid bundling a whole another copy of LLVM with JAX packages
and so we can finally start building Mosaic GPU by default.

PiperOrigin-RevId: 638569750
2024-05-30 01:46:23 -07:00
Vadym Matsishevskyi
517e299a9d Use hermetic Python in JAX, see "Managing hermetic Python" in developer.md for details
PiperOrigin-RevId: 634146391
2024-05-15 18:20:56 -07:00
jax authors
174405c953 The Bazel version used in JAX is bumped from 6.1.2 to 6.5.0.
The update is needed for Windows/Clang builds and for the future hermetic CUDA implementation.

PiperOrigin-RevId: 631519200
2024-05-07 12:58:37 -07:00
jax authors
8ba5c64794 Pass bazel_options directly to the Bazel command, instead of into .bazelrc.
PiperOrigin-RevId: 631099970
2024-05-06 10:05:19 -07:00
Jieying Luo
16b4f69769 Rename arg in build script to be more clear.
The flag means skips GPU plugin extension in jaxlib.

PiperOrigin-RevId: 627203738
2024-04-22 17:22:24 -07:00
Adam Paszke
8e3f5b1018 Initial commit for Mosaic GPU
Moving this to JAX to make it easier to explore Pallas integration.

PiperOrigin-RevId: 625982382
2024-04-18 04:04:10 -07:00
David Dunleavy
6928465b87 Add --use_clang and --clang_path options to build.py
PiperOrigin-RevId: 603837975
2024-02-02 18:20:44 -08:00
Parker Schuh
23b9c2a22f Add the githash that the jaxlib was built at to __init__.py. This is to allow identifying the githash of nightlies.
PiperOrigin-RevId: 595529249
2024-01-03 16:12:23 -08:00
Jieying Luo
c8b3567e82 Add two flags to support only building cuda kernel plugin or cuda pjrt plugin.
PiperOrigin-RevId: 591274120
2023-12-15 09:15:46 -08:00
Jieying Luo
0ce7c7b7bd Register plugin profiler for TPU and remove --config=tpu/--enable_tpu in jaxlib.
PiperOrigin-RevId: 580561059
2023-11-08 09:40:28 -08:00
Jieying Luo
462ef165c4 [PJRT C API] Change build wheel script to build a separate package for cuda kernels.
With this change, `python3 build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12` will generate three wheels:

|                      |size|wheel name                                                               |
|----------------------|----|-------------------------------------------------------------------------|
|jaxlib w/o cuda kernels|76M |jaxlib-0.4.20.dev20231101-cp310-cp310-manylinux2014_x86_64.whl           |
|cuda pjrt              |73M|jax_cuda12_pjrt-0.4.20.dev20231101-py3-none-manylinux2014_x86_64.whl                    |
|cuda kernels           |6.6M|jax_cuda12_plugin-0.4.20.dev20231101-cp310-cp310-manylinux2014_x86_64.whl|

The size of jaxlib with cuda kernels and pjrt is 119M.

The cuda kernel wheel contains all the cuda kernels. A plugin_setup.py and plugin_pyproject.toml are added for this new pacakge.

PiperOrigin-RevId: 579861480
2023-11-06 09:13:44 -08:00
Jieying Luo
0290150c4c Build jaxlib without PJRT GPU deps when plugin will be built.
PiperOrigin-RevId: 573844805
2023-10-16 09:59:07 -07:00
Jieying Luo
432506f1ae [PJRT C API] Fixed pjrt_c_api_gpu and remove noincompatible_remove_legacy_whole_archive
PiperOrigin-RevId: 573094387
2023-10-12 21:25:25 -07:00
Peter Hawkins
73db6ecf2f Set -P when testing whether a package is installed during build.py.
(Only on Python 3.11+)

The test for the "build" package being installed always succeeded because of the subdirectory named "build".
2023-10-10 10:30:37 -04:00
Peter Hawkins
6e5409c008 Add missing raise to build.py 2023-09-26 22:11:01 -04:00
Peter Hawkins
cf28e2c5fa Small improvements to build/build.py
Add a --verbose option that logs all shell() commands run by the script.
Remove some Python 2 backward compatibility logic related to urllib and shutil.

Enable debug logging on Windows wheel builds.

Also include setuptools in the build requirements and test for its presence in build.py.
2023-09-26 21:55:45 -04:00
Jieying Luo
3dbd3649ab Do not share the build command between building jaxlib and gpu plugin as their commands diverge.
PiperOrigin-RevId: 568625959
2023-09-26 13:03:00 -07:00
Jieying Luo
ea01085522 Fix build.py to set include_gpu_plugin_extension as bool flag.
PiperOrigin-RevId: 568572158
2023-09-26 10:01:55 -07:00
Jieying Luo
c7f60fa6eb [PJRT C API] Implement framework side change for registering a custom call.
- Add a py extension to call the custom call C API.
- Change the implementation of register_custom_call_target to store handlers for the custom call targets and delays the registration until the handler for a xla platform is registered.
- Change register_plugin to load PJRT plugin when register_pluin is called (instead of when a client is created), and let it return the PJRT_Api* loaded.
- Delay calling discover_pjrt_plugins() and register_pjrt_plugin_factories_from_env() until the first time backends() is called.

PiperOrigin-RevId: 568265745
2023-09-25 10:52:29 -07:00
Jieying Luo
91fbf9da26 [PJRT C API] Set up jax xla cuda package.
Add a build wheel, pyproject.toml and setup.py.

The directory structure in jax repo is:
jax/
└── plugins/
     └── cuda/
          ├── __init__.py
          ├── pyproject.toml
          └── setup.py

Installed package structure is:
jax_plugins/
     └── xla_cuda_cu12/
           ├── __init__.py
           └── xla_cuda_plugin.so

The major cuda version will be part of the package name.

The plugin wheel can be built with command:
python3 build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12 --bazel_options="--override_repository=xla=$HOME/xla"

PiperOrigin-RevId: 565187954
2023-09-13 16:03:53 -07:00
Peter Hawkins
3c4527b6b0 Check build and wheel are installed before building jaxlib. 2023-07-26 11:46:11 -07:00
Peter Hawkins
f540ae4338 Fix warning about direct invocation of setup.py during jaxlib build.
The jaxlib wheel build currently uses `python setup.py bdist_wheel` to construct the wheel. Change it to use `python -m build -w` instead.

To avoid Python getting confused between the directory named `build` in the bazel tree and the Python `build` module, move `build_wheel.py` into `jaxlib/tools`.

PiperOrigin-RevId: 548133811
2023-07-14 08:31:16 -07:00
Peter Hawkins
1d4b10b775 Remove --distinct_host_configuration from Bazel flags.
This flag does nothing under Bazel 6 and will be removed in Bazel 7.
2023-07-11 11:38:05 -04:00
Jake VanderPlas
ad35702934 Drop support for numpy 1.21
This is in accordance with NEP 29 and https://jax.readthedocs.io/en/latest/deprecation.html
2023-06-23 10:28:26 -07:00
Yash Katariya
fc0dcd15a2 Copybara import of the project:
--
57af5360a1ca1356dbf7760c76e241f7134ef6dd by Jake VanderPlas <jakevdp@google.com>:

[Roll forward] Update required Python version to 3.9

PiperOrigin-RevId: 542728213
2023-06-22 18:58:30 -07:00
Peter Hawkins
119661ce6b Remove older plugin device integration.
Users of this mechanism should migrate to the newer PJRT plugin registration mechanism (see the comments on discover_plugins() in this file).
2023-06-14 15:26:58 -04:00