mirror of
https://github.com/ROCm/jax.git
synced 2025-04-13 02:16:06 +00:00
Move jaxlib/{cuda,rocm}_plugin_extension into jaxlib/{cuda/rocm}/
Move the common jaxlib/gpu_plugin_extension into jaxlib/gpu/ Cleanup only, no functional changes intended. PiperOrigin-RevId: 738183402
This commit is contained in:
parent
01a110c4c9
commit
3f91b4b43a
@ -71,7 +71,7 @@ import numpy as np
|
||||
|
||||
export = set_module('jax.numpy')
|
||||
|
||||
for pkg_name in ['jax_cuda12_plugin', 'jax.jaxlib']:
|
||||
for pkg_name in ['jax_cuda12_plugin', 'jax.jaxlib.cuda']:
|
||||
try:
|
||||
cuda_plugin_extension = importlib.import_module(
|
||||
f'{pkg_name}.cuda_plugin_extension'
|
||||
|
@ -24,7 +24,7 @@ import jax._src.xla_bridge as xb
|
||||
|
||||
# cuda_plugin_extension locates inside jaxlib. `jaxlib` is for testing without
|
||||
# preinstalled jax cuda plugin packages.
|
||||
for pkg_name in ['jax_cuda12_plugin', 'jaxlib']:
|
||||
for pkg_name in ['jax_cuda12_plugin', 'jaxlib.cuda']:
|
||||
try:
|
||||
cuda_plugin_extension = importlib.import_module(
|
||||
f'{pkg_name}.cuda_plugin_extension'
|
||||
|
@ -23,7 +23,7 @@ import jax._src.xla_bridge as xb
|
||||
|
||||
# rocm_plugin_extension locates inside jaxlib. `jaxlib` is for testing without
|
||||
# preinstalled jax rocm plugin packages.
|
||||
for pkg_name in ['jax_rocm60_plugin', 'jaxlib']:
|
||||
for pkg_name in ['jax_rocm60_plugin', 'jaxlib.cuda']:
|
||||
try:
|
||||
rocm_plugin_extension = importlib.import_module(
|
||||
f'{pkg_name}.rocm_plugin_extension'
|
||||
|
59
jaxlib/BUILD
59
jaxlib/BUILD
@ -222,62 +222,3 @@ nanobind_extension(
|
||||
"@xla//third_party/python_runtime:headers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gpu_plugin_extension",
|
||||
srcs = ["gpu_plugin_extension.cc"],
|
||||
hdrs = ["gpu_plugin_extension.h"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
":kernel_nanobind_helpers",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/strings:string_view",
|
||||
"@nanobind",
|
||||
"@xla//xla:util",
|
||||
"@xla//xla/ffi/api:c_api",
|
||||
"@xla//xla/pjrt:status_casters",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_ffi_extension_hdrs",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_helpers",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_triton_extension_hdrs",
|
||||
"@xla//xla/python:py_client_gpu",
|
||||
"@xla//xla/tsl/python/lib/core:numpy",
|
||||
],
|
||||
)
|
||||
|
||||
nanobind_extension(
|
||||
name = "cuda_plugin_extension",
|
||||
srcs = ["cuda_plugin_extension.cc"],
|
||||
module_name = "cuda_plugin_extension",
|
||||
deps = [
|
||||
":gpu_plugin_extension",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@nanobind",
|
||||
"@xla//xla/pjrt:status_casters",
|
||||
"@xla//xla/tsl/cuda:cublas",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
],
|
||||
)
|
||||
|
||||
nanobind_extension(
|
||||
name = "rocm_plugin_extension",
|
||||
srcs = ["rocm_plugin_extension.cc"],
|
||||
module_name = "rocm_plugin_extension",
|
||||
deps = [
|
||||
":gpu_plugin_extension",
|
||||
"@com_google_absl//absl/log",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@local_config_rocm//rocm:hip",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
"@nanobind",
|
||||
],
|
||||
)
|
||||
|
@ -657,6 +657,22 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
nanobind_extension(
|
||||
name = "cuda_plugin_extension",
|
||||
srcs = ["cuda_plugin_extension.cc"],
|
||||
module_name = "cuda_plugin_extension",
|
||||
deps = [
|
||||
"//jaxlib/gpu:gpu_plugin_extension",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@nanobind",
|
||||
"@xla//xla/pjrt:status_casters",
|
||||
"@xla//xla/tsl/cuda:cublas",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
],
|
||||
)
|
||||
|
||||
# We cannot nest select and if_cuda_is_configured so we introduce
|
||||
# a standalone py_library target.
|
||||
py_library(
|
||||
@ -664,6 +680,6 @@ py_library(
|
||||
# `if_cuda_is_configured` will default to `[]`.
|
||||
deps = if_cuda_is_configured([
|
||||
":cuda_gpu_support",
|
||||
"//jaxlib:cuda_plugin_extension",
|
||||
":cuda_plugin_extension",
|
||||
]),
|
||||
)
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#include "jaxlib/gpu_plugin_extension.h"
|
||||
#include "jaxlib/gpu/gpu_plugin_extension.h"
|
||||
#include "xla/pjrt/status_casters.h"
|
||||
|
||||
namespace nb = nanobind;
|
@ -90,3 +90,32 @@ xla_py_proto_library(
|
||||
visibility = jax_visibility("triton_proto_py_users"),
|
||||
deps = [":triton_proto"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gpu_plugin_extension",
|
||||
srcs = ["gpu_plugin_extension.cc"],
|
||||
hdrs = ["gpu_plugin_extension.h"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
"//jaxlib:kernel_nanobind_helpers",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/strings:string_view",
|
||||
"@nanobind",
|
||||
"@xla//xla:util",
|
||||
"@xla//xla/ffi/api:c_api",
|
||||
"@xla//xla/pjrt:status_casters",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_ffi_extension_hdrs",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_helpers",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_triton_extension_hdrs",
|
||||
"@xla//xla/python:py_client_gpu",
|
||||
"@xla//xla/tsl/python/lib/core:numpy",
|
||||
],
|
||||
)
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/gpu_plugin_extension.h"
|
||||
#include "jaxlib/gpu/gpu_plugin_extension.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef JAXLIB_GPU_PLUGIN_EXTENSION_H_
|
||||
#define JAXLIB_GPU_PLUGIN_EXTENSION_H_
|
||||
#ifndef JAXLIB_GPU_GPU_PLUGIN_EXTENSION_H_
|
||||
#define JAXLIB_GPU_GPU_PLUGIN_EXTENSION_H_
|
||||
|
||||
#include "nanobind/nanobind.h"
|
||||
|
||||
@ -24,4 +24,4 @@ void BuildGpuPluginExtension(nanobind::module_& m);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // JAXLIB_GPU_PLUGIN_EXTENSION_H_
|
||||
#endif // JAXLIB_GPU_GPU_PLUGIN_EXTENSION_H_
|
@ -555,11 +555,25 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
nanobind_extension(
|
||||
name = "rocm_plugin_extension",
|
||||
srcs = ["rocm_plugin_extension.cc"],
|
||||
module_name = "rocm_plugin_extension",
|
||||
deps = [
|
||||
"//jaxlib/gpu:gpu_plugin_extension",
|
||||
"@com_google_absl//absl/log",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@local_config_rocm//rocm:hip",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
"@nanobind",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "gpu_only_test_deps",
|
||||
# `if_rocm_is_configured` will default to `[]`.
|
||||
deps = if_rocm_is_configured([
|
||||
":rocm_gpu_support",
|
||||
"//jaxlib:rocm_plugin_extension",
|
||||
":rocm_plugin_extension",
|
||||
]),
|
||||
)
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
#include "absl/log/log.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "rocm/include/hip/hip_runtime.h"
|
||||
#include "jaxlib/gpu_plugin_extension.h"
|
||||
#include "jaxlib/gpu/gpu_plugin_extension.h"
|
||||
|
||||
namespace nb = nanobind;
|
||||
|
@ -143,16 +143,16 @@ py_binary(
|
||||
data = [
|
||||
"LICENSE.txt",
|
||||
] + if_cuda([
|
||||
"//jaxlib/mosaic/gpu:mosaic_gpu",
|
||||
"//jaxlib:cuda_plugin_extension",
|
||||
"//jaxlib:version",
|
||||
"//jaxlib/mosaic/gpu:mosaic_gpu",
|
||||
"//jaxlib/cuda:cuda_plugin_extension",
|
||||
"//jaxlib/cuda:cuda_gpu_support",
|
||||
"//jax_plugins/cuda:plugin_pyproject.toml",
|
||||
"//jax_plugins/cuda:plugin_setup.py",
|
||||
"@local_config_cuda//cuda:cuda-nvvm",
|
||||
]) + if_rocm([
|
||||
"//jaxlib:rocm_plugin_extension",
|
||||
"//jaxlib:version",
|
||||
"//jaxlib/rocm:rocm_plugin_extension",
|
||||
"//jaxlib/rocm:rocm_gpu_support",
|
||||
"//jax_plugins/rocm:plugin_pyproject.toml",
|
||||
"//jax_plugins/rocm:plugin_setup.py",
|
||||
|
@ -110,7 +110,7 @@ def prepare_wheel_cuda(
|
||||
f"__main__/jaxlib/cuda/_triton.{pyext}",
|
||||
f"__main__/jaxlib/cuda/_hybrid.{pyext}",
|
||||
f"__main__/jaxlib/cuda/_versions.{pyext}",
|
||||
f"__main__/jaxlib/cuda_plugin_extension.{pyext}",
|
||||
f"__main__/jaxlib/cuda/cuda_plugin_extension.{pyext}",
|
||||
f"__main__/jaxlib/mosaic/gpu/_mosaic_gpu_ext.{pyext}",
|
||||
"__main__/jaxlib/mosaic/gpu/libmosaic_gpu_runtime.so",
|
||||
"__main__/jaxlib/version.py",
|
||||
@ -148,7 +148,7 @@ def prepare_wheel_rocm(
|
||||
f"__main__/jaxlib/rocm/_hybrid.{pyext}",
|
||||
f"__main__/jaxlib/rocm/_rnn.{pyext}",
|
||||
f"__main__/jaxlib/rocm/_triton.{pyext}",
|
||||
f"__main__/jaxlib/rocm_plugin_extension.{pyext}",
|
||||
f"__main__/jaxlib/rocm/rocm_plugin_extension.{pyext}",
|
||||
"__main__/jaxlib/version.py",
|
||||
],
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user