493 Commits

Author SHA1 Message Date
Jieying Luo
c8b3567e82 Add two flags to support only building cuda kernel plugin or cuda pjrt plugin.
PiperOrigin-RevId: 591274120
2023-12-15 09:15:46 -08:00
George Necula
b077483bfa [export] Add support for serialization and deserialization of Exported
At the moment we can export a JAX function into an Exported and we can invoke an Exported from another JAX function, but there is no way to serialize an Exported to be able to use it in another process.

Exported now has all the features we had in mind, so it is a reasonable time to add a serialization method. The intention is for the serialization to have backwards compatibility guarantees, meaning that we can deserialize an Exported that has been serialized with an older version of JAX. This PR does not add explicit support for versioning, nor backwards compatibility tests. Those will follow.

Here we add serialization and deserialization to bytearray, using the flatbuffers package. We use flatbuffers because it is simple, it has backwards and forwards compatibility guarantees, it is a lightweight dependency that does not require additional build steps, and it is fast (deserialization simply indexes into the bytearray rather than creating a Python structure).

In the process of implementing this we have done some small cleanup of the Exported structure:

  * renamed serialization_version to mlir_module_serialization_version
  * renamed disabled_checks to disabled_safety_checks

This code is tested by changing export_test.py to interpose a serialization followed by a deserialization every time we export.export.

There is a known bug with the serialization of effects, so I disabled one of the export tests. Will fix in a subsequent PR.

PiperOrigin-RevId: 590078785
2023-12-11 23:23:02 -08:00
Jieying Luo
0ce7c7b7bd Register plugin profiler for TPU and remove --config=tpu/--enable_tpu in jaxlib.
PiperOrigin-RevId: 580561059
2023-11-08 09:40:28 -08:00
Jieying Luo
462ef165c4 [PJRT C API] Change build wheel script to build a separate package for cuda kernels.
With this change, `python3 build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12` will generate three wheels:

|                      |size|wheel name                                                               |
|----------------------|----|-------------------------------------------------------------------------|
|jaxlib w/o cuda kernels|76M |jaxlib-0.4.20.dev20231101-cp310-cp310-manylinux2014_x86_64.whl           |
|cuda pjrt              |73M|jax_cuda12_pjrt-0.4.20.dev20231101-py3-none-manylinux2014_x86_64.whl                    |
|cuda kernels           |6.6M|jax_cuda12_plugin-0.4.20.dev20231101-cp310-cp310-manylinux2014_x86_64.whl|

The size of jaxlib with cuda kernels and pjrt is 119M.

The cuda kernel wheel contains all the cuda kernels. A plugin_setup.py and plugin_pyproject.toml are added for this new pacakge.

PiperOrigin-RevId: 579861480
2023-11-06 09:13:44 -08:00
Jieying Luo
0290150c4c Build jaxlib without PJRT GPU deps when plugin will be built.
PiperOrigin-RevId: 573844805
2023-10-16 09:59:07 -07:00
Jieying Luo
432506f1ae [PJRT C API] Fixed pjrt_c_api_gpu and remove noincompatible_remove_legacy_whole_archive
PiperOrigin-RevId: 573094387
2023-10-12 21:25:25 -07:00
Peter Hawkins
73db6ecf2f Set -P when testing whether a package is installed during build.py.
(Only on Python 3.11+)

The test for the "build" package being installed always succeeded because of the subdirectory named "build".
2023-10-10 10:30:37 -04:00
Peter Hawkins
6e5409c008 Add missing raise to build.py 2023-09-26 22:11:01 -04:00
Peter Hawkins
cf28e2c5fa Small improvements to build/build.py
Add a --verbose option that logs all shell() commands run by the script.
Remove some Python 2 backward compatibility logic related to urllib and shutil.

Enable debug logging on Windows wheel builds.

Also include setuptools in the build requirements and test for its presence in build.py.
2023-09-26 21:55:45 -04:00
Jieying Luo
3dbd3649ab Do not share the build command between building jaxlib and gpu plugin as their commands diverge.
PiperOrigin-RevId: 568625959
2023-09-26 13:03:00 -07:00
Jieying Luo
ea01085522 Fix build.py to set include_gpu_plugin_extension as bool flag.
PiperOrigin-RevId: 568572158
2023-09-26 10:01:55 -07:00
Jieying Luo
c7f60fa6eb [PJRT C API] Implement framework side change for registering a custom call.
- Add a py extension to call the custom call C API.
- Change the implementation of register_custom_call_target to store handlers for the custom call targets and delays the registration until the handler for a xla platform is registered.
- Change register_plugin to load PJRT plugin when register_pluin is called (instead of when a client is created), and let it return the PJRT_Api* loaded.
- Delay calling discover_pjrt_plugins() and register_pjrt_plugin_factories_from_env() until the first time backends() is called.

PiperOrigin-RevId: 568265745
2023-09-25 10:52:29 -07:00
Jieying Luo
91fbf9da26 [PJRT C API] Set up jax xla cuda package.
Add a build wheel, pyproject.toml and setup.py.

The directory structure in jax repo is:
jax/
└── plugins/
     └── cuda/
          ├── __init__.py
          ├── pyproject.toml
          └── setup.py

Installed package structure is:
jax_plugins/
     └── xla_cuda_cu12/
           ├── __init__.py
           └── xla_cuda_plugin.so

The major cuda version will be part of the package name.

The plugin wheel can be built with command:
python3 build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12 --bazel_options="--override_repository=xla=$HOME/xla"

PiperOrigin-RevId: 565187954
2023-09-13 16:03:53 -07:00
Rahul Batra
ef79c19093 [ROCm]: Dockerfile and build script updates
Add hipblaslt in Dockerfile
	Update docker file to default to ROCm5.6
 	CI scripts update to handle multiple ROCm versions
2023-09-08 22:56:59 +00:00
Sharad Vikram
d872812a35 [Pallas] Upstream pallas to JAX
PiperOrigin-RevId: 552963029
2023-08-01 16:43:13 -07:00
Peter Hawkins
3c4527b6b0 Check build and wheel are installed before building jaxlib. 2023-07-26 11:46:11 -07:00
Skye Wanderman-Milne
6c909760d5 Cloud TPU CI: make sure we update test deps and upgrade protobuf version
`profiler_test.py:ProfilerTest.test_remote_profiler` fails with the
protobuf upgrade. However, I was seeing mysterious hangs without this,
and in general I think we should be testing with up-to-date deps given
that we don't pin. I'm gonna continue working on getting the Cloud TPU
CI green.
2023-07-19 16:17:47 -07:00
Peter Hawkins
1aa09dbd3e Fix broken jaxlib wheel build by moving LICENSE.txt to the correct location.
PiperOrigin-RevId: 548189924
2023-07-14 12:13:28 -07:00
Peter Hawkins
f540ae4338 Fix warning about direct invocation of setup.py during jaxlib build.
The jaxlib wheel build currently uses `python setup.py bdist_wheel` to construct the wheel. Change it to use `python -m build -w` instead.

To avoid Python getting confused between the directory named `build` in the bazel tree and the Python `build` module, move `build_wheel.py` into `jaxlib/tools`.

PiperOrigin-RevId: 548133811
2023-07-14 08:31:16 -07:00
Peter Hawkins
1d4b10b775 Remove --distinct_host_configuration from Bazel flags.
This flag does nothing under Bazel 6 and will be removed in Bazel 7.
2023-07-11 11:38:05 -04:00
jax authors
60d481078e Merge pull request #16523 from ROCmSoftwarePlatform:rocm-update-build-doc
PiperOrigin-RevId: 547185925
2023-07-11 07:41:51 -07:00
Jake VanderPlas
ad35702934 Drop support for numpy 1.21
This is in accordance with NEP 29 and https://jax.readthedocs.io/en/latest/deprecation.html
2023-06-23 10:28:26 -07:00
Peter Hawkins
bfa113ba60 Remove references to Python 3.8.
Remove the old build scripts/Dockerfile, since they are unused and broken.

PiperOrigin-RevId: 542870354
2023-06-23 08:48:57 -07:00
Yash Katariya
fc0dcd15a2 Copybara import of the project:
--
57af5360a1ca1356dbf7760c76e241f7134ef6dd by Jake VanderPlas <jakevdp@google.com>:

[Roll forward] Update required Python version to 3.9

PiperOrigin-RevId: 542728213
2023-06-22 18:58:30 -07:00
Rahul Batra
80bd361364 [ROCm]: Update README.md 2023-06-22 15:53:04 +00:00
jax authors
961d918883 Merge pull request #16333 from skye:profiler_test
PiperOrigin-RevId: 542069606
2023-06-20 15:40:10 -07:00
Skye Wanderman-Milne
2ca151ef5b profiler_test.py fixes and add coverage to Cloud TPU CI
* Add deps to test requirements, including in new
  `collect-profile-requirements.txt` (to avoid adding tensorflow to
  `test-requirements.txt`).
* Use correct Python executable `ProfilerTest.test_remote_profiler`
  (`python` sometimes defaults to python2)
* Run computations for longer in `ProfilerTest.test_remote_profiler`,
  othewise `collect_profile` sometimes misses it.
2023-06-20 22:25:17 +00:00
Rahul Batra
88a2c8ca5e [ROCm]: Updates defaults in build script 2023-06-20 20:25:40 +00:00
jax authors
632c79093e Merge pull request #16437 from ROCmSoftwarePlatform:rocm-updates-build
PiperOrigin-RevId: 541896270
2023-06-20 06:19:59 -07:00
Rahul Batra
b0e541a730 [ROCm]: Updates for container and build script
-Updated dockerfile.ms
	-Updated build script to switch building against XLA repo
  	-Update CI script
	-Update jaxlib setup.py to add rocm version
2023-06-19 18:13:28 +00:00
Peter Hawkins
119661ce6b Remove older plugin device integration.
Users of this mechanism should migrate to the newer PJRT plugin registration mechanism (see the comments on discover_plugins() in this file).
2023-06-14 15:26:58 -04:00
Peter Hawkins
0c441574c4 Add NumPy as a test requirement.
The Windows CI currently installs all of the test requirements before building jaxlib, but NumPy is needed to build jaxlib.
Previously this came transitively via matplotlib.
2023-06-14 10:14:56 -04:00
Peter Hawkins
6b76937c53 Remove matplotlib from the test requirements.
In the Windows CI, we seem to be hitting the following error:

```
=================================== ERRORS ====================================
____________________ ERROR collecting tests/lobpcg_test.py ____________________
tests\lobpcg_test.py:28: in <module>
    from matplotlib import pyplot as plt
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\pyplot.py:52: in <module>
    import matplotlib.colorbar
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\colorbar.py:19: in <module>
    from matplotlib import _api, cbook, collections, cm, colors, contour, ticker
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\contour.py:13: in <module>
    from matplotlib.backend_bases import MouseButton
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\backend_bases.py:45: in <module>
    from matplotlib import (
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\text.py:16: in <module>
    from .font_manager import FontProperties
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\font_manager.py:1548: in <module>
    fontManager = _load_fontmanager()
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\font_manager.py:1543: in _load_fontmanager
    json_dump(fm, fm_path)
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\font_manager.py:957: in json_dump
    with cbook._lock_path(filename), open(filename, 'w') as fh:
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\contextlib.py:119: in __enter__
    return next(self.gen)
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\cbook\__init__.py:1804: in _lock_path
    with lock_path.open("xb"):
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\pathlib.py:1252: in open
    return io.open(self, mode, buffering, encoding, errors, newline,
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\pathlib.py:1120: in _opener
    return self._accessor.open(self, flags, mode)
E   PermissionError: [Errno 13] Permission denied: 'C:\\Users\\runneradmin\\.matplotlib\\fontlist-v330.json.matplotlib-lock'
```

The use of matplotlib is only for an optional debugging feature anyway, so just make it an optional dependency.
2023-06-14 09:02:49 -04:00
Samuel Ainsworth
82d894e41c
build/BUILD.bazel: remove unused import 2023-06-12 18:36:32 -04:00
Sharad Vikram
bf8ed6a543 Move triton_kernel_call_lib to jaxlib
PiperOrigin-RevId: 534934592
2023-05-24 12:11:21 -07:00
Peter Hawkins
a18e82b28b Update bazel version to 6.1.2.
Several of our CI builds are already using 6.1.2, so it's probably best to upgrade for consistency.
2023-05-10 10:57:29 -04:00
Yash Katariya
6506ee2a40 Copybara import of the project:
--
57af5360a1ca1356dbf7760c76e241f7134ef6dd by Jake VanderPlas <jakevdp@google.com>:

[Rollback] Update required Python version to 3.9

PiperOrigin-RevId: 528905991
2023-05-02 15:33:29 -07:00
Jake VanderPlas
57af5360a1 Update required Python version to 3.9 2023-05-01 10:00:57 -07:00
Peter Hawkins
3bb7386149 [JAX] Improve handling of metadata in compilation cache.
Metadata, in particular code location information is present in the HLO generated by JAX. The compilation cache uses the serialized HLO as a cache key, which begs the question: should code location information be part of that key? Simply changing the line number on which a function appears shouldn't necessarily cause a cache miss.

There are pros and cons: the main advantage of excluding metadata is that we will get more cache hits, and the main disadvantage is that debug information and profiling data in the HLO might become confusing, since it may refer to a different program entirely, or to a version of a program that does not correspond to the current state of the source tree. We argue that saving compilation time is the more important concern.

This change adds a tiny MLIR pass that strips Locations from a StableHLO module, and applies it in the compilation cache if metadata stripping is enabled.

PiperOrigin-RevId: 525534901
2023-04-19 13:27:04 -07:00
Yash Katariya
b38e85b3a4 Package utils.cc properly in jaxlib so that if jaxlib nightly is installed and then used, jaxlib_utils can be accessed.
PiperOrigin-RevId: 523374835
2023-04-11 05:38:30 -07:00
Rahul Batra
13e45c8953 [ROCm]: Run pmap test on specific number of GPUs 2023-03-30 18:34:47 +00:00
jax authors
6715736583 Merge pull request #15205 from yhtang:editable-jaxlib-build
PiperOrigin-RevId: 519704474
2023-03-27 06:33:31 -07:00
Yu-Hang 'Maxin' Tang
caaa0a2669 add build option to create editable jaxlib
Co-authored-by: Yonghao Zhuang <zhuangyh@sjtu.edu.cn>
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
2023-03-24 21:25:26 +00:00
Peter Hawkins
6ed66ada0f Delete remote TPU support.
TPU VMs are the only supported way to use TPUs as of JAX 0.4.0.

PiperOrigin-RevId: 519211267
2023-03-24 12:33:33 -07:00
Peter Hawkins
b7375b316b Increase minimum NumPy version to 1.21.
Also increase minimum SciPy version to 1.7, which was released just before NumPy 1.21.
2023-03-23 21:15:10 -04:00
Peter Hawkins
172a831219 Switch JAX to use the OpenXLA repository. 2023-03-13 18:38:26 +00:00
jax authors
edff87eb07 Merge pull request #13613 from ROCmSoftwarePlatform:rocm_rt_build
PiperOrigin-RevId: 510440289
2023-02-17 08:40:28 -08:00
Chao
0dde7a0fb1
Update Dockerfile.ms
update to ROCm5.4
2023-02-17 14:33:33 +00:00
Stella Laurenzo
c1e13bdf3f A few developer workflow enhancements for working with jaxlib.
It seems to me that jaxlib development must be mostly happening on CI, because some basics are pretty essential. Here are a few things I've been typing/carrying for a while in my flow:

* Add .bazelrc.user to .gitignore so it doesn't accidentally get checked in.
* Add configs for 'debug_symbols' and 'debug' that make some things minimally workable under a debugger (or to get backtraces, etc).
* Add `--force-reinstall` to the copy/paste command to update a built jaxlib wheel (without this, if you are iterating, it fairly quietly does nothing).
2023-02-10 21:03:21 -08:00
Rahul Batra
023226e181 [ROCm]: Move dockerfile to ROCm5.4 2023-02-09 20:08:35 +00:00