mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Fix struct string encoding non-determinism in the RNN descriptor.
Boolean fields in the descriptor struct led to padding, which let random bytes in the string representation of the struct and variance in HLO from run to run.
This commit is contained in:
parent
640cb009f1
commit
f0e1c3cf36
@ -31,25 +31,26 @@ struct RnnDescriptor {
|
||||
int batch_size;
|
||||
int max_seq_length;
|
||||
float dropout;
|
||||
bool bidirectional;
|
||||
bool cudnn_allow_tf32;
|
||||
int bidirectional;
|
||||
int cudnn_allow_tf32;
|
||||
int workspace_size;
|
||||
int reserve_space_size;
|
||||
};
|
||||
|
||||
// Return (workspace size, reserve space size).
|
||||
absl::StatusOr<std::pair<int, int>> RnnComputeWorkspaceReserveSpaceSizes(
|
||||
int input_size, int hidden_size, int num_layers, int batch_size,
|
||||
int max_seq_length, float dropout, bool bidirectional,
|
||||
bool cudnn_allow_tf32);
|
||||
absl::StatusOr<std::pair<int, int>>
|
||||
RnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size,
|
||||
int num_layers, int batch_size,
|
||||
int max_seq_length, float dropout,
|
||||
bool bidirectional, bool cudnn_allow_tf32);
|
||||
|
||||
void RNNForward(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
void RNNForward(gpuStream_t stream, void **buffers, const char *opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus *status);
|
||||
|
||||
void RNNBackward(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
void RNNBackward(gpuStream_t stream, void **buffers, const char *opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus *status);
|
||||
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
||||
|
||||
#endif // JAXLIB_GPU_RNN_KERNELS_H_
|
||||
#endif // JAXLIB_GPU_RNN_KERNELS_H_
|
||||
|
@ -178,5 +178,44 @@ class RnnTest(jtu.JaxTestCase):
|
||||
y_padded = y_ref[i, seq_lengths[i]:]
|
||||
np.testing.assert_allclose(y_padded, jnp.zeros_like(y_padded))
|
||||
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_struct_encoding_determinism(self):
|
||||
def f(k1, k2, k3, k4):
|
||||
batch_size = 1
|
||||
seq_len = 1
|
||||
input_size = 1
|
||||
hidden_size = 1
|
||||
bidirectional = False
|
||||
num_directions = 2 if bidirectional else 1
|
||||
num_layers = 1
|
||||
x = jax.random.normal(k1, (batch_size, seq_len, input_size), dtype=jnp.float32)
|
||||
h_0 = jax.random.normal(
|
||||
k2, (num_directions * num_layers, batch_size, hidden_size),
|
||||
dtype=jnp.float32)
|
||||
c_0 = jax.random.normal(
|
||||
k3, (num_directions * num_layers, batch_size, hidden_size),
|
||||
dtype=jnp.float32)
|
||||
seq_lengths = jnp.ones((batch_size,), dtype=jnp.int32) * seq_len
|
||||
weights = rnn.init_lstm_weight(k4, input_size, hidden_size, num_layers,
|
||||
bidirectional)
|
||||
return rnn.lstm(
|
||||
x,
|
||||
h_0,
|
||||
c_0,
|
||||
weights,
|
||||
seq_lengths=seq_lengths,
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
dropout=False,
|
||||
bidirectional=bidirectional)
|
||||
|
||||
k = jax.random.split(jax.random.PRNGKey(1), 4)
|
||||
stablehlo = jax.jit(f).lower(*k).as_text("stablehlo")
|
||||
self.assertIn('stablehlo.custom_call @cudnn_rnn(%0, %1, %2, %6, %5) '
|
||||
'{api_version = 2 : i32, backend_config = '
|
||||
'"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00@\\01\\00\\00"}',
|
||||
stablehlo)
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user