Fix struct string encoding non-determinism in the RNN descriptor.

Boolean fields in the descriptor struct led to padding, which let random
bytes in the string representation of the struct and variance in HLO
from run to run.
This commit is contained in:
Ilia Sergachev 2024-11-01 23:42:26 +00:00
parent 640cb009f1
commit f0e1c3cf36
2 changed files with 53 additions and 13 deletions

View File

@ -31,25 +31,26 @@ struct RnnDescriptor {
int batch_size;
int max_seq_length;
float dropout;
bool bidirectional;
bool cudnn_allow_tf32;
int bidirectional;
int cudnn_allow_tf32;
int workspace_size;
int reserve_space_size;
};
// Return (workspace size, reserve space size).
absl::StatusOr<std::pair<int, int>> RnnComputeWorkspaceReserveSpaceSizes(
int input_size, int hidden_size, int num_layers, int batch_size,
int max_seq_length, float dropout, bool bidirectional,
bool cudnn_allow_tf32);
absl::StatusOr<std::pair<int, int>>
RnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size,
int num_layers, int batch_size,
int max_seq_length, float dropout,
bool bidirectional, bool cudnn_allow_tf32);
void RNNForward(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
void RNNForward(gpuStream_t stream, void **buffers, const char *opaque,
size_t opaque_len, XlaCustomCallStatus *status);
void RNNBackward(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
void RNNBackward(gpuStream_t stream, void **buffers, const char *opaque,
size_t opaque_len, XlaCustomCallStatus *status);
} // namespace JAX_GPU_NAMESPACE
} // namespace jax
} // namespace JAX_GPU_NAMESPACE
} // namespace jax
#endif // JAXLIB_GPU_RNN_KERNELS_H_
#endif // JAXLIB_GPU_RNN_KERNELS_H_

View File

@ -178,5 +178,44 @@ class RnnTest(jtu.JaxTestCase):
y_padded = y_ref[i, seq_lengths[i]:]
np.testing.assert_allclose(y_padded, jnp.zeros_like(y_padded))
@jtu.run_on_devices("cuda")
def test_struct_encoding_determinism(self):
def f(k1, k2, k3, k4):
batch_size = 1
seq_len = 1
input_size = 1
hidden_size = 1
bidirectional = False
num_directions = 2 if bidirectional else 1
num_layers = 1
x = jax.random.normal(k1, (batch_size, seq_len, input_size), dtype=jnp.float32)
h_0 = jax.random.normal(
k2, (num_directions * num_layers, batch_size, hidden_size),
dtype=jnp.float32)
c_0 = jax.random.normal(
k3, (num_directions * num_layers, batch_size, hidden_size),
dtype=jnp.float32)
seq_lengths = jnp.ones((batch_size,), dtype=jnp.int32) * seq_len
weights = rnn.init_lstm_weight(k4, input_size, hidden_size, num_layers,
bidirectional)
return rnn.lstm(
x,
h_0,
c_0,
weights,
seq_lengths=seq_lengths,
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=False,
bidirectional=bidirectional)
k = jax.random.split(jax.random.PRNGKey(1), 4)
stablehlo = jax.jit(f).lower(*k).as_text("stablehlo")
self.assertIn('stablehlo.custom_call @cudnn_rnn(%0, %1, %2, %6, %5) '
'{api_version = 2 : i32, backend_config = '
'"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00@\\01\\00\\00"}',
stablehlo)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())