29 Commits

Author SHA1 Message Date
Peter Hawkins
c5e8ae80f9 Update jax.scipy.special.gamma and gammasgn to return NaN for negative integer inputs.
Change to match upstream scipy: https://github.com/scipy/scipy/pull/21827.

Fixes #24875
2024-11-18 20:33:27 -05:00
Jake VanderPlas
d698da610a scipy.special.beta: remove deprecated x and y parameters 2024-11-06 09:01:27 -08:00
Jérome Eertmans
f9cb95ca08
feat(lib): add real-valued implementation of jax.scipy.special.fresnel
Add implementation, documentation, and tests, for both single-precision and double-precision floating-point arithmetic.
2024-09-03 09:50:19 +02:00
Neil Girdhar
56fdb42e9d Copy nn.{softmax,log_softmax} to scipy.special 2024-06-22 09:32:14 -04:00
Jake VanderPlas
aa1452375b Register beta args deprecation
PiperOrigin-RevId: 642427224
2024-06-11 16:19:14 -07:00
Jake VanderPlas
990b475b77 jax.scipy.special.beta: deprecate x,y in favor of a,b 2024-06-10 09:01:39 -07:00
Seonghyeon
818e7d92a4 Fix rel_entr behavior at boundary value 2024-05-28 13:17:28 +00:00
Jake VanderPlas
568db105ea Add jax.scipy.special.gammasgn 2024-04-18 16:14:55 -07:00
Jake VanderPlas
f090074d86 Avoid 'from jax import config' imports
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00
Jake VanderPlas
97b15cc64b BUG: fix sign of beta() 2024-04-08 11:06:08 -07:00
rajasekharporeddy
4f3d27bde8 Update jax.scipy.special.ndtri to return NaN for the values outside of the range [0, 1] 2024-04-03 11:45:21 +05:30
Adam Paszke
bb0405e548 Use explicit test case names instead of the numbering system of parameterized
`@parameterized.parameters` interacts very badly with `JaxTestCase.rng`:
the RNG seed is derived from the test method and the test method name
can change if additional test cases are inserted before it. This can cause
CI failures in functions that are completely unrelated to the change that
introduces the breakage.

We should seriously reconsider this strategy. Either all instances of
`parameters` + `self.rng` should be removed or we should find an alternative
strategy for seeding.

PiperOrigin-RevId: 589798050
2023-12-11 05:55:27 -08:00
Jake VanderPlas
70d0f60ce1 Add special.factorial function 2023-12-04 06:13:14 -08:00
Jake VanderPlas
01fde43fce Fix sign of jax.scipy.special.gamma for negative inputs 2023-11-27 14:08:02 -08:00
sdupourque
47ca51f474 implementation of poch and hyp1f1 2023-11-15 20:01:00 +01:00
Ben West
02f6fcb9da Add beta function 2023-11-05 15:37:38 -08:00
Sergei Lebedev
cbcaac2756 MAINT Migrate remaining internal/test modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.
2023-10-12 17:32:15 +01:00
Peter Hawkins
1885c4933c Add a new internal test utility test_device_matches() and use it instead of equality tests on device_under_test().
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.

Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().

PiperOrigin-RevId: 568923117
2023-09-27 12:10:43 -07:00
Jake VanderPlas
75d12a2e21 Fix tolerance on bernoulli test 2023-08-23 16:59:27 -07:00
Jake VanderPlas
042111eb08 Add jax.scipy.special.bernoulli 2023-08-23 12:58:37 -07:00
Jake VanderPlas
6cd467fd57 Create lax.zeta with native HLO lowering 2023-08-16 13:43:41 -07:00
salamandercrossing
4e42adb599 Add kl_div and rel_entr functions
Their behavior is the same as functions in scipy.special. The only small
difference is in rel_entr function, which unlike scipy.special does not
take the optional parameter 'out'.

Resolves #16630
2023-07-27 21:34:55 +00:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
jax authors
1de4d14da8 Merge pull request #15656 from laqua-stack:add-special-gamma-fcn
PiperOrigin-RevId: 525566749
2023-04-19 15:28:36 -07:00
Jake VanderPlas
1b0106fd1e Make i0e gradient test more robust 2023-04-19 14:41:44 -07:00
laqua-stack
d742733bea feat (scipy.special): Add a xla version of scipy.special.gamma function
- Add gamma fcn api in scipy.special
- Add tests for this purpose
- Add function to the docs

Currently, there is no implementation of the gamma function in jax
but there is one in scipy.special. This breaks some higher level
jit-compilation like in the blackjax backend for pymc. This commit
adds the missing gamma function.

Resolves: #15409
2023-04-18 21:10:22 +02:00
pizzud
ef28dcf091 lax_scipy_test: Split into three targets, take 2.
The goal is to ensure that all shards fit into a medium timeout in sanitizer
configurations.

Running 256 entry vectors in spectral_dac is too slow, so let's replace that
with a smaller vector that isn't a power of 2. Avoiding a power of 2 requires
us to widen the tolerance a bit due to vectorization changes.

While here, specify deps a little more precisely as well.

PiperOrigin-RevId: 514440062
2023-03-06 09:53:23 -08:00
pizzud
0292f5d0a6 lax_scipy_test: Revert split into three targets.
Somehow the spectral_dac functionality is flaky on its own when run on CPU.

PiperOrigin-RevId: 512195860
2023-02-24 16:56:40 -08:00
pizzud
09afbac6ff lax_scipy_test: Split into three so that each target is small enough to fit within a medium timeout.
The spectral_dac tests are also shrunk because running the full suite on 256-entry vectors is too slow.

This allows them to run in ASAN in more situations.

While here, specify deps a little more precisely as well.

PiperOrigin-RevId: 511829646
2023-02-23 10:51:58 -08:00