Breaks tests. lax.sub requires arguments to have the same dtypes, got float32, float64. (Tip: jnp.subtract is a similar function that does automatic type promotion on inputs).
PiperOrigin-RevId: 514897538
This reverts commit 69d18cc7b58ae4ed82246605d66ed07a49fad676, reversing
changes made to 13e875f8b8d8dd9152045c7e3b5045a9bb0d7db0.
Reverting until we address https://github.com/google/jax/issues/14249
Added vonmises pdf, logpdf & respective tests.
Altered type-hinting, added pi as a _lax_const
Changed lax constant pi to be created in _pdf instead of passed arg.
Changed name in __init__.py
Fixed bug in tests.
Review related alterations.
Review related changes.
Added vonmises pdf, logpdf & respective tests.
Added vonmises pdf, logpdf & respective tests.
Altered type-hinting, added pi as a _lax_const
Changed lax constant pi to be created in _pdf instead of passed arg.
Changed name in __init__.py
Fixed bug in tests.
Review related alterations.
PR
PR
PR
* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg.
* Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction.
* Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction.
None of these primitives are differentiable at the moment.
PiperOrigin-RevId: 487224934
fix some shape and type issues
import into namespace
imports into non-_src library
working logpdf test
cleanup
working tests for cdf and sf after fixing select
relax need for x to be in (a, b)
ensure behavior with invalid input matches scipy
remove enforcing valid parameters in tests
added truncnorm to docs
whoops alphabetical
fix linter error
fix circular import issue
Increase precision of matmuls in LU decompositions, pseudo-inverse solves, and their gradients. It is unlikely users want to use low precision for these operations and high precision is probably the right default.
PiperOrigin-RevId: 482071629
scipy accounts for around 400ms of the 900ms of JAX's import time. By
loading scipy lazily, we can improve the timing of `import jax` down to
about 500ms.