135 Commits

Author SHA1 Message Date
Ruturaj4
99c2b7b4e9 [ROCm] Bring-up pjrt support 2024-06-17 16:49:22 +00:00
Adam Paszke
cfe64cd5ce [Mosaic GPU] Integrate the ExecutionEngine with the jaxlib GPU plugin
This lets us avoid bundling a whole another copy of LLVM with JAX packages
and so we can finally start building Mosaic GPU by default.

PiperOrigin-RevId: 638569750
2024-05-30 01:46:23 -07:00
Vadym Matsishevskyi
517e299a9d Use hermetic Python in JAX, see "Managing hermetic Python" in developer.md for details
PiperOrigin-RevId: 634146391
2024-05-15 18:20:56 -07:00
jax authors
174405c953 The Bazel version used in JAX is bumped from 6.1.2 to 6.5.0.
The update is needed for Windows/Clang builds and for the future hermetic CUDA implementation.

PiperOrigin-RevId: 631519200
2024-05-07 12:58:37 -07:00
jax authors
8ba5c64794 Pass bazel_options directly to the Bazel command, instead of into .bazelrc.
PiperOrigin-RevId: 631099970
2024-05-06 10:05:19 -07:00
Jieying Luo
16b4f69769 Rename arg in build script to be more clear.
The flag means skips GPU plugin extension in jaxlib.

PiperOrigin-RevId: 627203738
2024-04-22 17:22:24 -07:00
Adam Paszke
8e3f5b1018 Initial commit for Mosaic GPU
Moving this to JAX to make it easier to explore Pallas integration.

PiperOrigin-RevId: 625982382
2024-04-18 04:04:10 -07:00
David Dunleavy
6928465b87 Add --use_clang and --clang_path options to build.py
PiperOrigin-RevId: 603837975
2024-02-02 18:20:44 -08:00
Parker Schuh
23b9c2a22f Add the githash that the jaxlib was built at to __init__.py. This is to allow identifying the githash of nightlies.
PiperOrigin-RevId: 595529249
2024-01-03 16:12:23 -08:00
Jieying Luo
c8b3567e82 Add two flags to support only building cuda kernel plugin or cuda pjrt plugin.
PiperOrigin-RevId: 591274120
2023-12-15 09:15:46 -08:00
Jieying Luo
0ce7c7b7bd Register plugin profiler for TPU and remove --config=tpu/--enable_tpu in jaxlib.
PiperOrigin-RevId: 580561059
2023-11-08 09:40:28 -08:00
Jieying Luo
462ef165c4 [PJRT C API] Change build wheel script to build a separate package for cuda kernels.
With this change, `python3 build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12` will generate three wheels:

|                      |size|wheel name                                                               |
|----------------------|----|-------------------------------------------------------------------------|
|jaxlib w/o cuda kernels|76M |jaxlib-0.4.20.dev20231101-cp310-cp310-manylinux2014_x86_64.whl           |
|cuda pjrt              |73M|jax_cuda12_pjrt-0.4.20.dev20231101-py3-none-manylinux2014_x86_64.whl                    |
|cuda kernels           |6.6M|jax_cuda12_plugin-0.4.20.dev20231101-cp310-cp310-manylinux2014_x86_64.whl|

The size of jaxlib with cuda kernels and pjrt is 119M.

The cuda kernel wheel contains all the cuda kernels. A plugin_setup.py and plugin_pyproject.toml are added for this new pacakge.

PiperOrigin-RevId: 579861480
2023-11-06 09:13:44 -08:00
Jieying Luo
0290150c4c Build jaxlib without PJRT GPU deps when plugin will be built.
PiperOrigin-RevId: 573844805
2023-10-16 09:59:07 -07:00
Jieying Luo
432506f1ae [PJRT C API] Fixed pjrt_c_api_gpu and remove noincompatible_remove_legacy_whole_archive
PiperOrigin-RevId: 573094387
2023-10-12 21:25:25 -07:00
Peter Hawkins
73db6ecf2f Set -P when testing whether a package is installed during build.py.
(Only on Python 3.11+)

The test for the "build" package being installed always succeeded because of the subdirectory named "build".
2023-10-10 10:30:37 -04:00
Peter Hawkins
6e5409c008 Add missing raise to build.py 2023-09-26 22:11:01 -04:00
Peter Hawkins
cf28e2c5fa Small improvements to build/build.py
Add a --verbose option that logs all shell() commands run by the script.
Remove some Python 2 backward compatibility logic related to urllib and shutil.

Enable debug logging on Windows wheel builds.

Also include setuptools in the build requirements and test for its presence in build.py.
2023-09-26 21:55:45 -04:00
Jieying Luo
3dbd3649ab Do not share the build command between building jaxlib and gpu plugin as their commands diverge.
PiperOrigin-RevId: 568625959
2023-09-26 13:03:00 -07:00
Jieying Luo
ea01085522 Fix build.py to set include_gpu_plugin_extension as bool flag.
PiperOrigin-RevId: 568572158
2023-09-26 10:01:55 -07:00
Jieying Luo
c7f60fa6eb [PJRT C API] Implement framework side change for registering a custom call.
- Add a py extension to call the custom call C API.
- Change the implementation of register_custom_call_target to store handlers for the custom call targets and delays the registration until the handler for a xla platform is registered.
- Change register_plugin to load PJRT plugin when register_pluin is called (instead of when a client is created), and let it return the PJRT_Api* loaded.
- Delay calling discover_pjrt_plugins() and register_pjrt_plugin_factories_from_env() until the first time backends() is called.

PiperOrigin-RevId: 568265745
2023-09-25 10:52:29 -07:00
Jieying Luo
91fbf9da26 [PJRT C API] Set up jax xla cuda package.
Add a build wheel, pyproject.toml and setup.py.

The directory structure in jax repo is:
jax/
└── plugins/
     └── cuda/
          ├── __init__.py
          ├── pyproject.toml
          └── setup.py

Installed package structure is:
jax_plugins/
     └── xla_cuda_cu12/
           ├── __init__.py
           └── xla_cuda_plugin.so

The major cuda version will be part of the package name.

The plugin wheel can be built with command:
python3 build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12 --bazel_options="--override_repository=xla=$HOME/xla"

PiperOrigin-RevId: 565187954
2023-09-13 16:03:53 -07:00
Peter Hawkins
3c4527b6b0 Check build and wheel are installed before building jaxlib. 2023-07-26 11:46:11 -07:00
Peter Hawkins
f540ae4338 Fix warning about direct invocation of setup.py during jaxlib build.
The jaxlib wheel build currently uses `python setup.py bdist_wheel` to construct the wheel. Change it to use `python -m build -w` instead.

To avoid Python getting confused between the directory named `build` in the bazel tree and the Python `build` module, move `build_wheel.py` into `jaxlib/tools`.

PiperOrigin-RevId: 548133811
2023-07-14 08:31:16 -07:00
Peter Hawkins
1d4b10b775 Remove --distinct_host_configuration from Bazel flags.
This flag does nothing under Bazel 6 and will be removed in Bazel 7.
2023-07-11 11:38:05 -04:00
Jake VanderPlas
ad35702934 Drop support for numpy 1.21
This is in accordance with NEP 29 and https://jax.readthedocs.io/en/latest/deprecation.html
2023-06-23 10:28:26 -07:00
Yash Katariya
fc0dcd15a2 Copybara import of the project:
--
57af5360a1ca1356dbf7760c76e241f7134ef6dd by Jake VanderPlas <jakevdp@google.com>:

[Roll forward] Update required Python version to 3.9

PiperOrigin-RevId: 542728213
2023-06-22 18:58:30 -07:00
Peter Hawkins
119661ce6b Remove older plugin device integration.
Users of this mechanism should migrate to the newer PJRT plugin registration mechanism (see the comments on discover_plugins() in this file).
2023-06-14 15:26:58 -04:00
Peter Hawkins
a18e82b28b Update bazel version to 6.1.2.
Several of our CI builds are already using 6.1.2, so it's probably best to upgrade for consistency.
2023-05-10 10:57:29 -04:00
Yash Katariya
6506ee2a40 Copybara import of the project:
--
57af5360a1ca1356dbf7760c76e241f7134ef6dd by Jake VanderPlas <jakevdp@google.com>:

[Rollback] Update required Python version to 3.9

PiperOrigin-RevId: 528905991
2023-05-02 15:33:29 -07:00
Jake VanderPlas
57af5360a1 Update required Python version to 3.9 2023-05-01 10:00:57 -07:00
jax authors
6715736583 Merge pull request #15205 from yhtang:editable-jaxlib-build
PiperOrigin-RevId: 519704474
2023-03-27 06:33:31 -07:00
Yu-Hang 'Maxin' Tang
caaa0a2669 add build option to create editable jaxlib
Co-authored-by: Yonghao Zhuang <zhuangyh@sjtu.edu.cn>
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
2023-03-24 21:25:26 +00:00
Peter Hawkins
6ed66ada0f Delete remote TPU support.
TPU VMs are the only supported way to use TPUs as of JAX 0.4.0.

PiperOrigin-RevId: 519211267
2023-03-24 12:33:33 -07:00
Peter Hawkins
b7375b316b Increase minimum NumPy version to 1.21.
Also increase minimum SciPy version to 1.7, which was released just before NumPy 1.21.
2023-03-23 21:15:10 -04:00
Jake VanderPlas
e7f53479e2 Some cleanups related to dropping Python 3.7 2022-11-29 15:54:49 -08:00
jax authors
254dc24a8b Merge pull request #11961 from jakeh-gc:plugin_device
PiperOrigin-RevId: 476363760
2022-09-23 07:29:17 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jake
21f82c6c0d Use the pjrt plugin device client. 2022-08-17 14:34:07 +01:00
Peter Hawkins
03876bd702 build.py fixes.
* Add aarch64 as a known target_cpu value.
* Only pass --bazel_options to build actions since they can make "bazel
  shutdown" fail.
* Pass the bazel startup options to "bazel shutdown".

Issue https://github.com/google/jax/issues/7097
Fixes https://github.com/google/jax/issues/7639
2022-08-16 15:47:15 +00:00
Peter Hawkins
c735c6bf0e Increase minimum NumPy version to 1.20.
Per NEP 29, support for 1.19 ended on Jun 21, 2022.
2022-08-06 14:51:14 +00:00
Peter Hawkins
1c75eee1ff Document how to run tests using Bazel.
* Add a new --configure_only option to build.py to allow build.py to generate a .bazelrc without necessarily building jaxlib.
* Add a bazel flag that make the dependency of //jax on //jaxlib optional. If //jaxlib isn't built by bazel, then tests will implicitly use a preinstalled jaxlib.
2022-07-06 08:30:35 -04:00
Peter Hawkins
22304eeb2e Add a build flag that allows disabling remote TPU builds.
Disable remote TPU by default.
2022-06-23 21:14:52 +00:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Peter Hawkins
9fb9e12169 Don't include PTX for older GPU generations.
See: https://github.com/tensorflow/tensorflow/pull/55613

For a CUDA build at head with the default compute capabilities, reduces wheel size from 141MB to 112MB.

Don't redundantly specify default compute capabilities in .bazelrc and in
build.py.
2022-05-02 20:27:37 +00:00
Yash Katariya
6ba9fb699d Upgrade the bazel version to 5.1.1
PiperOrigin-RevId: 441338363
2022-04-12 17:48:09 -07:00
Yash Katariya
aa5d6b4a58 Fix the breakage by including --experimental_cc_shared_library as done by TF.
PiperOrigin-RevId: 438746867
2022-03-31 23:07:42 -07:00
Reza Rahimi
a0d9d81f92 Update JAX to use new math libraries in ROCm-5.0. 2022-03-01 20:02:15 +00:00
Peter Hawkins
2e0cfe8e42 Update the list of default CUDA capabilities used for wheel builds to match build.py. 2022-02-15 09:23:28 -05:00
Peter Hawkins
2388e353da Increase bazel version to 5.0.0 to match TensorFlow
(8871926b0a).
2022-01-28 21:11:02 +00:00
Peter Hawkins
04369a3588 Drop support for NumPy 1.18.
Per NEP-29, we can drop NumPy 1.18 support on Dec 22, 2021.

The next NumPy deprecation will be 1.19 on Jun 21, 2022.

PiperOrigin-RevId: 419651428
2022-01-04 12:11:38 -08:00