Sharad Vikram
c9a57e1b44
Delete jax.experimental.callback
...
PiperOrigin-RevId: 501760507
2023-01-12 22:58:31 -08:00
Yash Katariya
e21c29476d
Add batch_jaxpr2 which tells the caller where batch dims are.
...
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 501746795
2023-01-12 21:16:59 -08:00
Yash Katariya
94f0ccc54a
Fix host_callback for pjit which was using REPLICATED which was a CanonicalizedParsedPspec
...
PiperOrigin-RevId: 501713533
2023-01-12 18:00:33 -08:00
Yash Katariya
936247a7e5
Fix debugging primitives for pjit. This came up during jit/pjit merge
...
PiperOrigin-RevId: 501710198
2023-01-12 17:40:35 -08:00
Yash Katariya
c8ad89e358
Make jit
a thin wrapper around pjit
which ignores the mesh context manager (just like how it is today)
...
Pass `None` as the resource_env via `jit` because `jit(pjit)` will ignore the outer mesh because `jit` will set the resource env to empty mesh.
This does not make `jit` and `pjit` the same API but it shares all the code between both the APIs (cpp and python) while preserving the current semantics of both `jit` and `pjit`.
PiperOrigin-RevId: 501707496
2023-01-12 17:24:32 -08:00
jax authors
7206cb5b7b
Merge pull request #13940 from DPS0340:main
...
PiperOrigin-RevId: 501692167
2023-01-12 16:10:16 -08:00
Parker Schuh
30d64f38f1
Add 'hard xmap' support for pure_callback.
...
PiperOrigin-RevId: 501689068
2023-01-12 15:56:50 -08:00
jax authors
34e10e3495
Merge pull request #13892 from jakevdp:bcoo-sum-duplicates-batch
...
PiperOrigin-RevId: 501676606
2023-01-12 15:02:53 -08:00
jax authors
15ec37c581
Merge pull request #13977 from skye:pjrt_c_api_marker
...
PiperOrigin-RevId: 501671140
2023-01-12 14:41:43 -08:00
Jake VanderPlas
f314f2d504
[sparse] generalize batch rule for bcoo_sum_duplicates & improve tests
2023-01-12 14:28:12 -08:00
Skye Wanderman-Milne
f90b5eed52
Add pjrt_c_api_unimplemented
pytest marker to skip unsupported tests.
...
Also adds `test_util.pytest_mark_if_available` helper function.
2023-01-12 22:17:23 +00:00
Dinghua Li
360589a8a1
Allow call_tf to accept shape-polymorphic inputs if the output shapes are fully static.
...
PiperOrigin-RevId: 501645235
2023-01-12 13:04:39 -08:00
jax authors
ba506cbfe2
Merge pull request #13887 from jakevdp:bcoo-dot-general-batch
...
PiperOrigin-RevId: 501637620
2023-01-12 12:34:39 -08:00
Jake VanderPlas
a6a3b59748
[sparse] generalize batch rule for bcoo_dot_general
2023-01-12 12:03:21 -08:00
jax authors
a18bfa57b2
Merge pull request #13889 from jakevdp:bcoo-spdot-general-batch
...
PiperOrigin-RevId: 501626784
2023-01-12 11:52:04 -08:00
Jake VanderPlas
eaf6179594
[sparse] generalize batch rule for bcoo_spdot_general
2023-01-12 10:59:28 -08:00
jax authors
f7c915e6a2
Merge pull request #13909 from jakevdp:sparse-nfold-vmap
...
PiperOrigin-RevId: 501602422
2023-01-12 10:25:47 -08:00
Jake VanderPlas
c0c347bbe7
[sparse] add nfold_vmap utility
2023-01-12 10:09:58 -08:00
Sharad Vikram
f729da4a36
Add shards for checkify_test on GPU
...
PiperOrigin-RevId: 501430172
2023-01-11 18:28:37 -08:00
jax authors
f0cca20a67
Merge pull request #13929 from nvcastet:add_ompi_init
...
PiperOrigin-RevId: 501429401
2023-01-11 18:20:51 -08:00
Yash Katariya
68c43e6c99
Update the non-contiguous error message to not say GDA anymore
...
PiperOrigin-RevId: 501396344
2023-01-11 15:35:15 -08:00
Yash Katariya
acd8dadc74
Make some api_test pass with jit/pjit merge
...
PiperOrigin-RevId: 501392938
2023-01-11 15:21:16 -08:00
Nicolas Castet
b86030d86f
Add Open MPI automatic distributed initialization
2023-01-11 17:08:09 -06:00
jax authors
3cd069f2da
Merge pull request #13967 from jakevdp:array-attrs
...
PiperOrigin-RevId: 501320927
2023-01-11 10:44:58 -08:00
jax authors
7a6c75339f
Merge pull request #13958 from mattjj:pjit-partial-eval-2
...
PiperOrigin-RevId: 501319644
2023-01-11 10:36:39 -08:00
Jake VanderPlas
e1738af5b2
[typing] add missing attributes to jax.Array
2023-01-11 10:00:59 -08:00
Matthew Johnson
8b585302db
add pjit partial_eval_jaxpr_custom rule
...
fix some issues with closed_call's partial_eval_jaxpr_custom rule
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2023-01-11 09:30:49 -08:00
jax authors
ced7332587
Merge pull request #13872 from gnecula:tf_pad
...
PiperOrigin-RevId: 501268817
2023-01-11 06:49:05 -08:00
jax authors
45197ed8db
Merge pull request #13545 from LenaMartens:check-initial
...
PiperOrigin-RevId: 501248830
2023-01-11 04:52:02 -08:00
George Necula
f7093955dc
[jax2tf] Fixed the shape-polymorphic lowering for lax.pad and dynamic_slice
...
Generate DynamicPadOp instea of PadOp when the padding
sizes are not constant.
Fix the generation of RealDynamicSliceOp.
Exclude some tests that fail due to unimplemented support
for custom calls with polymorphic shapes.
2023-01-11 13:02:48 +01:00
Yash Katariya
857febcc15
Pass in the debug_info while we create jaxpr the first time so that the error messages are better
...
PiperOrigin-RevId: 501185867
2023-01-10 22:42:14 -08:00
Yash Katariya
66aafb6e16
Don't take the cpp dispatch path for pjit
if it contains ordered effects just like jit
.
...
PiperOrigin-RevId: 501141750
2023-01-10 18:07:23 -08:00
Yash Katariya
c447e987e1
Skip custom_object_test and dynamic_api_test for pjit/jit merge since it doesn't work with jax.Array's too.
...
PiperOrigin-RevId: 501129056
2023-01-10 16:55:51 -08:00
Yash Katariya
e02c1da4c7
Fix debug nans test after merging jit
and pjit
codepaths
...
PiperOrigin-RevId: 501122848
2023-01-10 16:27:00 -08:00
jax authors
3f712480c6
Merge pull request #13950 from jakevdp:fix-gpu-headings
...
PiperOrigin-RevId: 501111194
2023-01-10 15:35:50 -08:00
Ruoxin Sang
17be6fe1cd
Use tf.ensure_shape
instead of tensor.set_shape
for hinting TF graph to set static shape output. The difference is that tensor.set_shape
only changes the shape information in python, which can still confuse the compiler when dynamic padder is enabled. While tf.ensure_shape
creates a EnsureShape
op in the graph, which XLA compiler can propagate and set the dynamic dimension properly.
...
Re-enable the failing `gather_from_take_indices` test.
PiperOrigin-RevId: 501105068
2023-01-10 15:07:59 -08:00
Jake VanderPlas
140fcd1752
DOC: fix heading levels in GPU custom ops
2023-01-10 13:51:22 -08:00
jax authors
3e52c2d3fd
Merge pull request #13946 from jakevdp:faq-tracer
...
PiperOrigin-RevId: 501073465
2023-01-10 13:07:56 -08:00
jax authors
fa111a22f0
Merge pull request #13944 from nvlcambier:nvlcambier/coordinator_address_doc
...
PiperOrigin-RevId: 501072355
2023-01-10 12:59:58 -08:00
Jake VanderPlas
1bb4de280e
DOC: add FAQ entry on converting a tracer to an array
2023-01-10 12:28:16 -08:00
jax authors
6d2c157354
Merge pull request #13948 from jakevdp:fix-finfo-test
...
PiperOrigin-RevId: 501063882
2023-01-10 12:26:33 -08:00
Jake VanderPlas
264bc9531e
Fix finfo test for older numpy versions
2023-01-10 12:17:31 -08:00
Leopold Cambier
d1edad6a68
Typos: suited -> suitable, node -> host
2023-01-10 11:01:54 -08:00
jax authors
e83621ae45
Merge pull request #13937 from jakevdp:finfo-tests
...
PiperOrigin-RevId: 501035949
2023-01-10 10:44:19 -08:00
Leopold Cambier
7e395c9bbe
DOC: add note about localhost & friends in jax.distributed.initialize
2023-01-10 09:17:31 -08:00
Jake VanderPlas
31fd81f2d5
Add tests of dtypes.finfo properties
2023-01-10 09:02:00 -08:00
Marc van Zee
10847a9372
[jax2tf] Simplifies model testing file writing logic.
...
PiperOrigin-RevId: 500984975
2023-01-10 07:08:00 -08:00
jax authors
62f2b9680b
Merge pull request #13917 from gnecula:tf_bug_gda
...
PiperOrigin-RevId: 500971319
2023-01-10 05:58:28 -08:00
Marc van Zee
4dd289d4c0
[jax2tf] Fixes a bug in converters test.
...
Ensures that models_test_main.py can be executed from OS as well.
PiperOrigin-RevId: 500968754
2023-01-10 05:42:01 -08:00
George Necula
21ebf9042d
[jax2tf] Fixed the conversion of a function that contains an inner pjit
...
In experimental_native_lowering when we convert a function that is not
a jit or pjit, we wrap it with an implicit jit. We used to specify the
backend when doing this conversion, which bypassed some logic in jit
to handle the merging of jit and pjit code paths. We now drop the
backend parameter to the implicit jit.
We also moved some pjit tests from jax2tf to sharding_test and dropped
the old disabled test for teh GDAs, since GDAs are going away.
2023-01-10 14:12:41 +01:00