2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2021 The JAX Authors.
|
2021-03-31 18:35:15 -07:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
import functools
|
2022-05-06 09:34:25 -07:00
|
|
|
from functools import partial
|
2021-03-31 18:35:15 -07:00
|
|
|
import operator
|
|
|
|
|
2022-04-06 13:56:01 -07:00
|
|
|
import jaxlib.mlir.ir as ir
|
|
|
|
|
2022-12-15 20:59:34 -08:00
|
|
|
from .hlo_helpers import custom_call
|
2021-03-31 18:35:15 -07:00
|
|
|
|
|
|
|
from jaxlib import xla_client
|
|
|
|
|
|
|
|
try:
|
2023-03-04 00:48:29 +00:00
|
|
|
from .cuda import _linalg as _cuda_linalg # pytype: disable=import-error
|
2021-09-02 07:52:35 -07:00
|
|
|
for _name, _value in _cuda_linalg.registrations().items():
|
2021-03-31 18:35:15 -07:00
|
|
|
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
|
2022-08-18 11:38:31 -07:00
|
|
|
except ImportError:
|
2022-05-06 09:34:25 -07:00
|
|
|
_cuda_linalg = None
|
|
|
|
|
|
|
|
try:
|
2023-03-04 00:48:29 +00:00
|
|
|
from .rocm import _linalg as _hip_linalg # pytype: disable=import-error
|
2022-05-06 09:34:25 -07:00
|
|
|
for _name, _value in _hip_linalg.registrations().items():
|
|
|
|
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
|
2022-08-18 11:38:31 -07:00
|
|
|
except ImportError:
|
2022-05-06 09:34:25 -07:00
|
|
|
_hip_linalg = None
|
2021-03-31 18:35:15 -07:00
|
|
|
|
|
|
|
_prod = lambda xs: functools.reduce(operator.mul, xs, 1)
|
|
|
|
|
|
|
|
|
2022-12-15 20:59:34 -08:00
|
|
|
def _lu_pivots_to_permutation_hlo(platform, gpu_linalg, pivots, *, permutation_size):
|
2022-04-06 13:56:01 -07:00
|
|
|
"""Kernel for the transformation of pivots to permutations on GPU."""
|
|
|
|
typ = ir.RankedTensorType(pivots.type)
|
|
|
|
dims = typ.shape
|
|
|
|
i32_type = ir.IntegerType.get_signless(32)
|
|
|
|
|
|
|
|
assert typ.element_type == i32_type, typ
|
|
|
|
|
|
|
|
batch_size = _prod(dims[:-1])
|
|
|
|
pivot_size = dims[-1]
|
|
|
|
|
2022-05-06 09:34:25 -07:00
|
|
|
opaque = gpu_linalg.lu_pivots_to_permutation_descriptor(
|
2022-04-06 13:56:01 -07:00
|
|
|
batch_size, pivot_size, permutation_size)
|
2022-05-06 14:50:54 -07:00
|
|
|
pivots_layout = tuple(range(len(dims) - 1, -1, -1))
|
2022-04-06 13:56:01 -07:00
|
|
|
permutations_layout = pivots_layout
|
|
|
|
permutations_dims = list(dims)
|
|
|
|
permutations_dims[-1] = permutation_size
|
|
|
|
permutations_type = ir.RankedTensorType.get(permutations_dims, i32_type)
|
2022-05-06 14:50:54 -07:00
|
|
|
return custom_call(
|
|
|
|
f"{platform}_lu_pivots_to_permutation",
|
2022-04-06 13:56:01 -07:00
|
|
|
[permutations_type],
|
|
|
|
[pivots],
|
2022-05-06 14:50:54 -07:00
|
|
|
backend_config=opaque,
|
|
|
|
operand_layouts=[pivots_layout],
|
|
|
|
result_layouts=[permutations_layout])
|
2022-05-06 09:34:25 -07:00
|
|
|
|
2022-12-15 20:59:34 -08:00
|
|
|
cuda_lu_pivots_to_permutation = partial(_lu_pivots_to_permutation_hlo, "cu",
|
2022-10-25 07:23:07 -07:00
|
|
|
_cuda_linalg)
|
2022-05-23 13:58:59 +00:00
|
|
|
hip_lu_pivots_to_permutation = partial(
|
2022-12-15 20:59:34 -08:00
|
|
|
_lu_pivots_to_permutation_hlo, "hip", _hip_linalg)
|