37 Commits

Author SHA1 Message Date
Peter Hawkins
26e9ebcdae Move jax.api to jax._src.api.
PiperOrigin-RevId: 368233837
2021-04-13 09:43:24 -07:00
Matthew Johnson
af59542d00 Re-applying the changes in #6014, after they had to be rolled-back.
PiperOrigin-RevId: 364200195
2021-03-21 13:40:20 -07:00
jax authors
4f8814a760 Copybara import of the project:
--
bf15ba5310d5f9009571928f70548bcbc7e856c3 by Matthew Johnson <mattjj@google.com>:

don't device transfer in convert_element_type

Co-authored-by: Qiao Zhang <zhangqiaorjc@google.com>
PiperOrigin-RevId: 363995032
2021-03-19 16:35:37 -07:00
Matthew Johnson
bf15ba5310 don't device transfer in convert_element_type
Co-authored-by: Qiao Zhang <zhangqiaorjc@google.com>
2021-03-19 13:42:33 -07:00
Jake VanderPlas
7580876e31 jax.scipy.special.logsumexp: support integer input 2021-03-04 14:41:30 -08:00
Jake VanderPlas
2c623d5837 jax.scipy.special.logsumexp: fix b=0 corner case 2021-02-26 17:05:32 -08:00
Jake VanderPlas
d8074aca10 DOC: fix update_numpydoc to prevent deleting docstring content 2021-02-22 16:45:40 -08:00
sunilkpai
997ad31670 added bicgstab to new jax repo
fixed some bugs in the bicgstab method and adjusted tolerance for scipy comparison

fixed flake8

added some tests for gradients, fixed symmetry checks, modified lax.cond -> jnp.where

comment out gmres grad check, to be addressed on future PR

increasing tolerance for bicgstab grad test

change to order 1 checks for bicgstab (gmres still fails in order 1) for internal CI check

remove grad checks for now

changing tolerance to pass numpy comparison test
2021-02-18 18:01:28 -08:00
Jake VanderPlas
e66ae17e05 BUG: fix jax.scipy.stats.dirichlet implementation & add tests 2021-02-17 09:40:18 -08:00
Peter Hawkins
dbeb1a5273 Remove message field from OptimizeResults to allow for vmapping.
Add OptimizeResults to the documentation.
2021-02-16 20:43:53 -05:00
Peter Hawkins
1de5734d63 Enforce the the args argument to jax.scipy.optimize.minimize is a tuple. 2021-02-16 15:57:46 -05:00
jax authors
18f4f6910f Merge pull request #5619 from Dpananos:add_beta_binom
PiperOrigin-RevId: 355927750
2021-02-05 14:22:44 -08:00
Demetri
a3ad787402 Add betabinomial logpmf/pmf and tests
Squash all changes to single commit.  Add betabinom

Add tests for betabinom. nan where undefefined

squash
2021-02-05 15:13:09 -05:00
Demetri
48864a665b Add logdf and pdf for chisquare distribution
Add tests

Lint with flake8 fails.  Should pass now

newline at end of file for flake8

docs and changes

remove whitespace in changeloc
2021-02-03 15:09:21 -05:00
jax authors
9c7258230e Merge pull request #5584 from jakevdp:fix-multivariate-normal
PiperOrigin-RevId: 355212888
2021-02-02 11:31:05 -08:00
Adam Paszke
e11c4fffac Add support for axis names in jax.scipy.special.logsumexp 2021-02-02 13:11:16 +00:00
Jake VanderPlas
2076f42b19 Fix multivariate_normal.logpdf for batched computation 2021-02-01 10:38:02 -08:00
jax authors
149a744892 Merge pull request #5533 from apaszke:general-collectives
PiperOrigin-RevId: 354942612
2021-02-01 08:15:32 -08:00
Adam Paszke
f86bf12b5a Add support for axis names in jnp.{sum,min,max}
Similarly to `jnp.einsum`, whenever we encounter an extension to the
positional NumPy API (in the case of reductions, the extension is
whenever a non-integer axis is specified), we reroute the call to a
parallel primitive instead of the standard lax reductions.

Note that this makes the parallel primitives implement a strict subset
of functionality of the lax reductions so in the future (when we decide
that we want axes to be truly first class) we can always swap out the
implementation for the parallel version. But, it makes sense to keep
them separate for the ease of prototyping in the near future.
2021-02-01 11:41:05 +00:00
Jake VanderPlas
5d16ab03b5 Minor doc formatting fixes 2021-01-29 16:43:27 -08:00
Jake VanderPlas
cfe934c053 Fix some doc build warnings 2021-01-25 14:08:57 -08:00
Jonathan Terhorst
1524b82189 add support for scipy.stats.poisson.cdf 2021-01-24 16:15:31 +00:00
Peter Hawkins
3ac809ede3 [JAX] Move jax.util to jax._src_util.
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00
Jake VanderPlas
8a8a48a926 multivariate_normal.logpdf: add (unimplemented) allow_singular argument 2020-12-09 11:39:06 -08:00
Jake VanderPlas
f74235cdae X32 tests: fail on dtype warnings 2020-12-08 13:03:30 -08:00
Stephan Hoyer
cd9f6cccbf Support ndarrays as arguments to cg and gmres
This is consistent with SciPy, and makes things a little bit less
surprising for users.
2020-12-04 12:53:45 -08:00
Stephan Hoyer
6cc5b28327 Cleanup/fixup jax.scipy.sparse.linalg.gmres and expose it publicly. 2020-12-03 09:23:00 -08:00
Jake VanderPlas
c43cfbd8d1 Better error for jsp.special.multigammaln 2020-12-01 13:31:10 -08:00
Jake VanderPlas
9d2f6148ed Call asarray() rather than array() to avoid host round-trips. 2020-11-24 16:05:48 -08:00
Jamie Townsend
5fccc89a42 Add derivatives for eigenvalues (not eigenvectors)
We aren't supporting eigenvectors for now because eigenvectors are not
uniquely determined by the input matrix, they're only determined up to
'gauge' (that is multiplication by a complex scalar with absolute value
1). Note, this means that second derivatives aren't supported, because
they involve differentiating the eigvals jvp, which itself depends on
eigenvectors.
2020-11-20 16:41:40 +00:00
Peter Hawkins
94cd2046fa [JAX] Move implementation of jax.scipy.sparse.linalg into jax._src.
PiperOrigin-RevId: 343276958
2020-11-19 06:18:09 -08:00
ayush-1506
a3c729b97a Fix #4775 + additional fixes 2020-11-09 10:40:14 +05:30
Peter Hawkins
575a8e0668 Move lax linear algebra routines into a jax.lax.linalg module.
PiperOrigin-RevId: 340717634
2020-11-04 13:36:28 -08:00
Peter Hawkins
c57bbb3cea [JAX] Move jax/lax_linalg.py to jax/_src/lax/linalg.py.
Because we now have a facade around the lax library, we can expose the lax_linalg primitives directly in lax without creating circular dependency problems.

Leave a few forwarding stubs to be removed later.

PiperOrigin-RevId: 340658800
2020-11-04 08:59:36 -08:00
joshuaalbert
231168d480 all changes plus test verifcation on TPU squashed 2020-10-26 23:58:09 +01:00
Peter Hawkins
aa107cf1f4 Move jax.numpy internals into jax._src.numpy. 2020-10-16 20:35:19 -04:00
Peter Hawkins
6acb46516e Move most of the implementation of jax.scipy into jax._src.scipy. 2020-10-16 17:04:25 -04:00