Replace pjrt_c_api_gpu_plugin.so symlink with XLA dependency.

The runfiles of the original targets were lost when the symlinked files were used.

This change is needed for future Hermetic CUDA implementation. Bazel will download CUDA distributives in cache, and CUDA executables and libraries will be added in the runfiles of the targets. When pjrt_c_api_gpu_plugin.so is simlinked, the content of the runfiles is lost. With proper XLA target dependency the runfiles are preserved.

PiperOrigin-RevId: 662197057
This commit is contained in:
jax authors 2024-08-12 13:00:29 -07:00 committed by jax authors
parent ee31e95ecd
commit e5eaff84bd
4 changed files with 23 additions and 24 deletions

View File

@ -14,7 +14,6 @@
licenses(["notice"])
load("//jaxlib:symlink_files.bzl", "symlink_files")
load(
"//jaxlib:jax.bzl",
"if_windows",
@ -35,22 +34,15 @@ exports_files([
"setup.py",
])
symlink_files(
name = "pjrt_c_api_gpu_plugin",
srcs = if_windows(
["@xla//xla/pjrt/c/pjrt_c_api_gpu_plugin.pyd"],
["@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so"],
),
dst = ".",
flatten = True,
)
py_library_providing_imports_info(
name = "cuda_plugin",
srcs = [
"__init__.py",
],
data = [":pjrt_c_api_gpu_plugin"],
data = if_windows(
["@xla//xla/pjrt/c/pjrt_c_api_gpu_plugin.pyd"],
["@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so"],
),
lib_rule = pytype_library,
)

View File

@ -48,6 +48,13 @@ def _get_library_path():
local_path = os.path.join(
os.path.dirname(__file__), 'pjrt_c_api_gpu_plugin.so'
)
if not os.path.exists(local_path):
runfiles_dir = os.getenv('RUNFILES_DIR', None)
if runfiles_dir:
local_path = os.path.join(
runfiles_dir, 'xla/xla/pjrt/c/pjrt_c_api_gpu_plugin.so'
)
if os.path.exists(local_path):
logger.debug(
'Native library %s does not exist. This most likely indicates an issue'

View File

@ -14,7 +14,6 @@
licenses(["notice"])
load("//jaxlib:symlink_files.bzl", "symlink_files")
load(
"//jaxlib:jax.bzl",
"if_windows",
@ -35,21 +34,14 @@ exports_files([
"setup.py",
])
symlink_files(
name = "pjrt_c_api_gpu_plugin",
srcs = if_windows(
["@xla//xla/pjrt/c/pjrt_c_api_gpu_plugin.pyd"],
["@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so"],
),
dst = ".",
flatten = True,
)
py_library_providing_imports_info(
name = "rocm_plugin",
srcs = [
"__init__.py",
],
data = [":pjrt_c_api_gpu_plugin"],
data = if_windows(
["@xla//xla/pjrt/c/pjrt_c_api_gpu_plugin.pyd"],
["@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so"],
),
lib_rule = pytype_library,
)

View File

@ -15,6 +15,7 @@
import functools
import importlib
import logging
import os
import pathlib
import platform
@ -47,6 +48,13 @@ def _get_library_path():
local_path = (
base_path / 'pjrt_c_api_gpu_plugin.so'
)
if not local_path.exists():
runfiles_dir = os.getenv('RUNFILES_DIR', None)
if runfiles_dir:
local_path = pathlib.Path(
os.path.join(runfiles_dir, 'xla/xla/pjrt/c/pjrt_c_api_gpu_plugin.so')
)
if local_path.exists():
logger.debug(
'Native library %s does not exist. This most likely indicates an issue'