mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Move jaxlib/{cuda,rocm}_plugin_extension into jaxlib/{cuda/rocm}/
Move the common jaxlib/gpu_plugin_extension into jaxlib/gpu/ Cleanup only, no functional changes intended. PiperOrigin-RevId: 738183402
This commit is contained in:
parent
01a110c4c9
commit
3f91b4b43a
@ -71,7 +71,7 @@ import numpy as np
|
|||||||
|
|
||||||
export = set_module('jax.numpy')
|
export = set_module('jax.numpy')
|
||||||
|
|
||||||
for pkg_name in ['jax_cuda12_plugin', 'jax.jaxlib']:
|
for pkg_name in ['jax_cuda12_plugin', 'jax.jaxlib.cuda']:
|
||||||
try:
|
try:
|
||||||
cuda_plugin_extension = importlib.import_module(
|
cuda_plugin_extension = importlib.import_module(
|
||||||
f'{pkg_name}.cuda_plugin_extension'
|
f'{pkg_name}.cuda_plugin_extension'
|
||||||
|
@ -24,7 +24,7 @@ import jax._src.xla_bridge as xb
|
|||||||
|
|
||||||
# cuda_plugin_extension locates inside jaxlib. `jaxlib` is for testing without
|
# cuda_plugin_extension locates inside jaxlib. `jaxlib` is for testing without
|
||||||
# preinstalled jax cuda plugin packages.
|
# preinstalled jax cuda plugin packages.
|
||||||
for pkg_name in ['jax_cuda12_plugin', 'jaxlib']:
|
for pkg_name in ['jax_cuda12_plugin', 'jaxlib.cuda']:
|
||||||
try:
|
try:
|
||||||
cuda_plugin_extension = importlib.import_module(
|
cuda_plugin_extension = importlib.import_module(
|
||||||
f'{pkg_name}.cuda_plugin_extension'
|
f'{pkg_name}.cuda_plugin_extension'
|
||||||
|
@ -23,7 +23,7 @@ import jax._src.xla_bridge as xb
|
|||||||
|
|
||||||
# rocm_plugin_extension locates inside jaxlib. `jaxlib` is for testing without
|
# rocm_plugin_extension locates inside jaxlib. `jaxlib` is for testing without
|
||||||
# preinstalled jax rocm plugin packages.
|
# preinstalled jax rocm plugin packages.
|
||||||
for pkg_name in ['jax_rocm60_plugin', 'jaxlib']:
|
for pkg_name in ['jax_rocm60_plugin', 'jaxlib.cuda']:
|
||||||
try:
|
try:
|
||||||
rocm_plugin_extension = importlib.import_module(
|
rocm_plugin_extension = importlib.import_module(
|
||||||
f'{pkg_name}.rocm_plugin_extension'
|
f'{pkg_name}.rocm_plugin_extension'
|
||||||
|
59
jaxlib/BUILD
59
jaxlib/BUILD
@ -222,62 +222,3 @@ nanobind_extension(
|
|||||||
"@xla//third_party/python_runtime:headers",
|
"@xla//third_party/python_runtime:headers",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "gpu_plugin_extension",
|
|
||||||
srcs = ["gpu_plugin_extension.cc"],
|
|
||||||
hdrs = ["gpu_plugin_extension.h"],
|
|
||||||
copts = [
|
|
||||||
"-fexceptions",
|
|
||||||
"-fno-strict-aliasing",
|
|
||||||
],
|
|
||||||
features = ["-use_header_modules"],
|
|
||||||
deps = [
|
|
||||||
":kernel_nanobind_helpers",
|
|
||||||
"@com_google_absl//absl/status",
|
|
||||||
"@com_google_absl//absl/status:statusor",
|
|
||||||
"@com_google_absl//absl/strings:str_format",
|
|
||||||
"@com_google_absl//absl/strings:string_view",
|
|
||||||
"@nanobind",
|
|
||||||
"@xla//xla:util",
|
|
||||||
"@xla//xla/ffi/api:c_api",
|
|
||||||
"@xla//xla/pjrt:status_casters",
|
|
||||||
"@xla//xla/pjrt/c:pjrt_c_api_ffi_extension_hdrs",
|
|
||||||
"@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs",
|
|
||||||
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
|
|
||||||
"@xla//xla/pjrt/c:pjrt_c_api_helpers",
|
|
||||||
"@xla//xla/pjrt/c:pjrt_c_api_triton_extension_hdrs",
|
|
||||||
"@xla//xla/python:py_client_gpu",
|
|
||||||
"@xla//xla/tsl/python/lib/core:numpy",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
nanobind_extension(
|
|
||||||
name = "cuda_plugin_extension",
|
|
||||||
srcs = ["cuda_plugin_extension.cc"],
|
|
||||||
module_name = "cuda_plugin_extension",
|
|
||||||
deps = [
|
|
||||||
":gpu_plugin_extension",
|
|
||||||
"@com_google_absl//absl/status",
|
|
||||||
"@com_google_absl//absl/strings",
|
|
||||||
"@local_config_cuda//cuda:cuda_headers",
|
|
||||||
"@nanobind",
|
|
||||||
"@xla//xla/pjrt:status_casters",
|
|
||||||
"@xla//xla/tsl/cuda:cublas",
|
|
||||||
"@xla//xla/tsl/cuda:cudart",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
nanobind_extension(
|
|
||||||
name = "rocm_plugin_extension",
|
|
||||||
srcs = ["rocm_plugin_extension.cc"],
|
|
||||||
module_name = "rocm_plugin_extension",
|
|
||||||
deps = [
|
|
||||||
":gpu_plugin_extension",
|
|
||||||
"@com_google_absl//absl/log",
|
|
||||||
"@com_google_absl//absl/strings",
|
|
||||||
"@local_config_rocm//rocm:hip",
|
|
||||||
"@local_config_rocm//rocm:rocm_headers",
|
|
||||||
"@nanobind",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
@ -657,6 +657,22 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
nanobind_extension(
|
||||||
|
name = "cuda_plugin_extension",
|
||||||
|
srcs = ["cuda_plugin_extension.cc"],
|
||||||
|
module_name = "cuda_plugin_extension",
|
||||||
|
deps = [
|
||||||
|
"//jaxlib/gpu:gpu_plugin_extension",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@local_config_cuda//cuda:cuda_headers",
|
||||||
|
"@nanobind",
|
||||||
|
"@xla//xla/pjrt:status_casters",
|
||||||
|
"@xla//xla/tsl/cuda:cublas",
|
||||||
|
"@xla//xla/tsl/cuda:cudart",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# We cannot nest select and if_cuda_is_configured so we introduce
|
# We cannot nest select and if_cuda_is_configured so we introduce
|
||||||
# a standalone py_library target.
|
# a standalone py_library target.
|
||||||
py_library(
|
py_library(
|
||||||
@ -664,6 +680,6 @@ py_library(
|
|||||||
# `if_cuda_is_configured` will default to `[]`.
|
# `if_cuda_is_configured` will default to `[]`.
|
||||||
deps = if_cuda_is_configured([
|
deps = if_cuda_is_configured([
|
||||||
":cuda_gpu_support",
|
":cuda_gpu_support",
|
||||||
"//jaxlib:cuda_plugin_extension",
|
":cuda_plugin_extension",
|
||||||
]),
|
]),
|
||||||
)
|
)
|
||||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
|||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "third_party/gpus/cuda/include/cuda.h"
|
#include "third_party/gpus/cuda/include/cuda.h"
|
||||||
#include "jaxlib/gpu_plugin_extension.h"
|
#include "jaxlib/gpu/gpu_plugin_extension.h"
|
||||||
#include "xla/pjrt/status_casters.h"
|
#include "xla/pjrt/status_casters.h"
|
||||||
|
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
@ -90,3 +90,32 @@ xla_py_proto_library(
|
|||||||
visibility = jax_visibility("triton_proto_py_users"),
|
visibility = jax_visibility("triton_proto_py_users"),
|
||||||
deps = [":triton_proto"],
|
deps = [":triton_proto"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "gpu_plugin_extension",
|
||||||
|
srcs = ["gpu_plugin_extension.cc"],
|
||||||
|
hdrs = ["gpu_plugin_extension.h"],
|
||||||
|
copts = [
|
||||||
|
"-fexceptions",
|
||||||
|
"-fno-strict-aliasing",
|
||||||
|
],
|
||||||
|
features = ["-use_header_modules"],
|
||||||
|
deps = [
|
||||||
|
"//jaxlib:kernel_nanobind_helpers",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
"@com_google_absl//absl/strings:string_view",
|
||||||
|
"@nanobind",
|
||||||
|
"@xla//xla:util",
|
||||||
|
"@xla//xla/ffi/api:c_api",
|
||||||
|
"@xla//xla/pjrt:status_casters",
|
||||||
|
"@xla//xla/pjrt/c:pjrt_c_api_ffi_extension_hdrs",
|
||||||
|
"@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs",
|
||||||
|
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
|
||||||
|
"@xla//xla/pjrt/c:pjrt_c_api_helpers",
|
||||||
|
"@xla//xla/pjrt/c:pjrt_c_api_triton_extension_hdrs",
|
||||||
|
"@xla//xla/python:py_client_gpu",
|
||||||
|
"@xla//xla/tsl/python/lib/core:numpy",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "jaxlib/gpu_plugin_extension.h"
|
#include "jaxlib/gpu/gpu_plugin_extension.h"
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef JAXLIB_GPU_PLUGIN_EXTENSION_H_
|
#ifndef JAXLIB_GPU_GPU_PLUGIN_EXTENSION_H_
|
||||||
#define JAXLIB_GPU_PLUGIN_EXTENSION_H_
|
#define JAXLIB_GPU_GPU_PLUGIN_EXTENSION_H_
|
||||||
|
|
||||||
#include "nanobind/nanobind.h"
|
#include "nanobind/nanobind.h"
|
||||||
|
|
||||||
@ -24,4 +24,4 @@ void BuildGpuPluginExtension(nanobind::module_& m);
|
|||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
||||||
#endif // JAXLIB_GPU_PLUGIN_EXTENSION_H_
|
#endif // JAXLIB_GPU_GPU_PLUGIN_EXTENSION_H_
|
@ -555,11 +555,25 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
nanobind_extension(
|
||||||
|
name = "rocm_plugin_extension",
|
||||||
|
srcs = ["rocm_plugin_extension.cc"],
|
||||||
|
module_name = "rocm_plugin_extension",
|
||||||
|
deps = [
|
||||||
|
"//jaxlib/gpu:gpu_plugin_extension",
|
||||||
|
"@com_google_absl//absl/log",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@local_config_rocm//rocm:hip",
|
||||||
|
"@local_config_rocm//rocm:rocm_headers",
|
||||||
|
"@nanobind",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "gpu_only_test_deps",
|
name = "gpu_only_test_deps",
|
||||||
# `if_rocm_is_configured` will default to `[]`.
|
# `if_rocm_is_configured` will default to `[]`.
|
||||||
deps = if_rocm_is_configured([
|
deps = if_rocm_is_configured([
|
||||||
":rocm_gpu_support",
|
":rocm_gpu_support",
|
||||||
"//jaxlib:rocm_plugin_extension",
|
":rocm_plugin_extension",
|
||||||
]),
|
]),
|
||||||
)
|
)
|
||||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
|||||||
#include "absl/log/log.h"
|
#include "absl/log/log.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "rocm/include/hip/hip_runtime.h"
|
#include "rocm/include/hip/hip_runtime.h"
|
||||||
#include "jaxlib/gpu_plugin_extension.h"
|
#include "jaxlib/gpu/gpu_plugin_extension.h"
|
||||||
|
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
|
|
@ -143,16 +143,16 @@ py_binary(
|
|||||||
data = [
|
data = [
|
||||||
"LICENSE.txt",
|
"LICENSE.txt",
|
||||||
] + if_cuda([
|
] + if_cuda([
|
||||||
"//jaxlib/mosaic/gpu:mosaic_gpu",
|
|
||||||
"//jaxlib:cuda_plugin_extension",
|
|
||||||
"//jaxlib:version",
|
"//jaxlib:version",
|
||||||
|
"//jaxlib/mosaic/gpu:mosaic_gpu",
|
||||||
|
"//jaxlib/cuda:cuda_plugin_extension",
|
||||||
"//jaxlib/cuda:cuda_gpu_support",
|
"//jaxlib/cuda:cuda_gpu_support",
|
||||||
"//jax_plugins/cuda:plugin_pyproject.toml",
|
"//jax_plugins/cuda:plugin_pyproject.toml",
|
||||||
"//jax_plugins/cuda:plugin_setup.py",
|
"//jax_plugins/cuda:plugin_setup.py",
|
||||||
"@local_config_cuda//cuda:cuda-nvvm",
|
"@local_config_cuda//cuda:cuda-nvvm",
|
||||||
]) + if_rocm([
|
]) + if_rocm([
|
||||||
"//jaxlib:rocm_plugin_extension",
|
|
||||||
"//jaxlib:version",
|
"//jaxlib:version",
|
||||||
|
"//jaxlib/rocm:rocm_plugin_extension",
|
||||||
"//jaxlib/rocm:rocm_gpu_support",
|
"//jaxlib/rocm:rocm_gpu_support",
|
||||||
"//jax_plugins/rocm:plugin_pyproject.toml",
|
"//jax_plugins/rocm:plugin_pyproject.toml",
|
||||||
"//jax_plugins/rocm:plugin_setup.py",
|
"//jax_plugins/rocm:plugin_setup.py",
|
||||||
|
@ -110,7 +110,7 @@ def prepare_wheel_cuda(
|
|||||||
f"__main__/jaxlib/cuda/_triton.{pyext}",
|
f"__main__/jaxlib/cuda/_triton.{pyext}",
|
||||||
f"__main__/jaxlib/cuda/_hybrid.{pyext}",
|
f"__main__/jaxlib/cuda/_hybrid.{pyext}",
|
||||||
f"__main__/jaxlib/cuda/_versions.{pyext}",
|
f"__main__/jaxlib/cuda/_versions.{pyext}",
|
||||||
f"__main__/jaxlib/cuda_plugin_extension.{pyext}",
|
f"__main__/jaxlib/cuda/cuda_plugin_extension.{pyext}",
|
||||||
f"__main__/jaxlib/mosaic/gpu/_mosaic_gpu_ext.{pyext}",
|
f"__main__/jaxlib/mosaic/gpu/_mosaic_gpu_ext.{pyext}",
|
||||||
"__main__/jaxlib/mosaic/gpu/libmosaic_gpu_runtime.so",
|
"__main__/jaxlib/mosaic/gpu/libmosaic_gpu_runtime.so",
|
||||||
"__main__/jaxlib/version.py",
|
"__main__/jaxlib/version.py",
|
||||||
@ -148,7 +148,7 @@ def prepare_wheel_rocm(
|
|||||||
f"__main__/jaxlib/rocm/_hybrid.{pyext}",
|
f"__main__/jaxlib/rocm/_hybrid.{pyext}",
|
||||||
f"__main__/jaxlib/rocm/_rnn.{pyext}",
|
f"__main__/jaxlib/rocm/_rnn.{pyext}",
|
||||||
f"__main__/jaxlib/rocm/_triton.{pyext}",
|
f"__main__/jaxlib/rocm/_triton.{pyext}",
|
||||||
f"__main__/jaxlib/rocm_plugin_extension.{pyext}",
|
f"__main__/jaxlib/rocm/rocm_plugin_extension.{pyext}",
|
||||||
"__main__/jaxlib/version.py",
|
"__main__/jaxlib/version.py",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user