12 Commits

Author SHA1 Message Date
jax authors
40a3d0c78d Create the test targets for the wheel size verification.
Add the tests to the Bazel presubmit RBE jobs (except `arm64`/`aarch64` jobs that use RBE cross-compilation).

PiperOrigin-RevId: 742724458
2025-04-01 09:11:56 -07:00
jax authors
1b7c8e8d08 Add editable jax wheel target.
The set of editable wheels (`jaxlib`, `jax-cuda-plugin` and `jax-cuda-pjrt`) was used as dependencies in `requirements.in` file together with `:build_jaxlib=false` flag.

After [adding `jax` wheel dependencies](f5a4d1a85c) to the tests when `:build_jaxlib=false` is used, we need an editable `jax` wheel target as well to get the tests passing.

PiperOrigin-RevId: 740840736
2025-03-26 11:25:52 -07:00
jax authors
9f40440d47 Add missing jax wheel dependencies.
PiperOrigin-RevId: 740767116
2025-03-26 07:57:28 -07:00
jax authors
49aad1b97f Add the missing flatbuffers dependency for the tests that run under :build_jaxlib=false.
PiperOrigin-RevId: 740115575
2025-03-24 16:40:22 -07:00
jax authors
f5a4d1a85c Enable jax wheel testing via Bazel.
Remove jax dependencies from the Bazel test targets for `:build_jaxlib=false` and `:build_jaxlib=wheel`.

`internal_test_util` is removed from the `jax` wheel. To use this package in Bazel py_test, we need to copy it to the unpacked wheel folder. This is done by adding `wheel_deps` value to `py_import` Jax targets.

This change concludes ML Wheels design implementation in JAX and enables testing of all wheels via Bazel command.

PiperOrigin-RevId: 740037952
2025-03-24 12:48:19 -07:00
jax authors
16dc0ad1dd Add jax_source_package macros and target to generate a source package .tar.gz.
Refactor `jax_wheel` macros, so it outputs a `.whl` file only.

When the macros returns one output object only, it allows all downstream dependencies consume it easily without the need to filter the macros outputs.

The previous implementation design (when `jax_wheel` returned `.tar.gz` and `.whl` files) required one of two options: either create a new target that produces `.whl` only, or to implement filename filtering in the downstream rules. With the new implementation we can just depend on `//:jax_wheel` target that produces the `.whl`.

PiperOrigin-RevId: 738547491
2025-03-19 14:48:36 -07:00
jax authors
8fbe3b1333 Remove internal_test_util folder and packages from jax wheel.
PiperOrigin-RevId: 736861450
2025-03-14 07:52:03 -07:00
David Dunleavy
1a19d5594a Update all uses of @tsl//third_party to @xla//third_party
PiperOrigin-RevId: 733495240
2025-03-04 15:55:23 -08:00
jax authors
4eb782e402 Update jax_wheel target to produce both wheel and source distribution files.
This change replicates the old method of building `jax` wheel via `python -m build`, which produced `.tar.gz` and `.whl` files.

PiperOrigin-RevId: 731721522
2025-02-27 07:41:13 -08:00
Dan Foreman-Mackey
525cb4bde4 Rename top level build file to BUILD.bazel.
PiperOrigin-RevId: 730957694
2025-02-25 11:13:17 -08:00
Peter Hawkins
e63765a7a6 Use symlink_files() to add version.py to jaxlib, rather than copying it in as part of the wheel assembly process.
Change in preparation for supporting running JAX tests under Bazel. This change allows the Bazel py_library() to see version.py.

Update symlink_files Bazel macro to a newer version.

PiperOrigin-RevId: 458481396
2022-07-01 09:07:03 -07:00
Peter Hawkins
1e171ccd10 Unify jax and jaxlib versions.
Currently jax and jaxlib have separate version numbers in the JAX source
tree. It is tedious and confusing to bump both version numbers.

However, there is a simpler way to think of things: it is the source
tree that is versioned using a single version number, and jax/jaxlib
releases are made using that unified source version number.

PiperOrigin-RevId: 458041752
2022-06-29 12:51:01 -07:00