213 Commits

Author SHA1 Message Date
Peter Hawkins
1cead779a3 Add support for Hessenberg and tridiagonal matrix reductions on CPU.
* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg.
* Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction.
* Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction.

None of these primitives are differentiable at the moment.

PiperOrigin-RevId: 487224934
2022-11-09 06:23:55 -08:00
Tianjian Lu
3b1ddf2881 [linalg] Add jax.scipy.special.bessel_jn (Bessel function of the first kind).
PiperOrigin-RevId: 487146250
2022-11-08 23:03:21 -08:00
Adrian Price-Whelan
5784d61048 implement truncnorm in jax.scipy.stats
fix some shape and type issues

import into namespace

imports into non-_src library

working logpdf test

cleanup

working tests for cdf and sf after fixing select

relax need for x to be in (a, b)

ensure behavior with invalid input matches scipy

remove enforcing valid parameters in tests

added truncnorm to docs

whoops alphabetical

fix linter error

fix circular import issue
2022-10-22 15:48:20 -04:00
Yann Lamidon
ccbc3059b0 Add JAX equivalent of scipy.stats.mode 2022-10-18 20:45:02 +01:00
Peter Hawkins
9ab88071a7 Avoid loading scipy eagerly.
scipy accounts for around 400ms of the 900ms of JAX's import time. By
loading scipy lazily, we can improve the timing of `import jax` down to
about 500ms.
2022-10-12 19:51:09 +00:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Wonhyeong Seo
3f6eb40698 JAX implementation of scipy.stats.multinomial pmf & logpmf
Co-authored-by: harryjulian <harry.julian@peak.ai>
2022-09-15 13:21:44 -07:00
Dan F-M
0788d5708a Implementation of jax.scipy.stats.gaussian_kde 2022-06-28 15:17:12 -04:00
carlosgmartin
57b89ba7cb Added scipy.stats.gennorm. 2022-06-14 13:38:24 -04:00
Peter Hawkins
7ba36fc178 Change implementation of jax.scipy.linalg.polar() and jax._src.scipy.eigh to use the QDWH decomposition from jax._src.lax.qdwh.
Remove jax._src.lax.polar.

PiperOrigin-RevId: 448241206
2022-05-12 07:20:52 -07:00
Alex Riley
372371cec6 Add jax.scipy.linalg.funm 2022-05-02 21:46:41 +01:00
YouJiacheng
b485b8e5ce implement scipy.cluster.vq.vq
also add no check_finite and overwrite_* docstring for some scipy.linalg functions
2022-04-23 03:14:32 +08:00
Jake VanderPlas
5782210174 CI: fix flake8 ignore declarations 2022-04-21 13:44:12 -07:00
Alex Riley
869596fc2c Add jax.scipy.linalg.rsf2csf 2022-04-06 21:06:23 +01:00
Yotaro Kubo
a7fd751acf Add istft to jax.scipy.signal. 2022-04-01 14:28:53 +09:00
jax authors
54a6e4dad3 Merge pull request #9422 from yotarok:signal_stft
PiperOrigin-RevId: 429377655
2022-02-17 12:46:12 -08:00
Yotaro Kubo
e085370ec4 Add some functions for spectral analysis.
This commit adds "stft", "csd", and "welch" functions in scipy.signal.
2022-02-17 15:59:24 +09:00
Leello Tadesse Dadi
cb732323f3 adds jax.scipy.linalg.sqrtm 2022-02-16 22:33:47 +01:00
Leello Tadesse Dadi
514d8883ce adds jax.scipy.schur 2022-02-16 22:33:37 +01:00
jax authors
16c809ce7f Merge pull request #8625 from Edenhofer:regular_grid_interpolator
PiperOrigin-RevId: 423876494
2022-01-24 12:02:34 -08:00
Gordian Edenhofer
2c5fe8c40d Implement SciPy's RegularGridInterpolator
Resolves #8572 .
2022-01-04 22:10:36 +01:00
Peter Hawkins
4e21922055 Use imports relative to the jax package consistently, rather than .-relative imports.
This is more consistent, since currently we use a mix of both styles. It may also help pytype yield more accurate types.

PiperOrigin-RevId: 412057514
2021-11-24 07:48:29 -08:00
Peter Hawkins
256e7220ff [JAX] Fix pylint errors.
* trailing-whitespace
* dangerous-default-value. None of these appear to be bugs in practice, but the potential for accidentally mutating the default value is there, and the cost of avoiding the problem is small.
* invalid-envvar-default. Pass strings as getenv() defaults.
* unnecessary-semicolon. Use tuples instead for this one-liner.
* invalid-hash-returned. Raise an exception rather than asserting false.
* pointless-string-statement. Use comments instead.
* unreachable. Use @unittest.skip() decorator rather than raising as first line in test.
* logging-not-lazy. Make the logging lazy.
* bad-format-string-type. Use f-string instead.
* subprocess-run-check. Pass check=...

PiperOrigin-RevId: 400858477
2021-10-04 17:54:46 -07:00
Jake VanderPlas
33e2bed1b4 Fix package exports 2021-09-14 13:55:55 -07:00
jax authors
5c58afa647 Merge pull request #7855 from cyprienc:scipy-stats-nbinom
PiperOrigin-RevId: 396617949
2021-09-14 09:49:50 -07:00
Cyprien
8c8f0a8c71 Feat: scipy.stats.nbinom implementation
fix: increasing tolerance check for testNBinomLogPmf in scipy_stats_test.py
2021-09-14 08:05:42 +01:00
Jake VanderPlas
245581411e Add PEP484-compatible export for jax and its subpackages 2021-09-13 14:08:48 -07:00
Jonathan Terhorst
fec72e1852 add support for scipy.special.{expn,expi,exp1} 2021-08-24 16:36:10 -04:00
Julius Kunze
6d83027b69 Support scipy.fft.dct/dctn type=2 2021-08-17 18:56:44 +02:00
Peter Hawkins
0dfd76af97 Remove additional info return value from jax.scipy.linalg.polar(). 2021-07-20 13:13:31 -04:00
Adam Lewis
a2073ffcc2 Adds an implementation of a QR-based Dynamically Weighted Halley iteration. 2021-07-20 11:30:36 -04:00
tlu7
d97b393694 Adds spherical harmonics.
Co-authored-by: Jake VanderPlas <jakevdp@google.com>
2021-07-02 10:42:29 -07:00
tlu7
095e6507b9 Support value computation of associated Legendre functions.
Co-authored-by: Jake VanderPlas <jakevdp@google.com>
2021-06-14 14:51:37 -07:00
tlu7
a02bf59233 Adds associated Legendre functions of the first kind.
Co-authored-by: Jake VanderPlas <jakevdp@google.com>
2021-06-02 11:37:37 -07:00
Peter Hawkins
97e89bde18 Add a tridiagonal eigh solver. 2021-05-04 20:43:41 -04:00
sunilkpai
997ad31670 added bicgstab to new jax repo
fixed some bugs in the bicgstab method and adjusted tolerance for scipy comparison

fixed flake8

added some tests for gradients, fixed symmetry checks, modified lax.cond -> jnp.where

comment out gmres grad check, to be addressed on future PR

increasing tolerance for bicgstab grad test

change to order 1 checks for bicgstab (gmres still fails in order 1) for internal CI check

remove grad checks for now

changing tolerance to pass numpy comparison test
2021-02-18 18:01:28 -08:00
Peter Hawkins
dbeb1a5273 Remove message field from OptimizeResults to allow for vmapping.
Add OptimizeResults to the documentation.
2021-02-16 20:43:53 -05:00
Demetri
a3ad787402 Add betabinomial logpmf/pmf and tests
Squash all changes to single commit.  Add betabinom

Add tests for betabinom. nan where undefefined

squash
2021-02-05 15:13:09 -05:00
Demetri
48864a665b Add logdf and pdf for chisquare distribution
Add tests

Lint with flake8 fails.  Should pass now

newline at end of file for flake8

docs and changes

remove whitespace in changeloc
2021-02-03 15:09:21 -05:00
Jonathan Terhorst
1524b82189 add support for scipy.stats.poisson.cdf 2021-01-24 16:15:31 +00:00
Stephan Hoyer
6cc5b28327 Cleanup/fixup jax.scipy.sparse.linalg.gmres and expose it publicly. 2020-12-03 09:23:00 -08:00
Peter Hawkins
94cd2046fa [JAX] Move implementation of jax.scipy.sparse.linalg into jax._src.
PiperOrigin-RevId: 343276958
2020-11-19 06:18:09 -08:00
Stephan Hoyer
7e62270e5a More unit-tests + mark gmres as internal for now 2020-11-10 21:02:51 -08:00
Stephan Hoyer
c5c71a0a37 Cleanup _safe_normalize 2020-11-08 17:04:00 -08:00
Adam GM Lewis
7ed9fe70ea Corrections to GMRES - now gives correct result.
Co-authored-by: gehring <clement.gehring@gmail.com>

Co-authored-by: Stephan Hoyer <shoyer@google.com>
2020-11-08 15:37:50 -08:00
gehring
342cc36051 Initial implementation of GMRES 2020-11-08 15:34:56 -08:00
Peter Hawkins
c0b480bd51 Add missing jax.scipy.stats distributions to the docs.
Alphabetize import order.
2020-10-20 09:47:35 -04:00
Peter Hawkins
6acb46516e Move most of the implementation of jax.scipy into jax._src.scipy. 2020-10-16 17:04:25 -04:00
Qiao Zhang
49a01d36c8 Use jvp(expm) to compute expm_frechet. 2020-09-17 21:25:40 -07:00
Qiao Zhang
b3a098747a
Make expm transposable and remove custom_jvp rule. (#4314)
* Make expm transposable and remove custom_jvp rule.

* Add check_grads for up to 2nd order derivative.
2020-09-17 21:10:54 -07:00