32 Commits

Author SHA1 Message Date
Skye Wanderman-Milne
f90b5eed52 Add pjrt_c_api_unimplemented pytest marker to skip unsupported tests.
Also adds `test_util.pytest_mark_if_available` helper function.
2023-01-12 22:17:23 +00:00
Yash Katariya
7d4ef891af Add device and backend API to pjit but resolve them away in infer_params. This is to merge jit and pjit frontend API.
The semantics of mentioning `device` or `backend` on `pjit` is the same as doing a `device_put` i.e. no matter which device the arg is on, reshard it to the device mentioned.

PiperOrigin-RevId: 495437165
2022-12-14 15:41:59 -08:00
jax authors
dd902fde21 Merge pull request #13317 from google:xdist_tpu
PiperOrigin-RevId: 490366370
2022-11-22 16:40:00 -08:00
Skye Wanderman-Milne
120125f3dd Make pytest-xdist work on TPU and update Cloud TPU CI.
This change also marks multiaccelerator test files in a way pytest can
understand (if pytest is installed).

By running single-device tests on a single TPU chip, running the test
suite goes from 1hr 45m to 35m (both timings are running slow tests).

I tried using bazel at first, which already supported parallel
execution across TPU cores, but somehow it still takes 2h 20m! I'm not
sure why it's so slow. It appears that bazel creates many new test
processes over time, vs. pytest reuses the number of processes
initially specified, and starting and stopping the TPU runtime takes a
few seconds so that may be adding up. It also appears that
single-process bazel is slower than single-process pytest, which I
haven't looked into yet.
2022-11-18 22:05:13 +00:00
Yash Katariya
b6fa77cb60 Fix forward (Add deprecation warnings to DA, SDA and GDA): By raising the warnings in the hook of the jax_array config.
PiperOrigin-RevId: 489503583
2022-11-18 10:12:40 -08:00
Peter Hawkins
9f2a6acb61 Revert: Add deprecation warnings to DA, SDA and GDA.
This change is currently overly noisy for users.

PiperOrigin-RevId: 489455729
2022-11-18 06:06:13 -08:00
Yash Katariya
52a2428073 Add deprecation warnings to DA, SDA and GDA.
PiperOrigin-RevId: 489314189
2022-11-17 14:51:29 -08:00
Jingxin Ye
e6c88f2c58 update pytest.ini to print warning message for compilation_cache_test 2022-11-04 21:43:51 +00:00
Yash Katariya
c9a60f9410 Only raise the warning if jax_array is enabled and the code is coming from jit.
PiperOrigin-RevId: 482053435
2022-10-18 16:35:43 -07:00
Sudhakar
cbcd0cdd04 ignore UserWarning 2022-10-13 17:22:15 -07:00
Sudhakar
5f1858f533 Add pytest marker inside the test only if pytest is present in the env 2022-09-06 11:45:59 -07:00
Jake VanderPlas
a40fb76a51 pytest: remove obsolete warning filters 2022-07-25 10:47:06 -07:00
Yash Katariya
4ed06602d3 Add deprecation warning for sharded_jit.
PiperOrigin-RevId: 439926957
2022-04-06 13:54:06 -07:00
Peter Hawkins
541f762a31 Tolerate NumPy deprecation warnings when using older SciPy.
Simply importing scipy 1.2.3 with NumPy 1.21.5 leads to deprecation
warnings. Tolerate these in pytest.
2022-01-31 16:34:59 +00:00
Peter Hawkins
3c193613ce Fix test failures under Numpy 1.22. 2022-01-04 12:35:44 -05:00
Roy Frostig
623c201054 [JAX] move example libraries from jax.experimental into jax.example_libraries
The `jax.experimental.stax` and `jax.experimental.optimizers` modules are standalone examples libraries. By contrast, the remaining modules in `jax.experimental` are experimental features of the JAX core system. This change moves the two example libraries, and the README that describes them, to `jax.example_libraries` to reflect this distinction.

PiperOrigin-RevId: 404405186
2021-10-19 17:30:45 -07:00
Peter Hawkins
104a46594b Add DeprecationWarnings to jax.ops.index_... operators.
Remove uses of index_... in Common Gotchas notebook.
2021-10-05 20:47:22 -04:00
Skye Wanderman-Milne
3ff51bbb1d Use pytest's filterwarnings feature instead of filtering each test case.
We often forget to put the per-test-case decorators, resulting in test
failures in cases not covered by github CI (e.g. Cloud TPU
tests). This change filters the "experimental feature" warnings by
default.
2021-04-23 10:28:22 -07:00
George Necula
d9468c7513 Cleanup the API, and more documentation 2021-04-08 11:25:32 +03:00
Jake VanderPlas
8e789c7380 Run doctest on all source files except jax2tf 2021-04-05 10:39:59 -07:00
Jake VanderPlas
f74235cdae X32 tests: fail on dtype warnings 2020-12-08 13:03:30 -08:00
Peter Hawkins
195e13c14b Remove jax.experimental.optix.
optix has become its own Python package (optax). You should use optax instead.

PiperOrigin-RevId: 343291598
2020-11-19 08:03:15 -08:00
Peter Hawkins
2b8d840cc3 [JAX] Remove uses of the deprecated jax.experimental.vectorize.
jax.numpy.vectorize should be used instead.

PiperOrigin-RevId: 341836454
2020-11-11 08:34:48 -08:00
Peter Hawkins
a1c6831124 Add a deprecation warning to the optix package. 2020-11-09 09:18:03 -05:00
George Necula
bf97e47929
Make infeed_test and host_callback_test independent. (#3676)
* Make infeed_test and host_callback_test independent.

* the infeed_test will stop the outfeed receiver
* Remove the use of --dist=loadfile.
* Prevent logging on exit
2020-07-07 11:03:30 +03:00
George Necula
4f3011f320
Refactored host_callback to use the C++ runtime. (#3644)
* Refactored host_callback to use the C++ runtime.

* The new runtime makes it unnecessary to start the outfeed_receiver
  in the user's code
* We don't need msgpack anymore
* There is an interaction between host_callback and using lax.outfeed.
  I am trying to solve this by (a) making host_callback_test stop the
  outfeed receiver on finish and infeed_test on start, and (b)
  telling pytest-xdist to run all the tests from one file into
  a single worker.
2020-07-04 18:12:58 +03:00
Peter Hawkins
e680304dca
Remove warning suppression for tuple and list arguments to reductions. (#3545)
Fix callers.
2020-06-24 15:59:31 -04:00
joao guilherme
319eeaf5c9
Future warning about lists and tuples (#3369) 2020-06-24 10:54:06 -04:00
George Necula
27906ce2d8
Renamed experimental/jax_to_tf to experimental/jax2tf (#3404)
* Renamed experimental/jax_to_tf to experimental/jax2tf

* Leave a trampoline behind, for backwards compatibility
2020-06-11 11:52:09 +03:00
George Necula
8e0a012666
Initial import of jax2tf into JAX core (#3202)
* Initial import of jax2tf into JAX core

Renamed jax2tf.convert to jax_to_tf.
Added Travis test support.
Added OSS build configuration.

* Added support for squeeze
2020-05-29 09:56:32 +03:00
Jamie Townsend
670fab59cf
Test code in docs and api.py docstrings (#2994)
Also remove jaxpr doc tests from api_test.py.
2020-05-16 16:19:24 +03:00
Peter Hawkins
2dc81fb40c
Make pytest run over JAX tests warning clean, and error on warnings. (#2674)
* Make pytest run over JAX tests warning clean, and error on warnings.

Remove global warning suppression in travis.yml. Instead add a pytest.ini that converts warnings to errors, with the exception of a whitelist.
Either fix or locally suppress warnings in tests.

Also fix crashes on Mac related to a preexisting linear algebra bug.

* Fix some type errors in the FFT transpose rules revealed by the convert_element_type transpose rule change.
2020-04-12 15:35:35 -04:00