mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Copy seq_lengths before creating descriptor
PiperOrigin-RevId: 519771897
This commit is contained in:
parent
88c2898e36
commit
3c3fa042e3
@ -261,6 +261,8 @@ def lstm_ref(x: Array, h_0: Array, c_0: Array, W_ih: Dict[int, Array],
|
||||
See https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html#lstm
|
||||
https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnRNNMode_t
|
||||
"""
|
||||
if seq_lengths.dtype != jnp.dtype("int32"):
|
||||
raise NotImplementedError("`seq_lengths` can only be int32.")
|
||||
if dropout != 0.0:
|
||||
raise NotImplementedError(
|
||||
'Dropout not supported in LSTM reference because we cannot determine CUDNN dropout mask.'
|
||||
@ -326,6 +328,8 @@ def lstm_ref(x: Array, h_0: Array, c_0: Array, W_ih: Dict[int, Array],
|
||||
def lstm_fwd(x: Array, h_0: Array, c_0: Array, w: Array, seq_lengths: Array,
|
||||
input_size: int, hidden_size: int, num_layers: int, dropout: float,
|
||||
bidirectional: bool):
|
||||
if seq_lengths.dtype != jnp.dtype("int32"):
|
||||
raise NotImplementedError("`seq_lengths` can only be int32.")
|
||||
y, h_n, c_n, workspace, reserve_space = rnn_fwd_p.bind(
|
||||
x,
|
||||
h_0,
|
||||
|
@ -316,8 +316,12 @@ static absl::Status DnnRNNBackward_(gpuStream_t stream, void** buffers,
|
||||
cudnnRNNDataLayout_t layout = CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED;
|
||||
float padding = 0.0f;
|
||||
|
||||
auto seq_lengths_buf = buffers[11];
|
||||
std::vector<int32_t> seq_length_vector(d.batch_size, d.max_seq_length);
|
||||
int32_t* seq_length_array = &seq_length_vector[0];
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpy(
|
||||
seq_length_array, seq_lengths_buf,
|
||||
seq_length_vector.size() * sizeof(int32_t), gpuMemcpyDeviceToHost)));
|
||||
|
||||
cudnnRNNDataDescriptor_t input_data_desc;
|
||||
JAX_RETURN_IF_ERROR(
|
||||
@ -367,7 +371,6 @@ static absl::Status DnnRNNBackward_(gpuStream_t stream, void** buffers,
|
||||
auto workspace_buf = buffers[8];
|
||||
auto reserve_space_buf = buffers[9];
|
||||
auto zeroed_dw_buf = buffers[10];
|
||||
auto seq_lengths_buf = buffers[11];
|
||||
auto dx_buf = buffers[12];
|
||||
auto dh_0_buf = buffers[13];
|
||||
auto dc_0_buf = buffers[14];
|
||||
|
@ -96,6 +96,70 @@ class RnnTest(jtu.JaxTestCase):
|
||||
np.testing.assert_allclose(h_n_ref, h_n, rtol=1e-05, atol=1e-5)
|
||||
np.testing.assert_allclose(c_n_ref, c_n, rtol=1e-05, atol=1e-5)
|
||||
|
||||
@jtu.skip_on_devices("cpu", "tpu", "rocm")
|
||||
def test_lstm_with_varying_seq_lens(self):
|
||||
batch_size = 6
|
||||
seq_len = 7
|
||||
input_size = 8
|
||||
hidden_size = 12
|
||||
num_layers = 5
|
||||
bidirectional = False
|
||||
num_directions = 2 if bidirectional else 1
|
||||
|
||||
seq_lengths = jnp.array([4, 5, 1, 1, 1, 1], dtype=jnp.dtype("int32"))
|
||||
|
||||
root_key = jax.random.PRNGKey(1)
|
||||
k1, k2, k3, k4 = jax.random.split(root_key, 4)
|
||||
x = jax.random.normal(
|
||||
k1, (batch_size, seq_len, input_size), dtype=jnp.float32)
|
||||
h_0 = jax.random.normal(
|
||||
k2, (num_directions * num_layers, batch_size, hidden_size),
|
||||
dtype=jnp.float32)
|
||||
c_0 = jax.random.normal(
|
||||
k3, (num_directions * num_layers, batch_size, hidden_size),
|
||||
dtype=jnp.float32)
|
||||
weights = rnn.init_lstm_weight(k4, input_size, hidden_size, num_layers,
|
||||
bidirectional)
|
||||
|
||||
@jax.jit
|
||||
def f(x, h_0, c_0, weights):
|
||||
return rnn.lstm(
|
||||
x,
|
||||
h_0,
|
||||
c_0,
|
||||
weights,
|
||||
seq_lengths=seq_lengths,
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
dropout=False,
|
||||
bidirectional=bidirectional)
|
||||
|
||||
jtu.check_grads(f, (x, h_0, c_0, weights), modes=['rev'], order=1)
|
||||
|
||||
# TODO(sharadmv): enable when lstm_ref works with seq_lengths
|
||||
# W_ih, W_hh, b_ih, b_hh = rnn.unpack_lstm_weights(weights, input_size,
|
||||
# hidden_size, num_layers,
|
||||
# bidirectional)
|
||||
# y_ref, h_n_ref, c_n_ref = rnn.lstm_ref(
|
||||
# x,
|
||||
# h_0,
|
||||
# c_0,
|
||||
# W_ih,
|
||||
# W_hh,
|
||||
# b_ih,
|
||||
# b_hh,
|
||||
# seq_lengths=seq_lengths,
|
||||
# input_size=input_size,
|
||||
# hidden_size=hidden_size,
|
||||
# num_layers=num_layers,
|
||||
# dropout=False,
|
||||
# bidirectional=bidirectional)
|
||||
|
||||
# np.testing.assert_allclose(y_ref, y, rtol=1e-05, atol=1e-5)
|
||||
# np.testing.assert_allclose(h_n_ref, h_n, rtol=1e-05, atol=1e-5)
|
||||
# np.testing.assert_allclose(c_n_ref, c_n, rtol=1e-05, atol=1e-5)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user