[ROCm] Fixes for compilation failures caused by compiler changes in ROCm Tensorflow fork.

This commit is contained in:
Rohit Santhanam 2022-06-25 21:19:19 +00:00
parent 406a61cf52
commit 080cf47002
2 changed files with 9 additions and 8 deletions

View File

@ -99,6 +99,7 @@ build:windows --host_linkopt=/OPT:ICF
build:windows --experimental_strict_action_env=true
build:linux --config=posix
build:linux --copt=-Wno-unknown-warning-option
# Workaround for gcc 10+ warnings related to upb.
# See https://github.com/tensorflow/tensorflow/issues/39467
build:linux --copt=-Wno-stringop-truncation

View File

@ -126,30 +126,30 @@ static absl::Status Potrf_(hipStream_t stream, void** buffers,
case HipsolverType::F32: {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverSpotrfBatched(
handle.get(), d.uplo, d.n, static_cast<float**>(workspace), d.n,
static_cast<float*>(workspace + (d.batch * sizeof(float*))), d.lwork,
info, d.batch)));
reinterpret_cast<float*>(static_cast<float**>(workspace) + d.batch),
d.lwork, info, d.batch)));
break;
}
case HipsolverType::F64: {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverDpotrfBatched(
handle.get(), d.uplo, d.n, static_cast<double**>(workspace), d.n,
static_cast<double*>(workspace + (d.batch * sizeof(double*))), d.lwork,
info, d.batch)));
reinterpret_cast<double*>(static_cast<double**>(workspace) + d.batch),
d.lwork, info, d.batch)));
break;
}
case HipsolverType::C64: {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCpotrfBatched(
handle.get(), d.uplo, d.n, static_cast<hipFloatComplex**>(workspace), d.n,
static_cast<hipFloatComplex*>(workspace + (d.batch * sizeof(hipFloatComplex*))),d.lwork,
info, d.batch)));
reinterpret_cast<hipFloatComplex*>(static_cast<hipFloatComplex**>(workspace) +
d.batch), d.lwork, info, d.batch)));
break;
}
case HipsolverType::C128: {
hipDoubleComplex* a = static_cast<hipDoubleComplex*>(buffers[1]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverZpotrfBatched(
handle.get(), d.uplo, d.n, static_cast<hipDoubleComplex**>(workspace), d.n,
static_cast<hipDoubleComplex*>(workspace + (d.batch * sizeof(hipDoubleComplex*))), d.lwork,
info, d.batch)));
reinterpret_cast<hipDoubleComplex*>(static_cast<hipDoubleComplex**>(workspace) +
d.batch), d.lwork, info, d.batch)));
break;
}
}