diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index b2b828cf3..96efc4806 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -71,7 +71,7 @@ import numpy as np export = set_module('jax.numpy') -for pkg_name in ['jax_cuda12_plugin', 'jax.jaxlib']: +for pkg_name in ['jax_cuda12_plugin', 'jax.jaxlib.cuda']: try: cuda_plugin_extension = importlib.import_module( f'{pkg_name}.cuda_plugin_extension' diff --git a/jax_plugins/cuda/__init__.py b/jax_plugins/cuda/__init__.py index 68281f4f3..f6540e986 100644 --- a/jax_plugins/cuda/__init__.py +++ b/jax_plugins/cuda/__init__.py @@ -24,7 +24,7 @@ import jax._src.xla_bridge as xb # cuda_plugin_extension locates inside jaxlib. `jaxlib` is for testing without # preinstalled jax cuda plugin packages. -for pkg_name in ['jax_cuda12_plugin', 'jaxlib']: +for pkg_name in ['jax_cuda12_plugin', 'jaxlib.cuda']: try: cuda_plugin_extension = importlib.import_module( f'{pkg_name}.cuda_plugin_extension' diff --git a/jax_plugins/rocm/__init__.py b/jax_plugins/rocm/__init__.py index b16806e39..c48a681bf 100644 --- a/jax_plugins/rocm/__init__.py +++ b/jax_plugins/rocm/__init__.py @@ -23,7 +23,7 @@ import jax._src.xla_bridge as xb # rocm_plugin_extension locates inside jaxlib. `jaxlib` is for testing without # preinstalled jax rocm plugin packages. -for pkg_name in ['jax_rocm60_plugin', 'jaxlib']: +for pkg_name in ['jax_rocm60_plugin', 'jaxlib.cuda']: try: rocm_plugin_extension = importlib.import_module( f'{pkg_name}.rocm_plugin_extension' diff --git a/jaxlib/BUILD b/jaxlib/BUILD index a61bf7c88..a35eabc9a 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -222,62 +222,3 @@ nanobind_extension( "@xla//third_party/python_runtime:headers", ], ) - -cc_library( - name = "gpu_plugin_extension", - srcs = ["gpu_plugin_extension.cc"], - hdrs = ["gpu_plugin_extension.h"], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - deps = [ - ":kernel_nanobind_helpers", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@nanobind", - "@xla//xla:util", - "@xla//xla/ffi/api:c_api", - "@xla//xla/pjrt:status_casters", - "@xla//xla/pjrt/c:pjrt_c_api_ffi_extension_hdrs", - "@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs", - "@xla//xla/pjrt/c:pjrt_c_api_hdrs", - "@xla//xla/pjrt/c:pjrt_c_api_helpers", - "@xla//xla/pjrt/c:pjrt_c_api_triton_extension_hdrs", - "@xla//xla/python:py_client_gpu", - "@xla//xla/tsl/python/lib/core:numpy", - ], -) - -nanobind_extension( - name = "cuda_plugin_extension", - srcs = ["cuda_plugin_extension.cc"], - module_name = "cuda_plugin_extension", - deps = [ - ":gpu_plugin_extension", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@local_config_cuda//cuda:cuda_headers", - "@nanobind", - "@xla//xla/pjrt:status_casters", - "@xla//xla/tsl/cuda:cublas", - "@xla//xla/tsl/cuda:cudart", - ], -) - -nanobind_extension( - name = "rocm_plugin_extension", - srcs = ["rocm_plugin_extension.cc"], - module_name = "rocm_plugin_extension", - deps = [ - ":gpu_plugin_extension", - "@com_google_absl//absl/log", - "@com_google_absl//absl/strings", - "@local_config_rocm//rocm:hip", - "@local_config_rocm//rocm:rocm_headers", - "@nanobind", - ], -) diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 9a0315266..a9bd35b77 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -657,6 +657,22 @@ py_library( ], ) +nanobind_extension( + name = "cuda_plugin_extension", + srcs = ["cuda_plugin_extension.cc"], + module_name = "cuda_plugin_extension", + deps = [ + "//jaxlib/gpu:gpu_plugin_extension", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@local_config_cuda//cuda:cuda_headers", + "@nanobind", + "@xla//xla/pjrt:status_casters", + "@xla//xla/tsl/cuda:cublas", + "@xla//xla/tsl/cuda:cudart", + ], +) + # We cannot nest select and if_cuda_is_configured so we introduce # a standalone py_library target. py_library( @@ -664,6 +680,6 @@ py_library( # `if_cuda_is_configured` will default to `[]`. deps = if_cuda_is_configured([ ":cuda_gpu_support", - "//jaxlib:cuda_plugin_extension", + ":cuda_plugin_extension", ]), ) diff --git a/jaxlib/cuda_plugin_extension.cc b/jaxlib/cuda/cuda_plugin_extension.cc similarity index 97% rename from jaxlib/cuda_plugin_extension.cc rename to jaxlib/cuda/cuda_plugin_extension.cc index 34cf462d6..8d8514bd2 100644 --- a/jaxlib/cuda_plugin_extension.cc +++ b/jaxlib/cuda/cuda_plugin_extension.cc @@ -20,7 +20,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "third_party/gpus/cuda/include/cuda.h" -#include "jaxlib/gpu_plugin_extension.h" +#include "jaxlib/gpu/gpu_plugin_extension.h" #include "xla/pjrt/status_casters.h" namespace nb = nanobind; diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD index abaed291a..b5292746d 100644 --- a/jaxlib/gpu/BUILD +++ b/jaxlib/gpu/BUILD @@ -90,3 +90,32 @@ xla_py_proto_library( visibility = jax_visibility("triton_proto_py_users"), deps = [":triton_proto"], ) + +cc_library( + name = "gpu_plugin_extension", + srcs = ["gpu_plugin_extension.cc"], + hdrs = ["gpu_plugin_extension.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "//jaxlib:kernel_nanobind_helpers", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@nanobind", + "@xla//xla:util", + "@xla//xla/ffi/api:c_api", + "@xla//xla/pjrt:status_casters", + "@xla//xla/pjrt/c:pjrt_c_api_ffi_extension_hdrs", + "@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs", + "@xla//xla/pjrt/c:pjrt_c_api_hdrs", + "@xla//xla/pjrt/c:pjrt_c_api_helpers", + "@xla//xla/pjrt/c:pjrt_c_api_triton_extension_hdrs", + "@xla//xla/python:py_client_gpu", + "@xla//xla/tsl/python/lib/core:numpy", + ], +) diff --git a/jaxlib/gpu_plugin_extension.cc b/jaxlib/gpu/gpu_plugin_extension.cc similarity index 99% rename from jaxlib/gpu_plugin_extension.cc rename to jaxlib/gpu/gpu_plugin_extension.cc index d666ef6cc..b56cb8337 100644 --- a/jaxlib/gpu_plugin_extension.cc +++ b/jaxlib/gpu/gpu_plugin_extension.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/gpu_plugin_extension.h" +#include "jaxlib/gpu/gpu_plugin_extension.h" #include #include diff --git a/jaxlib/gpu_plugin_extension.h b/jaxlib/gpu/gpu_plugin_extension.h similarity index 85% rename from jaxlib/gpu_plugin_extension.h rename to jaxlib/gpu/gpu_plugin_extension.h index ae8cd73db..70c74454e 100644 --- a/jaxlib/gpu_plugin_extension.h +++ b/jaxlib/gpu/gpu_plugin_extension.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_GPU_PLUGIN_EXTENSION_H_ -#define JAXLIB_GPU_PLUGIN_EXTENSION_H_ +#ifndef JAXLIB_GPU_GPU_PLUGIN_EXTENSION_H_ +#define JAXLIB_GPU_GPU_PLUGIN_EXTENSION_H_ #include "nanobind/nanobind.h" @@ -24,4 +24,4 @@ void BuildGpuPluginExtension(nanobind::module_& m); } // namespace xla -#endif // JAXLIB_GPU_PLUGIN_EXTENSION_H_ +#endif // JAXLIB_GPU_GPU_PLUGIN_EXTENSION_H_ diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index 9774708ad..9a25a795f 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -555,11 +555,25 @@ py_library( ], ) +nanobind_extension( + name = "rocm_plugin_extension", + srcs = ["rocm_plugin_extension.cc"], + module_name = "rocm_plugin_extension", + deps = [ + "//jaxlib/gpu:gpu_plugin_extension", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + "@local_config_rocm//rocm:hip", + "@local_config_rocm//rocm:rocm_headers", + "@nanobind", + ], +) + py_library( name = "gpu_only_test_deps", # `if_rocm_is_configured` will default to `[]`. deps = if_rocm_is_configured([ ":rocm_gpu_support", - "//jaxlib:rocm_plugin_extension", + ":rocm_plugin_extension", ]), ) diff --git a/jaxlib/rocm_plugin_extension.cc b/jaxlib/rocm/rocm_plugin_extension.cc similarity index 98% rename from jaxlib/rocm_plugin_extension.cc rename to jaxlib/rocm/rocm_plugin_extension.cc index f28b5c9b4..1dd1f1943 100644 --- a/jaxlib/rocm_plugin_extension.cc +++ b/jaxlib/rocm/rocm_plugin_extension.cc @@ -20,7 +20,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/strings/str_cat.h" #include "rocm/include/hip/hip_runtime.h" -#include "jaxlib/gpu_plugin_extension.h" +#include "jaxlib/gpu/gpu_plugin_extension.h" namespace nb = nanobind; diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 5b24d2359..afa5866e2 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -143,16 +143,16 @@ py_binary( data = [ "LICENSE.txt", ] + if_cuda([ - "//jaxlib/mosaic/gpu:mosaic_gpu", - "//jaxlib:cuda_plugin_extension", "//jaxlib:version", + "//jaxlib/mosaic/gpu:mosaic_gpu", + "//jaxlib/cuda:cuda_plugin_extension", "//jaxlib/cuda:cuda_gpu_support", "//jax_plugins/cuda:plugin_pyproject.toml", "//jax_plugins/cuda:plugin_setup.py", "@local_config_cuda//cuda:cuda-nvvm", ]) + if_rocm([ - "//jaxlib:rocm_plugin_extension", "//jaxlib:version", + "//jaxlib/rocm:rocm_plugin_extension", "//jaxlib/rocm:rocm_gpu_support", "//jax_plugins/rocm:plugin_pyproject.toml", "//jax_plugins/rocm:plugin_setup.py", diff --git a/jaxlib/tools/build_gpu_kernels_wheel.py b/jaxlib/tools/build_gpu_kernels_wheel.py index 09a55d3c3..2f81eacbd 100644 --- a/jaxlib/tools/build_gpu_kernels_wheel.py +++ b/jaxlib/tools/build_gpu_kernels_wheel.py @@ -110,7 +110,7 @@ def prepare_wheel_cuda( f"__main__/jaxlib/cuda/_triton.{pyext}", f"__main__/jaxlib/cuda/_hybrid.{pyext}", f"__main__/jaxlib/cuda/_versions.{pyext}", - f"__main__/jaxlib/cuda_plugin_extension.{pyext}", + f"__main__/jaxlib/cuda/cuda_plugin_extension.{pyext}", f"__main__/jaxlib/mosaic/gpu/_mosaic_gpu_ext.{pyext}", "__main__/jaxlib/mosaic/gpu/libmosaic_gpu_runtime.so", "__main__/jaxlib/version.py", @@ -148,7 +148,7 @@ def prepare_wheel_rocm( f"__main__/jaxlib/rocm/_hybrid.{pyext}", f"__main__/jaxlib/rocm/_rnn.{pyext}", f"__main__/jaxlib/rocm/_triton.{pyext}", - f"__main__/jaxlib/rocm_plugin_extension.{pyext}", + f"__main__/jaxlib/rocm/rocm_plugin_extension.{pyext}", "__main__/jaxlib/version.py", ], )