Add version check to jaxlib plugin imports.

For the CUDA and ROCM plugins, we only support exact matches between the plugin and jaxlib version, and bad things can happen if we try and load mismatched versions. This change issues a warning and skips importing a plugin when there is a version mismatch.

There are a handful of other places where plugins are imported throughout the JAX codebase (e.g. in lax_numpy, mosaic_gpu, and in the plugins themselves). In a follow up it would be good to add version checking there too, but let's start with just these ones.

PiperOrigin-RevId: 731808733
This commit is contained in:
Dan Foreman-Mackey 2025-02-27 11:51:39 -08:00 committed by jax authors
parent c94ec0eb0d
commit c7ed1bd3a8
9 changed files with 138 additions and 161 deletions

View File

@ -49,6 +49,7 @@ py_library_providing_imports_info(
"hlo_helpers.py",
"init.py",
"lapack.py",
"plugin_support.py",
":version",
":xla_client",
":xla_extension_py",

View File

@ -12,29 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from jaxlib import xla_client
for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
try:
_cuda_linalg = importlib.import_module(
f"{cuda_module_name}._linalg", package="jaxlib"
)
except ImportError:
_cuda_linalg = None
else:
break
from .plugin_support import import_from_plugin
for rocm_module_name in [".rocm", "jax_rocm60_plugin"]:
try:
_hip_linalg = importlib.import_module(
f"{rocm_module_name}._linalg", package="jaxlib"
)
except ImportError:
_hip_linalg = None
else:
break
_cuda_linalg = import_from_plugin("cuda", "_linalg")
_hip_linalg = import_from_plugin("rocm", "_linalg")
if _cuda_linalg:
for _name, _value in _cuda_linalg.registrations().items():

View File

@ -14,7 +14,6 @@
from __future__ import annotations
from functools import partial
import importlib
import itertools
import jaxlib.mlir.ir as ir
@ -22,17 +21,10 @@ import jaxlib.mlir.ir as ir
from jaxlib import xla_client
from .hlo_helpers import custom_call
from .plugin_support import import_from_plugin
for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
try:
_cuda_prng = importlib.import_module(
f"{cuda_module_name}._prng", package="jaxlib"
)
except ImportError:
_cuda_prng = None
else:
break
_cuda_prng = import_from_plugin("cuda", "_prng")
_hip_prng = import_from_plugin("rocm", "_prng")
if _cuda_prng:
for _name, _value in _cuda_prng.registrations().items():
@ -41,16 +33,6 @@ if _cuda_prng:
xla_client.register_custom_call_target(_name, _value, platform="CUDA",
api_version=api_version)
for rocm_module_name in [".rocm", "jax_rocm60_plugin"]:
try:
_hip_prng = importlib.import_module(
f"{rocm_module_name}._prng", package="jaxlib"
)
except ImportError:
_hip_prng = None
else:
break
if _hip_prng:
for _name, _value in _hip_prng.registrations().items():
# TODO(danfm): remove after JAX 0.5.1 release

View File

@ -13,7 +13,6 @@
# limitations under the License.
from functools import partial
import importlib
import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.stablehlo as hlo
@ -22,14 +21,10 @@ import numpy as np
from jaxlib import xla_client
from .gpu_common_utils import GpuLibNotLinkedError
from .plugin_support import import_from_plugin
for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
try:
_cuda_rnn = importlib.import_module(f"{cuda_module_name}._rnn", package="jaxlib")
except ImportError:
_cuda_rnn = None
else:
break
_cuda_rnn = import_from_plugin("cuda", "_rnn")
_hip_rnn = import_from_plugin("rocm", "_rnn")
if _cuda_rnn:
for _name, _value in _cuda_rnn.registrations().items():
@ -38,15 +33,6 @@ if _cuda_rnn:
api_version=api_version)
compute_rnn_workspace_reserve_space_sizes = _cuda_rnn.compute_rnn_workspace_reserve_space_sizes
for rocm_module_name in [".rocm", "jax_rocm60_plugin"]:
try:
_hip_rnn = importlib.import_module(f"{rocm_module_name}._rnn", package="jaxlib")
except ImportError:
_hip_rnn = None
else:
break
if _hip_rnn:
for _name, _value in _hip_rnn.registrations().items():
api_version = 1 if _name.endswith("_ffi") else 0

View File

@ -12,35 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from jaxlib import xla_client
try:
from .cuda import _blas as _cublas # pytype: disable=import-error
except ImportError:
for cuda_module_name in ["jax_cuda12_plugin"]:
try:
_cublas = importlib.import_module(f"{cuda_module_name}._blas")
except ImportError:
_cublas = None
else:
break
from .plugin_support import import_from_plugin
_cublas = import_from_plugin("cuda", "_blas")
_cusolver = import_from_plugin("cuda", "_solver")
_cuhybrid = import_from_plugin("cuda", "_hybrid")
_hipblas = import_from_plugin("rocm", "_blas")
_hipsolver = import_from_plugin("rocm", "_solver")
_hiphybrid = import_from_plugin("rocm", "_hybrid")
if _cublas:
for _name, _value in _cublas.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
try:
_cusolver = importlib.import_module(
f"{cuda_module_name}._solver", package="jaxlib"
)
except ImportError:
_cusolver = None
else:
break
if _cusolver:
for _name, _value in _cusolver.registrations().items():
# TODO(danfm): Clean up after all legacy custom calls are ported.
@ -51,47 +38,16 @@ if _cusolver:
xla_client.register_custom_call_target(_name, _value, platform="CUDA",
api_version=api_version)
for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
try:
_cuhybrid = importlib.import_module(
f"{cuda_module_name}._hybrid", package="jaxlib"
)
except ImportError:
_cuhybrid = None
else:
break
if _cuhybrid:
for _name, _value in _cuhybrid.registrations().items():
xla_client.register_custom_call_as_batch_partitionable(_name)
xla_client.register_custom_call_target(_name, _value, platform="CUDA",
api_version=1)
try:
from .rocm import _blas as _hipblas # pytype: disable=import-error
except ImportError:
for rocm_module_name in ["jax_rocm60_plugin"]:
try:
_hipblas = importlib.import_module(f"{rocm_module_name}._blas")
except:
_hipblas = None
else:
break
if _hipblas:
for _name, _value in _hipblas.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
for rocm_module_name in [".rocm", "jax_rocm60_plugin"]:
try:
_hipsolver = importlib.import_module(
f"{rocm_module_name}._solver", package="jaxlib"
)
except ImportError:
_hipsolver = None
else:
break
if _hipsolver:
for _name, _value in _hipsolver.registrations().items():
# TODO(danfm): Clean up after all legacy custom calls are ported.
@ -102,16 +58,6 @@ if _hipsolver:
xla_client.register_custom_call_target(_name, _value, platform="ROCM",
api_version=api_version)
for rocm_module_name in [".rocm", "jax_rocm60_plugin"]:
try:
_hiphybrid = importlib.import_module(
f"{rocm_module_name}._hybrid", package="jaxlib"
)
except ImportError:
_hiphybrid = None
else:
break
if _hiphybrid:
for _name, _value in _hiphybrid.registrations().items():
xla_client.register_custom_call_as_batch_partitionable(_name)

View File

@ -17,7 +17,6 @@ cusparse wrappers for performing sparse matrix computations in JAX
import math
from functools import partial
import importlib
import jaxlib.mlir.ir as ir
@ -27,15 +26,10 @@ from jaxlib import xla_client
from .hlo_helpers import custom_call, mk_result_types_and_shapes
for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
try:
_cusparse = importlib.import_module(
f"{cuda_module_name}._sparse", package="jaxlib"
)
except ImportError:
_cusparse = None
else:
break
from .plugin_support import import_from_plugin
_cusparse = import_from_plugin("cuda", "_sparse")
_hipsparse = import_from_plugin("rocm", "_sparse")
if _cusparse:
for _name, _value in _cusparse.registrations().items():
@ -43,16 +37,6 @@ if _cusparse:
xla_client.register_custom_call_target(_name, _value, platform="CUDA",
api_version=api_version)
for rocm_module_name in [".rocm", "jax_rocm60_plugin"]:
try:
_hipsparse = importlib.import_module(
f"{rocm_module_name}._sparse", package="jaxlib"
)
except ImportError:
_hipsparse = None
else:
break
if _hipsparse:
for _name, _value in _hipsparse.registrations().items():
api_version = 1 if _name.endswith("_ffi") else 0

View File

@ -11,19 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from jaxlib import xla_client
for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
try:
_cuda_triton = importlib.import_module(
f"{cuda_module_name}._triton", package="jaxlib"
)
except ImportError:
_cuda_triton = None
else:
break
from .plugin_support import import_from_plugin
_cuda_triton = import_from_plugin("cuda", "_triton")
_hip_triton = import_from_plugin("rocm", "_triton")
if _cuda_triton:
xla_client.register_custom_call_target(
@ -39,16 +33,6 @@ if _cuda_triton:
get_custom_call = _cuda_triton.get_custom_call
get_serialized_metadata = _cuda_triton.get_serialized_metadata
for rocm_module_name in [".rocm", "jax_rocm60_plugin"]:
try:
_hip_triton = importlib.import_module(
f"{rocm_module_name}._triton", package="jaxlib"
)
except ImportError:
_hip_triton = None
else:
break
if _hip_triton:
xla_client.register_custom_call_target(
"triton_kernel_call", _hip_triton.get_custom_call(),

110
jaxlib/plugin_support.py Normal file
View File

@ -0,0 +1,110 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
import importlib
import re
from types import ModuleType
import warnings
from .version import __version__ as jaxlib_version
_PLUGIN_MODULE_NAME = {
"cuda": "jax_cuda12_plugin",
"rocm": "jax_rocm60_plugin",
}
def import_from_plugin(
plugin_name: str, submodule_name: str, *, check_version: bool = True
) -> ModuleType | None:
"""Import a submodule from a known plugin with version checking.
Args:
plugin_name: The name of the plugin. The supported values are "cuda" or
"rocm".
submodule_name: The name of the submodule to import, e.g. "_triton".
check_version: Whether to check that the plugin version is compatible with
the jaxlib version. If the plugin is installed but the versions are not
compatible, this function produces a warning and returns None.
Returns:
The imported submodule, or None if the plugin is not installed or if the
versions are incompatible.
"""
if plugin_name not in _PLUGIN_MODULE_NAME:
raise ValueError(f"Unknown plugin: {plugin_name}")
return maybe_import_plugin_submodule(
[f".{plugin_name}", _PLUGIN_MODULE_NAME[plugin_name]],
submodule_name,
check_version=check_version,
)
def check_plugin_version(
plugin_name: str, jaxlib_version: str, plugin_version: str
) -> bool:
# Regex to match a dotted version prefix 0.1.23.456.789 of a PEP440 version.
# PEP440 allows a number of non-numeric suffixes, which we allow also.
# We currently do not allow an epoch.
version_regex = re.compile(r"[0-9]+(?:\.[0-9]+)*")
def _parse_version(v: str) -> tuple[int, ...]:
m = version_regex.match(v)
if m is None:
raise ValueError(f"Unable to parse version string '{v}'")
return tuple(int(x) for x in m.group(0).split("."))
if _parse_version(jaxlib_version) != _parse_version(plugin_version):
warnings.warn(
f"JAX plugin {plugin_name} version {plugin_version} is installed, but "
"it is not compatible with the installed jaxlib version "
f"{jaxlib_version}, so it will not be used.",
RuntimeWarning,
)
return False
return True
def maybe_import_plugin_submodule(
plugin_module_names: Sequence[str],
submodule_name: str,
*,
check_version: bool = True,
) -> ModuleType | None:
for plugin_module_name in plugin_module_names:
try:
module = importlib.import_module(
f"{plugin_module_name}.{submodule_name}",
package="jaxlib",
)
except ImportError:
continue
else:
if not check_version:
return module
try:
version_module = importlib.import_module(
f"{plugin_module_name}.version",
package="jaxlib",
)
except ImportError:
return module
plugin_version = getattr(version_module, "__version__", "")
if check_plugin_version(
plugin_module_name, jaxlib_version, plugin_version
):
return module
return None

View File

@ -195,6 +195,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu):
"__main__/jaxlib/gpu_common_utils.py",
"__main__/jaxlib/gpu_solver.py",
"__main__/jaxlib/gpu_sparse.py",
"__main__/jaxlib/plugin_support.py",
"__main__/jaxlib/version.py",
"__main__/jaxlib/xla_client.py",
f"xla/xla/python/xla_extension.{pyext}",