12 Commits

Author SHA1 Message Date
Jake VanderPlas
97512e9e44 JaxTestCase: set jax_numpy_rank_promotion='raise' by default 2022-02-14 09:22:05 -08:00
jax authors
5691010d2f Copybara import of the project:
--
d42fffd849a4bac0c0c11a3346c93f07f8c64c44 by Jake VanderPlas <jakevdp@google.com>:

JaxTestCase: set numpy_rank_promotion='raise' by default
PiperOrigin-RevId: 427896974
2022-02-10 19:08:29 -08:00
Jake VanderPlas
6324577a63 JaxTestCase: set numpy_rank_promotion='raise' by default 2022-02-10 16:54:31 -08:00
Peter Hawkins
db2e91eba2 Move jax.test_util to jax._src.test_util.
Add forwarding shims for names used by external clients of JAX in practice.

PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00
Jake VanderPlas
6114e6a0d3 test_util: add decorator to set config values in test cases 2021-08-05 14:06:37 -07:00
Jake VanderPlas
768aba55f1 disable implicit rank promotion in lax_numpy_einsum/indexing/vectorize_test 2021-08-03 12:19:36 -07:00
Peter Hawkins
9fad2441a2 Mark arguments to jax.jit() other than the function as keyword-only.
This change is to prevent breakage when options are added or removed.
2021-04-08 10:32:35 -04:00
Jake VanderPlas
afce718eb1 Add ability to specify individual test targets 2020-06-29 11:08:57 -07:00
Peter Hawkins
fffdb2daa8
Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_… (#3280)
* Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_util APIs.
Default to check_dtypes=True.

Remove explicit usages of check_dtypes=True from tests. This mostly just removes visual noise from tests. Testing for exact type equality is the sensible default, although there are cases where opting out makes sense.

No functional changes intended.

* Fix a number of lax reference implementations to preserve types.
2020-06-01 17:19:23 -04:00
Jake Vanderplas
bc30597780
Cleanup: remove unused imports in tests (#3276) 2020-06-01 11:49:35 -07:00
Peter Hawkins
b1bc841ae5
Replace np -> jnp, onp -> np in more places. (#2973)
* Replace np -> jnp, onp -> np in more places.

Context: #2370

* Fix typo in random_test.py
2020-05-05 16:40:41 -04:00
Stephan Hoyer
00140f07e2
Add jax.numpy.vectorize (#2146)
* Add jax.numpy.vectorize

This is basically a non-experimental version of the machinery in
`jax.experimental.vectorize`, except:
- It adds the `excluded` argument from NumPy, which works just like
  `static_argnums` in `jax.jit`.
- It doesn't include the `axis` argument yet (which NumPy doesn't have).

Eventually we might want want to consolidate the specification of signatures
with signatures used by shape-checking machinery, but it's nice to emulate
NumPy's existing interface, and this is already useful (e.g., for writing
vectorized linear algebra routines).

* Add deprecation warning to jax.experimental.vectorize

* improve implementation
2020-02-12 14:09:37 -08:00