67 Commits

Author SHA1 Message Date
Tyler Augustine
d52de206cb Disable tests that timeout in debug mode in CI
PiperOrigin-RevId: 476157051
2022-09-22 11:44:56 -07:00
Peter Hawkins
d0e1c3e684 Disable tests under sanitizers that are timing out in CI.
PiperOrigin-RevId: 475839926
2022-09-21 08:50:55 -07:00
Yash Katariya
eec1b4a017 Set the sharding of uncommitted single device sharding Arrays correctly and fix some miscellaneous tests with Array too. Enable pjit_test and xmap_test with Array too (all of them are mechanical changes).
PiperOrigin-RevId: 474858389
2022-09-16 11:16:27 -07:00
Yash Katariya
45e48b3a7d Mark multiprocess_gpu_test as manual to skip it in OSS
PiperOrigin-RevId: 474806518
2022-09-16 07:08:08 -07:00
Yash Katariya
28741b8e0d Some miscellaneous changes to make tests pass when jax.Array is enabled by default.
1. Add `device_buffer` and `device_buffers` fields to Array as a backwards compatible change for DA and SDA.
2. Support PartitionSpecs as input to in_axis_resources and out_axis_resources when jax_array is enabled as a backwards compatible change since all user code uses this currently. Create a MeshPspecSharding internally.
3. Some tests changes to make them pass

PiperOrigin-RevId: 474642889
2022-09-15 13:27:40 -07:00
Jake VanderPlas
13a7034e6a Internal change
PiperOrigin-RevId: 474331907
2022-09-14 10:39:38 -07:00
jax authors
b27d8c1267 Merge pull request #12342 from jakevdp:typing-test
PiperOrigin-RevId: 474154750
2022-09-13 16:36:49 -07:00
Jake VanderPlas
b3c31ebe7d Add typing_test.py 2022-09-13 12:43:51 -07:00
Sharad Vikram
ad326b99da Use cases_from_list to subsample enumerated cases in for_loop_test
PiperOrigin-RevId: 474093596
2022-09-13 12:34:10 -07:00
Sharad Vikram
dc4922fd08 Bump shards on for_loop_test
PiperOrigin-RevId: 474038276
2022-09-13 09:30:04 -07:00
Sharad Vikram
e5725f1df1 Split for_loop_test out of lax_control_flow_test
PiperOrigin-RevId: 473848277
2022-09-12 14:46:07 -07:00
Yash Katariya
864849d075 Increase the GPU shard count to 40 since its timing out every once in a while in kokoro.
PiperOrigin-RevId: 473617463
2022-09-11 14:30:45 -07:00
Sharad Vikram
b6c3b9df19 Split State effect into Read/Write/Accum effects and tie them to Ref avals 2022-09-08 08:04:13 -07:00
jax authors
c8ac6dbe6b [PJRT:C] Control whether to call wrapped C++ PJRT client directly with a global variable kPjRtCApiBypass.
Default value of `kPjRtCApiBypass` is false.

PiperOrigin-RevId: 472847147
2022-09-07 16:52:07 -07:00
Yash Katariya
6340952e2a Make jit == pjit. This means that the lowering and execution paths of jit and pjit are merged.
A fallback to `lower_xla_callable` is taken when pmap appears in the jaxpr during the jit lowering path.

Added support for `keep_unused`, `committed` and `core.Token` to pxla.py.

PiperOrigin-RevId: 470896270
2022-08-29 22:03:21 -07:00
Sudhakar
4b1a2eaaec combine gpu tests 2022-08-25 15:27:07 -07:00
jax authors
2a00533e3e Internal change
PiperOrigin-RevId: 469730523
2022-08-24 08:16:46 -07:00
Peter Hawkins
da4e79a625 Increase some test shardings to reduce CI timeouts under asan/tsan/msan.
PiperOrigin-RevId: 469173324
2022-08-22 06:52:04 -07:00
Yash Katariya
f905d989c1 Make eager pmap tests pass with Array. Also add a slow path for Array in pmap similar to what SDA has. This is required for eager pmap. Adding a slow path removes the need for doing sharding checks in api.py because SDA doesn't do those checks and if the sharding does not match with pmap sharding, then it just defaults to the slow path (exactly like SDA).
PiperOrigin-RevId: 468843310
2022-08-19 21:37:22 -07:00
Yash Katariya
314cf8a439 Use .device() to get the device and platform from the device and fix TODO to point to github issue
PiperOrigin-RevId: 468769708
2022-08-19 13:14:13 -07:00
Yash Katariya
d77848bcc9 Enable jax_array on CPU for the entire JAX test suite!
PiperOrigin-RevId: 468726200
2022-08-19 10:04:35 -07:00
Yash Katariya
2231b0c054 Add array test counterparts to dtypes_test.
PiperOrigin-RevId: 468568788
2022-08-18 15:58:39 -07:00
Yash Katariya
acdae7c237 Add weak type support to Array. Also make all api_test.py tests pass with Array. I have disabled the float0 test for now until I investigate.
PiperOrigin-RevId: 468264910
2022-08-17 12:25:49 -07:00
Sharad Vikram
8068e4638c Re-bump shard count for pmap_test
PiperOrigin-RevId: 468239588
2022-08-17 10:46:19 -07:00
jax authors
7721579700 Internal change
PiperOrigin-RevId: 468068879
2022-08-16 17:43:03 -07:00
Sharad Vikram
6ae46c3d69 Bump pmap_test size to handle new eager tests
PiperOrigin-RevId: 467967021
2022-08-16 10:45:50 -07:00
Peter Hawkins
cd62dbffb3 Increase Bazel sharding of scipy_signal_test on CPU. 2022-08-12 15:14:35 +00:00
Peter Hawkins
3a693e98c0 Increase Bazel sharding of long-running CPU tests. 2022-08-12 13:21:40 +00:00
Yash Katariya
33c4fc4fe2 Pmap should output SDA like Arrays to maintain the current behavior exactly. Split the shard_arg_handler for Array based on whether the mode is pmap or pjit. Why do this? The doc below explains more about the context.
PiperOrigin-RevId: 466849614
2022-08-10 20:11:37 -07:00
Sharad Vikram
2d72dc899b Enable receives in TPU callbacks and add tests
PiperOrigin-RevId: 466103306
2022-08-08 11:42:16 -07:00
Peter Hawkins
efe1b6b44c Disable tsan on linalg_test on TPU.
The test is too slow under tsan and times out in CI.

PiperOrigin-RevId: 465570229
2022-08-05 08:33:42 -07:00
Peter Hawkins
b865111996 Refactor BUILD files to avoid individually naming Python dependencies.
Add a parametric py_deps() macro for adding Python package dependencies for Bazel rules.

Fix build failure with dangling matplotlib reference.

PiperOrigin-RevId: 465562141
2022-08-05 07:49:20 -07:00
Vlad Feinberg
269067e3e8 Make LOBPCG test plots compatible with bazel.
bazel test invocations would previously not work, because the lobpcg_test did not include the appropriate flag parsing and absl test invocations when run as a script. This change fixes that, and in addition shards tests and removes needless and redundant slow tests with larger matrix sizes to make the tests finish in a smaller amount of time. Now, generated pngs with debug information are properly reported via the undeclared outputs directory when the environment variable to emit them, LOBPCG_EMIT_DEBUG_PLOTS, is set to a non-falsy value.

PiperOrigin-RevId: 465465731
2022-08-04 20:05:53 -07:00
jax authors
0a8ca1982c Merge pull request #11721 from sudhakarsingh27:main
PiperOrigin-RevId: 465381834
2022-08-04 12:52:16 -07:00
jax authors
01819257f6 Merge pull request #11701 from sharadmv:state
PiperOrigin-RevId: 464658336
2022-08-01 17:05:43 -07:00
Yash Katariya
9a5af235da Delete sharded_jit
PiperOrigin-RevId: 464081692
2022-07-29 08:19:52 -07:00
Peter Hawkins
9e6254e058 Increase shard counts for TPU tests in an attempt to fix CI timeouts under asan.
PiperOrigin-RevId: 463830139
2022-07-28 07:14:36 -07:00
George Necula
07fcf79324 jax.mask and jax.shapecheck are being deprecated
Issue: #11557
PiperOrigin-RevId: 462315754
2022-07-21 00:09:31 -07:00
Kuangyuan Chen
c0ec3b33e6 Introduce jax.experimental.clear_backends to delete all JAX runtime backends.
In cases like unit tests, users may want to clean up all the backends along with the resources used in the end of the test, and reinitialize them in the next test.

PiperOrigin-RevId: 462239974
2022-07-20 15:10:27 -07:00
Adam Paszke
117da44712 Internal change
PiperOrigin-RevId: 462110048
2022-07-20 04:31:21 -07:00
George Necula
777c129dfb [dynamic-shapes] Split dynamic_api_test.py
PiperOrigin-RevId: 461109288
2022-07-14 20:18:53 -07:00
jax authors
3eff9d11d2 Internal change
PiperOrigin-RevId: 460434859
2022-07-12 05:21:20 -07:00
Peter Hawkins
64e0b5d801 Increase bazel sharding of GPU tests.
Reduces the maximum time for some test shards to avoid flaky timeouts.
2022-07-11 14:19:43 +00:00
Sharad Vikram
b666f665ec Rollback of HCB GPU custom call due to internal failures
PiperOrigin-RevId: 460079787
2022-07-10 13:05:27 -07:00
jax authors
66ab792fc0 Merge pull request #11383 from YouJiacheng:Enable-HCB-customCall-implementation-on-GPU
PiperOrigin-RevId: 459872063
2022-07-08 18:23:16 -07:00
YouJiacheng
7c707832aa Enable CustomCall implementation on GPU 2022-07-09 02:29:08 +08:00
Peter Hawkins
5a7bedca37 Increase shard_count for sparse_test_gpu to 20.
1918d39765 updated the wrong test!

This test is close to the timeout in the GPU CI and flakes sometimes.

PiperOrigin-RevId: 459762867
2022-07-08 08:30:26 -07:00
Peter Hawkins
1918d39765 Increase number of shards for GPU sparse_test to 20. 2022-07-07 21:14:25 -04:00
Sharad Vikram
6274b9ed39 Enable Python callbacks on TFRT TPU backend
PiperOrigin-RevId: 459415455
2022-07-06 20:52:50 -07:00
Peter Hawkins
95e79332c0 Add JAX_TEST_TARGETS and JAX_EXCLUDE_TEST_TARGETS environment variables to assist with skipping tests under Bazel.
Add "multiaccelerator" test tags to mark tests that would meaningfully run with more than one accelerator (e.g., GPU).

PiperOrigin-RevId: 459320212
2022-07-06 12:51:43 -07:00