diff --git a/jaxlib/cpu/cpu_kernels.cc b/jaxlib/cpu/cpu_kernels.cc index c2e122c04..eb3029c62 100644 --- a/jaxlib/cpu/cpu_kernels.cc +++ b/jaxlib/cpu/cpu_kernels.cc @@ -149,6 +149,10 @@ JAX_CPU_REGISTER_HANDLER(lapack_sgeev_ffi); JAX_CPU_REGISTER_HANDLER(lapack_dgeev_ffi); JAX_CPU_REGISTER_HANDLER(lapack_cgeev_ffi); JAX_CPU_REGISTER_HANDLER(lapack_zgeev_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_ssytrd_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_dsytrd_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_chetrd_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_zhetrd_ffi); JAX_CPU_REGISTER_HANDLER(lapack_sgehrd_ffi); JAX_CPU_REGISTER_HANDLER(lapack_dgehrd_ffi); JAX_CPU_REGISTER_HANDLER(lapack_cgehrd_ffi); diff --git a/jaxlib/cpu/lapack.cc b/jaxlib/cpu/lapack.cc index 8fc480951..fc04963a0 100644 --- a/jaxlib/cpu/lapack.cc +++ b/jaxlib/cpu/lapack.cc @@ -151,6 +151,10 @@ void GetLapackKernelsFromScipy() { AssignKernelFn>(lapack_ptr("dsytrd")); AssignKernelFn>>(lapack_ptr("chetrd")); AssignKernelFn>>(lapack_ptr("zhetrd")); + AssignKernelFn>(lapack_ptr("ssytrd")); + AssignKernelFn>(lapack_ptr("dsytrd")); + AssignKernelFn>(lapack_ptr("chetrd")); + AssignKernelFn>(lapack_ptr("zhetrd")); initialized = true; } @@ -257,6 +261,10 @@ nb::dict Registrations() { dict["lapack_dgeev_ffi"] = EncapsulateFunction(lapack_dgeev_ffi); dict["lapack_cgeev_ffi"] = EncapsulateFunction(lapack_cgeev_ffi); dict["lapack_zgeev_ffi"] = EncapsulateFunction(lapack_zgeev_ffi); + dict["lapack_ssytrd_ffi"] = EncapsulateFunction(lapack_ssytrd_ffi); + dict["lapack_dsytrd_ffi"] = EncapsulateFunction(lapack_dsytrd_ffi); + dict["lapack_chetrd_ffi"] = EncapsulateFunction(lapack_chetrd_ffi); + dict["lapack_zhetrd_ffi"] = EncapsulateFunction(lapack_zhetrd_ffi); dict["lapack_sgehrd_ffi"] = EncapsulateFunction(lapack_sgehrd_ffi); dict["lapack_dgehrd_ffi"] = EncapsulateFunction(lapack_dgehrd_ffi); dict["lapack_cgehrd_ffi"] = EncapsulateFunction(lapack_cgehrd_ffi); diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index 7d5839522..4b2554033 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -1741,6 +1741,68 @@ template struct Sytrd; template struct Sytrd>; template struct Sytrd>; +// FFI Kernel + +template +ffi::Error TridiagonalReduction::Kernel( + ffi::Buffer x, MatrixParams::UpLo uplo, + ffi::ResultBuffer x_out, + ffi::ResultBuffer diagonal, + ffi::ResultBuffer off_diagonal, + ffi::ResultBuffer tau, ffi::ResultBuffer info) { + FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]), + SplitBatch2D(x.dimensions())); + + CopyIfDiffBuffer(x, x_out); + + ValueType* x_out_data = x_out->typed_data(); + RealType* diagonal_data = diagonal->typed_data(); + RealType* off_diagonal_data = off_diagonal->typed_data(); + ValueType* tau_data = tau->typed_data(); + lapack_int* info_data = info->typed_data(); + + // Prepare LAPACK workspaces. + const auto work_size = GetWorkspaceSize(x_rows, x_cols); + auto work_data = AllocateScratchMemory(work_size); + + auto uplo_v = static_cast(uplo); + FFI_ASSIGN_OR_RETURN(auto x_leading_dim_v, + MaybeCastNoOverflow(x_rows)); + FFI_ASSIGN_OR_RETURN(auto work_size_v, + MaybeCastNoOverflow(work_size)); + FFI_ASSIGN_OR_RETURN(auto x_order_v, MaybeCastNoOverflow(x_cols)); + + int64_t x_size = x_rows * x_cols; + int64_t tau_step = {tau->dimensions().back()}; + for (int64_t i = 0; i < batch_count; ++i) { + fn(&uplo_v, &x_order_v, x_out_data, &x_leading_dim_v, diagonal_data, + off_diagonal_data, tau_data, work_data.get(), &work_size_v, info_data); + x_out_data += x_size; + diagonal_data += x_cols; + off_diagonal_data += x_cols - 1; + tau_data += tau_step; + ++info_data; + } + return ffi::Error::Success(); +} + +template +int64_t TridiagonalReduction::GetWorkspaceSize(lapack_int x_rows, + lapack_int x_cols) { + ValueType optimal_size = {}; + lapack_int workspace_query = -1; + lapack_int info = 0; + char uplo_v = 'L'; + fn(&uplo_v, &x_cols, nullptr, &x_rows, nullptr, nullptr, nullptr, + &optimal_size, &workspace_query, &info); + return info == 0 ? static_cast(std::real(optimal_size)) : -1; +} + +template struct TridiagonalReduction; +template struct TridiagonalReduction; +template struct TridiagonalReduction; +template struct TridiagonalReduction; + // FFI Definition Macros (by DataType) #define JAX_CPU_DEFINE_TRSM(name, data_type) \ @@ -1864,6 +1926,20 @@ template struct Sytrd>; .Ret<::xla::ffi::Buffer>(/*eigvecs_right*/) \ .Ret<::xla::ffi::Buffer>(/*info*/)) +#define JAX_CPU_DEFINE_SYTRD_HETRD(name, data_type) \ + XLA_FFI_DEFINE_HANDLER_SYMBOL( \ + name, TridiagonalReduction::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Attr("uplo") \ + .Ret<::xla::ffi::Buffer>(/*x_out*/) \ + .Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>( \ + /*diagonal*/) \ + .Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>( \ + /*off_diagonal*/) \ + .Ret<::xla::ffi::Buffer>(/*tau*/) \ + .Ret<::xla::ffi::Buffer>(/*info*/)) + #define JAX_CPU_DEFINE_GEHRD(name, data_type) \ XLA_FFI_DEFINE_HANDLER_SYMBOL( \ name, HessenbergDecomposition::Kernel, \ @@ -1917,6 +1993,11 @@ JAX_CPU_DEFINE_GEEV(lapack_dgeev_ffi, ::xla::ffi::DataType::F64); JAX_CPU_DEFINE_GEEV_COMPLEX(lapack_cgeev_ffi, ::xla::ffi::DataType::C64); JAX_CPU_DEFINE_GEEV_COMPLEX(lapack_zgeev_ffi, ::xla::ffi::DataType::C128); +JAX_CPU_DEFINE_SYTRD_HETRD(lapack_ssytrd_ffi, ::xla::ffi::DataType::F32); +JAX_CPU_DEFINE_SYTRD_HETRD(lapack_dsytrd_ffi, ::xla::ffi::DataType::F64); +JAX_CPU_DEFINE_SYTRD_HETRD(lapack_chetrd_ffi, ::xla::ffi::DataType::C64); +JAX_CPU_DEFINE_SYTRD_HETRD(lapack_zhetrd_ffi, ::xla::ffi::DataType::C128); + JAX_CPU_DEFINE_GEHRD(lapack_sgehrd_ffi, ::xla::ffi::DataType::F32); JAX_CPU_DEFINE_GEHRD(lapack_dgehrd_ffi, ::xla::ffi::DataType::F64); JAX_CPU_DEFINE_GEHRD(lapack_cgehrd_ffi, ::xla::ffi::DataType::C64); @@ -1933,6 +2014,7 @@ JAX_CPU_DEFINE_GEHRD(lapack_zgehrd_ffi, ::xla::ffi::DataType::C128); #undef JAX_CPU_DEFINE_HEEVD #undef JAX_CPU_DEFINE_GEEV #undef JAX_CPU_DEFINE_GEEV_COMPLEX +#undef JAX_CPU_DEFINE_SYTRD_HETRD #undef JAX_CPU_DEFINE_GEHRD } // namespace jax diff --git a/jaxlib/cpu/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h index b4f54b923..e5fa9d354 100644 --- a/jaxlib/cpu/lapack_kernels.h +++ b/jaxlib/cpu/lapack_kernels.h @@ -617,6 +617,29 @@ struct Sytrd { static int64_t Workspace(lapack_int lda, lapack_int n); }; +// FFI Kernel + +template <::xla::ffi::DataType dtype> +struct TridiagonalReduction { + using ValueType = ::xla::ffi::NativeType; + using RealType = ::xla::ffi::NativeType<::xla::ffi::ToReal(dtype)>; + using FnType = void(char* uplo, lapack_int* n, ValueType* a, lapack_int* lda, + RealType* d, RealType* e, ValueType* tau, ValueType* work, + lapack_int* lwork, lapack_int* info); + + inline static FnType* fn = nullptr; + + static ::xla::ffi::Error Kernel( + ::xla::ffi::Buffer x, MatrixParams::UpLo uplo, + ::xla::ffi::ResultBuffer x_out, + ::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> diagonal, + ::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> off_diagonal, + ::xla::ffi::ResultBuffer tau, + ::xla::ffi::ResultBuffer info); + + static int64_t GetWorkspaceSize(lapack_int x_rows, lapack_int x_cols); +}; + // Declare all the handler symbols XLA_FFI_DECLARE_HANDLER_SYMBOL(blas_strsm_ffi); XLA_FFI_DECLARE_HANDLER_SYMBOL(blas_dtrsm_ffi); @@ -650,6 +673,10 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgeev_ffi); XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgeev_ffi); XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgeev_ffi); XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zgeev_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_ssytrd_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dsytrd_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_chetrd_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zhetrd_ffi); XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgehrd_ffi); XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgehrd_ffi); XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgehrd_ffi); diff --git a/jaxlib/cpu/lapack_kernels_using_lapack.cc b/jaxlib/cpu/lapack_kernels_using_lapack.cc index 9f13bb99d..48efcedeb 100644 --- a/jaxlib/cpu/lapack_kernels_using_lapack.cc +++ b/jaxlib/cpu/lapack_kernels_using_lapack.cc @@ -76,10 +76,10 @@ jax::HessenbergDecomposition::FnType dgehrd_; jax::HessenbergDecomposition::FnType cgehrd_; jax::HessenbergDecomposition::FnType zgehrd_; -jax::Sytrd::FnType ssytrd_; -jax::Sytrd::FnType dsytrd_; -jax::Sytrd>::FnType chetrd_; -jax::Sytrd>::FnType zhetrd_; +jax::TridiagonalReduction::FnType ssytrd_; +jax::TridiagonalReduction::FnType dsytrd_; +jax::TridiagonalReduction::FnType chetrd_; +jax::TridiagonalReduction::FnType zhetrd_; } // extern "C" @@ -211,6 +211,22 @@ static_assert( jax::EigenvalueDecompositionComplex::FnType, jax::ComplexGeev>::FnType>, JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert( + std::is_same_v::FnType, + jax::Sytrd::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert( + std::is_same_v::FnType, + jax::Sytrd::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert( + std::is_same_v::FnType, + jax::Sytrd>::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert( + std::is_same_v::FnType, + jax::Sytrd>::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); static_assert( std::is_same_v::FnType, jax::Gehrd::FnType>, @@ -331,6 +347,11 @@ static auto init = []() -> int { AssignKernelFn>(cgeev_); AssignKernelFn>(zgeev_); + AssignKernelFn>(ssytrd_); + AssignKernelFn>(dsytrd_); + AssignKernelFn>(chetrd_); + AssignKernelFn>(zhetrd_); + AssignKernelFn>(sgehrd_); AssignKernelFn>(dgehrd_); AssignKernelFn>(cgehrd_);