490 Commits

Author SHA1 Message Date
jax authors
140955dce0 Merge pull request #23224 from ROCm:ci_build_script
PiperOrigin-RevId: 668061975
2024-08-27 11:12:05 -07:00
Ayaka
859eacb5a1 Fix mypy error 2024-08-27 16:57:53 +01:00
Ruturaj4
9ce8de5fb0 [ROCm] add build file. 2024-08-23 18:11:48 -05:00
Mathew Odden
5c2ffa893f * Add conditional docker interactive mode
Interactive causes bazel to output more
useful info when running locally.

* Fix issue with rocm el8 repo urls

Work around quirk with rocm version
when it ends with 0

* Fix package name conflict

Ubu22 and higher have a package name conflict
between the debian versions and the AMD provided
versions.

* [ROCm] Use clang env
2024-08-22 10:08:41 -05:00
Ruturaj4
e06be544d4 [ROCm] improve gpu script 2024-08-19 15:13:58 -05:00
jax authors
3a6ff86df9 Merge pull request #23088 from vfdev-5:update-build-build-py
PiperOrigin-RevId: 663876121
2024-08-16 14:50:57 -07:00
jax authors
3d942ef22e Merge pull request #23086 from ROCm:ci_build_fix
PiperOrigin-RevId: 663769517
2024-08-16 10:17:04 -07:00
vfdev-5
12e8bf4525 Pass bazel options to requirements_update and requirements_nightly_update commands 2024-08-16 11:47:21 +02:00
Ruturaj4
fd7c52d213 [ROCm] Fix python in rocm ci_build script. 2024-08-15 21:33:51 -05:00
jax authors
a498c1e668 Set Clang as the default compiler in the build script.
PiperOrigin-RevId: 663433112
2024-08-15 13:36:15 -07:00
jax authors
599c13aa09 Introduce hermetic CUDA in Google ML projects.
1) Hermetic CUDA rules allow building wheels with GPU support on a machine without GPUs, as well as running Bazel GPU tests on a machine with only GPUs and NVIDIA driver installed. When `--config=cuda` is provided in Bazel options, Bazel will download CUDA, CUDNN and NCCL redistributions in the cache, and use them during build and test phases.

    [Default location of CUNN redistributions](https://developer.download.nvidia.com/compute/cudnn/redist/)

    [Default location of CUDA redistributions](https://developer.download.nvidia.com/compute/cuda/redist/)

    [Default location of NCCL redistributions](https://pypi.org/project/nvidia-nccl-cu12/#history)

2) To include hermetic CUDA rules in your project, add the following in the WORKSPACE of the downstream project dependent on XLA.

   Note: use `@local_tsl` instead of `@tsl` in Tensorflow project.

   ```
   load(
      "@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl",
      "cuda_json_init_repository",
   )

   cuda_json_init_repository()

   load(
      "@cuda_redist_json//:distributions.bzl",
      "CUDA_REDISTRIBUTIONS",
      "CUDNN_REDISTRIBUTIONS",
   )
   load(
      "@tsl//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl",
      "cuda_redist_init_repositories",
      "cudnn_redist_init_repository",
   )

   cuda_redist_init_repositories(
      cuda_redistributions = CUDA_REDISTRIBUTIONS,
   )

   cudnn_redist_init_repository(
      cudnn_redistributions = CUDNN_REDISTRIBUTIONS,
   )

   load(
      "@tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl",
      "cuda_configure",
   )

   cuda_configure(name = "local_config_cuda")

   load(
      "@tsl//third_party/nccl/hermetic:nccl_redist_init_repository.bzl",
      "nccl_redist_init_repository",
   )

   nccl_redist_init_repository()

   load(
      "@tsl//third_party/nccl/hermetic:nccl_configure.bzl",
      "nccl_configure",
   )

   nccl_configure(name = "local_config_nccl")
   ```

PiperOrigin-RevId: 662981325
2024-08-14 10:58:43 -07:00
Mathew Odden
fafa03c60f Add missing CPython build deps for pyenv 2024-08-12 15:01:34 -05:00
Mathew Odden
701cda8ebd Fix not finding wheels in bazel output 2024-08-12 15:01:34 -05:00
Mathew Odden
df2d140f51 Fix jenkins notty issue 2024-08-12 15:01:34 -05:00
Mathew Odden
319ebf81c1 Add defaults for ROCm build vars 2024-08-12 15:01:34 -05:00
Mathew Odden
abe44f6d9e Add copyright and license headers to new files 2024-08-12 15:01:34 -05:00
Mathew Odden
a1a0a4ecdd Add support for ROCm development builds
Use get_rocm.py changes in ci_build to pull in
development builds for ROCm.

Specify ROCM_BUILD_JOB and ROCM_BUILD_NUM for
activating the development build path.
2024-08-12 15:01:34 -05:00
Mathew Odden
3175f13c59 Add internal release support to get_rocm.py 2024-08-12 15:01:34 -05:00
Mathew Odden
1e58d76772 [ROCm] Change ROCm builds to manylinux wheels 2024-08-12 15:01:34 -05:00
Jieying Luo
9c2caedab1 Add subdirectories to the output path when building editable wheels for jaxlib and GPU plugin.
When `build_gpu_plugin` is true, three wheels will be produced (jaxlib, jax-cuda-pjrt and jax-cuda-plugin). If they are editable, they need to be placed in subdirectories to avoid overwrite.

Tested on GPU. After the editable wheels are built, they can be installed with `pip install -e /jax/dist/jax_gpu_pjrt /jax/dist/jaxlib /jax/dist/jax_gpu_plugin`.

PiperOrigin-RevId: 660984311
2024-08-08 14:34:55 -07:00
Jieying Luo
abe7982d65 Remove enable_gpu and xla_python_enable_gpu from jax .bazelrc.
The plugin is released and the flag is no longer needed.

Also set default value of enable_gpu to False. enable_gpu will be removed in the next change.

PiperOrigin-RevId: 660059432
2024-08-06 12:39:45 -07:00
Yue Sheng
eb571c984a Fix lint in run_single_gpu.py
PiperOrigin-RevId: 658933291
2024-08-02 16:03:03 -07:00
Rahul Batra
7d6fa3c05b [ROCm]: Add support to continue on fail, fix script paths and update Dockerfile to add necessary packages 2024-08-01 17:55:15 -05:00
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
Peter Hawkins
945fde41e4 Update minimum Python version to 3.10. 2024-06-26 13:47:14 -04:00
Dan Foreman-Mackey
1c640b25cb Fix pre-commit formatting errors
It looks like #21916 introduced some whitespace issues that are causing
our pre-commits to fail.
2024-06-26 06:38:51 -04:00
jax authors
96cf5d53c8 Merge pull request #21916 from ROCm:ci_pjrt
PiperOrigin-RevId: 646793145
2024-06-26 02:43:21 -07:00
Ruturaj Vaidya
385283c50b
Update build.py 2024-06-25 11:31:18 -05:00
Ruturaj4
a00d030248 [ROCM] nits and fixes 2024-06-18 20:21:23 +00:00
dependabot[bot]
b26af70949
Bump scipy from 1.13.0 to 1.13.1
Bumps [scipy](https://github.com/scipy/scipy) from 1.13.0 to 1.13.1.
- [Release notes](https://github.com/scipy/scipy/releases)
- [Commits](https://github.com/scipy/scipy/compare/v1.13.0...v1.13.1)

---
updated-dependencies:
- dependency-name: scipy
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-06-17 19:58:11 +00:00
Peter Hawkins
160e09e235 Use NumPy 2.0.0 and SciPy 1.13.1 in builds.
Don't override the XLA repository in the nightly Windows CI builds,
which should be building JAX as it exists in the source repository.
2024-06-17 19:35:08 +00:00
Ruturaj4
690dba546c [JAX] change rocm build file for pjrt plugin 2024-06-17 16:49:28 +00:00
Ruturaj4
99c2b7b4e9 [ROCm] Bring-up pjrt support 2024-06-17 16:49:22 +00:00
Sergei Lebedev
3b1b5fda81 Added filelock to test-requirements.txt and requirements lock files
This is a follow up to #21741.
2024-06-11 11:53:10 +01:00
Adam Paszke
cfe64cd5ce [Mosaic GPU] Integrate the ExecutionEngine with the jaxlib GPU plugin
This lets us avoid bundling a whole another copy of LLVM with JAX packages
and so we can finally start building Mosaic GPU by default.

PiperOrigin-RevId: 638569750
2024-05-30 01:46:23 -07:00
Peter Hawkins
5968db592c Pin matplotlib < 3.9 for Python 3.10 and earlier.
matplotlib 3.9.0 pins NumPy 1.23 or newer, which is incompatible with
our minimum Numpy pin.
2024-05-22 15:07:09 +00:00
Vadym Matsishevskyi
45a7c22e93 fix: Update hermetic python dependencies to numpy=2.0.0rc2 and scipy=1.13.0 for all python version
Also install built jaxlib in hermetic python to support //jax:build_jaxlib=false tests.

PiperOrigin-RevId: 635169327
2024-05-18 23:39:09 -07:00
Vadym Matsishevskyi
517e299a9d Use hermetic Python in JAX, see "Managing hermetic Python" in developer.md for details
PiperOrigin-RevId: 634146391
2024-05-15 18:20:56 -07:00
jax authors
174405c953 The Bazel version used in JAX is bumped from 6.1.2 to 6.5.0.
The update is needed for Windows/Clang builds and for the future hermetic CUDA implementation.

PiperOrigin-RevId: 631519200
2024-05-07 12:58:37 -07:00
jax authors
8ba5c64794 Pass bazel_options directly to the Bazel command, instead of into .bazelrc.
PiperOrigin-RevId: 631099970
2024-05-06 10:05:19 -07:00
Jieying Luo
16b4f69769 Rename arg in build script to be more clear.
The flag means skips GPU plugin extension in jaxlib.

PiperOrigin-RevId: 627203738
2024-04-22 17:22:24 -07:00
Adam Paszke
8e3f5b1018 Initial commit for Mosaic GPU
Moving this to JAX to make it easier to explore Pallas integration.

PiperOrigin-RevId: 625982382
2024-04-18 04:04:10 -07:00
Pearu Peterson
fdb5015909 Evaluate the correctness of JAX complex functions using mpmath as a reference 2024-03-21 23:35:29 +02:00
David Dunleavy
6928465b87 Add --use_clang and --clang_path options to build.py
PiperOrigin-RevId: 603837975
2024-02-02 18:20:44 -08:00
zahiqbal
ef7694f26a [ROCM]: Generating pytest html logs from unit-tests. 2024-01-24 15:08:35 +00:00
Rahul Batra
b7a7f0bd80 [ROCm]: Dockerfile updates 2024-01-22 16:08:37 +00:00
Parker Schuh
23b9c2a22f Add the githash that the jaxlib was built at to __init__.py. This is to allow identifying the githash of nightlies.
PiperOrigin-RevId: 595529249
2024-01-03 16:12:23 -08:00
Jieying Luo
c8b3567e82 Add two flags to support only building cuda kernel plugin or cuda pjrt plugin.
PiperOrigin-RevId: 591274120
2023-12-15 09:15:46 -08:00
George Necula
b077483bfa [export] Add support for serialization and deserialization of Exported
At the moment we can export a JAX function into an Exported and we can invoke an Exported from another JAX function, but there is no way to serialize an Exported to be able to use it in another process.

Exported now has all the features we had in mind, so it is a reasonable time to add a serialization method. The intention is for the serialization to have backwards compatibility guarantees, meaning that we can deserialize an Exported that has been serialized with an older version of JAX. This PR does not add explicit support for versioning, nor backwards compatibility tests. Those will follow.

Here we add serialization and deserialization to bytearray, using the flatbuffers package. We use flatbuffers because it is simple, it has backwards and forwards compatibility guarantees, it is a lightweight dependency that does not require additional build steps, and it is fast (deserialization simply indexes into the bytearray rather than creating a Python structure).

In the process of implementing this we have done some small cleanup of the Exported structure:

  * renamed serialization_version to mlir_module_serialization_version
  * renamed disabled_checks to disabled_safety_checks

This code is tested by changing export_test.py to interpose a serialization followed by a deserialization every time we export.export.

There is a known bug with the serialization of effects, so I disabled one of the export tests. Will fix in a subsequent PR.

PiperOrigin-RevId: 590078785
2023-12-11 23:23:02 -08:00
Jieying Luo
0ce7c7b7bd Register plugin profiler for TPU and remove --config=tpu/--enable_tpu in jaxlib.
PiperOrigin-RevId: 580561059
2023-11-08 09:40:28 -08:00