Yash Katariya
4601928277
Enable jit_pjit_api_merge by default "in tests" and disable the current failing tests.
...
PiperOrigin-RevId: 502088044
2023-01-14 11:15:03 -08:00
Qiao Zhang
d203926c16
Expose fp8 in jax dtypes and mlir builder.
...
PiperOrigin-RevId: 501980811
2023-01-13 18:12:12 -08:00
Yash Katariya
5eb23a7615
Fix name_stack
usage of pjit. Now all the metadata of transformations in hlo are correct.
...
PiperOrigin-RevId: 501918212
2023-01-13 12:54:12 -08:00
Yash Katariya
649ee1be34
Make pickle_test.py pass with jit/pjit api merge. Also rename and move some functions around
...
PiperOrigin-RevId: 501878555
2023-01-13 10:16:01 -08:00
Sharad Vikram
c9a57e1b44
Delete jax.experimental.callback
...
PiperOrigin-RevId: 501760507
2023-01-12 22:58:31 -08:00
Yash Katariya
e21c29476d
Add batch_jaxpr2 which tells the caller where batch dims are.
...
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 501746795
2023-01-12 21:16:59 -08:00
Yash Katariya
94f0ccc54a
Fix host_callback for pjit which was using REPLICATED which was a CanonicalizedParsedPspec
...
PiperOrigin-RevId: 501713533
2023-01-12 18:00:33 -08:00
Yash Katariya
936247a7e5
Fix debugging primitives for pjit. This came up during jit/pjit merge
...
PiperOrigin-RevId: 501710198
2023-01-12 17:40:35 -08:00
Yash Katariya
c8ad89e358
Make jit
a thin wrapper around pjit
which ignores the mesh context manager (just like how it is today)
...
Pass `None` as the resource_env via `jit` because `jit(pjit)` will ignore the outer mesh because `jit` will set the resource env to empty mesh.
This does not make `jit` and `pjit` the same API but it shares all the code between both the APIs (cpp and python) while preserving the current semantics of both `jit` and `pjit`.
PiperOrigin-RevId: 501707496
2023-01-12 17:24:32 -08:00
Sharad Vikram
f729da4a36
Add shards for checkify_test on GPU
...
PiperOrigin-RevId: 501430172
2023-01-11 18:28:37 -08:00
Yash Katariya
66aafb6e16
Don't take the cpp dispatch path for pjit
if it contains ordered effects just like jit
.
...
PiperOrigin-RevId: 501141750
2023-01-10 18:07:23 -08:00
Yash Katariya
c447e987e1
Skip custom_object_test and dynamic_api_test for pjit/jit merge since it doesn't work with jax.Array's too.
...
PiperOrigin-RevId: 501129056
2023-01-10 16:55:51 -08:00
Yash Katariya
e02c1da4c7
Fix debug nans test after merging jit
and pjit
codepaths
...
PiperOrigin-RevId: 501122848
2023-01-10 16:27:00 -08:00
Yash Katariya
849af498d1
Make jaxpr_util_test work with jit/pjit merge
...
PiperOrigin-RevId: 500841015
2023-01-09 16:50:04 -08:00
Adam Paszke
904cd4b98d
Internal change
...
PiperOrigin-RevId: 499812920
2023-01-05 04:13:34 -08:00
Yash Katariya
c3bb26050c
Add pjit
rule to sparse_rules to support pjit
. This is done to merge the jit and pjit API.
...
PiperOrigin-RevId: 499311841
2023-01-03 14:13:19 -08:00
Peter Hawkins
401fbb61a9
Disable xmap_test on TPU under asan due to CI timeouts.
...
PiperOrigin-RevId: 492994226
2022-12-05 06:52:09 -08:00
Peter Hawkins
7495a9e370
[JAX] Enable/disable tests that timed out in CI.
...
Reenable pmap_test since it was recently sped up.
PiperOrigin-RevId: 491650701
2022-11-29 09:02:16 -08:00
Qiao Zhang
4d1c4bc761
Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
...
PiperOrigin-RevId: 491445515
2022-11-28 14:31:48 -08:00
jax authors
d1fbdbc1cf
Rollback of "Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module."
...
PiperOrigin-RevId: 490499003
2022-11-23 07:48:05 -08:00
Adam Paszke
fe56a19904
Shard fft tests to avoid timeouts
...
PiperOrigin-RevId: 490486632
2022-11-23 06:33:13 -08:00
Qiao Zhang
78963b6020
Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
...
PiperOrigin-RevId: 490387796
2022-11-22 18:53:29 -08:00
Peter Hawkins
61aa415356
Disable sparse_test_cpu under msan due to CI timeouts.
...
PiperOrigin-RevId: 490312188
2022-11-22 12:48:34 -08:00
jax authors
518fe6656c
Pickling of Sharding classes: use module level functions when deserializing.
...
This avoids having to pickle the sharding class (which references the module and the Python source file) in the serialized bytes, which happens when deserializing using `classmethod`s.
PiperOrigin-RevId: 490249959
2022-11-22 08:31:16 -08:00
Peter Hawkins
42e367af9c
Fix typo in "nomsan" tag on pmap_test.
...
PiperOrigin-RevId: 489978468
2022-11-21 07:46:13 -08:00
Peter Hawkins
ebee4f4bfd
Disable test variants that time out in CI.
...
PiperOrigin-RevId: 489214464
2022-11-17 08:14:07 -08:00
Jake VanderPlas
66262901f0
[sparse] improve testing framework
2022-11-16 09:58:06 -08:00
Peter Hawkins
51c69ac594
Tag several tests as optonly to prevent test timeouts in debug mode CI builds.
...
PiperOrigin-RevId: 488950972
2022-11-16 08:50:23 -08:00
Peter Hawkins
0548c2d23b
Disable sanitizer builds that are timing out or that are incompatible with test targets.
...
PiperOrigin-RevId: 488919571
2022-11-16 06:00:37 -08:00
jax authors
726b2bc2ee
Add JAX monitoring library that instruments code via events.
...
PiperOrigin-RevId: 488731805
2022-11-15 12:41:41 -08:00
Peter Hawkins
da130cb074
Disable more tests under tsan/asan.
...
PiperOrigin-RevId: 488406459
2022-11-14 10:34:30 -08:00
Peter Hawkins
aa658bde6f
Disable asan/tsan for a number of slow tests.
...
PiperOrigin-RevId: 488356786
2022-11-14 07:12:16 -08:00
Yash Katariya
cc41ee85c4
Mark scipy_signal_test and sparse_test optonly
because it times out under debug mode.
...
PiperOrigin-RevId: 487533356
2022-11-10 07:38:58 -08:00
Yash Katariya
71360edf90
Bump the shard count for TPU to avoid timeouts
...
PiperOrigin-RevId: 487421018
2022-11-09 20:32:12 -08:00
Peter Hawkins
e42e52d4aa
Rename test flag --num_generated_cases to --jax_num_generated_cases.
...
parse_flags_with_absl() only parses flags that start with --jax_. Other flags are only parsed when absl.app's main function runs. But that's too late for test cases: test cases need to have the number of generated cases chosen at module initialization time. Hence the --num_generated_cases flag wasn't doing anything. Oops. By renaming it it works once again.
It might make sense to stop using flags for the number of generated cases and only use environment variables. We defer that to a future change.
Fix many test cases that were shown to be broken with a larger number of test cases enabled.
PiperOrigin-RevId: 487406670
2022-11-09 18:58:05 -08:00
Yash Katariya
2dc804371c
Increase the shard count of scipy_stats_test because it is timing out in OSS builds.
...
PiperOrigin-RevId: 485659146
2022-11-02 12:12:54 -07:00
Yash Katariya
08fbbc3618
Skip checkify checks for pjrt C API because of a failing pjit test with jax.Array is switched on.
...
PiperOrigin-RevId: 485190494
2022-10-31 17:44:21 -07:00
Jake VanderPlas
f699acd931
Internal change
...
PiperOrigin-RevId: 484330913
2022-10-27 13:12:38 -07:00
Sharad Vikram
4d94908e6b
Add more shards to for_loop_test_cpu
...
PiperOrigin-RevId: 484309176
2022-10-27 11:42:10 -07:00
Peter Hawkins
9ab88071a7
Avoid loading scipy eagerly.
...
scipy accounts for around 400ms of the 900ms of JAX's import time. By
loading scipy lazily, we can improve the timing of `import jax` down to
about 500ms.
2022-10-12 19:51:09 +00:00
Jake VanderPlas
439217644a
Split parts of lax_numpy_test.py into separate test files.
...
Why? The main test file is getting too big and this hinders iteration on individual tests
PiperOrigin-RevId: 478130215
2022-09-30 19:38:11 -07:00
Yash Katariya
fb8558cfdd
Add jax_array coverage to debug_nans_test
...
PiperOrigin-RevId: 478079509
2022-09-30 14:21:32 -07:00
Yash Katariya
3c7d927a2c
Disable dynamic_api_test and custom_object_test.py with jax.Array. Enable it back when support for it is added. Also don't use xla_shape since its deprecated.
...
PiperOrigin-RevId: 477833061
2022-09-29 15:09:55 -07:00
Peter Hawkins
d63a9442bb
Change jax_jit_test to be a jax_test() under Bazel that works across backends.
...
Make it pass under TPU if x64 types are enabled.
PiperOrigin-RevId: 476994286
2022-09-26 14:38:35 -07:00
Peter Hawkins
ba557d5e1b
Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
...
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.
PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Tyler Augustine
d52de206cb
Disable tests that timeout in debug mode in CI
...
PiperOrigin-RevId: 476157051
2022-09-22 11:44:56 -07:00
Peter Hawkins
d0e1c3e684
Disable tests under sanitizers that are timing out in CI.
...
PiperOrigin-RevId: 475839926
2022-09-21 08:50:55 -07:00
Yash Katariya
eec1b4a017
Set the sharding of uncommitted single device sharding Arrays correctly and fix some miscellaneous tests with Array too. Enable pjit_test and xmap_test with Array too (all of them are mechanical changes).
...
PiperOrigin-RevId: 474858389
2022-09-16 11:16:27 -07:00
Yash Katariya
45e48b3a7d
Mark multiprocess_gpu_test as manual to skip it in OSS
...
PiperOrigin-RevId: 474806518
2022-09-16 07:08:08 -07:00
Yash Katariya
28741b8e0d
Some miscellaneous changes to make tests pass when jax.Array is enabled by default.
...
1. Add `device_buffer` and `device_buffers` fields to Array as a backwards compatible change for DA and SDA.
2. Support PartitionSpecs as input to in_axis_resources and out_axis_resources when jax_array is enabled as a backwards compatible change since all user code uses this currently. Create a MeshPspecSharding internally.
3. Some tests changes to make them pass
PiperOrigin-RevId: 474642889
2022-09-15 13:27:40 -07:00