Move jaxlib/{cuda,rocm}_plugin_extension into jaxlib/{cuda/rocm}/

Move the common jaxlib/gpu_plugin_extension into jaxlib/gpu/

Cleanup only, no functional changes intended.

PiperOrigin-RevId: 738183402
This commit is contained in:
Peter Hawkins 2025-03-18 16:28:00 -07:00 committed by jax authors
parent 01a110c4c9
commit 3f91b4b43a
13 changed files with 75 additions and 75 deletions

View File

@ -71,7 +71,7 @@ import numpy as np
export = set_module('jax.numpy')
for pkg_name in ['jax_cuda12_plugin', 'jax.jaxlib']:
for pkg_name in ['jax_cuda12_plugin', 'jax.jaxlib.cuda']:
try:
cuda_plugin_extension = importlib.import_module(
f'{pkg_name}.cuda_plugin_extension'

View File

@ -24,7 +24,7 @@ import jax._src.xla_bridge as xb
# cuda_plugin_extension locates inside jaxlib. `jaxlib` is for testing without
# preinstalled jax cuda plugin packages.
for pkg_name in ['jax_cuda12_plugin', 'jaxlib']:
for pkg_name in ['jax_cuda12_plugin', 'jaxlib.cuda']:
try:
cuda_plugin_extension = importlib.import_module(
f'{pkg_name}.cuda_plugin_extension'

View File

@ -23,7 +23,7 @@ import jax._src.xla_bridge as xb
# rocm_plugin_extension locates inside jaxlib. `jaxlib` is for testing without
# preinstalled jax rocm plugin packages.
for pkg_name in ['jax_rocm60_plugin', 'jaxlib']:
for pkg_name in ['jax_rocm60_plugin', 'jaxlib.cuda']:
try:
rocm_plugin_extension = importlib.import_module(
f'{pkg_name}.rocm_plugin_extension'

View File

@ -222,62 +222,3 @@ nanobind_extension(
"@xla//third_party/python_runtime:headers",
],
)
cc_library(
name = "gpu_plugin_extension",
srcs = ["gpu_plugin_extension.cc"],
hdrs = ["gpu_plugin_extension.h"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
deps = [
":kernel_nanobind_helpers",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/strings:string_view",
"@nanobind",
"@xla//xla:util",
"@xla//xla/ffi/api:c_api",
"@xla//xla/pjrt:status_casters",
"@xla//xla/pjrt/c:pjrt_c_api_ffi_extension_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_helpers",
"@xla//xla/pjrt/c:pjrt_c_api_triton_extension_hdrs",
"@xla//xla/python:py_client_gpu",
"@xla//xla/tsl/python/lib/core:numpy",
],
)
nanobind_extension(
name = "cuda_plugin_extension",
srcs = ["cuda_plugin_extension.cc"],
module_name = "cuda_plugin_extension",
deps = [
":gpu_plugin_extension",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@local_config_cuda//cuda:cuda_headers",
"@nanobind",
"@xla//xla/pjrt:status_casters",
"@xla//xla/tsl/cuda:cublas",
"@xla//xla/tsl/cuda:cudart",
],
)
nanobind_extension(
name = "rocm_plugin_extension",
srcs = ["rocm_plugin_extension.cc"],
module_name = "rocm_plugin_extension",
deps = [
":gpu_plugin_extension",
"@com_google_absl//absl/log",
"@com_google_absl//absl/strings",
"@local_config_rocm//rocm:hip",
"@local_config_rocm//rocm:rocm_headers",
"@nanobind",
],
)

View File

@ -657,6 +657,22 @@ py_library(
],
)
nanobind_extension(
name = "cuda_plugin_extension",
srcs = ["cuda_plugin_extension.cc"],
module_name = "cuda_plugin_extension",
deps = [
"//jaxlib/gpu:gpu_plugin_extension",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@local_config_cuda//cuda:cuda_headers",
"@nanobind",
"@xla//xla/pjrt:status_casters",
"@xla//xla/tsl/cuda:cublas",
"@xla//xla/tsl/cuda:cudart",
],
)
# We cannot nest select and if_cuda_is_configured so we introduce
# a standalone py_library target.
py_library(
@ -664,6 +680,6 @@ py_library(
# `if_cuda_is_configured` will default to `[]`.
deps = if_cuda_is_configured([
":cuda_gpu_support",
"//jaxlib:cuda_plugin_extension",
":cuda_plugin_extension",
]),
)

View File

@ -20,7 +20,7 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "jaxlib/gpu_plugin_extension.h"
#include "jaxlib/gpu/gpu_plugin_extension.h"
#include "xla/pjrt/status_casters.h"
namespace nb = nanobind;

View File

@ -90,3 +90,32 @@ xla_py_proto_library(
visibility = jax_visibility("triton_proto_py_users"),
deps = [":triton_proto"],
)
cc_library(
name = "gpu_plugin_extension",
srcs = ["gpu_plugin_extension.cc"],
hdrs = ["gpu_plugin_extension.h"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
deps = [
"//jaxlib:kernel_nanobind_helpers",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/strings:string_view",
"@nanobind",
"@xla//xla:util",
"@xla//xla/ffi/api:c_api",
"@xla//xla/pjrt:status_casters",
"@xla//xla/pjrt/c:pjrt_c_api_ffi_extension_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_helpers",
"@xla//xla/pjrt/c:pjrt_c_api_triton_extension_hdrs",
"@xla//xla/python:py_client_gpu",
"@xla//xla/tsl/python/lib/core:numpy",
],
)

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/gpu_plugin_extension.h"
#include "jaxlib/gpu/gpu_plugin_extension.h"
#include <cstddef>
#include <cstdint>

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef JAXLIB_GPU_PLUGIN_EXTENSION_H_
#define JAXLIB_GPU_PLUGIN_EXTENSION_H_
#ifndef JAXLIB_GPU_GPU_PLUGIN_EXTENSION_H_
#define JAXLIB_GPU_GPU_PLUGIN_EXTENSION_H_
#include "nanobind/nanobind.h"
@ -24,4 +24,4 @@ void BuildGpuPluginExtension(nanobind::module_& m);
} // namespace xla
#endif // JAXLIB_GPU_PLUGIN_EXTENSION_H_
#endif // JAXLIB_GPU_GPU_PLUGIN_EXTENSION_H_

View File

@ -555,11 +555,25 @@ py_library(
],
)
nanobind_extension(
name = "rocm_plugin_extension",
srcs = ["rocm_plugin_extension.cc"],
module_name = "rocm_plugin_extension",
deps = [
"//jaxlib/gpu:gpu_plugin_extension",
"@com_google_absl//absl/log",
"@com_google_absl//absl/strings",
"@local_config_rocm//rocm:hip",
"@local_config_rocm//rocm:rocm_headers",
"@nanobind",
],
)
py_library(
name = "gpu_only_test_deps",
# `if_rocm_is_configured` will default to `[]`.
deps = if_rocm_is_configured([
":rocm_gpu_support",
"//jaxlib:rocm_plugin_extension",
":rocm_plugin_extension",
]),
)

View File

@ -20,7 +20,7 @@ limitations under the License.
#include "absl/log/log.h"
#include "absl/strings/str_cat.h"
#include "rocm/include/hip/hip_runtime.h"
#include "jaxlib/gpu_plugin_extension.h"
#include "jaxlib/gpu/gpu_plugin_extension.h"
namespace nb = nanobind;

View File

@ -143,16 +143,16 @@ py_binary(
data = [
"LICENSE.txt",
] + if_cuda([
"//jaxlib/mosaic/gpu:mosaic_gpu",
"//jaxlib:cuda_plugin_extension",
"//jaxlib:version",
"//jaxlib/mosaic/gpu:mosaic_gpu",
"//jaxlib/cuda:cuda_plugin_extension",
"//jaxlib/cuda:cuda_gpu_support",
"//jax_plugins/cuda:plugin_pyproject.toml",
"//jax_plugins/cuda:plugin_setup.py",
"@local_config_cuda//cuda:cuda-nvvm",
]) + if_rocm([
"//jaxlib:rocm_plugin_extension",
"//jaxlib:version",
"//jaxlib/rocm:rocm_plugin_extension",
"//jaxlib/rocm:rocm_gpu_support",
"//jax_plugins/rocm:plugin_pyproject.toml",
"//jax_plugins/rocm:plugin_setup.py",

View File

@ -110,7 +110,7 @@ def prepare_wheel_cuda(
f"__main__/jaxlib/cuda/_triton.{pyext}",
f"__main__/jaxlib/cuda/_hybrid.{pyext}",
f"__main__/jaxlib/cuda/_versions.{pyext}",
f"__main__/jaxlib/cuda_plugin_extension.{pyext}",
f"__main__/jaxlib/cuda/cuda_plugin_extension.{pyext}",
f"__main__/jaxlib/mosaic/gpu/_mosaic_gpu_ext.{pyext}",
"__main__/jaxlib/mosaic/gpu/libmosaic_gpu_runtime.so",
"__main__/jaxlib/version.py",
@ -148,7 +148,7 @@ def prepare_wheel_rocm(
f"__main__/jaxlib/rocm/_hybrid.{pyext}",
f"__main__/jaxlib/rocm/_rnn.{pyext}",
f"__main__/jaxlib/rocm/_triton.{pyext}",
f"__main__/jaxlib/rocm_plugin_extension.{pyext}",
f"__main__/jaxlib/rocm/rocm_plugin_extension.{pyext}",
"__main__/jaxlib/version.py",
],
)