--
d42fffd849a4bac0c0c11a3346c93f07f8c64c44 by Jake VanderPlas <jakevdp@google.com>:
JaxTestCase: set numpy_rank_promotion='raise' by default
PiperOrigin-RevId: 427896974
* Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_util APIs.
Default to check_dtypes=True.
Remove explicit usages of check_dtypes=True from tests. This mostly just removes visual noise from tests. Testing for exact type equality is the sensible default, although there are cases where opting out makes sense.
No functional changes intended.
* Fix a number of lax reference implementations to preserve types.
* Add jax.numpy.vectorize
This is basically a non-experimental version of the machinery in
`jax.experimental.vectorize`, except:
- It adds the `excluded` argument from NumPy, which works just like
`static_argnums` in `jax.jit`.
- It doesn't include the `axis` argument yet (which NumPy doesn't have).
Eventually we might want want to consolidate the specification of signatures
with signatures used by shape-checking machinery, but it's nice to emulate
NumPy's existing interface, and this is already useful (e.g., for writing
vectorized linear algebra routines).
* Add deprecation warning to jax.experimental.vectorize
* improve implementation