Move (most) jaxlib linalg custom call registration into JAX.

My motivation here is to fix the plugin support for batch partitionable custom calls. Since plugin support for custom call partitioners is provided via register_plugin_callback in xla_bridge, instead of xla_client itself, it's much more straightforward to register the custom calls in JAX.

It would be possible to refactor things differently, but it actually seems like a reasonable choice to use the supported APIs from `jax.ffi` instead of `xla_client` so that we can take advantage of any new features we might add there in the future.

This is all still a little bit brittle and I'd eventually like to migrate to a version where the XLA FFI library provides a mechanism for exporting handlers, but this change is still compatible with any future changes like that.

PiperOrigin-RevId: 735381736
This commit is contained in:
Dan Foreman-Mackey 2025-03-10 08:17:07 -07:00 committed by jax authors
parent 91340ea0a7
commit 21884d4a14
6 changed files with 100 additions and 81 deletions

View File

@ -44,6 +44,10 @@ from jax._src.lax import lax as lax_internal
from jax._src.lax import svd as lax_svd
from jax._src.lax import utils as lax_utils
from jax._src.lax.lax import _float, _complex, _int
from jax._src.lib import gpu_linalg
from jax._src.lib import gpu_solver
from jax._src.lib import gpu_sparse
from jax._src.lib import lapack
from jax._src.lib import version as jaxlib_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import chlo
@ -51,12 +55,23 @@ from jax._src.lib.mlir.dialects import hlo
from jax._src.partition_spec import PartitionSpec as P
from jax._src.typing import Array, ArrayLike
# The following imports may be unused but they are needed to register the
# custom call targets defined in each module.
from jax._src.lib import gpu_linalg # pylint:disable=unused-import # noqa: F401
from jax._src.lib import gpu_solver # pylint:disable=unused-import # noqa: F401
from jax._src.lib import gpu_sparse # pylint:disable=unused-import # noqa: F401
from jax._src.lib import lapack # pylint:disable=unused-import # noqa: F401
def register_module_custom_calls(module):
if hasattr(module, "registrations"):
for platform, targets in module.registrations().items():
for name, value, api_version in targets:
ffi.register_ffi_target(
name, value, platform=platform, api_version=api_version
)
if hasattr(module, "batch_partitionable_targets"):
for name in module.batch_partitionable_targets():
ffi.register_ffi_target_as_batch_partitionable(name)
register_module_custom_calls(gpu_linalg)
register_module_custom_calls(gpu_solver)
register_module_custom_calls(gpu_sparse)
register_module_custom_calls(lapack)
# Top-level functions in alphabetical order.

View File

@ -19,8 +19,18 @@ import math
import jax
from jax._src import core
from jax._src import ffi
from jax._src import util
from jax._src.typing import Array
from jax._src.lib import gpu_sparse
if hasattr(gpu_sparse, "registrations"):
for platform, targets in gpu_sparse.registrations().items():
for name, value, api_version in targets:
ffi.register_ffi_target(
name, value, platform=platform, api_version=api_version
)
class JAXSparse(util.StrictABC):

View File

@ -12,25 +12,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from jaxlib import xla_client
from typing import Any
from .plugin_support import import_from_plugin
_cuda_linalg = import_from_plugin("cuda", "_linalg")
_hip_linalg = import_from_plugin("rocm", "_linalg")
if _cuda_linalg:
for _name, _value in _cuda_linalg.registrations().items():
xla_client.register_custom_call_target(
_name, _value, platform="CUDA", api_version=1
)
xla_client.register_custom_call_as_batch_partitionable(
"cu_lu_pivots_to_permutation")
def registrations() -> dict[str, list[tuple[str, Any, int]]]:
registrations = {"CUDA": [], "ROCM": []}
for platform, module in [("CUDA", _cuda_linalg), ("ROCM", _hip_linalg)]:
if module:
registrations[platform].extend(
(*i, 1) for i in module.registrations().items())
return registrations # pytype: disable=bad-return-type
if _hip_linalg:
for _name, _value in _hip_linalg.registrations().items():
xla_client.register_custom_call_target(
_name, _value, platform="ROCM", api_version=1
)
xla_client.register_custom_call_as_batch_partitionable(
"hip_lu_pivots_to_permutation")
def batch_partitionable_targets() -> list[str]:
targets = []
if _cuda_linalg:
targets.append("cu_lu_pivots_to_permutation")
if _hip_linalg:
targets.append("hip_lu_pivots_to_permutation")
return targets

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from jaxlib import xla_client
from typing import Any
from .plugin_support import import_from_plugin
@ -24,45 +24,39 @@ _hipblas = import_from_plugin("rocm", "_blas")
_hipsolver = import_from_plugin("rocm", "_solver")
_hiphybrid = import_from_plugin("rocm", "_hybrid")
if _cublas:
for _name, _value in _cublas.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
if _cusolver:
for _name, _value in _cusolver.registrations().items():
# TODO(danfm): Clean up after all legacy custom calls are ported.
api_version = 0
if _name.endswith("_ffi"):
api_version = 1
xla_client.register_custom_call_as_batch_partitionable(_name)
xla_client.register_custom_call_target(_name, _value, platform="CUDA",
api_version=api_version)
def registrations() -> dict[str, list[tuple[str, Any, int]]]:
registrations = {"CUDA": [], "ROCM": []}
for platform, module in [("CUDA", _cublas), ("ROCM", _hipblas)]:
if module:
registrations[platform].extend(
(*i, 0) for i in module.registrations().items())
for platform, module in [("CUDA", _cusolver), ("ROCM", _hipsolver)]:
if module:
registrations[platform].extend(
(name, value, int(name.endswith("_ffi")))
for name, value in module.registrations().items()
)
for platform, module in [("CUDA", _cuhybrid), ("ROCM", _hiphybrid)]:
if module:
registrations[platform].extend(
(*i, 1) for i in module.registrations().items())
return registrations # pytype: disable=bad-return-type
if _cuhybrid:
for _name, _value in _cuhybrid.registrations().items():
xla_client.register_custom_call_as_batch_partitionable(_name)
xla_client.register_custom_call_target(_name, _value, platform="CUDA",
api_version=1)
if _hipblas:
for _name, _value in _hipblas.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
def batch_partitionable_targets() -> list[str]:
targets = []
for module in [_cusolver, _hipsolver]:
if module:
targets.extend(
name for name in module.registrations()
if name.endswith("_ffi")
)
for module in [_cuhybrid, _hiphybrid]:
if module:
targets.extend(name for name in module.registrations())
return targets
if _hipsolver:
for _name, _value in _hipsolver.registrations().items():
# TODO(danfm): Clean up after all legacy custom calls are ported.
api_version = 0
if _name.endswith("_ffi"):
api_version = 1
xla_client.register_custom_call_as_batch_partitionable(_name)
xla_client.register_custom_call_target(_name, _value, platform="ROCM",
api_version=api_version)
if _hiphybrid:
for _name, _value in _hiphybrid.registrations().items():
xla_client.register_custom_call_as_batch_partitionable(_name)
xla_client.register_custom_call_target(_name, _value, platform="ROCM",
api_version=1)
def initialize_hybrid_kernels():
if _cuhybrid:
@ -70,6 +64,7 @@ def initialize_hybrid_kernels():
if _hiphybrid:
_hiphybrid.initialize()
def has_magma():
if _cuhybrid:
return _cuhybrid.has_magma()

View File

@ -17,13 +17,12 @@ cusparse wrappers for performing sparse matrix computations in JAX
import math
from functools import partial
from typing import Any
import jaxlib.mlir.ir as ir
import numpy as np
from jaxlib import xla_client
from .hlo_helpers import custom_call, mk_result_types_and_shapes
from .plugin_support import import_from_plugin
@ -31,17 +30,14 @@ from .plugin_support import import_from_plugin
_cusparse = import_from_plugin("cuda", "_sparse")
_hipsparse = import_from_plugin("rocm", "_sparse")
if _cusparse:
for _name, _value in _cusparse.registrations().items():
api_version = 1 if _name.endswith("_ffi") else 0
xla_client.register_custom_call_target(_name, _value, platform="CUDA",
api_version=api_version)
if _hipsparse:
for _name, _value in _hipsparse.registrations().items():
api_version = 1 if _name.endswith("_ffi") else 0
xla_client.register_custom_call_target(_name, _value, platform="ROCM",
api_version=api_version)
def registrations() -> dict[str, list[tuple[str, Any, int]]]:
registrations = {"CUDA": [], "ROCM": []}
for platform, module in [("CUDA", _cusparse), ("ROCM", _hipsparse)]:
if module:
registrations[platform].extend(
(name, value, int(name.endswith("_ffi")))
for name, value in module.registrations().items())
return registrations # pytype: disable=bad-return-type
cuda_is_supported = bool(_cusparse and _cusparse.sparse_supported)

View File

@ -12,23 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from typing import Any
from jaxlib import xla_client
import numpy as np
from .cpu import _lapack
from .cpu._lapack import eig
from .cpu._lapack import schur
for _name, _value in _lapack.registrations().items():
api_version = 0
if _name.endswith("_ffi"):
api_version = 1
xla_client.register_custom_call_as_batch_partitionable(_name)
xla_client.register_custom_call_target(
_name, _value, platform="cpu", api_version=api_version
)
EigComputationMode = eig.ComputationMode
SchurComputationMode = schur.ComputationMode
@ -43,6 +34,17 @@ LAPACK_DTYPE_PREFIX = {
}
def registrations() -> dict[str, list[tuple[str, Any, int]]]:
return {"cpu": [
(name, value, int(name.endswith("_ffi")))
for name, value in _lapack.registrations().items()
]}
def batch_partitionable_targets() -> list[str]:
return [name for name in _lapack.registrations() if name.endswith("_ffi")]
def prepare_lapack_call(fn_base, dtype):
"""Initializes the LAPACK library and returns the LAPACK target name."""
_lapack.initialize()