diff --git a/jaxlib/gpu/rnn.cc b/jaxlib/gpu/rnn.cc index c88b164e6..eaa815d33 100644 --- a/jaxlib/gpu/rnn.cc +++ b/jaxlib/gpu/rnn.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "nanobind/nanobind.h" #include "nanobind/stl/pair.h" #include "jaxlib/absl_status_casters.h" @@ -29,7 +31,7 @@ namespace nb = nanobind; nb::bytes BuildRnnDescriptor(int input_size, int hidden_size, int num_layers, int batch_size, int max_seq_length, float dropout, bool bidirectional, bool cudnn_allow_tf32, - int workspace_size, int reserve_space_size) { + size_t workspace_size, size_t reserve_space_size) { return PackDescriptor(RnnDescriptor{ input_size, hidden_size, num_layers, batch_size, max_seq_length, dropout, bidirectional, cudnn_allow_tf32, workspace_size, reserve_space_size}); diff --git a/jaxlib/gpu/rnn_kernels.cc b/jaxlib/gpu/rnn_kernels.cc index 89a6d0a30..e9820bc31 100644 --- a/jaxlib/gpu/rnn_kernels.cc +++ b/jaxlib/gpu/rnn_kernels.cc @@ -15,6 +15,7 @@ limitations under the License. #include "jaxlib/gpu/rnn_kernels.h" +#include #include #include @@ -71,7 +72,7 @@ template <> namespace JAX_GPU_NAMESPACE { -static absl::StatusOr> +static absl::StatusOr> DoRnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size, int num_layers, int batch_size, int max_seq_length, float dropout, @@ -174,7 +175,7 @@ DoRnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size, return std::make_pair(workSpaceSize, reserveSpaceSize); } -absl::StatusOr> RnnComputeWorkspaceReserveSpaceSizes( +absl::StatusOr> RnnComputeWorkspaceReserveSpaceSizes( int input_size, int hidden_size, int num_layers, int batch_size, int max_seq_length, float dropout, bool bidirectional, bool cudnn_allow_tf32) { diff --git a/jaxlib/gpu/rnn_kernels.h b/jaxlib/gpu/rnn_kernels.h index 468c02eac..e95b77883 100644 --- a/jaxlib/gpu/rnn_kernels.h +++ b/jaxlib/gpu/rnn_kernels.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef JAXLIB_GPU_RNN_KERNELS_H_ #define JAXLIB_GPU_RNN_KERNELS_H_ +#include + #include "absl/status/statusor.h" #include "jaxlib/gpu/vendor.h" #include "xla/ffi/api/ffi.h" @@ -34,12 +36,12 @@ struct RnnDescriptor { float dropout; int bidirectional; int cudnn_allow_tf32; - int workspace_size; - int reserve_space_size; + size_t workspace_size; + size_t reserve_space_size; }; // Return (workspace size, reserve space size). -absl::StatusOr> RnnComputeWorkspaceReserveSpaceSizes( +absl::StatusOr> RnnComputeWorkspaceReserveSpaceSizes( int input_size, int hidden_size, int num_layers, int batch_size, int max_seq_length, float dropout, bool bidirectional, bool cudnn_allow_tf32); diff --git a/tests/experimental_rnn_test.py b/tests/experimental_rnn_test.py index 376a9b1a1..7fa3b93f3 100644 --- a/tests/experimental_rnn_test.py +++ b/tests/experimental_rnn_test.py @@ -213,8 +213,36 @@ class RnnTest(jtu.JaxTestCase): k = jax.random.split(jax.random.PRNGKey(1), 4) stablehlo = jax.jit(f).lower(*k).as_text("stablehlo") - self.assertIn('"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00@\\01\\00\\00"', - stablehlo) + if jtu.jaxlib_version() <= (0, 5, 2): + self.assertIn('"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00@\\01\\00\\00"', + stablehlo) + else: + self.assertIn('"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00\\00\\00\\00\\00@\\01\\00\\00\\00\\00\\00\\00"', + stablehlo) + + @jtu.run_on_devices("cuda") + def test_no_workspace_overflow(self): + if jtu.jaxlib_version() <= (0, 5, 2): + self.skipTest("Older versions fail because of integer overflow.") + + # Problem sizes known to cause overflows on older versions. + batch_size, max_seq_length, input_size = 256, 500, 512 + num_layers, hidden_size = 1, 256 + num_params = rnn.get_num_params_in_lstm( + input_size, hidden_size, num_layers, True) + x = jax.ShapeDtypeStruct( + (batch_size, max_seq_length, input_size), jnp.float32) + h_0 = jax.ShapeDtypeStruct( + (2 * num_layers, batch_size, hidden_size), jnp.float32) + c_0 = jax.ShapeDtypeStruct( + (2 * num_layers, batch_size, hidden_size), jnp.float32) + weights = jax.ShapeDtypeStruct((num_params,), jnp.float32) + seq_lengths = jax.ShapeDtypeStruct((batch_size,), jnp.int32) + fun = jax.jit(partial( + rnn.lstm, input_size=input_size, hidden_size=hidden_size, + num_layers=num_layers, dropout=0.0, bidirectional=True)) + fun.lower(x, h_0, c_0, weights, seq_lengths) # Doesn't crash. + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())