From ff4e0b1214c18aa37b1c36676a5b7170c5aae37a Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 29 Jul 2024 06:58:48 -0700 Subject: [PATCH] Rearrange the LAPACK handler definitions in jaxlib to avoid duplicate handler errors. When linking the jaxlib `cpu_kernels` target and importing JAX, we currently silently fail to instantiate the CPU backend. This refactor means that we only ever define one version of the handlers. PiperOrigin-RevId: 657186057 --- jaxlib/cpu/BUILD | 2 - jaxlib/cpu/cpu_kernels.cc | 1 - jaxlib/cpu/lapack.cc | 2 - jaxlib/cpu/lapack.h | 147 ----------------------------------- jaxlib/cpu/lapack_kernels.cc | 122 +++++++++++++++++++++++++++++ jaxlib/cpu/lapack_kernels.h | 26 +++++++ 6 files changed, 148 insertions(+), 152 deletions(-) delete mode 100644 jaxlib/cpu/lapack.h diff --git a/jaxlib/cpu/BUILD b/jaxlib/cpu/BUILD index d6dba6e31..a89d9ef95 100644 --- a/jaxlib/cpu/BUILD +++ b/jaxlib/cpu/BUILD @@ -54,7 +54,6 @@ cc_library( pybind_extension( name = "_lapack", srcs = ["lapack.cc"], - hdrs = ["lapack.h"], copts = [ "-fexceptions", "-fno-strict-aliasing", @@ -77,7 +76,6 @@ pybind_extension( cc_library( name = "cpu_kernels", srcs = ["cpu_kernels.cc"], - hdrs = ["lapack.h"], visibility = ["//visibility:public"], deps = [ ":lapack_kernels", diff --git a/jaxlib/cpu/cpu_kernels.cc b/jaxlib/cpu/cpu_kernels.cc index c7bc9dc4b..d6136a202 100644 --- a/jaxlib/cpu/cpu_kernels.cc +++ b/jaxlib/cpu/cpu_kernels.cc @@ -18,7 +18,6 @@ limitations under the License. #include -#include "jaxlib/cpu/lapack.h" #include "jaxlib/cpu/lapack_kernels.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/api/ffi.h" diff --git a/jaxlib/cpu/lapack.cc b/jaxlib/cpu/lapack.cc index 0aa924486..fa8a352e2 100644 --- a/jaxlib/cpu/lapack.cc +++ b/jaxlib/cpu/lapack.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/cpu/lapack.h" - #include #include "nanobind/nanobind.h" diff --git a/jaxlib/cpu/lapack.h b/jaxlib/cpu/lapack.h deleted file mode 100644 index 81a5acce5..000000000 --- a/jaxlib/cpu/lapack.h +++ /dev/null @@ -1,147 +0,0 @@ -/* Copyright 2024 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef JAXLIB_CPU_LAPACK_H_ -#define JAXLIB_CPU_LAPACK_H_ - -#include "jaxlib/cpu/lapack_kernels.h" -#include "xla/ffi/api/ffi.h" - -namespace jax { - -// FFI Definition Macros (by DataType) - -#define JAX_CPU_DEFINE_TRSM(name, data_type) \ - XLA_FFI_DEFINE_HANDLER(name, TriMatrixEquationSolver::Kernel, \ - ::xla::ffi::Ffi::Bind() \ - .Arg<::xla::ffi::Buffer>(/*x*/) \ - .Arg<::xla::ffi::Buffer>(/*y*/) \ - .Arg<::xla::ffi::BufferR0>(/*alpha*/) \ - .Ret<::xla::ffi::Buffer>(/*y_out*/) \ - .Attr("side") \ - .Attr("uplo") \ - .Attr("trans_x") \ - .Attr("diag")) - -#define JAX_CPU_DEFINE_GETRF(name, data_type) \ - XLA_FFI_DEFINE_HANDLER( \ - name, LuDecomposition::Kernel, \ - ::xla::ffi::Ffi::Bind() \ - .Arg<::xla::ffi::Buffer>(/*x*/) \ - .Ret<::xla::ffi::Buffer>(/*x_out*/) \ - .Ret<::xla::ffi::Buffer>(/*ipiv*/) \ - .Ret<::xla::ffi::Buffer>(/*info*/)) - -#define JAX_CPU_DEFINE_GEQRF(name, data_type) \ - XLA_FFI_DEFINE_HANDLER( \ - name, QrFactorization::Kernel, \ - ::xla::ffi::Ffi::Bind() \ - .Arg<::xla::ffi::Buffer>(/*x*/) \ - .Ret<::xla::ffi::Buffer>(/*x_out*/) \ - .Ret<::xla::ffi::Buffer>(/*tau*/) \ - .Ret<::xla::ffi::Buffer>(/*info*/) \ - .Ret<::xla::ffi::Buffer>(/*work*/)) - -#define JAX_CPU_DEFINE_ORGQR(name, data_type) \ - XLA_FFI_DEFINE_HANDLER( \ - name, OrthogonalQr::Kernel, \ - ::xla::ffi::Ffi::Bind() \ - .Arg<::xla::ffi::Buffer>(/*x*/) \ - .Arg<::xla::ffi::Buffer>(/*tau*/) \ - .Ret<::xla::ffi::Buffer>(/*x_out*/) \ - .Ret<::xla::ffi::Buffer>(/*info*/) \ - .Ret<::xla::ffi::Buffer>(/*work*/)) - -#define JAX_CPU_DEFINE_POTRF(name, data_type) \ - XLA_FFI_DEFINE_HANDLER( \ - name, CholeskyFactorization::Kernel, \ - ::xla::ffi::Ffi::Bind() \ - .Arg<::xla::ffi::Buffer>(/*x*/) \ - .Attr("uplo") \ - .Ret<::xla::ffi::Buffer>(/*x_out*/) \ - .Ret<::xla::ffi::Buffer>(/*info*/)) - -#define JAX_CPU_DEFINE_GESDD(name, data_type) \ - XLA_FFI_DEFINE_HANDLER( \ - name, SingularValueDecomposition::Kernel, \ - ::xla::ffi::Ffi::Bind() \ - .Arg<::xla::ffi::Buffer>(/*x*/) \ - .Ret<::xla::ffi::Buffer>(/*x_out*/) \ - .Ret<::xla::ffi::Buffer>(/*s*/) \ - .Ret<::xla::ffi::Buffer>(/*u*/) \ - .Ret<::xla::ffi::Buffer>(/*vt*/) \ - .Ret<::xla::ffi::Buffer>(/*info*/) \ - .Ret<::xla::ffi::Buffer>(/*iwork*/) \ - .Ret<::xla::ffi::Buffer>(/*work*/) \ - .Attr("mode")) - -#define JAX_CPU_DEFINE_GESDD_COMPLEX(name, data_type) \ - XLA_FFI_DEFINE_HANDLER( \ - name, SingularValueDecompositionComplex::Kernel, \ - ::xla::ffi::Ffi::Bind() \ - .Arg<::xla::ffi::Buffer>(/*x*/) \ - .Ret<::xla::ffi::Buffer>(/*x_out*/) \ - .Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>(/*s*/) \ - .Ret<::xla::ffi::Buffer>(/*u*/) \ - .Ret<::xla::ffi::Buffer>(/*vt*/) \ - .Ret<::xla::ffi::Buffer>(/*info*/) \ - .Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>(/*rwork*/) \ - .Ret<::xla::ffi::Buffer>(/*iwork*/) \ - .Ret<::xla::ffi::Buffer>(/*work*/) \ - .Attr("mode")) - -// FFI Handlers - -JAX_CPU_DEFINE_TRSM(blas_strsm_ffi, ::xla::ffi::DataType::F32); -JAX_CPU_DEFINE_TRSM(blas_dtrsm_ffi, ::xla::ffi::DataType::F64); -JAX_CPU_DEFINE_TRSM(blas_ctrsm_ffi, ::xla::ffi::DataType::C64); -JAX_CPU_DEFINE_TRSM(blas_ztrsm_ffi, ::xla::ffi::DataType::C128); - -JAX_CPU_DEFINE_GETRF(lapack_sgetrf_ffi, ::xla::ffi::DataType::F32); -JAX_CPU_DEFINE_GETRF(lapack_dgetrf_ffi, ::xla::ffi::DataType::F64); -JAX_CPU_DEFINE_GETRF(lapack_cgetrf_ffi, ::xla::ffi::DataType::C64); -JAX_CPU_DEFINE_GETRF(lapack_zgetrf_ffi, ::xla::ffi::DataType::C128); - -JAX_CPU_DEFINE_GEQRF(lapack_sgeqrf_ffi, ::xla::ffi::DataType::F32); -JAX_CPU_DEFINE_GEQRF(lapack_dgeqrf_ffi, ::xla::ffi::DataType::F64); -JAX_CPU_DEFINE_GEQRF(lapack_cgeqrf_ffi, ::xla::ffi::DataType::C64); -JAX_CPU_DEFINE_GEQRF(lapack_zgeqrf_ffi, ::xla::ffi::DataType::C128); - -JAX_CPU_DEFINE_ORGQR(lapack_sorgqr_ffi, ::xla::ffi::DataType::F32); -JAX_CPU_DEFINE_ORGQR(lapack_dorgqr_ffi, ::xla::ffi::DataType::F64); -JAX_CPU_DEFINE_ORGQR(lapack_cungqr_ffi, ::xla::ffi::DataType::C64); -JAX_CPU_DEFINE_ORGQR(lapack_zungqr_ffi, ::xla::ffi::DataType::C128); - -JAX_CPU_DEFINE_POTRF(lapack_spotrf_ffi, ::xla::ffi::DataType::F32); -JAX_CPU_DEFINE_POTRF(lapack_dpotrf_ffi, ::xla::ffi::DataType::F64); -JAX_CPU_DEFINE_POTRF(lapack_cpotrf_ffi, ::xla::ffi::DataType::C64); -JAX_CPU_DEFINE_POTRF(lapack_zpotrf_ffi, ::xla::ffi::DataType::C128); - -JAX_CPU_DEFINE_GESDD(lapack_sgesdd_ffi, ::xla::ffi::DataType::F32); -JAX_CPU_DEFINE_GESDD(lapack_dgesdd_ffi, ::xla::ffi::DataType::F64); -JAX_CPU_DEFINE_GESDD_COMPLEX(lapack_cgesdd_ffi, ::xla::ffi::DataType::C64); -JAX_CPU_DEFINE_GESDD_COMPLEX(lapack_zgesdd_ffi, ::xla::ffi::DataType::C128); - -#undef JAX_CPU_DEFINE_TRSM -#undef JAX_CPU_DEFINE_GETRF -#undef JAX_CPU_DEFINE_GEQRF -#undef JAX_CPU_DEFINE_ORGQR -#undef JAX_CPU_DEFINE_POTRF -#undef JAX_CPU_DEFINE_GESDD -#undef JAX_CPU_DEFINE_GESDD_COMPLEX - -} // namespace jax - -#endif // JAXLIB_CPU_LAPACK_H_ diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index 85fa02893..b3579c06f 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -1387,4 +1387,126 @@ template struct Sytrd; template struct Sytrd>; template struct Sytrd>; +// FFI Definition Macros (by DataType) + +#define JAX_CPU_DEFINE_TRSM(name, data_type) \ + XLA_FFI_DEFINE_HANDLER_SYMBOL( \ + name, TriMatrixEquationSolver::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Arg<::xla::ffi::Buffer>(/*y*/) \ + .Arg<::xla::ffi::BufferR0>(/*alpha*/) \ + .Ret<::xla::ffi::Buffer>(/*y_out*/) \ + .Attr("side") \ + .Attr("uplo") \ + .Attr("trans_x") \ + .Attr("diag")) + +#define JAX_CPU_DEFINE_GETRF(name, data_type) \ + XLA_FFI_DEFINE_HANDLER_SYMBOL( \ + name, LuDecomposition::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Ret<::xla::ffi::Buffer>(/*x_out*/) \ + .Ret<::xla::ffi::Buffer>(/*ipiv*/) \ + .Ret<::xla::ffi::Buffer>(/*info*/)) + +#define JAX_CPU_DEFINE_GEQRF(name, data_type) \ + XLA_FFI_DEFINE_HANDLER_SYMBOL( \ + name, QrFactorization::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Ret<::xla::ffi::Buffer>(/*x_out*/) \ + .Ret<::xla::ffi::Buffer>(/*tau*/) \ + .Ret<::xla::ffi::Buffer>(/*info*/) \ + .Ret<::xla::ffi::Buffer>(/*work*/)) + +#define JAX_CPU_DEFINE_ORGQR(name, data_type) \ + XLA_FFI_DEFINE_HANDLER_SYMBOL( \ + name, OrthogonalQr::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Arg<::xla::ffi::Buffer>(/*tau*/) \ + .Ret<::xla::ffi::Buffer>(/*x_out*/) \ + .Ret<::xla::ffi::Buffer>(/*info*/) \ + .Ret<::xla::ffi::Buffer>(/*work*/)) + +#define JAX_CPU_DEFINE_POTRF(name, data_type) \ + XLA_FFI_DEFINE_HANDLER_SYMBOL( \ + name, CholeskyFactorization::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Attr("uplo") \ + .Ret<::xla::ffi::Buffer>(/*x_out*/) \ + .Ret<::xla::ffi::Buffer>(/*info*/)) + +#define JAX_CPU_DEFINE_GESDD(name, data_type) \ + XLA_FFI_DEFINE_HANDLER_SYMBOL( \ + name, SingularValueDecomposition::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Ret<::xla::ffi::Buffer>(/*x_out*/) \ + .Ret<::xla::ffi::Buffer>(/*s*/) \ + .Ret<::xla::ffi::Buffer>(/*u*/) \ + .Ret<::xla::ffi::Buffer>(/*vt*/) \ + .Ret<::xla::ffi::Buffer>(/*info*/) \ + .Ret<::xla::ffi::Buffer>(/*iwork*/) \ + .Ret<::xla::ffi::Buffer>(/*work*/) \ + .Attr("mode")) + +#define JAX_CPU_DEFINE_GESDD_COMPLEX(name, data_type) \ + XLA_FFI_DEFINE_HANDLER_SYMBOL( \ + name, SingularValueDecompositionComplex::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Ret<::xla::ffi::Buffer>(/*x_out*/) \ + .Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>(/*s*/) \ + .Ret<::xla::ffi::Buffer>(/*u*/) \ + .Ret<::xla::ffi::Buffer>(/*vt*/) \ + .Ret<::xla::ffi::Buffer>(/*info*/) \ + .Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>(/*rwork*/) \ + .Ret<::xla::ffi::Buffer>(/*iwork*/) \ + .Ret<::xla::ffi::Buffer>(/*work*/) \ + .Attr("mode")) + +// FFI Handlers + +JAX_CPU_DEFINE_TRSM(blas_strsm_ffi, ::xla::ffi::DataType::F32); +JAX_CPU_DEFINE_TRSM(blas_dtrsm_ffi, ::xla::ffi::DataType::F64); +JAX_CPU_DEFINE_TRSM(blas_ctrsm_ffi, ::xla::ffi::DataType::C64); +JAX_CPU_DEFINE_TRSM(blas_ztrsm_ffi, ::xla::ffi::DataType::C128); + +JAX_CPU_DEFINE_GETRF(lapack_sgetrf_ffi, ::xla::ffi::DataType::F32); +JAX_CPU_DEFINE_GETRF(lapack_dgetrf_ffi, ::xla::ffi::DataType::F64); +JAX_CPU_DEFINE_GETRF(lapack_cgetrf_ffi, ::xla::ffi::DataType::C64); +JAX_CPU_DEFINE_GETRF(lapack_zgetrf_ffi, ::xla::ffi::DataType::C128); + +JAX_CPU_DEFINE_GEQRF(lapack_sgeqrf_ffi, ::xla::ffi::DataType::F32); +JAX_CPU_DEFINE_GEQRF(lapack_dgeqrf_ffi, ::xla::ffi::DataType::F64); +JAX_CPU_DEFINE_GEQRF(lapack_cgeqrf_ffi, ::xla::ffi::DataType::C64); +JAX_CPU_DEFINE_GEQRF(lapack_zgeqrf_ffi, ::xla::ffi::DataType::C128); + +JAX_CPU_DEFINE_ORGQR(lapack_sorgqr_ffi, ::xla::ffi::DataType::F32); +JAX_CPU_DEFINE_ORGQR(lapack_dorgqr_ffi, ::xla::ffi::DataType::F64); +JAX_CPU_DEFINE_ORGQR(lapack_cungqr_ffi, ::xla::ffi::DataType::C64); +JAX_CPU_DEFINE_ORGQR(lapack_zungqr_ffi, ::xla::ffi::DataType::C128); + +JAX_CPU_DEFINE_POTRF(lapack_spotrf_ffi, ::xla::ffi::DataType::F32); +JAX_CPU_DEFINE_POTRF(lapack_dpotrf_ffi, ::xla::ffi::DataType::F64); +JAX_CPU_DEFINE_POTRF(lapack_cpotrf_ffi, ::xla::ffi::DataType::C64); +JAX_CPU_DEFINE_POTRF(lapack_zpotrf_ffi, ::xla::ffi::DataType::C128); + +JAX_CPU_DEFINE_GESDD(lapack_sgesdd_ffi, ::xla::ffi::DataType::F32); +JAX_CPU_DEFINE_GESDD(lapack_dgesdd_ffi, ::xla::ffi::DataType::F64); +JAX_CPU_DEFINE_GESDD_COMPLEX(lapack_cgesdd_ffi, ::xla::ffi::DataType::C64); +JAX_CPU_DEFINE_GESDD_COMPLEX(lapack_zgesdd_ffi, ::xla::ffi::DataType::C128); + +#undef JAX_CPU_DEFINE_TRSM +#undef JAX_CPU_DEFINE_GETRF +#undef JAX_CPU_DEFINE_GEQRF +#undef JAX_CPU_DEFINE_ORGQR +#undef JAX_CPU_DEFINE_POTRF +#undef JAX_CPU_DEFINE_GESDD +#undef JAX_CPU_DEFINE_GESDD_COMPLEX + } // namespace jax diff --git a/jaxlib/cpu/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h index fd9d8c975..811a5d4b3 100644 --- a/jaxlib/cpu/lapack_kernels.h +++ b/jaxlib/cpu/lapack_kernels.h @@ -475,6 +475,32 @@ struct Sytrd { static int64_t Workspace(lapack_int lda, lapack_int n); }; +// Declare all the handler symbols +XLA_FFI_DECLARE_HANDLER_SYMBOL(blas_strsm_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(blas_dtrsm_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(blas_ctrsm_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(blas_ztrsm_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgetrf_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgetrf_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgetrf_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zgetrf_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgeqrf_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgeqrf_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgeqrf_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zgeqrf_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sorgqr_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dorgqr_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cungqr_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zungqr_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_spotrf_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dpotrf_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cpotrf_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zpotrf_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgesdd_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgesdd_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgesdd_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zgesdd_ffi); + } // namespace jax #endif // JAXLIB_CPU_LAPACK_KERNELS_H_