363 Commits

Author SHA1 Message Date
Pearu Peterson
82b2591b21 Fix scipy.special.gammainc/gammaincc evaluation at boundary points 2025-03-11 21:18:47 +02:00
Jake VanderPlas
b441b2b7a5 Prevent tracer leaks in scipy.special.expn 2025-03-06 14:38:11 -08:00
Jan Naumann
e03fe3a06d Implement SVD algorithm based on QR for CPU targets
In a recent jax release the SvdAlgorithm parameter has been added
to the jax.lax.linalg.svd function. Currently, for CPU targets
still only the divide and conquer algorithm from LAPACK is
supported (gesdd).

This commits adds the functionality to select the QR based
algorithm on CPU as well. Mainly it addes the wrapper code
to call the gesvd function of LAPACK using the FFI interface.

Signed-off-by: Jan Naumann <j.naumann@fu-berlin.de>
2025-02-22 15:24:57 +01:00
tttc3
b1b56ea0b0 Enable pivoted QR on GPU via MAGMA.
Originally noted in #20282, this commit provides a GPU compatible
implementation of `geqp3` via MAGMA.
2025-02-12 16:12:42 +00:00
Nikolas Klug
af794721f9 Document behavior of linalg.solve in case the system matrix is singular 2025-02-08 18:09:13 +01:00
Jake VanderPlas
b4f98eef7e refactor: move scalar type defs out of lax_numpy.py 2025-02-06 14:48:10 -08:00
Qazalbash
8561f90f8c
fix: simplify logcdf implementation by removing unnecessary argument promotion 2025-02-04 23:43:15 +05:00
Qazalbash
9a324bfb14 fix: update logcdf implementation to use sf 2025-02-04 09:38:02 +05:00
Qazalbash
f9c5b0d21b fix: correct argument name in promote_args_inexact for logcdf function 2025-02-04 03:06:03 +05:00
Qazalbash
a3f7307333 chore: optimize calculations in exponential distribution functions for cdf, logcdf, sf, and ppf 2025-02-04 03:04:26 +05:00
Qazalbash
42b64fc06c
feat(gh-13291): Add exponential distribution functions: cdf, logcdf, sf, logsf, and ppf 2025-02-01 12:51:11 +05:00
Roy Frostig
a60ead6fd1 enable partitionable threefry by default
PiperOrigin-RevId: 715242560
2025-01-13 22:46:24 -08:00
jax authors
564b6b0d72 Merge pull request #20282 from tttc3:pivoted-qr
PiperOrigin-RevId: 714053620
2025-01-10 08:02:02 -08:00
tttc3
c89be05b5b Enable pivoted QR on CPU devices.
A pivoted QR factorization is possible in `scipy.linalg.qr`, thanks
to the `geqp3` routine of LAPACK. To provide the same functionality
in JAX, we implement a new primitive `geqp3_p` which calls the LAPACK
routine via the FFI on CPU devices.

Both `jax.scipy.linalg.qr` and `jax.lax.linalg.qr` now support the
use of column-pivoting on CPU devices.

To provide a GPU implementation of `geqp3` may require using MAGMA,
due to the lack of a `geqp3` implementation in `cuSolver` -  see
ccb331707e80b16d89de6e5c9f2f89b87c1682ed (`jax.lax.linalg.eig`) for
an example of using MAGMA in GPU lowerings. Such a GPU implementation
can be considered in the future.
2025-01-09 20:44:45 +00:00
Dan Foreman-Mackey
5f3e0d9e5e Add sph_harm_y to jax.scipy.special and deprecate sph_harm. 2025-01-09 12:53:00 -05:00
Brecht Ooms
257b033e8c Add Rotation return type hint to Rotation.__mul__()
Without this type hint, some tools (including PyCharm) infer the more
generic return type from typing.NamedTuple.
To improve user experience, I've added a narrower type hint.

However, the typing of this method is still 'flawed' as the only properly supported
input is another Rotation. This is a narrower input type and therefore
violates the Liskov substitution principle. Therefore I left the input
parameter untyped.

For more info:
https://mypy.readthedocs.io/en/stable/common_issues.html#incompatible-overrides
2024-12-11 12:48:02 +01:00
Nitin Srinivasan
47d1960926 Update the render documentation job to use the new self-hosted runners
PiperOrigin-RevId: 700550934
2024-11-26 21:01:50 -08:00
Dan Foreman-Mackey
96c012990d Fix false positive debug_nans error caused by NaNs that are properly handled in jax.scipy.stats.gamma
As reported in https://github.com/jax-ml/jax/issues/24939, even though the implementation of `jax.scipy.stats.gamma.logpdf` handles invalid inputs (e.g. `x < loc`) by returning `-inf`, the existing implementation incorrectly triggers the NaN checks introduced by JAX's debug NaNs mode. This change updates the implementation to no longer produce internal NaNs.

Fixes https://github.com/jax-ml/jax/issues/24939

PiperOrigin-RevId: 698833589
2024-11-21 10:33:29 -08:00
Peter Hawkins
c5e8ae80f9 Update jax.scipy.special.gamma and gammasgn to return NaN for negative integer inputs.
Change to match upstream scipy: https://github.com/scipy/scipy/pull/21827.

Fixes #24875
2024-11-18 20:33:27 -05:00
Jake VanderPlas
3f98c57f7b jax.scipy.linalg.toeplitz: support implicit batching 2024-11-11 15:32:43 -08:00
jax authors
cbaafbbe99 Merge pull request #24723 from jakevdp:beta-dep
PiperOrigin-RevId: 693759757
2024-11-06 09:42:29 -08:00
jax authors
542cb2e57e Fix a bug in jax.scipy.stats.rankdata leading to breakage with shape polymorphism.
PiperOrigin-RevId: 693755546
2024-11-06 09:31:43 -08:00
Jake VanderPlas
d698da610a scipy.special.beta: remove deprecated x and y parameters 2024-11-06 09:01:27 -08:00
Jake VanderPlas
5f90f63d19 Improve efficiency of jax.scipy.stats.rankdata 2024-11-05 05:13:57 -08:00
Dan Foreman-Mackey
96268dcae6 Fix dtype bug in jax.scipy.fft.idct 2024-09-25 12:55:43 -04:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Filippo Luca Ferretti
2ff26ff3e0 Add scalar_first argument to jax.scipy.spatial.transform.Rotation.as_quat 2024-09-16 21:57:55 +02:00
jax authors
9789056061 Fix a small typo for the condition of scipy.entr.
PiperOrigin-RevId: 674205855
2024-09-13 02:06:40 -07:00
Sergei Lebedev
51eb0d27c7 Fixed some type errors under pyright
These are mostly due to relience on submodule import side-effects, which
AFAIU are unchecked by both pytype and mypy.
2024-09-05 09:56:38 +01:00
Jake VanderPlas
7b41583414 refactor jax.lax to not depend on jax.numpy 2024-09-01 07:49:49 -07:00
jax authors
efba5f61b5 Merge pull request #22812 from superbobry:maint
PiperOrigin-RevId: 658751187
2024-08-02 04:43:33 -07:00
Sergei Lebedev
fb1dbf15df Bumped mypy to 1.11.0 and jaxlib to 0.4.31 on the CI 2024-08-01 22:30:24 +01:00
Sergei Lebedev
92b1f71314 Removed various ununsed functions
To rerun the analysis do

    python -m vulture jax/_src --ignore-names "[A-Za-z]*" --ignore-decorators "*"
2024-08-01 11:18:19 +01:00
rajasekharporeddy
3a0e4376cd Fix betabinom.logpmf and binom.logpmf for JAX to emulate SciPy's behavior when k=n=0 2024-07-31 07:58:43 +05:30
Pavel Sountsov
5ba26953be Add canonical arg to Rotation.as_quat() and switch .inv() to use the quaternion conjugate.
This matches scipy behavior as of 1.11.

I also went through the tests and enabled a bunch of disabled tests which appear to pass now(?).

PiperOrigin-RevId: 655719643
2024-07-24 15:29:19 -07:00
Jake VanderPlas
2efd1ec011 jax.scipy.fft.dct: implement & test norm='backward' 2024-07-22 11:18:35 -07:00
Jake VanderPlas
326559ca47 jax.scipy.fft: error for unsupported norm argument 2024-07-22 10:32:03 -07:00
Jake VanderPlas
3833c46d10 jnp.vectorize: respect numpy_rank_promotion config 2024-07-08 09:03:03 -07:00
tilakrayal
3bbd141d3d
Fixing the naming conventions in linalg.py 2024-06-27 13:22:39 +05:30
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
rajasekharporeddy
da334e37d0 Add code examples to jax.scipy.stats.sem docs 2024-06-24 20:47:49 +05:30
Neil Girdhar
56fdb42e9d Copy nn.{softmax,log_softmax} to scipy.special 2024-06-22 09:32:14 -04:00
jax authors
56e8fe630e Merge pull request #22028 from rajasekharporeddy:stats-sem
PiperOrigin-RevId: 645518083
2024-06-21 15:30:04 -07:00
rajasekharporeddy
edde7d9762 Fix the behavior of jax.scipy.stats.sem when keepdims=True 2024-06-22 02:39:00 +05:30
jax authors
4a7b293bd9 Merge pull request #22027 from rajasekharporeddy:testbranch5
PiperOrigin-RevId: 645437879
2024-06-21 10:51:05 -07:00
rajasekharporeddy
8cb5fb5f7c Add code examples to jax.scipy.stats.mode docs 2024-06-21 22:12:48 +05:30
Dan Foreman-Mackey
6d35b109fd Rename "Example" to "Examples" in docstrings.
This PR updates all docstrings that previously had a section heading
called "Example" and replaces that with "Examples" to be consistent.
2024-06-21 11:43:16 -04:00
Jake VanderPlas
0a86e9a929 Deprecate hashing of tracers 2024-06-13 13:14:27 -07:00
tilakrayal
3ef89a2113
Fixing the naming conventions in signal.py 2024-06-13 12:21:25 +05:30
Jake VanderPlas
aa1452375b Register beta args deprecation
PiperOrigin-RevId: 642427224
2024-06-11 16:19:14 -07:00