From 080cf47002f2c4d0815b7d92054ee9492e1422dc Mon Sep 17 00:00:00 2001 From: Rohit Santhanam Date: Sat, 25 Jun 2022 21:19:19 +0000 Subject: [PATCH] [ROCm] Fixes for compilation failures caused by compiler changes in ROCm Tensorflow fork. --- .bazelrc | 1 + jaxlib/rocm/hipsolver_kernels.cc | 16 ++++++++-------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/.bazelrc b/.bazelrc index da136a8d7..be7f6fdbd 100644 --- a/.bazelrc +++ b/.bazelrc @@ -99,6 +99,7 @@ build:windows --host_linkopt=/OPT:ICF build:windows --experimental_strict_action_env=true build:linux --config=posix +build:linux --copt=-Wno-unknown-warning-option # Workaround for gcc 10+ warnings related to upb. # See https://github.com/tensorflow/tensorflow/issues/39467 build:linux --copt=-Wno-stringop-truncation diff --git a/jaxlib/rocm/hipsolver_kernels.cc b/jaxlib/rocm/hipsolver_kernels.cc index 43f111bd6..aa2b7c9f2 100644 --- a/jaxlib/rocm/hipsolver_kernels.cc +++ b/jaxlib/rocm/hipsolver_kernels.cc @@ -126,30 +126,30 @@ static absl::Status Potrf_(hipStream_t stream, void** buffers, case HipsolverType::F32: { JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverSpotrfBatched( handle.get(), d.uplo, d.n, static_cast(workspace), d.n, - static_cast(workspace + (d.batch * sizeof(float*))), d.lwork, - info, d.batch))); + reinterpret_cast(static_cast(workspace) + d.batch), + d.lwork, info, d.batch))); break; } case HipsolverType::F64: { JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverDpotrfBatched( handle.get(), d.uplo, d.n, static_cast(workspace), d.n, - static_cast(workspace + (d.batch * sizeof(double*))), d.lwork, - info, d.batch))); + reinterpret_cast(static_cast(workspace) + d.batch), + d.lwork, info, d.batch))); break; } case HipsolverType::C64: { JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCpotrfBatched( handle.get(), d.uplo, d.n, static_cast(workspace), d.n, - static_cast(workspace + (d.batch * sizeof(hipFloatComplex*))),d.lwork, - info, d.batch))); + reinterpret_cast(static_cast(workspace) + + d.batch), d.lwork, info, d.batch))); break; } case HipsolverType::C128: { hipDoubleComplex* a = static_cast(buffers[1]); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverZpotrfBatched( handle.get(), d.uplo, d.n, static_cast(workspace), d.n, - static_cast(workspace + (d.batch * sizeof(hipDoubleComplex*))), d.lwork, - info, d.batch))); + reinterpret_cast(static_cast(workspace) + + d.batch), d.lwork, info, d.batch))); break; } }