364 Commits

Author SHA1 Message Date
Jake VanderPlas
8a8a48a926 multivariate_normal.logpdf: add (unimplemented) allow_singular argument 2020-12-09 11:39:06 -08:00
Jake VanderPlas
f74235cdae X32 tests: fail on dtype warnings 2020-12-08 13:03:30 -08:00
Stephan Hoyer
cd9f6cccbf Support ndarrays as arguments to cg and gmres
This is consistent with SciPy, and makes things a little bit less
surprising for users.
2020-12-04 12:53:45 -08:00
Stephan Hoyer
6cc5b28327 Cleanup/fixup jax.scipy.sparse.linalg.gmres and expose it publicly. 2020-12-03 09:23:00 -08:00
Jake VanderPlas
c43cfbd8d1 Better error for jsp.special.multigammaln 2020-12-01 13:31:10 -08:00
Jake VanderPlas
9d2f6148ed Call asarray() rather than array() to avoid host round-trips. 2020-11-24 16:05:48 -08:00
Jamie Townsend
5fccc89a42 Add derivatives for eigenvalues (not eigenvectors)
We aren't supporting eigenvectors for now because eigenvectors are not
uniquely determined by the input matrix, they're only determined up to
'gauge' (that is multiplication by a complex scalar with absolute value
1). Note, this means that second derivatives aren't supported, because
they involve differentiating the eigvals jvp, which itself depends on
eigenvectors.
2020-11-20 16:41:40 +00:00
Peter Hawkins
94cd2046fa [JAX] Move implementation of jax.scipy.sparse.linalg into jax._src.
PiperOrigin-RevId: 343276958
2020-11-19 06:18:09 -08:00
ayush-1506
a3c729b97a Fix #4775 + additional fixes 2020-11-09 10:40:14 +05:30
Peter Hawkins
575a8e0668 Move lax linear algebra routines into a jax.lax.linalg module.
PiperOrigin-RevId: 340717634
2020-11-04 13:36:28 -08:00
Peter Hawkins
c57bbb3cea [JAX] Move jax/lax_linalg.py to jax/_src/lax/linalg.py.
Because we now have a facade around the lax library, we can expose the lax_linalg primitives directly in lax without creating circular dependency problems.

Leave a few forwarding stubs to be removed later.

PiperOrigin-RevId: 340658800
2020-11-04 08:59:36 -08:00
joshuaalbert
231168d480 all changes plus test verifcation on TPU squashed 2020-10-26 23:58:09 +01:00
Peter Hawkins
aa107cf1f4 Move jax.numpy internals into jax._src.numpy. 2020-10-16 20:35:19 -04:00
Peter Hawkins
6acb46516e Move most of the implementation of jax.scipy into jax._src.scipy. 2020-10-16 17:04:25 -04:00