handle seq_lengths in lstm_ref

This commit is contained in:
Cristian Garcia 2023-04-03 22:22:54 +00:00
parent d743d23859
commit aa12e3597b
2 changed files with 154 additions and 35 deletions

View File

@ -289,6 +289,13 @@ def lstm_ref(x: Array, h_0: Array, c_0: Array, W_ih: Dict[int, Array],
h = o * tanh(c)
return (h, c), h
# here we also output the carry so that we can later slice
# the correct carry according to seq_lengths, while this takes more memory
# it is faster than using 'jnp.where' inside the scan loop
def scan_fn(cell, carry, x):
carry, y = cell(carry, x)
return carry, (carry, y)
seq_first_y = x.transpose(1, 0, 2)
if not bidirectional:
final_h = []
@ -296,8 +303,10 @@ def lstm_ref(x: Array, h_0: Array, c_0: Array, W_ih: Dict[int, Array],
for l in range(num_layers):
cell = partial(
lstm_cell, W_ih=W_ih[l], W_hh=W_hh[l], b_ih=b_ih[l], b_hh=b_hh[l])
(h_t, c_t), seq_first_y = jax.lax.scan(cell, (h_0[l], c_0[l]),
cell_fn = partial(scan_fn, cell)
out = jax.lax.scan(cell_fn, (h_0[l], c_0[l]),
seq_first_y)
(h_t, c_t), seq_first_y = _extract_output(seq_lengths, out)
final_h.append(h_t)
final_c.append(c_t)
h_n = jnp.stack(final_h)
@ -310,12 +319,19 @@ def lstm_ref(x: Array, h_0: Array, c_0: Array, W_ih: Dict[int, Array],
for l in range(num_layers * 2):
cell = partial(
lstm_cell, W_ih=W_ih[l], W_hh=W_hh[l], b_ih=b_ih[l], b_hh=b_hh[l])
cell_fn = partial(scan_fn, cell)
if l % 2 == 0:
(h_t, c_t), seq_first_y_fwd = jax.lax.scan(cell, (h_0[l], c_0[l]),
out = jax.lax.scan(cell_fn, (h_0[l], c_0[l]),
seq_first_y)
(h_t, c_t), seq_first_y_fwd = _extract_output(seq_lengths, out)
else:
(h_t, c_t), seq_first_y_bwd = jax.lax.scan(
cell, (h_0[l], c_0[l]), seq_first_y, reverse=True)
# reverse sequence while keeping padding at the end
seq_first_y_reversed = _flip_sequence(seq_first_y, seq_lengths)
out = jax.lax.scan(
cell_fn, (h_0[l], c_0[l]), seq_first_y_reversed)
(h_t, c_t), seq_first_y_bwd = _extract_output(seq_lengths, out)
# align reversed sequence with original sequence
seq_first_y_bwd = _flip_sequence(seq_first_y_bwd, seq_lengths)
# Inputs to next layer are concat'ed from fwd and bwd.
seq_first_y = jnp.concatenate([seq_first_y_fwd, seq_first_y_bwd], axis=-1) # pytype: disable=name-error
final_h.append(h_t)
@ -324,6 +340,30 @@ def lstm_ref(x: Array, h_0: Array, c_0: Array, W_ih: Dict[int, Array],
c_n = jnp.stack(final_c)
return seq_first_y.transpose(1, 0, 2), h_n, c_n
def _extract_output(seq_lengths: Array, out) -> Tuple[Tuple[Array, Array], Array]:
_, ((hs, cs), seq_first_y) = out
h_t = _select_last_carry(hs, seq_lengths)
c_t = _select_last_carry(cs, seq_lengths)
# [seq_len, batch] [1, batch] [seq_len, 1]
mask = seq_lengths[None] > jnp.arange(seq_first_y.shape[0], dtype=jnp.int32)[:, None]
# [batch, seq_len, hidden_size]
seq_first_y = jnp.where(
mask[..., None], # [seq_len, batch, 1]
seq_first_y, # [seq_len, batch, hidden_size]
0)
return (h_t, c_t), seq_first_y
def _select_last_carry(carry_seq: Array, seq_lengths: Array):
return carry_seq[seq_lengths - 1, jnp.arange(carry_seq.shape[1])]
def _flip_sequence(sequences: Array, seq_lengths: Array) -> Array:
max_steps = sequences.shape[0]
roll_amounts = max_steps - seq_lengths
# roll initially puts padding at the front so when the sequence is reversed
# (via [::-1]) the padding stays at the end
return jax.vmap(partial(jnp.roll, axis=0), in_axes=(1, 0),
out_axes=1)(sequences, roll_amounts)[::-1]
def lstm_fwd(x: Array, h_0: Array, c_0: Array, w: Array, seq_lengths: Array,
input_size: int, hidden_size: int, num_layers: int, dropout: float,

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from absl.testing import absltest
import numpy as np
import jax
@ -37,14 +38,95 @@ class RnnTest(jtu.JaxTestCase):
@jtu.skip_on_devices("cpu", "tpu", "rocm")
def test_lstm(self, batch_size: int, seq_len: int, input_size: int,
hidden_size: int, num_layers: int, bidirectional: bool):
batch_size = 6
seq_len = 7
input_size = 8
hidden_size = 12
num_layers = 5
num_directions = 2 if bidirectional else 1
seq_lengths = jnp.ones((batch_size,), dtype=jnp.int32) * seq_len
seq_lengths = jax.random.randint(
jax.random.PRNGKey(0), (batch_size,), 0, seq_len, dtype=jnp.int32)
root_key = jax.random.PRNGKey(1)
k1, k2, k3, k4 = jax.random.split(root_key, 4)
x = jax.random.normal(
k1, (batch_size, seq_len, input_size), dtype=jnp.float32)
h_0 = jax.random.normal(
k2, (num_directions * num_layers, batch_size, hidden_size),
dtype=jnp.float32)
c_0 = jax.random.normal(
k3, (num_directions * num_layers, batch_size, hidden_size),
dtype=jnp.float32)
weights = rnn.init_lstm_weight(k4, input_size, hidden_size, num_layers,
bidirectional)
@partial(jax.value_and_grad, has_aux=True)
def f(weights, x, h_0, c_0):
y, h, c = rnn.lstm(
x,
h_0,
c_0,
weights,
seq_lengths=seq_lengths,
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=False,
bidirectional=bidirectional)
loss = jnp.sum(y)
return loss, (y, h, c)
(loss, (y, h_n, c_n)), grad = f(weights, x, h_0, c_0)
jtu.check_grads(f, (x, h_0, c_0, weights), modes=['rev'], order=1)
self.assertFalse(np.isnan(loss))
self.assertFalse(np.isnan(grad).any())
@partial(jax.value_and_grad, has_aux=True)
def g(weights, x, h_0, c_0):
W_ih, W_hh, b_ih, b_hh = rnn.unpack_lstm_weights(weights, input_size,
hidden_size, num_layers,
bidirectional)
y_ref, h_n_ref, c_n_ref = rnn.lstm_ref(
x,
h_0,
c_0,
W_ih,
W_hh,
b_ih,
b_hh,
seq_lengths=seq_lengths,
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=False,
bidirectional=bidirectional)
loss = jnp.sum(y_ref)
return loss, (y_ref, h_n_ref, c_n_ref)
(loss_ref, (y_ref, h_n_ref, c_n_ref)), grad_ref = g(weights, x, h_0, c_0)
self.assertFalse(np.isnan(loss_ref))
self.assertFalse(np.isnan(grad_ref).any())
np.testing.assert_allclose(y_ref, y, rtol=1e-05, atol=1e-5)
np.testing.assert_allclose(h_n_ref, h_n, rtol=1e-05, atol=1e-5)
np.testing.assert_allclose(c_n_ref, c_n, rtol=1e-05, atol=1e-5)
@jtu.sample_product(
batch_size=[1, 4],
seq_len=[1, 4],
input_size=[1, 2],
hidden_size=[1, 6],
num_layers=[1, 4],
bidirectional=[True, False],
)
def test_lstm_ref(self, batch_size: int, seq_len: int, input_size: int,
hidden_size: int, num_layers: int, bidirectional: bool):
num_directions = 2 if bidirectional else 1
seq_lengths = jax.random.randint(
jax.random.PRNGKey(0), (batch_size,), 0, seq_len, dtype=jnp.int32)
root_key = jax.random.PRNGKey(1)
k1, k2, k3, k4 = jax.random.split(root_key, 4)
@ -59,12 +141,19 @@ class RnnTest(jtu.JaxTestCase):
weights = rnn.init_lstm_weight(k4, input_size, hidden_size, num_layers,
bidirectional)
def f(x, h_0, c_0, weights):
return rnn.lstm(
@partial(jax.value_and_grad, has_aux=True)
def f(weights, x, h_0, c_0):
W_ih, W_hh, b_ih, b_hh = rnn.unpack_lstm_weights(weights, input_size,
hidden_size, num_layers,
bidirectional)
y_ref, h_n_ref, c_n_ref = rnn.lstm_ref(
x,
h_0,
c_0,
weights,
W_ih,
W_hh,
b_ih,
b_hh,
seq_lengths=seq_lengths,
input_size=input_size,
hidden_size=hidden_size,
@ -72,30 +161,20 @@ class RnnTest(jtu.JaxTestCase):
dropout=False,
bidirectional=bidirectional)
y, h_n, c_n = f(x, h_0, c_0, weights)
jtu.check_grads(f, (x, h_0, c_0, weights), modes=['rev'], order=1)
loss = jnp.sum(y_ref)
W_ih, W_hh, b_ih, b_hh = rnn.unpack_lstm_weights(weights, input_size,
hidden_size, num_layers,
bidirectional)
y_ref, h_n_ref, c_n_ref = rnn.lstm_ref(
x,
h_0,
c_0,
W_ih,
W_hh,
b_ih,
b_hh,
seq_lengths=seq_lengths,
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=False,
bidirectional=bidirectional)
return loss, (y_ref, h_n_ref, c_n_ref)
np.testing.assert_allclose(y_ref, y, rtol=1e-05, atol=1e-5)
np.testing.assert_allclose(h_n_ref, h_n, rtol=1e-05, atol=1e-5)
np.testing.assert_allclose(c_n_ref, c_n, rtol=1e-05, atol=1e-5)
(loss_ref, (y_ref, h_n_ref, c_n_ref)), grad_ref = f(weights, x, h_0, c_0)
self.assertFalse(np.isnan(loss_ref))
self.assertFalse(np.isnan(grad_ref).any())
self.assertEqual(y_ref.shape, (batch_size, seq_len, num_directions * hidden_size))
for i in range(batch_size):
y_padded = y_ref[i, seq_lengths[i]:]
np.testing.assert_allclose(y_padded, jnp.zeros_like(y_padded))
@jtu.skip_on_devices("cpu", "tpu", "rocm")
def test_lstm_with_varying_seq_lens(self):