diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index abd104293..c674401fb 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -44,6 +44,10 @@ from jax._src.lax import lax as lax_internal from jax._src.lax import svd as lax_svd from jax._src.lax import utils as lax_utils from jax._src.lax.lax import _float, _complex, _int +from jax._src.lib import gpu_linalg +from jax._src.lib import gpu_solver +from jax._src.lib import gpu_sparse +from jax._src.lib import lapack from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo @@ -51,12 +55,23 @@ from jax._src.lib.mlir.dialects import hlo from jax._src.partition_spec import PartitionSpec as P from jax._src.typing import Array, ArrayLike -# The following imports may be unused but they are needed to register the -# custom call targets defined in each module. -from jax._src.lib import gpu_linalg # pylint:disable=unused-import # noqa: F401 -from jax._src.lib import gpu_solver # pylint:disable=unused-import # noqa: F401 -from jax._src.lib import gpu_sparse # pylint:disable=unused-import # noqa: F401 -from jax._src.lib import lapack # pylint:disable=unused-import # noqa: F401 + +def register_module_custom_calls(module): + if hasattr(module, "registrations"): + for platform, targets in module.registrations().items(): + for name, value, api_version in targets: + ffi.register_ffi_target( + name, value, platform=platform, api_version=api_version + ) + if hasattr(module, "batch_partitionable_targets"): + for name in module.batch_partitionable_targets(): + ffi.register_ffi_target_as_batch_partitionable(name) + + +register_module_custom_calls(gpu_linalg) +register_module_custom_calls(gpu_solver) +register_module_custom_calls(gpu_sparse) +register_module_custom_calls(lapack) # Top-level functions in alphabetical order. diff --git a/jax/experimental/sparse/_base.py b/jax/experimental/sparse/_base.py index 36d84cb0d..7739af029 100644 --- a/jax/experimental/sparse/_base.py +++ b/jax/experimental/sparse/_base.py @@ -19,8 +19,18 @@ import math import jax from jax._src import core +from jax._src import ffi from jax._src import util from jax._src.typing import Array +from jax._src.lib import gpu_sparse + + +if hasattr(gpu_sparse, "registrations"): + for platform, targets in gpu_sparse.registrations().items(): + for name, value, api_version in targets: + ffi.register_ffi_target( + name, value, platform=platform, api_version=api_version + ) class JAXSparse(util.StrictABC): diff --git a/jaxlib/gpu_linalg.py b/jaxlib/gpu_linalg.py index 1acfbaf22..c747c0abb 100644 --- a/jaxlib/gpu_linalg.py +++ b/jaxlib/gpu_linalg.py @@ -12,25 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jaxlib import xla_client +from typing import Any from .plugin_support import import_from_plugin _cuda_linalg = import_from_plugin("cuda", "_linalg") _hip_linalg = import_from_plugin("rocm", "_linalg") -if _cuda_linalg: - for _name, _value in _cuda_linalg.registrations().items(): - xla_client.register_custom_call_target( - _name, _value, platform="CUDA", api_version=1 - ) - xla_client.register_custom_call_as_batch_partitionable( - "cu_lu_pivots_to_permutation") +def registrations() -> dict[str, list[tuple[str, Any, int]]]: + registrations = {"CUDA": [], "ROCM": []} + for platform, module in [("CUDA", _cuda_linalg), ("ROCM", _hip_linalg)]: + if module: + registrations[platform].extend( + (*i, 1) for i in module.registrations().items()) + return registrations # pytype: disable=bad-return-type -if _hip_linalg: - for _name, _value in _hip_linalg.registrations().items(): - xla_client.register_custom_call_target( - _name, _value, platform="ROCM", api_version=1 - ) - xla_client.register_custom_call_as_batch_partitionable( - "hip_lu_pivots_to_permutation") + +def batch_partitionable_targets() -> list[str]: + targets = [] + if _cuda_linalg: + targets.append("cu_lu_pivots_to_permutation") + if _hip_linalg: + targets.append("hip_lu_pivots_to_permutation") + return targets diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index a40c6bf93..efb58f9a4 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jaxlib import xla_client +from typing import Any from .plugin_support import import_from_plugin @@ -24,45 +24,39 @@ _hipblas = import_from_plugin("rocm", "_blas") _hipsolver = import_from_plugin("rocm", "_solver") _hiphybrid = import_from_plugin("rocm", "_hybrid") -if _cublas: - for _name, _value in _cublas.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="CUDA") -if _cusolver: - for _name, _value in _cusolver.registrations().items(): - # TODO(danfm): Clean up after all legacy custom calls are ported. - api_version = 0 - if _name.endswith("_ffi"): - api_version = 1 - xla_client.register_custom_call_as_batch_partitionable(_name) - xla_client.register_custom_call_target(_name, _value, platform="CUDA", - api_version=api_version) +def registrations() -> dict[str, list[tuple[str, Any, int]]]: + registrations = {"CUDA": [], "ROCM": []} + for platform, module in [("CUDA", _cublas), ("ROCM", _hipblas)]: + if module: + registrations[platform].extend( + (*i, 0) for i in module.registrations().items()) + for platform, module in [("CUDA", _cusolver), ("ROCM", _hipsolver)]: + if module: + registrations[platform].extend( + (name, value, int(name.endswith("_ffi"))) + for name, value in module.registrations().items() + ) + for platform, module in [("CUDA", _cuhybrid), ("ROCM", _hiphybrid)]: + if module: + registrations[platform].extend( + (*i, 1) for i in module.registrations().items()) + return registrations # pytype: disable=bad-return-type -if _cuhybrid: - for _name, _value in _cuhybrid.registrations().items(): - xla_client.register_custom_call_as_batch_partitionable(_name) - xla_client.register_custom_call_target(_name, _value, platform="CUDA", - api_version=1) -if _hipblas: - for _name, _value in _hipblas.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="ROCM") +def batch_partitionable_targets() -> list[str]: + targets = [] + for module in [_cusolver, _hipsolver]: + if module: + targets.extend( + name for name in module.registrations() + if name.endswith("_ffi") + ) + for module in [_cuhybrid, _hiphybrid]: + if module: + targets.extend(name for name in module.registrations()) + return targets -if _hipsolver: - for _name, _value in _hipsolver.registrations().items(): - # TODO(danfm): Clean up after all legacy custom calls are ported. - api_version = 0 - if _name.endswith("_ffi"): - api_version = 1 - xla_client.register_custom_call_as_batch_partitionable(_name) - xla_client.register_custom_call_target(_name, _value, platform="ROCM", - api_version=api_version) - -if _hiphybrid: - for _name, _value in _hiphybrid.registrations().items(): - xla_client.register_custom_call_as_batch_partitionable(_name) - xla_client.register_custom_call_target(_name, _value, platform="ROCM", - api_version=1) def initialize_hybrid_kernels(): if _cuhybrid: @@ -70,6 +64,7 @@ def initialize_hybrid_kernels(): if _hiphybrid: _hiphybrid.initialize() + def has_magma(): if _cuhybrid: return _cuhybrid.has_magma() diff --git a/jaxlib/gpu_sparse.py b/jaxlib/gpu_sparse.py index d397557df..d8645041c 100644 --- a/jaxlib/gpu_sparse.py +++ b/jaxlib/gpu_sparse.py @@ -17,13 +17,12 @@ cusparse wrappers for performing sparse matrix computations in JAX import math from functools import partial +from typing import Any import jaxlib.mlir.ir as ir import numpy as np -from jaxlib import xla_client - from .hlo_helpers import custom_call, mk_result_types_and_shapes from .plugin_support import import_from_plugin @@ -31,17 +30,14 @@ from .plugin_support import import_from_plugin _cusparse = import_from_plugin("cuda", "_sparse") _hipsparse = import_from_plugin("rocm", "_sparse") -if _cusparse: - for _name, _value in _cusparse.registrations().items(): - api_version = 1 if _name.endswith("_ffi") else 0 - xla_client.register_custom_call_target(_name, _value, platform="CUDA", - api_version=api_version) - -if _hipsparse: - for _name, _value in _hipsparse.registrations().items(): - api_version = 1 if _name.endswith("_ffi") else 0 - xla_client.register_custom_call_target(_name, _value, platform="ROCM", - api_version=api_version) +def registrations() -> dict[str, list[tuple[str, Any, int]]]: + registrations = {"CUDA": [], "ROCM": []} + for platform, module in [("CUDA", _cusparse), ("ROCM", _hipsparse)]: + if module: + registrations[platform].extend( + (name, value, int(name.endswith("_ffi"))) + for name, value in module.registrations().items()) + return registrations # pytype: disable=bad-return-type cuda_is_supported = bool(_cusparse and _cusparse.sparse_supported) diff --git a/jaxlib/lapack.py b/jaxlib/lapack.py index c5a59e314..330fcb992 100644 --- a/jaxlib/lapack.py +++ b/jaxlib/lapack.py @@ -12,23 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np +from typing import Any -from jaxlib import xla_client +import numpy as np from .cpu import _lapack from .cpu._lapack import eig from .cpu._lapack import schur -for _name, _value in _lapack.registrations().items(): - api_version = 0 - if _name.endswith("_ffi"): - api_version = 1 - xla_client.register_custom_call_as_batch_partitionable(_name) - xla_client.register_custom_call_target( - _name, _value, platform="cpu", api_version=api_version - ) - EigComputationMode = eig.ComputationMode SchurComputationMode = schur.ComputationMode @@ -43,6 +34,17 @@ LAPACK_DTYPE_PREFIX = { } +def registrations() -> dict[str, list[tuple[str, Any, int]]]: + return {"cpu": [ + (name, value, int(name.endswith("_ffi"))) + for name, value in _lapack.registrations().items() + ]} + + +def batch_partitionable_targets() -> list[str]: + return [name for name in _lapack.registrations() if name.endswith("_ffi")] + + def prepare_lapack_call(fn_base, dtype): """Initializes the LAPACK library and returns the LAPACK target name.""" _lapack.initialize()