rocm_jax/jaxlib/lapack.py
Dan Foreman-Mackey 21884d4a14 Move (most) jaxlib linalg custom call registration into JAX.
My motivation here is to fix the plugin support for batch partitionable custom calls. Since plugin support for custom call partitioners is provided via register_plugin_callback in xla_bridge, instead of xla_client itself, it's much more straightforward to register the custom calls in JAX.

It would be possible to refactor things differently, but it actually seems like a reasonable choice to use the supported APIs from `jax.ffi` instead of `xla_client` so that we can take advantage of any new features we might add there in the future.

This is all still a little bit brittle and I'd eventually like to migrate to a version where the XLA FFI library provides a mechanism for exporting handlers, but this change is still compatible with any future changes like that.

PiperOrigin-RevId: 735381736
2025-03-10 08:17:44 -07:00

63 lines
1.8 KiB
Python

# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import numpy as np
from .cpu import _lapack
from .cpu._lapack import eig
from .cpu._lapack import schur
EigComputationMode = eig.ComputationMode
SchurComputationMode = schur.ComputationMode
SchurSort = schur.Sort
LAPACK_DTYPE_PREFIX = {
np.float32: "s",
np.float64: "d",
np.complex64: "c",
np.complex128: "z",
}
def registrations() -> dict[str, list[tuple[str, Any, int]]]:
return {"cpu": [
(name, value, int(name.endswith("_ffi")))
for name, value in _lapack.registrations().items()
]}
def batch_partitionable_targets() -> list[str]:
return [name for name in _lapack.registrations() if name.endswith("_ffi")]
def prepare_lapack_call(fn_base, dtype):
"""Initializes the LAPACK library and returns the LAPACK target name."""
_lapack.initialize()
return build_lapack_fn_target(fn_base, dtype)
def build_lapack_fn_target(fn_base: str, dtype) -> str:
"""Builds the target name for a LAPACK function custom call."""
try:
prefix = (
LAPACK_DTYPE_PREFIX.get(dtype, None) or LAPACK_DTYPE_PREFIX[dtype.type]
)
return f"lapack_{prefix}{fn_base}"
except KeyError as err:
raise NotImplementedError(err, f"Unsupported dtype {dtype}.") from err