Peter Hawkins
51c69ac594
Tag several tests as optonly to prevent test timeouts in debug mode CI builds.
...
PiperOrigin-RevId: 488950972
2022-11-16 08:50:23 -08:00
Peter Hawkins
0548c2d23b
Disable sanitizer builds that are timing out or that are incompatible with test targets.
...
PiperOrigin-RevId: 488919571
2022-11-16 06:00:37 -08:00
jax authors
726b2bc2ee
Add JAX monitoring library that instruments code via events.
...
PiperOrigin-RevId: 488731805
2022-11-15 12:41:41 -08:00
Peter Hawkins
da130cb074
Disable more tests under tsan/asan.
...
PiperOrigin-RevId: 488406459
2022-11-14 10:34:30 -08:00
Peter Hawkins
aa658bde6f
Disable asan/tsan for a number of slow tests.
...
PiperOrigin-RevId: 488356786
2022-11-14 07:12:16 -08:00
Yash Katariya
cc41ee85c4
Mark scipy_signal_test and sparse_test optonly
because it times out under debug mode.
...
PiperOrigin-RevId: 487533356
2022-11-10 07:38:58 -08:00
Yash Katariya
71360edf90
Bump the shard count for TPU to avoid timeouts
...
PiperOrigin-RevId: 487421018
2022-11-09 20:32:12 -08:00
Peter Hawkins
e42e52d4aa
Rename test flag --num_generated_cases to --jax_num_generated_cases.
...
parse_flags_with_absl() only parses flags that start with --jax_. Other flags are only parsed when absl.app's main function runs. But that's too late for test cases: test cases need to have the number of generated cases chosen at module initialization time. Hence the --num_generated_cases flag wasn't doing anything. Oops. By renaming it it works once again.
It might make sense to stop using flags for the number of generated cases and only use environment variables. We defer that to a future change.
Fix many test cases that were shown to be broken with a larger number of test cases enabled.
PiperOrigin-RevId: 487406670
2022-11-09 18:58:05 -08:00
Yash Katariya
2dc804371c
Increase the shard count of scipy_stats_test because it is timing out in OSS builds.
...
PiperOrigin-RevId: 485659146
2022-11-02 12:12:54 -07:00
Yash Katariya
08fbbc3618
Skip checkify checks for pjrt C API because of a failing pjit test with jax.Array is switched on.
...
PiperOrigin-RevId: 485190494
2022-10-31 17:44:21 -07:00
Jake VanderPlas
f699acd931
Internal change
...
PiperOrigin-RevId: 484330913
2022-10-27 13:12:38 -07:00
Sharad Vikram
4d94908e6b
Add more shards to for_loop_test_cpu
...
PiperOrigin-RevId: 484309176
2022-10-27 11:42:10 -07:00
Peter Hawkins
9ab88071a7
Avoid loading scipy eagerly.
...
scipy accounts for around 400ms of the 900ms of JAX's import time. By
loading scipy lazily, we can improve the timing of `import jax` down to
about 500ms.
2022-10-12 19:51:09 +00:00
Jake VanderPlas
439217644a
Split parts of lax_numpy_test.py into separate test files.
...
Why? The main test file is getting too big and this hinders iteration on individual tests
PiperOrigin-RevId: 478130215
2022-09-30 19:38:11 -07:00
Yash Katariya
fb8558cfdd
Add jax_array coverage to debug_nans_test
...
PiperOrigin-RevId: 478079509
2022-09-30 14:21:32 -07:00
Yash Katariya
3c7d927a2c
Disable dynamic_api_test and custom_object_test.py with jax.Array. Enable it back when support for it is added. Also don't use xla_shape since its deprecated.
...
PiperOrigin-RevId: 477833061
2022-09-29 15:09:55 -07:00
Peter Hawkins
d63a9442bb
Change jax_jit_test to be a jax_test() under Bazel that works across backends.
...
Make it pass under TPU if x64 types are enabled.
PiperOrigin-RevId: 476994286
2022-09-26 14:38:35 -07:00
Peter Hawkins
ba557d5e1b
Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
...
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.
PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Tyler Augustine
d52de206cb
Disable tests that timeout in debug mode in CI
...
PiperOrigin-RevId: 476157051
2022-09-22 11:44:56 -07:00
Peter Hawkins
d0e1c3e684
Disable tests under sanitizers that are timing out in CI.
...
PiperOrigin-RevId: 475839926
2022-09-21 08:50:55 -07:00
Yash Katariya
eec1b4a017
Set the sharding of uncommitted single device sharding Arrays correctly and fix some miscellaneous tests with Array too. Enable pjit_test and xmap_test with Array too (all of them are mechanical changes).
...
PiperOrigin-RevId: 474858389
2022-09-16 11:16:27 -07:00
Yash Katariya
45e48b3a7d
Mark multiprocess_gpu_test as manual to skip it in OSS
...
PiperOrigin-RevId: 474806518
2022-09-16 07:08:08 -07:00
Yash Katariya
28741b8e0d
Some miscellaneous changes to make tests pass when jax.Array is enabled by default.
...
1. Add `device_buffer` and `device_buffers` fields to Array as a backwards compatible change for DA and SDA.
2. Support PartitionSpecs as input to in_axis_resources and out_axis_resources when jax_array is enabled as a backwards compatible change since all user code uses this currently. Create a MeshPspecSharding internally.
3. Some tests changes to make them pass
PiperOrigin-RevId: 474642889
2022-09-15 13:27:40 -07:00
Jake VanderPlas
13a7034e6a
Internal change
...
PiperOrigin-RevId: 474331907
2022-09-14 10:39:38 -07:00
jax authors
b27d8c1267
Merge pull request #12342 from jakevdp:typing-test
...
PiperOrigin-RevId: 474154750
2022-09-13 16:36:49 -07:00
Jake VanderPlas
b3c31ebe7d
Add typing_test.py
2022-09-13 12:43:51 -07:00
Sharad Vikram
ad326b99da
Use cases_from_list to subsample enumerated cases in for_loop_test
...
PiperOrigin-RevId: 474093596
2022-09-13 12:34:10 -07:00
Sharad Vikram
dc4922fd08
Bump shards on for_loop_test
...
PiperOrigin-RevId: 474038276
2022-09-13 09:30:04 -07:00
Sharad Vikram
e5725f1df1
Split for_loop_test out of lax_control_flow_test
...
PiperOrigin-RevId: 473848277
2022-09-12 14:46:07 -07:00
Yash Katariya
864849d075
Increase the GPU shard count to 40 since its timing out every once in a while in kokoro.
...
PiperOrigin-RevId: 473617463
2022-09-11 14:30:45 -07:00
Sharad Vikram
b6c3b9df19
Split State effect into Read/Write/Accum effects and tie them to Ref avals
2022-09-08 08:04:13 -07:00
jax authors
c8ac6dbe6b
[PJRT:C] Control whether to call wrapped C++ PJRT client directly with a global variable kPjRtCApiBypass
.
...
Default value of `kPjRtCApiBypass` is false.
PiperOrigin-RevId: 472847147
2022-09-07 16:52:07 -07:00
Yash Katariya
6340952e2a
Make jit == pjit. This means that the lowering and execution paths of jit and pjit are merged.
...
A fallback to `lower_xla_callable` is taken when pmap appears in the jaxpr during the jit lowering path.
Added support for `keep_unused`, `committed` and `core.Token` to pxla.py.
PiperOrigin-RevId: 470896270
2022-08-29 22:03:21 -07:00
Sudhakar
4b1a2eaaec
combine gpu tests
2022-08-25 15:27:07 -07:00
jax authors
2a00533e3e
Internal change
...
PiperOrigin-RevId: 469730523
2022-08-24 08:16:46 -07:00
Peter Hawkins
da4e79a625
Increase some test shardings to reduce CI timeouts under asan/tsan/msan.
...
PiperOrigin-RevId: 469173324
2022-08-22 06:52:04 -07:00
Yash Katariya
f905d989c1
Make eager pmap tests pass with Array
. Also add a slow path for Array in pmap
similar to what SDA has. This is required for eager pmap. Adding a slow path removes the need for doing sharding checks in api.py because SDA doesn't do those checks and if the sharding does not match with pmap sharding, then it just defaults to the slow path (exactly like SDA).
...
PiperOrigin-RevId: 468843310
2022-08-19 21:37:22 -07:00
Yash Katariya
314cf8a439
Use .device()
to get the device and platform from the device and fix TODO to point to github issue
...
PiperOrigin-RevId: 468769708
2022-08-19 13:14:13 -07:00
Yash Katariya
d77848bcc9
Enable jax_array
on CPU for the entire JAX test suite!
...
PiperOrigin-RevId: 468726200
2022-08-19 10:04:35 -07:00
Yash Katariya
2231b0c054
Add array test counterparts to dtypes_test.
...
PiperOrigin-RevId: 468568788
2022-08-18 15:58:39 -07:00
Yash Katariya
acdae7c237
Add weak type support to Array. Also make all api_test.py tests pass with Array. I have disabled the float0
test for now until I investigate.
...
PiperOrigin-RevId: 468264910
2022-08-17 12:25:49 -07:00
Sharad Vikram
8068e4638c
Re-bump shard count for pmap_test
...
PiperOrigin-RevId: 468239588
2022-08-17 10:46:19 -07:00
jax authors
7721579700
Internal change
...
PiperOrigin-RevId: 468068879
2022-08-16 17:43:03 -07:00
Sharad Vikram
6ae46c3d69
Bump pmap_test size to handle new eager tests
...
PiperOrigin-RevId: 467967021
2022-08-16 10:45:50 -07:00
Peter Hawkins
cd62dbffb3
Increase Bazel sharding of scipy_signal_test on CPU.
2022-08-12 15:14:35 +00:00
Peter Hawkins
3a693e98c0
Increase Bazel sharding of long-running CPU tests.
2022-08-12 13:21:40 +00:00
Yash Katariya
33c4fc4fe2
Pmap should output SDA like Array
s to maintain the current behavior exactly. Split the shard_arg_handler for Array
based on whether the mode is pmap or pjit. Why do this? The doc below explains more about the context.
...
PiperOrigin-RevId: 466849614
2022-08-10 20:11:37 -07:00
Sharad Vikram
2d72dc899b
Enable receives in TPU callbacks and add tests
...
PiperOrigin-RevId: 466103306
2022-08-08 11:42:16 -07:00
Peter Hawkins
efe1b6b44c
Disable tsan on linalg_test on TPU.
...
The test is too slow under tsan and times out in CI.
PiperOrigin-RevId: 465570229
2022-08-05 08:33:42 -07:00
Peter Hawkins
b865111996
Refactor BUILD files to avoid individually naming Python dependencies.
...
Add a parametric py_deps() macro for adding Python package dependencies for Bazel rules.
Fix build failure with dangling matplotlib reference.
PiperOrigin-RevId: 465562141
2022-08-05 07:49:20 -07:00