rocm_jax/jax_plugins/cuda/__init__.py
Peter Hawkins 3f91b4b43a Move jaxlib/{cuda,rocm}_plugin_extension into jaxlib/{cuda/rocm}/
Move the common jaxlib/gpu_plugin_extension into jaxlib/gpu/

Cleanup only, no functional changes intended.

PiperOrigin-RevId: 738183402
2025-03-18 16:29:37 -07:00

111 lines
3.2 KiB
Python

# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import importlib
import logging
import os
import pathlib
from jax._src.lib import triton
from jax._src.lib import xla_client
import jax._src.xla_bridge as xb
# cuda_plugin_extension locates inside jaxlib. `jaxlib` is for testing without
# preinstalled jax cuda plugin packages.
for pkg_name in ['jax_cuda12_plugin', 'jaxlib.cuda']:
try:
cuda_plugin_extension = importlib.import_module(
f'{pkg_name}.cuda_plugin_extension'
)
except ImportError:
cuda_plugin_extension = None
else:
break
logger = logging.getLogger(__name__)
def _get_library_path():
installed_path = (
pathlib.Path(__file__).resolve().parent / 'xla_cuda_plugin.so'
)
if installed_path.exists():
return installed_path
local_path = os.path.join(
os.path.dirname(__file__), 'pjrt_c_api_gpu_plugin.so'
)
if not os.path.exists(local_path):
runfiles_dir = os.getenv('RUNFILES_DIR', None)
if runfiles_dir:
local_path = os.path.join(
runfiles_dir, 'xla/xla/pjrt/c/pjrt_c_api_gpu_plugin.so'
)
if os.path.exists(local_path):
logger.debug(
'Native library %s does not exist. This most likely indicates an issue'
' with how %s was built or installed. Fallback to local test'
' library %s',
installed_path,
__package__,
local_path,
)
return local_path
logger.debug(
'WARNING: Native library %s and local test library path %s do not'
' exist. This most likely indicates an issue with how %s was built or'
' installed or missing src files.',
installed_path,
local_path,
__package__,
)
return None
def initialize():
path = _get_library_path()
if path is None:
return
options = xla_client.generate_pjrt_gpu_plugin_options()
c_api = xb.register_plugin(
'cuda', priority=500, library_path=str(path), options=options
)
if cuda_plugin_extension:
xla_client.register_custom_call_handler(
"CUDA",
functools.partial(
cuda_plugin_extension.register_custom_call_target, c_api
),
)
for _name, _value in cuda_plugin_extension.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
xla_client.register_custom_type_id_handler(
"CUDA",
functools.partial(
cuda_plugin_extension.register_custom_type_id, c_api
),
)
triton.register_compilation_handler(
"CUDA",
functools.partial(
cuda_plugin_extension.compile_triton_to_asm, c_api
),
)
else:
logger.warning('cuda_plugin_extension is not found.')