202 Commits

Author SHA1 Message Date
YouJiacheng
b485b8e5ce implement scipy.cluster.vq.vq
also add no check_finite and overwrite_* docstring for some scipy.linalg functions
2022-04-23 03:14:32 +08:00
Jake VanderPlas
5782210174 CI: fix flake8 ignore declarations 2022-04-21 13:44:12 -07:00
Alex Riley
869596fc2c Add jax.scipy.linalg.rsf2csf 2022-04-06 21:06:23 +01:00
Yotaro Kubo
a7fd751acf Add istft to jax.scipy.signal. 2022-04-01 14:28:53 +09:00
jax authors
54a6e4dad3 Merge pull request #9422 from yotarok:signal_stft
PiperOrigin-RevId: 429377655
2022-02-17 12:46:12 -08:00
Yotaro Kubo
e085370ec4 Add some functions for spectral analysis.
This commit adds "stft", "csd", and "welch" functions in scipy.signal.
2022-02-17 15:59:24 +09:00
Leello Tadesse Dadi
cb732323f3 adds jax.scipy.linalg.sqrtm 2022-02-16 22:33:47 +01:00
Leello Tadesse Dadi
514d8883ce adds jax.scipy.schur 2022-02-16 22:33:37 +01:00
jax authors
16c809ce7f Merge pull request #8625 from Edenhofer:regular_grid_interpolator
PiperOrigin-RevId: 423876494
2022-01-24 12:02:34 -08:00
Gordian Edenhofer
2c5fe8c40d Implement SciPy's RegularGridInterpolator
Resolves #8572 .
2022-01-04 22:10:36 +01:00
Peter Hawkins
4e21922055 Use imports relative to the jax package consistently, rather than .-relative imports.
This is more consistent, since currently we use a mix of both styles. It may also help pytype yield more accurate types.

PiperOrigin-RevId: 412057514
2021-11-24 07:48:29 -08:00
Peter Hawkins
256e7220ff [JAX] Fix pylint errors.
* trailing-whitespace
* dangerous-default-value. None of these appear to be bugs in practice, but the potential for accidentally mutating the default value is there, and the cost of avoiding the problem is small.
* invalid-envvar-default. Pass strings as getenv() defaults.
* unnecessary-semicolon. Use tuples instead for this one-liner.
* invalid-hash-returned. Raise an exception rather than asserting false.
* pointless-string-statement. Use comments instead.
* unreachable. Use @unittest.skip() decorator rather than raising as first line in test.
* logging-not-lazy. Make the logging lazy.
* bad-format-string-type. Use f-string instead.
* subprocess-run-check. Pass check=...

PiperOrigin-RevId: 400858477
2021-10-04 17:54:46 -07:00
Jake VanderPlas
33e2bed1b4 Fix package exports 2021-09-14 13:55:55 -07:00
jax authors
5c58afa647 Merge pull request #7855 from cyprienc:scipy-stats-nbinom
PiperOrigin-RevId: 396617949
2021-09-14 09:49:50 -07:00
Cyprien
8c8f0a8c71 Feat: scipy.stats.nbinom implementation
fix: increasing tolerance check for testNBinomLogPmf in scipy_stats_test.py
2021-09-14 08:05:42 +01:00
Jake VanderPlas
245581411e Add PEP484-compatible export for jax and its subpackages 2021-09-13 14:08:48 -07:00
Jonathan Terhorst
fec72e1852 add support for scipy.special.{expn,expi,exp1} 2021-08-24 16:36:10 -04:00
Julius Kunze
6d83027b69 Support scipy.fft.dct/dctn type=2 2021-08-17 18:56:44 +02:00
Peter Hawkins
0dfd76af97 Remove additional info return value from jax.scipy.linalg.polar(). 2021-07-20 13:13:31 -04:00
Adam Lewis
a2073ffcc2 Adds an implementation of a QR-based Dynamically Weighted Halley iteration. 2021-07-20 11:30:36 -04:00
tlu7
d97b393694 Adds spherical harmonics.
Co-authored-by: Jake VanderPlas <jakevdp@google.com>
2021-07-02 10:42:29 -07:00
tlu7
095e6507b9 Support value computation of associated Legendre functions.
Co-authored-by: Jake VanderPlas <jakevdp@google.com>
2021-06-14 14:51:37 -07:00
tlu7
a02bf59233 Adds associated Legendre functions of the first kind.
Co-authored-by: Jake VanderPlas <jakevdp@google.com>
2021-06-02 11:37:37 -07:00
Peter Hawkins
97e89bde18 Add a tridiagonal eigh solver. 2021-05-04 20:43:41 -04:00
sunilkpai
997ad31670 added bicgstab to new jax repo
fixed some bugs in the bicgstab method and adjusted tolerance for scipy comparison

fixed flake8

added some tests for gradients, fixed symmetry checks, modified lax.cond -> jnp.where

comment out gmres grad check, to be addressed on future PR

increasing tolerance for bicgstab grad test

change to order 1 checks for bicgstab (gmres still fails in order 1) for internal CI check

remove grad checks for now

changing tolerance to pass numpy comparison test
2021-02-18 18:01:28 -08:00
Peter Hawkins
dbeb1a5273 Remove message field from OptimizeResults to allow for vmapping.
Add OptimizeResults to the documentation.
2021-02-16 20:43:53 -05:00
Demetri
a3ad787402 Add betabinomial logpmf/pmf and tests
Squash all changes to single commit.  Add betabinom

Add tests for betabinom. nan where undefefined

squash
2021-02-05 15:13:09 -05:00
Demetri
48864a665b Add logdf and pdf for chisquare distribution
Add tests

Lint with flake8 fails.  Should pass now

newline at end of file for flake8

docs and changes

remove whitespace in changeloc
2021-02-03 15:09:21 -05:00
Jonathan Terhorst
1524b82189 add support for scipy.stats.poisson.cdf 2021-01-24 16:15:31 +00:00
Stephan Hoyer
6cc5b28327 Cleanup/fixup jax.scipy.sparse.linalg.gmres and expose it publicly. 2020-12-03 09:23:00 -08:00
Peter Hawkins
94cd2046fa [JAX] Move implementation of jax.scipy.sparse.linalg into jax._src.
PiperOrigin-RevId: 343276958
2020-11-19 06:18:09 -08:00
Stephan Hoyer
7e62270e5a More unit-tests + mark gmres as internal for now 2020-11-10 21:02:51 -08:00
Stephan Hoyer
c5c71a0a37 Cleanup _safe_normalize 2020-11-08 17:04:00 -08:00
Adam GM Lewis
7ed9fe70ea Corrections to GMRES - now gives correct result.
Co-authored-by: gehring <clement.gehring@gmail.com>

Co-authored-by: Stephan Hoyer <shoyer@google.com>
2020-11-08 15:37:50 -08:00
gehring
342cc36051 Initial implementation of GMRES 2020-11-08 15:34:56 -08:00
Peter Hawkins
c0b480bd51 Add missing jax.scipy.stats distributions to the docs.
Alphabetize import order.
2020-10-20 09:47:35 -04:00
Peter Hawkins
6acb46516e Move most of the implementation of jax.scipy into jax._src.scipy. 2020-10-16 17:04:25 -04:00
Qiao Zhang
49a01d36c8 Use jvp(expm) to compute expm_frechet. 2020-09-17 21:25:40 -07:00
Qiao Zhang
b3a098747a
Make expm transposable and remove custom_jvp rule. (#4314)
* Make expm transposable and remove custom_jvp rule.

* Add check_grads for up to 2nd order derivative.
2020-09-17 21:10:54 -07:00
Peter Hawkins
cf65f6b24e
Change lax_linalg.lu to return a permutation representation of the partial pivoting information. (#4241)
The permutation is more efficiently computed during the decomposition on TPU, and the only use case that would not require us to compute it would be for evaluating determinants.
2020-09-10 11:16:35 -04:00
Jake Vanderplas
d551cec6e8
Add Bessel functions in jax.numpy & jax.scipy.special (#4007) 2020-08-10 10:10:59 -07:00
Stephan Hoyer
242b3249c6
Add missing license headers (#3899)
Oops!
2020-07-29 14:22:21 -07:00
Joshua George Albert
02009e0cf0
BFGS algorithm (#3101)
* BFGS algorithm
Addressing https://github.com/google/jax/issues/1400

* * addresses @shoyer comments of PR

* * skip dtype checks

* * backslash in docstring

* * increase closeness tol

* * increase closeness atol to 1.6e-6

* * addresses jakevdp comments

* * same line search as scipy
* same results format
* same (and more) testing as in scipy for line search and bfgs
* 2 spacing
* documenting
* analytic hessian non default but still available
* NamedTuple classes

* * small fix in setup_method

* * small doc string addition

* * increase atol to 2e-5 for comparison

* * removed experimental analytic_hessian
* using jnp.where for all binary replace operations
* removed _nojit as this is what disable_jit does

* * fix indentation mangling
* remove remaining _nojit

* * fixing more indentation mangling

* * segregate third_party test

* * use parametrise

* * use parametrise

* * minor nitpicking

* * fix some errors

* * use _CompileAndCheck

* * replace f_0 and g_0 for (ugly) scipy variable names

* * remove unused function

* * fix spacing

* * add args argument to minimize
* adhere fmin_bfgs to scipy api

* * remove unused function

* * ignore F401

* * look into unittest

* * fix unittest error

* * delete unused function
* more adherence to scipy's api
* add scipy's old_old_fval arg though unused
* increase line_search default maxiter to 20 (10 not enough in some cases)

* * remove unused imports

* * add ord=norm to the initial convergence check

* * remove helper function

* * merge jax/master

* * Resolve a remnant conflict from merging master to solve ReadTheDocs issue.

* * Add an informative termination message and status number.

* Revert changes to unrelated files

* cleanup bfgs_minimize

* cleanup minimize.py

* Move minimize_bfgs.py to _bfgs.py

* Move more modules around

* improve docs

* high precision einsum

* Formatting in line search

* fixup

* Type checking

* fix mypy failures

* minor fixup

Co-authored-by: Stephan Hoyer <shoyer@google.com>
2020-07-29 14:14:40 -07:00
Du Phan
0a3a5bbb16
address nan issue (#3777) 2020-07-22 09:17:06 -07:00
Claudio Fantacci
150d028d9d
Update scipy.ndimage.map_coordinates docstring (#3762) 2020-07-15 11:19:40 -07:00
Stephan Hoyer
36eb137dd3
Refine argument validation inside jax.scipy.sparse.linalg.cg (#3630)
Now we check tree structure and leaf shapes separately. This allow us to
support pytrees that either don't define equality or that define it
inconsistently (e.g., elementwise like NumPy) with builtin data structures like
list/dict.
2020-07-06 09:24:44 -07:00
James Bradbury
039b6e2ed9
Use precision=HIGHEST in expm repeated squaring (#3601)
* Use precision=HIGHEST in expm repeated squaring

This will improve the output accuracy on TPUs, and doesn't currently affect other platforms.

Also remove a spurious duplicate line.

* Also use HIGHEST precision in Pade approximants

* Update linalg.py

* Update linalg.py
2020-06-29 15:46:38 -07:00
Sri Hari Krishna Narayanan
7b57dc8c80
Issue1635 expm frechet (#2062)
* Implement Frechet derivatives for expm.

* Update expm to use the current custom gradients API.

Make some stylistic fixes.

Co-authored-by: Peter Hawkins <phawkins@google.com>
2020-06-28 12:11:12 -04:00
David Pfau
9d173c6225
Support b and return_sign in scipy.special.logsumexp (#3488) 2020-06-23 15:36:45 -07:00
Jake Vanderplas
33c455a1a8
Add jax.scipy.signal.detrend (#3516) 2020-06-22 19:49:00 -07:00