89 Commits

Author SHA1 Message Date
Jake VanderPlas
f8e18e9a00 [x64] minor weak_type changes to linalg.py 2021-12-07 16:27:29 -08:00
Jake VanderPlas
022f8ac2ee [x64] preserve weak types in jax.scipy.sparse solvers 2021-11-30 10:36:28 -08:00
jax authors
d81114deff Merge pull request #8668 from jakevdp:canonicalize-axis
PiperOrigin-RevId: 412088801
2021-11-24 10:30:14 -08:00
Peter Hawkins
4e21922055 Use imports relative to the jax package consistently, rather than .-relative imports.
This is more consistent, since currently we use a mix of both styles. It may also help pytype yield more accurate types.

PiperOrigin-RevId: 412057514
2021-11-24 07:48:29 -08:00
Jake VanderPlas
f6e3f1b4ad Cleanup: remove duplicate canonicalize_axis utility 2021-11-23 16:54:02 -08:00
Jake VanderPlas
f2a959054a Document jax.lax.Precision 2021-11-08 14:15:31 -08:00
Jake VanderPlas
40d6f5ed90 Tighten up dtypes across the package 2021-10-29 13:50:30 -07:00
Jake VanderPlas
0a232b2237 stats.multivariate_normal: support broadcasted inputs 2021-10-19 16:58:36 -07:00
Peter Hawkins
f8ba024621 Fix JAX functions to work if the default gather mode is set to "fill".
These functions really do want "clip".
2021-09-30 14:21:05 -04:00
Peter Hawkins
2eb20357db Add @jit decorators to jax.numpy.linalg and jax.scipy.linalg. 2021-09-24 15:52:11 -04:00
Peter Hawkins
41a0f6a682 Remove a stale comment. 2021-09-21 17:23:05 -04:00
jax authors
4bc6b27021 Merge pull request #7966 from jakevdp:faster-conv
PiperOrigin-RevId: 398072150
2021-09-21 13:33:33 -07:00
Peter Hawkins
1163e218e8 Attempt to land https://github.com/google/jax/pull/6400 again.
This PR changes `jax.numpy.array()` to avoid creating any on-device arrays during tracing. As a consequence, calls to `jnp.array()` in a traced context, such as `jax.jit` will always be staged into the trace.

This change may break code that depends on the current (undocumented and unintentional) behavior of `jnp.array()` to perform shape or index calculations that must be known statically (at trace time). The workaround for such cases is to use classic NumPy to perform shape/index calculations.

PiperOrigin-RevId: 398008511
2021-09-21 09:06:40 -07:00
Jake VanderPlas
d7e94b9eef convolutions: use flip() to clean up reverse-indexing 2021-09-21 08:49:32 -07:00
jax authors
e67d49b8a8 Merge pull request #7838 from khdlr:map_coordinates_reflect_mirror
PiperOrigin-RevId: 397062070
2021-09-16 05:58:44 -07:00
jax authors
5c58afa647 Merge pull request #7855 from cyprienc:scipy-stats-nbinom
PiperOrigin-RevId: 396617949
2021-09-14 09:49:50 -07:00
Cyprien
8c8f0a8c71 Feat: scipy.stats.nbinom implementation
fix: increasing tolerance check for testNBinomLogPmf in scipy_stats_test.py
2021-09-14 08:05:42 +01:00
Peter Hawkins
8b2123968a Switch internal users of jax.util.partial to use functools.partial. 2021-09-13 21:09:58 -04:00
Peter Hawkins
a84426cb8f Switch internal users of jax.ops.index_... to use x.at[x].set() APIs. 2021-09-13 19:48:29 -04:00
Peter Hawkins
80599c0821 Replace uses of jax.partial with functools.partial, in preparation for removing jax.partial.
jax.partial is an alias for functools.partial, and functools.partial is a Python standard library API. There's no need for jax to export this function.

PiperOrigin-RevId: 396370975
2021-09-13 09:16:19 -07:00
Konrad Heidler
5ed619afbb
Implement 'reflect' and 'mirror' padding modes for scipy.ndimage.map_coordinates 2021-09-07 16:05:35 +02:00
Jonathan Terhorst
fec72e1852 add support for scipy.special.{expn,expi,exp1} 2021-08-24 16:36:10 -04:00
Jake VanderPlas
730ae33e03 logsumexp: fix issue with debug_nans 2021-08-18 13:57:00 -07:00
jax authors
3827a8b26a Merge pull request #7617 from juliuskunze:dct
PiperOrigin-RevId: 391540164
2021-08-18 09:06:02 -07:00
Jake VanderPlas
0e256ddeb7 Fix logsumexp issue with debug_nans and disable_jit 2021-08-17 13:47:06 -07:00
Julius Kunze
6d83027b69 Support scipy.fft.dct/dctn type=2 2021-08-17 18:56:44 +02:00
Jake VanderPlas
743a1c270d scipy.stats.beta: fix pdf for x=0, 1 2021-08-17 09:52:11 -07:00
elliotwaite
7392a57b75 DOC: many small fixes 2021-08-04 16:55:13 -07:00
jax authors
606cbe036a Merge pull request #7370 from slowy07:fixing
PiperOrigin-RevId: 388774232
2021-08-04 13:43:58 -07:00
Jake VanderPlas
20cef7eaa8 Fix rank promotion error in jsp.special.zeta 2021-08-03 14:33:57 -07:00
Jake VanderPlas
af16177659 Fix rank promotion error in jsp.special.multigammaln 2021-08-03 13:39:29 -07:00
Jake VanderPlas
e131343274 Fix issue with infinities in logsumexp 2021-08-02 15:27:24 -07:00
Peter Hawkins
b232d09440 Enable flake8 checks for spaces around operators. 2021-07-30 08:45:38 -04:00
slowy07
9eadb07bdc fix: miss typo codespell and documentation 2021-07-24 15:25:13 +07:00
Peter Hawkins
0dfd76af97 Remove additional info return value from jax.scipy.linalg.polar(). 2021-07-20 13:13:31 -04:00
Adam Lewis
a2073ffcc2 Adds an implementation of a QR-based Dynamically Weighted Halley iteration. 2021-07-20 11:30:36 -04:00
Peter Hawkins
3ddcec27f2 Update minimum jaxlib version to 0.1.69. 2021-07-15 17:00:13 -04:00
Jake VanderPlas
9e73972d0a Fix jax.scipy.stats.gamma.pdf() for x=0.0, a=1.0 2021-07-12 14:43:44 -07:00
tlu7
d97b393694 Adds spherical harmonics.
Co-authored-by: Jake VanderPlas <jakevdp@google.com>
2021-07-02 10:42:29 -07:00
Peter Hawkins
d658108d36 Fix type errors with current mypy and NumPy.
Enable type stubs for jaxlib.

Fix a nondeterminism problem in jax2tf tests.
2021-06-24 10:51:06 -04:00
Luke Pfister
c33388b136 Support complex numbers in jax.scipy.signal.convolve/correlate 2021-06-18 13:07:00 -06:00
tlu7
095e6507b9 Support value computation of associated Legendre functions.
Co-authored-by: Jake VanderPlas <jakevdp@google.com>
2021-06-14 14:51:37 -07:00
Jake VanderPlas
22dbe80255 DOC: state that digamma only accepts float 2021-06-08 10:47:27 -07:00
tlu7
a02bf59233 Adds associated Legendre functions of the first kind.
Co-authored-by: Jake VanderPlas <jakevdp@google.com>
2021-06-02 11:37:37 -07:00
Stephan Hoyer
5b3d203453
Merge branch 'master' into master 2021-06-02 01:05:47 -07:00
jax authors
87abab713d Merge pull request #6785 from GregCT:changelist/373551581
PiperOrigin-RevId: 376922795
2021-06-01 14:47:15 -07:00
Filippo Vicentini
c0c8e0d0a3 make logsumexp work with complex numbers 2021-05-31 16:01:57 +02:00
Gregory Thornton
03a1ee9269 Update Jax linesearch to behave more like Scipy 2021-05-26 12:49:56 +01:00
Lukas Geiger
3a2e80ef51 Replace pow() with srqt() or square() where possible 2021-05-24 10:43:35 +01:00
Jakob Unfried
f0c7427000 add L-BFGS optimizer 2021-05-19 19:46:11 +02:00