mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
111 lines
3.5 KiB
Python
111 lines
3.5 KiB
Python
![]() |
# Copyright 2025 The JAX Authors.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
from collections.abc import Sequence
|
||
|
import importlib
|
||
|
import re
|
||
|
from types import ModuleType
|
||
|
import warnings
|
||
|
|
||
|
from .version import __version__ as jaxlib_version
|
||
|
|
||
|
|
||
|
_PLUGIN_MODULE_NAME = {
|
||
|
"cuda": "jax_cuda12_plugin",
|
||
|
"rocm": "jax_rocm60_plugin",
|
||
|
}
|
||
|
|
||
|
|
||
|
def import_from_plugin(
|
||
|
plugin_name: str, submodule_name: str, *, check_version: bool = True
|
||
|
) -> ModuleType | None:
|
||
|
"""Import a submodule from a known plugin with version checking.
|
||
|
|
||
|
Args:
|
||
|
plugin_name: The name of the plugin. The supported values are "cuda" or
|
||
|
"rocm".
|
||
|
submodule_name: The name of the submodule to import, e.g. "_triton".
|
||
|
check_version: Whether to check that the plugin version is compatible with
|
||
|
the jaxlib version. If the plugin is installed but the versions are not
|
||
|
compatible, this function produces a warning and returns None.
|
||
|
|
||
|
Returns:
|
||
|
The imported submodule, or None if the plugin is not installed or if the
|
||
|
versions are incompatible.
|
||
|
"""
|
||
|
if plugin_name not in _PLUGIN_MODULE_NAME:
|
||
|
raise ValueError(f"Unknown plugin: {plugin_name}")
|
||
|
return maybe_import_plugin_submodule(
|
||
|
[f".{plugin_name}", _PLUGIN_MODULE_NAME[plugin_name]],
|
||
|
submodule_name,
|
||
|
check_version=check_version,
|
||
|
)
|
||
|
|
||
|
|
||
|
def check_plugin_version(
|
||
|
plugin_name: str, jaxlib_version: str, plugin_version: str
|
||
|
) -> bool:
|
||
|
# Regex to match a dotted version prefix 0.1.23.456.789 of a PEP440 version.
|
||
|
# PEP440 allows a number of non-numeric suffixes, which we allow also.
|
||
|
# We currently do not allow an epoch.
|
||
|
version_regex = re.compile(r"[0-9]+(?:\.[0-9]+)*")
|
||
|
|
||
|
def _parse_version(v: str) -> tuple[int, ...]:
|
||
|
m = version_regex.match(v)
|
||
|
if m is None:
|
||
|
raise ValueError(f"Unable to parse version string '{v}'")
|
||
|
return tuple(int(x) for x in m.group(0).split("."))
|
||
|
|
||
|
if _parse_version(jaxlib_version) != _parse_version(plugin_version):
|
||
|
warnings.warn(
|
||
|
f"JAX plugin {plugin_name} version {plugin_version} is installed, but "
|
||
|
"it is not compatible with the installed jaxlib version "
|
||
|
f"{jaxlib_version}, so it will not be used.",
|
||
|
RuntimeWarning,
|
||
|
)
|
||
|
return False
|
||
|
return True
|
||
|
|
||
|
|
||
|
def maybe_import_plugin_submodule(
|
||
|
plugin_module_names: Sequence[str],
|
||
|
submodule_name: str,
|
||
|
*,
|
||
|
check_version: bool = True,
|
||
|
) -> ModuleType | None:
|
||
|
for plugin_module_name in plugin_module_names:
|
||
|
try:
|
||
|
module = importlib.import_module(
|
||
|
f"{plugin_module_name}.{submodule_name}",
|
||
|
package="jaxlib",
|
||
|
)
|
||
|
except ImportError:
|
||
|
continue
|
||
|
else:
|
||
|
if not check_version:
|
||
|
return module
|
||
|
try:
|
||
|
version_module = importlib.import_module(
|
||
|
f"{plugin_module_name}.version",
|
||
|
package="jaxlib",
|
||
|
)
|
||
|
except ImportError:
|
||
|
return module
|
||
|
plugin_version = getattr(version_module, "__version__", "")
|
||
|
if check_plugin_version(
|
||
|
plugin_module_name, jaxlib_version, plugin_version
|
||
|
):
|
||
|
return module
|
||
|
return None
|