This lets us avoid bundling a whole another copy of LLVM with JAX packages
and so we can finally start building Mosaic GPU by default.
PiperOrigin-RevId: 638569750
At the moment we can export a JAX function into an Exported and we can invoke an Exported from another JAX function, but there is no way to serialize an Exported to be able to use it in another process.
Exported now has all the features we had in mind, so it is a reasonable time to add a serialization method. The intention is for the serialization to have backwards compatibility guarantees, meaning that we can deserialize an Exported that has been serialized with an older version of JAX. This PR does not add explicit support for versioning, nor backwards compatibility tests. Those will follow.
Here we add serialization and deserialization to bytearray, using the flatbuffers package. We use flatbuffers because it is simple, it has backwards and forwards compatibility guarantees, it is a lightweight dependency that does not require additional build steps, and it is fast (deserialization simply indexes into the bytearray rather than creating a Python structure).
In the process of implementing this we have done some small cleanup of the Exported structure:
* renamed serialization_version to mlir_module_serialization_version
* renamed disabled_checks to disabled_safety_checks
This code is tested by changing export_test.py to interpose a serialization followed by a deserialization every time we export.export.
There is a known bug with the serialization of effects, so I disabled one of the export tests. Will fix in a subsequent PR.
PiperOrigin-RevId: 590078785
With this change, `python3 build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12` will generate three wheels:
| |size|wheel name |
|----------------------|----|-------------------------------------------------------------------------|
|jaxlib w/o cuda kernels|76M |jaxlib-0.4.20.dev20231101-cp310-cp310-manylinux2014_x86_64.whl |
|cuda pjrt |73M|jax_cuda12_pjrt-0.4.20.dev20231101-py3-none-manylinux2014_x86_64.whl |
|cuda kernels |6.6M|jax_cuda12_plugin-0.4.20.dev20231101-cp310-cp310-manylinux2014_x86_64.whl|
The size of jaxlib with cuda kernels and pjrt is 119M.
The cuda kernel wheel contains all the cuda kernels. A plugin_setup.py and plugin_pyproject.toml are added for this new pacakge.
PiperOrigin-RevId: 579861480
Add a --verbose option that logs all shell() commands run by the script.
Remove some Python 2 backward compatibility logic related to urllib and shutil.
Enable debug logging on Windows wheel builds.
Also include setuptools in the build requirements and test for its presence in build.py.
- Add a py extension to call the custom call C API.
- Change the implementation of register_custom_call_target to store handlers for the custom call targets and delays the registration until the handler for a xla platform is registered.
- Change register_plugin to load PJRT plugin when register_pluin is called (instead of when a client is created), and let it return the PJRT_Api* loaded.
- Delay calling discover_pjrt_plugins() and register_pjrt_plugin_factories_from_env() until the first time backends() is called.
PiperOrigin-RevId: 568265745
Add a build wheel, pyproject.toml and setup.py.
The directory structure in jax repo is:
jax/
└── plugins/
└── cuda/
├── __init__.py
├── pyproject.toml
└── setup.py
Installed package structure is:
jax_plugins/
└── xla_cuda_cu12/
├── __init__.py
└── xla_cuda_plugin.so
The major cuda version will be part of the package name.
The plugin wheel can be built with command:
python3 build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12 --bazel_options="--override_repository=xla=$HOME/xla"
PiperOrigin-RevId: 565187954
`profiler_test.py:ProfilerTest.test_remote_profiler` fails with the
protobuf upgrade. However, I was seeing mysterious hangs without this,
and in general I think we should be testing with up-to-date deps given
that we don't pin. I'm gonna continue working on getting the Cloud TPU
CI green.
The jaxlib wheel build currently uses `python setup.py bdist_wheel` to construct the wheel. Change it to use `python -m build -w` instead.
To avoid Python getting confused between the directory named `build` in the bazel tree and the Python `build` module, move `build_wheel.py` into `jaxlib/tools`.
PiperOrigin-RevId: 548133811
* Add deps to test requirements, including in new
`collect-profile-requirements.txt` (to avoid adding tensorflow to
`test-requirements.txt`).
* Use correct Python executable `ProfilerTest.test_remote_profiler`
(`python` sometimes defaults to python2)
* Run computations for longer in `ProfilerTest.test_remote_profiler`,
othewise `collect_profile` sometimes misses it.
The Windows CI currently installs all of the test requirements before building jaxlib, but NumPy is needed to build jaxlib.
Previously this came transitively via matplotlib.
In the Windows CI, we seem to be hitting the following error:
```
=================================== ERRORS ====================================
____________________ ERROR collecting tests/lobpcg_test.py ____________________
tests\lobpcg_test.py:28: in <module>
from matplotlib import pyplot as plt
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\pyplot.py:52: in <module>
import matplotlib.colorbar
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\colorbar.py:19: in <module>
from matplotlib import _api, cbook, collections, cm, colors, contour, ticker
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\contour.py:13: in <module>
from matplotlib.backend_bases import MouseButton
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\backend_bases.py:45: in <module>
from matplotlib import (
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\text.py:16: in <module>
from .font_manager import FontProperties
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\font_manager.py:1548: in <module>
fontManager = _load_fontmanager()
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\font_manager.py:1543: in _load_fontmanager
json_dump(fm, fm_path)
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\font_manager.py:957: in json_dump
with cbook._lock_path(filename), open(filename, 'w') as fh:
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\contextlib.py:119: in __enter__
return next(self.gen)
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\cbook\__init__.py:1804: in _lock_path
with lock_path.open("xb"):
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\pathlib.py:1252: in open
return io.open(self, mode, buffering, encoding, errors, newline,
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\pathlib.py:1120: in _opener
return self._accessor.open(self, flags, mode)
E PermissionError: [Errno 13] Permission denied: 'C:\\Users\\runneradmin\\.matplotlib\\fontlist-v330.json.matplotlib-lock'
```
The use of matplotlib is only for an optional debugging feature anyway, so just make it an optional dependency.