Peter Hawkins
1cead779a3
Add support for Hessenberg and tridiagonal matrix reductions on CPU.
...
* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg.
* Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction.
* Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction.
None of these primitives are differentiable at the moment.
PiperOrigin-RevId: 487224934
2022-11-09 06:23:55 -08:00
Tianjian Lu
3b1ddf2881
[linalg] Add jax.scipy.special.bessel_jn (Bessel function of the first kind).
...
PiperOrigin-RevId: 487146250
2022-11-08 23:03:21 -08:00
Adrian Price-Whelan
5784d61048
implement truncnorm in jax.scipy.stats
...
fix some shape and type issues
import into namespace
imports into non-_src library
working logpdf test
cleanup
working tests for cdf and sf after fixing select
relax need for x to be in (a, b)
ensure behavior with invalid input matches scipy
remove enforcing valid parameters in tests
added truncnorm to docs
whoops alphabetical
fix linter error
fix circular import issue
2022-10-22 15:48:20 -04:00
Yann Lamidon
ccbc3059b0
Add JAX equivalent of scipy.stats.mode
2022-10-18 20:45:02 +01:00
Peter Hawkins
9ab88071a7
Avoid loading scipy eagerly.
...
scipy accounts for around 400ms of the 900ms of JAX's import time. By
loading scipy lazily, we can improve the timing of `import jax` down to
about 500ms.
2022-10-12 19:51:09 +00:00
Peter Hawkins
ba557d5e1b
Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
...
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.
PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Wonhyeong Seo
3f6eb40698
JAX implementation of scipy.stats.multinomial pmf & logpmf
...
Co-authored-by: harryjulian <harry.julian@peak.ai>
2022-09-15 13:21:44 -07:00
Dan F-M
0788d5708a
Implementation of jax.scipy.stats.gaussian_kde
2022-06-28 15:17:12 -04:00
carlosgmartin
57b89ba7cb
Added scipy.stats.gennorm.
2022-06-14 13:38:24 -04:00
Peter Hawkins
7ba36fc178
Change implementation of jax.scipy.linalg.polar() and jax._src.scipy.eigh to use the QDWH decomposition from jax._src.lax.qdwh.
...
Remove jax._src.lax.polar.
PiperOrigin-RevId: 448241206
2022-05-12 07:20:52 -07:00
Alex Riley
372371cec6
Add jax.scipy.linalg.funm
2022-05-02 21:46:41 +01:00
YouJiacheng
b485b8e5ce
implement scipy.cluster.vq.vq
...
also add no check_finite and overwrite_* docstring for some scipy.linalg functions
2022-04-23 03:14:32 +08:00
Jake VanderPlas
5782210174
CI: fix flake8 ignore declarations
2022-04-21 13:44:12 -07:00
Alex Riley
869596fc2c
Add jax.scipy.linalg.rsf2csf
2022-04-06 21:06:23 +01:00
Yotaro Kubo
a7fd751acf
Add istft to jax.scipy.signal.
2022-04-01 14:28:53 +09:00
jax authors
54a6e4dad3
Merge pull request #9422 from yotarok:signal_stft
...
PiperOrigin-RevId: 429377655
2022-02-17 12:46:12 -08:00
Yotaro Kubo
e085370ec4
Add some functions for spectral analysis.
...
This commit adds "stft", "csd", and "welch" functions in scipy.signal.
2022-02-17 15:59:24 +09:00
Leello Tadesse Dadi
cb732323f3
adds jax.scipy.linalg.sqrtm
2022-02-16 22:33:47 +01:00
Leello Tadesse Dadi
514d8883ce
adds jax.scipy.schur
2022-02-16 22:33:37 +01:00
jax authors
16c809ce7f
Merge pull request #8625 from Edenhofer:regular_grid_interpolator
...
PiperOrigin-RevId: 423876494
2022-01-24 12:02:34 -08:00
Gordian Edenhofer
2c5fe8c40d
Implement SciPy's RegularGridInterpolator
...
Resolves #8572 .
2022-01-04 22:10:36 +01:00
Peter Hawkins
4e21922055
Use imports relative to the jax
package consistently, rather than .
-relative imports.
...
This is more consistent, since currently we use a mix of both styles. It may also help pytype yield more accurate types.
PiperOrigin-RevId: 412057514
2021-11-24 07:48:29 -08:00
Peter Hawkins
256e7220ff
[JAX] Fix pylint errors.
...
* trailing-whitespace
* dangerous-default-value. None of these appear to be bugs in practice, but the potential for accidentally mutating the default value is there, and the cost of avoiding the problem is small.
* invalid-envvar-default. Pass strings as getenv() defaults.
* unnecessary-semicolon. Use tuples instead for this one-liner.
* invalid-hash-returned. Raise an exception rather than asserting false.
* pointless-string-statement. Use comments instead.
* unreachable. Use @unittest.skip() decorator rather than raising as first line in test.
* logging-not-lazy. Make the logging lazy.
* bad-format-string-type. Use f-string instead.
* subprocess-run-check. Pass check=...
PiperOrigin-RevId: 400858477
2021-10-04 17:54:46 -07:00
Jake VanderPlas
33e2bed1b4
Fix package exports
2021-09-14 13:55:55 -07:00
jax authors
5c58afa647
Merge pull request #7855 from cyprienc:scipy-stats-nbinom
...
PiperOrigin-RevId: 396617949
2021-09-14 09:49:50 -07:00
Cyprien
8c8f0a8c71
Feat: scipy.stats.nbinom implementation
...
fix: increasing tolerance check for testNBinomLogPmf in scipy_stats_test.py
2021-09-14 08:05:42 +01:00
Jake VanderPlas
245581411e
Add PEP484-compatible export for jax and its subpackages
2021-09-13 14:08:48 -07:00
Jonathan Terhorst
fec72e1852
add support for scipy.special.{expn,expi,exp1}
2021-08-24 16:36:10 -04:00
Julius Kunze
6d83027b69
Support scipy.fft.dct/dctn type=2
2021-08-17 18:56:44 +02:00
Peter Hawkins
0dfd76af97
Remove additional info return value from jax.scipy.linalg.polar().
2021-07-20 13:13:31 -04:00
Adam Lewis
a2073ffcc2
Adds an implementation of a QR-based Dynamically Weighted Halley iteration.
2021-07-20 11:30:36 -04:00
tlu7
d97b393694
Adds spherical harmonics.
...
Co-authored-by: Jake VanderPlas <jakevdp@google.com>
2021-07-02 10:42:29 -07:00
tlu7
095e6507b9
Support value computation of associated Legendre functions.
...
Co-authored-by: Jake VanderPlas <jakevdp@google.com>
2021-06-14 14:51:37 -07:00
tlu7
a02bf59233
Adds associated Legendre functions of the first kind.
...
Co-authored-by: Jake VanderPlas <jakevdp@google.com>
2021-06-02 11:37:37 -07:00
Peter Hawkins
97e89bde18
Add a tridiagonal eigh solver.
2021-05-04 20:43:41 -04:00
sunilkpai
997ad31670
added bicgstab to new jax repo
...
fixed some bugs in the bicgstab method and adjusted tolerance for scipy comparison
fixed flake8
added some tests for gradients, fixed symmetry checks, modified lax.cond -> jnp.where
comment out gmres grad check, to be addressed on future PR
increasing tolerance for bicgstab grad test
change to order 1 checks for bicgstab (gmres still fails in order 1) for internal CI check
remove grad checks for now
changing tolerance to pass numpy comparison test
2021-02-18 18:01:28 -08:00
Peter Hawkins
dbeb1a5273
Remove message field from OptimizeResults to allow for vmapping.
...
Add OptimizeResults to the documentation.
2021-02-16 20:43:53 -05:00
Demetri
a3ad787402
Add betabinomial logpmf/pmf and tests
...
Squash all changes to single commit. Add betabinom
Add tests for betabinom. nan where undefefined
squash
2021-02-05 15:13:09 -05:00
Demetri
48864a665b
Add logdf and pdf for chisquare distribution
...
Add tests
Lint with flake8 fails. Should pass now
newline at end of file for flake8
docs and changes
remove whitespace in changeloc
2021-02-03 15:09:21 -05:00
Jonathan Terhorst
1524b82189
add support for scipy.stats.poisson.cdf
2021-01-24 16:15:31 +00:00
Stephan Hoyer
6cc5b28327
Cleanup/fixup jax.scipy.sparse.linalg.gmres and expose it publicly.
2020-12-03 09:23:00 -08:00
Peter Hawkins
94cd2046fa
[JAX] Move implementation of jax.scipy.sparse.linalg into jax._src.
...
PiperOrigin-RevId: 343276958
2020-11-19 06:18:09 -08:00
Stephan Hoyer
7e62270e5a
More unit-tests + mark gmres as internal for now
2020-11-10 21:02:51 -08:00
Stephan Hoyer
c5c71a0a37
Cleanup _safe_normalize
2020-11-08 17:04:00 -08:00
Adam GM Lewis
7ed9fe70ea
Corrections to GMRES - now gives correct result.
...
Co-authored-by: gehring <clement.gehring@gmail.com>
Co-authored-by: Stephan Hoyer <shoyer@google.com>
2020-11-08 15:37:50 -08:00
gehring
342cc36051
Initial implementation of GMRES
2020-11-08 15:34:56 -08:00
Peter Hawkins
c0b480bd51
Add missing jax.scipy.stats distributions to the docs.
...
Alphabetize import order.
2020-10-20 09:47:35 -04:00
Peter Hawkins
6acb46516e
Move most of the implementation of jax.scipy into jax._src.scipy.
2020-10-16 17:04:25 -04:00
Qiao Zhang
49a01d36c8
Use jvp(expm) to compute expm_frechet.
2020-09-17 21:25:40 -07:00
Qiao Zhang
b3a098747a
Make expm transposable and remove custom_jvp rule. ( #4314 )
...
* Make expm transposable and remove custom_jvp rule.
* Add check_grads for up to 2nd order derivative.
2020-09-17 21:10:54 -07:00