From 5617a02fa4fc3deaaa0090681510d8c0ae311a30 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 18 Oct 2022 15:03:38 -0700 Subject: [PATCH] Remove JAX custom call implementation of batched triangular solve. XLA supports batched triangular solve on GPU and has since February 2022, which is older than the minimum jaxlib version. We can therefore delete our implementation and just use XLA's implementation. PiperOrigin-RevId: 482031830 --- jax/_src/lax/linalg.py | 33 ------------- jaxlib/cuda/cublas.cc | 19 -------- jaxlib/cuda/cublas_kernels.cc | 84 --------------------------------- jaxlib/cuda/cublas_kernels.h | 15 ------ jaxlib/cuda/cuda_gpu_kernels.cc | 2 - jaxlib/cuda/cuda_kernels.cc | 70 --------------------------- jaxlib/gpu_solver.py | 43 ----------------- jaxlib/rocm/hipblas.cc | 20 -------- jaxlib/rocm/hipblas_kernels.cc | 84 --------------------------------- jaxlib/rocm/hipblas_kernels.h | 14 ------ 10 files changed, 384 deletions(-) delete mode 100644 jaxlib/cuda/cuda_kernels.cc diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index b24361e87..f18227634 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -915,39 +915,6 @@ mlir.register_lowering(triangular_solve_p, _triangular_solve_cpu_lower, platform='cpu') -def _triangular_solve_gpu_lower( - trsm_impl, ctx, a, b, *, left_side, lower, transpose_a, conjugate_a, - unit_diagonal): - a_aval, _ = ctx.avals_in - m, n = a_aval.shape[-2:] - batch = prod(a_aval.shape[:-2]) - if conjugate_a and not transpose_a: - a = chlo.ConjOp(a).result - conjugate_a = False - if batch > 1 and m <= 256 and n <= 256: - return [trsm_impl(a_aval.dtype, a, b, left_side, lower, transpose_a, - conjugate_a, unit_diagonal)] - else: - # Use the XLA implementation for unbatched triangular_solve. - if transpose_a: - transpose = "ADJOINT" if conjugate_a else "TRANSPOSE" - else: - transpose = "NO_TRANSPOSE" - return mhlo.TriangularSolveOp(b.type, a, b, ir.BoolAttr.get(left_side), - ir.BoolAttr.get(lower), - ir.BoolAttr.get(unit_diagonal), - mhlo.TransposeAttr.get(transpose)).results - -mlir.register_lowering( - triangular_solve_p, - partial(_triangular_solve_gpu_lower, gpu_solver.cuda_trsm), - platform='cuda') -mlir.register_lowering( - triangular_solve_p, - partial(_triangular_solve_gpu_lower, gpu_solver.rocm_trsm), - platform='rocm') - - # Support operation for LU decomposition: Transformation of the pivots returned # by LU decomposition into permutations. diff --git a/jaxlib/cuda/cublas.cc b/jaxlib/cuda/cublas.cc index d3539bcd1..6679ddc12 100644 --- a/jaxlib/cuda/cublas.cc +++ b/jaxlib/cuda/cublas.cc @@ -51,23 +51,6 @@ CublasType DtypeToCublasType(const py::dtype& np_type) { return it->second; } -// Returns the descriptor for a TrsmBatched operation. -std::pair BuildTrsmBatchedDescriptor( - const py::dtype& dtype, int batch, int m, int n, bool left_side, bool lower, - bool trans_a, bool conj_a, bool unit_diagonal) { - size_t size = batch * sizeof(void*); - TrsmBatchedDescriptor desc; - desc.type = DtypeToCublasType(dtype); - desc.batch = batch; - desc.m = m; - desc.n = n; - desc.side = left_side ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT; - desc.uplo = lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER; - desc.trans = trans_a ? (conj_a ? CUBLAS_OP_C : CUBLAS_OP_T) : CUBLAS_OP_N; - desc.diag = unit_diagonal ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT; - return {size, PackDescriptor(desc)}; -} - // Returns the descriptor for a GetrfBatched operation. std::pair BuildGetrfBatchedDescriptor(const py::dtype& dtype, int b, int n) { @@ -86,7 +69,6 @@ std::pair BuildGeqrfBatchedDescriptor(const py::dtype& dtype, py::dict Registrations() { py::dict dict; - dict["cublas_trsm_batched"] = EncapsulateFunction(TrsmBatched); dict["cublas_getrf_batched"] = EncapsulateFunction(GetrfBatched); dict["cublas_geqrf_batched"] = EncapsulateFunction(GeqrfBatched); return dict; @@ -94,7 +76,6 @@ py::dict Registrations() { PYBIND11_MODULE(_cublas, m) { m.def("registrations", &Registrations); - m.def("build_trsm_batched_descriptor", &BuildTrsmBatchedDescriptor); m.def("build_getrf_batched_descriptor", &BuildGetrfBatchedDescriptor); m.def("build_geqrf_batched_descriptor", &BuildGeqrfBatchedDescriptor); } diff --git a/jaxlib/cuda/cublas_kernels.cc b/jaxlib/cuda/cublas_kernels.cc index 2669ae06c..d171ac45c 100644 --- a/jaxlib/cuda/cublas_kernels.cc +++ b/jaxlib/cuda/cublas_kernels.cc @@ -73,90 +73,6 @@ int SizeOfCublasType(CublasType type) { } // namespace -// Batched triangular solve: trsmbatched - -static absl::Status TrsmBatched_(cudaStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const TrsmBatchedDescriptor& d = **s; - auto h = BlasHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[2] != buffers[1]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync( - buffers[2], buffers[1], SizeOfCublasType(d.type) * d.batch * d.m * d.n, - cudaMemcpyDeviceToDevice, stream))); - } - const int lda = d.side == CUBLAS_SIDE_LEFT ? d.m : d.n; - const int ldb = d.m; - auto a_batch_host = MakeBatchPointers(stream, buffers[0], buffers[3], d.batch, - SizeOfCublasType(d.type) * lda * lda); - JAX_RETURN_IF_ERROR(a_batch_host.status()); - auto b_batch_host = MakeBatchPointers(stream, buffers[2], buffers[4], d.batch, - SizeOfCublasType(d.type) * d.m * d.n); - JAX_RETURN_IF_ERROR(b_batch_host.status()); - // TODO(phawkins): ideally we would not need to synchronize here, but to - // avoid it we need a way to keep the host-side buffer alive until the copy - // completes. - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaStreamSynchronize(stream))); - switch (d.type) { - case CublasType::F32: { - float** a_batch_ptrs = static_cast(buffers[3]); - float** b_batch_ptrs = static_cast(buffers[4]); - // NOTE(phawkins): if alpha is in GPU memory, cuBlas seems to segfault. - const float alpha = 1.0f; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasStrsmBatched( - handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha, - const_cast(a_batch_ptrs), lda, b_batch_ptrs, ldb, - d.batch))); - break; - } - case CublasType::F64: { - double** a_batch_ptrs = static_cast(buffers[3]); - double** b_batch_ptrs = static_cast(buffers[4]); - const double alpha = 1.0; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasDtrsmBatched( - handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha, - const_cast(a_batch_ptrs), lda, b_batch_ptrs, ldb, - d.batch))); - break; - } - case CublasType::C64: { - cuComplex** a_batch_ptrs = static_cast(buffers[3]); - cuComplex** b_batch_ptrs = static_cast(buffers[4]); - const cuComplex alpha = make_cuComplex(1.0f, 0.0f); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasCtrsmBatched( - handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha, - const_cast(a_batch_ptrs), lda, b_batch_ptrs, ldb, - d.batch))); - break; - } - case CublasType::C128: { - cuDoubleComplex** a_batch_ptrs = - static_cast(buffers[3]); - cuDoubleComplex** b_batch_ptrs = - static_cast(buffers[4]); - const cuDoubleComplex alpha = make_cuDoubleComplex(1.0f, 0.0f); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasZtrsmBatched( - handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha, - const_cast(a_batch_ptrs), lda, b_batch_ptrs, - ldb, d.batch))); - break; - } - } - return absl::OkStatus(); -} - -void TrsmBatched(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = TrsmBatched_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - // Batched LU decomposition: getrfbatched static absl::Status GetrfBatched_(cudaStream_t stream, void** buffers, diff --git a/jaxlib/cuda/cublas_kernels.h b/jaxlib/cuda/cublas_kernels.h index 0fcc8c1eb..218f950d3 100644 --- a/jaxlib/cuda/cublas_kernels.h +++ b/jaxlib/cuda/cublas_kernels.h @@ -33,21 +33,6 @@ enum class CublasType { C128, }; - -// Batched triangular solve: trsmbatched - -struct TrsmBatchedDescriptor { - CublasType type; - int batch, m, n; - cublasSideMode_t side; - cublasFillMode_t uplo; - cublasOperation_t trans; - cublasDiagType_t diag; -}; - -void TrsmBatched(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - // Batched LU decomposition: getrfbatched struct GetrfBatchedDescriptor { diff --git a/jaxlib/cuda/cuda_gpu_kernels.cc b/jaxlib/cuda/cuda_gpu_kernels.cc index 37aa49f46..c3ef2ebcd 100644 --- a/jaxlib/cuda/cuda_gpu_kernels.cc +++ b/jaxlib/cuda/cuda_gpu_kernels.cc @@ -26,8 +26,6 @@ limitations under the License. namespace jax { namespace { -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_trsm_batched", TrsmBatched, - "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_getrf_batched", GetrfBatched, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cuda_lu_pivots_to_permutation", diff --git a/jaxlib/cuda/cuda_kernels.cc b/jaxlib/cuda/cuda_kernels.cc deleted file mode 100644 index 935fda53c..000000000 --- a/jaxlib/cuda/cuda_kernels.cc +++ /dev/null @@ -1,70 +0,0 @@ -/* Copyright 2021 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file is not used by JAX itself, but exists to assist with running -// JAX-generated HLO code from outside of JAX. - -#include "jaxlib/cuda/cublas_kernels.h" -#include "jaxlib/cuda/cuda_lu_pivot_kernels.h" -#include "jaxlib/cuda/cuda_prng_kernels.h" -#include "jaxlib/cuda/cusolver_kernels.h" -#include "jaxlib/cuda/cusparse_kernels.h" -#include "tensorflow/compiler/xla/service/custom_call_target_registry.h" - -namespace jax { -namespace { - -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_trsm_batched", TrsmBatched, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_getrf_batched", GetrfBatched, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cuda_lu_pivots_to_permutation", - CudaLuPivotsToPermutation, "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cuda_threefry2x32", CudaThreeFry2x32, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_potrf", Potrf, "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_getrf", Getrf, "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_geqrf", Geqrf, "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_orgqr", Orgqr, "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_syevd", Syevd, "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_syevj", Syevj, "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_gesvd", Gesvd, "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_gesvdj", Gesvdj, "CUDA"); - -#if JAX_CUSPARSE_11300 -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_csr_todense", CsrToDense, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_csr_fromdense", CsrFromDense, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_csr_matvec", CsrMatvec, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_csr_matmat", CsrMatmat, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_coo_todense", CooToDense, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_coo_fromdense", CooFromDense, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_coo_matvec", CooMatvec, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_coo_matmat", CooMatmat, - "CUDA"); -#endif -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_gtsv2_f32", gtsv2_f32, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_gtsv2_f64", gtsv2_f64, - "CUDA"); - -} // namespace -} // namespace jax diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index 012e95b62..048d575c0 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -63,49 +63,6 @@ def _real_type(dtype): _prod = lambda xs: functools.reduce(operator.mul, xs, 1) -def _trsm_mhlo(platform, gpu_blas, dtype, a, b, left_side=False, lower=False, - trans_a=False, conj_a=False, diag=False): - """Batched triangular solve. - - XLA implements unbatched triangular solve directly, so we need only implement - the batched case.""" - b_type = ir.RankedTensorType(b.type) - dims = b_type.shape - assert len(dims) >= 2 - m, n = dims[-2:] - batch_dims = tuple(dims[:-2]) - num_bd = len(batch_dims) - batch = _prod(batch_dims) - k = m if left_side else n - - a_type = ir.RankedTensorType(a.type) - if (batch_dims + (k, k) != tuple(a_type.shape) or - a_type.element_type != b_type.element_type): - raise ValueError("Argument mismatch for trsm, got {} and {}".format( - a_type, b_type)) - - if conj_a and not trans_a: - raise NotImplementedError("Conjugation without transposition not supported") - - lwork, opaque = gpu_blas.build_trsm_batched_descriptor( - np.dtype(dtype), batch, m, n, left_side, lower, trans_a, conj_a, diag) - layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) - work_type = ir.RankedTensorType.get([lwork], ir.IntegerType.get_signless(8)) - work_layout = [0] - out = custom_call( - f"{platform}blas_trsm_batched", - [b_type, work_type, work_type], - [a, b], - backend_config=opaque, - operand_layouts=[layout] * 2, - result_layouts=[layout, work_layout, work_layout], - operand_output_aliases={1: 0}) - return out[0] - -cuda_trsm = partial(_trsm_mhlo, "cu", _cublas) -rocm_trsm = partial(_trsm_mhlo, "hip", _hipblas) - - def _potrf_mhlo(platform, gpu_solver, dtype, a, lower): """Cholesky decomposition.""" a_type = ir.RankedTensorType(a.type) diff --git a/jaxlib/rocm/hipblas.cc b/jaxlib/rocm/hipblas.cc index 3bb093cbb..aa880f38c 100644 --- a/jaxlib/rocm/hipblas.cc +++ b/jaxlib/rocm/hipblas.cc @@ -50,24 +50,6 @@ HipblasType DtypeToHipblasType(const py::dtype& np_type) { return it->second; } -// Returns the descriptor for a TrsmBatched operation. -std::pair -BuildTrsmBatchedDescriptor(const py::dtype& dtype, int batch, int m, int n, - bool left_side, bool lower, bool trans_a, - bool conj_a, bool unit_diagonal) { - size_t size = batch * sizeof(void*); - TrsmBatchedDescriptor desc; - desc.type = DtypeToHipblasType(dtype); - desc.batch = batch; - desc.m = m; - desc.n = n; - desc.side = left_side ? HIPBLAS_SIDE_LEFT : HIPBLAS_SIDE_RIGHT; - desc.uplo = lower ? HIPBLAS_FILL_MODE_LOWER : HIPBLAS_FILL_MODE_UPPER; - desc.trans = trans_a ? (conj_a ? HIPBLAS_OP_C : HIPBLAS_OP_T) : HIPBLAS_OP_N; - desc.diag = unit_diagonal ? HIPBLAS_DIAG_UNIT : HIPBLAS_DIAG_NON_UNIT; - return {size, PackDescriptor(desc)}; -} - // Returns the descriptor for a GetrfBatched operation. std::pair BuildGetrfBatchedDescriptor(const py::dtype& dtype, int b, int n) { @@ -86,7 +68,6 @@ std::pair BuildGeqrfBatchedDescriptor(const py::dtype& dtype, py::dict Registrations() { py::dict dict; - dict["hipblas_trsm_batched"] = EncapsulateFunction(TrsmBatched); dict["hipblas_getrf_batched"] = EncapsulateFunction(GetrfBatched); dict["hipblas_geqrf_batched"] = EncapsulateFunction(GeqrfBatched); return dict; @@ -94,7 +75,6 @@ py::dict Registrations() { PYBIND11_MODULE(_hipblas, m) { m.def("registrations", &Registrations); - m.def("build_trsm_batched_descriptor", &BuildTrsmBatchedDescriptor); m.def("build_getrf_batched_descriptor", &BuildGetrfBatchedDescriptor); m.def("build_geqrf_batched_descriptor", &BuildGeqrfBatchedDescriptor); } diff --git a/jaxlib/rocm/hipblas_kernels.cc b/jaxlib/rocm/hipblas_kernels.cc index 948e7dad8..10ab69500 100644 --- a/jaxlib/rocm/hipblas_kernels.cc +++ b/jaxlib/rocm/hipblas_kernels.cc @@ -72,90 +72,6 @@ int SizeOfHipblasType(HipblasType type) { } // namespace -// Batched triangular solve: trsmbatched - -static absl::Status TrsmBatched_(hipStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const TrsmBatchedDescriptor& d = **s; - auto h = BlasHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[2] != buffers[1]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync( - buffers[2], buffers[1], SizeOfHipblasType(d.type) * d.batch * d.m * d.n, - hipMemcpyDeviceToDevice, stream))); - } - const int lda = d.side == HIPBLAS_SIDE_LEFT ? d.m : d.n; - const int ldb = d.m; - auto a_batch_host = MakeBatchPointers(stream, buffers[0], buffers[3], d.batch, - SizeOfHipblasType(d.type) * lda * lda); - JAX_RETURN_IF_ERROR(a_batch_host.status()); - auto b_batch_host = MakeBatchPointers(stream, buffers[2], buffers[4], d.batch, - SizeOfHipblasType(d.type) * d.m * d.n); - JAX_RETURN_IF_ERROR(b_batch_host.status()); - // TODO(phawkins): ideally we would not need to synchronize here, but to - // avoid it we need a way to keep the host-side buffer alive until the copy - // completes. - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipStreamSynchronize(stream))); - switch (d.type) { - case HipblasType::F32: { - float** a_batch_ptrs = static_cast(buffers[3]); - float** b_batch_ptrs = static_cast(buffers[4]); - // TODO(reza): is the following statement correct for rocm? - // NOTE(phawkins): if alpha is in GPU memory, cuBlas seems to segfault. - const float alpha = 1.0f; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasStrsmBatched( - handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha, - const_cast(a_batch_ptrs), lda, b_batch_ptrs, ldb, d.batch))); - break; - } - case HipblasType::F64: { - double** a_batch_ptrs = static_cast(buffers[3]); - double** b_batch_ptrs = static_cast(buffers[4]); - const double alpha = 1.0; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasDtrsmBatched( - handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha, - const_cast(a_batch_ptrs), lda, b_batch_ptrs, ldb, - d.batch))); - break; - } - case HipblasType::C64: { - hipblasComplex** a_batch_ptrs = static_cast(buffers[3]); - hipblasComplex** b_batch_ptrs = static_cast(buffers[4]); - const hipblasComplex alpha = hipblasComplex(1.0f, 0.0f); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasCtrsmBatched( - handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha, - const_cast(a_batch_ptrs), lda, b_batch_ptrs, ldb, - d.batch))); - break; - } - case HipblasType::C128: { - hipblasDoubleComplex** a_batch_ptrs = - static_cast(buffers[3]); - hipblasDoubleComplex** b_batch_ptrs = - static_cast(buffers[4]); - const hipblasDoubleComplex alpha = hipblasDoubleComplex(1.0f, 0.0f); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasZtrsmBatched( - handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha, - const_cast(a_batch_ptrs), lda, b_batch_ptrs, - ldb, d.batch))); - break; - } - } - return absl::OkStatus(); -} - -void TrsmBatched(hipStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = TrsmBatched_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - // Batched LU decomposition: getrfbatched static absl::Status GetrfBatched_(hipStream_t stream, void** buffers, diff --git a/jaxlib/rocm/hipblas_kernels.h b/jaxlib/rocm/hipblas_kernels.h index 1b7f691ed..f377665dc 100644 --- a/jaxlib/rocm/hipblas_kernels.h +++ b/jaxlib/rocm/hipblas_kernels.h @@ -32,20 +32,6 @@ enum class HipblasType { C128, }; -// Batched triangular solve: trsmbatched - -struct TrsmBatchedDescriptor { - HipblasType type; - int batch, m, n; - hipblasSideMode_t side; - hipblasFillMode_t uplo; - hipblasOperation_t trans; - hipblasDiagType_t diag; -}; - -void TrsmBatched(hipStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - // Batched LU decomposition: getrfbatched struct GetrfBatchedDescriptor {