488 Commits

Author SHA1 Message Date
Charles Hofer
a1734fd31f Change to trigger CI 2025-01-06 15:50:02 +00:00
Nitin Srinivasan
0159bead97 Move output path to be inside the wheel build command execution loop
This was causing an issue when building multiple wheels in editable mode.

i.e instead of wheels being stored as:
```
# jax-cuda12-pjrt   0.4.36.dev20241125           ./dist/jax-cuda-pjrt
# jax-cuda12-plugin 0.4.36.dev20241125           ./dist/jax-cuda-plugin
# jaxlib            0.4.36.dev20241125           ./dist/jaxlib
```

they were being stored as:
```
# jaxlib            0.4.36.dev20241125           ./dist/jaxlib
# jax-cuda12-pjrt   0.4.36.dev20241125           ./dist/jaxlib/jax-cuda-pjrt
# jax-cuda12-plugin 0.4.36.dev20241125           ./dist/jaxlib/jax-cuda-plugin
```

PiperOrigin-RevId: 708468522
2024-12-20 17:27:51 -08:00
Nitin Srinivasan
6b096b0cb0 Use common set of build options when building jaxlib+plugin artifacts together
This commit modifies the behavior of the build CLI when building jaxlib and GPU plugin artifacts together (for instance `python build --wheels=jaxlib,jax-cuda-plugin`.

Before, CUDA/ROCm build options were only passed when building the CUDA/ROCm artifacts. However, this leads to inefficient use of the build cache as it looks like Bazel tries to rebuild some targets that has already been built in the previous run. This seems to be because the GPU plugin artifacts have a different set of build options compared to `jaxlib` which for some reason causes Bazel to invalidate/ignore certain cache hits. Therefore, this commit makes it so that the build options remain the same when the `jaxlib` and GPU artifacts are being built together so that we can better utilize the build cache.

As an example, this means that if `python build --wheels=jaxlib,jax-cuda-plugin` is run, the following build options will apply to both `jaxlib` and `jax-cuda-plugin` builds:
```
 /usr/local/bin/bazel run --repo_env=HERMETIC_PYTHON_VERSION=3.10 \
--verbose_failures=true --action_env=CLANG_COMPILER_PATH="/usr/lib/llvm-16/bin/clang" \
--repo_env=CC="/usr/lib/llvm-16/bin/clang" \
--repo_env=BAZEL_COMPILER="/usr/lib/llvm-16/bin/clang" \
--config=clang --config=mkl_open_source_only --config=avx_posix \
--config=cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-16/bin/clang" \
--config=build_cuda_with_nvcc
```

Note, this commit shouldn't affect the content of the wheel it self. It is only meant to give a performance boost when building `jalxib`+plugin aritfacts together.

Also, this removes code that was used to build (now deprecated) monolithic `jaxlib` build from `build_wheel.py`

PiperOrigin-RevId: 708035062
2024-12-19 14:29:24 -08:00
jax authors
bad6f0f0e9 Merge pull request #25440 from ROCm:jax_pypi-upstream
PiperOrigin-RevId: 707934347
2024-12-19 09:17:46 -08:00
jax authors
25713f95f2 Merge pull request #25148 from ROCm:ci_install_sys_libs-upstream
PiperOrigin-RevId: 707934273
2024-12-19 09:15:55 -08:00
jax authors
c78ca042e7 Add experimental support for building JAX CPU and GPU wheels with GCC.
The `build.py` script uses Clang compiler by default, and JAX doesn't support building with GCC officially. However, experimental GCC support is still present.

Command examples:

```
python build/build.py build --wheels=jaxlib,jax-cuda-plugin --use_clang=false
python build/build.py build --wheels=jaxlib,jax-cuda-plugin --use_clang=false --gcc_path=/use/bin/gcc
```

This change addresses the request in https://github.com/jax-ml/jax/issues/25488.

PiperOrigin-RevId: 707930913
2024-12-19 09:03:25 -08:00
Ruturaj4
fefebebea4 Update documentation and add setup.py pypi bindings 2024-12-19 10:21:28 -06:00
Ruturaj4
bfcace4933 [ROCm] ci build and dockerfile changes 2024-12-19 09:02:55 -06:00
Nitin Srinivasan
bcca77cd8e Enable --config=clang only on newer Clang versions
These flags disable Clang extensions that do things such as reject type definitions within offsetof or reject unknown arguments which does not seem to be needed on versions older than Clang 16

Also, fix a syntax error

Fixes https://github.com/jax-ml/jax/issues/25530

PiperOrigin-RevId: 707555651
2024-12-18 08:12:04 -08:00
Adam J. Stewart
1afed917fb
get_githash: fix support for missing git 2024-12-13 10:02:05 -05:00
Sunita Nadampalli
e370deee0f add mkldnn+acl build config for aarch64 platform 2024-12-09 16:03:14 +00:00
Nitin Srinivasan
83c64b2379 Add a flag to enable detailed timestamped logging of subprocess commands.
This adds a new command-line flag, `--detailed_timestamped_log`, that enables detailed logging of Bazel build commands. When disabled (the default), logging mirrors the output you'd see when running the command directly in your terminal.

When this flag is enabled:
- Bazel's output is captured line by line.
- Each line is timestamped for improved traceability.
- The complete log is stored for potential use as an artifact.

The flag is disabled by default and only enabled in the CI builds. If you're running locally and enable `detailed_timestamped_log`, you might notice that Bazel's output is not colored. To force a color output, include `--bazel_options=--color=yes` in your command.

PiperOrigin-RevId: 703581368
2024-12-06 12:32:24 -08:00
jax authors
7b32d88247 Merge pull request #25136 from ROCm:ci_dockerfile_arg_changes-upstream
PiperOrigin-RevId: 701959495
2024-12-02 07:21:15 -08:00
jax authors
ab79066bbe Merge pull request #25128 from ROCm:ci_fix_wheelhouse_relative_paths-upstream
PiperOrigin-RevId: 701143534
2024-11-28 18:51:49 -08:00
jax authors
b0df405250 Merge pull request #25130 from ROCm:ci_fix_set_options-upstream
PiperOrigin-RevId: 701143242
2024-11-28 18:49:47 -08:00
jax authors
385e2f4339 Merge pull request #25137 from ROCm:ci_enable_https-upstream
PiperOrigin-RevId: 701142951
2024-11-28 18:45:52 -08:00
jax authors
04a4f9bd8f Merge pull request #25096 from nitins17:update-rocm-ci-scripts
PiperOrigin-RevId: 700725187
2024-11-27 09:27:02 -08:00
Nitin Srinivasan
d449f12a2e Fix early exiting when building multiple wheels
PiperOrigin-RevId: 700711389
2024-11-27 08:35:51 -08:00
Nitin Srinivasan
c6866d05db Add a check for return codes of executor.run so that we propagate error codes correctly
PiperOrigin-RevId: 700518396
2024-11-26 18:18:06 -08:00
Ruturaj4
3d8063209e Update http to https in amd artifactory url. 2024-11-26 22:38:04 +00:00
Ruturaj4
8df2766466 Add argument to override base docker in dockerfile 2024-11-26 22:28:53 +00:00
Ruturaj4
694de6b64c [ROCm] Change run_multi_gpu set opts 2024-11-26 21:10:37 +00:00
Ruturaj4
d30ec2b5b3 [ROCm] fix jax and wheelhouse relative paths 2024-11-26 21:04:36 +00:00
Nitin Srinivasan
6761512658 Re-factor build CLI to a subcommand based approach
This commit reworks the JAX build CLI to a subcommand based approach where CLI use cases are now defined as subcommands. Two subcommands are defined: build and requirements_update. "build" is to be used when wanting to build a JAX wheel package. "requirements_update" is to be used when wanting to update the requirements_lock.txt files. The new structure offers a clear and organized CLI that enables users to execute specific build tasks without having to navigate through a monolithic script.

Each subcommand has specific arguments that apply to its respective build process. In addition, arguments are separated into groups to achieve a cleaner separation and improves the readability when the CLI subcommands are run with `--help`. It also makes it clear as to which parts of the build they affect. E.g: CUDA arguments only apply to CUDA builds, ROCM arguments only apply to ROCM builds, etc. This reduces the complexity and the potential for errors during the build process. Segregating functionalities into distinct subcommands also simplifies the code which should help with the maintenance and future extensions.

There is also a transition from using `subprocess.check_output` to `asyncio.create_subprocess_shell` for executing the build commands which allows for streaming logs and helps in showing the build progress in real time.

Usage:
* Building `jaxlib`:
```
python build/build.py build --wheels=jaxlib --python_version=3.10
```
* Building `jax-cuda-plugin`:
```
python build/build.py build --wheels=jax-cuda-plugin --cuda_version=12.3.2 --cudnn_version=9.1.1 --python_version=3.10
```
* Building multiple packages:
```
python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt --cuda_version=12.3.2 --cudnn_version=9.1.1 --python_version=3.10
```
* Building `jax-rocm-pjrt`:
```
python build/build.py build --wheels=jax-rocm-pjrt --rocm_version=60 --rocm_path=/path/to/rocm
```
* Using a local XLA path:
```
python build/build.py build --wheels=jaxlib --local_xla_path=/path/to/xla
```
* Updating requirements_lock.txt files:
```
python build/build.py requirements_update --python_version=3.10
```

For more details on each argument and to see available options, run:
```
python build/build.py build --help
```
or
```
python build/build.py requirements_update --help
```

PiperOrigin-RevId: 700075411
2024-11-25 13:03:04 -08:00
Nitin Srinivasan
84dc9bab33 Update ROCm scripts to match new build.py usage 2024-11-25 19:25:08 +00:00
jax authors
a0b0a8e5a1 Set minimum supported Python version to 3.10 for matplotlib.
Temporary fixes an issue with `python -m build` that fails when python 3.8 is used because `matplotlib~=3.8.4` is unavailable for this python version.

We are working on creating Bazel build rule with the hermetic Python for JAX wheel ([we already have Jaxlib and plugins build rules ready](https://github.com/jax-ml/jax/pull/23276)). The required python modules are provided in requirements.in file, so when we implement Bazel build rule for JAX wheel, requirements.in will be the only source of dependencies, and test-requirements.txt won't be needed for building JAX wheel.

PiperOrigin-RevId: 692260046
2024-11-01 12:34:28 -07:00
Vadym Matsishevskyi
a75d94622c Reverts 72f9a493589a1046e6927a5f16d7dc71df530743
PiperOrigin-RevId: 691843537
2024-10-31 10:05:22 -07:00
Nitin Srinivasan
da994d3552 Move utility functions in build.py to utils.py
This commit is the first step towards re-working the build CLI. It moves all the auxiliary functions used by the CLI into a separate script for easier maintenance and readability.

PiperOrigin-RevId: 691458051
2024-10-30 10:00:32 -07:00
Peter Hawkins
72f9a49358 Reverts 6d8950c04f23ad15a0443006f1e5bd21bfa84156
PiperOrigin-RevId: 691222756
2024-10-29 17:46:55 -07:00
Vadym Matsishevskyi
6d8950c04f Cleanup requirements.in and test-requirements.txt
PiperOrigin-RevId: 691208596
2024-10-29 16:50:54 -07:00
Ruturaj4
bfd7075c39 [ROCm] ci build fixes 2024-10-25 05:01:44 -05:00
jax authors
3bdc57dd29 Merge pull request #24300 from ROCm:ci_rocm_readme
PiperOrigin-RevId: 686872994
2024-10-17 05:21:13 -07:00
Ruturaj4
3c3b08dfd6 [ROCm] Fix README.md to update AMD JAX installation instructions 2024-10-16 17:15:32 -05:00
Ruturaj4
a2824862f5 [ROCm] build script fix 2024-10-15 13:43:08 -05:00
Ruturaj4
937d79e3f2 [ROCm] apt update 2024-10-11 18:33:12 -05:00
jax authors
e4629f6a4c Merge pull request #24232 from ROCm:ci_rv_clang_clean
PiperOrigin-RevId: 684891301
2024-10-11 11:00:55 -07:00
David Dunleavy
ee312afe86 Corresponding build.py updates after 5132a188f7
PiperOrigin-RevId: 684805296
2024-10-11 05:41:28 -07:00
Ruturaj4
33bcd0cb7a [ROCm] Bring up clang support for JAX+XLA
* Add clang path

* bazelrc env fixes

* Fix wheelhouse installation and preserve wheels

* dockerfile changes

* Add target.lst

* Change target architectures

* Install bzip2 and sqlite packages
2024-10-10 16:31:26 -05:00
jax authors
53668b88eb Update rules_python.patch to support Python 3.13.0 and update python 3.13 packages in JAX.
The downloaded Python `tar.gz` files should have suffix `install_only`.

PiperOrigin-RevId: 684113192
2024-10-09 11:36:33 -07:00
jax authors
e212c77336 Merge pull request #23891 from ROCm:build-fixes-rollup
PiperOrigin-RevId: 681448694
2024-10-02 07:43:13 -07:00
Mathew Odden
9ff891dfa1 [ROCm] Remove broken legacy env vars
These env vars are no longer used or need and were
being set incorrectly.

[ROCm] Use specific amdgpu version for EL8 systems

We were always installing the latest driver versions
but this had some side effects when yum would try
to download index files from a URL with changing content.

[ROCm] Fix formatting on python files

Reformatted with black
2024-09-30 12:39:51 -05:00
jax authors
f1b3251bf9 Change CLANG_CUDA_COMPILER_PATH set order. Add --config=cuda_clang to build.py
Set `--action_env=CLANG_CUDA_COMPILER_PATH` after cuda_nvcc configuration
Add `--config=cuda_clang` when `--nouse_cuda_nvcc` flag set

PiperOrigin-RevId: 678873849
2024-09-25 15:39:44 -07:00
8bitmp3
60a06fd4c9
Update pillow version in JAX build test-requirements.txt 2024-09-25 14:55:46 +00:00
jax authors
6e116491c1 Add --use_cuda_nvcc flag to enable or disable compilation of CUDA code using NVCC.
If `--use_cuda_nvcc` flag is set the NVCC compiler driver will be used to build the CUDA code (default behavior). Otherwise, if the flag `--nouse_cuda_nvcc` is set, only the clang compiler will be used to build the CUDA code (effectively disabling NVCC).

Mark `--use_clang` flag as deprecated.

Refactor `.bazelrc` configs to match the new flag and to cleanup all previous confusing names.

PiperOrigin-RevId: 678332548
2024-09-24 11:37:00 -07:00
jax authors
0c7c71e640 Update python version from 3.12 to 3.13.0rc2 in Github presubmit jobs.
PiperOrigin-RevId: 676140293
2024-09-18 14:49:42 -07:00
jax authors
8bcdb12852 Add CI jobs for python 3.13.0rc2.
PiperOrigin-RevId: 675758096
2024-09-17 16:51:35 -07:00
Vadym Matsishevskyi
8804be0229 Add Python 3.130rc2 support to the build.
This PR depends on https://github.com/openxla/xla/pull/17169. The change does not fail existing builds, but to be able to use python 3.13 functionality in jax the corresponding XLA pr needs to land first and get integrated with JAX (happens automatically).

PiperOrigin-RevId: 675243989
2024-09-16 12:14:32 -07:00
Mathew Odden
a9e54b3e0a Add docker builds for ubu22 and 24 2024-08-27 16:40:37 -05:00
jax authors
140955dce0 Merge pull request #23224 from ROCm:ci_build_script
PiperOrigin-RevId: 668061975
2024-08-27 11:12:05 -07:00
Ayaka
859eacb5a1 Fix mypy error 2024-08-27 16:57:53 +01:00