14394 Commits

Author SHA1 Message Date
Sharad Vikram
c9a57e1b44 Delete jax.experimental.callback
PiperOrigin-RevId: 501760507
2023-01-12 22:58:31 -08:00
Yash Katariya
e21c29476d Add batch_jaxpr2 which tells the caller where batch dims are.
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 501746795
2023-01-12 21:16:59 -08:00
Yash Katariya
94f0ccc54a Fix host_callback for pjit which was using REPLICATED which was a CanonicalizedParsedPspec
PiperOrigin-RevId: 501713533
2023-01-12 18:00:33 -08:00
Yash Katariya
936247a7e5 Fix debugging primitives for pjit. This came up during jit/pjit merge
PiperOrigin-RevId: 501710198
2023-01-12 17:40:35 -08:00
Yash Katariya
c8ad89e358 Make jit a thin wrapper around pjit which ignores the mesh context manager (just like how it is today)
Pass `None` as the resource_env via `jit` because `jit(pjit)` will ignore the outer mesh because `jit` will set the resource env to empty mesh.

This does not make `jit` and `pjit` the same API but it shares all the code between both the APIs (cpp and python) while preserving the current semantics of both `jit` and `pjit`.

PiperOrigin-RevId: 501707496
2023-01-12 17:24:32 -08:00
jax authors
7206cb5b7b Merge pull request #13940 from DPS0340:main
PiperOrigin-RevId: 501692167
2023-01-12 16:10:16 -08:00
Parker Schuh
30d64f38f1 Add 'hard xmap' support for pure_callback.
PiperOrigin-RevId: 501689068
2023-01-12 15:56:50 -08:00
jax authors
34e10e3495 Merge pull request #13892 from jakevdp:bcoo-sum-duplicates-batch
PiperOrigin-RevId: 501676606
2023-01-12 15:02:53 -08:00
jax authors
15ec37c581 Merge pull request #13977 from skye:pjrt_c_api_marker
PiperOrigin-RevId: 501671140
2023-01-12 14:41:43 -08:00
Jake VanderPlas
f314f2d504 [sparse] generalize batch rule for bcoo_sum_duplicates & improve tests 2023-01-12 14:28:12 -08:00
Skye Wanderman-Milne
f90b5eed52 Add pjrt_c_api_unimplemented pytest marker to skip unsupported tests.
Also adds `test_util.pytest_mark_if_available` helper function.
2023-01-12 22:17:23 +00:00
Dinghua Li
360589a8a1 Allow call_tf to accept shape-polymorphic inputs if the output shapes are fully static.
PiperOrigin-RevId: 501645235
2023-01-12 13:04:39 -08:00
jax authors
ba506cbfe2 Merge pull request #13887 from jakevdp:bcoo-dot-general-batch
PiperOrigin-RevId: 501637620
2023-01-12 12:34:39 -08:00
Jake VanderPlas
a6a3b59748 [sparse] generalize batch rule for bcoo_dot_general 2023-01-12 12:03:21 -08:00
jax authors
a18bfa57b2 Merge pull request #13889 from jakevdp:bcoo-spdot-general-batch
PiperOrigin-RevId: 501626784
2023-01-12 11:52:04 -08:00
Jake VanderPlas
eaf6179594 [sparse] generalize batch rule for bcoo_spdot_general 2023-01-12 10:59:28 -08:00
jax authors
f7c915e6a2 Merge pull request #13909 from jakevdp:sparse-nfold-vmap
PiperOrigin-RevId: 501602422
2023-01-12 10:25:47 -08:00
Jake VanderPlas
c0c347bbe7 [sparse] add nfold_vmap utility 2023-01-12 10:09:58 -08:00
Sharad Vikram
f729da4a36 Add shards for checkify_test on GPU
PiperOrigin-RevId: 501430172
2023-01-11 18:28:37 -08:00
jax authors
f0cca20a67 Merge pull request #13929 from nvcastet:add_ompi_init
PiperOrigin-RevId: 501429401
2023-01-11 18:20:51 -08:00
Yash Katariya
68c43e6c99 Update the non-contiguous error message to not say GDA anymore
PiperOrigin-RevId: 501396344
2023-01-11 15:35:15 -08:00
Yash Katariya
acd8dadc74 Make some api_test pass with jit/pjit merge
PiperOrigin-RevId: 501392938
2023-01-11 15:21:16 -08:00
Nicolas Castet
b86030d86f Add Open MPI automatic distributed initialization 2023-01-11 17:08:09 -06:00
jax authors
3cd069f2da Merge pull request #13967 from jakevdp:array-attrs
PiperOrigin-RevId: 501320927
2023-01-11 10:44:58 -08:00
jax authors
7a6c75339f Merge pull request #13958 from mattjj:pjit-partial-eval-2
PiperOrigin-RevId: 501319644
2023-01-11 10:36:39 -08:00
Jake VanderPlas
e1738af5b2 [typing] add missing attributes to jax.Array 2023-01-11 10:00:59 -08:00
Matthew Johnson
8b585302db add pjit partial_eval_jaxpr_custom rule
fix some issues with closed_call's partial_eval_jaxpr_custom rule

Co-authored-by: Yash Katariya <yashkatariya@google.com>
2023-01-11 09:30:49 -08:00
jax authors
ced7332587 Merge pull request #13872 from gnecula:tf_pad
PiperOrigin-RevId: 501268817
2023-01-11 06:49:05 -08:00
jax authors
45197ed8db Merge pull request #13545 from LenaMartens:check-initial
PiperOrigin-RevId: 501248830
2023-01-11 04:52:02 -08:00
George Necula
f7093955dc [jax2tf] Fixed the shape-polymorphic lowering for lax.pad and dynamic_slice
Generate DynamicPadOp instea of PadOp when the padding
sizes are not constant.

Fix the generation of RealDynamicSliceOp.

Exclude some tests that fail due to unimplemented support
for custom calls with polymorphic shapes.
2023-01-11 13:02:48 +01:00
Yash Katariya
857febcc15 Pass in the debug_info while we create jaxpr the first time so that the error messages are better
PiperOrigin-RevId: 501185867
2023-01-10 22:42:14 -08:00
Yash Katariya
66aafb6e16 Don't take the cpp dispatch path for pjit if it contains ordered effects just like jit.
PiperOrigin-RevId: 501141750
2023-01-10 18:07:23 -08:00
Yash Katariya
c447e987e1 Skip custom_object_test and dynamic_api_test for pjit/jit merge since it doesn't work with jax.Array's too.
PiperOrigin-RevId: 501129056
2023-01-10 16:55:51 -08:00
Yash Katariya
e02c1da4c7 Fix debug nans test after merging jit and pjit codepaths
PiperOrigin-RevId: 501122848
2023-01-10 16:27:00 -08:00
jax authors
3f712480c6 Merge pull request #13950 from jakevdp:fix-gpu-headings
PiperOrigin-RevId: 501111194
2023-01-10 15:35:50 -08:00
Ruoxin Sang
17be6fe1cd Use tf.ensure_shape instead of tensor.set_shape for hinting TF graph to set static shape output. The difference is that tensor.set_shape only changes the shape information in python, which can still confuse the compiler when dynamic padder is enabled. While tf.ensure_shape creates a EnsureShape op in the graph, which XLA compiler can propagate and set the dynamic dimension properly.
Re-enable the failing `gather_from_take_indices` test.

PiperOrigin-RevId: 501105068
2023-01-10 15:07:59 -08:00
Jake VanderPlas
140fcd1752 DOC: fix heading levels in GPU custom ops 2023-01-10 13:51:22 -08:00
jax authors
3e52c2d3fd Merge pull request #13946 from jakevdp:faq-tracer
PiperOrigin-RevId: 501073465
2023-01-10 13:07:56 -08:00
jax authors
fa111a22f0 Merge pull request #13944 from nvlcambier:nvlcambier/coordinator_address_doc
PiperOrigin-RevId: 501072355
2023-01-10 12:59:58 -08:00
Jake VanderPlas
1bb4de280e DOC: add FAQ entry on converting a tracer to an array 2023-01-10 12:28:16 -08:00
jax authors
6d2c157354 Merge pull request #13948 from jakevdp:fix-finfo-test
PiperOrigin-RevId: 501063882
2023-01-10 12:26:33 -08:00
Jake VanderPlas
264bc9531e Fix finfo test for older numpy versions 2023-01-10 12:17:31 -08:00
Leopold Cambier
d1edad6a68 Typos: suited -> suitable, node -> host 2023-01-10 11:01:54 -08:00
jax authors
e83621ae45 Merge pull request #13937 from jakevdp:finfo-tests
PiperOrigin-RevId: 501035949
2023-01-10 10:44:19 -08:00
Leopold Cambier
7e395c9bbe DOC: add note about localhost & friends in jax.distributed.initialize 2023-01-10 09:17:31 -08:00
Jake VanderPlas
31fd81f2d5 Add tests of dtypes.finfo properties 2023-01-10 09:02:00 -08:00
Marc van Zee
10847a9372 [jax2tf] Simplifies model testing file writing logic.
PiperOrigin-RevId: 500984975
2023-01-10 07:08:00 -08:00
jax authors
62f2b9680b Merge pull request #13917 from gnecula:tf_bug_gda
PiperOrigin-RevId: 500971319
2023-01-10 05:58:28 -08:00
Marc van Zee
4dd289d4c0 [jax2tf] Fixes a bug in converters test.
Ensures that models_test_main.py can be executed from OS as well.

PiperOrigin-RevId: 500968754
2023-01-10 05:42:01 -08:00
George Necula
21ebf9042d [jax2tf] Fixed the conversion of a function that contains an inner pjit
In experimental_native_lowering when we convert a function that is not
a jit or pjit, we wrap it with an implicit jit. We used to specify the
backend when doing this conversion, which bypassed some logic in jit
to handle the merging of jit and pjit code paths. We now drop the
backend parameter to the implicit jit.

We also moved some pjit tests from jax2tf to sharding_test and dropped
the old disabled test for teh GDAs, since GDAs are going away.
2023-01-10 14:12:41 +01:00