mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Add version check to jaxlib plugin imports.
For the CUDA and ROCM plugins, we only support exact matches between the plugin and jaxlib version, and bad things can happen if we try and load mismatched versions. This change issues a warning and skips importing a plugin when there is a version mismatch. There are a handful of other places where plugins are imported throughout the JAX codebase (e.g. in lax_numpy, mosaic_gpu, and in the plugins themselves). In a follow up it would be good to add version checking there too, but let's start with just these ones. PiperOrigin-RevId: 731808733
This commit is contained in:
parent
c94ec0eb0d
commit
c7ed1bd3a8
@ -49,6 +49,7 @@ py_library_providing_imports_info(
|
||||
"hlo_helpers.py",
|
||||
"init.py",
|
||||
"lapack.py",
|
||||
"plugin_support.py",
|
||||
":version",
|
||||
":xla_client",
|
||||
":xla_extension_py",
|
||||
|
@ -12,29 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
|
||||
from jaxlib import xla_client
|
||||
|
||||
for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
|
||||
try:
|
||||
_cuda_linalg = importlib.import_module(
|
||||
f"{cuda_module_name}._linalg", package="jaxlib"
|
||||
)
|
||||
except ImportError:
|
||||
_cuda_linalg = None
|
||||
else:
|
||||
break
|
||||
from .plugin_support import import_from_plugin
|
||||
|
||||
for rocm_module_name in [".rocm", "jax_rocm60_plugin"]:
|
||||
try:
|
||||
_hip_linalg = importlib.import_module(
|
||||
f"{rocm_module_name}._linalg", package="jaxlib"
|
||||
)
|
||||
except ImportError:
|
||||
_hip_linalg = None
|
||||
else:
|
||||
break
|
||||
_cuda_linalg = import_from_plugin("cuda", "_linalg")
|
||||
_hip_linalg = import_from_plugin("rocm", "_linalg")
|
||||
|
||||
if _cuda_linalg:
|
||||
for _name, _value in _cuda_linalg.registrations().items():
|
||||
|
@ -14,7 +14,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
from functools import partial
|
||||
import importlib
|
||||
import itertools
|
||||
|
||||
import jaxlib.mlir.ir as ir
|
||||
@ -22,17 +21,10 @@ import jaxlib.mlir.ir as ir
|
||||
from jaxlib import xla_client
|
||||
|
||||
from .hlo_helpers import custom_call
|
||||
from .plugin_support import import_from_plugin
|
||||
|
||||
|
||||
for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
|
||||
try:
|
||||
_cuda_prng = importlib.import_module(
|
||||
f"{cuda_module_name}._prng", package="jaxlib"
|
||||
)
|
||||
except ImportError:
|
||||
_cuda_prng = None
|
||||
else:
|
||||
break
|
||||
_cuda_prng = import_from_plugin("cuda", "_prng")
|
||||
_hip_prng = import_from_plugin("rocm", "_prng")
|
||||
|
||||
if _cuda_prng:
|
||||
for _name, _value in _cuda_prng.registrations().items():
|
||||
@ -41,16 +33,6 @@ if _cuda_prng:
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA",
|
||||
api_version=api_version)
|
||||
|
||||
for rocm_module_name in [".rocm", "jax_rocm60_plugin"]:
|
||||
try:
|
||||
_hip_prng = importlib.import_module(
|
||||
f"{rocm_module_name}._prng", package="jaxlib"
|
||||
)
|
||||
except ImportError:
|
||||
_hip_prng = None
|
||||
else:
|
||||
break
|
||||
|
||||
if _hip_prng:
|
||||
for _name, _value in _hip_prng.registrations().items():
|
||||
# TODO(danfm): remove after JAX 0.5.1 release
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
import importlib
|
||||
|
||||
import jaxlib.mlir.ir as ir
|
||||
import jaxlib.mlir.dialects.stablehlo as hlo
|
||||
@ -22,14 +21,10 @@ import numpy as np
|
||||
|
||||
from jaxlib import xla_client
|
||||
from .gpu_common_utils import GpuLibNotLinkedError
|
||||
from .plugin_support import import_from_plugin
|
||||
|
||||
for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
|
||||
try:
|
||||
_cuda_rnn = importlib.import_module(f"{cuda_module_name}._rnn", package="jaxlib")
|
||||
except ImportError:
|
||||
_cuda_rnn = None
|
||||
else:
|
||||
break
|
||||
_cuda_rnn = import_from_plugin("cuda", "_rnn")
|
||||
_hip_rnn = import_from_plugin("rocm", "_rnn")
|
||||
|
||||
if _cuda_rnn:
|
||||
for _name, _value in _cuda_rnn.registrations().items():
|
||||
@ -38,15 +33,6 @@ if _cuda_rnn:
|
||||
api_version=api_version)
|
||||
compute_rnn_workspace_reserve_space_sizes = _cuda_rnn.compute_rnn_workspace_reserve_space_sizes
|
||||
|
||||
|
||||
for rocm_module_name in [".rocm", "jax_rocm60_plugin"]:
|
||||
try:
|
||||
_hip_rnn = importlib.import_module(f"{rocm_module_name}._rnn", package="jaxlib")
|
||||
except ImportError:
|
||||
_hip_rnn = None
|
||||
else:
|
||||
break
|
||||
|
||||
if _hip_rnn:
|
||||
for _name, _value in _hip_rnn.registrations().items():
|
||||
api_version = 1 if _name.endswith("_ffi") else 0
|
||||
|
@ -12,35 +12,22 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
|
||||
from jaxlib import xla_client
|
||||
|
||||
try:
|
||||
from .cuda import _blas as _cublas # pytype: disable=import-error
|
||||
except ImportError:
|
||||
for cuda_module_name in ["jax_cuda12_plugin"]:
|
||||
try:
|
||||
_cublas = importlib.import_module(f"{cuda_module_name}._blas")
|
||||
except ImportError:
|
||||
_cublas = None
|
||||
else:
|
||||
break
|
||||
from .plugin_support import import_from_plugin
|
||||
|
||||
_cublas = import_from_plugin("cuda", "_blas")
|
||||
_cusolver = import_from_plugin("cuda", "_solver")
|
||||
_cuhybrid = import_from_plugin("cuda", "_hybrid")
|
||||
|
||||
_hipblas = import_from_plugin("rocm", "_blas")
|
||||
_hipsolver = import_from_plugin("rocm", "_solver")
|
||||
_hiphybrid = import_from_plugin("rocm", "_hybrid")
|
||||
|
||||
if _cublas:
|
||||
for _name, _value in _cublas.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
|
||||
|
||||
for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
|
||||
try:
|
||||
_cusolver = importlib.import_module(
|
||||
f"{cuda_module_name}._solver", package="jaxlib"
|
||||
)
|
||||
except ImportError:
|
||||
_cusolver = None
|
||||
else:
|
||||
break
|
||||
|
||||
if _cusolver:
|
||||
for _name, _value in _cusolver.registrations().items():
|
||||
# TODO(danfm): Clean up after all legacy custom calls are ported.
|
||||
@ -51,47 +38,16 @@ if _cusolver:
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA",
|
||||
api_version=api_version)
|
||||
|
||||
for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
|
||||
try:
|
||||
_cuhybrid = importlib.import_module(
|
||||
f"{cuda_module_name}._hybrid", package="jaxlib"
|
||||
)
|
||||
except ImportError:
|
||||
_cuhybrid = None
|
||||
else:
|
||||
break
|
||||
|
||||
if _cuhybrid:
|
||||
for _name, _value in _cuhybrid.registrations().items():
|
||||
xla_client.register_custom_call_as_batch_partitionable(_name)
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA",
|
||||
api_version=1)
|
||||
|
||||
try:
|
||||
from .rocm import _blas as _hipblas # pytype: disable=import-error
|
||||
except ImportError:
|
||||
for rocm_module_name in ["jax_rocm60_plugin"]:
|
||||
try:
|
||||
_hipblas = importlib.import_module(f"{rocm_module_name}._blas")
|
||||
except:
|
||||
_hipblas = None
|
||||
else:
|
||||
break
|
||||
|
||||
if _hipblas:
|
||||
for _name, _value in _hipblas.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
|
||||
|
||||
for rocm_module_name in [".rocm", "jax_rocm60_plugin"]:
|
||||
try:
|
||||
_hipsolver = importlib.import_module(
|
||||
f"{rocm_module_name}._solver", package="jaxlib"
|
||||
)
|
||||
except ImportError:
|
||||
_hipsolver = None
|
||||
else:
|
||||
break
|
||||
|
||||
if _hipsolver:
|
||||
for _name, _value in _hipsolver.registrations().items():
|
||||
# TODO(danfm): Clean up after all legacy custom calls are ported.
|
||||
@ -102,16 +58,6 @@ if _hipsolver:
|
||||
xla_client.register_custom_call_target(_name, _value, platform="ROCM",
|
||||
api_version=api_version)
|
||||
|
||||
for rocm_module_name in [".rocm", "jax_rocm60_plugin"]:
|
||||
try:
|
||||
_hiphybrid = importlib.import_module(
|
||||
f"{rocm_module_name}._hybrid", package="jaxlib"
|
||||
)
|
||||
except ImportError:
|
||||
_hiphybrid = None
|
||||
else:
|
||||
break
|
||||
|
||||
if _hiphybrid:
|
||||
for _name, _value in _hiphybrid.registrations().items():
|
||||
xla_client.register_custom_call_as_batch_partitionable(_name)
|
||||
|
@ -17,7 +17,6 @@ cusparse wrappers for performing sparse matrix computations in JAX
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
import importlib
|
||||
|
||||
import jaxlib.mlir.ir as ir
|
||||
|
||||
@ -27,15 +26,10 @@ from jaxlib import xla_client
|
||||
|
||||
from .hlo_helpers import custom_call, mk_result_types_and_shapes
|
||||
|
||||
for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
|
||||
try:
|
||||
_cusparse = importlib.import_module(
|
||||
f"{cuda_module_name}._sparse", package="jaxlib"
|
||||
)
|
||||
except ImportError:
|
||||
_cusparse = None
|
||||
else:
|
||||
break
|
||||
from .plugin_support import import_from_plugin
|
||||
|
||||
_cusparse = import_from_plugin("cuda", "_sparse")
|
||||
_hipsparse = import_from_plugin("rocm", "_sparse")
|
||||
|
||||
if _cusparse:
|
||||
for _name, _value in _cusparse.registrations().items():
|
||||
@ -43,16 +37,6 @@ if _cusparse:
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA",
|
||||
api_version=api_version)
|
||||
|
||||
for rocm_module_name in [".rocm", "jax_rocm60_plugin"]:
|
||||
try:
|
||||
_hipsparse = importlib.import_module(
|
||||
f"{rocm_module_name}._sparse", package="jaxlib"
|
||||
)
|
||||
except ImportError:
|
||||
_hipsparse = None
|
||||
else:
|
||||
break
|
||||
|
||||
if _hipsparse:
|
||||
for _name, _value in _hipsparse.registrations().items():
|
||||
api_version = 1 if _name.endswith("_ffi") else 0
|
||||
|
@ -11,19 +11,13 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
|
||||
from jaxlib import xla_client
|
||||
|
||||
for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
|
||||
try:
|
||||
_cuda_triton = importlib.import_module(
|
||||
f"{cuda_module_name}._triton", package="jaxlib"
|
||||
)
|
||||
except ImportError:
|
||||
_cuda_triton = None
|
||||
else:
|
||||
break
|
||||
from .plugin_support import import_from_plugin
|
||||
|
||||
_cuda_triton = import_from_plugin("cuda", "_triton")
|
||||
_hip_triton = import_from_plugin("rocm", "_triton")
|
||||
|
||||
if _cuda_triton:
|
||||
xla_client.register_custom_call_target(
|
||||
@ -39,16 +33,6 @@ if _cuda_triton:
|
||||
get_custom_call = _cuda_triton.get_custom_call
|
||||
get_serialized_metadata = _cuda_triton.get_serialized_metadata
|
||||
|
||||
for rocm_module_name in [".rocm", "jax_rocm60_plugin"]:
|
||||
try:
|
||||
_hip_triton = importlib.import_module(
|
||||
f"{rocm_module_name}._triton", package="jaxlib"
|
||||
)
|
||||
except ImportError:
|
||||
_hip_triton = None
|
||||
else:
|
||||
break
|
||||
|
||||
if _hip_triton:
|
||||
xla_client.register_custom_call_target(
|
||||
"triton_kernel_call", _hip_triton.get_custom_call(),
|
||||
|
110
jaxlib/plugin_support.py
Normal file
110
jaxlib/plugin_support.py
Normal file
@ -0,0 +1,110 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Sequence
|
||||
import importlib
|
||||
import re
|
||||
from types import ModuleType
|
||||
import warnings
|
||||
|
||||
from .version import __version__ as jaxlib_version
|
||||
|
||||
|
||||
_PLUGIN_MODULE_NAME = {
|
||||
"cuda": "jax_cuda12_plugin",
|
||||
"rocm": "jax_rocm60_plugin",
|
||||
}
|
||||
|
||||
|
||||
def import_from_plugin(
|
||||
plugin_name: str, submodule_name: str, *, check_version: bool = True
|
||||
) -> ModuleType | None:
|
||||
"""Import a submodule from a known plugin with version checking.
|
||||
|
||||
Args:
|
||||
plugin_name: The name of the plugin. The supported values are "cuda" or
|
||||
"rocm".
|
||||
submodule_name: The name of the submodule to import, e.g. "_triton".
|
||||
check_version: Whether to check that the plugin version is compatible with
|
||||
the jaxlib version. If the plugin is installed but the versions are not
|
||||
compatible, this function produces a warning and returns None.
|
||||
|
||||
Returns:
|
||||
The imported submodule, or None if the plugin is not installed or if the
|
||||
versions are incompatible.
|
||||
"""
|
||||
if plugin_name not in _PLUGIN_MODULE_NAME:
|
||||
raise ValueError(f"Unknown plugin: {plugin_name}")
|
||||
return maybe_import_plugin_submodule(
|
||||
[f".{plugin_name}", _PLUGIN_MODULE_NAME[plugin_name]],
|
||||
submodule_name,
|
||||
check_version=check_version,
|
||||
)
|
||||
|
||||
|
||||
def check_plugin_version(
|
||||
plugin_name: str, jaxlib_version: str, plugin_version: str
|
||||
) -> bool:
|
||||
# Regex to match a dotted version prefix 0.1.23.456.789 of a PEP440 version.
|
||||
# PEP440 allows a number of non-numeric suffixes, which we allow also.
|
||||
# We currently do not allow an epoch.
|
||||
version_regex = re.compile(r"[0-9]+(?:\.[0-9]+)*")
|
||||
|
||||
def _parse_version(v: str) -> tuple[int, ...]:
|
||||
m = version_regex.match(v)
|
||||
if m is None:
|
||||
raise ValueError(f"Unable to parse version string '{v}'")
|
||||
return tuple(int(x) for x in m.group(0).split("."))
|
||||
|
||||
if _parse_version(jaxlib_version) != _parse_version(plugin_version):
|
||||
warnings.warn(
|
||||
f"JAX plugin {plugin_name} version {plugin_version} is installed, but "
|
||||
"it is not compatible with the installed jaxlib version "
|
||||
f"{jaxlib_version}, so it will not be used.",
|
||||
RuntimeWarning,
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def maybe_import_plugin_submodule(
|
||||
plugin_module_names: Sequence[str],
|
||||
submodule_name: str,
|
||||
*,
|
||||
check_version: bool = True,
|
||||
) -> ModuleType | None:
|
||||
for plugin_module_name in plugin_module_names:
|
||||
try:
|
||||
module = importlib.import_module(
|
||||
f"{plugin_module_name}.{submodule_name}",
|
||||
package="jaxlib",
|
||||
)
|
||||
except ImportError:
|
||||
continue
|
||||
else:
|
||||
if not check_version:
|
||||
return module
|
||||
try:
|
||||
version_module = importlib.import_module(
|
||||
f"{plugin_module_name}.version",
|
||||
package="jaxlib",
|
||||
)
|
||||
except ImportError:
|
||||
return module
|
||||
plugin_version = getattr(version_module, "__version__", "")
|
||||
if check_plugin_version(
|
||||
plugin_module_name, jaxlib_version, plugin_version
|
||||
):
|
||||
return module
|
||||
return None
|
@ -195,6 +195,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu):
|
||||
"__main__/jaxlib/gpu_common_utils.py",
|
||||
"__main__/jaxlib/gpu_solver.py",
|
||||
"__main__/jaxlib/gpu_sparse.py",
|
||||
"__main__/jaxlib/plugin_support.py",
|
||||
"__main__/jaxlib/version.py",
|
||||
"__main__/jaxlib/xla_client.py",
|
||||
f"xla/xla/python/xla_extension.{pyext}",
|
||||
|
Loading…
x
Reference in New Issue
Block a user