As reported in https://github.com/jax-ml/jax/issues/24939, even though the implementation of `jax.scipy.stats.gamma.logpdf` handles invalid inputs (e.g. `x < loc`) by returning `-inf`, the existing implementation incorrectly triggers the NaN checks introduced by JAX's debug NaNs mode. This change updates the implementation to no longer produce internal NaNs.
Fixes https://github.com/jax-ml/jax/issues/24939
PiperOrigin-RevId: 698833589
suppressions.
We want to support running Bazel tests with PYTHONWARNINGS=error. In
preparation for that change, move warning suppressions from
pyproject.toml into the individual test cases that generate them, which
is a reasonable cleanup anyway.
It is to fix our CI, the warning itself started occurring on scipy 1.14 due to this change https://github.com/scipy/scipy/pull/20694, which introduced SmallSampleWarning and started emitting it if the input is an empty array (the `a` variable in the randomized parametrized test LaxBackedScipyStatsTests.testMode sometimes happens to be an empty array).
Note, the actual ignored warning is RungimeWarning (the superclass of SmallSampleWarning) to make it backward compatible (scipy.stats._axis_nan_policy.SmallSampleWarning does not exist in scipy prior 1.14, not to mention it being under private declared in a private (_axis_nan_policy) namespace.
PiperOrigin-RevId: 677629866
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.
Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().
PiperOrigin-RevId: 568923117
In cases where this causes TPU tests to fail, relax test tolerances in the test cases themselves.
TPUs are less precise only for specific operations, notably matrix multiplication (for which usually enabling higher-precision matrix multiplication is the right choice if precision is needed), and certain special functions (e.g., log/exp/pow).
The net effect of this change is mostly to tighten up many test tolerances on TPU.
PiperOrigin-RevId: 562953488
Breaks tests. lax.sub requires arguments to have the same dtypes, got float32, float64. (Tip: jnp.subtract is a similar function that does automatic type promotion on inputs).
PiperOrigin-RevId: 514897538
Added vonmises pdf, logpdf & respective tests.
Altered type-hinting, added pi as a _lax_const
Changed lax constant pi to be created in _pdf instead of passed arg.
Changed name in __init__.py
Fixed bug in tests.
Review related alterations.
Review related changes.
Added vonmises pdf, logpdf & respective tests.
Added vonmises pdf, logpdf & respective tests.
Altered type-hinting, added pi as a _lax_const
Changed lax constant pi to be created in _pdf instead of passed arg.
Changed name in __init__.py
Fixed bug in tests.
Review related alterations.
PR
PR
PR
* allow rc2 in numpy versions when parsed by tests.
* don't cast np.empty(), which can lead to cast errors.
* NumPy 1.24 now warns on overflowing scalar int to array casts in more
places.
fix some shape and type issues
import into namespace
imports into non-_src library
working logpdf test
cleanup
working tests for cdf and sf after fixing select
relax need for x to be in (a, b)
ensure behavior with invalid input matches scipy
remove enforcing valid parameters in tests
added truncnorm to docs
whoops alphabetical
fix linter error
fix circular import issue