31 Commits

Author SHA1 Message Date
Will Froom
dc16721b52 [XLA:CPU] Use central difference to calculate numerical gradient
PiperOrigin-RevId: 718383754
2025-01-22 07:49:43 -08:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Jake VanderPlas
f090074d86 Avoid 'from jax import config' imports
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00
Jake VanderPlas
cddee4654c tests: access tree utilities via jax.tree.* 2024-02-26 14:17:18 -08:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
jax authors
9dac458e85 Merge pull request #11077 from DanPuzzuoli:ode_dt_max
PiperOrigin-RevId: 484639938
2022-10-28 16:05:25 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jake VanderPlas
5e97011dc8 [x64] make ode_test pass with strict dtype promotion 2022-06-17 15:21:10 -07:00
DanPuzzuoli
f835716085 changing dtmax to hmax 2022-06-13 06:38:34 -07:00
DanPuzzuoli
85dfca26fe slight test changes 2022-06-12 17:27:05 -07:00
DanPuzzuoli
0935b29865 added dtmax arg and test verifying correct behaviour - errors now with differentation 2022-06-12 16:51:25 -07:00
Peter Hawkins
c339330bc1 [XLA:CPU] Relax test tolerances for tests using XLA:CPU.
An upcoming change to XLA:CPU will disable reassociation on floating point operators by default which is an unsound fast math optimization. This change is being made to fix numerical errors in softmax computations caused by reassocation. After that change, we will enable reassociation only in reduction operators where it is very important for performance and the XLA operator contract allows that.

Since this change alters the order of operations, it may cause small numerical changes leading to test failures. This change relaxes test tolerances to make tests pass.

PiperOrigin-RevId: 431453240
2022-02-28 09:26:54 -08:00
Florian Hopfmueller
38e98d2d7b Fix a bug that promoted t to a complex in odeint, and modify a test so it would have caught it
In odeint, raise error if t is not an array of floats
2022-01-07 12:10:35 -05:00
Jake VanderPlas
df0969961b Testing: avoid hard-coding random seeds 2021-12-10 10:32:09 -08:00
Peter Hawkins
db2e91eba2 Move jax.test_util to jax._src.test_util.
Add forwarding shims for names used by external clients of JAX in practice.

PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00
Philipp Thölke
603f0c1e58
Fix scan carry types in gradient of complex ODE (#4130)
* Cast t_bar from potential complex to float in ode.py

* Add test case for complex odeint (currently failing)

* Wrap odeint into complex-to-real function in test case

* fixup

Co-authored-by: Stephan Hoyer <shoyer@google.com>
2020-08-24 13:50:44 -07:00
Matthew Johnson
c564aca777
skip more ode tests on gpu, b/c slow to compile (#4028) 2020-08-11 20:36:51 -07:00
Matthew Johnson
ba1b5ce8de
skip some ode tests on gpu for speed (#3629) 2020-07-01 11:26:44 -07:00
Jake Vanderplas
09d128edb3
Cleanup: remove some test interdependence (#3600) 2020-06-29 16:22:05 -07:00
Jake VanderPlas
afce718eb1 Add ability to specify individual test targets 2020-06-29 11:08:57 -07:00
Matthew Johnson
57915dc1d9
odeint: don't hoist non-differentiable consts (#3587)
fixes #3584

This could use further revision! Left a todo.

The issue is that in #3562 we started closure-converting the dynamics
function (by tracing it to a jaxpr up-front) so as to handle closed-over
constants with respect to which we want to differentiate the odeint
call. But if the dynamics function closes over integer-valued constants,
then we can no longer call `vjp` on the closure-converted function
without getting an error.

One fix would be to support (trivial) differentiation with respect to
integer-valued inputs. That would work if we supperss the error message
for integer-valued inputs in `vjp` and add a trivial tangent space
for integer-valued arrays. Since that's potentially a further-reaching
change, this commit instead just applies a local fix to avoid adding
integer-valued inputs to the dynamics function by adapting the
closure-conversion code.
2020-06-28 14:27:07 -07:00
Matthew Johnson
26c6c3a457
fix error when doing forward-mode of odeint (#3566)
fixes #3558
2020-06-25 20:57:34 -07:00
Matthew Johnson
db80ca5dd8
allow closures for odeint dynamics functions (#3562)
* allow closures for odeint dynamics functions

fixes #2718, #3557

* add tests for odeint dynamics closing over tracers
2020-06-25 17:36:17 -07:00
samuela
1eb7f1b13d
Use onp instead of np in ode_test (#3288)
* Use onp instead of np in ode_test

* other ode_test.py fixes

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-02 09:54:51 -07:00
samuela
03e2971216
Add a pytree odeint test (#3268) 2020-06-01 10:47:39 -07:00
samuela
9cbd63e86f
Remove unused unittest import (#3269) 2020-06-01 17:23:28 +03:00
Matthew Johnson
42b425d8e5
fix disable_jit logic in lax.cond and lax.while_loop (#3156)
* fix disable_jit logic in lax.cond, fixes #3093

* fix disable_jit logic in lax.while_loop, fix #2823

* add test for issue #3093

* add test for #2823

* add test for #2598
2020-05-19 18:14:10 -07:00
Peter Hawkins
b543652332
Replace np -> jnp, onp -> np in tests. (#2969) 2020-05-05 14:59:16 -04:00
Jacob Kelly
cc0e9a3189
refactor ode tests, add scipy benchmark (#2824)
* refactor ode tests, add scipy benchmark

remove double import

rename to scipy merge vmap test properly

* clean up more global trace state after errors

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-04-27 21:53:38 -07:00
Matthew Johnson
11d7fb051c
add more ode tests (#2819) 2020-04-24 01:47:20 -07:00
Matthew Johnson
6ad2908f8d
add ode test file (#2818)
* add ode test file

* control test tolerances based on precision
2020-04-24 01:21:27 -07:00