262 Commits

Author SHA1 Message Date
carlosgmartin
18ecd2e4fd Add scipy.stats.sem. 2024-01-13 22:17:21 -05:00
Jake VanderPlas
77258cd6bd stats.binom.pmf: return zero for k > n 2024-01-02 10:53:44 -08:00
Jake VanderPlas
d7d2b767f1 integrate.trapezoid: fix function name in error message 2023-12-18 14:38:27 -08:00
Jake VanderPlas
e356d76913 Remove a number of deprecated APIs
All of these were deprecated prior to the JAX 0.4.16 release, on Sept 18 2023.
As of Monday Dec 18, we have met the 3 month deprecation period specified by the [API Compatiblity Policy](https://jax.readthedocs.io/en/latest/api_compatibility.html).

PiperOrigin-RevId: 591933493
2023-12-18 10:08:47 -08:00
Sergei Lebedev
f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
Jake VanderPlas
70d0f60ce1 Add special.factorial function 2023-12-04 06:13:14 -08:00
Jake VanderPlas
01fde43fce Fix sign of jax.scipy.special.gamma for negative inputs 2023-11-27 14:08:02 -08:00
Peter Hawkins
84c1e825c0 Make jax.numpy.where()'s condition, x, y arguments positional-only to match numpy.where.
PiperOrigin-RevId: 584377134
2023-11-21 11:10:12 -08:00
jax authors
7657a0fb15 Merge pull request #18539 from NeilGirdhar:ruff
PiperOrigin-RevId: 583105786
2023-11-16 11:15:19 -08:00
Neil Girdhar
3c920c0120 Switch from flake8 to Ruff 2023-11-15 22:35:52 -05:00
Lukas Geiger
52d7f4911c Prefer expand_dims over reshape 2023-11-16 01:15:48 +00:00
sdupourque
47ca51f474 implementation of poch and hyp1f1 2023-11-15 20:01:00 +01:00
Jake VanderPlas
340e655ac2 Remove deprecated sym_pos argument from jax.scipy.linalg.solve
PiperOrigin-RevId: 580940755
2023-11-09 09:53:37 -08:00
Ben West
02f6fcb9da Add beta function 2023-11-05 15:37:38 -08:00
Nicola De Angeli
890b762a3e feat: add wrapcauchy logpdf and pdf 2023-10-18 13:47:10 +02:00
jax authors
c3e73c67aa Merge pull request #17760 from superbobry:array-any
PiperOrigin-RevId: 570400629
2023-10-03 08:50:07 -07:00
Sergei Lebedev
5ab05e42c9 MAINT Clean up leftover Array = Any aliases in jax/_src/**.py
I had to revert to using `Any` for `RaggedAxis.ragged_axes` because pytype
found more latent type errors, which require the understanding of ragedness
and dynamic shapes internals to fix properly.
2023-10-01 12:19:21 +01:00
jax authors
9f91df725d Merge pull request #17733 from superbobry:dict-literals
PiperOrigin-RevId: 568560963
2023-09-26 09:22:37 -07:00
Sergei Lebedev
eca10f5a3d ENH Use {} and () instead of dict() and tuple() 2023-09-25 11:53:33 +01:00
Peter Hawkins
2fd6df45e4 Fix test failures under SciPy 1.11 for scipy.stats.mode. 2023-09-23 20:15:51 +00:00
Jake VanderPlas
4a5bd9e046 Fix typos across the package 2023-09-22 14:54:31 -07:00
Sergei Lebedev
df7f6a06c0 MAINT Use a generator expression in tuple([... for ... in ...])
In a few cases I also replaced tuple([*xs, *ys]) with (*xs, ys), because
tuple literals support unpacking as well.
2023-09-21 22:25:38 +01:00
Peter Hawkins
975dae34a4 Deprecate jax.numpy.trapz.
Expose the current implementation of jax.numpy.trapz as jax.scipy.integrate.trapezoid instead.

Fixes https://github.com/google/jax/issues/17244
2023-08-25 09:04:13 -06:00
Jake VanderPlas
042111eb08 Add jax.scipy.special.bernoulli 2023-08-23 12:58:37 -07:00
Jake VanderPlas
d1c2277bfc jax.scipy.stats: add logsf & make sf more accurate near zero 2023-08-22 14:45:18 -07:00
Peter Hawkins
3082109a59 Add a type stub for jax.numpy.
This type stub is intended to match what pytype currently infers for jax.numpy, which is not particularly accurate in many cases. Future changes will add more accurate types to this stub.

Fix a number of new type errors this reveals to mypy.

PiperOrigin-RevId: 559179804
2023-08-22 11:50:49 -07:00
Jake VanderPlas
cf11f8da8a stats.norm: add logsf & make sf more accurate near zero 2023-08-21 16:48:39 -07:00
jax authors
209b6b02f4 Merge pull request #17144 from jakevdp:zeta
PiperOrigin-RevId: 558193896
2023-08-18 11:04:43 -07:00
Jake VanderPlas
6cd467fd57 Create lax.zeta with native HLO lowering 2023-08-16 13:43:41 -07:00
Jake VanderPlas
0ad6196ff0 Create lax.polygamma with native HLO lowering 2023-08-16 11:57:05 -07:00
Peter Hawkins
640ee1e815 Remove superfluous double-where in xlogy and xlog1py.
These functions have custom derivatives, so there seems to be no point to using the double-where guard on the primal function: the implementation can never be differentiated!

PiperOrigin-RevId: 551843160
2023-07-28 07:11:24 -07:00
salamandercrossing
4e42adb599 Add kl_div and rel_entr functions
Their behavior is the same as functions in scipy.special. The only small
difference is in rel_entr function, which unlike scipy.special does not
take the optional parameter 'out'.

Resolves #16630
2023-07-27 21:34:55 +00:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Jake VanderPlas
30d1a8a80f Add jax.scipy.stats.binom 2023-06-27 03:41:38 -07:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
Peter Hawkins
ef3f2abfd2 Fix test failures in JAX under NumPy 1.25.0rc1.
`jnp.finfo(...)` of an Array type yields:

```
TypeError: unhashable type: 'ArrayImpl'
```

However, `np.finfo(...)` no longer accepts NumPy arrays as input either, so it would be consistent to require the user to pass a dtype where they are currently passing an array.

PiperOrigin-RevId: 539174254
2023-06-09 14:10:35 -07:00
Chris Flesher
5be17ed90c Added scipy.spatial.transform Rotation and Slerp classes 2023-06-08 07:51:32 -05:00
Nicolas Tessore
a835cafdad
Fix incorrect wrapped docstring of jax.scipy.special.gamma
Fixes the docstring `jax.scipy.special.gamma`, which was wrapping `scipy.special.gammaln` by mistake. Also adds a note that the function currently only accepts real inputs.
2023-06-01 20:13:37 +01:00
Jake VanderPlas
222b951b19 Use new matrix_transpose in linalg code 2023-05-25 09:32:14 -07:00
Jake VanderPlas
9ac3781c7e grad(entr)(0.0): return inf instead of NaN 2023-04-25 08:32:37 -07:00
jax authors
1de4d14da8 Merge pull request #15656 from laqua-stack:add-special-gamma-fcn
PiperOrigin-RevId: 525566749
2023-04-19 15:28:36 -07:00
Jake VanderPlas
dd023e266e jax.scipy.special: fix gradient for xlogy & xlog1py 2023-04-18 15:56:32 -07:00
laqua-stack
d742733bea feat (scipy.special): Add a xla version of scipy.special.gamma function
- Add gamma fcn api in scipy.special
- Add tests for this purpose
- Add function to the docs

Currently, there is no implementation of the gamma function in jax
but there is one in scipy.special. This breaks some higher level
jit-compilation like in the blackjax backend for pymc. This commit
adds the missing gamma function.

Resolves: #15409
2023-04-18 21:10:22 +02:00
Vaishaal Shankar
add15aca25 implement idct and idctn + add function to scipy.rst 2023-04-17 12:12:51 -07:00
jax authors
0fd5b2ca61 Remove use of int casting in STFT collapse of batch dimensions.
PiperOrigin-RevId: 524115535
2023-04-13 15:15:11 -07:00
Jake VanderPlas
5521423d92 Change np.prod->math.prod
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
2023-04-13 11:48:11 -07:00
Jake VanderPlas
3ca7d67e8d Fully implement and test axes argument to jax.scipy.signal.fftconvolve
PiperOrigin-RevId: 523707411
2023-04-12 08:31:30 -07:00
Jake VanderPlas
d0ed619101 jax.scipy.signal.convolve: support method='fft' 2023-04-10 14:54:15 -07:00
Jean-Eric Campagne
4beee13ba0 Add implementation of jax.scipy.fftconvolve 2023-04-07 17:19:08 +02:00
Peter Hawkins
c1f65fc8b2 Avoid imports from the public jax.* namespace in more places internally.
This change is in preparation for more cycle breaking in the Bazel dependency graph.

PiperOrigin-RevId: 521822756
2023-04-04 11:41:40 -07:00