Yash Katariya
05e1ddd4ea
Make error_test
a jax_test so that we can test other configs and fix it with jit
/pjit
merge.
...
PiperOrigin-RevId: 502743523
2023-01-17 18:43:05 -08:00
Parker Schuh
b58dd3cbe1
Add support for __signature__ to PjitFunction.
...
PiperOrigin-RevId: 502731453
2023-01-17 17:28:14 -08:00
jax authors
6d80e5909b
Merge pull request #13423 from nvlcambier:lcambier/on_demand_ci_jax_gpu
...
PiperOrigin-RevId: 502727655
2023-01-17 17:11:18 -08:00
jax authors
e6e8350389
Merge pull request #14047 from jakevdp:doc-installation
...
PiperOrigin-RevId: 502727042
2023-01-17 17:03:43 -08:00
Yash Katariya
53fceab17c
pjit
allows nesting of pjits where the outer backend is None while the inner backend is something other than device_under_test()
. This is because the inner backend will take priority.
...
PiperOrigin-RevId: 502721834
2023-01-17 16:39:45 -08:00
Leopold Cambier
056702c1cb
Multinodes CICD on GPUs using on-demand cluster and e2e tests using T5X
2023-01-17 16:29:30 -08:00
Jake VanderPlas
5ebd1e79ff
DOC: add GPU & TPU installation tabs on index page
2023-01-17 16:19:57 -08:00
jax authors
8da6c89c7b
Merge pull request #13759 from sharadmv:io-callback
...
PiperOrigin-RevId: 502694690
2023-01-17 14:48:50 -08:00
Sharad Vikram
3de5c2b716
Add IO callback
2023-01-17 13:55:05 -08:00
Yash Katariya
cb9a9952fe
Check if the sharding input to ShapeDtypeStruct is an instance of Sharding
...
PiperOrigin-RevId: 502652848
2023-01-17 12:08:51 -08:00
Yash Katariya
85654ceeab
Default dynamic_api_test and custom_object_test to take the old jit path and not the merged path since there is no pjit support for it yet.
...
PiperOrigin-RevId: 502620662
2023-01-17 10:19:39 -08:00
George Necula
7e0041c903
Fix scatter in CLIP mode with uint32 and uint64 indices
...
Clipping uses np.iinfo(indices.dtype).max and those values
are too large to be converted to Python constants or C constants.
This is a second attempt, after https://github.com/google/jax/pull/13746 was rolled back due to
failures when jax_array=False. Since that use case will go away
soon we just enable the fix for when jax_array=True.
PiperOrigin-RevId: 502568204
2023-01-17 06:26:21 -08:00
jax authors
7ce9fa2f87
Merge pull request #14030 from gnecula:poly_vmap_error
...
PiperOrigin-RevId: 502546564
2023-01-17 04:19:16 -08:00
jax authors
74d6c219b5
Merge pull request #13312 from nouiz:jax_array_doc
...
PiperOrigin-RevId: 502534154
2023-01-17 03:00:19 -08:00
George Necula
cf4e568e21
[shape_poly] Improve error message from vmap axis size inconsistency
...
vmap tries hard to give nice error messages when the mapped axes
for different arguments have different sizes, but the code to
compute the error message can run into InconsistentDimensionOperation
in presence of dimension polynomials. Ensure that the comparisons
are done symbolically.
2023-01-17 10:45:12 +02:00
John QiangZhang
469a8eb520
Change args_tf_flat to tf.TensorSpec in jax2tf.call_tf.
...
PiperOrigin-RevId: 502492762
2023-01-16 22:43:13 -08:00
Sharad Vikram
a58e59d98f
Add in effects_barrier for the pmap unordered callback test
...
PiperOrigin-RevId: 502434258
2023-01-16 14:44:44 -08:00
jax authors
432b909ef8
Merge pull request #14027 from mattjj:issue14026
...
PiperOrigin-RevId: 502408481
2023-01-16 10:59:46 -08:00
Matthew Johnson
e516d41180
cond transpose, use UndefinedPrimal not linear
for transpose inputs
2023-01-16 10:39:19 -08:00
Yash Katariya
c4d21f97ea
Make xmap use dispatch.sharded_lowering as dispatch.lower_xla_callable is deprecated.
...
PiperOrigin-RevId: 502398366
2023-01-16 09:42:40 -08:00
jax authors
35820ef1e4
Merge pull request #13980 from gnecula:call_tf_effects2
...
PiperOrigin-RevId: 502346453
2023-01-16 03:54:11 -08:00
jax authors
78c9dd8104
Merge pull request #14020 from mattjj:issue13983
...
PiperOrigin-RevId: 502147420
2023-01-14 21:36:41 -08:00
Matthew Johnson
1da24b61fc
fix transcription error for initial_step_size (thanks @AlbertMitjans)
...
fixes #13983
2023-01-14 21:10:44 -08:00
Yash Katariya
8f538f95dc
Pass the proper api_name to debug_info
...
PiperOrigin-RevId: 502141425
2023-01-14 20:41:01 -08:00
Yash Katariya
1209ab17e4
Add abstracted axes to pjit to make jax2tf tests pass. abstracted_axes and dynamic_shapes is not supported by pjit yet.
...
PiperOrigin-RevId: 502138836
2023-01-14 20:17:30 -08:00
Yash Katariya
4601928277
Enable jit_pjit_api_merge by default "in tests" and disable the current failing tests.
...
PiperOrigin-RevId: 502088044
2023-01-14 11:15:03 -08:00
Yash Katariya
4c58ef3840
Add in_positional_semantics to new_params_known and new_params_staged otherwise it leads to length mismatch error down the stack. It is similar to donated_invars and in_shardings.
...
PiperOrigin-RevId: 502082828
2023-01-14 10:19:00 -08:00
George Necula
ade5691630
[call_tf] Add has_side_effects parameter
...
The CallTfEffect was added recently as an internal workaround for
DCE removing instances of call_tf. Here we add a parameter to
`call_tf` to be able to declare if the called computation is
effectful and should not be removed by DCE.
2023-01-14 08:12:29 +01:00
Yash Katariya
38f91bdaa5
Skip core tests which have nested pjits and DShapedArray.
...
PiperOrigin-RevId: 502013080
2023-01-13 22:39:31 -08:00
Qiao Zhang
d203926c16
Expose fp8 in jax dtypes and mlir builder.
...
PiperOrigin-RevId: 501980811
2023-01-13 18:12:12 -08:00
Pankaj Kanwar
8fcb5180b2
disable flaky tests on certain targets.
...
PiperOrigin-RevId: 501974439
2023-01-13 17:23:36 -08:00
jax authors
17785a4ded
Merge pull request #13285 from canyon289:draft_focused_landing_page
...
PiperOrigin-RevId: 501954451
2023-01-13 15:35:34 -08:00
Yash Katariya
7e8fe13c6a
jit
was the default name in name_stack in mlir.py. Fix that by taking the name as an optional argument (defaulting to jit
) so that nested pjits will show up as pjit
in the name stack.
...
PiperOrigin-RevId: 501946780
2023-01-13 15:00:22 -08:00
Jake VanderPlas
9aa8fe520f
Stricter input validation for shaped_abstractify
...
Fixes https://github.com/google/jax/issues/13976
PiperOrigin-RevId: 501944946
2023-01-13 14:53:08 -08:00
Yash Katariya
5eb23a7615
Fix name_stack
usage of pjit. Now all the metadata of transformations in hlo are correct.
...
PiperOrigin-RevId: 501918212
2023-01-13 12:54:12 -08:00
Ravin Kumar
2257e2075d
Update doc landing page
...
Co-authored-by: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com>
Co-authored-by: Jake VanderPlas <jakevdp@google.com>
2023-01-13 12:45:28 -08:00
jax authors
86ba62d05e
Merge pull request #13991 from skye:pjrt_c_api_marker
...
PiperOrigin-RevId: 501909180
2023-01-13 12:15:00 -08:00
Yash Katariya
649ee1be34
Make pickle_test.py pass with jit/pjit api merge. Also rename and move some functions around
...
PiperOrigin-RevId: 501878555
2023-01-13 10:16:01 -08:00
Yash Katariya
eb887f42c4
Add type hints to some functions in pxla.py. Adding typehints to pjit
is a bit difficult because of functions like _is_unspecified
, etc. TypeGuard
is a nice solution for it but it doesn't work in the negative case.
...
PiperOrigin-RevId: 501858300
2023-01-13 08:45:49 -08:00
jax authors
d1593289a0
Merge pull request #13869 from jakevdp:bcoo-extract-api
...
PiperOrigin-RevId: 501842123
2023-01-13 07:27:19 -08:00
Jake VanderPlas
7a8781db1c
[sparse] add higher-level version of bcoo_extract & improve tests
2023-01-13 07:13:13 -08:00
jax authors
44a044f936
Merge pull request #13982 from Edenhofer:fix_zero_length_meshgrid
...
PiperOrigin-RevId: 501827615
2023-01-13 06:06:52 -08:00
jax authors
9c418e9399
Merge pull request #13862 from jakevdp:bcoo-extract
...
PiperOrigin-RevId: 501826918
2023-01-13 05:59:08 -08:00
Sharad Vikram
c9a57e1b44
Delete jax.experimental.callback
...
PiperOrigin-RevId: 501760507
2023-01-12 22:58:31 -08:00
Yash Katariya
e21c29476d
Add batch_jaxpr2 which tells the caller where batch dims are.
...
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 501746795
2023-01-12 21:16:59 -08:00
Yash Katariya
94f0ccc54a
Fix host_callback for pjit which was using REPLICATED which was a CanonicalizedParsedPspec
...
PiperOrigin-RevId: 501713533
2023-01-12 18:00:33 -08:00
Yash Katariya
936247a7e5
Fix debugging primitives for pjit. This came up during jit/pjit merge
...
PiperOrigin-RevId: 501710198
2023-01-12 17:40:35 -08:00
Yash Katariya
c8ad89e358
Make jit
a thin wrapper around pjit
which ignores the mesh context manager (just like how it is today)
...
Pass `None` as the resource_env via `jit` because `jit(pjit)` will ignore the outer mesh because `jit` will set the resource env to empty mesh.
This does not make `jit` and `pjit` the same API but it shares all the code between both the APIs (cpp and python) while preserving the current semantics of both `jit` and `pjit`.
PiperOrigin-RevId: 501707496
2023-01-12 17:24:32 -08:00
jax authors
7206cb5b7b
Merge pull request #13940 from DPS0340:main
...
PiperOrigin-RevId: 501692167
2023-01-12 16:10:16 -08:00
Parker Schuh
30d64f38f1
Add 'hard xmap' support for pure_callback.
...
PiperOrigin-RevId: 501689068
2023-01-12 15:56:50 -08:00