mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Merge pull request #16150 from jakevdp:loop-error
PiperOrigin-RevId: 536913485
This commit is contained in:
commit
15299eb2ee
@ -1471,8 +1471,8 @@ def _while_partial_eval_custom(saveable, unks_in, inst_in, eqn):
|
||||
|
||||
def _while_transpose_error(*_, **kwargs):
|
||||
raise ValueError("Reverse-mode differentiation does not work for "
|
||||
"lax.while_loop or lax.fori_loop. "
|
||||
"Try using lax.scan instead.")
|
||||
"lax.while_loop or lax.fori_loop with dynamic start/stop values. "
|
||||
"Try using lax.scan, or using fori_loop with static start/stop.")
|
||||
|
||||
# For a while loop with ordered effects in the cond, we need a special
|
||||
# lowering. Fundamentally, we'd like to rewrite a while loop that looks like
|
||||
|
Loading…
x
Reference in New Issue
Block a user