mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Revert: Use input-output aliasing for jaxlib GPU custom calls.
Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need. This seems to be causing some test failures downstream, so reverting this for the moment until I can debug them. PiperOrigin-RevId: 479670565
This commit is contained in:
parent
e8ba61d82b
commit
2693afa263
@ -83,6 +83,11 @@ static absl::Status TrsmBatched_(cudaStream_t stream, void** buffers,
|
||||
auto h = BlasHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
if (buffers[2] != buffers[1]) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync(
|
||||
buffers[2], buffers[1], SizeOfCublasType(d.type) * d.batch * d.m * d.n,
|
||||
cudaMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
const int lda = d.side == CUBLAS_SIDE_LEFT ? d.m : d.n;
|
||||
const int ldb = d.m;
|
||||
auto a_batch_host = MakeBatchPointers(stream, buffers[0], buffers[3], d.batch,
|
||||
@ -162,6 +167,11 @@ static absl::Status GetrfBatched_(cudaStream_t stream, void** buffers,
|
||||
auto h = BlasHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
if (buffers[0] != buffers[1]) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync(
|
||||
buffers[1], buffers[0], SizeOfCublasType(d.type) * d.batch * d.n * d.n,
|
||||
cudaMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
|
||||
int* ipiv = static_cast<int*>(buffers[2]);
|
||||
int* info = static_cast<int*>(buffers[3]);
|
||||
@ -220,6 +230,11 @@ static absl::Status GeqrfBatched_(cudaStream_t stream, void** buffers,
|
||||
auto h = BlasHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
if (buffers[0] != buffers[1]) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync(
|
||||
buffers[1], buffers[0], SizeOfCublasType(d.type) * d.batch * d.m * d.n,
|
||||
cudaMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
|
||||
std::vector<int> info(d.batch);
|
||||
auto a_ptrs_host = MakeBatchPointers(stream, buffers[1], buffers[3], d.batch,
|
||||
|
@ -94,6 +94,12 @@ static absl::Status Potrf_(cudaStream_t stream, void** buffers,
|
||||
auto h = SolverHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
if (buffers[1] != buffers[0]) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cudaMemcpyAsync(buffers[1], buffers[0],
|
||||
SizeOfCusolverType(d.type) * d.batch * d.n * d.n,
|
||||
cudaMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
|
||||
int* info = static_cast<int*>(buffers[2]);
|
||||
void* workspace = buffers[3];
|
||||
@ -186,6 +192,13 @@ static absl::Status Getrf_(cudaStream_t stream, void** buffers,
|
||||
auto h = SolverHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
if (buffers[1] != buffers[0]) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync(
|
||||
buffers[1], buffers[0],
|
||||
SizeOfCusolverType(d.type) * static_cast<std::int64_t>(d.batch) *
|
||||
static_cast<std::int64_t>(d.m) * static_cast<std::int64_t>(d.n),
|
||||
cudaMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
|
||||
int* ipiv = static_cast<int*>(buffers[2]);
|
||||
int* info = static_cast<int*>(buffers[3]);
|
||||
@ -262,6 +275,13 @@ static absl::Status Geqrf_(cudaStream_t stream, void** buffers,
|
||||
auto h = SolverHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
if (buffers[1] != buffers[0]) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync(
|
||||
buffers[1], buffers[0],
|
||||
SizeOfCusolverType(d.type) * static_cast<std::int64_t>(d.batch) *
|
||||
static_cast<std::int64_t>(d.m) * static_cast<std::int64_t>(d.n),
|
||||
cudaMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
|
||||
int* info = static_cast<int*>(buffers[3]);
|
||||
void* workspace = buffers[4];
|
||||
@ -438,6 +458,13 @@ static absl::Status Orgqr_(cudaStream_t stream, void** buffers,
|
||||
auto h = SolverHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
if (buffers[2] != buffers[0]) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync(
|
||||
buffers[2], buffers[0],
|
||||
SizeOfCusolverType(d.type) * static_cast<std::int64_t>(d.batch) *
|
||||
static_cast<std::int64_t>(d.m) * static_cast<std::int64_t>(d.n),
|
||||
cudaMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
|
||||
int* info = static_cast<int*>(buffers[3]);
|
||||
void* workspace = buffers[4];
|
||||
@ -517,6 +544,11 @@ static absl::Status Syevd_(cudaStream_t stream, void** buffers,
|
||||
auto h = SolverHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync(
|
||||
buffers[1], buffers[0],
|
||||
SizeOfCusolverType(d.type) * static_cast<std::int64_t>(d.batch) *
|
||||
static_cast<std::int64_t>(d.n) * static_cast<std::int64_t>(d.n),
|
||||
cudaMemcpyDeviceToDevice, stream)));
|
||||
cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR;
|
||||
int* info = static_cast<int*>(buffers[3]);
|
||||
void* work = buffers[4];
|
||||
@ -597,6 +629,13 @@ absl::Status Syevj_(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
auto h = SolverHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
if (buffers[1] != buffers[0]) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync(
|
||||
buffers[1], buffers[0],
|
||||
SizeOfCusolverType(d.type) * static_cast<std::int64_t>(d.batch) *
|
||||
static_cast<std::int64_t>(d.n) * static_cast<std::int64_t>(d.n),
|
||||
cudaMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
syevjInfo_t params;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCreateSyevjInfo(¶ms)));
|
||||
std::unique_ptr<syevjInfo, void (*)(syevjInfo*)> params_cleanup(
|
||||
@ -699,6 +738,11 @@ static absl::Status Gesvd_(cudaStream_t stream, void** buffers,
|
||||
auto h = SolverHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync(
|
||||
buffers[1], buffers[0],
|
||||
SizeOfCusolverType(d.type) * static_cast<std::int64_t>(d.batch) *
|
||||
static_cast<std::int64_t>(d.m) * static_cast<std::int64_t>(d.n),
|
||||
cudaMemcpyDeviceToDevice, stream)));
|
||||
int* info = static_cast<int*>(buffers[5]);
|
||||
void* work = buffers[6];
|
||||
int64_t k = d.jobu == 'A' ? d.m : d.n;
|
||||
@ -797,6 +841,11 @@ static absl::Status Gesvdj_(cudaStream_t stream, void** buffers,
|
||||
auto h = SolverHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync(
|
||||
buffers[1], buffers[0],
|
||||
SizeOfCusolverType(d.type) * static_cast<std::int64_t>(d.batch) *
|
||||
static_cast<std::int64_t>(d.m) * static_cast<std::int64_t>(d.n),
|
||||
cudaMemcpyDeviceToDevice, stream)));
|
||||
int* info = static_cast<int*>(buffers[5]);
|
||||
void* work = buffers[6];
|
||||
gesvdjInfo_t params;
|
||||
|
@ -520,24 +520,25 @@ static absl::Status CooMatmat_(cudaStream_t stream, void** buffers,
|
||||
cusparseDnMatDescr_t mat_b = 0;
|
||||
cusparseDnMatDescr_t mat_c = 0;
|
||||
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateCoo(
|
||||
&mat_a, d.A.rows, d.A.cols, d.A.nnz, coo_row_ind, coo_col_ind, coo_values,
|
||||
d.A.index_type, CUSPARSE_INDEX_BASE_ZERO, d.A.value_type)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cusparseCooSetStridedBatch(mat_a, /*batchCount=*/d.A.batch_count,
|
||||
/*batchStride=*/d.A.batch_stride)));
|
||||
cusparseCooSetStridedBatch(mat_a, /*batchCount=*/d.A.batch_count,
|
||||
/*batchStride=*/d.A.batch_stride)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
|
||||
&mat_b, d.B.rows, d.B.cols,
|
||||
/*ld=*/d.B.cols, Bbuf, d.B.type, CUSPARSE_ORDER_ROW)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cusparseDnMatSetStridedBatch(mat_b, /*batchCount=*/d.B.batch_count,
|
||||
/*batchStride=*/d.B.batch_stride)));
|
||||
cusparseDnMatSetStridedBatch(mat_b, /*batchCount=*/d.B.batch_count,
|
||||
/*batchStride=*/d.B.batch_stride)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
|
||||
&mat_c, d.C.rows, d.C.cols,
|
||||
/*ld=*/d.C.cols, Cbuf, d.C.type, CUSPARSE_ORDER_ROW)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cusparseDnMatSetStridedBatch(mat_c, /*batchCount=*/d.C.batch_count,
|
||||
/*batchStride=*/d.C.batch_stride)));
|
||||
cusparseDnMatSetStridedBatch(mat_c, /*batchCount=*/d.C.batch_count,
|
||||
/*batchStride=*/d.C.batch_stride)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseSpMM(
|
||||
handle.get(), d.op_A, /*opB=*/CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha,
|
||||
mat_a, mat_b, &beta, mat_c, d.C.type, CUSPARSE_SPMM_ALG_DEFAULT, buf)));
|
||||
@ -579,10 +580,16 @@ static absl::Status gtsv2(F computeGtsv2, cudaStream_t stream, void** buffers,
|
||||
T* X = (T*)(buffers[4]);
|
||||
void* buffer = buffers[5];
|
||||
|
||||
// The solution X is written in place to B.
|
||||
// The solution X is written in place to B. We need to therefore copy the
|
||||
// contents of B into the output buffer X and pass that into the kernel as B.
|
||||
// Once copy insertion is supported for custom call aliasing, we could alias B
|
||||
// with X and avoid the copy, the code below is written defensively assuming B
|
||||
// and X might alias, but today we know they will not.
|
||||
// TODO(b/182906199): Update the comment here once copy insertion is WAI.
|
||||
if (X != B) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Input and output buffers to gtsv2 must alias");
|
||||
size_t B_bytes = ldb * n * sizeof(T);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cudaMemcpyAsync(X, B, B_bytes, cudaMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
|
@ -98,8 +98,7 @@ def _trsm_mhlo(platform, gpu_blas, dtype, a, b, left_side=False, lower=False,
|
||||
[a, b],
|
||||
backend_config=opaque,
|
||||
operand_layouts=[layout] * 2,
|
||||
result_layouts=[layout, work_layout, work_layout],
|
||||
operand_output_aliases={1: 0})
|
||||
result_layouts=[layout, work_layout, work_layout])
|
||||
return out[0]
|
||||
|
||||
cuda_trsm = partial(_trsm_mhlo, "cu", _cublas)
|
||||
@ -134,8 +133,7 @@ def _potrf_mhlo(platform, gpu_solver, dtype, a, lower):
|
||||
[a],
|
||||
backend_config=opaque,
|
||||
operand_layouts=[layout],
|
||||
result_layouts=[layout, info_layout, work_layout],
|
||||
operand_output_aliases={0: 0})
|
||||
result_layouts=[layout, info_layout, work_layout])
|
||||
return out[:2]
|
||||
|
||||
cuda_potrf = partial(_potrf_mhlo, "cu", _cusolver)
|
||||
@ -181,8 +179,7 @@ def _getrf_mhlo(platform, gpu_blas, gpu_solver, dtype, a):
|
||||
tuple(range(num_bd, -1, -1)),
|
||||
tuple(range(num_bd - 1, -1, -1)),
|
||||
[0],
|
||||
],
|
||||
operand_output_aliases={0: 0})
|
||||
])
|
||||
return out[:3]
|
||||
|
||||
cuda_getrf = partial(_getrf_mhlo, "cu", _cublas, _cusolver)
|
||||
@ -220,8 +217,7 @@ def _geqrf_mhlo(platform, gpu_solver, dtype, a):
|
||||
tuple(range(num_bd, -1, -1)),
|
||||
tuple(range(num_bd - 1, -1, -1)),
|
||||
[0],
|
||||
],
|
||||
operand_output_aliases={0: 0})
|
||||
])
|
||||
return out[:3]
|
||||
|
||||
cuda_geqrf = partial(_geqrf_mhlo, "cu", _cusolver)
|
||||
@ -257,9 +253,7 @@ def _geqrf_batched_mhlo(platform, gpu_blas, dtype, a):
|
||||
tuple(range(num_bd, -1, -1)),
|
||||
[0],
|
||||
[0],
|
||||
],
|
||||
operand_output_aliases={0: 0}
|
||||
)
|
||||
])
|
||||
return out[:2]
|
||||
|
||||
cuda_geqrf_batched = partial(_geqrf_batched_mhlo, "cu", _cublas)
|
||||
@ -327,8 +321,7 @@ def _orgqr_mhlo(platform, gpu_solver, dtype, a, tau):
|
||||
layout,
|
||||
tuple(range(num_bd - 1, -1, -1)),
|
||||
[0],
|
||||
],
|
||||
operand_output_aliases={0: 0})
|
||||
])
|
||||
return out[:2]
|
||||
|
||||
cuda_orgqr = partial(_orgqr_mhlo, "cu", _cusolver)
|
||||
@ -379,8 +372,7 @@ def _syevd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
|
||||
tuple(range(num_bd, -1, -1)),
|
||||
tuple(range(num_bd - 1, -1, -1)),
|
||||
[0],
|
||||
],
|
||||
operand_output_aliases={0: 0})
|
||||
])
|
||||
return out[:3]
|
||||
|
||||
cuda_syevd = partial(_syevd_mhlo, "cu", _cusolver, True)
|
||||
@ -435,8 +427,7 @@ def _gesvd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
|
||||
matrix_layout,
|
||||
scalar_layout,
|
||||
[0],
|
||||
],
|
||||
operand_output_aliases={0: 0})
|
||||
])
|
||||
vt = mhlo.TransposeOp(
|
||||
v,
|
||||
ir.DenseIntElementsAttr.get(np.array(tuple(range(num_bd)) + (num_bd + 1, num_bd)))).result
|
||||
@ -478,8 +469,7 @@ def _gesvd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
|
||||
matrix_layout,
|
||||
scalar_layout,
|
||||
[0],
|
||||
],
|
||||
operand_output_aliases={0: 0})
|
||||
])
|
||||
else:
|
||||
lwork, opaque = gpu_solver.build_gesvd_descriptor(
|
||||
np.dtype(dtype), b, m, n, compute_uv, full_matrices)
|
||||
@ -505,8 +495,7 @@ def _gesvd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
|
||||
matrix_layout,
|
||||
scalar_layout,
|
||||
[0],
|
||||
],
|
||||
operand_output_aliases={0: 0})
|
||||
])
|
||||
return s, u, vt, info
|
||||
|
||||
cuda_gesvd = partial(_gesvd_mhlo, "cu", _cusolver, True)
|
||||
|
@ -356,8 +356,7 @@ def _gtsv2_mhlo(platform, gpu_sparse, dl, d, du, B, *, m, n, ldb, t):
|
||||
[dl, d, du, B],
|
||||
backend_config=gpu_sparse.build_gtsv2_descriptor(m, n, ldb),
|
||||
operand_layouts=[[0]] * 3 + [[1, 0]],
|
||||
result_layouts=[[1, 0], [0]],
|
||||
operand_output_aliases={3: 0})
|
||||
result_layouts=[[1, 0], [0]])
|
||||
return out[0]
|
||||
|
||||
cuda_gtsv2 = partial(_gtsv2_mhlo, "cu", _cusparse)
|
||||
|
@ -14,37 +14,28 @@
|
||||
|
||||
# Helpers for building MHLO operators
|
||||
|
||||
from typing import Dict, Optional, Sequence, Union
|
||||
from typing import Optional, Sequence, Union
|
||||
import jaxlib.mlir.ir as ir
|
||||
import jaxlib.mlir.dialects.mhlo as mhlo
|
||||
import numpy as np
|
||||
|
||||
|
||||
def custom_call(
|
||||
call_target_name: str,
|
||||
out_types: Sequence[ir.Type],
|
||||
operands: Sequence[ir.Value],
|
||||
operand_layouts: Sequence[Sequence[int]],
|
||||
result_layouts: Sequence[Sequence[int]],
|
||||
backend_config: Optional[str] = None,
|
||||
has_side_effect: bool = False,
|
||||
api_version: int = 2,
|
||||
operand_output_aliases: Dict[int, int] = {},
|
||||
) -> Union[ir.Value, Sequence[ir.Value]]:
|
||||
def custom_call(call_target_name: str, out_types: Sequence[ir.Type],
|
||||
operands: Sequence[ir.Value],
|
||||
operand_layouts: Sequence[Sequence[int]],
|
||||
result_layouts: Sequence[Sequence[int]],
|
||||
backend_config: Optional[str] = None,
|
||||
has_side_effect: bool = False,
|
||||
api_version: int = 2,
|
||||
) -> Union[ir.Value, Sequence[ir.Value]]:
|
||||
"""Less-verbose helper for building an MHLO custom call op.
|
||||
|
||||
Once https://github.com/llvm/llvm-project/issues/54932 is fixed, this helper
|
||||
may be able to go away.
|
||||
|
||||
Args:
|
||||
...
|
||||
operand_output_alias: a dictionary mapping input numbers -> output numbers
|
||||
that must alias.
|
||||
"""
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
out = mhlo.CustomCallOp(
|
||||
(out_types
|
||||
if len(out_types) == 1 else [ir.TupleType.get_tuple(out_types)]),
|
||||
(out_types if len(out_types) == 1 else
|
||||
[ir.TupleType.get_tuple(out_types)]),
|
||||
operands,
|
||||
call_target_name=ir.StringAttr.get(call_target_name),
|
||||
has_side_effect=ir.BoolAttr.get(has_side_effect),
|
||||
@ -52,27 +43,18 @@ def custom_call(
|
||||
"" if backend_config is None else backend_config),
|
||||
api_version=ir.IntegerAttr.get(i32_type, api_version),
|
||||
called_computations=ir.ArrayAttr.get([]),
|
||||
operand_layouts=ir.ArrayAttr.get([
|
||||
ir.DenseIntElementsAttr.get(
|
||||
operand_layouts=ir.ArrayAttr.get(
|
||||
[ir.DenseIntElementsAttr.get(
|
||||
np.atleast_1d(np.asarray(l, dtype=np.int64)),
|
||||
type=ir.IndexType.get()) for l in operand_layouts
|
||||
]),
|
||||
result_layouts=ir.ArrayAttr.get([
|
||||
ir.DenseIntElementsAttr.get(
|
||||
type=ir.IndexType.get())
|
||||
for l in operand_layouts]),
|
||||
result_layouts=ir.ArrayAttr.get(
|
||||
[ir.DenseIntElementsAttr.get(
|
||||
np.atleast_1d(np.asarray(l, dtype=np.int64)),
|
||||
type=ir.IndexType.get()) for l in result_layouts
|
||||
]),
|
||||
output_operand_aliases=ir.ArrayAttr.get([
|
||||
mhlo.OutputOperandAlias.get(
|
||||
output_tuple_indices=[] if len(out_types) == 1 else [output],
|
||||
operand_index=input,
|
||||
operand_tuple_indices=[])
|
||||
for input, output in operand_output_aliases.items()
|
||||
]))
|
||||
type=ir.IndexType.get())
|
||||
for l in result_layouts]))
|
||||
if len(out_types) == 1:
|
||||
return out.result
|
||||
else:
|
||||
return [
|
||||
mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
|
||||
for i in range(len(out_types))
|
||||
]
|
||||
return [mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
|
||||
for i in range(len(out_types))]
|
||||
|
@ -82,6 +82,11 @@ static absl::Status TrsmBatched_(hipStream_t stream, void** buffers,
|
||||
auto h = BlasHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
if (buffers[2] != buffers[1]) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
|
||||
buffers[2], buffers[1], SizeOfHipblasType(d.type) * d.batch * d.m * d.n,
|
||||
hipMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
const int lda = d.side == HIPBLAS_SIDE_LEFT ? d.m : d.n;
|
||||
const int ldb = d.m;
|
||||
auto a_batch_host = MakeBatchPointers(stream, buffers[0], buffers[3], d.batch,
|
||||
@ -161,6 +166,11 @@ static absl::Status GetrfBatched_(hipStream_t stream, void** buffers,
|
||||
auto h = BlasHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
if (buffers[0] != buffers[1]) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
|
||||
buffers[1], buffers[0], SizeOfHipblasType(d.type) * d.batch * d.n * d.n,
|
||||
hipMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
|
||||
int* ipiv = static_cast<int*>(buffers[2]);
|
||||
int* info = static_cast<int*>(buffers[3]);
|
||||
@ -220,6 +230,11 @@ static absl::Status GeqrfBatched_(hipStream_t stream, void** buffers,
|
||||
auto h = BlasHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
if (buffers[0] != buffers[1]) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
|
||||
buffers[1], buffers[0], SizeOfHipblasType(d.type) * d.batch * d.m * d.n,
|
||||
hipMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
|
||||
std::vector<int> info(d.batch);
|
||||
auto a_ptrs_host = MakeBatchPointers(stream, buffers[1], buffers[3], d.batch,
|
||||
|
@ -74,6 +74,12 @@ static absl::Status Potrf_(hipStream_t stream, void** buffers,
|
||||
auto h = SolverHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
if (buffers[1] != buffers[0]) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
hipMemcpyAsync(buffers[1], buffers[0],
|
||||
SizeOfHipsolverType(d.type) * d.batch * d.n * d.n,
|
||||
hipMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
|
||||
int* info = static_cast<int*>(buffers[2]);
|
||||
void* workspace = buffers[3];
|
||||
@ -169,6 +175,13 @@ static absl::Status Getrf_(hipStream_t stream, void** buffers,
|
||||
auto h = SolverHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
if (buffers[1] != buffers[0]) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
|
||||
buffers[1], buffers[0],
|
||||
SizeOfHipsolverType(d.type) * static_cast<std::int64_t>(d.batch) *
|
||||
static_cast<std::int64_t>(d.m) * static_cast<std::int64_t>(d.n),
|
||||
hipMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
|
||||
int* ipiv = static_cast<int*>(buffers[2]);
|
||||
int* info = static_cast<int*>(buffers[3]);
|
||||
@ -245,6 +258,13 @@ static absl::Status Geqrf_(hipStream_t stream, void** buffers,
|
||||
auto h = SolverHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
if (buffers[1] != buffers[0]) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
|
||||
buffers[1], buffers[0],
|
||||
SizeOfHipsolverType(d.type) * static_cast<std::int64_t>(d.batch) *
|
||||
static_cast<std::int64_t>(d.m) * static_cast<std::int64_t>(d.n),
|
||||
hipMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
|
||||
int* info = static_cast<int*>(buffers[3]);
|
||||
|
||||
@ -325,6 +345,13 @@ static absl::Status Orgqr_(hipStream_t stream, void** buffers,
|
||||
auto h = SolverHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
if (buffers[2] != buffers[0]) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
|
||||
buffers[2], buffers[0],
|
||||
SizeOfHipsolverType(d.type) * static_cast<std::int64_t>(d.batch) *
|
||||
static_cast<std::int64_t>(d.m) * static_cast<std::int64_t>(d.n),
|
||||
hipMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
|
||||
int* info = static_cast<int*>(buffers[3]);
|
||||
|
||||
@ -405,6 +432,11 @@ static absl::Status Syevd_(hipStream_t stream, void** buffers,
|
||||
auto h = SolverHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
|
||||
buffers[1], buffers[0],
|
||||
SizeOfHipsolverType(d.type) * static_cast<std::int64_t>(d.batch) *
|
||||
static_cast<std::int64_t>(d.n) * static_cast<std::int64_t>(d.n),
|
||||
hipMemcpyDeviceToDevice, stream)));
|
||||
hipsolverEigMode_t jobz = HIPSOLVER_EIG_MODE_VECTOR;
|
||||
int* info = static_cast<int*>(buffers[3]);
|
||||
void* work = buffers[4];
|
||||
@ -485,6 +517,13 @@ absl::Status Syevj_(hipStream_t stream, void** buffers, const char* opaque,
|
||||
auto h = SolverHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
if (buffers[1] != buffers[0]) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
|
||||
buffers[1], buffers[0],
|
||||
SizeOfHipsolverType(d.type) * static_cast<std::int64_t>(d.batch) *
|
||||
static_cast<std::int64_t>(d.n) * static_cast<std::int64_t>(d.n),
|
||||
hipMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
hipsolverSyevjInfo_t params;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCreateSyevjInfo(¶ms)));
|
||||
std::unique_ptr<void, void (*)(hipsolverSyevjInfo_t)> params_cleanup(
|
||||
@ -586,6 +625,11 @@ static absl::Status Gesvd_(hipStream_t stream, void** buffers,
|
||||
auto h = SolverHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
|
||||
buffers[1], buffers[0],
|
||||
SizeOfHipsolverType(d.type) * static_cast<std::int64_t>(d.batch) *
|
||||
static_cast<std::int64_t>(d.m) * static_cast<std::int64_t>(d.n),
|
||||
hipMemcpyDeviceToDevice, stream)));
|
||||
int* info = static_cast<int*>(buffers[5]);
|
||||
void* work = buffers[6];
|
||||
switch (d.type) {
|
||||
|
@ -504,10 +504,16 @@ static absl::Status gtsv2(F computeGtsv2, hipStream_t stream, void** buffers,
|
||||
T* X = (T*)(buffers[4]);
|
||||
void* buffer = buffers[5];
|
||||
|
||||
// The solution X is written in place to B.
|
||||
// The solution X is written in place to B. We need to therefore copy the
|
||||
// contents of B into the output buffer X and pass that into the kernel as B.
|
||||
// Once copy insertion is supported for custom call aliasing, we could alias B
|
||||
// with X and avoid the copy, the code below is written defensively assuming B
|
||||
// and X might alias, but today we know they will not.
|
||||
// TODO(b/182906199): Update the comment here once copy insertion is WAI.
|
||||
if (X != B) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Input and output buffers to gtsv2 must alias");
|
||||
size_t B_bytes = ldb * n * sizeof(T);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
hipMemcpyAsync(X, B, B_bytes, hipMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
|
Loading…
x
Reference in New Issue
Block a user