Port Triangular Solve to XLA's FFI

This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 655484166
This commit is contained in:
Paweł Paruzel 2024-07-24 02:14:40 -07:00 committed by jax authors
parent 832eb2d8d2
commit 54fe6e68a0
6 changed files with 115 additions and 4 deletions

View File

@ -118,6 +118,10 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(
// FFI Kernels
JAX_CPU_REGISTER_HANDLER(blas_strsm_ffi);
JAX_CPU_REGISTER_HANDLER(blas_dtrsm_ffi);
JAX_CPU_REGISTER_HANDLER(blas_ctrsm_ffi);
JAX_CPU_REGISTER_HANDLER(blas_ztrsm_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_sgetrf_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_dgetrf_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_cgetrf_ffi);

View File

@ -69,6 +69,10 @@ void GetLapackKernelsFromScipy() {
AssignKernelFn<Trsm<double>>(blas_ptr("dtrsm"));
AssignKernelFn<Trsm<std::complex<float>>>(blas_ptr("ctrsm"));
AssignKernelFn<Trsm<std::complex<double>>>(blas_ptr("ztrsm"));
AssignKernelFn<TriMatrixEquationSolver<DataType::F32>>(blas_ptr("strsm"));
AssignKernelFn<TriMatrixEquationSolver<DataType::F64>>(blas_ptr("dtrsm"));
AssignKernelFn<TriMatrixEquationSolver<DataType::C64>>(blas_ptr("ctrsm"));
AssignKernelFn<TriMatrixEquationSolver<DataType::C128>>(blas_ptr("ztrsm"));
nb::module_ cython_lapack =
nb::module_::import_("scipy.linalg.cython_lapack");
@ -219,6 +223,10 @@ nb::dict Registrations() {
dict["lapack_zhetrd"] =
EncapsulateFunction(Sytrd<std::complex<double>>::Kernel);
dict["blas_strsm_ffi"] = EncapsulateFunction(blas_strsm_ffi);
dict["blas_dtrsm_ffi"] = EncapsulateFunction(blas_dtrsm_ffi);
dict["blas_ctrsm_ffi"] = EncapsulateFunction(blas_ctrsm_ffi);
dict["blas_ztrsm_ffi"] = EncapsulateFunction(blas_ztrsm_ffi);
dict["lapack_sgetrf_ffi"] = EncapsulateFunction(lapack_sgetrf_ffi);
dict["lapack_dgetrf_ffi"] = EncapsulateFunction(lapack_dgetrf_ffi);
dict["lapack_cgetrf_ffi"] = EncapsulateFunction(lapack_cgetrf_ffi);

View File

@ -23,6 +23,18 @@ namespace jax {
// FFI Definition Macros (by DataType)
#define JAX_CPU_DEFINE_TRSM(name, data_type) \
XLA_FFI_DEFINE_HANDLER(name, TriMatrixEquationSolver<data_type>::Kernel, \
::xla::ffi::Ffi::Bind() \
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
.Arg<::xla::ffi::Buffer<data_type>>(/*y*/) \
.Arg<::xla::ffi::BufferR0<data_type>>(/*alpha*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*y_out*/) \
.Attr<MatrixParams::Side>("side") \
.Attr<MatrixParams::UpLo>("uplo") \
.Attr<MatrixParams::Transpose>("trans_x") \
.Attr<MatrixParams::Diag>("diag"))
#define JAX_CPU_DEFINE_GETRF(name, data_type) \
XLA_FFI_DEFINE_HANDLER( \
name, LuDecomposition<data_type>::Kernel, \
@ -92,6 +104,11 @@ namespace jax {
// FFI Handlers
JAX_CPU_DEFINE_TRSM(blas_strsm_ffi, ::xla::ffi::DataType::F32);
JAX_CPU_DEFINE_TRSM(blas_dtrsm_ffi, ::xla::ffi::DataType::F64);
JAX_CPU_DEFINE_TRSM(blas_ctrsm_ffi, ::xla::ffi::DataType::C64);
JAX_CPU_DEFINE_TRSM(blas_ztrsm_ffi, ::xla::ffi::DataType::C128);
JAX_CPU_DEFINE_GETRF(lapack_sgetrf_ffi, ::xla::ffi::DataType::F32);
JAX_CPU_DEFINE_GETRF(lapack_dgetrf_ffi, ::xla::ffi::DataType::F64);
JAX_CPU_DEFINE_GETRF(lapack_cgetrf_ffi, ::xla::ffi::DataType::C64);
@ -117,6 +134,7 @@ JAX_CPU_DEFINE_GESDD(lapack_dgesdd_ffi, ::xla::ffi::DataType::F64);
JAX_CPU_DEFINE_GESDD_COMPLEX(lapack_cgesdd_ffi, ::xla::ffi::DataType::C64);
JAX_CPU_DEFINE_GESDD_COMPLEX(lapack_zgesdd_ffi, ::xla::ffi::DataType::C128);
#undef JAX_CPU_DEFINE_TRSM
#undef JAX_CPU_DEFINE_GETRF
#undef JAX_CPU_DEFINE_GEQRF
#undef JAX_CPU_DEFINE_ORGQR

View File

@ -160,6 +160,48 @@ template struct Trsm<double>;
template struct Trsm<std::complex<float>>;
template struct Trsm<std::complex<double>>;
// FFI Kernel
template <ffi::DataType dtype>
ffi::Error TriMatrixEquationSolver<dtype>::Kernel(
ffi::Buffer<dtype> x, ffi::Buffer<dtype> y, ffi::BufferR0<dtype> alpha,
ffi::ResultBuffer<dtype> y_out, MatrixParams::Side side,
MatrixParams::UpLo uplo, MatrixParams::Transpose trans_x,
MatrixParams::Diag diag) {
CopyIfDiffBuffer(y, y_out);
auto [batch_count, y_rows, y_cols] = SplitBatch2D(y.dimensions());
auto* y_out_data = y_out->typed_data();
lapack_int x_leading_dim_v =
side == MatrixParams::Side::kLeft ? y_rows : y_cols;
lapack_int y_leading_dim_v = y_rows;
auto side_v = static_cast<char>(side);
auto uplo_v = static_cast<char>(uplo);
auto trans_x_v = static_cast<char>(trans_x);
auto diag_v = static_cast<char>(diag);
FFI_ASSIGN_OR_RETURN(auto y_rows_v, MaybeCastNoOverflow<lapack_int>(y_rows));
FFI_ASSIGN_OR_RETURN(auto y_cols_v, MaybeCastNoOverflow<lapack_int>(y_cols));
auto* x_data = x.typed_data();
const int64_t y_out_step{y_rows * y_cols};
const int64_t x_step{x_leading_dim_v * x_leading_dim_v};
for (int64_t i = 0; i < batch_count; ++i) {
fn(&side_v, &uplo_v, &trans_x_v, &diag_v, &y_rows_v, &y_cols_v,
alpha.typed_data(), x_data, &x_leading_dim_v, y_out_data,
&y_leading_dim_v);
y_out_data += y_out_step;
x_data += x_step;
}
return ffi::Error::Success();
}
template struct TriMatrixEquationSolver<ffi::DataType::F32>;
template struct TriMatrixEquationSolver<ffi::DataType::F64>;
template struct TriMatrixEquationSolver<ffi::DataType::C64>;
template struct TriMatrixEquationSolver<ffi::DataType::C128>;
//== LU Decomposition ==//
// lapack getrf

View File

@ -109,6 +109,24 @@ struct Trsm {
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
};
// FFI Kernel
template <::xla::ffi::DataType dtype>
struct TriMatrixEquationSolver {
using ValueType = ::xla::ffi::NativeType<dtype>;
using FnType = void(char* side, char* uplo, char* transa, char* diag,
lapack_int* m, lapack_int* n, ValueType* alpha,
ValueType* a, lapack_int* lda, ValueType* b,
lapack_int* ldb);
inline static FnType* fn = nullptr;
static ::xla::ffi::Error Kernel(
::xla::ffi::Buffer<dtype> x, ::xla::ffi::Buffer<dtype> y,
::xla::ffi::BufferR0<dtype> alpha, ::xla::ffi::ResultBuffer<dtype> y_out,
MatrixParams::Side side, MatrixParams::UpLo uplo,
MatrixParams::Transpose trans_x, MatrixParams::Diag diag);
};
//== LU Decomposition ==//
// lapack getrf

View File

@ -26,10 +26,10 @@ namespace ffi = xla::ffi;
extern "C" {
jax::Trsm<float>::FnType strsm_;
jax::Trsm<double>::FnType dtrsm_;
jax::Trsm<std::complex<float>>::FnType ctrsm_;
jax::Trsm<std::complex<double>>::FnType ztrsm_;
jax::TriMatrixEquationSolver<ffi::DataType::F32>::FnType strsm_;
jax::TriMatrixEquationSolver<ffi::DataType::F64>::FnType dtrsm_;
jax::TriMatrixEquationSolver<ffi::DataType::C64>::FnType ctrsm_;
jax::TriMatrixEquationSolver<ffi::DataType::C128>::FnType ztrsm_;
jax::LuDecomposition<ffi::DataType::F32>::FnType sgetrf_;
jax::LuDecomposition<ffi::DataType::F64>::FnType dgetrf_;
@ -87,6 +87,22 @@ namespace jax {
#define JAX_KERNEL_FNTYPE_MISMATCH_MSG "FFI Kernel FnType mismatch"
static_assert(
std::is_same_v<jax::TriMatrixEquationSolver<ffi::DataType::F32>::FnType,
jax::Trsm<float>::FnType>,
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
static_assert(
std::is_same_v<jax::TriMatrixEquationSolver<ffi::DataType::F64>::FnType,
jax::Trsm<double>::FnType>,
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
static_assert(
std::is_same_v<jax::TriMatrixEquationSolver<ffi::DataType::C64>::FnType,
jax::Trsm<std::complex<float>>::FnType>,
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
static_assert(
std::is_same_v<jax::TriMatrixEquationSolver<ffi::DataType::C128>::FnType,
jax::Trsm<std::complex<double>>::FnType>,
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
static_assert(std::is_same_v<jax::LuDecomposition<ffi::DataType::F32>::FnType,
jax::Getrf<float>::FnType>,
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
@ -218,6 +234,11 @@ static auto init = []() -> int {
// FFI Kernels
AssignKernelFn<TriMatrixEquationSolver<ffi::DataType::F32>>(strsm_);
AssignKernelFn<TriMatrixEquationSolver<ffi::DataType::F64>>(dtrsm_);
AssignKernelFn<TriMatrixEquationSolver<ffi::DataType::C64>>(ctrsm_);
AssignKernelFn<TriMatrixEquationSolver<ffi::DataType::C128>>(ztrsm_);
AssignKernelFn<LuDecomposition<ffi::DataType::F32>>(sgetrf_);
AssignKernelFn<LuDecomposition<ffi::DataType::F64>>(dgetrf_);
AssignKernelFn<LuDecomposition<ffi::DataType::C64>>(cgetrf_);