mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #15169 from cgarciae:fix-lstm
PiperOrigin-RevId: 521616002
This commit is contained in:
commit
b361f4cd0c
@ -289,6 +289,13 @@ def lstm_ref(x: Array, h_0: Array, c_0: Array, W_ih: Dict[int, Array],
|
||||
h = o * tanh(c)
|
||||
return (h, c), h
|
||||
|
||||
# here we also output the carry so that we can later slice
|
||||
# the correct carry according to seq_lengths, while this takes more memory
|
||||
# it is faster than using 'jnp.where' inside the scan loop
|
||||
def scan_fn(cell, carry, x):
|
||||
carry, y = cell(carry, x)
|
||||
return carry, (carry, y)
|
||||
|
||||
seq_first_y = x.transpose(1, 0, 2)
|
||||
if not bidirectional:
|
||||
final_h = []
|
||||
@ -296,8 +303,10 @@ def lstm_ref(x: Array, h_0: Array, c_0: Array, W_ih: Dict[int, Array],
|
||||
for l in range(num_layers):
|
||||
cell = partial(
|
||||
lstm_cell, W_ih=W_ih[l], W_hh=W_hh[l], b_ih=b_ih[l], b_hh=b_hh[l])
|
||||
(h_t, c_t), seq_first_y = jax.lax.scan(cell, (h_0[l], c_0[l]),
|
||||
cell_fn = partial(scan_fn, cell)
|
||||
out = jax.lax.scan(cell_fn, (h_0[l], c_0[l]),
|
||||
seq_first_y)
|
||||
(h_t, c_t), seq_first_y = _extract_output(seq_lengths, out)
|
||||
final_h.append(h_t)
|
||||
final_c.append(c_t)
|
||||
h_n = jnp.stack(final_h)
|
||||
@ -310,12 +319,19 @@ def lstm_ref(x: Array, h_0: Array, c_0: Array, W_ih: Dict[int, Array],
|
||||
for l in range(num_layers * 2):
|
||||
cell = partial(
|
||||
lstm_cell, W_ih=W_ih[l], W_hh=W_hh[l], b_ih=b_ih[l], b_hh=b_hh[l])
|
||||
cell_fn = partial(scan_fn, cell)
|
||||
if l % 2 == 0:
|
||||
(h_t, c_t), seq_first_y_fwd = jax.lax.scan(cell, (h_0[l], c_0[l]),
|
||||
out = jax.lax.scan(cell_fn, (h_0[l], c_0[l]),
|
||||
seq_first_y)
|
||||
(h_t, c_t), seq_first_y_fwd = _extract_output(seq_lengths, out)
|
||||
else:
|
||||
(h_t, c_t), seq_first_y_bwd = jax.lax.scan(
|
||||
cell, (h_0[l], c_0[l]), seq_first_y, reverse=True)
|
||||
# reverse sequence while keeping padding at the end
|
||||
seq_first_y_reversed = _flip_sequence(seq_first_y, seq_lengths)
|
||||
out = jax.lax.scan(
|
||||
cell_fn, (h_0[l], c_0[l]), seq_first_y_reversed)
|
||||
(h_t, c_t), seq_first_y_bwd = _extract_output(seq_lengths, out)
|
||||
# align reversed sequence with original sequence
|
||||
seq_first_y_bwd = _flip_sequence(seq_first_y_bwd, seq_lengths)
|
||||
# Inputs to next layer are concat'ed from fwd and bwd.
|
||||
seq_first_y = jnp.concatenate([seq_first_y_fwd, seq_first_y_bwd], axis=-1) # pytype: disable=name-error
|
||||
final_h.append(h_t)
|
||||
@ -324,6 +340,30 @@ def lstm_ref(x: Array, h_0: Array, c_0: Array, W_ih: Dict[int, Array],
|
||||
c_n = jnp.stack(final_c)
|
||||
return seq_first_y.transpose(1, 0, 2), h_n, c_n
|
||||
|
||||
def _extract_output(seq_lengths: Array, out) -> Tuple[Tuple[Array, Array], Array]:
|
||||
_, ((hs, cs), seq_first_y) = out
|
||||
h_t = _select_last_carry(hs, seq_lengths)
|
||||
c_t = _select_last_carry(cs, seq_lengths)
|
||||
|
||||
# [seq_len, batch] [1, batch] [seq_len, 1]
|
||||
mask = seq_lengths[None] > jnp.arange(seq_first_y.shape[0], dtype=jnp.int32)[:, None]
|
||||
# [batch, seq_len, hidden_size]
|
||||
seq_first_y = jnp.where(
|
||||
mask[..., None], # [seq_len, batch, 1]
|
||||
seq_first_y, # [seq_len, batch, hidden_size]
|
||||
0)
|
||||
return (h_t, c_t), seq_first_y
|
||||
|
||||
def _select_last_carry(carry_seq: Array, seq_lengths: Array):
|
||||
return carry_seq[seq_lengths - 1, jnp.arange(carry_seq.shape[1])]
|
||||
|
||||
def _flip_sequence(sequences: Array, seq_lengths: Array) -> Array:
|
||||
max_steps = sequences.shape[0]
|
||||
roll_amounts = max_steps - seq_lengths
|
||||
# roll initially puts padding at the front so when the sequence is reversed
|
||||
# (via [::-1]) the padding stays at the end
|
||||
return jax.vmap(partial(jnp.roll, axis=0), in_axes=(1, 0),
|
||||
out_axes=1)(sequences, roll_amounts)[::-1]
|
||||
|
||||
def lstm_fwd(x: Array, h_0: Array, c_0: Array, w: Array, seq_lengths: Array,
|
||||
input_size: int, hidden_size: int, num_layers: int, dropout: float,
|
||||
|
@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from functools import partial
|
||||
from absl.testing import absltest
|
||||
import numpy as np
|
||||
import jax
|
||||
@ -37,16 +38,16 @@ class RnnTest(jtu.JaxTestCase):
|
||||
@jtu.skip_on_devices("cpu", "tpu", "rocm")
|
||||
def test_lstm(self, batch_size: int, seq_len: int, input_size: int,
|
||||
hidden_size: int, num_layers: int, bidirectional: bool):
|
||||
batch_size = 6
|
||||
seq_len = 7
|
||||
input_size = 8
|
||||
hidden_size = 12
|
||||
num_layers = 5
|
||||
if lib.version < (0, 4, 7):
|
||||
# TODO(sharadmv, zhangqiaorjc): remove this when minimum jaxlib version is
|
||||
# bumped
|
||||
self.skipTest("Need latest jaxlib for this test to pass.")
|
||||
num_directions = 2 if bidirectional else 1
|
||||
seq_length_key, root_key = jax.random.split(jax.random.PRNGKey(0))
|
||||
|
||||
seq_lengths = jnp.ones((batch_size,), dtype=jnp.int32) * seq_len
|
||||
seq_lengths = jax.random.randint(
|
||||
seq_length_key, (batch_size,), 1, seq_len, dtype=jnp.int32)
|
||||
|
||||
root_key = jax.random.PRNGKey(1)
|
||||
k1, k2, k3, k4 = jax.random.split(root_key, 4)
|
||||
x = jax.random.normal(
|
||||
k1, (batch_size, seq_len, input_size), dtype=jnp.float32)
|
||||
@ -58,60 +59,75 @@ class RnnTest(jtu.JaxTestCase):
|
||||
dtype=jnp.float32)
|
||||
weights = rnn.init_lstm_weight(k4, input_size, hidden_size, num_layers,
|
||||
bidirectional)
|
||||
|
||||
def f(x, h_0, c_0, weights):
|
||||
return rnn.lstm(
|
||||
x,
|
||||
h_0,
|
||||
c_0,
|
||||
weights,
|
||||
seq_lengths=seq_lengths,
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
dropout=False,
|
||||
bidirectional=bidirectional)
|
||||
|
||||
y, h_n, c_n = f(x, h_0, c_0, weights)
|
||||
jtu.check_grads(f, (x, h_0, c_0, weights), modes=['rev'], order=1)
|
||||
|
||||
W_ih, W_hh, b_ih, b_hh = rnn.unpack_lstm_weights(weights, input_size,
|
||||
hidden_size, num_layers,
|
||||
bidirectional)
|
||||
y_ref, h_n_ref, c_n_ref = rnn.lstm_ref(
|
||||
def f(weights, x, h_0, c_0):
|
||||
y, h, c = rnn.lstm(
|
||||
x,
|
||||
h_0,
|
||||
c_0,
|
||||
W_ih,
|
||||
W_hh,
|
||||
b_ih,
|
||||
b_hh,
|
||||
weights,
|
||||
seq_lengths=seq_lengths,
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
dropout=False,
|
||||
bidirectional=bidirectional)
|
||||
seq_length_mask = jnp.tile(jnp.arange(seq_len, dtype=jnp.int32)[None],
|
||||
[batch_size, 1]) < seq_lengths[:, None]
|
||||
loss = jnp.sum(jnp.where(seq_length_mask[..., None], y, 0.))
|
||||
return loss, (y, h, c)
|
||||
|
||||
jtu.check_grads(f, (weights, x, h_0, c_0), modes=["rev"], order=1)
|
||||
|
||||
(loss, (y, h_n, c_n)), weights_grad = jax.value_and_grad(f, has_aux=True)(
|
||||
weights, x, h_0, c_0)
|
||||
|
||||
def g(weights, x, h_0, c_0):
|
||||
W_ih, W_hh, b_ih, b_hh = rnn.unpack_lstm_weights(weights, input_size,
|
||||
hidden_size, num_layers,
|
||||
bidirectional)
|
||||
y_ref, h_n_ref, c_n_ref = rnn.lstm_ref(
|
||||
x,
|
||||
h_0,
|
||||
c_0,
|
||||
W_ih,
|
||||
W_hh,
|
||||
b_ih,
|
||||
b_hh,
|
||||
seq_lengths=seq_lengths,
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
dropout=False,
|
||||
bidirectional=bidirectional)
|
||||
seq_length_mask = jnp.tile(jnp.arange(seq_len, dtype=jnp.int32)[None],
|
||||
[batch_size, 1]) < seq_lengths[:, None]
|
||||
loss = jnp.sum(jnp.where(seq_length_mask[..., None], y_ref, 0.))
|
||||
return loss, (y_ref, h_n_ref, c_n_ref)
|
||||
|
||||
(loss_ref, (y_ref, h_n_ref, c_n_ref)), weights_grad_ref = (
|
||||
jax.value_and_grad(g, has_aux=True)(weights, x, h_0, c_0))
|
||||
|
||||
self.assertAllClose(weights_grad_ref, weights_grad, rtol=1e-5, atol=1e-5)
|
||||
np.testing.assert_allclose(loss_ref, loss, rtol=1e-05, atol=1e-5)
|
||||
np.testing.assert_allclose(y_ref, y, rtol=1e-05, atol=1e-5)
|
||||
np.testing.assert_allclose(h_n_ref, h_n, rtol=1e-05, atol=1e-5)
|
||||
np.testing.assert_allclose(c_n_ref, c_n, rtol=1e-05, atol=1e-5)
|
||||
|
||||
@jtu.skip_on_devices("cpu", "tpu", "rocm")
|
||||
def test_lstm_with_varying_seq_lens(self):
|
||||
if lib.version < (0, 4, 7):
|
||||
# TODO(sharadmv, zhangqiaorjc): remove this when minimum jaxlib version is
|
||||
# bumped
|
||||
self.skipTest("Need latest jaxlib for this test to pass.")
|
||||
batch_size = 6
|
||||
seq_len = 7
|
||||
input_size = 8
|
||||
hidden_size = 12
|
||||
num_layers = 5
|
||||
bidirectional = False
|
||||
@jtu.sample_product(
|
||||
batch_size=[1, 4],
|
||||
seq_len=[1, 4],
|
||||
input_size=[1, 2],
|
||||
hidden_size=[1, 6],
|
||||
num_layers=[1, 4],
|
||||
bidirectional=[True, False],
|
||||
)
|
||||
def test_lstm_ref(self, batch_size: int, seq_len: int, input_size: int,
|
||||
hidden_size: int, num_layers: int, bidirectional: bool):
|
||||
|
||||
num_directions = 2 if bidirectional else 1
|
||||
|
||||
seq_lengths = jnp.array([4, 5, 1, 1, 1, 1], dtype=jnp.dtype("int32"))
|
||||
seq_lengths = jax.random.randint(
|
||||
jax.random.PRNGKey(0), (batch_size,), 0, seq_len, dtype=jnp.int32)
|
||||
|
||||
root_key = jax.random.PRNGKey(1)
|
||||
k1, k2, k3, k4 = jax.random.split(root_key, 4)
|
||||
@ -126,13 +142,19 @@ class RnnTest(jtu.JaxTestCase):
|
||||
weights = rnn.init_lstm_weight(k4, input_size, hidden_size, num_layers,
|
||||
bidirectional)
|
||||
|
||||
@jax.jit
|
||||
def f(x, h_0, c_0, weights):
|
||||
return rnn.lstm(
|
||||
@partial(jax.value_and_grad, has_aux=True)
|
||||
def f(weights, x, h_0, c_0):
|
||||
W_ih, W_hh, b_ih, b_hh = rnn.unpack_lstm_weights(weights, input_size,
|
||||
hidden_size, num_layers,
|
||||
bidirectional)
|
||||
y_ref, h_n_ref, c_n_ref = rnn.lstm_ref(
|
||||
x,
|
||||
h_0,
|
||||
c_0,
|
||||
weights,
|
||||
W_ih,
|
||||
W_hh,
|
||||
b_ih,
|
||||
b_hh,
|
||||
seq_lengths=seq_lengths,
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
@ -140,31 +162,20 @@ class RnnTest(jtu.JaxTestCase):
|
||||
dropout=False,
|
||||
bidirectional=bidirectional)
|
||||
|
||||
jtu.check_grads(f, (x, h_0, c_0, weights), modes=['rev'], order=1)
|
||||
loss = jnp.sum(y_ref)
|
||||
|
||||
# TODO(sharadmv): enable when lstm_ref works with seq_lengths
|
||||
# W_ih, W_hh, b_ih, b_hh = rnn.unpack_lstm_weights(weights, input_size,
|
||||
# hidden_size, num_layers,
|
||||
# bidirectional)
|
||||
# y_ref, h_n_ref, c_n_ref = rnn.lstm_ref(
|
||||
# x,
|
||||
# h_0,
|
||||
# c_0,
|
||||
# W_ih,
|
||||
# W_hh,
|
||||
# b_ih,
|
||||
# b_hh,
|
||||
# seq_lengths=seq_lengths,
|
||||
# input_size=input_size,
|
||||
# hidden_size=hidden_size,
|
||||
# num_layers=num_layers,
|
||||
# dropout=False,
|
||||
# bidirectional=bidirectional)
|
||||
return loss, (y_ref, h_n_ref, c_n_ref)
|
||||
|
||||
# np.testing.assert_allclose(y_ref, y, rtol=1e-05, atol=1e-5)
|
||||
# np.testing.assert_allclose(h_n_ref, h_n, rtol=1e-05, atol=1e-5)
|
||||
# np.testing.assert_allclose(c_n_ref, c_n, rtol=1e-05, atol=1e-5)
|
||||
(loss_ref, (y_ref, h_n_ref, c_n_ref)), grad_ref = f(weights, x, h_0, c_0)
|
||||
|
||||
self.assertFalse(np.isnan(loss_ref))
|
||||
self.assertFalse(np.isnan(grad_ref).any())
|
||||
|
||||
self.assertEqual(y_ref.shape, (batch_size, seq_len, num_directions * hidden_size))
|
||||
|
||||
for i in range(batch_size):
|
||||
y_padded = y_ref[i, seq_lengths[i]:]
|
||||
np.testing.assert_allclose(y_padded, jnp.zeros_like(y_padded))
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user