When build_cuda_plugin_from_source is true, it will build cuda plugin from source, and it is used for the case of `bazel test` without preinstall jax cuda packages.
PiperOrigin-RevId: 583057751
An upcoming change to JAX will include non-local (addressable) CPU devices in jax.devices() when JAX is used multicontroller-style, where there are multiple Python processes.
This change preserves the current behavior by replacing uses of jax.devices("cpu"), which previously only returned local devices, with jax.local_devices("cpu"), which will return local devices both now and in the future.
This change is always be safe (i.e., it should always preserve the previous behavior) but it may sometimes be unnecessary if code is never used in a multicontroller setting.
PiperOrigin-RevId: 582786346
- Add build target for jax_plugins/ and jax_plugins/cuda for bazel test.
- Update jax_plugins/cuda/__init__.py to fallback to local `.so` file path.
- Add a flag --//jax:build_cuda_plugin to control whether to link in local cuda plugin.
The following command will test with cuda plugin:
```
bazel test tests:python_callback_test_gpu --test_output=all --test_filter=PythonCallbackTest.test_send_zero_dim_arrays_pure --config=tensorflow_testing_rbe_linux --config=rbe_linux_cuda12.2_nvcc_py3.9 --//jax:build_cuda_plugin=false
```
Default behavior (without `--//jax:build_cuda_plugin=false`) remains unchanged.
PiperOrigin-RevId: 582728477
We expose 3 modes:
* `SpecifiedLayout`: User specifies the `minor_to_major` field of the layout. Tiling not exposed yet.
* `DefaultLayout`: PJRT chooses the layout. It defaults to the current behavior.
* `AUTO`: Compiler chooses the layout. This field is not a layout per se. It's a request to get the layout from the compiler. This field cannot be on an Array or other data types. It can only be on jit.
Public API coming soon.
Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 582692036
With this change, existing plugin discovery mechanism can discover local plugins without pip install.
Update jax_plugins/cuda/__init__.py to return without registering the plugin if the .so file does not exist.
PiperOrigin-RevId: 582431300
Workloads that set the environment variable LIBTPU_INIT_ARGS
expect that the cache key will be invalidated if the value
of the variable changes between runs. Today, LIBTPU_INIT_ARGS
is not used in the cache key computation. The fix is to factor
it in similar to what is done with the XLA_FLAGS environment
variable.
Testing: new unit test; test workloads.
PiperOrigin-RevId: 582423420
If the gesvdj() is preferable to gesvd() absent a batch dimension, even if there is a batch dimension we should prefer a loop of gesvdj() over a loop of gesvd().
PiperOrigin-RevId: 582279549