51 Commits

Author SHA1 Message Date
Andrey Portnoy
15dccd458c Add data argument to jax_test Bazel rule, forward to py_test 2024-06-04 11:17:30 -04:00
Adam Paszke
cfe64cd5ce [Mosaic GPU] Integrate the ExecutionEngine with the jaxlib GPU plugin
This lets us avoid bundling a whole another copy of LLVM with JAX packages
and so we can finally start building Mosaic GPU by default.

PiperOrigin-RevId: 638569750
2024-05-30 01:46:23 -07:00
jax authors
c4559115ec Internal BUILD file change
PiperOrigin-RevId: 634713068
2024-05-17 04:30:21 -07:00
Vadym Matsishevskyi
517e299a9d Use hermetic Python in JAX, see "Managing hermetic Python" in developer.md for details
PiperOrigin-RevId: 634146391
2024-05-15 18:20:56 -07:00
Adam Paszke
8e3f5b1018 Initial commit for Mosaic GPU
Moving this to JAX to make it easier to explore Pallas integration.

PiperOrigin-RevId: 625982382
2024-04-18 04:04:10 -07:00
David Dunleavy
d18323f3c4 Move tsl/BUILD, tsl.bzl, and tsl.default.bzl to XLA
PiperOrigin-RevId: 623215553
2024-04-09 10:47:06 -07:00
Sergei Lebedev
37f313ab22 Fixed internal CI builds
* Added a noop config_tags_overrides parameter to jax_test()
* Updated BUILD files necessary to run Pallas tests via Bazel
* Changed PallasTest to skip "large" test cases

PiperOrigin-RevId: 608534008
2024-02-20 02:42:14 -08:00
Peter Hawkins
720ff42cbf [bazel] Add a macro if_building_jaxlib() to guard dependencies that should only be present if building jaxlib.
Cleanup only, NFC intended.

PiperOrigin-RevId: 588074047
2023-12-05 08:05:17 -08:00
George Necula
c1f54d447e Move back_compat_test_util.py to jax._src.internal_test_util.
Until now the backwards compatibility tests for exporting JAX functions with custom calls were part of the jax2tf test suite. But these tests are independent of TF, and we need to write such tests for Pallas and other projects that should not depend on jax2tf.

Here we move the test utilities out of jax2tf.
This is needed to enable writing Pallas backwards compatibility tests.

We rename back_compat_test_util.py to export_back_compat_test_util.py for clarity.

In a subsequent move we will move the actual backwards compatibility tests themselves out of jax2tf.

PiperOrigin-RevId: 583312085
2023-11-17 02:05:30 -08:00
Jieying Luo
43732e3fd4 Change the definition of the config to run bazel test for cuda plugin to match //jax:build_jaxlib.
When build_cuda_plugin_from_source is true, it will build cuda plugin from source, and it is used for the case of `bazel test` without preinstall jax cuda packages.

PiperOrigin-RevId: 583057751
2023-11-16 08:44:22 -08:00
Jieying Luo
88685d8de0 Support bazel test without bazel build for CUDA PJRT plugin.
- Add build target for jax_plugins/ and jax_plugins/cuda for bazel test.
- Update jax_plugins/cuda/__init__.py to fallback to local `.so` file path.
- Add a flag --//jax:build_cuda_plugin to control whether to link in local cuda plugin.

The following command will test with cuda plugin:
```
bazel test tests:python_callback_test_gpu --test_output=all --test_filter=PythonCallbackTest.test_send_zero_dim_arrays_pure --config=tensorflow_testing_rbe_linux --config=rbe_linux_cuda12.2_nvcc_py3.9 --//jax:build_cuda_plugin=false
```

Default behavior (without `--//jax:build_cuda_plugin=false`) remains unchanged.

PiperOrigin-RevId: 582728477
2023-11-15 10:38:19 -08:00
George Necula
5001a21bad Move primitive_harness.py to jax._src.internal_test_util.test_harnesses.
The primitive_harness.py defines a set of about 7000 test harnesses, each with a JAX callable and a recipe for generating the arguments for the callable. Note that the test harness does not define any expected behavior. The test harnesses can be used in several kinds of tests.

Initially these harnesses were designed to test the completeness of the jax2tf lowering: for each test harness we convert it to TF and then we test that the result of invoking it is the same as for JAX native. Since then we have found other uses of test harnesses.

  * E.g., shape_poly_test.py tests that we can apply `jax.vmap` to each test harness and that we get a JAX callable that can be traced shape polymorphically, using a dimension variable for the batch dimension.
  * E.g., multi_platform_lowering_test.py tests that we can generate multi-platform lowering for each test harnesse.
  * E.g., the TFLite team is using the test harnesses to check the completeness of the TFLite lowering.

Since the test harnesses are useful for non-jax2tf uses we hereby moved them to jax._src.internal_test_util.test_harnesses. (We also renamed the module from primitive_harness to test_harnesses.)

This change is necessary to move some tests out of jax2tf: multi_platform_lowering_test.py, shape_poly_test.py.

PiperOrigin-RevId: 581016785
2023-11-09 13:58:00 -08:00
Peter Hawkins
dbf13252f0 Copybara import of the project:
--
3905d6123bdc22f505934242363fda426c99c4cf by Peter Hawkins <phawkins@google.com>:

Update flatbuffers.

Use upstream flatbuffer bazel scripts, with a couple of small patches to fix:
* https://github.com/google/flatbuffers/issues/8087 (remove npm references)
* https://github.com/google/flatbuffers/pull/8088 (fix flatc build failure due to main() removal by linker)

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/17502 from hawkinsp:fb 3905d6123bdc22f505934242363fda426c99c4cf
PiperOrigin-RevId: 563543954
2023-09-07 14:27:25 -07:00
Roy Frostig
a71c0e6ecc create jax.extend.random as a copy of jax.prng
Co-authored-by: Jake Vanderplas <jakevdp@google.com>
PiperOrigin-RevId: 559874051
2023-08-24 14:41:56 -07:00
Sharad Vikram
d872812a35 [Pallas] Upstream pallas to JAX
PiperOrigin-RevId: 552963029
2023-08-01 16:43:13 -07:00
Sharad Vikram
3d556b7a19 Add Mosaic to Jaxlib and expose bindings in jax.experimental.mosaic
PiperOrigin-RevId: 549801858
2023-07-20 18:28:51 -07:00
Chris Jones
f238667492 Make JAX-Triton calls serializable.
PiperOrigin-RevId: 542524794
2023-06-22 04:57:14 -07:00
Peter Hawkins
88c2898e36 Use pytype_strict_library() in Bazel build rules.
PiperOrigin-RevId: 519757928
2023-03-27 10:16:08 -07:00
Yash Katariya
88584290aa Remove GDA tests from JAX since GDA is deprecated. There are jax.Array tests for all the corresponding GDA tests
PiperOrigin-RevId: 516881635
2023-03-15 11:34:57 -07:00
jax authors
42ef649e65 Merge pull request #14475 from hawkinsp:openxla
PiperOrigin-RevId: 516316330
2023-03-13 14:04:41 -07:00
Peter Hawkins
172a831219 Switch JAX to use the OpenXLA repository. 2023-03-13 18:38:26 +00:00
Peter Hawkins
e4b154b660 Split basearray into separate Bazel module.
Move the definition of ArrayLike into basearray to avoid a cyclic dependency between array.py and basearray.

PiperOrigin-RevId: 516264828
2023-03-13 11:14:41 -07:00
Peter Hawkins
d58be3d4df Split source_info_util into its own Bazel target.
PiperOrigin-RevId: 515646269
2023-03-10 08:41:06 -08:00
Peter Hawkins
0e05a7987f Split some submodules out of //jax under Bazel.
Add separate BUILD targets
* :version - for version.py
* _src/lib - wrapping the jaxlib shims.
* :util - for util.py
* :config - for config.py

PiperOrigin-RevId: 515307923
2023-03-09 05:27:34 -08:00
Peter Hawkins
f7734fd6a4 Limit visibility of Bazel target jax:global_device_array.
PiperOrigin-RevId: 510521459
2023-02-17 14:30:05 -08:00
Jake VanderPlas
936e4ae101 Add new argument to jax_test rule
PiperOrigin-RevId: 509952902
2023-02-15 15:45:47 -08:00
Yash Katariya
3e5a5053f4 Run GPU presubmits via bazel test on the RBE cluster. This speeds up the build + testing significantly (upto 10x).
But run the continuous builds by building on RBE and testing locally so as to run the multiaccelerator tests too. Locally we have 4 GPUs available.

Also make GPU presubmits blocking for JAX (re-enabled it).

PiperOrigin-RevId: 491647775
2022-11-29 08:45:58 -08:00
Yash Katariya
a4e8df76ab Use the remote_gpu tag which is inserted by TF's workspace2 when REMOTE_GPU_TESTING=1
PiperOrigin-RevId: 490553133
2022-11-23 11:50:50 -08:00
Yash Katariya
8e270575f8 Set tf_exec_properties on OSS tests to use TF's gpu pool in the RBE cluster.
PiperOrigin-RevId: 490542399
2022-11-23 11:00:53 -08:00
Jake VanderPlas
6cae54f82d Fix bazel build alias 2022-09-26 15:13:12 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
jax authors
fd90f40c45 Merge pull request #12443 from cloudhan:fix-mlir-chlo-stablehlo-symbols
PiperOrigin-RevId: 475808753
2022-09-21 06:12:44 -07:00
Cloud Han
3fa2c933f4 Fix linker error due to chlo and stablehol symbols are not exported in mlir dll 2022-09-21 17:26:21 +08:00
Jake VanderPlas
13a7034e6a Internal change
PiperOrigin-RevId: 474331907
2022-09-14 10:39:38 -07:00
Skye Wanderman-Milne
031f0b1a10 Add missing Google-internal option to jax_test 2022-09-07 18:31:02 -07:00
Mehdi Amini
ae6e0e0950 Move tensorflow/core/platform/{default, google, windows} to tensorflow/tsl/platform/...
PiperOrigin-RevId: 468025286
2022-08-16 14:33:22 -07:00
Peter Hawkins
b865111996 Refactor BUILD files to avoid individually naming Python dependencies.
Add a parametric py_deps() macro for adding Python package dependencies for Bazel rules.

Fix build failure with dangling matplotlib reference.

PiperOrigin-RevId: 465562141
2022-08-05 07:49:20 -07:00
Yash Katariya
f0b6478b3e Plumb env through jax_test.
PiperOrigin-RevId: 465473378
2022-08-04 21:05:28 -07:00
Yash Katariya
9a5af235da Delete sharded_jit
PiperOrigin-RevId: 464081692
2022-07-29 08:19:52 -07:00
Peter Hawkins
a48f4e116e Change Bazel test rules to generate per-backend test suites. 2022-07-08 14:19:05 +00:00
Peter Hawkins
1c75eee1ff Document how to run tests using Bazel.
* Add a new --configure_only option to build.py to allow build.py to generate a .bazelrc without necessarily building jaxlib.
* Add a bazel flag that make the dependency of //jax on //jaxlib optional. If //jaxlib isn't built by bazel, then tests will implicitly use a preinstalled jaxlib.
2022-07-06 08:30:35 -04:00
Peter Hawkins
1fc9afd03a Add support for running JAX tests under Bazel.
This is an alternative method for running the tests that some users may prefer: pytest is and will remain fully supported.

To use this, one creates a .bazelrc by running the existing `build.py` script, and then one can run the tests by running:
```
bazel test -c opt //tests/...
```

Issue #7323

PiperOrigin-RevId: 458551208
2022-07-01 15:07:22 -07:00
Peter Hawkins
7c49864fdf Symlink xla_client and xla_extension into jaxlib rather than copying them into place in the wheel build.
Change in preparation for allowing JAX tests to run under Bazel.

Remove code to patch paths in xla_client.py in the wheel build script; the patch is no longer used.

PiperOrigin-RevId: 458522398
2022-07-01 12:31:42 -07:00
Peter Hawkins
47f2f091bc Reapply: Drop flatbuffers as a Python dependency of JAX.
The crashes on Mac were, as best we can tell, unrelated to this PR.

Original description:
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.

Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.

PiperOrigin-RevId: 457819042
2022-06-28 14:25:14 -07:00
Peter Hawkins
5b576cb03e Revert: Drop flatbuffers as a Python dependency of JAX.
This change appears to be causing crashes on Mac.

Original description:
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.

Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.

PiperOrigin-RevId: 457559793
2022-06-27 13:56:32 -07:00
Peter Hawkins
efefeac450 Drop flatbuffers as a Python dependency of JAX.
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.

Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.

PiperOrigin-RevId: 457460347
2022-06-27 06:14:07 -07:00
jax authors
cf9a900d78 Merge pull request #9584 from ROCmSoftwarePlatform:rocm_refactor_jaxlib
PiperOrigin-RevId: 432236852
2022-03-03 11:11:02 -08:00
Reza Rahimi
a0d9d81f92 Update JAX to use new math libraries in ROCm-5.0. 2022-03-01 20:02:15 +00:00
Cloud Han
317edcdacd fix mlir capi dll building and linking 2021-11-25 00:07:25 +08:00
Peter Hawkins
11f6c535ae Add MLIR:Python bindings to jaxlib build.
PiperOrigin-RevId: 407657331
2021-11-04 13:29:58 -07:00