126 Commits

Author SHA1 Message Date
Qazalbash
42b64fc06c
feat(gh-13291): Add exponential distribution functions: cdf, logcdf, sf, logsf, and ppf 2025-02-01 12:51:11 +05:00
Dan Foreman-Mackey
96c012990d Fix false positive debug_nans error caused by NaNs that are properly handled in jax.scipy.stats.gamma
As reported in https://github.com/jax-ml/jax/issues/24939, even though the implementation of `jax.scipy.stats.gamma.logpdf` handles invalid inputs (e.g. `x < loc`) by returning `-inf`, the existing implementation incorrectly triggers the NaN checks introduced by JAX's debug NaNs mode. This change updates the implementation to no longer produce internal NaNs.

Fixes https://github.com/jax-ml/jax/issues/24939

PiperOrigin-RevId: 698833589
2024-11-21 10:33:29 -08:00
Peter Hawkins
a0e4448393 Remove warning filters from pyproject.toml, add local warning
suppressions.

We want to support running Bazel tests with PYTHONWARNINGS=error. In
preparation for that change, move warning suppressions from
pyproject.toml into the individual test cases that generate them, which
is a reasonable cleanup anyway.
2024-09-24 01:38:24 +00:00
Vadym Matsishevskyi
2199685437 Ignore scipy.stats._axis_nan_policy.SmallSampleWarning for LaxBackedScipyStatsTests.testMode
It is to fix our CI, the warning itself started occurring on scipy 1.14 due to this change https://github.com/scipy/scipy/pull/20694, which introduced SmallSampleWarning and started emitting it if the input is an empty array (the `a` variable in the randomized parametrized test LaxBackedScipyStatsTests.testMode sometimes happens to be an empty array).

Note, the actual ignored warning is RungimeWarning (the superclass of SmallSampleWarning) to make it backward compatible (scipy.stats._axis_nan_policy.SmallSampleWarning does not exist in scipy prior 1.14, not to mention it being under private declared in a private (_axis_nan_policy) namespace.

PiperOrigin-RevId: 677629866
2024-09-22 22:26:33 -07:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
rajasekharporeddy
3a0e4376cd Fix betabinom.logpmf and binom.logpmf for JAX to emulate SciPy's behavior when k=n=0 2024-07-31 07:58:43 +05:30
rajasekharporeddy
edde7d9762 Fix the behavior of jax.scipy.stats.sem when keepdims=True 2024-06-22 02:39:00 +05:30
jax authors
8b1418244b Merge pull request #20885 from rajasekharporeddy:test_branch4
PiperOrigin-RevId: 627486343
2024-04-23 13:29:40 -07:00
rajasekharporeddy
c536eea1e5 Fix jax.scipy.stats.beta.logpdf to emulate scipy.stats.beta.logpdf 2024-04-24 01:24:09 +05:30
rajasekharporeddy
95ed0538fd Fix jax.scipy.stats.poisson.logpmf to emulate scipy.stats.poisson.logpmf for non-integer values of k 2024-04-24 00:29:52 +05:30
Jake VanderPlas
f090074d86 Avoid 'from jax import config' imports
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00
Jake VanderPlas
8949a63ce1 [key reuse] rename flag to jax_debug_key_reuse 2024-03-22 05:37:30 -07:00
Jake VanderPlas
d08e9a03d8 [key reuse] add eager checks 2024-02-29 15:30:19 -08:00
Jake VanderPlas
cddee4654c tests: access tree utilities via jax.tree.* 2024-02-26 14:17:18 -08:00
ilay menahem
390e90361a Add .hypothesis/ directory to .gitignore
and ppf and cdf to scipy.stats.uniform
2024-01-16 18:59:52 +00:00
jax authors
94b2da6a3b Merge pull request #19302 from carlosgmartin:scipy-stats-sem
PiperOrigin-RevId: 598884144
2024-01-16 10:34:45 -08:00
carlosgmartin
18ecd2e4fd Add scipy.stats.sem. 2024-01-13 22:17:21 -05:00
Jake VanderPlas
1870eee062 Test: make scipy version parsing compatible with pre-releases 2024-01-12 14:35:28 -08:00
Jake VanderPlas
77258cd6bd stats.binom.pmf: return zero for k > n 2024-01-02 10:53:44 -08:00
jax authors
3778265e2e Merge pull request #18126 from niqodea:wrapcauchy
PiperOrigin-RevId: 574572631
2023-10-18 13:18:20 -07:00
Nicola De Angeli
890b762a3e feat: add wrapcauchy logpdf and pdf 2023-10-18 13:47:10 +02:00
Peter Hawkins
1885c4933c Add a new internal test utility test_device_matches() and use it instead of equality tests on device_under_test().
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.

Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().

PiperOrigin-RevId: 568923117
2023-09-27 12:10:43 -07:00
Peter Hawkins
2fd6df45e4 Fix test failures under SciPy 1.11 for scipy.stats.mode. 2023-09-23 20:15:51 +00:00
Peter Hawkins
4f805c2d8f [JAX] Change jax.test_util utilities to have identical tolerances on all platforms.
In cases where this causes TPU tests to fail, relax test tolerances in the test cases themselves.

TPUs are less precise only for specific operations, notably matrix multiplication (for which usually enabling higher-precision matrix multiplication is the right choice if precision is needed), and certain special functions (e.g., log/exp/pow).

The net effect of this change is mostly to tighten up many test tolerances on TPU.

PiperOrigin-RevId: 562953488
2023-09-05 18:48:55 -07:00
Jake VanderPlas
2f878a7168 Tests: set jax_legacy_prng_key='error' 2023-08-28 10:56:09 -07:00
Jake VanderPlas
d1c2277bfc jax.scipy.stats: add logsf & make sf more accurate near zero 2023-08-22 14:45:18 -07:00
Jake VanderPlas
cf11f8da8a stats.norm: add logsf & make sf more accurate near zero 2023-08-21 16:48:39 -07:00
Jake VanderPlas
30d1a8a80f Add jax.scipy.stats.binom 2023-06-27 03:41:38 -07:00
Jake VanderPlas
ad35702934 Drop support for numpy 1.21
This is in accordance with NEP 29 and https://jax.readthedocs.io/en/latest/deprecation.html
2023-06-23 10:28:26 -07:00
Peter Hawkins
0adfafe293 Relax test tolerances.
This makes the tests pass on CPU with a slightly different seed (+ 1).

PiperOrigin-RevId: 542877795
2023-06-23 09:22:11 -07:00
Peter Hawkins
f7f1ddbb1e Temporarily disable LaxBackedScipyStatsTests.testTruncnormPdf.
This test started failing at LLVM head.

PiperOrigin-RevId: 532095958
2023-05-15 06:52:46 -07:00
Benjamin Kramer
545c483e50 Re-enable testTruncNormPdf on CPU
Breaking change was reverted in LLVM 3b8bc83527

PiperOrigin-RevId: 529072697
2023-05-03 06:31:59 -07:00
Peter Hawkins
57e62ca03c Reenable scipy_stats_test in CI.
Disable testTruncNormPdf on CPU, which is failing after an LLVM update.

PiperOrigin-RevId: 528884880
2023-05-02 14:11:08 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Yash Katariya
16ca0ca15c Relax the tolerance of testCauchyLogCdf
PiperOrigin-RevId: 520741306
2023-03-30 14:19:50 -07:00
Peter Hawkins
67a28ce30f Relax test tolerances for testLogisticPpf.
Fixes a test failure in CI.

PiperOrigin-RevId: 520649225
2023-03-30 08:41:56 -07:00
Peter Hawkins
b7375b316b Increase minimum NumPy version to 1.21.
Also increase minimum SciPy version to 1.7, which was released just before NumPy 1.21.
2023-03-23 21:15:10 -04:00
Misha
83b3f5b759 Fix loc and scale parameters in scipy.logistic. Add CDF and SF for several distributions. 2023-03-21 00:16:13 +01:00
Parker Schuh
d62fc88fb1 Roll back #14792
Breaks tests. lax.sub requires arguments to have the same dtypes, got float32, float64. (Tip: jnp.subtract is a similar function that does automatic type promotion on inputs).

PiperOrigin-RevId: 514897538
2023-03-07 18:31:19 -08:00
Misha
feb9ab33af Fixed loc and scale parameters for logistic distribution. CDF and SF have been added for several distributions, including cauchy, gamma, logistic, chi2 and beta. ISF and PPF have also been added for cauchy and logistic. 2023-03-07 07:56:47 +01:00
Peter Hawkins
33bed1e520 Opt into higher matmul precision for A100 and TPU tests.
PiperOrigin-RevId: 509598465
2023-02-14 12:03:12 -08:00
carlosgmartin
8251957025 Added scipy.stats.rankdata 2023-02-07 12:07:00 -05:00
Peter Hawkins
27da460f25 Fix test failures under SciPy 1.10.0. 2023-01-31 14:51:38 +00:00
harryjulian
c0d4ae0cc3 Added scipy.stats.bernoulli cdf and ppf. 2022-12-22 18:12:25 +00:00
harryjulian
351e1874ab Added vonmises pdf, logpdf & respective tests.
Added vonmises pdf, logpdf & respective tests.

Altered type-hinting, added pi as a _lax_const

Changed lax constant pi to be created in _pdf instead of passed arg.

Changed name in __init__.py

Fixed bug in tests.

Review related alterations.

Review related changes.

Added vonmises pdf, logpdf & respective tests.

Added vonmises pdf, logpdf & respective tests.

Altered type-hinting, added pi as a _lax_const

Changed lax constant pi to be created in _pdf instead of passed arg.

Changed name in __init__.py

Fixed bug in tests.

Review related alterations.

PR

PR

PR
2022-12-14 16:08:37 +00:00
Peter Hawkins
73de02d5ce Make JAX tests pass under NumPy 1.24.0rc2.
* allow rc2 in numpy versions when parsed by tests.
* don't cast np.empty(), which can lead to cast errors.
* NumPy 1.24 now warns on overflowing scalar int to array casts in more
places.
2022-12-08 19:46:10 +00:00
Adrian Price-Whelan
5784d61048 implement truncnorm in jax.scipy.stats
fix some shape and type issues

import into namespace

imports into non-_src library

working logpdf test

cleanup

working tests for cdf and sf after fixing select

relax need for x to be in (a, b)

ensure behavior with invalid input matches scipy

remove enforcing valid parameters in tests

added truncnorm to docs

whoops alphabetical

fix linter error

fix circular import issue
2022-10-22 15:48:20 -04:00
Yann Lamidon
ccbc3059b0 Add JAX equivalent of scipy.stats.mode 2022-10-18 20:45:02 +01:00
Peter Hawkins
72f4f389be Migrate remaining tests from jtu.cases_from_list to jtu.sample_product.
Delete jtu.cases_from_list.
2022-10-12 15:20:53 +00:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00