mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Move (most) jaxlib linalg custom call registration into JAX.
My motivation here is to fix the plugin support for batch partitionable custom calls. Since plugin support for custom call partitioners is provided via register_plugin_callback in xla_bridge, instead of xla_client itself, it's much more straightforward to register the custom calls in JAX. It would be possible to refactor things differently, but it actually seems like a reasonable choice to use the supported APIs from `jax.ffi` instead of `xla_client` so that we can take advantage of any new features we might add there in the future. This is all still a little bit brittle and I'd eventually like to migrate to a version where the XLA FFI library provides a mechanism for exporting handlers, but this change is still compatible with any future changes like that. PiperOrigin-RevId: 735381736
This commit is contained in:
parent
91340ea0a7
commit
21884d4a14
@ -44,6 +44,10 @@ from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import svd as lax_svd
|
||||
from jax._src.lax import utils as lax_utils
|
||||
from jax._src.lax.lax import _float, _complex, _int
|
||||
from jax._src.lib import gpu_linalg
|
||||
from jax._src.lib import gpu_solver
|
||||
from jax._src.lib import gpu_sparse
|
||||
from jax._src.lib import lapack
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import chlo
|
||||
@ -51,12 +55,23 @@ from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.partition_spec import PartitionSpec as P
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
# The following imports may be unused but they are needed to register the
|
||||
# custom call targets defined in each module.
|
||||
from jax._src.lib import gpu_linalg # pylint:disable=unused-import # noqa: F401
|
||||
from jax._src.lib import gpu_solver # pylint:disable=unused-import # noqa: F401
|
||||
from jax._src.lib import gpu_sparse # pylint:disable=unused-import # noqa: F401
|
||||
from jax._src.lib import lapack # pylint:disable=unused-import # noqa: F401
|
||||
|
||||
def register_module_custom_calls(module):
|
||||
if hasattr(module, "registrations"):
|
||||
for platform, targets in module.registrations().items():
|
||||
for name, value, api_version in targets:
|
||||
ffi.register_ffi_target(
|
||||
name, value, platform=platform, api_version=api_version
|
||||
)
|
||||
if hasattr(module, "batch_partitionable_targets"):
|
||||
for name in module.batch_partitionable_targets():
|
||||
ffi.register_ffi_target_as_batch_partitionable(name)
|
||||
|
||||
|
||||
register_module_custom_calls(gpu_linalg)
|
||||
register_module_custom_calls(gpu_solver)
|
||||
register_module_custom_calls(gpu_sparse)
|
||||
register_module_custom_calls(lapack)
|
||||
|
||||
|
||||
# Top-level functions in alphabetical order.
|
||||
|
@ -19,8 +19,18 @@ import math
|
||||
|
||||
import jax
|
||||
from jax._src import core
|
||||
from jax._src import ffi
|
||||
from jax._src import util
|
||||
from jax._src.typing import Array
|
||||
from jax._src.lib import gpu_sparse
|
||||
|
||||
|
||||
if hasattr(gpu_sparse, "registrations"):
|
||||
for platform, targets in gpu_sparse.registrations().items():
|
||||
for name, value, api_version in targets:
|
||||
ffi.register_ffi_target(
|
||||
name, value, platform=platform, api_version=api_version
|
||||
)
|
||||
|
||||
|
||||
class JAXSparse(util.StrictABC):
|
||||
|
@ -12,25 +12,26 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jaxlib import xla_client
|
||||
from typing import Any
|
||||
|
||||
from .plugin_support import import_from_plugin
|
||||
|
||||
_cuda_linalg = import_from_plugin("cuda", "_linalg")
|
||||
_hip_linalg = import_from_plugin("rocm", "_linalg")
|
||||
|
||||
if _cuda_linalg:
|
||||
for _name, _value in _cuda_linalg.registrations().items():
|
||||
xla_client.register_custom_call_target(
|
||||
_name, _value, platform="CUDA", api_version=1
|
||||
)
|
||||
xla_client.register_custom_call_as_batch_partitionable(
|
||||
"cu_lu_pivots_to_permutation")
|
||||
def registrations() -> dict[str, list[tuple[str, Any, int]]]:
|
||||
registrations = {"CUDA": [], "ROCM": []}
|
||||
for platform, module in [("CUDA", _cuda_linalg), ("ROCM", _hip_linalg)]:
|
||||
if module:
|
||||
registrations[platform].extend(
|
||||
(*i, 1) for i in module.registrations().items())
|
||||
return registrations # pytype: disable=bad-return-type
|
||||
|
||||
if _hip_linalg:
|
||||
for _name, _value in _hip_linalg.registrations().items():
|
||||
xla_client.register_custom_call_target(
|
||||
_name, _value, platform="ROCM", api_version=1
|
||||
)
|
||||
xla_client.register_custom_call_as_batch_partitionable(
|
||||
"hip_lu_pivots_to_permutation")
|
||||
|
||||
def batch_partitionable_targets() -> list[str]:
|
||||
targets = []
|
||||
if _cuda_linalg:
|
||||
targets.append("cu_lu_pivots_to_permutation")
|
||||
if _hip_linalg:
|
||||
targets.append("hip_lu_pivots_to_permutation")
|
||||
return targets
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jaxlib import xla_client
|
||||
from typing import Any
|
||||
|
||||
from .plugin_support import import_from_plugin
|
||||
|
||||
@ -24,45 +24,39 @@ _hipblas = import_from_plugin("rocm", "_blas")
|
||||
_hipsolver = import_from_plugin("rocm", "_solver")
|
||||
_hiphybrid = import_from_plugin("rocm", "_hybrid")
|
||||
|
||||
if _cublas:
|
||||
for _name, _value in _cublas.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
|
||||
|
||||
if _cusolver:
|
||||
for _name, _value in _cusolver.registrations().items():
|
||||
# TODO(danfm): Clean up after all legacy custom calls are ported.
|
||||
api_version = 0
|
||||
if _name.endswith("_ffi"):
|
||||
api_version = 1
|
||||
xla_client.register_custom_call_as_batch_partitionable(_name)
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA",
|
||||
api_version=api_version)
|
||||
def registrations() -> dict[str, list[tuple[str, Any, int]]]:
|
||||
registrations = {"CUDA": [], "ROCM": []}
|
||||
for platform, module in [("CUDA", _cublas), ("ROCM", _hipblas)]:
|
||||
if module:
|
||||
registrations[platform].extend(
|
||||
(*i, 0) for i in module.registrations().items())
|
||||
for platform, module in [("CUDA", _cusolver), ("ROCM", _hipsolver)]:
|
||||
if module:
|
||||
registrations[platform].extend(
|
||||
(name, value, int(name.endswith("_ffi")))
|
||||
for name, value in module.registrations().items()
|
||||
)
|
||||
for platform, module in [("CUDA", _cuhybrid), ("ROCM", _hiphybrid)]:
|
||||
if module:
|
||||
registrations[platform].extend(
|
||||
(*i, 1) for i in module.registrations().items())
|
||||
return registrations # pytype: disable=bad-return-type
|
||||
|
||||
if _cuhybrid:
|
||||
for _name, _value in _cuhybrid.registrations().items():
|
||||
xla_client.register_custom_call_as_batch_partitionable(_name)
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA",
|
||||
api_version=1)
|
||||
|
||||
if _hipblas:
|
||||
for _name, _value in _hipblas.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
|
||||
def batch_partitionable_targets() -> list[str]:
|
||||
targets = []
|
||||
for module in [_cusolver, _hipsolver]:
|
||||
if module:
|
||||
targets.extend(
|
||||
name for name in module.registrations()
|
||||
if name.endswith("_ffi")
|
||||
)
|
||||
for module in [_cuhybrid, _hiphybrid]:
|
||||
if module:
|
||||
targets.extend(name for name in module.registrations())
|
||||
return targets
|
||||
|
||||
if _hipsolver:
|
||||
for _name, _value in _hipsolver.registrations().items():
|
||||
# TODO(danfm): Clean up after all legacy custom calls are ported.
|
||||
api_version = 0
|
||||
if _name.endswith("_ffi"):
|
||||
api_version = 1
|
||||
xla_client.register_custom_call_as_batch_partitionable(_name)
|
||||
xla_client.register_custom_call_target(_name, _value, platform="ROCM",
|
||||
api_version=api_version)
|
||||
|
||||
if _hiphybrid:
|
||||
for _name, _value in _hiphybrid.registrations().items():
|
||||
xla_client.register_custom_call_as_batch_partitionable(_name)
|
||||
xla_client.register_custom_call_target(_name, _value, platform="ROCM",
|
||||
api_version=1)
|
||||
|
||||
def initialize_hybrid_kernels():
|
||||
if _cuhybrid:
|
||||
@ -70,6 +64,7 @@ def initialize_hybrid_kernels():
|
||||
if _hiphybrid:
|
||||
_hiphybrid.initialize()
|
||||
|
||||
|
||||
def has_magma():
|
||||
if _cuhybrid:
|
||||
return _cuhybrid.has_magma()
|
||||
|
@ -17,13 +17,12 @@ cusparse wrappers for performing sparse matrix computations in JAX
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import jaxlib.mlir.ir as ir
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jaxlib import xla_client
|
||||
|
||||
from .hlo_helpers import custom_call, mk_result_types_and_shapes
|
||||
|
||||
from .plugin_support import import_from_plugin
|
||||
@ -31,17 +30,14 @@ from .plugin_support import import_from_plugin
|
||||
_cusparse = import_from_plugin("cuda", "_sparse")
|
||||
_hipsparse = import_from_plugin("rocm", "_sparse")
|
||||
|
||||
if _cusparse:
|
||||
for _name, _value in _cusparse.registrations().items():
|
||||
api_version = 1 if _name.endswith("_ffi") else 0
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA",
|
||||
api_version=api_version)
|
||||
|
||||
if _hipsparse:
|
||||
for _name, _value in _hipsparse.registrations().items():
|
||||
api_version = 1 if _name.endswith("_ffi") else 0
|
||||
xla_client.register_custom_call_target(_name, _value, platform="ROCM",
|
||||
api_version=api_version)
|
||||
def registrations() -> dict[str, list[tuple[str, Any, int]]]:
|
||||
registrations = {"CUDA": [], "ROCM": []}
|
||||
for platform, module in [("CUDA", _cusparse), ("ROCM", _hipsparse)]:
|
||||
if module:
|
||||
registrations[platform].extend(
|
||||
(name, value, int(name.endswith("_ffi")))
|
||||
for name, value in module.registrations().items())
|
||||
return registrations # pytype: disable=bad-return-type
|
||||
|
||||
|
||||
cuda_is_supported = bool(_cusparse and _cusparse.sparse_supported)
|
||||
|
@ -12,23 +12,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
from typing import Any
|
||||
|
||||
from jaxlib import xla_client
|
||||
import numpy as np
|
||||
|
||||
from .cpu import _lapack
|
||||
from .cpu._lapack import eig
|
||||
from .cpu._lapack import schur
|
||||
|
||||
for _name, _value in _lapack.registrations().items():
|
||||
api_version = 0
|
||||
if _name.endswith("_ffi"):
|
||||
api_version = 1
|
||||
xla_client.register_custom_call_as_batch_partitionable(_name)
|
||||
xla_client.register_custom_call_target(
|
||||
_name, _value, platform="cpu", api_version=api_version
|
||||
)
|
||||
|
||||
|
||||
EigComputationMode = eig.ComputationMode
|
||||
SchurComputationMode = schur.ComputationMode
|
||||
@ -43,6 +34,17 @@ LAPACK_DTYPE_PREFIX = {
|
||||
}
|
||||
|
||||
|
||||
def registrations() -> dict[str, list[tuple[str, Any, int]]]:
|
||||
return {"cpu": [
|
||||
(name, value, int(name.endswith("_ffi")))
|
||||
for name, value in _lapack.registrations().items()
|
||||
]}
|
||||
|
||||
|
||||
def batch_partitionable_targets() -> list[str]:
|
||||
return [name for name in _lapack.registrations() if name.endswith("_ffi")]
|
||||
|
||||
|
||||
def prepare_lapack_call(fn_base, dtype):
|
||||
"""Initializes the LAPACK library and returns the LAPACK target name."""
|
||||
_lapack.initialize()
|
||||
|
Loading…
x
Reference in New Issue
Block a user