mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Fix test breakage in RNN test with old jaxlibs.
Remove some outdated version guards.
This commit is contained in:
parent
512216056f
commit
f52926e832
@ -397,8 +397,7 @@ def lstm_fwd(x: Array, h_0: Array, c_0: Array, w: Array, seq_lengths: Array,
|
||||
if seq_lengths.dtype != jnp.dtype("int32"):
|
||||
raise NotImplementedError("`seq_lengths` can only be int32.")
|
||||
cudnn_allow_tf32 = _lstm_cudnn_allow_tf32(precision)
|
||||
if jax._src.lib.version < (0, 4, 9):
|
||||
y, h_n, c_n, workspace, reserve_space = rnn_fwd_p.bind(
|
||||
y, h_n, c_n, reserve_space = rnn_fwd_p.bind(
|
||||
x,
|
||||
h_0,
|
||||
c_0,
|
||||
@ -410,22 +409,7 @@ def lstm_fwd(x: Array, h_0: Array, c_0: Array, w: Array, seq_lengths: Array,
|
||||
dropout=dropout,
|
||||
bidirectional=bidirectional,
|
||||
cudnn_allow_tf32=cudnn_allow_tf32)
|
||||
return (y, h_n, c_n), (x, h_0, c_0, w, seq_lengths, y, workspace,
|
||||
reserve_space)
|
||||
else:
|
||||
y, h_n, c_n, reserve_space = rnn_fwd_p.bind(
|
||||
x,
|
||||
h_0,
|
||||
c_0,
|
||||
w,
|
||||
seq_lengths,
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
dropout=dropout,
|
||||
bidirectional=bidirectional,
|
||||
cudnn_allow_tf32=cudnn_allow_tf32)
|
||||
return (y, h_n, c_n), (x, h_0, c_0, w, seq_lengths, y, reserve_space)
|
||||
return (y, h_n, c_n), (x, h_0, c_0, w, seq_lengths, y, reserve_space)
|
||||
|
||||
|
||||
def rnn_abstract_eval(x_aval, h_0_aval, c_0_aval, w_aval, seq_lengths_aval,
|
||||
@ -436,94 +420,71 @@ def rnn_abstract_eval(x_aval, h_0_aval, c_0_aval, w_aval, seq_lengths_aval,
|
||||
num_directions = 2 if bidirectional else 1
|
||||
output_shape = (batch_size, max_seq_length, num_directions * hidden_size)
|
||||
output_aval = core.ShapedArray(output_shape, x_aval.dtype)
|
||||
if jax._src.lib.version < (0, 4, 9):
|
||||
workspace_size, reserve_space_size = (
|
||||
gpu_rnn.compute_rnn_workspace_reserve_space_sizes( # pytype: disable=attribute-error
|
||||
input_size, hidden_size, num_layers, batch_size, max_seq_length,
|
||||
dropout, bidirectional, cudnn_allow_tf32))
|
||||
workspace_aval = core.ShapedArray((workspace_size,), jnp.float32)
|
||||
reserve_space_aval = core.ShapedArray((reserve_space_size,), jnp.float32)
|
||||
return output_aval, h_0_aval, c_0_aval, workspace_aval, reserve_space_aval
|
||||
else:
|
||||
if jax._src.lib.version >= (0, 4, 17):
|
||||
_, reserve_space_size = (
|
||||
gpu_rnn.compute_rnn_workspace_reserve_space_sizes( # pytype: disable=attribute-error
|
||||
input_size, hidden_size, num_layers, batch_size, max_seq_length,
|
||||
dropout, bidirectional, cudnn_allow_tf32))
|
||||
reserve_space_aval = core.ShapedArray((reserve_space_size,), jnp.float32)
|
||||
return output_aval, h_0_aval, c_0_aval, reserve_space_aval
|
||||
else:
|
||||
_, reserve_space_size = (
|
||||
gpu_rnn.compute_rnn_workspace_reserve_space_sizes( # pytype: disable=attribute-error
|
||||
input_size, hidden_size, num_layers, batch_size, max_seq_length,
|
||||
dropout, bidirectional))
|
||||
reserve_space_aval = core.ShapedArray((reserve_space_size,), jnp.float32)
|
||||
return output_aval, h_0_aval, c_0_aval, reserve_space_aval
|
||||
|
||||
|
||||
def _gpu_lowering_strip_tf32(fn, *args, cudnn_allow_tf32, **kw):
|
||||
del cudnn_allow_tf32
|
||||
return fn(*args, **kw)
|
||||
|
||||
rnn_fwd_p = core.Primitive('rnn_fwd')
|
||||
rnn_fwd_p.multiple_results = True
|
||||
rnn_fwd_p.def_impl(partial(xla.apply_primitive, rnn_fwd_p))
|
||||
rnn_fwd_p.def_abstract_eval(rnn_abstract_eval)
|
||||
if gpu_rnn:
|
||||
mlir.register_lowering(rnn_fwd_p, gpu_rnn.cudnn_rnn_lowering, platform='cuda')
|
||||
if jax._src.lib.version >= (0, 4, 17):
|
||||
mlir.register_lowering(rnn_fwd_p, gpu_rnn.cudnn_rnn_lowering, platform='cuda')
|
||||
else:
|
||||
mlir.register_lowering(
|
||||
rnn_fwd_p,
|
||||
partial(_gpu_lowering_strip_tf32, gpu_rnn.cudnn_rnn_lowering),
|
||||
platform='cuda'
|
||||
)
|
||||
|
||||
|
||||
def lstm_bwd(input_size: int, hidden_size: int, num_layers: int, dropout: float,
|
||||
bidirectional: bool, precision: lax.PrecisionLike,
|
||||
residuals, gradients):
|
||||
cudnn_allow_tf32 = _lstm_cudnn_allow_tf32(precision)
|
||||
if jax._src.lib.version < (0, 4, 9):
|
||||
x, h_0, c_0, w, seq_lengths, y, workspace, reserve_space = residuals
|
||||
dy, dh_n, dc_n = gradients
|
||||
dx, dh_0, dc_0, dw = rnn_bwd_p.bind(
|
||||
dy,
|
||||
dh_n,
|
||||
dc_n,
|
||||
x,
|
||||
h_0,
|
||||
c_0,
|
||||
w,
|
||||
y,
|
||||
workspace,
|
||||
reserve_space,
|
||||
seq_lengths,
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
dropout=dropout,
|
||||
bidirectional=bidirectional,
|
||||
cudnn_allow_tf32=cudnn_allow_tf32)
|
||||
return (dx, dh_0, dc_0, dw, jnp.zeros_like(seq_lengths))
|
||||
else:
|
||||
x, h_0, c_0, w, seq_lengths, y, reserve_space = residuals
|
||||
dy, dh_n, dc_n = gradients
|
||||
dx, dh_0, dc_0, dw = rnn_bwd_p.bind(
|
||||
dy,
|
||||
dh_n,
|
||||
dc_n,
|
||||
x,
|
||||
h_0,
|
||||
c_0,
|
||||
w,
|
||||
y,
|
||||
reserve_space,
|
||||
seq_lengths,
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
dropout=dropout,
|
||||
bidirectional=bidirectional,
|
||||
cudnn_allow_tf32=cudnn_allow_tf32)
|
||||
return (dx, dh_0, dc_0, dw, jnp.zeros_like(seq_lengths))
|
||||
x, h_0, c_0, w, seq_lengths, y, reserve_space = residuals
|
||||
dy, dh_n, dc_n = gradients
|
||||
dx, dh_0, dc_0, dw = rnn_bwd_p.bind(
|
||||
dy,
|
||||
dh_n,
|
||||
dc_n,
|
||||
x,
|
||||
h_0,
|
||||
c_0,
|
||||
w,
|
||||
y,
|
||||
reserve_space,
|
||||
seq_lengths,
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
dropout=dropout,
|
||||
bidirectional=bidirectional,
|
||||
cudnn_allow_tf32=cudnn_allow_tf32)
|
||||
return (dx, dh_0, dc_0, dw, jnp.zeros_like(seq_lengths))
|
||||
|
||||
|
||||
if jax._src.lib.version < (0, 4, 9):
|
||||
def rnn_bwd_abstract_eval(dy_aval, dhn_aval, dcn_aval, x_aval, h0_aval, c0_aval,
|
||||
w_aval, y_aval, workspace_aval, reserve_space_aval,
|
||||
def rnn_bwd_abstract_eval(dy_aval, dhn_aval, dcn_aval, x_aval, h0_aval, c0_aval, # type: ignore
|
||||
w_aval, y_aval, reserve_space_aval,
|
||||
seq_lengths_aval, input_size: int, hidden_size: int,
|
||||
num_layers: int, dropout: float, bidirectional: bool,
|
||||
cudnn_allow_tf32: bool):
|
||||
return x_aval, h0_aval, c0_aval, w_aval
|
||||
else:
|
||||
def rnn_bwd_abstract_eval(dy_aval, dhn_aval, dcn_aval, x_aval, h0_aval, c0_aval, # type: ignore
|
||||
w_aval, y_aval, reserve_space_aval,
|
||||
seq_lengths_aval, input_size: int, hidden_size: int,
|
||||
num_layers: int, dropout: float, bidirectional: bool,
|
||||
cudnn_allow_tf32: bool):
|
||||
return x_aval, h0_aval, c0_aval, w_aval
|
||||
return x_aval, h0_aval, c0_aval, w_aval
|
||||
|
||||
|
||||
rnn_bwd_p = core.Primitive('rnn_bwd')
|
||||
@ -531,7 +492,14 @@ rnn_bwd_p.multiple_results = True
|
||||
rnn_bwd_p.def_impl(partial(xla.apply_primitive, rnn_bwd_p))
|
||||
rnn_bwd_p.def_abstract_eval(rnn_bwd_abstract_eval)
|
||||
if gpu_rnn:
|
||||
mlir.register_lowering(
|
||||
rnn_bwd_p, gpu_rnn.cudnn_rnn_bwd_lowering, platform='cuda')
|
||||
if jax._src.lib.version >= (0, 4, 17):
|
||||
mlir.register_lowering(
|
||||
rnn_bwd_p, gpu_rnn.cudnn_rnn_bwd_lowering, platform='cuda')
|
||||
else:
|
||||
mlir.register_lowering(
|
||||
rnn_bwd_p,
|
||||
partial(_gpu_lowering_strip_tf32, gpu_rnn.cudnn_rnn_bwd_lowering),
|
||||
platform='cuda'
|
||||
)
|
||||
|
||||
lstm.defvjp(lstm_fwd, lstm_bwd)
|
||||
|
@ -16,7 +16,6 @@ from absl.testing import absltest
|
||||
import numpy as np
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax._src import lib
|
||||
from jax._src import test_util as jtu
|
||||
from jax.experimental import rnn
|
||||
|
||||
@ -40,11 +39,6 @@ class RnnTest(jtu.JaxTestCase):
|
||||
@jax.default_matmul_precision("float32")
|
||||
def test_lstm(self, batch_size: int, seq_len: int, input_size: int,
|
||||
hidden_size: int, num_layers: int, bidirectional: bool):
|
||||
if lib.version < (0, 4, 7):
|
||||
# TODO(sharadmv, zhangqiaorjc): remove this when minimum jaxlib version is
|
||||
# bumped
|
||||
self.skipTest("Need latest jaxlib for this test to pass.")
|
||||
|
||||
# TODO(phawkins): Partially disable this on cudnn version per b/281071013
|
||||
if (batch_size == 1 and seq_len == 4 and input_size == 1 and
|
||||
hidden_size == 6 and num_layers == 4 and bidirectional == False):
|
||||
|
Loading…
x
Reference in New Issue
Block a user