Fix integer overflow in workspace size computations for experimental.rnn.*.

PiperOrigin-RevId: 736139471
This commit is contained in:
Dan Foreman-Mackey 2025-03-12 08:21:16 -07:00 committed by jax authors
parent e33f3fc48b
commit 8b7cfcb33c
4 changed files with 41 additions and 8 deletions

View File

@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstddef>
#include "nanobind/nanobind.h"
#include "nanobind/stl/pair.h"
#include "jaxlib/absl_status_casters.h"
@ -29,7 +31,7 @@ namespace nb = nanobind;
nb::bytes BuildRnnDescriptor(int input_size, int hidden_size, int num_layers,
int batch_size, int max_seq_length, float dropout,
bool bidirectional, bool cudnn_allow_tf32,
int workspace_size, int reserve_space_size) {
size_t workspace_size, size_t reserve_space_size) {
return PackDescriptor(RnnDescriptor{
input_size, hidden_size, num_layers, batch_size, max_seq_length, dropout,
bidirectional, cudnn_allow_tf32, workspace_size, reserve_space_size});

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "jaxlib/gpu/rnn_kernels.h"
#include <cstddef>
#include <utility>
#include <vector>
@ -71,7 +72,7 @@ template <>
namespace JAX_GPU_NAMESPACE {
static absl::StatusOr<std::pair<int, int>>
static absl::StatusOr<std::pair<size_t, size_t>>
DoRnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size,
int num_layers, int batch_size,
int max_seq_length, float dropout,
@ -174,7 +175,7 @@ DoRnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size,
return std::make_pair(workSpaceSize, reserveSpaceSize);
}
absl::StatusOr<std::pair<int, int>> RnnComputeWorkspaceReserveSpaceSizes(
absl::StatusOr<std::pair<size_t, size_t>> RnnComputeWorkspaceReserveSpaceSizes(
int input_size, int hidden_size, int num_layers, int batch_size,
int max_seq_length, float dropout, bool bidirectional,
bool cudnn_allow_tf32) {

View File

@ -16,6 +16,8 @@ limitations under the License.
#ifndef JAXLIB_GPU_RNN_KERNELS_H_
#define JAXLIB_GPU_RNN_KERNELS_H_
#include <cstddef>
#include "absl/status/statusor.h"
#include "jaxlib/gpu/vendor.h"
#include "xla/ffi/api/ffi.h"
@ -34,12 +36,12 @@ struct RnnDescriptor {
float dropout;
int bidirectional;
int cudnn_allow_tf32;
int workspace_size;
int reserve_space_size;
size_t workspace_size;
size_t reserve_space_size;
};
// Return (workspace size, reserve space size).
absl::StatusOr<std::pair<int, int>> RnnComputeWorkspaceReserveSpaceSizes(
absl::StatusOr<std::pair<size_t, size_t>> RnnComputeWorkspaceReserveSpaceSizes(
int input_size, int hidden_size, int num_layers, int batch_size,
int max_seq_length, float dropout, bool bidirectional,
bool cudnn_allow_tf32);

View File

@ -213,8 +213,36 @@ class RnnTest(jtu.JaxTestCase):
k = jax.random.split(jax.random.PRNGKey(1), 4)
stablehlo = jax.jit(f).lower(*k).as_text("stablehlo")
self.assertIn('"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00@\\01\\00\\00"',
stablehlo)
if jtu.jaxlib_version() <= (0, 5, 2):
self.assertIn('"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00@\\01\\00\\00"',
stablehlo)
else:
self.assertIn('"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00\\00\\00\\00\\00@\\01\\00\\00\\00\\00\\00\\00"',
stablehlo)
@jtu.run_on_devices("cuda")
def test_no_workspace_overflow(self):
if jtu.jaxlib_version() <= (0, 5, 2):
self.skipTest("Older versions fail because of integer overflow.")
# Problem sizes known to cause overflows on older versions.
batch_size, max_seq_length, input_size = 256, 500, 512
num_layers, hidden_size = 1, 256
num_params = rnn.get_num_params_in_lstm(
input_size, hidden_size, num_layers, True)
x = jax.ShapeDtypeStruct(
(batch_size, max_seq_length, input_size), jnp.float32)
h_0 = jax.ShapeDtypeStruct(
(2 * num_layers, batch_size, hidden_size), jnp.float32)
c_0 = jax.ShapeDtypeStruct(
(2 * num_layers, batch_size, hidden_size), jnp.float32)
weights = jax.ShapeDtypeStruct((num_params,), jnp.float32)
seq_lengths = jax.ShapeDtypeStruct((batch_size,), jnp.int32)
fun = jax.jit(partial(
rnn.lstm, input_size=input_size, hidden_size=hidden_size,
num_layers=num_layers, dropout=0.0, bidirectional=True))
fun.lower(x, h_0, c_0, weights, seq_lengths) # Doesn't crash.
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())