mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Replace pjrt_c_api_gpu_plugin.so
symlink with XLA dependency.
The runfiles of the original targets were lost when the symlinked files were used. This change is needed for future Hermetic CUDA implementation. Bazel will download CUDA distributives in cache, and CUDA executables and libraries will be added in the runfiles of the targets. When pjrt_c_api_gpu_plugin.so is simlinked, the content of the runfiles is lost. With proper XLA target dependency the runfiles are preserved. PiperOrigin-RevId: 662197057
This commit is contained in:
parent
ee31e95ecd
commit
e5eaff84bd
@ -14,7 +14,6 @@
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
load("//jaxlib:symlink_files.bzl", "symlink_files")
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"if_windows",
|
||||
@ -35,22 +34,15 @@ exports_files([
|
||||
"setup.py",
|
||||
])
|
||||
|
||||
symlink_files(
|
||||
name = "pjrt_c_api_gpu_plugin",
|
||||
srcs = if_windows(
|
||||
["@xla//xla/pjrt/c/pjrt_c_api_gpu_plugin.pyd"],
|
||||
["@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so"],
|
||||
),
|
||||
dst = ".",
|
||||
flatten = True,
|
||||
)
|
||||
|
||||
py_library_providing_imports_info(
|
||||
name = "cuda_plugin",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
],
|
||||
data = [":pjrt_c_api_gpu_plugin"],
|
||||
data = if_windows(
|
||||
["@xla//xla/pjrt/c/pjrt_c_api_gpu_plugin.pyd"],
|
||||
["@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so"],
|
||||
),
|
||||
lib_rule = pytype_library,
|
||||
)
|
||||
|
||||
|
@ -48,6 +48,13 @@ def _get_library_path():
|
||||
local_path = os.path.join(
|
||||
os.path.dirname(__file__), 'pjrt_c_api_gpu_plugin.so'
|
||||
)
|
||||
if not os.path.exists(local_path):
|
||||
runfiles_dir = os.getenv('RUNFILES_DIR', None)
|
||||
if runfiles_dir:
|
||||
local_path = os.path.join(
|
||||
runfiles_dir, 'xla/xla/pjrt/c/pjrt_c_api_gpu_plugin.so'
|
||||
)
|
||||
|
||||
if os.path.exists(local_path):
|
||||
logger.debug(
|
||||
'Native library %s does not exist. This most likely indicates an issue'
|
||||
|
@ -14,7 +14,6 @@
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
load("//jaxlib:symlink_files.bzl", "symlink_files")
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"if_windows",
|
||||
@ -35,21 +34,14 @@ exports_files([
|
||||
"setup.py",
|
||||
])
|
||||
|
||||
symlink_files(
|
||||
name = "pjrt_c_api_gpu_plugin",
|
||||
srcs = if_windows(
|
||||
["@xla//xla/pjrt/c/pjrt_c_api_gpu_plugin.pyd"],
|
||||
["@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so"],
|
||||
),
|
||||
dst = ".",
|
||||
flatten = True,
|
||||
)
|
||||
|
||||
py_library_providing_imports_info(
|
||||
name = "rocm_plugin",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
],
|
||||
data = [":pjrt_c_api_gpu_plugin"],
|
||||
data = if_windows(
|
||||
["@xla//xla/pjrt/c/pjrt_c_api_gpu_plugin.pyd"],
|
||||
["@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so"],
|
||||
),
|
||||
lib_rule = pytype_library,
|
||||
)
|
||||
|
@ -15,6 +15,7 @@
|
||||
import functools
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import platform
|
||||
|
||||
@ -47,6 +48,13 @@ def _get_library_path():
|
||||
local_path = (
|
||||
base_path / 'pjrt_c_api_gpu_plugin.so'
|
||||
)
|
||||
if not local_path.exists():
|
||||
runfiles_dir = os.getenv('RUNFILES_DIR', None)
|
||||
if runfiles_dir:
|
||||
local_path = pathlib.Path(
|
||||
os.path.join(runfiles_dir, 'xla/xla/pjrt/c/pjrt_c_api_gpu_plugin.so')
|
||||
)
|
||||
|
||||
if local_path.exists():
|
||||
logger.debug(
|
||||
'Native library %s does not exist. This most likely indicates an issue'
|
||||
|
Loading…
x
Reference in New Issue
Block a user