364 Commits

Author SHA1 Message Date
Jean-Eric Campagne
4beee13ba0 Add implementation of jax.scipy.fftconvolve 2023-04-07 17:19:08 +02:00
Peter Hawkins
c1f65fc8b2 Avoid imports from the public jax.* namespace in more places internally.
This change is in preparation for more cycle breaking in the Bazel dependency graph.

PiperOrigin-RevId: 521822756
2023-04-04 11:41:40 -07:00
Peter Hawkins
abf1acf76c Replace references to jax.interpreters with jax._src.interpreters in JAX core.
PiperOrigin-RevId: 520933067
2023-03-31 08:58:00 -07:00
Jake VanderPlas
ad0fc8979b jax.scipy.linalg.expm: support batched inputs 2023-03-27 16:39:48 -07:00
Misha
83b3f5b759 Fix loc and scale parameters in scipy.logistic. Add CDF and SF for several distributions. 2023-03-21 00:16:13 +01:00
Jake VanderPlas
760deb310e Remove leading underscores in jax._src.numpy.util 2023-03-13 12:18:36 -07:00
Jake VanderPlas
c8c269f5f5 internal: avoid unused imports in lax_numpy 2023-03-08 10:29:04 -08:00
Parker Schuh
d62fc88fb1 Roll back #14792
Breaks tests. lax.sub requires arguments to have the same dtypes, got float32, float64. (Tip: jnp.subtract is a similar function that does automatic type promotion on inputs).

PiperOrigin-RevId: 514897538
2023-03-07 18:31:19 -08:00
Misha
feb9ab33af Fixed loc and scale parameters for logistic distribution. CDF and SF have been added for several distributions, including cauchy, gamma, logistic, chi2 and beta. ISF and PPF have also been added for cauchy and logistic. 2023-03-07 07:56:47 +01:00
Peter Hawkins
7b6321cc09 Reenable pytype for numpy ufuncs.
Add a few type annotations to ufuncs so the exported types are more precise.

PiperOrigin-RevId: 513798060
2023-03-03 05:01:03 -08:00
jiayaobo
8b8d7ffd12 fix typo 2023-03-02 13:58:18 +08:00
Peter Hawkins
8fb1fd318d Replace jax._src.util.prod with math.prod.
math.prod() was added in Python 3.8, so we can assume it is always present.

PiperOrigin-RevId: 513011144
2023-02-28 12:41:00 -08:00
Øyvind Sigmundson Schøyen
c7ddd2a7fa
DOC: fix typo in sph_harm
:math:\theta` -> :math:`\theta`
2023-02-16 15:34:18 +01:00
Peter Hawkins
cd0533cab0 Replace uses of jnp.ndarray with jax.Array inside JAX.
PiperOrigin-RevId: 509939691
2023-02-15 14:53:00 -08:00
Roy Frostig
cb8dcce2fe migrate more internal dependencies from jax.core to jax._src.core
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Lucas Hofer
4636276214 added scipy special spence
added dtype to arrays in the _spence_poly function
2023-02-10 20:33:47 +00:00
carlosgmartin
8251957025 Added scipy.stats.rankdata 2023-02-07 12:07:00 -05:00
Roy Frostig
e199b35f4e Revert "Merge pull request #14113 from botev:main"
This reverts commit 69d18cc7b58ae4ed82246605d66ed07a49fad676, reversing
changes made to 13e875f8b8d8dd9152045c7e3b5045a9bb0d7db0.

Reverting until we address https://github.com/google/jax/issues/14249
2023-02-01 19:50:27 -08:00
tttc3
96707f09b1 Removed deprecated polar_unitary as per comment. 2023-01-27 07:24:55 +00:00
botev
73ed511d39 Adding info to CG and BICGSTAB 2023-01-22 21:47:34 +00:00
Jiayu Chang
8df97ff2a3 Add Optional to boundary arg of stft. 2022-12-30 22:42:45 -05:00
harryjulian
c0d4ae0cc3 Added scipy.stats.bernoulli cdf and ppf. 2022-12-22 18:12:25 +00:00
harryjulian
351e1874ab Added vonmises pdf, logpdf & respective tests.
Added vonmises pdf, logpdf & respective tests.

Altered type-hinting, added pi as a _lax_const

Changed lax constant pi to be created in _pdf instead of passed arg.

Changed name in __init__.py

Fixed bug in tests.

Review related alterations.

Review related changes.

Added vonmises pdf, logpdf & respective tests.

Added vonmises pdf, logpdf & respective tests.

Altered type-hinting, added pi as a _lax_const

Changed lax constant pi to be created in _pdf instead of passed arg.

Changed name in __init__.py

Fixed bug in tests.

Review related alterations.

PR

PR

PR
2022-12-14 16:08:37 +00:00
Yotaro Kubo
1ade5f8592 Add jax.scipy.linalg.toeplitz. 2022-12-09 01:03:21 +09:00
Jake VanderPlas
4389216d0c Remove typing_extensions dependency 2022-12-05 15:42:26 -08:00
Jake VanderPlas
924894fdd6 [x64] make tests more type-safe 2022-12-02 13:21:35 -08:00
jax authors
5927032664 Merge pull request #13482 from jakevdp:x64-signal
PiperOrigin-RevId: 492367133
2022-12-01 20:36:41 -08:00
Jake VanderPlas
37acc6e426 [x64] more type safety in scipy.optimize.line_search 2022-12-01 14:04:39 -08:00
Jake VanderPlas
d25a96caea [x64] more type safety in jax.scipy.signal 2022-12-01 13:43:07 -08:00
Jake VanderPlas
26d9837b36 Switch to new-style f-strings 2022-12-01 09:14:16 -08:00
Ian Horn
a35fe206a1 Added more accurate version of the betaln function. 2022-11-29 11:56:07 -08:00
Peter Hawkins
1cead779a3 Add support for Hessenberg and tridiagonal matrix reductions on CPU.
* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg.
* Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction.
* Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction.

None of these primitives are differentiable at the moment.

PiperOrigin-RevId: 487224934
2022-11-09 06:23:55 -08:00
Tianjian Lu
3b1ddf2881 [linalg] Add jax.scipy.special.bessel_jn (Bessel function of the first kind).
PiperOrigin-RevId: 487146250
2022-11-08 23:03:21 -08:00
Peter Hawkins
cd84eb10a6 Add a number of missing function cross-references in the docs. 2022-11-07 12:00:26 -05:00
Jake VanderPlas
709ffd7e77 [typing] annotate jax.numpy reduction operations 2022-10-26 13:33:15 -07:00
Jake VanderPlas
2f27d516d7 [typing] annotate next part of lax_numpy.py 2022-10-25 12:36:26 -07:00
jax authors
8f2f9f4563 Merge pull request #12646 from adrn:truncnorm
PiperOrigin-RevId: 483425197
2022-10-24 10:41:51 -07:00
Adrian Price-Whelan
5784d61048 implement truncnorm in jax.scipy.stats
fix some shape and type issues

import into namespace

imports into non-_src library

working logpdf test

cleanup

working tests for cdf and sf after fixing select

relax need for x to be in (a, b)

ensure behavior with invalid input matches scipy

remove enforcing valid parameters in tests

added truncnorm to docs

whoops alphabetical

fix linter error

fix circular import issue
2022-10-22 15:48:20 -04:00
Jake VanderPlas
7f89fd40a2 Cleanup: remove unused imports in private modules
Also improve our flake8 filter rules to avoid ignoring these.
2022-10-20 14:37:21 -07:00
Jake VanderPlas
6d308653e4 [typing] annotate jax.numpy ufuncs 2022-10-20 11:22:04 -07:00
Jake VanderPlas
524745f322 TMP: annotate util.safe_zip 2022-10-19 10:29:53 -07:00
Peter Hawkins
807269990e Enable more GPU and TPU tests that pass at head.
Increase precision of matmuls in LU decompositions, pseudo-inverse solves, and their gradients. It is unlikely users want to use low precision for these operations and high precision is probably the right default.

PiperOrigin-RevId: 482071629
2022-10-18 18:09:44 -07:00
Yann Lamidon
ccbc3059b0 Add JAX equivalent of scipy.stats.mode 2022-10-18 20:45:02 +01:00
Jake VanderPlas
8ac9ea312a [typing] annotate jax.scipy.special 2022-10-13 12:16:12 -07:00
Jake VanderPlas
512e7004bc [typing] annotage jax.scipy.signal 2022-10-13 12:15:40 -07:00
jax authors
7d4ea9bb8e Merge pull request #12757 from jakevdp:erf-doc
PiperOrigin-RevId: 480943111
2022-10-13 11:39:35 -07:00
jax authors
3950c4c39d Merge pull request #12777 from jakevdp:scipy-stats-typing
PiperOrigin-RevId: 480943061
2022-10-13 11:32:32 -07:00
Jake VanderPlas
b1119be830 [typing] annotage jax.scipy.stats distributions 2022-10-12 13:42:11 -07:00
Peter Hawkins
9ab88071a7 Avoid loading scipy eagerly.
scipy accounts for around 400ms of the 900ms of JAX's import time. By
loading scipy lazily, we can improve the timing of `import jax` down to
about 500ms.
2022-10-12 19:51:09 +00:00
Jake VanderPlas
ffafd7e220 DOC: switch order of custom_jvp and _wraps 2022-10-12 10:13:55 -07:00