mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Move logic about when to dispatch to batched LU decomposition algorithm on GPU into the kernel.
This simplifies the lowering logic, and means that we don't get hit with a performance penalty when exporting with shape polymorphism. PiperOrigin-RevId: 662945116
This commit is contained in:
parent
bab70dda97
commit
ad1bd38790
@ -114,22 +114,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cublas_kernels_ffi",
|
||||
srcs = ["//jaxlib/gpu:blas_kernels_ffi.cc"],
|
||||
hdrs = ["//jaxlib/gpu:blas_kernels_ffi.h"],
|
||||
deps = [
|
||||
":cuda_blas_handle_pool",
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_vendor",
|
||||
"//jaxlib:ffi_helpers",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
"@xla//xla/tsl/cuda:cublas",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
"@com_google_absl//absl/status",
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "_blas",
|
||||
srcs = ["//jaxlib/gpu:blas.cc"],
|
||||
@ -148,7 +132,6 @@ pybind_extension(
|
||||
module_name = "_blas",
|
||||
deps = [
|
||||
":cublas_kernels",
|
||||
":cublas_kernels_ffi",
|
||||
":cuda_vendor",
|
||||
"//jaxlib:kernel_nanobind_helpers",
|
||||
"@xla//xla/tsl/cuda:cublas",
|
||||
@ -238,11 +221,13 @@ cc_library(
|
||||
srcs = ["//jaxlib/gpu:solver_kernels_ffi.cc"],
|
||||
hdrs = ["//jaxlib/gpu:solver_kernels_ffi.h"],
|
||||
deps = [
|
||||
":cuda_blas_handle_pool",
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_solver_handle_pool",
|
||||
":cuda_vendor",
|
||||
"//jaxlib:ffi_helpers",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
"@xla//xla/tsl/cuda:cublas",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
"@xla//xla/tsl/cuda:cusolver",
|
||||
"@com_google_absl//absl/status",
|
||||
@ -274,6 +259,7 @@ pybind_extension(
|
||||
":cusolver_kernels",
|
||||
":cusolver_kernels_ffi",
|
||||
"//jaxlib:kernel_nanobind_helpers",
|
||||
"@xla//xla/tsl/cuda:cublas",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
"@xla//xla/tsl/cuda:cusolver",
|
||||
"@xla//xla/tsl/python/lib/core:numpy",
|
||||
@ -466,7 +452,6 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":cublas_kernels",
|
||||
":cublas_kernels_ffi",
|
||||
":cuda_linalg_kernels",
|
||||
":cuda_prng_kernels",
|
||||
":cuda_vendor",
|
||||
|
@ -29,8 +29,6 @@ exports_files(srcs = [
|
||||
"blas_handle_pool.h",
|
||||
"blas_kernels.cc",
|
||||
"blas_kernels.h",
|
||||
"blas_kernels_ffi.cc",
|
||||
"blas_kernels_ffi.h",
|
||||
"gpu_kernel_helpers.cc",
|
||||
"gpu_kernel_helpers.h",
|
||||
"gpu_kernels.cc",
|
||||
|
@ -22,7 +22,6 @@ limitations under the License.
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "jaxlib/gpu/blas_kernels.h"
|
||||
#include "jaxlib/gpu/blas_kernels_ffi.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "jaxlib/kernel_nanobind_helpers.h"
|
||||
#include "xla/tsl/python/lib/core/numpy.h"
|
||||
@ -70,9 +69,6 @@ nb::dict Registrations() {
|
||||
nb::dict dict;
|
||||
dict[JAX_GPU_PREFIX "blas_getrf_batched"] = EncapsulateFunction(GetrfBatched);
|
||||
dict[JAX_GPU_PREFIX "blas_geqrf_batched"] = EncapsulateFunction(GeqrfBatched);
|
||||
|
||||
dict[JAX_GPU_PREFIX "blas_getrf_batched_ffi"] =
|
||||
EncapsulateFfiHandler(GetrfBatchedFfi);
|
||||
return dict;
|
||||
}
|
||||
|
||||
|
@ -1,133 +0,0 @@
|
||||
/* Copyright 2024 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/gpu/blas_kernels_ffi.h"
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "jaxlib/ffi_helpers.h"
|
||||
#include "jaxlib/gpu/blas_handle_pool.h"
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "xla/ffi/api/ffi.h"
|
||||
|
||||
namespace jax {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
|
||||
namespace ffi = ::xla::ffi;
|
||||
|
||||
namespace {
|
||||
#define GETRF_BATCHED_KERNEL_IMPL(type, name) \
|
||||
template <> \
|
||||
struct GetrfBatchedKernel<type> { \
|
||||
static absl::Status Run(gpublasHandle_t handle, int n, type** a, int lda, \
|
||||
int* ipiv, int* info, int batch) { \
|
||||
return JAX_AS_STATUS(name(handle, n, a, lda, ipiv, info, batch)); \
|
||||
} \
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct GetrfBatchedKernel;
|
||||
GETRF_BATCHED_KERNEL_IMPL(float, gpublasSgetrfBatched);
|
||||
GETRF_BATCHED_KERNEL_IMPL(double, gpublasDgetrfBatched);
|
||||
GETRF_BATCHED_KERNEL_IMPL(gpublasComplex, gpublasCgetrfBatched);
|
||||
GETRF_BATCHED_KERNEL_IMPL(gpublasDoubleComplex, gpublasZgetrfBatched);
|
||||
#undef GETRF_BATCHED_KERNEL_IMPL
|
||||
|
||||
template <typename T>
|
||||
ffi::Error GetrfBatchedImpl(gpuStream_t stream, ffi::ScratchAllocator& scratch,
|
||||
ffi::AnyBuffer a, ffi::Result<ffi::AnyBuffer> out,
|
||||
ffi::Result<ffi::Buffer<ffi::DataType::S32>> ipiv,
|
||||
ffi::Result<ffi::Buffer<ffi::DataType::S32>> info) {
|
||||
FFI_RETURN_IF_ERROR(CheckMatrixDimensions(a.dimensions()));
|
||||
auto [batch, rows, cols] = SplitBatch2D(a.dimensions());
|
||||
if (rows != cols) {
|
||||
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
|
||||
"getrf_batched only supports square matrices");
|
||||
}
|
||||
FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow<int>(cols));
|
||||
FFI_ASSIGN_OR_RETURN(auto handle, BlasHandlePool::Borrow(stream));
|
||||
|
||||
auto maybe_workspace = scratch.Allocate(sizeof(void*) * batch);
|
||||
if (!maybe_workspace.has_value()) {
|
||||
return ffi::Error(ffi::ErrorCode::kUnknown,
|
||||
"Unable to allocate workspace for batched getrf");
|
||||
}
|
||||
auto workspace = maybe_workspace.value();
|
||||
|
||||
auto a_data = a.untyped_data();
|
||||
auto out_data = out->untyped_data();
|
||||
auto ipiv_data = ipiv->typed_data();
|
||||
auto info_data = info->typed_data();
|
||||
if (a_data != out_data) {
|
||||
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(
|
||||
gpuMemcpyAsync(out_data, a_data, sizeof(T) * batch * cols * cols,
|
||||
gpuMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
|
||||
FFI_ASSIGN_OR_RETURN(
|
||||
auto a_ptrs_host,
|
||||
MakeBatchPointers(stream, out_data, workspace, batch, sizeof(T) * n * n));
|
||||
// TODO(phawkins, danfm): ideally we would not need to synchronize here, but
|
||||
// to avoid it we need a way to keep the host-side buffer alive until the copy
|
||||
// completes.
|
||||
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
|
||||
|
||||
auto batch_ptrs = static_cast<T**>(workspace);
|
||||
FFI_RETURN_IF_ERROR_STATUS(GetrfBatchedKernel<T>::Run(
|
||||
handle.get(), n, batch_ptrs, n, ipiv_data, info_data, batch));
|
||||
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
ffi::Error GetrfBatchedDispatch(
|
||||
gpuStream_t stream, ffi::ScratchAllocator scratch, ffi::AnyBuffer a,
|
||||
ffi::Result<ffi::AnyBuffer> out,
|
||||
ffi::Result<ffi::Buffer<ffi::DataType::S32>> ipiv,
|
||||
ffi::Result<ffi::Buffer<ffi::DataType::S32>> info) {
|
||||
auto dataType = a.element_type();
|
||||
if (dataType != out->element_type()) {
|
||||
return ffi::Error(
|
||||
ffi::ErrorCode::kInvalidArgument,
|
||||
"Input and output to getrf_batched must have the same element type");
|
||||
}
|
||||
if (dataType == ffi::DataType::F32) {
|
||||
return GetrfBatchedImpl<float>(stream, scratch, a, out, ipiv, info);
|
||||
} else if (dataType == ffi::DataType::F64) {
|
||||
return GetrfBatchedImpl<double>(stream, scratch, a, out, ipiv, info);
|
||||
} else if (dataType == ffi::DataType::C64) {
|
||||
return GetrfBatchedImpl<gpublasComplex>(stream, scratch, a, out, ipiv,
|
||||
info);
|
||||
} else if (dataType == ffi::DataType::C128) {
|
||||
return GetrfBatchedImpl<gpublasDoubleComplex>(stream, scratch, a, out, ipiv,
|
||||
info);
|
||||
}
|
||||
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
|
||||
"Unsupported element type for getrf");
|
||||
}
|
||||
} // namespace
|
||||
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL(
|
||||
GetrfBatchedFfi, GetrfBatchedDispatch,
|
||||
ffi::Ffi::Bind()
|
||||
.Ctx<ffi::PlatformStream<gpuStream_t>>()
|
||||
.Ctx<ffi::ScratchAllocator>()
|
||||
.Arg<ffi::AnyBuffer>() // a
|
||||
.Ret<ffi::AnyBuffer>() // out
|
||||
.Ret<ffi::Buffer<ffi::DataType::S32>>() // ipiv
|
||||
.Ret<ffi::Buffer<ffi::DataType::S32>>() // info
|
||||
);
|
||||
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
@ -1,30 +0,0 @@
|
||||
/* Copyright 2024 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef JAXLIB_GPU_BLAS_KERNELS_FFI_H_
|
||||
#define JAXLIB_GPU_BLAS_KERNELS_FFI_H_
|
||||
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "xla/ffi/api/ffi.h"
|
||||
|
||||
namespace jax {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(GetrfBatchedFfi);
|
||||
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
||||
|
||||
#endif // JAXLIB_GPU_BLAS_KERNELS_FFI_H_
|
@ -17,7 +17,6 @@ limitations under the License.
|
||||
// JAX-generated HLO code from outside of JAX.
|
||||
|
||||
#include "jaxlib/gpu/blas_kernels.h"
|
||||
#include "jaxlib/gpu/blas_kernels_ffi.h"
|
||||
#include "jaxlib/gpu/linalg_kernels.h"
|
||||
#include "jaxlib/gpu/prng_kernels.h"
|
||||
#include "jaxlib/gpu/rnn_kernels.h"
|
||||
@ -36,8 +35,6 @@ namespace {
|
||||
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_getrf_batched", GetrfBatched,
|
||||
"CUDA");
|
||||
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cublas_getrf_batched_ffi", "CUDA",
|
||||
GetrfBatchedFfi);
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_geqrf_batched", GeqrfBatched,
|
||||
"CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cudnn_rnn", RNNForward, "CUDA");
|
||||
|
@ -16,10 +16,12 @@ limitations under the License.
|
||||
#include "jaxlib/gpu/solver_kernels_ffi.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "jaxlib/ffi_helpers.h"
|
||||
#include "jaxlib/gpu/blas_handle_pool.h"
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/gpu/solver_handle_pool.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
@ -58,12 +60,11 @@ GETRF_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZgetrf);
|
||||
#undef GETRF_KERNEL_IMPL
|
||||
|
||||
template <typename T>
|
||||
ffi::Error GetrfImpl(gpuStream_t stream, ffi::ScratchAllocator& scratch,
|
||||
ffi::Error GetrfImpl(int64_t batch, int64_t rows, int64_t cols,
|
||||
gpuStream_t stream, ffi::ScratchAllocator& scratch,
|
||||
ffi::AnyBuffer a, ffi::Result<ffi::AnyBuffer> out,
|
||||
ffi::Result<ffi::Buffer<ffi::DataType::S32>> ipiv,
|
||||
ffi::Result<ffi::Buffer<ffi::DataType::S32>> info) {
|
||||
FFI_RETURN_IF_ERROR(CheckMatrixDimensions(a.dimensions()));
|
||||
auto [batch, rows, cols] = SplitBatch2D(a.dimensions());
|
||||
FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow<int>(rows));
|
||||
FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow<int>(cols));
|
||||
|
||||
@ -98,6 +99,64 @@ ffi::Error GetrfImpl(gpuStream_t stream, ffi::ScratchAllocator& scratch,
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
#define GETRF_BATCHED_KERNEL_IMPL(type, name) \
|
||||
template <> \
|
||||
struct GetrfBatchedKernel<type> { \
|
||||
static absl::Status Run(gpublasHandle_t handle, int n, type** a, int lda, \
|
||||
int* ipiv, int* info, int batch) { \
|
||||
return JAX_AS_STATUS(name(handle, n, a, lda, ipiv, info, batch)); \
|
||||
} \
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct GetrfBatchedKernel;
|
||||
GETRF_BATCHED_KERNEL_IMPL(float, gpublasSgetrfBatched);
|
||||
GETRF_BATCHED_KERNEL_IMPL(double, gpublasDgetrfBatched);
|
||||
GETRF_BATCHED_KERNEL_IMPL(gpublasComplex, gpublasCgetrfBatched);
|
||||
GETRF_BATCHED_KERNEL_IMPL(gpublasDoubleComplex, gpublasZgetrfBatched);
|
||||
#undef GETRF_BATCHED_KERNEL_IMPL
|
||||
|
||||
template <typename T>
|
||||
ffi::Error GetrfBatchedImpl(int64_t batch, int64_t cols, gpuStream_t stream,
|
||||
ffi::ScratchAllocator& scratch, ffi::AnyBuffer a,
|
||||
ffi::Result<ffi::AnyBuffer> out,
|
||||
ffi::Result<ffi::Buffer<ffi::DataType::S32>> ipiv,
|
||||
ffi::Result<ffi::Buffer<ffi::DataType::S32>> info) {
|
||||
FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow<int>(cols));
|
||||
FFI_ASSIGN_OR_RETURN(auto handle, BlasHandlePool::Borrow(stream));
|
||||
|
||||
auto maybe_workspace = scratch.Allocate(sizeof(void*) * batch);
|
||||
if (!maybe_workspace.has_value()) {
|
||||
return ffi::Error(ffi::ErrorCode::kUnknown,
|
||||
"Unable to allocate workspace for batched getrf");
|
||||
}
|
||||
auto workspace = maybe_workspace.value();
|
||||
|
||||
auto a_data = a.untyped_data();
|
||||
auto out_data = out->untyped_data();
|
||||
auto ipiv_data = ipiv->typed_data();
|
||||
auto info_data = info->typed_data();
|
||||
if (a_data != out_data) {
|
||||
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(
|
||||
gpuMemcpyAsync(out_data, a_data, sizeof(T) * batch * cols * cols,
|
||||
gpuMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
|
||||
FFI_ASSIGN_OR_RETURN(
|
||||
auto a_ptrs_host,
|
||||
MakeBatchPointers(stream, out_data, workspace, batch, sizeof(T) * n * n));
|
||||
// TODO(phawkins, danfm): ideally we would not need to synchronize here, but
|
||||
// to avoid it we need a way to keep the host-side buffer alive until the copy
|
||||
// completes.
|
||||
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
|
||||
|
||||
auto batch_ptrs = static_cast<T**>(workspace);
|
||||
FFI_RETURN_IF_ERROR_STATUS(GetrfBatchedKernel<T>::Run(
|
||||
handle.get(), n, batch_ptrs, n, ipiv_data, info_data, batch));
|
||||
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
ffi::Error GetrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
|
||||
ffi::AnyBuffer a, ffi::Result<ffi::AnyBuffer> out,
|
||||
ffi::Result<ffi::Buffer<ffi::DataType::S32>> ipiv,
|
||||
@ -108,14 +167,36 @@ ffi::Error GetrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
|
||||
ffi::ErrorCode::kInvalidArgument,
|
||||
"The input and output to getrf must have the same element type");
|
||||
}
|
||||
FFI_RETURN_IF_ERROR(CheckMatrixDimensions(a.dimensions()));
|
||||
auto [batch, rows, cols] = SplitBatch2D(a.dimensions());
|
||||
if (batch > 1 && rows == cols && rows / batch <= 128) {
|
||||
if (dataType == ffi::DataType::F32) {
|
||||
return GetrfImpl<float>(stream, scratch, a, out, ipiv, info);
|
||||
return GetrfBatchedImpl<float>(batch, cols, stream, scratch, a, out, ipiv,
|
||||
info);
|
||||
} else if (dataType == ffi::DataType::F64) {
|
||||
return GetrfImpl<double>(stream, scratch, a, out, ipiv, info);
|
||||
return GetrfBatchedImpl<double>(batch, cols, stream, scratch, a, out,
|
||||
ipiv, info);
|
||||
} else if (dataType == ffi::DataType::C64) {
|
||||
return GetrfImpl<gpuComplex>(stream, scratch, a, out, ipiv, info);
|
||||
return GetrfBatchedImpl<gpublasComplex>(batch, cols, stream, scratch, a,
|
||||
out, ipiv, info);
|
||||
} else if (dataType == ffi::DataType::C128) {
|
||||
return GetrfImpl<gpuDoubleComplex>(stream, scratch, a, out, ipiv, info);
|
||||
return GetrfBatchedImpl<gpublasDoubleComplex>(
|
||||
batch, cols, stream, scratch, a, out, ipiv, info);
|
||||
}
|
||||
} else {
|
||||
if (dataType == ffi::DataType::F32) {
|
||||
return GetrfImpl<float>(batch, rows, cols, stream, scratch, a, out, ipiv,
|
||||
info);
|
||||
} else if (dataType == ffi::DataType::F64) {
|
||||
return GetrfImpl<double>(batch, rows, cols, stream, scratch, a, out, ipiv,
|
||||
info);
|
||||
} else if (dataType == ffi::DataType::C64) {
|
||||
return GetrfImpl<gpuComplex>(batch, rows, cols, stream, scratch, a, out,
|
||||
ipiv, info);
|
||||
} else if (dataType == ffi::DataType::C128) {
|
||||
return GetrfImpl<gpuDoubleComplex>(batch, rows, cols, stream, scratch, a,
|
||||
out, ipiv, info);
|
||||
}
|
||||
}
|
||||
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
|
||||
"Unsupported element type for getrf");
|
||||
|
@ -43,10 +43,7 @@ except ImportError:
|
||||
|
||||
if _cublas:
|
||||
for _name, _value in _cublas.registrations().items():
|
||||
# TODO(danfm): Clean up after all legacy custom calls are ported.
|
||||
api_version = 1 if _name.endswith("_ffi") else 0
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA",
|
||||
api_version=api_version)
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
|
||||
|
||||
for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
|
||||
try:
|
||||
@ -78,10 +75,7 @@ except ImportError:
|
||||
|
||||
if _hipblas:
|
||||
for _name, _value in _hipblas.registrations().items():
|
||||
# TODO(danfm): Clean up after all legacy custom calls are ported.
|
||||
api_version = 1 if _name.endswith("_ffi") else 0
|
||||
xla_client.register_custom_call_target(_name, _value, platform="ROCM",
|
||||
api_version=api_version)
|
||||
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
|
||||
|
||||
for rocm_module_name in [".rocm", "jax_rocm60_plugin"]:
|
||||
try:
|
||||
@ -115,15 +109,14 @@ def _getrf_hlo(platform, gpu_blas, gpu_solver, ctx, dtype, a):
|
||||
num_bd = len(batch_dims)
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
|
||||
batch = math.prod(batch_dims)
|
||||
use_batched = batch > 1 and m == n and m // batch <= 128
|
||||
|
||||
# TODO(b/357034884): Remove after 3 week forward compatibility window.
|
||||
if ctx.is_forward_compat():
|
||||
if not gpu_blas:
|
||||
raise GpuLibNotLinkedError()
|
||||
|
||||
if use_batched:
|
||||
batch = math.prod(batch_dims)
|
||||
if batch > 1 and m == n and m // batch <= 128:
|
||||
lwork, opaque = gpu_blas.build_getrf_batched_descriptor(
|
||||
np.dtype(dtype), batch, m)
|
||||
workspace = ir.RankedTensorType.get([lwork], ir.IntegerType.get_signless(8))
|
||||
@ -154,9 +147,8 @@ def _getrf_hlo(platform, gpu_blas, gpu_solver, ctx, dtype, a):
|
||||
operand_output_aliases={0: 0}).results
|
||||
return out[:3]
|
||||
|
||||
target = "blas_getrf_batched_ffi" if use_batched else "solver_getrf_ffi"
|
||||
return custom_call(
|
||||
f"{platform}{target}",
|
||||
f"{platform}solver_getrf_ffi",
|
||||
result_types=[
|
||||
a.type,
|
||||
ir.RankedTensorType.get(batch_dims + (min(m, n),), i32_type),
|
||||
|
@ -98,23 +98,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hipblas_kernels_ffi",
|
||||
srcs = ["//jaxlib/gpu:blas_kernels_ffi.cc"],
|
||||
hdrs = ["//jaxlib/gpu:blas_kernels_ffi.h"],
|
||||
deps = [
|
||||
":hip_blas_handle_pool",
|
||||
":hip_gpu_kernel_helpers",
|
||||
":hip_vendor",
|
||||
"//jaxlib:ffi_helpers",
|
||||
"@com_google_absl//absl/status",
|
||||
"@local_config_rocm//rocm:hipblas",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "_blas",
|
||||
srcs = ["//jaxlib/gpu:blas.cc"],
|
||||
@ -127,7 +110,6 @@ pybind_extension(
|
||||
deps = [
|
||||
":hip_vendor",
|
||||
":hipblas_kernels",
|
||||
":hipblas_kernels_ffi",
|
||||
"//jaxlib:kernel_nanobind_helpers",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
@ -176,12 +158,14 @@ cc_library(
|
||||
srcs = ["//jaxlib/gpu:solver_kernels_ffi.cc"],
|
||||
hdrs = ["//jaxlib/gpu:solver_kernels_ffi.h"],
|
||||
deps = [
|
||||
":hip_blas_handle_pool",
|
||||
":hip_gpu_kernel_helpers",
|
||||
":hip_solver_handle_pool",
|
||||
":hip_vendor",
|
||||
"//jaxlib:ffi_helpers",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@local_config_rocm//rocm:hipblas",
|
||||
"@local_config_rocm//rocm:hipsolver",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
@ -199,14 +183,15 @@ pybind_extension(
|
||||
features = ["-use_header_modules"],
|
||||
module_name = "_solver",
|
||||
deps = [
|
||||
":hip_solver_handle_pool",
|
||||
":hip_gpu_kernel_helpers",
|
||||
":hip_solver_handle_pool",
|
||||
":hip_vendor",
|
||||
":hipsolver_kernels",
|
||||
":hipsolver_kernels_ffi",
|
||||
"//jaxlib:kernel_nanobind_helpers",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@local_config_rocm//rocm:hipblas",
|
||||
"@local_config_rocm//rocm:hipsolver",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
"@nanobind",
|
||||
|
Loading…
x
Reference in New Issue
Block a user