mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Port Triangular Solve to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks. PiperOrigin-RevId: 655484166
This commit is contained in:
parent
832eb2d8d2
commit
54fe6e68a0
@ -118,6 +118,10 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(
|
||||
|
||||
// FFI Kernels
|
||||
|
||||
JAX_CPU_REGISTER_HANDLER(blas_strsm_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(blas_dtrsm_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(blas_ctrsm_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(blas_ztrsm_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_sgetrf_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_dgetrf_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_cgetrf_ffi);
|
||||
|
@ -69,6 +69,10 @@ void GetLapackKernelsFromScipy() {
|
||||
AssignKernelFn<Trsm<double>>(blas_ptr("dtrsm"));
|
||||
AssignKernelFn<Trsm<std::complex<float>>>(blas_ptr("ctrsm"));
|
||||
AssignKernelFn<Trsm<std::complex<double>>>(blas_ptr("ztrsm"));
|
||||
AssignKernelFn<TriMatrixEquationSolver<DataType::F32>>(blas_ptr("strsm"));
|
||||
AssignKernelFn<TriMatrixEquationSolver<DataType::F64>>(blas_ptr("dtrsm"));
|
||||
AssignKernelFn<TriMatrixEquationSolver<DataType::C64>>(blas_ptr("ctrsm"));
|
||||
AssignKernelFn<TriMatrixEquationSolver<DataType::C128>>(blas_ptr("ztrsm"));
|
||||
|
||||
nb::module_ cython_lapack =
|
||||
nb::module_::import_("scipy.linalg.cython_lapack");
|
||||
@ -219,6 +223,10 @@ nb::dict Registrations() {
|
||||
dict["lapack_zhetrd"] =
|
||||
EncapsulateFunction(Sytrd<std::complex<double>>::Kernel);
|
||||
|
||||
dict["blas_strsm_ffi"] = EncapsulateFunction(blas_strsm_ffi);
|
||||
dict["blas_dtrsm_ffi"] = EncapsulateFunction(blas_dtrsm_ffi);
|
||||
dict["blas_ctrsm_ffi"] = EncapsulateFunction(blas_ctrsm_ffi);
|
||||
dict["blas_ztrsm_ffi"] = EncapsulateFunction(blas_ztrsm_ffi);
|
||||
dict["lapack_sgetrf_ffi"] = EncapsulateFunction(lapack_sgetrf_ffi);
|
||||
dict["lapack_dgetrf_ffi"] = EncapsulateFunction(lapack_dgetrf_ffi);
|
||||
dict["lapack_cgetrf_ffi"] = EncapsulateFunction(lapack_cgetrf_ffi);
|
||||
|
@ -23,6 +23,18 @@ namespace jax {
|
||||
|
||||
// FFI Definition Macros (by DataType)
|
||||
|
||||
#define JAX_CPU_DEFINE_TRSM(name, data_type) \
|
||||
XLA_FFI_DEFINE_HANDLER(name, TriMatrixEquationSolver<data_type>::Kernel, \
|
||||
::xla::ffi::Ffi::Bind() \
|
||||
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
|
||||
.Arg<::xla::ffi::Buffer<data_type>>(/*y*/) \
|
||||
.Arg<::xla::ffi::BufferR0<data_type>>(/*alpha*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*y_out*/) \
|
||||
.Attr<MatrixParams::Side>("side") \
|
||||
.Attr<MatrixParams::UpLo>("uplo") \
|
||||
.Attr<MatrixParams::Transpose>("trans_x") \
|
||||
.Attr<MatrixParams::Diag>("diag"))
|
||||
|
||||
#define JAX_CPU_DEFINE_GETRF(name, data_type) \
|
||||
XLA_FFI_DEFINE_HANDLER( \
|
||||
name, LuDecomposition<data_type>::Kernel, \
|
||||
@ -92,6 +104,11 @@ namespace jax {
|
||||
|
||||
// FFI Handlers
|
||||
|
||||
JAX_CPU_DEFINE_TRSM(blas_strsm_ffi, ::xla::ffi::DataType::F32);
|
||||
JAX_CPU_DEFINE_TRSM(blas_dtrsm_ffi, ::xla::ffi::DataType::F64);
|
||||
JAX_CPU_DEFINE_TRSM(blas_ctrsm_ffi, ::xla::ffi::DataType::C64);
|
||||
JAX_CPU_DEFINE_TRSM(blas_ztrsm_ffi, ::xla::ffi::DataType::C128);
|
||||
|
||||
JAX_CPU_DEFINE_GETRF(lapack_sgetrf_ffi, ::xla::ffi::DataType::F32);
|
||||
JAX_CPU_DEFINE_GETRF(lapack_dgetrf_ffi, ::xla::ffi::DataType::F64);
|
||||
JAX_CPU_DEFINE_GETRF(lapack_cgetrf_ffi, ::xla::ffi::DataType::C64);
|
||||
@ -117,6 +134,7 @@ JAX_CPU_DEFINE_GESDD(lapack_dgesdd_ffi, ::xla::ffi::DataType::F64);
|
||||
JAX_CPU_DEFINE_GESDD_COMPLEX(lapack_cgesdd_ffi, ::xla::ffi::DataType::C64);
|
||||
JAX_CPU_DEFINE_GESDD_COMPLEX(lapack_zgesdd_ffi, ::xla::ffi::DataType::C128);
|
||||
|
||||
#undef JAX_CPU_DEFINE_TRSM
|
||||
#undef JAX_CPU_DEFINE_GETRF
|
||||
#undef JAX_CPU_DEFINE_GEQRF
|
||||
#undef JAX_CPU_DEFINE_ORGQR
|
||||
|
@ -160,6 +160,48 @@ template struct Trsm<double>;
|
||||
template struct Trsm<std::complex<float>>;
|
||||
template struct Trsm<std::complex<double>>;
|
||||
|
||||
// FFI Kernel
|
||||
|
||||
template <ffi::DataType dtype>
|
||||
ffi::Error TriMatrixEquationSolver<dtype>::Kernel(
|
||||
ffi::Buffer<dtype> x, ffi::Buffer<dtype> y, ffi::BufferR0<dtype> alpha,
|
||||
ffi::ResultBuffer<dtype> y_out, MatrixParams::Side side,
|
||||
MatrixParams::UpLo uplo, MatrixParams::Transpose trans_x,
|
||||
MatrixParams::Diag diag) {
|
||||
CopyIfDiffBuffer(y, y_out);
|
||||
|
||||
auto [batch_count, y_rows, y_cols] = SplitBatch2D(y.dimensions());
|
||||
auto* y_out_data = y_out->typed_data();
|
||||
lapack_int x_leading_dim_v =
|
||||
side == MatrixParams::Side::kLeft ? y_rows : y_cols;
|
||||
lapack_int y_leading_dim_v = y_rows;
|
||||
|
||||
auto side_v = static_cast<char>(side);
|
||||
auto uplo_v = static_cast<char>(uplo);
|
||||
auto trans_x_v = static_cast<char>(trans_x);
|
||||
auto diag_v = static_cast<char>(diag);
|
||||
FFI_ASSIGN_OR_RETURN(auto y_rows_v, MaybeCastNoOverflow<lapack_int>(y_rows));
|
||||
FFI_ASSIGN_OR_RETURN(auto y_cols_v, MaybeCastNoOverflow<lapack_int>(y_cols));
|
||||
|
||||
auto* x_data = x.typed_data();
|
||||
const int64_t y_out_step{y_rows * y_cols};
|
||||
const int64_t x_step{x_leading_dim_v * x_leading_dim_v};
|
||||
for (int64_t i = 0; i < batch_count; ++i) {
|
||||
fn(&side_v, &uplo_v, &trans_x_v, &diag_v, &y_rows_v, &y_cols_v,
|
||||
alpha.typed_data(), x_data, &x_leading_dim_v, y_out_data,
|
||||
&y_leading_dim_v);
|
||||
|
||||
y_out_data += y_out_step;
|
||||
x_data += x_step;
|
||||
}
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
template struct TriMatrixEquationSolver<ffi::DataType::F32>;
|
||||
template struct TriMatrixEquationSolver<ffi::DataType::F64>;
|
||||
template struct TriMatrixEquationSolver<ffi::DataType::C64>;
|
||||
template struct TriMatrixEquationSolver<ffi::DataType::C128>;
|
||||
|
||||
//== LU Decomposition ==//
|
||||
|
||||
// lapack getrf
|
||||
|
@ -109,6 +109,24 @@ struct Trsm {
|
||||
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
|
||||
};
|
||||
|
||||
// FFI Kernel
|
||||
|
||||
template <::xla::ffi::DataType dtype>
|
||||
struct TriMatrixEquationSolver {
|
||||
using ValueType = ::xla::ffi::NativeType<dtype>;
|
||||
using FnType = void(char* side, char* uplo, char* transa, char* diag,
|
||||
lapack_int* m, lapack_int* n, ValueType* alpha,
|
||||
ValueType* a, lapack_int* lda, ValueType* b,
|
||||
lapack_int* ldb);
|
||||
|
||||
inline static FnType* fn = nullptr;
|
||||
static ::xla::ffi::Error Kernel(
|
||||
::xla::ffi::Buffer<dtype> x, ::xla::ffi::Buffer<dtype> y,
|
||||
::xla::ffi::BufferR0<dtype> alpha, ::xla::ffi::ResultBuffer<dtype> y_out,
|
||||
MatrixParams::Side side, MatrixParams::UpLo uplo,
|
||||
MatrixParams::Transpose trans_x, MatrixParams::Diag diag);
|
||||
};
|
||||
|
||||
//== LU Decomposition ==//
|
||||
|
||||
// lapack getrf
|
||||
|
@ -26,10 +26,10 @@ namespace ffi = xla::ffi;
|
||||
|
||||
extern "C" {
|
||||
|
||||
jax::Trsm<float>::FnType strsm_;
|
||||
jax::Trsm<double>::FnType dtrsm_;
|
||||
jax::Trsm<std::complex<float>>::FnType ctrsm_;
|
||||
jax::Trsm<std::complex<double>>::FnType ztrsm_;
|
||||
jax::TriMatrixEquationSolver<ffi::DataType::F32>::FnType strsm_;
|
||||
jax::TriMatrixEquationSolver<ffi::DataType::F64>::FnType dtrsm_;
|
||||
jax::TriMatrixEquationSolver<ffi::DataType::C64>::FnType ctrsm_;
|
||||
jax::TriMatrixEquationSolver<ffi::DataType::C128>::FnType ztrsm_;
|
||||
|
||||
jax::LuDecomposition<ffi::DataType::F32>::FnType sgetrf_;
|
||||
jax::LuDecomposition<ffi::DataType::F64>::FnType dgetrf_;
|
||||
@ -87,6 +87,22 @@ namespace jax {
|
||||
|
||||
#define JAX_KERNEL_FNTYPE_MISMATCH_MSG "FFI Kernel FnType mismatch"
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<jax::TriMatrixEquationSolver<ffi::DataType::F32>::FnType,
|
||||
jax::Trsm<float>::FnType>,
|
||||
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
|
||||
static_assert(
|
||||
std::is_same_v<jax::TriMatrixEquationSolver<ffi::DataType::F64>::FnType,
|
||||
jax::Trsm<double>::FnType>,
|
||||
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
|
||||
static_assert(
|
||||
std::is_same_v<jax::TriMatrixEquationSolver<ffi::DataType::C64>::FnType,
|
||||
jax::Trsm<std::complex<float>>::FnType>,
|
||||
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
|
||||
static_assert(
|
||||
std::is_same_v<jax::TriMatrixEquationSolver<ffi::DataType::C128>::FnType,
|
||||
jax::Trsm<std::complex<double>>::FnType>,
|
||||
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
|
||||
static_assert(std::is_same_v<jax::LuDecomposition<ffi::DataType::F32>::FnType,
|
||||
jax::Getrf<float>::FnType>,
|
||||
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
|
||||
@ -218,6 +234,11 @@ static auto init = []() -> int {
|
||||
|
||||
// FFI Kernels
|
||||
|
||||
AssignKernelFn<TriMatrixEquationSolver<ffi::DataType::F32>>(strsm_);
|
||||
AssignKernelFn<TriMatrixEquationSolver<ffi::DataType::F64>>(dtrsm_);
|
||||
AssignKernelFn<TriMatrixEquationSolver<ffi::DataType::C64>>(ctrsm_);
|
||||
AssignKernelFn<TriMatrixEquationSolver<ffi::DataType::C128>>(ztrsm_);
|
||||
|
||||
AssignKernelFn<LuDecomposition<ffi::DataType::F32>>(sgetrf_);
|
||||
AssignKernelFn<LuDecomposition<ffi::DataType::F64>>(dgetrf_);
|
||||
AssignKernelFn<LuDecomposition<ffi::DataType::C64>>(cgetrf_);
|
||||
|
Loading…
x
Reference in New Issue
Block a user