4 Commits

Author SHA1 Message Date
Jieying Luo
1559d6495e Remove local version in jax-cuda-plugin and jax-cuda-pjrt package.
PiperOrigin-RevId: 591057013
2023-12-14 14:44:49 -08:00
Jieying Luo
d6c5910105 [PJRT C API] Move cuda_plugin_extension from jaxlib to jax-cuda-plugin (the package for cuda kernels).
PiperOrigin-RevId: 583406466
2023-11-17 09:11:46 -08:00
Jieying Luo
88685d8de0 Support bazel test without bazel build for CUDA PJRT plugin.
- Add build target for jax_plugins/ and jax_plugins/cuda for bazel test.
- Update jax_plugins/cuda/__init__.py to fallback to local `.so` file path.
- Add a flag --//jax:build_cuda_plugin to control whether to link in local cuda plugin.

The following command will test with cuda plugin:
```
bazel test tests:python_callback_test_gpu --test_output=all --test_filter=PythonCallbackTest.test_send_zero_dim_arrays_pure --config=tensorflow_testing_rbe_linux --config=rbe_linux_cuda12.2_nvcc_py3.9 --//jax:build_cuda_plugin=false
```

Default behavior (without `--//jax:build_cuda_plugin=false`) remains unchanged.

PiperOrigin-RevId: 582728477
2023-11-15 10:38:19 -08:00
Jieying Luo
ec21e04201 [PJRT C API] Rename the folder "plugins" to "jax_plugins".
With this change, existing plugin discovery mechanism can discover local plugins without pip install.

Update jax_plugins/cuda/__init__.py to return without registering the plugin if the .so file does not exist.

PiperOrigin-RevId: 582431300
2023-11-14 13:56:13 -08:00