rocm_jax/jax_plugins/rocm/__init__.py
Peter Hawkins 3f91b4b43a Move jaxlib/{cuda,rocm}_plugin_extension into jaxlib/{cuda/rocm}/
Move the common jaxlib/gpu_plugin_extension into jaxlib/gpu/

Cleanup only, no functional changes intended.

PiperOrigin-RevId: 738183402
2025-03-18 16:29:37 -07:00

105 lines
3.0 KiB
Python

# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import importlib
import logging
import os
import pathlib
from jax._src.lib import xla_client
import jax._src.xla_bridge as xb
# rocm_plugin_extension locates inside jaxlib. `jaxlib` is for testing without
# preinstalled jax rocm plugin packages.
for pkg_name in ['jax_rocm60_plugin', 'jaxlib.cuda']:
try:
rocm_plugin_extension = importlib.import_module(
f'{pkg_name}.rocm_plugin_extension'
)
except ImportError:
rocm_plugin_extension = None
else:
break
logger = logging.getLogger(__name__)
def _get_library_path():
base_path = pathlib.Path(__file__).resolve().parent
installed_path = (
base_path / 'xla_rocm_plugin.so'
)
if installed_path.exists():
return installed_path
local_path = (
base_path / 'pjrt_c_api_gpu_plugin.so'
)
if not local_path.exists():
runfiles_dir = os.getenv('RUNFILES_DIR', None)
if runfiles_dir:
local_path = pathlib.Path(
os.path.join(runfiles_dir, 'xla/xla/pjrt/c/pjrt_c_api_gpu_plugin.so')
)
if local_path.exists():
logger.debug(
'Native library %s does not exist. This most likely indicates an issue'
' with how %s was built or installed. Fallback to local test'
' library %s',
installed_path,
__package__,
local_path,
)
return local_path
logger.debug(
'WARNING: Native library %s and local test library path %s do not'
' exist. This most likely indicates an issue with how %s was built or'
' installed or missing src files.',
installed_path,
local_path,
__package__,
)
return None
def initialize():
path = _get_library_path()
if path is None:
return
options = xla_client.generate_pjrt_gpu_plugin_options()
options["platform_name"] = "ROCM"
c_api = xb.register_plugin(
'rocm', priority=500, library_path=str(path), options=options
)
if rocm_plugin_extension:
xla_client.register_custom_call_handler(
"ROCM",
functools.partial(
rocm_plugin_extension.register_custom_call_target, c_api
),
)
for _name, _value in rocm_plugin_extension.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
xla_client.register_custom_type_id_handler(
"ROCM",
functools.partial(
rocm_plugin_extension.register_custom_type_id, c_api
),
)
else:
logger.warning('rocm_plugin_extension is not found.')