5974 Commits

Author SHA1 Message Date
jax authors
d8c487b5c7 Merge pull request #15956 from sharadmv:pure-callback-maximal
PiperOrigin-RevId: 531304370
2023-05-11 14:14:49 -07:00
jax authors
0037ab6240 [PJRT C API] Check whether the PJRT_Api* for the device type already exists before calling dlopen and dlsym.
PiperOrigin-RevId: 531295150
2023-05-11 13:43:17 -07:00
Sharad Vikram
61f22676b0 Add maximal sharding for pure_callback not inside of a shard_map 2023-05-11 13:28:37 -07:00
Parker Schuh
11b34a90fd Skip stream-executor for aot_test.py.
PiperOrigin-RevId: 531248352
2023-05-11 10:51:32 -07:00
Peter Hawkins
9471bb3045 Disable sparsify_test on CPU under tsan.
Under tsan this test times out in CI.

PiperOrigin-RevId: 531210930
2023-05-11 08:33:35 -07:00
Matthew Johnson
f55de18933 [checkify] fix closed_call_p handling
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Sharad Vikram <sharadmv@google.com>
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2023-05-10 22:00:16 -07:00
Parker Schuh
261ff9e9ed Stop passing CompileOptions when deserializing.
PiperOrigin-RevId: 531034200
2023-05-10 16:22:54 -07:00
jax authors
74df2d758a Merge pull request #15603 from mattjj:shmap-call-lowering
PiperOrigin-RevId: 530996233
2023-05-10 13:49:51 -07:00
Matthew Johnson
8b66f073d1 [shard-map] experiment with lowering to a Call with attrs
Co-authored-by: Bart Chrzaszcz <bartchr@google.com>
2023-05-10 13:14:04 -07:00
Jake VanderPlas
b250c706b0 Allow opaque dtypes in grad with allow_int=True 2023-05-10 11:43:17 -07:00
Jake VanderPlas
6ada8785aa PRNGKeyArray: fix dynamic slice index dtype 2023-05-10 09:24:18 -07:00
jax authors
538c680e04 Merge pull request #15943 from mattjj:custom-jvp-checkify-symzeros
PiperOrigin-RevId: 530907814
2023-05-10 07:56:40 -07:00
Yash Katariya
befa29b566 Fix the cache on to_gspmd_sharding to depend on if device/backend is set on pjit/jit.
Before if a SingleDeviceSharding went via `to_gspmd_sharding` and then the same SingleDeviceSharding (created when device/backend is set) went via `to_gspmd_sharding`, we would hit the cache and return the first SingleDeviceSharding which didn't have the dynamic attribute on it.

This would eventually cause errors down the stack. The fix is to explicitly thread this argument through all the caches so we miss them and create the correct sharding.

PiperOrigin-RevId: 530712918
2023-05-09 14:24:21 -07:00
Matthew Johnson
391e95a683 fix checkify custom_jvp rule to handle symbolic zeros
likely broken in #15426, or maybe not quite right before either

Co-authored-by: Roy Frostig <frostig@google.com>
2023-05-09 14:12:53 -07:00
jax authors
a2b5bd5230 Merge pull request #15931 from geraschenko:bcoo_reshape
PiperOrigin-RevId: 530657565
2023-05-09 10:58:53 -07:00
Anton Geraschenko
27aa5fb774 Make dimensions argument of bcoo_reshape optional. 2023-05-09 10:38:18 -07:00
jax authors
cf4c1edafa Merge pull request #15920 from froystig:issue15869
PiperOrigin-RevId: 530634021
2023-05-09 09:39:48 -07:00
Roy Frostig
051c5dda6e delegate select lowering to opaque dtype rule
... and implement it for PRNG key arrays
2023-05-08 19:02:42 -07:00
Peter Hawkins
00b75aff82 Add tests for negative inputs to top-k.
Make the top-k test inputs larger.

This test would have caught the top-k bug fixed by https://github.com/openxla/xla/pull/2809

PiperOrigin-RevId: 530398528
2023-05-08 13:51:24 -07:00
Yash Katariya
1629c6c76b Make jax.jit work with vmap(..., spmd_axis_name) when there is no mesh context manager.
This will only work if the input Array's sharding is a NamedSharding

Fixes https://github.com/google/jax/issues/15886

PiperOrigin-RevId: 529758233
2023-05-05 10:48:33 -07:00
Peter Hawkins
e8c735125c Disable more tests that are flaky in CI.
PiperOrigin-RevId: 529724306
2023-05-05 08:33:33 -07:00
Yash Katariya
a6254c75e0 Improve the shape incompatible error message by adding the argument/result name path to it.
PiperOrigin-RevId: 529605855
2023-05-04 21:50:04 -07:00
Jake VanderPlas
4db717c52a KeyArray: support make_array_from_* APIs 2023-05-04 16:32:49 -07:00
pizzud
40d730be49 aot_test: Stop forcing XLA to assume a certain number of devices.
Test cases are still frequently skipped due to lack of CompileOptions
support, but the skip/run behavior does not seem to meaningfully change
compared to a clean checkout. This was verified by inserting an exception
in place of unittest.SkipTest.

PiperOrigin-RevId: 529437419
2023-05-04 09:53:26 -07:00
jax authors
68614b4dcc [XLA:TPU] Fix a bug in eigh that caused a slight loss of accuracy.
PiperOrigin-RevId: 529406623
2023-05-04 07:49:04 -07:00
Peter Hawkins
09fce87f54 Increase sharding of or disable some flaky CI tests.
PiperOrigin-RevId: 529405705
2023-05-04 07:41:56 -07:00
Yash Katariya
bffddf76cb Improve the error raised when wsc is passed a PartitionSpec without a mesh context manager
PiperOrigin-RevId: 529260748
2023-05-03 19:35:51 -07:00
jax authors
d84d19b7d1 Merge pull request #15846 from jakevdp:deprecate-make-sharded
PiperOrigin-RevId: 529172585
2023-05-03 13:02:33 -07:00
Yash Katariya
9515ccf376 Fix pjit + vmap when device is passed as an argument to pjit/jit
PiperOrigin-RevId: 529155035
2023-05-03 11:55:23 -07:00
Jake VanderPlas
9cfe77d5e1 Remove use of deprecated make_sharded_device_array 2023-05-03 10:11:29 -07:00
jax authors
5d143e6eea Merge pull request #15818 from froystig:random-bits-direct
PiperOrigin-RevId: 529090390
2023-05-03 07:56:17 -07:00
Rahul Joshi
9d750ae97d Fix pjit outfeed test avoid potential deadlocks.
PiperOrigin-RevId: 529076350
2023-05-03 06:51:26 -07:00
Benjamin Kramer
545c483e50 Re-enable testTruncNormPdf on CPU
Breaking change was reverted in LLVM 3b8bc83527

PiperOrigin-RevId: 529072697
2023-05-03 06:31:59 -07:00
Roy Frostig
ea3389205f add jax.random.bits 2023-05-03 06:10:05 -07:00
Yash Katariya
7530ac1e09 Improve the error message for incompatible avals when the aval is a scalar
PiperOrigin-RevId: 528918215
2023-05-02 16:22:30 -07:00
Yash Katariya
6506ee2a40 Copybara import of the project:
--
57af5360a1ca1356dbf7760c76e241f7134ef6dd by Jake VanderPlas <jakevdp@google.com>:

[Rollback] Update required Python version to 3.9

PiperOrigin-RevId: 528905991
2023-05-02 15:33:29 -07:00
jax authors
162f09fc8d Stop recursion in spectral bisection eigensolver when the remaining sub-matrix has norm less than epsilon times the input matrix norm, which means that it is pure numerical noise.
PiperOrigin-RevId: 528891206
2023-05-02 14:35:07 -07:00
Peter Hawkins
57e62ca03c Reenable scipy_stats_test in CI.
Disable testTruncNormPdf on CPU, which is failing after an LLVM update.

PiperOrigin-RevId: 528884880
2023-05-02 14:11:08 -07:00
Yash Katariya
40349a8612 Normalize 1 length tuples to a string while getting PartitionSpec from array mapping.
Fixes https://github.com/google/jax/issues/15782

PiperOrigin-RevId: 528796985
2023-05-02 08:55:40 -07:00
jax authors
d5289e627f Merge pull request #15804 from froystig:issue13949
PiperOrigin-RevId: 528790988
2023-05-02 08:30:46 -07:00
Yash Katariya
c52e48b6c0 Only return the same input Sharding object is the original aval's ndim and out_aval's ndim are the same.
This is because if both the OpShardings are replicated then the ndim is not encoded in the OpSharding and it will return True even if the Sharding is incompatible with the output's ndim. Concretely `NamedSharding({'x': 1, y: '2'}, P('x'))` is not compatible with a input with `ndim == 0`.

PiperOrigin-RevId: 528621971
2023-05-01 17:39:51 -07:00
jax authors
12e3db5fbc Merge pull request #15813 from jakevdp:keyarray-device-put-sharded
PiperOrigin-RevId: 528578837
2023-05-01 14:41:55 -07:00
Jake VanderPlas
979aa3235b KeyArray: implement sharded & replicated device_put 2023-05-01 14:17:01 -07:00
Skye Wanderman-Milne
70cac773f7 Exclude scipy_fft_test from msan as well as t/asan.
PiperOrigin-RevId: 528562775
2023-05-01 13:42:24 -07:00
Skye Wanderman-Milne
fa68c1f882 Bump up lax_test TPU sharding to avoid asan timeouts
PiperOrigin-RevId: 528559870
2023-05-01 13:31:22 -07:00
Yash Katariya
4a3fb238f6 Return the same sharding object if the output OpSharding matches the input OpSharding.
Fixes https://github.com/google/jax/issues/15782

PiperOrigin-RevId: 528531594
2023-05-01 11:46:57 -07:00
Jake VanderPlas
57af5360a1 Update required Python version to 3.9 2023-05-01 10:00:57 -07:00
Roy Frostig
8d4d520933 resolve opaque dtypes in MLIR callback lowering and in XLA shape translation 2023-05-01 08:21:54 -07:00
Skye Wanderman-Milne
c662fd216d Disable tsan CI for random_test_with_custom_prng to avoid timeouts.
asan is already disabled, and the comment and "cpu" case indicates
that tsan should already have been disabled as well.

PiperOrigin-RevId: 528000458
2023-04-28 15:26:46 -07:00
Jake VanderPlas
054fca5cd4 KeyArray: define itemsize on opaque dtype 2023-04-27 15:59:57 -07:00