[GPU] Add missing stream synchronization to tridiagonal_solve gtsv2 call.

May fix flaky failures in CI.

Make stream argument to Pool::Borrow() mandatory to minimize chance of forgetting it.

PiperOrigin-RevId: 530425766
This commit is contained in:
Peter Hawkins 2023-05-08 15:36:17 -07:00 committed by jax authors
parent 00b75aff82
commit f168a1560c
5 changed files with 20 additions and 20 deletions

View File

@ -75,7 +75,7 @@ DoRnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size,
int num_layers, int batch_size,
int max_seq_length, float dropout,
bool bidirectional) {
auto h = DnnHandlePool::Borrow();
auto h = DnnHandlePool::Borrow(/*stream=*/nullptr);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;

View File

@ -57,7 +57,7 @@ SolverType DtypeToSolverType(const py::dtype& np_type) {
std::pair<int, py::bytes> BuildGetrfDescriptor(const py::dtype& dtype, int b,
int m, int n) {
SolverType type = DtypeToSolverType(dtype);
auto h = SolverHandlePool::Borrow();
auto h = SolverHandlePool::Borrow(/*stream=*/nullptr);
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
int lwork;
@ -96,7 +96,7 @@ std::pair<int, py::bytes> BuildGetrfDescriptor(const py::dtype& dtype, int b,
std::pair<int, py::bytes> BuildGeqrfDescriptor(const py::dtype& dtype, int b,
int m, int n) {
SolverType type = DtypeToSolverType(dtype);
auto h = SolverHandlePool::Borrow();
auto h = SolverHandlePool::Borrow(/*stream=*/nullptr);
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
int lwork;
@ -148,7 +148,7 @@ py::bytes BuildCsrlsvqrDescriptor(const py::dtype& dtype, int n, int nnzA,
std::pair<int, py::bytes> BuildOrgqrDescriptor(const py::dtype& dtype, int b,
int m, int n, int k) {
SolverType type = DtypeToSolverType(dtype);
auto h = SolverHandlePool::Borrow();
auto h = SolverHandlePool::Borrow(/*stream=*/nullptr);
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
int lwork;
@ -191,7 +191,7 @@ std::pair<int, py::bytes> BuildOrgqrDescriptor(const py::dtype& dtype, int b,
std::pair<int, py::bytes> BuildSyevdDescriptor(const py::dtype& dtype,
bool lower, int b, int n) {
SolverType type = DtypeToSolverType(dtype);
auto h = SolverHandlePool::Borrow();
auto h = SolverHandlePool::Borrow(/*stream=*/nullptr);
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
int lwork;
@ -230,7 +230,7 @@ std::pair<int, py::bytes> BuildSyevdDescriptor(const py::dtype& dtype,
std::pair<int, py::bytes> BuildSyevjDescriptor(const py::dtype& dtype,
bool lower, int batch, int n) {
SolverType type = DtypeToSolverType(dtype);
auto h = SolverHandlePool::Borrow();
auto h = SolverHandlePool::Borrow(/*stream=*/nullptr);
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
int lwork;
@ -298,7 +298,7 @@ std::pair<int, py::bytes> BuildGesvdDescriptor(const py::dtype& dtype, int b,
int m, int n, bool compute_uv,
bool full_matrices) {
SolverType type = DtypeToSolverType(dtype);
auto h = SolverHandlePool::Borrow();
auto h = SolverHandlePool::Borrow(/*stream=*/nullptr);
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
int lwork;
@ -343,7 +343,7 @@ std::pair<int, py::bytes> BuildGesvdjDescriptor(const py::dtype& dtype,
int batch, int m, int n,
bool compute_uv, int econ) {
SolverType type = DtypeToSolverType(dtype);
auto h = SolverHandlePool::Borrow();
auto h = SolverHandlePool::Borrow(/*stream=*/nullptr);
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
int lwork;
@ -426,7 +426,7 @@ std::pair<int, py::bytes> BuildGesvdjDescriptor(const py::dtype& dtype,
std::pair<int, py::bytes> BuildSytrdDescriptor(const py::dtype& dtype,
bool lower, int b, int n) {
SolverType type = DtypeToSolverType(dtype);
auto h = SolverHandlePool::Borrow();
auto h = SolverHandlePool::Borrow(/*stream=*/nullptr);
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
int lwork;

View File

@ -108,7 +108,7 @@ DenseVecDescriptor BuildDenseVecDescriptor(const py::dtype& data_dtype,
std::pair<size_t, py::bytes> BuildCsrToDenseDescriptor(
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
int cols, int nnz) {
auto h = SparseHandlePool::Borrow();
auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
SparseMatDescriptor d =
@ -185,7 +185,7 @@ void CsrToDense(gpuStream_t stream, void** buffers, const char* opaque,
std::pair<size_t, py::bytes> BuildCsrFromDenseDescriptor(
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
int cols, int nnz) {
auto h = SparseHandlePool::Borrow();
auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
SparseMatDescriptor d =
@ -262,7 +262,7 @@ std::pair<size_t, py::bytes> BuildCsrMatvecDescriptor(
const py::dtype& data_dtype, const py::dtype& x_dtype,
const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
int cols, int nnz, bool transpose) {
auto h = SparseHandlePool::Borrow();
auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
SparseMatDescriptor A =
@ -310,7 +310,7 @@ std::pair<size_t, py::bytes> BuildCsrMatmatDescriptor(
const py::dtype& data_dtype, const py::dtype& b_dtype,
const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
int cols, int BCcols, int nnz, bool transpose) {
auto h = SparseHandlePool::Borrow();
auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
SparseMatDescriptor A =
@ -361,7 +361,7 @@ std::pair<size_t, py::bytes> BuildCsrMatmatDescriptor(
std::pair<size_t, py::bytes> BuildCooToDenseDescriptor(
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
int cols, int nnz) {
auto h = SparseHandlePool::Borrow();
auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
SparseMatDescriptor d =
@ -398,7 +398,7 @@ std::pair<size_t, py::bytes> BuildCooToDenseDescriptor(
std::pair<size_t, py::bytes> BuildCooFromDenseDescriptor(
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
int cols, int nnz) {
auto h = SparseHandlePool::Borrow();
auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
SparseMatDescriptor d =
@ -435,7 +435,7 @@ std::pair<size_t, py::bytes> BuildCooMatvecDescriptor(
const py::dtype& data_dtype, const py::dtype& x_dtype,
const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
int cols, int nnz, bool transpose) {
auto h = SparseHandlePool::Borrow();
auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
SparseMatDescriptor A =
@ -489,7 +489,7 @@ std::pair<size_t, py::bytes> BuildCooMatmatDescriptor(
// All three matrices A, B, and C must have the same batch count.
// Use batch stride to trigger individual mode, e.g.,
// `rhs_batch_stride = 0` for C_i = A_i B.
auto h = SparseHandlePool::Borrow();
auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
@ -554,7 +554,7 @@ py::bytes BuildGtsv2Descriptor(int m, int n, int ldb) {
template <typename F>
size_t Gtsv2BufferSize(F f, int m, int n, int ldb) {
auto h = SparseHandlePool::Borrow();
auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
size_t size;

View File

@ -550,7 +550,7 @@ void CooMatmat(gpuStream_t stream, void** buffers, const char* opaque,
template <typename T, typename F>
static absl::Status gtsv2(F computeGtsv2, gpuStream_t stream, void** buffers,
const char* opaque, std::size_t opaque_len) {
auto h = SparseHandlePool::Borrow();
auto h = SparseHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;

View File

@ -77,7 +77,7 @@ class HandlePool {
// Borrows a handle from the pool. If 'stream' is non-null, sets the stream
// associated with the handle.
static absl::StatusOr<Handle> Borrow(StreamType stream = nullptr);
static absl::StatusOr<Handle> Borrow(StreamType stream);
private:
static HandlePool<HandleType, StreamType>* Instance();