rocm_jax/jaxlib/plugin_support.py

111 lines
3.5 KiB
Python
Raw Normal View History

# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
import importlib
import re
from types import ModuleType
import warnings
from .version import __version__ as jaxlib_version
_PLUGIN_MODULE_NAME = {
"cuda": "jax_cuda12_plugin",
"rocm": "jax_rocm60_plugin",
}
def import_from_plugin(
plugin_name: str, submodule_name: str, *, check_version: bool = True
) -> ModuleType | None:
"""Import a submodule from a known plugin with version checking.
Args:
plugin_name: The name of the plugin. The supported values are "cuda" or
"rocm".
submodule_name: The name of the submodule to import, e.g. "_triton".
check_version: Whether to check that the plugin version is compatible with
the jaxlib version. If the plugin is installed but the versions are not
compatible, this function produces a warning and returns None.
Returns:
The imported submodule, or None if the plugin is not installed or if the
versions are incompatible.
"""
if plugin_name not in _PLUGIN_MODULE_NAME:
raise ValueError(f"Unknown plugin: {plugin_name}")
return maybe_import_plugin_submodule(
[f".{plugin_name}", _PLUGIN_MODULE_NAME[plugin_name]],
submodule_name,
check_version=check_version,
)
def check_plugin_version(
plugin_name: str, jaxlib_version: str, plugin_version: str
) -> bool:
# Regex to match a dotted version prefix 0.1.23.456.789 of a PEP440 version.
# PEP440 allows a number of non-numeric suffixes, which we allow also.
# We currently do not allow an epoch.
version_regex = re.compile(r"[0-9]+(?:\.[0-9]+)*")
def _parse_version(v: str) -> tuple[int, ...]:
m = version_regex.match(v)
if m is None:
raise ValueError(f"Unable to parse version string '{v}'")
return tuple(int(x) for x in m.group(0).split("."))
if _parse_version(jaxlib_version) != _parse_version(plugin_version):
warnings.warn(
f"JAX plugin {plugin_name} version {plugin_version} is installed, but "
"it is not compatible with the installed jaxlib version "
f"{jaxlib_version}, so it will not be used.",
RuntimeWarning,
)
return False
return True
def maybe_import_plugin_submodule(
plugin_module_names: Sequence[str],
submodule_name: str,
*,
check_version: bool = True,
) -> ModuleType | None:
for plugin_module_name in plugin_module_names:
try:
module = importlib.import_module(
f"{plugin_module_name}.{submodule_name}",
package="jaxlib",
)
except ImportError:
continue
else:
if not check_version:
return module
try:
version_module = importlib.import_module(
f"{plugin_module_name}.version",
package="jaxlib",
)
except ImportError:
return module
plugin_version = getattr(version_module, "__version__", "")
if check_plugin_version(
plugin_module_name, jaxlib_version, plugin_version
):
return module
return None