This lets us avoid bundling a whole another copy of LLVM with JAX packages
and so we can finally start building Mosaic GPU by default.
PiperOrigin-RevId: 638569750
* Added a noop config_tags_overrides parameter to jax_test()
* Updated BUILD files necessary to run Pallas tests via Bazel
* Changed PallasTest to skip "large" test cases
PiperOrigin-RevId: 608534008
Until now the backwards compatibility tests for exporting JAX functions with custom calls were part of the jax2tf test suite. But these tests are independent of TF, and we need to write such tests for Pallas and other projects that should not depend on jax2tf.
Here we move the test utilities out of jax2tf.
This is needed to enable writing Pallas backwards compatibility tests.
We rename back_compat_test_util.py to export_back_compat_test_util.py for clarity.
In a subsequent move we will move the actual backwards compatibility tests themselves out of jax2tf.
PiperOrigin-RevId: 583312085
When build_cuda_plugin_from_source is true, it will build cuda plugin from source, and it is used for the case of `bazel test` without preinstall jax cuda packages.
PiperOrigin-RevId: 583057751
- Add build target for jax_plugins/ and jax_plugins/cuda for bazel test.
- Update jax_plugins/cuda/__init__.py to fallback to local `.so` file path.
- Add a flag --//jax:build_cuda_plugin to control whether to link in local cuda plugin.
The following command will test with cuda plugin:
```
bazel test tests:python_callback_test_gpu --test_output=all --test_filter=PythonCallbackTest.test_send_zero_dim_arrays_pure --config=tensorflow_testing_rbe_linux --config=rbe_linux_cuda12.2_nvcc_py3.9 --//jax:build_cuda_plugin=false
```
Default behavior (without `--//jax:build_cuda_plugin=false`) remains unchanged.
PiperOrigin-RevId: 582728477
The primitive_harness.py defines a set of about 7000 test harnesses, each with a JAX callable and a recipe for generating the arguments for the callable. Note that the test harness does not define any expected behavior. The test harnesses can be used in several kinds of tests.
Initially these harnesses were designed to test the completeness of the jax2tf lowering: for each test harness we convert it to TF and then we test that the result of invoking it is the same as for JAX native. Since then we have found other uses of test harnesses.
* E.g., shape_poly_test.py tests that we can apply `jax.vmap` to each test harness and that we get a JAX callable that can be traced shape polymorphically, using a dimension variable for the batch dimension.
* E.g., multi_platform_lowering_test.py tests that we can generate multi-platform lowering for each test harnesse.
* E.g., the TFLite team is using the test harnesses to check the completeness of the TFLite lowering.
Since the test harnesses are useful for non-jax2tf uses we hereby moved them to jax._src.internal_test_util.test_harnesses. (We also renamed the module from primitive_harness to test_harnesses.)
This change is necessary to move some tests out of jax2tf: multi_platform_lowering_test.py, shape_poly_test.py.
PiperOrigin-RevId: 581016785
But run the continuous builds by building on RBE and testing locally so as to run the multiaccelerator tests too. Locally we have 4 GPUs available.
Also make GPU presubmits blocking for JAX (re-enabled it).
PiperOrigin-RevId: 491647775
Add a parametric py_deps() macro for adding Python package dependencies for Bazel rules.
Fix build failure with dangling matplotlib reference.
PiperOrigin-RevId: 465562141
* Add a new --configure_only option to build.py to allow build.py to generate a .bazelrc without necessarily building jaxlib.
* Add a bazel flag that make the dependency of //jax on //jaxlib optional. If //jaxlib isn't built by bazel, then tests will implicitly use a preinstalled jaxlib.
This is an alternative method for running the tests that some users may prefer: pytest is and will remain fully supported.
To use this, one creates a .bazelrc by running the existing `build.py` script, and then one can run the tests by running:
```
bazel test -c opt //tests/...
```
Issue #7323
PiperOrigin-RevId: 458551208
Change in preparation for allowing JAX tests to run under Bazel.
Remove code to patch paths in xla_client.py in the wheel build script; the patch is no longer used.
PiperOrigin-RevId: 458522398
The crashes on Mac were, as best we can tell, unrelated to this PR.
Original description:
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.
Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.
PiperOrigin-RevId: 457819042
This change appears to be causing crashes on Mac.
Original description:
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.
Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.
PiperOrigin-RevId: 457559793
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.
Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.
PiperOrigin-RevId: 457460347