Revert: Use input-output aliasing for jaxlib GPU custom calls.

Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need.

This seems to be causing some test failures downstream, so reverting this for the moment until I can debug them.

PiperOrigin-RevId: 479670565
This commit is contained in:
Peter Hawkins 2022-10-07 14:35:41 -07:00 committed by jax authors
parent e8ba61d82b
commit 2693afa263
9 changed files with 180 additions and 74 deletions

View File

@ -83,6 +83,11 @@ static absl::Status TrsmBatched_(cudaStream_t stream, void** buffers,
auto h = BlasHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
if (buffers[2] != buffers[1]) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync(
buffers[2], buffers[1], SizeOfCublasType(d.type) * d.batch * d.m * d.n,
cudaMemcpyDeviceToDevice, stream)));
}
const int lda = d.side == CUBLAS_SIDE_LEFT ? d.m : d.n;
const int ldb = d.m;
auto a_batch_host = MakeBatchPointers(stream, buffers[0], buffers[3], d.batch,
@ -162,6 +167,11 @@ static absl::Status GetrfBatched_(cudaStream_t stream, void** buffers,
auto h = BlasHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
if (buffers[0] != buffers[1]) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync(
buffers[1], buffers[0], SizeOfCublasType(d.type) * d.batch * d.n * d.n,
cudaMemcpyDeviceToDevice, stream)));
}
int* ipiv = static_cast<int*>(buffers[2]);
int* info = static_cast<int*>(buffers[3]);
@ -220,6 +230,11 @@ static absl::Status GeqrfBatched_(cudaStream_t stream, void** buffers,
auto h = BlasHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
if (buffers[0] != buffers[1]) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync(
buffers[1], buffers[0], SizeOfCublasType(d.type) * d.batch * d.m * d.n,
cudaMemcpyDeviceToDevice, stream)));
}
std::vector<int> info(d.batch);
auto a_ptrs_host = MakeBatchPointers(stream, buffers[1], buffers[3], d.batch,

View File

@ -94,6 +94,12 @@ static absl::Status Potrf_(cudaStream_t stream, void** buffers,
auto h = SolverHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
if (buffers[1] != buffers[0]) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cudaMemcpyAsync(buffers[1], buffers[0],
SizeOfCusolverType(d.type) * d.batch * d.n * d.n,
cudaMemcpyDeviceToDevice, stream)));
}
int* info = static_cast<int*>(buffers[2]);
void* workspace = buffers[3];
@ -186,6 +192,13 @@ static absl::Status Getrf_(cudaStream_t stream, void** buffers,
auto h = SolverHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
if (buffers[1] != buffers[0]) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync(
buffers[1], buffers[0],
SizeOfCusolverType(d.type) * static_cast<std::int64_t>(d.batch) *
static_cast<std::int64_t>(d.m) * static_cast<std::int64_t>(d.n),
cudaMemcpyDeviceToDevice, stream)));
}
int* ipiv = static_cast<int*>(buffers[2]);
int* info = static_cast<int*>(buffers[3]);
@ -262,6 +275,13 @@ static absl::Status Geqrf_(cudaStream_t stream, void** buffers,
auto h = SolverHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
if (buffers[1] != buffers[0]) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync(
buffers[1], buffers[0],
SizeOfCusolverType(d.type) * static_cast<std::int64_t>(d.batch) *
static_cast<std::int64_t>(d.m) * static_cast<std::int64_t>(d.n),
cudaMemcpyDeviceToDevice, stream)));
}
int* info = static_cast<int*>(buffers[3]);
void* workspace = buffers[4];
@ -438,6 +458,13 @@ static absl::Status Orgqr_(cudaStream_t stream, void** buffers,
auto h = SolverHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
if (buffers[2] != buffers[0]) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync(
buffers[2], buffers[0],
SizeOfCusolverType(d.type) * static_cast<std::int64_t>(d.batch) *
static_cast<std::int64_t>(d.m) * static_cast<std::int64_t>(d.n),
cudaMemcpyDeviceToDevice, stream)));
}
int* info = static_cast<int*>(buffers[3]);
void* workspace = buffers[4];
@ -517,6 +544,11 @@ static absl::Status Syevd_(cudaStream_t stream, void** buffers,
auto h = SolverHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync(
buffers[1], buffers[0],
SizeOfCusolverType(d.type) * static_cast<std::int64_t>(d.batch) *
static_cast<std::int64_t>(d.n) * static_cast<std::int64_t>(d.n),
cudaMemcpyDeviceToDevice, stream)));
cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR;
int* info = static_cast<int*>(buffers[3]);
void* work = buffers[4];
@ -597,6 +629,13 @@ absl::Status Syevj_(cudaStream_t stream, void** buffers, const char* opaque,
auto h = SolverHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
if (buffers[1] != buffers[0]) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync(
buffers[1], buffers[0],
SizeOfCusolverType(d.type) * static_cast<std::int64_t>(d.batch) *
static_cast<std::int64_t>(d.n) * static_cast<std::int64_t>(d.n),
cudaMemcpyDeviceToDevice, stream)));
}
syevjInfo_t params;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCreateSyevjInfo(&params)));
std::unique_ptr<syevjInfo, void (*)(syevjInfo*)> params_cleanup(
@ -699,6 +738,11 @@ static absl::Status Gesvd_(cudaStream_t stream, void** buffers,
auto h = SolverHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync(
buffers[1], buffers[0],
SizeOfCusolverType(d.type) * static_cast<std::int64_t>(d.batch) *
static_cast<std::int64_t>(d.m) * static_cast<std::int64_t>(d.n),
cudaMemcpyDeviceToDevice, stream)));
int* info = static_cast<int*>(buffers[5]);
void* work = buffers[6];
int64_t k = d.jobu == 'A' ? d.m : d.n;
@ -797,6 +841,11 @@ static absl::Status Gesvdj_(cudaStream_t stream, void** buffers,
auto h = SolverHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync(
buffers[1], buffers[0],
SizeOfCusolverType(d.type) * static_cast<std::int64_t>(d.batch) *
static_cast<std::int64_t>(d.m) * static_cast<std::int64_t>(d.n),
cudaMemcpyDeviceToDevice, stream)));
int* info = static_cast<int*>(buffers[5]);
void* work = buffers[6];
gesvdjInfo_t params;

View File

@ -520,24 +520,25 @@ static absl::Status CooMatmat_(cudaStream_t stream, void** buffers,
cusparseDnMatDescr_t mat_b = 0;
cusparseDnMatDescr_t mat_c = 0;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateCoo(
&mat_a, d.A.rows, d.A.cols, d.A.nnz, coo_row_ind, coo_col_ind, coo_values,
d.A.index_type, CUSPARSE_INDEX_BASE_ZERO, d.A.value_type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cusparseCooSetStridedBatch(mat_a, /*batchCount=*/d.A.batch_count,
/*batchStride=*/d.A.batch_stride)));
cusparseCooSetStridedBatch(mat_a, /*batchCount=*/d.A.batch_count,
/*batchStride=*/d.A.batch_stride)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
&mat_b, d.B.rows, d.B.cols,
/*ld=*/d.B.cols, Bbuf, d.B.type, CUSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cusparseDnMatSetStridedBatch(mat_b, /*batchCount=*/d.B.batch_count,
/*batchStride=*/d.B.batch_stride)));
cusparseDnMatSetStridedBatch(mat_b, /*batchCount=*/d.B.batch_count,
/*batchStride=*/d.B.batch_stride)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
&mat_c, d.C.rows, d.C.cols,
/*ld=*/d.C.cols, Cbuf, d.C.type, CUSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cusparseDnMatSetStridedBatch(mat_c, /*batchCount=*/d.C.batch_count,
/*batchStride=*/d.C.batch_stride)));
cusparseDnMatSetStridedBatch(mat_c, /*batchCount=*/d.C.batch_count,
/*batchStride=*/d.C.batch_stride)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseSpMM(
handle.get(), d.op_A, /*opB=*/CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha,
mat_a, mat_b, &beta, mat_c, d.C.type, CUSPARSE_SPMM_ALG_DEFAULT, buf)));
@ -579,10 +580,16 @@ static absl::Status gtsv2(F computeGtsv2, cudaStream_t stream, void** buffers,
T* X = (T*)(buffers[4]);
void* buffer = buffers[5];
// The solution X is written in place to B.
// The solution X is written in place to B. We need to therefore copy the
// contents of B into the output buffer X and pass that into the kernel as B.
// Once copy insertion is supported for custom call aliasing, we could alias B
// with X and avoid the copy, the code below is written defensively assuming B
// and X might alias, but today we know they will not.
// TODO(b/182906199): Update the comment here once copy insertion is WAI.
if (X != B) {
return absl::InvalidArgumentError(
"Input and output buffers to gtsv2 must alias");
size_t B_bytes = ldb * n * sizeof(T);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cudaMemcpyAsync(X, B, B_bytes, cudaMemcpyDeviceToDevice, stream)));
}
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(

View File

@ -98,8 +98,7 @@ def _trsm_mhlo(platform, gpu_blas, dtype, a, b, left_side=False, lower=False,
[a, b],
backend_config=opaque,
operand_layouts=[layout] * 2,
result_layouts=[layout, work_layout, work_layout],
operand_output_aliases={1: 0})
result_layouts=[layout, work_layout, work_layout])
return out[0]
cuda_trsm = partial(_trsm_mhlo, "cu", _cublas)
@ -134,8 +133,7 @@ def _potrf_mhlo(platform, gpu_solver, dtype, a, lower):
[a],
backend_config=opaque,
operand_layouts=[layout],
result_layouts=[layout, info_layout, work_layout],
operand_output_aliases={0: 0})
result_layouts=[layout, info_layout, work_layout])
return out[:2]
cuda_potrf = partial(_potrf_mhlo, "cu", _cusolver)
@ -181,8 +179,7 @@ def _getrf_mhlo(platform, gpu_blas, gpu_solver, dtype, a):
tuple(range(num_bd, -1, -1)),
tuple(range(num_bd - 1, -1, -1)),
[0],
],
operand_output_aliases={0: 0})
])
return out[:3]
cuda_getrf = partial(_getrf_mhlo, "cu", _cublas, _cusolver)
@ -220,8 +217,7 @@ def _geqrf_mhlo(platform, gpu_solver, dtype, a):
tuple(range(num_bd, -1, -1)),
tuple(range(num_bd - 1, -1, -1)),
[0],
],
operand_output_aliases={0: 0})
])
return out[:3]
cuda_geqrf = partial(_geqrf_mhlo, "cu", _cusolver)
@ -257,9 +253,7 @@ def _geqrf_batched_mhlo(platform, gpu_blas, dtype, a):
tuple(range(num_bd, -1, -1)),
[0],
[0],
],
operand_output_aliases={0: 0}
)
])
return out[:2]
cuda_geqrf_batched = partial(_geqrf_batched_mhlo, "cu", _cublas)
@ -327,8 +321,7 @@ def _orgqr_mhlo(platform, gpu_solver, dtype, a, tau):
layout,
tuple(range(num_bd - 1, -1, -1)),
[0],
],
operand_output_aliases={0: 0})
])
return out[:2]
cuda_orgqr = partial(_orgqr_mhlo, "cu", _cusolver)
@ -379,8 +372,7 @@ def _syevd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
tuple(range(num_bd, -1, -1)),
tuple(range(num_bd - 1, -1, -1)),
[0],
],
operand_output_aliases={0: 0})
])
return out[:3]
cuda_syevd = partial(_syevd_mhlo, "cu", _cusolver, True)
@ -435,8 +427,7 @@ def _gesvd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
matrix_layout,
scalar_layout,
[0],
],
operand_output_aliases={0: 0})
])
vt = mhlo.TransposeOp(
v,
ir.DenseIntElementsAttr.get(np.array(tuple(range(num_bd)) + (num_bd + 1, num_bd)))).result
@ -478,8 +469,7 @@ def _gesvd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
matrix_layout,
scalar_layout,
[0],
],
operand_output_aliases={0: 0})
])
else:
lwork, opaque = gpu_solver.build_gesvd_descriptor(
np.dtype(dtype), b, m, n, compute_uv, full_matrices)
@ -505,8 +495,7 @@ def _gesvd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
matrix_layout,
scalar_layout,
[0],
],
operand_output_aliases={0: 0})
])
return s, u, vt, info
cuda_gesvd = partial(_gesvd_mhlo, "cu", _cusolver, True)

View File

@ -356,8 +356,7 @@ def _gtsv2_mhlo(platform, gpu_sparse, dl, d, du, B, *, m, n, ldb, t):
[dl, d, du, B],
backend_config=gpu_sparse.build_gtsv2_descriptor(m, n, ldb),
operand_layouts=[[0]] * 3 + [[1, 0]],
result_layouts=[[1, 0], [0]],
operand_output_aliases={3: 0})
result_layouts=[[1, 0], [0]])
return out[0]
cuda_gtsv2 = partial(_gtsv2_mhlo, "cu", _cusparse)

View File

@ -14,37 +14,28 @@
# Helpers for building MHLO operators
from typing import Dict, Optional, Sequence, Union
from typing import Optional, Sequence, Union
import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.mhlo as mhlo
import numpy as np
def custom_call(
call_target_name: str,
out_types: Sequence[ir.Type],
operands: Sequence[ir.Value],
operand_layouts: Sequence[Sequence[int]],
result_layouts: Sequence[Sequence[int]],
backend_config: Optional[str] = None,
has_side_effect: bool = False,
api_version: int = 2,
operand_output_aliases: Dict[int, int] = {},
) -> Union[ir.Value, Sequence[ir.Value]]:
def custom_call(call_target_name: str, out_types: Sequence[ir.Type],
operands: Sequence[ir.Value],
operand_layouts: Sequence[Sequence[int]],
result_layouts: Sequence[Sequence[int]],
backend_config: Optional[str] = None,
has_side_effect: bool = False,
api_version: int = 2,
) -> Union[ir.Value, Sequence[ir.Value]]:
"""Less-verbose helper for building an MHLO custom call op.
Once https://github.com/llvm/llvm-project/issues/54932 is fixed, this helper
may be able to go away.
Args:
...
operand_output_alias: a dictionary mapping input numbers -> output numbers
that must alias.
"""
i32_type = ir.IntegerType.get_signless(32)
out = mhlo.CustomCallOp(
(out_types
if len(out_types) == 1 else [ir.TupleType.get_tuple(out_types)]),
(out_types if len(out_types) == 1 else
[ir.TupleType.get_tuple(out_types)]),
operands,
call_target_name=ir.StringAttr.get(call_target_name),
has_side_effect=ir.BoolAttr.get(has_side_effect),
@ -52,27 +43,18 @@ def custom_call(
"" if backend_config is None else backend_config),
api_version=ir.IntegerAttr.get(i32_type, api_version),
called_computations=ir.ArrayAttr.get([]),
operand_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(
operand_layouts=ir.ArrayAttr.get(
[ir.DenseIntElementsAttr.get(
np.atleast_1d(np.asarray(l, dtype=np.int64)),
type=ir.IndexType.get()) for l in operand_layouts
]),
result_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(
type=ir.IndexType.get())
for l in operand_layouts]),
result_layouts=ir.ArrayAttr.get(
[ir.DenseIntElementsAttr.get(
np.atleast_1d(np.asarray(l, dtype=np.int64)),
type=ir.IndexType.get()) for l in result_layouts
]),
output_operand_aliases=ir.ArrayAttr.get([
mhlo.OutputOperandAlias.get(
output_tuple_indices=[] if len(out_types) == 1 else [output],
operand_index=input,
operand_tuple_indices=[])
for input, output in operand_output_aliases.items()
]))
type=ir.IndexType.get())
for l in result_layouts]))
if len(out_types) == 1:
return out.result
else:
return [
mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
for i in range(len(out_types))
]
return [mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
for i in range(len(out_types))]

View File

@ -82,6 +82,11 @@ static absl::Status TrsmBatched_(hipStream_t stream, void** buffers,
auto h = BlasHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
if (buffers[2] != buffers[1]) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
buffers[2], buffers[1], SizeOfHipblasType(d.type) * d.batch * d.m * d.n,
hipMemcpyDeviceToDevice, stream)));
}
const int lda = d.side == HIPBLAS_SIDE_LEFT ? d.m : d.n;
const int ldb = d.m;
auto a_batch_host = MakeBatchPointers(stream, buffers[0], buffers[3], d.batch,
@ -161,6 +166,11 @@ static absl::Status GetrfBatched_(hipStream_t stream, void** buffers,
auto h = BlasHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
if (buffers[0] != buffers[1]) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
buffers[1], buffers[0], SizeOfHipblasType(d.type) * d.batch * d.n * d.n,
hipMemcpyDeviceToDevice, stream)));
}
int* ipiv = static_cast<int*>(buffers[2]);
int* info = static_cast<int*>(buffers[3]);
@ -220,6 +230,11 @@ static absl::Status GeqrfBatched_(hipStream_t stream, void** buffers,
auto h = BlasHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
if (buffers[0] != buffers[1]) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
buffers[1], buffers[0], SizeOfHipblasType(d.type) * d.batch * d.m * d.n,
hipMemcpyDeviceToDevice, stream)));
}
std::vector<int> info(d.batch);
auto a_ptrs_host = MakeBatchPointers(stream, buffers[1], buffers[3], d.batch,

View File

@ -74,6 +74,12 @@ static absl::Status Potrf_(hipStream_t stream, void** buffers,
auto h = SolverHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
if (buffers[1] != buffers[0]) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipMemcpyAsync(buffers[1], buffers[0],
SizeOfHipsolverType(d.type) * d.batch * d.n * d.n,
hipMemcpyDeviceToDevice, stream)));
}
int* info = static_cast<int*>(buffers[2]);
void* workspace = buffers[3];
@ -169,6 +175,13 @@ static absl::Status Getrf_(hipStream_t stream, void** buffers,
auto h = SolverHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
if (buffers[1] != buffers[0]) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
buffers[1], buffers[0],
SizeOfHipsolverType(d.type) * static_cast<std::int64_t>(d.batch) *
static_cast<std::int64_t>(d.m) * static_cast<std::int64_t>(d.n),
hipMemcpyDeviceToDevice, stream)));
}
int* ipiv = static_cast<int*>(buffers[2]);
int* info = static_cast<int*>(buffers[3]);
@ -245,6 +258,13 @@ static absl::Status Geqrf_(hipStream_t stream, void** buffers,
auto h = SolverHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
if (buffers[1] != buffers[0]) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
buffers[1], buffers[0],
SizeOfHipsolverType(d.type) * static_cast<std::int64_t>(d.batch) *
static_cast<std::int64_t>(d.m) * static_cast<std::int64_t>(d.n),
hipMemcpyDeviceToDevice, stream)));
}
int* info = static_cast<int*>(buffers[3]);
@ -325,6 +345,13 @@ static absl::Status Orgqr_(hipStream_t stream, void** buffers,
auto h = SolverHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
if (buffers[2] != buffers[0]) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
buffers[2], buffers[0],
SizeOfHipsolverType(d.type) * static_cast<std::int64_t>(d.batch) *
static_cast<std::int64_t>(d.m) * static_cast<std::int64_t>(d.n),
hipMemcpyDeviceToDevice, stream)));
}
int* info = static_cast<int*>(buffers[3]);
@ -405,6 +432,11 @@ static absl::Status Syevd_(hipStream_t stream, void** buffers,
auto h = SolverHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
buffers[1], buffers[0],
SizeOfHipsolverType(d.type) * static_cast<std::int64_t>(d.batch) *
static_cast<std::int64_t>(d.n) * static_cast<std::int64_t>(d.n),
hipMemcpyDeviceToDevice, stream)));
hipsolverEigMode_t jobz = HIPSOLVER_EIG_MODE_VECTOR;
int* info = static_cast<int*>(buffers[3]);
void* work = buffers[4];
@ -485,6 +517,13 @@ absl::Status Syevj_(hipStream_t stream, void** buffers, const char* opaque,
auto h = SolverHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
if (buffers[1] != buffers[0]) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
buffers[1], buffers[0],
SizeOfHipsolverType(d.type) * static_cast<std::int64_t>(d.batch) *
static_cast<std::int64_t>(d.n) * static_cast<std::int64_t>(d.n),
hipMemcpyDeviceToDevice, stream)));
}
hipsolverSyevjInfo_t params;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCreateSyevjInfo(&params)));
std::unique_ptr<void, void (*)(hipsolverSyevjInfo_t)> params_cleanup(
@ -586,6 +625,11 @@ static absl::Status Gesvd_(hipStream_t stream, void** buffers,
auto h = SolverHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
buffers[1], buffers[0],
SizeOfHipsolverType(d.type) * static_cast<std::int64_t>(d.batch) *
static_cast<std::int64_t>(d.m) * static_cast<std::int64_t>(d.n),
hipMemcpyDeviceToDevice, stream)));
int* info = static_cast<int*>(buffers[5]);
void* work = buffers[6];
switch (d.type) {

View File

@ -504,10 +504,16 @@ static absl::Status gtsv2(F computeGtsv2, hipStream_t stream, void** buffers,
T* X = (T*)(buffers[4]);
void* buffer = buffers[5];
// The solution X is written in place to B.
// The solution X is written in place to B. We need to therefore copy the
// contents of B into the output buffer X and pass that into the kernel as B.
// Once copy insertion is supported for custom call aliasing, we could alias B
// with X and avoid the copy, the code below is written defensively assuming B
// and X might alias, but today we know they will not.
// TODO(b/182906199): Update the comment here once copy insertion is WAI.
if (X != B) {
return absl::InvalidArgumentError(
"Input and output buffers to gtsv2 must alias");
size_t B_bytes = ldb * n * sizeof(T);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipMemcpyAsync(X, B, B_bytes, hipMemcpyDeviceToDevice, stream)));
}
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(