mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[ROCm] Fixes for compilation failures caused by compiler changes in ROCm Tensorflow fork.
This commit is contained in:
parent
406a61cf52
commit
080cf47002
1
.bazelrc
1
.bazelrc
@ -99,6 +99,7 @@ build:windows --host_linkopt=/OPT:ICF
|
||||
build:windows --experimental_strict_action_env=true
|
||||
|
||||
build:linux --config=posix
|
||||
build:linux --copt=-Wno-unknown-warning-option
|
||||
# Workaround for gcc 10+ warnings related to upb.
|
||||
# See https://github.com/tensorflow/tensorflow/issues/39467
|
||||
build:linux --copt=-Wno-stringop-truncation
|
||||
|
@ -126,30 +126,30 @@ static absl::Status Potrf_(hipStream_t stream, void** buffers,
|
||||
case HipsolverType::F32: {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverSpotrfBatched(
|
||||
handle.get(), d.uplo, d.n, static_cast<float**>(workspace), d.n,
|
||||
static_cast<float*>(workspace + (d.batch * sizeof(float*))), d.lwork,
|
||||
info, d.batch)));
|
||||
reinterpret_cast<float*>(static_cast<float**>(workspace) + d.batch),
|
||||
d.lwork, info, d.batch)));
|
||||
break;
|
||||
}
|
||||
case HipsolverType::F64: {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverDpotrfBatched(
|
||||
handle.get(), d.uplo, d.n, static_cast<double**>(workspace), d.n,
|
||||
static_cast<double*>(workspace + (d.batch * sizeof(double*))), d.lwork,
|
||||
info, d.batch)));
|
||||
reinterpret_cast<double*>(static_cast<double**>(workspace) + d.batch),
|
||||
d.lwork, info, d.batch)));
|
||||
break;
|
||||
}
|
||||
case HipsolverType::C64: {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCpotrfBatched(
|
||||
handle.get(), d.uplo, d.n, static_cast<hipFloatComplex**>(workspace), d.n,
|
||||
static_cast<hipFloatComplex*>(workspace + (d.batch * sizeof(hipFloatComplex*))),d.lwork,
|
||||
info, d.batch)));
|
||||
reinterpret_cast<hipFloatComplex*>(static_cast<hipFloatComplex**>(workspace) +
|
||||
d.batch), d.lwork, info, d.batch)));
|
||||
break;
|
||||
}
|
||||
case HipsolverType::C128: {
|
||||
hipDoubleComplex* a = static_cast<hipDoubleComplex*>(buffers[1]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverZpotrfBatched(
|
||||
handle.get(), d.uplo, d.n, static_cast<hipDoubleComplex**>(workspace), d.n,
|
||||
static_cast<hipDoubleComplex*>(workspace + (d.batch * sizeof(hipDoubleComplex*))), d.lwork,
|
||||
info, d.batch)));
|
||||
reinterpret_cast<hipDoubleComplex*>(static_cast<hipDoubleComplex**>(workspace) +
|
||||
d.batch), d.lwork, info, d.batch)));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user