mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Use absl::Status::message() instead of error_message().
PiperOrigin-RevId: 389810033
This commit is contained in:
parent
2afba31f71
commit
a93eaf3c9e
@ -23,15 +23,15 @@ limitations under the License.
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "include/pybind11/numpy.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "include/pybind11/stl.h"
|
||||
#include "jaxlib/cuda_gpu_kernel_helpers.h"
|
||||
#include "jaxlib/handle_pool.h"
|
||||
#include "jaxlib/kernel_pybind11_helpers.h"
|
||||
#include "third_party/gpus/cuda/include/cublas_v2.h"
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||
#include "jaxlib/cuda_gpu_kernel_helpers.h"
|
||||
#include "jaxlib/handle_pool.h"
|
||||
#include "jaxlib/kernel_pybind11_helpers.h"
|
||||
#include "include/pybind11/numpy.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "include/pybind11/stl.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
@ -209,8 +209,8 @@ void TrsmBatched(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = TrsmBatched_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
@ -289,8 +289,8 @@ void GetrfBatched(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = GetrfBatched_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -27,10 +27,10 @@ limitations under the License.
|
||||
|
||||
#define JAX_AS_STATUS(expr) jax::AsStatus(expr, __FILE__, __LINE__, #expr)
|
||||
|
||||
#define JAX_THROW_IF_ERROR(expr) \
|
||||
{ \
|
||||
auto s___ = (expr); \
|
||||
if (!s___.ok()) throw std::runtime_error(s___.error_message()); \
|
||||
#define JAX_THROW_IF_ERROR(expr) \
|
||||
{ \
|
||||
auto s___ = (expr); \
|
||||
if (!s___.ok()) throw std::runtime_error(std::string(s___.message())); \
|
||||
}
|
||||
|
||||
#define JAX_RETURN_IF_ERROR(expr) \
|
||||
|
@ -102,8 +102,8 @@ void CudaLuPivotsToPermutation(cudaStream_t stream, void** buffers,
|
||||
XlaCustomCallStatus* status) {
|
||||
auto s = CudaLuPivotsToPermutation_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -136,8 +136,8 @@ void CudaThreeFry2x32(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CudaThreeFry2x32_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -25,15 +25,15 @@ limitations under the License.
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "include/pybind11/numpy.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "include/pybind11/stl.h"
|
||||
#include "jaxlib/cuda_gpu_kernel_helpers.h"
|
||||
#include "jaxlib/handle_pool.h"
|
||||
#include "jaxlib/kernel_pybind11_helpers.h"
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||
#include "third_party/gpus/cuda/include/cusolverDn.h"
|
||||
#include "jaxlib/cuda_gpu_kernel_helpers.h"
|
||||
#include "jaxlib/handle_pool.h"
|
||||
#include "jaxlib/kernel_pybind11_helpers.h"
|
||||
#include "include/pybind11/numpy.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "include/pybind11/stl.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
@ -245,8 +245,8 @@ void Potrf(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = Potrf_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
@ -370,8 +370,8 @@ void Getrf(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = Getrf_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
@ -498,8 +498,8 @@ void Geqrf(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = Geqrf_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
@ -630,8 +630,8 @@ void Orgqr(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = Orgqr_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
@ -757,8 +757,8 @@ void Syevd(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = Syevd_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
@ -939,8 +939,8 @@ void Syevj(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = Syevj_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
@ -1089,8 +1089,8 @@ void Gesvd(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = Gesvd_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
@ -1299,8 +1299,8 @@ void Gesvdj(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = Gesvdj_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/gpus/cuda/include/cusparse.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <stdexcept>
|
||||
@ -27,7 +29,6 @@ limitations under the License.
|
||||
#include "third_party/gpus/cuda/include/cuComplex.h"
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||
#include "third_party/gpus/cuda/include/cusparse.h"
|
||||
#include "jaxlib/cuda_gpu_kernel_helpers.h"
|
||||
#include "jaxlib/handle_pool.h"
|
||||
#include "jaxlib/kernel_pybind11_helpers.h"
|
||||
@ -312,8 +313,8 @@ void CsrToDense(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CsrToDense_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
@ -387,8 +388,8 @@ void CsrFromDense(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CsrFromDense_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
@ -495,8 +496,8 @@ void CsrMatvec(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CsrMatvec_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
@ -606,8 +607,8 @@ void CsrMatmat(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CsrMatmat_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
@ -681,8 +682,8 @@ void CooToDense(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CooToDense_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
@ -756,8 +757,8 @@ void CooFromDense(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CooFromDense_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
@ -863,8 +864,8 @@ void CooMatvec(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CooMatvec_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
@ -973,8 +974,8 @@ void CooMatmat(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CooMatmat_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
#endif // if JAX_CUSPARSE_11030
|
||||
@ -1029,8 +1030,8 @@ void gtsv2_f32(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
std::size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = gtsv2<float>(cusparseSgtsv2, stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
@ -1038,8 +1039,8 @@ void gtsv2_f64(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
std::size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = gtsv2<double>(cusparseDgtsv2, stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -13,29 +13,28 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "rocm/include/rocblas.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <stdexcept>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
|
||||
#include "rocm/include/hip/hip_runtime.h"
|
||||
#include "rocm/include/hip/hip_runtime_api.h"
|
||||
#include "rocm/include/rocsolver.h"
|
||||
#include "absl/base/casts.h"
|
||||
#include "absl/base/thread_annotations.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "include/pybind11/numpy.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "include/pybind11/stl.h"
|
||||
#include "jaxlib/handle_pool.h"
|
||||
#include "jaxlib/kernel_pybind11_helpers.h"
|
||||
#include "jaxlib/rocm_gpu_kernel_helpers.h"
|
||||
#include "include/pybind11/numpy.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "include/pybind11/stl.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
#include "rocm/include/hip/hip_runtime.h"
|
||||
#include "rocm/include/hip/hip_runtime_api.h"
|
||||
#include "rocm/include/rocblas.h"
|
||||
#include "rocm/include/rocsolver.h"
|
||||
|
||||
|
||||
namespace jax {
|
||||
|
||||
@ -280,8 +279,8 @@ void Trsm(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = Trsm_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
@ -399,8 +398,8 @@ void Potrf(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = Potrf_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
@ -519,8 +518,8 @@ void Getrf(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = Getrf_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
@ -649,8 +648,8 @@ void Geqrf(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = Geqrf_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
@ -746,8 +745,8 @@ void Orgqr(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = Orgqr_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
@ -942,8 +941,8 @@ void Gesvd(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = Gesvd_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user