224 Commits

Author SHA1 Message Date
Jake VanderPlas
222b951b19 Use new matrix_transpose in linalg code 2023-05-25 09:32:14 -07:00
Jake VanderPlas
9ac3781c7e grad(entr)(0.0): return inf instead of NaN 2023-04-25 08:32:37 -07:00
jax authors
1de4d14da8 Merge pull request #15656 from laqua-stack:add-special-gamma-fcn
PiperOrigin-RevId: 525566749
2023-04-19 15:28:36 -07:00
Jake VanderPlas
dd023e266e jax.scipy.special: fix gradient for xlogy & xlog1py 2023-04-18 15:56:32 -07:00
laqua-stack
d742733bea feat (scipy.special): Add a xla version of scipy.special.gamma function
- Add gamma fcn api in scipy.special
- Add tests for this purpose
- Add function to the docs

Currently, there is no implementation of the gamma function in jax
but there is one in scipy.special. This breaks some higher level
jit-compilation like in the blackjax backend for pymc. This commit
adds the missing gamma function.

Resolves: #15409
2023-04-18 21:10:22 +02:00
Vaishaal Shankar
add15aca25 implement idct and idctn + add function to scipy.rst 2023-04-17 12:12:51 -07:00
jax authors
0fd5b2ca61 Remove use of int casting in STFT collapse of batch dimensions.
PiperOrigin-RevId: 524115535
2023-04-13 15:15:11 -07:00
Jake VanderPlas
5521423d92 Change np.prod->math.prod
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
2023-04-13 11:48:11 -07:00
Jake VanderPlas
3ca7d67e8d Fully implement and test axes argument to jax.scipy.signal.fftconvolve
PiperOrigin-RevId: 523707411
2023-04-12 08:31:30 -07:00
Jake VanderPlas
d0ed619101 jax.scipy.signal.convolve: support method='fft' 2023-04-10 14:54:15 -07:00
Jean-Eric Campagne
4beee13ba0 Add implementation of jax.scipy.fftconvolve 2023-04-07 17:19:08 +02:00
Peter Hawkins
c1f65fc8b2 Avoid imports from the public jax.* namespace in more places internally.
This change is in preparation for more cycle breaking in the Bazel dependency graph.

PiperOrigin-RevId: 521822756
2023-04-04 11:41:40 -07:00
Peter Hawkins
abf1acf76c Replace references to jax.interpreters with jax._src.interpreters in JAX core.
PiperOrigin-RevId: 520933067
2023-03-31 08:58:00 -07:00
Jake VanderPlas
ad0fc8979b jax.scipy.linalg.expm: support batched inputs 2023-03-27 16:39:48 -07:00
Misha
83b3f5b759 Fix loc and scale parameters in scipy.logistic. Add CDF and SF for several distributions. 2023-03-21 00:16:13 +01:00
Jake VanderPlas
760deb310e Remove leading underscores in jax._src.numpy.util 2023-03-13 12:18:36 -07:00
Jake VanderPlas
c8c269f5f5 internal: avoid unused imports in lax_numpy 2023-03-08 10:29:04 -08:00
Parker Schuh
d62fc88fb1 Roll back #14792
Breaks tests. lax.sub requires arguments to have the same dtypes, got float32, float64. (Tip: jnp.subtract is a similar function that does automatic type promotion on inputs).

PiperOrigin-RevId: 514897538
2023-03-07 18:31:19 -08:00
Misha
feb9ab33af Fixed loc and scale parameters for logistic distribution. CDF and SF have been added for several distributions, including cauchy, gamma, logistic, chi2 and beta. ISF and PPF have also been added for cauchy and logistic. 2023-03-07 07:56:47 +01:00
Peter Hawkins
7b6321cc09 Reenable pytype for numpy ufuncs.
Add a few type annotations to ufuncs so the exported types are more precise.

PiperOrigin-RevId: 513798060
2023-03-03 05:01:03 -08:00
jiayaobo
8b8d7ffd12 fix typo 2023-03-02 13:58:18 +08:00
Peter Hawkins
8fb1fd318d Replace jax._src.util.prod with math.prod.
math.prod() was added in Python 3.8, so we can assume it is always present.

PiperOrigin-RevId: 513011144
2023-02-28 12:41:00 -08:00
Øyvind Sigmundson Schøyen
c7ddd2a7fa
DOC: fix typo in sph_harm
:math:\theta` -> :math:`\theta`
2023-02-16 15:34:18 +01:00
Peter Hawkins
cd0533cab0 Replace uses of jnp.ndarray with jax.Array inside JAX.
PiperOrigin-RevId: 509939691
2023-02-15 14:53:00 -08:00
Roy Frostig
cb8dcce2fe migrate more internal dependencies from jax.core to jax._src.core
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Lucas Hofer
4636276214 added scipy special spence
added dtype to arrays in the _spence_poly function
2023-02-10 20:33:47 +00:00
carlosgmartin
8251957025 Added scipy.stats.rankdata 2023-02-07 12:07:00 -05:00
Roy Frostig
e199b35f4e Revert "Merge pull request #14113 from botev:main"
This reverts commit 69d18cc7b58ae4ed82246605d66ed07a49fad676, reversing
changes made to 13e875f8b8d8dd9152045c7e3b5045a9bb0d7db0.

Reverting until we address https://github.com/google/jax/issues/14249
2023-02-01 19:50:27 -08:00
tttc3
96707f09b1 Removed deprecated polar_unitary as per comment. 2023-01-27 07:24:55 +00:00
botev
73ed511d39 Adding info to CG and BICGSTAB 2023-01-22 21:47:34 +00:00
Jiayu Chang
8df97ff2a3 Add Optional to boundary arg of stft. 2022-12-30 22:42:45 -05:00
harryjulian
c0d4ae0cc3 Added scipy.stats.bernoulli cdf and ppf. 2022-12-22 18:12:25 +00:00
harryjulian
351e1874ab Added vonmises pdf, logpdf & respective tests.
Added vonmises pdf, logpdf & respective tests.

Altered type-hinting, added pi as a _lax_const

Changed lax constant pi to be created in _pdf instead of passed arg.

Changed name in __init__.py

Fixed bug in tests.

Review related alterations.

Review related changes.

Added vonmises pdf, logpdf & respective tests.

Added vonmises pdf, logpdf & respective tests.

Altered type-hinting, added pi as a _lax_const

Changed lax constant pi to be created in _pdf instead of passed arg.

Changed name in __init__.py

Fixed bug in tests.

Review related alterations.

PR

PR

PR
2022-12-14 16:08:37 +00:00
Yotaro Kubo
1ade5f8592 Add jax.scipy.linalg.toeplitz. 2022-12-09 01:03:21 +09:00
Jake VanderPlas
4389216d0c Remove typing_extensions dependency 2022-12-05 15:42:26 -08:00
Jake VanderPlas
924894fdd6 [x64] make tests more type-safe 2022-12-02 13:21:35 -08:00
jax authors
5927032664 Merge pull request #13482 from jakevdp:x64-signal
PiperOrigin-RevId: 492367133
2022-12-01 20:36:41 -08:00
Jake VanderPlas
37acc6e426 [x64] more type safety in scipy.optimize.line_search 2022-12-01 14:04:39 -08:00
Jake VanderPlas
d25a96caea [x64] more type safety in jax.scipy.signal 2022-12-01 13:43:07 -08:00
Jake VanderPlas
26d9837b36 Switch to new-style f-strings 2022-12-01 09:14:16 -08:00
Ian Horn
a35fe206a1 Added more accurate version of the betaln function. 2022-11-29 11:56:07 -08:00
Peter Hawkins
1cead779a3 Add support for Hessenberg and tridiagonal matrix reductions on CPU.
* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg.
* Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction.
* Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction.

None of these primitives are differentiable at the moment.

PiperOrigin-RevId: 487224934
2022-11-09 06:23:55 -08:00
Tianjian Lu
3b1ddf2881 [linalg] Add jax.scipy.special.bessel_jn (Bessel function of the first kind).
PiperOrigin-RevId: 487146250
2022-11-08 23:03:21 -08:00
Peter Hawkins
cd84eb10a6 Add a number of missing function cross-references in the docs. 2022-11-07 12:00:26 -05:00
Jake VanderPlas
709ffd7e77 [typing] annotate jax.numpy reduction operations 2022-10-26 13:33:15 -07:00
Jake VanderPlas
2f27d516d7 [typing] annotate next part of lax_numpy.py 2022-10-25 12:36:26 -07:00
jax authors
8f2f9f4563 Merge pull request #12646 from adrn:truncnorm
PiperOrigin-RevId: 483425197
2022-10-24 10:41:51 -07:00
Adrian Price-Whelan
5784d61048 implement truncnorm in jax.scipy.stats
fix some shape and type issues

import into namespace

imports into non-_src library

working logpdf test

cleanup

working tests for cdf and sf after fixing select

relax need for x to be in (a, b)

ensure behavior with invalid input matches scipy

remove enforcing valid parameters in tests

added truncnorm to docs

whoops alphabetical

fix linter error

fix circular import issue
2022-10-22 15:48:20 -04:00
Jake VanderPlas
7f89fd40a2 Cleanup: remove unused imports in private modules
Also improve our flake8 filter rules to avoid ignoring these.
2022-10-20 14:37:21 -07:00
Jake VanderPlas
6d308653e4 [typing] annotate jax.numpy ufuncs 2022-10-20 11:22:04 -07:00