From 2693afa263ed651404098fd98ea15b2a8c605a9e Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 7 Oct 2022 14:35:41 -0700 Subject: [PATCH] Revert: Use input-output aliasing for jaxlib GPU custom calls. Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need. This seems to be causing some test failures downstream, so reverting this for the moment until I can debug them. PiperOrigin-RevId: 479670565 --- jaxlib/cuda/cublas_kernels.cc | 15 ++++++++ jaxlib/cuda/cusolver_kernels.cc | 49 ++++++++++++++++++++++++++ jaxlib/cuda/cusparse_kernels.cc | 25 ++++++++----- jaxlib/gpu_solver.py | 31 ++++++----------- jaxlib/gpu_sparse.py | 3 +- jaxlib/mhlo_helpers.py | 60 +++++++++++--------------------- jaxlib/rocm/hipblas_kernels.cc | 15 ++++++++ jaxlib/rocm/hipsolver_kernels.cc | 44 +++++++++++++++++++++++ jaxlib/rocm/hipsparse_kernels.cc | 12 +++++-- 9 files changed, 180 insertions(+), 74 deletions(-) diff --git a/jaxlib/cuda/cublas_kernels.cc b/jaxlib/cuda/cublas_kernels.cc index cc39f9995..2669ae06c 100644 --- a/jaxlib/cuda/cublas_kernels.cc +++ b/jaxlib/cuda/cublas_kernels.cc @@ -83,6 +83,11 @@ static absl::Status TrsmBatched_(cudaStream_t stream, void** buffers, auto h = BlasHandlePool::Borrow(stream); JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; + if (buffers[2] != buffers[1]) { + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync( + buffers[2], buffers[1], SizeOfCublasType(d.type) * d.batch * d.m * d.n, + cudaMemcpyDeviceToDevice, stream))); + } const int lda = d.side == CUBLAS_SIDE_LEFT ? d.m : d.n; const int ldb = d.m; auto a_batch_host = MakeBatchPointers(stream, buffers[0], buffers[3], d.batch, @@ -162,6 +167,11 @@ static absl::Status GetrfBatched_(cudaStream_t stream, void** buffers, auto h = BlasHandlePool::Borrow(stream); JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; + if (buffers[0] != buffers[1]) { + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync( + buffers[1], buffers[0], SizeOfCublasType(d.type) * d.batch * d.n * d.n, + cudaMemcpyDeviceToDevice, stream))); + } int* ipiv = static_cast(buffers[2]); int* info = static_cast(buffers[3]); @@ -220,6 +230,11 @@ static absl::Status GeqrfBatched_(cudaStream_t stream, void** buffers, auto h = BlasHandlePool::Borrow(stream); JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; + if (buffers[0] != buffers[1]) { + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync( + buffers[1], buffers[0], SizeOfCublasType(d.type) * d.batch * d.m * d.n, + cudaMemcpyDeviceToDevice, stream))); + } std::vector info(d.batch); auto a_ptrs_host = MakeBatchPointers(stream, buffers[1], buffers[3], d.batch, diff --git a/jaxlib/cuda/cusolver_kernels.cc b/jaxlib/cuda/cusolver_kernels.cc index 8f5baa3d6..7979eaceb 100644 --- a/jaxlib/cuda/cusolver_kernels.cc +++ b/jaxlib/cuda/cusolver_kernels.cc @@ -94,6 +94,12 @@ static absl::Status Potrf_(cudaStream_t stream, void** buffers, auto h = SolverHandlePool::Borrow(stream); JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; + if (buffers[1] != buffers[0]) { + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( + cudaMemcpyAsync(buffers[1], buffers[0], + SizeOfCusolverType(d.type) * d.batch * d.n * d.n, + cudaMemcpyDeviceToDevice, stream))); + } int* info = static_cast(buffers[2]); void* workspace = buffers[3]; @@ -186,6 +192,13 @@ static absl::Status Getrf_(cudaStream_t stream, void** buffers, auto h = SolverHandlePool::Borrow(stream); JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; + if (buffers[1] != buffers[0]) { + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync( + buffers[1], buffers[0], + SizeOfCusolverType(d.type) * static_cast(d.batch) * + static_cast(d.m) * static_cast(d.n), + cudaMemcpyDeviceToDevice, stream))); + } int* ipiv = static_cast(buffers[2]); int* info = static_cast(buffers[3]); @@ -262,6 +275,13 @@ static absl::Status Geqrf_(cudaStream_t stream, void** buffers, auto h = SolverHandlePool::Borrow(stream); JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; + if (buffers[1] != buffers[0]) { + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync( + buffers[1], buffers[0], + SizeOfCusolverType(d.type) * static_cast(d.batch) * + static_cast(d.m) * static_cast(d.n), + cudaMemcpyDeviceToDevice, stream))); + } int* info = static_cast(buffers[3]); void* workspace = buffers[4]; @@ -438,6 +458,13 @@ static absl::Status Orgqr_(cudaStream_t stream, void** buffers, auto h = SolverHandlePool::Borrow(stream); JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; + if (buffers[2] != buffers[0]) { + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync( + buffers[2], buffers[0], + SizeOfCusolverType(d.type) * static_cast(d.batch) * + static_cast(d.m) * static_cast(d.n), + cudaMemcpyDeviceToDevice, stream))); + } int* info = static_cast(buffers[3]); void* workspace = buffers[4]; @@ -517,6 +544,11 @@ static absl::Status Syevd_(cudaStream_t stream, void** buffers, auto h = SolverHandlePool::Borrow(stream); JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync( + buffers[1], buffers[0], + SizeOfCusolverType(d.type) * static_cast(d.batch) * + static_cast(d.n) * static_cast(d.n), + cudaMemcpyDeviceToDevice, stream))); cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR; int* info = static_cast(buffers[3]); void* work = buffers[4]; @@ -597,6 +629,13 @@ absl::Status Syevj_(cudaStream_t stream, void** buffers, const char* opaque, auto h = SolverHandlePool::Borrow(stream); JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; + if (buffers[1] != buffers[0]) { + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync( + buffers[1], buffers[0], + SizeOfCusolverType(d.type) * static_cast(d.batch) * + static_cast(d.n) * static_cast(d.n), + cudaMemcpyDeviceToDevice, stream))); + } syevjInfo_t params; JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCreateSyevjInfo(¶ms))); std::unique_ptr params_cleanup( @@ -699,6 +738,11 @@ static absl::Status Gesvd_(cudaStream_t stream, void** buffers, auto h = SolverHandlePool::Borrow(stream); JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync( + buffers[1], buffers[0], + SizeOfCusolverType(d.type) * static_cast(d.batch) * + static_cast(d.m) * static_cast(d.n), + cudaMemcpyDeviceToDevice, stream))); int* info = static_cast(buffers[5]); void* work = buffers[6]; int64_t k = d.jobu == 'A' ? d.m : d.n; @@ -797,6 +841,11 @@ static absl::Status Gesvdj_(cudaStream_t stream, void** buffers, auto h = SolverHandlePool::Borrow(stream); JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync( + buffers[1], buffers[0], + SizeOfCusolverType(d.type) * static_cast(d.batch) * + static_cast(d.m) * static_cast(d.n), + cudaMemcpyDeviceToDevice, stream))); int* info = static_cast(buffers[5]); void* work = buffers[6]; gesvdjInfo_t params; diff --git a/jaxlib/cuda/cusparse_kernels.cc b/jaxlib/cuda/cusparse_kernels.cc index 0d6312736..44e2877d3 100644 --- a/jaxlib/cuda/cusparse_kernels.cc +++ b/jaxlib/cuda/cusparse_kernels.cc @@ -520,24 +520,25 @@ static absl::Status CooMatmat_(cudaStream_t stream, void** buffers, cusparseDnMatDescr_t mat_b = 0; cusparseDnMatDescr_t mat_c = 0; + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateCoo( &mat_a, d.A.rows, d.A.cols, d.A.nnz, coo_row_ind, coo_col_ind, coo_values, d.A.index_type, CUSPARSE_INDEX_BASE_ZERO, d.A.value_type))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cusparseCooSetStridedBatch(mat_a, /*batchCount=*/d.A.batch_count, - /*batchStride=*/d.A.batch_stride))); + cusparseCooSetStridedBatch(mat_a, /*batchCount=*/d.A.batch_count, + /*batchStride=*/d.A.batch_stride))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat( &mat_b, d.B.rows, d.B.cols, /*ld=*/d.B.cols, Bbuf, d.B.type, CUSPARSE_ORDER_ROW))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cusparseDnMatSetStridedBatch(mat_b, /*batchCount=*/d.B.batch_count, - /*batchStride=*/d.B.batch_stride))); + cusparseDnMatSetStridedBatch(mat_b, /*batchCount=*/d.B.batch_count, + /*batchStride=*/d.B.batch_stride))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat( &mat_c, d.C.rows, d.C.cols, /*ld=*/d.C.cols, Cbuf, d.C.type, CUSPARSE_ORDER_ROW))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cusparseDnMatSetStridedBatch(mat_c, /*batchCount=*/d.C.batch_count, - /*batchStride=*/d.C.batch_stride))); + cusparseDnMatSetStridedBatch(mat_c, /*batchCount=*/d.C.batch_count, + /*batchStride=*/d.C.batch_stride))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseSpMM( handle.get(), d.op_A, /*opB=*/CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a, mat_b, &beta, mat_c, d.C.type, CUSPARSE_SPMM_ALG_DEFAULT, buf))); @@ -579,10 +580,16 @@ static absl::Status gtsv2(F computeGtsv2, cudaStream_t stream, void** buffers, T* X = (T*)(buffers[4]); void* buffer = buffers[5]; - // The solution X is written in place to B. + // The solution X is written in place to B. We need to therefore copy the + // contents of B into the output buffer X and pass that into the kernel as B. + // Once copy insertion is supported for custom call aliasing, we could alias B + // with X and avoid the copy, the code below is written defensively assuming B + // and X might alias, but today we know they will not. + // TODO(b/182906199): Update the comment here once copy insertion is WAI. if (X != B) { - return absl::InvalidArgumentError( - "Input and output buffers to gtsv2 must alias"); + size_t B_bytes = ldb * n * sizeof(T); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( + cudaMemcpyAsync(X, B, B_bytes, cudaMemcpyDeviceToDevice, stream))); } JAX_RETURN_IF_ERROR(JAX_AS_STATUS( diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index 012e95b62..b715c1e54 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -98,8 +98,7 @@ def _trsm_mhlo(platform, gpu_blas, dtype, a, b, left_side=False, lower=False, [a, b], backend_config=opaque, operand_layouts=[layout] * 2, - result_layouts=[layout, work_layout, work_layout], - operand_output_aliases={1: 0}) + result_layouts=[layout, work_layout, work_layout]) return out[0] cuda_trsm = partial(_trsm_mhlo, "cu", _cublas) @@ -134,8 +133,7 @@ def _potrf_mhlo(platform, gpu_solver, dtype, a, lower): [a], backend_config=opaque, operand_layouts=[layout], - result_layouts=[layout, info_layout, work_layout], - operand_output_aliases={0: 0}) + result_layouts=[layout, info_layout, work_layout]) return out[:2] cuda_potrf = partial(_potrf_mhlo, "cu", _cusolver) @@ -181,8 +179,7 @@ def _getrf_mhlo(platform, gpu_blas, gpu_solver, dtype, a): tuple(range(num_bd, -1, -1)), tuple(range(num_bd - 1, -1, -1)), [0], - ], - operand_output_aliases={0: 0}) + ]) return out[:3] cuda_getrf = partial(_getrf_mhlo, "cu", _cublas, _cusolver) @@ -220,8 +217,7 @@ def _geqrf_mhlo(platform, gpu_solver, dtype, a): tuple(range(num_bd, -1, -1)), tuple(range(num_bd - 1, -1, -1)), [0], - ], - operand_output_aliases={0: 0}) + ]) return out[:3] cuda_geqrf = partial(_geqrf_mhlo, "cu", _cusolver) @@ -257,9 +253,7 @@ def _geqrf_batched_mhlo(platform, gpu_blas, dtype, a): tuple(range(num_bd, -1, -1)), [0], [0], - ], - operand_output_aliases={0: 0} - ) + ]) return out[:2] cuda_geqrf_batched = partial(_geqrf_batched_mhlo, "cu", _cublas) @@ -327,8 +321,7 @@ def _orgqr_mhlo(platform, gpu_solver, dtype, a, tau): layout, tuple(range(num_bd - 1, -1, -1)), [0], - ], - operand_output_aliases={0: 0}) + ]) return out[:2] cuda_orgqr = partial(_orgqr_mhlo, "cu", _cusolver) @@ -379,8 +372,7 @@ def _syevd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a, tuple(range(num_bd, -1, -1)), tuple(range(num_bd - 1, -1, -1)), [0], - ], - operand_output_aliases={0: 0}) + ]) return out[:3] cuda_syevd = partial(_syevd_mhlo, "cu", _cusolver, True) @@ -435,8 +427,7 @@ def _gesvd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a, matrix_layout, scalar_layout, [0], - ], - operand_output_aliases={0: 0}) + ]) vt = mhlo.TransposeOp( v, ir.DenseIntElementsAttr.get(np.array(tuple(range(num_bd)) + (num_bd + 1, num_bd)))).result @@ -478,8 +469,7 @@ def _gesvd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a, matrix_layout, scalar_layout, [0], - ], - operand_output_aliases={0: 0}) + ]) else: lwork, opaque = gpu_solver.build_gesvd_descriptor( np.dtype(dtype), b, m, n, compute_uv, full_matrices) @@ -505,8 +495,7 @@ def _gesvd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a, matrix_layout, scalar_layout, [0], - ], - operand_output_aliases={0: 0}) + ]) return s, u, vt, info cuda_gesvd = partial(_gesvd_mhlo, "cu", _cusolver, True) diff --git a/jaxlib/gpu_sparse.py b/jaxlib/gpu_sparse.py index 99d459d6a..1299953a5 100644 --- a/jaxlib/gpu_sparse.py +++ b/jaxlib/gpu_sparse.py @@ -356,8 +356,7 @@ def _gtsv2_mhlo(platform, gpu_sparse, dl, d, du, B, *, m, n, ldb, t): [dl, d, du, B], backend_config=gpu_sparse.build_gtsv2_descriptor(m, n, ldb), operand_layouts=[[0]] * 3 + [[1, 0]], - result_layouts=[[1, 0], [0]], - operand_output_aliases={3: 0}) + result_layouts=[[1, 0], [0]]) return out[0] cuda_gtsv2 = partial(_gtsv2_mhlo, "cu", _cusparse) diff --git a/jaxlib/mhlo_helpers.py b/jaxlib/mhlo_helpers.py index 904d9bba3..aa6d62dce 100644 --- a/jaxlib/mhlo_helpers.py +++ b/jaxlib/mhlo_helpers.py @@ -14,37 +14,28 @@ # Helpers for building MHLO operators -from typing import Dict, Optional, Sequence, Union +from typing import Optional, Sequence, Union import jaxlib.mlir.ir as ir import jaxlib.mlir.dialects.mhlo as mhlo import numpy as np - -def custom_call( - call_target_name: str, - out_types: Sequence[ir.Type], - operands: Sequence[ir.Value], - operand_layouts: Sequence[Sequence[int]], - result_layouts: Sequence[Sequence[int]], - backend_config: Optional[str] = None, - has_side_effect: bool = False, - api_version: int = 2, - operand_output_aliases: Dict[int, int] = {}, -) -> Union[ir.Value, Sequence[ir.Value]]: +def custom_call(call_target_name: str, out_types: Sequence[ir.Type], + operands: Sequence[ir.Value], + operand_layouts: Sequence[Sequence[int]], + result_layouts: Sequence[Sequence[int]], + backend_config: Optional[str] = None, + has_side_effect: bool = False, + api_version: int = 2, + ) -> Union[ir.Value, Sequence[ir.Value]]: """Less-verbose helper for building an MHLO custom call op. Once https://github.com/llvm/llvm-project/issues/54932 is fixed, this helper may be able to go away. - - Args: - ... - operand_output_alias: a dictionary mapping input numbers -> output numbers - that must alias. """ i32_type = ir.IntegerType.get_signless(32) out = mhlo.CustomCallOp( - (out_types - if len(out_types) == 1 else [ir.TupleType.get_tuple(out_types)]), + (out_types if len(out_types) == 1 else + [ir.TupleType.get_tuple(out_types)]), operands, call_target_name=ir.StringAttr.get(call_target_name), has_side_effect=ir.BoolAttr.get(has_side_effect), @@ -52,27 +43,18 @@ def custom_call( "" if backend_config is None else backend_config), api_version=ir.IntegerAttr.get(i32_type, api_version), called_computations=ir.ArrayAttr.get([]), - operand_layouts=ir.ArrayAttr.get([ - ir.DenseIntElementsAttr.get( + operand_layouts=ir.ArrayAttr.get( + [ir.DenseIntElementsAttr.get( np.atleast_1d(np.asarray(l, dtype=np.int64)), - type=ir.IndexType.get()) for l in operand_layouts - ]), - result_layouts=ir.ArrayAttr.get([ - ir.DenseIntElementsAttr.get( + type=ir.IndexType.get()) + for l in operand_layouts]), + result_layouts=ir.ArrayAttr.get( + [ir.DenseIntElementsAttr.get( np.atleast_1d(np.asarray(l, dtype=np.int64)), - type=ir.IndexType.get()) for l in result_layouts - ]), - output_operand_aliases=ir.ArrayAttr.get([ - mhlo.OutputOperandAlias.get( - output_tuple_indices=[] if len(out_types) == 1 else [output], - operand_index=input, - operand_tuple_indices=[]) - for input, output in operand_output_aliases.items() - ])) + type=ir.IndexType.get()) + for l in result_layouts])) if len(out_types) == 1: return out.result else: - return [ - mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result - for i in range(len(out_types)) - ] + return [mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result + for i in range(len(out_types))] diff --git a/jaxlib/rocm/hipblas_kernels.cc b/jaxlib/rocm/hipblas_kernels.cc index a0047e96e..948e7dad8 100644 --- a/jaxlib/rocm/hipblas_kernels.cc +++ b/jaxlib/rocm/hipblas_kernels.cc @@ -82,6 +82,11 @@ static absl::Status TrsmBatched_(hipStream_t stream, void** buffers, auto h = BlasHandlePool::Borrow(stream); JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; + if (buffers[2] != buffers[1]) { + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync( + buffers[2], buffers[1], SizeOfHipblasType(d.type) * d.batch * d.m * d.n, + hipMemcpyDeviceToDevice, stream))); + } const int lda = d.side == HIPBLAS_SIDE_LEFT ? d.m : d.n; const int ldb = d.m; auto a_batch_host = MakeBatchPointers(stream, buffers[0], buffers[3], d.batch, @@ -161,6 +166,11 @@ static absl::Status GetrfBatched_(hipStream_t stream, void** buffers, auto h = BlasHandlePool::Borrow(stream); JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; + if (buffers[0] != buffers[1]) { + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync( + buffers[1], buffers[0], SizeOfHipblasType(d.type) * d.batch * d.n * d.n, + hipMemcpyDeviceToDevice, stream))); + } int* ipiv = static_cast(buffers[2]); int* info = static_cast(buffers[3]); @@ -220,6 +230,11 @@ static absl::Status GeqrfBatched_(hipStream_t stream, void** buffers, auto h = BlasHandlePool::Borrow(stream); JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; + if (buffers[0] != buffers[1]) { + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync( + buffers[1], buffers[0], SizeOfHipblasType(d.type) * d.batch * d.m * d.n, + hipMemcpyDeviceToDevice, stream))); + } std::vector info(d.batch); auto a_ptrs_host = MakeBatchPointers(stream, buffers[1], buffers[3], d.batch, diff --git a/jaxlib/rocm/hipsolver_kernels.cc b/jaxlib/rocm/hipsolver_kernels.cc index cd7e33957..9f5f14128 100644 --- a/jaxlib/rocm/hipsolver_kernels.cc +++ b/jaxlib/rocm/hipsolver_kernels.cc @@ -74,6 +74,12 @@ static absl::Status Potrf_(hipStream_t stream, void** buffers, auto h = SolverHandlePool::Borrow(stream); JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; + if (buffers[1] != buffers[0]) { + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( + hipMemcpyAsync(buffers[1], buffers[0], + SizeOfHipsolverType(d.type) * d.batch * d.n * d.n, + hipMemcpyDeviceToDevice, stream))); + } int* info = static_cast(buffers[2]); void* workspace = buffers[3]; @@ -169,6 +175,13 @@ static absl::Status Getrf_(hipStream_t stream, void** buffers, auto h = SolverHandlePool::Borrow(stream); JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; + if (buffers[1] != buffers[0]) { + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync( + buffers[1], buffers[0], + SizeOfHipsolverType(d.type) * static_cast(d.batch) * + static_cast(d.m) * static_cast(d.n), + hipMemcpyDeviceToDevice, stream))); + } int* ipiv = static_cast(buffers[2]); int* info = static_cast(buffers[3]); @@ -245,6 +258,13 @@ static absl::Status Geqrf_(hipStream_t stream, void** buffers, auto h = SolverHandlePool::Borrow(stream); JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; + if (buffers[1] != buffers[0]) { + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync( + buffers[1], buffers[0], + SizeOfHipsolverType(d.type) * static_cast(d.batch) * + static_cast(d.m) * static_cast(d.n), + hipMemcpyDeviceToDevice, stream))); + } int* info = static_cast(buffers[3]); @@ -325,6 +345,13 @@ static absl::Status Orgqr_(hipStream_t stream, void** buffers, auto h = SolverHandlePool::Borrow(stream); JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; + if (buffers[2] != buffers[0]) { + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync( + buffers[2], buffers[0], + SizeOfHipsolverType(d.type) * static_cast(d.batch) * + static_cast(d.m) * static_cast(d.n), + hipMemcpyDeviceToDevice, stream))); + } int* info = static_cast(buffers[3]); @@ -405,6 +432,11 @@ static absl::Status Syevd_(hipStream_t stream, void** buffers, auto h = SolverHandlePool::Borrow(stream); JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync( + buffers[1], buffers[0], + SizeOfHipsolverType(d.type) * static_cast(d.batch) * + static_cast(d.n) * static_cast(d.n), + hipMemcpyDeviceToDevice, stream))); hipsolverEigMode_t jobz = HIPSOLVER_EIG_MODE_VECTOR; int* info = static_cast(buffers[3]); void* work = buffers[4]; @@ -485,6 +517,13 @@ absl::Status Syevj_(hipStream_t stream, void** buffers, const char* opaque, auto h = SolverHandlePool::Borrow(stream); JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; + if (buffers[1] != buffers[0]) { + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync( + buffers[1], buffers[0], + SizeOfHipsolverType(d.type) * static_cast(d.batch) * + static_cast(d.n) * static_cast(d.n), + hipMemcpyDeviceToDevice, stream))); + } hipsolverSyevjInfo_t params; JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCreateSyevjInfo(¶ms))); std::unique_ptr params_cleanup( @@ -586,6 +625,11 @@ static absl::Status Gesvd_(hipStream_t stream, void** buffers, auto h = SolverHandlePool::Borrow(stream); JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync( + buffers[1], buffers[0], + SizeOfHipsolverType(d.type) * static_cast(d.batch) * + static_cast(d.m) * static_cast(d.n), + hipMemcpyDeviceToDevice, stream))); int* info = static_cast(buffers[5]); void* work = buffers[6]; switch (d.type) { diff --git a/jaxlib/rocm/hipsparse_kernels.cc b/jaxlib/rocm/hipsparse_kernels.cc index c51d266c3..b9e061c9d 100644 --- a/jaxlib/rocm/hipsparse_kernels.cc +++ b/jaxlib/rocm/hipsparse_kernels.cc @@ -504,10 +504,16 @@ static absl::Status gtsv2(F computeGtsv2, hipStream_t stream, void** buffers, T* X = (T*)(buffers[4]); void* buffer = buffers[5]; - // The solution X is written in place to B. + // The solution X is written in place to B. We need to therefore copy the + // contents of B into the output buffer X and pass that into the kernel as B. + // Once copy insertion is supported for custom call aliasing, we could alias B + // with X and avoid the copy, the code below is written defensively assuming B + // and X might alias, but today we know they will not. + // TODO(b/182906199): Update the comment here once copy insertion is WAI. if (X != B) { - return absl::InvalidArgumentError( - "Input and output buffers to gtsv2 must alias"); + size_t B_bytes = ldb * n * sizeof(T); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( + hipMemcpyAsync(X, B, B_bytes, hipMemcpyDeviceToDevice, stream))); } JAX_RETURN_IF_ERROR(JAX_AS_STATUS(