Update lu_pivots_to_permutation to use FFI dimensions on GPU.

The XLA FFI interface provides metadata about buffer dimensions, so quantities
like batch dimensions can be evaluated on the backend, instead of passed as
attributes. This change has the added benefit of allowing this FFI call to
support "vectorized" vmap and dynamic shapes.

PiperOrigin-RevId: 647343656
This commit is contained in:
Dan Foreman-Mackey 2024-06-27 09:24:15 -07:00 committed by jax authors
parent 43dc4c1ff8
commit 9ae1c56c44
5 changed files with 75 additions and 33 deletions

View File

@ -283,6 +283,10 @@ cc_library(
":cuda_vendor",
"@xla//xla/ffi/api:c_api",
"@xla//xla/ffi/api:ffi",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@local_config_cuda//cuda:cuda_headers",
],
)
@ -296,7 +300,7 @@ cuda_library(
deps = [
":cuda_gpu_kernel_helpers",
":cuda_vendor",
"@xla//xla/ffi/api:c_api",
"@xla//xla/ffi/api:ffi",
"@xla//xla/service:custom_call_status",
"@local_config_cuda//cuda:cuda_headers",
],

View File

@ -16,8 +16,14 @@ limitations under the License.
#include "jaxlib/gpu/lu_pivot_kernels.h"
#include <cstdint>
#include <functional>
#include <limits>
#include <string>
#include "absl/algorithm/container.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/vendor.h"
#include "xla/ffi/api/c_api.h"
@ -28,29 +34,51 @@ namespace JAX_GPU_NAMESPACE {
namespace ffi = xla::ffi;
XLA_FFI_Error* LuPivotsToPermutation(XLA_FFI_CallFrame* call_frame) {
static const auto* kImpl =
ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<gpuStream_t>>()
.Attr<std::int64_t>("batch_size")
.Attr<std::int32_t>("pivot_size")
.Attr<std::int32_t>("permutation_size")
.Arg<ffi::Buffer<ffi::DataType::S32>>()
.Ret<ffi::Buffer<ffi::DataType::S32>>()
.To([](gpuStream_t stream, std::int64_t batch_size,
std::int32_t pivot_size, std::int32_t permutation_size,
auto pivots, auto permutation) -> ffi::Error {
LaunchLuPivotsToPermutationKernel(stream, batch_size, pivot_size,
permutation_size, pivots.data,
permutation->data);
if (auto status = JAX_AS_STATUS(gpuGetLastError()); !status.ok()) {
return ffi::Error(static_cast<XLA_FFI_Error_Code>(status.code()),
std::string(status.message()));
}
return ffi::Error::Success();
})
.release();
return kImpl->Call(call_frame);
template <typename T>
inline absl::StatusOr<T> MaybeCastNoOverflow(
std::int64_t value, const std::string& source = __FILE__) {
if constexpr (sizeof(T) == sizeof(std::int64_t)) {
return value;
} else {
if (value > std::numeric_limits<T>::max()) [[unlikely]] {
return absl::InvalidArgumentError(absl::StrFormat(
"%s: Value (=%d) exceeds the maximum representable value of the "
"desired type",
source, value));
}
return static_cast<T>(value);
}
}
ffi::Error LuPivotsToPermutationImpl(
gpuStream_t stream, std::int32_t permutation_size,
ffi::Buffer<ffi::DataType::S32> pivots,
ffi::Result<ffi::Buffer<ffi::DataType::S32>> permutation) {
auto dims = pivots.dimensions;
if (dims.size() < 1) {
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
"pivots must have at least one dimension");
}
auto maybe_pivot_size = MaybeCastNoOverflow<std::int32_t>(dims.back());
if (!maybe_pivot_size.ok()) {
return ffi::Error(
static_cast<XLA_FFI_Error_Code>(maybe_pivot_size.status().code()),
std::string(maybe_pivot_size.status().message()));
}
std::int32_t pivot_size = maybe_pivot_size.value();
std::int64_t batch_size = 1;
if (dims.size() >= 2) {
batch_size =
absl::c_accumulate(dims.first(dims.size() - 1), 1, std::multiplies<>());
}
LaunchLuPivotsToPermutationKernel(stream, batch_size, pivot_size,
permutation_size, pivots.data,
permutation->data);
if (auto status = JAX_AS_STATUS(gpuGetLastError()); !status.ok()) {
return ffi::Error(static_cast<XLA_FFI_Error_Code>(status.code()),
std::string(status.message()));
}
return ffi::Error::Success();
}
} // namespace JAX_GPU_NAMESPACE

View File

@ -19,11 +19,13 @@ limitations under the License.
#include <cstdint>
#include "jaxlib/gpu/vendor.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/ffi.h"
namespace jax {
namespace JAX_GPU_NAMESPACE {
namespace ffi = xla::ffi;
void LaunchLuPivotsToPermutationKernel(gpuStream_t stream,
std::int64_t batch_size,
std::int32_t pivot_size,
@ -31,7 +33,17 @@ void LaunchLuPivotsToPermutationKernel(gpuStream_t stream,
const std::int32_t* pivots,
std::int32_t* permutation);
XLA_FFI_Error* LuPivotsToPermutation(XLA_FFI_CallFrame* call_frame);
ffi::Error LuPivotsToPermutationImpl(
gpuStream_t stream, std::int32_t permutation_size,
ffi::Buffer<ffi::DataType::S32> pivots,
ffi::Result<ffi::Buffer<ffi::DataType::S32>> permutation);
XLA_FFI_DEFINE_HANDLER(LuPivotsToPermutation, LuPivotsToPermutationImpl,
ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<gpuStream_t>>()
.Attr<std::int32_t>("permutation_size")
.Arg<ffi::Buffer<ffi::DataType::S32>>()
.Ret<ffi::Buffer<ffi::DataType::S32>>());
} // namespace JAX_GPU_NAMESPACE
} // namespace jax

View File

@ -66,13 +66,9 @@ def _lu_pivots_to_permutation_hlo(platform, gpu_linalg, pivots, *, permutation_s
typ = ir.RankedTensorType(pivots.type)
dims = typ.shape
i32_type = ir.IntegerType.get_signless(32)
i64_type = ir.IntegerType.get_signless(64)
assert typ.element_type == i32_type, typ
batch_size = _prod(dims[:-1])
pivot_size = dims[-1]
if not gpu_linalg:
raise GpuLibNotLinkedError()
@ -87,8 +83,6 @@ def _lu_pivots_to_permutation_hlo(platform, gpu_linalg, pivots, *, permutation_s
result_types=[permutations_type],
operands=[pivots],
backend_config=dict(
batch_size=ir.IntegerAttr.get(i64_type, batch_size),
pivot_size=ir.IntegerAttr.get(i32_type, pivot_size),
permutation_size=ir.IntegerAttr.get(i32_type, permutation_size),
),
operand_layouts=[pivots_layout],

View File

@ -201,6 +201,10 @@ cc_library(
":hip_gpu_kernel_helpers",
":hip_lu_pivot_kernels_impl",
":hip_vendor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@local_config_rocm//rocm:rocm_headers",
"@xla//xla/ffi/api:c_api",
"@xla//xla/ffi/api:ffi",
@ -215,7 +219,7 @@ rocm_library(
":hip_gpu_kernel_helpers",
":hip_vendor",
"@local_config_rocm//rocm:rocm_headers",
"@xla//xla/ffi/api:c_api",
"@xla//xla/ffi/api:ffi",
],
)