Change the contract of lax.linalg.tridiagonal to return the d and e vectors as well. Since we only just added this function and have never released JAX with it we can make this change without breaking compatibility.
Also fix wrong dtypes for d and e values in the CPU lapack sytrd wrapper.
PiperOrigin-RevId: 487621469
parse_flags_with_absl() only parses flags that start with --jax_. Other flags are only parsed when absl.app's main function runs. But that's too late for test cases: test cases need to have the number of generated cases chosen at module initialization time. Hence the --num_generated_cases flag wasn't doing anything. Oops. By renaming it it works once again.
It might make sense to stop using flags for the number of generated cases and only use environment variables. We defer that to a future change.
Fix many test cases that were shown to be broken with a larger number of test cases enabled.
PiperOrigin-RevId: 487406670
* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg.
* Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction.
* Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction.
None of these primitives are differentiable at the moment.
PiperOrigin-RevId: 487224934
fix some shape and type issues
import into namespace
imports into non-_src library
working logpdf test
cleanup
working tests for cdf and sf after fixing select
relax need for x to be in (a, b)
ensure behavior with invalid input matches scipy
remove enforcing valid parameters in tests
added truncnorm to docs
whoops alphabetical
fix linter error
fix circular import issue
Unlike the previous attempt, we don't try to use mhlo.logistic as the lowering of the new primitive yet. Instead, we lower to the old implementation of `expit`. This means that this change should be a no-op numerically and we can work on changing its implementation in a subsequent change.
PiperOrigin-RevId: 472705623
.to_py() was something of an accidental export from the JAX array classes. There are other mechanisms to turn a JAX array into a NumPy array, including `np.asarray(x)` and `jax.device_get(x)`. Deprecate this mechanism because it is redundant.
PiperOrigin-RevId: 469984029