507 Commits

Author SHA1 Message Date
Sergei Lebedev
3b1b5fda81 Added filelock to test-requirements.txt and requirements lock files
This is a follow up to #21741.
2024-06-11 11:53:10 +01:00
Adam Paszke
cfe64cd5ce [Mosaic GPU] Integrate the ExecutionEngine with the jaxlib GPU plugin
This lets us avoid bundling a whole another copy of LLVM with JAX packages
and so we can finally start building Mosaic GPU by default.

PiperOrigin-RevId: 638569750
2024-05-30 01:46:23 -07:00
Peter Hawkins
5968db592c Pin matplotlib < 3.9 for Python 3.10 and earlier.
matplotlib 3.9.0 pins NumPy 1.23 or newer, which is incompatible with
our minimum Numpy pin.
2024-05-22 15:07:09 +00:00
Vadym Matsishevskyi
45a7c22e93 fix: Update hermetic python dependencies to numpy=2.0.0rc2 and scipy=1.13.0 for all python version
Also install built jaxlib in hermetic python to support //jax:build_jaxlib=false tests.

PiperOrigin-RevId: 635169327
2024-05-18 23:39:09 -07:00
Vadym Matsishevskyi
517e299a9d Use hermetic Python in JAX, see "Managing hermetic Python" in developer.md for details
PiperOrigin-RevId: 634146391
2024-05-15 18:20:56 -07:00
jax authors
174405c953 The Bazel version used in JAX is bumped from 6.1.2 to 6.5.0.
The update is needed for Windows/Clang builds and for the future hermetic CUDA implementation.

PiperOrigin-RevId: 631519200
2024-05-07 12:58:37 -07:00
jax authors
8ba5c64794 Pass bazel_options directly to the Bazel command, instead of into .bazelrc.
PiperOrigin-RevId: 631099970
2024-05-06 10:05:19 -07:00
Jieying Luo
16b4f69769 Rename arg in build script to be more clear.
The flag means skips GPU plugin extension in jaxlib.

PiperOrigin-RevId: 627203738
2024-04-22 17:22:24 -07:00
Adam Paszke
8e3f5b1018 Initial commit for Mosaic GPU
Moving this to JAX to make it easier to explore Pallas integration.

PiperOrigin-RevId: 625982382
2024-04-18 04:04:10 -07:00
Pearu Peterson
fdb5015909 Evaluate the correctness of JAX complex functions using mpmath as a reference 2024-03-21 23:35:29 +02:00
David Dunleavy
6928465b87 Add --use_clang and --clang_path options to build.py
PiperOrigin-RevId: 603837975
2024-02-02 18:20:44 -08:00
zahiqbal
ef7694f26a [ROCM]: Generating pytest html logs from unit-tests. 2024-01-24 15:08:35 +00:00
Rahul Batra
b7a7f0bd80 [ROCm]: Dockerfile updates 2024-01-22 16:08:37 +00:00
Parker Schuh
23b9c2a22f Add the githash that the jaxlib was built at to __init__.py. This is to allow identifying the githash of nightlies.
PiperOrigin-RevId: 595529249
2024-01-03 16:12:23 -08:00
Jieying Luo
c8b3567e82 Add two flags to support only building cuda kernel plugin or cuda pjrt plugin.
PiperOrigin-RevId: 591274120
2023-12-15 09:15:46 -08:00
George Necula
b077483bfa [export] Add support for serialization and deserialization of Exported
At the moment we can export a JAX function into an Exported and we can invoke an Exported from another JAX function, but there is no way to serialize an Exported to be able to use it in another process.

Exported now has all the features we had in mind, so it is a reasonable time to add a serialization method. The intention is for the serialization to have backwards compatibility guarantees, meaning that we can deserialize an Exported that has been serialized with an older version of JAX. This PR does not add explicit support for versioning, nor backwards compatibility tests. Those will follow.

Here we add serialization and deserialization to bytearray, using the flatbuffers package. We use flatbuffers because it is simple, it has backwards and forwards compatibility guarantees, it is a lightweight dependency that does not require additional build steps, and it is fast (deserialization simply indexes into the bytearray rather than creating a Python structure).

In the process of implementing this we have done some small cleanup of the Exported structure:

  * renamed serialization_version to mlir_module_serialization_version
  * renamed disabled_checks to disabled_safety_checks

This code is tested by changing export_test.py to interpose a serialization followed by a deserialization every time we export.export.

There is a known bug with the serialization of effects, so I disabled one of the export tests. Will fix in a subsequent PR.

PiperOrigin-RevId: 590078785
2023-12-11 23:23:02 -08:00
Jieying Luo
0ce7c7b7bd Register plugin profiler for TPU and remove --config=tpu/--enable_tpu in jaxlib.
PiperOrigin-RevId: 580561059
2023-11-08 09:40:28 -08:00
Jieying Luo
462ef165c4 [PJRT C API] Change build wheel script to build a separate package for cuda kernels.
With this change, `python3 build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12` will generate three wheels:

|                      |size|wheel name                                                               |
|----------------------|----|-------------------------------------------------------------------------|
|jaxlib w/o cuda kernels|76M |jaxlib-0.4.20.dev20231101-cp310-cp310-manylinux2014_x86_64.whl           |
|cuda pjrt              |73M|jax_cuda12_pjrt-0.4.20.dev20231101-py3-none-manylinux2014_x86_64.whl                    |
|cuda kernels           |6.6M|jax_cuda12_plugin-0.4.20.dev20231101-cp310-cp310-manylinux2014_x86_64.whl|

The size of jaxlib with cuda kernels and pjrt is 119M.

The cuda kernel wheel contains all the cuda kernels. A plugin_setup.py and plugin_pyproject.toml are added for this new pacakge.

PiperOrigin-RevId: 579861480
2023-11-06 09:13:44 -08:00
Jieying Luo
0290150c4c Build jaxlib without PJRT GPU deps when plugin will be built.
PiperOrigin-RevId: 573844805
2023-10-16 09:59:07 -07:00
Jieying Luo
432506f1ae [PJRT C API] Fixed pjrt_c_api_gpu and remove noincompatible_remove_legacy_whole_archive
PiperOrigin-RevId: 573094387
2023-10-12 21:25:25 -07:00
Peter Hawkins
73db6ecf2f Set -P when testing whether a package is installed during build.py.
(Only on Python 3.11+)

The test for the "build" package being installed always succeeded because of the subdirectory named "build".
2023-10-10 10:30:37 -04:00
Peter Hawkins
6e5409c008 Add missing raise to build.py 2023-09-26 22:11:01 -04:00
Peter Hawkins
cf28e2c5fa Small improvements to build/build.py
Add a --verbose option that logs all shell() commands run by the script.
Remove some Python 2 backward compatibility logic related to urllib and shutil.

Enable debug logging on Windows wheel builds.

Also include setuptools in the build requirements and test for its presence in build.py.
2023-09-26 21:55:45 -04:00
Jieying Luo
3dbd3649ab Do not share the build command between building jaxlib and gpu plugin as their commands diverge.
PiperOrigin-RevId: 568625959
2023-09-26 13:03:00 -07:00
Jieying Luo
ea01085522 Fix build.py to set include_gpu_plugin_extension as bool flag.
PiperOrigin-RevId: 568572158
2023-09-26 10:01:55 -07:00
Jieying Luo
c7f60fa6eb [PJRT C API] Implement framework side change for registering a custom call.
- Add a py extension to call the custom call C API.
- Change the implementation of register_custom_call_target to store handlers for the custom call targets and delays the registration until the handler for a xla platform is registered.
- Change register_plugin to load PJRT plugin when register_pluin is called (instead of when a client is created), and let it return the PJRT_Api* loaded.
- Delay calling discover_pjrt_plugins() and register_pjrt_plugin_factories_from_env() until the first time backends() is called.

PiperOrigin-RevId: 568265745
2023-09-25 10:52:29 -07:00
Jieying Luo
91fbf9da26 [PJRT C API] Set up jax xla cuda package.
Add a build wheel, pyproject.toml and setup.py.

The directory structure in jax repo is:
jax/
└── plugins/
     └── cuda/
          ├── __init__.py
          ├── pyproject.toml
          └── setup.py

Installed package structure is:
jax_plugins/
     └── xla_cuda_cu12/
           ├── __init__.py
           └── xla_cuda_plugin.so

The major cuda version will be part of the package name.

The plugin wheel can be built with command:
python3 build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12 --bazel_options="--override_repository=xla=$HOME/xla"

PiperOrigin-RevId: 565187954
2023-09-13 16:03:53 -07:00
Rahul Batra
ef79c19093 [ROCm]: Dockerfile and build script updates
Add hipblaslt in Dockerfile
	Update docker file to default to ROCm5.6
 	CI scripts update to handle multiple ROCm versions
2023-09-08 22:56:59 +00:00
Sharad Vikram
d872812a35 [Pallas] Upstream pallas to JAX
PiperOrigin-RevId: 552963029
2023-08-01 16:43:13 -07:00
Peter Hawkins
3c4527b6b0 Check build and wheel are installed before building jaxlib. 2023-07-26 11:46:11 -07:00
Skye Wanderman-Milne
6c909760d5 Cloud TPU CI: make sure we update test deps and upgrade protobuf version
`profiler_test.py:ProfilerTest.test_remote_profiler` fails with the
protobuf upgrade. However, I was seeing mysterious hangs without this,
and in general I think we should be testing with up-to-date deps given
that we don't pin. I'm gonna continue working on getting the Cloud TPU
CI green.
2023-07-19 16:17:47 -07:00
Peter Hawkins
1aa09dbd3e Fix broken jaxlib wheel build by moving LICENSE.txt to the correct location.
PiperOrigin-RevId: 548189924
2023-07-14 12:13:28 -07:00
Peter Hawkins
f540ae4338 Fix warning about direct invocation of setup.py during jaxlib build.
The jaxlib wheel build currently uses `python setup.py bdist_wheel` to construct the wheel. Change it to use `python -m build -w` instead.

To avoid Python getting confused between the directory named `build` in the bazel tree and the Python `build` module, move `build_wheel.py` into `jaxlib/tools`.

PiperOrigin-RevId: 548133811
2023-07-14 08:31:16 -07:00
Peter Hawkins
1d4b10b775 Remove --distinct_host_configuration from Bazel flags.
This flag does nothing under Bazel 6 and will be removed in Bazel 7.
2023-07-11 11:38:05 -04:00
jax authors
60d481078e Merge pull request #16523 from ROCmSoftwarePlatform:rocm-update-build-doc
PiperOrigin-RevId: 547185925
2023-07-11 07:41:51 -07:00
Jake VanderPlas
ad35702934 Drop support for numpy 1.21
This is in accordance with NEP 29 and https://jax.readthedocs.io/en/latest/deprecation.html
2023-06-23 10:28:26 -07:00
Peter Hawkins
bfa113ba60 Remove references to Python 3.8.
Remove the old build scripts/Dockerfile, since they are unused and broken.

PiperOrigin-RevId: 542870354
2023-06-23 08:48:57 -07:00
Yash Katariya
fc0dcd15a2 Copybara import of the project:
--
57af5360a1ca1356dbf7760c76e241f7134ef6dd by Jake VanderPlas <jakevdp@google.com>:

[Roll forward] Update required Python version to 3.9

PiperOrigin-RevId: 542728213
2023-06-22 18:58:30 -07:00
Rahul Batra
80bd361364 [ROCm]: Update README.md 2023-06-22 15:53:04 +00:00
jax authors
961d918883 Merge pull request #16333 from skye:profiler_test
PiperOrigin-RevId: 542069606
2023-06-20 15:40:10 -07:00
Skye Wanderman-Milne
2ca151ef5b profiler_test.py fixes and add coverage to Cloud TPU CI
* Add deps to test requirements, including in new
  `collect-profile-requirements.txt` (to avoid adding tensorflow to
  `test-requirements.txt`).
* Use correct Python executable `ProfilerTest.test_remote_profiler`
  (`python` sometimes defaults to python2)
* Run computations for longer in `ProfilerTest.test_remote_profiler`,
  othewise `collect_profile` sometimes misses it.
2023-06-20 22:25:17 +00:00
Rahul Batra
88a2c8ca5e [ROCm]: Updates defaults in build script 2023-06-20 20:25:40 +00:00
jax authors
632c79093e Merge pull request #16437 from ROCmSoftwarePlatform:rocm-updates-build
PiperOrigin-RevId: 541896270
2023-06-20 06:19:59 -07:00
Rahul Batra
b0e541a730 [ROCm]: Updates for container and build script
-Updated dockerfile.ms
	-Updated build script to switch building against XLA repo
  	-Update CI script
	-Update jaxlib setup.py to add rocm version
2023-06-19 18:13:28 +00:00
Peter Hawkins
119661ce6b Remove older plugin device integration.
Users of this mechanism should migrate to the newer PJRT plugin registration mechanism (see the comments on discover_plugins() in this file).
2023-06-14 15:26:58 -04:00
Peter Hawkins
0c441574c4 Add NumPy as a test requirement.
The Windows CI currently installs all of the test requirements before building jaxlib, but NumPy is needed to build jaxlib.
Previously this came transitively via matplotlib.
2023-06-14 10:14:56 -04:00
Peter Hawkins
6b76937c53 Remove matplotlib from the test requirements.
In the Windows CI, we seem to be hitting the following error:

```
=================================== ERRORS ====================================
____________________ ERROR collecting tests/lobpcg_test.py ____________________
tests\lobpcg_test.py:28: in <module>
    from matplotlib import pyplot as plt
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\pyplot.py:52: in <module>
    import matplotlib.colorbar
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\colorbar.py:19: in <module>
    from matplotlib import _api, cbook, collections, cm, colors, contour, ticker
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\contour.py:13: in <module>
    from matplotlib.backend_bases import MouseButton
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\backend_bases.py:45: in <module>
    from matplotlib import (
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\text.py:16: in <module>
    from .font_manager import FontProperties
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\font_manager.py:1548: in <module>
    fontManager = _load_fontmanager()
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\font_manager.py:1543: in _load_fontmanager
    json_dump(fm, fm_path)
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\font_manager.py:957: in json_dump
    with cbook._lock_path(filename), open(filename, 'w') as fh:
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\contextlib.py:119: in __enter__
    return next(self.gen)
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\cbook\__init__.py:1804: in _lock_path
    with lock_path.open("xb"):
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\pathlib.py:1252: in open
    return io.open(self, mode, buffering, encoding, errors, newline,
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\pathlib.py:1120: in _opener
    return self._accessor.open(self, flags, mode)
E   PermissionError: [Errno 13] Permission denied: 'C:\\Users\\runneradmin\\.matplotlib\\fontlist-v330.json.matplotlib-lock'
```

The use of matplotlib is only for an optional debugging feature anyway, so just make it an optional dependency.
2023-06-14 09:02:49 -04:00
Samuel Ainsworth
82d894e41c
build/BUILD.bazel: remove unused import 2023-06-12 18:36:32 -04:00
Sharad Vikram
bf8ed6a543 Move triton_kernel_call_lib to jaxlib
PiperOrigin-RevId: 534934592
2023-05-24 12:11:21 -07:00
Peter Hawkins
a18e82b28b Update bazel version to 6.1.2.
Several of our CI builds are already using 6.1.2, so it's probably best to upgrade for consistency.
2023-05-10 10:57:29 -04:00