Merge pull request #16150 from jakevdp:loop-error

PiperOrigin-RevId: 536913485
This commit is contained in:
jax authors 2023-05-31 21:54:25 -07:00
commit 15299eb2ee

View File

@ -1471,8 +1471,8 @@ def _while_partial_eval_custom(saveable, unks_in, inst_in, eqn):
def _while_transpose_error(*_, **kwargs):
raise ValueError("Reverse-mode differentiation does not work for "
"lax.while_loop or lax.fori_loop. "
"Try using lax.scan instead.")
"lax.while_loop or lax.fori_loop with dynamic start/stop values. "
"Try using lax.scan, or using fori_loop with static start/stop.")
# For a while loop with ordered effects in the cond, we need a special
# lowering. Fundamentally, we'd like to rewrite a while loop that looks like