mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Remove JAX custom call implementation of batched triangular solve.
XLA supports batched triangular solve on GPU and has since February 2022, which is older than the minimum jaxlib version. We can therefore delete our implementation and just use XLA's implementation. PiperOrigin-RevId: 482031830
This commit is contained in:
parent
d20b9fa498
commit
5617a02fa4
@ -915,39 +915,6 @@ mlir.register_lowering(triangular_solve_p, _triangular_solve_cpu_lower,
|
||||
platform='cpu')
|
||||
|
||||
|
||||
def _triangular_solve_gpu_lower(
|
||||
trsm_impl, ctx, a, b, *, left_side, lower, transpose_a, conjugate_a,
|
||||
unit_diagonal):
|
||||
a_aval, _ = ctx.avals_in
|
||||
m, n = a_aval.shape[-2:]
|
||||
batch = prod(a_aval.shape[:-2])
|
||||
if conjugate_a and not transpose_a:
|
||||
a = chlo.ConjOp(a).result
|
||||
conjugate_a = False
|
||||
if batch > 1 and m <= 256 and n <= 256:
|
||||
return [trsm_impl(a_aval.dtype, a, b, left_side, lower, transpose_a,
|
||||
conjugate_a, unit_diagonal)]
|
||||
else:
|
||||
# Use the XLA implementation for unbatched triangular_solve.
|
||||
if transpose_a:
|
||||
transpose = "ADJOINT" if conjugate_a else "TRANSPOSE"
|
||||
else:
|
||||
transpose = "NO_TRANSPOSE"
|
||||
return mhlo.TriangularSolveOp(b.type, a, b, ir.BoolAttr.get(left_side),
|
||||
ir.BoolAttr.get(lower),
|
||||
ir.BoolAttr.get(unit_diagonal),
|
||||
mhlo.TransposeAttr.get(transpose)).results
|
||||
|
||||
mlir.register_lowering(
|
||||
triangular_solve_p,
|
||||
partial(_triangular_solve_gpu_lower, gpu_solver.cuda_trsm),
|
||||
platform='cuda')
|
||||
mlir.register_lowering(
|
||||
triangular_solve_p,
|
||||
partial(_triangular_solve_gpu_lower, gpu_solver.rocm_trsm),
|
||||
platform='rocm')
|
||||
|
||||
|
||||
# Support operation for LU decomposition: Transformation of the pivots returned
|
||||
# by LU decomposition into permutations.
|
||||
|
||||
|
@ -51,23 +51,6 @@ CublasType DtypeToCublasType(const py::dtype& np_type) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// Returns the descriptor for a TrsmBatched operation.
|
||||
std::pair<size_t, py::bytes> BuildTrsmBatchedDescriptor(
|
||||
const py::dtype& dtype, int batch, int m, int n, bool left_side, bool lower,
|
||||
bool trans_a, bool conj_a, bool unit_diagonal) {
|
||||
size_t size = batch * sizeof(void*);
|
||||
TrsmBatchedDescriptor desc;
|
||||
desc.type = DtypeToCublasType(dtype);
|
||||
desc.batch = batch;
|
||||
desc.m = m;
|
||||
desc.n = n;
|
||||
desc.side = left_side ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT;
|
||||
desc.uplo = lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
|
||||
desc.trans = trans_a ? (conj_a ? CUBLAS_OP_C : CUBLAS_OP_T) : CUBLAS_OP_N;
|
||||
desc.diag = unit_diagonal ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT;
|
||||
return {size, PackDescriptor(desc)};
|
||||
}
|
||||
|
||||
// Returns the descriptor for a GetrfBatched operation.
|
||||
std::pair<size_t, py::bytes> BuildGetrfBatchedDescriptor(const py::dtype& dtype,
|
||||
int b, int n) {
|
||||
@ -86,7 +69,6 @@ std::pair<size_t, py::bytes> BuildGeqrfBatchedDescriptor(const py::dtype& dtype,
|
||||
|
||||
py::dict Registrations() {
|
||||
py::dict dict;
|
||||
dict["cublas_trsm_batched"] = EncapsulateFunction(TrsmBatched);
|
||||
dict["cublas_getrf_batched"] = EncapsulateFunction(GetrfBatched);
|
||||
dict["cublas_geqrf_batched"] = EncapsulateFunction(GeqrfBatched);
|
||||
return dict;
|
||||
@ -94,7 +76,6 @@ py::dict Registrations() {
|
||||
|
||||
PYBIND11_MODULE(_cublas, m) {
|
||||
m.def("registrations", &Registrations);
|
||||
m.def("build_trsm_batched_descriptor", &BuildTrsmBatchedDescriptor);
|
||||
m.def("build_getrf_batched_descriptor", &BuildGetrfBatchedDescriptor);
|
||||
m.def("build_geqrf_batched_descriptor", &BuildGeqrfBatchedDescriptor);
|
||||
}
|
||||
|
@ -73,90 +73,6 @@ int SizeOfCublasType(CublasType type) {
|
||||
|
||||
} // namespace
|
||||
|
||||
// Batched triangular solve: trsmbatched
|
||||
|
||||
static absl::Status TrsmBatched_(cudaStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<TrsmBatchedDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const TrsmBatchedDescriptor& d = **s;
|
||||
auto h = BlasHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
if (buffers[2] != buffers[1]) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync(
|
||||
buffers[2], buffers[1], SizeOfCublasType(d.type) * d.batch * d.m * d.n,
|
||||
cudaMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
const int lda = d.side == CUBLAS_SIDE_LEFT ? d.m : d.n;
|
||||
const int ldb = d.m;
|
||||
auto a_batch_host = MakeBatchPointers(stream, buffers[0], buffers[3], d.batch,
|
||||
SizeOfCublasType(d.type) * lda * lda);
|
||||
JAX_RETURN_IF_ERROR(a_batch_host.status());
|
||||
auto b_batch_host = MakeBatchPointers(stream, buffers[2], buffers[4], d.batch,
|
||||
SizeOfCublasType(d.type) * d.m * d.n);
|
||||
JAX_RETURN_IF_ERROR(b_batch_host.status());
|
||||
// TODO(phawkins): ideally we would not need to synchronize here, but to
|
||||
// avoid it we need a way to keep the host-side buffer alive until the copy
|
||||
// completes.
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaStreamSynchronize(stream)));
|
||||
switch (d.type) {
|
||||
case CublasType::F32: {
|
||||
float** a_batch_ptrs = static_cast<float**>(buffers[3]);
|
||||
float** b_batch_ptrs = static_cast<float**>(buffers[4]);
|
||||
// NOTE(phawkins): if alpha is in GPU memory, cuBlas seems to segfault.
|
||||
const float alpha = 1.0f;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasStrsmBatched(
|
||||
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
|
||||
const_cast<const float**>(a_batch_ptrs), lda, b_batch_ptrs, ldb,
|
||||
d.batch)));
|
||||
break;
|
||||
}
|
||||
case CublasType::F64: {
|
||||
double** a_batch_ptrs = static_cast<double**>(buffers[3]);
|
||||
double** b_batch_ptrs = static_cast<double**>(buffers[4]);
|
||||
const double alpha = 1.0;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasDtrsmBatched(
|
||||
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
|
||||
const_cast<const double**>(a_batch_ptrs), lda, b_batch_ptrs, ldb,
|
||||
d.batch)));
|
||||
break;
|
||||
}
|
||||
case CublasType::C64: {
|
||||
cuComplex** a_batch_ptrs = static_cast<cuComplex**>(buffers[3]);
|
||||
cuComplex** b_batch_ptrs = static_cast<cuComplex**>(buffers[4]);
|
||||
const cuComplex alpha = make_cuComplex(1.0f, 0.0f);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasCtrsmBatched(
|
||||
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
|
||||
const_cast<const cuComplex**>(a_batch_ptrs), lda, b_batch_ptrs, ldb,
|
||||
d.batch)));
|
||||
break;
|
||||
}
|
||||
case CublasType::C128: {
|
||||
cuDoubleComplex** a_batch_ptrs =
|
||||
static_cast<cuDoubleComplex**>(buffers[3]);
|
||||
cuDoubleComplex** b_batch_ptrs =
|
||||
static_cast<cuDoubleComplex**>(buffers[4]);
|
||||
const cuDoubleComplex alpha = make_cuDoubleComplex(1.0f, 0.0f);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasZtrsmBatched(
|
||||
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
|
||||
const_cast<const cuDoubleComplex**>(a_batch_ptrs), lda, b_batch_ptrs,
|
||||
ldb, d.batch)));
|
||||
break;
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void TrsmBatched(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = TrsmBatched_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
// Batched LU decomposition: getrfbatched
|
||||
|
||||
static absl::Status GetrfBatched_(cudaStream_t stream, void** buffers,
|
||||
|
@ -33,21 +33,6 @@ enum class CublasType {
|
||||
C128,
|
||||
};
|
||||
|
||||
|
||||
// Batched triangular solve: trsmbatched
|
||||
|
||||
struct TrsmBatchedDescriptor {
|
||||
CublasType type;
|
||||
int batch, m, n;
|
||||
cublasSideMode_t side;
|
||||
cublasFillMode_t uplo;
|
||||
cublasOperation_t trans;
|
||||
cublasDiagType_t diag;
|
||||
};
|
||||
|
||||
void TrsmBatched(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
// Batched LU decomposition: getrfbatched
|
||||
|
||||
struct GetrfBatchedDescriptor {
|
||||
|
@ -26,8 +26,6 @@ limitations under the License.
|
||||
namespace jax {
|
||||
namespace {
|
||||
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_trsm_batched", TrsmBatched,
|
||||
"CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_getrf_batched", GetrfBatched,
|
||||
"CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cuda_lu_pivots_to_permutation",
|
||||
|
@ -1,70 +0,0 @@
|
||||
/* Copyright 2021 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file is not used by JAX itself, but exists to assist with running
|
||||
// JAX-generated HLO code from outside of JAX.
|
||||
|
||||
#include "jaxlib/cuda/cublas_kernels.h"
|
||||
#include "jaxlib/cuda/cuda_lu_pivot_kernels.h"
|
||||
#include "jaxlib/cuda/cuda_prng_kernels.h"
|
||||
#include "jaxlib/cuda/cusolver_kernels.h"
|
||||
#include "jaxlib/cuda/cusparse_kernels.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
|
||||
|
||||
namespace jax {
|
||||
namespace {
|
||||
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_trsm_batched", TrsmBatched,
|
||||
"CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_getrf_batched", GetrfBatched,
|
||||
"CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cuda_lu_pivots_to_permutation",
|
||||
CudaLuPivotsToPermutation, "CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cuda_threefry2x32", CudaThreeFry2x32,
|
||||
"CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_potrf", Potrf, "CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_getrf", Getrf, "CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_geqrf", Geqrf, "CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_orgqr", Orgqr, "CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_syevd", Syevd, "CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_syevj", Syevj, "CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_gesvd", Gesvd, "CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_gesvdj", Gesvdj, "CUDA");
|
||||
|
||||
#if JAX_CUSPARSE_11300
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_csr_todense", CsrToDense,
|
||||
"CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_csr_fromdense", CsrFromDense,
|
||||
"CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_csr_matvec", CsrMatvec,
|
||||
"CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_csr_matmat", CsrMatmat,
|
||||
"CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_coo_todense", CooToDense,
|
||||
"CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_coo_fromdense", CooFromDense,
|
||||
"CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_coo_matvec", CooMatvec,
|
||||
"CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_coo_matmat", CooMatmat,
|
||||
"CUDA");
|
||||
#endif
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_gtsv2_f32", gtsv2_f32,
|
||||
"CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_gtsv2_f64", gtsv2_f64,
|
||||
"CUDA");
|
||||
|
||||
} // namespace
|
||||
} // namespace jax
|
@ -63,49 +63,6 @@ def _real_type(dtype):
|
||||
_prod = lambda xs: functools.reduce(operator.mul, xs, 1)
|
||||
|
||||
|
||||
def _trsm_mhlo(platform, gpu_blas, dtype, a, b, left_side=False, lower=False,
|
||||
trans_a=False, conj_a=False, diag=False):
|
||||
"""Batched triangular solve.
|
||||
|
||||
XLA implements unbatched triangular solve directly, so we need only implement
|
||||
the batched case."""
|
||||
b_type = ir.RankedTensorType(b.type)
|
||||
dims = b_type.shape
|
||||
assert len(dims) >= 2
|
||||
m, n = dims[-2:]
|
||||
batch_dims = tuple(dims[:-2])
|
||||
num_bd = len(batch_dims)
|
||||
batch = _prod(batch_dims)
|
||||
k = m if left_side else n
|
||||
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
if (batch_dims + (k, k) != tuple(a_type.shape) or
|
||||
a_type.element_type != b_type.element_type):
|
||||
raise ValueError("Argument mismatch for trsm, got {} and {}".format(
|
||||
a_type, b_type))
|
||||
|
||||
if conj_a and not trans_a:
|
||||
raise NotImplementedError("Conjugation without transposition not supported")
|
||||
|
||||
lwork, opaque = gpu_blas.build_trsm_batched_descriptor(
|
||||
np.dtype(dtype), batch, m, n, left_side, lower, trans_a, conj_a, diag)
|
||||
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
|
||||
work_type = ir.RankedTensorType.get([lwork], ir.IntegerType.get_signless(8))
|
||||
work_layout = [0]
|
||||
out = custom_call(
|
||||
f"{platform}blas_trsm_batched",
|
||||
[b_type, work_type, work_type],
|
||||
[a, b],
|
||||
backend_config=opaque,
|
||||
operand_layouts=[layout] * 2,
|
||||
result_layouts=[layout, work_layout, work_layout],
|
||||
operand_output_aliases={1: 0})
|
||||
return out[0]
|
||||
|
||||
cuda_trsm = partial(_trsm_mhlo, "cu", _cublas)
|
||||
rocm_trsm = partial(_trsm_mhlo, "hip", _hipblas)
|
||||
|
||||
|
||||
def _potrf_mhlo(platform, gpu_solver, dtype, a, lower):
|
||||
"""Cholesky decomposition."""
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
|
@ -50,24 +50,6 @@ HipblasType DtypeToHipblasType(const py::dtype& np_type) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// Returns the descriptor for a TrsmBatched operation.
|
||||
std::pair<size_t, py::bytes>
|
||||
BuildTrsmBatchedDescriptor(const py::dtype& dtype, int batch, int m, int n,
|
||||
bool left_side, bool lower, bool trans_a,
|
||||
bool conj_a, bool unit_diagonal) {
|
||||
size_t size = batch * sizeof(void*);
|
||||
TrsmBatchedDescriptor desc;
|
||||
desc.type = DtypeToHipblasType(dtype);
|
||||
desc.batch = batch;
|
||||
desc.m = m;
|
||||
desc.n = n;
|
||||
desc.side = left_side ? HIPBLAS_SIDE_LEFT : HIPBLAS_SIDE_RIGHT;
|
||||
desc.uplo = lower ? HIPBLAS_FILL_MODE_LOWER : HIPBLAS_FILL_MODE_UPPER;
|
||||
desc.trans = trans_a ? (conj_a ? HIPBLAS_OP_C : HIPBLAS_OP_T) : HIPBLAS_OP_N;
|
||||
desc.diag = unit_diagonal ? HIPBLAS_DIAG_UNIT : HIPBLAS_DIAG_NON_UNIT;
|
||||
return {size, PackDescriptor(desc)};
|
||||
}
|
||||
|
||||
// Returns the descriptor for a GetrfBatched operation.
|
||||
std::pair<size_t, py::bytes> BuildGetrfBatchedDescriptor(const py::dtype& dtype,
|
||||
int b, int n) {
|
||||
@ -86,7 +68,6 @@ std::pair<size_t, py::bytes> BuildGeqrfBatchedDescriptor(const py::dtype& dtype,
|
||||
|
||||
py::dict Registrations() {
|
||||
py::dict dict;
|
||||
dict["hipblas_trsm_batched"] = EncapsulateFunction(TrsmBatched);
|
||||
dict["hipblas_getrf_batched"] = EncapsulateFunction(GetrfBatched);
|
||||
dict["hipblas_geqrf_batched"] = EncapsulateFunction(GeqrfBatched);
|
||||
return dict;
|
||||
@ -94,7 +75,6 @@ py::dict Registrations() {
|
||||
|
||||
PYBIND11_MODULE(_hipblas, m) {
|
||||
m.def("registrations", &Registrations);
|
||||
m.def("build_trsm_batched_descriptor", &BuildTrsmBatchedDescriptor);
|
||||
m.def("build_getrf_batched_descriptor", &BuildGetrfBatchedDescriptor);
|
||||
m.def("build_geqrf_batched_descriptor", &BuildGeqrfBatchedDescriptor);
|
||||
}
|
||||
|
@ -72,90 +72,6 @@ int SizeOfHipblasType(HipblasType type) {
|
||||
|
||||
} // namespace
|
||||
|
||||
// Batched triangular solve: trsmbatched
|
||||
|
||||
static absl::Status TrsmBatched_(hipStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<TrsmBatchedDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const TrsmBatchedDescriptor& d = **s;
|
||||
auto h = BlasHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
if (buffers[2] != buffers[1]) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
|
||||
buffers[2], buffers[1], SizeOfHipblasType(d.type) * d.batch * d.m * d.n,
|
||||
hipMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
const int lda = d.side == HIPBLAS_SIDE_LEFT ? d.m : d.n;
|
||||
const int ldb = d.m;
|
||||
auto a_batch_host = MakeBatchPointers(stream, buffers[0], buffers[3], d.batch,
|
||||
SizeOfHipblasType(d.type) * lda * lda);
|
||||
JAX_RETURN_IF_ERROR(a_batch_host.status());
|
||||
auto b_batch_host = MakeBatchPointers(stream, buffers[2], buffers[4], d.batch,
|
||||
SizeOfHipblasType(d.type) * d.m * d.n);
|
||||
JAX_RETURN_IF_ERROR(b_batch_host.status());
|
||||
// TODO(phawkins): ideally we would not need to synchronize here, but to
|
||||
// avoid it we need a way to keep the host-side buffer alive until the copy
|
||||
// completes.
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipStreamSynchronize(stream)));
|
||||
switch (d.type) {
|
||||
case HipblasType::F32: {
|
||||
float** a_batch_ptrs = static_cast<float**>(buffers[3]);
|
||||
float** b_batch_ptrs = static_cast<float**>(buffers[4]);
|
||||
// TODO(reza): is the following statement correct for rocm?
|
||||
// NOTE(phawkins): if alpha is in GPU memory, cuBlas seems to segfault.
|
||||
const float alpha = 1.0f;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasStrsmBatched(
|
||||
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
|
||||
const_cast<float**>(a_batch_ptrs), lda, b_batch_ptrs, ldb, d.batch)));
|
||||
break;
|
||||
}
|
||||
case HipblasType::F64: {
|
||||
double** a_batch_ptrs = static_cast<double**>(buffers[3]);
|
||||
double** b_batch_ptrs = static_cast<double**>(buffers[4]);
|
||||
const double alpha = 1.0;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasDtrsmBatched(
|
||||
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
|
||||
const_cast<double**>(a_batch_ptrs), lda, b_batch_ptrs, ldb,
|
||||
d.batch)));
|
||||
break;
|
||||
}
|
||||
case HipblasType::C64: {
|
||||
hipblasComplex** a_batch_ptrs = static_cast<hipblasComplex**>(buffers[3]);
|
||||
hipblasComplex** b_batch_ptrs = static_cast<hipblasComplex**>(buffers[4]);
|
||||
const hipblasComplex alpha = hipblasComplex(1.0f, 0.0f);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasCtrsmBatched(
|
||||
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
|
||||
const_cast<hipblasComplex**>(a_batch_ptrs), lda, b_batch_ptrs, ldb,
|
||||
d.batch)));
|
||||
break;
|
||||
}
|
||||
case HipblasType::C128: {
|
||||
hipblasDoubleComplex** a_batch_ptrs =
|
||||
static_cast<hipblasDoubleComplex**>(buffers[3]);
|
||||
hipblasDoubleComplex** b_batch_ptrs =
|
||||
static_cast<hipblasDoubleComplex**>(buffers[4]);
|
||||
const hipblasDoubleComplex alpha = hipblasDoubleComplex(1.0f, 0.0f);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasZtrsmBatched(
|
||||
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
|
||||
const_cast<hipblasDoubleComplex**>(a_batch_ptrs), lda, b_batch_ptrs,
|
||||
ldb, d.batch)));
|
||||
break;
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void TrsmBatched(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = TrsmBatched_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
// Batched LU decomposition: getrfbatched
|
||||
|
||||
static absl::Status GetrfBatched_(hipStream_t stream, void** buffers,
|
||||
|
@ -32,20 +32,6 @@ enum class HipblasType {
|
||||
C128,
|
||||
};
|
||||
|
||||
// Batched triangular solve: trsmbatched
|
||||
|
||||
struct TrsmBatchedDescriptor {
|
||||
HipblasType type;
|
||||
int batch, m, n;
|
||||
hipblasSideMode_t side;
|
||||
hipblasFillMode_t uplo;
|
||||
hipblasOperation_t trans;
|
||||
hipblasDiagType_t diag;
|
||||
};
|
||||
|
||||
void TrsmBatched(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
// Batched LU decomposition: getrfbatched
|
||||
|
||||
struct GetrfBatchedDescriptor {
|
||||
|
Loading…
x
Reference in New Issue
Block a user