2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2019 The JAX Authors.
|
2019-08-02 11:16:15 -04:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2025-03-10 08:17:07 -07:00
|
|
|
from typing import Any
|
2019-08-02 11:16:15 -04:00
|
|
|
|
2025-02-27 11:51:39 -08:00
|
|
|
from .plugin_support import import_from_plugin
|
|
|
|
|
|
|
|
_cublas = import_from_plugin("cuda", "_blas")
|
|
|
|
_cusolver = import_from_plugin("cuda", "_solver")
|
|
|
|
_cuhybrid = import_from_plugin("cuda", "_hybrid")
|
|
|
|
|
|
|
|
_hipblas = import_from_plugin("rocm", "_blas")
|
|
|
|
_hipsolver = import_from_plugin("rocm", "_solver")
|
|
|
|
_hiphybrid = import_from_plugin("rocm", "_hybrid")
|
2023-11-06 09:05:08 -08:00
|
|
|
|
2019-08-08 11:50:31 -04:00
|
|
|
|
2025-03-10 08:17:07 -07:00
|
|
|
def registrations() -> dict[str, list[tuple[str, Any, int]]]:
|
|
|
|
registrations = {"CUDA": [], "ROCM": []}
|
|
|
|
for platform, module in [("CUDA", _cublas), ("ROCM", _hipblas)]:
|
|
|
|
if module:
|
|
|
|
registrations[platform].extend(
|
|
|
|
(*i, 0) for i in module.registrations().items())
|
|
|
|
for platform, module in [("CUDA", _cusolver), ("ROCM", _hipsolver)]:
|
|
|
|
if module:
|
|
|
|
registrations[platform].extend(
|
|
|
|
(name, value, int(name.endswith("_ffi")))
|
|
|
|
for name, value in module.registrations().items()
|
|
|
|
)
|
|
|
|
for platform, module in [("CUDA", _cuhybrid), ("ROCM", _hiphybrid)]:
|
|
|
|
if module:
|
|
|
|
registrations[platform].extend(
|
|
|
|
(*i, 1) for i in module.registrations().items())
|
|
|
|
return registrations # pytype: disable=bad-return-type
|
2022-05-06 09:34:25 -07:00
|
|
|
|
2024-11-18 08:11:04 -08:00
|
|
|
|
2025-03-10 08:17:07 -07:00
|
|
|
def batch_partitionable_targets() -> list[str]:
|
|
|
|
targets = []
|
|
|
|
for module in [_cusolver, _hipsolver]:
|
|
|
|
if module:
|
|
|
|
targets.extend(
|
|
|
|
name for name in module.registrations()
|
|
|
|
if name.endswith("_ffi")
|
|
|
|
)
|
|
|
|
for module in [_cuhybrid, _hiphybrid]:
|
|
|
|
if module:
|
|
|
|
targets.extend(name for name in module.registrations())
|
|
|
|
return targets
|
2024-06-12 18:45:01 -05:00
|
|
|
|
2024-11-18 08:11:04 -08:00
|
|
|
|
|
|
|
def initialize_hybrid_kernels():
|
|
|
|
if _cuhybrid:
|
|
|
|
_cuhybrid.initialize()
|
|
|
|
if _hiphybrid:
|
|
|
|
_hiphybrid.initialize()
|
|
|
|
|
2025-03-10 08:17:07 -07:00
|
|
|
|
2024-11-18 08:11:04 -08:00
|
|
|
def has_magma():
|
|
|
|
if _cuhybrid:
|
|
|
|
return _cuhybrid.has_magma()
|
|
|
|
if _hiphybrid:
|
|
|
|
return _hiphybrid.has_magma()
|
|
|
|
return False
|