An upcoming change to XLA:CPU will disable reassociation on floating point operators by default which is an unsound fast math optimization. This change is being made to fix numerical errors in softmax computations caused by reassocation. After that change, we will enable reassociation only in reduction operators where it is very important for performance and the XLA operator contract allows that.
Since this change alters the order of operations, it may cause small numerical changes leading to test failures. This change relaxes test tolerances to make tests pass.
PiperOrigin-RevId: 431453240
* Cast t_bar from potential complex to float in ode.py
* Add test case for complex odeint (currently failing)
* Wrap odeint into complex-to-real function in test case
* fixup
Co-authored-by: Stephan Hoyer <shoyer@google.com>
fixes#3584
This could use further revision! Left a todo.
The issue is that in #3562 we started closure-converting the dynamics
function (by tracing it to a jaxpr up-front) so as to handle closed-over
constants with respect to which we want to differentiate the odeint
call. But if the dynamics function closes over integer-valued constants,
then we can no longer call `vjp` on the closure-converted function
without getting an error.
One fix would be to support (trivial) differentiation with respect to
integer-valued inputs. That would work if we supperss the error message
for integer-valued inputs in `vjp` and add a trivial tangent space
for integer-valued arrays. Since that's potentially a further-reaching
change, this commit instead just applies a local fix to avoid adding
integer-valued inputs to the dynamics function by adapting the
closure-conversion code.
* fix disable_jit logic in lax.cond, fixes#3093
* fix disable_jit logic in lax.while_loop, fix#2823
* add test for issue #3093
* add test for #2823
* add test for #2598
* refactor ode tests, add scipy benchmark
remove double import
rename to scipy merge vmap test properly
* clean up more global trace state after errors
Co-authored-by: Matthew Johnson <mattjj@google.com>