mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00

Move the common jaxlib/gpu_plugin_extension into jaxlib/gpu/ Cleanup only, no functional changes intended. PiperOrigin-RevId: 738183402
111 lines
3.2 KiB
Python
111 lines
3.2 KiB
Python
# Copyright 2023 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import functools
|
|
import importlib
|
|
import logging
|
|
import os
|
|
import pathlib
|
|
|
|
from jax._src.lib import triton
|
|
from jax._src.lib import xla_client
|
|
import jax._src.xla_bridge as xb
|
|
|
|
# cuda_plugin_extension locates inside jaxlib. `jaxlib` is for testing without
|
|
# preinstalled jax cuda plugin packages.
|
|
for pkg_name in ['jax_cuda12_plugin', 'jaxlib.cuda']:
|
|
try:
|
|
cuda_plugin_extension = importlib.import_module(
|
|
f'{pkg_name}.cuda_plugin_extension'
|
|
)
|
|
except ImportError:
|
|
cuda_plugin_extension = None
|
|
else:
|
|
break
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _get_library_path():
|
|
installed_path = (
|
|
pathlib.Path(__file__).resolve().parent / 'xla_cuda_plugin.so'
|
|
)
|
|
if installed_path.exists():
|
|
return installed_path
|
|
|
|
local_path = os.path.join(
|
|
os.path.dirname(__file__), 'pjrt_c_api_gpu_plugin.so'
|
|
)
|
|
if not os.path.exists(local_path):
|
|
runfiles_dir = os.getenv('RUNFILES_DIR', None)
|
|
if runfiles_dir:
|
|
local_path = os.path.join(
|
|
runfiles_dir, 'xla/xla/pjrt/c/pjrt_c_api_gpu_plugin.so'
|
|
)
|
|
|
|
if os.path.exists(local_path):
|
|
logger.debug(
|
|
'Native library %s does not exist. This most likely indicates an issue'
|
|
' with how %s was built or installed. Fallback to local test'
|
|
' library %s',
|
|
installed_path,
|
|
__package__,
|
|
local_path,
|
|
)
|
|
return local_path
|
|
|
|
logger.debug(
|
|
'WARNING: Native library %s and local test library path %s do not'
|
|
' exist. This most likely indicates an issue with how %s was built or'
|
|
' installed or missing src files.',
|
|
installed_path,
|
|
local_path,
|
|
__package__,
|
|
)
|
|
return None
|
|
|
|
|
|
def initialize():
|
|
path = _get_library_path()
|
|
if path is None:
|
|
return
|
|
|
|
options = xla_client.generate_pjrt_gpu_plugin_options()
|
|
c_api = xb.register_plugin(
|
|
'cuda', priority=500, library_path=str(path), options=options
|
|
)
|
|
if cuda_plugin_extension:
|
|
xla_client.register_custom_call_handler(
|
|
"CUDA",
|
|
functools.partial(
|
|
cuda_plugin_extension.register_custom_call_target, c_api
|
|
),
|
|
)
|
|
for _name, _value in cuda_plugin_extension.registrations().items():
|
|
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
|
|
xla_client.register_custom_type_id_handler(
|
|
"CUDA",
|
|
functools.partial(
|
|
cuda_plugin_extension.register_custom_type_id, c_api
|
|
),
|
|
)
|
|
triton.register_compilation_handler(
|
|
"CUDA",
|
|
functools.partial(
|
|
cuda_plugin_extension.compile_triton_to_asm, c_api
|
|
),
|
|
)
|
|
else:
|
|
logger.warning('cuda_plugin_extension is not found.')
|