mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Fix integer overflow in workspace size computations for experimental.rnn.*.
PiperOrigin-RevId: 736139471
This commit is contained in:
parent
e33f3fc48b
commit
8b7cfcb33c
@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
#include "nanobind/nanobind.h"
|
||||
#include "nanobind/stl/pair.h"
|
||||
#include "jaxlib/absl_status_casters.h"
|
||||
@ -29,7 +31,7 @@ namespace nb = nanobind;
|
||||
nb::bytes BuildRnnDescriptor(int input_size, int hidden_size, int num_layers,
|
||||
int batch_size, int max_seq_length, float dropout,
|
||||
bool bidirectional, bool cudnn_allow_tf32,
|
||||
int workspace_size, int reserve_space_size) {
|
||||
size_t workspace_size, size_t reserve_space_size) {
|
||||
return PackDescriptor(RnnDescriptor{
|
||||
input_size, hidden_size, num_layers, batch_size, max_seq_length, dropout,
|
||||
bidirectional, cudnn_allow_tf32, workspace_size, reserve_space_size});
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "jaxlib/gpu/rnn_kernels.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
@ -71,7 +72,7 @@ template <>
|
||||
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
|
||||
static absl::StatusOr<std::pair<int, int>>
|
||||
static absl::StatusOr<std::pair<size_t, size_t>>
|
||||
DoRnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size,
|
||||
int num_layers, int batch_size,
|
||||
int max_seq_length, float dropout,
|
||||
@ -174,7 +175,7 @@ DoRnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size,
|
||||
return std::make_pair(workSpaceSize, reserveSpaceSize);
|
||||
}
|
||||
|
||||
absl::StatusOr<std::pair<int, int>> RnnComputeWorkspaceReserveSpaceSizes(
|
||||
absl::StatusOr<std::pair<size_t, size_t>> RnnComputeWorkspaceReserveSpaceSizes(
|
||||
int input_size, int hidden_size, int num_layers, int batch_size,
|
||||
int max_seq_length, float dropout, bool bidirectional,
|
||||
bool cudnn_allow_tf32) {
|
||||
|
@ -16,6 +16,8 @@ limitations under the License.
|
||||
#ifndef JAXLIB_GPU_RNN_KERNELS_H_
|
||||
#define JAXLIB_GPU_RNN_KERNELS_H_
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "xla/ffi/api/ffi.h"
|
||||
@ -34,12 +36,12 @@ struct RnnDescriptor {
|
||||
float dropout;
|
||||
int bidirectional;
|
||||
int cudnn_allow_tf32;
|
||||
int workspace_size;
|
||||
int reserve_space_size;
|
||||
size_t workspace_size;
|
||||
size_t reserve_space_size;
|
||||
};
|
||||
|
||||
// Return (workspace size, reserve space size).
|
||||
absl::StatusOr<std::pair<int, int>> RnnComputeWorkspaceReserveSpaceSizes(
|
||||
absl::StatusOr<std::pair<size_t, size_t>> RnnComputeWorkspaceReserveSpaceSizes(
|
||||
int input_size, int hidden_size, int num_layers, int batch_size,
|
||||
int max_seq_length, float dropout, bool bidirectional,
|
||||
bool cudnn_allow_tf32);
|
||||
|
@ -213,8 +213,36 @@ class RnnTest(jtu.JaxTestCase):
|
||||
|
||||
k = jax.random.split(jax.random.PRNGKey(1), 4)
|
||||
stablehlo = jax.jit(f).lower(*k).as_text("stablehlo")
|
||||
self.assertIn('"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00@\\01\\00\\00"',
|
||||
stablehlo)
|
||||
if jtu.jaxlib_version() <= (0, 5, 2):
|
||||
self.assertIn('"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00@\\01\\00\\00"',
|
||||
stablehlo)
|
||||
else:
|
||||
self.assertIn('"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00\\00\\00\\00\\00@\\01\\00\\00\\00\\00\\00\\00"',
|
||||
stablehlo)
|
||||
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_no_workspace_overflow(self):
|
||||
if jtu.jaxlib_version() <= (0, 5, 2):
|
||||
self.skipTest("Older versions fail because of integer overflow.")
|
||||
|
||||
# Problem sizes known to cause overflows on older versions.
|
||||
batch_size, max_seq_length, input_size = 256, 500, 512
|
||||
num_layers, hidden_size = 1, 256
|
||||
num_params = rnn.get_num_params_in_lstm(
|
||||
input_size, hidden_size, num_layers, True)
|
||||
x = jax.ShapeDtypeStruct(
|
||||
(batch_size, max_seq_length, input_size), jnp.float32)
|
||||
h_0 = jax.ShapeDtypeStruct(
|
||||
(2 * num_layers, batch_size, hidden_size), jnp.float32)
|
||||
c_0 = jax.ShapeDtypeStruct(
|
||||
(2 * num_layers, batch_size, hidden_size), jnp.float32)
|
||||
weights = jax.ShapeDtypeStruct((num_params,), jnp.float32)
|
||||
seq_lengths = jax.ShapeDtypeStruct((batch_size,), jnp.int32)
|
||||
fun = jax.jit(partial(
|
||||
rnn.lstm, input_size=input_size, hidden_size=hidden_size,
|
||||
num_layers=num_layers, dropout=0.0, bidirectional=True))
|
||||
fun.lower(x, h_0, c_0, weights, seq_lengths) # Doesn't crash.
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user