112 Commits

Author SHA1 Message Date
ilay menahem
390e90361a Add .hypothesis/ directory to .gitignore
and ppf and cdf to scipy.stats.uniform
2024-01-16 18:59:52 +00:00
jax authors
94b2da6a3b Merge pull request #19302 from carlosgmartin:scipy-stats-sem
PiperOrigin-RevId: 598884144
2024-01-16 10:34:45 -08:00
carlosgmartin
18ecd2e4fd Add scipy.stats.sem. 2024-01-13 22:17:21 -05:00
Jake VanderPlas
1870eee062 Test: make scipy version parsing compatible with pre-releases 2024-01-12 14:35:28 -08:00
Jake VanderPlas
77258cd6bd stats.binom.pmf: return zero for k > n 2024-01-02 10:53:44 -08:00
jax authors
3778265e2e Merge pull request #18126 from niqodea:wrapcauchy
PiperOrigin-RevId: 574572631
2023-10-18 13:18:20 -07:00
Nicola De Angeli
890b762a3e feat: add wrapcauchy logpdf and pdf 2023-10-18 13:47:10 +02:00
Peter Hawkins
1885c4933c Add a new internal test utility test_device_matches() and use it instead of equality tests on device_under_test().
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.

Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().

PiperOrigin-RevId: 568923117
2023-09-27 12:10:43 -07:00
Peter Hawkins
2fd6df45e4 Fix test failures under SciPy 1.11 for scipy.stats.mode. 2023-09-23 20:15:51 +00:00
Peter Hawkins
4f805c2d8f [JAX] Change jax.test_util utilities to have identical tolerances on all platforms.
In cases where this causes TPU tests to fail, relax test tolerances in the test cases themselves.

TPUs are less precise only for specific operations, notably matrix multiplication (for which usually enabling higher-precision matrix multiplication is the right choice if precision is needed), and certain special functions (e.g., log/exp/pow).

The net effect of this change is mostly to tighten up many test tolerances on TPU.

PiperOrigin-RevId: 562953488
2023-09-05 18:48:55 -07:00
Jake VanderPlas
2f878a7168 Tests: set jax_legacy_prng_key='error' 2023-08-28 10:56:09 -07:00
Jake VanderPlas
d1c2277bfc jax.scipy.stats: add logsf & make sf more accurate near zero 2023-08-22 14:45:18 -07:00
Jake VanderPlas
cf11f8da8a stats.norm: add logsf & make sf more accurate near zero 2023-08-21 16:48:39 -07:00
Jake VanderPlas
30d1a8a80f Add jax.scipy.stats.binom 2023-06-27 03:41:38 -07:00
Jake VanderPlas
ad35702934 Drop support for numpy 1.21
This is in accordance with NEP 29 and https://jax.readthedocs.io/en/latest/deprecation.html
2023-06-23 10:28:26 -07:00
Peter Hawkins
0adfafe293 Relax test tolerances.
This makes the tests pass on CPU with a slightly different seed (+ 1).

PiperOrigin-RevId: 542877795
2023-06-23 09:22:11 -07:00
Peter Hawkins
f7f1ddbb1e Temporarily disable LaxBackedScipyStatsTests.testTruncnormPdf.
This test started failing at LLVM head.

PiperOrigin-RevId: 532095958
2023-05-15 06:52:46 -07:00
Benjamin Kramer
545c483e50 Re-enable testTruncNormPdf on CPU
Breaking change was reverted in LLVM 3b8bc83527

PiperOrigin-RevId: 529072697
2023-05-03 06:31:59 -07:00
Peter Hawkins
57e62ca03c Reenable scipy_stats_test in CI.
Disable testTruncNormPdf on CPU, which is failing after an LLVM update.

PiperOrigin-RevId: 528884880
2023-05-02 14:11:08 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Yash Katariya
16ca0ca15c Relax the tolerance of testCauchyLogCdf
PiperOrigin-RevId: 520741306
2023-03-30 14:19:50 -07:00
Peter Hawkins
67a28ce30f Relax test tolerances for testLogisticPpf.
Fixes a test failure in CI.

PiperOrigin-RevId: 520649225
2023-03-30 08:41:56 -07:00
Peter Hawkins
b7375b316b Increase minimum NumPy version to 1.21.
Also increase minimum SciPy version to 1.7, which was released just before NumPy 1.21.
2023-03-23 21:15:10 -04:00
Misha
83b3f5b759 Fix loc and scale parameters in scipy.logistic. Add CDF and SF for several distributions. 2023-03-21 00:16:13 +01:00
Parker Schuh
d62fc88fb1 Roll back #14792
Breaks tests. lax.sub requires arguments to have the same dtypes, got float32, float64. (Tip: jnp.subtract is a similar function that does automatic type promotion on inputs).

PiperOrigin-RevId: 514897538
2023-03-07 18:31:19 -08:00
Misha
feb9ab33af Fixed loc and scale parameters for logistic distribution. CDF and SF have been added for several distributions, including cauchy, gamma, logistic, chi2 and beta. ISF and PPF have also been added for cauchy and logistic. 2023-03-07 07:56:47 +01:00
Peter Hawkins
33bed1e520 Opt into higher matmul precision for A100 and TPU tests.
PiperOrigin-RevId: 509598465
2023-02-14 12:03:12 -08:00
carlosgmartin
8251957025 Added scipy.stats.rankdata 2023-02-07 12:07:00 -05:00
Peter Hawkins
27da460f25 Fix test failures under SciPy 1.10.0. 2023-01-31 14:51:38 +00:00
harryjulian
c0d4ae0cc3 Added scipy.stats.bernoulli cdf and ppf. 2022-12-22 18:12:25 +00:00
harryjulian
351e1874ab Added vonmises pdf, logpdf & respective tests.
Added vonmises pdf, logpdf & respective tests.

Altered type-hinting, added pi as a _lax_const

Changed lax constant pi to be created in _pdf instead of passed arg.

Changed name in __init__.py

Fixed bug in tests.

Review related alterations.

Review related changes.

Added vonmises pdf, logpdf & respective tests.

Added vonmises pdf, logpdf & respective tests.

Altered type-hinting, added pi as a _lax_const

Changed lax constant pi to be created in _pdf instead of passed arg.

Changed name in __init__.py

Fixed bug in tests.

Review related alterations.

PR

PR

PR
2022-12-14 16:08:37 +00:00
Peter Hawkins
73de02d5ce Make JAX tests pass under NumPy 1.24.0rc2.
* allow rc2 in numpy versions when parsed by tests.
* don't cast np.empty(), which can lead to cast errors.
* NumPy 1.24 now warns on overflowing scalar int to array casts in more
places.
2022-12-08 19:46:10 +00:00
Adrian Price-Whelan
5784d61048 implement truncnorm in jax.scipy.stats
fix some shape and type issues

import into namespace

imports into non-_src library

working logpdf test

cleanup

working tests for cdf and sf after fixing select

relax need for x to be in (a, b)

ensure behavior with invalid input matches scipy

remove enforcing valid parameters in tests

added truncnorm to docs

whoops alphabetical

fix linter error

fix circular import issue
2022-10-22 15:48:20 -04:00
Yann Lamidon
ccbc3059b0 Add JAX equivalent of scipy.stats.mode 2022-10-18 20:45:02 +01:00
Peter Hawkins
72f4f389be Migrate remaining tests from jtu.cases_from_list to jtu.sample_product.
Delete jtu.cases_from_list.
2022-10-12 15:20:53 +00:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Wonhyeong Seo
3f6eb40698 JAX implementation of scipy.stats.multinomial pmf & logpmf
Co-authored-by: harryjulian <harry.julian@peak.ai>
2022-09-15 13:21:44 -07:00
Peter Hawkins
57b5acf1b6 Roll forward: Upgrade logistic into a primitive.
Unlike the previous attempt, we don't try to use mhlo.logistic as the lowering of the new primitive yet. Instead, we lower to the old implementation of `expit`. This means that this change should be a no-op numerically and we can work on changing its implementation in a subsequent change.

PiperOrigin-RevId: 472705623
2022-09-07 06:06:56 -07:00
jax authors
9c16c83234 Rollback of upgrade logistic (sigmoid) function into a lax primitive.
PiperOrigin-RevId: 471105650
2022-08-30 15:30:43 -07:00
Peter Hawkins
f68f1c0cd0 Upgrade logistic (sigmoid) function into a lax primitive.
This allows us to lower it to `mhlo.logistic`, which allows XLA to generate more efficient code.

PiperOrigin-RevId: 470300985
2022-08-26 11:58:28 -07:00
jax authors
3e3542b0d6 Upgrade logistic (sigmoid) function into a lax primitive.
This allows us to lower it to `mhlo.logistic`, which allows XLA to generate more efficient code.

PiperOrigin-RevId: 469841487
2022-08-24 15:39:37 -07:00
Peter Hawkins
6276194e1c Upgrade logistic (sigmoid) function into a lax primitive.
This allows us to lower it to `mhlo.logistic`, which allows XLA to generate more efficient code.

PiperOrigin-RevId: 469789339
2022-08-24 12:04:01 -07:00
Dan F-M
dc2a50ff21 looser TPU precision 2022-06-28 17:07:30 -04:00
Dan F-M
0788d5708a Implementation of jax.scipy.stats.gaussian_kde 2022-06-28 15:17:12 -04:00
Peter Hawkins
a560a29e12 Increase the minimum scipy version to 1.5.
We don't have a formal support policy for scipy versions, but 1.5 dates from around the same date as the oldest supported NumPy release NEP-29 would have us support (1.20).
2022-06-24 15:07:09 -04:00
Jake VanderPlas
4c0d61a143 Add jtu.strict_promotion_if_dtypes_match utility 2022-06-16 13:59:53 -07:00
Jake VanderPlas
f00d706a6d [x64] make scipy_stats_test.py compatible with strict dtype promotion 2022-06-14 14:47:58 -07:00
carlosgmartin
57b89ba7cb Added scipy.stats.gennorm. 2022-06-14 13:38:24 -04:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
YouJiacheng
4695dd919c Fix#10219 2022-04-13 04:04:11 +08:00