# Copyright 2024 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import functools import importlib import logging import os import pathlib from jax._src.lib import xla_client import jax._src.xla_bridge as xb # rocm_plugin_extension locates inside jaxlib. `jaxlib` is for testing without # preinstalled jax rocm plugin packages. for pkg_name in ['jax_rocm60_plugin', 'jaxlib.cuda']: try: rocm_plugin_extension = importlib.import_module( f'{pkg_name}.rocm_plugin_extension' ) except ImportError: rocm_plugin_extension = None else: break logger = logging.getLogger(__name__) def _get_library_path(): base_path = pathlib.Path(__file__).resolve().parent installed_path = ( base_path / 'xla_rocm_plugin.so' ) if installed_path.exists(): return installed_path local_path = ( base_path / 'pjrt_c_api_gpu_plugin.so' ) if not local_path.exists(): runfiles_dir = os.getenv('RUNFILES_DIR', None) if runfiles_dir: local_path = pathlib.Path( os.path.join(runfiles_dir, 'xla/xla/pjrt/c/pjrt_c_api_gpu_plugin.so') ) if local_path.exists(): logger.debug( 'Native library %s does not exist. This most likely indicates an issue' ' with how %s was built or installed. Fallback to local test' ' library %s', installed_path, __package__, local_path, ) return local_path logger.debug( 'WARNING: Native library %s and local test library path %s do not' ' exist. This most likely indicates an issue with how %s was built or' ' installed or missing src files.', installed_path, local_path, __package__, ) return None def initialize(): path = _get_library_path() if path is None: return options = xla_client.generate_pjrt_gpu_plugin_options() options["platform_name"] = "ROCM" c_api = xb.register_plugin( 'rocm', priority=500, library_path=str(path), options=options ) if rocm_plugin_extension: xla_client.register_custom_call_handler( "ROCM", functools.partial( rocm_plugin_extension.register_custom_call_target, c_api ), ) for _name, _value in rocm_plugin_extension.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="ROCM") xla_client.register_custom_type_id_handler( "ROCM", functools.partial( rocm_plugin_extension.register_custom_type_id, c_api ), ) else: logger.warning('rocm_plugin_extension is not found.')