1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-20 13:56:07 +00:00

Fix C++ registration of FFI handlers and consolidate gpu/linalg kernel implementation.

This change does a few things (arguably too many):

1. The key change here is that it fixes the handler registration in `jaxlib/gpu/gpu_kernels.cc` for the two handlers that use the XLA FFI API. A previous attempt at this change caused downstream issues because of duplicate registrations, but we were able to fix that directly in XLA.

2. A second related change is to declare and define the XLA FFI handlers consistently using the `XLA_FFI_DECLARE_HANDLER_SYMBOL` and `XLA_FFI_DEFINE_HANDLER_SYMBOL` macros. We need to use these macros instead of the `XLA_FFI_DEFINE_HANDLER` version which produces a lambda, so that when XLA checks the address of the handler during registration it is consistent. Without this change, the downstream tests would continue to fail.

3. The final change is to consolidate the `cholesky_update_kernel` and `lu_pivot_kernels` implementations into a common `linalg_kernels` target. This makes the implementation of the `_linalg` nanobind module consistent with the other targets within `jaxlib/gpu`, and (I think!) makes the details easier to follow. This last change is less urgent, but it was what I set out to do so that's why I'm suggesting them all together, but I can split this in two if that would be preferred.

PiperOrigin-RevId: 651107659
This commit is contained in:
Dan Foreman-Mackey 2024-07-10 12:08:30 -07:00 committed by jax authors
parent 44359cb30a
commit 4f394828e1
17 changed files with 293 additions and 424 deletions

@ -123,6 +123,17 @@ cc_library(
],
)
cc_library(
name = "ffi_helpers",
hdrs = ["ffi_helpers.h"],
features = ["-use_header_modules"],
deps = [
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
],
)
cc_library(
name = "kernel_nanobind_helpers",
hdrs = ["kernel_nanobind_helpers.h"],
@ -133,6 +144,7 @@ cc_library(
features = ["-use_header_modules"],
deps = [
":kernel_helpers",
"@xla//xla/ffi/api:c_api",
"@xla//xla/tsl/python/lib/core:numpy",
"@com_google_absl//absl/base",
"@nanobind",

@ -272,17 +272,21 @@ pybind_extension(
)
cc_library(
name = "cuda_lu_pivot_kernels",
name = "cuda_linalg_kernels",
srcs = [
"//jaxlib/gpu:lu_pivot_kernels.cc",
"//jaxlib/gpu:linalg_kernels.cc",
],
hdrs = ["//jaxlib/gpu:lu_pivot_kernels.h"],
hdrs = ["//jaxlib/gpu:linalg_kernels.h"],
features = ["-use_header_modules"],
deps = [
":cuda_gpu_kernel_helpers",
":cuda_lu_pivot_kernels_impl",
":cuda_linalg_kernels_impl",
":cuda_vendor",
"//jaxlib:ffi_helpers",
"//jaxlib:kernel_helpers",
"@xla//xla/ffi/api:c_api",
"@xla//xla/ffi/api:ffi",
"@xla//xla/service:custom_call_status",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
@ -292,11 +296,13 @@ cc_library(
)
cuda_library(
name = "cuda_lu_pivot_kernels_impl",
name = "cuda_linalg_kernels_impl",
srcs = [
"//jaxlib/gpu:lu_pivot_kernels.cu.cc",
"//jaxlib/gpu:linalg_kernels.cu.cc",
],
hdrs = [
"//jaxlib/gpu:linalg_kernels.h",
],
hdrs = ["//jaxlib/gpu:lu_pivot_kernels.h"],
deps = [
":cuda_gpu_kernel_helpers",
":cuda_vendor",
@ -306,43 +312,6 @@ cuda_library(
],
)
cc_library(
name = "cholesky_update_kernel",
srcs = [
"//jaxlib/gpu:cholesky_update_kernel.cc",
],
hdrs = ["//jaxlib/gpu:cholesky_update_kernel.h"],
features = ["-use_header_modules"],
deps = [
":cholesky_update_kernel_impl",
":cuda_gpu_kernel_helpers",
":cuda_vendor",
":cusolver_kernels",
"//jaxlib:kernel_helpers",
"@xla//xla/service:custom_call_status",
"@com_google_absl//absl/status",
"@local_config_cuda//cuda:cuda_headers",
],
)
cuda_library(
name = "cholesky_update_kernel_impl",
srcs = [
"//jaxlib/gpu:cholesky_update_kernel.cu.cc",
],
hdrs = [
"//jaxlib/gpu:cholesky_update_kernel.h",
],
deps = [
":cuda_gpu_kernel_helpers",
":cuda_vendor",
":cusolver_kernels",
"//jaxlib:kernel_helpers",
"@xla//xla/service:custom_call_status",
"@local_config_cuda//cuda:cuda_headers",
],
)
pybind_extension(
name = "_linalg",
srcs = ["//jaxlib/gpu:linalg.cc"],
@ -353,10 +322,8 @@ pybind_extension(
features = ["-use_header_modules"],
module_name = "_linalg",
deps = [
":cholesky_update_kernel",
":cuda_gpu_kernel_helpers",
":cuda_lu_pivot_kernels",
":cuda_lu_pivot_kernels_impl",
":cuda_linalg_kernels",
":cuda_vendor",
"//jaxlib:kernel_nanobind_helpers",
"@xla//xla/tsl/cuda:cudart",
@ -381,6 +348,7 @@ cc_library(
"@xla//xla/ffi/api:ffi",
"@xla//xla/service:custom_call_status",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/status",
"@local_config_cuda//cuda:cuda_headers",
],
)
@ -395,7 +363,6 @@ cuda_library(
":cuda_gpu_kernel_helpers",
":cuda_vendor",
"//jaxlib:kernel_helpers",
"@xla//xla/ffi/api:c_api",
"@xla//xla/ffi/api:ffi",
"@xla//xla/service:custom_call_status",
"@local_config_cuda//cuda:cuda_headers",
@ -426,9 +393,8 @@ cc_library(
srcs = ["//jaxlib/gpu:gpu_kernels.cc"],
visibility = ["//visibility:public"],
deps = [
":cholesky_update_kernel",
":cublas_kernels",
":cuda_lu_pivot_kernels",
":cuda_linalg_kernels",
":cuda_prng_kernels",
":cuda_vendor",
":cudnn_rnn_kernels",

32
jaxlib/ffi_helpers.h Normal file

@ -0,0 +1,32 @@
#ifndef JAXLIB_FFI_HELPERS_H_
#define JAXLIB_FFI_HELPERS_H_
#include <cstdint>
#include <limits>
#include <string>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
namespace jax {
template <typename T>
inline absl::StatusOr<T> MaybeCastNoOverflow(
std::int64_t value, const std::string& source = __FILE__) {
if constexpr (sizeof(T) == sizeof(std::int64_t)) {
return value;
} else {
if (value > std::numeric_limits<T>::max()) [[unlikely]] {
return absl::InvalidArgumentError(absl::StrFormat(
"%s: Value (=%d) exceeds the maximum representable value of the "
"desired type",
source, value));
}
return static_cast<T>(value);
}
}
} // namespace jax
#endif // JAXLIB_FFI_HELPERS_H_

@ -27,16 +27,13 @@ exports_files(srcs = [
"blas.cc",
"blas_kernels.cc",
"blas_kernels.h",
"cholesky_update_kernel.cc",
"cholesky_update_kernel.cu.cc",
"cholesky_update_kernel.h",
"gpu_kernel_helpers.cc",
"gpu_kernel_helpers.h",
"gpu_kernels.cc",
"linalg.cc",
"lu_pivot_kernels.cc",
"lu_pivot_kernels.cu.cc",
"lu_pivot_kernels.h",
"linalg_kernels.cc",
"linalg_kernels.cu.cc",
"linalg_kernels.h",
"prng.cc",
"prng_kernels.cc",
"prng_kernels.cu.cc",

@ -1,53 +0,0 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/gpu/cholesky_update_kernel.h"
#include <cstddef>
#include <string_view>
#include "absl/status/status.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/vendor.h"
#include "jaxlib/kernel_helpers.h"
#include "xla/service/custom_call_status.h"
namespace jax {
namespace JAX_GPU_NAMESPACE {
namespace {
absl::Status CholeskyUpdateImpl(gpuStream_t stream, void** buffers,
const char* opaque, std::size_t opaque_len) {
auto s = UnpackDescriptor<CholeskyUpdateDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const CholeskyUpdateDescriptor& d = **s;
LaunchCholeskyUpdateKernel(stream, buffers, d);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError()));
return absl::OkStatus();
}
} // namespace
void CholeskyUpdate(gpuStream_t stream, void** buffers,
const char* opaque, size_t opaque_len,
XlaCustomCallStatus* status) {
auto s = CholeskyUpdateImpl(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
std::string_view message = s.message();
XlaCustomCallStatusSetFailure(status, message.data(), message.length());
}
}
} // namespace JAX_GPU_NAMESPACE
} // namespace jax

@ -1,136 +0,0 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/gpu/cholesky_update_kernel.h"
#include <stdio.h>
#ifdef JAX_GPU_HIP
#include "rocm/include/hip/amd_detail/amd_hip_cooperative_groups.h"
#else // JAX_GPU_CUDA
#include "third_party/gpus/cuda/include/cooperative_groups.h"
#endif
namespace cg = cooperative_groups;
namespace jax {
namespace JAX_GPU_NAMESPACE {
namespace {
template <typename T>
__device__ void drotg(T* da, T* db, T* c, T* s) {
if (*db == 0) {
*c = 1.;
*s = 0.;
return;
}
T denominator = max(abs(*da), abs(*db));
T a = *da / denominator;
T b = *db / denominator;
T rh = rhypot(a, b);
*c = a * rh;
*s = -(b * rh);
return;
}
template <typename T>
__global__ void CholeskyUpdateKernel(
T* rMatrix, T* uVector,
int nSize) {
cg::grid_group grid = cg::this_grid();
int k = grid.thread_rank();
T c, s;
for (int step = 0; step < 2 * nSize; ++step) {
grid.sync();
int i = step - k;
if (i < k || i >= nSize || k >= nSize) {
continue;
}
if (i == k) {
drotg(
rMatrix + k * nSize + k,
uVector + k,
&c,
&s);
}
T r_i = c * rMatrix[k * nSize + i] - s * uVector[i];
uVector[i] = s * rMatrix[k * nSize + i] + c * uVector[i];
rMatrix[k * nSize + i] = r_i;
}
}
} // namespace
template <typename T>
void LaunchCholeskyUpdateKernelBody(
gpuStream_t stream, void** buffers,
int grid_dim, int block_dim, int nSize) {
T* rMatrix = reinterpret_cast<T*>(buffers[2]);
T* uVector = reinterpret_cast<T*>(buffers[3]);
void* arg_ptrs[3] = {
reinterpret_cast<void*>(&rMatrix),
reinterpret_cast<void*>(&uVector),
reinterpret_cast<void*>(&nSize),
};
#ifdef JAX_GPU_HIP
hipLaunchCooperativeKernel(
(void*) CholeskyUpdateKernel<T>, grid_dim, block_dim, arg_ptrs,
/*dynamic_shared_mem_bytes=*/ 0, stream);
#else // JAX_GPU_CUDA
cudaLaunchCooperativeKernel(
(void*) CholeskyUpdateKernel<T>, grid_dim, block_dim, arg_ptrs,
/*dynamic_shared_mem_bytes=*/ 0, stream);
#endif
}
void LaunchCholeskyUpdateKernel(
gpuStream_t stream, void** buffers,
CholeskyUpdateDescriptor descriptor) {
int nSize = descriptor.matrix_size;
LinalgType type = descriptor.linalg_type;
int dev = 0;
#ifdef JAX_GPU_HIP
hipDeviceProp_t deviceProp;
hipGetDeviceProperties(&deviceProp, dev);
#else // JAX_GPU_CUDA
cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, dev);
#endif
int block_dim = deviceProp.maxThreadsPerBlock;
int grid_dim = deviceProp.multiProcessorCount;
switch (type) {
case LinalgType::F64:
LaunchCholeskyUpdateKernelBody<double>(
stream, buffers, grid_dim, block_dim, nSize);
break;
case LinalgType::F32:
LaunchCholeskyUpdateKernelBody<float>(
stream, buffers, grid_dim, block_dim, nSize);
break;
}
}
} // namespace JAX_GPU_NAMESPACE
} // namespace jax

@ -1,50 +0,0 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef JAXLIB_GPU_CHOLESKY_UPDATE_KERNEL_H_
#define JAXLIB_GPU_CHOLESKY_UPDATE_KERNEL_H_
#include <cstddef>
#include <cstdint>
#include <string>
#include "jaxlib/gpu/vendor.h"
#include "xla/service/custom_call_status.h"
namespace jax {
namespace JAX_GPU_NAMESPACE {
enum LinalgType {
F32 = 0,
F64 = 1,
};
struct CholeskyUpdateDescriptor {
LinalgType linalg_type;
std::int64_t matrix_size; // leading dim (N) for a square (NxN)matrix
};
void LaunchCholeskyUpdateKernel(
gpuStream_t stream, void** buffers, CholeskyUpdateDescriptor descriptor);
void CholeskyUpdate(gpuStream_t stream, void** buffers,
const char* opaque, size_t opaque_len,
XlaCustomCallStatus* status);
} // namespace JAX_GPU_NAMESPACE
} // namespace jax
#endif // JAXLIB_GPU_CHOLESKY_UPDATE_KERNEL_H_

@ -17,8 +17,7 @@ limitations under the License.
// JAX-generated HLO code from outside of JAX.
#include "jaxlib/gpu/blas_kernels.h"
#include "jaxlib/gpu/cholesky_update_kernel.h"
#include "jaxlib/gpu/lu_pivot_kernels.h"
#include "jaxlib/gpu/linalg_kernels.h"
#include "jaxlib/gpu/prng_kernels.h"
#include "jaxlib/gpu/rnn_kernels.h"
#include "jaxlib/gpu/solver_kernels.h"
@ -39,16 +38,10 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_geqrf_batched", GeqrfBatched,
"CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cudnn_rnn", RNNForward, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cudnn_rnn_bwd", RNNBackward, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cu_cholesky_update",
CholeskyUpdate, "CUDA");
// TODO(b/350111820): use the new FFI registration mechanism
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cu_lu_pivots_to_permutation",
LuPivotsToPermutation, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cu_cholesky_update", CholeskyUpdate,
"CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cu_threefry2x32", ThreeFry2x32,
"CUDA");
// TODO(b/350111820): use the new FFI registration mechanism
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cu_threefry2x32_ffi",
ThreeFry2x32Ffi, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_getrf", Getrf, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_geqrf", Geqrf, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_csrlsvqr", Csrlsvqr, "CUDA");
@ -59,6 +52,11 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_sytrd", Sytrd, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_gesvd", Gesvd, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_gesvdj", Gesvdj, "CUDA");
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cu_lu_pivots_to_permutation",
"CUDA", LuPivotsToPermutation);
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cu_threefry2x32_ffi", "CUDA",
ThreeFry2x32Ffi);
#if JAX_CUSPARSE_11300
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_csr_todense", CsrToDense,
"CUDA");

@ -14,9 +14,7 @@ limitations under the License.
==============================================================================*/
#include "nanobind/nanobind.h"
#include "jaxlib/gpu/cholesky_update_kernel.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/lu_pivot_kernels.h"
#include "jaxlib/gpu/linalg_kernels.h"
#include "jaxlib/gpu/vendor.h"
#include "jaxlib/kernel_nanobind_helpers.h"
#include "xla/tsl/python/lib/core/numpy.h"
@ -27,13 +25,10 @@ namespace {
namespace nb = nanobind;
nb::bytes BuildCholeskyUpdateDescriptor(
dtype np_type,
std::int64_t matrix_size) {
LinalgType linalg_type = (
np_type.itemsize() == 4 ? LinalgType::F32 : LinalgType::F64);
nb::bytes BuildCholeskyUpdateDescriptor(dtype np_type,
std::int64_t matrix_size) {
LinalgType linalg_type =
(np_type.itemsize() == 4 ? LinalgType::F32 : LinalgType::F64);
return PackDescriptor(CholeskyUpdateDescriptor{linalg_type, matrix_size});
}
@ -43,8 +38,9 @@ NB_MODULE(_linalg, m) {
m.def("registrations", []() {
nb::dict dict;
dict[JAX_GPU_PREFIX "_lu_pivots_to_permutation"] =
nb::capsule(reinterpret_cast<void*>(+LuPivotsToPermutation));
dict["cu_cholesky_update"] = EncapsulateFunction(CholeskyUpdate);
EncapsulateFfiHandler(LuPivotsToPermutation);
dict[JAX_GPU_PREFIX "_cholesky_update"] =
EncapsulateFunction(CholeskyUpdate);
return dict;
});
m.def("build_cholesky_update_descriptor", &BuildCholeskyUpdateDescriptor);

@ -13,43 +13,54 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/gpu/lu_pivot_kernels.h"
#include "jaxlib/gpu/linalg_kernels.h"
#include <cstddef>
#include <cstdint>
#include <functional>
#include <limits>
#include <string>
#include <string_view>
#include "absl/algorithm/container.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "jaxlib/ffi_helpers.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/vendor.h"
#include "jaxlib/kernel_helpers.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/ffi.h"
#include "xla/service/custom_call_status.h"
namespace jax {
namespace JAX_GPU_NAMESPACE {
namespace ffi = xla::ffi;
template <typename T>
inline absl::StatusOr<T> MaybeCastNoOverflow(
std::int64_t value, const std::string& source = __FILE__) {
if constexpr (sizeof(T) == sizeof(std::int64_t)) {
return value;
} else {
if (value > std::numeric_limits<T>::max()) [[unlikely]] {
return absl::InvalidArgumentError(absl::StrFormat(
"%s: Value (=%d) exceeds the maximum representable value of the "
"desired type",
source, value));
}
return static_cast<T>(value);
namespace {
absl::Status CholeskyUpdateImpl(gpuStream_t stream, void** buffers,
const char* opaque, std::size_t opaque_len) {
auto s = UnpackDescriptor<CholeskyUpdateDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const CholeskyUpdateDescriptor& d = **s;
LaunchCholeskyUpdateKernel(stream, buffers, d);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError()));
return absl::OkStatus();
}
} // namespace
void CholeskyUpdate(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CholeskyUpdateImpl(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
std::string_view message = s.message();
XlaCustomCallStatusSetFailure(status, message.data(), message.length());
}
}
namespace {
ffi::Error LuPivotsToPermutationImpl(
gpuStream_t stream, std::int32_t permutation_size,
ffi::Buffer<ffi::DataType::S32> pivots,
@ -81,6 +92,14 @@ ffi::Error LuPivotsToPermutationImpl(
}
return ffi::Error::Success();
}
} // namespace
XLA_FFI_DEFINE_HANDLER_SYMBOL(LuPivotsToPermutation, LuPivotsToPermutationImpl,
ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<gpuStream_t>>()
.Attr<std::int32_t>("permutation_size")
.Arg<ffi::Buffer<ffi::DataType::S32>>()
.Ret<ffi::Buffer<ffi::DataType::S32>>());
} // namespace JAX_GPU_NAMESPACE
} // namespace jax

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/gpu/lu_pivot_kernels.h"
#include "jaxlib/gpu/linalg_kernels.h"
#include <array>
#include <cstdint>
@ -21,8 +21,110 @@ limitations under the License.
#include "jaxlib/gpu/vendor.h"
#ifdef JAX_GPU_HIP
#include "rocm/include/hip/amd_detail/amd_hip_cooperative_groups.h"
#else // JAX_GPU_CUDA
#include "third_party/gpus/cuda/include/cooperative_groups.h"
#endif
namespace cg = cooperative_groups;
namespace jax {
namespace JAX_GPU_NAMESPACE {
namespace {
template <typename T>
__device__ void drotg(T* da, T* db, T* c, T* s) {
if (*db == 0) {
*c = 1.;
*s = 0.;
return;
}
T denominator = max(abs(*da), abs(*db));
T a = *da / denominator;
T b = *db / denominator;
T rh = rhypot(a, b);
*c = a * rh;
*s = -(b * rh);
return;
}
template <typename T>
__global__ void CholeskyUpdateKernel(T* rMatrix, T* uVector, int nSize) {
cg::grid_group grid = cg::this_grid();
int k = grid.thread_rank();
T c, s;
for (int step = 0; step < 2 * nSize; ++step) {
grid.sync();
int i = step - k;
if (i < k || i >= nSize || k >= nSize) {
continue;
}
if (i == k) {
drotg(rMatrix + k * nSize + k, uVector + k, &c, &s);
}
T r_i = c * rMatrix[k * nSize + i] - s * uVector[i];
uVector[i] = s * rMatrix[k * nSize + i] + c * uVector[i];
rMatrix[k * nSize + i] = r_i;
}
}
} // namespace
template <typename T>
void LaunchCholeskyUpdateKernelBody(gpuStream_t stream, void** buffers,
int grid_dim, int block_dim, int nSize) {
T* rMatrix = reinterpret_cast<T*>(buffers[2]);
T* uVector = reinterpret_cast<T*>(buffers[3]);
void* arg_ptrs[3] = {
reinterpret_cast<void*>(&rMatrix),
reinterpret_cast<void*>(&uVector),
reinterpret_cast<void*>(&nSize),
};
#ifdef JAX_GPU_HIP
hipLaunchCooperativeKernel((void*)CholeskyUpdateKernel<T>, grid_dim,
block_dim, arg_ptrs,
/*dynamic_shared_mem_bytes=*/0, stream);
#else // JAX_GPU_CUDA
cudaLaunchCooperativeKernel((void*)CholeskyUpdateKernel<T>, grid_dim,
block_dim, arg_ptrs,
/*dynamic_shared_mem_bytes=*/0, stream);
#endif
}
void LaunchCholeskyUpdateKernel(gpuStream_t stream, void** buffers,
CholeskyUpdateDescriptor descriptor) {
int nSize = descriptor.matrix_size;
LinalgType type = descriptor.linalg_type;
int dev = 0;
#ifdef JAX_GPU_HIP
hipDeviceProp_t deviceProp;
hipGetDeviceProperties(&deviceProp, dev);
#else // JAX_GPU_CUDA
cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, dev);
#endif
int block_dim = deviceProp.maxThreadsPerBlock;
int grid_dim = deviceProp.multiProcessorCount;
switch (type) {
case LinalgType::F64:
LaunchCholeskyUpdateKernelBody<double>(stream, buffers, grid_dim,
block_dim, nSize);
break;
case LinalgType::F32:
LaunchCholeskyUpdateKernelBody<float>(stream, buffers, grid_dim,
block_dim, nSize);
break;
}
}
namespace {
__device__ void ComputePermutation(const std::int32_t* pivots,

@ -13,19 +13,37 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef JAXLIB_GPU_LU_PIVOT_KERNELS_H_
#define JAXLIB_GPU_LU_PIVOT_KERNELS_H_
#ifndef JAXLIB_GPU_LINALG_KERNELS_H_
#define JAXLIB_GPU_LINALG_KERNELS_H_
#include <cstddef>
#include <cstdint>
#include "jaxlib/gpu/vendor.h"
#include "xla/ffi/api/ffi.h"
#include "xla/service/custom_call_status.h"
namespace jax {
namespace JAX_GPU_NAMESPACE {
namespace ffi = xla::ffi;
enum LinalgType {
F32 = 0,
F64 = 1,
};
struct CholeskyUpdateDescriptor {
LinalgType linalg_type;
std::int64_t matrix_size; // leading dim (N) for a square (NxN)matrix
};
void LaunchCholeskyUpdateKernel(gpuStream_t stream, void** buffers,
CholeskyUpdateDescriptor descriptor);
void CholeskyUpdate(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
void LaunchLuPivotsToPermutationKernel(gpuStream_t stream,
std::int64_t batch_size,
std::int32_t pivot_size,
@ -33,19 +51,9 @@ void LaunchLuPivotsToPermutationKernel(gpuStream_t stream,
const std::int32_t* pivots,
std::int32_t* permutation);
ffi::Error LuPivotsToPermutationImpl(
gpuStream_t stream, std::int32_t permutation_size,
ffi::Buffer<ffi::DataType::S32> pivots,
ffi::Result<ffi::Buffer<ffi::DataType::S32>> permutation);
XLA_FFI_DEFINE_HANDLER(LuPivotsToPermutation, LuPivotsToPermutationImpl,
ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<gpuStream_t>>()
.Attr<std::int32_t>("permutation_size")
.Arg<ffi::Buffer<ffi::DataType::S32>>()
.Ret<ffi::Buffer<ffi::DataType::S32>>());
XLA_FFI_DECLARE_HANDLER_SYMBOL(LuPivotsToPermutation);
} // namespace JAX_GPU_NAMESPACE
} // namespace jax
#endif // JAXLIB_GPU_LU_PIVOT_KERNELS_H_
#endif // JAXLIB_GPU_LINALG_KERNELS_H_

@ -29,7 +29,7 @@ std::string BuildThreeFry2x32Descriptor(std::int64_t n) {
nb::dict Registrations() {
nb::dict dict;
dict[JAX_GPU_PREFIX "_threefry2x32_ffi"] =
EncapsulateFunction(ThreeFry2x32Ffi);
EncapsulateFfiHandler(ThreeFry2x32Ffi);
// TODO(b/338022728): remove after 3 weeks
dict[JAX_GPU_PREFIX "_threefry2x32"] = EncapsulateFunction(ThreeFry2x32);
return dict;

@ -21,6 +21,7 @@ limitations under the License.
#include <string_view>
#include "absl/algorithm/container.h"
#include "absl/status/status.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/kernel_helpers.h"
#include "xla/ffi/api/c_api.h"
@ -56,32 +57,36 @@ void ThreeFry2x32(gpuStream_t stream, void** buffers, const char* opaque,
}
}
XLA_FFI_Error* ThreeFry2x32Ffi(XLA_FFI_CallFrame* call_frame) {
static const auto* kImpl =
ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<gpuStream_t>>()
.Arg<ffi::Buffer<ffi::DataType::U32>>()
.Arg<ffi::Buffer<ffi::DataType::U32>>()
.Arg<ffi::Buffer<ffi::DataType::U32>>()
.Arg<ffi::Buffer<ffi::DataType::U32>>()
.Ret<ffi::Buffer<ffi::DataType::U32>>()
.Ret<ffi::Buffer<ffi::DataType::U32>>()
.To([](gpuStream_t stream, auto keys0, auto keys1, auto data0,
auto data1, auto out0, auto out1) -> ffi::Error {
std::int64_t n = out0->element_count();
LaunchThreeFry2x32KernelFfi(stream, n, keys0.typed_data(),
keys1.typed_data(), data0.typed_data(),
data1.typed_data(), out0->typed_data(),
out1->typed_data());
if (auto status = JAX_AS_STATUS(gpuGetLastError()); !status.ok()) {
return ffi::Error(static_cast<XLA_FFI_Error_Code>(status.code()),
std::string(status.message()));
}
return ffi::Error::Success();
})
.release();
return kImpl->Call(call_frame);
namespace {
ffi::Error ThreeFry2x32Impl(gpuStream_t stream,
ffi::Buffer<ffi::DataType::U32> keys0,
ffi::Buffer<ffi::DataType::U32> keys1,
ffi::Buffer<ffi::DataType::U32> data0,
ffi::Buffer<ffi::DataType::U32> data1,
ffi::Result<ffi::Buffer<ffi::DataType::U32>> out0,
ffi::Result<ffi::Buffer<ffi::DataType::U32>> out1) {
std::int64_t n =
absl::c_accumulate(out0->dimensions(), 1, std::multiplies<int64_t>());
LaunchThreeFry2x32KernelFfi(stream, n, keys0.typed_data(), keys1.typed_data(),
data0.typed_data(), data1.typed_data(),
out0->typed_data(), out1->typed_data());
if (auto status = JAX_AS_STATUS(gpuGetLastError()); !status.ok()) {
return ffi::Error(static_cast<XLA_FFI_Error_Code>(status.code()),
std::string(status.message()));
}
return ffi::Error::Success();
}
} // namespace
XLA_FFI_DEFINE_HANDLER_SYMBOL(ThreeFry2x32Ffi, ThreeFry2x32Impl,
ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<gpuStream_t>>()
.Arg<ffi::Buffer<ffi::DataType::U32>>()
.Arg<ffi::Buffer<ffi::DataType::U32>>()
.Arg<ffi::Buffer<ffi::DataType::U32>>()
.Arg<ffi::Buffer<ffi::DataType::U32>>()
.Ret<ffi::Buffer<ffi::DataType::U32>>()
.Ret<ffi::Buffer<ffi::DataType::U32>>());
} // namespace JAX_GPU_NAMESPACE
} // namespace jax

@ -18,10 +18,9 @@ limitations under the License.
#include <cstddef>
#include <cstdint>
#include <string>
#include "jaxlib/gpu/vendor.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/ffi.h"
#include "xla/service/custom_call_status.h"
namespace jax {
@ -40,14 +39,14 @@ void LaunchThreeFry2x32Kernel(gpuStream_t stream, void** buffers,
void ThreeFry2x32(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
XLA_FFI_Error* ThreeFry2x32Ffi(XLA_FFI_CallFrame* call_frame);
void LaunchThreeFry2x32KernelFfi(gpuStream_t stream,
std::int64_t n,
std::uint32_t *keys0, std::uint32_t *keys1,
std::uint32_t *data0, std::uint32_t *data1,
std::uint32_t *out0, std::uint32_t *out1);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ThreeFry2x32Ffi);
} // namespace JAX_GPU_NAMESPACE
} // namespace jax

@ -17,10 +17,12 @@ limitations under the License.
#define JAXLIB_KERNEL_NANOBIND_HELPERS_H_
#include <string>
#include <type_traits>
#include "nanobind/nanobind.h"
#include "absl/base/casts.h"
#include "jaxlib/kernel_helpers.h"
#include "xla/ffi/api/c_api.h"
#include "xla/tsl/python/lib/core/numpy.h" // NOLINT
namespace jax {
@ -58,6 +60,13 @@ nanobind::capsule EncapsulateFunction(T* fn) {
"xla._CUSTOM_CALL_TARGET");
}
template <typename T>
nanobind::capsule EncapsulateFfiHandler(T* fn) {
static_assert(std::is_invocable_r_v<XLA_FFI_Error *, T, XLA_FFI_CallFrame *>,
"Encapsulated function must be an XLA FFI handler");
return nanobind::capsule(absl::bit_cast<void*>(fn));
}
} // namespace jax
#endif // JAXLIB_KERNEL_NANOBIND_HELPERS_H_

@ -194,13 +194,16 @@ pybind_extension(
)
cc_library(
name = "hip_lu_pivot_kernels",
srcs = ["//jaxlib/gpu:lu_pivot_kernels.cc"],
hdrs = ["//jaxlib/gpu:lu_pivot_kernels.h"],
name = "hip_linalg_kernels",
srcs = ["//jaxlib/gpu:linalg_kernels.cc"],
hdrs = ["//jaxlib/gpu:linalg_kernels.h"],
features = ["-use_header_modules"],
deps = [
":hip_gpu_kernel_helpers",
":hip_lu_pivot_kernels_impl",
":hip_linalg_kernels_impl",
":hip_vendor",
"//jaxlib:ffi_helpers",
"//jaxlib:kernel_helpers",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
@ -212,55 +215,18 @@ cc_library(
)
rocm_library(
name = "hip_lu_pivot_kernels_impl",
srcs = ["//jaxlib/gpu:lu_pivot_kernels.cu.cc"],
hdrs = ["//jaxlib/gpu:lu_pivot_kernels.h"],
name = "hip_linalg_kernels_impl",
srcs = ["//jaxlib/gpu:linalg_kernels.cu.cc"],
hdrs = ["//jaxlib/gpu:linalg_kernels.h"],
deps = [
":hip_gpu_kernel_helpers",
":hip_vendor",
"@local_config_rocm//rocm:rocm_headers",
"@xla//xla/ffi/api:ffi",
],
)
cc_library(
name = "cholesky_update_kernel",
srcs = [
"//jaxlib/gpu:cholesky_update_kernel.cc",
],
hdrs = ["//jaxlib/gpu:cholesky_update_kernel.h"],
features = ["-use_header_modules"],
deps = [
":cholesky_update_kernel_impl",
":hip_gpu_kernel_helpers",
":hip_vendor",
":hipsolver_kernels",
"//jaxlib:kernel_helpers",
"@xla//xla/service:custom_call_status",
"@com_google_absl//absl/status",
"@local_config_rocm//rocm:rocm_headers",
],
)
rocm_library(
name = "cholesky_update_kernel_impl",
srcs = [
"//jaxlib/gpu:cholesky_update_kernel.cu.cc",
],
hdrs = [
"//jaxlib/gpu:cholesky_update_kernel.h",
],
deps = [
":hip_gpu_kernel_helpers",
":hip_vendor",
":hipsolver_kernels",
"//jaxlib:kernel_helpers",
"@xla//xla/service:custom_call_status",
"@local_config_rocm//rocm:rocm_headers",
],
)
pybind_extension(
name = "_linalg",
srcs = ["//jaxlib/gpu:linalg.cc"],
@ -271,10 +237,8 @@ pybind_extension(
features = ["-use_header_modules"],
module_name = "_linalg",
deps = [
":cholesky_update_kernel",
":hip_gpu_kernel_helpers",
":hip_lu_pivot_kernels",
":hip_lu_pivot_kernels_impl",
":hip_linalg_kernels",
":hip_vendor",
"//jaxlib:kernel_nanobind_helpers",
"@xla//xla/tsl/python/lib/core:numpy",
@ -291,11 +255,13 @@ cc_library(
":hip_gpu_kernel_helpers",
":hip_prng_kernels_impl",
":hip_vendor",
"//jaxlib:ffi_helpers",
"//jaxlib:kernel_helpers",
"@xla//xla/ffi/api:c_api",
"@xla//xla/ffi/api:ffi",
"@local_config_rocm//rocm:rocm_headers",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/status",
"@xla//xla/service:custom_call_status",
],
)
@ -308,7 +274,6 @@ rocm_library(
":hip_gpu_kernel_helpers",
":hip_vendor",
"//jaxlib:kernel_helpers",
"@xla//xla/ffi/api:c_api",
"@xla//xla/ffi/api:ffi",
"@local_config_rocm//rocm:rocm_headers",
"@xla//xla/service:custom_call_status",