mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Make rnn_bwd_abstract_eval backwards compatible by guarding it agains the jaxlib version
PiperOrigin-RevId: 529260653
This commit is contained in:
parent
c15f30f22e
commit
47fc23d7ba
@ -476,11 +476,18 @@ def lstm_bwd(input_size: int, hidden_size: int, num_layers: int, dropout: float,
|
||||
return (dx, dh_0, dc_0, dw, jnp.zeros_like(seq_lengths))
|
||||
|
||||
|
||||
def rnn_bwd_abstract_eval(dy_aval, dhn_aval, dcn_aval, x_aval, h0_aval, c0_aval,
|
||||
w_aval, y_aval, reserve_space_aval,
|
||||
if jax._src.lib.version < (0, 4, 9):
|
||||
def rnn_bwd_abstract_eval(dy_aval, dhn_aval, dcn_aval, x_aval, h0_aval, c0_aval,
|
||||
w_aval, y_aval, workspace_aval, reserve_space_aval,
|
||||
seq_lengths_aval, input_size: int, hidden_size: int,
|
||||
num_layers: int, dropout: float, bidirectional: bool):
|
||||
return x_aval, h0_aval, c0_aval, w_aval
|
||||
return x_aval, h0_aval, c0_aval, w_aval
|
||||
else:
|
||||
def rnn_bwd_abstract_eval(dy_aval, dhn_aval, dcn_aval, x_aval, h0_aval, c0_aval, # type: ignore
|
||||
w_aval, y_aval, reserve_space_aval,
|
||||
seq_lengths_aval, input_size: int, hidden_size: int,
|
||||
num_layers: int, dropout: float, bidirectional: bool):
|
||||
return x_aval, h0_aval, c0_aval, w_aval
|
||||
|
||||
|
||||
rnn_bwd_p = core.Primitive('rnn_bwd')
|
||||
|
Loading…
x
Reference in New Issue
Block a user