Copy seq_lengths before creating descriptor

PiperOrigin-RevId: 519771897
This commit is contained in:
Sharad Vikram 2023-03-27 10:59:14 -07:00 committed by jax authors
parent 88c2898e36
commit 3c3fa042e3
3 changed files with 72 additions and 1 deletions

View File

@ -261,6 +261,8 @@ def lstm_ref(x: Array, h_0: Array, c_0: Array, W_ih: Dict[int, Array],
See https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html#lstm
https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnRNNMode_t
"""
if seq_lengths.dtype != jnp.dtype("int32"):
raise NotImplementedError("`seq_lengths` can only be int32.")
if dropout != 0.0:
raise NotImplementedError(
'Dropout not supported in LSTM reference because we cannot determine CUDNN dropout mask.'
@ -326,6 +328,8 @@ def lstm_ref(x: Array, h_0: Array, c_0: Array, W_ih: Dict[int, Array],
def lstm_fwd(x: Array, h_0: Array, c_0: Array, w: Array, seq_lengths: Array,
input_size: int, hidden_size: int, num_layers: int, dropout: float,
bidirectional: bool):
if seq_lengths.dtype != jnp.dtype("int32"):
raise NotImplementedError("`seq_lengths` can only be int32.")
y, h_n, c_n, workspace, reserve_space = rnn_fwd_p.bind(
x,
h_0,

View File

@ -316,8 +316,12 @@ static absl::Status DnnRNNBackward_(gpuStream_t stream, void** buffers,
cudnnRNNDataLayout_t layout = CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED;
float padding = 0.0f;
auto seq_lengths_buf = buffers[11];
std::vector<int32_t> seq_length_vector(d.batch_size, d.max_seq_length);
int32_t* seq_length_array = &seq_length_vector[0];
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpy(
seq_length_array, seq_lengths_buf,
seq_length_vector.size() * sizeof(int32_t), gpuMemcpyDeviceToHost)));
cudnnRNNDataDescriptor_t input_data_desc;
JAX_RETURN_IF_ERROR(
@ -367,7 +371,6 @@ static absl::Status DnnRNNBackward_(gpuStream_t stream, void** buffers,
auto workspace_buf = buffers[8];
auto reserve_space_buf = buffers[9];
auto zeroed_dw_buf = buffers[10];
auto seq_lengths_buf = buffers[11];
auto dx_buf = buffers[12];
auto dh_0_buf = buffers[13];
auto dc_0_buf = buffers[14];

View File

@ -96,6 +96,70 @@ class RnnTest(jtu.JaxTestCase):
np.testing.assert_allclose(h_n_ref, h_n, rtol=1e-05, atol=1e-5)
np.testing.assert_allclose(c_n_ref, c_n, rtol=1e-05, atol=1e-5)
@jtu.skip_on_devices("cpu", "tpu", "rocm")
def test_lstm_with_varying_seq_lens(self):
batch_size = 6
seq_len = 7
input_size = 8
hidden_size = 12
num_layers = 5
bidirectional = False
num_directions = 2 if bidirectional else 1
seq_lengths = jnp.array([4, 5, 1, 1, 1, 1], dtype=jnp.dtype("int32"))
root_key = jax.random.PRNGKey(1)
k1, k2, k3, k4 = jax.random.split(root_key, 4)
x = jax.random.normal(
k1, (batch_size, seq_len, input_size), dtype=jnp.float32)
h_0 = jax.random.normal(
k2, (num_directions * num_layers, batch_size, hidden_size),
dtype=jnp.float32)
c_0 = jax.random.normal(
k3, (num_directions * num_layers, batch_size, hidden_size),
dtype=jnp.float32)
weights = rnn.init_lstm_weight(k4, input_size, hidden_size, num_layers,
bidirectional)
@jax.jit
def f(x, h_0, c_0, weights):
return rnn.lstm(
x,
h_0,
c_0,
weights,
seq_lengths=seq_lengths,
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=False,
bidirectional=bidirectional)
jtu.check_grads(f, (x, h_0, c_0, weights), modes=['rev'], order=1)
# TODO(sharadmv): enable when lstm_ref works with seq_lengths
# W_ih, W_hh, b_ih, b_hh = rnn.unpack_lstm_weights(weights, input_size,
# hidden_size, num_layers,
# bidirectional)
# y_ref, h_n_ref, c_n_ref = rnn.lstm_ref(
# x,
# h_0,
# c_0,
# W_ih,
# W_hh,
# b_ih,
# b_hh,
# seq_lengths=seq_lengths,
# input_size=input_size,
# hidden_size=hidden_size,
# num_layers=num_layers,
# dropout=False,
# bidirectional=bidirectional)
# np.testing.assert_allclose(y_ref, y, rtol=1e-05, atol=1e-5)
# np.testing.assert_allclose(h_n_ref, h_n, rtol=1e-05, atol=1e-5)
# np.testing.assert_allclose(c_n_ref, c_n, rtol=1e-05, atol=1e-5)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())