Make rnn_bwd_abstract_eval backwards compatible by guarding it agains the jaxlib version

PiperOrigin-RevId: 529260653
This commit is contained in:
Yash Katariya 2023-05-03 19:28:08 -07:00 committed by jax authors
parent c15f30f22e
commit 47fc23d7ba

View File

@ -476,11 +476,18 @@ def lstm_bwd(input_size: int, hidden_size: int, num_layers: int, dropout: float,
return (dx, dh_0, dc_0, dw, jnp.zeros_like(seq_lengths))
def rnn_bwd_abstract_eval(dy_aval, dhn_aval, dcn_aval, x_aval, h0_aval, c0_aval,
w_aval, y_aval, reserve_space_aval,
if jax._src.lib.version < (0, 4, 9):
def rnn_bwd_abstract_eval(dy_aval, dhn_aval, dcn_aval, x_aval, h0_aval, c0_aval,
w_aval, y_aval, workspace_aval, reserve_space_aval,
seq_lengths_aval, input_size: int, hidden_size: int,
num_layers: int, dropout: float, bidirectional: bool):
return x_aval, h0_aval, c0_aval, w_aval
return x_aval, h0_aval, c0_aval, w_aval
else:
def rnn_bwd_abstract_eval(dy_aval, dhn_aval, dcn_aval, x_aval, h0_aval, c0_aval, # type: ignore
w_aval, y_aval, reserve_space_aval,
seq_lengths_aval, input_size: int, hidden_size: int,
num_layers: int, dropout: float, bidirectional: bool):
return x_aval, h0_aval, c0_aval, w_aval
rnn_bwd_p = core.Primitive('rnn_bwd')