mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Update lu_pivots_to_permutation to use FFI dimensions on GPU.
The XLA FFI interface provides metadata about buffer dimensions, so quantities like batch dimensions can be evaluated on the backend, instead of passed as attributes. This change has the added benefit of allowing this FFI call to support "vectorized" vmap and dynamic shapes. PiperOrigin-RevId: 647343656
This commit is contained in:
parent
43dc4c1ff8
commit
9ae1c56c44
@ -283,6 +283,10 @@ cc_library(
|
||||
":cuda_vendor",
|
||||
"@xla//xla/ffi/api:c_api",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
],
|
||||
)
|
||||
@ -296,7 +300,7 @@ cuda_library(
|
||||
deps = [
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_vendor",
|
||||
"@xla//xla/ffi/api:c_api",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
],
|
||||
|
@ -16,8 +16,14 @@ limitations under the License.
|
||||
#include "jaxlib/gpu/lu_pivot_kernels.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "xla/ffi/api/c_api.h"
|
||||
@ -28,29 +34,51 @@ namespace JAX_GPU_NAMESPACE {
|
||||
|
||||
namespace ffi = xla::ffi;
|
||||
|
||||
XLA_FFI_Error* LuPivotsToPermutation(XLA_FFI_CallFrame* call_frame) {
|
||||
static const auto* kImpl =
|
||||
ffi::Ffi::Bind()
|
||||
.Ctx<ffi::PlatformStream<gpuStream_t>>()
|
||||
.Attr<std::int64_t>("batch_size")
|
||||
.Attr<std::int32_t>("pivot_size")
|
||||
.Attr<std::int32_t>("permutation_size")
|
||||
.Arg<ffi::Buffer<ffi::DataType::S32>>()
|
||||
.Ret<ffi::Buffer<ffi::DataType::S32>>()
|
||||
.To([](gpuStream_t stream, std::int64_t batch_size,
|
||||
std::int32_t pivot_size, std::int32_t permutation_size,
|
||||
auto pivots, auto permutation) -> ffi::Error {
|
||||
LaunchLuPivotsToPermutationKernel(stream, batch_size, pivot_size,
|
||||
permutation_size, pivots.data,
|
||||
permutation->data);
|
||||
if (auto status = JAX_AS_STATUS(gpuGetLastError()); !status.ok()) {
|
||||
return ffi::Error(static_cast<XLA_FFI_Error_Code>(status.code()),
|
||||
std::string(status.message()));
|
||||
}
|
||||
return ffi::Error::Success();
|
||||
})
|
||||
.release();
|
||||
return kImpl->Call(call_frame);
|
||||
template <typename T>
|
||||
inline absl::StatusOr<T> MaybeCastNoOverflow(
|
||||
std::int64_t value, const std::string& source = __FILE__) {
|
||||
if constexpr (sizeof(T) == sizeof(std::int64_t)) {
|
||||
return value;
|
||||
} else {
|
||||
if (value > std::numeric_limits<T>::max()) [[unlikely]] {
|
||||
return absl::InvalidArgumentError(absl::StrFormat(
|
||||
"%s: Value (=%d) exceeds the maximum representable value of the "
|
||||
"desired type",
|
||||
source, value));
|
||||
}
|
||||
return static_cast<T>(value);
|
||||
}
|
||||
}
|
||||
|
||||
ffi::Error LuPivotsToPermutationImpl(
|
||||
gpuStream_t stream, std::int32_t permutation_size,
|
||||
ffi::Buffer<ffi::DataType::S32> pivots,
|
||||
ffi::Result<ffi::Buffer<ffi::DataType::S32>> permutation) {
|
||||
auto dims = pivots.dimensions;
|
||||
if (dims.size() < 1) {
|
||||
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
|
||||
"pivots must have at least one dimension");
|
||||
}
|
||||
auto maybe_pivot_size = MaybeCastNoOverflow<std::int32_t>(dims.back());
|
||||
if (!maybe_pivot_size.ok()) {
|
||||
return ffi::Error(
|
||||
static_cast<XLA_FFI_Error_Code>(maybe_pivot_size.status().code()),
|
||||
std::string(maybe_pivot_size.status().message()));
|
||||
}
|
||||
std::int32_t pivot_size = maybe_pivot_size.value();
|
||||
std::int64_t batch_size = 1;
|
||||
if (dims.size() >= 2) {
|
||||
batch_size =
|
||||
absl::c_accumulate(dims.first(dims.size() - 1), 1, std::multiplies<>());
|
||||
}
|
||||
LaunchLuPivotsToPermutationKernel(stream, batch_size, pivot_size,
|
||||
permutation_size, pivots.data,
|
||||
permutation->data);
|
||||
if (auto status = JAX_AS_STATUS(gpuGetLastError()); !status.ok()) {
|
||||
return ffi::Error(static_cast<XLA_FFI_Error_Code>(status.code()),
|
||||
std::string(status.message()));
|
||||
}
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
|
@ -19,11 +19,13 @@ limitations under the License.
|
||||
#include <cstdint>
|
||||
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "xla/ffi/api/c_api.h"
|
||||
#include "xla/ffi/api/ffi.h"
|
||||
|
||||
namespace jax {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
|
||||
namespace ffi = xla::ffi;
|
||||
|
||||
void LaunchLuPivotsToPermutationKernel(gpuStream_t stream,
|
||||
std::int64_t batch_size,
|
||||
std::int32_t pivot_size,
|
||||
@ -31,7 +33,17 @@ void LaunchLuPivotsToPermutationKernel(gpuStream_t stream,
|
||||
const std::int32_t* pivots,
|
||||
std::int32_t* permutation);
|
||||
|
||||
XLA_FFI_Error* LuPivotsToPermutation(XLA_FFI_CallFrame* call_frame);
|
||||
ffi::Error LuPivotsToPermutationImpl(
|
||||
gpuStream_t stream, std::int32_t permutation_size,
|
||||
ffi::Buffer<ffi::DataType::S32> pivots,
|
||||
ffi::Result<ffi::Buffer<ffi::DataType::S32>> permutation);
|
||||
|
||||
XLA_FFI_DEFINE_HANDLER(LuPivotsToPermutation, LuPivotsToPermutationImpl,
|
||||
ffi::Ffi::Bind()
|
||||
.Ctx<ffi::PlatformStream<gpuStream_t>>()
|
||||
.Attr<std::int32_t>("permutation_size")
|
||||
.Arg<ffi::Buffer<ffi::DataType::S32>>()
|
||||
.Ret<ffi::Buffer<ffi::DataType::S32>>());
|
||||
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
||||
|
@ -66,13 +66,9 @@ def _lu_pivots_to_permutation_hlo(platform, gpu_linalg, pivots, *, permutation_s
|
||||
typ = ir.RankedTensorType(pivots.type)
|
||||
dims = typ.shape
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
i64_type = ir.IntegerType.get_signless(64)
|
||||
|
||||
assert typ.element_type == i32_type, typ
|
||||
|
||||
batch_size = _prod(dims[:-1])
|
||||
pivot_size = dims[-1]
|
||||
|
||||
if not gpu_linalg:
|
||||
raise GpuLibNotLinkedError()
|
||||
|
||||
@ -87,8 +83,6 @@ def _lu_pivots_to_permutation_hlo(platform, gpu_linalg, pivots, *, permutation_s
|
||||
result_types=[permutations_type],
|
||||
operands=[pivots],
|
||||
backend_config=dict(
|
||||
batch_size=ir.IntegerAttr.get(i64_type, batch_size),
|
||||
pivot_size=ir.IntegerAttr.get(i32_type, pivot_size),
|
||||
permutation_size=ir.IntegerAttr.get(i32_type, permutation_size),
|
||||
),
|
||||
operand_layouts=[pivots_layout],
|
||||
|
@ -201,6 +201,10 @@ cc_library(
|
||||
":hip_gpu_kernel_helpers",
|
||||
":hip_lu_pivot_kernels_impl",
|
||||
":hip_vendor",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
"@xla//xla/ffi/api:c_api",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
@ -215,7 +219,7 @@ rocm_library(
|
||||
":hip_gpu_kernel_helpers",
|
||||
":hip_vendor",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
"@xla//xla/ffi/api:c_api",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
],
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user