diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 037704462..27f0a8e2a 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -283,6 +283,10 @@ cc_library( ":cuda_vendor", "@xla//xla/ffi/api:c_api", "@xla//xla/ffi/api:ffi", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", "@local_config_cuda//cuda:cuda_headers", ], ) @@ -296,7 +300,7 @@ cuda_library( deps = [ ":cuda_gpu_kernel_helpers", ":cuda_vendor", - "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", "@xla//xla/service:custom_call_status", "@local_config_cuda//cuda:cuda_headers", ], diff --git a/jaxlib/gpu/lu_pivot_kernels.cc b/jaxlib/gpu/lu_pivot_kernels.cc index dc5b71716..b2c636227 100644 --- a/jaxlib/gpu/lu_pivot_kernels.cc +++ b/jaxlib/gpu/lu_pivot_kernels.cc @@ -16,8 +16,14 @@ limitations under the License. #include "jaxlib/gpu/lu_pivot_kernels.h" #include +#include +#include #include +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/vendor.h" #include "xla/ffi/api/c_api.h" @@ -28,29 +34,51 @@ namespace JAX_GPU_NAMESPACE { namespace ffi = xla::ffi; -XLA_FFI_Error* LuPivotsToPermutation(XLA_FFI_CallFrame* call_frame) { - static const auto* kImpl = - ffi::Ffi::Bind() - .Ctx>() - .Attr("batch_size") - .Attr("pivot_size") - .Attr("permutation_size") - .Arg>() - .Ret>() - .To([](gpuStream_t stream, std::int64_t batch_size, - std::int32_t pivot_size, std::int32_t permutation_size, - auto pivots, auto permutation) -> ffi::Error { - LaunchLuPivotsToPermutationKernel(stream, batch_size, pivot_size, - permutation_size, pivots.data, - permutation->data); - if (auto status = JAX_AS_STATUS(gpuGetLastError()); !status.ok()) { - return ffi::Error(static_cast(status.code()), - std::string(status.message())); - } - return ffi::Error::Success(); - }) - .release(); - return kImpl->Call(call_frame); +template +inline absl::StatusOr MaybeCastNoOverflow( + std::int64_t value, const std::string& source = __FILE__) { + if constexpr (sizeof(T) == sizeof(std::int64_t)) { + return value; + } else { + if (value > std::numeric_limits::max()) [[unlikely]] { + return absl::InvalidArgumentError(absl::StrFormat( + "%s: Value (=%d) exceeds the maximum representable value of the " + "desired type", + source, value)); + } + return static_cast(value); + } +} + +ffi::Error LuPivotsToPermutationImpl( + gpuStream_t stream, std::int32_t permutation_size, + ffi::Buffer pivots, + ffi::Result> permutation) { + auto dims = pivots.dimensions; + if (dims.size() < 1) { + return ffi::Error(ffi::ErrorCode::kInvalidArgument, + "pivots must have at least one dimension"); + } + auto maybe_pivot_size = MaybeCastNoOverflow(dims.back()); + if (!maybe_pivot_size.ok()) { + return ffi::Error( + static_cast(maybe_pivot_size.status().code()), + std::string(maybe_pivot_size.status().message())); + } + std::int32_t pivot_size = maybe_pivot_size.value(); + std::int64_t batch_size = 1; + if (dims.size() >= 2) { + batch_size = + absl::c_accumulate(dims.first(dims.size() - 1), 1, std::multiplies<>()); + } + LaunchLuPivotsToPermutationKernel(stream, batch_size, pivot_size, + permutation_size, pivots.data, + permutation->data); + if (auto status = JAX_AS_STATUS(gpuGetLastError()); !status.ok()) { + return ffi::Error(static_cast(status.code()), + std::string(status.message())); + } + return ffi::Error::Success(); } } // namespace JAX_GPU_NAMESPACE diff --git a/jaxlib/gpu/lu_pivot_kernels.h b/jaxlib/gpu/lu_pivot_kernels.h index a4af440d5..b2cceb883 100644 --- a/jaxlib/gpu/lu_pivot_kernels.h +++ b/jaxlib/gpu/lu_pivot_kernels.h @@ -19,11 +19,13 @@ limitations under the License. #include #include "jaxlib/gpu/vendor.h" -#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" namespace jax { namespace JAX_GPU_NAMESPACE { +namespace ffi = xla::ffi; + void LaunchLuPivotsToPermutationKernel(gpuStream_t stream, std::int64_t batch_size, std::int32_t pivot_size, @@ -31,7 +33,17 @@ void LaunchLuPivotsToPermutationKernel(gpuStream_t stream, const std::int32_t* pivots, std::int32_t* permutation); -XLA_FFI_Error* LuPivotsToPermutation(XLA_FFI_CallFrame* call_frame); +ffi::Error LuPivotsToPermutationImpl( + gpuStream_t stream, std::int32_t permutation_size, + ffi::Buffer pivots, + ffi::Result> permutation); + +XLA_FFI_DEFINE_HANDLER(LuPivotsToPermutation, LuPivotsToPermutationImpl, + ffi::Ffi::Bind() + .Ctx>() + .Attr("permutation_size") + .Arg>() + .Ret>()); } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu_linalg.py b/jaxlib/gpu_linalg.py index a1ce5fa4d..af79b3ae7 100644 --- a/jaxlib/gpu_linalg.py +++ b/jaxlib/gpu_linalg.py @@ -66,13 +66,9 @@ def _lu_pivots_to_permutation_hlo(platform, gpu_linalg, pivots, *, permutation_s typ = ir.RankedTensorType(pivots.type) dims = typ.shape i32_type = ir.IntegerType.get_signless(32) - i64_type = ir.IntegerType.get_signless(64) assert typ.element_type == i32_type, typ - batch_size = _prod(dims[:-1]) - pivot_size = dims[-1] - if not gpu_linalg: raise GpuLibNotLinkedError() @@ -87,8 +83,6 @@ def _lu_pivots_to_permutation_hlo(platform, gpu_linalg, pivots, *, permutation_s result_types=[permutations_type], operands=[pivots], backend_config=dict( - batch_size=ir.IntegerAttr.get(i64_type, batch_size), - pivot_size=ir.IntegerAttr.get(i32_type, pivot_size), permutation_size=ir.IntegerAttr.get(i32_type, permutation_size), ), operand_layouts=[pivots_layout], diff --git a/jaxlib/rocm/BUILD.bazel b/jaxlib/rocm/BUILD.bazel index 74d5ef30b..13e632086 100644 --- a/jaxlib/rocm/BUILD.bazel +++ b/jaxlib/rocm/BUILD.bazel @@ -201,6 +201,10 @@ cc_library( ":hip_gpu_kernel_helpers", ":hip_lu_pivot_kernels_impl", ":hip_vendor", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", "@local_config_rocm//rocm:rocm_headers", "@xla//xla/ffi/api:c_api", "@xla//xla/ffi/api:ffi", @@ -215,7 +219,7 @@ rocm_library( ":hip_gpu_kernel_helpers", ":hip_vendor", "@local_config_rocm//rocm:rocm_headers", - "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", ], )