Matthew Johnson e0d2736e37 add custom_jvp for jax.nn.softmax
This avoids saving the jnp.exp(...) value.
2023-04-22 11:28:03 -07:00
..
2023-04-22 11:28:03 -07:00