From 21884d4a14d364c3b82b312a668079c668cc2836 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 10 Mar 2025 08:17:07 -0700 Subject: [PATCH] Move (most) jaxlib linalg custom call registration into JAX. My motivation here is to fix the plugin support for batch partitionable custom calls. Since plugin support for custom call partitioners is provided via register_plugin_callback in xla_bridge, instead of xla_client itself, it's much more straightforward to register the custom calls in JAX. It would be possible to refactor things differently, but it actually seems like a reasonable choice to use the supported APIs from `jax.ffi` instead of `xla_client` so that we can take advantage of any new features we might add there in the future. This is all still a little bit brittle and I'd eventually like to migrate to a version where the XLA FFI library provides a mechanism for exporting handlers, but this change is still compatible with any future changes like that. PiperOrigin-RevId: 735381736 --- jax/_src/lax/linalg.py | 27 ++++++++++--- jax/experimental/sparse/_base.py | 10 +++++ jaxlib/gpu_linalg.py | 31 ++++++++------- jaxlib/gpu_solver.py | 67 +++++++++++++++----------------- jaxlib/gpu_sparse.py | 22 +++++------ jaxlib/lapack.py | 24 ++++++------ 6 files changed, 100 insertions(+), 81 deletions(-) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index abd104293..c674401fb 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -44,6 +44,10 @@ from jax._src.lax import lax as lax_internal from jax._src.lax import svd as lax_svd from jax._src.lax import utils as lax_utils from jax._src.lax.lax import _float, _complex, _int +from jax._src.lib import gpu_linalg +from jax._src.lib import gpu_solver +from jax._src.lib import gpu_sparse +from jax._src.lib import lapack from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo @@ -51,12 +55,23 @@ from jax._src.lib.mlir.dialects import hlo from jax._src.partition_spec import PartitionSpec as P from jax._src.typing import Array, ArrayLike -# The following imports may be unused but they are needed to register the -# custom call targets defined in each module. -from jax._src.lib import gpu_linalg # pylint:disable=unused-import # noqa: F401 -from jax._src.lib import gpu_solver # pylint:disable=unused-import # noqa: F401 -from jax._src.lib import gpu_sparse # pylint:disable=unused-import # noqa: F401 -from jax._src.lib import lapack # pylint:disable=unused-import # noqa: F401 + +def register_module_custom_calls(module): + if hasattr(module, "registrations"): + for platform, targets in module.registrations().items(): + for name, value, api_version in targets: + ffi.register_ffi_target( + name, value, platform=platform, api_version=api_version + ) + if hasattr(module, "batch_partitionable_targets"): + for name in module.batch_partitionable_targets(): + ffi.register_ffi_target_as_batch_partitionable(name) + + +register_module_custom_calls(gpu_linalg) +register_module_custom_calls(gpu_solver) +register_module_custom_calls(gpu_sparse) +register_module_custom_calls(lapack) # Top-level functions in alphabetical order. diff --git a/jax/experimental/sparse/_base.py b/jax/experimental/sparse/_base.py index 36d84cb0d..7739af029 100644 --- a/jax/experimental/sparse/_base.py +++ b/jax/experimental/sparse/_base.py @@ -19,8 +19,18 @@ import math import jax from jax._src import core +from jax._src import ffi from jax._src import util from jax._src.typing import Array +from jax._src.lib import gpu_sparse + + +if hasattr(gpu_sparse, "registrations"): + for platform, targets in gpu_sparse.registrations().items(): + for name, value, api_version in targets: + ffi.register_ffi_target( + name, value, platform=platform, api_version=api_version + ) class JAXSparse(util.StrictABC): diff --git a/jaxlib/gpu_linalg.py b/jaxlib/gpu_linalg.py index 1acfbaf22..c747c0abb 100644 --- a/jaxlib/gpu_linalg.py +++ b/jaxlib/gpu_linalg.py @@ -12,25 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jaxlib import xla_client +from typing import Any from .plugin_support import import_from_plugin _cuda_linalg = import_from_plugin("cuda", "_linalg") _hip_linalg = import_from_plugin("rocm", "_linalg") -if _cuda_linalg: - for _name, _value in _cuda_linalg.registrations().items(): - xla_client.register_custom_call_target( - _name, _value, platform="CUDA", api_version=1 - ) - xla_client.register_custom_call_as_batch_partitionable( - "cu_lu_pivots_to_permutation") +def registrations() -> dict[str, list[tuple[str, Any, int]]]: + registrations = {"CUDA": [], "ROCM": []} + for platform, module in [("CUDA", _cuda_linalg), ("ROCM", _hip_linalg)]: + if module: + registrations[platform].extend( + (*i, 1) for i in module.registrations().items()) + return registrations # pytype: disable=bad-return-type -if _hip_linalg: - for _name, _value in _hip_linalg.registrations().items(): - xla_client.register_custom_call_target( - _name, _value, platform="ROCM", api_version=1 - ) - xla_client.register_custom_call_as_batch_partitionable( - "hip_lu_pivots_to_permutation") + +def batch_partitionable_targets() -> list[str]: + targets = [] + if _cuda_linalg: + targets.append("cu_lu_pivots_to_permutation") + if _hip_linalg: + targets.append("hip_lu_pivots_to_permutation") + return targets diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index a40c6bf93..efb58f9a4 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jaxlib import xla_client +from typing import Any from .plugin_support import import_from_plugin @@ -24,45 +24,39 @@ _hipblas = import_from_plugin("rocm", "_blas") _hipsolver = import_from_plugin("rocm", "_solver") _hiphybrid = import_from_plugin("rocm", "_hybrid") -if _cublas: - for _name, _value in _cublas.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="CUDA") -if _cusolver: - for _name, _value in _cusolver.registrations().items(): - # TODO(danfm): Clean up after all legacy custom calls are ported. - api_version = 0 - if _name.endswith("_ffi"): - api_version = 1 - xla_client.register_custom_call_as_batch_partitionable(_name) - xla_client.register_custom_call_target(_name, _value, platform="CUDA", - api_version=api_version) +def registrations() -> dict[str, list[tuple[str, Any, int]]]: + registrations = {"CUDA": [], "ROCM": []} + for platform, module in [("CUDA", _cublas), ("ROCM", _hipblas)]: + if module: + registrations[platform].extend( + (*i, 0) for i in module.registrations().items()) + for platform, module in [("CUDA", _cusolver), ("ROCM", _hipsolver)]: + if module: + registrations[platform].extend( + (name, value, int(name.endswith("_ffi"))) + for name, value in module.registrations().items() + ) + for platform, module in [("CUDA", _cuhybrid), ("ROCM", _hiphybrid)]: + if module: + registrations[platform].extend( + (*i, 1) for i in module.registrations().items()) + return registrations # pytype: disable=bad-return-type -if _cuhybrid: - for _name, _value in _cuhybrid.registrations().items(): - xla_client.register_custom_call_as_batch_partitionable(_name) - xla_client.register_custom_call_target(_name, _value, platform="CUDA", - api_version=1) -if _hipblas: - for _name, _value in _hipblas.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="ROCM") +def batch_partitionable_targets() -> list[str]: + targets = [] + for module in [_cusolver, _hipsolver]: + if module: + targets.extend( + name for name in module.registrations() + if name.endswith("_ffi") + ) + for module in [_cuhybrid, _hiphybrid]: + if module: + targets.extend(name for name in module.registrations()) + return targets -if _hipsolver: - for _name, _value in _hipsolver.registrations().items(): - # TODO(danfm): Clean up after all legacy custom calls are ported. - api_version = 0 - if _name.endswith("_ffi"): - api_version = 1 - xla_client.register_custom_call_as_batch_partitionable(_name) - xla_client.register_custom_call_target(_name, _value, platform="ROCM", - api_version=api_version) - -if _hiphybrid: - for _name, _value in _hiphybrid.registrations().items(): - xla_client.register_custom_call_as_batch_partitionable(_name) - xla_client.register_custom_call_target(_name, _value, platform="ROCM", - api_version=1) def initialize_hybrid_kernels(): if _cuhybrid: @@ -70,6 +64,7 @@ def initialize_hybrid_kernels(): if _hiphybrid: _hiphybrid.initialize() + def has_magma(): if _cuhybrid: return _cuhybrid.has_magma() diff --git a/jaxlib/gpu_sparse.py b/jaxlib/gpu_sparse.py index d397557df..d8645041c 100644 --- a/jaxlib/gpu_sparse.py +++ b/jaxlib/gpu_sparse.py @@ -17,13 +17,12 @@ cusparse wrappers for performing sparse matrix computations in JAX import math from functools import partial +from typing import Any import jaxlib.mlir.ir as ir import numpy as np -from jaxlib import xla_client - from .hlo_helpers import custom_call, mk_result_types_and_shapes from .plugin_support import import_from_plugin @@ -31,17 +30,14 @@ from .plugin_support import import_from_plugin _cusparse = import_from_plugin("cuda", "_sparse") _hipsparse = import_from_plugin("rocm", "_sparse") -if _cusparse: - for _name, _value in _cusparse.registrations().items(): - api_version = 1 if _name.endswith("_ffi") else 0 - xla_client.register_custom_call_target(_name, _value, platform="CUDA", - api_version=api_version) - -if _hipsparse: - for _name, _value in _hipsparse.registrations().items(): - api_version = 1 if _name.endswith("_ffi") else 0 - xla_client.register_custom_call_target(_name, _value, platform="ROCM", - api_version=api_version) +def registrations() -> dict[str, list[tuple[str, Any, int]]]: + registrations = {"CUDA": [], "ROCM": []} + for platform, module in [("CUDA", _cusparse), ("ROCM", _hipsparse)]: + if module: + registrations[platform].extend( + (name, value, int(name.endswith("_ffi"))) + for name, value in module.registrations().items()) + return registrations # pytype: disable=bad-return-type cuda_is_supported = bool(_cusparse and _cusparse.sparse_supported) diff --git a/jaxlib/lapack.py b/jaxlib/lapack.py index c5a59e314..330fcb992 100644 --- a/jaxlib/lapack.py +++ b/jaxlib/lapack.py @@ -12,23 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np +from typing import Any -from jaxlib import xla_client +import numpy as np from .cpu import _lapack from .cpu._lapack import eig from .cpu._lapack import schur -for _name, _value in _lapack.registrations().items(): - api_version = 0 - if _name.endswith("_ffi"): - api_version = 1 - xla_client.register_custom_call_as_batch_partitionable(_name) - xla_client.register_custom_call_target( - _name, _value, platform="cpu", api_version=api_version - ) - EigComputationMode = eig.ComputationMode SchurComputationMode = schur.ComputationMode @@ -43,6 +34,17 @@ LAPACK_DTYPE_PREFIX = { } +def registrations() -> dict[str, list[tuple[str, Any, int]]]: + return {"cpu": [ + (name, value, int(name.endswith("_ffi"))) + for name, value in _lapack.registrations().items() + ]} + + +def batch_partitionable_targets() -> list[str]: + return [name for name in _lapack.registrations() if name.endswith("_ffi")] + + def prepare_lapack_call(fn_base, dtype): """Initializes the LAPACK library and returns the LAPACK target name.""" _lapack.initialize()