This PR is a follow up to #18881.
The changes were generated by adding
from __future__ import annotations
to the files which did not already have them and running
pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
Jax isn't using this, and in fact our code to build this wasn't including the C++ parts, so it was broken anyway. Remove it until someone actually needs it for something.
PiperOrigin-RevId: 587323808
We infer a missing cudnn if cudnnGetVersion() returns 0, since the stub implementation in TSL will do that if the library isn't found (10a378f499/third_party/tsl/tsl/cuda/cudnn_stub.cc (L58)).
PiperOrigin-RevId: 587056454
Also remove the vector-avoiding specialization. For some reason
is_same<ssize_t, int64_t> evaluates to true on macOS, but then
the compiler complains that int64_t is a long long, while
ssize_t is only a long.
Although the TODO says to return failure, this is actually done at the end of the function (and this way we handle the case for ops without vector args).
PiperOrigin-RevId: 584575120
The argument to the cast is of type ssize_t. Mismatch between int64_t and ssize_t happens in Mac and causes build to fail:
`error: const_cast from 'const pybind11::ssize_t *' (aka 'const long *') to 'int64_t *' (aka 'long long *') is not allowed`
PiperOrigin-RevId: 584457599
Split code to determine CUDA library versions out of py_extension() module and into a cc_library(), because it fixes a linking problem in Google's build. (Long story, not worth it.)
Fixes https://github.com/google/jax/issues/8289
PiperOrigin-RevId: 583544218
Until now the backwards compatibility tests for exporting JAX functions with custom calls were part of the jax2tf test suite. But these tests are independent of TF, and we need to write such tests for Pallas and other projects that should not depend on jax2tf.
Here we move the test utilities out of jax2tf.
This is needed to enable writing Pallas backwards compatibility tests.
We rename back_compat_test_util.py to export_back_compat_test_util.py for clarity.
In a subsequent move we will move the actual backwards compatibility tests themselves out of jax2tf.
PiperOrigin-RevId: 583312085
When build_cuda_plugin_from_source is true, it will build cuda plugin from source, and it is used for the case of `bazel test` without preinstall jax cuda packages.
PiperOrigin-RevId: 583057751
- Add build target for jax_plugins/ and jax_plugins/cuda for bazel test.
- Update jax_plugins/cuda/__init__.py to fallback to local `.so` file path.
- Add a flag --//jax:build_cuda_plugin to control whether to link in local cuda plugin.
The following command will test with cuda plugin:
```
bazel test tests:python_callback_test_gpu --test_output=all --test_filter=PythonCallbackTest.test_send_zero_dim_arrays_pure --config=tensorflow_testing_rbe_linux --config=rbe_linux_cuda12.2_nvcc_py3.9 --//jax:build_cuda_plugin=false
```
Default behavior (without `--//jax:build_cuda_plugin=false`) remains unchanged.
PiperOrigin-RevId: 582728477
With this change, existing plugin discovery mechanism can discover local plugins without pip install.
Update jax_plugins/cuda/__init__.py to return without registering the plugin if the .so file does not exist.
PiperOrigin-RevId: 582431300
If the gesvdj() is preferable to gesvd() absent a batch dimension, even if there is a batch dimension we should prefer a loop of gesvdj() over a loop of gesvd().
PiperOrigin-RevId: 582279549
The primitive_harness.py defines a set of about 7000 test harnesses, each with a JAX callable and a recipe for generating the arguments for the callable. Note that the test harness does not define any expected behavior. The test harnesses can be used in several kinds of tests.
Initially these harnesses were designed to test the completeness of the jax2tf lowering: for each test harness we convert it to TF and then we test that the result of invoking it is the same as for JAX native. Since then we have found other uses of test harnesses.
* E.g., shape_poly_test.py tests that we can apply `jax.vmap` to each test harness and that we get a JAX callable that can be traced shape polymorphically, using a dimension variable for the batch dimension.
* E.g., multi_platform_lowering_test.py tests that we can generate multi-platform lowering for each test harnesse.
* E.g., the TFLite team is using the test harnesses to check the completeness of the TFLite lowering.
Since the test harnesses are useful for non-jax2tf uses we hereby moved them to jax._src.internal_test_util.test_harnesses. (We also renamed the module from primitive_harness to test_harnesses.)
This change is necessary to move some tests out of jax2tf: multi_platform_lowering_test.py, shape_poly_test.py.
PiperOrigin-RevId: 581016785
With this change, `python3 build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12` will generate three wheels:
| |size|wheel name |
|----------------------|----|-------------------------------------------------------------------------|
|jaxlib w/o cuda kernels|76M |jaxlib-0.4.20.dev20231101-cp310-cp310-manylinux2014_x86_64.whl |
|cuda pjrt |73M|jax_cuda12_pjrt-0.4.20.dev20231101-py3-none-manylinux2014_x86_64.whl |
|cuda kernels |6.6M|jax_cuda12_plugin-0.4.20.dev20231101-cp310-cp310-manylinux2014_x86_64.whl|
The size of jaxlib with cuda kernels and pjrt is 119M.
The cuda kernel wheel contains all the cuda kernels. A plugin_setup.py and plugin_pyproject.toml are added for this new pacakge.
PiperOrigin-RevId: 579861480