jax authors
d8c487b5c7
Merge pull request #15956 from sharadmv:pure-callback-maximal
...
PiperOrigin-RevId: 531304370
2023-05-11 14:14:49 -07:00
jax authors
0037ab6240
[PJRT C API] Check whether the PJRT_Api* for the device type already exists before calling dlopen and dlsym.
...
PiperOrigin-RevId: 531295150
2023-05-11 13:43:17 -07:00
Sharad Vikram
61f22676b0
Add maximal sharding for pure_callback not inside of a shard_map
2023-05-11 13:28:37 -07:00
Parker Schuh
11b34a90fd
Skip stream-executor for aot_test.py.
...
PiperOrigin-RevId: 531248352
2023-05-11 10:51:32 -07:00
Peter Hawkins
9471bb3045
Disable sparsify_test on CPU under tsan.
...
Under tsan this test times out in CI.
PiperOrigin-RevId: 531210930
2023-05-11 08:33:35 -07:00
Matthew Johnson
f55de18933
[checkify] fix closed_call_p handling
...
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Sharad Vikram <sharadmv@google.com>
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2023-05-10 22:00:16 -07:00
Parker Schuh
261ff9e9ed
Stop passing CompileOptions when deserializing.
...
PiperOrigin-RevId: 531034200
2023-05-10 16:22:54 -07:00
jax authors
74df2d758a
Merge pull request #15603 from mattjj:shmap-call-lowering
...
PiperOrigin-RevId: 530996233
2023-05-10 13:49:51 -07:00
Matthew Johnson
8b66f073d1
[shard-map] experiment with lowering to a Call with attrs
...
Co-authored-by: Bart Chrzaszcz <bartchr@google.com>
2023-05-10 13:14:04 -07:00
Jake VanderPlas
b250c706b0
Allow opaque dtypes in grad with allow_int=True
2023-05-10 11:43:17 -07:00
Jake VanderPlas
6ada8785aa
PRNGKeyArray: fix dynamic slice index dtype
2023-05-10 09:24:18 -07:00
jax authors
538c680e04
Merge pull request #15943 from mattjj:custom-jvp-checkify-symzeros
...
PiperOrigin-RevId: 530907814
2023-05-10 07:56:40 -07:00
Yash Katariya
befa29b566
Fix the cache on to_gspmd_sharding
to depend on if device/backend is set on pjit/jit.
...
Before if a SingleDeviceSharding went via `to_gspmd_sharding` and then the same SingleDeviceSharding (created when device/backend is set) went via `to_gspmd_sharding`, we would hit the cache and return the first SingleDeviceSharding which didn't have the dynamic attribute on it.
This would eventually cause errors down the stack. The fix is to explicitly thread this argument through all the caches so we miss them and create the correct sharding.
PiperOrigin-RevId: 530712918
2023-05-09 14:24:21 -07:00
Matthew Johnson
391e95a683
fix checkify custom_jvp rule to handle symbolic zeros
...
likely broken in #15426 , or maybe not quite right before either
Co-authored-by: Roy Frostig <frostig@google.com>
2023-05-09 14:12:53 -07:00
jax authors
a2b5bd5230
Merge pull request #15931 from geraschenko:bcoo_reshape
...
PiperOrigin-RevId: 530657565
2023-05-09 10:58:53 -07:00
Anton Geraschenko
27aa5fb774
Make dimensions
argument of bcoo_reshape optional.
2023-05-09 10:38:18 -07:00
jax authors
cf4c1edafa
Merge pull request #15920 from froystig:issue15869
...
PiperOrigin-RevId: 530634021
2023-05-09 09:39:48 -07:00
Roy Frostig
051c5dda6e
delegate select
lowering to opaque dtype rule
...
... and implement it for PRNG key arrays
2023-05-08 19:02:42 -07:00
Peter Hawkins
00b75aff82
Add tests for negative inputs to top-k.
...
Make the top-k test inputs larger.
This test would have caught the top-k bug fixed by https://github.com/openxla/xla/pull/2809
PiperOrigin-RevId: 530398528
2023-05-08 13:51:24 -07:00
Yash Katariya
1629c6c76b
Make jax.jit
work with vmap(..., spmd_axis_name) when there is no mesh context manager.
...
This will only work if the input Array's sharding is a NamedSharding
Fixes https://github.com/google/jax/issues/15886
PiperOrigin-RevId: 529758233
2023-05-05 10:48:33 -07:00
Peter Hawkins
e8c735125c
Disable more tests that are flaky in CI.
...
PiperOrigin-RevId: 529724306
2023-05-05 08:33:33 -07:00
Yash Katariya
a6254c75e0
Improve the shape incompatible error message by adding the argument/result name path to it.
...
PiperOrigin-RevId: 529605855
2023-05-04 21:50:04 -07:00
Jake VanderPlas
4db717c52a
KeyArray: support make_array_from_* APIs
2023-05-04 16:32:49 -07:00
pizzud
40d730be49
aot_test: Stop forcing XLA to assume a certain number of devices.
...
Test cases are still frequently skipped due to lack of CompileOptions
support, but the skip/run behavior does not seem to meaningfully change
compared to a clean checkout. This was verified by inserting an exception
in place of unittest.SkipTest.
PiperOrigin-RevId: 529437419
2023-05-04 09:53:26 -07:00
jax authors
68614b4dcc
[XLA:TPU] Fix a bug in eigh that caused a slight loss of accuracy.
...
PiperOrigin-RevId: 529406623
2023-05-04 07:49:04 -07:00
Peter Hawkins
09fce87f54
Increase sharding of or disable some flaky CI tests.
...
PiperOrigin-RevId: 529405705
2023-05-04 07:41:56 -07:00
Yash Katariya
bffddf76cb
Improve the error raised when wsc is passed a PartitionSpec without a mesh context manager
...
PiperOrigin-RevId: 529260748
2023-05-03 19:35:51 -07:00
jax authors
d84d19b7d1
Merge pull request #15846 from jakevdp:deprecate-make-sharded
...
PiperOrigin-RevId: 529172585
2023-05-03 13:02:33 -07:00
Yash Katariya
9515ccf376
Fix pjit + vmap when device
is passed as an argument to pjit/jit
...
PiperOrigin-RevId: 529155035
2023-05-03 11:55:23 -07:00
Jake VanderPlas
9cfe77d5e1
Remove use of deprecated make_sharded_device_array
2023-05-03 10:11:29 -07:00
jax authors
5d143e6eea
Merge pull request #15818 from froystig:random-bits-direct
...
PiperOrigin-RevId: 529090390
2023-05-03 07:56:17 -07:00
Rahul Joshi
9d750ae97d
Fix pjit outfeed test avoid potential deadlocks.
...
PiperOrigin-RevId: 529076350
2023-05-03 06:51:26 -07:00
Benjamin Kramer
545c483e50
Re-enable testTruncNormPdf on CPU
...
Breaking change was reverted in LLVM 3b8bc83527
PiperOrigin-RevId: 529072697
2023-05-03 06:31:59 -07:00
Roy Frostig
ea3389205f
add jax.random.bits
2023-05-03 06:10:05 -07:00
Yash Katariya
7530ac1e09
Improve the error message for incompatible avals when the aval is a scalar
...
PiperOrigin-RevId: 528918215
2023-05-02 16:22:30 -07:00
Yash Katariya
6506ee2a40
Copybara import of the project:
...
--
57af5360a1ca1356dbf7760c76e241f7134ef6dd by Jake VanderPlas <jakevdp@google.com>:
[Rollback] Update required Python version to 3.9
PiperOrigin-RevId: 528905991
2023-05-02 15:33:29 -07:00
jax authors
162f09fc8d
Stop recursion in spectral bisection eigensolver when the remaining sub-matrix has norm less than epsilon times the input matrix norm, which means that it is pure numerical noise.
...
PiperOrigin-RevId: 528891206
2023-05-02 14:35:07 -07:00
Peter Hawkins
57e62ca03c
Reenable scipy_stats_test in CI.
...
Disable testTruncNormPdf on CPU, which is failing after an LLVM update.
PiperOrigin-RevId: 528884880
2023-05-02 14:11:08 -07:00
Yash Katariya
40349a8612
Normalize 1 length tuples to a string while getting PartitionSpec from array mapping.
...
Fixes https://github.com/google/jax/issues/15782
PiperOrigin-RevId: 528796985
2023-05-02 08:55:40 -07:00
jax authors
d5289e627f
Merge pull request #15804 from froystig:issue13949
...
PiperOrigin-RevId: 528790988
2023-05-02 08:30:46 -07:00
Yash Katariya
c52e48b6c0
Only return the same input Sharding object is the original aval's ndim and out_aval's ndim are the same.
...
This is because if both the OpShardings are replicated then the ndim is not encoded in the OpSharding and it will return True even if the Sharding is incompatible with the output's ndim. Concretely `NamedSharding({'x': 1, y: '2'}, P('x'))` is not compatible with a input with `ndim == 0`.
PiperOrigin-RevId: 528621971
2023-05-01 17:39:51 -07:00
jax authors
12e3db5fbc
Merge pull request #15813 from jakevdp:keyarray-device-put-sharded
...
PiperOrigin-RevId: 528578837
2023-05-01 14:41:55 -07:00
Jake VanderPlas
979aa3235b
KeyArray: implement sharded & replicated device_put
2023-05-01 14:17:01 -07:00
Skye Wanderman-Milne
70cac773f7
Exclude scipy_fft_test from msan as well as t/asan.
...
PiperOrigin-RevId: 528562775
2023-05-01 13:42:24 -07:00
Skye Wanderman-Milne
fa68c1f882
Bump up lax_test TPU sharding to avoid asan timeouts
...
PiperOrigin-RevId: 528559870
2023-05-01 13:31:22 -07:00
Yash Katariya
4a3fb238f6
Return the same sharding object if the output OpSharding matches the input OpSharding.
...
Fixes https://github.com/google/jax/issues/15782
PiperOrigin-RevId: 528531594
2023-05-01 11:46:57 -07:00
Jake VanderPlas
57af5360a1
Update required Python version to 3.9
2023-05-01 10:00:57 -07:00
Roy Frostig
8d4d520933
resolve opaque dtypes in MLIR callback lowering and in XLA shape translation
2023-05-01 08:21:54 -07:00
Skye Wanderman-Milne
c662fd216d
Disable tsan CI for random_test_with_custom_prng to avoid timeouts.
...
asan is already disabled, and the comment and "cpu" case indicates
that tsan should already have been disabled as well.
PiperOrigin-RevId: 528000458
2023-04-28 15:26:46 -07:00
Jake VanderPlas
054fca5cd4
KeyArray: define itemsize on opaque dtype
2023-04-27 15:59:57 -07:00