mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 04:16:07 +00:00

Add a build wheel, pyproject.toml and setup.py. The directory structure in jax repo is: jax/ └── plugins/ └── cuda/ ├── __init__.py ├── pyproject.toml └── setup.py Installed package structure is: jax_plugins/ └── xla_cuda_cu12/ ├── __init__.py └── xla_cuda_plugin.so The major cuda version will be part of the package name. The plugin wheel can be built with command: python3 build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12 --bazel_options="--override_repository=xla=$HOME/xla" PiperOrigin-RevId: 565187954