Use absl::Status::message() instead of error_message().

PiperOrigin-RevId: 389810033
This commit is contained in:
Qiao Zhang 2021-08-09 23:44:14 -07:00 committed by jax authors
parent 2afba31f71
commit a93eaf3c9e
7 changed files with 82 additions and 82 deletions

View File

@ -23,15 +23,15 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_format.h"
#include "absl/synchronization/mutex.h"
#include "include/pybind11/numpy.h"
#include "include/pybind11/pybind11.h"
#include "include/pybind11/stl.h"
#include "jaxlib/cuda_gpu_kernel_helpers.h"
#include "jaxlib/handle_pool.h"
#include "jaxlib/kernel_pybind11_helpers.h"
#include "third_party/gpus/cuda/include/cublas_v2.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "jaxlib/cuda_gpu_kernel_helpers.h"
#include "jaxlib/handle_pool.h"
#include "jaxlib/kernel_pybind11_helpers.h"
#include "include/pybind11/numpy.h"
#include "include/pybind11/pybind11.h"
#include "include/pybind11/stl.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
namespace jax {
@ -209,8 +209,8 @@ void TrsmBatched(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = TrsmBatched_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
@ -289,8 +289,8 @@ void GetrfBatched(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = GetrfBatched_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}

View File

@ -27,10 +27,10 @@ limitations under the License.
#define JAX_AS_STATUS(expr) jax::AsStatus(expr, __FILE__, __LINE__, #expr)
#define JAX_THROW_IF_ERROR(expr) \
{ \
auto s___ = (expr); \
if (!s___.ok()) throw std::runtime_error(s___.error_message()); \
#define JAX_THROW_IF_ERROR(expr) \
{ \
auto s___ = (expr); \
if (!s___.ok()) throw std::runtime_error(std::string(s___.message())); \
}
#define JAX_RETURN_IF_ERROR(expr) \

View File

@ -102,8 +102,8 @@ void CudaLuPivotsToPermutation(cudaStream_t stream, void** buffers,
XlaCustomCallStatus* status) {
auto s = CudaLuPivotsToPermutation_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}

View File

@ -136,8 +136,8 @@ void CudaThreeFry2x32(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CudaThreeFry2x32_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}

View File

@ -25,15 +25,15 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/strings/str_format.h"
#include "absl/synchronization/mutex.h"
#include "include/pybind11/numpy.h"
#include "include/pybind11/pybind11.h"
#include "include/pybind11/stl.h"
#include "jaxlib/cuda_gpu_kernel_helpers.h"
#include "jaxlib/handle_pool.h"
#include "jaxlib/kernel_pybind11_helpers.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "third_party/gpus/cuda/include/cusolverDn.h"
#include "jaxlib/cuda_gpu_kernel_helpers.h"
#include "jaxlib/handle_pool.h"
#include "jaxlib/kernel_pybind11_helpers.h"
#include "include/pybind11/numpy.h"
#include "include/pybind11/pybind11.h"
#include "include/pybind11/stl.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
namespace jax {
@ -245,8 +245,8 @@ void Potrf(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Potrf_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
@ -370,8 +370,8 @@ void Getrf(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Getrf_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
@ -498,8 +498,8 @@ void Geqrf(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Geqrf_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
@ -630,8 +630,8 @@ void Orgqr(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Orgqr_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
@ -757,8 +757,8 @@ void Syevd(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Syevd_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
@ -939,8 +939,8 @@ void Syevj(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Syevj_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
@ -1089,8 +1089,8 @@ void Gesvd(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Gesvd_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
@ -1299,8 +1299,8 @@ void Gesvdj(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Gesvdj_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}

View File

@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/gpus/cuda/include/cusparse.h"
#include <algorithm>
#include <cstdint>
#include <stdexcept>
@ -27,7 +29,6 @@ limitations under the License.
#include "third_party/gpus/cuda/include/cuComplex.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "third_party/gpus/cuda/include/cusparse.h"
#include "jaxlib/cuda_gpu_kernel_helpers.h"
#include "jaxlib/handle_pool.h"
#include "jaxlib/kernel_pybind11_helpers.h"
@ -312,8 +313,8 @@ void CsrToDense(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CsrToDense_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
@ -387,8 +388,8 @@ void CsrFromDense(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CsrFromDense_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
@ -495,8 +496,8 @@ void CsrMatvec(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CsrMatvec_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
@ -606,8 +607,8 @@ void CsrMatmat(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CsrMatmat_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
@ -681,8 +682,8 @@ void CooToDense(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CooToDense_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
@ -756,8 +757,8 @@ void CooFromDense(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CooFromDense_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
@ -863,8 +864,8 @@ void CooMatvec(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CooMatvec_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
@ -973,8 +974,8 @@ void CooMatmat(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CooMatmat_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
#endif // if JAX_CUSPARSE_11030
@ -1029,8 +1030,8 @@ void gtsv2_f32(cudaStream_t stream, void** buffers, const char* opaque,
std::size_t opaque_len, XlaCustomCallStatus* status) {
auto s = gtsv2<float>(cusparseSgtsv2, stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
@ -1038,8 +1039,8 @@ void gtsv2_f64(cudaStream_t stream, void** buffers, const char* opaque,
std::size_t opaque_len, XlaCustomCallStatus* status) {
auto s = gtsv2<double>(cusparseDgtsv2, stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}

View File

@ -13,29 +13,28 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "rocm/include/rocblas.h"
#include <algorithm>
#include <stdexcept>
#include <utility>
#include <vector>
#include "rocm/include/hip/hip_runtime.h"
#include "rocm/include/hip/hip_runtime_api.h"
#include "rocm/include/rocsolver.h"
#include "absl/base/casts.h"
#include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_format.h"
#include "absl/synchronization/mutex.h"
#include "include/pybind11/numpy.h"
#include "include/pybind11/pybind11.h"
#include "include/pybind11/stl.h"
#include "jaxlib/handle_pool.h"
#include "jaxlib/kernel_pybind11_helpers.h"
#include "jaxlib/rocm_gpu_kernel_helpers.h"
#include "include/pybind11/numpy.h"
#include "include/pybind11/pybind11.h"
#include "include/pybind11/stl.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
#include "rocm/include/hip/hip_runtime.h"
#include "rocm/include/hip/hip_runtime_api.h"
#include "rocm/include/rocblas.h"
#include "rocm/include/rocsolver.h"
namespace jax {
@ -280,8 +279,8 @@ void Trsm(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Trsm_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
@ -399,8 +398,8 @@ void Potrf(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Potrf_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
@ -519,8 +518,8 @@ void Getrf(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Getrf_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
@ -649,8 +648,8 @@ void Geqrf(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Geqrf_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
@ -746,8 +745,8 @@ void Orgqr(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Orgqr_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
@ -942,8 +941,8 @@ void Gesvd(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Gesvd_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}