85 Commits

Author SHA1 Message Date
Peter Hawkins
51c69ac594 Tag several tests as optonly to prevent test timeouts in debug mode CI builds.
PiperOrigin-RevId: 488950972
2022-11-16 08:50:23 -08:00
Peter Hawkins
0548c2d23b Disable sanitizer builds that are timing out or that are incompatible with test targets.
PiperOrigin-RevId: 488919571
2022-11-16 06:00:37 -08:00
jax authors
726b2bc2ee Add JAX monitoring library that instruments code via events.
PiperOrigin-RevId: 488731805
2022-11-15 12:41:41 -08:00
Peter Hawkins
da130cb074 Disable more tests under tsan/asan.
PiperOrigin-RevId: 488406459
2022-11-14 10:34:30 -08:00
Peter Hawkins
aa658bde6f Disable asan/tsan for a number of slow tests.
PiperOrigin-RevId: 488356786
2022-11-14 07:12:16 -08:00
Yash Katariya
cc41ee85c4 Mark scipy_signal_test and sparse_test optonly because it times out under debug mode.
PiperOrigin-RevId: 487533356
2022-11-10 07:38:58 -08:00
Yash Katariya
71360edf90 Bump the shard count for TPU to avoid timeouts
PiperOrigin-RevId: 487421018
2022-11-09 20:32:12 -08:00
Peter Hawkins
e42e52d4aa Rename test flag --num_generated_cases to --jax_num_generated_cases.
parse_flags_with_absl() only parses flags that start with --jax_. Other flags are only parsed when absl.app's main function runs. But that's too late for test cases: test cases need to have the number of generated cases chosen at module initialization time. Hence the --num_generated_cases flag wasn't doing anything. Oops. By renaming it it works once again.

It might make sense to stop using flags for the number of generated cases and only use environment variables. We defer that to a future change.

Fix many test cases that were shown to be broken with a larger number of test cases enabled.

PiperOrigin-RevId: 487406670
2022-11-09 18:58:05 -08:00
Yash Katariya
2dc804371c Increase the shard count of scipy_stats_test because it is timing out in OSS builds.
PiperOrigin-RevId: 485659146
2022-11-02 12:12:54 -07:00
Yash Katariya
08fbbc3618 Skip checkify checks for pjrt C API because of a failing pjit test with jax.Array is switched on.
PiperOrigin-RevId: 485190494
2022-10-31 17:44:21 -07:00
Jake VanderPlas
f699acd931 Internal change
PiperOrigin-RevId: 484330913
2022-10-27 13:12:38 -07:00
Sharad Vikram
4d94908e6b Add more shards to for_loop_test_cpu
PiperOrigin-RevId: 484309176
2022-10-27 11:42:10 -07:00
Peter Hawkins
9ab88071a7 Avoid loading scipy eagerly.
scipy accounts for around 400ms of the 900ms of JAX's import time. By
loading scipy lazily, we can improve the timing of `import jax` down to
about 500ms.
2022-10-12 19:51:09 +00:00
Jake VanderPlas
439217644a Split parts of lax_numpy_test.py into separate test files.
Why? The main test file is getting too big and this hinders iteration on individual tests

PiperOrigin-RevId: 478130215
2022-09-30 19:38:11 -07:00
Yash Katariya
fb8558cfdd Add jax_array coverage to debug_nans_test
PiperOrigin-RevId: 478079509
2022-09-30 14:21:32 -07:00
Yash Katariya
3c7d927a2c Disable dynamic_api_test and custom_object_test.py with jax.Array. Enable it back when support for it is added. Also don't use xla_shape since its deprecated.
PiperOrigin-RevId: 477833061
2022-09-29 15:09:55 -07:00
Peter Hawkins
d63a9442bb Change jax_jit_test to be a jax_test() under Bazel that works across backends.
Make it pass under TPU if x64 types are enabled.

PiperOrigin-RevId: 476994286
2022-09-26 14:38:35 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Tyler Augustine
d52de206cb Disable tests that timeout in debug mode in CI
PiperOrigin-RevId: 476157051
2022-09-22 11:44:56 -07:00
Peter Hawkins
d0e1c3e684 Disable tests under sanitizers that are timing out in CI.
PiperOrigin-RevId: 475839926
2022-09-21 08:50:55 -07:00
Yash Katariya
eec1b4a017 Set the sharding of uncommitted single device sharding Arrays correctly and fix some miscellaneous tests with Array too. Enable pjit_test and xmap_test with Array too (all of them are mechanical changes).
PiperOrigin-RevId: 474858389
2022-09-16 11:16:27 -07:00
Yash Katariya
45e48b3a7d Mark multiprocess_gpu_test as manual to skip it in OSS
PiperOrigin-RevId: 474806518
2022-09-16 07:08:08 -07:00
Yash Katariya
28741b8e0d Some miscellaneous changes to make tests pass when jax.Array is enabled by default.
1. Add `device_buffer` and `device_buffers` fields to Array as a backwards compatible change for DA and SDA.
2. Support PartitionSpecs as input to in_axis_resources and out_axis_resources when jax_array is enabled as a backwards compatible change since all user code uses this currently. Create a MeshPspecSharding internally.
3. Some tests changes to make them pass

PiperOrigin-RevId: 474642889
2022-09-15 13:27:40 -07:00
Jake VanderPlas
13a7034e6a Internal change
PiperOrigin-RevId: 474331907
2022-09-14 10:39:38 -07:00
jax authors
b27d8c1267 Merge pull request #12342 from jakevdp:typing-test
PiperOrigin-RevId: 474154750
2022-09-13 16:36:49 -07:00
Jake VanderPlas
b3c31ebe7d Add typing_test.py 2022-09-13 12:43:51 -07:00
Sharad Vikram
ad326b99da Use cases_from_list to subsample enumerated cases in for_loop_test
PiperOrigin-RevId: 474093596
2022-09-13 12:34:10 -07:00
Sharad Vikram
dc4922fd08 Bump shards on for_loop_test
PiperOrigin-RevId: 474038276
2022-09-13 09:30:04 -07:00
Sharad Vikram
e5725f1df1 Split for_loop_test out of lax_control_flow_test
PiperOrigin-RevId: 473848277
2022-09-12 14:46:07 -07:00
Yash Katariya
864849d075 Increase the GPU shard count to 40 since its timing out every once in a while in kokoro.
PiperOrigin-RevId: 473617463
2022-09-11 14:30:45 -07:00
Sharad Vikram
b6c3b9df19 Split State effect into Read/Write/Accum effects and tie them to Ref avals 2022-09-08 08:04:13 -07:00
jax authors
c8ac6dbe6b [PJRT:C] Control whether to call wrapped C++ PJRT client directly with a global variable kPjRtCApiBypass.
Default value of `kPjRtCApiBypass` is false.

PiperOrigin-RevId: 472847147
2022-09-07 16:52:07 -07:00
Yash Katariya
6340952e2a Make jit == pjit. This means that the lowering and execution paths of jit and pjit are merged.
A fallback to `lower_xla_callable` is taken when pmap appears in the jaxpr during the jit lowering path.

Added support for `keep_unused`, `committed` and `core.Token` to pxla.py.

PiperOrigin-RevId: 470896270
2022-08-29 22:03:21 -07:00
Sudhakar
4b1a2eaaec combine gpu tests 2022-08-25 15:27:07 -07:00
jax authors
2a00533e3e Internal change
PiperOrigin-RevId: 469730523
2022-08-24 08:16:46 -07:00
Peter Hawkins
da4e79a625 Increase some test shardings to reduce CI timeouts under asan/tsan/msan.
PiperOrigin-RevId: 469173324
2022-08-22 06:52:04 -07:00
Yash Katariya
f905d989c1 Make eager pmap tests pass with Array. Also add a slow path for Array in pmap similar to what SDA has. This is required for eager pmap. Adding a slow path removes the need for doing sharding checks in api.py because SDA doesn't do those checks and if the sharding does not match with pmap sharding, then it just defaults to the slow path (exactly like SDA).
PiperOrigin-RevId: 468843310
2022-08-19 21:37:22 -07:00
Yash Katariya
314cf8a439 Use .device() to get the device and platform from the device and fix TODO to point to github issue
PiperOrigin-RevId: 468769708
2022-08-19 13:14:13 -07:00
Yash Katariya
d77848bcc9 Enable jax_array on CPU for the entire JAX test suite!
PiperOrigin-RevId: 468726200
2022-08-19 10:04:35 -07:00
Yash Katariya
2231b0c054 Add array test counterparts to dtypes_test.
PiperOrigin-RevId: 468568788
2022-08-18 15:58:39 -07:00
Yash Katariya
acdae7c237 Add weak type support to Array. Also make all api_test.py tests pass with Array. I have disabled the float0 test for now until I investigate.
PiperOrigin-RevId: 468264910
2022-08-17 12:25:49 -07:00
Sharad Vikram
8068e4638c Re-bump shard count for pmap_test
PiperOrigin-RevId: 468239588
2022-08-17 10:46:19 -07:00
jax authors
7721579700 Internal change
PiperOrigin-RevId: 468068879
2022-08-16 17:43:03 -07:00
Sharad Vikram
6ae46c3d69 Bump pmap_test size to handle new eager tests
PiperOrigin-RevId: 467967021
2022-08-16 10:45:50 -07:00
Peter Hawkins
cd62dbffb3 Increase Bazel sharding of scipy_signal_test on CPU. 2022-08-12 15:14:35 +00:00
Peter Hawkins
3a693e98c0 Increase Bazel sharding of long-running CPU tests. 2022-08-12 13:21:40 +00:00
Yash Katariya
33c4fc4fe2 Pmap should output SDA like Arrays to maintain the current behavior exactly. Split the shard_arg_handler for Array based on whether the mode is pmap or pjit. Why do this? The doc below explains more about the context.
PiperOrigin-RevId: 466849614
2022-08-10 20:11:37 -07:00
Sharad Vikram
2d72dc899b Enable receives in TPU callbacks and add tests
PiperOrigin-RevId: 466103306
2022-08-08 11:42:16 -07:00
Peter Hawkins
efe1b6b44c Disable tsan on linalg_test on TPU.
The test is too slow under tsan and times out in CI.

PiperOrigin-RevId: 465570229
2022-08-05 08:33:42 -07:00
Peter Hawkins
b865111996 Refactor BUILD files to avoid individually naming Python dependencies.
Add a parametric py_deps() macro for adding Python package dependencies for Bazel rules.

Fix build failure with dangling matplotlib reference.

PiperOrigin-RevId: 465562141
2022-08-05 07:49:20 -07:00