Remove JAX custom call implementation of batched triangular solve.

XLA supports batched triangular solve on GPU and has since February 2022, which is older than the minimum jaxlib version. We can therefore delete our implementation and just use XLA's implementation.

PiperOrigin-RevId: 482031830
This commit is contained in:
Peter Hawkins 2022-10-18 15:03:38 -07:00 committed by jax authors
parent d20b9fa498
commit 5617a02fa4
10 changed files with 0 additions and 384 deletions

View File

@ -915,39 +915,6 @@ mlir.register_lowering(triangular_solve_p, _triangular_solve_cpu_lower,
platform='cpu')
def _triangular_solve_gpu_lower(
trsm_impl, ctx, a, b, *, left_side, lower, transpose_a, conjugate_a,
unit_diagonal):
a_aval, _ = ctx.avals_in
m, n = a_aval.shape[-2:]
batch = prod(a_aval.shape[:-2])
if conjugate_a and not transpose_a:
a = chlo.ConjOp(a).result
conjugate_a = False
if batch > 1 and m <= 256 and n <= 256:
return [trsm_impl(a_aval.dtype, a, b, left_side, lower, transpose_a,
conjugate_a, unit_diagonal)]
else:
# Use the XLA implementation for unbatched triangular_solve.
if transpose_a:
transpose = "ADJOINT" if conjugate_a else "TRANSPOSE"
else:
transpose = "NO_TRANSPOSE"
return mhlo.TriangularSolveOp(b.type, a, b, ir.BoolAttr.get(left_side),
ir.BoolAttr.get(lower),
ir.BoolAttr.get(unit_diagonal),
mhlo.TransposeAttr.get(transpose)).results
mlir.register_lowering(
triangular_solve_p,
partial(_triangular_solve_gpu_lower, gpu_solver.cuda_trsm),
platform='cuda')
mlir.register_lowering(
triangular_solve_p,
partial(_triangular_solve_gpu_lower, gpu_solver.rocm_trsm),
platform='rocm')
# Support operation for LU decomposition: Transformation of the pivots returned
# by LU decomposition into permutations.

View File

@ -51,23 +51,6 @@ CublasType DtypeToCublasType(const py::dtype& np_type) {
return it->second;
}
// Returns the descriptor for a TrsmBatched operation.
std::pair<size_t, py::bytes> BuildTrsmBatchedDescriptor(
const py::dtype& dtype, int batch, int m, int n, bool left_side, bool lower,
bool trans_a, bool conj_a, bool unit_diagonal) {
size_t size = batch * sizeof(void*);
TrsmBatchedDescriptor desc;
desc.type = DtypeToCublasType(dtype);
desc.batch = batch;
desc.m = m;
desc.n = n;
desc.side = left_side ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT;
desc.uplo = lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
desc.trans = trans_a ? (conj_a ? CUBLAS_OP_C : CUBLAS_OP_T) : CUBLAS_OP_N;
desc.diag = unit_diagonal ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT;
return {size, PackDescriptor(desc)};
}
// Returns the descriptor for a GetrfBatched operation.
std::pair<size_t, py::bytes> BuildGetrfBatchedDescriptor(const py::dtype& dtype,
int b, int n) {
@ -86,7 +69,6 @@ std::pair<size_t, py::bytes> BuildGeqrfBatchedDescriptor(const py::dtype& dtype,
py::dict Registrations() {
py::dict dict;
dict["cublas_trsm_batched"] = EncapsulateFunction(TrsmBatched);
dict["cublas_getrf_batched"] = EncapsulateFunction(GetrfBatched);
dict["cublas_geqrf_batched"] = EncapsulateFunction(GeqrfBatched);
return dict;
@ -94,7 +76,6 @@ py::dict Registrations() {
PYBIND11_MODULE(_cublas, m) {
m.def("registrations", &Registrations);
m.def("build_trsm_batched_descriptor", &BuildTrsmBatchedDescriptor);
m.def("build_getrf_batched_descriptor", &BuildGetrfBatchedDescriptor);
m.def("build_geqrf_batched_descriptor", &BuildGeqrfBatchedDescriptor);
}

View File

@ -73,90 +73,6 @@ int SizeOfCublasType(CublasType type) {
} // namespace
// Batched triangular solve: trsmbatched
static absl::Status TrsmBatched_(cudaStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<TrsmBatchedDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const TrsmBatchedDescriptor& d = **s;
auto h = BlasHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
if (buffers[2] != buffers[1]) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync(
buffers[2], buffers[1], SizeOfCublasType(d.type) * d.batch * d.m * d.n,
cudaMemcpyDeviceToDevice, stream)));
}
const int lda = d.side == CUBLAS_SIDE_LEFT ? d.m : d.n;
const int ldb = d.m;
auto a_batch_host = MakeBatchPointers(stream, buffers[0], buffers[3], d.batch,
SizeOfCublasType(d.type) * lda * lda);
JAX_RETURN_IF_ERROR(a_batch_host.status());
auto b_batch_host = MakeBatchPointers(stream, buffers[2], buffers[4], d.batch,
SizeOfCublasType(d.type) * d.m * d.n);
JAX_RETURN_IF_ERROR(b_batch_host.status());
// TODO(phawkins): ideally we would not need to synchronize here, but to
// avoid it we need a way to keep the host-side buffer alive until the copy
// completes.
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaStreamSynchronize(stream)));
switch (d.type) {
case CublasType::F32: {
float** a_batch_ptrs = static_cast<float**>(buffers[3]);
float** b_batch_ptrs = static_cast<float**>(buffers[4]);
// NOTE(phawkins): if alpha is in GPU memory, cuBlas seems to segfault.
const float alpha = 1.0f;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasStrsmBatched(
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
const_cast<const float**>(a_batch_ptrs), lda, b_batch_ptrs, ldb,
d.batch)));
break;
}
case CublasType::F64: {
double** a_batch_ptrs = static_cast<double**>(buffers[3]);
double** b_batch_ptrs = static_cast<double**>(buffers[4]);
const double alpha = 1.0;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasDtrsmBatched(
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
const_cast<const double**>(a_batch_ptrs), lda, b_batch_ptrs, ldb,
d.batch)));
break;
}
case CublasType::C64: {
cuComplex** a_batch_ptrs = static_cast<cuComplex**>(buffers[3]);
cuComplex** b_batch_ptrs = static_cast<cuComplex**>(buffers[4]);
const cuComplex alpha = make_cuComplex(1.0f, 0.0f);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasCtrsmBatched(
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
const_cast<const cuComplex**>(a_batch_ptrs), lda, b_batch_ptrs, ldb,
d.batch)));
break;
}
case CublasType::C128: {
cuDoubleComplex** a_batch_ptrs =
static_cast<cuDoubleComplex**>(buffers[3]);
cuDoubleComplex** b_batch_ptrs =
static_cast<cuDoubleComplex**>(buffers[4]);
const cuDoubleComplex alpha = make_cuDoubleComplex(1.0f, 0.0f);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasZtrsmBatched(
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
const_cast<const cuDoubleComplex**>(a_batch_ptrs), lda, b_batch_ptrs,
ldb, d.batch)));
break;
}
}
return absl::OkStatus();
}
void TrsmBatched(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = TrsmBatched_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// Batched LU decomposition: getrfbatched
static absl::Status GetrfBatched_(cudaStream_t stream, void** buffers,

View File

@ -33,21 +33,6 @@ enum class CublasType {
C128,
};
// Batched triangular solve: trsmbatched
struct TrsmBatchedDescriptor {
CublasType type;
int batch, m, n;
cublasSideMode_t side;
cublasFillMode_t uplo;
cublasOperation_t trans;
cublasDiagType_t diag;
};
void TrsmBatched(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// Batched LU decomposition: getrfbatched
struct GetrfBatchedDescriptor {

View File

@ -26,8 +26,6 @@ limitations under the License.
namespace jax {
namespace {
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_trsm_batched", TrsmBatched,
"CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_getrf_batched", GetrfBatched,
"CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cuda_lu_pivots_to_permutation",

View File

@ -1,70 +0,0 @@
/* Copyright 2021 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file is not used by JAX itself, but exists to assist with running
// JAX-generated HLO code from outside of JAX.
#include "jaxlib/cuda/cublas_kernels.h"
#include "jaxlib/cuda/cuda_lu_pivot_kernels.h"
#include "jaxlib/cuda/cuda_prng_kernels.h"
#include "jaxlib/cuda/cusolver_kernels.h"
#include "jaxlib/cuda/cusparse_kernels.h"
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
namespace jax {
namespace {
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_trsm_batched", TrsmBatched,
"CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_getrf_batched", GetrfBatched,
"CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cuda_lu_pivots_to_permutation",
CudaLuPivotsToPermutation, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cuda_threefry2x32", CudaThreeFry2x32,
"CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_potrf", Potrf, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_getrf", Getrf, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_geqrf", Geqrf, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_orgqr", Orgqr, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_syevd", Syevd, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_syevj", Syevj, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_gesvd", Gesvd, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_gesvdj", Gesvdj, "CUDA");
#if JAX_CUSPARSE_11300
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_csr_todense", CsrToDense,
"CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_csr_fromdense", CsrFromDense,
"CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_csr_matvec", CsrMatvec,
"CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_csr_matmat", CsrMatmat,
"CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_coo_todense", CooToDense,
"CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_coo_fromdense", CooFromDense,
"CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_coo_matvec", CooMatvec,
"CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_coo_matmat", CooMatmat,
"CUDA");
#endif
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_gtsv2_f32", gtsv2_f32,
"CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_gtsv2_f64", gtsv2_f64,
"CUDA");
} // namespace
} // namespace jax

View File

@ -63,49 +63,6 @@ def _real_type(dtype):
_prod = lambda xs: functools.reduce(operator.mul, xs, 1)
def _trsm_mhlo(platform, gpu_blas, dtype, a, b, left_side=False, lower=False,
trans_a=False, conj_a=False, diag=False):
"""Batched triangular solve.
XLA implements unbatched triangular solve directly, so we need only implement
the batched case."""
b_type = ir.RankedTensorType(b.type)
dims = b_type.shape
assert len(dims) >= 2
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
batch = _prod(batch_dims)
k = m if left_side else n
a_type = ir.RankedTensorType(a.type)
if (batch_dims + (k, k) != tuple(a_type.shape) or
a_type.element_type != b_type.element_type):
raise ValueError("Argument mismatch for trsm, got {} and {}".format(
a_type, b_type))
if conj_a and not trans_a:
raise NotImplementedError("Conjugation without transposition not supported")
lwork, opaque = gpu_blas.build_trsm_batched_descriptor(
np.dtype(dtype), batch, m, n, left_side, lower, trans_a, conj_a, diag)
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
work_type = ir.RankedTensorType.get([lwork], ir.IntegerType.get_signless(8))
work_layout = [0]
out = custom_call(
f"{platform}blas_trsm_batched",
[b_type, work_type, work_type],
[a, b],
backend_config=opaque,
operand_layouts=[layout] * 2,
result_layouts=[layout, work_layout, work_layout],
operand_output_aliases={1: 0})
return out[0]
cuda_trsm = partial(_trsm_mhlo, "cu", _cublas)
rocm_trsm = partial(_trsm_mhlo, "hip", _hipblas)
def _potrf_mhlo(platform, gpu_solver, dtype, a, lower):
"""Cholesky decomposition."""
a_type = ir.RankedTensorType(a.type)

View File

@ -50,24 +50,6 @@ HipblasType DtypeToHipblasType(const py::dtype& np_type) {
return it->second;
}
// Returns the descriptor for a TrsmBatched operation.
std::pair<size_t, py::bytes>
BuildTrsmBatchedDescriptor(const py::dtype& dtype, int batch, int m, int n,
bool left_side, bool lower, bool trans_a,
bool conj_a, bool unit_diagonal) {
size_t size = batch * sizeof(void*);
TrsmBatchedDescriptor desc;
desc.type = DtypeToHipblasType(dtype);
desc.batch = batch;
desc.m = m;
desc.n = n;
desc.side = left_side ? HIPBLAS_SIDE_LEFT : HIPBLAS_SIDE_RIGHT;
desc.uplo = lower ? HIPBLAS_FILL_MODE_LOWER : HIPBLAS_FILL_MODE_UPPER;
desc.trans = trans_a ? (conj_a ? HIPBLAS_OP_C : HIPBLAS_OP_T) : HIPBLAS_OP_N;
desc.diag = unit_diagonal ? HIPBLAS_DIAG_UNIT : HIPBLAS_DIAG_NON_UNIT;
return {size, PackDescriptor(desc)};
}
// Returns the descriptor for a GetrfBatched operation.
std::pair<size_t, py::bytes> BuildGetrfBatchedDescriptor(const py::dtype& dtype,
int b, int n) {
@ -86,7 +68,6 @@ std::pair<size_t, py::bytes> BuildGeqrfBatchedDescriptor(const py::dtype& dtype,
py::dict Registrations() {
py::dict dict;
dict["hipblas_trsm_batched"] = EncapsulateFunction(TrsmBatched);
dict["hipblas_getrf_batched"] = EncapsulateFunction(GetrfBatched);
dict["hipblas_geqrf_batched"] = EncapsulateFunction(GeqrfBatched);
return dict;
@ -94,7 +75,6 @@ py::dict Registrations() {
PYBIND11_MODULE(_hipblas, m) {
m.def("registrations", &Registrations);
m.def("build_trsm_batched_descriptor", &BuildTrsmBatchedDescriptor);
m.def("build_getrf_batched_descriptor", &BuildGetrfBatchedDescriptor);
m.def("build_geqrf_batched_descriptor", &BuildGeqrfBatchedDescriptor);
}

View File

@ -72,90 +72,6 @@ int SizeOfHipblasType(HipblasType type) {
} // namespace
// Batched triangular solve: trsmbatched
static absl::Status TrsmBatched_(hipStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<TrsmBatchedDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const TrsmBatchedDescriptor& d = **s;
auto h = BlasHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
if (buffers[2] != buffers[1]) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
buffers[2], buffers[1], SizeOfHipblasType(d.type) * d.batch * d.m * d.n,
hipMemcpyDeviceToDevice, stream)));
}
const int lda = d.side == HIPBLAS_SIDE_LEFT ? d.m : d.n;
const int ldb = d.m;
auto a_batch_host = MakeBatchPointers(stream, buffers[0], buffers[3], d.batch,
SizeOfHipblasType(d.type) * lda * lda);
JAX_RETURN_IF_ERROR(a_batch_host.status());
auto b_batch_host = MakeBatchPointers(stream, buffers[2], buffers[4], d.batch,
SizeOfHipblasType(d.type) * d.m * d.n);
JAX_RETURN_IF_ERROR(b_batch_host.status());
// TODO(phawkins): ideally we would not need to synchronize here, but to
// avoid it we need a way to keep the host-side buffer alive until the copy
// completes.
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipStreamSynchronize(stream)));
switch (d.type) {
case HipblasType::F32: {
float** a_batch_ptrs = static_cast<float**>(buffers[3]);
float** b_batch_ptrs = static_cast<float**>(buffers[4]);
// TODO(reza): is the following statement correct for rocm?
// NOTE(phawkins): if alpha is in GPU memory, cuBlas seems to segfault.
const float alpha = 1.0f;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasStrsmBatched(
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
const_cast<float**>(a_batch_ptrs), lda, b_batch_ptrs, ldb, d.batch)));
break;
}
case HipblasType::F64: {
double** a_batch_ptrs = static_cast<double**>(buffers[3]);
double** b_batch_ptrs = static_cast<double**>(buffers[4]);
const double alpha = 1.0;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasDtrsmBatched(
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
const_cast<double**>(a_batch_ptrs), lda, b_batch_ptrs, ldb,
d.batch)));
break;
}
case HipblasType::C64: {
hipblasComplex** a_batch_ptrs = static_cast<hipblasComplex**>(buffers[3]);
hipblasComplex** b_batch_ptrs = static_cast<hipblasComplex**>(buffers[4]);
const hipblasComplex alpha = hipblasComplex(1.0f, 0.0f);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasCtrsmBatched(
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
const_cast<hipblasComplex**>(a_batch_ptrs), lda, b_batch_ptrs, ldb,
d.batch)));
break;
}
case HipblasType::C128: {
hipblasDoubleComplex** a_batch_ptrs =
static_cast<hipblasDoubleComplex**>(buffers[3]);
hipblasDoubleComplex** b_batch_ptrs =
static_cast<hipblasDoubleComplex**>(buffers[4]);
const hipblasDoubleComplex alpha = hipblasDoubleComplex(1.0f, 0.0f);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasZtrsmBatched(
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
const_cast<hipblasDoubleComplex**>(a_batch_ptrs), lda, b_batch_ptrs,
ldb, d.batch)));
break;
}
}
return absl::OkStatus();
}
void TrsmBatched(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = TrsmBatched_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// Batched LU decomposition: getrfbatched
static absl::Status GetrfBatched_(hipStream_t stream, void** buffers,

View File

@ -32,20 +32,6 @@ enum class HipblasType {
C128,
};
// Batched triangular solve: trsmbatched
struct TrsmBatchedDescriptor {
HipblasType type;
int batch, m, n;
hipblasSideMode_t side;
hipblasFillMode_t uplo;
hipblasOperation_t trans;
hipblasDiagType_t diag;
};
void TrsmBatched(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// Batched LU decomposition: getrfbatched
struct GetrfBatchedDescriptor {